[Feat] Add llava qwen, llava mistral (#419)
Co-authored-by: Bo Li <drluodian@gmail.com>
This commit is contained in:
committed by
GitHub
parent
e0ae5d42ec
commit
664287b2a7
117
examples/usage/llava/http_llama3_llava_test.py
Normal file
117
examples/usage/llava/http_llama3_llava_test.py
Normal file
@@ -0,0 +1,117 @@
|
|||||||
|
"""
|
||||||
|
Usage:
|
||||||
|
# Installing latest llava-next: pip install git+https://github.com/LLaVA-VL/LLaVA-NeXT.git
|
||||||
|
# Installing latest sglang.
|
||||||
|
|
||||||
|
# Endpoint Service CLI:
|
||||||
|
# python -m sglang.launch_server --model-path lmms-lab/llama3-llava-next-8b --tokenizer-path lmms-lab/llama3-llava-next-8b-tokenizer --port=30000 --host="127.0.0.1" --tp-size=4
|
||||||
|
|
||||||
|
python3 http_llama3_llava_test.py
|
||||||
|
|
||||||
|
Output:
|
||||||
|
"Friends posing for a fun photo with a life-sized teddy bear, creating a playful and memorable moment."
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import time
|
||||||
|
import copy
|
||||||
|
|
||||||
|
import aiohttp
|
||||||
|
import requests
|
||||||
|
|
||||||
|
from llava.conversation import (
|
||||||
|
default_conversation,
|
||||||
|
conv_templates,
|
||||||
|
SeparatorStyle,
|
||||||
|
conv_llava_llama_3,
|
||||||
|
conv_qwen,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def send_request(url, data, delay=0):
|
||||||
|
await asyncio.sleep(delay)
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
async with session.post(url, json=data) as resp:
|
||||||
|
output = await resp.json()
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
async def test_concurrent(args):
|
||||||
|
url = f"{args.host}:{args.port}"
|
||||||
|
|
||||||
|
prompt = "<image>\nPlease generate caption towards this image."
|
||||||
|
conv_template = copy.deepcopy(conv_llava_llama_3)
|
||||||
|
conv_template.append_message(role="user", message=prompt)
|
||||||
|
prompt_with_template = conv_template.get_prompt()
|
||||||
|
response = []
|
||||||
|
for i in range(1):
|
||||||
|
response.append(
|
||||||
|
send_request(
|
||||||
|
url + "/generate",
|
||||||
|
{
|
||||||
|
"text": prompt_with_template,
|
||||||
|
"image_data": "https://farm4.staticflickr.com/3175/2653711032_804ff86d81_z.jpg",
|
||||||
|
"sampling_params": {
|
||||||
|
"max_new_tokens": 1024,
|
||||||
|
"temperature": 0,
|
||||||
|
"top_p": 1.0,
|
||||||
|
"presence_penalty": 2,
|
||||||
|
"frequency_penalty": 2,
|
||||||
|
"stop": "<|eot_id|>",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
rets = await asyncio.gather(*response)
|
||||||
|
for ret in rets:
|
||||||
|
print(ret["text"])
|
||||||
|
|
||||||
|
|
||||||
|
def test_streaming(args):
|
||||||
|
url = f"{args.host}:{args.port}"
|
||||||
|
prompt = "<image>\nPlease generate caption towards this image."
|
||||||
|
conv_template = copy.deepcopy(conv_llava_llama_3)
|
||||||
|
conv_template.append_message(role="user", message=prompt)
|
||||||
|
prompt_with_template = conv_template.get_prompt()
|
||||||
|
pload = {
|
||||||
|
"text": prompt_with_template,
|
||||||
|
"sampling_params": {
|
||||||
|
"max_new_tokens": 1024,
|
||||||
|
"temperature": 0,
|
||||||
|
"top_p": 1.0,
|
||||||
|
"presence_penalty": 2,
|
||||||
|
"frequency_penalty": 2,
|
||||||
|
"stop": "<|eot_id|>",
|
||||||
|
},
|
||||||
|
"image_data": "https://farm4.staticflickr.com/3175/2653711032_804ff86d81_z.jpg",
|
||||||
|
"stream": True,
|
||||||
|
}
|
||||||
|
response = requests.post(
|
||||||
|
url + "/generate",
|
||||||
|
json=pload,
|
||||||
|
stream=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
prev = 0
|
||||||
|
for chunk in response.iter_lines(decode_unicode=False):
|
||||||
|
chunk = chunk.decode("utf-8")
|
||||||
|
if chunk and chunk.startswith("data:"):
|
||||||
|
if chunk == "data: [DONE]":
|
||||||
|
break
|
||||||
|
data = json.loads(chunk[5:].strip("\n"))
|
||||||
|
output = data["text"].strip()
|
||||||
|
print(output[prev:], end="", flush=True)
|
||||||
|
prev = len(output)
|
||||||
|
print("")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("--host", type=str, default="http://127.0.0.1")
|
||||||
|
parser.add_argument("--port", type=int, default=30000)
|
||||||
|
args = parser.parse_args()
|
||||||
|
asyncio.run(test_concurrent(args))
|
||||||
|
test_streaming(args)
|
||||||
117
examples/usage/llava/http_qwen_llava_test.py
Normal file
117
examples/usage/llava/http_qwen_llava_test.py
Normal file
@@ -0,0 +1,117 @@
|
|||||||
|
"""
|
||||||
|
Usage:
|
||||||
|
# Installing latest llava-next: pip install git+https://github.com/LLaVA-VL/LLaVA-NeXT.git
|
||||||
|
# Installing latest sglang.
|
||||||
|
|
||||||
|
# Endpoint Service CLI:
|
||||||
|
# python -m sglang.launch_server --model-path lmms-lab/llava-next-72b --tokenizer-path lmms-lab/llavanext-qwen-tokenizer --port=30000 --host="127.0.0.1" --tp-size=4
|
||||||
|
|
||||||
|
python3 http_qwen_llava_test.py
|
||||||
|
|
||||||
|
Output:
|
||||||
|
"Two children pose with a large teddy bear, one holding a smaller stuffed bear, in a room with an American flag and potted plants."
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import time
|
||||||
|
import copy
|
||||||
|
|
||||||
|
import aiohttp
|
||||||
|
import requests
|
||||||
|
|
||||||
|
from llava.conversation import (
|
||||||
|
default_conversation,
|
||||||
|
conv_templates,
|
||||||
|
SeparatorStyle,
|
||||||
|
conv_llava_llama_3,
|
||||||
|
conv_qwen,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def send_request(url, data, delay=0):
|
||||||
|
await asyncio.sleep(delay)
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
async with session.post(url, json=data) as resp:
|
||||||
|
output = await resp.json()
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
async def test_concurrent(args):
|
||||||
|
url = f"{args.host}:{args.port}"
|
||||||
|
|
||||||
|
prompt = "<image>\nPlease generate caption towards this image."
|
||||||
|
conv_template = copy.deepcopy(conv_qwen)
|
||||||
|
conv_template.append_message(role="user", message=prompt)
|
||||||
|
prompt_with_template = conv_template.get_prompt()
|
||||||
|
response = []
|
||||||
|
for i in range(1):
|
||||||
|
response.append(
|
||||||
|
send_request(
|
||||||
|
url + "/generate",
|
||||||
|
{
|
||||||
|
"text": prompt_with_template,
|
||||||
|
"image_data": "https://farm4.staticflickr.com/3175/2653711032_804ff86d81_z.jpg",
|
||||||
|
"sampling_params": {
|
||||||
|
"max_new_tokens": 1024,
|
||||||
|
"temperature": 0,
|
||||||
|
"top_p": 1.0,
|
||||||
|
"presence_penalty": 2,
|
||||||
|
"frequency_penalty": 2,
|
||||||
|
"stop": "<|im_end|>",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
rets = await asyncio.gather(*response)
|
||||||
|
for ret in rets:
|
||||||
|
print(ret["text"])
|
||||||
|
|
||||||
|
|
||||||
|
def test_streaming(args):
|
||||||
|
url = f"{args.host}:{args.port}"
|
||||||
|
prompt = "<image>\nPlease generate caption towards this image."
|
||||||
|
conv_template = copy.deepcopy(conv_qwen)
|
||||||
|
conv_template.append_message(role="user", message=prompt)
|
||||||
|
prompt_with_template = conv_template.get_prompt()
|
||||||
|
pload = {
|
||||||
|
"text": prompt_with_template,
|
||||||
|
"sampling_params": {
|
||||||
|
"max_new_tokens": 1024,
|
||||||
|
"temperature": 0,
|
||||||
|
"top_p": 1.0,
|
||||||
|
"presence_penalty": 2,
|
||||||
|
"frequency_penalty": 2,
|
||||||
|
"stop": "<|im_end|>",
|
||||||
|
},
|
||||||
|
"image_data": "https://farm4.staticflickr.com/3175/2653711032_804ff86d81_z.jpg",
|
||||||
|
"stream": True,
|
||||||
|
}
|
||||||
|
response = requests.post(
|
||||||
|
url + "/generate",
|
||||||
|
json=pload,
|
||||||
|
stream=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
prev = 0
|
||||||
|
for chunk in response.iter_lines(decode_unicode=False):
|
||||||
|
chunk = chunk.decode("utf-8")
|
||||||
|
if chunk and chunk.startswith("data:"):
|
||||||
|
if chunk == "data: [DONE]":
|
||||||
|
break
|
||||||
|
data = json.loads(chunk[5:].strip("\n"))
|
||||||
|
output = data["text"].strip()
|
||||||
|
print(output[prev:], end="", flush=True)
|
||||||
|
prev = len(output)
|
||||||
|
print("")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("--host", type=str, default="http://127.0.0.1")
|
||||||
|
parser.add_argument("--port", type=int, default=30000)
|
||||||
|
args = parser.parse_args()
|
||||||
|
# asyncio.run(test_concurrent(args))
|
||||||
|
test_streaming(args)
|
||||||
88
examples/usage/llava/srt_llava_next_test.py
Normal file
88
examples/usage/llava/srt_llava_next_test.py
Normal file
@@ -0,0 +1,88 @@
|
|||||||
|
"""
|
||||||
|
Usage: python3 srt_example_llava.py
|
||||||
|
"""
|
||||||
|
|
||||||
|
import sglang as sgl
|
||||||
|
from sglang.srt.utils import load_image
|
||||||
|
from sglang.lang.chat_template import get_chat_template
|
||||||
|
|
||||||
|
from PIL import ImageFile
|
||||||
|
ImageFile.LOAD_TRUNCATED_IMAGES = True # Allow loading of truncated images
|
||||||
|
|
||||||
|
@sgl.function
|
||||||
|
def image_qa(s, image, question):
|
||||||
|
s += sgl.user(sgl.image(image) + question)
|
||||||
|
s += sgl.assistant(sgl.gen("answer"))
|
||||||
|
|
||||||
|
|
||||||
|
def single():
|
||||||
|
image_url = "https://farm4.staticflickr.com/3175/2653711032_804ff86d81_z.jpg"
|
||||||
|
pil_image = load_image(image_url)
|
||||||
|
state = image_qa.run(image=pil_image, question="What is this?", max_new_tokens=512)
|
||||||
|
print(state["answer"], "\n")
|
||||||
|
|
||||||
|
|
||||||
|
def stream():
|
||||||
|
image_url = "https://farm4.staticflickr.com/3175/2653711032_804ff86d81_z.jpg"
|
||||||
|
pil_image = load_image(image_url)
|
||||||
|
state = image_qa.run(
|
||||||
|
image=pil_image,
|
||||||
|
question="Please generate short caption for this image.",
|
||||||
|
max_new_tokens=512,
|
||||||
|
temperature=0,
|
||||||
|
stream=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
for out in state.text_iter("answer"):
|
||||||
|
print(out, end="", flush=True)
|
||||||
|
print()
|
||||||
|
|
||||||
|
|
||||||
|
def batch():
|
||||||
|
image_url = "https://farm4.staticflickr.com/3175/2653711032_804ff86d81_z.jpg"
|
||||||
|
pil_image = load_image(image_url)
|
||||||
|
states = image_qa.run_batch(
|
||||||
|
[
|
||||||
|
{"image": pil_image, "question": "What is this?"},
|
||||||
|
{"image": pil_image, "question": "What is this?"},
|
||||||
|
],
|
||||||
|
max_new_tokens=512,
|
||||||
|
)
|
||||||
|
for s in states:
|
||||||
|
print(s["answer"], "\n")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import multiprocessing as mp
|
||||||
|
|
||||||
|
mp.set_start_method("spawn", force=True)
|
||||||
|
runtime = sgl.Runtime(
|
||||||
|
model_path="lmms-lab/llama3-llava-next-8b",
|
||||||
|
tokenizer_path="lmms-lab/llama3-llava-next-8b-tokenizer",
|
||||||
|
)
|
||||||
|
runtime.endpoint.chat_template = get_chat_template("llama-3-instruct")
|
||||||
|
# runtime = sgl.Runtime(
|
||||||
|
# model_path="lmms-lab/llava-next-72b",
|
||||||
|
# tokenizer_path="lmms-lab/llavanext-qwen-tokenizer",
|
||||||
|
# )
|
||||||
|
# runtime.endpoint.chat_template = get_chat_template("chatml-llava")
|
||||||
|
sgl.set_default_backend(runtime)
|
||||||
|
print(f"chat template: {runtime.endpoint.chat_template.name}")
|
||||||
|
|
||||||
|
# Or you can use API models
|
||||||
|
# sgl.set_default_backend(sgl.OpenAI("gpt-4-vision-preview"))
|
||||||
|
# sgl.set_default_backend(sgl.VertexAI("gemini-pro-vision"))
|
||||||
|
|
||||||
|
# Run a single request
|
||||||
|
print("\n========== single ==========\n")
|
||||||
|
single()
|
||||||
|
|
||||||
|
# Stream output
|
||||||
|
print("\n========== stream ==========\n")
|
||||||
|
stream()
|
||||||
|
|
||||||
|
# Run a batch of requests
|
||||||
|
print("\n========== batch ==========\n")
|
||||||
|
batch()
|
||||||
|
|
||||||
|
runtime.shutdown()
|
||||||
@@ -328,4 +328,4 @@ def monkey_path_clip_vision_embed_forward():
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
EntryClass = LlavaLlamaForCausalLM
|
EntryClass = LlavaLlamaForCausalLM
|
||||||
347
python/sglang/srt/models/llava_mistral.py
Normal file
347
python/sglang/srt/models/llava_mistral.py
Normal file
@@ -0,0 +1,347 @@
|
|||||||
|
"""Inference-only LLaVa model compatible with HuggingFace weights."""
|
||||||
|
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
from transformers import CLIPVisionModel, LlavaConfig, CLIPVisionConfig, MistralConfig
|
||||||
|
from transformers.models.llava.modeling_llava import LlavaMultiModalProjector
|
||||||
|
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
|
||||||
|
from sglang.srt.weight_utils import (
|
||||||
|
default_weight_loader,
|
||||||
|
hf_model_weights_iterator,
|
||||||
|
)
|
||||||
|
|
||||||
|
from sglang.srt.managers.router.infer_batch import ForwardMode
|
||||||
|
from sglang.srt.managers.router.model_runner import InputMetadata
|
||||||
|
from sglang.srt.mm_utils import (
|
||||||
|
get_anyres_image_grid_shape,
|
||||||
|
unpad_image,
|
||||||
|
unpad_image_shape,
|
||||||
|
)
|
||||||
|
from sglang.srt.models.mistral import MistralForCausalLM
|
||||||
|
|
||||||
|
|
||||||
|
class LlavaMistralForCausalLM(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
config: LlavaConfig,
|
||||||
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.config = config
|
||||||
|
self.vision_tower = None
|
||||||
|
if getattr(self.config, "vision_config", None) is None:
|
||||||
|
self.config.vision_config = CLIPVisionConfig(self.config.mm_vision_tower)
|
||||||
|
|
||||||
|
if getattr(self.config, "text_config", None) is None:
|
||||||
|
self.config.text_config = MistralConfig(self.config._name_or_path)
|
||||||
|
|
||||||
|
self.config.vision_config.hidden_size = config.mm_hidden_size
|
||||||
|
self.config.text_config.hidden_size = config.hidden_size
|
||||||
|
|
||||||
|
if getattr(self.config, "projector_hidden_act", None) is None:
|
||||||
|
self.config.projector_hidden_act = "gelu"
|
||||||
|
|
||||||
|
if getattr(self.config, "image_token_index", None) is None:
|
||||||
|
self.config.image_token_index = 32000
|
||||||
|
|
||||||
|
self.multi_modal_projector = LlavaMultiModalProjector(config)
|
||||||
|
self.language_model = MistralForCausalLM(config, quant_config=quant_config)
|
||||||
|
if "unpad" in getattr(config, "mm_patch_merge_type", ""):
|
||||||
|
self.language_model.model.image_newline = nn.Parameter(
|
||||||
|
torch.empty(config.text_config.hidden_size, dtype=torch.float16)
|
||||||
|
)
|
||||||
|
|
||||||
|
def pad_input_ids(self, input_ids, pad_value, pt_shape=None, image_size=None):
|
||||||
|
new_image_feature_len = self.image_feature_len
|
||||||
|
# now only support spatial_unpad + anyres
|
||||||
|
if self.mm_patch_merge_type.startswith("spatial"):
|
||||||
|
height = width = self.num_patches_per_side
|
||||||
|
if pt_shape[0] > 1:
|
||||||
|
if self.image_aspect_ratio == "anyres":
|
||||||
|
num_patch_width, num_patch_height = get_anyres_image_grid_shape(
|
||||||
|
image_size,
|
||||||
|
self.image_grid_pinpoints,
|
||||||
|
self.vision_tower.config.image_size,
|
||||||
|
)
|
||||||
|
if "unpad" in self.mm_patch_merge_type:
|
||||||
|
h = num_patch_height * height
|
||||||
|
w = num_patch_width * width
|
||||||
|
new_h, new_w = unpad_image_shape(h, w, image_size)
|
||||||
|
new_image_feature_len += new_h * (new_w + 1)
|
||||||
|
|
||||||
|
pad_ids = pad_value * (
|
||||||
|
(new_image_feature_len + len(pad_value)) // len(pad_value)
|
||||||
|
)
|
||||||
|
offset = input_ids.index(self.config.image_token_index)
|
||||||
|
# old_len + pad_len - 1, because we need to remove image_token_id
|
||||||
|
new_input_ids = (
|
||||||
|
input_ids[:offset]
|
||||||
|
+ pad_ids[:new_image_feature_len]
|
||||||
|
+ input_ids[offset + 1 :]
|
||||||
|
)
|
||||||
|
return new_input_ids, offset
|
||||||
|
|
||||||
|
def encode_images(self, pixel_values: torch.Tensor) -> torch.Tensor:
|
||||||
|
image_outputs = self.vision_tower(pixel_values, output_hidden_states=True)
|
||||||
|
# NOTE: This is not memory efficient. (output_hidden_states=True) will save all the hidden stated.
|
||||||
|
|
||||||
|
selected_image_feature = image_outputs.hidden_states[self.vision_feature_layer]
|
||||||
|
if self.vision_feature_select_strategy in ["default", "patch"]:
|
||||||
|
selected_image_feature = selected_image_feature[:, 1:]
|
||||||
|
elif self.vision_feature_select_strategy == "full":
|
||||||
|
selected_image_feature = selected_image_feature
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"Unexpected select feature strategy: {self.config.vision_feature_select_strategy}"
|
||||||
|
)
|
||||||
|
image_features = self.multi_modal_projector(selected_image_feature)
|
||||||
|
|
||||||
|
return image_features
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.LongTensor,
|
||||||
|
positions: torch.Tensor,
|
||||||
|
input_metadata: InputMetadata,
|
||||||
|
pixel_values: Optional[List[Optional[np.array]]] = None,
|
||||||
|
image_sizes: Optional[List[List[int]]] = None,
|
||||||
|
image_offsets: Optional[List[int]] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
if input_metadata.forward_mode == ForwardMode.EXTEND:
|
||||||
|
bs = input_metadata.batch_size
|
||||||
|
|
||||||
|
# Embed text input
|
||||||
|
input_embeds = self.language_model.model.embed_tokens(input_ids)
|
||||||
|
|
||||||
|
# Embed vision input
|
||||||
|
need_vision = (
|
||||||
|
(positions[input_metadata.extend_start_loc] < self.image_feature_len)
|
||||||
|
.cpu()
|
||||||
|
.numpy()
|
||||||
|
)
|
||||||
|
# FIXME: We need to substract the length of the system prompt
|
||||||
|
has_pixel = np.array([pixel_values[i] is not None for i in range(bs)])
|
||||||
|
need_vision = need_vision & has_pixel
|
||||||
|
|
||||||
|
if need_vision.any():
|
||||||
|
pixel_values = [pixel_values[i] for i in range(bs) if need_vision[i]]
|
||||||
|
image_sizes = [image_sizes[i] for i in range(bs) if need_vision[i]]
|
||||||
|
|
||||||
|
########## Encode Image ########
|
||||||
|
|
||||||
|
if pixel_values[0].ndim == 4:
|
||||||
|
# llava-hd: BS, num_patch, C=3, H=336, W=336, num_patch obtained from process_images
|
||||||
|
np.concatenate(pixel_values, axis=0)
|
||||||
|
# ndim=4
|
||||||
|
concat_images = torch.tensor(
|
||||||
|
np.concatenate(pixel_values, axis=0),
|
||||||
|
device=self.vision_tower.device,
|
||||||
|
)
|
||||||
|
image_features = self.encode_images(concat_images)
|
||||||
|
split_sizes = [image.shape[0] for image in pixel_values]
|
||||||
|
image_features = torch.split(image_features, split_sizes, dim=0)
|
||||||
|
# hd image_features: BS, num_patch, 576, 4096
|
||||||
|
else:
|
||||||
|
# normal pixel: BS, C=3, H=336, W=336
|
||||||
|
pixel_values = torch.tensor(
|
||||||
|
np.array(pixel_values), device=self.vision_tower.device
|
||||||
|
)
|
||||||
|
image_features = self.encode_images(pixel_values)
|
||||||
|
# image_features: BS, 576, 4096
|
||||||
|
|
||||||
|
if self.mm_patch_merge_type.startswith("spatial"):
|
||||||
|
new_image_features = []
|
||||||
|
for image_idx, image_feature in enumerate(image_features):
|
||||||
|
if image_feature.shape[0] > 1:
|
||||||
|
base_image_feature = image_feature[0]
|
||||||
|
image_feature = image_feature[1:]
|
||||||
|
height = width = self.num_patches_per_side
|
||||||
|
assert height * width == base_image_feature.shape[0]
|
||||||
|
if self.image_aspect_ratio == "anyres":
|
||||||
|
(
|
||||||
|
num_patch_width,
|
||||||
|
num_patch_height,
|
||||||
|
) = get_anyres_image_grid_shape(
|
||||||
|
image_sizes[image_idx],
|
||||||
|
self.image_grid_pinpoints,
|
||||||
|
self.vision_tower.config.image_size,
|
||||||
|
)
|
||||||
|
image_feature = image_feature.view(
|
||||||
|
num_patch_height, num_patch_width, height, width, -1
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError()
|
||||||
|
if "unpad" in self.mm_patch_merge_type:
|
||||||
|
image_feature = image_feature.permute(
|
||||||
|
4, 0, 2, 1, 3
|
||||||
|
).contiguous()
|
||||||
|
image_feature = image_feature.flatten(1, 2).flatten(
|
||||||
|
2, 3
|
||||||
|
)
|
||||||
|
image_feature = unpad_image(
|
||||||
|
image_feature, image_sizes[image_idx]
|
||||||
|
)
|
||||||
|
image_feature = torch.cat(
|
||||||
|
(
|
||||||
|
image_feature,
|
||||||
|
self.language_model.model.image_newline[
|
||||||
|
:, None, None
|
||||||
|
].expand(*image_feature.shape[:-1], 1),
|
||||||
|
),
|
||||||
|
dim=-1,
|
||||||
|
)
|
||||||
|
image_feature = image_feature.flatten(1, 2).transpose(
|
||||||
|
0, 1
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
image_feature = image_feature.permute(
|
||||||
|
0, 2, 1, 3, 4
|
||||||
|
).contiguous()
|
||||||
|
image_feature = image_feature.flatten(0, 3)
|
||||||
|
image_feature = torch.cat(
|
||||||
|
(base_image_feature, image_feature), dim=0
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
image_feature = image_feature[0]
|
||||||
|
if "unpad" in self.mm_patch_merge_type:
|
||||||
|
image_feature = torch.cat(
|
||||||
|
(
|
||||||
|
image_feature,
|
||||||
|
self.language_model.model.image_newline[None],
|
||||||
|
),
|
||||||
|
dim=0,
|
||||||
|
)
|
||||||
|
new_image_features.append(image_feature)
|
||||||
|
image_features = new_image_features
|
||||||
|
|
||||||
|
extend_start_loc_cpu = input_metadata.extend_start_loc.cpu().numpy()
|
||||||
|
pt = 0
|
||||||
|
for i in range(bs):
|
||||||
|
if not need_vision[i]:
|
||||||
|
continue
|
||||||
|
|
||||||
|
start_idx = extend_start_loc_cpu[i]
|
||||||
|
pad_len, pad_dim = image_features[pt].shape # 576, 4096
|
||||||
|
dim = input_embeds.shape[1]
|
||||||
|
assert (
|
||||||
|
pad_dim == dim
|
||||||
|
), "invalid pad_dim={}, input_embed_dim={}!".format(pad_dim, dim)
|
||||||
|
# Fill in the placeholder for the image
|
||||||
|
try:
|
||||||
|
input_embeds[
|
||||||
|
start_idx
|
||||||
|
+ image_offsets[i] : start_idx
|
||||||
|
+ image_offsets[i]
|
||||||
|
+ pad_len
|
||||||
|
] = image_features[pt]
|
||||||
|
except RuntimeError as e:
|
||||||
|
print(f"RuntimeError in llava image encoding: {e}")
|
||||||
|
print(input_embeds.shape)
|
||||||
|
print(start_idx, image_offsets[i])
|
||||||
|
pt += 1
|
||||||
|
|
||||||
|
return self.language_model(
|
||||||
|
input_ids, positions, input_metadata, input_embeds=input_embeds
|
||||||
|
)
|
||||||
|
elif input_metadata.forward_mode == ForwardMode.DECODE:
|
||||||
|
return self.language_model(input_ids, positions, input_metadata)
|
||||||
|
|
||||||
|
def load_weights(
|
||||||
|
self,
|
||||||
|
model_name_or_path: str,
|
||||||
|
cache_dir: Optional[str] = None,
|
||||||
|
load_format: str = "auto",
|
||||||
|
revision: Optional[str] = None,
|
||||||
|
):
|
||||||
|
# load clip vision model by cfg['mm_vision_tower']:
|
||||||
|
# huggingface_name or path_of_clip_relative_to_llava_model_dir
|
||||||
|
vision_path = self.config.mm_vision_tower
|
||||||
|
self.vision_tower = CLIPVisionModel.from_pretrained(
|
||||||
|
vision_path, torch_dtype=torch.float16
|
||||||
|
).cuda()
|
||||||
|
self.vision_tower.eval()
|
||||||
|
|
||||||
|
self.vision_feature_layer = self.config.mm_vision_select_layer
|
||||||
|
self.vision_feature_select_strategy = self.config.mm_vision_select_feature
|
||||||
|
self.image_size = self.vision_tower.config.image_size
|
||||||
|
self.patch_size = self.vision_tower.config.patch_size
|
||||||
|
|
||||||
|
self.mm_patch_merge_type = getattr(self.config, "mm_patch_merge_type", "flat")
|
||||||
|
self.image_aspect_ratio = getattr(self.config, "image_aspect_ratio", "square")
|
||||||
|
self.image_grid_pinpoints = getattr(self.config, "image_grid_pinpoints", None)
|
||||||
|
|
||||||
|
self.image_feature_len = int((self.image_size / self.patch_size) ** 2)
|
||||||
|
if self.vision_feature_select_strategy == "patch":
|
||||||
|
pass
|
||||||
|
elif self.vision_feature_select_strategy == "cls_patch":
|
||||||
|
self.image_feature_len += 1
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unexpected select feature: {self.select_feature}")
|
||||||
|
|
||||||
|
# load mm_projector
|
||||||
|
projector_weights = {
|
||||||
|
"model.mm_projector.0": "multi_modal_projector.linear_1",
|
||||||
|
"model.mm_projector.2": "multi_modal_projector.linear_2",
|
||||||
|
"model.vision_tower.vision_tower": "vision_tower", # Update the vision tower weights if we find them in the checkpoint (it may be finetuned).
|
||||||
|
}
|
||||||
|
params_dict = dict(self.named_parameters())
|
||||||
|
for name, loaded_weight in hf_model_weights_iterator(
|
||||||
|
model_name_or_path, cache_dir, load_format, revision
|
||||||
|
):
|
||||||
|
# FIXME: why projector weights read two times?
|
||||||
|
if "projector" in name or "vision_tower" in name:
|
||||||
|
for weight_name, param_name in projector_weights.items():
|
||||||
|
if weight_name in name:
|
||||||
|
name = name.replace(weight_name, param_name)
|
||||||
|
param = params_dict[name]
|
||||||
|
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||||
|
weight_loader(param, loaded_weight)
|
||||||
|
|
||||||
|
# load language model
|
||||||
|
self.language_model.load_weights(
|
||||||
|
model_name_or_path, cache_dir, load_format, revision
|
||||||
|
)
|
||||||
|
|
||||||
|
monkey_path_clip_vision_embed_forward()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def num_patches_per_side(self):
|
||||||
|
return self.image_size // self.patch_size
|
||||||
|
|
||||||
|
|
||||||
|
first_call = True
|
||||||
|
|
||||||
|
|
||||||
|
def clip_vision_embed_forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
|
||||||
|
batch_size = pixel_values.shape[0]
|
||||||
|
|
||||||
|
# Move this conv layer to CPU to avoid a bug in torch >= 2.1 on A10G.
|
||||||
|
global first_call
|
||||||
|
if first_call:
|
||||||
|
self.patch_embedding.cpu().float()
|
||||||
|
first_call = False
|
||||||
|
pixel_values = pixel_values.to(dtype=torch.float32, device="cpu")
|
||||||
|
patch_embeds = self.patch_embedding(pixel_values).cuda().half()
|
||||||
|
|
||||||
|
patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
|
||||||
|
|
||||||
|
class_embeds = self.class_embedding.expand(batch_size, 1, -1)
|
||||||
|
embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
|
||||||
|
embeddings = embeddings + self.position_embedding(self.position_ids)
|
||||||
|
return embeddings
|
||||||
|
|
||||||
|
|
||||||
|
def monkey_path_clip_vision_embed_forward():
|
||||||
|
import transformers
|
||||||
|
|
||||||
|
setattr(
|
||||||
|
transformers.models.clip.modeling_clip.CLIPVisionEmbeddings,
|
||||||
|
"forward",
|
||||||
|
clip_vision_embed_forward,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
EntryClass = LlavaMistralForCausalLM
|
||||||
347
python/sglang/srt/models/llava_qwen.py
Normal file
347
python/sglang/srt/models/llava_qwen.py
Normal file
@@ -0,0 +1,347 @@
|
|||||||
|
"""Inference-only LLaVa model compatible with HuggingFace weights."""
|
||||||
|
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
from transformers import CLIPVisionModel, LlavaConfig, CLIPVisionConfig, Qwen2Config
|
||||||
|
from transformers.models.llava.modeling_llava import LlavaMultiModalProjector
|
||||||
|
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
|
||||||
|
from sglang.srt.weight_utils import (
|
||||||
|
default_weight_loader,
|
||||||
|
hf_model_weights_iterator,
|
||||||
|
)
|
||||||
|
|
||||||
|
from sglang.srt.managers.router.infer_batch import ForwardMode
|
||||||
|
from sglang.srt.managers.router.model_runner import InputMetadata
|
||||||
|
from sglang.srt.mm_utils import (
|
||||||
|
get_anyres_image_grid_shape,
|
||||||
|
unpad_image,
|
||||||
|
unpad_image_shape,
|
||||||
|
)
|
||||||
|
from sglang.srt.models.qwen2 import Qwen2ForCausalLM
|
||||||
|
|
||||||
|
|
||||||
|
class LlavaQwenForCausalLM(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
config: LlavaConfig,
|
||||||
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.config = config
|
||||||
|
self.vision_tower = None
|
||||||
|
if getattr(self.config, "vision_config", None) is None:
|
||||||
|
self.config.vision_config = CLIPVisionConfig(self.config.mm_vision_tower)
|
||||||
|
|
||||||
|
if getattr(self.config, "text_config", None) is None:
|
||||||
|
self.config.text_config = Qwen2Config(self.config._name_or_path)
|
||||||
|
|
||||||
|
self.config.vision_config.hidden_size = config.mm_hidden_size
|
||||||
|
self.config.text_config.hidden_size = config.hidden_size
|
||||||
|
|
||||||
|
if getattr(self.config, "projector_hidden_act", None) is None:
|
||||||
|
self.config.projector_hidden_act = "gelu"
|
||||||
|
|
||||||
|
if getattr(self.config, "image_token_index", None) is None:
|
||||||
|
self.config.image_token_index = 151646
|
||||||
|
|
||||||
|
self.multi_modal_projector = LlavaMultiModalProjector(config)
|
||||||
|
self.language_model = Qwen2ForCausalLM(config, quant_config=quant_config)
|
||||||
|
if "unpad" in getattr(config, "mm_patch_merge_type", ""):
|
||||||
|
self.language_model.model.image_newline = nn.Parameter(
|
||||||
|
torch.empty(config.text_config.hidden_size, dtype=torch.float16)
|
||||||
|
)
|
||||||
|
|
||||||
|
def pad_input_ids(self, input_ids, pad_value, pt_shape=None, image_size=None):
|
||||||
|
new_image_feature_len = self.image_feature_len
|
||||||
|
# now only support spatial_unpad + anyres
|
||||||
|
if self.mm_patch_merge_type.startswith("spatial"):
|
||||||
|
height = width = self.num_patches_per_side
|
||||||
|
if pt_shape[0] > 1:
|
||||||
|
if self.image_aspect_ratio == "anyres":
|
||||||
|
num_patch_width, num_patch_height = get_anyres_image_grid_shape(
|
||||||
|
image_size,
|
||||||
|
self.image_grid_pinpoints,
|
||||||
|
self.vision_tower.config.image_size,
|
||||||
|
)
|
||||||
|
if "unpad" in self.mm_patch_merge_type:
|
||||||
|
h = num_patch_height * height
|
||||||
|
w = num_patch_width * width
|
||||||
|
new_h, new_w = unpad_image_shape(h, w, image_size)
|
||||||
|
new_image_feature_len += new_h * (new_w + 1)
|
||||||
|
|
||||||
|
pad_ids = pad_value * (
|
||||||
|
(new_image_feature_len + len(pad_value)) // len(pad_value)
|
||||||
|
)
|
||||||
|
offset = input_ids.index(self.config.image_token_index)
|
||||||
|
# old_len + pad_len - 1, because we need to remove image_token_id
|
||||||
|
new_input_ids = (
|
||||||
|
input_ids[:offset]
|
||||||
|
+ pad_ids[:new_image_feature_len]
|
||||||
|
+ input_ids[offset + 1 :]
|
||||||
|
)
|
||||||
|
return new_input_ids, offset
|
||||||
|
|
||||||
|
def encode_images(self, pixel_values: torch.Tensor) -> torch.Tensor:
|
||||||
|
image_outputs = self.vision_tower(pixel_values, output_hidden_states=True)
|
||||||
|
# NOTE: This is not memory efficient. (output_hidden_states=True) will save all the hidden stated.
|
||||||
|
|
||||||
|
selected_image_feature = image_outputs.hidden_states[self.vision_feature_layer]
|
||||||
|
if self.vision_feature_select_strategy in ["default", "patch"]:
|
||||||
|
selected_image_feature = selected_image_feature[:, 1:]
|
||||||
|
elif self.vision_feature_select_strategy == "full":
|
||||||
|
selected_image_feature = selected_image_feature
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"Unexpected select feature strategy: {self.config.vision_feature_select_strategy}"
|
||||||
|
)
|
||||||
|
image_features = self.multi_modal_projector(selected_image_feature)
|
||||||
|
|
||||||
|
return image_features
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.LongTensor,
|
||||||
|
positions: torch.Tensor,
|
||||||
|
input_metadata: InputMetadata,
|
||||||
|
pixel_values: Optional[List[Optional[np.array]]] = None,
|
||||||
|
image_sizes: Optional[List[List[int]]] = None,
|
||||||
|
image_offsets: Optional[List[int]] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
if input_metadata.forward_mode == ForwardMode.EXTEND:
|
||||||
|
bs = input_metadata.batch_size
|
||||||
|
|
||||||
|
# Embed text input
|
||||||
|
input_embeds = self.language_model.model.embed_tokens(input_ids)
|
||||||
|
|
||||||
|
# Embed vision input
|
||||||
|
need_vision = (
|
||||||
|
(positions[input_metadata.extend_start_loc] < self.image_feature_len)
|
||||||
|
.cpu()
|
||||||
|
.numpy()
|
||||||
|
)
|
||||||
|
# FIXME: We need to substract the length of the system prompt
|
||||||
|
has_pixel = np.array([pixel_values[i] is not None for i in range(bs)])
|
||||||
|
need_vision = need_vision & has_pixel
|
||||||
|
|
||||||
|
if need_vision.any():
|
||||||
|
pixel_values = [pixel_values[i] for i in range(bs) if need_vision[i]]
|
||||||
|
image_sizes = [image_sizes[i] for i in range(bs) if need_vision[i]]
|
||||||
|
|
||||||
|
########## Encode Image ########
|
||||||
|
|
||||||
|
if pixel_values[0].ndim == 4:
|
||||||
|
# llava-hd: BS, num_patch, C=3, H=336, W=336, num_patch obtained from process_images
|
||||||
|
np.concatenate(pixel_values, axis=0)
|
||||||
|
# ndim=4
|
||||||
|
concat_images = torch.tensor(
|
||||||
|
np.concatenate(pixel_values, axis=0),
|
||||||
|
device=self.vision_tower.device,
|
||||||
|
)
|
||||||
|
image_features = self.encode_images(concat_images)
|
||||||
|
split_sizes = [image.shape[0] for image in pixel_values]
|
||||||
|
image_features = torch.split(image_features, split_sizes, dim=0)
|
||||||
|
# hd image_features: BS, num_patch, 576, 4096
|
||||||
|
else:
|
||||||
|
# normal pixel: BS, C=3, H=336, W=336
|
||||||
|
pixel_values = torch.tensor(
|
||||||
|
np.array(pixel_values), device=self.vision_tower.device
|
||||||
|
)
|
||||||
|
image_features = self.encode_images(pixel_values)
|
||||||
|
# image_features: BS, 576, 4096
|
||||||
|
|
||||||
|
if self.mm_patch_merge_type.startswith("spatial"):
|
||||||
|
new_image_features = []
|
||||||
|
for image_idx, image_feature in enumerate(image_features):
|
||||||
|
if image_feature.shape[0] > 1:
|
||||||
|
base_image_feature = image_feature[0]
|
||||||
|
image_feature = image_feature[1:]
|
||||||
|
height = width = self.num_patches_per_side
|
||||||
|
assert height * width == base_image_feature.shape[0]
|
||||||
|
if self.image_aspect_ratio == "anyres":
|
||||||
|
(
|
||||||
|
num_patch_width,
|
||||||
|
num_patch_height,
|
||||||
|
) = get_anyres_image_grid_shape(
|
||||||
|
image_sizes[image_idx],
|
||||||
|
self.image_grid_pinpoints,
|
||||||
|
self.vision_tower.config.image_size,
|
||||||
|
)
|
||||||
|
image_feature = image_feature.view(
|
||||||
|
num_patch_height, num_patch_width, height, width, -1
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError()
|
||||||
|
if "unpad" in self.mm_patch_merge_type:
|
||||||
|
image_feature = image_feature.permute(
|
||||||
|
4, 0, 2, 1, 3
|
||||||
|
).contiguous()
|
||||||
|
image_feature = image_feature.flatten(1, 2).flatten(
|
||||||
|
2, 3
|
||||||
|
)
|
||||||
|
image_feature = unpad_image(
|
||||||
|
image_feature, image_sizes[image_idx]
|
||||||
|
)
|
||||||
|
image_feature = torch.cat(
|
||||||
|
(
|
||||||
|
image_feature,
|
||||||
|
self.language_model.model.image_newline[
|
||||||
|
:, None, None
|
||||||
|
].expand(*image_feature.shape[:-1], 1),
|
||||||
|
),
|
||||||
|
dim=-1,
|
||||||
|
)
|
||||||
|
image_feature = image_feature.flatten(1, 2).transpose(
|
||||||
|
0, 1
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
image_feature = image_feature.permute(
|
||||||
|
0, 2, 1, 3, 4
|
||||||
|
).contiguous()
|
||||||
|
image_feature = image_feature.flatten(0, 3)
|
||||||
|
image_feature = torch.cat(
|
||||||
|
(base_image_feature, image_feature), dim=0
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
image_feature = image_feature[0]
|
||||||
|
if "unpad" in self.mm_patch_merge_type:
|
||||||
|
image_feature = torch.cat(
|
||||||
|
(
|
||||||
|
image_feature,
|
||||||
|
self.language_model.model.image_newline[None],
|
||||||
|
),
|
||||||
|
dim=0,
|
||||||
|
)
|
||||||
|
new_image_features.append(image_feature)
|
||||||
|
image_features = new_image_features
|
||||||
|
|
||||||
|
extend_start_loc_cpu = input_metadata.extend_start_loc.cpu().numpy()
|
||||||
|
pt = 0
|
||||||
|
for i in range(bs):
|
||||||
|
if not need_vision[i]:
|
||||||
|
continue
|
||||||
|
|
||||||
|
start_idx = extend_start_loc_cpu[i]
|
||||||
|
pad_len, pad_dim = image_features[pt].shape # 576, 4096
|
||||||
|
dim = input_embeds.shape[1]
|
||||||
|
assert (
|
||||||
|
pad_dim == dim
|
||||||
|
), "invalid pad_dim={}, input_embed_dim={}!".format(pad_dim, dim)
|
||||||
|
# Fill in the placeholder for the image
|
||||||
|
try:
|
||||||
|
input_embeds[
|
||||||
|
start_idx
|
||||||
|
+ image_offsets[i] : start_idx
|
||||||
|
+ image_offsets[i]
|
||||||
|
+ pad_len
|
||||||
|
] = image_features[pt]
|
||||||
|
except RuntimeError as e:
|
||||||
|
print(f"RuntimeError in llava image encoding: {e}")
|
||||||
|
print(input_embeds.shape)
|
||||||
|
print(start_idx, image_offsets[i])
|
||||||
|
pt += 1
|
||||||
|
|
||||||
|
return self.language_model(
|
||||||
|
input_ids, positions, input_metadata, input_embeds=input_embeds
|
||||||
|
)
|
||||||
|
elif input_metadata.forward_mode == ForwardMode.DECODE:
|
||||||
|
return self.language_model(input_ids, positions, input_metadata)
|
||||||
|
|
||||||
|
def load_weights(
|
||||||
|
self,
|
||||||
|
model_name_or_path: str,
|
||||||
|
cache_dir: Optional[str] = None,
|
||||||
|
load_format: str = "auto",
|
||||||
|
revision: Optional[str] = None,
|
||||||
|
):
|
||||||
|
# load clip vision model by cfg['mm_vision_tower']:
|
||||||
|
# huggingface_name or path_of_clip_relative_to_llava_model_dir
|
||||||
|
vision_path = self.config.mm_vision_tower
|
||||||
|
self.vision_tower = CLIPVisionModel.from_pretrained(
|
||||||
|
vision_path, torch_dtype=torch.float16
|
||||||
|
).cuda()
|
||||||
|
self.vision_tower.eval()
|
||||||
|
|
||||||
|
self.vision_feature_layer = self.config.mm_vision_select_layer
|
||||||
|
self.vision_feature_select_strategy = self.config.mm_vision_select_feature
|
||||||
|
self.image_size = self.vision_tower.config.image_size
|
||||||
|
self.patch_size = self.vision_tower.config.patch_size
|
||||||
|
|
||||||
|
self.mm_patch_merge_type = getattr(self.config, "mm_patch_merge_type", "flat")
|
||||||
|
self.image_aspect_ratio = getattr(self.config, "image_aspect_ratio", "square")
|
||||||
|
self.image_grid_pinpoints = getattr(self.config, "image_grid_pinpoints", None)
|
||||||
|
|
||||||
|
self.image_feature_len = int((self.image_size / self.patch_size) ** 2)
|
||||||
|
if self.vision_feature_select_strategy == "patch":
|
||||||
|
pass
|
||||||
|
elif self.vision_feature_select_strategy == "cls_patch":
|
||||||
|
self.image_feature_len += 1
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unexpected select feature: {self.select_feature}")
|
||||||
|
|
||||||
|
# load mm_projector
|
||||||
|
projector_weights = {
|
||||||
|
"model.mm_projector.0": "multi_modal_projector.linear_1",
|
||||||
|
"model.mm_projector.2": "multi_modal_projector.linear_2",
|
||||||
|
"model.vision_tower.vision_tower": "vision_tower", # Update the vision tower weights if we find them in the checkpoint (it may be finetuned).
|
||||||
|
}
|
||||||
|
params_dict = dict(self.named_parameters())
|
||||||
|
for name, loaded_weight in hf_model_weights_iterator(
|
||||||
|
model_name_or_path, cache_dir, load_format, revision
|
||||||
|
):
|
||||||
|
# FIXME: why projector weights read two times?
|
||||||
|
if "projector" in name or "vision_tower" in name:
|
||||||
|
for weight_name, param_name in projector_weights.items():
|
||||||
|
if weight_name in name:
|
||||||
|
name = name.replace(weight_name, param_name)
|
||||||
|
param = params_dict[name]
|
||||||
|
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||||
|
weight_loader(param, loaded_weight)
|
||||||
|
|
||||||
|
# load language model
|
||||||
|
self.language_model.load_weights(
|
||||||
|
model_name_or_path, cache_dir, load_format, revision
|
||||||
|
)
|
||||||
|
|
||||||
|
monkey_path_clip_vision_embed_forward()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def num_patches_per_side(self):
|
||||||
|
return self.image_size // self.patch_size
|
||||||
|
|
||||||
|
|
||||||
|
first_call = True
|
||||||
|
|
||||||
|
|
||||||
|
def clip_vision_embed_forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
|
||||||
|
batch_size = pixel_values.shape[0]
|
||||||
|
|
||||||
|
# Move this conv layer to CPU to avoid a bug in torch >= 2.1 on A10G.
|
||||||
|
global first_call
|
||||||
|
if first_call:
|
||||||
|
self.patch_embedding.cpu().float()
|
||||||
|
first_call = False
|
||||||
|
pixel_values = pixel_values.to(dtype=torch.float32, device="cpu")
|
||||||
|
patch_embeds = self.patch_embedding(pixel_values).cuda().half()
|
||||||
|
|
||||||
|
patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
|
||||||
|
|
||||||
|
class_embeds = self.class_embedding.expand(batch_size, 1, -1)
|
||||||
|
embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
|
||||||
|
embeddings = embeddings + self.position_embedding(self.position_ids)
|
||||||
|
return embeddings
|
||||||
|
|
||||||
|
|
||||||
|
def monkey_path_clip_vision_embed_forward():
|
||||||
|
import transformers
|
||||||
|
|
||||||
|
setattr(
|
||||||
|
transformers.models.clip.modeling_clip.CLIPVisionEmbeddings,
|
||||||
|
"forward",
|
||||||
|
clip_vision_embed_forward,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
EntryClass = LlavaQwenForCausalLM
|
||||||
@@ -303,6 +303,8 @@ class Qwen2ForCausalLM(nn.Module):
|
|||||||
# Skip loading extra bias for GPTQ models.
|
# Skip loading extra bias for GPTQ models.
|
||||||
if name.endswith(".bias") and name not in params_dict:
|
if name.endswith(".bias") and name not in params_dict:
|
||||||
continue
|
continue
|
||||||
|
if name.startswith("model.vision_tower") and name not in params_dict:
|
||||||
|
continue
|
||||||
param = params_dict[name]
|
param = params_dict[name]
|
||||||
weight_loader = param.weight_loader
|
weight_loader = param.weight_loader
|
||||||
weight_loader(param, loaded_weight, shard_id)
|
weight_loader(param, loaded_weight, shard_id)
|
||||||
@@ -311,6 +313,8 @@ class Qwen2ForCausalLM(nn.Module):
|
|||||||
# Skip loading extra bias for GPTQ models.
|
# Skip loading extra bias for GPTQ models.
|
||||||
if name.endswith(".bias") and name not in params_dict:
|
if name.endswith(".bias") and name not in params_dict:
|
||||||
continue
|
continue
|
||||||
|
if name.startswith("model.vision_tower") and name not in params_dict:
|
||||||
|
continue
|
||||||
param = params_dict[name]
|
param = params_dict[name]
|
||||||
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||||
weight_loader(param, loaded_weight)
|
weight_loader(param, loaded_weight)
|
||||||
|
|||||||
Reference in New Issue
Block a user