[Fix] Fix llava on multi images (#1247)
This commit is contained in:
26
python/sglang/launch_server_llavavid.py
Normal file
26
python/sglang/launch_server_llavavid.py
Normal file
@@ -0,0 +1,26 @@
|
||||
"""Launch the inference server for Llava-video model."""
|
||||
|
||||
import argparse
|
||||
|
||||
from sglang.srt.server import ServerArgs, launch_server
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
ServerArgs.add_cli_args(parser)
|
||||
args = parser.parse_args()
|
||||
server_args = ServerArgs.from_cli_args(args)
|
||||
|
||||
model_overide_args = {}
|
||||
model_overide_args["mm_spatial_pool_stride"] = 2
|
||||
model_overide_args["architectures"] = ["LlavaVidForCausalLM"]
|
||||
model_overide_args["num_frames"] = 16
|
||||
model_overide_args["model_type"] = "llavavid"
|
||||
if model_overide_args["num_frames"] == 32:
|
||||
model_overide_args["rope_scaling"] = {"factor": 2.0, "type": "linear"}
|
||||
model_overide_args["max_sequence_length"] = 4096 * 2
|
||||
model_overide_args["tokenizer_model_max_length"] = 4096 * 2
|
||||
model_overide_args["model_max_length"] = 4096 * 2
|
||||
if "34b" in args.model_path.lower():
|
||||
model_overide_args["image_token_index"] = 64002
|
||||
|
||||
launch_server(server_args, model_overide_args, None)
|
||||
@@ -119,24 +119,7 @@ def get_tokenizer(
|
||||
tokenizer_revision: Optional[str] = None,
|
||||
**kwargs,
|
||||
) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]:
|
||||
if tokenizer_name.endswith(".json"):
|
||||
return TiktokenTokenizer(tokenizer_name)
|
||||
|
||||
if tokenizer_name.endswith(".model"):
|
||||
return SentencePieceTokenizer(tokenizer_name)
|
||||
|
||||
"""Gets a tokenizer for the given model name via Huggingface."""
|
||||
if is_multimodal_model(tokenizer_name):
|
||||
processor = get_processor(
|
||||
tokenizer_name,
|
||||
*args,
|
||||
trust_remote_code=trust_remote_code,
|
||||
tokenizer_revision=tokenizer_revision,
|
||||
**kwargs,
|
||||
)
|
||||
tokenizer = processor.tokenizer
|
||||
return tokenizer
|
||||
|
||||
if tokenizer_mode == "slow":
|
||||
if kwargs.get("use_fast", False):
|
||||
raise ValueError("Cannot use the fast tokenizer in slow tokenizer mode.")
|
||||
@@ -199,135 +182,3 @@ def get_processor(
|
||||
**kwargs,
|
||||
)
|
||||
return processor
|
||||
|
||||
|
||||
class TiktokenTokenizer:
|
||||
def __init__(self, tokenizer_path):
|
||||
import tiktoken
|
||||
from jinja2 import Template
|
||||
|
||||
PAT_STR_B = r"""(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+"""
|
||||
|
||||
# Read JSON
|
||||
name = "tmp-json"
|
||||
with open(tokenizer_path, "rb") as fin:
|
||||
tok_dict = json.load(fin)
|
||||
|
||||
mergeable_ranks = {
|
||||
bytes(item["bytes"]): item["token"] for item in tok_dict["regular_tokens"]
|
||||
}
|
||||
special_tokens = {
|
||||
bytes(item["bytes"]).decode(): item["token"]
|
||||
for item in tok_dict["special_tokens"]
|
||||
}
|
||||
assert tok_dict["word_split"] == "V1"
|
||||
|
||||
default_allowed_special = None
|
||||
|
||||
kwargs = {
|
||||
"name": name,
|
||||
"pat_str": tok_dict.get("pat_str", PAT_STR_B),
|
||||
"mergeable_ranks": mergeable_ranks,
|
||||
"special_tokens": special_tokens,
|
||||
}
|
||||
if "default_allowed_special" in tok_dict:
|
||||
default_allowed_special = set(
|
||||
[
|
||||
bytes(bytes_list).decode()
|
||||
for bytes_list in tok_dict["default_allowed_special"]
|
||||
]
|
||||
)
|
||||
if "vocab_size" in tok_dict:
|
||||
kwargs["explicit_n_vocab"] = tok_dict["vocab_size"]
|
||||
|
||||
PAD = "<|pad|>"
|
||||
EOS = "<|eos|>"
|
||||
SEP = "<|separator|>"
|
||||
|
||||
DEFAULT_CONTROL_TOKENS = {"pad": PAD, "sep": EOS, "eos": SEP}
|
||||
|
||||
tokenizer = tiktoken.Encoding(**kwargs)
|
||||
tokenizer._default_allowed_special = default_allowed_special or set()
|
||||
tokenizer._control_tokens = DEFAULT_CONTROL_TOKENS
|
||||
|
||||
def encode_patched(
|
||||
self,
|
||||
text: str,
|
||||
*,
|
||||
allowed_special: Union[
|
||||
Literal["all"], AbstractSet[str]
|
||||
] = set(), # noqa: B006
|
||||
disallowed_special: Union[Literal["all"], Collection[str]] = "all",
|
||||
) -> List[int]:
|
||||
if isinstance(allowed_special, set):
|
||||
allowed_special |= self._default_allowed_special
|
||||
return tiktoken.Encoding.encode(
|
||||
self,
|
||||
text,
|
||||
allowed_special=allowed_special,
|
||||
disallowed_special=(),
|
||||
)
|
||||
|
||||
tokenizer.encode = functools.partial(encode_patched, tokenizer)
|
||||
|
||||
# Convert to HF interface
|
||||
self.tokenizer = tokenizer
|
||||
self.eos_token_id = tokenizer._special_tokens[EOS]
|
||||
self.vocab_size = tokenizer.n_vocab
|
||||
self.chat_template = Template(
|
||||
"{% for message in messages %}{% if message['role'] == 'user' %}{{ 'Human: ' + message['content'].strip() + '<|separator|>\n\n' }}{% elif message['role'] == 'system' %}{{ 'System: ' + message['content'].strip() + '<|separator|>\n\n' }}{% elif message['role'] == 'assistant' %}{{ 'Assistant: ' + message['content'] + '<|separator|>\n\n' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ 'Assistant:' }}{% endif %}"
|
||||
)
|
||||
|
||||
def encode(self, x, add_special_tokens=False):
|
||||
return self.tokenizer.encode(x)
|
||||
|
||||
def decode(self, x):
|
||||
return self.tokenizer.decode(x)
|
||||
|
||||
def batch_decode(
|
||||
self, batch, skip_special_tokens=True, spaces_between_special_tokens=False
|
||||
):
|
||||
if isinstance(batch[0], int):
|
||||
batch = [[x] for x in batch]
|
||||
return self.tokenizer.decode_batch(batch)
|
||||
|
||||
def apply_chat_template(self, messages, tokenize, add_generation_prompt):
|
||||
ret = self.chat_template.render(
|
||||
messages=messages, add_generation_prompt=add_generation_prompt
|
||||
)
|
||||
return self.encode(ret) if tokenize else ret
|
||||
|
||||
|
||||
class SentencePieceTokenizer:
|
||||
def __init__(self, tokenizer_path):
|
||||
import sentencepiece as spm
|
||||
from jinja2 import Template
|
||||
|
||||
tokenizer = spm.SentencePieceProcessor(model_file=tokenizer_path)
|
||||
|
||||
# Convert to HF interface
|
||||
self.tokenizer = tokenizer
|
||||
self.eos_token_id = tokenizer.eos_id()
|
||||
self.vocab_size = tokenizer.vocab_size()
|
||||
self.chat_template = Template(
|
||||
"{% for message in messages %}{% if message['role'] == 'user' %}{{ 'Human: ' + message['content'].strip() + '<|separator|>\n\n' }}{% elif message['role'] == 'system' %}{{ 'System: ' + message['content'].strip() + '<|separator|>\n\n' }}{% elif message['role'] == 'assistant' %}{{ 'Assistant: ' + message['content'] + '<|separator|>\n\n' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ 'Assistant:' }}{% endif %}"
|
||||
)
|
||||
|
||||
def encode(self, x, add_special_tokens=False):
|
||||
return self.tokenizer.encode(x)
|
||||
|
||||
def decode(self, x):
|
||||
return self.tokenizer.decode(x)
|
||||
|
||||
def batch_decode(
|
||||
self, batch, skip_special_tokens=True, spaces_between_special_tokens=False
|
||||
):
|
||||
if isinstance(batch[0], int):
|
||||
batch = [[x] for x in batch]
|
||||
return self.tokenizer.decode(batch)
|
||||
|
||||
def apply_chat_template(self, messages, tokenize, add_generation_prompt):
|
||||
ret = self.chat_template.render(
|
||||
messages=messages, add_generation_prompt=add_generation_prompt
|
||||
)
|
||||
return self.encode(ret) if tokenize else ret
|
||||
|
||||
@@ -55,6 +55,7 @@ class GenerateReqInput:
|
||||
self.text is not None and self.input_ids is not None
|
||||
):
|
||||
raise ValueError("Either text or input_ids should be provided.")
|
||||
|
||||
if (
|
||||
isinstance(self.sampling_params, dict)
|
||||
and self.sampling_params.get("n", 1) != 1
|
||||
@@ -161,10 +162,10 @@ class TokenizedGenerateReqInput:
|
||||
input_ids: List[int]
|
||||
# The pixel values for input images
|
||||
pixel_values: List[float]
|
||||
# The hash of input images
|
||||
image_hash: int
|
||||
# The image size
|
||||
image_size: List[int]
|
||||
# The hash values of input images
|
||||
image_hashes: List[int]
|
||||
# The image sizes
|
||||
image_sizes: List[List[int]]
|
||||
# The sampling parameters
|
||||
sampling_params: SamplingParams
|
||||
# Whether to return the logprobs
|
||||
|
||||
@@ -121,8 +121,8 @@ class Req:
|
||||
|
||||
# For vision input
|
||||
self.pixel_values = None
|
||||
self.image_size = None
|
||||
self.image_offset = None
|
||||
self.image_sizes = None
|
||||
self.image_offsets = None
|
||||
self.pad_value = None
|
||||
|
||||
# Prefix info
|
||||
@@ -600,12 +600,12 @@ class ScheduleBatch:
|
||||
if req.pixel_values is not None:
|
||||
(
|
||||
req.origin_input_ids,
|
||||
req.image_offset,
|
||||
req.image_offsets,
|
||||
) = model_runner.model.pad_input_ids(
|
||||
req.origin_input_ids_unpadded,
|
||||
req.pad_value,
|
||||
req.pixel_values.shape,
|
||||
req.image_size,
|
||||
req.pixel_values,
|
||||
req.image_sizes,
|
||||
)
|
||||
|
||||
jump_forward_reqs.append(req)
|
||||
|
||||
@@ -23,6 +23,7 @@ import multiprocessing as mp
|
||||
import os
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
|
||||
import fastapi
|
||||
import numpy as np
|
||||
import transformers
|
||||
import uvloop
|
||||
@@ -96,21 +97,18 @@ class TokenizerManager:
|
||||
trust_remote_code=server_args.trust_remote_code,
|
||||
model_overide_args=model_overide_args,
|
||||
)
|
||||
|
||||
self.is_generation = is_generation_model(
|
||||
self.hf_config.architectures, self.server_args.is_embedding
|
||||
)
|
||||
|
||||
if server_args.context_length is not None:
|
||||
self.context_len = server_args.context_length
|
||||
else:
|
||||
self.context_len = get_context_length(self.hf_config)
|
||||
self.context_len = server_args.context_length or get_context_length(
|
||||
self.hf_config
|
||||
)
|
||||
|
||||
# Create tokenizer
|
||||
if server_args.skip_tokenizer_init:
|
||||
self.tokenizer = self.processor = None
|
||||
else:
|
||||
if is_multimodal_model(self.model_path):
|
||||
if is_multimodal_model(self.hf_config.architectures):
|
||||
self.processor = get_processor(
|
||||
server_args.tokenizer_path,
|
||||
tokenizer_mode=server_args.tokenizer_mode,
|
||||
@@ -118,6 +116,9 @@ class TokenizerManager:
|
||||
)
|
||||
self.tokenizer = self.processor.tokenizer
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
|
||||
# We want to parallelize the image pre-processing so we
|
||||
# create an executor for it
|
||||
self.executor = concurrent.futures.ProcessPoolExecutor(
|
||||
initializer=init_global_processor,
|
||||
mp_context=mp.get_context("fork"),
|
||||
@@ -134,12 +135,14 @@ class TokenizerManager:
|
||||
self.to_create_loop = True
|
||||
self.rid_to_state: Dict[str, ReqState] = {}
|
||||
|
||||
# for update model weights
|
||||
# For update model weights
|
||||
self.model_update_lock = asyncio.Lock()
|
||||
self.model_update_result = None
|
||||
|
||||
async def generate_request(
|
||||
self, obj: Union[GenerateReqInput, EmbeddingReqInput], request=None
|
||||
self,
|
||||
obj: Union[GenerateReqInput, EmbeddingReqInput],
|
||||
request: Optional[fastapi.Request] = None,
|
||||
):
|
||||
if self.to_create_loop:
|
||||
self.create_handle_loop()
|
||||
@@ -160,7 +163,7 @@ class TokenizerManager:
|
||||
async def _handle_single_request(
|
||||
self,
|
||||
obj: Union[GenerateReqInput, EmbeddingReqInput],
|
||||
request,
|
||||
request: Optional[fastapi.Request] = None,
|
||||
index: Optional[int] = None,
|
||||
is_cache_for_prefill: Optional[bool] = False,
|
||||
):
|
||||
@@ -182,8 +185,8 @@ class TokenizerManager:
|
||||
)
|
||||
|
||||
if self.is_generation:
|
||||
pixel_values, image_hash, image_size = await self._get_pixel_values(
|
||||
obj.image_data
|
||||
pixel_values, image_hashes, image_sizes = await self._get_pixel_values(
|
||||
obj.image_data if not_use_index else obj.image_data[index]
|
||||
)
|
||||
return_logprob = (
|
||||
obj.return_logprob if not_use_index else obj.return_logprob[index]
|
||||
@@ -195,7 +198,6 @@ class TokenizerManager:
|
||||
)
|
||||
if return_logprob and logprob_start_len == -1:
|
||||
logprob_start_len = len(input_ids) - 1
|
||||
|
||||
top_logprobs_num = (
|
||||
obj.top_logprobs_num
|
||||
if not_use_index
|
||||
@@ -238,13 +240,14 @@ class TokenizerManager:
|
||||
|
||||
sampling_params = SamplingParams(**obj.sampling_params[0])
|
||||
sampling_params.max_new_tokens = 0
|
||||
pixel_values, image_hash, image_size = await self._get_pixel_values(
|
||||
pixel_values, image_hashes, image_sizes = await self._get_pixel_values(
|
||||
obj.image_data[0]
|
||||
)
|
||||
return_logprob = obj.return_logprob[0]
|
||||
logprob_start_len = obj.logprob_start_len[0]
|
||||
top_logprobs_num = obj.top_logprobs_num[0]
|
||||
|
||||
# Send to the controller
|
||||
if self.is_generation:
|
||||
if return_logprob and logprob_start_len == -1:
|
||||
logprob_start_len = len(input_ids) - 1
|
||||
@@ -253,8 +256,8 @@ class TokenizerManager:
|
||||
input_text,
|
||||
input_ids,
|
||||
pixel_values,
|
||||
image_hash,
|
||||
image_size,
|
||||
image_hashes,
|
||||
image_sizes,
|
||||
sampling_params,
|
||||
return_logprob,
|
||||
logprob_start_len,
|
||||
@@ -268,24 +271,24 @@ class TokenizerManager:
|
||||
input_ids,
|
||||
sampling_params,
|
||||
)
|
||||
|
||||
self.send_to_router.send_pyobj(tokenized_obj)
|
||||
|
||||
# Recv results
|
||||
event = asyncio.Event()
|
||||
state = ReqState([], False, event)
|
||||
self.rid_to_state[rid] = state
|
||||
if not is_cache_for_prefill:
|
||||
async for response in self._wait_for_response(
|
||||
event, state, obj, rid, request
|
||||
):
|
||||
async for response in self._wait_for_response(state, obj, rid, request):
|
||||
yield response
|
||||
else:
|
||||
assert self.is_generation
|
||||
await self._wait_for_cache_prefill_response(event, state, obj, rid, request)
|
||||
await self._wait_for_cache_prefill_response(state, obj, rid, request)
|
||||
yield input_ids
|
||||
|
||||
async def _handle_batch_request(
|
||||
self, obj: Union[GenerateReqInput, EmbeddingReqInput], request
|
||||
self,
|
||||
obj: Union[GenerateReqInput, EmbeddingReqInput],
|
||||
request: Optional[fastapi.Request] = None,
|
||||
):
|
||||
batch_size = obj.batch_size
|
||||
if self.is_generation:
|
||||
@@ -340,8 +343,8 @@ class TokenizerManager:
|
||||
if self.is_generation:
|
||||
if obj.return_logprob[index] and obj.logprob_start_len[index] == -1:
|
||||
obj.logprob_start_len[index] = len(input_ids) - 1
|
||||
pixel_values, image_hash, image_size = await self._get_pixel_values(
|
||||
obj.image_data[index]
|
||||
pixel_values, image_hashes, image_sizes = (
|
||||
await self._get_pixel_values(obj.image_data[index])
|
||||
)
|
||||
|
||||
tokenized_obj = TokenizedGenerateReqInput(
|
||||
@@ -349,8 +352,8 @@ class TokenizerManager:
|
||||
input_text,
|
||||
input_ids,
|
||||
pixel_values,
|
||||
image_hash,
|
||||
image_size,
|
||||
image_hashes,
|
||||
image_sizes,
|
||||
sampling_params,
|
||||
obj.return_logprob[index],
|
||||
obj.logprob_start_len[index],
|
||||
@@ -372,7 +375,6 @@ class TokenizerManager:
|
||||
|
||||
generators.append(
|
||||
self._wait_for_response(
|
||||
event,
|
||||
state,
|
||||
obj,
|
||||
rid,
|
||||
@@ -388,6 +390,7 @@ class TokenizerManager:
|
||||
tasks = [asyncio.create_task(gen.__anext__()) for gen in generators]
|
||||
output_list = [None] * len(tasks)
|
||||
|
||||
# Recv results
|
||||
while tasks:
|
||||
done, _ = await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED)
|
||||
|
||||
@@ -426,25 +429,18 @@ class TokenizerManager:
|
||||
sampling_params.verify()
|
||||
return sampling_params
|
||||
|
||||
async def _get_pixel_values(self, image_data):
|
||||
if image_data is None:
|
||||
return None, None, None
|
||||
else:
|
||||
return await self._get_pixel_values_internal(image_data)
|
||||
|
||||
async def _wait_for_response(
|
||||
self,
|
||||
event: asyncio.Event,
|
||||
state: ReqState,
|
||||
obj: Union[GenerateReqInput, EmbeddingReqInput],
|
||||
rid: str,
|
||||
request,
|
||||
index: int = None,
|
||||
request: Optional[fastapi.Request] = None,
|
||||
index: Optional[int] = None,
|
||||
response_index: int = 0,
|
||||
):
|
||||
while True:
|
||||
try:
|
||||
await asyncio.wait_for(event.wait(), timeout=4)
|
||||
await asyncio.wait_for(state.event.wait(), timeout=4)
|
||||
except asyncio.TimeoutError:
|
||||
if request is not None and await request.is_disconnected():
|
||||
for rid in [obj.rid] if obj.is_single else obj.rid:
|
||||
@@ -478,16 +474,15 @@ class TokenizerManager:
|
||||
yield out
|
||||
break
|
||||
|
||||
event.clear()
|
||||
state.event.clear()
|
||||
yield out
|
||||
|
||||
async def _wait_for_cache_prefill_response(
|
||||
self,
|
||||
event: asyncio.Event,
|
||||
state: ReqState,
|
||||
obj: GenerateReqInput,
|
||||
rid: str,
|
||||
request,
|
||||
request: Optional[fastapi.Request] = None,
|
||||
):
|
||||
while True:
|
||||
try:
|
||||
@@ -514,7 +509,9 @@ class TokenizerManager:
|
||||
req = AbortReq(rid)
|
||||
self.send_to_router.send_pyobj(req)
|
||||
|
||||
async def update_weights(self, obj: UpdateWeightReqInput, request):
|
||||
async def update_weights(
|
||||
self, obj: UpdateWeightReqInput, request: Optional[fastapi.Request] = None
|
||||
):
|
||||
if self.to_create_loop:
|
||||
self.create_handle_loop()
|
||||
|
||||
@@ -659,12 +656,11 @@ class TokenizerManager:
|
||||
)
|
||||
return top_logprobs
|
||||
|
||||
async def _get_pixel_values_internal(self, image_data, aspect_ratio=None):
|
||||
aspect_ratio = (
|
||||
getattr(self.hf_config, "image_aspect_ratio", None)
|
||||
if aspect_ratio is None
|
||||
else aspect_ratio
|
||||
)
|
||||
async def _get_pixel_values(self, image_data: List[Union[str, bytes]]):
|
||||
if not image_data:
|
||||
return None, None, None
|
||||
|
||||
aspect_ratio = getattr(self.hf_config, "image_aspect_ratio", None)
|
||||
grid_pinpoints = (
|
||||
self.hf_config.image_grid_pinpoints
|
||||
if hasattr(self.hf_config, "image_grid_pinpoints")
|
||||
@@ -673,35 +669,42 @@ class TokenizerManager:
|
||||
)
|
||||
|
||||
if isinstance(image_data, list) and len(image_data) > 0:
|
||||
pixel_values, image_hash, image_size = [], [], []
|
||||
# Multiple images
|
||||
if len(image_data) > 1:
|
||||
aspect_ratio = "pad" # LLaVA OneVision Handling: more than one image --> interleaved image mode or video mode. We do not use anyres
|
||||
pixel_values, image_hashes, image_sizes = [], [], []
|
||||
for img_data in image_data:
|
||||
pixel_v, image_h, image_s = await self._process_single_image(
|
||||
img_data, aspect_ratio, grid_pinpoints
|
||||
)
|
||||
pixel_values.append(pixel_v)
|
||||
image_hash.append(image_h)
|
||||
image_size.append(image_s)
|
||||
pixel_values = np.stack(pixel_values, axis=0)
|
||||
image_hashes.append(image_h)
|
||||
image_sizes.append(image_s)
|
||||
|
||||
if isinstance(pixel_values[0], np.ndarray):
|
||||
pixel_values = np.stack(pixel_values, axis=0)
|
||||
else:
|
||||
# A single image
|
||||
pixel_values, image_hash, image_size = await self._process_single_image(
|
||||
image_data[0], aspect_ratio, grid_pinpoints
|
||||
)
|
||||
image_hash = [image_hash]
|
||||
image_size = [image_size]
|
||||
image_hashes = [image_hash]
|
||||
image_sizes = [image_size]
|
||||
elif isinstance(image_data, str):
|
||||
# A single image
|
||||
pixel_values, image_hash, image_size = await self._process_single_image(
|
||||
image_data, aspect_ratio, grid_pinpoints
|
||||
)
|
||||
image_hash = [image_hash]
|
||||
image_size = [image_size]
|
||||
image_hashes = [image_hash]
|
||||
image_sizes = [image_size]
|
||||
else:
|
||||
pixel_values, image_hash, image_size = None, None, None
|
||||
raise ValueError(f"Invalid image data: {image_data}")
|
||||
|
||||
return pixel_values, image_hash, image_size
|
||||
return pixel_values, image_hashes, image_sizes
|
||||
|
||||
async def _process_single_image(self, image_data, aspect_ratio, grid_pinpoints):
|
||||
async def _process_single_image(
|
||||
self, image_data: Union[bytes, str], aspect_ratio: str, grid_pinpoints: str
|
||||
):
|
||||
if self.executor is not None:
|
||||
loop = asyncio.get_event_loop()
|
||||
return await loop.run_in_executor(
|
||||
@@ -732,12 +735,16 @@ def init_global_processor(server_args: ServerArgs):
|
||||
|
||||
|
||||
def _process_single_image_task(
|
||||
image_data, image_aspect_ratio=None, image_grid_pinpoints=None, processor=None
|
||||
image_data: Union[str, bytes],
|
||||
image_aspect_ratio: Optional[str] = None,
|
||||
image_grid_pinpoints: Optional[str] = None,
|
||||
processor=None,
|
||||
):
|
||||
try:
|
||||
processor = processor or global_processor
|
||||
image, image_size = load_image(image_data)
|
||||
if image_size is not None:
|
||||
# It is a video with multiple images
|
||||
image_hash = hash(image_data)
|
||||
pixel_values = processor.image_processor(image)["pixel_values"]
|
||||
for _ in range(len(pixel_values)):
|
||||
@@ -745,6 +752,7 @@ def _process_single_image_task(
|
||||
pixel_values = np.stack(pixel_values, axis=0)
|
||||
return pixel_values, image_hash, image_size
|
||||
else:
|
||||
# It is an image
|
||||
image_hash = hash(image_data)
|
||||
if image_aspect_ratio == "pad":
|
||||
image = expand2square(
|
||||
@@ -754,13 +762,18 @@ def _process_single_image_task(
|
||||
pixel_values = processor.image_processor(image.convert("RGB"))[
|
||||
"pixel_values"
|
||||
][0]
|
||||
elif image_aspect_ratio == "anyres" or "anyres_max" in image_aspect_ratio:
|
||||
elif image_aspect_ratio == "anyres" or (
|
||||
image_aspect_ratio is not None and "anyres_max" in image_aspect_ratio
|
||||
):
|
||||
pixel_values = process_anyres_image(
|
||||
image, processor.image_processor, image_grid_pinpoints
|
||||
)
|
||||
else:
|
||||
pixel_values = processor.image_processor(image)["pixel_values"][0]
|
||||
pixel_values = pixel_values.astype(np.float16)
|
||||
|
||||
if isinstance(pixel_values, np.ndarray):
|
||||
pixel_values = pixel_values.astype(np.float16)
|
||||
|
||||
return pixel_values, image_hash, image.size
|
||||
except Exception:
|
||||
logger.error("Exception in TokenizerManager:\n" + get_exception_traceback())
|
||||
|
||||
@@ -108,7 +108,7 @@ class ModelTpServer:
|
||||
if server_args.skip_tokenizer_init:
|
||||
self.tokenizer = self.processor = None
|
||||
else:
|
||||
if is_multimodal_model(server_args.model_path):
|
||||
if is_multimodal_model(self.model_config.hf_config.architectures):
|
||||
self.processor = get_processor(
|
||||
server_args.tokenizer_path,
|
||||
tokenizer_mode=server_args.tokenizer_mode,
|
||||
@@ -333,26 +333,24 @@ class ModelTpServer:
|
||||
if self.model_runner.is_generation:
|
||||
req.pixel_values = recv_req.pixel_values
|
||||
if req.pixel_values is not None:
|
||||
image_hash = (
|
||||
hash(tuple(recv_req.image_hash))
|
||||
if isinstance(recv_req.image_hash, list)
|
||||
else recv_req.image_hash
|
||||
)
|
||||
# Use image hash as fake token_ids, which is then used
|
||||
# for prefix matching
|
||||
image_hash = hash(tuple(recv_req.image_hashes))
|
||||
req.pad_value = [
|
||||
(image_hash) % self.model_config.vocab_size,
|
||||
(image_hash >> 16) % self.model_config.vocab_size,
|
||||
(image_hash >> 32) % self.model_config.vocab_size,
|
||||
(image_hash >> 64) % self.model_config.vocab_size,
|
||||
]
|
||||
req.image_size = recv_req.image_size
|
||||
req.image_sizes = recv_req.image_sizes
|
||||
(
|
||||
req.origin_input_ids,
|
||||
req.image_offset,
|
||||
req.image_offsets,
|
||||
) = self.model_runner.model.pad_input_ids(
|
||||
req.origin_input_ids_unpadded,
|
||||
req.pad_value,
|
||||
req.pixel_values.shape,
|
||||
req.image_size,
|
||||
req.pixel_values,
|
||||
req.image_sizes,
|
||||
)
|
||||
req.return_logprob = recv_req.return_logprob
|
||||
req.logprob_start_len = recv_req.logprob_start_len
|
||||
@@ -368,6 +366,7 @@ class ModelTpServer:
|
||||
req.jump_forward_map = self.jump_forward_cache.query(
|
||||
computed_regex_string
|
||||
)
|
||||
|
||||
# Init regex fsm
|
||||
elif req.sampling_params.regex is not None:
|
||||
req.regex_fsm = self.regex_fsm_cache.query(req.sampling_params.regex)
|
||||
|
||||
@@ -16,7 +16,7 @@ limitations under the License.
|
||||
"""ModelRunner runs the forward passes of the models."""
|
||||
from dataclasses import dataclass
|
||||
from enum import IntEnum, auto
|
||||
from typing import TYPE_CHECKING, List, Optional
|
||||
from typing import TYPE_CHECKING, List
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -58,6 +58,7 @@ class InputMetadata:
|
||||
|
||||
# For extend
|
||||
extend_seq_lens: torch.Tensor = None
|
||||
extend_prefix_lens: torch.Tensor = None
|
||||
extend_start_loc: torch.Tensor = None
|
||||
extend_no_prefix: bool = None
|
||||
|
||||
@@ -69,8 +70,8 @@ class InputMetadata:
|
||||
|
||||
# For multimodal
|
||||
pixel_values: List[torch.Tensor] = None
|
||||
image_sizes: List[List[int]] = None
|
||||
image_offsets: List[int] = None
|
||||
image_sizes: List[List[List[int]]] = None
|
||||
image_offsets: List[List[int]] = None
|
||||
|
||||
# Trition attention backend
|
||||
triton_max_seq_len: int = 0
|
||||
@@ -87,20 +88,8 @@ class InputMetadata:
|
||||
def init_multimuldal_info(self, batch: ScheduleBatch):
|
||||
reqs = batch.reqs
|
||||
self.pixel_values = [r.pixel_values for r in reqs]
|
||||
self.image_sizes = [r.image_size for r in reqs]
|
||||
self.image_offsets = []
|
||||
for r in reqs:
|
||||
if isinstance(r.image_offset, list):
|
||||
self.image_offsets.append(
|
||||
[
|
||||
(image_offset - len(r.prefix_indices))
|
||||
for image_offset in r.image_offset
|
||||
]
|
||||
)
|
||||
elif isinstance(r.image_offset, int):
|
||||
self.image_offsets.append(r.image_offset - len(r.prefix_indices))
|
||||
elif r.image_offset is None:
|
||||
self.image_offsets.append(0)
|
||||
self.image_sizes = [r.image_sizes for r in reqs]
|
||||
self.image_offsets = [r.image_offsets for r in reqs]
|
||||
|
||||
def compute_positions(self, batch: ScheduleBatch):
|
||||
position_ids_offsets = batch.position_ids_offsets
|
||||
@@ -153,6 +142,7 @@ class InputMetadata:
|
||||
for i, r in enumerate(batch.reqs)
|
||||
]
|
||||
self.extend_seq_lens = torch.tensor(extend_lens_cpu, device="cuda")
|
||||
self.extend_prefix_lens = torch.tensor(batch.prefix_lens_cpu, device="cuda")
|
||||
self.extend_start_loc = torch.zeros_like(self.seq_lens)
|
||||
self.extend_start_loc[1:] = torch.cumsum(self.extend_seq_lens[:-1], dim=0)
|
||||
self.extend_no_prefix = all(l == 0 for l in batch.prefix_lens_cpu)
|
||||
@@ -238,10 +228,10 @@ class InputMetadata:
|
||||
prefix_lens_cpu,
|
||||
flashinfer_use_ragged,
|
||||
):
|
||||
if self.forward_mode != ForwardMode.DECODE:
|
||||
prefix_lens = torch.tensor(prefix_lens_cpu, device="cuda")
|
||||
else:
|
||||
if self.forward_mode == ForwardMode.DECODE:
|
||||
prefix_lens = None
|
||||
else:
|
||||
prefix_lens = self.extend_prefix_lens
|
||||
|
||||
update_flashinfer_indices(
|
||||
self.forward_mode,
|
||||
|
||||
@@ -50,7 +50,7 @@ from sglang.srt.mem_cache.memory_pool import (
|
||||
MLATokenToKVPool,
|
||||
ReqToTokenPool,
|
||||
)
|
||||
from sglang.srt.model_config import AttentionArch
|
||||
from sglang.srt.model_config import AttentionArch, ModelConfig
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetadata
|
||||
from sglang.srt.server_args import ServerArgs
|
||||
from sglang.srt.utils import (
|
||||
@@ -69,7 +69,7 @@ logger = logging.getLogger(__name__)
|
||||
class ModelRunner:
|
||||
def __init__(
|
||||
self,
|
||||
model_config,
|
||||
model_config: ModelConfig,
|
||||
mem_fraction_static: float,
|
||||
gpu_id: int,
|
||||
tp_rank: int,
|
||||
@@ -85,7 +85,9 @@ class ModelRunner:
|
||||
self.tp_size = tp_size
|
||||
self.nccl_port = nccl_port
|
||||
self.server_args = server_args
|
||||
self.is_multimodal_model = is_multimodal_model(self.model_config)
|
||||
self.is_multimodal_model = is_multimodal_model(
|
||||
self.model_config.hf_config.architectures
|
||||
)
|
||||
global_server_args_dict.update(
|
||||
{
|
||||
"disable_flashinfer": server_args.disable_flashinfer,
|
||||
@@ -95,6 +97,13 @@ class ModelRunner:
|
||||
}
|
||||
)
|
||||
|
||||
if self.is_multimodal_model:
|
||||
logger.info(
|
||||
"Automatically turn off --chunked-prefill-size and adjust --mem-fraction-static for multimodal models."
|
||||
)
|
||||
server_args.chunked_prefill_size = None
|
||||
server_args.mem_fraction_static *= 0.95
|
||||
|
||||
min_per_gpu_memory = self.init_torch_distributed()
|
||||
self.load_model()
|
||||
self.init_memory_pool(
|
||||
@@ -507,9 +516,9 @@ class ModelRunner:
|
||||
raise Exception(
|
||||
f"Capture cuda graph failed: {e}\n"
|
||||
"Possible solutions:\n"
|
||||
"1. disable torch compile by not using --enable-torch-compile\n"
|
||||
"2. disable cuda graph by --disable-cuda-graph\n"
|
||||
"3. set --mem-fraction-static to a smaller value\n"
|
||||
"1. disable cuda graph by --disable-cuda-graph\n"
|
||||
"2. set --mem-fraction-static to a smaller value\n"
|
||||
"3. disable torch compile by not using --enable-torch-compile\n"
|
||||
"Open an issue on GitHub https://github.com/sgl-project/sglang/issues/new/choose \n"
|
||||
)
|
||||
|
||||
|
||||
@@ -17,7 +17,7 @@ limitations under the License.
|
||||
# Adapted from
|
||||
# https://github.com/THUDM/ChatGLM2-6B
|
||||
"""Inference-only ChatGLM model compatible with THUDM weights."""
|
||||
from typing import Iterable, List, Optional, Tuple
|
||||
from typing import Iterable, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
@@ -273,9 +273,9 @@ class Grok1Model(nn.Module):
|
||||
) -> torch.Tensor:
|
||||
if input_embeds is None:
|
||||
hidden_states = self.embed_tokens(input_ids)
|
||||
hidden_states.mul_(self.config.embedding_multiplier_scale)
|
||||
else:
|
||||
hidden_states = input_embeds
|
||||
hidden_states.mul_(self.config.embedding_multiplier_scale)
|
||||
|
||||
for i in range(len(self.layers)):
|
||||
hidden_states = self.layers[i](positions, hidden_states, input_metadata)
|
||||
@@ -284,7 +284,7 @@ class Grok1Model(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
class Grok1ModelForCausalLM(nn.Module):
|
||||
class Grok1ForCausalLM(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
@@ -415,4 +415,10 @@ def _prepare_presharded_weights(
|
||||
return hf_folder, hf_weights_files, use_safetensors
|
||||
|
||||
|
||||
EntryClass = Grok1ModelForCausalLM
|
||||
class Grok1ModelForCausalLM(Grok1ForCausalLM):
|
||||
"""An alias for backward-compatbility."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
EntryClass = [Grok1ForCausalLM, Grok1ModelForCausalLM]
|
||||
|
||||
@@ -357,6 +357,9 @@ class LlamaForCausalLM(nn.Module):
|
||||
# Models trained using ColossalAI may include these tensors in
|
||||
# the checkpoint. Skip them.
|
||||
return
|
||||
if name.startswith("model.vision_tower") and name not in params_dict:
|
||||
return
|
||||
|
||||
for param_name, weight_name, shard_id in stacked_params_mapping:
|
||||
if weight_name not in name:
|
||||
continue
|
||||
@@ -364,8 +367,6 @@ class LlamaForCausalLM(nn.Module):
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
if name.startswith("model.vision_tower") and name not in params_dict:
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param, loaded_weight, shard_id)
|
||||
@@ -374,8 +375,6 @@ class LlamaForCausalLM(nn.Module):
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
return
|
||||
if name.startswith("model.vision_tower") and name not in params_dict:
|
||||
return
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
|
||||
@@ -103,8 +103,6 @@ class LlamaForClassification(nn.Module):
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
if name.startswith("model.vision_tower") and name not in params_dict:
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param, loaded_weight, shard_id)
|
||||
@@ -113,8 +111,6 @@ class LlamaForClassification(nn.Module):
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
if name.startswith("model.vision_tower") and name not in params_dict:
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
|
||||
@@ -57,6 +57,9 @@ class LlamaEmbeddingModel(nn.Module):
|
||||
# Models trained using ColossalAI may include these tensors in
|
||||
# the checkpoint. Skip them.
|
||||
return
|
||||
if name.startswith("model.vision_tower") and name not in params_dict:
|
||||
return
|
||||
|
||||
for param_name, weight_name, shard_id in stacked_params_mapping:
|
||||
if weight_name not in name:
|
||||
continue
|
||||
@@ -64,8 +67,6 @@ class LlamaEmbeddingModel(nn.Module):
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
if name.startswith("model.vision_tower") and name not in params_dict:
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param, loaded_weight, shard_id)
|
||||
@@ -74,8 +75,6 @@ class LlamaEmbeddingModel(nn.Module):
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
return
|
||||
if name.startswith("model.vision_tower") and name not in params_dict:
|
||||
return
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
|
||||
@@ -28,7 +28,6 @@ from transformers import (
|
||||
LlavaConfig,
|
||||
MistralConfig,
|
||||
Qwen2Config,
|
||||
SiglipVisionConfig,
|
||||
SiglipVisionModel,
|
||||
)
|
||||
from transformers.models.llava.modeling_llava import LlavaMultiModalProjector
|
||||
@@ -66,13 +65,18 @@ class LlavaLlamaForCausalLM(nn.Module):
|
||||
torch.empty(config.text_config.hidden_size, dtype=torch.float16)
|
||||
)
|
||||
|
||||
def pad_input_ids(self, input_ids, pad_value, pt_shape=None, image_size=None):
|
||||
|
||||
def pad_input_ids(
|
||||
self,
|
||||
input_ids: List[int],
|
||||
pad_value: List[int],
|
||||
pixel_values: List,
|
||||
image_sizes: List[List[int]],
|
||||
):
|
||||
# hardcode for spatial_unpad + anyres
|
||||
image_aspect_ratio = "anyres" if len(image_size) == 1 else "pad"
|
||||
image_aspect_ratio = "anyres" if len(image_sizes) == 1 else "pad"
|
||||
offset_list = []
|
||||
for image_s in image_size:
|
||||
if len(image_size) > 16:
|
||||
for image_s in image_sizes:
|
||||
if len(image_sizes) > 16:
|
||||
# 2x2 pooling with stride 2
|
||||
new_image_feature_len = (
|
||||
math.ceil(self.image_size / self.patch_size / 2) ** 2
|
||||
@@ -153,17 +157,15 @@ class LlavaLlamaForCausalLM(nn.Module):
|
||||
if input_metadata.forward_mode == ForwardMode.EXTEND:
|
||||
bs = input_metadata.batch_size
|
||||
|
||||
# Embed text input
|
||||
# Embed text inputs
|
||||
input_embeds = self.language_model.model.embed_tokens(input_ids)
|
||||
# Embed vision input
|
||||
need_vision = (
|
||||
(positions[input_metadata.extend_start_loc] < self.image_feature_len)
|
||||
.cpu()
|
||||
.numpy()
|
||||
|
||||
# Whether the requests need vision inputs
|
||||
max_image_offset = np.array(
|
||||
[max(image_offsets[i]) if image_offsets[i] else -1 for i in range(bs)]
|
||||
)
|
||||
# FIXME: We need to substract the length of the system prompt
|
||||
has_pixel = np.array([pixel_values[i] is not None for i in range(bs)])
|
||||
need_vision = need_vision & has_pixel
|
||||
start_positions = positions[input_metadata.extend_start_loc].cpu().numpy()
|
||||
need_vision = start_positions <= max_image_offset
|
||||
|
||||
if need_vision.any():
|
||||
pixel_values = [pixel_values[i] for i in range(bs) if need_vision[i]]
|
||||
@@ -332,31 +334,35 @@ class LlavaLlamaForCausalLM(nn.Module):
|
||||
new_image_features.append(image_feature)
|
||||
image_features = new_image_features
|
||||
|
||||
# Fill in the placeholder for the image
|
||||
extend_start_loc_cpu = input_metadata.extend_start_loc.cpu().numpy()
|
||||
prefix_lens_cpu = input_metadata.extend_prefix_lens.cpu().numpy()
|
||||
pt = 0
|
||||
for i in range(bs):
|
||||
if not need_vision[i]:
|
||||
continue
|
||||
|
||||
start_idx = extend_start_loc_cpu[i]
|
||||
pad_dim = image_features[pt].shape[-1] # 576, 4096
|
||||
dim = input_embeds.shape[1]
|
||||
assert (
|
||||
pad_dim == dim
|
||||
), "invalid pad_dim={}, input_embed_dim={}!".format(pad_dim, dim)
|
||||
# Fill in the placeholder for the image
|
||||
try:
|
||||
for j, image_off in enumerate(image_offsets[i]):
|
||||
# print("actual image_features length: ", image_features[pt][j].shape[0])
|
||||
pad_len = image_features[pt][j].shape[0]
|
||||
input_embeds[
|
||||
start_idx + image_off : start_idx + image_off + pad_len
|
||||
] = image_features[pt][j]
|
||||
except RuntimeError as e:
|
||||
print(f"RuntimeError in llava image encoding: {e}")
|
||||
print(image_features[pt].shape)
|
||||
print(input_embeds.shape)
|
||||
print(start_idx, image_offsets[i])
|
||||
prefix_len = prefix_lens_cpu[i]
|
||||
|
||||
# Multiple images
|
||||
for j, image_offset in enumerate(image_offsets[i]):
|
||||
if image_offset < prefix_len:
|
||||
continue
|
||||
|
||||
tmp_image_feature = image_features[pt][j]
|
||||
pad_len = tmp_image_feature.shape[0]
|
||||
|
||||
left_idx = start_idx + (image_offset - prefix_len)
|
||||
right_idx = start_idx + (image_offset - prefix_len) + pad_len
|
||||
try:
|
||||
input_embeds[left_idx:right_idx] = tmp_image_feature
|
||||
except RuntimeError as e:
|
||||
print(f"RuntimeError in image encoding: {e}")
|
||||
print(f"{input_embeds.shape=}, {tmp_image_feature.shape=}")
|
||||
print(
|
||||
f"{start_idx=}, {image_offset=}, {prefix_len=}, {pad_len=}"
|
||||
)
|
||||
pt += 1
|
||||
|
||||
return self.language_model(
|
||||
@@ -366,8 +372,9 @@ class LlavaLlamaForCausalLM(nn.Module):
|
||||
return self.language_model(input_ids, positions, input_metadata)
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
# load clip vision model by cfg['mm_vision_tower']:
|
||||
# huggingface_name or path_of_clip_relative_to_llava_model_dir
|
||||
# Load clip vision model by cfg['mm_vision_tower']:
|
||||
# huggingface_name or path_of_clip_relative_to_llava_model_dir
|
||||
# We put the initialization here instead of __init__ to allow it being reused by other subclasses.
|
||||
vision_path = self.config.mm_vision_tower
|
||||
if "clip" in vision_path:
|
||||
self.vision_tower = CLIPVisionModel.from_pretrained(
|
||||
@@ -422,8 +429,6 @@ class LlavaLlamaForCausalLM(nn.Module):
|
||||
# load language model
|
||||
self.language_model.load_weights(weights)
|
||||
|
||||
monkey_path_clip_vision_embed_forward()
|
||||
|
||||
@property
|
||||
def num_patches_per_side(self):
|
||||
return self.image_size // self.patch_size
|
||||
@@ -495,36 +500,4 @@ class LlavaMistralForCausalLM(LlavaLlamaForCausalLM):
|
||||
)
|
||||
|
||||
|
||||
first_call = True
|
||||
|
||||
|
||||
def clip_vision_embed_forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
|
||||
batch_size = pixel_values.shape[0]
|
||||
|
||||
# Move this conv layer to CPU to avoid a bug in torch >= 2.1 on A10G.
|
||||
global first_call
|
||||
if first_call:
|
||||
self.patch_embedding.cpu().float()
|
||||
first_call = False
|
||||
pixel_values = pixel_values.to(dtype=torch.float32, device="cpu")
|
||||
patch_embeds = self.patch_embedding(pixel_values).cuda().half()
|
||||
|
||||
patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
|
||||
|
||||
class_embeds = self.class_embedding.expand(batch_size, 1, -1)
|
||||
embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
|
||||
embeddings = embeddings + self.position_embedding(self.position_ids)
|
||||
return embeddings
|
||||
|
||||
|
||||
def monkey_path_clip_vision_embed_forward():
|
||||
import transformers
|
||||
|
||||
setattr(
|
||||
transformers.models.clip.modeling_clip.CLIPVisionEmbeddings,
|
||||
"forward",
|
||||
clip_vision_embed_forward,
|
||||
)
|
||||
|
||||
|
||||
EntryClass = [LlavaLlamaForCausalLM, LlavaQwenForCausalLM, LlavaMistralForCausalLM]
|
||||
|
||||
@@ -26,11 +26,6 @@ from vllm.config import CacheConfig
|
||||
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
|
||||
from sglang.srt.mm_utils import (
|
||||
get_anyres_image_grid_shape,
|
||||
unpad_image,
|
||||
unpad_image_shape,
|
||||
)
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetadata
|
||||
from sglang.srt.models.llama2 import LlamaForCausalLM
|
||||
|
||||
@@ -59,23 +54,14 @@ class LlavaVidForCausalLM(nn.Module):
|
||||
torch.empty(config.text_config.hidden_size, dtype=torch.float16)
|
||||
)
|
||||
|
||||
def pad_input_ids(self, input_ids, pad_value, pt_shape=None, image_size=None):
|
||||
def pad_input_ids(
|
||||
self,
|
||||
input_ids: List[int],
|
||||
pad_value: List[int],
|
||||
pixel_values: List,
|
||||
image_sizes: List[List[int]],
|
||||
):
|
||||
new_image_feature_len = self.image_feature_len
|
||||
# now only support spatial_unpad + anyres
|
||||
# if self.mm_patch_merge_type.startswith("spatial"):
|
||||
# height = width = self.num_patches_per_side
|
||||
# if pt_shape[0] > 1:
|
||||
# if self.image_aspect_ratio == "anyres":
|
||||
# num_patch_width, num_patch_height = get_anyres_image_grid_shape(
|
||||
# image_size,
|
||||
# self.image_grid_pinpoints,
|
||||
# self.vision_tower.config.image_size,
|
||||
# )
|
||||
# if "unpad" in self.mm_patch_merge_type:
|
||||
# h = num_patch_height * height
|
||||
# w = num_patch_width * width
|
||||
# new_h, new_w = unpad_image_shape(h, w, image_size)
|
||||
# new_image_feature_len += new_h * (new_w + 1)
|
||||
|
||||
pad_ids = pad_value * (
|
||||
(new_image_feature_len + len(pad_value)) // len(pad_value)
|
||||
@@ -87,7 +73,7 @@ class LlavaVidForCausalLM(nn.Module):
|
||||
+ pad_ids[:new_image_feature_len]
|
||||
+ input_ids[offset + 1 :]
|
||||
)
|
||||
return new_input_ids, offset
|
||||
return new_input_ids, [offset]
|
||||
|
||||
def encode_images(self, pixel_values: torch.Tensor) -> torch.Tensor:
|
||||
image_outputs = self.vision_tower(pixel_values, output_hidden_states=True)
|
||||
@@ -133,22 +119,18 @@ class LlavaVidForCausalLM(nn.Module):
|
||||
if input_metadata.forward_mode == ForwardMode.EXTEND:
|
||||
bs = input_metadata.batch_size
|
||||
|
||||
# Embed text input
|
||||
# Embed text inputs
|
||||
input_embeds = self.language_model.model.embed_tokens(input_ids)
|
||||
|
||||
# Embed vision input
|
||||
need_vision = (
|
||||
(positions[input_metadata.extend_start_loc] < self.image_feature_len)
|
||||
.cpu()
|
||||
.numpy()
|
||||
# Whether the requests need vision inputs
|
||||
max_image_offset = np.array(
|
||||
[max(image_offsets[i]) if image_offsets[i] else -1 for i in range(bs)]
|
||||
)
|
||||
# FIXME: We need to substract the length of the system prompt
|
||||
has_pixel = np.array([pixel_values[i] is not None for i in range(bs)])
|
||||
need_vision = need_vision & has_pixel
|
||||
start_positions = positions[input_metadata.extend_start_loc].cpu().numpy()
|
||||
need_vision = start_positions <= max_image_offset
|
||||
|
||||
if need_vision.any():
|
||||
pixel_values = [pixel_values[i] for i in range(bs) if need_vision[i]]
|
||||
image_sizes = [image_sizes[i] for i in range(bs) if need_vision[i]]
|
||||
|
||||
########## Encode Image ########
|
||||
|
||||
@@ -183,31 +165,36 @@ class LlavaVidForCausalLM(nn.Module):
|
||||
new_image_features.append(image_feature.flatten(0, 1))
|
||||
image_features = new_image_features
|
||||
|
||||
# Fill in the placeholder for the image
|
||||
extend_start_loc_cpu = input_metadata.extend_start_loc.cpu().numpy()
|
||||
prefix_lens_cpu = input_metadata.extend_prefix_lens.cpu().numpy()
|
||||
pt = 0
|
||||
for i in range(bs):
|
||||
if not need_vision[i]:
|
||||
continue
|
||||
|
||||
start_idx = extend_start_loc_cpu[i]
|
||||
pad_len, pad_dim = image_features[pt].shape # 576, 4096
|
||||
dim = input_embeds.shape[1]
|
||||
assert (
|
||||
pad_dim == dim
|
||||
), "invalid pad_dim={}, input_embed_dim={}!".format(pad_dim, dim)
|
||||
# Fill in the placeholder for the image
|
||||
try:
|
||||
input_embeds[
|
||||
start_idx
|
||||
+ image_offsets[i] : start_idx
|
||||
+ image_offsets[i]
|
||||
+ pad_len
|
||||
] = image_features[pt]
|
||||
except RuntimeError as e:
|
||||
print(f"RuntimeError in llava image encoding: {e}")
|
||||
print(input_embeds.shape)
|
||||
print(start_idx, image_offsets[i])
|
||||
pt += 1
|
||||
prefix_len = prefix_lens_cpu[i]
|
||||
|
||||
# Multiple images
|
||||
for image_offset in image_offsets[i]:
|
||||
if image_offset < prefix_len:
|
||||
continue
|
||||
|
||||
tmp_image_feature = image_features[pt]
|
||||
pad_len = tmp_image_feature.shape[0]
|
||||
|
||||
left_idx = start_idx + (image_offset - prefix_len)
|
||||
right_idx = start_idx + (image_offset - prefix_len) + pad_len
|
||||
try:
|
||||
input_embeds[left_idx:right_idx] = tmp_image_feature
|
||||
except RuntimeError as e:
|
||||
print(f"RuntimeError in image encoding: {e}")
|
||||
print(f"{input_embeds.shape=}, {tmp_image_feature.shape=}")
|
||||
print(
|
||||
f"{start_idx=}, {image_offset=}, {prefix_len=}, {pad_len=}"
|
||||
)
|
||||
pt += 1
|
||||
|
||||
return self.language_model(
|
||||
input_ids, positions, input_metadata, input_embeds=input_embeds
|
||||
@@ -216,8 +203,9 @@ class LlavaVidForCausalLM(nn.Module):
|
||||
return self.language_model(input_ids, positions, input_metadata)
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
# load clip vision model by cfg['mm_vision_tower']:
|
||||
# huggingface_name or path_of_clip_relative_to_llava_model_dir
|
||||
# Load clip vision model by cfg['mm_vision_tower']:
|
||||
# huggingface_name or path_of_clip_relative_to_llava_model_dir
|
||||
# We put the initialization here instead of __init__ to allow it being reused by other subclasses.
|
||||
vision_path = self.config.mm_vision_tower
|
||||
self.vision_tower = CLIPVisionModel.from_pretrained(
|
||||
vision_path, torch_dtype=torch.float16
|
||||
@@ -271,43 +259,9 @@ class LlavaVidForCausalLM(nn.Module):
|
||||
# load language model
|
||||
self.language_model.load_weights(weights)
|
||||
|
||||
monkey_path_clip_vision_embed_forward()
|
||||
|
||||
@property
|
||||
def num_patches_per_side(self):
|
||||
return self.image_size // self.patch_size
|
||||
|
||||
|
||||
first_call = True
|
||||
|
||||
|
||||
def clip_vision_embed_forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
|
||||
batch_size = pixel_values.shape[0]
|
||||
|
||||
# Move this conv layer to CPU to avoid a bug in torch >= 2.1 on A10G.
|
||||
global first_call
|
||||
if first_call:
|
||||
self.patch_embedding.cpu().float()
|
||||
first_call = False
|
||||
pixel_values = pixel_values.to(dtype=torch.float32, device="cpu")
|
||||
patch_embeds = self.patch_embedding(pixel_values).cuda().half()
|
||||
|
||||
patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
|
||||
|
||||
class_embeds = self.class_embedding.expand(batch_size, 1, -1)
|
||||
embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
|
||||
embeddings = embeddings + self.position_embedding(self.position_ids)
|
||||
return embeddings
|
||||
|
||||
|
||||
def monkey_path_clip_vision_embed_forward():
|
||||
import transformers
|
||||
|
||||
setattr(
|
||||
transformers.models.clip.modeling_clip.CLIPVisionEmbeddings,
|
||||
"forward",
|
||||
clip_vision_embed_forward,
|
||||
)
|
||||
|
||||
|
||||
EntryClass = LlavaVidForCausalLM
|
||||
|
||||
@@ -312,6 +312,9 @@ class Qwen2ForCausalLM(nn.Module):
|
||||
# Models trained using ColossalAI may include these tensors in
|
||||
# the checkpoint. Skip them.
|
||||
continue
|
||||
if name.startswith("model.vision_tower") and name not in params_dict:
|
||||
continue
|
||||
|
||||
for param_name, weight_name, shard_id in stacked_params_mapping:
|
||||
if weight_name not in name:
|
||||
continue
|
||||
@@ -319,8 +322,6 @@ class Qwen2ForCausalLM(nn.Module):
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
if name.startswith("model.vision_tower") and name not in params_dict:
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param, loaded_weight, shard_id)
|
||||
@@ -329,8 +330,6 @@ class Qwen2ForCausalLM(nn.Module):
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
if name.startswith("model.vision_tower") and name not in params_dict:
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
|
||||
@@ -24,10 +24,7 @@ from vllm.config import CacheConfig
|
||||
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
|
||||
from sglang.srt.models.llava import (
|
||||
LlavaLlamaForCausalLM,
|
||||
monkey_path_clip_vision_embed_forward,
|
||||
)
|
||||
from sglang.srt.models.llava import LlavaLlamaForCausalLM
|
||||
|
||||
|
||||
class YiVLForCausalLM(LlavaLlamaForCausalLM):
|
||||
@@ -50,7 +47,7 @@ class YiVLForCausalLM(LlavaLlamaForCausalLM):
|
||||
self.config._name_or_path,
|
||||
torch_dtype=torch.float16,
|
||||
subfolder=self.vision_tower_subfolder,
|
||||
).cuda()
|
||||
).to("cuda")
|
||||
|
||||
self.vision_tower.eval()
|
||||
|
||||
@@ -94,8 +91,6 @@ class YiVLForCausalLM(LlavaLlamaForCausalLM):
|
||||
# load language model
|
||||
self.language_model.load_weights(weights)
|
||||
|
||||
monkey_path_clip_vision_embed_forward()
|
||||
|
||||
|
||||
class YiVLMultiModalProjector(nn.Module):
|
||||
def __init__(self, config: LlavaConfig):
|
||||
|
||||
@@ -335,12 +335,12 @@ def launch_server(
|
||||
pipe_detoken_reader, pipe_detoken_writer = mp.Pipe(duplex=False)
|
||||
|
||||
if server_args.dp_size == 1:
|
||||
start_process = start_controller_process_single
|
||||
start_controller_process = start_controller_process_single
|
||||
else:
|
||||
start_process = start_controller_process_multi
|
||||
start_controller_process = start_controller_process_multi
|
||||
|
||||
proc_controller = mp.Process(
|
||||
target=start_process,
|
||||
target=start_controller_process,
|
||||
args=(server_args, port_args, pipe_controller_writer, model_overide_args),
|
||||
)
|
||||
proc_controller.start()
|
||||
@@ -421,7 +421,7 @@ def _set_envs_and_config(server_args: ServerArgs):
|
||||
if not server_args.disable_flashinfer:
|
||||
assert_pkg_version(
|
||||
"flashinfer",
|
||||
"0.1.6",
|
||||
"0.1.5",
|
||||
"Please uninstall the old version and "
|
||||
"reinstall the latest version by following the instructions "
|
||||
"at https://docs.flashinfer.ai/installation.html.",
|
||||
|
||||
@@ -26,7 +26,7 @@ import struct
|
||||
import time
|
||||
from importlib.metadata import PackageNotFoundError, version
|
||||
from io import BytesIO
|
||||
from typing import List, Optional
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import psutil
|
||||
@@ -193,35 +193,16 @@ def allocate_init_ports(
|
||||
return ret_ports[0], ret_ports[1:num_ports_needed]
|
||||
|
||||
|
||||
def get_int_token_logit_bias(tokenizer, vocab_size):
|
||||
"""Get the logit bias for integer-only tokens."""
|
||||
# a bug when model's vocab size > tokenizer.vocab_size
|
||||
if tokenizer == None:
|
||||
return [-1e5] * vocab_size
|
||||
vocab_size = tokenizer.vocab_size
|
||||
logit_bias = np.zeros(vocab_size, dtype=np.float32)
|
||||
for t_id in range(vocab_size):
|
||||
ss = tokenizer.decode([t_id]).strip()
|
||||
if not (ss.isdigit() or len(ss) == 0 or t_id == tokenizer.eos_token_id):
|
||||
logit_bias[t_id] = -1e5
|
||||
|
||||
return logit_bias
|
||||
|
||||
|
||||
def is_multimodal_model(model):
|
||||
from sglang.srt.model_config import ModelConfig
|
||||
|
||||
if isinstance(model, str):
|
||||
model = model.lower()
|
||||
return "llava" in model or "yi-vl" in model or "llava-next" in model
|
||||
|
||||
if isinstance(model, ModelConfig):
|
||||
model_path = model.path.lower()
|
||||
return (
|
||||
"llava" in model_path or "yi-vl" in model_path or "llava-next" in model_path
|
||||
)
|
||||
|
||||
raise ValueError("unrecognized type")
|
||||
def is_multimodal_model(model_architectures):
|
||||
if (
|
||||
"LlavaLlamaForCausalLM" in model_architectures
|
||||
or "LlavaQwenForCausalLM" in model_architectures
|
||||
or "LlavaMistralForCausalLM" in model_architectures
|
||||
or "LlavaVidForCausalLM" in model_architectures
|
||||
):
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
def is_generation_model(model_architectures, is_embedding: bool = False):
|
||||
@@ -317,12 +298,14 @@ def decode_video_base64(video_base64):
|
||||
) # Return an empty array and size tuple if no frames were found
|
||||
|
||||
|
||||
def load_image(image_file):
|
||||
def load_image(image_file: Union[str, bytes]):
|
||||
from PIL import Image
|
||||
|
||||
image = image_size = None
|
||||
|
||||
if image_file.startswith("http://") or image_file.startswith("https://"):
|
||||
if isinstance(image_file, bytes):
|
||||
image = Image.open(BytesIO(image_file))
|
||||
elif image_file.startswith("http://") or image_file.startswith("https://"):
|
||||
timeout = int(os.getenv("REQUEST_TIMEOUT", "3"))
|
||||
response = requests.get(image_file, timeout=timeout)
|
||||
image = Image.open(BytesIO(response.content))
|
||||
@@ -334,8 +317,10 @@ def load_image(image_file):
|
||||
elif image_file.startswith("video:"):
|
||||
image_file = image_file.replace("video:", "")
|
||||
image, image_size = decode_video_base64(image_file)
|
||||
else:
|
||||
elif isinstance(image_file, str):
|
||||
image = Image.open(BytesIO(base64.b64decode(image_file)))
|
||||
else:
|
||||
raise ValueError(f"Invalid image: {image}")
|
||||
|
||||
return image, image_size
|
||||
|
||||
|
||||
Reference in New Issue
Block a user