Skip to content

Commit 3776275

Browse files
apaszkeGoogle-ML-Automation
authored andcommitted
[Pallas:MGPU] Properly restore the pytree structure when unflattening unions
Previously we returned all the leaves in a flat list, which is unhelpful. PiperOrigin-RevId: 842637237
1 parent e43f4cb commit 3776275

File tree

2 files changed

+39
-32
lines changed

2 files changed

+39
-32
lines changed

jax/_src/pallas/mosaic_gpu/core.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -431,18 +431,17 @@ def flatten_ref_union(ref_union: AbstractRefUnion) -> tuple[_Ref, ...]:
431431
union_bytes = 0
432432
for ref_group in ref_union.refs:
433433
byte_offset = 0
434-
for ref in jax.tree.leaves(ref_group):
434+
def unflatten(ref):
435+
nonlocal byte_offset
435436
byte_offset = align_to(byte_offset, SMEM_ALIGNMENT)
436437
assert isinstance(ref, state.AbstractRef) or isinstance(
437438
ref, pallas_core.TransformedRef
438439
)
439440
if not isinstance(ref, pallas_core.TransformedRef):
440441
ref = pallas_core.TransformedRef(ref, transforms=())
441442
transform = ExtractAliasedRef.from_transformed_ref(ref, byte_offset)
442-
flat_refs.append(
443-
pallas_core.TransformedRef(
444-
ref_union, transforms=(transform, *ref.transforms)
445-
)
443+
result = pallas_core.TransformedRef(
444+
ref_union, transforms=(transform, *ref.transforms)
446445
)
447446
if jnp.issubdtype(ref.dtype, jnp.integer):
448447
nbits = jnp.iinfo(ref.dtype).bits
@@ -457,26 +456,29 @@ def flatten_ref_union(ref_union: AbstractRefUnion) -> tuple[_Ref, ...]:
457456
f" {ref.dtype}{ref.shape}"
458457
)
459458
byte_offset += ref_bits // 8
459+
return result
460+
flat_refs.append(jax.tree.map(unflatten, ref_group))
460461
union_bytes = max(union_bytes, byte_offset)
461462
assert union_bytes == ref_union.shape[0]
462463
elif ref_union.memory_space == TMEM:
463464
union_cols = 0
464465
for ref_group in ref_union.refs:
465466
col_offset = 0
466-
for ref in jax.tree.leaves(ref_group):
467+
def unflatten(ref):
468+
nonlocal col_offset
467469
col_offset = align_to(col_offset, TMEM_COL_ALIGNMENT)
468470
if not isinstance(ref, pallas_core.TransformedRef):
469471
ref = pallas_core.TransformedRef(ref, transforms=())
470472
ncols = ref.layout.cols_in_shape(ref.shape,
471473
dtypes.itemsize_bits(ref.dtype))
472474
transform = ExtractAliasedRef.from_transformed_ref(
473475
ref, col_offset, layout=ref.layout)
474-
flat_refs.append(
475-
pallas_core.TransformedRef(
476-
ref_union, transforms=(transform, *ref.transforms)
477-
)
476+
result = pallas_core.TransformedRef(
477+
ref_union, transforms=(transform, *ref.transforms)
478478
)
479479
col_offset += ncols
480+
return result
481+
flat_refs.append(jax.tree.map(unflatten, ref_group))
480482
union_cols = max(union_cols, col_offset)
481483
assert union_cols == ref_union.shape[1], (union_cols, ref_union.shape[1])
482484
else:

tests/pallas/mosaic_gpu_test.py

Lines changed: 27 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1907,7 +1907,7 @@ def kernel(x_ref, y_ref, o_ref):
19071907
y = jax.lax.iota(jnp.float32, 128) * 3
19081908
np.testing.assert_array_equal(kernel(x, y), x + y)
19091909

1910-
def test_smem_aliasing_works(self):
1910+
def test_smem_aliasing_works_basic(self):
19111911
self.skip_if_wg_semantics()
19121912

19131913
in_shape = (2, 256)
@@ -1938,17 +1938,16 @@ def test_smem_aliasing_works(self):
19381938
plgpu.SMEM(
19391939
(128,),
19401940
jnp.float32,
1941-
transforms=(plgpu.TilingTransform((64,)),),
1942-
),
1941+
transforms=(plgpu.TilingTransform((64,)),)),
19431942
]
19441943
],
19451944
)
19461945
],
19471946
)
19481947
def kernel(x_ref, o_ref128, aliased_ref):
1949-
smem_ref256, _, smem_ref128 = aliased_ref
1948+
smem_ref256, [_, [smem_ref128]] = aliased_ref
19501949
# Ensure that extraction via index works the same as unfolding.
1951-
smem_ref128_2 = aliased_ref[2]
1950+
smem_ref128_2 = aliased_ref[1][1][0]
19521951
self.assertIsInstance(smem_ref128, state_types.TransformedRef)
19531952
self.assertIsInstance(smem_ref128_2, state_types.TransformedRef)
19541953
self.assertIs(smem_ref128.ref, smem_ref128_2.ref)
@@ -2005,7 +2004,7 @@ def test_smem_aliasing_works_with_subbyte_dtypes(self):
20052004
],
20062005
)
20072006
def kernel(x_ref, o_refi4, aliased_ref):
2008-
_, smem_refi8, _, smem_refi4 = aliased_ref
2007+
[_, smem_refi8], [_, smem_refi4] = aliased_ref
20092008
smem_refi8[...] = x_ref[...]
20102009
plgpu.commit_smem()
20112010
plgpu.copy_smem_to_gmem(smem_refi4, o_refi4)
@@ -3399,7 +3398,7 @@ def test_tmem_ref_aliasing(self):
33993398
thread_name="x",
34003399
)
34013400
def kernel(x_ref, y_ref, aliased_ref, smem_ref, barrier_ref):
3402-
tmem_128x32a, tmem_128x32b, tmem_128x64 = aliased_ref
3401+
[tmem_128x32a, tmem_128x32b], tmem_128x64 = aliased_ref
34033402
plgpu.copy_gmem_to_smem(x_ref, smem_ref, barrier_ref)
34043403
plgpu.barrier_wait(barrier_ref)
34053404
# Test tmem_128x32 a and b
@@ -4252,7 +4251,7 @@ def kernel(a_gmem, b_gmem, out_gmem128, out_gmem64,
42524251
plgpu.barrier_wait(tma_barrier)
42534252
plgpu.copy_gmem_to_smem(b_gmem, b_smem, tma_barrier)
42544253
plgpu.barrier_wait(tma_barrier)
4255-
acc_128, lhs_128, lhs_64, acc_64, _ = aliased_refs
4254+
[acc_128, lhs_128], [lhs_64, acc_64], _ = aliased_refs
42564255

42574256
# Do 128x128 @ 128x128 matmul
42584257
plgpu.async_store_tmem(lhs_128, plgpu.load(a_smem, (), layout=plgpu.Layout.TCGEN05))
@@ -4289,21 +4288,27 @@ def kernel(a_gmem, b_gmem, out_gmem128, out_gmem64,
42894288

42904289
f = self.kernel(
42914290
kernel,
4292-
out_shape=[jax.ShapeDtypeStruct(shape, dtype),
4293-
jax.ShapeDtypeStruct(shape, dtype)],
4291+
out_shape=[
4292+
jax.ShapeDtypeStruct(shape, dtype),
4293+
jax.ShapeDtypeStruct(shape, dtype),
4294+
],
42944295
scratch_shapes=[
4295-
plgpu.SMEM(shape, dtype, transforms=transforms), # a_smem
4296-
plgpu.SMEM(shape, dtype, transforms=transforms), # b_smem
4297-
plgpu.SMEM(shape, dtype, transforms=transforms), # out_smem
4298-
plgpu.Barrier(), # tma_barrier
4299-
plgpu.Barrier(orders_tensor_core=True), # mma_barrier
4300-
plgpu.RefUnion( # aliased_refs
4301-
[plgpu.TMEM((128, 128), jnp.float32), # acc
4302-
plgpu.TMEM((128, 128), dtype, packed=True)], # lhs
4303-
[plgpu.TMEM((128, 64), dtype, packed=True), # lhs
4304-
plgpu.TMEM((128, 128), jnp.float32)], # acc
4305-
plgpu.TMEM((128, 128), jnp.float32) # unused
4306-
),
4296+
plgpu.SMEM(shape, dtype, transforms=transforms), # a_smem
4297+
plgpu.SMEM(shape, dtype, transforms=transforms), # b_smem
4298+
plgpu.SMEM(shape, dtype, transforms=transforms), # out_smem
4299+
plgpu.Barrier(), # tma_barrier
4300+
plgpu.Barrier(orders_tensor_core=True), # mma_barrier
4301+
plgpu.RefUnion( # aliased_refs
4302+
[
4303+
plgpu.TMEM((128, 128), jnp.float32), # acc
4304+
plgpu.TMEM((128, 128), dtype, packed=True), # lhs
4305+
],
4306+
[
4307+
plgpu.TMEM((128, 64), dtype, packed=True), # lhs
4308+
plgpu.TMEM((128, 128), jnp.float32), # acc
4309+
],
4310+
plgpu.TMEM((128, 128), jnp.float32), # unused
4311+
),
43074312
],
43084313
)
43094314
x = jax.random.uniform(jax.random.key(0), shape=shape, dtype=dtype)

0 commit comments

Comments
 (0)