Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/data processing #110

Closed
wants to merge 6 commits into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions config/asr_finetuning.yaml
Original file line number Diff line number Diff line change
@@ -6,14 +6,12 @@ defaults:
- wikipedia
- common_voice
- reddit
- experiment_tracking: wandb
- experiment_tracking: mlflow
- override hydra/job_logging: custom
- _self_

seed: 4242

experiment_tracking: null

evaluation_dataset:
id: alexandrainst/coral
subset: read_aloud
1 change: 1 addition & 0 deletions config/evaluation.yaml
Original file line number Diff line number Diff line change
@@ -13,6 +13,7 @@ max_seconds_per_example: 10
clean_text: true
lower_case: true
characters_to_keep: 'abcdefghijklmnopqrstuvwxyzæøå0123456789éü'
normalize_audio: false

# Evaluation parameters
model_id: null
1 change: 1 addition & 0 deletions config/model/test-wav2vec2.yaml
Original file line number Diff line number Diff line change
@@ -6,6 +6,7 @@ freeze_feature_encoder: true
# Data hyperparameters
lower_case: true
clean_text: true
normalize_audio: false

# Model hyperparameters
sampling_rate: 16_000
1 change: 1 addition & 0 deletions config/model/test-whisper.yaml
Original file line number Diff line number Diff line change
@@ -6,6 +6,7 @@ freeze_feature_encoder: true
# Data hyperparameters
lower_case: true
clean_text: true
normalize_audio: false

# Model hyperparameters
sampling_rate: 16_000
1 change: 1 addition & 0 deletions config/model/wav2vec2-large.yaml
Original file line number Diff line number Diff line change
@@ -6,6 +6,7 @@ freeze_feature_encoder: false
# Data hyperparameters
lower_case: true
clean_text: true
normalize_audio: false

# Model hyperparameters
sampling_rate: 16_000
1 change: 1 addition & 0 deletions config/model/wav2vec2-medium.yaml
Original file line number Diff line number Diff line change
@@ -6,6 +6,7 @@ freeze_feature_encoder: false
# Data hyperparameters
lower_case: true
clean_text: true
normalize_audio: false

# Model hyperparameters
sampling_rate: 16_000
1 change: 1 addition & 0 deletions config/model/wav2vec2-small.yaml
Original file line number Diff line number Diff line change
@@ -8,6 +8,7 @@ freeze_feature_encoder: false
# Data hyperparameters
lower_case: true
clean_text: true
normalize_audio: false

# Model hyperparameters
sampling_rate: 16_000
1 change: 1 addition & 0 deletions config/model/whisper-large-turbo.yaml
Original file line number Diff line number Diff line change
@@ -6,6 +6,7 @@ freeze_feature_encoder: false
# Data hyperparameters
lower_case: true
clean_text: true
normalize_audio: false

# Model hyperparameters
sampling_rate: 16_000
1 change: 1 addition & 0 deletions config/model/whisper-large.yaml
Original file line number Diff line number Diff line change
@@ -9,6 +9,7 @@ freeze_feature_encoder: false
# Data hyperparameters
lower_case: true
clean_text: true
normalize_audio: false

# Model hyperparameters
sampling_rate: 16_000
1 change: 1 addition & 0 deletions config/model/whisper-medium.yaml
Original file line number Diff line number Diff line change
@@ -8,6 +8,7 @@ freeze_feature_encoder: false
# Data hyperparameters
lower_case: true
clean_text: true
normalize_audio: false

# Model hyperparameters
sampling_rate: 16_000
1 change: 1 addition & 0 deletions config/model/whisper-small.yaml
Original file line number Diff line number Diff line change
@@ -8,6 +8,7 @@ freeze_feature_encoder: false
# Data hyperparameters
lower_case: true
clean_text: true
normalize_audio: false

# Model hyperparameters
sampling_rate: 16_000
1 change: 1 addition & 0 deletions config/model/whisper-xsmall.yaml
Original file line number Diff line number Diff line change
@@ -8,6 +8,7 @@ freeze_feature_encoder: false
# Data hyperparameters
lower_case: true
clean_text: true
normalize_audio: false

# Model hyperparameters
sampling_rate: 16_000
1 change: 1 addition & 0 deletions config/model/whisper-xxsmall.yaml
Original file line number Diff line number Diff line change
@@ -8,6 +8,7 @@ freeze_feature_encoder: false
# Data hyperparameters
lower_case: true
clean_text: true
normalize_audio: false

# Model hyperparameters
sampling_rate: 16_000
1 change: 1 addition & 0 deletions src/coral/compute_metrics.py
Original file line number Diff line number Diff line change
@@ -180,6 +180,7 @@ def compute_metrics_of_dataset_using_pipeline(
clean_text=True,
lower_case=True,
convert_numerals=True,
normalize_audio=False,
processor=None,
)["text"]

46 changes: 34 additions & 12 deletions src/coral/data.py
Original file line number Diff line number Diff line change
@@ -16,6 +16,7 @@
IterableDataset,
IterableDatasetDict,
NamedSplit,
get_dataset_config_info,
interleave_datasets,
load_dataset,
)
@@ -97,6 +98,7 @@ def load_data_for_finetuning(
is_main_process = os.getenv("RANK", "0") == "0"

all_datasets: list[IterableDataset] | list[Dataset] = list()
len_datasets: list[int] = list()
for dataset_name, dataset_config in config.datasets.items():
if is_main_process:
logger.info(f"Loading dataset {dataset_name!r}")
@@ -138,6 +140,8 @@ def load_data_for_finetuning(
cache_dir=config.cache_dir,
)

len_datasets.append(len(ds))

# Load dataset from the Hugging Face Hub. The HUGGINGFACE_HUB_TOKEN is only
# used during CI - normally it is expected that the user is logged in to the
# Hugging Face Hub using the `huggingface-cli login` command.
@@ -152,6 +156,14 @@ def load_data_for_finetuning(
cache_dir=config.cache_dir,
)

len_datasets.append(
get_dataset_config_info(
dataset_config.id, config_name=dataset_config.subset
)
.splits[dataset_config.train_name]
.num_examples
)

assert isinstance(
ds, Dataset | IterableDataset
), f"Unsupported dataset type: {type(ds)}"
@@ -193,15 +205,13 @@ def load_data_for_finetuning(
if config.dataset_probabilities is None and len(all_datasets) > 1:
logger.warning(
"No dataset probabilities were specified for the training split. "
"This means that each dataset will be sampled with equal "
"probability, which means that the smaller datasets will be "
"sampled more often than the larger datasets. This is probably "
"not what you want."
"This means that each dataset will be sampled according to their "
"relative sizes, which might not be what you want."
)

probabilities = config.dataset_probabilities
if probabilities is None:
probabilities = [1 / len(all_datasets)] * len(all_datasets)
probabilities = [n / sum(len_datasets) for n in len_datasets]
probabilities[-1] = 1 - sum(probabilities[:-1])
elif sum(probabilities) != 1:
raise ValueError(
@@ -223,12 +233,13 @@ def load_data_for_finetuning(
clean_text=config.model.clean_text,
lower_case=config.model.lower_case,
characters_to_keep=config.characters_to_keep,
remove_input_dataset_columns=True,
text_column="text",
audio_column="audio",
convert_numerals=False,
remove_input_dataset_columns=True,
processor=processor,
normalize_audio=config.model.normalize_audio,
num_proc=config.dataset_num_workers,
processor=processor,
)

data_dict = dict(train=train)
@@ -268,12 +279,13 @@ def load_data_for_finetuning(
clean_text=config.model.clean_text,
lower_case=config.model.lower_case,
characters_to_keep=config.characters_to_keep,
remove_input_dataset_columns=True,
text_column="text",
audio_column="audio",
convert_numerals=False,
remove_input_dataset_columns=True,
processor=processor,
normalize_audio=config.model.normalize_audio,
num_proc=config.dataset_num_workers,
processor=processor,
)
dataset["val"] = val

@@ -342,10 +354,11 @@ def load_dataset_for_evaluation(config: DictConfig) -> Dataset:
clean_text=config.clean_text,
lower_case=config.lower_case,
characters_to_keep=config.characters_to_keep,
remove_input_dataset_columns=False,
text_column=config.text_column,
audio_column=config.audio_column,
remove_input_dataset_columns=False,
convert_numerals=True,
normalize_audio=config.normalize_audio,
)

if config.cache_dir:
@@ -453,10 +466,11 @@ def process_dataset(
clean_text: bool,
lower_case: bool,
characters_to_keep: Iterable[str] | None,
text_column: str,
remove_input_dataset_columns: bool,
text_column: str,
audio_column: str | None,
convert_numerals: bool,
normalize_audio: bool = False,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No need to include a default here, as you're including the argument explicitly when calling the function. Keeping it as a default could lead to silent errors:

Suggested change
normalize_audio: bool = False,
normalize_audio: bool,

num_proc: int | None = None,
processor: Callable | None = None,
) -> Data:
@@ -483,6 +497,8 @@ def process_dataset(
does not have an audio column.
convert_numerals:
Whether to convert numerals to words.
normalize_audio:
Whether to normalize the audio.
num_proc (optional):
The number of processes to use for processing the dataset. If `None`, then
no multiprocessing is used. Defaults to `None`.
@@ -507,6 +523,7 @@ def process_dataset(
clean_text=clean_text,
lower_case=lower_case,
convert_numerals=convert_numerals,
normalize_audio=normalize_audio,
processor=processor,
)
if isinstance(dataset, Dataset | DatasetDict):
@@ -531,6 +548,7 @@ def process_example(
clean_text: bool,
lower_case: bool,
convert_numerals: bool,
normalize_audio: bool,
processor: Callable | None,
) -> dict:
"""Helper function which cleans a single example.
@@ -554,6 +572,8 @@ def process_example(
Whether to make the text lower case.
convert_numerals:
Whether to convert numerals to words.
normalize_audio:
Whether to normalize the audio.
processor:
The processor to use for processing the audio and transcriptions. If `None`,
then the processor is not used. Requires `audio_column` to be specified.
@@ -610,7 +630,9 @@ def process_example(
# Prepare audio
audio = example[audio_column]
sampling_rate = audio["sampling_rate"]
processed = processor(audio["array"], sampling_rate=sampling_rate)
processed = processor(
audio["array"], sampling_rate=sampling_rate, do_normalize=normalize_audio
)
if "input_values" in processed:
example["input_values"] = processed.input_values[0]
example["num_seconds"] = len(example["input_values"]) / sampling_rate
5 changes: 3 additions & 2 deletions src/coral/ngram.py
Original file line number Diff line number Diff line change
@@ -223,12 +223,13 @@ def get_sentence_corpus_path(config: DictConfig) -> Path:
dataset = process_dataset(
dataset=dataset,
clean_text=config.model.clean_text,
lower_case=config.model.lower_case,
characters_to_keep=config.characters_to_keep,
text_column="text",
remove_input_dataset_columns=False,
text_column="text",
audio_column=None,
convert_numerals=False,
lower_case=config.model.lower_case,
normalize_audio=config.model.normalize_audio,
)
assert isinstance(dataset, Dataset)

5 changes: 3 additions & 2 deletions src/coral/validation.py
Original file line number Diff line number Diff line change
@@ -67,12 +67,13 @@ def add_validations(
processed_dataset = process_dataset(
dataset=dataset,
clean_text=clean_text,
lower_case=lower_case,
characters_to_keep=characters_to_keep,
remove_input_dataset_columns=True,
text_column=text_column,
audio_column=audio_column,
convert_numerals=False,
remove_input_dataset_columns=True,
lower_case=lower_case,
normalize_audio=False,
)

logger.info(f"Loading the {model_id!r} ASR model...")
6 changes: 4 additions & 2 deletions tests/test_data.py
Original file line number Diff line number Diff line change
@@ -35,12 +35,13 @@ def test_process_dataset(self, dataset):
processed_dataset = process_dataset(
dataset=dataset,
clean_text=True,
lower_case=True,
characters_to_keep=None,
remove_input_dataset_columns=False,
text_column="text",
audio_column=None,
convert_numerals=False,
remove_input_dataset_columns=False,
lower_case=True,
normalize_audio=False,
)
processed_samples = {sample["text"] for sample in processed_dataset}
expected_samples = {
@@ -213,6 +214,7 @@ def test_clean_example(
clean_text=True,
lower_case=lower_case,
convert_numerals=False,
normalize_audio=False,
processor=None,
)[text_column]
assert cleaned_transcription == expected