From fde7d1259f6406946ed15bca18c9964de1d77639 Mon Sep 17 00:00:00 2001 From: Vaishnavi Shrivastava Date: Thu, 28 May 2026 00:09:26 -0700 Subject: [PATCH 1/2] Add trainer hooks for auxiliary policy losses --- .../skyrl_train/inference_engines/base.py | 11 +++++ .../remote_inference_client.py | 6 ++- skyrl/backends/skyrl_train/training_batch.py | 6 ++- skyrl/backends/skyrl_train/workers/worker.py | 44 ++++++++++++++++++- .../skyrl_train/workers/worker_utils.py | 23 +++++++++- skyrl/train/dataset/replay_buffer.py | 3 ++ skyrl/train/entrypoints/main_base.py | 35 +++++++++------ skyrl/train/utils/trainer_utils.py | 24 ++++------ 8 files changed, 116 insertions(+), 36 deletions(-) diff --git a/skyrl/backends/skyrl_train/inference_engines/base.py b/skyrl/backends/skyrl_train/inference_engines/base.py index 1ec593aae8..4a79b312f8 100644 --- a/skyrl/backends/skyrl_train/inference_engines/base.py +++ b/skyrl/backends/skyrl_train/inference_engines/base.py @@ -1,6 +1,11 @@ from abc import ABC, abstractmethod from typing import TYPE_CHECKING, Any, Dict, Hashable, List, Optional, TypedDict +try: + from typing import NotRequired +except ImportError: + from typing_extensions import NotRequired + if TYPE_CHECKING: from skyrl.backends.skyrl_train.weight_sync import WeightUpdateRequest from skyrl.backends.skyrl_train.weight_sync.transfer_strategy import ( @@ -29,6 +34,9 @@ class InferenceEngineInput(TypedDict): sampling_params: Optional[Dict[str, Any]] session_ids: Optional[List[Hashable]] mm_features: Optional[List[MultiModalFeatures]] + # Token-only callers can skip response detokenization in inference clients + # that otherwise perform an additional decode round trip. + skip_detokenize: NotRequired[bool] class InferenceEngineOutput(TypedDict): @@ -39,6 +47,9 @@ class InferenceEngineOutput(TypedDict): # represent the same text with tokens. Therefore, for multi-turn generation, # please use token-in-token-out to ensure correctness. # `skip_special_tokens=True` is needed because string responses do not include EOS tokens like `<|im_end|>` + # Token-only callers may set `skip_detokenize` in `InferenceEngineInput`; + # in that case `responses` may contain empty strings while `response_ids` + # remains populated. responses: List[str] response_ids: List[List[int]] stop_reasons: List[str] diff --git a/skyrl/backends/skyrl_train/inference_servers/remote_inference_client.py b/skyrl/backends/skyrl_train/inference_servers/remote_inference_client.py index f07fb794e0..7aed2e9709 100644 --- a/skyrl/backends/skyrl_train/inference_servers/remote_inference_client.py +++ b/skyrl/backends/skyrl_train/inference_servers/remote_inference_client.py @@ -384,6 +384,7 @@ async def generate( session_ids = input_batch.get("session_ids") mm_features = input_batch.get("mm_features") + skip_detokenize = bool(input_batch.get("skip_detokenize", False)) get_logprobs = sampling_params.get("logprobs") is not None # Two semaphores decouple the generate and detokenize stages: @@ -425,7 +426,10 @@ async def _throttled_detokenize(token_ids: List[int]) -> str: return (await self.detokenize([token_ids]))[0] raw_results = await asyncio.gather(*[_throttled_generate(idx) for idx in range(batch_size)]) - responses = await asyncio.gather(*[_throttled_detokenize(r["response_ids"]) for r in raw_results]) + if skip_detokenize: + responses = [""] * len(raw_results) + else: + responses = await asyncio.gather(*[_throttled_detokenize(r["response_ids"]) for r in raw_results]) rollout_expert_indices = [r.get("routed_experts") for r in raw_results] has_routed_experts = any(x is not None for x in rollout_expert_indices) diff --git a/skyrl/backends/skyrl_train/training_batch.py b/skyrl/backends/skyrl_train/training_batch.py index 9fba75d541..480a5fd266 100644 --- a/skyrl/backends/skyrl_train/training_batch.py +++ b/skyrl/backends/skyrl_train/training_batch.py @@ -502,6 +502,8 @@ def pad_training_input_batch(unpadded_batch: TrainingInputBatch, pad_size: int) unpadded_batch.metadata["pad_size"] = 0 return unpadded_batch + zero_pad_keys = set((unpadded_batch.metadata or {}).get("zero_pad_keys", [])) + # Pad each tensor depending on its type. new_tensors = {} for key, tensor in unpadded_batch.items(): @@ -513,8 +515,8 @@ def pad_training_input_batch(unpadded_batch: TrainingInputBatch, pad_size: int) assert len(tensor) > 0, f"Cannot pad empty TensorList field {key!r}" padding = TensorList([tensor[0].clone() for _ in range(pad_size)]) new_tensors[key] = TensorList.cat([tensor, padding]) - elif key == "loss_mask": - # Ensures that padding tensors don't count towards the loss + elif key == "loss_mask" or key in zero_pad_keys: + # Ensures that padding tensors don't count towards losses. additional_dims = tensor.shape[1:] padding_tensor = torch.zeros(pad_size, *additional_dims, dtype=tensor.dtype, device=tensor.device) new_tensors[key] = torch.cat([tensor, padding_tensor], dim=0) diff --git a/skyrl/backends/skyrl_train/workers/worker.py b/skyrl/backends/skyrl_train/workers/worker.py index e5364ab043..0f9ad4b990 100644 --- a/skyrl/backends/skyrl_train/workers/worker.py +++ b/skyrl/backends/skyrl_train/workers/worker.py @@ -711,10 +711,16 @@ def forward_backward( all_metrics = defaultdict(list) all_loss_fn_outputs = [] # Handle separately from scalar metrics + aux_policy_loss_context = self.get_aux_policy_loss_context(data) + for micro_batch in BatchIterator(data, micro_batch_size, drop_last=False): microbatch_weight = micro_batch_size / len(data) metrics = self._forward_backward_micro( - micro_batch, microbatch_weight, loss_fn=loss_fn, loss_fn_config=loss_fn_config + micro_batch, + microbatch_weight, + loss_fn=loss_fn, + loss_fn_config=loss_fn_config, + aux_policy_loss_context=aux_policy_loss_context, ) # Extract loss_fn_outputs before reduce_metrics (it's not a scalar metric) @@ -737,12 +743,29 @@ def forward_backward( return WorkerOutput(loss_fn_outputs=all_loss_fn_outputs, metrics=result) + def get_aux_policy_loss_context(self, data: TrainingInputBatch) -> Dict[str, Any]: + return {} + + def compute_aux_policy_loss( + self, + *, + action_log_probs: torch.Tensor, + policy_loss: torch.Tensor, + experience: Experience, + loss_config: Any, + microbatch_weight: float, + grad_sum_correction_factor: float, + aux_policy_loss_context: Optional[Dict[str, Any]] = None, + ) -> tuple[Optional[torch.Tensor], Dict[str, Any]]: + return None, {} + def _forward_backward_micro( self, experience: Experience, microbatch_weight: float, loss_fn: Optional[str] = None, loss_fn_config: Optional[Dict[str, Any]] = None, + aux_policy_loss_context: Optional[Dict[str, Any]] = None, ) -> Dict[str, float]: """ Perform forward and backward pass for one micro batch. @@ -901,6 +924,17 @@ def _forward_backward_micro( # NOTE: The KL and entropy loss terms are not pre-scaled, # so we just average them across microbatches and DP workers. loss = policy_loss * grad_sum_correction_factor + (kl_loss_term - entropy_loss_term) * microbatch_weight + aux_policy_loss, aux_policy_metrics = self.compute_aux_policy_loss( + action_log_probs=action_log_probs, + policy_loss=policy_loss, + experience=experience, + loss_config=loss_config, + microbatch_weight=microbatch_weight, + grad_sum_correction_factor=grad_sum_correction_factor, + aux_policy_loss_context=aux_policy_loss_context, + ) + if aux_policy_loss is not None: + loss = loss + aux_policy_loss unscaled_loss = loss / grad_sum_correction_factor self.strategy.backward(loss, self.model, self.optimizer) @@ -934,6 +968,14 @@ def _forward_backward_micro( } for k, v in loss_metrics.items(): status["loss_metrics/" + k] = v + for k, v in aux_policy_metrics.items(): + if v is None: + continue + status_key = k[len("status/") :] if k.startswith("status/") else "loss_metrics/" + k + if hasattr(v, "detach"): + status[status_key] = v.detach().item() + else: + status[status_key] = v if self.cfg.algorithm.use_kl_loss: status["policy_kl"] = kl_loss.item() diff --git a/skyrl/backends/skyrl_train/workers/worker_utils.py b/skyrl/backends/skyrl_train/workers/worker_utils.py index cecaccc17e..a88e0b0fe1 100644 --- a/skyrl/backends/skyrl_train/workers/worker_utils.py +++ b/skyrl/backends/skyrl_train/workers/worker_utils.py @@ -102,8 +102,26 @@ def __next__(self) -> Experience: @staticmethod def batch_to_experience(batch: TrainingInputBatch): - # TODO (sumanthrh): other keys are not permitted right now, can go into info - # TODO: this conversion is hidden right now, might need to be surfaced in worker explicitly. + # Preserve integration-provided tensor fields without requiring SkyRL core + # to know each auxiliary objective by name. + known_fields = { + "sequences", + "action_log_probs", + "base_action_log_probs", + "values", + "returns", + "advantages", + "attention_mask", + "loss_mask", + "response_mask", + "rollout_logprobs", + "rollout_expert_indices", + "pixel_values", + "image_grid_thw", + "rewards", + "kl", + } + extras = {key: value for key, value in batch.items() if key not in known_fields and value is not None} exp = Experience( sequences=batch["sequences"], action_log_probs=batch.get("action_log_probs"), @@ -125,5 +143,6 @@ def batch_to_experience(batch: TrainingInputBatch): # Multi-modal vision fields (may be absent for text-only) pixel_values=batch.get("pixel_values"), image_grid_thw=batch.get("image_grid_thw"), + extras=extras, ) return exp diff --git a/skyrl/train/dataset/replay_buffer.py b/skyrl/train/dataset/replay_buffer.py index c1d50ad98b..6b35d9414c 100644 --- a/skyrl/train/dataset/replay_buffer.py +++ b/skyrl/train/dataset/replay_buffer.py @@ -74,6 +74,7 @@ class Experience: metadata: Optional[Dict[str, Any]] = None pixel_values: Optional[TensorList] = None image_grid_thw: Optional[TensorList] = None + extras: Optional[Dict[str, Any]] = None @torch.no_grad() def to_device(self, device: torch.device) -> None: @@ -102,6 +103,8 @@ def to_device(self, device: torch.device) -> None: self.pixel_values = self.pixel_values.to(device) if self.image_grid_thw is not None: self.image_grid_thw = self.image_grid_thw.to(device) + if self.extras is not None: + self.extras = {key: to(value, device) for key, value in self.extras.items()} def pin_memory(self): self.sequences = pin_memory(self.sequences) diff --git a/skyrl/train/entrypoints/main_base.py b/skyrl/train/entrypoints/main_base.py index 7d9d170a1a..804fe8b405 100644 --- a/skyrl/train/entrypoints/main_base.py +++ b/skyrl/train/entrypoints/main_base.py @@ -140,7 +140,7 @@ def __init__(self, cfg: SkyRLTrainConfig): """ self.cfg = cfg self.tokenizer = get_tokenizer( - self.cfg.trainer.policy.model.path, + self.get_tokenizer_path(), trust_remote_code=True, use_fast=not self.cfg.trainer.disable_fast_tokenizer, padding_side="left", @@ -155,6 +155,9 @@ def __init__(self, cfg: SkyRLTrainConfig): self._decode_server_groups = None self._inference_router = None + def get_tokenizer_path(self) -> str: + return self.cfg.trainer.policy.model.path + @staticmethod def get_cfg_as_str(cfg: SkyRLTrainConfig) -> str: return get_config_as_yaml_str(cfg) @@ -342,19 +345,8 @@ def _get_new_inference_client(self): return client - def _setup_trainer(self): - """Setup and return the trainer. - - Instantiates the trainer and all the associated models for training. - - Returns: - RayPPOTrainer: The trainer. - """ - logger.info(self.get_cfg_as_str(self.cfg)) - os.makedirs(self.cfg.trainer.export_path, exist_ok=True) - os.makedirs(self.cfg.trainer.ckpt_path, exist_ok=True) - - if self.cfg.trainer.strategy == "fsdp": + def get_worker_classes(self): + if self.cfg.trainer.strategy in ("fsdp", "fsdp2"): from skyrl.backends.skyrl_train.workers.fsdp.fsdp_worker import ( CriticWorker, PolicyWorker, @@ -368,6 +360,21 @@ def _setup_trainer(self): ) else: raise ValueError(f"Unknown strategy type: {self.cfg.trainer.strategy}") + return PolicyWorker, CriticWorker, RefWorker + + def _setup_trainer(self): + """Setup and return the trainer. + + Instantiates the trainer and all the associated models for training. + + Returns: + RayPPOTrainer: The trainer. + """ + logger.info(self.get_cfg_as_str(self.cfg)) + os.makedirs(self.cfg.trainer.export_path, exist_ok=True) + os.makedirs(self.cfg.trainer.ckpt_path, exist_ok=True) + + PolicyWorker, CriticWorker, RefWorker = self.get_worker_classes() # NOTE (sumanthrh): Instantiate tracker before trainer init. # We have custom validation before this step to give better error messages. diff --git a/skyrl/train/utils/trainer_utils.py b/skyrl/train/utils/trainer_utils.py index 23b53b2819..0641502edc 100644 --- a/skyrl/train/utils/trainer_utils.py +++ b/skyrl/train/utils/trainer_utils.py @@ -276,7 +276,6 @@ def dump_per_dataset_eval_results( "score": concat_generator_outputs["rewards"][i], "stop_reason": concat_generator_outputs.get("stop_reasons", [None] * len(input_prompts))[i], "env_class": concat_all_envs[i], - "env_extras": concat_env_extras[i], "data_source": data_source, } f.write(json.dumps(entry, ensure_ascii=False) + "\n") @@ -550,21 +549,14 @@ def get_bad_sample_replacements(good_uids: List[str], bad_uids: List[str]) -> Li def filter_generator_output(output: GeneratorOutput, kept_indices: List[int]) -> GeneratorOutput: """Filter GeneratorOutput based on kept indices.""" - filtered = { - "prompt_token_ids": [output["prompt_token_ids"][i] for i in kept_indices], - "response_ids": [output["response_ids"][i] for i in kept_indices], - "rewards": [output["rewards"][i] for i in kept_indices], - "loss_masks": [output["loss_masks"][i] for i in kept_indices], - "stop_reasons": None, - "rollout_metrics": output.get("rollout_metrics"), - "rollout_logprobs": ( - [output["rollout_logprobs"][i] for i in kept_indices] if output["rollout_logprobs"] else None - ), - } - - if output.get("stop_reasons"): - filtered["stop_reasons"] = [output["stop_reasons"][i] for i in kept_indices] - + filtered = {} + for key, value in output.items(): + if key == "rollout_metrics": + filtered[key] = value + elif isinstance(value, list): + filtered[key] = [value[i] for i in kept_indices] + else: + filtered[key] = value return filtered From efd864b9f5ec44908780cd5f4160a6cebf191eef Mon Sep 17 00:00:00 2001 From: Vaishnavi Shrivastava Date: Thu, 28 May 2026 00:09:49 -0700 Subject: [PATCH 2/2] Add ECHO terminal-agent training example --- .../echo_terminal/README.md | 80 ++ .../echo_terminal/__init__.py | 1 + .../echo_terminal/chat_template.py | 97 ++ .../qwen3_xml_tool_calling.jinja | 154 ++++ .../echo_terminal/configs/qwen3_8b_rl.yaml | 118 +++ .../configs/qwen3_8b_rl_wm05.yaml | 118 +++ .../echo_terminal/dataset.py | 140 +++ .../echo_terminal/entrypoint.py | 227 +++++ .../echo_terminal/generator.py | 855 ++++++++++++++++++ .../echo_terminal/harbor_environment.py | 435 +++++++++ .../echo_terminal/interaction.py | 78 ++ .../echo_terminal/parsers.py | 199 ++++ .../echo_terminal/prompts.py | 88 ++ .../echo_terminal/run_echo_terminal.sh | 6 + .../train_integrations/echo_terminal/tools.py | 55 ++ .../echo_terminal/world_modeling/__init__.py | 1 + .../echo_terminal/world_modeling/config.py | 35 + .../world_modeling/fsdp_worker.py | 87 ++ .../echo_terminal/world_modeling/loss.py | 114 +++ .../echo_terminal/world_modeling/trainer.py | 235 +++++ 20 files changed, 3123 insertions(+) create mode 100644 examples/train_integrations/echo_terminal/README.md create mode 100644 examples/train_integrations/echo_terminal/__init__.py create mode 100644 examples/train_integrations/echo_terminal/chat_template.py create mode 100644 examples/train_integrations/echo_terminal/chat_templates/qwen3_xml_tool_calling.jinja create mode 100644 examples/train_integrations/echo_terminal/configs/qwen3_8b_rl.yaml create mode 100644 examples/train_integrations/echo_terminal/configs/qwen3_8b_rl_wm05.yaml create mode 100644 examples/train_integrations/echo_terminal/dataset.py create mode 100644 examples/train_integrations/echo_terminal/entrypoint.py create mode 100644 examples/train_integrations/echo_terminal/generator.py create mode 100644 examples/train_integrations/echo_terminal/harbor_environment.py create mode 100644 examples/train_integrations/echo_terminal/interaction.py create mode 100644 examples/train_integrations/echo_terminal/parsers.py create mode 100644 examples/train_integrations/echo_terminal/prompts.py create mode 100755 examples/train_integrations/echo_terminal/run_echo_terminal.sh create mode 100644 examples/train_integrations/echo_terminal/tools.py create mode 100644 examples/train_integrations/echo_terminal/world_modeling/__init__.py create mode 100644 examples/train_integrations/echo_terminal/world_modeling/config.py create mode 100644 examples/train_integrations/echo_terminal/world_modeling/fsdp_worker.py create mode 100644 examples/train_integrations/echo_terminal/world_modeling/loss.py create mode 100644 examples/train_integrations/echo_terminal/world_modeling/trainer.py diff --git a/examples/train_integrations/echo_terminal/README.md b/examples/train_integrations/echo_terminal/README.md new file mode 100644 index 0000000000..c77dab8a60 --- /dev/null +++ b/examples/train_integrations/echo_terminal/README.md @@ -0,0 +1,80 @@ +# ECHO Terminal-Agent Training + +This example trains terminal agents with ECHO, an environment cross-entropy hybrid objective. ECHO combines standard policy-gradient RL with an auxiliary cross-entropy loss on terminal-output tokens observed in the same rollout. + +SkyRL provides the core RL training stack, distributed worker execution, and vLLM-backed inference. This example adds the terminal-agent dataset loader, prompt formatting, tool-call parsing, Harbor-backed environment execution, rollout construction, token masks, and the optional environment-prediction loss. + +## Structure + +```text +examples/train_integrations/echo_terminal/ + entrypoint.py # Training entrypoint + generator.py # Terminal rollout loop and SkyRL trajectory construction + dataset.py # Parquet dataset loader and prompt tokenization + harbor_environment.py # Harbor container execution wrapper + interaction.py # Rollout transcript and token-mask bookkeeping + parsers.py # Tool-call parsing + prompts.py # Terminal-agent system prompts + tools.py # Tool schemas + chat_template.py # Chat-template loading helpers + chat_templates/qwen3_xml_tool_calling.jinja + world_modeling/ + config.py # ECHO config extensions + fsdp_worker.py # FSDP auxiliary-loss hook implementation + loss.py # Environment-token CE loss + trainer.py # Training-batch conversion for ECHO masks + configs/ + qwen3_8b_rl.yaml # Vanilla GRPO baseline + qwen3_8b_rl_wm05.yaml # GRPO + ECHO loss, lambda=0.05 +``` + +## Quick Start + +Install SkyRL with the FSDP and Harbor dependencies: + +```bash +cd SkyRL +pip install -e ".[fsdp,harbor]" +``` + +Edit the train and validation parquet paths in the config you want to run: + +```yaml +data: + train_data: + - name: terminal_agent_train + path: /path/to/train.parquet + val_data: + - name: terminal_agent_train + path: /path/to/val.parquet +``` + +Set an output directory and launch the vanilla GRPO baseline: + +```bash +export OUTPUT_DIR=/path/to/outputs/qwen3_8b_rl +export CONFIG_PATH=examples/train_integrations/echo_terminal/configs/qwen3_8b_rl.yaml +bash examples/train_integrations/echo_terminal/run_echo_terminal.sh +``` + +Launch ECHO with the auxiliary environment-prediction loss: + +```bash +export OUTPUT_DIR=/path/to/outputs/qwen3_8b_rl_wm05 +export CONFIG_PATH=examples/train_integrations/echo_terminal/configs/qwen3_8b_rl_wm05.yaml +bash examples/train_integrations/echo_terminal/run_echo_terminal.sh +``` + +Checkpoints are written to `${OUTPUT_DIR}/ckpts`, and SkyRL logs are written to `${OUTPUT_DIR}/skyrl_logs`. + +## Design + +The rollout loop is handled directly in this example rather than through Harbor's full rollout API. Harbor is used as the terminal task backend: it starts the task containers, runs shell commands, returns terminal observations, and executes verifiers. SkyRL/vLLM owns model generation so the training code has direct, batched access to generated token ids, logprobs, attention masks, sampling controls, and ECHO-specific token masks. + +During training, the standard GRPO loss is computed on model-generated action tokens. When `trainer.algorithm.world_model_coeff > 0`, ECHO also computes cross entropy on selected terminal-output tokens from the same trajectory: + +```text +L = L_GRPO(action tokens) + world_model_coeff * CE(terminal-output tokens) +``` + +Setting `world_model_coeff: 0.0` recovers the vanilla GRPO baseline. The included ECHO config uses `world_model_coeff: 0.05` and `generator.world_loss_target: env_only`, which trains on terminal environment-output tokens while leaving the RL action-token mask unchanged. diff --git a/examples/train_integrations/echo_terminal/__init__.py b/examples/train_integrations/echo_terminal/__init__.py new file mode 100644 index 0000000000..a9a2c5b3bb --- /dev/null +++ b/examples/train_integrations/echo_terminal/__init__.py @@ -0,0 +1 @@ +__all__ = [] diff --git a/examples/train_integrations/echo_terminal/chat_template.py b/examples/train_integrations/echo_terminal/chat_template.py new file mode 100644 index 0000000000..7939fd8586 --- /dev/null +++ b/examples/train_integrations/echo_terminal/chat_template.py @@ -0,0 +1,97 @@ +from __future__ import annotations + +from pathlib import Path +from typing import Any + +from transformers import PreTrainedTokenizerBase + +_PROBE_SENTINEL = "OBSERVATION_PROBE_TOKEN_ABCXYZ_12345" +_PROBE_MESSAGES: list[dict[str, str]] = [ + {"role": "system", "content": "You are a helpful agent with access to bash."}, + {"role": "user", "content": "Run `ls /tmp` and report the result."}, + { + "role": "assistant", + "content": ( + "\nI should call the bash tool with the command.\n\n\n" + "\n\n\nls /tmp\n" + "\n\n" + ), + }, +] + + +def check_role_roundtrip( + tokenizer: PreTrainedTokenizerBase, + role: str, + *, + template_kwargs: dict[str, Any] | None = None, + sentinel: str = _PROBE_SENTINEL, +) -> tuple[bool, str]: + template_kwargs = template_kwargs or {} + try: + before = tokenizer.apply_chat_template( + _PROBE_MESSAGES, + add_generation_prompt=False, + tokenize=True, + return_dict=False, + **template_kwargs, + ) + after = tokenizer.apply_chat_template( + _PROBE_MESSAGES + [{"role": role, "content": sentinel}], + add_generation_prompt=True, + tokenize=True, + return_dict=False, + **template_kwargs, + ) + except Exception as exc: + return False, f"apply_chat_template raised {type(exc).__name__}: {exc}" + + if after[: len(before)] != before: + return False, f"prefix mismatch when {role!r} was appended" + delta = after[len(before) :] + if not delta: + return False, "empty delta" + decoded = tokenizer.decode(delta, skip_special_tokens=False) + if sentinel not in decoded: + return False, f"sentinel content not found for role {role!r}" + return True, "" + + +def choose_obs_role( + tokenizer: PreTrainedTokenizerBase, + candidates: tuple[str, ...] = ("tool", "user"), + *, + template_kwargs: dict[str, Any] | None = None, +) -> str: + failures: list[str] = [] + for role in candidates: + ok, reason = check_role_roundtrip(tokenizer, role, template_kwargs=template_kwargs) + if ok: + return role + failures.append(f" - {role!r}: {reason}") + raise RuntimeError( + f"None of {list(candidates)!r} produced a usable observation role for " + f"tokenizer {tokenizer.name_or_path!r}. Failures:\n" + "\n".join(failures) + ) + + +def resolve_chat_template_path(template_path: str | Path) -> Path: + path = Path(template_path).expanduser() + if path.is_absolute(): + candidates = [path] + else: + module_path = Path(__file__).resolve() + candidates = [ + Path.cwd() / path, + module_path.parent / path, + module_path.parents[2] / path, + ] + + for candidate in candidates: + if candidate.exists(): + return candidate + raise FileNotFoundError(f"Chat template not found at {template_path!r}; checked {candidates!r}") + + +def load_chat_template(tokenizer: PreTrainedTokenizerBase, template_path: str | Path) -> None: + tokenizer.chat_template = resolve_chat_template_path(template_path).read_text() diff --git a/examples/train_integrations/echo_terminal/chat_templates/qwen3_xml_tool_calling.jinja b/examples/train_integrations/echo_terminal/chat_templates/qwen3_xml_tool_calling.jinja new file mode 100644 index 0000000000..a585dec894 --- /dev/null +++ b/examples/train_integrations/echo_terminal/chat_templates/qwen3_xml_tool_calling.jinja @@ -0,0 +1,154 @@ +{%- set image_count = namespace(value=0) %} +{%- set video_count = namespace(value=0) %} +{%- macro render_content(content, do_vision_count, is_system_content=false) %} + {%- if content is string %} + {{- content }} + {%- elif content is iterable and content is not mapping %} + {%- for item in content %} + {%- if 'image' in item or 'image_url' in item or item.type == 'image' %} + {%- if is_system_content %} + {{- raise_exception('System message cannot contain images.') }} + {%- endif %} + {%- if do_vision_count %} + {%- set image_count.value = image_count.value + 1 %} + {%- endif %} + {%- if add_vision_id %} + {{- 'Picture ' ~ image_count.value ~ ': ' }} + {%- endif %} + {{- '<|vision_start|><|image_pad|><|vision_end|>' }} + {%- elif 'video' in item or item.type == 'video' %} + {%- if is_system_content %} + {{- raise_exception('System message cannot contain videos.') }} + {%- endif %} + {%- if do_vision_count %} + {%- set video_count.value = video_count.value + 1 %} + {%- endif %} + {%- if add_vision_id %} + {{- 'Video ' ~ video_count.value ~ ': ' }} + {%- endif %} + {{- '<|vision_start|><|video_pad|><|vision_end|>' }} + {%- elif 'text' in item %} + {{- item.text }} + {%- else %} + {{- raise_exception('Unexpected item type in content.') }} + {%- endif %} + {%- endfor %} + {%- elif content is none or content is undefined %} + {{- '' }} + {%- else %} + {{- raise_exception('Unexpected content type.') }} + {%- endif %} +{%- endmacro %} +{%- if not messages %} + {{- raise_exception('No messages provided.') }} +{%- endif %} +{%- if tools and tools is iterable and tools is not mapping %} + {{- '<|im_start|>system\n' }} + {{- "# Tools\n\nYou have access to the following functions:\n\n" }} + {%- for tool in tools %} + {{- "\n" }} + {{- tool | tojson }} + {%- endfor %} + {{- "\n" }} + {{- '\n\nIf you choose to call a function ONLY reply in the following format with NO suffix:\n\n\n\n\nvalue_1\n\n\nThis is the value for the second parameter\nthat can span\nmultiple lines\n\n\n\n\n\nReminder:\n- Function calls MUST follow the specified format: an inner block must be nested within XML tags\n- Required parameters MUST be specified\n- You may provide optional reasoning for your function call in natural language BEFORE the function call, but NOT after\n- If there is no function call available, answer the question like normal with your current knowledge and do not tell the user about function calls\n' }} + {%- if messages[0].role == 'system' %} + {%- set content = render_content(messages[0].content, false, true)|trim %} + {%- if content %} + {{- '\n\n' + content }} + {%- endif %} + {%- endif %} + {{- '<|im_end|>\n' }} +{%- else %} + {%- if messages[0].role == 'system' %} + {%- set content = render_content(messages[0].content, false, true)|trim %} + {{- '<|im_start|>system\n' + content + '<|im_end|>\n' }} + {%- endif %} +{%- endif %} +{%- set ns = namespace(multi_step_tool=true, last_query_index=messages|length - 1) %} +{%- for message in messages[::-1] %} + {%- set index = (messages|length - 1) - loop.index0 %} + {%- if ns.multi_step_tool and message.role == "user" %} + {%- set content = render_content(message.content, false)|trim %} + {%- if not(content.startswith('') and content.endswith('')) %} + {%- set ns.multi_step_tool = false %} + {%- set ns.last_query_index = index %} + {%- endif %} + {%- endif %} +{%- endfor %} +{%- if ns.multi_step_tool %} + {{- raise_exception('No user query found in messages.') }} +{%- endif %} +{%- for message in messages %} + {%- set content = render_content(message.content, true)|trim %} + {%- if message.role == "system" %} + {%- if not loop.first %} + {{- raise_exception('System message must be at the beginning.') }} + {%- endif %} + {%- elif message.role == "user" %} + {{- '<|im_start|>' + message.role + '\n' + content + '<|im_end|>' + '\n' }} + {%- elif message.role == "assistant" %} + {%- set reasoning_content = '' %} + {%- if message.reasoning_content is string %} + {%- set reasoning_content = message.reasoning_content %} + {%- else %} + {%- if '' in content %} + {%- set reasoning_content = content.split('')[0].rstrip('\n').split('')[-1].lstrip('\n') %} + {%- set content = content.split('')[-1].lstrip('\n') %} + {%- endif %} + {%- endif %} + {%- set reasoning_content = reasoning_content|trim %} + {%- if loop.index0 > ns.last_query_index %} + {{- '<|im_start|>' + message.role + '\n\n' + reasoning_content + '\n\n\n' + content }} + {%- else %} + {{- '<|im_start|>' + message.role + '\n' + content }} + {%- endif %} + {%- if message.tool_calls and message.tool_calls is iterable and message.tool_calls is not mapping %} + {%- for tool_call in message.tool_calls %} + {%- if tool_call.function is defined %} + {%- set tool_call = tool_call.function %} + {%- endif %} + {%- if loop.first %} + {%- if content|trim %} + {{- '\n\n\n\n' }} + {%- else %} + {{- '\n\n' }} + {%- endif %} + {%- else %} + {{- '\n\n\n' }} + {%- endif %} + {%- if tool_call.arguments is defined %} + {%- for args_name, args_value in tool_call.arguments|items %} + {{- '\n' }} + {%- set args_value = args_value | tojson | safe if args_value is mapping or (args_value is sequence and args_value is not string) else args_value | string %} + {{- args_value }} + {{- '\n\n' }} + {%- endfor %} + {%- endif %} + {{- '\n' }} + {%- endfor %} + {%- endif %} + {{- '<|im_end|>\n' }} + {%- elif message.role == "tool" %} + {%- if loop.previtem and loop.previtem.role != "tool" %} + {{- '<|im_start|>user' }} + {%- endif %} + {{- '\n\n' }} + {{- content }} + {{- '\n' }} + {%- if not loop.last and loop.nextitem.role != "tool" %} + {{- '<|im_end|>\n' }} + {%- elif loop.last %} + {{- '<|im_end|>\n' }} + {%- endif %} + {%- else %} + {{- raise_exception('Unexpected message role.') }} + {%- endif %} +{%- endfor %} +{%- if add_generation_prompt %} + {{- '<|im_start|>assistant\n' }} + {%- if enable_thinking is defined and enable_thinking is false %} + {{- '\n\n\n\n' }} + {%- else %} + {{- '\n' }} + {%- endif %} +{%- endif %} \ No newline at end of file diff --git a/examples/train_integrations/echo_terminal/configs/qwen3_8b_rl.yaml b/examples/train_integrations/echo_terminal/configs/qwen3_8b_rl.yaml new file mode 100644 index 0000000000..8744264a53 --- /dev/null +++ b/examples/train_integrations/echo_terminal/configs/qwen3_8b_rl.yaml @@ -0,0 +1,118 @@ +# Example config: vanilla RL on Qwen3-8B for the terminal-agent task. +# +# This matches the q8b e20 cluster runs used in the ECHO experiments: +# train batch size 16, 16 samples per prompt, eval/checkpoint every 20 steps, +# and evaluation only on terminal_agent_train val100. +# +# Edit data paths for your machine before running. + +data: + train_data: + - name: terminal_agent_train + path: /mnt/pvc/datasets/terminalagent/terminal-traindev/v0.1.5/qwen35_parquets/train8770_sa_q35xml_wip2.parquet + val_data: + - name: terminal_agent_train + path: /mnt/pvc/datasets/terminalagent/terminal-traindev/v0.1.5/qwen35_parquets/val100_sa_q35xml_wip2.parquet + +tokenizer_path: Qwen/Qwen3-8B +chat_template_path: examples/train_integrations/echo_terminal/chat_templates/qwen3_xml_tool_calling.jinja + +trainer: + placement: + colocate_all: true + colocate_policy_ref: true + policy_num_nodes: 1 + policy_num_gpus_per_node: 8 + ref_num_nodes: 1 + ref_num_gpus_per_node: 8 + policy: + model: + path: Qwen/Qwen3-8B + optimizer_config: + lr: 1.0e-6 + adam_betas: [0.9, 0.999] + weight_decay: 0.01 + max_grad_norm: 0.2 + scheduler: constant_with_warmup + num_warmup_steps: 20 + algorithm: + max_seq_len: 34096 + use_kl_loss: false + use_kl_in_reward: false + kl_loss_coef: 0.0 + eps_clip_low: 0.2 + eps_clip_high: 0.2 + advantage_batch_normalize: false + loss_reduction: sequence_mean + world_model_coeff: 0.0 + world_loss_normalization: full_observation_tokens + epochs: 2 + update_epochs_per_batch: 1 + max_prompt_length: 1536 + train_batch_size: 16 + policy_mini_batch_size: 16 + micro_train_batch_size_per_gpu: 1 + micro_forward_batch_size_per_gpu: 1 + eval_batch_size: 16 + eval_before_train: false + eval_interval: 20 + ckpt_interval: 20 + hf_save_interval: -1 + ckpt_path: ${oc.env:OUTPUT_DIR}/ckpts + export_path: ${oc.env:OUTPUT_DIR}/exports + log_path: ${oc.env:OUTPUT_DIR}/skyrl_logs + logger: wandb + project_name: world_model + run_name: terminal_agent_phitrain_full_scale + log_example_interval: -1 + +generator: + n_samples_per_prompt: 16 + eval_n_samples_per_prompt: 8 + step_wise_trajectories: false + parser_name: qwen35 + command_selection: first + max_turns: 16 + max_context_tokens: 16384 + max_tokens_per_generation: 2048 + max_terminal_output_chars: 50000 + terminal_output_truncation: start + add_format_warn: true + world_loss_target: env_only + thinking_handling: keep_all + verifier_timeout: 120.0 + length_penalty_coef: 0.0 + length_penalty_threshold: 20000 + correct_threshold: 0.5 + max_total_tokens: 34096 + docker_memory_mb: 1024 + docker_cpus: 1 + base_temp_dir: /tmp/tbench_b200 + agent_max_concurrency: 512 + agent_timeout: 600.0 + max_concurrent_builds: 32 + max_build_retries: 3 + dataset_num_workers: 8 + dataset_max_rows: null + sampling_params: + max_generate_length: 2048 + temperature: 0.8 + logprobs: 0 + eval_sampling_params: + max_generate_length: 2048 + temperature: 0.6 + logprobs: 0 + inference_engine: + backend: vllm + run_engines_locally: true + num_engines: 8 + weight_sync_backend: nccl + tensor_parallel_size: 1 + async_engine: true + enforce_eager: false + enable_prefix_caching: true + enable_chunked_prefill: true + gpu_memory_utilization: 0.8 + max_num_batched_tokens: 262144 + engine_init_kwargs: + swap_space: 4 diff --git a/examples/train_integrations/echo_terminal/configs/qwen3_8b_rl_wm05.yaml b/examples/train_integrations/echo_terminal/configs/qwen3_8b_rl_wm05.yaml new file mode 100644 index 0000000000..060f29002a --- /dev/null +++ b/examples/train_integrations/echo_terminal/configs/qwen3_8b_rl_wm05.yaml @@ -0,0 +1,118 @@ +# Example config: RL plus auxiliary world-modeling loss on Qwen3-8B for the terminal-agent task. +# +# This matches the q8b e20 cluster runs used in the ECHO experiments: +# train batch size 16, 16 samples per prompt, eval/checkpoint every 20 steps, +# and evaluation only on terminal_agent_train val100. +# +# Edit data paths for your machine before running. + +data: + train_data: + - name: terminal_agent_train + path: /mnt/pvc/datasets/terminalagent/terminal-traindev/v0.1.5/qwen35_parquets/train8770_sa_q35xml_wip2.parquet + val_data: + - name: terminal_agent_train + path: /mnt/pvc/datasets/terminalagent/terminal-traindev/v0.1.5/qwen35_parquets/val100_sa_q35xml_wip2.parquet + +tokenizer_path: Qwen/Qwen3-8B +chat_template_path: examples/train_integrations/echo_terminal/chat_templates/qwen3_xml_tool_calling.jinja + +trainer: + placement: + colocate_all: true + colocate_policy_ref: true + policy_num_nodes: 1 + policy_num_gpus_per_node: 8 + ref_num_nodes: 1 + ref_num_gpus_per_node: 8 + policy: + model: + path: Qwen/Qwen3-8B + optimizer_config: + lr: 1.0e-6 + adam_betas: [0.9, 0.999] + weight_decay: 0.01 + max_grad_norm: 0.2 + scheduler: constant_with_warmup + num_warmup_steps: 20 + algorithm: + max_seq_len: 34096 + use_kl_loss: false + use_kl_in_reward: false + kl_loss_coef: 0.0 + eps_clip_low: 0.2 + eps_clip_high: 0.2 + advantage_batch_normalize: false + loss_reduction: sequence_mean + world_model_coeff: 0.05 + world_loss_normalization: full_observation_tokens + epochs: 2 + update_epochs_per_batch: 1 + max_prompt_length: 1536 + train_batch_size: 16 + policy_mini_batch_size: 16 + micro_train_batch_size_per_gpu: 1 + micro_forward_batch_size_per_gpu: 1 + eval_batch_size: 16 + eval_before_train: false + eval_interval: 20 + ckpt_interval: 20 + hf_save_interval: -1 + ckpt_path: ${oc.env:OUTPUT_DIR}/ckpts + export_path: ${oc.env:OUTPUT_DIR}/exports + log_path: ${oc.env:OUTPUT_DIR}/skyrl_logs + logger: wandb + project_name: world_model + run_name: terminal_agent_phitrain_full_scale + log_example_interval: -1 + +generator: + n_samples_per_prompt: 16 + eval_n_samples_per_prompt: 8 + step_wise_trajectories: false + parser_name: qwen35 + command_selection: first + max_turns: 16 + max_context_tokens: 16384 + max_tokens_per_generation: 2048 + max_terminal_output_chars: 50000 + terminal_output_truncation: start + add_format_warn: true + world_loss_target: env_only + thinking_handling: keep_all + verifier_timeout: 120.0 + length_penalty_coef: 0.0 + length_penalty_threshold: 20000 + correct_threshold: 0.5 + max_total_tokens: 34096 + docker_memory_mb: 1024 + docker_cpus: 1 + base_temp_dir: /tmp/tbench_b200 + agent_max_concurrency: 512 + agent_timeout: 600.0 + max_concurrent_builds: 32 + max_build_retries: 3 + dataset_num_workers: 8 + dataset_max_rows: null + sampling_params: + max_generate_length: 2048 + temperature: 0.8 + logprobs: 0 + eval_sampling_params: + max_generate_length: 2048 + temperature: 0.6 + logprobs: 0 + inference_engine: + backend: vllm + run_engines_locally: true + num_engines: 8 + weight_sync_backend: nccl + tensor_parallel_size: 1 + async_engine: true + enforce_eager: false + enable_prefix_caching: true + enable_chunked_prefill: true + gpu_memory_utilization: 0.8 + max_num_batched_tokens: 262144 + engine_init_kwargs: + swap_space: 4 diff --git a/examples/train_integrations/echo_terminal/dataset.py b/examples/train_integrations/echo_terminal/dataset.py new file mode 100644 index 0000000000..587b0e7e46 --- /dev/null +++ b/examples/train_integrations/echo_terminal/dataset.py @@ -0,0 +1,140 @@ +import os +from typing import Any, List + +import datasets +from loguru import logger +from transformers import PreTrainedTokenizerBase + + +class TerminalAgentTaskDataset: + """Dataset for phitrain terminal-agent parquet files. + + Expected columns are: + - prompt: chat messages + - path: task identifier/path + - task_binary: gzipped tar bytes for the Harbor task directory + """ + + def __init__( + self, + data_files: List[str | dict[str, Any]], + tokenizer: PreTrainedTokenizerBase, + max_prompt_length: int, + prompt_key: str = "prompt", + path_key: str = "path", + task_binary_key: str = "task_binary", + num_workers: int = 8, + max_rows: int | None = None, + max_rows_per_source: int | None = None, + world_model_only_paths: list[str] | None = None, + ) -> None: + self.data_files = data_files + self.tokenizer = tokenizer + self.max_prompt_length = max_prompt_length + self.prompt_key = prompt_key + self.path_key = path_key + self.task_binary_key = task_binary_key + self.num_workers = num_workers + self.max_rows = max_rows + self.max_rows_per_source = max_rows_per_source + self.world_model_only_paths = set(world_model_only_paths or []) + self._read_files_and_tokenize() + + def _normalize_source(self, source: str | dict[str, Any]) -> tuple[str, bool, str]: + if isinstance(source, str): + return source, False, os.path.splitext(os.path.basename(source))[0] + if not isinstance(source, dict): + raise TypeError(f"Terminal-agent data source must be a path or dict, got {type(source)}") + path = source.get("path") or source.get("file") or source.get("data") + if not isinstance(path, str): + raise ValueError("Terminal-agent data source dict must include a string 'path', 'file', or 'data'.") + data_source = source.get("name") or source.get("data_source") or os.path.splitext(os.path.basename(path))[0] + if not isinstance(data_source, str): + raise ValueError("Terminal-agent data source dict 'name' or 'data_source' must be a string when provided.") + return path, bool(source.get("world_model_only", False)), data_source + + def _load_file(self, source: str) -> datasets.Dataset: + ext = os.path.splitext(source)[-1].lower() + if ext == ".parquet": + return datasets.load_dataset("parquet", data_files=source, keep_in_memory=True)["train"] + if ext in [".json", ".jsonl"]: + return datasets.load_dataset("json", data_files=source, keep_in_memory=True)["train"] + raise ValueError(f"TerminalAgentTaskDataset expects parquet/json/jsonl data, got {source!r}") + + def _read_files_and_tokenize(self) -> None: + loaded = [] + for source in self.data_files: + path, world_model_only, data_source = self._normalize_source(source) + dataset = self._load_file(path) + if self.max_rows_per_source is not None: + dataset = dataset.select(range(min(self.max_rows_per_source, len(dataset)))) + world_model_only = world_model_only or path in self.world_model_only_paths + if world_model_only: + dataset = dataset.map(lambda row: {**row, "_world_model_only": True, "_data_source": data_source}) + else: + dataset = dataset.map( + lambda row: { + **row, + "_world_model_only": bool(row.get("_world_model_only", False)), + "_data_source": data_source, + } + ) + loaded.append(dataset) + self.dataframe: datasets.Dataset = datasets.concatenate_datasets(loaded) + if self.max_rows is not None: + self.dataframe = self.dataframe.select(range(min(self.max_rows, len(self.dataframe)))) + logger.info(f"Loaded terminal-agent dataset with {len(self.dataframe)} rows") + + required = {self.prompt_key, self.path_key, self.task_binary_key} + missing = required - set(self.dataframe.column_names) + if missing: + raise ValueError(f"Terminal-agent dataset is missing required columns: {sorted(missing)}") + + def _tokenize(row: dict[str, Any]) -> dict[str, Any]: + row["prompt_token_ids"] = self.tokenizer.apply_chat_template( + row[self.prompt_key], + tokenize=True, + add_generation_prompt=True, + return_dict=False, + ) + row["prompt_token_count"] = len(row["prompt_token_ids"]) + return row + + if self.num_workers > 1: + import multiprocess + + multiprocess.set_start_method("spawn", force=True) + + self.dataframe = self.dataframe.map( + _tokenize, + num_proc=self.num_workers, + desc="Tokenizing terminal-agent prompts", + ) + self.dataframe = self.dataframe.filter( + lambda row: row["prompt_token_count"] <= self.max_prompt_length, + num_proc=self.num_workers, + desc=f"Filtering prompts longer than {self.max_prompt_length} tokens", + ) + logger.info(f"Filtered terminal-agent dataset size: {len(self.dataframe)}") + + def __getitem__(self, index: int) -> dict: + row = self.dataframe[index] + env_extras = { + "path": row[self.path_key], + "task_binary": row[self.task_binary_key], + "prompt_token_ids": row["prompt_token_ids"], + "world_model_only": bool(row.get("_world_model_only", False)), + "data_source": row.get("_data_source"), + } + return { + "prompt": row[self.prompt_key], + "env_class": "terminal_agent", + "env_extras": env_extras, + "uid": str(index), + } + + def __len__(self) -> int: + return len(self.dataframe) + + def collate_fn(self, item_list): + return item_list diff --git a/examples/train_integrations/echo_terminal/entrypoint.py b/examples/train_integrations/echo_terminal/entrypoint.py new file mode 100644 index 0000000000..80bf4be909 --- /dev/null +++ b/examples/train_integrations/echo_terminal/entrypoint.py @@ -0,0 +1,227 @@ +"""Entrypoint for ECHO terminal-agent training on SkyRL hooks.""" + +import sys +from dataclasses import dataclass, field +from pathlib import Path +from typing import Literal, Optional + +import ray +from omegaconf import OmegaConf + +from skyrl.backends.skyrl_train.workers.fsdp.fsdp_worker import CriticWorker, RefWorker +from skyrl.train.config import GeneratorConfig +from skyrl.train.entrypoints.main_base import BasePPOExp +from skyrl.train.utils import validate_cfg +from skyrl.train.utils.utils import initialize_ray + +from .chat_template import load_chat_template, resolve_chat_template_path +from .dataset import TerminalAgentTaskDataset +from .generator import TerminalAgentGenerator +from .world_modeling.config import EchoSkyRLTrainConfig +from .world_modeling.fsdp_worker import PolicyWorker +from .world_modeling.trainer import EchoPPOTrainer + + +def _load_config(argv: list[str]) -> "EchoTerminalAgentSkyRLConfig": + config_path = None + remaining = [] + i = 0 + while i < len(argv): + arg = argv[i] + if arg in ("--config", "-c"): + if i + 1 >= len(argv): + raise ValueError(f"{arg} requires a config path") + config_path = argv[i + 1] + i += 2 + continue + if arg.startswith("config="): + config_path = arg.split("=", 1)[1] + i += 1 + continue + remaining.append(arg) + i += 1 + + if config_path is None: + return EchoTerminalAgentSkyRLConfig.from_cli_overrides(remaining) + + base = OmegaConf.load(Path(config_path)) + overrides = OmegaConf.from_cli(remaining) + merged = OmegaConf.merge(base, overrides) + return EchoTerminalAgentSkyRLConfig.from_dict_config(merged) + + +@dataclass +class TerminalAgentGeneratorConfig(GeneratorConfig): + max_turns: int = 16 + max_context_tokens: int = 25000 + max_tokens_per_generation: int = 2048 + max_terminal_output_chars: int = 6000 + terminal_output_truncation: str = "start_end" + max_commands_per_turn: int = 1 + command_selection: Literal["first", "last", "all"] = "first" + parser_name: Literal["xml", "hermes", "qwen35"] = "xml" + length_penalty_coef: float = 0.0 + length_penalty_threshold: int = 20000 + correct_threshold: float = 0.5 + verifier_timeout: float = 120.0 + thinking_handling: Literal["keep_all", "strip_rollout", "strip_all"] = "keep_all" + max_total_tokens: Optional[int] = None + add_format_warn: bool = False + max_world_model_tokens: Optional[int] = None + world_loss_target: Optional[Literal["full_observation", "env_only", "warning_only", "warning_plus_env"]] = None + world_loss_on_format_warn: bool = True + world_model_only_paths: list[str] = field(default_factory=list) + agent_max_concurrency: int = 32 + agent_timeout: float = 1800.0 + docker_memory_mb: Optional[int] = None + docker_cpus: Optional[int] = None + base_temp_dir: Optional[str] = None + max_concurrent_builds: int = 32 + max_build_retries: int = 3 + prompt_key: str = "prompt" + path_key: str = "path" + task_binary_key: str = "task_binary" + dataset_num_workers: int = 8 + dataset_max_rows: Optional[int] = None + eval_dataset_max_rows_per_source: Optional[int] = None + + +@dataclass +class EchoTerminalAgentSkyRLConfig(EchoSkyRLTrainConfig): + generator: TerminalAgentGeneratorConfig = field(default_factory=TerminalAgentGeneratorConfig) + tokenizer_path: Optional[str] = None + chat_template_path: Optional[str] = None + + def __post_init__(self): + super().__post_init__() + engine_init_kwargs = self.generator.inference_engine.engine_init_kwargs + + if self.chat_template_path is not None: + self.chat_template_path = str(resolve_chat_template_path(self.chat_template_path)) + configured_chat_template = engine_init_kwargs.get("chat_template") + if configured_chat_template is not None and configured_chat_template != self.chat_template_path: + raise ValueError( + "chat_template_path and generator.inference_engine.engine_init_kwargs.chat_template " + f"must match when both are set, got {self.chat_template_path!r} and " + f"{configured_chat_template!r}" + ) + engine_init_kwargs["chat_template"] = self.chat_template_path + + if self.tokenizer_path is None: + return + + configured_tokenizer = engine_init_kwargs.get("tokenizer") + if configured_tokenizer is not None and configured_tokenizer != self.tokenizer_path: + raise ValueError( + "tokenizer_path and generator.inference_engine.engine_init_kwargs.tokenizer " + f"must match when both are set, got {self.tokenizer_path!r} and {configured_tokenizer!r}" + ) + engine_init_kwargs["tokenizer"] = self.tokenizer_path + + +class EchoTerminalAgentExp(BasePPOExp): + def _ensure_chat_template_loaded(self) -> None: + template_path = getattr(self.cfg, "chat_template_path", None) + if not template_path or getattr(self, "_echo_chat_template_loaded", False): + return + load_chat_template(self.tokenizer, template_path) + self._echo_chat_template_loaded = True + + def get_tokenizer_path(self) -> str: + return self.cfg.tokenizer_path or self.cfg.trainer.policy.model.path + + def get_worker_classes(self): + if self.cfg.trainer.strategy not in ("fsdp", "fsdp2"): + raise ValueError("ECHO hook implementation currently supports only fsdp/fsdp2, not Megatron.") + return PolicyWorker, CriticWorker, RefWorker + + def get_generator(self, cfg, tokenizer, inference_engine_client): + self._ensure_chat_template_loaded() + return TerminalAgentGenerator( + generator_cfg=cfg.generator, + inference_engine_client=inference_engine_client, + tokenizer=tokenizer, + max_seq_len=cfg.trainer.algorithm.max_seq_len, + ) + + def get_trainer( + self, + cfg, + tracker, + tokenizer, + train_dataset, + eval_dataset, + inference_engine_client, + generator, + colocate_pg, + ): + return EchoPPOTrainer( + cfg=cfg, + tracker=tracker, + tokenizer=tokenizer, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + inference_engine_client=inference_engine_client, + generator=generator, + colocate_pg=colocate_pg, + ) + + def get_train_dataset(self): + self._ensure_chat_template_loaded() + prompts_dataset = TerminalAgentTaskDataset( + data_files=self.cfg.data.train_data, + tokenizer=self.tokenizer, + max_prompt_length=self.cfg.trainer.max_prompt_length, + prompt_key=self.cfg.generator.prompt_key, + path_key=self.cfg.generator.path_key, + task_binary_key=self.cfg.generator.task_binary_key, + num_workers=self.cfg.generator.dataset_num_workers, + max_rows=self.cfg.generator.dataset_max_rows, + max_rows_per_source=None, + world_model_only_paths=self.cfg.generator.world_model_only_paths, + ) + assert len(prompts_dataset) >= self.cfg.trainer.train_batch_size, ( + f"dataset should be at least as large as `train_batch_size` {self.cfg.trainer.train_batch_size}, " + f"got size {len(prompts_dataset)}" + ) + return prompts_dataset + + def get_eval_dataset(self): + self._ensure_chat_template_loaded() + if self.cfg.trainer.eval_interval > 0 and self.cfg.data.val_data: + return TerminalAgentTaskDataset( + data_files=self.cfg.data.val_data, + tokenizer=self.tokenizer, + max_prompt_length=self.cfg.trainer.max_prompt_length, + prompt_key=self.cfg.generator.prompt_key, + path_key=self.cfg.generator.path_key, + task_binary_key=self.cfg.generator.task_binary_key, + num_workers=self.cfg.generator.dataset_num_workers, + max_rows=self.cfg.generator.dataset_max_rows, + max_rows_per_source=self.cfg.generator.eval_dataset_max_rows_per_source, + world_model_only_paths=[], + ) + return None + + +@ray.remote(num_cpus=1) +def skyrl_entrypoint(cfg): + exp = EchoTerminalAgentExp(cfg) + exp.run() + + +def main() -> None: + cfg = _load_config(sys.argv[1:]) + validate_cfg(cfg) + if cfg.generator.step_wise_trajectories: + raise ValueError("TerminalAgentGenerator requires generator.step_wise_trajectories=false.") + if cfg.trainer.algorithm.max_seq_len is None: + raise ValueError("trainer.algorithm.max_seq_len must be set for terminal-agent training.") + if cfg.trainer.strategy not in ("fsdp", "fsdp2"): + raise ValueError("ECHO hook implementation currently supports only fsdp/fsdp2.") + initialize_ray(cfg) + ray.get(skyrl_entrypoint.remote(cfg)) + + +if __name__ == "__main__": + main() diff --git a/examples/train_integrations/echo_terminal/generator.py b/examples/train_integrations/echo_terminal/generator.py new file mode 100644 index 0000000000..f5ae529f2d --- /dev/null +++ b/examples/train_integrations/echo_terminal/generator.py @@ -0,0 +1,855 @@ +from __future__ import annotations + +import asyncio +import copy +import logging +import re +import time +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any, Literal + +from tqdm import tqdm + +from skyrl.backends.skyrl_train.inference_engines.inference_engine_client import InferenceEngineClient +from skyrl.train.generators.base import GeneratorInput, GeneratorInterface, GeneratorOutput, TrajectoryID +from skyrl.train.generators.utils import get_rollout_metrics + +from .chat_template import choose_obs_role +from .interaction import AddedMessageSpan, TerminalInteraction +from .parsers import check_format_warnings, get_parser + +logger = logging.getLogger(__name__) + + +def truncate_output(output: str, max_chars: int = 6000, strategy: str = "start_end") -> str: + if len(output) <= max_chars: + return output + original_len = len(output) + if strategy == "start": + msg = f"\n[Output truncated: showing first {max_chars} of {original_len} characters]" + return output[:max_chars] + msg + if strategy == "end": + msg = f"[Output truncated: showing last {max_chars} of {original_len} characters]\n" + return msg + output[-max_chars:] + half = max_chars // 2 + msg = f"\n[Output truncated: showing first {half} and last {half} of {original_len} characters]\n" + return output[:half] + msg + output[-half:] + + +def strip_thinking(text: str) -> str: + result = re.sub(r".*?", "", text, flags=re.DOTALL) + result = re.sub(r".*$", "", result, flags=re.DOTALL) + return result.strip() + + +@dataclass +class TerminalTrajectoryOutput: + trajectory_id: TrajectoryID + prompt_token_ids: list[int] + response_ids: list[int] + loss_masks: list[int] + world_loss_masks: list[int] + world_warning_masks: list[int] + world_env_masks: list[int] + world_full_observation_count: int + rollout_logprobs: list[float] + world_model_only: bool = False + reward: float = 0.0 + correct: bool = False + stop_reason: str = "error" + metrics: dict[str, float | int] = field(default_factory=dict) + metadata: dict[str, Any] = field(default_factory=dict) + + +class TerminalAgentGenerator(GeneratorInterface): + def __init__( + self, + generator_cfg, + inference_engine_client: InferenceEngineClient, + tokenizer, + max_seq_len: int, + ) -> None: + if getattr(generator_cfg, "step_wise_trajectories", False): + raise ValueError("TerminalAgentGenerator is non-step-wise. Set generator.step_wise_trajectories=false.") + self.generator_cfg = generator_cfg + self.inference_engine_client = inference_engine_client + self.tokenizer = tokenizer + self.max_seq_len = max_seq_len + self._parser_name = generator_cfg.parser_name + self._parser = get_parser(self._parser_name) + self._obs_role = choose_obs_role(tokenizer, candidates=("tool", "user")) + world_loss_target = generator_cfg.world_loss_target + if world_loss_target is None: + world_loss_target = "full_observation" if generator_cfg.world_loss_on_format_warn else "env_only" + if world_loss_target not in {"full_observation", "env_only", "warning_only", "warning_plus_env"}: + raise ValueError(f"Unsupported world_loss_target={generator_cfg.world_loss_target!r}") + self.world_loss_target = world_loss_target + self._im_end_id = tokenizer.convert_tokens_to_ids("<|im_end|>") + newline_ids = tokenizer.encode("\n", add_special_tokens=False) + self._turn_end_tokens = [self._im_end_id, newline_ids[-1]] if self._im_end_id is not None else newline_ids[-1:] + self._think_start_id: int | None = None + self._think_end_id: int | None = None + if generator_cfg.thinking_handling != "keep_all": + self._init_thinking_token_ids() + + async def generate(self, input_batch: GeneratorInput, disable_tqdm: bool = False) -> GeneratorOutput: + prompts = input_batch["prompts"] + env_extras = input_batch.get("env_extras") + trajectory_ids = input_batch.get("trajectory_ids") + sampling_params = input_batch.get("sampling_params") or {} + if env_extras is None or trajectory_ids is None: + raise ValueError("TerminalAgentGenerator requires env_extras and trajectory_ids.") + if not (len(prompts) == len(env_extras) == len(trajectory_ids)): + raise ValueError("prompts, env_extras, and trajectory_ids must have the same length.") + + from .harbor_environment import HarborEnvironmentProvider + + provider = HarborEnvironmentProvider( + docker_memory_mb=self.generator_cfg.docker_memory_mb, + docker_cpus=self.generator_cfg.docker_cpus, + base_temp_dir=Path(self.generator_cfg.base_temp_dir) if self.generator_cfg.base_temp_dir else None, + max_concurrent_builds=self.generator_cfg.max_concurrent_builds, + max_build_retries=self.generator_cfg.max_build_retries, + ) + unique_envs, prompt_idx_by_instance = self._prepare_unique_envs(env_extras, trajectory_ids) + for extra, trajectory_id in zip(env_extras, trajectory_ids): + extra["_prompt_idx"] = prompt_idx_by_instance[trajectory_id.instance_id] + + outputs: list[TerminalTrajectoryOutput | None] = [None] * len(prompts) + progress = tqdm( + disable=disable_tqdm, + total=len(prompts), + desc="Generating Terminal Trajectories", + miniters=max(1, len(prompts) // 10), + mininterval=5, + ) + + try: + await provider.prepare_batch(unique_envs, num_generations=self.generator_cfg.n_samples_per_prompt) + semaphore = asyncio.Semaphore(self.generator_cfg.agent_max_concurrency) + + async def _worker(idx: int) -> None: + async with semaphore: + try: + outputs[idx] = await asyncio.wait_for( + self._run_one( + provider, + prompts[idx], + env_extras[idx], + trajectory_ids[idx], + sampling_params, + ), + timeout=self.generator_cfg.agent_timeout, + ) + except asyncio.TimeoutError: + outputs[idx] = self._failure_output( + prompts[idx], + env_extras[idx], + trajectory_ids[idx], + "timeout", + ) + except Exception as exc: + logger.exception("Terminal rollout failed for %s: %s", trajectory_ids[idx], exc) + outputs[idx] = self._failure_output( + prompts[idx], + env_extras[idx], + trajectory_ids[idx], + "error", + ) + finally: + progress.update(1) + + async with asyncio.TaskGroup() as tg: + for idx in range(len(prompts)): + tg.create_task(_worker(idx)) + finally: + progress.close() + await provider.cleanup_batch() + + return self._build_generator_output([o for o in outputs if o is not None]) + + def _prepare_unique_envs( + self, + env_extras: list[dict[str, Any]], + trajectory_ids: list[TrajectoryID], + ) -> tuple[list[dict[str, Any]], dict[str, int]]: + unique_envs: list[dict[str, Any]] = [] + prompt_idx_by_instance: dict[str, int] = {} + for extra, tid in zip(env_extras, trajectory_ids): + if tid.instance_id in prompt_idx_by_instance: + continue + prompt_idx_by_instance[tid.instance_id] = len(unique_envs) + unique_envs.append(extra) + return unique_envs, prompt_idx_by_instance + + async def _run_one( + self, + provider: HarborEnvironmentProvider, + prompt: list[dict[str, Any]], + env_extra: dict[str, Any], + trajectory_id: TrajectoryID, + sampling_params: dict[str, Any], + ) -> TerminalTrajectoryOutput: + prompt_token_ids = list(env_extra.get("prompt_token_ids") or self.tokenizer.apply_chat_template( + prompt, tokenize=True, add_generation_prompt=True, return_dict=False + )) + interaction = TerminalInteraction( + prompt_id=int(trajectory_id.instance_id) if str(trajectory_id.instance_id).isdigit() else 0, + completion_id=trajectory_id.repetition_id, + prompt_messages=copy.deepcopy(prompt), + prompt_token_ids=prompt_token_ids, + metadata={ + "path": env_extra.get("path", ""), + "data_source": env_extra.get("data_source"), + }, + ) + environment = None + setup_complete = False + try: + environment = await provider.create(env_extra) + await environment.setup() + setup_complete = True + await self._agent_loop(interaction, trajectory_id.to_string(), environment, sampling_params) + except Exception as exc: + stop_reason = "env_setup_error" if not setup_complete else "error" + logger.warning("Environment or agent failed for %s: %s", trajectory_id, exc, exc_info=True) + self._mark_failure(interaction, stop_reason, exc) + finally: + if environment is not None: + await environment.cleanup() + + self._ensure_non_empty_completion(interaction) + return TerminalTrajectoryOutput( + trajectory_id=trajectory_id, + prompt_token_ids=interaction.prompt_token_ids, + response_ids=interaction.completion_token_ids, + loss_masks=interaction.completion_masks, + world_loss_masks=interaction.completion_observation_masks, + world_warning_masks=interaction.completion_warning_masks, + world_env_masks=interaction.completion_env_output_masks, + world_full_observation_count=int(interaction.metadata.get("full_observation_body_count", 0)), + rollout_logprobs=interaction.completion_logprobs, + world_model_only=bool(env_extra.get("world_model_only", False)), + reward=interaction.reward, + correct=interaction.correct, + stop_reason=str(interaction.metadata.get("stop_reason", "error")), + metrics=interaction.metrics, + metadata=dict(interaction.metadata), + ) + + async def _agent_loop( + self, + interaction: TerminalInteraction, + session_id: str, + environment: HarborEnvironment, + sampling_params: dict[str, Any], + ) -> None: + sampling_params = dict(sampling_params) + sampling_params.setdefault("logprobs", 1) + stop_reason = "error" + t_run_start = time.monotonic() + turn_traces: list[dict[str, Any]] = [] + rollout_context_ids: list[int] | None = None + if self.generator_cfg.thinking_handling == "strip_rollout": + rollout_context_ids = list(interaction.prompt_token_ids) + + turn = -1 + for turn in range(self.generator_cfg.max_turns): + total_len = len(interaction.token_ids) + if not self._has_token_budget(interaction, 0): + stop_reason = "max_total_tokens" + break + current_len = len(rollout_context_ids) if rollout_context_ids is not None else total_len + if current_len >= self.generator_cfg.max_context_tokens: + stop_reason = "max_context_tokens" + break + + generation_context = rollout_context_ids if rollout_context_ids is not None else interaction.token_ids + max_tokens = self.generator_cfg.max_tokens_per_generation + if self.generator_cfg.max_total_tokens is not None: + max_tokens = min(max_tokens, self.generator_cfg.max_total_tokens - total_len) + if max_tokens <= 0: + stop_reason = "max_total_tokens" + break + + t_gen = time.monotonic() + token_ids, logprobs = await self._generate_tokens( + generation_context, + sampling_params, + session_id=session_id, + seed=interaction.completion_id, + max_tokens=max_tokens, + ) + generate_sec = time.monotonic() - t_gen + token_ids = token_ids[:max_tokens] + logprobs = logprobs[:max_tokens] + + if self.generator_cfg.thinking_handling == "strip_all": + add_ids, add_logprobs = self._strip_thinking_from_tokens(token_ids, logprobs) + else: + add_ids, add_logprobs = token_ids, logprobs + if not self._has_token_budget(interaction, len(add_ids)): + stop_reason = "max_total_tokens" + break + + interaction.append_assistant(add_ids, add_logprobs or [0.0] * len(add_ids), self.tokenizer) + if rollout_context_ids is not None: + clean_ids, _ = self._strip_thinking_from_tokens(token_ids) + rollout_context_ids.extend(clean_ids) + + response_text = self.tokenizer.decode(token_ids, skip_special_tokens=True) + format_warnings, format_violations = check_format_warnings( + response_text, + tags=self._parser.format_tags, + parser_name=self._parser_name, + ) + text_for_parsing = strip_thinking(response_text) + parse_result = self._parser.parse_response(text_for_parsing) + parse_result.warnings = format_warnings + parse_result.violations = format_violations + + if parse_result.error: + turn_traces.append( + self._turn_trace(turn, generate_sec, len(token_ids), 0.0, True, False, current_len, format_violations) + ) + await self._add_observation(interaction, "", parse_result.error, rollout_context_ids) + continue + + commands = self._select_commands(parse_result.commands) + t_exec = time.monotonic() + terminal_output = await self._execute_commands(environment, commands) + exec_sec = time.monotonic() - t_exec + turn_traces.append( + self._turn_trace( + turn, + generate_sec, + len(token_ids), + exec_sec, + False, + parse_result.is_done, + current_len, + format_violations, + ) + ) + + if parse_result.is_done: + stop_reason = "done" + break + if not self._has_token_budget(interaction, 0): + stop_reason = "max_total_tokens" + break + warnings_prefix = "" + if self.generator_cfg.add_format_warn and format_warnings: + warnings_prefix = "WARNINGS:\n" + "\n".join(f"- {w}" for w in format_warnings) + "\n\n" + if not commands: + await self._add_observation( + interaction, + warnings_prefix, + "No commands provided. Please provide commands to execute or set done.", + rollout_context_ids, + ) + continue + await self._add_observation(interaction, warnings_prefix, terminal_output, rollout_context_ids) + else: + stop_reason = "max_turns" + + num_turns = min(turn + 1, self.generator_cfg.max_turns) if turn >= 0 else 0 + interaction.metadata["stop_reason"] = stop_reason + interaction.metadata["num_agent_turns"] = num_turns + + total_generate_sec = sum(t.get("generate_sec", 0.0) for t in turn_traces) + total_exec_sec = sum(t.get("exec_sec", 0.0) for t in turn_traces) + total_generate_tokens = sum(t.get("generate_tokens", 0) for t in turn_traces) + if stop_reason == "max_total_tokens": + interaction.reward = 0.0 + interaction.correct = False + verifier_sec = 0.0 + verifier_error = None + else: + t_verify = time.monotonic() + verifier_reward, verifier_error = await environment.run_verifier(timeout=self.generator_cfg.verifier_timeout) + verifier_sec = time.monotonic() - t_verify + interaction.reward = self._compute_reward(verifier_reward, interaction) + interaction.correct = verifier_reward >= self.generator_cfg.correct_threshold + + interaction.metadata["trace"] = { + "agent_run_sec": time.monotonic() - t_run_start, + "verifier_sec": verifier_sec, + "total_generate_sec": total_generate_sec, + "total_exec_sec": total_exec_sec, + "total_tokens": len(interaction.token_ids), + "total_prompt_tokens": len(interaction.prompt_token_ids), + "total_completion_tokens": len(interaction.completion_token_ids), + "total_generate_tokens": total_generate_tokens, + "turns": turn_traces, + } + metrics: dict[str, float | int] = {f"stop_reason/{stop_reason}": 1, "num_agent_turns": num_turns} + if verifier_error: + metrics[f"verifier_error/{verifier_error}"] = 1 + for trace in turn_traces: + for violation in trace.get("format_violations", []): + metrics[f"format/{violation}"] = metrics.get(f"format/{violation}", 0) + 1 + metrics["parse_errors"] = sum(1 for t in turn_traces if t.get("parse_error")) + metrics["format_violations_total"] = sum( + value for key, value in metrics.items() if key.startswith("format/") and isinstance(value, int) + ) + interaction.metrics = metrics + + async def _generate_tokens( + self, + prompt_token_ids: list[int], + sampling_params: dict[str, Any], + session_id: str, + seed: int, + max_tokens: int, + ) -> tuple[list[int], list[float]]: + params = dict(sampling_params) + params["max_tokens"] = max_tokens + params.setdefault("logprobs", 1) + params["seed"] = seed + result = await self.inference_engine_client.generate( + { + "prompts": None, + "prompt_token_ids": [prompt_token_ids], + "sampling_params": params, + "session_ids": [session_id], + "skip_detokenize": True, + } + ) + token_ids = result["response_ids"][0] + logprobs = result.get("response_logprobs") + if logprobs is None or logprobs[0] is None: + return token_ids, [0.0] * len(token_ids) + return token_ids, list(logprobs[0]) + + def _get_message_delta_tokens(self, interaction: TerminalInteraction, role: str, text: str) -> list[int]: + before_ids = self.tokenizer.apply_chat_template( + interaction.messages, + add_generation_prompt=False, + tokenize=True, + return_dict=False, + ) + after_ids = self.tokenizer.apply_chat_template( + interaction.messages + [{"role": role, "content": text}], + add_generation_prompt=False, + tokenize=True, + return_dict=False, + ) + return after_ids[len(before_ids) :] + + @staticmethod + def _common_prefix_len(a: list[int], b: list[int]) -> int: + shared = 0 + for a_tok, b_tok in zip(a, b): + if a_tok != b_tok: + break + shared += 1 + return shared + + @staticmethod + def _common_suffix_len(a: list[int], b: list[int], prefix_len: int = 0) -> int: + max_suffix = min(len(a), len(b)) - prefix_len + shared = 0 + while shared < max_suffix and a[len(a) - 1 - shared] == b[len(b) - 1 - shared]: + shared += 1 + return shared + + def _get_observation_content_spans( + self, + interaction: TerminalInteraction, + role: str, + warning_text: str, + env_text: str, + ) -> tuple[int, int, int]: + empty_delta = self._get_message_delta_tokens(interaction, role, "") + full_delta = self._get_message_delta_tokens(interaction, role, warning_text + env_text) + + prefix_len = self._common_prefix_len(empty_delta, full_delta) + suffix_len = self._common_suffix_len(empty_delta, full_delta, prefix_len=prefix_len) + content_end = len(full_delta) - suffix_len + + if not warning_text: + return prefix_len, prefix_len, content_end + + warning_delta = self._get_message_delta_tokens(interaction, role, warning_text) + warning_prefix_len = self._common_prefix_len(empty_delta, warning_delta) + warning_suffix_len = self._common_suffix_len(empty_delta, warning_delta, prefix_len=warning_prefix_len) + warning_content_end = len(warning_delta) - warning_suffix_len + + if warning_prefix_len != prefix_len: + raise ValueError( + f"Inconsistent observation-message prefix spans: empty/full={prefix_len}, " + f"empty/warning={warning_prefix_len}." + ) + + return prefix_len, warning_content_end, content_end + + def _apply_mask_indices(self, mask: list[int], indices: list[int]) -> None: + if self.generator_cfg.max_world_model_tokens is not None: + indices = indices[: self.generator_cfg.max_world_model_tokens] + for idx in indices: + mask[idx] = 1 + + async def _add_observation( + self, + interaction: TerminalInteraction, + warning_text: str, + env_text: str, + rollout_context_ids: list[int] | None, + ) -> None: + text = warning_text + env_text + if self.generator_cfg.max_total_tokens is not None: + overhead = len(self._turn_end_tokens) + 10 + remaining = self.generator_cfg.max_total_tokens - len(interaction.token_ids) - overhead + if remaining <= 0: + return + obs_token_count = len(self.tokenizer.encode(text, add_special_tokens=False)) + if obs_token_count > remaining: + ratio = remaining / obs_token_count + warning_text, env_text, text = self._truncate_observation_segments( + warning_text, + env_text, + max(0, int(len(text) * ratio * 0.9)), + ) + + turn_end = self._append_turn_end(interaction) + if rollout_context_ids is not None: + rollout_context_ids.extend(turn_end) + content_start_offset, warning_end_offset, content_end_offset = self._get_observation_content_spans( + interaction, + self._obs_role, + warning_text, + env_text, + ) + span, token_ids = self._add_message(interaction, self._obs_role, text, add_generation_prompt=True) + + warning_start = span.message_start + content_start_offset + warning_end = span.message_start + warning_end_offset + env_start = warning_end + env_end = span.message_start + content_end_offset + interaction.metadata["full_observation_body_count"] = ( + int(interaction.metadata.get("full_observation_body_count", 0)) + + max(0, content_end_offset - content_start_offset) + ) + + if not (span.message_start <= warning_start <= warning_end <= env_start <= env_end <= span.message_end): + raise ValueError( + "Invalid observation spans: " + f"target={self.world_loss_target}, offsets=({content_start_offset}, {warning_end_offset}, " + f"{content_end_offset}), span={span}" + ) + + for i in range(warning_start, warning_end): + interaction.completion_warning_masks[i] = 1 + for i in range(env_start, env_end): + interaction.completion_env_output_masks[i] = 1 + + if self.world_loss_target == "full_observation": + obs_indices = list(range(span.message_start, span.message_end)) + elif self.world_loss_target == "warning_only": + obs_indices = list(range(warning_start, warning_end)) + elif self.world_loss_target == "warning_plus_env": + obs_indices = list(range(warning_start, env_end)) + else: + obs_indices = list(range(env_start, env_end)) + self._apply_mask_indices(interaction.completion_observation_masks, obs_indices) + + if rollout_context_ids is not None: + rollout_context_ids.extend(token_ids) + return span + + def _add_message( + self, + interaction: TerminalInteraction, + role: Literal["system", "user", "tool"], + text: str, + add_generation_prompt: bool, + ) -> tuple[AddedMessageSpan, list[int]]: + before_ids = self.tokenizer.apply_chat_template( + interaction.messages, + add_generation_prompt=False, + tokenize=True, + return_dict=False, + ) + after_message_ids = self.tokenizer.apply_chat_template( + interaction.messages + [{"role": role, "content": text}], + add_generation_prompt=False, + tokenize=True, + return_dict=False, + ) + after_full_ids = self.tokenizer.apply_chat_template( + interaction.messages + [{"role": role, "content": text}], + add_generation_prompt=add_generation_prompt, + tokenize=True, + return_dict=False, + ) + message_token_ids = after_message_ids[len(before_ids) :] + token_ids = after_full_ids[len(before_ids) :] + interaction.append_tokens(token_ids, [0] * len(token_ids), [0.0] * len(token_ids), role=role, text=text) + completion_start = len(interaction.completion_token_ids) - len(token_ids) + message_end = completion_start + len(message_token_ids) + return ( + AddedMessageSpan( + completion_start=completion_start, + completion_end=completion_start + len(token_ids), + message_start=completion_start, + message_end=message_end, + generation_prompt_start=message_end, + generation_prompt_end=completion_start + len(token_ids), + ), + token_ids, + ) + + def _append_turn_end(self, interaction: TerminalInteraction) -> list[int]: + tail = interaction.completion_token_ids[-len(self._turn_end_tokens) :] + if tail == self._turn_end_tokens: + return [] + if tail and self._im_end_id is not None and tail[-1] == self._im_end_id: + tokens = self._turn_end_tokens[1:] + else: + tokens = self._turn_end_tokens + interaction.completion_token_ids.extend(tokens) + interaction.completion_masks.extend([0] * len(tokens)) + interaction.completion_observation_masks.extend([0] * len(tokens)) + interaction.completion_warning_masks.extend([0] * len(tokens)) + interaction.completion_env_output_masks.extend([0] * len(tokens)) + interaction.completion_logprobs.extend([0.0] * len(tokens)) + return tokens + + async def _execute_commands(self, environment: HarborEnvironment, commands: list) -> str: + if not commands: + return "" + outputs = [] + for parsed_cmd in commands: + raw_text = parsed_cmd.raw + if parsed_cmd.error: + outputs.append(f"Command '{raw_text}' skipped due to parse error: {parsed_cmd.error}") + continue + if parsed_cmd.name != "bash": + outputs.append( + f"Command '{raw_text}' skipped due to Unknown tool '{parsed_cmd.name}'. Only 'bash' is supported." + ) + continue + cmd = parsed_cmd.arguments.get("command", "") + timeout = parsed_cmd.arguments.get("timeout", 30) + if not cmd: + outputs.append( + f"Command '{raw_text}' skipped due to no command provided. Please provide a command in correct format to execute." + ) + continue + try: + result = await environment.exec(cmd, timeout=float(timeout)) + except (RuntimeError, TimeoutError) as exc: + outputs.append(f"Command '{cmd}' timed out after {timeout} seconds.\n\n(exit_code=-1)") + logger.warning("Command failed/timed out: %r (%s)", cmd, exc) + continue + except Exception as exc: + outputs.append(f"Command '{cmd}' execution error: {exc}\n\n(exit_code=-1)") + continue + raw_output = result.stdout or "" + if result.stderr: + raw_output += f"\n{result.stderr}" + truncated = truncate_output( + raw_output, + max_chars=self.generator_cfg.max_terminal_output_chars, + strategy=self.generator_cfg.terminal_output_truncation, + ) + status = "executed successfully" if result.return_code == 0 else "failed" + outputs.append(f"Command '{cmd}' {status}. Output: {truncated}\n\n(exit_code={result.return_code})") + return "\n\n".join(outputs) + + def _select_commands(self, commands: list) -> list: + if self.generator_cfg.command_selection == "all": + return commands + if self.generator_cfg.command_selection == "last": + return commands[-self.generator_cfg.max_commands_per_turn :] + return commands[: self.generator_cfg.max_commands_per_turn] + + def _compute_reward(self, verifier_reward: float, interaction: TerminalInteraction) -> float: + total_reward = verifier_reward + completion_len = len(interaction.completion_token_ids) + if ( + self.generator_cfg.length_penalty_coef > 0 + and completion_len > self.generator_cfg.length_penalty_threshold + ): + excess = completion_len - self.generator_cfg.length_penalty_threshold + total_reward -= self.generator_cfg.length_penalty_coef * (excess / 1000) + return float(total_reward) + + def _mark_failure( + self, + interaction: TerminalInteraction, + stop_reason: str, + exc: Exception | None = None, + ) -> None: + interaction.metadata["stop_reason"] = stop_reason + if exc is not None: + interaction.metadata["error_type"] = type(exc).__name__ + interaction.metadata["error_message"] = str(exc) + interaction.reward = 0.0 + interaction.correct = False + metrics: dict[str, float | int] = {f"stop_reason/{stop_reason}": 1} + if stop_reason == "env_setup_error": + metrics["env_setup_errors"] = 1 + if exc is not None: + interaction.metadata["env_setup_error"] = str(exc) + interaction.metrics = metrics + self._ensure_non_empty_completion(interaction) + + def _failure_output( + self, + prompt: list[dict[str, Any]], + env_extra: dict[str, Any], + trajectory_id: TrajectoryID, + stop_reason: str, + ) -> TerminalTrajectoryOutput: + prompt_token_ids = list(env_extra.get("prompt_token_ids") or self.tokenizer.apply_chat_template( + prompt, tokenize=True, add_generation_prompt=True, return_dict=False + )) + eos_id = self.tokenizer.eos_token_id or 0 + return TerminalTrajectoryOutput( + trajectory_id=trajectory_id, + prompt_token_ids=prompt_token_ids, + response_ids=[eos_id], + loss_masks=[0], + world_loss_masks=[0], + world_warning_masks=[0], + world_env_masks=[0], + world_full_observation_count=0, + rollout_logprobs=[0.0], + world_model_only=bool(env_extra.get("world_model_only", False)), + reward=0.0, + correct=False, + stop_reason=stop_reason, + metrics={f"stop_reason/{stop_reason}": 1}, + metadata={ + "stop_reason": stop_reason, + "path": env_extra.get("path", ""), + "data_source": env_extra.get("data_source"), + }, + ) + + def _ensure_non_empty_completion(self, interaction: TerminalInteraction) -> None: + if interaction.completion_token_ids: + return + eos_id = self.tokenizer.eos_token_id or 0 + interaction.completion_token_ids = [eos_id] + interaction.completion_masks = [0] + interaction.completion_observation_masks = [0] + interaction.completion_warning_masks = [0] + interaction.completion_env_output_masks = [0] + interaction.completion_logprobs = [0.0] + + def _has_token_budget(self, interaction: TerminalInteraction, num_tokens: int) -> bool: + if self.generator_cfg.max_total_tokens is None: + return True + return len(interaction.token_ids) + num_tokens <= self.generator_cfg.max_total_tokens + + def _init_thinking_token_ids(self) -> None: + unk_id = self.tokenizer.unk_token_id + think_start = self.tokenizer.convert_tokens_to_ids("") + think_end = self.tokenizer.convert_tokens_to_ids("") + if think_start is not None and think_start != unk_id: + self._think_start_id = think_start + if think_end is not None and think_end != unk_id: + self._think_end_id = think_end + + def _strip_thinking_from_tokens( + self, + token_ids: list[int], + logprobs: list[float] | None = None, + ) -> tuple[list[int], list[float] | None]: + if self._think_start_id is None: + return token_ids, logprobs + clean_ids: list[int] = [] + clean_logprobs: list[float] | None = [] if logprobs is not None else None + in_thinking = False + for i, token_id in enumerate(token_ids): + if token_id == self._think_start_id: + in_thinking = True + continue + if in_thinking and self._think_end_id is not None and token_id == self._think_end_id: + in_thinking = False + continue + if not in_thinking: + clean_ids.append(token_id) + if clean_logprobs is not None: + clean_logprobs.append(logprobs[i]) + return clean_ids, clean_logprobs + + @staticmethod + def _truncate_observation_segments(warning_text: str, env_text: str, max_chars: int) -> tuple[str, str, str]: + full_text = warning_text + env_text + if len(full_text) <= max_chars: + return warning_text, env_text, full_text + truncated_full = full_text[:max_chars] + warning_len = min(len(warning_text), len(truncated_full)) + return truncated_full[:warning_len], truncated_full[warning_len:], truncated_full + + @staticmethod + def _turn_trace( + turn: int, + generate_sec: float, + generate_tokens: int, + exec_sec: float, + parse_error: bool, + is_done: bool, + context_len: int, + format_violations: list[str], + ) -> dict[str, Any]: + return { + "turn": turn, + "generate_sec": generate_sec, + "generate_tokens": generate_tokens, + "exec_sec": exec_sec, + "parse_error": parse_error, + "is_done": is_done, + "context_len": context_len, + "format_violations": format_violations, + } + + def _build_generator_output(self, trajectory_outputs: list[TerminalTrajectoryOutput]) -> GeneratorOutput: + prompt_token_ids = [t.prompt_token_ids for t in trajectory_outputs] + response_ids = [t.response_ids for t in trajectory_outputs] + rewards = [t.reward for t in trajectory_outputs] + loss_masks = [t.loss_masks for t in trajectory_outputs] + stop_reasons = [t.stop_reason for t in trajectory_outputs] + rollout_logprobs = [t.rollout_logprobs for t in trajectory_outputs] + world_loss_masks = [t.world_loss_masks for t in trajectory_outputs] + world_warning_masks = [t.world_warning_masks for t in trajectory_outputs] + world_env_masks = [t.world_env_masks for t in trajectory_outputs] + world_full_observation_counts = [t.world_full_observation_count for t in trajectory_outputs] + world_model_only = [t.world_model_only for t in trajectory_outputs] + correct = [t.correct for t in trajectory_outputs] + trajectory_metadata = [t.metadata for t in trajectory_outputs] + rollout_metrics = get_rollout_metrics(response_ids, rewards) + for key in set().union(*(t.metrics.keys() for t in trajectory_outputs)): + values = [t.metrics.get(key, 0) for t in trajectory_outputs] + rollout_metrics[f"generate/{key}"] = sum(values) + rollout_metrics["generate/avg_num_turns"] = ( + sum(t.metrics.get("num_agent_turns", 0) for t in trajectory_outputs) / len(trajectory_outputs) + if trajectory_outputs + else 0 + ) + return { + "prompt_token_ids": prompt_token_ids, + "response_ids": response_ids, + "rewards": rewards, + "loss_masks": loss_masks, + "world_loss_masks": world_loss_masks, + "world_warning_masks": world_warning_masks, + "world_env_masks": world_env_masks, + "world_full_observation_counts": world_full_observation_counts, + "world_model_only": world_model_only, + "correct": correct, + "trajectory_metadata": trajectory_metadata, + "stop_reasons": stop_reasons, + "rollout_metrics": rollout_metrics, + "rollout_logprobs": rollout_logprobs, + "trajectory_ids": [t.trajectory_id for t in trajectory_outputs], + "rollout_expert_indices": None, + "is_last_step": None, + "pixel_values": None, + "image_grid_thw": None, + } diff --git a/examples/train_integrations/echo_terminal/harbor_environment.py b/examples/train_integrations/echo_terminal/harbor_environment.py new file mode 100644 index 0000000000..6f0844e55e --- /dev/null +++ b/examples/train_integrations/echo_terminal/harbor_environment.py @@ -0,0 +1,435 @@ +import asyncio +import io +import json +import logging +import os +import shutil +import tarfile +import tempfile +import time +import uuid +from pathlib import Path, PurePosixPath +from typing import Any + +from harbor.environments.base import BaseEnvironment, ExecResult +from harbor.environments.docker.docker import DockerEnvironment +from harbor.environments.factory import EnvironmentFactory +from harbor.models.environment_type import EnvironmentType +from harbor.models.task.config import EnvironmentConfig +from harbor.models.task.task import Task +from harbor.models.trial.paths import TrialPaths +from harbor.verifier.verifier import Verifier + +logger = logging.getLogger(__name__) + + +class EnvironmentStartTimeoutError(asyncio.TimeoutError): + pass + + +class HarborEnvironment: + def __init__( + self, + task_name: str, + shared_task: Task, + shared_env_name: str, + shared_task_env_config: EnvironmentConfig, + rollout_id: str, + base_temp_dir: Path | None = None, + keep_container: bool = False, + environment_build_timeout_sec: float = 600.0, + force_build: bool = False, + use_prebuilt_image: bool = False, + ) -> None: + self.task_name = task_name + self._shared_task = shared_task + self._shared_env_name = shared_env_name + self._shared_task_env_config = shared_task_env_config + self._rollout_id = rollout_id + self._base_temp_dir = base_temp_dir + self._keep_container = keep_container + self._environment_build_timeout_sec = environment_build_timeout_sec + self._force_build = force_build + self._use_prebuilt_image = use_prebuilt_image + self._work_dir: Path | None = None + self._trial_paths: TrialPaths | None = None + self._environment: BaseEnvironment | None = None + self._is_setup = False + self._env_setup_sec: float | None = None + + async def setup(self) -> None: + if self._is_setup: + return + dir_label = self.task_name.replace("/", "_").replace(" ", "_")[:40] or "tbench" + dir_suffix = self._rollout_id or uuid.uuid4().hex[:8] + if self._base_temp_dir: + self._base_temp_dir.mkdir(parents=True, exist_ok=True) + self._work_dir = self._base_temp_dir / f"{dir_label}_r{dir_suffix}" + else: + self._work_dir = Path(tempfile.gettempdir()) / f"tbench_{dir_label}_r{dir_suffix}" + + trial_dir = self._work_dir / "trial" + self._trial_paths = TrialPaths(trial_dir) + self._trial_paths.mkdir() + + env_name = self._shared_env_name + session_id = f"{env_name[:32]}__{uuid.uuid4().hex[:7]}" + + if self._use_prebuilt_image: + task_env_config = self._shared_task_env_config + force_build = False + elif self._force_build: + task_env_config = self._shared_task_env_config + force_build = True + else: + task_env_config = self._shared_task_env_config.model_copy() + task_env_config.docker_image = f"hb__{env_name}" + self._ensure_pull_policy_override(self._shared_task.paths.environment_dir) + force_build = False + + self._environment = EnvironmentFactory.create_environment( + type=EnvironmentType.DOCKER, + environment_dir=self._shared_task.paths.environment_dir, + environment_name=self._shared_env_name, + session_id=session_id, + trial_paths=self._trial_paths, + task_env_config=task_env_config, + keep_containers=self._keep_container, + suppress_override_warnings=True, + ) + t_env = time.monotonic() + await self._start_environment_with_retry(force_build=force_build) + self._env_setup_sec = time.monotonic() - t_env + self._is_setup = True + + async def _start_environment_with_retry(self, force_build: bool = True) -> None: + last_err: BaseException | None = None + for attempt in range(1, 3): + try: + await asyncio.wait_for( + self._environment.start(force_build=force_build), + timeout=self._environment_build_timeout_sec, + ) + if force_build: + await self._tag_built_image() + return + except asyncio.TimeoutError as exc: + last_err = EnvironmentStartTimeoutError( + f"Environment start timed out after {self._environment_build_timeout_sec} seconds" + ) + except RuntimeError as exc: + last_err = exc + if attempt < 2: + await asyncio.sleep(1) + raise last_err + + async def _tag_built_image(self) -> None: + target_image = f"hb__{self._shared_env_name}" + try: + result = await self._environment._run_docker_compose_command(["images", "--format", "json"], check=False) + images = json.loads(result.stdout or "[]") + if not images: + return + source_image = f"{images[0]['Repository']}:{images[0]['Tag']}" + proc = await asyncio.subprocess.create_subprocess_exec( + "docker", + "tag", + source_image, + target_image, + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE, + ) + _, stderr = await proc.communicate() + if proc.returncode != 0: + logger.warning("Failed to tag %s as %s: %s", source_image, target_image, stderr.decode()) + except Exception as exc: + logger.warning("Could not query/tag container image: %s", exc) + + @staticmethod + def _ensure_pull_policy_override(environment_dir: Path) -> None: + compose_path = environment_dir / "docker-compose.yaml" + if not compose_path.exists(): + compose_path.write_text("services:\n main:\n pull_policy: never\n") + + async def exec(self, command: str, timeout: float | None = None) -> ExecResult: + if not self._is_setup: + raise RuntimeError("Environment not set up. Call setup() first.") + timeout_sec = int(timeout) if timeout is not None else None + return await self._environment.exec(command, timeout_sec=timeout_sec) + + async def run_verifier(self, timeout: float | None = None) -> tuple[float, str | None]: + if not self._is_setup: + raise RuntimeError("Environment not set up. Call setup() first.") + try: + verifier = Verifier(task=self._shared_task, trial_paths=self._trial_paths, environment=self._environment) + result = await asyncio.wait_for(verifier.verify(), timeout=timeout) if timeout is not None else await verifier.verify() + reward = result.rewards.get("reward", 0.0) if result.rewards else 0.0 + return float(reward), None if result.rewards else "no_reward" + except asyncio.TimeoutError: + logger.warning("Verifier timed out after %ss", timeout) + return 0.0, "verifier_timeout" + except Exception as exc: + logger.warning("Verifier failed: %s", exc) + return 0.0, "verifier_error" + + async def cleanup(self) -> None: + if self._environment is not None: + try: + if self._keep_container: + await self._environment.stop(delete=False) + else: + await self._environment._run_docker_compose_command(["down", "--volumes", "--remove-orphans"]) + except Exception as exc: + logger.warning("Environment stop failed: %s", exc) + self._environment = None + + if self._work_dir and self._work_dir.exists() and not self._keep_container: + shutil.rmtree(self._work_dir, ignore_errors=True) + self._work_dir = None + self._is_setup = False + + +class _SharedTaskImage: + def __init__( + self, + task_binary: bytes, + task_name: str = "", + base_temp_dir: Path | None = None, + memory_mb: int | None = None, + cpus: int | None = None, + ) -> None: + self._task_binary = task_binary + self._task_name = task_name + self._base_temp_dir = base_temp_dir + self._memory_mb = memory_mb + self._cpus = cpus + self._work_dir: Path | None = None + self._task: Task | None = None + self._env_name: str | None = None + self._task_env_config: EnvironmentConfig | None = None + self._is_setup = False + + @property + def env_name(self) -> str | None: + return self._env_name + + @property + def has_prebuilt_image(self) -> bool: + return self._is_setup and bool(self._task_env_config.docker_image) + + def setup(self) -> None: + if self._is_setup: + return + dir_suffix = uuid.uuid4().hex[:8] + dir_label = self._task_name.replace("/", "_").replace(" ", "_")[:40] or "tbench" + if self._base_temp_dir: + self._base_temp_dir.mkdir(parents=True, exist_ok=True) + self._work_dir = self._base_temp_dir / f"shared_{dir_label}_{dir_suffix}" + else: + self._work_dir = Path(tempfile.gettempdir()) / f"tbench_shared_{dir_label}_{dir_suffix}" + + task_dir = self._work_dir / dir_label + task_dir.mkdir(parents=True, exist_ok=True) + _safe_extract_tar(self._task_binary, task_dir) + self._task = Task(task_dir) + self._task_env_config = self._task.config.environment + if self._memory_mb is not None: + self._task_env_config.memory_mb = self._memory_mb + if self._cpus is not None: + self._task_env_config.cpus = self._cpus + self._env_name = self._task.name + self._is_setup = True + + def create_environment(self, rollout_id: str, force_build: bool = False) -> HarborEnvironment: + if not self._is_setup: + raise RuntimeError("_SharedTaskImage not set up. Call setup() first.") + return HarborEnvironment( + task_name=self._task_name, + shared_task=self._task, + shared_env_name=self._env_name, + shared_task_env_config=self._task_env_config, + rollout_id=rollout_id, + base_temp_dir=self._base_temp_dir, + force_build=force_build, + use_prebuilt_image=self.has_prebuilt_image, + ) + + async def build_image(self) -> None: + if not self._is_setup: + raise RuntimeError("_SharedTaskImage not set up. Call setup() first.") + build_env = self.create_environment(rollout_id="build", force_build=True) + try: + await build_env.setup() + finally: + await build_env.cleanup() + + async def pull_image(self) -> None: + if not self._is_setup: + raise RuntimeError("_SharedTaskImage not set up. Call setup() first.") + image = self._task_env_config.docker_image + proc = await asyncio.subprocess.create_subprocess_exec( + "docker", + "pull", + image, + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE, + ) + _, stderr = await proc.communicate() + if proc.returncode != 0: + raise RuntimeError(f"docker pull failed for {image!r}: {stderr.decode()}") + + async def cleanup(self) -> None: + if self._env_name and not self.has_prebuilt_image: + try: + proc = await asyncio.subprocess.create_subprocess_exec( + "docker", + "rmi", + "-f", + f"hb__{self._env_name}", + stdout=asyncio.subprocess.DEVNULL, + stderr=asyncio.subprocess.DEVNULL, + ) + await proc.wait() + except Exception: + logger.debug("Failed to remove image hb__%s", self._env_name) + if self._work_dir and self._work_dir.exists(): + shutil.rmtree(self._work_dir, ignore_errors=True) + self._work_dir = None + self._is_setup = False + + +class HarborEnvironmentProvider: + def __init__( + self, + docker_memory_mb: int | None = None, + docker_cpus: int | None = None, + base_temp_dir: Path | None = None, + max_concurrent_builds: int = 32, + max_build_retries: int = 3, + ) -> None: + self._docker_memory_mb = docker_memory_mb + self._docker_cpus = docker_cpus + self._base_temp_dir = base_temp_dir + self._max_concurrent_builds = max_concurrent_builds + self._max_build_retries = max_build_retries + self._shared_images: list[_SharedTaskImage] = [] + self._failed_prompts: set[int] = set() + self._rollout_counter: dict[int, int] = {} + + async def prepare_batch(self, batch_environment_data: list[dict[str, Any]], num_generations: int) -> None: + DockerEnvironment._image_build_locks.clear() + self._shared_images = [] + self._failed_prompts = set() + self._rollout_counter = {} + + for prompt_idx, env_data in enumerate(batch_environment_data): + task_binary = env_data.get("task_binary") + task_name = env_data.get("path", "unknown") + if task_binary is None: + raise ValueError(f"prompt {prompt_idx}: environment_data must contain 'task_binary'.") + env_data["_prompt_idx"] = prompt_idx + shared_image = _SharedTaskImage( + task_binary=task_binary, + task_name=task_name, + base_temp_dir=self._base_temp_dir, + memory_mb=self._docker_memory_mb, + cpus=self._docker_cpus, + ) + shared_image.setup() + self._shared_images.append(shared_image) + + build_semaphore = asyncio.Semaphore(self._max_concurrent_builds) + + async def _with_retries(action, img: _SharedTaskImage) -> None: + async with build_semaphore: + last_err: BaseException | None = None + for attempt in range(1, self._max_build_retries + 1): + try: + await action() + return + except Exception as exc: + last_err = exc + logger.warning( + "Image build/pull attempt %s/%s failed for %s: %s", + attempt, + self._max_build_retries, + img.env_name or "?", + exc, + ) + if attempt < self._max_build_retries: + await asyncio.sleep(min(2**attempt, 10)) + raise last_err + + work = [] + for idx, img in enumerate(self._shared_images): + action = img.pull_image if img.has_prebuilt_image else img.build_image + work.append((idx, _with_retries(action, img))) + results = await asyncio.gather(*[task for _, task in work], return_exceptions=True) + for (prompt_idx, _), result in zip(work, results): + if isinstance(result, BaseException): + logger.error("Image build/pull failed for prompt %s: %s", prompt_idx, result) + self._failed_prompts.add(prompt_idx) + + async def create(self, environment_data: dict[str, Any]) -> HarborEnvironment: + prompt_idx = environment_data.get("_prompt_idx", 0) + if prompt_idx in self._failed_prompts: + raise RuntimeError(f"Docker image build failed for prompt {prompt_idx}.") + if prompt_idx >= len(self._shared_images): + raise RuntimeError(f"prompt_idx {prompt_idx} out of range.") + rollout_num = self._rollout_counter.get(prompt_idx, 0) + self._rollout_counter[prompt_idx] = rollout_num + 1 + return self._shared_images[prompt_idx].create_environment(rollout_id=f"{prompt_idx}_{rollout_num}") + + async def cleanup_batch(self) -> None: + for shared_image in self._shared_images: + try: + await shared_image.cleanup() + except Exception: + logger.exception("[shared-image-cleanup] Failed") + self._shared_images = [] + self._failed_prompts = set() + self._rollout_counter = {} + + +def _sanitize_tar_member_name(name: str) -> str: + p = PurePosixPath(name) + parts = [part for part in p.parts if part not in ("..", ".", "")] + while parts and parts[0] == "/": + parts.pop(0) + return str(PurePosixPath(*parts)) if parts else "" + + +def _is_within(base: Path, target: Path) -> bool: + try: + return os.path.commonpath([str(base.resolve()), str(target.resolve())]) == str(base.resolve()) + except Exception: + return False + + +def _safe_extract_tar(archive_bytes: bytes, dest_dir: Path) -> None: + dest_dir.mkdir(parents=True, exist_ok=True) + buf = io.BytesIO(archive_bytes) + with tarfile.open(fileobj=buf, mode="r:*") as tf: + for member in tf.getmembers(): + member_name = _sanitize_tar_member_name(member.name) + if not member_name or member_name.endswith("/"): + dir_path = member_name.rstrip("/") + if dir_path: + (dest_dir / dir_path).mkdir(parents=True, exist_ok=True) + continue + if ".snapshot" in PurePosixPath(member_name).parts: + continue + target = (dest_dir / member_name).resolve() + if not _is_within(dest_dir, target): + raise RuntimeError(f"Unsafe path in archive: {member.name}") + target.parent.mkdir(parents=True, exist_ok=True) + if member.isfile(): + with tf.extractfile(member) as src: + if src is None: + continue + with open(target, "wb") as dst: + dst.write(src.read()) + if member.mode & 0o111: + target.chmod(target.stat().st_mode | 0o111) + elif member.isdir(): + target.mkdir(parents=True, exist_ok=True) diff --git a/examples/train_integrations/echo_terminal/interaction.py b/examples/train_integrations/echo_terminal/interaction.py new file mode 100644 index 0000000000..c9abe352e2 --- /dev/null +++ b/examples/train_integrations/echo_terminal/interaction.py @@ -0,0 +1,78 @@ +from dataclasses import dataclass, field +from typing import Any + + +@dataclass(frozen=True) +class AddedMessageSpan: + completion_start: int + completion_end: int + message_start: int + message_end: int + generation_prompt_start: int + generation_prompt_end: int + + +@dataclass +class TerminalInteraction: + prompt_id: int + completion_id: int + prompt_messages: list[dict[str, Any]] + prompt_token_ids: list[int] + metadata: dict[str, Any] = field(default_factory=dict) + completion_messages: list[dict[str, Any]] = field(default_factory=list) + completion_token_ids: list[int] = field(default_factory=list) + completion_masks: list[int] = field(default_factory=list) + completion_observation_masks: list[int] = field(default_factory=list) + completion_warning_masks: list[int] = field(default_factory=list) + completion_env_output_masks: list[int] = field(default_factory=list) + completion_logprobs: list[float] = field(default_factory=list) + reward: float = 0.0 + correct: bool = False + metrics: dict[str, float | int] = field(default_factory=dict) + + @property + def messages(self) -> list[dict[str, Any]]: + return self.prompt_messages + self.completion_messages + + @property + def token_ids(self) -> list[int]: + return self.prompt_token_ids + self.completion_token_ids + + def append_assistant(self, token_ids: list[int], logprobs: list[float], tokenizer) -> str: + text = tokenizer.decode(token_ids, skip_special_tokens=True) + self.completion_messages.append({"role": "assistant", "content": text}) + self.completion_token_ids.extend(token_ids) + self.completion_masks.extend([1] * len(token_ids)) + self.completion_observation_masks.extend([0] * len(token_ids)) + self.completion_warning_masks.extend([0] * len(token_ids)) + self.completion_env_output_masks.extend([0] * len(token_ids)) + self.completion_logprobs.extend(logprobs) + return text + + def append_tokens( + self, + token_ids: list[int], + masks: list[int], + logprobs: list[float], + role: str, + text: str, + observation_masks: list[int] | None = None, + warning_masks: list[int] | None = None, + env_output_masks: list[int] | None = None, + ) -> None: + if len(token_ids) != len(masks) or len(token_ids) != len(logprobs): + raise ValueError("token_ids, masks, and logprobs must have the same length") + for name, extra_masks in ( + ("observation_masks", observation_masks), + ("warning_masks", warning_masks), + ("env_output_masks", env_output_masks), + ): + if extra_masks is not None and len(extra_masks) != len(token_ids): + raise ValueError(f"token_ids and {name} must have the same length") + self.completion_messages.append({"role": role, "content": text}) + self.completion_token_ids.extend(token_ids) + self.completion_masks.extend(masks) + self.completion_observation_masks.extend(observation_masks or [0] * len(token_ids)) + self.completion_warning_masks.extend(warning_masks or [0] * len(token_ids)) + self.completion_env_output_masks.extend(env_output_masks or [0] * len(token_ids)) + self.completion_logprobs.extend(logprobs) diff --git a/examples/train_integrations/echo_terminal/parsers.py b/examples/train_integrations/echo_terminal/parsers.py new file mode 100644 index 0000000000..aac81c0997 --- /dev/null +++ b/examples/train_integrations/echo_terminal/parsers.py @@ -0,0 +1,199 @@ +import json +import re +from dataclasses import dataclass, field + + +@dataclass +class ParsedCommand: + name: str = "" + arguments: dict = field(default_factory=dict) + error: str = "" + raw: str = "" + + +@dataclass +class ParseResult: + thinking: str = "" + commands: list[ParsedCommand] = field(default_factory=list) + is_done: bool = False + error: str = "" + warnings: list[str] = field(default_factory=list) + violations: list[str] = field(default_factory=list) + + +def check_format_warnings( + raw_response: str, + tags: list[str] | None = None, + parser_name: str = "xml", + thinking_enabled: bool = True, +) -> tuple[list[str], list[str]]: + if not raw_response: + return [], [] + if thinking_enabled and parser_name == "qwen35": + raw_response = "\n" + raw_response + + tags = tags or ["command", "action"] + all_tags = ["think"] + [t for t in tags if t != "think"] + tag_alternation = "|".join(re.escape(t) for t in all_tags) + tag_re = re.compile(rf"<(/?)({tag_alternation})>") + + warnings: list[str] = [] + violations: list[str] = [] + + if thinking_enabled: + think_opens = len(re.findall(r"", raw_response)) + if think_opens == 0: + warnings.append("Missing block") + violations.append("missing_think") + elif think_opens > 1: + warnings.append(f"Found {think_opens} blocks, expected 1") + violations.append("multiple_think") + + for tag in all_tags: + opens = len(re.findall(f"<{re.escape(tag)}>", raw_response)) + closes = len(re.findall(f"", raw_response)) + if opens > closes: + warnings.append(f"Unclosed <{tag}> tag") + violations.append("unclosed_tag") + if re.findall(rf"<{re.escape(tag)}>\s*", raw_response): + warnings.append(f"Empty <{tag}> block") + violations.append("empty_block") + + open_tag = None + for match in tag_re.finditer(raw_response): + is_close = match.group(1) == "/" + tag_name = match.group(2) + if is_close: + if open_tag == tag_name: + open_tag = None + elif open_tag is not None: + warnings.append(f"<{tag_name}> nested inside unclosed <{open_tag}>") + violations.append("nested_tags") + break + else: + open_tag = tag_name + + first_open = re.search(rf"<(?:{tag_alternation})>", raw_response) + if first_open and raw_response[: first_open.start()].strip(): + warnings.append("Text detected before first XML tag") + violations.append("text_outside_tags") + else: + between_re = rf"(.*?)<(?:{tag_alternation})>" + for match in re.finditer(between_re, raw_response, re.DOTALL): + if match.group(1).strip(): + warnings.append("Text detected between XML tags") + violations.append("text_outside_tags") + break + + last_close = None + for match in re.finditer(rf"", raw_response): + last_close = match + if last_close and raw_response[last_close.end() :].strip(): + warnings.append("Text detected after final XML tag") + violations.append("text_outside_tags") + + return warnings, violations + + +class XMLParser: + format_tags = ["command", "action"] + + def parse_response(self, response: str | None) -> ParseResult: + if not response: + return ParseResult(error="Empty response") + think_match = re.search(r"(.*?)", response, re.DOTALL) + thinking = think_match.group(1).strip() if think_match else "" + command_matches = re.findall(r"(.*?)", response, re.DOTALL) + commands = [ + ParsedCommand(name="bash", arguments={"command": cmd.strip()}) for cmd in command_matches if cmd.strip() + ] + is_done = bool(re.search(r"\s*done\s*", response, re.IGNORECASE)) + if not commands and not is_done: + return ParseResult(thinking=thinking, error="No or done found in response.") + return ParseResult(thinking=thinking, commands=commands, is_done=is_done) + + +class HFHermesParser: + format_tags = ["tool_call"] + _TOOL_CALL_RE = re.compile(r"(.*?)", re.DOTALL) + _ANSWER_RE = re.compile(r"(.*?)", re.DOTALL) + + def parse_response(self, response: str | None, extract_answer_tags_for_done: bool = False) -> ParseResult: + if not response: + return ParseResult(thinking="", error="Empty response") + think_match = re.search(r"(.*?)", response, re.DOTALL) + thinking = think_match.group(1).strip() if think_match else "" + is_done = bool(self._ANSWER_RE.search(response)) if extract_answer_tags_for_done else False + matches = self._TOOL_CALL_RE.findall(response) + if not matches: + if is_done: + return ParseResult(thinking=thinking, is_done=True) + return ParseResult(thinking=thinking, error="No found in response.") + commands = [] + for raw in matches: + error = "" + try: + payload = json.loads(raw.strip()) + except json.JSONDecodeError: + commands.append(ParsedCommand(error="Invalid JSON in tool_call", raw=raw.strip())) + continue + if isinstance(payload, dict): + name = payload.get("name", "") + args = payload.get("arguments", {}) + else: + name, args, error = "unknown", {}, "Expected JSON object with 'name' and 'arguments' fields" + if not isinstance(args, dict): + args, error = {}, "Expected 'arguments' field to be a JSON object" + if not extract_answer_tags_for_done and name == "done": + is_done = True + elif name: + commands.append(ParsedCommand(name=name, arguments=args, error=error, raw=raw.strip())) + if not commands and not is_done: + return ParseResult(thinking=thinking, error="No valid tool calls or done signal found.") + return ParseResult(thinking=thinking, commands=commands, is_done=is_done) + + +class Qwen35XMLParser: + format_tags = ["tool_call"] + _TOOL_CALL_RE = re.compile(r"(.*?)(?:|$)", re.DOTALL) + _FUNCTION_RE = re.compile(r"", re.DOTALL) + _PARAM_RE = re.compile(r"\n?(.*?)\n?", re.DOTALL) + _ANSWER_RE = re.compile(r"(.*?)", re.DOTALL) + + def parse_response(self, response: str | None, extract_answer_tags_for_done: bool = False) -> ParseResult: + if not response: + return ParseResult(thinking="", error="Empty response") + think_match = re.search(r"(.*?)", response, re.DOTALL) + thinking = think_match.group(1).strip() if think_match else "" + matches = self._TOOL_CALL_RE.findall(response) + if not matches: + return ParseResult(thinking=thinking, error="No found in response.") + is_done = bool(self._ANSWER_RE.search(response)) if extract_answer_tags_for_done else False + commands = [] + for raw in matches: + func_match = self._FUNCTION_RE.search(raw) + if not func_match: + commands.append(ParsedCommand(error="Missing tag", raw=raw.strip())) + continue + name = func_match.group(1) + if not extract_answer_tags_for_done and name == "done": + is_done = True + continue + params = {m.group(1): m.group(2).strip() for m in self._PARAM_RE.finditer(raw)} + if params: + commands.append(ParsedCommand(name=name, arguments=params, raw=raw.strip())) + else: + commands.append(ParsedCommand(name=name, arguments={}, error="No parameters found", raw=raw.strip())) + if not commands and not is_done: + return ParseResult(thinking=thinking, error="No valid tool calls or done signal found.") + return ParseResult(thinking=thinking, commands=commands, is_done=is_done) + + +def get_parser(parser_name: str) -> HFHermesParser | Qwen35XMLParser | XMLParser: + if parser_name == "hermes": + return HFHermesParser() + if parser_name == "xml": + return XMLParser() + if parser_name == "qwen35": + return Qwen35XMLParser() + raise ValueError(f"Unknown parser_name {parser_name!r}. Must be 'hermes', 'xml', or 'qwen35'.") diff --git a/examples/train_integrations/echo_terminal/prompts.py b/examples/train_integrations/echo_terminal/prompts.py new file mode 100644 index 0000000000..6ba9c4db86 --- /dev/null +++ b/examples/train_integrations/echo_terminal/prompts.py @@ -0,0 +1,88 @@ +from typing import Any + +XML_SYSTEM_PROMPT = """You are a highly capable Linux terminal agent operating strictly via a single-shell-command interface. +Goal: Complete the user's task. + +Detailed Instructions: +- Output exactly one of the following per turn after you think in the tags: + 1) THE_SINGLE_SHELL_COMMAND + XOR (XOR means you can only respond with one of the two) + 2) done +- Don't use interactive commands and confirmations; use non-interactive flags. +- Prefer simple, robust CLI tools; write files explicitly when needed. +- If you believe the task is solved, emit done. +- You should run commands interactively to see the output and then write the command. Don't just pipe the commands. +- Only your first command in command tags will be executed. So don't respond with multiple commands. +- Verify your solution once you are done. Eg: you can use cat to see the input and the output. +- Do not just write long bash scripts. Write the commands that you would write in a terminal. +- Only respond with one of ... or done after you think in the tags. +- Plan and simulate your actions in tags before you respond with .... +""".strip() + +HF_HERMES_INSTRUCTION_PREFIX = """You are given a task description and your goal is to solve the task by using shell commands and python code. +Start each response with a ... section where you analyze the current state based on the terminal output and describe your plan for the next steps. Then, provide your commands using bash tool calls that specify the commands to execute. When you determine that the task is complete, use the done tool call to indicate completion. + +Required tools: +- "bash": Call this to execute the specified bash command in the terminal and return the output to you in the next turn as terminal_output. You can use this to iteratively solve the task by analyzing the terminal output after each command. Example: \n{"name": "bash", "arguments": {"command": "...", "timeout": ...}}\n + +Optional tools: +- "done": Call this when you have verified the task is solved. Example: \n{"name": "done", "arguments": {}}\n + +IMPORTANT: +- Only use or to structure your response as described above. +- Exactly one ... block should be present. Failure to follow the response format will lead to parsing errors. +- Bash tool calls will be executed sequentially in the order they appear in the response. +- Each command will be executed in the same shell environment, so you can rely on side effects (e.g. file creation) across commands in the same batch. +- Each shell command string will be used completely verbatim. Write commands exactly as you want them sent to the terminal. +- Do not include extra whitespace before or after the commands unless it's part of the intended command +- The done tool can be included optionally at the end of the response if you determine that the task is complete. If included, it should be the last tool_call in the response. If not included, it is assumed to be false and the agent will continue with another turn. +- The tool_call JSON must be valid - use proper escaping for quotes and special characters within strings""" + +QWEN35_INSTRUCTION_PREFIX = """You are given a task description and your goal is to solve the task by using shell commands and python code. +Start each response with a ... section where you analyze the current state based on the terminal output and describe your plan for the next steps. Then, provide your commands using bash tool calls that specify the commands to execute. When you determine that the task is complete, use the done tool call to indicate completion. + +Required tools: +- "bash": Call this to execute the specified bash command in the terminal and return the output to you in the next turn as terminal_output. You can use this to iteratively solve the task by analyzing the terminal output after each command. Example: \n\n\n....\n\n\n....\n\n\n + +Optional tools: +- "done": Call this when you have verified the task is solved. Example: \n\n\n + +IMPORTANT: +- Only use or to structure your response as described above. +- Exactly one ... block should be present. Failure to follow the response format will lead to parsing errors. +- Bash tool calls will be executed sequentially in the order they appear in the response. +- Each command will be executed in the same shell environment, so you can rely on side effects (e.g. file creation) across commands in the same batch. +- Each shell command string will be used completely verbatim. Write commands exactly as you want them sent to the terminal. +- Do not include extra whitespace before or after the commands unless it's part of the intended command +- The done tool can be included optionally at the end of the response if you determine that the task is complete. If included, it should be the last tool_call in the response. If not included, it is assumed to be false and the agent will continue with another turn. +- You may use the bash tool to verify and validate that the task is complete before calling the done tool. Verify that commands were executed successfully and that the expected output is present before calling done. The done tool should ONLY be called once you are confident that the task is fully solved.""" + +_INSTRUCTION_PREFIX_REGISTRY = { + "hermes": HF_HERMES_INSTRUCTION_PREFIX, + "qwen35": QWEN35_INSTRUCTION_PREFIX, +} + + +def get_system_prompt(parser_name: str = "xml") -> str: + if parser_name != "xml": + raise ValueError("Only the xml parser has a raw system prompt.") + return XML_SYSTEM_PROMPT + + +def format_task_prompt( + instruction: str, + parser_name: str = "xml", + tokenizer: Any = None, + add_instruction_prefix: bool = False, +) -> str | list[dict]: + if parser_name in ["hermes", "qwen35"]: + from .tools import get_augmented_system_content + + if tokenizer is None: + raise ValueError("tokenizer is required for hermes or qwen35 parser") + user_content = instruction + if add_instruction_prefix: + user_content = _INSTRUCTION_PREFIX_REGISTRY.get(parser_name, "") + "\n\nTask: " + user_content + return [{"role": "system", "content": get_augmented_system_content(tokenizer, parser_name)}, {"role": "user", "content": user_content}] + + return get_system_prompt(parser_name) + "\n\nTask:\n" + instruction diff --git a/examples/train_integrations/echo_terminal/run_echo_terminal.sh b/examples/train_integrations/echo_terminal/run_echo_terminal.sh new file mode 100755 index 0000000000..9f65dec4a5 --- /dev/null +++ b/examples/train_integrations/echo_terminal/run_echo_terminal.sh @@ -0,0 +1,6 @@ +#!/usr/bin/env bash +set -euo pipefail + +CONFIG_PATH=${CONFIG_PATH:-examples/train_integrations/echo_terminal/configs/qwen3_8b_rl.yaml} +export PYTHONPATH="${PYTHONPATH:-}:$(pwd)" +python -m examples.train_integrations.echo_terminal.entrypoint --config "$CONFIG_PATH" diff --git a/examples/train_integrations/echo_terminal/tools.py b/examples/train_integrations/echo_terminal/tools.py new file mode 100644 index 0000000000..60f305799a --- /dev/null +++ b/examples/train_integrations/echo_terminal/tools.py @@ -0,0 +1,55 @@ +import re +from typing import Any + +TOOL_SCHEMAS = [ + { + "type": "function", + "function": { + "name": "bash", + "description": "Execute a bash command in the terminal.", + "parameters": { + "type": "object", + "properties": { + "command": {"type": "string", "description": "The bash command to execute"}, + "timeout": {"type": "integer", "description": "Timeout in seconds (default 30)"}, + }, + "required": ["command"], + }, + }, + }, + { + "type": "function", + "function": { + "name": "done", + "description": "Mark the task as complete. Call this when you have verified the task is solved.", + "parameters": {"type": "object", "properties": {}}, + }, + }, +] + +_SYSTEM_PROMPT_REGISTRY = { + "hermes": ( + "You are a highly capable Linux terminal agent. " + "Complete the user's task by running commands and verifying the result. " + "When the task is complete, call done." + ), + "qwen35": ( + "You are a highly capable Linux terminal agent. " + "Complete the user's task by running commands and verifying the result. " + "When the task is complete, call done." + ), +} + + +def get_augmented_system_content(tokenizer: Any, parser_name: str = "hermes") -> str: + system_msg = _SYSTEM_PROMPT_REGISTRY.get(parser_name, "") + rendered = tokenizer.apply_chat_template( + [{"role": "system", "content": system_msg}, {"role": "user", "content": ""}], + tools=TOOL_SCHEMAS, + tokenize=False, + add_generation_prompt=False, + ) + match = re.search(r"<\|im_start\|>system\n(.*?)<\|im_end\|>", rendered, re.DOTALL) + if match: + return match.group(1) + return system_msg diff --git a/examples/train_integrations/echo_terminal/world_modeling/__init__.py b/examples/train_integrations/echo_terminal/world_modeling/__init__.py new file mode 100644 index 0000000000..a9a2c5b3bb --- /dev/null +++ b/examples/train_integrations/echo_terminal/world_modeling/__init__.py @@ -0,0 +1 @@ +__all__ = [] diff --git a/examples/train_integrations/echo_terminal/world_modeling/config.py b/examples/train_integrations/echo_terminal/world_modeling/config.py new file mode 100644 index 0000000000..74746902fb --- /dev/null +++ b/examples/train_integrations/echo_terminal/world_modeling/config.py @@ -0,0 +1,35 @@ +from dataclasses import dataclass, field +from typing import Any, List, Literal, Optional + +from skyrl.train.config import AlgorithmConfig, DataConfig, SkyRLTrainConfig, TrainerConfig + + +@dataclass +class EchoDataConfig(DataConfig): + train_data: List[str | dict[str, Any]] = field(default_factory=list) + val_data: List[str | dict[str, Any]] = field(default_factory=list) + + +@dataclass +class EchoAlgorithmConfig(AlgorithmConfig): + world_model_coeff: float = 0.0 + world_loss_normalization: Literal["selected_tokens", "full_observation_tokens"] = "full_observation_tokens" + world_model_coeff_schedule: Literal["constant", "linear_decay", "step_decay"] = "constant" + world_model_coeff_end: float = 0.0 + world_model_coeff_transition_step: int = 0 + world_model_coeff_steps: Optional[List[int]] = None + world_model_coeff_values: Optional[List[float]] = None + wm_filter_min_valid_tool_call_pct: Optional[float] = None + wm_filter_min_parse_clean_pct: Optional[float] = None + wm_filter_min_correct_pct: Optional[float] = None + + +@dataclass +class EchoTrainerConfig(TrainerConfig): + algorithm: EchoAlgorithmConfig = field(default_factory=EchoAlgorithmConfig) + + +@dataclass +class EchoSkyRLTrainConfig(SkyRLTrainConfig): + data: EchoDataConfig = field(default_factory=EchoDataConfig) + trainer: EchoTrainerConfig = field(default_factory=EchoTrainerConfig) diff --git a/examples/train_integrations/echo_terminal/world_modeling/fsdp_worker.py b/examples/train_integrations/echo_terminal/world_modeling/fsdp_worker.py new file mode 100644 index 0000000000..868de36e17 --- /dev/null +++ b/examples/train_integrations/echo_terminal/world_modeling/fsdp_worker.py @@ -0,0 +1,87 @@ +from __future__ import annotations + +from dataclasses import replace +from typing import Any, Dict, Optional + +import ray +import torch + +from skyrl.backends.skyrl_train.training_batch import TrainingInputBatch +from skyrl.backends.skyrl_train.workers.fsdp.fsdp_worker import FSDPPolicyWorkerBase +from skyrl.train.dataset.replay_buffer import Experience + +from .loss import compute_world_model_loss, get_scheduled_coeff + + +class EchoFSDPPolicyWorkerBase(FSDPPolicyWorkerBase): + """FSDP policy worker that adds ECHO's auxiliary world-modeling loss.""" + + def get_aux_policy_loss_context(self, data: TrainingInputBatch) -> Dict[str, Any]: + world_loss_mask = data.get("world_loss_mask") + token_denominator = None + if world_loss_mask is not None: + token_denominator = float(world_loss_mask.sum().item() + 1e-3) + return { + "world_loss_weight": 1.0 / max(len(data), 1), + "world_token_normalization_denominator": token_denominator, + } + + def compute_aux_policy_loss( + self, + *, + action_log_probs: torch.Tensor, + policy_loss: torch.Tensor, + experience: Experience, + loss_config: Any, + microbatch_weight: float, + grad_sum_correction_factor: float, + aux_policy_loss_context: Optional[Dict[str, Any]] = None, + ): + extras = experience.extras or {} + metadata = experience.metadata or {} + global_step = int(metadata.get("world_model_schedule_step", metadata.get("global_step", 0))) + total_training_steps = int(metadata.get("total_training_steps", 0)) + world_model_coeff = get_scheduled_coeff( + initial=loss_config.world_model_coeff, + final=loss_config.world_model_coeff_end, + step=global_step, + total_steps=total_training_steps, + schedule=loss_config.world_model_coeff_schedule, + transition_step=loss_config.world_model_coeff_transition_step, + steps=loss_config.world_model_coeff_steps, + values=loss_config.world_model_coeff_values, + ) + loss_config = replace(loss_config, world_model_coeff=world_model_coeff) + context = aux_policy_loss_context or {} + world_loss_scaled, world_metrics = compute_world_model_loss( + action_log_probs, + extras.get("world_loss_mask"), + config=loss_config, + warning_mask=extras.get("world_warning_mask"), + env_mask=extras.get("world_env_mask"), + full_observation_count=extras.get("world_full_observation_count"), + token_normalization_denominator=context.get("world_token_normalization_denominator"), + ) + if world_loss_scaled is None: + coeff_tensor = torch.tensor(float(world_model_coeff), device=action_log_probs.device) + return None, { + "status/world_model_coeff": coeff_tensor, + } + + if loss_config.loss_reduction in ("sequence_mean", "seq_mean_token_sum_norm"): + aux_loss = world_loss_scaled * context.get("world_loss_weight", microbatch_weight) + else: + aux_loss = world_loss_scaled + + world_metrics = dict(world_metrics) + world_metrics["status/world_model_coeff"] = torch.tensor( + float(world_model_coeff), device=action_log_probs.device + ) + if world_metrics.get("world_loss_scaled") is not None: + world_metrics["world_policy_loss_ratio"] = ( + world_metrics["world_loss_scaled"].detach().abs() / (policy_loss.detach().abs() + 1e-8) + ) + return aux_loss, world_metrics + + +PolicyWorker = ray.remote(num_gpus=1)(EchoFSDPPolicyWorkerBase) diff --git a/examples/train_integrations/echo_terminal/world_modeling/loss.py b/examples/train_integrations/echo_terminal/world_modeling/loss.py new file mode 100644 index 0000000000..c5a61b24ce --- /dev/null +++ b/examples/train_integrations/echo_terminal/world_modeling/loss.py @@ -0,0 +1,114 @@ +from __future__ import annotations + +from typing import Any, Optional, Tuple, Union + +import torch + + +def get_scheduled_coeff( + *, + initial: float, + final: float, + step: int, + total_steps: int, + schedule: str = "constant", + transition_step: int = 0, + steps: Optional[list[int]] = None, + values: Optional[list[float]] = None, +) -> float: + if schedule == "constant": + return initial + if schedule == "linear_decay": + t = min(step / max(total_steps, 1), 1.0) + return initial + (final - initial) * t + if schedule == "step_decay": + if steps is not None and values is not None: + coeff = values[0] + for s, v in zip(steps, values): + if step >= s: + coeff = v + return coeff + return initial if step < transition_step else final + return initial + + +def compute_world_model_loss( + action_log_probs: torch.Tensor, + world_loss_mask: Optional[torch.Tensor], + config: Any, + warning_mask: Optional[torch.Tensor] = None, + env_mask: Optional[torch.Tensor] = None, + full_observation_count: Optional[torch.Tensor] = None, + token_normalization_denominator: Optional[Union[float, torch.Tensor]] = None, +) -> Tuple[Optional[torch.Tensor], dict[str, Optional[torch.Tensor]]]: + if world_loss_mask is None or config.world_model_coeff <= 0: + return None, {} + + device = action_log_probs.device + world_loss_mask = world_loss_mask.to(device).float() + world_ce = -action_log_probs + selected_tokens = world_loss_mask.sum() + + if selected_tokens <= 0: + zero = action_log_probs.sum() * 0.0 + return zero, { + "world_loss_unscaled": zero.detach(), + "world_loss_scaled": zero.detach(), + "world_policy_loss_ratio": None, + "world_tokens_selected": selected_tokens.detach(), + "world_tokens_warning": None, + "world_tokens_env": None, + "world_tokens_full_observation": selected_tokens.detach(), + "world_ce_selected_per_token": zero.detach(), + "world_ce_warning_per_token": None, + "world_ce_env_per_token": None, + } + + if config.loss_reduction == "sequence_mean": + if config.world_loss_normalization == "full_observation_tokens" and full_observation_count is not None: + denom = full_observation_count.to(device).float() + 1e-3 + else: + denom = (world_loss_mask.sum(dim=-1, keepdim=True) + 1e-3).expand_as(world_loss_mask) + world_loss_unscaled = ((world_ce * world_loss_mask) / denom).sum() + full_obs_counts = torch.where(world_loss_mask > 0, denom, torch.zeros_like(denom)) + world_tokens_full_observation = full_obs_counts.max(dim=1).values.sum() + elif config.loss_reduction == "seq_mean_token_sum_norm": + if config.max_seq_len is None: + raise ValueError("max_seq_len must be set when using seq_mean_token_sum_norm world loss.") + world_loss_unscaled = (world_ce * world_loss_mask).sum() / config.max_seq_len + world_tokens_full_observation = selected_tokens + else: + if token_normalization_denominator is None: + token_denominator = selected_tokens + 1e-3 + elif isinstance(token_normalization_denominator, torch.Tensor): + token_denominator = token_normalization_denominator.to(device).float() + else: + token_denominator = torch.tensor(float(token_normalization_denominator), device=device) + world_loss_unscaled = (world_ce * world_loss_mask).sum() / token_denominator + world_tokens_full_observation = selected_tokens + + world_loss_scaled = config.world_model_coeff * world_loss_unscaled + selected_den = selected_tokens + 1e-3 + metrics: dict[str, Optional[torch.Tensor]] = { + "world_loss_unscaled": world_loss_unscaled, + "world_loss_scaled": world_loss_scaled, + "world_policy_loss_ratio": None, + "world_tokens_selected": selected_tokens, + "world_tokens_warning": None, + "world_tokens_env": None, + "world_tokens_full_observation": world_tokens_full_observation, + "world_ce_selected_per_token": (world_ce * world_loss_mask).sum() / selected_den, + "world_ce_warning_per_token": None, + "world_ce_env_per_token": None, + } + if warning_mask is not None: + warning_mask = warning_mask.to(device).float() + warning_tokens = warning_mask.sum() + metrics["world_tokens_warning"] = warning_tokens + metrics["world_ce_warning_per_token"] = (world_ce * warning_mask).sum() / (warning_tokens + 1e-3) + if env_mask is not None: + env_mask = env_mask.to(device).float() + env_tokens = env_mask.sum() + metrics["world_tokens_env"] = env_tokens + metrics["world_ce_env_per_token"] = (world_ce * env_mask).sum() / (env_tokens + 1e-3) + return world_loss_scaled, metrics diff --git a/examples/train_integrations/echo_terminal/world_modeling/trainer.py b/examples/train_integrations/echo_terminal/world_modeling/trainer.py new file mode 100644 index 0000000000..f66d223f83 --- /dev/null +++ b/examples/train_integrations/echo_terminal/world_modeling/trainer.py @@ -0,0 +1,235 @@ +from __future__ import annotations + +from collections import defaultdict +from typing import List, Optional, Tuple, Union + +import torch + +from skyrl.backends.skyrl_train.training_batch import TrainingInputBatch +from skyrl.train.generators.base import GeneratorOutput +from skyrl.train.generators.utils import get_metrics_from_generator_output, merge_stepwise_output +from skyrl.train.trainer import RayPPOTrainer +from skyrl.train.utils import Timer +from skyrl.train.utils.trainer_utils import zero_variance_filter +from loguru import logger + + +class EchoPPOTrainer(RayPPOTrainer): + """Ray PPO trainer extension that carries world-model masks as generic extras.""" + + def convert_to_training_input(self, generator_output: GeneratorOutput, uids: List[str]) -> TrainingInputBatch: + training_input = super().convert_to_training_input(generator_output, uids) + response_length = int(training_input.metadata["response_length"]) + response_ids = generator_output["response_ids"] + pad_size = int(training_input.metadata.get("pad_size", 0) or 0) + + def right_align(mask_lists: Optional[List[List[int]]], name: str) -> Optional[torch.Tensor]: + if mask_lists is None: + return None + if len(mask_lists) != len(response_ids): + raise AssertionError(f"{name} must have one row per response") + tensor = torch.zeros(len(response_ids), response_length, dtype=torch.float) + for i, mask in enumerate(mask_lists): + if len(mask) != len(response_ids[i]): + raise AssertionError( + f"{name} must match response length for sample {i}, " + f"got {len(mask)} and {len(response_ids[i])}" + ) + tensor[i, response_length - len(mask) :] = torch.tensor(mask, dtype=torch.float) + if pad_size: + tensor = torch.cat([tensor, torch.zeros(pad_size, response_length, dtype=tensor.dtype)], dim=0) + return tensor + + world_loss_mask = right_align(generator_output.get("world_loss_masks"), "world_loss_masks") + world_warning_mask = right_align(generator_output.get("world_warning_masks"), "world_warning_masks") + world_env_mask = right_align(generator_output.get("world_env_masks"), "world_env_masks") + + counts = generator_output.get("world_full_observation_counts") + world_full_observation_count = None + if counts is not None: + if len(counts) != len(response_ids): + raise AssertionError("world_full_observation_counts must have one value per response") + world_full_observation_count = torch.zeros(len(response_ids), response_length, dtype=torch.float) + for i, count in enumerate(counts): + world_full_observation_count[i, response_length - len(response_ids[i]) :] = float(count) + if pad_size: + world_full_observation_count = torch.cat( + [ + world_full_observation_count, + torch.zeros(pad_size, response_length, dtype=world_full_observation_count.dtype), + ], + dim=0, + ) + + training_input["world_loss_mask"] = world_loss_mask + training_input["world_warning_mask"] = world_warning_mask + training_input["world_env_mask"] = world_env_mask + training_input["world_full_observation_count"] = world_full_observation_count + + zero_pad_keys = set(training_input.metadata.get("zero_pad_keys", [])) + zero_pad_keys.update( + { + "world_loss_mask", + "world_warning_mask", + "world_env_mask", + "world_full_observation_count", + } + ) + training_input.metadata["zero_pad_keys"] = sorted(zero_pad_keys) + return training_input + + def postprocess_generator_output( + self, generator_output: GeneratorOutput, uids: List[str] + ) -> Tuple[GeneratorOutput, List[str]]: + rewards = generator_output["rewards"] + if rewards and not isinstance(rewards[0], list): + self._apply_world_model_only(generator_output) + self._apply_world_model_filter(generator_output, uids) + return super().postprocess_generator_output(generator_output, uids) + + def _apply_world_model_only(self, generator_output: GeneratorOutput) -> None: + world_model_only = generator_output.get("world_model_only") + if not world_model_only: + return + for i, wm_only in enumerate(world_model_only): + if wm_only: + generator_output["loss_masks"][i] = [0] * len(generator_output["loss_masks"][i]) + + def _apply_world_model_filter(self, generator_output: GeneratorOutput, uids: List[str]) -> None: + cfg = self.cfg.trainer.algorithm + filters_active = ( + cfg.wm_filter_min_valid_tool_call_pct is not None + or cfg.wm_filter_min_parse_clean_pct is not None + or cfg.wm_filter_min_correct_pct is not None + ) + if not filters_active: + return + if ( + generator_output.get("trajectory_metadata") is None + or generator_output.get("world_loss_masks") is None + or generator_output.get("correct") is None + ): + return + + total = passed_quality = kept = correct_kept = 0 + dropped_low_turns = dropped_by_tool_call = dropped_by_parse_clean = dropped_by_correctness = 0 + prompts_with_zero_kept = 0 + kept_reward_sum = kept_parse_err_per_turn_sum = kept_num_turns_sum = 0.0 + dropped_reward_sum = dropped_parse_err_per_turn_sum = dropped_num_turns_sum = 0.0 + kept_stats_n = dropped_stats_n = 0 + + uid_to_indices: dict[str, list[int]] = defaultdict(list) + for i, uid in enumerate(uids): + uid_to_indices[uid].append(i) + + rewards = generator_output["rewards"] + metadata = generator_output["trajectory_metadata"] + correct = generator_output["correct"] + + for indices in uid_to_indices.values(): + total += len(indices) + filter_result: dict[int, str] = {} + rollout_stats: dict[int, tuple[float, float, int, bool]] = {} + for idx in indices: + meta = metadata[idx] or {} + turns = (meta.get("trace") or {}).get("turns") or [] + n_turns = int(meta.get("num_agent_turns", 0) or 0) + if n_turns <= 0: + filter_result[idx] = "low_turns" + continue + parse_errors = sum(1 for turn in turns if turn.get("parse_error")) + clean_turns = sum( + 1 for turn in turns if not turn.get("parse_error") and not turn.get("format_violations") + ) + valid_tc_pct = (n_turns - parse_errors) / n_turns + parse_clean_pct = clean_turns / n_turns + is_correct = bool(correct[idx]) + rollout_stats[idx] = (float(rewards[idx]), parse_errors / n_turns, n_turns, is_correct) + if cfg.wm_filter_min_valid_tool_call_pct is not None and valid_tc_pct < cfg.wm_filter_min_valid_tool_call_pct: + filter_result[idx] = "tool_call_drop" + elif cfg.wm_filter_min_parse_clean_pct is not None and parse_clean_pct < cfg.wm_filter_min_parse_clean_pct: + filter_result[idx] = "parse_clean_drop" + else: + filter_result[idx] = "quality_passed" + + survivors = [(idx, rollout_stats[idx][3]) for idx in indices if filter_result[idx] == "quality_passed"] + passed_quality += len(survivors) + if cfg.wm_filter_min_correct_pct is not None: + theta = cfg.wm_filter_min_correct_pct + correct_indices = [idx for idx, is_correct in survivors if is_correct] + incorrect_indices = [idx for idx, is_correct in survivors if not is_correct] + if theta >= 1.0 or not correct_indices: + kept_indices = set(correct_indices) + else: + max_incorrect = int(len(correct_indices) * (1.0 - theta) / theta) + kept_indices = set(correct_indices) | set(incorrect_indices[:max_incorrect]) + else: + kept_indices = {idx for idx, _ in survivors} + + kept += len(kept_indices) + correct_kept += sum(1 for idx, is_correct in survivors if is_correct and idx in kept_indices) + if not kept_indices: + prompts_with_zero_kept += 1 + + for idx in indices: + result = filter_result[idx] + if result == "low_turns": + dropped_low_turns += 1 + continue + score, parse_err_pt, n_turns, _ = rollout_stats[idx] + if result == "tool_call_drop": + dropped_by_tool_call += 1 + in_kept = False + elif result == "parse_clean_drop": + dropped_by_parse_clean += 1 + in_kept = False + else: + in_kept = idx in kept_indices + if not in_kept: + dropped_by_correctness += 1 + if in_kept: + kept_reward_sum += score + kept_parse_err_per_turn_sum += parse_err_pt + kept_num_turns_sum += n_turns + kept_stats_n += 1 + else: + dropped_reward_sum += score + dropped_parse_err_per_turn_sum += parse_err_pt + dropped_num_turns_sum += n_turns + dropped_stats_n += 1 + + for idx in indices: + if idx in kept_indices: + continue + for key in ("world_loss_masks", "world_warning_masks", "world_env_masks"): + if generator_output.get(key) is not None: + generator_output[key][idx] = [0] * len(generator_output[key][idx]) + + def avg(total_value: float, count: int) -> float: + return total_value / count if count else 0.0 + + self.all_metrics.update( + { + "train/wm_filter_total": float(total), + "train/wm_filter_passed_quality": float(passed_quality), + "train/wm_filter_kept": float(kept), + "train/wm_filter_correct_kept": float(correct_kept), + "train/wm_filter_kept_frac": (kept / total) if total else 0.0, + "train/wm_filter_dropped_low_turns": float(dropped_low_turns), + "train/wm_filter_dropped_by_tool_call": float(dropped_by_tool_call), + "train/wm_filter_dropped_by_parse_clean": float(dropped_by_parse_clean), + "train/wm_filter_dropped_by_correctness": float(dropped_by_correctness), + "train/wm_filter_kept_avg_reward": avg(kept_reward_sum, kept_stats_n), + "train/wm_filter_kept_avg_parse_errors_per_turn": avg(kept_parse_err_per_turn_sum, kept_stats_n), + "train/wm_filter_kept_avg_num_turns": avg(kept_num_turns_sum, kept_stats_n), + "train/wm_filter_dropped_avg_reward": avg(dropped_reward_sum, dropped_stats_n), + "train/wm_filter_dropped_avg_parse_errors_per_turn": avg(dropped_parse_err_per_turn_sum, dropped_stats_n), + "train/wm_filter_dropped_avg_num_turns": avg(dropped_num_turns_sum, dropped_stats_n), + "train/wm_filter_prompts_with_zero_kept": float(prompts_with_zero_kept), + } + ) + + def train_critic_and_policy(self, data: TrainingInputBatch): + data.metadata["world_model_schedule_step"] = max(self.global_step - 1, 0) + data.metadata["total_training_steps"] = self.total_training_steps or 0 + return super().train_critic_and_policy(data)