Skip to content

Commit 5679966

Browse files
committed
Merge branch 'softmax-fix-junzhang' into 'main'
Fix softmax tensor declaration in ctor See merge request dl/hugectr/hugectr!1513
2 parents 94dbb7e + bec7a6b commit 5679966

File tree

1 file changed

+9
-7
lines changed

1 file changed

+9
-7
lines changed

HugeCTR/src/layers/softmax_layer.cu

+9-7
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
#include <linalg/binary_op.cuh>
2323
#include <linalg/reduce.cuh>
2424
#include <linalg/unary_op.cuh>
25+
#include <network_buffer_channels.hpp>
2526
#include <utils.hpp>
2627
namespace HugeCTR {
2728

@@ -36,14 +37,15 @@ SoftmaxLayer<T>::SoftmaxLayer(const core23::Tensor& input_tensor,
3637
dims_ = input_tensor.shape().dims();
3738
hidden_size_ = input_tensor.shape().size(dims_ - 1);
3839
n_rows_ = len_ / hidden_size_;
39-
workspace23_ =
40-
core23::Tensor({(int64_t)n_rows_}, core23::DataType(core23::ToScalarType<T>::value));
41-
identity23_ =
42-
core23::Tensor({(int64_t)hidden_size_}, core23::DataType(core23::ToScalarType<T>::value));
43-
softmax_out23_ =
44-
core23::Tensor(input_tensor.shape(), core23::DataType(core23::ToScalarType<T>::value));
40+
core23::BufferParams buf_p{.channel = GetBlobsBufferChannel()};
41+
auto param = (input_tensor.my_params().buffer_params(buf_p));
42+
workspace23_ = core23::Tensor(
43+
param.shape({(int64_t)n_rows_}).data_type(core23::DataType(core23::ToScalarType<T>::value)));
44+
identity23_ = core23::Tensor(param.shape({(int64_t)hidden_size_})
45+
.data_type(core23::DataType(core23::ToScalarType<T>::value)));
46+
softmax_out23_ = core23::Tensor(param.shape(input_tensor.shape())
47+
.data_type(core23::DataType(core23::ToScalarType<T>::value)));
4548
}
46-
4749
template <typename T>
4850
void SoftmaxLayer<T>::initialize() {
4951
CudaDeviceContext context(get_device_id());

0 commit comments

Comments
 (0)