Llama3.2 vision model support (#1551)
This commit is contained in:
@@ -33,20 +33,9 @@ def init_global_processor(server_args: ServerArgs):
|
||||
|
||||
|
||||
class BaseImageProcessor(ABC):
|
||||
@abstractmethod
|
||||
async def process_images_async(self, image_data, **kwargs):
|
||||
pass
|
||||
|
||||
|
||||
class DummyImageProcessor(BaseImageProcessor):
|
||||
async def process_images_async(self, *args, **kwargs):
|
||||
return None
|
||||
|
||||
|
||||
class LlavaImageProcessor(BaseImageProcessor):
|
||||
def __init__(self, hf_config, server_args, _image_processor):
|
||||
def __init__(self, hf_config, server_args, _processor):
|
||||
self.hf_config = hf_config
|
||||
self._image_processor = _image_processor
|
||||
self._processor = _processor
|
||||
self.executor = concurrent.futures.ProcessPoolExecutor(
|
||||
initializer=init_global_processor,
|
||||
mp_context=mp.get_context("fork"),
|
||||
@@ -54,6 +43,23 @@ class LlavaImageProcessor(BaseImageProcessor):
|
||||
max_workers=os.environ.get("SGLANG_CPU_COUNT", os.cpu_count()),
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
async def process_images_async(self, image_data, input_text, **kwargs):
|
||||
pass
|
||||
|
||||
|
||||
class DummyImageProcessor(BaseImageProcessor):
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
async def process_images_async(self, *args, **kwargs):
|
||||
return None
|
||||
|
||||
|
||||
class LlavaImageProcessor(BaseImageProcessor):
|
||||
def __init__(self, hf_config, server_args, _processor):
|
||||
super().__init__(hf_config, server_args, _processor)
|
||||
|
||||
@staticmethod
|
||||
def _process_single_image_task(
|
||||
image_data: Union[str, bytes],
|
||||
@@ -119,7 +125,7 @@ class LlavaImageProcessor(BaseImageProcessor):
|
||||
)
|
||||
|
||||
async def process_images_async(
|
||||
self, image_data: List[Union[str, bytes]], request_obj
|
||||
self, image_data: List[Union[str, bytes]], input_text, request_obj
|
||||
):
|
||||
if not image_data:
|
||||
return None
|
||||
@@ -177,6 +183,54 @@ class LlavaImageProcessor(BaseImageProcessor):
|
||||
}
|
||||
|
||||
|
||||
class MllamaImageProcessor(BaseImageProcessor):
|
||||
def __init__(self, hf_config, server_args, _processor):
|
||||
super().__init__(hf_config, server_args, _processor)
|
||||
|
||||
@staticmethod
|
||||
def _process_single_image_task(images, input_text):
|
||||
# input_ids', 'attention_mask', 'pixel_values', 'aspect_ratio_ids', 'aspect_ratio_mask', 'cross_attention_mask'
|
||||
return global_processor(images, input_text, return_tensors="pt")
|
||||
|
||||
async def _process_single_image(self, images, input_text):
|
||||
if self.executor is not None:
|
||||
loop = asyncio.get_event_loop()
|
||||
image_inputs = await loop.run_in_executor(
|
||||
self.executor,
|
||||
MllamaImageProcessor._process_single_image_task,
|
||||
images,
|
||||
input_text,
|
||||
)
|
||||
else:
|
||||
image_inputs = self._processor(images, input_text, return_tensors="pt")
|
||||
|
||||
return image_inputs
|
||||
|
||||
async def process_images_async(
|
||||
self, image_data: List[Union[str, bytes]], input_text, *args, **kwargs
|
||||
):
|
||||
if not image_data:
|
||||
return None
|
||||
|
||||
if isinstance(input_text, list):
|
||||
assert len(input_text) and isinstance(input_text[0], int)
|
||||
input_text = self._processor.tokenizer.decode(input_text)
|
||||
|
||||
if not isinstance(image_data, list):
|
||||
image_data = [image_data]
|
||||
|
||||
if len(image_data) > 0:
|
||||
images = [load_image(image)[0] for image in image_data]
|
||||
else:
|
||||
images = load_image(image_data[0])[0]
|
||||
|
||||
image_inputs = await self._process_single_image(images, input_text)
|
||||
image_inputs["image_hashes"] = [hash(str(image_data))]
|
||||
image_inputs["input_ids"] = image_inputs["input_ids"].tolist()[0]
|
||||
|
||||
return image_inputs
|
||||
|
||||
|
||||
class Qwen2VLImageProcessor(BaseImageProcessor):
|
||||
def __init__(self, hf_config, server_args, _image_processor):
|
||||
self.hf_config = hf_config
|
||||
@@ -237,7 +291,7 @@ class Qwen2VLImageProcessor(BaseImageProcessor):
|
||||
return self._process_single_image_task(image_data)
|
||||
|
||||
async def process_images_async(
|
||||
self, image_data: List[Union[str, bytes]], request_obj
|
||||
self, image_data: List[Union[str, bytes]], input_text, request_obj
|
||||
):
|
||||
if not image_data:
|
||||
return None
|
||||
@@ -292,12 +346,14 @@ class Qwen2VLImageProcessor(BaseImageProcessor):
|
||||
|
||||
|
||||
def get_image_processor(
|
||||
hf_config, server_args: ServerArgs, _image_processor
|
||||
hf_config, server_args: ServerArgs, processor
|
||||
) -> BaseImageProcessor:
|
||||
if "Qwen2VLForConditionalGeneration" in hf_config.architectures:
|
||||
return Qwen2VLImageProcessor(hf_config, server_args, _image_processor)
|
||||
if "MllamaForConditionalGeneration" in hf_config.architectures:
|
||||
return MllamaImageProcessor(hf_config, server_args, processor)
|
||||
elif "Qwen2VLForConditionalGeneration" in hf_config.architectures:
|
||||
return Qwen2VLImageProcessor(hf_config, server_args, processor.image_processor)
|
||||
else:
|
||||
return LlavaImageProcessor(hf_config, server_args, _image_processor)
|
||||
return LlavaImageProcessor(hf_config, server_args, processor.image_processor)
|
||||
|
||||
|
||||
def get_dummy_image_processor():
|
||||
|
||||
@@ -36,6 +36,7 @@ from typing import List, Optional, Tuple, Union
|
||||
import torch
|
||||
|
||||
from sglang.global_config import global_config
|
||||
from sglang.srt.configs.model_config import ModelConfig
|
||||
from sglang.srt.constrained import RegexGuide
|
||||
from sglang.srt.constrained.jump_forward import JumpForwardMap
|
||||
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
|
||||
@@ -121,11 +122,12 @@ class ImageInputs:
|
||||
"""The image related inputs."""
|
||||
|
||||
pixel_values: torch.Tensor
|
||||
image_hash: int
|
||||
image_hashes: Optional[list] = None
|
||||
image_sizes: Optional[list] = None
|
||||
image_offsets: Optional[list] = None
|
||||
pad_values: Optional[list] = None
|
||||
modalities: Optional[list] = None
|
||||
num_image_tokens: Optional[int] = None
|
||||
|
||||
image_embeds: Optional[List[torch.Tensor]] = None
|
||||
aspect_ratio_ids: Optional[List[torch.Tensor]] = None
|
||||
@@ -138,19 +140,27 @@ class ImageInputs:
|
||||
# Use image hash as fake token_ids, which is then used for prefix matching
|
||||
ret = ImageInputs(
|
||||
pixel_values=obj["pixel_values"],
|
||||
image_hash=hash(tuple(obj["image_hashes"])),
|
||||
image_grid_thws=obj.get("image_grid_thws"),
|
||||
image_hashes=hash(tuple(obj["image_hashes"])),
|
||||
)
|
||||
image_hash = ret.image_hash
|
||||
image_hash = ret.image_hashes
|
||||
ret.pad_values = [
|
||||
(image_hash) % vocab_size,
|
||||
(image_hash >> 16) % vocab_size,
|
||||
(image_hash >> 32) % vocab_size,
|
||||
(image_hash >> 64) % vocab_size,
|
||||
]
|
||||
ret.image_sizes = obj["image_sizes"]
|
||||
# Only when pixel values is not None we have modalities
|
||||
ret.modalities = obj["modalities"] or ["image"]
|
||||
|
||||
optional_args = [
|
||||
"image_sizes",
|
||||
"modalities",
|
||||
"aspect_ratio_ids",
|
||||
"aspect_ratio_mask",
|
||||
"image_grid_thws",
|
||||
]
|
||||
for arg in optional_args:
|
||||
if arg in obj:
|
||||
setattr(ret, arg, obj[arg])
|
||||
|
||||
return ret
|
||||
|
||||
|
||||
@@ -416,6 +426,10 @@ class ScheduleBatch:
|
||||
req_to_token_pool: ReqToTokenPool = None
|
||||
token_to_kv_pool: BaseTokenToKVPool = None
|
||||
tree_cache: BasePrefixCache = None
|
||||
|
||||
# For utility
|
||||
model_config: ModelConfig = None
|
||||
|
||||
forward_mode: ForwardMode = None
|
||||
sampling_info: SamplingBatchInfo = None
|
||||
|
||||
@@ -440,6 +454,12 @@ class ScheduleBatch:
|
||||
extend_num_tokens: int = None
|
||||
decoding_reqs: List[Req] = None
|
||||
|
||||
# For encoder-decoder
|
||||
encoder_cached: Optional[List[bool]] = None
|
||||
encoder_lens: Optional[torch.Tensor] = None
|
||||
encoder_lens_cpu: Optional[List[int]] = None
|
||||
encoder_out_cache_loc: Optional[torch.Tensor] = None
|
||||
|
||||
# Stream
|
||||
has_stream: bool = False
|
||||
|
||||
@@ -450,12 +470,20 @@ class ScheduleBatch:
|
||||
device: str = "cuda"
|
||||
|
||||
@classmethod
|
||||
def init_new(cls, reqs, req_to_token_pool, token_to_kv_pool, tree_cache):
|
||||
def init_new(
|
||||
cls,
|
||||
reqs,
|
||||
req_to_token_pool,
|
||||
token_to_kv_pool,
|
||||
tree_cache,
|
||||
model_config,
|
||||
):
|
||||
return cls(
|
||||
reqs=reqs,
|
||||
req_to_token_pool=req_to_token_pool,
|
||||
token_to_kv_pool=token_to_kv_pool,
|
||||
tree_cache=tree_cache,
|
||||
model_config=model_config,
|
||||
return_logprob=any(req.return_logprob for req in reqs),
|
||||
has_stream=any(req.stream for req in reqs),
|
||||
has_regex=any(req.regex_fsm for req in reqs),
|
||||
@@ -493,7 +521,78 @@ class ScheduleBatch:
|
||||
|
||||
return out_cache_loc
|
||||
|
||||
def prepare_for_extend(self, vocab_size: int):
|
||||
def prepare_encoder_info_extend(self, input_ids: List[int], seq_lens: List[int]):
|
||||
self.encoder_lens_cpu = []
|
||||
self.encoder_cached = []
|
||||
|
||||
for req in self.reqs:
|
||||
im = req.image_inputs
|
||||
if im is None or im.num_image_tokens is None:
|
||||
# No image input
|
||||
self.encoder_lens_cpu.append(0)
|
||||
self.encoder_cached.append(True)
|
||||
else:
|
||||
self.encoder_lens_cpu.append(im.num_image_tokens)
|
||||
self.encoder_cached.append(
|
||||
self.forward_mode.is_decode()
|
||||
or len(req.prefix_indices) >= im.num_image_tokens
|
||||
)
|
||||
|
||||
self.encoder_lens = torch.tensor(self.encoder_lens_cpu, dtype=torch.int32).to(
|
||||
self.device, non_blocking=True
|
||||
)
|
||||
|
||||
# Strip encoder infos
|
||||
pt = 0
|
||||
decoder_out_cache_loc = []
|
||||
encoder_out_cache_loc = []
|
||||
for i, req in enumerate(self.reqs):
|
||||
encoder_len = self.encoder_lens_cpu[i]
|
||||
seq_lens[i] -= encoder_len
|
||||
|
||||
if len(req.prefix_indices) < encoder_len:
|
||||
# NOTE: the encoder part should considered as a whole
|
||||
assert len(req.prefix_indices) == 0
|
||||
input_ids[i] = input_ids[i][encoder_len:]
|
||||
encoder_out_cache_loc.append(self.out_cache_loc[pt : pt + encoder_len])
|
||||
decoder_out_cache_loc.append(
|
||||
self.out_cache_loc[pt + encoder_len : pt + req.extend_input_len]
|
||||
)
|
||||
self.extend_lens[i] -= encoder_len
|
||||
self.extend_num_tokens -= encoder_len
|
||||
else:
|
||||
decoder_out_cache_loc.append(
|
||||
self.out_cache_loc[pt : pt + req.extend_input_len]
|
||||
)
|
||||
self.prefix_lens[i] -= encoder_len
|
||||
|
||||
pt += req.extend_input_len
|
||||
|
||||
# Reassign
|
||||
self.input_ids = torch.tensor(sum(input_ids, []), dtype=torch.int32).to(
|
||||
self.device, non_blocking=True
|
||||
)
|
||||
self.seq_lens = torch.tensor(seq_lens, dtype=torch.int32).to(
|
||||
self.device, non_blocking=True
|
||||
)
|
||||
|
||||
if not decoder_out_cache_loc:
|
||||
self.out_cache_loc = torch.empty(0, dtype=torch.int32).to(
|
||||
self.device, non_blocking=True
|
||||
)
|
||||
else:
|
||||
self.out_cache_loc = torch.cat(decoder_out_cache_loc)
|
||||
|
||||
if not encoder_out_cache_loc:
|
||||
self.encoder_out_cache_loc = torch.empty(0, dtype=torch.int32).to(
|
||||
self.device, non_blocking=True
|
||||
)
|
||||
else:
|
||||
self.encoder_out_cache_loc = torch.cat(encoder_out_cache_loc)
|
||||
|
||||
assert len(self.out_cache_loc) == self.extend_num_tokens
|
||||
|
||||
def prepare_for_extend(self):
|
||||
self.forward_mode = ForwardMode.EXTEND
|
||||
|
||||
bs = len(self.reqs)
|
||||
@@ -561,8 +660,13 @@ class ScheduleBatch:
|
||||
self.extend_lens = [r.extend_input_len for r in reqs]
|
||||
self.extend_logprob_start_lens = [r.extend_logprob_start_len for r in reqs]
|
||||
|
||||
if self.model_config.is_encoder_decoder:
|
||||
self.prepare_encoder_info_extend(input_ids, seq_lens)
|
||||
|
||||
self.sampling_info = SamplingBatchInfo.from_schedule_batch(
|
||||
self, vocab_size, global_server_args_dict["disable_penalizer"]
|
||||
self,
|
||||
self.model_config.vocab_size,
|
||||
global_server_args_dict["disable_penalizer"],
|
||||
)
|
||||
|
||||
def mix_with_running(self, running_batch: "ScheduleBatch"):
|
||||
@@ -752,6 +856,10 @@ class ScheduleBatch:
|
||||
|
||||
return jump_forward_reqs
|
||||
|
||||
def prepare_encoder_info_decode(self):
|
||||
# Reset the encoder cached status
|
||||
self.encoder_cached = [True] * len(self.reqs)
|
||||
|
||||
def prepare_for_decode(self, enable_overlap: bool = False):
|
||||
self.forward_mode = ForwardMode.DECODE
|
||||
|
||||
@@ -766,16 +874,22 @@ class ScheduleBatch:
|
||||
bs = len(self.reqs)
|
||||
self.out_cache_loc = self.alloc_token_slots(bs)
|
||||
|
||||
if self.model_config.is_encoder_decoder:
|
||||
locs = self.encoder_lens + self.seq_lens
|
||||
self.prepare_encoder_info_decode()
|
||||
else:
|
||||
locs = self.seq_lens
|
||||
|
||||
if enable_overlap:
|
||||
# Do not use in-place operations in the overlap mode
|
||||
self.req_to_token_pool.write(
|
||||
(self.req_pool_indices, self.seq_lens), self.out_cache_loc
|
||||
(self.req_pool_indices, locs), self.out_cache_loc
|
||||
)
|
||||
self.seq_lens = self.seq_lens + 1
|
||||
else:
|
||||
# A faster in-place version
|
||||
self.req_to_token_pool.write(
|
||||
(self.req_pool_indices, self.seq_lens), self.out_cache_loc
|
||||
(self.req_pool_indices, locs), self.out_cache_loc
|
||||
)
|
||||
self.seq_lens.add_(1)
|
||||
self.seq_lens_sum += bs
|
||||
@@ -802,6 +916,10 @@ class ScheduleBatch:
|
||||
# No need to filter
|
||||
return
|
||||
|
||||
if self.model_config.is_encoder_decoder:
|
||||
self.encoder_lens = self.encoder_lens[keep_indices]
|
||||
self.encoder_lens_cpu = [self.encoder_lens_cpu[i] for i in keep_indices]
|
||||
|
||||
self.reqs = [self.reqs[i] for i in keep_indices]
|
||||
new_indices = torch.tensor(keep_indices, dtype=torch.int32).to(
|
||||
self.device, non_blocking=True
|
||||
@@ -828,6 +946,11 @@ class ScheduleBatch:
|
||||
# needs to be called with pre-merged Batch.reqs.
|
||||
self.sampling_info.merge_batch(other.sampling_info)
|
||||
|
||||
# Encoder-decoder infos
|
||||
if self.model_config.is_encoder_decoder:
|
||||
self.encoder_lens = torch.cat([self.encoder_lens, other.encoder_lens])
|
||||
self.encoder_lens_cpu.extend(other.encoder_lens_cpu)
|
||||
|
||||
self.req_pool_indices = torch.concat(
|
||||
[self.req_pool_indices, other.req_pool_indices]
|
||||
)
|
||||
@@ -850,14 +973,11 @@ class ScheduleBatch:
|
||||
|
||||
def get_model_worker_batch(self):
|
||||
if self.forward_mode.is_decode():
|
||||
extend_seq_lens = extend_prefix_lens = extend_logprob_start_lens = (
|
||||
image_inputs
|
||||
) = None
|
||||
extend_seq_lens = extend_prefix_lens = extend_logprob_start_lens = None
|
||||
else:
|
||||
extend_seq_lens = self.extend_lens
|
||||
extend_prefix_lens = self.prefix_lens
|
||||
extend_logprob_start_lens = self.extend_logprob_start_lens
|
||||
image_inputs = [r.image_inputs for r in self.reqs]
|
||||
|
||||
if self.has_regex:
|
||||
self.sampling_info.regex_fsms = [req.regex_fsm for req in self.reqs]
|
||||
@@ -887,7 +1007,11 @@ class ScheduleBatch:
|
||||
extend_seq_lens=extend_seq_lens,
|
||||
extend_prefix_lens=extend_prefix_lens,
|
||||
extend_logprob_start_lens=extend_logprob_start_lens,
|
||||
image_inputs=image_inputs,
|
||||
image_inputs=[r.image_inputs for r in self.reqs],
|
||||
encoder_cached=self.encoder_cached,
|
||||
encoder_lens=self.encoder_lens,
|
||||
encoder_lens_cpu=self.encoder_lens_cpu,
|
||||
encoder_out_cache_loc=self.encoder_out_cache_loc,
|
||||
lora_paths=[req.lora_path for req in self.reqs],
|
||||
sampling_info=self.sampling_info,
|
||||
mrope_positions_delta=mrope_positions_delta,
|
||||
@@ -897,6 +1021,7 @@ class ScheduleBatch:
|
||||
# Only contain fields that will be used by process_batch_result
|
||||
return ScheduleBatch(
|
||||
reqs=self.reqs,
|
||||
model_config=self.model_config,
|
||||
forward_mode=self.forward_mode,
|
||||
out_cache_loc=self.out_cache_loc,
|
||||
return_logprob=self.return_logprob,
|
||||
@@ -944,6 +1069,12 @@ class ModelWorkerBatch:
|
||||
# For multimodal
|
||||
image_inputs: Optional[List[ImageInputs]]
|
||||
|
||||
# For encoder-decoder
|
||||
encoder_cached: Optional[List[bool]]
|
||||
encoder_lens: Optional[torch.Tensor]
|
||||
encoder_lens_cpu: Optional[List[int]]
|
||||
encoder_out_cache_loc: Optional[torch.Tensor]
|
||||
|
||||
# For LoRA
|
||||
lora_paths: Optional[List[str]]
|
||||
|
||||
|
||||
@@ -662,8 +662,9 @@ class Scheduler:
|
||||
self.req_to_token_pool,
|
||||
self.token_to_kv_pool,
|
||||
self.tree_cache,
|
||||
self.model_config,
|
||||
)
|
||||
new_batch.prepare_for_extend(self.model_config.vocab_size)
|
||||
new_batch.prepare_for_extend()
|
||||
|
||||
# Mixed-style chunked prefill
|
||||
if self.is_mixed_chunk and self.running_batch is not None:
|
||||
|
||||
@@ -122,7 +122,7 @@ class TokenizerManager:
|
||||
|
||||
# We want to parallelize the image pre-processing so we create an executor for it
|
||||
self.image_processor = get_image_processor(
|
||||
self.hf_config, server_args, self.processor.image_processor
|
||||
self.hf_config, server_args, self.processor
|
||||
)
|
||||
else:
|
||||
self.tokenizer = get_tokenizer(
|
||||
@@ -191,8 +191,10 @@ class TokenizerManager:
|
||||
sampling_params = self._get_sampling_params(obj.sampling_params)
|
||||
if self.is_generation:
|
||||
image_inputs = await self.image_processor.process_images_async(
|
||||
obj.image_data, obj
|
||||
obj.image_data, input_text or input_ids, obj
|
||||
)
|
||||
if image_inputs and "input_ids" in image_inputs:
|
||||
input_ids = image_inputs["input_ids"]
|
||||
return_logprob = obj.return_logprob
|
||||
logprob_start_len = obj.logprob_start_len
|
||||
top_logprobs_num = obj.top_logprobs_num
|
||||
@@ -217,8 +219,10 @@ class TokenizerManager:
|
||||
sampling_params = self._get_sampling_params(obj.sampling_params[index])
|
||||
if self.is_generation:
|
||||
image_inputs = await self.image_processor.process_images_async(
|
||||
obj.image_data[index], obj
|
||||
obj.image_data[index], input_text or input_ids, obj
|
||||
)
|
||||
if image_inputs and "input_ids" in image_inputs:
|
||||
input_ids = image_inputs["input_ids"]
|
||||
return_logprob = obj.return_logprob[index]
|
||||
logprob_start_len = obj.logprob_start_len[index]
|
||||
top_logprobs_num = obj.top_logprobs_num[index]
|
||||
@@ -263,8 +267,10 @@ class TokenizerManager:
|
||||
sampling_params = SamplingParams(**obj.sampling_params[0])
|
||||
sampling_params.max_new_tokens = 0
|
||||
image_inputs = await self.image_processor.process_images_async(
|
||||
obj.image_data[0], obj
|
||||
obj.image_data[0], input_text or input_ids, obj
|
||||
)
|
||||
if image_inputs and "input_ids" in image_inputs:
|
||||
input_ids = image_inputs["input_ids"]
|
||||
return_logprob = obj.return_logprob[0]
|
||||
logprob_start_len = obj.logprob_start_len[0]
|
||||
top_logprobs_num = obj.top_logprobs_num[0]
|
||||
|
||||
Reference in New Issue
Block a user