[Fix] Add torch compile for torch.clamp back (#4936)
This commit is contained in:
@@ -39,6 +39,7 @@ import triton
|
||||
import triton.language as tl
|
||||
|
||||
from sglang.srt.layers.rotary_embedding import MRotaryEmbedding
|
||||
from sglang.srt.utils import get_compiler_backend
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
|
||||
@@ -299,7 +300,7 @@ class ForwardBatch:
|
||||
# Init position information
|
||||
if ret.forward_mode.is_decode():
|
||||
if ret.positions is None:
|
||||
ret.positions = torch.clamp((batch.seq_lens - 1), min=0).to(torch.int64)
|
||||
ret.positions = clamp_position(batch.seq_lens)
|
||||
else:
|
||||
ret.extend_seq_lens = torch.tensor(
|
||||
batch.extend_seq_lens, dtype=torch.int32
|
||||
@@ -519,3 +520,8 @@ def compute_position_torch(
|
||||
extend_start_loc = torch.zeros_like(extend_seq_lens)
|
||||
extend_start_loc[1:] = torch.cumsum(extend_seq_lens[:-1], dim=0)
|
||||
return positions.to(torch.int64), extend_start_loc
|
||||
|
||||
|
||||
@torch.compile(dynamic=True, backend=get_compiler_backend())
|
||||
def clamp_position(seq_lens):
|
||||
return torch.clamp((seq_lens - 1), min=0).to(torch.int64)
|
||||
|
||||
Reference in New Issue
Block a user