diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index 1abd67424..4d67ce6ff 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -133,6 +133,7 @@ class ImageInputs: aspect_ratio_mask: Optional[List[torch.Tensor]] = None # QWen2-VL related image_grid_thws: List[Tuple[int, int, int]] = None + mrope_position_delta: Optional[torch.Tensor] = None @staticmethod def from_dict(obj, vocab_size): @@ -251,9 +252,6 @@ class Req: # The number of cached tokens, that were already cached in the KV cache self.cached_tokens = 0 - # For Qwen2-VL - self.mrope_position_delta = [] # use mutable object - # whether request reached finished condition def finished(self) -> bool: return self.finished_reason is not None @@ -983,8 +981,6 @@ class ScheduleBatch: global bid bid += 1 - mrope_positions_delta = [req.mrope_position_delta for req in self.reqs] - return ModelWorkerBatch( bid=bid, forward_mode=self.forward_mode, @@ -1007,7 +1003,6 @@ class ScheduleBatch: encoder_out_cache_loc=self.encoder_out_cache_loc, lora_paths=[req.lora_path for req in self.reqs], sampling_info=self.sampling_info, - mrope_positions_delta=mrope_positions_delta, ) def copy(self): @@ -1074,9 +1069,6 @@ class ModelWorkerBatch: # Sampling info sampling_info: SamplingBatchInfo - # For Qwen2-VL - mrope_positions_delta: List[List[int]] - def copy(self): return dataclasses.replace(self, sampling_info=self.sampling_info.copy()) diff --git a/python/sglang/srt/model_executor/forward_batch_info.py b/python/sglang/srt/model_executor/forward_batch_info.py index d314af944..8bd5f197a 100644 --- a/python/sglang/srt/model_executor/forward_batch_info.py +++ b/python/sglang/srt/model_executor/forward_batch_info.py @@ -136,8 +136,13 @@ class ForwardBatch: mrope_positions_list = [None] * self.seq_lens.shape[0] if self.forward_mode.is_decode(): for i, _ in enumerate(mrope_positions_list): + mrope_position_delta = ( + 0 + if batch.image_inputs[i] is None + else batch.image_inputs[i].mrope_position_delta + ) mrope_positions_list[i] = MRotaryEmbedding.get_next_input_positions( - batch.mrope_positions_delta[i][0], + mrope_position_delta, int(self.seq_lens[i]) - 1, int(self.seq_lens[i]), ) @@ -159,7 +164,6 @@ class ForwardBatch: ) ] ] * 3 - mrope_position_delta = 0 else: # TODO: current qwen2-vl do not support radix cache since mrope position calculation mrope_positions, mrope_position_delta = ( @@ -173,8 +177,8 @@ class ForwardBatch: context_len=0, ) ) + batch.image_inputs[i].mrope_position_delta = mrope_position_delta mrope_positions_list[i] = mrope_positions - batch.mrope_positions_delta[i].append(mrope_position_delta) self.mrope_positions = torch.concat( [torch.tensor(pos, device=device) for pos in mrope_positions_list], diff --git a/python/sglang/srt/models/qwen2_vl.py b/python/sglang/srt/models/qwen2_vl.py index 4e4fef3dc..cedaa8e5c 100644 --- a/python/sglang/srt/models/qwen2_vl.py +++ b/python/sglang/srt/models/qwen2_vl.py @@ -649,8 +649,6 @@ class Qwen2VLForConditionalGeneration(nn.Module): ] image_embeds_offset += num_image_tokens - input_ids = None - hidden_states = self.model( input_ids=input_ids, positions=positions,