Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 13 additions & 9 deletions benchmarks/api_server/maxtext_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,9 @@

from dataclasses import dataclass, field

from MaxText import maxengine, pyconfig, multimodal_utils
from MaxText import maxengine, pyconfig
from MaxText.multimodal import processor as mm_processor
from MaxText.multimodal import utils as mm_utils
from maxtext.utils import max_logging, max_utils

# Set TF log level to avoid verbose startup messages.
Expand Down Expand Up @@ -493,23 +495,25 @@ def _build_completions(self, streams, logprobs, echo):

def _preprocess_inputs(self, text, prefill_length, image_path):
"""Helper to preprocess a single text and optional image input."""
processor_output = multimodal_utils.PreprocessorOutput()
processor_output = mm_utils.PreprocessorOutput()
images = None
if self.config.use_multimodal and image_path:
text = multimodal_utils.reformat_prompt(
text, image_placeholder=self.config.image_placeholder, model_name=self.config.model_name, num_images=1
text = mm_processor.reformat_prompt(
prompt=self.config.prompt,
image_placeholder=self.config.image_placeholder,
model_name=self.config.model_name,
num_images=1,
)
loaded_images = multimodal_utils.load_image_from_path(image_path)
processor_output = multimodal_utils.pre_process_image(loaded_images, model_name=self.config.model_name)
prefill_length -= multimodal_utils.get_image_offsets(self.config.model_name, processor_output=processor_output)
processor_output = mm_processor.preprocess_mm_data(self.config)
prefill_length -= mm_processor.get_image_offsets(self.config.model_name, processor_output=processor_output)
images = processor_output.pixel_values

tokens, true_length = self.tokenizer.encode(text, is_bos=not self.has_chat_template, prefill_lengths=[prefill_length])
if self.config.use_multimodal and image_path:
tokens = multimodal_utils.prepare_text_for_image_fusion(
tokens = mm_processor.prepare_text_for_image_fusion(
tokens, model_name=self.config.model_name, processor_output=processor_output
)
true_length += multimodal_utils.get_image_offsets(self.config.model_name, processor_output=processor_output)
true_length += mm_processor.get_image_offsets(self.config.model_name, processor_output=processor_output)

return tokens, true_length, images

Expand Down
34 changes: 17 additions & 17 deletions src/MaxText/input_pipeline/_input_pipeline_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,9 @@
import numpy as np
import tensorflow as tf
from MaxText import tokenizer
from MaxText import multimodal_utils
# from MaxText import multimodal_utils
from MaxText.multimodal import processor as mm_processor
from MaxText.multimodal import utils as mm_utils
from maxtext.utils import max_logging

Features = dict[str, tf.Tensor]
Expand Down Expand Up @@ -73,13 +75,13 @@ def reformat_prompt(example, column, image_placeholder, model_name):
num_images = len(example["images"])
else:
num_images = 1
example[column] = multimodal_utils.reformat_prompt(example[column], image_placeholder, model_name, num_images)
example[column] = mm_processor.reformat_prompt(example[column], image_placeholder, model_name, num_images)
return example


def reformat_response(example, column, model_name):
"""reformat response for multimodal SFT"""
example[column] = multimodal_utils.reformat_response(example[column][0], model_name)
example[column] = mm_processor.reformat_response(example[column][0], model_name)
return example


Expand All @@ -101,11 +103,11 @@ def pre_process_image_sft(example, image_column, model_name):

def _process_image_fn(image):
if isinstance(image, list):
image = [np.array(multimodal_utils.convert_to_RGB(img)) for img in image]
image = [np.array(mm_utils.convert_to_RGB(img)) for img in image]
else:
image = np.array(multimodal_utils.convert_to_RGB(image))
image = np.array(mm_utils.convert_to_RGB(image))

image = multimodal_utils.pre_process_image(image, model_name)
image = mm_processor.preprocess_image_for_training(image, model_name)
return image

example[image_column] = _process_image_fn(example[image_column])
Expand All @@ -114,7 +116,7 @@ def _process_image_fn(image):

def prepare_text_for_image_fusion(example, column_name, model_name):
"""prepare text for image fusion for multimodal SFT"""
example[column_name] = multimodal_utils.prepare_text_for_image_fusion(
example[column_name] = mm_processor.prepare_text_for_image_fusion(
example[column_name], model_name, processor_output=example["images"]
)
return example
Expand Down Expand Up @@ -478,9 +480,7 @@ def _pad_text(self, x: np.ndarray, max_length: int, pad_id: int) -> np.ndarray:
pad_amount = [(0, pad_amount)] + [(0, 0)] * (len(x.shape) - 1)
return np.pad(x, pad_amount, constant_values=pad_id)[: self.max_length]

def _pad_image_and_mask(
self, preprocessed_image: multimodal_utils.PreprocessorOutput
) -> multimodal_utils.PreprocessorOutput:
def _pad_image_and_mask(self, preprocessed_image: mm_utils.PreprocessorOutput) -> mm_utils.PreprocessorOutput:
"""Pads the input tensors (image and mask) of a PreprocessorOutput to a maximum number of items.

This function unifies padding logic for image tensors (standard or tiled) and
Expand Down Expand Up @@ -513,14 +513,14 @@ def _pad_image_and_mask(
- The dummy images used for padding are based on the image shape for initialization
of this model (ignoring batch size).
"""
if not isinstance(preprocessed_image, multimodal_utils.PreprocessorOutput):
if not isinstance(preprocessed_image, mm_utils.PreprocessorOutput):
raise TypeError(f"Input must be multimodal_utils.PreprocessorOutput, but got {type(preprocessed_image)}")

if preprocessed_image.pixel_values is None:
raise ValueError("Input preprocessed_image must have pixel_values to pad images.")

# Determine the maximum number of images/masks allowed.
image_offsets = multimodal_utils.get_image_offsets(self.model_name, preprocessed_image)
image_offsets = mm_processor.get_image_offsets(self.model_name, preprocessed_image)
single_image_offset = image_offsets // preprocessed_image.pixel_values.shape[0]

# Reserve space for at least one text token.
Expand Down Expand Up @@ -569,13 +569,13 @@ def _pad(tensor: np.ndarray) -> np.ndarray:
return preprocessed_image

def map(
self, element: dict[str, np.ndarray | multimodal_utils.PreprocessorOutput]
) -> dict[str, np.ndarray | multimodal_utils.PreprocessorOutput]:
self, element: dict[str, np.ndarray | mm_utils.PreprocessorOutput]
) -> dict[str, np.ndarray | mm_utils.PreprocessorOutput]:
"""map to each element"""
data_columns = list(element.keys())
for data_column in data_columns:
if data_column != "images":
if isinstance(element[data_column], multimodal_utils.PreprocessorOutput):
if isinstance(element[data_column], mm_utils.PreprocessorOutput):
raise TypeError("Only 'images' column can be of type PreprocessorOutput.")

element[f"{data_column}_segmentation"] = element[data_column] != self.pad_id
Expand Down Expand Up @@ -615,7 +615,7 @@ def map(self, element: dict[str, np.ndarray]) -> dict[str, np.ndarray]:
if preprocessed_image is None:
return element

if not isinstance(preprocessed_image, multimodal_utils.PreprocessorOutput):
if not isinstance(preprocessed_image, mm_utils.PreprocessorOutput):
raise TypeError(f"'images' must be of type PreprocessorOutput, but got {type(preprocessed_image)}")

output = element.copy()
Expand Down Expand Up @@ -646,7 +646,7 @@ class FoldImagesIntoBatch(grain.MapTransform):

def __post_init__(self):
"""Initializes the target shape after the dataclass is created."""
self.target_shape = multimodal_utils.get_dummy_image_shape_for_init(self.model_name)
self.target_shape = mm_processor.get_dummy_image_shape_for_init(self.model_name)

def map(self, element: dict[str, np.ndarray]) -> dict[str, np.ndarray]:
"""Applies the folding transformation to the 'images' field if present."""
Expand Down
6 changes: 3 additions & 3 deletions src/MaxText/layers/decoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@
from MaxText.layers import normalizations
from MaxText.layers import quantizations
from MaxText.layers import pipeline
from MaxText import multimodal_utils
from MaxText import sharding
from MaxText.layers.attentions import attention_as_linen
from MaxText.layers.normalizations import rms_norm
Expand All @@ -57,6 +56,7 @@
simple_layer,
olmo3,
)
from MaxText.multimodal import utils as mm_utils
from maxtext.inference import page_manager
from maxtext.utils import max_logging
from maxtext.utils import max_utils
Expand Down Expand Up @@ -586,7 +586,7 @@ def _apply_embedding(
"llama4-17b-128e",
"qwen3-omni-30b-a3b",
]:
y = multimodal_utils.merge_mm_embeddings(
y = mm_utils.merge_mm_embeddings(
text_embeddings=y,
multimodal_embeddings=image_embeddings,
mask=bidirectional_mask,
Expand All @@ -598,7 +598,7 @@ def _apply_embedding(

if audio_embeddings is not None and cfg.use_audio:
if cfg.model_name in ["qwen3-omni-30b-a3b"]:
y = multimodal_utils.merge_mm_embeddings(
y = mm_utils.merge_mm_embeddings(
text_embeddings=y,
multimodal_embeddings=audio_embeddings,
mask=audio_masks,
Expand Down
33 changes: 8 additions & 25 deletions src/MaxText/layers/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,14 @@
from flax import nnx
from MaxText.layers import initializers

from MaxText.common_types import DecoderBlockType, Config, MODEL_MODE_TRAIN, MODEL_MODE_AUTOREGRESSIVE, DECODING_ACTIVE_SEQUENCE_INDICATOR
from MaxText import multimodal_utils
from MaxText.common_types import Config, MODEL_MODE_TRAIN, MODEL_MODE_AUTOREGRESSIVE, DECODING_ACTIVE_SEQUENCE_INDICATOR
from MaxText.layers import nnx_wrappers
from MaxText.layers.decoders import Decoder
from MaxText.layers.embeddings import Embed, embed_as_linen
from MaxText.layers.encoders import VisionEncoder, vision_encoder_as_linen, AudioEncoder, audio_encoder_as_linen
from MaxText.layers.quantizations import AqtQuantization as Quant
from MaxText.layers.multi_token_prediction import multi_token_prediction_block_as_linen
from MaxText.multimodal import processor
from maxtext.inference import page_manager
from maxtext.utils import max_utils

Expand Down Expand Up @@ -155,24 +155,15 @@ def __call__(

if self.config.use_multimodal and encoder_images is not None:
image_embeddings = self.vision_encoder(input_images=encoder_images, deterministic=not enable_dropout)
bidirectional_mask = processor.get_bidirectional_mask_vision(self.config, decoder_input_tokens)

if self.config.decoder_block == DecoderBlockType.GEMMA3:
bidirectional_mask = decoder_input_tokens == multimodal_utils.GEMMA_TOKEN_PLACEHOLDER
elif self.config.decoder_block == DecoderBlockType.LLAMA4:
bidirectional_mask = decoder_input_tokens == multimodal_utils.LLAMA4_PATCH_TOKEN
elif self.config.decoder_block == DecoderBlockType.QWEN3_MOE:
# Create bidirectional_mask for vision/video token merging
bidirectional_mask = (decoder_input_tokens == multimodal_utils.QWEN3_OMNI_IMAGE_TOKEN) | (
decoder_input_tokens == multimodal_utils.QWEN3_OMNI_VIDEO_TOKEN
)
# Create image/video mask for deepstack visual embedding injection
if self.config.use_multimodal and encoder_audios is not None and self.audio_encoder is not None:
audio_embeddings = self.audio_encoder(input_audio=encoder_audios, deterministic=not enable_dropout)

# Create audio mask for placeholder tokens (qwen3-omni models)
audio_masks = None
if audio_embeddings is not None and self.config.decoder_block == DecoderBlockType.QWEN3_MOE:
audio_masks = decoder_input_tokens == multimodal_utils.QWEN3_OMNI_AUDIO_TOKEN
if audio_embeddings is not None:
audio_masks = processor.get_bidirectional_mask_audio(self.config, decoder_input_tokens)

logits, hidden_state, kv_caches = self.decoder(
shared_embedding=self.shared_embedding,
Expand Down Expand Up @@ -469,24 +460,16 @@ def __call__(
image_embeddings = None
if self.config.use_multimodal and encoder_images is not None:
image_embeddings = self.vision_encoder(input_images=encoder_images, deterministic=not enable_dropout)

if self.config.decoder_block == DecoderBlockType.GEMMA3:
bidirectional_mask = decoder_input_tokens == multimodal_utils.GEMMA_TOKEN_PLACEHOLDER
elif self.config.decoder_block == DecoderBlockType.LLAMA4:
bidirectional_mask = decoder_input_tokens == multimodal_utils.LLAMA4_PATCH_TOKEN
elif self.config.decoder_block == DecoderBlockType.QWEN3_MOE:
bidirectional_mask = (decoder_input_tokens == multimodal_utils.QWEN3_OMNI_IMAGE_TOKEN) | (
decoder_input_tokens == multimodal_utils.QWEN3_OMNI_VIDEO_TOKEN
)
bidirectional_mask = processor.get_bidirectional_mask_vision(self.config, decoder_input_tokens)

audio_embeddings = None
if self.config.use_multimodal and encoder_audios is not None and self.audio_encoder is not None:
audio_embeddings = self.audio_encoder(input_audio=encoder_audios, deterministic=not enable_dropout)

# Create audio mask for placeholder tokens (qwen3-omni models)
audio_masks = None
if audio_embeddings is not None and self.config.decoder_block == DecoderBlockType.QWEN3_MOE:
audio_masks = decoder_input_tokens == multimodal_utils.QWEN3_OMNI_AUDIO_TOKEN
if audio_embeddings is not None:
audio_masks = processor.get_bidirectional_mask_audio(self.config, decoder_input_tokens)

logits, hidden_state, kv_caches = self.decoder(
shared_embedding=self.token_embedder,
Expand Down
19 changes: 8 additions & 11 deletions src/MaxText/maxengine.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,11 +43,11 @@
from jetstream.engine.tokenizer_pb2 import TokenizerParameters
from jetstream.engine.tokenizer_pb2 import TokenizerType

from MaxText import multimodal_utils
from MaxText import pyconfig
from MaxText.common_types import MODEL_MODE_PREFILL, DECODING_ACTIVE_SEQUENCE_INDICATOR, MODEL_MODE_AUTOREGRESSIVE
from MaxText.globals import MAXTEXT_PKG_DIR
from MaxText.layers import models, quantizations
from MaxText.multimodal import processor
from maxtext.inference import inference_utils
from maxtext.inference.page_manager import PageManager, PageState
from maxtext.utils import lora_utils
Expand Down Expand Up @@ -324,12 +324,11 @@ def quantize_params(self, state, rng: PRNGKeyType | None = None):

@jax.jit
def model_apply(_p, _rng):
image_shape = multimodal_utils.get_dummy_image_shape_for_init(
self.config.model_name, batch_size=self.config.micro_batch_size_to_train_on
)
audio_shape = multimodal_utils.get_dummy_audio_shape_for_init(
self.config.model_name, config=self.config, batch_size=self.config.micro_batch_size_to_train_on
image_shape = processor.get_dummy_image_shape_for_init(
model_name=self.config.model_name,
batch_size=self.config.micro_batch_size_to_train_on,
)
audio_shape = processor.get_dummy_audio_shape_for_init(self.config)
return self.model.apply(
_p | {"aqt": {}},
jnp.ones((1, self.config.max_prefill_predict_length), dtype=jnp.int32),
Expand Down Expand Up @@ -1549,15 +1548,13 @@ def init(abstract_params, page_state):
dtype=jnp.int32,
)
dummy_image = jnp.ones(
multimodal_utils.get_dummy_image_shape_for_init(
self.config.model_name, batch_size=self.config.micro_batch_size_to_train_on
processor.get_dummy_image_shape_for_init(
model_name=self.config.model_name, batch_size=self.config.per_device_batch_size
),
dtype=jnp.int32,
)
dummy_audio = jnp.ones(
multimodal_utils.get_dummy_audio_shape_for_init(
self.config.model_name, config=self.config, batch_size=self.config.micro_batch_size_to_train_on
),
processor.get_dummy_audio_shape_for_init(self.config),
dtype=jnp.float32,
)
_, cache = self.model.apply(
Expand Down
47 changes: 0 additions & 47 deletions src/MaxText/multimodal/preprocessor.py

This file was deleted.

Loading
Loading