Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions nemo_rl/algorithms/grpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -2589,6 +2589,12 @@ def grpo_train(
"vllm_metrics_logger_interval"
],
logger,
log_scalars=master_config.policy["generation"]["vllm_cfg"].get(
"vllm_metrics_log_scalars", False
),
log_timeline_plots=master_config.policy["generation"][
"vllm_cfg"
].get("vllm_metrics_log_timeline_plots", True),
)

# Plot ISL/OSL/ISL+OSL histograms to wandb
Expand Down Expand Up @@ -3763,6 +3769,12 @@ def async_grpo_train(
"vllm_metrics_logger_interval"
],
logger,
log_scalars=master_config.policy["generation"]["vllm_cfg"].get(
"vllm_metrics_log_scalars", False
),
log_timeline_plots=master_config.policy["generation"][
"vllm_cfg"
].get("vllm_metrics_log_timeline_plots", True),
)

# Plot ISL/OSL/ISL+OSL histograms to wandb
Expand Down
6 changes: 6 additions & 0 deletions nemo_rl/algorithms/grpo_sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -1219,6 +1219,12 @@ def grpo_train_sync(
"vllm_metrics_logger_interval"
],
logger,
log_scalars=master_config.policy["generation"]["vllm_cfg"].get(
"vllm_metrics_log_scalars", False
),
log_timeline_plots=master_config.policy["generation"][
"vllm_cfg"
].get("vllm_metrics_log_timeline_plots", True),
)

if (
Expand Down
44 changes: 37 additions & 7 deletions nemo_rl/algorithms/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -884,6 +884,8 @@ def log_generation_metrics_to_wandb(
step: int,
timeline_interval: float,
logger: Logger,
log_scalars: bool = False,
log_timeline_plots: bool = True,
) -> None:
"""Log generation metrics to wandb.

Expand All @@ -892,12 +894,40 @@ def log_generation_metrics_to_wandb(
step: Global step value
timeline_interval: Interval between timeline points (in seconds)
logger: Logger instance
log_scalars: also log per-engine scalar aggregates (mean/max/p95 across time)
so the metrics show as trackable line charts, not just per-step Media images.
log_timeline_plots: log the per-worker matplotlib timeline IMAGE plots. These are heavy
(one figure per metric per step) and slow to upload/render; set False to rely on the
lightweight scalars instead.
"""
for generation_metric in generation_logger_metrics.keys():
logger.log_plot_per_worker_timeline_metrics(
generation_logger_metrics[generation_metric],
step=step,
prefix="generation_metrics",
name=generation_metric,
timeline_interval=timeline_interval,
)
per_worker = generation_logger_metrics[generation_metric]
# Heavy per-step image plots (optional -> can dominate wandb render time).
if log_timeline_plots:
logger.log_plot_per_worker_timeline_metrics(
per_worker,
step=step,
prefix="generation_metrics",
name=generation_metric,
timeline_interval=timeline_interval,
)
# Optional per-engine scalar aggregates -> line charts trackable across steps.
if log_scalars:
for dp_idx, series in per_worker.items():
if series:
arr = np.asarray(series, dtype=float)
logger.log_metrics(
{
f"{generation_metric}/engine_{dp_idx}/mean": float(
arr.mean()
),
f"{generation_metric}/engine_{dp_idx}/max": float(
arr.max()
),
f"{generation_metric}/engine_{dp_idx}/p95": float(
np.percentile(arr, 95)
),
},
step,
prefix="generation_metrics",
)
13 changes: 13 additions & 0 deletions nemo_rl/models/generation/vllm/vllm_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -916,6 +916,10 @@ def get_vllm_logger_metrics(self) -> dict[str, Any]:
"num_pending_samples": {}, # dp_idx -> list[int]
"kv_cache_usage_perc": {}, # dp_idx -> list[float]
"generation_tokens": {}, # dp_idx -> list[int]
"prefix_cache_queries": {}, # dp_idx -> list[int] (windowed delta)
"prefix_cache_hits": {}, # dp_idx -> list[int] (windowed delta)
"prefix_cache_hit_rate": {}, # dp_idx -> list[float] (windowed)
"num_preemptions": {}, # dp_idx -> list[int] (windowed delta)
}

for dp_idx, stats in zip(dp_indices, results):
Expand All @@ -935,6 +939,15 @@ def get_vllm_logger_metrics(self) -> dict[str, Any]:
generation_tokens = stats.get("generation_tokens")
if generation_tokens:
vllm_logger_metrics["generation_tokens"][dp_idx] = generation_tokens
for key in (
"prefix_cache_queries",
"prefix_cache_hits",
"prefix_cache_hit_rate",
"num_preemptions",
):
val = stats.get(key)
if val:
vllm_logger_metrics[key][dp_idx] = val

return vllm_logger_metrics

Expand Down
82 changes: 81 additions & 1 deletion nemo_rl/models/generation/vllm/vllm_worker_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,16 +340,53 @@ def _start_vllm_metrics_logger(self) -> None:
stop_event = threading.Event()
self._vllm_metrics_logger_stop_event = stop_event

# Gauges (instantaneous readings) -- logged as sampled.
self.inflight_batch_sizes: list[int] = []
self.num_pending_samples: list[int] = []
self.kv_cache_usage_perc: list[float] = []
# vLLM exposes generation_tokens / prefix_cache_queries / prefix_cache_hits /
# num_preemptions as CUMULATIVE counters (monotonic from engine start). We log per-interval
# DELTAS rather than the lifetime totals so each step's series reflects that step's activity
# (a lifetime total would just ramp with runtime and never be step-comparable).
# prefix_cache_hit_rate is derived from the query/hit deltas. _prev_counters holds the last
# raw cumulative readings and PERSISTS across clear_vllm_logger_metrics() so the deltas stay
# correct across steps; _windowed_delta() clamps a negative delta (engine restart -> counter
# reset to 0) to 0 and re-baselines.
self.generation_tokens: list[int] = []
self.prefix_cache_queries: list[int] = []
self.prefix_cache_hits: list[int] = []
self.num_preemptions: list[int] = []
self.prefix_cache_hit_rate: list[float] = []
self._prev_counters: dict[str, int] = {
"generation_tokens": 0,
"prefix_cache_queries": 0,
"prefix_cache_hits": 0,
"num_preemptions": 0,
}

def _windowed_delta(raw: int, counter_name: str) -> int:
"""Per-interval delta of a cumulative vLLM counter.

Returns raw - previous-reading and re-baselines. A negative delta means the
engine restarted (counter reset to 0), so clamp to 0 instead of logging a large
negative spike. _prev_counters persist across clear_vllm_logger_metrics() so
deltas stay correct across steps.
"""
delta = raw - self._prev_counters[counter_name]
self._prev_counters[counter_name] = raw
return delta if delta >= 0 else 0

def _logger_loop():
# Delay a little to let engine settle
time.sleep(min(2.0, interval_s))
while True:
try:
# Gauges are appended as sampled; cumulative counters are captured raw here
# and converted to per-interval deltas after the snapshot pass.
raw_generation_tokens: int | None = None
raw_prefix_cache_queries: int | None = None
raw_prefix_cache_hits: int | None = None
raw_num_preemptions: int | None = None
for m in get_metrics_snapshot():
with self._vllm_metrics_lock:
if isinstance(m, Gauge):
Expand All @@ -364,7 +401,39 @@ def _logger_loop():
self.kv_cache_usage_perc.append(float(m.value))
elif isinstance(m, Counter):
if m.name == "vllm:generation_tokens":
self.generation_tokens.append(int(m.value))
raw_generation_tokens = int(m.value)
elif m.name == "vllm:prefix_cache_queries":
raw_prefix_cache_queries = int(m.value)
elif m.name == "vllm:prefix_cache_hits":
raw_prefix_cache_hits = int(m.value)
elif m.name == "vllm:num_preemptions":
raw_num_preemptions = int(m.value)
# Convert cumulative counters to per-interval deltas (step-comparable), and
# derive the windowed prefix-cache hit-rate from the same query/hit deltas.
with self._vllm_metrics_lock:
if raw_generation_tokens is not None:
self.generation_tokens.append(
_windowed_delta(
raw_generation_tokens, "generation_tokens"
)
)
dq = dh = None
if raw_prefix_cache_queries is not None:
dq = _windowed_delta(
raw_prefix_cache_queries, "prefix_cache_queries"
)
self.prefix_cache_queries.append(dq)
if raw_prefix_cache_hits is not None:
dh = _windowed_delta(
raw_prefix_cache_hits, "prefix_cache_hits"
)
self.prefix_cache_hits.append(dh)
if raw_num_preemptions is not None:
self.num_preemptions.append(
_windowed_delta(raw_num_preemptions, "num_preemptions")
)
if dq is not None and dh is not None and dq > 0:
self.prefix_cache_hit_rate.append(dh / dq)
except Exception:
print(
"⚠️[vLLM Metric Logger] Exception in vLLM metrics logger",
Expand Down Expand Up @@ -393,6 +462,10 @@ def get_vllm_logger_metrics(self) -> dict[str, Any]:
"num_pending_samples": copy.deepcopy(self.num_pending_samples),
"kv_cache_usage_perc": copy.deepcopy(self.kv_cache_usage_perc),
"generation_tokens": copy.deepcopy(self.generation_tokens),
"prefix_cache_queries": copy.deepcopy(self.prefix_cache_queries),
"prefix_cache_hits": copy.deepcopy(self.prefix_cache_hits),
"prefix_cache_hit_rate": copy.deepcopy(self.prefix_cache_hit_rate),
"num_preemptions": copy.deepcopy(self.num_preemptions),
}
return metric

Expand All @@ -404,7 +477,14 @@ def clear_vllm_logger_metrics(self) -> None:
self.inflight_batch_sizes = []
self.num_pending_samples = []
self.kv_cache_usage_perc = []
# NOTE: clear only the per-step series; do NOT reset _prev_counters (handled in
# _windowed_delta) — they track the monotonic cumulative counters across steps so the
# next step's deltas (and the hit-rate) stay correct after a clear.
self.generation_tokens = []
self.prefix_cache_queries = []
self.prefix_cache_hits = []
self.prefix_cache_hit_rate = []
self.num_preemptions = []

async def post_init_async(self):
self.vllm_device_ids = await self.report_device_id_async()
Expand Down
Loading