@@ -287,9 +287,7 @@ def apply_reduction(val: torch.Tensor, reduction: str) -> torch.Tensor:
287
287
return red (val )
288
288
289
289
290
- def kr_loss (
291
- input : torch .Tensor , target : torch .Tensor , multi_gpu = False , true_values = None
292
- ) -> torch .Tensor :
290
+ def kr_loss (input : torch .Tensor , target : torch .Tensor , multi_gpu = False ) -> torch .Tensor :
293
291
r"""
294
292
Loss to estimate the Wasserstein-1 distance using Kantorovich-Rubinstein duality,
295
293
as per
@@ -300,12 +298,19 @@ def kr_loss(
300
298
- \underset{\mathbf{x}\sim{}\nu}{\mathbb{E}}[f(\mathbf{x})]
301
299
302
300
where :math:`\mu` and :math:`\nu` are the distributions corresponding to the
303
- two possible labels as specific by ``true_values``.
301
+ two possible labels as specific by their sign.
302
+
303
+ `target` accepts label values in (0, 1), (-1, 1), or pre-processed with the
304
+ `deel.torchlip.functional.process_labels_for_multi_gpu()` function.
305
+
306
+ Using a multi-GPU/TPU strategy requires to set `multi_gpu` to True and to
307
+ pre-process the labels `target` with the
308
+ `deel.torchlip.functional.process_labels_for_multi_gpu()` function.
304
309
305
310
Args:
306
311
input: Tensor of arbitrary shape.
307
312
target: Tensor of the same shape as input.
308
- true_values: depreciated (target>0 is used)
313
+ multi_gpu (bool): set to True when running on multi-GPU/TPU
309
314
310
315
Returns:
311
316
The Wasserstein-1 loss between ``input`` and ``target``.
@@ -316,9 +321,7 @@ def kr_loss(
316
321
return kr_loss_standard (input , target )
317
322
318
323
319
- def kr_loss_standard (
320
- input : torch .Tensor , target : torch .Tensor , true_values = None
321
- ) -> torch .Tensor :
324
+ def kr_loss_standard (input : torch .Tensor , target : torch .Tensor ) -> torch .Tensor :
322
325
r"""
323
326
Loss to estimate the Wasserstein-1 distance using Kantorovich-Rubinstein duality,
324
327
as per
@@ -329,12 +332,13 @@ def kr_loss_standard(
329
332
- \underset{\mathbf{x}\sim{}\nu}{\mathbb{E}}[f(\mathbf{x})]
330
333
331
334
where :math:`\mu` and :math:`\nu` are the distributions corresponding to the
332
- two possible labels as specific by ``true_values``.
335
+ two possible labels as specific by their sign.
336
+
337
+ `target` accepts label values in (0, 1), (-1, 1)
333
338
334
339
Args:
335
340
input: Tensor of arbitrary shape.
336
341
target: Tensor of the same shape as input.
337
- true_values: depreciated (target>0 is used)
338
342
339
343
Returns:
340
344
The Wasserstein-1 loss between ``input`` and ``target``.
@@ -384,7 +388,6 @@ def neg_kr_loss(
384
388
input : torch .Tensor ,
385
389
target : torch .Tensor ,
386
390
multi_gpu = False ,
387
- true_values = None ,
388
391
) -> torch .Tensor :
389
392
"""
390
393
Loss to estimate the negative wasserstein-1 distance using Kantorovich-Rubinstein
@@ -393,7 +396,7 @@ def neg_kr_loss(
393
396
Args:
394
397
input: Tensor of arbitrary shape.
395
398
target: Tensor of the same shape as input.
396
- true_values: depreciated (target>0 is used)
399
+ multi_gpu (bool): set to True when running on multi-GPU/TPU
397
400
398
401
Returns:
399
402
The negative Wasserstein-1 loss between ``input`` and ``target``.
@@ -437,7 +440,6 @@ def hkr_loss(
437
440
alpha : float ,
438
441
min_margin : float = 1.0 ,
439
442
multi_gpu = False ,
440
- true_values = None ,
441
443
) -> torch .Tensor :
442
444
"""
443
445
Loss to estimate the wasserstein-1 distance with a hinge regularization using
@@ -446,9 +448,9 @@ def hkr_loss(
446
448
Args:
447
449
input: Tensor of arbitrary shape.
448
450
target: Tensor of the same shape as input.
449
- alpha: Regularization factor between the hinge and the KR loss.
451
+ alpha: Regularization factor ([0,1]) between the hinge and the KR loss.
450
452
min_margin: Minimal margin for the hinge loss.
451
- true_values: tuple containing the two label for each predicted class.
453
+ multi_gpu (bool): set to True when running on multi-GPU/TPU
452
454
453
455
Returns:
454
456
The regularized Wasserstein-1 loss.
@@ -478,7 +480,7 @@ def hinge_multiclass_loss(
478
480
"""
479
481
Loss to estimate the Hinge loss in a multiclass setup. It compute the
480
482
elementwise hinge term. Note that this formulation differs from the
481
- one commonly found in tensorflow/pytorch (with marximise the difference
483
+ one commonly found in tensorflow/pytorch (with maximise the difference
482
484
between the two largest logits). This formulation is consistent with the
483
485
binary classification loss used in a multiclass fashion.
484
486
@@ -515,9 +517,9 @@ def hkr_multiclass_loss(
515
517
Args:
516
518
input: Tensor of arbitrary shape.
517
519
target: Tensor of the same shape as input.
518
- alpha: Regularization factor between the hinge and the KR loss.
520
+ alpha: Regularization factor ([0,1]) between the hinge and the KR loss.
519
521
min_margin: Minimal margin for the hinge loss.
520
- true_values: tuple containing the two label for each predicted class.
522
+ multi_gpu (bool): set to True when running on multi-GPU/TPU
521
523
522
524
Returns:
523
525
The regularized Wasserstein-1 loss.
0 commit comments