AITER backend extension and workload optimizations (#6838)
Co-authored-by: wunhuang <wunhuang@amd.com> Co-authored-by: Hubert Lu <Hubert.Lu@amd.com>
This commit is contained in:
@@ -27,12 +27,19 @@ if TYPE_CHECKING:
|
||||
from sglang.srt.speculative.spec_info import SpecInfo
|
||||
|
||||
try:
|
||||
from aiter import mha_batch_prefill_func, paged_attention_ragged
|
||||
from aiter import (
|
||||
flash_attn_varlen_func,
|
||||
mha_batch_prefill_func,
|
||||
paged_attention_ragged,
|
||||
)
|
||||
from aiter.mla import mla_decode_fwd
|
||||
except ImportError:
|
||||
print(
|
||||
"aiter is AMD specific kernel library. Please make sure aiter is installed on your AMD device."
|
||||
)
|
||||
|
||||
from sglang.srt.configs.model_config import AttentionArch
|
||||
|
||||
|
||||
class WrapperDispatch(Enum):
|
||||
SLIDING_WINDOW = auto()
|
||||
@@ -43,6 +50,10 @@ class WrapperDispatch(Enum):
|
||||
class ForwardMetadata:
|
||||
kv_indptr: torch.Tensor
|
||||
kv_indices: torch.Tensor
|
||||
qo_indptr: torch.Tensor
|
||||
kv_last_page_len: torch.Tensor
|
||||
max_extend_len: int
|
||||
max_prefix_extend_len: int
|
||||
max_q_len: int
|
||||
max_kv_len: int
|
||||
|
||||
@@ -63,6 +74,7 @@ class AiterAttnBackend(AttentionBackend):
|
||||
|
||||
self.device = model_runner.device
|
||||
self.is_multimodal = model_runner.model_config.is_multimodal
|
||||
self.num_draft_tokens = model_runner.server_args.speculative_num_draft_tokens
|
||||
self.num_head = (
|
||||
model_runner.model_config.num_attention_heads // get_attention_tp_size()
|
||||
)
|
||||
@@ -75,6 +87,8 @@ class AiterAttnBackend(AttentionBackend):
|
||||
|
||||
self.req_to_token = model_runner.req_to_token_pool.req_to_token
|
||||
|
||||
self.use_mla = model_runner.model_config.attention_arch == AttentionArch.MLA
|
||||
|
||||
# Parse constants
|
||||
self.max_context_len = model_runner.model_config.context_len
|
||||
self.skip_prefill = skip_prefill
|
||||
@@ -100,6 +114,10 @@ class AiterAttnBackend(AttentionBackend):
|
||||
self.indices_updater_prefill = AiterIndicesUpdaterPrefill(
|
||||
model_runner, self
|
||||
)
|
||||
if self.use_mla:
|
||||
self.mla_indices_updater_prefill = AiterMlaIndicesUpdaterPrefill(
|
||||
model_runner, self
|
||||
)
|
||||
|
||||
# aiter kernel related initialization
|
||||
self.max_num_partitions = (
|
||||
@@ -108,33 +126,40 @@ class AiterAttnBackend(AttentionBackend):
|
||||
|
||||
nbyes_per_qo_elem = torch.finfo(torch.float32).bits // 8
|
||||
|
||||
self.workspace_buffer = torch.empty(
|
||||
(max_bs * self.num_head * self.max_num_partitions * self.head_dim)
|
||||
* nbyes_per_qo_elem
|
||||
+ 2 * (max_bs * self.num_head * self.max_num_partitions) * 4,
|
||||
dtype=torch.uint8,
|
||||
device=self.device,
|
||||
)
|
||||
if not self.use_mla:
|
||||
self.workspace_buffer = torch.empty(
|
||||
(max_bs * self.num_head * self.max_num_partitions * self.head_dim)
|
||||
* nbyes_per_qo_elem
|
||||
+ 2 * (max_bs * self.num_head * self.max_num_partitions) * 4,
|
||||
dtype=torch.uint8,
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
self.scale = float(1.0 / (self.head_dim**0.5))
|
||||
self.k_scale = self.v_scale = torch.tensor([1.0], dtype=torch.float32).to(
|
||||
self.device
|
||||
)
|
||||
self.kv_last_page_lens = torch.ones((max_bs,), dtype=torch.int32).to(
|
||||
self.device
|
||||
)
|
||||
|
||||
self.logits_soft_cap = 0.0
|
||||
|
||||
self.forward_metadata: ForwardMetadata = None
|
||||
|
||||
if self.use_mla:
|
||||
self.qo_indptr_ = torch.zeros(
|
||||
(max_bs + 1,), dtype=torch.int32, device=model_runner.device
|
||||
)
|
||||
|
||||
def init_forward_metadata(self, forward_batch: ForwardBatch):
|
||||
"""Init auxiliary variables for triton attention backend."""
|
||||
|
||||
bs = forward_batch.batch_size
|
||||
kv_indptr = self.kv_indptr
|
||||
spec_info = forward_batch.spec_info
|
||||
qo_indptr = None
|
||||
kv_last_page_len = None
|
||||
max_extend_len = None
|
||||
|
||||
if forward_batch.forward_mode.is_decode_or_idle():
|
||||
# update for aiter
|
||||
# create kv_indices and kv_inptr
|
||||
bs = forward_batch.batch_size
|
||||
kv_indptr = self.kv_indptr
|
||||
spec_info = forward_batch.spec_info
|
||||
if spec_info is None:
|
||||
kv_indptr[1 : bs + 1] = torch.cumsum(forward_batch.seq_lens, dim=0)
|
||||
kv_indptr = kv_indptr[: bs + 1]
|
||||
@@ -154,38 +179,103 @@ class AiterAttnBackend(AttentionBackend):
|
||||
kv_indptr, kv_indices = spec_info.kv_indptr, spec_info.kv_indices
|
||||
bs = kv_indptr.shape[0] - 1
|
||||
|
||||
self.forward_metadata = ForwardMetadata(kv_indptr, kv_indices, None, None)
|
||||
if self.use_mla:
|
||||
qo_indptr = self.qo_indptr_[: bs + 1]
|
||||
qo_indptr[1 : bs + 1] = torch.cumsum(self.kv_last_page_len[:bs], dim=0)
|
||||
kv_last_page_len = self.kv_last_page_len[:bs]
|
||||
max_extend_len = 1
|
||||
|
||||
self.forward_metadata = ForwardMetadata(
|
||||
kv_indptr,
|
||||
kv_indices,
|
||||
qo_indptr,
|
||||
kv_last_page_len,
|
||||
max_extend_len,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
)
|
||||
|
||||
elif forward_batch.forward_mode.is_draft_extend():
|
||||
self.indices_updater_prefill.update(
|
||||
forward_batch.req_pool_indices,
|
||||
forward_batch.seq_lens,
|
||||
forward_batch.seq_lens_sum,
|
||||
prefix_lens=None,
|
||||
encoder_lens=forward_batch.encoder_lens,
|
||||
spec_info=forward_batch.spec_info,
|
||||
)
|
||||
self.forward_metadata = ForwardMetadata(
|
||||
self.indices_updater_prefill.kv_indptr,
|
||||
self.indices_updater_prefill.kv_indices,
|
||||
self.indices_updater_prefill.max_q_len,
|
||||
self.indices_updater_prefill.max_kv_len,
|
||||
)
|
||||
if self.use_mla:
|
||||
prefix_lens = forward_batch.extend_prefix_lens
|
||||
self.mla_indices_updater_prefill.update(
|
||||
forward_batch.req_pool_indices,
|
||||
prefix_lens,
|
||||
prefix_lens.sum().item(),
|
||||
forward_batch.extend_seq_lens,
|
||||
encoder_lens=forward_batch.encoder_lens,
|
||||
spec_info=None,
|
||||
)
|
||||
self.forward_metadata = ForwardMetadata(
|
||||
self.mla_indices_updater_prefill.kv_indptr,
|
||||
self.mla_indices_updater_prefill.kv_indices,
|
||||
self.mla_indices_updater_prefill.qo_indptr,
|
||||
self.mla_indices_updater_prefill.kv_last_page_len,
|
||||
self.mla_indices_updater_prefill.max_extend_len,
|
||||
self.mla_indices_updater_prefill.max_prefix_extend_len,
|
||||
None,
|
||||
None,
|
||||
)
|
||||
else:
|
||||
self.indices_updater_prefill.update(
|
||||
forward_batch.req_pool_indices,
|
||||
forward_batch.seq_lens,
|
||||
forward_batch.seq_lens_sum,
|
||||
prefix_lens=None,
|
||||
encoder_lens=forward_batch.encoder_lens,
|
||||
spec_info=forward_batch.spec_info,
|
||||
)
|
||||
self.forward_metadata = ForwardMetadata(
|
||||
self.indices_updater_prefill.kv_indptr,
|
||||
self.indices_updater_prefill.kv_indices,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
self.indices_updater_prefill.max_q_len,
|
||||
self.indices_updater_prefill.max_kv_len,
|
||||
)
|
||||
elif forward_batch.forward_mode.is_target_verify():
|
||||
self.indices_updater_prefill.update(
|
||||
forward_batch.req_pool_indices,
|
||||
forward_batch.seq_lens,
|
||||
forward_batch.seq_lens_sum,
|
||||
prefix_lens=None,
|
||||
encoder_lens=forward_batch.encoder_lens,
|
||||
spec_info=forward_batch.spec_info,
|
||||
)
|
||||
self.forward_metadata = ForwardMetadata(
|
||||
self.indices_updater_prefill.kv_indptr,
|
||||
self.indices_updater_prefill.kv_indices,
|
||||
self.indices_updater_prefill.max_q_len,
|
||||
self.indices_updater_prefill.max_kv_len,
|
||||
)
|
||||
if self.use_mla:
|
||||
prefix_lens = forward_batch.extend_prefix_lens
|
||||
self.mla_indices_updater_prefill.update(
|
||||
forward_batch.req_pool_indices,
|
||||
prefix_lens,
|
||||
prefix_lens.sum().item(),
|
||||
forward_batch.extend_seq_lens,
|
||||
encoder_lens=forward_batch.encoder_lens,
|
||||
spec_info=None,
|
||||
)
|
||||
self.forward_metadata = ForwardMetadata(
|
||||
self.mla_indices_updater_prefill.kv_indptr,
|
||||
self.mla_indices_updater_prefill.kv_indices,
|
||||
self.mla_indices_updater_prefill.qo_indptr,
|
||||
self.mla_indices_updater_prefill.kv_last_page_len,
|
||||
self.mla_indices_updater_prefill.max_extend_len,
|
||||
self.mla_indices_updater_prefill.max_prefix_extend_len,
|
||||
None,
|
||||
None,
|
||||
)
|
||||
else:
|
||||
self.indices_updater_prefill.update(
|
||||
forward_batch.req_pool_indices,
|
||||
forward_batch.seq_lens,
|
||||
forward_batch.seq_lens_sum,
|
||||
prefix_lens=None,
|
||||
encoder_lens=forward_batch.encoder_lens,
|
||||
spec_info=forward_batch.spec_info,
|
||||
)
|
||||
self.forward_metadata = ForwardMetadata(
|
||||
self.indices_updater_prefill.kv_indptr,
|
||||
self.indices_updater_prefill.kv_indices,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
self.indices_updater_prefill.max_q_len,
|
||||
self.indices_updater_prefill.max_kv_len,
|
||||
)
|
||||
else:
|
||||
prefix_lens = forward_batch.extend_prefix_lens
|
||||
|
||||
@@ -194,24 +284,49 @@ class AiterAttnBackend(AttentionBackend):
|
||||
else:
|
||||
extend_no_prefix = not any(forward_batch.extend_prefix_lens_cpu)
|
||||
|
||||
self.indices_updater_prefill.update(
|
||||
forward_batch.req_pool_indices,
|
||||
forward_batch.seq_lens,
|
||||
forward_batch.seq_lens_sum,
|
||||
prefix_lens,
|
||||
encoder_lens=forward_batch.encoder_lens,
|
||||
spec_info=None,
|
||||
)
|
||||
self.forward_metadata = ForwardMetadata(
|
||||
self.indices_updater_prefill.kv_indptr,
|
||||
self.indices_updater_prefill.kv_indices,
|
||||
self.indices_updater_prefill.max_q_len,
|
||||
self.indices_updater_prefill.max_kv_len,
|
||||
)
|
||||
if self.use_mla:
|
||||
self.mla_indices_updater_prefill.update(
|
||||
forward_batch.req_pool_indices,
|
||||
prefix_lens,
|
||||
prefix_lens.sum().item(),
|
||||
forward_batch.extend_seq_lens,
|
||||
encoder_lens=forward_batch.encoder_lens,
|
||||
spec_info=None,
|
||||
)
|
||||
self.forward_metadata = ForwardMetadata(
|
||||
self.mla_indices_updater_prefill.kv_indptr,
|
||||
self.mla_indices_updater_prefill.kv_indices,
|
||||
self.mla_indices_updater_prefill.qo_indptr,
|
||||
self.mla_indices_updater_prefill.kv_last_page_len,
|
||||
self.mla_indices_updater_prefill.max_extend_len,
|
||||
self.mla_indices_updater_prefill.max_prefix_extend_len,
|
||||
None,
|
||||
None,
|
||||
)
|
||||
else:
|
||||
self.indices_updater_prefill.update(
|
||||
forward_batch.req_pool_indices,
|
||||
forward_batch.seq_lens,
|
||||
forward_batch.seq_lens_sum,
|
||||
prefix_lens,
|
||||
encoder_lens=forward_batch.encoder_lens,
|
||||
spec_info=None,
|
||||
)
|
||||
self.forward_metadata = ForwardMetadata(
|
||||
self.indices_updater_prefill.kv_indptr,
|
||||
self.indices_updater_prefill.kv_indices,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
self.indices_updater_prefill.max_q_len,
|
||||
self.indices_updater_prefill.max_kv_len,
|
||||
)
|
||||
|
||||
def init_cuda_graph_state(
|
||||
self, max_bs: int, kv_indices_buf: Optional[torch.Tensor] = None
|
||||
):
|
||||
self.cuda_graph_kv_last_page_len = torch.ones(max_bs, dtype=torch.int)
|
||||
if kv_indices_buf is None:
|
||||
self.cuda_graph_kv_indices = torch.zeros(
|
||||
(max_bs * self.max_context_len),
|
||||
@@ -239,6 +354,10 @@ class AiterAttnBackend(AttentionBackend):
|
||||
spec_info: Optional[SpecInfo],
|
||||
):
|
||||
if forward_mode.is_decode_or_idle():
|
||||
qo_indptr = None
|
||||
kv_last_page_len = None
|
||||
max_extend_len = None
|
||||
|
||||
if spec_info is None:
|
||||
kv_indptr = self.kv_indptr
|
||||
kv_indptr[1 : bs + 1] = torch.cumsum(seq_lens, dim=0)
|
||||
@@ -255,24 +374,82 @@ class AiterAttnBackend(AttentionBackend):
|
||||
)
|
||||
else:
|
||||
kv_indptr, kv_indices = spec_info.kv_indptr, spec_info.kv_indices
|
||||
self.forward_metadata = ForwardMetadata(kv_indptr, kv_indices, None, None)
|
||||
|
||||
if self.use_mla:
|
||||
qo_indptr = self.qo_indptr_[: bs + 1]
|
||||
qo_indptr[1 : bs + 1] = torch.cumsum(
|
||||
self.cuda_graph_kv_last_page_len[:bs], dim=0
|
||||
)
|
||||
max_extend_len = 1
|
||||
kv_last_page_len = self.cuda_graph_kv_last_page_len[:bs]
|
||||
|
||||
self.forward_metadata = ForwardMetadata(
|
||||
kv_indptr,
|
||||
kv_indices,
|
||||
qo_indptr,
|
||||
kv_last_page_len,
|
||||
max_extend_len,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
)
|
||||
|
||||
elif forward_mode.is_target_verify():
|
||||
seq_lens_sum = seq_lens.sum().item()
|
||||
self.indices_updater_prefill.update(
|
||||
req_pool_indices,
|
||||
seq_lens,
|
||||
seq_lens_sum,
|
||||
prefix_lens=None,
|
||||
encoder_lens=encoder_lens,
|
||||
spec_info=spec_info,
|
||||
)
|
||||
self.forward_metadata = ForwardMetadata(
|
||||
self.indices_updater_prefill.kv_indptr,
|
||||
self.indices_updater_prefill.kv_indices,
|
||||
self.indices_updater_prefill.max_q_len,
|
||||
self.indices_updater_prefill.max_kv_len,
|
||||
)
|
||||
if self.use_mla:
|
||||
qo_indptr = self.qo_indptr[: bs + 1]
|
||||
qo_indptr[: bs + 1] = torch.arange(
|
||||
0,
|
||||
(1 + bs) * self.num_draft_tokens,
|
||||
step=self.num_draft_tokens,
|
||||
dtype=torch.int32,
|
||||
device=self.device,
|
||||
)
|
||||
kv_indptr = self.kv_indptr[: bs + 1]
|
||||
kv_indptr[1 : bs + 1] = torch.cumsum(seq_lens, dim=0)
|
||||
kv_indices = self.cuda_graph_kv_indices
|
||||
create_flashinfer_kv_indices_triton[(bs,)](
|
||||
self.req_to_token,
|
||||
req_pool_indices,
|
||||
seq_lens,
|
||||
kv_indptr,
|
||||
None,
|
||||
kv_indices,
|
||||
self.req_to_token.stride(0),
|
||||
)
|
||||
|
||||
max_extend_len = self.num_draft_tokens
|
||||
kv_last_page_len = None
|
||||
|
||||
self.forward_metadata = ForwardMetadata(
|
||||
kv_indptr,
|
||||
kv_indices,
|
||||
qo_indptr,
|
||||
kv_last_page_len,
|
||||
max_extend_len,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
)
|
||||
else:
|
||||
seq_lens_sum = seq_lens.sum().item()
|
||||
self.indices_updater_prefill.update(
|
||||
req_pool_indices,
|
||||
seq_lens,
|
||||
seq_lens_sum,
|
||||
prefix_lens=None,
|
||||
encoder_lens=encoder_lens,
|
||||
spec_info=spec_info,
|
||||
)
|
||||
self.forward_metadata = ForwardMetadata(
|
||||
self.indices_updater_prefill.kv_indptr,
|
||||
self.indices_updater_prefill.kv_indices,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
self.indices_updater_prefill.max_q_len,
|
||||
self.indices_updater_prefill.max_kv_len,
|
||||
)
|
||||
|
||||
else:
|
||||
raise ValueError(f"Invalid mode: {forward_mode=}")
|
||||
@@ -342,31 +519,113 @@ class AiterAttnBackend(AttentionBackend):
|
||||
if k is not None:
|
||||
assert v is not None
|
||||
if save_kv_cache:
|
||||
forward_batch.token_to_kv_pool.set_kv_buffer(
|
||||
layer, cache_loc, k, v, layer.k_scale, layer.v_scale
|
||||
if self.use_mla:
|
||||
forward_batch.token_to_kv_pool.set_kv_buffer(layer, cache_loc, k, v)
|
||||
else:
|
||||
forward_batch.token_to_kv_pool.set_kv_buffer(
|
||||
layer, cache_loc, k, v, layer.k_scale, layer.v_scale
|
||||
)
|
||||
|
||||
if self.use_mla:
|
||||
max_extend_len = self.forward_metadata.max_extend_len
|
||||
max_prefix_extend_len = self.forward_metadata.max_prefix_extend_len
|
||||
kv_indptr = self.forward_metadata.kv_indptr
|
||||
kv_indices = self.forward_metadata.kv_indices
|
||||
kv_last_page_lens = self.forward_metadata.kv_last_page_len
|
||||
qo_indptr = self.forward_metadata.qo_indptr
|
||||
K_Buffer = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
|
||||
V_Buffer = forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id)
|
||||
kv_lora_rank = V_Buffer.shape[-1]
|
||||
qk_rope_head_dim = K_Buffer.shape[-1] - kv_lora_rank
|
||||
qk_nope_head_dim = k.shape[-1] - qk_rope_head_dim
|
||||
assert len(q.shape) == 3
|
||||
assert len(k.shape) == 3
|
||||
assert len(v.shape) == 3
|
||||
|
||||
if kv_indices.shape[0] == 0:
|
||||
o = flash_attn_varlen_func(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
qo_indptr,
|
||||
qo_indptr,
|
||||
max_extend_len,
|
||||
max_extend_len,
|
||||
softmax_scale=layer.scaling,
|
||||
causal=True,
|
||||
)
|
||||
return o
|
||||
elif layer.qk_head_dim != (kv_lora_rank + qk_rope_head_dim):
|
||||
K_Buffer = torch.index_select(K_Buffer, 0, kv_indices)
|
||||
kvc, k_pe = torch.split(
|
||||
K_Buffer, [kv_lora_rank, qk_rope_head_dim], dim=-1
|
||||
)
|
||||
kvprefix = layer.kv_b_proj(kvc.contiguous())[0]
|
||||
|
||||
k_cache, v_cache = forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id)
|
||||
kvprefix = kvprefix.view(
|
||||
-1, layer.tp_k_head_num, qk_nope_head_dim + layer.v_head_dim
|
||||
)
|
||||
k_prefix, v_prefix = torch.split(
|
||||
kvprefix, [qk_nope_head_dim, layer.v_head_dim], dim=-1
|
||||
)
|
||||
k_prefix = torch.cat(
|
||||
[
|
||||
k_prefix,
|
||||
torch.broadcast_to(
|
||||
k_pe,
|
||||
(k_pe.shape[0], layer.tp_k_head_num, k_pe.shape[2]),
|
||||
),
|
||||
],
|
||||
dim=-1,
|
||||
)
|
||||
assert (
|
||||
forward_batch.extend_prefix_lens.shape
|
||||
== forward_batch.extend_seq_lens.shape
|
||||
)
|
||||
k_prefix = torch.split(k_prefix, forward_batch.extend_prefix_lens_cpu)
|
||||
k_extend = torch.split(k, forward_batch.extend_seq_lens_cpu)
|
||||
assert len(k_prefix) == len(forward_batch.extend_prefix_lens_cpu)
|
||||
k = torch.cat([x for el in zip(k_prefix, k_extend) for x in el])
|
||||
v_prefix = torch.split(v_prefix, forward_batch.extend_prefix_lens_cpu)
|
||||
v_extend = torch.split(v, forward_batch.extend_seq_lens_cpu)
|
||||
v = torch.cat([x for el in zip(v_prefix, v_extend) for x in el])
|
||||
|
||||
bs0 = forward_batch.batch_size + 1
|
||||
o = flash_attn_varlen_func(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
qo_indptr,
|
||||
kv_indptr,
|
||||
max_extend_len,
|
||||
max_prefix_extend_len,
|
||||
softmax_scale=layer.scaling,
|
||||
causal=True,
|
||||
)
|
||||
return o
|
||||
else:
|
||||
k_cache, v_cache = forward_batch.token_to_kv_pool.get_kv_buffer(
|
||||
layer.layer_id
|
||||
)
|
||||
|
||||
o = mha_batch_prefill_func(
|
||||
q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
|
||||
k_cache,
|
||||
v_cache,
|
||||
self.qo_indptr[:bs0],
|
||||
self.forward_metadata.kv_indptr[:bs0],
|
||||
self.forward_metadata.kv_indices,
|
||||
self.forward_metadata.max_q_len,
|
||||
self.forward_metadata.max_kv_len,
|
||||
causal=True,
|
||||
logits_soft_cap=self.logits_soft_cap,
|
||||
alibi_slopes=None,
|
||||
return_lse=False,
|
||||
return_attn_probs=False,
|
||||
)
|
||||
bs0 = forward_batch.batch_size + 1
|
||||
|
||||
return o.view(-1, layer.tp_q_head_num * layer.head_dim)
|
||||
o = mha_batch_prefill_func(
|
||||
q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
|
||||
k_cache,
|
||||
v_cache,
|
||||
self.qo_indptr[:bs0],
|
||||
self.forward_metadata.kv_indptr[:bs0],
|
||||
self.forward_metadata.kv_indices,
|
||||
self.forward_metadata.max_q_len,
|
||||
self.forward_metadata.max_kv_len,
|
||||
causal=True,
|
||||
logits_soft_cap=self.logits_soft_cap,
|
||||
alibi_slopes=None,
|
||||
return_lse=False,
|
||||
return_attn_probs=False,
|
||||
)
|
||||
|
||||
return o.view(-1, layer.tp_q_head_num * layer.head_dim)
|
||||
|
||||
def forward_decode(
|
||||
self,
|
||||
@@ -377,6 +636,7 @@ class AiterAttnBackend(AttentionBackend):
|
||||
forward_batch: ForwardBatch,
|
||||
save_kv_cache=True,
|
||||
):
|
||||
|
||||
q = q.reshape(-1, layer.tp_q_head_num * layer.qk_head_dim)
|
||||
|
||||
if layer.qk_head_dim != layer.v_head_dim:
|
||||
@@ -389,32 +649,48 @@ class AiterAttnBackend(AttentionBackend):
|
||||
layer, forward_batch.out_cache_loc, k, v
|
||||
)
|
||||
|
||||
self.logits_soft_cap = layer.logit_cap
|
||||
paged_attention_ragged(
|
||||
o.view(-1, layer.tp_q_head_num, layer.qk_head_dim),
|
||||
self.workspace_buffer,
|
||||
q.view(-1, layer.tp_q_head_num, layer.qk_head_dim),
|
||||
forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id).view(
|
||||
-1, 1, layer.tp_k_head_num, layer.qk_head_dim
|
||||
),
|
||||
forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id).view(
|
||||
-1, 1, layer.tp_v_head_num, layer.v_head_dim
|
||||
),
|
||||
self.scale,
|
||||
self.forward_metadata.kv_indptr,
|
||||
self.forward_metadata.kv_indices,
|
||||
self.kv_last_page_lens,
|
||||
1,
|
||||
self.max_num_partitions,
|
||||
None,
|
||||
"auto",
|
||||
"NHD",
|
||||
self.logits_soft_cap,
|
||||
self.k_scale,
|
||||
self.v_scale,
|
||||
None,
|
||||
_AITER_PARTITION_SIZE_ROCM,
|
||||
)
|
||||
if self.use_mla:
|
||||
k_buffer = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
|
||||
mla_decode_fwd(
|
||||
q.view(-1, layer.tp_q_head_num, layer.qk_head_dim),
|
||||
k_buffer.view(-1, 1, 1, layer.qk_head_dim),
|
||||
o.view(-1, layer.tp_q_head_num, layer.v_head_dim),
|
||||
self.forward_metadata.qo_indptr,
|
||||
self.forward_metadata.kv_indptr,
|
||||
self.forward_metadata.kv_indices,
|
||||
self.forward_metadata.kv_last_page_len,
|
||||
self.forward_metadata.max_extend_len,
|
||||
layer.scaling,
|
||||
layer.logit_cap,
|
||||
)
|
||||
k_buffer = k_buffer.view(-1, 1, layer.qk_head_dim)
|
||||
else:
|
||||
self.logits_soft_cap = layer.logit_cap
|
||||
paged_attention_ragged(
|
||||
o.view(-1, layer.tp_q_head_num, layer.qk_head_dim),
|
||||
self.workspace_buffer,
|
||||
q.view(-1, layer.tp_q_head_num, layer.qk_head_dim),
|
||||
forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id).view(
|
||||
-1, 1, layer.tp_k_head_num, layer.qk_head_dim
|
||||
),
|
||||
forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id).view(
|
||||
-1, 1, layer.tp_v_head_num, layer.v_head_dim
|
||||
),
|
||||
self.scale,
|
||||
self.forward_metadata.kv_indptr,
|
||||
self.forward_metadata.kv_indices,
|
||||
self.kv_last_page_len,
|
||||
1,
|
||||
self.max_num_partitions,
|
||||
None,
|
||||
"auto",
|
||||
"NHD",
|
||||
self.logits_soft_cap,
|
||||
self.k_scale,
|
||||
self.v_scale,
|
||||
None,
|
||||
_AITER_PARTITION_SIZE_ROCM,
|
||||
)
|
||||
|
||||
return o
|
||||
|
||||
@@ -506,9 +782,97 @@ class AiterIndicesUpdaterPrefill:
|
||||
spec_info.generate_attn_arg_prefill(
|
||||
req_pool_indices,
|
||||
paged_kernel_lens,
|
||||
None,
|
||||
paged_kernel_lens_sum,
|
||||
self.req_to_token,
|
||||
)
|
||||
)
|
||||
|
||||
self.kv_indices = kv_indices
|
||||
|
||||
|
||||
class AiterMlaIndicesUpdaterPrefill:
|
||||
def __init__(self, model_runner: ModelRunner, attn_backend: AttentionBackend):
|
||||
# Parse Constants
|
||||
self.attn_backend = attn_backend
|
||||
|
||||
# Buffers and wrappers
|
||||
self.req_to_token = model_runner.req_to_token_pool.req_to_token
|
||||
self.update = self.update_single_wrapper
|
||||
|
||||
self.kv_indptr = None
|
||||
self.kv_indices = None
|
||||
self.qo_indptr = None
|
||||
self.kv_last_page_len = None
|
||||
self.max_extend_len = 0
|
||||
self.max_prefix_extend_len = 0
|
||||
|
||||
def update(
|
||||
self,
|
||||
req_pool_indices: torch.Tensor,
|
||||
prefix_lens: torch.Tensor,
|
||||
prefix_lens_sum: int,
|
||||
extend_lens: torch.Tensor,
|
||||
encoder_lens: Optional[torch.Tensor],
|
||||
spec_info: Optional[SpecInfo],
|
||||
):
|
||||
# Keep the signature for type checking. It will be assigned during runtime.
|
||||
raise NotImplementedError()
|
||||
|
||||
def update_single_wrapper(
|
||||
self,
|
||||
req_pool_indices: torch.Tensor,
|
||||
prefix_lens: torch.Tensor,
|
||||
prefix_lens_sum: int,
|
||||
extend_lens: torch.Tensor,
|
||||
encoder_lens: Optional[torch.Tensor],
|
||||
spec_info: Optional[SpecInfo],
|
||||
):
|
||||
|
||||
paged_kernel_lens = prefix_lens
|
||||
paged_kernel_lens_sum = prefix_lens_sum
|
||||
|
||||
bs = len(req_pool_indices)
|
||||
|
||||
kv_indptr = self.attn_backend.kv_indptr
|
||||
|
||||
if spec_info is None:
|
||||
# Normal extend
|
||||
kv_indptr[1 : bs + 1] = torch.cumsum(paged_kernel_lens, dim=0)
|
||||
kv_indptr = kv_indptr[: bs + 1]
|
||||
kv_indices = torch.empty(
|
||||
paged_kernel_lens_sum,
|
||||
dtype=torch.int32,
|
||||
device=req_pool_indices.device,
|
||||
)
|
||||
create_flashinfer_kv_indices_triton[(bs,)](
|
||||
self.req_to_token,
|
||||
req_pool_indices,
|
||||
paged_kernel_lens,
|
||||
kv_indptr,
|
||||
None,
|
||||
kv_indices,
|
||||
self.req_to_token.stride(0),
|
||||
)
|
||||
|
||||
qo_indptr = self.attn_backend.qo_indptr
|
||||
qo_indptr[1 : bs + 1] = torch.cumsum(extend_lens, dim=0)
|
||||
qo_indptr = qo_indptr[: bs + 1]
|
||||
|
||||
max_extend_len = torch.max(extend_lens).item()
|
||||
max_prefix_extend_len = torch.max(extend_lens + paged_kernel_lens).item()
|
||||
kv_indptr += qo_indptr
|
||||
else:
|
||||
kv_indices, kv_indptr, qo_indptr, custom_mask = (
|
||||
spec_info.generate_attn_arg_prefill(
|
||||
req_pool_indices,
|
||||
paged_kernel_lens,
|
||||
paged_kernel_lens_sum,
|
||||
self.req_to_token,
|
||||
)
|
||||
)
|
||||
|
||||
self.kv_indptr = kv_indptr
|
||||
self.kv_indices = kv_indices
|
||||
self.qo_indptr = qo_indptr
|
||||
self.max_extend_len = max_extend_len
|
||||
self.max_prefix_extend_len = max_prefix_extend_len
|
||||
|
||||
Reference in New Issue
Block a user