-
Notifications
You must be signed in to change notification settings - Fork 6.6k
Open
Labels
bugSomething isn't workingSomething isn't working
Description
Describe the bug
Hi, I might be misunderstanding something here — I’m not very experienced with submitting issues to Diffusers yet, so please feel free to correct me if I’ve overlooked anything.
When using DPMSolverMultistepScheduler with the combination:
DPMSolverMultistepScheduler(
beta_schedule="squaredcos_cap_v2",
use_karras_sigmas=True,
)calling set_timesteps(20) produces duplicate timesteps, which causes the multistep solver’s internal index tracking to drift, eventually leading to an index out of bounds error during sampling.
Reproduction
import diffusers
import torch
from diffusers import DPMSolverMultistepScheduler
print("Diffusers version:", diffusers.__version__)
print("Torch version:", torch.__version__)
scheduler = DPMSolverMultistepScheduler(
beta_schedule="squaredcos_cap_v2",
use_karras_sigmas=True,
)
scheduler.set_timesteps(20)
print("Timesteps:", scheduler.timesteps)
latents = torch.randn(1, 4, 64, 64)
noise_pred = torch.randn_like(latents)
for i, t in enumerate(scheduler.timesteps):
print(f"Step {i}: t={t.item()}")
try:
result = scheduler.step(noise_pred, t, latents)
latents = result.prev_sample
except Exception as e:
print(f"ERROR: {e}")
breakLogs
Diffusers version: 0.35.2
Torch version: 2.9.0+cu128
Timesteps: tensor([998, 998, 998, 998, 998, 998, 998, 998, 998, 997, 996, 994, 989, 978,
949, 867, 623, 221, 33, 0])
Step 0: t=998
Step 1: t=998
Step 2: t=998
Step 3: t=998
Step 4: t=998
Step 5: t=998
Step 6: t=998
Step 7: t=998
Step 8: t=998
Step 9: t=997
Step 10: t=996
Step 11: t=994
Step 12: t=989
Step 13: t=978
Step 14: t=949
Step 15: t=867
Step 16: t=623
Step 17: t=221
Step 18: t=33
Step 19: t=0
ERROR: index 21 is out of bounds for dimension 0 with size 21System Info
- 🤗 Diffusers version: 0.35.2
- Platform: Linux-6.8.0-88-generic-x86_64-with-glibc2.39
- Running on Google Colab?: No
- Python version: 3.12.3
- PyTorch version (GPU?): 2.9.0+cu128 (True)
- Flax version (CPU?/GPU?/TPU?): not installed (NA)
- Jax version: not installed
- JaxLib version: not installed
- Huggingface_hub version: 0.35.3
- Transformers version: 4.57.1
- Accelerate version: 1.10.1
- PEFT version: not installed
- Bitsandbytes version: not installed
- Safetensors version: 0.6.2
- xFormers version: not installed
- Accelerator: NVIDIA GeForce RTX 4070 Ti SUPER, 16376 MiB
- Using GPU in script?:
- Using distributed or parallel set-up in script?:
Who can help?
No response
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working