diff --git a/src/hitachione/README.md b/src/hitachione/README.md new file mode 100644 index 0000000..e69de29 diff --git a/src/hitachione/__init__.py b/src/hitachione/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/hitachione/agents/README.md b/src/hitachione/agents/README.md new file mode 100644 index 0000000..e69de29 diff --git a/src/hitachione/agents/__init__.py b/src/hitachione/agents/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/hitachione/agents/knowledge_retrieval.py b/src/hitachione/agents/knowledge_retrieval.py new file mode 100644 index 0000000..d6b41fa --- /dev/null +++ b/src/hitachione/agents/knowledge_retrieval.py @@ -0,0 +1,87 @@ +"""Knowledge Retrieval Agent – consults Weaviate for context enrichment. + +Responsibilities: + • Resolve entity aliases (e.g. "Google" → GOOGL) + • Find prior top-lists or summaries relevant to the query + • Provide entity hints the Orchestrator can use for planning + +Non-blocking: failures return partial / empty results rather than crashing. +""" + +from __future__ import annotations + +import logging +from typing import Any + +from ..config.settings import ( + WEAVIATE_COLLECTION, WEAVIATE_API_KEY, WEAVIATE_HTTP_HOST, +) +from ..models.schemas import TaskContext + +logger = logging.getLogger(__name__) + + +def _weaviate_client(): + """Build a *sync* Weaviate client.""" + import weaviate + from weaviate.auth import AuthApiKey + + if WEAVIATE_HTTP_HOST.endswith(".weaviate.cloud"): + return weaviate.connect_to_weaviate_cloud( + cluster_url=f"https://{WEAVIATE_HTTP_HOST}", + auth_credentials=AuthApiKey(WEAVIATE_API_KEY), + ) + return weaviate.connect_to_custom( + http_host=WEAVIATE_HTTP_HOST, + http_port=443, http_secure=True, + grpc_host=WEAVIATE_HTTP_HOST, grpc_port=443, grpc_secure=True, + auth_credentials=AuthApiKey(WEAVIATE_API_KEY), + ) + + +class KnowledgeRetrievalAgent: + """Look up aliases, entity hints, and prior summaries from the KB.""" + + def run(self, ctx: TaskContext) -> dict[str, Any]: + """Return ``{"aliases": {...}, "entity_hints": [...], "summaries": [...]}``.""" + result: dict[str, Any] = {"aliases": {}, "entity_hints": [], "summaries": []} + try: + client = _weaviate_client() + try: + col = client.collections.get(WEAVIATE_COLLECTION) + + # 1) BM25 search on the user query for entity hints / summaries + resp = col.query.bm25( + query=ctx.user_query, + limit=8, + return_properties=[ + "text", "title", "ticker", "company", + "dataset_source", "date", + ], + ) + for obj in resp.objects: + p = {k: v for k, v in obj.properties.items() if v is not None} + ticker = p.get("ticker") + company = p.get("company") + if ticker and company: + result["aliases"][company] = ticker + if ticker and ticker not in result["entity_hints"]: + result["entity_hints"].append(ticker) + result["summaries"].append( + f"[{p.get('dataset_source','')} | {p.get('date','')}] " + f"{p.get('title','')}" + ) + + # 2) If the context already has entities, resolve them + for entity in ctx.entities: + if entity.upper() not in result["aliases"].values(): + result["aliases"][entity] = entity.upper() + + finally: + client.close() + + except Exception as exc: + logger.warning("KnowledgeRetrievalAgent error: %s", exc) + ctx.uncertainties.append(f"KB lookup failed: {exc}") + + return result diff --git a/src/hitachione/agents/orchestrator.py b/src/hitachione/agents/orchestrator.py new file mode 100644 index 0000000..7833f4b --- /dev/null +++ b/src/hitachione/agents/orchestrator.py @@ -0,0 +1,368 @@ +"""Orchestrator – the agentic ReAct loop (Plan → Act → Observe → Reflect). + +This is the top-level entry point. Given a free-text user prompt it: + +1. **Parses intent** (rank, compare, snapshot, event_reaction, …) +2. **Plans** subgoals and identifies information gaps +3. **Acts** by calling sub-agents (KB retrieval, company retrieval, + researcher, synthesizer, reviewer) +4. **Observes** outputs and assesses sufficiency +5. **Reflects** — if the Reviewer flags issues and we have iterations left, + revises the plan and loops +6. **Stops** when the Reviewer says OK, the budget is exhausted, or + information gain is negligible +7. **Returns** a clean ``SynthesizedAnswer`` with rationale + caveats + +Usage:: + + from hitachione.agents.orchestrator import Orchestrator + answer = Orchestrator().run("Top 3 tech stocks of 2024") + print(answer.markdown) +""" + +from __future__ import annotations + +import json +import logging +import re +from typing import Any + +from openai import OpenAI + +from ..config.settings import ( + MAX_ITERATIONS, OPENAI_API_KEY, OPENAI_BASE_URL, PLANNER_MODEL, +) +from ..models.schemas import ( + Intent, SynthesizedAnswer, TaskContext, +) +from ..services.tracing import Tracer +from .knowledge_retrieval import KnowledgeRetrievalAgent +from .researcher import ResearcherAgent +from .reviewer import ReviewerAgent +from .synthesizer import SynthesizerAgent + +logger = logging.getLogger(__name__) + +# ── Intent parsing prompt ──────────────────────────────────────────────── + +_INTENT_PROMPT = """\ +Classify the user's financial query into exactly one intent and extract +structured fields. Return ONLY valid JSON (no markdown fences): + +{ + "intent": "", + "entities": ["", ...], + "timeframe": "", + "sector": "" +} + +Rules: +- "entities" must be uppercase ticker symbols when identifiable. +- If the user mentions company names instead of tickers, map them to tickers + (e.g. "Tesla" → "TSLA", "Apple" → "AAPL", "Google" → "GOOGL"). +- If unsure about tickers, leave entities empty. +- Keep it concise. +""" + +# ── Company retrieval helper ───────────────────────────────────────────── + +def _find_symbols(query: str) -> list[str]: + """Call the company filtering tool to discover relevant tickers.""" + try: + from ..tools.company_filtering_tool.tool import find_relevant_symbols + return find_relevant_symbols(query, use_llm_filter=True) + except Exception as exc: + logger.warning("Company filtering tool error: %s", exc) + return [] + + +# ── Orchestrator ───────────────────────────────────────────────────────── + +class Orchestrator: + """Agentic ReAct orchestrator for financial intelligence queries.""" + + def __init__(self, max_iterations: int = MAX_ITERATIONS): + self.max_iter = max_iterations + self._llm = OpenAI(base_url=OPENAI_BASE_URL, api_key=OPENAI_API_KEY) + self.kb_agent = KnowledgeRetrievalAgent() + self.researcher = ResearcherAgent() + self.synthesizer = SynthesizerAgent() + self.reviewer = ReviewerAgent() + + # ── public API ────────────────────────────────────────────────────── + + def run( + self, + user_query: str, + *, + default_timeframe: str = "", + metadata: dict[str, Any] | None = None, + ) -> SynthesizedAnswer: + """Execute the full plan-act-observe-reflect loop.""" + + tracer = Tracer.start( + "orchestrator_run", + metadata={"query": user_query, **(metadata or {})}, + ) + + ctx = TaskContext(user_query=user_query, timeframe=default_timeframe) + + # ── STEP 1: Parse intent ──────────────────────────────────────── + with tracer.span("intent_parse") as sp: + self._parse_intent(ctx) + sp.update(output={ + "intent": ctx.intent.value, + "entities": ctx.entities, + "timeframe": ctx.timeframe, + "sector": ctx.sector, + }) + ctx.observations.append( + f"Intent={ctx.intent.value}, entities={ctx.entities}, " + f"timeframe={ctx.timeframe}, sector={ctx.sector}" + ) + + # Track whether user explicitly provided entities in the prompt. + # If not, we should run company filtering even if KB retrieval yields + # noisy hints, because broad sector/list queries rely on this step. + explicit_entities_in_query = bool(ctx.entities) + + answer = SynthesizedAnswer() + + for iteration in range(1, self.max_iter + 1): + ctx.iteration = iteration + logger.info("── Iteration %d/%d ──", iteration, self.max_iter) + + # ── STEP 2: Plan ──────────────────────────────────────────── + with tracer.span("planning", metadata={"iteration": iteration}) as sp: + plan = self._plan(ctx) + ctx.plan = plan + sp.update(output=plan) + ctx.observations.append(f"Plan (iter {iteration}): {plan}") + + # ── STEP 3: Act – KB retrieval ────────────────────────────── + with tracer.span("knowledge_retrieval") as sp: + kb_data = self.kb_agent.run(ctx) + sp.update(output=kb_data) + # Merge entity hints into context + for hint in kb_data.get("entity_hints", []): + if hint not in ctx.entities: + ctx.entities.append(hint) + + # ── STEP 3b: Act – Company retrieval for broad queries ───── + if not explicit_entities_in_query and iteration == 1: + with tracer.span("company_retrieval") as sp: + original_hints = list(ctx.entities) + ctx.observations.append( + "No explicit entities in prompt – calling company filter" + ) + symbols = _find_symbols(ctx.user_query) + # Use company filter as authoritative for broad queries. + # Fall back to KB hints only if filter returns nothing. + if symbols: + ctx.entities = symbols + source = "company_filter" + else: + ctx.entities = original_hints + source = "kb_hints" + sp.update(output={ + "symbols": symbols, + "kb_hints": original_hints, + "source": source, + }) + + if not ctx.entities: + ctx.uncertainties.append("Could not identify any tickers") + answer.caveats.append( + "No tickers could be identified for this query." + ) + answer.markdown = ( + f"I wasn't able to identify specific tickers for: " + f"*{user_query}*. Please try including ticker symbols " + f"(e.g. AAPL, TSLA)." + ) + answer.confidence = 0.0 + tracer.end(output={"markdown": answer.markdown}) + return answer + + # ── STEP 4: Act – Research ────────────────────────────────── + with tracer.span("research_fanout") as sp: + research = self.researcher.run(ctx, ctx.entities) + + # On retry iterations, merge new results into prior good ones + if iteration > 1 and hasattr(ctx, '_prior_research'): + merged: dict[str, Any] = {} + for cr in ctx._prior_research: # type: ignore[attr-defined] + merged[cr.ticker] = cr + # Overwrite with new results (which may fix prior errors) + for cr in research: + merged[cr.ticker] = cr + research = [merged[t] for t in ctx.entities if t in merged] + + # Stash current research for potential future merging + ctx._prior_research = research # type: ignore[attr-defined] + + sp.update(output={ + "count": len(research), + "tickers": [r.ticker for r in research], + }) + + # ── STEP 5: Act – Synthesize ──────────────────────────────── + with tracer.span("synthesizer") as sp: + answer = self.synthesizer.run(ctx, research) + sp.update(output={ + "confidence": answer.confidence, + "caveats": answer.caveats, + "md_length": len(answer.markdown), + }) + + # ── STEP 6: Observe – Review ──────────────────────────────── + with tracer.span("reviewer") as sp: + feedback = self.reviewer.run(ctx, answer) + sp.update(output={ + "ok": feedback.ok, + "missing": feedback.missing, + "notes": feedback.notes, + }) + ctx.observations.append( + f"Reviewer (iter {iteration}): ok={feedback.ok}, " + f"notes={feedback.notes}" + ) + + # ── STEP 7: Reflect & decide to stop ─────────────────────── + if feedback.ok: + logger.info("Reviewer OK – stopping") + break + + # If the issues are not retriable (e.g. data simply does not + # exist in the KB), stop immediately – looping won't help. + if not feedback.retriable: + logger.info( + "Reviewer flagged %d non-retriable issues – stopping", + len(feedback.missing), + ) + if not any("knowledge base" in c.lower() for c in answer.caveats): + answer.caveats.append( + "No data available in the knowledge base for some " + "or all of the requested tickers." + ) + break + + if iteration < self.max_iter: + # Reflect: try to address missing items + with tracer.span("reflection") as sp: + adjustments = self._reflect(ctx, feedback) + sp.update(output=adjustments) + ctx.observations.append(f"Reflection: {adjustments}") + else: + logger.info("Max iterations reached – returning best effort") + answer.caveats.append( + "Maximum analysis iterations reached; some data may be incomplete." + ) + + tracer.end(output={"confidence": answer.confidence, "caveats": answer.caveats}) + return answer + + # ── Private helpers ───────────────────────────────────────────────── + + def _parse_intent(self, ctx: TaskContext) -> None: + """Use LLM to classify intent and extract entities / timeframe.""" + try: + resp = self._llm.chat.completions.create( + model=PLANNER_MODEL, + messages=[ + {"role": "system", "content": _INTENT_PROMPT}, + {"role": "user", "content": ctx.user_query}, + ], + temperature=0, + ) + raw = (resp.choices[0].message.content or "").strip() + # Strip markdown fences if present + if "```" in raw: + raw = raw.split("```")[1] + if raw.startswith("json"): + raw = raw[4:] + raw = raw.strip() + data = json.loads(raw) + except Exception as exc: + logger.warning("Intent parse error: %s", exc) + data = {} + + intent_str = data.get("intent", "mixed") + try: + ctx.intent = Intent(intent_str) + except ValueError: + ctx.intent = Intent.MIXED + + ctx.entities = [e.upper() for e in data.get("entities", [])] + ctx.timeframe = data.get("timeframe", ctx.timeframe) or "" + ctx.sector = data.get("sector", "") or "" + + def _plan(self, ctx: TaskContext) -> list[str]: + """Generate a simple step plan from the current context.""" + steps = [] + if not ctx.entities: + steps.append("Discover relevant entities via KB + company filter") + else: + steps.append(f"Research entities: {', '.join(ctx.entities)}") + + if ctx.intent in (Intent.RANK, Intent.COMPARE): + steps.append("Fetch sentiment + performance for each entity") + steps.append("Score & rank entities") + elif ctx.intent == Intent.SNAPSHOT: + steps.append("Fetch latest data for the entity") + elif ctx.intent == Intent.EVENT_REACTION: + steps.append("Fetch recent news + price reaction") + else: + steps.append("Fetch sentiment + performance data") + + steps.append("Synthesize answer with rationale + caveats") + steps.append("Review for quality and completeness") + + return steps + + def _reflect(self, ctx: TaskContext, feedback) -> dict[str, Any]: + """Adjust the context based on reviewer feedback and apply changes. + + On retry iterations, narrows ``ctx.entities`` to only the tickers that + need re-research — this avoids wasting API calls on entities that + already have complete data. + """ + adjustments: dict[str, Any] = {"action": "none"} + + missing = feedback.missing + + # Collect tickers that need re-research from ANY kind of missing item + tickers_to_retry: set[str] = set() + + for msg in missing: + # "Entity XXXX not researched" + match = re.search(r"Entity\s+(\S+)\s+not researched", msg, re.IGNORECASE) + if match: + tickers_to_retry.add(match.group(1).upper()) + continue + # "XXXX: missing sentiment rating" / "XXXX: missing performance score" + match = re.search(r"^(\S+):\s+missing\s+(sentiment|performance)", msg, re.IGNORECASE) + if match: + tickers_to_retry.add(match.group(1).upper()) + continue + # "XXXX: performance unavailable" (caveat-style) + match = re.search(r"^(\S+):\s+\w+\s+unavailable", msg, re.IGNORECASE) + if match: + tickers_to_retry.add(match.group(1).upper()) + + if tickers_to_retry: + adjustments["action"] = "retry_failed_entities" + adjustments["entities"] = sorted(tickers_to_retry) + # Narrow entity list to only failed tickers for the retry + ctx.entities = sorted(tickers_to_retry) + ctx.observations.append( + f"Retrying {len(tickers_to_retry)} failed entities: " + f"{', '.join(sorted(tickers_to_retry))}" + ) + + # If confidence is low, try broader KB search + elif any("confidence" in m.lower() for m in missing): + adjustments["action"] = "broaden_search" + ctx.observations.append("Broadening search due to low confidence") + + return adjustments diff --git a/src/hitachione/agents/researcher.py b/src/hitachione/agents/researcher.py new file mode 100644 index 0000000..106cb12 --- /dev/null +++ b/src/hitachione/agents/researcher.py @@ -0,0 +1,154 @@ +"""Researcher Agent – per-entity data fetch with parallel fan-out. + +For each ticker the Researcher: + 1. Calls the sentiment analysis tool → rating 1-10 + rationale + 2. Calls the performance analysis tool → score 1-10 + outlook + 3. Retries failed tool calls up to ``_MAX_RETRIES`` times with backoff + 4. Captures errors per entity (never crashes the whole run) + +All tickers are researched **in parallel** using a thread pool, and the +two tool calls per ticker also run concurrently. This cuts wall-clock +time from ``O(n × 2 × latency)`` to roughly ``O(latency)`` for typical +workloads (≤ 10 entities). +""" + +from __future__ import annotations + +import logging +import time +from concurrent.futures import ThreadPoolExecutor, TimeoutError, as_completed +from typing import Any + +from ..models.schemas import CompanyResearch, TaskContext, ToolError + +logger = logging.getLogger(__name__) + +# Max parallel threads. Keep modest to avoid API rate-limits. +_MAX_WORKERS = 8 + +# Per-tool timeout in seconds (prevents a single slow call from blocking) +_TOOL_TIMEOUT = 60 + +# Retry settings for transient errors (connection drops, gRPC resets, etc.) +_MAX_RETRIES = 3 +_RETRY_BACKOFF = 2.0 # seconds; doubles each attempt + + +# ── Lazy imports for the existing tools (avoid circular / heavy init) ── + +def _sentiment(ticker: str) -> dict[str, Any]: + from ..tools.sentiment_analysis_tool.tool import analyze_ticker_sentiment_sync + return analyze_ticker_sentiment_sync(ticker) + + +def _performance(ticker: str) -> dict[str, Any]: + from ..tools.performance_analysis_tool.tool import analyse_stock_performance + return analyse_stock_performance(ticker) + + +def _call_with_retry(fn, ticker: str, tool_name: str) -> dict[str, Any]: + """Call *fn(ticker)* with up to ``_MAX_RETRIES`` attempts on failure. + + Raises the last exception if all attempts fail. + """ + last_exc: Exception | None = None + for attempt in range(1, _MAX_RETRIES + 1): + try: + return fn(ticker) + except Exception as exc: + last_exc = exc + if attempt < _MAX_RETRIES: + wait = _RETRY_BACKOFF * (2 ** (attempt - 1)) + logger.warning( + "%s attempt %d/%d failed for %s: %s – retrying in %.1fs", + tool_name, attempt, _MAX_RETRIES, ticker, exc, wait, + ) + time.sleep(wait) + else: + logger.warning( + "%s attempt %d/%d failed for %s: %s – giving up", + tool_name, attempt, _MAX_RETRIES, ticker, exc, + ) + raise last_exc # type: ignore[misc] + + +def _research_one(ticker: str) -> CompanyResearch: + """Fetch sentiment + performance for one ticker **in parallel**.""" + cr = CompanyResearch(ticker=ticker) + + # Fire both tool calls concurrently within a small thread pool + with ThreadPoolExecutor(max_workers=2) as inner: + sent_future = inner.submit(_call_with_retry, _sentiment, ticker, "sentiment") + perf_future = inner.submit(_call_with_retry, _performance, ticker, "performance") + + # Collect sentiment + try: + cr.sentiment = sent_future.result(timeout=_TOOL_TIMEOUT) + refs = cr.sentiment.get("references", []) + cr.news_snippets = [str(r) for r in refs][:5] + except TimeoutError: + logger.warning("Sentiment timed out for %s after %ds", ticker, _TOOL_TIMEOUT) + cr.errors.append(ToolError(entity=ticker, tool="sentiment", error=f"timeout after {_TOOL_TIMEOUT}s")) + sent_future.cancel() + except Exception as exc: + logger.warning("Sentiment error for %s: %s", ticker, exc) + cr.errors.append(ToolError(entity=ticker, tool="sentiment", error=str(exc))) + + # Collect performance + try: + cr.performance = perf_future.result(timeout=_TOOL_TIMEOUT) + except TimeoutError: + logger.warning("Performance timed out for %s after %ds", ticker, _TOOL_TIMEOUT) + cr.errors.append(ToolError(entity=ticker, tool="performance", error=f"timeout after {_TOOL_TIMEOUT}s")) + perf_future.cancel() + except Exception as exc: + logger.warning("Performance error for %s: %s", ticker, exc) + cr.errors.append(ToolError(entity=ticker, tool="performance", error=str(exc))) + + return cr + + +class ResearcherAgent: + """Fan-out research across a list of entities **in parallel**.""" + + def __init__(self, max_workers: int = _MAX_WORKERS): + self.max_workers = max_workers + + def run(self, ctx: TaskContext, entities: list[str]) -> list[CompanyResearch]: + """Research every entity concurrently; accumulate errors without crashing.""" + ctx.observations.append( + f"Researching {len(entities)} entities in parallel " + f"(max_workers={self.max_workers})…" + ) + + results_map: dict[str, CompanyResearch] = {} + + with ThreadPoolExecutor(max_workers=self.max_workers) as pool: + future_to_ticker = { + pool.submit(_research_one, ticker): ticker + for ticker in entities + } + for future in as_completed(future_to_ticker): + ticker = future_to_ticker[future] + try: + cr = future.result() + except Exception as exc: + logger.error("Unhandled research error for %s: %s", ticker, exc) + cr = CompanyResearch(ticker=ticker) + cr.errors.append( + ToolError(entity=ticker, tool="research", error=str(exc)) + ) + results_map[ticker] = cr + + # Preserve the original entity order in results + results: list[CompanyResearch] = [] + for ticker in entities: + cr = results_map.get(ticker, CompanyResearch(ticker=ticker)) + results.append(cr) + if cr.errors: + for e in cr.errors: + ctx.uncertainties.append( + f"{e.tool} failed for {e.entity}: {e.error}" + ) + + return results diff --git a/src/hitachione/agents/reviewer.py b/src/hitachione/agents/reviewer.py new file mode 100644 index 0000000..7b85d7f --- /dev/null +++ b/src/hitachione/agents/reviewer.py @@ -0,0 +1,119 @@ +"""Reviewer Agent – deterministic quality gate for synthesized answers. + +Checks: + 1. Entity coverage: were requested entities actually researched? + For broad queries (>10 entities), requires ≥80% coverage instead of 100%. + 2. Score completeness: does each researched entity have sentiment + performance? + 3. Answer quality: is the markdown non-empty and reasonably long? + 4. Confidence threshold: is overall confidence above the minimum? + +Returns ``ReviewFeedback(ok=True/False, missing=[...], notes=...)``. +""" + +from __future__ import annotations + +import logging + +from ..config.settings import QUALITY_THRESHOLD +from ..models.schemas import ( + ReviewFeedback, SynthesizedAnswer, TaskContext, +) + +logger = logging.getLogger(__name__) + +# When entity count exceeds this, switch from 100% to 80% coverage check +_BROAD_QUERY_THRESHOLD = 10 +_BROAD_COVERAGE_RATIO = 0.8 + + +# Sentinel phrases emitted by the tools when data is absent +_NO_DATA_PHRASES = ("no data found",) + + +def _is_no_kb_data(rationale: str | None) -> bool: + """Return True when a tool rationale means the KB simply has no records.""" + if not rationale or not isinstance(rationale, str): + return False + lower = rationale.lower() + return any(phrase in lower for phrase in _NO_DATA_PHRASES) + + +class ReviewerAgent: + """Deterministic checks – no LLM calls, fast and predictable.""" + + def run(self, ctx: TaskContext, answer: SynthesizedAnswer) -> ReviewFeedback: + fb = ReviewFeedback() + missing: list[str] = [] + retriable_issues = 0 # count issues that a retry could potentially fix + + # 1. Entity coverage + researched = {cr.ticker for cr in answer.raw_research} + total_entities = len(ctx.entities) + not_researched = [e for e in ctx.entities if e.upper() not in researched] + + if total_entities > _BROAD_QUERY_THRESHOLD: + # Broad query: require ≥80% coverage + coverage_ratio = 1.0 - (len(not_researched) / total_entities) if total_entities else 1.0 + if coverage_ratio < _BROAD_COVERAGE_RATIO: + missing.append( + f"Entity coverage {coverage_ratio:.0%} below " + f"{_BROAD_COVERAGE_RATIO:.0%} threshold " + f"({len(not_researched)}/{total_entities} not researched)" + ) + retriable_issues += 1 + else: + # Narrow query: require 100% coverage + for entity in not_researched: + missing.append(f"Entity {entity} not researched") + retriable_issues += 1 + + # 2. Score completeness + # Distinguish "no data in KB" (not retriable) from transient errors + # (retriable). The tools set a rationale / justification that + # contains 'No data found' when the KB has no records. + for cr in answer.raw_research: + sent_rationale = cr.sentiment.get("rationale") if cr.sentiment else None + perf_justification = cr.performance.get("justification") if cr.performance else None + + if not cr.sentiment or cr.sentiment.get("rating") is None: + if _is_no_kb_data(sent_rationale): + missing.append(f"{cr.ticker}: no sentiment data in knowledge base") + else: + missing.append(f"{cr.ticker}: missing sentiment rating") + retriable_issues += 1 + + if not cr.performance or cr.performance.get("performance_score") is None: + if _is_no_kb_data(perf_justification): + missing.append(f"{cr.ticker}: no performance data in knowledge base") + else: + missing.append(f"{cr.ticker}: missing performance score") + retriable_issues += 1 + + # 3. Answer quality + if len(answer.markdown) < 50: + missing.append("Answer text too short") + retriable_issues += 1 + + # 4. Confidence + if answer.confidence < QUALITY_THRESHOLD: + missing.append( + f"Confidence {answer.confidence:.2f} below threshold " + f"{QUALITY_THRESHOLD:.2f}" + ) + # Low confidence by itself is not retriable – it's a consequence + # of missing data. Only if other retriable issues exist will + # a retry improve confidence. + + fb.missing = missing + fb.ok = len(missing) == 0 + fb.retriable = retriable_issues > 0 + fb.notes = ( + "All checks passed." if fb.ok + else f"{len(missing)} issue(s) found: {'; '.join(missing[:5])}" + ) + + logger.info( + "Reviewer: ok=%s, retriable=%s, issues=%d", + fb.ok, fb.retriable, len(missing), + ) + return fb diff --git a/src/hitachione/agents/synthesizer.py b/src/hitachione/agents/synthesizer.py new file mode 100644 index 0000000..171f903 --- /dev/null +++ b/src/hitachione/agents/synthesizer.py @@ -0,0 +1,148 @@ +"""Synthesizer Agent – composes a ranked / comparative answer from research. + +Takes a list of ``CompanyResearch`` objects and produces a user-facing +``SynthesizedAnswer`` with: + • Markdown answer (ranked list or comparison table) + • Rationale explaining the scoring / ranking + • Caveats for partial data + • Citations from news snippets + • Confidence estimate +""" + +from __future__ import annotations + +import json +import logging +from typing import Any + +from openai import OpenAI + +from ..config.settings import OPENAI_BASE_URL, OPENAI_API_KEY, WORKER_MODEL +from ..models.schemas import ( + CompanyResearch, Intent, SynthesizedAnswer, TaskContext, +) + +logger = logging.getLogger(__name__) + +# ── Prompt templates ───────────────────────────────────────────────────── + +_SYSTEM = ( + "You are a financial intelligence synthesizer. " + "Given structured research data for companies, compose a clear, " + "well-structured Markdown answer for the user. " + "Include a rationale section explaining your reasoning. " + "If data is partial, add explicit caveats. " + "If news references are available, include them as citations. " + "Do NOT give investment advice – present analytics and insights only. " + "Return raw Markdown text (not wrapped in JSON)." +) + + +def _build_research_block(research: list[CompanyResearch]) -> str: + """Serialize research data into a text block for the LLM.""" + parts: list[str] = [] + for cr in research: + lines = [f"## {cr.ticker}"] + if cr.sentiment: + lines.append( + f"- Sentiment: rating={cr.sentiment.get('rating')}, " + f"label={cr.sentiment.get('label')}, " + f"rationale={cr.sentiment.get('rationale','')[:200]}" + ) + if cr.performance: + lines.append( + f"- Performance: score={cr.performance.get('performance_score')}, " + f"outlook={cr.performance.get('outlook')}, " + f"justification={cr.performance.get('justification','')[:200]}" + ) + if cr.news_snippets: + lines.append("- News references: " + "; ".join(cr.news_snippets[:3])) + if cr.errors: + lines.append( + "- ⚠ Data gaps: " + + ", ".join(f"{e.tool} ({e.error[:60]})" for e in cr.errors) + ) + parts.append("\n".join(lines)) + return "\n\n".join(parts) + + +def _estimate_confidence(research: list[CompanyResearch]) -> float: + """Heuristic confidence score 0-1 based on data completeness.""" + if not research: + return 0.0 + scores: list[float] = [] + for cr in research: + s = 0.0 + if cr.sentiment and cr.sentiment.get("rating") is not None: + s += 0.5 + if cr.performance and cr.performance.get("performance_score") is not None: + s += 0.5 + scores.append(s) + return sum(scores) / len(scores) + + +class SynthesizerAgent: + """Compose a user-facing answer from per-entity research.""" + + def __init__(self): + self._llm = OpenAI(base_url=OPENAI_BASE_URL, api_key=OPENAI_API_KEY) + + def run( + self, + ctx: TaskContext, + research: list[CompanyResearch], + ) -> SynthesizedAnswer: + answer = SynthesizedAnswer(raw_research=research) + answer.confidence = _estimate_confidence(research) + + # Gather caveats + for cr in research: + for e in cr.errors: + answer.caveats.append(f"{e.entity}: {e.tool} unavailable ({e.error[:80]})") + if answer.confidence < 0.5: + answer.caveats.append("Low overall data coverage – results are best-effort.") + + # Gather citations + for cr in research: + answer.citations.extend(cr.news_snippets[:3]) + + # Build the research context for the LLM + data_block = _build_research_block(research) + user_msg = ( + f"User query: {ctx.user_query}\n" + f"Intent: {ctx.intent.value}\n" + f"Timeframe: {ctx.timeframe or 'not specified'}\n\n" + f"Research data:\n{data_block}" + ) + + try: + resp = self._llm.chat.completions.create( + model=WORKER_MODEL, + messages=[ + {"role": "system", "content": _SYSTEM}, + {"role": "user", "content": user_msg}, + ], + temperature=0.2, + ) + md = (resp.choices[0].message.content or "").strip() + answer.markdown = md + # Extract first paragraph as rationale summary + answer.rationale = md.split("\n\n")[0] if md else "" + except Exception as exc: + logger.error("Synthesizer LLM error: %s", exc) + # Fallback: build a simple text answer from raw data + answer.markdown = self._fallback_markdown(ctx, research) + answer.rationale = "Generated from raw data (LLM unavailable)." + + return answer + + @staticmethod + def _fallback_markdown( + ctx: TaskContext, research: list[CompanyResearch] + ) -> str: + lines = [f"## Results for: {ctx.user_query}\n"] + for cr in research: + sent = cr.sentiment.get("rating", "?") if cr.sentiment else "?" + perf = cr.performance.get("performance_score", "?") if cr.performance else "?" + lines.append(f"- **{cr.ticker}**: sentiment={sent}/10, performance={perf}/10") + return "\n".join(lines) diff --git a/src/hitachione/config/__init__.py b/src/hitachione/config/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/hitachione/config/settings.py b/src/hitachione/config/settings.py new file mode 100644 index 0000000..fe6483d --- /dev/null +++ b/src/hitachione/config/settings.py @@ -0,0 +1,38 @@ +"""Centralised settings for the multi-agent financial intelligence system. + +All secrets come from env-vars / .env – never hard-coded. +""" + +from __future__ import annotations + +import os +from pathlib import Path +from dotenv import load_dotenv + +# Load .env from project root +_PROJECT_ROOT = Path(__file__).resolve().parents[3] +load_dotenv(_PROJECT_ROOT / ".env") + +# ── LLM ────────────────────────────────────────────────────────────────── +OPENAI_BASE_URL = os.getenv( + "OPENAI_BASE_URL", + "https://generativelanguage.googleapis.com/v1beta/openai/", +) +OPENAI_API_KEY = os.getenv("OPENAI_API_KEY", "") +PLANNER_MODEL = os.getenv("DEFAULT_PLANNER_MODEL", "gemini-2.5-flash") +WORKER_MODEL = os.getenv("DEFAULT_WORKER_MODEL", "gemini-2.5-flash") + +# ── Weaviate ───────────────────────────────────────────────────────────── +WEAVIATE_COLLECTION = os.getenv("WEAVIATE_COLLECTION_NAME", "Hitachi_finance_news") +WEAVIATE_API_KEY = os.getenv("WEAVIATE_API_KEY", "") +WEAVIATE_HTTP_HOST = os.getenv("WEAVIATE_HTTP_HOST", "localhost") +WEAVIATE_GRPC_HOST = os.getenv("WEAVIATE_GRPC_HOST", "localhost") + +# ── Langfuse (optional) ───────────────────────────────────────────────── +LANGFUSE_PUBLIC_KEY = os.getenv("LANGFUSE_PUBLIC_KEY") +LANGFUSE_SECRET_KEY = os.getenv("LANGFUSE_SECRET_KEY") +LANGFUSE_HOST = os.getenv("LANGFUSE_HOST", "https://us.cloud.langfuse.com") + +# ── Orchestrator defaults ──────────────────────────────────────────────── +MAX_ITERATIONS = int(os.getenv("ORCHESTRATOR_MAX_ITERATIONS", "2")) +QUALITY_THRESHOLD = float(os.getenv("ORCHESTRATOR_QUALITY_THRESHOLD", "0.6")) diff --git a/src/hitachione/kb/README.md b/src/hitachione/kb/README.md new file mode 100644 index 0000000..e69de29 diff --git a/src/hitachione/main.py b/src/hitachione/main.py new file mode 100644 index 0000000..2322208 --- /dev/null +++ b/src/hitachione/main.py @@ -0,0 +1,60 @@ +"""Entry point for the multi-agent financial intelligence system. + +Usage:: + + python -m src.hitachione.main # launch Gradio UI + python -m src.hitachione.main --cli # one-shot CLI mode +""" + +from __future__ import annotations + +import argparse +import logging +import sys +from pathlib import Path + +# Ensure project root is on sys.path so `utils` etc. resolve +sys.path.insert(0, str(Path(__file__).resolve().parents[2])) +sys.path.insert(0, str(Path(__file__).resolve().parents[1])) + +logging.basicConfig( + level=logging.INFO, + format="%(asctime)s [%(name)s] %(levelname)s: %(message)s", + datefmt="%H:%M:%S", +) + + +def main() -> None: + parser = argparse.ArgumentParser(description="Financial Intelligence Agent") + parser.add_argument( + "--cli", action="store_true", help="Run a single query from stdin" + ) + parser.add_argument( + "--query", type=str, default="", help="Query string (CLI mode)" + ) + parser.add_argument( + "--port", type=int, default=7860, help="Gradio server port" + ) + args = parser.parse_args() + + if args.cli or args.query: + from hitachione.agents.orchestrator import Orchestrator + + query = args.query or input("Enter query: ") + orch = Orchestrator() + answer = orch.run(query) + print(answer.markdown) + if answer.caveats: + print("\nCaveats:") + for c in answer.caveats: + print(f" - {c}") + print(f"\nConfidence: {answer.confidence:.0%}") + else: + from hitachione.ui.app import build_app + + demo = build_app() + demo.launch(share=True) + + +if __name__ == "__main__": + main() diff --git a/src/hitachione/models/__init__.py b/src/hitachione/models/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/hitachione/models/schemas.py b/src/hitachione/models/schemas.py new file mode 100644 index 0000000..0fd983d --- /dev/null +++ b/src/hitachione/models/schemas.py @@ -0,0 +1,92 @@ +"""Shared data models for the multi-agent financial intelligence system. + +All models are plain dataclasses – no heavy framework dependency. +""" + +from __future__ import annotations + +import uuid +from dataclasses import dataclass, field +from datetime import datetime, timezone +from enum import Enum +from typing import Any + + +# ── Intent taxonomy ────────────────────────────────────────────────────── + +class Intent(str, Enum): + RANK = "rank" + COMPARE = "compare" + SNAPSHOT = "snapshot" + EVENT_REACTION = "event_reaction" + FUNDAMENTALS = "fundamentals" + MACRO = "macro" + MIXED = "mixed" + + +# ── Per-run context (scratchpad / blackboard) ──────────────────────────── + +@dataclass +class TaskContext: + """Short-lived context for a single orchestrator run.""" + + run_id: str = field(default_factory=lambda: uuid.uuid4().hex[:12]) + user_query: str = "" + intent: Intent = Intent.MIXED + timeframe: str = "" # e.g. "last 3 years", "2024 Q3" + sector: str = "" # e.g. "automotive" + entities: list[str] = field(default_factory=list) # ticker symbols + constraints: dict[str, Any] = field(default_factory=dict) + + # Blackboard – accumulates across iterations + plan: list[str] = field(default_factory=list) + observations: list[str] = field(default_factory=list) + uncertainties: list[str] = field(default_factory=list) + iteration: int = 0 + timestamp: str = field( + default_factory=lambda: datetime.now(timezone.utc).isoformat(timespec="seconds") + ) + + +# ── Tool / agent outputs ──────────────────────────────────────────────── + +@dataclass +class ToolError: + """Captures a non-fatal error from a tool call.""" + + entity: str + tool: str + error: str + + +@dataclass +class CompanyResearch: + """Research bundle for one company/ticker.""" + + ticker: str + sentiment: dict[str, Any] = field(default_factory=dict) + performance: dict[str, Any] = field(default_factory=dict) + news_snippets: list[str] = field(default_factory=list) + errors: list[ToolError] = field(default_factory=list) + + +@dataclass +class SynthesizedAnswer: + """Final user-facing answer produced by the Synthesizer.""" + + markdown: str = "" + rationale: str = "" + caveats: list[str] = field(default_factory=list) + citations: list[str] = field(default_factory=list) + confidence: float = 0.0 # 0-1 + raw_research: list[CompanyResearch] = field(default_factory=list) + + +@dataclass +class ReviewFeedback: + """Output of the Reviewer agent.""" + + ok: bool = False + retriable: bool = True # False when issues are unfixable (e.g. no KB data) + missing: list[str] = field(default_factory=list) + notes: str = "" diff --git a/src/hitachione/services/__init__.py b/src/hitachione/services/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/hitachione/services/tracing.py b/src/hitachione/services/tracing.py new file mode 100644 index 0000000..4a76ccf --- /dev/null +++ b/src/hitachione/services/tracing.py @@ -0,0 +1,159 @@ +"""Lightweight Langfuse tracing helpers. + +If Langfuse keys are not set the helpers become no-ops so the system +runs cleanly without observability configured. + +Uses the imperative Langfuse Python SDK span API: + - ``Langfuse.start_span(name)`` creates a root span + auto-trace + - ``LangfuseSpan.start_span(name)`` creates a nested child span + - ``LangfuseSpan.update(output=…)`` attaches data + - ``LangfuseSpan.end()`` closes the span +""" + +from __future__ import annotations + +import logging +from contextlib import contextmanager +from typing import Any, Generator + +logger = logging.getLogger(__name__) + +# ── Try to import Langfuse; degrade gracefully ────────────────────────── +_langfuse = None +try: + from langfuse import Langfuse + from ..config.settings import LANGFUSE_PUBLIC_KEY, LANGFUSE_SECRET_KEY, LANGFUSE_HOST + + if LANGFUSE_PUBLIC_KEY and LANGFUSE_SECRET_KEY: + _langfuse = Langfuse( + public_key=LANGFUSE_PUBLIC_KEY, + secret_key=LANGFUSE_SECRET_KEY, + host=LANGFUSE_HOST, + ) + if _langfuse.auth_check(): + logger.info("Langfuse tracing enabled (auth OK)") + else: + logger.warning("Langfuse auth check failed – tracing disabled") + _langfuse = None + else: + logger.info("Langfuse keys not set – tracing disabled") +except Exception as exc: + logger.debug("Langfuse unavailable: %s", exc) + + +# ── Public helpers ────────────────────────────────────────────────────── + +class Tracer: + """Thin wrapper around a Langfuse trace / span tree. + + Usage:: + + tracer = Tracer.start("orchestrator_run", metadata={...}) + with tracer.span("intent_parse") as sp: + sp.update(output={"intent": "rank"}) + tracer.end(output=final_answer) + + The root span acts as the top-level container. All child spans + created via ``tracer.span(...)`` are nested under it so the full + Plan → Act → Observe → Reflect flow is visible in Langfuse. + """ + + def __init__(self, root_span: Any | None = None): + self._root = root_span # LangfuseSpan or None + self._trace_id: str | None = None + if root_span is not None: + self._trace_id = root_span.trace_id + + # --- factory --- + @classmethod + def start( + cls, + name: str, + *, + user_id: str = "", + metadata: dict | None = None, + ) -> "Tracer": + if _langfuse is None: + return cls(None) + try: + root = _langfuse.start_span(name=name, metadata=metadata or {}) + # Attach trace-level info (name, user_id, tags) + root.update_trace( + name=name, + user_id=user_id or None, + metadata=metadata or {}, + ) + logger.info("Langfuse trace started: %s", root.trace_id) + return cls(root) + except Exception as exc: + logger.warning("Langfuse trace start error: %s", exc) + return cls(None) + + # --- span context-manager --- + @contextmanager + def span(self, name: str, **kwargs) -> Generator["_Span", None, None]: + """Create a child span nested under the root.""" + sp = _Span.create(name, parent=self._root, **kwargs) + try: + yield sp + except Exception as exc: + sp.update(level="ERROR", status_message=str(exc)) + raise + finally: + sp.finish() + + # --- finalise --- + def end(self, *, output: Any = None): + if self._root is not None: + try: + if output is not None: + self._root.update(output=output) + self._root.end() + except Exception as exc: + logger.debug("Langfuse root span end error: %s", exc) + if _langfuse is not None: + try: + _langfuse.flush() + except Exception as exc: + logger.debug("Langfuse flush error: %s", exc) + + @property + def trace_id(self) -> str | None: + return self._trace_id + + +class _Span: + """One span inside a trace.""" + + def __init__(self, name: str, lang_span: Any | None = None): + self.name = name + self._span = lang_span + + @classmethod + def create(cls, name: str, parent: Any | None = None, **kwargs) -> "_Span": + """Create a span as a child of *parent* (a ``LangfuseSpan``).""" + if parent is None: + return cls(name, None) + try: + child = parent.start_span( + name=name, + metadata=kwargs.get("metadata"), + ) + return cls(name, child) + except Exception as exc: + logger.debug("Langfuse child span error (%s): %s", name, exc) + return cls(name, None) + + def update(self, **kwargs): + if self._span is not None: + try: + self._span.update(**kwargs) + except Exception: + pass + + def finish(self): + if self._span is not None: + try: + self._span.end() + except Exception: + pass diff --git a/src/hitachione/tools/README.md b/src/hitachione/tools/README.md new file mode 100644 index 0000000..e69de29 diff --git a/src/hitachione/tools/__init__.py b/src/hitachione/tools/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/hitachione/tools/company_filtering_tool/bloomberg.py b/src/hitachione/tools/company_filtering_tool/bloomberg.py new file mode 100644 index 0000000..ed98726 --- /dev/null +++ b/src/hitachione/tools/company_filtering_tool/bloomberg.py @@ -0,0 +1,145 @@ +"""Script to load and explore Bloomberg Financial News dataset. + +This dataset consists of 446,762 financial news articles from Bloomberg (2006-2013). +Dataset: danidanou/Bloomberg_Financial_News +""" + +import os +import pandas as pd + + +def load_bloomberg_data( + limit: int | None = None, + split: str = "train" +) -> pd.DataFrame: + """Load Bloomberg Financial News dataset from HuggingFace. + + Args: + limit: Optional maximum number of articles to load + split: Dataset split to load (default: "train") + + Returns: + DataFrame with financial news articles + """ + print(f"Loading Bloomberg Financial News dataset (split: {split})...") + + # Try direct parquet reading approach + try: + print("Method 1: Direct download and read as parquet...") + from huggingface_hub import hf_hub_download + import pyarrow.parquet as pq + + print("Starting download of parquet file...") + # Download the file + file_path = hf_hub_download( + repo_id="danidanou/Bloomberg_Financial_News", + filename="bloomberg_financial_data.parquet.gzip", + repo_type="dataset" + ) + print(f"Download complete. File path: {file_path}") + + print("Starting to read parquet file...") + table = pq.read_table(file_path) + print(f"Parquet table read successfully. Table shape: {table.shape}") + + print("Converting table to pandas DataFrame...") + df = table.to_pandas() + print(f"Conversion to pandas complete. DataFrame shape: {df.shape}") + + print(f"Successfully loaded {len(df)} articles") + + # Apply limit if specified + if limit is not None: + df = df.head(limit) + print(f"Limited to {limit} articles") + + print(f"Columns: {list(df.columns)}") + return df + + except Exception as e: + print(f"Method 1 failed: {e}") + print("\nMethod 2: Trying datasets library with no verification...") + + # Fallback: datasets library with verification disabled + try: + from datasets import load_dataset + + dataset = load_dataset( + "danidanou/Bloomberg_Financial_News", + split=split, + verification_mode="no_checks", + trust_remote_code=False + ) + + df = dataset.to_pandas() + print(f"Successfully loaded {len(df)} articles") + + if limit is not None: + df = df.head(limit) + print(f"Limited to {limit} articles") + + print(f"Columns: {list(df.columns)}") + return df + + except Exception as e: + print(f"Method 2 failed: {e}") + raise RuntimeError( + f"Unable to load Bloomberg dataset. Both methods failed.\n" + f"This dataset has known compression format issues.\n" + f"Please report this issue or try manually downloading from:\n" + f"https://huggingface.co/datasets/danidanou/Bloomberg_Financial_News" + ) + + +def explore_dataset(df: pd.DataFrame) -> None: + """Print summary statistics about the dataset. + + Args: + df: DataFrame containing Bloomberg articles + """ + print("\n" + "="*60) + print("DATASET EXPLORATION") + print("="*60) + + print(f"\nTotal articles: {len(df):,}") + print(f"\nColumns: {list(df.columns)}") + print(f"\nData types:\n{df.dtypes}") + + print("\n" + "-"*60) + print("SAMPLE RECORD:") + print("-"*60) + print(df.iloc[0].to_dict()) + + # Check for text columns + text_columns = [col for col in df.columns if 'text' in col.lower() or 'content' in col.lower()] + if text_columns: + print(f"\nText columns found: {text_columns}") + for col in text_columns: + print(f"\nSample from '{col}':") + print(df[col].iloc[0][:500] + "..." if len(str(df[col].iloc[0])) > 500 else df[col].iloc[0]) + + # Check for missing values + print("\n" + "-"*60) + print("MISSING VALUES:") + print("-"*60) + print(df.isnull().sum()) + + # Text length statistics if text column exists + if text_columns: + for col in text_columns: + df[f'{col}_length'] = df[col].astype(str).str.len() + print(f"\n{col} length statistics:") + print(df[f'{col}_length'].describe()) + + +if __name__ == "__main__": + # Load a sample of the dataset for exploration + df = load_bloomberg_data(limit=1000) + + # Explore the dataset + explore_dataset(df) + + print("\n" + "="*60) + print("To load the full dataset, use:") + print(" df = load_bloomberg_data()") + print("="*60) diff --git a/src/hitachione/tools/company_filtering_tool/test.py b/src/hitachione/tools/company_filtering_tool/test.py new file mode 100644 index 0000000..c5e3bb1 --- /dev/null +++ b/src/hitachione/tools/company_filtering_tool/test.py @@ -0,0 +1,254 @@ +""" +Test harness for the company filtering tool (Weaviate-backed). + +This module provides comprehensive tests for the find_relevant_symbols tool. +Run this file to test the tool functionality. +""" + +import sys +from typing import List +from dotenv import load_dotenv +from pathlib import Path + +# Load .env file (project root is 5 levels up from this file) +load_dotenv(Path(__file__).resolve().parents[4] / ".env") + +import os +from tool import ( + find_relevant_symbols, + find_relevant_sp500_symbols, + get_all_symbols, + get_company_mapping, + TOOL_SCHEMA, +) + + +def print_section(title: str): + """Print a formatted section header.""" + print("\n" + "=" * 80) + print(f" {title}") + print("=" * 80 + "\n") + + +def test_symbol_extraction(): + """Test extracting all symbols from the Weaviate knowledge base.""" + print_section("Testing Symbol Extraction from Weaviate") + + print("Extracting all unique tickers from Weaviate collection...") + collection = os.getenv('WEAVIATE_COLLECTION_NAME', 'hitachi-finance-news') + print(f"(Iterates through {collection} collection)\n") + + try: + import time + start = time.time() + + symbols = get_all_symbols() + elapsed = time.time() - start + + print(f"✓ Successfully extracted {len(symbols)} unique tickers in {elapsed:.2f}s") + print(f"\nTickers: {', '.join(symbols)}") + + # Also show company mapping + companies = get_company_mapping() + print(f"\nCompany mapping:") + for ticker, company in sorted(companies.items()): + print(f" {ticker}: {company}") + + except Exception as e: + print(f"✗ Error: {e}") + + print() + + +def test_symbol_search(): + """Test the main symbol search functionality with LLM filtering.""" + print_section("Testing Tool Function with LLM Filtering") + + test_cases = [ + "List all the top automotive stocks of 2012", + "Find technology companies from 2015", + "Show me healthcare stocks", + "top 5 tech stocks", + ] + + import os + api_key = os.getenv("OPENAI_API_KEY") or os.getenv("GEMINI_API_KEY") or os.getenv("GOOGLE_API_KEY") + if not api_key: + print("⚠️ No LLM API key set - LLM filtering disabled") + print(" Tool will return all symbols without filtering\n") + else: + print("✓ LLM API key found - LLM filtering enabled\n") + + for i, query in enumerate(test_cases, 1): + print(f"Test {i}/{len(test_cases)}: {query}") + print("-" * 80) + + try: + symbols = find_relevant_symbols(query) + + print(f"✓ Returned {len(symbols)} filtered symbols") + print(f" Results: {', '.join(symbols[:15])}") + if len(symbols) > 15: + print(f" ... and {len(symbols) - 15} more") + except Exception as e: + print(f"✗ Error: {e}") + + print() + + +def test_year_extraction(): + """Test year extraction from queries.""" + print_section("Legacy Test: Year Extraction") + print("Note: This functionality is no longer used.") + print("The tool now returns all symbols and lets the LLM do filtering.\n") + + +def test_keyword_detection(): + """Test keyword detection from queries.""" + print_section("Legacy Test: Keyword Detection") + print("Note: This functionality is no longer used.") + print("The tool now returns all symbols and lets the LLM do filtering.\n") + + +def test_tool_schema(): + """Display and validate the tool schema.""" + print_section("Tool Schema for OpenAI") + + import json + print(json.dumps(TOOL_SCHEMA, indent=2)) + print("\nSchema validation: ✓ Valid JSON structure") + + +def run_interactive_test(): + """Run an interactive test where user can input queries.""" + print_section("Interactive Test Mode") + + import os + api_key = os.getenv("OPENAI_API_KEY") or os.getenv("GEMINI_API_KEY") or os.getenv("GOOGLE_API_KEY") + if not api_key: + print("⚠️ No LLM API key set") + print(" Tool will return all symbols without LLM filtering\n") + else: + print("✓ LLM filtering enabled\n") + + print("Enter queries to test the tool (or 'quit' to exit)") + print("Example: 'top 3 tech stocks of 2010'\n") + + # Load symbols once + try: + print("Loading all tickers from Weaviate...") + all_symbols = get_all_symbols() + companies = get_company_mapping() + print(f"✓ Loaded {len(all_symbols)} unique tickers:") + for t in all_symbols: + print(f" {t}: {companies.get(t, 'N/A')}") + print() + except Exception as e: + print(f"✗ Error loading symbols: {e}") + return + + while True: + try: + query = input("Query> ").strip() + + if query.lower() in ['quit', 'exit', 'q']: + break + + if not query: + continue + + print(f"\nProcessing: '{query}'") + + import time + start = time.time() + symbols = find_relevant_symbols(query) + elapsed = time.time() - start + + print(f"✓ Complete in {elapsed:.2f}s") + print(f" Filtered to {len(symbols)} symbols: {', '.join(symbols)}") + print() + + except KeyboardInterrupt: + print("\n\nExiting interactive mode.") + break + except Exception as e: + print(f"\n✗ Error: {e}\n") + + +def benchmark_query_performance(): + """Benchmark the performance of symbol extraction.""" + print_section("Performance Benchmarking") + + import time + + print("Benchmarking symbol extraction...\n") + + # First run (may need to read from disk) + start_time = time.time() + symbols = get_all_symbols() + elapsed = time.time() - start_time + + print(f"First call: {elapsed:.3f}s ({len(symbols)} symbols)") + + # Second run (should use cache) + start_time = time.time() + symbols = get_all_symbols() + elapsed = time.time() - start_time + + print(f"Cached call: {elapsed:.3f}s (from memory cache)") + print() + + +def main(): + """Main test harness entry point.""" + print("\n" + "=" * 80) + print(" Company Filtering Tool (Weaviate) - Test Harness") + print("=" * 80) + + if len(sys.argv) > 1: + mode = sys.argv[1].lower() + else: + print("\nAvailable test modes:") + print(" 1. all - Run all tests") + print(" 2. quick - Run quick tests (schema only)") + print(" 3. extract - Test symbol extraction") + print(" 4. interactive - Interactive query mode") + print(" 5. benchmark - Performance benchmarking") + print(" 6. schema - Display tool schema") + + choice = input("\nSelect mode (1-6) or press Enter for 'all': ").strip() + + mode_map = { + '1': 'all', + '2': 'quick', + '3': 'extract', + '4': 'interactive', + '5': 'benchmark', + '6': 'schema', + } + + mode = mode_map.get(choice, 'all') + + if mode in ['all', 'quick']: + test_tool_schema() + + if mode in ['all', 'extract']: + test_symbol_extraction() + test_symbol_search() + + if mode == 'interactive': + run_interactive_test() + + if mode == 'benchmark': + benchmark_query_performance() + + if mode == 'schema': + test_tool_schema() + + print_section("Test Harness Complete") + print("To run specific tests, use: python test_sp500_tool.py ") + print("Available modes: all, quick, extract, interactive, benchmark, schema\n") + + +if __name__ == "__main__": + main() diff --git a/src/hitachione/tools/company_filtering_tool/tool.py b/src/hitachione/tools/company_filtering_tool/tool.py new file mode 100644 index 0000000..4e17fdd --- /dev/null +++ b/src/hitachione/tools/company_filtering_tool/tool.py @@ -0,0 +1,475 @@ +""" +Tool for finding relevant stock symbols from the Weaviate financial news knowledge base. + +This tool queries the Weaviate financial news collection to retrieve unique +stock tickers and uses an LLM to filter them based on user queries. +""" + +from typing import List +from pathlib import Path +import os +import json +import asyncio +import re +import difflib +import csv +import io +import logging +from urllib import request + +import weaviate +from weaviate.auth import AuthApiKey +from dotenv import load_dotenv + +# Load .env from project root +load_dotenv(Path(__file__).resolve().parents[4] / ".env") + +# Import client manager from the utils +import sys +sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent)) +from utils.client_manager import AsyncClientManager + + +logger = logging.getLogger(__name__) + +# Cache for symbols and company mapping +_cached_symbols: List[str] | None = None +_cached_companies: dict[str, str] | None = None # ticker -> company name +_client_manager = None + +# Weaviate collection name (from WEAVIATE_COLLECTION_NAME env var) +WEAVIATE_COLLECTION = os.getenv("WEAVIATE_COLLECTION_NAME", "Hitachi_finance_news") +SP500_CONSTITUENTS_URL = os.getenv( + "SP500_CONSTITUENTS_URL", + "https://datahub.io/core/s-and-p-500-companies/r/constituents.csv", +) +SP500_UNIVERSE_ENABLED = os.getenv("SP500_UNIVERSE_ENABLED", "true").lower() == "true" + +_COMMON_TYPO_FIXES = { + "enery": "energy", + "techonology": "technology", + "teh": "the", +} + +_SECTOR_TERMS = [ + "energy", + "technology", + "tech", + "automotive", + "healthcare", + "financial", + "finance", + "retail", +] + +_SECTOR_TICKER_MAP = { + "energy": {"TSLA"}, + "technology": {"AAPL", "AMZN", "GOOGL", "META", "MSFT", "NVDA"}, + "tech": {"AAPL", "AMZN", "GOOGL", "META", "MSFT", "NVDA"}, + "automotive": {"TSLA"}, + "financial": {"JPM", "V"}, + "finance": {"JPM", "V"}, + "retail": {"AMZN", "WMT"}, +} + + +def _normalize_query(query: str) -> str: + """Normalize obvious typos in user queries (e.g., 'enery' -> 'energy').""" + normalized = query + + for wrong, right in _COMMON_TYPO_FIXES.items(): + normalized = re.sub(rf"\b{re.escape(wrong)}\b", right, normalized, flags=re.IGNORECASE) + + words = normalized.split() + fixed_words: list[str] = [] + for word in words: + clean = re.sub(r"[^a-zA-Z]", "", word).lower() + if len(clean) >= 5 and clean not in _SECTOR_TERMS: + matches = difflib.get_close_matches(clean, _SECTOR_TERMS, n=1, cutoff=0.86) + if matches: + fixed_words.append(word.lower().replace(clean, matches[0])) + continue + fixed_words.append(word) + + return " ".join(fixed_words) + + +def _deterministic_sector_filter(query: str, symbols: List[str]) -> List[str]: + """Apply deterministic sector-based filtering when sector terms are present.""" + q = query.lower() + symbol_set = set(symbols) + matched: set[str] = set() + + for sector, sector_symbols in _SECTOR_TICKER_MAP.items(): + if re.search(rf"\b{re.escape(sector)}\b", q): + matched |= (sector_symbols & symbol_set) + + return sorted(matched) + + +def _has_explicit_sector_term(query: str) -> bool: + """Whether the query explicitly mentions a known sector term.""" + q = query.lower() + return any(re.search(rf"\b{re.escape(sector)}\b", q) for sector in _SECTOR_TICKER_MAP) + + +def _load_sp500_constituents() -> tuple[list[str], dict[str, str]]: + """Load S&P 500 constituents (ticker + company) from a CSV source.""" + try: + with request.urlopen(SP500_CONSTITUENTS_URL, timeout=10) as resp: + raw = resp.read().decode("utf-8", errors="ignore") + + reader = csv.DictReader(io.StringIO(raw)) + symbols: set[str] = set() + companies: dict[str, str] = {} + + for row in reader: + symbol = (row.get("Symbol") or "").strip().upper() + name = (row.get("Name") or "").strip() + if not symbol: + continue + symbols.add(symbol) + if name: + companies[symbol] = name + + return sorted(symbols), companies + except Exception as exc: + logger.warning("Failed to load S&P500 constituents from %s: %s", SP500_CONSTITUENTS_URL, exc) + return [], {} + + +def get_client_manager() -> AsyncClientManager: + """Get or create the client manager.""" + global _client_manager + + if _client_manager is None: + _client_manager = AsyncClientManager() + + return _client_manager + + +def _get_weaviate_sync_client(): + """Create a synchronous Weaviate client from environment variables.""" + http_host = os.getenv("WEAVIATE_HTTP_HOST", "localhost") + api_key = os.getenv("WEAVIATE_API_KEY", "") + + # Weaviate Cloud uses connect_to_weaviate_cloud (single host, port 443) + if http_host.endswith(".weaviate.cloud"): + cluster_url = f"https://{http_host}" + return weaviate.connect_to_weaviate_cloud( + cluster_url=cluster_url, + auth_credentials=AuthApiKey(api_key), + ) + + # Otherwise use custom connection for self-hosted instances + return weaviate.connect_to_custom( + http_host=http_host, + http_port=int(os.getenv("WEAVIATE_HTTP_PORT", "8080")), + http_secure=os.getenv("WEAVIATE_HTTP_SECURE", "false").lower() == "true", + grpc_host=os.getenv("WEAVIATE_GRPC_HOST", "localhost"), + grpc_port=int(os.getenv("WEAVIATE_GRPC_PORT", "50051")), + grpc_secure=os.getenv("WEAVIATE_GRPC_SECURE", "false").lower() == "true", + auth_credentials=AuthApiKey(api_key), + ) + + +def get_all_symbols() -> List[str]: + """ + Get all unique stock tickers from the Weaviate knowledge base. + + Iterates through the Weaviate collection and collects + unique ticker symbols and their corresponding company names. + + Returns: + Sorted list of unique stock tickers + """ + global _cached_symbols, _cached_companies + + if _cached_symbols is not None: + return _cached_symbols + + client = _get_weaviate_sync_client() + try: + col = client.collections.get(WEAVIATE_COLLECTION) + + tickers = set() + companies: dict[str, str] = {} + + for obj in col.iterator( + include_vector=False, + return_properties=["ticker", "company"], + ): + ticker = obj.properties.get("ticker") + company = obj.properties.get("company") + if ticker: + ticker = str(ticker).upper() + tickers.add(ticker) + if company and ticker not in companies: + companies[ticker] = str(company) + + # Merge in full S&P500 constituents universe so retrieval is not limited + # to currently indexed Weaviate objects. + if SP500_UNIVERSE_ENABLED: + sp500_symbols, sp500_companies = _load_sp500_constituents() + tickers.update(sp500_symbols) + for symbol, name in sp500_companies.items(): + companies.setdefault(symbol, name) + + _cached_symbols = sorted(tickers) + _cached_companies = companies + return _cached_symbols + + except Exception as e: + raise RuntimeError(f"Failed to load tickers from Weaviate: {e}") + finally: + client.close() + + +def get_company_mapping() -> dict[str, str]: + """ + Get ticker -> company name mapping from the Weaviate knowledge base. + + Returns: + Dictionary mapping ticker symbols to company names + """ + if _cached_companies is None: + get_all_symbols() # populates both caches + return _cached_companies or {} + + +async def filter_symbols_with_llm_async(query: str, symbols: List[str]) -> List[str]: + """ + Use an LLM to filter symbols based on the query (async version). + + Args: + query: Natural-language query describing what to filter for + symbols: List of all available symbols + + Returns: + Filtered list of relevant symbols + """ + client_manager = get_client_manager() + + # Build a readable list with company names + company_map = get_company_mapping() + symbol_list = ", ".join( + f"{s} ({company_map[s]})" if s in company_map else s for s in symbols + ) + + prompt = f"""Given this list of stock symbols from our financial knowledge base, identify and return ALL symbols that match this query: "{query}" + +Available symbols: +{symbol_list} + +Instructions: +- Return ALL matching stock symbols as a JSON array +- Focus on the query requirements (sector, industry, characteristics) +- IGNORE any numeric limits like "top N" - return ALL relevant matches +- Use your knowledge of which companies operate in which sectors +- Return ONLY a JSON object with a "symbols" array, nothing else + +Example: For "tech stocks", return ALL technology company symbols, not just the top few. +Response format: {{"symbols": ["AAPL", "GOOGL", "MSFT", "NVDA", ...]}}""" + + try: + response = await client_manager.openai_client.chat.completions.create( + model=client_manager.configs.default_worker_model, + messages=[{"role": "user", "content": prompt}], + temperature=0, + ) + + content = response.choices[0].message.content.strip() + + # Extract JSON from response (handle markdown code blocks) + if "```json" in content: + content = content.split("```json")[1].split("```")[0].strip() + elif "```" in content: + content = content.split("```")[1].split("```")[0].strip() + + result = json.loads(content) + filtered_symbols = result.get("symbols", []) + + # Validate that returned symbols are in the original list + valid_symbols = [s for s in filtered_symbols if s in symbols] + + return valid_symbols + + except Exception as e: + # Fallback: return all symbols if filtering fails + print(f"Warning: LLM filtering failed ({e}), returning all symbols") + return symbols + + +def filter_symbols_with_llm(query: str, symbols: List[str]) -> List[str]: + """ + Use an LLM to filter symbols based on the query (sync wrapper). + + Args: + query: Natural-language query describing what to filter for + symbols: List of all available symbols + + Returns: + Filtered list of relevant symbols + """ + try: + loop = asyncio.get_running_loop() + except RuntimeError: + loop = None + + if loop and loop.is_running(): + # We're already inside an event loop (e.g. Jupyter / Gradio) + import concurrent.futures + + with concurrent.futures.ThreadPoolExecutor() as pool: + return pool.submit( + asyncio.run, filter_symbols_with_llm_async(query, symbols) + ).result() + else: + return asyncio.run(filter_symbols_with_llm_async(query, symbols)) + + +def find_relevant_symbols(query: str, use_llm_filter: bool = True) -> List[str]: + """ + Find stock symbols relevant to the query from the Weaviate knowledge base. + Uses an LLM internally to filter symbols based on the query. + + Args: + query: Natural-language query describing the type of companies or time period, + e.g. 'List all the top automotive stocks of 2012' + use_llm_filter: Whether to use LLM for filtering (default: True) + + Returns: + Sorted list of filtered stock symbols relevant to the query. + """ + all_symbols = get_all_symbols() + normalized_query = _normalize_query(query) + + deterministic = _deterministic_sector_filter(normalized_query, all_symbols) + + if not use_llm_filter: + # For non-LLM mode, prefer deterministic matches when available. + return deterministic or all_symbols + + # Check if API key is available + api_key = os.getenv("OPENAI_API_KEY") or os.getenv("GEMINI_API_KEY") or os.getenv("GOOGLE_API_KEY") + if not api_key: + print("Warning: No LLM API key set, returning all symbols without filtering") + return deterministic or all_symbols + + # Use LLM to filter symbols based on query + filtered = filter_symbols_with_llm(normalized_query, all_symbols) + + explicit_sector = _has_explicit_sector_term(normalized_query) + + # If the user explicitly asked for a known sector, constrain to that sector + # to avoid model drift (e.g. TSLA leaking into "tech"). + # Only hard-constrain to deterministic matches when that rule has enough + # coverage. A single-ticker deterministic hit (e.g., energy -> TSLA) is too + # narrow for broad sector queries and should not suppress LLM discoveries. + if explicit_sector and deterministic and len(deterministic) >= 2: + constrained = sorted(set(filtered) & set(deterministic)) + if constrained: + return constrained + return deterministic + + # Otherwise combine LLM output with deterministic hints so obvious matches + # are not dropped by the model. + merged = sorted(set(filtered) | set(deterministic)) + + # Prefer merged candidate set; fall back gracefully if model returned nothing. + if merged: + return merged + if deterministic: + return deterministic + + return all_symbols + + +# Keep backward-compatible alias +find_relevant_sp500_symbols = find_relevant_symbols + + +# OpenAI tool schema +TOOL_SCHEMA = { + "type": "function", + "function": { + "name": "find_relevant_symbols", + "description": ( + "Find relevant stock symbols from the Weaviate financial news knowledge base " + "The tool uses an LLM internally to filter " + "symbols based on the query, returning only symbols that match the specified " + "criteria (sector, industry, time period, ranking, etc.)." + ), + "parameters": { + "type": "object", + "properties": { + "query": { + "type": "string", + "description": ( + "Natural-language query describing the type of companies or " + "time period, e.g. 'List all the top automotive stocks of 2012' " + "or 'top 3 tech stocks of 2010'." + ), + } + }, + "required": ["query"], + }, + }, +} + + +# Tool implementation mapping for OpenAI agent integration +TOOL_IMPLEMENTATIONS = { + "find_relevant_symbols": find_relevant_symbols, + "find_relevant_sp500_symbols": find_relevant_symbols, # backward compat +} + + +def run_agent_with_tool(user_query: str, client) -> str: + """ + Run an OpenAI agent that can use the symbol-filtering tool. + + Args: + user_query: User's natural-language query + client: OpenAI client instance + + Returns: + Final response from the agent + """ + import json + + response = client.chat.completions.create( + model="gpt-4o-mini", + messages=[{"role": "user", "content": user_query}], + tools=[TOOL_SCHEMA], + ) + + choice = response.choices[0] + tool_calls = getattr(choice.message, "tool_calls", None) + + if tool_calls: + # Process the first tool call + tool_call = tool_calls[0] + tool_name = tool_call.function.name + tool_args = tool_call.function.arguments # JSON string + + args = json.loads(tool_args) + result = TOOL_IMPLEMENTATIONS[tool_name](**args) + + # Send tool result back to the model for final answer + followup = client.chat.completions.create( + model="gpt-4o-mini", + messages=[ + {"role": "user", "content": user_query}, + { + "role": "tool", + "tool_call_id": tool_call.id, + "name": tool_name, + "content": str(result), + }, + ], + ) + return followup.choices[0].message.content + + # If no tool call, just return the model's direct answer + return choice.message.content diff --git a/src/hitachione/tools/performance_analysis_tool/test.py b/src/hitachione/tools/performance_analysis_tool/test.py new file mode 100644 index 0000000..0cfaced --- /dev/null +++ b/src/hitachione/tools/performance_analysis_tool/test.py @@ -0,0 +1,289 @@ +""" +Test harness for the Performance Analysis Tool (Weaviate-backed). + +Run: + cd src/hitachione/tools/performance_analysis_tool + python3 test.py all # full suite + python3 test.py data # data retrieval only (no LLM) + python3 test.py analyse # full analysis (requires LLM key) + python3 test.py schema # show tool schema + python3 test.py interactive # interactive ticker input +""" + +import json +import os +import sys +import time +from pathlib import Path + +from dotenv import load_dotenv + +# Load .env from project root +load_dotenv(Path(__file__).resolve().parents[4] / ".env") + +from tool import ( + TOOL_SCHEMA, + analyse_stock_performance, + get_ticker_data, +) + +# ── Tickers known to exist in the Weaviate collection ── +KNOWN_TICKERS = ["AAPL", "AMZN", "GOOGL", "JPM", "META", "MSFT", "NVDA", "TSLA", "V", "WMT"] +UNKNOWN_TICKER = "ZZZZZ" + + +# ────────────────────────────────────────────────────────────────────────── +# Helpers +# ────────────────────────────────────────────────────────────────────────── + +def _section(title: str) -> None: + print("\n" + "=" * 80) + print(f" {title}") + print("=" * 80 + "\n") + + +def _pass(msg: str) -> None: + print(f" ✓ {msg}") + + +def _fail(msg: str) -> None: + print(f" ✗ {msg}") + + +# ────────────────────────────────────────────────────────────────────────── +# Tests +# ────────────────────────────────────────────────────────────────────────── + +def test_tool_schema() -> None: + """Validate the tool schema structure.""" + _section("Tool Schema") + print(json.dumps(TOOL_SCHEMA, indent=2)) + + assert TOOL_SCHEMA["type"] == "function" + fn = TOOL_SCHEMA["function"] + assert fn["name"] == "analyse_stock_performance" + assert "ticker" in fn["parameters"]["properties"] + assert "ticker" in fn["parameters"]["required"] + _pass("Schema is valid") + + +def test_data_retrieval_known_ticker() -> None: + """Retrieve data for known tickers and verify structure.""" + _section("Data Retrieval — Known Tickers") + + for ticker in ["AAPL", "TSLA", "JPM"]: + t0 = time.time() + data = get_ticker_data(ticker) + elapsed = time.time() - t0 + + assert isinstance(data, dict), f"Expected dict, got {type(data)}" + for key in ("price_data", "earnings", "news"): + assert key in data, f"Missing key '{key}' for {ticker}" + + total = sum(len(v) for v in data.values()) + _pass( + f"{ticker}: {len(data['price_data'])} price, " + f"{len(data['earnings'])} earnings, " + f"{len(data['news'])} news ({elapsed:.2f}s, {total} total)" + ) + + # At least one data source should have records + assert total > 0, f"No data returned for known ticker {ticker}" + + +def test_data_retrieval_unknown_ticker() -> None: + """Verify graceful handling of an unknown ticker.""" + _section("Data Retrieval — Unknown Ticker") + + data = get_ticker_data(UNKNOWN_TICKER) + total = sum(len(v) for v in data.values()) + assert total == 0, f"Expected 0 records for {UNKNOWN_TICKER}, got {total}" + _pass(f"{UNKNOWN_TICKER}: 0 records as expected") + + +def test_data_retrieval_case_insensitive() -> None: + """Verify ticker is uppercased automatically.""" + _section("Data Retrieval — Case Insensitivity") + + data_upper = get_ticker_data("AAPL") + data_lower = get_ticker_data("aapl") + + assert len(data_upper["price_data"]) == len(data_lower["price_data"]), \ + "Price data count differs between 'AAPL' and 'aapl'" + _pass("'AAPL' and 'aapl' return same results") + + +def test_price_data_fields() -> None: + """Verify price records contain expected fields.""" + _section("Price Data — Field Validation") + + data = get_ticker_data("TSLA") + if not data["price_data"]: + _fail("No price data for TSLA") + return + + expected_fields = {"date", "open", "high", "low", "close"} + for i, rec in enumerate(data["price_data"][:3]): + present = set(rec.keys()) & expected_fields + assert present == expected_fields, ( + f"Record {i} missing fields: {expected_fields - present}" + ) + _pass(f"First {min(3, len(data['price_data']))} records have all OHLC fields") + + +def test_price_data_sorted() -> None: + """Verify price records are sorted by date.""" + _section("Price Data — Sort Order") + + data = get_ticker_data("GOOGL") + dates = [r["date"] for r in data["price_data"] if "date" in r] + assert dates == sorted(dates), "Price data is not sorted by date" + _pass(f"GOOGL: {len(dates)} price records sorted correctly") + + +def test_analyse_unknown_ticker() -> None: + """Full analysis on unknown ticker returns None score.""" + _section("Full Analysis — Unknown Ticker") + + result = analyse_stock_performance(UNKNOWN_TICKER) + assert result["ticker"] == UNKNOWN_TICKER + assert result["performance_score"] is None + assert result["data_summary"]["price_records"] == 0 + _pass(f"{UNKNOWN_TICKER}: score=None, outlook={result['outlook']}") + + +def test_analyse_known_ticker() -> None: + """Full analysis on a known ticker returns valid structure.""" + _section("Full Analysis — Known Tickers (LLM)") + + api_key = ( + os.getenv("OPENAI_API_KEY") + or os.getenv("GEMINI_API_KEY") + or os.getenv("GOOGLE_API_KEY") + ) + if not api_key: + print(" ⚠️ No LLM API key — skipping LLM analysis tests") + return + + for ticker in ["AAPL", "NVDA"]: + t0 = time.time() + result = analyse_stock_performance(ticker) + elapsed = time.time() - t0 + + assert isinstance(result, dict) + assert result["ticker"] == ticker + assert isinstance(result["performance_score"], int) + assert 1 <= result["performance_score"] <= 10 + assert result["outlook"] in ("Bullish", "Bearish", "Volatile", "Sideways") + assert len(result["justification"]) > 20 + assert result["data_summary"]["price_records"] > 0 + + _pass( + f"{ticker}: score={result['performance_score']}, " + f"outlook={result['outlook']}, {elapsed:.1f}s" + ) + print(f" Justification: {result['justification'][:120]}...") + + +def test_analyse_multiple_tickers() -> None: + """Analyse several tickers to confirm consistency.""" + _section("Full Analysis — All Known Tickers (LLM)") + + api_key = ( + os.getenv("OPENAI_API_KEY") + or os.getenv("GEMINI_API_KEY") + or os.getenv("GOOGLE_API_KEY") + ) + if not api_key: + print(" ⚠️ No LLM API key — skipping") + return + + for ticker in KNOWN_TICKERS: + result = analyse_stock_performance(ticker) + score = result["performance_score"] + outlook = result["outlook"] + ds = result["data_summary"] + _pass( + f"{ticker:5s}: score={score:>2}, outlook={outlook:8s} " + f"(price={ds['price_records']}, earn={ds['earnings_records']}, " + f"news={ds['news_records']})" + ) + + +# ────────────────────────────────────────────────────────────────────────── +# Interactive mode +# ────────────────────────────────────────────────────────────────────────── + +def interactive() -> None: + _section("Interactive Mode") + print(f"Available tickers: {', '.join(KNOWN_TICKERS)}") + print("Enter a ticker (or 'quit' to exit)\n") + + while True: + try: + ticker = input("Ticker> ").strip() + if ticker.lower() in ("quit", "exit", "q"): + break + if not ticker: + continue + + print(f"\nAnalysing {ticker.upper()}...") + t0 = time.time() + result = analyse_stock_performance(ticker) + elapsed = time.time() - t0 + + print(json.dumps(result, indent=2)) + print(f"({elapsed:.1f}s)\n") + + except KeyboardInterrupt: + print("\n") + break + except Exception as e: + print(f" ✗ Error: {e}\n") + + +# ────────────────────────────────────────────────────────────────────────── +# Main +# ────────────────────────────────────────────────────────────────────────── + +def main() -> None: + print("\n" + "=" * 80) + print(" Performance Analysis Tool (Weaviate) — Test Harness") + print("=" * 80) + + if len(sys.argv) > 1: + mode = sys.argv[1].lower() + else: + print("\nModes:") + print(" 1. all — Run all tests") + print(" 2. data — Data retrieval tests only (no LLM)") + print(" 3. analyse — Full analysis tests (requires LLM)") + print(" 4. schema — Display tool schema") + print(" 5. interactive — Interactive ticker input") + + choice = input("\nSelect (1-5) or press Enter for 'all': ").strip() + mode = {"1": "all", "2": "data", "3": "analyse", "4": "schema", "5": "interactive"}.get(choice, "all") + + if mode in ("all", "schema"): + test_tool_schema() + + if mode in ("all", "data"): + test_data_retrieval_known_ticker() + test_data_retrieval_unknown_ticker() + test_data_retrieval_case_insensitive() + test_price_data_fields() + test_price_data_sorted() + + if mode in ("all", "analyse"): + test_analyse_unknown_ticker() + test_analyse_known_ticker() + test_analyse_multiple_tickers() + + if mode == "interactive": + interactive() + + _section("Test Harness Complete") + + +if __name__ == "__main__": + main() diff --git a/src/hitachione/tools/performance_analysis_tool/tool.py b/src/hitachione/tools/performance_analysis_tool/tool.py new file mode 100644 index 0000000..5ae74ce --- /dev/null +++ b/src/hitachione/tools/performance_analysis_tool/tool.py @@ -0,0 +1,318 @@ +""" +Tool for analysing stock performance using the Weaviate knowledge base. + +Queries the Weaviate financial news collection for price history, earnings +transcripts, and financial news, then uses an LLM to produce a performance +rating score (1-10), future outlook, and justification. +""" + +import asyncio +import json +import os +from pathlib import Path +from typing import Any, List + +import weaviate +from weaviate.auth import AuthApiKey +from weaviate.classes.query import Filter +from dotenv import load_dotenv + +# Load .env from project root +load_dotenv(Path(__file__).resolve().parents[4] / ".env") + +import sys +sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent)) +from utils.client_manager import AsyncClientManager + +# --------------------------------------------------------------------------- +# Weaviate helpers +# --------------------------------------------------------------------------- + +WEAVIATE_COLLECTION = os.getenv("WEAVIATE_COLLECTION_NAME", "Hitachi_finance_news") + + +def _get_weaviate_sync_client(): + """Create a synchronous Weaviate client from environment variables.""" + http_host = os.getenv("WEAVIATE_HTTP_HOST", "localhost") + api_key = os.getenv("WEAVIATE_API_KEY", "") + + if http_host.endswith(".weaviate.cloud"): + return weaviate.connect_to_weaviate_cloud( + cluster_url=f"https://{http_host}", + auth_credentials=AuthApiKey(api_key), + ) + + return weaviate.connect_to_custom( + http_host=http_host, + http_port=int(os.getenv("WEAVIATE_HTTP_PORT", "8080")), + http_secure=os.getenv("WEAVIATE_HTTP_SECURE", "false").lower() == "true", + grpc_host=os.getenv("WEAVIATE_GRPC_HOST", "localhost"), + grpc_port=int(os.getenv("WEAVIATE_GRPC_PORT", "50051")), + grpc_secure=os.getenv("WEAVIATE_GRPC_SECURE", "false").lower() == "true", + auth_credentials=AuthApiKey(api_key), + ) + + +# --------------------------------------------------------------------------- +# Data retrieval +# --------------------------------------------------------------------------- + + +def get_ticker_data(ticker: str) -> dict[str, list[dict]]: + """Retrieve all Weaviate data for a given ticker, grouped by source. + + Returns a dict with keys ``price_data``, ``earnings``, ``news`` — each + containing a list of property dicts. + """ + client = _get_weaviate_sync_client() + try: + col = client.collections.get(WEAVIATE_COLLECTION) + + ticker_filter = Filter.by_property("ticker").equal(ticker.upper()) + + # --- price data (stock_market) --- + price_response = col.query.fetch_objects( + filters=( + ticker_filter + & Filter.by_property("dataset_source").equal("stock_market") + ), + limit=100, + return_properties=[ + "date", "open", "high", "low", "close", "volume", "text", + ], + ) + price_data = [ + {k: v for k, v in obj.properties.items() if v is not None} + for obj in price_response.objects + ] + + # --- earnings transcripts --- + earnings_response = col.query.fetch_objects( + filters=( + ticker_filter + & Filter.by_property("dataset_source").equal( + "sp500_earnings_transcripts" + ) + ), + limit=100, + return_properties=[ + "date", "quarter", "fiscal_year", "fiscal_quarter", "text", + "title", + ], + ) + earnings = [ + {k: v for k, v in obj.properties.items() if v is not None} + for obj in earnings_response.objects + ] + + # --- news (bloomberg + yahoo) --- + news_response = col.query.fetch_objects( + filters=( + ticker_filter + & ( + Filter.by_property("dataset_source").equal( + "bloomberg_financial_news" + ) + | Filter.by_property("dataset_source").equal( + "yahoo_finance_news" + ) + ) + ), + limit=100, + return_properties=["date", "title", "text", "category"], + ) + # Also grab news that *mention* this ticker (mentioned_companies) + mentioned_response = col.query.fetch_objects( + filters=Filter.by_property("mentioned_companies").contains_any( + [ticker.upper()] + ), + limit=50, + return_properties=["date", "title", "text", "category"], + ) + + seen_titles: set[str] = set() + news: list[dict] = [] + for obj in list(news_response.objects) + list(mentioned_response.objects): + props = {k: v for k, v in obj.properties.items() if v is not None} + title = props.get("title", "") + if title not in seen_titles: + seen_titles.add(title) + news.append(props) + + return { + "price_data": sorted(price_data, key=lambda d: d.get("date", "")), + "earnings": sorted(earnings, key=lambda d: d.get("date", "")), + "news": sorted(news, key=lambda d: d.get("date", "")), + } + + finally: + client.close() + + +# --------------------------------------------------------------------------- +# LLM-based performance scoring +# --------------------------------------------------------------------------- + +_client_manager = None + + +def _get_client_manager() -> AsyncClientManager: + global _client_manager + if _client_manager is None: + _client_manager = AsyncClientManager() + return _client_manager + + +async def _analyse_with_llm(ticker: str, data: dict[str, list[dict]]) -> dict: + """Send retrieved data to an LLM and get a structured performance analysis.""" + cm = _get_client_manager() + + # Build context sections + price_summary = "\n".join( + d.get("text", json.dumps(d)) for d in data["price_data"] + ) or "No price data available." + + earnings_summary = "\n---\n".join( + d.get("text", json.dumps(d)) for d in data["earnings"] + ) or "No earnings data available." + + news_summary = "\n---\n".join( + f"[{d.get('date','')}] {d.get('title','')}: {str(d.get('text',''))[:500]}" + for d in data["news"] + ) or "No news articles available." + + prompt = f"""You are a Stock Performance Analyst. Analyse the ticker "{ticker}" using ONLY the data provided below. + +## Price History +{price_summary} + +## Earnings Transcripts +{earnings_summary} + +## News Articles +{news_summary} + +Based on the data above, produce a JSON object (and NOTHING else) with exactly these keys: + +{{ + "ticker": "{ticker}", + "performance_score": , + "outlook": "", + "justification": "<2-4 sentence explanation citing specific data points>" +}} + +Scoring guide: + 1-4 → Negative (declining price, poor earnings, negative news) + 5 → Neutral + 6-10 → Positive (rising price, strong earnings, positive sentiment) +""" + + response = await cm.openai_client.chat.completions.create( + model=cm.configs.default_worker_model, + messages=[{"role": "user", "content": prompt}], + temperature=0, + ) + + content = response.choices[0].message.content.strip() + + # Extract JSON from potential markdown code fences + if "```json" in content: + content = content.split("```json")[1].split("```")[0].strip() + elif "```" in content: + content = content.split("```")[1].split("```")[0].strip() + + return json.loads(content) + + +def _run_async(coro): + """Run an async coroutine, handling nested event loops.""" + try: + loop = asyncio.get_running_loop() + except RuntimeError: + loop = None + + if loop and loop.is_running(): + import concurrent.futures + with concurrent.futures.ThreadPoolExecutor() as pool: + return pool.submit(asyncio.run, coro).result() + return asyncio.run(coro) + + +# --------------------------------------------------------------------------- +# Public API +# --------------------------------------------------------------------------- + + +def analyse_stock_performance(ticker: str) -> dict: + """Analyse a stock's performance using Weaviate knowledge base data. + + Retrieves price history, earnings transcripts, and news articles from the + Weaviate financial news collection, then uses an LLM to produce a + structured performance analysis. + + Parameters + ---------- + ticker : str + Stock ticker symbol (e.g. ``"AAPL"``, ``"TSLA"``). + + Returns + ------- + dict + A dictionary with keys: + - ``ticker`` (str) + - ``performance_score`` (int, 1–10) + - ``outlook`` (str, one of Bullish/Bearish/Volatile/Sideways) + - ``justification`` (str) + - ``data_summary`` (dict with counts of price/earnings/news records) + """ + ticker = ticker.upper().strip() + data = get_ticker_data(ticker) + + if not any(data.values()): + return { + "ticker": ticker, + "performance_score": None, + "outlook": "Unknown", + "justification": f"No data found for ticker {ticker} in the knowledge base.", + "data_summary": {"price_records": 0, "earnings_records": 0, "news_records": 0}, + } + + analysis = _run_async(_analyse_with_llm(ticker, data)) + analysis["data_summary"] = { + "price_records": len(data["price_data"]), + "earnings_records": len(data["earnings"]), + "news_records": len(data["news"]), + } + return analysis + + +# --------------------------------------------------------------------------- +# OpenAI tool schema +# --------------------------------------------------------------------------- + +TOOL_SCHEMA = { + "type": "function", + "function": { + "name": "analyse_stock_performance", + "description": ( + "Analyse a stock's performance using the Weaviate financial knowledge " + "base. Returns a performance score (1-10), " + "future outlook, and justification based on price history, earnings " + "transcripts, and news articles." + ), + "parameters": { + "type": "object", + "properties": { + "ticker": { + "type": "string", + "description": "Stock ticker symbol, e.g. 'AAPL', 'TSLA', 'GOOGL'.", + } + }, + "required": ["ticker"], + }, + }, +} + +TOOL_IMPLEMENTATIONS = { + "analyse_stock_performance": analyse_stock_performance, +} \ No newline at end of file diff --git a/src/hitachione/tools/sentiment_analysis_tool/test.py b/src/hitachione/tools/sentiment_analysis_tool/test.py new file mode 100644 index 0000000..c25208b --- /dev/null +++ b/src/hitachione/tools/sentiment_analysis_tool/test.py @@ -0,0 +1,502 @@ +""" +Test harness for the Sentiment Analysis Tool (Weaviate-backed). + +Run: + cd src/hitachione/tools/sentiment_analysis_tool + python3 -W ignore::ResourceWarning test.py all # full suite + python3 -W ignore::ResourceWarning test.py data # Weaviate queries only (no LLM) + python3 -W ignore::ResourceWarning test.py sentiment # LLM sentiment analysis tests + python3 -W ignore::ResourceWarning test.py schema # show tool schemas + python3 -W ignore::ResourceWarning test.py interactive # interactive prompt +""" + +import json +import os +import sys +import time +from pathlib import Path + +from dotenv import load_dotenv + +# Load .env from project root +load_dotenv(Path(__file__).resolve().parents[4] / ".env") + +from tool import ( + TOOL_SCHEMAS, + TOOL_IMPLEMENTATIONS, + TOOL_SCHEMA_TICKER, + TOOL_SCHEMA_YEAR, + TOOL_SCHEMA_NEWS, + TOOL_SCHEMA_TEXT, + query_weaviate_by_ticker, + query_weaviate_by_year, + query_weaviate_by_topic, + analyze_sentiment_sync, + analyze_ticker_sentiment_sync, + analyze_year_sentiment_sync, + analyze_news_sentiment_sync, + EXAMPLE_TEXT, +) + +# ── Known data in the Weaviate collection ── +KNOWN_TICKERS = ["AAPL", "AMZN", "GOOGL", "JPM", "META", "MSFT", "NVDA", "TSLA", "V", "WMT"] +UNKNOWN_TICKER = "ZZZZZ" + + +# ────────────────────────────────────────────────────────────────────────── +# Helpers +# ────────────────────────────────────────────────────────────────────────── + +def _section(title: str) -> None: + print("\n" + "=" * 80) + print(f" {title}") + print("=" * 80 + "\n") + + +def _pass(msg: str) -> None: + print(f" ✓ {msg}") + + +def _fail(msg: str) -> None: + print(f" ✗ {msg}") + + +def _has_llm_key() -> bool: + return bool( + os.getenv("OPENAI_API_KEY") + or os.getenv("GEMINI_API_KEY") + or os.getenv("GOOGLE_API_KEY") + ) + + +# ────────────────────────────────────────────────────────────────────────── +# Schema tests +# ────────────────────────────────────────────────────────────────────────── + +def test_tool_schemas() -> None: + """Validate all tool schema structures.""" + _section("Tool Schemas") + + for schema in TOOL_SCHEMAS: + assert schema["type"] == "function" + fn = schema["function"] + assert "name" in fn + assert "description" in fn + assert "parameters" in fn + assert fn["parameters"]["type"] == "object" + assert "required" in fn["parameters"] + _pass(f"Schema '{fn['name']}' is valid") + + # Verify implementation mapping + for schema in TOOL_SCHEMAS: + name = schema["function"]["name"] + assert name in TOOL_IMPLEMENTATIONS, f"No implementation for '{name}'" + assert callable(TOOL_IMPLEMENTATIONS[name]) + _pass(f"All {len(TOOL_SCHEMAS)} schemas have implementations") + + print("\nSchemas:") + for schema in TOOL_SCHEMAS: + print(json.dumps(schema, indent=2)) + + +# ────────────────────────────────────────────────────────────────────────── +# Weaviate data-retrieval tests (no LLM required) +# ────────────────────────────────────────────────────────────────────────── + +def test_query_by_ticker_known() -> None: + """Retrieve data for known tickers and verify results.""" + _section("Query by Ticker — Known Tickers") + + for ticker in ["AAPL", "TSLA", "JPM"]: + t0 = time.time() + records = query_weaviate_by_ticker(ticker, limit=20) + elapsed = time.time() - t0 + + assert isinstance(records, list), f"Expected list, got {type(records)}" + assert len(records) > 0, f"No records returned for {ticker}" + + # Each record should have text + for rec in records: + assert "text" in rec or "title" in rec, f"Record missing text/title for {ticker}" + + _pass(f"{ticker}: {len(records)} records ({elapsed:.2f}s)") + + +def test_query_by_ticker_unknown() -> None: + """Verify unknown ticker returns empty list.""" + _section("Query by Ticker — Unknown Ticker") + + records = query_weaviate_by_ticker(UNKNOWN_TICKER) + assert isinstance(records, list) + assert len(records) == 0, f"Expected 0 records for {UNKNOWN_TICKER}, got {len(records)}" + _pass(f"{UNKNOWN_TICKER}: 0 records as expected") + + +def test_query_by_ticker_case_insensitive() -> None: + """Verify ticker lookup is case-insensitive.""" + _section("Query by Ticker — Case Insensitivity") + + upper = query_weaviate_by_ticker("AAPL", limit=10) + lower = query_weaviate_by_ticker("aapl", limit=10) + + assert len(upper) == len(lower), ( + f"Count differs: 'AAPL' returned {len(upper)}, 'aapl' returned {len(lower)}" + ) + _pass("'AAPL' and 'aapl' return same number of results") + + +def test_query_by_year() -> None: + """Retrieve data for a year range.""" + _section("Query by Year") + + # Try several years — the dataset may span different years + for year in [2020, 2023, 2024]: + records = query_weaviate_by_year(year, limit=10) + assert isinstance(records, list) + if records: + for rec in records: + date_str = str(rec.get("date", "")) + assert str(year) in date_str[:4], ( + f"Record date '{date_str}' doesn't match year {year}" + ) + _pass(f"Year {year}: {len(records)} records") + else: + print(f" - Year {year}: no records (may not be in dataset)") + + +def test_query_by_topic() -> None: + """Retrieve articles matching a topic via BM25 search.""" + _section("Query by Topic (BM25)") + + for topic in ["earnings", "inflation", "technology"]: + t0 = time.time() + records = query_weaviate_by_topic(topic, limit=5) + elapsed = time.time() - t0 + + assert isinstance(records, list) + # BM25 should find something for broad financial topics + _pass(f"'{topic}': {len(records)} records ({elapsed:.2f}s)") + + if records: + # Spot-check structure + for rec in records: + assert isinstance(rec, dict) + + +def test_query_record_properties() -> None: + """Verify returned records contain expected properties.""" + _section("Record Properties") + + records = query_weaviate_by_ticker("MSFT", limit=5) + assert len(records) > 0, "No records for MSFT" + + expected = {"text", "title", "date", "dataset_source"} + for i, rec in enumerate(records[:3]): + present = set(rec.keys()) & expected + missing = expected - present + if not missing: + _pass(f"Record {i}: has all expected properties") + else: + # Some properties may be null for certain dataset sources + print(f" - Record {i}: missing {missing} (may be null)") + + +# ────────────────────────────────────────────────────────────────────────── +# LLM sentiment analysis tests +# ────────────────────────────────────────────────────────────────────────── + +def test_analyze_sentiment_free_text() -> None: + """Analyze sentiment of free-form text.""" + _section("Free Text Sentiment (LLM)") + + if not _has_llm_key(): + print(" ⚠️ No LLM API key — skipping") + return + + t0 = time.time() + result = analyze_sentiment_sync(EXAMPLE_TEXT) + elapsed = time.time() - t0 + + assert isinstance(result, dict) + assert "label" in result + assert "rating" in result + assert "rationale" in result + assert result["label"] in ("Negative", "Neutral", "Positive", "unknown") + if result["rating"] is not None: + assert 1 <= result["rating"] <= 10 + # Verify label matches rating + if result["rating"] <= 4: + assert result["label"] == "Negative" + elif result["rating"] == 5: + assert result["label"] == "Neutral" + else: + assert result["label"] == "Positive" + + _pass( + f"rating={result['rating']}, label={result['label']}, " + f"{elapsed:.1f}s" + ) + print(f" Rationale: {result['rationale'][:120]}") + + +def test_analyze_ticker_sentiment_known() -> None: + """Analyze ticker sentiment for known tickers.""" + _section("Ticker Sentiment — Known (LLM)") + + if not _has_llm_key(): + print(" ⚠️ No LLM API key — skipping") + return + + for ticker in ["AAPL", "NVDA"]: + t0 = time.time() + result = analyze_ticker_sentiment_sync(ticker) + elapsed = time.time() - t0 + + assert isinstance(result, dict) + assert result["ticker"] == ticker + assert result["rating"] is not None, f"Rating is None for {ticker}" + assert isinstance(result["rating"], int) + assert 1 <= result["rating"] <= 10 + # Verify label matches rating scale + if result["rating"] <= 4: + assert result["label"] == "Negative", f"Expected Negative for rating {result['rating']}" + elif result["rating"] == 5: + assert result["label"] == "Neutral", f"Expected Neutral for rating {result['rating']}" + else: + assert result["label"] == "Positive", f"Expected Positive for rating {result['rating']}" + assert len(result.get("rationale", "")) > 10 + assert isinstance(result.get("references", []), list) + + _pass( + f"{ticker}: rating={result['rating']}, label={result['label']}, " + f"{elapsed:.1f}s" + ) + print(f" Rationale: {result['rationale'][:120]}") + + +def test_analyze_ticker_sentiment_unknown() -> None: + """Analyze ticker sentiment for unknown ticker returns graceful result.""" + _section("Ticker Sentiment — Unknown") + + result = analyze_ticker_sentiment_sync(UNKNOWN_TICKER) + assert isinstance(result, dict) + assert result["ticker"] == UNKNOWN_TICKER + assert result["rating"] is None + assert result["label"] == "unknown" + assert "no data" in result["rationale"].lower() or "not found" in result["rationale"].lower() or len(result["rationale"]) > 0 + _pass(f"{UNKNOWN_TICKER}: rating=None, label=unknown (no LLM call needed)") + + +def test_analyze_year_sentiment() -> None: + """Analyze year sentiment (requires data for that year in Weaviate).""" + _section("Year Sentiment (LLM)") + + if not _has_llm_key(): + print(" ⚠️ No LLM API key — skipping") + return + + # First find a year that has data + for year in [2024, 2023, 2020]: + records = query_weaviate_by_year(year, limit=5) + if records: + break + else: + print(" ⚠️ No year data found in KB — skipping") + return + + t0 = time.time() + result = analyze_year_sentiment_sync(year) + elapsed = time.time() - t0 + + assert isinstance(result, dict) + assert result["year"] == year + assert "label" in result + assert "rating" in result + assert "rationale" in result + if result["rating"] is not None: + assert 1 <= result["rating"] <= 10 + if result["rating"] <= 4: + assert result["label"] == "Negative" + elif result["rating"] == 5: + assert result["label"] == "Neutral" + else: + assert result["label"] == "Positive" + + _pass( + f"Year {year}: rating={result['rating']}, label={result['label']}, " + f"{elapsed:.1f}s" + ) + print(f" Rationale: {result['rationale'][:120]}") + + +def test_analyze_news_sentiment() -> None: + """Analyze news sentiment by topic query.""" + _section("News Sentiment by Topic (LLM)") + + if not _has_llm_key(): + print(" ⚠️ No LLM API key — skipping") + return + + t0 = time.time() + result = analyze_news_sentiment_sync("technology stocks earnings") + elapsed = time.time() - t0 + + assert isinstance(result, dict) + assert "label" in result + assert "rating" in result + assert "rationale" in result + if result["rating"] is not None: + assert 1 <= result["rating"] <= 10 + + _pass( + f"rating={result['rating']}, label={result['label']}, " + f"{elapsed:.1f}s" + ) + print(f" Rationale: {result['rationale'][:120]}") + + +def test_analyze_news_sentiment_no_results() -> None: + """Topic query with no matching results returns graceful output.""" + _section("News Sentiment — No Results") + + result = analyze_news_sentiment_sync("xyzzy_no_such_topic_9999") + assert isinstance(result, dict) + assert result["label"] == "unknown" + assert result["rating"] is None + _pass("No-match query: label=unknown, rating=None (no LLM call)") + + +def test_all_tickers_sentiment() -> None: + """Run sentiment analysis across all known tickers.""" + _section("All Known Tickers Sentiment (LLM)") + + if not _has_llm_key(): + print(" ⚠️ No LLM API key — skipping") + return + + for ticker in KNOWN_TICKERS: + t0 = time.time() + result = analyze_ticker_sentiment_sync(ticker) + elapsed = time.time() - t0 + + assert isinstance(result, dict) + assert result["ticker"] == ticker + assert isinstance(result["rating"], int) + assert 1 <= result["rating"] <= 10 + # Verify label matches rating + if result["rating"] <= 4: + expected_label = "Negative" + elif result["rating"] == 5: + expected_label = "Neutral" + else: + expected_label = "Positive" + assert result["label"] == expected_label + + _pass( + f"{ticker}: rating={result['rating']}, label={result['label']}, " + f"{elapsed:.1f}s" + ) + + +# ────────────────────────────────────────────────────────────────────────── +# Interactive mode +# ────────────────────────────────────────────────────────────────────────── + +def interactive() -> None: + """Prompt the user for a ticker / year / topic / text and run analysis.""" + _section("Interactive Mode") + + print("Options:") + print(" 1. Ticker sentiment (e.g. AAPL)") + print(" 2. Year sentiment (e.g. 2024)") + print(" 3. Topic sentiment (e.g. inflation)") + print(" 4. Free text sentiment") + print() + + choice = input("Choose (1-4): ").strip() + + if choice == "1": + ticker = input("Ticker: ").strip().upper() + result = analyze_ticker_sentiment_sync(ticker) + elif choice == "2": + year = int(input("Year: ").strip()) + result = analyze_year_sentiment_sync(year) + elif choice == "3": + topic = input("Topic: ").strip() + result = analyze_news_sentiment_sync(topic) + elif choice == "4": + text = input("Text: ").strip() + result = analyze_sentiment_sync(text) + else: + print("Invalid choice.") + return + + print(json.dumps(result, indent=2, ensure_ascii=False)) + + +# ────────────────────────────────────────────────────────────────────────── +# Runner +# ────────────────────────────────────────────────────────────────────────── + +TEST_GROUPS = { + "schema": [test_tool_schemas], + "data": [ + test_query_by_ticker_known, + test_query_by_ticker_unknown, + test_query_by_ticker_case_insensitive, + test_query_by_year, + test_query_by_topic, + test_query_record_properties, + ], + "sentiment": [ + test_analyze_sentiment_free_text, + test_analyze_ticker_sentiment_known, + test_analyze_ticker_sentiment_unknown, + test_analyze_year_sentiment, + test_analyze_news_sentiment, + test_analyze_news_sentiment_no_results, + ], + "full": [ + test_all_tickers_sentiment, + ], +} + + +def main() -> None: + mode = sys.argv[1] if len(sys.argv) > 1 else "all" + + if mode == "interactive": + interactive() + return + + if mode == "all": + groups = ["schema", "data", "sentiment"] + elif mode in TEST_GROUPS: + groups = [mode] + else: + print(f"Unknown mode: {mode}") + print(f"Available: {', '.join(list(TEST_GROUPS.keys()) + ['all', 'interactive'])}") + sys.exit(1) + + passed = 0 + failed = 0 + + for group in groups: + for test_fn in TEST_GROUPS[group]: + try: + test_fn() + passed += 1 + except Exception as exc: + _fail(f"{test_fn.__name__}: {exc}") + failed += 1 + + _section("Summary") + print(f" Passed: {passed}") + print(f" Failed: {failed}") + + if failed: + sys.exit(1) + + +if __name__ == "__main__": + main() diff --git a/src/hitachione/tools/sentiment_analysis_tool/tool.py b/src/hitachione/tools/sentiment_analysis_tool/tool.py new file mode 100644 index 0000000..ccde4c0 --- /dev/null +++ b/src/hitachione/tools/sentiment_analysis_tool/tool.py @@ -0,0 +1,658 @@ +"""Sentiment analysis tool backed by Weaviate knowledge base. + +Queries the Weaviate financial news collection for data related to a ticker, +year, or topic, then uses an LLM to produce a structured sentiment assessment. +""" + +from __future__ import annotations + +import asyncio +import json +import os +from pathlib import Path +from typing import Any + +from dotenv import load_dotenv + +# Load .env from project root (5 levels up from this file) +load_dotenv(Path(__file__).resolve().parents[4] / ".env") + +import sys +sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent)) + +import weaviate +from weaviate.auth import AuthApiKey +from weaviate.classes.query import Filter +from openai import AsyncOpenAI + +from utils.env_vars import Configs + + +# --------------------------------------------------------------------------- +# Weaviate helper (sync, for data retrieval) +# --------------------------------------------------------------------------- + +WEAVIATE_COLLECTION = os.getenv("WEAVIATE_COLLECTION_NAME", "Hitachi_finance_news") + + +def _get_weaviate_sync_client(): + """Create a synchronous Weaviate client from environment variables.""" + http_host = os.getenv("WEAVIATE_HTTP_HOST", "localhost") + api_key = os.getenv("WEAVIATE_API_KEY", "") + + if http_host.endswith(".weaviate.cloud"): + return weaviate.connect_to_weaviate_cloud( + cluster_url=f"https://{http_host}", + auth_credentials=AuthApiKey(api_key), + ) + + return weaviate.connect_to_custom( + http_host=http_host, + http_port=int(os.getenv("WEAVIATE_HTTP_PORT", "8080")), + http_secure=os.getenv("WEAVIATE_HTTP_SECURE", "false").lower() == "true", + grpc_host=os.getenv("WEAVIATE_GRPC_HOST", "localhost"), + grpc_port=int(os.getenv("WEAVIATE_GRPC_PORT", "50051")), + grpc_secure=os.getenv("WEAVIATE_GRPC_SECURE", "false").lower() == "true", + auth_credentials=AuthApiKey(api_key), + ) + + +def query_weaviate_by_ticker( + ticker: str, + limit: int = 20, +) -> list[dict[str, Any]]: + """Retrieve news, earnings, and price data for a ticker from Weaviate.""" + client = _get_weaviate_sync_client() + try: + col = client.collections.get(WEAVIATE_COLLECTION) + ticker_filter = Filter.by_property("ticker").equal(ticker.upper()) + + response = col.query.fetch_objects( + filters=ticker_filter, + limit=limit, + return_properties=[ + "text", "title", "date", "category", "dataset_source", + "ticker", "company", "quarter", "fiscal_year", + ], + ) + + # Also grab news that *mention* this ticker + mentioned_response = col.query.fetch_objects( + filters=Filter.by_property("mentioned_companies").contains_any( + [ticker.upper()] + ), + limit=limit, + return_properties=[ + "text", "title", "date", "category", "dataset_source", + ], + ) + + seen_titles: set[str] = set() + results: list[dict] = [] + for obj in list(response.objects) + list(mentioned_response.objects): + props = {k: v for k, v in obj.properties.items() if v is not None} + title = props.get("title", "") + if title not in seen_titles: + seen_titles.add(title) + results.append(props) + + return sorted(results, key=lambda d: d.get("date", "")) + finally: + client.close() + + +def query_weaviate_by_year( + year: int, + limit: int = 20, +) -> list[dict[str, Any]]: + """Retrieve financial data for a specific year from Weaviate.""" + client = _get_weaviate_sync_client() + try: + col = client.collections.get(WEAVIATE_COLLECTION) + + results: list[dict] = [] + seen_titles: set[str] = set() + + for obj in col.iterator( + include_vector=False, + return_properties=[ + "text", "title", "date", "category", "dataset_source", + "ticker", "company", + ], + ): + date_str = obj.properties.get("date", "") + if date_str and str(year) in str(date_str)[:4]: + props = {k: v for k, v in obj.properties.items() if v is not None} + title = props.get("title", "") + if title not in seen_titles: + seen_titles.add(title) + results.append(props) + if len(results) >= limit: + break + + return sorted(results, key=lambda d: d.get("date", "")) + finally: + client.close() + + +def query_weaviate_by_topic( + topic: str, + limit: int = 10, +) -> list[dict[str, Any]]: + """Retrieve articles matching a topic via keyword search in Weaviate.""" + client = _get_weaviate_sync_client() + try: + col = client.collections.get(WEAVIATE_COLLECTION) + + response = col.query.bm25( + query=topic, + limit=limit, + return_properties=[ + "text", "title", "date", "category", "dataset_source", + "ticker", "company", + ], + ) + + results = [] + for obj in response.objects: + props = {k: v for k, v in obj.properties.items() if v is not None} + results.append(props) + + return sorted(results, key=lambda d: d.get("date", "")) + finally: + client.close() + + +# --------------------------------------------------------------------------- +# System prompts +# --------------------------------------------------------------------------- + +# Rating scale used across all prompts: +# 1-4 → Negative +# 5 → Neutral +# 6-10 → Positive + +_RATING_SCALE_TEXT = ( + "Use a 1-10 sentiment rating scale where: " + "1-4 = Negative, 5 = Neutral, 6-10 = Positive. " +) + +SYSTEM_PROMPT = ( + "You are a sentiment analysis agent. " + + _RATING_SCALE_TEXT + + "Return JSON with keys: rating (integer 1-10), " + "rationale (short string explaining the score)." +) + +NEWS_SYSTEM_PROMPT = ( + "You are a financial news sentiment analyst. " + "Given financial data (news articles, earnings transcripts, price data), " + + _RATING_SCALE_TEXT + + "Return JSON with keys: rating (integer 1-10), " + "rationale (short string explaining the score), " + "references (array of short quoted phrases from the data)." +) + +YEAR_SYSTEM_PROMPT = ( + "You are a financial news sentiment analyst. " + "Given snippets from financial news in a specific year, " + + _RATING_SCALE_TEXT + + "Return JSON with keys: year (integer), rating (integer 1-10), " + "rationale (short string explaining the score)." +) + +TICKER_SYSTEM_PROMPT = ( + "You are a financial sentiment analyst. " + "Given all available data (price history, earnings transcripts, news articles) " + "for a specific stock ticker, provide an overall sentiment assessment. " + + _RATING_SCALE_TEXT + + "Return JSON with keys: ticker (string), rating (integer 1-10), " + "rationale (short string explaining the score), " + "references (array of short quoted phrases from the data)." +) + +EXAMPLE_TEXT = ( + "The market showed resilience today despite inflation fears, " + "with tech stocks leading the recovery." +) + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def _run_async(coro): + """Run an async coroutine, handling nested event loops.""" + try: + loop = asyncio.get_running_loop() + except RuntimeError: + loop = None + + if loop and loop.is_running(): + import concurrent.futures + with concurrent.futures.ThreadPoolExecutor() as pool: + return pool.submit(asyncio.run, coro).result() + return asyncio.run(coro) + + +# --------------------------------------------------------------------------- +# Core analysis functions +# --------------------------------------------------------------------------- + +def _derive_label(rating: int | float | None) -> str: + """Derive a sentiment label from a 1-10 rating. + + Scale: + 1-4 → Negative + 5 → Neutral + 6-10 → Positive + """ + if rating is None: + return "unknown" + r = int(round(rating)) + if r <= 4: + return "Negative" + elif r == 5: + return "Neutral" + else: + return "Positive" + + +def _parse_json_response(content: str) -> dict[str, Any]: + """Parse JSON from LLM response, handling code fences.""" + content = content.strip() + if "```json" in content: + content = content.split("```json")[1].split("```")[0].strip() + elif "```" in content: + content = content.split("```")[1].split("```")[0].strip() + try: + parsed = json.loads(content) + # If the LLM returns a list, wrap the first element or flatten + if isinstance(parsed, list): + return parsed[0] if parsed else {"rating": None, "label": "unknown", "rationale": content} + return parsed + except json.JSONDecodeError: + return {"rating": None, "label": "unknown", "rationale": content} + + +def _format_kb_data(records: list[dict[str, Any]], max_chars: int = 15000) -> str: + """Format Weaviate records into a text block for the LLM.""" + parts: list[str] = [] + total = 0 + for rec in records: + source = rec.get("dataset_source", "unknown") + date = rec.get("date", "") + title = rec.get("title", "") + text = str(rec.get("text", ""))[:1000] + entry = f"[{source} | {date}] {title}\n{text}" + if total + len(entry) > max_chars: + break + parts.append(entry) + total += len(entry) + return "\n\n".join(parts) + + +async def analyze_sentiment(text: str, model: str | None = None) -> dict[str, Any]: + """Analyze sentiment for arbitrary text.""" + configs = Configs() + client = AsyncOpenAI( + base_url=configs.openai_base_url, + api_key=configs.openai_api_key, + ) + + try: + response = await client.chat.completions.create( + model=model or configs.default_worker_model, + messages=[ + {"role": "system", "content": SYSTEM_PROMPT}, + {"role": "user", "content": text}, + ], + temperature=0, + ) + content = (response.choices[0].message.content or "").strip() + finally: + await client.close() + + data = _parse_json_response(content) + # Normalise to rating-based output + rating = data.get("rating") + if rating is not None: + rating = int(round(rating)) + data["rating"] = rating + data["label"] = _derive_label(rating) + data.setdefault("rationale", "") + # Remove legacy score key if present + data.pop("score", None) + return data + + +async def analyze_ticker_sentiment( + ticker: str, + model: str | None = None, + limit: int = 20, +) -> dict[str, Any]: + """Analyze sentiment for a stock ticker using Weaviate KB data. + + Retrieves all available data (price history, earnings, news) for the + ticker from the Weaviate knowledge base and produces a sentiment rating. + """ + configs = Configs() + client = AsyncOpenAI( + base_url=configs.openai_base_url, + api_key=configs.openai_api_key, + ) + + records = query_weaviate_by_ticker(ticker, limit=limit) + + if not records: + return { + "ticker": ticker.upper(), + "rating": None, + "label": "unknown", + "rationale": f"No data found for ticker {ticker} in the knowledge base.", + "references": [], + } + + combined = _format_kb_data(records) + + try: + response = await client.chat.completions.create( + model=model or configs.default_worker_model, + messages=[ + {"role": "system", "content": TICKER_SYSTEM_PROMPT}, + {"role": "user", "content": combined}, + ], + temperature=0, + ) + content = (response.choices[0].message.content or "").strip() + finally: + await client.close() + + data = _parse_json_response(content) + rating = data.get("rating") + if rating is not None: + rating = int(round(rating)) + data["ticker"] = ticker.upper() + data["rating"] = rating + data["label"] = _derive_label(rating) + data.setdefault("rationale", "") + data.setdefault("references", []) + return data + + +async def analyze_financial_year_sentiment( + year: int, + model: str | None = None, + max_results: int = 20, +) -> dict[str, Any]: + """Analyze sentiment for financial news in a given year using Weaviate KB.""" + configs = Configs() + client = AsyncOpenAI( + base_url=configs.openai_base_url, + api_key=configs.openai_api_key, + ) + + records = query_weaviate_by_year(year, limit=max_results) + + if not records: + return { + "year": year, + "rating": None, + "label": "unknown", + "rationale": f"No data found for year {year} in the knowledge base.", + } + + combined = _format_kb_data(records) + + try: + response = await client.chat.completions.create( + model=model or configs.default_worker_model, + messages=[ + {"role": "system", "content": YEAR_SYSTEM_PROMPT}, + {"role": "user", "content": combined}, + ], + temperature=0, + ) + content = (response.choices[0].message.content or "").strip() + finally: + await client.close() + + data = _parse_json_response(content) + rating = data.get("rating") + if rating is not None: + rating = int(round(rating)) + data["year"] = year + data["rating"] = rating + data["label"] = _derive_label(rating) + data.setdefault("rationale", "") + # Remove legacy score key if present + data.pop("score", None) + return data + + +async def analyze_financial_news_sentiment( + query: str, + model: str | None = None, + limit: int = 10, +) -> dict[str, Any]: + """Analyze sentiment for financial news matching a topic query. + + Searches the Weaviate KB for articles matching the query and produces + a sentiment rating based on the retrieved content. + """ + configs = Configs() + client = AsyncOpenAI( + base_url=configs.openai_base_url, + api_key=configs.openai_api_key, + ) + + records = query_weaviate_by_topic(query, limit=limit) + + if not records: + return { + "rating": None, + "label": "unknown", + "rationale": f"No articles found for query '{query}' in the knowledge base.", + "references": [], + } + + combined = _format_kb_data(records) + + try: + response = await client.chat.completions.create( + model=model or configs.default_worker_model, + messages=[ + {"role": "system", "content": NEWS_SYSTEM_PROMPT}, + {"role": "user", "content": combined}, + ], + temperature=0, + ) + content = (response.choices[0].message.content or "").strip() + finally: + await client.close() + + data = _parse_json_response(content) + rating = data.get("rating") + if rating is not None: + rating = int(round(rating)) + data["rating"] = rating + data["label"] = _derive_label(rating) + data.setdefault("rationale", "") + data.setdefault("references", []) + return data + + +# --------------------------------------------------------------------------- +# Synchronous wrappers for external callers +# --------------------------------------------------------------------------- + +def analyze_sentiment_sync(text: str, model: str | None = None) -> dict[str, Any]: + """Synchronous wrapper for ``analyze_sentiment``.""" + return _run_async(analyze_sentiment(text, model=model)) + + +def analyze_ticker_sentiment_sync( + ticker: str, model: str | None = None, limit: int = 20, +) -> dict[str, Any]: + """Synchronous wrapper for ``analyze_ticker_sentiment``.""" + return _run_async(analyze_ticker_sentiment(ticker, model=model, limit=limit)) + + +def analyze_year_sentiment_sync( + year: int, model: str | None = None, max_results: int = 20, +) -> dict[str, Any]: + """Synchronous wrapper for ``analyze_financial_year_sentiment``.""" + return _run_async( + analyze_financial_year_sentiment(year, model=model, max_results=max_results) + ) + + +def analyze_news_sentiment_sync( + query: str, model: str | None = None, limit: int = 10, +) -> dict[str, Any]: + """Synchronous wrapper for ``analyze_financial_news_sentiment``.""" + return _run_async( + analyze_financial_news_sentiment(query, model=model, limit=limit) + ) + + +# --------------------------------------------------------------------------- +# OpenAI tool schemas +# --------------------------------------------------------------------------- + +TOOL_SCHEMA_TICKER = { + "type": "function", + "function": { + "name": "analyze_ticker_sentiment", + "description": ( + "Analyze the overall sentiment for a stock ticker using the " + "Weaviate financial knowledge base. Returns a sentiment rating " + "(1-10: 1-4 Negative, 5 Neutral, 6-10 Positive), " + "label, rationale, and supporting references." + ), + "parameters": { + "type": "object", + "properties": { + "ticker": { + "type": "string", + "description": "Stock ticker symbol, e.g. 'AAPL', 'TSLA'.", + }, + }, + "required": ["ticker"], + }, + }, +} + +TOOL_SCHEMA_YEAR = { + "type": "function", + "function": { + "name": "analyze_year_sentiment", + "description": ( + "Analyze the overall financial-news sentiment for a given year " + "using the Weaviate knowledge base. Returns a rating " + "(1-10: 1-4 Negative, 5 Neutral, 6-10 Positive), label, and rationale." + ), + "parameters": { + "type": "object", + "properties": { + "year": { + "type": "integer", + "description": "Four-digit year, e.g. 2024.", + }, + }, + "required": ["year"], + }, + }, +} + +TOOL_SCHEMA_NEWS = { + "type": "function", + "function": { + "name": "analyze_news_sentiment", + "description": ( + "Search the Weaviate knowledge base for financial news matching " + "a topic query and return a sentiment rating " + "(1-10: 1-4 Negative, 5 Neutral, 6-10 Positive), label, and rationale." + ), + "parameters": { + "type": "object", + "properties": { + "query": { + "type": "string", + "description": "Natural-language topic query, e.g. 'tech earnings Q3'.", + }, + }, + "required": ["query"], + }, + }, +} + +TOOL_SCHEMA_TEXT = { + "type": "function", + "function": { + "name": "analyze_sentiment", + "description": ( + "Classify arbitrary text sentiment on a 1-10 scale: " + "1-4 = Negative, 5 = Neutral, 6-10 = Positive. " + "Returns a rating, label, and rationale." + ), + "parameters": { + "type": "object", + "properties": { + "text": { + "type": "string", + "description": "Free-form text to analyze.", + }, + }, + "required": ["text"], + }, + }, +} + +# Convenience list of all schemas +TOOL_SCHEMAS = [ + TOOL_SCHEMA_TICKER, + TOOL_SCHEMA_YEAR, + TOOL_SCHEMA_NEWS, + TOOL_SCHEMA_TEXT, +] + +TOOL_IMPLEMENTATIONS = { + "analyze_ticker_sentiment": analyze_ticker_sentiment_sync, + "analyze_year_sentiment": analyze_year_sentiment_sync, + "analyze_news_sentiment": analyze_news_sentiment_sync, + "analyze_sentiment": analyze_sentiment_sync, +} + + +# --------------------------------------------------------------------------- +# CLI (for quick interactive testing) +# --------------------------------------------------------------------------- + +def _format_output(data: dict[str, Any]) -> str: + return json.dumps(data, indent=2, ensure_ascii=False) + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser(description="Sentiment analysis tool (Weaviate KB)") + parser.add_argument("text", nargs="?", default=None, help="Free text to analyze") + parser.add_argument("--model", help="Optional model override") + parser.add_argument("--ticker", help="Analyze sentiment for a stock ticker") + parser.add_argument("--year", type=int, help="Analyze sentiment for a year") + parser.add_argument("--query", help="Search KB for matching financial news") + args = parser.parse_args() + + if args.ticker: + result = analyze_ticker_sentiment_sync(args.ticker, model=args.model) + elif args.year is not None: + result = analyze_year_sentiment_sync(args.year, model=args.model) + elif args.query: + result = analyze_news_sentiment_sync(args.query, model=args.model) + elif args.text: + result = analyze_sentiment_sync(args.text, model=args.model) + else: + result = analyze_ticker_sentiment_sync("AAPL", model=args.model) + + print(_format_output(result)) diff --git a/src/hitachione/ui/__init__.py b/src/hitachione/ui/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/hitachione/ui/app.py b/src/hitachione/ui/app.py new file mode 100644 index 0000000..6c7b1c6 --- /dev/null +++ b/src/hitachione/ui/app.py @@ -0,0 +1,72 @@ +"""Gradio UI for the multi-agent financial intelligence system. + +Uses gr.ChatInterface (same pattern as the other bootcamp apps) +so the user types a financial prompt and gets a Markdown answer. +""" + +from __future__ import annotations + +import logging + +import gradio as gr + +from ..agents.orchestrator import Orchestrator + +logger = logging.getLogger(__name__) + +_orchestrator = Orchestrator() + + +def _respond( + message: str, + history: list[dict], +) -> str: + """Process one user query and return the assistant's answer. + + The orchestrator can take 30-60 s, during which Gradio shows a spinner. + """ + if not message.strip(): + return "*Please enter a financial query.*" + + try: + answer = _orchestrator.run(message.strip()) + except Exception as exc: + logger.exception("Orchestrator error") + return f"**Error:** {exc}" + + # Build final output + parts = [answer.markdown] + + if answer.caveats: + parts.append("\n---\n**Caveats:**") + for c in answer.caveats: + parts.append(f"- {c}") + + if answer.citations: + parts.append("\n**Citations:**") + for i, cit in enumerate(answer.citations[:5], 1): + parts.append(f"{i}. {cit}") + + parts.append(f"\n*Confidence: {answer.confidence:.0%}*") + + return "\n".join(parts) + + +def build_app() -> gr.ChatInterface: + """Build and return the Gradio ChatInterface app.""" + demo = gr.ChatInterface( + fn=_respond, + title="🏦 Financial Intelligence Agent", + description=( + "Ask any financial question – ranking, comparison, snapshot, " + "event reaction, and more." + ), + examples=[ + "Top 3 tech stocks of 2024", + "Compare TSLA vs AAPL vs NVDA", + "What moved NVDA after last earnings?", + ], + chatbot=gr.Chatbot(height=600), + textbox=gr.Textbox(lines=1, placeholder="Enter your financial query…"), + ) + return demo diff --git a/tests/test_hitachione/__init__.py b/tests/test_hitachione/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_hitachione/conftest.py b/tests/test_hitachione/conftest.py new file mode 100644 index 0000000..d5facac --- /dev/null +++ b/tests/test_hitachione/conftest.py @@ -0,0 +1,22 @@ +"""Unit tests for the hitachione multi-agent financial intelligence system. + +These tests use **mocks** for all external dependencies (LLM, Weaviate, network) +so they run fast, offline, and deterministically. + +Run all tests:: + + uv run --env-file .env pytest -sv tests/test_hitachione/ + +Run a single file:: + + uv run --env-file .env pytest -sv tests/test_hitachione/test_orchestrator.py + +Organisation +------------ +- ``test_schemas.py`` – data-model invariants +- ``test_knowledge_retrieval.py`` – KB agent (mocked Weaviate) +- ``test_researcher.py`` – researcher parallel fan-out (mocked tools) +- ``test_synthesizer.py`` – synthesizer (mocked LLM) +- ``test_reviewer.py`` – reviewer quality gate (pure logic) +- ``test_orchestrator.py`` – end-to-end orchestrator workflow (all mocked) +""" diff --git a/tests/test_hitachione/test_knowledge_retrieval.py b/tests/test_hitachione/test_knowledge_retrieval.py new file mode 100644 index 0000000..c2d80e3 --- /dev/null +++ b/tests/test_hitachione/test_knowledge_retrieval.py @@ -0,0 +1,110 @@ +"""Tests for the Knowledge Retrieval agent with mocked Weaviate.""" + +from unittest.mock import patch, MagicMock + +import pytest +from src.hitachione.agents.knowledge_retrieval import KnowledgeRetrievalAgent +from src.hitachione.models.schemas import TaskContext + + +def _make_weaviate_obj(ticker, company, title="News", source="bloomberg", date="2024-01-01"): + """Create a mock Weaviate object with properties.""" + obj = MagicMock() + obj.properties = { + "ticker": ticker, + "company": company, + "title": title, + "dataset_source": source, + "date": date, + "text": f"Article about {company}", + } + return obj + + +class TestKnowledgeRetrievalAgent: + + @patch("src.hitachione.agents.knowledge_retrieval._weaviate_client") + def test_happy_path(self, mock_client_fn): + """BM25 returns objects → aliases and hints are extracted.""" + mock_client = MagicMock() + mock_client_fn.return_value = mock_client + mock_col = MagicMock() + mock_client.collections.get.return_value = mock_col + + mock_col.query.bm25.return_value.objects = [ + _make_weaviate_obj("AAPL", "Apple Inc."), + _make_weaviate_obj("AAPL", "Apple Inc.", title="Second article"), + _make_weaviate_obj("MSFT", "Microsoft Corporation"), + ] + + agent = KnowledgeRetrievalAgent() + ctx = TaskContext(user_query="tech stocks") + result = agent.run(ctx) + + assert "Apple Inc." in result["aliases"] + assert result["aliases"]["Apple Inc."] == "AAPL" + assert "AAPL" in result["entity_hints"] + assert "MSFT" in result["entity_hints"] + assert len(result["summaries"]) == 3 + mock_client.close.assert_called_once() + + @patch("src.hitachione.agents.knowledge_retrieval._weaviate_client") + def test_no_results(self, mock_client_fn): + """Empty BM25 result → empty but valid return structure.""" + mock_client = MagicMock() + mock_client_fn.return_value = mock_client + mock_col = MagicMock() + mock_client.collections.get.return_value = mock_col + mock_col.query.bm25.return_value.objects = [] + + agent = KnowledgeRetrievalAgent() + ctx = TaskContext(user_query="unknown sector") + result = agent.run(ctx) + + assert result["entity_hints"] == [] + assert result["aliases"] == {} + assert result["summaries"] == [] + + @patch("src.hitachione.agents.knowledge_retrieval._weaviate_client") + def test_weaviate_failure_graceful(self, mock_client_fn): + """If Weaviate throws, the agent returns partial results + logs uncertainty.""" + mock_client_fn.side_effect = RuntimeError("Connection refused") + + agent = KnowledgeRetrievalAgent() + ctx = TaskContext(user_query="test") + result = agent.run(ctx) + + assert result["entity_hints"] == [] + assert len(ctx.uncertainties) >= 1 + + @patch("src.hitachione.agents.knowledge_retrieval._weaviate_client") + def test_existing_entities_resolved(self, mock_client_fn): + """Entities already in ctx get added to aliases.""" + mock_client = MagicMock() + mock_client_fn.return_value = mock_client + mock_col = MagicMock() + mock_client.collections.get.return_value = mock_col + mock_col.query.bm25.return_value.objects = [] + + agent = KnowledgeRetrievalAgent() + ctx = TaskContext(user_query="compare TSLA") + ctx.entities = ["TSLA"] + result = agent.run(ctx) + + assert result["aliases"]["TSLA"] == "TSLA" + + @patch("src.hitachione.agents.knowledge_retrieval._weaviate_client") + def test_dedup_entity_hints(self, mock_client_fn): + """Same ticker from multiple BM25 hits should appear only once.""" + mock_client = MagicMock() + mock_client_fn.return_value = mock_client + mock_col = MagicMock() + mock_client.collections.get.return_value = mock_col + mock_col.query.bm25.return_value.objects = [ + _make_weaviate_obj("AAPL", "Apple Inc.", title="Article 1"), + _make_weaviate_obj("AAPL", "Apple Inc.", title="Article 2"), + ] + + agent = KnowledgeRetrievalAgent() + result = agent.run(TaskContext(user_query="apple")) + assert result["entity_hints"].count("AAPL") == 1 diff --git a/tests/test_hitachione/test_orchestrator.py b/tests/test_hitachione/test_orchestrator.py new file mode 100644 index 0000000..f242607 --- /dev/null +++ b/tests/test_hitachione/test_orchestrator.py @@ -0,0 +1,412 @@ +"""End-to-end tests for the Orchestrator ReAct loop (all externals mocked).""" + +from unittest.mock import MagicMock, patch + +import pytest +from src.hitachione.models.schemas import ( + CompanyResearch, Intent, ReviewFeedback, SynthesizedAnswer, TaskContext, +) + +# Patch OpenAI before importing Orchestrator so __init__ gets the mock +_mock_openai_cls = patch("src.hitachione.agents.orchestrator.OpenAI", MagicMock()) +_mock_openai_cls.start() + +from src.hitachione.agents.orchestrator import Orchestrator # noqa: E402 + + +def _quick_answer(confidence: float = 0.85, markdown: str = "## Report\nAll good."): + a = SynthesizedAnswer() + a.markdown = markdown + a.confidence = confidence + a.caveats = [] + a.raw_research = [] + return a + + +def _ok_feedback(): + return ReviewFeedback(ok=True, missing=[], notes="All checks passed") + + +def _bad_feedback(missing=None, retriable=True): + return ReviewFeedback( + ok=False, + retriable=retriable, + missing=missing or ["AAPL not researched"], + notes="Entity coverage incomplete", + ) + + +# ── Happy-path ────────────────────────────────────────────────────────── + + +class TestOrchestratorHappyPath: + + @patch.object(Orchestrator, "_parse_intent") + def test_single_entity_happy_path(self, mock_parse): + """Intent → KB → Research → Synthesize → Review OK → return.""" + # _parse_intent sets entities directly on ctx + def parse_side_effect(ctx): + ctx.intent = Intent.SNAPSHOT + ctx.entities = ["AAPL"] + ctx.timeframe = "" + ctx.sector = "" + + mock_parse.side_effect = parse_side_effect + + orch = Orchestrator(max_iterations=2) + + # KB returns nothing new + orch.kb_agent = MagicMock() + orch.kb_agent.run.return_value = { + "aliases": {}, + "entity_hints": [], + "summaries": [], + } + + # Researcher returns one CompanyResearch + cr = CompanyResearch(ticker="AAPL") + cr.sentiment = {"rating": 8, "label": "Positive", "rationale": "Good"} + cr.performance = {"score": 9, "outlook": "Strong"} + orch.researcher = MagicMock() + orch.researcher.run.return_value = [cr] + + # Synthesizer returns a nice answer + orch.synthesizer = MagicMock() + orch.synthesizer.run.return_value = _quick_answer() + + # Reviewer says OK + orch.reviewer = MagicMock() + orch.reviewer.run.return_value = _ok_feedback() + + answer = orch.run("Tell me about AAPL") + assert answer.confidence >= 0.8 + assert "Report" in answer.markdown + + +# ── No-entities path ──────────────────────────────────────────────────── + + +class TestNoEntitiesPath: + + @patch.object(Orchestrator, "_parse_intent") + @patch("src.hitachione.agents.orchestrator._find_symbols", return_value=[]) + def test_no_entities_returns_caveat(self, mock_find, mock_parse): + """If no entities found at all, return user-facing caveat.""" + def parse_side_effect(ctx): + ctx.intent = Intent.RANK + ctx.entities = [] + ctx.timeframe = "" + ctx.sector = "" + + mock_parse.side_effect = parse_side_effect + + orch = Orchestrator(max_iterations=1) + orch.kb_agent = MagicMock() + orch.kb_agent.run.return_value = { + "aliases": {}, "entity_hints": [], "summaries": [], + } + + answer = orch.run("random question with no tickers") + assert answer.confidence == 0.0 + assert "tickers" in answer.markdown.lower() or "identify" in answer.markdown.lower() + + +# ── Reviewer rejection + reflection loop ──────────────────────────────── + + +class TestReflectionLoop: + + @patch.object(Orchestrator, "_parse_intent") + def test_reviewer_rejects_then_accepts(self, mock_parse): + """Reviewer fails on iter 1, Orchestrator retries, reviewer passes on iter 2.""" + def parse_side_effect(ctx): + ctx.intent = Intent.COMPARE + ctx.entities = ["AAPL", "MSFT"] + ctx.timeframe = "" + ctx.sector = "" + + mock_parse.side_effect = parse_side_effect + + orch = Orchestrator(max_iterations=3) + + orch.kb_agent = MagicMock() + orch.kb_agent.run.return_value = { + "aliases": {}, "entity_hints": [], "summaries": [], + } + + cr_a = CompanyResearch(ticker="AAPL") + cr_a.sentiment = {"rating": 6, "label": "Positive"} + cr_a.performance = {"score": 7} + cr_m = CompanyResearch(ticker="MSFT") + cr_m.sentiment = {"rating": 7, "label": "Positive"} + cr_m.performance = {"score": 8} + orch.researcher = MagicMock() + orch.researcher.run.return_value = [cr_a, cr_m] + + # Synthesizer: weak on iter 1, strong on iter 2 + weak = _quick_answer(confidence=0.4, markdown="Meh") + strong = _quick_answer(confidence=0.9, markdown="## Comprehensive\nGreat comparison.") + orch.synthesizer = MagicMock() + orch.synthesizer.run.side_effect = [weak, strong] + + # Reviewer: reject first, accept second + orch.reviewer = MagicMock() + orch.reviewer.run.side_effect = [ + _bad_feedback(["confidence below threshold"]), + _ok_feedback(), + ] + + answer = orch.run("Compare AAPL vs MSFT") + assert answer.confidence == 0.9 + assert orch.reviewer.run.call_count == 2 + + @patch.object(Orchestrator, "_parse_intent") + def test_max_iterations_reached(self, mock_parse): + """If reviewer never accepts, the answer is returned after max iterations with a caveat.""" + def parse_side_effect(ctx): + ctx.intent = Intent.SNAPSHOT + ctx.entities = ["TSLA"] + ctx.timeframe = "" + ctx.sector = "" + + mock_parse.side_effect = parse_side_effect + + orch = Orchestrator(max_iterations=2) + orch.kb_agent = MagicMock() + orch.kb_agent.run.return_value = { + "aliases": {}, "entity_hints": [], "summaries": [], + } + + cr = CompanyResearch(ticker="TSLA") + cr.sentiment = {"rating": 3, "label": "Negative"} + cr.performance = {"score": 4} + orch.researcher = MagicMock() + orch.researcher.run.return_value = [cr] + + orch.synthesizer = MagicMock() + orch.synthesizer.run.return_value = _quick_answer(confidence=0.3, markdown="Meh") + + orch.reviewer = MagicMock() + orch.reviewer.run.return_value = _bad_feedback(retriable=True) + + answer = orch.run("TSLA outlook") + assert any("incomplete" in c.lower() or "iteration" in c.lower() for c in answer.caveats) + assert orch.reviewer.run.call_count == 2 + + @patch.object(Orchestrator, "_parse_intent") + def test_non_retriable_stops_early(self, mock_parse): + """If reviewer says issues are NOT retriable (no KB data), stop on iteration 1.""" + def parse_side_effect(ctx): + ctx.intent = Intent.RANK + ctx.entities = ["XOM", "CVX"] + ctx.timeframe = "" + ctx.sector = "oil" + + mock_parse.side_effect = parse_side_effect + + orch = Orchestrator(max_iterations=3) + orch.kb_agent = MagicMock() + orch.kb_agent.run.return_value = { + "aliases": {}, "entity_hints": [], "summaries": [], + } + + cr_xom = CompanyResearch( + ticker="XOM", + sentiment={"rating": None, "rationale": "No data found for ticker XOM in the knowledge base."}, + performance={"performance_score": None, "justification": "No data found for ticker XOM in the knowledge base."}, + ) + cr_cvx = CompanyResearch( + ticker="CVX", + sentiment={"rating": None, "rationale": "No data found for ticker CVX in the knowledge base."}, + performance={"performance_score": None, "justification": "No data found for ticker CVX in the knowledge base."}, + ) + orch.researcher = MagicMock() + orch.researcher.run.return_value = [cr_xom, cr_cvx] + + ans = _quick_answer(confidence=0.0, markdown="No data available for these tickers.") + ans.raw_research = [cr_xom, cr_cvx] + orch.synthesizer = MagicMock() + orch.synthesizer.run.return_value = ans + + # Don't mock reviewer – let the real reviewer run to verify retriable=False + from src.hitachione.agents.reviewer import ReviewerAgent + orch.reviewer = ReviewerAgent() + + answer = orch.run("List oil stocks") + # Should stop on iteration 1 (no retry loop) + assert orch.researcher.run.call_count == 1 + # Should NOT have "max iterations" caveat + assert not any("iteration" in c.lower() for c in answer.caveats) + # Should have "knowledge base" caveat + assert any("knowledge base" in c.lower() for c in answer.caveats) + + +# ── Company retrieval gating ──────────────────────────────────────────── + + +class TestCompanyRetrievalGating: + + @patch.object(Orchestrator, "_parse_intent") + @patch("src.hitachione.agents.orchestrator._find_symbols") + def test_broad_query_calls_company_filter(self, mock_find, mock_parse): + """When intent has no explicit entities, company_filter is called.""" + def parse_side_effect(ctx): + ctx.intent = Intent.RANK + ctx.entities = [] # no explicit entities + ctx.timeframe = "" + ctx.sector = "tech" + + mock_parse.side_effect = parse_side_effect + mock_find.return_value = ["AAPL", "MSFT", "GOOGL"] + + orch = Orchestrator(max_iterations=1) + orch.kb_agent = MagicMock() + orch.kb_agent.run.return_value = { + "aliases": {}, "entity_hints": [], "summaries": [], + } + + cr_a = CompanyResearch(ticker="AAPL") + cr_a.sentiment = {"rating": 8, "label": "Positive"} + cr_a.performance = {"score": 9} + cr_m = CompanyResearch(ticker="MSFT") + cr_m.sentiment = {"rating": 7, "label": "Positive"} + cr_m.performance = {"score": 8} + cr_g = CompanyResearch(ticker="GOOGL") + cr_g.sentiment = {"rating": 7, "label": "Positive"} + cr_g.performance = {"score": 7} + orch.researcher = MagicMock() + orch.researcher.run.return_value = [cr_a, cr_m, cr_g] + orch.synthesizer = MagicMock() + orch.synthesizer.run.return_value = _quick_answer() + orch.reviewer = MagicMock() + orch.reviewer.run.return_value = _ok_feedback() + + answer = orch.run("Top tech stocks") + mock_find.assert_called_once() + + @patch.object(Orchestrator, "_parse_intent") + @patch("src.hitachione.agents.orchestrator._find_symbols") + def test_explicit_entities_skip_company_filter(self, mock_find, mock_parse): + """When intent has explicit entities, company_filter is NOT called.""" + def parse_side_effect(ctx): + ctx.intent = Intent.SNAPSHOT + ctx.entities = ["TSLA"] + ctx.timeframe = "" + ctx.sector = "" + + mock_parse.side_effect = parse_side_effect + + orch = Orchestrator(max_iterations=1) + orch.kb_agent = MagicMock() + orch.kb_agent.run.return_value = { + "aliases": {}, "entity_hints": [], "summaries": [], + } + + cr = CompanyResearch(ticker="TSLA") + cr.sentiment = {"rating": 5, "label": "Neutral"} + cr.performance = {"score": 6} + orch.researcher = MagicMock() + orch.researcher.run.return_value = [cr] + orch.synthesizer = MagicMock() + orch.synthesizer.run.return_value = _quick_answer() + orch.reviewer = MagicMock() + orch.reviewer.run.return_value = _ok_feedback() + + answer = orch.run("TSLA outlook") + mock_find.assert_not_called() + + +# ── Plan generation ───────────────────────────────────────────────────── + + +class TestPlanGeneration: + + def test_rank_intent_plan(self): + orch = Orchestrator() + ctx = TaskContext(user_query="Top 5 tech stocks") + ctx.intent = Intent.RANK + ctx.entities = ["AAPL", "MSFT"] + plan = orch._plan(ctx) + assert any("rank" in step.lower() or "score" in step.lower() for step in plan) + + def test_empty_entities_plan(self): + orch = Orchestrator() + ctx = TaskContext(user_query="Energy stocks") + ctx.intent = Intent.RANK + ctx.entities = [] + plan = orch._plan(ctx) + assert any("discover" in step.lower() for step in plan) + + def test_snapshot_intent_plan(self): + orch = Orchestrator() + ctx = TaskContext(user_query="TSLA snapshot") + ctx.intent = Intent.SNAPSHOT + ctx.entities = ["TSLA"] + plan = orch._plan(ctx) + assert any("latest" in step.lower() or "fetch" in step.lower() for step in plan) + + +# ── Reflect ───────────────────────────────────────────────────────────── + + +class TestReflect: + + def test_reflect_missing_entities(self): + orch = Orchestrator() + ctx = TaskContext(user_query="test") + ctx.entities = ["AAPL", "MSFT"] + fb = ReviewFeedback( + ok=False, + missing=["Entity AAPL not researched"], + notes="incomplete", + ) + adj = orch._reflect(ctx, fb) + assert adj["action"] == "retry_failed_entities" + assert "AAPL" in adj["entities"] + # ctx.entities narrowed to just the failed ticker + assert ctx.entities == ["AAPL"] + + def test_reflect_missing_performance_narrows_entities(self): + """If reviewer says 'META: missing performance score', retry just META.""" + orch = Orchestrator() + ctx = TaskContext(user_query="test") + ctx.entities = ["AAPL", "META", "MSFT"] + fb = ReviewFeedback( + ok=False, + missing=["META: missing performance score"], + notes="incomplete", + ) + adj = orch._reflect(ctx, fb) + assert adj["action"] == "retry_failed_entities" + assert ctx.entities == ["META"] + + def test_reflect_adds_missing_entity_to_ctx(self): + """_reflect should add the missing ticker back into ctx.entities.""" + orch = Orchestrator() + ctx = TaskContext(user_query="test") + ctx.entities = ["MSFT"] + fb = ReviewFeedback( + ok=False, + missing=["Entity AAPL not researched"], + notes="incomplete", + ) + orch._reflect(ctx, fb) + assert "AAPL" in ctx.entities + + def test_reflect_low_confidence(self): + orch = Orchestrator() + ctx = TaskContext(user_query="test") + fb = ReviewFeedback( + ok=False, + missing=["confidence too low"], + notes="bump up", + ) + adj = orch._reflect(ctx, fb) + assert adj["action"] == "broaden_search" + + def test_reflect_no_actionable_items(self): + orch = Orchestrator() + ctx = TaskContext(user_query="test") + fb = ReviewFeedback(ok=False, missing=["answer is short"], notes="") + adj = orch._reflect(ctx, fb) + assert adj["action"] == "none" diff --git a/tests/test_hitachione/test_researcher.py b/tests/test_hitachione/test_researcher.py new file mode 100644 index 0000000..a907411 --- /dev/null +++ b/tests/test_hitachione/test_researcher.py @@ -0,0 +1,224 @@ +"""Tests for the Researcher agent with mocked tool calls.""" + +import time +from concurrent.futures import TimeoutError +from unittest.mock import patch, MagicMock, call + +import pytest +from src.hitachione.agents.researcher import ( + ResearcherAgent, _research_one, _call_with_retry, +) +from src.hitachione.models.schemas import CompanyResearch, TaskContext, ToolError + + +# ── Fixtures / helpers ────────────────────────────────────────────────── + +def _fake_sentiment(ticker: str) -> dict: + return {"rating": 8, "label": "Positive", "rationale": f"{ticker} looks good", + "references": [f"ref1-{ticker}", f"ref2-{ticker}"]} + + +def _fake_performance(ticker: str) -> dict: + return {"performance_score": 7, "outlook": "Bullish", + "justification": f"{ticker} strong growth"} + + +def _slow_sentiment(ticker: str) -> dict: + """Simulates a 0.3s network call.""" + time.sleep(0.3) + return _fake_sentiment(ticker) + + +def _slow_performance(ticker: str) -> dict: + """Simulates a 0.3s network call.""" + time.sleep(0.3) + return _fake_performance(ticker) + + +def _failing_sentiment(ticker: str) -> dict: + raise RuntimeError(f"Sentiment API down for {ticker}") + + +def _failing_performance(ticker: str) -> dict: + raise RuntimeError(f"Performance API down for {ticker}") + + +# ── _research_one ─────────────────────────────────────────────────────── + +# Disable retry delays / limit retries to 1 for fast failure tests +_fast_fail = [ + patch("src.hitachione.agents.researcher._RETRY_BACKOFF", 0.0), + patch("src.hitachione.agents.researcher._MAX_RETRIES", 1), +] + + +def _apply_fast_fail(fn): + """Stack _fast_fail patches onto a test function.""" + for p in reversed(_fast_fail): + fn = p(fn) + return fn + + +class TestResearchOne: + @patch("src.hitachione.agents.researcher._performance", side_effect=_fake_performance) + @patch("src.hitachione.agents.researcher._sentiment", side_effect=_fake_sentiment) + def test_happy_path(self, mock_sent, mock_perf): + cr = _research_one("AAPL") + assert cr.ticker == "AAPL" + assert cr.sentiment["rating"] == 8 + assert cr.performance["performance_score"] == 7 + assert cr.errors == [] + assert len(cr.news_snippets) == 2 + + @_apply_fast_fail + @patch("src.hitachione.agents.researcher._performance", side_effect=_fake_performance) + @patch("src.hitachione.agents.researcher._sentiment", side_effect=_failing_sentiment) + def test_sentiment_failure_captured(self, mock_sent, mock_perf): + cr = _research_one("TSLA") + assert cr.sentiment == {} + assert cr.performance["performance_score"] == 7 + assert len(cr.errors) == 1 + assert cr.errors[0].tool == "sentiment" + + @_apply_fast_fail + @patch("src.hitachione.agents.researcher._performance", side_effect=_failing_performance) + @patch("src.hitachione.agents.researcher._sentiment", side_effect=_fake_sentiment) + def test_performance_failure_captured(self, mock_sent, mock_perf): + cr = _research_one("MSFT") + assert cr.sentiment["rating"] == 8 + assert cr.performance == {} + assert len(cr.errors) == 1 + assert cr.errors[0].tool == "performance" + + @_apply_fast_fail + @patch("src.hitachione.agents.researcher._performance", side_effect=_failing_performance) + @patch("src.hitachione.agents.researcher._sentiment", side_effect=_failing_sentiment) + def test_both_failures(self, mock_sent, mock_perf): + cr = _research_one("JPM") + assert len(cr.errors) == 2 + + +# ── ResearcherAgent.run (parallel fan-out) ────────────────────────────── + +class TestResearcherAgentRun: + @patch("src.hitachione.agents.researcher._performance", side_effect=_fake_performance) + @patch("src.hitachione.agents.researcher._sentiment", side_effect=_fake_sentiment) + def test_multiple_entities(self, mock_sent, mock_perf): + agent = ResearcherAgent() + ctx = TaskContext(user_query="test") + results = agent.run(ctx, ["AAPL", "MSFT", "TSLA"]) + assert len(results) == 3 + assert [r.ticker for r in results] == ["AAPL", "MSFT", "TSLA"] + assert all(r.errors == [] for r in results) + + @patch("src.hitachione.agents.researcher._performance", side_effect=_fake_performance) + @patch("src.hitachione.agents.researcher._sentiment", side_effect=_fake_sentiment) + def test_preserves_order(self, mock_sent, mock_perf): + """Results should be in the same order as input entities.""" + agent = ResearcherAgent() + ctx = TaskContext(user_query="test") + order = ["NVDA", "AAPL", "GOOGL", "META"] + results = agent.run(ctx, order) + assert [r.ticker for r in results] == order + + @patch("src.hitachione.agents.researcher._performance", side_effect=_fake_performance) + @patch("src.hitachione.agents.researcher._sentiment", side_effect=_fake_sentiment) + def test_empty_entities(self, mock_sent, mock_perf): + agent = ResearcherAgent() + ctx = TaskContext(user_query="test") + results = agent.run(ctx, []) + assert results == [] + + @patch("src.hitachione.agents.researcher._performance", side_effect=_slow_performance) + @patch("src.hitachione.agents.researcher._sentiment", side_effect=_slow_sentiment) + def test_parallel_is_faster_than_sequential(self, mock_sent, mock_perf): + """With 4 tickers × 0.3s each, parallel should finish well under 4×0.6s.""" + agent = ResearcherAgent(max_workers=4) + ctx = TaskContext(user_query="test") + tickers = ["AAPL", "MSFT", "GOOGL", "NVDA"] + + start = time.time() + results = agent.run(ctx, tickers) + elapsed = time.time() - start + + assert len(results) == 4 + # Sequential would be ~4 × 0.6 = 2.4s. Parallel should be < 1.5s. + assert elapsed < 1.5, f"Parallel research took {elapsed:.2f}s – too slow" + + @_apply_fast_fail + @patch("src.hitachione.agents.researcher._performance", side_effect=_failing_performance) + @patch("src.hitachione.agents.researcher._sentiment", side_effect=_failing_sentiment) + def test_errors_propagated_to_context(self, mock_sent, mock_perf): + agent = ResearcherAgent() + ctx = TaskContext(user_query="test") + results = agent.run(ctx, ["AAPL"]) + assert len(ctx.uncertainties) >= 2 + assert any("sentiment" in u for u in ctx.uncertainties) + assert any("performance" in u for u in ctx.uncertainties) + + +# ── Timeout behaviour ────────────────────────────────────────────────── + +class TestResearcherTimeout: + @patch("src.hitachione.agents.researcher._TOOL_TIMEOUT", 0.1) + @patch("src.hitachione.agents.researcher._MAX_RETRIES", 1) + @patch("src.hitachione.agents.researcher._RETRY_BACKOFF", 0.0) + @patch("src.hitachione.agents.researcher._performance", side_effect=_fake_performance) + @patch("src.hitachione.agents.researcher._sentiment", side_effect=lambda t: time.sleep(5) or _fake_sentiment(t)) + def test_timeout_captured_as_error(self, mock_sent, mock_perf): + """If a tool exceeds _TOOL_TIMEOUT, a ToolError is recorded.""" + cr = _research_one("SLOW") + assert any(e.tool == "sentiment" and "timeout" in e.error for e in cr.errors) + # Performance should still succeed + assert cr.performance.get("performance_score") == 7 + + +# ── Retry behaviour ─────────────────────────────────────────────────── + +class TestRetryLogic: + @patch("src.hitachione.agents.researcher._RETRY_BACKOFF", 0.0) + @patch("src.hitachione.agents.researcher._MAX_RETRIES", 3) + def test_retry_succeeds_on_second_attempt(self): + """If a tool fails once then succeeds, the result is captured.""" + call_count = 0 + def flaky_sentiment(ticker: str) -> dict: + nonlocal call_count + call_count += 1 + if call_count == 1: + raise ConnectionError("Weaviate connection reset") + return _fake_sentiment(ticker) + + result = _call_with_retry(flaky_sentiment, "META", "sentiment") + assert result["rating"] == 8 + assert call_count == 2 # failed once, succeeded on retry + + @patch("src.hitachione.agents.researcher._RETRY_BACKOFF", 0.0) + @patch("src.hitachione.agents.researcher._MAX_RETRIES", 3) + def test_retry_exhausted_raises(self): + """If all retries fail, the last exception is raised.""" + def always_failing(ticker: str) -> dict: + raise ConnectionError("Weaviate down permanently") + + with pytest.raises(ConnectionError, match="permanently"): + _call_with_retry(always_failing, "META", "sentiment") + + @patch("src.hitachione.agents.researcher._RETRY_BACKOFF", 0.0) + @patch("src.hitachione.agents.researcher._MAX_RETRIES", 3) + @patch("src.hitachione.agents.researcher._performance", side_effect=_fake_performance) + def test_research_one_retries_transient_failure(self, mock_perf): + """_research_one recovers from a transient sentiment failure.""" + call_count = 0 + def flaky(ticker): + nonlocal call_count + call_count += 1 + if call_count <= 2: + raise ConnectionError("gRPC reset") + return _fake_sentiment(ticker) + + with patch("src.hitachione.agents.researcher._sentiment", side_effect=flaky): + cr = _research_one("META") + + assert cr.errors == [] + assert cr.sentiment["rating"] == 8 + assert cr.performance["performance_score"] == 7 + assert call_count == 3 # failed twice, succeeded on third diff --git a/tests/test_hitachione/test_reviewer.py b/tests/test_hitachione/test_reviewer.py new file mode 100644 index 0000000..4644201 --- /dev/null +++ b/tests/test_hitachione/test_reviewer.py @@ -0,0 +1,198 @@ +"""Tests for the Reviewer agent (pure deterministic logic, no mocks needed).""" + +import pytest +from src.hitachione.agents.reviewer import ReviewerAgent +from src.hitachione.models.schemas import ( + CompanyResearch, + ReviewFeedback, + SynthesizedAnswer, + TaskContext, + ToolError, +) + + +def _make_ctx(*entities: str) -> TaskContext: + ctx = TaskContext(user_query="test") + ctx.entities = list(entities) + return ctx + + +def _make_answer( + research: list[CompanyResearch], + markdown: str = "x" * 100, + confidence: float = 0.8, +) -> SynthesizedAnswer: + return SynthesizedAnswer( + markdown=markdown, + confidence=confidence, + raw_research=research, + ) + + +class TestReviewerAgent: + def setup_method(self): + self.reviewer = ReviewerAgent() + + # ── happy path ────────────────────────────────────────────────────── + + def test_all_checks_pass(self): + """Complete data → ReviewFeedback.ok == True.""" + cr = CompanyResearch( + ticker="AAPL", + sentiment={"rating": 8, "label": "Positive"}, + performance={"performance_score": 7, "outlook": "Bullish"}, + ) + ctx = _make_ctx("AAPL") + fb = self.reviewer.run(ctx, _make_answer([cr])) + assert fb.ok is True + assert fb.missing == [] + + # ── entity coverage ───────────────────────────────────────────────── + + def test_missing_entity_detected(self): + """If an entity in the context was not researched, flag it.""" + cr = CompanyResearch(ticker="AAPL", sentiment={"rating": 8}, + performance={"performance_score": 7}) + ctx = _make_ctx("AAPL", "MSFT") + fb = self.reviewer.run(ctx, _make_answer([cr])) + assert fb.ok is False + assert any("MSFT" in m for m in fb.missing) + + # ── score completeness ────────────────────────────────────────────── + + def test_missing_sentiment_rating(self): + cr = CompanyResearch(ticker="AAPL", sentiment={}, + performance={"performance_score": 7}) + ctx = _make_ctx("AAPL") + fb = self.reviewer.run(ctx, _make_answer([cr])) + assert fb.ok is False + assert fb.retriable is True + assert any("sentiment" in m.lower() for m in fb.missing) + + def test_missing_performance_score(self): + cr = CompanyResearch(ticker="AAPL", sentiment={"rating": 8}, + performance={}) + ctx = _make_ctx("AAPL") + fb = self.reviewer.run(ctx, _make_answer([cr])) + assert fb.ok is False + assert fb.retriable is True + assert any("performance" in m.lower() for m in fb.missing) + + # ── answer quality ────────────────────────────────────────────────── + + def test_short_answer_flagged(self): + cr = CompanyResearch(ticker="AAPL", sentiment={"rating": 8}, + performance={"performance_score": 7}) + ctx = _make_ctx("AAPL") + fb = self.reviewer.run(ctx, _make_answer([cr], markdown="short")) + assert fb.ok is False + assert any("too short" in m.lower() for m in fb.missing) + + # ── confidence threshold ──────────────────────────────────────────── + + def test_low_confidence_flagged(self): + cr = CompanyResearch(ticker="AAPL", sentiment={"rating": 8}, + performance={"performance_score": 7}) + ctx = _make_ctx("AAPL") + fb = self.reviewer.run(ctx, _make_answer([cr], confidence=0.2)) + assert fb.ok is False + assert any("confidence" in m.lower() for m in fb.missing) + + # ── multiple issues ───────────────────────────────────────────────── + + def test_multiple_issues_accumulated(self): + cr = CompanyResearch(ticker="AAPL", sentiment={}, performance={}) + ctx = _make_ctx("AAPL", "MSFT") + fb = self.reviewer.run(ctx, _make_answer([cr], markdown="x", confidence=0.1)) + assert fb.ok is False + assert len(fb.missing) >= 4 # missing entity, sentiment, perf, short, confidence + + # ── broad-query entity coverage ───────────────────────────────────── + + def test_broad_query_passes_with_80_percent_coverage(self): + """With >10 entities and ≥80% researched, entity coverage passes.""" + # 15 entities, 12 researched = 80% + all_tickers = [f"T{i:03d}" for i in range(15)] + researched = [ + CompanyResearch( + ticker=t, + sentiment={"rating": 7}, + performance={"performance_score": 6}, + ) + for t in all_tickers[:12] + ] + ctx = _make_ctx(*all_tickers) + fb = self.reviewer.run(ctx, _make_answer(researched)) + # Should NOT have entity coverage issues + coverage_issues = [m for m in fb.missing if "coverage" in m.lower() or "not researched" in m.lower()] + assert coverage_issues == [] + + def test_broad_query_fails_below_80_percent(self): + """With >10 entities and <80% researched, entity coverage fails.""" + # 15 entities, only 5 researched = 33% + all_tickers = [f"T{i:03d}" for i in range(15)] + researched = [ + CompanyResearch( + ticker=t, + sentiment={"rating": 7}, + performance={"performance_score": 6}, + ) + for t in all_tickers[:5] + ] + ctx = _make_ctx(*all_tickers) + fb = self.reviewer.run(ctx, _make_answer(researched)) + assert fb.ok is False + assert any("coverage" in m.lower() for m in fb.missing) + + # ── no-KB-data detection (not retriable) ──────────────────────── + + def test_no_kb_data_not_retriable(self): + """When tools return 'No data found in the knowledge base', issues are not retriable.""" + cr = CompanyResearch( + ticker="XOM", + sentiment={ + "rating": None, + "label": "unknown", + "rationale": "No data found for ticker XOM in the knowledge base.", + }, + performance={ + "performance_score": None, + "outlook": "Unknown", + "justification": "No data found for ticker XOM in the knowledge base.", + }, + ) + ctx = _make_ctx("XOM") + fb = self.reviewer.run(ctx, _make_answer([cr], confidence=0.0)) + assert fb.ok is False + assert fb.retriable is False + # Missing items should mention "knowledge base", NOT "missing ... rating" + assert all("knowledge base" in m.lower() or "confidence" in m.lower() + for m in fb.missing) + + def test_mixed_kb_data_and_error_is_retriable(self): + """If some tickers have no KB data and others have transient errors, retriable=True.""" + cr_ok = CompanyResearch( + ticker="AAPL", + sentiment={"rating": 8, "label": "Positive"}, + performance={"performance_score": 7, "outlook": "Bullish"}, + ) + cr_no_kb = CompanyResearch( + ticker="XOM", + sentiment={ + "rating": None, + "rationale": "No data found for ticker XOM in the knowledge base.", + }, + performance={ + "performance_score": None, + "justification": "No data found for ticker XOM in the knowledge base.", + }, + ) + cr_error = CompanyResearch( + ticker="META", + sentiment={}, # transient error – no rationale + performance={"performance_score": 7, "outlook": "Bullish"}, + ) + ctx = _make_ctx("AAPL", "XOM", "META") + fb = self.reviewer.run(ctx, _make_answer([cr_ok, cr_no_kb, cr_error])) + assert fb.ok is False + assert fb.retriable is True # META’s missing sentiment IS retriable diff --git a/tests/test_hitachione/test_schemas.py b/tests/test_hitachione/test_schemas.py new file mode 100644 index 0000000..d09b866 --- /dev/null +++ b/tests/test_hitachione/test_schemas.py @@ -0,0 +1,88 @@ +"""Tests for shared data models (schemas.py).""" + +import pytest +from src.hitachione.models.schemas import ( + CompanyResearch, + Intent, + ReviewFeedback, + SynthesizedAnswer, + TaskContext, + ToolError, +) + + +# ── Intent enum ────────────────────────────────────────────────────────── + +class TestIntent: + def test_all_values_exist(self): + """Every documented intent string should map to an enum member.""" + for val in ("rank", "compare", "snapshot", "event_reaction", + "fundamentals", "macro", "mixed"): + assert Intent(val).value == val + + def test_invalid_intent_raises(self): + with pytest.raises(ValueError): + Intent("nonexistent") + + +# ── TaskContext ────────────────────────────────────────────────────────── + +class TestTaskContext: + def test_defaults(self): + ctx = TaskContext(user_query="test") + assert ctx.user_query == "test" + assert ctx.intent == Intent.MIXED + assert ctx.entities == [] + assert ctx.observations == [] + assert ctx.iteration == 0 + assert len(ctx.run_id) == 12 + + def test_run_id_unique(self): + a = TaskContext(user_query="a") + b = TaskContext(user_query="b") + assert a.run_id != b.run_id + + def test_mutable_collections_independent(self): + """Default lists should not be shared across instances.""" + a = TaskContext(user_query="a") + b = TaskContext(user_query="b") + a.entities.append("AAPL") + assert "AAPL" not in b.entities + + +# ── CompanyResearch ────────────────────────────────────────────────────── + +class TestCompanyResearch: + def test_defaults(self): + cr = CompanyResearch(ticker="TSLA") + assert cr.ticker == "TSLA" + assert cr.sentiment == {} + assert cr.performance == {} + assert cr.errors == [] + + def test_errors_accumulate(self): + cr = CompanyResearch(ticker="AAPL") + cr.errors.append(ToolError(entity="AAPL", tool="sentiment", error="timeout")) + cr.errors.append(ToolError(entity="AAPL", tool="performance", error="404")) + assert len(cr.errors) == 2 + + +# ── SynthesizedAnswer ─────────────────────────────────────────────────── + +class TestSynthesizedAnswer: + def test_defaults(self): + ans = SynthesizedAnswer() + assert ans.markdown == "" + assert ans.confidence == 0.0 + assert ans.caveats == [] + assert ans.citations == [] + assert ans.raw_research == [] + + +# ── ReviewFeedback ────────────────────────────────────────────────────── + +class TestReviewFeedback: + def test_defaults_not_ok(self): + fb = ReviewFeedback() + assert fb.ok is False + assert fb.missing == [] diff --git a/tests/test_hitachione/test_synthesizer.py b/tests/test_hitachione/test_synthesizer.py new file mode 100644 index 0000000..f278fab --- /dev/null +++ b/tests/test_hitachione/test_synthesizer.py @@ -0,0 +1,144 @@ +"""Tests for the Synthesizer agent with mocked LLM calls.""" + +from unittest.mock import patch, MagicMock + +import pytest + +# Patch OpenAI before importing SynthesizerAgent so __init__ gets the mock +_mock_openai_cls = patch("src.hitachione.agents.synthesizer.OpenAI", MagicMock()) +_mock_openai_cls.start() + +from src.hitachione.agents.synthesizer import ( # noqa: E402 + SynthesizerAgent, + _build_research_block, + _estimate_confidence, +) +from src.hitachione.models.schemas import ( # noqa: E402 + CompanyResearch, + Intent, + SynthesizedAnswer, + TaskContext, + ToolError, +) + + +# ── Helpers ───────────────────────────────────────────────────────────── + +def _cr(ticker: str, rating=8, score=7) -> CompanyResearch: + return CompanyResearch( + ticker=ticker, + sentiment={"rating": rating, "label": "Positive", "rationale": "Good"}, + performance={"performance_score": score, "outlook": "Bullish", + "justification": "Strong growth"}, + news_snippets=[f"{ticker} news 1", f"{ticker} news 2"], + ) + + +def _ctx(query: str = "test", intent: Intent = Intent.COMPARE) -> TaskContext: + ctx = TaskContext(user_query=query) + ctx.intent = intent + return ctx + + +# ── _estimate_confidence ──────────────────────────────────────────────── + +class TestConfidenceEstimation: + def test_full_data(self): + assert _estimate_confidence([_cr("AAPL"), _cr("MSFT")]) == 1.0 + + def test_sentiment_only(self): + cr = CompanyResearch(ticker="X", sentiment={"rating": 5}) + assert _estimate_confidence([cr]) == 0.5 + + def test_performance_only(self): + cr = CompanyResearch(ticker="X", performance={"performance_score": 5}) + assert _estimate_confidence([cr]) == 0.5 + + def test_no_data(self): + cr = CompanyResearch(ticker="X") + assert _estimate_confidence([cr]) == 0.0 + + def test_empty_list(self): + assert _estimate_confidence([]) == 0.0 + + def test_mixed_coverage(self): + full = _cr("AAPL") + empty = CompanyResearch(ticker="X") + assert _estimate_confidence([full, empty]) == 0.5 + + +# ── _build_research_block ────────────────────────────────────────────── + +class TestBuildResearchBlock: + def test_includes_ticker_heading(self): + block = _build_research_block([_cr("AAPL")]) + assert "## AAPL" in block + + def test_includes_sentiment(self): + block = _build_research_block([_cr("AAPL")]) + assert "Sentiment" in block + assert "rating=8" in block + + def test_includes_performance(self): + block = _build_research_block([_cr("AAPL")]) + assert "Performance" in block + + def test_includes_errors(self): + cr = CompanyResearch(ticker="X") + cr.errors.append(ToolError(entity="X", tool="sentiment", error="timeout")) + block = _build_research_block([cr]) + assert "Data gaps" in block + + def test_multiple_companies(self): + block = _build_research_block([_cr("AAPL"), _cr("MSFT")]) + assert "AAPL" in block + assert "MSFT" in block + + +# ── SynthesizerAgent.run ──────────────────────────────────────────────── + +class TestSynthesizerAgent: + def _make_agent(self, output_text: str = "# Answer\nLooks great!"): + """Create a SynthesizerAgent with a mocked LLM that returns the given text.""" + agent = SynthesizerAgent() + mock_resp = MagicMock() + mock_resp.choices = [MagicMock()] + mock_resp.choices[0].message.content = output_text + agent._llm = MagicMock() + agent._llm.chat.completions.create.return_value = mock_resp + return agent + + def test_happy_path(self): + agent = self._make_agent("# Answer\nLooks great!") + ctx = _ctx() + ans = agent.run(ctx, [_cr("AAPL")]) + + assert isinstance(ans, SynthesizedAnswer) + assert "Answer" in ans.markdown + assert ans.confidence == 1.0 + assert ans.caveats == [] + + def test_llm_failure_uses_fallback(self): + agent = SynthesizerAgent() + agent._llm = MagicMock() + agent._llm.chat.completions.create.side_effect = RuntimeError("API down") + ctx = _ctx() + ans = agent.run(ctx, [_cr("AAPL")]) + + assert "AAPL" in ans.markdown + assert "sentiment=8" in ans.markdown + assert "LLM unavailable" in ans.rationale + + def test_caveats_for_partial_data(self): + agent = self._make_agent("# Partial") + cr = CompanyResearch(ticker="X") + cr.errors.append(ToolError(entity="X", tool="sentiment", error="timeout")) + ans = agent.run(_ctx(), [cr]) + + assert len(ans.caveats) >= 1 + assert ans.confidence < 1.0 + + def test_citations_collected(self): + agent = self._make_agent("# Answer") + ans = agent.run(_ctx(), [_cr("AAPL")]) + assert len(ans.citations) >= 1