Skip to content

Commit 7b1e380

Browse files
committed
make both confidence and expected lang id case work, add LangID in tutorial
Signed-off-by: Shuoyang Ding <[email protected]>
1 parent d444257 commit 7b1e380

File tree

2 files changed

+19
-3
lines changed

2 files changed

+19
-3
lines changed

nemo_curator/filters/classifier_filter.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ def _load_model(self):
6868

6969
class FastTextLangId(DocumentFilter):
7070

71-
def __init__(self, model_path=None, min_langid_score=0.3, lang="en"):
71+
def __init__(self, model_path=None, min_langid_score=0.3, lang=None):
7272
if model_path is None:
7373
raise ValueError(
7474
"Must provide a valid path to a FastText model "
@@ -98,8 +98,10 @@ def _score_document(text):
9898
return df.apply(_score_document)
9999

100100
def keep_document(self, score):
101-
# return score[0] >= self._cutoff
102-
return score[1] == self._lang_code
101+
if self._lang_code:
102+
return score[1] == self._lang_code
103+
else:
104+
return score[0] >= self._cutoff
103105

104106
def _load_model(self):
105107
return fasttext.load_model(self._model_path)

tutorials/bitext_cleaning/main.py

+14
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from nemo_curator import ParallelScoreFilter, Sequential
2424
from nemo_curator.datasets.parallel_dataset import ParallelDataset
2525
from nemo_curator.filters import (
26+
FastTextLangId,
2627
HistogramFilter,
2728
LengthRatioFilter,
2829
QualityEstimationFilter,
@@ -38,6 +39,10 @@
3839
SCRIPT_DIR_PATH = os.path.dirname(os.path.abspath(__file__))
3940
DATA_DIR = os.path.join(SCRIPT_DIR_PATH, "data")
4041

42+
# If you want to test FastText language ID,
43+
# download the model from here first then update this with your local model path (https://dl.fbaipublicfiles.com/fasttext/supervised-models/lid.176.ftz)
44+
FAST_TEXT_MODEL_DIR = ""
45+
4146

4247
def download_files() -> str:
4348
downloader = TedTalksDownloader(DATA_DIR)
@@ -67,6 +72,15 @@ def filter_dataset(dataset: ParallelDataset, gpu: bool = False) -> ParallelDatas
6772
]
6873
)
6974

75+
if FAST_TEXT_MODEL_DIR:
76+
filters.modules.append(
77+
ParallelScoreFilter(
78+
FastTextLangId(model_path=FAST_TEXT_MODEL_DIR, lang=SRC_LANG),
79+
FastTextLangId(model_path=FAST_TEXT_MODEL_DIR, lang=TGT_LANG),
80+
score_type=str,
81+
)
82+
)
83+
7084
if gpu:
7185
filters.modules.append(
7286
QualityEstimationFilter(

0 commit comments

Comments
 (0)