34
34
# End-of-sentence marker.
35
35
EOS = text_encoder .EOS_ID
36
36
37
- _ENFR_TRAIN_DATASETS = [
37
+ _ENFR_TRAIN_SMALL_DATA = [
38
38
[
39
39
"https://s3.amazonaws.com/opennmt-trainingdata/baseline-1M-enfr.tgz" ,
40
40
("baseline-1M-enfr/baseline-1M_train.en" ,
41
41
"baseline-1M-enfr/baseline-1M_train.fr" )
42
42
],
43
- # [
44
- # "http://www.statmt.org/wmt13/training-parallel-commoncrawl.tgz",
45
- # ("commoncrawl.fr-en.en", "commoncrawl.fr-en.fr")
46
- # ],
47
- # [
48
- # "http://www.statmt.org/wmt13/training-parallel-europarl-v7.tgz",
49
- # ("training/europarl-v7.fr-en.en", "training/europarl-v7.fr-en.fr")
50
- # ],
51
- # [
52
- # "http://www.statmt.org/wmt14/training-parallel-nc-v9.tgz",
53
- # ("training/news-commentary-v9.fr-en.en",
54
- # "training/news-commentary-v9.fr-en.fr")
55
- # ],
56
- # [
57
- # "http://www.statmt.org/wmt10/training-giga-fren.tar",
58
- # ("giga-fren.release2.fixed.en.gz",
59
- # "giga-fren.release2.fixed.fr.gz")
60
- # ],
61
- # [
62
- # "http://www.statmt.org/wmt13/training-parallel-un.tgz",
63
- # ("un/undoc.2000.fr-en.en", "un/undoc.2000.fr-en.fr")
64
- # ],
65
43
]
66
- _ENFR_TEST_DATASETS = [
44
+ _ENFR_TEST_SMALL_DATA = [
67
45
[
68
46
"https://s3.amazonaws.com/opennmt-trainingdata/baseline-1M-enfr.tgz" ,
69
47
("baseline-1M-enfr/baseline-1M_valid.en" ,
70
48
"baseline-1M-enfr/baseline-1M_valid.fr" )
71
49
],
72
- # [
73
- # "http://data.statmt.org/wmt17/translation-task/dev.tgz",
74
- # ("dev/newstest2013.en", "dev/newstest2013.fr")
75
- # ],
50
+ ]
51
+ _ENFR_TRAIN_LARGE_DATA = [
52
+ [
53
+ "http://www.statmt.org/wmt13/training-parallel-commoncrawl.tgz" ,
54
+ ("commoncrawl.fr-en.en" , "commoncrawl.fr-en.fr" )
55
+ ],
56
+ [
57
+ "http://www.statmt.org/wmt13/training-parallel-europarl-v7.tgz" ,
58
+ ("training/europarl-v7.fr-en.en" , "training/europarl-v7.fr-en.fr" )
59
+ ],
60
+ [
61
+ "http://www.statmt.org/wmt14/training-parallel-nc-v9.tgz" ,
62
+ ("training/news-commentary-v9.fr-en.en" ,
63
+ "training/news-commentary-v9.fr-en.fr" )
64
+ ],
65
+ [
66
+ "http://www.statmt.org/wmt10/training-giga-fren.tar" ,
67
+ ("giga-fren.release2.fixed.en.gz" ,
68
+ "giga-fren.release2.fixed.fr.gz" )
69
+ ],
70
+ [
71
+ "http://www.statmt.org/wmt13/training-parallel-un.tgz" ,
72
+ ("un/undoc.2000.fr-en.en" , "un/undoc.2000.fr-en.fr" )
73
+ ],
74
+ ]
75
+ _ENFR_TEST_LARGE_DATA = [
76
+ [
77
+ "http://data.statmt.org/wmt17/translation-task/dev.tgz" ,
78
+ ("dev/newstest2013.en" , "dev/newstest2013.fr" )
79
+ ],
76
80
]
77
81
78
82
79
83
@registry .register_problem
80
- class TranslateEnfrWmt8k (translate .TranslateProblem ):
84
+ class TranslateEnfrWmtSmall8k (translate .TranslateProblem ):
81
85
"""Problem spec for WMT En-Fr translation."""
82
86
83
87
@property
@@ -88,11 +92,18 @@ def targeted_vocab_size(self):
88
92
def vocab_name (self ):
89
93
return "vocab.enfr"
90
94
95
+ @property
96
+ def use_small_dataset (self ):
97
+ return True
98
+
91
99
def generator (self , data_dir , tmp_dir , train ):
92
100
symbolizer_vocab = generator_utils .get_or_generate_vocab (
93
101
data_dir , tmp_dir , self .vocab_file , self .targeted_vocab_size ,
94
- _ENFR_TRAIN_DATASETS )
95
- datasets = _ENFR_TRAIN_DATASETS if train else _ENFR_TEST_DATASETS
102
+ _ENFR_TRAIN_SMALL_DATA )
103
+ if self .use_small_dataset :
104
+ datasets = _ENFR_TRAIN_SMALL_DATA if train else _ENFR_TEST_SMALL_DATA
105
+ else :
106
+ datasets = _ENFR_TRAIN_LARGE_DATA if train else _ENFR_TEST_LARGE_DATA
96
107
tag = "train" if train else "dev"
97
108
data_path = translate .compile_data (tmp_dir , datasets ,
98
109
"wmt_enfr_tok_%s" % tag )
@@ -109,15 +120,31 @@ def target_space_id(self):
109
120
110
121
111
122
@registry .register_problem
112
- class TranslateEnfrWmt32k ( TranslateEnfrWmt8k ):
123
+ class TranslateEnfrWmtSmall32k ( TranslateEnfrWmtSmall8k ):
113
124
114
125
@property
115
126
def targeted_vocab_size (self ):
116
127
return 2 ** 15 # 32768
117
128
118
129
119
130
@registry .register_problem
120
- class TranslateEnfrWmtCharacters (translate .TranslateProblem ):
131
+ class TranslateEnfrWmt8k (TranslateEnfrWmtSmall8k ):
132
+
133
+ @property
134
+ def use_small_dataset (self ):
135
+ return False
136
+
137
+
138
+ @registry .register_problem
139
+ class TranslateEnfrWmt32k (TranslateEnfrWmtSmall32k ):
140
+
141
+ @property
142
+ def use_small_dataset (self ):
143
+ return False
144
+
145
+
146
+ @registry .register_problem
147
+ class TranslateEnfrWmtSmallCharacters (translate .TranslateProblem ):
121
148
"""Problem spec for WMT En-Fr translation."""
122
149
123
150
@property
@@ -130,7 +157,10 @@ def vocab_name(self):
130
157
131
158
def generator (self , data_dir , tmp_dir , train ):
132
159
character_vocab = text_encoder .ByteTextEncoder ()
133
- datasets = _ENFR_TRAIN_DATASETS if train else _ENFR_TEST_DATASETS
160
+ if self .use_small_dataset :
161
+ datasets = _ENFR_TRAIN_SMALL_DATA if train else _ENFR_TEST_SMALL_DATA
162
+ else :
163
+ datasets = _ENFR_TRAIN_LARGE_DATA if train else _ENFR_TEST_LARGE_DATA
134
164
tag = "train" if train else "dev"
135
165
data_path = translate .compile_data (tmp_dir , datasets ,
136
166
"wmt_enfr_chr_%s" % tag )
@@ -144,3 +174,11 @@ def input_space_id(self):
144
174
@property
145
175
def target_space_id (self ):
146
176
return problem .SpaceID .FR_CHR
177
+
178
+
179
+ @registry .register_problem
180
+ class TranslateEnfrWmtCharacters (TranslateEnfrWmtSmallCharacters ):
181
+
182
+ @property
183
+ def use_small_dataset (self ):
184
+ return False
0 commit comments