Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
06b957b
feat: topology-aware inference placement for non-colocated vLLM
ananthsub May 28, 2026
d5970ef
bug fix in topology-aware placement
youngeunkwon0405 May 20, 2026
2c7af4e
fix: make vllm placement group init idempotent
youngeunkwon0405 Jun 3, 2026
db9e24e
fix: restore SBATCH singleton dependency in ray.sub
terrykong Jun 5, 2026
69545b9
fix(ray.sub): force base-10 for TOPO_RANK fallbacks to avoid invalid …
terrykong Jun 5, 2026
b51e143
Revert "fix: make vllm placement group init idempotent"
terrykong Jun 5, 2026
b01c048
test
terrykong Jun 5, 2026
1826f03
fix: remove unnecessary try/except in _get_gpu_id_info
terrykong Jun 5, 2026
c1f547a
feat: add segment_size to ClusterConfig and exemplar configs
terrykong Jun 5, 2026
1556d11
feat: add topology-aware placement to SFT
terrykong Jun 5, 2026
5ca28b2
feat: add topology-aware placement to DPO
terrykong Jun 5, 2026
dd1813c
feat: add topology-aware placement to distillation
terrykong Jun 5, 2026
adcd78d
test: add unit tests for topology-aware placement
terrykong Jun 5, 2026
ec11276
refactor: extract prepare_segment_topology to eliminate topology boil…
terrykong Jun 5, 2026
1d18106
feat: add topology-aware placement to RM
terrykong Jun 5, 2026
03682f9
fix: use generation_config[backend] to detect vLLM vs SGLang for infe…
terrykong Jun 5, 2026
cae36c7
fix: assert bundle count matches world_size after topology sort
terrykong Jun 5, 2026
f70121f
fix: make ClusterConfig.segment_size NotRequired to avoid breaking ex…
terrykong Jun 5, 2026
dc00082
ci: trigger DCO re-check with updated base branch
terrykong Jun 5, 2026
c05984a
Merge remote-tracking branch 'origin/main' into tde/megatron_inf_debug
tdene Jun 21, 2026
1aee1ba
fix: Reconcile #2315 and #2612
tdene May 29, 2026
48c74fa
fix: Reconcile #2612 and #2267
tdene Jun 21, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions examples/configs/distillation_math.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -281,3 +281,4 @@ cluster:
num_nodes: 1
master_port_range_low: 25000
master_port_range_high: 28000
segment_size: null # Nodes per NVLink domain segment for topology-aware alignment; null to disable
1 change: 1 addition & 0 deletions examples/configs/distillation_math_megatron.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -188,3 +188,4 @@ logger:
cluster:
gpus_per_node: 8
num_nodes: 1
segment_size: null # Nodes per NVLink domain segment for topology-aware alignment; null to disable
1 change: 1 addition & 0 deletions examples/configs/dpo.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -311,3 +311,4 @@ logger:
cluster:
gpus_per_node: 1
num_nodes: 1
segment_size: null # Nodes per NVLink domain segment for topology-aware alignment; null to disable
1 change: 1 addition & 0 deletions examples/configs/grpo_math_1B.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -441,6 +441,7 @@ cluster:
# (32768-60999 on stock Linux). See ray.sub for the full port layout.
master_port_range_low: 25000
master_port_range_high: 28000
segment_size: null # Nodes per NVLink domain segment for topology-aware alignment; null to disable

# TransferQueue-mediated data plane for sync GRPO.
# Off by default — the legacy grpo_train trainer never engages this.
Expand Down
1 change: 1 addition & 0 deletions examples/configs/grpo_math_1B_megatron.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -210,3 +210,4 @@ logger:
cluster:
gpus_per_node: 1
num_nodes: 1
segment_size: null # Nodes per NVLink domain segment for topology-aware alignment; null to disable
1 change: 1 addition & 0 deletions examples/configs/grpo_math_70B_megatron.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -64,3 +64,4 @@ policy:
cluster:
gpus_per_node: 8
num_nodes: 8
segment_size: null # Nodes per NVLink domain segment for topology-aware alignment; null to disable
1 change: 1 addition & 0 deletions examples/configs/grpo_math_8B.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -60,3 +60,4 @@ policy:
cluster:
gpus_per_node: 8
num_nodes: 1
segment_size: null # Nodes per NVLink domain segment for topology-aware alignment; null to disable
1 change: 1 addition & 0 deletions examples/configs/grpo_math_8B_megatron.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -76,3 +76,4 @@ policy:
cluster:
gpus_per_node: 8
num_nodes: 1
segment_size: null # Nodes per NVLink domain segment for topology-aware alignment; null to disable
1 change: 1 addition & 0 deletions examples/configs/grpo_math_qwen30ba3b_megatron.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -77,3 +77,4 @@ policy:
cluster:
gpus_per_node: 8
num_nodes: 8
segment_size: null # Nodes per NVLink domain segment for topology-aware alignment; null to disable
1 change: 1 addition & 0 deletions examples/configs/grpo_rm_1B.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -42,3 +42,4 @@ env:
cluster:
gpus_per_node: 2
num_nodes: 1
segment_size: null # Nodes per NVLink domain segment for topology-aware alignment; null to disable
1 change: 1 addition & 0 deletions examples/configs/rm.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -220,3 +220,4 @@ logger:
cluster:
gpus_per_node: 1
num_nodes: 1
segment_size: null # Nodes per NVLink domain segment for topology-aware alignment; null to disable
1 change: 1 addition & 0 deletions examples/configs/sft.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -290,3 +290,4 @@ logger:
cluster:
gpus_per_node: 1
num_nodes: 1
segment_size: null # Nodes per NVLink domain segment for topology-aware alignment; null to disable
1 change: 1 addition & 0 deletions examples/configs/sft_openmathinstruct2_megatron.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -42,3 +42,4 @@ logger:
name: llama8b
cluster:
num_nodes: 2
segment_size: null # Nodes per NVLink domain segment for topology-aware alignment; null to disable
1 change: 1 addition & 0 deletions examples/configs/sft_vlm_3B.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -55,3 +55,4 @@ logger:
cluster:
gpus_per_node: 2
num_nodes: 1
segment_size: null # Nodes per NVLink domain segment for topology-aware alignment; null to disable
38 changes: 35 additions & 3 deletions nemo_rl/algorithms/distillation.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
from nemo_rl.distributed.virtual_cluster import (
ClusterConfig,
RayVirtualCluster,
prepare_segment_topology,
)
from nemo_rl.environments.interfaces import EnvironmentInterface
from nemo_rl.environments.nemo_gym import (
Expand Down Expand Up @@ -305,6 +306,7 @@ def setup(
# ==========================
print("\n▶ Setting up compute cluster...", flush=True)
colocated_inference = generation_config["colocated"]["enabled"]
segment_size = cluster_config.get("segment_size")
enable_nemo_gym = bool(env_configs) and _should_use_nemo_gym(master_config)
nemo_gym_actor: Optional[EnvironmentInterface] = None
if enable_nemo_gym:
Expand All @@ -315,22 +317,27 @@ def setup(
ray_cur_node_id = None

if colocated_inference:
num_nodes = cluster_config["num_nodes"]
node_resource_constraints, _, _ = prepare_segment_topology(
segment_size, num_nodes
)
cluster = RayVirtualCluster(
name="distillation_cluster",
bundle_ct_per_node_list=[cluster_config["gpus_per_node"]]
* cluster_config["num_nodes"],
bundle_ct_per_node_list=[cluster_config["gpus_per_node"]] * num_nodes,
use_gpus=True,
num_gpus_per_node=cluster_config["gpus_per_node"],
max_colocated_worker_groups=1
if generation_config["backend"] == "megatron"
else 3,
port_range_low=cluster_config.get("master_port_range_low"),
port_range_high=cluster_config.get("master_port_range_high"),
segment_size=segment_size,
node_resource_constraints=node_resource_constraints,
)
train_cluster = cluster
inference_cluster = cluster
print(
f" ✓ Ray cluster initialized with {cluster_config['num_nodes']} nodes",
f" ✓ Ray cluster initialized with {num_nodes} nodes",
flush=True,
)
else:
Expand Down Expand Up @@ -379,6 +386,27 @@ def setup(
)
train_nodes -= inference_nodes

# Topology-aware node selection for non-colocated distillation
node_resource_constraints = None
inference_node_resource_constraints = None
inference_segment_size = None
node_resource_constraints, remaining_node_ids, topology = (
prepare_segment_topology(segment_size, train_nodes, role="training")
)
if node_resource_constraints is not None and inference_nodes > 0:
nodes_per_instance = (
inference_gpus_per_node + cluster_config["gpus_per_node"] - 1
) // cluster_config["gpus_per_node"]
if nodes_per_instance > 1 and inference_nodes % nodes_per_instance == 0:
remaining_topology = {nid: topology[nid] for nid in remaining_node_ids}
inference_node_resource_constraints, _, _ = prepare_segment_topology(
nodes_per_instance,
inference_nodes,
topology=remaining_topology,
role="inference",
)
inference_segment_size = nodes_per_instance

# create clusters
train_cluster = RayVirtualCluster(
name="distillation_train_cluster",
Expand All @@ -388,6 +416,8 @@ def setup(
max_colocated_worker_groups=3,
port_range_low=cluster_config.get("master_port_range_low"),
port_range_high=cluster_config.get("master_port_range_high"),
segment_size=segment_size,
node_resource_constraints=node_resource_constraints,
)
inference_cluster = RayVirtualCluster(
name="distillation_inference_cluster",
Expand All @@ -397,6 +427,8 @@ def setup(
max_colocated_worker_groups=3,
port_range_low=cluster_config.get("master_port_range_low"),
port_range_high=cluster_config.get("master_port_range_high"),
segment_size=inference_segment_size,
node_resource_constraints=inference_node_resource_constraints,
)
print(
f" ✓ Separate clusters created: train={train_nodes}x{train_gpus_per_node}GPUs, inference={inference_nodes}x{inference_gpus_per_node}GPUs",
Expand Down
11 changes: 8 additions & 3 deletions nemo_rl/algorithms/dpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from nemo_rl.distributed.virtual_cluster import (
ClusterConfig,
RayVirtualCluster,
prepare_segment_topology,
)
from nemo_rl.models.policy import PolicyConfig
from nemo_rl.models.policy.interfaces import PolicyInterface
Expand Down Expand Up @@ -233,17 +234,21 @@ def setup(
# Cluster
# ==========================
print("\n▶ Setting up compute cluster...")
num_nodes = cluster_config["num_nodes"]
segment_size = cluster_config.get("segment_size")
node_resource_constraints, _, _ = prepare_segment_topology(segment_size, num_nodes)
cluster = RayVirtualCluster(
name="dpo_cluster",
bundle_ct_per_node_list=[cluster_config["gpus_per_node"]]
* cluster_config["num_nodes"],
bundle_ct_per_node_list=[cluster_config["gpus_per_node"]] * num_nodes,
use_gpus=True,
num_gpus_per_node=cluster_config["gpus_per_node"],
max_colocated_worker_groups=1,
port_range_low=cluster_config.get("master_port_range_low"),
port_range_high=cluster_config.get("master_port_range_high"),
segment_size=segment_size,
node_resource_constraints=node_resource_constraints,
)
print(f" ✓ Ray cluster initialized with {cluster_config['num_nodes']} nodes")
print(f" ✓ Ray cluster initialized with {num_nodes} nodes")

# ==========================
# Training
Expand Down
128 changes: 127 additions & 1 deletion nemo_rl/algorithms/grpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,11 @@
from nemo_rl.distributed.batched_data_dict import BatchedDataDict
from nemo_rl.distributed.ray_actor_environment_registry import get_actor_python_env
from nemo_rl.distributed.virtual_cluster import (
TOPO_RANK_UNKNOWN,
ClusterConfig,
RayVirtualCluster,
get_ray_cluster_topology,
prepare_segment_topology,
)
from nemo_rl.environments.interfaces import EnvironmentInterface
from nemo_rl.environments.nemo_gym import (
Expand Down Expand Up @@ -507,6 +510,7 @@ def _spinup_nemo_gym(base_urls, model_name):
return actor, time.perf_counter() - t0

total_nodes = cluster_config["num_nodes"]
segment_size = cluster_config.get("segment_size")
if rm_env_enabled:
rm_resource = env_configs["reward_model"]["resources"]
rm_nodes = rm_resource["num_nodes"]
Expand Down Expand Up @@ -535,6 +539,9 @@ def _spinup_nemo_gym(base_urls, model_name):
else:
policy_gpus_per_node = cluster_config["gpus_per_node"]

node_resource_constraints, _, _ = prepare_segment_topology(
segment_size, policy_nodes
)
cluster = RayVirtualCluster(
name="grpo_policy_cluster",
bundle_ct_per_node_list=[policy_gpus_per_node] * policy_nodes,
Expand All @@ -545,6 +552,8 @@ def _spinup_nemo_gym(base_urls, model_name):
else 2,
port_range_low=cluster_config.get("master_port_range_low"),
port_range_high=cluster_config.get("master_port_range_high"),
segment_size=segment_size,
node_resource_constraints=node_resource_constraints,
)
train_cluster = cluster
inference_cluster = cluster
Expand Down Expand Up @@ -610,6 +619,102 @@ def _spinup_nemo_gym(base_urls, model_name):
)
train_nodes -= inference_nodes

assert train_nodes > 0 and inference_nodes > 0, (
f"Non-colocated mode requires train_nodes > 0 and inference_nodes > 0, "
f"got train_nodes={train_nodes}, inference_nodes={inference_nodes}"
)

# Build topology-aware domain constraints for placement groups.
# Each selected node's bundles are pinned to a specific NVLink domain so
# that EP groups stay within high-bandwidth switch fabrics.
#
# NOTE: segment_size is also passed to RayVirtualCluster and used later
# by _sort_bundle_indices_by_topology to trim incomplete domain segments
# when ordering ranks. When constraints successfully pin nodes to
# complete segments, that post-placement trimming is a no-op. It serves
# as defense-in-depth for the fallback path where constraints are absent.
node_resource_constraints = None
inference_node_resource_constraints = None
inference_segment_size = None
if segment_size is not None:
topology = get_ray_cluster_topology()
num_alive_nodes = len(topology)
required_nodes = train_nodes + inference_nodes
assert num_alive_nodes >= required_nodes, (
f"Not enough alive Ray nodes for all roles: "
f"need {required_nodes} (train={train_nodes} + inference={inference_nodes}), "
f"but only {num_alive_nodes} alive nodes found"
)
node_resource_constraints, remaining_node_ids, topology = (
prepare_segment_topology(
segment_size, train_nodes, topology=topology, role="training"
)
)
# Warn if any selected training node lacks topo_rank — domain pinning
# still works but intra-domain rank ordering will be arbitrary.
if node_resource_constraints is not None:
training_node_ids = set(topology) - set(remaining_node_ids)
nodes_missing_topo_rank = [
nid
for nid in training_node_ids
if topology[nid][1] == TOPO_RANK_UNKNOWN
]
if nodes_missing_topo_rank:
print(
f" ⚠ {len(nodes_missing_topo_rank)} selected training nodes have NVLink domain "
f"info but no topo_rank; intra-domain rank ordering may be suboptimal",
flush=True,
)

# Inference topology: each inference instance spans
# nodes_per_instance nodes; keep those within one domain
# so cross-node all-reduce uses NVLink, not InfiniBand.
#
# For vLLM: total GPUs per instance = TP * PP (separate dimensions).
# For SGLang: gpus_per_server already includes all parallelism
# dimensions (TP, DP-attention, PP are internal subdivisions),
# so we use it directly without multiplying by pp_size.
# For Megatron: inference reuses the training megatron_cfg
# parallelism, so an instance spans TP * PP * CP GPUs.
if generation_config["backend"] == "megatron":
megatron_cfg = policy_config["megatron_cfg"]
gpus_per_instance = (
megatron_cfg["tensor_model_parallel_size"]
* megatron_cfg["pipeline_model_parallel_size"]
* megatron_cfg["context_parallel_size"]
)
elif generation_config["backend"] == "vllm":
vllm_cfg = generation_config.get("vllm_cfg", {})
gpus_per_instance = vllm_cfg["tensor_parallel_size"] * vllm_cfg.get(
"pipeline_parallel_size", 1
)
else:
sglang_cfg = generation_config.get("sglang_cfg", {})
gpus_per_instance = sglang_cfg.get("gpus_per_server", 1)
nodes_per_instance = (
gpus_per_instance + inference_gpus_per_node - 1
) // inference_gpus_per_node
if nodes_per_instance > 1 and inference_nodes % nodes_per_instance == 0:
remaining_topology = {
nid: topology[nid] for nid in remaining_node_ids
}
inference_node_resource_constraints, _, _ = (
prepare_segment_topology(
nodes_per_instance,
inference_nodes,
topology=remaining_topology,
role="inference",
)
)
inference_segment_size = nodes_per_instance
elif nodes_per_instance > 1:
print(
f" ⚠ inference_nodes={inference_nodes} is not divisible by "
f"nodes_per_instance={nodes_per_instance} (gpus_per_instance={gpus_per_instance}); "
f"skipping inference topology constraints",
flush=True,
)

# initialize train cluster
train_cluster = RayVirtualCluster(
name="grpo_train_cluster",
Expand All @@ -619,13 +724,21 @@ def _spinup_nemo_gym(base_urls, model_name):
max_colocated_worker_groups=1,
port_range_low=cluster_config.get("master_port_range_low"),
port_range_high=cluster_config.get("master_port_range_high"),
segment_size=segment_size,
node_resource_constraints=node_resource_constraints,
)
# When domain constraints are set, eagerly create placement groups
# so training claims the constrained nodes before inference can grab them.
if node_resource_constraints is not None:
train_cluster.get_placement_groups()
print(
f" ✓ Ray train cluster initialized with {train_nodes} nodes with {train_gpus_per_node} GPUs per node",
flush=True,
)

# initialize inference cluster
# Create inference cluster with topology constraints so TP groups
# stay within NVLink domains. Eagerly initialize PGs when constraints
# are set so inference claims domain-aligned nodes first.
inference_cluster = RayVirtualCluster(
name="grpo_inference_cluster",
bundle_ct_per_node_list=[inference_gpus_per_node] * inference_nodes,
Expand All @@ -634,7 +747,20 @@ def _spinup_nemo_gym(base_urls, model_name):
max_colocated_worker_groups=1,
port_range_low=cluster_config.get("master_port_range_low"),
port_range_high=cluster_config.get("master_port_range_high"),
segment_size=inference_segment_size,
node_resource_constraints=inference_node_resource_constraints,
)
if inference_node_resource_constraints is not None:
if generation_config["backend"] == "megatron":
MegatronGeneration.init_cluster_placement_groups(
inference_cluster, policy_config
)
elif generation_config["backend"] == "sglang":
SGLangGeneration.init_cluster_placement_groups(inference_cluster)
else:
VllmGeneration.init_cluster_placement_groups(
inference_cluster, generation_config
)
print(
f" ✓ Ray inference cluster initialized with {inference_nodes} nodes with {inference_gpus_per_node} GPUs per node",
flush=True,
Expand Down
Loading
Loading