Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions examples/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ Most examples expect provider API keys in the environment (e.g. `OPENAI_API_KEY`
| `discord_bot` | See source. |
| `enum_dispatch` | See source. |
| `extractor` | Demonstrates typed extraction and extraction with usage metadata. |
| `force_tool_first_turn` | Demonstrates a per-turn `RequestPatch` footgun and its fix: forcing `tool_choice = Required` on *every* turn loops until `max_turns`, so an `AgentHook` gates the patch on `ctx.turn() == 1` to force the tool only up front. |
| `gemini_deep_research` | See source. |
| `gemini_default_api_recovery` | Demonstrates recovering from Gemini emitting a legacy `default_api` tool name. |
| `gemini_extractor_with_rag` | See source. |
Expand Down
16 changes: 16 additions & 0 deletions examples/force_tool_first_turn/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
[package]
name = "force_tool_first_turn"
version.workspace = true
edition.workspace = true
publish = false

[lints]
workspace = true

[dependencies]
rig.workspace = true
anyhow = { workspace = true }
tokio = { workspace = true, features = ["full"] }
serde = { workspace = true, features = ["derive"] }
serde_json = { workspace = true }
thiserror = { workspace = true }
166 changes: 166 additions & 0 deletions examples/force_tool_first_turn/src/main.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,166 @@
//! # Forcing a tool on the first turn: a `RequestPatch` footgun and its fix
//!
//! A hook can steer a single model turn by returning [`Flow::patch_request`] on
//! the [`StepEvent::CompletionCall`] event. A common wish is "make the model call
//! a tool *first*", done by patching `tool_choice = Required`.
//!
//! **The footgun.** A [`RequestPatch`] is **per-turn and non-sticky**: the
//! `CompletionCall` event re-fires on *every* turn, so a hook that patches
//! `Required` unconditionally forces a tool call on *every* turn. The model never
//! reaches a turn where it is free to stop calling tools and write the final
//! answer, so the run loops until `max_turns` and fails with
//! [`PromptError::MaxTurnsError`].
//!
//! **The fix.** Gate the patch on the turn index — force `Required` only on the
//! first turn (`ctx.turn() == 1`). The model is nudged to call the tool up front;
//! later turns inherit the agent's baseline (`auto`), so it can stop and answer.
//!
//! This example runs the footgun first (and catches the resulting
//! `MaxTurnsError`), then runs the fix.
//!
//! Requires `OPENAI_API_KEY`.

use anyhow::Result;
use rig::agent::{AgentHook, Flow, HookContext, RequestPatch, StepEvent};
use rig::client::{CompletionClient, ProviderClient};
use rig::completion::{CompletionModel, Prompt, PromptError, ToolDefinition};
use rig::message::ToolChoice;
use rig::providers::openai;
use rig::tool::Tool;
use serde::Deserialize;
use serde_json::json;

const PREAMBLE: &str =
"You are a calculator assistant. Use the add tool for arithmetic, then report the result.";
const PROMPT: &str = "What is 21 + 21? Use the add tool, then tell me the answer.";

// ---------------------------------------------------------------------------
// A tiny calculator tool the hook can force the model to call.
// ---------------------------------------------------------------------------

#[derive(Deserialize)]
struct AddArgs {
x: i64,
y: i64,
}

#[derive(Debug, thiserror::Error)]
#[error("math error")]
struct MathError;

#[derive(Clone)]
struct Add;

impl Tool for Add {
const NAME: &'static str = "add";
type Error = MathError;
type Args = AddArgs;
type Output = i64;

async fn definition(&self, _prompt: String) -> ToolDefinition {
ToolDefinition {
name: Self::NAME.to_string(),
description: "Add x and y together".to_string(),
parameters: json!({
"type": "object",
"properties": {
"x": { "type": "number", "description": "The first addend" },
"y": { "type": "number", "description": "The second addend" }
},
"required": ["x", "y"]
}),
}
}

async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> {
Ok(args.x + args.y)
}
}

// ---------------------------------------------------------------------------
// The footgun: force `Required` on EVERY completion call.
// ---------------------------------------------------------------------------

#[derive(Clone)]
struct ForceToolEveryTurn;

impl<M> AgentHook<M> for ForceToolEveryTurn
where
M: CompletionModel,
{
async fn on_event(&self, _ctx: &HookContext, event: StepEvent<'_, M>) -> Flow {
if matches!(event, StepEvent::CompletionCall { .. }) {
// BUG: re-applied every turn. The model is forced to call a tool on
// every turn and can never produce a final text answer, so the run
// loops until `max_turns`.
return Flow::patch_request(RequestPatch::new().tool_choice(ToolChoice::Required));
}
Flow::cont()
}
}

// ---------------------------------------------------------------------------
// The fix: force `Required` on the FIRST turn only.
// ---------------------------------------------------------------------------

#[derive(Clone)]
struct ForceToolOnFirstTurn;

impl<M> AgentHook<M> for ForceToolOnFirstTurn
where
M: CompletionModel,
{
async fn on_event(&self, ctx: &HookContext, event: StepEvent<'_, M>) -> Flow {
// Gate the per-turn patch on the turn index. On turn 1 we force the tool;
// on later turns we return `Continue`, so the request inherits the agent's
// baseline `tool_choice` and the model is free to answer.
if matches!(event, StepEvent::CompletionCall { .. }) && ctx.turn() == 1 {
return Flow::patch_request(RequestPatch::new().tool_choice(ToolChoice::Required));
}
Flow::cont()
}
}

#[tokio::main]
async fn main() -> Result<()> {
let client = openai::Client::from_env()?;
// A fresh agent per run (both share the same tool and preamble).
let make_agent = || {
client
.agent(openai::GPT_4O)
.preamble(PREAMBLE)
.tool(Add)
.build()
};

// 1) The footgun. Forcing `Required` on every turn re-forces a tool call each
// turn, so the run loops until `max_turns` and errors.
println!("=== forcing tool_choice=Required on EVERY turn (the footgun) ===");
let agent = make_agent();
match agent
.prompt(PROMPT)
.max_turns(4)
.add_hook(ForceToolEveryTurn)
.await
{
Ok(answer) => println!("(unexpected) got a final answer: {answer}\n"),
Err(PromptError::MaxTurnsError { max_turns, .. }) => println!(
"hit MaxTurnsError after {max_turns} turns — every turn re-forced a tool call, so the \
model never produced a final answer.\n"
),
Err(err) => println!("run failed: {err}\n"),
}

// 2) The fix. Forcing `Required` on the first turn only nudges the model to
// call the tool up front, then lets it answer.
println!("=== forcing tool_choice=Required on the FIRST turn only (the fix) ===");
let agent = make_agent();
let answer = agent
.prompt(PROMPT)
.max_turns(4)
.add_hook(ForceToolOnFirstTurn)
.await?;
println!("final answer: {answer}");

Ok(())
}
Loading