[Fix] avoid stream sync and torch compile in prefill for fa3 backend (#4932)

This commit is contained in:
Baizhou Zhang
2025-03-30 13:53:44 -07:00
committed by GitHub
parent 032f8faaab
commit e62d60fe6d
7 changed files with 30 additions and 35 deletions

View File

@@ -491,10 +491,10 @@ class CudaGraphRunner:
self.seq_lens[:raw_bs].copy_(forward_batch.seq_lens)
self.out_cache_loc[:raw_num_token].copy_(forward_batch.out_cache_loc)
self.positions[:raw_num_token].copy_(forward_batch.positions)
if forward_batch.decode_seq_lens_cpu is not None:
if forward_batch.seq_lens_cpu is not None:
if bs != raw_bs:
self.seq_lens_cpu.fill_(1)
self.seq_lens_cpu[:raw_bs].copy_(forward_batch.decode_seq_lens_cpu)
self.seq_lens_cpu[:raw_bs].copy_(forward_batch.seq_lens_cpu)
if self.is_encoder_decoder:
self.encoder_lens[:raw_bs].copy_(forward_batch.encoder_lens)

View File

@@ -39,7 +39,6 @@ import triton
import triton.language as tl
from sglang.srt.layers.rotary_embedding import MRotaryEmbedding
from sglang.srt.utils import get_compiler_backend
if TYPE_CHECKING:
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
@@ -148,6 +147,9 @@ class ForwardBatch:
# The sum of all sequence lengths
seq_lens_sum: int
# Optional seq_lens on cpu
seq_lens_cpu: Optional[torch.Tensor] = None
# For logprob
return_logprob: bool = False
top_logprobs_nums: Optional[List[int]] = None
@@ -162,9 +164,6 @@ class ForwardBatch:
# Position information
positions: torch.Tensor = None
# For decode
decode_seq_lens_cpu: Optional[torch.Tensor] = None
# For extend
extend_num_tokens: Optional[int] = None
extend_seq_lens: Optional[torch.Tensor] = None
@@ -293,12 +292,14 @@ class ForwardBatch:
):
ret.positions = ret.spec_info.positions
# Get seq_lens_cpu if needed
if ret.seq_lens_cpu is None:
ret.seq_lens_cpu = batch.seq_lens_cpu
# Init position information
if ret.forward_mode.is_decode():
if ret.positions is None:
ret.positions = clamp_position(batch.seq_lens)
if ret.decode_seq_lens_cpu is None:
ret.decode_seq_lens_cpu = batch.decode_seq_lens
ret.positions = torch.clamp((batch.seq_lens - 1), min=0).to(torch.int64)
else:
ret.extend_seq_lens = torch.tensor(
batch.extend_seq_lens, dtype=torch.int32
@@ -518,8 +519,3 @@ def compute_position_torch(
extend_start_loc = torch.zeros_like(extend_seq_lens)
extend_start_loc[1:] = torch.cumsum(extend_seq_lens[:-1], dim=0)
return positions.to(torch.int64), extend_start_loc
@torch.compile(dynamic=True, backend=get_compiler_backend())
def clamp_position(seq_lens):
return torch.clamp((seq_lens - 1), min=0).to(torch.int64)