@@ -200,7 +200,6 @@ def test_name_to_name_sharding(self):
200
200
data_array = jax .ShapeDtypeStruct (
201
201
shape = (4 , 8 , 6 ),
202
202
dtype = np .dtype ('float32' ),
203
- named_shape = {},
204
203
sharding = jax .sharding .NamedSharding (
205
204
explicit_mesh , P ('mesh_axis_1' , 'mesh_axis_0' , None )
206
205
),
@@ -211,7 +210,6 @@ def test_name_to_name_sharding(self):
211
210
data_array = jax .ShapeDtypeStruct (
212
211
shape = (8 ,),
213
212
dtype = np .dtype ('float32' ),
214
- named_shape = {},
215
213
sharding = jax .sharding .NamedSharding (
216
214
explicit_mesh , P ('mesh_axis_1' )
217
215
),
@@ -222,7 +220,6 @@ def test_name_to_name_sharding(self):
222
220
data_array = jax .ShapeDtypeStruct (
223
221
shape = (8 , 12 ),
224
222
dtype = np .dtype ('float32' ),
225
- named_shape = {},
226
223
sharding = jax .sharding .NamedSharding (
227
224
explicit_mesh , P ('mesh_axis_0' , None )
228
225
),
@@ -251,7 +248,6 @@ def test_name_to_name_sharding(self):
251
248
data_array = jax .ShapeDtypeStruct (
252
249
shape = (4 , 8 , 6 ),
253
250
dtype = np .dtype ('float32' ),
254
- named_shape = {},
255
251
sharding = jax .sharding .NamedSharding (
256
252
explicit_mesh , P ('mesh_axis_1' , 'mesh_axis_0' , None )
257
253
),
@@ -264,7 +260,6 @@ def test_name_to_name_sharding(self):
264
260
data_array = jax .ShapeDtypeStruct (
265
261
shape = (8 ,),
266
262
dtype = np .dtype ('float32' ),
267
- named_shape = {},
268
263
sharding = jax .sharding .NamedSharding (
269
264
explicit_mesh , P ('mesh_axis_1' )
270
265
),
@@ -277,7 +272,6 @@ def test_name_to_name_sharding(self):
277
272
data_array = jax .ShapeDtypeStruct (
278
273
shape = (8 , 12 ),
279
274
dtype = np .dtype ('float32' ),
280
- named_shape = {},
281
275
sharding = jax .sharding .NamedSharding (
282
276
explicit_mesh , P ('mesh_axis_0' , None )
283
277
),
0 commit comments