2323from typing import Any , Union
2424
2525import jax
26+ from jax import core as jax_core
2627from jax import lax
2728from jax import tree_util
2829from jax ._src import util as jax_util
2930from jax ._src .pallas import core as pallas_core
3031from jax ._src .pallas import primitives as primitives
3132from jax ._src .pallas .mosaic import core as tpu_core
3233from jax ._src .pallas .mosaic import helpers as tpu_helpers
33- from jax ._src .pallas .mosaic import tpu_info
3434from jax ._src .pallas .mosaic import primitives as tpu_primitives
35+ from jax ._src .pallas .mosaic import tpu_info
36+ from jax ._src .state import types as state_types
3537from jax .experimental import pallas as pl
3638import jax .numpy as jnp
37- import numpy as np
3839
3940
4041SMEM = tpu_core .MemorySpace .SMEM
@@ -79,16 +80,25 @@ def add_leaves(i, x):
7980def _get_tpu_generation () -> int :
8081 return tpu_info .get_tpu_info ().generation
8182
82- def _make_tiling (shape : tuple [int , ...], dtype : np .dtype ) -> tuple [int , ...]:
83+
84+ def _make_tiling (
85+ shape : tuple [int , ...], ty : jax_core .AbstractValue
86+ ) -> tuple [int | None , ...]:
8387 # For a n-dimensional shape, returns (8, 128) for the last 2 dimensions
8488 # and 1 for the leading n - 2. For example, (256, 256) -> (8, 128) and
8589 # (2, 3, 128, 128) -> (1, 1, 8, 128).
8690 if len (shape ) < 2 :
8791 raise ValueError (f"Shape must have at least 2 dimensions: { shape = } " )
92+
93+ dtype = getattr (ty , 'dtype' , None )
94+ if dtype is None :
95+ return (None ,) * len (shape )
96+
8897 leading_dims , final_dims = shape [:- 2 ], shape [- 2 :]
8998 # We want to find the minimum power of 2 that fits the second-minor dimension
9099 # of shape, with maximum value 8.
91100 second_minor , _ = final_dims
101+
92102 packing = 4 // dtype .itemsize
93103 max_tiling = _TILING [0 ]
94104 second_minor_tiling = (1 + int (_get_tpu_generation () < 4 )) * packing
@@ -114,17 +124,23 @@ def _make_block_ds(
114124 assert isinstance (out , pl .Slice )
115125 return out
116126
117- def _create_blocked_slice (block_index : jax .Array | int ,
118- block_size : int ,
119- dim_size : int ,
120- tiling : int ):
127+
128+ def _create_blocked_slice (
129+ block_index : jax .Array | int ,
130+ block_size : int ,
131+ dim_size : int ,
132+ tiling : int | None ,
133+ ):
121134 block_start = block_size * block_index
122135 if (dim_rem := dim_size % block_size ) == 0 :
123136 return pl .ds (block_start , block_size )
137+ if tiling is None :
138+ raise ValueError ("If tiling is None, block_size must divide dim_size." )
124139 if block_size % tiling != 0 :
125140 raise ValueError (f"Block size must divide tiling: { block_size = } , { tiling = } " )
126141 num_blocks = pl .cdiv (dim_size , block_size )
127142 is_last = block_index == num_blocks - 1
143+
128144 rounded_size = jnp .where (
129145 is_last ,
130146 _round_up_to_nearest_multiple (dim_rem % block_size , tiling ),
@@ -133,31 +149,35 @@ def _create_blocked_slice(block_index: jax.Array | int,
133149 rounded_size = pl .multiple_of (rounded_size , tiling )
134150 return pl .ds (block_index * block_size , rounded_size )
135151
152+
136153def _create_bounded_slice (slice_start : jax .Array | int ,
137154 slice_size : jax .Array | int ,
138155 block_size : int ,
139156 dim_size : int ,
140- tiling : int ):
141- if block_size % tiling != 0 :
157+ tiling : int | None ):
158+ if tiling is not None and block_size % tiling != 0 :
142159 raise ValueError (f"Block size must divide tiling: { block_size = } , { tiling = } " )
160+
143161 # We assume by construction that slice_size <= block_size. We also assume
144162 # that the slice_start is already aligned to the tiling.
145163
146- # If we are out of bound, we need to round the slice size down to the nearest
147- # multiple of the tiling.
148- is_oob = slice_start + slice_size > dim_size
149- remaining = dim_size - slice_start
150- rounded_size = jnp .where (
151- is_oob ,
152- _round_up_to_nearest_multiple (remaining , tiling ),
153- slice_size ,
154- )
155- rounded_size = pl .multiple_of (rounded_size , tiling )
164+ rounded_size = slice_size
165+ if tiling is not None :
166+ # If we are out of bound, we need to round the slice size down to the nearest
167+ # multiple of the tiling.
168+ is_oob = slice_start + slice_size > dim_size
169+ remaining = dim_size - slice_start
170+
171+ rounded_size = jnp .where (
172+ is_oob ,
173+ _round_up_to_nearest_multiple (remaining , tiling ),
174+ slice_size ,
175+ )
156176 return pl .ds (slice_start , rounded_size )
157177
158178def _make_block_slice (
159179 block_index : jax .Array , block_size : pl .BlockDim | int | None , size : int ,
160- tiling : int
180+ tiling : int | None
161181) -> pl .Slice | slice | int | jax .Array :
162182 # Computes a slice given a block index and block size. In the default case,
163183 # we return slice(block_index * block_size, (block_index + 1) * block_size).
@@ -332,7 +352,7 @@ def block_shape(self) -> Sequence[pl.BlockDim | int | None] | None:
332352 def compute_index (self ):
333353 return self .spec .index_map
334354
335- def get_dma_slice (self , src_shape , src_dtype , grid_indices ):
355+ def get_dma_slice (self , src_ty , grid_indices ):
336356 # We need to handle blocks that might go OOB in the src array. An in bounds
337357 # block looks like this (for array shape (600, 600) and block shape
338358 # (256, 256)):
@@ -379,10 +399,14 @@ def get_dma_slice(self, src_shape, src_dtype, grid_indices):
379399 # Suppose A is now (601, 600), instead of picking a (88, 256)-sized block
380400 # for the last iteration on that dimension, we will pick the next highest
381401 # tile multiple, i.e. (96, 256).
402+
403+ if (src_shape := getattr (src_ty , "shape" , None )) is None :
404+ raise ValueError (f'Type { src_ty } does not have a type.' )
405+
382406 if len (src_shape ) < 2 :
383407 raise NotImplementedError ("Must use >1D values." )
384408
385- tiling = _make_tiling (src_shape , src_dtype )
409+ tiling = _make_tiling (src_shape , src_ty )
386410 block_indices = self .compute_index (* grid_indices )
387411 return tuple (
388412 _make_block_slice (bi , bs , ss , t )
@@ -403,6 +427,14 @@ def with_spec(self, spec: pl.BlockSpec) -> BufferedRefBase:
403427 """Returns a new BufferedRefBase with the given block spec."""
404428 raise NotImplementedError ()
405429
430+ def _ref_to_value_aval (ref ):
431+ """Return the inner of a ref, or a ShapedArray for TransformedRefs."""
432+ return (
433+ jax_core .ShapedArray (shape = ref .shape , dtype = ref .dtype )
434+ if isinstance (ref , state_types .TransformedRef )
435+ else jax .typeof (ref ).inner_aval
436+ )
437+
406438
407439# TODO(justinfu): Refactor and rename slot fields to reflect cumulative values
408440# instead of slot index.
@@ -413,7 +445,6 @@ class BufferedRef(BufferedRefBase):
413445
414446 Attributes:
415447 spec: pallas blockspec.
416- dtype: dtype for buffers.
417448 buffer_type: enum indicating whether this is an input, output, or in/out
418449 accumulator buffered reference.
419450 window_ref: a multiple-buffer to hold the working and dirty buffers used
@@ -444,7 +475,6 @@ class BufferedRef(BufferedRefBase):
444475 copy.
445476 """
446477 _spec : pl .BlockSpec = dataclasses .field (metadata = dict (static = True ))
447- dtype : Any = dataclasses .field (metadata = dict (static = True ))
448478 _buffer_type : BufferType = dataclasses .field (metadata = dict (static = True ))
449479 window_ref : ArrayRef | None
450480 accum_ref : ArrayRef | None
@@ -507,7 +537,7 @@ def buffer_types() -> type[BufferType]:
507537 return BufferType
508538
509539 @classmethod
510- def create (cls , spec : pl .BlockSpec , dtype , buffer_type , buffer_count ,
540+ def create (cls , spec : pl .BlockSpec , dtype_or_type , buffer_type , buffer_count ,
511541 needs_swap_ref = True ,
512542 grid_rank = None ,
513543 use_lookahead = False ,
@@ -516,7 +546,8 @@ def create(cls, spec: pl.BlockSpec, dtype, buffer_type, buffer_count,
516546
517547 Args:
518548 spec: pallas blockspec.
519- dtype: dtype for buffers.
549+ dtype_or_type: dtype or aval for buffers. If an aval, the shape is
550+ ignored.
520551 buffer_type: enum indicating whether this is an input, output, or in/out
521552 accumulator buffered reference.
522553 needs_swap_ref: whether a swap slots tracker needs to be allocated.
@@ -527,9 +558,15 @@ def create(cls, spec: pl.BlockSpec, dtype, buffer_type, buffer_count,
527558 Returns:
528559 Initialized BufferedRef
529560 """
561+ ty = (
562+ dtype_or_type
563+ if isinstance (dtype_or_type , jax_core .AbstractValue ) else
564+ jax_core .ShapedArray ((1 , 1 ,), dtype_or_type ) # dummy shape
565+ )
566+
530567 block_shape = _get_block_shape (spec )
531568 if buffer_type is BufferType .ACCUMULATOR :
532- accum_ref = VMEM ( block_shape , dtype )
569+ accum_ref = VMEM . from_type ( ty . update ( shape = block_shape ) )
533570 else :
534571 accum_ref = None
535572 if source_memory_space == VMEM :
@@ -541,7 +578,6 @@ def create(cls, spec: pl.BlockSpec, dtype, buffer_type, buffer_count,
541578 f"Cannot hold a non-buffered ref in { spec .memory_space = } " )
542579 return cls (
543580 _spec = spec ,
544- dtype = dtype ,
545581 _buffer_type = buffer_type ,
546582 window_ref = None , # to be bound to existing ref by the pipeline routine
547583 accum_ref = accum_ref ,
@@ -570,11 +606,12 @@ def create(cls, spec: pl.BlockSpec, dtype, buffer_type, buffer_count,
570606 raise ValueError (
571607 "grid_rank must be specified when use_lookahead is True."
572608 )
609+
610+ buffer_ty = ty .update (shape = (buffer_count , * block_shape ))
573611 return cls (
574612 _spec = spec ,
575- dtype = dtype ,
576613 _buffer_type = buffer_type ,
577- window_ref = buffer_memory_space (( buffer_count ,) + block_shape , dtype ),
614+ window_ref = buffer_memory_space . from_type ( buffer_ty ),
578615 accum_ref = accum_ref ,
579616 copy_in_slot = SMEM ((1 ,), jnp .uint32 ) if buffer_type .is_input else None ,
580617 wait_in_slot = SMEM ((1 ,), jnp .uint32 ) if buffer_type .is_input else None ,
@@ -601,22 +638,28 @@ def create(cls, spec: pl.BlockSpec, dtype, buffer_type, buffer_count,
601638 )
602639
603640 @classmethod
604- def input (cls , spec , dtype , buffer_count = 2 , ** kwargs ):
605- return cls .create (spec , dtype , BufferType .INPUT , buffer_count , ** kwargs )
641+ def input (cls , spec , dtype_or_type , buffer_count = 2 , ** kwargs ):
642+ return cls .create (
643+ spec , dtype_or_type , BufferType .INPUT , buffer_count , ** kwargs
644+ )
606645
607646 @classmethod
608- def output (cls , spec , dtype , buffer_count = 2 , ** kwargs ):
609- return cls .create (spec , dtype , BufferType .OUTPUT , buffer_count , ** kwargs )
647+ def output (cls , spec , dtype_or_type , buffer_count = 2 , ** kwargs ):
648+ return cls .create (
649+ spec , dtype_or_type , BufferType .OUTPUT , buffer_count , ** kwargs
650+ )
610651
611652 @classmethod
612- def accumulator (cls , spec , dtype , buffer_count = 2 , ** kwargs ):
613- return cls .create (spec , dtype , BufferType .ACCUMULATOR , buffer_count ,
614- ** kwargs )
653+ def accumulator (cls , spec , dtype_or_type , buffer_count = 2 , ** kwargs ):
654+ return cls .create (
655+ spec , dtype_or_type , BufferType .ACCUMULATOR , buffer_count , ** kwargs
656+ )
615657
616658 @classmethod
617- def input_output (cls , spec , dtype , buffer_count = 2 , ** kwargs ):
618- return cls .create (spec , dtype , BufferType .INPUT_OUTPUT , buffer_count ,
619- ** kwargs )
659+ def input_output (cls , spec , dtype_or_type , buffer_count = 2 , ** kwargs ):
660+ return cls .create (
661+ spec , dtype_or_type , BufferType .INPUT_OUTPUT , buffer_count , ** kwargs
662+ )
620663
621664 @property
622665 def block_shape (self ):
@@ -923,7 +966,7 @@ def copy_in(self, src_ref, grid_indices):
923966 if self .swap is not None :
924967 self .swap [0 ] = True
925968 slot = self .current_copy_in_slot
926- src_slice = self .get_dma_slice (src_ref . shape , src_ref . dtype , grid_indices )
969+ src_slice = self .get_dma_slice (_ref_to_value_aval ( src_ref ) , grid_indices )
927970 dst_slice = tuple (
928971 pl .ds (0 , s .size )
929972 for s , bd in zip (src_slice , self .block_shape )
@@ -944,7 +987,7 @@ def copy_out(self, dst_ref, grid_indices):
944987 if self .swap is not None :
945988 self .swap [0 ] = True
946989 slot = self .current_copy_out_slot
947- dst_slice = self .get_dma_slice (dst_ref . shape , dst_ref . dtype , grid_indices )
990+ dst_slice = self .get_dma_slice (_ref_to_value_aval ( dst_ref ) , grid_indices )
948991 src_slice = tuple (
949992 pl .ds (0 , s .size )
950993 for s , bd in zip (dst_slice , self .block_shape )
@@ -962,7 +1005,7 @@ def wait_in(self, src_ref, grid_indices):
9621005 if not self .is_buffered : return
9631006 assert not (self .window_ref is None or isinstance (self .window_ref , REF ))
9641007 assert self .sem_recvs is not None
965- src_slice = self .get_dma_slice (src_ref . shape , src_ref . dtype , grid_indices )
1008+ src_slice = self .get_dma_slice (_ref_to_value_aval ( src_ref ) , grid_indices )
9661009 dst_slice = tuple (
9671010 pl .ds (0 , s .size )
9681011 for s , bd in zip (src_slice , self .block_shape )
@@ -984,7 +1027,7 @@ def wait_out(self, dst_ref, grid_indices):
9841027 assert not (self .window_ref is None or isinstance (self .window_ref , REF ))
9851028 assert self .sem_sends is not None
9861029 wait_slot = self .current_wait_out_slot
987- dst_slice = self .get_dma_slice (dst_ref . shape , dst_ref . dtype , grid_indices )
1030+ dst_slice = self .get_dma_slice (_ref_to_value_aval ( dst_ref ) , grid_indices )
9881031 src_slice = tuple (
9891032 pl .ds (0 , s .size )
9901033 for s , bd in zip (dst_slice , self .block_shape )
@@ -1682,7 +1725,9 @@ def make_input_bref(in_spec, in_ref):
16821725 use_lookahead = in_spec .pipeline_mode .use_lookahead
16831726 if use_lookahead and grid is None :
16841727 raise ValueError ("Grid must be specified when using lookahead." )
1685- return BufferedRef .input (in_spec , in_ref .dtype , buffer_count ,
1728+
1729+ in_aval = _ref_to_value_aval (in_ref )
1730+ return BufferedRef .input (in_spec , in_aval , buffer_count ,
16861731 needs_swap_ref = needs_swap_ref ,
16871732 grid_rank = len (grid ),
16881733 use_lookahead = use_lookahead ,
@@ -1695,11 +1740,13 @@ def make_output_bref(out_spec, out_ref, accumulate):
16951740 if out_spec .pipeline_mode .use_lookahead :
16961741 raise ValueError ("Output buffering does not support lookahead." )
16971742
1743+ out_aval = _ref_to_value_aval (out_ref )
1744+
16981745 if accumulate :
1699- return BufferedRef .accumulator (out_spec , out_ref . dtype , buffer_count ,
1746+ return BufferedRef .accumulator (out_spec , out_aval , buffer_count ,
17001747 needs_swap_ref = needs_swap_ref ,
17011748 source_memory_space = out_ref .memory_space )
1702- return BufferedRef .output (out_spec , out_ref . dtype , buffer_count ,
1749+ return BufferedRef .output (out_spec , out_aval , buffer_count ,
17031750 needs_swap_ref = needs_swap_ref ,
17041751 source_memory_space = out_ref .memory_space )
17051752 out_brefs = jax .tree .map (
@@ -1817,7 +1864,7 @@ def sync_copy(src: REF | BufferedRef, dst: REF | BufferedRef, indices):
18171864 bref = dst
18181865 hbm_ref = src
18191866 copy_in = True
1820- hbm_slice = bref .get_dma_slice (hbm_ref . shape , hbm_ref . dtype , indices )
1867+ hbm_slice = bref .get_dma_slice (_ref_to_value_aval ( hbm_ref ) , indices )
18211868 bref_slice = tuple (
18221869 pl .ds (0 , s .size )
18231870 for s , bd in zip (hbm_slice , bref .block_shape )
0 commit comments