Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Memory leak/crash on torch backend #2142

Closed
mattdangerw opened this issue Mar 15, 2025 · 0 comments
Closed

Memory leak/crash on torch backend #2142

mattdangerw opened this issue Mar 15, 2025 · 0 comments
Assignees

Comments

@mattdangerw
Copy link
Member

Working on a mini-gpt example for @monicadsong. Loads a decently large (few GB) tf.data.Dataset. Colab:

https://colab.research.google.com/gist/mattdangerw/4f871c46f3eb5af49f828e2aea3bef79/mini-gpt-from-scatch.ipynb

This works as written on tf and jax backends without issue, but on the torch backend we OOM the GPU in the middle of the first epoch. This appears to be a leak or something inconsistent as we see this a variable number of steps into training. A few hundred or few thousand depending on the run.

[/usr/local/lib/python3.11/dist-packages/keras/src/trainers/compile_utils.py](https://localhost:8080/#) in __call__(self, y_true, y_pred, sample_weight)
    689     def __call__(self, y_true, y_pred, sample_weight=None):
    690         with ops.name_scope(self.name):
--> 691             return self.call(y_true, y_pred, sample_weight)
    692 
    693     def call(self, y_true, y_pred, sample_weight=None):

[/usr/local/lib/python3.11/dist-packages/keras/src/trainers/compile_utils.py](https://localhost:8080/#) in call(self, y_true, y_pred, sample_weight)
    698             _, loss_fn, loss_weight, _ = self._flat_losses[0]
    699             loss_value = ops.cast(
--> 700                 loss_fn(y_true, y_pred, sample_weight), dtype=self.dtype
    701             )
    702             if loss_weight is not None:

[/usr/local/lib/python3.11/dist-packages/keras/src/losses/loss.py](https://localhost:8080/#) in __call__(self, y_true, y_pred, sample_weight)
     65             )
     66 
---> 67             losses = self.call(y_true, y_pred)
     68             out_mask = backend.get_keras_mask(losses)
     69 

[/usr/local/lib/python3.11/dist-packages/keras/src/losses/losses.py](https://localhost:8080/#) in call(self, y_true, y_pred)
     31         y_true = tree.map_structure_up_to(y_true, lambda x: x[0], y_true_y_pred)
     32         y_pred = tree.map_structure_up_to(y_pred, lambda x: x[1], y_true_y_pred)
---> 33         return self.fn(y_true, y_pred, **self._fn_kwargs)
     34 
     35     def get_config(self):

[/usr/local/lib/python3.11/dist-packages/keras/src/losses/losses.py](https://localhost:8080/#) in sparse_categorical_crossentropy(y_true, y_pred, from_logits, ignore_class, axis)
   2244         )
   2245 
-> 2246     res = ops.sparse_categorical_crossentropy(
   2247         y_true,
   2248         y_pred,

[/usr/local/lib/python3.11/dist-packages/keras/src/ops/nn.py](https://localhost:8080/#) in sparse_categorical_crossentropy(target, output, from_logits, axis)
   1961             from_logits=from_logits, axis=axis
   1962         ).symbolic_call(target, output)
-> 1963     return backend.nn.sparse_categorical_crossentropy(
   1964         target, output, from_logits=from_logits, axis=axis
   1965     )

[/usr/local/lib/python3.11/dist-packages/keras/src/backend/torch/nn.py](https://localhost:8080/#) in sparse_categorical_crossentropy(target, output, from_logits, axis)
    705         output = torch.clip(output, backend.epsilon(), 1.0 - backend.epsilon())
    706         log_prob = torch.log(output)
--> 707     target = one_hot(target, output.shape[axis], axis=axis)
    708     return -torch.sum(target * log_prob, dim=axis)
    709 

[/usr/local/lib/python3.11/dist-packages/keras/src/backend/torch/nn.py](https://localhost:8080/#) in one_hot(x, num_classes, axis, dtype, sparse)
    629     # `where` afterwards.
    630     output = tnn.one_hot(maximum(x, 0), num_classes)
--> 631     output = where(expand_dims(x, axis=-1) >= 0, output, zero)
    632     output = convert_to_tensor(output, dtype=dtype)
    633     dims = output.dim()

[/usr/local/lib/python3.11/dist-packages/keras/src/backend/torch/numpy.py](https://localhost:8080/#) in where(condition, x1, x2)
   1529         x1 = convert_to_tensor(x1)
   1530         x2 = convert_to_tensor(x2)
-> 1531         return torch.where(condition, x1, x2)
   1532     else:
   1533         return torch.where(condition)

OutOfMemoryError: CUDA out of memory. Tried to allocate 7.81 GiB. GPU 0 has a total capacity of 39.56 GiB of which 7.34 GiB is free. Process 30879 has 32.21 GiB memory in use. Of the allocated memory 25.56 GiB is allocated by PyTorch, and 6.14 GiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants