From c998d04b46920f06d945fbef9023884a768723fc Mon Sep 17 00:00:00 2001 From: Mick Date: Thu, 24 Apr 2025 12:35:05 +0900 Subject: [PATCH] vlm: enable radix cache for qwen-vl models (#5349) Co-authored-by: Xinyuan Tong --- benchmark/mmmu/eval_utils.py | 58 +++- python/sglang/srt/configs/model_config.py | 13 +- python/sglang/srt/layers/rotary_embedding.py | 252 ++++++++++-------- python/sglang/srt/managers/io_struct.py | 2 + python/sglang/srt/managers/mm_utils.py | 113 ++++++-- .../multimodal_processors/base_processor.py | 15 +- .../multimodal_processors/deepseek_vl_v2.py | 11 +- .../managers/multimodal_processors/gemma3.py | 7 +- .../multimodal_processors/janus_pro.py | 4 +- .../managers/multimodal_processors/minicpm.py | 5 +- .../managers/multimodal_processors/qwen_vl.py | 51 +++- python/sglang/srt/managers/schedule_batch.py | 25 +- .../sglang/srt/managers/tokenizer_manager.py | 5 +- .../srt/model_executor/forward_batch_info.py | 97 ++----- .../sglang/srt/model_executor/model_runner.py | 9 - python/sglang/srt/models/deepseek_vl2.py | 6 +- python/sglang/srt/models/minicpmo.py | 6 +- python/sglang/srt/models/mllama4.py | 4 +- python/sglang/srt/models/qwen2_5_vl.py | 9 +- python/sglang/srt/models/qwen2_vl.py | 10 +- python/sglang/srt/openai_api/adapter.py | 9 +- python/sglang/srt/server_args.py | 19 +- python/sglang/test/runners.py | 19 +- test/srt/run_suite.py | 2 +- test/srt/test_vision_openai_server.py | 3 +- test/srt/test_vlm_accuracy.py | 6 +- 26 files changed, 429 insertions(+), 331 deletions(-) diff --git a/benchmark/mmmu/eval_utils.py b/benchmark/mmmu/eval_utils.py index 2613be788..59e2c4930 100644 --- a/benchmark/mmmu/eval_utils.py +++ b/benchmark/mmmu/eval_utils.py @@ -89,7 +89,7 @@ def set_seed(seed_value): def prepare_samples(eval_args: EvalArgs): - print("preparing samples...") + print("Preparing samples...") # Build prompts set_seed(eval_args.seed) @@ -105,15 +105,40 @@ def prepare_samples(eval_args: EvalArgs): assert len(value) == 1, "key {} has more than one value".format(key) eval_args.config[key] = value[0] - # run for each subject + # run for each subject in parallel sub_dataset_list = [] + subjects = list(CAT_SHORT2LONG.values()) # Get a fixed list of subjects - for subject in tqdm(CAT_SHORT2LONG.values()): - sub_dataset = load_dataset( - eval_args.dataset_path, subject, split=eval_args.split - ) - sub_dataset_list.append(sub_dataset) - # break + print(f"Loading datasets for {len(subjects)} subjects...") + with ThreadPoolExecutor() as executor: + # Submit all load_dataset tasks + future_to_subject = { + executor.submit( + load_dataset, eval_args.dataset_path, subject, split=eval_args.split + ): subject + for subject in subjects + } + + # Collect results as they complete + results = {} + for future in tqdm( + as_completed(future_to_subject), + total=len(subjects), + desc="Loading datasets", + ): + subject = future_to_subject[future] + try: + results[subject] = future.result() + except Exception as exc: + print(f"{subject} generated an exception: {exc}") + + # Ensure datasets are added in the original order for consistency + for subject in subjects: + if subject in results: + sub_dataset_list.append(results[subject]) + else: + # Handle cases where a dataset failed to load (optional, depends on desired behavior) + print(f"Warning: Dataset for subject '{subject}' could not be loaded.") # merge all dataset dataset = concatenate_datasets(sub_dataset_list) @@ -133,18 +158,25 @@ def prepare_samples(eval_args: EvalArgs): width, height = image.size if width * height >= eval_args.image_pixels_limit: return None, True - image_path = f"{images_path}/image_{i}.png" + # Use a unique identifier for the image path to avoid potential collisions if indices reset + image_path = f"{images_path}/image_{sample['id']}.png" if not os.path.exists(image_path): image.save(image_path) sample["image_path"] = image_path return sample, False + print("Processing samples...") with ThreadPoolExecutor() as executor: + # Pass the sample itself to process_sample, index is less reliable now futures = [ - executor.submit(process_sample, i, sample) + executor.submit( + process_sample, i, sample + ) # Keep index i for tqdm maybe? Or remove it. Let's keep it for now. for i, sample in enumerate(dataset) ] - for future in tqdm(as_completed(futures), total=len(futures)): + for future in tqdm( + as_completed(futures), total=len(dataset), desc="Processing samples" + ): sample, skipped = future.result() if skipped: skip_count += 1 @@ -152,9 +184,9 @@ def prepare_samples(eval_args: EvalArgs): samples.append(sample) print( - f"skipping {skip_count} samples with large images, {round((float(skip_count) / len(dataset)) * 100, 2)}% of dataset" + f"Skipping {skip_count} samples with large images, {round((float(skip_count) / len(dataset)) * 100, 2)}% of dataset" ) - print("samples have been prepared") + print("Samples have been prepared") return samples diff --git a/python/sglang/srt/configs/model_config.py b/python/sglang/srt/configs/model_config.py index a719bf32b..f066c5b1b 100644 --- a/python/sglang/srt/configs/model_config.py +++ b/python/sglang/srt/configs/model_config.py @@ -73,15 +73,14 @@ class ModelConfig: ) if enable_multimodal is None: - if self.hf_config.architectures[0] == "Llama4ForConditionalGeneration": + mm_disabled_models = [ + "Gemma3ForConditionalGeneration", + "Llama4ForConditionalGeneration", + ] + if self.hf_config.architectures[0] in mm_disabled_models: enable_multimodal = False logger.info( - "Multimodal is disabled for Llama4. To enable it, set --enable-llama4-multimodal." - ) - elif self.hf_config.architectures[0] == "Gemma3ForConditionalGeneration": - enable_multimodal = False - logger.info( - "Multimodal is disabled for Gemma3. To enable it, set --enable-gemma3-multimodal." + f"Multimodal is disabled for {self.hf_config.model_type}. To enable it, set --enable-multimodal." ) else: enable_multimodal = True diff --git a/python/sglang/srt/layers/rotary_embedding.py b/python/sglang/srt/layers/rotary_embedding.py index e3a99655a..6b132c965 100644 --- a/python/sglang/srt/layers/rotary_embedding.py +++ b/python/sglang/srt/layers/rotary_embedding.py @@ -877,127 +877,163 @@ class MRotaryEmbedding(RotaryEmbedding): key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape) return query, key + # Copied from https://github.com/huggingface/transformers/blob/c8e0e603de9b3d49161a15fe6e8ea84badfb5d02/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py#L1439 @staticmethod - def get_input_positions( - input_tokens: List[int], - image_grid_thw: Union[List[List[int]], torch.Tensor], - video_grid_thw: Union[List[List[int]], torch.Tensor], + def get_rope_index( + spatial_merge_size: int, image_token_id: int, video_token_id: int, vision_start_token_id: int, - vision_end_token_id: int, - spatial_merge_size: int, - context_len: int = 0, - seq_len: Optional[int] = None, - second_per_grid_ts: Optional[torch.Tensor] = None, + model_type: str, tokens_per_second: Optional[int] = None, - ) -> Tuple[List[List[int]], int]: - """ - Get mrope input positions and delta value. + input_ids: Optional[torch.LongTensor] = None, + image_grid_thw: Optional[torch.LongTensor] = None, + video_grid_thw: Optional[torch.LongTensor] = None, + second_per_grid_ts: Optional[torch.Tensor] = None, + **kwargs, + ) -> Tuple[torch.Tensor, torch.Tensor]: + mrope_position_deltas = [] + if input_ids is not None and ( + image_grid_thw is not None or video_grid_thw is not None + ): + total_input_ids = input_ids + position_ids = torch.ones( + 3, + input_ids.shape[0], + input_ids.shape[1], + dtype=input_ids.dtype, + device=input_ids.device, + ) + image_index, video_index = 0, 0 + for i, input_ids in enumerate(total_input_ids): + image_nums, video_nums = 0, 0 + vision_start_indices = torch.argwhere( + input_ids == vision_start_token_id + ).squeeze(1) + vision_tokens = input_ids[vision_start_indices + 1] + image_nums = (vision_tokens == image_token_id).sum() + video_nums = (vision_tokens == video_token_id).sum() + input_tokens = input_ids.tolist() + llm_pos_ids_list: list = [] + st = 0 + remain_images, remain_videos = image_nums, video_nums + for _ in range(image_nums + video_nums): + if image_token_id in input_tokens and remain_images > 0: + ed_image = input_tokens.index(image_token_id, st) + else: + ed_image = len(input_tokens) + 1 + if video_token_id in input_tokens and remain_videos > 0: + ed_video = input_tokens.index(video_token_id, st) + else: + ed_video = len(input_tokens) + 1 + if ed_image < ed_video: + t, h, w = ( + image_grid_thw[image_index][0], + image_grid_thw[image_index][1], + image_grid_thw[image_index][2], + ) + second_per_grid_t = 0 + image_index += 1 + remain_images -= 1 + ed = ed_image + else: + t, h, w = ( + video_grid_thw[video_index][0], + video_grid_thw[video_index][1], + video_grid_thw[video_index][2], + ) + if second_per_grid_ts is not None: + second_per_grid_t = second_per_grid_ts[video_index] + else: + second_per_grid_t = 1.0 + video_index += 1 + remain_videos -= 1 + ed = ed_video + llm_grid_t, llm_grid_h, llm_grid_w = ( + t.item(), + h.item() // spatial_merge_size, + w.item() // spatial_merge_size, + ) + text_len = ed - st - :arg - second_per_grid_ts (`torch.Tensor` of shape `(num_videos)`, *optional*): - The time interval (in seconds) for each grid along the temporal dimension in the 3D position IDs. + st_idx = ( + llm_pos_ids_list[-1].max() + 1 + if len(llm_pos_ids_list) > 0 + else 0 + ) + llm_pos_ids_list.append( + torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx + ) - """ + if model_type == "qwen2_5_vl": + range_tensor = torch.arange(llm_grid_t).view(-1, 1) + expanded_range = range_tensor.expand( + -1, llm_grid_h * llm_grid_w + ) - if isinstance(image_grid_thw, torch.Tensor): - image_grid_thw = image_grid_thw.tolist() - if isinstance(video_grid_thw, torch.Tensor): - video_grid_thw = video_grid_thw.tolist() + time_tensor = ( + expanded_range * second_per_grid_t * tokens_per_second + ) - input_tokens_tensor = torch.tensor(input_tokens) - vision_start_indices = torch.argwhere( - input_tokens_tensor == vision_start_token_id - ).squeeze(1) - vision_tokens = input_tokens_tensor[vision_start_indices + 1] - image_nums = (vision_tokens == image_token_id).sum() - video_nums = (vision_tokens == video_token_id).sum() - llm_pos_ids_list: list = [] + time_tensor_long = time_tensor.long() + t_index = time_tensor_long.flatten() + elif model_type == "qwen2_vl": + t_index = ( + torch.arange(llm_grid_t) + .view(-1, 1) + .expand(-1, llm_grid_h * llm_grid_w) + .flatten() + ) + else: + raise RuntimeError("Unimplemented") + h_index = ( + torch.arange(llm_grid_h) + .view(1, -1, 1) + .expand(llm_grid_t, -1, llm_grid_w) + .flatten() + ) + w_index = ( + torch.arange(llm_grid_w) + .view(1, 1, -1) + .expand(llm_grid_t, llm_grid_h, -1) + .flatten() + ) + llm_pos_ids_list.append( + torch.stack([t_index, h_index, w_index]) + text_len + st_idx + ) + st = ed + llm_grid_t * llm_grid_h * llm_grid_w - st = 0 - remain_images, remain_videos = image_nums, video_nums + if st < len(input_tokens): + st_idx = ( + llm_pos_ids_list[-1].max() + 1 + if len(llm_pos_ids_list) > 0 + else 0 + ) + text_len = len(input_tokens) - st + llm_pos_ids_list.append( + torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx + ) - image_index, video_index = 0, 0 - for _ in range(image_nums + video_nums): - if image_token_id in input_tokens and remain_images > 0: - ed_image = input_tokens.index(image_token_id, st) - else: - ed_image = len(input_tokens) + 1 - if video_token_id in input_tokens and remain_videos > 0: - ed_video = input_tokens.index(video_token_id, st) - else: - ed_video = len(input_tokens) + 1 - if ed_image < ed_video: - t, h, w = ( - image_grid_thw[image_index][0], - image_grid_thw[image_index][1], - image_grid_thw[image_index][2], + llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1) + position_ids[..., i, :] = llm_positions.to(position_ids.device) + mrope_position_deltas.append( + llm_positions.max() + 1 - len(total_input_ids[i]) ) - image_index += 1 - remain_images -= 1 - second_per_grid_t = 0 - ed = ed_image - else: - t, h, w = ( - video_grid_thw[video_index][0], - video_grid_thw[video_index][1], - video_grid_thw[video_index][2], - ) - if second_per_grid_ts is not None: - second_per_grid_t = second_per_grid_ts[video_index] - else: - second_per_grid_t = 1.0 - video_index += 1 - remain_videos -= 1 - ed = ed_video - llm_grid_t, llm_grid_h, llm_grid_w = ( - t, - h // spatial_merge_size, - w // spatial_merge_size, + mrope_position_deltas = torch.tensor( + mrope_position_deltas, device=input_ids.device + ).unsqueeze(1) + return position_ids, mrope_position_deltas + else: + s = input_ids.shape[1] + position_ids = torch.arange(s) + position_ids = ( + position_ids.unsqueeze(0).expand(3, -1, -1).to(input_ids.device) ) - text_len = ed - st - - st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 - llm_pos_ids_list.append( - torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx - ) - - t_index = ( - torch.arange(llm_grid_t).view(-1, 1).expand(-1, llm_grid_h * llm_grid_w) - * second_per_grid_t - * tokens_per_second - ).flatten() - - h_index = ( - torch.arange(llm_grid_h) - .view(1, -1, 1) - .expand(llm_grid_t, -1, llm_grid_w) - .flatten() - ) - w_index = ( - torch.arange(llm_grid_w) - .view(1, 1, -1) - .expand(llm_grid_t, llm_grid_h, -1) - .flatten() - ) - llm_pos_ids_list.append( - torch.stack([t_index, h_index, w_index]) + text_len + st_idx - ) - st = ed + llm_grid_t * llm_grid_h * llm_grid_w - - if st < len(input_tokens): - st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 - text_len = len(input_tokens) - st - llm_pos_ids_list.append( - torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx - ) - - llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1) - mrope_position_delta = (llm_positions.max() + 1 - len(input_tokens)).item() - llm_positions = llm_positions[:, context_len:seq_len] - - return llm_positions.tolist(), mrope_position_delta + max_position_ids = position_ids.max(0, keepdim=False)[0].max( + -1, keepdim=True + )[0] + mrope_position_deltas = max_position_ids + 1 - s + return position_ids, mrope_position_deltas @staticmethod def get_next_input_positions( diff --git a/python/sglang/srt/managers/io_struct.py b/python/sglang/srt/managers/io_struct.py index e6ddb03f7..9d9b576c8 100644 --- a/python/sglang/srt/managers/io_struct.py +++ b/python/sglang/srt/managers/io_struct.py @@ -463,6 +463,8 @@ class EmbeddingReqInput: image_data: Optional[ Union[List[List[Union[Image, str]]], List[Union[Image, str]], Union[Image, str]] ] = None + # The audio input. Like image data, it can be a file name, a url, or base64 encoded string. + audio_data: Optional[Union[List[str], str]] = None # The token ids for text; one can either specify text or input_ids. input_ids: Optional[Union[List[List[int]], List[int]]] = None # The request id. diff --git a/python/sglang/srt/managers/mm_utils.py b/python/sglang/srt/managers/mm_utils.py index aa7ed1554..025a3010b 100644 --- a/python/sglang/srt/managers/mm_utils.py +++ b/python/sglang/srt/managers/mm_utils.py @@ -10,12 +10,13 @@ import torch from torch import nn from sglang.srt.managers.schedule_batch import ( + Modality, MultimodalDataItem, MultimodalInputs, global_server_args_dict, ) from sglang.srt.model_executor.forward_batch_info import ForwardBatch -from sglang.srt.utils import print_warning_once +from sglang.srt.utils import flatten_nested_list, print_warning_once logger = logging.getLogger(__name__) @@ -97,31 +98,80 @@ class MultiModalityDataPaddingPatternTokenPairs(MultiModalityDataPaddingPattern) return padded_ids -class MultiModalityDataPaddingPatternImageTokens(MultiModalityDataPaddingPattern): +class MultiModalityDataPaddingPatternMultimodalTokens(MultiModalityDataPaddingPattern): """In this pattern, data tokens should be represented as repetitions of a single token e.g. ...., or