diff --git a/src/google/adk/flows/llm_flows/_nl_planning.py b/src/google/adk/flows/llm_flows/_nl_planning.py index c81648ea73..066c0f7d87 100644 --- a/src/google/adk/flows/llm_flows/_nl_planning.py +++ b/src/google/adk/flows/llm_flows/_nl_planning.py @@ -82,7 +82,11 @@ async def run_async( return planner = _get_planner(invocation_context) - if not planner or isinstance(planner, BuiltInPlanner): + if ( + not planner + or type(planner).process_planning_response + is BuiltInPlanner.process_planning_response + ): return # Postprocess the LLM response. diff --git a/tests/unittests/flows/llm_flows/test_nl_planning.py b/tests/unittests/flows/llm_flows/test_nl_planning.py index e4bdff7332..53a1b8d05a 100644 --- a/tests/unittests/flows/llm_flows/test_nl_planning.py +++ b/tests/unittests/flows/llm_flows/test_nl_planning.py @@ -14,11 +14,17 @@ """Unit tests for NL planning logic.""" +from typing import List +from typing import Optional from unittest.mock import MagicMock +from unittest.mock import patch +from google.adk.agents.callback_context import CallbackContext from google.adk.agents.llm_agent import Agent from google.adk.flows.llm_flows._nl_planning import request_processor +from google.adk.flows.llm_flows._nl_planning import response_processor from google.adk.models.llm_request import LlmRequest +from google.adk.models.llm_response import LlmResponse from google.adk.planners.built_in_planner import BuiltInPlanner from google.adk.planners.plan_re_act_planner import PlanReActPlanner from google.genai import types @@ -126,3 +132,89 @@ async def test_remove_thought_from_request_with_thoughts(): for content in llm_request.contents for part in content.parts or [] ) + + +class OverriddenBuiltInPlanner(BuiltInPlanner): + """Subclass that overrides process_planning_response.""" + + def __init__(self, *, thinking_config: types.ThinkingConfig): + super().__init__(thinking_config=thinking_config) + self.process_planning_response_called = False + self.received_parts = None + + def process_planning_response( + self, + callback_context: CallbackContext, + response_parts: List[types.Part], + ) -> Optional[List[types.Part]]: + self.process_planning_response_called = True + self.received_parts = response_parts + return response_parts + + +class NonOverriddenBuiltInPlanner(BuiltInPlanner): + """Subclass that does NOT override process_planning_response.""" + + pass + + +@pytest.mark.asyncio +async def test_overridden_subclass_process_planning_response_called(): + """Test that subclasses overriding process_planning_response have it called. + + Regression test for issue #4133. + """ + planner = OverriddenBuiltInPlanner(thinking_config=types.ThinkingConfig()) + agent = Agent(name='test_agent', planner=planner) + invocation_context = await testing_utils.create_invocation_context( + agent=agent, user_content='test message' + ) + + response_parts = [ + types.Part(text='thinking...', thought=True), + types.Part(text='Here is my response'), + ] + llm_response = LlmResponse( + content=types.Content(role='model', parts=response_parts) + ) + + async for _ in response_processor.run_async(invocation_context, llm_response): + pass + + assert planner.process_planning_response_called + assert planner.received_parts == response_parts + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + 'planner_class', + [BuiltInPlanner, NonOverriddenBuiltInPlanner], + ids=['base_class', 'non_overridden_subclass'], +) +async def test_process_planning_response_not_called_without_override( + planner_class, +): + """Test that process_planning_response is not called for base or non-overridden subclasses.""" + planner = planner_class(thinking_config=types.ThinkingConfig()) + agent = Agent(name='test_agent', planner=planner) + invocation_context = await testing_utils.create_invocation_context( + agent=agent, user_content='test message' + ) + + response_parts = [ + types.Part(text='thinking...', thought=True), + types.Part(text='Here is my response'), + ] + llm_response = LlmResponse( + content=types.Content(role='model', parts=response_parts) + ) + + with patch.object( + BuiltInPlanner, + 'process_planning_response', + ) as mock_method: + async for _ in response_processor.run_async( + invocation_context, llm_response + ): + pass + mock_method.assert_not_called()