File tree 1 file changed +9
-7
lines changed
1 file changed +9
-7
lines changed Original file line number Diff line number Diff line change 22
22
#include < linalg/binary_op.cuh>
23
23
#include < linalg/reduce.cuh>
24
24
#include < linalg/unary_op.cuh>
25
+ #include < network_buffer_channels.hpp>
25
26
#include < utils.hpp>
26
27
namespace HugeCTR {
27
28
@@ -36,14 +37,15 @@ SoftmaxLayer<T>::SoftmaxLayer(const core23::Tensor& input_tensor,
36
37
dims_ = input_tensor.shape ().dims ();
37
38
hidden_size_ = input_tensor.shape ().size (dims_ - 1 );
38
39
n_rows_ = len_ / hidden_size_;
39
- workspace23_ =
40
- core23::Tensor ({(int64_t )n_rows_}, core23::DataType (core23::ToScalarType<T>::value));
41
- identity23_ =
42
- core23::Tensor ({(int64_t )hidden_size_}, core23::DataType (core23::ToScalarType<T>::value));
43
- softmax_out23_ =
44
- core23::Tensor (input_tensor.shape (), core23::DataType (core23::ToScalarType<T>::value));
40
+ core23::BufferParams buf_p{.channel = GetBlobsBufferChannel ()};
41
+ auto param = (input_tensor.my_params ().buffer_params (buf_p));
42
+ workspace23_ = core23::Tensor (
43
+ param.shape ({(int64_t )n_rows_}).data_type (core23::DataType (core23::ToScalarType<T>::value)));
44
+ identity23_ = core23::Tensor (param.shape ({(int64_t )hidden_size_})
45
+ .data_type (core23::DataType (core23::ToScalarType<T>::value)));
46
+ softmax_out23_ = core23::Tensor (param.shape (input_tensor.shape ())
47
+ .data_type (core23::DataType (core23::ToScalarType<T>::value)));
45
48
}
46
-
47
49
template <typename T>
48
50
void SoftmaxLayer<T>::initialize() {
49
51
CudaDeviceContext context (get_device_id ());
You can’t perform that action at this time.
0 commit comments