3232from typing import (
3333 Any ,
3434 AsyncIterator ,
35- Awaitable ,
3635 Callable ,
3736 Coroutine ,
3837 Dict ,
7877 _STDLIB_MODULE_NAMES : frozenset [str ] = frozenset () # type: ignore[no-redef]
7978
8079
81- try :
82- from google .cloud import storage
80+ if typing . TYPE_CHECKING :
81+ from google .cloud import storage # type: ignore[attr-defined]
8382
84- _StorageBucket : type [Any ] = storage .Bucket
85- except (ImportError , AttributeError ):
86- _StorageBucket : type [Any ] = Any # type: ignore[no-redef]
83+ _StorageBucket = storage .Bucket
84+ else :
85+ try :
86+ from google .cloud import storage # type: ignore[attr-defined]
8787
88+ _StorageBucket : type [Any ] = storage .Bucket
89+ except (ImportError , AttributeError ):
90+ _StorageBucket : type [Any ] = Any # type: ignore[no-redef]
8891
89- try :
92+
93+ if typing .TYPE_CHECKING :
9094 import packaging
9195
92- _SpecifierSet : type [Any ] = packaging .specifiers .SpecifierSet
93- except (ImportError , AttributeError ):
94- _SpecifierSet : type [Any ] = Any # type: ignore[no-redef]
96+ _SpecifierSet = packaging .specifiers .SpecifierSet
97+ else :
98+ try :
99+ import packaging
100+
101+ _SpecifierSet : type [Any ] = packaging .specifiers .SpecifierSet
102+ except (ImportError , AttributeError ):
103+ _SpecifierSet : type [Any ] = Any # type: ignore[no-redef]
95104
96105
97106try :
@@ -258,16 +267,22 @@ class OperationRegistrable(Protocol):
258267 """Protocol for agents that have registered operations."""
259268
260269 @abc .abstractmethod
261- def register_operations (self , ** kwargs ) -> Dict [str , Sequence [str ]]: # type: ignore[no-untyped-def]
270+ def register_operations (self , ** kwargs : Any ) -> dict [str , list [str ]]:
262271 """Register the user provided operations (modes and methods)."""
272+ pass
263273
264274
265- try :
275+ if typing . TYPE_CHECKING :
266276 from google .adk .agents import BaseAgent
267277
268- ADKAgent : type [Any ] = BaseAgent
269- except (ImportError , AttributeError ):
270- ADKAgent : type [Any ] = Any # type: ignore[no-redef]
278+ ADKAgent : TypeAlias = BaseAgent
279+ else :
280+ try :
281+ from google .adk .agents import BaseAgent
282+
283+ ADKAgent : Optional [TypeAlias ] = BaseAgent
284+ except (ImportError , AttributeError ):
285+ ADKAgent : Optional [TypeAlias ] = None # type: ignore[no-redef]
271286
272287_AgentEngineInterface = Union [
273288 ADKAgent ,
@@ -283,8 +298,9 @@ def register_operations(self, **kwargs) -> Dict[str, Sequence[str]]: # type: ig
283298class _ModuleAgentAttributes (TypedDict , total = False ):
284299 module_name : str
285300 agent_name : str
286- register_operations : Dict [str , Sequence [str ]]
301+ register_operations : Dict [str , list [str ]]
287302 sys_paths : Optional [Sequence [str ]]
303+ agent : _AgentEngineInterface
288304
289305
290306class ModuleAgent (Cloneable , OperationRegistrable ):
@@ -300,7 +316,7 @@ def __init__(
300316 * ,
301317 module_name : str ,
302318 agent_name : str ,
303- register_operations : Dict [str , Sequence [str ]],
319+ register_operations : Dict [str , list [str ]],
304320 sys_paths : Optional [Sequence [str ]] = None ,
305321 ):
306322 """Initializes a module-based agent.
@@ -310,7 +326,7 @@ def __init__(
310326 Required. The name of the module to import.
311327 agent_name (str):
312328 Required. The name of the agent in the module to instantiate.
313- register_operations (Dict[str, Sequence [str]]):
329+ register_operations (Dict[str, list [str]]):
314330 Required. A dictionary of API modes to a list of method names.
315331 sys_paths (Sequence[str]):
316332 Optional. The system paths to search for the module. It should
@@ -336,8 +352,11 @@ def clone(self) -> "ModuleAgent":
336352 sys_paths = self ._tmpl_attrs .get ("sys_paths" ),
337353 )
338354
339- def register_operations (self ) -> Dict [str , Sequence [str ]]:
340- self ._tmpl_attrs .get ("register_operations" )
355+ def register_operations (self , ** kwargs : Any ) -> dict [str , list [str ]]:
356+ reg_operations = self ._tmpl_attrs .get ("register_operations" )
357+ if reg_operations is None :
358+ raise ValueError ("Register operations is not set." )
359+ return reg_operations
341360
342361 def set_up (self ) -> None :
343362 """Sets up the agent for execution of queries at runtime.
@@ -411,7 +430,7 @@ def __call__(
411430class GetAsyncOperationFunction (Protocol ):
412431 async def __call__ (
413432 self , * , operation_name : str , ** kwargs : Any
414- ) -> Awaitable [ AgentEngineOperationUnion ] :
433+ ) -> AgentEngineOperationUnion :
415434 pass
416435
417436
@@ -507,7 +526,7 @@ def _await_operation(
507526def _compare_requirements (
508527 * ,
509528 requirements : Mapping [str , str ],
510- constraints : Union [Sequence [str ], Mapping [str , "_SpecifierSet" ]],
529+ constraints : Union [Sequence [str ], Mapping [str , Optional [ "_SpecifierSet" ] ]],
511530 required_packages : Optional [Iterator [str ]] = None ,
512531) -> _RequirementsValidationResult :
513532 """Compares the requirements with the constraints.
@@ -536,7 +555,7 @@ def _compare_requirements(
536555 """
537556 packaging_version = _import_packaging_version_or_raise ()
538557 if required_packages is None :
539- required_packages = _DEFAULT_REQUIRED_PACKAGES
558+ required_packages = _DEFAULT_REQUIRED_PACKAGES # type: ignore[assignment]
540559 result = _RequirementsValidationResult (
541560 warnings = _RequirementsValidationWarnings (missing = set (), incompatible = set ()),
542561 actions = _RequirementsValidationActions (append = set ()),
@@ -583,7 +602,7 @@ def _generate_class_methods_spec_or_raise(
583602 if isinstance (agent , ModuleAgent ):
584603 # We do a dry-run of setting up the agent engine to have the operations
585604 # needed for registration.
586- agent : ModuleAgent = agent .clone ()
605+ agent : ModuleAgent = agent .clone () # type: ignore[no-redef]
587606 try :
588607 agent .set_up ()
589608 except Exception as e :
@@ -819,13 +838,13 @@ def _get_gcs_bucket(
819838 new_bucket = storage_client .bucket (staging_bucket )
820839 gcs_bucket = storage_client .create_bucket (new_bucket , location = location )
821840 logger .info (f"Creating bucket { staging_bucket } in { location = } " )
822- return gcs_bucket # type: ignore[no-any-return]
841+ return gcs_bucket
823842
824843
825844def _get_registered_operations (
826845 * ,
827846 agent : _AgentEngineInterface ,
828- ) -> Dict [str , List [str ]]:
847+ ) -> dict [str , list [str ]]:
829848 """Retrieves registered operations for a AgentEngine."""
830849 if isinstance (agent , OperationRegistrable ):
831850 return agent .register_operations ()
@@ -859,13 +878,13 @@ def _import_cloudpickle_or_raise() -> types.ModuleType:
859878def _import_cloud_storage_or_raise () -> types .ModuleType :
860879 """Tries to import the Cloud Storage module."""
861880 try :
862- from google .cloud import storage
881+ from google .cloud import storage # type: ignore[attr-defined]
863882 except ImportError as e :
864883 raise ImportError (
865884 "Cloud Storage is not installed. Please call "
866885 "'pip install google-cloud-aiplatform[agent_engines]'."
867886 ) from e
868- return storage
887+ return storage # type: ignore[no-any-return]
869888
870889
871890def _import_packaging_requirements_or_raise () -> types .ModuleType :
@@ -1202,7 +1221,7 @@ def _upload_agent_engine(
12021221) -> None :
12031222 """Uploads the agent engine to GCS."""
12041223 cloudpickle = _import_cloudpickle_or_raise ()
1205- blob = gcs_bucket .blob (f"{ gcs_dir_name } /{ _BLOB_FILENAME } " ) # type: ignore[attr-defined]
1224+ blob = gcs_bucket .blob (f"{ gcs_dir_name } /{ _BLOB_FILENAME } " )
12061225 with blob .open ("wb" ) as f :
12071226 try :
12081227 cloudpickle .dump (agent , f )
@@ -1216,7 +1235,7 @@ def _upload_agent_engine(
12161235 _ = cloudpickle .load (f )
12171236 except Exception as e :
12181237 raise TypeError ("Agent engine serialized to an invalid format" ) from e
1219- dir_name = f"gs://{ gcs_bucket .name } /{ gcs_dir_name } " # type: ignore[attr-defined]
1238+ dir_name = f"gs://{ gcs_bucket .name } /{ gcs_dir_name } "
12201239 logger .info (f"Wrote to { dir_name } /{ _BLOB_FILENAME } " )
12211240
12221241
@@ -1227,9 +1246,9 @@ def _upload_requirements(
12271246 gcs_dir_name : str ,
12281247) -> None :
12291248 """Uploads the requirements file to GCS."""
1230- blob = gcs_bucket .blob (f"{ gcs_dir_name } /{ _REQUIREMENTS_FILE } " ) # type: ignore[attr-defined]
1249+ blob = gcs_bucket .blob (f"{ gcs_dir_name } /{ _REQUIREMENTS_FILE } " )
12311250 blob .upload_from_string ("\n " .join (requirements ))
1232- dir_name = f"gs://{ gcs_bucket .name } /{ gcs_dir_name } " # type: ignore[attr-defined]
1251+ dir_name = f"gs://{ gcs_bucket .name } /{ gcs_dir_name } "
12331252 logger .info (f"Writing to { dir_name } /{ _REQUIREMENTS_FILE } " )
12341253
12351254
@@ -1246,9 +1265,9 @@ def _upload_extra_packages(
12461265 for file in extra_packages :
12471266 tar .add (file )
12481267 tar_fileobj .seek (0 )
1249- blob = gcs_bucket .blob (f"{ gcs_dir_name } /{ _EXTRA_PACKAGES_FILE } " ) # type: ignore[attr-defined]
1268+ blob = gcs_bucket .blob (f"{ gcs_dir_name } /{ _EXTRA_PACKAGES_FILE } " )
12501269 blob .upload_from_string (tar_fileobj .read ())
1251- dir_name = f"gs://{ gcs_bucket .name } /{ gcs_dir_name } " # type: ignore[attr-defined]
1270+ dir_name = f"gs://{ gcs_bucket .name } /{ gcs_dir_name } "
12521271 logger .info (f"Writing to { dir_name } /{ _EXTRA_PACKAGES_FILE } " )
12531272
12541273
@@ -1369,7 +1388,7 @@ def _validate_requirements_or_warn(
13691388 * ,
13701389 obj : Any ,
13711390 requirements : List [str ],
1372- ) -> Mapping [ str , str ]:
1391+ ) -> List [ str ]:
13731392 """Compiles the requirements into a list of requirements."""
13741393 requirements = requirements .copy ()
13751394 try :
@@ -1380,16 +1399,14 @@ def _validate_requirements_or_warn(
13801399 requirements = current_requirements ,
13811400 constraints = constraints ,
13821401 )
1383- for warning_type , warnings in missing_requirements .get (
1384- _WARNINGS_KEY , {}
1385- ).items ():
1402+ for warning_type , warnings in missing_requirements ["warnings" ].items ():
13861403 if warnings :
13871404 logger .warning (
13881405 f"The following requirements are { warning_type } : { warnings } "
13891406 )
1390- for action_type , actions in missing_requirements . get ( _ACTIONS_KEY , {}) .items ():
1407+ for action_type , actions in missing_requirements [ "actions" ] .items ():
13911408 if actions and action_type == _ACTION_APPEND :
1392- for action in actions :
1409+ for action in actions : # type: ignore[attr-defined]
13931410 requirements .append (action )
13941411 logger .info (f"The following requirements are appended: { actions } " )
13951412 except Exception as e :
@@ -1413,7 +1430,7 @@ def _validate_requirements_or_raise(
14131430 logger .info (f"Read the following lines: { requirements } " )
14141431 except IOError as err :
14151432 raise IOError (f"Failed to read requirements from { requirements = } " ) from err
1416- requirements = _validate_requirements_or_warn ( # type: ignore[assignment]
1433+ requirements = _validate_requirements_or_warn (
14171434 obj = agent ,
14181435 requirements = requirements ,
14191436 )
@@ -1560,19 +1577,6 @@ def _method(self, **kwargs) -> Any: # type: ignore[no-untyped-def]
15601577 return _method
15611578
15621579
1563- AgentEngineOperationUnion = Union [
1564- genai_types .AgentEngineOperation ,
1565- genai_types .AgentEngineMemoryOperation ,
1566- genai_types .AgentEngineGenerateMemoriesOperation ,
1567- ]
1568-
1569-
1570- class GetOperationFunction (Protocol ):
1571- def __call__ ( # noqa: E704
1572- self , * , operation_name : str , ** kwargs : Any
1573- ) -> AgentEngineOperationUnion : ...
1574-
1575-
15761580def _wrap_query_operation (* , method_name : str ) -> Callable [..., Any ]:
15771581 """Wraps an Agent Engine method, creating a callable for `query` API.
15781582
@@ -1835,7 +1839,7 @@ async def _method(self, **kwargs) -> Any: # type: ignore[no-untyped-def]
18351839
18361840 return response
18371841
1838- return _method
1842+ return _method # type: ignore[return-value]
18391843
18401844
18411845def _yield_parsed_json (http_response : google_genai_types .HttpResponse ) -> Iterator [Any ]:
0 commit comments