Skip to content

Commit d07f5cb

Browse files
superbobryPenzai Developers
authored and
Penzai Developers
committed
Removed named_shape= from jax.core.ShapedArray and jax.ShapeDtypeStruct
It is unused and was only kept around to avoid breaking internal users. PiperOrigin-RevId: 673411837
1 parent e78349f commit d07f5cb

File tree

1 file changed

+0
-6
lines changed

1 file changed

+0
-6
lines changed

tests/deprecated/v1/toolshed/sharding_util_test.py

-6
Original file line numberDiff line numberDiff line change
@@ -200,7 +200,6 @@ def test_name_to_name_sharding(self):
200200
data_array=jax.ShapeDtypeStruct(
201201
shape=(4, 8, 6),
202202
dtype=np.dtype('float32'),
203-
named_shape={},
204203
sharding=jax.sharding.NamedSharding(
205204
explicit_mesh, P('mesh_axis_1', 'mesh_axis_0', None)
206205
),
@@ -211,7 +210,6 @@ def test_name_to_name_sharding(self):
211210
data_array=jax.ShapeDtypeStruct(
212211
shape=(8,),
213212
dtype=np.dtype('float32'),
214-
named_shape={},
215213
sharding=jax.sharding.NamedSharding(
216214
explicit_mesh, P('mesh_axis_1')
217215
),
@@ -222,7 +220,6 @@ def test_name_to_name_sharding(self):
222220
data_array=jax.ShapeDtypeStruct(
223221
shape=(8, 12),
224222
dtype=np.dtype('float32'),
225-
named_shape={},
226223
sharding=jax.sharding.NamedSharding(
227224
explicit_mesh, P('mesh_axis_0', None)
228225
),
@@ -251,7 +248,6 @@ def test_name_to_name_sharding(self):
251248
data_array=jax.ShapeDtypeStruct(
252249
shape=(4, 8, 6),
253250
dtype=np.dtype('float32'),
254-
named_shape={},
255251
sharding=jax.sharding.NamedSharding(
256252
explicit_mesh, P('mesh_axis_1', 'mesh_axis_0', None)
257253
),
@@ -264,7 +260,6 @@ def test_name_to_name_sharding(self):
264260
data_array=jax.ShapeDtypeStruct(
265261
shape=(8,),
266262
dtype=np.dtype('float32'),
267-
named_shape={},
268263
sharding=jax.sharding.NamedSharding(
269264
explicit_mesh, P('mesh_axis_1')
270265
),
@@ -277,7 +272,6 @@ def test_name_to_name_sharding(self):
277272
data_array=jax.ShapeDtypeStruct(
278273
shape=(8, 12),
279274
dtype=np.dtype('float32'),
280-
named_shape={},
281275
sharding=jax.sharding.NamedSharding(
282276
explicit_mesh, P('mesh_axis_0', None)
283277
),

0 commit comments

Comments
 (0)