Skip to content

Commit

Permalink
Sentiment analysis example (awslabs#69)
Browse files Browse the repository at this point in the history
* edit installation instructions in readme

* bump up version

* improve sentiment analysis example
  • Loading branch information
gianlucadetommaso authored May 17, 2023
1 parent ab13c1d commit f719bec
Show file tree
Hide file tree
Showing 3 changed files with 70 additions and 14 deletions.
1 change: 1 addition & 0 deletions examples/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,4 @@ In this section we show some examples of how to use Fortuna in classification an
scaling_up_bayesian_inference
mnist_classification_sghmc
sgmcmc_diagnostics
sentiment_analysis
81 changes: 67 additions & 14 deletions examples/sentiment_analysis.pct.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
# ---
# jupyter:
# jupytext:
# notebook_metadata_filter: nbsphinx
# text_representation:
# extension: .py
# format_name: percent
# format_version: '1.3'
# jupytext_version: 1.14.1
# format_name: light
# format_version: '1.5'
# jupytext_version: 1.14.5
# kernelspec:
# display_name: python3
# language: python
Expand All @@ -14,7 +15,17 @@
# execute: never
# ---

# %%
# # Sentiment analysis

# In this notebook we show how to download a pre-trained model and a dataset from Hugging Face,
# and how to calibrate the model by fine-tuning part of its parameters for a sentiment analysis task.
#
# The following cell makes several configuration choices.
# By default, we use a [Bert](https://huggingface.co/docs/transformers/model_doc/bert#transformers.FlaxBertModel) model
# and the [imdb](https://huggingface.co/datasets/imdb) dataset. We make choices on how to split the data,
# on the batch size and on the optimization.

# +
pretrained_model_name_or_path = "bert-base-cased"
dataset_name = "imdb"
text_columns = ("text",)
Expand All @@ -34,13 +45,21 @@
learning_rate = 2e-5
max_grad_norm = 1.0
early_stopping_patience = 1
# -

# ## Prepare the data

# First thing first, from Hugging Face we instantiate a tokenizer for the pre-trained model in use.

# %%
# +
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path)
# -

# %%
# Then, we download some calibration, validation and test datasets.

# +
from datasets import DatasetDict, load_dataset

datasets = DatasetDict(
Expand All @@ -50,8 +69,12 @@
"test": load_dataset(dataset_name, split=test_split),
}
)
# -

# It's time for Fortuna to come into play. First, we call a sequence classification dataset object,
# then we tokenize the datasets, and finally we construct calibration, validation, and test data loaders.

# %%
# +
from fortuna.data.dataset.huggingface_datasets import (
HuggingFaceSequenceClassificationDataset,
)
Expand Down Expand Up @@ -92,20 +115,36 @@
rng=rng,
verbose=True,
)
# -

# ## Define the transformer model

# From the [transformers](https://huggingface.co/docs/transformers/index) library of Hugging Face,
# we instantiate the pre-trained transformer of interest.

# %%
# +
from transformers import FlaxAutoModelForSequenceClassification

model = FlaxAutoModelForSequenceClassification.from_pretrained(
pretrained_model_name_or_path=pretrained_model_name_or_path, num_labels=num_labels
)
# -

# %%
# We now pass the model to `CalibClassifier`, and instantiate a calibration model. Out-of-the-box, you will be able to
# finetune your model on custom loss functions, and on arbitrary subsets of model parameters.

# +
from fortuna.calib_model import CalibClassifier

calib_model = CalibClassifier(model=model)
# -

# ## Calibrate!

# %%
# We first construct an optimizer. We use Fortuna's functionality to define a learning rate scheduler for
# [AdamW](https://arxiv.org/pdf/1711.05101.pdf).

# +
from fortuna.utils.optimizer import (
linear_scheduler_with_warmup,
decay_mask_without_layer_norm_fn,
Expand All @@ -123,8 +162,13 @@
weight_decay=weight_decay,
mask=decay_mask_without_layer_norm_fn,
)
# -

# We then configure the calibration process, in particular hyperparameters, metrics to monitor, early stopping,
# the optimizer and which parameters we want to calibrate. Here, we are choosing to calibrate only the parameters that
# contain "classifier" in the path, i.e. only the parameters of the last layer.

# %%
# +
from fortuna.calib_model import Config, Optimizer, Monitor, Hyperparameters
from fortuna.metric.classification import accuracy, brier_score

Expand All @@ -151,13 +195,22 @@ def brier(preds, uncertainties, targets):
freeze_fun=lambda path, v: "trainable" if "classifier" in path else "frozen",
),
)
# -

# Finally, we calibrate! By default, the method employs a
# [focal loss](https://proceedings.neurips.cc/paper/2020/file/aeb7b30ef1d024a76f21a1d40e30c302-Paper.pdf),
# but feel free to pass your favourite one!

# %%
status = calib_model.calibrate(
calib_data_loader=calib_data_loader, val_data_loader=val_data_loader, config=config
)

# %%
# ## Compute metrics

# We now compute some accuracy and [Expected Calibration Error](http://proceedings.mlr.press/v70/guo17a/guo17a.pdf)
# (ECE) to evaluate how the method performs on some test data.

# +
from fortuna.metric.classification import expected_calibration_error

test_inputs_loader = test_data_loader.to_inputs_loader()
Expand All @@ -175,7 +228,7 @@ def brier(preds, uncertainties, targets):
probs=means,
targets=test_targets,
)
# -

# %%
print(f"Accuracy on test set: {acc}.")
print(f"ECE on test set: {ece}.")
2 changes: 2 additions & 0 deletions fortuna/data/dataset/data_collator.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
# Code adapted from https://github.com/huggingface/transformers/blob/main/examples/flax/

from typing import (
Dict,
List,
Expand Down

0 comments on commit f719bec

Please sign in to comment.