Support qwen2 vl model (#1721)

Co-authored-by: yizhang2077 <1109276519@qq.com>
Co-authored-by: ispobock <ISPObaoke@163.com>
This commit is contained in:
Yineng Zhang
2024-10-19 21:44:38 -07:00
committed by GitHub
parent 8bee20f80b
commit cbbc82b7b8
15 changed files with 1310 additions and 9 deletions

View File

@@ -36,6 +36,8 @@ from typing import TYPE_CHECKING, List, Optional
import numpy as np
import torch
from sglang.srt.layers.rotary_embedding import MRotaryEmbedding
if TYPE_CHECKING:
from sglang.srt.layers.attention import AttentionBackend
from sglang.srt.managers.schedule_batch import ImageInputs, ModelWorkerBatch
@@ -112,14 +114,88 @@ class ForwardBatch:
token_to_kv_pool: BaseTokenToKVPool = None
attn_backend: AttentionBackend = None
# For Qwen2-VL
mrope_positions: torch.Tensor = None
def compute_mrope_positions(
self, model_runner: ModelRunner, batch: ModelWorkerBatch
):
device = model_runner.device
hf_config = model_runner.model_config.hf_config
mrope_positions_list = [None] * self.seq_lens.shape[0]
if self.forward_mode.is_decode():
for i, _ in enumerate(mrope_positions_list):
mrope_positions_list[i] = MRotaryEmbedding.get_next_input_positions(
batch.mrope_positions_delta[i][0],
int(self.seq_lens[i]) - 1,
int(self.seq_lens[i]),
)
elif self.forward_mode.is_extend():
for i, image_inputs in enumerate(batch.image_inputs):
if image_inputs is None:
# text only
mrope_positions = [[i for i in range(self.seq_lens[i])]] * 3
mrope_position_delta = 0
else:
extend_start_loc, extend_seq_len, extend_prefix_len = (
self.extend_start_loc[i],
self.extend_seq_lens[i],
self.extend_prefix_lens[i],
)
mrope_positions, mrope_position_delta = (
MRotaryEmbedding.get_input_positions(
input_tokens=self.input_ids[
extend_start_loc : extend_start_loc + extend_seq_len
].tolist(),
image_grid_thw=image_inputs.image_grid_thws,
video_grid_thw=None,
image_token_id=hf_config.image_token_id,
video_token_id=hf_config.video_token_id,
vision_start_token_id=hf_config.vision_start_token_id,
vision_end_token_id=hf_config.vision_end_token_id,
spatial_merge_size=hf_config.vision_config.spatial_merge_size,
context_len=0,
extend_prefix_len=extend_prefix_len.item(),
)
)
mrope_positions_list[i] = mrope_positions
batch.mrope_positions_delta[i].append(mrope_position_delta)
self.mrope_positions = torch.tensor(
np.concatenate(
[np.array(pos) for pos in mrope_positions_list],
axis=1,
),
device=device,
)
self.mrope_positions = self.mrope_positions.to(torch.int64)
def compute_positions(self, model_runner: ModelRunner, batch: ModelWorkerBatch):
device = model_runner.device
if self.forward_mode.is_decode():
self.positions = (self.seq_lens - 1).to(torch.int64)
else:
self.positions = torch.tensor(
np.concatenate(
[
np.arange(prefix_len, prefix_len + extend_len)
for prefix_len, extend_len in zip(
batch.extend_prefix_lens, batch.extend_seq_lens
)
],
axis=0,
),
device=device,
).to(torch.int64)
@classmethod
def init_new(
cls,
batch: ModelWorkerBatch,
model_runner: ModelRunner,
):
device = model_runner.device
device = model_runner.device
ret = cls(
forward_mode=batch.forward_mode,
batch_size=len(batch.seq_lens),
@@ -156,6 +232,13 @@ class ForwardBatch:
ret.extend_seq_lens_cpu = batch.extend_seq_lens
ret.extend_logprob_start_lens_cpu = batch.extend_logprob_start_lens
# Init position information
is_mrope = model_runner.model_is_mrope
if is_mrope:
ret.compute_mrope_positions(model_runner, batch)
else:
ret.compute_positions(model_runner, batch)
# Init attention information
ret.req_to_token_pool = model_runner.req_to_token_pool
ret.token_to_kv_pool = model_runner.token_to_kv_pool

View File

@@ -125,6 +125,11 @@ class ModelRunner:
)
server_args.chunked_prefill_size = None
server_args.mem_fraction_static *= 0.95
# TODO: qwen2-vl does not support cuda graph now, set disable-graph=True automatically
if self.model_config.hf_config.architectures == [
"Qwen2VLForConditionalGeneration"
]:
server_args.disable_cuda_graph = True
# Global vars
if server_args.show_time_cost:
@@ -622,6 +627,15 @@ class ModelRunner:
return logits
@property
def model_is_mrope(self) -> bool:
"""Detect if the model has "mrope" rope_scaling type.
mrope requires keep "rope_deltas" between prompt and decoding phases."""
rope_scaling = getattr(self.model_config.hf_config, "rope_scaling", {})
if rope_scaling is None:
return False
return rope_scaling.get("type", None) == "mrope"
@lru_cache()
def import_model_classes():