diff --git a/workers/common/pipeline_helpers.py b/workers/common/pipeline_helpers.py index 76dbc7c..9665b31 100644 --- a/workers/common/pipeline_helpers.py +++ b/workers/common/pipeline_helpers.py @@ -15,6 +15,7 @@ from common.config import settings from common.logger import logger, task_log from common.memory import free_gpu_memory, gpu_memory_usage +from common.prompt_caching import clear_global_prompt_cache, enable_prompt_caching from utils.utils import time_info_decorator torch.backends.cuda.matmul.allow_tf32 = True # Enable TF32 for faster matrix multiplications @@ -118,6 +119,7 @@ def clear(self): def clear_global_pipeline_cache(): global_pipeline_cache.clear() + clear_global_prompt_cache() def decorator_global_pipeline_cache(func): @@ -129,7 +131,7 @@ def wrapper(*args, **kwargs): return wrapper -def optimize_pipeline(pipe, offload=True, vae_tiling=True): +def optimize_pipeline(pipe, offload=True, vae_tiling=True, apply_prompt_caching=True): # Override the safety checker def dummy_safety_checker(images, **kwargs): return images, [False] * len(images) @@ -153,6 +155,10 @@ def dummy_safety_checker(images, **kwargs): if hasattr(pipe, "disable_safety_checker"): pipe.safety_checker = dummy_safety_checker + # Apply generic prompt caching to the optimized pipeline + if apply_prompt_caching: + enable_prompt_caching(pipe) + return pipe @@ -190,6 +196,7 @@ def get_quantized_model( model_class, target_precision: Literal[4, 8, 16] = 8, torch_dtype=torch.bfloat16, + device_map_cpu: bool = False, # bits and bytes will force loading on cuda if not specified ): """ Load a quantized model component if available locally; otherwise, load original, @@ -201,14 +208,18 @@ def get_quantized_model( model_class (class): The HF model class to load (e.g., WanTransformer3DModel). target_precision (Literal[4, 8, 16]): Target precision for quantization. torch_dtype (torch.dtype): Dtype to use when loading. + device_map_cpu (bool): Whether to force loading on CPU. Returns: model instance """ + args = {} + if device_map_cpu: + args["device_map"] = "cpu" if target_precision == 16: logger.debug(f"Quantization disabled for {model_id} subfolder {subfolder}") - return model_class.from_pretrained(model_id, subfolder=subfolder, torch_dtype=torch_dtype) + return model_class.from_pretrained(model_id, subfolder=subfolder, torch_dtype=torch_dtype, **args) load_in_4bit = target_precision == 4 quant_dir = get_quant_dir(model_id, subfolder, load_in_4bit=load_in_4bit) @@ -225,6 +236,9 @@ def get_quantized_model( bnb_4bit_compute_dtype=torch_dtype, bnb_4bit_use_double_quant=False, # NOTE test this out ) + + # load on CPU first to avoid double swapping during load + args["device_map"] = "cpu" else: # 8-bit quantization # torchAO seems best fit for 8-bit currently as still supported offloading quant_config = TorchAoConfig("int8_weight_only") @@ -233,7 +247,7 @@ def get_quantized_model( try: logger.info(f"Loading quantized model from {quant_dir}") model = model_class.from_pretrained( - quant_dir, torch_dtype=torch_dtype, local_files_only=True, use_safetensors=use_safetensors + quant_dir, torch_dtype=torch_dtype, local_files_only=True, use_safetensors=use_safetensors, **args ) except Exception as e: logger.warning(f"Failed to load quantized model from {quant_dir}: {e}") diff --git a/workers/common/prompt_caching.py b/workers/common/prompt_caching.py new file mode 100644 index 0000000..44c25bd --- /dev/null +++ b/workers/common/prompt_caching.py @@ -0,0 +1,110 @@ +from collections import OrderedDict +from functools import wraps +from typing import Any + +import torch +from diffusers import DiffusionPipeline + +from common.logger import logger + +# Global cache for prompt embeddings +# Key is (pipeline_type, args, kwargs) +GLOBAL_PROMPT_CACHE: OrderedDict[Any, Any] = OrderedDict() +MAX_PROMPT_CACHE_SIZE = 64 + + +def _move_to_device(obj: Any, device): + """Recursively move tensors in a nested structure to a device.""" + if isinstance(obj, torch.Tensor): + # NOTE important we must detach and clone tensors before caching them + return obj.detach().clone().to(device) + if isinstance(obj, (list, tuple)): + return type(obj)(_move_to_device(x, device) for x in obj) + + # Dict support removed as encode_prompt does not return dicts + return obj + + +def clear_global_prompt_cache(): + """Clear the global prompt embeddings cache.""" + GLOBAL_PROMPT_CACHE.clear() + logger.debug("Global prompt cache cleared") + + +def get_prompt_cache_if_exists(cache_key): + """ + Retrieve cached result if it exists and move to Most Recently Used. + Note: The caller is responsible for moving the result to the correct device. + """ + if cache_key in GLOBAL_PROMPT_CACHE: + GLOBAL_PROMPT_CACHE.move_to_end(cache_key) + logger.info(f"Using cached prompt embeddings for {cache_key[0]}") + return GLOBAL_PROMPT_CACHE[cache_key] + return None + + +def add_prompt_cache(cache_key, result): + """ + Add a result to the global cache and manage its size. + Automatically moves the result to CPU to save VRAM. + """ + # Move to CPU for storage + cpu_result = _move_to_device(result, "cpu") + + GLOBAL_PROMPT_CACHE[cache_key] = cpu_result + if len(GLOBAL_PROMPT_CACHE) > MAX_PROMPT_CACHE_SIZE: + GLOBAL_PROMPT_CACHE.popitem(last=False) # Remove Least Recently Used + + logger.info( + f"Prompt cached for {cache_key[0]}. Current cache size: ({len(GLOBAL_PROMPT_CACHE)}/{MAX_PROMPT_CACHE_SIZE})" + ) + + +def make_hashable(obj): + if isinstance(obj, (list, tuple)): + return tuple(make_hashable(i) for i in obj) + if isinstance(obj, dict): + return tuple(sorted((k, make_hashable(v)) for k, v in obj.items())) + return obj + + +def enable_prompt_caching(pipeline: DiffusionPipeline) -> DiffusionPipeline: + """ + Generic wrapper to cache the results of encode_prompt on any diffusers pipeline. + Uses a global cache shared across pipeline instances to avoid redundant text encoding + even when pipelines are reloaded. + """ + if not hasattr(pipeline, "encode_prompt"): + logger.warning("Pipeline does not have encode_prompt method; cannot enable prompt caching") + return pipeline + + if hasattr(pipeline, "_prompt_cache_enabled"): + return pipeline # Already enabled + + original_encode_prompt = pipeline.encode_prompt + pipeline_identity = pipeline.__class__.__name__ + + @wraps(original_encode_prompt) + def wrapped_encode_prompt(*args, **kwargs): + try: + cache_key = (pipeline_identity, make_hashable(args), make_hashable(kwargs)) + except (TypeError, ValueError): + logger.warning("Failed to create hashable cache key; skipping prompt caching") + return original_encode_prompt(*args, **kwargs) + + cached_result = get_prompt_cache_if_exists(cache_key) + if cached_result is not None: + # Move back to the target device (e.g. CUDA) + target_device = kwargs.get("device") or getattr(pipeline, "device", torch.device("cuda")) + return _move_to_device(cached_result, target_device) + + result = original_encode_prompt(*args, **kwargs) + + add_prompt_cache(cache_key, result) + + return result + + # Monkey patch the instance method + pipeline.encode_prompt = wrapped_encode_prompt + pipeline._prompt_cache_enabled = True # type: ignore[attr-defined] + return pipeline diff --git a/workers/common/text_encoders.py b/workers/common/text_encoders.py index 96e7e42..aabfad5 100644 --- a/workers/common/text_encoders.py +++ b/workers/common/text_encoders.py @@ -53,6 +53,7 @@ def get_mistral3_text_encoder() -> Mistral3ForConditionalGeneration: model_class=Mistral3ForConditionalGeneration, target_precision=16, torch_dtype=torch.bfloat16, + device_map_cpu=True, ) diff --git a/workers/images/local/flux_1.py b/workers/images/local/flux_1.py index e1b06af..299032c 100644 --- a/workers/images/local/flux_1.py +++ b/workers/images/local/flux_1.py @@ -71,10 +71,7 @@ def get_inpainting_pipeline(model_id): def text_to_image_call(context: ImageContext) -> List[Path]: - pipe = AutoPipelineForText2Image.from_pipe( - get_pipeline("black-forest-labs/FLUX.1-Krea-dev"), - requires_safety_checker=False, - ) + pipe = get_pipeline("black-forest-labs/FLUX.1-Krea-dev") processed_image = pipe.__call__( prompt=context.data.cleaned_prompt, diff --git a/workers/images/local/flux_2.py b/workers/images/local/flux_2.py index f70b6fb..611fa9d 100644 --- a/workers/images/local/flux_2.py +++ b/workers/images/local/flux_2.py @@ -27,6 +27,7 @@ def get_pipeline(model_id): model_class=Flux2Transformer2DModel, target_precision=16, torch_dtype=torch.bfloat16, + device_map_cpu=True, ) pipe = Flux2Pipeline.from_pretrained( diff --git a/workers/images/local/qwen_image.py b/workers/images/local/qwen_image.py index a37ca18..01b5cf4 100644 --- a/workers/images/local/qwen_image.py +++ b/workers/images/local/qwen_image.py @@ -65,7 +65,7 @@ def get_pipeline(model_id) -> QwenImagePipeline: torch_dtype=torch.bfloat16, ) - return optimize_pipeline(pipe, offload=is_memory_exceeded(23)) + return optimize_pipeline(pipe, offload=is_memory_exceeded(23), apply_prompt_caching=False) @decorator_global_pipeline_cache @@ -83,7 +83,7 @@ def get_edit_pipeline(model_id) -> QwenImageEditPlusPipeline: torch_dtype=torch.bfloat16, ) - return optimize_pipeline(pipe, offload=is_memory_exceeded(23)) + return optimize_pipeline(pipe, offload=is_memory_exceeded(23), apply_prompt_caching=False) def text_to_image_call(context: ImageContext): diff --git a/workers/images/local/sd_xl.py b/workers/images/local/sd_xl.py index 57c0fa9..e76604b 100644 --- a/workers/images/local/sd_xl.py +++ b/workers/images/local/sd_xl.py @@ -3,8 +3,7 @@ import torch from diffusers import ( - AutoPipelineForImage2Image, - AutoPipelineForText2Image, + StableDiffusionXLImg2ImgPipeline, StableDiffusionXLInpaintPipeline, StableDiffusionXLPipeline, ) @@ -31,26 +30,24 @@ def get_pipeline(model_id) -> StableDiffusionXLPipeline: @decorator_global_pipeline_cache -def get_inpainting_pipeline(model_id) -> StableDiffusionXLInpaintPipeline: - args = {} - args["variant"] = "fp16" +def get_pipeline_image_to_image(model_id) -> StableDiffusionXLImg2ImgPipeline: + pipe = StableDiffusionXLImg2ImgPipeline.from_pretrained(model_id, torch_dtype=torch.bfloat16, use_safetensors=True) + return optimize_pipeline(pipe, offload=is_memory_exceeded(11)) + +@decorator_global_pipeline_cache +def get_inpainting_pipeline(model_id) -> StableDiffusionXLInpaintPipeline: pipe = StableDiffusionXLInpaintPipeline.from_pretrained( model_id, torch_dtype=torch.bfloat16, use_safetensors=True, - **args, + variant="fp16", ) - return optimize_pipeline(pipe, offload=is_memory_exceeded(11)) def text_to_image_call(context: ImageContext) -> List[Path]: - pipe = AutoPipelineForText2Image.from_pipe( - get_pipeline("SG161222/RealVisXL_V4.0"), - requires_safety_checker=False, - torch_dtype=torch.bfloat16, - ) + pipe = get_pipeline("SG161222/RealVisXL_V4.0") processed_image = pipe.__call__( width=context.width, @@ -60,30 +57,26 @@ def text_to_image_call(context: ImageContext) -> List[Path]: num_inference_steps=35, generator=context.generator, guidance_scale=3.5, - callback_on_step_end=task_log_callback(35), + callback_on_step_end=task_log_callback(35), # type: ignore ).images[0] return [context.save_output(processed_image, index=0)] def image_to_image_call(context: ImageContext) -> List[Path]: - pipe = AutoPipelineForImage2Image.from_pipe( - get_pipeline("SG161222/RealVisXL_V4.0"), - requires_safety_checker=False, - torch_dtype=torch.bfloat16, - ) + pipe = get_pipeline_image_to_image("SG161222/RealVisXL_V4.0") processed_image = pipe.__call__( width=context.width, height=context.height, prompt=context.data.cleaned_prompt, negative_prompt=_negative_prompt_default, - image=context.color_image, + image=context.color_image, # type: ignore num_inference_steps=35, generator=context.generator, strength=context.data.strength, guidance_scale=3.5, - callback_on_step_end=task_log_callback(35), + callback_on_step_end=task_log_callback(35), # type: ignore ).images[0] return [context.save_output(processed_image, index=0)]