diff --git a/nemo_curator/filters/classifier_filter.py b/nemo_curator/filters/classifier_filter.py index d34140771..e791cae9b 100644 --- a/nemo_curator/filters/classifier_filter.py +++ b/nemo_curator/filters/classifier_filter.py @@ -68,14 +68,14 @@ def _load_model(self): class FastTextLangId(DocumentFilter): - def __init__(self, model_path=None, min_langid_score=0.3): + def __init__(self, model_path=None, min_langid_score=0.3, lang=None): if model_path is None: raise ValueError( "Must provide a valid path to a FastText model " "to identify languages with this filter" ) self._model_path = model_path - self._lang_code = None + self._lang_code = lang self._cutoff = min_langid_score self._name = "lang_id" @@ -91,14 +91,17 @@ def _score_document(text): pp = text.strip().replace("\n", " ") label, score = model.predict(pp, k=1) score = score[0] - lang_code = label[0][-2:].upper() + lang_code = label[0][-2:].lower() return [score, lang_code] return df.apply(_score_document) def keep_document(self, score): - return score[0] >= self._cutoff + if self._lang_code: + return score[1] == self._lang_code + else: + return score[0] >= self._cutoff def _load_model(self): return fasttext.load_model(self._model_path) diff --git a/tutorials/bitext_cleaning/main.py b/tutorials/bitext_cleaning/main.py index 58e52ed50..7751baf00 100644 --- a/tutorials/bitext_cleaning/main.py +++ b/tutorials/bitext_cleaning/main.py @@ -23,6 +23,7 @@ from nemo_curator import ParallelScoreFilter, Sequential from nemo_curator.datasets.parallel_dataset import ParallelDataset from nemo_curator.filters import ( + FastTextLangId, HistogramFilter, LengthRatioFilter, QualityEstimationFilter, @@ -38,6 +39,10 @@ SCRIPT_DIR_PATH = os.path.dirname(os.path.abspath(__file__)) DATA_DIR = os.path.join(SCRIPT_DIR_PATH, "data") +# If you want to test FastText language ID, +# download the model from here first then update this with your local model path (https://dl.fbaipublicfiles.com/fasttext/supervised-models/lid.176.ftz) +FAST_TEXT_MODEL_DIR = "" + def download_files() -> str: downloader = TedTalksDownloader(DATA_DIR) @@ -67,6 +72,15 @@ def filter_dataset(dataset: ParallelDataset, gpu: bool = False) -> ParallelDatas ] ) + if FAST_TEXT_MODEL_DIR: + filters.modules.append( + ParallelScoreFilter( + FastTextLangId(model_path=FAST_TEXT_MODEL_DIR, lang=SRC_LANG), + FastTextLangId(model_path=FAST_TEXT_MODEL_DIR, lang=TGT_LANG), + score_type=str, + ) + ) + if gpu: filters.modules.append( QualityEstimationFilter(