diff --git a/nemo_rl/algorithms/grpo.py b/nemo_rl/algorithms/grpo.py index b963a6d4ff..6363122fe8 100644 --- a/nemo_rl/algorithms/grpo.py +++ b/nemo_rl/algorithms/grpo.py @@ -2589,6 +2589,12 @@ def grpo_train( "vllm_metrics_logger_interval" ], logger, + log_scalars=master_config.policy["generation"]["vllm_cfg"].get( + "vllm_metrics_log_scalars", False + ), + log_timeline_plots=master_config.policy["generation"][ + "vllm_cfg" + ].get("vllm_metrics_log_timeline_plots", True), ) # Plot ISL/OSL/ISL+OSL histograms to wandb @@ -3763,6 +3769,12 @@ def async_grpo_train( "vllm_metrics_logger_interval" ], logger, + log_scalars=master_config.policy["generation"]["vllm_cfg"].get( + "vllm_metrics_log_scalars", False + ), + log_timeline_plots=master_config.policy["generation"][ + "vllm_cfg" + ].get("vllm_metrics_log_timeline_plots", True), ) # Plot ISL/OSL/ISL+OSL histograms to wandb diff --git a/nemo_rl/algorithms/grpo_sync.py b/nemo_rl/algorithms/grpo_sync.py index 4ece3f81da..8652ad10e0 100644 --- a/nemo_rl/algorithms/grpo_sync.py +++ b/nemo_rl/algorithms/grpo_sync.py @@ -1219,6 +1219,12 @@ def grpo_train_sync( "vllm_metrics_logger_interval" ], logger, + log_scalars=master_config.policy["generation"]["vllm_cfg"].get( + "vllm_metrics_log_scalars", False + ), + log_timeline_plots=master_config.policy["generation"][ + "vllm_cfg" + ].get("vllm_metrics_log_timeline_plots", True), ) if ( diff --git a/nemo_rl/algorithms/utils.py b/nemo_rl/algorithms/utils.py index c51076402b..60177e07c3 100644 --- a/nemo_rl/algorithms/utils.py +++ b/nemo_rl/algorithms/utils.py @@ -884,6 +884,8 @@ def log_generation_metrics_to_wandb( step: int, timeline_interval: float, logger: Logger, + log_scalars: bool = False, + log_timeline_plots: bool = True, ) -> None: """Log generation metrics to wandb. @@ -892,12 +894,40 @@ def log_generation_metrics_to_wandb( step: Global step value timeline_interval: Interval between timeline points (in seconds) logger: Logger instance + log_scalars: also log per-engine scalar aggregates (mean/max/p95 across time) + so the metrics show as trackable line charts, not just per-step Media images. + log_timeline_plots: log the per-worker matplotlib timeline IMAGE plots. These are heavy + (one figure per metric per step) and slow to upload/render; set False to rely on the + lightweight scalars instead. """ for generation_metric in generation_logger_metrics.keys(): - logger.log_plot_per_worker_timeline_metrics( - generation_logger_metrics[generation_metric], - step=step, - prefix="generation_metrics", - name=generation_metric, - timeline_interval=timeline_interval, - ) + per_worker = generation_logger_metrics[generation_metric] + # Heavy per-step image plots (optional -> can dominate wandb render time). + if log_timeline_plots: + logger.log_plot_per_worker_timeline_metrics( + per_worker, + step=step, + prefix="generation_metrics", + name=generation_metric, + timeline_interval=timeline_interval, + ) + # Optional per-engine scalar aggregates -> line charts trackable across steps. + if log_scalars: + for dp_idx, series in per_worker.items(): + if series: + arr = np.asarray(series, dtype=float) + logger.log_metrics( + { + f"{generation_metric}/engine_{dp_idx}/mean": float( + arr.mean() + ), + f"{generation_metric}/engine_{dp_idx}/max": float( + arr.max() + ), + f"{generation_metric}/engine_{dp_idx}/p95": float( + np.percentile(arr, 95) + ), + }, + step, + prefix="generation_metrics", + ) diff --git a/nemo_rl/models/generation/vllm/vllm_generation.py b/nemo_rl/models/generation/vllm/vllm_generation.py index b2f2962c64..513630b482 100644 --- a/nemo_rl/models/generation/vllm/vllm_generation.py +++ b/nemo_rl/models/generation/vllm/vllm_generation.py @@ -916,6 +916,10 @@ def get_vllm_logger_metrics(self) -> dict[str, Any]: "num_pending_samples": {}, # dp_idx -> list[int] "kv_cache_usage_perc": {}, # dp_idx -> list[float] "generation_tokens": {}, # dp_idx -> list[int] + "prefix_cache_queries": {}, # dp_idx -> list[int] (windowed delta) + "prefix_cache_hits": {}, # dp_idx -> list[int] (windowed delta) + "prefix_cache_hit_rate": {}, # dp_idx -> list[float] (windowed) + "num_preemptions": {}, # dp_idx -> list[int] (windowed delta) } for dp_idx, stats in zip(dp_indices, results): @@ -935,6 +939,15 @@ def get_vllm_logger_metrics(self) -> dict[str, Any]: generation_tokens = stats.get("generation_tokens") if generation_tokens: vllm_logger_metrics["generation_tokens"][dp_idx] = generation_tokens + for key in ( + "prefix_cache_queries", + "prefix_cache_hits", + "prefix_cache_hit_rate", + "num_preemptions", + ): + val = stats.get(key) + if val: + vllm_logger_metrics[key][dp_idx] = val return vllm_logger_metrics diff --git a/nemo_rl/models/generation/vllm/vllm_worker_async.py b/nemo_rl/models/generation/vllm/vllm_worker_async.py index c7e09455a4..a322df97ea 100644 --- a/nemo_rl/models/generation/vllm/vllm_worker_async.py +++ b/nemo_rl/models/generation/vllm/vllm_worker_async.py @@ -340,16 +340,53 @@ def _start_vllm_metrics_logger(self) -> None: stop_event = threading.Event() self._vllm_metrics_logger_stop_event = stop_event + # Gauges (instantaneous readings) -- logged as sampled. self.inflight_batch_sizes: list[int] = [] self.num_pending_samples: list[int] = [] self.kv_cache_usage_perc: list[float] = [] + # vLLM exposes generation_tokens / prefix_cache_queries / prefix_cache_hits / + # num_preemptions as CUMULATIVE counters (monotonic from engine start). We log per-interval + # DELTAS rather than the lifetime totals so each step's series reflects that step's activity + # (a lifetime total would just ramp with runtime and never be step-comparable). + # prefix_cache_hit_rate is derived from the query/hit deltas. _prev_counters holds the last + # raw cumulative readings and PERSISTS across clear_vllm_logger_metrics() so the deltas stay + # correct across steps; _windowed_delta() clamps a negative delta (engine restart -> counter + # reset to 0) to 0 and re-baselines. self.generation_tokens: list[int] = [] + self.prefix_cache_queries: list[int] = [] + self.prefix_cache_hits: list[int] = [] + self.num_preemptions: list[int] = [] + self.prefix_cache_hit_rate: list[float] = [] + self._prev_counters: dict[str, int] = { + "generation_tokens": 0, + "prefix_cache_queries": 0, + "prefix_cache_hits": 0, + "num_preemptions": 0, + } + + def _windowed_delta(raw: int, counter_name: str) -> int: + """Per-interval delta of a cumulative vLLM counter. + + Returns raw - previous-reading and re-baselines. A negative delta means the + engine restarted (counter reset to 0), so clamp to 0 instead of logging a large + negative spike. _prev_counters persist across clear_vllm_logger_metrics() so + deltas stay correct across steps. + """ + delta = raw - self._prev_counters[counter_name] + self._prev_counters[counter_name] = raw + return delta if delta >= 0 else 0 def _logger_loop(): # Delay a little to let engine settle time.sleep(min(2.0, interval_s)) while True: try: + # Gauges are appended as sampled; cumulative counters are captured raw here + # and converted to per-interval deltas after the snapshot pass. + raw_generation_tokens: int | None = None + raw_prefix_cache_queries: int | None = None + raw_prefix_cache_hits: int | None = None + raw_num_preemptions: int | None = None for m in get_metrics_snapshot(): with self._vllm_metrics_lock: if isinstance(m, Gauge): @@ -364,7 +401,39 @@ def _logger_loop(): self.kv_cache_usage_perc.append(float(m.value)) elif isinstance(m, Counter): if m.name == "vllm:generation_tokens": - self.generation_tokens.append(int(m.value)) + raw_generation_tokens = int(m.value) + elif m.name == "vllm:prefix_cache_queries": + raw_prefix_cache_queries = int(m.value) + elif m.name == "vllm:prefix_cache_hits": + raw_prefix_cache_hits = int(m.value) + elif m.name == "vllm:num_preemptions": + raw_num_preemptions = int(m.value) + # Convert cumulative counters to per-interval deltas (step-comparable), and + # derive the windowed prefix-cache hit-rate from the same query/hit deltas. + with self._vllm_metrics_lock: + if raw_generation_tokens is not None: + self.generation_tokens.append( + _windowed_delta( + raw_generation_tokens, "generation_tokens" + ) + ) + dq = dh = None + if raw_prefix_cache_queries is not None: + dq = _windowed_delta( + raw_prefix_cache_queries, "prefix_cache_queries" + ) + self.prefix_cache_queries.append(dq) + if raw_prefix_cache_hits is not None: + dh = _windowed_delta( + raw_prefix_cache_hits, "prefix_cache_hits" + ) + self.prefix_cache_hits.append(dh) + if raw_num_preemptions is not None: + self.num_preemptions.append( + _windowed_delta(raw_num_preemptions, "num_preemptions") + ) + if dq is not None and dh is not None and dq > 0: + self.prefix_cache_hit_rate.append(dh / dq) except Exception: print( "⚠️[vLLM Metric Logger] Exception in vLLM metrics logger", @@ -393,6 +462,10 @@ def get_vllm_logger_metrics(self) -> dict[str, Any]: "num_pending_samples": copy.deepcopy(self.num_pending_samples), "kv_cache_usage_perc": copy.deepcopy(self.kv_cache_usage_perc), "generation_tokens": copy.deepcopy(self.generation_tokens), + "prefix_cache_queries": copy.deepcopy(self.prefix_cache_queries), + "prefix_cache_hits": copy.deepcopy(self.prefix_cache_hits), + "prefix_cache_hit_rate": copy.deepcopy(self.prefix_cache_hit_rate), + "num_preemptions": copy.deepcopy(self.num_preemptions), } return metric @@ -404,7 +477,14 @@ def clear_vllm_logger_metrics(self) -> None: self.inflight_batch_sizes = [] self.num_pending_samples = [] self.kv_cache_usage_perc = [] + # NOTE: clear only the per-step series; do NOT reset _prev_counters (handled in + # _windowed_delta) — they track the monotonic cumulative counters across steps so the + # next step's deltas (and the hit-rate) stay correct after a clear. self.generation_tokens = [] + self.prefix_cache_queries = [] + self.prefix_cache_hits = [] + self.prefix_cache_hit_rate = [] + self.num_preemptions = [] async def post_init_async(self): self.vllm_device_ids = await self.report_device_id_async()