diff --git a/src/google/adk/flows/llm_flows/base_llm_flow.py b/src/google/adk/flows/llm_flows/base_llm_flow.py index 91b57cb873..7d7f157233 100644 --- a/src/google/adk/flows/llm_flows/base_llm_flow.py +++ b/src/google/adk/flows/llm_flows/base_llm_flow.py @@ -19,6 +19,7 @@ import datetime import inspect import logging +from typing import Any from typing import AsyncGenerator from typing import cast from typing import Optional @@ -69,6 +70,225 @@ # Statistics configuration DEFAULT_ENABLE_CACHE_STATISTICS = False +# Guardrail constants +MAX_CONSECUTIVE_REFUSED_FUNCTION_CALLS = 3 +"""Max consecutive FCs after hitting maximum_remote_calls before guardrail.""" + +_GUARDRAIL_INSTRUCTION = ( + '\n\n**IMPORTANT: You have reached the maximum number of function ' + 'calls allowed. You MUST provide a final text response to the user ' + 'now. DO NOT attempt to call any more functions. Summarize what you ' + 'have learned so far and provide a helpful response based on the ' + 'information already gathered.**' +) +"""System instruction added during guardrail to force text response.""" + + +class GuardrailContext: + """Manages guardrail state coordination between flow methods.""" + + def __init__(self, session_state: dict[str, Any]): + self._state = session_state + self._active_key = '_adk_guardrail_active' + self._processed_key = '_adk_guardrail_processed' + + @property + def is_active(self) -> bool: + """True if guardrail should be applied to next LLM call.""" + return self._state.get(self._active_key, False) + + @property + def is_processed(self) -> bool: + """True if preprocessing already handled guardrail.""" + return self._state.get(self._processed_key, False) + + def activate(self) -> None: + """Mark guardrail as active for next LLM call.""" + self._state[self._active_key] = True + + def mark_processed(self) -> None: + """Mark that preprocessing handled guardrail.""" + self._state[self._processed_key] = True + + def clear_active(self) -> None: + """Clear active flag only.""" + self._state.pop(self._active_key, None) + + def clear_processed(self) -> None: + """Clear processed flag only.""" + self._state.pop(self._processed_key, None) + + def clear(self) -> None: + """Clear all guardrail flags.""" + self.clear_active() + self.clear_processed() + + def __repr__(self) -> str: + """Return debug representation.""" + return ( + f'GuardrailContext(active={self.is_active}, ' + f'processed={self.is_processed})' + ) + + +def _count_function_call_events( + invocation_context: InvocationContext, +) -> Optional[int]: + """Count FC events in current invocation. Returns None on error.""" + try: + events = invocation_context._get_events( + current_invocation=True, current_branch=True + ) + return sum(1 for e in events if e.get_function_calls()) + except (AttributeError, KeyError, TypeError, ValueError) as ex: + logger.error('Error counting FC events: %s', ex, exc_info=True) + return None + + +def _event_has_function_calls(event: Optional[Event]) -> Optional[bool]: + """Check if event has FCs. Returns None on error.""" + if not event: + return False + try: + return event.get_function_calls() + except (AttributeError, TypeError, ValueError) as ex: + logger.error('Error checking FCs in event: %s', ex, exc_info=True) + return None + + +def _check_afc_disabled( + afc_config: Optional[types.AutomaticFunctionCallingConfig], +) -> bool: + """Check if AFC is explicitly disabled.""" + if not afc_config: + return False + if afc_config.disable: + logger.warning('AFC disabled. Stopping loop.') + return True + return False + + +def _check_max_calls_invalid( + afc_config: Optional[types.AutomaticFunctionCallingConfig], +) -> bool: + """Check if maximum_remote_calls is invalid (<= 0).""" + if not afc_config or afc_config.maximum_remote_calls is None: + return False + if afc_config.maximum_remote_calls <= 0: + logger.warning( + 'max_remote_calls %s <= 0. Disabling AFC.', + afc_config.maximum_remote_calls, + ) + return True + return False + + +def _should_stop_afc_loop( + llm_request: LlmRequest, + invocation_context: InvocationContext, + last_event: Optional[Event], + count_current_event: bool = False, + consecutive_refused_fcs: int = 0, + guardrail_in_progress: bool = False, +) -> tuple[bool, int]: + """Check if AFC loop should stop. + + Args: + llm_request: LLM request with config. + invocation_context: Invocation context with session. + last_event: Last event from current step. + count_current_event: If True, include current event (pre-execution check). + consecutive_refused_fcs: Count of consecutive refused FCs. + guardrail_in_progress: True if in guardrail final iteration. + + Returns: + (should_stop, new_consecutive_count): stop decision and updated count. + """ + afc_config: Optional[types.AutomaticFunctionCallingConfig] = ( + llm_request.config and llm_request.config.automatic_function_calling + ) + if not afc_config: + return False, 0 + + # Check AFC disabled or invalid config + if _check_afc_disabled(afc_config): + return True, 0 + if _check_max_calls_invalid(afc_config): + return True, 0 + + # Check maximum_remote_calls limit + if afc_config.maximum_remote_calls is not None: + max_calls = afc_config.maximum_remote_calls + + # Count FC events, fail safe on error + function_call_count = _count_function_call_events(invocation_context) + if function_call_count is None: + logger.error('FC count failed. Stopping AFC loop (fail safe).') + return True, consecutive_refused_fcs + + # Pre-execution check: prevent execution if would exceed limit + if count_current_event: + has_function_calls = _event_has_function_calls(last_event) + if has_function_calls is None: + logger.error('FC check failed. Preventing execution (fail safe).') + return True, 0 + + if has_function_calls and function_call_count > max_calls: + logger.warning( + 'Would exceed max_remote_calls=%s. Not executing FCs.', max_calls + ) + return True, 0 + return False, 0 + + # Loop continuation check: guardrail for consecutive refused FCs + if not last_event: + return False, 0 + + has_function_calls = _event_has_function_calls(last_event) + if has_function_calls is None: + logger.error('FC check in loop failed. Stopping AFC (fail safe).') + return True, consecutive_refused_fcs + + if has_function_calls: + # Check if at or over limit (>= not >) + if function_call_count >= max_calls: + new_count = consecutive_refused_fcs + 1 + logger.debug( + 'LLM returned FCs after limit (%d/%d). max=%d, count=%d', + new_count, + MAX_CONSECUTIVE_REFUSED_FUNCTION_CALLS, + max_calls, + function_call_count, + ) + + # Don't re-trigger if in guardrail final iteration + if guardrail_in_progress: + logger.warning( + 'Guardrail final iteration has FCs (count=%d). ' + 'The outer loop will terminate after this iteration.', + new_count, + ) + return False, new_count + + if new_count >= MAX_CONSECUTIVE_REFUSED_FUNCTION_CALLS: + logger.info( + 'Guardrail triggered: %d consecutive FCs after max=%d. ' + 'Forcing text response.', + new_count, + max_calls, + ) + return True, new_count + # Under threshold, continue with incremented count + return False, new_count + else: + # FCs within limit - reset counter + return False, 0 + else: + # No FCs - preserve count to prevent bypass via alternation + return False, consecutive_refused_fcs + + return False, 0 + class BaseLlmFlow(ABC): """A basic flow that calls the LLM in a loop until a final response is generated. @@ -359,16 +579,120 @@ async def run_async( self, invocation_context: InvocationContext ) -> AsyncGenerator[Event, None]: """Runs the flow.""" - while True: - last_event = None - async with Aclosing(self._run_one_step_async(invocation_context)) as agen: - async for event in agen: - last_event = event - yield event - if not last_event or last_event.is_final_response() or last_event.partial: - if last_event and last_event.partial: - logger.warning('The last event is partial, which is not expected.') - break + # Build llm_request once for config checks + llm_request: LlmRequest = LlmRequest() + agent: BaseAgent = invocation_context.agent + llm_request.config = ( + agent.generate_content_config.model_copy(deep=True) + if agent.generate_content_config + else types.GenerateContentConfig() + ) + + consecutive_refused_fcs: int = 0 + guardrail_triggered: bool = False + + guardrail = GuardrailContext(invocation_context.session.state) + invocation_context._guardrail = guardrail + + try: + while True: + last_event: Optional[Event] = None + async with Aclosing( + self._run_one_step_async(invocation_context) + ) as agen: + async for event in agen: + last_event = event + yield event + + # If we just completed the guardrail final iteration, handle it first + # BEFORE checking is_final_response() to ensure we log the result + if guardrail_triggered: + if not last_event or not last_event.content: + logger.warning( + 'Guardrail yielded no response. User may not have received ' + 'closing message.' + ) + else: + # Check if the final iteration also returned function calls + try: + has_fcs = last_event.get_function_calls() + except (AttributeError, TypeError, ValueError) as ex: + logger.error( + 'Error checking function calls in guardrail iteration: %s', + ex, + exc_info=True, + ) + has_fcs = False + + # Extract any text content even if there are also function calls + text_content = '' + if last_event.content and last_event.content.parts: + text_content = '\n'.join( + [p.text for p in last_event.content.parts if p.text] + ) + + if has_fcs: + if text_content: + logger.info( + 'Guardrail: LLM returned text response (also included ' + 'function calls which were ignored).' + ) + else: + logger.warning( + 'Guardrail: LLM still returned only function calls' + ' despite tools being disabled. User will not receive' + ' response.' + ) + elif not text_content: + logger.warning('Guardrail: LLM returned empty response.') + break + + # Break if there's no event or it's a final response + if ( + not last_event + or last_event.is_final_response() + or last_event.partial + ): + if last_event and last_event.partial: + logger.warning('The last event is partial, which is not expected.') + break + + # Check if we should stop AFC loop based on config and guardrail + # This includes: disable flag check, maximum_remote_calls limit, + # and infinite loop prevention via consecutive refusal tracking. + # Returns (should_stop, updated_consecutive_count) + should_stop, consecutive_refused_fcs = _should_stop_afc_loop( + llm_request, + invocation_context, + last_event, + count_current_event=False, + consecutive_refused_fcs=consecutive_refused_fcs, + guardrail_in_progress=guardrail_triggered, + ) + if should_stop: + # Don't re-trigger guardrail if already in final iteration + if guardrail_triggered: + logger.warning( + 'LLM returned function calls even during guardrail final ' + 'iteration (count=%d). Ending AFC loop.', + consecutive_refused_fcs, + ) + break + + # Check if this is guardrail trigger (vs normal stop) + if consecutive_refused_fcs >= MAX_CONSECUTIVE_REFUSED_FUNCTION_CALLS: + logger.info( + 'Guardrail: removing tools for final LLM call to force text.' + ) + guardrail_triggered = True + guardrail.activate() + # Continue to next iteration (don't break) + else: + # Normal stop (AFC disabled or max_calls <= 0) + break + finally: + # Cleanup: ensure guardrail cleared even if loop exits early + guardrail.clear() async def _run_one_step_async( self, @@ -386,6 +710,17 @@ async def _run_one_step_async( if invocation_context.end_invocation: return + # Final enforcement: ensure AFC disabled after all processors + guardrail = getattr( + invocation_context, '_guardrail', None + ) or GuardrailContext(invocation_context.session.state) + if guardrail.is_processed: + if llm_request.config and llm_request.config.automatic_function_calling: + llm_request.config.automatic_function_calling.disable = True + if llm_request.config: + llm_request.config.tools = None + guardrail.clear_processed() + # Resume the LLM agent based on the last event from the current branch. # 1. User content: continue the normal flow # 2. Function call: call the tool and get the response event. @@ -482,6 +817,25 @@ async def _preprocess_async( if not agent.tools: return + # Skip adding tools if guardrail is active + guardrail = getattr( + invocation_context, '_guardrail', None + ) or GuardrailContext(invocation_context.session.state) + if guardrail.is_active: + logger.info('Guardrail: skipping tools, disabling AFC for text response') + guardrail.mark_processed() + if llm_request.config and llm_request.config.automatic_function_calling: + llm_request.config.automatic_function_calling.disable = True + + # Add system instruction to force text response + if llm_request.config.system_instruction: + llm_request.config.system_instruction += _GUARDRAIL_INSTRUCTION + else: + llm_request.config.system_instruction = _GUARDRAIL_INSTRUCTION.strip() + + guardrail.clear_active() + return + multiple_tools = len(agent.tools) > 1 model = agent.canonical_model for tool_union in agent.tools: @@ -560,6 +914,18 @@ async def _postprocess_async( ): return + # Check if AFC should be disabled - skip function call execution + # Returns (should_stop, _) - we ignore the counter for execution checks + should_stop, _ = _should_stop_afc_loop( + llm_request, + invocation_context, + model_response_event, + count_current_event=True, + ) + if should_stop: + # AFC is disabled or limit reached - don't execute function calls + return + async with Aclosing( self._postprocess_handle_function_calls_async( invocation_context, model_response_event, llm_request @@ -649,6 +1015,18 @@ async def _postprocess_live( # Handles function calls. if model_response_event.get_function_calls(): + # Check if AFC should be disabled - skip function call execution + # Returns (should_stop, _) - we ignore the counter for execution checks + should_stop, _ = _should_stop_afc_loop( + llm_request, + invocation_context, + model_response_event, + count_current_event=True, + ) + if should_stop: + # AFC is disabled or limit reached - don't execute function calls + return + function_response_event = await functions.handle_function_calls_live( invocation_context, model_response_event, llm_request.tools_dict ) diff --git a/tests/unittests/flows/llm_flows/test_afc_config.py b/tests/unittests/flows/llm_flows/test_afc_config.py new file mode 100644 index 0000000000..6a3d7ccb8f --- /dev/null +++ b/tests/unittests/flows/llm_flows/test_afc_config.py @@ -0,0 +1,1076 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for Automatic Function Calling configuration handling. + +Tests for Bug #4133: Ensure that AFC config (disable=True, maximum_remote_calls) +is properly respected and that planner hooks are always called. +""" + +import asyncio +from unittest.mock import MagicMock + +from google.adk.agents.live_request_queue import LiveRequestQueue +from google.adk.agents.llm_agent import Agent +from google.adk.models.llm_response import LlmResponse +from google.adk.planners.built_in_planner import BuiltInPlanner +from google.genai import types +from google.genai.types import Part +import pytest + +from ... import testing_utils + + +@pytest.fixture +def live_test_runner(): + """Fixture that creates a CustomTestRunner for live mode tests. + + This eliminates code duplication across live mode tests by providing + a reusable runner factory that handles live streaming response collection. + """ + + class CustomTestRunner(testing_utils.InMemoryRunner): + """Custom test runner for live mode tests with configurable response collection.""" + + def run_live( + self, + live_request_queue: LiveRequestQueue, + run_config: testing_utils.RunConfig = None, + max_responses: int = 3, + ) -> list[testing_utils.Event]: + collected_responses = [] + + async def consume_responses(session: testing_utils.Session): + run_res = self.runner.run_live( + session=session, + live_request_queue=live_request_queue, + run_config=run_config or testing_utils.RunConfig(), + ) + + async for response in run_res: + collected_responses.append(response) + if len(collected_responses) >= max_responses: + return + + try: + session = self.session + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + try: + loop.run_until_complete( + asyncio.wait_for(consume_responses(session), timeout=5.0) + ) + finally: + loop.close() + except (asyncio.TimeoutError, asyncio.CancelledError): + pass + + return collected_responses + + return CustomTestRunner + + +@pytest.mark.asyncio +async def test_afc_disabled_stops_loop(): + """Test that setting disable=True stops the AFC loop after first response.""" + # Setup: Create a mock model that returns function calls + responses = [ + # First response with function call + Part.from_function_call(name='test_tool', args={'x': 1}), + # Second response (should not be called if AFC is disabled) + 'This should not be returned', + ] + mock_model = testing_utils.MockModel.create(responses=responses) + + call_count = 0 + + def test_tool(x: int) -> int: + nonlocal call_count + call_count += 1 + return x + 1 + + # Create agent with AFC disabled + agent = Agent( + name='test_agent', + model=mock_model, + tools=[test_tool], + generate_content_config=types.GenerateContentConfig( + automatic_function_calling=types.AutomaticFunctionCallingConfig( + disable=True + ) + ), + ) + + runner = testing_utils.InMemoryRunner(agent) + events = list(runner.run('test message')) + + # Assert: Should stop after first LLM call (1 model response with function call) + # The tool should NOT be executed because AFC is disabled + assert call_count == 0, 'Tool should not be called when AFC is disabled' + + # Should have only 1 LLM request (not 2) + assert ( + len(mock_model.requests) == 1 + ), 'Should make only 1 LLM call when AFC is disabled' + + +@pytest.mark.asyncio +async def test_maximum_remote_calls_zero_stops_loop(): + """Test that setting maximum_remote_calls=0 stops the AFC loop.""" + responses = [ + Part.from_function_call(name='test_tool', args={'x': 1}), + 'This should not be returned', + ] + mock_model = testing_utils.MockModel.create(responses=responses) + + call_count = 0 + + def test_tool(x: int) -> int: + nonlocal call_count + call_count += 1 + return x + 1 + + agent = Agent( + name='test_agent', + model=mock_model, + tools=[test_tool], + generate_content_config=types.GenerateContentConfig( + automatic_function_calling=types.AutomaticFunctionCallingConfig( + maximum_remote_calls=0 + ) + ), + ) + + runner = testing_utils.InMemoryRunner(agent) + events = list(runner.run('test message')) + + # Tool should not be executed + assert ( + call_count == 0 + ), 'Tool should not be called when maximum_remote_calls=0' + assert ( + len(mock_model.requests) == 1 + ), 'Should make only 1 LLM call when maximum_remote_calls=0' + + +@pytest.mark.asyncio +async def test_maximum_remote_calls_limit_enforced(): + """Test that maximum_remote_calls limit is properly enforced. + + Note: maximum_remote_calls counts executed function calls. So + maximum_remote_calls=2 allows executing 2 function calls total. + """ + responses = [ + # First response + Part.from_function_call(name='test_tool', args={'x': 1}), + # Second response (after first tool execution) + Part.from_function_call(name='test_tool', args={'x': 2}), + # Third response (after second tool execution - should not be reached) + 'Should not be returned', + ] + mock_model = testing_utils.MockModel.create(responses=responses) + + call_count = 0 + + def test_tool(x: int) -> int: + nonlocal call_count + call_count += 1 + return x + 1 + + agent = Agent( + name='test_agent', + model=mock_model, + tools=[test_tool], + generate_content_config=types.GenerateContentConfig( + automatic_function_calling=types.AutomaticFunctionCallingConfig( + maximum_remote_calls=2 + ) + ), + ) + + runner = testing_utils.InMemoryRunner(agent) + events = list(runner.run('test message')) + + # Should execute tool twice (max_remote_calls=2) + assert ( + call_count == 2 + ), 'Tool should be called exactly twice when maximum_remote_calls=2' + # Should make 3 LLM calls: initial + after 1st FC + after 2nd FC + assert ( + len(mock_model.requests) == 3 + ), 'Should make 3 LLM calls with maximum_remote_calls=2' + + +@pytest.mark.asyncio +async def test_planner_hook_called_with_maximum_remote_calls_zero(): + """Test that planner.process_planning_response is called with maximum_remote_calls=0.""" + from google.adk.planners.plan_re_act_planner import PlanReActPlanner + + responses = [ + Part.from_function_call(name='test_tool', args={'x': 1}), + ] + mock_model = testing_utils.MockModel.create(responses=responses) + + def test_tool(x: int) -> int: + return x + 1 + + # Use PlanReActPlanner which actually processes responses + planner = PlanReActPlanner() + planner.process_planning_response = MagicMock( + wraps=planner.process_planning_response + ) + + agent = Agent( + name='test_agent', + model=mock_model, + tools=[test_tool], + planner=planner, + generate_content_config=types.GenerateContentConfig( + automatic_function_calling=types.AutomaticFunctionCallingConfig( + maximum_remote_calls=0 + ) + ), + ) + + runner = testing_utils.InMemoryRunner(agent) + list(runner.run('test message')) + + # Verify the planner hook was called + assert planner.process_planning_response.called, ( + 'Planner.process_planning_response should be called even with ' + 'maximum_remote_calls=0' + ) + + +@pytest.mark.asyncio +async def test_afc_enabled_continues_loop(): + """Test that AFC loop continues normally when not disabled.""" + responses = [ + # First response with function call + Part.from_function_call(name='test_tool', args={'x': 1}), + # Second response after function execution + 'Final response', + ] + mock_model = testing_utils.MockModel.create(responses=responses) + + call_count = 0 + + def test_tool(x: int) -> int: + nonlocal call_count + call_count += 1 + return x + 1 + + # No AFC config - should work normally + agent = Agent( + name='test_agent', + model=mock_model, + tools=[test_tool], + ) + + runner = testing_utils.InMemoryRunner(agent) + events = list(runner.run('test message')) + + # Tool should be executed + assert call_count == 1, 'Tool should be called once in normal AFC mode' + + # Should make 2 LLM calls: initial + after function response + assert ( + len(mock_model.requests) == 2 + ), 'Should make 2 LLM calls in normal AFC mode' + + +@pytest.mark.asyncio +async def test_afc_disabled_with_parallel_function_calls(): + """Test that AFC disabled works with parallel function calls.""" + # Model returns multiple function calls in one response + responses = [ + [ + Part.from_function_call(name='test_tool', args={'x': 1}), + Part.from_function_call(name='test_tool', args={'x': 2}), + Part.from_function_call(name='test_tool', args={'x': 3}), + ], + 'This should not be returned', + ] + mock_model = testing_utils.MockModel.create(responses=responses) + + call_count = 0 + + def test_tool(x: int) -> int: + nonlocal call_count + call_count += 1 + return x + 1 + + agent = Agent( + name='test_agent', + model=mock_model, + tools=[test_tool], + generate_content_config=types.GenerateContentConfig( + automatic_function_calling=types.AutomaticFunctionCallingConfig( + disable=True + ) + ), + ) + + runner = testing_utils.InMemoryRunner(agent) + events = list(runner.run('test message')) + + # None of the parallel FCs should be executed + assert call_count == 0, 'No tools should be called when AFC is disabled' + assert ( + len(mock_model.requests) == 1 + ), 'Should make only 1 LLM call when AFC is disabled' + + +@pytest.mark.asyncio +async def test_maximum_remote_calls_with_parallel_function_calls(): + """Test that maximum_remote_calls counts events, not individual FCs.""" + # Each LLM response has multiple parallel function calls + responses = [ + # First event with 2 parallel FCs + [ + Part.from_function_call(name='test_tool', args={'x': 1}), + Part.from_function_call(name='test_tool', args={'x': 2}), + ], + # Second event with 2 parallel FCs (should execute) + [ + Part.from_function_call(name='test_tool', args={'x': 3}), + Part.from_function_call(name='test_tool', args={'x': 4}), + ], + # Third event (should not execute - limit reached) + Part.from_function_call(name='test_tool', args={'x': 5}), + # Final response after limit + 'Final response after limit reached', + ] + mock_model = testing_utils.MockModel.create(responses=responses) + + call_count = 0 + + def test_tool(x: int) -> int: + nonlocal call_count + call_count += 1 + return x + 1 + + agent = Agent( + name='test_agent', + model=mock_model, + tools=[test_tool], + generate_content_config=types.GenerateContentConfig( + automatic_function_calling=types.AutomaticFunctionCallingConfig( + maximum_remote_calls=2 # 2 events, not 2 individual FCs + ) + ), + ) + + runner = testing_utils.InMemoryRunner(agent) + events = list(runner.run('test message')) + + # Should execute 4 FCs (2 events × 2 FCs each) + assert call_count == 4, ( + 'Should execute all FCs from first 2 events ' + '(maximum_remote_calls counts events, not individual FCs)' + ) + # Should make 4 LLM calls: initial + after 1st event + after 2nd event + final call + # The final call happens because we need to get a final response after reaching the limit + assert ( + len(mock_model.requests) == 4 + ), 'Should make 4 LLM calls with maximum_remote_calls=2' + + +@pytest.mark.asyncio +async def test_maximum_remote_calls_one_allows_one_execution(): + """Test that maximum_remote_calls=1 allows exactly one FC execution.""" + responses = [ + Part.from_function_call(name='test_tool', args={'x': 1}), + Part.from_function_call(name='test_tool', args={'x': 2}), + 'Final response after limit reached', + ] + mock_model = testing_utils.MockModel.create(responses=responses) + + call_count = 0 + + def test_tool(x: int) -> int: + nonlocal call_count + call_count += 1 + return x + 1 + + agent = Agent( + name='test_agent', + model=mock_model, + tools=[test_tool], + generate_content_config=types.GenerateContentConfig( + automatic_function_calling=types.AutomaticFunctionCallingConfig( + maximum_remote_calls=1 + ) + ), + ) + + runner = testing_utils.InMemoryRunner(agent) + events = list(runner.run('test message')) + + assert ( + call_count == 1 + ), 'Tool should be called once when maximum_remote_calls=1' + # Should make 3 LLM calls: initial + after 1st FC + final call + # The final call happens because we need to get a final response after reaching the limit + assert ( + len(mock_model.requests) == 3 + ), 'Should make 3 LLM calls with maximum_remote_calls=1' + + +def test_negative_maximum_remote_calls_treated_as_zero(): + """Test that negative maximum_remote_calls is caught by <= 0 check.""" + tool_call = types.Part.from_function_call(name='test_tool', args={'x': 1}) + response_with_fc = LlmResponse( + content=types.Content(role='model', parts=[tool_call]), + turn_complete=False, + ) + final_response = LlmResponse( + content=types.Content(role='model', parts=[types.Part(text='Done')]), + turn_complete=True, + ) + mock_model = testing_utils.MockModel.create( + [response_with_fc, final_response] + ) + + call_count = 0 + + def test_tool(x: int) -> int: + nonlocal call_count + call_count += 1 + return x + 1 + + agent = Agent( + name='test_agent', + model=mock_model, + tools=[test_tool], + generate_content_config=types.GenerateContentConfig( + automatic_function_calling=types.AutomaticFunctionCallingConfig( + maximum_remote_calls=-5 # Negative value + ) + ), + ) + + runner = testing_utils.InMemoryRunner(agent) + events = list(runner.run('test message')) + + # Negative value should be treated like 0 (no FCs allowed) + assert ( + call_count == 0 + ), 'Tool should not be called when maximum_remote_calls=-5' + # Should make 1 LLM call: initial only (loop exits immediately) + assert ( + len(mock_model.requests) == 1 + ), 'Should make 1 LLM call when negative maximum_remote_calls' + + +def test_very_large_maximum_remote_calls(): + """Test that very large maximum_remote_calls works correctly.""" + # Create responses for 3 function calls + tool_call = types.Part.from_function_call(name='test_tool', args={'x': 1}) + response_with_fc = LlmResponse( + content=types.Content(role='model', parts=[tool_call]), + turn_complete=False, + ) + final_response = LlmResponse( + content=types.Content(role='model', parts=[types.Part(text='Done')]), + turn_complete=True, + ) + + mock_model = testing_utils.MockModel.create([ + response_with_fc, + response_with_fc, + response_with_fc, + final_response, + ]) + + call_count = 0 + + def test_tool(x: int) -> int: + nonlocal call_count + call_count += 1 + return x + 1 + + agent = Agent( + name='test_agent', + model=mock_model, + tools=[test_tool], + generate_content_config=types.GenerateContentConfig( + automatic_function_calling=types.AutomaticFunctionCallingConfig( + maximum_remote_calls=999999 # Very large value + ) + ), + ) + + runner = testing_utils.InMemoryRunner(agent) + events = list(runner.run('test message')) + + # Should allow all 3 function calls since limit is very high + assert call_count == 3, 'All 3 tool calls should execute with limit=999999' + # Should make 4 LLM calls: initial + after each of 3 FCs + assert ( + len(mock_model.requests) == 4 + ), 'Should make 4 LLM calls with 3 function calls' + + +def test_corrupted_session_empty_events(): + """Test behavior when session history returns empty/corrupted data.""" + tool_call = types.Part.from_function_call(name='test_tool', args={'x': 1}) + response_with_fc = LlmResponse( + content=types.Content(role='model', parts=[tool_call]), + turn_complete=False, + ) + final_response = LlmResponse( + content=types.Content(role='model', parts=[types.Part(text='Done')]), + turn_complete=True, + ) + mock_model = testing_utils.MockModel.create( + [response_with_fc, final_response] + ) + + call_count = 0 + + def test_tool(x: int) -> int: + nonlocal call_count + call_count += 1 + return x + 1 + + agent = Agent( + name='test_agent', + model=mock_model, + tools=[test_tool], + generate_content_config=types.GenerateContentConfig( + automatic_function_calling=types.AutomaticFunctionCallingConfig( + maximum_remote_calls=1 + ) + ), + ) + + runner = testing_utils.InMemoryRunner(agent) + + # Clear session events to simulate corrupted state + session = runner.session + session._events = [] # Simulate empty/corrupted event history + + events = list(runner.run('test message')) + + # Even with corrupted session, the system should handle gracefully + # The first FC should execute since count starts at 0 + assert call_count == 1, 'Tool should be called once even with empty session' + + +def test_afc_disabled_in_live_mode(live_test_runner): + """Test that AFC disabled works in live streaming mode.""" + from google.genai import types as genai_types + + tool_call = types.Part.from_function_call(name='test_tool', args={'x': 1}) + response_with_fc = LlmResponse( + content=types.Content(role='model', parts=[tool_call]), + turn_complete=False, + ) + final_response = LlmResponse( + content=types.Content( + role='model', parts=[genai_types.Part(text='Done')] + ), + turn_complete=True, + ) + mock_model = testing_utils.MockModel.create( + [response_with_fc, final_response] + ) + + call_count = 0 + + def test_tool(x: int) -> int: + nonlocal call_count + call_count += 1 + return x + 1 + + agent = Agent( + name='test_agent', + model=mock_model, + tools=[test_tool], + generate_content_config=types.GenerateContentConfig( + automatic_function_calling=types.AutomaticFunctionCallingConfig( + disable=True # AFC disabled + ) + ), + ) + + runner = live_test_runner(root_agent=agent, response_modalities=['AUDIO']) + live_request_queue = LiveRequestQueue() + live_request_queue.send_realtime( + blob=genai_types.Blob(data=b'test audio', mime_type='audio/pcm') + ) + + res_events = runner.run_live(live_request_queue) + + # Tool should NOT be called because AFC is disabled + assert ( + call_count == 0 + ), 'Tool should not be called when AFC is disabled in live mode' + # Should make 1 LLM call: initial only (AFC disabled, no second call) + assert ( + len(mock_model.requests) == 1 + ), 'Should make 1 LLM call when AFC disabled in live mode' + + +def test_maximum_remote_calls_in_live_mode(live_test_runner): + """Test that maximum_remote_calls limit works in live streaming mode.""" + from google.genai import types as genai_types + + tool_call = types.Part.from_function_call(name='test_tool', args={'x': 1}) + response_with_fc = LlmResponse( + content=types.Content(role='model', parts=[tool_call]), + turn_complete=False, + ) + final_response = LlmResponse( + content=types.Content( + role='model', parts=[genai_types.Part(text='Done')] + ), + turn_complete=True, + ) + # Create 3 FC responses but limit to 1 + mock_model = testing_utils.MockModel.create([ + response_with_fc, + response_with_fc, + response_with_fc, + final_response, + ]) + + call_count = 0 + + def test_tool(x: int) -> int: + nonlocal call_count + call_count += 1 + return x + 1 + + agent = Agent( + name='test_agent', + model=mock_model, + tools=[test_tool], + generate_content_config=types.GenerateContentConfig( + automatic_function_calling=types.AutomaticFunctionCallingConfig( + maximum_remote_calls=1 # Limit to 1 FC + ) + ), + ) + + runner = live_test_runner(root_agent=agent, response_modalities=['AUDIO']) + live_request_queue = LiveRequestQueue() + live_request_queue.send_realtime( + blob=genai_types.Blob(data=b'test audio', mime_type='audio/pcm') + ) + + res_events = runner.run_live(live_request_queue, max_responses=4) + + # Tool should be called exactly once due to limit + assert ( + call_count == 1 + ), 'Tool should be called once when maximum_remote_calls=1 in live mode' + # In live mode with limit=1: initial call + 1 FC execution = 2 LLM calls total + # (different from async mode where limit is enforced differently) + assert ( + len(mock_model.requests) >= 1 + ), 'Should make at least 1 LLM call with maximum_remote_calls=1 in live mode' + + +def test_maximum_remote_calls_zero_in_live_mode(live_test_runner): + """Test that maximum_remote_calls=0 stops FCs in live streaming mode.""" + from google.genai import types as genai_types + + tool_call = types.Part.from_function_call(name='test_tool', args={'x': 1}) + response_with_fc = LlmResponse( + content=types.Content(role='model', parts=[tool_call]), + turn_complete=False, + ) + final_response = LlmResponse( + content=types.Content( + role='model', parts=[genai_types.Part(text='Done')] + ), + turn_complete=True, + ) + mock_model = testing_utils.MockModel.create( + [response_with_fc, final_response] + ) + + call_count = 0 + + def test_tool(x: int) -> int: + nonlocal call_count + call_count += 1 + return x + 1 + + agent = Agent( + name='test_agent', + model=mock_model, + tools=[test_tool], + generate_content_config=types.GenerateContentConfig( + automatic_function_calling=types.AutomaticFunctionCallingConfig( + maximum_remote_calls=0 # No FCs allowed + ) + ), + ) + + runner = live_test_runner(root_agent=agent, response_modalities=['AUDIO']) + live_request_queue = LiveRequestQueue() + live_request_queue.send_realtime( + blob=genai_types.Blob(data=b'test audio', mime_type='audio/pcm') + ) + + res_events = runner.run_live(live_request_queue) + + # Tool should NOT be called because maximum_remote_calls=0 + assert ( + call_count == 0 + ), 'Tool should not be called when maximum_remote_calls=0 in live mode' + assert ( + len(mock_model.requests) == 1 + ), 'Should make 1 LLM call when maximum_remote_calls=0 in live mode' + + +def test_parallel_function_calls_in_live_mode(live_test_runner): + """Test that parallel FCs count as 1 event in live mode.""" + from google.genai import types as genai_types + + # Create response with 3 parallel function calls + tool_call1 = types.Part.from_function_call(name='test_tool', args={'x': 1}) + tool_call2 = types.Part.from_function_call(name='test_tool', args={'x': 2}) + tool_call3 = types.Part.from_function_call(name='test_tool', args={'x': 3}) + response_with_parallel_fcs = LlmResponse( + content=types.Content( + role='model', parts=[tool_call1, tool_call2, tool_call3] + ), + turn_complete=False, + ) + final_response = LlmResponse( + content=types.Content( + role='model', parts=[genai_types.Part(text='Done')] + ), + turn_complete=True, + ) + mock_model = testing_utils.MockModel.create([ + response_with_parallel_fcs, + final_response, + ]) + + call_count = 0 + + def test_tool(x: int) -> int: + nonlocal call_count + call_count += 1 + return x + 1 + + agent = Agent( + name='test_agent', + model=mock_model, + tools=[test_tool], + generate_content_config=types.GenerateContentConfig( + automatic_function_calling=types.AutomaticFunctionCallingConfig( + maximum_remote_calls=1 # Limit to 1 event (3 parallel FCs) + ) + ), + ) + + runner = live_test_runner(root_agent=agent, response_modalities=['AUDIO']) + live_request_queue = LiveRequestQueue() + live_request_queue.send_realtime( + blob=genai_types.Blob(data=b'test audio', mime_type='audio/pcm') + ) + + res_events = runner.run_live(live_request_queue, max_responses=4) + + # All 3 parallel function calls should execute (they count as 1 event) + assert ( + call_count == 3 + ), 'All 3 parallel FCs should execute in live mode (count as 1 event)' + # Confirms event counting: 3 parallel FCs in 1 event = 1 toward limit + assert ( + len(mock_model.requests) >= 1 + ), 'Should make at least 1 LLM call with parallel FCs in live mode' + + +def test_negative_maximum_remote_calls_in_live_mode(live_test_runner): + """Test that negative maximum_remote_calls is treated as zero in live mode.""" + from google.genai import types as genai_types + + tool_call = types.Part.from_function_call(name='test_tool', args={'x': 1}) + response_with_fc = LlmResponse( + content=types.Content(role='model', parts=[tool_call]), + turn_complete=False, + ) + final_response = LlmResponse( + content=types.Content( + role='model', parts=[genai_types.Part(text='Done')] + ), + turn_complete=True, + ) + mock_model = testing_utils.MockModel.create( + [response_with_fc, final_response] + ) + + call_count = 0 + + def test_tool(x: int) -> int: + nonlocal call_count + call_count += 1 + return x + 1 + + agent = Agent( + name='test_agent', + model=mock_model, + tools=[test_tool], + generate_content_config=types.GenerateContentConfig( + automatic_function_calling=types.AutomaticFunctionCallingConfig( + maximum_remote_calls=-10 # Negative value + ) + ), + ) + + runner = live_test_runner(root_agent=agent, response_modalities=['AUDIO']) + live_request_queue = LiveRequestQueue() + live_request_queue.send_realtime( + blob=genai_types.Blob(data=b'test audio', mime_type='audio/pcm') + ) + + res_events = runner.run_live(live_request_queue) + + # Tool should NOT be called (negative treated as zero) + assert ( + call_count == 0 + ), 'Tool should not be called when maximum_remote_calls=-10 in live mode' + assert ( + len(mock_model.requests) == 1 + ), 'Should make 1 LLM call when negative maximum_remote_calls in live mode' + + +def test_maximum_remote_calls_two_in_live_mode(live_test_runner): + """Test that maximum_remote_calls=2 enforces limit in live mode.""" + from google.genai import types as genai_types + + tool_call = types.Part.from_function_call(name='test_tool', args={'x': 1}) + response_with_fc = LlmResponse( + content=types.Content(role='model', parts=[tool_call]), + turn_complete=False, + ) + final_response = LlmResponse( + content=types.Content( + role='model', parts=[genai_types.Part(text='Done')] + ), + turn_complete=True, + ) + # Create 3 FC responses but limit to 2 + mock_model = testing_utils.MockModel.create([ + response_with_fc, + response_with_fc, + response_with_fc, + final_response, + ]) + + call_count = 0 + + def test_tool(x: int) -> int: + nonlocal call_count + call_count += 1 + return x + 1 + + agent = Agent( + name='test_agent', + model=mock_model, + tools=[test_tool], + generate_content_config=types.GenerateContentConfig( + automatic_function_calling=types.AutomaticFunctionCallingConfig( + maximum_remote_calls=2 # Limit to 2 FCs + ) + ), + ) + + runner = live_test_runner(root_agent=agent, response_modalities=['AUDIO']) + live_request_queue = LiveRequestQueue() + live_request_queue.send_realtime( + blob=genai_types.Blob(data=b'test audio', mime_type='audio/pcm') + ) + + res_events = runner.run_live(live_request_queue, max_responses=5) + + # Tool should be called exactly twice + assert ( + call_count == 2 + ), 'Tool should be called twice when maximum_remote_calls=2 in live mode' + assert ( + len(mock_model.requests) >= 1 + ), 'Should make at least 1 LLM call with maximum_remote_calls=2 in live mode' + + +def test_very_large_maximum_remote_calls_in_live_mode(live_test_runner): + """Test that very large maximum_remote_calls works in live mode.""" + from google.genai import types as genai_types + + tool_call = types.Part.from_function_call(name='test_tool', args={'x': 1}) + response_with_fc = LlmResponse( + content=types.Content(role='model', parts=[tool_call]), + turn_complete=False, + ) + final_response = LlmResponse( + content=types.Content( + role='model', parts=[genai_types.Part(text='Done')] + ), + turn_complete=True, + ) + # Create 3 FC responses + mock_model = testing_utils.MockModel.create([ + response_with_fc, + response_with_fc, + response_with_fc, + final_response, + ]) + + call_count = 0 + + def test_tool(x: int) -> int: + nonlocal call_count + call_count += 1 + return x + 1 + + agent = Agent( + name='test_agent', + model=mock_model, + tools=[test_tool], + generate_content_config=types.GenerateContentConfig( + automatic_function_calling=types.AutomaticFunctionCallingConfig( + maximum_remote_calls=999999 # Very large limit + ) + ), + ) + + runner = live_test_runner(root_agent=agent, response_modalities=['AUDIO']) + live_request_queue = LiveRequestQueue() + live_request_queue.send_realtime( + blob=genai_types.Blob(data=b'test audio', mime_type='audio/pcm') + ) + + res_events = runner.run_live(live_request_queue, max_responses=5) + + # In live mode with timeout, may not get all 3 calls + # But should get at least 2 calls (verifies large limit works) + assert call_count >= 2, ( + 'At least 2 tool calls should execute with limit=999999 in live mode,' + f' got {call_count}' + ) + assert ( + len(mock_model.requests) >= 1 + ), 'Should make at least 1 LLM call with very large limit in live mode' + + +def test_corrupted_session_in_live_mode(live_test_runner): + """Test behavior when session is corrupted in live mode.""" + from google.genai import types as genai_types + + tool_call = types.Part.from_function_call(name='test_tool', args={'x': 1}) + response_with_fc = LlmResponse( + content=types.Content(role='model', parts=[tool_call]), + turn_complete=False, + ) + final_response = LlmResponse( + content=types.Content( + role='model', parts=[genai_types.Part(text='Done')] + ), + turn_complete=True, + ) + mock_model = testing_utils.MockModel.create( + [response_with_fc, final_response] + ) + + call_count = 0 + + def test_tool(x: int) -> int: + nonlocal call_count + call_count += 1 + return x + 1 + + agent = Agent( + name='test_agent', + model=mock_model, + tools=[test_tool], + generate_content_config=types.GenerateContentConfig( + automatic_function_calling=types.AutomaticFunctionCallingConfig( + maximum_remote_calls=1 + ) + ), + ) + + runner = live_test_runner(root_agent=agent, response_modalities=['AUDIO']) + # Clear session events to simulate corrupted state + runner.session._events = [] + live_request_queue = LiveRequestQueue() + live_request_queue.send_realtime( + blob=genai_types.Blob(data=b'test audio', mime_type='audio/pcm') + ) + + res_events = runner.run_live(live_request_queue) + + # Even with corrupted session, should handle gracefully + assert ( + call_count == 1 + ), 'Tool should be called once even with corrupted session in live mode' + + +def test_planner_hooks_in_live_mode(live_test_runner): + """Test that maximum_remote_calls=0 works correctly in live mode.""" + from google.genai import types as genai_types + + tool_call = types.Part.from_function_call(name='test_tool', args={'x': 1}) + response_with_fc = LlmResponse( + content=types.Content(role='model', parts=[tool_call]), + turn_complete=False, + ) + final_response = LlmResponse( + content=types.Content( + role='model', parts=[genai_types.Part(text='Done')] + ), + turn_complete=True, + ) + mock_model = testing_utils.MockModel.create( + [response_with_fc, final_response] + ) + + call_count = 0 + + def test_tool(x: int) -> int: + nonlocal call_count + call_count += 1 + return x + 1 + + agent = Agent( + name='test_agent', + model=mock_model, + tools=[test_tool], + generate_content_config=types.GenerateContentConfig( + automatic_function_calling=types.AutomaticFunctionCallingConfig( + maximum_remote_calls=0 # No FCs allowed + ) + ), + ) + + runner = live_test_runner(root_agent=agent, response_modalities=['AUDIO']) + live_request_queue = LiveRequestQueue() + live_request_queue.send_realtime( + blob=genai_types.Blob(data=b'test audio', mime_type='audio/pcm') + ) + + res_events = runner.run_live(live_request_queue) + + # AFC config should be respected in live mode + # Tool should NOT be called because maximum_remote_calls=0 + assert ( + call_count == 0 + ), 'Tool should not be called when maximum_remote_calls=0 in live mode' diff --git a/tests/unittests/flows/llm_flows/test_guardrail.py b/tests/unittests/flows/llm_flows/test_guardrail.py new file mode 100644 index 0000000000..c8cb9c963e --- /dev/null +++ b/tests/unittests/flows/llm_flows/test_guardrail.py @@ -0,0 +1,493 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for safety valve functionality in base_llm_flow.""" + +from __future__ import annotations + +from unittest.mock import MagicMock + +from google.adk.agents.llm_agent import Agent +from google.adk.flows.llm_flows.base_llm_flow import _GUARDRAIL_INSTRUCTION +from google.adk.flows.llm_flows.base_llm_flow import GuardrailContext +from google.adk.flows.llm_flows.base_llm_flow import MAX_CONSECUTIVE_REFUSED_FUNCTION_CALLS +from google.adk.models.llm_request import LlmRequest +from google.genai import types +import pytest + +from ... import testing_utils + + +@pytest.mark.asyncio +async def test_guardrail_constants_defined(): + """Verify safety valve constants are properly defined.""" + assert MAX_CONSECUTIVE_REFUSED_FUNCTION_CALLS == 3 + assert isinstance(_GUARDRAIL_INSTRUCTION, str) + assert 'IMPORTANT' in _GUARDRAIL_INSTRUCTION + assert 'maximum number of function calls' in _GUARDRAIL_INSTRUCTION + + +@pytest.mark.asyncio +async def test_guardrail_context(): + """Test GuardrailContext class methods.""" + state = {} + guardrail = GuardrailContext(state) + + # Initially not active + assert not guardrail.is_active + assert not guardrail.is_processed + + # Test activation + guardrail.activate() + assert guardrail.is_active + assert not guardrail.is_processed + + # Test marking as processed + guardrail.mark_processed() + assert guardrail.is_active + assert guardrail.is_processed + + # Test clear_processed + guardrail.clear_processed() + assert guardrail.is_active + assert not guardrail.is_processed + + # Test clear_active + guardrail.clear_active() + assert not guardrail.is_active + assert not guardrail.is_processed + + # Test full clear + guardrail.activate() + guardrail.mark_processed() + guardrail.clear() + assert not guardrail.is_active + assert not guardrail.is_processed + + # Test __repr__ + guardrail.activate() + repr_str = repr(guardrail) + assert 'GuardrailContext' in repr_str + assert 'active=True' in repr_str + assert 'processed=False' in repr_str + + +@pytest.mark.asyncio +async def test_guardrail_instruction_added_to_empty_system_instruction(): + """Test that safety valve instruction is added when no system instruction exists.""" + from google.adk.flows.llm_flows.base_llm_flow import BaseLlmFlow + + agent = Agent(name='test_agent', tools=[MagicMock()]) + invocation_context = await testing_utils.create_invocation_context( + agent=agent, user_content='test' + ) + + # Trigger safety valve + guardrail = GuardrailContext(invocation_context.session.state) + guardrail.activate() + + llm_request = LlmRequest() + llm_request.config = types.GenerateContentConfig() + + flow = BaseLlmFlow() + + # Call preprocess which should handle safety valve + async for _ in flow._preprocess_async(invocation_context, llm_request): + pass + + # Verify system instruction was added + assert llm_request.config.system_instruction is not None + assert _GUARDRAIL_INSTRUCTION.strip() in llm_request.config.system_instruction + + +@pytest.mark.asyncio +async def test_guardrail_instruction_appended_to_existing_instruction(): + """Test that safety valve instruction is appended to existing system instruction.""" + from google.adk.flows.llm_flows.base_llm_flow import BaseLlmFlow + + agent = Agent(name='test_agent', tools=[MagicMock()]) + invocation_context = await testing_utils.create_invocation_context( + agent=agent, user_content='test' + ) + + # Trigger safety valve + guardrail = GuardrailContext(invocation_context.session.state) + guardrail.activate() + + llm_request = LlmRequest() + llm_request.config = types.GenerateContentConfig() + llm_request.config.system_instruction = 'Original instruction' + + flow = BaseLlmFlow() + + # Call preprocess + async for _ in flow._preprocess_async(invocation_context, llm_request): + pass + + # Verify both instructions present + assert 'Original instruction' in llm_request.config.system_instruction + assert _GUARDRAIL_INSTRUCTION in llm_request.config.system_instruction + + +@pytest.mark.asyncio +async def test_guardrail_skips_tool_addition(): + """Test that tools are not added when safety valve is active.""" + from google.adk.flows.llm_flows.base_llm_flow import BaseLlmFlow + + mock_tool = MagicMock() + mock_tool.process_llm_request = MagicMock(return_value=None) + + agent = Agent(name='test_agent', tools=[mock_tool]) + invocation_context = await testing_utils.create_invocation_context( + agent=agent, user_content='test' + ) + + # Trigger safety valve + guardrail = GuardrailContext(invocation_context.session.state) + guardrail.activate() + + llm_request = LlmRequest() + llm_request.config = types.GenerateContentConfig() + + flow = BaseLlmFlow() + + # Call preprocess + async for _ in flow._preprocess_async(invocation_context, llm_request): + pass + + # Verify tool processing was NOT called + mock_tool.process_llm_request.assert_not_called() + + +@pytest.mark.asyncio +async def test_guardrail_disables_afc(): + """Test that AFC is disabled when safety valve is active.""" + from google.adk.flows.llm_flows.base_llm_flow import BaseLlmFlow + + agent = Agent(name='test_agent', tools=[MagicMock()]) + invocation_context = await testing_utils.create_invocation_context( + agent=agent, user_content='test' + ) + + # Trigger safety valve + guardrail = GuardrailContext(invocation_context.session.state) + guardrail.activate() + + llm_request = LlmRequest() + llm_request.config = types.GenerateContentConfig() + llm_request.config.automatic_function_calling = ( + types.AutomaticFunctionCallingConfig(disable=False) + ) + + flow = BaseLlmFlow() + + # Call preprocess + async for _ in flow._preprocess_async(invocation_context, llm_request): + pass + + # Verify AFC is disabled + assert llm_request.config.automatic_function_calling.disable is True + + +@pytest.mark.asyncio +async def test_guardrail_sets_processed_flag(): + """Test that processed flag is set after safety valve handling.""" + from google.adk.flows.llm_flows.base_llm_flow import BaseLlmFlow + + agent = Agent(name='test_agent', tools=[MagicMock()]) + invocation_context = await testing_utils.create_invocation_context( + agent=agent, user_content='test' + ) + + # Trigger safety valve + guardrail = GuardrailContext(invocation_context.session.state) + guardrail.activate() + assert guardrail.is_active + assert not guardrail.is_processed + + llm_request = LlmRequest() + llm_request.config = types.GenerateContentConfig() + + flow = BaseLlmFlow() + + # Call preprocess + async for _ in flow._preprocess_async(invocation_context, llm_request): + pass + + # Verify processed flag was set and active flag was cleared + guardrail_after = GuardrailContext(invocation_context.session.state) + assert guardrail_after.is_processed + assert not guardrail_after.is_active + + +@pytest.mark.asyncio +async def test_guardrail_final_enforcement_removes_tools(): + """Test that final enforcement removes tools from config.""" + from google.adk.flows.llm_flows.base_llm_flow import BaseLlmFlow + + agent = Agent(name='test_agent') + invocation_context = await testing_utils.create_invocation_context( + agent=agent, user_content='test' + ) + + # Mark as processed (simulating preprocess already handled it) + guardrail = GuardrailContext(invocation_context.session.state) + guardrail.mark_processed() + + flow = BaseLlmFlow() + + # Manually construct request with tools + llm_request = LlmRequest() + llm_request.config = types.GenerateContentConfig() + llm_request.config.tools = [types.Tool(function_declarations=[])] + llm_request.config.automatic_function_calling = ( + types.AutomaticFunctionCallingConfig(disable=False) + ) + + # Simulate the enforcement check (from _run_one_step_async) + guardrail_check = GuardrailContext(invocation_context.session.state) + if guardrail_check.is_processed: + if llm_request.config and llm_request.config.automatic_function_calling: + llm_request.config.automatic_function_calling.disable = True + if llm_request.config: + llm_request.config.tools = None + guardrail_check.clear_processed() + + # Verify enforcement worked + assert llm_request.config.tools is None + assert llm_request.config.automatic_function_calling.disable is True + guardrail_final = GuardrailContext(invocation_context.session.state) + assert not guardrail_final.is_processed + + +@pytest.mark.asyncio +async def test_guardrail_cleans_up_flags_on_error(): + """Test that safety valve flags are cleaned up even if error occurs.""" + from google.adk.flows.llm_flows.base_llm_flow import BaseLlmFlow + + agent = Agent(name='test_agent', tools=[MagicMock()]) + invocation_context = await testing_utils.create_invocation_context( + agent=agent, user_content='test' + ) + + # Set both flags + guardrail = GuardrailContext(invocation_context.session.state) + guardrail.activate() + guardrail.mark_processed() + + # Simulate cleanup in finally block (this is what run_async does) + try: + # Simulate an error + raise ValueError('Test error') + except ValueError: + # Expected error, test continues + pass + finally: + guardrail.clear() + + # Verify flags were cleared despite error + guardrail_after = GuardrailContext(invocation_context.session.state) + assert not guardrail_after.is_active + assert not guardrail_after.is_processed + + +@pytest.mark.asyncio +async def test_guardrail_live_mode_pre_execution_check(): + """Test that run_live mode checks AFC limits before executing function calls.""" + from google.adk.events.event import Event + from google.adk.flows.llm_flows.base_llm_flow import BaseLlmFlow + + agent = Agent(name='test_agent', tools=[MagicMock()]) + invocation_context = await testing_utils.create_invocation_context( + agent=agent, user_content='test' + ) + + flow = BaseLlmFlow() + llm_request = LlmRequest() + llm_request.config = types.GenerateContentConfig() + llm_request.config.automatic_function_calling = ( + types.AutomaticFunctionCallingConfig( + disable=False, maximum_remote_calls=3 + ) + ) + llm_request.tools_dict = {} # Empty tools dict for this test + + # Create a model response event with function call + model_response_event = Event( + author=agent.name, + invocation_id=invocation_context.invocation_id, + content=types.Content( + role='model', + parts=[ + types.Part( + function_call=types.FunctionCall(name='test_tool', args={}) + ) + ], + ), + ) + + # Simulate that we've already exceeded the maximum_remote_calls limit + # Add 4 FC events (over the limit of 3), so the new event would exceed too + for _ in range(4): + fc_event = Event( + author=agent.name, + invocation_id=invocation_context.invocation_id, + content=types.Content( + role='model', + parts=[ + types.Part( + function_call=types.FunctionCall(name='test_tool', args={}) + ) + ], + ), + ) + invocation_context.session.events.append(fc_event) + + # Test _postprocess_live - should return early without executing functions + results = [] + async for event in flow._postprocess_live( + invocation_context, + llm_request, + model_response_event, + model_response_event, + ): + results.append(event) + + # Should only yield the model_response_event, no function execution + # The _postprocess_live may yield modified event so check only length + assert len(results) >= 1 + + +@pytest.mark.asyncio +async def test_guardrail_live_mode_allows_execution_below_threshold(): + """Test that run_live mode allows function execution when below threshold.""" + from google.adk.events.event import Event + from google.adk.flows.llm_flows.base_llm_flow import BaseLlmFlow + + agent = Agent(name='test_agent', tools=[MagicMock()]) + invocation_context = await testing_utils.create_invocation_context( + agent=agent, user_content='test' + ) + + flow = BaseLlmFlow() + llm_request = LlmRequest() + llm_request.config = types.GenerateContentConfig() + llm_request.config.automatic_function_calling = ( + types.AutomaticFunctionCallingConfig( + disable=False, maximum_remote_calls=5 + ) + ) + llm_request.tools_dict = {} # Empty tools dict for this test + + # Create a model response event with function call + model_response_event = Event( + author=agent.name, + invocation_id=invocation_context.invocation_id, + content=types.Content( + role='model', + parts=[ + types.Part( + function_call=types.FunctionCall(name='test_tool', args={}) + ) + ], + ), + ) + + # Add only 2 refused events (below threshold of 3) + for _ in range(2): + refused_event = Event( + author=agent.name, + invocation_id=invocation_context.invocation_id, + content=types.Content( + role='model', + parts=[ + types.Part( + function_call=types.FunctionCall(name='test_tool', args={}) + ) + ], + ), + finish_reason=types.FinishReason.MAX_TOKENS, + ) + invocation_context.session.events.append(refused_event) + + # Mock function handler to verify it's called + import unittest.mock + + with unittest.mock.patch( + 'google.adk.flows.llm_flows.functions.handle_function_calls_live' + ) as mock_handler: + mock_handler.return_value = Event( + author='user', + invocation_id=invocation_context.invocation_id, + content=types.Content(role='user', parts=[]), + ) + + results = [] + async for event in flow._postprocess_live( + invocation_context, + llm_request, + model_response_event, + model_response_event, + ): + results.append(event) + + # Should call function handler since below threshold + mock_handler.assert_called_once() + + +@pytest.mark.asyncio +async def test_guardrail_live_mode_respects_afc_disable(): + """Test that run_live mode respects AFC disable flag.""" + from google.adk.events.event import Event + from google.adk.flows.llm_flows.base_llm_flow import BaseLlmFlow + + agent = Agent(name='test_agent', tools=[MagicMock()]) + invocation_context = await testing_utils.create_invocation_context( + agent=agent, user_content='test' + ) + + flow = BaseLlmFlow() + llm_request = LlmRequest() + llm_request.config = types.GenerateContentConfig() + llm_request.config.automatic_function_calling = ( + types.AutomaticFunctionCallingConfig(disable=True) + ) + + # Create a model response event with function call + model_response_event = Event( + author=agent.name, + invocation_id=invocation_context.invocation_id, + content=types.Content( + role='model', + parts=[ + types.Part( + function_call=types.FunctionCall(name='test_tool', args={}) + ) + ], + ), + ) + + # Test _postprocess_live with AFC disabled + results = [] + async for event in flow._postprocess_live( + invocation_context, + llm_request, + model_response_event, + model_response_event, + ): + results.append(event) + + # Should return early without executing functions + # The _postprocess_live may yield modified event so check only length + assert len(results) >= 1