Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion src/google/adk/flows/llm_flows/_nl_planning.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,11 @@ async def run_async(
return

planner = _get_planner(invocation_context)
if not planner or isinstance(planner, BuiltInPlanner):
if (
not planner
or type(planner).process_planning_response
is BuiltInPlanner.process_planning_response
):
return

# Postprocess the LLM response.
Expand Down
92 changes: 92 additions & 0 deletions tests/unittests/flows/llm_flows/test_nl_planning.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,17 @@

"""Unit tests for NL planning logic."""

from typing import List
from typing import Optional
from unittest.mock import MagicMock
from unittest.mock import patch

from google.adk.agents.callback_context import CallbackContext
from google.adk.agents.llm_agent import Agent
from google.adk.flows.llm_flows._nl_planning import request_processor
from google.adk.flows.llm_flows._nl_planning import response_processor
from google.adk.models.llm_request import LlmRequest
from google.adk.models.llm_response import LlmResponse
from google.adk.planners.built_in_planner import BuiltInPlanner
from google.adk.planners.plan_re_act_planner import PlanReActPlanner
from google.genai import types
Expand Down Expand Up @@ -126,3 +132,89 @@ async def test_remove_thought_from_request_with_thoughts():
for content in llm_request.contents
for part in content.parts or []
)


class OverriddenBuiltInPlanner(BuiltInPlanner):
"""Subclass that overrides process_planning_response."""

def __init__(self, *, thinking_config: types.ThinkingConfig):
super().__init__(thinking_config=thinking_config)
self.process_planning_response_called = False
self.received_parts = None

def process_planning_response(
self,
callback_context: CallbackContext,
response_parts: List[types.Part],
) -> Optional[List[types.Part]]:
self.process_planning_response_called = True
self.received_parts = response_parts
return response_parts


class NonOverriddenBuiltInPlanner(BuiltInPlanner):
"""Subclass that does NOT override process_planning_response."""

pass


@pytest.mark.asyncio
async def test_overridden_subclass_process_planning_response_called():
"""Test that subclasses overriding process_planning_response have it called.

Regression test for issue #4133.
"""
planner = OverriddenBuiltInPlanner(thinking_config=types.ThinkingConfig())
agent = Agent(name='test_agent', planner=planner)
invocation_context = await testing_utils.create_invocation_context(
agent=agent, user_content='test message'
)

response_parts = [
types.Part(text='thinking...', thought=True),
types.Part(text='Here is my response'),
]
llm_response = LlmResponse(
content=types.Content(role='model', parts=response_parts)
)

async for _ in response_processor.run_async(invocation_context, llm_response):
pass

assert planner.process_planning_response_called
assert planner.received_parts == response_parts


@pytest.mark.asyncio
@pytest.mark.parametrize(
'planner_class',
[BuiltInPlanner, NonOverriddenBuiltInPlanner],
ids=['base_class', 'non_overridden_subclass'],
)
async def test_process_planning_response_not_called_without_override(
planner_class,
):
"""Test that process_planning_response is not called for base or non-overridden subclasses."""
planner = planner_class(thinking_config=types.ThinkingConfig())
agent = Agent(name='test_agent', planner=planner)
invocation_context = await testing_utils.create_invocation_context(
agent=agent, user_content='test message'
)

response_parts = [
types.Part(text='thinking...', thought=True),
types.Part(text='Here is my response'),
]
llm_response = LlmResponse(
content=types.Content(role='model', parts=response_parts)
)

with patch.object(
BuiltInPlanner,
'process_planning_response',
) as mock_method:
Comment on lines +212 to +215
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The current test patches BuiltInPlanner at the class level. While this works, it's a bit subtle as it relies on the patch interacting with the is check in the implementation. A clearer and more robust approach is to patch the method on the planner instance. This better isolates the test's scope and verifies the behavior without being tightly coupled to the implementation of the conditional check.

  with patch.object(planner, 'process_planning_response') as mock_method:

async for _ in response_processor.run_async(
invocation_context, llm_response
):
pass
mock_method.assert_not_called()