Llava-hd Support (#92)
Co-authored-by: Haotian Liu <liuhaotian.cn@gmail.com>
This commit is contained in:
@@ -20,6 +20,7 @@ from sglang.srt.managers.io_struct import (
|
||||
GenerateReqInput,
|
||||
TokenizedGenerateReqInput,
|
||||
)
|
||||
from sglang.srt.mm_utils import expand2square, process_anyres_image
|
||||
from sglang.srt.sampling_params import SamplingParams
|
||||
from sglang.srt.server_args import PortArgs, ServerArgs
|
||||
from sglang.srt.utils import get_exception_traceback, is_multimodal_model, load_image
|
||||
@@ -48,14 +49,25 @@ def init_global_processor(server_args: ServerArgs):
|
||||
)
|
||||
|
||||
|
||||
def get_pixel_values(image_data, processor=None):
|
||||
def get_pixel_values(image_data, model_cfg, processor=None):
|
||||
image_aspect_ratio = getattr(model_cfg, "image_aspect_ratio", None)
|
||||
try:
|
||||
processor = processor or global_processor
|
||||
image = load_image(image_data)
|
||||
image_hash = hash(image_data)
|
||||
pixel_values = processor.image_processor(image)["pixel_values"][0]
|
||||
if image_aspect_ratio == "pad":
|
||||
image = expand2square(
|
||||
image, tuple(int(x * 255) for x in processor.image_processor.image_mean)
|
||||
)
|
||||
pixel_values = processor.image_processor(image)["pixel_values"][0]
|
||||
elif image_aspect_ratio == "anyres":
|
||||
pixel_values = process_anyres_image(
|
||||
image, processor.image_processor, model_cfg.image_grid_pinpoints
|
||||
)
|
||||
else:
|
||||
pixel_values = processor.image_processor(image)["pixel_values"][0]
|
||||
pixel_values = pixel_values.astype(np.float16)
|
||||
return pixel_values, image_hash
|
||||
return pixel_values, image_hash, image.size
|
||||
except Exception:
|
||||
print("Exception in TokenizerManager:\n" + get_exception_traceback())
|
||||
|
||||
@@ -77,6 +89,7 @@ class TokenizerManager:
|
||||
self.hf_config = get_config(
|
||||
self.model_path, trust_remote_code=server_args.trust_remote_code
|
||||
)
|
||||
|
||||
self.context_len = get_context_length(self.hf_config)
|
||||
|
||||
if is_multimodal_model(self.model_path):
|
||||
@@ -104,10 +117,10 @@ class TokenizerManager:
|
||||
if self.executor is not None:
|
||||
loop = asyncio.get_event_loop()
|
||||
return await loop.run_in_executor(
|
||||
self.executor, get_pixel_values, image_data
|
||||
self.executor, get_pixel_values, image_data, self.hf_config
|
||||
)
|
||||
else:
|
||||
return get_pixel_values(image_data, self.processor)
|
||||
return get_pixel_values(image_data, self.hf_config, self.processor)
|
||||
|
||||
async def generate_request(self, obj: GenerateReqInput):
|
||||
if self.to_create_loop:
|
||||
@@ -123,14 +136,17 @@ class TokenizerManager:
|
||||
sampling_params.normalize(self.tokenizer)
|
||||
sampling_params.verify()
|
||||
if obj.image_data is None:
|
||||
pixel_values, image_hash = None, None
|
||||
pixel_values, image_hash, image_size = None, None, None
|
||||
else:
|
||||
pixel_values, image_hash = await self.get_pixel_values(obj.image_data)
|
||||
pixel_values, image_hash, image_size = await self.get_pixel_values(
|
||||
obj.image_data
|
||||
)
|
||||
tokenized_obj = TokenizedGenerateReqInput(
|
||||
rid=rid,
|
||||
input_ids=input_ids,
|
||||
pixel_values=pixel_values,
|
||||
image_hash=image_hash,
|
||||
image_size=image_size,
|
||||
sampling_params=sampling_params,
|
||||
return_logprob=obj.return_logprob,
|
||||
logprob_start_len=obj.logprob_start_len,
|
||||
@@ -162,9 +178,9 @@ class TokenizerManager:
|
||||
sampling_params.normalize(self.tokenizer)
|
||||
sampling_params.verify()
|
||||
if obj.image_data[i] is None:
|
||||
pixel_values, image_hash = None, None
|
||||
pixel_values, image_hash, image_size = None, None, None
|
||||
else:
|
||||
pixel_values, image_hash = await self.get_pixel_values(
|
||||
pixel_values, image_hash, image_size = await self.get_pixel_values(
|
||||
obj.image_data[i]
|
||||
)
|
||||
tokenized_obj = TokenizedGenerateReqInput(
|
||||
@@ -172,6 +188,7 @@ class TokenizerManager:
|
||||
input_ids=input_ids,
|
||||
pixel_values=pixel_values,
|
||||
image_hash=image_hash,
|
||||
image_size=image_size,
|
||||
sampling_params=sampling_params,
|
||||
return_logprob=obj.return_logprob[i],
|
||||
logprob_start_len=obj.logprob_start_len[i],
|
||||
|
||||
Reference in New Issue
Block a user