Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 82 additions & 0 deletions tests/unit/environments/test_nemo_gym.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import pytest
import ray
import requests
import torch
from yaml import safe_load

Expand Down Expand Up @@ -288,3 +289,84 @@ def _standardize(l: list[dict]):
return list(map(_standardize_single_result, l))

assert _standardize(expected_result) == _standardize(actual_result)


# Sentinel for omitting the top_logprobs field entirely, which is distinct from sending null.
_OMIT_TOP_LOGPROBS = object()


@pytest.mark.nemo_gym
def test_vllm_http_logprobs_contract(nemo_gym_vllm_generation):
"""Pin the vLLM OpenAI HTTP logprobs contract that NeMo-Gym capture depends on.

NeMo-Gym's vllm_model sets logprobs=True and return_tokens_as_token_ids=True to extract
per-token ids and logprobs for training (Gym omits top_logprobs on the capture path, so
vLLM applies its default; Gym PR #1612 additionally pins top_logprobs=0, which is
equivalent). vLLM computes `logprobs = top_logprobs if logprobs else None`, so omitting
top_logprobs (default 0) or sending 0 returns logprobs, while an explicit null returns
none and silently empties the captured token ids. This exercises the real HTTP path where
that translation lives (the offline LLM API does not), so a vLLM bump that changes the
contract fails here instead of silently freezing training.

All three cases share the (expensive) vLLM fixture, so they run in a single test rather
than as separate parametrized cases.
"""
base_url = nemo_gym_vllm_generation.dp_openai_server_base_urls[0]
gen_cfg = nemo_gym_vllm_generation.cfg

def _chat(top_logprobs_field):
body = {
"model": gen_cfg["model_name"],
"messages": [{"role": "user", "content": "Say hello."}],
"max_tokens": 8,
# The RL HTTP wrapper asserts these match the generation config exactly.
"temperature": gen_cfg["temperature"],
"top_p": gen_cfg["top_p"],
# The fields NeMo-Gym sets to capture token ids.
"logprobs": True,
"return_tokens_as_token_ids": True,
}
if top_logprobs_field is not _OMIT_TOP_LOGPROBS:
body["top_logprobs"] = top_logprobs_field

# The base URL is known once the fixture is ready, but retry briefly to avoid racing
# the very first connection to the server.
last_exc = None
for _ in range(30):
try:
return requests.post(
f"{base_url}/chat/completions", json=body, timeout=60
)
except requests.exceptions.ConnectionError as e:
last_exc = e
time.sleep(1)
raise AssertionError(f"vLLM HTTP server never became reachable: {last_exc}")

def _assert_has_token_ids(resp, label):
resp.raise_for_status()
content = resp.json()["choices"][0]["logprobs"]["content"]
assert content, f"expected per-token logprobs for {label}"
# return_tokens_as_token_ids makes each token a "token_id:<int>" string; capture
# parses these into ints, so they must all parse.
token_ids = [int(c["token"].removeprefix("token_id:")) for c in content]
assert len(token_ids) == len(content)

# Omitting top_logprobs (what Gym does on the capture path; vLLM default 0) and sending 0
# (the equivalent explicit pin) must both yield per-token logprobs whose tokens decode to ints.
_assert_has_token_ids(_chat(_OMIT_TOP_LOGPROBS), "omitted top_logprobs")
_assert_has_token_ids(_chat(0), "top_logprobs=0")

# Explicit null is the divergence that motivates the Gym fix: vLLM returns no logprobs
# (200 with logprobs=None) or rejects the request outright. Both mean capture gets
# nothing. If a future vLLM makes null behave like 0, this fails and signals the Gym
# workaround can be relaxed.
null_resp = _chat(None)
if null_resp.status_code == 200:
assert null_resp.json()["choices"][0].get("logprobs") is None
Comment thread
ananthsub marked this conversation as resolved.
else:
# A rejection must be a client-side validation error, not an unrelated server failure
# that would let this branch pass vacuously.
assert 400 <= null_resp.status_code < 500, (
f"expected null top_logprobs accepted-with-None or rejected as 4xx, "
f"got {null_resp.status_code}: {null_resp.text}"
)
Loading