From feda9b11b3a7a1ad9aa22d9149e23081945eed26 Mon Sep 17 00:00:00 2001 From: Mick Date: Sat, 26 Apr 2025 09:28:33 +0900 Subject: [PATCH] fix: fix one more bug from merging mm_inputs (#5718) Co-authored-by: Xinyuan Tong Co-authored-by: XinyuanTong <115166877+JustinTong0323@users.noreply.github.com> --- python/sglang/srt/layers/rotary_embedding.py | 19 +++--- python/sglang/srt/managers/schedule_batch.py | 10 ++- .../srt/model_executor/forward_batch_info.py | 63 ++++++++++--------- test/srt/test_vision_openai_server.py | 1 - 4 files changed, 54 insertions(+), 39 deletions(-) diff --git a/python/sglang/srt/layers/rotary_embedding.py b/python/sglang/srt/layers/rotary_embedding.py index 6b132c965..0b68a2191 100644 --- a/python/sglang/srt/layers/rotary_embedding.py +++ b/python/sglang/srt/layers/rotary_embedding.py @@ -1040,15 +1040,18 @@ class MRotaryEmbedding(RotaryEmbedding): mrope_position_delta: int, context_len: int, seq_len: int, - ) -> List[List[int]]: - return [ - list( - range( - context_len + mrope_position_delta, seq_len + mrope_position_delta + ) -> torch.Tensor: + return torch.tensor( + [ + list( + range( + context_len + mrope_position_delta, + seq_len + mrope_position_delta, + ) ) - ) - for _ in range(3) - ] + for _ in range(3) + ] + ) _ROPE_DICT: Dict[Tuple, RotaryEmbedding] = {} diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index 6dea9321c..6b8506ddc 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -351,7 +351,6 @@ class MultimodalInputs: optional_args = [ "mm_items", "image_pad_len", - "mrope_position_delta", ] for arg in optional_args: self_arg = getattr(self, arg, None) @@ -367,6 +366,14 @@ class MultimodalInputs: [self.mrope_positions, other.mrope_positions], dim=1 ) + mrope_position_delta = self.mrope_position_delta + if mrope_position_delta is not None: + if other.mrope_position_delta is None: + self.mrope_position_delta = mrope_position_delta + else: + self.mrope_position_delta = torch.cat( + [self.mrope_position_delta, other.mrope_position_delta], dim=0 + ) # other args would be kept intact @@ -1455,7 +1462,6 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): if self.model_config.is_encoder_decoder: self.encoder_lens = torch.cat([self.encoder_lens, other.encoder_lens]) self.encoder_lens_cpu.extend(other.encoder_lens_cpu) - self.req_pool_indices = torch.cat( [self.req_pool_indices, other.req_pool_indices] ) diff --git a/python/sglang/srt/model_executor/forward_batch_info.py b/python/sglang/srt/model_executor/forward_batch_info.py index 9cfc1d32e..e493dec7a 100644 --- a/python/sglang/srt/model_executor/forward_batch_info.py +++ b/python/sglang/srt/model_executor/forward_batch_info.py @@ -38,7 +38,7 @@ import triton import triton.language as tl from sglang.srt.layers.rotary_embedding import MRotaryEmbedding -from sglang.srt.utils import get_compiler_backend +from sglang.srt.utils import flatten_nested_list, get_compiler_backend if TYPE_CHECKING: from sglang.srt.layers.attention.base_attn_backend import AttentionBackend @@ -364,23 +364,23 @@ class ForwardBatch: def merge_mm_inputs(self) -> Optional[MultimodalInputs]: """ - Merge all image inputs in the batch into a single MultiModalInputs object. + Merge all multimodal inputs in the batch into a single MultiModalInputs object. Returns: - if none, current batch contains no image input + if none, current batch contains no multimodal input """ if not self.mm_inputs or all(x is None for x in self.mm_inputs): return None - # Filter out None values valid_inputs = [x for x in self.mm_inputs if x is not None] - # Start with the first valid image input - merged = valid_inputs[0] + # TODO: is it expensive? + # a workaround to avoid importing `MultimodalInputs` + merged = valid_inputs[0].__class__(mm_items=[]) # Merge remaining inputs - for mm_input in valid_inputs[1:]: + for mm_input in valid_inputs: merged.merge(mm_input) return merged @@ -407,26 +407,34 @@ class ForwardBatch: def _compute_mrope_positions( self, model_runner: ModelRunner, batch: ModelWorkerBatch ): - mrope_positions_list = [None] * self.seq_lens.shape[0] - if self.forward_mode.is_decode(): - for i, _ in enumerate(mrope_positions_list): - mrope_position_delta = ( - 0 - if batch.multimodal_inputs[i] is None - else batch.multimodal_inputs[i].mrope_position_delta + # batch_size * [3 * seq_len] + batch_size = self.seq_lens.shape[0] + mrope_positions_list = [[]] * batch_size + for batch_idx in range(batch_size): + mm_input = batch.multimodal_inputs[batch_idx] + if self.forward_mode.is_decode(): + mrope_position_deltas = ( + [0] + if mm_input is None + else flatten_nested_list(mm_input.mrope_position_delta.tolist()) ) - mrope_positions_list[i] = torch.tensor( - MRotaryEmbedding.get_next_input_positions( - mrope_position_delta, - int(self.seq_lens[i]) - 1, - int(self.seq_lens[i]), - ) - ) - elif self.forward_mode.is_extend(): - for i, mm_input in enumerate(batch.multimodal_inputs): + next_input_positions = [] + for mrope_position_delta in mrope_position_deltas: + # batched deltas needs to be processed separately + # Convert list of lists to tensor with shape [3, seq_len] + next_input_positions += [ + MRotaryEmbedding.get_next_input_positions( + mrope_position_delta, + int(self.seq_lens[batch_idx]) - 1, + int(self.seq_lens[batch_idx]), + ) + ] + # 3 * N + mrope_positions_list[batch_idx] = torch.cat(next_input_positions, dim=1) + elif self.forward_mode.is_extend(): extend_seq_len, extend_prefix_len = ( - batch.extend_seq_lens[i], - batch.extend_prefix_lens[i], + batch.extend_seq_lens[batch_idx], + batch.extend_prefix_lens[batch_idx], ) if mm_input is None: # text only @@ -447,13 +455,12 @@ class ForwardBatch: :, extend_prefix_len : extend_prefix_len + extend_seq_len, ] - mrope_positions_list[i] = mrope_positions + mrope_positions_list[batch_idx] = mrope_positions self.mrope_positions = torch.cat( [pos.to(device=model_runner.device) for pos in mrope_positions_list], dim=1, - ).to(device=model_runner.device) - self.mrope_positions = self.mrope_positions.to(torch.int64) + ).to(dtype=torch.int64, device=model_runner.device) def get_max_chunk_capacity(self): # Maximum number of tokens in each chunk diff --git a/test/srt/test_vision_openai_server.py b/test/srt/test_vision_openai_server.py index 44ece4784..efed5fdb9 100644 --- a/test/srt/test_vision_openai_server.py +++ b/test/srt/test_vision_openai_server.py @@ -307,7 +307,6 @@ class TestOpenAIVisionServer(CustomTestCase): self.assertGreater(len(video_response), 0) def test_regex(self): - return client = openai.Client(api_key=self.api_key, base_url=self.base_url) regex = (