From 52299e6a659e68879b2c656ffd82ae2a9d28afd8 Mon Sep 17 00:00:00 2001 From: Keshav Vinayak Jha <162080376+keshavj-cerebras@users.noreply.github.com> Date: Fri, 7 Feb 2025 12:01:07 +0530 Subject: [PATCH] Missing Shape Inference for Prod (#4003) Added required but missing Shape inference for `aten.prod` --- .../ltc/csrc/base_lazy_backend/shape_inference.cpp | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/projects/ltc/csrc/base_lazy_backend/shape_inference.cpp b/projects/ltc/csrc/base_lazy_backend/shape_inference.cpp index 2e42e4fed3ba..04f81dac0446 100644 --- a/projects/ltc/csrc/base_lazy_backend/shape_inference.cpp +++ b/projects/ltc/csrc/base_lazy_backend/shape_inference.cpp @@ -252,6 +252,18 @@ std::vector compute_shape_native_group_norm( return shapes; } +std::vector +compute_shape_prod(const at::Tensor &self, + c10::optional dtype) { + if (dtype.has_value()) { + return {Shape(dtype.value(), {})}; + } + if (isIntegralType(self.scalar_type(), true)) { + return {Shape(c10::ScalarType::Long, {})}; + } + return {Shape(self.scalar_type(), {})}; +} + std::vector compute_shape_im2col(const at::Tensor &self, at::IntArrayRef kernel_size, at::IntArrayRef dilation, at::IntArrayRef padding,