Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
82979ef
fix grad, translate and lower vjp
albi3ro Dec 9, 2025
3f0f17b
Merge branch 'main' into capture-vjp-lowering
albi3ro Dec 9, 2025
cd314a4
[skip-ci] more fixeS
albi3ro Dec 11, 2025
8f2f98e
fixes
albi3ro Dec 11, 2025
685569f
Merge branch 'main' into capture-vjp-lowering
albi3ro Dec 16, 2025
83b215f
black, isort, etc. [skip-ci]
albi3ro Dec 16, 2025
b2ead14
Merge branch 'capture-vjp-lowering' of https://github.com/PennyLaneAI…
albi3ro Dec 16, 2025
180dc83
Merge branch 'main' into capture-vjp-lowering
albi3ro Dec 16, 2025
87c2afb
switch to using pl namespace
albi3ro Dec 17, 2025
65384f2
Merge branch 'main' into capture-vjp-lowering
albi3ro Jan 19, 2026
fcd9e59
Merge branch 'main' into capture-vjp-lowering
albi3ro Jan 22, 2026
480bee1
remove targeting special branch
albi3ro Jan 22, 2026
072b76a
Merge branch 'capture-vjp-lowering' of https://github.com/PennyLaneAI…
albi3ro Jan 22, 2026
a469236
Merge branch 'main' into capture-vjp-lowering
albi3ro Jan 22, 2026
7ff8277
polishing up now pennylane PR merged
albi3ro Jan 22, 2026
0f51f93
Merge branch 'main' into capture-vjp-lowering
albi3ro Jan 22, 2026
80ce261
pylint
albi3ro Jan 22, 2026
9636df2
Merge branch 'capture-vjp-lowering' of https://github.com/PennyLaneAI…
albi3ro Jan 22, 2026
b5b11c0
bump version, black
albi3ro Jan 22, 2026
fbeab5e
Update frontend/test/pytest/test_autograph.py
albi3ro Jan 23, 2026
2da94dc
Apply suggestion from @albi3ro
albi3ro Jan 23, 2026
8fa6167
Update frontend/test/pytest/test_autograph.py
albi3ro Jan 23, 2026
f0130c0
black
albi3ro Jan 23, 2026
9268881
Merge branch 'main' into capture-vjp-lowering
albi3ro Jan 23, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .dep-versions
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ enzyme=v0.0.203

# For a custom PL version, update the package version here and at
# 'doc/requirements.txt'
pennylane=0.45.0-dev12
pennylane=0.45.0-dev13

# For a custom LQ/LK version, update the package version here and at
# 'doc/requirements.txt'
Expand Down
7 changes: 7 additions & 0 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,15 @@
length and the number of qubit operands are the same, and that all of the Pauli words are legal.
[(#2405)](https://github.com/PennyLaneAI/catalyst/pull/2405)

* `qml.vjp` can now be used with Catalyst and program capture.
[(#2279)](https://github.com/PennyLaneAI/catalyst/pull/2279)

<h3>Breaking changes 💔</h3>

* When an integer argnums is provided to `catalyst.vjp`, a singleton dimension is now squeezed
out. This brings the behaviour in line with that of `grad` and `jacobian`.
[(#2279)](https://github.com/PennyLaneAI/catalyst/pull/2279)

* Dropped support for NumPy 1.x following its end-of-life. NumPy 2.0 or higher is now required.
[(#2407)](https://github.com/PennyLaneAI/catalyst/pull/2407)

Expand Down
2 changes: 1 addition & 1 deletion doc/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -34,4 +34,4 @@ lxml_html_clean
--extra-index-url https://test.pypi.org/simple/
pennylane-lightning-kokkos==0.45.0-dev8
pennylane-lightning==0.45.0-dev8
pennylane==0.45.0-dev12
pennylane==0.45.0-dev13
9 changes: 7 additions & 2 deletions frontend/catalyst/api_extensions/differentiation.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from typing import Callable, Iterable, List, Optional, Union

import jax
import pennylane as qml
from jax._src.api import _dtype
from jax._src.tree_util import PyTreeDef, tree_flatten, tree_unflatten
from jax.api_util import debug_info
Expand Down Expand Up @@ -546,13 +547,14 @@ def f(x):
(Array([0.09983342, 0.04 , 0.02 ], dtype=float64),
(Array([-0.43750208, 0.07 ], dtype=float64),))
"""
if qml.capture.enabled():
return qml.vjp(f, params, cotangents, method=method, h=h, argnums=argnums)

def check_is_iterable(x, hint):
if not isinstance(x, Iterable):
raise ValueError(f"vjp '{hint}' argument must be an iterable, not {type(x)}")

check_is_iterable(params, "params")
check_is_iterable(cotangents, "cotangents")

if EvaluationContext.is_tracing():
scalar_out = False
Expand All @@ -564,7 +566,10 @@ def check_is_iterable(x, hint):
grad_params = _check_grad_params(method, scalar_out, h, argnums, len(args_flatten), in_tree)

args_argnums = tuple(params[i] for i in grad_params.argnums)
_, in_tree = tree_flatten(args_argnums)
if isinstance(argnums, int) or argnums is None:
_, in_tree = tree_flatten(0)
else:
_, in_tree = tree_flatten(args_argnums)

jaxpr, out_tree = _make_jaxpr_check_differentiable(fn, grad_params, *params)

Expand Down
4 changes: 4 additions & 0 deletions frontend/catalyst/autograph/ag_primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -532,6 +532,10 @@ def converted_call(fn, args, kwargs, caller_fn_scope=None, options=None):
qml.prod,
catalyst.ctrl,
qml.ctrl,
qml.grad,
qml.jacobian,
qml.vjp,
qml.jvp,
catalyst.grad,
catalyst.value_and_grad,
catalyst.jacobian,
Expand Down
13 changes: 13 additions & 0 deletions frontend/catalyst/from_plxpr/from_plxpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from pennylane.capture.expand_transforms import ExpandTransformsInterpreter
from pennylane.capture.primitives import jacobian_prim as pl_jac_prim
from pennylane.capture.primitives import transform_prim
from pennylane.capture.primitives import vjp_prim as pl_vjp_prim
from pennylane.transforms import commute_controlled as pl_commute_controlled
from pennylane.transforms import decompose as pl_decompose
from pennylane.transforms import gridsynth as pl_gridsynth
Expand Down Expand Up @@ -239,6 +240,18 @@ def handle_grad(self, *args, jaxpr, n_consts, **kwargs):
)


@WorkflowInterpreter.register_primitive(pl_vjp_prim)
def handle_vjp(self, *args, jaxpr, **kwargs):
"""Translate a grad equation."""
f = partial(copy(self).eval, jaxpr, [])
new_jaxpr = jax.make_jaxpr(f)(*args[: -len(jaxpr.outvars)])

new_args = (*new_jaxpr.consts, *args)
j = new_jaxpr.jaxpr
new_j = j.replace(constvars=(), invars=j.constvars + j.invars)
return pl_vjp_prim.bind(*new_args, jaxpr=new_j, **kwargs)


# pylint: disable=unused-argument, too-many-arguments
@WorkflowInterpreter.register_primitive(qnode_prim)
def handle_qnode(
Expand Down
34 changes: 34 additions & 0 deletions frontend/catalyst/jax_primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,7 @@
)

from pennylane.capture.primitives import jacobian_prim as pl_jac_prim
from pennylane.capture.primitives import vjp_prim as pl_vjp_prim

from catalyst.compiler import get_lib_path
from catalyst.jax_extras import (
Expand Down Expand Up @@ -931,6 +932,38 @@ def _vjp_lowering(ctx, *args, jaxpr, fn, grad_params):
).results


def _capture_vjp_lowering(ctx, *args, jaxpr, fn, method, argnums, h):
"""
Returns:
MLIR results
"""
args = list(args)
mlir_ctx = ctx.module_context.context
n_params = len(jaxpr.invars)
new_argnums = np.array(argnums)

output_types = list(map(mlir.aval_to_ir_types, ctx.avals_out))
flat_output_types = util.flatten(output_types)
func_args = args[:n_params]
cotang_args = args[n_params:]
func_result_types = flat_output_types[: len(flat_output_types) - len(argnums)]
vjp_result_types = flat_output_types[len(flat_output_types) - len(argnums) :]

func_op = lower_jaxpr(ctx, jaxpr, (method, h, *argnums), fn=fn)

symbol_ref = get_symbolref(ctx, func_op)
return VJPOp(
func_result_types,
vjp_result_types,
ir.StringAttr.get(method),
symbol_ref,
mlir.flatten_ir_values(func_args),
mlir.flatten_ir_values(cotang_args),
diffArgIndices=ir.DenseIntElementsAttr.get(new_argnums),
finiteDiffParam=ir.FloatAttr.get(ir.F64Type.get(mlir_ctx), h) if h else None,
).results


#
# zne
#
Expand Down Expand Up @@ -2843,6 +2876,7 @@ def subroutine_lowering(*args, **kwargs):
(for_p, _for_loop_lowering),
(grad_p, _grad_lowering),
(pl_jac_prim, _capture_grad_lowering),
(pl_vjp_prim, _capture_vjp_lowering),
(func_p, _func_lowering),
(jvp_p, _jvp_lowering),
(vjp_p, _vjp_lowering),
Expand Down
11 changes: 9 additions & 2 deletions frontend/test/pytest/test_autograph.py
Original file line number Diff line number Diff line change
Expand Up @@ -396,19 +396,26 @@ def fn(x: float):
assert check_cache(inner)
assert fn(3) == tuple([jax.numpy.array(2.0), jax.numpy.array(6.0)])

@pytest.mark.usefixtures("use_both_frontend")
@pytest.mark.parametrize("vjp_func", [vjp, qml.vjp])
def test_vjp_wrapper(self, vjp_func):
"""Test conversion is happening succesfully on functions wrapped with 'vjp'."""

if qml.capture.enabled() and vjp_func == vjp: # pylint: disable=comparison-with-callable
pytest.xfail("program capture autograph doesn't work with catalyst.vjp")

def inner(x):
return 2 * x, x**2
if x > 0:
return 2 * x, x**2
return 4 * x, x**8

@qjit(autograph=True)
def fn(x: float):
return vjp_func(inner, (x,), (1.0, 1.0))

assert hasattr(fn.user_function, "ag_unconverted")
assert check_cache(inner)
if not qml.capture.enabled():
assert check_cache(inner)
assert np.allclose(fn(3)[0], tuple([jnp.array(6.0), jnp.array(9.0)]))
assert np.allclose(fn(3)[1], jnp.array(8.0))

Expand Down
Loading