Skip to content

Commit 199d769

Browse files
sararobcopybara-github
authored andcommitted
chore: Resolve remaining agent engines mypy errors
PiperOrigin-RevId: 862709652
1 parent d685d81 commit 199d769

File tree

6 files changed

+128
-105
lines changed

6 files changed

+128
-105
lines changed

.github/workflows/mypy.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ jobs:
1616
runs-on: ubuntu-latest
1717
strategy:
1818
matrix:
19-
python-version: ['3.9', '3.10', '3.11', '3.12', '3.13', '3.14']
19+
python-version: ['3.10', '3.11', '3.12', '3.13', '3.14']
2020

2121
steps:
2222
- name: Checkout code

vertexai/_genai/_agent_engines_utils.py

Lines changed: 59 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,6 @@
3232
from typing import (
3333
Any,
3434
AsyncIterator,
35-
Awaitable,
3635
Callable,
3736
Coroutine,
3837
Dict,
@@ -78,20 +77,30 @@
7877
_STDLIB_MODULE_NAMES: frozenset[str] = frozenset() # type: ignore[no-redef]
7978

8079

81-
try:
82-
from google.cloud import storage
80+
if typing.TYPE_CHECKING:
81+
from google.cloud import storage # type: ignore[attr-defined]
8382

84-
_StorageBucket: type[Any] = storage.Bucket
85-
except (ImportError, AttributeError):
86-
_StorageBucket: type[Any] = Any # type: ignore[no-redef]
83+
_StorageBucket = storage.Bucket
84+
else:
85+
try:
86+
from google.cloud import storage # type: ignore[attr-defined]
8787

88+
_StorageBucket: type[Any] = storage.Bucket
89+
except (ImportError, AttributeError):
90+
_StorageBucket: type[Any] = Any # type: ignore[no-redef]
8891

89-
try:
92+
93+
if typing.TYPE_CHECKING:
9094
import packaging
9195

92-
_SpecifierSet: type[Any] = packaging.specifiers.SpecifierSet
93-
except (ImportError, AttributeError):
94-
_SpecifierSet: type[Any] = Any # type: ignore[no-redef]
96+
_SpecifierSet = packaging.specifiers.SpecifierSet
97+
else:
98+
try:
99+
import packaging
100+
101+
_SpecifierSet: type[Any] = packaging.specifiers.SpecifierSet
102+
except (ImportError, AttributeError):
103+
_SpecifierSet: type[Any] = Any # type: ignore[no-redef]
95104

96105

97106
try:
@@ -258,16 +267,22 @@ class OperationRegistrable(Protocol):
258267
"""Protocol for agents that have registered operations."""
259268

260269
@abc.abstractmethod
261-
def register_operations(self, **kwargs) -> Dict[str, Sequence[str]]: # type: ignore[no-untyped-def]
270+
def register_operations(self, **kwargs: Any) -> dict[str, list[str]]:
262271
"""Register the user provided operations (modes and methods)."""
272+
pass
263273

264274

265-
try:
275+
if typing.TYPE_CHECKING:
266276
from google.adk.agents import BaseAgent
267277

268-
ADKAgent: type[Any] = BaseAgent
269-
except (ImportError, AttributeError):
270-
ADKAgent: type[Any] = Any # type: ignore[no-redef]
278+
ADKAgent: TypeAlias = BaseAgent
279+
else:
280+
try:
281+
from google.adk.agents import BaseAgent
282+
283+
ADKAgent: Optional[TypeAlias] = BaseAgent
284+
except (ImportError, AttributeError):
285+
ADKAgent: Optional[TypeAlias] = None # type: ignore[no-redef]
271286

272287
_AgentEngineInterface = Union[
273288
ADKAgent,
@@ -283,8 +298,9 @@ def register_operations(self, **kwargs) -> Dict[str, Sequence[str]]: # type: ig
283298
class _ModuleAgentAttributes(TypedDict, total=False):
284299
module_name: str
285300
agent_name: str
286-
register_operations: Dict[str, Sequence[str]]
301+
register_operations: Dict[str, list[str]]
287302
sys_paths: Optional[Sequence[str]]
303+
agent: _AgentEngineInterface
288304

289305

290306
class ModuleAgent(Cloneable, OperationRegistrable):
@@ -300,7 +316,7 @@ def __init__(
300316
*,
301317
module_name: str,
302318
agent_name: str,
303-
register_operations: Dict[str, Sequence[str]],
319+
register_operations: Dict[str, list[str]],
304320
sys_paths: Optional[Sequence[str]] = None,
305321
):
306322
"""Initializes a module-based agent.
@@ -310,7 +326,7 @@ def __init__(
310326
Required. The name of the module to import.
311327
agent_name (str):
312328
Required. The name of the agent in the module to instantiate.
313-
register_operations (Dict[str, Sequence[str]]):
329+
register_operations (Dict[str, list[str]]):
314330
Required. A dictionary of API modes to a list of method names.
315331
sys_paths (Sequence[str]):
316332
Optional. The system paths to search for the module. It should
@@ -336,8 +352,11 @@ def clone(self) -> "ModuleAgent":
336352
sys_paths=self._tmpl_attrs.get("sys_paths"),
337353
)
338354

339-
def register_operations(self) -> Dict[str, Sequence[str]]:
340-
self._tmpl_attrs.get("register_operations")
355+
def register_operations(self, **kwargs: Any) -> dict[str, list[str]]:
356+
reg_operations = self._tmpl_attrs.get("register_operations")
357+
if reg_operations is None:
358+
raise ValueError("Register operations is not set.")
359+
return reg_operations
341360

342361
def set_up(self) -> None:
343362
"""Sets up the agent for execution of queries at runtime.
@@ -411,7 +430,7 @@ def __call__(
411430
class GetAsyncOperationFunction(Protocol):
412431
async def __call__(
413432
self, *, operation_name: str, **kwargs: Any
414-
) -> Awaitable[AgentEngineOperationUnion]:
433+
) -> AgentEngineOperationUnion:
415434
pass
416435

417436

@@ -507,7 +526,7 @@ def _await_operation(
507526
def _compare_requirements(
508527
*,
509528
requirements: Mapping[str, str],
510-
constraints: Union[Sequence[str], Mapping[str, "_SpecifierSet"]],
529+
constraints: Union[Sequence[str], Mapping[str, Optional["_SpecifierSet"]]],
511530
required_packages: Optional[Iterator[str]] = None,
512531
) -> _RequirementsValidationResult:
513532
"""Compares the requirements with the constraints.
@@ -536,7 +555,7 @@ def _compare_requirements(
536555
"""
537556
packaging_version = _import_packaging_version_or_raise()
538557
if required_packages is None:
539-
required_packages = _DEFAULT_REQUIRED_PACKAGES
558+
required_packages = _DEFAULT_REQUIRED_PACKAGES # type: ignore[assignment]
540559
result = _RequirementsValidationResult(
541560
warnings=_RequirementsValidationWarnings(missing=set(), incompatible=set()),
542561
actions=_RequirementsValidationActions(append=set()),
@@ -583,7 +602,7 @@ def _generate_class_methods_spec_or_raise(
583602
if isinstance(agent, ModuleAgent):
584603
# We do a dry-run of setting up the agent engine to have the operations
585604
# needed for registration.
586-
agent: ModuleAgent = agent.clone()
605+
agent: ModuleAgent = agent.clone() # type: ignore[no-redef]
587606
try:
588607
agent.set_up()
589608
except Exception as e:
@@ -819,13 +838,13 @@ def _get_gcs_bucket(
819838
new_bucket = storage_client.bucket(staging_bucket)
820839
gcs_bucket = storage_client.create_bucket(new_bucket, location=location)
821840
logger.info(f"Creating bucket {staging_bucket} in {location=}")
822-
return gcs_bucket # type: ignore[no-any-return]
841+
return gcs_bucket
823842

824843

825844
def _get_registered_operations(
826845
*,
827846
agent: _AgentEngineInterface,
828-
) -> Dict[str, List[str]]:
847+
) -> dict[str, list[str]]:
829848
"""Retrieves registered operations for a AgentEngine."""
830849
if isinstance(agent, OperationRegistrable):
831850
return agent.register_operations()
@@ -859,13 +878,13 @@ def _import_cloudpickle_or_raise() -> types.ModuleType:
859878
def _import_cloud_storage_or_raise() -> types.ModuleType:
860879
"""Tries to import the Cloud Storage module."""
861880
try:
862-
from google.cloud import storage
881+
from google.cloud import storage # type: ignore[attr-defined]
863882
except ImportError as e:
864883
raise ImportError(
865884
"Cloud Storage is not installed. Please call "
866885
"'pip install google-cloud-aiplatform[agent_engines]'."
867886
) from e
868-
return storage
887+
return storage # type: ignore[no-any-return]
869888

870889

871890
def _import_packaging_requirements_or_raise() -> types.ModuleType:
@@ -1202,7 +1221,7 @@ def _upload_agent_engine(
12021221
) -> None:
12031222
"""Uploads the agent engine to GCS."""
12041223
cloudpickle = _import_cloudpickle_or_raise()
1205-
blob = gcs_bucket.blob(f"{gcs_dir_name}/{_BLOB_FILENAME}") # type: ignore[attr-defined]
1224+
blob = gcs_bucket.blob(f"{gcs_dir_name}/{_BLOB_FILENAME}")
12061225
with blob.open("wb") as f:
12071226
try:
12081227
cloudpickle.dump(agent, f)
@@ -1216,7 +1235,7 @@ def _upload_agent_engine(
12161235
_ = cloudpickle.load(f)
12171236
except Exception as e:
12181237
raise TypeError("Agent engine serialized to an invalid format") from e
1219-
dir_name = f"gs://{gcs_bucket.name}/{gcs_dir_name}" # type: ignore[attr-defined]
1238+
dir_name = f"gs://{gcs_bucket.name}/{gcs_dir_name}"
12201239
logger.info(f"Wrote to {dir_name}/{_BLOB_FILENAME}")
12211240

12221241

@@ -1227,9 +1246,9 @@ def _upload_requirements(
12271246
gcs_dir_name: str,
12281247
) -> None:
12291248
"""Uploads the requirements file to GCS."""
1230-
blob = gcs_bucket.blob(f"{gcs_dir_name}/{_REQUIREMENTS_FILE}") # type: ignore[attr-defined]
1249+
blob = gcs_bucket.blob(f"{gcs_dir_name}/{_REQUIREMENTS_FILE}")
12311250
blob.upload_from_string("\n".join(requirements))
1232-
dir_name = f"gs://{gcs_bucket.name}/{gcs_dir_name}" # type: ignore[attr-defined]
1251+
dir_name = f"gs://{gcs_bucket.name}/{gcs_dir_name}"
12331252
logger.info(f"Writing to {dir_name}/{_REQUIREMENTS_FILE}")
12341253

12351254

@@ -1246,9 +1265,9 @@ def _upload_extra_packages(
12461265
for file in extra_packages:
12471266
tar.add(file)
12481267
tar_fileobj.seek(0)
1249-
blob = gcs_bucket.blob(f"{gcs_dir_name}/{_EXTRA_PACKAGES_FILE}") # type: ignore[attr-defined]
1268+
blob = gcs_bucket.blob(f"{gcs_dir_name}/{_EXTRA_PACKAGES_FILE}")
12501269
blob.upload_from_string(tar_fileobj.read())
1251-
dir_name = f"gs://{gcs_bucket.name}/{gcs_dir_name}" # type: ignore[attr-defined]
1270+
dir_name = f"gs://{gcs_bucket.name}/{gcs_dir_name}"
12521271
logger.info(f"Writing to {dir_name}/{_EXTRA_PACKAGES_FILE}")
12531272

12541273

@@ -1369,7 +1388,7 @@ def _validate_requirements_or_warn(
13691388
*,
13701389
obj: Any,
13711390
requirements: List[str],
1372-
) -> Mapping[str, str]:
1391+
) -> List[str]:
13731392
"""Compiles the requirements into a list of requirements."""
13741393
requirements = requirements.copy()
13751394
try:
@@ -1380,16 +1399,14 @@ def _validate_requirements_or_warn(
13801399
requirements=current_requirements,
13811400
constraints=constraints,
13821401
)
1383-
for warning_type, warnings in missing_requirements.get(
1384-
_WARNINGS_KEY, {}
1385-
).items():
1402+
for warning_type, warnings in missing_requirements["warnings"].items():
13861403
if warnings:
13871404
logger.warning(
13881405
f"The following requirements are {warning_type}: {warnings}"
13891406
)
1390-
for action_type, actions in missing_requirements.get(_ACTIONS_KEY, {}).items():
1407+
for action_type, actions in missing_requirements["actions"].items():
13911408
if actions and action_type == _ACTION_APPEND:
1392-
for action in actions:
1409+
for action in actions: # type: ignore[attr-defined]
13931410
requirements.append(action)
13941411
logger.info(f"The following requirements are appended: {actions}")
13951412
except Exception as e:
@@ -1413,7 +1430,7 @@ def _validate_requirements_or_raise(
14131430
logger.info(f"Read the following lines: {requirements}")
14141431
except IOError as err:
14151432
raise IOError(f"Failed to read requirements from {requirements=}") from err
1416-
requirements = _validate_requirements_or_warn( # type: ignore[assignment]
1433+
requirements = _validate_requirements_or_warn(
14171434
obj=agent,
14181435
requirements=requirements,
14191436
)
@@ -1560,19 +1577,6 @@ def _method(self, **kwargs) -> Any: # type: ignore[no-untyped-def]
15601577
return _method
15611578

15621579

1563-
AgentEngineOperationUnion = Union[
1564-
genai_types.AgentEngineOperation,
1565-
genai_types.AgentEngineMemoryOperation,
1566-
genai_types.AgentEngineGenerateMemoriesOperation,
1567-
]
1568-
1569-
1570-
class GetOperationFunction(Protocol):
1571-
def __call__( # noqa: E704
1572-
self, *, operation_name: str, **kwargs: Any
1573-
) -> AgentEngineOperationUnion: ...
1574-
1575-
15761580
def _wrap_query_operation(*, method_name: str) -> Callable[..., Any]:
15771581
"""Wraps an Agent Engine method, creating a callable for `query` API.
15781582
@@ -1835,7 +1839,7 @@ async def _method(self, **kwargs) -> Any: # type: ignore[no-untyped-def]
18351839

18361840
return response
18371841

1838-
return _method
1842+
return _method # type: ignore[return-value]
18391843

18401844

18411845
def _yield_parsed_json(http_response: google_genai_types.HttpResponse) -> Iterator[Any]:

0 commit comments

Comments
 (0)