@@ -88,12 +88,13 @@ static int _ccv_nnc_group_norm_forw(const ccv_nnc_cmd_t cmd, const ccv_nnc_hint_
88
88
assert (ccv_nnc_tensor_view_check_dim ((ccv_nnc_tensor_view_t *)outputs[0 ], adim));
89
89
int x;
90
90
int n = 1 ;
91
- for (x = 0 ; x < CCV_NNC_MAX_DIM + 2 ; x++)
91
+ const int a_nd = ccv_nnc_tensor_nd (adim);
92
+ for (x = 0 ; x < a_nd; x++)
92
93
n *= adim[x];
93
- for (x = 0 ; x < CCV_NNC_MAX_DIM + 2 ; x++)
94
+ for (x = 0 ; x < a_nd ; x++)
94
95
n /= rdim[x];
95
96
int rcount = 1 ;
96
- for (x = 0 ; x < CCV_NNC_MAX_DIM + 2 ; x++)
97
+ for (x = 0 ; x < a_nd ; x++)
97
98
rcount *= rdim[x];
98
99
const float inv_n = 1 . / n;
99
100
cudnnReduceTensorDescriptor_t reduce = ccv_nnc_stream_context_get_reduce_tensor_descriptor (stream_context);
@@ -197,12 +198,13 @@ static int _ccv_nnc_group_norm_back(const ccv_nnc_cmd_t cmd, const ccv_nnc_hint_
197
198
static const float one = 1 , zero = 0 , neg_one = -1 ;
198
199
int x;
199
200
int n = 1 ;
200
- for (x = 0 ; x < CCV_NNC_MAX_DIM + 2 ; x++)
201
+ const int g_nd = ccv_nnc_tensor_nd (gdim);
202
+ for (x = 0 ; x < g_nd; x++)
201
203
n *= gdim[x];
202
- for (x = 0 ; x < CCV_NNC_MAX_DIM + 2 ; x++)
204
+ for (x = 0 ; x < g_nd ; x++)
203
205
n /= rdim[x];
204
206
int gcount = 1 , rcount = 1 ;
205
- for (x = 0 ; x < CCV_NNC_MAX_DIM + 2 ; x++)
207
+ for (x = 0 ; x < g_nd ; x++)
206
208
gcount *= gdim[x], rcount *= rdim[x];
207
209
const float neg_inv_n = -1 . / n;
208
210
cudnnReduceTensorDescriptor_t reduce = ccv_nnc_stream_context_get_reduce_tensor_descriptor (stream_context);
0 commit comments