generated from fastai/nbdev_template
-
Notifications
You must be signed in to change notification settings - Fork 2.3k
Open
Labels
⚡ PEFTRelated to PEFTRelated to PEFT✨ enhancementNew feature or requestNew feature or request🏋 SFTRelated to SFTRelated to SFT
Description
Feature request
Hello,
When fine-tuning MoE models and setting output_router_logits=True, I’ve found it useful to report per-expert token throughput statistics. This information is extremely valuable for identifying which expert layers are most active and therefore the best candidates for adding LoRA adapters later on. I implemented this functionality myself, and although it isn’t fully optimized, it has worked well for my use case.
class SFTTrainerWithExpertsLogging(SFTTrainer):
def compute_loss(
self,
model: nn.Module,
inputs: dict[str, torch.Tensor | Any],
return_outputs: bool = False,
num_items_in_batch: torch.Tensor | None = None,
):
(loss, outputs) = super().compute_loss(
model, inputs, return_outputs=True, num_items_in_batch=num_items_in_batch
)
mode = "train" if self.model.training else "eval"
# Compute expert usage statistics
if hasattr(outputs, "router_logits") and outputs.router_logits is not None:
labels = inputs.get("labels")
mask = labels != -100 if labels is not None else None
with torch.no_grad():
for layer_idx, router_logit in enumerate(outputs.router_logits):
expert_indices = router_logit.argmax(dim=-1)
expert_indices = expert_indices.flatten()[mask.flatten()]
num_experts = router_logit.shape[-1]
# Count tokens per expert locally
expert_counts = torch.bincount(expert_indices, minlength=num_experts)[:num_experts]
# Sum across all processes
if self.accelerator.use_distributed:
torch.distributed.all_reduce(expert_counts, op=torch.distributed.ReduceOp.SUM)
# Normalize
expert_percentages = expert_counts / num_items_in_batch
for expert_idx in range(num_experts):
metric_key = f"expert_usage_layer_{layer_idx}_expert_{expert_idx}"
self._metrics[mode][metric_key].append(expert_percentages[expert_idx].item())
if hasattr(outputs, "aux_loss") and outputs.aux_loss is not None:
aux_loss = self.accelerator.gather_for_metrics(outputs.aux_loss).mean().item()
self._metrics[mode]["aux_loss"].append(aux_loss)
return (loss, outputs) if return_outputs else lossMotivation
When fine-tuning MoE models, it’s often inefficient to attach LoRA adapters to every expert layer. By collecting statistics on which experts are actually being activated, we can make much better use of PEFT methods and selectively target only the most impactful experts.
Your contribution
see code above.
Metadata
Metadata
Assignees
Labels
⚡ PEFTRelated to PEFTRelated to PEFT✨ enhancementNew feature or requestNew feature or request🏋 SFTRelated to SFTRelated to SFT