Skip to content

Commit 819aa2a

Browse files
committed
Fix an issue with cudnn conv3d.
Also checked in some code for bfloat16 support, not done yet.
1 parent c0bd3a5 commit 819aa2a

File tree

5 files changed

+220
-15
lines changed

5 files changed

+220
-15
lines changed

lib/ccv.h

+17-1
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,11 @@ static const ssize_t _ccv_get_data_type_size[] = {
6767
-1, -1, -1, -1, -1, -1, -1, 8,
6868
-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 2,
6969
-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
70-
-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 1
70+
-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 1,
71+
-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
72+
-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
73+
-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
74+
-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 2
7175
};
7276

7377
#define CCV_GET_DATA_TYPE(x) ((x) & 0xFF000)
@@ -885,10 +889,22 @@ ccv_dense_matrix_t ccv_reshape(ccv_dense_matrix_t* a, int y, int x, int rows, in
885889
void ccv_float_to_half_precision(const float* f, uint16_t* h, size_t len);
886890
void ccv_half_precision_to_float(const uint16_t* h, float* f, size_t len);
887891

892+
// 32-bit float to 16-bit bfloat
893+
void ccv_float_to_bfloat(const float* f, uint16_t* h, size_t len);
894+
void ccv_bfloat_to_float(const uint16_t* h, float* f, size_t len);
895+
888896
// 64-bit float to 16-bit float
889897
void ccv_double_to_half_precision(const double* f, uint16_t* h, size_t len);
890898
void ccv_half_precision_to_double(const uint16_t* h, double* f, size_t len);
891899

900+
// 64-bit float to 16-bit bfloat
901+
void ccv_double_to_bfloat(const double* f, uint16_t* h, size_t len);
902+
void ccv_bfloat_to_double(const uint16_t* h, double* f, size_t len);
903+
904+
// 16-bit float to 16-bit bfloat
905+
void ccv_bfloat_to_half_precision(const uint16_t* h, uint16_t* f, size_t len);
906+
void ccv_half_precision_to_bfloat(const uint16_t* f, uint16_t* h, size_t len);
907+
892908
/* basic data structures ccv_util.c */
893909

894910
typedef struct {

lib/ccv_util.c

+73
Original file line numberDiff line numberDiff line change
@@ -1439,6 +1439,27 @@ void ccv_float_to_half_precision(const float* f, uint16_t* h, size_t len)
14391439
h[i] = _ccv_base_table[(u[i] >> 23) & 0x1ff] + ((u[i] & 0x007fffff) >> _ccv_shift_table[(u[i] >> 23) & 0x1ff]);
14401440
}
14411441

1442+
void ccv_float_to_bfloat(const float* f, uint16_t* h, size_t len)
1443+
{
1444+
int i;
1445+
const uint16_t* u = (const uint16_t*)f;
1446+
for (i = 0; i < len; i++)
1447+
h[i] = u[i * 2 + 1];
1448+
}
1449+
1450+
void ccv_bfloat_to_half_precision(const uint16_t* h, uint16_t* f, size_t len)
1451+
{
1452+
int i;
1453+
for (i = 0; i < len; i++)
1454+
{
1455+
union {
1456+
uint16_t h[2];
1457+
uint32_t p;
1458+
} u = { .h = { 0, f[i] } };
1459+
f[i] = _ccv_base_table[(u.p >> 23) & 0x1ff] + ((u.p & 0x007fffff) >> _ccv_shift_table[(u.p >> 23) & 0x1ff]);
1460+
}
1461+
}
1462+
14421463
void ccv_double_to_half_precision(const double* f, uint16_t* h, size_t len)
14431464
{
14441465
int i;
@@ -1452,6 +1473,19 @@ void ccv_double_to_half_precision(const double* f, uint16_t* h, size_t len)
14521473
}
14531474
}
14541475

1476+
void ccv_double_to_bfloat(const double* f, uint16_t* h, size_t len)
1477+
{
1478+
int i;
1479+
for (i = 0; i < len; i++)
1480+
{
1481+
union {
1482+
float v;
1483+
uint16_t h[2];
1484+
} u = { .v = (const float)f[i] };
1485+
h[i] = u.h[1];
1486+
}
1487+
}
1488+
14551489
static uint32_t _ccv_mantissa_table[2048] = {
14561490
0x0, 0x33800000, 0x34000000, 0x34400000, 0x34800000, 0x34a00000, 0x34c00000, 0x34e00000,
14571491
0x35000000, 0x35100000, 0x35200000, 0x35300000, 0x35400000, 0x35500000, 0x35600000, 0x35700000,
@@ -1741,6 +1775,32 @@ void ccv_half_precision_to_float(const uint16_t* h, float* f, size_t len)
17411775
u[i] = _ccv_mantissa_table[_ccv_offset_table[h[i] >> 10] + (h[i] & 0x3ff)] + _ccv_exponent_table[h[i] >> 10];
17421776
}
17431777

1778+
void ccv_bfloat_to_float(const uint16_t* h, float* f, size_t len)
1779+
{
1780+
int i;
1781+
uint16_t* u = (uint16_t*)f;
1782+
for (i = 0; i < len; i++)
1783+
{
1784+
u[i * 2] = 0;
1785+
u[i * 2 + 1] = h[i];
1786+
}
1787+
}
1788+
1789+
void ccv_half_precision_to_bfloat(const uint16_t* h, uint16_t* f, size_t len)
1790+
{
1791+
int i;
1792+
for (i = 0; i < len; i++)
1793+
{
1794+
union {
1795+
uint16_t h[2];
1796+
uint32_t p;
1797+
} u = {
1798+
.p = _ccv_mantissa_table[_ccv_offset_table[h[i] >> 10] + (h[i] & 0x3ff)] + _ccv_exponent_table[h[i] >> 10]
1799+
};
1800+
f[i] = u.h[1];
1801+
}
1802+
}
1803+
17441804
void ccv_half_precision_to_double(const uint16_t* h, double* f, size_t len)
17451805
{
17461806
int i;
@@ -1756,6 +1816,19 @@ void ccv_half_precision_to_double(const uint16_t* h, double* f, size_t len)
17561816
}
17571817
}
17581818

1819+
void ccv_bfloat_to_double(const uint16_t* h, double* f, size_t len)
1820+
{
1821+
int i;
1822+
for (i = 0; i < len; i++)
1823+
{
1824+
union {
1825+
float v;
1826+
uint16_t h[2];
1827+
} u = { .h = { 0, h[1] } };
1828+
f[i] = u.v;
1829+
}
1830+
}
1831+
17591832
void ccv_array_push(ccv_array_t* array, const void* r)
17601833
{
17611834
array->rnum++;

lib/nnc/cmd/util/ccv_nnc_util_cpu_ref.c

+120-11
Original file line numberDiff line numberDiff line change
@@ -386,6 +386,77 @@ void _ccv_nnc_tensor_set_cpu_ref_f16(ccv_nnc_tensor_view_t* const a, const float
386386
}
387387
}
388388

389+
void _ccv_nnc_tensor_set_cpu_ref_bf16(ccv_nnc_tensor_view_t* const a, const float b)
390+
{
391+
// Assuming this is short.
392+
int dim[CCV_NNC_MAX_DIM_ALLOC];
393+
int astride[CCV_NNC_MAX_DIM_ALLOC];
394+
short h;
395+
ccv_float_to_bfloat((float*)&b, (uint16_t*)&h, 1);
396+
int x;
397+
if (!CCV_IS_TENSOR_VIEW(a))
398+
{
399+
// Super optimal case, just do one for-loop for sum.
400+
const int tensor_count = ccv_nnc_tensor_count(a->info);
401+
for (x = 0; x < tensor_count; x++)
402+
a->data.f16[x].v = h;
403+
return;
404+
}
405+
assert(CCV_NNC_MAX_DIM == 2); // Need to change this logic for CCV_NNC_MAX_DIM == other number.
406+
ccv_nnc_tensor_view_get_dim(a, dim);
407+
ccv_nnc_tensor_view_get_stride(a, astride);
408+
int i[CCV_NNC_MAX_DIM + 2];
409+
short* const ap = (short*)a->data.f16;
410+
const int count = dim[2] * dim[3];
411+
if (astride[2] == dim[3])
412+
{
413+
// Special casing if the ainc[3] is the same as dim[3]
414+
for (i[0] = 0; i[0] < dim[0]; i[0]++)
415+
{
416+
short* ap0 = ap + i[0] * astride[0];
417+
for (i[1] = 0; i[1] < dim[1]; i[1]++)
418+
{
419+
for (x = 0; x < count; x++)
420+
ap0[x] = h;
421+
ap0 += astride[1];
422+
}
423+
}
424+
return;
425+
} else if (astride[3] == 1) {
426+
// The case the last dimension is packed.
427+
for (i[0] = 0; i[0] < dim[0]; i[0]++)
428+
{
429+
short* const ap0 = ap + i[0] * astride[0];
430+
for (i[1] = 0; i[1] < dim[1]; i[1]++)
431+
{
432+
short* ap1 = ap0 + i[1] * astride[1];
433+
for (i[2] = 0; i[2] < dim[2]; i[2]++)
434+
{
435+
for (x = 0; x < dim[3]; x++)
436+
ap1[x] = h;
437+
ap1 += astride[2];
438+
}
439+
}
440+
}
441+
return;
442+
}
443+
// Non-optimal case, need to do skip copy.
444+
for (i[0] = 0; i[0] < dim[0]; i[0]++)
445+
{
446+
short* const ap0 = ap + i[0] * astride[0];
447+
for (i[1] = 0; i[1] < dim[1]; i[1]++)
448+
{
449+
short* ap1 = ap0 + i[1] * astride[1];
450+
for (i[2] = 0; i[2] < dim[2]; i[2]++)
451+
{
452+
for (x = 0; x < dim[3]; x++)
453+
ap1[x * astride[3]] = h;
454+
ap1 += astride[2];
455+
}
456+
}
457+
}
458+
}
459+
389460
void _ccv_nnc_tensor_set_cpu_ref_f32(ccv_nnc_tensor_view_t* const a, const float b)
390461
{
391462
// Assuming this is float 32.
@@ -603,7 +674,7 @@ static int _ccv_nnc_data_transfer(const ccv_nnc_cmd_t cmd, const ccv_nnc_hint_t
603674
if (a != b) // Only do transfer if these are two different tensors.
604675
{
605676
assert(a->info.datatype == b->info.datatype);
606-
if (a->info.datatype == CCV_16F)
677+
if (a->info.datatype == CCV_16F || a->info.datatype == CCV_16BF)
607678
_ccv_nnc_tensor_transfer_cpu_ref_f16(a, b);
608679
else if (a->info.datatype == CCV_32F || a->info.datatype == CCV_32S)
609680
_ccv_nnc_tensor_transfer_cpu_ref_f32(a, b);
@@ -619,7 +690,7 @@ static int _ccv_nnc_data_transfer(const ccv_nnc_cmd_t cmd, const ccv_nnc_hint_t
619690
REGISTER_COMMAND_BACKEND(CCV_NNC_DATA_TRANSFER_FORWARD, CCV_NNC_BACKEND_CPU_REF)(ccv_nnc_cmd_backend_registry_t* const registry)
620691
{
621692
registry->tensor_formats = CCV_TENSOR_FORMAT_NCHW | CCV_TENSOR_FORMAT_NHWC | CCV_TENSOR_FORMAT_CHWN;
622-
registry->tensor_datatypes = CCV_64F | CCV_32F | CCV_16F | CCV_32S;
693+
registry->tensor_datatypes = CCV_64F | CCV_32F | CCV_16F | CCV_32S | CCV_16BF;
623694
registry->tensor_memory = CCV_TENSOR_CPU_MEMORY;
624695
registry->algorithms = 1;
625696
registry->exec = _ccv_nnc_data_transfer;
@@ -628,7 +699,7 @@ REGISTER_COMMAND_BACKEND(CCV_NNC_DATA_TRANSFER_FORWARD, CCV_NNC_BACKEND_CPU_REF)
628699
REGISTER_COMMAND_BACKEND(CCV_NNC_DATA_TRANSFER_BACKWARD, CCV_NNC_BACKEND_CPU_REF)(ccv_nnc_cmd_backend_registry_t* const registry)
629700
{
630701
registry->tensor_formats = CCV_TENSOR_FORMAT_NCHW | CCV_TENSOR_FORMAT_NHWC | CCV_TENSOR_FORMAT_CHWN;
631-
registry->tensor_datatypes = CCV_64F | CCV_32F | CCV_16F | CCV_32S;
702+
registry->tensor_datatypes = CCV_64F | CCV_32F | CCV_16F | CCV_32S | CCV_16BF;
632703
registry->tensor_memory = CCV_TENSOR_CPU_MEMORY;
633704
registry->algorithms = 1;
634705
registry->exec = _ccv_nnc_data_transfer;
@@ -644,6 +715,8 @@ static int _ccv_nnc_set_forw(const ccv_nnc_cmd_t cmd, const ccv_nnc_hint_t hint,
644715
for (i = 0; i < output_size; i++)
645716
if (outputs[i]->info.datatype == CCV_16F)
646717
_ccv_nnc_tensor_set_cpu_ref_f16((ccv_nnc_tensor_view_t*)outputs[i], cmd.info.blas.a[0]);
718+
else if (outputs[i]->info.datatype == CCV_16BF)
719+
_ccv_nnc_tensor_set_cpu_ref_bf16((ccv_nnc_tensor_view_t*)outputs[i], cmd.info.blas.a[0]);
647720
else if (outputs[i]->info.datatype == CCV_32F)
648721
_ccv_nnc_tensor_set_cpu_ref_f32((ccv_nnc_tensor_view_t*)outputs[i], cmd.info.blas.a[0]);
649722
else if (outputs[i]->info.datatype == CCV_64F)
@@ -666,7 +739,7 @@ static int _ccv_nnc_set_back(const ccv_nnc_cmd_t cmd, const ccv_nnc_hint_t hint,
666739
REGISTER_COMMAND_BACKEND(CCV_NNC_SET_FORWARD, CCV_NNC_BACKEND_CPU_REF)(ccv_nnc_cmd_backend_registry_t* const registry)
667740
{
668741
registry->tensor_formats = CCV_TENSOR_FORMAT_NCHW | CCV_TENSOR_FORMAT_NHWC | CCV_TENSOR_FORMAT_CHWN;
669-
registry->tensor_datatypes = CCV_64F | CCV_32F | CCV_16F | CCV_32S;
742+
registry->tensor_datatypes = CCV_64F | CCV_32F | CCV_16F | CCV_32S | CCV_16BF;
670743
registry->tensor_memory = CCV_TENSOR_CPU_MEMORY;
671744
registry->algorithms = 1;
672745
registry->exec = _ccv_nnc_set_forw;
@@ -675,7 +748,7 @@ REGISTER_COMMAND_BACKEND(CCV_NNC_SET_FORWARD, CCV_NNC_BACKEND_CPU_REF)(ccv_nnc_c
675748
REGISTER_COMMAND_BACKEND(CCV_NNC_SET_BACKWARD, CCV_NNC_BACKEND_CPU_REF)(ccv_nnc_cmd_backend_registry_t* const registry)
676749
{
677750
registry->tensor_formats = CCV_TENSOR_FORMAT_NCHW | CCV_TENSOR_FORMAT_NHWC | CCV_TENSOR_FORMAT_CHWN;
678-
registry->tensor_datatypes = CCV_64F | CCV_32F | CCV_16F;
751+
registry->tensor_datatypes = CCV_64F | CCV_32F | CCV_16F | CCV_32S | CCV_16BF;
679752
registry->tensor_memory = CCV_TENSOR_CPU_MEMORY;
680753
registry->algorithms = 1;
681754
registry->exec = _ccv_nnc_set_back;
@@ -1040,7 +1113,7 @@ static int _ccv_nnc_format_transform(const ccv_nnc_cmd_t cmd, const ccv_nnc_hint
10401113
} else if (a->info.format == CCV_TENSOR_FORMAT_CHWN && b->info.format == CCV_TENSOR_FORMAT_NCHW) {
10411114
assert(0);
10421115
}
1043-
} else if (a->info.datatype == CCV_16F) {
1116+
} else if (a->info.datatype == CCV_16F || a->info.datatype == CCV_16BF) {
10441117
if (a->info.format == b->info.format) {
10451118
// If it is the same, just do a normal data transfer.
10461119
_ccv_nnc_tensor_transfer_cpu_ref_f16(a, b);
@@ -1084,7 +1157,7 @@ static int _ccv_nnc_format_transform(const ccv_nnc_cmd_t cmd, const ccv_nnc_hint
10841157
REGISTER_COMMAND_BACKEND(CCV_NNC_FORMAT_TRANSFORM_FORWARD, CCV_NNC_BACKEND_CPU_REF)(ccv_nnc_cmd_backend_registry_t* const registry)
10851158
{
10861159
registry->tensor_formats = CCV_TENSOR_FORMAT_NCHW | CCV_TENSOR_FORMAT_NHWC | CCV_TENSOR_FORMAT_CHWN;
1087-
registry->tensor_datatypes = CCV_64F | CCV_32F | CCV_32S | CCV_16F | CCV_8U;
1160+
registry->tensor_datatypes = CCV_64F | CCV_32F | CCV_32S | CCV_16F | CCV_8U | CCV_16BF;
10881161
registry->tensor_memory = CCV_TENSOR_CPU_MEMORY;
10891162
registry->algorithms = 1;
10901163
registry->exec = _ccv_nnc_format_transform;
@@ -1093,7 +1166,7 @@ REGISTER_COMMAND_BACKEND(CCV_NNC_FORMAT_TRANSFORM_FORWARD, CCV_NNC_BACKEND_CPU_R
10931166
REGISTER_COMMAND_BACKEND(CCV_NNC_FORMAT_TRANSFORM_BACKWARD, CCV_NNC_BACKEND_CPU_REF)(ccv_nnc_cmd_backend_registry_t* const registry)
10941167
{
10951168
registry->tensor_formats = CCV_TENSOR_FORMAT_NCHW | CCV_TENSOR_FORMAT_NHWC | CCV_TENSOR_FORMAT_CHWN;
1096-
registry->tensor_datatypes = CCV_64F | CCV_32F | CCV_32S | CCV_16F | CCV_8U;
1169+
registry->tensor_datatypes = CCV_64F | CCV_32F | CCV_32S | CCV_16F | CCV_8U | CCV_16BF;
10971170
registry->tensor_memory = CCV_TENSOR_CPU_MEMORY;
10981171
registry->algorithms = 1;
10991172
registry->exec = _ccv_nnc_format_transform;
@@ -1209,7 +1282,7 @@ static int _ccv_nnc_datatype_conversion(const ccv_nnc_cmd_t cmd, const ccv_nnc_h
12091282
assert(a->info.format == b->info.format);
12101283
if (a->info.datatype == b->info.datatype) {
12111284
// If it is the same, just do a normal data transfer.
1212-
if (a->info.datatype == CCV_16F)
1285+
if (a->info.datatype == CCV_16F || a->info.datatype == CCV_16BF)
12131286
_ccv_nnc_tensor_transfer_cpu_ref_f16(a, b);
12141287
else if (a->info.datatype == CCV_32F)
12151288
_ccv_nnc_tensor_transfer_cpu_ref_f32(a, b);
@@ -1254,6 +1327,42 @@ static int _ccv_nnc_datatype_conversion(const ccv_nnc_cmd_t cmd, const ccv_nnc_h
12541327
const int tensor_count = ccv_nnc_tensor_count(a->info);
12551328
assert(tensor_count == ccv_nnc_tensor_count(b->info));
12561329
ccv_half_precision_to_double((uint16_t*)a->data.f16, b->data.f64, tensor_count);
1330+
} else if (a->info.datatype == CCV_16F && b->info.datatype == CCV_16BF) {
1331+
assert(CCV_IS_TENSOR_CONTIGUOUS(a));
1332+
assert(CCV_IS_TENSOR_CONTIGUOUS(b));
1333+
const size_t tensor_count = ccv_nnc_tensor_count(a->info);
1334+
assert(tensor_count == ccv_nnc_tensor_count(b->info));
1335+
ccv_half_precision_to_bfloat((uint16_t*)a->data.f16, (uint16_t*)b->data.f16, tensor_count);
1336+
} else if (a->info.datatype == CCV_16BF && b->info.datatype == CCV_16F) {
1337+
assert(CCV_IS_TENSOR_CONTIGUOUS(a));
1338+
assert(CCV_IS_TENSOR_CONTIGUOUS(b));
1339+
const int tensor_count = ccv_nnc_tensor_count(a->info);
1340+
assert(tensor_count == ccv_nnc_tensor_count(b->info));
1341+
ccv_bfloat_to_half_precision((uint16_t*)a->data.f16, (uint16_t*)b->data.f16, tensor_count);
1342+
} else if (a->info.datatype == CCV_32F && b->info.datatype == CCV_16BF) {
1343+
assert(CCV_IS_TENSOR_CONTIGUOUS(a));
1344+
assert(CCV_IS_TENSOR_CONTIGUOUS(b));
1345+
const size_t tensor_count = ccv_nnc_tensor_count(a->info);
1346+
assert(tensor_count == ccv_nnc_tensor_count(b->info));
1347+
ccv_float_to_bfloat(a->data.f32, (uint16_t*)b->data.f16, tensor_count);
1348+
} else if (a->info.datatype == CCV_16BF && b->info.datatype == CCV_32F) {
1349+
assert(CCV_IS_TENSOR_CONTIGUOUS(a));
1350+
assert(CCV_IS_TENSOR_CONTIGUOUS(b));
1351+
const int tensor_count = ccv_nnc_tensor_count(a->info);
1352+
assert(tensor_count == ccv_nnc_tensor_count(b->info));
1353+
ccv_bfloat_to_float((uint16_t*)a->data.f16, b->data.f32, tensor_count);
1354+
} else if (a->info.datatype == CCV_64F && b->info.datatype == CCV_16BF) {
1355+
assert(CCV_IS_TENSOR_CONTIGUOUS(a));
1356+
assert(CCV_IS_TENSOR_CONTIGUOUS(b));
1357+
const size_t tensor_count = ccv_nnc_tensor_count(a->info);
1358+
assert(tensor_count == ccv_nnc_tensor_count(b->info));
1359+
ccv_double_to_bfloat(a->data.f64, (uint16_t*)b->data.f16, tensor_count);
1360+
} else if (a->info.datatype == CCV_16BF && b->info.datatype == CCV_64F) {
1361+
assert(CCV_IS_TENSOR_CONTIGUOUS(a));
1362+
assert(CCV_IS_TENSOR_CONTIGUOUS(b));
1363+
const int tensor_count = ccv_nnc_tensor_count(a->info);
1364+
assert(tensor_count == ccv_nnc_tensor_count(b->info));
1365+
ccv_bfloat_to_double((uint16_t*)a->data.f16, b->data.f64, tensor_count);
12571366
}
12581367
}
12591368
return CCV_NNC_EXEC_SUCCESS;
@@ -1262,7 +1371,7 @@ static int _ccv_nnc_datatype_conversion(const ccv_nnc_cmd_t cmd, const ccv_nnc_h
12621371
REGISTER_COMMAND_BACKEND(CCV_NNC_DATATYPE_CONVERSION_FORWARD, CCV_NNC_BACKEND_CPU_REF)(ccv_nnc_cmd_backend_registry_t* const registry)
12631372
{
12641373
registry->tensor_formats = CCV_TENSOR_FORMAT_NCHW | CCV_TENSOR_FORMAT_NHWC | CCV_TENSOR_FORMAT_CHWN;
1265-
registry->tensor_datatypes = CCV_64F | CCV_32F | CCV_16F;
1374+
registry->tensor_datatypes = CCV_64F | CCV_32F | CCV_16F | CCV_16BF;
12661375
registry->tensor_memory = CCV_TENSOR_CPU_MEMORY;
12671376
registry->algorithms = 1;
12681377
registry->exec = _ccv_nnc_datatype_conversion;
@@ -1271,7 +1380,7 @@ REGISTER_COMMAND_BACKEND(CCV_NNC_DATATYPE_CONVERSION_FORWARD, CCV_NNC_BACKEND_CP
12711380
REGISTER_COMMAND_BACKEND(CCV_NNC_DATATYPE_CONVERSION_BACKWARD, CCV_NNC_BACKEND_CPU_REF)(ccv_nnc_cmd_backend_registry_t* const registry)
12721381
{
12731382
registry->tensor_formats = CCV_TENSOR_FORMAT_NCHW | CCV_TENSOR_FORMAT_NHWC | CCV_TENSOR_FORMAT_CHWN;
1274-
registry->tensor_datatypes = CCV_64F | CCV_32F | CCV_16F;
1383+
registry->tensor_datatypes = CCV_64F | CCV_32F | CCV_16F | CCV_16BF;
12751384
registry->tensor_memory = CCV_TENSOR_CPU_MEMORY;
12761385
registry->algorithms = 1;
12771386
registry->exec = _ccv_nnc_datatype_conversion;

lib/nnc/cmd/util/gpu/ccv_nnc_util_gpu_ref.cu

+2-2
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,7 @@ static int _ccv_nnc_datatype_conversion(const ccv_nnc_cmd_t cmd, const ccv_nnc_h
145145
REGISTER_COMMAND_BACKEND(CCV_NNC_DATATYPE_CONVERSION_FORWARD, CCV_NNC_BACKEND_GPU_REF)(ccv_nnc_cmd_backend_registry_t* const registry)
146146
{
147147
registry->tensor_formats = CCV_TENSOR_FORMAT_NCHW | CCV_TENSOR_FORMAT_NHWC | CCV_TENSOR_FORMAT_CHWN;
148-
registry->tensor_datatypes = CCV_64F | CCV_32F | CCV_16F;
148+
registry->tensor_datatypes = CCV_64F | CCV_32F | CCV_16F | CCV_16BF;
149149
registry->tensor_memory = CCV_TENSOR_GPU_MEMORY;
150150
registry->algorithms = 1;
151151
registry->exec = _ccv_nnc_datatype_conversion;
@@ -154,7 +154,7 @@ REGISTER_COMMAND_BACKEND(CCV_NNC_DATATYPE_CONVERSION_FORWARD, CCV_NNC_BACKEND_GP
154154
REGISTER_COMMAND_BACKEND(CCV_NNC_DATATYPE_CONVERSION_BACKWARD, CCV_NNC_BACKEND_GPU_REF)(ccv_nnc_cmd_backend_registry_t* const registry)
155155
{
156156
registry->tensor_formats = CCV_TENSOR_FORMAT_NCHW | CCV_TENSOR_FORMAT_NHWC | CCV_TENSOR_FORMAT_CHWN;
157-
registry->tensor_datatypes = CCV_64F | CCV_32F | CCV_16F;
157+
registry->tensor_datatypes = CCV_64F | CCV_32F | CCV_16F | CCV_16BF;
158158
registry->tensor_memory = CCV_TENSOR_GPU_MEMORY;
159159
registry->algorithms = 1;
160160
registry->exec = _ccv_nnc_datatype_conversion;

0 commit comments

Comments
 (0)