feat: Support VLM in reference_hf (#2726)
Signed-off-by: Ce Gao <gaocegege@hotmail.com>
This commit is contained in:
@@ -25,12 +25,89 @@ I'm going to the
|
||||
|
||||
import argparse
|
||||
|
||||
import requests
|
||||
from PIL import Image
|
||||
|
||||
import torch
|
||||
from transformers import AutoModelForCausalLM
|
||||
from transformers import (
|
||||
AutoModelForCausalLM, AutoProcessor, AutoModelForImageTextToText
|
||||
)
|
||||
|
||||
from sglang.srt.hf_transformers_utils import get_tokenizer
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def vlm_text_with_image(args):
|
||||
# Load the processor and model for ImageTextToText tasks
|
||||
processor = AutoProcessor.from_pretrained(
|
||||
args.model_path, trust_remote_code=True)
|
||||
model = AutoModelForImageTextToText.from_pretrained(
|
||||
args.model_path,
|
||||
torch_dtype=args.dtype,
|
||||
low_cpu_mem_usage=True,
|
||||
device_map="auto",
|
||||
trust_remote_code=True,
|
||||
)
|
||||
|
||||
torch.cuda.set_device(0)
|
||||
|
||||
# List of image URLs to process
|
||||
image_urls = [
|
||||
"https://github.com/haotian-liu/LLaVA/blob/1a91fc274d7c35a9b50b3cb29c4247ae5837ce39/images/llava_v1_5_radar.jpg?raw=true"
|
||||
]
|
||||
|
||||
# Conversation template for the processor
|
||||
conversation = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "image",
|
||||
},
|
||||
{
|
||||
"type": "text",
|
||||
"text": "Describe this image."
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
|
||||
max_new_tokens = args.max_new_tokens
|
||||
|
||||
for i, url in enumerate(image_urls):
|
||||
# Load the image from the URL
|
||||
image = Image.open(requests.get(url, stream=True).raw)
|
||||
|
||||
# Apply the chat template to the text prompt
|
||||
# Notice that not all processors support chat templates.
|
||||
# LLaVA and QWen are two processors that support chat templates.
|
||||
if not hasattr(processor, "apply_chat_template"):
|
||||
raise ValueError("The processor does not support chat templates.")
|
||||
text_prompt = processor.apply_chat_template(
|
||||
conversation, add_generation_prompt=True)
|
||||
|
||||
# Prepare inputs for the model
|
||||
inputs = processor(text=[text_prompt], images=[image],
|
||||
return_tensors="pt").to("cuda:0")
|
||||
|
||||
# Generate output from the model
|
||||
output_ids = model.generate(
|
||||
**inputs, do_sample=False, max_new_tokens=max_new_tokens
|
||||
)
|
||||
output_str = processor.decode(output_ids[0])
|
||||
|
||||
# Get the logits from the model's forward pass
|
||||
outputs = model.forward(**inputs)
|
||||
logits = outputs.logits[0, -1, :]
|
||||
|
||||
print(f"\n========== Image {i} ==========")
|
||||
print("prefill logits (final)", logits)
|
||||
# TODO(gaocegege): The output contains numerous <|image_pad|> tokens,
|
||||
# making it cluttered and difficult to read.
|
||||
# These tokens should be removed or cleaned up for better readability.
|
||||
print(output_str)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def normal_text(args):
|
||||
t = get_tokenizer(args.model_path, trust_remote_code=True)
|
||||
@@ -108,7 +185,11 @@ if __name__ == "__main__":
|
||||
|
||||
parser.add_argument("--dtype", type=str, default="float16")
|
||||
|
||||
parser.add_argument("--model-type", type=str, default="text")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
normal_text(args)
|
||||
# synthetic_tokens(args)
|
||||
if args.model_type == "vlm":
|
||||
vlm_text_with_image(args)
|
||||
else:
|
||||
normal_text(args)
|
||||
|
||||
Reference in New Issue
Block a user