model(vlm): pixtral (#5084)
This commit is contained in:
@@ -19,7 +19,9 @@ from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import transformers
|
||||
from transformers import (
|
||||
AutoConfig,
|
||||
AutoModel,
|
||||
AutoModelForCausalLM,
|
||||
AutoModelForVision2Seq,
|
||||
@@ -211,7 +213,12 @@ class HFRunner:
|
||||
|
||||
# Load the model and tokenizer
|
||||
if self.model_type == "generation":
|
||||
self.base_model = AutoModelForCausalLM.from_pretrained(
|
||||
config = AutoConfig.from_pretrained(model_path)
|
||||
if model_archs := getattr(config, "architectures"):
|
||||
model_cls = getattr(transformers, model_archs[0])
|
||||
else:
|
||||
model_cls = AutoModelForCausalLM
|
||||
self.base_model = model_cls.from_pretrained(
|
||||
model_path,
|
||||
torch_dtype=torch_dtype,
|
||||
trust_remote_code=self.trust_remote_code,
|
||||
|
||||
Reference in New Issue
Block a user