Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -70,5 +70,7 @@ sharedConfig:
shouldSkipInference: false
# Model Evaluation is currently only supported for tabularized SGS GiGL pipelines. This will soon be added for in-mem SGS GiGL pipelines.
shouldSkipModelEvaluation: true
postProcessorConfig:
postProcessorClsPath: gigl.src.post_process.impl.record_count_validating_post_processor.RecordCountValidatingPostProcessor
featureFlags:
should_run_glt_backend: 'True'
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ taskMetadata:
- dstNodeType: paper
relation: cites
srcNodeType: paper
postProcessorConfig:
postProcessorClsPath: gigl.src.post_process.impl.record_count_validating_post_processor.RecordCountValidatingPostProcessor
featureFlags:
should_run_glt_backend: 'True'
data_preprocessor_num_shards: '2'
Original file line number Diff line number Diff line change
Expand Up @@ -76,5 +76,7 @@ sharedConfig:
shouldSkipInference: false
# Model Evaluation is currently only supported for tabularized SGS GiGL pipelines. This will soon be added for in-mem SGS GiGL pipelines.
shouldSkipModelEvaluation: true
postProcessorConfig:
postProcessorClsPath: gigl.src.post_process.impl.record_count_validating_post_processor.RecordCountValidatingPostProcessor
featureFlags:
should_run_glt_backend: 'True'
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ taskMetadata:
- dstNodeType: paper
relation: cites
srcNodeType: paper
postProcessorConfig:
postProcessorClsPath: gigl.src.post_process.impl.record_count_validating_post_processor.RecordCountValidatingPostProcessor
featureFlags:
should_run_glt_backend: 'True'
data_preprocessor_num_shards: '2'
Empty file.
151 changes: 151 additions & 0 deletions gigl/src/post_process/impl/record_count_validating_post_processor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
from gigl.common.logger import Logger
from gigl.src.common.types import AppliedTaskIdentifier
from gigl.src.common.types.graph_data import NodeType
from gigl.src.common.types.pb_wrappers.gbml_config import GbmlConfigPbWrapper
from gigl.src.common.utils.bq import BqUtils
from gigl.src.post_process.lib.base_post_processor import BasePostProcessor
from snapchat.research.gbml import gbml_config_pb2

logger = Logger()


class RecordCountValidatingPostProcessor(BasePostProcessor):
"""
Post processor that extends PostProcessor with record count validation.


Only applicable for the GLT backend path.
"""

# We need __init__ as applied_task_identified gets injected PostProcessor._run_post_process
# But we have no need for it.
def __init__(self, *, applied_task_identifier: AppliedTaskIdentifier):
pass

# TODO: Add edge-level validation support.

def run_post_process(
self,
gbml_config_pb: gbml_config_pb2.GbmlConfig,
):
gbml_config_wrapper = GbmlConfigPbWrapper(gbml_config_pb=gbml_config_pb)
self._validate_record_counts(gbml_config_wrapper=gbml_config_wrapper)

def _validate_record_counts(
self,
gbml_config_wrapper: GbmlConfigPbWrapper,
bq_utils: BqUtils | None = None,
) -> None:
"""Validates that output BQ tables have matching record counts.

For each node type in the inference output, checks that:
1. The enumerated_node_ids_bq_table exists.
2. The embeddings table (if configured) exists and has the same row count.
3. The predictions table (if configured) exists and has the same row count.
4. At least one of embeddings or predictions is configured.

All errors are collected and reported together before raising.

Args:
gbml_config_wrapper: The GbmlConfig wrapper with access to all metadata.
bq_utils: Optional BqUtils instance for testing. If None, creates one
from the resource config.

Raises:
ValueError: If any validation errors are found, with all errors listed.
"""
if bq_utils is None:
from gigl.env.pipelines_config import get_resource_config

resource_config = get_resource_config()
bq_utils = BqUtils(project=resource_config.project)

validation_errors: list[str] = []

inference_output_map = (
gbml_config_wrapper.shared_config.inference_metadata.node_type_to_inferencer_output_info_map
)
node_type_to_condensed = (
gbml_config_wrapper.graph_metadata_pb_wrapper.node_type_to_condensed_node_type_map
)
preprocessed_metadata = (
gbml_config_wrapper.preprocessed_metadata_pb_wrapper.preprocessed_metadata
)

for node_type, inference_output in inference_output_map.items():
condensed_node_type = node_type_to_condensed[NodeType(node_type)]
node_metadata = (
preprocessed_metadata.condensed_node_type_to_preprocessed_metadata[
int(condensed_node_type)
]
)
enumerated_table = node_metadata.enumerated_node_ids_bq_table

if not enumerated_table:
validation_errors.append(
f"[{node_type}] No enumerated_node_ids_bq_table configured."
)
continue

if not bq_utils.does_bq_table_exist(enumerated_table):
validation_errors.append(
f"[{node_type}] enumerated_node_ids_bq_table does not exist: {enumerated_table}"
)
continue

expected_count = bq_utils.count_number_of_rows_in_bq_table(enumerated_table)
logger.info(
f"[{node_type}] enumerated_node_ids_bq_table ({enumerated_table}) has {expected_count} rows."
)

# Validate embeddings
embeddings_path = inference_output.embeddings_path
if embeddings_path:
if not bq_utils.does_bq_table_exist(embeddings_path):
validation_errors.append(
f"[{node_type}] Embeddings table does not exist: {embeddings_path}"
)
else:
actual = bq_utils.count_number_of_rows_in_bq_table(embeddings_path)
logger.info(
f"[{node_type}] Embeddings table ({embeddings_path}) has {actual} rows."
)
if actual != expected_count:
validation_errors.append(
f"[{node_type}] Embeddings row count mismatch: "
f"expected {expected_count}, got {actual} "
f"(table: {embeddings_path})"
)

# Validate predictions
predictions_path = inference_output.predictions_path
if predictions_path:
if not bq_utils.does_bq_table_exist(predictions_path):
validation_errors.append(
f"[{node_type}] Predictions table does not exist: {predictions_path}"
)
else:
actual = bq_utils.count_number_of_rows_in_bq_table(predictions_path)
logger.info(
f"[{node_type}] Predictions table ({predictions_path}) has {actual} rows."
)
if actual != expected_count:
validation_errors.append(
f"[{node_type}] Predictions row count mismatch: "
f"expected {expected_count}, got {actual} "
f"(table: {predictions_path})"
)

if not embeddings_path and not predictions_path:
validation_errors.append(
f"[{node_type}] Neither embeddings_path nor predictions_path is set."
)

if validation_errors:
error_summary = "\n".join(validation_errors)
raise ValueError(
f"Record count validation failed with {len(validation_errors)} error(s):\n"
f"{error_summary}"
)

logger.info("All record count validations passed.")
16 changes: 8 additions & 8 deletions gigl/src/post_process/post_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@


class PostProcessor:
def __run_post_process(
def _run_post_process(
self,
gbml_config_pb: gbml_config_pb2.GbmlConfig,
applied_task_identifier: AppliedTaskIdentifier,
Expand Down Expand Up @@ -66,7 +66,7 @@ def __run_post_process(
EvalMetricsCollection
] = post_processor.run_post_process(gbml_config_pb=gbml_config_pb)
if post_processor_metrics is not None:
self.__write_post_processor_metrics_to_uri(
self._write_post_processor_metrics_to_uri(
model_eval_metrics=post_processor_metrics,
gbml_config_pb=gbml_config_pb,
)
Expand All @@ -87,7 +87,7 @@ def __run_post_process(
)
gcs_utils.delete_files_in_bucket_dir(gcs_path=temp_dir_gcs_path)

def __write_post_processor_metrics_to_uri(
def _write_post_processor_metrics_to_uri(
self,
model_eval_metrics: EvalMetricsCollection,
gbml_config_pb: gbml_config_pb2.GbmlConfig,
Expand All @@ -106,15 +106,15 @@ def __write_post_processor_metrics_to_uri(
)
logger.info(f"Wrote eval metrics to {post_processor_log_metrics_uri.uri}.")

def __should_run_unenumeration(
def _should_run_unenumeration(
self, gbml_config_wrapper: GbmlConfigPbWrapper
) -> bool:
"""
When using the experimental GLT backend, we should run unenumeration in the post processor.
"""
return gbml_config_wrapper.should_use_glt_backend

def __run(
def _run(
self,
applied_task_identifier: AppliedTaskIdentifier,
task_config_uri: Uri,
Expand All @@ -124,7 +124,7 @@ def __run(
gbml_config_uri=task_config_uri
)
)
if self.__should_run_unenumeration(gbml_config_wrapper=gbml_config_wrapper):
if self._should_run_unenumeration(gbml_config_wrapper=gbml_config_wrapper):
logger.info(f"Running unenumeration for inferred assets in post processor")
unenumerate_all_inferred_bq_assets(
gbml_config_pb_wrapper=gbml_config_wrapper
Expand All @@ -133,7 +133,7 @@ def __run(
f"Finished running unenumeration for inferred assets in post processor"
)

self.__run_post_process(
self._run_post_process(
gbml_config_pb=gbml_config_wrapper.gbml_config_pb,
applied_task_identifier=applied_task_identifier,
)
Expand All @@ -150,7 +150,7 @@ def run(
resource_config_uri: Uri,
):
try:
return self.__run(
return self._run(
applied_task_identifier=applied_task_identifier,
task_config_uri=task_config_uri,
)
Expand Down
Empty file.
Empty file.
Loading