-
Notifications
You must be signed in to change notification settings - Fork 4.7k
feat: Add token counting utility + Add support for it in Compression #5593
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 19 commits
bc3cc01
517f2d7
22ab15d
7011c06
8a43877
4c73db2
6f43ed5
ff1e84a
f6e7200
259b5a7
8728502
3e51b13
2fec0e9
4269391
131f190
5d1ed33
5f09d4b
6458a30
be4e3c1
0f17f6d
d568ff0
4dc5a2b
5e7dbeb
bb73ed7
7f4498e
c4a74aa
b38b84b
38c57f9
ec2d35d
67e138b
b353541
a189834
f1ceeb0
e9a2254
7601c69
a819c23
94ddee8
e34ccd7
292fc9a
715bba3
14f9495
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,72 @@ | ||
| """ | ||
| This example shows how to set a context token based limit for tool call compression. | ||
| Run: `python cookbook/agents/context_compression/token_based_tool_call_compression.py` | ||
| """ | ||
|
|
||
| from agno.agent import Agent | ||
| from agno.compression.manager import CompressionManager | ||
| from agno.db.sqlite import SqliteDb | ||
| from agno.models.openai import OpenAIChat | ||
| from agno.tools.duckduckgo import DuckDuckGoTools | ||
|
|
||
| compression_prompt = """ | ||
| You are a compression expert. Your goal is to compress web search results for a competitive intelligence analyst. | ||
| YOUR GOAL: Extract only actionable competitive insights while being extremely concise. | ||
| MUST PRESERVE: | ||
| - Competitor names and specific actions (product launches, partnerships, acquisitions, pricing changes) | ||
| - Exact numbers (revenue, market share, growth rates, pricing, headcount) | ||
| - Precise dates (announcement dates, launch dates, deal dates) | ||
| - Direct quotes from executives or official statements | ||
| - Funding rounds and valuations | ||
| MUST REMOVE: | ||
| - Company history and background information | ||
| - General industry trends (unless competitor-specific) | ||
| - Analyst opinions and speculation (keep only facts) | ||
| - Detailed product descriptions (keep only key differentiators and pricing) | ||
| - Marketing fluff and promotional language | ||
| OUTPUT FORMAT: | ||
| Return a bullet-point list where each line follows this format: | ||
| "[Company Name] - [Date]: [Action/Event] ([Key Numbers/Details])" | ||
| Keep it under 200 words total. Be ruthlessly concise. Facts only. | ||
| Example: | ||
| - Acme Corp - Mar 15, 2024: Launched AcmeGPT at $99/user/month, targeting enterprise market | ||
| - TechCo - Feb 10, 2024: Acquired DataStart for $150M, gaining 500 enterprise customers | ||
| """ | ||
|
|
||
| compression_manager = CompressionManager( | ||
| model=OpenAIChat(id="gpt-5-mini"), | ||
| compress_tool_results_token_limit=5000, | ||
| compress_tool_call_instructions=compression_prompt, | ||
| ) | ||
|
|
||
| agent = Agent( | ||
| model=OpenAIChat(id="gpt-4o-mini"), | ||
| tools=[DuckDuckGoTools()], | ||
| description="Specialized in tracking competitor activities", | ||
| instructions="Use the search tools and always use the latest information and data.", | ||
| db=SqliteDb(db_file="tmp/dbs/token_based_tool_call_compression.db"), | ||
| compression_manager=compression_manager, | ||
| add_history_to_context=True, # Add history to context | ||
| num_history_runs=3, | ||
| session_id="token_based_tool_call_compression", | ||
| ) | ||
|
|
||
| agent.print_response( | ||
| """ | ||
| Use the search tools and always use the latest information and data. | ||
| Research recent activities (last 3 months) for these AI companies: | ||
| 1. OpenAI - product launches, partnerships, pricing | ||
| 2. Anthropic - new features, enterprise deals, funding | ||
| 3. Google DeepMind - research breakthroughs, product releases | ||
| 4. Meta AI - open source releases, research papers | ||
| For each, find specific actions with dates and numbers.""", | ||
| stream=True, | ||
| ) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -357,6 +357,32 @@ def _format_messages( | |
| # TODO: Add caching: https://docs.aws.amazon.com/bedrock/latest/userguide/conversation-inference-call.html | ||
| return formatted_messages, system_message | ||
|
|
||
| def count_tokens( | ||
| self, | ||
| messages: List[Message], | ||
| tools: Optional[List[Dict[str, Any]]] = None, | ||
| ) -> int: | ||
| try: | ||
| formatted_messages, system_message = self._format_messages(messages, compress_tool_results=True) | ||
| converse_input: Dict[str, Any] = {"messages": formatted_messages} | ||
| if system_message: | ||
| converse_input["system"] = system_message | ||
|
|
||
| response = self.get_client().count_tokens(modelId=self.id, input={"converse": converse_input}) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please confirm that this works for bedrock.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. works |
||
| tokens = response.get("inputTokens", 0) | ||
|
|
||
| # Count tool tokens | ||
| if tools: | ||
| from agno.utils.tokens import _count_tool_tokens | ||
|
|
||
| includes_system = any(m.role == "system" for m in messages) | ||
| tokens += _count_tool_tokens(tools, self.id, includes_system) | ||
|
|
||
| return tokens | ||
| except Exception as e: | ||
| log_warning(f"Failed to count tokens via Bedrock API: {e}") | ||
| return super().count_tokens(messages, tools) | ||
Mustafa-Esoofally marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| def invoke( | ||
| self, | ||
| messages: List[Message], | ||
|
|
@@ -719,4 +745,9 @@ def _get_metrics(self, response_usage: Dict[str, Any]) -> Metrics: | |
| metrics.output_tokens = response_usage.get("outputTokens", 0) or 0 | ||
| metrics.total_tokens = metrics.input_tokens + metrics.output_tokens | ||
|
|
||
| log_debug( | ||
manuhortet marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| f"Bedrock response metrics: input_tokens={metrics.input_tokens}, " | ||
| f"output_tokens={metrics.output_tokens}, total_tokens={metrics.total_tokens}" | ||
| ) | ||
|
|
||
| return metrics | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -15,6 +15,7 @@ | |
| List, | ||
| Literal, | ||
| Optional, | ||
| Sequence, | ||
| Tuple, | ||
| Type, | ||
| Union, | ||
|
|
@@ -427,6 +428,15 @@ def _format_tools(self, tools: Optional[List[Union[Function, dict]]]) -> List[Di | |
| _tool_dicts.append(tool) | ||
| return _tool_dicts | ||
|
|
||
| def count_tokens( | ||
manuhortet marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| self, | ||
| messages: List[Message], | ||
| tools: Optional[Sequence[Union[Function, Dict[str, Any]]]] = None, | ||
| ) -> int: | ||
| from agno.utils.tokens import count_tokens | ||
|
|
||
| return count_tokens(messages, tools=list(tools) if tools else None, model_id=self.id) | ||
|
|
||
| def response( | ||
| self, | ||
| messages: List[Message], | ||
|
|
@@ -476,6 +486,10 @@ def response( | |
| _compress_tool_results = compression_manager is not None and compression_manager.compress_tool_results | ||
|
|
||
| while True: | ||
| # Compress tool results | ||
| if compression_manager and compression_manager.should_compress(messages, tools, main_model=self): | ||
Mustafa-Esoofally marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| compression_manager.compress(messages) | ||
|
|
||
| # Get response from model | ||
| assistant_message = Message(role=self.assistant_message_role) | ||
| self._process_model_response( | ||
|
|
@@ -574,11 +588,6 @@ def response( | |
| # Add a function call for each successful execution | ||
| function_call_count += len(function_call_results) | ||
|
|
||
| all_messages = messages + function_call_results | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why are we changing this? I think probably you are right, but there was a reason we did it here
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Before we were limited by |
||
| # Compress tool results | ||
| if compression_manager and compression_manager.should_compress(all_messages): | ||
| compression_manager.compress(all_messages) | ||
|
|
||
| # Format and add results to messages | ||
| self.format_function_call_results( | ||
| messages=messages, | ||
|
|
@@ -678,6 +687,10 @@ async def aresponse( | |
| function_call_count = 0 | ||
|
|
||
| while True: | ||
| # Compress existing tool results BEFORE making API call to avoid context overflow | ||
| if compression_manager and compression_manager.should_compress(messages, tools, main_model=self): | ||
| await compression_manager.acompress(messages) | ||
|
|
||
| # Get response from model | ||
| assistant_message = Message(role=self.assistant_message_role) | ||
| await self._aprocess_model_response( | ||
|
|
@@ -775,11 +788,6 @@ async def aresponse( | |
| # Add a function call for each successful execution | ||
| function_call_count += len(function_call_results) | ||
|
|
||
| all_messages = messages + function_call_results | ||
| # Compress tool results | ||
| if compression_manager and compression_manager.should_compress(all_messages): | ||
| await compression_manager.acompress(all_messages) | ||
|
|
||
| # Format and add results to messages | ||
| self.format_function_call_results( | ||
| messages=messages, | ||
|
|
@@ -1105,6 +1113,10 @@ def response_stream( | |
| function_call_count = 0 | ||
|
|
||
| while True: | ||
| # Compress existing tool results BEFORE invoke | ||
| if compression_manager and compression_manager.should_compress(messages, tools, main_model=self): | ||
| compression_manager.compress(messages) | ||
|
|
||
| assistant_message = Message(role=self.assistant_message_role) | ||
| # Create assistant message and stream data | ||
| stream_data = MessageData() | ||
|
|
@@ -1166,11 +1178,6 @@ def response_stream( | |
| # Add a function call for each successful execution | ||
| function_call_count += len(function_call_results) | ||
|
|
||
| all_messages = messages + function_call_results | ||
| # Compress tool results | ||
| if compression_manager and compression_manager.should_compress(all_messages): | ||
| compression_manager.compress(all_messages) | ||
|
|
||
| # Format and add results to messages | ||
| if stream_data and stream_data.extra is not None: | ||
| self.format_function_call_results( | ||
|
|
@@ -1323,6 +1330,10 @@ async def aresponse_stream( | |
| function_call_count = 0 | ||
|
|
||
| while True: | ||
| # Compress existing tool results BEFORE making API call to avoid context overflow | ||
| if compression_manager and compression_manager.should_compress(messages, tools, main_model=self): | ||
| await compression_manager.acompress(messages) | ||
|
|
||
| # Create assistant message and stream data | ||
| assistant_message = Message(role=self.assistant_message_role) | ||
| stream_data = MessageData() | ||
|
|
@@ -1384,11 +1395,6 @@ async def aresponse_stream( | |
| # Add a function call for each successful execution | ||
| function_call_count += len(function_call_results) | ||
|
|
||
| all_messages = messages + function_call_results | ||
| # Compress tool results | ||
| if compression_manager and compression_manager.should_compress(all_messages): | ||
| await compression_manager.acompress(all_messages) | ||
|
|
||
| # Format and add results to messages | ||
| if stream_data and stream_data.extra is not None: | ||
| self.format_function_call_results( | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.