-
Notifications
You must be signed in to change notification settings - Fork 63
[QEff. Finetune]: Adding base class and HF class #658
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: ft_experimental
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -4,3 +4,146 @@ | |
| # SPDX-License-Identifier: BSD-3-Clause | ||
| # | ||
| # ----------------------------------------------------------------------------- | ||
|
|
||
| import logging | ||
| from abc import ABC, abstractmethod | ||
| from typing import Any, Dict, Optional, Type | ||
|
|
||
| import torch.nn as nn | ||
| from transformers import AutoTokenizer | ||
| import transformers | ||
| from transformers.utils.logging import get_logger | ||
|
|
||
| from QEfficient.finetune.experimental.core.component_registry import registry | ||
| from QEfficient.finetune.experimental.core.utils.dataset_utils import insert_pad_token | ||
|
|
||
| logger = get_logger(__name__) | ||
|
|
||
|
|
||
| class BaseModel(nn.Module, ABC): | ||
| """Shared skeleton for every finetunable model in the system.""" | ||
|
|
||
| def __init__(self, model_name: str, **model_kwargs: Any) -> None: | ||
| super().__init__() | ||
| self.model_name = model_name | ||
| self.model_kwargs: Dict[str, Any] = model_kwargs | ||
| self._model: Optional[nn.Module] = None | ||
| self._tokenizer: Any = None # HF tokenizers are not nn.Modules. | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This variable is specific to LLMs. We can not put it here.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, missed this part, The reference code also makes this mistake and got copied over. Base model needs to stay model-agnostic. I'll probably move these LLM specific methods to an LLM mixin.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let us use some generic name e.g. _preprocessor (even in HFModel class)? |
||
|
|
||
| # Factory constructor: load model after __init__ finishes | ||
| @classmethod | ||
| def create(cls, model_name: str, **model_kwargs: Any) -> "BaseModel": | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not needed. Use registry mechanism to instantiate the BaseModel type of objects, which in turn instantiates nn.Modules as well.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. create factory seems to me a clean way to construct objects. It is the only place that guarantees the wrapped model is actually registered in the nn.Module class before anyone calls state_dict, parameters, to / train. Dropping it and relying on bare init puts us back into lazy-init land, where un-necessary we have to call load_model() in every method like reference code.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I understand your concern. Basically, you want to make sure there will be one liner initialization of model. The reference code does it in 2 lines : https://github.com/quic-meetkuma/LightningLLMs/blob/da2f6b39e8533cbd05d563ca68113828f783e73e/LightningLLM/main.py#L131C34-L131C46 I suggest better to move this create into the ComponentRegistry class and create a "create_model" method over there. The reason being that class has a sole responsibility of storing references of particular classes and creating instances of classes which the user want. This way all the registry and instance creation is encapsulated in one place and we only need to deal with ComponentRegistry and not other individual classes.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The reference shared above is stale as PR #645 didn't have create_model() added in the component_registry.py Why was it not added? I have added it. Even after adding it, we would still need a method in BaseModel to load the model and register it in nn.Module which will be called through create_model() now. Keeping the name of that method same as create as of now. Loading part can also be moved to create_model() in the component_registry.py if that way seems better. |
||
| obj = cls(model_name, **model_kwargs) | ||
| # load model after __init__ finishes | ||
| module = obj.load_model() | ||
| if not isinstance(module, nn.Module): | ||
| raise TypeError(f"load_model() must return nn.Module, got {type(module)}") | ||
| obj._model = module | ||
| return obj | ||
|
|
||
| @abstractmethod | ||
| def load_model(self) -> nn.Module: | ||
| """Load and return the underlying torch.nn.Module.""" | ||
| pass | ||
|
|
||
| def load_tokenizer(self) -> Any: | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Specific to LLMs. We can not put it here.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Replied above
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What I feel is that the naming should be "load_preprocessor" instead of load_tokenizer. In the HFModel class's load_preprocessor, you should actually load the tokenizer and return it. Tokenizer is one of the preprocessor. Other models can have different preprocessor. |
||
| """Override if the model exposes a tokenizer.""" | ||
| raise NotImplementedError(f"{type(self).__name__} does not provide a tokenizer.") | ||
|
|
||
| # Lazy accessors | ||
| @property | ||
| def model(self) -> nn.Module: | ||
| if self._model is None: | ||
| raise RuntimeError("Model not loaded; use .create(...) to load.") | ||
| return self._model | ||
|
|
||
| @property | ||
| def tokenizer(self) -> Any: | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is applicable for NLP models. We can't put it in BaseClass. Better to create an abstract method called "preprocessor" which defines generic preprocessing function applicable for the model. There wont be any implementation here but the children class should implement that. In case of LLMs, this method should return tokenizer. |
||
| if self._tokenizer is None: | ||
| self._tokenizer = self.load_tokenizer() | ||
| return self._tokenizer | ||
|
|
||
| # nn.Module API surface | ||
| def forward(self, *args, **kwargs): | ||
| return self.model(*args, **kwargs) | ||
|
|
||
| def get_input_embeddings(self): | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is applicable for NLP models. We can't put it in BaseClass.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Replied above.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I am a bit doubtful on this. But looking at most of the HF model's implementation (e.g. ASR and Vision) this method exists. So it is fine. Bdw do we need this method at all? I think we can get rid of it as it is not even used.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This method is also accessible via PreTrainedClass. method implementation here is not helping. What was the thought process behind adding this one in the reference code? I also think we should remove this.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I meant this method also* (just like resize_token_embeddings) is accessible via PreTrainedModel class. Check reply for resize_token_embeddings one. |
||
| if hasattr(self.model, "get_input_embeddings"): | ||
| return self.model.get_input_embeddings() | ||
| logger.info(f"Model {self.model_name} does not expose input embeddings", logging.WARNING) | ||
| return None | ||
|
|
||
| def resize_token_embeddings(self, new_num_tokens: int) -> None: | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is applicable for NLP models. We can't put it in BaseClass.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Replied above.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This method looks specific to NLP. I dont find reference of this method in VIT or Whisper model of HF. I think the else case is taking care of it. So it is fine for now.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Whisper and ViT inherit PreTrainedModel class which has resize_token_embeddings implemented in it. Else block will take care of the models which didn't define resize_token_embeddings and have not inherited PreTrainedModel . What was the thought process behind adding this method in the reference code? I think we should remove it from here. |
||
| if hasattr(self.model, "resize_token_embeddings"): | ||
| self.model.resize_token_embeddings(new_num_tokens) | ||
| else: | ||
| logger.info(f"Model {self.model_name} cannot resize token embeddings", logging.WARNING) | ||
|
|
||
| # optional | ||
| def to(self, *args, **kwargs): | ||
| self.model.to(*args, **kwargs) | ||
| return self | ||
|
|
||
| def train(self, mode: bool = True): | ||
| self.model.train(mode) | ||
| return super().train(mode) | ||
|
|
||
| def eval(self): | ||
| return self.train(False) | ||
|
|
||
|
|
||
| @registry.model("hf") | ||
| class HFModel(BaseModel): | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. How can I get the PEFT config file from this model? Any reason to avoid adding that method? Reference code had that functionality for some reason: https://github.com/quic-meetkuma/LightningLLMs/blob/hf_trainer/LightningLLM/components/model.py#L272
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The PEFT configuration should not reside within model classes. Even if it had to be included, it belongs in the BaseModel rather than HFModel, since the PEFT library supports custom models. Refer documentation from PEFT: https://huggingface.co/docs/peft/developer_guides/custom_models But the reference code incorrectly places it in HFModel.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, makes more sense to put it in base class. Please do the needful.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We need to load the model based on configuration as well. That is mainly for testing purpose. In integration tests we will not load an entire model consists of 32 layers. But we will only load the same model with 2 or 4 layers and do further testing. For that purpose config should be used to load the model. Check huggingface documentation on how to do that. |
||
| """HuggingFace-backed model with optional quantization.""" | ||
|
|
||
| def __init__( | ||
| self, | ||
| model_name: str, | ||
| auto_class_name: str = "AutoModelForCausalLM", | ||
| *, | ||
| tokenizer_name: Optional[str] = None, | ||
| **model_kwargs: Any, | ||
| ) -> None: | ||
| super().__init__(model_name, **model_kwargs) | ||
| self.tokenizer_name = tokenizer_name or model_name | ||
| self.auto_class: Type = self._resolve_auto_class(auto_class_name) | ||
|
|
||
| @staticmethod | ||
| def _resolve_auto_class(auto_class_name: str) -> Type: | ||
| if not hasattr(transformers, auto_class_name): | ||
| candidates = sorted(name for name in dir(transformers) if name.startswith("AutoModel")) | ||
| raise ValueError( | ||
| f"Unsupported Auto class '{auto_class_name}'. Available candidates: {', '.join(candidates)}" | ||
| ) | ||
| return getattr(transformers, auto_class_name) | ||
|
|
||
| # def _build_quant_config(self) -> Optional[BitsAndBytesConfig]: | ||
| # if not self.model_kwargs.get("load_in_4bit"): | ||
| # return None | ||
| # return BitsAndBytesConfig( | ||
| # load_in_4bit=True, | ||
| # bnb_4bit_quant_type=self.model_kwargs.get("bnb_4bit_quant_type", "nf4"), | ||
| # bnb_4bit_compute_dtype=self.model_kwargs.get("bnb_4bit_compute_dtype", torch.float16), | ||
| # bnb_4bit_use_double_quant=self.model_kwargs.get("bnb_4bit_use_double_quant", True), | ||
| # ) | ||
|
|
||
| def configure_model_kwargs(self) -> Dict[str, Any]: | ||
| """Hook for subclasses to tweak HF `.from_pretrained` kwargs.""" | ||
| extra = dict(self.model_kwargs) | ||
| # extra["quantization_config"] = self._build_quant_config() | ||
| return extra | ||
|
|
||
| def load_model(self) -> nn.Module: | ||
| logger.info(f"Loading HuggingFace model '{self.model_name}' via {self.auto_class.__name__}") | ||
|
|
||
| return self.auto_class.from_pretrained( | ||
| self.model_name, | ||
| **self.configure_model_kwargs(), | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Dont directly pass results from a function. Keep it explicit in a variable and then unpack dict here. That way it will be easier to pin point the error if any.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Here, you are unpacking all the kwargs. If any extra kwargs are given, does the self.auto_class.from_pretrained method accepts and discards it? If not then it will surely throw an error. Please check and correct it if needed. |
||
| ) | ||
|
|
||
| def load_tokenizer(self) -> AutoTokenizer: | ||
| """Load Hugging Face tokenizer.""" | ||
| logger.info(f"Loading tokenizer '{self.tokenizer_name}'") | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. logger should only log to rank zero. Check other logger references as well. You already have a reference code : https://github.com/quic-meetkuma/LightningLLMs/blob/hf_trainer/LightningLLM/components/model.py#L264 |
||
| tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_name) | ||
| insert_pad_token(tokenizer) | ||
| return tokenizer | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,152 @@ | ||
| # ----------------------------------------------------------------------------- | ||
| # | ||
| # Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. | ||
| # SPDX-License-Identifier: BSD-3-Clause | ||
| # | ||
| # ----------------------------------------------------------------------------- | ||
|
|
||
| import pytest | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Copy write header missing
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Updated in latest. |
||
| import torch | ||
| import torch.nn as nn | ||
| from unittest import mock | ||
|
|
||
| import transformers | ||
| from QEfficient.finetune.experimental.core import model | ||
| from QEfficient.finetune.experimental.core.model import BaseModel, HFModel | ||
| from QEfficient.finetune.experimental.core.component_registry import registry | ||
| from QEfficient.finetune.experimental.core.component_registry import ComponentFactory | ||
|
|
||
|
|
||
| class TestMockModel(nn.Module): | ||
| def __init__(self): | ||
| super().__init__() | ||
| self.linear = nn.Linear(2, 2) | ||
|
|
||
| def forward(self, x): | ||
| return self.linear(x) | ||
|
|
||
|
|
||
| @registry.model("testcustom") | ||
| class TestCustomModel(BaseModel): | ||
| def __init__(self, model_name): | ||
| super().__init__(model_name) | ||
| print("init of custom class") | ||
|
|
||
| def load_model(self) -> nn.Module: | ||
| return TestMockModel() | ||
|
|
||
| def load_tokenizer(self): | ||
| return "dummy-tokenizer" | ||
|
|
||
|
|
||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Where is PEFT related test cases?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There is no PEFT code, hence no test cases. Replied above for PEFT code.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We need to add test code for the PEFT models. Add it in the BaseClass, make an instance of HFModel with PEFT config and check whether the model is modified with PEFT changes or not. |
||
| # BaseModel tests | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Add tests to register the children class (of BaseModel) in registry and fetch correct instance of the class upon passing required parameters. |
||
| def test_model_property_errors_if_not_created(): | ||
| m = TestCustomModel("dummy") | ||
| with pytest.raises(RuntimeError): | ||
| _ = m.model # must call .create() | ||
|
|
||
|
|
||
| def test_create_builds_and_registers(): | ||
| m = ComponentFactory.create_model("testcustom", "dummy") | ||
| # inner model exists and registered | ||
| assert "_model" in m._modules | ||
| assert isinstance(m.model, TestMockModel) | ||
| # forward works | ||
| out = m(torch.zeros(1, 2)) | ||
| assert out.shape == (1, 2) | ||
|
|
||
|
|
||
| def test_tokenizer_lazy_loading(): | ||
| m = ComponentFactory.create_model("testcustom", "dummy") | ||
| assert m._tokenizer is None | ||
| tok = m.tokenizer | ||
| assert tok == "dummy-tokenizer" | ||
| assert m._tokenizer == tok | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What are you trying to accomplish? First you are checking whether m._tokenizer is None or not. Then you are assigning it to tok variable. Then you are again checking the same whether it is matching with "dummy-tokenizer" string or not. Then you are again trying to match the same m._tokenizer contents with tok variable.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If you look carefully they are different. tokenizer is a property and _tokenizer is a private variable
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ohk, so you are trying to make sure that the "tokenizer" property correctly returns _tokenizer member variable. Then makes sense. |
||
|
|
||
|
|
||
| def test_to_moves_inner_and_returns_self(): | ||
| m = ComponentFactory.create_model("testcustom", "dummy") | ||
| with mock.patch.object(TestMockModel, "to", autospec=True) as mocked_to: | ||
| ret = m.to("cuda:0") | ||
| mocked_to.assert_called_once_with(m.model, "cuda:0") | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Better to check the weights placement on that device. |
||
| assert ret is m | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What does this mean? This is not needed. |
||
|
|
||
|
|
||
| def test_train_eval_sync_flags(): | ||
| m = ComponentFactory.create_model("testcustom", "dummy") | ||
| m.eval() | ||
| assert m.training is False | ||
| assert m.model.training is False | ||
| m.train() | ||
| assert m.training is True | ||
| assert m.model.training is True | ||
|
|
||
|
|
||
| def test_resize_token_embeddings_and_get_input_embeddings_warn(monkeypatch): | ||
| m = ComponentFactory.create_model("testcustom", "dummy") | ||
|
|
||
| # resize_token_embeddings: underlying model lacks the method, should warn and not raise | ||
| with mock.patch("QEfficient.finetune.experimental.core.model.logger.info") as mocked_log: | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. logger.info should be changed to logger.log_rank_zero based on earlier comments. |
||
| m.resize_token_embeddings(10) | ||
| mocked_log.assert_called_once() | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is too brute force. Better to manipulate the embedding and check whether embedding shape has changed or not. |
||
|
|
||
| # get_input_embeddings: underlying model lacks method, should warn and return None | ||
| with mock.patch("QEfficient.finetune.experimental.core.model.logger.info") as mocked_log: | ||
| assert m.get_input_embeddings() is None | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. create a proper dummy model which returns some embeddings rather than None. Your test should use HFModel class instead of some dummy class. |
||
| mocked_log.assert_called_once() | ||
|
|
||
|
|
||
| def test_state_dict_contains_inner_params(): | ||
| m = ComponentFactory.create_model("testcustom", "dummy") | ||
| sd = m.state_dict() | ||
| # should contain params from TestMockModel.linear | ||
| assert any("linear.weight" in k for k in sd) | ||
| assert any("linear.bias" in k for k in sd) | ||
|
|
||
|
|
||
| # HFModel tests | ||
| def test_hfmodel_invalid_auto_class_raises(): | ||
| with pytest.raises(ValueError): | ||
| ComponentFactory.create_model("hf", "hf-name", auto_class_name="AutoDoesNotExist") | ||
|
|
||
|
|
||
| def test_hfmodel_loads_auto_and_tokenizer(monkeypatch): | ||
| # fake HF Auto class | ||
| class FakeAuto(nn.Module): | ||
| @classmethod | ||
| def from_pretrained(cls, name, **kwargs): | ||
| inst = cls() | ||
| inst.loaded = (name, kwargs) | ||
| return inst | ||
|
|
||
| def forward(self, x): | ||
| return x | ||
|
|
||
| fake_tok = mock.Mock() | ||
|
|
||
| # Monkeypatch transformer classes used in HFModel | ||
| monkeypatch.setattr( | ||
| "QEfficient.finetune.experimental.core.model.transformers.AutoModelForCausalLM", | ||
| FakeAuto, | ||
| raising=False, | ||
| ) | ||
| monkeypatch.setattr( | ||
| model, | ||
| "AutoTokenizer", | ||
| mock.Mock(from_pretrained=mock.Mock(return_value=fake_tok)), | ||
| ) | ||
| monkeypatch.setattr( | ||
| "QEfficient.finetune.experimental.core.model.insert_pad_token", | ||
| mock.Mock(), | ||
| raising=False, | ||
| ) | ||
| m = ComponentFactory.create_model("hf", "hf-name") | ||
| m = HFModel.create("hf-name") | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No need to call individual class's create method. There is a reason to use a registry functionality.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Purpose of create is different than component_registry. Explained in another comment.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Move it to ComponentFactory and then instantiate from there.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Replied above. |
||
| assert isinstance(m.model, FakeAuto) | ||
|
|
||
| # load tokenizer | ||
| tok = m.load_tokenizer() | ||
|
|
||
| # tokenizer was loaded and pad token inserted | ||
| model.AutoTokenizer.from_pretrained.assert_called_once_with("hf-name") | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Dont write such tests filled with assert_called_once etc.. We have written such tests in past but that was not an appropriate thing. It was a makeshift arrangement because of monolith structure of code. Write extensive and proper tests. If a function has made some changes to the model's structure then use that to test rather then counting how many times the function gets called. |
||
| model.insert_pad_token.assert_called_once_with(fake_tok) | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there any reason why below methods are not implemented? It was given in a reference code as well.
Reference code: https://github.com/quic-meetkuma/LightningLLMs/blob/hf_trainer/LightningLLM/components/model.py
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IMO the reference code is wrong and misses some careful observations about pytorch nn.Module and how it works. nn.Module already provides these and because the wrapped model is registered, so state_dict / parameters work as-is. Adding trivial passthroughs would be redundant, is noisy and a future maintenance hazard. Let me know if you have difficulty understanding this, I can explain offline.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Right, i missed that the BaseModel is inherited from nn.Module.