Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
127 changes: 112 additions & 15 deletions src/google/adk/a2a/executor/a2a_agent_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

from __future__ import annotations

import asyncio
from datetime import datetime
from datetime import timezone
import inspect
Expand Down Expand Up @@ -87,9 +88,16 @@ def __init__(
super().__init__()
self._runner = runner
self._config = config or A2aAgentExecutorConfig()
# Track active tasks by task_id for cancellation support
self._active_tasks: dict[str, asyncio.Task] = {}
# Lock to protect _active_tasks from race conditions
self._tasks_lock = asyncio.Lock()

async def _resolve_runner(self) -> Runner:
"""Resolve the runner, handling cases where it's a callable that returns a Runner."""
"""Resolve the runner.

Handles cases where it's a callable returning a Runner.
"""
# If already resolved and cached, return it
if isinstance(self._runner, Runner):
return self._runner
Expand All @@ -114,9 +122,72 @@ async def _resolve_runner(self) -> Runner:

@override
async def cancel(self, context: RequestContext, event_queue: EventQueue):
"""Cancel the execution."""
# TODO: Implement proper cancellation logic if needed
raise NotImplementedError('Cancellation is not supported')
"""Cancel the execution of a running task.

Args:
context: The request context containing the task_id to cancel.
event_queue: The event queue to publish cancellation events to.

If the task is found and running, it will be cancelled and a cancellation
event will be published. If the task is not found or already completed,
the method will log a warning and return gracefully.
"""
if not context.task_id:
logger.warning('Cannot cancel task: no task_id provided in context')
return

# Use lock to prevent race conditions with _handle_request cleanup
async with self._tasks_lock:
task = self._active_tasks.get(context.task_id)
if not task:
logger.warning(
'Task %s not found or already completed', context.task_id
)
return

if task.done():
# Task already completed, clean up
self._active_tasks.pop(context.task_id, None)
logger.info('Task %s already completed', context.task_id)
return

# Remove from tracking before cancelling to prevent double cleanup
self._active_tasks.pop(context.task_id, None)

# Cancel the task (outside lock to avoid blocking other operations)
logger.info('Cancelling task %s', context.task_id)
if not task.cancel():
# Task completed before it could be cancelled
logger.info('Task %s completed before it could be cancelled', context.task_id)
return

try:
# Wait for cancellation to complete with timeout
await asyncio.wait_for(task, timeout=1.0)
except (asyncio.CancelledError, asyncio.TimeoutError):
# Expected when task is cancelled or timeout occurs
pass

# Publish cancellation event
try:
await event_queue.enqueue_event(
TaskStatusUpdateEvent(
task_id=context.task_id,
status=TaskStatus(
state=TaskState.failed,
timestamp=datetime.now(timezone.utc).isoformat(),
message=Message(
message_id=str(uuid.uuid4()),
role=Role.agent,
parts=[TextPart(text='Task was cancelled')],
),
),
context_id=context.context_id,
final=True,
)
)
except Exception as e:
logger.error('Failed to publish cancellation event: %s', e, exc_info=True)

@override
async def execute(
Expand Down Expand Up @@ -221,17 +292,43 @@ async def _handle_request(
)

task_result_aggregator = TaskResultAggregator()
async with Aclosing(runner.run_async(**vars(run_request))) as agen:
async for adk_event in agen:
for a2a_event in self._config.event_converter(
adk_event,
invocation_context,
context.task_id,
context.context_id,
self._config.gen_ai_part_converter,
):
task_result_aggregator.process_event(a2a_event)
await event_queue.enqueue_event(a2a_event)

# Helper function to iterate over async generator
async def _process_events():
async with Aclosing(runner.run_async(**vars(run_request))) as agen:
async for adk_event in agen:
for a2a_event in self._config.event_converter(
adk_event,
invocation_context,
context.task_id,
context.context_id,
self._config.gen_ai_part_converter,
):
task_result_aggregator.process_event(a2a_event)
await event_queue.enqueue_event(a2a_event)

# Create and track the task for cancellation support
if context.task_id:
task = asyncio.create_task(_process_events())
# Use lock to prevent race conditions with cancel()
async with self._tasks_lock:
self._active_tasks[context.task_id] = task
try:
await task
except asyncio.CancelledError:
# Task was cancelled
# Note: cancellation event is published by cancel() method,
# so we just log and handle gracefully here
logger.info('Task %s was cancelled', context.task_id)
# Return early - don't publish completion events for cancelled tasks
return
finally:
# Clean up task tracking (use lock to prevent race conditions)
async with self._tasks_lock:
self._active_tasks.pop(context.task_id, None)
else:
# No task_id, run without tracking
await _process_events()

# publish the task result event - this is final
if (
Expand Down
176 changes: 166 additions & 10 deletions tests/unittests/a2a/executor/test_a2a_agent_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import asyncio
from unittest.mock import AsyncMock
from unittest.mock import Mock
from unittest.mock import patch
Expand All @@ -20,6 +21,7 @@
from a2a.server.events.event_queue import EventQueue
from a2a.types import Message
from a2a.types import TaskState
from a2a.types import TaskStatusUpdateEvent
from a2a.types import TextPart
from google.adk.a2a.converters.request_converter import AgentRunRequest
from google.adk.a2a.executor.a2a_agent_executor import A2aAgentExecutor
Expand Down Expand Up @@ -583,22 +585,176 @@ async def test_cancel_with_task_id(self):
"""Test cancellation with a task ID."""
self.mock_context.task_id = "test-task-id"

# The current implementation raises NotImplementedError
with pytest.raises(
NotImplementedError, match="Cancellation is not supported"
):
await self.executor.cancel(self.mock_context, self.mock_event_queue)
# Cancel should succeed without raising
await self.executor.cancel(self.mock_context, self.mock_event_queue)

# If no task is running, should log warning but not raise
# Verify event queue was not called (no task to cancel)
assert self.mock_event_queue.enqueue_event.call_count == 0

@pytest.mark.asyncio
async def test_cancel_without_task_id(self):
"""Test cancellation without a task ID."""
self.mock_context.task_id = None

# The current implementation raises NotImplementedError regardless of task_id
with pytest.raises(
NotImplementedError, match="Cancellation is not supported"
):
await self.executor.cancel(self.mock_context, self.mock_event_queue)
# Cancel should handle missing task_id gracefully
await self.executor.cancel(self.mock_context, self.mock_event_queue)

# Should not publish any events when task_id is missing
assert self.mock_event_queue.enqueue_event.call_count == 0

@pytest.mark.asyncio
async def test_cancel_running_task(self):
"""Test cancellation of a running task."""
self.mock_context.task_id = "test-task-id"

# Setup: Create a running task by starting execution
self.mock_request_converter.return_value = AgentRunRequest(
user_id="test-user",
session_id="test-session",
new_message=Mock(spec=Content),
run_config=Mock(spec=RunConfig),
)
mock_session = Mock()
mock_session.id = "test-session"
self.mock_runner.session_service.get_session = AsyncMock(
return_value=mock_session
)
mock_invocation_context = Mock()
self.mock_runner._new_invocation_context.return_value = (
mock_invocation_context
)

# Create an async generator that yields events slowly
async def slow_generator():
mock_event = Mock(spec=Event)
yield mock_event
# This will hang if not cancelled
await asyncio.sleep(10)

# Replace run_async with the async generator function
self.mock_runner.run_async = slow_generator
self.mock_event_converter.return_value = []

# Start execution in background
execute_task = asyncio.create_task(
self.executor.execute(self.mock_context, self.mock_event_queue)
)

# Wait a bit to ensure task is running
await asyncio.sleep(0.1)

# Cancel the task
await self.executor.cancel(self.mock_context, self.mock_event_queue)

# Wait for cancellation to complete
try:
await asyncio.wait_for(execute_task, timeout=2.0)
except asyncio.CancelledError:
pass

# Verify cancellation event was published
assert self.mock_event_queue.enqueue_event.call_count > 0
# Find the cancellation event (should be the last one with failed state)
cancellation_events = [
call[0][0]
for call in self.mock_event_queue.enqueue_event.call_args_list
if isinstance(call[0][0], TaskStatusUpdateEvent)
and call[0][0].status.state == TaskState.failed
and call[0][0].final is True
]
assert len(cancellation_events) > 0, "No cancellation event found"
cancellation_event = cancellation_events[-1]
assert cancellation_event.status.state == TaskState.failed
assert cancellation_event.final is True

@pytest.mark.asyncio
async def test_cancel_nonexistent_task(self):
"""Test cancellation of a non-existent task."""
self.mock_context.task_id = "nonexistent-task-id"

# Cancel should handle gracefully
await self.executor.cancel(self.mock_context, self.mock_event_queue)

# Should not publish any events for non-existent task
assert self.mock_event_queue.enqueue_event.call_count == 0

@pytest.mark.asyncio
async def test_cancel_completed_task(self):
"""Test cancellation of an already completed task."""
self.mock_context.task_id = "test-task-id"

# Setup and run a task to completion
self.mock_request_converter.return_value = AgentRunRequest(
user_id="test-user",
session_id="test-session",
new_message=Mock(spec=Content),
run_config=Mock(spec=RunConfig),
)
mock_session = Mock()
mock_session.id = "test-session"
self.mock_runner.session_service.get_session = AsyncMock(
return_value=mock_session
)
mock_invocation_context = Mock()
self.mock_runner._new_invocation_context.return_value = (
mock_invocation_context
)

# Create a generator that completes immediately
async def quick_generator():
mock_event = Mock(spec=Event)
yield mock_event

self.mock_runner.run_async.return_value = quick_generator()
self.mock_event_converter.return_value = []

# Run to completion
await self.executor.execute(self.mock_context, self.mock_event_queue)

# Now try to cancel (should handle gracefully)
await self.executor.cancel(self.mock_context, self.mock_event_queue)

# Should not publish additional cancellation event for completed task
# (The execute already published final event)

@pytest.mark.asyncio
async def test_cancel_race_condition_task_completes_before_cancel(self):
"""Test race condition where task completes before cancel() is called."""
self.mock_context.task_id = "test-task-id"

# Create a mock task that is already done
mock_task = Mock(spec=asyncio.Task)
mock_task.done.return_value = False # Initially not done (passes check)
mock_task.cancel.return_value = (
False # Returns False because task completed between check and cancel
)

# Manually add task to _active_tasks to simulate race condition
self.executor._active_tasks["test-task-id"] = mock_task

# Call cancel
await self.executor.cancel(self.mock_context, self.mock_event_queue)

# Verify task.cancel() was called
mock_task.cancel.assert_called_once()

# Verify no cancellation event was published (since cancel() returned False)
# Check that no TaskStatusUpdateEvent with "Task was cancelled" was published
cancellation_events = [
call[0][0]
for call in self.mock_event_queue.enqueue_event.call_args_list
if isinstance(call[0][0], TaskStatusUpdateEvent)
and call[0][0].status.state == TaskState.failed
and any(
part.text == "Task was cancelled"
for part in call[0][0].status.message.parts
if hasattr(part, "text")
)
]
assert (
len(cancellation_events) == 0
), "Should not publish cancellation event when task completed before cancel"

@pytest.mark.asyncio
async def test_execute_with_exception_handling(self):
Expand Down