Skip to content

Commit c48d9f9

Browse files
authored
adds checks for used device in metrics tests. #3335 (#3353)
1 parent ef8d912 commit c48d9f9

9 files changed

+19
-1
lines changed

tests/ignite/metrics/gan/test_inception_score.py

+1
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ def test_inception_score(available_device):
3232

3333
p_yx = torch.rand(20, 3, 299, 299)
3434
m = InceptionScore(device=available_device)
35+
assert m._device == torch.device(available_device)
3536
m.update(p_yx)
3637
assert isinstance(m.compute(), float)
3738

tests/ignite/metrics/test_accuracy.py

+3
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ def test_binary_wrong_inputs():
6666
@pytest.mark.parametrize("n_times", range(3))
6767
def test_binary_input(n_times, available_device, test_data_binary):
6868
acc = Accuracy(device=available_device)
69+
assert acc._device == torch.device(available_device)
6970

7071
y_pred, y, batch_size = test_data_binary
7172
acc.reset()
@@ -104,6 +105,7 @@ def test_multiclass_wrong_inputs():
104105
@pytest.mark.parametrize("n_times", range(3))
105106
def test_multiclass_input(n_times, available_device, test_data_multiclass):
106107
acc = Accuracy(device=available_device)
108+
assert acc._device == torch.device(available_device)
107109

108110
y_pred, y, batch_size = test_data_multiclass
109111
acc.reset()
@@ -155,6 +157,7 @@ def test_multilabel_wrong_inputs():
155157
@pytest.mark.parametrize("n_times", range(3))
156158
def test_multilabel_input(n_times, available_device, test_data_multilabel):
157159
acc = Accuracy(is_multilabel=True, device=available_device)
160+
assert acc._device == torch.device(available_device)
158161

159162
y_pred, y, batch_size = test_data_multilabel
160163
if batch_size > 1:

tests/ignite/metrics/test_average_precision.py

+2
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@ def test_data_binary_and_multilabel(request):
8585
def test_binary_and_multilabel_inputs(n_times, available_device, test_data_binary_and_multilabel):
8686
y_pred, y, batch_size = test_data_binary_and_multilabel
8787
ap = AveragePrecision(device=available_device)
88+
assert ap._device == torch.device(available_device)
8889
ap.reset()
8990
if batch_size > 1:
9091
n_iters = y.shape[0] // batch_size + 1
@@ -129,6 +130,7 @@ def update_fn(engine, batch):
129130
engine = Engine(update_fn)
130131

131132
ap_metric = AveragePrecision(device=available_device)
133+
assert ap_metric._device == torch.device(available_device)
132134
ap_metric.attach(engine, "ap")
133135

134136
np_y = y.numpy()

tests/ignite/metrics/test_cosine_similarity.py

+2
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ def test_compute(n_times, test_case: Tuple[Tensor, Tensor, float, int], availabl
4545
y_pred, y, eps, batch_size = test_case
4646

4747
cos = CosineSimilarity(eps=eps, device=available_device)
48+
assert cos._device == torch.device(available_device)
4849

4950
cos.reset()
5051
if batch_size > 1:
@@ -69,6 +70,7 @@ def test_compute(n_times, test_case: Tuple[Tensor, Tensor, float, int], availabl
6970

7071
def test_accumulator_detached(available_device):
7172
cos = CosineSimilarity(device=available_device)
73+
assert cos._device == torch.device(available_device)
7274

7375
y_pred = torch.tensor([[2.0, 3.0], [-2.0, 1.0]], dtype=torch.float)
7476
y = torch.ones(2, 2, dtype=torch.float)

tests/ignite/metrics/test_precision.py

+3
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,7 @@ def ignite_average_to_scikit_average(average, data_type: str):
105105
@pytest.mark.parametrize("average", [None, False, "macro", "micro", "weighted"])
106106
def test_binary_input(n_times, available_device, average, test_data_binary):
107107
pr = Precision(average=average, device=available_device)
108+
assert pr._device == torch.device(available_device)
108109
assert pr._updated is False
109110
y_pred, y, batch_size = test_data_binary
110111

@@ -193,6 +194,7 @@ def test_multiclass_wrong_inputs():
193194
@pytest.mark.parametrize("average", [None, False, "macro", "micro", "weighted"])
194195
def test_multiclass_input(n_times, available_device, average, test_data_multiclass):
195196
pr = Precision(average=average, device=available_device)
197+
assert pr._device == torch.device(available_device)
196198
assert pr._updated is False
197199

198200
y_pred, y, batch_size = test_data_multiclass
@@ -260,6 +262,7 @@ def to_numpy_multilabel(y):
260262
@pytest.mark.parametrize("average", [None, False, "macro", "micro", "weighted", "samples"])
261263
def test_multilabel_input(n_times, available_device, average, test_data_multilabel):
262264
pr = Precision(average=average, is_multilabel=True, device=available_device)
265+
assert pr._device == torch.device(available_device)
263266
assert pr._updated is False
264267

265268
y_pred, y, batch_size = test_data_multilabel

tests/ignite/metrics/test_psnr.py

+1
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ def test_psnr(test_data, available_device):
5353
data_range = (y.max() - y.min()).cpu().item()
5454

5555
psnr = PSNR(data_range=data_range, device=available_device)
56+
assert psnr._device == torch.device(available_device)
5657
psnr.update(test_data)
5758
psnr_compute = psnr.compute()
5859

tests/ignite/metrics/test_recall.py

+3
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,7 @@ def ignite_average_to_scikit_average(average, data_type: str):
108108
@pytest.mark.parametrize("average", [None, False, "macro", "micro", "weighted"])
109109
def test_binary_input(n_times, available_device, average, test_data_binary):
110110
re = Recall(average=average, device=available_device)
111+
assert re._device == torch.device(available_device)
111112
assert re._updated is False
112113

113114
y_pred, y, batch_size = test_data_binary
@@ -195,6 +196,7 @@ def test_multiclass_wrong_inputs():
195196
@pytest.mark.parametrize("average", [None, False, "macro", "micro", "weighted"])
196197
def test_multiclass_input(n_times, available_device, average, test_data_multiclass):
197198
re = Recall(average=average, device=available_device)
199+
assert re._device == torch.device(available_device)
198200
assert re._updated is False
199201

200202
y_pred, y, batch_size = test_data_multiclass
@@ -263,6 +265,7 @@ def to_numpy_multilabel(y):
263265
def test_multilabel_input(n_times, available_device, average, test_data_multilabel):
264266

265267
re = Recall(average=average, is_multilabel=True, device=available_device)
268+
assert re._device == torch.device(available_device)
266269
assert re._updated is False
267270

268271
y_pred, y, batch_size = test_data_multilabel

tests/ignite/metrics/test_roc_auc.py

+2
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,7 @@ def test_data_binary_and_multilabel(request):
8787
def test_binary_and_multilabel_inputs(n_times, available_device, test_data_binary_and_multilabel):
8888
y_pred, y, batch_size = test_data_binary_and_multilabel
8989
roc_auc = ROC_AUC(device=available_device)
90+
assert roc_auc._device == torch.device(available_device)
9091
roc_auc.reset()
9192
if batch_size > 1:
9293
n_iters = y.shape[0] // batch_size + 1
@@ -147,6 +148,7 @@ def update_fn(engine, batch):
147148
engine = Engine(update_fn)
148149

149150
roc_auc_metric = ROC_AUC(device=available_device)
151+
assert roc_auc_metric._device == torch.device(available_device)
150152
roc_auc_metric.attach(engine, "roc_auc")
151153

152154
np_y = y.numpy()

tests/ignite/metrics/test_ssim.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,7 @@ def test_ssim_variable_batchsize(available_device):
163163
sigma = 1.5
164164
data_range = 1.0
165165
ssim = SSIM(data_range=data_range, sigma=sigma, device=available_device)
166-
166+
assert ssim._device == torch.device(available_device)
167167
y_preds = [
168168
torch.rand(12, 3, 28, 28, device=available_device),
169169
torch.rand(12, 3, 28, 28, device=available_device),
@@ -229,6 +229,7 @@ def test_ssim_uint8(available_device, shape, kernel_size, gaussian, use_sample_c
229229
sigma = 1.5
230230
data_range = 255
231231
ssim = SSIM(data_range=data_range, sigma=sigma, device=available_device)
232+
assert ssim._device == torch.device(available_device)
232233
ssim.update((y_pred, y))
233234
ignite_ssim = ssim.compute()
234235

0 commit comments

Comments
 (0)