Skip to content

Commit f653ef4

Browse files
sararobcopybara-github
authored andcommitted
chore: Resolve remaining agent engines mypy errors
PiperOrigin-RevId: 862709652
1 parent 66c4d85 commit f653ef4

File tree

6 files changed

+134
-105
lines changed

6 files changed

+134
-105
lines changed

.github/workflows/mypy.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ jobs:
1616
runs-on: ubuntu-latest
1717
strategy:
1818
matrix:
19-
python-version: ['3.9', '3.10', '3.11', '3.12', '3.13', '3.14']
19+
python-version: ['3.10', '3.11', '3.12', '3.13', '3.14']
2020

2121
steps:
2222
- name: Checkout code

vertexai/_genai/_agent_engines_utils.py

Lines changed: 65 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,6 @@
3232
from typing import (
3333
Any,
3434
AsyncIterator,
35-
Awaitable,
3635
Callable,
3736
Coroutine,
3837
Dict,
@@ -59,6 +58,12 @@
5958
from . import types as genai_types
6059

6160

61+
if sys.version_info > (3, 9):
62+
from typing import TypeAlias
63+
else:
64+
from typing_extensions import TypeAlias
65+
66+
6267
try:
6368
_BUILTIN_MODULE_NAMES: Sequence[str] = sys.builtin_module_names
6469
except AttributeError:
@@ -78,20 +83,30 @@
7883
_STDLIB_MODULE_NAMES: frozenset[str] = frozenset() # type: ignore[no-redef]
7984

8085

81-
try:
82-
from google.cloud import storage
86+
if typing.TYPE_CHECKING:
87+
from google.cloud import storage # type: ignore[attr-defined]
8388

84-
_StorageBucket: type[Any] = storage.Bucket
85-
except (ImportError, AttributeError):
86-
_StorageBucket: type[Any] = Any # type: ignore[no-redef]
89+
_StorageBucket: TypeAlias = storage.Bucket
90+
else:
91+
try:
92+
from google.cloud import storage # type: ignore[attr-defined]
8793

94+
_StorageBucket: type[Any] = storage.Bucket
95+
except (ImportError, AttributeError):
96+
_StorageBucket: type[Any] = Any # type: ignore[no-redef]
8897

89-
try:
98+
99+
if typing.TYPE_CHECKING:
90100
import packaging
91101

92-
_SpecifierSet: type[Any] = packaging.specifiers.SpecifierSet
93-
except (ImportError, AttributeError):
94-
_SpecifierSet: type[Any] = Any # type: ignore[no-redef]
102+
_SpecifierSet = packaging.specifiers.SpecifierSet
103+
else:
104+
try:
105+
import packaging
106+
107+
_SpecifierSet: type[Any] = packaging.specifiers.SpecifierSet
108+
except (ImportError, AttributeError):
109+
_SpecifierSet: type[Any] = Any # type: ignore[no-redef]
95110

96111

97112
try:
@@ -258,16 +273,22 @@ class OperationRegistrable(Protocol):
258273
"""Protocol for agents that have registered operations."""
259274

260275
@abc.abstractmethod
261-
def register_operations(self, **kwargs) -> Dict[str, Sequence[str]]: # type: ignore[no-untyped-def]
276+
def register_operations(self, **kwargs: Any) -> dict[str, list[str]]:
262277
"""Register the user provided operations (modes and methods)."""
278+
pass
263279

264280

265-
try:
281+
if typing.TYPE_CHECKING:
266282
from google.adk.agents import BaseAgent
267283

268-
ADKAgent: type[Any] = BaseAgent
269-
except (ImportError, AttributeError):
270-
ADKAgent: type[Any] = Any # type: ignore[no-redef]
284+
ADKAgent: TypeAlias = BaseAgent
285+
else:
286+
try:
287+
from google.adk.agents import BaseAgent
288+
289+
ADKAgent: Optional[TypeAlias] = BaseAgent
290+
except (ImportError, AttributeError):
291+
ADKAgent = None # type: ignore[no-redef]
271292

272293
_AgentEngineInterface = Union[
273294
ADKAgent,
@@ -283,8 +304,9 @@ def register_operations(self, **kwargs) -> Dict[str, Sequence[str]]: # type: ig
283304
class _ModuleAgentAttributes(TypedDict, total=False):
284305
module_name: str
285306
agent_name: str
286-
register_operations: Dict[str, Sequence[str]]
307+
register_operations: Dict[str, list[str]]
287308
sys_paths: Optional[Sequence[str]]
309+
agent: _AgentEngineInterface
288310

289311

290312
class ModuleAgent(Cloneable, OperationRegistrable):
@@ -300,7 +322,7 @@ def __init__(
300322
*,
301323
module_name: str,
302324
agent_name: str,
303-
register_operations: Dict[str, Sequence[str]],
325+
register_operations: Dict[str, list[str]],
304326
sys_paths: Optional[Sequence[str]] = None,
305327
):
306328
"""Initializes a module-based agent.
@@ -310,7 +332,7 @@ def __init__(
310332
Required. The name of the module to import.
311333
agent_name (str):
312334
Required. The name of the agent in the module to instantiate.
313-
register_operations (Dict[str, Sequence[str]]):
335+
register_operations (Dict[str, list[str]]):
314336
Required. A dictionary of API modes to a list of method names.
315337
sys_paths (Sequence[str]):
316338
Optional. The system paths to search for the module. It should
@@ -336,8 +358,11 @@ def clone(self) -> "ModuleAgent":
336358
sys_paths=self._tmpl_attrs.get("sys_paths"),
337359
)
338360

339-
def register_operations(self) -> Dict[str, Sequence[str]]:
340-
self._tmpl_attrs.get("register_operations")
361+
def register_operations(self, **kwargs: Any) -> dict[str, list[str]]:
362+
reg_operations = self._tmpl_attrs.get("register_operations")
363+
if reg_operations is None:
364+
raise ValueError("Register operations is not set.")
365+
return reg_operations
341366

342367
def set_up(self) -> None:
343368
"""Sets up the agent for execution of queries at runtime.
@@ -411,7 +436,7 @@ def __call__(
411436
class GetAsyncOperationFunction(Protocol):
412437
async def __call__(
413438
self, *, operation_name: str, **kwargs: Any
414-
) -> Awaitable[AgentEngineOperationUnion]:
439+
) -> AgentEngineOperationUnion:
415440
pass
416441

417442

@@ -507,7 +532,7 @@ def _await_operation(
507532
def _compare_requirements(
508533
*,
509534
requirements: Mapping[str, str],
510-
constraints: Union[Sequence[str], Mapping[str, "_SpecifierSet"]],
535+
constraints: Union[Sequence[str], Mapping[str, Optional["_SpecifierSet"]]],
511536
required_packages: Optional[Iterator[str]] = None,
512537
) -> _RequirementsValidationResult:
513538
"""Compares the requirements with the constraints.
@@ -536,7 +561,7 @@ def _compare_requirements(
536561
"""
537562
packaging_version = _import_packaging_version_or_raise()
538563
if required_packages is None:
539-
required_packages = _DEFAULT_REQUIRED_PACKAGES
564+
required_packages = _DEFAULT_REQUIRED_PACKAGES # type: ignore[assignment]
540565
result = _RequirementsValidationResult(
541566
warnings=_RequirementsValidationWarnings(missing=set(), incompatible=set()),
542567
actions=_RequirementsValidationActions(append=set()),
@@ -583,7 +608,7 @@ def _generate_class_methods_spec_or_raise(
583608
if isinstance(agent, ModuleAgent):
584609
# We do a dry-run of setting up the agent engine to have the operations
585610
# needed for registration.
586-
agent: ModuleAgent = agent.clone()
611+
agent: ModuleAgent = agent.clone() # type: ignore[no-redef]
587612
try:
588613
agent.set_up()
589614
except Exception as e:
@@ -819,13 +844,13 @@ def _get_gcs_bucket(
819844
new_bucket = storage_client.bucket(staging_bucket)
820845
gcs_bucket = storage_client.create_bucket(new_bucket, location=location)
821846
logger.info(f"Creating bucket {staging_bucket} in {location=}")
822-
return gcs_bucket # type: ignore[no-any-return]
847+
return gcs_bucket
823848

824849

825850
def _get_registered_operations(
826851
*,
827852
agent: _AgentEngineInterface,
828-
) -> Dict[str, List[str]]:
853+
) -> dict[str, list[str]]:
829854
"""Retrieves registered operations for a AgentEngine."""
830855
if isinstance(agent, OperationRegistrable):
831856
return agent.register_operations()
@@ -859,13 +884,13 @@ def _import_cloudpickle_or_raise() -> types.ModuleType:
859884
def _import_cloud_storage_or_raise() -> types.ModuleType:
860885
"""Tries to import the Cloud Storage module."""
861886
try:
862-
from google.cloud import storage
887+
from google.cloud import storage # type: ignore[attr-defined]
863888
except ImportError as e:
864889
raise ImportError(
865890
"Cloud Storage is not installed. Please call "
866891
"'pip install google-cloud-aiplatform[agent_engines]'."
867892
) from e
868-
return storage
893+
return storage # type: ignore[no-any-return]
869894

870895

871896
def _import_packaging_requirements_or_raise() -> types.ModuleType:
@@ -1202,7 +1227,7 @@ def _upload_agent_engine(
12021227
) -> None:
12031228
"""Uploads the agent engine to GCS."""
12041229
cloudpickle = _import_cloudpickle_or_raise()
1205-
blob = gcs_bucket.blob(f"{gcs_dir_name}/{_BLOB_FILENAME}") # type: ignore[attr-defined]
1230+
blob = gcs_bucket.blob(f"{gcs_dir_name}/{_BLOB_FILENAME}")
12061231
with blob.open("wb") as f:
12071232
try:
12081233
cloudpickle.dump(agent, f)
@@ -1216,7 +1241,7 @@ def _upload_agent_engine(
12161241
_ = cloudpickle.load(f)
12171242
except Exception as e:
12181243
raise TypeError("Agent engine serialized to an invalid format") from e
1219-
dir_name = f"gs://{gcs_bucket.name}/{gcs_dir_name}" # type: ignore[attr-defined]
1244+
dir_name = f"gs://{gcs_bucket.name}/{gcs_dir_name}"
12201245
logger.info(f"Wrote to {dir_name}/{_BLOB_FILENAME}")
12211246

12221247

@@ -1227,9 +1252,9 @@ def _upload_requirements(
12271252
gcs_dir_name: str,
12281253
) -> None:
12291254
"""Uploads the requirements file to GCS."""
1230-
blob = gcs_bucket.blob(f"{gcs_dir_name}/{_REQUIREMENTS_FILE}") # type: ignore[attr-defined]
1255+
blob = gcs_bucket.blob(f"{gcs_dir_name}/{_REQUIREMENTS_FILE}")
12311256
blob.upload_from_string("\n".join(requirements))
1232-
dir_name = f"gs://{gcs_bucket.name}/{gcs_dir_name}" # type: ignore[attr-defined]
1257+
dir_name = f"gs://{gcs_bucket.name}/{gcs_dir_name}"
12331258
logger.info(f"Writing to {dir_name}/{_REQUIREMENTS_FILE}")
12341259

12351260

@@ -1246,9 +1271,9 @@ def _upload_extra_packages(
12461271
for file in extra_packages:
12471272
tar.add(file)
12481273
tar_fileobj.seek(0)
1249-
blob = gcs_bucket.blob(f"{gcs_dir_name}/{_EXTRA_PACKAGES_FILE}") # type: ignore[attr-defined]
1274+
blob = gcs_bucket.blob(f"{gcs_dir_name}/{_EXTRA_PACKAGES_FILE}")
12501275
blob.upload_from_string(tar_fileobj.read())
1251-
dir_name = f"gs://{gcs_bucket.name}/{gcs_dir_name}" # type: ignore[attr-defined]
1276+
dir_name = f"gs://{gcs_bucket.name}/{gcs_dir_name}"
12521277
logger.info(f"Writing to {dir_name}/{_EXTRA_PACKAGES_FILE}")
12531278

12541279

@@ -1369,7 +1394,7 @@ def _validate_requirements_or_warn(
13691394
*,
13701395
obj: Any,
13711396
requirements: List[str],
1372-
) -> Mapping[str, str]:
1397+
) -> List[str]:
13731398
"""Compiles the requirements into a list of requirements."""
13741399
requirements = requirements.copy()
13751400
try:
@@ -1380,16 +1405,14 @@ def _validate_requirements_or_warn(
13801405
requirements=current_requirements,
13811406
constraints=constraints,
13821407
)
1383-
for warning_type, warnings in missing_requirements.get(
1384-
_WARNINGS_KEY, {}
1385-
).items():
1408+
for warning_type, warnings in missing_requirements["warnings"].items():
13861409
if warnings:
13871410
logger.warning(
13881411
f"The following requirements are {warning_type}: {warnings}"
13891412
)
1390-
for action_type, actions in missing_requirements.get(_ACTIONS_KEY, {}).items():
1413+
for action_type, actions in missing_requirements["actions"].items():
13911414
if actions and action_type == _ACTION_APPEND:
1392-
for action in actions:
1415+
for action in actions: # type: ignore[attr-defined]
13931416
requirements.append(action)
13941417
logger.info(f"The following requirements are appended: {actions}")
13951418
except Exception as e:
@@ -1413,7 +1436,7 @@ def _validate_requirements_or_raise(
14131436
logger.info(f"Read the following lines: {requirements}")
14141437
except IOError as err:
14151438
raise IOError(f"Failed to read requirements from {requirements=}") from err
1416-
requirements = _validate_requirements_or_warn( # type: ignore[assignment]
1439+
requirements = _validate_requirements_or_warn(
14171440
obj=agent,
14181441
requirements=requirements,
14191442
)
@@ -1560,19 +1583,6 @@ def _method(self, **kwargs) -> Any: # type: ignore[no-untyped-def]
15601583
return _method
15611584

15621585

1563-
AgentEngineOperationUnion = Union[
1564-
genai_types.AgentEngineOperation,
1565-
genai_types.AgentEngineMemoryOperation,
1566-
genai_types.AgentEngineGenerateMemoriesOperation,
1567-
]
1568-
1569-
1570-
class GetOperationFunction(Protocol):
1571-
def __call__( # noqa: E704
1572-
self, *, operation_name: str, **kwargs: Any
1573-
) -> AgentEngineOperationUnion: ...
1574-
1575-
15761586
def _wrap_query_operation(*, method_name: str) -> Callable[..., Any]:
15771587
"""Wraps an Agent Engine method, creating a callable for `query` API.
15781588
@@ -1835,7 +1845,7 @@ async def _method(self, **kwargs) -> Any: # type: ignore[no-untyped-def]
18351845

18361846
return response
18371847

1838-
return _method
1848+
return _method # type: ignore[return-value]
18391849

18401850

18411851
def _yield_parsed_json(http_response: google_genai_types.HttpResponse) -> Iterator[Any]:

0 commit comments

Comments
 (0)