Skip to content

SFTTrainer: Logging Experts Statistics #4611

@shon-otmazgin

Description

@shon-otmazgin

Feature request

Hello,

When fine-tuning MoE models and setting output_router_logits=True, I’ve found it useful to report per-expert token throughput statistics. This information is extremely valuable for identifying which expert layers are most active and therefore the best candidates for adding LoRA adapters later on. I implemented this functionality myself, and although it isn’t fully optimized, it has worked well for my use case.

class SFTTrainerWithExpertsLogging(SFTTrainer):
    def compute_loss(
            self,
            model: nn.Module,
            inputs: dict[str, torch.Tensor | Any],
            return_outputs: bool = False,
            num_items_in_batch: torch.Tensor | None = None,
    ):
        (loss, outputs) = super().compute_loss(
            model, inputs, return_outputs=True, num_items_in_batch=num_items_in_batch
        )

        mode = "train" if self.model.training else "eval"

        # Compute expert usage statistics
        if hasattr(outputs, "router_logits") and outputs.router_logits is not None:
            labels = inputs.get("labels")
            mask = labels != -100 if labels is not None else None

            with torch.no_grad():
                for layer_idx, router_logit in enumerate(outputs.router_logits):
                    expert_indices = router_logit.argmax(dim=-1)
                    expert_indices = expert_indices.flatten()[mask.flatten()]
                    num_experts = router_logit.shape[-1]

                    # Count tokens per expert locally
                    expert_counts = torch.bincount(expert_indices, minlength=num_experts)[:num_experts]

                    # Sum across all processes
                    if self.accelerator.use_distributed:
                        torch.distributed.all_reduce(expert_counts, op=torch.distributed.ReduceOp.SUM)

                    # Normalize
                    expert_percentages = expert_counts / num_items_in_batch
                    for expert_idx in range(num_experts):
                        metric_key = f"expert_usage_layer_{layer_idx}_expert_{expert_idx}"
                        self._metrics[mode][metric_key].append(expert_percentages[expert_idx].item())

                if hasattr(outputs, "aux_loss") and outputs.aux_loss is not None:
                    aux_loss = self.accelerator.gather_for_metrics(outputs.aux_loss).mean().item()
                    self._metrics[mode]["aux_loss"].append(aux_loss)

        return (loss, outputs) if return_outputs else loss

Motivation

When fine-tuning MoE models, it’s often inefficient to attach LoRA adapters to every expert layer. By collecting statistics on which experts are actually being activated, we can make much better use of PEFT methods and selectively target only the most impactful experts.

Your contribution

see code above.

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions