diff --git a/deployment/configs/e2e_glt_gs_resource_config.yaml b/deployment/configs/e2e_glt_gs_resource_config.yaml index 097cf60d9..4839c8f48 100644 --- a/deployment/configs/e2e_glt_gs_resource_config.yaml +++ b/deployment/configs/e2e_glt_gs_resource_config.yaml @@ -1,5 +1,6 @@ # Diffs from e2e_glt_resource_config.yaml # - Swap vertex_ai_inferencer_config for vertex_ai_graph_store_inferencer_config +# - Swap vertex_ai_trainer_config for vertex_ai_graph_store_trainer_config shared_resource_config: resource_labels: cost_resource_group_tag: dev_experiments_COMPONENT @@ -26,11 +27,17 @@ preprocessor_config: machine_type: "n2d-highmem-64" disk_size_gb: 300 trainer_resource_config: - vertex_ai_trainer_config: - machine_type: n1-highmem-32 - gpu_type: NVIDIA_TESLA_T4 - gpu_limit: 2 - num_replicas: 2 + vertex_ai_graph_store_trainer_config: + graph_store_pool: + machine_type: n2-highmem-32 + gpu_type: ACCELERATOR_TYPE_UNSPECIFIED + gpu_limit: 0 + num_replicas: 2 + compute_pool: + machine_type: n1-standard-16 + gpu_type: NVIDIA_TESLA_T4 + gpu_limit: 2 + num_replicas: 2 inferencer_resource_config: vertex_ai_graph_store_inferencer_config: graph_store_pool: diff --git a/examples/link_prediction/graph_store/configs/e2e_het_dblp_sup_gs_task_config.yaml b/examples/link_prediction/graph_store/configs/e2e_het_dblp_sup_gs_task_config.yaml index 139abfc72..251eea4b6 100644 --- a/examples/link_prediction/graph_store/configs/e2e_het_dblp_sup_gs_task_config.yaml +++ b/examples/link_prediction/graph_store/configs/e2e_het_dblp_sup_gs_task_config.yaml @@ -29,7 +29,6 @@ datasetConfig: dataPreprocessorArgs: # This argument is specific for the `PassthroughPreprocessorConfigForMockedAssets` preprocessor to indicate which dataset we should be using mocked_dataset_name: 'dblp_node_anchor_edge_features_lp' -# TODO(kmonte): Add GS trainer trainerConfig: trainerArgs: # Example argument to trainer @@ -49,7 +48,15 @@ trainerConfig: ("paper", "to", "author"): [15, 15], ("author", "to", "paper"): [20, 20] } - command: python -m examples.link_prediction.heterogeneous_training + command: python -m examples.link_prediction.graph_store.heterogeneous_training + graphStoreStorageConfig: + command: python -m examples.link_prediction.graph_store.storage_main + storageArgs: + sample_edge_direction: "in" + splitter_cls_path: "gigl.utils.data_splitters.DistNodeAnchorLinkSplitter" + splitter_kwargs: '{"sampling_direction": "in", "should_convert_labels_to_edges": True, "num_val": 0.3, "num_test": 0.3, "supervision_edge_types": [("author", "to", "paper")]}' + ssl_positive_label_percentage: "0.15" + num_server_sessions: "1" # TODO(kmonte): Move to user-defined server code inferencerConfig: inferencerArgs: diff --git a/examples/link_prediction/graph_store/configs/e2e_hom_cora_sup_gs_task_config.yaml b/examples/link_prediction/graph_store/configs/e2e_hom_cora_sup_gs_task_config.yaml index 50e1c2b64..f972310e7 100644 --- a/examples/link_prediction/graph_store/configs/e2e_hom_cora_sup_gs_task_config.yaml +++ b/examples/link_prediction/graph_store/configs/e2e_hom_cora_sup_gs_task_config.yaml @@ -15,13 +15,19 @@ datasetConfig: dataPreprocessorArgs: # This argument is specific for the `PassthroughPreprocessorConfigForMockedAssets` preprocessor to indicate which dataset we should be using mocked_dataset_name: 'cora_homogeneous_node_anchor_edge_features_user_defined_labels' -# TODO(kmonte): Add GS trainer trainerConfig: trainerArgs: # Example argument to trainer log_every_n_batch: "50" # Frequency in which we log batch information num_neighbors: "[10, 10]" # Fanout per hop, specified as a string representation of a list for the homogeneous use case - command: python -m examples.link_prediction.homogeneous_training + command: python -m examples.link_prediction.graph_store.homogeneous_training + graphStoreStorageConfig: + command: python -m examples.link_prediction.graph_store.storage_main + storageArgs: + sample_edge_direction: "in" + splitter_cls_path: "gigl.utils.data_splitters.DistNodeAnchorLinkSplitter" + splitter_kwargs: '{"sampling_direction": "in", "should_convert_labels_to_edges": True, "num_val": 0.1, "num_test": 0.1}' + num_server_sessions: "1" # TODO(kmonte): Move to user-defined server code inferencerConfig: inferencerArgs: diff --git a/examples/link_prediction/graph_store/configs/example_resource_config.yaml b/examples/link_prediction/graph_store/configs/example_resource_config.yaml index aa6618fd7..869f627ca 100644 --- a/examples/link_prediction/graph_store/configs/example_resource_config.yaml +++ b/examples/link_prediction/graph_store/configs/example_resource_config.yaml @@ -46,13 +46,18 @@ preprocessor_config: max_num_workers: 4 machine_type: "n2-standard-16" disk_size_gb: 300 -# TODO(kmonte): Update trainer_resource_config: - vertex_ai_trainer_config: - machine_type: n1-standard-16 - gpu_type: NVIDIA_TESLA_T4 - gpu_limit: 2 - num_replicas: 2 + vertex_ai_graph_store_trainer_config: + graph_store_pool: + machine_type: n2-highmem-32 + gpu_type: ACCELERATOR_TYPE_UNSPECIFIED + gpu_limit: 0 + num_replicas: 2 + compute_pool: + machine_type: n1-standard-16 + gpu_type: NVIDIA_TESLA_T4 + gpu_limit: 2 + num_replicas: 2 inferencer_resource_config: vertex_ai_graph_store_inferencer_config: graph_store_pool: diff --git a/examples/link_prediction/graph_store/heterogeneous_training.py b/examples/link_prediction/graph_store/heterogeneous_training.py new file mode 100644 index 000000000..a448e2fc3 --- /dev/null +++ b/examples/link_prediction/graph_store/heterogeneous_training.py @@ -0,0 +1,974 @@ +""" +This file contains an example for how to run heterogeneous training in **graph store mode** using GiGL. + +Graph Store Mode vs Standard Mode: +---------------------------------- +Graph store mode uses a heterogeneous cluster architecture with two distinct sub-clusters: + 1. **Storage Cluster (graph_store_pool)**: Dedicated machines for storing and serving the graph + data. These are typically high-memory machines without GPUs (e.g., n2-highmem-32). + 2. **Compute Cluster (compute_pool)**: Dedicated machines for running model training. + These typically have GPUs attached (e.g., n1-standard-16 with NVIDIA_TESLA_T4). + +This separation allows for: + - Independent scaling of storage and compute resources + - Better memory utilization (graph data stays on storage nodes) + - Cost optimization by using appropriate hardware for each role + +In contrast, the standard training mode (see `examples/link_prediction/heterogeneous_training.py`) +uses a homogeneous cluster where each machine handles both graph storage and computation. + +Key Implementation Differences: +------------------------------- +This file (graph store mode): + - Uses `RemoteDistDataset` to connect to a remote graph store cluster + - Uses `init_compute_process` to initialize the compute node connection to storage + - Obtains cluster topology via `get_graph_store_info()` which returns `GraphStoreInfo` + - Uses `mp_sharing_dict` for efficient tensor sharing between local processes + - Fetches ABLP input via `RemoteDistDataset.get_ablp_input()` for the train/val/test splits + - Fetches random negative node IDs via `RemoteDistDataset.get_node_ids()` + +Standard mode (`heterogeneous_training.py`): + - Uses `DistDataset` with `build_dataset_from_task_config_uri` where each node loads its partition + - Manually manages distributed process groups with master IP and port + - Each machine stores its own partition of the graph data + +To run this file with GiGL orchestration, set the fields similar to below: + +trainerConfig: + trainerArgs: + log_every_n_batch: "50" + ssl_positive_label_percentage: "0.05" + command: python -m examples.link_prediction.graph_store.heterogeneous_training + graphStoreStorageConfig: + command: python -m examples.link_prediction.graph_store.storage_main + storageArgs: + sample_edge_direction: "in" + splitter_cls_path: "gigl.utils.data_splitters.DistNodeAnchorLinkSplitter" + splitter_kwargs: '{"sampling_direction": "in", "should_convert_labels_to_edges": true, "num_val": 0.1, "num_test": 0.1}' + ssl_positive_label_percentage: "0.05" + num_server_sessions: "1" +featureFlags: + should_run_glt_backend: 'True' + +Note: Ensure you use a resource config with `vertex_ai_graph_store_trainer_config` when +running in graph store mode. + +You can run this example in a full pipeline with `make run_het_dblp_sup_gs_e2e_test` from GiGL root. + +Note that the DBLP Dataset does not have specified labeled edges so we use the `ssl_positive_label_percentage` +field in the config to indicate what percentage of edges we should select as self-supervised labeled edges. +""" + +import argparse +import gc +import os +import statistics +import sys +import time +from collections.abc import Iterator, MutableMapping +from dataclasses import dataclass +from typing import Literal, Optional, Union + +import torch +import torch.distributed +import torch.multiprocessing as mp +from examples.link_prediction.models import init_example_gigl_heterogeneous_model +from torch_geometric.data import HeteroData + +from gigl.common import Uri, UriFactory +from gigl.common.logger import Logger +from gigl.common.utils.torch_training import is_distributed_available_and_initialized +from gigl.distributed import DistABLPLoader +from gigl.distributed.distributed_neighborloader import DistNeighborLoader +from gigl.distributed.graph_store.compute import ( + init_compute_process, + shutdown_compute_proccess, +) +from gigl.distributed.graph_store.remote_dist_dataset import RemoteDistDataset +from gigl.distributed.utils import get_available_device, get_graph_store_info +from gigl.env.distributed import GraphStoreInfo +from gigl.nn import LinkPredictionGNN, RetrievalLoss +from gigl.src.common.types.graph_data import EdgeType, NodeType +from gigl.src.common.types.pb_wrappers.gbml_config import GbmlConfigPbWrapper +from gigl.src.common.utils.model import load_state_dict_from_uri, save_state_dict +from gigl.utils.iterator import InfiniteIterator +from gigl.utils.sampling import parse_fanout + +logger = Logger() + + +# We don't see logs for graph store mode for whatever reason. +# TODO(#442): Revert this once the GCP issues are resolved. +def flush(): + sys.stdout.write("\n") + sys.stdout.flush() + sys.stderr.write("\n") + sys.stderr.flush() + + +def _sync_metric_across_processes(metric: torch.Tensor) -> float: + """ + Takes the average of a training metric across multiple processes. Note that this function requires DDP to be initialized. + Args: + metric (torch.Tensor): The metric, expressed as a torch Tensor, which should be synced across multiple processes + Returns: + float: The average of the provided metric across all training processes + """ + assert is_distributed_available_and_initialized(), "DDP is not initialized" + # Make a copy of the local loss tensor + loss_tensor = metric.detach().clone() + print(f"---Rank {torch.distributed.get_rank()} loss tensor: {loss_tensor}") + torch.distributed.all_reduce(loss_tensor, op=torch.distributed.ReduceOp.SUM) + return loss_tensor.item() / torch.distributed.get_world_size() + + +def _setup_dataloaders( + dataset: RemoteDistDataset, + split: Literal["train", "val", "test"], + cluster_info: GraphStoreInfo, + supervision_edge_type: EdgeType, + num_neighbors: Union[list[int], dict[EdgeType, list[int]]], + sampling_workers_per_process: int, + main_batch_size: int, + random_batch_size: int, + device: torch.device, + sampling_worker_shared_channel_size: str, + process_start_gap_seconds: int, +) -> tuple[DistABLPLoader, DistNeighborLoader]: + """ + Sets up main and random dataloaders for training and testing purposes using a remote graph store dataset. + Args: + dataset (RemoteDistDataset): Remote dataset connected to the graph store cluster. + split (Literal["train", "val", "test"]): The current split which we are loading data for. + cluster_info (GraphStoreInfo): Cluster topology info for graph store mode. + supervision_edge_type (EdgeType): The supervision edge type to use for training. + num_neighbors: Fanout for subgraph sampling, where the ith item corresponds to the number of items to sample for the ith hop. + sampling_workers_per_process (int): Number of sampling workers per training/testing process. + main_batch_size (int): Batch size for main dataloader with query and labeled nodes. + random_batch_size (int): Batch size for random negative dataloader. + device (torch.device): Device to put loaded subgraphs on. + sampling_worker_shared_channel_size (str): Shared-memory buffer size (bytes) allocated for the channel during sampling. + process_start_gap_seconds (int): The amount of time to sleep for initializing each dataloader. + Returns: + DistABLPLoader: Dataloader for loading main batch data with query and labeled nodes. + DistNeighborLoader: Dataloader for loading random negative data. + """ + rank = torch.distributed.get_rank() + + if dataset.fetch_edge_dir() == "in": + query_node_type = supervision_edge_type.dst_node_type + labeled_node_type = supervision_edge_type.src_node_type + anchor_node_type = query_node_type + else: + query_node_type = supervision_edge_type.src_node_type + labeled_node_type = supervision_edge_type.dst_node_type + anchor_node_type = query_node_type + + print( + f"---Rank {rank} query node type: {query_node_type}, labeled node type: {labeled_node_type}, anchor node type: {anchor_node_type} due to edge direction {dataset.fetch_edge_dir()}" + ) + + shuffle = split == "train" + + # In graph store mode, we fetch ABLP input (anchors + positive/negative labels) from the storage cluster. + # This returns dict[server_rank, (anchors, pos_labels, neg_labels)] which the DistABLPLoader knows how to handle. + print(f"---Rank {rank} fetching ABLP input for split={split}") + flush() + ablp_input = dataset.fetch_ablp_input( + split=split, + rank=cluster_info.compute_node_rank, + world_size=cluster_info.num_compute_nodes, + anchor_node_type=anchor_node_type, + supervision_edge_type=supervision_edge_type, + ) + pos_labels = [a.labels[supervision_edge_type][0].shape for a in ablp_input.values()] + print(f"---Rank {rank} split {split} ABLP input sizes: main_loader: {[a.anchor_nodes.shape for a in ablp_input.values()]}, pos labels: {pos_labels}") + main_loader = DistABLPLoader( + dataset=dataset, + num_neighbors=num_neighbors, + input_nodes=ablp_input, + num_workers=sampling_workers_per_process, + batch_size=main_batch_size, + pin_memory_device=device, + worker_concurrency=sampling_workers_per_process, + channel_size=sampling_worker_shared_channel_size, + process_start_gap_seconds=process_start_gap_seconds, + shuffle=shuffle, + ) + + print(f"---Rank {rank} finished setting up main loader for split={split}") + flush() + + # We need to wait for all processes to finish initializing the main_loader before creating the + # random_negative_loader so that its initialization doesn't compete for memory with the main_loader. + torch.distributed.barrier() + + # For the random negative loader, we get all node IDs of the labeled node type from the storage cluster. + all_node_ids = dataset.fetch_node_ids( + rank=cluster_info.compute_node_rank, + world_size=cluster_info.num_compute_nodes, + node_type=labeled_node_type, + ) + + print(f"---Rank {rank} split {split} all node ids sizes: {[n.shape for n in all_node_ids.values()]}") + flush() + + random_negative_loader = DistNeighborLoader( + dataset=dataset, + num_neighbors=num_neighbors, + input_nodes=(labeled_node_type, all_node_ids), + num_workers=sampling_workers_per_process, + batch_size=random_batch_size, + pin_memory_device=device, + worker_concurrency=sampling_workers_per_process, + channel_size=sampling_worker_shared_channel_size, + process_start_gap_seconds=process_start_gap_seconds, + shuffle=shuffle, + ) + + print( + f"---Rank {rank} finished setting up random negative loader for split={split}" + ) + flush() + + # Wait for all processes to finish initializing the random_loader + torch.distributed.barrier() + + return main_loader, random_negative_loader + + +def _compute_loss( + model: LinkPredictionGNN, + main_data: HeteroData, + random_negative_data: HeteroData, + loss_fn: RetrievalLoss, + supervision_edge_type: EdgeType, + edge_dir: str, + device: torch.device, +) -> torch.Tensor: + """ + With the provided model and loss function, computes the forward pass on the main batch data and random negative data. + Args: + model (LinkPredictionGNN): DDP-wrapped LinkPredictionGNN model for training and testing + main_data (HeteroData): The batch of data containing query nodes, positive nodes, and hard negative nodes + random_negative_data (HeteroData): The batch of data containing random negative nodes + loss_fn (RetrievalLoss): Initialized class to use for loss calculation + supervision_edge_type (EdgeType): The supervision edge type to use for training in format query_node -> relation -> labeled_node + edge_dir (str): Direction of the supervision edge + device (torch.device): Device for training or validation + Returns: + torch.Tensor: Final loss for the current batch on the current process + """ + # Extract relevant node types from the supervision edge + if edge_dir == "out": + query_node_type = supervision_edge_type.src_node_type + labeled_node_type = supervision_edge_type.dst_node_type + else: + query_node_type = supervision_edge_type.dst_node_type + labeled_node_type = supervision_edge_type.src_node_type + + # print(f"---Rank {torch.distributed.get_rank()} query node type: {query_node_type}, labeled node type: {labeled_node_type} due to edge direction {edge_dir}") + if query_node_type == labeled_node_type: + inference_node_types = [query_node_type] + else: + inference_node_types = [query_node_type, labeled_node_type] + + # Forward pass through encoder + # print(f"Computing loss for main data: {main_data}") + # print(f"Computing loss for random negative data: {random_negative_data}") + # print(f"Using model: {model}") + flush() + main_embeddings = model( + data=main_data, output_node_types=inference_node_types, device=device + ) + random_negative_embeddings = model( + data=random_negative_data, + output_node_types=inference_node_types, + device=device, + ) + + # Extracting local query, random negative, positive, hard_negative, and random_negative indices. + query_node_idx: torch.Tensor = torch.arange( + main_data[query_node_type].batch_size + ).to(device) + random_negative_batch_size = random_negative_data[labeled_node_type].batch_size + + positive_idx: torch.Tensor = torch.cat(list(main_data.y_positive.values())).to( + device + ) + repeated_query_node_idx = query_node_idx.repeat_interleave( + torch.tensor([len(v) for v in main_data.y_positive.values()]).to(device) + ) + if hasattr(main_data, "y_negative"): + hard_negative_idx: torch.Tensor = torch.cat( + list(main_data.y_negative.values()) + ).to(device) + else: + hard_negative_idx = torch.empty(0, dtype=torch.long).to(device) + + # Use local IDs to get the corresponding embeddings in the tensors + + repeated_query_embeddings = main_embeddings[query_node_type][ + repeated_query_node_idx + ] + positive_node_embeddings = main_embeddings[labeled_node_type][positive_idx] + hard_negative_embeddings = main_embeddings[labeled_node_type][hard_negative_idx] + random_negative_embeddings = random_negative_embeddings[labeled_node_type][ + :random_negative_batch_size + ] + + # Decode the query embeddings and the candidate embeddings + + repeated_candidate_scores = model.decode( + query_embeddings=repeated_query_embeddings, + candidate_embeddings=torch.cat( + [ + positive_node_embeddings, + hard_negative_embeddings, + random_negative_embeddings, + ], + dim=0, + ), + ) + + # Compute the global candidate ids and concatenate into a single tensor + + global_candidate_ids = torch.cat( + ( + main_data[labeled_node_type].node[positive_idx], + main_data[labeled_node_type].node[hard_negative_idx], + random_negative_data[labeled_node_type].node[:random_negative_batch_size], + ) + ) + + global_repeated_query_ids = main_data[query_node_type].node[repeated_query_node_idx] + + # Feed scores and ids into the RetrievalLoss forward pass to get the final loss + + loss = loss_fn( + repeated_candidate_scores=repeated_candidate_scores, + candidate_ids=global_candidate_ids, + repeated_query_ids=global_repeated_query_ids, + device=device, + ) + + return loss + + +@dataclass(frozen=True) +class TrainingProcessArgs: + """ + Arguments for the heterogeneous training process in graph store mode. + + Attributes: + local_world_size (int): Number of training processes spawned by each machine. + cluster_info (GraphStoreInfo): Cluster topology info for graph store mode. + mp_sharing_dict (MutableMapping[str, torch.Tensor]): Shared dictionary for efficient tensor + sharing between local processes. + supervision_edge_type (EdgeType): The supervision edge type for training. + model_uri (Uri): URI to save/load the trained model state dict. + hid_dim (int): Hidden dimension of the model. + out_dim (int): Output dimension of the model. + node_type_to_feature_dim (dict[NodeType, int]): Mapping of node types to their feature dimensions. + edge_type_to_feature_dim (dict[EdgeType, int]): Mapping of edge types to their feature dimensions. + num_neighbors (Union[list[int], dict[EdgeType, list[int]]]): Fanout for subgraph sampling. + sampling_workers_per_process (int): Number of sampling workers per training/testing process. + sampling_worker_shared_channel_size (str): Shared-memory buffer size for the channel during sampling. + process_start_gap_seconds (int): Time to sleep between dataloader initializations. + main_batch_size (int): Batch size for main dataloader. + random_batch_size (int): Batch size for random negative dataloader. + learning_rate (float): Learning rate for the optimizer. + weight_decay (float): Weight decay for the optimizer. + num_max_train_batches (int): Maximum number of training batches across all processes. + num_val_batches (int): Number of validation batches across all processes. + val_every_n_batch (int): Frequency to run validation during training. + log_every_n_batch (int): Frequency to log batch information during training. + should_skip_training (bool): If True, skip training and only run testing. + """ + + # Distributed context + local_world_size: int + cluster_info: GraphStoreInfo + mp_sharing_dict: MutableMapping[str, torch.Tensor] + + # Data + supervision_edge_type: EdgeType + + # Model + model_uri: Uri + hid_dim: int + out_dim: int + node_type_to_feature_dim: dict[NodeType, int] + edge_type_to_feature_dim: dict[EdgeType, int] + + # Sampling config + num_neighbors: Union[list[int], dict[EdgeType, list[int]]] + sampling_workers_per_process: int + sampling_worker_shared_channel_size: str + process_start_gap_seconds: int + + # Training hyperparameters + main_batch_size: int + random_batch_size: int + learning_rate: float + weight_decay: float + num_max_train_batches: int + num_val_batches: int + val_every_n_batch: int + log_every_n_batch: int + should_skip_training: bool + + +def _training_process( + local_rank: int, + args: TrainingProcessArgs, +) -> None: + """ + This function is spawned by each machine for training a GNN model using graph store mode. + Args: + local_rank (int): Process number on the current machine + args (TrainingProcessArgs): Dataclass containing all training process arguments + """ + + # Note: This is a *critical* step in Graph Store mode. It initializes the connection to the storage cluster + # and sets up torch.distributed with the appropriate backend (NCCL if CUDA available, gloo otherwise). + print( + f"Initializing compute process for local_rank {local_rank} in machine {args.cluster_info.compute_node_rank}" + ) + flush() + init_compute_process(local_rank, args.cluster_info) + dataset = RemoteDistDataset( + args.cluster_info, local_rank, mp_sharing_dict=args.mp_sharing_dict + ) + + rank = torch.distributed.get_rank() + world_size = torch.distributed.get_world_size() + print( + f"---Current training process rank: {rank}, training process world size: {world_size}" + ) + flush() + + # We use one training device for each local process + device = get_available_device(local_process_rank=local_rank) + if torch.cuda.is_available(): + torch.cuda.set_device(device) + print(f"---Rank {rank} training process set device {device}") + + loss_fn = RetrievalLoss( + loss=torch.nn.CrossEntropyLoss(reduction="mean"), + temperature=0.07, + remove_accidental_hits=True, + ) + + if not args.should_skip_training: + train_main_loader, train_random_negative_loader = _setup_dataloaders( + dataset=dataset, + split="train", + cluster_info=args.cluster_info, + supervision_edge_type=args.supervision_edge_type, + num_neighbors=args.num_neighbors, + sampling_workers_per_process=args.sampling_workers_per_process, + main_batch_size=args.main_batch_size, + random_batch_size=args.random_batch_size, + device=device, + sampling_worker_shared_channel_size=args.sampling_worker_shared_channel_size, + process_start_gap_seconds=args.process_start_gap_seconds, + ) + + train_main_loader_iter = InfiniteIterator(train_main_loader) + train_random_negative_loader_iter = InfiniteIterator( + train_random_negative_loader + ) + + val_main_loader, val_random_negative_loader = _setup_dataloaders( + dataset=dataset, + split="val", + cluster_info=args.cluster_info, + supervision_edge_type=args.supervision_edge_type, + num_neighbors=args.num_neighbors, + sampling_workers_per_process=args.sampling_workers_per_process, + main_batch_size=args.main_batch_size, + random_batch_size=args.random_batch_size, + device=device, + sampling_worker_shared_channel_size=args.sampling_worker_shared_channel_size, + process_start_gap_seconds=args.process_start_gap_seconds, + ) + + val_main_loader_iter = InfiniteIterator(val_main_loader) + val_random_negative_loader_iter = InfiniteIterator(val_random_negative_loader) + + model = init_example_gigl_heterogeneous_model( + node_type_to_feature_dim=args.node_type_to_feature_dim, + edge_type_to_feature_dim=args.edge_type_to_feature_dim, + hid_dim=args.hid_dim, + out_dim=args.out_dim, + device=device, + wrap_with_ddp=True, + find_unused_encoder_parameters=True, + ) + optimizer = torch.optim.AdamW( + params=model.parameters(), + lr=args.learning_rate, + weight_decay=args.weight_decay, + ) + print(f"Model initialized on rank {rank} training device {device}\n{model}") + flush() + + # We add a barrier to wait for all processes to finish preparing the dataloader and initializing the model + torch.distributed.barrier() + + # Entering the training loop + training_start_time = time.time() + batch_idx = 0 + avg_train_loss = 0.0 + last_n_batch_avg_loss: list[float] = [] + last_n_batch_time: list[float] = [] + num_max_train_batches_per_process = args.num_max_train_batches // world_size + num_val_batches_per_process = args.num_val_batches // world_size + print( + f"num_max_train_batches_per_process is set to {num_max_train_batches_per_process}" + ) + + model.train() + + batch_start = time.time() + for main_data, random_data in zip( + train_main_loader_iter, train_random_negative_loader_iter + ): + if batch_idx >= num_max_train_batches_per_process: + print( + f"num_max_train_batches_per_process={num_max_train_batches_per_process} reached, " + f"stopping training on machine {args.cluster_info.compute_node_rank} local rank {local_rank}" + ) + break + loss = _compute_loss( + model=model, + main_data=main_data, + random_negative_data=random_data, + loss_fn=loss_fn, + supervision_edge_type=args.supervision_edge_type, + edge_dir=dataset.fetch_edge_dir(), + device=device, + ) + optimizer.zero_grad() + loss.backward() + optimizer.step() + avg_train_loss = _sync_metric_across_processes(metric=loss) + last_n_batch_avg_loss.append(avg_train_loss) + last_n_batch_time.append(time.time() - batch_start) + batch_start = time.time() + batch_idx += 1 + if ( + batch_idx % args.log_every_n_batch == 0 or batch_idx < 10 + ): # Log the first 10 batches to ensure the model is initialized correctly + print( + f"rank={rank}, batch={batch_idx}, latest local train_loss={loss:.6f}" + ) + if torch.cuda.is_available(): + torch.cuda.synchronize() + print( + f"rank={rank}, batch={batch_idx}, mean(batch_time)={statistics.mean(last_n_batch_time):.3f} sec, max(batch_time)={max(last_n_batch_time):.3f} sec, min(batch_time)={min(last_n_batch_time):.3f} sec" + ) + last_n_batch_time.clear() + print( + f"rank={rank}, latest avg_train_loss={avg_train_loss:.6f}, last {args.log_every_n_batch} mean(avg_train_loss)={statistics.mean(last_n_batch_avg_loss):.6f}" + ) + last_n_batch_avg_loss.clear() + flush() + + if batch_idx % args.val_every_n_batch == 0: + print(f"rank={rank}, batch={batch_idx}, validating...") + model.eval() + _run_validation_loops( + model=model, + main_loader=val_main_loader_iter, + random_negative_loader=val_random_negative_loader_iter, + loss_fn=loss_fn, + supervision_edge_type=args.supervision_edge_type, + edge_dir=dataset.fetch_edge_dir(), + device=device, + log_every_n_batch=args.log_every_n_batch, + num_batches=num_val_batches_per_process, + ) + model.train() + else: + print(f"rank={rank} ended training early - no break condition was met") + print(f"---Rank {rank} finished training") + flush() + + # Memory cleanup and waiting for all processes to finish + if torch.cuda.is_available(): + torch.cuda.empty_cache() + torch.cuda.synchronize() + torch.distributed.barrier() + + # We explicitly shutdown all the dataloaders to reduce their memory footprint. + train_main_loader.shutdown() + train_random_negative_loader.shutdown() + val_main_loader.shutdown() + val_random_negative_loader.shutdown() + + # We save the model on the process with rank 0. + if torch.distributed.get_rank() == 0: + print( + f"Training loop finished, took {time.time() - training_start_time:.3f} seconds, saving model to {args.model_uri}" + ) + save_state_dict( + model=model.unwrap_from_ddp(), save_to_path_uri=args.model_uri + ) + flush() + + else: # should_skip_training is True, meaning we should only run testing + state_dict = load_state_dict_from_uri( + load_from_uri=args.model_uri, device=device + ) + model = init_example_gigl_heterogeneous_model( + node_type_to_feature_dim=args.node_type_to_feature_dim, + edge_type_to_feature_dim=args.edge_type_to_feature_dim, + hid_dim=args.hid_dim, + out_dim=args.out_dim, + device=device, + wrap_with_ddp=True, + find_unused_encoder_parameters=True, + state_dict=state_dict, + ) + print(f"Model initialized on rank {rank} training device {device}\n{model}") + + print(f"---Rank {rank} started testing") + flush() + testing_start_time = time.time() + + model.eval() + + test_main_loader, test_random_negative_loader = _setup_dataloaders( + dataset=dataset, + split="test", + cluster_info=args.cluster_info, + supervision_edge_type=args.supervision_edge_type, + num_neighbors=args.num_neighbors, + sampling_workers_per_process=args.sampling_workers_per_process, + main_batch_size=args.main_batch_size, + random_batch_size=args.random_batch_size, + device=device, + sampling_worker_shared_channel_size=args.sampling_worker_shared_channel_size, + process_start_gap_seconds=args.process_start_gap_seconds, + ) + + # Since we are doing testing, we only want to go through the data once. + test_main_loader_iter = iter(test_main_loader) + test_random_negative_loader_iter = iter(test_random_negative_loader) + + _run_validation_loops( + model=model, + main_loader=test_main_loader_iter, + random_negative_loader=test_random_negative_loader_iter, + loss_fn=loss_fn, + supervision_edge_type=args.supervision_edge_type, + edge_dir=dataset.fetch_edge_dir(), + device=device, + log_every_n_batch=args.log_every_n_batch, + ) + + # Memory cleanup and waiting for all processes to finish + if torch.cuda.is_available(): + torch.cuda.empty_cache() + torch.cuda.synchronize() + torch.distributed.barrier() + + test_main_loader.shutdown() + test_random_negative_loader.shutdown() + + print( + f"---Rank {rank} finished testing in {time.time() - testing_start_time:.3f} seconds" + ) + flush() + + # Graph store mode cleanup: shutdown the compute process connection to the storage cluster. + shutdown_compute_proccess() + gc.collect() + + print( + f"---Rank {rank} finished all training and testing, shut down compute process" + ) + flush() + + +@torch.inference_mode() +def _run_validation_loops( + model: LinkPredictionGNN, + main_loader: Iterator[HeteroData], + random_negative_loader: Iterator[HeteroData], + loss_fn: RetrievalLoss, + supervision_edge_type: EdgeType, + edge_dir: str, + device: torch.device, + log_every_n_batch: int, + num_batches: Optional[int] = None, +) -> None: + """ + Runs validation using the provided models and dataloaders. + This function is shared for both validation while training and testing after training has completed. + Args: + model (LinkPredictionGNN): DDP-wrapped LinkPredictionGNN model for training and testing + main_loader (Iterator[HeteroData]): Dataloader for loading main batch data with query and labeled nodes + random_negative_loader (Iterator[HeteroData]): Dataloader for loading random negative data + loss_fn (RetrievalLoss): Initialized class to use for loss calculation + supervision_edge_type (EdgeType): The supervision edge type to use for training + edge_dir (Literal["in", "out"]): Direction of the supervision edge + device (torch.device): Device to use for training or testing + log_every_n_batch (int): The frequency we should log batch information + num_batches (Optional[int]): The number of batches to run the validation loop for. + """ + + rank = torch.distributed.get_rank() + + print( + f"Running validation loop on rank={rank}, log_every_n_batch={log_every_n_batch}, num_batches={num_batches}" + ) + if num_batches is None: + if isinstance(main_loader, InfiniteIterator) or isinstance( + random_negative_loader, InfiniteIterator + ): + raise ValueError( + "Must set `num_batches` field when the provided data loaders are wrapped with InfiniteIterator" + ) + + batch_idx = 0 + batch_losses: list[float] = [] + last_n_batch_time: list[float] = [] + batch_start = time.time() + + while True: + if num_batches and batch_idx >= num_batches: + print( + f"Rank {torch.distributed.get_rank()} num_batches={num_batches} reached, stopping validation loop" + ) + break + try: + main_data = next(main_loader) + random_data = next(random_negative_loader) + except StopIteration: + print( + f"Rank {torch.distributed.get_rank()} test data loader exhausted, stopping validation loop" + ) + break + + loss = _compute_loss( + model=model, + main_data=main_data, + random_negative_data=random_data, + loss_fn=loss_fn, + supervision_edge_type=supervision_edge_type, + edge_dir=edge_dir, + device=device, + ) + + batch_losses.append(loss.item()) + last_n_batch_time.append(time.time() - batch_start) + batch_start = time.time() + batch_idx += 1 + if batch_idx % log_every_n_batch == 0: + print(f"rank={rank}, batch={batch_idx}, latest test_loss={loss:.6f}") + if torch.cuda.is_available(): + torch.cuda.synchronize() + print( + f"rank={rank}, batch={batch_idx}, mean(batch_time)={statistics.mean(last_n_batch_time):.3f} sec, max(batch_time)={max(last_n_batch_time):.3f} sec, min(batch_time)={min(last_n_batch_time):.3f} sec" + ) + last_n_batch_time.clear() + flush() + if len(batch_losses) == 0: + print( + f"rank={rank} WARNING: 0 batches processed in validation loop, setting local loss to 0.0" + ) + flush() + local_avg_loss = 0.0 + else: + local_avg_loss = statistics.mean(batch_losses) + global_avg_val_loss = _sync_metric_across_processes( + metric=torch.tensor(local_avg_loss, device=device) + ) + print(f"rank={rank} got global validation loss {global_avg_val_loss=:.6f}") + flush() + + return + + +def _run_example_training( + task_config_uri: str, +): + """ + Runs an example training + testing loop using GiGL Orchestration in graph store mode. + Args: + task_config_uri (str): Path to YAML-serialized GbmlConfig proto. + """ + program_start_time = time.time() + mp.set_start_method("spawn") + print(f"Starting sub process method: {mp.get_start_method()}") + + # Step 1: Initialize global process group to get cluster info + torch.distributed.init_process_group(backend="gloo") + print( + f"World size: {torch.distributed.get_world_size()}, rank: {torch.distributed.get_rank()}, OS world size: {os.environ['WORLD_SIZE']}, OS rank: {os.environ['RANK']}" + ) + cluster_info = get_graph_store_info() + print(f"Cluster info: {cluster_info}") + torch.distributed.destroy_process_group() + print(f"Took {time.time() - program_start_time:.2f} seconds to connect worker pool") + flush() + + # Step 2: Read config + gbml_config_pb_wrapper = GbmlConfigPbWrapper.get_gbml_config_pb_wrapper_from_uri( + gbml_config_uri=UriFactory.create_uri(task_config_uri) + ) + + # Training Hyperparameters + trainer_args = dict(gbml_config_pb_wrapper.trainer_config.trainer_args) + + if torch.cuda.is_available(): + default_local_world_size = torch.cuda.device_count() + else: + default_local_world_size = 2 + local_world_size = int( + trainer_args.get("local_world_size", str(default_local_world_size)) + ) + + if torch.cuda.is_available(): + if local_world_size > torch.cuda.device_count(): + raise ValueError( + f"Specified a local world size of {local_world_size} which exceeds the number of devices {torch.cuda.device_count()}" + ) + + fanout = trainer_args.get("num_neighbors", "[10, 10]") + num_neighbors = parse_fanout(fanout) + + sampling_workers_per_process: int = int( + trainer_args.get("sampling_workers_per_process", "4") + ) + + main_batch_size = int(trainer_args.get("main_batch_size", "4")) + random_batch_size = int(trainer_args.get("random_batch_size", "4")) + + hid_dim = int(trainer_args.get("hid_dim", "16")) + out_dim = int(trainer_args.get("out_dim", "16")) + + sampling_worker_shared_channel_size: str = trainer_args.get( + "sampling_worker_shared_channel_size", "4GB" + ) + + process_start_gap_seconds = int(trainer_args.get("process_start_gap_seconds", "0")) + log_every_n_batch = int(trainer_args.get("log_every_n_batch", "25")) + + learning_rate = float(trainer_args.get("learning_rate", "0.0005")) + weight_decay = float(trainer_args.get("weight_decay", "0.0005")) + num_max_train_batches = int(trainer_args.get("num_max_train_batches", "1000")) + num_val_batches = int(trainer_args.get("num_val_batches", "100")) + val_every_n_batch = int(trainer_args.get("val_every_n_batch", "50")) + + print( + f"Got training args local_world_size={local_world_size}, \ + num_neighbors={num_neighbors}, \ + sampling_workers_per_process={sampling_workers_per_process}, \ + main_batch_size={main_batch_size}, \ + random_batch_size={random_batch_size}, \ + hid_dim={hid_dim}, \ + out_dim={out_dim}, \ + sampling_worker_shared_channel_size={sampling_worker_shared_channel_size}, \ + process_start_gap_seconds={process_start_gap_seconds}, \ + log_every_n_batch={log_every_n_batch}, \ + learning_rate={learning_rate}, \ + weight_decay={weight_decay}, \ + num_max_train_batches={num_max_train_batches}, \ + num_val_batches={num_val_batches}, \ + val_every_n_batch={val_every_n_batch}" + ) + + # Step 3: Extract model/data config + graph_metadata = gbml_config_pb_wrapper.graph_metadata_pb_wrapper + + node_type_to_feature_dim: dict[NodeType, int] = { + graph_metadata.condensed_node_type_to_node_type_map[ + condensed_node_type + ]: node_feature_dim + for condensed_node_type, node_feature_dim in gbml_config_pb_wrapper.preprocessed_metadata_pb_wrapper.condensed_node_type_to_feature_dim_map.items() + } + + edge_type_to_feature_dim: dict[EdgeType, int] = { + graph_metadata.condensed_edge_type_to_edge_type_map[ + condensed_edge_type + ]: edge_feature_dim + for condensed_edge_type, edge_feature_dim in gbml_config_pb_wrapper.preprocessed_metadata_pb_wrapper.condensed_edge_type_to_feature_dim_map.items() + } + + model_uri = UriFactory.create_uri( + gbml_config_pb_wrapper.gbml_config_pb.shared_config.trained_model_metadata.trained_model_uri + ) + + should_skip_training = gbml_config_pb_wrapper.shared_config.should_skip_training + + supervision_edge_types = ( + gbml_config_pb_wrapper.task_metadata_pb_wrapper.get_supervision_edge_types() + ) + if len(supervision_edge_types) != 1: + raise NotImplementedError( + "GiGL Training currently only supports 1 supervision edge type." + ) + supervision_edge_type = supervision_edge_types[0] + + # Step 4: Create shared dict for inter-process tensor sharing + mp_sharing_dict = mp.Manager().dict() + + # Step 5: Spawn training processes + print("--- Launching training processes ...\n") + flush() + start_time = time.time() + + training_args = TrainingProcessArgs( + local_world_size=local_world_size, + cluster_info=cluster_info, + mp_sharing_dict=mp_sharing_dict, + supervision_edge_type=supervision_edge_type, + model_uri=model_uri, + hid_dim=hid_dim, + out_dim=out_dim, + node_type_to_feature_dim=node_type_to_feature_dim, + edge_type_to_feature_dim=edge_type_to_feature_dim, + num_neighbors=num_neighbors, + sampling_workers_per_process=sampling_workers_per_process, + sampling_worker_shared_channel_size=sampling_worker_shared_channel_size, + process_start_gap_seconds=process_start_gap_seconds, + main_batch_size=main_batch_size, + random_batch_size=random_batch_size, + learning_rate=learning_rate, + weight_decay=weight_decay, + num_max_train_batches=num_max_train_batches, + num_val_batches=num_val_batches, + val_every_n_batch=val_every_n_batch, + log_every_n_batch=log_every_n_batch, + should_skip_training=should_skip_training, + ) + + torch.multiprocessing.spawn( + _training_process, + args=(training_args,), + nprocs=local_world_size, + join=True, + ) + print(f"--- Training finished, took {time.time() - start_time} seconds") + print( + f"--- Program finished, which took {time.time() - program_start_time:.2f} seconds" + ) + flush() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Arguments for distributed model training on VertexAI (graph store mode)" + ) + parser.add_argument("--task_config_uri", type=str, help="Gbml config uri") + + args, unused_args = parser.parse_known_args() + print(f"Unused arguments: {unused_args}") + + _run_example_training( + task_config_uri=args.task_config_uri, + ) diff --git a/examples/link_prediction/graph_store/homogeneous_training.py b/examples/link_prediction/graph_store/homogeneous_training.py new file mode 100644 index 000000000..816fb1edc --- /dev/null +++ b/examples/link_prediction/graph_store/homogeneous_training.py @@ -0,0 +1,949 @@ +""" +This file contains an example for how to run homogeneous link prediction training in +**graph store mode** using GiGL. + +Graph Store Mode vs Standard (Colocated) Mode +---------------------------------------------- +Graph store mode uses a heterogeneous cluster architecture with two distinct sub-clusters: + + 1. **Storage Cluster (graph_store_pool)**: Dedicated machines for storing and serving the graph + data. These are typically high-memory machines without GPUs (e.g., n2-highmem-32). + 2. **Compute Cluster (compute_pool)**: Dedicated machines for running model training. + These typically have GPUs attached (e.g., n1-standard-16 with NVIDIA_TESLA_T4). + +This separation allows for independent scaling of storage and compute resources, better memory +utilization (graph data stays on storage nodes), and cost optimization by using appropriate +hardware for each role. + +In contrast, the standard colocated training mode +(see ``examples/link_prediction/homogeneous_training.py``) uses a homogeneous cluster where each +machine handles both graph storage and computation. + +Key Implementation Differences +------------------------------ + ++---------------------------+----------------------------------------------+----------------------------------------------+ +| Aspect | Standard (``homogeneous_training.py``) | Graph Store (this file) | ++===========================+==============================================+==============================================+ +| **Dataset class** | ``DistDataset`` (local partition) | ``RemoteDistDataset`` (RPC to storage) | ++---------------------------+----------------------------------------------+----------------------------------------------+ +| **Dataset loading** | ``build_dataset_from_task_config_uri()`` | Storage nodes build data; compute nodes | +| | loads and partitions data locally | connect via ``init_compute_process()`` | ++---------------------------+----------------------------------------------+----------------------------------------------+ +| **Process group init** | Manual ``init_process_group`` with master | ``init_process_group(gloo)`` to | +| | IP/port, ``destroy_process_group``, then | ``get_graph_store_info()``, then | +| | re-init in spawned processes | ``destroy_process_group``; spawned processes | +| | | call ``init_compute_process()`` | ++---------------------------+----------------------------------------------+----------------------------------------------+ +| **Split/label access** | ``dataset.train_node_ids`` / | ``dataset.fetch_ablp_input(split=...)`` | +| | ``dataset.val_node_ids`` / | fetches anchors + labels from storage via | +| | ``dataset.test_node_ids`` via | RPC | +| | ``to_homogeneous()`` | | ++---------------------------+----------------------------------------------+----------------------------------------------+ +| **Random negative nodes** | ``dataset.node_ids`` via | ``dataset.fetch_node_ids()`` fetches from | +| | ``to_homogeneous()`` | storage via RPC | ++---------------------------+----------------------------------------------+----------------------------------------------+ +| **Cluster info** | ``machine_rank``, ``machine_world_size``, | ``GraphStoreInfo`` dataclass from | +| | ``master_ip_address`` extracted manually | ``get_graph_store_info()`` encapsulates all | +| | | topology | ++---------------------------+----------------------------------------------+----------------------------------------------+ +| **Inter-process sharing** | N/A (each process loads own partition) | ``mp_sharing_dict`` for efficient tensor | +| | | sharing between local processes | ++---------------------------+----------------------------------------------+----------------------------------------------+ +| **Cleanup** | ``torch.distributed.destroy_process_group()`` | ``shutdown_compute_proccess()`` disconnects | +| | | from storage cluster | ++---------------------------+----------------------------------------------+----------------------------------------------+ + +Data Splitting and Storage Pipeline +------------------------------------ +Before training begins, the **storage cluster** prepares the graph data including train/val/test +splits. The flow is: + +1. **Splitter configuration**: The ``splitter_cls_path`` and ``splitter_kwargs`` are specified in + the YAML config under ``graphStoreStorageConfig.storageArgs``. The storage entry point + (``storage_main.py``) parses these via ``argparse`` and dynamically imports the splitter class + using ``import_obj()``. The kwargs string is evaluated with ``ast.literal_eval`` and passed to + the splitter constructor (e.g. ``DistNodeAnchorLinkSplitter(**splitter_kwargs)``). + +2. **ABLP input fetching** (at training time): ``RemoteDistDataset.fetch_ablp_input(split=...)`` + issues an RPC to the storage cluster and returns a ``dict[int, ABLPInputNodes]`` keyed by + storage rank. Each ``ABLPInputNodes`` contains ``anchor_nodes``, ``positive_labels``, and + optional ``negative_labels`` tensors for the requested split. + +3. **Node ID fetching**: ``RemoteDistDataset.fetch_node_ids()`` similarly fetches all node IDs + from storage, used for the random negative sampling loader. + +Because the storage cluster owns the split, compute nodes see train/val/test as first-class +properties of the remote dataset. + +Config Example +-------------- +To run this file with GiGL orchestration, set the fields similar to below:: + + trainerConfig: + trainerArgs: + log_every_n_batch: "50" + num_neighbors: "[10, 10]" + command: python -m examples.link_prediction.graph_store.homogeneous_training + graphStoreStorageConfig: + command: python -m examples.link_prediction.graph_store.storage_main + storageArgs: + sample_edge_direction: "in" + splitter_cls_path: "gigl.utils.data_splitters.DistNodeAnchorLinkSplitter" + splitter_kwargs: '{"sampling_direction": "in", "should_convert_labels_to_edges": true, "num_val": 0.1, "num_test": 0.1}' + num_server_sessions: "1" + featureFlags: + should_run_glt_backend: 'True' + +Note: Ensure you use a resource config with ``vertex_ai_graph_store_trainer_config`` when +running in graph store mode. + +You can run this example in a full pipeline with ``make run_hom_cora_sup_gs_e2e_test`` from +GiGL root. +""" + +import argparse +import gc +import os +import statistics +import sys +import time +from collections.abc import Iterator, MutableMapping +from dataclasses import dataclass +from typing import Literal, Optional + +import torch +import torch.distributed +import torch.multiprocessing as mp +from examples.link_prediction.models import init_example_gigl_homogeneous_model +from torch_geometric.data import Data + +from gigl.common import Uri, UriFactory +from gigl.common.logger import Logger +from gigl.common.utils.torch_training import is_distributed_available_and_initialized +from gigl.distributed import DistABLPLoader +from gigl.distributed.distributed_neighborloader import DistNeighborLoader +from gigl.distributed.graph_store.compute import ( + init_compute_process, + shutdown_compute_proccess, +) +from gigl.distributed.graph_store.remote_dist_dataset import RemoteDistDataset +from gigl.distributed.utils import get_available_device, get_graph_store_info +from gigl.env.distributed import GraphStoreInfo +from gigl.nn import LinkPredictionGNN, RetrievalLoss +from gigl.src.common.types.graph_data import EdgeType +from gigl.src.common.types.pb_wrappers.gbml_config import GbmlConfigPbWrapper +from gigl.src.common.utils.model import load_state_dict_from_uri, save_state_dict +from gigl.utils.iterator import InfiniteIterator +from gigl.utils.sampling import parse_fanout + +logger = Logger() + + +# We don't see logs for graph store mode for whatever reason. +# TODO(#442): Revert this once the GCP issues are resolved. +def flush(): + sys.stdout.write("\n") + sys.stdout.flush() + sys.stderr.write("\n") + sys.stderr.flush() + + +def _sync_metric_across_processes(metric: torch.Tensor) -> float: + """ + Takes the average of a training metric across multiple processes. Note that this function requires DDP to be initialized. + Args: + metric (torch.Tensor): The metric, expressed as a torch Tensor, which should be synced across multiple processes + Returns: + float: The average of the provided metric across all training processes + """ + assert is_distributed_available_and_initialized(), "DDP is not initialized" + loss_tensor = metric.detach().clone() + torch.distributed.all_reduce(loss_tensor, op=torch.distributed.ReduceOp.SUM) + return loss_tensor.item() / torch.distributed.get_world_size() + + +def _setup_dataloaders( + dataset: RemoteDistDataset, + split: Literal["train", "val", "test"], + cluster_info: GraphStoreInfo, + num_neighbors: list[int] | dict[EdgeType, list[int]], + sampling_workers_per_process: int, + main_batch_size: int, + random_batch_size: int, + device: torch.device, + sampling_worker_shared_channel_size: str, + process_start_gap_seconds: int, +) -> tuple[DistABLPLoader, DistNeighborLoader]: + """ + Sets up main and random dataloaders for training and testing purposes using a remote graph store dataset. + Args: + dataset (RemoteDistDataset): Remote dataset connected to the graph store cluster. + split (Literal["train", "val", "test"]): The current split which we are loading data for. + cluster_info (GraphStoreInfo): Cluster topology info for graph store mode. + num_neighbors: Fanout for subgraph sampling. + sampling_workers_per_process (int): Number of sampling workers per training/testing process. + main_batch_size (int): Batch size for main dataloader with query and labeled nodes. + random_batch_size (int): Batch size for random negative dataloader. + device (torch.device): Device to put loaded subgraphs on. + sampling_worker_shared_channel_size (str): Shared-memory buffer size (bytes) for the channel during sampling. + process_start_gap_seconds (int): The amount of time to sleep for initializing each dataloader. + Returns: + DistABLPLoader: Dataloader for loading main batch data with query and labeled nodes. + DistNeighborLoader: Dataloader for loading random negative data. + """ + rank = torch.distributed.get_rank() + + shuffle = split == "train" + + # In graph store mode, we fetch ABLP input (anchors + positive/negative labels) from the storage cluster. + # For homogeneous graphs, no node type or supervision edge type wrapper is needed. + logger.info(f"---Rank {rank} fetching ABLP input for split={split}") + flush() + ablp_input = dataset.fetch_ablp_input( + split=split, + rank=cluster_info.compute_node_rank, + world_size=cluster_info.num_compute_nodes, + ) + + for storage_rank, ablp_nodes in ablp_input.items(): + print( + f"Rank {rank} split={split}: storage_rank={storage_rank}, " + f"num_anchors={ablp_nodes.anchor_nodes.shape}, " + f"labels: {ablp_nodes.labels}" + ) + flush() + + main_loader = DistABLPLoader( + dataset=dataset, + num_neighbors=num_neighbors, + input_nodes=ablp_input, + num_workers=sampling_workers_per_process, + batch_size=main_batch_size, + pin_memory_device=device, + worker_concurrency=sampling_workers_per_process, + channel_size=sampling_worker_shared_channel_size, + process_start_gap_seconds=process_start_gap_seconds, + shuffle=shuffle, + ) + + logger.info(f"---Rank {rank} finished setting up main loader for split={split}") + flush() + + # We need to wait for all processes to finish initializing the main_loader before creating the + # random_negative_loader so that its initialization doesn't compete for memory. + torch.distributed.barrier() + + # For the random negative loader, we get all node IDs from the storage cluster. + all_node_ids = dataset.fetch_node_ids( + rank=cluster_info.compute_node_rank, + world_size=cluster_info.num_compute_nodes, + ) + + for storage_rank, node_ids_tensor in all_node_ids.items(): + print( + f"Rank {rank} split={split}: random_negative storage_rank={storage_rank}, " + f"num_node_ids={node_ids_tensor.shape}" + ) + flush() + + random_negative_loader = DistNeighborLoader( + dataset=dataset, + num_neighbors=num_neighbors, + input_nodes=all_node_ids, + num_workers=sampling_workers_per_process, + batch_size=random_batch_size, + pin_memory_device=device, + worker_concurrency=sampling_workers_per_process, + channel_size=sampling_worker_shared_channel_size, + process_start_gap_seconds=process_start_gap_seconds, + shuffle=shuffle, + ) + + logger.info( + f"---Rank {rank} finished setting up random negative loader for split={split}" + ) + flush() + + # Wait for all processes to finish initializing the random_loader + torch.distributed.barrier() + + return main_loader, random_negative_loader + + +def _compute_loss( + model: LinkPredictionGNN, + main_data: Data, + random_negative_data: Data, + loss_fn: RetrievalLoss, + device: torch.device, +) -> torch.Tensor: + """ + With the provided model and loss function, computes the forward pass on the main batch data and random negative data. + Args: + model (LinkPredictionGNN): DDP-wrapped LinkPredictionGNN model for training and testing + main_data (Data): The batch of data containing query nodes, positive nodes, and hard negative nodes + random_negative_data (Data): The batch of data containing random negative nodes + loss_fn (RetrievalLoss): Initialized class to use for loss calculation + device (torch.device): Device for training or validation + Returns: + torch.Tensor: Final loss for the current batch on the current process + """ + # print(f"Computing loss for main data: {main_data}") + # print(f"Computing loss for random negative data: {random_negative_data}") + # print(f"Using model: {model}") + flush() + # Forward pass through encoder + main_embeddings = model(data=main_data, device=device) + random_negative_embeddings = model(data=random_negative_data, device=device) + + query_node_idx: torch.Tensor = torch.arange(main_data.batch_size).to(device) + random_negative_batch_size = random_negative_data.batch_size + + positive_idx: torch.Tensor = torch.cat(list(main_data.y_positive.values())).to( + device + ) + repeated_query_node_idx = query_node_idx.repeat_interleave( + torch.tensor([len(v) for v in main_data.y_positive.values()]).to(device) + ) + if hasattr(main_data, "y_negative"): + hard_negative_idx: torch.Tensor = torch.cat( + list(main_data.y_negative.values()) + ).to(device) + else: + hard_negative_idx = torch.empty(0, dtype=torch.long).to(device) + + # Use local IDs to get the corresponding embeddings in the tensors + + repeated_query_embeddings = main_embeddings[repeated_query_node_idx] + positive_node_embeddings = main_embeddings[positive_idx] + hard_negative_embeddings = main_embeddings[hard_negative_idx] + random_negative_embeddings = random_negative_embeddings[:random_negative_batch_size] + + repeated_candidate_scores = model.decode( + query_embeddings=repeated_query_embeddings, + candidate_embeddings=torch.cat( + [ + positive_node_embeddings, + hard_negative_embeddings, + random_negative_embeddings, + ], + dim=0, + ), + ) + + global_candidate_ids = torch.cat( + ( + main_data.node[positive_idx], + main_data.node[hard_negative_idx], + random_negative_data.node[:random_negative_batch_size], + ) + ) + + global_repeated_query_ids = main_data.node[repeated_query_node_idx] + + loss = loss_fn( + repeated_candidate_scores=repeated_candidate_scores, + candidate_ids=global_candidate_ids, + repeated_query_ids=global_repeated_query_ids, + device=device, + ) + + return loss + + +@dataclass(frozen=True) +class TrainingProcessArgs: + """ + Arguments for the homogeneous training process in graph store mode. + + Attributes: + local_world_size (int): Number of training processes spawned by each machine. + cluster_info (GraphStoreInfo): Cluster topology info for graph store mode. + mp_sharing_dict (MutableMapping[str, torch.Tensor]): Shared dictionary for efficient tensor + sharing between local processes. + model_uri (Uri): URI to save/load the trained model state dict. + hid_dim (int): Hidden dimension of the model. + out_dim (int): Output dimension of the model. + node_feature_dim (int): Input node feature dimension for the model. + edge_feature_dim (int): Input edge feature dimension for the model. + num_neighbors (list[int] | dict[EdgeType, list[int]]): Fanout for subgraph sampling. + sampling_workers_per_process (int): Number of sampling workers per training/testing process. + sampling_worker_shared_channel_size (str): Shared-memory buffer size for the channel during sampling. + process_start_gap_seconds (int): Time to sleep between dataloader initializations. + main_batch_size (int): Batch size for main dataloader. + random_batch_size (int): Batch size for random negative dataloader. + learning_rate (float): Learning rate for the optimizer. + weight_decay (float): Weight decay for the optimizer. + num_max_train_batches (int): Maximum number of training batches across all processes. + num_val_batches (int): Number of validation batches across all processes. + val_every_n_batch (int): Frequency to run validation during training. + log_every_n_batch (int): Frequency to log batch information during training. + should_skip_training (bool): If True, skip training and only run testing. + """ + + # Distributed context + local_world_size: int + cluster_info: GraphStoreInfo + mp_sharing_dict: MutableMapping[str, torch.Tensor] + + # Model + model_uri: Uri + hid_dim: int + out_dim: int + node_feature_dim: int + edge_feature_dim: int + + # Sampling config + num_neighbors: list[int] | dict[EdgeType, list[int]] + sampling_workers_per_process: int + sampling_worker_shared_channel_size: str + process_start_gap_seconds: int + + # Training hyperparameters + main_batch_size: int + random_batch_size: int + learning_rate: float + weight_decay: float + num_max_train_batches: int + num_val_batches: int + val_every_n_batch: int + log_every_n_batch: int + should_skip_training: bool + + +def _training_process( + local_rank: int, + args: TrainingProcessArgs, +) -> None: + """ + This function is spawned by each machine for training a GNN model using graph store mode. + Args: + local_rank (int): Process number on the current machine + args (TrainingProcessArgs): Dataclass containing all training process arguments + """ + + # Note: This is a *critical* step in Graph Store mode. It initializes the connection to the storage cluster + # and sets up torch.distributed with the appropriate backend (NCCL if CUDA available, gloo otherwise). + logger.info( + f"Initializing compute process for local_rank {local_rank} in machine {args.cluster_info.compute_node_rank}" + ) + flush() + init_compute_process(local_rank, args.cluster_info) + dataset = RemoteDistDataset( + args.cluster_info, local_rank, mp_sharing_dict=args.mp_sharing_dict + ) + + rank = torch.distributed.get_rank() + world_size = torch.distributed.get_world_size() + logger.info( + f"---Current training process rank: {rank}, training process world size: {world_size}" + ) + flush() + + device = get_available_device(local_process_rank=local_rank) + if torch.cuda.is_available(): + torch.cuda.set_device(device) + logger.info(f"---Rank {rank} training process set device {device}") + + loss_fn = RetrievalLoss( + loss=torch.nn.CrossEntropyLoss(reduction="mean"), + temperature=0.07, + remove_accidental_hits=True, + ) + + if not args.should_skip_training: + train_main_loader, train_random_negative_loader = _setup_dataloaders( + dataset=dataset, + split="train", + cluster_info=args.cluster_info, + num_neighbors=args.num_neighbors, + sampling_workers_per_process=args.sampling_workers_per_process, + main_batch_size=args.main_batch_size, + random_batch_size=args.random_batch_size, + device=device, + sampling_worker_shared_channel_size=args.sampling_worker_shared_channel_size, + process_start_gap_seconds=args.process_start_gap_seconds, + ) + + train_main_loader_iter = InfiniteIterator(train_main_loader) + train_random_negative_loader_iter = InfiniteIterator( + train_random_negative_loader + ) + + val_main_loader, val_random_negative_loader = _setup_dataloaders( + dataset=dataset, + split="val", + cluster_info=args.cluster_info, + num_neighbors=args.num_neighbors, + sampling_workers_per_process=args.sampling_workers_per_process, + main_batch_size=args.main_batch_size, + random_batch_size=args.random_batch_size, + device=device, + sampling_worker_shared_channel_size=args.sampling_worker_shared_channel_size, + process_start_gap_seconds=args.process_start_gap_seconds, + ) + + val_main_loader_iter = InfiniteIterator(val_main_loader) + val_random_negative_loader_iter = InfiniteIterator(val_random_negative_loader) + + model = init_example_gigl_homogeneous_model( + node_feature_dim=args.node_feature_dim, + edge_feature_dim=args.edge_feature_dim, + hid_dim=args.hid_dim, + out_dim=args.out_dim, + device=device, + wrap_with_ddp=True, + find_unused_encoder_parameters=True, + ) + + optimizer = torch.optim.AdamW( + params=model.parameters(), + lr=args.learning_rate, + weight_decay=args.weight_decay, + ) + logger.info( + f"Model initialized on rank {rank} training device {device}\n{model}" + ) + flush() + + # We add a barrier to wait for all processes to finish preparing the dataloader and initializing the model + torch.distributed.barrier() + + # Entering the training loop + training_start_time = time.time() + batch_idx = 0 + avg_train_loss = 0.0 + last_n_batch_avg_loss: list[float] = [] + last_n_batch_time: list[float] = [] + num_max_train_batches_per_process = args.num_max_train_batches // world_size + num_val_batches_per_process = args.num_val_batches // world_size + logger.info( + f"num_max_train_batches_per_process is set to {num_max_train_batches_per_process}" + ) + + model.train() + + batch_start = time.time() + for main_data, random_data in zip( + train_main_loader_iter, train_random_negative_loader_iter + ): + if batch_idx >= num_max_train_batches_per_process: + logger.info( + f"num_max_train_batches_per_process={num_max_train_batches_per_process} reached, " + f"stopping training on machine {args.cluster_info.compute_node_rank} local rank {local_rank}" + ) + break + loss = _compute_loss( + model=model, + main_data=main_data, + random_negative_data=random_data, + loss_fn=loss_fn, + device=device, + ) + optimizer.zero_grad() + loss.backward() + optimizer.step() + avg_train_loss = _sync_metric_across_processes(metric=loss) + last_n_batch_avg_loss.append(avg_train_loss) + last_n_batch_time.append(time.time() - batch_start) + batch_start = time.time() + batch_idx += 1 + if batch_idx % args.log_every_n_batch == 0: + logger.info( + f"rank={rank}, batch={batch_idx}, latest local train_loss={loss:.6f}" + ) + if torch.cuda.is_available(): + torch.cuda.synchronize() + logger.info( + f"rank={rank}, mean(batch_time)={statistics.mean(last_n_batch_time):.3f} sec, max(batch_time)={max(last_n_batch_time):.3f} sec, min(batch_time)={min(last_n_batch_time):.3f} sec" + ) + last_n_batch_time.clear() + logger.info( + f"rank={rank}, latest avg_train_loss={avg_train_loss:.6f}, last {args.log_every_n_batch} mean(avg_train_loss)={statistics.mean(last_n_batch_avg_loss):.6f}" + ) + last_n_batch_avg_loss.clear() + flush() + + if batch_idx % args.val_every_n_batch == 0: + logger.info(f"rank={rank}, batch={batch_idx}, validating...") + model.eval() + _run_validation_loops( + model=model, + main_loader=val_main_loader_iter, + random_negative_loader=val_random_negative_loader_iter, + loss_fn=loss_fn, + device=device, + log_every_n_batch=args.log_every_n_batch, + num_batches=num_val_batches_per_process, + ) + model.train() + + logger.info(f"---Rank {rank} finished training") + flush() + + # Memory cleanup and waiting for all processes to finish + if torch.cuda.is_available(): + torch.cuda.empty_cache() + torch.cuda.synchronize() + torch.distributed.barrier() + + # We explicitly shutdown all the dataloaders to reduce their memory footprint. + train_main_loader.shutdown() + train_random_negative_loader.shutdown() + val_main_loader.shutdown() + val_random_negative_loader.shutdown() + + # We save the model on the process with rank 0. + if torch.distributed.get_rank() == 0: + logger.info( + f"Training loop finished, took {time.time() - training_start_time:.3f} seconds, saving model to {args.model_uri}" + ) + save_state_dict( + model=model.unwrap_from_ddp(), save_to_path_uri=args.model_uri + ) + flush() + + else: # should_skip_training is True, meaning we should only run testing + state_dict = load_state_dict_from_uri( + load_from_uri=args.model_uri, device=device + ) + model = init_example_gigl_homogeneous_model( + node_feature_dim=args.node_feature_dim, + edge_feature_dim=args.edge_feature_dim, + hid_dim=args.hid_dim, + out_dim=args.out_dim, + device=device, + wrap_with_ddp=True, + find_unused_encoder_parameters=True, + state_dict=state_dict, + ) + logger.info( + f"Model initialized on rank {rank} training device {device}\n{model}" + ) + + logger.info(f"---Rank {rank} started testing") + flush() + testing_start_time = time.time() + model.eval() + + test_main_loader, test_random_negative_loader = _setup_dataloaders( + dataset=dataset, + split="test", + cluster_info=args.cluster_info, + num_neighbors=args.num_neighbors, + sampling_workers_per_process=args.sampling_workers_per_process, + main_batch_size=args.main_batch_size, + random_batch_size=args.random_batch_size, + device=device, + sampling_worker_shared_channel_size=args.sampling_worker_shared_channel_size, + process_start_gap_seconds=args.process_start_gap_seconds, + ) + + # Since we are doing testing, we only want to go through the data once. + test_main_loader_iter = iter(test_main_loader) + test_random_negative_loader_iter = iter(test_random_negative_loader) + + _run_validation_loops( + model=model, + main_loader=test_main_loader_iter, + random_negative_loader=test_random_negative_loader_iter, + loss_fn=loss_fn, + device=device, + log_every_n_batch=args.log_every_n_batch, + ) + + # Memory cleanup and waiting for all processes to finish + if torch.cuda.is_available(): + torch.cuda.empty_cache() + torch.cuda.synchronize() + torch.distributed.barrier() + + test_main_loader.shutdown() + test_random_negative_loader.shutdown() + + logger.info( + f"---Rank {rank} finished testing in {time.time() - testing_start_time:.3f} seconds" + ) + flush() + + # Graph store mode cleanup: shutdown the compute process connection to the storage cluster. + shutdown_compute_proccess() + gc.collect() + + logger.info( + f"---Rank {rank} finished all training and testing, shut down compute process" + ) + flush() + + +@torch.inference_mode() +def _run_validation_loops( + model: LinkPredictionGNN, + main_loader: Iterator[Data], + random_negative_loader: Iterator[Data], + loss_fn: RetrievalLoss, + device: torch.device, + log_every_n_batch: int, + num_batches: Optional[int] = None, +) -> None: + """ + Runs validation using the provided models and dataloaders. + Args: + model (LinkPredictionGNN): DDP-wrapped LinkPredictionGNN model for training and testing + main_loader (Iterator[Data]): Dataloader for loading main batch data + random_negative_loader (Iterator[Data]): Dataloader for loading random negative data + loss_fn (RetrievalLoss): Initialized class to use for loss calculation + device (torch.device): Device to use for training or testing + log_every_n_batch (int): The frequency we should log batch information + num_batches (Optional[int]): The number of batches to run the validation loop for. + """ + rank = torch.distributed.get_rank() + + logger.info( + f"Running validation loop on rank={rank}, log_every_n_batch={log_every_n_batch}, num_batches={num_batches}" + ) + if num_batches is None: + if isinstance(main_loader, InfiniteIterator) or isinstance( + random_negative_loader, InfiniteIterator + ): + raise ValueError( + "Must set `num_batches` field when the provided data loaders are wrapped with InfiniteIterator" + ) + + batch_idx = 0 + batch_losses: list[float] = [] + last_n_batch_time: list[float] = [] + batch_start = time.time() + + while True: + if num_batches and batch_idx >= num_batches: + print( + f"Rank {torch.distributed.get_rank()} num_batches={num_batches} reached, stopping validation loop with batch_idx={batch_idx} and num_batches={num_batches}" + ) + flush() + break + try: + main_data = next(main_loader) + except StopIteration: + print( + f"Rank {torch.distributed.get_rank()} MAIN loader exhausted at batch_idx={batch_idx}, num_batches={num_batches}" + ) + flush() + break + try: + random_data = next(random_negative_loader) + except StopIteration: + print( + f"Rank {torch.distributed.get_rank()} RANDOM NEGATIVE loader exhausted at batch_idx={batch_idx}, num_batches={num_batches}" + ) + flush() + break + + loss = _compute_loss( + model=model, + main_data=main_data, + random_negative_data=random_data, + loss_fn=loss_fn, + device=device, + ) + + batch_losses.append(loss.item()) + last_n_batch_time.append(time.time() - batch_start) + batch_start = time.time() + batch_idx += 1 + if batch_idx % log_every_n_batch == 0: + print(f"rank={rank}, batch={batch_idx}, latest test_loss={loss:.6f}") + if torch.cuda.is_available(): + torch.cuda.synchronize() + logger.info( + f"rank={rank}, batch={batch_idx}, mean(batch_time)={statistics.mean(last_n_batch_time):.3f} sec, max(batch_time)={max(last_n_batch_time):.3f} sec, min(batch_time)={min(last_n_batch_time):.3f} sec" + ) + last_n_batch_time.clear() + flush() + if batch_losses: + local_avg_loss = statistics.mean(batch_losses) + else: + print( + f"rank={rank} WARNING: 0 batches processed in validation loop, setting local loss to 0.0" + ) + flush() + local_avg_loss = 0.0 + print( + f"rank={rank} finished validation loop, num_batches_processed={len(batch_losses)}, local loss: {local_avg_loss:.6f}" + ) + flush() + global_avg_val_loss = _sync_metric_across_processes( + metric=torch.tensor(local_avg_loss, device=device) + ) + print(f"rank={rank} got global validation loss {global_avg_val_loss=:.6f}") + flush() + + return + + +def _run_example_training( + task_config_uri: str, +): + """ + Runs an example training + testing loop using GiGL Orchestration in graph store mode. + Args: + task_config_uri (str): Path to YAML-serialized GbmlConfig proto. + """ + program_start_time = time.time() + mp.set_start_method("spawn") + logger.info(f"Starting sub process method: {mp.get_start_method()}") + + # Step 1: Initialize global process group to get cluster info + torch.distributed.init_process_group(backend="gloo") + logger.info( + f"World size: {torch.distributed.get_world_size()}, rank: {torch.distributed.get_rank()}, OS world size: {os.environ['WORLD_SIZE']}, OS rank: {os.environ['RANK']}" + ) + cluster_info = get_graph_store_info() + print(f"Cluster info: {cluster_info}") + flush() + torch.distributed.destroy_process_group() + logger.info( + f"Took {time.time() - program_start_time:.2f} seconds to connect worker pool" + ) + flush() + + # Step 2: Read config + gbml_config_pb_wrapper = GbmlConfigPbWrapper.get_gbml_config_pb_wrapper_from_uri( + gbml_config_uri=UriFactory.create_uri(task_config_uri) + ) + + # Training Hyperparameters + trainer_args = dict(gbml_config_pb_wrapper.trainer_config.trainer_args) + + if torch.cuda.is_available(): + default_local_world_size = torch.cuda.device_count() + else: + default_local_world_size = 2 + local_world_size = int( + trainer_args.get("local_world_size", str(default_local_world_size)) + ) + + if torch.cuda.is_available(): + if local_world_size > torch.cuda.device_count(): + raise ValueError( + f"Specified a local world size of {local_world_size} which exceeds the number of devices {torch.cuda.device_count()}" + ) + + fanout = trainer_args.get("num_neighbors", "[10, 10]") + num_neighbors = parse_fanout(fanout) + + sampling_workers_per_process: int = int( + trainer_args.get("sampling_workers_per_process", "4") + ) + + main_batch_size = int(trainer_args.get("main_batch_size", "16")) + random_batch_size = int(trainer_args.get("random_batch_size", "16")) + + hid_dim = int(trainer_args.get("hid_dim", "16")) + out_dim = int(trainer_args.get("out_dim", "16")) + + sampling_worker_shared_channel_size: str = trainer_args.get( + "sampling_worker_shared_channel_size", "4GB" + ) + + process_start_gap_seconds = int(trainer_args.get("process_start_gap_seconds", "0")) + log_every_n_batch = int(trainer_args.get("log_every_n_batch", "25")) + + learning_rate = float(trainer_args.get("learning_rate", "0.0005")) + weight_decay = float(trainer_args.get("weight_decay", "0.0005")) + num_max_train_batches = int(trainer_args.get("num_max_train_batches", "1000")) + num_val_batches = int(trainer_args.get("num_val_batches", "100")) + val_every_n_batch = int(trainer_args.get("val_every_n_batch", "50")) + + logger.info( + f"Got training args local_world_size={local_world_size}, \ + num_neighbors={num_neighbors}, \ + sampling_workers_per_process={sampling_workers_per_process}, \ + main_batch_size={main_batch_size}, \ + random_batch_size={random_batch_size}, \ + hid_dim={hid_dim}, \ + out_dim={out_dim}, \ + sampling_worker_shared_channel_size={sampling_worker_shared_channel_size}, \ + process_start_gap_seconds={process_start_gap_seconds}, \ + log_every_n_batch={log_every_n_batch}, \ + learning_rate={learning_rate}, \ + weight_decay={weight_decay}, \ + num_max_train_batches={num_max_train_batches}, \ + num_val_batches={num_val_batches}, \ + val_every_n_batch={val_every_n_batch}" + ) + + # Step 3: Extract model/data config + graph_metadata = gbml_config_pb_wrapper.graph_metadata_pb_wrapper + + node_feature_dim = gbml_config_pb_wrapper.preprocessed_metadata_pb_wrapper.condensed_node_type_to_feature_dim_map[ + graph_metadata.homogeneous_condensed_node_type + ] + edge_feature_dim = gbml_config_pb_wrapper.preprocessed_metadata_pb_wrapper.condensed_edge_type_to_feature_dim_map[ + graph_metadata.homogeneous_condensed_edge_type + ] + + model_uri = UriFactory.create_uri( + gbml_config_pb_wrapper.gbml_config_pb.shared_config.trained_model_metadata.trained_model_uri + ) + + should_skip_training = gbml_config_pb_wrapper.shared_config.should_skip_training + + # Step 4: Create shared dict for inter-process tensor sharing + mp_sharing_dict = mp.Manager().dict() + + # Step 5: Spawn training processes + logger.info("--- Launching training processes ...\n") + flush() + start_time = time.time() + + training_args = TrainingProcessArgs( + local_world_size=local_world_size, + cluster_info=cluster_info, + mp_sharing_dict=mp_sharing_dict, + model_uri=model_uri, + hid_dim=hid_dim, + out_dim=out_dim, + node_feature_dim=node_feature_dim, + edge_feature_dim=edge_feature_dim, + num_neighbors=num_neighbors, + sampling_workers_per_process=sampling_workers_per_process, + sampling_worker_shared_channel_size=sampling_worker_shared_channel_size, + process_start_gap_seconds=process_start_gap_seconds, + main_batch_size=main_batch_size, + random_batch_size=random_batch_size, + learning_rate=learning_rate, + weight_decay=weight_decay, + num_max_train_batches=num_max_train_batches, + num_val_batches=num_val_batches, + val_every_n_batch=val_every_n_batch, + log_every_n_batch=log_every_n_batch, + should_skip_training=should_skip_training, + ) + + torch.multiprocessing.spawn( + _training_process, + args=(training_args,), + nprocs=local_world_size, + join=True, + ) + logger.info(f"--- Training finished, took {time.time() - start_time} seconds") + logger.info( + f"--- Program finished, which took {time.time() - program_start_time:.2f} seconds" + ) + flush() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Arguments for distributed model training on VertexAI (graph store mode)" + ) + parser.add_argument("--task_config_uri", type=str, help="Gbml config uri") + + args, unused_args = parser.parse_known_args() + logger.info(f"Unused arguments: {unused_args}") + + _run_example_training( + task_config_uri=args.task_config_uri, + ) diff --git a/examples/link_prediction/graph_store/storage_main.py b/examples/link_prediction/graph_store/storage_main.py index 4de929df1..ff777378e 100644 --- a/examples/link_prediction/graph_store/storage_main.py +++ b/examples/link_prediction/graph_store/storage_main.py @@ -72,6 +72,7 @@ """ import argparse +import ast import os from distutils.util import strtobool from typing import Literal, Optional, Union @@ -80,6 +81,7 @@ from gigl.common import Uri, UriFactory from gigl.common.logger import Logger +from gigl.common.utils.os_utils import import_obj from gigl.distributed.graph_store.storage_utils import ( build_storage_dataset, run_storage_server, @@ -101,7 +103,6 @@ def storage_node_process( should_load_tf_records_in_parallel: bool = True, tf_record_uri_pattern: str = r".*-of-.*\.tfrecord(\.gz)?$", ssl_positive_label_percentage: Optional[float] = None, - storage_world_backend: Optional[str] = None, ) -> None: """Run a storage node process. @@ -113,8 +114,7 @@ def storage_node_process( nodes for coordination (server comms). 2. Builds the dataset via :func:`~gigl.distributed.graph_store.storage_utils.build_storage_dataset`. - 3. Obtains free ports from the master node. - 4. Destroys the coordination process group and spawns one + 3. Destroys the coordination process group and spawns one :func:`~gigl.distributed.graph_store.storage_utils.run_storage_server` per session. @@ -123,12 +123,14 @@ def storage_node_process( cluster_info (GraphStoreInfo): The cluster information. task_config_uri (Uri): The task config URI. sample_edge_direction (Literal["in", "out"]): The sample edge direction. + num_server_sessions (int): Number of server sessions to run. For training, this should be 1 + (a single session for the entire training + testing lifecycle). For inference, this should + be one session per inference node type. splitter (Optional[Union[DistNodeAnchorLinkSplitter, DistNodeSplitter]]): The splitter to use. If None, will not split the dataset. tf_record_uri_pattern (str): The TF Record URI pattern. ssl_positive_label_percentage (Optional[float]): The percentage of edges to select as self-supervised labels. Must be None if supervised edge labels are provided in advance. If 0.1 is provided, 10% of the edges will be selected as self-supervised labels. - storage_world_backend (Optional[str]): The backend for the storage Torch Distributed process group. """ init_method = f"tcp://{cluster_info.storage_cluster_master_ip}:{cluster_info.storage_cluster_master_port}" logger.info( @@ -167,7 +169,6 @@ def storage_node_process( if __name__ == "__main__": - # TODO(kmonte): We want to expose splitter class here probably. parser = argparse.ArgumentParser() parser.add_argument("--task_config_uri", type=str, required=True) parser.add_argument("--resource_config_uri", type=str, required=True) @@ -177,14 +178,57 @@ def storage_node_process( parser.add_argument( "--should_load_tf_records_in_parallel", type=str, default="True" ) + # Splitter configuration: use import_obj to dynamically load a splitter class. + # This is needed for training (where the dataset needs train/val/test splits) but not for inference. + parser.add_argument( + "--splitter_cls_path", + type=str, + default=None, + help="Fully qualified import path to splitter class, e.g. 'gigl.utils.data_splitters.DistNodeAnchorLinkSplitter'", + ) + parser.add_argument( + "--splitter_kwargs", + type=str, + default=None, + help="Python dict literal of keyword arguments for the splitter constructor, " + "parsed with ast.literal_eval. Tuples are supported directly, e.g. " + "'supervision_edge_types': [('paper', 'to', 'author')].", + ) + parser.add_argument( + "--ssl_positive_label_percentage", + type=str, + default=None, + help="Percentage of edges to select as self-supervised labels. " + "Must be None if supervised edge labels are provided in advance.", + ) args = parser.parse_args() logger.info(f"Running storage node with arguments: {args}") + # Build splitter from args if provided. + # We use ast.literal_eval instead of json.loads so that Python tuples (e.g. for EdgeType) + # can be passed directly in the splitter_kwargs string without needing custom serialization. + splitter: Optional[Union[DistNodeAnchorLinkSplitter, DistNodeSplitter]] = None + ssl_positive_label_percentage: Optional[float] = None + if args.splitter_cls_path: + splitter_cls = import_obj(args.splitter_cls_path) + splitter_kwargs = ( + ast.literal_eval(args.splitter_kwargs) if args.splitter_kwargs else {} + ) + splitter = splitter_cls(**splitter_kwargs) + logger.info(f"Built splitter: {splitter}") + + if args.ssl_positive_label_percentage: + ssl_positive_label_percentage = float(args.ssl_positive_label_percentage) + # Setup cluster-wide (e.g. storage and compute nodes) Torch Distributed process group. # This is needed so we can get the cluster information (e.g. number of storage and compute nodes) and rank/world_size. torch.distributed.init_process_group(backend="gloo") cluster_info = get_graph_store_info() - # Tear down the """"global""" process group so we can have a server-specific process group. + logger.info(f"Cluster info: {cluster_info}") + logger.info( + f"World size: {torch.distributed.get_world_size()}, rank: {torch.distributed.get_rank()}, OS world size: {os.environ['WORLD_SIZE']}, OS rank: {os.environ['RANK']}" + ) + # Tear down the "global" process group so we can have a server-specific process group. torch.distributed.destroy_process_group() storage_node_process( storage_rank=cluster_info.storage_node_rank, @@ -192,6 +236,8 @@ def storage_node_process( task_config_uri=UriFactory.create_uri(args.task_config_uri), sample_edge_direction=args.sample_edge_direction, num_server_sessions=args.num_server_sessions, + splitter=splitter, + ssl_positive_label_percentage=ssl_positive_label_percentage, should_load_tf_records_in_parallel=bool( strtobool(args.should_load_tf_records_in_parallel) ), diff --git a/gigl/distributed/base_dist_loader.py b/gigl/distributed/base_dist_loader.py index d4ae3e452..b8d9ff63d 100644 --- a/gigl/distributed/base_dist_loader.py +++ b/gigl/distributed/base_dist_loader.py @@ -562,6 +562,8 @@ def shutdown(self) -> None: torch.futures.wait_all(rpc_futures) self._shutdowned = True + _MAX_EPOCH_CATCH_UP_RETRIES: int = 10 + # Overwrite DistLoader.__iter__ to so we can use our own __iter__ and rpc calls def __iter__(self) -> Self: self._num_recv = 0 @@ -570,7 +572,32 @@ def __iter__(self) -> Self: elif self._is_mp_worker: self._mp_producer.produce_all() else: - rpc_futures: list[torch.futures.Future[None]] = [] + self._request_new_epoch_production() + self._channel.reset() + self._epoch += 1 + return self + + def _request_new_epoch_production(self) -> None: + """Request production from all servers, retrying only on genuine epoch skew. + + In graph store mode, multiple GPUs on the same compute node share a + producer per server (same ``worker_key``). Only the first GPU to call + ``start_new_epoch_sampling`` for a given epoch triggers + ``produce_all()``; subsequent calls at the same epoch are no-ops + because the data is already flowing through the shared buffer. + + Two distinct cases are handled: + + * **Same epoch** (``self._epoch >= max_server_epoch``): another GPU + already triggered production for this epoch. Data is in the shared + buffer — return immediately without retrying. + * **Behind** (``self._epoch < max_server_epoch``): our epoch is + genuinely stale. Fast-forward past the server's epoch and retry so + ``produce_all()`` is guaranteed to fire. This typically resolves in + two iterations (first detects staleness, second triggers). + """ + for attempt in range(self._MAX_EPOCH_CATCH_UP_RETRIES): + rpc_futures: list[torch.futures.Future[tuple[int, bool]]] = [] for server_rank, producer_id in zip( self._server_rank_list, self._producer_id_list ): @@ -581,7 +608,33 @@ def __iter__(self) -> Self: self._epoch, ) rpc_futures.append(fut) - torch.futures.wait_all(rpc_futures) - self._channel.reset() - self._epoch += 1 - return self + + results = [fut.wait() for fut in rpc_futures] + any_produced = any(produced for _, produced in results) + + if any_produced: + return + + # No server produced — check whether we are genuinely behind or + # another GPU sharing the same producer simply beat us. + max_server_epoch = max(server_epoch for server_epoch, _ in results) + + if self._epoch >= max_server_epoch: + # Another GPU already triggered production for this epoch. + # Data is flowing through the shared buffer — nothing to do. + return + + # Our epoch is genuinely behind the server's. Fast-forward and + # retry so the next RPC has epoch > max_server_epoch. + logger.warning( + f"Epoch skew detected: client epoch {self._epoch} behind " + f"server epoch {max_server_epoch}. Retrying with epoch " + f"{max_server_epoch + 1} (attempt {attempt + 1})." + ) + self._epoch = max_server_epoch + 1 + + raise RuntimeError( + f"Failed to trigger production after " + f"{self._MAX_EPOCH_CATCH_UP_RETRIES} attempts. " + f"This indicates a persistent epoch skew." + ) diff --git a/gigl/distributed/graph_store/dist_server.py b/gigl/distributed/graph_store/dist_server.py index 1506432b2..47880cfdc 100644 --- a/gigl/distributed/graph_store/dist_server.py +++ b/gigl/distributed/graph_store/dist_server.py @@ -551,17 +551,30 @@ def destroy_sampling_producer(self, producer_id: int) -> None: self._msg_buffer_pool.pop(producer_id) self._epoch.pop(producer_id) - def start_new_epoch_sampling(self, producer_id: int, epoch: int) -> None: - r"""Start a new epoch sampling tasks for a specific sampling producer - with its producer id. + def start_new_epoch_sampling( + self, producer_id: int, epoch: int + ) -> tuple[int, bool]: + """Start a new epoch sampling for a specific sampling producer. + + Args: + producer_id: The unique id of the sampling producer. + epoch: The epoch requested by the client. + + Returns: + A tuple of (server_epoch, produced) where server_epoch is the + current epoch on the server after this call and produced indicates + whether ``produce_all()`` was triggered. """ with self._producer_lock[producer_id]: cur_epoch = self._epoch[producer_id] + produced = False if cur_epoch < epoch: self._epoch[producer_id] = epoch producer = self._producer_pool.get(producer_id, None) if producer is not None: producer.produce_all() + produced = True + return self._epoch[producer_id], produced def fetch_one_sampled_message( self, producer_id: int diff --git a/gigl/distributed/graph_store/remote_dist_sampling_worker_options.md b/gigl/distributed/graph_store/remote_dist_sampling_worker_options.md new file mode 100644 index 000000000..f3d439501 --- /dev/null +++ b/gigl/distributed/graph_store/remote_dist_sampling_worker_options.md @@ -0,0 +1,184 @@ +# `RemoteDistSamplingWorkerOptions` Deep Dive + +## Class Definition + +Defined in the installed GLT package at: +`graphlearn_torch/distributed/dist_options.py` (lines 210-291) + +It extends `_BasicDistSamplingWorkerOptions` (lines 26-117) and is designed for **Graph Store (server-client) mode**, +where sampling workers run on remote storage servers and results are sent back to compute nodes. + +## All Fields/Knobs + +### Inherited from `_BasicDistSamplingWorkerOptions` (lines 26-117) + +| Field | Type | Default | Description | +|---|---|---|---| +| `num_workers` | `int` | `1` | Number of sampling worker subprocesses to launch on the server for this client | +| `worker_devices` | `list[torch.device] \| None` | `None` | Device assignment per worker; auto-assigned if `None` | +| `worker_concurrency` | `int` | `4` | Max concurrent seed batches each worker processes simultaneously (clamped to [1, 32]) | +| `master_addr` | `str` | env `MASTER_ADDR` | Master address for RPC init of the sampling worker group | +| `master_port` | `int` | env `MASTER_PORT` + 1 | Master port for RPC init of the sampling worker group | +| `num_rpc_threads` | `int \| None` | `None` | RPC threads per sampling worker; auto-set to `min(num_partitions, 16)` if `None` | +| `rpc_timeout` | `float` | `180` | Timeout (seconds) for all RPC requests during sampling | + +### Specific to `RemoteDistSamplingWorkerOptions` (lines 210-291) + +| Field | Type | Default | Description | +|---|---|---|---| +| `server_rank` | `int \| list[int]` | auto-assigned | Which storage server(s) to create sampling workers on | +| `buffer_size` | `int \| str` | `"{num_workers * 64}MB"` | Size of server-side shared-memory buffer for sampled messages | +| `buffer_capacity` | computed | `num_workers * worker_concurrency` | Max messages the server-side buffer can hold | +| `prefetch_size` | `int` | `4` | Max prefetched messages on the **client** side (must be <= `buffer_capacity`) | +| `worker_key` | `str \| None` | `None` | Deduplication key -- same key reuses existing producer on server | +| `use_all2all` | `bool` | `False` | Use all2all collective for feature collection instead of point-to-point RPC | +| `glt_graph` | any | `None` | GraphScope only (not used by GiGL) | +| `workload_type` | `str \| None` | `None` | GraphScope only (not used by GiGL) | + +## How Each Field Is Used & Client vs. Server + +### `server_rank` -- CLIENT side + +The client reads this to know which servers to talk to: + +- `dist_loader.py:170-171` -- expanded to `_server_rank_list` +- `dist_loader.py:178` -- `request_server(self._server_rank_list[0], DistServer.get_dataset_meta)` to fetch metadata +- `dist_loader.py:188` -- loops over servers calling `DistServer.create_sampling_producer` +- `dist_loader.py:194` -- passed to `RemoteReceivingChannel` for receiving results +- `dist_loader.py:305-306` -- `request_server(server_rank, DistServer.start_new_epoch_sampling, ...)` + +### `num_workers` -- SERVER side + +Serialized and sent to the server via RPC. On the server: + +- `dist_sampling_producer.py:184` -- `self.num_workers = self.worker_options.num_workers` +- `dist_sampling_producer.py:208` -- spawns that many subprocesses via `mp_context.Process(...)` +- Also drives the defaults for `buffer_size` (line 281) and `buffer_capacity` (line 279) + +### `worker_devices` -- SERVER side + +- Auto-assigned via `_assign_worker_devices()` (`dist_options.py:113-116`) if `None` +- Used in `_sampling_worker_loop` (line 87): `current_device = worker_options.worker_devices[rank]` + +### `worker_concurrency` -- SERVER side + +Controls async parallelism within each sampling worker: + +- `_sampling_worker_loop:106` -- passed to `DistNeighborSampler(..., concurrency=worker_options.worker_concurrency)` +- In `ConcurrentEventLoop` (`event_loop.py:47`): creates a `BoundedSemaphore(concurrency)` limiting concurrent seed batches +- Also drives `buffer_capacity = num_workers * worker_concurrency` (line 279) + +### `master_addr` / `master_port` -- SERVER side + +Used by sampling worker subprocesses to form their own RPC group for cross-partition sampling: + +- `_sampling_worker_loop:93-98` -- `init_rpc(master_addr=..., master_port=..., ...)` + +### `num_rpc_threads` -- SERVER side + +- `_sampling_worker_loop:82-85` -- if `None`, auto-set to `min(data.num_partitions, 16)` +- Line 91: `torch.set_num_threads(num_rpc_threads + 1)` +- Line 93: passed to `init_rpc(...)` which sets `TensorPipeRpcBackendOptions.num_worker_threads` + +### `rpc_timeout` -- SERVER side + +- `_sampling_worker_loop:97` -- `init_rpc(..., rpc_timeout=...)` +- Sets the timeout for RPCs made by sampling workers when fetching graph partitions from other servers + +### `buffer_size` -- SERVER side + +- GLT `dist_server.py:158`: `ShmChannel(worker_options.buffer_capacity, worker_options.buffer_size)` +- GiGL `dist_server.py:456-457`: same usage +- Controls the total bytes of shared memory allocated for the message queue + +### `buffer_capacity` -- SERVER side + +- Computed as `num_workers * worker_concurrency` (`dist_options.py:279`) +- Passed as the first arg to `ShmChannel(capacity, size)` -- max messages before producers block + +### `prefetch_size` -- CLIENT side + +- `dist_loader.py:196` -- `RemoteReceivingChannel(..., prefetch_size)` +- In `remote_channel.py:47`: `self.prefetch_size = prefetch_size` +- Line 56: `queue.Queue(maxsize=self.prefetch_size * len(self.server_rank_list))` +- Lines 120-131: controls how many async RPC fetch requests are in-flight per server at any time + +### `worker_key` -- SERVER side (during producer creation) + +- GLT `dist_server.py:152`: `producer_id = self._worker_key2producer_id.get(worker_options.worker_key)` -- if already exists, reuses the producer +- GiGL `dist_server.py:444-453`: same pattern with per-producer locks + +### `use_all2all` -- SERVER side + +- `_sampling_worker_loop:73-80` -- if True, initializes `torch.distributed` process group (gloo backend) +- `dist_neighbor_sampler.py:749-753` -- switches from per-type `async_get()` to `get_all2all()` for feature collection + +## Client vs. Server Summary + +| Field | Side | Purpose | +|---|---|---| +| `server_rank` | **Client** | Which servers to send RPCs to | +| `num_workers` | **Server** | Sampling subprocesses per server | +| `worker_devices` | **Server** | Device per subprocess | +| `worker_concurrency` | **Server** | Concurrent batches per subprocess | +| `master_addr` / `master_port` | **Server** | RPC group for cross-partition sampling | +| `num_rpc_threads` | **Server** | RPC threads per sampling subprocess | +| `rpc_timeout` | **Server** | Timeout for cross-partition RPCs | +| `buffer_size` | **Server** | Shared-memory buffer bytes | +| `buffer_capacity` | **Server** | Shared-memory buffer message count | +| `prefetch_size` | **Client** | Prefetched messages per server | +| `worker_key` | **Server** | Producer deduplication | +| `use_all2all` | **Server** | Collective vs point-to-point features | + +The entire options object is **serialized and sent via RPC** from client to server (at `dist_loader.py:188` via +`DistServer.create_sampling_producer`). The server reads the server-side fields; the client reads `server_rank` and +`prefetch_size` locally. + +## How GiGL Uses It + +### `DistNeighborLoader._setup_for_graph_store()` + +**File:** `gigl/distributed/distributed_neighborloader.py:386-395` + +```python +worker_options = RemoteDistSamplingWorkerOptions( + server_rank=list(range(dataset.cluster_info.num_storage_nodes)), + num_workers=num_workers, + worker_devices=[torch.device("cpu") for i in range(num_workers)], + master_addr=dataset.cluster_info.storage_cluster_master_ip, + buffer_size=channel_size, # defaults to "4GB" + master_port=sampling_port, + worker_key=worker_key, # unique per compute rank + loader instance + prefetch_size=prefetch_size, # default 4 +) +``` + +GiGL talks to **all** storage servers (`server_rank=list(range(num_storage_nodes))`), always uses **CPU** sampling, and +assigns a unique `worker_key` per compute rank + loader instance (`distributed_neighborloader.py:384`). + +Notably, GiGL **bypasses GLT's `DistLoader.__init__`** in `_init_graph_store_connections()` (lines 609-837), +dispatching `create_sampling_producer` RPCs sequentially per compute node to avoid GLT's `ThreadPoolExecutor` deadlock +at large scale. + +### `DistABLPLoader._setup_for_graph_store()` + +**File:** `gigl/distributed/dist_ablp_neighborloader.py:799-808` + +```python +worker_options = RemoteDistSamplingWorkerOptions( + server_rank=list(range(dataset.cluster_info.num_storage_nodes)), + num_workers=num_workers, + worker_devices=[torch.device("cpu") for _ in range(num_workers)], + worker_concurrency=worker_concurrency, + master_addr=dataset.cluster_info.storage_cluster_master_ip, + master_port=sampling_port, + worker_key=worker_key, + prefetch_size=prefetch_size, +) +``` + +Nearly identical, except: + +- Explicitly passes `worker_concurrency` (default `4`) +- Does **not** set `buffer_size` (uses GLT default of `num_workers * 64 MB` instead of GiGL's `4GB`) +- Uses `ThreadPoolExecutor` for setup (lines 1002-1015) rather than the sequential barrier approach diff --git a/gigl/utils/data_splitters.py b/gigl/utils/data_splitters.py index 9e05ead4f..254552fe6 100644 --- a/gigl/utils/data_splitters.py +++ b/gigl/utils/data_splitters.py @@ -189,7 +189,7 @@ def __init__( num_val: float = 0.1, num_test: float = 0.1, hash_function: Callable[[torch.Tensor], torch.Tensor] = _fast_hash, - supervision_edge_types: Optional[list[EdgeType]] = None, + supervision_edge_types: Optional[list[Union[EdgeType, PyGEdgeType]]] = None, should_convert_labels_to_edges: bool = True, ): """Initializes the DistNodeAnchorLinkSplitter. @@ -199,7 +199,7 @@ def __init__( num_val (float): The percentage of nodes to use for training. Defaults to 0.1 (10%). num_test (float): The percentage of nodes to use for validation. Defaults to 0.1 (10%). hash_function (Callable[[torch.Tensor, torch.Tensor], torch.Tensor]): The hash function to use. Defaults to `_fast_hash`. - supervision_edge_types (Optional[list[EdgeType]]): The supervision edge types we should use for splitting. + supervision_edge_types (Optional[list[Union[EdgeType, PyGEdgeType]]]): The supervision edge types we should use for splitting. Must be provided if we are splitting a heterogeneous graph. If None, uses the default message passing edge type in the graph. should_convert_labels_to_edges (bool): Whether label should be converted into an edge type in the graph. If provided, will make `gigl.distributed.build_dataset` convert all labels into edges, and will infer positive and negative edge types based on @@ -232,7 +232,10 @@ def __init__( # also be ("user", "positive", "story"), meaning that all edges in the loaded edge index tensor with this edge type will be treated as a labeled # edge and will be used for splitting. - self._supervision_edge_types: Sequence[EdgeType] = supervision_edge_types + self._supervision_edge_types: Sequence[EdgeType] = [ + EdgeType(*supervision_edge_type) + for supervision_edge_type in supervision_edge_types + ] self._labeled_edge_types: Sequence[EdgeType] if should_convert_labels_to_edges: labeled_edge_types = [ diff --git a/gigl/utils/iterator.py b/gigl/utils/iterator.py index 63f809083..219cf4fdb 100644 --- a/gigl/utils/iterator.py +++ b/gigl/utils/iterator.py @@ -1,6 +1,8 @@ from collections.abc import Iterable, Iterator from typing import TypeVar +import torch + _T = TypeVar("_T") @@ -20,5 +22,13 @@ def __next__(self) -> _T: try: return next(self._iter) except StopIteration: + if torch.distributed.is_initialized(): + print( + f"rank={torch.distributed.get_rank()}: InfiniteIterator: _iterable={self._iterable} exhausted, resetting iterator" + ) + else: + print( + f"InfiniteIterator: _iterable={self._iterable} exhausted, resetting iterator" + ) self._iter = iter(self._iterable) return next(self._iter) diff --git a/tests/e2e_tests/e2e_tests.yaml b/tests/e2e_tests/e2e_tests.yaml index 0f3691e80..c6735ce27 100644 --- a/tests/e2e_tests/e2e_tests.yaml +++ b/tests/e2e_tests/e2e_tests.yaml @@ -1,24 +1,24 @@ # Combined e2e test configurations for GiGL # This file contains all the test specifications that can be run via the e2e test script tests: - cora_nalp_test: - task_config_uri: "gigl/src/mocking/configs/e2e_node_anchor_based_link_prediction_template_gbml_config.yaml" - resource_config_uri: "${oc.env:GIGL_TEST_DEFAULT_RESOURCE_CONFIG,deployment/configs/e2e_cicd_resource_config.yaml}" - cora_snc_test: - task_config_uri: "gigl/src/mocking/configs/e2e_supervised_node_classification_template_gbml_config.yaml" - resource_config_uri: "${oc.env:GIGL_TEST_DEFAULT_RESOURCE_CONFIG,deployment/configs/e2e_cicd_resource_config.yaml}" - cora_udl_test: - task_config_uri: "gigl/src/mocking/configs/e2e_udl_node_anchor_based_link_prediction_template_gbml_config.yaml" - resource_config_uri: "${oc.env:GIGL_TEST_DEFAULT_RESOURCE_CONFIG,deployment/configs/e2e_cicd_resource_config.yaml}" - dblp_nalp_test: - task_config_uri: "gigl/src/mocking/configs/dblp_node_anchor_based_link_prediction_template_gbml_config.yaml" - resource_config_uri: "${oc.env:GIGL_TEST_DEFAULT_RESOURCE_CONFIG,deployment/configs/e2e_cicd_resource_config.yaml}" - hom_cora_sup_test: - task_config_uri: "examples/link_prediction/configs/e2e_hom_cora_sup_task_config.yaml" - resource_config_uri: "${oc.env:GIGL_TEST_IN_MEMORY_DEFAULT_RESOURCE_CONFIG,deployment/configs/e2e_glt_resource_config.yaml}" - het_dblp_sup_test: - task_config_uri: "examples/link_prediction/configs/e2e_het_dblp_sup_task_config.yaml" - resource_config_uri: "${oc.env:GIGL_TEST_IN_MEMORY_DEFAULT_RESOURCE_CONFIG,deployment/configs/e2e_glt_resource_config.yaml}" + # cora_nalp_test: + # task_config_uri: "gigl/src/mocking/configs/e2e_node_anchor_based_link_prediction_template_gbml_config.yaml" + # resource_config_uri: "${oc.env:GIGL_TEST_DEFAULT_RESOURCE_CONFIG,deployment/configs/e2e_cicd_resource_config.yaml}" + # cora_snc_test: + # task_config_uri: "gigl/src/mocking/configs/e2e_supervised_node_classification_template_gbml_config.yaml" + # resource_config_uri: "${oc.env:GIGL_TEST_DEFAULT_RESOURCE_CONFIG,deployment/configs/e2e_cicd_resource_config.yaml}" + # cora_udl_test: + # task_config_uri: "gigl/src/mocking/configs/e2e_udl_node_anchor_based_link_prediction_template_gbml_config.yaml" + # resource_config_uri: "${oc.env:GIGL_TEST_DEFAULT_RESOURCE_CONFIG,deployment/configs/e2e_cicd_resource_config.yaml}" + # dblp_nalp_test: + # task_config_uri: "gigl/src/mocking/configs/dblp_node_anchor_based_link_prediction_template_gbml_config.yaml" + # resource_config_uri: "${oc.env:GIGL_TEST_DEFAULT_RESOURCE_CONFIG,deployment/configs/e2e_cicd_resource_config.yaml}" + # hom_cora_sup_test: + # task_config_uri: "examples/link_prediction/configs/e2e_hom_cora_sup_task_config.yaml" + # resource_config_uri: "${oc.env:GIGL_TEST_IN_MEMORY_DEFAULT_RESOURCE_CONFIG,deployment/configs/e2e_glt_resource_config.yaml}" + # het_dblp_sup_test: + # task_config_uri: "examples/link_prediction/configs/e2e_het_dblp_sup_task_config.yaml" + # resource_config_uri: "${oc.env:GIGL_TEST_IN_MEMORY_DEFAULT_RESOURCE_CONFIG,deployment/configs/e2e_glt_resource_config.yaml}" hom_cora_sup_gs_test: task_config_uri: "examples/link_prediction/graph_store/configs/e2e_hom_cora_sup_gs_task_config.yaml" resource_config_uri: "${oc.env:GIGL_TEST_IN_MEMORY_DEFAULT_GRAPH_STORE_RESOURCE_CONFIG,deployment/configs/e2e_glt_gs_resource_config.yaml}" diff --git a/tests/integration/distributed/graph_store/graph_store_integration_test.py b/tests/integration/distributed/graph_store/graph_store_integration_test.py index f29a1fd97..34b9c4471 100644 --- a/tests/integration/distributed/graph_store/graph_store_integration_test.py +++ b/tests/integration/distributed/graph_store/graph_store_integration_test.py @@ -217,11 +217,13 @@ def _run_compute_train_tests( ) _assert_ablp_input(cluster_info, ablp_result) + # For labeled homogeneous, pass the dict directly (not as tuple) + input_nodes = ablp_result ablp_loader = DistABLPLoader( dataset=remote_dist_dataset, num_neighbors=[2, 2], - input_nodes=ablp_result, + input_nodes=input_nodes, pin_memory_device=torch.device("cpu"), num_workers=2, worker_concurrency=2, @@ -246,6 +248,7 @@ def _run_compute_train_tests( for i, (ablp_batch, random_negative_batch) in enumerate( zip(ablp_loader, random_negative_loader) ): + # Verify batch structure assert hasattr(ablp_batch, "y_positive"), "Batch should have y_positive labels" # y_positive should be dict mapping local anchor idx -> local label indices assert isinstance(