refine aiter_backend for mtp (#7279)
Co-authored-by: HAI <hixiao@gmail.com>
This commit is contained in:
@@ -32,7 +32,7 @@ try:
|
|||||||
mha_batch_prefill_func,
|
mha_batch_prefill_func,
|
||||||
paged_attention_ragged,
|
paged_attention_ragged,
|
||||||
)
|
)
|
||||||
from aiter.mla import mla_decode_fwd
|
from aiter.mla import mla_decode_fwd, mla_prefill_fwd
|
||||||
except ImportError:
|
except ImportError:
|
||||||
print(
|
print(
|
||||||
"aiter is AMD specific kernel library. Please make sure aiter is installed on your AMD device."
|
"aiter is AMD specific kernel library. Please make sure aiter is installed on your AMD device."
|
||||||
@@ -52,10 +52,8 @@ class ForwardMetadata:
|
|||||||
kv_indices: torch.Tensor
|
kv_indices: torch.Tensor
|
||||||
qo_indptr: torch.Tensor
|
qo_indptr: torch.Tensor
|
||||||
kv_last_page_len: torch.Tensor
|
kv_last_page_len: torch.Tensor
|
||||||
max_extend_len: int
|
|
||||||
max_prefix_extend_len: int
|
|
||||||
max_q_len: int
|
max_q_len: int
|
||||||
max_kv_len: int
|
max_kv_len: Optional[int]
|
||||||
|
|
||||||
|
|
||||||
global_workspace_buffer = None
|
global_workspace_buffer = None
|
||||||
@@ -71,10 +69,17 @@ class AiterAttnBackend(AttentionBackend):
|
|||||||
kv_indptr_buf: Optional[torch.Tensor] = None,
|
kv_indptr_buf: Optional[torch.Tensor] = None,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
# Lazy import to avoid the initialization of cuda context
|
||||||
|
from sglang.srt.layers.attention.triton_ops.extend_attention import (
|
||||||
|
extend_attention_fwd,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.extend_attention_fwd = torch.compiler.disable(extend_attention_fwd)
|
||||||
|
|
||||||
self.device = model_runner.device
|
self.device = model_runner.device
|
||||||
self.is_multimodal = model_runner.model_config.is_multimodal
|
self.is_multimodal = model_runner.model_config.is_multimodal
|
||||||
self.num_draft_tokens = model_runner.server_args.speculative_num_draft_tokens
|
self.num_draft_tokens = model_runner.server_args.speculative_num_draft_tokens
|
||||||
|
self.speculative_num_steps = model_runner.server_args.speculative_num_steps
|
||||||
self.num_head = (
|
self.num_head = (
|
||||||
model_runner.model_config.num_attention_heads // get_attention_tp_size()
|
model_runner.model_config.num_attention_heads // get_attention_tp_size()
|
||||||
)
|
)
|
||||||
@@ -157,13 +162,13 @@ class AiterAttnBackend(AttentionBackend):
|
|||||||
spec_info = forward_batch.spec_info
|
spec_info = forward_batch.spec_info
|
||||||
qo_indptr = None
|
qo_indptr = None
|
||||||
kv_last_page_len = None
|
kv_last_page_len = None
|
||||||
max_extend_len = None
|
max_q_len = None
|
||||||
|
|
||||||
if forward_batch.forward_mode.is_decode_or_idle():
|
if forward_batch.forward_mode.is_decode_or_idle():
|
||||||
if spec_info is None:
|
if spec_info is None:
|
||||||
kv_indptr[1 : bs + 1] = torch.cumsum(forward_batch.seq_lens, dim=0)
|
kv_indptr[1 : bs + 1] = torch.cumsum(forward_batch.seq_lens, dim=0)
|
||||||
kv_indptr = kv_indptr[: bs + 1]
|
kv_indptr = kv_indptr[: bs + 1]
|
||||||
kv_indices = torch.zeros(
|
kv_indices = torch.empty(
|
||||||
forward_batch.seq_lens_sum, dtype=torch.int32, device=self.device
|
forward_batch.seq_lens_sum, dtype=torch.int32, device=self.device
|
||||||
)
|
)
|
||||||
create_flashinfer_kv_indices_triton[(bs,)](
|
create_flashinfer_kv_indices_triton[(bs,)](
|
||||||
@@ -183,39 +188,35 @@ class AiterAttnBackend(AttentionBackend):
|
|||||||
qo_indptr = self.qo_indptr_[: bs + 1]
|
qo_indptr = self.qo_indptr_[: bs + 1]
|
||||||
qo_indptr[1 : bs + 1] = torch.cumsum(self.kv_last_page_len[:bs], dim=0)
|
qo_indptr[1 : bs + 1] = torch.cumsum(self.kv_last_page_len[:bs], dim=0)
|
||||||
kv_last_page_len = self.kv_last_page_len[:bs]
|
kv_last_page_len = self.kv_last_page_len[:bs]
|
||||||
max_extend_len = 1
|
max_q_len = 1
|
||||||
|
|
||||||
self.forward_metadata = ForwardMetadata(
|
self.forward_metadata = ForwardMetadata(
|
||||||
kv_indptr,
|
kv_indptr,
|
||||||
kv_indices,
|
kv_indices,
|
||||||
qo_indptr,
|
qo_indptr,
|
||||||
kv_last_page_len,
|
kv_last_page_len,
|
||||||
max_extend_len,
|
max_q_len,
|
||||||
None,
|
|
||||||
None,
|
|
||||||
None,
|
None,
|
||||||
)
|
)
|
||||||
|
|
||||||
elif forward_batch.forward_mode.is_draft_extend():
|
elif forward_batch.forward_mode.is_draft_extend():
|
||||||
if self.use_mla:
|
if self.use_mla:
|
||||||
prefix_lens = forward_batch.extend_prefix_lens
|
kv_indices, kv_indptr, qo_indptr, custom_mask = (
|
||||||
self.mla_indices_updater_prefill.update(
|
spec_info.generate_attn_arg_prefill(
|
||||||
forward_batch.req_pool_indices,
|
forward_batch.req_pool_indices,
|
||||||
prefix_lens,
|
forward_batch.seq_lens,
|
||||||
prefix_lens.sum().item(),
|
forward_batch.seq_lens_sum,
|
||||||
forward_batch.extend_seq_lens,
|
self.req_to_token,
|
||||||
encoder_lens=forward_batch.encoder_lens,
|
)
|
||||||
spec_info=None,
|
|
||||||
)
|
)
|
||||||
self.forward_metadata = ForwardMetadata(
|
self.forward_metadata = ForwardMetadata(
|
||||||
self.mla_indices_updater_prefill.kv_indptr,
|
kv_indptr,
|
||||||
self.mla_indices_updater_prefill.kv_indices,
|
kv_indices,
|
||||||
self.mla_indices_updater_prefill.qo_indptr,
|
qo_indptr,
|
||||||
self.mla_indices_updater_prefill.kv_last_page_len,
|
# self.mla_indices_updater_prefill.kv_last_page_len,
|
||||||
self.mla_indices_updater_prefill.max_extend_len,
|
self.kv_last_page_len[:bs],
|
||||||
self.mla_indices_updater_prefill.max_prefix_extend_len,
|
max(forward_batch.extend_seq_lens_cpu),
|
||||||
None,
|
forward_batch.seq_lens_cpu.max().item(),
|
||||||
None,
|
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
self.indices_updater_prefill.update(
|
self.indices_updater_prefill.update(
|
||||||
@@ -231,30 +232,47 @@ class AiterAttnBackend(AttentionBackend):
|
|||||||
self.indices_updater_prefill.kv_indices,
|
self.indices_updater_prefill.kv_indices,
|
||||||
None,
|
None,
|
||||||
None,
|
None,
|
||||||
None,
|
|
||||||
None,
|
|
||||||
self.indices_updater_prefill.max_q_len,
|
self.indices_updater_prefill.max_q_len,
|
||||||
self.indices_updater_prefill.max_kv_len,
|
self.indices_updater_prefill.max_kv_len,
|
||||||
)
|
)
|
||||||
elif forward_batch.forward_mode.is_target_verify():
|
elif forward_batch.forward_mode.is_target_verify():
|
||||||
if self.use_mla:
|
if self.use_mla:
|
||||||
prefix_lens = forward_batch.extend_prefix_lens
|
draft_num = spec_info.draft_token_num
|
||||||
self.mla_indices_updater_prefill.update(
|
kv_lens = forward_batch.seq_lens + draft_num
|
||||||
|
kv_lens_sum = forward_batch.seq_lens_sum + draft_num * bs
|
||||||
|
device = forward_batch.seq_lens.device
|
||||||
|
|
||||||
|
qo_indptr = torch.arange(
|
||||||
|
0,
|
||||||
|
(1 + bs) * draft_num,
|
||||||
|
step=draft_num,
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
kv_indptr = self.kv_indptr
|
||||||
|
kv_indptr[1 : bs + 1] = torch.cumsum(kv_lens, dim=0)
|
||||||
|
kv_indptr = kv_indptr[: bs + 1]
|
||||||
|
kv_indices = torch.empty(
|
||||||
|
kv_lens_sum,
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
create_flashinfer_kv_indices_triton[(bs,)](
|
||||||
|
self.req_to_token,
|
||||||
forward_batch.req_pool_indices,
|
forward_batch.req_pool_indices,
|
||||||
prefix_lens,
|
kv_lens,
|
||||||
prefix_lens.sum().item(),
|
kv_indptr,
|
||||||
forward_batch.extend_seq_lens,
|
None,
|
||||||
encoder_lens=forward_batch.encoder_lens,
|
kv_indices,
|
||||||
spec_info=None,
|
self.req_to_token.stride(0),
|
||||||
)
|
)
|
||||||
self.forward_metadata = ForwardMetadata(
|
self.forward_metadata = ForwardMetadata(
|
||||||
self.mla_indices_updater_prefill.kv_indptr,
|
kv_indptr,
|
||||||
self.mla_indices_updater_prefill.kv_indices,
|
kv_indices,
|
||||||
self.mla_indices_updater_prefill.qo_indptr,
|
qo_indptr,
|
||||||
self.mla_indices_updater_prefill.kv_last_page_len,
|
# self.mla_indices_updater_prefill.kv_last_page_len,
|
||||||
self.mla_indices_updater_prefill.max_extend_len,
|
self.kv_last_page_len[:bs],
|
||||||
self.mla_indices_updater_prefill.max_prefix_extend_len,
|
draft_num,
|
||||||
None,
|
|
||||||
None,
|
None,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
@@ -271,8 +289,6 @@ class AiterAttnBackend(AttentionBackend):
|
|||||||
self.indices_updater_prefill.kv_indices,
|
self.indices_updater_prefill.kv_indices,
|
||||||
None,
|
None,
|
||||||
None,
|
None,
|
||||||
None,
|
|
||||||
None,
|
|
||||||
self.indices_updater_prefill.max_q_len,
|
self.indices_updater_prefill.max_q_len,
|
||||||
self.indices_updater_prefill.max_kv_len,
|
self.indices_updater_prefill.max_kv_len,
|
||||||
)
|
)
|
||||||
@@ -283,25 +299,26 @@ class AiterAttnBackend(AttentionBackend):
|
|||||||
extend_no_prefix = False
|
extend_no_prefix = False
|
||||||
else:
|
else:
|
||||||
extend_no_prefix = not any(forward_batch.extend_prefix_lens_cpu)
|
extend_no_prefix = not any(forward_batch.extend_prefix_lens_cpu)
|
||||||
|
|
||||||
if self.use_mla:
|
if self.use_mla:
|
||||||
self.mla_indices_updater_prefill.update(
|
self.mla_indices_updater_prefill.update(
|
||||||
forward_batch.req_pool_indices,
|
forward_batch.req_pool_indices,
|
||||||
prefix_lens,
|
forward_batch.extend_prefix_lens,
|
||||||
prefix_lens.sum().item(),
|
sum(forward_batch.extend_prefix_lens_cpu),
|
||||||
forward_batch.extend_seq_lens,
|
forward_batch.extend_seq_lens,
|
||||||
encoder_lens=forward_batch.encoder_lens,
|
max(forward_batch.extend_seq_lens_cpu),
|
||||||
|
forward_batch.seq_lens_cpu.max().item(),
|
||||||
spec_info=None,
|
spec_info=None,
|
||||||
)
|
)
|
||||||
|
self.mla_indices_updater_prefill.kv_indptr += (
|
||||||
|
self.mla_indices_updater_prefill.qo_indptr
|
||||||
|
)
|
||||||
self.forward_metadata = ForwardMetadata(
|
self.forward_metadata = ForwardMetadata(
|
||||||
self.mla_indices_updater_prefill.kv_indptr,
|
self.mla_indices_updater_prefill.kv_indptr,
|
||||||
self.mla_indices_updater_prefill.kv_indices,
|
self.mla_indices_updater_prefill.kv_indices,
|
||||||
self.mla_indices_updater_prefill.qo_indptr,
|
self.mla_indices_updater_prefill.qo_indptr,
|
||||||
self.mla_indices_updater_prefill.kv_last_page_len,
|
self.kv_last_page_len[:bs],
|
||||||
self.mla_indices_updater_prefill.max_extend_len,
|
self.mla_indices_updater_prefill.max_q_len,
|
||||||
self.mla_indices_updater_prefill.max_prefix_extend_len,
|
self.mla_indices_updater_prefill.max_kv_len,
|
||||||
None,
|
|
||||||
None,
|
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
self.indices_updater_prefill.update(
|
self.indices_updater_prefill.update(
|
||||||
@@ -317,8 +334,6 @@ class AiterAttnBackend(AttentionBackend):
|
|||||||
self.indices_updater_prefill.kv_indices,
|
self.indices_updater_prefill.kv_indices,
|
||||||
None,
|
None,
|
||||||
None,
|
None,
|
||||||
None,
|
|
||||||
None,
|
|
||||||
self.indices_updater_prefill.max_q_len,
|
self.indices_updater_prefill.max_q_len,
|
||||||
self.indices_updater_prefill.max_kv_len,
|
self.indices_updater_prefill.max_kv_len,
|
||||||
)
|
)
|
||||||
@@ -359,7 +374,7 @@ class AiterAttnBackend(AttentionBackend):
|
|||||||
if forward_mode.is_decode_or_idle():
|
if forward_mode.is_decode_or_idle():
|
||||||
qo_indptr = None
|
qo_indptr = None
|
||||||
kv_last_page_len = None
|
kv_last_page_len = None
|
||||||
max_extend_len = None
|
max_q_len = None
|
||||||
|
|
||||||
if spec_info is None:
|
if spec_info is None:
|
||||||
kv_indptr = self.kv_indptr
|
kv_indptr = self.kv_indptr
|
||||||
@@ -383,17 +398,15 @@ class AiterAttnBackend(AttentionBackend):
|
|||||||
qo_indptr[1 : bs + 1] = torch.cumsum(
|
qo_indptr[1 : bs + 1] = torch.cumsum(
|
||||||
self.cuda_graph_kv_last_page_len[:bs], dim=0
|
self.cuda_graph_kv_last_page_len[:bs], dim=0
|
||||||
)
|
)
|
||||||
max_extend_len = 1
|
|
||||||
kv_last_page_len = self.cuda_graph_kv_last_page_len[:bs]
|
kv_last_page_len = self.cuda_graph_kv_last_page_len[:bs]
|
||||||
|
max_q_len = 1
|
||||||
|
|
||||||
self.forward_metadata = ForwardMetadata(
|
self.forward_metadata = ForwardMetadata(
|
||||||
kv_indptr,
|
kv_indptr,
|
||||||
kv_indices,
|
kv_indices,
|
||||||
qo_indptr,
|
qo_indptr,
|
||||||
kv_last_page_len,
|
kv_last_page_len,
|
||||||
max_extend_len,
|
max_q_len,
|
||||||
None,
|
|
||||||
None,
|
|
||||||
None,
|
None,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -419,18 +432,15 @@ class AiterAttnBackend(AttentionBackend):
|
|||||||
kv_indices,
|
kv_indices,
|
||||||
self.req_to_token.stride(0),
|
self.req_to_token.stride(0),
|
||||||
)
|
)
|
||||||
|
kv_last_page_len = self.cuda_graph_kv_last_page_len[:bs]
|
||||||
max_extend_len = self.num_draft_tokens
|
max_q_len = self.num_draft_tokens
|
||||||
kv_last_page_len = None
|
|
||||||
|
|
||||||
self.forward_metadata = ForwardMetadata(
|
self.forward_metadata = ForwardMetadata(
|
||||||
kv_indptr,
|
kv_indptr,
|
||||||
kv_indices,
|
kv_indices,
|
||||||
qo_indptr,
|
qo_indptr,
|
||||||
kv_last_page_len,
|
kv_last_page_len,
|
||||||
max_extend_len,
|
max_q_len,
|
||||||
None,
|
|
||||||
None,
|
|
||||||
None,
|
None,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
@@ -448,12 +458,41 @@ class AiterAttnBackend(AttentionBackend):
|
|||||||
self.indices_updater_prefill.kv_indices,
|
self.indices_updater_prefill.kv_indices,
|
||||||
None,
|
None,
|
||||||
None,
|
None,
|
||||||
None,
|
|
||||||
None,
|
|
||||||
self.indices_updater_prefill.max_q_len,
|
self.indices_updater_prefill.max_q_len,
|
||||||
self.indices_updater_prefill.max_kv_len,
|
self.indices_updater_prefill.max_kv_len,
|
||||||
)
|
)
|
||||||
|
elif forward_mode.is_draft_extend():
|
||||||
|
num_tokens_per_bs = self.speculative_num_steps + 1
|
||||||
|
qo_indptr = self.qo_indptr[: bs + 1]
|
||||||
|
qo_indptr[: bs + 1] = torch.arange(
|
||||||
|
0,
|
||||||
|
bs * num_tokens_per_bs + 1,
|
||||||
|
step=num_tokens_per_bs,
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=self.device,
|
||||||
|
)
|
||||||
|
kv_indptr = self.kv_indptr[: bs + 1]
|
||||||
|
kv_indptr[1 : bs + 1] = torch.cumsum(seq_lens, dim=0)
|
||||||
|
kv_indices = self.cuda_graph_kv_indices
|
||||||
|
create_flashinfer_kv_indices_triton[(bs,)](
|
||||||
|
self.req_to_token,
|
||||||
|
req_pool_indices,
|
||||||
|
seq_lens,
|
||||||
|
kv_indptr,
|
||||||
|
None,
|
||||||
|
kv_indices,
|
||||||
|
self.req_to_token.stride(0),
|
||||||
|
)
|
||||||
|
kv_last_page_len = self.cuda_graph_kv_last_page_len[:bs]
|
||||||
|
max_q_len = num_tokens_per_bs
|
||||||
|
self.forward_metadata = ForwardMetadata(
|
||||||
|
kv_indptr,
|
||||||
|
kv_indices,
|
||||||
|
qo_indptr,
|
||||||
|
kv_last_page_len,
|
||||||
|
max_q_len,
|
||||||
|
None,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Invalid mode: {forward_mode=}")
|
raise ValueError(f"Invalid mode: {forward_mode=}")
|
||||||
|
|
||||||
@@ -488,13 +527,44 @@ class AiterAttnBackend(AttentionBackend):
|
|||||||
kv_indices[: spec_info.kv_indices.shape[0]] = spec_info.kv_indices
|
kv_indices[: spec_info.kv_indices.shape[0]] = spec_info.kv_indices
|
||||||
|
|
||||||
elif forward_mode.is_target_verify():
|
elif forward_mode.is_target_verify():
|
||||||
self.indices_updater_prefill.update(
|
bs = len(req_pool_indices)
|
||||||
req_pool_indices[:bs],
|
qo_indptr = self.qo_indptr[: bs + 1]
|
||||||
seq_lens[:bs],
|
qo_indptr[: bs + 1] = torch.arange(
|
||||||
seq_lens_sum,
|
0,
|
||||||
prefix_lens=None,
|
(1 + bs) * self.num_draft_tokens,
|
||||||
encoder_lens=encoder_lens[:bs] if encoder_lens is not None else None,
|
step=self.num_draft_tokens,
|
||||||
spec_info=spec_info,
|
dtype=torch.int32,
|
||||||
|
device=self.device,
|
||||||
|
)
|
||||||
|
kv_lens = seq_lens + self.num_draft_tokens
|
||||||
|
kv_indptr = self.kv_indptr[: bs + 1]
|
||||||
|
kv_indptr[1 : bs + 1] = torch.cumsum(kv_lens, dim=0)
|
||||||
|
kv_indices = self.cuda_graph_kv_indices
|
||||||
|
create_flashinfer_kv_indices_triton[(bs,)](
|
||||||
|
self.req_to_token,
|
||||||
|
req_pool_indices,
|
||||||
|
kv_lens,
|
||||||
|
kv_indptr,
|
||||||
|
None,
|
||||||
|
kv_indices,
|
||||||
|
self.req_to_token.stride(0),
|
||||||
|
)
|
||||||
|
elif forward_mode.is_draft_extend():
|
||||||
|
seq_lens = seq_lens[:bs]
|
||||||
|
accept_lens = spec_info.accept_length[:bs]
|
||||||
|
qo_indptr = self.qo_indptr[: bs + 1]
|
||||||
|
qo_indptr[1 : bs + 1] = torch.cumsum(accept_lens, dim=0)
|
||||||
|
kv_indptr = self.kv_indptr[: bs + 1]
|
||||||
|
kv_indptr[1 : bs + 1] = torch.cumsum(seq_lens, dim=0)
|
||||||
|
kv_indices = self.cuda_graph_kv_indices
|
||||||
|
create_flashinfer_kv_indices_triton[(bs,)](
|
||||||
|
self.req_to_token,
|
||||||
|
req_pool_indices,
|
||||||
|
seq_lens,
|
||||||
|
kv_indptr,
|
||||||
|
None,
|
||||||
|
kv_indices,
|
||||||
|
self.req_to_token.stride(0),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError("Invalid forward mode")
|
raise ValueError("Invalid forward mode")
|
||||||
@@ -530,11 +600,10 @@ class AiterAttnBackend(AttentionBackend):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if self.use_mla:
|
if self.use_mla:
|
||||||
max_extend_len = self.forward_metadata.max_extend_len
|
max_q_len = self.forward_metadata.max_q_len
|
||||||
max_prefix_extend_len = self.forward_metadata.max_prefix_extend_len
|
max_kv_len = self.forward_metadata.max_kv_len
|
||||||
kv_indptr = self.forward_metadata.kv_indptr
|
kv_indptr = self.forward_metadata.kv_indptr
|
||||||
kv_indices = self.forward_metadata.kv_indices
|
kv_indices = self.forward_metadata.kv_indices
|
||||||
kv_last_page_lens = self.forward_metadata.kv_last_page_len
|
|
||||||
qo_indptr = self.forward_metadata.qo_indptr
|
qo_indptr = self.forward_metadata.qo_indptr
|
||||||
K_Buffer = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
|
K_Buffer = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
|
||||||
V_Buffer = forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id)
|
V_Buffer = forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id)
|
||||||
@@ -552,8 +621,8 @@ class AiterAttnBackend(AttentionBackend):
|
|||||||
v,
|
v,
|
||||||
qo_indptr,
|
qo_indptr,
|
||||||
qo_indptr,
|
qo_indptr,
|
||||||
max_extend_len,
|
max_q_len,
|
||||||
max_extend_len,
|
max_q_len,
|
||||||
softmax_scale=layer.scaling,
|
softmax_scale=layer.scaling,
|
||||||
causal=True,
|
causal=True,
|
||||||
)
|
)
|
||||||
@@ -599,12 +668,71 @@ class AiterAttnBackend(AttentionBackend):
|
|||||||
v,
|
v,
|
||||||
qo_indptr,
|
qo_indptr,
|
||||||
kv_indptr,
|
kv_indptr,
|
||||||
max_extend_len,
|
max_q_len,
|
||||||
max_prefix_extend_len,
|
max_kv_len,
|
||||||
softmax_scale=layer.scaling,
|
softmax_scale=layer.scaling,
|
||||||
causal=True,
|
causal=True,
|
||||||
)
|
)
|
||||||
return o
|
return o
|
||||||
|
elif forward_batch.forward_mode.is_target_verify():
|
||||||
|
o = q.new_empty((q.shape[0], layer.tp_q_head_num, layer.v_head_dim))
|
||||||
|
mla_decode_fwd(
|
||||||
|
q,
|
||||||
|
K_Buffer.view(-1, 1, 1, layer.qk_head_dim),
|
||||||
|
o,
|
||||||
|
self.forward_metadata.qo_indptr,
|
||||||
|
self.forward_metadata.kv_indptr,
|
||||||
|
self.forward_metadata.kv_indices,
|
||||||
|
self.forward_metadata.kv_last_page_len,
|
||||||
|
self.forward_metadata.max_q_len,
|
||||||
|
layer.scaling,
|
||||||
|
layer.logit_cap,
|
||||||
|
)
|
||||||
|
K_Buffer = K_Buffer.view(-1, 1, layer.qk_head_dim)
|
||||||
|
return o
|
||||||
|
elif forward_batch.forward_mode.is_draft_extend():
|
||||||
|
o = q.new_empty((q.shape[0], layer.tp_q_head_num, layer.v_head_dim))
|
||||||
|
causal = True
|
||||||
|
sliding_window_size = -1
|
||||||
|
kv_indptr = self.forward_metadata.kv_indptr
|
||||||
|
kv_indices = self.forward_metadata.kv_indices
|
||||||
|
mla_prefill_fwd(
|
||||||
|
q,
|
||||||
|
K_Buffer.view(-1, 1, 1, layer.qk_head_dim),
|
||||||
|
o,
|
||||||
|
self.forward_metadata.qo_indptr,
|
||||||
|
self.forward_metadata.kv_indptr,
|
||||||
|
self.forward_metadata.kv_indices,
|
||||||
|
self.forward_metadata.kv_last_page_len,
|
||||||
|
self.forward_metadata.max_q_len,
|
||||||
|
layer.scaling,
|
||||||
|
layer.logit_cap,
|
||||||
|
)
|
||||||
|
K_Buffer = K_Buffer.view(-1, 1, layer.qk_head_dim)
|
||||||
|
return o
|
||||||
|
# self.extend_attention_fwd(
|
||||||
|
# q.view(-1, layer.tp_q_head_num, layer.qk_head_dim),
|
||||||
|
# k.contiguous(),
|
||||||
|
# v.contiguous(),
|
||||||
|
# o.view(-1, layer.tp_q_head_num, layer.v_head_dim),
|
||||||
|
# forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id),
|
||||||
|
# forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id),
|
||||||
|
# self.forward_metadata.qo_indptr,
|
||||||
|
# kv_indptr,
|
||||||
|
# kv_indices,
|
||||||
|
# None,
|
||||||
|
# causal,
|
||||||
|
# None,
|
||||||
|
# self.forward_metadata.max_q_len,
|
||||||
|
# layer.scaling,
|
||||||
|
# layer.logit_cap,
|
||||||
|
# sliding_window_size,
|
||||||
|
# )
|
||||||
|
# return o
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"Invalid forward mode for MLA prefill: {forward_batch.forward_mode=}"
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
k_cache, v_cache = forward_batch.token_to_kv_pool.get_kv_buffer(
|
k_cache, v_cache = forward_batch.token_to_kv_pool.get_kv_buffer(
|
||||||
layer.layer_id
|
layer.layer_id
|
||||||
@@ -662,7 +790,7 @@ class AiterAttnBackend(AttentionBackend):
|
|||||||
self.forward_metadata.kv_indptr,
|
self.forward_metadata.kv_indptr,
|
||||||
self.forward_metadata.kv_indices,
|
self.forward_metadata.kv_indices,
|
||||||
self.forward_metadata.kv_last_page_len,
|
self.forward_metadata.kv_last_page_len,
|
||||||
self.forward_metadata.max_extend_len,
|
self.forward_metadata.max_q_len,
|
||||||
layer.scaling,
|
layer.scaling,
|
||||||
layer.logit_cap,
|
layer.logit_cap,
|
||||||
)
|
)
|
||||||
@@ -816,16 +944,17 @@ class AiterMlaIndicesUpdaterPrefill:
|
|||||||
self.kv_indices = None
|
self.kv_indices = None
|
||||||
self.qo_indptr = None
|
self.qo_indptr = None
|
||||||
self.kv_last_page_len = None
|
self.kv_last_page_len = None
|
||||||
self.max_extend_len = 0
|
self.max_q_len = 0
|
||||||
self.max_prefix_extend_len = 0
|
self.max_kv_len = 0
|
||||||
|
|
||||||
def update(
|
def update(
|
||||||
self,
|
self,
|
||||||
req_pool_indices: torch.Tensor,
|
req_pool_indices: torch.Tensor,
|
||||||
prefix_lens: torch.Tensor,
|
kv_lens: torch.Tensor,
|
||||||
prefix_lens_sum: int,
|
kv_lens_sum: int,
|
||||||
extend_lens: torch.Tensor,
|
extend_lens: torch.Tensor,
|
||||||
encoder_lens: Optional[torch.Tensor],
|
max_q_len: int,
|
||||||
|
max_kv_len: int,
|
||||||
spec_info: Optional[SpecInfo],
|
spec_info: Optional[SpecInfo],
|
||||||
):
|
):
|
||||||
# Keep the signature for type checking. It will be assigned during runtime.
|
# Keep the signature for type checking. It will be assigned during runtime.
|
||||||
@@ -834,33 +963,30 @@ class AiterMlaIndicesUpdaterPrefill:
|
|||||||
def update_single_wrapper(
|
def update_single_wrapper(
|
||||||
self,
|
self,
|
||||||
req_pool_indices: torch.Tensor,
|
req_pool_indices: torch.Tensor,
|
||||||
prefix_lens: torch.Tensor,
|
kv_lens: torch.Tensor,
|
||||||
prefix_lens_sum: int,
|
kv_lens_sum: int,
|
||||||
extend_lens: torch.Tensor,
|
extend_lens: torch.Tensor,
|
||||||
encoder_lens: Optional[torch.Tensor],
|
max_q_len: int,
|
||||||
|
max_kv_len: int,
|
||||||
spec_info: Optional[SpecInfo],
|
spec_info: Optional[SpecInfo],
|
||||||
):
|
):
|
||||||
|
|
||||||
paged_kernel_lens = prefix_lens
|
|
||||||
paged_kernel_lens_sum = prefix_lens_sum
|
|
||||||
|
|
||||||
bs = len(req_pool_indices)
|
bs = len(req_pool_indices)
|
||||||
|
|
||||||
kv_indptr = self.attn_backend.kv_indptr
|
kv_indptr = self.attn_backend.kv_indptr
|
||||||
|
|
||||||
if spec_info is None:
|
if spec_info is None:
|
||||||
# Normal extend
|
# Normal extend
|
||||||
kv_indptr[1 : bs + 1] = torch.cumsum(paged_kernel_lens, dim=0)
|
kv_indptr[1 : bs + 1] = torch.cumsum(kv_lens, dim=0)
|
||||||
kv_indptr = kv_indptr[: bs + 1]
|
kv_indptr = kv_indptr[: bs + 1]
|
||||||
kv_indices = torch.empty(
|
kv_indices = torch.empty(
|
||||||
paged_kernel_lens_sum,
|
kv_lens_sum,
|
||||||
dtype=torch.int32,
|
dtype=torch.int32,
|
||||||
device=req_pool_indices.device,
|
device=req_pool_indices.device,
|
||||||
)
|
)
|
||||||
create_flashinfer_kv_indices_triton[(bs,)](
|
create_flashinfer_kv_indices_triton[(bs,)](
|
||||||
self.req_to_token,
|
self.req_to_token,
|
||||||
req_pool_indices,
|
req_pool_indices,
|
||||||
paged_kernel_lens,
|
kv_lens,
|
||||||
kv_indptr,
|
kv_indptr,
|
||||||
None,
|
None,
|
||||||
kv_indices,
|
kv_indices,
|
||||||
@@ -870,16 +996,12 @@ class AiterMlaIndicesUpdaterPrefill:
|
|||||||
qo_indptr = self.attn_backend.qo_indptr
|
qo_indptr = self.attn_backend.qo_indptr
|
||||||
qo_indptr[1 : bs + 1] = torch.cumsum(extend_lens, dim=0)
|
qo_indptr[1 : bs + 1] = torch.cumsum(extend_lens, dim=0)
|
||||||
qo_indptr = qo_indptr[: bs + 1]
|
qo_indptr = qo_indptr[: bs + 1]
|
||||||
|
|
||||||
max_extend_len = torch.max(extend_lens).item()
|
|
||||||
max_prefix_extend_len = torch.max(extend_lens + paged_kernel_lens).item()
|
|
||||||
kv_indptr += qo_indptr
|
|
||||||
else:
|
else:
|
||||||
kv_indices, kv_indptr, qo_indptr, custom_mask = (
|
kv_indices, kv_indptr, qo_indptr, custom_mask = (
|
||||||
spec_info.generate_attn_arg_prefill(
|
spec_info.generate_attn_arg_prefill(
|
||||||
req_pool_indices,
|
req_pool_indices,
|
||||||
paged_kernel_lens,
|
kv_lens,
|
||||||
paged_kernel_lens_sum,
|
kv_lens_sum,
|
||||||
self.req_to_token,
|
self.req_to_token,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
@@ -887,5 +1009,146 @@ class AiterMlaIndicesUpdaterPrefill:
|
|||||||
self.kv_indptr = kv_indptr
|
self.kv_indptr = kv_indptr
|
||||||
self.kv_indices = kv_indices
|
self.kv_indices = kv_indices
|
||||||
self.qo_indptr = qo_indptr
|
self.qo_indptr = qo_indptr
|
||||||
self.max_extend_len = max_extend_len
|
self.max_q_len = max_q_len
|
||||||
self.max_prefix_extend_len = max_prefix_extend_len
|
self.max_kv_len = max_kv_len
|
||||||
|
|
||||||
|
|
||||||
|
class AiterMultiStepDraftBackend:
|
||||||
|
"""
|
||||||
|
Wrap multiple triton attention backends as one for multiple consecutive
|
||||||
|
draft decoding steps.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_runner: ModelRunner,
|
||||||
|
topk: int,
|
||||||
|
speculative_num_steps: int,
|
||||||
|
):
|
||||||
|
from sglang.srt.speculative.eagle_utils import generate_draft_decode_kv_indices
|
||||||
|
|
||||||
|
self.topk = topk
|
||||||
|
self.speculative_num_steps = speculative_num_steps
|
||||||
|
self.generate_draft_decode_kv_indices = generate_draft_decode_kv_indices
|
||||||
|
max_bs = model_runner.req_to_token_pool.size * self.topk
|
||||||
|
self.kv_indptr = torch.zeros(
|
||||||
|
(
|
||||||
|
self.speculative_num_steps,
|
||||||
|
max_bs + 1,
|
||||||
|
),
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=model_runner.device,
|
||||||
|
)
|
||||||
|
self.attn_backends = []
|
||||||
|
for i in range(self.speculative_num_steps):
|
||||||
|
self.attn_backends.append(
|
||||||
|
AiterAttnBackend(
|
||||||
|
model_runner,
|
||||||
|
skip_prefill=True,
|
||||||
|
kv_indptr_buf=self.kv_indptr[i],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self.max_context_len = self.attn_backends[0].max_context_len
|
||||||
|
self.num_head = (
|
||||||
|
model_runner.model_config.num_attention_heads // get_attention_tp_size()
|
||||||
|
)
|
||||||
|
self.device = model_runner.device
|
||||||
|
# Cached variables for generate_draft_decode_kv_indices
|
||||||
|
self.pool_len = model_runner.req_to_token_pool.req_to_token.shape[1]
|
||||||
|
self.page_size = model_runner.server_args.page_size
|
||||||
|
assert self.page_size == 1, "Page size must be 1"
|
||||||
|
|
||||||
|
def common_template(
|
||||||
|
self, forward_batch: ForwardBatch, kv_indices_buffer: torch.Tensor, call_fn: int
|
||||||
|
):
|
||||||
|
num_seqs = forward_batch.batch_size
|
||||||
|
bs = self.topk * num_seqs
|
||||||
|
seq_lens_sum = forward_batch.seq_lens_sum
|
||||||
|
|
||||||
|
self.generate_draft_decode_kv_indices[
|
||||||
|
(self.speculative_num_steps, num_seqs, self.topk)
|
||||||
|
](
|
||||||
|
forward_batch.req_pool_indices,
|
||||||
|
forward_batch.req_to_token_pool.req_to_token,
|
||||||
|
forward_batch.seq_lens,
|
||||||
|
kv_indices_buffer,
|
||||||
|
self.kv_indptr,
|
||||||
|
forward_batch.positions,
|
||||||
|
self.pool_len,
|
||||||
|
kv_indices_buffer.shape[1],
|
||||||
|
self.kv_indptr.shape[1],
|
||||||
|
triton.next_power_of_2(num_seqs),
|
||||||
|
triton.next_power_of_2(self.speculative_num_steps),
|
||||||
|
triton.next_power_of_2(bs),
|
||||||
|
self.page_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
for i in range(self.speculative_num_steps):
|
||||||
|
forward_batch.spec_info.kv_indptr = self.kv_indptr[i, : bs + 1]
|
||||||
|
forward_batch.spec_info.kv_indices = kv_indices_buffer[i][
|
||||||
|
: seq_lens_sum * self.topk + bs * (i + 1)
|
||||||
|
]
|
||||||
|
call_fn(i, forward_batch)
|
||||||
|
|
||||||
|
def init_forward_metadata(self, forward_batch: ForwardBatch):
|
||||||
|
kv_indices = torch.empty(
|
||||||
|
(
|
||||||
|
self.speculative_num_steps,
|
||||||
|
forward_batch.batch_size * self.topk * self.max_context_len,
|
||||||
|
),
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=self.device,
|
||||||
|
)
|
||||||
|
|
||||||
|
def call_fn(i, forward_batch):
|
||||||
|
forward_batch.spec_info.kv_indptr = (
|
||||||
|
forward_batch.spec_info.kv_indptr.clone()
|
||||||
|
)
|
||||||
|
forward_batch.spec_info.kv_indices = (
|
||||||
|
forward_batch.spec_info.kv_indices.clone()
|
||||||
|
)
|
||||||
|
self.attn_backends[i].init_forward_metadata(forward_batch)
|
||||||
|
|
||||||
|
self.common_template(forward_batch, kv_indices, call_fn)
|
||||||
|
|
||||||
|
def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int):
|
||||||
|
self.cuda_graph_kv_indices = torch.zeros(
|
||||||
|
(self.speculative_num_steps, max_num_tokens * self.max_context_len),
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=self.device,
|
||||||
|
)
|
||||||
|
for i in range(self.speculative_num_steps):
|
||||||
|
self.attn_backends[i].init_cuda_graph_state(
|
||||||
|
max_bs, max_num_tokens, kv_indices_buf=self.cuda_graph_kv_indices[i]
|
||||||
|
)
|
||||||
|
|
||||||
|
def init_forward_metadata_capture_cuda_graph(self, forward_batch: ForwardBatch):
|
||||||
|
def call_fn(i, forward_batch):
|
||||||
|
self.attn_backends[i].init_forward_metadata_capture_cuda_graph(
|
||||||
|
forward_batch.batch_size,
|
||||||
|
forward_batch.batch_size * self.topk,
|
||||||
|
forward_batch.req_pool_indices,
|
||||||
|
forward_batch.seq_lens,
|
||||||
|
encoder_lens=None,
|
||||||
|
forward_mode=ForwardMode.DECODE,
|
||||||
|
spec_info=forward_batch.spec_info,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.common_template(forward_batch, self.cuda_graph_kv_indices, call_fn)
|
||||||
|
|
||||||
|
def init_forward_metadata_replay_cuda_graph(
|
||||||
|
self, forward_batch: ForwardBatch, bs: int
|
||||||
|
):
|
||||||
|
def call_fn(i, forward_batch):
|
||||||
|
self.attn_backends[i].init_forward_metadata_replay_cuda_graph(
|
||||||
|
bs,
|
||||||
|
forward_batch.req_pool_indices,
|
||||||
|
forward_batch.seq_lens,
|
||||||
|
seq_lens_sum=-1,
|
||||||
|
encoder_lens=None,
|
||||||
|
forward_mode=ForwardMode.DECODE,
|
||||||
|
spec_info=forward_batch.spec_info,
|
||||||
|
seq_lens_cpu=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.common_template(forward_batch, self.cuda_graph_kv_indices, call_fn)
|
||||||
|
|||||||
@@ -1722,6 +1722,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|||||||
or attention_backend_str == "cutlass_mla"
|
or attention_backend_str == "cutlass_mla"
|
||||||
or attention_backend_str == "ascend"
|
or attention_backend_str == "ascend"
|
||||||
or attention_backend_str == "trtllm_mha"
|
or attention_backend_str == "trtllm_mha"
|
||||||
|
or attention_backend_str == "aiter"
|
||||||
or global_server_args_dict["enable_two_batch_overlap"]
|
or global_server_args_dict["enable_two_batch_overlap"]
|
||||||
):
|
):
|
||||||
seq_lens_cpu = (
|
seq_lens_cpu = (
|
||||||
|
|||||||
@@ -226,6 +226,22 @@ class EAGLEWorker(TpModelWorker):
|
|||||||
self.draft_model_runner,
|
self.draft_model_runner,
|
||||||
skip_prefill=False,
|
skip_prefill=False,
|
||||||
)
|
)
|
||||||
|
elif self.server_args.attention_backend == "aiter":
|
||||||
|
from sglang.srt.layers.attention.aiter_backend import (
|
||||||
|
AiterAttnBackend,
|
||||||
|
AiterMultiStepDraftBackend,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.draft_attn_backend = AiterMultiStepDraftBackend(
|
||||||
|
self.draft_model_runner,
|
||||||
|
self.topk,
|
||||||
|
self.speculative_num_steps,
|
||||||
|
)
|
||||||
|
self.draft_extend_attn_backend = AiterAttnBackend(
|
||||||
|
self.draft_model_runner,
|
||||||
|
skip_prefill=False,
|
||||||
|
)
|
||||||
|
self.has_prefill_wrapper_verify = False
|
||||||
elif self.server_args.attention_backend == "fa3":
|
elif self.server_args.attention_backend == "fa3":
|
||||||
from sglang.srt.layers.attention.flashattention_backend import (
|
from sglang.srt.layers.attention.flashattention_backend import (
|
||||||
FlashAttentionBackend,
|
FlashAttentionBackend,
|
||||||
|
|||||||
Reference in New Issue
Block a user