@@ -79,64 +79,6 @@ Tensor bmm_nested(const Tensor& self, const Tensor& mat2) {
79
79
return output;
80
80
}
81
81
82
- // utilities support `matmul_nested`
83
- namespace {
84
- // Args:
85
- // self_sizes: the sizes of `self` in `matmul_nested`
86
- // mat2_sizes: the sizes of `mat2` in `matmul_nested`
87
- // buffer_op: the options for new buffer
88
- // sizemat_op: the options for new size matrix
89
- // Returns:
90
- // the batch size of each input underlying tensor, i.e. the product of batch-dimension sizes
91
- // the empty output nested tensor
92
- inline std::tuple<std::vector<int64_t >, Tensor>
93
- matmul_nested_helper (
94
- const std::vector<IntArrayRef>& self_sizes,
95
- const std::vector<IntArrayRef>& mat2_sizes,
96
- const c10::TensorOptions& buffer_op,
97
- const c10::TensorOptions& sizemat_op) {
98
- int64_t ntensors = self_sizes.size (),
99
- ndims = self_sizes[0 ].size ();
100
- std::vector<int64_t > batch_sizes (ntensors, 1 );
101
- Tensor sizemat = at::empty ({ntensors, ndims}, sizemat_op);
102
- int64_t * sizemat_ptr = sizemat.mutable_data_ptr <int64_t >();
103
- int64_t numel = 0 ;
104
- for (int64_t i = 0 ; i < ntensors; i++) {
105
- const IntArrayRef& self_size = self_sizes[i],
106
- & mat2_size = mat2_sizes[i];
107
- int64_t & batch_size = batch_sizes[i];
108
- // batch dimensions
109
- for (int64_t j = 0 ; j < ndims - 2 ; j++) {
110
- const int64_t & self_sizej = self_size[j],
111
- & mat2_sizej = mat2_size[j];
112
- TORCH_CHECK (
113
- self_sizej == mat2_sizej,
114
- " matmul: For nested tensors, no broadcasting is currently performed: " ,
115
- i, " -th nested matrices in batch at dimension " , j + 1 ,
116
- " have mismatching sizes " , self_sizej, " and " , mat2_sizej);
117
- sizemat_ptr[j] = self_sizej;
118
- batch_size *= sizemat_ptr[j];
119
- }
120
- // matrix multiplication dimensions
121
- const int64_t & self_size0 = self_size[ndims - 2 ], & self_size1 = self_size[ndims - 1 ],
122
- & mat2_size0 = mat2_size[ndims - 2 ], & mat2_size1 = mat2_size[ndims - 1 ];
123
- TORCH_CHECK (
124
- self_size1 == mat2_size0,
125
- " matmul: " ,
126
- i, " -th nested matrices in batch cannot be multiplied (" ,
127
- self_size0, " x" , self_size1, " and " ,
128
- mat2_size0, " x" , mat2_size1, " )" );
129
- sizemat_ptr[ndims - 2 ] = self_size0;
130
- sizemat_ptr[ndims - 1 ] = mat2_size1;
131
- sizemat_ptr += ndims;
132
- numel += batch_size * self_size0 * mat2_size1;
133
- }
134
- Tensor buffer = at::empty (numel, buffer_op);
135
- Tensor output = wrap_buffer (buffer, sizemat);
136
- return std::make_tuple (batch_sizes, output);
137
- }
138
- }
139
-
140
82
Tensor matmul_with_bmm_nested (const Tensor& self, const Tensor& mat2) {
141
83
// Tensor self = self_.contiguous();
142
84
// Tensor mat2 = mat2_.contiguous();
0 commit comments