@@ -357,9 +357,19 @@ def _build_reasoning_param(self, model_settings: ModelSettings) -> Any:
357357 reasoning_param = {
358358 "effort" : model_settings .reasoning .effort ,
359359 }
360- # Add generate_summary if specified and not None
361- if hasattr (model_settings .reasoning , 'generate_summary' ) and model_settings .reasoning .generate_summary is not None :
362- reasoning_param ["summary" ] = model_settings .reasoning .generate_summary
360+ # Add summary if specified (check both 'summary' and 'generate_summary' for compatibility)
361+ summary_value = None
362+ if hasattr (model_settings .reasoning , 'summary' ) and model_settings .reasoning .summary is not None :
363+ summary_value = model_settings .reasoning .summary
364+ elif (
365+ hasattr (model_settings .reasoning , 'generate_summary' )
366+ and model_settings .reasoning .generate_summary is not None
367+ ):
368+ summary_value = model_settings .reasoning .generate_summary
369+
370+ if summary_value is not None :
371+ reasoning_param ["summary" ] = summary_value
372+
363373 logger .debug (f"[TemporalStreamingModel] Using reasoning param: { reasoning_param } " )
364374 return reasoning_param
365375
@@ -679,9 +689,34 @@ async def get_response(
679689 output_index = getattr (event , 'output_index' , 0 )
680690
681691 if item and getattr (item , 'type' , None ) == 'reasoning' :
682- logger .debug (f"[TemporalStreamingModel] Reasoning item completed" )
683- # Don't close the context here - let it stay open for more reasoning events
684- # It will be closed when we send the final update or at the end
692+ if reasoning_context and reasoning_summaries :
693+ logger .debug (f"[TemporalStreamingModel] Reasoning itme completed, sending final update" )
694+ try :
695+ # Send a full message update with the complete reasoning content
696+ complete_reasoning_content = ReasoningContent (
697+ author = "agent" ,
698+ summary = reasoning_summaries , # Use accumulated summaries
699+ content = reasoning_contents if reasoning_contents else [],
700+ type = "reasoning" ,
701+ style = "static" ,
702+ )
703+
704+ await reasoning_context .stream_update (
705+ update = StreamTaskMessageFull (
706+ parent_task_message = reasoning_context .task_message ,
707+ content = complete_reasoning_content ,
708+ type = "full" ,
709+ ),
710+ )
711+
712+ # Close the reasoning context after sending the final update
713+ # This matches the reference implementation pattern
714+ await reasoning_context .close ()
715+ reasoning_context = None
716+ logger .debug (f"[TemporalStreamingModel] Closed reasoning context after final update" )
717+ except Exception as e :
718+ logger .warning (f"Failed to send reasoning part done update: { e } " )
719+
685720 elif item and getattr (item , 'type' , None ) == 'function_call' :
686721 # Function call completed - add to output
687722 if output_index in function_calls_in_progress :
@@ -708,34 +743,8 @@ async def get_response(
708743 current_reasoning_summary = ""
709744
710745 elif isinstance (event , ResponseReasoningSummaryPartDoneEvent ):
711- # Reasoning part completed - send final update and close if this is the last part
712- if reasoning_context and reasoning_summaries :
713- logger .debug (f"[TemporalStreamingModel] Reasoning part completed, sending final update" )
714- try :
715- # Send a full message update with the complete reasoning content
716- complete_reasoning_content = ReasoningContent (
717- author = "agent" ,
718- summary = reasoning_summaries , # Use accumulated summaries
719- content = reasoning_contents if reasoning_contents else [],
720- type = "reasoning" ,
721- style = "static" ,
722- )
723-
724- await reasoning_context .stream_update (
725- update = StreamTaskMessageFull (
726- parent_task_message = reasoning_context .task_message ,
727- content = complete_reasoning_content ,
728- type = "full" ,
729- ),
730- )
731-
732- # Close the reasoning context after sending the final update
733- # This matches the reference implementation pattern
734- await reasoning_context .close ()
735- reasoning_context = None
736- logger .debug (f"[TemporalStreamingModel] Closed reasoning context after final update" )
737- except Exception as e :
738- logger .warning (f"Failed to send reasoning part done update: { e } " )
746+ # Reasoning part completed - ResponseOutputItemDoneEvent will handle the final update
747+ logger .debug (f"[TemporalStreamingModel] Reasoning part completed" )
739748
740749 elif isinstance (event , ResponseCompletedEvent ):
741750 # Response completed
@@ -842,10 +851,16 @@ def stream_response(self, *args, **kwargs):
842851class TemporalStreamingModelProvider (ModelProvider ):
843852 """Custom model provider that returns a streaming-capable model."""
844853
845- def __init__ (self ):
846- """Initialize the provider."""
854+ def __init__ (self , openai_client : Optional [AsyncOpenAI ] = None ):
855+ """Initialize the provider.
856+
857+ Args:
858+ openai_client: Optional custom AsyncOpenAI client to use for all models.
859+ If not provided, each model will create its own default client.
860+ """
847861 super ().__init__ ()
848- logger .info ("[TemporalStreamingModelProvider] Initialized" )
862+ self .openai_client = openai_client
863+ logger .info (f"[TemporalStreamingModelProvider] Initialized, custom_client={ openai_client is not None } " )
849864
850865 @override
851866 def get_model (self , model_name : Union [str , None ]) -> Model :
@@ -860,5 +875,5 @@ def get_model(self, model_name: Union[str, None]) -> Model:
860875 # Use the provided model_name or default to gpt-4o
861876 actual_model = model_name if model_name else "gpt-4o"
862877 logger .info (f"[TemporalStreamingModelProvider] Creating TemporalStreamingModel for model_name: { actual_model } " )
863- model = TemporalStreamingModel (model_name = actual_model )
878+ model = TemporalStreamingModel (model_name = actual_model , openai_client = self . openai_client )
864879 return model
0 commit comments