From 664287b2a787ff774b6ce9529b2a784e304ee38c Mon Sep 17 00:00:00 2001 From: Kaichen Zhang - NTU Date: Tue, 14 May 2024 13:17:50 +0800 Subject: [PATCH] [Feat] Add llava qwen, llava mistral (#419) Co-authored-by: Bo Li --- .../usage/llava/http_llama3_llava_test.py | 117 ++++++ examples/usage/llava/http_qwen_llava_test.py | 117 ++++++ examples/usage/llava/srt_llava_next_test.py | 88 +++++ python/sglang/srt/models/llava.py | 2 +- python/sglang/srt/models/llava_mistral.py | 347 ++++++++++++++++++ python/sglang/srt/models/llava_qwen.py | 347 ++++++++++++++++++ python/sglang/srt/models/qwen2.py | 4 + 7 files changed, 1021 insertions(+), 1 deletion(-) create mode 100644 examples/usage/llava/http_llama3_llava_test.py create mode 100644 examples/usage/llava/http_qwen_llava_test.py create mode 100644 examples/usage/llava/srt_llava_next_test.py create mode 100644 python/sglang/srt/models/llava_mistral.py create mode 100644 python/sglang/srt/models/llava_qwen.py diff --git a/examples/usage/llava/http_llama3_llava_test.py b/examples/usage/llava/http_llama3_llava_test.py new file mode 100644 index 000000000..2f95c6542 --- /dev/null +++ b/examples/usage/llava/http_llama3_llava_test.py @@ -0,0 +1,117 @@ +""" +Usage: +# Installing latest llava-next: pip install git+https://github.com/LLaVA-VL/LLaVA-NeXT.git +# Installing latest sglang. + +# Endpoint Service CLI: +# python -m sglang.launch_server --model-path lmms-lab/llama3-llava-next-8b --tokenizer-path lmms-lab/llama3-llava-next-8b-tokenizer --port=30000 --host="127.0.0.1" --tp-size=4 + +python3 http_llama3_llava_test.py + +Output: +"Friends posing for a fun photo with a life-sized teddy bear, creating a playful and memorable moment." +""" + +import argparse +import asyncio +import json +import time +import copy + +import aiohttp +import requests + +from llava.conversation import ( + default_conversation, + conv_templates, + SeparatorStyle, + conv_llava_llama_3, + conv_qwen, +) + + +async def send_request(url, data, delay=0): + await asyncio.sleep(delay) + async with aiohttp.ClientSession() as session: + async with session.post(url, json=data) as resp: + output = await resp.json() + return output + + +async def test_concurrent(args): + url = f"{args.host}:{args.port}" + + prompt = "\nPlease generate caption towards this image." + conv_template = copy.deepcopy(conv_llava_llama_3) + conv_template.append_message(role="user", message=prompt) + prompt_with_template = conv_template.get_prompt() + response = [] + for i in range(1): + response.append( + send_request( + url + "/generate", + { + "text": prompt_with_template, + "image_data": "https://farm4.staticflickr.com/3175/2653711032_804ff86d81_z.jpg", + "sampling_params": { + "max_new_tokens": 1024, + "temperature": 0, + "top_p": 1.0, + "presence_penalty": 2, + "frequency_penalty": 2, + "stop": "<|eot_id|>", + }, + }, + ) + ) + + rets = await asyncio.gather(*response) + for ret in rets: + print(ret["text"]) + + +def test_streaming(args): + url = f"{args.host}:{args.port}" + prompt = "\nPlease generate caption towards this image." + conv_template = copy.deepcopy(conv_llava_llama_3) + conv_template.append_message(role="user", message=prompt) + prompt_with_template = conv_template.get_prompt() + pload = { + "text": prompt_with_template, + "sampling_params": { + "max_new_tokens": 1024, + "temperature": 0, + "top_p": 1.0, + "presence_penalty": 2, + "frequency_penalty": 2, + "stop": "<|eot_id|>", + }, + "image_data": "https://farm4.staticflickr.com/3175/2653711032_804ff86d81_z.jpg", + "stream": True, + } + response = requests.post( + url + "/generate", + json=pload, + stream=True, + ) + + prev = 0 + for chunk in response.iter_lines(decode_unicode=False): + chunk = chunk.decode("utf-8") + if chunk and chunk.startswith("data:"): + if chunk == "data: [DONE]": + break + data = json.loads(chunk[5:].strip("\n")) + output = data["text"].strip() + print(output[prev:], end="", flush=True) + prev = len(output) + print("") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--host", type=str, default="http://127.0.0.1") + parser.add_argument("--port", type=int, default=30000) + args = parser.parse_args() + asyncio.run(test_concurrent(args)) + test_streaming(args) diff --git a/examples/usage/llava/http_qwen_llava_test.py b/examples/usage/llava/http_qwen_llava_test.py new file mode 100644 index 000000000..1495c0dfb --- /dev/null +++ b/examples/usage/llava/http_qwen_llava_test.py @@ -0,0 +1,117 @@ +""" +Usage: +# Installing latest llava-next: pip install git+https://github.com/LLaVA-VL/LLaVA-NeXT.git +# Installing latest sglang. + +# Endpoint Service CLI: +# python -m sglang.launch_server --model-path lmms-lab/llava-next-72b --tokenizer-path lmms-lab/llavanext-qwen-tokenizer --port=30000 --host="127.0.0.1" --tp-size=4 + +python3 http_qwen_llava_test.py + +Output: +"Two children pose with a large teddy bear, one holding a smaller stuffed bear, in a room with an American flag and potted plants." +""" + +import argparse +import asyncio +import json +import time +import copy + +import aiohttp +import requests + +from llava.conversation import ( + default_conversation, + conv_templates, + SeparatorStyle, + conv_llava_llama_3, + conv_qwen, +) + + +async def send_request(url, data, delay=0): + await asyncio.sleep(delay) + async with aiohttp.ClientSession() as session: + async with session.post(url, json=data) as resp: + output = await resp.json() + return output + + +async def test_concurrent(args): + url = f"{args.host}:{args.port}" + + prompt = "\nPlease generate caption towards this image." + conv_template = copy.deepcopy(conv_qwen) + conv_template.append_message(role="user", message=prompt) + prompt_with_template = conv_template.get_prompt() + response = [] + for i in range(1): + response.append( + send_request( + url + "/generate", + { + "text": prompt_with_template, + "image_data": "https://farm4.staticflickr.com/3175/2653711032_804ff86d81_z.jpg", + "sampling_params": { + "max_new_tokens": 1024, + "temperature": 0, + "top_p": 1.0, + "presence_penalty": 2, + "frequency_penalty": 2, + "stop": "<|im_end|>", + }, + }, + ) + ) + + rets = await asyncio.gather(*response) + for ret in rets: + print(ret["text"]) + + +def test_streaming(args): + url = f"{args.host}:{args.port}" + prompt = "\nPlease generate caption towards this image." + conv_template = copy.deepcopy(conv_qwen) + conv_template.append_message(role="user", message=prompt) + prompt_with_template = conv_template.get_prompt() + pload = { + "text": prompt_with_template, + "sampling_params": { + "max_new_tokens": 1024, + "temperature": 0, + "top_p": 1.0, + "presence_penalty": 2, + "frequency_penalty": 2, + "stop": "<|im_end|>", + }, + "image_data": "https://farm4.staticflickr.com/3175/2653711032_804ff86d81_z.jpg", + "stream": True, + } + response = requests.post( + url + "/generate", + json=pload, + stream=True, + ) + + prev = 0 + for chunk in response.iter_lines(decode_unicode=False): + chunk = chunk.decode("utf-8") + if chunk and chunk.startswith("data:"): + if chunk == "data: [DONE]": + break + data = json.loads(chunk[5:].strip("\n")) + output = data["text"].strip() + print(output[prev:], end="", flush=True) + prev = len(output) + print("") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--host", type=str, default="http://127.0.0.1") + parser.add_argument("--port", type=int, default=30000) + args = parser.parse_args() + # asyncio.run(test_concurrent(args)) + test_streaming(args) diff --git a/examples/usage/llava/srt_llava_next_test.py b/examples/usage/llava/srt_llava_next_test.py new file mode 100644 index 000000000..8a7a6267d --- /dev/null +++ b/examples/usage/llava/srt_llava_next_test.py @@ -0,0 +1,88 @@ +""" +Usage: python3 srt_example_llava.py +""" + +import sglang as sgl +from sglang.srt.utils import load_image +from sglang.lang.chat_template import get_chat_template + +from PIL import ImageFile +ImageFile.LOAD_TRUNCATED_IMAGES = True # Allow loading of truncated images + +@sgl.function +def image_qa(s, image, question): + s += sgl.user(sgl.image(image) + question) + s += sgl.assistant(sgl.gen("answer")) + + +def single(): + image_url = "https://farm4.staticflickr.com/3175/2653711032_804ff86d81_z.jpg" + pil_image = load_image(image_url) + state = image_qa.run(image=pil_image, question="What is this?", max_new_tokens=512) + print(state["answer"], "\n") + + +def stream(): + image_url = "https://farm4.staticflickr.com/3175/2653711032_804ff86d81_z.jpg" + pil_image = load_image(image_url) + state = image_qa.run( + image=pil_image, + question="Please generate short caption for this image.", + max_new_tokens=512, + temperature=0, + stream=True, + ) + + for out in state.text_iter("answer"): + print(out, end="", flush=True) + print() + + +def batch(): + image_url = "https://farm4.staticflickr.com/3175/2653711032_804ff86d81_z.jpg" + pil_image = load_image(image_url) + states = image_qa.run_batch( + [ + {"image": pil_image, "question": "What is this?"}, + {"image": pil_image, "question": "What is this?"}, + ], + max_new_tokens=512, + ) + for s in states: + print(s["answer"], "\n") + + +if __name__ == "__main__": + import multiprocessing as mp + + mp.set_start_method("spawn", force=True) + runtime = sgl.Runtime( + model_path="lmms-lab/llama3-llava-next-8b", + tokenizer_path="lmms-lab/llama3-llava-next-8b-tokenizer", + ) + runtime.endpoint.chat_template = get_chat_template("llama-3-instruct") + # runtime = sgl.Runtime( + # model_path="lmms-lab/llava-next-72b", + # tokenizer_path="lmms-lab/llavanext-qwen-tokenizer", + # ) + # runtime.endpoint.chat_template = get_chat_template("chatml-llava") + sgl.set_default_backend(runtime) + print(f"chat template: {runtime.endpoint.chat_template.name}") + + # Or you can use API models + # sgl.set_default_backend(sgl.OpenAI("gpt-4-vision-preview")) + # sgl.set_default_backend(sgl.VertexAI("gemini-pro-vision")) + + # Run a single request + print("\n========== single ==========\n") + single() + + # Stream output + print("\n========== stream ==========\n") + stream() + + # Run a batch of requests + print("\n========== batch ==========\n") + batch() + + runtime.shutdown() diff --git a/python/sglang/srt/models/llava.py b/python/sglang/srt/models/llava.py index abce92061..d423541dd 100644 --- a/python/sglang/srt/models/llava.py +++ b/python/sglang/srt/models/llava.py @@ -328,4 +328,4 @@ def monkey_path_clip_vision_embed_forward(): ) -EntryClass = LlavaLlamaForCausalLM +EntryClass = LlavaLlamaForCausalLM \ No newline at end of file diff --git a/python/sglang/srt/models/llava_mistral.py b/python/sglang/srt/models/llava_mistral.py new file mode 100644 index 000000000..d14617cbd --- /dev/null +++ b/python/sglang/srt/models/llava_mistral.py @@ -0,0 +1,347 @@ +"""Inference-only LLaVa model compatible with HuggingFace weights.""" + +from typing import List, Optional + +import numpy as np +import torch +from torch import nn +from transformers import CLIPVisionModel, LlavaConfig, CLIPVisionConfig, MistralConfig +from transformers.models.llava.modeling_llava import LlavaMultiModalProjector +from vllm.model_executor.layers.quantization.base_config import QuantizationConfig +from sglang.srt.weight_utils import ( + default_weight_loader, + hf_model_weights_iterator, +) + +from sglang.srt.managers.router.infer_batch import ForwardMode +from sglang.srt.managers.router.model_runner import InputMetadata +from sglang.srt.mm_utils import ( + get_anyres_image_grid_shape, + unpad_image, + unpad_image_shape, +) +from sglang.srt.models.mistral import MistralForCausalLM + + +class LlavaMistralForCausalLM(nn.Module): + def __init__( + self, + config: LlavaConfig, + quant_config: Optional[QuantizationConfig] = None, + ) -> None: + super().__init__() + self.config = config + self.vision_tower = None + if getattr(self.config, "vision_config", None) is None: + self.config.vision_config = CLIPVisionConfig(self.config.mm_vision_tower) + + if getattr(self.config, "text_config", None) is None: + self.config.text_config = MistralConfig(self.config._name_or_path) + + self.config.vision_config.hidden_size = config.mm_hidden_size + self.config.text_config.hidden_size = config.hidden_size + + if getattr(self.config, "projector_hidden_act", None) is None: + self.config.projector_hidden_act = "gelu" + + if getattr(self.config, "image_token_index", None) is None: + self.config.image_token_index = 32000 + + self.multi_modal_projector = LlavaMultiModalProjector(config) + self.language_model = MistralForCausalLM(config, quant_config=quant_config) + if "unpad" in getattr(config, "mm_patch_merge_type", ""): + self.language_model.model.image_newline = nn.Parameter( + torch.empty(config.text_config.hidden_size, dtype=torch.float16) + ) + + def pad_input_ids(self, input_ids, pad_value, pt_shape=None, image_size=None): + new_image_feature_len = self.image_feature_len + # now only support spatial_unpad + anyres + if self.mm_patch_merge_type.startswith("spatial"): + height = width = self.num_patches_per_side + if pt_shape[0] > 1: + if self.image_aspect_ratio == "anyres": + num_patch_width, num_patch_height = get_anyres_image_grid_shape( + image_size, + self.image_grid_pinpoints, + self.vision_tower.config.image_size, + ) + if "unpad" in self.mm_patch_merge_type: + h = num_patch_height * height + w = num_patch_width * width + new_h, new_w = unpad_image_shape(h, w, image_size) + new_image_feature_len += new_h * (new_w + 1) + + pad_ids = pad_value * ( + (new_image_feature_len + len(pad_value)) // len(pad_value) + ) + offset = input_ids.index(self.config.image_token_index) + # old_len + pad_len - 1, because we need to remove image_token_id + new_input_ids = ( + input_ids[:offset] + + pad_ids[:new_image_feature_len] + + input_ids[offset + 1 :] + ) + return new_input_ids, offset + + def encode_images(self, pixel_values: torch.Tensor) -> torch.Tensor: + image_outputs = self.vision_tower(pixel_values, output_hidden_states=True) + # NOTE: This is not memory efficient. (output_hidden_states=True) will save all the hidden stated. + + selected_image_feature = image_outputs.hidden_states[self.vision_feature_layer] + if self.vision_feature_select_strategy in ["default", "patch"]: + selected_image_feature = selected_image_feature[:, 1:] + elif self.vision_feature_select_strategy == "full": + selected_image_feature = selected_image_feature + else: + raise ValueError( + f"Unexpected select feature strategy: {self.config.vision_feature_select_strategy}" + ) + image_features = self.multi_modal_projector(selected_image_feature) + + return image_features + + def forward( + self, + input_ids: torch.LongTensor, + positions: torch.Tensor, + input_metadata: InputMetadata, + pixel_values: Optional[List[Optional[np.array]]] = None, + image_sizes: Optional[List[List[int]]] = None, + image_offsets: Optional[List[int]] = None, + ) -> torch.Tensor: + if input_metadata.forward_mode == ForwardMode.EXTEND: + bs = input_metadata.batch_size + + # Embed text input + input_embeds = self.language_model.model.embed_tokens(input_ids) + + # Embed vision input + need_vision = ( + (positions[input_metadata.extend_start_loc] < self.image_feature_len) + .cpu() + .numpy() + ) + # FIXME: We need to substract the length of the system prompt + has_pixel = np.array([pixel_values[i] is not None for i in range(bs)]) + need_vision = need_vision & has_pixel + + if need_vision.any(): + pixel_values = [pixel_values[i] for i in range(bs) if need_vision[i]] + image_sizes = [image_sizes[i] for i in range(bs) if need_vision[i]] + + ########## Encode Image ######## + + if pixel_values[0].ndim == 4: + # llava-hd: BS, num_patch, C=3, H=336, W=336, num_patch obtained from process_images + np.concatenate(pixel_values, axis=0) + # ndim=4 + concat_images = torch.tensor( + np.concatenate(pixel_values, axis=0), + device=self.vision_tower.device, + ) + image_features = self.encode_images(concat_images) + split_sizes = [image.shape[0] for image in pixel_values] + image_features = torch.split(image_features, split_sizes, dim=0) + # hd image_features: BS, num_patch, 576, 4096 + else: + # normal pixel: BS, C=3, H=336, W=336 + pixel_values = torch.tensor( + np.array(pixel_values), device=self.vision_tower.device + ) + image_features = self.encode_images(pixel_values) + # image_features: BS, 576, 4096 + + if self.mm_patch_merge_type.startswith("spatial"): + new_image_features = [] + for image_idx, image_feature in enumerate(image_features): + if image_feature.shape[0] > 1: + base_image_feature = image_feature[0] + image_feature = image_feature[1:] + height = width = self.num_patches_per_side + assert height * width == base_image_feature.shape[0] + if self.image_aspect_ratio == "anyres": + ( + num_patch_width, + num_patch_height, + ) = get_anyres_image_grid_shape( + image_sizes[image_idx], + self.image_grid_pinpoints, + self.vision_tower.config.image_size, + ) + image_feature = image_feature.view( + num_patch_height, num_patch_width, height, width, -1 + ) + else: + raise NotImplementedError() + if "unpad" in self.mm_patch_merge_type: + image_feature = image_feature.permute( + 4, 0, 2, 1, 3 + ).contiguous() + image_feature = image_feature.flatten(1, 2).flatten( + 2, 3 + ) + image_feature = unpad_image( + image_feature, image_sizes[image_idx] + ) + image_feature = torch.cat( + ( + image_feature, + self.language_model.model.image_newline[ + :, None, None + ].expand(*image_feature.shape[:-1], 1), + ), + dim=-1, + ) + image_feature = image_feature.flatten(1, 2).transpose( + 0, 1 + ) + else: + image_feature = image_feature.permute( + 0, 2, 1, 3, 4 + ).contiguous() + image_feature = image_feature.flatten(0, 3) + image_feature = torch.cat( + (base_image_feature, image_feature), dim=0 + ) + else: + image_feature = image_feature[0] + if "unpad" in self.mm_patch_merge_type: + image_feature = torch.cat( + ( + image_feature, + self.language_model.model.image_newline[None], + ), + dim=0, + ) + new_image_features.append(image_feature) + image_features = new_image_features + + extend_start_loc_cpu = input_metadata.extend_start_loc.cpu().numpy() + pt = 0 + for i in range(bs): + if not need_vision[i]: + continue + + start_idx = extend_start_loc_cpu[i] + pad_len, pad_dim = image_features[pt].shape # 576, 4096 + dim = input_embeds.shape[1] + assert ( + pad_dim == dim + ), "invalid pad_dim={}, input_embed_dim={}!".format(pad_dim, dim) + # Fill in the placeholder for the image + try: + input_embeds[ + start_idx + + image_offsets[i] : start_idx + + image_offsets[i] + + pad_len + ] = image_features[pt] + except RuntimeError as e: + print(f"RuntimeError in llava image encoding: {e}") + print(input_embeds.shape) + print(start_idx, image_offsets[i]) + pt += 1 + + return self.language_model( + input_ids, positions, input_metadata, input_embeds=input_embeds + ) + elif input_metadata.forward_mode == ForwardMode.DECODE: + return self.language_model(input_ids, positions, input_metadata) + + def load_weights( + self, + model_name_or_path: str, + cache_dir: Optional[str] = None, + load_format: str = "auto", + revision: Optional[str] = None, + ): + # load clip vision model by cfg['mm_vision_tower']: + # huggingface_name or path_of_clip_relative_to_llava_model_dir + vision_path = self.config.mm_vision_tower + self.vision_tower = CLIPVisionModel.from_pretrained( + vision_path, torch_dtype=torch.float16 + ).cuda() + self.vision_tower.eval() + + self.vision_feature_layer = self.config.mm_vision_select_layer + self.vision_feature_select_strategy = self.config.mm_vision_select_feature + self.image_size = self.vision_tower.config.image_size + self.patch_size = self.vision_tower.config.patch_size + + self.mm_patch_merge_type = getattr(self.config, "mm_patch_merge_type", "flat") + self.image_aspect_ratio = getattr(self.config, "image_aspect_ratio", "square") + self.image_grid_pinpoints = getattr(self.config, "image_grid_pinpoints", None) + + self.image_feature_len = int((self.image_size / self.patch_size) ** 2) + if self.vision_feature_select_strategy == "patch": + pass + elif self.vision_feature_select_strategy == "cls_patch": + self.image_feature_len += 1 + else: + raise ValueError(f"Unexpected select feature: {self.select_feature}") + + # load mm_projector + projector_weights = { + "model.mm_projector.0": "multi_modal_projector.linear_1", + "model.mm_projector.2": "multi_modal_projector.linear_2", + "model.vision_tower.vision_tower": "vision_tower", # Update the vision tower weights if we find them in the checkpoint (it may be finetuned). + } + params_dict = dict(self.named_parameters()) + for name, loaded_weight in hf_model_weights_iterator( + model_name_or_path, cache_dir, load_format, revision + ): + # FIXME: why projector weights read two times? + if "projector" in name or "vision_tower" in name: + for weight_name, param_name in projector_weights.items(): + if weight_name in name: + name = name.replace(weight_name, param_name) + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, loaded_weight) + + # load language model + self.language_model.load_weights( + model_name_or_path, cache_dir, load_format, revision + ) + + monkey_path_clip_vision_embed_forward() + + @property + def num_patches_per_side(self): + return self.image_size // self.patch_size + + +first_call = True + + +def clip_vision_embed_forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor: + batch_size = pixel_values.shape[0] + + # Move this conv layer to CPU to avoid a bug in torch >= 2.1 on A10G. + global first_call + if first_call: + self.patch_embedding.cpu().float() + first_call = False + pixel_values = pixel_values.to(dtype=torch.float32, device="cpu") + patch_embeds = self.patch_embedding(pixel_values).cuda().half() + + patch_embeds = patch_embeds.flatten(2).transpose(1, 2) + + class_embeds = self.class_embedding.expand(batch_size, 1, -1) + embeddings = torch.cat([class_embeds, patch_embeds], dim=1) + embeddings = embeddings + self.position_embedding(self.position_ids) + return embeddings + + +def monkey_path_clip_vision_embed_forward(): + import transformers + + setattr( + transformers.models.clip.modeling_clip.CLIPVisionEmbeddings, + "forward", + clip_vision_embed_forward, + ) + + +EntryClass = LlavaMistralForCausalLM diff --git a/python/sglang/srt/models/llava_qwen.py b/python/sglang/srt/models/llava_qwen.py new file mode 100644 index 000000000..c73ba7e95 --- /dev/null +++ b/python/sglang/srt/models/llava_qwen.py @@ -0,0 +1,347 @@ +"""Inference-only LLaVa model compatible with HuggingFace weights.""" + +from typing import List, Optional + +import numpy as np +import torch +from torch import nn +from transformers import CLIPVisionModel, LlavaConfig, CLIPVisionConfig, Qwen2Config +from transformers.models.llava.modeling_llava import LlavaMultiModalProjector +from vllm.model_executor.layers.quantization.base_config import QuantizationConfig +from sglang.srt.weight_utils import ( + default_weight_loader, + hf_model_weights_iterator, +) + +from sglang.srt.managers.router.infer_batch import ForwardMode +from sglang.srt.managers.router.model_runner import InputMetadata +from sglang.srt.mm_utils import ( + get_anyres_image_grid_shape, + unpad_image, + unpad_image_shape, +) +from sglang.srt.models.qwen2 import Qwen2ForCausalLM + + +class LlavaQwenForCausalLM(nn.Module): + def __init__( + self, + config: LlavaConfig, + quant_config: Optional[QuantizationConfig] = None, + ) -> None: + super().__init__() + self.config = config + self.vision_tower = None + if getattr(self.config, "vision_config", None) is None: + self.config.vision_config = CLIPVisionConfig(self.config.mm_vision_tower) + + if getattr(self.config, "text_config", None) is None: + self.config.text_config = Qwen2Config(self.config._name_or_path) + + self.config.vision_config.hidden_size = config.mm_hidden_size + self.config.text_config.hidden_size = config.hidden_size + + if getattr(self.config, "projector_hidden_act", None) is None: + self.config.projector_hidden_act = "gelu" + + if getattr(self.config, "image_token_index", None) is None: + self.config.image_token_index = 151646 + + self.multi_modal_projector = LlavaMultiModalProjector(config) + self.language_model = Qwen2ForCausalLM(config, quant_config=quant_config) + if "unpad" in getattr(config, "mm_patch_merge_type", ""): + self.language_model.model.image_newline = nn.Parameter( + torch.empty(config.text_config.hidden_size, dtype=torch.float16) + ) + + def pad_input_ids(self, input_ids, pad_value, pt_shape=None, image_size=None): + new_image_feature_len = self.image_feature_len + # now only support spatial_unpad + anyres + if self.mm_patch_merge_type.startswith("spatial"): + height = width = self.num_patches_per_side + if pt_shape[0] > 1: + if self.image_aspect_ratio == "anyres": + num_patch_width, num_patch_height = get_anyres_image_grid_shape( + image_size, + self.image_grid_pinpoints, + self.vision_tower.config.image_size, + ) + if "unpad" in self.mm_patch_merge_type: + h = num_patch_height * height + w = num_patch_width * width + new_h, new_w = unpad_image_shape(h, w, image_size) + new_image_feature_len += new_h * (new_w + 1) + + pad_ids = pad_value * ( + (new_image_feature_len + len(pad_value)) // len(pad_value) + ) + offset = input_ids.index(self.config.image_token_index) + # old_len + pad_len - 1, because we need to remove image_token_id + new_input_ids = ( + input_ids[:offset] + + pad_ids[:new_image_feature_len] + + input_ids[offset + 1 :] + ) + return new_input_ids, offset + + def encode_images(self, pixel_values: torch.Tensor) -> torch.Tensor: + image_outputs = self.vision_tower(pixel_values, output_hidden_states=True) + # NOTE: This is not memory efficient. (output_hidden_states=True) will save all the hidden stated. + + selected_image_feature = image_outputs.hidden_states[self.vision_feature_layer] + if self.vision_feature_select_strategy in ["default", "patch"]: + selected_image_feature = selected_image_feature[:, 1:] + elif self.vision_feature_select_strategy == "full": + selected_image_feature = selected_image_feature + else: + raise ValueError( + f"Unexpected select feature strategy: {self.config.vision_feature_select_strategy}" + ) + image_features = self.multi_modal_projector(selected_image_feature) + + return image_features + + def forward( + self, + input_ids: torch.LongTensor, + positions: torch.Tensor, + input_metadata: InputMetadata, + pixel_values: Optional[List[Optional[np.array]]] = None, + image_sizes: Optional[List[List[int]]] = None, + image_offsets: Optional[List[int]] = None, + ) -> torch.Tensor: + if input_metadata.forward_mode == ForwardMode.EXTEND: + bs = input_metadata.batch_size + + # Embed text input + input_embeds = self.language_model.model.embed_tokens(input_ids) + + # Embed vision input + need_vision = ( + (positions[input_metadata.extend_start_loc] < self.image_feature_len) + .cpu() + .numpy() + ) + # FIXME: We need to substract the length of the system prompt + has_pixel = np.array([pixel_values[i] is not None for i in range(bs)]) + need_vision = need_vision & has_pixel + + if need_vision.any(): + pixel_values = [pixel_values[i] for i in range(bs) if need_vision[i]] + image_sizes = [image_sizes[i] for i in range(bs) if need_vision[i]] + + ########## Encode Image ######## + + if pixel_values[0].ndim == 4: + # llava-hd: BS, num_patch, C=3, H=336, W=336, num_patch obtained from process_images + np.concatenate(pixel_values, axis=0) + # ndim=4 + concat_images = torch.tensor( + np.concatenate(pixel_values, axis=0), + device=self.vision_tower.device, + ) + image_features = self.encode_images(concat_images) + split_sizes = [image.shape[0] for image in pixel_values] + image_features = torch.split(image_features, split_sizes, dim=0) + # hd image_features: BS, num_patch, 576, 4096 + else: + # normal pixel: BS, C=3, H=336, W=336 + pixel_values = torch.tensor( + np.array(pixel_values), device=self.vision_tower.device + ) + image_features = self.encode_images(pixel_values) + # image_features: BS, 576, 4096 + + if self.mm_patch_merge_type.startswith("spatial"): + new_image_features = [] + for image_idx, image_feature in enumerate(image_features): + if image_feature.shape[0] > 1: + base_image_feature = image_feature[0] + image_feature = image_feature[1:] + height = width = self.num_patches_per_side + assert height * width == base_image_feature.shape[0] + if self.image_aspect_ratio == "anyres": + ( + num_patch_width, + num_patch_height, + ) = get_anyres_image_grid_shape( + image_sizes[image_idx], + self.image_grid_pinpoints, + self.vision_tower.config.image_size, + ) + image_feature = image_feature.view( + num_patch_height, num_patch_width, height, width, -1 + ) + else: + raise NotImplementedError() + if "unpad" in self.mm_patch_merge_type: + image_feature = image_feature.permute( + 4, 0, 2, 1, 3 + ).contiguous() + image_feature = image_feature.flatten(1, 2).flatten( + 2, 3 + ) + image_feature = unpad_image( + image_feature, image_sizes[image_idx] + ) + image_feature = torch.cat( + ( + image_feature, + self.language_model.model.image_newline[ + :, None, None + ].expand(*image_feature.shape[:-1], 1), + ), + dim=-1, + ) + image_feature = image_feature.flatten(1, 2).transpose( + 0, 1 + ) + else: + image_feature = image_feature.permute( + 0, 2, 1, 3, 4 + ).contiguous() + image_feature = image_feature.flatten(0, 3) + image_feature = torch.cat( + (base_image_feature, image_feature), dim=0 + ) + else: + image_feature = image_feature[0] + if "unpad" in self.mm_patch_merge_type: + image_feature = torch.cat( + ( + image_feature, + self.language_model.model.image_newline[None], + ), + dim=0, + ) + new_image_features.append(image_feature) + image_features = new_image_features + + extend_start_loc_cpu = input_metadata.extend_start_loc.cpu().numpy() + pt = 0 + for i in range(bs): + if not need_vision[i]: + continue + + start_idx = extend_start_loc_cpu[i] + pad_len, pad_dim = image_features[pt].shape # 576, 4096 + dim = input_embeds.shape[1] + assert ( + pad_dim == dim + ), "invalid pad_dim={}, input_embed_dim={}!".format(pad_dim, dim) + # Fill in the placeholder for the image + try: + input_embeds[ + start_idx + + image_offsets[i] : start_idx + + image_offsets[i] + + pad_len + ] = image_features[pt] + except RuntimeError as e: + print(f"RuntimeError in llava image encoding: {e}") + print(input_embeds.shape) + print(start_idx, image_offsets[i]) + pt += 1 + + return self.language_model( + input_ids, positions, input_metadata, input_embeds=input_embeds + ) + elif input_metadata.forward_mode == ForwardMode.DECODE: + return self.language_model(input_ids, positions, input_metadata) + + def load_weights( + self, + model_name_or_path: str, + cache_dir: Optional[str] = None, + load_format: str = "auto", + revision: Optional[str] = None, + ): + # load clip vision model by cfg['mm_vision_tower']: + # huggingface_name or path_of_clip_relative_to_llava_model_dir + vision_path = self.config.mm_vision_tower + self.vision_tower = CLIPVisionModel.from_pretrained( + vision_path, torch_dtype=torch.float16 + ).cuda() + self.vision_tower.eval() + + self.vision_feature_layer = self.config.mm_vision_select_layer + self.vision_feature_select_strategy = self.config.mm_vision_select_feature + self.image_size = self.vision_tower.config.image_size + self.patch_size = self.vision_tower.config.patch_size + + self.mm_patch_merge_type = getattr(self.config, "mm_patch_merge_type", "flat") + self.image_aspect_ratio = getattr(self.config, "image_aspect_ratio", "square") + self.image_grid_pinpoints = getattr(self.config, "image_grid_pinpoints", None) + + self.image_feature_len = int((self.image_size / self.patch_size) ** 2) + if self.vision_feature_select_strategy == "patch": + pass + elif self.vision_feature_select_strategy == "cls_patch": + self.image_feature_len += 1 + else: + raise ValueError(f"Unexpected select feature: {self.select_feature}") + + # load mm_projector + projector_weights = { + "model.mm_projector.0": "multi_modal_projector.linear_1", + "model.mm_projector.2": "multi_modal_projector.linear_2", + "model.vision_tower.vision_tower": "vision_tower", # Update the vision tower weights if we find them in the checkpoint (it may be finetuned). + } + params_dict = dict(self.named_parameters()) + for name, loaded_weight in hf_model_weights_iterator( + model_name_or_path, cache_dir, load_format, revision + ): + # FIXME: why projector weights read two times? + if "projector" in name or "vision_tower" in name: + for weight_name, param_name in projector_weights.items(): + if weight_name in name: + name = name.replace(weight_name, param_name) + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, loaded_weight) + + # load language model + self.language_model.load_weights( + model_name_or_path, cache_dir, load_format, revision + ) + + monkey_path_clip_vision_embed_forward() + + @property + def num_patches_per_side(self): + return self.image_size // self.patch_size + + +first_call = True + + +def clip_vision_embed_forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor: + batch_size = pixel_values.shape[0] + + # Move this conv layer to CPU to avoid a bug in torch >= 2.1 on A10G. + global first_call + if first_call: + self.patch_embedding.cpu().float() + first_call = False + pixel_values = pixel_values.to(dtype=torch.float32, device="cpu") + patch_embeds = self.patch_embedding(pixel_values).cuda().half() + + patch_embeds = patch_embeds.flatten(2).transpose(1, 2) + + class_embeds = self.class_embedding.expand(batch_size, 1, -1) + embeddings = torch.cat([class_embeds, patch_embeds], dim=1) + embeddings = embeddings + self.position_embedding(self.position_ids) + return embeddings + + +def monkey_path_clip_vision_embed_forward(): + import transformers + + setattr( + transformers.models.clip.modeling_clip.CLIPVisionEmbeddings, + "forward", + clip_vision_embed_forward, + ) + + +EntryClass = LlavaQwenForCausalLM diff --git a/python/sglang/srt/models/qwen2.py b/python/sglang/srt/models/qwen2.py index 45e5371e7..f0ad5d9bf 100644 --- a/python/sglang/srt/models/qwen2.py +++ b/python/sglang/srt/models/qwen2.py @@ -303,6 +303,8 @@ class Qwen2ForCausalLM(nn.Module): # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue + if name.startswith("model.vision_tower") and name not in params_dict: + continue param = params_dict[name] weight_loader = param.weight_loader weight_loader(param, loaded_weight, shard_id) @@ -311,6 +313,8 @@ class Qwen2ForCausalLM(nn.Module): # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue + if name.startswith("model.vision_tower") and name not in params_dict: + continue param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight)