diff --git a/experiments/OGB-LSC/Trainer.py b/experiments/OGB-LSC/Trainer.py index 385d6fa..55d96f6 100644 --- a/experiments/OGB-LSC/Trainer.py +++ b/experiments/OGB-LSC/Trainer.py @@ -159,35 +159,52 @@ def evaluate(self): comm_plans = self.dataset.get_NCCL_comm_plans() out = self.model(xs, edge_type, comm_plans) - y_pred = out.argmax(dim=-1, keepdim=True).cpu().numpy() - train_mask = self.dataset.get_mask("train").cpu().numpy() - val_mask = self.dataset.get_mask("val").cpu().numpy() - test_mask = self.dataset.get_mask("test").cpu().numpy() - y_true_train = self.dataset.get_target("train").cpu().numpy() - y_pred_val = self.dataset.get_target("val").cpu().numpy() - y_pred_test = self.dataset.get_target("test").cpu().numpy() - - train_acc = (y_pred[train_mask] == y_true_train).sum() / int(train_mask.sum()) + y_pred = out.argmax(dim=-1) # [1, num_nodes] + train_idx = self.dataset.get_mask("train") + val_idx = self.dataset.get_mask("val") + test_idx = self.dataset.get_mask("test") + y_true_train = self.dataset.get_target("train").flatten() + y_true_val = self.dataset.get_target("val").flatten() + y_true_test = self.dataset.get_target("test").flatten() + + # Move indices/targets to the same device for indexing + device = y_pred.device + train_idx = train_idx.to(device) + val_idx = val_idx.to(device) + test_idx = test_idx.to(device) + y_true_train = y_true_train.to(device) + y_true_val = y_true_val.to(device) + y_true_test = y_true_test.to(device) + + num_local_train_samples = int(train_idx.numel()) + num_local_val_samples = int(val_idx.numel()) + num_local_test_samples = int(test_idx.numel()) + + # Train accuracy (global) + if num_local_train_samples == 0: + train_acc = 0.0 + else: + train_correct = (y_pred[0, train_idx] == y_true_train).sum().item() + train_correct = GetGlobalVal(train_correct) + num_global_train_samples = GetGlobalVal(num_local_train_samples) + train_acc = train_correct / int(num_global_train_samples) + # Not guaranteed to have validation or test samples on every rank - num_local_val_samples = int(val_mask.sum()) - num_local_test_samples = int(test_mask.sum()) if num_local_val_samples == 0: val_acc = 0.0 else: - val_acc = (y_pred[val_mask] == y_pred_val).sum().item() - val_acc = GetGlobalVal(val_acc) - - num_global_val_samples = GetGlobalVal(num_local_val_samples) - val_acc = val_acc / int(num_global_val_samples) + val_correct = (y_pred[0, val_idx] == y_true_val).sum().item() + val_correct = GetGlobalVal(val_correct) + num_global_val_samples = GetGlobalVal(num_local_val_samples) + val_acc = val_correct / int(num_global_val_samples) if num_local_test_samples == 0: test_acc = 0.0 else: - test_acc = (y_pred[test_mask] == y_pred_test).sum().item() - - test_acc = GetGlobalVal(test_acc) - num_global_test_samples = GetGlobalVal(num_local_test_samples) - test_acc = test_acc / int(num_global_test_samples) + test_correct = (y_pred[0, test_idx] == y_true_test).sum().item() + test_correct = GetGlobalVal(test_correct) + num_global_test_samples = GetGlobalVal(num_local_test_samples) + test_acc = test_correct / int(num_global_test_samples) # All ranks should have the same accuracy values diff --git a/experiments/OGB-LSC/main.py b/experiments/OGB-LSC/main.py index 18a8b56..a068c75 100644 --- a/experiments/OGB-LSC/main.py +++ b/experiments/OGB-LSC/main.py @@ -107,8 +107,14 @@ def main( trainer = Trainer(graph_dataset, comm) trainer.prepare_data() trainer.train() - comm.destroy() + print("Training completed!!") + + print("Evaluating the model...") + train_acc, val_acc, test_acc = trainer.evaluate() + if comm.get_rank() == 0: + print(f"Acc: train={train_acc:.4f} val={val_acc:.4f} test={test_acc:.4f}") + comm.destroy() if dist.is_initialized(): dist.destroy_process_group()