[Fix] Fix llava on multi images (#1247)

This commit is contained in:
Lianmin Zheng
2024-08-28 06:33:05 -07:00
committed by GitHub
parent b1a540ec42
commit bf53bf5142
22 changed files with 272 additions and 488 deletions

View File

@@ -240,7 +240,7 @@ python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3-8B-Instruct
- Qwen / Qwen 2 / Qwen 2 MoE - Qwen / Qwen 2 / Qwen 2 MoE
- DeepSeek / DeepSeek 2 - DeepSeek / DeepSeek 2
- [LLaVA-OneVision](https://llava-vl.github.io/blog/2024-08-05-llava-onevision/) - [LLaVA-OneVision](https://llava-vl.github.io/blog/2024-08-05-llava-onevision/)
- `python3 -m sglang.launch_server --model-path lmms-lab/llava-onevision-qwen2-72b-ov --port=30000 --tp-size=8 --chat-template=chatml-llava --chunked-prefill-size=16384` - `python3 -m sglang.launch_server --model-path lmms-lab/llava-onevision-qwen2-72b-ov --port=30000 --tp-size=8 --chat-template=chatml-llava`
- Query the server with the [OpenAI Vision API](https://platform.openai.com/docs/guides/vision). See examples at [test/srt/test_vision_openai_server.py](test/srt/test_vision_openai_server.py) - Query the server with the [OpenAI Vision API](https://platform.openai.com/docs/guides/vision). See examples at [test/srt/test_vision_openai_server.py](test/srt/test_vision_openai_server.py)
- LLaVA 1.5 / 1.6 / NeXT - LLaVA 1.5 / 1.6 / NeXT
- `python -m sglang.launch_server --model-path lmms-lab/llama3-llava-next-8b --port=30000 --tp-size=1 --chat-template=llava_llama_3` - `python -m sglang.launch_server --model-path lmms-lab/llama3-llava-next-8b --port=30000 --tp-size=1 --chat-template=llava_llama_3`

View File

@@ -184,13 +184,9 @@ if __name__ == "__main__":
# Parse the arguments # Parse the arguments
args = parser.parse_args() args = parser.parse_args()
cur_port = args.port cur_port = args.port
cur_chunk = args.chunk_idx cur_chunk = args.chunk_idx
num_chunks = args.num_chunks num_chunks = args.num_chunks
num_frames = args.num_frames num_frames = args.num_frames
if "34b" in args.model_path.lower(): if "34b" in args.model_path.lower():
@@ -202,7 +198,6 @@ if __name__ == "__main__":
exit() exit()
model_overide_args = {} model_overide_args = {}
model_overide_args["mm_spatial_pool_stride"] = args.mm_spatial_pool_stride model_overide_args["mm_spatial_pool_stride"] = args.mm_spatial_pool_stride
model_overide_args["architectures"] = ["LlavaVidForCausalLM"] model_overide_args["architectures"] = ["LlavaVidForCausalLM"]
model_overide_args["num_frames"] = args.num_frames model_overide_args["num_frames"] = args.num_frames
@@ -235,7 +230,6 @@ if __name__ == "__main__":
print(f"chat template: {runtime.endpoint.chat_template.name}") print(f"chat template: {runtime.endpoint.chat_template.name}")
# Run a single request # Run a single request
# try:
print("\n========== single ==========\n") print("\n========== single ==========\n")
root = args.video_dir root = args.video_dir
if os.path.isfile(root): if os.path.isfile(root):
@@ -257,13 +251,10 @@ if __name__ == "__main__":
) # Calculate the average processing time ) # Calculate the average processing time
print(f"Average processing time per video: {average_time:.2f} seconds") print(f"Average processing time per video: {average_time:.2f} seconds")
runtime.shutdown() runtime.shutdown()
# except Exception as e:
# print(e)
runtime.shutdown()
# # # Run a batch of requests # # Run a batch of requests
# print("\n========== batch ==========\n") # print("\n========== batch ==========\n")
# if not os.path.exists(args.save_dir): # if not os.path.exists(args.save_dir):
# os.makedirs(args.save_dir) # os.makedirs(args.save_dir)
# batch(args.video_dir,args.save_dir,cur_chunk, num_chunks, num_frames, num_chunks) # batch(args.video_dir, args.save_dir, cur_chunk, num_chunks, num_frames, num_chunks)
# runtime.shutdown() # runtime.shutdown()

View File

@@ -0,0 +1,26 @@
"""Launch the inference server for Llava-video model."""
import argparse
from sglang.srt.server import ServerArgs, launch_server
if __name__ == "__main__":
parser = argparse.ArgumentParser()
ServerArgs.add_cli_args(parser)
args = parser.parse_args()
server_args = ServerArgs.from_cli_args(args)
model_overide_args = {}
model_overide_args["mm_spatial_pool_stride"] = 2
model_overide_args["architectures"] = ["LlavaVidForCausalLM"]
model_overide_args["num_frames"] = 16
model_overide_args["model_type"] = "llavavid"
if model_overide_args["num_frames"] == 32:
model_overide_args["rope_scaling"] = {"factor": 2.0, "type": "linear"}
model_overide_args["max_sequence_length"] = 4096 * 2
model_overide_args["tokenizer_model_max_length"] = 4096 * 2
model_overide_args["model_max_length"] = 4096 * 2
if "34b" in args.model_path.lower():
model_overide_args["image_token_index"] = 64002
launch_server(server_args, model_overide_args, None)

View File

@@ -119,24 +119,7 @@ def get_tokenizer(
tokenizer_revision: Optional[str] = None, tokenizer_revision: Optional[str] = None,
**kwargs, **kwargs,
) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]: ) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]:
if tokenizer_name.endswith(".json"):
return TiktokenTokenizer(tokenizer_name)
if tokenizer_name.endswith(".model"):
return SentencePieceTokenizer(tokenizer_name)
"""Gets a tokenizer for the given model name via Huggingface.""" """Gets a tokenizer for the given model name via Huggingface."""
if is_multimodal_model(tokenizer_name):
processor = get_processor(
tokenizer_name,
*args,
trust_remote_code=trust_remote_code,
tokenizer_revision=tokenizer_revision,
**kwargs,
)
tokenizer = processor.tokenizer
return tokenizer
if tokenizer_mode == "slow": if tokenizer_mode == "slow":
if kwargs.get("use_fast", False): if kwargs.get("use_fast", False):
raise ValueError("Cannot use the fast tokenizer in slow tokenizer mode.") raise ValueError("Cannot use the fast tokenizer in slow tokenizer mode.")
@@ -199,135 +182,3 @@ def get_processor(
**kwargs, **kwargs,
) )
return processor return processor
class TiktokenTokenizer:
def __init__(self, tokenizer_path):
import tiktoken
from jinja2 import Template
PAT_STR_B = r"""(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+"""
# Read JSON
name = "tmp-json"
with open(tokenizer_path, "rb") as fin:
tok_dict = json.load(fin)
mergeable_ranks = {
bytes(item["bytes"]): item["token"] for item in tok_dict["regular_tokens"]
}
special_tokens = {
bytes(item["bytes"]).decode(): item["token"]
for item in tok_dict["special_tokens"]
}
assert tok_dict["word_split"] == "V1"
default_allowed_special = None
kwargs = {
"name": name,
"pat_str": tok_dict.get("pat_str", PAT_STR_B),
"mergeable_ranks": mergeable_ranks,
"special_tokens": special_tokens,
}
if "default_allowed_special" in tok_dict:
default_allowed_special = set(
[
bytes(bytes_list).decode()
for bytes_list in tok_dict["default_allowed_special"]
]
)
if "vocab_size" in tok_dict:
kwargs["explicit_n_vocab"] = tok_dict["vocab_size"]
PAD = "<|pad|>"
EOS = "<|eos|>"
SEP = "<|separator|>"
DEFAULT_CONTROL_TOKENS = {"pad": PAD, "sep": EOS, "eos": SEP}
tokenizer = tiktoken.Encoding(**kwargs)
tokenizer._default_allowed_special = default_allowed_special or set()
tokenizer._control_tokens = DEFAULT_CONTROL_TOKENS
def encode_patched(
self,
text: str,
*,
allowed_special: Union[
Literal["all"], AbstractSet[str]
] = set(), # noqa: B006
disallowed_special: Union[Literal["all"], Collection[str]] = "all",
) -> List[int]:
if isinstance(allowed_special, set):
allowed_special |= self._default_allowed_special
return tiktoken.Encoding.encode(
self,
text,
allowed_special=allowed_special,
disallowed_special=(),
)
tokenizer.encode = functools.partial(encode_patched, tokenizer)
# Convert to HF interface
self.tokenizer = tokenizer
self.eos_token_id = tokenizer._special_tokens[EOS]
self.vocab_size = tokenizer.n_vocab
self.chat_template = Template(
"{% for message in messages %}{% if message['role'] == 'user' %}{{ 'Human: ' + message['content'].strip() + '<|separator|>\n\n' }}{% elif message['role'] == 'system' %}{{ 'System: ' + message['content'].strip() + '<|separator|>\n\n' }}{% elif message['role'] == 'assistant' %}{{ 'Assistant: ' + message['content'] + '<|separator|>\n\n' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ 'Assistant:' }}{% endif %}"
)
def encode(self, x, add_special_tokens=False):
return self.tokenizer.encode(x)
def decode(self, x):
return self.tokenizer.decode(x)
def batch_decode(
self, batch, skip_special_tokens=True, spaces_between_special_tokens=False
):
if isinstance(batch[0], int):
batch = [[x] for x in batch]
return self.tokenizer.decode_batch(batch)
def apply_chat_template(self, messages, tokenize, add_generation_prompt):
ret = self.chat_template.render(
messages=messages, add_generation_prompt=add_generation_prompt
)
return self.encode(ret) if tokenize else ret
class SentencePieceTokenizer:
def __init__(self, tokenizer_path):
import sentencepiece as spm
from jinja2 import Template
tokenizer = spm.SentencePieceProcessor(model_file=tokenizer_path)
# Convert to HF interface
self.tokenizer = tokenizer
self.eos_token_id = tokenizer.eos_id()
self.vocab_size = tokenizer.vocab_size()
self.chat_template = Template(
"{% for message in messages %}{% if message['role'] == 'user' %}{{ 'Human: ' + message['content'].strip() + '<|separator|>\n\n' }}{% elif message['role'] == 'system' %}{{ 'System: ' + message['content'].strip() + '<|separator|>\n\n' }}{% elif message['role'] == 'assistant' %}{{ 'Assistant: ' + message['content'] + '<|separator|>\n\n' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ 'Assistant:' }}{% endif %}"
)
def encode(self, x, add_special_tokens=False):
return self.tokenizer.encode(x)
def decode(self, x):
return self.tokenizer.decode(x)
def batch_decode(
self, batch, skip_special_tokens=True, spaces_between_special_tokens=False
):
if isinstance(batch[0], int):
batch = [[x] for x in batch]
return self.tokenizer.decode(batch)
def apply_chat_template(self, messages, tokenize, add_generation_prompt):
ret = self.chat_template.render(
messages=messages, add_generation_prompt=add_generation_prompt
)
return self.encode(ret) if tokenize else ret

View File

@@ -55,6 +55,7 @@ class GenerateReqInput:
self.text is not None and self.input_ids is not None self.text is not None and self.input_ids is not None
): ):
raise ValueError("Either text or input_ids should be provided.") raise ValueError("Either text or input_ids should be provided.")
if ( if (
isinstance(self.sampling_params, dict) isinstance(self.sampling_params, dict)
and self.sampling_params.get("n", 1) != 1 and self.sampling_params.get("n", 1) != 1
@@ -161,10 +162,10 @@ class TokenizedGenerateReqInput:
input_ids: List[int] input_ids: List[int]
# The pixel values for input images # The pixel values for input images
pixel_values: List[float] pixel_values: List[float]
# The hash of input images # The hash values of input images
image_hash: int image_hashes: List[int]
# The image size # The image sizes
image_size: List[int] image_sizes: List[List[int]]
# The sampling parameters # The sampling parameters
sampling_params: SamplingParams sampling_params: SamplingParams
# Whether to return the logprobs # Whether to return the logprobs

View File

@@ -121,8 +121,8 @@ class Req:
# For vision input # For vision input
self.pixel_values = None self.pixel_values = None
self.image_size = None self.image_sizes = None
self.image_offset = None self.image_offsets = None
self.pad_value = None self.pad_value = None
# Prefix info # Prefix info
@@ -600,12 +600,12 @@ class ScheduleBatch:
if req.pixel_values is not None: if req.pixel_values is not None:
( (
req.origin_input_ids, req.origin_input_ids,
req.image_offset, req.image_offsets,
) = model_runner.model.pad_input_ids( ) = model_runner.model.pad_input_ids(
req.origin_input_ids_unpadded, req.origin_input_ids_unpadded,
req.pad_value, req.pad_value,
req.pixel_values.shape, req.pixel_values,
req.image_size, req.image_sizes,
) )
jump_forward_reqs.append(req) jump_forward_reqs.append(req)

View File

@@ -23,6 +23,7 @@ import multiprocessing as mp
import os import os
from typing import Dict, List, Optional, Tuple, Union from typing import Dict, List, Optional, Tuple, Union
import fastapi
import numpy as np import numpy as np
import transformers import transformers
import uvloop import uvloop
@@ -96,21 +97,18 @@ class TokenizerManager:
trust_remote_code=server_args.trust_remote_code, trust_remote_code=server_args.trust_remote_code,
model_overide_args=model_overide_args, model_overide_args=model_overide_args,
) )
self.is_generation = is_generation_model( self.is_generation = is_generation_model(
self.hf_config.architectures, self.server_args.is_embedding self.hf_config.architectures, self.server_args.is_embedding
) )
self.context_len = server_args.context_length or get_context_length(
if server_args.context_length is not None: self.hf_config
self.context_len = server_args.context_length )
else:
self.context_len = get_context_length(self.hf_config)
# Create tokenizer # Create tokenizer
if server_args.skip_tokenizer_init: if server_args.skip_tokenizer_init:
self.tokenizer = self.processor = None self.tokenizer = self.processor = None
else: else:
if is_multimodal_model(self.model_path): if is_multimodal_model(self.hf_config.architectures):
self.processor = get_processor( self.processor = get_processor(
server_args.tokenizer_path, server_args.tokenizer_path,
tokenizer_mode=server_args.tokenizer_mode, tokenizer_mode=server_args.tokenizer_mode,
@@ -118,6 +116,9 @@ class TokenizerManager:
) )
self.tokenizer = self.processor.tokenizer self.tokenizer = self.processor.tokenizer
os.environ["TOKENIZERS_PARALLELISM"] = "false" os.environ["TOKENIZERS_PARALLELISM"] = "false"
# We want to parallelize the image pre-processing so we
# create an executor for it
self.executor = concurrent.futures.ProcessPoolExecutor( self.executor = concurrent.futures.ProcessPoolExecutor(
initializer=init_global_processor, initializer=init_global_processor,
mp_context=mp.get_context("fork"), mp_context=mp.get_context("fork"),
@@ -134,12 +135,14 @@ class TokenizerManager:
self.to_create_loop = True self.to_create_loop = True
self.rid_to_state: Dict[str, ReqState] = {} self.rid_to_state: Dict[str, ReqState] = {}
# for update model weights # For update model weights
self.model_update_lock = asyncio.Lock() self.model_update_lock = asyncio.Lock()
self.model_update_result = None self.model_update_result = None
async def generate_request( async def generate_request(
self, obj: Union[GenerateReqInput, EmbeddingReqInput], request=None self,
obj: Union[GenerateReqInput, EmbeddingReqInput],
request: Optional[fastapi.Request] = None,
): ):
if self.to_create_loop: if self.to_create_loop:
self.create_handle_loop() self.create_handle_loop()
@@ -160,7 +163,7 @@ class TokenizerManager:
async def _handle_single_request( async def _handle_single_request(
self, self,
obj: Union[GenerateReqInput, EmbeddingReqInput], obj: Union[GenerateReqInput, EmbeddingReqInput],
request, request: Optional[fastapi.Request] = None,
index: Optional[int] = None, index: Optional[int] = None,
is_cache_for_prefill: Optional[bool] = False, is_cache_for_prefill: Optional[bool] = False,
): ):
@@ -182,8 +185,8 @@ class TokenizerManager:
) )
if self.is_generation: if self.is_generation:
pixel_values, image_hash, image_size = await self._get_pixel_values( pixel_values, image_hashes, image_sizes = await self._get_pixel_values(
obj.image_data obj.image_data if not_use_index else obj.image_data[index]
) )
return_logprob = ( return_logprob = (
obj.return_logprob if not_use_index else obj.return_logprob[index] obj.return_logprob if not_use_index else obj.return_logprob[index]
@@ -195,7 +198,6 @@ class TokenizerManager:
) )
if return_logprob and logprob_start_len == -1: if return_logprob and logprob_start_len == -1:
logprob_start_len = len(input_ids) - 1 logprob_start_len = len(input_ids) - 1
top_logprobs_num = ( top_logprobs_num = (
obj.top_logprobs_num obj.top_logprobs_num
if not_use_index if not_use_index
@@ -238,13 +240,14 @@ class TokenizerManager:
sampling_params = SamplingParams(**obj.sampling_params[0]) sampling_params = SamplingParams(**obj.sampling_params[0])
sampling_params.max_new_tokens = 0 sampling_params.max_new_tokens = 0
pixel_values, image_hash, image_size = await self._get_pixel_values( pixel_values, image_hashes, image_sizes = await self._get_pixel_values(
obj.image_data[0] obj.image_data[0]
) )
return_logprob = obj.return_logprob[0] return_logprob = obj.return_logprob[0]
logprob_start_len = obj.logprob_start_len[0] logprob_start_len = obj.logprob_start_len[0]
top_logprobs_num = obj.top_logprobs_num[0] top_logprobs_num = obj.top_logprobs_num[0]
# Send to the controller
if self.is_generation: if self.is_generation:
if return_logprob and logprob_start_len == -1: if return_logprob and logprob_start_len == -1:
logprob_start_len = len(input_ids) - 1 logprob_start_len = len(input_ids) - 1
@@ -253,8 +256,8 @@ class TokenizerManager:
input_text, input_text,
input_ids, input_ids,
pixel_values, pixel_values,
image_hash, image_hashes,
image_size, image_sizes,
sampling_params, sampling_params,
return_logprob, return_logprob,
logprob_start_len, logprob_start_len,
@@ -268,24 +271,24 @@ class TokenizerManager:
input_ids, input_ids,
sampling_params, sampling_params,
) )
self.send_to_router.send_pyobj(tokenized_obj) self.send_to_router.send_pyobj(tokenized_obj)
# Recv results
event = asyncio.Event() event = asyncio.Event()
state = ReqState([], False, event) state = ReqState([], False, event)
self.rid_to_state[rid] = state self.rid_to_state[rid] = state
if not is_cache_for_prefill: if not is_cache_for_prefill:
async for response in self._wait_for_response( async for response in self._wait_for_response(state, obj, rid, request):
event, state, obj, rid, request
):
yield response yield response
else: else:
assert self.is_generation assert self.is_generation
await self._wait_for_cache_prefill_response(event, state, obj, rid, request) await self._wait_for_cache_prefill_response(state, obj, rid, request)
yield input_ids yield input_ids
async def _handle_batch_request( async def _handle_batch_request(
self, obj: Union[GenerateReqInput, EmbeddingReqInput], request self,
obj: Union[GenerateReqInput, EmbeddingReqInput],
request: Optional[fastapi.Request] = None,
): ):
batch_size = obj.batch_size batch_size = obj.batch_size
if self.is_generation: if self.is_generation:
@@ -340,8 +343,8 @@ class TokenizerManager:
if self.is_generation: if self.is_generation:
if obj.return_logprob[index] and obj.logprob_start_len[index] == -1: if obj.return_logprob[index] and obj.logprob_start_len[index] == -1:
obj.logprob_start_len[index] = len(input_ids) - 1 obj.logprob_start_len[index] = len(input_ids) - 1
pixel_values, image_hash, image_size = await self._get_pixel_values( pixel_values, image_hashes, image_sizes = (
obj.image_data[index] await self._get_pixel_values(obj.image_data[index])
) )
tokenized_obj = TokenizedGenerateReqInput( tokenized_obj = TokenizedGenerateReqInput(
@@ -349,8 +352,8 @@ class TokenizerManager:
input_text, input_text,
input_ids, input_ids,
pixel_values, pixel_values,
image_hash, image_hashes,
image_size, image_sizes,
sampling_params, sampling_params,
obj.return_logprob[index], obj.return_logprob[index],
obj.logprob_start_len[index], obj.logprob_start_len[index],
@@ -372,7 +375,6 @@ class TokenizerManager:
generators.append( generators.append(
self._wait_for_response( self._wait_for_response(
event,
state, state,
obj, obj,
rid, rid,
@@ -388,6 +390,7 @@ class TokenizerManager:
tasks = [asyncio.create_task(gen.__anext__()) for gen in generators] tasks = [asyncio.create_task(gen.__anext__()) for gen in generators]
output_list = [None] * len(tasks) output_list = [None] * len(tasks)
# Recv results
while tasks: while tasks:
done, _ = await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED) done, _ = await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED)
@@ -426,25 +429,18 @@ class TokenizerManager:
sampling_params.verify() sampling_params.verify()
return sampling_params return sampling_params
async def _get_pixel_values(self, image_data):
if image_data is None:
return None, None, None
else:
return await self._get_pixel_values_internal(image_data)
async def _wait_for_response( async def _wait_for_response(
self, self,
event: asyncio.Event,
state: ReqState, state: ReqState,
obj: Union[GenerateReqInput, EmbeddingReqInput], obj: Union[GenerateReqInput, EmbeddingReqInput],
rid: str, rid: str,
request, request: Optional[fastapi.Request] = None,
index: int = None, index: Optional[int] = None,
response_index: int = 0, response_index: int = 0,
): ):
while True: while True:
try: try:
await asyncio.wait_for(event.wait(), timeout=4) await asyncio.wait_for(state.event.wait(), timeout=4)
except asyncio.TimeoutError: except asyncio.TimeoutError:
if request is not None and await request.is_disconnected(): if request is not None and await request.is_disconnected():
for rid in [obj.rid] if obj.is_single else obj.rid: for rid in [obj.rid] if obj.is_single else obj.rid:
@@ -478,16 +474,15 @@ class TokenizerManager:
yield out yield out
break break
event.clear() state.event.clear()
yield out yield out
async def _wait_for_cache_prefill_response( async def _wait_for_cache_prefill_response(
self, self,
event: asyncio.Event,
state: ReqState, state: ReqState,
obj: GenerateReqInput, obj: GenerateReqInput,
rid: str, rid: str,
request, request: Optional[fastapi.Request] = None,
): ):
while True: while True:
try: try:
@@ -514,7 +509,9 @@ class TokenizerManager:
req = AbortReq(rid) req = AbortReq(rid)
self.send_to_router.send_pyobj(req) self.send_to_router.send_pyobj(req)
async def update_weights(self, obj: UpdateWeightReqInput, request): async def update_weights(
self, obj: UpdateWeightReqInput, request: Optional[fastapi.Request] = None
):
if self.to_create_loop: if self.to_create_loop:
self.create_handle_loop() self.create_handle_loop()
@@ -659,12 +656,11 @@ class TokenizerManager:
) )
return top_logprobs return top_logprobs
async def _get_pixel_values_internal(self, image_data, aspect_ratio=None): async def _get_pixel_values(self, image_data: List[Union[str, bytes]]):
aspect_ratio = ( if not image_data:
getattr(self.hf_config, "image_aspect_ratio", None) return None, None, None
if aspect_ratio is None
else aspect_ratio aspect_ratio = getattr(self.hf_config, "image_aspect_ratio", None)
)
grid_pinpoints = ( grid_pinpoints = (
self.hf_config.image_grid_pinpoints self.hf_config.image_grid_pinpoints
if hasattr(self.hf_config, "image_grid_pinpoints") if hasattr(self.hf_config, "image_grid_pinpoints")
@@ -673,35 +669,42 @@ class TokenizerManager:
) )
if isinstance(image_data, list) and len(image_data) > 0: if isinstance(image_data, list) and len(image_data) > 0:
pixel_values, image_hash, image_size = [], [], [] # Multiple images
if len(image_data) > 1: if len(image_data) > 1:
aspect_ratio = "pad" # LLaVA OneVision Handling: more than one image --> interleaved image mode or video mode. We do not use anyres aspect_ratio = "pad" # LLaVA OneVision Handling: more than one image --> interleaved image mode or video mode. We do not use anyres
pixel_values, image_hashes, image_sizes = [], [], []
for img_data in image_data: for img_data in image_data:
pixel_v, image_h, image_s = await self._process_single_image( pixel_v, image_h, image_s = await self._process_single_image(
img_data, aspect_ratio, grid_pinpoints img_data, aspect_ratio, grid_pinpoints
) )
pixel_values.append(pixel_v) pixel_values.append(pixel_v)
image_hash.append(image_h) image_hashes.append(image_h)
image_size.append(image_s) image_sizes.append(image_s)
pixel_values = np.stack(pixel_values, axis=0)
if isinstance(pixel_values[0], np.ndarray):
pixel_values = np.stack(pixel_values, axis=0)
else: else:
# A single image
pixel_values, image_hash, image_size = await self._process_single_image( pixel_values, image_hash, image_size = await self._process_single_image(
image_data[0], aspect_ratio, grid_pinpoints image_data[0], aspect_ratio, grid_pinpoints
) )
image_hash = [image_hash] image_hashes = [image_hash]
image_size = [image_size] image_sizes = [image_size]
elif isinstance(image_data, str): elif isinstance(image_data, str):
# A single image
pixel_values, image_hash, image_size = await self._process_single_image( pixel_values, image_hash, image_size = await self._process_single_image(
image_data, aspect_ratio, grid_pinpoints image_data, aspect_ratio, grid_pinpoints
) )
image_hash = [image_hash] image_hashes = [image_hash]
image_size = [image_size] image_sizes = [image_size]
else: else:
pixel_values, image_hash, image_size = None, None, None raise ValueError(f"Invalid image data: {image_data}")
return pixel_values, image_hash, image_size return pixel_values, image_hashes, image_sizes
async def _process_single_image(self, image_data, aspect_ratio, grid_pinpoints): async def _process_single_image(
self, image_data: Union[bytes, str], aspect_ratio: str, grid_pinpoints: str
):
if self.executor is not None: if self.executor is not None:
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
return await loop.run_in_executor( return await loop.run_in_executor(
@@ -732,12 +735,16 @@ def init_global_processor(server_args: ServerArgs):
def _process_single_image_task( def _process_single_image_task(
image_data, image_aspect_ratio=None, image_grid_pinpoints=None, processor=None image_data: Union[str, bytes],
image_aspect_ratio: Optional[str] = None,
image_grid_pinpoints: Optional[str] = None,
processor=None,
): ):
try: try:
processor = processor or global_processor processor = processor or global_processor
image, image_size = load_image(image_data) image, image_size = load_image(image_data)
if image_size is not None: if image_size is not None:
# It is a video with multiple images
image_hash = hash(image_data) image_hash = hash(image_data)
pixel_values = processor.image_processor(image)["pixel_values"] pixel_values = processor.image_processor(image)["pixel_values"]
for _ in range(len(pixel_values)): for _ in range(len(pixel_values)):
@@ -745,6 +752,7 @@ def _process_single_image_task(
pixel_values = np.stack(pixel_values, axis=0) pixel_values = np.stack(pixel_values, axis=0)
return pixel_values, image_hash, image_size return pixel_values, image_hash, image_size
else: else:
# It is an image
image_hash = hash(image_data) image_hash = hash(image_data)
if image_aspect_ratio == "pad": if image_aspect_ratio == "pad":
image = expand2square( image = expand2square(
@@ -754,13 +762,18 @@ def _process_single_image_task(
pixel_values = processor.image_processor(image.convert("RGB"))[ pixel_values = processor.image_processor(image.convert("RGB"))[
"pixel_values" "pixel_values"
][0] ][0]
elif image_aspect_ratio == "anyres" or "anyres_max" in image_aspect_ratio: elif image_aspect_ratio == "anyres" or (
image_aspect_ratio is not None and "anyres_max" in image_aspect_ratio
):
pixel_values = process_anyres_image( pixel_values = process_anyres_image(
image, processor.image_processor, image_grid_pinpoints image, processor.image_processor, image_grid_pinpoints
) )
else: else:
pixel_values = processor.image_processor(image)["pixel_values"][0] pixel_values = processor.image_processor(image)["pixel_values"][0]
pixel_values = pixel_values.astype(np.float16)
if isinstance(pixel_values, np.ndarray):
pixel_values = pixel_values.astype(np.float16)
return pixel_values, image_hash, image.size return pixel_values, image_hash, image.size
except Exception: except Exception:
logger.error("Exception in TokenizerManager:\n" + get_exception_traceback()) logger.error("Exception in TokenizerManager:\n" + get_exception_traceback())

View File

@@ -108,7 +108,7 @@ class ModelTpServer:
if server_args.skip_tokenizer_init: if server_args.skip_tokenizer_init:
self.tokenizer = self.processor = None self.tokenizer = self.processor = None
else: else:
if is_multimodal_model(server_args.model_path): if is_multimodal_model(self.model_config.hf_config.architectures):
self.processor = get_processor( self.processor = get_processor(
server_args.tokenizer_path, server_args.tokenizer_path,
tokenizer_mode=server_args.tokenizer_mode, tokenizer_mode=server_args.tokenizer_mode,
@@ -333,26 +333,24 @@ class ModelTpServer:
if self.model_runner.is_generation: if self.model_runner.is_generation:
req.pixel_values = recv_req.pixel_values req.pixel_values = recv_req.pixel_values
if req.pixel_values is not None: if req.pixel_values is not None:
image_hash = ( # Use image hash as fake token_ids, which is then used
hash(tuple(recv_req.image_hash)) # for prefix matching
if isinstance(recv_req.image_hash, list) image_hash = hash(tuple(recv_req.image_hashes))
else recv_req.image_hash
)
req.pad_value = [ req.pad_value = [
(image_hash) % self.model_config.vocab_size, (image_hash) % self.model_config.vocab_size,
(image_hash >> 16) % self.model_config.vocab_size, (image_hash >> 16) % self.model_config.vocab_size,
(image_hash >> 32) % self.model_config.vocab_size, (image_hash >> 32) % self.model_config.vocab_size,
(image_hash >> 64) % self.model_config.vocab_size, (image_hash >> 64) % self.model_config.vocab_size,
] ]
req.image_size = recv_req.image_size req.image_sizes = recv_req.image_sizes
( (
req.origin_input_ids, req.origin_input_ids,
req.image_offset, req.image_offsets,
) = self.model_runner.model.pad_input_ids( ) = self.model_runner.model.pad_input_ids(
req.origin_input_ids_unpadded, req.origin_input_ids_unpadded,
req.pad_value, req.pad_value,
req.pixel_values.shape, req.pixel_values,
req.image_size, req.image_sizes,
) )
req.return_logprob = recv_req.return_logprob req.return_logprob = recv_req.return_logprob
req.logprob_start_len = recv_req.logprob_start_len req.logprob_start_len = recv_req.logprob_start_len
@@ -368,6 +366,7 @@ class ModelTpServer:
req.jump_forward_map = self.jump_forward_cache.query( req.jump_forward_map = self.jump_forward_cache.query(
computed_regex_string computed_regex_string
) )
# Init regex fsm # Init regex fsm
elif req.sampling_params.regex is not None: elif req.sampling_params.regex is not None:
req.regex_fsm = self.regex_fsm_cache.query(req.sampling_params.regex) req.regex_fsm = self.regex_fsm_cache.query(req.sampling_params.regex)

View File

@@ -16,7 +16,7 @@ limitations under the License.
"""ModelRunner runs the forward passes of the models.""" """ModelRunner runs the forward passes of the models."""
from dataclasses import dataclass from dataclasses import dataclass
from enum import IntEnum, auto from enum import IntEnum, auto
from typing import TYPE_CHECKING, List, Optional from typing import TYPE_CHECKING, List
import numpy as np import numpy as np
import torch import torch
@@ -58,6 +58,7 @@ class InputMetadata:
# For extend # For extend
extend_seq_lens: torch.Tensor = None extend_seq_lens: torch.Tensor = None
extend_prefix_lens: torch.Tensor = None
extend_start_loc: torch.Tensor = None extend_start_loc: torch.Tensor = None
extend_no_prefix: bool = None extend_no_prefix: bool = None
@@ -69,8 +70,8 @@ class InputMetadata:
# For multimodal # For multimodal
pixel_values: List[torch.Tensor] = None pixel_values: List[torch.Tensor] = None
image_sizes: List[List[int]] = None image_sizes: List[List[List[int]]] = None
image_offsets: List[int] = None image_offsets: List[List[int]] = None
# Trition attention backend # Trition attention backend
triton_max_seq_len: int = 0 triton_max_seq_len: int = 0
@@ -87,20 +88,8 @@ class InputMetadata:
def init_multimuldal_info(self, batch: ScheduleBatch): def init_multimuldal_info(self, batch: ScheduleBatch):
reqs = batch.reqs reqs = batch.reqs
self.pixel_values = [r.pixel_values for r in reqs] self.pixel_values = [r.pixel_values for r in reqs]
self.image_sizes = [r.image_size for r in reqs] self.image_sizes = [r.image_sizes for r in reqs]
self.image_offsets = [] self.image_offsets = [r.image_offsets for r in reqs]
for r in reqs:
if isinstance(r.image_offset, list):
self.image_offsets.append(
[
(image_offset - len(r.prefix_indices))
for image_offset in r.image_offset
]
)
elif isinstance(r.image_offset, int):
self.image_offsets.append(r.image_offset - len(r.prefix_indices))
elif r.image_offset is None:
self.image_offsets.append(0)
def compute_positions(self, batch: ScheduleBatch): def compute_positions(self, batch: ScheduleBatch):
position_ids_offsets = batch.position_ids_offsets position_ids_offsets = batch.position_ids_offsets
@@ -153,6 +142,7 @@ class InputMetadata:
for i, r in enumerate(batch.reqs) for i, r in enumerate(batch.reqs)
] ]
self.extend_seq_lens = torch.tensor(extend_lens_cpu, device="cuda") self.extend_seq_lens = torch.tensor(extend_lens_cpu, device="cuda")
self.extend_prefix_lens = torch.tensor(batch.prefix_lens_cpu, device="cuda")
self.extend_start_loc = torch.zeros_like(self.seq_lens) self.extend_start_loc = torch.zeros_like(self.seq_lens)
self.extend_start_loc[1:] = torch.cumsum(self.extend_seq_lens[:-1], dim=0) self.extend_start_loc[1:] = torch.cumsum(self.extend_seq_lens[:-1], dim=0)
self.extend_no_prefix = all(l == 0 for l in batch.prefix_lens_cpu) self.extend_no_prefix = all(l == 0 for l in batch.prefix_lens_cpu)
@@ -238,10 +228,10 @@ class InputMetadata:
prefix_lens_cpu, prefix_lens_cpu,
flashinfer_use_ragged, flashinfer_use_ragged,
): ):
if self.forward_mode != ForwardMode.DECODE: if self.forward_mode == ForwardMode.DECODE:
prefix_lens = torch.tensor(prefix_lens_cpu, device="cuda")
else:
prefix_lens = None prefix_lens = None
else:
prefix_lens = self.extend_prefix_lens
update_flashinfer_indices( update_flashinfer_indices(
self.forward_mode, self.forward_mode,

View File

@@ -50,7 +50,7 @@ from sglang.srt.mem_cache.memory_pool import (
MLATokenToKVPool, MLATokenToKVPool,
ReqToTokenPool, ReqToTokenPool,
) )
from sglang.srt.model_config import AttentionArch from sglang.srt.model_config import AttentionArch, ModelConfig
from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetadata from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetadata
from sglang.srt.server_args import ServerArgs from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import ( from sglang.srt.utils import (
@@ -69,7 +69,7 @@ logger = logging.getLogger(__name__)
class ModelRunner: class ModelRunner:
def __init__( def __init__(
self, self,
model_config, model_config: ModelConfig,
mem_fraction_static: float, mem_fraction_static: float,
gpu_id: int, gpu_id: int,
tp_rank: int, tp_rank: int,
@@ -85,7 +85,9 @@ class ModelRunner:
self.tp_size = tp_size self.tp_size = tp_size
self.nccl_port = nccl_port self.nccl_port = nccl_port
self.server_args = server_args self.server_args = server_args
self.is_multimodal_model = is_multimodal_model(self.model_config) self.is_multimodal_model = is_multimodal_model(
self.model_config.hf_config.architectures
)
global_server_args_dict.update( global_server_args_dict.update(
{ {
"disable_flashinfer": server_args.disable_flashinfer, "disable_flashinfer": server_args.disable_flashinfer,
@@ -95,6 +97,13 @@ class ModelRunner:
} }
) )
if self.is_multimodal_model:
logger.info(
"Automatically turn off --chunked-prefill-size and adjust --mem-fraction-static for multimodal models."
)
server_args.chunked_prefill_size = None
server_args.mem_fraction_static *= 0.95
min_per_gpu_memory = self.init_torch_distributed() min_per_gpu_memory = self.init_torch_distributed()
self.load_model() self.load_model()
self.init_memory_pool( self.init_memory_pool(
@@ -507,9 +516,9 @@ class ModelRunner:
raise Exception( raise Exception(
f"Capture cuda graph failed: {e}\n" f"Capture cuda graph failed: {e}\n"
"Possible solutions:\n" "Possible solutions:\n"
"1. disable torch compile by not using --enable-torch-compile\n" "1. disable cuda graph by --disable-cuda-graph\n"
"2. disable cuda graph by --disable-cuda-graph\n" "2. set --mem-fraction-static to a smaller value\n"
"3. set --mem-fraction-static to a smaller value\n" "3. disable torch compile by not using --enable-torch-compile\n"
"Open an issue on GitHub https://github.com/sgl-project/sglang/issues/new/choose \n" "Open an issue on GitHub https://github.com/sgl-project/sglang/issues/new/choose \n"
) )

View File

@@ -17,7 +17,7 @@ limitations under the License.
# Adapted from # Adapted from
# https://github.com/THUDM/ChatGLM2-6B # https://github.com/THUDM/ChatGLM2-6B
"""Inference-only ChatGLM model compatible with THUDM weights.""" """Inference-only ChatGLM model compatible with THUDM weights."""
from typing import Iterable, List, Optional, Tuple from typing import Iterable, Optional, Tuple
import torch import torch
from torch import nn from torch import nn

View File

@@ -273,9 +273,9 @@ class Grok1Model(nn.Module):
) -> torch.Tensor: ) -> torch.Tensor:
if input_embeds is None: if input_embeds is None:
hidden_states = self.embed_tokens(input_ids) hidden_states = self.embed_tokens(input_ids)
hidden_states.mul_(self.config.embedding_multiplier_scale)
else: else:
hidden_states = input_embeds hidden_states = input_embeds
hidden_states.mul_(self.config.embedding_multiplier_scale)
for i in range(len(self.layers)): for i in range(len(self.layers)):
hidden_states = self.layers[i](positions, hidden_states, input_metadata) hidden_states = self.layers[i](positions, hidden_states, input_metadata)
@@ -284,7 +284,7 @@ class Grok1Model(nn.Module):
return hidden_states return hidden_states
class Grok1ModelForCausalLM(nn.Module): class Grok1ForCausalLM(nn.Module):
def __init__( def __init__(
self, self,
config: PretrainedConfig, config: PretrainedConfig,
@@ -415,4 +415,10 @@ def _prepare_presharded_weights(
return hf_folder, hf_weights_files, use_safetensors return hf_folder, hf_weights_files, use_safetensors
EntryClass = Grok1ModelForCausalLM class Grok1ModelForCausalLM(Grok1ForCausalLM):
"""An alias for backward-compatbility."""
pass
EntryClass = [Grok1ForCausalLM, Grok1ModelForCausalLM]

View File

@@ -357,6 +357,9 @@ class LlamaForCausalLM(nn.Module):
# Models trained using ColossalAI may include these tensors in # Models trained using ColossalAI may include these tensors in
# the checkpoint. Skip them. # the checkpoint. Skip them.
return return
if name.startswith("model.vision_tower") and name not in params_dict:
return
for param_name, weight_name, shard_id in stacked_params_mapping: for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name: if weight_name not in name:
continue continue
@@ -364,8 +367,6 @@ class LlamaForCausalLM(nn.Module):
# Skip loading extra bias for GPTQ models. # Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict: if name.endswith(".bias") and name not in params_dict:
continue continue
if name.startswith("model.vision_tower") and name not in params_dict:
continue
param = params_dict[name] param = params_dict[name]
weight_loader = param.weight_loader weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id) weight_loader(param, loaded_weight, shard_id)
@@ -374,8 +375,6 @@ class LlamaForCausalLM(nn.Module):
# Skip loading extra bias for GPTQ models. # Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict: if name.endswith(".bias") and name not in params_dict:
return return
if name.startswith("model.vision_tower") and name not in params_dict:
return
param = params_dict[name] param = params_dict[name]
weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight) weight_loader(param, loaded_weight)

View File

@@ -103,8 +103,6 @@ class LlamaForClassification(nn.Module):
# Skip loading extra bias for GPTQ models. # Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict: if name.endswith(".bias") and name not in params_dict:
continue continue
if name.startswith("model.vision_tower") and name not in params_dict:
continue
param = params_dict[name] param = params_dict[name]
weight_loader = param.weight_loader weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id) weight_loader(param, loaded_weight, shard_id)
@@ -113,8 +111,6 @@ class LlamaForClassification(nn.Module):
# Skip loading extra bias for GPTQ models. # Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict: if name.endswith(".bias") and name not in params_dict:
continue continue
if name.startswith("model.vision_tower") and name not in params_dict:
continue
param = params_dict[name] param = params_dict[name]
weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight) weight_loader(param, loaded_weight)

View File

@@ -57,6 +57,9 @@ class LlamaEmbeddingModel(nn.Module):
# Models trained using ColossalAI may include these tensors in # Models trained using ColossalAI may include these tensors in
# the checkpoint. Skip them. # the checkpoint. Skip them.
return return
if name.startswith("model.vision_tower") and name not in params_dict:
return
for param_name, weight_name, shard_id in stacked_params_mapping: for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name: if weight_name not in name:
continue continue
@@ -64,8 +67,6 @@ class LlamaEmbeddingModel(nn.Module):
# Skip loading extra bias for GPTQ models. # Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict: if name.endswith(".bias") and name not in params_dict:
continue continue
if name.startswith("model.vision_tower") and name not in params_dict:
continue
param = params_dict[name] param = params_dict[name]
weight_loader = param.weight_loader weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id) weight_loader(param, loaded_weight, shard_id)
@@ -74,8 +75,6 @@ class LlamaEmbeddingModel(nn.Module):
# Skip loading extra bias for GPTQ models. # Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict: if name.endswith(".bias") and name not in params_dict:
return return
if name.startswith("model.vision_tower") and name not in params_dict:
return
param = params_dict[name] param = params_dict[name]
weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight) weight_loader(param, loaded_weight)

View File

@@ -28,7 +28,6 @@ from transformers import (
LlavaConfig, LlavaConfig,
MistralConfig, MistralConfig,
Qwen2Config, Qwen2Config,
SiglipVisionConfig,
SiglipVisionModel, SiglipVisionModel,
) )
from transformers.models.llava.modeling_llava import LlavaMultiModalProjector from transformers.models.llava.modeling_llava import LlavaMultiModalProjector
@@ -66,13 +65,18 @@ class LlavaLlamaForCausalLM(nn.Module):
torch.empty(config.text_config.hidden_size, dtype=torch.float16) torch.empty(config.text_config.hidden_size, dtype=torch.float16)
) )
def pad_input_ids(self, input_ids, pad_value, pt_shape=None, image_size=None): def pad_input_ids(
self,
input_ids: List[int],
pad_value: List[int],
pixel_values: List,
image_sizes: List[List[int]],
):
# hardcode for spatial_unpad + anyres # hardcode for spatial_unpad + anyres
image_aspect_ratio = "anyres" if len(image_size) == 1 else "pad" image_aspect_ratio = "anyres" if len(image_sizes) == 1 else "pad"
offset_list = [] offset_list = []
for image_s in image_size: for image_s in image_sizes:
if len(image_size) > 16: if len(image_sizes) > 16:
# 2x2 pooling with stride 2 # 2x2 pooling with stride 2
new_image_feature_len = ( new_image_feature_len = (
math.ceil(self.image_size / self.patch_size / 2) ** 2 math.ceil(self.image_size / self.patch_size / 2) ** 2
@@ -153,17 +157,15 @@ class LlavaLlamaForCausalLM(nn.Module):
if input_metadata.forward_mode == ForwardMode.EXTEND: if input_metadata.forward_mode == ForwardMode.EXTEND:
bs = input_metadata.batch_size bs = input_metadata.batch_size
# Embed text input # Embed text inputs
input_embeds = self.language_model.model.embed_tokens(input_ids) input_embeds = self.language_model.model.embed_tokens(input_ids)
# Embed vision input
need_vision = ( # Whether the requests need vision inputs
(positions[input_metadata.extend_start_loc] < self.image_feature_len) max_image_offset = np.array(
.cpu() [max(image_offsets[i]) if image_offsets[i] else -1 for i in range(bs)]
.numpy()
) )
# FIXME: We need to substract the length of the system prompt start_positions = positions[input_metadata.extend_start_loc].cpu().numpy()
has_pixel = np.array([pixel_values[i] is not None for i in range(bs)]) need_vision = start_positions <= max_image_offset
need_vision = need_vision & has_pixel
if need_vision.any(): if need_vision.any():
pixel_values = [pixel_values[i] for i in range(bs) if need_vision[i]] pixel_values = [pixel_values[i] for i in range(bs) if need_vision[i]]
@@ -332,31 +334,35 @@ class LlavaLlamaForCausalLM(nn.Module):
new_image_features.append(image_feature) new_image_features.append(image_feature)
image_features = new_image_features image_features = new_image_features
# Fill in the placeholder for the image
extend_start_loc_cpu = input_metadata.extend_start_loc.cpu().numpy() extend_start_loc_cpu = input_metadata.extend_start_loc.cpu().numpy()
prefix_lens_cpu = input_metadata.extend_prefix_lens.cpu().numpy()
pt = 0 pt = 0
for i in range(bs): for i in range(bs):
if not need_vision[i]: if not need_vision[i]:
continue continue
start_idx = extend_start_loc_cpu[i] start_idx = extend_start_loc_cpu[i]
pad_dim = image_features[pt].shape[-1] # 576, 4096 prefix_len = prefix_lens_cpu[i]
dim = input_embeds.shape[1]
assert ( # Multiple images
pad_dim == dim for j, image_offset in enumerate(image_offsets[i]):
), "invalid pad_dim={}, input_embed_dim={}!".format(pad_dim, dim) if image_offset < prefix_len:
# Fill in the placeholder for the image continue
try:
for j, image_off in enumerate(image_offsets[i]): tmp_image_feature = image_features[pt][j]
# print("actual image_features length: ", image_features[pt][j].shape[0]) pad_len = tmp_image_feature.shape[0]
pad_len = image_features[pt][j].shape[0]
input_embeds[ left_idx = start_idx + (image_offset - prefix_len)
start_idx + image_off : start_idx + image_off + pad_len right_idx = start_idx + (image_offset - prefix_len) + pad_len
] = image_features[pt][j] try:
except RuntimeError as e: input_embeds[left_idx:right_idx] = tmp_image_feature
print(f"RuntimeError in llava image encoding: {e}") except RuntimeError as e:
print(image_features[pt].shape) print(f"RuntimeError in image encoding: {e}")
print(input_embeds.shape) print(f"{input_embeds.shape=}, {tmp_image_feature.shape=}")
print(start_idx, image_offsets[i]) print(
f"{start_idx=}, {image_offset=}, {prefix_len=}, {pad_len=}"
)
pt += 1 pt += 1
return self.language_model( return self.language_model(
@@ -366,8 +372,9 @@ class LlavaLlamaForCausalLM(nn.Module):
return self.language_model(input_ids, positions, input_metadata) return self.language_model(input_ids, positions, input_metadata)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
# load clip vision model by cfg['mm_vision_tower']: # Load clip vision model by cfg['mm_vision_tower']:
# huggingface_name or path_of_clip_relative_to_llava_model_dir # huggingface_name or path_of_clip_relative_to_llava_model_dir
# We put the initialization here instead of __init__ to allow it being reused by other subclasses.
vision_path = self.config.mm_vision_tower vision_path = self.config.mm_vision_tower
if "clip" in vision_path: if "clip" in vision_path:
self.vision_tower = CLIPVisionModel.from_pretrained( self.vision_tower = CLIPVisionModel.from_pretrained(
@@ -422,8 +429,6 @@ class LlavaLlamaForCausalLM(nn.Module):
# load language model # load language model
self.language_model.load_weights(weights) self.language_model.load_weights(weights)
monkey_path_clip_vision_embed_forward()
@property @property
def num_patches_per_side(self): def num_patches_per_side(self):
return self.image_size // self.patch_size return self.image_size // self.patch_size
@@ -495,36 +500,4 @@ class LlavaMistralForCausalLM(LlavaLlamaForCausalLM):
) )
first_call = True
def clip_vision_embed_forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
batch_size = pixel_values.shape[0]
# Move this conv layer to CPU to avoid a bug in torch >= 2.1 on A10G.
global first_call
if first_call:
self.patch_embedding.cpu().float()
first_call = False
pixel_values = pixel_values.to(dtype=torch.float32, device="cpu")
patch_embeds = self.patch_embedding(pixel_values).cuda().half()
patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
class_embeds = self.class_embedding.expand(batch_size, 1, -1)
embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
embeddings = embeddings + self.position_embedding(self.position_ids)
return embeddings
def monkey_path_clip_vision_embed_forward():
import transformers
setattr(
transformers.models.clip.modeling_clip.CLIPVisionEmbeddings,
"forward",
clip_vision_embed_forward,
)
EntryClass = [LlavaLlamaForCausalLM, LlavaQwenForCausalLM, LlavaMistralForCausalLM] EntryClass = [LlavaLlamaForCausalLM, LlavaQwenForCausalLM, LlavaMistralForCausalLM]

View File

@@ -26,11 +26,6 @@ from vllm.config import CacheConfig
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from sglang.srt.mm_utils import (
get_anyres_image_grid_shape,
unpad_image,
unpad_image_shape,
)
from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetadata from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetadata
from sglang.srt.models.llama2 import LlamaForCausalLM from sglang.srt.models.llama2 import LlamaForCausalLM
@@ -59,23 +54,14 @@ class LlavaVidForCausalLM(nn.Module):
torch.empty(config.text_config.hidden_size, dtype=torch.float16) torch.empty(config.text_config.hidden_size, dtype=torch.float16)
) )
def pad_input_ids(self, input_ids, pad_value, pt_shape=None, image_size=None): def pad_input_ids(
self,
input_ids: List[int],
pad_value: List[int],
pixel_values: List,
image_sizes: List[List[int]],
):
new_image_feature_len = self.image_feature_len new_image_feature_len = self.image_feature_len
# now only support spatial_unpad + anyres
# if self.mm_patch_merge_type.startswith("spatial"):
# height = width = self.num_patches_per_side
# if pt_shape[0] > 1:
# if self.image_aspect_ratio == "anyres":
# num_patch_width, num_patch_height = get_anyres_image_grid_shape(
# image_size,
# self.image_grid_pinpoints,
# self.vision_tower.config.image_size,
# )
# if "unpad" in self.mm_patch_merge_type:
# h = num_patch_height * height
# w = num_patch_width * width
# new_h, new_w = unpad_image_shape(h, w, image_size)
# new_image_feature_len += new_h * (new_w + 1)
pad_ids = pad_value * ( pad_ids = pad_value * (
(new_image_feature_len + len(pad_value)) // len(pad_value) (new_image_feature_len + len(pad_value)) // len(pad_value)
@@ -87,7 +73,7 @@ class LlavaVidForCausalLM(nn.Module):
+ pad_ids[:new_image_feature_len] + pad_ids[:new_image_feature_len]
+ input_ids[offset + 1 :] + input_ids[offset + 1 :]
) )
return new_input_ids, offset return new_input_ids, [offset]
def encode_images(self, pixel_values: torch.Tensor) -> torch.Tensor: def encode_images(self, pixel_values: torch.Tensor) -> torch.Tensor:
image_outputs = self.vision_tower(pixel_values, output_hidden_states=True) image_outputs = self.vision_tower(pixel_values, output_hidden_states=True)
@@ -133,22 +119,18 @@ class LlavaVidForCausalLM(nn.Module):
if input_metadata.forward_mode == ForwardMode.EXTEND: if input_metadata.forward_mode == ForwardMode.EXTEND:
bs = input_metadata.batch_size bs = input_metadata.batch_size
# Embed text input # Embed text inputs
input_embeds = self.language_model.model.embed_tokens(input_ids) input_embeds = self.language_model.model.embed_tokens(input_ids)
# Embed vision input # Whether the requests need vision inputs
need_vision = ( max_image_offset = np.array(
(positions[input_metadata.extend_start_loc] < self.image_feature_len) [max(image_offsets[i]) if image_offsets[i] else -1 for i in range(bs)]
.cpu()
.numpy()
) )
# FIXME: We need to substract the length of the system prompt start_positions = positions[input_metadata.extend_start_loc].cpu().numpy()
has_pixel = np.array([pixel_values[i] is not None for i in range(bs)]) need_vision = start_positions <= max_image_offset
need_vision = need_vision & has_pixel
if need_vision.any(): if need_vision.any():
pixel_values = [pixel_values[i] for i in range(bs) if need_vision[i]] pixel_values = [pixel_values[i] for i in range(bs) if need_vision[i]]
image_sizes = [image_sizes[i] for i in range(bs) if need_vision[i]]
########## Encode Image ######## ########## Encode Image ########
@@ -183,31 +165,36 @@ class LlavaVidForCausalLM(nn.Module):
new_image_features.append(image_feature.flatten(0, 1)) new_image_features.append(image_feature.flatten(0, 1))
image_features = new_image_features image_features = new_image_features
# Fill in the placeholder for the image
extend_start_loc_cpu = input_metadata.extend_start_loc.cpu().numpy() extend_start_loc_cpu = input_metadata.extend_start_loc.cpu().numpy()
prefix_lens_cpu = input_metadata.extend_prefix_lens.cpu().numpy()
pt = 0 pt = 0
for i in range(bs): for i in range(bs):
if not need_vision[i]: if not need_vision[i]:
continue continue
start_idx = extend_start_loc_cpu[i] start_idx = extend_start_loc_cpu[i]
pad_len, pad_dim = image_features[pt].shape # 576, 4096 prefix_len = prefix_lens_cpu[i]
dim = input_embeds.shape[1]
assert ( # Multiple images
pad_dim == dim for image_offset in image_offsets[i]:
), "invalid pad_dim={}, input_embed_dim={}!".format(pad_dim, dim) if image_offset < prefix_len:
# Fill in the placeholder for the image continue
try:
input_embeds[ tmp_image_feature = image_features[pt]
start_idx pad_len = tmp_image_feature.shape[0]
+ image_offsets[i] : start_idx
+ image_offsets[i] left_idx = start_idx + (image_offset - prefix_len)
+ pad_len right_idx = start_idx + (image_offset - prefix_len) + pad_len
] = image_features[pt] try:
except RuntimeError as e: input_embeds[left_idx:right_idx] = tmp_image_feature
print(f"RuntimeError in llava image encoding: {e}") except RuntimeError as e:
print(input_embeds.shape) print(f"RuntimeError in image encoding: {e}")
print(start_idx, image_offsets[i]) print(f"{input_embeds.shape=}, {tmp_image_feature.shape=}")
pt += 1 print(
f"{start_idx=}, {image_offset=}, {prefix_len=}, {pad_len=}"
)
pt += 1
return self.language_model( return self.language_model(
input_ids, positions, input_metadata, input_embeds=input_embeds input_ids, positions, input_metadata, input_embeds=input_embeds
@@ -216,8 +203,9 @@ class LlavaVidForCausalLM(nn.Module):
return self.language_model(input_ids, positions, input_metadata) return self.language_model(input_ids, positions, input_metadata)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
# load clip vision model by cfg['mm_vision_tower']: # Load clip vision model by cfg['mm_vision_tower']:
# huggingface_name or path_of_clip_relative_to_llava_model_dir # huggingface_name or path_of_clip_relative_to_llava_model_dir
# We put the initialization here instead of __init__ to allow it being reused by other subclasses.
vision_path = self.config.mm_vision_tower vision_path = self.config.mm_vision_tower
self.vision_tower = CLIPVisionModel.from_pretrained( self.vision_tower = CLIPVisionModel.from_pretrained(
vision_path, torch_dtype=torch.float16 vision_path, torch_dtype=torch.float16
@@ -271,43 +259,9 @@ class LlavaVidForCausalLM(nn.Module):
# load language model # load language model
self.language_model.load_weights(weights) self.language_model.load_weights(weights)
monkey_path_clip_vision_embed_forward()
@property @property
def num_patches_per_side(self): def num_patches_per_side(self):
return self.image_size // self.patch_size return self.image_size // self.patch_size
first_call = True
def clip_vision_embed_forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
batch_size = pixel_values.shape[0]
# Move this conv layer to CPU to avoid a bug in torch >= 2.1 on A10G.
global first_call
if first_call:
self.patch_embedding.cpu().float()
first_call = False
pixel_values = pixel_values.to(dtype=torch.float32, device="cpu")
patch_embeds = self.patch_embedding(pixel_values).cuda().half()
patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
class_embeds = self.class_embedding.expand(batch_size, 1, -1)
embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
embeddings = embeddings + self.position_embedding(self.position_ids)
return embeddings
def monkey_path_clip_vision_embed_forward():
import transformers
setattr(
transformers.models.clip.modeling_clip.CLIPVisionEmbeddings,
"forward",
clip_vision_embed_forward,
)
EntryClass = LlavaVidForCausalLM EntryClass = LlavaVidForCausalLM

View File

@@ -312,6 +312,9 @@ class Qwen2ForCausalLM(nn.Module):
# Models trained using ColossalAI may include these tensors in # Models trained using ColossalAI may include these tensors in
# the checkpoint. Skip them. # the checkpoint. Skip them.
continue continue
if name.startswith("model.vision_tower") and name not in params_dict:
continue
for param_name, weight_name, shard_id in stacked_params_mapping: for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name: if weight_name not in name:
continue continue
@@ -319,8 +322,6 @@ class Qwen2ForCausalLM(nn.Module):
# Skip loading extra bias for GPTQ models. # Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict: if name.endswith(".bias") and name not in params_dict:
continue continue
if name.startswith("model.vision_tower") and name not in params_dict:
continue
param = params_dict[name] param = params_dict[name]
weight_loader = param.weight_loader weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id) weight_loader(param, loaded_weight, shard_id)
@@ -329,8 +330,6 @@ class Qwen2ForCausalLM(nn.Module):
# Skip loading extra bias for GPTQ models. # Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict: if name.endswith(".bias") and name not in params_dict:
continue continue
if name.startswith("model.vision_tower") and name not in params_dict:
continue
param = params_dict[name] param = params_dict[name]
weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight) weight_loader(param, loaded_weight)

View File

@@ -24,10 +24,7 @@ from vllm.config import CacheConfig
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from sglang.srt.models.llava import ( from sglang.srt.models.llava import LlavaLlamaForCausalLM
LlavaLlamaForCausalLM,
monkey_path_clip_vision_embed_forward,
)
class YiVLForCausalLM(LlavaLlamaForCausalLM): class YiVLForCausalLM(LlavaLlamaForCausalLM):
@@ -50,7 +47,7 @@ class YiVLForCausalLM(LlavaLlamaForCausalLM):
self.config._name_or_path, self.config._name_or_path,
torch_dtype=torch.float16, torch_dtype=torch.float16,
subfolder=self.vision_tower_subfolder, subfolder=self.vision_tower_subfolder,
).cuda() ).to("cuda")
self.vision_tower.eval() self.vision_tower.eval()
@@ -94,8 +91,6 @@ class YiVLForCausalLM(LlavaLlamaForCausalLM):
# load language model # load language model
self.language_model.load_weights(weights) self.language_model.load_weights(weights)
monkey_path_clip_vision_embed_forward()
class YiVLMultiModalProjector(nn.Module): class YiVLMultiModalProjector(nn.Module):
def __init__(self, config: LlavaConfig): def __init__(self, config: LlavaConfig):

View File

@@ -335,12 +335,12 @@ def launch_server(
pipe_detoken_reader, pipe_detoken_writer = mp.Pipe(duplex=False) pipe_detoken_reader, pipe_detoken_writer = mp.Pipe(duplex=False)
if server_args.dp_size == 1: if server_args.dp_size == 1:
start_process = start_controller_process_single start_controller_process = start_controller_process_single
else: else:
start_process = start_controller_process_multi start_controller_process = start_controller_process_multi
proc_controller = mp.Process( proc_controller = mp.Process(
target=start_process, target=start_controller_process,
args=(server_args, port_args, pipe_controller_writer, model_overide_args), args=(server_args, port_args, pipe_controller_writer, model_overide_args),
) )
proc_controller.start() proc_controller.start()
@@ -421,7 +421,7 @@ def _set_envs_and_config(server_args: ServerArgs):
if not server_args.disable_flashinfer: if not server_args.disable_flashinfer:
assert_pkg_version( assert_pkg_version(
"flashinfer", "flashinfer",
"0.1.6", "0.1.5",
"Please uninstall the old version and " "Please uninstall the old version and "
"reinstall the latest version by following the instructions " "reinstall the latest version by following the instructions "
"at https://docs.flashinfer.ai/installation.html.", "at https://docs.flashinfer.ai/installation.html.",

View File

@@ -26,7 +26,7 @@ import struct
import time import time
from importlib.metadata import PackageNotFoundError, version from importlib.metadata import PackageNotFoundError, version
from io import BytesIO from io import BytesIO
from typing import List, Optional from typing import List, Optional, Union
import numpy as np import numpy as np
import psutil import psutil
@@ -193,35 +193,16 @@ def allocate_init_ports(
return ret_ports[0], ret_ports[1:num_ports_needed] return ret_ports[0], ret_ports[1:num_ports_needed]
def get_int_token_logit_bias(tokenizer, vocab_size): def is_multimodal_model(model_architectures):
"""Get the logit bias for integer-only tokens.""" if (
# a bug when model's vocab size > tokenizer.vocab_size "LlavaLlamaForCausalLM" in model_architectures
if tokenizer == None: or "LlavaQwenForCausalLM" in model_architectures
return [-1e5] * vocab_size or "LlavaMistralForCausalLM" in model_architectures
vocab_size = tokenizer.vocab_size or "LlavaVidForCausalLM" in model_architectures
logit_bias = np.zeros(vocab_size, dtype=np.float32) ):
for t_id in range(vocab_size): return True
ss = tokenizer.decode([t_id]).strip() else:
if not (ss.isdigit() or len(ss) == 0 or t_id == tokenizer.eos_token_id): return False
logit_bias[t_id] = -1e5
return logit_bias
def is_multimodal_model(model):
from sglang.srt.model_config import ModelConfig
if isinstance(model, str):
model = model.lower()
return "llava" in model or "yi-vl" in model or "llava-next" in model
if isinstance(model, ModelConfig):
model_path = model.path.lower()
return (
"llava" in model_path or "yi-vl" in model_path or "llava-next" in model_path
)
raise ValueError("unrecognized type")
def is_generation_model(model_architectures, is_embedding: bool = False): def is_generation_model(model_architectures, is_embedding: bool = False):
@@ -317,12 +298,14 @@ def decode_video_base64(video_base64):
) # Return an empty array and size tuple if no frames were found ) # Return an empty array and size tuple if no frames were found
def load_image(image_file): def load_image(image_file: Union[str, bytes]):
from PIL import Image from PIL import Image
image = image_size = None image = image_size = None
if image_file.startswith("http://") or image_file.startswith("https://"): if isinstance(image_file, bytes):
image = Image.open(BytesIO(image_file))
elif image_file.startswith("http://") or image_file.startswith("https://"):
timeout = int(os.getenv("REQUEST_TIMEOUT", "3")) timeout = int(os.getenv("REQUEST_TIMEOUT", "3"))
response = requests.get(image_file, timeout=timeout) response = requests.get(image_file, timeout=timeout)
image = Image.open(BytesIO(response.content)) image = Image.open(BytesIO(response.content))
@@ -334,8 +317,10 @@ def load_image(image_file):
elif image_file.startswith("video:"): elif image_file.startswith("video:"):
image_file = image_file.replace("video:", "") image_file = image_file.replace("video:", "")
image, image_size = decode_video_base64(image_file) image, image_size = decode_video_base64(image_file)
else: elif isinstance(image_file, str):
image = Image.open(BytesIO(base64.b64decode(image_file))) image = Image.open(BytesIO(base64.b64decode(image_file)))
else:
raise ValueError(f"Invalid image: {image}")
return image, image_size return image, image_size

View File

@@ -32,8 +32,6 @@ class TestOpenAIVisionServer(unittest.TestCase):
other_args=[ other_args=[
"--chat-template", "--chat-template",
"chatml-llava", "chatml-llava",
"--chunked-prefill-size",
"16384",
# "--log-requests", # "--log-requests",
], ],
) )