Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion QEfficient/finetune/experimental/core/component_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
#
# -----------------------------------------------------------------------------


import logging
from typing import Callable, Dict, Optional, Type

Expand Down Expand Up @@ -198,3 +197,14 @@ def list_callbacks(self) -> list[str]:

# Global registry instance
registry = ComponentRegistry()


class ComponentFactory:
@staticmethod
def create_model(model_type: str, model_name: str, **kwargs) -> any:
"""Create a model instance."""
model_class = registry.get_model(model_type)
if model_class is None:
raise ValueError(f"Unknown model: {model_type}. Available: {registry.list_models()}")
model_instance = model_class.create(model_name, **kwargs)
return model_instance
143 changes: 143 additions & 0 deletions QEfficient/finetune/experimental/core/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,146 @@
# SPDX-License-Identifier: BSD-3-Clause
#
# -----------------------------------------------------------------------------

import logging
from abc import ABC, abstractmethod
from typing import Any, Dict, Optional, Type

import torch.nn as nn
from transformers import AutoTokenizer
import transformers
from transformers.utils.logging import get_logger

from QEfficient.finetune.experimental.core.component_registry import registry
from QEfficient.finetune.experimental.core.utils.dataset_utils import insert_pad_token

Check failure on line 18 in QEfficient/finetune/experimental/core/model.py

View workflow job for this annotation

GitHub Actions / lint

Ruff (I001)

QEfficient/finetune/experimental/core/model.py:8:1: I001 Import block is un-sorted or un-formatted

logger = get_logger(__name__)


class BaseModel(nn.Module, ABC):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there any reason why below methods are not implemented? It was given in a reference code as well.

  • parameters
  • named_parameters
  • state_dict
  • load_state_dict

Reference code: https://github.com/quic-meetkuma/LightningLLMs/blob/hf_trainer/LightningLLM/components/model.py

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMO the reference code is wrong and misses some careful observations about pytorch nn.Module and how it works. nn.Module already provides these and because the wrapped model is registered, so state_dict / parameters work as-is. Adding trivial passthroughs would be redundant, is noisy and a future maintenance hazard. Let me know if you have difficulty understanding this, I can explain offline.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right, i missed that the BaseModel is inherited from nn.Module.

"""Shared skeleton for every finetunable model in the system."""

def __init__(self, model_name: str, **model_kwargs: Any) -> None:
super().__init__()
self.model_name = model_name
self.model_kwargs: Dict[str, Any] = model_kwargs
self._model: Optional[nn.Module] = None
self._tokenizer: Any = None # HF tokenizers are not nn.Modules.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This variable is specific to LLMs. We can not put it here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, missed this part, The reference code also makes this mistake and got copied over. Base model needs to stay model-agnostic. I'll probably move these LLM specific methods to an LLM mixin.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let us use some generic name e.g. _preprocessor (even in HFModel class)?


# Factory constructor: load model after __init__ finishes
@classmethod
def create(cls, model_name: str, **model_kwargs: Any) -> "BaseModel":
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not needed. Use registry mechanism to instantiate the BaseModel type of objects, which in turn instantiates nn.Modules as well.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

create factory seems to me a clean way to construct objects. It is the only place that guarantees the wrapped model is actually registered in the nn.Module class before anyone calls state_dict, parameters, to / train. Dropping it and relying on bare init puts us back into lazy-init land, where un-necessary we have to call load_model() in every method like reference code.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I understand your concern. Basically, you want to make sure there will be one liner initialization of model. The reference code does it in 2 lines : https://github.com/quic-meetkuma/LightningLLMs/blob/da2f6b39e8533cbd05d563ca68113828f783e73e/LightningLLM/main.py#L131C34-L131C46

I suggest better to move this create into the ComponentRegistry class and create a "create_model" method over there. The reason being that class has a sole responsibility of storing references of particular classes and creating instances of classes which the user want. This way all the registry and instance creation is encapsulated in one place and we only need to deal with ComponentRegistry and not other individual classes.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The reference shared above is stale as PR #645 didn't have create_model() added in the component_registry.py

Why was it not added?

I have added it. Even after adding it, we would still need a method in BaseModel to load the model and register it in nn.Module which will be called through create_model() now. Keeping the name of that method same as create as of now. Loading part can also be moved to create_model() in the component_registry.py if that way seems better.

obj = cls(model_name, **model_kwargs)
# load model after __init__ finishes
module = obj.load_model()
if not isinstance(module, nn.Module):
raise TypeError(f"load_model() must return nn.Module, got {type(module)}")
obj._model = module
return obj

@abstractmethod
def load_model(self) -> nn.Module:
"""Load and return the underlying torch.nn.Module."""
pass

def load_tokenizer(self) -> Any:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Specific to LLMs. We can not put it here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Replied above

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What I feel is that the naming should be "load_preprocessor" instead of load_tokenizer. In the HFModel class's load_preprocessor, you should actually load the tokenizer and return it.

Tokenizer is one of the preprocessor. Other models can have different preprocessor.

"""Override if the model exposes a tokenizer."""
raise NotImplementedError(f"{type(self).__name__} does not provide a tokenizer.")

# Lazy accessors
@property
def model(self) -> nn.Module:
if self._model is None:
raise RuntimeError("Model not loaded; use .create(...) to load.")
return self._model

@property
def tokenizer(self) -> Any:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is applicable for NLP models. We can't put it in BaseClass. Better to create an abstract method called "preprocessor" which defines generic preprocessing function applicable for the model. There wont be any implementation here but the children class should implement that. In case of LLMs, this method should return tokenizer.

if self._tokenizer is None:
self._tokenizer = self.load_tokenizer()
return self._tokenizer

# nn.Module API surface
def forward(self, *args, **kwargs):
return self.model(*args, **kwargs)

def get_input_embeddings(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is applicable for NLP models. We can't put it in BaseClass.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Replied above.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am a bit doubtful on this. But looking at most of the HF model's implementation (e.g. ASR and Vision) this method exists. So it is fine.

Bdw do we need this method at all? I think we can get rid of it as it is not even used.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This method is also accessible via PreTrainedClass. method implementation here is not helping.

What was the thought process behind adding this one in the reference code? I also think we should remove this.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I meant this method also* (just like resize_token_embeddings) is accessible via PreTrainedModel class. Check reply for resize_token_embeddings one.

if hasattr(self.model, "get_input_embeddings"):
return self.model.get_input_embeddings()
logger.info(f"Model {self.model_name} does not expose input embeddings", logging.WARNING)
return None

def resize_token_embeddings(self, new_num_tokens: int) -> None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is applicable for NLP models. We can't put it in BaseClass.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Replied above.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This method looks specific to NLP. I dont find reference of this method in VIT or Whisper model of HF. I think the else case is taking care of it. So it is fine for now.

Copy link
Contributor Author

@quic-swatia quic-swatia Dec 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Whisper and ViT inherit PreTrainedModel class which has resize_token_embeddings implemented in it. Else block will take care of the models which didn't define resize_token_embeddings and have not inherited PreTrainedModel .

What was the thought process behind adding this method in the reference code? I think we should remove it from here.

if hasattr(self.model, "resize_token_embeddings"):
self.model.resize_token_embeddings(new_num_tokens)
else:
logger.info(f"Model {self.model_name} cannot resize token embeddings", logging.WARNING)

# optional
def to(self, *args, **kwargs):
self.model.to(*args, **kwargs)
return self

def train(self, mode: bool = True):
self.model.train(mode)
return super().train(mode)

def eval(self):
return self.train(False)


@registry.model("hf")
class HFModel(BaseModel):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How can I get the PEFT config file from this model?

Any reason to avoid adding that method?

Reference code had that functionality for some reason: https://github.com/quic-meetkuma/LightningLLMs/blob/hf_trainer/LightningLLM/components/model.py#L272

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The PEFT configuration should not reside within model classes. Even if it had to be included, it belongs in the BaseModel rather than HFModel, since the PEFT library supports custom models. Refer documentation from PEFT: https://huggingface.co/docs/peft/developer_guides/custom_models But the reference code incorrectly places it in HFModel.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, makes more sense to put it in base class. Please do the needful.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to load the model based on configuration as well. That is mainly for testing purpose. In integration tests we will not load an entire model consists of 32 layers. But we will only load the same model with 2 or 4 layers and do further testing. For that purpose config should be used to load the model. Check huggingface documentation on how to do that.

"""HuggingFace-backed model with optional quantization."""

def __init__(
self,
model_name: str,
auto_class_name: str = "AutoModelForCausalLM",
*,
tokenizer_name: Optional[str] = None,
**model_kwargs: Any,
) -> None:
super().__init__(model_name, **model_kwargs)
self.tokenizer_name = tokenizer_name or model_name
self.auto_class: Type = self._resolve_auto_class(auto_class_name)

@staticmethod
def _resolve_auto_class(auto_class_name: str) -> Type:
if not hasattr(transformers, auto_class_name):
candidates = sorted(name for name in dir(transformers) if name.startswith("AutoModel"))
raise ValueError(
f"Unsupported Auto class '{auto_class_name}'. Available candidates: {', '.join(candidates)}"
)
return getattr(transformers, auto_class_name)

# def _build_quant_config(self) -> Optional[BitsAndBytesConfig]:
# if not self.model_kwargs.get("load_in_4bit"):
# return None
# return BitsAndBytesConfig(
# load_in_4bit=True,
# bnb_4bit_quant_type=self.model_kwargs.get("bnb_4bit_quant_type", "nf4"),
# bnb_4bit_compute_dtype=self.model_kwargs.get("bnb_4bit_compute_dtype", torch.float16),
# bnb_4bit_use_double_quant=self.model_kwargs.get("bnb_4bit_use_double_quant", True),
# )

def configure_model_kwargs(self) -> Dict[str, Any]:
"""Hook for subclasses to tweak HF `.from_pretrained` kwargs."""
extra = dict(self.model_kwargs)
# extra["quantization_config"] = self._build_quant_config()
return extra

def load_model(self) -> nn.Module:
logger.info(f"Loading HuggingFace model '{self.model_name}' via {self.auto_class.__name__}")

return self.auto_class.from_pretrained(
self.model_name,
**self.configure_model_kwargs(),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Dont directly pass results from a function. Keep it explicit in a variable and then unpack dict here. That way it will be easier to pin point the error if any.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here, you are unpacking all the kwargs. If any extra kwargs are given, does the self.auto_class.from_pretrained method accepts and discards it? If not then it will surely throw an error. Please check and correct it if needed.

)

def load_tokenizer(self) -> AutoTokenizer:
"""Load Hugging Face tokenizer."""
logger.info(f"Loading tokenizer '{self.tokenizer_name}'")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logger should only log to rank zero. Check other logger references as well.

You already have a reference code : https://github.com/quic-meetkuma/LightningLLMs/blob/hf_trainer/LightningLLM/components/model.py#L264

tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_name)
insert_pad_token(tokenizer)
return tokenizer
152 changes: 152 additions & 0 deletions QEfficient/finetune/experimental/tests/test_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
# -----------------------------------------------------------------------------
#
# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
# SPDX-License-Identifier: BSD-3-Clause
#
# -----------------------------------------------------------------------------

import pytest
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy write header missing

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated in latest.

import torch
import torch.nn as nn
from unittest import mock

import transformers

Check failure on line 13 in QEfficient/finetune/experimental/tests/test_model.py

View workflow job for this annotation

GitHub Actions / lint

Ruff (F401)

QEfficient/finetune/experimental/tests/test_model.py:13:8: F401 `transformers` imported but unused
from QEfficient.finetune.experimental.core import model
from QEfficient.finetune.experimental.core.model import BaseModel, HFModel
from QEfficient.finetune.experimental.core.component_registry import registry
from QEfficient.finetune.experimental.core.component_registry import ComponentFactory

Check failure on line 17 in QEfficient/finetune/experimental/tests/test_model.py

View workflow job for this annotation

GitHub Actions / lint

Ruff (I001)

QEfficient/finetune/experimental/tests/test_model.py:8:1: I001 Import block is un-sorted or un-formatted


class TestMockModel(nn.Module):
def __init__(self):
super().__init__()
self.linear = nn.Linear(2, 2)

def forward(self, x):
return self.linear(x)


@registry.model("testcustom")
class TestCustomModel(BaseModel):
def __init__(self, model_name):
super().__init__(model_name)
print("init of custom class")

def load_model(self) -> nn.Module:
return TestMockModel()

def load_tokenizer(self):
return "dummy-tokenizer"


Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where is PEFT related test cases?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is no PEFT code, hence no test cases.

Replied above for PEFT code.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to add test code for the PEFT models. Add it in the BaseClass, make an instance of HFModel with PEFT config and check whether the model is modified with PEFT changes or not.

# BaseModel tests
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add tests to register the children class (of BaseModel) in registry and fetch correct instance of the class upon passing required parameters.

def test_model_property_errors_if_not_created():
m = TestCustomModel("dummy")
with pytest.raises(RuntimeError):
_ = m.model # must call .create()


def test_create_builds_and_registers():
m = ComponentFactory.create_model("testcustom", "dummy")
# inner model exists and registered
assert "_model" in m._modules
assert isinstance(m.model, TestMockModel)
# forward works
out = m(torch.zeros(1, 2))
assert out.shape == (1, 2)


def test_tokenizer_lazy_loading():
m = ComponentFactory.create_model("testcustom", "dummy")
assert m._tokenizer is None
tok = m.tokenizer
assert tok == "dummy-tokenizer"
assert m._tokenizer == tok
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What are you trying to accomplish?

First you are checking whether m._tokenizer is None or not. Then you are assigning it to tok variable. Then you are again checking the same whether it is matching with "dummy-tokenizer" string or not. Then you are again trying to match the same m._tokenizer contents with tok variable.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you look carefully they are different. tokenizer is a property and _tokenizer is a private variable

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ohk, so you are trying to make sure that the "tokenizer" property correctly returns _tokenizer member variable. Then makes sense.



def test_to_moves_inner_and_returns_self():
m = ComponentFactory.create_model("testcustom", "dummy")
with mock.patch.object(TestMockModel, "to", autospec=True) as mocked_to:
ret = m.to("cuda:0")
mocked_to.assert_called_once_with(m.model, "cuda:0")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Better to check the weights placement on that device.

assert ret is m
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What does this mean? This is not needed.



def test_train_eval_sync_flags():
m = ComponentFactory.create_model("testcustom", "dummy")
m.eval()
assert m.training is False
assert m.model.training is False
m.train()
assert m.training is True
assert m.model.training is True


def test_resize_token_embeddings_and_get_input_embeddings_warn(monkeypatch):
m = ComponentFactory.create_model("testcustom", "dummy")

# resize_token_embeddings: underlying model lacks the method, should warn and not raise
with mock.patch("QEfficient.finetune.experimental.core.model.logger.info") as mocked_log:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logger.info should be changed to logger.log_rank_zero based on earlier comments.

m.resize_token_embeddings(10)
mocked_log.assert_called_once()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is too brute force. Better to manipulate the embedding and check whether embedding shape has changed or not.


# get_input_embeddings: underlying model lacks method, should warn and return None
with mock.patch("QEfficient.finetune.experimental.core.model.logger.info") as mocked_log:
assert m.get_input_embeddings() is None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

create a proper dummy model which returns some embeddings rather than None. Your test should use HFModel class instead of some dummy class.

mocked_log.assert_called_once()


def test_state_dict_contains_inner_params():
m = ComponentFactory.create_model("testcustom", "dummy")
sd = m.state_dict()
# should contain params from TestMockModel.linear
assert any("linear.weight" in k for k in sd)
assert any("linear.bias" in k for k in sd)


# HFModel tests
def test_hfmodel_invalid_auto_class_raises():
with pytest.raises(ValueError):
ComponentFactory.create_model("hf", "hf-name", auto_class_name="AutoDoesNotExist")


def test_hfmodel_loads_auto_and_tokenizer(monkeypatch):
# fake HF Auto class
class FakeAuto(nn.Module):
@classmethod
def from_pretrained(cls, name, **kwargs):
inst = cls()
inst.loaded = (name, kwargs)
return inst

def forward(self, x):
return x

fake_tok = mock.Mock()

# Monkeypatch transformer classes used in HFModel
monkeypatch.setattr(
"QEfficient.finetune.experimental.core.model.transformers.AutoModelForCausalLM",
FakeAuto,
raising=False,
)
monkeypatch.setattr(
model,
"AutoTokenizer",
mock.Mock(from_pretrained=mock.Mock(return_value=fake_tok)),
)
monkeypatch.setattr(
"QEfficient.finetune.experimental.core.model.insert_pad_token",
mock.Mock(),
raising=False,
)
m = ComponentFactory.create_model("hf", "hf-name")
m = HFModel.create("hf-name")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No need to call individual class's create method. There is a reason to use a registry functionality.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Purpose of create is different than component_registry. Explained in another comment.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Move it to ComponentFactory and then instantiate from there.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Replied above.

assert isinstance(m.model, FakeAuto)

# load tokenizer
tok = m.load_tokenizer()

Check failure on line 148 in QEfficient/finetune/experimental/tests/test_model.py

View workflow job for this annotation

GitHub Actions / lint

Ruff (F841)

QEfficient/finetune/experimental/tests/test_model.py:148:5: F841 Local variable `tok` is assigned to but never used

# tokenizer was loaded and pad token inserted
model.AutoTokenizer.from_pretrained.assert_called_once_with("hf-name")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Dont write such tests filled with assert_called_once etc.. We have written such tests in past but that was not an appropriate thing. It was a makeshift arrangement because of monolith structure of code. Write extensive and proper tests. If a function has made some changes to the model's structure then use that to test rather then counting how many times the function gets called.

model.insert_pad_token.assert_called_once_with(fake_tok)
Loading