From 00d993e9442720dda1402f1163d7ed5f1153c140 Mon Sep 17 00:00:00 2001 From: kmontemayor Date: Tue, 10 Feb 2026 21:01:37 +0000 Subject: [PATCH 1/6] Add RecordCountValidatingPostProcessor --- .../configs/e2e_het_dblp_sup_task_config.yaml | 2 + .../configs/e2e_hom_cora_sup_task_config.yaml | 2 + .../e2e_het_dblp_sup_gs_task_config.yaml | 2 + .../e2e_hom_cora_sup_gs_task_config.yaml | 2 + gigl/src/post_process/impl/__init__.py | 0 .../record_count_validating_post_processor.py | 160 ++++++++ gigl/src/post_process/post_processor.py | 16 +- .../integration/src/post_process/__init__.py | 0 .../src/post_process/impl/__init__.py | 0 ...rd_count_validating_post_processor_test.py | 353 ++++++++++++++++++ 10 files changed, 529 insertions(+), 8 deletions(-) create mode 100644 gigl/src/post_process/impl/__init__.py create mode 100644 gigl/src/post_process/impl/record_count_validating_post_processor.py create mode 100644 tests/integration/src/post_process/__init__.py create mode 100644 tests/integration/src/post_process/impl/__init__.py create mode 100644 tests/integration/src/post_process/impl/record_count_validating_post_processor_test.py diff --git a/examples/link_prediction/configs/e2e_het_dblp_sup_task_config.yaml b/examples/link_prediction/configs/e2e_het_dblp_sup_task_config.yaml index 8531fd081..c5edac2df 100644 --- a/examples/link_prediction/configs/e2e_het_dblp_sup_task_config.yaml +++ b/examples/link_prediction/configs/e2e_het_dblp_sup_task_config.yaml @@ -70,5 +70,7 @@ sharedConfig: shouldSkipInference: false # Model Evaluation is currently only supported for tabularized SGS GiGL pipelines. This will soon be added for in-mem SGS GiGL pipelines. shouldSkipModelEvaluation: true +postProcessorConfig: + postProcessorClsPath: gigl.src.post_process.impl.record_count_validating_post_processor.RecordCountValidatingPostProcessor featureFlags: should_run_glt_backend: 'True' diff --git a/examples/link_prediction/configs/e2e_hom_cora_sup_task_config.yaml b/examples/link_prediction/configs/e2e_hom_cora_sup_task_config.yaml index 606f13c29..8e0794a69 100644 --- a/examples/link_prediction/configs/e2e_hom_cora_sup_task_config.yaml +++ b/examples/link_prediction/configs/e2e_hom_cora_sup_task_config.yaml @@ -36,6 +36,8 @@ taskMetadata: - dstNodeType: paper relation: cites srcNodeType: paper +postProcessorConfig: + postProcessorClsPath: gigl.src.post_process.impl.record_count_validating_post_processor.RecordCountValidatingPostProcessor featureFlags: should_run_glt_backend: 'True' data_preprocessor_num_shards: '2' diff --git a/examples/link_prediction/graph_store/configs/e2e_het_dblp_sup_gs_task_config.yaml b/examples/link_prediction/graph_store/configs/e2e_het_dblp_sup_gs_task_config.yaml index 1ebf9acb7..580733f47 100644 --- a/examples/link_prediction/graph_store/configs/e2e_het_dblp_sup_gs_task_config.yaml +++ b/examples/link_prediction/graph_store/configs/e2e_het_dblp_sup_gs_task_config.yaml @@ -76,5 +76,7 @@ sharedConfig: shouldSkipInference: false # Model Evaluation is currently only supported for tabularized SGS GiGL pipelines. This will soon be added for in-mem SGS GiGL pipelines. shouldSkipModelEvaluation: true +postProcessorConfig: + postProcessorClsPath: gigl.src.post_process.impl.record_count_validating_post_processor.RecordCountValidatingPostProcessor featureFlags: should_run_glt_backend: 'True' diff --git a/examples/link_prediction/graph_store/configs/e2e_hom_cora_sup_gs_task_config.yaml b/examples/link_prediction/graph_store/configs/e2e_hom_cora_sup_gs_task_config.yaml index 84e0badef..29de10763 100644 --- a/examples/link_prediction/graph_store/configs/e2e_hom_cora_sup_gs_task_config.yaml +++ b/examples/link_prediction/graph_store/configs/e2e_hom_cora_sup_gs_task_config.yaml @@ -44,6 +44,8 @@ taskMetadata: - dstNodeType: paper relation: cites srcNodeType: paper +postProcessorConfig: + postProcessorClsPath: gigl.src.post_process.impl.record_count_validating_post_processor.RecordCountValidatingPostProcessor featureFlags: should_run_glt_backend: 'True' data_preprocessor_num_shards: '2' diff --git a/gigl/src/post_process/impl/__init__.py b/gigl/src/post_process/impl/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/gigl/src/post_process/impl/record_count_validating_post_processor.py b/gigl/src/post_process/impl/record_count_validating_post_processor.py new file mode 100644 index 000000000..9097dd77c --- /dev/null +++ b/gigl/src/post_process/impl/record_count_validating_post_processor.py @@ -0,0 +1,160 @@ +from gigl.common import Uri +from gigl.common.logger import Logger +from gigl.src.common.types import AppliedTaskIdentifier +from gigl.src.common.types.graph_data import NodeType +from gigl.src.common.types.pb_wrappers.gbml_config import GbmlConfigPbWrapper +from gigl.src.common.utils.bq import BqUtils +from gigl.src.post_process.post_processor import PostProcessor + +logger = Logger() + + +class RecordCountValidatingPostProcessor(PostProcessor): + """ + Post processor that extends PostProcessor with record count validation. + + Runs all standard PostProcessor logic (unenumeration, user-defined + post-processing, metric export, cleanup), then validates that for each + node type, the unenumerated output tables (embeddings, predictions) have + the same number of rows as the corresponding enumerated_node_ids_bq_table. + + Only applicable for the GLT backend path. + """ + + # TODO: Add edge-level validation support. + + def _run( + self, + applied_task_identifier: AppliedTaskIdentifier, + task_config_uri: Uri, + ): + # Run all standard PostProcessor logic first + super()._run( + applied_task_identifier=applied_task_identifier, + task_config_uri=task_config_uri, + ) + + # Then validate record counts + gbml_config_wrapper = GbmlConfigPbWrapper.get_gbml_config_pb_wrapper_from_uri( + gbml_config_uri=task_config_uri + ) + self._validate_record_counts(gbml_config_wrapper=gbml_config_wrapper) + + def _validate_record_counts( + self, + gbml_config_wrapper: GbmlConfigPbWrapper, + bq_utils: BqUtils | None = None, + ) -> None: + """Validates that output BQ tables have matching record counts. + + For each node type in the inference output, checks that: + 1. The enumerated_node_ids_bq_table exists. + 2. The embeddings table (if configured) exists and has the same row count. + 3. The predictions table (if configured) exists and has the same row count. + 4. At least one of embeddings or predictions is configured. + + All errors are collected and reported together before raising. + + Args: + gbml_config_wrapper: The GbmlConfig wrapper with access to all metadata. + bq_utils: Optional BqUtils instance for testing. If None, creates one + from the resource config. + + Raises: + ValueError: If any validation errors are found, with all errors listed. + """ + if bq_utils is None: + from gigl.env.pipelines_config import get_resource_config + + resource_config = get_resource_config() + bq_utils = BqUtils(project=resource_config.project) + + validation_errors: list[str] = [] + + inference_output_map = ( + gbml_config_wrapper.shared_config.inference_metadata.node_type_to_inferencer_output_info_map + ) + node_type_to_condensed = ( + gbml_config_wrapper.graph_metadata_pb_wrapper.node_type_to_condensed_node_type_map + ) + preprocessed_metadata = ( + gbml_config_wrapper.preprocessed_metadata_pb_wrapper.preprocessed_metadata + ) + + for node_type, inference_output in inference_output_map.items(): + condensed_node_type = node_type_to_condensed[NodeType(node_type)] + node_metadata = ( + preprocessed_metadata.condensed_node_type_to_preprocessed_metadata[ + int(condensed_node_type) + ] + ) + enumerated_table = node_metadata.enumerated_node_ids_bq_table + + if not enumerated_table: + validation_errors.append( + f"[{node_type}] No enumerated_node_ids_bq_table configured." + ) + continue + + if not bq_utils.does_bq_table_exist(enumerated_table): + validation_errors.append( + f"[{node_type}] enumerated_node_ids_bq_table does not exist: {enumerated_table}" + ) + continue + + expected_count = bq_utils.count_number_of_rows_in_bq_table(enumerated_table) + logger.info( + f"[{node_type}] enumerated_node_ids_bq_table ({enumerated_table}) has {expected_count} rows." + ) + + # Validate embeddings + embeddings_path = inference_output.embeddings_path + if embeddings_path: + if not bq_utils.does_bq_table_exist(embeddings_path): + validation_errors.append( + f"[{node_type}] Embeddings table does not exist: {embeddings_path}" + ) + else: + actual = bq_utils.count_number_of_rows_in_bq_table(embeddings_path) + logger.info( + f"[{node_type}] Embeddings table ({embeddings_path}) has {actual} rows." + ) + if actual != expected_count: + validation_errors.append( + f"[{node_type}] Embeddings row count mismatch: " + f"expected {expected_count}, got {actual} " + f"(table: {embeddings_path})" + ) + + # Validate predictions + predictions_path = inference_output.predictions_path + if predictions_path: + if not bq_utils.does_bq_table_exist(predictions_path): + validation_errors.append( + f"[{node_type}] Predictions table does not exist: {predictions_path}" + ) + else: + actual = bq_utils.count_number_of_rows_in_bq_table(predictions_path) + logger.info( + f"[{node_type}] Predictions table ({predictions_path}) has {actual} rows." + ) + if actual != expected_count: + validation_errors.append( + f"[{node_type}] Predictions row count mismatch: " + f"expected {expected_count}, got {actual} " + f"(table: {predictions_path})" + ) + + if not embeddings_path and not predictions_path: + validation_errors.append( + f"[{node_type}] Neither embeddings_path nor predictions_path is set." + ) + + if validation_errors: + error_summary = "\n".join(validation_errors) + raise ValueError( + f"Record count validation failed with {len(validation_errors)} error(s):\n" + f"{error_summary}" + ) + + logger.info("All record count validations passed.") diff --git a/gigl/src/post_process/post_processor.py b/gigl/src/post_process/post_processor.py index e9f9f8828..3213b94ef 100644 --- a/gigl/src/post_process/post_processor.py +++ b/gigl/src/post_process/post_processor.py @@ -30,7 +30,7 @@ class PostProcessor: - def __run_post_process( + def _run_post_process( self, gbml_config_pb: gbml_config_pb2.GbmlConfig, applied_task_identifier: AppliedTaskIdentifier, @@ -66,7 +66,7 @@ def __run_post_process( EvalMetricsCollection ] = post_processor.run_post_process(gbml_config_pb=gbml_config_pb) if post_processor_metrics is not None: - self.__write_post_processor_metrics_to_uri( + self._write_post_processor_metrics_to_uri( model_eval_metrics=post_processor_metrics, gbml_config_pb=gbml_config_pb, ) @@ -87,7 +87,7 @@ def __run_post_process( ) gcs_utils.delete_files_in_bucket_dir(gcs_path=temp_dir_gcs_path) - def __write_post_processor_metrics_to_uri( + def _write_post_processor_metrics_to_uri( self, model_eval_metrics: EvalMetricsCollection, gbml_config_pb: gbml_config_pb2.GbmlConfig, @@ -106,7 +106,7 @@ def __write_post_processor_metrics_to_uri( ) logger.info(f"Wrote eval metrics to {post_processor_log_metrics_uri.uri}.") - def __should_run_unenumeration( + def _should_run_unenumeration( self, gbml_config_wrapper: GbmlConfigPbWrapper ) -> bool: """ @@ -114,7 +114,7 @@ def __should_run_unenumeration( """ return gbml_config_wrapper.should_use_glt_backend - def __run( + def _run( self, applied_task_identifier: AppliedTaskIdentifier, task_config_uri: Uri, @@ -124,7 +124,7 @@ def __run( gbml_config_uri=task_config_uri ) ) - if self.__should_run_unenumeration(gbml_config_wrapper=gbml_config_wrapper): + if self._should_run_unenumeration(gbml_config_wrapper=gbml_config_wrapper): logger.info(f"Running unenumeration for inferred assets in post processor") unenumerate_all_inferred_bq_assets( gbml_config_pb_wrapper=gbml_config_wrapper @@ -133,7 +133,7 @@ def __run( f"Finished running unenumeration for inferred assets in post processor" ) - self.__run_post_process( + self._run_post_process( gbml_config_pb=gbml_config_wrapper.gbml_config_pb, applied_task_identifier=applied_task_identifier, ) @@ -150,7 +150,7 @@ def run( resource_config_uri: Uri, ): try: - return self.__run( + return self._run( applied_task_identifier=applied_task_identifier, task_config_uri=task_config_uri, ) diff --git a/tests/integration/src/post_process/__init__.py b/tests/integration/src/post_process/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/integration/src/post_process/impl/__init__.py b/tests/integration/src/post_process/impl/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/integration/src/post_process/impl/record_count_validating_post_processor_test.py b/tests/integration/src/post_process/impl/record_count_validating_post_processor_test.py new file mode 100644 index 000000000..cc94c0a36 --- /dev/null +++ b/tests/integration/src/post_process/impl/record_count_validating_post_processor_test.py @@ -0,0 +1,353 @@ +import tempfile +import time +import uuid + +from absl.testing import absltest + +from gigl.common import LocalUri +from gigl.common.utils.proto_utils import ProtoUtils +from gigl.env.pipelines_config import get_resource_config +from gigl.src.common.types.pb_wrappers.gbml_config import GbmlConfigPbWrapper +from gigl.src.common.utils.bq import BqUtils +from gigl.src.post_process.impl.record_count_validating_post_processor import ( + RecordCountValidatingPostProcessor, +) +from snapchat.research.gbml import ( + gbml_config_pb2, + graph_schema_pb2, + preprocessed_metadata_pb2, +) +from tests.test_assets.test_case import TestCase + + +class RecordCountValidatingPostProcessorTest(TestCase): + """Integration tests for RecordCountValidatingPostProcessor against real BigQuery.""" + + def setUp(self): + resource_config = get_resource_config() + self._test_unique_name = f"gigl_rcv_test_{uuid.uuid4().hex[:12]}" + + self._bq_utils = BqUtils(project=resource_config.project) + self._bq_project = resource_config.project + self._bq_dataset = resource_config.temp_assets_bq_dataset_name + self._proto_utils = ProtoUtils() + + self._tables_to_cleanup: list[str] = [] + self._validator = RecordCountValidatingPostProcessor() + + def tearDown(self): + for table_path in self._tables_to_cleanup: + self._bq_utils.delete_bq_table_if_exist( + bq_table_path=table_path, not_found_ok=True + ) + + def _make_table_path(self, suffix: str) -> str: + table_name = f"{self._test_unique_name}_{suffix}" + return self._bq_utils.join_path(self._bq_project, self._bq_dataset, table_name) + + def _create_test_table(self, table_path: str, num_rows: int) -> None: + """Create a BQ table with the specified number of rows and track for cleanup.""" + self._tables_to_cleanup.append(table_path) + formatted = self._bq_utils.format_bq_path(table_path) + create_query = f""" + CREATE OR REPLACE TABLE `{formatted}` AS + SELECT + id, + CONCAT('node_', CAST(id AS STRING)) as node_id + FROM UNNEST(GENERATE_ARRAY(0, {num_rows - 1})) as id + """ + self._bq_utils.run_query(query=create_query, labels={}) + time.sleep(1) + + def _build_gbml_config_wrapper( + self, + node_types: list[str], + enumerated_tables: dict[str, str], + embeddings_tables: dict[str, str] | None = None, + predictions_tables: dict[str, str] | None = None, + ) -> GbmlConfigPbWrapper: + """Build a GbmlConfigPbWrapper with the given table paths. + + Args: + node_types: List of string node types. + enumerated_tables: Map of node_type -> enumerated_node_ids_bq_table path. + embeddings_tables: Map of node_type -> embeddings BQ table path. + predictions_tables: Map of node_type -> predictions BQ table path. + """ + embeddings_tables = embeddings_tables or {} + predictions_tables = predictions_tables or {} + + # Build graph metadata with condensed node type and edge type mappings. + # GbmlConfigPbWrapper requires both maps to be non-empty. + graph_metadata = graph_schema_pb2.GraphMetadata() + for i, nt in enumerate(node_types): + graph_metadata.node_types.append(nt) + graph_metadata.condensed_node_type_map[i] = nt + + # Add a dummy edge type so condensed_edge_type_map is non-empty + dummy_edge = graph_schema_pb2.EdgeType( + src_node_type=node_types[0], + relation="dummy", + dst_node_type=node_types[0], + ) + graph_metadata.edge_types.append(dummy_edge) + graph_metadata.condensed_edge_type_map[0].CopyFrom(dummy_edge) + + # Build preprocessed metadata + preprocessed_metadata = preprocessed_metadata_pb2.PreprocessedMetadata() + for i, nt in enumerate(node_types): + node_output = ( + preprocessed_metadata.condensed_node_type_to_preprocessed_metadata[i] + ) + node_output.enumerated_node_ids_bq_table = enumerated_tables.get(nt, "") + + # Write preprocessed metadata to a temp file + f = tempfile.NamedTemporaryFile(delete=False, suffix=".yaml") + preprocessed_metadata_uri = LocalUri(f.name) + self._proto_utils.write_proto_to_yaml( + proto=preprocessed_metadata, uri=preprocessed_metadata_uri + ) + + # Build the main GbmlConfig + gbml_config = gbml_config_pb2.GbmlConfig() + gbml_config.graph_metadata.CopyFrom(graph_metadata) + gbml_config.shared_config.preprocessed_metadata_uri = ( + preprocessed_metadata_uri.uri + ) + + # Populate inference metadata + for nt in node_types: + inference_output = gbml_config.shared_config.inference_metadata.node_type_to_inferencer_output_info_map[ + nt + ] + if nt in embeddings_tables: + inference_output.embeddings_path = embeddings_tables[nt] + if nt in predictions_tables: + inference_output.predictions_path = predictions_tables[nt] + + return GbmlConfigPbWrapper(gbml_config_pb=gbml_config) + + def test_validation_passes_when_counts_match(self): + """All tables exist and row counts match — validation should pass.""" + num_rows = 25 + enum_table = self._make_table_path("enum_paper") + emb_table = self._make_table_path("emb_paper") + + self._create_test_table(enum_table, num_rows) + self._create_test_table(emb_table, num_rows) + + wrapper = self._build_gbml_config_wrapper( + node_types=["paper"], + enumerated_tables={"paper": enum_table}, + embeddings_tables={"paper": emb_table}, + ) + + # Should not raise + self._validator._validate_record_counts( + gbml_config_wrapper=wrapper, bq_utils=self._bq_utils + ) + + def test_validation_fails_on_row_count_mismatch(self): + """Embeddings table has fewer rows — should raise ValueError.""" + enum_table = self._make_table_path("enum_paper") + emb_table = self._make_table_path("emb_paper") + + self._create_test_table(enum_table, 50) + self._create_test_table(emb_table, 30) + + wrapper = self._build_gbml_config_wrapper( + node_types=["paper"], + enumerated_tables={"paper": enum_table}, + embeddings_tables={"paper": emb_table}, + ) + + with self.assertRaises(ValueError): + self._validator._validate_record_counts( + gbml_config_wrapper=wrapper, bq_utils=self._bq_utils + ) + + def test_validation_fails_on_missing_embeddings_table(self): + """embeddings_path is set but BQ table does not exist — should raise ValueError.""" + enum_table = self._make_table_path("enum_paper") + self._create_test_table(enum_table, 25) + + nonexistent_emb_table = self._make_table_path("emb_nonexistent") + # Do NOT create this table — it should not exist + + wrapper = self._build_gbml_config_wrapper( + node_types=["paper"], + enumerated_tables={"paper": enum_table}, + embeddings_tables={"paper": nonexistent_emb_table}, + ) + + with self.assertRaises(ValueError): + self._validator._validate_record_counts( + gbml_config_wrapper=wrapper, bq_utils=self._bq_utils + ) + + def test_validation_fails_on_missing_predictions_table(self): + """predictions_path is set but BQ table does not exist — should raise ValueError.""" + enum_table = self._make_table_path("enum_paper") + self._create_test_table(enum_table, 25) + + nonexistent_pred_table = self._make_table_path("pred_nonexistent") + + wrapper = self._build_gbml_config_wrapper( + node_types=["paper"], + enumerated_tables={"paper": enum_table}, + predictions_tables={"paper": nonexistent_pred_table}, + ) + + with self.assertRaises(ValueError): + self._validator._validate_record_counts( + gbml_config_wrapper=wrapper, bq_utils=self._bq_utils + ) + + def test_validation_fails_when_no_output_paths_set(self): + """Neither embeddings nor predictions is set — should raise ValueError.""" + enum_table = self._make_table_path("enum_paper") + self._create_test_table(enum_table, 25) + + wrapper = self._build_gbml_config_wrapper( + node_types=["paper"], + enumerated_tables={"paper": enum_table}, + # No embeddings or predictions + ) + + with self.assertRaises(ValueError): + self._validator._validate_record_counts( + gbml_config_wrapper=wrapper, bq_utils=self._bq_utils + ) + + def test_validation_with_multiple_node_types(self): + """Heterogeneous graph with multiple node types — each validated independently.""" + num_rows_author = 30 + num_rows_paper = 50 + + enum_author = self._make_table_path("enum_author") + emb_author = self._make_table_path("emb_author") + enum_paper = self._make_table_path("enum_paper") + emb_paper = self._make_table_path("emb_paper") + + self._create_test_table(enum_author, num_rows_author) + self._create_test_table(emb_author, num_rows_author) + self._create_test_table(enum_paper, num_rows_paper) + self._create_test_table(emb_paper, num_rows_paper) + + wrapper = self._build_gbml_config_wrapper( + node_types=["author", "paper"], + enumerated_tables={ + "author": enum_author, + "paper": enum_paper, + }, + embeddings_tables={ + "author": emb_author, + "paper": emb_paper, + }, + ) + + # Should not raise + self._validator._validate_record_counts( + gbml_config_wrapper=wrapper, bq_utils=self._bq_utils + ) + + def test_validation_partial_output_only_embeddings(self): + """Only embeddings (no predictions) for a node type — should pass.""" + num_rows = 40 + enum_table = self._make_table_path("enum_paper") + emb_table = self._make_table_path("emb_paper") + + self._create_test_table(enum_table, num_rows) + self._create_test_table(emb_table, num_rows) + + wrapper = self._build_gbml_config_wrapper( + node_types=["paper"], + enumerated_tables={"paper": enum_table}, + embeddings_tables={"paper": emb_table}, + # No predictions — that's fine + ) + + # Should not raise + self._validator._validate_record_counts( + gbml_config_wrapper=wrapper, bq_utils=self._bq_utils + ) + + def test_multiple_errors_collected(self): + """One table missing AND one count mismatch — both errors appear in exception.""" + enum_author = self._make_table_path("enum_author") + emb_author = self._make_table_path("emb_author") + enum_paper = self._make_table_path("enum_paper") + + self._create_test_table(enum_author, 30) + self._create_test_table(emb_author, 20) # Mismatch: 20 != 30 + self._create_test_table(enum_paper, 50) + + nonexistent_emb_paper = self._make_table_path("emb_paper_missing") + + wrapper = self._build_gbml_config_wrapper( + node_types=["author", "paper"], + enumerated_tables={ + "author": enum_author, + "paper": enum_paper, + }, + embeddings_tables={ + "author": emb_author, + "paper": nonexistent_emb_paper, + }, + ) + + with self.assertRaises(ValueError) as ctx: + self._validator._validate_record_counts( + gbml_config_wrapper=wrapper, bq_utils=self._bq_utils + ) + + error_message = str(ctx.exception) + self.assertIn("2 error(s)", error_message) + self.assertIn("[author]", error_message) + self.assertIn("[paper]", error_message) + + def test_validation_passes_with_both_embeddings_and_predictions(self): + """Both embeddings and predictions tables exist with matching counts.""" + num_rows = 35 + enum_table = self._make_table_path("enum_paper") + emb_table = self._make_table_path("emb_paper") + pred_table = self._make_table_path("pred_paper") + + self._create_test_table(enum_table, num_rows) + self._create_test_table(emb_table, num_rows) + self._create_test_table(pred_table, num_rows) + + wrapper = self._build_gbml_config_wrapper( + node_types=["paper"], + enumerated_tables={"paper": enum_table}, + embeddings_tables={"paper": emb_table}, + predictions_tables={"paper": pred_table}, + ) + + # Should not raise + self._validator._validate_record_counts( + gbml_config_wrapper=wrapper, bq_utils=self._bq_utils + ) + + def test_validation_fails_on_predictions_count_mismatch(self): + """Predictions table has more rows — should raise ValueError.""" + enum_table = self._make_table_path("enum_paper") + pred_table = self._make_table_path("pred_paper") + + self._create_test_table(enum_table, 40) + self._create_test_table(pred_table, 60) + + wrapper = self._build_gbml_config_wrapper( + node_types=["paper"], + enumerated_tables={"paper": enum_table}, + predictions_tables={"paper": pred_table}, + ) + + with self.assertRaises(ValueError): + self._validator._validate_record_counts( + gbml_config_wrapper=wrapper, bq_utils=self._bq_utils + ) + + +if __name__ == "__main__": + absltest.main() From ba2af27ab19afbc312657fb2a14973dc40d2d1d4 Mon Sep 17 00:00:00 2001 From: kmontemayor Date: Tue, 10 Feb 2026 21:59:36 +0000 Subject: [PATCH 2/6] cleanup --- ...rd_count_validating_post_processor_test.py | 139 ++++++++++-------- 1 file changed, 77 insertions(+), 62 deletions(-) diff --git a/tests/integration/src/post_process/impl/record_count_validating_post_processor_test.py b/tests/integration/src/post_process/impl/record_count_validating_post_processor_test.py index cc94c0a36..38ae83790 100644 --- a/tests/integration/src/post_process/impl/record_count_validating_post_processor_test.py +++ b/tests/integration/src/post_process/impl/record_count_validating_post_processor_test.py @@ -1,3 +1,4 @@ +import os import tempfile import time import uuid @@ -7,7 +8,7 @@ from gigl.common import LocalUri from gigl.common.utils.proto_utils import ProtoUtils from gigl.env.pipelines_config import get_resource_config -from gigl.src.common.types.pb_wrappers.gbml_config import GbmlConfigPbWrapper +from gigl.src.common.types import AppliedTaskIdentifier from gigl.src.common.utils.bq import BqUtils from gigl.src.post_process.impl.record_count_validating_post_processor import ( RecordCountValidatingPostProcessor, @@ -33,13 +34,16 @@ def setUp(self): self._proto_utils = ProtoUtils() self._tables_to_cleanup: list[str] = [] - self._validator = RecordCountValidatingPostProcessor() + self._temp_files_to_cleanup: list[str] = [] def tearDown(self): for table_path in self._tables_to_cleanup: self._bq_utils.delete_bq_table_if_exist( bq_table_path=table_path, not_found_ok=True ) + for temp_file in self._temp_files_to_cleanup: + if os.path.exists(temp_file): + os.remove(temp_file) def _make_table_path(self, suffix: str) -> str: table_name = f"{self._test_unique_name}_{suffix}" @@ -59,14 +63,23 @@ def _create_test_table(self, table_path: str, num_rows: int) -> None: self._bq_utils.run_query(query=create_query, labels={}) time.sleep(1) - def _build_gbml_config_wrapper( + def _track_temp_file(self, path: str) -> None: + """Track a temp file for cleanup in tearDown.""" + self._temp_files_to_cleanup.append(path) + + def _build_task_config_uri( self, node_types: list[str], enumerated_tables: dict[str, str], embeddings_tables: dict[str, str] | None = None, predictions_tables: dict[str, str] | None = None, - ) -> GbmlConfigPbWrapper: - """Build a GbmlConfigPbWrapper with the given table paths. + ) -> LocalUri: + """Build a GbmlConfig proto, write it to a temp YAML file, and return the URI. + + The config is set up so that PostProcessor._run() completes without side effects: + - GLT backend is NOT enabled (skips unenumeration). + - post_processor_cls_path is empty (skips user-defined post-processing). + - should_skip_automatic_temp_asset_cleanup is true (skips GCS cleanup). Args: node_types: List of string node types. @@ -84,7 +97,11 @@ def _build_gbml_config_wrapper( graph_metadata.node_types.append(nt) graph_metadata.condensed_node_type_map[i] = nt - # Add a dummy edge type so condensed_edge_type_map is non-empty + # We add a dummy edge type because GbmlConfigPbWrapper requires + # condensed_edge_type_map to be non-empty in order to initialize + # the graph_metadata_pb_wrapper (see GbmlConfigPbWrapper.__load_graph_metadata_pb_wrapper, + # gigl/src/common/types/pb_wrappers/gbml_config.py:181-184). + # Our validator only checks node-level tables, not edge-level tables. dummy_edge = graph_schema_pb2.EdgeType( src_node_type=node_types[0], relation="dummy", @@ -93,7 +110,7 @@ def _build_gbml_config_wrapper( graph_metadata.edge_types.append(dummy_edge) graph_metadata.condensed_edge_type_map[0].CopyFrom(dummy_edge) - # Build preprocessed metadata + # Build preprocessed metadata and write to temp file preprocessed_metadata = preprocessed_metadata_pb2.PreprocessedMetadata() for i, nt in enumerate(node_types): node_output = ( @@ -101,9 +118,11 @@ def _build_gbml_config_wrapper( ) node_output.enumerated_node_ids_bq_table = enumerated_tables.get(nt, "") - # Write preprocessed metadata to a temp file - f = tempfile.NamedTemporaryFile(delete=False, suffix=".yaml") - preprocessed_metadata_uri = LocalUri(f.name) + preprocessed_metadata_file = tempfile.NamedTemporaryFile( + delete=False, suffix=".yaml" + ) + preprocessed_metadata_uri = LocalUri(preprocessed_metadata_file.name) + self._track_temp_file(preprocessed_metadata_file.name) self._proto_utils.write_proto_to_yaml( proto=preprocessed_metadata, uri=preprocessed_metadata_uri ) @@ -114,6 +133,7 @@ def _build_gbml_config_wrapper( gbml_config.shared_config.preprocessed_metadata_uri = ( preprocessed_metadata_uri.uri ) + gbml_config.shared_config.should_skip_automatic_temp_asset_cleanup = True # Populate inference metadata for nt in node_types: @@ -125,7 +145,22 @@ def _build_gbml_config_wrapper( if nt in predictions_tables: inference_output.predictions_path = predictions_tables[nt] - return GbmlConfigPbWrapper(gbml_config_pb=gbml_config) + # Write GbmlConfig to temp file + task_config_file = tempfile.NamedTemporaryFile(delete=False, suffix=".yaml") + task_config_uri = LocalUri(task_config_file.name) + self._track_temp_file(task_config_file.name) + self._proto_utils.write_proto_to_yaml(proto=gbml_config, uri=task_config_uri) + + return task_config_uri + + def _run_validator(self, task_config_uri: LocalUri) -> None: + """Run the RecordCountValidatingPostProcessor via its public run() API.""" + validator = RecordCountValidatingPostProcessor() + validator.run( + applied_task_identifier=AppliedTaskIdentifier(self._test_unique_name), + task_config_uri=task_config_uri, + resource_config_uri=get_resource_config().get_resource_config_uri, + ) def test_validation_passes_when_counts_match(self): """All tables exist and row counts match — validation should pass.""" @@ -136,88 +171,78 @@ def test_validation_passes_when_counts_match(self): self._create_test_table(enum_table, num_rows) self._create_test_table(emb_table, num_rows) - wrapper = self._build_gbml_config_wrapper( + task_config_uri = self._build_task_config_uri( node_types=["paper"], enumerated_tables={"paper": enum_table}, embeddings_tables={"paper": emb_table}, ) # Should not raise - self._validator._validate_record_counts( - gbml_config_wrapper=wrapper, bq_utils=self._bq_utils - ) + self._run_validator(task_config_uri) def test_validation_fails_on_row_count_mismatch(self): - """Embeddings table has fewer rows — should raise ValueError.""" + """Embeddings table has fewer rows — should raise SystemExit.""" enum_table = self._make_table_path("enum_paper") emb_table = self._make_table_path("emb_paper") self._create_test_table(enum_table, 50) self._create_test_table(emb_table, 30) - wrapper = self._build_gbml_config_wrapper( + task_config_uri = self._build_task_config_uri( node_types=["paper"], enumerated_tables={"paper": enum_table}, embeddings_tables={"paper": emb_table}, ) - with self.assertRaises(ValueError): - self._validator._validate_record_counts( - gbml_config_wrapper=wrapper, bq_utils=self._bq_utils - ) + with self.assertRaises(SystemExit): + self._run_validator(task_config_uri) def test_validation_fails_on_missing_embeddings_table(self): - """embeddings_path is set but BQ table does not exist — should raise ValueError.""" + """embeddings_path is set but BQ table does not exist — should raise SystemExit.""" enum_table = self._make_table_path("enum_paper") self._create_test_table(enum_table, 25) nonexistent_emb_table = self._make_table_path("emb_nonexistent") # Do NOT create this table — it should not exist - wrapper = self._build_gbml_config_wrapper( + task_config_uri = self._build_task_config_uri( node_types=["paper"], enumerated_tables={"paper": enum_table}, embeddings_tables={"paper": nonexistent_emb_table}, ) - with self.assertRaises(ValueError): - self._validator._validate_record_counts( - gbml_config_wrapper=wrapper, bq_utils=self._bq_utils - ) + with self.assertRaises(SystemExit): + self._run_validator(task_config_uri) def test_validation_fails_on_missing_predictions_table(self): - """predictions_path is set but BQ table does not exist — should raise ValueError.""" + """predictions_path is set but BQ table does not exist — should raise SystemExit.""" enum_table = self._make_table_path("enum_paper") self._create_test_table(enum_table, 25) nonexistent_pred_table = self._make_table_path("pred_nonexistent") - wrapper = self._build_gbml_config_wrapper( + task_config_uri = self._build_task_config_uri( node_types=["paper"], enumerated_tables={"paper": enum_table}, predictions_tables={"paper": nonexistent_pred_table}, ) - with self.assertRaises(ValueError): - self._validator._validate_record_counts( - gbml_config_wrapper=wrapper, bq_utils=self._bq_utils - ) + with self.assertRaises(SystemExit): + self._run_validator(task_config_uri) def test_validation_fails_when_no_output_paths_set(self): - """Neither embeddings nor predictions is set — should raise ValueError.""" + """Neither embeddings nor predictions is set — should raise SystemExit.""" enum_table = self._make_table_path("enum_paper") self._create_test_table(enum_table, 25) - wrapper = self._build_gbml_config_wrapper( + task_config_uri = self._build_task_config_uri( node_types=["paper"], enumerated_tables={"paper": enum_table}, # No embeddings or predictions ) - with self.assertRaises(ValueError): - self._validator._validate_record_counts( - gbml_config_wrapper=wrapper, bq_utils=self._bq_utils - ) + with self.assertRaises(SystemExit): + self._run_validator(task_config_uri) def test_validation_with_multiple_node_types(self): """Heterogeneous graph with multiple node types — each validated independently.""" @@ -234,7 +259,7 @@ def test_validation_with_multiple_node_types(self): self._create_test_table(enum_paper, num_rows_paper) self._create_test_table(emb_paper, num_rows_paper) - wrapper = self._build_gbml_config_wrapper( + task_config_uri = self._build_task_config_uri( node_types=["author", "paper"], enumerated_tables={ "author": enum_author, @@ -247,9 +272,7 @@ def test_validation_with_multiple_node_types(self): ) # Should not raise - self._validator._validate_record_counts( - gbml_config_wrapper=wrapper, bq_utils=self._bq_utils - ) + self._run_validator(task_config_uri) def test_validation_partial_output_only_embeddings(self): """Only embeddings (no predictions) for a node type — should pass.""" @@ -260,7 +283,7 @@ def test_validation_partial_output_only_embeddings(self): self._create_test_table(enum_table, num_rows) self._create_test_table(emb_table, num_rows) - wrapper = self._build_gbml_config_wrapper( + task_config_uri = self._build_task_config_uri( node_types=["paper"], enumerated_tables={"paper": enum_table}, embeddings_tables={"paper": emb_table}, @@ -268,9 +291,7 @@ def test_validation_partial_output_only_embeddings(self): ) # Should not raise - self._validator._validate_record_counts( - gbml_config_wrapper=wrapper, bq_utils=self._bq_utils - ) + self._run_validator(task_config_uri) def test_multiple_errors_collected(self): """One table missing AND one count mismatch — both errors appear in exception.""" @@ -284,7 +305,7 @@ def test_multiple_errors_collected(self): nonexistent_emb_paper = self._make_table_path("emb_paper_missing") - wrapper = self._build_gbml_config_wrapper( + task_config_uri = self._build_task_config_uri( node_types=["author", "paper"], enumerated_tables={ "author": enum_author, @@ -296,10 +317,8 @@ def test_multiple_errors_collected(self): }, ) - with self.assertRaises(ValueError) as ctx: - self._validator._validate_record_counts( - gbml_config_wrapper=wrapper, bq_utils=self._bq_utils - ) + with self.assertRaises(SystemExit) as ctx: + self._run_validator(task_config_uri) error_message = str(ctx.exception) self.assertIn("2 error(s)", error_message) @@ -317,7 +336,7 @@ def test_validation_passes_with_both_embeddings_and_predictions(self): self._create_test_table(emb_table, num_rows) self._create_test_table(pred_table, num_rows) - wrapper = self._build_gbml_config_wrapper( + task_config_uri = self._build_task_config_uri( node_types=["paper"], enumerated_tables={"paper": enum_table}, embeddings_tables={"paper": emb_table}, @@ -325,28 +344,24 @@ def test_validation_passes_with_both_embeddings_and_predictions(self): ) # Should not raise - self._validator._validate_record_counts( - gbml_config_wrapper=wrapper, bq_utils=self._bq_utils - ) + self._run_validator(task_config_uri) def test_validation_fails_on_predictions_count_mismatch(self): - """Predictions table has more rows — should raise ValueError.""" + """Predictions table has more rows — should raise SystemExit.""" enum_table = self._make_table_path("enum_paper") pred_table = self._make_table_path("pred_paper") self._create_test_table(enum_table, 40) self._create_test_table(pred_table, 60) - wrapper = self._build_gbml_config_wrapper( + task_config_uri = self._build_task_config_uri( node_types=["paper"], enumerated_tables={"paper": enum_table}, predictions_tables={"paper": pred_table}, ) - with self.assertRaises(ValueError): - self._validator._validate_record_counts( - gbml_config_wrapper=wrapper, bq_utils=self._bq_utils - ) + with self.assertRaises(SystemExit): + self._run_validator(task_config_uri) if __name__ == "__main__": From db44fd9ba8dbe01c61d188226f28d782da248fe3 Mon Sep 17 00:00:00 2001 From: kmontemayor Date: Tue, 10 Feb 2026 22:55:55 +0000 Subject: [PATCH 3/6] fix --- .../impl/record_count_validating_post_processor_test.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/tests/integration/src/post_process/impl/record_count_validating_post_processor_test.py b/tests/integration/src/post_process/impl/record_count_validating_post_processor_test.py index 38ae83790..9a656ddc3 100644 --- a/tests/integration/src/post_process/impl/record_count_validating_post_processor_test.py +++ b/tests/integration/src/post_process/impl/record_count_validating_post_processor_test.py @@ -5,7 +5,7 @@ from absl.testing import absltest -from gigl.common import LocalUri +from gigl.common import LocalUri, UriFactory from gigl.common.utils.proto_utils import ProtoUtils from gigl.env.pipelines_config import get_resource_config from gigl.src.common.types import AppliedTaskIdentifier @@ -159,7 +159,9 @@ def _run_validator(self, task_config_uri: LocalUri) -> None: validator.run( applied_task_identifier=AppliedTaskIdentifier(self._test_unique_name), task_config_uri=task_config_uri, - resource_config_uri=get_resource_config().get_resource_config_uri, + resource_config_uri=UriFactory.create_uri( + get_resource_config().get_resource_config_uri + ), ) def test_validation_passes_when_counts_match(self): From cd3eb71679d02336f5ebe112394d4206ab30b55e Mon Sep 17 00:00:00 2001 From: kmontemayor Date: Wed, 11 Feb 2026 17:59:19 +0000 Subject: [PATCH 4/6] fix --- .../record_count_validating_post_processor.py | 29 ++----- ...rd_count_validating_post_processor_test.py | 87 ++++++++----------- 2 files changed, 45 insertions(+), 71 deletions(-) diff --git a/gigl/src/post_process/impl/record_count_validating_post_processor.py b/gigl/src/post_process/impl/record_count_validating_post_processor.py index 9097dd77c..c8645a9b1 100644 --- a/gigl/src/post_process/impl/record_count_validating_post_processor.py +++ b/gigl/src/post_process/impl/record_count_validating_post_processor.py @@ -1,43 +1,28 @@ -from gigl.common import Uri from gigl.common.logger import Logger -from gigl.src.common.types import AppliedTaskIdentifier from gigl.src.common.types.graph_data import NodeType from gigl.src.common.types.pb_wrappers.gbml_config import GbmlConfigPbWrapper from gigl.src.common.utils.bq import BqUtils -from gigl.src.post_process.post_processor import PostProcessor +from gigl.src.post_process.lib.base_post_processor import BasePostProcessor +from snapchat.research.gbml import gbml_config_pb2 logger = Logger() -class RecordCountValidatingPostProcessor(PostProcessor): +class RecordCountValidatingPostProcessor(BasePostProcessor): """ - Post processor that extends PostProcessor with record count validation. + Post processor that extends PostProcessor with record count validation. - Runs all standard PostProcessor logic (unenumeration, user-defined - post-processing, metric export, cleanup), then validates that for each - node type, the unenumerated output tables (embeddings, predictions) have - the same number of rows as the corresponding enumerated_node_ids_bq_table. Only applicable for the GLT backend path. """ # TODO: Add edge-level validation support. - def _run( + def run_post_process( self, - applied_task_identifier: AppliedTaskIdentifier, - task_config_uri: Uri, + gbml_config_pb: gbml_config_pb2.GbmlConfig, ): - # Run all standard PostProcessor logic first - super()._run( - applied_task_identifier=applied_task_identifier, - task_config_uri=task_config_uri, - ) - - # Then validate record counts - gbml_config_wrapper = GbmlConfigPbWrapper.get_gbml_config_pb_wrapper_from_uri( - gbml_config_uri=task_config_uri - ) + gbml_config_wrapper = GbmlConfigPbWrapper(gbml_config_pb=gbml_config_pb) self._validate_record_counts(gbml_config_wrapper=gbml_config_wrapper) def _validate_record_counts( diff --git a/tests/integration/src/post_process/impl/record_count_validating_post_processor_test.py b/tests/integration/src/post_process/impl/record_count_validating_post_processor_test.py index 9a656ddc3..a67b7cebf 100644 --- a/tests/integration/src/post_process/impl/record_count_validating_post_processor_test.py +++ b/tests/integration/src/post_process/impl/record_count_validating_post_processor_test.py @@ -5,10 +5,9 @@ from absl.testing import absltest -from gigl.common import LocalUri, UriFactory +from gigl.common import LocalUri from gigl.common.utils.proto_utils import ProtoUtils from gigl.env.pipelines_config import get_resource_config -from gigl.src.common.types import AppliedTaskIdentifier from gigl.src.common.utils.bq import BqUtils from gigl.src.post_process.impl.record_count_validating_post_processor import ( RecordCountValidatingPostProcessor, @@ -67,13 +66,13 @@ def _track_temp_file(self, path: str) -> None: """Track a temp file for cleanup in tearDown.""" self._temp_files_to_cleanup.append(path) - def _build_task_config_uri( + def _build_gbml_config_pb( self, node_types: list[str], enumerated_tables: dict[str, str], embeddings_tables: dict[str, str] | None = None, predictions_tables: dict[str, str] | None = None, - ) -> LocalUri: + ) -> gbml_config_pb2.GbmlConfig: """Build a GbmlConfig proto, write it to a temp YAML file, and return the URI. The config is set up so that PostProcessor._run() completes without side effects: @@ -145,23 +144,13 @@ def _build_task_config_uri( if nt in predictions_tables: inference_output.predictions_path = predictions_tables[nt] - # Write GbmlConfig to temp file - task_config_file = tempfile.NamedTemporaryFile(delete=False, suffix=".yaml") - task_config_uri = LocalUri(task_config_file.name) - self._track_temp_file(task_config_file.name) - self._proto_utils.write_proto_to_yaml(proto=gbml_config, uri=task_config_uri) + return gbml_config - return task_config_uri - - def _run_validator(self, task_config_uri: LocalUri) -> None: + def _run_validator(self, gbml_config_pb: gbml_config_pb2.GbmlConfig) -> None: """Run the RecordCountValidatingPostProcessor via its public run() API.""" validator = RecordCountValidatingPostProcessor() - validator.run( - applied_task_identifier=AppliedTaskIdentifier(self._test_unique_name), - task_config_uri=task_config_uri, - resource_config_uri=UriFactory.create_uri( - get_resource_config().get_resource_config_uri - ), + validator.run_post_process( + gbml_config_pb=gbml_config_pb, ) def test_validation_passes_when_counts_match(self): @@ -173,78 +162,78 @@ def test_validation_passes_when_counts_match(self): self._create_test_table(enum_table, num_rows) self._create_test_table(emb_table, num_rows) - task_config_uri = self._build_task_config_uri( + gbml_config_pb = self._build_gbml_config_pb( node_types=["paper"], enumerated_tables={"paper": enum_table}, embeddings_tables={"paper": emb_table}, ) # Should not raise - self._run_validator(task_config_uri) + self._run_validator(gbml_config_pb) def test_validation_fails_on_row_count_mismatch(self): - """Embeddings table has fewer rows — should raise SystemExit.""" + """Embeddings table has fewer rows — should raise ValueError.""" enum_table = self._make_table_path("enum_paper") emb_table = self._make_table_path("emb_paper") self._create_test_table(enum_table, 50) self._create_test_table(emb_table, 30) - task_config_uri = self._build_task_config_uri( + gbml_config_pb = self._build_gbml_config_pb( node_types=["paper"], enumerated_tables={"paper": enum_table}, embeddings_tables={"paper": emb_table}, ) - with self.assertRaises(SystemExit): - self._run_validator(task_config_uri) + with self.assertRaises(ValueError): + self._run_validator(gbml_config_pb) def test_validation_fails_on_missing_embeddings_table(self): - """embeddings_path is set but BQ table does not exist — should raise SystemExit.""" + """embeddings_path is set but BQ table does not exist — should raise ValueError.""" enum_table = self._make_table_path("enum_paper") self._create_test_table(enum_table, 25) nonexistent_emb_table = self._make_table_path("emb_nonexistent") # Do NOT create this table — it should not exist - task_config_uri = self._build_task_config_uri( + gbml_config_pb = self._build_gbml_config_pb( node_types=["paper"], enumerated_tables={"paper": enum_table}, embeddings_tables={"paper": nonexistent_emb_table}, ) - with self.assertRaises(SystemExit): - self._run_validator(task_config_uri) + with self.assertRaises(ValueError): + self._run_validator(gbml_config_pb) def test_validation_fails_on_missing_predictions_table(self): - """predictions_path is set but BQ table does not exist — should raise SystemExit.""" + """predictions_path is set but BQ table does not exist — should raise ValueError.""" enum_table = self._make_table_path("enum_paper") self._create_test_table(enum_table, 25) nonexistent_pred_table = self._make_table_path("pred_nonexistent") - task_config_uri = self._build_task_config_uri( + gbml_config_pb = self._build_gbml_config_pb( node_types=["paper"], enumerated_tables={"paper": enum_table}, predictions_tables={"paper": nonexistent_pred_table}, ) - with self.assertRaises(SystemExit): - self._run_validator(task_config_uri) + with self.assertRaises(ValueError): + self._run_validator(gbml_config_pb) def test_validation_fails_when_no_output_paths_set(self): - """Neither embeddings nor predictions is set — should raise SystemExit.""" + """Neither embeddings nor predictions is set — should raise ValueError.""" enum_table = self._make_table_path("enum_paper") self._create_test_table(enum_table, 25) - task_config_uri = self._build_task_config_uri( + gbml_config_pb = self._build_gbml_config_pb( node_types=["paper"], enumerated_tables={"paper": enum_table}, # No embeddings or predictions ) - with self.assertRaises(SystemExit): - self._run_validator(task_config_uri) + with self.assertRaises(ValueError): + self._run_validator(gbml_config_pb) def test_validation_with_multiple_node_types(self): """Heterogeneous graph with multiple node types — each validated independently.""" @@ -261,7 +250,7 @@ def test_validation_with_multiple_node_types(self): self._create_test_table(enum_paper, num_rows_paper) self._create_test_table(emb_paper, num_rows_paper) - task_config_uri = self._build_task_config_uri( + gbml_config_pb = self._build_gbml_config_pb( node_types=["author", "paper"], enumerated_tables={ "author": enum_author, @@ -274,7 +263,7 @@ def test_validation_with_multiple_node_types(self): ) # Should not raise - self._run_validator(task_config_uri) + self._run_validator(gbml_config_pb) def test_validation_partial_output_only_embeddings(self): """Only embeddings (no predictions) for a node type — should pass.""" @@ -285,7 +274,7 @@ def test_validation_partial_output_only_embeddings(self): self._create_test_table(enum_table, num_rows) self._create_test_table(emb_table, num_rows) - task_config_uri = self._build_task_config_uri( + gbml_config_pb = self._build_gbml_config_pb( node_types=["paper"], enumerated_tables={"paper": enum_table}, embeddings_tables={"paper": emb_table}, @@ -293,7 +282,7 @@ def test_validation_partial_output_only_embeddings(self): ) # Should not raise - self._run_validator(task_config_uri) + self._run_validator(gbml_config_pb) def test_multiple_errors_collected(self): """One table missing AND one count mismatch — both errors appear in exception.""" @@ -307,7 +296,7 @@ def test_multiple_errors_collected(self): nonexistent_emb_paper = self._make_table_path("emb_paper_missing") - task_config_uri = self._build_task_config_uri( + gbml_config_pb = self._build_gbml_config_pb( node_types=["author", "paper"], enumerated_tables={ "author": enum_author, @@ -319,8 +308,8 @@ def test_multiple_errors_collected(self): }, ) - with self.assertRaises(SystemExit) as ctx: - self._run_validator(task_config_uri) + with self.assertRaises(ValueError) as ctx: + self._run_validator(gbml_config_pb) error_message = str(ctx.exception) self.assertIn("2 error(s)", error_message) @@ -338,7 +327,7 @@ def test_validation_passes_with_both_embeddings_and_predictions(self): self._create_test_table(emb_table, num_rows) self._create_test_table(pred_table, num_rows) - task_config_uri = self._build_task_config_uri( + gbml_config_pb = self._build_gbml_config_pb( node_types=["paper"], enumerated_tables={"paper": enum_table}, embeddings_tables={"paper": emb_table}, @@ -346,24 +335,24 @@ def test_validation_passes_with_both_embeddings_and_predictions(self): ) # Should not raise - self._run_validator(task_config_uri) + self._run_validator(gbml_config_pb) def test_validation_fails_on_predictions_count_mismatch(self): - """Predictions table has more rows — should raise SystemExit.""" + """Predictions table has more rows — should raise ValueError.""" enum_table = self._make_table_path("enum_paper") pred_table = self._make_table_path("pred_paper") self._create_test_table(enum_table, 40) self._create_test_table(pred_table, 60) - task_config_uri = self._build_task_config_uri( + gbml_config_pb = self._build_gbml_config_pb( node_types=["paper"], enumerated_tables={"paper": enum_table}, predictions_tables={"paper": pred_table}, ) - with self.assertRaises(SystemExit): - self._run_validator(task_config_uri) + with self.assertRaises(ValueError): + self._run_validator(gbml_config_pb) if __name__ == "__main__": From 68f7cf8fe611894c320be798ea47a96df7afb7c8 Mon Sep 17 00:00:00 2001 From: kmontemayor Date: Wed, 11 Feb 2026 23:12:52 +0000 Subject: [PATCH 5/6] fix --- .../impl/record_count_validating_post_processor.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/gigl/src/post_process/impl/record_count_validating_post_processor.py b/gigl/src/post_process/impl/record_count_validating_post_processor.py index c8645a9b1..2b1d5436d 100644 --- a/gigl/src/post_process/impl/record_count_validating_post_processor.py +++ b/gigl/src/post_process/impl/record_count_validating_post_processor.py @@ -1,4 +1,5 @@ from gigl.common.logger import Logger +from gigl.src.common.types import AppliedTaskIdentifier from gigl.src.common.types.graph_data import NodeType from gigl.src.common.types.pb_wrappers.gbml_config import GbmlConfigPbWrapper from gigl.src.common.utils.bq import BqUtils @@ -16,6 +17,11 @@ class RecordCountValidatingPostProcessor(BasePostProcessor): Only applicable for the GLT backend path. """ + # We need __init__ as applied_task_identified gets injected PostProcessor._run_post_process + # But we have no need for it. + def __init__(self, applied_task_identifier: AppliedTaskIdentifier): + pass + # TODO: Add edge-level validation support. def run_post_process( From 284092881027bee78ca7966219e9cf8d9441ccd2 Mon Sep 17 00:00:00 2001 From: kmontemayor Date: Fri, 20 Feb 2026 16:42:39 +0000 Subject: [PATCH 6/6] update --- .../impl/record_count_validating_post_processor.py | 2 +- .../impl/record_count_validating_post_processor_test.py | 5 ++++- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/gigl/src/post_process/impl/record_count_validating_post_processor.py b/gigl/src/post_process/impl/record_count_validating_post_processor.py index 2b1d5436d..29a345c3a 100644 --- a/gigl/src/post_process/impl/record_count_validating_post_processor.py +++ b/gigl/src/post_process/impl/record_count_validating_post_processor.py @@ -19,7 +19,7 @@ class RecordCountValidatingPostProcessor(BasePostProcessor): # We need __init__ as applied_task_identified gets injected PostProcessor._run_post_process # But we have no need for it. - def __init__(self, applied_task_identifier: AppliedTaskIdentifier): + def __init__(self, *, applied_task_identifier: AppliedTaskIdentifier): pass # TODO: Add edge-level validation support. diff --git a/tests/integration/src/post_process/impl/record_count_validating_post_processor_test.py b/tests/integration/src/post_process/impl/record_count_validating_post_processor_test.py index a67b7cebf..a19a208a7 100644 --- a/tests/integration/src/post_process/impl/record_count_validating_post_processor_test.py +++ b/tests/integration/src/post_process/impl/record_count_validating_post_processor_test.py @@ -8,6 +8,7 @@ from gigl.common import LocalUri from gigl.common.utils.proto_utils import ProtoUtils from gigl.env.pipelines_config import get_resource_config +from gigl.src.common.types import AppliedTaskIdentifier from gigl.src.common.utils.bq import BqUtils from gigl.src.post_process.impl.record_count_validating_post_processor import ( RecordCountValidatingPostProcessor, @@ -148,7 +149,9 @@ def _build_gbml_config_pb( def _run_validator(self, gbml_config_pb: gbml_config_pb2.GbmlConfig) -> None: """Run the RecordCountValidatingPostProcessor via its public run() API.""" - validator = RecordCountValidatingPostProcessor() + validator = RecordCountValidatingPostProcessor( + applied_task_identifier=AppliedTaskIdentifier("foo") + ) validator.run_post_process( gbml_config_pb=gbml_config_pb, )