You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
For below example, model is being compiled, DDPStrategy is passed to Trainer, then during fit method DDPStrategy is being applied, so forward is compiled but _pre_forward/_post_forward in DDP class is not.
Due to this in DDP _pre_forward/_post_forward cpp_reducer is not being disabled later on causing problem with queueing callback.
When DDP is also compiled cpp_reducer is disabled as expected.
importtorchimporttorch.nnasnnimporttorch.nn.functionalasFfromtorch.utils.dataimportDataLoader, TensorDatasetimportpytorch_lightningasplfrompytorch_lightning.strategiesimportDDPStrategyimportfunctoolsfromtorch._dynamoimportcompiled_autogradtorch._dynamo.config.optimize_ddp="python_reducer"classSomeModel(pl.LightningModule):
def__init__(self):
super().__init__()
self.layer=nn.Linear(10, 1)
defforward(self, x):
returnself.layer(x)
deftraining_step(self, batch, batch_idx):
x, y=batchreturntorch.mean((self(x) -y) **2)
defconfigure_optimizers(self):
returntorch.optim.Adam(self.parameters(), lr=0.02)
defcreate_dataset(num_samples=1000):
x=torch.randn(num_samples, 10)
y=torch.sum(x, dim=1, keepdim=True)
returnTensorDataset(x, y)
defrun_training():
dataset=create_dataset()
train_loader=DataLoader(dataset, batch_size=32)
model=SomeModel()
# First compile whole modelmodel=torch.compile(model)
# static_graph has to be true causing _DDPSink.backward method queueing callbackddp_strategy=DDPStrategy(static_graph=True)
trainer=pl.Trainer(
max_epochs=2,
accelerator='cpu',
devices=1,
strategy=ddp_strategy
)
# DDP will be applied inside fit method, so DDP pre/post forward won't be compiled while forward iswithcompiled_autograd.enable(torch.compile()):
trainer.fit(model, train_loader)
if__name__=="__main__":
run_training()
Expected repro:
[rank0]: Variable._execution_engine.queue_callback( # type: ignore[call-arg,misc]
[rank0]: RuntimeError: Final callbacks can only be installed during backward pass.
Pitch
It seems useful to compile after applying strategy, so my suggestion is to add something like bool use_compile parameter for Trainer which would help for example in this situation, and also be cleaner to use.
Looks like it should be more advanced than just bool to setup specific backend and other optional compile parameters.
Alternatives
src/lightning/pytorch/trainer/trainer.py
classTrainer:
def_run(
self, model: "pl.LightningModule", ckpt_path: Optional[_PATH] =None
) ->Optional[Union[_EVALUATE_OUTPUT, _PREDICT_OUTPUT]]:
(...)
# ----------------------------# SET UP THE TRAINER# ----------------------------
(...)
self.strategy.setup(self)
(...)
------>### Pseudo proposition------>ifself.use_compile------>self.model=torch.compile(self.model)
# ----------------------------# RUN THE TRAINER# ----------------------------results=self._run_stage()
(...)
Description & Motivation
For below example, model is being compiled,
DDPStrategy
is passed to Trainer, then during fit methodDDPStrategy
is being applied, so forward is compiled but_pre_forward/_post_forward
in DDP class is not.Due to this in DDP
_pre_forward/_post_forward
cpp_reducer is not being disabled later on causing problem with queueing callback.When DDP is also compiled cpp_reducer is disabled as expected.
Expected repro:
Pitch
It seems useful to compile after applying strategy, so my suggestion is to add something like bool
use_compile
parameter for Trainer which would help for example in this situation, and also be cleaner to use.Looks like it should be more advanced than just bool to setup specific backend and other optional compile parameters.
Alternatives
src/lightning/pytorch/trainer/trainer.py
Additional context
File: torch/nn/parallel/distributed.py
Dynamo replaces output of
torch._utils.is_compiling()
to True when code is compiled, False otherwise.cc @Borda
The text was updated successfully, but these errors were encountered: