Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add something like use_compile parameter for Trainer #20242

Open
mieshkiwrk opened this issue Sep 3, 2024 · 1 comment · May be fixed by #20269
Open

Add something like use_compile parameter for Trainer #20242

mieshkiwrk opened this issue Sep 3, 2024 · 1 comment · May be fixed by #20269
Labels
feature Is an improvement or enhancement needs triage Waiting to be triaged by maintainers

Comments

@mieshkiwrk
Copy link

mieshkiwrk commented Sep 3, 2024

Description & Motivation

For below example, model is being compiled, DDPStrategy is passed to Trainer, then during fit method DDPStrategy is being applied, so forward is compiled but _pre_forward/_post_forward in DDP class is not.
Due to this in DDP _pre_forward/_post_forward cpp_reducer is not being disabled later on causing problem with queueing callback.
When DDP is also compiled cpp_reducer is disabled as expected.

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset
import pytorch_lightning as pl
from pytorch_lightning.strategies import DDPStrategy
import functools
from torch._dynamo import compiled_autograd


torch._dynamo.config.optimize_ddp = "python_reducer"

class SomeModel(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.layer = nn.Linear(10, 1)

    def forward(self, x):
        return self.layer(x)

    def training_step(self, batch, batch_idx):
        x, y = batch
        return torch.mean((self(x) - y) ** 2)

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=0.02)


def create_dataset(num_samples=1000):
    x = torch.randn(num_samples, 10)
    y = torch.sum(x, dim=1, keepdim=True)
    return TensorDataset(x, y)


def run_training():
    dataset = create_dataset()
    train_loader = DataLoader(dataset, batch_size=32)

    model = SomeModel()

    # First compile whole model
    model = torch.compile(model)

    # static_graph has to be true causing _DDPSink.backward method queueing callback
    ddp_strategy = DDPStrategy(static_graph=True)

    trainer = pl.Trainer(
            max_epochs=2,
            accelerator='cpu',
            devices=1,
            strategy=ddp_strategy
    )

    # DDP will be applied inside fit method, so DDP pre/post forward won't be compiled while forward is
    with compiled_autograd.enable(torch.compile()):
        trainer.fit(model, train_loader)


if __name__ == "__main__":
    run_training()

Expected repro:

[rank0]:     Variable._execution_engine.queue_callback(  # type: ignore[call-arg,misc]
[rank0]: RuntimeError: Final callbacks can only be installed during backward pass. 

Pitch

It seems useful to compile after applying strategy, so my suggestion is to add something like bool use_compile parameter for Trainer which would help for example in this situation, and also be cleaner to use.
Looks like it should be more advanced than just bool to setup specific backend and other optional compile parameters.

Alternatives

src/lightning/pytorch/trainer/trainer.py

class Trainer:
    def _run(
        self, model: "pl.LightningModule", ckpt_path: Optional[_PATH] = None
    ) -> Optional[Union[_EVALUATE_OUTPUT, _PREDICT_OUTPUT]]:
        (...)
        # ----------------------------
        # SET UP THE TRAINER
        # ----------------------------
        (...)
        self.strategy.setup(self)
        (...)
        
------> ### Pseudo proposition
------> if self.use_compile
------>     self.model = torch.compile(self.model)

        # ----------------------------
        # RUN THE TRAINER
        # ----------------------------
        results = self._run_stage()

        (...)

Additional context

File: torch/nn/parallel/distributed.py

class DistributedDataParallel(Module, Joinable): 
    def _should_disable_cpp_reducer(self) -> bool: 
        return self._use_python_reducer and ( 
            torch._utils.is_compiling() or self._force_to_disable_cpp_reducer 
        )

    def _pre_forward(self, *inputs, **kwargs):        
        if self._should_disable_cpp_reducer():                  
            return inputs, kwargs
        (...)

    def _post_forward(self, output): 
        if self._should_disable_cpp_reducer(): 
            return output
        (...)

Dynamo replaces output of torch._utils.is_compiling() to True when code is compiled, False otherwise.

cc @Borda

@mieshkiwrk mieshkiwrk added feature Is an improvement or enhancement needs triage Waiting to be triaged by maintainers labels Sep 3, 2024
@mieshkiwrk
Copy link
Author

cc @jerome-habana

@mieshkiwrk mieshkiwrk linked a pull request Sep 10, 2024 that will close this issue
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
feature Is an improvement or enhancement needs triage Waiting to be triaged by maintainers
Projects
None yet
Development

Successfully merging a pull request may close this issue.

1 participant