From d6553810007e63e41fc56e6e433167da6da88dc2 Mon Sep 17 00:00:00 2001 From: "joe.gaffney" Date: Sat, 24 Jan 2026 22:29:46 +0000 Subject: [PATCH 1/5] wip: adding prompt cacheing --- workers/common/pipeline_helpers.py | 8 +- workers/common/prompt_caching.py | 118 +++++++++++++++++++++++++++++ 2 files changed, 125 insertions(+), 1 deletion(-) create mode 100644 workers/common/prompt_caching.py diff --git a/workers/common/pipeline_helpers.py b/workers/common/pipeline_helpers.py index 76dbc7c..abad1f3 100644 --- a/workers/common/pipeline_helpers.py +++ b/workers/common/pipeline_helpers.py @@ -15,6 +15,7 @@ from common.config import settings from common.logger import logger, task_log from common.memory import free_gpu_memory, gpu_memory_usage +from common.prompt_caching import clear_global_prompt_cache, enable_prompt_caching from utils.utils import time_info_decorator torch.backends.cuda.matmul.allow_tf32 = True # Enable TF32 for faster matrix multiplications @@ -118,6 +119,7 @@ def clear(self): def clear_global_pipeline_cache(): global_pipeline_cache.clear() + clear_global_prompt_cache() def decorator_global_pipeline_cache(func): @@ -129,7 +131,7 @@ def wrapper(*args, **kwargs): return wrapper -def optimize_pipeline(pipe, offload=True, vae_tiling=True): +def optimize_pipeline(pipe, offload=True, vae_tiling=True, apply_prompt_caching=True): # Override the safety checker def dummy_safety_checker(images, **kwargs): return images, [False] * len(images) @@ -153,6 +155,10 @@ def dummy_safety_checker(images, **kwargs): if hasattr(pipe, "disable_safety_checker"): pipe.safety_checker = dummy_safety_checker + # Apply generic prompt caching to the optimized pipeline + if apply_prompt_caching: + enable_prompt_caching(pipe) + return pipe diff --git a/workers/common/prompt_caching.py b/workers/common/prompt_caching.py new file mode 100644 index 0000000..832bf2b --- /dev/null +++ b/workers/common/prompt_caching.py @@ -0,0 +1,118 @@ +from collections import OrderedDict +from functools import wraps + +import torch +from diffusers import DiffusionPipeline + +from common.logger import logger + +# Global cache for prompt embeddings +# Key is (pipeline_type, args, kwargs) +GLOBAL_PROMPT_CACHE = OrderedDict() +MAX_PROMPT_CACHE_SIZE = 64 + + +def clear_global_prompt_cache(): + """Clear the global prompt embeddings cache.""" + GLOBAL_PROMPT_CACHE.clear() + logger.debug("Global prompt cache cleared") + + +def get_prompt_cache_total_mb() -> float: + """ + Calculate the total memory usage of the prompt cache in Megabytes (MB). + Each element in GLOBAL_PROMPT_CACHE can be a Tensor or a tuple/list of Tensors. + """ + total_bytes = 0 + + def get_size(obj): + if isinstance(obj, torch.Tensor): + return obj.element_size() * obj.nelement() + if isinstance(obj, (list, tuple)): + return sum(get_size(i) for i in obj) + return 0 + + for value in GLOBAL_PROMPT_CACHE.values(): + total_bytes += get_size(value) + + return total_bytes / (1024 * 1024) + + +def get_prompt_cache_if_exists(cache_key): + """ + Retrieve cached result if it exists and move to Most Recently Used. + """ + if cache_key in GLOBAL_PROMPT_CACHE: + GLOBAL_PROMPT_CACHE.move_to_end(cache_key) + logger.info(f"Using cached prompt embeddings for {cache_key[0]}") + return GLOBAL_PROMPT_CACHE[cache_key] + return None + + +def add_prompt_cache(cache_key, result): + """ + Add a result to the global cache and manage its size. + """ + GLOBAL_PROMPT_CACHE[cache_key] = result + if len(GLOBAL_PROMPT_CACHE) > MAX_PROMPT_CACHE_SIZE: + GLOBAL_PROMPT_CACHE.popitem(last=False) # Remove Least Recently Used + + mb = get_prompt_cache_total_mb() + logger.info( + f"Prompt cached for {cache_key[0]}. Current cache size: {mb:.2f} MB ({len(GLOBAL_PROMPT_CACHE)}/{MAX_PROMPT_CACHE_SIZE})" + ) + + +def make_hashable(obj): + if isinstance(obj, (list, tuple)): + return tuple(make_hashable(i) for i in obj) + if isinstance(obj, dict): + return tuple(sorted((k, make_hashable(v)) for k, v in obj.items())) + return obj + + +def enable_prompt_caching(pipeline: DiffusionPipeline) -> DiffusionPipeline: + """ + Generic wrapper to cache the results of encode_prompt on any diffusers pipeline. + Uses a global cache shared across pipeline instances to avoid redundant text encoding + even when pipelines are reloaded. + """ + if not hasattr(pipeline, "encode_prompt"): + logger.warning("Pipeline does not have encode_prompt method; cannot enable prompt caching") + return pipeline + + if hasattr(pipeline, "_prompt_cache_enabled"): + logger.warning("Prompt caching already enabled on this pipeline instance") + return pipeline # Already enabled + + original_encode_prompt = pipeline.encode_prompt + pipeline_identity = pipeline.__class__.__name__ + + @wraps(original_encode_prompt) + def wrapped_encode_prompt(*args, **kwargs): + print("wrapped_encode_prompt called") + try: + # Create a cache key from identity and hashable representation of all arguments + # Identity ensures we don't use Flux embeddings for a Wan model, etc. + cache_key = (pipeline_identity, make_hashable(args), make_hashable(kwargs)) + except (TypeError, ValueError): + logger.warning("Failed to create hashable cache key; skipping prompt caching") + # Fallback: if something isn't hashable, just compute normally + return original_encode_prompt(*args, **kwargs) + + cached_result = get_prompt_cache_if_exists(cache_key) + if cached_result is not None: + return cached_result + + # Compute new results (e.g., prompt_embeds, negative_prompt_embeds) + result = original_encode_prompt(*args, **kwargs) + + # Store in global cache + add_prompt_cache(cache_key, result) + + return result + + # Monkey patch the instance method + pipeline.encode_prompt = wrapped_encode_prompt + pipeline._prompt_cache_enabled = True + return pipeline From 0ac789088a8da44e13117b68ba3540618e76b730 Mon Sep 17 00:00:00 2001 From: "joe.gaffney" Date: Sat, 24 Jan 2026 23:08:26 +0000 Subject: [PATCH 2/5] fix: removed the from pipe and qwen cant cache currently as there is a monky patch fixing another bug --- workers/common/prompt_caching.py | 2 -- workers/images/local/flux_1.py | 5 +---- workers/images/local/qwen_image.py | 4 ++-- workers/images/local/sd_xl.py | 33 ++++++++++++------------------ 4 files changed, 16 insertions(+), 28 deletions(-) diff --git a/workers/common/prompt_caching.py b/workers/common/prompt_caching.py index 832bf2b..f13abdf 100644 --- a/workers/common/prompt_caching.py +++ b/workers/common/prompt_caching.py @@ -82,7 +82,6 @@ def enable_prompt_caching(pipeline: DiffusionPipeline) -> DiffusionPipeline: return pipeline if hasattr(pipeline, "_prompt_cache_enabled"): - logger.warning("Prompt caching already enabled on this pipeline instance") return pipeline # Already enabled original_encode_prompt = pipeline.encode_prompt @@ -90,7 +89,6 @@ def enable_prompt_caching(pipeline: DiffusionPipeline) -> DiffusionPipeline: @wraps(original_encode_prompt) def wrapped_encode_prompt(*args, **kwargs): - print("wrapped_encode_prompt called") try: # Create a cache key from identity and hashable representation of all arguments # Identity ensures we don't use Flux embeddings for a Wan model, etc. diff --git a/workers/images/local/flux_1.py b/workers/images/local/flux_1.py index e1b06af..299032c 100644 --- a/workers/images/local/flux_1.py +++ b/workers/images/local/flux_1.py @@ -71,10 +71,7 @@ def get_inpainting_pipeline(model_id): def text_to_image_call(context: ImageContext) -> List[Path]: - pipe = AutoPipelineForText2Image.from_pipe( - get_pipeline("black-forest-labs/FLUX.1-Krea-dev"), - requires_safety_checker=False, - ) + pipe = get_pipeline("black-forest-labs/FLUX.1-Krea-dev") processed_image = pipe.__call__( prompt=context.data.cleaned_prompt, diff --git a/workers/images/local/qwen_image.py b/workers/images/local/qwen_image.py index a37ca18..01b5cf4 100644 --- a/workers/images/local/qwen_image.py +++ b/workers/images/local/qwen_image.py @@ -65,7 +65,7 @@ def get_pipeline(model_id) -> QwenImagePipeline: torch_dtype=torch.bfloat16, ) - return optimize_pipeline(pipe, offload=is_memory_exceeded(23)) + return optimize_pipeline(pipe, offload=is_memory_exceeded(23), apply_prompt_caching=False) @decorator_global_pipeline_cache @@ -83,7 +83,7 @@ def get_edit_pipeline(model_id) -> QwenImageEditPlusPipeline: torch_dtype=torch.bfloat16, ) - return optimize_pipeline(pipe, offload=is_memory_exceeded(23)) + return optimize_pipeline(pipe, offload=is_memory_exceeded(23), apply_prompt_caching=False) def text_to_image_call(context: ImageContext): diff --git a/workers/images/local/sd_xl.py b/workers/images/local/sd_xl.py index 57c0fa9..e76604b 100644 --- a/workers/images/local/sd_xl.py +++ b/workers/images/local/sd_xl.py @@ -3,8 +3,7 @@ import torch from diffusers import ( - AutoPipelineForImage2Image, - AutoPipelineForText2Image, + StableDiffusionXLImg2ImgPipeline, StableDiffusionXLInpaintPipeline, StableDiffusionXLPipeline, ) @@ -31,26 +30,24 @@ def get_pipeline(model_id) -> StableDiffusionXLPipeline: @decorator_global_pipeline_cache -def get_inpainting_pipeline(model_id) -> StableDiffusionXLInpaintPipeline: - args = {} - args["variant"] = "fp16" +def get_pipeline_image_to_image(model_id) -> StableDiffusionXLImg2ImgPipeline: + pipe = StableDiffusionXLImg2ImgPipeline.from_pretrained(model_id, torch_dtype=torch.bfloat16, use_safetensors=True) + return optimize_pipeline(pipe, offload=is_memory_exceeded(11)) + +@decorator_global_pipeline_cache +def get_inpainting_pipeline(model_id) -> StableDiffusionXLInpaintPipeline: pipe = StableDiffusionXLInpaintPipeline.from_pretrained( model_id, torch_dtype=torch.bfloat16, use_safetensors=True, - **args, + variant="fp16", ) - return optimize_pipeline(pipe, offload=is_memory_exceeded(11)) def text_to_image_call(context: ImageContext) -> List[Path]: - pipe = AutoPipelineForText2Image.from_pipe( - get_pipeline("SG161222/RealVisXL_V4.0"), - requires_safety_checker=False, - torch_dtype=torch.bfloat16, - ) + pipe = get_pipeline("SG161222/RealVisXL_V4.0") processed_image = pipe.__call__( width=context.width, @@ -60,30 +57,26 @@ def text_to_image_call(context: ImageContext) -> List[Path]: num_inference_steps=35, generator=context.generator, guidance_scale=3.5, - callback_on_step_end=task_log_callback(35), + callback_on_step_end=task_log_callback(35), # type: ignore ).images[0] return [context.save_output(processed_image, index=0)] def image_to_image_call(context: ImageContext) -> List[Path]: - pipe = AutoPipelineForImage2Image.from_pipe( - get_pipeline("SG161222/RealVisXL_V4.0"), - requires_safety_checker=False, - torch_dtype=torch.bfloat16, - ) + pipe = get_pipeline_image_to_image("SG161222/RealVisXL_V4.0") processed_image = pipe.__call__( width=context.width, height=context.height, prompt=context.data.cleaned_prompt, negative_prompt=_negative_prompt_default, - image=context.color_image, + image=context.color_image, # type: ignore num_inference_steps=35, generator=context.generator, strength=context.data.strength, guidance_scale=3.5, - callback_on_step_end=task_log_callback(35), + callback_on_step_end=task_log_callback(35), # type: ignore ).images[0] return [context.save_output(processed_image, index=0)] From 6143569350e434d3c83ae206613658902b7d1b7a Mon Sep 17 00:00:00 2001 From: "joe.gaffney" Date: Sun, 25 Jan 2026 13:01:33 +0000 Subject: [PATCH 3/5] fix: allow loading on cpu --- workers/common/pipeline_helpers.py | 12 ++++++++++-- workers/common/text_encoders.py | 1 + workers/images/local/flux_2.py | 1 + 3 files changed, 12 insertions(+), 2 deletions(-) diff --git a/workers/common/pipeline_helpers.py b/workers/common/pipeline_helpers.py index abad1f3..9665b31 100644 --- a/workers/common/pipeline_helpers.py +++ b/workers/common/pipeline_helpers.py @@ -196,6 +196,7 @@ def get_quantized_model( model_class, target_precision: Literal[4, 8, 16] = 8, torch_dtype=torch.bfloat16, + device_map_cpu: bool = False, # bits and bytes will force loading on cuda if not specified ): """ Load a quantized model component if available locally; otherwise, load original, @@ -207,14 +208,18 @@ def get_quantized_model( model_class (class): The HF model class to load (e.g., WanTransformer3DModel). target_precision (Literal[4, 8, 16]): Target precision for quantization. torch_dtype (torch.dtype): Dtype to use when loading. + device_map_cpu (bool): Whether to force loading on CPU. Returns: model instance """ + args = {} + if device_map_cpu: + args["device_map"] = "cpu" if target_precision == 16: logger.debug(f"Quantization disabled for {model_id} subfolder {subfolder}") - return model_class.from_pretrained(model_id, subfolder=subfolder, torch_dtype=torch_dtype) + return model_class.from_pretrained(model_id, subfolder=subfolder, torch_dtype=torch_dtype, **args) load_in_4bit = target_precision == 4 quant_dir = get_quant_dir(model_id, subfolder, load_in_4bit=load_in_4bit) @@ -231,6 +236,9 @@ def get_quantized_model( bnb_4bit_compute_dtype=torch_dtype, bnb_4bit_use_double_quant=False, # NOTE test this out ) + + # load on CPU first to avoid double swapping during load + args["device_map"] = "cpu" else: # 8-bit quantization # torchAO seems best fit for 8-bit currently as still supported offloading quant_config = TorchAoConfig("int8_weight_only") @@ -239,7 +247,7 @@ def get_quantized_model( try: logger.info(f"Loading quantized model from {quant_dir}") model = model_class.from_pretrained( - quant_dir, torch_dtype=torch_dtype, local_files_only=True, use_safetensors=use_safetensors + quant_dir, torch_dtype=torch_dtype, local_files_only=True, use_safetensors=use_safetensors, **args ) except Exception as e: logger.warning(f"Failed to load quantized model from {quant_dir}: {e}") diff --git a/workers/common/text_encoders.py b/workers/common/text_encoders.py index 96e7e42..aabfad5 100644 --- a/workers/common/text_encoders.py +++ b/workers/common/text_encoders.py @@ -53,6 +53,7 @@ def get_mistral3_text_encoder() -> Mistral3ForConditionalGeneration: model_class=Mistral3ForConditionalGeneration, target_precision=16, torch_dtype=torch.bfloat16, + device_map_cpu=True, ) diff --git a/workers/images/local/flux_2.py b/workers/images/local/flux_2.py index f70b6fb..611fa9d 100644 --- a/workers/images/local/flux_2.py +++ b/workers/images/local/flux_2.py @@ -27,6 +27,7 @@ def get_pipeline(model_id): model_class=Flux2Transformer2DModel, target_precision=16, torch_dtype=torch.bfloat16, + device_map_cpu=True, ) pipe = Flux2Pipeline.from_pretrained( From e36a6ece6f71dadbaae4071ec8453763761ea2ce Mon Sep 17 00:00:00 2001 From: "joe.gaffney" Date: Sun, 25 Jan 2026 20:44:44 +0000 Subject: [PATCH 4/5] feature: storing prompt on cpu --- workers/common/prompt_caching.py | 52 ++++++++++++++------------------ 1 file changed, 23 insertions(+), 29 deletions(-) diff --git a/workers/common/prompt_caching.py b/workers/common/prompt_caching.py index f13abdf..e557212 100644 --- a/workers/common/prompt_caching.py +++ b/workers/common/prompt_caching.py @@ -1,5 +1,6 @@ from collections import OrderedDict from functools import wraps +from typing import Any import torch from diffusers import DiffusionPipeline @@ -12,35 +13,28 @@ MAX_PROMPT_CACHE_SIZE = 64 +def _move_to_device(obj: Any, device): + """Recursively move tensors in a nested structure to a device.""" + if isinstance(obj, torch.Tensor): + # NOTE important we must detach and clone tensors before caching them + return obj.detach().clone().to(device) + if isinstance(obj, (list, tuple)): + return type(obj)(_move_to_device(x, device) for x in obj) + + # Dict support removed as encode_prompt does not return dicts + return obj + + def clear_global_prompt_cache(): """Clear the global prompt embeddings cache.""" GLOBAL_PROMPT_CACHE.clear() logger.debug("Global prompt cache cleared") -def get_prompt_cache_total_mb() -> float: - """ - Calculate the total memory usage of the prompt cache in Megabytes (MB). - Each element in GLOBAL_PROMPT_CACHE can be a Tensor or a tuple/list of Tensors. - """ - total_bytes = 0 - - def get_size(obj): - if isinstance(obj, torch.Tensor): - return obj.element_size() * obj.nelement() - if isinstance(obj, (list, tuple)): - return sum(get_size(i) for i in obj) - return 0 - - for value in GLOBAL_PROMPT_CACHE.values(): - total_bytes += get_size(value) - - return total_bytes / (1024 * 1024) - - def get_prompt_cache_if_exists(cache_key): """ Retrieve cached result if it exists and move to Most Recently Used. + Note: The caller is responsible for moving the result to the correct device. """ if cache_key in GLOBAL_PROMPT_CACHE: GLOBAL_PROMPT_CACHE.move_to_end(cache_key) @@ -52,14 +46,17 @@ def get_prompt_cache_if_exists(cache_key): def add_prompt_cache(cache_key, result): """ Add a result to the global cache and manage its size. + Automatically moves the result to CPU to save VRAM. """ - GLOBAL_PROMPT_CACHE[cache_key] = result + # Move to CPU for storage + cpu_result = _move_to_device(result, "cpu") + + GLOBAL_PROMPT_CACHE[cache_key] = cpu_result if len(GLOBAL_PROMPT_CACHE) > MAX_PROMPT_CACHE_SIZE: GLOBAL_PROMPT_CACHE.popitem(last=False) # Remove Least Recently Used - mb = get_prompt_cache_total_mb() logger.info( - f"Prompt cached for {cache_key[0]}. Current cache size: {mb:.2f} MB ({len(GLOBAL_PROMPT_CACHE)}/{MAX_PROMPT_CACHE_SIZE})" + f"Prompt cached for {cache_key[0]}. Current cache size: ({len(GLOBAL_PROMPT_CACHE)}/{MAX_PROMPT_CACHE_SIZE})" ) @@ -90,22 +87,19 @@ def enable_prompt_caching(pipeline: DiffusionPipeline) -> DiffusionPipeline: @wraps(original_encode_prompt) def wrapped_encode_prompt(*args, **kwargs): try: - # Create a cache key from identity and hashable representation of all arguments - # Identity ensures we don't use Flux embeddings for a Wan model, etc. cache_key = (pipeline_identity, make_hashable(args), make_hashable(kwargs)) except (TypeError, ValueError): logger.warning("Failed to create hashable cache key; skipping prompt caching") - # Fallback: if something isn't hashable, just compute normally return original_encode_prompt(*args, **kwargs) cached_result = get_prompt_cache_if_exists(cache_key) if cached_result is not None: - return cached_result + # Move back to the target device (e.g. CUDA) + target_device = kwargs.get("device") or getattr(pipeline, "device", torch.device("cuda")) + return _move_to_device(cached_result, target_device) - # Compute new results (e.g., prompt_embeds, negative_prompt_embeds) result = original_encode_prompt(*args, **kwargs) - # Store in global cache add_prompt_cache(cache_key, result) return result From 303ccbf61f64d0c1b434d5bb705e5a4946e311f0 Mon Sep 17 00:00:00 2001 From: "joe.gaffney" Date: Mon, 26 Jan 2026 21:33:16 +0000 Subject: [PATCH 5/5] fix: mypy issues --- workers/common/prompt_caching.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/workers/common/prompt_caching.py b/workers/common/prompt_caching.py index e557212..44c25bd 100644 --- a/workers/common/prompt_caching.py +++ b/workers/common/prompt_caching.py @@ -9,7 +9,7 @@ # Global cache for prompt embeddings # Key is (pipeline_type, args, kwargs) -GLOBAL_PROMPT_CACHE = OrderedDict() +GLOBAL_PROMPT_CACHE: OrderedDict[Any, Any] = OrderedDict() MAX_PROMPT_CACHE_SIZE = 64 @@ -106,5 +106,5 @@ def wrapped_encode_prompt(*args, **kwargs): # Monkey patch the instance method pipeline.encode_prompt = wrapped_encode_prompt - pipeline._prompt_cache_enabled = True + pipeline._prompt_cache_enabled = True # type: ignore[attr-defined] return pipeline