Support qwen2 vl model (#1721)
Co-authored-by: yizhang2077 <1109276519@qq.com> Co-authored-by: ispobock <ISPObaoke@163.com>
This commit is contained in:
@@ -36,6 +36,8 @@ from typing import TYPE_CHECKING, List, Optional
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from sglang.srt.layers.rotary_embedding import MRotaryEmbedding
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.layers.attention import AttentionBackend
|
||||
from sglang.srt.managers.schedule_batch import ImageInputs, ModelWorkerBatch
|
||||
@@ -112,14 +114,88 @@ class ForwardBatch:
|
||||
token_to_kv_pool: BaseTokenToKVPool = None
|
||||
attn_backend: AttentionBackend = None
|
||||
|
||||
# For Qwen2-VL
|
||||
mrope_positions: torch.Tensor = None
|
||||
|
||||
def compute_mrope_positions(
|
||||
self, model_runner: ModelRunner, batch: ModelWorkerBatch
|
||||
):
|
||||
device = model_runner.device
|
||||
hf_config = model_runner.model_config.hf_config
|
||||
mrope_positions_list = [None] * self.seq_lens.shape[0]
|
||||
if self.forward_mode.is_decode():
|
||||
for i, _ in enumerate(mrope_positions_list):
|
||||
mrope_positions_list[i] = MRotaryEmbedding.get_next_input_positions(
|
||||
batch.mrope_positions_delta[i][0],
|
||||
int(self.seq_lens[i]) - 1,
|
||||
int(self.seq_lens[i]),
|
||||
)
|
||||
elif self.forward_mode.is_extend():
|
||||
for i, image_inputs in enumerate(batch.image_inputs):
|
||||
if image_inputs is None:
|
||||
# text only
|
||||
mrope_positions = [[i for i in range(self.seq_lens[i])]] * 3
|
||||
mrope_position_delta = 0
|
||||
else:
|
||||
extend_start_loc, extend_seq_len, extend_prefix_len = (
|
||||
self.extend_start_loc[i],
|
||||
self.extend_seq_lens[i],
|
||||
self.extend_prefix_lens[i],
|
||||
)
|
||||
mrope_positions, mrope_position_delta = (
|
||||
MRotaryEmbedding.get_input_positions(
|
||||
input_tokens=self.input_ids[
|
||||
extend_start_loc : extend_start_loc + extend_seq_len
|
||||
].tolist(),
|
||||
image_grid_thw=image_inputs.image_grid_thws,
|
||||
video_grid_thw=None,
|
||||
image_token_id=hf_config.image_token_id,
|
||||
video_token_id=hf_config.video_token_id,
|
||||
vision_start_token_id=hf_config.vision_start_token_id,
|
||||
vision_end_token_id=hf_config.vision_end_token_id,
|
||||
spatial_merge_size=hf_config.vision_config.spatial_merge_size,
|
||||
context_len=0,
|
||||
extend_prefix_len=extend_prefix_len.item(),
|
||||
)
|
||||
)
|
||||
mrope_positions_list[i] = mrope_positions
|
||||
batch.mrope_positions_delta[i].append(mrope_position_delta)
|
||||
|
||||
self.mrope_positions = torch.tensor(
|
||||
np.concatenate(
|
||||
[np.array(pos) for pos in mrope_positions_list],
|
||||
axis=1,
|
||||
),
|
||||
device=device,
|
||||
)
|
||||
self.mrope_positions = self.mrope_positions.to(torch.int64)
|
||||
|
||||
def compute_positions(self, model_runner: ModelRunner, batch: ModelWorkerBatch):
|
||||
device = model_runner.device
|
||||
if self.forward_mode.is_decode():
|
||||
self.positions = (self.seq_lens - 1).to(torch.int64)
|
||||
else:
|
||||
self.positions = torch.tensor(
|
||||
np.concatenate(
|
||||
[
|
||||
np.arange(prefix_len, prefix_len + extend_len)
|
||||
for prefix_len, extend_len in zip(
|
||||
batch.extend_prefix_lens, batch.extend_seq_lens
|
||||
)
|
||||
],
|
||||
axis=0,
|
||||
),
|
||||
device=device,
|
||||
).to(torch.int64)
|
||||
|
||||
@classmethod
|
||||
def init_new(
|
||||
cls,
|
||||
batch: ModelWorkerBatch,
|
||||
model_runner: ModelRunner,
|
||||
):
|
||||
device = model_runner.device
|
||||
|
||||
device = model_runner.device
|
||||
ret = cls(
|
||||
forward_mode=batch.forward_mode,
|
||||
batch_size=len(batch.seq_lens),
|
||||
@@ -156,6 +232,13 @@ class ForwardBatch:
|
||||
ret.extend_seq_lens_cpu = batch.extend_seq_lens
|
||||
ret.extend_logprob_start_lens_cpu = batch.extend_logprob_start_lens
|
||||
|
||||
# Init position information
|
||||
is_mrope = model_runner.model_is_mrope
|
||||
if is_mrope:
|
||||
ret.compute_mrope_positions(model_runner, batch)
|
||||
else:
|
||||
ret.compute_positions(model_runner, batch)
|
||||
|
||||
# Init attention information
|
||||
ret.req_to_token_pool = model_runner.req_to_token_pool
|
||||
ret.token_to_kv_pool = model_runner.token_to_kv_pool
|
||||
|
||||
@@ -125,6 +125,11 @@ class ModelRunner:
|
||||
)
|
||||
server_args.chunked_prefill_size = None
|
||||
server_args.mem_fraction_static *= 0.95
|
||||
# TODO: qwen2-vl does not support cuda graph now, set disable-graph=True automatically
|
||||
if self.model_config.hf_config.architectures == [
|
||||
"Qwen2VLForConditionalGeneration"
|
||||
]:
|
||||
server_args.disable_cuda_graph = True
|
||||
|
||||
# Global vars
|
||||
if server_args.show_time_cost:
|
||||
@@ -622,6 +627,15 @@ class ModelRunner:
|
||||
|
||||
return logits
|
||||
|
||||
@property
|
||||
def model_is_mrope(self) -> bool:
|
||||
"""Detect if the model has "mrope" rope_scaling type.
|
||||
mrope requires keep "rope_deltas" between prompt and decoding phases."""
|
||||
rope_scaling = getattr(self.model_config.hf_config, "rope_scaling", {})
|
||||
if rope_scaling is None:
|
||||
return False
|
||||
return rope_scaling.get("type", None) == "mrope"
|
||||
|
||||
|
||||
@lru_cache()
|
||||
def import_model_classes():
|
||||
|
||||
Reference in New Issue
Block a user