Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 12 additions & 10 deletions jax/_src/pallas/mosaic_gpu/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -431,18 +431,17 @@ def flatten_ref_union(ref_union: AbstractRefUnion) -> tuple[_Ref, ...]:
union_bytes = 0
for ref_group in ref_union.refs:
byte_offset = 0
for ref in jax.tree.leaves(ref_group):
def unflatten(ref):
nonlocal byte_offset
byte_offset = align_to(byte_offset, SMEM_ALIGNMENT)
assert isinstance(ref, state.AbstractRef) or isinstance(
ref, pallas_core.TransformedRef
)
if not isinstance(ref, pallas_core.TransformedRef):
ref = pallas_core.TransformedRef(ref, transforms=())
transform = ExtractAliasedRef.from_transformed_ref(ref, byte_offset)
flat_refs.append(
pallas_core.TransformedRef(
ref_union, transforms=(transform, *ref.transforms)
)
result = pallas_core.TransformedRef(
ref_union, transforms=(transform, *ref.transforms)
)
if jnp.issubdtype(ref.dtype, jnp.integer):
nbits = jnp.iinfo(ref.dtype).bits
Expand All @@ -457,26 +456,29 @@ def flatten_ref_union(ref_union: AbstractRefUnion) -> tuple[_Ref, ...]:
f" {ref.dtype}{ref.shape}"
)
byte_offset += ref_bits // 8
return result
flat_refs.append(jax.tree.map(unflatten, ref_group))
union_bytes = max(union_bytes, byte_offset)
assert union_bytes == ref_union.shape[0]
elif ref_union.memory_space == TMEM:
union_cols = 0
for ref_group in ref_union.refs:
col_offset = 0
for ref in jax.tree.leaves(ref_group):
def unflatten(ref):
nonlocal col_offset
col_offset = align_to(col_offset, TMEM_COL_ALIGNMENT)
if not isinstance(ref, pallas_core.TransformedRef):
ref = pallas_core.TransformedRef(ref, transforms=())
ncols = ref.layout.cols_in_shape(ref.shape,
dtypes.itemsize_bits(ref.dtype))
transform = ExtractAliasedRef.from_transformed_ref(
ref, col_offset, layout=ref.layout)
flat_refs.append(
pallas_core.TransformedRef(
ref_union, transforms=(transform, *ref.transforms)
)
result = pallas_core.TransformedRef(
ref_union, transforms=(transform, *ref.transforms)
)
col_offset += ncols
return result
flat_refs.append(jax.tree.map(unflatten, ref_group))
union_cols = max(union_cols, col_offset)
assert union_cols == ref_union.shape[1], (union_cols, ref_union.shape[1])
else:
Expand Down
49 changes: 27 additions & 22 deletions tests/pallas/mosaic_gpu_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1907,7 +1907,7 @@ def kernel(x_ref, y_ref, o_ref):
y = jax.lax.iota(jnp.float32, 128) * 3
np.testing.assert_array_equal(kernel(x, y), x + y)

def test_smem_aliasing_works(self):
def test_smem_aliasing_works_basic(self):
self.skip_if_wg_semantics()

in_shape = (2, 256)
Expand Down Expand Up @@ -1938,17 +1938,16 @@ def test_smem_aliasing_works(self):
plgpu.SMEM(
(128,),
jnp.float32,
transforms=(plgpu.TilingTransform((64,)),),
),
transforms=(plgpu.TilingTransform((64,)),)),
]
],
)
],
)
def kernel(x_ref, o_ref128, aliased_ref):
smem_ref256, _, smem_ref128 = aliased_ref
smem_ref256, [_, [smem_ref128]] = aliased_ref
# Ensure that extraction via index works the same as unfolding.
smem_ref128_2 = aliased_ref[2]
smem_ref128_2 = aliased_ref[1][1][0]
self.assertIsInstance(smem_ref128, state_types.TransformedRef)
self.assertIsInstance(smem_ref128_2, state_types.TransformedRef)
self.assertIs(smem_ref128.ref, smem_ref128_2.ref)
Expand Down Expand Up @@ -2005,7 +2004,7 @@ def test_smem_aliasing_works_with_subbyte_dtypes(self):
],
)
def kernel(x_ref, o_refi4, aliased_ref):
_, smem_refi8, _, smem_refi4 = aliased_ref
[_, smem_refi8], [_, smem_refi4] = aliased_ref
smem_refi8[...] = x_ref[...]
plgpu.commit_smem()
plgpu.copy_smem_to_gmem(smem_refi4, o_refi4)
Expand Down Expand Up @@ -3415,7 +3414,7 @@ def test_tmem_ref_aliasing(self):
thread_name="x",
)
def kernel(x_ref, y_ref, aliased_ref, smem_ref, barrier_ref):
tmem_128x32a, tmem_128x32b, tmem_128x64 = aliased_ref
[tmem_128x32a, tmem_128x32b], tmem_128x64 = aliased_ref
plgpu.copy_gmem_to_smem(x_ref, smem_ref, barrier_ref)
plgpu.barrier_wait(barrier_ref)
# Test tmem_128x32 a and b
Expand Down Expand Up @@ -4268,7 +4267,7 @@ def kernel(a_gmem, b_gmem, out_gmem128, out_gmem64,
plgpu.barrier_wait(tma_barrier)
plgpu.copy_gmem_to_smem(b_gmem, b_smem, tma_barrier)
plgpu.barrier_wait(tma_barrier)
acc_128, lhs_128, lhs_64, acc_64, _ = aliased_refs
[acc_128, lhs_128], [lhs_64, acc_64], _ = aliased_refs

# Do 128x128 @ 128x128 matmul
plgpu.async_store_tmem(lhs_128, plgpu.load(a_smem, (), layout=plgpu.Layout.TCGEN05))
Expand Down Expand Up @@ -4305,21 +4304,27 @@ def kernel(a_gmem, b_gmem, out_gmem128, out_gmem64,

f = self.kernel(
kernel,
out_shape=[jax.ShapeDtypeStruct(shape, dtype),
jax.ShapeDtypeStruct(shape, dtype)],
out_shape=[
jax.ShapeDtypeStruct(shape, dtype),
jax.ShapeDtypeStruct(shape, dtype),
],
scratch_shapes=[
plgpu.SMEM(shape, dtype, transforms=transforms), # a_smem
plgpu.SMEM(shape, dtype, transforms=transforms), # b_smem
plgpu.SMEM(shape, dtype, transforms=transforms), # out_smem
plgpu.Barrier(), # tma_barrier
plgpu.Barrier(orders_tensor_core=True), # mma_barrier
plgpu.RefUnion( # aliased_refs
[plgpu.TMEM((128, 128), jnp.float32), # acc
plgpu.TMEM((128, 128), dtype, packed=True)], # lhs
[plgpu.TMEM((128, 64), dtype, packed=True), # lhs
plgpu.TMEM((128, 128), jnp.float32)], # acc
plgpu.TMEM((128, 128), jnp.float32) # unused
),
plgpu.SMEM(shape, dtype, transforms=transforms), # a_smem
plgpu.SMEM(shape, dtype, transforms=transforms), # b_smem
plgpu.SMEM(shape, dtype, transforms=transforms), # out_smem
plgpu.Barrier(), # tma_barrier
plgpu.Barrier(orders_tensor_core=True), # mma_barrier
plgpu.RefUnion( # aliased_refs
[
plgpu.TMEM((128, 128), jnp.float32), # acc
plgpu.TMEM((128, 128), dtype, packed=True), # lhs
],
[
plgpu.TMEM((128, 64), dtype, packed=True), # lhs
plgpu.TMEM((128, 128), jnp.float32), # acc
],
plgpu.TMEM((128, 128), jnp.float32), # unused
),
],
)
x = jax.random.uniform(jax.random.key(0), shape=shape, dtype=dtype)
Expand Down
Loading