Skip to content

Commit 9658d27

Browse files
committed
update docstring in losses
1 parent d7481a2 commit 9658d27

File tree

2 files changed

+160
-69
lines changed

2 files changed

+160
-69
lines changed

deel/torchlip/functional.py

+20-18
Original file line numberDiff line numberDiff line change
@@ -287,9 +287,7 @@ def apply_reduction(val: torch.Tensor, reduction: str) -> torch.Tensor:
287287
return red(val)
288288

289289

290-
def kr_loss(
291-
input: torch.Tensor, target: torch.Tensor, multi_gpu=False, true_values=None
292-
) -> torch.Tensor:
290+
def kr_loss(input: torch.Tensor, target: torch.Tensor, multi_gpu=False) -> torch.Tensor:
293291
r"""
294292
Loss to estimate the Wasserstein-1 distance using Kantorovich-Rubinstein duality,
295293
as per
@@ -300,12 +298,19 @@ def kr_loss(
300298
- \underset{\mathbf{x}\sim{}\nu}{\mathbb{E}}[f(\mathbf{x})]
301299
302300
where :math:`\mu` and :math:`\nu` are the distributions corresponding to the
303-
two possible labels as specific by ``true_values``.
301+
two possible labels as specific by their sign.
302+
303+
`target` accepts label values in (0, 1), (-1, 1), or pre-processed with the
304+
`deel.torchlip.functional.process_labels_for_multi_gpu()` function.
305+
306+
Using a multi-GPU/TPU strategy requires to set `multi_gpu` to True and to
307+
pre-process the labels `target` with the
308+
`deel.torchlip.functional.process_labels_for_multi_gpu()` function.
304309
305310
Args:
306311
input: Tensor of arbitrary shape.
307312
target: Tensor of the same shape as input.
308-
true_values: depreciated (target>0 is used)
313+
multi_gpu (bool): set to True when running on multi-GPU/TPU
309314
310315
Returns:
311316
The Wasserstein-1 loss between ``input`` and ``target``.
@@ -316,9 +321,7 @@ def kr_loss(
316321
return kr_loss_standard(input, target)
317322

318323

319-
def kr_loss_standard(
320-
input: torch.Tensor, target: torch.Tensor, true_values=None
321-
) -> torch.Tensor:
324+
def kr_loss_standard(input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
322325
r"""
323326
Loss to estimate the Wasserstein-1 distance using Kantorovich-Rubinstein duality,
324327
as per
@@ -329,12 +332,13 @@ def kr_loss_standard(
329332
- \underset{\mathbf{x}\sim{}\nu}{\mathbb{E}}[f(\mathbf{x})]
330333
331334
where :math:`\mu` and :math:`\nu` are the distributions corresponding to the
332-
two possible labels as specific by ``true_values``.
335+
two possible labels as specific by their sign.
336+
337+
`target` accepts label values in (0, 1), (-1, 1)
333338
334339
Args:
335340
input: Tensor of arbitrary shape.
336341
target: Tensor of the same shape as input.
337-
true_values: depreciated (target>0 is used)
338342
339343
Returns:
340344
The Wasserstein-1 loss between ``input`` and ``target``.
@@ -384,7 +388,6 @@ def neg_kr_loss(
384388
input: torch.Tensor,
385389
target: torch.Tensor,
386390
multi_gpu=False,
387-
true_values=None,
388391
) -> torch.Tensor:
389392
"""
390393
Loss to estimate the negative wasserstein-1 distance using Kantorovich-Rubinstein
@@ -393,7 +396,7 @@ def neg_kr_loss(
393396
Args:
394397
input: Tensor of arbitrary shape.
395398
target: Tensor of the same shape as input.
396-
true_values: depreciated (target>0 is used)
399+
multi_gpu (bool): set to True when running on multi-GPU/TPU
397400
398401
Returns:
399402
The negative Wasserstein-1 loss between ``input`` and ``target``.
@@ -437,7 +440,6 @@ def hkr_loss(
437440
alpha: float,
438441
min_margin: float = 1.0,
439442
multi_gpu=False,
440-
true_values=None,
441443
) -> torch.Tensor:
442444
"""
443445
Loss to estimate the wasserstein-1 distance with a hinge regularization using
@@ -446,9 +448,9 @@ def hkr_loss(
446448
Args:
447449
input: Tensor of arbitrary shape.
448450
target: Tensor of the same shape as input.
449-
alpha: Regularization factor between the hinge and the KR loss.
451+
alpha: Regularization factor ([0,1]) between the hinge and the KR loss.
450452
min_margin: Minimal margin for the hinge loss.
451-
true_values: tuple containing the two label for each predicted class.
453+
multi_gpu (bool): set to True when running on multi-GPU/TPU
452454
453455
Returns:
454456
The regularized Wasserstein-1 loss.
@@ -478,7 +480,7 @@ def hinge_multiclass_loss(
478480
"""
479481
Loss to estimate the Hinge loss in a multiclass setup. It compute the
480482
elementwise hinge term. Note that this formulation differs from the
481-
one commonly found in tensorflow/pytorch (with marximise the difference
483+
one commonly found in tensorflow/pytorch (with maximise the difference
482484
between the two largest logits). This formulation is consistent with the
483485
binary classification loss used in a multiclass fashion.
484486
@@ -515,9 +517,9 @@ def hkr_multiclass_loss(
515517
Args:
516518
input: Tensor of arbitrary shape.
517519
target: Tensor of the same shape as input.
518-
alpha: Regularization factor between the hinge and the KR loss.
520+
alpha: Regularization factor ([0,1]) between the hinge and the KR loss.
519521
min_margin: Minimal margin for the hinge loss.
520-
true_values: tuple containing the two label for each predicted class.
522+
multi_gpu (bool): set to True when running on multi-GPU/TPU
521523
522524
Returns:
523525
The regularized Wasserstein-1 loss.

0 commit comments

Comments
 (0)