diff --git a/src/hitachione/README.md b/src/hitachione/README.md new file mode 100644 index 0000000..e69de29 diff --git a/src/hitachione/agents/README.md b/src/hitachione/agents/README.md new file mode 100644 index 0000000..e69de29 diff --git a/src/hitachione/kb/README.md b/src/hitachione/kb/README.md new file mode 100644 index 0000000..e69de29 diff --git a/src/hitachione/tools/README.md b/src/hitachione/tools/README.md new file mode 100644 index 0000000..e69de29 diff --git a/src/hitachione/tools/company_filtering_tool/bloomberg.py b/src/hitachione/tools/company_filtering_tool/bloomberg.py new file mode 100644 index 0000000..ed98726 --- /dev/null +++ b/src/hitachione/tools/company_filtering_tool/bloomberg.py @@ -0,0 +1,145 @@ +"""Script to load and explore Bloomberg Financial News dataset. + +This dataset consists of 446,762 financial news articles from Bloomberg (2006-2013). +Dataset: danidanou/Bloomberg_Financial_News +""" + +import os +import pandas as pd + + +def load_bloomberg_data( + limit: int | None = None, + split: str = "train" +) -> pd.DataFrame: + """Load Bloomberg Financial News dataset from HuggingFace. + + Args: + limit: Optional maximum number of articles to load + split: Dataset split to load (default: "train") + + Returns: + DataFrame with financial news articles + """ + print(f"Loading Bloomberg Financial News dataset (split: {split})...") + + # Try direct parquet reading approach + try: + print("Method 1: Direct download and read as parquet...") + from huggingface_hub import hf_hub_download + import pyarrow.parquet as pq + + print("Starting download of parquet file...") + # Download the file + file_path = hf_hub_download( + repo_id="danidanou/Bloomberg_Financial_News", + filename="bloomberg_financial_data.parquet.gzip", + repo_type="dataset" + ) + print(f"Download complete. File path: {file_path}") + + print("Starting to read parquet file...") + table = pq.read_table(file_path) + print(f"Parquet table read successfully. Table shape: {table.shape}") + + print("Converting table to pandas DataFrame...") + df = table.to_pandas() + print(f"Conversion to pandas complete. DataFrame shape: {df.shape}") + + print(f"Successfully loaded {len(df)} articles") + + # Apply limit if specified + if limit is not None: + df = df.head(limit) + print(f"Limited to {limit} articles") + + print(f"Columns: {list(df.columns)}") + return df + + except Exception as e: + print(f"Method 1 failed: {e}") + print("\nMethod 2: Trying datasets library with no verification...") + + # Fallback: datasets library with verification disabled + try: + from datasets import load_dataset + + dataset = load_dataset( + "danidanou/Bloomberg_Financial_News", + split=split, + verification_mode="no_checks", + trust_remote_code=False + ) + + df = dataset.to_pandas() + print(f"Successfully loaded {len(df)} articles") + + if limit is not None: + df = df.head(limit) + print(f"Limited to {limit} articles") + + print(f"Columns: {list(df.columns)}") + return df + + except Exception as e: + print(f"Method 2 failed: {e}") + raise RuntimeError( + f"Unable to load Bloomberg dataset. Both methods failed.\n" + f"This dataset has known compression format issues.\n" + f"Please report this issue or try manually downloading from:\n" + f"https://huggingface.co/datasets/danidanou/Bloomberg_Financial_News" + ) + + +def explore_dataset(df: pd.DataFrame) -> None: + """Print summary statistics about the dataset. + + Args: + df: DataFrame containing Bloomberg articles + """ + print("\n" + "="*60) + print("DATASET EXPLORATION") + print("="*60) + + print(f"\nTotal articles: {len(df):,}") + print(f"\nColumns: {list(df.columns)}") + print(f"\nData types:\n{df.dtypes}") + + print("\n" + "-"*60) + print("SAMPLE RECORD:") + print("-"*60) + print(df.iloc[0].to_dict()) + + # Check for text columns + text_columns = [col for col in df.columns if 'text' in col.lower() or 'content' in col.lower()] + if text_columns: + print(f"\nText columns found: {text_columns}") + for col in text_columns: + print(f"\nSample from '{col}':") + print(df[col].iloc[0][:500] + "..." if len(str(df[col].iloc[0])) > 500 else df[col].iloc[0]) + + # Check for missing values + print("\n" + "-"*60) + print("MISSING VALUES:") + print("-"*60) + print(df.isnull().sum()) + + # Text length statistics if text column exists + if text_columns: + for col in text_columns: + df[f'{col}_length'] = df[col].astype(str).str.len() + print(f"\n{col} length statistics:") + print(df[f'{col}_length'].describe()) + + +if __name__ == "__main__": + # Load a sample of the dataset for exploration + df = load_bloomberg_data(limit=1000) + + # Explore the dataset + explore_dataset(df) + + print("\n" + "="*60) + print("To load the full dataset, use:") + print(" df = load_bloomberg_data()") + print("="*60) diff --git a/src/hitachione/tools/company_filtering_tool/test.py b/src/hitachione/tools/company_filtering_tool/test.py new file mode 100644 index 0000000..c5e3bb1 --- /dev/null +++ b/src/hitachione/tools/company_filtering_tool/test.py @@ -0,0 +1,254 @@ +""" +Test harness for the company filtering tool (Weaviate-backed). + +This module provides comprehensive tests for the find_relevant_symbols tool. +Run this file to test the tool functionality. +""" + +import sys +from typing import List +from dotenv import load_dotenv +from pathlib import Path + +# Load .env file (project root is 5 levels up from this file) +load_dotenv(Path(__file__).resolve().parents[4] / ".env") + +import os +from tool import ( + find_relevant_symbols, + find_relevant_sp500_symbols, + get_all_symbols, + get_company_mapping, + TOOL_SCHEMA, +) + + +def print_section(title: str): + """Print a formatted section header.""" + print("\n" + "=" * 80) + print(f" {title}") + print("=" * 80 + "\n") + + +def test_symbol_extraction(): + """Test extracting all symbols from the Weaviate knowledge base.""" + print_section("Testing Symbol Extraction from Weaviate") + + print("Extracting all unique tickers from Weaviate collection...") + collection = os.getenv('WEAVIATE_COLLECTION_NAME', 'hitachi-finance-news') + print(f"(Iterates through {collection} collection)\n") + + try: + import time + start = time.time() + + symbols = get_all_symbols() + elapsed = time.time() - start + + print(f"✓ Successfully extracted {len(symbols)} unique tickers in {elapsed:.2f}s") + print(f"\nTickers: {', '.join(symbols)}") + + # Also show company mapping + companies = get_company_mapping() + print(f"\nCompany mapping:") + for ticker, company in sorted(companies.items()): + print(f" {ticker}: {company}") + + except Exception as e: + print(f"✗ Error: {e}") + + print() + + +def test_symbol_search(): + """Test the main symbol search functionality with LLM filtering.""" + print_section("Testing Tool Function with LLM Filtering") + + test_cases = [ + "List all the top automotive stocks of 2012", + "Find technology companies from 2015", + "Show me healthcare stocks", + "top 5 tech stocks", + ] + + import os + api_key = os.getenv("OPENAI_API_KEY") or os.getenv("GEMINI_API_KEY") or os.getenv("GOOGLE_API_KEY") + if not api_key: + print("⚠️ No LLM API key set - LLM filtering disabled") + print(" Tool will return all symbols without filtering\n") + else: + print("✓ LLM API key found - LLM filtering enabled\n") + + for i, query in enumerate(test_cases, 1): + print(f"Test {i}/{len(test_cases)}: {query}") + print("-" * 80) + + try: + symbols = find_relevant_symbols(query) + + print(f"✓ Returned {len(symbols)} filtered symbols") + print(f" Results: {', '.join(symbols[:15])}") + if len(symbols) > 15: + print(f" ... and {len(symbols) - 15} more") + except Exception as e: + print(f"✗ Error: {e}") + + print() + + +def test_year_extraction(): + """Test year extraction from queries.""" + print_section("Legacy Test: Year Extraction") + print("Note: This functionality is no longer used.") + print("The tool now returns all symbols and lets the LLM do filtering.\n") + + +def test_keyword_detection(): + """Test keyword detection from queries.""" + print_section("Legacy Test: Keyword Detection") + print("Note: This functionality is no longer used.") + print("The tool now returns all symbols and lets the LLM do filtering.\n") + + +def test_tool_schema(): + """Display and validate the tool schema.""" + print_section("Tool Schema for OpenAI") + + import json + print(json.dumps(TOOL_SCHEMA, indent=2)) + print("\nSchema validation: ✓ Valid JSON structure") + + +def run_interactive_test(): + """Run an interactive test where user can input queries.""" + print_section("Interactive Test Mode") + + import os + api_key = os.getenv("OPENAI_API_KEY") or os.getenv("GEMINI_API_KEY") or os.getenv("GOOGLE_API_KEY") + if not api_key: + print("⚠️ No LLM API key set") + print(" Tool will return all symbols without LLM filtering\n") + else: + print("✓ LLM filtering enabled\n") + + print("Enter queries to test the tool (or 'quit' to exit)") + print("Example: 'top 3 tech stocks of 2010'\n") + + # Load symbols once + try: + print("Loading all tickers from Weaviate...") + all_symbols = get_all_symbols() + companies = get_company_mapping() + print(f"✓ Loaded {len(all_symbols)} unique tickers:") + for t in all_symbols: + print(f" {t}: {companies.get(t, 'N/A')}") + print() + except Exception as e: + print(f"✗ Error loading symbols: {e}") + return + + while True: + try: + query = input("Query> ").strip() + + if query.lower() in ['quit', 'exit', 'q']: + break + + if not query: + continue + + print(f"\nProcessing: '{query}'") + + import time + start = time.time() + symbols = find_relevant_symbols(query) + elapsed = time.time() - start + + print(f"✓ Complete in {elapsed:.2f}s") + print(f" Filtered to {len(symbols)} symbols: {', '.join(symbols)}") + print() + + except KeyboardInterrupt: + print("\n\nExiting interactive mode.") + break + except Exception as e: + print(f"\n✗ Error: {e}\n") + + +def benchmark_query_performance(): + """Benchmark the performance of symbol extraction.""" + print_section("Performance Benchmarking") + + import time + + print("Benchmarking symbol extraction...\n") + + # First run (may need to read from disk) + start_time = time.time() + symbols = get_all_symbols() + elapsed = time.time() - start_time + + print(f"First call: {elapsed:.3f}s ({len(symbols)} symbols)") + + # Second run (should use cache) + start_time = time.time() + symbols = get_all_symbols() + elapsed = time.time() - start_time + + print(f"Cached call: {elapsed:.3f}s (from memory cache)") + print() + + +def main(): + """Main test harness entry point.""" + print("\n" + "=" * 80) + print(" Company Filtering Tool (Weaviate) - Test Harness") + print("=" * 80) + + if len(sys.argv) > 1: + mode = sys.argv[1].lower() + else: + print("\nAvailable test modes:") + print(" 1. all - Run all tests") + print(" 2. quick - Run quick tests (schema only)") + print(" 3. extract - Test symbol extraction") + print(" 4. interactive - Interactive query mode") + print(" 5. benchmark - Performance benchmarking") + print(" 6. schema - Display tool schema") + + choice = input("\nSelect mode (1-6) or press Enter for 'all': ").strip() + + mode_map = { + '1': 'all', + '2': 'quick', + '3': 'extract', + '4': 'interactive', + '5': 'benchmark', + '6': 'schema', + } + + mode = mode_map.get(choice, 'all') + + if mode in ['all', 'quick']: + test_tool_schema() + + if mode in ['all', 'extract']: + test_symbol_extraction() + test_symbol_search() + + if mode == 'interactive': + run_interactive_test() + + if mode == 'benchmark': + benchmark_query_performance() + + if mode == 'schema': + test_tool_schema() + + print_section("Test Harness Complete") + print("To run specific tests, use: python test_sp500_tool.py ") + print("Available modes: all, quick, extract, interactive, benchmark, schema\n") + + +if __name__ == "__main__": + main() diff --git a/src/hitachione/tools/company_filtering_tool/tool.py b/src/hitachione/tools/company_filtering_tool/tool.py new file mode 100644 index 0000000..c3eefc1 --- /dev/null +++ b/src/hitachione/tools/company_filtering_tool/tool.py @@ -0,0 +1,333 @@ +""" +Tool for finding relevant stock symbols from the Weaviate financial news knowledge base. + +This tool queries the Weaviate financial news collection to retrieve unique +stock tickers and uses an LLM to filter them based on user queries. +""" + +from typing import List +from pathlib import Path +import os +import json +import asyncio + +import weaviate +from weaviate.auth import AuthApiKey +from dotenv import load_dotenv + +# Load .env from project root +load_dotenv(Path(__file__).resolve().parents[4] / ".env") + +# Import client manager from the utils +import sys +sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent)) +from utils.client_manager import AsyncClientManager + +# Cache for symbols and company mapping +_cached_symbols: List[str] | None = None +_cached_companies: dict[str, str] | None = None # ticker -> company name +_client_manager = None + +# Weaviate collection name (from WEAVIATE_COLLECTION_NAME env var) +WEAVIATE_COLLECTION = os.getenv("WEAVIATE_COLLECTION_NAME", "Hitachi_finance_news") + + +def get_client_manager() -> AsyncClientManager: + """Get or create the client manager.""" + global _client_manager + + if _client_manager is None: + _client_manager = AsyncClientManager() + + return _client_manager + + +def _get_weaviate_sync_client(): + """Create a synchronous Weaviate client from environment variables.""" + http_host = os.getenv("WEAVIATE_HTTP_HOST", "localhost") + api_key = os.getenv("WEAVIATE_API_KEY", "") + + # Weaviate Cloud uses connect_to_weaviate_cloud (single host, port 443) + if http_host.endswith(".weaviate.cloud"): + cluster_url = f"https://{http_host}" + return weaviate.connect_to_weaviate_cloud( + cluster_url=cluster_url, + auth_credentials=AuthApiKey(api_key), + ) + + # Otherwise use custom connection for self-hosted instances + return weaviate.connect_to_custom( + http_host=http_host, + http_port=int(os.getenv("WEAVIATE_HTTP_PORT", "8080")), + http_secure=os.getenv("WEAVIATE_HTTP_SECURE", "false").lower() == "true", + grpc_host=os.getenv("WEAVIATE_GRPC_HOST", "localhost"), + grpc_port=int(os.getenv("WEAVIATE_GRPC_PORT", "50051")), + grpc_secure=os.getenv("WEAVIATE_GRPC_SECURE", "false").lower() == "true", + auth_credentials=AuthApiKey(api_key), + ) + + +def get_all_symbols() -> List[str]: + """ + Get all unique stock tickers from the Weaviate knowledge base. + + Iterates through the Weaviate collection and collects + unique ticker symbols and their corresponding company names. + + Returns: + Sorted list of unique stock tickers + """ + global _cached_symbols, _cached_companies + + if _cached_symbols is not None: + return _cached_symbols + + client = _get_weaviate_sync_client() + try: + col = client.collections.get(WEAVIATE_COLLECTION) + + tickers = set() + companies: dict[str, str] = {} + + for obj in col.iterator( + include_vector=False, + return_properties=["ticker", "company"], + ): + ticker = obj.properties.get("ticker") + company = obj.properties.get("company") + if ticker: + tickers.add(ticker) + if company and ticker not in companies: + companies[ticker] = company + + _cached_symbols = sorted(tickers) + _cached_companies = companies + return _cached_symbols + + except Exception as e: + raise RuntimeError(f"Failed to load tickers from Weaviate: {e}") + finally: + client.close() + + +def get_company_mapping() -> dict[str, str]: + """ + Get ticker -> company name mapping from the Weaviate knowledge base. + + Returns: + Dictionary mapping ticker symbols to company names + """ + if _cached_companies is None: + get_all_symbols() # populates both caches + return _cached_companies or {} + + +async def filter_symbols_with_llm_async(query: str, symbols: List[str]) -> List[str]: + """ + Use an LLM to filter symbols based on the query (async version). + + Args: + query: Natural-language query describing what to filter for + symbols: List of all available symbols + + Returns: + Filtered list of relevant symbols + """ + client_manager = get_client_manager() + + # Build a readable list with company names + company_map = get_company_mapping() + symbol_list = ", ".join( + f"{s} ({company_map[s]})" if s in company_map else s for s in symbols + ) + + prompt = f"""Given this list of stock symbols from our financial knowledge base, identify and return ALL symbols that match this query: "{query}" + +Available symbols: +{symbol_list} + +Instructions: +- Return ALL matching stock symbols as a JSON array +- Focus on the query requirements (sector, industry, characteristics) +- IGNORE any numeric limits like "top N" - return ALL relevant matches +- Use your knowledge of which companies operate in which sectors +- Return ONLY a JSON object with a "symbols" array, nothing else + +Example: For "tech stocks", return ALL technology company symbols, not just the top few. +Response format: {{"symbols": ["AAPL", "GOOGL", "MSFT", "NVDA", ...]}}""" + + try: + response = await client_manager.openai_client.chat.completions.create( + model=client_manager.configs.default_worker_model, + messages=[{"role": "user", "content": prompt}], + temperature=0, + ) + + content = response.choices[0].message.content.strip() + + # Extract JSON from response (handle markdown code blocks) + if "```json" in content: + content = content.split("```json")[1].split("```")[0].strip() + elif "```" in content: + content = content.split("```")[1].split("```")[0].strip() + + result = json.loads(content) + filtered_symbols = result.get("symbols", []) + + # Validate that returned symbols are in the original list + valid_symbols = [s for s in filtered_symbols if s in symbols] + + return valid_symbols + + except Exception as e: + # Fallback: return all symbols if filtering fails + print(f"Warning: LLM filtering failed ({e}), returning all symbols") + return symbols + + +def filter_symbols_with_llm(query: str, symbols: List[str]) -> List[str]: + """ + Use an LLM to filter symbols based on the query (sync wrapper). + + Args: + query: Natural-language query describing what to filter for + symbols: List of all available symbols + + Returns: + Filtered list of relevant symbols + """ + try: + loop = asyncio.get_running_loop() + except RuntimeError: + loop = None + + if loop and loop.is_running(): + # We're already inside an event loop (e.g. Jupyter / Gradio) + import concurrent.futures + + with concurrent.futures.ThreadPoolExecutor() as pool: + return pool.submit( + asyncio.run, filter_symbols_with_llm_async(query, symbols) + ).result() + else: + return asyncio.run(filter_symbols_with_llm_async(query, symbols)) + + +def find_relevant_symbols(query: str, use_llm_filter: bool = True) -> List[str]: + """ + Find stock symbols relevant to the query from the Weaviate knowledge base. + Uses an LLM internally to filter symbols based on the query. + + Args: + query: Natural-language query describing the type of companies or time period, + e.g. 'List all the top automotive stocks of 2012' + use_llm_filter: Whether to use LLM for filtering (default: True) + + Returns: + Sorted list of filtered stock symbols relevant to the query. + """ + all_symbols = get_all_symbols() + + if not use_llm_filter: + return all_symbols + + # Check if API key is available + api_key = os.getenv("OPENAI_API_KEY") or os.getenv("GEMINI_API_KEY") or os.getenv("GOOGLE_API_KEY") + if not api_key: + print("Warning: No LLM API key set, returning all symbols without filtering") + return all_symbols + + # Use LLM to filter symbols based on query + filtered = filter_symbols_with_llm(query, all_symbols) + + return sorted(filtered) + + +# Keep backward-compatible alias +find_relevant_sp500_symbols = find_relevant_symbols + + +# OpenAI tool schema +TOOL_SCHEMA = { + "type": "function", + "function": { + "name": "find_relevant_symbols", + "description": ( + "Find relevant stock symbols from the Weaviate financial news knowledge base " + "The tool uses an LLM internally to filter " + "symbols based on the query, returning only symbols that match the specified " + "criteria (sector, industry, time period, ranking, etc.)." + ), + "parameters": { + "type": "object", + "properties": { + "query": { + "type": "string", + "description": ( + "Natural-language query describing the type of companies or " + "time period, e.g. 'List all the top automotive stocks of 2012' " + "or 'top 3 tech stocks of 2010'." + ), + } + }, + "required": ["query"], + }, + }, +} + + +# Tool implementation mapping for OpenAI agent integration +TOOL_IMPLEMENTATIONS = { + "find_relevant_symbols": find_relevant_symbols, + "find_relevant_sp500_symbols": find_relevant_symbols, # backward compat +} + + +def run_agent_with_tool(user_query: str, client) -> str: + """ + Run an OpenAI agent that can use the symbol-filtering tool. + + Args: + user_query: User's natural-language query + client: OpenAI client instance + + Returns: + Final response from the agent + """ + import json + + response = client.chat.completions.create( + model="gpt-4o-mini", + messages=[{"role": "user", "content": user_query}], + tools=[TOOL_SCHEMA], + ) + + choice = response.choices[0] + tool_calls = getattr(choice.message, "tool_calls", None) + + if tool_calls: + # Process the first tool call + tool_call = tool_calls[0] + tool_name = tool_call.function.name + tool_args = tool_call.function.arguments # JSON string + + args = json.loads(tool_args) + result = TOOL_IMPLEMENTATIONS[tool_name](**args) + + # Send tool result back to the model for final answer + followup = client.chat.completions.create( + model="gpt-4o-mini", + messages=[ + {"role": "user", "content": user_query}, + { + "role": "tool", + "tool_call_id": tool_call.id, + "name": tool_name, + "content": str(result), + }, + ], + ) + return followup.choices[0].message.content + + # If no tool call, just return the model's direct answer + return choice.message.content diff --git a/src/hitachione/tools/performance_analysis_tool/test.py b/src/hitachione/tools/performance_analysis_tool/test.py new file mode 100644 index 0000000..0cfaced --- /dev/null +++ b/src/hitachione/tools/performance_analysis_tool/test.py @@ -0,0 +1,289 @@ +""" +Test harness for the Performance Analysis Tool (Weaviate-backed). + +Run: + cd src/hitachione/tools/performance_analysis_tool + python3 test.py all # full suite + python3 test.py data # data retrieval only (no LLM) + python3 test.py analyse # full analysis (requires LLM key) + python3 test.py schema # show tool schema + python3 test.py interactive # interactive ticker input +""" + +import json +import os +import sys +import time +from pathlib import Path + +from dotenv import load_dotenv + +# Load .env from project root +load_dotenv(Path(__file__).resolve().parents[4] / ".env") + +from tool import ( + TOOL_SCHEMA, + analyse_stock_performance, + get_ticker_data, +) + +# ── Tickers known to exist in the Weaviate collection ── +KNOWN_TICKERS = ["AAPL", "AMZN", "GOOGL", "JPM", "META", "MSFT", "NVDA", "TSLA", "V", "WMT"] +UNKNOWN_TICKER = "ZZZZZ" + + +# ────────────────────────────────────────────────────────────────────────── +# Helpers +# ────────────────────────────────────────────────────────────────────────── + +def _section(title: str) -> None: + print("\n" + "=" * 80) + print(f" {title}") + print("=" * 80 + "\n") + + +def _pass(msg: str) -> None: + print(f" ✓ {msg}") + + +def _fail(msg: str) -> None: + print(f" ✗ {msg}") + + +# ────────────────────────────────────────────────────────────────────────── +# Tests +# ────────────────────────────────────────────────────────────────────────── + +def test_tool_schema() -> None: + """Validate the tool schema structure.""" + _section("Tool Schema") + print(json.dumps(TOOL_SCHEMA, indent=2)) + + assert TOOL_SCHEMA["type"] == "function" + fn = TOOL_SCHEMA["function"] + assert fn["name"] == "analyse_stock_performance" + assert "ticker" in fn["parameters"]["properties"] + assert "ticker" in fn["parameters"]["required"] + _pass("Schema is valid") + + +def test_data_retrieval_known_ticker() -> None: + """Retrieve data for known tickers and verify structure.""" + _section("Data Retrieval — Known Tickers") + + for ticker in ["AAPL", "TSLA", "JPM"]: + t0 = time.time() + data = get_ticker_data(ticker) + elapsed = time.time() - t0 + + assert isinstance(data, dict), f"Expected dict, got {type(data)}" + for key in ("price_data", "earnings", "news"): + assert key in data, f"Missing key '{key}' for {ticker}" + + total = sum(len(v) for v in data.values()) + _pass( + f"{ticker}: {len(data['price_data'])} price, " + f"{len(data['earnings'])} earnings, " + f"{len(data['news'])} news ({elapsed:.2f}s, {total} total)" + ) + + # At least one data source should have records + assert total > 0, f"No data returned for known ticker {ticker}" + + +def test_data_retrieval_unknown_ticker() -> None: + """Verify graceful handling of an unknown ticker.""" + _section("Data Retrieval — Unknown Ticker") + + data = get_ticker_data(UNKNOWN_TICKER) + total = sum(len(v) for v in data.values()) + assert total == 0, f"Expected 0 records for {UNKNOWN_TICKER}, got {total}" + _pass(f"{UNKNOWN_TICKER}: 0 records as expected") + + +def test_data_retrieval_case_insensitive() -> None: + """Verify ticker is uppercased automatically.""" + _section("Data Retrieval — Case Insensitivity") + + data_upper = get_ticker_data("AAPL") + data_lower = get_ticker_data("aapl") + + assert len(data_upper["price_data"]) == len(data_lower["price_data"]), \ + "Price data count differs between 'AAPL' and 'aapl'" + _pass("'AAPL' and 'aapl' return same results") + + +def test_price_data_fields() -> None: + """Verify price records contain expected fields.""" + _section("Price Data — Field Validation") + + data = get_ticker_data("TSLA") + if not data["price_data"]: + _fail("No price data for TSLA") + return + + expected_fields = {"date", "open", "high", "low", "close"} + for i, rec in enumerate(data["price_data"][:3]): + present = set(rec.keys()) & expected_fields + assert present == expected_fields, ( + f"Record {i} missing fields: {expected_fields - present}" + ) + _pass(f"First {min(3, len(data['price_data']))} records have all OHLC fields") + + +def test_price_data_sorted() -> None: + """Verify price records are sorted by date.""" + _section("Price Data — Sort Order") + + data = get_ticker_data("GOOGL") + dates = [r["date"] for r in data["price_data"] if "date" in r] + assert dates == sorted(dates), "Price data is not sorted by date" + _pass(f"GOOGL: {len(dates)} price records sorted correctly") + + +def test_analyse_unknown_ticker() -> None: + """Full analysis on unknown ticker returns None score.""" + _section("Full Analysis — Unknown Ticker") + + result = analyse_stock_performance(UNKNOWN_TICKER) + assert result["ticker"] == UNKNOWN_TICKER + assert result["performance_score"] is None + assert result["data_summary"]["price_records"] == 0 + _pass(f"{UNKNOWN_TICKER}: score=None, outlook={result['outlook']}") + + +def test_analyse_known_ticker() -> None: + """Full analysis on a known ticker returns valid structure.""" + _section("Full Analysis — Known Tickers (LLM)") + + api_key = ( + os.getenv("OPENAI_API_KEY") + or os.getenv("GEMINI_API_KEY") + or os.getenv("GOOGLE_API_KEY") + ) + if not api_key: + print(" ⚠️ No LLM API key — skipping LLM analysis tests") + return + + for ticker in ["AAPL", "NVDA"]: + t0 = time.time() + result = analyse_stock_performance(ticker) + elapsed = time.time() - t0 + + assert isinstance(result, dict) + assert result["ticker"] == ticker + assert isinstance(result["performance_score"], int) + assert 1 <= result["performance_score"] <= 10 + assert result["outlook"] in ("Bullish", "Bearish", "Volatile", "Sideways") + assert len(result["justification"]) > 20 + assert result["data_summary"]["price_records"] > 0 + + _pass( + f"{ticker}: score={result['performance_score']}, " + f"outlook={result['outlook']}, {elapsed:.1f}s" + ) + print(f" Justification: {result['justification'][:120]}...") + + +def test_analyse_multiple_tickers() -> None: + """Analyse several tickers to confirm consistency.""" + _section("Full Analysis — All Known Tickers (LLM)") + + api_key = ( + os.getenv("OPENAI_API_KEY") + or os.getenv("GEMINI_API_KEY") + or os.getenv("GOOGLE_API_KEY") + ) + if not api_key: + print(" ⚠️ No LLM API key — skipping") + return + + for ticker in KNOWN_TICKERS: + result = analyse_stock_performance(ticker) + score = result["performance_score"] + outlook = result["outlook"] + ds = result["data_summary"] + _pass( + f"{ticker:5s}: score={score:>2}, outlook={outlook:8s} " + f"(price={ds['price_records']}, earn={ds['earnings_records']}, " + f"news={ds['news_records']})" + ) + + +# ────────────────────────────────────────────────────────────────────────── +# Interactive mode +# ────────────────────────────────────────────────────────────────────────── + +def interactive() -> None: + _section("Interactive Mode") + print(f"Available tickers: {', '.join(KNOWN_TICKERS)}") + print("Enter a ticker (or 'quit' to exit)\n") + + while True: + try: + ticker = input("Ticker> ").strip() + if ticker.lower() in ("quit", "exit", "q"): + break + if not ticker: + continue + + print(f"\nAnalysing {ticker.upper()}...") + t0 = time.time() + result = analyse_stock_performance(ticker) + elapsed = time.time() - t0 + + print(json.dumps(result, indent=2)) + print(f"({elapsed:.1f}s)\n") + + except KeyboardInterrupt: + print("\n") + break + except Exception as e: + print(f" ✗ Error: {e}\n") + + +# ────────────────────────────────────────────────────────────────────────── +# Main +# ────────────────────────────────────────────────────────────────────────── + +def main() -> None: + print("\n" + "=" * 80) + print(" Performance Analysis Tool (Weaviate) — Test Harness") + print("=" * 80) + + if len(sys.argv) > 1: + mode = sys.argv[1].lower() + else: + print("\nModes:") + print(" 1. all — Run all tests") + print(" 2. data — Data retrieval tests only (no LLM)") + print(" 3. analyse — Full analysis tests (requires LLM)") + print(" 4. schema — Display tool schema") + print(" 5. interactive — Interactive ticker input") + + choice = input("\nSelect (1-5) or press Enter for 'all': ").strip() + mode = {"1": "all", "2": "data", "3": "analyse", "4": "schema", "5": "interactive"}.get(choice, "all") + + if mode in ("all", "schema"): + test_tool_schema() + + if mode in ("all", "data"): + test_data_retrieval_known_ticker() + test_data_retrieval_unknown_ticker() + test_data_retrieval_case_insensitive() + test_price_data_fields() + test_price_data_sorted() + + if mode in ("all", "analyse"): + test_analyse_unknown_ticker() + test_analyse_known_ticker() + test_analyse_multiple_tickers() + + if mode == "interactive": + interactive() + + _section("Test Harness Complete") + + +if __name__ == "__main__": + main() diff --git a/src/hitachione/tools/performance_analysis_tool/tool.py b/src/hitachione/tools/performance_analysis_tool/tool.py new file mode 100644 index 0000000..5ae74ce --- /dev/null +++ b/src/hitachione/tools/performance_analysis_tool/tool.py @@ -0,0 +1,318 @@ +""" +Tool for analysing stock performance using the Weaviate knowledge base. + +Queries the Weaviate financial news collection for price history, earnings +transcripts, and financial news, then uses an LLM to produce a performance +rating score (1-10), future outlook, and justification. +""" + +import asyncio +import json +import os +from pathlib import Path +from typing import Any, List + +import weaviate +from weaviate.auth import AuthApiKey +from weaviate.classes.query import Filter +from dotenv import load_dotenv + +# Load .env from project root +load_dotenv(Path(__file__).resolve().parents[4] / ".env") + +import sys +sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent)) +from utils.client_manager import AsyncClientManager + +# --------------------------------------------------------------------------- +# Weaviate helpers +# --------------------------------------------------------------------------- + +WEAVIATE_COLLECTION = os.getenv("WEAVIATE_COLLECTION_NAME", "Hitachi_finance_news") + + +def _get_weaviate_sync_client(): + """Create a synchronous Weaviate client from environment variables.""" + http_host = os.getenv("WEAVIATE_HTTP_HOST", "localhost") + api_key = os.getenv("WEAVIATE_API_KEY", "") + + if http_host.endswith(".weaviate.cloud"): + return weaviate.connect_to_weaviate_cloud( + cluster_url=f"https://{http_host}", + auth_credentials=AuthApiKey(api_key), + ) + + return weaviate.connect_to_custom( + http_host=http_host, + http_port=int(os.getenv("WEAVIATE_HTTP_PORT", "8080")), + http_secure=os.getenv("WEAVIATE_HTTP_SECURE", "false").lower() == "true", + grpc_host=os.getenv("WEAVIATE_GRPC_HOST", "localhost"), + grpc_port=int(os.getenv("WEAVIATE_GRPC_PORT", "50051")), + grpc_secure=os.getenv("WEAVIATE_GRPC_SECURE", "false").lower() == "true", + auth_credentials=AuthApiKey(api_key), + ) + + +# --------------------------------------------------------------------------- +# Data retrieval +# --------------------------------------------------------------------------- + + +def get_ticker_data(ticker: str) -> dict[str, list[dict]]: + """Retrieve all Weaviate data for a given ticker, grouped by source. + + Returns a dict with keys ``price_data``, ``earnings``, ``news`` — each + containing a list of property dicts. + """ + client = _get_weaviate_sync_client() + try: + col = client.collections.get(WEAVIATE_COLLECTION) + + ticker_filter = Filter.by_property("ticker").equal(ticker.upper()) + + # --- price data (stock_market) --- + price_response = col.query.fetch_objects( + filters=( + ticker_filter + & Filter.by_property("dataset_source").equal("stock_market") + ), + limit=100, + return_properties=[ + "date", "open", "high", "low", "close", "volume", "text", + ], + ) + price_data = [ + {k: v for k, v in obj.properties.items() if v is not None} + for obj in price_response.objects + ] + + # --- earnings transcripts --- + earnings_response = col.query.fetch_objects( + filters=( + ticker_filter + & Filter.by_property("dataset_source").equal( + "sp500_earnings_transcripts" + ) + ), + limit=100, + return_properties=[ + "date", "quarter", "fiscal_year", "fiscal_quarter", "text", + "title", + ], + ) + earnings = [ + {k: v for k, v in obj.properties.items() if v is not None} + for obj in earnings_response.objects + ] + + # --- news (bloomberg + yahoo) --- + news_response = col.query.fetch_objects( + filters=( + ticker_filter + & ( + Filter.by_property("dataset_source").equal( + "bloomberg_financial_news" + ) + | Filter.by_property("dataset_source").equal( + "yahoo_finance_news" + ) + ) + ), + limit=100, + return_properties=["date", "title", "text", "category"], + ) + # Also grab news that *mention* this ticker (mentioned_companies) + mentioned_response = col.query.fetch_objects( + filters=Filter.by_property("mentioned_companies").contains_any( + [ticker.upper()] + ), + limit=50, + return_properties=["date", "title", "text", "category"], + ) + + seen_titles: set[str] = set() + news: list[dict] = [] + for obj in list(news_response.objects) + list(mentioned_response.objects): + props = {k: v for k, v in obj.properties.items() if v is not None} + title = props.get("title", "") + if title not in seen_titles: + seen_titles.add(title) + news.append(props) + + return { + "price_data": sorted(price_data, key=lambda d: d.get("date", "")), + "earnings": sorted(earnings, key=lambda d: d.get("date", "")), + "news": sorted(news, key=lambda d: d.get("date", "")), + } + + finally: + client.close() + + +# --------------------------------------------------------------------------- +# LLM-based performance scoring +# --------------------------------------------------------------------------- + +_client_manager = None + + +def _get_client_manager() -> AsyncClientManager: + global _client_manager + if _client_manager is None: + _client_manager = AsyncClientManager() + return _client_manager + + +async def _analyse_with_llm(ticker: str, data: dict[str, list[dict]]) -> dict: + """Send retrieved data to an LLM and get a structured performance analysis.""" + cm = _get_client_manager() + + # Build context sections + price_summary = "\n".join( + d.get("text", json.dumps(d)) for d in data["price_data"] + ) or "No price data available." + + earnings_summary = "\n---\n".join( + d.get("text", json.dumps(d)) for d in data["earnings"] + ) or "No earnings data available." + + news_summary = "\n---\n".join( + f"[{d.get('date','')}] {d.get('title','')}: {str(d.get('text',''))[:500]}" + for d in data["news"] + ) or "No news articles available." + + prompt = f"""You are a Stock Performance Analyst. Analyse the ticker "{ticker}" using ONLY the data provided below. + +## Price History +{price_summary} + +## Earnings Transcripts +{earnings_summary} + +## News Articles +{news_summary} + +Based on the data above, produce a JSON object (and NOTHING else) with exactly these keys: + +{{ + "ticker": "{ticker}", + "performance_score": , + "outlook": "", + "justification": "<2-4 sentence explanation citing specific data points>" +}} + +Scoring guide: + 1-4 → Negative (declining price, poor earnings, negative news) + 5 → Neutral + 6-10 → Positive (rising price, strong earnings, positive sentiment) +""" + + response = await cm.openai_client.chat.completions.create( + model=cm.configs.default_worker_model, + messages=[{"role": "user", "content": prompt}], + temperature=0, + ) + + content = response.choices[0].message.content.strip() + + # Extract JSON from potential markdown code fences + if "```json" in content: + content = content.split("```json")[1].split("```")[0].strip() + elif "```" in content: + content = content.split("```")[1].split("```")[0].strip() + + return json.loads(content) + + +def _run_async(coro): + """Run an async coroutine, handling nested event loops.""" + try: + loop = asyncio.get_running_loop() + except RuntimeError: + loop = None + + if loop and loop.is_running(): + import concurrent.futures + with concurrent.futures.ThreadPoolExecutor() as pool: + return pool.submit(asyncio.run, coro).result() + return asyncio.run(coro) + + +# --------------------------------------------------------------------------- +# Public API +# --------------------------------------------------------------------------- + + +def analyse_stock_performance(ticker: str) -> dict: + """Analyse a stock's performance using Weaviate knowledge base data. + + Retrieves price history, earnings transcripts, and news articles from the + Weaviate financial news collection, then uses an LLM to produce a + structured performance analysis. + + Parameters + ---------- + ticker : str + Stock ticker symbol (e.g. ``"AAPL"``, ``"TSLA"``). + + Returns + ------- + dict + A dictionary with keys: + - ``ticker`` (str) + - ``performance_score`` (int, 1–10) + - ``outlook`` (str, one of Bullish/Bearish/Volatile/Sideways) + - ``justification`` (str) + - ``data_summary`` (dict with counts of price/earnings/news records) + """ + ticker = ticker.upper().strip() + data = get_ticker_data(ticker) + + if not any(data.values()): + return { + "ticker": ticker, + "performance_score": None, + "outlook": "Unknown", + "justification": f"No data found for ticker {ticker} in the knowledge base.", + "data_summary": {"price_records": 0, "earnings_records": 0, "news_records": 0}, + } + + analysis = _run_async(_analyse_with_llm(ticker, data)) + analysis["data_summary"] = { + "price_records": len(data["price_data"]), + "earnings_records": len(data["earnings"]), + "news_records": len(data["news"]), + } + return analysis + + +# --------------------------------------------------------------------------- +# OpenAI tool schema +# --------------------------------------------------------------------------- + +TOOL_SCHEMA = { + "type": "function", + "function": { + "name": "analyse_stock_performance", + "description": ( + "Analyse a stock's performance using the Weaviate financial knowledge " + "base. Returns a performance score (1-10), " + "future outlook, and justification based on price history, earnings " + "transcripts, and news articles." + ), + "parameters": { + "type": "object", + "properties": { + "ticker": { + "type": "string", + "description": "Stock ticker symbol, e.g. 'AAPL', 'TSLA', 'GOOGL'.", + } + }, + "required": ["ticker"], + }, + }, +} + +TOOL_IMPLEMENTATIONS = { + "analyse_stock_performance": analyse_stock_performance, +} \ No newline at end of file diff --git a/src/hitachione/tools/sentiment_analysis_tool/test.py b/src/hitachione/tools/sentiment_analysis_tool/test.py new file mode 100644 index 0000000..c25208b --- /dev/null +++ b/src/hitachione/tools/sentiment_analysis_tool/test.py @@ -0,0 +1,502 @@ +""" +Test harness for the Sentiment Analysis Tool (Weaviate-backed). + +Run: + cd src/hitachione/tools/sentiment_analysis_tool + python3 -W ignore::ResourceWarning test.py all # full suite + python3 -W ignore::ResourceWarning test.py data # Weaviate queries only (no LLM) + python3 -W ignore::ResourceWarning test.py sentiment # LLM sentiment analysis tests + python3 -W ignore::ResourceWarning test.py schema # show tool schemas + python3 -W ignore::ResourceWarning test.py interactive # interactive prompt +""" + +import json +import os +import sys +import time +from pathlib import Path + +from dotenv import load_dotenv + +# Load .env from project root +load_dotenv(Path(__file__).resolve().parents[4] / ".env") + +from tool import ( + TOOL_SCHEMAS, + TOOL_IMPLEMENTATIONS, + TOOL_SCHEMA_TICKER, + TOOL_SCHEMA_YEAR, + TOOL_SCHEMA_NEWS, + TOOL_SCHEMA_TEXT, + query_weaviate_by_ticker, + query_weaviate_by_year, + query_weaviate_by_topic, + analyze_sentiment_sync, + analyze_ticker_sentiment_sync, + analyze_year_sentiment_sync, + analyze_news_sentiment_sync, + EXAMPLE_TEXT, +) + +# ── Known data in the Weaviate collection ── +KNOWN_TICKERS = ["AAPL", "AMZN", "GOOGL", "JPM", "META", "MSFT", "NVDA", "TSLA", "V", "WMT"] +UNKNOWN_TICKER = "ZZZZZ" + + +# ────────────────────────────────────────────────────────────────────────── +# Helpers +# ────────────────────────────────────────────────────────────────────────── + +def _section(title: str) -> None: + print("\n" + "=" * 80) + print(f" {title}") + print("=" * 80 + "\n") + + +def _pass(msg: str) -> None: + print(f" ✓ {msg}") + + +def _fail(msg: str) -> None: + print(f" ✗ {msg}") + + +def _has_llm_key() -> bool: + return bool( + os.getenv("OPENAI_API_KEY") + or os.getenv("GEMINI_API_KEY") + or os.getenv("GOOGLE_API_KEY") + ) + + +# ────────────────────────────────────────────────────────────────────────── +# Schema tests +# ────────────────────────────────────────────────────────────────────────── + +def test_tool_schemas() -> None: + """Validate all tool schema structures.""" + _section("Tool Schemas") + + for schema in TOOL_SCHEMAS: + assert schema["type"] == "function" + fn = schema["function"] + assert "name" in fn + assert "description" in fn + assert "parameters" in fn + assert fn["parameters"]["type"] == "object" + assert "required" in fn["parameters"] + _pass(f"Schema '{fn['name']}' is valid") + + # Verify implementation mapping + for schema in TOOL_SCHEMAS: + name = schema["function"]["name"] + assert name in TOOL_IMPLEMENTATIONS, f"No implementation for '{name}'" + assert callable(TOOL_IMPLEMENTATIONS[name]) + _pass(f"All {len(TOOL_SCHEMAS)} schemas have implementations") + + print("\nSchemas:") + for schema in TOOL_SCHEMAS: + print(json.dumps(schema, indent=2)) + + +# ────────────────────────────────────────────────────────────────────────── +# Weaviate data-retrieval tests (no LLM required) +# ────────────────────────────────────────────────────────────────────────── + +def test_query_by_ticker_known() -> None: + """Retrieve data for known tickers and verify results.""" + _section("Query by Ticker — Known Tickers") + + for ticker in ["AAPL", "TSLA", "JPM"]: + t0 = time.time() + records = query_weaviate_by_ticker(ticker, limit=20) + elapsed = time.time() - t0 + + assert isinstance(records, list), f"Expected list, got {type(records)}" + assert len(records) > 0, f"No records returned for {ticker}" + + # Each record should have text + for rec in records: + assert "text" in rec or "title" in rec, f"Record missing text/title for {ticker}" + + _pass(f"{ticker}: {len(records)} records ({elapsed:.2f}s)") + + +def test_query_by_ticker_unknown() -> None: + """Verify unknown ticker returns empty list.""" + _section("Query by Ticker — Unknown Ticker") + + records = query_weaviate_by_ticker(UNKNOWN_TICKER) + assert isinstance(records, list) + assert len(records) == 0, f"Expected 0 records for {UNKNOWN_TICKER}, got {len(records)}" + _pass(f"{UNKNOWN_TICKER}: 0 records as expected") + + +def test_query_by_ticker_case_insensitive() -> None: + """Verify ticker lookup is case-insensitive.""" + _section("Query by Ticker — Case Insensitivity") + + upper = query_weaviate_by_ticker("AAPL", limit=10) + lower = query_weaviate_by_ticker("aapl", limit=10) + + assert len(upper) == len(lower), ( + f"Count differs: 'AAPL' returned {len(upper)}, 'aapl' returned {len(lower)}" + ) + _pass("'AAPL' and 'aapl' return same number of results") + + +def test_query_by_year() -> None: + """Retrieve data for a year range.""" + _section("Query by Year") + + # Try several years — the dataset may span different years + for year in [2020, 2023, 2024]: + records = query_weaviate_by_year(year, limit=10) + assert isinstance(records, list) + if records: + for rec in records: + date_str = str(rec.get("date", "")) + assert str(year) in date_str[:4], ( + f"Record date '{date_str}' doesn't match year {year}" + ) + _pass(f"Year {year}: {len(records)} records") + else: + print(f" - Year {year}: no records (may not be in dataset)") + + +def test_query_by_topic() -> None: + """Retrieve articles matching a topic via BM25 search.""" + _section("Query by Topic (BM25)") + + for topic in ["earnings", "inflation", "technology"]: + t0 = time.time() + records = query_weaviate_by_topic(topic, limit=5) + elapsed = time.time() - t0 + + assert isinstance(records, list) + # BM25 should find something for broad financial topics + _pass(f"'{topic}': {len(records)} records ({elapsed:.2f}s)") + + if records: + # Spot-check structure + for rec in records: + assert isinstance(rec, dict) + + +def test_query_record_properties() -> None: + """Verify returned records contain expected properties.""" + _section("Record Properties") + + records = query_weaviate_by_ticker("MSFT", limit=5) + assert len(records) > 0, "No records for MSFT" + + expected = {"text", "title", "date", "dataset_source"} + for i, rec in enumerate(records[:3]): + present = set(rec.keys()) & expected + missing = expected - present + if not missing: + _pass(f"Record {i}: has all expected properties") + else: + # Some properties may be null for certain dataset sources + print(f" - Record {i}: missing {missing} (may be null)") + + +# ────────────────────────────────────────────────────────────────────────── +# LLM sentiment analysis tests +# ────────────────────────────────────────────────────────────────────────── + +def test_analyze_sentiment_free_text() -> None: + """Analyze sentiment of free-form text.""" + _section("Free Text Sentiment (LLM)") + + if not _has_llm_key(): + print(" ⚠️ No LLM API key — skipping") + return + + t0 = time.time() + result = analyze_sentiment_sync(EXAMPLE_TEXT) + elapsed = time.time() - t0 + + assert isinstance(result, dict) + assert "label" in result + assert "rating" in result + assert "rationale" in result + assert result["label"] in ("Negative", "Neutral", "Positive", "unknown") + if result["rating"] is not None: + assert 1 <= result["rating"] <= 10 + # Verify label matches rating + if result["rating"] <= 4: + assert result["label"] == "Negative" + elif result["rating"] == 5: + assert result["label"] == "Neutral" + else: + assert result["label"] == "Positive" + + _pass( + f"rating={result['rating']}, label={result['label']}, " + f"{elapsed:.1f}s" + ) + print(f" Rationale: {result['rationale'][:120]}") + + +def test_analyze_ticker_sentiment_known() -> None: + """Analyze ticker sentiment for known tickers.""" + _section("Ticker Sentiment — Known (LLM)") + + if not _has_llm_key(): + print(" ⚠️ No LLM API key — skipping") + return + + for ticker in ["AAPL", "NVDA"]: + t0 = time.time() + result = analyze_ticker_sentiment_sync(ticker) + elapsed = time.time() - t0 + + assert isinstance(result, dict) + assert result["ticker"] == ticker + assert result["rating"] is not None, f"Rating is None for {ticker}" + assert isinstance(result["rating"], int) + assert 1 <= result["rating"] <= 10 + # Verify label matches rating scale + if result["rating"] <= 4: + assert result["label"] == "Negative", f"Expected Negative for rating {result['rating']}" + elif result["rating"] == 5: + assert result["label"] == "Neutral", f"Expected Neutral for rating {result['rating']}" + else: + assert result["label"] == "Positive", f"Expected Positive for rating {result['rating']}" + assert len(result.get("rationale", "")) > 10 + assert isinstance(result.get("references", []), list) + + _pass( + f"{ticker}: rating={result['rating']}, label={result['label']}, " + f"{elapsed:.1f}s" + ) + print(f" Rationale: {result['rationale'][:120]}") + + +def test_analyze_ticker_sentiment_unknown() -> None: + """Analyze ticker sentiment for unknown ticker returns graceful result.""" + _section("Ticker Sentiment — Unknown") + + result = analyze_ticker_sentiment_sync(UNKNOWN_TICKER) + assert isinstance(result, dict) + assert result["ticker"] == UNKNOWN_TICKER + assert result["rating"] is None + assert result["label"] == "unknown" + assert "no data" in result["rationale"].lower() or "not found" in result["rationale"].lower() or len(result["rationale"]) > 0 + _pass(f"{UNKNOWN_TICKER}: rating=None, label=unknown (no LLM call needed)") + + +def test_analyze_year_sentiment() -> None: + """Analyze year sentiment (requires data for that year in Weaviate).""" + _section("Year Sentiment (LLM)") + + if not _has_llm_key(): + print(" ⚠️ No LLM API key — skipping") + return + + # First find a year that has data + for year in [2024, 2023, 2020]: + records = query_weaviate_by_year(year, limit=5) + if records: + break + else: + print(" ⚠️ No year data found in KB — skipping") + return + + t0 = time.time() + result = analyze_year_sentiment_sync(year) + elapsed = time.time() - t0 + + assert isinstance(result, dict) + assert result["year"] == year + assert "label" in result + assert "rating" in result + assert "rationale" in result + if result["rating"] is not None: + assert 1 <= result["rating"] <= 10 + if result["rating"] <= 4: + assert result["label"] == "Negative" + elif result["rating"] == 5: + assert result["label"] == "Neutral" + else: + assert result["label"] == "Positive" + + _pass( + f"Year {year}: rating={result['rating']}, label={result['label']}, " + f"{elapsed:.1f}s" + ) + print(f" Rationale: {result['rationale'][:120]}") + + +def test_analyze_news_sentiment() -> None: + """Analyze news sentiment by topic query.""" + _section("News Sentiment by Topic (LLM)") + + if not _has_llm_key(): + print(" ⚠️ No LLM API key — skipping") + return + + t0 = time.time() + result = analyze_news_sentiment_sync("technology stocks earnings") + elapsed = time.time() - t0 + + assert isinstance(result, dict) + assert "label" in result + assert "rating" in result + assert "rationale" in result + if result["rating"] is not None: + assert 1 <= result["rating"] <= 10 + + _pass( + f"rating={result['rating']}, label={result['label']}, " + f"{elapsed:.1f}s" + ) + print(f" Rationale: {result['rationale'][:120]}") + + +def test_analyze_news_sentiment_no_results() -> None: + """Topic query with no matching results returns graceful output.""" + _section("News Sentiment — No Results") + + result = analyze_news_sentiment_sync("xyzzy_no_such_topic_9999") + assert isinstance(result, dict) + assert result["label"] == "unknown" + assert result["rating"] is None + _pass("No-match query: label=unknown, rating=None (no LLM call)") + + +def test_all_tickers_sentiment() -> None: + """Run sentiment analysis across all known tickers.""" + _section("All Known Tickers Sentiment (LLM)") + + if not _has_llm_key(): + print(" ⚠️ No LLM API key — skipping") + return + + for ticker in KNOWN_TICKERS: + t0 = time.time() + result = analyze_ticker_sentiment_sync(ticker) + elapsed = time.time() - t0 + + assert isinstance(result, dict) + assert result["ticker"] == ticker + assert isinstance(result["rating"], int) + assert 1 <= result["rating"] <= 10 + # Verify label matches rating + if result["rating"] <= 4: + expected_label = "Negative" + elif result["rating"] == 5: + expected_label = "Neutral" + else: + expected_label = "Positive" + assert result["label"] == expected_label + + _pass( + f"{ticker}: rating={result['rating']}, label={result['label']}, " + f"{elapsed:.1f}s" + ) + + +# ────────────────────────────────────────────────────────────────────────── +# Interactive mode +# ────────────────────────────────────────────────────────────────────────── + +def interactive() -> None: + """Prompt the user for a ticker / year / topic / text and run analysis.""" + _section("Interactive Mode") + + print("Options:") + print(" 1. Ticker sentiment (e.g. AAPL)") + print(" 2. Year sentiment (e.g. 2024)") + print(" 3. Topic sentiment (e.g. inflation)") + print(" 4. Free text sentiment") + print() + + choice = input("Choose (1-4): ").strip() + + if choice == "1": + ticker = input("Ticker: ").strip().upper() + result = analyze_ticker_sentiment_sync(ticker) + elif choice == "2": + year = int(input("Year: ").strip()) + result = analyze_year_sentiment_sync(year) + elif choice == "3": + topic = input("Topic: ").strip() + result = analyze_news_sentiment_sync(topic) + elif choice == "4": + text = input("Text: ").strip() + result = analyze_sentiment_sync(text) + else: + print("Invalid choice.") + return + + print(json.dumps(result, indent=2, ensure_ascii=False)) + + +# ────────────────────────────────────────────────────────────────────────── +# Runner +# ────────────────────────────────────────────────────────────────────────── + +TEST_GROUPS = { + "schema": [test_tool_schemas], + "data": [ + test_query_by_ticker_known, + test_query_by_ticker_unknown, + test_query_by_ticker_case_insensitive, + test_query_by_year, + test_query_by_topic, + test_query_record_properties, + ], + "sentiment": [ + test_analyze_sentiment_free_text, + test_analyze_ticker_sentiment_known, + test_analyze_ticker_sentiment_unknown, + test_analyze_year_sentiment, + test_analyze_news_sentiment, + test_analyze_news_sentiment_no_results, + ], + "full": [ + test_all_tickers_sentiment, + ], +} + + +def main() -> None: + mode = sys.argv[1] if len(sys.argv) > 1 else "all" + + if mode == "interactive": + interactive() + return + + if mode == "all": + groups = ["schema", "data", "sentiment"] + elif mode in TEST_GROUPS: + groups = [mode] + else: + print(f"Unknown mode: {mode}") + print(f"Available: {', '.join(list(TEST_GROUPS.keys()) + ['all', 'interactive'])}") + sys.exit(1) + + passed = 0 + failed = 0 + + for group in groups: + for test_fn in TEST_GROUPS[group]: + try: + test_fn() + passed += 1 + except Exception as exc: + _fail(f"{test_fn.__name__}: {exc}") + failed += 1 + + _section("Summary") + print(f" Passed: {passed}") + print(f" Failed: {failed}") + + if failed: + sys.exit(1) + + +if __name__ == "__main__": + main() diff --git a/src/hitachione/tools/sentiment_analysis_tool/tool.py b/src/hitachione/tools/sentiment_analysis_tool/tool.py new file mode 100644 index 0000000..ccde4c0 --- /dev/null +++ b/src/hitachione/tools/sentiment_analysis_tool/tool.py @@ -0,0 +1,658 @@ +"""Sentiment analysis tool backed by Weaviate knowledge base. + +Queries the Weaviate financial news collection for data related to a ticker, +year, or topic, then uses an LLM to produce a structured sentiment assessment. +""" + +from __future__ import annotations + +import asyncio +import json +import os +from pathlib import Path +from typing import Any + +from dotenv import load_dotenv + +# Load .env from project root (5 levels up from this file) +load_dotenv(Path(__file__).resolve().parents[4] / ".env") + +import sys +sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent)) + +import weaviate +from weaviate.auth import AuthApiKey +from weaviate.classes.query import Filter +from openai import AsyncOpenAI + +from utils.env_vars import Configs + + +# --------------------------------------------------------------------------- +# Weaviate helper (sync, for data retrieval) +# --------------------------------------------------------------------------- + +WEAVIATE_COLLECTION = os.getenv("WEAVIATE_COLLECTION_NAME", "Hitachi_finance_news") + + +def _get_weaviate_sync_client(): + """Create a synchronous Weaviate client from environment variables.""" + http_host = os.getenv("WEAVIATE_HTTP_HOST", "localhost") + api_key = os.getenv("WEAVIATE_API_KEY", "") + + if http_host.endswith(".weaviate.cloud"): + return weaviate.connect_to_weaviate_cloud( + cluster_url=f"https://{http_host}", + auth_credentials=AuthApiKey(api_key), + ) + + return weaviate.connect_to_custom( + http_host=http_host, + http_port=int(os.getenv("WEAVIATE_HTTP_PORT", "8080")), + http_secure=os.getenv("WEAVIATE_HTTP_SECURE", "false").lower() == "true", + grpc_host=os.getenv("WEAVIATE_GRPC_HOST", "localhost"), + grpc_port=int(os.getenv("WEAVIATE_GRPC_PORT", "50051")), + grpc_secure=os.getenv("WEAVIATE_GRPC_SECURE", "false").lower() == "true", + auth_credentials=AuthApiKey(api_key), + ) + + +def query_weaviate_by_ticker( + ticker: str, + limit: int = 20, +) -> list[dict[str, Any]]: + """Retrieve news, earnings, and price data for a ticker from Weaviate.""" + client = _get_weaviate_sync_client() + try: + col = client.collections.get(WEAVIATE_COLLECTION) + ticker_filter = Filter.by_property("ticker").equal(ticker.upper()) + + response = col.query.fetch_objects( + filters=ticker_filter, + limit=limit, + return_properties=[ + "text", "title", "date", "category", "dataset_source", + "ticker", "company", "quarter", "fiscal_year", + ], + ) + + # Also grab news that *mention* this ticker + mentioned_response = col.query.fetch_objects( + filters=Filter.by_property("mentioned_companies").contains_any( + [ticker.upper()] + ), + limit=limit, + return_properties=[ + "text", "title", "date", "category", "dataset_source", + ], + ) + + seen_titles: set[str] = set() + results: list[dict] = [] + for obj in list(response.objects) + list(mentioned_response.objects): + props = {k: v for k, v in obj.properties.items() if v is not None} + title = props.get("title", "") + if title not in seen_titles: + seen_titles.add(title) + results.append(props) + + return sorted(results, key=lambda d: d.get("date", "")) + finally: + client.close() + + +def query_weaviate_by_year( + year: int, + limit: int = 20, +) -> list[dict[str, Any]]: + """Retrieve financial data for a specific year from Weaviate.""" + client = _get_weaviate_sync_client() + try: + col = client.collections.get(WEAVIATE_COLLECTION) + + results: list[dict] = [] + seen_titles: set[str] = set() + + for obj in col.iterator( + include_vector=False, + return_properties=[ + "text", "title", "date", "category", "dataset_source", + "ticker", "company", + ], + ): + date_str = obj.properties.get("date", "") + if date_str and str(year) in str(date_str)[:4]: + props = {k: v for k, v in obj.properties.items() if v is not None} + title = props.get("title", "") + if title not in seen_titles: + seen_titles.add(title) + results.append(props) + if len(results) >= limit: + break + + return sorted(results, key=lambda d: d.get("date", "")) + finally: + client.close() + + +def query_weaviate_by_topic( + topic: str, + limit: int = 10, +) -> list[dict[str, Any]]: + """Retrieve articles matching a topic via keyword search in Weaviate.""" + client = _get_weaviate_sync_client() + try: + col = client.collections.get(WEAVIATE_COLLECTION) + + response = col.query.bm25( + query=topic, + limit=limit, + return_properties=[ + "text", "title", "date", "category", "dataset_source", + "ticker", "company", + ], + ) + + results = [] + for obj in response.objects: + props = {k: v for k, v in obj.properties.items() if v is not None} + results.append(props) + + return sorted(results, key=lambda d: d.get("date", "")) + finally: + client.close() + + +# --------------------------------------------------------------------------- +# System prompts +# --------------------------------------------------------------------------- + +# Rating scale used across all prompts: +# 1-4 → Negative +# 5 → Neutral +# 6-10 → Positive + +_RATING_SCALE_TEXT = ( + "Use a 1-10 sentiment rating scale where: " + "1-4 = Negative, 5 = Neutral, 6-10 = Positive. " +) + +SYSTEM_PROMPT = ( + "You are a sentiment analysis agent. " + + _RATING_SCALE_TEXT + + "Return JSON with keys: rating (integer 1-10), " + "rationale (short string explaining the score)." +) + +NEWS_SYSTEM_PROMPT = ( + "You are a financial news sentiment analyst. " + "Given financial data (news articles, earnings transcripts, price data), " + + _RATING_SCALE_TEXT + + "Return JSON with keys: rating (integer 1-10), " + "rationale (short string explaining the score), " + "references (array of short quoted phrases from the data)." +) + +YEAR_SYSTEM_PROMPT = ( + "You are a financial news sentiment analyst. " + "Given snippets from financial news in a specific year, " + + _RATING_SCALE_TEXT + + "Return JSON with keys: year (integer), rating (integer 1-10), " + "rationale (short string explaining the score)." +) + +TICKER_SYSTEM_PROMPT = ( + "You are a financial sentiment analyst. " + "Given all available data (price history, earnings transcripts, news articles) " + "for a specific stock ticker, provide an overall sentiment assessment. " + + _RATING_SCALE_TEXT + + "Return JSON with keys: ticker (string), rating (integer 1-10), " + "rationale (short string explaining the score), " + "references (array of short quoted phrases from the data)." +) + +EXAMPLE_TEXT = ( + "The market showed resilience today despite inflation fears, " + "with tech stocks leading the recovery." +) + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def _run_async(coro): + """Run an async coroutine, handling nested event loops.""" + try: + loop = asyncio.get_running_loop() + except RuntimeError: + loop = None + + if loop and loop.is_running(): + import concurrent.futures + with concurrent.futures.ThreadPoolExecutor() as pool: + return pool.submit(asyncio.run, coro).result() + return asyncio.run(coro) + + +# --------------------------------------------------------------------------- +# Core analysis functions +# --------------------------------------------------------------------------- + +def _derive_label(rating: int | float | None) -> str: + """Derive a sentiment label from a 1-10 rating. + + Scale: + 1-4 → Negative + 5 → Neutral + 6-10 → Positive + """ + if rating is None: + return "unknown" + r = int(round(rating)) + if r <= 4: + return "Negative" + elif r == 5: + return "Neutral" + else: + return "Positive" + + +def _parse_json_response(content: str) -> dict[str, Any]: + """Parse JSON from LLM response, handling code fences.""" + content = content.strip() + if "```json" in content: + content = content.split("```json")[1].split("```")[0].strip() + elif "```" in content: + content = content.split("```")[1].split("```")[0].strip() + try: + parsed = json.loads(content) + # If the LLM returns a list, wrap the first element or flatten + if isinstance(parsed, list): + return parsed[0] if parsed else {"rating": None, "label": "unknown", "rationale": content} + return parsed + except json.JSONDecodeError: + return {"rating": None, "label": "unknown", "rationale": content} + + +def _format_kb_data(records: list[dict[str, Any]], max_chars: int = 15000) -> str: + """Format Weaviate records into a text block for the LLM.""" + parts: list[str] = [] + total = 0 + for rec in records: + source = rec.get("dataset_source", "unknown") + date = rec.get("date", "") + title = rec.get("title", "") + text = str(rec.get("text", ""))[:1000] + entry = f"[{source} | {date}] {title}\n{text}" + if total + len(entry) > max_chars: + break + parts.append(entry) + total += len(entry) + return "\n\n".join(parts) + + +async def analyze_sentiment(text: str, model: str | None = None) -> dict[str, Any]: + """Analyze sentiment for arbitrary text.""" + configs = Configs() + client = AsyncOpenAI( + base_url=configs.openai_base_url, + api_key=configs.openai_api_key, + ) + + try: + response = await client.chat.completions.create( + model=model or configs.default_worker_model, + messages=[ + {"role": "system", "content": SYSTEM_PROMPT}, + {"role": "user", "content": text}, + ], + temperature=0, + ) + content = (response.choices[0].message.content or "").strip() + finally: + await client.close() + + data = _parse_json_response(content) + # Normalise to rating-based output + rating = data.get("rating") + if rating is not None: + rating = int(round(rating)) + data["rating"] = rating + data["label"] = _derive_label(rating) + data.setdefault("rationale", "") + # Remove legacy score key if present + data.pop("score", None) + return data + + +async def analyze_ticker_sentiment( + ticker: str, + model: str | None = None, + limit: int = 20, +) -> dict[str, Any]: + """Analyze sentiment for a stock ticker using Weaviate KB data. + + Retrieves all available data (price history, earnings, news) for the + ticker from the Weaviate knowledge base and produces a sentiment rating. + """ + configs = Configs() + client = AsyncOpenAI( + base_url=configs.openai_base_url, + api_key=configs.openai_api_key, + ) + + records = query_weaviate_by_ticker(ticker, limit=limit) + + if not records: + return { + "ticker": ticker.upper(), + "rating": None, + "label": "unknown", + "rationale": f"No data found for ticker {ticker} in the knowledge base.", + "references": [], + } + + combined = _format_kb_data(records) + + try: + response = await client.chat.completions.create( + model=model or configs.default_worker_model, + messages=[ + {"role": "system", "content": TICKER_SYSTEM_PROMPT}, + {"role": "user", "content": combined}, + ], + temperature=0, + ) + content = (response.choices[0].message.content or "").strip() + finally: + await client.close() + + data = _parse_json_response(content) + rating = data.get("rating") + if rating is not None: + rating = int(round(rating)) + data["ticker"] = ticker.upper() + data["rating"] = rating + data["label"] = _derive_label(rating) + data.setdefault("rationale", "") + data.setdefault("references", []) + return data + + +async def analyze_financial_year_sentiment( + year: int, + model: str | None = None, + max_results: int = 20, +) -> dict[str, Any]: + """Analyze sentiment for financial news in a given year using Weaviate KB.""" + configs = Configs() + client = AsyncOpenAI( + base_url=configs.openai_base_url, + api_key=configs.openai_api_key, + ) + + records = query_weaviate_by_year(year, limit=max_results) + + if not records: + return { + "year": year, + "rating": None, + "label": "unknown", + "rationale": f"No data found for year {year} in the knowledge base.", + } + + combined = _format_kb_data(records) + + try: + response = await client.chat.completions.create( + model=model or configs.default_worker_model, + messages=[ + {"role": "system", "content": YEAR_SYSTEM_PROMPT}, + {"role": "user", "content": combined}, + ], + temperature=0, + ) + content = (response.choices[0].message.content or "").strip() + finally: + await client.close() + + data = _parse_json_response(content) + rating = data.get("rating") + if rating is not None: + rating = int(round(rating)) + data["year"] = year + data["rating"] = rating + data["label"] = _derive_label(rating) + data.setdefault("rationale", "") + # Remove legacy score key if present + data.pop("score", None) + return data + + +async def analyze_financial_news_sentiment( + query: str, + model: str | None = None, + limit: int = 10, +) -> dict[str, Any]: + """Analyze sentiment for financial news matching a topic query. + + Searches the Weaviate KB for articles matching the query and produces + a sentiment rating based on the retrieved content. + """ + configs = Configs() + client = AsyncOpenAI( + base_url=configs.openai_base_url, + api_key=configs.openai_api_key, + ) + + records = query_weaviate_by_topic(query, limit=limit) + + if not records: + return { + "rating": None, + "label": "unknown", + "rationale": f"No articles found for query '{query}' in the knowledge base.", + "references": [], + } + + combined = _format_kb_data(records) + + try: + response = await client.chat.completions.create( + model=model or configs.default_worker_model, + messages=[ + {"role": "system", "content": NEWS_SYSTEM_PROMPT}, + {"role": "user", "content": combined}, + ], + temperature=0, + ) + content = (response.choices[0].message.content or "").strip() + finally: + await client.close() + + data = _parse_json_response(content) + rating = data.get("rating") + if rating is not None: + rating = int(round(rating)) + data["rating"] = rating + data["label"] = _derive_label(rating) + data.setdefault("rationale", "") + data.setdefault("references", []) + return data + + +# --------------------------------------------------------------------------- +# Synchronous wrappers for external callers +# --------------------------------------------------------------------------- + +def analyze_sentiment_sync(text: str, model: str | None = None) -> dict[str, Any]: + """Synchronous wrapper for ``analyze_sentiment``.""" + return _run_async(analyze_sentiment(text, model=model)) + + +def analyze_ticker_sentiment_sync( + ticker: str, model: str | None = None, limit: int = 20, +) -> dict[str, Any]: + """Synchronous wrapper for ``analyze_ticker_sentiment``.""" + return _run_async(analyze_ticker_sentiment(ticker, model=model, limit=limit)) + + +def analyze_year_sentiment_sync( + year: int, model: str | None = None, max_results: int = 20, +) -> dict[str, Any]: + """Synchronous wrapper for ``analyze_financial_year_sentiment``.""" + return _run_async( + analyze_financial_year_sentiment(year, model=model, max_results=max_results) + ) + + +def analyze_news_sentiment_sync( + query: str, model: str | None = None, limit: int = 10, +) -> dict[str, Any]: + """Synchronous wrapper for ``analyze_financial_news_sentiment``.""" + return _run_async( + analyze_financial_news_sentiment(query, model=model, limit=limit) + ) + + +# --------------------------------------------------------------------------- +# OpenAI tool schemas +# --------------------------------------------------------------------------- + +TOOL_SCHEMA_TICKER = { + "type": "function", + "function": { + "name": "analyze_ticker_sentiment", + "description": ( + "Analyze the overall sentiment for a stock ticker using the " + "Weaviate financial knowledge base. Returns a sentiment rating " + "(1-10: 1-4 Negative, 5 Neutral, 6-10 Positive), " + "label, rationale, and supporting references." + ), + "parameters": { + "type": "object", + "properties": { + "ticker": { + "type": "string", + "description": "Stock ticker symbol, e.g. 'AAPL', 'TSLA'.", + }, + }, + "required": ["ticker"], + }, + }, +} + +TOOL_SCHEMA_YEAR = { + "type": "function", + "function": { + "name": "analyze_year_sentiment", + "description": ( + "Analyze the overall financial-news sentiment for a given year " + "using the Weaviate knowledge base. Returns a rating " + "(1-10: 1-4 Negative, 5 Neutral, 6-10 Positive), label, and rationale." + ), + "parameters": { + "type": "object", + "properties": { + "year": { + "type": "integer", + "description": "Four-digit year, e.g. 2024.", + }, + }, + "required": ["year"], + }, + }, +} + +TOOL_SCHEMA_NEWS = { + "type": "function", + "function": { + "name": "analyze_news_sentiment", + "description": ( + "Search the Weaviate knowledge base for financial news matching " + "a topic query and return a sentiment rating " + "(1-10: 1-4 Negative, 5 Neutral, 6-10 Positive), label, and rationale." + ), + "parameters": { + "type": "object", + "properties": { + "query": { + "type": "string", + "description": "Natural-language topic query, e.g. 'tech earnings Q3'.", + }, + }, + "required": ["query"], + }, + }, +} + +TOOL_SCHEMA_TEXT = { + "type": "function", + "function": { + "name": "analyze_sentiment", + "description": ( + "Classify arbitrary text sentiment on a 1-10 scale: " + "1-4 = Negative, 5 = Neutral, 6-10 = Positive. " + "Returns a rating, label, and rationale." + ), + "parameters": { + "type": "object", + "properties": { + "text": { + "type": "string", + "description": "Free-form text to analyze.", + }, + }, + "required": ["text"], + }, + }, +} + +# Convenience list of all schemas +TOOL_SCHEMAS = [ + TOOL_SCHEMA_TICKER, + TOOL_SCHEMA_YEAR, + TOOL_SCHEMA_NEWS, + TOOL_SCHEMA_TEXT, +] + +TOOL_IMPLEMENTATIONS = { + "analyze_ticker_sentiment": analyze_ticker_sentiment_sync, + "analyze_year_sentiment": analyze_year_sentiment_sync, + "analyze_news_sentiment": analyze_news_sentiment_sync, + "analyze_sentiment": analyze_sentiment_sync, +} + + +# --------------------------------------------------------------------------- +# CLI (for quick interactive testing) +# --------------------------------------------------------------------------- + +def _format_output(data: dict[str, Any]) -> str: + return json.dumps(data, indent=2, ensure_ascii=False) + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser(description="Sentiment analysis tool (Weaviate KB)") + parser.add_argument("text", nargs="?", default=None, help="Free text to analyze") + parser.add_argument("--model", help="Optional model override") + parser.add_argument("--ticker", help="Analyze sentiment for a stock ticker") + parser.add_argument("--year", type=int, help="Analyze sentiment for a year") + parser.add_argument("--query", help="Search KB for matching financial news") + args = parser.parse_args() + + if args.ticker: + result = analyze_ticker_sentiment_sync(args.ticker, model=args.model) + elif args.year is not None: + result = analyze_year_sentiment_sync(args.year, model=args.model) + elif args.query: + result = analyze_news_sentiment_sync(args.query, model=args.model) + elif args.text: + result = analyze_sentiment_sync(args.text, model=args.model) + else: + result = analyze_ticker_sentiment_sync("AAPL", model=args.model) + + print(_format_output(result))