Skip to content

Commit 12c3b4c

Browse files
committed
♻️ Centralize tool registration with attach_core_tools macro
Introduce a registry module with attach_core_tools! macro that attaches the standard 7-tool set (git_status, git_diff, git_log, git_changed_files, file_read, code_search, project_docs) to any agent builder. This prevents tool list drift between main agents and subagents. Additional improvements: - Replace hand-rolled error types with define_tool_error! macro in workspace, docs, and parallel_analyze tools - Convert CodeSearch and Workspace args to use typed enums (SearchType, WorkspaceAction, TaskPriority, TaskStatus) with schemars derivation - Generate tool schemas from Rust types via parameters_schema::<T>() - Remove unused DynClientBuilder from IrisAgent and AgentSetupService - Extract inject_style_instructions() from execute_task() in iris.rs - Delegate commit/PR formatting to types module helpers
1 parent e2de14e commit 12c3b4c

File tree

9 files changed

+327
-310
lines changed

9 files changed

+327
-310
lines changed

src/agents/iris.rs

Lines changed: 96 additions & 122 deletions
Large diffs are not rendered by default.

src/agents/setup.rs

Lines changed: 5 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
//! including configuration loading, client creation, and agent setup.
55
66
use anyhow::Result;
7-
use rig::client::builder::DynClientBuilder;
87
use std::sync::Arc;
98

109
use crate::agents::context::TaskContext;
@@ -19,7 +18,6 @@ use crate::providers::Provider;
1918
pub struct AgentSetupService {
2019
config: Config,
2120
git_repo: Option<GitRepo>,
22-
client_builder: Option<DynClientBuilder>,
2321
}
2422

2523
impl AgentSetupService {
@@ -28,7 +26,6 @@ impl AgentSetupService {
2826
Self {
2927
config,
3028
git_repo: None,
31-
client_builder: None,
3229
}
3330
}
3431

@@ -59,10 +56,10 @@ impl AgentSetupService {
5956
/// Create a configured Iris agent
6057
pub fn create_iris_agent(&mut self) -> Result<IrisAgent> {
6158
let backend = AgentBackend::from_config(&self.config)?;
62-
let client_builder = self.create_client_builder(&backend)?;
59+
// Validate environment (API keys etc) before creating agent
60+
self.validate_provider(&backend)?;
6361

6462
let mut agent = IrisAgentBuilder::new()
65-
.with_client(client_builder)
6663
.with_provider(&backend.provider_name)
6764
.with_model(&backend.model)
6865
.build()?;
@@ -74,9 +71,8 @@ impl AgentSetupService {
7471
Ok(agent)
7572
}
7673

77-
/// Create a Rig client builder based on the backend configuration
78-
fn create_client_builder(&mut self, backend: &AgentBackend) -> Result<DynClientBuilder> {
79-
// Parse and validate provider
74+
/// Validate provider configuration (API keys etc)
75+
fn validate_provider(&self, backend: &AgentBackend) -> Result<()> {
8076
let provider: Provider = backend
8177
.provider_name
8278
.parse()
@@ -96,11 +92,7 @@ impl AgentSetupService {
9692
));
9793
}
9894

99-
// Create client builder - Rig will read from environment variables
100-
let client_builder = DynClientBuilder::new();
101-
self.client_builder = Some(DynClientBuilder::new());
102-
103-
Ok(client_builder)
95+
Ok(())
10496
}
10597

10698
/// Get the git repository instance
@@ -112,11 +104,6 @@ impl AgentSetupService {
112104
pub fn config(&self) -> &Config {
113105
&self.config
114106
}
115-
116-
/// Get the client builder
117-
pub fn client_builder(&self) -> Option<&DynClientBuilder> {
118-
self.client_builder.as_ref()
119-
}
120107
}
121108

122109
/// High-level function to handle tasks with agents using a common pattern
@@ -147,10 +134,7 @@ where
147134

148135
/// Simple factory function for creating agents with minimal configuration
149136
pub fn create_agent_with_defaults(provider: &str, model: &str) -> Result<IrisAgent> {
150-
let client_builder = DynClientBuilder::new();
151-
152137
IrisAgentBuilder::new()
153-
.with_client(client_builder)
154138
.with_provider(provider)
155139
.with_model(model)
156140
.build()
@@ -312,10 +296,7 @@ impl IrisAgentService {
312296

313297
/// Create a configured Iris agent
314298
fn create_agent(&self) -> Result<IrisAgent> {
315-
let client_builder = DynClientBuilder::new();
316-
317299
let mut agent = IrisAgentBuilder::new()
318-
.with_client(client_builder)
319300
.with_provider(&self.provider)
320301
.with_model(&self.model)
321302
.build()?;

src/agents/tools/code_search.rs

Lines changed: 52 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,10 @@ use anyhow::Result;
77
use rig::completion::ToolDefinition;
88
use rig::tool::Tool;
99
use serde::{Deserialize, Serialize};
10-
use serde_json::json;
1110
use std::path::Path;
1211
use std::process::Command;
1312

13+
use super::common::parameters_schema;
1414
use crate::define_tool_error;
1515

1616
define_tool_error!(CodeSearchError);
@@ -142,12 +142,52 @@ pub struct SearchResult {
142142
pub context_lines: usize,
143143
}
144144

145-
#[derive(Debug, Deserialize, Serialize)]
145+
/// Search type for code search
146+
#[derive(Debug, Clone, Serialize, Deserialize, schemars::JsonSchema, Default)]
147+
#[serde(rename_all = "lowercase")]
148+
pub enum SearchType {
149+
/// Search for function definitions
150+
Function,
151+
/// Search for class/struct definitions
152+
Class,
153+
/// Search for variable assignments
154+
Variable,
155+
/// General text search (case-insensitive)
156+
#[default]
157+
Text,
158+
/// Regex pattern search
159+
Pattern,
160+
}
161+
162+
impl SearchType {
163+
fn as_str(&self) -> &'static str {
164+
match self {
165+
SearchType::Function => "function",
166+
SearchType::Class => "class",
167+
SearchType::Variable => "variable",
168+
SearchType::Text => "text",
169+
SearchType::Pattern => "pattern",
170+
}
171+
}
172+
}
173+
174+
#[derive(Debug, Deserialize, Serialize, schemars::JsonSchema)]
146175
pub struct CodeSearchArgs {
176+
/// Search query - function name, class name, variable, or pattern
147177
pub query: String,
148-
pub search_type: String, // "function", "class", "variable", "text", "pattern"
178+
/// Type of search to perform
179+
#[serde(default)]
180+
pub search_type: SearchType,
181+
/// Optional file glob pattern to limit scope (e.g., "*.rs", "*.js")
182+
#[serde(default)]
149183
pub file_pattern: Option<String>,
150-
pub max_results: Option<usize>,
184+
/// Maximum results to return (default: 20, max: 100)
185+
#[serde(default = "default_max_results")]
186+
pub max_results: usize,
187+
}
188+
189+
fn default_max_results() -> usize {
190+
20
151191
}
152192

153193
impl Tool for CodeSearch {
@@ -157,55 +197,29 @@ impl Tool for CodeSearch {
157197
type Output = String;
158198

159199
async fn definition(&self, _: String) -> ToolDefinition {
160-
serde_json::from_value(json!({
161-
"name": "code_search",
162-
"description": "Search for code patterns, functions, classes, and related files in the repository using ripgrep. Supports multiple search types and file filtering.",
163-
"parameters": {
164-
"type": "object",
165-
"properties": {
166-
"query": {
167-
"type": "string",
168-
"description": "Search query - can be function name, class name, variable, or text pattern"
169-
},
170-
"search_type": {
171-
"type": "string",
172-
"enum": ["function", "class", "variable", "text", "pattern"],
173-
"description": "Type of search to perform (default: text)"
174-
},
175-
"file_pattern": {
176-
"type": ["string", "null"],
177-
"description": "Optional file glob pattern to limit search scope (e.g., '*.rs', '*.js')"
178-
},
179-
"max_results": {
180-
"type": ["integer", "null"],
181-
"description": "Maximum number of results to return (default: 20, max: 100)",
182-
"default": 20,
183-
"minimum": 1,
184-
"maximum": 100
185-
}
186-
},
187-
"required": ["query", "search_type", "file_pattern", "max_results"]
188-
}
189-
}))
190-
.expect("code_search tool definition should be valid JSON")
200+
ToolDefinition {
201+
name: "code_search".to_string(),
202+
description: "Search for code patterns, functions, classes, and related files in the repository using ripgrep. Supports multiple search types and file filtering.".to_string(),
203+
parameters: parameters_schema::<CodeSearchArgs>(),
204+
}
191205
}
192206

193207
async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> {
194208
let current_dir = std::env::current_dir().map_err(CodeSearchError::from)?;
195-
let max_results = args.max_results.unwrap_or(20);
209+
let max_results = args.max_results.min(100); // Cap at 100
196210

197211
let results = Self::execute_ripgrep_search(
198212
&args.query,
199213
&current_dir,
200214
args.file_pattern.as_deref(),
201-
&args.search_type,
215+
args.search_type.as_str(),
202216
max_results,
203217
)
204218
.map_err(CodeSearchError::from)?;
205219

206220
let result = serde_json::json!({
207221
"query": args.query,
208-
"search_type": args.search_type,
222+
"search_type": args.search_type.as_str(),
209223
"results": results,
210224
"total_found": results.len(),
211225
"max_results": max_results,

src/agents/tools/docs.rs

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -11,15 +11,8 @@ use std::path::PathBuf;
1111

1212
use super::common::parameters_schema;
1313

14-
#[derive(Debug, thiserror::Error)]
15-
#[error("Docs error: {0}")]
16-
pub struct DocsError(String);
17-
18-
impl From<std::io::Error> for DocsError {
19-
fn from(err: std::io::Error) -> Self {
20-
DocsError(err.to_string())
21-
}
22-
}
14+
// Use standard tool error macro for consistency
15+
crate::define_tool_error!(DocsError);
2316

2417
/// Tool for fetching project documentation files
2518
#[derive(Debug, Clone, Serialize, Deserialize)]

src/agents/tools/mod.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,10 @@
77
pub mod common;
88
pub use common::{get_current_repo, parameters_schema};
99

10+
// Tool registry for consistent attachment
11+
pub mod registry;
12+
pub use registry::CORE_TOOLS;
13+
1014
// Tool modules with Rig-based implementations
1115
pub mod git;
1216

src/agents/tools/parallel_analyze.rs

Lines changed: 15 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,7 @@ use serde_json::json;
1717
use std::sync::Arc;
1818
use tokio::sync::Mutex;
1919

20-
use super::{CodeSearch, FileRead, GitChangedFiles, GitDiff, GitLog, GitStatus, ProjectDocs};
2120
use crate::agents::debug as agent_debug;
22-
use crate::agents::debug_tool::DebugTool;
2321

2422
/// Arguments for parallel analysis
2523
#[derive(Debug, Deserialize, JsonSchema)]
@@ -102,37 +100,16 @@ impl SubagentRunner {
102100
- Return a clear, structured summary\n\
103101
- Be concise but comprehensive";
104102

103+
// Use shared tool registry for consistent tool attachment
105104
let result = match self {
106105
Self::OpenAI { client, model } => {
107-
let agent = client
108-
.agent(model)
109-
.preamble(preamble)
110-
.max_tokens(4096)
111-
.tool(DebugTool::new(GitStatus))
112-
.tool(DebugTool::new(GitDiff))
113-
.tool(DebugTool::new(GitLog))
114-
.tool(DebugTool::new(GitChangedFiles))
115-
.tool(DebugTool::new(FileRead))
116-
.tool(DebugTool::new(CodeSearch))
117-
.tool(DebugTool::new(ProjectDocs))
118-
.build();
119-
106+
let builder = client.agent(model).preamble(preamble).max_tokens(4096);
107+
let agent = crate::attach_core_tools!(builder).build();
120108
agent.prompt(task).await
121109
}
122110
Self::Anthropic { client, model } => {
123-
let agent = client
124-
.agent(model)
125-
.preamble(preamble)
126-
.max_tokens(4096)
127-
.tool(DebugTool::new(GitStatus))
128-
.tool(DebugTool::new(GitDiff))
129-
.tool(DebugTool::new(GitLog))
130-
.tool(DebugTool::new(GitChangedFiles))
131-
.tool(DebugTool::new(FileRead))
132-
.tool(DebugTool::new(CodeSearch))
133-
.tool(DebugTool::new(ProjectDocs))
134-
.build();
135-
111+
let builder = client.agent(model).preamble(preamble).max_tokens(4096);
112+
let agent = crate::attach_core_tools!(builder).build();
136113
agent.prompt(task).await
137114
}
138115
};
@@ -179,9 +156,8 @@ impl ParallelAnalyze {
179156
}
180157
}
181158

182-
#[derive(Debug, thiserror::Error)]
183-
#[error("Parallel analysis error: {0}")]
184-
pub struct ParallelAnalyzeError(String);
159+
// Use standard tool error macro for consistency
160+
crate::define_tool_error!(ParallelAnalyzeError);
185161

186162
impl Tool for ParallelAnalyze {
187163
const NAME: &'static str = "parallel_analyze";
@@ -228,7 +204,10 @@ impl Tool for ParallelAnalyze {
228204

229205
agent_debug::debug_context_management(
230206
"ParallelAnalyze",
231-
&format!("Spawning {} subagents (fast model: {})", num_tasks, self.model),
207+
&format!(
208+
"Spawning {} subagents (fast model: {})",
209+
num_tasks, self.model
210+
),
232211
);
233212

234213
// Collect results using Arc<Mutex> for thread-safe access
@@ -268,7 +247,10 @@ impl Tool for ParallelAnalyze {
268247

269248
agent_debug::debug_context_management(
270249
"ParallelAnalyze",
271-
&format!("{}/{} successful in {}ms", successful, num_tasks, execution_time_ms),
250+
&format!(
251+
"{}/{} successful in {}ms",
252+
successful, num_tasks, execution_time_ms
253+
),
272254
);
273255

274256
Ok(ParallelAnalyzeResult {

0 commit comments

Comments
 (0)