30
30
from tensor2tensor .data_generators import tokenizer
31
31
import tensorflow as tf
32
32
33
- FLAGS = tf .app . flags .FLAGS
33
+ FLAGS = tf .flags .FLAGS
34
34
35
- _TESTDATA = "google3/third_party/py/tensor2tensor/data_generators/test_data"
35
+ pkg_dir , _ = os .path .split (__file__ )
36
+ _TESTDATA = os .path .join (pkg_dir , "test_data" )
36
37
37
38
38
39
class TokenizerTest (tf .test .TestCase ):
@@ -41,18 +42,13 @@ def test_encode(self):
41
42
self .assertListEqual (
42
43
[u"Dude" , u" - " , u"that" , u"'" , u"s" , u"so" , u"cool" , u"." ],
43
44
tokenizer .encode (u"Dude - that's so cool." ))
44
- self .assertListEqual (
45
- [u"Łukasz" , u"est" , u"né" , u"en" , u"1981" , u"." ],
46
- tokenizer .encode (u"Łukasz est né en 1981." ))
47
- self .assertListEqual (
48
- [u" " , u"Spaces" , u"at" , u"the" , u"ends" , u" " ],
49
- tokenizer .encode (u" Spaces at the ends " ))
50
- self .assertListEqual (
51
- [u"802" , u"." , u"11b" ],
52
- tokenizer .encode (u"802.11b" ))
53
- self .assertListEqual (
54
- [u"two" , u". \n " , u"lines" ],
55
- tokenizer .encode (u"two. \n lines" ))
45
+ self .assertListEqual ([u"Łukasz" , u"est" , u"né" , u"en" , u"1981" , u"." ],
46
+ tokenizer .encode (u"Łukasz est né en 1981." ))
47
+ self .assertListEqual ([u" " , u"Spaces" , u"at" , u"the" , u"ends" , u" " ],
48
+ tokenizer .encode (u" Spaces at the ends " ))
49
+ self .assertListEqual ([u"802" , u"." , u"11b" ], tokenizer .encode (u"802.11b" ))
50
+ self .assertListEqual ([u"two" , u". \n " , u"lines" ],
51
+ tokenizer .encode (u"two. \n lines" ))
56
52
57
53
def test_decode (self ):
58
54
self .assertEqual (
@@ -62,19 +58,16 @@ def test_decode(self):
62
58
63
59
def test_invertibility_on_random_strings (self ):
64
60
for _ in xrange (1000 ):
65
- s = u"" .join (
66
- six .unichr (random .randint (0 , 65535 )) for _ in xrange (10 ))
61
+ s = u"" .join (six .unichr (random .randint (0 , 65535 )) for _ in xrange (10 ))
67
62
self .assertEqual (s , tokenizer .decode (tokenizer .encode (s )))
68
63
69
64
70
65
class TestTokenCounts (tf .test .TestCase ):
71
66
72
67
def setUp (self ):
73
68
super (TestTokenCounts , self ).setUp ()
74
- self .corpus_path = os .path .join (
75
- FLAGS .test_srcdir , _TESTDATA , "corpus-*.txt" )
76
- self .vocab_path = os .path .join (
77
- FLAGS .test_srcdir , _TESTDATA , "vocab-*.txt" )
69
+ self .corpus_path = os .path .join (_TESTDATA , "corpus-*.txt" )
70
+ self .vocab_path = os .path .join (_TESTDATA , "vocab-*.txt" )
78
71
79
72
def test_corpus_token_counts_split_on_newlines (self ):
80
73
token_counts = tokenizer .corpus_token_counts (
@@ -117,31 +110,33 @@ def test_corpus_token_counts_no_split_with_max_lines(self):
117
110
118
111
self .assertIn (u"slept" , token_counts )
119
112
self .assertNotIn (u"Mitch" , token_counts )
120
- self .assertDictContainsSubset (
121
- {u".\n \n " : 1 , u"\n " : 2 , u".\n " : 1 }, token_counts )
113
+ self .assertDictContainsSubset ({
114
+ u".\n \n " : 1 ,
115
+ u"\n " : 2 ,
116
+ u".\n " : 1
117
+ }, token_counts )
122
118
123
119
def test_vocab_token_counts (self ):
124
- token_counts = tokenizer .vocab_token_counts (
125
- self .vocab_path , 0 )
120
+ token_counts = tokenizer .vocab_token_counts (self .vocab_path , 0 )
126
121
127
122
expected = {
128
- "lollipop" : 8 ,
129
- "reverberated" : 12 ,
130
- "kattywampus" : 11 ,
131
- "balderdash" : 10 ,
132
- "jiggery-pokery" : 14 ,
123
+ u "lollipop" : 8 ,
124
+ u "reverberated" : 12 ,
125
+ u "kattywampus" : 11 ,
126
+ u "balderdash" : 10 ,
127
+ u "jiggery-pokery" : 14 ,
133
128
}
134
129
self .assertDictEqual (expected , token_counts )
135
130
136
131
def test_vocab_token_counts_with_max_lines (self ):
137
- token_counts = tokenizer . vocab_token_counts (
138
- self .vocab_path , 4 )
132
+ # vocab-1 has 2 lines, vocab-2 has 3
133
+ token_counts = tokenizer . vocab_token_counts ( self .vocab_path , 4 )
139
134
140
135
expected = {
141
- "lollipop" : 8 ,
142
- "reverberated" : 12 ,
143
- "kattywampus" : 11 ,
144
- "balderdash" : 10 ,
136
+ u "lollipop" : 8 ,
137
+ u "reverberated" : 12 ,
138
+ u "kattywampus" : 11 ,
139
+ u "balderdash" : 10 ,
145
140
}
146
141
self .assertDictEqual (expected , token_counts )
147
142
0 commit comments