[Fix] avoid stream sync and torch compile in prefill for fa3 backend (#4932)

This commit is contained in:
Baizhou Zhang
2025-03-30 13:53:44 -07:00
committed by GitHub
parent 032f8faaab
commit e62d60fe6d
7 changed files with 30 additions and 35 deletions

View File

@@ -1398,21 +1398,22 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
def get_model_worker_batch(self) -> ModelWorkerBatch:
if self.forward_mode.is_decode_or_idle():
if (
global_server_args_dict["enable_flashinfer_mla"]
or global_server_args_dict["enable_flashmla"]
or global_server_args_dict["attention_backend"] == "fa3"
):
decode_seq_lens = self.seq_lens.cpu()
else:
decode_seq_lens = None
extend_seq_lens = extend_prefix_lens = extend_logprob_start_lens = None
else:
decode_seq_lens = None
extend_seq_lens = self.extend_lens
extend_prefix_lens = self.prefix_lens
extend_logprob_start_lens = self.extend_logprob_start_lens
# Create seq_lens_cpu when needed
if (
global_server_args_dict["enable_flashinfer_mla"]
or global_server_args_dict["enable_flashmla"]
or global_server_args_dict["attention_backend"] == "fa3"
):
seq_lens_cpu = self.seq_lens.cpu()
else:
seq_lens_cpu = None
if self.sampling_info:
if self.has_grammar:
self.sampling_info.grammars = [req.grammar for req in self.reqs]
@@ -1435,7 +1436,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
global_num_tokens=self.global_num_tokens,
global_num_tokens_for_logprob=self.global_num_tokens_for_logprob,
can_run_dp_cuda_graph=self.can_run_dp_cuda_graph,
decode_seq_lens=decode_seq_lens,
seq_lens_cpu=seq_lens_cpu,
extend_num_tokens=self.extend_num_tokens,
extend_seq_lens=extend_seq_lens,
extend_prefix_lens=extend_prefix_lens,
@@ -1496,6 +1497,7 @@ class ModelWorkerBatch:
req_pool_indices: torch.Tensor
# The sequence length
seq_lens: torch.Tensor
seq_lens_cpu: Optional[torch.Tensor]
# The indices of output tokens in the token_to_kv_pool_allocator
out_cache_loc: torch.Tensor
@@ -1512,9 +1514,6 @@ class ModelWorkerBatch:
global_num_tokens_for_logprob: Optional[List[int]]
can_run_dp_cuda_graph: bool
# For decode
decode_seq_lens: Optional[torch.Tensor]
# For extend
extend_num_tokens: Optional[int]
extend_seq_lens: Optional[List[int]]