From 6be321b95ec49b623ae3ca3c452f18bafaab56aa Mon Sep 17 00:00:00 2001 From: zzzzwwjj <34335947+zzzzwwjj@users.noreply.github.com> Date: Fri, 24 Oct 2025 16:29:08 +0800 Subject: [PATCH] remove useless code (#3685) ### What this PR does / why we need it? `vanilla_chunked_prefill_mla` and `vanilla_decode_mla` is unused, so remove it. ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? - vLLM version: v0.11.0rc3 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0 Signed-off-by: zzzzwwjj <1183291235@qq.com> --- vllm_ascend/ops/attention.py | 177 +---------------------------------- 1 file changed, 1 insertion(+), 176 deletions(-) diff --git a/vllm_ascend/ops/attention.py b/vllm_ascend/ops/attention.py index 05600aee..f530de03 100644 --- a/vllm_ascend/ops/attention.py +++ b/vllm_ascend/ops/attention.py @@ -15,10 +15,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, Optional, Tuple +from typing import Optional import torch -from vllm.model_executor.layers.linear import ColumnParallelLinear # Implementation of vanilla chunked prefill, should be removed after the kernel is ready for @@ -133,177 +132,3 @@ def vanilla_chunked_prefill( head_dim]).to(output.dtype)) output.copy_(attn_output) return attn_output - - -def vanilla_chunked_prefill_mla( - output: torch.Tensor, # (num_tokens, num_heads, v_head_dim) - query: torch.Tensor, # (num_tokens, num_heads, nope_dim + rope_dim) - kv_cache: Tuple[ - torch.Tensor], # [nope, rope] (num_blocks, block_size, latent_kv) - block_tables: torch.Tensor, # (batch_size, max_num_blocks_per_seq) - query_lens: torch.Tensor, # (batch_size) - context_lens: torch.Tensor, # (batch_size) - kv_b_proj: ColumnParallelLinear, # () - max_query_len: int, - max_context_len: int, - nope_dim: int, - rope_dim: int, - v_head_dim: int, - scale: float, - alibi_slopes: Optional[torch.Tensor], - causal: bool = True) -> None: - batch_size = block_tables.size(0) - assert len(kv_cache) > 1 - assert query_lens.size(0) == batch_size - num_heads = query.size(1) - nope_cache = kv_cache[0] - rope_cache = kv_cache[1] - block_size = nope_cache.size(1) - latent_kv_dim = nope_cache.size(-1) - max_num_blocks_per_seq = block_tables.size(1) - batch_size = query_lens.size(0) - nope_cache = nope_cache.squeeze() - # select kv_c out as [batch_size, max_context_len, latent_kv + rope_dim] and get kv_c and k_pe - # cached_kv_c: [batch_size, max_context_len, latent_kv] - # cached_k_pe: [batch_size, max_context_len, rope_dim] - cache_kv_c = nope_cache[block_tables].view( - batch_size, max_num_blocks_per_seq * block_size, - latent_kv_dim)[:, :max_context_len, :] - cache_k_pe = rope_cache[block_tables].view( - batch_size, max_num_blocks_per_seq * block_size, - rope_dim)[:, :max_context_len, :] - # get k_rope and v - # k_nope: [batch_size, max_context_len, num_heads, nope_dim] - # value: [batch_size, max_context_len, num_heads, v_head_dim] - k_nope, value = kv_b_proj(cache_kv_c)[0].view( - batch_size, max_context_len, num_heads, - nope_dim + v_head_dim).split([nope_dim, v_head_dim], dim=-1) - # key: [batch_size, max_context_len, num_hads, rope_dim + nope_dim] - key = torch.cat( - [k_nope, cache_k_pe.unsqueeze(2).expand(-1, -1, num_heads, -1)], - dim=-1) - - context_lens = context_lens.view(-1, 1).to("npu") - query_lens = query_lens.view(-1, 1).to("npu") - seq_diff = context_lens - query_lens - - q_idx_mask = (torch.arange(0, max_query_len, - device="npu").view(1, -1).repeat(batch_size, 1)) - kv_c_idx_mask = (torch.arange(0, max_context_len, - device="npu").view(1, - -1).repeat(batch_size, 1)) - kv_c_mask = kv_c_idx_mask < context_lens - q_mask = q_idx_mask < query_lens - - # calculate idx for causal mask of query [batch, max_seqlen_q] - causal_mask_idx = (q_idx_mask + seq_diff)[q_mask] - - # generate causal mask [batch, max_seqlen_q, max_seqlen_k] - tril_mask = torch.tril( - torch.ones(max_context_len, max_context_len, device="npu")) - tril_mask[tril_mask == 0] = float("-inf") - tril_mask[tril_mask == 1] = 0 - causal_mask = tril_mask[causal_mask_idx] - causal_mask_padding = torch.empty( - [batch_size, max_query_len, max_context_len], - device="npu").fill_(float("-inf")) - causal_mask_padding[q_mask] = causal_mask - # to [batch, num_heads, max_seqlen_q, max_seqlen_k] - causal_mask_padding = causal_mask_padding.unsqueeze(1) - - pad_q = torch.zeros( - [batch_size, max_query_len, num_heads, rope_dim + nope_dim], - device="npu", - dtype=query.dtype, - ) - pad_k = torch.zeros( - [batch_size, max_context_len, num_heads, rope_dim + nope_dim], - device="npu", - dtype=key.dtype, - ) - pad_v = torch.zeros( - [batch_size, max_context_len, num_heads, v_head_dim], - device="npu", - dtype=value.dtype, - ) - num_query = torch.sum(q_mask).item() - num_add_query = num_query - query.size(0) - # mtp will come in - if num_add_query > 0: - add_query_size = query.size() - add_query_size = list(add_query_size) - add_query_size[0] = num_add_query - pad_tensor = torch.zeros(add_query_size, - dtype=query.dtype, - device=query.device) - query = torch.cat([query, pad_tensor], dim=0) - pad_q[q_mask] = query - pad_k[kv_c_mask] = key[kv_c_mask] - pad_v[kv_c_mask] = value[kv_c_mask] - - pad_q = pad_q.permute(0, 2, 1, 3) - pad_k = pad_k.permute(0, 2, 1, 3) - pad_v = pad_v.permute(0, 2, 1, 3) - attn_mask = torch.empty([batch_size, 1, 1, max_context_len], - device="npu").fill_(float("-inf")) - attn_mask[:, :, :, :max_context_len].masked_fill_( - kv_c_mask[:, None, None, :], 0) - # [b, h, f, t] - attn_weights = torch.einsum("bhqd,bhkd->bhqk", pad_q, pad_k) - attn_weights *= scale - attn_mask = attn_mask.float() - attn_weights = attn_weights + attn_mask - if causal: - attn_weights = attn_weights + causal_mask_padding - - attn_weights = torch.softmax(attn_weights, dim=-1) - attn_output = torch.einsum("bhqk,bhkd->bhqd", attn_weights, pad_v.float()) - attn_output = attn_output.permute(0, 2, 1, 3) - - attn_output = (attn_output[q_mask].view([-1, num_heads, - v_head_dim]).to(output.dtype)) - attn_output = attn_output.view_as(output) - output.copy_(attn_output) - return attn_output - - -def vanilla_decode_mla( - query: torch.Tensor, # [num_tokens, num_heads, latent_dim + rope_dim] - key_cache: torch. - Tensor, # [num_blocks, block_size, num_kv_heads, latent_dim + rope_dim] - num_kv_heads: int, - num_heads: int, - scale: float, - block_table: torch.Tensor, # [batch_size, max_block_size] - context_lens: List[int], - mla_vhead_size: int, - rope_dim: int, - output: torch.Tensor): - batch_size = block_table.size()[0] - max_block_size = block_table.size()[1] - reduce_dim = key_cache.size()[-1] - block_size = key_cache.size()[1] - latent_dim = reduce_dim - rope_dim - kv_c_and_pe = key_cache[block_table].view( - [batch_size, max_block_size * block_size, num_kv_heads, reduce_dim]) - max_context_len = max(context_lens) - context_lens = torch.tensor(context_lens, device="npu").view(batch_size, 1) - # [batch_size, max_context_len, num_kv_heads, latent_dim + rope_dim] - # since the kv head is 1 in deepseek, we use expand here for perf - kv_c_and_pe = kv_c_and_pe[:, :max_context_len, :, :].expand( - -1, -1, num_heads, 1) - kv_c = kv_c_and_pe[..., :latent_dim] - kv_idx_mask = (torch.arange(0, max_context_len, - device="npu").view(1, - -1).repeat(batch_size, 1)) - # [batch_size, max_context_len] - kv_idx_mask = kv_idx_mask < context_lens - query = query.unsqueeze(1) - attn_weights = torch.einsum("bqhd,bkhd->bhqk", query, kv_c_and_pe) - attn_weights *= scale - attn_weights = attn_weights + kv_idx_mask[:, -1, -1, :].float() - attn_weights = torch.softmax(attn_weights, dim=-1) - attn_output = torch.einsum("bhqk,bkhd->bqhd", attn_weights, - kv_c.float()).view(-1, num_heads, latent_dim) - output.copy_(attn_output) - return output