@@ -166,10 +166,10 @@ TEST_CASE("schedule symbolic graph to data parallel with broadcast and reduce")
166
166
ccv_nnc_graph_free (graph );
167
167
ccv_nnc_tensor_arena_free (tensor_arena );
168
168
ccv_nnc_graph_exec_arena_free (graph_exec_arena );
169
- REQUIRE_TENSOR_EQ ( np_updated [0 ], updated [0 ], "updated params should be equal" );
170
- REQUIRE_TENSOR_EQ ( np_updated [1 ], updated [1 ], "updated params should be equal" );
171
- REQUIRE_TENSOR_EQ ( np_updated [2 ], updated [2 ], "updated params should be equal" );
172
- REQUIRE_TENSOR_EQ ( np_updated [3 ], updated [3 ], "updated params should be equal" );
169
+ REQUIRE_ARRAY_EQ_WITH_TOLERANCE ( float , np_updated [0 ]-> data . f32 , updated [0 ]-> data . f32 , 8 * 3 * 5 * 5 , 1e-4 , "updated params should be equal" );
170
+ REQUIRE_ARRAY_EQ_WITH_TOLERANCE ( float , np_updated [1 ]-> data . f32 , updated [1 ]-> data . f32 , 8 , 1e-5 , "updated params should be equal" );
171
+ REQUIRE_ARRAY_EQ_WITH_TOLERANCE ( float , np_updated [2 ]-> data . f32 , updated [2 ]-> data . f32 , 8 * 8 * 5 * 5 , 1e-4 , "updated params should be equal" );
172
+ REQUIRE_ARRAY_EQ_WITH_TOLERANCE ( float , np_updated [3 ]-> data . f32 , updated [3 ]-> data . f32 , 8 , 1e-4 , "updated params should be equal" );
173
173
ccv_nnc_tensor_free (cpu_input );
174
174
ccv_nnc_tensor_free (cpu_fit );
175
175
ccv_nnc_tensor_free (np_updated [0 ]);
@@ -345,10 +345,10 @@ TEST_CASE("schedule symbolic graph to data parallel with allreduce")
345
345
ccv_nnc_graph_free (graph );
346
346
ccv_nnc_tensor_arena_free (tensor_arena );
347
347
ccv_nnc_graph_exec_arena_free (graph_exec_arena );
348
- REQUIRE_TENSOR_EQ ( np_updated [0 ], updated [0 ], "updated params should be equal" );
349
- REQUIRE_TENSOR_EQ ( np_updated [1 ], updated [1 ], "updated params should be equal" );
350
- REQUIRE_TENSOR_EQ ( np_updated [2 ], updated [2 ], "updated params should be equal" );
351
- REQUIRE_TENSOR_EQ ( np_updated [3 ], updated [3 ], "updated params should be equal" );
348
+ REQUIRE_ARRAY_EQ_WITH_TOLERANCE ( float , np_updated [0 ]-> data . f32 , updated [0 ]-> data . f32 , 8 * 3 * 5 * 5 , 1e-4 , "updated params should be equal" );
349
+ REQUIRE_ARRAY_EQ_WITH_TOLERANCE ( float , np_updated [1 ]-> data . f32 , updated [1 ]-> data . f32 , 8 , 1e-5 , "updated params should be equal" );
350
+ REQUIRE_ARRAY_EQ_WITH_TOLERANCE ( float , np_updated [2 ]-> data . f32 , updated [2 ]-> data . f32 , 8 * 8 * 5 * 5 , 1e-4 , "updated params should be equal" );
351
+ REQUIRE_ARRAY_EQ_WITH_TOLERANCE ( float , np_updated [3 ]-> data . f32 , updated [3 ]-> data . f32 , 8 , 1e-4 , "updated params should be equal" );
352
352
ccv_nnc_tensor_free (cpu_input );
353
353
ccv_nnc_tensor_free (cpu_fit );
354
354
ccv_nnc_tensor_free (np_updated [0 ]);
0 commit comments