Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Make HeavyBall compatible with FSDP2 (DTensor) #15

Open
casper-hansen opened this issue Nov 22, 2024 · 4 comments
Open

[Feature] Make HeavyBall compatible with FSDP2 (DTensor) #15

casper-hansen opened this issue Nov 22, 2024 · 4 comments

Comments

@casper-hansen
Copy link

Here is an example traceback of me trying to use heavyball in TorchTitan. It seems that heavyball is not yet compatible with FSDP2 due to some of the utilities operating on tensors instead of dtensors.

Code:

import heavyball
import torch
from torchtitan.config_manager import JobConfig

# consider split between PP and non-PP
def build_optimizers(model_parts, job_config: JobConfig):
    """Wrap one optimizer per model part in an OptimizersContainer which provides a single
    step() and zero_grad() method for all the child optimizers.
    """

    def _build_optimizer(model):
        name = job_config.optimizer.name
        lr = job_config.optimizer.lr
        fused = job_config.optimizer.fused

        # Common parameters for both optimizers
        optimizer_kwargs = {
            "lr": lr,
            "betas": (0.9, 0.95),
            "weight_decay": 0.1,
            "fused": fused,
            "foreach": not fused,
        }
        if name == "Adam":
            # TODO: make the optimizer options configurable by toml/cmd args
            optimizer = torch.optim.Adam(model.parameters(), **optimizer_kwargs)
        elif name == "AdamW":
            optimizer = torch.optim.AdamW(model.parameters(), **optimizer_kwargs)
        else:
            optimizer_kwargs = {
                "lr": lr,
                "betas": (0.9, 0.95),
                "weight_decay": 0.1,
                "foreach": True,
            }
            optimizer_cls = getattr(heavyball, name)
            optimizer = optimizer_cls(model.parameters(), **optimizer_kwargs)

        return optimizer

    class OptimizersContainer:
        """Util for calling step/zero_grad on multiple optimizers needed for virtual pipeline stages"""

        def __init__(self, optimizers):
            self.optimizers = optimizers

        def step(self):
            for optimizer in self.optimizers:
                optimizer.step()

        def zero_grad(self):
            for optimizer in self.optimizers:
                optimizer.zero_grad()

    return OptimizersContainer([_build_optimizer(model) for model in model_parts])

Traceback:

0: [rank0]:Traceback (most recent call last):
0: [rank0]:  File "/workspace/./nlp_train/torchtitan/train.py", line 540, in <module>
0: [rank0]:    main(job_config)
0: [rank0]:  File "/opt/conda/envs/py_3.11/lib/python3.11/site-packages/torch/distributed/elastic/multiprocessing/errors/__init__.py", line 355, in wrapper
0: [rank0]:    return f(*args, **kwargs)
0: [rank0]:           ^^^^^^^^^^^^^^^^^^
0: [rank0]:  File "/workspace/./nlp_train/torchtitan/train.py", line 392, in main
0: [rank0]:    optimizers.step()
0: [rank0]:  File "/workspace/nlp_train/torchtitan/optimizer/precond_soap.py", line 55, in step
0: [rank0]:    optimizer.step()
0: [rank0]:  File "/opt/conda/envs/py_3.11/lib/python3.11/site-packages/torch/optim/lr_scheduler.py", line 140, in wrapper
0: [rank0]:    return func.__get__(opt, opt.__class__)(*args, **kwargs)
0: [rank0]:           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
0: [rank0]:  File "/opt/conda/envs/py_3.11/lib/python3.11/site-packages/torch/optim/optimizer.py", line 494, in wrapper
0: [rank0]:    out = func(*args, **kwargs)
0: [rank0]:          ^^^^^^^^^^^^^^^^^^^^^
0: [rank0]:  File "/opt/conda/envs/py_3.11/lib/python3.11/site-packages/heavyball/utils.py", line 446, in step
0: [rank0]:    self._step(group)
0: [rank0]:  File "/opt/conda/envs/py_3.11/lib/python3.11/site-packages/heavyball/precond_schedule_palm_foreach_soap.py", line 63, in _step
0: [rank0]:    update_preconditioner(g, state, max_precond_dim, precondition_1d, 0, True)
0: [rank0]:  File "/opt/conda/envs/py_3.11/lib/python3.11/site-packages/heavyball/utils.py", line 354, in update_preconditioner
0: [rank0]:    compute_ggt(grad, state['GG'], max_precond_dim, precondition_1d, beta)
0: [rank0]:  File "/opt/conda/envs/py_3.11/lib/python3.11/site-packages/heavyball/utils.py", line 23, in _fn
0: [rank0]:    return func(*args, **kwargs)
0: [rank0]:           ^^^^^^^^^^^^^^^^^^^^^
0: [rank0]:  File "/opt/conda/envs/py_3.11/lib/python3.11/site-packages/heavyball/utils.py", line 331, in compute_ggt
0: [rank0]:    GG[idx].lerp_(promote(outer_product), 1 - beta)
0: [rank0]:  File "/opt/conda/envs/py_3.11/lib/python3.11/site-packages/torch/_compile.py", line 32, in inner
0: [rank0]:    return disable_fn(*args, **kwargs)
0: [rank0]:           ^^^^^^^^^^^^^^^^^^^^^^^^^^^
0: [rank0]:  File "/opt/conda/envs/py_3.11/lib/python3.11/site-packages/torch/_dynamo/eval_frame.py", line 721, in _fn
0: [rank0]:    return fn(*args, **kwargs)
0: [rank0]:           ^^^^^^^^^^^^^^^^^^^
0: [rank0]:  File "/opt/conda/envs/py_3.11/lib/python3.11/site-packages/torch/distributed/tensor/_api.py", line 343, in __torch_dispatch__
0: [rank0]:    return DTensor._op_dispatcher.dispatch(
0: [rank0]:           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
0: [rank0]:  File "/opt/conda/envs/py_3.11/lib/python3.11/site-packages/torch/distributed/tensor/_dispatch.py", line 166, in dispatch
0: [rank0]:    op_info = self.unwrap_to_op_info(op_call, args, kwargs)
0: [rank0]:              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
0: [rank0]:  File "/opt/conda/envs/py_3.11/lib/python3.11/site-packages/torch/distributed/tensor/_dispatch.py", line 371, in unwrap_to_op_info
0: [rank0]:    self._try_replicate_spec_for_scalar_tensor(op_call, arg, mesh)
0: [rank0]:  File "/opt/conda/envs/py_3.11/lib/python3.11/site-packages/torch/distributed/tensor/_dispatch.py", line 470, in _try_replicate_spec_for_scalar_tensor
0: [rank0]:    raise RuntimeError(
0: [rank0]:RuntimeError: aten.lerp_.Scalar: got mixed torch.Tensor and DTensor, need to convert all torch.Tensor to DTensor before calling distributed operators!
@ClashLuke
Copy link
Owner

Interesting, thank you for raising an issue.
I can't currently maintain DTensors and personally don't use them either. However, @ethansmith2000 has started an implementation at #14, which may interest you.

@ethansmith2000
Copy link
Contributor

I have it working with FSDP now though a few sharp bits that need to be adjusted, especially around gradient clipping and checkpoint saving
#11

@casper-hansen
Copy link
Author

Interesting, thank you for raising an issue. I can't currently maintain DTensors and personally don't use them either. However, @ethansmith2000 has started an implementation at #14, which may interest you.

That is fair enough. I wish there was compatibility, but it seems it can be quite some work to get there. I will have a look at https://github.com/facebookresearch/optimizers/ which seems to support a distributed version of Shampoo.

@ethansmith2000
Copy link
Contributor

ethansmith2000 commented Nov 23, 2024

@casper-hansen
have supported SOAP and psgd-kron here for FSDP https://github.com/ethansmith2000/fsdp_optimizers
it conflicts a bit with the optimizations as there's many places we can't do compiled, in-place operations, and some of the foreach stuff, but im hoping torch will have more support for this in the future.

somewhat comes down to whether memory management or speed is the higher priority

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants