Skip to content

Commit e71c5be

Browse files
Merge pull request #197 from scaleapi/dm/bug-fix-streaming
Bug fixes streaming provider
2 parents 5ad8b88 + d1b06ad commit e71c5be

File tree

1 file changed

+53
-38
lines changed

1 file changed

+53
-38
lines changed

src/agentex/lib/core/temporal/plugins/openai_agents/models/temporal_streaming_model.py

Lines changed: 53 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -357,9 +357,19 @@ def _build_reasoning_param(self, model_settings: ModelSettings) -> Any:
357357
reasoning_param = {
358358
"effort": model_settings.reasoning.effort,
359359
}
360-
# Add generate_summary if specified and not None
361-
if hasattr(model_settings.reasoning, 'generate_summary') and model_settings.reasoning.generate_summary is not None:
362-
reasoning_param["summary"] = model_settings.reasoning.generate_summary
360+
# Add summary if specified (check both 'summary' and 'generate_summary' for compatibility)
361+
summary_value = None
362+
if hasattr(model_settings.reasoning, 'summary') and model_settings.reasoning.summary is not None:
363+
summary_value = model_settings.reasoning.summary
364+
elif (
365+
hasattr(model_settings.reasoning, 'generate_summary')
366+
and model_settings.reasoning.generate_summary is not None
367+
):
368+
summary_value = model_settings.reasoning.generate_summary
369+
370+
if summary_value is not None:
371+
reasoning_param["summary"] = summary_value
372+
363373
logger.debug(f"[TemporalStreamingModel] Using reasoning param: {reasoning_param}")
364374
return reasoning_param
365375

@@ -679,9 +689,34 @@ async def get_response(
679689
output_index = getattr(event, 'output_index', 0)
680690

681691
if item and getattr(item, 'type', None) == 'reasoning':
682-
logger.debug(f"[TemporalStreamingModel] Reasoning item completed")
683-
# Don't close the context here - let it stay open for more reasoning events
684-
# It will be closed when we send the final update or at the end
692+
if reasoning_context and reasoning_summaries:
693+
logger.debug(f"[TemporalStreamingModel] Reasoning itme completed, sending final update")
694+
try:
695+
# Send a full message update with the complete reasoning content
696+
complete_reasoning_content = ReasoningContent(
697+
author="agent",
698+
summary=reasoning_summaries, # Use accumulated summaries
699+
content=reasoning_contents if reasoning_contents else [],
700+
type="reasoning",
701+
style="static",
702+
)
703+
704+
await reasoning_context.stream_update(
705+
update=StreamTaskMessageFull(
706+
parent_task_message=reasoning_context.task_message,
707+
content=complete_reasoning_content,
708+
type="full",
709+
),
710+
)
711+
712+
# Close the reasoning context after sending the final update
713+
# This matches the reference implementation pattern
714+
await reasoning_context.close()
715+
reasoning_context = None
716+
logger.debug(f"[TemporalStreamingModel] Closed reasoning context after final update")
717+
except Exception as e:
718+
logger.warning(f"Failed to send reasoning part done update: {e}")
719+
685720
elif item and getattr(item, 'type', None) == 'function_call':
686721
# Function call completed - add to output
687722
if output_index in function_calls_in_progress:
@@ -708,34 +743,8 @@ async def get_response(
708743
current_reasoning_summary = ""
709744

710745
elif isinstance(event, ResponseReasoningSummaryPartDoneEvent):
711-
# Reasoning part completed - send final update and close if this is the last part
712-
if reasoning_context and reasoning_summaries:
713-
logger.debug(f"[TemporalStreamingModel] Reasoning part completed, sending final update")
714-
try:
715-
# Send a full message update with the complete reasoning content
716-
complete_reasoning_content = ReasoningContent(
717-
author="agent",
718-
summary=reasoning_summaries, # Use accumulated summaries
719-
content=reasoning_contents if reasoning_contents else [],
720-
type="reasoning",
721-
style="static",
722-
)
723-
724-
await reasoning_context.stream_update(
725-
update=StreamTaskMessageFull(
726-
parent_task_message=reasoning_context.task_message,
727-
content=complete_reasoning_content,
728-
type="full",
729-
),
730-
)
731-
732-
# Close the reasoning context after sending the final update
733-
# This matches the reference implementation pattern
734-
await reasoning_context.close()
735-
reasoning_context = None
736-
logger.debug(f"[TemporalStreamingModel] Closed reasoning context after final update")
737-
except Exception as e:
738-
logger.warning(f"Failed to send reasoning part done update: {e}")
746+
# Reasoning part completed - ResponseOutputItemDoneEvent will handle the final update
747+
logger.debug(f"[TemporalStreamingModel] Reasoning part completed")
739748

740749
elif isinstance(event, ResponseCompletedEvent):
741750
# Response completed
@@ -842,10 +851,16 @@ def stream_response(self, *args, **kwargs):
842851
class TemporalStreamingModelProvider(ModelProvider):
843852
"""Custom model provider that returns a streaming-capable model."""
844853

845-
def __init__(self):
846-
"""Initialize the provider."""
854+
def __init__(self, openai_client: Optional[AsyncOpenAI] = None):
855+
"""Initialize the provider.
856+
857+
Args:
858+
openai_client: Optional custom AsyncOpenAI client to use for all models.
859+
If not provided, each model will create its own default client.
860+
"""
847861
super().__init__()
848-
logger.info("[TemporalStreamingModelProvider] Initialized")
862+
self.openai_client = openai_client
863+
logger.info(f"[TemporalStreamingModelProvider] Initialized, custom_client={openai_client is not None}")
849864

850865
@override
851866
def get_model(self, model_name: Union[str, None]) -> Model:
@@ -860,5 +875,5 @@ def get_model(self, model_name: Union[str, None]) -> Model:
860875
# Use the provided model_name or default to gpt-4o
861876
actual_model = model_name if model_name else "gpt-4o"
862877
logger.info(f"[TemporalStreamingModelProvider] Creating TemporalStreamingModel for model_name: {actual_model}")
863-
model = TemporalStreamingModel(model_name=actual_model)
878+
model = TemporalStreamingModel(model_name=actual_model, openai_client=self.openai_client)
864879
return model

0 commit comments

Comments
 (0)