Skip to content

Commit 91ea416

Browse files
Support eturn_outputs buffer option for 1-state optimizers
1 parent 16cc220 commit 91ea416

File tree

3 files changed

+33
-33
lines changed

3 files changed

+33
-33
lines changed

csrc/kernels.cu

+25-25
Original file line numberDiff line numberDiff line change
@@ -971,7 +971,7 @@ __global__ void kPreconditionOptimizer32bit1State(T* g, T* p,
971971

972972
template<typename T, int OPTIMIZER>
973973
__launch_bounds__(TH, 1)
974-
__global__ void kOptimizer32bit1State(T *g, T *p,
974+
__global__ void kOptimizer32bit1State(T *g, T *p, T *return_updates,
975975
float *state1, float *unorm, const float max_unorm, const float param_norm,
976976
const float beta1, const float beta2, const float eps, const float weight_decay,
977977
const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n)
@@ -1017,13 +1017,13 @@ __global__ void kOptimizer32bit1State(T *g, T *p,
10171017
__syncthreads();
10181018
LoadFloat(temp_storage.loadf).Load(&(state1[i]), s1_vals, valid_items);
10191019
__syncthreads();
1020-
Load(temp_storage.load).Load(&(p[i]), p_vals, valid_items);
1020+
Load(temp_storage.load).Load(return_updates == nullptr ? &(p[i]) : &(return_updates[i]), p_vals, valid_items);
10211021

10221022
# pragma unroll 4
10231023
for(unsigned int j = 0; j < NUM_PER_THREAD; j++)
10241024
{
10251025
g_vals[j] = gnorm_scale*((float)g_vals[j]);
1026-
if(weight_decay > 0.0f)
1026+
if(weight_decay > 0.0f && return_updates == nullptr)
10271027
g_vals[j] = (float)g_vals[j] + (((float)p_vals[j])*weight_decay);
10281028
}
10291029

@@ -1040,26 +1040,26 @@ __global__ void kOptimizer32bit1State(T *g, T *p,
10401040
else
10411041
s1_vals[j] = s1_vals[j]*beta1 + ((float)g_vals[j]);
10421042

1043-
p_vals[j] = ((float)p_vals[j]) + update_scale*(-lr*(s1_vals[j]));
1043+
p_vals[j] = (return_updates == nullptr ? (float)p_vals[j] : 0.0f) + update_scale*(-lr*(s1_vals[j]));
10441044
break;
10451045
case LION:
1046-
p_vals[j] = ((float)p_vals[j]) - update_scale*(lr*sgn(((float)s1_vals[j])*beta1 + ((1.0f-beta1)*((float)g_vals[j]))));
1046+
p_vals[j] = (return_updates == nullptr ? (float)p_vals[j] : 0.0f) - update_scale*(lr*sgn(((float)s1_vals[j])*beta1 + ((1.0f-beta1)*((float)g_vals[j]))));
10471047
s1_vals[j] = s1_vals[j]*beta2 + ((1.0f-beta2)*((float)g_vals[j]));
10481048
break;
10491049
case RMSPROP:
10501050
s1_vals[j] = s1_vals[j]*beta1 + ((1.0f-beta1)*((float)g_vals[j])*((float)g_vals[j]));
1051-
p_vals[j] = ((float)p_vals[j]) - update_scale*(lr*__fdividef((float)g_vals[j],sqrtf((float)s1_vals[j])+eps));
1051+
p_vals[j] = (return_updates == nullptr ? (float)p_vals[j] : 0.0f) - update_scale*(lr*__fdividef((float)g_vals[j],sqrtf((float)s1_vals[j])+eps));
10521052
break;
10531053
case ADAGRAD:
10541054
s1_vals[j] = s1_vals[j] + ((float)g_vals[j])*((float)g_vals[j]);
1055-
p_vals[j] = ((float)p_vals[j]) - lr*__fdividef((float)g_vals[j],sqrtf((float)s1_vals[j])+eps);
1055+
p_vals[j] = (return_updates == nullptr ? (float)p_vals[j] : 0.0f) - lr*__fdividef((float)g_vals[j],sqrtf((float)s1_vals[j])+eps);
10561056
break;
10571057
}
10581058
}
10591059
}
10601060

10611061
__syncthreads();
1062-
Store(temp_storage.store).Store(&(p[i]), p_vals, valid_items);
1062+
Store(temp_storage.store).Store(return_updates == nullptr ? &(p[i]) : &(return_updates[i]), p_vals, valid_items);
10631063
__syncthreads();
10641064
StoreFloat(temp_storage.storef).Store(&(state1[i]), s1_vals, valid_items);
10651065
}
@@ -1406,7 +1406,7 @@ kPreconditionOptimizerStatic8bit1State(T* p, T* __restrict__ const g, unsigned c
14061406
template<typename T, int OPTIMIZER>
14071407
__global__ void
14081408
__launch_bounds__(1024, 1)
1409-
kOptimizerStatic8bit1State(T* p, T* const g, unsigned char* state1,
1409+
kOptimizerStatic8bit1State(T* p, T* const g, T* return_updates, unsigned char* state1,
14101410
const float *unorm, const float max_unorm, const float param_norm,
14111411
const float beta1, const float beta2,
14121412
const float eps, const int step, const float lr,
@@ -1462,7 +1462,7 @@ kOptimizerStatic8bit1State(T* p, T* const g, unsigned char* state1,
14621462
__syncthreads();
14631463
LoadChar(temp_storage.loadc).Load(&(state1[i]), c1s, valid_items, 128);
14641464
__syncthreads();
1465-
LoadT(temp_storage.loadh).Load(&(p[i]), p_vals, valid_items);
1465+
LoadT(temp_storage.loadh).Load(return_updates == nullptr ? &(p[i]) : &(return_updates[i]), p_vals, valid_items);
14661466

14671467
if((i + (threadIdx.x*NUM_PER_THREAD2) + NUM_PER_THREAD2) > n){ continue; }
14681468

@@ -1472,7 +1472,7 @@ kOptimizerStatic8bit1State(T* p, T* const g, unsigned char* state1,
14721472
g_val = float(g_vals[j]);
14731473
g_val *= gnorm_scale;
14741474

1475-
if(weight_decay > 0.0f) {
1475+
if(weight_decay > 0.0f && return_updates == nullptr) {
14761476
switch(OPTIMIZER) {
14771477
case MOMENTUM:
14781478
case RMSPROP:
@@ -1494,15 +1494,15 @@ kOptimizerStatic8bit1State(T* p, T* const g, unsigned char* state1,
14941494
else
14951495
s1_vals[j] = s1_vals[j]*beta1 + ((float)g_vals[j]);
14961496

1497-
p_vals[j] = ((float)p_vals[j]) + (-lr*update_scale*(s1_vals[j]));
1497+
p_vals[j] = (return_updates == nullptr ? (float)p_vals[j] : 0.0f) + (-lr*update_scale*(s1_vals[j]));
14981498
break;
14991499
case LION:
1500-
p_vals[j] = ((float)p_vals[j]) - (lr*sgn(((float)s1_vals[j])*beta1 + ((1.0f-beta1)*((float)g_val))));
1500+
p_vals[j] = (return_updates == nullptr ? (float)p_vals[j] : 0.0f) - (lr*sgn(((float)s1_vals[j])*beta1 + ((1.0f-beta1)*((float)g_val))));
15011501
s1_vals[j] = s1_vals[j]*beta2 + ((1.0f-beta2)*g_val);
15021502
break;
15031503
case RMSPROP:
15041504
s1_vals[j] = s1_vals[j]*beta1 + ((1.0f-beta1)*(g_val*g_val));
1505-
p_vals[j] = ((float)p_vals[j]) - (lr*__fdividef(g_val,sqrtf(s1_vals[j])+eps));
1505+
p_vals[j] = (return_updates == nullptr ? (float)p_vals[j] : 0.0f) - (lr*__fdividef(g_val,sqrtf(s1_vals[j])+eps));
15061506
break;
15071507
}
15081508

@@ -1518,7 +1518,7 @@ kOptimizerStatic8bit1State(T* p, T* const g, unsigned char* state1,
15181518
}
15191519
}
15201520

1521-
StoreT(temp_storage.storeh).Store(&(p[i]), p_vals, valid_items);
1521+
StoreT(temp_storage.storeh).Store(return_updates == nullptr ? &(p[i]) : &(return_updates[i]), p_vals, valid_items);
15221522
__syncthreads();
15231523
StoreChar(temp_storage.storec).Store(&(state1[i]), c1s, valid_items);
15241524
__syncthreads();
@@ -1769,7 +1769,7 @@ kOptimizerStatic8bit2StateBlockwise(T* p, T* __restrict__ const g, T* return_upd
17691769
template<typename T, int OPTIMIZER, int BLOCK_SIZE, int N_PER_TH>
17701770
__launch_bounds__(256, 3)
17711771
__global__ void
1772-
kOptimizerStatic8bit1StateBlockwise(T* p, T* __restrict__ const g, unsigned char* state1,
1772+
kOptimizerStatic8bit1StateBlockwise(T* p, T* __restrict__ const g, T* return_updates, unsigned char* state1,
17731773
const float beta1, const float beta2,
17741774
const float eps, const int step, const float lr,
17751775
float* __restrict__ const quantiles1,
@@ -1833,7 +1833,7 @@ kOptimizerStatic8bit1StateBlockwise(T* p, T* __restrict__ const g, unsigned char
18331833
__syncthreads();
18341834
LoadChar(temp_storage.loadc).Load(&(state1[i]), c1s, valid_items, 128);
18351835
__syncthreads();
1836-
LoadT(temp_storage.loadh).Load(&(p[i]), p_vals, valid_items, (T)0.0f);
1836+
LoadT(temp_storage.loadh).Load(return_updates == nullptr ? &(p[i]) : &(return_updates[i]), p_vals, valid_items, (T)0.0f);
18371837

18381838
new_local_abs_max1 = -FLT_MAX;
18391839

@@ -1845,7 +1845,7 @@ kOptimizerStatic8bit1StateBlockwise(T* p, T* __restrict__ const g, unsigned char
18451845
g_val *= gnorm_scale;
18461846
if(!skip_zeros || (skip_zeros && ((float)g_vals[j] != 0.0f)))
18471847
{
1848-
if(weight_decay > 0.0f) {
1848+
if(weight_decay > 0.0f && return_updates == nullptr) {
18491849
switch(OPTIMIZER) {
18501850
case MOMENTUM:
18511851
case ADAGRAD:
@@ -1908,18 +1908,18 @@ kOptimizerStatic8bit1StateBlockwise(T* p, T* __restrict__ const g, unsigned char
19081908
switch(OPTIMIZER)
19091909
{
19101910
case MOMENTUM:
1911-
p_vals[j] = ((float)p_vals[j]) - lr*(s1_vals[j]);
1911+
p_vals[j] = (return_updates == nullptr ? (float)p_vals[j] : 0.0f) - lr*(s1_vals[j]);
19121912
break;
19131913
case LION:
1914-
p_vals[j] = ((float)p_vals[j]) - ((float)g_vals[j]);
1914+
p_vals[j] = (return_updates == nullptr ? (float)p_vals[j] : 0.0f) - ((float)g_vals[j]);
19151915
break;
19161916
case RMSPROP:
19171917
g_val = g_vals[j];
1918-
p_vals[j] = ((float)p_vals[j]) - lr*(__fdividef(g_val, sqrtf(s1_vals[j])+eps));
1918+
p_vals[j] = (return_updates == nullptr ? (float)p_vals[j] : 0.0f) - lr*(__fdividef(g_val, sqrtf(s1_vals[j])+eps));
19191919
break;
19201920
case ADAGRAD:
19211921
g_val = g_vals[j];
1922-
p_vals[j] = ((float)p_vals[j]) - lr*(__fdividef(g_val, sqrtf(s1_vals[j])+eps));
1922+
p_vals[j] = (return_updates == nullptr ? (float)p_vals[j] : 0.0f) - lr*(__fdividef(g_val, sqrtf(s1_vals[j])+eps));
19231923
break;
19241924
}
19251925
}
@@ -3679,7 +3679,7 @@ MAKE_PreconditionOptimizer32bit1State(ADAGRAD, half)
36793679
MAKE_PreconditionOptimizer32bit1State(ADAGRAD, float)
36803680

36813681
#define MAKE_Optimizer32bit1State(oname, gtype) \
3682-
template __global__ void kOptimizer32bit1State<gtype, oname>(gtype* g, gtype* p, float* state1, float *unorm, const float max_unorm, const float param_norm, \
3682+
template __global__ void kOptimizer32bit1State<gtype, oname>(gtype* g, gtype* p, gtype* return_updates, float* state1, float *unorm, const float max_unorm, const float param_norm, \
36833683
const float beta1, const float beta2, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n); \
36843684

36853685
MAKE_Optimizer32bit1State(MOMENTUM, half)
@@ -3729,7 +3729,7 @@ MAKE_PreconditionStatic8bit1State(LION, half)
37293729
MAKE_PreconditionStatic8bit1State(LION, float)
37303730

37313731
#define MAKE_optimizerStatic8bit1State(oname, gtype) \
3732-
template __global__ void kOptimizerStatic8bit1State<gtype, oname>(gtype* p, gtype* const g, unsigned char* state1, \
3732+
template __global__ void kOptimizerStatic8bit1State<gtype, oname>(gtype* p, gtype* const g, gtype* return_updates, unsigned char* state1, \
37333733
const float *unorm, const float max_unorm, const float param_norm, \
37343734
const float beta1, \
37353735
const float beta2, \
@@ -3876,7 +3876,7 @@ MAKE_OptimizerStatic8bit2StateBlockwise(ADAM, __nv_bfloat16, 2048, 8)
38763876

38773877
#define MAKE_OptimizerStatic8bit1StateBlockwise(oname, gtype, block_size, num_per_thread) \
38783878
template __global__ void kOptimizerStatic8bit1StateBlockwise<gtype, oname, block_size, num_per_thread>( \
3879-
gtype* p, gtype* __restrict__ const g, unsigned char* state1, \
3879+
gtype* p, gtype* __restrict__ const g, gtype* return_updates, unsigned char* state1, \
38803880
const float beta1, const float beta2, \
38813881
const float eps, const int step, const float lr, \
38823882
float* __restrict__ const quantiles1, \

csrc/kernels.cuh

+3-3
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ __global__ void kPreconditionOptimizer32bit1State(T* g, T* p,
3838
const int step, const float lr, const float gnorm_scale, const int n);
3939

4040
template<typename T, int OPTIMIZER>
41-
__global__ void kOptimizer32bit1State(T* g, T* p,
41+
__global__ void kOptimizer32bit1State(T* g, T* p, T* return_updates,
4242
float* state1, float *unorm, const float max_unorm, const float param_norm,
4343
const float beta1, const float beta2, const float eps, const float weight_decay,
4444
const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n);
@@ -57,7 +57,7 @@ kPreconditionOptimizerStatic8bit1State(T* p, T* __restrict__ const g, unsigned c
5757

5858
template<typename T, int OPTIMIZER>
5959
__global__ void
60-
kOptimizerStatic8bit1State(T* p, T* const g, unsigned char* state1,
60+
kOptimizerStatic8bit1State(T* p, T* const g, T* return_updates, unsigned char* state1,
6161
const float *unorm, const float max_unorm, const float param_norm,
6262
const float beta1, const float beta2,
6363
const float eps, const int step, const float lr,
@@ -95,7 +95,7 @@ template<typename T, int OPTIMIZER, int BLOCK_SIZE, int N_PER_TH> __global__ voi
9595
float* absmax1, float* absmax2, float weight_decay, const float gnorm_scale, const bool skip_zeros, const int n);
9696

9797
template<typename T, int OPTIMIZER, int BLOCK_SIZE, int N_PER_TH> __global__ void kOptimizerStatic8bit1StateBlockwise(
98-
T* p, T* __restrict__ const g, unsigned char* state1,
98+
T* p, T* __restrict__ const g, T* return_updates, unsigned char* state1,
9999
const float beta1, const float beta2,
100100
const float eps, const int step, const float lr,
101101
float* __restrict__ const quantiles1,

csrc/ops.cu

+5-5
Original file line numberDiff line numberDiff line change
@@ -128,12 +128,12 @@ template<typename T, int OPTIMIZER> void optimizer32bit(T* g, T* p, T* return_up
128128
CUDA_CHECK_RETURN(cudaPeekAtLastError());
129129
}
130130

131-
kOptimizer32bit1State<T, OPTIMIZER><<<num_blocks, 1024>>>(g, p, state1, unorm, max_unorm, param_norm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n);
131+
kOptimizer32bit1State<T, OPTIMIZER><<<num_blocks, 1024>>>(g, p, return_updates, state1, unorm, max_unorm, param_norm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n);
132132
CUDA_CHECK_RETURN(cudaPeekAtLastError());
133133
break;
134134
case LION:
135135
// in lion, the momentum update after the parameter update
136-
kOptimizer32bit1State<T, OPTIMIZER><<<num_blocks, 1024>>>(g, p, state1, unorm, max_unorm, param_norm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n);
136+
kOptimizer32bit1State<T, OPTIMIZER><<<num_blocks, 1024>>>(g, p, return_updates, state1, unorm, max_unorm, param_norm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n);
137137
CUDA_CHECK_RETURN(cudaPeekAtLastError());
138138

139139
if(max_unorm > 0.0f)
@@ -178,13 +178,13 @@ template<typename T, int OPTIMIZER> void optimizerStatic8bit(T* p, T* g, T* retu
178178
CUDA_CHECK_RETURN(cudaMemset(new_max1, 0, 1*sizeof(float)));
179179
kPreconditionOptimizerStatic8bit1State<T, OPTIMIZER><<<num_blocks, 256>>>(p, g, state1, unorm, beta1, beta2, eps, step, quantiles1, max1, new_max1, weight_decay, gnorm_scale, n);
180180
CUDA_CHECK_RETURN(cudaPeekAtLastError());
181-
kOptimizerStatic8bit1State<T, OPTIMIZER><<<num_blocks, 1024>>>(p, g, state1, unorm, max_unorm, param_norm, beta1, beta2, eps, step, lr,
181+
kOptimizerStatic8bit1State<T, OPTIMIZER><<<num_blocks, 1024>>>(p, g, return_updates, state1, unorm, max_unorm, param_norm, beta1, beta2, eps, step, lr,
182182
quantiles1, max1, new_max1, weight_decay, gnorm_scale, n);
183183
CUDA_CHECK_RETURN(cudaPeekAtLastError());
184184
break;
185185
case LION:
186186
// in lion, the momentum update happens after the parameter update
187-
kOptimizerStatic8bit1State<T, OPTIMIZER><<<num_blocks, 1024>>>(p, g, state1, unorm, max_unorm, param_norm, beta1, beta2, eps, step, lr,
187+
kOptimizerStatic8bit1State<T, OPTIMIZER><<<num_blocks, 1024>>>(p, g, return_updates, state1, unorm, max_unorm, param_norm, beta1, beta2, eps, step, lr,
188188
quantiles1, max1, new_max1, weight_decay, gnorm_scale, n);
189189
CUDA_CHECK_RETURN(cudaPeekAtLastError());
190190

@@ -223,7 +223,7 @@ template<typename T, int OPTIMIZER> void optimizerStatic8bitBlockwise(T* p, T* g
223223
case LION:
224224
num_blocks = n/BLOCKSIZE_1STATE;
225225
num_blocks = n % BLOCKSIZE_1STATE == 0 ? num_blocks : num_blocks + 1;
226-
kOptimizerStatic8bit1StateBlockwise<T, OPTIMIZER, BLOCKSIZE_1STATE, NUM_1STATE><<<num_blocks, BLOCKSIZE_1STATE/NUM_1STATE>>>(p, g, state1, beta1, beta2, eps, step, lr,
226+
kOptimizerStatic8bit1StateBlockwise<T, OPTIMIZER, BLOCKSIZE_1STATE, NUM_1STATE><<<num_blocks, BLOCKSIZE_1STATE/NUM_1STATE>>>(p, g, return_updates, state1, beta1, beta2, eps, step, lr,
227227
quantiles1, absmax1, weight_decay, gnorm_scale, skip_zeros, n);
228228
CUDA_CHECK_RETURN(cudaPeekAtLastError());
229229
break;

0 commit comments

Comments
 (0)