diff --git a/python/sglang/srt/layers/rotary_embedding.py b/python/sglang/srt/layers/rotary_embedding.py index 7cffccf6b..05f068557 100644 --- a/python/sglang/srt/layers/rotary_embedding.py +++ b/python/sglang/srt/layers/rotary_embedding.py @@ -1433,24 +1433,6 @@ class MRotaryEmbedding(RotaryEmbedding): return position_ids, mrope_position_deltas - @staticmethod - def get_next_input_positions( - mrope_position_delta: int, - context_len: int, - seq_len: int, - ) -> torch.Tensor: - return torch.tensor( - [ - list( - range( - context_len + mrope_position_delta, - seq_len + mrope_position_delta, - ) - ) - for _ in range(3) - ] - ) - class DualChunkRotaryEmbedding(CustomOp): """Rotary positional embedding for Dual Chunk Attention.""" diff --git a/python/sglang/srt/model_executor/forward_batch_info.py b/python/sglang/srt/model_executor/forward_batch_info.py index 65c0a07f8..8904e89f1 100644 --- a/python/sglang/srt/model_executor/forward_batch_info.py +++ b/python/sglang/srt/model_executor/forward_batch_info.py @@ -516,24 +516,23 @@ class ForwardBatch: for batch_idx in range(batch_size): mm_input = batch.multimodal_inputs[batch_idx] if self.forward_mode.is_decode(): - mrope_position_deltas = ( - [0] - if mm_input is None - else flatten_nested_list(mm_input.mrope_position_delta.tolist()) - ) - next_input_positions = [] - for mrope_position_delta in mrope_position_deltas: - # batched deltas needs to be processed separately - # Convert list of lists to tensor with shape [3, seq_len] - next_input_positions += [ - MRotaryEmbedding.get_next_input_positions( - mrope_position_delta, - int(self.seq_lens[batch_idx]) - 1, - int(self.seq_lens[batch_idx]), - ) - ] # 3 * N - mrope_positions_list[batch_idx] = torch.cat(next_input_positions, dim=1) + if mm_input is None: + mrope_positions_list[batch_idx] = torch.full( + (3, 1), + self.seq_lens[batch_idx] - 1, + dtype=torch.int64, + device=model_runner.device, + ) + else: + mrope_position_deltas = mm_input.mrope_position_delta.flatten().to( + model_runner.device, non_blocking=True + ) + mrope_positions_list[batch_idx] = ( + (mrope_position_deltas + self.seq_lens[batch_idx] - 1) + .unsqueeze(0) + .repeat(3, 1) + ) elif self.forward_mode.is_extend(): extend_seq_len, extend_prefix_len = ( batch.extend_seq_lens[batch_idx],