diff --git a/src/google/adk/a2a/executor/a2a_agent_executor.py b/src/google/adk/a2a/executor/a2a_agent_executor.py index b6880aaa5c..421097a3b8 100644 --- a/src/google/adk/a2a/executor/a2a_agent_executor.py +++ b/src/google/adk/a2a/executor/a2a_agent_executor.py @@ -14,6 +14,7 @@ from __future__ import annotations +import asyncio from datetime import datetime from datetime import timezone import inspect @@ -87,9 +88,16 @@ def __init__( super().__init__() self._runner = runner self._config = config or A2aAgentExecutorConfig() + # Track active tasks by task_id for cancellation support + self._active_tasks: dict[str, asyncio.Task] = {} + # Lock to protect _active_tasks from race conditions + self._tasks_lock = asyncio.Lock() async def _resolve_runner(self) -> Runner: - """Resolve the runner, handling cases where it's a callable that returns a Runner.""" + """Resolve the runner. + + Handles cases where it's a callable returning a Runner. + """ # If already resolved and cached, return it if isinstance(self._runner, Runner): return self._runner @@ -114,9 +122,70 @@ async def _resolve_runner(self) -> Runner: @override async def cancel(self, context: RequestContext, event_queue: EventQueue): - """Cancel the execution.""" - # TODO: Implement proper cancellation logic if needed - raise NotImplementedError('Cancellation is not supported') + """Cancel the execution of a running task. + + Args: + context: The request context containing the task_id to cancel. + event_queue: The event queue to publish cancellation events to. + + If the task is found and running, it will be cancelled and a cancellation + event will be published. If the task is not found or already completed, + the method will log a warning and return gracefully. + """ + if not context.task_id: + logger.warning('Cannot cancel task: no task_id provided in context') + return + + # Use lock to prevent race conditions with _handle_request cleanup + async with self._tasks_lock: + task = self._active_tasks.pop(context.task_id, None) + + if not task: + logger.warning( + 'Task %s not found or already completed', context.task_id + ) + return + + if task.done(): + logger.info('Task %s already completed', context.task_id) + return + + # Cancel the task (outside lock to avoid blocking other operations) + logger.info('Cancelling task %s', context.task_id) + if not task.cancel(): + # Task completed before it could be cancelled + logger.info( + 'Task %s completed before it could be cancelled', context.task_id + ) + return + + try: + # Wait for cancellation to complete with timeout + await asyncio.wait_for(task, timeout=1.0) + except (asyncio.CancelledError, asyncio.TimeoutError): + # Expected when task is cancelled or timeout occurs + pass + + # Publish cancellation event + try: + await event_queue.enqueue_event( + TaskStatusUpdateEvent( + task_id=context.task_id, + status=TaskStatus( + state=TaskState.failed, + timestamp=datetime.now(timezone.utc).isoformat(), + message=Message( + message_id=str(uuid.uuid4()), + role=Role.agent, + parts=[TextPart(text='Task was cancelled')], + ), + ), + context_id=context.context_id, + final=True, + ) + ) + except Exception as e: + logger.error('Failed to publish cancellation event: %s', e, exc_info=True) @override async def execute( @@ -221,17 +290,43 @@ async def _handle_request( ) task_result_aggregator = TaskResultAggregator() - async with Aclosing(runner.run_async(**vars(run_request))) as agen: - async for adk_event in agen: - for a2a_event in self._config.event_converter( - adk_event, - invocation_context, - context.task_id, - context.context_id, - self._config.gen_ai_part_converter, - ): - task_result_aggregator.process_event(a2a_event) - await event_queue.enqueue_event(a2a_event) + + # Helper function to iterate over async generator + async def _process_events(): + async with Aclosing(runner.run_async(**vars(run_request))) as agen: + async for adk_event in agen: + for a2a_event in self._config.event_converter( + adk_event, + invocation_context, + context.task_id, + context.context_id, + self._config.gen_ai_part_converter, + ): + task_result_aggregator.process_event(a2a_event) + await event_queue.enqueue_event(a2a_event) + + # Create and track the task for cancellation support + if context.task_id: + task = asyncio.create_task(_process_events()) + # Use lock to prevent race conditions with cancel() + async with self._tasks_lock: + self._active_tasks[context.task_id] = task + try: + await task + except asyncio.CancelledError: + # Task was cancelled + # Note: cancellation event is published by cancel() method, + # so we just log and handle gracefully here + logger.info('Task %s was cancelled', context.task_id) + # Return early - don't publish completion events for cancelled tasks + return + finally: + # Clean up task tracking (use lock to prevent race conditions) + async with self._tasks_lock: + self._active_tasks.pop(context.task_id, None) + else: + # No task_id, run without tracking + await _process_events() # publish the task result event - this is final if ( diff --git a/tests/unittests/a2a/executor/test_a2a_agent_executor.py b/tests/unittests/a2a/executor/test_a2a_agent_executor.py index 58d7521f7d..9322542edf 100644 --- a/tests/unittests/a2a/executor/test_a2a_agent_executor.py +++ b/tests/unittests/a2a/executor/test_a2a_agent_executor.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import asyncio from unittest.mock import AsyncMock from unittest.mock import Mock from unittest.mock import patch @@ -20,6 +21,7 @@ from a2a.server.events.event_queue import EventQueue from a2a.types import Message from a2a.types import TaskState +from a2a.types import TaskStatusUpdateEvent from a2a.types import TextPart from google.adk.a2a.converters.request_converter import AgentRunRequest from google.adk.a2a.executor.a2a_agent_executor import A2aAgentExecutor @@ -583,22 +585,176 @@ async def test_cancel_with_task_id(self): """Test cancellation with a task ID.""" self.mock_context.task_id = "test-task-id" - # The current implementation raises NotImplementedError - with pytest.raises( - NotImplementedError, match="Cancellation is not supported" - ): - await self.executor.cancel(self.mock_context, self.mock_event_queue) + # Cancel should succeed without raising + await self.executor.cancel(self.mock_context, self.mock_event_queue) + + # If no task is running, should log warning but not raise + # Verify event queue was not called (no task to cancel) + assert self.mock_event_queue.enqueue_event.call_count == 0 @pytest.mark.asyncio async def test_cancel_without_task_id(self): """Test cancellation without a task ID.""" self.mock_context.task_id = None - # The current implementation raises NotImplementedError regardless of task_id - with pytest.raises( - NotImplementedError, match="Cancellation is not supported" - ): - await self.executor.cancel(self.mock_context, self.mock_event_queue) + # Cancel should handle missing task_id gracefully + await self.executor.cancel(self.mock_context, self.mock_event_queue) + + # Should not publish any events when task_id is missing + assert self.mock_event_queue.enqueue_event.call_count == 0 + + @pytest.mark.asyncio + async def test_cancel_running_task(self): + """Test cancellation of a running task.""" + self.mock_context.task_id = "test-task-id" + + # Setup: Create a running task by starting execution + self.mock_request_converter.return_value = AgentRunRequest( + user_id="test-user", + session_id="test-session", + new_message=Mock(spec=Content), + run_config=Mock(spec=RunConfig), + ) + mock_session = Mock() + mock_session.id = "test-session" + self.mock_runner.session_service.get_session = AsyncMock( + return_value=mock_session + ) + mock_invocation_context = Mock() + self.mock_runner._new_invocation_context.return_value = ( + mock_invocation_context + ) + + # Create an async generator that yields events slowly + async def slow_generator(): + mock_event = Mock(spec=Event) + yield mock_event + # This will hang if not cancelled + await asyncio.sleep(10) + + # Replace run_async with the async generator function + self.mock_runner.run_async = slow_generator + self.mock_event_converter.return_value = [] + + # Start execution in background + execute_task = asyncio.create_task( + self.executor.execute(self.mock_context, self.mock_event_queue) + ) + + # Wait a bit to ensure task is running + await asyncio.sleep(0.1) + + # Cancel the task + await self.executor.cancel(self.mock_context, self.mock_event_queue) + + # Wait for cancellation to complete + try: + await asyncio.wait_for(execute_task, timeout=2.0) + except asyncio.CancelledError: + pass + + # Verify cancellation event was published + assert self.mock_event_queue.enqueue_event.call_count > 0 + # Find the cancellation event (should be the last one with failed state) + cancellation_events = [ + call[0][0] + for call in self.mock_event_queue.enqueue_event.call_args_list + if isinstance(call[0][0], TaskStatusUpdateEvent) + and call[0][0].status.state == TaskState.failed + and call[0][0].final is True + ] + assert len(cancellation_events) > 0, "No cancellation event found" + cancellation_event = cancellation_events[-1] + assert cancellation_event.status.state == TaskState.failed + assert cancellation_event.final is True + + @pytest.mark.asyncio + async def test_cancel_nonexistent_task(self): + """Test cancellation of a non-existent task.""" + self.mock_context.task_id = "nonexistent-task-id" + + # Cancel should handle gracefully + await self.executor.cancel(self.mock_context, self.mock_event_queue) + + # Should not publish any events for non-existent task + assert self.mock_event_queue.enqueue_event.call_count == 0 + + @pytest.mark.asyncio + async def test_cancel_completed_task(self): + """Test cancellation of an already completed task.""" + self.mock_context.task_id = "test-task-id" + + # Setup and run a task to completion + self.mock_request_converter.return_value = AgentRunRequest( + user_id="test-user", + session_id="test-session", + new_message=Mock(spec=Content), + run_config=Mock(spec=RunConfig), + ) + mock_session = Mock() + mock_session.id = "test-session" + self.mock_runner.session_service.get_session = AsyncMock( + return_value=mock_session + ) + mock_invocation_context = Mock() + self.mock_runner._new_invocation_context.return_value = ( + mock_invocation_context + ) + + # Create a generator that completes immediately + async def quick_generator(): + mock_event = Mock(spec=Event) + yield mock_event + + self.mock_runner.run_async.return_value = quick_generator() + self.mock_event_converter.return_value = [] + + # Run to completion + await self.executor.execute(self.mock_context, self.mock_event_queue) + + # Now try to cancel (should handle gracefully) + await self.executor.cancel(self.mock_context, self.mock_event_queue) + + # Should not publish additional cancellation event for completed task + # (The execute already published final event) + + @pytest.mark.asyncio + async def test_cancel_race_condition_task_completes_before_cancel(self): + """Test race condition where task completes before cancel() is called.""" + self.mock_context.task_id = "test-task-id" + + # Create a mock task that is already done + mock_task = Mock(spec=asyncio.Task) + mock_task.done.return_value = False # Initially not done (passes check) + mock_task.cancel.return_value = ( + False # Returns False because task completed between check and cancel + ) + + # Manually add task to _active_tasks to simulate race condition + self.executor._active_tasks["test-task-id"] = mock_task + + # Call cancel + await self.executor.cancel(self.mock_context, self.mock_event_queue) + + # Verify task.cancel() was called + mock_task.cancel.assert_called_once() + + # Verify no cancellation event was published (since cancel() returned False) + # Check that no TaskStatusUpdateEvent with "Task was cancelled" was published + cancellation_events = [ + call[0][0] + for call in self.mock_event_queue.enqueue_event.call_args_list + if isinstance(call[0][0], TaskStatusUpdateEvent) + and call[0][0].status.state == TaskState.failed + and any( + part.text == "Task was cancelled" + for part in call[0][0].status.message.parts + if hasattr(part, "text") + ) + ] + assert ( + len(cancellation_events) == 0 + ), "Should not publish cancellation event when task completed before cancel" @pytest.mark.asyncio async def test_execute_with_exception_handling(self):