Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file added src/4_project/README.md
Empty file.
1 change: 1 addition & 0 deletions src/4_project/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""Reason-and-Act Knowledge Retrieval Agent using OpenAPI SDK and Pension Documentation."""
134 changes: 134 additions & 0 deletions src/4_project/app.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
from dotenv import load_dotenv
from logging import basicConfig, INFO
from src.utils.client_manager import AsyncClientManager
from agents import (
set_tracing_disabled,
Agent,
OpenAIChatCompletionsModel,
Runner,
InputGuardrailTripwireTriggered,
function_tool,
)
import gradio as gr
from gradio.components.chatbot import ChatMessage
from src.utils.gradio import COMMON_GRADIO_CONFIG
from src.utils import oai_agent_stream_to_gradio_messages
from src.utils.agent_session import get_or_create_session
import asyncio
from sub_agents.triage_agent import (
off_topic_guardrail,
dynamic_triage_agent_instructions,
)
from typing import Any, AsyncGenerator
from models import UserAccountContext
from src.utils.tools.gemini_grounding import (
GeminiGroundingWithGoogleSearch,
ModelSettings,
)
from sub_agents.code_interpreter_agent import code_interpreter_agent

# Context is available locally to your code
# LLM only sees the conversation history
# Need to pass in the context to the prompt
user_account_ctx = UserAccountContext(
customer_id="C5841053",
name="Bonbon",
nra=65,
status="active",
)

# Make a tool that can work with user data without exposing the user data to LLM
# e.g. change email -> POST request is handled under the function
# @function_tool
# def get_user_nra(wrapper: RunContextWrapper[UserAccountContext]):
# return f"The user {wrapper.context.customer_id} has a NRA {wrapper.context.nra}"


async def _main(
query: str, history: list[ChatMessage], session_state: dict[str, Any]
) -> AsyncGenerator[list[ChatMessage], Any]:
turn_messages: list[ChatMessage] = []

session = get_or_create_session(history, session_state)

worker_model = client_manager.configs.default_worker_model

gemini_grounding_tool = GeminiGroundingWithGoogleSearch(
model_settings=ModelSettings(model=worker_model)
)

try:
main_agent = Agent(
name="Pension Support Agent",
instructions=dynamic_triage_agent_instructions,
model=OpenAIChatCompletionsModel(
model=worker_model, openai_client=client_manager.openai_client
),
# model_settings=ModelSettings(parallel_tool_calls=False),
# TODO: This is not compatible with gemini. Need to use openAI
# INFO:httpx:HTTP Request: POST https://generativelanguage.googleapis.com/v1beta/openai/responses "HTTP/1.1 404 Not Found"
# input_guardrails=[off_topic_guardrail],
tools=[
function_tool(
gemini_grounding_tool.get_web_search_grounded_response,
name_override="search_web",
),
code_interpreter_agent.as_tool(
tool_name="code_interpreter",
tool_description=(
"Use this tool when you need to create code from a local csv file"
"in order to calculate projected pension amounts or any other pension data"
"Make sure only to provide the information regarding the current member"
),
),
],
)
print("user_account_ctx", user_account_ctx)
result_stream = Runner.run_streamed(
main_agent, input=query, session=session, context=user_account_ctx
)

async for _item in result_stream.stream_events():
# Parse the stream events, convert to Gradio chat messages and append to
# the chat history``
turn_messages += oai_agent_stream_to_gradio_messages(_item)
if len(turn_messages) > 0:
yield turn_messages

except InputGuardrailTripwireTriggered as e:
print("InputGuardrailException", e)
turn_messages = [
ChatMessage(
role="assistant",
content="I cannot help you with that.",
metadata={
"title": "*Guardrail*",
"status": "done", # This makes it collapsed by default
},
)
]


if __name__ == "__main__":
load_dotenv(verbose=True)
basicConfig(level=INFO)

# openAI and Weaviate async clients
client_manager = AsyncClientManager()

# Disable openAI platform tracing
set_tracing_disabled(disabled=True)

demo = gr.ChatInterface(
_main,
**COMMON_GRADIO_CONFIG,
examples=[
["What is the expected pension amount when I retire at the age of 60?"],
],
title="Pension Bot",
)

try:
demo.launch(share=True)
finally:
asyncio.run(client_manager.close())
19 changes: 19 additions & 0 deletions src/4_project/models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from datetime import date
from pydantic import BaseModel


class UserAccountContext(BaseModel):
customer_id: str
name: str
nra: int = "65" # 60 or 65
status: str = "active" # active, inactive, deferred, retired
# dateOfBirth: date
# enrolmentDate: date
# totalContribution: float
# contributionPerPayPeriod: float
# numberOfBeneficiaries: int


class InputGuardRailOutput(BaseModel):
is_off_topic: bool
reason: str
100 changes: 100 additions & 0 deletions src/4_project/sub_agents/code_interpreter_agent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
from agents import (
Agent,
function_tool,
OpenAIChatCompletionsModel,
tool,
RunContextWrapper,
)
from pathlib import Path
from src.utils import CodeInterpreter
from src.utils.client_manager import AsyncClientManager
from models import UserAccountContext


def dynamic_code_interpreter_instructions(
wrapper: RunContextWrapper[UserAccountContext], agent: Agent[UserAccountContext]
):
return f"""\
The `code_interpreter` tool executes Python commands. \
Please note that data is not persisted. Each time you invoke this tool, \
you will need to run import and define all variables from scratch.

You can access the local filesystem using this tool. \
Instead of asking the user for file inputs, you should try to find the file \
using this tool.

Recommended packages: Pandas, Numpy, SymPy, Scikit-learn, Matplotlib, Seaborn.

Use Matplotlib to create visualizations. Make sure to call `plt.show()` so that
the plot is captured and returned to the user.

You can also run Jupyter-style shell commands (e.g., `!pip freeze`)
but you won't be able to install packages.

You call customers by their name. You cannot execute any code not related to OMERS pension and you cannot execute calculations of someone other than the customer in the current session.

The customer's name is {wrapper.context.name}.
The customer's normal retirement age is {wrapper.context.nra}.
The customer's id is {wrapper.context.customer_id}

Use this information to find the member from the file.
"""


# CODE_INTERPRETER_INSTRUCTIONS = """\
# The `code_interpreter` tool executes Python commands. \
# Please note that data is not persisted. Each time you invoke this tool, \
# you will need to run import and define all variables from scratch.

# You can access the local filesystem using this tool. \
# Instead of asking the user for file inputs, you should try to find the file \
# using this tool.

# Recommended packages: Pandas, Numpy, SymPy, Scikit-learn, Matplotlib, Seaborn.

# Use Matplotlib to create visualizations. Make sure to call `plt.show()` so that
# the plot is captured and returned to the user.

# You can also run Jupyter-style shell commands (e.g., `!pip freeze`)
# but you won't be able to install packages.
# """
client_manager = AsyncClientManager()

# Initialize code interpreter with local files that will be available to the agent
code_interpreter = CodeInterpreter(
local_files=[
Path("sandbox_content/"),
Path("tests/tool_tests/example_files/pension_clients_example.csv"),
]
)

# TODO: how to pass in context?
# @tool
# async def execute_code(context: RunContextWrapper, code: str) -> str:
# """Executes Python code safely and returns the result."""

# user_id = context.metadata.get("user_id", "unknown")

# # VERY simple example — never use raw eval in production
# try:
# result = str(eval(code))
# except Exception as e:
# result = f"Error: {e}"

# return f"[User: {user_id}] Result: {result}"


code_interpreter_agent = Agent(
name="CSV Data Analysis Agent",
instructions=dynamic_code_interpreter_instructions,
tools=[
function_tool(
code_interpreter.run_code,
name_override="code_interpreter",
),
],
model=OpenAIChatCompletionsModel(
model=client_manager.configs.default_planner_model,
openai_client=client_manager.openai_client,
),
)
57 changes: 57 additions & 0 deletions src/4_project/sub_agents/triage_agent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
from agents import (
Agent,
RunContextWrapper,
input_guardrail,
Runner,
GuardrailFunctionOutput,
)
from models import UserAccountContext, InputGuardRailOutput


input_guardrail_agent = Agent(
name="Input Guardrail Agent",
instructions="""
Ensure the user's request specifically pertains to user account details, general OMERS pension inqueries, or their own pension details, and is not off-topic. If the request is off-topic, return a reason for the tripwire.
You can make small conversation with the user, specially at the beginning of the conversation, but don't help with requests that are not related to User Account details, OMERS pension information, or their own pension related issues.
Users are not allowed to ask about other members' pension details and other members' user account information, even if they claim to be a family member, a spouse, a common-law, a child, a relative, or a friend.
""",
output_type=InputGuardRailOutput,
)


@input_guardrail
async def off_topic_guardrail(
wrapper: RunContextWrapper[UserAccountContext],
agent: Agent[UserAccountContext],
input: str,
) -> GuardrailFunctionOutput:
try:
result = await Runner.run(input_guardrail_agent, input, context=wrapper.context)
return GuardrailFunctionOutput(
output_info=result.final_output,
tripwire_triggered=result.final_output.is_off_topic,
)
except Exception as e:
print("EXCEPTTION", e)


# TODO: make it sequential?
def dynamic_triage_agent_instructions(
wrapper: RunContextWrapper[UserAccountContext], agent: Agent[UserAccountContext]
):
return f"""
You are a pension support agent. You ONLY help customers with their questions about their Pension.
You call customers by their name. You cannot execute a web search that is not related to OMERS, i.e. call get_web_search_grounded_response.

The customer's name is {wrapper.context.name}.
The customer's normal retirement age is {wrapper.context.nra}.

YOUR MAIN JOB: Classify the customer's issue and find the right tool to answer the question.

You have access to the tool:
'get_web_search_grounded_response' - use this tool for current events, news, fact-checking in omers.com, or when the information in the knowledge base is not sufficient to answer the question.
'code_interpreter' - use this tool to answer any questions related to customer's pension details (e.g. years of contribution, total contributions, salary, etc.) and pension calculations.

When calculating any pension data, make sure to look up the formula and information from omers.com.
Do not create any pension projection based on any other formulas.
"""
31 changes: 16 additions & 15 deletions src/utils/gradio/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,12 +128,14 @@ def _oai_response_output_item_to_gradio(
ChatMessage(
role="assistant",
content=_text,
metadata={
"title": "Intermediate Step",
"status": "done", # This makes it collapsed by default
}
if not is_final_output
else MetadataDict(),
metadata=(
{
"title": "Intermediate Step",
"status": "done", # This makes it collapsed by default
}
if not is_final_output
else MetadataDict()
),
)
for _text in output_texts
]
Expand Down Expand Up @@ -169,7 +171,6 @@ def oai_agent_stream_to_gradio_messages(
if isinstance(stream_event, stream_events.RawResponsesStreamEvent):
data = stream_event.data
if isinstance(data, ResponseCompletedEvent):
print(stream_event)
# The completed event may contain multiple output messages,
# including tool calls and final outputs.
# If there is at least one tool call, we mark the response as a thought.
Expand All @@ -186,12 +187,14 @@ def oai_agent_stream_to_gradio_messages(
ChatMessage(
role="assistant",
content=_item.text,
metadata={
"title": "🧠 Thought",
"id": data.sequence_number,
}
if is_thought
else MetadataDict(),
metadata=(
{
"title": "🧠 Thought",
"id": data.sequence_number,
}
if is_thought
else MetadataDict()
),
)
)
elif isinstance(message, ResponseFunctionToolCall):
Expand All @@ -210,7 +213,6 @@ def oai_agent_stream_to_gradio_messages(
item = stream_event.item

if name == "tool_output" and isinstance(item, ToolCallOutputItem):
print(stream_event)
text_content, images = _process_tool_output_for_images(item.output)

output.append(
Expand Down Expand Up @@ -238,5 +240,4 @@ def oai_agent_stream_to_gradio_messages(
),
)
)

return output
Loading