[Fix] avoid stream sync and torch compile in prefill for fa3 backend (#4932)
This commit is contained in:
@@ -79,7 +79,7 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)
|
||||
)
|
||||
# Precompute maximum sequence length
|
||||
metadata.max_seq_len_k = seqlens_in_batch.max().item()
|
||||
metadata.max_seq_len_k = forward_batch.seq_lens_cpu.max().item()
|
||||
# Precompute page table
|
||||
metadata.page_table = forward_batch.req_to_token_pool.req_to_token[
|
||||
forward_batch.req_pool_indices, : metadata.max_seq_len_k
|
||||
|
||||
@@ -797,7 +797,7 @@ class FlashInferMLAMultiStepDraftBackend:
|
||||
encoder_lens=None,
|
||||
forward_mode=ForwardMode.DECODE,
|
||||
spec_info=forward_batch.spec_info,
|
||||
seq_lens_cpu=forward_batch.decode_seq_lens_cpu,
|
||||
seq_lens_cpu=forward_batch.seq_lens_cpu,
|
||||
)
|
||||
|
||||
self.common_template(forward_batch, self.cuda_graph_kv_indices, call_fn)
|
||||
|
||||
@@ -92,7 +92,7 @@ class FlashMLABackend(FlashInferMLAAttnBackend):
|
||||
if forward_batch.forward_mode.is_decode_or_idle():
|
||||
if spec_info is None:
|
||||
max_seqlen_pad = triton.cdiv(
|
||||
forward_batch.decode_seq_lens_cpu.max().item(), PAGE_SIZE
|
||||
forward_batch.seq_lens_cpu.max().item(), PAGE_SIZE
|
||||
)
|
||||
block_kv_indices = torch.full(
|
||||
(bs, max_seqlen_pad),
|
||||
|
||||
@@ -1398,21 +1398,22 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
||||
|
||||
def get_model_worker_batch(self) -> ModelWorkerBatch:
|
||||
if self.forward_mode.is_decode_or_idle():
|
||||
if (
|
||||
global_server_args_dict["enable_flashinfer_mla"]
|
||||
or global_server_args_dict["enable_flashmla"]
|
||||
or global_server_args_dict["attention_backend"] == "fa3"
|
||||
):
|
||||
decode_seq_lens = self.seq_lens.cpu()
|
||||
else:
|
||||
decode_seq_lens = None
|
||||
extend_seq_lens = extend_prefix_lens = extend_logprob_start_lens = None
|
||||
else:
|
||||
decode_seq_lens = None
|
||||
extend_seq_lens = self.extend_lens
|
||||
extend_prefix_lens = self.prefix_lens
|
||||
extend_logprob_start_lens = self.extend_logprob_start_lens
|
||||
|
||||
# Create seq_lens_cpu when needed
|
||||
if (
|
||||
global_server_args_dict["enable_flashinfer_mla"]
|
||||
or global_server_args_dict["enable_flashmla"]
|
||||
or global_server_args_dict["attention_backend"] == "fa3"
|
||||
):
|
||||
seq_lens_cpu = self.seq_lens.cpu()
|
||||
else:
|
||||
seq_lens_cpu = None
|
||||
|
||||
if self.sampling_info:
|
||||
if self.has_grammar:
|
||||
self.sampling_info.grammars = [req.grammar for req in self.reqs]
|
||||
@@ -1435,7 +1436,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
||||
global_num_tokens=self.global_num_tokens,
|
||||
global_num_tokens_for_logprob=self.global_num_tokens_for_logprob,
|
||||
can_run_dp_cuda_graph=self.can_run_dp_cuda_graph,
|
||||
decode_seq_lens=decode_seq_lens,
|
||||
seq_lens_cpu=seq_lens_cpu,
|
||||
extend_num_tokens=self.extend_num_tokens,
|
||||
extend_seq_lens=extend_seq_lens,
|
||||
extend_prefix_lens=extend_prefix_lens,
|
||||
@@ -1496,6 +1497,7 @@ class ModelWorkerBatch:
|
||||
req_pool_indices: torch.Tensor
|
||||
# The sequence length
|
||||
seq_lens: torch.Tensor
|
||||
seq_lens_cpu: Optional[torch.Tensor]
|
||||
# The indices of output tokens in the token_to_kv_pool_allocator
|
||||
out_cache_loc: torch.Tensor
|
||||
|
||||
@@ -1512,9 +1514,6 @@ class ModelWorkerBatch:
|
||||
global_num_tokens_for_logprob: Optional[List[int]]
|
||||
can_run_dp_cuda_graph: bool
|
||||
|
||||
# For decode
|
||||
decode_seq_lens: Optional[torch.Tensor]
|
||||
|
||||
# For extend
|
||||
extend_num_tokens: Optional[int]
|
||||
extend_seq_lens: Optional[List[int]]
|
||||
|
||||
@@ -491,10 +491,10 @@ class CudaGraphRunner:
|
||||
self.seq_lens[:raw_bs].copy_(forward_batch.seq_lens)
|
||||
self.out_cache_loc[:raw_num_token].copy_(forward_batch.out_cache_loc)
|
||||
self.positions[:raw_num_token].copy_(forward_batch.positions)
|
||||
if forward_batch.decode_seq_lens_cpu is not None:
|
||||
if forward_batch.seq_lens_cpu is not None:
|
||||
if bs != raw_bs:
|
||||
self.seq_lens_cpu.fill_(1)
|
||||
self.seq_lens_cpu[:raw_bs].copy_(forward_batch.decode_seq_lens_cpu)
|
||||
self.seq_lens_cpu[:raw_bs].copy_(forward_batch.seq_lens_cpu)
|
||||
|
||||
if self.is_encoder_decoder:
|
||||
self.encoder_lens[:raw_bs].copy_(forward_batch.encoder_lens)
|
||||
|
||||
@@ -39,7 +39,6 @@ import triton
|
||||
import triton.language as tl
|
||||
|
||||
from sglang.srt.layers.rotary_embedding import MRotaryEmbedding
|
||||
from sglang.srt.utils import get_compiler_backend
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
|
||||
@@ -148,6 +147,9 @@ class ForwardBatch:
|
||||
# The sum of all sequence lengths
|
||||
seq_lens_sum: int
|
||||
|
||||
# Optional seq_lens on cpu
|
||||
seq_lens_cpu: Optional[torch.Tensor] = None
|
||||
|
||||
# For logprob
|
||||
return_logprob: bool = False
|
||||
top_logprobs_nums: Optional[List[int]] = None
|
||||
@@ -162,9 +164,6 @@ class ForwardBatch:
|
||||
# Position information
|
||||
positions: torch.Tensor = None
|
||||
|
||||
# For decode
|
||||
decode_seq_lens_cpu: Optional[torch.Tensor] = None
|
||||
|
||||
# For extend
|
||||
extend_num_tokens: Optional[int] = None
|
||||
extend_seq_lens: Optional[torch.Tensor] = None
|
||||
@@ -293,12 +292,14 @@ class ForwardBatch:
|
||||
):
|
||||
ret.positions = ret.spec_info.positions
|
||||
|
||||
# Get seq_lens_cpu if needed
|
||||
if ret.seq_lens_cpu is None:
|
||||
ret.seq_lens_cpu = batch.seq_lens_cpu
|
||||
|
||||
# Init position information
|
||||
if ret.forward_mode.is_decode():
|
||||
if ret.positions is None:
|
||||
ret.positions = clamp_position(batch.seq_lens)
|
||||
if ret.decode_seq_lens_cpu is None:
|
||||
ret.decode_seq_lens_cpu = batch.decode_seq_lens
|
||||
ret.positions = torch.clamp((batch.seq_lens - 1), min=0).to(torch.int64)
|
||||
else:
|
||||
ret.extend_seq_lens = torch.tensor(
|
||||
batch.extend_seq_lens, dtype=torch.int32
|
||||
@@ -518,8 +519,3 @@ def compute_position_torch(
|
||||
extend_start_loc = torch.zeros_like(extend_seq_lens)
|
||||
extend_start_loc[1:] = torch.cumsum(extend_seq_lens[:-1], dim=0)
|
||||
return positions.to(torch.int64), extend_start_loc
|
||||
|
||||
|
||||
@torch.compile(dynamic=True, backend=get_compiler_backend())
|
||||
def clamp_position(seq_lens):
|
||||
return torch.clamp((seq_lens - 1), min=0).to(torch.int64)
|
||||
|
||||
@@ -214,10 +214,10 @@ class EAGLEDraftCudaGraphRunner:
|
||||
forward_batch.positions = self.positions[:num_tokens]
|
||||
|
||||
# Special handle for seq_len_cpu used when flashinfer mla is used
|
||||
if (forward_batch.decode_seq_lens_cpu is not None) and (bs != raw_bs):
|
||||
if (forward_batch.seq_lens_cpu is not None) and (bs != raw_bs):
|
||||
self.seq_lens_cpu.fill_(1)
|
||||
self.seq_lens_cpu[:raw_bs].copy_(forward_batch.decode_seq_lens_cpu)
|
||||
forward_batch.decode_seq_lens_cpu = self.seq_lens_cpu[:bs]
|
||||
self.seq_lens_cpu[:raw_bs].copy_(forward_batch.seq_lens_cpu)
|
||||
forward_batch.seq_lens_cpu = self.seq_lens_cpu[:bs]
|
||||
|
||||
self.model_runner.draft_attn_backend.init_forward_metadata_replay_cuda_graph(
|
||||
forward_batch, bs
|
||||
@@ -233,7 +233,7 @@ class EAGLEDraftCudaGraphRunner:
|
||||
forward_batch.positions = self.positions[:raw_num_token]
|
||||
forward_batch.seq_lens = self.seq_lens[:raw_bs]
|
||||
forward_batch.req_pool_indices = self.req_pool_indices[:raw_bs]
|
||||
if forward_batch.decode_seq_lens_cpu is not None:
|
||||
forward_batch.decode_seq_lens_cpu = self.seq_lens_cpu[:raw_bs]
|
||||
if forward_batch.seq_lens_cpu is not None:
|
||||
forward_batch.seq_lens_cpu = self.seq_lens_cpu[:raw_bs]
|
||||
|
||||
return out
|
||||
|
||||
Reference in New Issue
Block a user