[Fix] Fix llava on multi images (#1247)
This commit is contained in:
@@ -240,7 +240,7 @@ python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3-8B-Instruct
|
|||||||
- Qwen / Qwen 2 / Qwen 2 MoE
|
- Qwen / Qwen 2 / Qwen 2 MoE
|
||||||
- DeepSeek / DeepSeek 2
|
- DeepSeek / DeepSeek 2
|
||||||
- [LLaVA-OneVision](https://llava-vl.github.io/blog/2024-08-05-llava-onevision/)
|
- [LLaVA-OneVision](https://llava-vl.github.io/blog/2024-08-05-llava-onevision/)
|
||||||
- `python3 -m sglang.launch_server --model-path lmms-lab/llava-onevision-qwen2-72b-ov --port=30000 --tp-size=8 --chat-template=chatml-llava --chunked-prefill-size=16384`
|
- `python3 -m sglang.launch_server --model-path lmms-lab/llava-onevision-qwen2-72b-ov --port=30000 --tp-size=8 --chat-template=chatml-llava`
|
||||||
- Query the server with the [OpenAI Vision API](https://platform.openai.com/docs/guides/vision). See examples at [test/srt/test_vision_openai_server.py](test/srt/test_vision_openai_server.py)
|
- Query the server with the [OpenAI Vision API](https://platform.openai.com/docs/guides/vision). See examples at [test/srt/test_vision_openai_server.py](test/srt/test_vision_openai_server.py)
|
||||||
- LLaVA 1.5 / 1.6 / NeXT
|
- LLaVA 1.5 / 1.6 / NeXT
|
||||||
- `python -m sglang.launch_server --model-path lmms-lab/llama3-llava-next-8b --port=30000 --tp-size=1 --chat-template=llava_llama_3`
|
- `python -m sglang.launch_server --model-path lmms-lab/llama3-llava-next-8b --port=30000 --tp-size=1 --chat-template=llava_llama_3`
|
||||||
|
|||||||
@@ -184,13 +184,9 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
# Parse the arguments
|
# Parse the arguments
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
cur_port = args.port
|
cur_port = args.port
|
||||||
|
|
||||||
cur_chunk = args.chunk_idx
|
cur_chunk = args.chunk_idx
|
||||||
|
|
||||||
num_chunks = args.num_chunks
|
num_chunks = args.num_chunks
|
||||||
|
|
||||||
num_frames = args.num_frames
|
num_frames = args.num_frames
|
||||||
|
|
||||||
if "34b" in args.model_path.lower():
|
if "34b" in args.model_path.lower():
|
||||||
@@ -202,7 +198,6 @@ if __name__ == "__main__":
|
|||||||
exit()
|
exit()
|
||||||
|
|
||||||
model_overide_args = {}
|
model_overide_args = {}
|
||||||
|
|
||||||
model_overide_args["mm_spatial_pool_stride"] = args.mm_spatial_pool_stride
|
model_overide_args["mm_spatial_pool_stride"] = args.mm_spatial_pool_stride
|
||||||
model_overide_args["architectures"] = ["LlavaVidForCausalLM"]
|
model_overide_args["architectures"] = ["LlavaVidForCausalLM"]
|
||||||
model_overide_args["num_frames"] = args.num_frames
|
model_overide_args["num_frames"] = args.num_frames
|
||||||
@@ -235,7 +230,6 @@ if __name__ == "__main__":
|
|||||||
print(f"chat template: {runtime.endpoint.chat_template.name}")
|
print(f"chat template: {runtime.endpoint.chat_template.name}")
|
||||||
|
|
||||||
# Run a single request
|
# Run a single request
|
||||||
# try:
|
|
||||||
print("\n========== single ==========\n")
|
print("\n========== single ==========\n")
|
||||||
root = args.video_dir
|
root = args.video_dir
|
||||||
if os.path.isfile(root):
|
if os.path.isfile(root):
|
||||||
@@ -257,13 +251,10 @@ if __name__ == "__main__":
|
|||||||
) # Calculate the average processing time
|
) # Calculate the average processing time
|
||||||
print(f"Average processing time per video: {average_time:.2f} seconds")
|
print(f"Average processing time per video: {average_time:.2f} seconds")
|
||||||
runtime.shutdown()
|
runtime.shutdown()
|
||||||
# except Exception as e:
|
|
||||||
# print(e)
|
|
||||||
runtime.shutdown()
|
|
||||||
|
|
||||||
# # # Run a batch of requests
|
# # Run a batch of requests
|
||||||
# print("\n========== batch ==========\n")
|
# print("\n========== batch ==========\n")
|
||||||
# if not os.path.exists(args.save_dir):
|
# if not os.path.exists(args.save_dir):
|
||||||
# os.makedirs(args.save_dir)
|
# os.makedirs(args.save_dir)
|
||||||
# batch(args.video_dir,args.save_dir,cur_chunk, num_chunks, num_frames, num_chunks)
|
# batch(args.video_dir, args.save_dir, cur_chunk, num_chunks, num_frames, num_chunks)
|
||||||
# runtime.shutdown()
|
# runtime.shutdown()
|
||||||
|
|||||||
26
python/sglang/launch_server_llavavid.py
Normal file
26
python/sglang/launch_server_llavavid.py
Normal file
@@ -0,0 +1,26 @@
|
|||||||
|
"""Launch the inference server for Llava-video model."""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
from sglang.srt.server import ServerArgs, launch_server
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
ServerArgs.add_cli_args(parser)
|
||||||
|
args = parser.parse_args()
|
||||||
|
server_args = ServerArgs.from_cli_args(args)
|
||||||
|
|
||||||
|
model_overide_args = {}
|
||||||
|
model_overide_args["mm_spatial_pool_stride"] = 2
|
||||||
|
model_overide_args["architectures"] = ["LlavaVidForCausalLM"]
|
||||||
|
model_overide_args["num_frames"] = 16
|
||||||
|
model_overide_args["model_type"] = "llavavid"
|
||||||
|
if model_overide_args["num_frames"] == 32:
|
||||||
|
model_overide_args["rope_scaling"] = {"factor": 2.0, "type": "linear"}
|
||||||
|
model_overide_args["max_sequence_length"] = 4096 * 2
|
||||||
|
model_overide_args["tokenizer_model_max_length"] = 4096 * 2
|
||||||
|
model_overide_args["model_max_length"] = 4096 * 2
|
||||||
|
if "34b" in args.model_path.lower():
|
||||||
|
model_overide_args["image_token_index"] = 64002
|
||||||
|
|
||||||
|
launch_server(server_args, model_overide_args, None)
|
||||||
@@ -119,24 +119,7 @@ def get_tokenizer(
|
|||||||
tokenizer_revision: Optional[str] = None,
|
tokenizer_revision: Optional[str] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]:
|
) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]:
|
||||||
if tokenizer_name.endswith(".json"):
|
|
||||||
return TiktokenTokenizer(tokenizer_name)
|
|
||||||
|
|
||||||
if tokenizer_name.endswith(".model"):
|
|
||||||
return SentencePieceTokenizer(tokenizer_name)
|
|
||||||
|
|
||||||
"""Gets a tokenizer for the given model name via Huggingface."""
|
"""Gets a tokenizer for the given model name via Huggingface."""
|
||||||
if is_multimodal_model(tokenizer_name):
|
|
||||||
processor = get_processor(
|
|
||||||
tokenizer_name,
|
|
||||||
*args,
|
|
||||||
trust_remote_code=trust_remote_code,
|
|
||||||
tokenizer_revision=tokenizer_revision,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
tokenizer = processor.tokenizer
|
|
||||||
return tokenizer
|
|
||||||
|
|
||||||
if tokenizer_mode == "slow":
|
if tokenizer_mode == "slow":
|
||||||
if kwargs.get("use_fast", False):
|
if kwargs.get("use_fast", False):
|
||||||
raise ValueError("Cannot use the fast tokenizer in slow tokenizer mode.")
|
raise ValueError("Cannot use the fast tokenizer in slow tokenizer mode.")
|
||||||
@@ -199,135 +182,3 @@ def get_processor(
|
|||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
return processor
|
return processor
|
||||||
|
|
||||||
|
|
||||||
class TiktokenTokenizer:
|
|
||||||
def __init__(self, tokenizer_path):
|
|
||||||
import tiktoken
|
|
||||||
from jinja2 import Template
|
|
||||||
|
|
||||||
PAT_STR_B = r"""(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+"""
|
|
||||||
|
|
||||||
# Read JSON
|
|
||||||
name = "tmp-json"
|
|
||||||
with open(tokenizer_path, "rb") as fin:
|
|
||||||
tok_dict = json.load(fin)
|
|
||||||
|
|
||||||
mergeable_ranks = {
|
|
||||||
bytes(item["bytes"]): item["token"] for item in tok_dict["regular_tokens"]
|
|
||||||
}
|
|
||||||
special_tokens = {
|
|
||||||
bytes(item["bytes"]).decode(): item["token"]
|
|
||||||
for item in tok_dict["special_tokens"]
|
|
||||||
}
|
|
||||||
assert tok_dict["word_split"] == "V1"
|
|
||||||
|
|
||||||
default_allowed_special = None
|
|
||||||
|
|
||||||
kwargs = {
|
|
||||||
"name": name,
|
|
||||||
"pat_str": tok_dict.get("pat_str", PAT_STR_B),
|
|
||||||
"mergeable_ranks": mergeable_ranks,
|
|
||||||
"special_tokens": special_tokens,
|
|
||||||
}
|
|
||||||
if "default_allowed_special" in tok_dict:
|
|
||||||
default_allowed_special = set(
|
|
||||||
[
|
|
||||||
bytes(bytes_list).decode()
|
|
||||||
for bytes_list in tok_dict["default_allowed_special"]
|
|
||||||
]
|
|
||||||
)
|
|
||||||
if "vocab_size" in tok_dict:
|
|
||||||
kwargs["explicit_n_vocab"] = tok_dict["vocab_size"]
|
|
||||||
|
|
||||||
PAD = "<|pad|>"
|
|
||||||
EOS = "<|eos|>"
|
|
||||||
SEP = "<|separator|>"
|
|
||||||
|
|
||||||
DEFAULT_CONTROL_TOKENS = {"pad": PAD, "sep": EOS, "eos": SEP}
|
|
||||||
|
|
||||||
tokenizer = tiktoken.Encoding(**kwargs)
|
|
||||||
tokenizer._default_allowed_special = default_allowed_special or set()
|
|
||||||
tokenizer._control_tokens = DEFAULT_CONTROL_TOKENS
|
|
||||||
|
|
||||||
def encode_patched(
|
|
||||||
self,
|
|
||||||
text: str,
|
|
||||||
*,
|
|
||||||
allowed_special: Union[
|
|
||||||
Literal["all"], AbstractSet[str]
|
|
||||||
] = set(), # noqa: B006
|
|
||||||
disallowed_special: Union[Literal["all"], Collection[str]] = "all",
|
|
||||||
) -> List[int]:
|
|
||||||
if isinstance(allowed_special, set):
|
|
||||||
allowed_special |= self._default_allowed_special
|
|
||||||
return tiktoken.Encoding.encode(
|
|
||||||
self,
|
|
||||||
text,
|
|
||||||
allowed_special=allowed_special,
|
|
||||||
disallowed_special=(),
|
|
||||||
)
|
|
||||||
|
|
||||||
tokenizer.encode = functools.partial(encode_patched, tokenizer)
|
|
||||||
|
|
||||||
# Convert to HF interface
|
|
||||||
self.tokenizer = tokenizer
|
|
||||||
self.eos_token_id = tokenizer._special_tokens[EOS]
|
|
||||||
self.vocab_size = tokenizer.n_vocab
|
|
||||||
self.chat_template = Template(
|
|
||||||
"{% for message in messages %}{% if message['role'] == 'user' %}{{ 'Human: ' + message['content'].strip() + '<|separator|>\n\n' }}{% elif message['role'] == 'system' %}{{ 'System: ' + message['content'].strip() + '<|separator|>\n\n' }}{% elif message['role'] == 'assistant' %}{{ 'Assistant: ' + message['content'] + '<|separator|>\n\n' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ 'Assistant:' }}{% endif %}"
|
|
||||||
)
|
|
||||||
|
|
||||||
def encode(self, x, add_special_tokens=False):
|
|
||||||
return self.tokenizer.encode(x)
|
|
||||||
|
|
||||||
def decode(self, x):
|
|
||||||
return self.tokenizer.decode(x)
|
|
||||||
|
|
||||||
def batch_decode(
|
|
||||||
self, batch, skip_special_tokens=True, spaces_between_special_tokens=False
|
|
||||||
):
|
|
||||||
if isinstance(batch[0], int):
|
|
||||||
batch = [[x] for x in batch]
|
|
||||||
return self.tokenizer.decode_batch(batch)
|
|
||||||
|
|
||||||
def apply_chat_template(self, messages, tokenize, add_generation_prompt):
|
|
||||||
ret = self.chat_template.render(
|
|
||||||
messages=messages, add_generation_prompt=add_generation_prompt
|
|
||||||
)
|
|
||||||
return self.encode(ret) if tokenize else ret
|
|
||||||
|
|
||||||
|
|
||||||
class SentencePieceTokenizer:
|
|
||||||
def __init__(self, tokenizer_path):
|
|
||||||
import sentencepiece as spm
|
|
||||||
from jinja2 import Template
|
|
||||||
|
|
||||||
tokenizer = spm.SentencePieceProcessor(model_file=tokenizer_path)
|
|
||||||
|
|
||||||
# Convert to HF interface
|
|
||||||
self.tokenizer = tokenizer
|
|
||||||
self.eos_token_id = tokenizer.eos_id()
|
|
||||||
self.vocab_size = tokenizer.vocab_size()
|
|
||||||
self.chat_template = Template(
|
|
||||||
"{% for message in messages %}{% if message['role'] == 'user' %}{{ 'Human: ' + message['content'].strip() + '<|separator|>\n\n' }}{% elif message['role'] == 'system' %}{{ 'System: ' + message['content'].strip() + '<|separator|>\n\n' }}{% elif message['role'] == 'assistant' %}{{ 'Assistant: ' + message['content'] + '<|separator|>\n\n' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ 'Assistant:' }}{% endif %}"
|
|
||||||
)
|
|
||||||
|
|
||||||
def encode(self, x, add_special_tokens=False):
|
|
||||||
return self.tokenizer.encode(x)
|
|
||||||
|
|
||||||
def decode(self, x):
|
|
||||||
return self.tokenizer.decode(x)
|
|
||||||
|
|
||||||
def batch_decode(
|
|
||||||
self, batch, skip_special_tokens=True, spaces_between_special_tokens=False
|
|
||||||
):
|
|
||||||
if isinstance(batch[0], int):
|
|
||||||
batch = [[x] for x in batch]
|
|
||||||
return self.tokenizer.decode(batch)
|
|
||||||
|
|
||||||
def apply_chat_template(self, messages, tokenize, add_generation_prompt):
|
|
||||||
ret = self.chat_template.render(
|
|
||||||
messages=messages, add_generation_prompt=add_generation_prompt
|
|
||||||
)
|
|
||||||
return self.encode(ret) if tokenize else ret
|
|
||||||
|
|||||||
@@ -55,6 +55,7 @@ class GenerateReqInput:
|
|||||||
self.text is not None and self.input_ids is not None
|
self.text is not None and self.input_ids is not None
|
||||||
):
|
):
|
||||||
raise ValueError("Either text or input_ids should be provided.")
|
raise ValueError("Either text or input_ids should be provided.")
|
||||||
|
|
||||||
if (
|
if (
|
||||||
isinstance(self.sampling_params, dict)
|
isinstance(self.sampling_params, dict)
|
||||||
and self.sampling_params.get("n", 1) != 1
|
and self.sampling_params.get("n", 1) != 1
|
||||||
@@ -161,10 +162,10 @@ class TokenizedGenerateReqInput:
|
|||||||
input_ids: List[int]
|
input_ids: List[int]
|
||||||
# The pixel values for input images
|
# The pixel values for input images
|
||||||
pixel_values: List[float]
|
pixel_values: List[float]
|
||||||
# The hash of input images
|
# The hash values of input images
|
||||||
image_hash: int
|
image_hashes: List[int]
|
||||||
# The image size
|
# The image sizes
|
||||||
image_size: List[int]
|
image_sizes: List[List[int]]
|
||||||
# The sampling parameters
|
# The sampling parameters
|
||||||
sampling_params: SamplingParams
|
sampling_params: SamplingParams
|
||||||
# Whether to return the logprobs
|
# Whether to return the logprobs
|
||||||
|
|||||||
@@ -121,8 +121,8 @@ class Req:
|
|||||||
|
|
||||||
# For vision input
|
# For vision input
|
||||||
self.pixel_values = None
|
self.pixel_values = None
|
||||||
self.image_size = None
|
self.image_sizes = None
|
||||||
self.image_offset = None
|
self.image_offsets = None
|
||||||
self.pad_value = None
|
self.pad_value = None
|
||||||
|
|
||||||
# Prefix info
|
# Prefix info
|
||||||
@@ -600,12 +600,12 @@ class ScheduleBatch:
|
|||||||
if req.pixel_values is not None:
|
if req.pixel_values is not None:
|
||||||
(
|
(
|
||||||
req.origin_input_ids,
|
req.origin_input_ids,
|
||||||
req.image_offset,
|
req.image_offsets,
|
||||||
) = model_runner.model.pad_input_ids(
|
) = model_runner.model.pad_input_ids(
|
||||||
req.origin_input_ids_unpadded,
|
req.origin_input_ids_unpadded,
|
||||||
req.pad_value,
|
req.pad_value,
|
||||||
req.pixel_values.shape,
|
req.pixel_values,
|
||||||
req.image_size,
|
req.image_sizes,
|
||||||
)
|
)
|
||||||
|
|
||||||
jump_forward_reqs.append(req)
|
jump_forward_reqs.append(req)
|
||||||
|
|||||||
@@ -23,6 +23,7 @@ import multiprocessing as mp
|
|||||||
import os
|
import os
|
||||||
from typing import Dict, List, Optional, Tuple, Union
|
from typing import Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
|
import fastapi
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import transformers
|
import transformers
|
||||||
import uvloop
|
import uvloop
|
||||||
@@ -96,21 +97,18 @@ class TokenizerManager:
|
|||||||
trust_remote_code=server_args.trust_remote_code,
|
trust_remote_code=server_args.trust_remote_code,
|
||||||
model_overide_args=model_overide_args,
|
model_overide_args=model_overide_args,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.is_generation = is_generation_model(
|
self.is_generation = is_generation_model(
|
||||||
self.hf_config.architectures, self.server_args.is_embedding
|
self.hf_config.architectures, self.server_args.is_embedding
|
||||||
)
|
)
|
||||||
|
self.context_len = server_args.context_length or get_context_length(
|
||||||
if server_args.context_length is not None:
|
self.hf_config
|
||||||
self.context_len = server_args.context_length
|
)
|
||||||
else:
|
|
||||||
self.context_len = get_context_length(self.hf_config)
|
|
||||||
|
|
||||||
# Create tokenizer
|
# Create tokenizer
|
||||||
if server_args.skip_tokenizer_init:
|
if server_args.skip_tokenizer_init:
|
||||||
self.tokenizer = self.processor = None
|
self.tokenizer = self.processor = None
|
||||||
else:
|
else:
|
||||||
if is_multimodal_model(self.model_path):
|
if is_multimodal_model(self.hf_config.architectures):
|
||||||
self.processor = get_processor(
|
self.processor = get_processor(
|
||||||
server_args.tokenizer_path,
|
server_args.tokenizer_path,
|
||||||
tokenizer_mode=server_args.tokenizer_mode,
|
tokenizer_mode=server_args.tokenizer_mode,
|
||||||
@@ -118,6 +116,9 @@ class TokenizerManager:
|
|||||||
)
|
)
|
||||||
self.tokenizer = self.processor.tokenizer
|
self.tokenizer = self.processor.tokenizer
|
||||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||||
|
|
||||||
|
# We want to parallelize the image pre-processing so we
|
||||||
|
# create an executor for it
|
||||||
self.executor = concurrent.futures.ProcessPoolExecutor(
|
self.executor = concurrent.futures.ProcessPoolExecutor(
|
||||||
initializer=init_global_processor,
|
initializer=init_global_processor,
|
||||||
mp_context=mp.get_context("fork"),
|
mp_context=mp.get_context("fork"),
|
||||||
@@ -134,12 +135,14 @@ class TokenizerManager:
|
|||||||
self.to_create_loop = True
|
self.to_create_loop = True
|
||||||
self.rid_to_state: Dict[str, ReqState] = {}
|
self.rid_to_state: Dict[str, ReqState] = {}
|
||||||
|
|
||||||
# for update model weights
|
# For update model weights
|
||||||
self.model_update_lock = asyncio.Lock()
|
self.model_update_lock = asyncio.Lock()
|
||||||
self.model_update_result = None
|
self.model_update_result = None
|
||||||
|
|
||||||
async def generate_request(
|
async def generate_request(
|
||||||
self, obj: Union[GenerateReqInput, EmbeddingReqInput], request=None
|
self,
|
||||||
|
obj: Union[GenerateReqInput, EmbeddingReqInput],
|
||||||
|
request: Optional[fastapi.Request] = None,
|
||||||
):
|
):
|
||||||
if self.to_create_loop:
|
if self.to_create_loop:
|
||||||
self.create_handle_loop()
|
self.create_handle_loop()
|
||||||
@@ -160,7 +163,7 @@ class TokenizerManager:
|
|||||||
async def _handle_single_request(
|
async def _handle_single_request(
|
||||||
self,
|
self,
|
||||||
obj: Union[GenerateReqInput, EmbeddingReqInput],
|
obj: Union[GenerateReqInput, EmbeddingReqInput],
|
||||||
request,
|
request: Optional[fastapi.Request] = None,
|
||||||
index: Optional[int] = None,
|
index: Optional[int] = None,
|
||||||
is_cache_for_prefill: Optional[bool] = False,
|
is_cache_for_prefill: Optional[bool] = False,
|
||||||
):
|
):
|
||||||
@@ -182,8 +185,8 @@ class TokenizerManager:
|
|||||||
)
|
)
|
||||||
|
|
||||||
if self.is_generation:
|
if self.is_generation:
|
||||||
pixel_values, image_hash, image_size = await self._get_pixel_values(
|
pixel_values, image_hashes, image_sizes = await self._get_pixel_values(
|
||||||
obj.image_data
|
obj.image_data if not_use_index else obj.image_data[index]
|
||||||
)
|
)
|
||||||
return_logprob = (
|
return_logprob = (
|
||||||
obj.return_logprob if not_use_index else obj.return_logprob[index]
|
obj.return_logprob if not_use_index else obj.return_logprob[index]
|
||||||
@@ -195,7 +198,6 @@ class TokenizerManager:
|
|||||||
)
|
)
|
||||||
if return_logprob and logprob_start_len == -1:
|
if return_logprob and logprob_start_len == -1:
|
||||||
logprob_start_len = len(input_ids) - 1
|
logprob_start_len = len(input_ids) - 1
|
||||||
|
|
||||||
top_logprobs_num = (
|
top_logprobs_num = (
|
||||||
obj.top_logprobs_num
|
obj.top_logprobs_num
|
||||||
if not_use_index
|
if not_use_index
|
||||||
@@ -238,13 +240,14 @@ class TokenizerManager:
|
|||||||
|
|
||||||
sampling_params = SamplingParams(**obj.sampling_params[0])
|
sampling_params = SamplingParams(**obj.sampling_params[0])
|
||||||
sampling_params.max_new_tokens = 0
|
sampling_params.max_new_tokens = 0
|
||||||
pixel_values, image_hash, image_size = await self._get_pixel_values(
|
pixel_values, image_hashes, image_sizes = await self._get_pixel_values(
|
||||||
obj.image_data[0]
|
obj.image_data[0]
|
||||||
)
|
)
|
||||||
return_logprob = obj.return_logprob[0]
|
return_logprob = obj.return_logprob[0]
|
||||||
logprob_start_len = obj.logprob_start_len[0]
|
logprob_start_len = obj.logprob_start_len[0]
|
||||||
top_logprobs_num = obj.top_logprobs_num[0]
|
top_logprobs_num = obj.top_logprobs_num[0]
|
||||||
|
|
||||||
|
# Send to the controller
|
||||||
if self.is_generation:
|
if self.is_generation:
|
||||||
if return_logprob and logprob_start_len == -1:
|
if return_logprob and logprob_start_len == -1:
|
||||||
logprob_start_len = len(input_ids) - 1
|
logprob_start_len = len(input_ids) - 1
|
||||||
@@ -253,8 +256,8 @@ class TokenizerManager:
|
|||||||
input_text,
|
input_text,
|
||||||
input_ids,
|
input_ids,
|
||||||
pixel_values,
|
pixel_values,
|
||||||
image_hash,
|
image_hashes,
|
||||||
image_size,
|
image_sizes,
|
||||||
sampling_params,
|
sampling_params,
|
||||||
return_logprob,
|
return_logprob,
|
||||||
logprob_start_len,
|
logprob_start_len,
|
||||||
@@ -268,24 +271,24 @@ class TokenizerManager:
|
|||||||
input_ids,
|
input_ids,
|
||||||
sampling_params,
|
sampling_params,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.send_to_router.send_pyobj(tokenized_obj)
|
self.send_to_router.send_pyobj(tokenized_obj)
|
||||||
|
|
||||||
|
# Recv results
|
||||||
event = asyncio.Event()
|
event = asyncio.Event()
|
||||||
state = ReqState([], False, event)
|
state = ReqState([], False, event)
|
||||||
self.rid_to_state[rid] = state
|
self.rid_to_state[rid] = state
|
||||||
if not is_cache_for_prefill:
|
if not is_cache_for_prefill:
|
||||||
async for response in self._wait_for_response(
|
async for response in self._wait_for_response(state, obj, rid, request):
|
||||||
event, state, obj, rid, request
|
|
||||||
):
|
|
||||||
yield response
|
yield response
|
||||||
else:
|
else:
|
||||||
assert self.is_generation
|
assert self.is_generation
|
||||||
await self._wait_for_cache_prefill_response(event, state, obj, rid, request)
|
await self._wait_for_cache_prefill_response(state, obj, rid, request)
|
||||||
yield input_ids
|
yield input_ids
|
||||||
|
|
||||||
async def _handle_batch_request(
|
async def _handle_batch_request(
|
||||||
self, obj: Union[GenerateReqInput, EmbeddingReqInput], request
|
self,
|
||||||
|
obj: Union[GenerateReqInput, EmbeddingReqInput],
|
||||||
|
request: Optional[fastapi.Request] = None,
|
||||||
):
|
):
|
||||||
batch_size = obj.batch_size
|
batch_size = obj.batch_size
|
||||||
if self.is_generation:
|
if self.is_generation:
|
||||||
@@ -340,8 +343,8 @@ class TokenizerManager:
|
|||||||
if self.is_generation:
|
if self.is_generation:
|
||||||
if obj.return_logprob[index] and obj.logprob_start_len[index] == -1:
|
if obj.return_logprob[index] and obj.logprob_start_len[index] == -1:
|
||||||
obj.logprob_start_len[index] = len(input_ids) - 1
|
obj.logprob_start_len[index] = len(input_ids) - 1
|
||||||
pixel_values, image_hash, image_size = await self._get_pixel_values(
|
pixel_values, image_hashes, image_sizes = (
|
||||||
obj.image_data[index]
|
await self._get_pixel_values(obj.image_data[index])
|
||||||
)
|
)
|
||||||
|
|
||||||
tokenized_obj = TokenizedGenerateReqInput(
|
tokenized_obj = TokenizedGenerateReqInput(
|
||||||
@@ -349,8 +352,8 @@ class TokenizerManager:
|
|||||||
input_text,
|
input_text,
|
||||||
input_ids,
|
input_ids,
|
||||||
pixel_values,
|
pixel_values,
|
||||||
image_hash,
|
image_hashes,
|
||||||
image_size,
|
image_sizes,
|
||||||
sampling_params,
|
sampling_params,
|
||||||
obj.return_logprob[index],
|
obj.return_logprob[index],
|
||||||
obj.logprob_start_len[index],
|
obj.logprob_start_len[index],
|
||||||
@@ -372,7 +375,6 @@ class TokenizerManager:
|
|||||||
|
|
||||||
generators.append(
|
generators.append(
|
||||||
self._wait_for_response(
|
self._wait_for_response(
|
||||||
event,
|
|
||||||
state,
|
state,
|
||||||
obj,
|
obj,
|
||||||
rid,
|
rid,
|
||||||
@@ -388,6 +390,7 @@ class TokenizerManager:
|
|||||||
tasks = [asyncio.create_task(gen.__anext__()) for gen in generators]
|
tasks = [asyncio.create_task(gen.__anext__()) for gen in generators]
|
||||||
output_list = [None] * len(tasks)
|
output_list = [None] * len(tasks)
|
||||||
|
|
||||||
|
# Recv results
|
||||||
while tasks:
|
while tasks:
|
||||||
done, _ = await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED)
|
done, _ = await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED)
|
||||||
|
|
||||||
@@ -426,25 +429,18 @@ class TokenizerManager:
|
|||||||
sampling_params.verify()
|
sampling_params.verify()
|
||||||
return sampling_params
|
return sampling_params
|
||||||
|
|
||||||
async def _get_pixel_values(self, image_data):
|
|
||||||
if image_data is None:
|
|
||||||
return None, None, None
|
|
||||||
else:
|
|
||||||
return await self._get_pixel_values_internal(image_data)
|
|
||||||
|
|
||||||
async def _wait_for_response(
|
async def _wait_for_response(
|
||||||
self,
|
self,
|
||||||
event: asyncio.Event,
|
|
||||||
state: ReqState,
|
state: ReqState,
|
||||||
obj: Union[GenerateReqInput, EmbeddingReqInput],
|
obj: Union[GenerateReqInput, EmbeddingReqInput],
|
||||||
rid: str,
|
rid: str,
|
||||||
request,
|
request: Optional[fastapi.Request] = None,
|
||||||
index: int = None,
|
index: Optional[int] = None,
|
||||||
response_index: int = 0,
|
response_index: int = 0,
|
||||||
):
|
):
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
await asyncio.wait_for(event.wait(), timeout=4)
|
await asyncio.wait_for(state.event.wait(), timeout=4)
|
||||||
except asyncio.TimeoutError:
|
except asyncio.TimeoutError:
|
||||||
if request is not None and await request.is_disconnected():
|
if request is not None and await request.is_disconnected():
|
||||||
for rid in [obj.rid] if obj.is_single else obj.rid:
|
for rid in [obj.rid] if obj.is_single else obj.rid:
|
||||||
@@ -478,16 +474,15 @@ class TokenizerManager:
|
|||||||
yield out
|
yield out
|
||||||
break
|
break
|
||||||
|
|
||||||
event.clear()
|
state.event.clear()
|
||||||
yield out
|
yield out
|
||||||
|
|
||||||
async def _wait_for_cache_prefill_response(
|
async def _wait_for_cache_prefill_response(
|
||||||
self,
|
self,
|
||||||
event: asyncio.Event,
|
|
||||||
state: ReqState,
|
state: ReqState,
|
||||||
obj: GenerateReqInput,
|
obj: GenerateReqInput,
|
||||||
rid: str,
|
rid: str,
|
||||||
request,
|
request: Optional[fastapi.Request] = None,
|
||||||
):
|
):
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
@@ -514,7 +509,9 @@ class TokenizerManager:
|
|||||||
req = AbortReq(rid)
|
req = AbortReq(rid)
|
||||||
self.send_to_router.send_pyobj(req)
|
self.send_to_router.send_pyobj(req)
|
||||||
|
|
||||||
async def update_weights(self, obj: UpdateWeightReqInput, request):
|
async def update_weights(
|
||||||
|
self, obj: UpdateWeightReqInput, request: Optional[fastapi.Request] = None
|
||||||
|
):
|
||||||
if self.to_create_loop:
|
if self.to_create_loop:
|
||||||
self.create_handle_loop()
|
self.create_handle_loop()
|
||||||
|
|
||||||
@@ -659,12 +656,11 @@ class TokenizerManager:
|
|||||||
)
|
)
|
||||||
return top_logprobs
|
return top_logprobs
|
||||||
|
|
||||||
async def _get_pixel_values_internal(self, image_data, aspect_ratio=None):
|
async def _get_pixel_values(self, image_data: List[Union[str, bytes]]):
|
||||||
aspect_ratio = (
|
if not image_data:
|
||||||
getattr(self.hf_config, "image_aspect_ratio", None)
|
return None, None, None
|
||||||
if aspect_ratio is None
|
|
||||||
else aspect_ratio
|
aspect_ratio = getattr(self.hf_config, "image_aspect_ratio", None)
|
||||||
)
|
|
||||||
grid_pinpoints = (
|
grid_pinpoints = (
|
||||||
self.hf_config.image_grid_pinpoints
|
self.hf_config.image_grid_pinpoints
|
||||||
if hasattr(self.hf_config, "image_grid_pinpoints")
|
if hasattr(self.hf_config, "image_grid_pinpoints")
|
||||||
@@ -673,35 +669,42 @@ class TokenizerManager:
|
|||||||
)
|
)
|
||||||
|
|
||||||
if isinstance(image_data, list) and len(image_data) > 0:
|
if isinstance(image_data, list) and len(image_data) > 0:
|
||||||
pixel_values, image_hash, image_size = [], [], []
|
# Multiple images
|
||||||
if len(image_data) > 1:
|
if len(image_data) > 1:
|
||||||
aspect_ratio = "pad" # LLaVA OneVision Handling: more than one image --> interleaved image mode or video mode. We do not use anyres
|
aspect_ratio = "pad" # LLaVA OneVision Handling: more than one image --> interleaved image mode or video mode. We do not use anyres
|
||||||
|
pixel_values, image_hashes, image_sizes = [], [], []
|
||||||
for img_data in image_data:
|
for img_data in image_data:
|
||||||
pixel_v, image_h, image_s = await self._process_single_image(
|
pixel_v, image_h, image_s = await self._process_single_image(
|
||||||
img_data, aspect_ratio, grid_pinpoints
|
img_data, aspect_ratio, grid_pinpoints
|
||||||
)
|
)
|
||||||
pixel_values.append(pixel_v)
|
pixel_values.append(pixel_v)
|
||||||
image_hash.append(image_h)
|
image_hashes.append(image_h)
|
||||||
image_size.append(image_s)
|
image_sizes.append(image_s)
|
||||||
pixel_values = np.stack(pixel_values, axis=0)
|
|
||||||
|
if isinstance(pixel_values[0], np.ndarray):
|
||||||
|
pixel_values = np.stack(pixel_values, axis=0)
|
||||||
else:
|
else:
|
||||||
|
# A single image
|
||||||
pixel_values, image_hash, image_size = await self._process_single_image(
|
pixel_values, image_hash, image_size = await self._process_single_image(
|
||||||
image_data[0], aspect_ratio, grid_pinpoints
|
image_data[0], aspect_ratio, grid_pinpoints
|
||||||
)
|
)
|
||||||
image_hash = [image_hash]
|
image_hashes = [image_hash]
|
||||||
image_size = [image_size]
|
image_sizes = [image_size]
|
||||||
elif isinstance(image_data, str):
|
elif isinstance(image_data, str):
|
||||||
|
# A single image
|
||||||
pixel_values, image_hash, image_size = await self._process_single_image(
|
pixel_values, image_hash, image_size = await self._process_single_image(
|
||||||
image_data, aspect_ratio, grid_pinpoints
|
image_data, aspect_ratio, grid_pinpoints
|
||||||
)
|
)
|
||||||
image_hash = [image_hash]
|
image_hashes = [image_hash]
|
||||||
image_size = [image_size]
|
image_sizes = [image_size]
|
||||||
else:
|
else:
|
||||||
pixel_values, image_hash, image_size = None, None, None
|
raise ValueError(f"Invalid image data: {image_data}")
|
||||||
|
|
||||||
return pixel_values, image_hash, image_size
|
return pixel_values, image_hashes, image_sizes
|
||||||
|
|
||||||
async def _process_single_image(self, image_data, aspect_ratio, grid_pinpoints):
|
async def _process_single_image(
|
||||||
|
self, image_data: Union[bytes, str], aspect_ratio: str, grid_pinpoints: str
|
||||||
|
):
|
||||||
if self.executor is not None:
|
if self.executor is not None:
|
||||||
loop = asyncio.get_event_loop()
|
loop = asyncio.get_event_loop()
|
||||||
return await loop.run_in_executor(
|
return await loop.run_in_executor(
|
||||||
@@ -732,12 +735,16 @@ def init_global_processor(server_args: ServerArgs):
|
|||||||
|
|
||||||
|
|
||||||
def _process_single_image_task(
|
def _process_single_image_task(
|
||||||
image_data, image_aspect_ratio=None, image_grid_pinpoints=None, processor=None
|
image_data: Union[str, bytes],
|
||||||
|
image_aspect_ratio: Optional[str] = None,
|
||||||
|
image_grid_pinpoints: Optional[str] = None,
|
||||||
|
processor=None,
|
||||||
):
|
):
|
||||||
try:
|
try:
|
||||||
processor = processor or global_processor
|
processor = processor or global_processor
|
||||||
image, image_size = load_image(image_data)
|
image, image_size = load_image(image_data)
|
||||||
if image_size is not None:
|
if image_size is not None:
|
||||||
|
# It is a video with multiple images
|
||||||
image_hash = hash(image_data)
|
image_hash = hash(image_data)
|
||||||
pixel_values = processor.image_processor(image)["pixel_values"]
|
pixel_values = processor.image_processor(image)["pixel_values"]
|
||||||
for _ in range(len(pixel_values)):
|
for _ in range(len(pixel_values)):
|
||||||
@@ -745,6 +752,7 @@ def _process_single_image_task(
|
|||||||
pixel_values = np.stack(pixel_values, axis=0)
|
pixel_values = np.stack(pixel_values, axis=0)
|
||||||
return pixel_values, image_hash, image_size
|
return pixel_values, image_hash, image_size
|
||||||
else:
|
else:
|
||||||
|
# It is an image
|
||||||
image_hash = hash(image_data)
|
image_hash = hash(image_data)
|
||||||
if image_aspect_ratio == "pad":
|
if image_aspect_ratio == "pad":
|
||||||
image = expand2square(
|
image = expand2square(
|
||||||
@@ -754,13 +762,18 @@ def _process_single_image_task(
|
|||||||
pixel_values = processor.image_processor(image.convert("RGB"))[
|
pixel_values = processor.image_processor(image.convert("RGB"))[
|
||||||
"pixel_values"
|
"pixel_values"
|
||||||
][0]
|
][0]
|
||||||
elif image_aspect_ratio == "anyres" or "anyres_max" in image_aspect_ratio:
|
elif image_aspect_ratio == "anyres" or (
|
||||||
|
image_aspect_ratio is not None and "anyres_max" in image_aspect_ratio
|
||||||
|
):
|
||||||
pixel_values = process_anyres_image(
|
pixel_values = process_anyres_image(
|
||||||
image, processor.image_processor, image_grid_pinpoints
|
image, processor.image_processor, image_grid_pinpoints
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
pixel_values = processor.image_processor(image)["pixel_values"][0]
|
pixel_values = processor.image_processor(image)["pixel_values"][0]
|
||||||
pixel_values = pixel_values.astype(np.float16)
|
|
||||||
|
if isinstance(pixel_values, np.ndarray):
|
||||||
|
pixel_values = pixel_values.astype(np.float16)
|
||||||
|
|
||||||
return pixel_values, image_hash, image.size
|
return pixel_values, image_hash, image.size
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.error("Exception in TokenizerManager:\n" + get_exception_traceback())
|
logger.error("Exception in TokenizerManager:\n" + get_exception_traceback())
|
||||||
|
|||||||
@@ -108,7 +108,7 @@ class ModelTpServer:
|
|||||||
if server_args.skip_tokenizer_init:
|
if server_args.skip_tokenizer_init:
|
||||||
self.tokenizer = self.processor = None
|
self.tokenizer = self.processor = None
|
||||||
else:
|
else:
|
||||||
if is_multimodal_model(server_args.model_path):
|
if is_multimodal_model(self.model_config.hf_config.architectures):
|
||||||
self.processor = get_processor(
|
self.processor = get_processor(
|
||||||
server_args.tokenizer_path,
|
server_args.tokenizer_path,
|
||||||
tokenizer_mode=server_args.tokenizer_mode,
|
tokenizer_mode=server_args.tokenizer_mode,
|
||||||
@@ -333,26 +333,24 @@ class ModelTpServer:
|
|||||||
if self.model_runner.is_generation:
|
if self.model_runner.is_generation:
|
||||||
req.pixel_values = recv_req.pixel_values
|
req.pixel_values = recv_req.pixel_values
|
||||||
if req.pixel_values is not None:
|
if req.pixel_values is not None:
|
||||||
image_hash = (
|
# Use image hash as fake token_ids, which is then used
|
||||||
hash(tuple(recv_req.image_hash))
|
# for prefix matching
|
||||||
if isinstance(recv_req.image_hash, list)
|
image_hash = hash(tuple(recv_req.image_hashes))
|
||||||
else recv_req.image_hash
|
|
||||||
)
|
|
||||||
req.pad_value = [
|
req.pad_value = [
|
||||||
(image_hash) % self.model_config.vocab_size,
|
(image_hash) % self.model_config.vocab_size,
|
||||||
(image_hash >> 16) % self.model_config.vocab_size,
|
(image_hash >> 16) % self.model_config.vocab_size,
|
||||||
(image_hash >> 32) % self.model_config.vocab_size,
|
(image_hash >> 32) % self.model_config.vocab_size,
|
||||||
(image_hash >> 64) % self.model_config.vocab_size,
|
(image_hash >> 64) % self.model_config.vocab_size,
|
||||||
]
|
]
|
||||||
req.image_size = recv_req.image_size
|
req.image_sizes = recv_req.image_sizes
|
||||||
(
|
(
|
||||||
req.origin_input_ids,
|
req.origin_input_ids,
|
||||||
req.image_offset,
|
req.image_offsets,
|
||||||
) = self.model_runner.model.pad_input_ids(
|
) = self.model_runner.model.pad_input_ids(
|
||||||
req.origin_input_ids_unpadded,
|
req.origin_input_ids_unpadded,
|
||||||
req.pad_value,
|
req.pad_value,
|
||||||
req.pixel_values.shape,
|
req.pixel_values,
|
||||||
req.image_size,
|
req.image_sizes,
|
||||||
)
|
)
|
||||||
req.return_logprob = recv_req.return_logprob
|
req.return_logprob = recv_req.return_logprob
|
||||||
req.logprob_start_len = recv_req.logprob_start_len
|
req.logprob_start_len = recv_req.logprob_start_len
|
||||||
@@ -368,6 +366,7 @@ class ModelTpServer:
|
|||||||
req.jump_forward_map = self.jump_forward_cache.query(
|
req.jump_forward_map = self.jump_forward_cache.query(
|
||||||
computed_regex_string
|
computed_regex_string
|
||||||
)
|
)
|
||||||
|
|
||||||
# Init regex fsm
|
# Init regex fsm
|
||||||
elif req.sampling_params.regex is not None:
|
elif req.sampling_params.regex is not None:
|
||||||
req.regex_fsm = self.regex_fsm_cache.query(req.sampling_params.regex)
|
req.regex_fsm = self.regex_fsm_cache.query(req.sampling_params.regex)
|
||||||
|
|||||||
@@ -16,7 +16,7 @@ limitations under the License.
|
|||||||
"""ModelRunner runs the forward passes of the models."""
|
"""ModelRunner runs the forward passes of the models."""
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from enum import IntEnum, auto
|
from enum import IntEnum, auto
|
||||||
from typing import TYPE_CHECKING, List, Optional
|
from typing import TYPE_CHECKING, List
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
@@ -58,6 +58,7 @@ class InputMetadata:
|
|||||||
|
|
||||||
# For extend
|
# For extend
|
||||||
extend_seq_lens: torch.Tensor = None
|
extend_seq_lens: torch.Tensor = None
|
||||||
|
extend_prefix_lens: torch.Tensor = None
|
||||||
extend_start_loc: torch.Tensor = None
|
extend_start_loc: torch.Tensor = None
|
||||||
extend_no_prefix: bool = None
|
extend_no_prefix: bool = None
|
||||||
|
|
||||||
@@ -69,8 +70,8 @@ class InputMetadata:
|
|||||||
|
|
||||||
# For multimodal
|
# For multimodal
|
||||||
pixel_values: List[torch.Tensor] = None
|
pixel_values: List[torch.Tensor] = None
|
||||||
image_sizes: List[List[int]] = None
|
image_sizes: List[List[List[int]]] = None
|
||||||
image_offsets: List[int] = None
|
image_offsets: List[List[int]] = None
|
||||||
|
|
||||||
# Trition attention backend
|
# Trition attention backend
|
||||||
triton_max_seq_len: int = 0
|
triton_max_seq_len: int = 0
|
||||||
@@ -87,20 +88,8 @@ class InputMetadata:
|
|||||||
def init_multimuldal_info(self, batch: ScheduleBatch):
|
def init_multimuldal_info(self, batch: ScheduleBatch):
|
||||||
reqs = batch.reqs
|
reqs = batch.reqs
|
||||||
self.pixel_values = [r.pixel_values for r in reqs]
|
self.pixel_values = [r.pixel_values for r in reqs]
|
||||||
self.image_sizes = [r.image_size for r in reqs]
|
self.image_sizes = [r.image_sizes for r in reqs]
|
||||||
self.image_offsets = []
|
self.image_offsets = [r.image_offsets for r in reqs]
|
||||||
for r in reqs:
|
|
||||||
if isinstance(r.image_offset, list):
|
|
||||||
self.image_offsets.append(
|
|
||||||
[
|
|
||||||
(image_offset - len(r.prefix_indices))
|
|
||||||
for image_offset in r.image_offset
|
|
||||||
]
|
|
||||||
)
|
|
||||||
elif isinstance(r.image_offset, int):
|
|
||||||
self.image_offsets.append(r.image_offset - len(r.prefix_indices))
|
|
||||||
elif r.image_offset is None:
|
|
||||||
self.image_offsets.append(0)
|
|
||||||
|
|
||||||
def compute_positions(self, batch: ScheduleBatch):
|
def compute_positions(self, batch: ScheduleBatch):
|
||||||
position_ids_offsets = batch.position_ids_offsets
|
position_ids_offsets = batch.position_ids_offsets
|
||||||
@@ -153,6 +142,7 @@ class InputMetadata:
|
|||||||
for i, r in enumerate(batch.reqs)
|
for i, r in enumerate(batch.reqs)
|
||||||
]
|
]
|
||||||
self.extend_seq_lens = torch.tensor(extend_lens_cpu, device="cuda")
|
self.extend_seq_lens = torch.tensor(extend_lens_cpu, device="cuda")
|
||||||
|
self.extend_prefix_lens = torch.tensor(batch.prefix_lens_cpu, device="cuda")
|
||||||
self.extend_start_loc = torch.zeros_like(self.seq_lens)
|
self.extend_start_loc = torch.zeros_like(self.seq_lens)
|
||||||
self.extend_start_loc[1:] = torch.cumsum(self.extend_seq_lens[:-1], dim=0)
|
self.extend_start_loc[1:] = torch.cumsum(self.extend_seq_lens[:-1], dim=0)
|
||||||
self.extend_no_prefix = all(l == 0 for l in batch.prefix_lens_cpu)
|
self.extend_no_prefix = all(l == 0 for l in batch.prefix_lens_cpu)
|
||||||
@@ -238,10 +228,10 @@ class InputMetadata:
|
|||||||
prefix_lens_cpu,
|
prefix_lens_cpu,
|
||||||
flashinfer_use_ragged,
|
flashinfer_use_ragged,
|
||||||
):
|
):
|
||||||
if self.forward_mode != ForwardMode.DECODE:
|
if self.forward_mode == ForwardMode.DECODE:
|
||||||
prefix_lens = torch.tensor(prefix_lens_cpu, device="cuda")
|
|
||||||
else:
|
|
||||||
prefix_lens = None
|
prefix_lens = None
|
||||||
|
else:
|
||||||
|
prefix_lens = self.extend_prefix_lens
|
||||||
|
|
||||||
update_flashinfer_indices(
|
update_flashinfer_indices(
|
||||||
self.forward_mode,
|
self.forward_mode,
|
||||||
|
|||||||
@@ -50,7 +50,7 @@ from sglang.srt.mem_cache.memory_pool import (
|
|||||||
MLATokenToKVPool,
|
MLATokenToKVPool,
|
||||||
ReqToTokenPool,
|
ReqToTokenPool,
|
||||||
)
|
)
|
||||||
from sglang.srt.model_config import AttentionArch
|
from sglang.srt.model_config import AttentionArch, ModelConfig
|
||||||
from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetadata
|
from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetadata
|
||||||
from sglang.srt.server_args import ServerArgs
|
from sglang.srt.server_args import ServerArgs
|
||||||
from sglang.srt.utils import (
|
from sglang.srt.utils import (
|
||||||
@@ -69,7 +69,7 @@ logger = logging.getLogger(__name__)
|
|||||||
class ModelRunner:
|
class ModelRunner:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
model_config,
|
model_config: ModelConfig,
|
||||||
mem_fraction_static: float,
|
mem_fraction_static: float,
|
||||||
gpu_id: int,
|
gpu_id: int,
|
||||||
tp_rank: int,
|
tp_rank: int,
|
||||||
@@ -85,7 +85,9 @@ class ModelRunner:
|
|||||||
self.tp_size = tp_size
|
self.tp_size = tp_size
|
||||||
self.nccl_port = nccl_port
|
self.nccl_port = nccl_port
|
||||||
self.server_args = server_args
|
self.server_args = server_args
|
||||||
self.is_multimodal_model = is_multimodal_model(self.model_config)
|
self.is_multimodal_model = is_multimodal_model(
|
||||||
|
self.model_config.hf_config.architectures
|
||||||
|
)
|
||||||
global_server_args_dict.update(
|
global_server_args_dict.update(
|
||||||
{
|
{
|
||||||
"disable_flashinfer": server_args.disable_flashinfer,
|
"disable_flashinfer": server_args.disable_flashinfer,
|
||||||
@@ -95,6 +97,13 @@ class ModelRunner:
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if self.is_multimodal_model:
|
||||||
|
logger.info(
|
||||||
|
"Automatically turn off --chunked-prefill-size and adjust --mem-fraction-static for multimodal models."
|
||||||
|
)
|
||||||
|
server_args.chunked_prefill_size = None
|
||||||
|
server_args.mem_fraction_static *= 0.95
|
||||||
|
|
||||||
min_per_gpu_memory = self.init_torch_distributed()
|
min_per_gpu_memory = self.init_torch_distributed()
|
||||||
self.load_model()
|
self.load_model()
|
||||||
self.init_memory_pool(
|
self.init_memory_pool(
|
||||||
@@ -507,9 +516,9 @@ class ModelRunner:
|
|||||||
raise Exception(
|
raise Exception(
|
||||||
f"Capture cuda graph failed: {e}\n"
|
f"Capture cuda graph failed: {e}\n"
|
||||||
"Possible solutions:\n"
|
"Possible solutions:\n"
|
||||||
"1. disable torch compile by not using --enable-torch-compile\n"
|
"1. disable cuda graph by --disable-cuda-graph\n"
|
||||||
"2. disable cuda graph by --disable-cuda-graph\n"
|
"2. set --mem-fraction-static to a smaller value\n"
|
||||||
"3. set --mem-fraction-static to a smaller value\n"
|
"3. disable torch compile by not using --enable-torch-compile\n"
|
||||||
"Open an issue on GitHub https://github.com/sgl-project/sglang/issues/new/choose \n"
|
"Open an issue on GitHub https://github.com/sgl-project/sglang/issues/new/choose \n"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -17,7 +17,7 @@ limitations under the License.
|
|||||||
# Adapted from
|
# Adapted from
|
||||||
# https://github.com/THUDM/ChatGLM2-6B
|
# https://github.com/THUDM/ChatGLM2-6B
|
||||||
"""Inference-only ChatGLM model compatible with THUDM weights."""
|
"""Inference-only ChatGLM model compatible with THUDM weights."""
|
||||||
from typing import Iterable, List, Optional, Tuple
|
from typing import Iterable, Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|||||||
@@ -273,9 +273,9 @@ class Grok1Model(nn.Module):
|
|||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
if input_embeds is None:
|
if input_embeds is None:
|
||||||
hidden_states = self.embed_tokens(input_ids)
|
hidden_states = self.embed_tokens(input_ids)
|
||||||
|
hidden_states.mul_(self.config.embedding_multiplier_scale)
|
||||||
else:
|
else:
|
||||||
hidden_states = input_embeds
|
hidden_states = input_embeds
|
||||||
hidden_states.mul_(self.config.embedding_multiplier_scale)
|
|
||||||
|
|
||||||
for i in range(len(self.layers)):
|
for i in range(len(self.layers)):
|
||||||
hidden_states = self.layers[i](positions, hidden_states, input_metadata)
|
hidden_states = self.layers[i](positions, hidden_states, input_metadata)
|
||||||
@@ -284,7 +284,7 @@ class Grok1Model(nn.Module):
|
|||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
class Grok1ModelForCausalLM(nn.Module):
|
class Grok1ForCausalLM(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config: PretrainedConfig,
|
config: PretrainedConfig,
|
||||||
@@ -415,4 +415,10 @@ def _prepare_presharded_weights(
|
|||||||
return hf_folder, hf_weights_files, use_safetensors
|
return hf_folder, hf_weights_files, use_safetensors
|
||||||
|
|
||||||
|
|
||||||
EntryClass = Grok1ModelForCausalLM
|
class Grok1ModelForCausalLM(Grok1ForCausalLM):
|
||||||
|
"""An alias for backward-compatbility."""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
EntryClass = [Grok1ForCausalLM, Grok1ModelForCausalLM]
|
||||||
|
|||||||
@@ -357,6 +357,9 @@ class LlamaForCausalLM(nn.Module):
|
|||||||
# Models trained using ColossalAI may include these tensors in
|
# Models trained using ColossalAI may include these tensors in
|
||||||
# the checkpoint. Skip them.
|
# the checkpoint. Skip them.
|
||||||
return
|
return
|
||||||
|
if name.startswith("model.vision_tower") and name not in params_dict:
|
||||||
|
return
|
||||||
|
|
||||||
for param_name, weight_name, shard_id in stacked_params_mapping:
|
for param_name, weight_name, shard_id in stacked_params_mapping:
|
||||||
if weight_name not in name:
|
if weight_name not in name:
|
||||||
continue
|
continue
|
||||||
@@ -364,8 +367,6 @@ class LlamaForCausalLM(nn.Module):
|
|||||||
# Skip loading extra bias for GPTQ models.
|
# Skip loading extra bias for GPTQ models.
|
||||||
if name.endswith(".bias") and name not in params_dict:
|
if name.endswith(".bias") and name not in params_dict:
|
||||||
continue
|
continue
|
||||||
if name.startswith("model.vision_tower") and name not in params_dict:
|
|
||||||
continue
|
|
||||||
param = params_dict[name]
|
param = params_dict[name]
|
||||||
weight_loader = param.weight_loader
|
weight_loader = param.weight_loader
|
||||||
weight_loader(param, loaded_weight, shard_id)
|
weight_loader(param, loaded_weight, shard_id)
|
||||||
@@ -374,8 +375,6 @@ class LlamaForCausalLM(nn.Module):
|
|||||||
# Skip loading extra bias for GPTQ models.
|
# Skip loading extra bias for GPTQ models.
|
||||||
if name.endswith(".bias") and name not in params_dict:
|
if name.endswith(".bias") and name not in params_dict:
|
||||||
return
|
return
|
||||||
if name.startswith("model.vision_tower") and name not in params_dict:
|
|
||||||
return
|
|
||||||
param = params_dict[name]
|
param = params_dict[name]
|
||||||
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||||
weight_loader(param, loaded_weight)
|
weight_loader(param, loaded_weight)
|
||||||
|
|||||||
@@ -103,8 +103,6 @@ class LlamaForClassification(nn.Module):
|
|||||||
# Skip loading extra bias for GPTQ models.
|
# Skip loading extra bias for GPTQ models.
|
||||||
if name.endswith(".bias") and name not in params_dict:
|
if name.endswith(".bias") and name not in params_dict:
|
||||||
continue
|
continue
|
||||||
if name.startswith("model.vision_tower") and name not in params_dict:
|
|
||||||
continue
|
|
||||||
param = params_dict[name]
|
param = params_dict[name]
|
||||||
weight_loader = param.weight_loader
|
weight_loader = param.weight_loader
|
||||||
weight_loader(param, loaded_weight, shard_id)
|
weight_loader(param, loaded_weight, shard_id)
|
||||||
@@ -113,8 +111,6 @@ class LlamaForClassification(nn.Module):
|
|||||||
# Skip loading extra bias for GPTQ models.
|
# Skip loading extra bias for GPTQ models.
|
||||||
if name.endswith(".bias") and name not in params_dict:
|
if name.endswith(".bias") and name not in params_dict:
|
||||||
continue
|
continue
|
||||||
if name.startswith("model.vision_tower") and name not in params_dict:
|
|
||||||
continue
|
|
||||||
param = params_dict[name]
|
param = params_dict[name]
|
||||||
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||||
weight_loader(param, loaded_weight)
|
weight_loader(param, loaded_weight)
|
||||||
|
|||||||
@@ -57,6 +57,9 @@ class LlamaEmbeddingModel(nn.Module):
|
|||||||
# Models trained using ColossalAI may include these tensors in
|
# Models trained using ColossalAI may include these tensors in
|
||||||
# the checkpoint. Skip them.
|
# the checkpoint. Skip them.
|
||||||
return
|
return
|
||||||
|
if name.startswith("model.vision_tower") and name not in params_dict:
|
||||||
|
return
|
||||||
|
|
||||||
for param_name, weight_name, shard_id in stacked_params_mapping:
|
for param_name, weight_name, shard_id in stacked_params_mapping:
|
||||||
if weight_name not in name:
|
if weight_name not in name:
|
||||||
continue
|
continue
|
||||||
@@ -64,8 +67,6 @@ class LlamaEmbeddingModel(nn.Module):
|
|||||||
# Skip loading extra bias for GPTQ models.
|
# Skip loading extra bias for GPTQ models.
|
||||||
if name.endswith(".bias") and name not in params_dict:
|
if name.endswith(".bias") and name not in params_dict:
|
||||||
continue
|
continue
|
||||||
if name.startswith("model.vision_tower") and name not in params_dict:
|
|
||||||
continue
|
|
||||||
param = params_dict[name]
|
param = params_dict[name]
|
||||||
weight_loader = param.weight_loader
|
weight_loader = param.weight_loader
|
||||||
weight_loader(param, loaded_weight, shard_id)
|
weight_loader(param, loaded_weight, shard_id)
|
||||||
@@ -74,8 +75,6 @@ class LlamaEmbeddingModel(nn.Module):
|
|||||||
# Skip loading extra bias for GPTQ models.
|
# Skip loading extra bias for GPTQ models.
|
||||||
if name.endswith(".bias") and name not in params_dict:
|
if name.endswith(".bias") and name not in params_dict:
|
||||||
return
|
return
|
||||||
if name.startswith("model.vision_tower") and name not in params_dict:
|
|
||||||
return
|
|
||||||
param = params_dict[name]
|
param = params_dict[name]
|
||||||
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||||
weight_loader(param, loaded_weight)
|
weight_loader(param, loaded_weight)
|
||||||
|
|||||||
@@ -28,7 +28,6 @@ from transformers import (
|
|||||||
LlavaConfig,
|
LlavaConfig,
|
||||||
MistralConfig,
|
MistralConfig,
|
||||||
Qwen2Config,
|
Qwen2Config,
|
||||||
SiglipVisionConfig,
|
|
||||||
SiglipVisionModel,
|
SiglipVisionModel,
|
||||||
)
|
)
|
||||||
from transformers.models.llava.modeling_llava import LlavaMultiModalProjector
|
from transformers.models.llava.modeling_llava import LlavaMultiModalProjector
|
||||||
@@ -66,13 +65,18 @@ class LlavaLlamaForCausalLM(nn.Module):
|
|||||||
torch.empty(config.text_config.hidden_size, dtype=torch.float16)
|
torch.empty(config.text_config.hidden_size, dtype=torch.float16)
|
||||||
)
|
)
|
||||||
|
|
||||||
def pad_input_ids(self, input_ids, pad_value, pt_shape=None, image_size=None):
|
def pad_input_ids(
|
||||||
|
self,
|
||||||
|
input_ids: List[int],
|
||||||
|
pad_value: List[int],
|
||||||
|
pixel_values: List,
|
||||||
|
image_sizes: List[List[int]],
|
||||||
|
):
|
||||||
# hardcode for spatial_unpad + anyres
|
# hardcode for spatial_unpad + anyres
|
||||||
image_aspect_ratio = "anyres" if len(image_size) == 1 else "pad"
|
image_aspect_ratio = "anyres" if len(image_sizes) == 1 else "pad"
|
||||||
offset_list = []
|
offset_list = []
|
||||||
for image_s in image_size:
|
for image_s in image_sizes:
|
||||||
if len(image_size) > 16:
|
if len(image_sizes) > 16:
|
||||||
# 2x2 pooling with stride 2
|
# 2x2 pooling with stride 2
|
||||||
new_image_feature_len = (
|
new_image_feature_len = (
|
||||||
math.ceil(self.image_size / self.patch_size / 2) ** 2
|
math.ceil(self.image_size / self.patch_size / 2) ** 2
|
||||||
@@ -153,17 +157,15 @@ class LlavaLlamaForCausalLM(nn.Module):
|
|||||||
if input_metadata.forward_mode == ForwardMode.EXTEND:
|
if input_metadata.forward_mode == ForwardMode.EXTEND:
|
||||||
bs = input_metadata.batch_size
|
bs = input_metadata.batch_size
|
||||||
|
|
||||||
# Embed text input
|
# Embed text inputs
|
||||||
input_embeds = self.language_model.model.embed_tokens(input_ids)
|
input_embeds = self.language_model.model.embed_tokens(input_ids)
|
||||||
# Embed vision input
|
|
||||||
need_vision = (
|
# Whether the requests need vision inputs
|
||||||
(positions[input_metadata.extend_start_loc] < self.image_feature_len)
|
max_image_offset = np.array(
|
||||||
.cpu()
|
[max(image_offsets[i]) if image_offsets[i] else -1 for i in range(bs)]
|
||||||
.numpy()
|
|
||||||
)
|
)
|
||||||
# FIXME: We need to substract the length of the system prompt
|
start_positions = positions[input_metadata.extend_start_loc].cpu().numpy()
|
||||||
has_pixel = np.array([pixel_values[i] is not None for i in range(bs)])
|
need_vision = start_positions <= max_image_offset
|
||||||
need_vision = need_vision & has_pixel
|
|
||||||
|
|
||||||
if need_vision.any():
|
if need_vision.any():
|
||||||
pixel_values = [pixel_values[i] for i in range(bs) if need_vision[i]]
|
pixel_values = [pixel_values[i] for i in range(bs) if need_vision[i]]
|
||||||
@@ -332,31 +334,35 @@ class LlavaLlamaForCausalLM(nn.Module):
|
|||||||
new_image_features.append(image_feature)
|
new_image_features.append(image_feature)
|
||||||
image_features = new_image_features
|
image_features = new_image_features
|
||||||
|
|
||||||
|
# Fill in the placeholder for the image
|
||||||
extend_start_loc_cpu = input_metadata.extend_start_loc.cpu().numpy()
|
extend_start_loc_cpu = input_metadata.extend_start_loc.cpu().numpy()
|
||||||
|
prefix_lens_cpu = input_metadata.extend_prefix_lens.cpu().numpy()
|
||||||
pt = 0
|
pt = 0
|
||||||
for i in range(bs):
|
for i in range(bs):
|
||||||
if not need_vision[i]:
|
if not need_vision[i]:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
start_idx = extend_start_loc_cpu[i]
|
start_idx = extend_start_loc_cpu[i]
|
||||||
pad_dim = image_features[pt].shape[-1] # 576, 4096
|
prefix_len = prefix_lens_cpu[i]
|
||||||
dim = input_embeds.shape[1]
|
|
||||||
assert (
|
# Multiple images
|
||||||
pad_dim == dim
|
for j, image_offset in enumerate(image_offsets[i]):
|
||||||
), "invalid pad_dim={}, input_embed_dim={}!".format(pad_dim, dim)
|
if image_offset < prefix_len:
|
||||||
# Fill in the placeholder for the image
|
continue
|
||||||
try:
|
|
||||||
for j, image_off in enumerate(image_offsets[i]):
|
tmp_image_feature = image_features[pt][j]
|
||||||
# print("actual image_features length: ", image_features[pt][j].shape[0])
|
pad_len = tmp_image_feature.shape[0]
|
||||||
pad_len = image_features[pt][j].shape[0]
|
|
||||||
input_embeds[
|
left_idx = start_idx + (image_offset - prefix_len)
|
||||||
start_idx + image_off : start_idx + image_off + pad_len
|
right_idx = start_idx + (image_offset - prefix_len) + pad_len
|
||||||
] = image_features[pt][j]
|
try:
|
||||||
except RuntimeError as e:
|
input_embeds[left_idx:right_idx] = tmp_image_feature
|
||||||
print(f"RuntimeError in llava image encoding: {e}")
|
except RuntimeError as e:
|
||||||
print(image_features[pt].shape)
|
print(f"RuntimeError in image encoding: {e}")
|
||||||
print(input_embeds.shape)
|
print(f"{input_embeds.shape=}, {tmp_image_feature.shape=}")
|
||||||
print(start_idx, image_offsets[i])
|
print(
|
||||||
|
f"{start_idx=}, {image_offset=}, {prefix_len=}, {pad_len=}"
|
||||||
|
)
|
||||||
pt += 1
|
pt += 1
|
||||||
|
|
||||||
return self.language_model(
|
return self.language_model(
|
||||||
@@ -366,8 +372,9 @@ class LlavaLlamaForCausalLM(nn.Module):
|
|||||||
return self.language_model(input_ids, positions, input_metadata)
|
return self.language_model(input_ids, positions, input_metadata)
|
||||||
|
|
||||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||||
# load clip vision model by cfg['mm_vision_tower']:
|
# Load clip vision model by cfg['mm_vision_tower']:
|
||||||
# huggingface_name or path_of_clip_relative_to_llava_model_dir
|
# huggingface_name or path_of_clip_relative_to_llava_model_dir
|
||||||
|
# We put the initialization here instead of __init__ to allow it being reused by other subclasses.
|
||||||
vision_path = self.config.mm_vision_tower
|
vision_path = self.config.mm_vision_tower
|
||||||
if "clip" in vision_path:
|
if "clip" in vision_path:
|
||||||
self.vision_tower = CLIPVisionModel.from_pretrained(
|
self.vision_tower = CLIPVisionModel.from_pretrained(
|
||||||
@@ -422,8 +429,6 @@ class LlavaLlamaForCausalLM(nn.Module):
|
|||||||
# load language model
|
# load language model
|
||||||
self.language_model.load_weights(weights)
|
self.language_model.load_weights(weights)
|
||||||
|
|
||||||
monkey_path_clip_vision_embed_forward()
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def num_patches_per_side(self):
|
def num_patches_per_side(self):
|
||||||
return self.image_size // self.patch_size
|
return self.image_size // self.patch_size
|
||||||
@@ -495,36 +500,4 @@ class LlavaMistralForCausalLM(LlavaLlamaForCausalLM):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
first_call = True
|
|
||||||
|
|
||||||
|
|
||||||
def clip_vision_embed_forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
|
|
||||||
batch_size = pixel_values.shape[0]
|
|
||||||
|
|
||||||
# Move this conv layer to CPU to avoid a bug in torch >= 2.1 on A10G.
|
|
||||||
global first_call
|
|
||||||
if first_call:
|
|
||||||
self.patch_embedding.cpu().float()
|
|
||||||
first_call = False
|
|
||||||
pixel_values = pixel_values.to(dtype=torch.float32, device="cpu")
|
|
||||||
patch_embeds = self.patch_embedding(pixel_values).cuda().half()
|
|
||||||
|
|
||||||
patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
|
|
||||||
|
|
||||||
class_embeds = self.class_embedding.expand(batch_size, 1, -1)
|
|
||||||
embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
|
|
||||||
embeddings = embeddings + self.position_embedding(self.position_ids)
|
|
||||||
return embeddings
|
|
||||||
|
|
||||||
|
|
||||||
def monkey_path_clip_vision_embed_forward():
|
|
||||||
import transformers
|
|
||||||
|
|
||||||
setattr(
|
|
||||||
transformers.models.clip.modeling_clip.CLIPVisionEmbeddings,
|
|
||||||
"forward",
|
|
||||||
clip_vision_embed_forward,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
EntryClass = [LlavaLlamaForCausalLM, LlavaQwenForCausalLM, LlavaMistralForCausalLM]
|
EntryClass = [LlavaLlamaForCausalLM, LlavaQwenForCausalLM, LlavaMistralForCausalLM]
|
||||||
|
|||||||
@@ -26,11 +26,6 @@ from vllm.config import CacheConfig
|
|||||||
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
|
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
|
||||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||||
|
|
||||||
from sglang.srt.mm_utils import (
|
|
||||||
get_anyres_image_grid_shape,
|
|
||||||
unpad_image,
|
|
||||||
unpad_image_shape,
|
|
||||||
)
|
|
||||||
from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetadata
|
from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetadata
|
||||||
from sglang.srt.models.llama2 import LlamaForCausalLM
|
from sglang.srt.models.llama2 import LlamaForCausalLM
|
||||||
|
|
||||||
@@ -59,23 +54,14 @@ class LlavaVidForCausalLM(nn.Module):
|
|||||||
torch.empty(config.text_config.hidden_size, dtype=torch.float16)
|
torch.empty(config.text_config.hidden_size, dtype=torch.float16)
|
||||||
)
|
)
|
||||||
|
|
||||||
def pad_input_ids(self, input_ids, pad_value, pt_shape=None, image_size=None):
|
def pad_input_ids(
|
||||||
|
self,
|
||||||
|
input_ids: List[int],
|
||||||
|
pad_value: List[int],
|
||||||
|
pixel_values: List,
|
||||||
|
image_sizes: List[List[int]],
|
||||||
|
):
|
||||||
new_image_feature_len = self.image_feature_len
|
new_image_feature_len = self.image_feature_len
|
||||||
# now only support spatial_unpad + anyres
|
|
||||||
# if self.mm_patch_merge_type.startswith("spatial"):
|
|
||||||
# height = width = self.num_patches_per_side
|
|
||||||
# if pt_shape[0] > 1:
|
|
||||||
# if self.image_aspect_ratio == "anyres":
|
|
||||||
# num_patch_width, num_patch_height = get_anyres_image_grid_shape(
|
|
||||||
# image_size,
|
|
||||||
# self.image_grid_pinpoints,
|
|
||||||
# self.vision_tower.config.image_size,
|
|
||||||
# )
|
|
||||||
# if "unpad" in self.mm_patch_merge_type:
|
|
||||||
# h = num_patch_height * height
|
|
||||||
# w = num_patch_width * width
|
|
||||||
# new_h, new_w = unpad_image_shape(h, w, image_size)
|
|
||||||
# new_image_feature_len += new_h * (new_w + 1)
|
|
||||||
|
|
||||||
pad_ids = pad_value * (
|
pad_ids = pad_value * (
|
||||||
(new_image_feature_len + len(pad_value)) // len(pad_value)
|
(new_image_feature_len + len(pad_value)) // len(pad_value)
|
||||||
@@ -87,7 +73,7 @@ class LlavaVidForCausalLM(nn.Module):
|
|||||||
+ pad_ids[:new_image_feature_len]
|
+ pad_ids[:new_image_feature_len]
|
||||||
+ input_ids[offset + 1 :]
|
+ input_ids[offset + 1 :]
|
||||||
)
|
)
|
||||||
return new_input_ids, offset
|
return new_input_ids, [offset]
|
||||||
|
|
||||||
def encode_images(self, pixel_values: torch.Tensor) -> torch.Tensor:
|
def encode_images(self, pixel_values: torch.Tensor) -> torch.Tensor:
|
||||||
image_outputs = self.vision_tower(pixel_values, output_hidden_states=True)
|
image_outputs = self.vision_tower(pixel_values, output_hidden_states=True)
|
||||||
@@ -133,22 +119,18 @@ class LlavaVidForCausalLM(nn.Module):
|
|||||||
if input_metadata.forward_mode == ForwardMode.EXTEND:
|
if input_metadata.forward_mode == ForwardMode.EXTEND:
|
||||||
bs = input_metadata.batch_size
|
bs = input_metadata.batch_size
|
||||||
|
|
||||||
# Embed text input
|
# Embed text inputs
|
||||||
input_embeds = self.language_model.model.embed_tokens(input_ids)
|
input_embeds = self.language_model.model.embed_tokens(input_ids)
|
||||||
|
|
||||||
# Embed vision input
|
# Whether the requests need vision inputs
|
||||||
need_vision = (
|
max_image_offset = np.array(
|
||||||
(positions[input_metadata.extend_start_loc] < self.image_feature_len)
|
[max(image_offsets[i]) if image_offsets[i] else -1 for i in range(bs)]
|
||||||
.cpu()
|
|
||||||
.numpy()
|
|
||||||
)
|
)
|
||||||
# FIXME: We need to substract the length of the system prompt
|
start_positions = positions[input_metadata.extend_start_loc].cpu().numpy()
|
||||||
has_pixel = np.array([pixel_values[i] is not None for i in range(bs)])
|
need_vision = start_positions <= max_image_offset
|
||||||
need_vision = need_vision & has_pixel
|
|
||||||
|
|
||||||
if need_vision.any():
|
if need_vision.any():
|
||||||
pixel_values = [pixel_values[i] for i in range(bs) if need_vision[i]]
|
pixel_values = [pixel_values[i] for i in range(bs) if need_vision[i]]
|
||||||
image_sizes = [image_sizes[i] for i in range(bs) if need_vision[i]]
|
|
||||||
|
|
||||||
########## Encode Image ########
|
########## Encode Image ########
|
||||||
|
|
||||||
@@ -183,31 +165,36 @@ class LlavaVidForCausalLM(nn.Module):
|
|||||||
new_image_features.append(image_feature.flatten(0, 1))
|
new_image_features.append(image_feature.flatten(0, 1))
|
||||||
image_features = new_image_features
|
image_features = new_image_features
|
||||||
|
|
||||||
|
# Fill in the placeholder for the image
|
||||||
extend_start_loc_cpu = input_metadata.extend_start_loc.cpu().numpy()
|
extend_start_loc_cpu = input_metadata.extend_start_loc.cpu().numpy()
|
||||||
|
prefix_lens_cpu = input_metadata.extend_prefix_lens.cpu().numpy()
|
||||||
pt = 0
|
pt = 0
|
||||||
for i in range(bs):
|
for i in range(bs):
|
||||||
if not need_vision[i]:
|
if not need_vision[i]:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
start_idx = extend_start_loc_cpu[i]
|
start_idx = extend_start_loc_cpu[i]
|
||||||
pad_len, pad_dim = image_features[pt].shape # 576, 4096
|
prefix_len = prefix_lens_cpu[i]
|
||||||
dim = input_embeds.shape[1]
|
|
||||||
assert (
|
# Multiple images
|
||||||
pad_dim == dim
|
for image_offset in image_offsets[i]:
|
||||||
), "invalid pad_dim={}, input_embed_dim={}!".format(pad_dim, dim)
|
if image_offset < prefix_len:
|
||||||
# Fill in the placeholder for the image
|
continue
|
||||||
try:
|
|
||||||
input_embeds[
|
tmp_image_feature = image_features[pt]
|
||||||
start_idx
|
pad_len = tmp_image_feature.shape[0]
|
||||||
+ image_offsets[i] : start_idx
|
|
||||||
+ image_offsets[i]
|
left_idx = start_idx + (image_offset - prefix_len)
|
||||||
+ pad_len
|
right_idx = start_idx + (image_offset - prefix_len) + pad_len
|
||||||
] = image_features[pt]
|
try:
|
||||||
except RuntimeError as e:
|
input_embeds[left_idx:right_idx] = tmp_image_feature
|
||||||
print(f"RuntimeError in llava image encoding: {e}")
|
except RuntimeError as e:
|
||||||
print(input_embeds.shape)
|
print(f"RuntimeError in image encoding: {e}")
|
||||||
print(start_idx, image_offsets[i])
|
print(f"{input_embeds.shape=}, {tmp_image_feature.shape=}")
|
||||||
pt += 1
|
print(
|
||||||
|
f"{start_idx=}, {image_offset=}, {prefix_len=}, {pad_len=}"
|
||||||
|
)
|
||||||
|
pt += 1
|
||||||
|
|
||||||
return self.language_model(
|
return self.language_model(
|
||||||
input_ids, positions, input_metadata, input_embeds=input_embeds
|
input_ids, positions, input_metadata, input_embeds=input_embeds
|
||||||
@@ -216,8 +203,9 @@ class LlavaVidForCausalLM(nn.Module):
|
|||||||
return self.language_model(input_ids, positions, input_metadata)
|
return self.language_model(input_ids, positions, input_metadata)
|
||||||
|
|
||||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||||
# load clip vision model by cfg['mm_vision_tower']:
|
# Load clip vision model by cfg['mm_vision_tower']:
|
||||||
# huggingface_name or path_of_clip_relative_to_llava_model_dir
|
# huggingface_name or path_of_clip_relative_to_llava_model_dir
|
||||||
|
# We put the initialization here instead of __init__ to allow it being reused by other subclasses.
|
||||||
vision_path = self.config.mm_vision_tower
|
vision_path = self.config.mm_vision_tower
|
||||||
self.vision_tower = CLIPVisionModel.from_pretrained(
|
self.vision_tower = CLIPVisionModel.from_pretrained(
|
||||||
vision_path, torch_dtype=torch.float16
|
vision_path, torch_dtype=torch.float16
|
||||||
@@ -271,43 +259,9 @@ class LlavaVidForCausalLM(nn.Module):
|
|||||||
# load language model
|
# load language model
|
||||||
self.language_model.load_weights(weights)
|
self.language_model.load_weights(weights)
|
||||||
|
|
||||||
monkey_path_clip_vision_embed_forward()
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def num_patches_per_side(self):
|
def num_patches_per_side(self):
|
||||||
return self.image_size // self.patch_size
|
return self.image_size // self.patch_size
|
||||||
|
|
||||||
|
|
||||||
first_call = True
|
|
||||||
|
|
||||||
|
|
||||||
def clip_vision_embed_forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
|
|
||||||
batch_size = pixel_values.shape[0]
|
|
||||||
|
|
||||||
# Move this conv layer to CPU to avoid a bug in torch >= 2.1 on A10G.
|
|
||||||
global first_call
|
|
||||||
if first_call:
|
|
||||||
self.patch_embedding.cpu().float()
|
|
||||||
first_call = False
|
|
||||||
pixel_values = pixel_values.to(dtype=torch.float32, device="cpu")
|
|
||||||
patch_embeds = self.patch_embedding(pixel_values).cuda().half()
|
|
||||||
|
|
||||||
patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
|
|
||||||
|
|
||||||
class_embeds = self.class_embedding.expand(batch_size, 1, -1)
|
|
||||||
embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
|
|
||||||
embeddings = embeddings + self.position_embedding(self.position_ids)
|
|
||||||
return embeddings
|
|
||||||
|
|
||||||
|
|
||||||
def monkey_path_clip_vision_embed_forward():
|
|
||||||
import transformers
|
|
||||||
|
|
||||||
setattr(
|
|
||||||
transformers.models.clip.modeling_clip.CLIPVisionEmbeddings,
|
|
||||||
"forward",
|
|
||||||
clip_vision_embed_forward,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
EntryClass = LlavaVidForCausalLM
|
EntryClass = LlavaVidForCausalLM
|
||||||
|
|||||||
@@ -312,6 +312,9 @@ class Qwen2ForCausalLM(nn.Module):
|
|||||||
# Models trained using ColossalAI may include these tensors in
|
# Models trained using ColossalAI may include these tensors in
|
||||||
# the checkpoint. Skip them.
|
# the checkpoint. Skip them.
|
||||||
continue
|
continue
|
||||||
|
if name.startswith("model.vision_tower") and name not in params_dict:
|
||||||
|
continue
|
||||||
|
|
||||||
for param_name, weight_name, shard_id in stacked_params_mapping:
|
for param_name, weight_name, shard_id in stacked_params_mapping:
|
||||||
if weight_name not in name:
|
if weight_name not in name:
|
||||||
continue
|
continue
|
||||||
@@ -319,8 +322,6 @@ class Qwen2ForCausalLM(nn.Module):
|
|||||||
# Skip loading extra bias for GPTQ models.
|
# Skip loading extra bias for GPTQ models.
|
||||||
if name.endswith(".bias") and name not in params_dict:
|
if name.endswith(".bias") and name not in params_dict:
|
||||||
continue
|
continue
|
||||||
if name.startswith("model.vision_tower") and name not in params_dict:
|
|
||||||
continue
|
|
||||||
param = params_dict[name]
|
param = params_dict[name]
|
||||||
weight_loader = param.weight_loader
|
weight_loader = param.weight_loader
|
||||||
weight_loader(param, loaded_weight, shard_id)
|
weight_loader(param, loaded_weight, shard_id)
|
||||||
@@ -329,8 +330,6 @@ class Qwen2ForCausalLM(nn.Module):
|
|||||||
# Skip loading extra bias for GPTQ models.
|
# Skip loading extra bias for GPTQ models.
|
||||||
if name.endswith(".bias") and name not in params_dict:
|
if name.endswith(".bias") and name not in params_dict:
|
||||||
continue
|
continue
|
||||||
if name.startswith("model.vision_tower") and name not in params_dict:
|
|
||||||
continue
|
|
||||||
param = params_dict[name]
|
param = params_dict[name]
|
||||||
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||||
weight_loader(param, loaded_weight)
|
weight_loader(param, loaded_weight)
|
||||||
|
|||||||
@@ -24,10 +24,7 @@ from vllm.config import CacheConfig
|
|||||||
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
|
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
|
||||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||||
|
|
||||||
from sglang.srt.models.llava import (
|
from sglang.srt.models.llava import LlavaLlamaForCausalLM
|
||||||
LlavaLlamaForCausalLM,
|
|
||||||
monkey_path_clip_vision_embed_forward,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class YiVLForCausalLM(LlavaLlamaForCausalLM):
|
class YiVLForCausalLM(LlavaLlamaForCausalLM):
|
||||||
@@ -50,7 +47,7 @@ class YiVLForCausalLM(LlavaLlamaForCausalLM):
|
|||||||
self.config._name_or_path,
|
self.config._name_or_path,
|
||||||
torch_dtype=torch.float16,
|
torch_dtype=torch.float16,
|
||||||
subfolder=self.vision_tower_subfolder,
|
subfolder=self.vision_tower_subfolder,
|
||||||
).cuda()
|
).to("cuda")
|
||||||
|
|
||||||
self.vision_tower.eval()
|
self.vision_tower.eval()
|
||||||
|
|
||||||
@@ -94,8 +91,6 @@ class YiVLForCausalLM(LlavaLlamaForCausalLM):
|
|||||||
# load language model
|
# load language model
|
||||||
self.language_model.load_weights(weights)
|
self.language_model.load_weights(weights)
|
||||||
|
|
||||||
monkey_path_clip_vision_embed_forward()
|
|
||||||
|
|
||||||
|
|
||||||
class YiVLMultiModalProjector(nn.Module):
|
class YiVLMultiModalProjector(nn.Module):
|
||||||
def __init__(self, config: LlavaConfig):
|
def __init__(self, config: LlavaConfig):
|
||||||
|
|||||||
@@ -335,12 +335,12 @@ def launch_server(
|
|||||||
pipe_detoken_reader, pipe_detoken_writer = mp.Pipe(duplex=False)
|
pipe_detoken_reader, pipe_detoken_writer = mp.Pipe(duplex=False)
|
||||||
|
|
||||||
if server_args.dp_size == 1:
|
if server_args.dp_size == 1:
|
||||||
start_process = start_controller_process_single
|
start_controller_process = start_controller_process_single
|
||||||
else:
|
else:
|
||||||
start_process = start_controller_process_multi
|
start_controller_process = start_controller_process_multi
|
||||||
|
|
||||||
proc_controller = mp.Process(
|
proc_controller = mp.Process(
|
||||||
target=start_process,
|
target=start_controller_process,
|
||||||
args=(server_args, port_args, pipe_controller_writer, model_overide_args),
|
args=(server_args, port_args, pipe_controller_writer, model_overide_args),
|
||||||
)
|
)
|
||||||
proc_controller.start()
|
proc_controller.start()
|
||||||
@@ -421,7 +421,7 @@ def _set_envs_and_config(server_args: ServerArgs):
|
|||||||
if not server_args.disable_flashinfer:
|
if not server_args.disable_flashinfer:
|
||||||
assert_pkg_version(
|
assert_pkg_version(
|
||||||
"flashinfer",
|
"flashinfer",
|
||||||
"0.1.6",
|
"0.1.5",
|
||||||
"Please uninstall the old version and "
|
"Please uninstall the old version and "
|
||||||
"reinstall the latest version by following the instructions "
|
"reinstall the latest version by following the instructions "
|
||||||
"at https://docs.flashinfer.ai/installation.html.",
|
"at https://docs.flashinfer.ai/installation.html.",
|
||||||
|
|||||||
@@ -26,7 +26,7 @@ import struct
|
|||||||
import time
|
import time
|
||||||
from importlib.metadata import PackageNotFoundError, version
|
from importlib.metadata import PackageNotFoundError, version
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
from typing import List, Optional
|
from typing import List, Optional, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import psutil
|
import psutil
|
||||||
@@ -193,35 +193,16 @@ def allocate_init_ports(
|
|||||||
return ret_ports[0], ret_ports[1:num_ports_needed]
|
return ret_ports[0], ret_ports[1:num_ports_needed]
|
||||||
|
|
||||||
|
|
||||||
def get_int_token_logit_bias(tokenizer, vocab_size):
|
def is_multimodal_model(model_architectures):
|
||||||
"""Get the logit bias for integer-only tokens."""
|
if (
|
||||||
# a bug when model's vocab size > tokenizer.vocab_size
|
"LlavaLlamaForCausalLM" in model_architectures
|
||||||
if tokenizer == None:
|
or "LlavaQwenForCausalLM" in model_architectures
|
||||||
return [-1e5] * vocab_size
|
or "LlavaMistralForCausalLM" in model_architectures
|
||||||
vocab_size = tokenizer.vocab_size
|
or "LlavaVidForCausalLM" in model_architectures
|
||||||
logit_bias = np.zeros(vocab_size, dtype=np.float32)
|
):
|
||||||
for t_id in range(vocab_size):
|
return True
|
||||||
ss = tokenizer.decode([t_id]).strip()
|
else:
|
||||||
if not (ss.isdigit() or len(ss) == 0 or t_id == tokenizer.eos_token_id):
|
return False
|
||||||
logit_bias[t_id] = -1e5
|
|
||||||
|
|
||||||
return logit_bias
|
|
||||||
|
|
||||||
|
|
||||||
def is_multimodal_model(model):
|
|
||||||
from sglang.srt.model_config import ModelConfig
|
|
||||||
|
|
||||||
if isinstance(model, str):
|
|
||||||
model = model.lower()
|
|
||||||
return "llava" in model or "yi-vl" in model or "llava-next" in model
|
|
||||||
|
|
||||||
if isinstance(model, ModelConfig):
|
|
||||||
model_path = model.path.lower()
|
|
||||||
return (
|
|
||||||
"llava" in model_path or "yi-vl" in model_path or "llava-next" in model_path
|
|
||||||
)
|
|
||||||
|
|
||||||
raise ValueError("unrecognized type")
|
|
||||||
|
|
||||||
|
|
||||||
def is_generation_model(model_architectures, is_embedding: bool = False):
|
def is_generation_model(model_architectures, is_embedding: bool = False):
|
||||||
@@ -317,12 +298,14 @@ def decode_video_base64(video_base64):
|
|||||||
) # Return an empty array and size tuple if no frames were found
|
) # Return an empty array and size tuple if no frames were found
|
||||||
|
|
||||||
|
|
||||||
def load_image(image_file):
|
def load_image(image_file: Union[str, bytes]):
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
image = image_size = None
|
image = image_size = None
|
||||||
|
|
||||||
if image_file.startswith("http://") or image_file.startswith("https://"):
|
if isinstance(image_file, bytes):
|
||||||
|
image = Image.open(BytesIO(image_file))
|
||||||
|
elif image_file.startswith("http://") or image_file.startswith("https://"):
|
||||||
timeout = int(os.getenv("REQUEST_TIMEOUT", "3"))
|
timeout = int(os.getenv("REQUEST_TIMEOUT", "3"))
|
||||||
response = requests.get(image_file, timeout=timeout)
|
response = requests.get(image_file, timeout=timeout)
|
||||||
image = Image.open(BytesIO(response.content))
|
image = Image.open(BytesIO(response.content))
|
||||||
@@ -334,8 +317,10 @@ def load_image(image_file):
|
|||||||
elif image_file.startswith("video:"):
|
elif image_file.startswith("video:"):
|
||||||
image_file = image_file.replace("video:", "")
|
image_file = image_file.replace("video:", "")
|
||||||
image, image_size = decode_video_base64(image_file)
|
image, image_size = decode_video_base64(image_file)
|
||||||
else:
|
elif isinstance(image_file, str):
|
||||||
image = Image.open(BytesIO(base64.b64decode(image_file)))
|
image = Image.open(BytesIO(base64.b64decode(image_file)))
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Invalid image: {image}")
|
||||||
|
|
||||||
return image, image_size
|
return image, image_size
|
||||||
|
|
||||||
|
|||||||
@@ -32,8 +32,6 @@ class TestOpenAIVisionServer(unittest.TestCase):
|
|||||||
other_args=[
|
other_args=[
|
||||||
"--chat-template",
|
"--chat-template",
|
||||||
"chatml-llava",
|
"chatml-llava",
|
||||||
"--chunked-prefill-size",
|
|
||||||
"16384",
|
|
||||||
# "--log-requests",
|
# "--log-requests",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user