[Fix] avoid stream sync and torch compile in prefill for fa3 backend (#4932)
This commit is contained in:
@@ -491,10 +491,10 @@ class CudaGraphRunner:
|
||||
self.seq_lens[:raw_bs].copy_(forward_batch.seq_lens)
|
||||
self.out_cache_loc[:raw_num_token].copy_(forward_batch.out_cache_loc)
|
||||
self.positions[:raw_num_token].copy_(forward_batch.positions)
|
||||
if forward_batch.decode_seq_lens_cpu is not None:
|
||||
if forward_batch.seq_lens_cpu is not None:
|
||||
if bs != raw_bs:
|
||||
self.seq_lens_cpu.fill_(1)
|
||||
self.seq_lens_cpu[:raw_bs].copy_(forward_batch.decode_seq_lens_cpu)
|
||||
self.seq_lens_cpu[:raw_bs].copy_(forward_batch.seq_lens_cpu)
|
||||
|
||||
if self.is_encoder_decoder:
|
||||
self.encoder_lens[:raw_bs].copy_(forward_batch.encoder_lens)
|
||||
|
||||
@@ -39,7 +39,6 @@ import triton
|
||||
import triton.language as tl
|
||||
|
||||
from sglang.srt.layers.rotary_embedding import MRotaryEmbedding
|
||||
from sglang.srt.utils import get_compiler_backend
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
|
||||
@@ -148,6 +147,9 @@ class ForwardBatch:
|
||||
# The sum of all sequence lengths
|
||||
seq_lens_sum: int
|
||||
|
||||
# Optional seq_lens on cpu
|
||||
seq_lens_cpu: Optional[torch.Tensor] = None
|
||||
|
||||
# For logprob
|
||||
return_logprob: bool = False
|
||||
top_logprobs_nums: Optional[List[int]] = None
|
||||
@@ -162,9 +164,6 @@ class ForwardBatch:
|
||||
# Position information
|
||||
positions: torch.Tensor = None
|
||||
|
||||
# For decode
|
||||
decode_seq_lens_cpu: Optional[torch.Tensor] = None
|
||||
|
||||
# For extend
|
||||
extend_num_tokens: Optional[int] = None
|
||||
extend_seq_lens: Optional[torch.Tensor] = None
|
||||
@@ -293,12 +292,14 @@ class ForwardBatch:
|
||||
):
|
||||
ret.positions = ret.spec_info.positions
|
||||
|
||||
# Get seq_lens_cpu if needed
|
||||
if ret.seq_lens_cpu is None:
|
||||
ret.seq_lens_cpu = batch.seq_lens_cpu
|
||||
|
||||
# Init position information
|
||||
if ret.forward_mode.is_decode():
|
||||
if ret.positions is None:
|
||||
ret.positions = clamp_position(batch.seq_lens)
|
||||
if ret.decode_seq_lens_cpu is None:
|
||||
ret.decode_seq_lens_cpu = batch.decode_seq_lens
|
||||
ret.positions = torch.clamp((batch.seq_lens - 1), min=0).to(torch.int64)
|
||||
else:
|
||||
ret.extend_seq_lens = torch.tensor(
|
||||
batch.extend_seq_lens, dtype=torch.int32
|
||||
@@ -518,8 +519,3 @@ def compute_position_torch(
|
||||
extend_start_loc = torch.zeros_like(extend_seq_lens)
|
||||
extend_start_loc[1:] = torch.cumsum(extend_seq_lens[:-1], dim=0)
|
||||
return positions.to(torch.int64), extend_start_loc
|
||||
|
||||
|
||||
@torch.compile(dynamic=True, backend=get_compiler_backend())
|
||||
def clamp_position(seq_lens):
|
||||
return torch.clamp((seq_lens - 1), min=0).to(torch.int64)
|
||||
|
||||
Reference in New Issue
Block a user