Llava-hd Support (#92)

Co-authored-by: Haotian Liu <liuhaotian.cn@gmail.com>
This commit is contained in:
shiyi.c_98
2024-01-24 01:51:21 -08:00
committed by GitHub
parent 99258181c6
commit c6576e820c
10 changed files with 429 additions and 38 deletions

View File

@@ -20,6 +20,7 @@ from sglang.srt.managers.io_struct import (
GenerateReqInput,
TokenizedGenerateReqInput,
)
from sglang.srt.mm_utils import expand2square, process_anyres_image
from sglang.srt.sampling_params import SamplingParams
from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.utils import get_exception_traceback, is_multimodal_model, load_image
@@ -48,14 +49,25 @@ def init_global_processor(server_args: ServerArgs):
)
def get_pixel_values(image_data, processor=None):
def get_pixel_values(image_data, model_cfg, processor=None):
image_aspect_ratio = getattr(model_cfg, "image_aspect_ratio", None)
try:
processor = processor or global_processor
image = load_image(image_data)
image_hash = hash(image_data)
pixel_values = processor.image_processor(image)["pixel_values"][0]
if image_aspect_ratio == "pad":
image = expand2square(
image, tuple(int(x * 255) for x in processor.image_processor.image_mean)
)
pixel_values = processor.image_processor(image)["pixel_values"][0]
elif image_aspect_ratio == "anyres":
pixel_values = process_anyres_image(
image, processor.image_processor, model_cfg.image_grid_pinpoints
)
else:
pixel_values = processor.image_processor(image)["pixel_values"][0]
pixel_values = pixel_values.astype(np.float16)
return pixel_values, image_hash
return pixel_values, image_hash, image.size
except Exception:
print("Exception in TokenizerManager:\n" + get_exception_traceback())
@@ -77,6 +89,7 @@ class TokenizerManager:
self.hf_config = get_config(
self.model_path, trust_remote_code=server_args.trust_remote_code
)
self.context_len = get_context_length(self.hf_config)
if is_multimodal_model(self.model_path):
@@ -104,10 +117,10 @@ class TokenizerManager:
if self.executor is not None:
loop = asyncio.get_event_loop()
return await loop.run_in_executor(
self.executor, get_pixel_values, image_data
self.executor, get_pixel_values, image_data, self.hf_config
)
else:
return get_pixel_values(image_data, self.processor)
return get_pixel_values(image_data, self.hf_config, self.processor)
async def generate_request(self, obj: GenerateReqInput):
if self.to_create_loop:
@@ -123,14 +136,17 @@ class TokenizerManager:
sampling_params.normalize(self.tokenizer)
sampling_params.verify()
if obj.image_data is None:
pixel_values, image_hash = None, None
pixel_values, image_hash, image_size = None, None, None
else:
pixel_values, image_hash = await self.get_pixel_values(obj.image_data)
pixel_values, image_hash, image_size = await self.get_pixel_values(
obj.image_data
)
tokenized_obj = TokenizedGenerateReqInput(
rid=rid,
input_ids=input_ids,
pixel_values=pixel_values,
image_hash=image_hash,
image_size=image_size,
sampling_params=sampling_params,
return_logprob=obj.return_logprob,
logprob_start_len=obj.logprob_start_len,
@@ -162,9 +178,9 @@ class TokenizerManager:
sampling_params.normalize(self.tokenizer)
sampling_params.verify()
if obj.image_data[i] is None:
pixel_values, image_hash = None, None
pixel_values, image_hash, image_size = None, None, None
else:
pixel_values, image_hash = await self.get_pixel_values(
pixel_values, image_hash, image_size = await self.get_pixel_values(
obj.image_data[i]
)
tokenized_obj = TokenizedGenerateReqInput(
@@ -172,6 +188,7 @@ class TokenizerManager:
input_ids=input_ids,
pixel_values=pixel_values,
image_hash=image_hash,
image_size=image_size,
sampling_params=sampling_params,
return_logprob=obj.return_logprob[i],
logprob_start_len=obj.logprob_start_len[i],