model: Minicpmo (#3023)
This commit is contained in:
@@ -13,8 +13,8 @@ from transformers import AutoModel, AutoProcessor, AutoTokenizer
|
||||
|
||||
from sglang.srt.configs.model_config import ModelConfig
|
||||
from sglang.srt.conversation import generate_chat_conv
|
||||
from sglang.srt.managers.mm_utils import embed_image_inputs
|
||||
from sglang.srt.managers.schedule_batch import ImageInputs
|
||||
from sglang.srt.managers.mm_utils import embed_mm_inputs
|
||||
from sglang.srt.managers.schedule_batch import MultimodalInputs
|
||||
from sglang.srt.model_executor.model_runner import ModelRunner
|
||||
from sglang.srt.openai_api.protocol import ChatCompletionRequest
|
||||
from sglang.srt.server_args import ServerArgs
|
||||
@@ -136,7 +136,7 @@ class VisionLLMLogitsBase(unittest.IsolatedAsyncioTestCase):
|
||||
return inputs
|
||||
|
||||
def get_sglang_model(self):
|
||||
model_runner = ModelRunner(
|
||||
self.model_runner = ModelRunner(
|
||||
model_config=ModelConfig(self.model_path, model_override_args="{}"),
|
||||
mem_fraction_static=0.8,
|
||||
gpu_id=0,
|
||||
@@ -148,7 +148,7 @@ class VisionLLMLogitsBase(unittest.IsolatedAsyncioTestCase):
|
||||
disable_cuda_graph=True,
|
||||
),
|
||||
)
|
||||
return model_runner.model
|
||||
return self.model_runner.model
|
||||
|
||||
|
||||
class TestMiniCPMVLogits(VisionLLMLogitsBase):
|
||||
@@ -165,10 +165,13 @@ class TestMiniCPMVLogits(VisionLLMLogitsBase):
|
||||
cls.chat_template = "minicpmv"
|
||||
|
||||
cls.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
cls.model = AutoModel.from_pretrained(
|
||||
cls.model_path, torch_dtype=torch.bfloat16, trust_remote_code=True
|
||||
).eval()
|
||||
cls.model.to(cls.device)
|
||||
cls.hf_model = (
|
||||
AutoModel.from_pretrained(
|
||||
cls.model_path, torch_dtype=torch.bfloat16, trust_remote_code=True
|
||||
)
|
||||
.eval()
|
||||
.to(cls.device)
|
||||
)
|
||||
|
||||
async def test_vlm_embedding_output(self):
|
||||
"""
|
||||
@@ -184,7 +187,7 @@ class TestMiniCPMVLogits(VisionLLMLogitsBase):
|
||||
"pixel_values": inputs.pixel_values,
|
||||
"tgt_sizes": inputs.tgt_sizes,
|
||||
}
|
||||
(hf_output, _) = self.model.get_vllm_embedding(
|
||||
(hf_output, _) = self.hf_model.get_vllm_embedding(
|
||||
model_inputs,
|
||||
)
|
||||
hf_output = hf_output.squeeze(0)
|
||||
@@ -192,14 +195,14 @@ class TestMiniCPMVLogits(VisionLLMLogitsBase):
|
||||
# sglang
|
||||
model = self.get_sglang_model()
|
||||
input_ids = inputs["input_ids"].to(self.device).flatten()
|
||||
sglang_output = embed_image_inputs(
|
||||
image_input=ImageInputs(
|
||||
sglang_output = embed_mm_inputs(
|
||||
mm_input=MultimodalInputs(
|
||||
pixel_values=inputs["pixel_values"][0],
|
||||
tgt_sizes=inputs["tgt_sizes"][0],
|
||||
),
|
||||
input_ids=input_ids,
|
||||
input_embedding=model.get_input_embeddings(),
|
||||
image_embedding_func=model.get_image_features,
|
||||
mm_data_embedding_func=model.get_image_features,
|
||||
placeholder_token_ids=[
|
||||
self.processor.tokenizer.unk_token_id,
|
||||
],
|
||||
|
||||
Reference in New Issue
Block a user