[Fix] Add torch compile for torch.clamp back (#4936)

This commit is contained in:
Baizhou Zhang
2025-03-30 20:46:07 -07:00
committed by GitHub
parent a303325fdb
commit 4a63bc32b7

View File

@@ -39,6 +39,7 @@ import triton
import triton.language as tl
from sglang.srt.layers.rotary_embedding import MRotaryEmbedding
from sglang.srt.utils import get_compiler_backend
if TYPE_CHECKING:
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
@@ -299,7 +300,7 @@ class ForwardBatch:
# Init position information
if ret.forward_mode.is_decode():
if ret.positions is None:
ret.positions = torch.clamp((batch.seq_lens - 1), min=0).to(torch.int64)
ret.positions = clamp_position(batch.seq_lens)
else:
ret.extend_seq_lens = torch.tensor(
batch.extend_seq_lens, dtype=torch.int32
@@ -519,3 +520,8 @@ def compute_position_torch(
extend_start_loc = torch.zeros_like(extend_seq_lens)
extend_start_loc[1:] = torch.cumsum(extend_seq_lens[:-1], dim=0)
return positions.to(torch.int64), extend_start_loc
@torch.compile(dynamic=True, backend=get_compiler_backend())
def clamp_position(seq_lens):
return torch.clamp((seq_lens - 1), min=0).to(torch.int64)