diff --git a/nemo_curator/filters/bitext_filter.py b/nemo_curator/filters/bitext_filter.py index f85e68408..edfda0fb7 100644 --- a/nemo_curator/filters/bitext_filter.py +++ b/nemo_curator/filters/bitext_filter.py @@ -15,10 +15,11 @@ from abc import ABC, abstractmethod from typing import Dict, List, Optional, Union +from dask.array import logical_and from dask.typing import no_default from nemo_curator.datasets.parallel_dataset import ParallelDataset -from nemo_curator.utils.module_utils import is_batched +from nemo_curator.utils.module_utils import REASON_LABEL_KEY, SKIP_LABEL_KEY, is_batched class BitextFilter(ABC): @@ -41,6 +42,7 @@ def __init__( score_field: Optional[str] = None, score_type: Union[type, str] = None, invert=False, + add_skip_label_only: bool = False, ): """Args: src_field (str, optional): The field the source documents will be read from. Defaults to "src". @@ -64,6 +66,7 @@ def __init__( self.score_field = score_field self.score_type = score_type self.invert = invert + self.add_skip_label_only = add_skip_label_only def __call__( self, @@ -89,21 +92,44 @@ def __call__( fields.append(self.tgt_field) fields.extend(self.metadata_fields) + if self.add_skip_label_only: + if SKIP_LABEL_KEY not in dataset.df.columns: + dataset.df[SKIP_LABEL_KEY] = 0 + if REASON_LABEL_KEY not in dataset.df.columns: + dataset.df[REASON_LABEL_KEY] = None + + # although the full dataset is passed, we don't need to compute score on full data + # only data that's still remaining needs to be processed + kept_mask = dataset.df[SKIP_LABEL_KEY] == 0 + df = dataset.df[kept_mask].copy() + else: + df = dataset.df + if is_batched(self.score_bitext): - scores = dataset.df[fields].map_partitions( + scores = df[fields].map_partitions( self._score_bitext_wrapper, metadata_field_name_mapping=self.metadata_field_name_mapping, meta=meta, ) else: - scores = dataset.df[fields].apply( + scores = df[fields].apply( self._score_bitext_wrapper, metadata_field_name_mapping=self.metadata_field_name_mapping, axis=1, meta=meta, ) - if self.score_field is not None: + if self.score_field is not None and self.add_skip_label_only: + dataset.df[self.score_field] = None + + def update_score(partition, mask_partition, score_partition): + partition.loc[mask_partition, self.score_field] = score_partition + return partition + + dataset.df = dataset.df.map_partitions( + update_score, kept_mask, scores, meta=dataset.df + ) + elif self.score_field is not None: dataset.df[self.score_field] = scores if is_batched(self.keep_bitext): @@ -113,7 +139,36 @@ def __call__( if self.invert: bool_mask = ~bool_mask - return ParallelDataset(dataset.df[bool_mask]) + def update_skipme(partition, kept_mask_partition, score_bool_mask_partition): + partition.loc[kept_mask_partition, SKIP_LABEL_KEY] = [ + 1 if skip else 0 for skip in ~score_bool_mask_partition + ] + return partition + + def update_reason(partition, kept_mask_partition, reason): + # filtering reason needs to be updated for the following entries + # 1. the entry was kept before + # 2. the entry was thrown out by this filter + new_skip = [ + True if skip == 1 else False for skip in partition[SKIP_LABEL_KEY] + ] + new_mask = logical_and(kept_mask_partition.values, new_skip) + partition.loc[new_mask, REASON_LABEL_KEY] = reason + return partition + + if self.add_skip_label_only: + dataset.df = dataset.df.map_partitions( + update_skipme, + kept_mask, + ~bool_mask if self.invert else bool_mask, + meta=dataset.df, + ) + dataset.df = dataset.df.map_partitions( + update_reason, kept_mask, self.__class__.__name__, meta=dataset.df + ) + return ParallelDataset(dataset.df) + else: + return ParallelDataset(dataset.df[bool_mask]) def _score_bitext_wrapper( self, diff --git a/nemo_curator/modules/filter.py b/nemo_curator/modules/filter.py index 888252193..08fda842b 100644 --- a/nemo_curator/modules/filter.py +++ b/nemo_curator/modules/filter.py @@ -23,7 +23,7 @@ from nemo_curator.datasets.parallel_dataset import ParallelDataset from nemo_curator.filters import DocumentFilter from nemo_curator.modules.base import BaseModule -from nemo_curator.utils.module_utils import is_batched +from nemo_curator.utils.module_utils import REASON_LABEL_KEY, SKIP_LABEL_KEY, is_batched # Override so that pd.NA is not passed during the metadata inference make_array_nonempty.register( @@ -168,6 +168,7 @@ def __init__( score_field: Optional[str] = None, score_type: Union[type, str] = None, invert: bool = False, + add_skip_label_only: bool = False, ): """ Constructs a ScoreFilter module. @@ -185,6 +186,7 @@ def __init__( self.score_field = score_field self.score_type = score_type self.invert = invert + self.add_skip_label_only = add_skip_label_only def compute_filter_mask(self, dataset: DocumentDataset): """Compute the bool mask to filter the dataset. @@ -200,16 +202,39 @@ def compute_filter_mask(self, dataset: DocumentDataset): else: meta = no_default + if self.add_skip_label_only: + if SKIP_LABEL_KEY not in dataset.df.columns: + dataset.df[SKIP_LABEL_KEY] = 0 + if REASON_LABEL_KEY not in dataset.df.columns: + dataset.df[REASON_LABEL_KEY] = None + + # although the full dataset is passed, we don't need to compute score on full data + # only data that's still remaining needs to be processed + kept_mask = dataset.df._skipme == 0 + df = dataset.df[kept_mask].copy() + else: + df = dataset.df + if is_batched(self.filter_obj.score_document): - scores = dataset.df[self.text_field].map_partitions( + scores = df[self.text_field].map_partitions( self.filter_obj.score_document, meta=meta ) else: - scores = dataset.df[self.text_field].apply( + scores = df[self.text_field].apply( self.filter_obj.score_document, meta=meta ) - if self.score_field is not None: + if self.score_field is not None and self.add_skip_label_only: + dataset.df[self.score_field] = None + + def update_score(partition, mask_partition, score_partition): + partition.loc[mask_partition, self.score_field] = score_partition + return partition + + dataset.df = dataset.df.map_partitions( + update_score, kept_mask, scores, meta=dataset.df + ) + elif self.score_field is not None: dataset.df[self.score_field] = scores if is_batched(self.filter_obj.keep_document): @@ -234,7 +259,41 @@ def call(self, dataset: DocumentDataset) -> DocumentDataset: DocumentDataset: A dataset with the score and filter applied """ bool_mask = self.compute_filter_mask(dataset) - return DocumentDataset(dataset.df[bool_mask]) + + def update_skipme(partition, kept_mask_partition, score_bool_mask_partition): + partition.loc[kept_mask_partition, SKIP_LABEL_KEY] = [ + 1 if skip else 0 for skip in ~score_bool_mask_partition + ] + return partition + + def update_reason(partition, kept_mask_partition, reason): + # filtering reason needs to be updated for the following entries + # 1. the entry was kept before + # 2. the entry was thrown out by this filter + new_skip = [ + True if skip == 1 else False for skip in partition[SKIP_LABEL_KEY] + ] + new_mask = logical_and(kept_mask_partition.values, new_skip) + partition.loc[new_mask, REASON_LABEL_KEY] = reason + return partition + + if self.add_skip_label_only: + kept_mask = dataset.df._skipme == 0 + dataset.df = dataset.df.map_partitions( + update_skipme, + kept_mask, + ~bool_mask if self.invert else bool_mask, + meta=dataset.df, + ) + dataset.df = dataset.df.map_partitions( + update_reason, + kept_mask, + self.filter_obj.__class__.__name__, + meta=dataset.df, + ) + return DocumentDataset(dataset.df) + else: + return DocumentDataset(dataset.df[bool_mask]) class ParallelScoreFilter(BaseModule): @@ -248,6 +307,7 @@ def __init__( tgt_score=None, score_type=None, invert=False, + add_skip_label_only: bool = False, ): """A filter object wrapper class for applying *monolingual* filter objects on bitext. If either side of the bitext is discarded, the whole bitext pair is discarded. @@ -269,17 +329,25 @@ def __init__( """ super().__init__(input_backend=src_filter_obj.backend) self.source_score_filter = ScoreFilter( - src_filter_obj, src_field, src_score, score_type, invert + src_filter_obj, + src_field, + src_score, + score_type, + invert, + add_skip_label_only, ) self.target_score_filter = ScoreFilter( - tgt_filter_obj, tgt_field, tgt_score, score_type, invert + tgt_filter_obj, + tgt_field, + tgt_score, + score_type, + invert, + add_skip_label_only, ) + self.add_skip_label_only = add_skip_label_only - def call(self, dataset: ParallelDataset): - src_bool_mask = self.source_score_filter.compute_filter_mask(dataset) - tgt_bool_mask = self.target_score_filter.compute_filter_mask(dataset) - + def call(self, dataset: ParallelDataset) -> ParallelDataset: # remove lines together if one of them is filtered - bool_mask = logical_and(src_bool_mask, tgt_bool_mask) - - return ParallelDataset(dataset.df[bool_mask]) + ds1 = self.source_score_filter(dataset) + ds2 = self.target_score_filter(ds1) + return ParallelDataset(ds2.df) diff --git a/nemo_curator/utils/module_utils.py b/nemo_curator/utils/module_utils.py index 388a949f6..340418397 100644 --- a/nemo_curator/utils/module_utils.py +++ b/nemo_curator/utils/module_utils.py @@ -13,6 +13,9 @@ # limitations under the License. import math +SKIP_LABEL_KEY = "_skipme" +REASON_LABEL_KEY = "reason" + def is_batched(function): return hasattr(function, "batched") and function.batched diff --git a/tutorials/bitext_cleaning/main.py b/tutorials/bitext_cleaning/main.py index 58e52ed50..c66eef6e6 100644 --- a/tutorials/bitext_cleaning/main.py +++ b/tutorials/bitext_cleaning/main.py @@ -56,6 +56,7 @@ def filter_dataset(dataset: ParallelDataset, gpu: bool = False) -> ParallelDatas tgt_lang=TGT_LANG, score_field="length_ratio", score_type=float, + add_skip_label_only=True, ), ParallelScoreFilter( HistogramFilter(lang=SRC_LANG), @@ -63,6 +64,7 @@ def filter_dataset(dataset: ParallelDataset, gpu: bool = False) -> ParallelDatas src_score="src_hist", tgt_score="tgt_hist", score_type=int, + add_skip_label_only=True, ), ] ) @@ -75,6 +77,7 @@ def filter_dataset(dataset: ParallelDataset, gpu: bool = False) -> ParallelDatas gpu=gpu, metadata_fields=["src_lang", "tgt_lang"], score_type=float, + add_skip_label_only=True, ) ) else: @@ -115,7 +118,8 @@ def run_curation_pipeline(args: Any, src_file: str, tgt_file: str) -> None: shutil.rmtree(out_path) os.makedirs(out_path) - dataset.to_bitext(out_path, write_to_filename=True) + # dataset.to_bitext(out_path, write_to_filename=True) + dataset.to_json(out_path) client.close()