From dbcdcc8dc72b405fff696833fa506176b7de4e5d Mon Sep 17 00:00:00 2001
From: Shuoyang Ding <shuoyangd@nvidia.com>
Date: Fri, 21 Feb 2025 15:23:38 -0800
Subject: [PATCH 1/2] add skipme functionality

Signed-off-by: Shuoyang Ding <shuoyangd@nvidia.com>
---
 nemo_curator/filters/bitext_filter.py | 65 ++++++++++++++++--
 nemo_curator/modules/filter.py        | 96 +++++++++++++++++++++++----
 nemo_curator/utils/module_utils.py    |  3 +
 tutorials/bitext_cleaning/main.py     |  6 +-
 4 files changed, 150 insertions(+), 20 deletions(-)

diff --git a/nemo_curator/filters/bitext_filter.py b/nemo_curator/filters/bitext_filter.py
index f85e68408..edfda0fb7 100644
--- a/nemo_curator/filters/bitext_filter.py
+++ b/nemo_curator/filters/bitext_filter.py
@@ -15,10 +15,11 @@
 from abc import ABC, abstractmethod
 from typing import Dict, List, Optional, Union
 
+from dask.array import logical_and
 from dask.typing import no_default
 
 from nemo_curator.datasets.parallel_dataset import ParallelDataset
-from nemo_curator.utils.module_utils import is_batched
+from nemo_curator.utils.module_utils import REASON_LABEL_KEY, SKIP_LABEL_KEY, is_batched
 
 
 class BitextFilter(ABC):
@@ -41,6 +42,7 @@ def __init__(
         score_field: Optional[str] = None,
         score_type: Union[type, str] = None,
         invert=False,
+        add_skip_label_only: bool = False,
     ):
         """Args:
             src_field (str, optional): The field the source documents will be read from. Defaults to "src".
@@ -64,6 +66,7 @@ def __init__(
         self.score_field = score_field
         self.score_type = score_type
         self.invert = invert
+        self.add_skip_label_only = add_skip_label_only
 
     def __call__(
         self,
@@ -89,21 +92,44 @@ def __call__(
         fields.append(self.tgt_field)
         fields.extend(self.metadata_fields)
 
+        if self.add_skip_label_only:
+            if SKIP_LABEL_KEY not in dataset.df.columns:
+                dataset.df[SKIP_LABEL_KEY] = 0
+            if REASON_LABEL_KEY not in dataset.df.columns:
+                dataset.df[REASON_LABEL_KEY] = None
+
+            # although the full dataset is passed, we don't need to compute score on full data
+            # only data that's still remaining needs to be processed
+            kept_mask = dataset.df[SKIP_LABEL_KEY] == 0
+            df = dataset.df[kept_mask].copy()
+        else:
+            df = dataset.df
+
         if is_batched(self.score_bitext):
-            scores = dataset.df[fields].map_partitions(
+            scores = df[fields].map_partitions(
                 self._score_bitext_wrapper,
                 metadata_field_name_mapping=self.metadata_field_name_mapping,
                 meta=meta,
             )
         else:
-            scores = dataset.df[fields].apply(
+            scores = df[fields].apply(
                 self._score_bitext_wrapper,
                 metadata_field_name_mapping=self.metadata_field_name_mapping,
                 axis=1,
                 meta=meta,
             )
 
-        if self.score_field is not None:
+        if self.score_field is not None and self.add_skip_label_only:
+            dataset.df[self.score_field] = None
+
+            def update_score(partition, mask_partition, score_partition):
+                partition.loc[mask_partition, self.score_field] = score_partition
+                return partition
+
+            dataset.df = dataset.df.map_partitions(
+                update_score, kept_mask, scores, meta=dataset.df
+            )
+        elif self.score_field is not None:
             dataset.df[self.score_field] = scores
 
         if is_batched(self.keep_bitext):
@@ -113,7 +139,36 @@ def __call__(
         if self.invert:
             bool_mask = ~bool_mask
 
-        return ParallelDataset(dataset.df[bool_mask])
+        def update_skipme(partition, kept_mask_partition, score_bool_mask_partition):
+            partition.loc[kept_mask_partition, SKIP_LABEL_KEY] = [
+                1 if skip else 0 for skip in ~score_bool_mask_partition
+            ]
+            return partition
+
+        def update_reason(partition, kept_mask_partition, reason):
+            # filtering reason needs to be updated for the following entries
+            # 1. the entry was kept before
+            # 2. the entry was thrown out by this filter
+            new_skip = [
+                True if skip == 1 else False for skip in partition[SKIP_LABEL_KEY]
+            ]
+            new_mask = logical_and(kept_mask_partition.values, new_skip)
+            partition.loc[new_mask, REASON_LABEL_KEY] = reason
+            return partition
+
+        if self.add_skip_label_only:
+            dataset.df = dataset.df.map_partitions(
+                update_skipme,
+                kept_mask,
+                ~bool_mask if self.invert else bool_mask,
+                meta=dataset.df,
+            )
+            dataset.df = dataset.df.map_partitions(
+                update_reason, kept_mask, self.__class__.__name__, meta=dataset.df
+            )
+            return ParallelDataset(dataset.df)
+        else:
+            return ParallelDataset(dataset.df[bool_mask])
 
     def _score_bitext_wrapper(
         self,
diff --git a/nemo_curator/modules/filter.py b/nemo_curator/modules/filter.py
index 888252193..08fda842b 100644
--- a/nemo_curator/modules/filter.py
+++ b/nemo_curator/modules/filter.py
@@ -23,7 +23,7 @@
 from nemo_curator.datasets.parallel_dataset import ParallelDataset
 from nemo_curator.filters import DocumentFilter
 from nemo_curator.modules.base import BaseModule
-from nemo_curator.utils.module_utils import is_batched
+from nemo_curator.utils.module_utils import REASON_LABEL_KEY, SKIP_LABEL_KEY, is_batched
 
 # Override so that pd.NA is not passed during the metadata inference
 make_array_nonempty.register(
@@ -168,6 +168,7 @@ def __init__(
         score_field: Optional[str] = None,
         score_type: Union[type, str] = None,
         invert: bool = False,
+        add_skip_label_only: bool = False,
     ):
         """
         Constructs a ScoreFilter module.
@@ -185,6 +186,7 @@ def __init__(
         self.score_field = score_field
         self.score_type = score_type
         self.invert = invert
+        self.add_skip_label_only = add_skip_label_only
 
     def compute_filter_mask(self, dataset: DocumentDataset):
         """Compute the bool mask to filter the dataset.
@@ -200,16 +202,39 @@ def compute_filter_mask(self, dataset: DocumentDataset):
         else:
             meta = no_default
 
+        if self.add_skip_label_only:
+            if SKIP_LABEL_KEY not in dataset.df.columns:
+                dataset.df[SKIP_LABEL_KEY] = 0
+            if REASON_LABEL_KEY not in dataset.df.columns:
+                dataset.df[REASON_LABEL_KEY] = None
+
+            # although the full dataset is passed, we don't need to compute score on full data
+            # only data that's still remaining needs to be processed
+            kept_mask = dataset.df._skipme == 0
+            df = dataset.df[kept_mask].copy()
+        else:
+            df = dataset.df
+
         if is_batched(self.filter_obj.score_document):
-            scores = dataset.df[self.text_field].map_partitions(
+            scores = df[self.text_field].map_partitions(
                 self.filter_obj.score_document, meta=meta
             )
         else:
-            scores = dataset.df[self.text_field].apply(
+            scores = df[self.text_field].apply(
                 self.filter_obj.score_document, meta=meta
             )
 
-        if self.score_field is not None:
+        if self.score_field is not None and self.add_skip_label_only:
+            dataset.df[self.score_field] = None
+
+            def update_score(partition, mask_partition, score_partition):
+                partition.loc[mask_partition, self.score_field] = score_partition
+                return partition
+
+            dataset.df = dataset.df.map_partitions(
+                update_score, kept_mask, scores, meta=dataset.df
+            )
+        elif self.score_field is not None:
             dataset.df[self.score_field] = scores
 
         if is_batched(self.filter_obj.keep_document):
@@ -234,7 +259,41 @@ def call(self, dataset: DocumentDataset) -> DocumentDataset:
             DocumentDataset: A dataset with the score and filter applied
         """
         bool_mask = self.compute_filter_mask(dataset)
-        return DocumentDataset(dataset.df[bool_mask])
+
+        def update_skipme(partition, kept_mask_partition, score_bool_mask_partition):
+            partition.loc[kept_mask_partition, SKIP_LABEL_KEY] = [
+                1 if skip else 0 for skip in ~score_bool_mask_partition
+            ]
+            return partition
+
+        def update_reason(partition, kept_mask_partition, reason):
+            # filtering reason needs to be updated for the following entries
+            # 1. the entry was kept before
+            # 2. the entry was thrown out by this filter
+            new_skip = [
+                True if skip == 1 else False for skip in partition[SKIP_LABEL_KEY]
+            ]
+            new_mask = logical_and(kept_mask_partition.values, new_skip)
+            partition.loc[new_mask, REASON_LABEL_KEY] = reason
+            return partition
+
+        if self.add_skip_label_only:
+            kept_mask = dataset.df._skipme == 0
+            dataset.df = dataset.df.map_partitions(
+                update_skipme,
+                kept_mask,
+                ~bool_mask if self.invert else bool_mask,
+                meta=dataset.df,
+            )
+            dataset.df = dataset.df.map_partitions(
+                update_reason,
+                kept_mask,
+                self.filter_obj.__class__.__name__,
+                meta=dataset.df,
+            )
+            return DocumentDataset(dataset.df)
+        else:
+            return DocumentDataset(dataset.df[bool_mask])
 
 
 class ParallelScoreFilter(BaseModule):
@@ -248,6 +307,7 @@ def __init__(
         tgt_score=None,
         score_type=None,
         invert=False,
+        add_skip_label_only: bool = False,
     ):
         """A filter object wrapper class for applying *monolingual* filter objects on bitext.
         If either side of the bitext is discarded, the whole bitext pair is discarded.
@@ -269,17 +329,25 @@ def __init__(
         """
         super().__init__(input_backend=src_filter_obj.backend)
         self.source_score_filter = ScoreFilter(
-            src_filter_obj, src_field, src_score, score_type, invert
+            src_filter_obj,
+            src_field,
+            src_score,
+            score_type,
+            invert,
+            add_skip_label_only,
         )
         self.target_score_filter = ScoreFilter(
-            tgt_filter_obj, tgt_field, tgt_score, score_type, invert
+            tgt_filter_obj,
+            tgt_field,
+            tgt_score,
+            score_type,
+            invert,
+            add_skip_label_only,
         )
+        self.add_skip_label_only = add_skip_label_only
 
-    def call(self, dataset: ParallelDataset):
-        src_bool_mask = self.source_score_filter.compute_filter_mask(dataset)
-        tgt_bool_mask = self.target_score_filter.compute_filter_mask(dataset)
-
+    def call(self, dataset: ParallelDataset) -> ParallelDataset:
         # remove lines together if one of them is filtered
-        bool_mask = logical_and(src_bool_mask, tgt_bool_mask)
-
-        return ParallelDataset(dataset.df[bool_mask])
+        ds1 = self.source_score_filter(dataset)
+        ds2 = self.target_score_filter(ds1)
+        return ParallelDataset(ds2.df)
diff --git a/nemo_curator/utils/module_utils.py b/nemo_curator/utils/module_utils.py
index 388a949f6..340418397 100644
--- a/nemo_curator/utils/module_utils.py
+++ b/nemo_curator/utils/module_utils.py
@@ -13,6 +13,9 @@
 # limitations under the License.
 import math
 
+SKIP_LABEL_KEY = "_skipme"
+REASON_LABEL_KEY = "reason"
+
 
 def is_batched(function):
     return hasattr(function, "batched") and function.batched
diff --git a/tutorials/bitext_cleaning/main.py b/tutorials/bitext_cleaning/main.py
index 58e52ed50..5b94266c7 100644
--- a/tutorials/bitext_cleaning/main.py
+++ b/tutorials/bitext_cleaning/main.py
@@ -56,6 +56,7 @@ def filter_dataset(dataset: ParallelDataset, gpu: bool = False) -> ParallelDatas
                 tgt_lang=TGT_LANG,
                 score_field="length_ratio",
                 score_type=float,
+                add_skip_label_only=True,
             ),
             ParallelScoreFilter(
                 HistogramFilter(lang=SRC_LANG),
@@ -63,6 +64,7 @@ def filter_dataset(dataset: ParallelDataset, gpu: bool = False) -> ParallelDatas
                 src_score="src_hist",
                 tgt_score="tgt_hist",
                 score_type=int,
+                add_skip_label_only=True,
             ),
         ]
     )
@@ -75,6 +77,7 @@ def filter_dataset(dataset: ParallelDataset, gpu: bool = False) -> ParallelDatas
                 gpu=gpu,
                 metadata_fields=["src_lang", "tgt_lang"],
                 score_type=float,
+                add_skip_label_only=True,
             )
         )
     else:
@@ -115,7 +118,8 @@ def run_curation_pipeline(args: Any, src_file: str, tgt_file: str) -> None:
         shutil.rmtree(out_path)
 
     os.makedirs(out_path)
-    dataset.to_bitext(out_path, write_to_filename=True)
+    # dataset.to_bitext(out_path, write_to_filename=True)
+    dataset.to_json(out_path, write_to_filename=True)
     client.close()
 
 

From 841b18f25c1e8463b05f112187081268ac401140 Mon Sep 17 00:00:00 2001
From: Shuoyang Ding <shuoyangd@nvidia.com>
Date: Fri, 21 Feb 2025 16:22:32 -0800
Subject: [PATCH 2/2] to_json does not handle multiple shards for the same file
 yet

Signed-off-by: Shuoyang Ding <shuoyangd@nvidia.com>
---
 tutorials/bitext_cleaning/main.py | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/tutorials/bitext_cleaning/main.py b/tutorials/bitext_cleaning/main.py
index 5b94266c7..c66eef6e6 100644
--- a/tutorials/bitext_cleaning/main.py
+++ b/tutorials/bitext_cleaning/main.py
@@ -119,7 +119,7 @@ def run_curation_pipeline(args: Any, src_file: str, tgt_file: str) -> None:
 
     os.makedirs(out_path)
     # dataset.to_bitext(out_path, write_to_filename=True)
-    dataset.to_json(out_path, write_to_filename=True)
+    dataset.to_json(out_path)
     client.close()