Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add option to skip data by adding a flag instead of removing them #566

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 60 additions & 5 deletions nemo_curator/filters/bitext_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,11 @@
from abc import ABC, abstractmethod
from typing import Dict, List, Optional, Union

from dask.array import logical_and
from dask.typing import no_default

from nemo_curator.datasets.parallel_dataset import ParallelDataset
from nemo_curator.utils.module_utils import is_batched
from nemo_curator.utils.module_utils import REASON_LABEL_KEY, SKIP_LABEL_KEY, is_batched


class BitextFilter(ABC):
Expand All @@ -41,6 +42,7 @@ def __init__(
score_field: Optional[str] = None,
score_type: Union[type, str] = None,
invert=False,
add_skip_label_only: bool = False,
):
"""Args:
src_field (str, optional): The field the source documents will be read from. Defaults to "src".
Expand All @@ -64,6 +66,7 @@ def __init__(
self.score_field = score_field
self.score_type = score_type
self.invert = invert
self.add_skip_label_only = add_skip_label_only

def __call__(
self,
Expand All @@ -89,21 +92,44 @@ def __call__(
fields.append(self.tgt_field)
fields.extend(self.metadata_fields)

if self.add_skip_label_only:
if SKIP_LABEL_KEY not in dataset.df.columns:
dataset.df[SKIP_LABEL_KEY] = 0
if REASON_LABEL_KEY not in dataset.df.columns:
dataset.df[REASON_LABEL_KEY] = None

# although the full dataset is passed, we don't need to compute score on full data
# only data that's still remaining needs to be processed
kept_mask = dataset.df[SKIP_LABEL_KEY] == 0
df = dataset.df[kept_mask].copy()
else:
df = dataset.df

if is_batched(self.score_bitext):
scores = dataset.df[fields].map_partitions(
scores = df[fields].map_partitions(
self._score_bitext_wrapper,
metadata_field_name_mapping=self.metadata_field_name_mapping,
meta=meta,
)
else:
scores = dataset.df[fields].apply(
scores = df[fields].apply(
self._score_bitext_wrapper,
metadata_field_name_mapping=self.metadata_field_name_mapping,
axis=1,
meta=meta,
)

if self.score_field is not None:
if self.score_field is not None and self.add_skip_label_only:
dataset.df[self.score_field] = None

def update_score(partition, mask_partition, score_partition):
partition.loc[mask_partition, self.score_field] = score_partition
return partition

dataset.df = dataset.df.map_partitions(
update_score, kept_mask, scores, meta=dataset.df
)
elif self.score_field is not None:
dataset.df[self.score_field] = scores

if is_batched(self.keep_bitext):
Expand All @@ -113,7 +139,36 @@ def __call__(
if self.invert:
bool_mask = ~bool_mask

return ParallelDataset(dataset.df[bool_mask])
def update_skipme(partition, kept_mask_partition, score_bool_mask_partition):
partition.loc[kept_mask_partition, SKIP_LABEL_KEY] = [
1 if skip else 0 for skip in ~score_bool_mask_partition
]
return partition

def update_reason(partition, kept_mask_partition, reason):
# filtering reason needs to be updated for the following entries
# 1. the entry was kept before
# 2. the entry was thrown out by this filter
new_skip = [
True if skip == 1 else False for skip in partition[SKIP_LABEL_KEY]
]
new_mask = logical_and(kept_mask_partition.values, new_skip)
partition.loc[new_mask, REASON_LABEL_KEY] = reason
return partition

if self.add_skip_label_only:
dataset.df = dataset.df.map_partitions(
update_skipme,
kept_mask,
~bool_mask if self.invert else bool_mask,
meta=dataset.df,
)
dataset.df = dataset.df.map_partitions(
update_reason, kept_mask, self.__class__.__name__, meta=dataset.df
)
return ParallelDataset(dataset.df)
else:
return ParallelDataset(dataset.df[bool_mask])

def _score_bitext_wrapper(
self,
Expand Down
96 changes: 82 additions & 14 deletions nemo_curator/modules/filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from nemo_curator.datasets.parallel_dataset import ParallelDataset
from nemo_curator.filters import DocumentFilter
from nemo_curator.modules.base import BaseModule
from nemo_curator.utils.module_utils import is_batched
from nemo_curator.utils.module_utils import REASON_LABEL_KEY, SKIP_LABEL_KEY, is_batched

# Override so that pd.NA is not passed during the metadata inference
make_array_nonempty.register(
Expand Down Expand Up @@ -168,6 +168,7 @@ def __init__(
score_field: Optional[str] = None,
score_type: Union[type, str] = None,
invert: bool = False,
add_skip_label_only: bool = False,
):
"""
Constructs a ScoreFilter module.
Expand All @@ -185,6 +186,7 @@ def __init__(
self.score_field = score_field
self.score_type = score_type
self.invert = invert
self.add_skip_label_only = add_skip_label_only

def compute_filter_mask(self, dataset: DocumentDataset):
"""Compute the bool mask to filter the dataset.
Expand All @@ -200,16 +202,39 @@ def compute_filter_mask(self, dataset: DocumentDataset):
else:
meta = no_default

if self.add_skip_label_only:
if SKIP_LABEL_KEY not in dataset.df.columns:
dataset.df[SKIP_LABEL_KEY] = 0
if REASON_LABEL_KEY not in dataset.df.columns:
dataset.df[REASON_LABEL_KEY] = None

# although the full dataset is passed, we don't need to compute score on full data
# only data that's still remaining needs to be processed
kept_mask = dataset.df._skipme == 0
df = dataset.df[kept_mask].copy()
else:
df = dataset.df

if is_batched(self.filter_obj.score_document):
scores = dataset.df[self.text_field].map_partitions(
scores = df[self.text_field].map_partitions(
self.filter_obj.score_document, meta=meta
)
else:
scores = dataset.df[self.text_field].apply(
scores = df[self.text_field].apply(
self.filter_obj.score_document, meta=meta
)

if self.score_field is not None:
if self.score_field is not None and self.add_skip_label_only:
dataset.df[self.score_field] = None

def update_score(partition, mask_partition, score_partition):
partition.loc[mask_partition, self.score_field] = score_partition
return partition

dataset.df = dataset.df.map_partitions(
update_score, kept_mask, scores, meta=dataset.df
)
elif self.score_field is not None:
dataset.df[self.score_field] = scores

if is_batched(self.filter_obj.keep_document):
Expand All @@ -234,7 +259,41 @@ def call(self, dataset: DocumentDataset) -> DocumentDataset:
DocumentDataset: A dataset with the score and filter applied
"""
bool_mask = self.compute_filter_mask(dataset)
return DocumentDataset(dataset.df[bool_mask])

def update_skipme(partition, kept_mask_partition, score_bool_mask_partition):
partition.loc[kept_mask_partition, SKIP_LABEL_KEY] = [
1 if skip else 0 for skip in ~score_bool_mask_partition
]
return partition

def update_reason(partition, kept_mask_partition, reason):
# filtering reason needs to be updated for the following entries
# 1. the entry was kept before
# 2. the entry was thrown out by this filter
new_skip = [
True if skip == 1 else False for skip in partition[SKIP_LABEL_KEY]
]
new_mask = logical_and(kept_mask_partition.values, new_skip)
partition.loc[new_mask, REASON_LABEL_KEY] = reason
return partition

if self.add_skip_label_only:
kept_mask = dataset.df._skipme == 0
dataset.df = dataset.df.map_partitions(
update_skipme,
kept_mask,
~bool_mask if self.invert else bool_mask,
meta=dataset.df,
)
dataset.df = dataset.df.map_partitions(
update_reason,
kept_mask,
self.filter_obj.__class__.__name__,
meta=dataset.df,
)
return DocumentDataset(dataset.df)
else:
return DocumentDataset(dataset.df[bool_mask])


class ParallelScoreFilter(BaseModule):
Expand All @@ -248,6 +307,7 @@ def __init__(
tgt_score=None,
score_type=None,
invert=False,
add_skip_label_only: bool = False,
):
"""A filter object wrapper class for applying *monolingual* filter objects on bitext.
If either side of the bitext is discarded, the whole bitext pair is discarded.
Expand All @@ -269,17 +329,25 @@ def __init__(
"""
super().__init__(input_backend=src_filter_obj.backend)
self.source_score_filter = ScoreFilter(
src_filter_obj, src_field, src_score, score_type, invert
src_filter_obj,
src_field,
src_score,
score_type,
invert,
add_skip_label_only,
)
self.target_score_filter = ScoreFilter(
tgt_filter_obj, tgt_field, tgt_score, score_type, invert
tgt_filter_obj,
tgt_field,
tgt_score,
score_type,
invert,
add_skip_label_only,
)
self.add_skip_label_only = add_skip_label_only

def call(self, dataset: ParallelDataset):
src_bool_mask = self.source_score_filter.compute_filter_mask(dataset)
tgt_bool_mask = self.target_score_filter.compute_filter_mask(dataset)

def call(self, dataset: ParallelDataset) -> ParallelDataset:
# remove lines together if one of them is filtered
bool_mask = logical_and(src_bool_mask, tgt_bool_mask)

return ParallelDataset(dataset.df[bool_mask])
ds1 = self.source_score_filter(dataset)
ds2 = self.target_score_filter(ds1)
return ParallelDataset(ds2.df)
3 changes: 3 additions & 0 deletions nemo_curator/utils/module_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@
# limitations under the License.
import math

SKIP_LABEL_KEY = "_skipme"
REASON_LABEL_KEY = "reason"


def is_batched(function):
return hasattr(function, "batched") and function.batched
Expand Down
6 changes: 5 additions & 1 deletion tutorials/bitext_cleaning/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,13 +56,15 @@ def filter_dataset(dataset: ParallelDataset, gpu: bool = False) -> ParallelDatas
tgt_lang=TGT_LANG,
score_field="length_ratio",
score_type=float,
add_skip_label_only=True,
),
ParallelScoreFilter(
HistogramFilter(lang=SRC_LANG),
HistogramFilter(lang=TGT_LANG),
src_score="src_hist",
tgt_score="tgt_hist",
score_type=int,
add_skip_label_only=True,
),
]
)
Expand All @@ -75,6 +77,7 @@ def filter_dataset(dataset: ParallelDataset, gpu: bool = False) -> ParallelDatas
gpu=gpu,
metadata_fields=["src_lang", "tgt_lang"],
score_type=float,
add_skip_label_only=True,
)
)
else:
Expand Down Expand Up @@ -115,7 +118,8 @@ def run_curation_pipeline(args: Any, src_file: str, tgt_file: str) -> None:
shutil.rmtree(out_path)

os.makedirs(out_path)
dataset.to_bitext(out_path, write_to_filename=True)
# dataset.to_bitext(out_path, write_to_filename=True)
dataset.to_json(out_path)
client.close()


Expand Down