Skip to content

Add ECHO terminal agent training integration#1716

Open
vshrivas wants to merge 2 commits into
NovaSky-AI:mainfrom
vshrivas:msft/echo-rl
Open

Add ECHO terminal agent training integration#1716
vshrivas wants to merge 2 commits into
NovaSky-AI:mainfrom
vshrivas:msft/echo-rl

Conversation

@vshrivas
Copy link
Copy Markdown

Add ECHO terminal agent training integration

This PR adds an ECHO terminal agent training integration to SkyRL.
ECHO combines standard policy-gradient RL with an auxiliary cross-entropy loss on terminal-output tokens observed in the same rollout. The implementation keeps the core SkyRL changes small by adding generic hooks for auxiliary policy losses and extra trajectory tensors, then implements the terminal-agent integration under examples/train_integrations/echo_terminal.

Paper: https://arxiv.org/abs/2605.24517
Blog: https://x.com/DimitrisPapail/status/2056368948870811746

What this adds

  • A new examples/train_integrations/echo_terminal integration.
  • A terminal-agent rollout harness that manages multi-turn model/tool interaction against Harbor-backed task containers.
  • XML tool-call parsing for shell commands and final-answer actions.
  • Optional ECHO world-modeling loss:
    L = L_GRPO(action tokens) + world_model_coeff * CE(terminal-output tokens)
    
  • Qwen3-8B configs for vanilla GRPO and GRPO + ECHO.
  • A Qwen3 XML tool-calling chat template used by the integration.

Core SkyRL hooks

The framework changes are generic and no-op by default:

  • Custom tokenizer path hook via BasePPOExp.get_tokenizer_path().
  • Custom worker-class hook via BasePPOExp.get_worker_classes().
  • Optional auxiliary policy loss hook on PolicyWorkerBase.
  • Optional Experience.extras for integration-provided tensors.
  • Metadata-driven zero-padding for auxiliary loss masks.
  • Optional token-only inference via skip_detokenize.
  • Generalized filtering of generator outputs so integrations can preserve custom output fields.

Design

This integration uses Harbor for terminal task execution: creating task containers, running shell commands, returning observations, and executing verifiers. Generation remains inside SkyRL/vLLM rather than delegating the full rollout loop to Harbor. This gives the trainer direct batched access to generated token ids, logprobs, masks, sampling controls, and ECHO-specific terminal-output masks needed for the auxiliary loss.

Setting:

trainer:
  algorithm:
    world_model_coeff: 0.0

recovers the vanilla GRPO baseline. Positive values enable the ECHO auxiliary environment-prediction loss.

Validation

We have run training jobs using this integration, including both vanilla GRPO and ECHO configs. The jobs successfully exercise rollout generation, terminal execution, policy training, checkpointing, and evaluation.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces an integration for training terminal agents with ECHO (Environment Cross-Entropy Hybrid Objective) under examples/train_integrations/echo_terminal/, adding custom dataset loading, rollout generation, tool-call parsing, and Harbor-backed environment execution. It also extends the SkyRL core to support auxiliary policy losses, custom tensor padding, and token-only inference optimizations. The review identified several critical issues: a potential ZeroDivisionError in the world-model filter when theta is 0.0, a performance bottleneck where Docker images are repeatedly rebuilt and destroyed on every batch, a potential loss explosion when full_observation_count is 0, a potential AttributeError in worker.py if aux_policy_metrics is None, global side effects from forcing the 'spawn' multiprocessing start method in the dataset constructor, and potential JSONDecodeError or maintainability issues when querying Docker images in harbor_environment.py.

Comment on lines +157 to +165
if cfg.wm_filter_min_correct_pct is not None:
theta = cfg.wm_filter_min_correct_pct
correct_indices = [idx for idx, is_correct in survivors if is_correct]
incorrect_indices = [idx for idx, is_correct in survivors if not is_correct]
if theta >= 1.0 or not correct_indices:
kept_indices = set(correct_indices)
else:
max_incorrect = int(len(correct_indices) * (1.0 - theta) / theta)
kept_indices = set(correct_indices) | set(incorrect_indices[:max_incorrect])
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

If cfg.wm_filter_min_correct_pct (theta) is set to 0.0, this line will raise a ZeroDivisionError. Consider adding a check to ensure theta > 0 before performing the division, or handle theta <= 0.0 gracefully.

Suggested change
if cfg.wm_filter_min_correct_pct is not None:
theta = cfg.wm_filter_min_correct_pct
correct_indices = [idx for idx, is_correct in survivors if is_correct]
incorrect_indices = [idx for idx, is_correct in survivors if not is_correct]
if theta >= 1.0 or not correct_indices:
kept_indices = set(correct_indices)
else:
max_incorrect = int(len(correct_indices) * (1.0 - theta) / theta)
kept_indices = set(correct_indices) | set(incorrect_indices[:max_incorrect])
if cfg.wm_filter_min_correct_pct is not None:
theta = cfg.wm_filter_min_correct_pct
correct_indices = [idx for idx, is_correct in survivors if is_correct]
incorrect_indices = [idx for idx, is_correct in survivors if not is_correct]
if theta >= 1.0 or not correct_indices:
kept_indices = set(correct_indices)
elif theta <= 0.0:
kept_indices = {idx for idx, _ in survivors}
else:
max_incorrect = int(len(correct_indices) * (1.0 - theta) / theta)
kept_indices = set(correct_indices) | set(incorrect_indices[:max_incorrect])

Comment on lines +108 to +114
provider = HarborEnvironmentProvider(
docker_memory_mb=self.generator_cfg.docker_memory_mb,
docker_cpus=self.generator_cfg.docker_cpus,
base_temp_dir=Path(self.generator_cfg.base_temp_dir) if self.generator_cfg.base_temp_dir else None,
max_concurrent_builds=self.generator_cfg.max_concurrent_builds,
max_build_retries=self.generator_cfg.max_build_retries,
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Instantiating HarborEnvironmentProvider inside generate and calling cleanup_batch (which deletes the built Docker images via docker rmi) on every batch means Docker images are rebuilt and destroyed on every single PPO generation step. If tasks are reused across iterations, this introduces a massive performance bottleneck. Consider persisting the provider or caching the built Docker images across generation calls so they do not need to be rebuilt and deleted repeatedly.

Comment on lines +68 to +69
if config.world_loss_normalization == "full_observation_tokens" and full_observation_count is not None:
denom = full_observation_count.to(device).float() + 1e-3
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

If full_observation_count is 0 (e.g., if no observations were added), denom will be 1e-3. Dividing by such a small value can cause the loss to explode (scaled by 1000x). It is safer to clamp the denominator to a minimum of 1.0 using torch.clamp(..., min=1.0).

Suggested change
if config.world_loss_normalization == "full_observation_tokens" and full_observation_count is not None:
denom = full_observation_count.to(device).float() + 1e-3
if config.world_loss_normalization == "full_observation_tokens" and full_observation_count is not None:
denom = torch.clamp(full_observation_count.to(device).float(), min=1.0)

}
for k, v in loss_metrics.items():
status["loss_metrics/" + k] = v
for k, v in aux_policy_metrics.items():
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

If aux_policy_metrics is returned as None by a custom implementation of compute_aux_policy_loss, calling .items() on it will raise an AttributeError. Consider using (aux_policy_metrics or {}).items() to be more defensive.

Suggested change
for k, v in aux_policy_metrics.items():
for k, v in (aux_policy_metrics or {}).items():

Comment on lines +103 to +106
if self.num_workers > 1:
import multiprocess

multiprocess.set_start_method("spawn", force=True)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Forcing the multiprocessing start method to 'spawn' globally with force=True inside a dataset constructor can have unintended side effects on other libraries (like Ray or PyTorch distributed) running in the same process space. It is generally safer to set this at the application entrypoint or let the underlying libraries manage their start methods.

Comment on lines +129 to +130
result = await self._environment._run_docker_compose_command(["images", "--format", "json"], check=False)
images = json.loads(result.stdout or "[]")
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Calling the protected method _run_docker_compose_command on self._environment is a maintainability issue. Additionally, json.loads(result.stdout or '[]') can raise a JSONDecodeError if the command output is not valid JSON (e.g., due to Docker daemon errors or unexpected output format). Consider wrapping this in a try-except block or using a public API if available.

@erictang000 erictang000 self-assigned this May 28, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants