Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 80 additions & 0 deletions examples/train_integrations/echo_terminal/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
# ECHO Terminal-Agent Training

This example trains terminal agents with ECHO, an environment cross-entropy hybrid objective. ECHO combines standard policy-gradient RL with an auxiliary cross-entropy loss on terminal-output tokens observed in the same rollout.

SkyRL provides the core RL training stack, distributed worker execution, and vLLM-backed inference. This example adds the terminal-agent dataset loader, prompt formatting, tool-call parsing, Harbor-backed environment execution, rollout construction, token masks, and the optional environment-prediction loss.

## Structure

```text
examples/train_integrations/echo_terminal/
entrypoint.py # Training entrypoint
generator.py # Terminal rollout loop and SkyRL trajectory construction
dataset.py # Parquet dataset loader and prompt tokenization
harbor_environment.py # Harbor container execution wrapper
interaction.py # Rollout transcript and token-mask bookkeeping
parsers.py # Tool-call parsing
prompts.py # Terminal-agent system prompts
tools.py # Tool schemas
chat_template.py # Chat-template loading helpers
chat_templates/qwen3_xml_tool_calling.jinja
world_modeling/
config.py # ECHO config extensions
fsdp_worker.py # FSDP auxiliary-loss hook implementation
loss.py # Environment-token CE loss
trainer.py # Training-batch conversion for ECHO masks
configs/
qwen3_8b_rl.yaml # Vanilla GRPO baseline
qwen3_8b_rl_wm05.yaml # GRPO + ECHO loss, lambda=0.05
```

## Quick Start

Install SkyRL with the FSDP and Harbor dependencies:

```bash
cd SkyRL
pip install -e ".[fsdp,harbor]"
```

Edit the train and validation parquet paths in the config you want to run:

```yaml
data:
train_data:
- name: terminal_agent_train
path: /path/to/train.parquet
val_data:
- name: terminal_agent_train
path: /path/to/val.parquet
```

Set an output directory and launch the vanilla GRPO baseline:

```bash
export OUTPUT_DIR=/path/to/outputs/qwen3_8b_rl
export CONFIG_PATH=examples/train_integrations/echo_terminal/configs/qwen3_8b_rl.yaml
bash examples/train_integrations/echo_terminal/run_echo_terminal.sh
```

Launch ECHO with the auxiliary environment-prediction loss:

```bash
export OUTPUT_DIR=/path/to/outputs/qwen3_8b_rl_wm05
export CONFIG_PATH=examples/train_integrations/echo_terminal/configs/qwen3_8b_rl_wm05.yaml
bash examples/train_integrations/echo_terminal/run_echo_terminal.sh
```

Checkpoints are written to `${OUTPUT_DIR}/ckpts`, and SkyRL logs are written to `${OUTPUT_DIR}/skyrl_logs`.

## Design

The rollout loop is handled directly in this example rather than through Harbor's full rollout API. Harbor is used as the terminal task backend: it starts the task containers, runs shell commands, returns terminal observations, and executes verifiers. SkyRL/vLLM owns model generation so the training code has direct, batched access to generated token ids, logprobs, attention masks, sampling controls, and ECHO-specific token masks.

During training, the standard GRPO loss is computed on model-generated action tokens. When `trainer.algorithm.world_model_coeff > 0`, ECHO also computes cross entropy on selected terminal-output tokens from the same trajectory:

```text
L = L_GRPO(action tokens) + world_model_coeff * CE(terminal-output tokens)
```

Setting `world_model_coeff: 0.0` recovers the vanilla GRPO baseline. The included ECHO config uses `world_model_coeff: 0.05` and `generator.world_loss_target: env_only`, which trains on terminal environment-output tokens while leaving the RL action-token mask unchanged.
1 change: 1 addition & 0 deletions examples/train_integrations/echo_terminal/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
__all__ = []
97 changes: 97 additions & 0 deletions examples/train_integrations/echo_terminal/chat_template.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
from __future__ import annotations

from pathlib import Path
from typing import Any

from transformers import PreTrainedTokenizerBase

_PROBE_SENTINEL = "OBSERVATION_PROBE_TOKEN_ABCXYZ_12345"
_PROBE_MESSAGES: list[dict[str, str]] = [
{"role": "system", "content": "You are a helpful agent with access to bash."},
{"role": "user", "content": "Run `ls /tmp` and report the result."},
{
"role": "assistant",
"content": (
"<think>\nI should call the bash tool with the command.\n</think>\n\n"
"<tool_call>\n<function=bash>\n<parameter=command>\nls /tmp\n"
"</parameter>\n</function>\n</tool_call>"
),
},
]


def check_role_roundtrip(
tokenizer: PreTrainedTokenizerBase,
role: str,
*,
template_kwargs: dict[str, Any] | None = None,
sentinel: str = _PROBE_SENTINEL,
) -> tuple[bool, str]:
template_kwargs = template_kwargs or {}
try:
before = tokenizer.apply_chat_template(
_PROBE_MESSAGES,
add_generation_prompt=False,
tokenize=True,
return_dict=False,
**template_kwargs,
)
after = tokenizer.apply_chat_template(
_PROBE_MESSAGES + [{"role": role, "content": sentinel}],
add_generation_prompt=True,
tokenize=True,
return_dict=False,
**template_kwargs,
)
except Exception as exc:
return False, f"apply_chat_template raised {type(exc).__name__}: {exc}"

if after[: len(before)] != before:
return False, f"prefix mismatch when {role!r} was appended"
delta = after[len(before) :]
if not delta:
return False, "empty delta"
decoded = tokenizer.decode(delta, skip_special_tokens=False)
if sentinel not in decoded:
return False, f"sentinel content not found for role {role!r}"
return True, ""


def choose_obs_role(
tokenizer: PreTrainedTokenizerBase,
candidates: tuple[str, ...] = ("tool", "user"),
*,
template_kwargs: dict[str, Any] | None = None,
) -> str:
failures: list[str] = []
for role in candidates:
ok, reason = check_role_roundtrip(tokenizer, role, template_kwargs=template_kwargs)
if ok:
return role
failures.append(f" - {role!r}: {reason}")
raise RuntimeError(
f"None of {list(candidates)!r} produced a usable observation role for "
f"tokenizer {tokenizer.name_or_path!r}. Failures:\n" + "\n".join(failures)
)


def resolve_chat_template_path(template_path: str | Path) -> Path:
path = Path(template_path).expanduser()
if path.is_absolute():
candidates = [path]
else:
module_path = Path(__file__).resolve()
candidates = [
Path.cwd() / path,
module_path.parent / path,
module_path.parents[2] / path,
]

for candidate in candidates:
if candidate.exists():
return candidate
raise FileNotFoundError(f"Chat template not found at {template_path!r}; checked {candidates!r}")


def load_chat_template(tokenizer: PreTrainedTokenizerBase, template_path: str | Path) -> None:
tokenizer.chat_template = resolve_chat_template_path(template_path).read_text()
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
{%- set image_count = namespace(value=0) %}
{%- set video_count = namespace(value=0) %}
{%- macro render_content(content, do_vision_count, is_system_content=false) %}
{%- if content is string %}
{{- content }}
{%- elif content is iterable and content is not mapping %}
{%- for item in content %}
{%- if 'image' in item or 'image_url' in item or item.type == 'image' %}
{%- if is_system_content %}
{{- raise_exception('System message cannot contain images.') }}
{%- endif %}
{%- if do_vision_count %}
{%- set image_count.value = image_count.value + 1 %}
{%- endif %}
{%- if add_vision_id %}
{{- 'Picture ' ~ image_count.value ~ ': ' }}
{%- endif %}
{{- '<|vision_start|><|image_pad|><|vision_end|>' }}
{%- elif 'video' in item or item.type == 'video' %}
{%- if is_system_content %}
{{- raise_exception('System message cannot contain videos.') }}
{%- endif %}
{%- if do_vision_count %}
{%- set video_count.value = video_count.value + 1 %}
{%- endif %}
{%- if add_vision_id %}
{{- 'Video ' ~ video_count.value ~ ': ' }}
{%- endif %}
{{- '<|vision_start|><|video_pad|><|vision_end|>' }}
{%- elif 'text' in item %}
{{- item.text }}
{%- else %}
{{- raise_exception('Unexpected item type in content.') }}
{%- endif %}
{%- endfor %}
{%- elif content is none or content is undefined %}
{{- '' }}
{%- else %}
{{- raise_exception('Unexpected content type.') }}
{%- endif %}
{%- endmacro %}
{%- if not messages %}
{{- raise_exception('No messages provided.') }}
{%- endif %}
{%- if tools and tools is iterable and tools is not mapping %}
{{- '<|im_start|>system\n' }}
{{- "# Tools\n\nYou have access to the following functions:\n\n<tools>" }}
{%- for tool in tools %}
{{- "\n" }}
{{- tool | tojson }}
{%- endfor %}
{{- "\n</tools>" }}
{{- '\n\nIf you choose to call a function ONLY reply in the following format with NO suffix:\n\n<tool_call>\n<function=example_function_name>\n<parameter=example_parameter_1>\nvalue_1\n</parameter>\n<parameter=example_parameter_2>\nThis is the value for the second parameter\nthat can span\nmultiple lines\n</parameter>\n</function>\n</tool_call>\n\n<IMPORTANT>\nReminder:\n- Function calls MUST follow the specified format: an inner <function=...></function> block must be nested within <tool_call></tool_call> XML tags\n- Required parameters MUST be specified\n- You may provide optional reasoning for your function call in natural language BEFORE the function call, but NOT after\n- If there is no function call available, answer the question like normal with your current knowledge and do not tell the user about function calls\n</IMPORTANT>' }}
{%- if messages[0].role == 'system' %}
{%- set content = render_content(messages[0].content, false, true)|trim %}
{%- if content %}
{{- '\n\n' + content }}
{%- endif %}
{%- endif %}
{{- '<|im_end|>\n' }}
{%- else %}
{%- if messages[0].role == 'system' %}
{%- set content = render_content(messages[0].content, false, true)|trim %}
{{- '<|im_start|>system\n' + content + '<|im_end|>\n' }}
{%- endif %}
{%- endif %}
{%- set ns = namespace(multi_step_tool=true, last_query_index=messages|length - 1) %}
{%- for message in messages[::-1] %}
{%- set index = (messages|length - 1) - loop.index0 %}
{%- if ns.multi_step_tool and message.role == "user" %}
{%- set content = render_content(message.content, false)|trim %}
{%- if not(content.startswith('<tool_response>') and content.endswith('</tool_response>')) %}
{%- set ns.multi_step_tool = false %}
{%- set ns.last_query_index = index %}
{%- endif %}
{%- endif %}
{%- endfor %}
{%- if ns.multi_step_tool %}
{{- raise_exception('No user query found in messages.') }}
{%- endif %}
{%- for message in messages %}
{%- set content = render_content(message.content, true)|trim %}
{%- if message.role == "system" %}
{%- if not loop.first %}
{{- raise_exception('System message must be at the beginning.') }}
{%- endif %}
{%- elif message.role == "user" %}
{{- '<|im_start|>' + message.role + '\n' + content + '<|im_end|>' + '\n' }}
{%- elif message.role == "assistant" %}
{%- set reasoning_content = '' %}
{%- if message.reasoning_content is string %}
{%- set reasoning_content = message.reasoning_content %}
{%- else %}
{%- if '</think>' in content %}
{%- set reasoning_content = content.split('</think>')[0].rstrip('\n').split('<think>')[-1].lstrip('\n') %}
{%- set content = content.split('</think>')[-1].lstrip('\n') %}
{%- endif %}
{%- endif %}
{%- set reasoning_content = reasoning_content|trim %}
{%- if loop.index0 > ns.last_query_index %}
{{- '<|im_start|>' + message.role + '\n<think>\n' + reasoning_content + '\n</think>\n\n' + content }}
{%- else %}
{{- '<|im_start|>' + message.role + '\n' + content }}
{%- endif %}
{%- if message.tool_calls and message.tool_calls is iterable and message.tool_calls is not mapping %}
{%- for tool_call in message.tool_calls %}
{%- if tool_call.function is defined %}
{%- set tool_call = tool_call.function %}
{%- endif %}
{%- if loop.first %}
{%- if content|trim %}
{{- '\n\n<tool_call>\n<function=' + tool_call.name + '>\n' }}
{%- else %}
{{- '<tool_call>\n<function=' + tool_call.name + '>\n' }}
{%- endif %}
{%- else %}
{{- '\n<tool_call>\n<function=' + tool_call.name + '>\n' }}
{%- endif %}
{%- if tool_call.arguments is defined %}
{%- for args_name, args_value in tool_call.arguments|items %}
{{- '<parameter=' + args_name + '>\n' }}
{%- set args_value = args_value | tojson | safe if args_value is mapping or (args_value is sequence and args_value is not string) else args_value | string %}
{{- args_value }}
{{- '\n</parameter>\n' }}
{%- endfor %}
{%- endif %}
{{- '</function>\n</tool_call>' }}
{%- endfor %}
{%- endif %}
{{- '<|im_end|>\n' }}
{%- elif message.role == "tool" %}
{%- if loop.previtem and loop.previtem.role != "tool" %}
{{- '<|im_start|>user' }}
{%- endif %}
{{- '\n<tool_response>\n' }}
{{- content }}
{{- '\n</tool_response>' }}
{%- if not loop.last and loop.nextitem.role != "tool" %}
{{- '<|im_end|>\n' }}
{%- elif loop.last %}
{{- '<|im_end|>\n' }}
{%- endif %}
{%- else %}
{{- raise_exception('Unexpected message role.') }}
{%- endif %}
{%- endfor %}
{%- if add_generation_prompt %}
{{- '<|im_start|>assistant\n' }}
{%- if enable_thinking is defined and enable_thinking is false %}
{{- '<think>\n\n</think>\n\n' }}
{%- else %}
{{- '<think>\n' }}
{%- endif %}
{%- endif %}
Loading
Loading