Use torch.compile to fuse flash attention decode metadata preparation (#6973)
This commit is contained in:
@@ -11,6 +11,7 @@ from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
|
|||||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
||||||
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
|
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
|
||||||
|
from sglang.srt.utils import get_compiler_backend
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from sglang.srt.layers.radix_attention import RadixAttention
|
from sglang.srt.layers.radix_attention import RadixAttention
|
||||||
@@ -1657,30 +1658,22 @@ class FlashAttentionBackend(AttentionBackend):
|
|||||||
)
|
)
|
||||||
# TODO: Handle local attention metadata for draft decode when llama4 eagle is supported
|
# TODO: Handle local attention metadata for draft decode when llama4 eagle is supported
|
||||||
else:
|
else:
|
||||||
metadata = self.decode_cuda_graph_metadata[bs]
|
|
||||||
# Normal Decode
|
# Normal Decode
|
||||||
|
metadata = self.decode_cuda_graph_metadata[bs]
|
||||||
max_len = seq_lens_cpu.max().item()
|
max_len = seq_lens_cpu.max().item()
|
||||||
|
max_seq_pages = (max_len + self.page_size - 1) // self.page_size
|
||||||
metadata.max_seq_len_k = max_len
|
metadata.max_seq_len_k = max_len
|
||||||
|
|
||||||
metadata.cache_seqlens_int32 = seq_lens.to(torch.int32)
|
normal_decode_set_medadata(
|
||||||
# Optimize cumulative sequence length calculation
|
metadata,
|
||||||
metadata.cu_seqlens_k[1:].copy_(
|
self.req_to_token,
|
||||||
torch.cumsum(seq_lens, dim=0, dtype=torch.int32)
|
req_pool_indices,
|
||||||
|
self.decode_cuda_graph_metadata["strided_indices"],
|
||||||
|
max_seq_pages,
|
||||||
|
seq_lens,
|
||||||
|
self.page_size,
|
||||||
)
|
)
|
||||||
|
|
||||||
max_seq_pages = (
|
|
||||||
metadata.max_seq_len_k + self.page_size - 1
|
|
||||||
) // self.page_size
|
|
||||||
page_indices = self.req_to_token[
|
|
||||||
req_pool_indices[:, None],
|
|
||||||
self.decode_cuda_graph_metadata["strided_indices"][:max_seq_pages][
|
|
||||||
None, :
|
|
||||||
],
|
|
||||||
]
|
|
||||||
page_indices //= self.page_size
|
|
||||||
metadata.page_table[:, :max_seq_pages].copy_(page_indices)
|
|
||||||
metadata.page_table[:, max_seq_pages:].fill_(0)
|
|
||||||
|
|
||||||
self._update_local_attn_metadata_for_replay(metadata, bs)
|
self._update_local_attn_metadata_for_replay(metadata, bs)
|
||||||
elif forward_mode.is_target_verify():
|
elif forward_mode.is_target_verify():
|
||||||
if self.topk <= 1:
|
if self.topk <= 1:
|
||||||
@@ -2063,3 +2056,23 @@ class FlashAttentionMultiStepBackend:
|
|||||||
seq_lens_cpu=forward_batch.seq_lens_cpu,
|
seq_lens_cpu=forward_batch.seq_lens_cpu,
|
||||||
out_cache_loc=forward_batch.out_cache_loc,
|
out_cache_loc=forward_batch.out_cache_loc,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@torch.compile(dynamic=True, backend=get_compiler_backend())
|
||||||
|
def normal_decode_set_medadata(
|
||||||
|
metadata,
|
||||||
|
req_to_token,
|
||||||
|
req_pool_indices,
|
||||||
|
strided_indices,
|
||||||
|
max_seq_pages,
|
||||||
|
seq_lens,
|
||||||
|
page_size,
|
||||||
|
):
|
||||||
|
metadata.cache_seqlens_int32 = seq_lens.to(torch.int32)
|
||||||
|
metadata.cu_seqlens_k[1:].copy_(torch.cumsum(seq_lens, dim=0, dtype=torch.int32))
|
||||||
|
page_indices = req_to_token[
|
||||||
|
req_pool_indices[:, None],
|
||||||
|
strided_indices[:max_seq_pages][None, :],
|
||||||
|
]
|
||||||
|
metadata.page_table[:, :max_seq_pages].copy_(page_indices // page_size)
|
||||||
|
metadata.page_table[:, max_seq_pages:].fill_(0)
|
||||||
|
|||||||
Reference in New Issue
Block a user