diff --git a/mmcv/parallel/_functions.py b/mmcv/parallel/_functions.py index 43580b46f9..e40442044c 100644 --- a/mmcv/parallel/_functions.py +++ b/mmcv/parallel/_functions.py @@ -4,7 +4,7 @@ import torch from torch import Tensor from torch.nn.parallel._functions import _get_stream - +from packaging import version def scatter(input: Union[List, Tensor], devices: List, @@ -72,7 +72,10 @@ def forward(target_gpus: List[int], input: Union[List, Tensor]) -> tuple: streams = None if input_device == -1 and target_gpus != [-1]: # Perform CPU to GPU copies in a background stream - streams = [_get_stream(device) for device in target_gpus] + if version.parse(torch.__version__) >= version.parse('2.1.0'): + streams = [_get_stream(torch.device("cuda", device)) for device in target_gpus] + else: + streams = [_get_stream(device) for device in target_gpus] outputs = scatter(input, target_gpus, streams) # Synchronize with the copy stream