port deepseekv2 and mtp to main branch (#429)
### What this PR does / why we need it? This PR ports all the deepseek graph mode code and mtp code from v0.7.3 to the main branch --------- Signed-off-by: SidaoY <1024863041@qq.com> Signed-off-by: linfeng-yuan <1102311262@qq.com> Signed-off-by: Yizhou Liu <liuyizhou5@h-partners.com> Signed-off-by: mengwei805 <mengwei25@huawei.com> Signed-off-by: libaokui <libaokui@huawei.com> Signed-off-by: q00832892 <qiaoyang19@huawei.com> Signed-off-by: ganyi <pleaplusone.gy@gmail.com> Co-authored-by: SidaoY <1024863041@qq.com> Co-authored-by: linfeng-yuan <1102311262@qq.com> Co-authored-by: Yizhou Liu <liuyizhou5@h-partners.com> Co-authored-by: mengwei805 <mengwei25@huawei.com> Co-authored-by: libaokui <libaokui@huawei.com>
This commit is contained in:
@@ -31,13 +31,14 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||
AttentionLayer,
|
||||
AttentionMetadata, AttentionType,
|
||||
MLAAttentionImpl)
|
||||
from vllm.attention.backends.utils import (CommonAttentionState,
|
||||
from vllm.attention.backends.utils import (PAD_SLOT_ID, CommonAttentionState,
|
||||
CommonMetadataBuilder,
|
||||
compute_slot_mapping,
|
||||
compute_slot_mapping_start_idx,
|
||||
is_block_tables_empty)
|
||||
from vllm.utils import async_tensor_h2d, make_tensor_with_pad
|
||||
|
||||
from vllm_ascend.utils import VLLM_ENABLE_GRAPH_MODE
|
||||
from vllm_ascend.worker.model_runner import (
|
||||
ModelInputForNPUBuilder, ModelInputForNPUWithSamplingMetadata)
|
||||
|
||||
@@ -222,7 +223,7 @@ class AscendMLAAttentionBackend(AscendAttentionBackend):
|
||||
num_kv_heads: int,
|
||||
head_size: int,
|
||||
) -> Tuple[int, ...]:
|
||||
return (1, num_blocks, block_size, num_kv_heads * head_size)
|
||||
return (num_blocks, block_size, num_kv_heads, head_size)
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -552,10 +553,33 @@ class AscendMetadataBuilder(CommonMetadataBuilder[AscendMetadata]):
|
||||
inter_data.block_tables,
|
||||
)
|
||||
|
||||
def _get_graph_runner_block_tables(
|
||||
self, num_seqs: int,
|
||||
block_tables: List[List[int]]) -> torch.Tensor:
|
||||
# The shape of graph_block_tables is
|
||||
# [max batch size, max context len // block size].
|
||||
|
||||
max_batch_size, max_blocks = self.runner.graph_block_tables.shape
|
||||
assert max_batch_size >= num_seqs
|
||||
|
||||
graph_block_tables = self.runner.graph_block_tables # [:num_seqs]
|
||||
for i, block_table in enumerate(block_tables):
|
||||
if block_table:
|
||||
num_blocks = len(block_table)
|
||||
if num_blocks <= max_blocks:
|
||||
graph_block_tables[i, :num_blocks] = block_table
|
||||
else:
|
||||
graph_block_tables[
|
||||
i, :max_blocks] = block_table[:max_blocks]
|
||||
|
||||
return torch.from_numpy(graph_block_tables).to(
|
||||
device=self.runner.device, non_blocking=True)
|
||||
|
||||
def build(
|
||||
self,
|
||||
seq_lens: List[int],
|
||||
query_lens: List[int],
|
||||
graph_pad_size: int,
|
||||
):
|
||||
"""Build attention metadata with on-device tensors.
|
||||
|
||||
@@ -568,6 +592,7 @@ class AscendMetadataBuilder(CommonMetadataBuilder[AscendMetadata]):
|
||||
self.input_builder.chunked_prefill_enabled)
|
||||
|
||||
device = self.runner.device
|
||||
use_torchair_graph = graph_pad_size != -1
|
||||
|
||||
max_query_len = max(query_lens)
|
||||
max_prefill_seq_len = max(self.prefill_seq_lens, default=0)
|
||||
@@ -582,12 +607,36 @@ class AscendMetadataBuilder(CommonMetadataBuilder[AscendMetadata]):
|
||||
self.attn_mask = None
|
||||
num_decode_tokens = self.num_decode_tokens
|
||||
|
||||
if self.num_prefills == 0 and use_torchair_graph:
|
||||
num_seqs = len(seq_lens)
|
||||
self.slot_mapping.extend([PAD_SLOT_ID] * graph_pad_size)
|
||||
self.block_tables.extend([[]] * graph_pad_size)
|
||||
block_tables = self._get_graph_runner_block_tables(
|
||||
num_seqs, self.block_tables)
|
||||
else:
|
||||
block_tables = make_tensor_with_pad(
|
||||
self.block_tables,
|
||||
pad=0,
|
||||
dtype=torch.int32,
|
||||
device=device,
|
||||
)
|
||||
|
||||
if self.num_prefills > 0:
|
||||
self.attn_mask = AscendMetadataBuilder._attn_mask_builder.get_attn_mask( # type: ignore
|
||||
max_prefill_seq_len,
|
||||
self.input_builder.runner.model_config.dtype,
|
||||
self.input_builder.runner.device)
|
||||
else:
|
||||
self.attn_mask = None
|
||||
num_decode_tokens = self.num_decode_tokens
|
||||
|
||||
block_tables = make_tensor_with_pad(
|
||||
self.block_tables,
|
||||
pad=0,
|
||||
dtype=torch.int32,
|
||||
device=device,
|
||||
)
|
||||
|
||||
assert max_query_len > 0, "query_lens: {}".format(query_lens)
|
||||
|
||||
assert device is not None
|
||||
@@ -855,14 +904,100 @@ class AscendMLAAttentionBackendImpl(MLAAttentionImpl):
|
||||
self.q_proj = extra_impl_args['q_proj']
|
||||
self.kv_b_proj = extra_impl_args['kv_b_proj']
|
||||
self.o_proj = extra_impl_args['o_proj']
|
||||
self.kv_a_proj_with_mqa = extra_impl_args.get('kv_a_proj_with_mqa',
|
||||
None)
|
||||
self.kv_a_layernorm = extra_impl_args.get('kv_a_layernorm', None)
|
||||
self.k_pe_cache = None
|
||||
self.k_nope_cache = None
|
||||
self.w_kc = None
|
||||
self.w_vc = None
|
||||
|
||||
def exec_kv(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
cos: torch.Tensor,
|
||||
sin: torch.Tensor,
|
||||
kv_cache: Tuple,
|
||||
slots: torch.Tensor,
|
||||
):
|
||||
B = hidden_states.shape[0]
|
||||
N = self.num_kv_heads
|
||||
S = 1
|
||||
kv = self.kv_a_proj_with_mqa(hidden_states)[0]
|
||||
# npu_kv_rmsnorm_rope_cache needs [B, N, S, D]
|
||||
kv = kv.view(B, N, S, self.kv_lora_rank + self.qk_rope_head_dim)
|
||||
|
||||
k_pe, k_nope = torch.ops.npu_inference.npu_kv_rmsnorm_rope_cache(
|
||||
kv,
|
||||
self.kv_a_layernorm.weight,
|
||||
cos,
|
||||
sin,
|
||||
slots.to(torch.int64),
|
||||
kv_cache[1],
|
||||
kv_cache[0],
|
||||
epsilon=self.kv_a_layernorm.variance_epsilon,
|
||||
cache_mode="PA",
|
||||
)
|
||||
|
||||
return k_pe, k_nope
|
||||
|
||||
def apply_rotary_emb(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
cos: torch.Tensor,
|
||||
sin: torch.Tensor,
|
||||
is_neox_style: bool,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
x: [num_tokens, num_heads, head_size]
|
||||
cos: [num_tokens, head_size // 2]
|
||||
sin: [num_tokens, head_size // 2]
|
||||
is_neox_style: Whether to use the Neox-style or GPT-J-style rotary
|
||||
positional embeddings.
|
||||
"""
|
||||
cos = cos.unsqueeze(-2).to(x.dtype)
|
||||
sin = sin.unsqueeze(-2).to(x.dtype)
|
||||
if is_neox_style:
|
||||
x1, x2 = torch.chunk(x, 2, dim=-1)
|
||||
else:
|
||||
x1 = x[..., ::2]
|
||||
x2 = x[..., 1::2]
|
||||
o1 = x1 * cos - x2 * sin
|
||||
o2 = x2 * cos + x1 * sin
|
||||
if is_neox_style:
|
||||
return torch.cat((o1, o2), dim=-1)
|
||||
else:
|
||||
return torch.stack((o1, o2), dim=-1).flatten(-2)
|
||||
|
||||
def rope_single(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
cos: torch.Tensor,
|
||||
sin: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
B, N, D = x.shape
|
||||
S = 1
|
||||
x = x.view(B, N, S, D)
|
||||
x = torch.ops.npu_inference.npu_interleave_rope(x, cos, sin)
|
||||
return x.view(B, N, D)
|
||||
|
||||
def process_weights_after_loading(self, act_dtype: torch.dtype):
|
||||
if self.w_kc is None or self.w_vc is None:
|
||||
kv_b_proj_weight = self.kv_b_proj.weight.reshape(
|
||||
self.num_heads, self.qk_nope_head_dim + self.v_head_dim,
|
||||
self.kv_lora_rank)
|
||||
self.w_kc = kv_b_proj_weight[:, :self.
|
||||
qk_nope_head_dim, :].contiguous()
|
||||
self.w_vc = kv_b_proj_weight[:,
|
||||
self.qk_nope_head_dim:, :].transpose(
|
||||
1, 2).contiguous()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
layer: AttentionLayer,
|
||||
hidden_states_or_q_c: torch.Tensor,
|
||||
kv_c_normed: torch.Tensor,
|
||||
hidden_states_or_kv_c_normed: torch.Tensor,
|
||||
k_pe: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AscendMetadata,
|
||||
@@ -873,7 +1008,7 @@ class AscendMLAAttentionBackendImpl(MLAAttentionImpl):
|
||||
Args:
|
||||
hidden_states_or_q_c: shape = [num_tokens, num_heads * head_size]
|
||||
num_tokens = batch_size * seq_len
|
||||
kv_c_normed: shape = [num_tokens, num_kv_heads * head_size]
|
||||
hidden_states_or_kv_c_normed: shape = [num_tokens, num_kv_heads * head_size]
|
||||
k_pe: shape = [num_tokens, num_kv_heads * head_size]
|
||||
kv_cache: shape = [1, num_blocks, block_size,
|
||||
num_kv_heads * head_size]
|
||||
@@ -889,71 +1024,86 @@ class AscendMLAAttentionBackendImpl(MLAAttentionImpl):
|
||||
"are not implemented for "
|
||||
"PallasAttentionBackendImpl")
|
||||
|
||||
if attn_metadata is None:
|
||||
# for profile run
|
||||
return hidden_states_or_q_c
|
||||
|
||||
num_tokens = hidden_states_or_q_c.shape[0]
|
||||
q = self.q_proj(hidden_states_or_q_c)[0].view(-1, self.num_heads,
|
||||
self.qk_head_dim)
|
||||
q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim],
|
||||
dim=-1)
|
||||
if k_pe is None and attn_metadata.decode_metadata:
|
||||
seq_len = self.rotary_emb.max_position_embeddings
|
||||
|
||||
k_pe = k_pe.view(num_tokens, self.num_kv_heads, -1)
|
||||
cos = self.rotary_emb.cos_cached[:seq_len].to(dtype=q_pe.dtype)
|
||||
sin = self.rotary_emb.sin_cached[:seq_len].to(dtype=q_pe.dtype)
|
||||
cos = cos[attn_metadata.input_positions]
|
||||
sin = sin[attn_metadata.input_positions]
|
||||
cos = cos[:, None, None, :]
|
||||
sin = sin[:, None, None, :]
|
||||
|
||||
if self.rotary_emb.__class__.__name__ == 'RotaryEmbedding':
|
||||
ori_q_pe_shape, ori_k_pe_shape = q_pe.shape, k_pe.shape
|
||||
q_pe = q_pe.reshape(num_tokens, -1)
|
||||
k_pe = k_pe.reshape(num_tokens, -1)
|
||||
q_pe, k_pe = self.rotary_emb(attn_metadata.input_positions, q_pe,
|
||||
k_pe)
|
||||
q_pe = q_pe.view(ori_q_pe_shape)
|
||||
k_pe = k_pe.view(ori_k_pe_shape)
|
||||
q_pe = self.rope_single(q_pe, cos, sin)
|
||||
k_pe, k_nope = self.exec_kv(hidden_states_or_kv_c_normed, cos, sin,
|
||||
kv_cache, attn_metadata.slot_mapping)
|
||||
else:
|
||||
q_pe, k_pe = self.rotary_emb(attn_metadata.input_positions, q_pe,
|
||||
k_pe)
|
||||
|
||||
if self.w_kc is None or self.w_vc is None:
|
||||
kv_b_proj_weight = self.kv_b_proj.weight.reshape(
|
||||
self.num_heads, self.qk_nope_head_dim + self.v_head_dim,
|
||||
self.kv_lora_rank)
|
||||
self.w_kc = kv_b_proj_weight[:, :self.
|
||||
qk_nope_head_dim, :].contiguous()
|
||||
self.w_vc = kv_b_proj_weight[:,
|
||||
self.qk_nope_head_dim:, :].transpose(
|
||||
1, 2).contiguous()
|
||||
if k_pe is None:
|
||||
# NOTE: k_pe is None when graph mode enabled
|
||||
kv_c, k_pe = self.kv_a_proj_with_mqa(
|
||||
hidden_states_or_kv_c_normed)[0].split(
|
||||
[self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
|
||||
kv_c_normed = self.kv_a_layernorm(kv_c.contiguous())
|
||||
else:
|
||||
kv_c_normed = hidden_states_or_kv_c_normed
|
||||
k_pe = k_pe.view(num_tokens, self.num_kv_heads, -1)
|
||||
if self.rotary_emb.__class__.__name__ == 'RotaryEmbedding':
|
||||
# NOTE: When scaling not specified
|
||||
ori_q_pe_shape, ori_k_pe_shape = q_pe.shape, k_pe.shape
|
||||
q_pe = q_pe.reshape(num_tokens, -1)
|
||||
k_pe = k_pe.reshape(num_tokens, -1)
|
||||
q_pe, k_pe = self.rotary_emb(attn_metadata.input_positions,
|
||||
q_pe, k_pe)
|
||||
q_pe = q_pe.view(ori_q_pe_shape)
|
||||
k_pe = k_pe.view(ori_k_pe_shape)
|
||||
else:
|
||||
q_pe, k_pe = self.rotary_emb(attn_metadata.input_positions,
|
||||
q_pe, k_pe)
|
||||
|
||||
if attn_metadata.num_prefills > 0:
|
||||
kv_heads_num = self.num_heads
|
||||
kv = self.kv_b_proj(kv_c_normed)[0].view(num_tokens, kv_heads_num,
|
||||
-1)
|
||||
kv = self.kv_b_proj(kv_c_normed)[0].view(num_tokens,
|
||||
self.num_heads, -1)
|
||||
k_nope, value = kv.split([self.qk_nope_head_dim, self.v_head_dim],
|
||||
dim=-1)
|
||||
k_cache = torch.cat(
|
||||
[kv_c_normed.view(num_tokens, self.num_kv_heads, -1), k_pe],
|
||||
dim=2)
|
||||
k_pe = k_pe.expand(-1, self.num_heads, -1)
|
||||
key = torch.cat([k_nope.view(num_tokens, kv_heads_num, -1), k_pe],
|
||||
dim=2)
|
||||
else:
|
||||
kv_heads_num = self.num_kv_heads
|
||||
q_nope_t = torch.transpose(q_nope, 0, 1)
|
||||
q_nope_out = torch.bmm(q_nope_t, self.w_kc)
|
||||
q_nope = torch.transpose(q_nope_out, 0, 1)
|
||||
k_cache = torch.cat(
|
||||
[kv_c_normed.view(num_tokens, self.num_kv_heads, -1), k_pe],
|
||||
dim=2)
|
||||
|
||||
query = torch.cat([q_nope, q_pe], dim=-1).view(num_tokens,
|
||||
self.num_heads, -1)
|
||||
|
||||
if kv_cache.numel() > 0:
|
||||
key_cache = kv_cache[0]
|
||||
num_blocks, block_size, _ = key_cache.shape
|
||||
|
||||
key_cache = key_cache.view(
|
||||
num_blocks, block_size, self.num_kv_heads,
|
||||
self.qk_rope_head_dim + self.kv_lora_rank)
|
||||
slots = attn_metadata.slot_mapping
|
||||
torch_npu._npu_reshape_and_cache_siso(key=k_cache,
|
||||
key_cache=key_cache,
|
||||
slot_indices=slots)
|
||||
# TODO: Replace the env with more flexible expressions
|
||||
if VLLM_ENABLE_GRAPH_MODE == '1':
|
||||
if len(kv_cache) > 0 and kv_cache[0].numel(
|
||||
) > 0 and attn_metadata.num_prefills > 0:
|
||||
slots = attn_metadata.slot_mapping
|
||||
# NOTE: Seperate the kv cache in advance to avoid OOM or other issues
|
||||
torch_npu._npu_reshape_and_cache(key=kv_c_normed.view(
|
||||
num_tokens, self.num_kv_heads, -1),
|
||||
value=k_pe,
|
||||
key_cache=kv_cache[0],
|
||||
value_cache=kv_cache[1],
|
||||
slot_indices=slots)
|
||||
else:
|
||||
if kv_cache.numel() > 0:
|
||||
key = torch.cat([
|
||||
kv_c_normed.view(num_tokens, self.num_kv_heads, -1), k_pe
|
||||
],
|
||||
dim=2)
|
||||
slots = attn_metadata.slot_mapping
|
||||
torch_npu._npu_reshape_and_cache_siso(key=key,
|
||||
key_cache=kv_cache,
|
||||
slot_indices=slots)
|
||||
|
||||
if attn_metadata.num_prefills > 0:
|
||||
attn_output = torch.empty(num_tokens,
|
||||
@@ -964,12 +1114,15 @@ class AscendMLAAttentionBackendImpl(MLAAttentionImpl):
|
||||
if (attn_metadata.block_tables is None
|
||||
or attn_metadata.block_tables.numel() == 0):
|
||||
assert attn_metadata.attn_mask is not None
|
||||
mask = attn_metadata.attn_mask
|
||||
assert attn_metadata.prefill_metadata is not None
|
||||
assert attn_metadata.prefill_metadata.seq_lens is not None
|
||||
mask = attn_metadata.attn_mask
|
||||
self.seq_lens_tensor_cpu = torch.from_numpy(
|
||||
np.array(attn_metadata.prefill_metadata.seq_lens).astype(
|
||||
np.int32))
|
||||
k_pe = k_pe.repeat(1, self.num_heads, 1)
|
||||
key = torch.cat(
|
||||
[k_nope.view(num_tokens, self.num_heads, -1), k_pe], dim=2)
|
||||
torch_npu._npu_flash_attention(
|
||||
query=query,
|
||||
key=key,
|
||||
@@ -987,29 +1140,55 @@ class AscendMLAAttentionBackendImpl(MLAAttentionImpl):
|
||||
)
|
||||
elif attn_metadata.decode_metadata:
|
||||
assert kv_cache is not None
|
||||
# if torch.empty is used here, the preemptive scheduling case of
|
||||
# test_mtp_correctness.py will fail to run.
|
||||
attn_output = torch.randn(
|
||||
[num_tokens, self.num_heads, self.kv_lora_rank],
|
||||
dtype=query.dtype,
|
||||
device=query.device)
|
||||
self.seq_lens_tensor_cpu = torch.from_numpy(
|
||||
np.array(attn_metadata.decode_metadata.seq_lens).astype(
|
||||
np.int32))
|
||||
block_tables = attn_metadata.decode_metadata.block_tables
|
||||
torch_npu._npu_paged_attention_mla(
|
||||
query=query,
|
||||
key_cache=key_cache,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
num_heads=self.num_heads,
|
||||
scale_value=self.scale,
|
||||
block_table=block_tables,
|
||||
context_lens=self.seq_lens_tensor_cpu,
|
||||
mla_vheadsize=self.kv_lora_rank,
|
||||
out=attn_output)
|
||||
attn_output_t = torch.transpose(attn_output, 0, 1)
|
||||
attn_output_t = torch.bmm(attn_output_t, self.w_vc)
|
||||
attn_output = torch.transpose(attn_output_t, 0, 1)
|
||||
if VLLM_ENABLE_GRAPH_MODE == '1':
|
||||
# TorchAir's shape is [bs, num_heads_per_rank, seq_len, dim]
|
||||
q_nope = q_nope.view(num_tokens, self.num_heads, 1, -1)
|
||||
q_pe = q_pe.view(num_tokens, self.num_heads, 1, -1)
|
||||
attn_output, _ = torch.ops.npu.npu_fused_infer_attention_score(
|
||||
q_nope,
|
||||
k_nope,
|
||||
k_nope,
|
||||
query_rope=q_pe,
|
||||
key_rope=k_pe,
|
||||
num_heads=self.num_heads,
|
||||
num_key_value_heads=1,
|
||||
input_layout="BNSD",
|
||||
atten_mask=attn_metadata.attn_mask,
|
||||
scale=self.scale,
|
||||
antiquant_mode=0,
|
||||
antiquant_scale=None,
|
||||
block_table=attn_metadata.block_tables,
|
||||
block_size=kv_cache[0].shape[1],
|
||||
actual_seq_lengths_kv=attn_metadata.seq_lens,
|
||||
)
|
||||
attn_output = attn_output.view(num_tokens, -1,
|
||||
self.kv_lora_rank).transpose(
|
||||
0, 1)
|
||||
attn_output = torch.bmm(attn_output, self.w_vc).transpose(0, 1)
|
||||
else:
|
||||
# if torch.empty is used here, the preemptive scheduling case of
|
||||
# test_mtp_correctness.py will fail to run.
|
||||
attn_output = torch.randn(
|
||||
[num_tokens, self.num_heads, self.kv_lora_rank],
|
||||
dtype=query.dtype,
|
||||
device=query.device)
|
||||
self.seq_lens_tensor_cpu = torch.from_numpy(
|
||||
np.array(attn_metadata.decode_metadata.seq_lens).astype(
|
||||
np.int32))
|
||||
block_tables = attn_metadata.decode_metadata.block_tables
|
||||
torch_npu._npu_paged_attention_mla(
|
||||
query=query,
|
||||
key_cache=kv_cache,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
num_heads=self.num_heads,
|
||||
scale_value=self.scale,
|
||||
block_table=block_tables,
|
||||
context_lens=self.seq_lens_tensor_cpu,
|
||||
mla_vheadsize=self.kv_lora_rank,
|
||||
out=attn_output)
|
||||
attn_output_t = torch.transpose(attn_output, 0, 1)
|
||||
attn_output_t = torch.bmm(attn_output_t, self.w_vc)
|
||||
attn_output = torch.transpose(attn_output_t, 0, 1)
|
||||
|
||||
output, _ = self.o_proj(attn_output.reshape(num_tokens, -1))
|
||||
|
||||
|
||||
@@ -24,6 +24,10 @@ import torch_npu
|
||||
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||
AttentionLayer, AttentionType)
|
||||
from vllm.attention.backends.utils import CommonAttentionState
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
from vllm.v1.worker.gpu_input_batch import InputBatch
|
||||
|
||||
from vllm_ascend.ops.attention import vanilla_chunked_prefill
|
||||
|
||||
|
||||
class AscendAttentionBackend(AttentionBackend):
|
||||
@@ -44,6 +48,10 @@ class AscendAttentionBackend(AttentionBackend):
|
||||
def get_state_cls() -> Type["CommonAttentionState"]:
|
||||
return CommonAttentionState
|
||||
|
||||
@staticmethod
|
||||
def get_builder_cls() -> type["AscendAttentionMetadataBuilder"]:
|
||||
return AscendAttentionMetadataBuilder
|
||||
|
||||
@staticmethod
|
||||
def get_kv_cache_shape(
|
||||
num_blocks: int,
|
||||
@@ -94,11 +102,11 @@ class AscendAttentionState(Enum):
|
||||
class AscendMetadata:
|
||||
# (batch_size, max_blocks_per_seq).
|
||||
# Block addresses per sequence. (Seq id -> list of physical block)
|
||||
block_tables: Optional[torch.Tensor]
|
||||
block_tables: torch.Tensor
|
||||
# (batch_size,). The sequence length per sequence. Sequence length means
|
||||
# the computed tokens + new tokens None if it is a decoding.
|
||||
seq_lens: Optional[List[int]] = None
|
||||
context_lens: Optional[List[int]] = None
|
||||
query_lens: torch.Tensor
|
||||
seq_lens: torch.Tensor
|
||||
# Maximum query length in the batch. None for decoding.
|
||||
max_query_len: Optional[int] = None
|
||||
# (num_tokens,). The indices of the token slots that input tokens will be
|
||||
@@ -117,6 +125,36 @@ class AscendMetadata:
|
||||
attn_mask: Optional[torch.Tensor] = None
|
||||
|
||||
|
||||
class AscendAttentionMetadataBuilder:
|
||||
|
||||
def __init__(self, runner):
|
||||
self.runner = runner
|
||||
|
||||
def reorder_batch(self, input_batch: "InputBatch",
|
||||
scheduler_output: "SchedulerOutput") -> bool:
|
||||
return False
|
||||
|
||||
def build(self, num_reqs, num_actual_tokens, max_query_len,
|
||||
common_prefix_len):
|
||||
block_table = (
|
||||
self.runner.input_batch.block_table.get_device_tensor()[:num_reqs])
|
||||
query_lens = self.runner.query_lens
|
||||
seq_lens = self.runner.seq_lens_cpu[:num_reqs]
|
||||
slot_mapping = self.runner.slot_mapping_cpu[:num_actual_tokens].to(
|
||||
self.runner.device, non_blocking=True)
|
||||
attn_mask = self.runner.attn_mask
|
||||
attn_state = self.runner.attn_state
|
||||
|
||||
attn_metadata = AscendMetadata(block_tables=block_table,
|
||||
query_lens=query_lens,
|
||||
seq_lens=seq_lens,
|
||||
max_query_len=max_query_len,
|
||||
slot_mapping=slot_mapping,
|
||||
attn_mask=attn_mask,
|
||||
attn_state=attn_state)
|
||||
return attn_metadata
|
||||
|
||||
|
||||
class AscendAttentionBackendImpl(AttentionImpl):
|
||||
|
||||
def __init__(
|
||||
@@ -229,29 +267,46 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
||||
out=output)
|
||||
elif attn_metadata.attn_state == AscendAttentionState.DecodeOnly:
|
||||
block_tables = attn_metadata.block_tables
|
||||
torch_npu._npu_paged_attention(
|
||||
query=query,
|
||||
key_cache=self.key_cache,
|
||||
value_cache=self.value_cache,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
num_heads=self.num_heads,
|
||||
scale_value=self.scale,
|
||||
block_table=block_tables,
|
||||
context_lens=attn_metadata.context_lens,
|
||||
out=output)
|
||||
torch_npu._npu_paged_attention(query=query,
|
||||
key_cache=self.key_cache,
|
||||
value_cache=self.value_cache,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
num_heads=self.num_heads,
|
||||
scale_value=self.scale,
|
||||
block_table=block_tables,
|
||||
context_lens=attn_metadata.seq_lens,
|
||||
out=output)
|
||||
# Normal V1 situation.
|
||||
else:
|
||||
# use paged attention
|
||||
torch_npu._npu_paged_attention_splitfuse(
|
||||
query=query,
|
||||
key_cache=self.key_cache,
|
||||
value_cache=self.value_cache,
|
||||
mask=attn_metadata.attn_mask,
|
||||
block_table=attn_metadata.block_tables,
|
||||
seq_len=attn_metadata.seq_lens,
|
||||
context_lens=attn_metadata.context_lens,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
num_heads=self.num_heads,
|
||||
scale_value=self.scale,
|
||||
out=output)
|
||||
# use chunked prefill for head size 192 scenario, like deepseek
|
||||
# paged_attention_splitfuse maybe crash at such scenario
|
||||
# TODO: vanilla path will be removed after the kernel support
|
||||
# head_size 192 scenario
|
||||
if self.head_size == 192:
|
||||
cu_seqlen_q = [0] + attn_metadata.query_lens.tolist()
|
||||
cu_seqlen_k = [0] + attn_metadata.seq_lens.tolist()
|
||||
cu_seqlen_q = torch.tensor(cu_seqlen_q, device="npu")
|
||||
cu_seqlen_k = torch.tensor(cu_seqlen_k, device="npu")
|
||||
cu_seqlen_q = torch.cumsum(cu_seqlen_q, dim=0)
|
||||
cu_seqlen_k = torch.cumsum(cu_seqlen_k, dim=0)
|
||||
max_seqlen_q = torch.max(attn_metadata.query_lens)
|
||||
max_seqlen_k = torch.max(attn_metadata.seq_lens)
|
||||
vanilla_chunked_prefill(output, query, self.key_cache,
|
||||
self.value_cache,
|
||||
attn_metadata.block_tables,
|
||||
cu_seqlen_q, cu_seqlen_k, max_seqlen_q,
|
||||
max_seqlen_k, self.scale, None, True)
|
||||
else:
|
||||
torch_npu._npu_paged_attention_splitfuse(
|
||||
query=query,
|
||||
key_cache=self.key_cache,
|
||||
value_cache=self.value_cache,
|
||||
mask=attn_metadata.attn_mask,
|
||||
block_table=attn_metadata.block_tables,
|
||||
seq_len=attn_metadata.query_lens,
|
||||
context_lens=attn_metadata.seq_lens,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
num_heads=self.num_heads,
|
||||
scale_value=self.scale,
|
||||
out=output)
|
||||
return output.view(num_tokens, self.hidden_size)
|
||||
|
||||
561
vllm_ascend/attention/mla_v1.py
Normal file
561
vllm_ascend/attention/mla_v1.py
Normal file
@@ -0,0 +1,561 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Any, Optional, Type, TypeVar
|
||||
|
||||
import torch
|
||||
import torch_npu
|
||||
from vllm.attention.backends.abstract import (AttentionBackend, AttentionLayer,
|
||||
AttentionMetadata,
|
||||
MLAAttentionImpl)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
LinearBase, RowParallelLinear,
|
||||
UnquantizedLinearMethod)
|
||||
from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
|
||||
|
||||
from vllm_ascend.attention.attention_v1 import AscendAttentionState
|
||||
from vllm_ascend.ops.attention import vanilla_chunked_prefill_mla
|
||||
from vllm_ascend.ops.cache import concat_and_cache_mla
|
||||
from vllm_ascend.worker.model_runner_v1 import NPUModelRunner
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
from vllm.v1.worker.gpu_input_batch import InputBatch
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class AscendMLABackend(AttentionBackend):
|
||||
|
||||
accept_output_buffer: bool = True
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "VLLM_ASCEND_MLA"
|
||||
|
||||
@staticmethod
|
||||
def get_metadata_cls() -> type["AttentionMetadata"]:
|
||||
return AscendMLAMetadata
|
||||
|
||||
@staticmethod
|
||||
def get_builder_cls():
|
||||
return AscendMLAMetadataBuilder
|
||||
|
||||
@staticmethod
|
||||
def get_kv_cache_shape(num_blocks: int, block_size: int, num_kv_heads: int,
|
||||
head_size: int) -> tuple[int, ...]:
|
||||
return (num_blocks, block_size, num_kv_heads, head_size)
|
||||
|
||||
@staticmethod
|
||||
def get_impl_cls() -> Type["MLAAttentionImpl"]:
|
||||
return AscendMLAImpl
|
||||
|
||||
|
||||
@dataclass
|
||||
class AscendMLAPrefillMetadata:
|
||||
""" Prefill Specific Metadata for Ascend"""
|
||||
attn_mask: torch.Tensor
|
||||
query_lens: list[int]
|
||||
context_lens: torch.Tensor
|
||||
input_positions: torch.Tensor
|
||||
block_table: torch.Tensor
|
||||
max_query_len: int
|
||||
max_context_len: int
|
||||
|
||||
|
||||
@dataclass
|
||||
class AscendMLADecodeMetadata:
|
||||
# Input positions for rotrary embeddings since for MLA the rotary
|
||||
# position embeddings are applied inside the attention backend
|
||||
input_positions: torch.Tensor
|
||||
block_table: torch.Tensor
|
||||
seq_lens: torch.Tensor
|
||||
|
||||
|
||||
@dataclass
|
||||
class AscendMLAMetadata:
|
||||
"""Metadata for MLACommon.
|
||||
|
||||
NOTE: Please read the comment at the top of the file before trying to
|
||||
understand this class
|
||||
"""
|
||||
# NOTE(sang): Definition of context_len, query_len, and seq_len.
|
||||
# |---------- N-1 iteration --------|
|
||||
# |---------------- N iteration ---------------------|
|
||||
# |- tokenA -|......................|-- newTokens ---|
|
||||
# |---------- context_len ----------|
|
||||
# |-------------------- seq_len ---------------------|
|
||||
# |-- query_len ---|
|
||||
|
||||
num_actual_tokens: int # Number of tokens excluding padding.
|
||||
slot_mapping: torch.Tensor
|
||||
|
||||
# New for MLA (compared to FlashAttention)
|
||||
# For handling prefill decode split
|
||||
num_decodes: int
|
||||
num_decode_tokens: int
|
||||
num_prefills: int
|
||||
|
||||
# For logging.
|
||||
num_input_tokens: int = 0 # Number of tokens including padding.
|
||||
|
||||
# The dimension of the attention heads
|
||||
head_dim: Optional[int] = None
|
||||
attn_mask: torch.Tensor = None
|
||||
# chunked prefill by default if no attn_states passed
|
||||
attn_state: AscendAttentionState = AscendAttentionState.ChunkedPrefill
|
||||
|
||||
decode: Optional[AscendMLADecodeMetadata] = None
|
||||
prefill: Optional[AscendMLAPrefillMetadata] = None
|
||||
|
||||
def __post_init__(self):
|
||||
pass
|
||||
# supported_head_sizes = AscendMLABackend.get_supported_head_sizes()
|
||||
# if self.head_dim is not None and self.head_dim \
|
||||
# not in supported_head_sizes:
|
||||
# raise ValueError(
|
||||
# f"Only {supported_head_sizes} are supported for head_dim,",
|
||||
# f"received {self.head_dim}.")
|
||||
|
||||
|
||||
M = TypeVar("M", bound=AscendMLAMetadata)
|
||||
|
||||
|
||||
class AscendMLAMetadataBuilder:
|
||||
"""
|
||||
NOTE: Please read the comment at the top of the file before trying to
|
||||
understand this class
|
||||
"""
|
||||
|
||||
# _attn_mask_builder = None
|
||||
def __init__(self,
|
||||
runner: "NPUModelRunner",
|
||||
metadata_cls: Optional[AscendMLAMetadata] = None):
|
||||
self.metadata_cls: Optional[AscendMLAMetadata] = metadata_cls \
|
||||
if metadata_cls is not None else AscendMLAMetadata # type: ignore
|
||||
self.runner = runner
|
||||
scheduler_config = runner.scheduler_config
|
||||
self.chunked_prefill_enabled = scheduler_config.chunked_prefill_enabled
|
||||
# self.attn_mask = None
|
||||
# if AscendMLAMetadataBuilder._attn_mask_builder is None:
|
||||
# AscendMLAMetadataBuilder._attn_mask_builder = AttentionMaskBuilder.initialize_from_len(
|
||||
# 128, self.runner.model_config.dtype
|
||||
# )
|
||||
|
||||
def reorder_batch(self, input_batch: "InputBatch",
|
||||
scheduler_output: "SchedulerOutput") -> bool:
|
||||
# We now want to reorder the batch so that the "decode" requests are at
|
||||
# the front and the "prefill" requests are at the using the least amount
|
||||
# swaps possible. (NOTE for now we loosely use "decode" to mean requests
|
||||
# where attention is likely memory-bound and "prefill" to mean requests
|
||||
# where attention is likely compute-bound, TODO(lucas): figure out a
|
||||
# better naming here)
|
||||
decodes = []
|
||||
prefills = []
|
||||
num_decode_tokens = 0
|
||||
num_prefill_tokens = 0
|
||||
|
||||
for i, req_id in enumerate(input_batch.req_ids):
|
||||
num_tokens = scheduler_output.num_scheduled_tokens[req_id]
|
||||
# for now treat 1 scheduled token as "decode" even if its not,
|
||||
# we should update this to something like < 8 in the future but
|
||||
# currently the TritonMLA._forward_decode only supports
|
||||
# num_tokens = 1
|
||||
if num_tokens == 1:
|
||||
decodes.append(i)
|
||||
num_decode_tokens += num_tokens
|
||||
else:
|
||||
prefills.append(i)
|
||||
num_prefill_tokens += num_tokens
|
||||
|
||||
# We hope that this is fairly minimal since decodes
|
||||
# should be around for a number of iterations so hopefully they are
|
||||
# relatively stationary (and new request are generally appended to the
|
||||
# persistent batch so already should be at the back)
|
||||
# To achieve this we loop over the decodes in descending order and
|
||||
# the prefills in ascending order. We swap decodes from the "back"
|
||||
# i.e. past where the last decode should be in the reodorered with
|
||||
# prefills from the front of the batch.
|
||||
# `decodes` and `prefills` are already in ascending order just based on
|
||||
# the above loop
|
||||
num_decodes = len(decodes)
|
||||
num_prefills = len(prefills)
|
||||
first_prefill = 0
|
||||
modified_batch = False
|
||||
|
||||
for i in range(1, min(num_decodes, num_prefills) + 1):
|
||||
# If the decode is at the "back" of the batch, i, we can swap it
|
||||
# with the prefill closest to the front of the batch
|
||||
if decodes[num_decodes - i] >= num_decodes:
|
||||
input_batch.swap_states(prefills[first_prefill],
|
||||
decodes[num_decodes - i])
|
||||
first_prefill += 1
|
||||
modified_batch = True
|
||||
else:
|
||||
break
|
||||
|
||||
# Save for next `build` call
|
||||
# TODO(lucas): this is a bit of a hack, we should probably have a
|
||||
# better way of doing this
|
||||
self._num_decodes = num_decodes
|
||||
self._num_prefills = num_prefills
|
||||
self._num_decode_tokens = num_decode_tokens
|
||||
self._num_prefill_tokens = num_prefill_tokens
|
||||
|
||||
return modified_batch
|
||||
|
||||
def build(self,
|
||||
num_reqs: int,
|
||||
num_actual_tokens: int,
|
||||
max_query_len: int,
|
||||
common_prefix_len: Optional[int] = None) -> AscendMLAMetadata:
|
||||
assert self._num_decodes + self._num_prefills == num_reqs
|
||||
|
||||
# Note(simon): be careful about the CPU <> GPU memory movement in this
|
||||
# function. We should avoid GPU -> CPU sync as much as possible because
|
||||
# it blocks on all previous kernels.
|
||||
device = self.runner.device
|
||||
block_table = (
|
||||
self.runner.input_batch.block_table.get_device_tensor()[:num_reqs])
|
||||
slot_mapping = self.runner.slot_mapping_cpu[:num_actual_tokens].to(
|
||||
device, non_blocking=True).long()
|
||||
input_positions = self.runner.positions_cpu[:num_actual_tokens].to(
|
||||
device, non_blocking=True).long()
|
||||
|
||||
seq_lens_cpu = self.runner.seq_lens_cpu[:num_reqs]
|
||||
query_lens = seq_lens_cpu - self.runner.input_batch.num_computed_tokens_cpu_tensor[:
|
||||
num_reqs]
|
||||
seq_lens = seq_lens_cpu
|
||||
max_query_len = query_lens.max().item()
|
||||
max_context_len = seq_lens.max().item()
|
||||
|
||||
prefill_metadata = None
|
||||
if self._num_prefills > 0:
|
||||
reqs_start = self._num_decodes # prefill_start
|
||||
tokens_start = self._num_decode_tokens
|
||||
|
||||
prefill_metadata = AscendMLAPrefillMetadata(
|
||||
attn_mask=self.runner.attn_mask,
|
||||
query_lens=query_lens[tokens_start:],
|
||||
context_lens=seq_lens[tokens_start:],
|
||||
input_positions=input_positions[tokens_start:],
|
||||
block_table=block_table[reqs_start:, ...],
|
||||
max_query_len=max_query_len,
|
||||
max_context_len=max_context_len,
|
||||
)
|
||||
|
||||
decode_metadata = None
|
||||
if self._num_decodes > 0:
|
||||
decode_metadata = AscendMLADecodeMetadata(
|
||||
input_positions=input_positions[:self._num_decode_tokens],
|
||||
block_table=block_table[:self._num_decode_tokens, ...],
|
||||
seq_lens=seq_lens[:self._num_decode_tokens])
|
||||
|
||||
return self.metadata_cls( # type: ignore
|
||||
num_actual_tokens=num_actual_tokens,
|
||||
slot_mapping=slot_mapping,
|
||||
head_dim=self.runner.model_config.get_head_size(),
|
||||
num_decodes=self._num_decodes,
|
||||
num_decode_tokens=self._num_decode_tokens,
|
||||
num_prefills=self._num_prefills,
|
||||
attn_mask=self.runner.attn_mask,
|
||||
attn_state=self.runner.attn_state,
|
||||
prefill=prefill_metadata,
|
||||
decode=decode_metadata,
|
||||
)
|
||||
|
||||
|
||||
class AscendMLAImpl(MLAAttentionImpl):
|
||||
"""
|
||||
NOTE: Please read the comment at the top of the file before trying to
|
||||
understand this class
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
scale: float,
|
||||
num_kv_heads: int,
|
||||
alibi_slopes: Optional[list[float]],
|
||||
sliding_window: Optional[int],
|
||||
kv_cache_dtype: str,
|
||||
blocksparse_params: Optional[dict[str, Any]],
|
||||
logits_soft_cap: Optional[float],
|
||||
attn_type: str,
|
||||
# MLA Specific Arguments
|
||||
q_lora_rank: Optional[int],
|
||||
kv_lora_rank: int,
|
||||
qk_nope_head_dim: int,
|
||||
qk_rope_head_dim: int,
|
||||
qk_head_dim: int,
|
||||
v_head_dim: int,
|
||||
rotary_emb: RotaryEmbedding,
|
||||
# q_proj should be q_b_proj if q_lora_rank is not None, but from an
|
||||
# attention backend perspective we rely on the layer to pass in the
|
||||
# correct matrix
|
||||
q_proj: ColumnParallelLinear,
|
||||
kv_b_proj: ColumnParallelLinear,
|
||||
o_proj: RowParallelLinear,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
self.num_heads = num_heads
|
||||
self.head_size = head_size
|
||||
self.scale = float(scale)
|
||||
self.num_kv_heads = num_kv_heads
|
||||
self.kv_cache_dtype = kv_cache_dtype
|
||||
|
||||
self.q_lora_rank = q_lora_rank
|
||||
self.kv_lora_rank = kv_lora_rank
|
||||
self.qk_nope_head_dim = qk_nope_head_dim
|
||||
self.qk_rope_head_dim = qk_rope_head_dim
|
||||
self.qk_head_dim = qk_head_dim
|
||||
self.v_head_dim = v_head_dim
|
||||
|
||||
# Hack for V1 for now to avoid torch library overhead (since we are
|
||||
# already inside an attention custom op), pull out the forward
|
||||
# method from the rotary embedding and call it directly
|
||||
# TODO(lucas): we should probably find a cleaner way to do this
|
||||
self.rotary_emb = rotary_emb.forward_native
|
||||
|
||||
self.q_proj = q_proj
|
||||
self.kv_b_proj = kv_b_proj
|
||||
self.o_proj = o_proj
|
||||
|
||||
# Handle the differences between the flash_attn_varlen from flash_attn
|
||||
# and the one from vllm_flash_attn. The former is used on RoCM and the
|
||||
# latter has an additional parameter to control FA2 vs FA3
|
||||
# self.flash_attn_varlen_func = flash_attn_varlen_func
|
||||
# if self.vllm_flash_attn_version is not None:
|
||||
# self.flash_attn_varlen_func = \
|
||||
# functools.partial(flash_attn_varlen_func,
|
||||
# fa_version=self.vllm_flash_attn_version)
|
||||
|
||||
def _v_up_proj_and_o_proj(self, x):
|
||||
# Convert from (B, N, L) to (N, B, L)
|
||||
x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1)
|
||||
# Multiply (N, B, L) x (N, L, V) -> (N, B, V)
|
||||
x = torch.bmm(x, self.W_UV)
|
||||
# Convert from (N, B, V) to (B, N * V)
|
||||
x = x.transpose(0, 1).reshape(-1, self.num_heads * self.v_head_dim)
|
||||
return self.o_proj(x)[0]
|
||||
|
||||
# Return `ql_nope`, `q_pe`
|
||||
def _q_proj_and_k_up_proj(self, x):
|
||||
q_nope, q_pe = self.q_proj(x)[0]\
|
||||
.view(-1, self.num_heads, self.qk_head_dim)\
|
||||
.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
|
||||
|
||||
# Convert from (B, N, P) to (N, B, P)
|
||||
q_nope = q_nope.transpose(0, 1)
|
||||
# Multiply (N, B, P) x (N, P, L) -> (N, B, L)
|
||||
ql_nope = torch.bmm(q_nope, self.W_UK_T)
|
||||
# Convert from (N, B, L) to (B, N, L)
|
||||
return ql_nope.transpose(0, 1), q_pe
|
||||
|
||||
def process_weights_after_loading(self, act_dtype: torch.dtype):
|
||||
|
||||
def get_layer_weight(layer):
|
||||
WEIGHT_NAMES = ("weight", "qweight", "weight_packed")
|
||||
for attr in WEIGHT_NAMES:
|
||||
if hasattr(layer, attr):
|
||||
return getattr(layer, attr)
|
||||
raise AttributeError(
|
||||
f"Layer '{layer}' has no recognized weight attribute:"
|
||||
f" {WEIGHT_NAMES}.")
|
||||
|
||||
def get_and_maybe_dequant_weights(layer: LinearBase):
|
||||
if not isinstance(layer.quant_method, UnquantizedLinearMethod):
|
||||
# NOTE: This should only be used offline, since it's O(N^3)
|
||||
eye = torch.eye(layer.input_size_per_partition,
|
||||
dtype=act_dtype,
|
||||
device=get_layer_weight(layer).device)
|
||||
dequant_weights = layer.quant_method.apply(layer,
|
||||
eye,
|
||||
bias=None)
|
||||
del eye
|
||||
# standardize to (output, input)
|
||||
return dequant_weights.T
|
||||
return layer.weight
|
||||
|
||||
# we currently do not have quantized bmm's which are needed for
|
||||
# `W_UV` and `W_UK_T`, we we just store fp16/bf16 copies and perform
|
||||
# the bmm's in 16-bit, the extra memory overhead of this is fairly low
|
||||
kv_b_proj_weight = get_and_maybe_dequant_weights(self.kv_b_proj).T
|
||||
assert kv_b_proj_weight.shape == (
|
||||
self.kv_lora_rank,
|
||||
self.num_heads * (self.qk_nope_head_dim + self.v_head_dim)), (
|
||||
f"{kv_b_proj_weight.shape=}, "
|
||||
f"{self.kv_lora_rank=}, "
|
||||
f"{self.num_heads=}, "
|
||||
f"{self.qk_nope_head_dim=}, "
|
||||
f"{self.v_head_dim=}")
|
||||
kv_b_proj_weight = kv_b_proj_weight.view(
|
||||
self.kv_lora_rank,
|
||||
self.num_heads,
|
||||
self.qk_nope_head_dim + self.v_head_dim,
|
||||
)
|
||||
|
||||
W_UK, W_UV = kv_b_proj_weight.split(
|
||||
[self.qk_nope_head_dim, self.v_head_dim], dim=-1)
|
||||
|
||||
# Convert from (L, N, V) to (N, L, V)
|
||||
self.W_UV = W_UV.transpose(0, 1)
|
||||
# Convert from (L, N, P) to (N, P, L)
|
||||
self.W_UK_T = W_UK.permute(1, 2, 0)
|
||||
|
||||
def _forward_prefill(
|
||||
self,
|
||||
query: torch.Tensor,
|
||||
kv_c_normed: torch.Tensor,
|
||||
k_pe: torch.Tensor,
|
||||
kv_c_and_k_pe_cache: torch.Tensor,
|
||||
attn_metadata: AscendMLAMetadata,
|
||||
) -> torch.Tensor:
|
||||
assert attn_metadata.prefill is not None
|
||||
|
||||
# TODO: enable this compute for flash attention computation
|
||||
# kv_nope = self.kv_b_proj(kv_c_normed)[0].view(\
|
||||
# -1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim)
|
||||
# k_nope, v = kv_nope.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)
|
||||
# key = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1)
|
||||
# v_padded = torch.nn.functional.pad(v, [0, query.shape[-1] - v.shape[-1]],
|
||||
# value=0)
|
||||
num_tokens = query.size(0)
|
||||
attn_output = torch.empty(num_tokens,
|
||||
self.num_heads,
|
||||
self.v_head_dim,
|
||||
dtype=query.dtype,
|
||||
device=query.device)
|
||||
# current requests is chunked in prefill, disable flash attention with chunked prefill
|
||||
vanilla_chunked_prefill_mla(
|
||||
output=attn_output,
|
||||
query=query,
|
||||
kv_cache=kv_c_and_k_pe_cache,
|
||||
block_tables=attn_metadata.prefill.block_table,
|
||||
query_lens=attn_metadata.prefill.query_lens,
|
||||
context_lens=attn_metadata.prefill.context_lens,
|
||||
kv_b_proj=self.kv_b_proj,
|
||||
max_query_len=attn_metadata.prefill.max_query_len,
|
||||
max_context_len=attn_metadata.prefill.max_context_len,
|
||||
nope_dim=self.qk_nope_head_dim,
|
||||
rope_dim=self.qk_rope_head_dim,
|
||||
v_head_dim=self.v_head_dim,
|
||||
scale=self.scale,
|
||||
alibi_slopes=None,
|
||||
causal=True)
|
||||
attn_output = attn_output.view(
|
||||
[num_tokens, self.num_heads * self.v_head_dim])
|
||||
return self.o_proj(attn_output)[0]
|
||||
|
||||
def _forward_decode(
|
||||
self,
|
||||
q_nope: torch.Tensor,
|
||||
q_pe: torch.Tensor,
|
||||
kv_c_and_k_pe_cache: torch.Tensor,
|
||||
attn_metadata: AscendMLAMetadata,
|
||||
) -> torch.Tensor:
|
||||
assert kv_c_and_k_pe_cache.numel() > 0
|
||||
|
||||
decode_meta = attn_metadata.decode
|
||||
assert decode_meta is not None
|
||||
|
||||
q = torch.cat([q_nope, q_pe], dim=-1)
|
||||
num_tokens = q.size(0)
|
||||
attn_output = torch.randn(
|
||||
[num_tokens, self.num_heads, self.kv_lora_rank],
|
||||
dtype=q.dtype,
|
||||
device=q.device)
|
||||
torch_npu._npu_paged_attention_mla(
|
||||
query=q,
|
||||
key_cache=kv_c_and_k_pe_cache,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
num_heads=self.num_heads,
|
||||
scale_value=self.scale,
|
||||
block_table=attn_metadata.decode.block_table, # type:ignore
|
||||
context_lens=attn_metadata.decode.seq_lens, # type:ignore
|
||||
mla_vheadsize=self.kv_lora_rank,
|
||||
out=attn_output)
|
||||
return self._v_up_proj_and_o_proj(attn_output)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
layer: AttentionLayer,
|
||||
hidden_states_or_q_c: torch.Tensor, # query in unified attn
|
||||
k_c_normed: torch.Tensor, # key in unified attn
|
||||
k_pe: torch.Tensor, # value in unified attn
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: M,
|
||||
output: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
|
||||
assert output is not None, "Output tensor must be provided."
|
||||
|
||||
if attn_metadata is None:
|
||||
# Profiling run.
|
||||
return output
|
||||
|
||||
num_actual_toks = attn_metadata.num_actual_tokens
|
||||
|
||||
# Inputs and outputs may be padded for CUDA graphs
|
||||
output_padded = output
|
||||
output = output[:num_actual_toks, ...]
|
||||
hidden_states_or_q_c = hidden_states_or_q_c[:num_actual_toks, ...]
|
||||
k_c_normed = k_c_normed[:num_actual_toks, ...]
|
||||
k_pe = k_pe[:num_actual_toks, ...]
|
||||
|
||||
# Restore head dim (for rotary embedding)
|
||||
k_pe = k_pe.unsqueeze(1)
|
||||
|
||||
assert attn_metadata.num_decodes is not None and \
|
||||
attn_metadata.num_prefills is not None and \
|
||||
attn_metadata.num_decode_tokens is not None
|
||||
|
||||
has_decode = attn_metadata.num_decodes > 0
|
||||
has_prefill = attn_metadata.num_prefills > 0
|
||||
num_decode_tokens = attn_metadata.num_decode_tokens
|
||||
|
||||
decode_hs_or_q_c = hidden_states_or_q_c[:num_decode_tokens]
|
||||
decode_k_pe = k_pe[:num_decode_tokens]
|
||||
|
||||
prefill_hs_or_q_c = hidden_states_or_q_c[num_decode_tokens:]
|
||||
prefill_k_pe = k_pe[num_decode_tokens:]
|
||||
prefill_k_c_normed = k_c_normed[num_decode_tokens:]
|
||||
|
||||
if has_decode:
|
||||
assert attn_metadata.decode is not None
|
||||
decode_ql_nope, decode_q_pe = \
|
||||
self._q_proj_and_k_up_proj(decode_hs_or_q_c)
|
||||
decode_q_pe[...], decode_k_pe[...] = self.rotary_emb(
|
||||
attn_metadata.decode.input_positions, decode_q_pe.contiguous(),
|
||||
decode_k_pe)
|
||||
|
||||
if has_prefill:
|
||||
assert attn_metadata.prefill is not None
|
||||
prefill_q = self.q_proj(prefill_hs_or_q_c)[0]\
|
||||
.view(-1, self.num_heads, self.qk_head_dim)
|
||||
prefill_q_pe = prefill_q[..., self.qk_nope_head_dim:]
|
||||
|
||||
prefill_q_pe[...], prefill_k_pe[...] = self.rotary_emb(
|
||||
attn_metadata.prefill.input_positions,
|
||||
prefill_q_pe.contiguous(), prefill_k_pe)
|
||||
|
||||
if kv_cache.numel() > 0:
|
||||
concat_and_cache_mla(k_c_normed, k_pe, kv_cache,
|
||||
attn_metadata.slot_mapping.flatten())
|
||||
# TODO: replaced back to ascend ops
|
||||
# key = torch.cat([k_c_normed.view([num_actual_toks, self.num_kv_heads, -1]), k_pe], dim=2)
|
||||
# torch_npu._npu_reshape_and_cache_siso(
|
||||
# key=key,
|
||||
# key_cache=kv_cache,
|
||||
# slot_indices=attn_metadata.slot_mapping.flatten())
|
||||
|
||||
if has_prefill:
|
||||
output[num_decode_tokens:] = self._forward_prefill(
|
||||
prefill_q, prefill_k_c_normed, prefill_k_pe, kv_cache,
|
||||
attn_metadata)
|
||||
|
||||
if has_decode:
|
||||
output[:num_decode_tokens] = self._forward_decode(
|
||||
decode_ql_nope, decode_q_pe, kv_cache, attn_metadata)
|
||||
|
||||
return output_padded
|
||||
Reference in New Issue
Block a user