diff --git a/.env.example b/.env.example index dc518b501..b5d5443dd 100644 --- a/.env.example +++ b/.env.example @@ -76,3 +76,6 @@ TARGET_REPO_PATH=. # Ollama base URL (without /v1 suffix) OLLAMA_BASE_URL=http://localhost:11434 + +ALLOWED_PROJECT_ROOTS=/path/to/project/root +MCP_MODE=query diff --git a/.gitignore b/.gitignore index 4b6211856..b6d03267a 100644 --- a/.gitignore +++ b/.gitignore @@ -23,3 +23,4 @@ PROJECT.md .DS_Store .pypi_cache.json .omc +openspec diff --git a/README.md b/README.md index 5ef87d4e0..c94c47040 100644 --- a/README.md +++ b/README.md @@ -513,17 +513,70 @@ The agent will incorporate the guidance from your reference documents when sugge Code-Graph-RAG can run as an MCP (Model Context Protocol) server, enabling seamless integration with Claude Code and other MCP clients. +### MCP Dual Mode System (v0.0.60+) + +The MCP server now supports two distinct modes with different capabilities and security profiles: + +#### Query Mode (Production Recommended) +**Read-only access** for safe codebase exploration and analysis. + +**Available Tools:** +- `list_projects` - List all indexed projects +- `query_code_graph` - Natural language graph queries +- `get_code_snippet` - Retrieve source code by qualified name +- `list_directory` - Browse directory structure + +**Use Cases:** +- Production environments where code modification is not allowed +- Code review and exploration +- Documentation generation +- Architecture analysis + +#### Edit Mode (Development) +**Full access** including file editing and database management. + +**Additional Tools (beyond Query mode):** +- `read_file` / `write_file` - File operations +- `surgical_replace_code` - Precise code editing +- `delete_project` - Remove projects from graph +- `wipe_database` - Complete database reset (dangerous!) +- `index_repository` - Build/update knowledge graph + +**Use Cases:** +- Local development environments +- Code refactoring assistance +- Automated code generation +- Database maintenance + ### Quick Setup +#### Query Mode (Recommended for Production) + +```bash +claude mcp add --transport stdio code-graph-rag \ + --env TARGET_REPO_PATH="$(pwd)" \ + --env MCP_MODE=query \ + --env CYPHER_PROVIDER=openai \ + --env CYPHER_MODEL=gpt-4 \ + --env CYPHER_API_KEY=your-api-key \ + -- uv run --directory /path/to/code-graph-rag code-graph-rag mcp-server +``` + +#### Edit Mode (For Development) + ```bash claude mcp add --transport stdio code-graph-rag \ - --env TARGET_REPO_PATH=/absolute/path/to/your/project \ + --env TARGET_REPO_PATH="$(pwd)" \ + --env MCP_MODE=edit \ + --env ALLOWED_PROJECT_ROOTS="$(pwd)" \ --env CYPHER_PROVIDER=openai \ --env CYPHER_MODEL=gpt-4 \ --env CYPHER_API_KEY=your-api-key \ -- uv run --directory /path/to/code-graph-rag code-graph-rag mcp-server ``` +**Important:** Always set `ALLOWED_PROJECT_ROOTS` in Edit mode to restrict file operations to specific directories. + ### Available Tools @@ -543,13 +596,48 @@ claude mcp add --transport stdio code-graph-rag \ ### Example Usage +#### Query Mode ``` -> Index this repository > What functions call UserService.create_user? +> Show me all classes that implement Repository +> List all modules in the utils package +> Get the source code for AuthService.login +``` + +#### Edit Mode +``` +> Index this repository > Update the login function to add rate limiting +> Refactor this class to use dependency injection +> Delete the deprecated project from the graph ``` -For detailed setup, see [Claude Code Setup Guide](docs/claude-code-setup.md). +### Security Configuration + +For Edit mode, always restrict access with `ALLOWED_PROJECT_ROOTS`: + +```bash +# Single project +--env ALLOWED_PROJECT_ROOTS="/path/to/project" + +# Multiple projects (comma-separated) +--env ALLOWED_PROJECT_ROOTS="/path/to/project1,/path/to/project2" +``` + +This ensures file operations cannot modify files outside the specified directories. + +### Mode Selection Guide + +| Scenario | Recommended Mode | Reasoning | +|----------|-----------------|-----------| +| Production code review | Query | Prevents accidental modifications | +| Development work | Edit | Allows code generation and editing | +| CI/CD pipelines | Query | Read-only analysis is sufficient | +| Local experimentation | Edit | Full control for testing | +| Multi-project analysis | Query | Safe exploration across projects | +| Code refactoring | Edit | Requires write access | + +For detailed setup and configuration examples, see [Claude Code Setup Guide](docs/claude-code-setup.md) and [Security Best Practices](docs/security-best-practices.md). ## 📊 Graph Schema @@ -560,20 +648,20 @@ The knowledge graph uses the following node types and relationships: | Label | Properties | |-----|----------| -| Project | `{name: string}` | -| Package | `{qualified_name: string, name: string, path: string}` | -| Folder | `{path: string, name: string}` | -| File | `{path: string, name: string, extension: string}` | -| Module | `{qualified_name: string, name: string, path: string}` | -| Class | `{qualified_name: string, name: string, decorators: list[string]}` | -| Function | `{qualified_name: string, name: string, decorators: list[string]}` | -| Method | `{qualified_name: string, name: string, decorators: list[string]}` | -| Interface | `{qualified_name: string, name: string}` | -| Enum | `{qualified_name: string, name: string}` | -| Type | `{qualified_name: string, name: string}` | -| Union | `{qualified_name: string, name: string}` | -| ModuleInterface | `{qualified_name: string, name: string, path: string}` | -| ModuleImplementation | `{qualified_name: string, name: string, path: string, implements_module: string}` | +| Project | `{name: string, absolute_path: string, project_name: string}` | +| Package | `{qualified_name: string, name: string, path: string, absolute_path: string, project_name: string}` | +| Folder | `{path: string, name: string, absolute_path: string, project_name: string}` | +| File | `{path: string, name: string, extension: string, absolute_path: string, project_name: string}` | +| Module | `{qualified_name: string, name: string, path: string, absolute_path: string, project_name: string}` | +| Class | `{qualified_name: string, name: string, path: string, absolute_path: string, project_name: string, decorators: list[string]}` | +| Function | `{qualified_name: string, name: string, path: string, absolute_path: string, project_name: string, decorators: list[string]}` | +| Method | `{qualified_name: string, name: string, path: string, absolute_path: string, project_name: string, decorators: list[string]}` | +| Interface | `{qualified_name: string, name: string, path: string, absolute_path: string, project_name: string}` | +| Enum | `{qualified_name: string, name: string, path: string, absolute_path: string, project_name: string}` | +| Type | `{qualified_name: string, name: string, path: string, absolute_path: string, project_name: string}` | +| Union | `{qualified_name: string, name: string, path: string, absolute_path: string, project_name: string}` | +| ModuleInterface | `{qualified_name: string, name: string, path: string, absolute_path: string, project_name: string}` | +| ModuleImplementation | `{qualified_name: string, name: string, path: string, absolute_path: string, project_name: string, implements_module: string}` | | ExternalPackage | `{name: string, version_spec: string}` | @@ -653,6 +741,10 @@ Configuration is managed through environment variables in `.env` file: - `TARGET_REPO_PATH`: Default repository path (default: `.`) - `LOCAL_MODEL_ENDPOINT`: Fallback endpoint for Ollama (default: `http://localhost:11434/v1`) +### MCP Server Configuration +- `MCP_MODE`: MCP server operation mode - `query` (read-only) or `edit` (full access). Default: `edit`. **Recommended: Use `query` mode for production environments.** +- `ALLOWED_PROJECT_ROOTS`: Comma-separated list of allowed project root paths for file operations in Edit mode. This is a critical security setting that restricts file read/write operations to specified directories. Example: `/path/to/project1,/path/to/project2` + ### Custom Ignore Patterns You can specify additional directories to exclude by creating a `.cgrignore` file in your repository root: diff --git a/codebase_rag/config.py b/codebase_rag/config.py index 31848e4d1..8ea440aa6 100644 --- a/codebase_rag/config.py +++ b/codebase_rag/config.py @@ -6,7 +6,7 @@ from dotenv import load_dotenv from loguru import logger -from pydantic import Field +from pydantic import Field, field_validator from pydantic_settings import BaseSettings, SettingsConfigDict from . import constants as cs @@ -17,6 +17,19 @@ load_dotenv() +def _parse_frozenset_of_strings(value: str | frozenset[str] | None) -> frozenset[Path]: + if value is None: + return frozenset() + if isinstance(value, frozenset): + return frozenset(Path(path) for path in value) + if isinstance(value, str): + if value.strip(): + return frozenset( + Path(path.strip()) for path in value.split(",") if path.strip() + ) + return frozenset() + + class ApiKeyInfoEntry(TypedDict): env_var: str url: str @@ -171,7 +184,21 @@ def ollama_endpoint(self) -> str: return f"{self.OLLAMA_BASE_URL.rstrip('/')}/v1" TARGET_REPO_PATH: str = "." + ALLOWED_PROJECT_ROOTS: str = "" SHELL_COMMAND_TIMEOUT: int = 30 + MCP_MODE: str = "edit" + + @field_validator("MCP_MODE") + @classmethod + def _validate_mcp_mode(cls, v: str) -> str: + if v not in ("query", "edit"): + raise ValueError("MCP_MODE must be 'query' or 'edit'") + return v + + @property + def allowed_project_roots_set(self) -> frozenset[Path]: + return _parse_frozenset_of_strings(self.ALLOWED_PROJECT_ROOTS) + SHELL_COMMAND_ALLOWLIST: frozenset[str] = frozenset( { "ls", diff --git a/codebase_rag/constants.py b/codebase_rag/constants.py index 4ef971d8a..5bd6373a5 100644 --- a/codebase_rag/constants.py +++ b/codebase_rag/constants.py @@ -181,6 +181,8 @@ class GoogleProviderType(StrEnum): KEY_VERSION_SPEC = "version_spec" KEY_PREFIX = "prefix" KEY_PROJECT_NAME = "project_name" +KEY_ABSOLUTE_PATH = "absolute_path" +EXTERNAL_PROJECT_NAME = "__external__" KEY_IS_EXTERNAL = "is_external" ERR_SUBSTR_ALREADY_EXISTS = "already exists" @@ -419,11 +421,10 @@ class RelationshipType(StrEnum): CYPHER_QUERY_EMBEDDINGS = """ MATCH (m:Module)-[:DEFINES]->(n) -WHERE (n:Function OR n:Method) - AND m.qualified_name STARTS WITH $project_name + '.' +WHERE n.project_name = $project_name RETURN id(n) AS node_id, n.qualified_name AS qualified_name, n.start_line AS start_line, n.end_line AS end_line, - m.path AS path + n.path AS path """ diff --git a/codebase_rag/cypher_queries.py b/codebase_rag/cypher_queries.py index 8d70bae4e..8f3cc372f 100644 --- a/codebase_rag/cypher_queries.py +++ b/codebase_rag/cypher_queries.py @@ -13,47 +13,58 @@ CYPHER_EXAMPLE_DECORATED_FUNCTIONS = f"""MATCH (n:Function|Method) WHERE ANY(d IN n.decorators WHERE toLower(d) IN ['flow', 'task']) -RETURN n.name AS name, n.qualified_name AS qualified_name, labels(n) AS type +RETURN n.name AS name, n.qualified_name AS qualified_name, labels(n) AS type, + n.path AS relative_path, n.absolute_path AS absolute_path, n.project_name AS project_name LIMIT {CYPHER_DEFAULT_LIMIT}""" CYPHER_EXAMPLE_CONTENT_BY_PATH = f"""MATCH (n) WHERE n.path IS NOT NULL AND n.path STARTS WITH 'workflows' -RETURN n.name AS name, n.path AS path, labels(n) AS type +RETURN n.name AS name, n.path AS relative_path, n.absolute_path AS absolute_path, + n.project_name AS project_name, labels(n) AS type LIMIT {CYPHER_DEFAULT_LIMIT}""" CYPHER_EXAMPLE_KEYWORD_SEARCH = f"""MATCH (n) WHERE toLower(n.name) CONTAINS 'database' OR (n.qualified_name IS NOT NULL AND toLower(n.qualified_name) CONTAINS 'database') -RETURN n.name AS name, n.qualified_name AS qualified_name, labels(n) AS type +RETURN n.name AS name, n.qualified_name AS qualified_name, labels(n) AS type, + n.path AS relative_path, n.absolute_path AS absolute_path, n.project_name AS project_name LIMIT {CYPHER_DEFAULT_LIMIT}""" CYPHER_EXAMPLE_FIND_FILE = """MATCH (f:File) WHERE toLower(f.name) = 'readme.md' AND f.path = 'README.md' -RETURN f.path as path, f.name as name, labels(f) as type""" +RETURN f.path AS relative_path, f.absolute_path AS absolute_path, f.project_name AS project_name, + f.name as name, labels(f) as type""" CYPHER_EXAMPLE_README = f"""MATCH (f:File) WHERE toLower(f.name) CONTAINS 'readme' -RETURN f.path AS path, f.name AS name, labels(f) AS type +RETURN f.path AS relative_path, f.absolute_path AS absolute_path, f.project_name AS project_name, + f.name AS name, labels(f) AS type LIMIT {CYPHER_DEFAULT_LIMIT}""" CYPHER_EXAMPLE_PYTHON_FILES = f"""MATCH (f:File) WHERE f.extension = '.py' -RETURN f.path AS path, f.name AS name, labels(f) AS type +RETURN f.path AS relative_path, f.absolute_path AS absolute_path, f.project_name AS project_name, + f.name AS name, labels(f) AS type LIMIT {CYPHER_DEFAULT_LIMIT}""" CYPHER_EXAMPLE_TASKS = f"""MATCH (n:Function|Method) WHERE 'task' IN n.decorators -RETURN n.qualified_name AS qualified_name, n.name AS name, labels(n) AS type +RETURN n.qualified_name AS qualified_name, n.name AS name, labels(n) AS type, + n.path AS relative_path, n.absolute_path AS absolute_path, n.project_name AS project_name LIMIT {CYPHER_DEFAULT_LIMIT}""" CYPHER_EXAMPLE_FILES_IN_FOLDER = f"""MATCH (f:File) WHERE f.path STARTS WITH 'services' -RETURN f.path AS path, f.name AS name, labels(f) AS type +RETURN f.path AS relative_path, f.absolute_path AS absolute_path, f.project_name AS project_name, + f.name AS name, labels(f) AS type LIMIT {CYPHER_DEFAULT_LIMIT}""" -CYPHER_EXAMPLE_LIMIT_ONE = """MATCH (f:File) RETURN f.path as path, f.name as name, labels(f) as type LIMIT 1""" +CYPHER_EXAMPLE_LIMIT_ONE = """MATCH (f:File) +RETURN f.path AS relative_path, f.absolute_path AS absolute_path, f.project_name AS project_name, + f.name as name, labels(f) as type LIMIT 1""" CYPHER_EXAMPLE_CLASS_METHODS = f"""MATCH (c:Class)-[:DEFINES_METHOD]->(m:Method) WHERE c.qualified_name ENDS WITH '.UserService' -RETURN m.name AS name, m.qualified_name AS qualified_name, labels(m) AS type +RETURN m.name AS name, m.qualified_name AS qualified_name, labels(m) AS type, + m.path AS relative_path, m.absolute_path AS absolute_path, m.project_name AS project_name LIMIT {CYPHER_DEFAULT_LIMIT}""" CYPHER_EXPORT_NODES = """ @@ -70,16 +81,18 @@ CYPHER_SET_PROPS_RETURN_COUNT = "SET r += row.props\nRETURN count(r) as created" CYPHER_GET_FUNCTION_SOURCE_LOCATION = """ -MATCH (m:Module)-[:DEFINES]->(n) +MATCH (n) WHERE id(n) = $node_id RETURN n.qualified_name AS qualified_name, n.start_line AS start_line, - n.end_line AS end_line, m.path AS path + n.end_line AS end_line, n.path AS relative_path, + n.absolute_path AS absolute_path, n.project_name AS project_name """ CYPHER_FIND_BY_QUALIFIED_NAME = """ MATCH (n) WHERE n.qualified_name = $qn -OPTIONAL MATCH (m:Module)-[*]-(n) -RETURN n.name AS name, n.start_line AS start, n.end_line AS end, m.path AS path, n.docstring AS docstring +RETURN n.name AS name, n.start_line AS start, n.end_line AS end, + n.path AS relative_path, n.absolute_path AS absolute_path, + n.project_name AS project_name, n.docstring AS docstring LIMIT 1 """ @@ -94,7 +107,9 @@ def build_nodes_by_ids_query(node_ids: list[int]) -> str: MATCH (n) WHERE id(n) IN [{placeholders}] RETURN id(n) AS node_id, n.qualified_name AS qualified_name, - labels(n) AS type, n.name AS name + labels(n) AS type, n.name AS name, + n.path AS relative_path, n.absolute_path AS absolute_path, + n.project_name AS project_name ORDER BY n.qualified_name """ @@ -126,3 +141,11 @@ def build_merge_relationship_query( ) query += CYPHER_SET_PROPS_RETURN_COUNT if has_props else CYPHER_RETURN_COUNT return query + + +def build_project_name_indexes() -> list[str]: + return [ + build_index_query("Function", "project_name"), + build_index_query("Method", "project_name"), + build_index_query("Class", "project_name"), + ] diff --git a/codebase_rag/decorators.py b/codebase_rag/decorators.py index b315ba643..1575c253b 100644 --- a/codebase_rag/decorators.py +++ b/codebase_rag/decorators.py @@ -13,6 +13,7 @@ LoadableProtocol, PathValidatorProtocol, ) +from .utils.path_utils import validate_allowed_path def ensure_loaded[T](func: Callable[..., T]) -> Callable[..., T]: @@ -70,10 +71,10 @@ async def wrapper(self: PathValidatorProtocol, *args, **kwargs) -> T: file_path=str(file_path_str), error_message=ex.ACCESS_DENIED ) try: - full_path = (self.project_root / file_path_str).resolve() - project_root = self.project_root.resolve() - full_path.relative_to(project_root) - except (ValueError, RuntimeError): + full_path = validate_allowed_path( + file_path_str, self.project_root, self.allowed_roots + ) + except PermissionError: return result_factory( file_path=file_path_str, error_message=ls.FILE_OUTSIDE_ROOT.format(action="access"), diff --git a/codebase_rag/graph_updater.py b/codebase_rag/graph_updater.py index 2620d2bcb..4ee4dda71 100644 --- a/codebase_rag/graph_updater.py +++ b/codebase_rag/graph_updater.py @@ -262,8 +262,14 @@ def _is_dependency_file(self, file_name: str, filepath: Path) -> bool: ) def run(self) -> None: + absolute_path = str(self.repo_path.resolve()) + self.ingestor.ensure_node_batch( - cs.NODE_PROJECT, {cs.KEY_NAME: self.project_name} + cs.NODE_PROJECT, + { + cs.KEY_NAME: self.project_name, + cs.KEY_ABSOLUTE_PATH: absolute_path, + }, ) logger.info(ls.ENSURING_PROJECT.format(name=self.project_name)) @@ -369,7 +375,7 @@ def _generate_semantic_embeddings(self) -> None: logger.info(ls.PASS_4_EMBEDDINGS) results = self.ingestor.fetch_all( - cs.CYPHER_QUERY_EMBEDDINGS, {"project_name": self.project_name + "."} + cs.CYPHER_QUERY_EMBEDDINGS, {"project_name": self.project_name} ) if not results: diff --git a/codebase_rag/logs.py b/codebase_rag/logs.py index 3e075c877..7e0c6d444 100644 --- a/codebase_rag/logs.py +++ b/codebase_rag/logs.py @@ -593,6 +593,7 @@ MCP_ERROR_WRITE = "[MCP] Error writing file: {error}" MCP_LIST_DIR = "[MCP] list_directory: {path}" MCP_ERROR_LIST_DIR = "[MCP] Error listing directory: {error}" +QUERY_MODE_WRITE_BLOCKED = "[MCP] Write operation blocked in query mode for: {path}" # (H) MCP server logs MCP_SERVER_INFERRED_ROOT = "[GraphCode MCP] Using inferred project root: {path}" @@ -608,10 +609,14 @@ MCP_SERVER_UNKNOWN_TOOL = "[GraphCode MCP] Unknown tool: {name}" MCP_SERVER_TOOL_ERROR = "[GraphCode MCP] Error executing tool '{name}': {error}" MCP_SERVER_STARTING = "[GraphCode MCP] Starting MCP server..." +MCP_SERVER_MODE = "[GraphCode MCP] Server running in mode: {mode}" MCP_SERVER_CREATED = "[GraphCode MCP] Server created, starting stdio transport..." MCP_SERVER_CONNECTED = "[GraphCode MCP] Connected to Memgraph at {host}:{port}" MCP_SERVER_FATAL_ERROR = "[GraphCode MCP] Fatal error: {error}" MCP_SERVER_SHUTDOWN = "[GraphCode MCP] Shutting down server..." +MCP_TOOLS_REGISTRY_MODE = ( + "[GraphCode MCP] MCPToolsRegistry initialized in '{mode}' mode" +) # (H) Exclude prompt logs EXCLUDE_INVALID_INDEX = "Invalid index: {index} (out of range)" @@ -621,3 +626,9 @@ MODEL_SWITCHED = "Model switched to: {model}" MODEL_SWITCH_FAILED = "Failed to switch model: {error}" MODEL_CURRENT = "Current model: {model}" + +# (H) Path parse logs +METHOD_PATH_CALC_FAILED = "Failed to calculate paths for method {qn}: {error}" +NO_ABSOLUTE_PATH_FALLBACK = ( + "No absolute_path found for {qn}, falling back to relative path" +) diff --git a/codebase_rag/main.py b/codebase_rag/main.py index af58a84a4..c649b082f 100644 --- a/codebase_rag/main.py +++ b/codebase_rag/main.py @@ -984,7 +984,10 @@ def _initialize_services_and_agent( shell_commander = ShellCommander( project_root=repo_path, timeout=settings.SHELL_COMMAND_TIMEOUT ) - directory_lister = DirectoryLister(project_root=repo_path) + directory_lister = DirectoryLister( + project_root=repo_path, + allowed_roots=settings.allowed_project_roots_set, + ) document_analyzer = DocumentAnalyzer(project_root=repo_path) query_tool = create_query_tool(ingestor, cypher_generator, app_context.console) diff --git a/codebase_rag/mcp/server.py b/codebase_rag/mcp/server.py index 9218a2d93..cac381938 100644 --- a/codebase_rag/mcp/server.py +++ b/codebase_rag/mcp/server.py @@ -67,6 +67,9 @@ def create_server() -> tuple[Server, MemgraphIngestor]: logger.info(lg.MCP_SERVER_INIT_SERVICES) + mode = settings.MCP_MODE + logger.info(lg.MCP_SERVER_MODE.format(mode=mode)) + ingestor = MemgraphIngestor( host=settings.MEMGRAPH_HOST, port=settings.MEMGRAPH_PORT, @@ -79,6 +82,7 @@ def create_server() -> tuple[Server, MemgraphIngestor]: project_root=str(project_root), ingestor=ingestor, cypher_gen=cypher_generator, + mode=mode, ) logger.info(lg.MCP_SERVER_INIT_SUCCESS) diff --git a/codebase_rag/mcp/tools.py b/codebase_rag/mcp/tools.py index 5d1d2f7f5..0db2ea32a 100644 --- a/codebase_rag/mcp/tools.py +++ b/codebase_rag/mcp/tools.py @@ -6,6 +6,7 @@ from codebase_rag import constants as cs from codebase_rag import logs as lg from codebase_rag import tool_errors as te +from codebase_rag.config import settings from codebase_rag.graph_updater import GraphUpdater from codebase_rag.models import ToolMetadata from codebase_rag.parser_loader import load_parsers @@ -36,6 +37,8 @@ QueryResultDict, ) +from ..utils.path_utils import validate_allowed_path + class MCPToolsRegistry: def __init__( @@ -43,18 +46,32 @@ def __init__( project_root: str, ingestor: MemgraphIngestor, cypher_gen: CypherGenerator, + mode: str = "edit", ) -> None: self.project_root = project_root self.ingestor = ingestor self.cypher_gen = cypher_gen + self.mode = mode self.parsers, self.queries = load_parsers() - self.code_retriever = CodeRetriever(project_root, ingestor) - self.file_editor = FileEditor(project_root=project_root) - self.file_reader = FileReader(project_root=project_root) - self.file_writer = FileWriter(project_root=project_root) - self.directory_lister = DirectoryLister(project_root=project_root) + self.code_retriever = CodeRetriever( + project_root, ingestor, allowed_roots=settings.allowed_project_roots_set + ) + self.file_editor = FileEditor(project_root=project_root, mode=mode) + self.file_reader = FileReader( + project_root=project_root, + mode=mode, + allowed_roots=settings.allowed_project_roots_set, + ) + self.file_writer = FileWriter(project_root=project_root, mode=mode) + + logger.info(lg.MCP_TOOLS_REGISTRY_MODE.format(mode=mode)) + + self.directory_lister = DirectoryLister( + project_root=project_root, + allowed_roots=settings.allowed_project_roots_set, + ) self._query_tool = create_query_tool( ingestor=ingestor, cypher_gen=cypher_gen, console=None @@ -67,186 +84,201 @@ def __init__( directory_lister=self.directory_lister ) - self._tools: dict[str, ToolMetadata] = { - cs.MCPToolName.LIST_PROJECTS: ToolMetadata( - name=cs.MCPToolName.LIST_PROJECTS, - description=td.MCP_TOOLS[cs.MCPToolName.LIST_PROJECTS], - input_schema=MCPInputSchema( - type=cs.MCPSchemaType.OBJECT, - properties={}, - required=[], - ), - handler=self.list_projects, - returns_json=True, - ), - cs.MCPToolName.DELETE_PROJECT: ToolMetadata( - name=cs.MCPToolName.DELETE_PROJECT, - description=td.MCP_TOOLS[cs.MCPToolName.DELETE_PROJECT], - input_schema=MCPInputSchema( - type=cs.MCPSchemaType.OBJECT, - properties={ - cs.MCPParamName.PROJECT_NAME: MCPInputSchemaProperty( - type=cs.MCPSchemaType.STRING, - description=td.MCP_PARAM_PROJECT_NAME, - ) - }, - required=[cs.MCPParamName.PROJECT_NAME], - ), - handler=self.delete_project, - returns_json=True, - ), - cs.MCPToolName.WIPE_DATABASE: ToolMetadata( - name=cs.MCPToolName.WIPE_DATABASE, - description=td.MCP_TOOLS[cs.MCPToolName.WIPE_DATABASE], - input_schema=MCPInputSchema( - type=cs.MCPSchemaType.OBJECT, - properties={ - cs.MCPParamName.CONFIRM: MCPInputSchemaProperty( - type=cs.MCPSchemaType.BOOLEAN, - description=td.MCP_PARAM_CONFIRM, - ) - }, - required=[cs.MCPParamName.CONFIRM], + self._tools: dict[str, ToolMetadata] = self._build_tools() + + def _build_tools(self) -> dict[str, ToolMetadata]: + tools: dict[str, ToolMetadata] = {} + + tools.update( + { + cs.MCPToolName.QUERY_CODE_GRAPH: ToolMetadata( + name=cs.MCPToolName.QUERY_CODE_GRAPH, + description=td.MCP_TOOLS[cs.MCPToolName.QUERY_CODE_GRAPH], + input_schema=MCPInputSchema( + type=cs.MCPSchemaType.OBJECT, + properties={ + cs.MCPParamName.NATURAL_LANGUAGE_QUERY: MCPInputSchemaProperty( + type=cs.MCPSchemaType.STRING, + description=td.MCP_PARAM_NATURAL_LANGUAGE_QUERY, + ) + }, + required=[cs.MCPParamName.NATURAL_LANGUAGE_QUERY], + ), + handler=self.query_code_graph, + returns_json=True, ), - handler=self.wipe_database, - returns_json=False, - ), - cs.MCPToolName.INDEX_REPOSITORY: ToolMetadata( - name=cs.MCPToolName.INDEX_REPOSITORY, - description=td.MCP_TOOLS[cs.MCPToolName.INDEX_REPOSITORY], - input_schema=MCPInputSchema( - type=cs.MCPSchemaType.OBJECT, - properties={}, - required=[], + cs.MCPToolName.GET_CODE_SNIPPET: ToolMetadata( + name=cs.MCPToolName.GET_CODE_SNIPPET, + description=td.MCP_TOOLS[cs.MCPToolName.GET_CODE_SNIPPET], + input_schema=MCPInputSchema( + type=cs.MCPSchemaType.OBJECT, + properties={ + cs.MCPParamName.QUALIFIED_NAME: MCPInputSchemaProperty( + type=cs.MCPSchemaType.STRING, + description=td.MCP_PARAM_QUALIFIED_NAME, + ) + }, + required=[cs.MCPParamName.QUALIFIED_NAME], + ), + handler=self.get_code_snippet, + returns_json=True, ), - handler=self.index_repository, - returns_json=False, - ), - cs.MCPToolName.QUERY_CODE_GRAPH: ToolMetadata( - name=cs.MCPToolName.QUERY_CODE_GRAPH, - description=td.MCP_TOOLS[cs.MCPToolName.QUERY_CODE_GRAPH], - input_schema=MCPInputSchema( - type=cs.MCPSchemaType.OBJECT, - properties={ - cs.MCPParamName.NATURAL_LANGUAGE_QUERY: MCPInputSchemaProperty( - type=cs.MCPSchemaType.STRING, - description=td.MCP_PARAM_NATURAL_LANGUAGE_QUERY, - ) - }, - required=[cs.MCPParamName.NATURAL_LANGUAGE_QUERY], + cs.MCPToolName.LIST_DIRECTORY: ToolMetadata( + name=cs.MCPToolName.LIST_DIRECTORY, + description=td.MCP_TOOLS[cs.MCPToolName.LIST_DIRECTORY], + input_schema=MCPInputSchema( + type=cs.MCPSchemaType.OBJECT, + properties={ + cs.MCPParamName.DIRECTORY_PATH: MCPInputSchemaProperty( + type=cs.MCPSchemaType.STRING, + description=td.MCP_PARAM_DIRECTORY_PATH, + default=cs.MCP_DEFAULT_DIRECTORY, + ) + }, + required=[], + ), + handler=self.list_directory, + returns_json=False, ), - handler=self.query_code_graph, - returns_json=True, - ), - cs.MCPToolName.GET_CODE_SNIPPET: ToolMetadata( - name=cs.MCPToolName.GET_CODE_SNIPPET, - description=td.MCP_TOOLS[cs.MCPToolName.GET_CODE_SNIPPET], - input_schema=MCPInputSchema( - type=cs.MCPSchemaType.OBJECT, - properties={ - cs.MCPParamName.QUALIFIED_NAME: MCPInputSchemaProperty( - type=cs.MCPSchemaType.STRING, - description=td.MCP_PARAM_QUALIFIED_NAME, - ) - }, - required=[cs.MCPParamName.QUALIFIED_NAME], + cs.MCPToolName.LIST_PROJECTS: ToolMetadata( + name=cs.MCPToolName.LIST_PROJECTS, + description=td.MCP_TOOLS[cs.MCPToolName.LIST_PROJECTS], + input_schema=MCPInputSchema( + type=cs.MCPSchemaType.OBJECT, + properties={}, + required=[], + ), + handler=self.list_projects, + returns_json=True, ), - handler=self.get_code_snippet, - returns_json=True, - ), - cs.MCPToolName.SURGICAL_REPLACE_CODE: ToolMetadata( - name=cs.MCPToolName.SURGICAL_REPLACE_CODE, - description=td.MCP_TOOLS[cs.MCPToolName.SURGICAL_REPLACE_CODE], - input_schema=MCPInputSchema( - type=cs.MCPSchemaType.OBJECT, - properties={ - cs.MCPParamName.FILE_PATH: MCPInputSchemaProperty( - type=cs.MCPSchemaType.STRING, - description=td.MCP_PARAM_FILE_PATH, - ), - cs.MCPParamName.TARGET_CODE: MCPInputSchemaProperty( - type=cs.MCPSchemaType.STRING, - description=td.MCP_PARAM_TARGET_CODE, - ), - cs.MCPParamName.REPLACEMENT_CODE: MCPInputSchemaProperty( - type=cs.MCPSchemaType.STRING, - description=td.MCP_PARAM_REPLACEMENT_CODE, - ), - }, - required=[ - cs.MCPParamName.FILE_PATH, - cs.MCPParamName.TARGET_CODE, - cs.MCPParamName.REPLACEMENT_CODE, - ], + cs.MCPToolName.READ_FILE: ToolMetadata( + name=cs.MCPToolName.READ_FILE, + description=td.MCP_TOOLS[cs.MCPToolName.READ_FILE], + input_schema=MCPInputSchema( + type=cs.MCPSchemaType.OBJECT, + properties={ + cs.MCPParamName.FILE_PATH: MCPInputSchemaProperty( + type=cs.MCPSchemaType.STRING, + description=td.MCP_PARAM_FILE_PATH, + ), + cs.MCPParamName.OFFSET: MCPInputSchemaProperty( + type=cs.MCPSchemaType.INTEGER, + description=td.MCP_PARAM_OFFSET, + ), + cs.MCPParamName.LIMIT: MCPInputSchemaProperty( + type=cs.MCPSchemaType.INTEGER, + description=td.MCP_PARAM_LIMIT, + ), + }, + required=[cs.MCPParamName.FILE_PATH], + ), + handler=self.read_file, + returns_json=False, ), - handler=self.surgical_replace_code, - returns_json=False, - ), - cs.MCPToolName.READ_FILE: ToolMetadata( - name=cs.MCPToolName.READ_FILE, - description=td.MCP_TOOLS[cs.MCPToolName.READ_FILE], - input_schema=MCPInputSchema( - type=cs.MCPSchemaType.OBJECT, - properties={ - cs.MCPParamName.FILE_PATH: MCPInputSchemaProperty( - type=cs.MCPSchemaType.STRING, - description=td.MCP_PARAM_FILE_PATH, + } + ) + + if self.mode == "edit": + tools.update( + { + cs.MCPToolName.SURGICAL_REPLACE_CODE: ToolMetadata( + name=cs.MCPToolName.SURGICAL_REPLACE_CODE, + description=td.MCP_TOOLS[cs.MCPToolName.SURGICAL_REPLACE_CODE], + input_schema=MCPInputSchema( + type=cs.MCPSchemaType.OBJECT, + properties={ + cs.MCPParamName.FILE_PATH: MCPInputSchemaProperty( + type=cs.MCPSchemaType.STRING, + description=td.MCP_PARAM_FILE_PATH, + ), + cs.MCPParamName.TARGET_CODE: MCPInputSchemaProperty( + type=cs.MCPSchemaType.STRING, + description=td.MCP_PARAM_TARGET_CODE, + ), + cs.MCPParamName.REPLACEMENT_CODE: MCPInputSchemaProperty( + type=cs.MCPSchemaType.STRING, + description=td.MCP_PARAM_REPLACEMENT_CODE, + ), + }, + required=[ + cs.MCPParamName.FILE_PATH, + cs.MCPParamName.TARGET_CODE, + cs.MCPParamName.REPLACEMENT_CODE, + ], ), - cs.MCPParamName.OFFSET: MCPInputSchemaProperty( - type=cs.MCPSchemaType.INTEGER, - description=td.MCP_PARAM_OFFSET, + handler=self.surgical_replace_code, + returns_json=False, + ), + cs.MCPToolName.WRITE_FILE: ToolMetadata( + name=cs.MCPToolName.WRITE_FILE, + description=td.MCP_TOOLS[cs.MCPToolName.WRITE_FILE], + input_schema=MCPInputSchema( + type=cs.MCPSchemaType.OBJECT, + properties={ + cs.MCPParamName.FILE_PATH: MCPInputSchemaProperty( + type=cs.MCPSchemaType.STRING, + description=td.MCP_PARAM_FILE_PATH, + ), + cs.MCPParamName.CONTENT: MCPInputSchemaProperty( + type=cs.MCPSchemaType.STRING, + description=td.MCP_PARAM_CONTENT, + ), + }, + required=[ + cs.MCPParamName.FILE_PATH, + cs.MCPParamName.CONTENT, + ], ), - cs.MCPParamName.LIMIT: MCPInputSchemaProperty( - type=cs.MCPSchemaType.INTEGER, - description=td.MCP_PARAM_LIMIT, + handler=self.write_file, + returns_json=False, + ), + cs.MCPToolName.DELETE_PROJECT: ToolMetadata( + name=cs.MCPToolName.DELETE_PROJECT, + description=td.MCP_TOOLS[cs.MCPToolName.DELETE_PROJECT], + input_schema=MCPInputSchema( + type=cs.MCPSchemaType.OBJECT, + properties={ + cs.MCPParamName.PROJECT_NAME: MCPInputSchemaProperty( + type=cs.MCPSchemaType.STRING, + description=td.MCP_PARAM_PROJECT_NAME, + ) + }, + required=[cs.MCPParamName.PROJECT_NAME], ), - }, - required=[cs.MCPParamName.FILE_PATH], - ), - handler=self.read_file, - returns_json=False, - ), - cs.MCPToolName.WRITE_FILE: ToolMetadata( - name=cs.MCPToolName.WRITE_FILE, - description=td.MCP_TOOLS[cs.MCPToolName.WRITE_FILE], - input_schema=MCPInputSchema( - type=cs.MCPSchemaType.OBJECT, - properties={ - cs.MCPParamName.FILE_PATH: MCPInputSchemaProperty( - type=cs.MCPSchemaType.STRING, - description=td.MCP_PARAM_FILE_PATH, + handler=self.delete_project, + returns_json=True, + ), + cs.MCPToolName.WIPE_DATABASE: ToolMetadata( + name=cs.MCPToolName.WIPE_DATABASE, + description=td.MCP_TOOLS[cs.MCPToolName.WIPE_DATABASE], + input_schema=MCPInputSchema( + type=cs.MCPSchemaType.OBJECT, + properties={ + cs.MCPParamName.CONFIRM: MCPInputSchemaProperty( + type=cs.MCPSchemaType.BOOLEAN, + description=td.MCP_PARAM_CONFIRM, + ) + }, + required=[cs.MCPParamName.CONFIRM], ), - cs.MCPParamName.CONTENT: MCPInputSchemaProperty( - type=cs.MCPSchemaType.STRING, - description=td.MCP_PARAM_CONTENT, + handler=self.wipe_database, + returns_json=False, + ), + cs.MCPToolName.INDEX_REPOSITORY: ToolMetadata( + name=cs.MCPToolName.INDEX_REPOSITORY, + description=td.MCP_TOOLS[cs.MCPToolName.INDEX_REPOSITORY], + input_schema=MCPInputSchema( + type=cs.MCPSchemaType.OBJECT, + properties={}, + required=[], ), - }, - required=[ - cs.MCPParamName.FILE_PATH, - cs.MCPParamName.CONTENT, - ], - ), - handler=self.write_file, - returns_json=False, - ), - cs.MCPToolName.LIST_DIRECTORY: ToolMetadata( - name=cs.MCPToolName.LIST_DIRECTORY, - description=td.MCP_TOOLS[cs.MCPToolName.LIST_DIRECTORY], - input_schema=MCPInputSchema( - type=cs.MCPSchemaType.OBJECT, - properties={ - cs.MCPParamName.DIRECTORY_PATH: MCPInputSchemaProperty( - type=cs.MCPSchemaType.STRING, - description=td.MCP_PARAM_DIRECTORY_PATH, - default=cs.MCP_DEFAULT_DIRECTORY, - ) - }, - required=[], - ), - handler=self.list_directory, - returns_json=False, - ), - } + handler=self.index_repository, + returns_json=False, + ), + } + ) + + return tools async def list_projects(self) -> ListProjectsResult: logger.info(lg.MCP_LISTING_PROJECTS) @@ -374,10 +406,14 @@ async def read_file( logger.info(lg.MCP_READ_FILE.format(path=file_path, offset=offset, limit=limit)) try: if offset is not None or limit is not None: - full_path = Path(self.project_root) / file_path + project_root_path = Path(self.project_root).resolve() + safe_path = validate_allowed_path( + file_path, project_root_path, self.file_reader.allowed_roots + ) + start = offset if offset is not None else 0 - with open(full_path, encoding=cs.ENCODING_UTF8) as f: + with safe_path.open("r", encoding=cs.ENCODING_UTF8) as f: skipped_count = sum(1 for _ in itertools.islice(f, start)) if limit is not None: @@ -449,9 +485,11 @@ def create_mcp_tools_registry( project_root: str, ingestor: MemgraphIngestor, cypher_gen: CypherGenerator, + mode: str = "edit", ) -> MCPToolsRegistry: return MCPToolsRegistry( project_root=project_root, ingestor=ingestor, cypher_gen=cypher_gen, + mode=mode, ) diff --git a/codebase_rag/parsers/class_ingest/cpp_modules.py b/codebase_rag/parsers/class_ingest/cpp_modules.py index a5db9bc47..e3f310d8d 100644 --- a/codebase_rag/parsers/class_ingest/cpp_modules.py +++ b/codebase_rag/parsers/class_ingest/cpp_modules.py @@ -83,7 +83,7 @@ def _process_export_module( { cs.KEY_QUALIFIED_NAME: interface_qn, cs.KEY_NAME: module_name, - cs.KEY_PATH: str(file_path.relative_to(repo_path)), + cs.KEY_PATH: file_path.relative_to(repo_path).as_posix(), cs.KEY_MODULE_TYPE: cs.CPP_MODULE_TYPE_INTERFACE, }, ) @@ -117,7 +117,7 @@ def _process_module_implementation( { cs.KEY_QUALIFIED_NAME: impl_qn, cs.KEY_NAME: f"{module_name}{cs.CPP_IMPL_SUFFIX}", - cs.KEY_PATH: str(file_path.relative_to(repo_path)), + cs.KEY_PATH: file_path.relative_to(repo_path).as_posix(), cs.KEY_IMPLEMENTS_MODULE: module_name, cs.KEY_MODULE_TYPE: cs.CPP_MODULE_TYPE_IMPLEMENTATION, }, diff --git a/codebase_rag/parsers/class_ingest/mixin.py b/codebase_rag/parsers/class_ingest/mixin.py index 2ba3f8f8c..d442be642 100644 --- a/codebase_rag/parsers/class_ingest/mixin.py +++ b/codebase_rag/parsers/class_ingest/mixin.py @@ -10,6 +10,7 @@ from ... import constants as cs from ... import logs from ...types_defs import ASTNode, PropertyDict +from ...utils.path_utils import calculate_paths from ..java import utils as java_utils from ..py import resolve_class_name from ..rs import utils as rs_utils @@ -142,6 +143,16 @@ def _process_class_node( cs.KEY_DOCSTRING: self._get_docstring(class_node), cs.KEY_IS_EXPORTED: is_exported, } + + if file_path: + paths = calculate_paths( + file_path=file_path, + repo_path=self.repo_path, + ) + class_props[cs.KEY_PATH] = paths["relative_path"] + class_props[cs.KEY_ABSOLUTE_PATH] = paths["absolute_path"] + class_props[cs.KEY_PROJECT_NAME] = self.project_name + self.ingestor.ensure_node_batch(node_type, class_props) self.function_registry[class_qn] = node_type if class_name: @@ -160,7 +171,9 @@ def _process_class_node( self._resolve_to_qn, self.function_registry, ) - self._ingest_class_methods(class_node, class_qn, language, lang_queries) + self._ingest_class_methods( + class_node, class_qn, language, lang_queries, file_path + ) def _ingest_rust_impl_methods( self, @@ -183,6 +196,7 @@ def _ingest_rust_impl_methods( method_captures = method_cursor.captures(body_node) for method_node in method_captures.get(cs.CAPTURE_FUNCTION, []): if isinstance(method_node, Node): + file_path = self.module_qn_to_file_path.get(module_qn) ingest_method( method_node, class_qn, @@ -192,6 +206,9 @@ def _ingest_rust_impl_methods( self.simple_name_lookup, self._get_docstring, language, + file_path=file_path, + repo_path=self.repo_path if file_path else None, + project_name=self.project_name, ) def _ingest_class_methods( @@ -200,6 +217,7 @@ def _ingest_class_methods( class_qn: str, language: cs.SupportedLanguage, lang_queries: LanguageQueries, + file_path: Path | None, ) -> None: body_node = class_node.child_by_field_name("body") method_query = lang_queries[cs.QUERY_FUNCTIONS] @@ -233,6 +251,9 @@ def _ingest_class_methods( language, self._extract_decorators, method_qualified_name, + file_path=file_path, + repo_path=self.repo_path if file_path else None, + project_name=self.project_name, ) def _process_inline_modules( diff --git a/codebase_rag/parsers/definition_processor.py b/codebase_rag/parsers/definition_processor.py index 8110140f8..ed809d85e 100644 --- a/codebase_rag/parsers/definition_processor.py +++ b/codebase_rag/parsers/definition_processor.py @@ -8,6 +8,7 @@ from .. import constants as cs from .. import logs as ls from ..types_defs import ASTNode, FunctionRegistryTrieProtocol, SimpleNameLookup +from ..utils.path_utils import calculate_paths from .class_ingest import ClassIngestMixin from .dependency_parser import parse_dependencies from .function_ingest import FunctionIngestMixin @@ -94,12 +95,19 @@ def process_file( ) self.module_qn_to_file_path[module_qn] = file_path + paths = calculate_paths( + file_path=file_path, + repo_path=self.repo_path, + ) + self.ingestor.ensure_node_batch( cs.NodeLabel.MODULE, { cs.KEY_QUALIFIED_NAME: module_qn, cs.KEY_NAME: file_path.name, - cs.KEY_PATH: relative_path_str, + cs.KEY_PATH: paths["relative_path"], + cs.KEY_ABSOLUTE_PATH: paths["absolute_path"], + cs.KEY_PROJECT_NAME: self.project_name, }, ) @@ -109,7 +117,7 @@ def process_file( (cs.NodeLabel.PACKAGE, cs.KEY_QUALIFIED_NAME, parent_container_qn) if parent_container_qn else ( - (cs.NodeLabel.FOLDER, cs.KEY_PATH, str(parent_rel_path)) + (cs.NodeLabel.FOLDER, cs.KEY_PATH, parent_rel_path.as_posix()) if parent_rel_path != Path(".") else (cs.NodeLabel.PROJECT, cs.KEY_NAME, self.project_name) ) diff --git a/codebase_rag/parsers/function_ingest.py b/codebase_rag/parsers/function_ingest.py index 1d32186e0..c8a1b8396 100644 --- a/codebase_rag/parsers/function_ingest.py +++ b/codebase_rag/parsers/function_ingest.py @@ -14,10 +14,12 @@ ASTNode, FunctionRegistryTrieProtocol, NodeType, + PathInfo, PropertyDict, SimpleNameLookup, ) from ..utils.fqn_resolver import resolve_fqn_from_ast +from ..utils.path_utils import calculate_paths from .cpp import utils as cpp_utils from .lua import utils as lua_utils from .rs import utils as rs_utils @@ -160,6 +162,8 @@ def _handle_cpp_out_of_class_method(self, func_node: Node, module_qn: str) -> bo ) class_qn = f"{module_qn}.{class_name_normalized}" + file_path = self.module_qn_to_file_path.get(module_qn) + ingest_method( method_node=func_node, container_qn=class_qn, @@ -170,6 +174,9 @@ def _handle_cpp_out_of_class_method(self, func_node: Node, module_qn: str) -> bo get_docstring_func=self._get_docstring, language=cs.SupportedLanguage.CPP, extract_decorators_func=self._extract_decorators, + file_path=file_path, + repo_path=self.repo_path if file_path else None, + project_name=self.project_name, ) return True @@ -238,7 +245,15 @@ def _register_function( language: cs.SupportedLanguage, lang_config: LanguageSpec, ) -> None: - func_props = self._build_function_props(func_node, resolution) + file_path = self.module_qn_to_file_path.get(module_qn) + paths = None + if file_path: + paths = calculate_paths( + file_path=file_path, + repo_path=self.repo_path, + ) + + func_props = self._build_function_props(func_node, resolution, paths) logger.info( ls.FUNC_FOUND.format(name=resolution.name, qn=resolution.qualified_name) ) @@ -253,9 +268,12 @@ def _register_function( ) def _build_function_props( - self, func_node: Node, resolution: FunctionResolution + self, + func_node: Node, + resolution: FunctionResolution, + paths: PathInfo | None = None, ) -> PropertyDict: - return { + props: PropertyDict = { cs.KEY_QUALIFIED_NAME: resolution.qualified_name, cs.KEY_NAME: resolution.name, cs.KEY_DECORATORS: self._extract_decorators(func_node), @@ -265,6 +283,13 @@ def _build_function_props( cs.KEY_IS_EXPORTED: resolution.is_exported, } + if paths: + props[cs.KEY_PATH] = paths["relative_path"] + props[cs.KEY_ABSOLUTE_PATH] = paths["absolute_path"] + props[cs.KEY_PROJECT_NAME] = self.project_name + + return props + def _create_function_relationships( self, func_node: Node, diff --git a/codebase_rag/parsers/import_processor.py b/codebase_rag/parsers/import_processor.py index 99c3a8526..db1f41609 100644 --- a/codebase_rag/parsers/import_processor.py +++ b/codebase_rag/parsers/import_processor.py @@ -238,6 +238,7 @@ def _ensure_external_module_node(self, module_path: str, full_name: str) -> None cs.KEY_QUALIFIED_NAME: module_path, cs.KEY_PATH: full_name, cs.KEY_IS_EXTERNAL: True, + cs.KEY_PROJECT_NAME: cs.EXTERNAL_PROJECT_NAME, }, ) diff --git a/codebase_rag/parsers/structure_processor.py b/codebase_rag/parsers/structure_processor.py index 9b4065bd3..51bb2ea76 100644 --- a/codebase_rag/parsers/structure_processor.py +++ b/codebase_rag/parsers/structure_processor.py @@ -6,7 +6,7 @@ from .. import logs from ..services import IngestorProtocol from ..types_defs import LanguageQueries, NodeIdentifier -from ..utils.path_utils import should_skip_path +from ..utils.path_utils import calculate_paths, should_skip_path class StructureProcessor: @@ -73,12 +73,20 @@ def identify_structure(self) -> None: logger.info( logs.STRUCT_IDENTIFIED_PACKAGE.format(package_qn=package_qn) ) + + paths = calculate_paths( + file_path=root, + repo_path=self.repo_path, + ) + self.ingestor.ensure_node_batch( cs.NodeLabel.PACKAGE, { cs.KEY_QUALIFIED_NAME: package_qn, cs.KEY_NAME: root.name, - cs.KEY_PATH: relative_root.as_posix(), + cs.KEY_PATH: paths["relative_path"], + cs.KEY_ABSOLUTE_PATH: paths["absolute_path"], + cs.KEY_PROJECT_NAME: self.project_name, }, ) parent_identifier = self._get_parent_identifier( @@ -94,9 +102,20 @@ def identify_structure(self) -> None: logger.info( logs.STRUCT_IDENTIFIED_FOLDER.format(relative_root=relative_root) ) + + paths = calculate_paths( + file_path=root, + repo_path=self.repo_path, + ) + self.ingestor.ensure_node_batch( cs.NodeLabel.FOLDER, - {cs.KEY_PATH: relative_root.as_posix(), cs.KEY_NAME: root.name}, + { + cs.KEY_PATH: paths["relative_path"], + cs.KEY_ABSOLUTE_PATH: paths["absolute_path"], + cs.KEY_NAME: root.name, + cs.KEY_PROJECT_NAME: self.project_name, + }, ) parent_identifier = self._get_parent_identifier( parent_rel_path, parent_container_qn @@ -108,7 +127,6 @@ def identify_structure(self) -> None: ) def process_generic_file(self, file_path: Path, file_name: str) -> None: - relative_filepath = file_path.relative_to(self.repo_path).as_posix() relative_root = file_path.parent.relative_to(self.repo_path) parent_container_qn = self.structural_elements.get(relative_root) @@ -116,17 +134,24 @@ def process_generic_file(self, file_path: Path, file_name: str) -> None: relative_root, parent_container_qn ) + paths = calculate_paths( + file_path=file_path, + repo_path=self.repo_path, + ) + self.ingestor.ensure_node_batch( cs.NodeLabel.FILE, { - cs.KEY_PATH: relative_filepath, + cs.KEY_PATH: paths["relative_path"], + cs.KEY_ABSOLUTE_PATH: paths["absolute_path"], cs.KEY_NAME: file_name, cs.KEY_EXTENSION: file_path.suffix, + cs.KEY_PROJECT_NAME: self.project_name, }, ) self.ingestor.ensure_relationship_batch( parent_identifier, cs.RelationshipType.CONTAINS_FILE, - (cs.NodeLabel.FILE, cs.KEY_PATH, relative_filepath), + (cs.NodeLabel.FILE, cs.KEY_PATH, paths["relative_path"]), ) diff --git a/codebase_rag/parsers/utils.py b/codebase_rag/parsers/utils.py index b164a5022..36e0affa1 100644 --- a/codebase_rag/parsers/utils.py +++ b/codebase_rag/parsers/utils.py @@ -2,6 +2,7 @@ from collections.abc import Callable from functools import lru_cache +from pathlib import Path from typing import TYPE_CHECKING, NamedTuple from loguru import logger @@ -17,6 +18,7 @@ SimpleNameLookup, TreeSitterNodeProtocol, ) +from ..utils.path_utils import calculate_paths if TYPE_CHECKING: from ..language_spec import LanguageSpec @@ -83,6 +85,9 @@ def ingest_method( language: cs.SupportedLanguage | None = None, extract_decorators_func: Callable[[ASTNode], list[str]] | None = None, method_qualified_name: str | None = None, + file_path: Path | str | None = None, + repo_path: Path | str | None = None, + project_name: str | None = None, ) -> None: if language == cs.SupportedLanguage.CPP: from .cpp import utils as cpp_utils @@ -110,6 +115,18 @@ def ingest_method( cs.KEY_DOCSTRING: get_docstring_func(method_node), } + if file_path and repo_path and project_name: + try: + paths = calculate_paths( + file_path=file_path, + repo_path=repo_path, + ) + method_props[cs.KEY_PATH] = paths["relative_path"] + method_props[cs.KEY_ABSOLUTE_PATH] = paths["absolute_path"] + method_props[cs.KEY_PROJECT_NAME] = project_name + except (ValueError, TypeError) as e: + logger.warning(logs.METHOD_PATH_CALC_FAILED.format(qn=method_qn, error=e)) + logger.info(logs.METHOD_FOUND.format(name=method_name, qn=method_qn)) ingestor.ensure_node_batch(cs.NodeLabel.METHOD, method_props) function_registry[method_qn] = NodeType.METHOD diff --git a/codebase_rag/providers/base.py b/codebase_rag/providers/base.py index 37f5cb462..e198c4833 100644 --- a/codebase_rag/providers/base.py +++ b/codebase_rag/providers/base.py @@ -7,7 +7,7 @@ import httpx from loguru import logger from pydantic_ai.models.google import GoogleModel, GoogleModelSettings -from pydantic_ai.models.openai import OpenAIChatModel, OpenAIResponsesModel +from pydantic_ai.models.openai import OpenAIChatModel from pydantic_ai.providers.google import GoogleProvider as PydanticGoogleProvider from pydantic_ai.providers.openai import OpenAIProvider as PydanticOpenAIProvider @@ -24,7 +24,7 @@ def __init__(self, **config: str | int | None) -> None: @abstractmethod def create_model( self, model_id: str, **kwargs: str | int | None - ) -> GoogleModel | OpenAIResponsesModel | OpenAIChatModel: + ) -> GoogleModel | OpenAIChatModel: pass @abstractmethod @@ -118,11 +118,11 @@ def validate_config(self) -> None: def create_model( self, model_id: str, **kwargs: str | int | None - ) -> OpenAIResponsesModel: + ) -> OpenAIChatModel: self.validate_config() provider = PydanticOpenAIProvider(api_key=self.api_key, base_url=self.endpoint) - return OpenAIResponsesModel(model_id, provider=provider) + return OpenAIChatModel(model_id, provider=provider) class OllamaProvider(ModelProvider): diff --git a/codebase_rag/schemas.py b/codebase_rag/schemas.py index 553d52d86..4e392d656 100644 --- a/codebase_rag/schemas.py +++ b/codebase_rag/schemas.py @@ -38,6 +38,7 @@ class CodeSnippet(BaseModel): qualified_name: str source_code: str file_path: str + project_name: str | None = None line_start: int line_end: int docstring: str | None = None diff --git a/codebase_rag/services/graph_service.py b/codebase_rag/services/graph_service.py index 7a8d95e02..e6ddd41e7 100644 --- a/codebase_rag/services/graph_service.py +++ b/codebase_rag/services/graph_service.py @@ -35,6 +35,7 @@ build_index_query, build_merge_node_query, build_merge_relationship_query, + build_project_name_indexes, wrap_with_unwind, ) from ..types_defs import ( @@ -201,6 +202,13 @@ def _ensure_indexes(self) -> None: self._execute_query(build_index_query(label, prop)) except Exception: pass + + for index_query in build_project_name_indexes(): + try: + self._execute_query(index_query) + except Exception: + pass + logger.info(ls.MG_INDEXES_DONE) def ensure_node_batch( diff --git a/codebase_rag/tests/test_decorators.py b/codebase_rag/tests/test_decorators.py index 366122047..d00279051 100644 --- a/codebase_rag/tests/test_decorators.py +++ b/codebase_rag/tests/test_decorators.py @@ -162,6 +162,7 @@ def __init__( class MockService: project_root = Path("/project") + allowed_roots: frozenset[Path] | None = None @validate_project_path(ResultType, "file_path") async def read(self, file_path: Path) -> ResultType: @@ -183,6 +184,7 @@ def __init__( class MockService: project_root = Path("/project") + allowed_roots: frozenset[Path] | None = None @validate_project_path(ResultType, "file_path") async def read(self, file_path: Path) -> ResultType: @@ -204,6 +206,7 @@ def __init__( class MockService: project_root = Path("/project") + allowed_roots: frozenset[Path] | None = None @validate_project_path(ResultType, "file_path") async def read(self, file_path: Path) -> ResultType: @@ -224,6 +227,7 @@ def __init__( class MockService: project_root = Path("/project") + allowed_roots: frozenset[Path] | None = None @validate_project_path(ResultType, "file_path") async def save(self, content: str, file_path: Path) -> ResultType: diff --git a/codebase_rag/tests/test_graph_service.py b/codebase_rag/tests/test_graph_service.py index c31b30741..4a45cd8b1 100644 --- a/codebase_rag/tests/test_graph_service.py +++ b/codebase_rag/tests/test_graph_service.py @@ -5,7 +5,7 @@ import pytest from codebase_rag.constants import NODE_UNIQUE_CONSTRAINTS -from codebase_rag.cypher_queries import wrap_with_unwind +from codebase_rag.cypher_queries import build_project_name_indexes, wrap_with_unwind from codebase_rag.services.graph_service import MemgraphIngestor @@ -285,7 +285,9 @@ def fail_then_succeed(query: str) -> None: with patch.object(ingestor, "_execute_query", side_effect=fail_then_succeed): ingestor.ensure_constraints() - expected_queries = len(NODE_UNIQUE_CONSTRAINTS) * 2 + expected_queries = len(NODE_UNIQUE_CONSTRAINTS) * 2 + len( + build_project_name_indexes() + ) assert call_count == expected_queries diff --git a/codebase_rag/tests/test_provider_classes.py b/codebase_rag/tests/test_provider_classes.py index 1475914a0..f446d7563 100644 --- a/codebase_rag/tests/test_provider_classes.py +++ b/codebase_rag/tests/test_provider_classes.py @@ -5,7 +5,7 @@ import pytest from pydantic_ai.models.google import GoogleModel -from pydantic_ai.models.openai import OpenAIChatModel, OpenAIResponsesModel +from pydantic_ai.models.openai import OpenAIChatModel from codebase_rag.constants import GoogleProviderType, Provider from codebase_rag.providers.base import ( @@ -59,7 +59,7 @@ def validate_config(self) -> None: def create_model( self, model_id: str, **kwargs: str | int | None - ) -> GoogleModel | OpenAIResponsesModel | OpenAIChatModel: + ) -> GoogleModel | OpenAIChatModel: return MagicMock(spec=GoogleModel) register_provider("custom", CustomProvider) @@ -241,21 +241,21 @@ def test_google_model_creation_with_thinking_budget( assert call_kwargs["settings"] == mock_settings @patch("codebase_rag.providers.base.PydanticOpenAIProvider") - @patch("codebase_rag.providers.base.OpenAIResponsesModel") + @patch("codebase_rag.providers.base.OpenAIChatModel") def test_openai_model_creation( - self, mock_openai_model: Any, mock_openai_provider: Any + self, mock_openai_chat_model: Any, mock_openai_provider: Any ) -> None: provider = OpenAIProvider(api_key="sk-test-key") mock_model = MagicMock() - mock_openai_model.return_value = mock_model + mock_openai_chat_model.return_value = mock_model provider.create_model("gpt-4o") mock_openai_provider.assert_called_once_with( api_key="sk-test-key", base_url="https://api.openai.com/v1" ) - mock_openai_model.assert_called_once_with( + mock_openai_chat_model.assert_called_once_with( "gpt-4o", provider=mock_openai_provider.return_value ) diff --git a/codebase_rag/tool_errors.py b/codebase_rag/tool_errors.py index 25540a976..274f853f7 100644 --- a/codebase_rag/tool_errors.py +++ b/codebase_rag/tool_errors.py @@ -49,6 +49,9 @@ CODE_ENTITY_NOT_FOUND = "Entity not found in graph." CODE_MISSING_LOCATION = "Graph entry is missing location data." +# (H) Tool operation errors +WRITE_QUERY_MODE_BLOCKED = "Write operations are not allowed in query mode" + # (H) File writer errors FILE_WRITER_SECURITY = ( "Security risk: Attempted to create file outside of project root: {path}" diff --git a/codebase_rag/tools/code_retrieval.py b/codebase_rag/tools/code_retrieval.py index 2e6331dcd..258714b21 100644 --- a/codebase_rag/tools/code_retrieval.py +++ b/codebase_rag/tools/code_retrieval.py @@ -11,13 +11,24 @@ from ..cypher_queries import CYPHER_FIND_BY_QUALIFIED_NAME from ..schemas import CodeSnippet from ..services import QueryProtocol +from ..utils.path_utils import validate_allowed_path from . import tool_descriptions as td class CodeRetriever: - def __init__(self, project_root: str, ingestor: QueryProtocol): + def __init__( + self, + project_root: str, + ingestor: QueryProtocol, + allowed_roots: frozenset[Path] | None = None, + ): self.project_root = Path(project_root).resolve() self.ingestor = ingestor + self.allowed_roots = ( + frozenset(root.resolve() for root in allowed_roots) + if allowed_roots + else None + ) logger.info(ls.CODE_RETRIEVER_INIT.format(root=self.project_root)) async def find_code_snippet(self, qualified_name: str) -> CodeSnippet: @@ -39,23 +50,38 @@ async def find_code_snippet(self, qualified_name: str) -> CodeSnippet: ) res = results[0] - file_path_str = res.get("path") + project_name = res.get("project_name") start_line = res.get("start") end_line = res.get("end") - if not all([file_path_str, start_line, end_line]): + absolute_path_str = res.get("absolute_path") + relative_path_str = res.get("relative_path") + + if absolute_path_str: + file_path_obj = Path(absolute_path_str) + elif relative_path_str: + file_path_obj = validate_allowed_path( + relative_path_str, self.project_root, self.allowed_roots + ) + logger.warning(ls.NO_ABSOLUTE_PATH_FALLBACK.format(qn=qualified_name)) + else: + file_path_obj = None + + if not (file_path_obj and start_line and end_line): return CodeSnippet( qualified_name=qualified_name, source_code="", - file_path=file_path_str or "", + file_path=str(file_path_obj) if file_path_obj else "", + project_name=project_name, line_start=0, line_end=0, found=False, error_message=te.CODE_MISSING_LOCATION, ) - full_path = self.project_root / file_path_str - with full_path.open("r", encoding=ENCODING_UTF8) as f: + assert file_path_obj is not None + + with file_path_obj.open("r", encoding=ENCODING_UTF8) as f: all_lines = f.readlines() snippet_lines = all_lines[start_line - 1 : end_line] @@ -64,7 +90,8 @@ async def find_code_snippet(self, qualified_name: str) -> CodeSnippet: return CodeSnippet( qualified_name=qualified_name, source_code=source_code, - file_path=file_path_str, + file_path=str(file_path_obj), + project_name=project_name, line_start=start_line, line_end=end_line, docstring=res.get("docstring"), diff --git a/codebase_rag/tools/directory_lister.py b/codebase_rag/tools/directory_lister.py index 01136a193..a03028687 100644 --- a/codebase_rag/tools/directory_lister.py +++ b/codebase_rag/tools/directory_lister.py @@ -13,8 +13,11 @@ class DirectoryLister: - def __init__(self, project_root: str): + def __init__(self, project_root: str, allowed_roots: frozenset[Path] | None = None): self.project_root = Path(project_root).resolve() + self.allowed_roots = frozenset( + {self.project_root} | ({root.resolve() for root in allowed_roots or []}) + ) def list_directory_contents(self, directory_path: str) -> str: target_path = self._get_safe_path(directory_path) @@ -38,12 +41,16 @@ def _get_safe_path(self, file_path: str) -> Path: else: safe_path = (self.project_root / file_path).resolve() - try: - safe_path.relative_to(self.project_root.resolve()) - except ValueError as e: - raise PermissionError(ex.ACCESS_DENIED) from e + is_allowed = False + for allowed_root in self.allowed_roots: + try: + safe_path.relative_to(allowed_root) + is_allowed = True + break + except ValueError: + continue - if not str(safe_path).startswith(str(self.project_root.resolve())): + if not is_allowed: raise PermissionError(ex.ACCESS_DENIED) return safe_path diff --git a/codebase_rag/tools/file_editor.py b/codebase_rag/tools/file_editor.py index 650da823e..8bd360058 100644 --- a/codebase_rag/tools/file_editor.py +++ b/codebase_rag/tools/file_editor.py @@ -16,12 +16,15 @@ from ..parser_loader import load_parsers from ..schemas import EditResult from ..types_defs import FunctionMatch +from ..utils.path_utils import validate_allowed_path from . import tool_descriptions as td class FileEditor: - def __init__(self, project_root: str = ".") -> None: + def __init__(self, project_root: str = ".", mode: str = "edit") -> None: self.project_root = Path(project_root).resolve() + self.mode = mode + self.allowed_roots: frozenset[Path] | None = None self.dmp = diff_match_patch.diff_match_patch() self.parsers, _ = load_parsers() logger.info(ls.FILE_EDITOR_INIT.format(root=self.project_root)) @@ -204,10 +207,15 @@ def apply_patch_to_file(self, file_path: str, patch_text: str) -> bool: def replace_code_block( self, file_path: str, target_block: str, replacement_block: str ) -> bool: + if self.mode == "query": + logger.error(ls.QUERY_MODE_WRITE_BLOCKED.format(path=file_path)) + return False + logger.info(ls.TOOL_FILE_EDIT_SURGICAL.format(path=file_path)) try: - full_path = (self.project_root / file_path).resolve() - full_path.relative_to(self.project_root) + full_path = validate_allowed_path( + file_path, self.project_root, self.allowed_roots + ) if not full_path.is_file(): logger.error(ls.EDITOR_FILE_NOT_FOUND.format(path=file_path)) @@ -245,7 +253,7 @@ def replace_code_block( logger.success(ls.TOOL_FILE_EDIT_SURGICAL_SUCCESS.format(path=file_path)) return True - except ValueError: + except PermissionError: logger.error(ls.FILE_OUTSIDE_ROOT.format(action=cs.FileAction.EDIT)) return False except Exception as e: diff --git a/codebase_rag/tools/file_reader.py b/codebase_rag/tools/file_reader.py index 1b5f8618b..95daf0bd6 100644 --- a/codebase_rag/tools/file_reader.py +++ b/codebase_rag/tools/file_reader.py @@ -14,8 +14,19 @@ class FileReader: - def __init__(self, project_root: str = "."): + def __init__( + self, + project_root: str = ".", + mode: str = "edit", + allowed_roots: frozenset[Path] | None = None, + ): self.project_root = Path(project_root).resolve() + self.mode = mode + self.allowed_roots = ( + frozenset(root.resolve() for root in allowed_roots) + if allowed_roots + else None + ) logger.info(ls.FILE_READER_INIT.format(root=self.project_root)) async def read_file(self, file_path: str) -> FileReadResult: diff --git a/codebase_rag/tools/file_writer.py b/codebase_rag/tools/file_writer.py index 4f3110b3b..240409b31 100644 --- a/codebase_rag/tools/file_writer.py +++ b/codebase_rag/tools/file_writer.py @@ -14,8 +14,10 @@ class FileWriter: - def __init__(self, project_root: str = "."): + def __init__(self, project_root: str = ".", mode: str = "edit"): self.project_root = Path(project_root).resolve() + self.mode = mode + self.allowed_roots: frozenset[Path] | None = None logger.info(ls.FILE_WRITER_INIT.format(root=self.project_root)) async def create_file(self, file_path: str, content: str) -> FileCreationResult: @@ -26,6 +28,14 @@ async def create_file(self, file_path: str, content: str) -> FileCreationResult: async def _create_validated( self, file_path: Path, content: str ) -> FileCreationResult: + if self.mode == "query": + logger.error(ls.QUERY_MODE_WRITE_BLOCKED.format(path=file_path)) + return FileCreationResult( + file_path=str(file_path), + success=False, + error_message=te.WRITE_QUERY_MODE_BLOCKED, + ) + try: file_path.parent.mkdir(parents=True, exist_ok=True) file_path.write_text(content, encoding=cs.ENCODING_UTF8) diff --git a/codebase_rag/tools/semantic_search.py b/codebase_rag/tools/semantic_search.py index e7aa9c5b2..71cf97171 100644 --- a/codebase_rag/tools/semantic_search.py +++ b/codebase_rag/tools/semantic_search.py @@ -100,7 +100,7 @@ def get_function_source_code(node_id: int) -> str | None: return None result = results[0] - file_path = result.get("path") + file_path = result.get("absolute_path") or result.get("relative_path") start_line = result.get("start_line") end_line = result.get("end_line") diff --git a/codebase_rag/types_defs.py b/codebase_rag/types_defs.py index fb293147b..dd2678155 100644 --- a/codebase_rag/types_defs.py +++ b/codebase_rag/types_defs.py @@ -133,6 +133,9 @@ class PathValidatorProtocol(Protocol): @property def project_root(self) -> Path: ... + @property + def allowed_roots(self) -> frozenset[Path] | None: ... + class TreeSitterNodeProtocol(Protocol): @property @@ -379,6 +382,8 @@ class CodeSnippetResultDict(TypedDict, total=False): qualified_name: str source_code: str file_path: str + relative_path: str | None + project_name: str | None line_start: int line_end: int docstring: str | None @@ -437,38 +442,60 @@ class RelationshipSchema(NamedTuple): NODE_SCHEMAS: tuple[NodeSchema, ...] = ( - NodeSchema(NodeLabel.PROJECT, "{name: string}"), NodeSchema( - NodeLabel.PACKAGE, "{qualified_name: string, name: string, path: string}" + NodeLabel.PROJECT, "{name: string, absolute_path: string, project_name: string}" ), - NodeSchema(NodeLabel.FOLDER, "{path: string, name: string}"), - NodeSchema(NodeLabel.FILE, "{path: string, name: string, extension: string}"), NodeSchema( - NodeLabel.MODULE, "{qualified_name: string, name: string, path: string}" + NodeLabel.PACKAGE, + "{qualified_name: string, name: string, path: string, absolute_path: string, project_name: string}", + ), + NodeSchema( + NodeLabel.FOLDER, + "{path: string, name: string, absolute_path: string, project_name: string}", + ), + NodeSchema( + NodeLabel.FILE, + "{path: string, name: string, extension: string, absolute_path: string, project_name: string}", + ), + NodeSchema( + NodeLabel.MODULE, + "{qualified_name: string, name: string, path: string, absolute_path: string, project_name: string}", ), NodeSchema( NodeLabel.CLASS, - "{qualified_name: string, name: string, decorators: list[string]}", + "{qualified_name: string, name: string, path: string, absolute_path: string, project_name: string, decorators: list[string]}", ), NodeSchema( NodeLabel.FUNCTION, - "{qualified_name: string, name: string, decorators: list[string]}", + "{qualified_name: string, name: string, path: string, absolute_path: string, project_name: string, decorators: list[string]}", ), NodeSchema( NodeLabel.METHOD, - "{qualified_name: string, name: string, decorators: list[string]}", + "{qualified_name: string, name: string, path: string, absolute_path: string, project_name: string, decorators: list[string]}", + ), + NodeSchema( + NodeLabel.INTERFACE, + "{qualified_name: string, name: string, path: string, absolute_path: string, project_name: string}", + ), + NodeSchema( + NodeLabel.ENUM, + "{qualified_name: string, name: string, path: string, absolute_path: string, project_name: string}", + ), + NodeSchema( + NodeLabel.TYPE, + "{qualified_name: string, name: string, path: string, absolute_path: string, project_name: string}", + ), + NodeSchema( + NodeLabel.UNION, + "{qualified_name: string, name: string, path: string, absolute_path: string, project_name: string}", ), - NodeSchema(NodeLabel.INTERFACE, "{qualified_name: string, name: string}"), - NodeSchema(NodeLabel.ENUM, "{qualified_name: string, name: string}"), - NodeSchema(NodeLabel.TYPE, "{qualified_name: string, name: string}"), - NodeSchema(NodeLabel.UNION, "{qualified_name: string, name: string}"), NodeSchema( NodeLabel.MODULE_INTERFACE, - "{qualified_name: string, name: string, path: string}", + "{qualified_name: string, name: string, path: string, absolute_path: string, project_name: string}", ), NodeSchema( NodeLabel.MODULE_IMPLEMENTATION, - "{qualified_name: string, name: string, path: string, implements_module: string}", + "{qualified_name: string, name: string, path: string, absolute_path: string, project_name: string, implements_module: string}", ), NodeSchema(NodeLabel.EXTERNAL_PACKAGE, "{name: string, version_spec: string}"), ) @@ -556,3 +583,8 @@ class RelationshipSchema(NamedTuple): (NodeLabel.FUNCTION, NodeLabel.METHOD), ), ) + + +class PathInfo(TypedDict): + relative_path: str + absolute_path: str diff --git a/codebase_rag/utils/path_utils.py b/codebase_rag/utils/path_utils.py index 5c9bbf5b5..5f6f74ee2 100644 --- a/codebase_rag/utils/path_utils.py +++ b/codebase_rag/utils/path_utils.py @@ -1,6 +1,7 @@ from pathlib import Path from .. import constants as cs +from ..types_defs import PathInfo def should_skip_path( @@ -25,3 +26,45 @@ def should_skip_path( ): return False return not cs.IGNORE_PATTERNS.isdisjoint(dir_parts) + + +def calculate_paths( + file_path: Path | str, + repo_path: Path | str, +) -> PathInfo: + file_path = Path(file_path) + repo_path = Path(repo_path) + relative_path = file_path.relative_to(repo_path).as_posix() + absolute_path = str(file_path.resolve()) + + return PathInfo( + relative_path=relative_path, + absolute_path=absolute_path, + ) + + +def validate_allowed_path( + file_path: str | Path, + project_root: Path, + allowed_roots: frozenset[Path] | None = None, +) -> Path: + if isinstance(file_path, str): + file_path = Path(file_path) + + if file_path.is_absolute(): + safe_path = file_path.resolve() + else: + safe_path = (project_root / file_path).resolve() + + all_roots = {project_root} + if allowed_roots: + all_roots.update(allowed_roots) + + for allowed_root in all_roots: + try: + safe_path.relative_to(allowed_root) + return safe_path + except ValueError: + continue + + raise PermissionError(f"Path outside allowed roots: {file_path}") diff --git a/realtime_updater.py b/realtime_updater.py index 4fd95d5bc..1fc2d17fe 100644 --- a/realtime_updater.py +++ b/realtime_updater.py @@ -71,7 +71,7 @@ def dispatch(self, event: FileSystemEvent) -> None: return path = Path(src_path) - relative_path_str = str(path.relative_to(self.updater.repo_path)) + relative_path_str = path.relative_to(self.updater.repo_path).as_posix() logger.warning( logs.CHANGE_DETECTED.format(event_type=event.event_type, path=path) diff --git a/tests/test_cross_project_access.py b/tests/test_cross_project_access.py new file mode 100644 index 000000000..ace919350 --- /dev/null +++ b/tests/test_cross_project_access.py @@ -0,0 +1,110 @@ +import tempfile +from pathlib import Path + +import pytest + +from codebase_rag import constants as cs + + +@pytest.fixture +def temp_projects(): + with tempfile.TemporaryDirectory() as tmpdir: + tmpdir = Path(tmpdir) + + project_a = tmpdir / "project_a" + project_a.mkdir() + (project_a / "utils.py").write_text( + """ +def parse_json(data): + '''Parse JSON data''' + import json + return json.loads(data) +""", + encoding="utf-8", + ) + + project_b = tmpdir / "project_b" + project_b.mkdir() + (project_b / "helpers.py").write_text( + """ +def format_output(data): + '''Format output data''' + return str(data) +""", + encoding="utf-8", + ) + + yield { + "project_a": project_a, + "project_b": project_b, + "base_dir": tmpdir, + } + + +@pytest.mark.integration +class TestCrossProjectAccess: + def test_index_multiple_projects(self, temp_projects): + project_a = temp_projects["project_a"] + project_b = temp_projects["project_b"] + + assert (project_a / "utils.py").exists() + assert (project_b / "helpers.py").exists() + + content_a = (project_a / "utils.py").read_text(encoding="utf-8") + content_b = (project_b / "helpers.py").read_text(encoding="utf-8") + + assert "parse_json" in content_a + assert "format_output" in content_b + + def test_absolute_path_calculation(self, temp_projects): + from codebase_rag.utils.path_utils import calculate_paths + + project_a = temp_projects["project_a"] + file_path = project_a / "utils.py" + + paths1 = calculate_paths( + file_path=file_path, + repo_path=project_a, + ) + + paths2 = calculate_paths( + file_path=file_path, + repo_path=project_a, + ) + + assert paths1["absolute_path"] == paths2["absolute_path"] + + def test_path_fields_in_schema(self): + from codebase_rag.constants import KEY_ABSOLUTE_PATH, KEY_PROJECT_NAME + from codebase_rag.schemas import CodeSnippet + + assert KEY_ABSOLUTE_PATH == "absolute_path" + assert KEY_PROJECT_NAME == "project_name" + assert cs.EXTERNAL_PROJECT_NAME == "__external__" + + snippet = CodeSnippet( + qualified_name="test.func", + source_code="def test(): pass", + file_path="/absolute/path/test.py", + project_name="test_project", + line_start=1, + line_end=2, + ) + + assert snippet.file_path == "/absolute/path/test.py" + assert snippet.project_name == "test_project" + + +@pytest.mark.integration +class TestExternalModuleHandling: + def test_query_filtering_external_modules(self): + mock_nodes = [ + {"project_name": "project_a", "name": "internal_func"}, + {"project_name": "__external__", "name": "json_loads"}, + {"project_name": "project_b", "name": "helper_func"}, + ] + + internal_nodes = [n for n in mock_nodes if n["project_name"] != "__external__"] + + assert len(internal_nodes) == 2 + assert all(n["project_name"] != "__external__" for n in internal_nodes) diff --git a/uv.lock b/uv.lock index aa1977b86..4380d5306 100644 --- a/uv.lock +++ b/uv.lock @@ -461,7 +461,7 @@ wheels = [ [[package]] name = "code-graph-rag" -version = "0.0.58" +version = "0.0.60" source = { editable = "." } dependencies = [ { name = "click" },