Use torch.compile to fuse flash attention decode metadata preparation (#6973)
This commit is contained in:
@@ -11,6 +11,7 @@ from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
|
||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
||||
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
|
||||
from sglang.srt.utils import get_compiler_backend
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.layers.radix_attention import RadixAttention
|
||||
@@ -1657,30 +1658,22 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
)
|
||||
# TODO: Handle local attention metadata for draft decode when llama4 eagle is supported
|
||||
else:
|
||||
metadata = self.decode_cuda_graph_metadata[bs]
|
||||
# Normal Decode
|
||||
metadata = self.decode_cuda_graph_metadata[bs]
|
||||
max_len = seq_lens_cpu.max().item()
|
||||
max_seq_pages = (max_len + self.page_size - 1) // self.page_size
|
||||
metadata.max_seq_len_k = max_len
|
||||
|
||||
metadata.cache_seqlens_int32 = seq_lens.to(torch.int32)
|
||||
# Optimize cumulative sequence length calculation
|
||||
metadata.cu_seqlens_k[1:].copy_(
|
||||
torch.cumsum(seq_lens, dim=0, dtype=torch.int32)
|
||||
normal_decode_set_medadata(
|
||||
metadata,
|
||||
self.req_to_token,
|
||||
req_pool_indices,
|
||||
self.decode_cuda_graph_metadata["strided_indices"],
|
||||
max_seq_pages,
|
||||
seq_lens,
|
||||
self.page_size,
|
||||
)
|
||||
|
||||
max_seq_pages = (
|
||||
metadata.max_seq_len_k + self.page_size - 1
|
||||
) // self.page_size
|
||||
page_indices = self.req_to_token[
|
||||
req_pool_indices[:, None],
|
||||
self.decode_cuda_graph_metadata["strided_indices"][:max_seq_pages][
|
||||
None, :
|
||||
],
|
||||
]
|
||||
page_indices //= self.page_size
|
||||
metadata.page_table[:, :max_seq_pages].copy_(page_indices)
|
||||
metadata.page_table[:, max_seq_pages:].fill_(0)
|
||||
|
||||
self._update_local_attn_metadata_for_replay(metadata, bs)
|
||||
elif forward_mode.is_target_verify():
|
||||
if self.topk <= 1:
|
||||
@@ -2063,3 +2056,23 @@ class FlashAttentionMultiStepBackend:
|
||||
seq_lens_cpu=forward_batch.seq_lens_cpu,
|
||||
out_cache_loc=forward_batch.out_cache_loc,
|
||||
)
|
||||
|
||||
|
||||
@torch.compile(dynamic=True, backend=get_compiler_backend())
|
||||
def normal_decode_set_medadata(
|
||||
metadata,
|
||||
req_to_token,
|
||||
req_pool_indices,
|
||||
strided_indices,
|
||||
max_seq_pages,
|
||||
seq_lens,
|
||||
page_size,
|
||||
):
|
||||
metadata.cache_seqlens_int32 = seq_lens.to(torch.int32)
|
||||
metadata.cu_seqlens_k[1:].copy_(torch.cumsum(seq_lens, dim=0, dtype=torch.int32))
|
||||
page_indices = req_to_token[
|
||||
req_pool_indices[:, None],
|
||||
strided_indices[:max_seq_pages][None, :],
|
||||
]
|
||||
metadata.page_table[:, :max_seq_pages].copy_(page_indices // page_size)
|
||||
metadata.page_table[:, max_seq_pages:].fill_(0)
|
||||
|
||||
Reference in New Issue
Block a user