@@ -146,6 +146,13 @@ def hps_dlpack(model_name, embedding_file_list, data_file, enable_cache, cache_t
146
146
True ,
147
147
hugectr .inference .EmbeddingCacheType_t .Dynamic ,
148
148
)
149
+ h1 , h2 = hps_dlpack (
150
+ model_name ,
151
+ embedding_file_list ,
152
+ data_file ,
153
+ False ,
154
+ hugectr .inference .EmbeddingCacheType_t .Dynamic ,
155
+ )
149
156
u1 , u2 = hps_dlpack (
150
157
model_name , embedding_file_list , data_file , True , hugectr .inference .EmbeddingCacheType_t .UVM
151
158
)
@@ -173,15 +180,28 @@ def hps_dlpack(model_name, embedding_file_list, data_file, enable_cache, cache_t
173
180
diff = u2 .reshape (1 , 26 * 16 ) - d2 .reshape (1 , 26 * 16 )
174
181
if diff .mean () > 1e-3 :
175
182
raise RuntimeError (
176
- "The lookup results of UVM cache are consistent with Dynamic cache: {}" .format (
183
+ "The lookup results of UVM cache are not consistent with Dynamic cache: {}" .format (
184
+ diff .mean ()
185
+ )
186
+ )
187
+ sys .exit (1 )
188
+ else :
189
+ print (
190
+ "The lookup results on UVM are consistent with Dynamic cache, mse: {}" .format (
191
+ diff .mean ()
192
+ )
193
+ )
194
+ diff = h2 .reshape (1 , 26 * 16 ) - d2 .reshape (1 , 26 * 16 )
195
+ if diff .mean () > 1e-3 :
196
+ raise RuntimeError (
197
+ "The lookup results of Database backend are not consistent with Dynamic cache: {}" .format (
177
198
diff .mean ()
178
199
)
179
200
)
180
201
sys .exit (1 )
181
202
else :
182
203
print (
183
- "Pytorch dlpack on cpu results are consistent with native HPS lookup api , mse: {}" .format (
204
+ "The lookup results on Database backend are consistent with Dynamic cache , mse: {}" .format (
184
205
diff .mean ()
185
206
)
186
207
)
187
- # hps_dlpack(model_name, network_file, dense_file, embedding_file_list, data_file, False)
0 commit comments