diff --git a/README.md b/README.md index 8c5fa7ad5..da4061d81 100644 --- a/README.md +++ b/README.md @@ -280,7 +280,7 @@ You can view the full example [here](https://github.com/sgl-project/sglang/tree/ - Llama / Llama 2 / Llama 3 / Llama 3.1 - Mistral / Mixtral / Mistral NeMo - Gemma / Gemma 2 -- Qwen / Qwen 2 / Qwen 2 MoE +- Qwen / Qwen 2 / Qwen 2 MoE / Qwen 2 VL - DeepSeek / DeepSeek 2 - OLMoE - [LLaVA-OneVision](https://llava-vl.github.io/blog/2024-08-05-llava-onevision/) diff --git a/python/sglang/srt/layers/rotary_embedding.py b/python/sglang/srt/layers/rotary_embedding.py index 9af9285fe..cfc63fe4c 100644 --- a/python/sglang/srt/layers/rotary_embedding.py +++ b/python/sglang/srt/layers/rotary_embedding.py @@ -22,64 +22,33 @@ class MRotaryEmbedding: @staticmethod def get_input_positions( - input_tokens: List[int], + input_tokens: torch.Tensor, image_grid_thw: Union[List[List[int]], torch.Tensor], - video_grid_thw: Union[List[List[int]], torch.Tensor], - image_token_id: int, - video_token_id: int, vision_start_token_id: int, - vision_end_token_id: int, spatial_merge_size: int, context_len: int = 0, - extend_prefix_len: int = 0, ) -> Tuple[List[List[int]], int]: """Get mrope input positions and delta value.""" if isinstance(image_grid_thw, torch.Tensor): image_grid_thw = image_grid_thw.tolist() - if isinstance(video_grid_thw, torch.Tensor): - video_grid_thw = video_grid_thw.tolist() - input_tokens_tensor = torch.tensor(input_tokens) vision_start_indices = torch.argwhere( - input_tokens_tensor == vision_start_token_id + input_tokens == vision_start_token_id ).squeeze(1) - vision_tokens = input_tokens_tensor[vision_start_indices + 1] - image_nums = (vision_tokens == image_token_id).sum() - video_nums = (vision_tokens == video_token_id).sum() + image_indices = vision_start_indices + 1 + image_nums = image_indices.shape[0] llm_pos_ids_list: list = [] st = 0 - remain_images, remain_videos = image_nums, video_nums - - image_index, video_index = 0, 0 - for _ in range(image_nums + video_nums): - if image_token_id in input_tokens and remain_images > 0: - ed_image = input_tokens.index(image_token_id, st) - else: - ed_image = len(input_tokens) + 1 - if video_token_id in input_tokens and remain_videos > 0: - ed_video = input_tokens.index(video_token_id, st) - else: - ed_video = len(input_tokens) + 1 - if ed_image < ed_video: - t, h, w = ( - image_grid_thw[image_index][0], - image_grid_thw[image_index][1], - image_grid_thw[image_index][2], - ) - image_index += 1 - remain_images -= 1 - ed = ed_image - else: - t, h, w = ( - video_grid_thw[video_index][0], - video_grid_thw[video_index][1], - video_grid_thw[video_index][2], - ) - video_index += 1 - remain_videos -= 1 - ed = ed_video + input_tokens_len = input_tokens.shape[0] + for image_index in range(image_nums): + ed = image_indices[image_index].item() + t, h, w = ( + image_grid_thw[image_index][0], + image_grid_thw[image_index][1], + image_grid_thw[image_index][2], + ) llm_grid_t, llm_grid_h, llm_grid_w = ( t, h // spatial_merge_size, @@ -115,18 +84,16 @@ class MRotaryEmbedding: ) st = ed + llm_grid_t * llm_grid_h * llm_grid_w - if st < len(input_tokens): + if st < input_tokens_len: st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 - text_len = len(input_tokens) - st + text_len = input_tokens_len - st llm_pos_ids_list.append( torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx ) llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1) llm_positions = llm_positions[:, context_len:] - mrope_position_delta = (llm_positions.max() + 1 - len(input_tokens)).item() - llm_positions += extend_prefix_len - + mrope_position_delta = (llm_positions.max() + 1 - input_tokens_len).item() return llm_positions.tolist(), mrope_position_delta @staticmethod diff --git a/python/sglang/srt/model_executor/cuda_graph_runner.py b/python/sglang/srt/model_executor/cuda_graph_runner.py index d9a9861cc..52378f566 100644 --- a/python/sglang/srt/model_executor/cuda_graph_runner.py +++ b/python/sglang/srt/model_executor/cuda_graph_runner.py @@ -152,6 +152,7 @@ class CudaGraphRunner: (self.max_bs,), self.seq_len_fill_value, dtype=torch.int32 ) self.out_cache_loc = torch.zeros((self.max_bs,), dtype=torch.int32) + self.mrope_positions = torch.zeros((3, self.max_bs), dtype=torch.int32) if self.is_encoder_decoder: # NOTE: encoder_lens can influence the full_text_row_masked_out_mask tensor when doing mixed batch @@ -233,6 +234,7 @@ class CudaGraphRunner: encoder_lens = None seq_lens_sum = seq_lens.sum().item() + mrope_positions = self.mrope_positions[:, :bs] # Attention backend self.model_runner.attn_backend.init_forward_metadata_capture_cuda_graph( @@ -259,6 +261,7 @@ class CudaGraphRunner: return_logprob=False, top_logprobs_nums=[0] * bs, positions=clamp_position(seq_lens), + mrope_positions=mrope_positions, ) return forward(input_ids, forward_batch.positions, forward_batch) @@ -301,6 +304,8 @@ class CudaGraphRunner: self.out_cache_loc[:raw_bs].copy_(forward_batch.out_cache_loc) if self.is_encoder_decoder: self.encoder_lens[:raw_bs].copy_(forward_batch.encoder_lens) + if forward_batch.mrope_positions is not None: + self.mrope_positions[:, :raw_bs].copy_(forward_batch.mrope_positions) # Attention backend self.model_runner.attn_backend.init_forward_metadata_replay_cuda_graph( diff --git a/python/sglang/srt/model_executor/forward_batch_info.py b/python/sglang/srt/model_executor/forward_batch_info.py index f3065d7a2..d314af944 100644 --- a/python/sglang/srt/model_executor/forward_batch_info.py +++ b/python/sglang/srt/model_executor/forward_batch_info.py @@ -142,11 +142,12 @@ class ForwardBatch: int(self.seq_lens[i]), ) elif self.forward_mode.is_extend(): + extend_start_loc_cpu = self.extend_start_loc.cpu().numpy() for i, image_inputs in enumerate(batch.image_inputs): extend_start_loc, extend_seq_len, extend_prefix_len = ( - self.extend_start_loc[i], - self.extend_seq_lens[i], - self.extend_prefix_lens[i], + extend_start_loc_cpu[i], + batch.extend_seq_lens[i], + batch.extend_prefix_lens[i], ) if image_inputs is None: # text only @@ -160,20 +161,16 @@ class ForwardBatch: ] * 3 mrope_position_delta = 0 else: + # TODO: current qwen2-vl do not support radix cache since mrope position calculation mrope_positions, mrope_position_delta = ( MRotaryEmbedding.get_input_positions( input_tokens=self.input_ids[ extend_start_loc : extend_start_loc + extend_seq_len - ].tolist(), + ], image_grid_thw=image_inputs.image_grid_thws, - video_grid_thw=None, - image_token_id=hf_config.image_token_id, - video_token_id=hf_config.video_token_id, vision_start_token_id=hf_config.vision_start_token_id, - vision_end_token_id=hf_config.vision_end_token_id, spatial_merge_size=hf_config.vision_config.spatial_merge_size, context_len=0, - extend_prefix_len=extend_prefix_len.item(), ) ) mrope_positions_list[i] = mrope_positions diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index e2a2504cb..2bc048197 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -125,11 +125,11 @@ class ModelRunner: ) server_args.chunked_prefill_size = None server_args.mem_fraction_static *= 0.95 - # TODO: qwen2-vl does not support cuda graph now, set disable-graph=True automatically + # TODO: qwen2-vl does not support radix cache now, set disable_radix_cache=True automatically if self.model_config.hf_config.architectures == [ "Qwen2VLForConditionalGeneration" ]: - server_args.disable_cuda_graph = True + server_args.disable_radix_cache = True # Global vars if server_args.show_time_cost: