Fix multimodal with model loaded through AutoGPTQ

This commit is contained in:
oobabooga 2023-06-06 19:42:40 -03:00
parent 3cc5ce3c42
commit f55e85e28a

View file

@ -56,7 +56,12 @@ class LLaVA_v0_Pipeline(AbstractMultimodalPipeline):
@staticmethod
def embed_tokens(input_ids: torch.Tensor) -> torch.Tensor:
return shared.model.model.embed_tokens(input_ids).to(shared.model.device, dtype=shared.model.dtype)
if hasattr(shared.model.model, 'embed_tokens'):
func = shared.model.model.embed_tokens
else:
func = shared.model.model.model.embed_tokens # AutoGPTQ case
return func(input_ids).to(shared.model.device, dtype=shared.model.dtype)
@staticmethod
def placeholder_embeddings() -> torch.Tensor: