port deepseekv2 and mtp to main branch (#429)
### What this PR does / why we need it? This PR ports all the deepseek graph mode code and mtp code from v0.7.3 to the main branch --------- Signed-off-by: SidaoY <1024863041@qq.com> Signed-off-by: linfeng-yuan <1102311262@qq.com> Signed-off-by: Yizhou Liu <liuyizhou5@h-partners.com> Signed-off-by: mengwei805 <mengwei25@huawei.com> Signed-off-by: libaokui <libaokui@huawei.com> Signed-off-by: q00832892 <qiaoyang19@huawei.com> Signed-off-by: ganyi <pleaplusone.gy@gmail.com> Co-authored-by: SidaoY <1024863041@qq.com> Co-authored-by: linfeng-yuan <1102311262@qq.com> Co-authored-by: Yizhou Liu <liuyizhou5@h-partners.com> Co-authored-by: mengwei805 <mengwei25@huawei.com> Co-authored-by: libaokui <libaokui@huawei.com>
This commit is contained in:
@@ -31,13 +31,14 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||
AttentionLayer,
|
||||
AttentionMetadata, AttentionType,
|
||||
MLAAttentionImpl)
|
||||
from vllm.attention.backends.utils import (CommonAttentionState,
|
||||
from vllm.attention.backends.utils import (PAD_SLOT_ID, CommonAttentionState,
|
||||
CommonMetadataBuilder,
|
||||
compute_slot_mapping,
|
||||
compute_slot_mapping_start_idx,
|
||||
is_block_tables_empty)
|
||||
from vllm.utils import async_tensor_h2d, make_tensor_with_pad
|
||||
|
||||
from vllm_ascend.utils import VLLM_ENABLE_GRAPH_MODE
|
||||
from vllm_ascend.worker.model_runner import (
|
||||
ModelInputForNPUBuilder, ModelInputForNPUWithSamplingMetadata)
|
||||
|
||||
@@ -222,7 +223,7 @@ class AscendMLAAttentionBackend(AscendAttentionBackend):
|
||||
num_kv_heads: int,
|
||||
head_size: int,
|
||||
) -> Tuple[int, ...]:
|
||||
return (1, num_blocks, block_size, num_kv_heads * head_size)
|
||||
return (num_blocks, block_size, num_kv_heads, head_size)
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -552,10 +553,33 @@ class AscendMetadataBuilder(CommonMetadataBuilder[AscendMetadata]):
|
||||
inter_data.block_tables,
|
||||
)
|
||||
|
||||
def _get_graph_runner_block_tables(
|
||||
self, num_seqs: int,
|
||||
block_tables: List[List[int]]) -> torch.Tensor:
|
||||
# The shape of graph_block_tables is
|
||||
# [max batch size, max context len // block size].
|
||||
|
||||
max_batch_size, max_blocks = self.runner.graph_block_tables.shape
|
||||
assert max_batch_size >= num_seqs
|
||||
|
||||
graph_block_tables = self.runner.graph_block_tables # [:num_seqs]
|
||||
for i, block_table in enumerate(block_tables):
|
||||
if block_table:
|
||||
num_blocks = len(block_table)
|
||||
if num_blocks <= max_blocks:
|
||||
graph_block_tables[i, :num_blocks] = block_table
|
||||
else:
|
||||
graph_block_tables[
|
||||
i, :max_blocks] = block_table[:max_blocks]
|
||||
|
||||
return torch.from_numpy(graph_block_tables).to(
|
||||
device=self.runner.device, non_blocking=True)
|
||||
|
||||
def build(
|
||||
self,
|
||||
seq_lens: List[int],
|
||||
query_lens: List[int],
|
||||
graph_pad_size: int,
|
||||
):
|
||||
"""Build attention metadata with on-device tensors.
|
||||
|
||||
@@ -568,6 +592,7 @@ class AscendMetadataBuilder(CommonMetadataBuilder[AscendMetadata]):
|
||||
self.input_builder.chunked_prefill_enabled)
|
||||
|
||||
device = self.runner.device
|
||||
use_torchair_graph = graph_pad_size != -1
|
||||
|
||||
max_query_len = max(query_lens)
|
||||
max_prefill_seq_len = max(self.prefill_seq_lens, default=0)
|
||||
@@ -582,12 +607,36 @@ class AscendMetadataBuilder(CommonMetadataBuilder[AscendMetadata]):
|
||||
self.attn_mask = None
|
||||
num_decode_tokens = self.num_decode_tokens
|
||||
|
||||
if self.num_prefills == 0 and use_torchair_graph:
|
||||
num_seqs = len(seq_lens)
|
||||
self.slot_mapping.extend([PAD_SLOT_ID] * graph_pad_size)
|
||||
self.block_tables.extend([[]] * graph_pad_size)
|
||||
block_tables = self._get_graph_runner_block_tables(
|
||||
num_seqs, self.block_tables)
|
||||
else:
|
||||
block_tables = make_tensor_with_pad(
|
||||
self.block_tables,
|
||||
pad=0,
|
||||
dtype=torch.int32,
|
||||
device=device,
|
||||
)
|
||||
|
||||
if self.num_prefills > 0:
|
||||
self.attn_mask = AscendMetadataBuilder._attn_mask_builder.get_attn_mask( # type: ignore
|
||||
max_prefill_seq_len,
|
||||
self.input_builder.runner.model_config.dtype,
|
||||
self.input_builder.runner.device)
|
||||
else:
|
||||
self.attn_mask = None
|
||||
num_decode_tokens = self.num_decode_tokens
|
||||
|
||||
block_tables = make_tensor_with_pad(
|
||||
self.block_tables,
|
||||
pad=0,
|
||||
dtype=torch.int32,
|
||||
device=device,
|
||||
)
|
||||
|
||||
assert max_query_len > 0, "query_lens: {}".format(query_lens)
|
||||
|
||||
assert device is not None
|
||||
@@ -855,14 +904,100 @@ class AscendMLAAttentionBackendImpl(MLAAttentionImpl):
|
||||
self.q_proj = extra_impl_args['q_proj']
|
||||
self.kv_b_proj = extra_impl_args['kv_b_proj']
|
||||
self.o_proj = extra_impl_args['o_proj']
|
||||
self.kv_a_proj_with_mqa = extra_impl_args.get('kv_a_proj_with_mqa',
|
||||
None)
|
||||
self.kv_a_layernorm = extra_impl_args.get('kv_a_layernorm', None)
|
||||
self.k_pe_cache = None
|
||||
self.k_nope_cache = None
|
||||
self.w_kc = None
|
||||
self.w_vc = None
|
||||
|
||||
def exec_kv(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
cos: torch.Tensor,
|
||||
sin: torch.Tensor,
|
||||
kv_cache: Tuple,
|
||||
slots: torch.Tensor,
|
||||
):
|
||||
B = hidden_states.shape[0]
|
||||
N = self.num_kv_heads
|
||||
S = 1
|
||||
kv = self.kv_a_proj_with_mqa(hidden_states)[0]
|
||||
# npu_kv_rmsnorm_rope_cache needs [B, N, S, D]
|
||||
kv = kv.view(B, N, S, self.kv_lora_rank + self.qk_rope_head_dim)
|
||||
|
||||
k_pe, k_nope = torch.ops.npu_inference.npu_kv_rmsnorm_rope_cache(
|
||||
kv,
|
||||
self.kv_a_layernorm.weight,
|
||||
cos,
|
||||
sin,
|
||||
slots.to(torch.int64),
|
||||
kv_cache[1],
|
||||
kv_cache[0],
|
||||
epsilon=self.kv_a_layernorm.variance_epsilon,
|
||||
cache_mode="PA",
|
||||
)
|
||||
|
||||
return k_pe, k_nope
|
||||
|
||||
def apply_rotary_emb(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
cos: torch.Tensor,
|
||||
sin: torch.Tensor,
|
||||
is_neox_style: bool,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
x: [num_tokens, num_heads, head_size]
|
||||
cos: [num_tokens, head_size // 2]
|
||||
sin: [num_tokens, head_size // 2]
|
||||
is_neox_style: Whether to use the Neox-style or GPT-J-style rotary
|
||||
positional embeddings.
|
||||
"""
|
||||
cos = cos.unsqueeze(-2).to(x.dtype)
|
||||
sin = sin.unsqueeze(-2).to(x.dtype)
|
||||
if is_neox_style:
|
||||
x1, x2 = torch.chunk(x, 2, dim=-1)
|
||||
else:
|
||||
x1 = x[..., ::2]
|
||||
x2 = x[..., 1::2]
|
||||
o1 = x1 * cos - x2 * sin
|
||||
o2 = x2 * cos + x1 * sin
|
||||
if is_neox_style:
|
||||
return torch.cat((o1, o2), dim=-1)
|
||||
else:
|
||||
return torch.stack((o1, o2), dim=-1).flatten(-2)
|
||||
|
||||
def rope_single(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
cos: torch.Tensor,
|
||||
sin: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
B, N, D = x.shape
|
||||
S = 1
|
||||
x = x.view(B, N, S, D)
|
||||
x = torch.ops.npu_inference.npu_interleave_rope(x, cos, sin)
|
||||
return x.view(B, N, D)
|
||||
|
||||
def process_weights_after_loading(self, act_dtype: torch.dtype):
|
||||
if self.w_kc is None or self.w_vc is None:
|
||||
kv_b_proj_weight = self.kv_b_proj.weight.reshape(
|
||||
self.num_heads, self.qk_nope_head_dim + self.v_head_dim,
|
||||
self.kv_lora_rank)
|
||||
self.w_kc = kv_b_proj_weight[:, :self.
|
||||
qk_nope_head_dim, :].contiguous()
|
||||
self.w_vc = kv_b_proj_weight[:,
|
||||
self.qk_nope_head_dim:, :].transpose(
|
||||
1, 2).contiguous()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
layer: AttentionLayer,
|
||||
hidden_states_or_q_c: torch.Tensor,
|
||||
kv_c_normed: torch.Tensor,
|
||||
hidden_states_or_kv_c_normed: torch.Tensor,
|
||||
k_pe: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AscendMetadata,
|
||||
@@ -873,7 +1008,7 @@ class AscendMLAAttentionBackendImpl(MLAAttentionImpl):
|
||||
Args:
|
||||
hidden_states_or_q_c: shape = [num_tokens, num_heads * head_size]
|
||||
num_tokens = batch_size * seq_len
|
||||
kv_c_normed: shape = [num_tokens, num_kv_heads * head_size]
|
||||
hidden_states_or_kv_c_normed: shape = [num_tokens, num_kv_heads * head_size]
|
||||
k_pe: shape = [num_tokens, num_kv_heads * head_size]
|
||||
kv_cache: shape = [1, num_blocks, block_size,
|
||||
num_kv_heads * head_size]
|
||||
@@ -889,71 +1024,86 @@ class AscendMLAAttentionBackendImpl(MLAAttentionImpl):
|
||||
"are not implemented for "
|
||||
"PallasAttentionBackendImpl")
|
||||
|
||||
if attn_metadata is None:
|
||||
# for profile run
|
||||
return hidden_states_or_q_c
|
||||
|
||||
num_tokens = hidden_states_or_q_c.shape[0]
|
||||
q = self.q_proj(hidden_states_or_q_c)[0].view(-1, self.num_heads,
|
||||
self.qk_head_dim)
|
||||
q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim],
|
||||
dim=-1)
|
||||
if k_pe is None and attn_metadata.decode_metadata:
|
||||
seq_len = self.rotary_emb.max_position_embeddings
|
||||
|
||||
k_pe = k_pe.view(num_tokens, self.num_kv_heads, -1)
|
||||
cos = self.rotary_emb.cos_cached[:seq_len].to(dtype=q_pe.dtype)
|
||||
sin = self.rotary_emb.sin_cached[:seq_len].to(dtype=q_pe.dtype)
|
||||
cos = cos[attn_metadata.input_positions]
|
||||
sin = sin[attn_metadata.input_positions]
|
||||
cos = cos[:, None, None, :]
|
||||
sin = sin[:, None, None, :]
|
||||
|
||||
if self.rotary_emb.__class__.__name__ == 'RotaryEmbedding':
|
||||
ori_q_pe_shape, ori_k_pe_shape = q_pe.shape, k_pe.shape
|
||||
q_pe = q_pe.reshape(num_tokens, -1)
|
||||
k_pe = k_pe.reshape(num_tokens, -1)
|
||||
q_pe, k_pe = self.rotary_emb(attn_metadata.input_positions, q_pe,
|
||||
k_pe)
|
||||
q_pe = q_pe.view(ori_q_pe_shape)
|
||||
k_pe = k_pe.view(ori_k_pe_shape)
|
||||
q_pe = self.rope_single(q_pe, cos, sin)
|
||||
k_pe, k_nope = self.exec_kv(hidden_states_or_kv_c_normed, cos, sin,
|
||||
kv_cache, attn_metadata.slot_mapping)
|
||||
else:
|
||||
q_pe, k_pe = self.rotary_emb(attn_metadata.input_positions, q_pe,
|
||||
k_pe)
|
||||
|
||||
if self.w_kc is None or self.w_vc is None:
|
||||
kv_b_proj_weight = self.kv_b_proj.weight.reshape(
|
||||
self.num_heads, self.qk_nope_head_dim + self.v_head_dim,
|
||||
self.kv_lora_rank)
|
||||
self.w_kc = kv_b_proj_weight[:, :self.
|
||||
qk_nope_head_dim, :].contiguous()
|
||||
self.w_vc = kv_b_proj_weight[:,
|
||||
self.qk_nope_head_dim:, :].transpose(
|
||||
1, 2).contiguous()
|
||||
if k_pe is None:
|
||||
# NOTE: k_pe is None when graph mode enabled
|
||||
kv_c, k_pe = self.kv_a_proj_with_mqa(
|
||||
hidden_states_or_kv_c_normed)[0].split(
|
||||
[self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
|
||||
kv_c_normed = self.kv_a_layernorm(kv_c.contiguous())
|
||||
else:
|
||||
kv_c_normed = hidden_states_or_kv_c_normed
|
||||
k_pe = k_pe.view(num_tokens, self.num_kv_heads, -1)
|
||||
if self.rotary_emb.__class__.__name__ == 'RotaryEmbedding':
|
||||
# NOTE: When scaling not specified
|
||||
ori_q_pe_shape, ori_k_pe_shape = q_pe.shape, k_pe.shape
|
||||
q_pe = q_pe.reshape(num_tokens, -1)
|
||||
k_pe = k_pe.reshape(num_tokens, -1)
|
||||
q_pe, k_pe = self.rotary_emb(attn_metadata.input_positions,
|
||||
q_pe, k_pe)
|
||||
q_pe = q_pe.view(ori_q_pe_shape)
|
||||
k_pe = k_pe.view(ori_k_pe_shape)
|
||||
else:
|
||||
q_pe, k_pe = self.rotary_emb(attn_metadata.input_positions,
|
||||
q_pe, k_pe)
|
||||
|
||||
if attn_metadata.num_prefills > 0:
|
||||
kv_heads_num = self.num_heads
|
||||
kv = self.kv_b_proj(kv_c_normed)[0].view(num_tokens, kv_heads_num,
|
||||
-1)
|
||||
kv = self.kv_b_proj(kv_c_normed)[0].view(num_tokens,
|
||||
self.num_heads, -1)
|
||||
k_nope, value = kv.split([self.qk_nope_head_dim, self.v_head_dim],
|
||||
dim=-1)
|
||||
k_cache = torch.cat(
|
||||
[kv_c_normed.view(num_tokens, self.num_kv_heads, -1), k_pe],
|
||||
dim=2)
|
||||
k_pe = k_pe.expand(-1, self.num_heads, -1)
|
||||
key = torch.cat([k_nope.view(num_tokens, kv_heads_num, -1), k_pe],
|
||||
dim=2)
|
||||
else:
|
||||
kv_heads_num = self.num_kv_heads
|
||||
q_nope_t = torch.transpose(q_nope, 0, 1)
|
||||
q_nope_out = torch.bmm(q_nope_t, self.w_kc)
|
||||
q_nope = torch.transpose(q_nope_out, 0, 1)
|
||||
k_cache = torch.cat(
|
||||
[kv_c_normed.view(num_tokens, self.num_kv_heads, -1), k_pe],
|
||||
dim=2)
|
||||
|
||||
query = torch.cat([q_nope, q_pe], dim=-1).view(num_tokens,
|
||||
self.num_heads, -1)
|
||||
|
||||
if kv_cache.numel() > 0:
|
||||
key_cache = kv_cache[0]
|
||||
num_blocks, block_size, _ = key_cache.shape
|
||||
|
||||
key_cache = key_cache.view(
|
||||
num_blocks, block_size, self.num_kv_heads,
|
||||
self.qk_rope_head_dim + self.kv_lora_rank)
|
||||
slots = attn_metadata.slot_mapping
|
||||
torch_npu._npu_reshape_and_cache_siso(key=k_cache,
|
||||
key_cache=key_cache,
|
||||
slot_indices=slots)
|
||||
# TODO: Replace the env with more flexible expressions
|
||||
if VLLM_ENABLE_GRAPH_MODE == '1':
|
||||
if len(kv_cache) > 0 and kv_cache[0].numel(
|
||||
) > 0 and attn_metadata.num_prefills > 0:
|
||||
slots = attn_metadata.slot_mapping
|
||||
# NOTE: Seperate the kv cache in advance to avoid OOM or other issues
|
||||
torch_npu._npu_reshape_and_cache(key=kv_c_normed.view(
|
||||
num_tokens, self.num_kv_heads, -1),
|
||||
value=k_pe,
|
||||
key_cache=kv_cache[0],
|
||||
value_cache=kv_cache[1],
|
||||
slot_indices=slots)
|
||||
else:
|
||||
if kv_cache.numel() > 0:
|
||||
key = torch.cat([
|
||||
kv_c_normed.view(num_tokens, self.num_kv_heads, -1), k_pe
|
||||
],
|
||||
dim=2)
|
||||
slots = attn_metadata.slot_mapping
|
||||
torch_npu._npu_reshape_and_cache_siso(key=key,
|
||||
key_cache=kv_cache,
|
||||
slot_indices=slots)
|
||||
|
||||
if attn_metadata.num_prefills > 0:
|
||||
attn_output = torch.empty(num_tokens,
|
||||
@@ -964,12 +1114,15 @@ class AscendMLAAttentionBackendImpl(MLAAttentionImpl):
|
||||
if (attn_metadata.block_tables is None
|
||||
or attn_metadata.block_tables.numel() == 0):
|
||||
assert attn_metadata.attn_mask is not None
|
||||
mask = attn_metadata.attn_mask
|
||||
assert attn_metadata.prefill_metadata is not None
|
||||
assert attn_metadata.prefill_metadata.seq_lens is not None
|
||||
mask = attn_metadata.attn_mask
|
||||
self.seq_lens_tensor_cpu = torch.from_numpy(
|
||||
np.array(attn_metadata.prefill_metadata.seq_lens).astype(
|
||||
np.int32))
|
||||
k_pe = k_pe.repeat(1, self.num_heads, 1)
|
||||
key = torch.cat(
|
||||
[k_nope.view(num_tokens, self.num_heads, -1), k_pe], dim=2)
|
||||
torch_npu._npu_flash_attention(
|
||||
query=query,
|
||||
key=key,
|
||||
@@ -987,29 +1140,55 @@ class AscendMLAAttentionBackendImpl(MLAAttentionImpl):
|
||||
)
|
||||
elif attn_metadata.decode_metadata:
|
||||
assert kv_cache is not None
|
||||
# if torch.empty is used here, the preemptive scheduling case of
|
||||
# test_mtp_correctness.py will fail to run.
|
||||
attn_output = torch.randn(
|
||||
[num_tokens, self.num_heads, self.kv_lora_rank],
|
||||
dtype=query.dtype,
|
||||
device=query.device)
|
||||
self.seq_lens_tensor_cpu = torch.from_numpy(
|
||||
np.array(attn_metadata.decode_metadata.seq_lens).astype(
|
||||
np.int32))
|
||||
block_tables = attn_metadata.decode_metadata.block_tables
|
||||
torch_npu._npu_paged_attention_mla(
|
||||
query=query,
|
||||
key_cache=key_cache,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
num_heads=self.num_heads,
|
||||
scale_value=self.scale,
|
||||
block_table=block_tables,
|
||||
context_lens=self.seq_lens_tensor_cpu,
|
||||
mla_vheadsize=self.kv_lora_rank,
|
||||
out=attn_output)
|
||||
attn_output_t = torch.transpose(attn_output, 0, 1)
|
||||
attn_output_t = torch.bmm(attn_output_t, self.w_vc)
|
||||
attn_output = torch.transpose(attn_output_t, 0, 1)
|
||||
if VLLM_ENABLE_GRAPH_MODE == '1':
|
||||
# TorchAir's shape is [bs, num_heads_per_rank, seq_len, dim]
|
||||
q_nope = q_nope.view(num_tokens, self.num_heads, 1, -1)
|
||||
q_pe = q_pe.view(num_tokens, self.num_heads, 1, -1)
|
||||
attn_output, _ = torch.ops.npu.npu_fused_infer_attention_score(
|
||||
q_nope,
|
||||
k_nope,
|
||||
k_nope,
|
||||
query_rope=q_pe,
|
||||
key_rope=k_pe,
|
||||
num_heads=self.num_heads,
|
||||
num_key_value_heads=1,
|
||||
input_layout="BNSD",
|
||||
atten_mask=attn_metadata.attn_mask,
|
||||
scale=self.scale,
|
||||
antiquant_mode=0,
|
||||
antiquant_scale=None,
|
||||
block_table=attn_metadata.block_tables,
|
||||
block_size=kv_cache[0].shape[1],
|
||||
actual_seq_lengths_kv=attn_metadata.seq_lens,
|
||||
)
|
||||
attn_output = attn_output.view(num_tokens, -1,
|
||||
self.kv_lora_rank).transpose(
|
||||
0, 1)
|
||||
attn_output = torch.bmm(attn_output, self.w_vc).transpose(0, 1)
|
||||
else:
|
||||
# if torch.empty is used here, the preemptive scheduling case of
|
||||
# test_mtp_correctness.py will fail to run.
|
||||
attn_output = torch.randn(
|
||||
[num_tokens, self.num_heads, self.kv_lora_rank],
|
||||
dtype=query.dtype,
|
||||
device=query.device)
|
||||
self.seq_lens_tensor_cpu = torch.from_numpy(
|
||||
np.array(attn_metadata.decode_metadata.seq_lens).astype(
|
||||
np.int32))
|
||||
block_tables = attn_metadata.decode_metadata.block_tables
|
||||
torch_npu._npu_paged_attention_mla(
|
||||
query=query,
|
||||
key_cache=kv_cache,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
num_heads=self.num_heads,
|
||||
scale_value=self.scale,
|
||||
block_table=block_tables,
|
||||
context_lens=self.seq_lens_tensor_cpu,
|
||||
mla_vheadsize=self.kv_lora_rank,
|
||||
out=attn_output)
|
||||
attn_output_t = torch.transpose(attn_output, 0, 1)
|
||||
attn_output_t = torch.bmm(attn_output_t, self.w_vc)
|
||||
attn_output = torch.transpose(attn_output_t, 0, 1)
|
||||
|
||||
output, _ = self.o_proj(attn_output.reshape(num_tokens, -1))
|
||||
|
||||
|
||||
Reference in New Issue
Block a user