Skip to content

Commit e934636

Browse files
Davis YoshidaGoogle-ML-Automation
authored andcommitted
Support Hijax types in emit_pipeline.
PiperOrigin-RevId: 841158497
1 parent 863e4e7 commit e934636

File tree

10 files changed

+488
-52
lines changed

10 files changed

+488
-52
lines changed

jax/_src/pallas/core.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1379,6 +1379,27 @@ def _get_sds(aval: jax_core.AbstractValue):
13791379
core_map_p = jax_core.Primitive("core_map")
13801380
core_map_p.multiple_results = True
13811381

1382+
def _core_map_is_high(*avals, jaxpr, **params):
1383+
del avals, params
1384+
return jaxpr.is_high
1385+
core_map_p.is_high = _core_map_is_high # type: ignore[method-assign]
1386+
1387+
def _core_map_to_lojax(*consts, jaxpr, mesh, **params):
1388+
closed_hi_jaxpr = jax_core.ClosedJaxpr(jaxpr, consts)
1389+
with (
1390+
tracing_grid_env(tuple(mesh.shape.values()), mapped_dims=()),
1391+
jax_core.extend_axis_env_nd(mesh.shape.items()),
1392+
):
1393+
closed_lo_jaxpr = pe.lower_jaxpr(closed_hi_jaxpr)
1394+
assert not closed_lo_jaxpr.is_high
1395+
return core_map_p.bind(
1396+
*closed_lo_jaxpr.consts,
1397+
jaxpr=closed_lo_jaxpr.jaxpr,
1398+
mesh=mesh,
1399+
**params,
1400+
)
1401+
core_map_p.to_lojax = _core_map_to_lojax
1402+
13821403

13831404
def core_map(
13841405
mesh,

jax/_src/pallas/mosaic/pipeline.py

Lines changed: 95 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -23,18 +23,19 @@
2323
from typing import Any, Union
2424

2525
import jax
26+
from jax import core as jax_core
2627
from jax import lax
2728
from jax import tree_util
2829
from jax._src import util as jax_util
2930
from jax._src.pallas import core as pallas_core
3031
from jax._src.pallas import primitives as primitives
3132
from jax._src.pallas.mosaic import core as tpu_core
3233
from jax._src.pallas.mosaic import helpers as tpu_helpers
33-
from jax._src.pallas.mosaic import tpu_info
3434
from jax._src.pallas.mosaic import primitives as tpu_primitives
35+
from jax._src.pallas.mosaic import tpu_info
36+
from jax._src.state import types as state_types
3537
from jax.experimental import pallas as pl
3638
import jax.numpy as jnp
37-
import numpy as np
3839

3940

4041
SMEM = tpu_core.MemorySpace.SMEM
@@ -79,16 +80,25 @@ def add_leaves(i, x):
7980
def _get_tpu_generation() -> int:
8081
return tpu_info.get_tpu_info().generation
8182

82-
def _make_tiling(shape: tuple[int, ...], dtype: np.dtype) -> tuple[int, ...]:
83+
84+
def _make_tiling(
85+
shape: tuple[int, ...], ty: jax_core.AbstractValue
86+
) -> tuple[int | None, ...]:
8387
# For a n-dimensional shape, returns (8, 128) for the last 2 dimensions
8488
# and 1 for the leading n - 2. For example, (256, 256) -> (8, 128) and
8589
# (2, 3, 128, 128) -> (1, 1, 8, 128).
8690
if len(shape) < 2:
8791
raise ValueError(f"Shape must have at least 2 dimensions: {shape=}")
92+
93+
dtype = getattr(ty, 'dtype', None)
94+
if dtype is None:
95+
return (None,) * len(shape)
96+
8897
leading_dims, final_dims = shape[:-2], shape[-2:]
8998
# We want to find the minimum power of 2 that fits the second-minor dimension
9099
# of shape, with maximum value 8.
91100
second_minor, _ = final_dims
101+
92102
packing = 4 // dtype.itemsize
93103
max_tiling = _TILING[0]
94104
second_minor_tiling = (1 + int(_get_tpu_generation() < 4)) * packing
@@ -114,17 +124,23 @@ def _make_block_ds(
114124
assert isinstance(out, pl.Slice)
115125
return out
116126

117-
def _create_blocked_slice(block_index: jax.Array | int,
118-
block_size: int,
119-
dim_size: int,
120-
tiling: int):
127+
128+
def _create_blocked_slice(
129+
block_index: jax.Array | int,
130+
block_size: int,
131+
dim_size: int,
132+
tiling: int | None,
133+
):
121134
block_start = block_size * block_index
122135
if (dim_rem := dim_size % block_size) == 0:
123136
return pl.ds(block_start, block_size)
137+
if tiling is None:
138+
raise ValueError("If tiling is None, block_size must divide dim_size.")
124139
if block_size % tiling != 0:
125140
raise ValueError(f"Block size must divide tiling: {block_size=}, {tiling=}")
126141
num_blocks = pl.cdiv(dim_size, block_size)
127142
is_last = block_index == num_blocks - 1
143+
128144
rounded_size = jnp.where(
129145
is_last,
130146
_round_up_to_nearest_multiple(dim_rem % block_size, tiling),
@@ -133,31 +149,35 @@ def _create_blocked_slice(block_index: jax.Array | int,
133149
rounded_size = pl.multiple_of(rounded_size, tiling)
134150
return pl.ds(block_index * block_size, rounded_size)
135151

152+
136153
def _create_bounded_slice(slice_start: jax.Array | int,
137154
slice_size: jax.Array | int,
138155
block_size: int,
139156
dim_size: int,
140-
tiling: int):
141-
if block_size % tiling != 0:
157+
tiling: int | None):
158+
if tiling is not None and block_size % tiling != 0:
142159
raise ValueError(f"Block size must divide tiling: {block_size=}, {tiling=}")
160+
143161
# We assume by construction that slice_size <= block_size. We also assume
144162
# that the slice_start is already aligned to the tiling.
145163

146-
# If we are out of bound, we need to round the slice size down to the nearest
147-
# multiple of the tiling.
148-
is_oob = slice_start + slice_size > dim_size
149-
remaining = dim_size - slice_start
150-
rounded_size = jnp.where(
151-
is_oob,
152-
_round_up_to_nearest_multiple(remaining, tiling),
153-
slice_size,
154-
)
155-
rounded_size = pl.multiple_of(rounded_size, tiling)
164+
rounded_size = slice_size
165+
if tiling is not None:
166+
# If we are out of bound, we need to round the slice size down to the nearest
167+
# multiple of the tiling.
168+
is_oob = slice_start + slice_size > dim_size
169+
remaining = dim_size - slice_start
170+
171+
rounded_size = jnp.where(
172+
is_oob,
173+
_round_up_to_nearest_multiple(remaining, tiling),
174+
slice_size,
175+
)
156176
return pl.ds(slice_start, rounded_size)
157177

158178
def _make_block_slice(
159179
block_index: jax.Array, block_size: pl.BlockDim | int | None, size: int,
160-
tiling: int
180+
tiling: int | None
161181
) -> pl.Slice | slice | int | jax.Array:
162182
# Computes a slice given a block index and block size. In the default case,
163183
# we return slice(block_index * block_size, (block_index + 1) * block_size).
@@ -332,7 +352,7 @@ def block_shape(self) -> Sequence[pl.BlockDim | int | None] | None:
332352
def compute_index(self):
333353
return self.spec.index_map
334354

335-
def get_dma_slice(self, src_shape, src_dtype, grid_indices):
355+
def get_dma_slice(self, src_ty, grid_indices):
336356
# We need to handle blocks that might go OOB in the src array. An in bounds
337357
# block looks like this (for array shape (600, 600) and block shape
338358
# (256, 256)):
@@ -379,10 +399,14 @@ def get_dma_slice(self, src_shape, src_dtype, grid_indices):
379399
# Suppose A is now (601, 600), instead of picking a (88, 256)-sized block
380400
# for the last iteration on that dimension, we will pick the next highest
381401
# tile multiple, i.e. (96, 256).
402+
403+
if (src_shape := getattr(src_ty, "shape", None)) is None:
404+
raise ValueError(f'Type {src_ty} does not have a type.')
405+
382406
if len(src_shape) < 2:
383407
raise NotImplementedError("Must use >1D values.")
384408

385-
tiling = _make_tiling(src_shape, src_dtype)
409+
tiling = _make_tiling(src_shape, src_ty)
386410
block_indices = self.compute_index(*grid_indices)
387411
return tuple(
388412
_make_block_slice(bi, bs, ss, t)
@@ -403,6 +427,14 @@ def with_spec(self, spec: pl.BlockSpec) -> BufferedRefBase:
403427
"""Returns a new BufferedRefBase with the given block spec."""
404428
raise NotImplementedError()
405429

430+
def _ref_to_value_aval(ref):
431+
"""Return the inner of a ref, or a ShapedArray for TransformedRefs."""
432+
return (
433+
jax_core.ShapedArray(shape=ref.shape, dtype=ref.dtype)
434+
if isinstance(ref, state_types.TransformedRef)
435+
else jax.typeof(ref).inner_aval
436+
)
437+
406438

407439
# TODO(justinfu): Refactor and rename slot fields to reflect cumulative values
408440
# instead of slot index.
@@ -413,7 +445,6 @@ class BufferedRef(BufferedRefBase):
413445
414446
Attributes:
415447
spec: pallas blockspec.
416-
dtype: dtype for buffers.
417448
buffer_type: enum indicating whether this is an input, output, or in/out
418449
accumulator buffered reference.
419450
window_ref: a multiple-buffer to hold the working and dirty buffers used
@@ -444,7 +475,6 @@ class BufferedRef(BufferedRefBase):
444475
copy.
445476
"""
446477
_spec: pl.BlockSpec = dataclasses.field(metadata=dict(static=True))
447-
dtype: Any = dataclasses.field(metadata=dict(static=True))
448478
_buffer_type: BufferType = dataclasses.field(metadata=dict(static=True))
449479
window_ref: ArrayRef | None
450480
accum_ref: ArrayRef | None
@@ -507,7 +537,7 @@ def buffer_types() -> type[BufferType]:
507537
return BufferType
508538

509539
@classmethod
510-
def create(cls, spec: pl.BlockSpec, dtype, buffer_type, buffer_count,
540+
def create(cls, spec: pl.BlockSpec, dtype_or_type, buffer_type, buffer_count,
511541
needs_swap_ref=True,
512542
grid_rank=None,
513543
use_lookahead=False,
@@ -516,7 +546,8 @@ def create(cls, spec: pl.BlockSpec, dtype, buffer_type, buffer_count,
516546
517547
Args:
518548
spec: pallas blockspec.
519-
dtype: dtype for buffers.
549+
dtype_or_type: dtype or aval for buffers. If an aval, the shape is
550+
ignored.
520551
buffer_type: enum indicating whether this is an input, output, or in/out
521552
accumulator buffered reference.
522553
needs_swap_ref: whether a swap slots tracker needs to be allocated.
@@ -527,9 +558,15 @@ def create(cls, spec: pl.BlockSpec, dtype, buffer_type, buffer_count,
527558
Returns:
528559
Initialized BufferedRef
529560
"""
561+
ty = (
562+
dtype_or_type
563+
if isinstance(dtype_or_type, jax_core.AbstractValue) else
564+
jax_core.ShapedArray((1, 1,), dtype_or_type) # dummy shape
565+
)
566+
530567
block_shape = _get_block_shape(spec)
531568
if buffer_type is BufferType.ACCUMULATOR:
532-
accum_ref = VMEM(block_shape, dtype)
569+
accum_ref = VMEM.from_type(ty.update(shape=block_shape))
533570
else:
534571
accum_ref = None
535572
if source_memory_space == VMEM:
@@ -541,7 +578,6 @@ def create(cls, spec: pl.BlockSpec, dtype, buffer_type, buffer_count,
541578
f"Cannot hold a non-buffered ref in {spec.memory_space=}")
542579
return cls(
543580
_spec=spec,
544-
dtype=dtype,
545581
_buffer_type=buffer_type,
546582
window_ref=None, # to be bound to existing ref by the pipeline routine
547583
accum_ref=accum_ref,
@@ -570,11 +606,12 @@ def create(cls, spec: pl.BlockSpec, dtype, buffer_type, buffer_count,
570606
raise ValueError(
571607
"grid_rank must be specified when use_lookahead is True."
572608
)
609+
610+
buffer_ty = ty.update(shape=(buffer_count, *block_shape))
573611
return cls(
574612
_spec=spec,
575-
dtype=dtype,
576613
_buffer_type=buffer_type,
577-
window_ref=buffer_memory_space((buffer_count,) + block_shape, dtype),
614+
window_ref=buffer_memory_space.from_type(buffer_ty),
578615
accum_ref=accum_ref,
579616
copy_in_slot=SMEM((1,), jnp.uint32) if buffer_type.is_input else None,
580617
wait_in_slot=SMEM((1,), jnp.uint32) if buffer_type.is_input else None,
@@ -601,22 +638,28 @@ def create(cls, spec: pl.BlockSpec, dtype, buffer_type, buffer_count,
601638
)
602639

603640
@classmethod
604-
def input(cls, spec, dtype, buffer_count=2, **kwargs):
605-
return cls.create(spec, dtype, BufferType.INPUT, buffer_count, **kwargs)
641+
def input(cls, spec, dtype_or_type, buffer_count=2, **kwargs):
642+
return cls.create(
643+
spec, dtype_or_type, BufferType.INPUT, buffer_count, **kwargs
644+
)
606645

607646
@classmethod
608-
def output(cls, spec, dtype, buffer_count=2, **kwargs):
609-
return cls.create(spec, dtype, BufferType.OUTPUT, buffer_count, **kwargs)
647+
def output(cls, spec, dtype_or_type, buffer_count=2, **kwargs):
648+
return cls.create(
649+
spec, dtype_or_type, BufferType.OUTPUT, buffer_count, **kwargs
650+
)
610651

611652
@classmethod
612-
def accumulator(cls, spec, dtype, buffer_count=2, **kwargs):
613-
return cls.create(spec, dtype, BufferType.ACCUMULATOR, buffer_count,
614-
**kwargs)
653+
def accumulator(cls, spec, dtype_or_type, buffer_count=2, **kwargs):
654+
return cls.create(
655+
spec, dtype_or_type, BufferType.ACCUMULATOR, buffer_count, **kwargs
656+
)
615657

616658
@classmethod
617-
def input_output(cls, spec, dtype, buffer_count=2, **kwargs):
618-
return cls.create(spec, dtype, BufferType.INPUT_OUTPUT, buffer_count,
619-
**kwargs)
659+
def input_output(cls, spec, dtype_or_type, buffer_count=2, **kwargs):
660+
return cls.create(
661+
spec, dtype_or_type, BufferType.INPUT_OUTPUT, buffer_count, **kwargs
662+
)
620663

621664
@property
622665
def block_shape(self):
@@ -923,7 +966,7 @@ def copy_in(self, src_ref, grid_indices):
923966
if self.swap is not None:
924967
self.swap[0] = True
925968
slot = self.current_copy_in_slot
926-
src_slice = self.get_dma_slice(src_ref.shape, src_ref.dtype, grid_indices)
969+
src_slice = self.get_dma_slice(_ref_to_value_aval(src_ref), grid_indices)
927970
dst_slice = tuple(
928971
pl.ds(0, s.size)
929972
for s, bd in zip(src_slice, self.block_shape)
@@ -944,7 +987,7 @@ def copy_out(self, dst_ref, grid_indices):
944987
if self.swap is not None:
945988
self.swap[0] = True
946989
slot = self.current_copy_out_slot
947-
dst_slice = self.get_dma_slice(dst_ref.shape, dst_ref.dtype, grid_indices)
990+
dst_slice = self.get_dma_slice(_ref_to_value_aval(dst_ref), grid_indices)
948991
src_slice = tuple(
949992
pl.ds(0, s.size)
950993
for s, bd in zip(dst_slice, self.block_shape)
@@ -962,7 +1005,7 @@ def wait_in(self, src_ref, grid_indices):
9621005
if not self.is_buffered: return
9631006
assert not (self.window_ref is None or isinstance(self.window_ref, REF))
9641007
assert self.sem_recvs is not None
965-
src_slice = self.get_dma_slice(src_ref.shape, src_ref.dtype, grid_indices)
1008+
src_slice = self.get_dma_slice(_ref_to_value_aval(src_ref), grid_indices)
9661009
dst_slice = tuple(
9671010
pl.ds(0, s.size)
9681011
for s, bd in zip(src_slice, self.block_shape)
@@ -984,7 +1027,7 @@ def wait_out(self, dst_ref, grid_indices):
9841027
assert not (self.window_ref is None or isinstance(self.window_ref, REF))
9851028
assert self.sem_sends is not None
9861029
wait_slot = self.current_wait_out_slot
987-
dst_slice = self.get_dma_slice(dst_ref.shape, dst_ref.dtype, grid_indices)
1030+
dst_slice = self.get_dma_slice(_ref_to_value_aval(dst_ref), grid_indices)
9881031
src_slice = tuple(
9891032
pl.ds(0, s.size)
9901033
for s, bd in zip(dst_slice, self.block_shape)
@@ -1682,7 +1725,9 @@ def make_input_bref(in_spec, in_ref):
16821725
use_lookahead = in_spec.pipeline_mode.use_lookahead
16831726
if use_lookahead and grid is None:
16841727
raise ValueError("Grid must be specified when using lookahead.")
1685-
return BufferedRef.input(in_spec, in_ref.dtype, buffer_count,
1728+
1729+
in_aval = _ref_to_value_aval(in_ref)
1730+
return BufferedRef.input(in_spec, in_aval, buffer_count,
16861731
needs_swap_ref=needs_swap_ref,
16871732
grid_rank=len(grid),
16881733
use_lookahead=use_lookahead,
@@ -1695,11 +1740,13 @@ def make_output_bref(out_spec, out_ref, accumulate):
16951740
if out_spec.pipeline_mode.use_lookahead:
16961741
raise ValueError("Output buffering does not support lookahead.")
16971742

1743+
out_aval = _ref_to_value_aval(out_ref)
1744+
16981745
if accumulate:
1699-
return BufferedRef.accumulator(out_spec, out_ref.dtype, buffer_count,
1746+
return BufferedRef.accumulator(out_spec, out_aval, buffer_count,
17001747
needs_swap_ref=needs_swap_ref,
17011748
source_memory_space=out_ref.memory_space)
1702-
return BufferedRef.output(out_spec, out_ref.dtype, buffer_count,
1749+
return BufferedRef.output(out_spec, out_aval, buffer_count,
17031750
needs_swap_ref=needs_swap_ref,
17041751
source_memory_space=out_ref.memory_space)
17051752
out_brefs = jax.tree.map(
@@ -1817,7 +1864,7 @@ def sync_copy(src: REF | BufferedRef, dst: REF | BufferedRef, indices):
18171864
bref = dst
18181865
hbm_ref = src
18191866
copy_in = True
1820-
hbm_slice = bref.get_dma_slice(hbm_ref.shape, hbm_ref.dtype, indices)
1867+
hbm_slice = bref.get_dma_slice(_ref_to_value_aval(hbm_ref), indices)
18211868
bref_slice = tuple(
18221869
pl.ds(0, s.size)
18231870
for s, bd in zip(hbm_slice, bref.block_shape)

0 commit comments

Comments
 (0)