diff --git a/packages/graphrag/graphrag/config/defaults.py b/packages/graphrag/graphrag/config/defaults.py index 640933581..0fe4e3831 100644 --- a/packages/graphrag/graphrag/config/defaults.py +++ b/packages/graphrag/graphrag/config/defaults.py @@ -311,6 +311,16 @@ class SnapshotsDefaults: raw_graph: bool = False +@dataclass +class EntityResolutionDefaults: + """Default values for entity resolution.""" + + enabled: bool = False + prompt: None = None + completion_model_id: str = DEFAULT_COMPLETION_MODEL_ID + model_instance_name: str = "entity_resolution" + + @dataclass class SummarizeDescriptionsDefaults: """Default values for summarizing descriptions.""" @@ -359,6 +369,9 @@ class GraphRagConfigDefaults: chunking: ChunkingDefaults = field(default_factory=ChunkingDefaults) snapshots: SnapshotsDefaults = field(default_factory=SnapshotsDefaults) extract_graph: ExtractGraphDefaults = field(default_factory=ExtractGraphDefaults) + entity_resolution: EntityResolutionDefaults = field( + default_factory=EntityResolutionDefaults + ) extract_graph_nlp: ExtractGraphNLPDefaults = field( default_factory=ExtractGraphNLPDefaults ) diff --git a/packages/graphrag/graphrag/config/init_content.py b/packages/graphrag/graphrag/config/init_content.py index 9973d1920..3dc15d7e3 100644 --- a/packages/graphrag/graphrag/config/init_content.py +++ b/packages/graphrag/graphrag/config/init_content.py @@ -82,6 +82,10 @@ entity_types: [{",".join(graphrag_config_defaults.extract_graph.entity_types)}] max_gleanings: {graphrag_config_defaults.extract_graph.max_gleanings} +entity_resolution: + enabled: {graphrag_config_defaults.entity_resolution.enabled} + completion_model_id: {graphrag_config_defaults.entity_resolution.completion_model_id} + summarize_descriptions: completion_model_id: {graphrag_config_defaults.summarize_descriptions.completion_model_id} prompt: "prompts/summarize_descriptions.txt" diff --git a/packages/graphrag/graphrag/config/models/entity_resolution_config.py b/packages/graphrag/graphrag/config/models/entity_resolution_config.py new file mode 100644 index 000000000..5b3ef03be --- /dev/null +++ b/packages/graphrag/graphrag/config/models/entity_resolution_config.py @@ -0,0 +1,48 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Parameterization settings for entity resolution.""" + +from dataclasses import dataclass +from pathlib import Path + +from pydantic import BaseModel, Field + +from graphrag.config.defaults import graphrag_config_defaults +from graphrag.prompts.index.entity_resolution import ENTITY_RESOLUTION_PROMPT + + +@dataclass +class EntityResolutionPrompts: + """Entity resolution prompt templates.""" + + resolution_prompt: str + + +class EntityResolutionConfig(BaseModel): + """Configuration section for entity resolution.""" + + enabled: bool = Field( + description="Whether to enable LLM-based entity resolution.", + default=graphrag_config_defaults.entity_resolution.enabled, + ) + completion_model_id: str = Field( + description="The model ID to use for entity resolution.", + default=graphrag_config_defaults.entity_resolution.completion_model_id, + ) + model_instance_name: str = Field( + description="The model singleton instance name. This primarily affects the cache storage partitioning.", + default=graphrag_config_defaults.entity_resolution.model_instance_name, + ) + prompt: str | None = Field( + description="The entity resolution prompt to use.", + default=graphrag_config_defaults.entity_resolution.prompt, + ) + + def resolved_prompts(self) -> EntityResolutionPrompts: + """Get the resolved entity resolution prompts.""" + return EntityResolutionPrompts( + resolution_prompt=Path(self.prompt).read_text(encoding="utf-8") + if self.prompt + else ENTITY_RESOLUTION_PROMPT, + ) diff --git a/packages/graphrag/graphrag/config/models/graph_rag_config.py b/packages/graphrag/graphrag/config/models/graph_rag_config.py index dc28da97c..27e692671 100644 --- a/packages/graphrag/graphrag/config/models/graph_rag_config.py +++ b/packages/graphrag/graphrag/config/models/graph_rag_config.py @@ -24,6 +24,7 @@ from graphrag.config.models.community_reports_config import CommunityReportsConfig from graphrag.config.models.drift_search_config import DRIFTSearchConfig from graphrag.config.models.embed_text_config import EmbedTextConfig +from graphrag.config.models.entity_resolution_config import EntityResolutionConfig from graphrag.config.models.extract_claims_config import ExtractClaimsConfig from graphrag.config.models.extract_graph_config import ExtractGraphConfig from graphrag.config.models.extract_graph_nlp_config import ExtractGraphNLPConfig @@ -186,6 +187,12 @@ def _validate_reporting_base_dir(self) -> None: ) """The entity extraction configuration to use.""" + entity_resolution: EntityResolutionConfig = Field( + description="The entity resolution configuration to use.", + default=EntityResolutionConfig(), + ) + """The entity resolution configuration to use.""" + summarize_descriptions: SummarizeDescriptionsConfig = Field( description="The description summarization configuration to use.", default=SummarizeDescriptionsConfig(), diff --git a/packages/graphrag/graphrag/data_model/schemas.py b/packages/graphrag/graphrag/data_model/schemas.py index c0926b9bb..644649e12 100644 --- a/packages/graphrag/graphrag/data_model/schemas.py +++ b/packages/graphrag/graphrag/data_model/schemas.py @@ -13,6 +13,7 @@ NODE_DEGREE = "degree" NODE_FREQUENCY = "frequency" NODE_DETAILS = "node_details" +ALTERNATIVE_NAMES = "alternative_names" # POST-PREP EDGE TABLE SCHEMA EDGE_SOURCE = "source" @@ -73,6 +74,7 @@ TITLE, TYPE, DESCRIPTION, + ALTERNATIVE_NAMES, TEXT_UNIT_IDS, NODE_FREQUENCY, NODE_DEGREE, diff --git a/packages/graphrag/graphrag/index/operations/finalize_entities.py b/packages/graphrag/graphrag/index/operations/finalize_entities.py index 71d6acc53..b75134cb4 100644 --- a/packages/graphrag/graphrag/index/operations/finalize_entities.py +++ b/packages/graphrag/graphrag/index/operations/finalize_entities.py @@ -28,6 +28,9 @@ def finalize_entities( final_entities["id"] = final_entities["human_readable_id"].apply( lambda _x: str(uuid4()) ) + # Ensure alternative_names column exists (empty when resolution is disabled) + if "alternative_names" not in final_entities.columns: + final_entities["alternative_names"] = [[] for _ in range(len(final_entities))] return final_entities.loc[ :, ENTITIES_FINAL_COLUMNS, diff --git a/packages/graphrag/graphrag/index/operations/resolve_entities/__init__.py b/packages/graphrag/graphrag/index/operations/resolve_entities/__init__.py new file mode 100644 index 000000000..ae85b818d --- /dev/null +++ b/packages/graphrag/graphrag/index/operations/resolve_entities/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Entity resolution operation package.""" diff --git a/packages/graphrag/graphrag/index/operations/resolve_entities/resolve_entities.py b/packages/graphrag/graphrag/index/operations/resolve_entities/resolve_entities.py new file mode 100644 index 000000000..458f939b2 --- /dev/null +++ b/packages/graphrag/graphrag/index/operations/resolve_entities/resolve_entities.py @@ -0,0 +1,138 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""LLM-based entity resolution operation. + +Identifies entities with different surface forms that refer to the same +real-world entity (e.g. "Ahab" and "Captain Ahab") and unifies their titles. +""" + +import logging +from typing import TYPE_CHECKING + +import pandas as pd + +from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks + +if TYPE_CHECKING: + from graphrag_llm.completion import LLMCompletion + +logger = logging.getLogger(__name__) + + +async def resolve_entities( + entities: pd.DataFrame, + relationships: pd.DataFrame, + callbacks: WorkflowCallbacks, + model: "LLMCompletion", + prompt: str, + num_threads: int, +) -> tuple[pd.DataFrame, pd.DataFrame]: + """Identify and merge duplicate entities with different surface forms. + + Sends all unique entity titles to the LLM in a single call, parses the + response to build a rename mapping, then applies it to entity titles and + relationship source/target columns. Each canonical entity receives an + ``alternative_names`` column listing all of its aliases. + + Parameters + ---------- + entities : pd.DataFrame + Entity DataFrame with at least a ``title`` column. + relationships : pd.DataFrame + Relationship DataFrame with ``source`` and ``target`` columns. + callbacks : WorkflowCallbacks + Progress callbacks. + model : LLMCompletion + The LLM completion model to use. + prompt : str + The entity resolution prompt template (must contain ``{entity_list}``). + num_threads : int + Concurrency limit for LLM calls (reserved for future use). + + Returns + ------- + tuple[pd.DataFrame, pd.DataFrame] + Updated ``(entities, relationships)`` with unified titles and an + ``alternative_names`` column on entities. + """ + if "title" not in entities.columns: + return entities, relationships + + titles = entities["title"].dropna().unique().tolist() + if len(titles) < 2: + return entities, relationships + + logger.info( + "Running LLM entity resolution on %d unique entity names...", len(titles) + ) + + # Build numbered entity list for the prompt + entity_list = "\n".join(f"{i+1}. {name}" for i, name in enumerate(titles)) + formatted_prompt = prompt.format(entity_list=entity_list) + + try: + response = await model.completion_async(messages=formatted_prompt) + raw = (response.content or "").strip() + except Exception as e: + logger.warning("Entity resolution LLM call failed, skipping resolution: %s", e, exc_info=True) + return entities, relationships + + if "NO_DUPLICATES" in raw: + logger.info("Entity resolution: no duplicates found") + return entities, relationships + + # Parse response and build rename mapping + rename_map: dict[str, str] = {} # alias → canonical + alternatives: dict[str, set[str]] = {} # canonical → {aliases} + + for line in raw.splitlines(): + line = line.strip() + if not line or line.startswith("#") or line.startswith("Where"): + continue + parts = [p.strip() for p in line.split(",")] + indices: list[int] = [] + for p in parts: + digits = "".join(c for c in p if c.isdigit()) + if digits: + idx = int(digits) - 1 # 1-indexed → 0-indexed + if 0 <= idx < len(titles): + indices.append(idx) + if len(indices) >= 2: + canonical = titles[indices[0]] + if canonical not in alternatives: + alternatives[canonical] = set() + for alias_idx in indices[1:]: + alias = titles[alias_idx] + rename_map[alias] = canonical + alternatives[canonical].add(alias) + logger.info(" Entity resolution: '%s' → '%s'", alias, canonical) + + if not rename_map: + logger.info("Entity resolution complete: no duplicates found") + return entities, relationships + + logger.info("Entity resolution: merging %d duplicate names", len(rename_map)) + + # Apply renames to entity titles + entities = entities.copy() + entities["title"] = entities["title"].map(lambda t: rename_map.get(t, t)) + + # Add alternative_names column + entities["alternative_names"] = entities["title"].map( + lambda t: sorted(alternatives.get(t, set())) + ) + + # Apply renames to relationship source/target + if not relationships.empty: + relationships = relationships.copy() + if "source" in relationships.columns: + relationships["source"] = relationships["source"].map( + lambda s: rename_map.get(s, s) + ) + if "target" in relationships.columns: + relationships["target"] = relationships["target"].map( + lambda t: rename_map.get(t, t) + ) + + return entities, relationships diff --git a/packages/graphrag/graphrag/index/update/entities.py b/packages/graphrag/graphrag/index/update/entities.py index fe9bb2347..09badda23 100644 --- a/packages/graphrag/graphrag/index/update/entities.py +++ b/packages/graphrag/graphrag/index/update/entities.py @@ -44,6 +44,11 @@ def _group_and_resolve_entities( delta_entities_df["human_readable_id"] = np.arange( initial_id, initial_id + len(delta_entities_df) ) + # Ensure alternative_names column exists (may be absent in older indexes) + for df in [old_entities_df, delta_entities_df]: + if "alternative_names" not in df.columns: + df["alternative_names"] = [[] for _ in range(len(df))] + # Concat A and B combined = pd.concat( [old_entities_df, delta_entities_df], ignore_index=True, copy=False @@ -60,6 +65,9 @@ def _group_and_resolve_entities( "description": lambda x: list(x.astype(str)), # Ensure str # Concatenate nd.array into a single list "text_unit_ids": lambda x: list(itertools.chain(*x.tolist())), + "alternative_names": lambda x: sorted( + set(itertools.chain(*x.tolist())) + ), "degree": "first", # todo: we could probably re-compute this with the entire new graph }) .reset_index() diff --git a/packages/graphrag/graphrag/index/workflows/extract_graph.py b/packages/graphrag/graphrag/index/workflows/extract_graph.py index dc86b180f..2d902fabc 100644 --- a/packages/graphrag/graphrag/index/workflows/extract_graph.py +++ b/packages/graphrag/graphrag/index/workflows/extract_graph.py @@ -17,6 +17,9 @@ from graphrag.index.operations.extract_graph.extract_graph import ( extract_graph as extractor, ) +from graphrag.index.operations.resolve_entities.resolve_entities import ( + resolve_entities, +) from graphrag.index.operations.summarize_descriptions.summarize_descriptions import ( summarize_descriptions, ) @@ -58,6 +61,24 @@ async def run_workflow( cache_key_creator=cache_key_creator, ) + # Entity resolution model (optional) + resolution_enabled = config.entity_resolution.enabled + resolution_model = None + resolution_prompt = "" + if resolution_enabled: + resolution_model_config = config.get_completion_model_config( + config.entity_resolution.completion_model_id + ) + resolution_prompts = config.entity_resolution.resolved_prompts() + resolution_prompt = resolution_prompts.resolution_prompt + resolution_model = create_completion( + resolution_model_config, + cache=context.cache.child( + config.entity_resolution.model_instance_name + ), + cache_key_creator=cache_key_creator, + ) + entities, relationships, raw_entities, raw_relationships = await extract_graph( text_units=text_units, callbacks=context.callbacks, @@ -72,6 +93,10 @@ async def run_workflow( max_input_tokens=config.summarize_descriptions.max_input_tokens, summarization_prompt=summarization_prompts.summarize_prompt, summarization_num_threads=config.concurrent_requests, + resolution_enabled=resolution_enabled, + resolution_model=resolution_model, + resolution_prompt=resolution_prompt, + resolution_num_threads=config.concurrent_requests, ) await context.output_table_provider.write_dataframe("entities", entities) @@ -108,6 +133,10 @@ async def extract_graph( max_input_tokens: int, summarization_prompt: str, summarization_num_threads: int, + resolution_enabled: bool = False, + resolution_model: "LLMCompletion | None" = None, + resolution_prompt: str = "", + resolution_num_threads: int = 1, ) -> tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame, pd.DataFrame]: """All the steps to create the base entity graph.""" # this returns a graph for each text unit, to be merged later @@ -136,10 +165,21 @@ async def extract_graph( logger.error(error_msg) raise ValueError(error_msg) - # copy these as is before any summarization + # copy these as is before any resolution or summarization raw_entities = extracted_entities.copy() raw_relationships = extracted_relationships.copy() + # Resolve duplicate entity names before grouping by title + if resolution_enabled and resolution_model is not None: + extracted_entities, extracted_relationships = await resolve_entities( + entities=extracted_entities, + relationships=extracted_relationships, + callbacks=callbacks, + model=resolution_model, + prompt=resolution_prompt, + num_threads=resolution_num_threads, + ) + entities, relationships = await get_summarized_entities_relationships( extracted_entities=extracted_entities, extracted_relationships=extracted_relationships, diff --git a/packages/graphrag/graphrag/prompts/index/entity_resolution.py b/packages/graphrag/graphrag/prompts/index/entity_resolution.py new file mode 100644 index 000000000..071b6674b --- /dev/null +++ b/packages/graphrag/graphrag/prompts/index/entity_resolution.py @@ -0,0 +1,29 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A file containing prompts definition.""" + +ENTITY_RESOLUTION_PROMPT = """ +You are an entity resolution expert. Below is a numbered list of entity names +extracted from a knowledge graph. Identify which names refer to the SAME +real-world entity and choose the best canonical name for each group of duplicates. + +Rules: +- Only merge names that clearly refer to the same entity (e.g., "Ahab" and +"Captain Ahab", "USA" and "United States of America") +- Do NOT merge entities that are merely related (e.g., "Ahab" and "Moby Dick") +- Choose the most complete and commonly used name as the canonical form +- Reference entities by their number + +Output format — one group per line, canonical number first, then duplicate numbers: +3, 17 +5, 12, 28 + +Where each line means: all listed numbers refer to the same entity, and the +first number's name is the canonical form. + +If no duplicates are found, respond with exactly: NO_DUPLICATES + +Entity list: +{entity_list} +""" diff --git a/tests/unit/indexing/operations/resolve_entities/test_resolve_entities.py b/tests/unit/indexing/operations/resolve_entities/test_resolve_entities.py new file mode 100644 index 000000000..79a11cdd2 --- /dev/null +++ b/tests/unit/indexing/operations/resolve_entities/test_resolve_entities.py @@ -0,0 +1,217 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Unit tests for the resolve_entities operation.""" + +from unittest.mock import AsyncMock, MagicMock + +import pandas as pd +import pytest + +from graphrag.index.operations.resolve_entities.resolve_entities import ( + resolve_entities, +) + + +@pytest.fixture +def sample_entities(): + """Create sample entity DataFrame with known duplicates.""" + return pd.DataFrame({ + "title": [ + "Captain Ahab", + "Moby Dick", + "Ahab", + "The Pequod", + "Ishmael", + "Pequod", + ], + "description": [ + "Captain of the Pequod", + "The great white whale", + "The obsessed captain", + "The whaling ship", + "The narrator", + "A whaling vessel", + ], + }) + + +@pytest.fixture +def sample_relationships(): + """Create sample relationship DataFrame.""" + return pd.DataFrame({ + "source": ["Ahab", "Captain Ahab", "Ishmael", "Pequod"], + "target": ["Moby Dick", "The Pequod", "Pequod", "Ahab"], + "description": [ + "hunts", + "commands", + "boards", + "carries", + ], + }) + + +@pytest.fixture +def mock_callbacks(): + """Create mock workflow callbacks.""" + callbacks = MagicMock() + callbacks.progress = MagicMock() + return callbacks + + +def _make_mock_model(response_text: str) -> AsyncMock: + """Create a mock LLM model that returns the given text.""" + model = AsyncMock() + model.return_value = response_text + return model + + +@pytest.mark.asyncio +async def test_no_duplicates(sample_entities, sample_relationships, mock_callbacks): + """When LLM finds no duplicates, entities remain unchanged.""" + model = _make_mock_model("NO_DUPLICATES") + + result_entities, result_relationships = await resolve_entities( + entities=sample_entities.copy(), + relationships=sample_relationships.copy(), + callbacks=mock_callbacks, + model=model, + prompt="{entity_list}", + batch_size=200, + num_threads=1, + ) + + # Titles should be unchanged + assert list(result_entities["title"]) == list(sample_entities["title"]) + assert list(result_relationships["source"]) == list( + sample_relationships["source"] + ) + + +@pytest.mark.asyncio +async def test_simple_duplicates( + sample_entities, sample_relationships, mock_callbacks +): + """Ahab → Captain Ahab, Pequod → The Pequod.""" + # LLM response: entity 1 (Captain Ahab) and 3 (Ahab) are the same; + # entity 4 (The Pequod) and 6 (Pequod) are the same. + model = _make_mock_model("1, 3\n4, 6") + + result_entities, result_relationships = await resolve_entities( + entities=sample_entities.copy(), + relationships=sample_relationships.copy(), + callbacks=mock_callbacks, + model=model, + prompt="{entity_list}", + batch_size=200, + num_threads=1, + ) + + # "Ahab" should become "Captain Ahab" + titles = list(result_entities["title"]) + assert "Ahab" not in titles + assert titles.count("Captain Ahab") == 2 # both rows unified + + # "Pequod" should become "The Pequod" + assert "Pequod" not in titles + assert titles.count("The Pequod") == 2 + + # Relationships should also be renamed + sources = list(result_relationships["source"]) + targets = list(result_relationships["target"]) + assert "Ahab" not in sources + assert "Ahab" not in targets + assert "Pequod" not in sources + assert "Pequod" not in targets + + +@pytest.mark.asyncio +async def test_llm_failure_graceful( + sample_entities, sample_relationships, mock_callbacks +): + """If LLM call fails, entities are returned unchanged.""" + model = AsyncMock(side_effect=Exception("LLM unavailable")) + + result_entities, result_relationships = await resolve_entities( + entities=sample_entities.copy(), + relationships=sample_relationships.copy(), + callbacks=mock_callbacks, + model=model, + prompt="{entity_list}", + batch_size=200, + num_threads=1, + ) + + # Should fall back to no changes + assert list(result_entities["title"]) == list(sample_entities["title"]) + + +@pytest.mark.asyncio +async def test_single_entity_skips(): + """With fewer than 2 entities, resolution is skipped entirely.""" + entities = pd.DataFrame({"title": ["Only One"]}) + relationships = pd.DataFrame({"source": [], "target": []}) + callbacks = MagicMock() + callbacks.progress = MagicMock() + model = _make_mock_model("should not be called") + + result_entities, _ = await resolve_entities( + entities=entities, + relationships=relationships, + callbacks=callbacks, + model=model, + prompt="{entity_list}", + batch_size=200, + num_threads=1, + ) + + # Model should not have been called + model.assert_not_called() + assert list(result_entities["title"]) == ["Only One"] + + +@pytest.mark.asyncio +async def test_batch_splitting(mock_callbacks): + """Entities are split into batches of the configured size.""" + # 5 entities, batch_size=2 → 3 batches + entities = pd.DataFrame({ + "title": ["A", "B", "C", "D", "E"], + "description": [""] * 5, + }) + relationships = pd.DataFrame({"source": ["A"], "target": ["B"]}) + + model = _make_mock_model("NO_DUPLICATES") + + await resolve_entities( + entities=entities.copy(), + relationships=relationships.copy(), + callbacks=mock_callbacks, + model=model, + prompt="{entity_list}", + batch_size=2, + num_threads=1, + ) + + # Model should have been called 3 times (ceil(5/2)) + assert model.call_count == 3 + + +@pytest.mark.asyncio +async def test_missing_title_column(mock_callbacks): + """If there's no title column, skip resolution.""" + entities = pd.DataFrame({"name": ["A", "B"]}) + relationships = pd.DataFrame({"source": ["A"], "target": ["B"]}) + model = _make_mock_model("should not be called") + + result_entities, _ = await resolve_entities( + entities=entities, + relationships=relationships, + callbacks=mock_callbacks, + model=model, + prompt="{entity_list}", + batch_size=200, + num_threads=1, + ) + + model.assert_not_called() + assert list(result_entities.columns) == ["name"]