model(vlm): pixtral (#5084)

This commit is contained in:
Kiv Chen
2025-05-13 00:16:10 -07:00
committed by GitHub
parent b2e95f62b4
commit 5380cd7ea3
16 changed files with 1125 additions and 39 deletions

View File

@@ -19,7 +19,9 @@ from typing import List, Optional, Tuple, Union
import torch
import torch.nn.functional as F
import transformers
from transformers import (
AutoConfig,
AutoModel,
AutoModelForCausalLM,
AutoModelForVision2Seq,
@@ -211,7 +213,12 @@ class HFRunner:
# Load the model and tokenizer
if self.model_type == "generation":
self.base_model = AutoModelForCausalLM.from_pretrained(
config = AutoConfig.from_pretrained(model_path)
if model_archs := getattr(config, "architectures"):
model_cls = getattr(transformers, model_archs[0])
else:
model_cls = AutoModelForCausalLM
self.base_model = model_cls.from_pretrained(
model_path,
torch_dtype=torch_dtype,
trust_remote_code=self.trust_remote_code,