Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 17 additions & 3 deletions workers/common/pipeline_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from common.config import settings
from common.logger import logger, task_log
from common.memory import free_gpu_memory, gpu_memory_usage
from common.prompt_caching import clear_global_prompt_cache, enable_prompt_caching
from utils.utils import time_info_decorator

torch.backends.cuda.matmul.allow_tf32 = True # Enable TF32 for faster matrix multiplications
Expand Down Expand Up @@ -118,6 +119,7 @@ def clear(self):

def clear_global_pipeline_cache():
global_pipeline_cache.clear()
clear_global_prompt_cache()


def decorator_global_pipeline_cache(func):
Expand All @@ -129,7 +131,7 @@ def wrapper(*args, **kwargs):
return wrapper


def optimize_pipeline(pipe, offload=True, vae_tiling=True):
def optimize_pipeline(pipe, offload=True, vae_tiling=True, apply_prompt_caching=True):
# Override the safety checker
def dummy_safety_checker(images, **kwargs):
return images, [False] * len(images)
Expand All @@ -153,6 +155,10 @@ def dummy_safety_checker(images, **kwargs):
if hasattr(pipe, "disable_safety_checker"):
pipe.safety_checker = dummy_safety_checker

# Apply generic prompt caching to the optimized pipeline
if apply_prompt_caching:
enable_prompt_caching(pipe)

return pipe


Expand Down Expand Up @@ -190,6 +196,7 @@ def get_quantized_model(
model_class,
target_precision: Literal[4, 8, 16] = 8,
torch_dtype=torch.bfloat16,
device_map_cpu: bool = False, # bits and bytes will force loading on cuda if not specified
):
"""
Load a quantized model component if available locally; otherwise, load original,
Expand All @@ -201,14 +208,18 @@ def get_quantized_model(
model_class (class): The HF model class to load (e.g., WanTransformer3DModel).
target_precision (Literal[4, 8, 16]): Target precision for quantization.
torch_dtype (torch.dtype): Dtype to use when loading.
device_map_cpu (bool): Whether to force loading on CPU.

Returns:
model instance
"""
args = {}
if device_map_cpu:
args["device_map"] = "cpu"

if target_precision == 16:
logger.debug(f"Quantization disabled for {model_id} subfolder {subfolder}")
return model_class.from_pretrained(model_id, subfolder=subfolder, torch_dtype=torch_dtype)
return model_class.from_pretrained(model_id, subfolder=subfolder, torch_dtype=torch_dtype, **args)

load_in_4bit = target_precision == 4
quant_dir = get_quant_dir(model_id, subfolder, load_in_4bit=load_in_4bit)
Expand All @@ -225,6 +236,9 @@ def get_quantized_model(
bnb_4bit_compute_dtype=torch_dtype,
bnb_4bit_use_double_quant=False, # NOTE test this out
)

# load on CPU first to avoid double swapping during load
args["device_map"] = "cpu"
else: # 8-bit quantization
# torchAO seems best fit for 8-bit currently as still supported offloading
quant_config = TorchAoConfig("int8_weight_only")
Expand All @@ -233,7 +247,7 @@ def get_quantized_model(
try:
logger.info(f"Loading quantized model from {quant_dir}")
model = model_class.from_pretrained(
quant_dir, torch_dtype=torch_dtype, local_files_only=True, use_safetensors=use_safetensors
quant_dir, torch_dtype=torch_dtype, local_files_only=True, use_safetensors=use_safetensors, **args
)
except Exception as e:
logger.warning(f"Failed to load quantized model from {quant_dir}: {e}")
Expand Down
110 changes: 110 additions & 0 deletions workers/common/prompt_caching.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
from collections import OrderedDict
from functools import wraps
from typing import Any

import torch
from diffusers import DiffusionPipeline

from common.logger import logger

# Global cache for prompt embeddings
# Key is (pipeline_type, args, kwargs)
GLOBAL_PROMPT_CACHE: OrderedDict[Any, Any] = OrderedDict()
MAX_PROMPT_CACHE_SIZE = 64
Comment thread
JoeGaffney marked this conversation as resolved.


def _move_to_device(obj: Any, device):
"""Recursively move tensors in a nested structure to a device."""
if isinstance(obj, torch.Tensor):
# NOTE important we must detach and clone tensors before caching them
return obj.detach().clone().to(device)
if isinstance(obj, (list, tuple)):
return type(obj)(_move_to_device(x, device) for x in obj)

# Dict support removed as encode_prompt does not return dicts
return obj


def clear_global_prompt_cache():
"""Clear the global prompt embeddings cache."""
GLOBAL_PROMPT_CACHE.clear()
logger.debug("Global prompt cache cleared")


def get_prompt_cache_if_exists(cache_key):
"""
Retrieve cached result if it exists and move to Most Recently Used.
Note: The caller is responsible for moving the result to the correct device.
"""
if cache_key in GLOBAL_PROMPT_CACHE:
GLOBAL_PROMPT_CACHE.move_to_end(cache_key)
logger.info(f"Using cached prompt embeddings for {cache_key[0]}")
return GLOBAL_PROMPT_CACHE[cache_key]
return None


def add_prompt_cache(cache_key, result):
"""
Add a result to the global cache and manage its size.
Automatically moves the result to CPU to save VRAM.
"""
# Move to CPU for storage
cpu_result = _move_to_device(result, "cpu")

GLOBAL_PROMPT_CACHE[cache_key] = cpu_result
if len(GLOBAL_PROMPT_CACHE) > MAX_PROMPT_CACHE_SIZE:
GLOBAL_PROMPT_CACHE.popitem(last=False) # Remove Least Recently Used

logger.info(
f"Prompt cached for {cache_key[0]}. Current cache size: ({len(GLOBAL_PROMPT_CACHE)}/{MAX_PROMPT_CACHE_SIZE})"
)


def make_hashable(obj):
if isinstance(obj, (list, tuple)):
return tuple(make_hashable(i) for i in obj)
if isinstance(obj, dict):
return tuple(sorted((k, make_hashable(v)) for k, v in obj.items()))
return obj
Comment thread
JoeGaffney marked this conversation as resolved.


def enable_prompt_caching(pipeline: DiffusionPipeline) -> DiffusionPipeline:
"""
Generic wrapper to cache the results of encode_prompt on any diffusers pipeline.
Uses a global cache shared across pipeline instances to avoid redundant text encoding
even when pipelines are reloaded.
"""
if not hasattr(pipeline, "encode_prompt"):
logger.warning("Pipeline does not have encode_prompt method; cannot enable prompt caching")
return pipeline

if hasattr(pipeline, "_prompt_cache_enabled"):
return pipeline # Already enabled

original_encode_prompt = pipeline.encode_prompt
pipeline_identity = pipeline.__class__.__name__

@wraps(original_encode_prompt)
def wrapped_encode_prompt(*args, **kwargs):
try:
cache_key = (pipeline_identity, make_hashable(args), make_hashable(kwargs))
except (TypeError, ValueError):
logger.warning("Failed to create hashable cache key; skipping prompt caching")
return original_encode_prompt(*args, **kwargs)

cached_result = get_prompt_cache_if_exists(cache_key)
if cached_result is not None:
# Move back to the target device (e.g. CUDA)
target_device = kwargs.get("device") or getattr(pipeline, "device", torch.device("cuda"))
return _move_to_device(cached_result, target_device)

result = original_encode_prompt(*args, **kwargs)

add_prompt_cache(cache_key, result)

return result

# Monkey patch the instance method
pipeline.encode_prompt = wrapped_encode_prompt
pipeline._prompt_cache_enabled = True # type: ignore[attr-defined]
return pipeline
Comment thread
JoeGaffney marked this conversation as resolved.
Comment thread
JoeGaffney marked this conversation as resolved.
1 change: 1 addition & 0 deletions workers/common/text_encoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ def get_mistral3_text_encoder() -> Mistral3ForConditionalGeneration:
model_class=Mistral3ForConditionalGeneration,
target_precision=16,
torch_dtype=torch.bfloat16,
device_map_cpu=True,
)


Expand Down
5 changes: 1 addition & 4 deletions workers/images/local/flux_1.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,10 +71,7 @@ def get_inpainting_pipeline(model_id):


def text_to_image_call(context: ImageContext) -> List[Path]:
pipe = AutoPipelineForText2Image.from_pipe(
get_pipeline("black-forest-labs/FLUX.1-Krea-dev"),
requires_safety_checker=False,
)
pipe = get_pipeline("black-forest-labs/FLUX.1-Krea-dev")

processed_image = pipe.__call__(
prompt=context.data.cleaned_prompt,
Expand Down
1 change: 1 addition & 0 deletions workers/images/local/flux_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ def get_pipeline(model_id):
model_class=Flux2Transformer2DModel,
target_precision=16,
torch_dtype=torch.bfloat16,
device_map_cpu=True,
)

pipe = Flux2Pipeline.from_pretrained(
Expand Down
4 changes: 2 additions & 2 deletions workers/images/local/qwen_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def get_pipeline(model_id) -> QwenImagePipeline:
torch_dtype=torch.bfloat16,
)

return optimize_pipeline(pipe, offload=is_memory_exceeded(23))
return optimize_pipeline(pipe, offload=is_memory_exceeded(23), apply_prompt_caching=False)
Comment thread
JoeGaffney marked this conversation as resolved.


@decorator_global_pipeline_cache
Expand All @@ -83,7 +83,7 @@ def get_edit_pipeline(model_id) -> QwenImageEditPlusPipeline:
torch_dtype=torch.bfloat16,
)

return optimize_pipeline(pipe, offload=is_memory_exceeded(23))
return optimize_pipeline(pipe, offload=is_memory_exceeded(23), apply_prompt_caching=False)


def text_to_image_call(context: ImageContext):
Expand Down
33 changes: 13 additions & 20 deletions workers/images/local/sd_xl.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,7 @@

import torch
from diffusers import (
AutoPipelineForImage2Image,
AutoPipelineForText2Image,
StableDiffusionXLImg2ImgPipeline,
StableDiffusionXLInpaintPipeline,
StableDiffusionXLPipeline,
)
Expand All @@ -31,26 +30,24 @@ def get_pipeline(model_id) -> StableDiffusionXLPipeline:


@decorator_global_pipeline_cache
def get_inpainting_pipeline(model_id) -> StableDiffusionXLInpaintPipeline:
args = {}
args["variant"] = "fp16"
def get_pipeline_image_to_image(model_id) -> StableDiffusionXLImg2ImgPipeline:
pipe = StableDiffusionXLImg2ImgPipeline.from_pretrained(model_id, torch_dtype=torch.bfloat16, use_safetensors=True)

Copilot AI Jan 25, 2026

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The new get_pipeline_image_to_image function is missing the variant="fp16" parameter that is present in the get_inpainting_pipeline function. This inconsistency may cause the img2img pipeline to use a different weight variant than intended. Consider adding variant="fp16" to maintain consistency, or document why this parameter is only needed for the inpainting pipeline.

Suggested change
pipe = StableDiffusionXLImg2ImgPipeline.from_pretrained(model_id, torch_dtype=torch.bfloat16, use_safetensors=True)
pipe = StableDiffusionXLImg2ImgPipeline.from_pretrained(
model_id,
torch_dtype=torch.bfloat16,
use_safetensors=True,
variant="fp16",
)

Copilot uses AI. Check for mistakes.
return optimize_pipeline(pipe, offload=is_memory_exceeded(11))


@decorator_global_pipeline_cache
def get_inpainting_pipeline(model_id) -> StableDiffusionXLInpaintPipeline:
pipe = StableDiffusionXLInpaintPipeline.from_pretrained(
model_id,
torch_dtype=torch.bfloat16,
use_safetensors=True,
**args,
variant="fp16",
Comment thread
JoeGaffney marked this conversation as resolved.
)

return optimize_pipeline(pipe, offload=is_memory_exceeded(11))


def text_to_image_call(context: ImageContext) -> List[Path]:
pipe = AutoPipelineForText2Image.from_pipe(
get_pipeline("SG161222/RealVisXL_V4.0"),
requires_safety_checker=False,
torch_dtype=torch.bfloat16,
)
pipe = get_pipeline("SG161222/RealVisXL_V4.0")

processed_image = pipe.__call__(
width=context.width,
Expand All @@ -60,30 +57,26 @@ def text_to_image_call(context: ImageContext) -> List[Path]:
num_inference_steps=35,
generator=context.generator,
guidance_scale=3.5,
callback_on_step_end=task_log_callback(35),
callback_on_step_end=task_log_callback(35), # type: ignore
).images[0]

return [context.save_output(processed_image, index=0)]


def image_to_image_call(context: ImageContext) -> List[Path]:
pipe = AutoPipelineForImage2Image.from_pipe(
get_pipeline("SG161222/RealVisXL_V4.0"),
requires_safety_checker=False,
torch_dtype=torch.bfloat16,
)
pipe = get_pipeline_image_to_image("SG161222/RealVisXL_V4.0")

processed_image = pipe.__call__(
width=context.width,
height=context.height,
prompt=context.data.cleaned_prompt,
negative_prompt=_negative_prompt_default,
image=context.color_image,
image=context.color_image, # type: ignore
num_inference_steps=35,
generator=context.generator,
strength=context.data.strength,
guidance_scale=3.5,
callback_on_step_end=task_log_callback(35),
callback_on_step_end=task_log_callback(35), # type: ignore
).images[0]

return [context.save_output(processed_image, index=0)]
Expand Down