Llama3.2 vision model support (#1551)
This commit is contained in:
@@ -36,6 +36,7 @@ from typing import List, Optional, Tuple, Union
|
||||
import torch
|
||||
|
||||
from sglang.global_config import global_config
|
||||
from sglang.srt.configs.model_config import ModelConfig
|
||||
from sglang.srt.constrained import RegexGuide
|
||||
from sglang.srt.constrained.jump_forward import JumpForwardMap
|
||||
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
|
||||
@@ -121,11 +122,12 @@ class ImageInputs:
|
||||
"""The image related inputs."""
|
||||
|
||||
pixel_values: torch.Tensor
|
||||
image_hash: int
|
||||
image_hashes: Optional[list] = None
|
||||
image_sizes: Optional[list] = None
|
||||
image_offsets: Optional[list] = None
|
||||
pad_values: Optional[list] = None
|
||||
modalities: Optional[list] = None
|
||||
num_image_tokens: Optional[int] = None
|
||||
|
||||
image_embeds: Optional[List[torch.Tensor]] = None
|
||||
aspect_ratio_ids: Optional[List[torch.Tensor]] = None
|
||||
@@ -138,19 +140,27 @@ class ImageInputs:
|
||||
# Use image hash as fake token_ids, which is then used for prefix matching
|
||||
ret = ImageInputs(
|
||||
pixel_values=obj["pixel_values"],
|
||||
image_hash=hash(tuple(obj["image_hashes"])),
|
||||
image_grid_thws=obj.get("image_grid_thws"),
|
||||
image_hashes=hash(tuple(obj["image_hashes"])),
|
||||
)
|
||||
image_hash = ret.image_hash
|
||||
image_hash = ret.image_hashes
|
||||
ret.pad_values = [
|
||||
(image_hash) % vocab_size,
|
||||
(image_hash >> 16) % vocab_size,
|
||||
(image_hash >> 32) % vocab_size,
|
||||
(image_hash >> 64) % vocab_size,
|
||||
]
|
||||
ret.image_sizes = obj["image_sizes"]
|
||||
# Only when pixel values is not None we have modalities
|
||||
ret.modalities = obj["modalities"] or ["image"]
|
||||
|
||||
optional_args = [
|
||||
"image_sizes",
|
||||
"modalities",
|
||||
"aspect_ratio_ids",
|
||||
"aspect_ratio_mask",
|
||||
"image_grid_thws",
|
||||
]
|
||||
for arg in optional_args:
|
||||
if arg in obj:
|
||||
setattr(ret, arg, obj[arg])
|
||||
|
||||
return ret
|
||||
|
||||
|
||||
@@ -416,6 +426,10 @@ class ScheduleBatch:
|
||||
req_to_token_pool: ReqToTokenPool = None
|
||||
token_to_kv_pool: BaseTokenToKVPool = None
|
||||
tree_cache: BasePrefixCache = None
|
||||
|
||||
# For utility
|
||||
model_config: ModelConfig = None
|
||||
|
||||
forward_mode: ForwardMode = None
|
||||
sampling_info: SamplingBatchInfo = None
|
||||
|
||||
@@ -440,6 +454,12 @@ class ScheduleBatch:
|
||||
extend_num_tokens: int = None
|
||||
decoding_reqs: List[Req] = None
|
||||
|
||||
# For encoder-decoder
|
||||
encoder_cached: Optional[List[bool]] = None
|
||||
encoder_lens: Optional[torch.Tensor] = None
|
||||
encoder_lens_cpu: Optional[List[int]] = None
|
||||
encoder_out_cache_loc: Optional[torch.Tensor] = None
|
||||
|
||||
# Stream
|
||||
has_stream: bool = False
|
||||
|
||||
@@ -450,12 +470,20 @@ class ScheduleBatch:
|
||||
device: str = "cuda"
|
||||
|
||||
@classmethod
|
||||
def init_new(cls, reqs, req_to_token_pool, token_to_kv_pool, tree_cache):
|
||||
def init_new(
|
||||
cls,
|
||||
reqs,
|
||||
req_to_token_pool,
|
||||
token_to_kv_pool,
|
||||
tree_cache,
|
||||
model_config,
|
||||
):
|
||||
return cls(
|
||||
reqs=reqs,
|
||||
req_to_token_pool=req_to_token_pool,
|
||||
token_to_kv_pool=token_to_kv_pool,
|
||||
tree_cache=tree_cache,
|
||||
model_config=model_config,
|
||||
return_logprob=any(req.return_logprob for req in reqs),
|
||||
has_stream=any(req.stream for req in reqs),
|
||||
has_regex=any(req.regex_fsm for req in reqs),
|
||||
@@ -493,7 +521,78 @@ class ScheduleBatch:
|
||||
|
||||
return out_cache_loc
|
||||
|
||||
def prepare_for_extend(self, vocab_size: int):
|
||||
def prepare_encoder_info_extend(self, input_ids: List[int], seq_lens: List[int]):
|
||||
self.encoder_lens_cpu = []
|
||||
self.encoder_cached = []
|
||||
|
||||
for req in self.reqs:
|
||||
im = req.image_inputs
|
||||
if im is None or im.num_image_tokens is None:
|
||||
# No image input
|
||||
self.encoder_lens_cpu.append(0)
|
||||
self.encoder_cached.append(True)
|
||||
else:
|
||||
self.encoder_lens_cpu.append(im.num_image_tokens)
|
||||
self.encoder_cached.append(
|
||||
self.forward_mode.is_decode()
|
||||
or len(req.prefix_indices) >= im.num_image_tokens
|
||||
)
|
||||
|
||||
self.encoder_lens = torch.tensor(self.encoder_lens_cpu, dtype=torch.int32).to(
|
||||
self.device, non_blocking=True
|
||||
)
|
||||
|
||||
# Strip encoder infos
|
||||
pt = 0
|
||||
decoder_out_cache_loc = []
|
||||
encoder_out_cache_loc = []
|
||||
for i, req in enumerate(self.reqs):
|
||||
encoder_len = self.encoder_lens_cpu[i]
|
||||
seq_lens[i] -= encoder_len
|
||||
|
||||
if len(req.prefix_indices) < encoder_len:
|
||||
# NOTE: the encoder part should considered as a whole
|
||||
assert len(req.prefix_indices) == 0
|
||||
input_ids[i] = input_ids[i][encoder_len:]
|
||||
encoder_out_cache_loc.append(self.out_cache_loc[pt : pt + encoder_len])
|
||||
decoder_out_cache_loc.append(
|
||||
self.out_cache_loc[pt + encoder_len : pt + req.extend_input_len]
|
||||
)
|
||||
self.extend_lens[i] -= encoder_len
|
||||
self.extend_num_tokens -= encoder_len
|
||||
else:
|
||||
decoder_out_cache_loc.append(
|
||||
self.out_cache_loc[pt : pt + req.extend_input_len]
|
||||
)
|
||||
self.prefix_lens[i] -= encoder_len
|
||||
|
||||
pt += req.extend_input_len
|
||||
|
||||
# Reassign
|
||||
self.input_ids = torch.tensor(sum(input_ids, []), dtype=torch.int32).to(
|
||||
self.device, non_blocking=True
|
||||
)
|
||||
self.seq_lens = torch.tensor(seq_lens, dtype=torch.int32).to(
|
||||
self.device, non_blocking=True
|
||||
)
|
||||
|
||||
if not decoder_out_cache_loc:
|
||||
self.out_cache_loc = torch.empty(0, dtype=torch.int32).to(
|
||||
self.device, non_blocking=True
|
||||
)
|
||||
else:
|
||||
self.out_cache_loc = torch.cat(decoder_out_cache_loc)
|
||||
|
||||
if not encoder_out_cache_loc:
|
||||
self.encoder_out_cache_loc = torch.empty(0, dtype=torch.int32).to(
|
||||
self.device, non_blocking=True
|
||||
)
|
||||
else:
|
||||
self.encoder_out_cache_loc = torch.cat(encoder_out_cache_loc)
|
||||
|
||||
assert len(self.out_cache_loc) == self.extend_num_tokens
|
||||
|
||||
def prepare_for_extend(self):
|
||||
self.forward_mode = ForwardMode.EXTEND
|
||||
|
||||
bs = len(self.reqs)
|
||||
@@ -561,8 +660,13 @@ class ScheduleBatch:
|
||||
self.extend_lens = [r.extend_input_len for r in reqs]
|
||||
self.extend_logprob_start_lens = [r.extend_logprob_start_len for r in reqs]
|
||||
|
||||
if self.model_config.is_encoder_decoder:
|
||||
self.prepare_encoder_info_extend(input_ids, seq_lens)
|
||||
|
||||
self.sampling_info = SamplingBatchInfo.from_schedule_batch(
|
||||
self, vocab_size, global_server_args_dict["disable_penalizer"]
|
||||
self,
|
||||
self.model_config.vocab_size,
|
||||
global_server_args_dict["disable_penalizer"],
|
||||
)
|
||||
|
||||
def mix_with_running(self, running_batch: "ScheduleBatch"):
|
||||
@@ -752,6 +856,10 @@ class ScheduleBatch:
|
||||
|
||||
return jump_forward_reqs
|
||||
|
||||
def prepare_encoder_info_decode(self):
|
||||
# Reset the encoder cached status
|
||||
self.encoder_cached = [True] * len(self.reqs)
|
||||
|
||||
def prepare_for_decode(self, enable_overlap: bool = False):
|
||||
self.forward_mode = ForwardMode.DECODE
|
||||
|
||||
@@ -766,16 +874,22 @@ class ScheduleBatch:
|
||||
bs = len(self.reqs)
|
||||
self.out_cache_loc = self.alloc_token_slots(bs)
|
||||
|
||||
if self.model_config.is_encoder_decoder:
|
||||
locs = self.encoder_lens + self.seq_lens
|
||||
self.prepare_encoder_info_decode()
|
||||
else:
|
||||
locs = self.seq_lens
|
||||
|
||||
if enable_overlap:
|
||||
# Do not use in-place operations in the overlap mode
|
||||
self.req_to_token_pool.write(
|
||||
(self.req_pool_indices, self.seq_lens), self.out_cache_loc
|
||||
(self.req_pool_indices, locs), self.out_cache_loc
|
||||
)
|
||||
self.seq_lens = self.seq_lens + 1
|
||||
else:
|
||||
# A faster in-place version
|
||||
self.req_to_token_pool.write(
|
||||
(self.req_pool_indices, self.seq_lens), self.out_cache_loc
|
||||
(self.req_pool_indices, locs), self.out_cache_loc
|
||||
)
|
||||
self.seq_lens.add_(1)
|
||||
self.seq_lens_sum += bs
|
||||
@@ -802,6 +916,10 @@ class ScheduleBatch:
|
||||
# No need to filter
|
||||
return
|
||||
|
||||
if self.model_config.is_encoder_decoder:
|
||||
self.encoder_lens = self.encoder_lens[keep_indices]
|
||||
self.encoder_lens_cpu = [self.encoder_lens_cpu[i] for i in keep_indices]
|
||||
|
||||
self.reqs = [self.reqs[i] for i in keep_indices]
|
||||
new_indices = torch.tensor(keep_indices, dtype=torch.int32).to(
|
||||
self.device, non_blocking=True
|
||||
@@ -828,6 +946,11 @@ class ScheduleBatch:
|
||||
# needs to be called with pre-merged Batch.reqs.
|
||||
self.sampling_info.merge_batch(other.sampling_info)
|
||||
|
||||
# Encoder-decoder infos
|
||||
if self.model_config.is_encoder_decoder:
|
||||
self.encoder_lens = torch.cat([self.encoder_lens, other.encoder_lens])
|
||||
self.encoder_lens_cpu.extend(other.encoder_lens_cpu)
|
||||
|
||||
self.req_pool_indices = torch.concat(
|
||||
[self.req_pool_indices, other.req_pool_indices]
|
||||
)
|
||||
@@ -850,14 +973,11 @@ class ScheduleBatch:
|
||||
|
||||
def get_model_worker_batch(self):
|
||||
if self.forward_mode.is_decode():
|
||||
extend_seq_lens = extend_prefix_lens = extend_logprob_start_lens = (
|
||||
image_inputs
|
||||
) = None
|
||||
extend_seq_lens = extend_prefix_lens = extend_logprob_start_lens = None
|
||||
else:
|
||||
extend_seq_lens = self.extend_lens
|
||||
extend_prefix_lens = self.prefix_lens
|
||||
extend_logprob_start_lens = self.extend_logprob_start_lens
|
||||
image_inputs = [r.image_inputs for r in self.reqs]
|
||||
|
||||
if self.has_regex:
|
||||
self.sampling_info.regex_fsms = [req.regex_fsm for req in self.reqs]
|
||||
@@ -887,7 +1007,11 @@ class ScheduleBatch:
|
||||
extend_seq_lens=extend_seq_lens,
|
||||
extend_prefix_lens=extend_prefix_lens,
|
||||
extend_logprob_start_lens=extend_logprob_start_lens,
|
||||
image_inputs=image_inputs,
|
||||
image_inputs=[r.image_inputs for r in self.reqs],
|
||||
encoder_cached=self.encoder_cached,
|
||||
encoder_lens=self.encoder_lens,
|
||||
encoder_lens_cpu=self.encoder_lens_cpu,
|
||||
encoder_out_cache_loc=self.encoder_out_cache_loc,
|
||||
lora_paths=[req.lora_path for req in self.reqs],
|
||||
sampling_info=self.sampling_info,
|
||||
mrope_positions_delta=mrope_positions_delta,
|
||||
@@ -897,6 +1021,7 @@ class ScheduleBatch:
|
||||
# Only contain fields that will be used by process_batch_result
|
||||
return ScheduleBatch(
|
||||
reqs=self.reqs,
|
||||
model_config=self.model_config,
|
||||
forward_mode=self.forward_mode,
|
||||
out_cache_loc=self.out_cache_loc,
|
||||
return_logprob=self.return_logprob,
|
||||
@@ -944,6 +1069,12 @@ class ModelWorkerBatch:
|
||||
# For multimodal
|
||||
image_inputs: Optional[List[ImageInputs]]
|
||||
|
||||
# For encoder-decoder
|
||||
encoder_cached: Optional[List[bool]]
|
||||
encoder_lens: Optional[torch.Tensor]
|
||||
encoder_lens_cpu: Optional[List[int]]
|
||||
encoder_out_cache_loc: Optional[torch.Tensor]
|
||||
|
||||
# For LoRA
|
||||
lora_paths: Optional[List[str]]
|
||||
|
||||
|
||||
Reference in New Issue
Block a user