-
Notifications
You must be signed in to change notification settings - Fork 63
[QEff. Finetune]: Adding base class and HF class #658
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: ft_experimental
Are you sure you want to change the base?
Conversation
Signed-off-by: Swati Allabadi <[email protected]>
Signed-off-by: Swati Allabadi <[email protected]>
a663197 to
defab15
Compare
quic-meetkuma
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please use your reference code wisely. The implementation lacks main functionalities. The tests are not extensive and implemented in naive manner. Please correct them as well.
Reference code: https://github.com/quic-meetkuma/LightningLLMs/blob/hf_trainer/LightningLLM/components/model.py
| logger.info(f"Model {self.model_name} does not expose input embeddings", logging.WARNING) | ||
| return None | ||
|
|
||
| def resize_token_embeddings(self, new_num_tokens: int) -> None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is applicable for NLP models. We can't put it in BaseClass.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Replied above.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This method looks specific to NLP. I dont find reference of this method in VIT or Whisper model of HF. I think the else case is taking care of it. So it is fine for now.
| def forward(self, *args, **kwargs): | ||
| return self.model(*args, **kwargs) | ||
|
|
||
| def get_input_embeddings(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is applicable for NLP models. We can't put it in BaseClass.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Replied above.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am a bit doubtful on this. But looking at most of the HF model's implementation (e.g. ASR and Vision) this method exists. So it is fine.
Bdw do we need this method at all? I think we can get rid of it as it is not even used.
| return self._model | ||
|
|
||
| @property | ||
| def tokenizer(self) -> Any: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is applicable for NLP models. We can't put it in BaseClass. Better to create an abstract method called "preprocessor" which defines generic preprocessing function applicable for the model. There wont be any implementation here but the children class should implement that. In case of LLMs, this method should return tokenizer.
| """Create and return the underlying torch.nn.Module.""" | ||
| ... | ||
|
|
||
| def load_tokenizer(self) -> Any: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Specific to LLMs. We can not put it here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Replied above
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What I feel is that the naming should be "load_preprocessor" instead of load_tokenizer. In the HFModel class's load_preprocessor, you should actually load the tokenizer and return it.
Tokenizer is one of the preprocessor. Other models can have different preprocessor.
| @abstractmethod | ||
| def load_model(self) -> nn.Module: | ||
| """Create and return the underlying torch.nn.Module.""" | ||
| ... |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
use "pass" as it is explicit.
|
|
||
| # get_input_embeddings: underlying model lacks method, should warn and return None | ||
| with mock.patch("QEfficient.finetune.experimental.core.model.logger.info") as mocked_log: | ||
| assert m.get_input_embeddings() is None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
create a proper dummy model which returns some embeddings rather than None. Your test should use HFModel class instead of some dummy class.
| raising=False, | ||
| ) | ||
|
|
||
| m = HFModel.create("hf-name") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No need to call individual class's create method. There is a reason to use a registry functionality.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Purpose of create is different than component_registry. Explained in another comment.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Move it to ComponentFactory and then instantiate from there.
| tok = m.load_tokenizer() | ||
|
|
||
| # tokenizer was loaded and pad token inserted | ||
| model.AutoTokenizer.from_pretrained.assert_called_once_with("hf-name") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Dont write such tests filled with assert_called_once etc.. We have written such tests in past but that was not an appropriate thing. It was a makeshift arrangement because of monolith structure of code. Write extensive and proper tests. If a function has made some changes to the model's structure then use that to test rather then counting how many times the function gets called.
| def load_tokenizer(self): | ||
| return "dummy-tokenizer" | ||
|
|
||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Where is PEFT related test cases?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is no PEFT code, hence no test cases.
Replied above for PEFT code.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We need to add test code for the PEFT models. Add it in the BaseClass, make an instance of HFModel with PEFT config and check whether the model is modified with PEFT changes or not.
|
|
||
|
|
||
| @registry.model("hf") | ||
| class HFModel(BaseModel): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We need to load the model based on configuration as well. That is mainly for testing purpose. In integration tests we will not load an entire model consists of 32 layers. But we will only load the same model with 2 or 4 layers and do further testing. For that purpose config should be used to load the model. Check huggingface documentation on how to do that.
Uh oh!
There was an error while loading. Please reload this page.