Skip to content

Commit 76a776e

Browse files
authored
Fix grid_sample (#3340)
1 parent ca59597 commit 76a776e

File tree

3 files changed

+64
-5
lines changed

3 files changed

+64
-5
lines changed

py/torch_tensorrt/dynamo/conversion/impl/grid.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,13 @@
44
from torch.fx.node import Target
55
from torch_tensorrt.dynamo._SourceIR import SourceIR
66
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
7-
from torch_tensorrt.fx.converters.converter_utils import set_layer_name
8-
from torch_tensorrt.fx.types import TRTTensor
7+
from torch_tensorrt.dynamo.conversion.converter_utils import set_layer_name
8+
from torch_tensorrt.dynamo.types import TRTTensor
99

10-
# nearest, linear, cubic
10+
# bilinear, nearest, bicubic
1111
GridSamplerInterpolationMode = {
12-
0: trt.InterpolationMode.NEAREST,
13-
1: trt.InterpolationMode.LINEAR,
12+
0: trt.InterpolationMode.LINEAR,
13+
1: trt.InterpolationMode.NEAREST,
1414
2: trt.InterpolationMode.CUBIC,
1515
}
1616

py/torch_tensorrt/dynamo/lowering/_decompositions.py

+9
Original file line numberDiff line numberDiff line change
@@ -566,6 +566,15 @@ def scaled_dot_product_cudnn_attention_decomposition(
566566
return attn, None, None, None, 0, 0, None, None, None
567567

568568

569+
@register_torch_trt_decomposition(
570+
aten.cudnn_grid_sampler, registry=TORCH_TRT_DECOMPOSITIONS
571+
)
572+
def cudnn_grid_sampler_decomposition(
573+
x: torch.Tensor, grid: torch.Tensor
574+
) -> torch.Tensor:
575+
return torch.grid_sampler_2d(x, grid, 0, 0, True)
576+
577+
569578
def get_decompositions(
570579
enable_experimental_decompositions: bool = False,
571580
) -> Dict[OpOverload, Callable[[Any], Any]]:

tests/py/dynamo/lowering/test_decompositions.py

+50
Original file line numberDiff line numberDiff line change
@@ -2117,6 +2117,56 @@ def forward(self, query, key, value, attn_bias=None):
21172117
msg="Scaled_dot_product_cudnn_attention TRT outputs don't match with the original model.",
21182118
)
21192119

2120+
def test_lowering_cudnn_grid_sampler(self):
2121+
class TestModule(torch.nn.Module):
2122+
def forward(self, x, grid):
2123+
return torch.ops.aten.cudnn_grid_sampler.default(x, grid)
2124+
2125+
# Operations expected to be removed in the traced graph after decompositions
2126+
expected_ops = {torch.ops.aten.grid_sampler_2d.default}
2127+
unexpected_ops = {torch.ops.aten.cudnn_grid_sampler.default}
2128+
2129+
inputs = [
2130+
torch.randn(1, 3, 5, 7, device="cuda"),
2131+
torch.randn(1, 5, 7, 2, device="cuda"),
2132+
]
2133+
2134+
exported_program = torch.export.export(TestModule(), tuple(inputs))
2135+
fx_graph = exported_program.module()
2136+
unexpected_ops_seen, expected_ops_unseen = lower_graph_testing(
2137+
fx_graph,
2138+
inputs,
2139+
expected_ops=expected_ops,
2140+
unexpected_ops=unexpected_ops,
2141+
min_block_size=1,
2142+
)
2143+
2144+
self.assertEqual(
2145+
len(unexpected_ops_seen),
2146+
0,
2147+
f"The following unexpected ops were encountered: {unexpected_ops_seen}",
2148+
)
2149+
2150+
self.assertEqual(
2151+
len(expected_ops_unseen),
2152+
0,
2153+
f"The following expected ops were not encountered: {expected_ops_unseen}",
2154+
)
2155+
2156+
torch._dynamo.reset()
2157+
2158+
# Validate that the results between Torch and Torch-TRT are similar
2159+
trt_model = torch_tensorrt.dynamo.compile(
2160+
exported_program, inputs, min_block_size=1
2161+
)
2162+
torch.testing.assert_close(
2163+
trt_model(*inputs),
2164+
fx_graph(*inputs),
2165+
rtol=RTOL,
2166+
atol=ATOL,
2167+
msg="Cudnn_grid_sampler TRT outputs don't match with the original model.",
2168+
)
2169+
21202170

21212171
if __name__ == "__main__":
21222172
run_tests()

0 commit comments

Comments
 (0)