From 554fbf93cd67234fa63f811aa458fe0f60f17e42 Mon Sep 17 00:00:00 2001 From: yizhang2077 <1109276519@qq.com> Date: Sun, 20 Oct 2024 17:38:35 +0800 Subject: [PATCH] [Bugfix] qwen2vl forward_extend (#1727) --- .../srt/model_executor/forward_batch_info.py | 59 ++++++++----------- test/srt/test_vision_openai_server.py | 4 -- 2 files changed, 25 insertions(+), 38 deletions(-) diff --git a/python/sglang/srt/model_executor/forward_batch_info.py b/python/sglang/srt/model_executor/forward_batch_info.py index eaf268cc2..49ef754a2 100644 --- a/python/sglang/srt/model_executor/forward_batch_info.py +++ b/python/sglang/srt/model_executor/forward_batch_info.py @@ -35,7 +35,6 @@ from dataclasses import dataclass from enum import IntEnum, auto from typing import TYPE_CHECKING, List, Optional -import numpy as np import torch from sglang.srt.layers.rotary_embedding import MRotaryEmbedding @@ -134,16 +133,23 @@ class ForwardBatch: ) elif self.forward_mode.is_extend(): for i, image_inputs in enumerate(batch.image_inputs): + extend_start_loc, extend_seq_len, extend_prefix_len = ( + self.extend_start_loc[i], + self.extend_seq_lens[i], + self.extend_prefix_lens[i], + ) if image_inputs is None: # text only - mrope_positions = [[i for i in range(self.seq_lens[i])]] * 3 + mrope_positions = [ + [ + pos + for pos in range( + extend_prefix_len, extend_prefix_len + extend_seq_len + ) + ] + ] * 3 mrope_position_delta = 0 else: - extend_start_loc, extend_seq_len, extend_prefix_len = ( - self.extend_start_loc[i], - self.extend_seq_lens[i], - self.extend_prefix_lens[i], - ) mrope_positions, mrope_position_delta = ( MRotaryEmbedding.get_input_positions( input_tokens=self.input_ids[ @@ -163,12 +169,9 @@ class ForwardBatch: mrope_positions_list[i] = mrope_positions batch.mrope_positions_delta[i].append(mrope_position_delta) - self.mrope_positions = torch.tensor( - np.concatenate( - [np.array(pos) for pos in mrope_positions_list], - axis=1, - ), - device=device, + self.mrope_positions = torch.concat( + [torch.tensor(pos, device=device) for pos in mrope_positions_list], + axis=1, ) self.mrope_positions = self.mrope_positions.to(torch.int64) @@ -177,18 +180,15 @@ class ForwardBatch: if self.forward_mode.is_decode(): self.positions = (self.seq_lens - 1).to(torch.int64) else: - self.positions = torch.tensor( - np.concatenate( - [ - np.arange(prefix_len, prefix_len + extend_len) - for prefix_len, extend_len in zip( - batch.extend_prefix_lens, batch.extend_seq_lens - ) - ], - axis=0, - ), - device=device, - ).to(torch.int64) + self.positions = torch.concat( + [ + torch.arange(prefix_len, prefix_len + extend_len, device=device) + for prefix_len, extend_len in zip( + batch.extend_prefix_lens, batch.extend_seq_lens + ) + ], + axis=0, + ) @classmethod def init_new( @@ -213,15 +213,6 @@ class ForwardBatch: # Init position information if not ret.forward_mode.is_decode(): - ret.positions = torch.concat( - [ - torch.arange(prefix_len, prefix_len + extend_len, device=device) - for prefix_len, extend_len in zip( - batch.extend_prefix_lens, batch.extend_seq_lens - ) - ], - axis=0, - ) ret.image_inputs = batch.image_inputs ret.extend_seq_lens = torch.tensor( batch.extend_seq_lens, dtype=torch.int32 diff --git a/test/srt/test_vision_openai_server.py b/test/srt/test_vision_openai_server.py index ba7a30026..296572ea9 100644 --- a/test/srt/test_vision_openai_server.py +++ b/test/srt/test_vision_openai_server.py @@ -362,10 +362,6 @@ class TestQWen2VLServer(TestOpenAIVisionServer): ) cls.base_url += "/v1" - def test_mixed_batch(self): - # FIXME: Temporarily skip this test. - pass - if __name__ == "__main__": unittest.main()