@@ -971,7 +971,7 @@ __global__ void kPreconditionOptimizer32bit1State(T* g, T* p,
971
971
972
972
template <typename T, int OPTIMIZER>
973
973
__launch_bounds__ (TH, 1 )
974
- __global__ void kOptimizer32bit1State(T *g, T *p,
974
+ __global__ void kOptimizer32bit1State(T *g, T *p, T *return_updates,
975
975
float *state1, float *unorm, const float max_unorm, const float param_norm,
976
976
const float beta1, const float beta2, const float eps, const float weight_decay,
977
977
const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n)
@@ -1017,13 +1017,13 @@ __global__ void kOptimizer32bit1State(T *g, T *p,
1017
1017
__syncthreads ();
1018
1018
LoadFloat (temp_storage.loadf ).Load (&(state1[i]), s1_vals, valid_items);
1019
1019
__syncthreads ();
1020
- Load (temp_storage.load ).Load (&(p[i]), p_vals, valid_items);
1020
+ Load (temp_storage.load ).Load (return_updates == nullptr ? &(p[i]) : &(return_updates [i]), p_vals, valid_items);
1021
1021
1022
1022
# pragma unroll 4
1023
1023
for (unsigned int j = 0 ; j < NUM_PER_THREAD; j++)
1024
1024
{
1025
1025
g_vals[j] = gnorm_scale*((float )g_vals[j]);
1026
- if (weight_decay > 0 .0f )
1026
+ if (weight_decay > 0 .0f && return_updates == nullptr )
1027
1027
g_vals[j] = (float )g_vals[j] + (((float )p_vals[j])*weight_decay);
1028
1028
}
1029
1029
@@ -1040,26 +1040,26 @@ __global__ void kOptimizer32bit1State(T *g, T *p,
1040
1040
else
1041
1041
s1_vals[j] = s1_vals[j]*beta1 + ((float )g_vals[j]);
1042
1042
1043
- p_vals[j] = ((float )p_vals[j]) + update_scale*(-lr*(s1_vals[j]));
1043
+ p_vals[j] = (return_updates == nullptr ? (float )p_vals[j] : 0 . 0f ) + update_scale*(-lr*(s1_vals[j]));
1044
1044
break ;
1045
1045
case LION:
1046
- p_vals[j] = ((float )p_vals[j]) - update_scale*(lr*sgn (((float )s1_vals[j])*beta1 + ((1 .0f -beta1)*((float )g_vals[j]))));
1046
+ p_vals[j] = (return_updates == nullptr ? (float )p_vals[j] : 0 . 0f ) - update_scale*(lr*sgn (((float )s1_vals[j])*beta1 + ((1 .0f -beta1)*((float )g_vals[j]))));
1047
1047
s1_vals[j] = s1_vals[j]*beta2 + ((1 .0f -beta2)*((float )g_vals[j]));
1048
1048
break ;
1049
1049
case RMSPROP:
1050
1050
s1_vals[j] = s1_vals[j]*beta1 + ((1 .0f -beta1)*((float )g_vals[j])*((float )g_vals[j]));
1051
- p_vals[j] = ((float )p_vals[j]) - update_scale*(lr*__fdividef ((float )g_vals[j],sqrtf ((float )s1_vals[j])+eps));
1051
+ p_vals[j] = (return_updates == nullptr ? (float )p_vals[j] : 0 . 0f ) - update_scale*(lr*__fdividef ((float )g_vals[j],sqrtf ((float )s1_vals[j])+eps));
1052
1052
break ;
1053
1053
case ADAGRAD:
1054
1054
s1_vals[j] = s1_vals[j] + ((float )g_vals[j])*((float )g_vals[j]);
1055
- p_vals[j] = ((float )p_vals[j]) - lr*__fdividef ((float )g_vals[j],sqrtf ((float )s1_vals[j])+eps);
1055
+ p_vals[j] = (return_updates == nullptr ? (float )p_vals[j] : 0 . 0f ) - lr*__fdividef ((float )g_vals[j],sqrtf ((float )s1_vals[j])+eps);
1056
1056
break ;
1057
1057
}
1058
1058
}
1059
1059
}
1060
1060
1061
1061
__syncthreads ();
1062
- Store (temp_storage.store ).Store (&(p[i]), p_vals, valid_items);
1062
+ Store (temp_storage.store ).Store (return_updates == nullptr ? &(p[i]) : &(return_updates [i]), p_vals, valid_items);
1063
1063
__syncthreads ();
1064
1064
StoreFloat (temp_storage.storef ).Store (&(state1[i]), s1_vals, valid_items);
1065
1065
}
@@ -1406,7 +1406,7 @@ kPreconditionOptimizerStatic8bit1State(T* p, T* __restrict__ const g, unsigned c
1406
1406
template <typename T, int OPTIMIZER>
1407
1407
__global__ void
1408
1408
__launch_bounds__ (1024 , 1 )
1409
- kOptimizerStatic8bit1State(T* p, T* const g, unsigned char * state1,
1409
+ kOptimizerStatic8bit1State(T* p, T* const g, T* return_updates, unsigned char * state1,
1410
1410
const float *unorm, const float max_unorm, const float param_norm,
1411
1411
const float beta1, const float beta2,
1412
1412
const float eps, const int step, const float lr,
@@ -1462,7 +1462,7 @@ kOptimizerStatic8bit1State(T* p, T* const g, unsigned char* state1,
1462
1462
__syncthreads ();
1463
1463
LoadChar (temp_storage.loadc ).Load (&(state1[i]), c1s, valid_items, 128 );
1464
1464
__syncthreads ();
1465
- LoadT (temp_storage.loadh ).Load (&(p[i]), p_vals, valid_items);
1465
+ LoadT (temp_storage.loadh ).Load (return_updates == nullptr ? &(p[i]) : &(return_updates [i]), p_vals, valid_items);
1466
1466
1467
1467
if ((i + (threadIdx .x *NUM_PER_THREAD2) + NUM_PER_THREAD2) > n){ continue ; }
1468
1468
@@ -1472,7 +1472,7 @@ kOptimizerStatic8bit1State(T* p, T* const g, unsigned char* state1,
1472
1472
g_val = float (g_vals[j]);
1473
1473
g_val *= gnorm_scale;
1474
1474
1475
- if (weight_decay > 0 .0f ) {
1475
+ if (weight_decay > 0 .0f && return_updates == nullptr ) {
1476
1476
switch (OPTIMIZER) {
1477
1477
case MOMENTUM:
1478
1478
case RMSPROP:
@@ -1494,15 +1494,15 @@ kOptimizerStatic8bit1State(T* p, T* const g, unsigned char* state1,
1494
1494
else
1495
1495
s1_vals[j] = s1_vals[j]*beta1 + ((float )g_vals[j]);
1496
1496
1497
- p_vals[j] = ((float )p_vals[j]) + (-lr*update_scale*(s1_vals[j]));
1497
+ p_vals[j] = (return_updates == nullptr ? (float )p_vals[j] : 0 . 0f ) + (-lr*update_scale*(s1_vals[j]));
1498
1498
break ;
1499
1499
case LION:
1500
- p_vals[j] = ((float )p_vals[j]) - (lr*sgn (((float )s1_vals[j])*beta1 + ((1 .0f -beta1)*((float )g_val))));
1500
+ p_vals[j] = (return_updates == nullptr ? (float )p_vals[j] : 0 . 0f ) - (lr*sgn (((float )s1_vals[j])*beta1 + ((1 .0f -beta1)*((float )g_val))));
1501
1501
s1_vals[j] = s1_vals[j]*beta2 + ((1 .0f -beta2)*g_val);
1502
1502
break ;
1503
1503
case RMSPROP:
1504
1504
s1_vals[j] = s1_vals[j]*beta1 + ((1 .0f -beta1)*(g_val*g_val));
1505
- p_vals[j] = ((float )p_vals[j]) - (lr*__fdividef (g_val,sqrtf (s1_vals[j])+eps));
1505
+ p_vals[j] = (return_updates == nullptr ? (float )p_vals[j] : 0 . 0f ) - (lr*__fdividef (g_val,sqrtf (s1_vals[j])+eps));
1506
1506
break ;
1507
1507
}
1508
1508
@@ -1518,7 +1518,7 @@ kOptimizerStatic8bit1State(T* p, T* const g, unsigned char* state1,
1518
1518
}
1519
1519
}
1520
1520
1521
- StoreT (temp_storage.storeh ).Store (&(p[i]), p_vals, valid_items);
1521
+ StoreT (temp_storage.storeh ).Store (return_updates == nullptr ? &(p[i]) : &(return_updates [i]), p_vals, valid_items);
1522
1522
__syncthreads ();
1523
1523
StoreChar (temp_storage.storec ).Store (&(state1[i]), c1s, valid_items);
1524
1524
__syncthreads ();
@@ -1769,7 +1769,7 @@ kOptimizerStatic8bit2StateBlockwise(T* p, T* __restrict__ const g, T* return_upd
1769
1769
template <typename T, int OPTIMIZER, int BLOCK_SIZE, int N_PER_TH>
1770
1770
__launch_bounds__ (256 , 3 )
1771
1771
__global__ void
1772
- kOptimizerStatic8bit1StateBlockwise(T* p, T* __restrict__ const g, unsigned char * state1,
1772
+ kOptimizerStatic8bit1StateBlockwise(T* p, T* __restrict__ const g, T* return_updates, unsigned char * state1,
1773
1773
const float beta1, const float beta2,
1774
1774
const float eps, const int step, const float lr,
1775
1775
float * __restrict__ const quantiles1,
@@ -1833,7 +1833,7 @@ kOptimizerStatic8bit1StateBlockwise(T* p, T* __restrict__ const g, unsigned char
1833
1833
__syncthreads ();
1834
1834
LoadChar (temp_storage.loadc ).Load (&(state1[i]), c1s, valid_items, 128 );
1835
1835
__syncthreads ();
1836
- LoadT (temp_storage.loadh ).Load (&(p[i]), p_vals, valid_items, (T)0 .0f );
1836
+ LoadT (temp_storage.loadh ).Load (return_updates == nullptr ? &(p[i]) : &(return_updates [i]), p_vals, valid_items, (T)0 .0f );
1837
1837
1838
1838
new_local_abs_max1 = -FLT_MAX;
1839
1839
@@ -1845,7 +1845,7 @@ kOptimizerStatic8bit1StateBlockwise(T* p, T* __restrict__ const g, unsigned char
1845
1845
g_val *= gnorm_scale;
1846
1846
if (!skip_zeros || (skip_zeros && ((float )g_vals[j] != 0 .0f )))
1847
1847
{
1848
- if (weight_decay > 0 .0f ) {
1848
+ if (weight_decay > 0 .0f && return_updates == nullptr ) {
1849
1849
switch (OPTIMIZER) {
1850
1850
case MOMENTUM:
1851
1851
case ADAGRAD:
@@ -1908,18 +1908,18 @@ kOptimizerStatic8bit1StateBlockwise(T* p, T* __restrict__ const g, unsigned char
1908
1908
switch (OPTIMIZER)
1909
1909
{
1910
1910
case MOMENTUM:
1911
- p_vals[j] = ((float )p_vals[j]) - lr*(s1_vals[j]);
1911
+ p_vals[j] = (return_updates == nullptr ? (float )p_vals[j] : 0 . 0f ) - lr*(s1_vals[j]);
1912
1912
break ;
1913
1913
case LION:
1914
- p_vals[j] = ((float )p_vals[j]) - ((float )g_vals[j]);
1914
+ p_vals[j] = (return_updates == nullptr ? (float )p_vals[j] : 0 . 0f ) - ((float )g_vals[j]);
1915
1915
break ;
1916
1916
case RMSPROP:
1917
1917
g_val = g_vals[j];
1918
- p_vals[j] = ((float )p_vals[j]) - lr*(__fdividef (g_val, sqrtf (s1_vals[j])+eps));
1918
+ p_vals[j] = (return_updates == nullptr ? (float )p_vals[j] : 0 . 0f ) - lr*(__fdividef (g_val, sqrtf (s1_vals[j])+eps));
1919
1919
break ;
1920
1920
case ADAGRAD:
1921
1921
g_val = g_vals[j];
1922
- p_vals[j] = ((float )p_vals[j]) - lr*(__fdividef (g_val, sqrtf (s1_vals[j])+eps));
1922
+ p_vals[j] = (return_updates == nullptr ? (float )p_vals[j] : 0 . 0f ) - lr*(__fdividef (g_val, sqrtf (s1_vals[j])+eps));
1923
1923
break ;
1924
1924
}
1925
1925
}
@@ -3679,7 +3679,7 @@ MAKE_PreconditionOptimizer32bit1State(ADAGRAD, half)
3679
3679
MAKE_PreconditionOptimizer32bit1State(ADAGRAD, float )
3680
3680
3681
3681
#define MAKE_Optimizer32bit1State (oname, gtype ) \
3682
- template __global__ void kOptimizer32bit1State <gtype, oname>(gtype* g, gtype* p, float * state1, float *unorm, const float max_unorm, const float param_norm, \
3682
+ template __global__ void kOptimizer32bit1State <gtype, oname>(gtype* g, gtype* p, gtype* return_updates, float * state1, float *unorm, const float max_unorm, const float param_norm, \
3683
3683
const float beta1, const float beta2, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n); \
3684
3684
3685
3685
MAKE_Optimizer32bit1State (MOMENTUM, half)
@@ -3729,7 +3729,7 @@ MAKE_PreconditionStatic8bit1State(LION, half)
3729
3729
MAKE_PreconditionStatic8bit1State(LION, float )
3730
3730
3731
3731
#define MAKE_optimizerStatic8bit1State (oname, gtype ) \
3732
- template __global__ void kOptimizerStatic8bit1State <gtype, oname>(gtype* p, gtype* const g, unsigned char * state1, \
3732
+ template __global__ void kOptimizerStatic8bit1State <gtype, oname>(gtype* p, gtype* const g, gtype* return_updates, unsigned char * state1, \
3733
3733
const float *unorm, const float max_unorm, const float param_norm, \
3734
3734
const float beta1, \
3735
3735
const float beta2, \
@@ -3876,7 +3876,7 @@ MAKE_OptimizerStatic8bit2StateBlockwise(ADAM, __nv_bfloat16, 2048, 8)
3876
3876
3877
3877
#define MAKE_OptimizerStatic8bit1StateBlockwise (oname, gtype, block_size, num_per_thread ) \
3878
3878
template __global__ void kOptimizerStatic8bit1StateBlockwise <gtype, oname, block_size, num_per_thread>( \
3879
- gtype* p, gtype* __restrict__ const g, unsigned char * state1, \
3879
+ gtype* p, gtype* __restrict__ const g, gtype* return_updates, unsigned char * state1, \
3880
3880
const float beta1, const float beta2, \
3881
3881
const float eps, const int step, const float lr, \
3882
3882
float * __restrict__ const quantiles1, \
0 commit comments