Add ECHO terminal agent training integration#1716
Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces an integration for training terminal agents with ECHO (Environment Cross-Entropy Hybrid Objective) under examples/train_integrations/echo_terminal/, adding custom dataset loading, rollout generation, tool-call parsing, and Harbor-backed environment execution. It also extends the SkyRL core to support auxiliary policy losses, custom tensor padding, and token-only inference optimizations. The review identified several critical issues: a potential ZeroDivisionError in the world-model filter when theta is 0.0, a performance bottleneck where Docker images are repeatedly rebuilt and destroyed on every batch, a potential loss explosion when full_observation_count is 0, a potential AttributeError in worker.py if aux_policy_metrics is None, global side effects from forcing the 'spawn' multiprocessing start method in the dataset constructor, and potential JSONDecodeError or maintainability issues when querying Docker images in harbor_environment.py.
| if cfg.wm_filter_min_correct_pct is not None: | ||
| theta = cfg.wm_filter_min_correct_pct | ||
| correct_indices = [idx for idx, is_correct in survivors if is_correct] | ||
| incorrect_indices = [idx for idx, is_correct in survivors if not is_correct] | ||
| if theta >= 1.0 or not correct_indices: | ||
| kept_indices = set(correct_indices) | ||
| else: | ||
| max_incorrect = int(len(correct_indices) * (1.0 - theta) / theta) | ||
| kept_indices = set(correct_indices) | set(incorrect_indices[:max_incorrect]) |
There was a problem hiding this comment.
If cfg.wm_filter_min_correct_pct (theta) is set to 0.0, this line will raise a ZeroDivisionError. Consider adding a check to ensure theta > 0 before performing the division, or handle theta <= 0.0 gracefully.
| if cfg.wm_filter_min_correct_pct is not None: | |
| theta = cfg.wm_filter_min_correct_pct | |
| correct_indices = [idx for idx, is_correct in survivors if is_correct] | |
| incorrect_indices = [idx for idx, is_correct in survivors if not is_correct] | |
| if theta >= 1.0 or not correct_indices: | |
| kept_indices = set(correct_indices) | |
| else: | |
| max_incorrect = int(len(correct_indices) * (1.0 - theta) / theta) | |
| kept_indices = set(correct_indices) | set(incorrect_indices[:max_incorrect]) | |
| if cfg.wm_filter_min_correct_pct is not None: | |
| theta = cfg.wm_filter_min_correct_pct | |
| correct_indices = [idx for idx, is_correct in survivors if is_correct] | |
| incorrect_indices = [idx for idx, is_correct in survivors if not is_correct] | |
| if theta >= 1.0 or not correct_indices: | |
| kept_indices = set(correct_indices) | |
| elif theta <= 0.0: | |
| kept_indices = {idx for idx, _ in survivors} | |
| else: | |
| max_incorrect = int(len(correct_indices) * (1.0 - theta) / theta) | |
| kept_indices = set(correct_indices) | set(incorrect_indices[:max_incorrect]) |
| provider = HarborEnvironmentProvider( | ||
| docker_memory_mb=self.generator_cfg.docker_memory_mb, | ||
| docker_cpus=self.generator_cfg.docker_cpus, | ||
| base_temp_dir=Path(self.generator_cfg.base_temp_dir) if self.generator_cfg.base_temp_dir else None, | ||
| max_concurrent_builds=self.generator_cfg.max_concurrent_builds, | ||
| max_build_retries=self.generator_cfg.max_build_retries, | ||
| ) |
There was a problem hiding this comment.
Instantiating HarborEnvironmentProvider inside generate and calling cleanup_batch (which deletes the built Docker images via docker rmi) on every batch means Docker images are rebuilt and destroyed on every single PPO generation step. If tasks are reused across iterations, this introduces a massive performance bottleneck. Consider persisting the provider or caching the built Docker images across generation calls so they do not need to be rebuilt and deleted repeatedly.
| if config.world_loss_normalization == "full_observation_tokens" and full_observation_count is not None: | ||
| denom = full_observation_count.to(device).float() + 1e-3 |
There was a problem hiding this comment.
If full_observation_count is 0 (e.g., if no observations were added), denom will be 1e-3. Dividing by such a small value can cause the loss to explode (scaled by 1000x). It is safer to clamp the denominator to a minimum of 1.0 using torch.clamp(..., min=1.0).
| if config.world_loss_normalization == "full_observation_tokens" and full_observation_count is not None: | |
| denom = full_observation_count.to(device).float() + 1e-3 | |
| if config.world_loss_normalization == "full_observation_tokens" and full_observation_count is not None: | |
| denom = torch.clamp(full_observation_count.to(device).float(), min=1.0) |
| } | ||
| for k, v in loss_metrics.items(): | ||
| status["loss_metrics/" + k] = v | ||
| for k, v in aux_policy_metrics.items(): |
There was a problem hiding this comment.
If aux_policy_metrics is returned as None by a custom implementation of compute_aux_policy_loss, calling .items() on it will raise an AttributeError. Consider using (aux_policy_metrics or {}).items() to be more defensive.
| for k, v in aux_policy_metrics.items(): | |
| for k, v in (aux_policy_metrics or {}).items(): |
| if self.num_workers > 1: | ||
| import multiprocess | ||
|
|
||
| multiprocess.set_start_method("spawn", force=True) |
There was a problem hiding this comment.
Forcing the multiprocessing start method to 'spawn' globally with force=True inside a dataset constructor can have unintended side effects on other libraries (like Ray or PyTorch distributed) running in the same process space. It is generally safer to set this at the application entrypoint or let the underlying libraries manage their start methods.
| result = await self._environment._run_docker_compose_command(["images", "--format", "json"], check=False) | ||
| images = json.loads(result.stdout or "[]") |
There was a problem hiding this comment.
Calling the protected method _run_docker_compose_command on self._environment is a maintainability issue. Additionally, json.loads(result.stdout or '[]') can raise a JSONDecodeError if the command output is not valid JSON (e.g., due to Docker daemon errors or unexpected output format). Consider wrapping this in a try-except block or using a public API if available.
Add ECHO terminal agent training integration
This PR adds an ECHO terminal agent training integration to SkyRL.
ECHO combines standard policy-gradient RL with an auxiliary cross-entropy loss on terminal-output tokens observed in the same rollout. The implementation keeps the core SkyRL changes small by adding generic hooks for auxiliary policy losses and extra trajectory tensors, then implements the terminal-agent integration under
examples/train_integrations/echo_terminal.Paper: https://arxiv.org/abs/2605.24517
Blog: https://x.com/DimitrisPapail/status/2056368948870811746
What this adds
examples/train_integrations/echo_terminalintegration.Core SkyRL hooks
The framework changes are generic and no-op by default:
BasePPOExp.get_tokenizer_path().BasePPOExp.get_worker_classes().PolicyWorkerBase.Experience.extrasfor integration-provided tensors.skip_detokenize.Design
This integration uses Harbor for terminal task execution: creating task containers, running shell commands, returning observations, and executing verifiers. Generation remains inside SkyRL/vLLM rather than delegating the full rollout loop to Harbor. This gives the trainer direct batched access to generated token ids, logprobs, masks, sampling controls, and ECHO-specific terminal-output masks needed for the auxiliary loss.
Setting:
recovers the vanilla GRPO baseline. Positive values enable the ECHO auxiliary environment-prediction loss.
Validation
We have run training jobs using this integration, including both vanilla GRPO and ECHO configs. The jobs successfully exercise rollout generation, terminal execution, policy training, checkpointing, and evaluation.