diff --git a/python/sglang/srt/model_executor/forward_batch_info.py b/python/sglang/srt/model_executor/forward_batch_info.py index d2991249d..811bff9ae 100644 --- a/python/sglang/srt/model_executor/forward_batch_info.py +++ b/python/sglang/srt/model_executor/forward_batch_info.py @@ -39,6 +39,7 @@ import triton import triton.language as tl from sglang.srt.layers.rotary_embedding import MRotaryEmbedding +from sglang.srt.utils import get_compiler_backend if TYPE_CHECKING: from sglang.srt.layers.attention.base_attn_backend import AttentionBackend @@ -299,7 +300,7 @@ class ForwardBatch: # Init position information if ret.forward_mode.is_decode(): if ret.positions is None: - ret.positions = torch.clamp((batch.seq_lens - 1), min=0).to(torch.int64) + ret.positions = clamp_position(batch.seq_lens) else: ret.extend_seq_lens = torch.tensor( batch.extend_seq_lens, dtype=torch.int32 @@ -519,3 +520,8 @@ def compute_position_torch( extend_start_loc = torch.zeros_like(extend_seq_lens) extend_start_loc[1:] = torch.cumsum(extend_seq_lens[:-1], dim=0) return positions.to(torch.int64), extend_start_loc + + +@torch.compile(dynamic=True, backend=get_compiler_backend()) +def clamp_position(seq_lens): + return torch.clamp((seq_lens - 1), min=0).to(torch.int64)