Qwen2vl support cuda graph and disable radix cache (#1780)
This commit is contained in:
@@ -280,7 +280,7 @@ You can view the full example [here](https://github.com/sgl-project/sglang/tree/
|
|||||||
- Llama / Llama 2 / Llama 3 / Llama 3.1
|
- Llama / Llama 2 / Llama 3 / Llama 3.1
|
||||||
- Mistral / Mixtral / Mistral NeMo
|
- Mistral / Mixtral / Mistral NeMo
|
||||||
- Gemma / Gemma 2
|
- Gemma / Gemma 2
|
||||||
- Qwen / Qwen 2 / Qwen 2 MoE
|
- Qwen / Qwen 2 / Qwen 2 MoE / Qwen 2 VL
|
||||||
- DeepSeek / DeepSeek 2
|
- DeepSeek / DeepSeek 2
|
||||||
- OLMoE
|
- OLMoE
|
||||||
- [LLaVA-OneVision](https://llava-vl.github.io/blog/2024-08-05-llava-onevision/)
|
- [LLaVA-OneVision](https://llava-vl.github.io/blog/2024-08-05-llava-onevision/)
|
||||||
|
|||||||
@@ -22,64 +22,33 @@ class MRotaryEmbedding:
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_input_positions(
|
def get_input_positions(
|
||||||
input_tokens: List[int],
|
input_tokens: torch.Tensor,
|
||||||
image_grid_thw: Union[List[List[int]], torch.Tensor],
|
image_grid_thw: Union[List[List[int]], torch.Tensor],
|
||||||
video_grid_thw: Union[List[List[int]], torch.Tensor],
|
|
||||||
image_token_id: int,
|
|
||||||
video_token_id: int,
|
|
||||||
vision_start_token_id: int,
|
vision_start_token_id: int,
|
||||||
vision_end_token_id: int,
|
|
||||||
spatial_merge_size: int,
|
spatial_merge_size: int,
|
||||||
context_len: int = 0,
|
context_len: int = 0,
|
||||||
extend_prefix_len: int = 0,
|
|
||||||
) -> Tuple[List[List[int]], int]:
|
) -> Tuple[List[List[int]], int]:
|
||||||
"""Get mrope input positions and delta value."""
|
"""Get mrope input positions and delta value."""
|
||||||
|
|
||||||
if isinstance(image_grid_thw, torch.Tensor):
|
if isinstance(image_grid_thw, torch.Tensor):
|
||||||
image_grid_thw = image_grid_thw.tolist()
|
image_grid_thw = image_grid_thw.tolist()
|
||||||
if isinstance(video_grid_thw, torch.Tensor):
|
|
||||||
video_grid_thw = video_grid_thw.tolist()
|
|
||||||
|
|
||||||
input_tokens_tensor = torch.tensor(input_tokens)
|
|
||||||
vision_start_indices = torch.argwhere(
|
vision_start_indices = torch.argwhere(
|
||||||
input_tokens_tensor == vision_start_token_id
|
input_tokens == vision_start_token_id
|
||||||
).squeeze(1)
|
).squeeze(1)
|
||||||
vision_tokens = input_tokens_tensor[vision_start_indices + 1]
|
image_indices = vision_start_indices + 1
|
||||||
image_nums = (vision_tokens == image_token_id).sum()
|
image_nums = image_indices.shape[0]
|
||||||
video_nums = (vision_tokens == video_token_id).sum()
|
|
||||||
llm_pos_ids_list: list = []
|
llm_pos_ids_list: list = []
|
||||||
|
|
||||||
st = 0
|
st = 0
|
||||||
remain_images, remain_videos = image_nums, video_nums
|
input_tokens_len = input_tokens.shape[0]
|
||||||
|
for image_index in range(image_nums):
|
||||||
image_index, video_index = 0, 0
|
ed = image_indices[image_index].item()
|
||||||
for _ in range(image_nums + video_nums):
|
|
||||||
if image_token_id in input_tokens and remain_images > 0:
|
|
||||||
ed_image = input_tokens.index(image_token_id, st)
|
|
||||||
else:
|
|
||||||
ed_image = len(input_tokens) + 1
|
|
||||||
if video_token_id in input_tokens and remain_videos > 0:
|
|
||||||
ed_video = input_tokens.index(video_token_id, st)
|
|
||||||
else:
|
|
||||||
ed_video = len(input_tokens) + 1
|
|
||||||
if ed_image < ed_video:
|
|
||||||
t, h, w = (
|
t, h, w = (
|
||||||
image_grid_thw[image_index][0],
|
image_grid_thw[image_index][0],
|
||||||
image_grid_thw[image_index][1],
|
image_grid_thw[image_index][1],
|
||||||
image_grid_thw[image_index][2],
|
image_grid_thw[image_index][2],
|
||||||
)
|
)
|
||||||
image_index += 1
|
|
||||||
remain_images -= 1
|
|
||||||
ed = ed_image
|
|
||||||
else:
|
|
||||||
t, h, w = (
|
|
||||||
video_grid_thw[video_index][0],
|
|
||||||
video_grid_thw[video_index][1],
|
|
||||||
video_grid_thw[video_index][2],
|
|
||||||
)
|
|
||||||
video_index += 1
|
|
||||||
remain_videos -= 1
|
|
||||||
ed = ed_video
|
|
||||||
llm_grid_t, llm_grid_h, llm_grid_w = (
|
llm_grid_t, llm_grid_h, llm_grid_w = (
|
||||||
t,
|
t,
|
||||||
h // spatial_merge_size,
|
h // spatial_merge_size,
|
||||||
@@ -115,18 +84,16 @@ class MRotaryEmbedding:
|
|||||||
)
|
)
|
||||||
st = ed + llm_grid_t * llm_grid_h * llm_grid_w
|
st = ed + llm_grid_t * llm_grid_h * llm_grid_w
|
||||||
|
|
||||||
if st < len(input_tokens):
|
if st < input_tokens_len:
|
||||||
st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
|
st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
|
||||||
text_len = len(input_tokens) - st
|
text_len = input_tokens_len - st
|
||||||
llm_pos_ids_list.append(
|
llm_pos_ids_list.append(
|
||||||
torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx
|
torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx
|
||||||
)
|
)
|
||||||
|
|
||||||
llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
|
llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
|
||||||
llm_positions = llm_positions[:, context_len:]
|
llm_positions = llm_positions[:, context_len:]
|
||||||
mrope_position_delta = (llm_positions.max() + 1 - len(input_tokens)).item()
|
mrope_position_delta = (llm_positions.max() + 1 - input_tokens_len).item()
|
||||||
llm_positions += extend_prefix_len
|
|
||||||
|
|
||||||
return llm_positions.tolist(), mrope_position_delta
|
return llm_positions.tolist(), mrope_position_delta
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
|||||||
@@ -152,6 +152,7 @@ class CudaGraphRunner:
|
|||||||
(self.max_bs,), self.seq_len_fill_value, dtype=torch.int32
|
(self.max_bs,), self.seq_len_fill_value, dtype=torch.int32
|
||||||
)
|
)
|
||||||
self.out_cache_loc = torch.zeros((self.max_bs,), dtype=torch.int32)
|
self.out_cache_loc = torch.zeros((self.max_bs,), dtype=torch.int32)
|
||||||
|
self.mrope_positions = torch.zeros((3, self.max_bs), dtype=torch.int32)
|
||||||
|
|
||||||
if self.is_encoder_decoder:
|
if self.is_encoder_decoder:
|
||||||
# NOTE: encoder_lens can influence the full_text_row_masked_out_mask tensor when doing mixed batch
|
# NOTE: encoder_lens can influence the full_text_row_masked_out_mask tensor when doing mixed batch
|
||||||
@@ -233,6 +234,7 @@ class CudaGraphRunner:
|
|||||||
encoder_lens = None
|
encoder_lens = None
|
||||||
|
|
||||||
seq_lens_sum = seq_lens.sum().item()
|
seq_lens_sum = seq_lens.sum().item()
|
||||||
|
mrope_positions = self.mrope_positions[:, :bs]
|
||||||
|
|
||||||
# Attention backend
|
# Attention backend
|
||||||
self.model_runner.attn_backend.init_forward_metadata_capture_cuda_graph(
|
self.model_runner.attn_backend.init_forward_metadata_capture_cuda_graph(
|
||||||
@@ -259,6 +261,7 @@ class CudaGraphRunner:
|
|||||||
return_logprob=False,
|
return_logprob=False,
|
||||||
top_logprobs_nums=[0] * bs,
|
top_logprobs_nums=[0] * bs,
|
||||||
positions=clamp_position(seq_lens),
|
positions=clamp_position(seq_lens),
|
||||||
|
mrope_positions=mrope_positions,
|
||||||
)
|
)
|
||||||
return forward(input_ids, forward_batch.positions, forward_batch)
|
return forward(input_ids, forward_batch.positions, forward_batch)
|
||||||
|
|
||||||
@@ -301,6 +304,8 @@ class CudaGraphRunner:
|
|||||||
self.out_cache_loc[:raw_bs].copy_(forward_batch.out_cache_loc)
|
self.out_cache_loc[:raw_bs].copy_(forward_batch.out_cache_loc)
|
||||||
if self.is_encoder_decoder:
|
if self.is_encoder_decoder:
|
||||||
self.encoder_lens[:raw_bs].copy_(forward_batch.encoder_lens)
|
self.encoder_lens[:raw_bs].copy_(forward_batch.encoder_lens)
|
||||||
|
if forward_batch.mrope_positions is not None:
|
||||||
|
self.mrope_positions[:, :raw_bs].copy_(forward_batch.mrope_positions)
|
||||||
|
|
||||||
# Attention backend
|
# Attention backend
|
||||||
self.model_runner.attn_backend.init_forward_metadata_replay_cuda_graph(
|
self.model_runner.attn_backend.init_forward_metadata_replay_cuda_graph(
|
||||||
|
|||||||
@@ -142,11 +142,12 @@ class ForwardBatch:
|
|||||||
int(self.seq_lens[i]),
|
int(self.seq_lens[i]),
|
||||||
)
|
)
|
||||||
elif self.forward_mode.is_extend():
|
elif self.forward_mode.is_extend():
|
||||||
|
extend_start_loc_cpu = self.extend_start_loc.cpu().numpy()
|
||||||
for i, image_inputs in enumerate(batch.image_inputs):
|
for i, image_inputs in enumerate(batch.image_inputs):
|
||||||
extend_start_loc, extend_seq_len, extend_prefix_len = (
|
extend_start_loc, extend_seq_len, extend_prefix_len = (
|
||||||
self.extend_start_loc[i],
|
extend_start_loc_cpu[i],
|
||||||
self.extend_seq_lens[i],
|
batch.extend_seq_lens[i],
|
||||||
self.extend_prefix_lens[i],
|
batch.extend_prefix_lens[i],
|
||||||
)
|
)
|
||||||
if image_inputs is None:
|
if image_inputs is None:
|
||||||
# text only
|
# text only
|
||||||
@@ -160,20 +161,16 @@ class ForwardBatch:
|
|||||||
] * 3
|
] * 3
|
||||||
mrope_position_delta = 0
|
mrope_position_delta = 0
|
||||||
else:
|
else:
|
||||||
|
# TODO: current qwen2-vl do not support radix cache since mrope position calculation
|
||||||
mrope_positions, mrope_position_delta = (
|
mrope_positions, mrope_position_delta = (
|
||||||
MRotaryEmbedding.get_input_positions(
|
MRotaryEmbedding.get_input_positions(
|
||||||
input_tokens=self.input_ids[
|
input_tokens=self.input_ids[
|
||||||
extend_start_loc : extend_start_loc + extend_seq_len
|
extend_start_loc : extend_start_loc + extend_seq_len
|
||||||
].tolist(),
|
],
|
||||||
image_grid_thw=image_inputs.image_grid_thws,
|
image_grid_thw=image_inputs.image_grid_thws,
|
||||||
video_grid_thw=None,
|
|
||||||
image_token_id=hf_config.image_token_id,
|
|
||||||
video_token_id=hf_config.video_token_id,
|
|
||||||
vision_start_token_id=hf_config.vision_start_token_id,
|
vision_start_token_id=hf_config.vision_start_token_id,
|
||||||
vision_end_token_id=hf_config.vision_end_token_id,
|
|
||||||
spatial_merge_size=hf_config.vision_config.spatial_merge_size,
|
spatial_merge_size=hf_config.vision_config.spatial_merge_size,
|
||||||
context_len=0,
|
context_len=0,
|
||||||
extend_prefix_len=extend_prefix_len.item(),
|
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
mrope_positions_list[i] = mrope_positions
|
mrope_positions_list[i] = mrope_positions
|
||||||
|
|||||||
@@ -125,11 +125,11 @@ class ModelRunner:
|
|||||||
)
|
)
|
||||||
server_args.chunked_prefill_size = None
|
server_args.chunked_prefill_size = None
|
||||||
server_args.mem_fraction_static *= 0.95
|
server_args.mem_fraction_static *= 0.95
|
||||||
# TODO: qwen2-vl does not support cuda graph now, set disable-graph=True automatically
|
# TODO: qwen2-vl does not support radix cache now, set disable_radix_cache=True automatically
|
||||||
if self.model_config.hf_config.architectures == [
|
if self.model_config.hf_config.architectures == [
|
||||||
"Qwen2VLForConditionalGeneration"
|
"Qwen2VLForConditionalGeneration"
|
||||||
]:
|
]:
|
||||||
server_args.disable_cuda_graph = True
|
server_args.disable_radix_cache = True
|
||||||
|
|
||||||
# Global vars
|
# Global vars
|
||||||
if server_args.show_time_cost:
|
if server_args.show_time_cost:
|
||||||
|
|||||||
Reference in New Issue
Block a user