@@ -105,6 +105,7 @@ def ignite_average_to_scikit_average(average, data_type: str):
105
105
@pytest .mark .parametrize ("average" , [None , False , "macro" , "micro" , "weighted" ])
106
106
def test_binary_input (n_times , available_device , average , test_data_binary ):
107
107
pr = Precision (average = average , device = available_device )
108
+ assert pr ._device == torch .device (available_device )
108
109
assert pr ._updated is False
109
110
y_pred , y , batch_size = test_data_binary
110
111
@@ -193,6 +194,7 @@ def test_multiclass_wrong_inputs():
193
194
@pytest .mark .parametrize ("average" , [None , False , "macro" , "micro" , "weighted" ])
194
195
def test_multiclass_input (n_times , available_device , average , test_data_multiclass ):
195
196
pr = Precision (average = average , device = available_device )
197
+ assert pr ._device == torch .device (available_device )
196
198
assert pr ._updated is False
197
199
198
200
y_pred , y , batch_size = test_data_multiclass
@@ -260,6 +262,7 @@ def to_numpy_multilabel(y):
260
262
@pytest .mark .parametrize ("average" , [None , False , "macro" , "micro" , "weighted" , "samples" ])
261
263
def test_multilabel_input (n_times , available_device , average , test_data_multilabel ):
262
264
pr = Precision (average = average , is_multilabel = True , device = available_device )
265
+ assert pr ._device == torch .device (available_device )
263
266
assert pr ._updated is False
264
267
265
268
y_pred , y , batch_size = test_data_multilabel
0 commit comments