Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,7 @@ policy:
kv_cache_dtype: "auto"
expose_http_server: true
skip_tokenizer_init: false
# tool_parser_plugin: ??? # This is set to the path for Nemotron Nano v2
tool_parser_plugin: nemo_rl/models/generation/vllm/tool_parsers/nemotron_json.py
http_server_serving_chat_kwargs:
# Workplace assistant uses 26 tools, so we enable auto_tools.
# For Nemotron Nano v2, we use the dedicated `nemotron_json` tool parser
Expand Down
125 changes: 125 additions & 0 deletions nemo_rl/models/generation/vllm/tool_parsers/nemotron_json.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import json
import re
from collections.abc import Sequence
from typing import Any

from vllm.entrypoints.openai.chat_completion.protocol import (
ChatCompletionRequest,
DeltaMessage,
FunctionCall,
ToolCall,
)
from vllm.logger import init_logger
from vllm.tool_parsers.abstract_tool_parser import (
ExtractedToolCallInformation,
ToolParser,
ToolParserManager,
)

logger = init_logger(__name__)


@ToolParserManager.register_module("nemotron_json")
class NemotronJSONToolParser(ToolParser):
"""Nemotron Nano v2 non-streaming tool parser for vLLM."""

def __init__(self, tokenizer: Any, tools: Any | None = None) -> None:
try:
super().__init__(tokenizer, tools)
except TypeError:
super().__init__(tokenizer)
self.tool_call_start_token = "<TOOLCALL>"
self.tool_call_end_token = "</TOOLCALL>"
self.tool_call_regex = re.compile(r"<TOOLCALL>(.*?)</TOOLCALL>", re.DOTALL)

def extract_tool_calls(
self,
model_output: str,
request: ChatCompletionRequest,
) -> ExtractedToolCallInformation:
if self.tool_call_start_token not in model_output:
return ExtractedToolCallInformation(
tools_called=False,
tool_calls=[],
content=model_output,
)

tool_call_matches = self.tool_call_regex.findall(model_output)
if not tool_call_matches:
return ExtractedToolCallInformation(
tools_called=False,
tool_calls=[],
content=model_output,
)

try:
str_tool_calls = tool_call_matches[0].strip()
if not str_tool_calls.startswith("["):
str_tool_calls = "[" + str_tool_calls
if not str_tool_calls.endswith("]"):
str_tool_calls += "]"

tool_calls = []
for tool_call in json.loads(str_tool_calls):
try:
arguments = tool_call["arguments"]
if isinstance(arguments, dict):
arguments = json.dumps(arguments, ensure_ascii=False)
tool_calls.append(
ToolCall(
type="function",
function=FunctionCall(
name=tool_call["name"],
arguments=arguments,
),
)
)
except (KeyError, TypeError):
continue

if not tool_calls:
return ExtractedToolCallInformation(
tools_called=False,
tool_calls=[],
content=model_output,
)

content = model_output[: model_output.rfind(self.tool_call_start_token)]
return ExtractedToolCallInformation(
tools_called=True,
tool_calls=tool_calls,
content=content or None,
)
except (json.JSONDecodeError, TypeError):
logger.debug("Failed to parse Nemotron tool call.", exc_info=True)
return ExtractedToolCallInformation(
tools_called=False,
tool_calls=[],
content=model_output,
)

def extract_tool_calls_streaming(
self,
previous_text: str,
current_text: str,
delta_text: str,
previous_token_ids: Sequence[int],
current_token_ids: Sequence[int],
delta_token_ids: Sequence[int],
request: ChatCompletionRequest,
) -> DeltaMessage | None:
raise NotImplementedError("Tool calling is not supported in streaming mode.")
15 changes: 11 additions & 4 deletions nemo_rl/models/generation/vllm/vllm_worker_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -722,15 +722,22 @@ async def create_chat_completion(
generator = await openai_serving_chat.create_chat_completion(
request, raw_request
)
except VLLMValidationError as e:
except (ValueError, VLLMValidationError) as e:
# vLLM 0.20 raises VLLMValidationError for prompts exceeding
# max_model_len during tokenization, instead of returning an
# ErrorResponse. Convert to HTTP 400 so the Gym proxy can
# detect context-length overflow and handle it gracefully.
# ErrorResponse. Our post-tokenization clamp can raise a local
# ValueError for the same condition after prefix replacement.
# Convert those cases to HTTP 400 so the Gym proxy can detect
# context-length overflow and handle it gracefully.
message = str(e)
if isinstance(e, ValueError) and not (
"max_model_len" in message or "maximum context length" in message
):
raise
return JSONResponse(
content={
"error": {
"message": str(e),
"message": message,
"type": "invalid_request_error",
"code": 400,
}
Expand Down
5 changes: 5 additions & 0 deletions ray.sub
Original file line number Diff line number Diff line change
Expand Up @@ -525,7 +525,12 @@ echo "All workers connected!"
# This driver process is responsible for launching a job on the Ray cluster
CONTAINER_CWD=$(scontrol show job $SLURM_JOB_ID | grep -oP 'WorkDir=\K[^ ]+' | head -1)
if [[ -n "$COMMAND" ]]; then
set +e
srun --no-container-mount-home --overlap --container-name=ray-head --container-workdir=$CONTAINER_CWD --nodes=1 --ntasks=1 -w "$head_node" -o $LOG_DIR/ray-driver.log bash -c "$COMMAND"
driver_exit_code=$?
set -e
touch "$LOG_DIR/ENDED"
exit "$driver_exit_code"
else
echo "[INFO]: Ray Cluster is idled, run this on the slurm head node to get a shell to the head node:"
cat <<EOF >$SLURM_SUBMIT_DIR/${SLURM_JOB_ID}-attach.sh
Expand Down
59 changes: 58 additions & 1 deletion tests/unit/models/generation/test_vllm_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import asyncio
import importlib.util
import json
import os
import sys
import types
from copy import deepcopy
from pathlib import Path
from unittest.mock import MagicMock
from unittest.mock import AsyncMock, MagicMock

import pytest
import ray
Expand Down Expand Up @@ -344,6 +345,62 @@ def test_vllm_async_http_server_loads_reasoning_parser_plugin(monkeypatch):
assert "reasoning_parser_plugin" not in openai_serving_chat.instances[0].kwargs


def _setup_fake_openai_chat_completion_route(monkeypatch):
(
_tool_parser_manager,
_reasoning_parser_manager,
openai_serving_chat,
) = _install_fake_vllm_openai_modules(monkeypatch)

worker = VllmAsyncGenerationWorkerImpl.__new__(VllmAsyncGenerationWorkerImpl)
worker.cfg = {
"temperature": 1.0,
"top_p": 1.0,
"vllm_cfg": {
"http_server_serving_chat_kwargs": {},
},
}
worker.llm = MagicMock(model_config="model-config", renderer="renderer")
model_config = MagicMock(served_model_name="served-model", model="model-path")
worker.llm_async_engine_args = MagicMock()
worker.llm_async_engine_args.create_model_config.return_value = model_config

app = _FakeFastAPIApp()
worker._setup_vllm_openai_api_server(app)
route = next(func for path, func in app.routes if path == "/v1/chat/completions")
return route, openai_serving_chat.instances[0]


def test_vllm_async_chat_completion_maps_context_value_error_to_400(monkeypatch):
route, openai_serving_chat = _setup_fake_openai_chat_completion_route(monkeypatch)
openai_serving_chat.create_chat_completion = AsyncMock(
side_effect=ValueError(
"Prompt length (8551) fills or exceeds max_model_len (8192). "
"No room for output tokens."
)
)

request = types.SimpleNamespace(top_k=None, temperature=1.0, top_p=1.0)
response = asyncio.run(route(request, MagicMock()))

assert response.status_code == 400
body = json.loads(response.body.decode())
assert body["error"]["code"] == 400
assert body["error"]["type"] == "invalid_request_error"
assert "max_model_len" in body["error"]["message"]


def test_vllm_async_chat_completion_reraises_unrelated_value_error(monkeypatch):
route, openai_serving_chat = _setup_fake_openai_chat_completion_route(monkeypatch)
openai_serving_chat.create_chat_completion = AsyncMock(
side_effect=ValueError("unexpected internal validation failure")
)

request = types.SimpleNamespace(top_k=None, temperature=1.0, top_p=1.0)
with pytest.raises(ValueError, match="unexpected internal validation failure"):
asyncio.run(route(request, MagicMock()))


def test_nano_v3_reasoning_parser_swaps_reasoning_when_thinking_disabled(
monkeypatch,
):
Expand Down
Loading
Loading