You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
This works as written on tf and jax backends without issue, but on the torch backend we OOM the GPU in the middle of the first epoch. This appears to be a leak or something inconsistent as we see this a variable number of steps into training. A few hundred or few thousand depending on the run.
[/usr/local/lib/python3.11/dist-packages/keras/src/trainers/compile_utils.py](https://localhost:8080/#) in __call__(self, y_true, y_pred, sample_weight)
689 def __call__(self, y_true, y_pred, sample_weight=None):
690 with ops.name_scope(self.name):
--> 691 return self.call(y_true, y_pred, sample_weight)
692
693 def call(self, y_true, y_pred, sample_weight=None):
[/usr/local/lib/python3.11/dist-packages/keras/src/trainers/compile_utils.py](https://localhost:8080/#) in call(self, y_true, y_pred, sample_weight)
698 _, loss_fn, loss_weight, _ = self._flat_losses[0]
699 loss_value = ops.cast(
--> 700 loss_fn(y_true, y_pred, sample_weight), dtype=self.dtype
701 )
702 if loss_weight is not None:
[/usr/local/lib/python3.11/dist-packages/keras/src/losses/loss.py](https://localhost:8080/#) in __call__(self, y_true, y_pred, sample_weight)
65 )
66
---> 67 losses = self.call(y_true, y_pred)
68 out_mask = backend.get_keras_mask(losses)
69
[/usr/local/lib/python3.11/dist-packages/keras/src/losses/losses.py](https://localhost:8080/#) in call(self, y_true, y_pred)
31 y_true = tree.map_structure_up_to(y_true, lambda x: x[0], y_true_y_pred)
32 y_pred = tree.map_structure_up_to(y_pred, lambda x: x[1], y_true_y_pred)
---> 33 return self.fn(y_true, y_pred, **self._fn_kwargs)
34
35 def get_config(self):
[/usr/local/lib/python3.11/dist-packages/keras/src/losses/losses.py](https://localhost:8080/#) in sparse_categorical_crossentropy(y_true, y_pred, from_logits, ignore_class, axis)
2244 )
2245
-> 2246 res = ops.sparse_categorical_crossentropy(
2247 y_true,
2248 y_pred,
[/usr/local/lib/python3.11/dist-packages/keras/src/ops/nn.py](https://localhost:8080/#) in sparse_categorical_crossentropy(target, output, from_logits, axis)
1961 from_logits=from_logits, axis=axis
1962 ).symbolic_call(target, output)
-> 1963 return backend.nn.sparse_categorical_crossentropy(
1964 target, output, from_logits=from_logits, axis=axis
1965 )
[/usr/local/lib/python3.11/dist-packages/keras/src/backend/torch/nn.py](https://localhost:8080/#) in sparse_categorical_crossentropy(target, output, from_logits, axis)
705 output = torch.clip(output, backend.epsilon(), 1.0 - backend.epsilon())
706 log_prob = torch.log(output)
--> 707 target = one_hot(target, output.shape[axis], axis=axis)
708 return -torch.sum(target * log_prob, dim=axis)
709
[/usr/local/lib/python3.11/dist-packages/keras/src/backend/torch/nn.py](https://localhost:8080/#) in one_hot(x, num_classes, axis, dtype, sparse)
629 # `where` afterwards.
630 output = tnn.one_hot(maximum(x, 0), num_classes)
--> 631 output = where(expand_dims(x, axis=-1) >= 0, output, zero)
632 output = convert_to_tensor(output, dtype=dtype)
633 dims = output.dim()
[/usr/local/lib/python3.11/dist-packages/keras/src/backend/torch/numpy.py](https://localhost:8080/#) in where(condition, x1, x2)
1529 x1 = convert_to_tensor(x1)
1530 x2 = convert_to_tensor(x2)
-> 1531 return torch.where(condition, x1, x2)
1532 else:
1533 return torch.where(condition)
OutOfMemoryError: CUDA out of memory. Tried to allocate 7.81 GiB. GPU 0 has a total capacity of 39.56 GiB of which 7.34 GiB is free. Process 30879 has 32.21 GiB memory in use. Of the allocated memory 25.56 GiB is allocated by PyTorch, and 6.14 GiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation. See documentation for Memory Management (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)
The text was updated successfully, but these errors were encountered:
Working on a mini-gpt example for @monicadsong. Loads a decently large (few GB)
tf.data.Dataset
. Colab:https://colab.research.google.com/gist/mattdangerw/4f871c46f3eb5af49f828e2aea3bef79/mini-gpt-from-scatch.ipynb
This works as written on tf and jax backends without issue, but on the torch backend we OOM the GPU in the middle of the first epoch. This appears to be a leak or something inconsistent as we see this a variable number of steps into training. A few hundred or few thousand depending on the run.
The text was updated successfully, but these errors were encountered: