3232from typing import (
3333 Any ,
3434 AsyncIterator ,
35- Awaitable ,
3635 Callable ,
3736 Coroutine ,
3837 Dict ,
5958from . import types as genai_types
6059
6160
61+ if sys .version_info > (3 , 9 ):
62+ from typing import TypeAlias
63+ else :
64+ from typing_extensions import TypeAlias
65+
66+
6267try :
6368 _BUILTIN_MODULE_NAMES : Sequence [str ] = sys .builtin_module_names
6469except AttributeError :
7883 _STDLIB_MODULE_NAMES : frozenset [str ] = frozenset () # type: ignore[no-redef]
7984
8085
81- try :
82- from google .cloud import storage
86+ if typing . TYPE_CHECKING :
87+ from google .cloud import storage # type: ignore[attr-defined]
8388
84- _StorageBucket : type [Any ] = storage .Bucket
85- except (ImportError , AttributeError ):
86- _StorageBucket : type [Any ] = Any # type: ignore[no-redef]
89+ _StorageBucket : TypeAlias = storage .Bucket
90+ else :
91+ try :
92+ from google .cloud import storage # type: ignore[attr-defined]
8793
94+ _StorageBucket : type [Any ] = storage .Bucket
95+ except (ImportError , AttributeError ):
96+ _StorageBucket : type [Any ] = Any # type: ignore[no-redef]
8897
89- try :
98+
99+ if typing .TYPE_CHECKING :
90100 import packaging
91101
92- _SpecifierSet : type [Any ] = packaging .specifiers .SpecifierSet
93- except (ImportError , AttributeError ):
94- _SpecifierSet : type [Any ] = Any # type: ignore[no-redef]
102+ _SpecifierSet = packaging .specifiers .SpecifierSet
103+ else :
104+ try :
105+ import packaging
106+
107+ _SpecifierSet : type [Any ] = packaging .specifiers .SpecifierSet
108+ except (ImportError , AttributeError ):
109+ _SpecifierSet : type [Any ] = Any # type: ignore[no-redef]
95110
96111
97112try :
@@ -258,16 +273,22 @@ class OperationRegistrable(Protocol):
258273 """Protocol for agents that have registered operations."""
259274
260275 @abc .abstractmethod
261- def register_operations (self , ** kwargs ) -> Dict [str , Sequence [str ]]: # type: ignore[no-untyped-def]
276+ def register_operations (self , ** kwargs : Any ) -> dict [str , list [str ]]:
262277 """Register the user provided operations (modes and methods)."""
278+ pass
263279
264280
265- try :
281+ if typing . TYPE_CHECKING :
266282 from google .adk .agents import BaseAgent
267283
268- ADKAgent : type [Any ] = BaseAgent
269- except (ImportError , AttributeError ):
270- ADKAgent : type [Any ] = Any # type: ignore[no-redef]
284+ ADKAgent : TypeAlias = BaseAgent
285+ else :
286+ try :
287+ from google .adk .agents import BaseAgent
288+
289+ ADKAgent : Optional [TypeAlias ] = BaseAgent
290+ except (ImportError , AttributeError ):
291+ ADKAgent = None # type: ignore[no-redef]
271292
272293_AgentEngineInterface = Union [
273294 ADKAgent ,
@@ -283,8 +304,9 @@ def register_operations(self, **kwargs) -> Dict[str, Sequence[str]]: # type: ig
283304class _ModuleAgentAttributes (TypedDict , total = False ):
284305 module_name : str
285306 agent_name : str
286- register_operations : Dict [str , Sequence [str ]]
307+ register_operations : Dict [str , list [str ]]
287308 sys_paths : Optional [Sequence [str ]]
309+ agent : _AgentEngineInterface
288310
289311
290312class ModuleAgent (Cloneable , OperationRegistrable ):
@@ -300,7 +322,7 @@ def __init__(
300322 * ,
301323 module_name : str ,
302324 agent_name : str ,
303- register_operations : Dict [str , Sequence [str ]],
325+ register_operations : Dict [str , list [str ]],
304326 sys_paths : Optional [Sequence [str ]] = None ,
305327 ):
306328 """Initializes a module-based agent.
@@ -310,7 +332,7 @@ def __init__(
310332 Required. The name of the module to import.
311333 agent_name (str):
312334 Required. The name of the agent in the module to instantiate.
313- register_operations (Dict[str, Sequence [str]]):
335+ register_operations (Dict[str, list [str]]):
314336 Required. A dictionary of API modes to a list of method names.
315337 sys_paths (Sequence[str]):
316338 Optional. The system paths to search for the module. It should
@@ -336,8 +358,11 @@ def clone(self) -> "ModuleAgent":
336358 sys_paths = self ._tmpl_attrs .get ("sys_paths" ),
337359 )
338360
339- def register_operations (self ) -> Dict [str , Sequence [str ]]:
340- self ._tmpl_attrs .get ("register_operations" )
361+ def register_operations (self , ** kwargs : Any ) -> dict [str , list [str ]]:
362+ reg_operations = self ._tmpl_attrs .get ("register_operations" )
363+ if reg_operations is None :
364+ raise ValueError ("Register operations is not set." )
365+ return reg_operations
341366
342367 def set_up (self ) -> None :
343368 """Sets up the agent for execution of queries at runtime.
@@ -411,7 +436,7 @@ def __call__(
411436class GetAsyncOperationFunction (Protocol ):
412437 async def __call__ (
413438 self , * , operation_name : str , ** kwargs : Any
414- ) -> Awaitable [ AgentEngineOperationUnion ] :
439+ ) -> AgentEngineOperationUnion :
415440 pass
416441
417442
@@ -507,7 +532,7 @@ def _await_operation(
507532def _compare_requirements (
508533 * ,
509534 requirements : Mapping [str , str ],
510- constraints : Union [Sequence [str ], Mapping [str , "_SpecifierSet" ]],
535+ constraints : Union [Sequence [str ], Mapping [str , Optional [ "_SpecifierSet" ] ]],
511536 required_packages : Optional [Iterator [str ]] = None ,
512537) -> _RequirementsValidationResult :
513538 """Compares the requirements with the constraints.
@@ -536,7 +561,7 @@ def _compare_requirements(
536561 """
537562 packaging_version = _import_packaging_version_or_raise ()
538563 if required_packages is None :
539- required_packages = _DEFAULT_REQUIRED_PACKAGES
564+ required_packages = _DEFAULT_REQUIRED_PACKAGES # type: ignore[assignment]
540565 result = _RequirementsValidationResult (
541566 warnings = _RequirementsValidationWarnings (missing = set (), incompatible = set ()),
542567 actions = _RequirementsValidationActions (append = set ()),
@@ -583,7 +608,7 @@ def _generate_class_methods_spec_or_raise(
583608 if isinstance (agent , ModuleAgent ):
584609 # We do a dry-run of setting up the agent engine to have the operations
585610 # needed for registration.
586- agent : ModuleAgent = agent .clone ()
611+ agent : ModuleAgent = agent .clone () # type: ignore[no-redef]
587612 try :
588613 agent .set_up ()
589614 except Exception as e :
@@ -819,13 +844,13 @@ def _get_gcs_bucket(
819844 new_bucket = storage_client .bucket (staging_bucket )
820845 gcs_bucket = storage_client .create_bucket (new_bucket , location = location )
821846 logger .info (f"Creating bucket { staging_bucket } in { location = } " )
822- return gcs_bucket # type: ignore[no-any-return]
847+ return gcs_bucket
823848
824849
825850def _get_registered_operations (
826851 * ,
827852 agent : _AgentEngineInterface ,
828- ) -> Dict [str , List [str ]]:
853+ ) -> dict [str , list [str ]]:
829854 """Retrieves registered operations for a AgentEngine."""
830855 if isinstance (agent , OperationRegistrable ):
831856 return agent .register_operations ()
@@ -859,13 +884,13 @@ def _import_cloudpickle_or_raise() -> types.ModuleType:
859884def _import_cloud_storage_or_raise () -> types .ModuleType :
860885 """Tries to import the Cloud Storage module."""
861886 try :
862- from google .cloud import storage
887+ from google .cloud import storage # type: ignore[attr-defined]
863888 except ImportError as e :
864889 raise ImportError (
865890 "Cloud Storage is not installed. Please call "
866891 "'pip install google-cloud-aiplatform[agent_engines]'."
867892 ) from e
868- return storage
893+ return storage # type: ignore[no-any-return]
869894
870895
871896def _import_packaging_requirements_or_raise () -> types .ModuleType :
@@ -1202,7 +1227,7 @@ def _upload_agent_engine(
12021227) -> None :
12031228 """Uploads the agent engine to GCS."""
12041229 cloudpickle = _import_cloudpickle_or_raise ()
1205- blob = gcs_bucket .blob (f"{ gcs_dir_name } /{ _BLOB_FILENAME } " ) # type: ignore[attr-defined]
1230+ blob = gcs_bucket .blob (f"{ gcs_dir_name } /{ _BLOB_FILENAME } " )
12061231 with blob .open ("wb" ) as f :
12071232 try :
12081233 cloudpickle .dump (agent , f )
@@ -1216,7 +1241,7 @@ def _upload_agent_engine(
12161241 _ = cloudpickle .load (f )
12171242 except Exception as e :
12181243 raise TypeError ("Agent engine serialized to an invalid format" ) from e
1219- dir_name = f"gs://{ gcs_bucket .name } /{ gcs_dir_name } " # type: ignore[attr-defined]
1244+ dir_name = f"gs://{ gcs_bucket .name } /{ gcs_dir_name } "
12201245 logger .info (f"Wrote to { dir_name } /{ _BLOB_FILENAME } " )
12211246
12221247
@@ -1227,9 +1252,9 @@ def _upload_requirements(
12271252 gcs_dir_name : str ,
12281253) -> None :
12291254 """Uploads the requirements file to GCS."""
1230- blob = gcs_bucket .blob (f"{ gcs_dir_name } /{ _REQUIREMENTS_FILE } " ) # type: ignore[attr-defined]
1255+ blob = gcs_bucket .blob (f"{ gcs_dir_name } /{ _REQUIREMENTS_FILE } " )
12311256 blob .upload_from_string ("\n " .join (requirements ))
1232- dir_name = f"gs://{ gcs_bucket .name } /{ gcs_dir_name } " # type: ignore[attr-defined]
1257+ dir_name = f"gs://{ gcs_bucket .name } /{ gcs_dir_name } "
12331258 logger .info (f"Writing to { dir_name } /{ _REQUIREMENTS_FILE } " )
12341259
12351260
@@ -1246,9 +1271,9 @@ def _upload_extra_packages(
12461271 for file in extra_packages :
12471272 tar .add (file )
12481273 tar_fileobj .seek (0 )
1249- blob = gcs_bucket .blob (f"{ gcs_dir_name } /{ _EXTRA_PACKAGES_FILE } " ) # type: ignore[attr-defined]
1274+ blob = gcs_bucket .blob (f"{ gcs_dir_name } /{ _EXTRA_PACKAGES_FILE } " )
12501275 blob .upload_from_string (tar_fileobj .read ())
1251- dir_name = f"gs://{ gcs_bucket .name } /{ gcs_dir_name } " # type: ignore[attr-defined]
1276+ dir_name = f"gs://{ gcs_bucket .name } /{ gcs_dir_name } "
12521277 logger .info (f"Writing to { dir_name } /{ _EXTRA_PACKAGES_FILE } " )
12531278
12541279
@@ -1369,7 +1394,7 @@ def _validate_requirements_or_warn(
13691394 * ,
13701395 obj : Any ,
13711396 requirements : List [str ],
1372- ) -> Mapping [ str , str ]:
1397+ ) -> List [ str ]:
13731398 """Compiles the requirements into a list of requirements."""
13741399 requirements = requirements .copy ()
13751400 try :
@@ -1380,16 +1405,14 @@ def _validate_requirements_or_warn(
13801405 requirements = current_requirements ,
13811406 constraints = constraints ,
13821407 )
1383- for warning_type , warnings in missing_requirements .get (
1384- _WARNINGS_KEY , {}
1385- ).items ():
1408+ for warning_type , warnings in missing_requirements ["warnings" ].items ():
13861409 if warnings :
13871410 logger .warning (
13881411 f"The following requirements are { warning_type } : { warnings } "
13891412 )
1390- for action_type , actions in missing_requirements . get ( _ACTIONS_KEY , {}) .items ():
1413+ for action_type , actions in missing_requirements [ "actions" ] .items ():
13911414 if actions and action_type == _ACTION_APPEND :
1392- for action in actions :
1415+ for action in actions : # type: ignore[attr-defined]
13931416 requirements .append (action )
13941417 logger .info (f"The following requirements are appended: { actions } " )
13951418 except Exception as e :
@@ -1413,7 +1436,7 @@ def _validate_requirements_or_raise(
14131436 logger .info (f"Read the following lines: { requirements } " )
14141437 except IOError as err :
14151438 raise IOError (f"Failed to read requirements from { requirements = } " ) from err
1416- requirements = _validate_requirements_or_warn ( # type: ignore[assignment]
1439+ requirements = _validate_requirements_or_warn (
14171440 obj = agent ,
14181441 requirements = requirements ,
14191442 )
@@ -1560,19 +1583,6 @@ def _method(self, **kwargs) -> Any: # type: ignore[no-untyped-def]
15601583 return _method
15611584
15621585
1563- AgentEngineOperationUnion = Union [
1564- genai_types .AgentEngineOperation ,
1565- genai_types .AgentEngineMemoryOperation ,
1566- genai_types .AgentEngineGenerateMemoriesOperation ,
1567- ]
1568-
1569-
1570- class GetOperationFunction (Protocol ):
1571- def __call__ ( # noqa: E704
1572- self , * , operation_name : str , ** kwargs : Any
1573- ) -> AgentEngineOperationUnion : ...
1574-
1575-
15761586def _wrap_query_operation (* , method_name : str ) -> Callable [..., Any ]:
15771587 """Wraps an Agent Engine method, creating a callable for `query` API.
15781588
@@ -1835,7 +1845,7 @@ async def _method(self, **kwargs) -> Any: # type: ignore[no-untyped-def]
18351845
18361846 return response
18371847
1838- return _method
1848+ return _method # type: ignore[return-value]
18391849
18401850
18411851def _yield_parsed_json (http_response : google_genai_types .HttpResponse ) -> Iterator [Any ]:
0 commit comments