diff --git a/paddle/phi/backends/xpu/xpu3_op_list.cc b/paddle/phi/backends/xpu/xpu3_op_list.cc index 2c2cb8fd1bfb5..3480aa6e87e2c 100644 --- a/paddle/phi/backends/xpu/xpu3_op_list.cc +++ b/paddle/phi/backends/xpu/xpu3_op_list.cc @@ -764,6 +764,7 @@ XPUOpMap& get_kl3_ops() { phi::DataType::FLOAT16, phi::DataType::BFLOAT16, phi::DataType::FLOAT32})}, + {"grid_sampler", XPUKernelSet({phi::DataType::FLOAT32})}, {"grid_sampler_grad", XPUKernelSet({phi::DataType::FLOAT32})}, {"group_norm_silu_xpu", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, diff --git a/test/xpu/test_grid_sampler_op_xpu.py b/test/xpu/test_grid_sampler_op_xpu.py index a5674327eaf59..816930da5344a 100644 --- a/test/xpu/test_grid_sampler_op_xpu.py +++ b/test/xpu/test_grid_sampler_op_xpu.py @@ -23,6 +23,7 @@ from op_test_xpu import XPUOpTest import paddle +from paddle.base import core paddle.enable_static() @@ -484,6 +485,10 @@ def initTestCase(self): self.mode = "bilinear" # 3d grid_sample_grad is not supported yet + @unittest.skipIf( + core.get_xpu_device_version(0) == core.XPUVersion.XPU3, + "grid_sample3d for XPU3 is not supported", + ) class TestGridSample3DBilinear(TestXPUGridSamplerOp): def initTestCase(self): self.x_shape = (2, 3, 5, 6, 7) @@ -495,6 +500,10 @@ def initTestCase(self): self.no_need_check_grad = True + @unittest.skipIf( + core.get_xpu_device_version(0) == core.XPUVersion.XPU3, + "grid_sample3d for XPU3 is not supported", + ) class TestGridSample3DNearest(TestXPUGridSamplerOp): def initTestCase(self): self.x_shape = (2, 3, 5, 6, 7) @@ -506,6 +515,10 @@ def initTestCase(self): self.no_need_check_grad = True + @unittest.skipIf( + core.get_xpu_device_version(0) == core.XPUVersion.XPU3, + "grid_sample3d for XPU3 is not supported", + ) class TestGridSample3DBorder(TestXPUGridSamplerOp): def initTestCase(self): self.x_shape = (2, 3, 5, 6, 7) @@ -517,6 +530,10 @@ def initTestCase(self): self.no_need_check_grad = True + @unittest.skipIf( + core.get_xpu_device_version(0) == core.XPUVersion.XPU3, + "grid_sample3d for XPU3 is not supported", + ) class TestGridSample3DReflection(TestXPUGridSamplerOp): def initTestCase(self): self.x_shape = (2, 3, 5, 6, 7) @@ -528,6 +545,10 @@ def initTestCase(self): self.no_need_check_grad = True + @unittest.skipIf( + core.get_xpu_device_version(0) == core.XPUVersion.XPU3, + "grid_sample3d for XPU3 is not supported", + ) class TestGridSample3DAlignCornersFalse(TestXPUGridSamplerOp): def initTestCase(self): self.x_shape = (2, 3, 5, 6, 7)