diff --git a/vllm_ascend/attention.py b/vllm_ascend/attention.py index 3088efb..66bc45e 100644 --- a/vllm_ascend/attention.py +++ b/vllm_ascend/attention.py @@ -742,30 +742,20 @@ class AscendMLAAttentionBackendImpl(MLAAttentionImpl): self.qk_head_dim) q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) - if attn_metadata.num_prefills > 0: - assert attn_metadata.prefill_metadata is not None - assert attn_metadata.prefill_metadata.seq_lens is not None - np_positions = np.concatenate([ - np.arange(i) for i in attn_metadata.prefill_metadata.seq_lens - ]) - positions = torch.tensor(np_positions, - device=hidden_states_or_q_c.device) - else: - assert attn_metadata.decode_metadata is not None - np_positions = np.array(attn_metadata.decode_metadata.seq_lens) - 1 - positions = torch.tensor(np_positions, - device=hidden_states_or_q_c.device) + k_pe = k_pe.view(num_tokens, self.num_kv_heads, -1) if self.rotary_emb.__class__.__name__ == 'RotaryEmbedding': ori_q_pe_shape, ori_k_pe_shape = q_pe.shape, k_pe.shape q_pe = q_pe.reshape(num_tokens, -1) k_pe = k_pe.reshape(num_tokens, -1) - q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe) + q_pe, k_pe = self.rotary_emb(attn_metadata.input_positions, q_pe, + k_pe) q_pe = q_pe.view(ori_q_pe_shape) k_pe = k_pe.view(ori_k_pe_shape) else: - q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe) + q_pe, k_pe = self.rotary_emb(attn_metadata.input_positions, q_pe, + k_pe) if self.w_kc is None or self.w_vc is None: kv_b_proj_weight = self.kv_b_proj.weight.reshape( @@ -786,16 +776,14 @@ class AscendMLAAttentionBackendImpl(MLAAttentionImpl): k_cache = torch.cat( [kv_c_normed.view(num_tokens, self.num_kv_heads, -1), k_pe], dim=2) - k_pe = k_pe.repeat(1, self.num_heads, 1) + k_pe = k_pe.expand(-1, self.num_heads, -1) key = torch.cat([k_nope.view(num_tokens, kv_heads_num, -1), k_pe], dim=2) else: kv_heads_num = self.num_kv_heads - q_nope_t = torch_npu.npu_transpose(q_nope, (1, 0, 2), - require_contiguous=True) + q_nope_t = torch.transpose(q_nope, 0, 1) q_nope_out = torch.bmm(q_nope_t, self.w_kc) - q_nope = torch_npu.npu_transpose(q_nope_out, (1, 0, 2), - require_contiguous=True) + q_nope = torch.transpose(q_nope_out, 0, 1) k_cache = torch.cat( [kv_c_normed.view(num_tokens, self.num_kv_heads, -1), k_pe], dim=2) @@ -895,12 +883,10 @@ class AscendMLAAttentionBackendImpl(MLAAttentionImpl): inputLayout=0, outDataType=-1, attnOut=attn_output) - attn_output_t = torch_npu.npu_transpose(attn_output, (1, 0, 2), - require_contiguous=True) + attn_output_t = torch.transpose(attn_output, 0, 1) attn_output_t = torch.bmm(attn_output_t, self.w_vc) - attn_output = torch_npu.npu_transpose(attn_output_t, (1, 0, 2), - require_contiguous=True) + attn_output = torch.transpose(attn_output_t, 0, 1) - output, _ = self.o_proj(attn_output.view(num_tokens, -1)) + output, _ = self.o_proj(attn_output.reshape(num_tokens, -1)) return output diff --git a/vllm_ascend/model_runner.py b/vllm_ascend/model_runner.py index 2bb057f..d0aa06d 100644 --- a/vllm_ascend/model_runner.py +++ b/vllm_ascend/model_runner.py @@ -1137,6 +1137,8 @@ class NPUModelRunner(NPUModelRunnerBase[ModelInputForNPUWithSamplingMetadata]): if not bypass_model_exec: with set_forward_context(model_input.attn_metadata, self.vllm_config, virtual_engine): + if model_input.attn_metadata is not None: + model_input.attn_metadata.input_positions = model_input.input_positions hidden_or_intermediate_states = model_executable( input_ids=model_input.input_tokens, positions=model_input.input_positions, diff --git a/vllm_ascend/ops/fused_moe.py b/vllm_ascend/ops/fused_moe.py index cbb8622..db03509 100644 --- a/vllm_ascend/ops/fused_moe.py +++ b/vllm_ascend/ops/fused_moe.py @@ -65,7 +65,7 @@ def group_topk(hidden_states: torch.Tensor, if renormalize: topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) - return topk_weights.to(torch.float32), topk_ids.to(torch.int32) + return topk_weights, topk_ids.to(torch.int32) def fused_experts(hidden_states: torch.Tensor, w1: torch.Tensor, @@ -126,13 +126,12 @@ def fused_experts(hidden_states: torch.Tensor, w1: torch.Tensor, down_out_list = torch.cat(down_out_list, dim=0) # TODO: Reorder device memory 2 times here, replace the current # implementation here when suitable operators become available. - routing_weights = topk_weights.to(down_out_list.dtype) hidden_states = torch_npu.npu_moe_finalize_routing( down_out_list, skip1=None, skip2=None, bias=None, - scales=routing_weights, + scales=topk_weights, expanded_src_to_dst_row=expanded_row_idx, export_for_source_row=topk_ids) if len(ori_shape) == 3: diff --git a/vllm_ascend/ops/rotary_embedding.py b/vllm_ascend/ops/rotary_embedding.py index 2279ad1..1999386 100644 --- a/vllm_ascend/ops/rotary_embedding.py +++ b/vllm_ascend/ops/rotary_embedding.py @@ -18,7 +18,8 @@ from typing import Optional, Tuple import torch -from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding +from vllm.model_executor.layers.rotary_embedding import ( + DeepseekScalingRotaryEmbedding, RotaryEmbedding) def rope_forward_oot( @@ -49,8 +50,43 @@ def rope_forward_oot( self.cos_sin_cache, self.is_neox_style, ) + return query, key + + +def rope_deepseek_forward_oot( + self, + positions: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor, + offsets: Optional[torch.Tensor] = None, +) -> Tuple[torch.Tensor, torch.Tensor]: + import torch_npu + + if self.cos_sin_cache.device != query.device: + self.cos_sin_cache = self.cos_sin_cache.to(query.device) + if self.cos_sin_cache.dtype != query.dtype: + self.cos_sin_cache = self.cos_sin_cache.to(query.dtype) + if offsets is not None: + raise NotImplementedError( + "Batched rotary embedding is currently not supported on NPU.") + else: + # TODO: Remove the contiguous in the future. + ori_query_shape, ori_key_shape = query.shape, key.shape + query = query.contiguous().view(query.shape[0], -1) + key = key.contiguous().view(query.shape[0], -1) + torch_npu.npu_rope( + positions, + query, + key, + self.head_size, + self.cos_sin_cache, + self.is_neox_style, + ) + query = query.view(ori_query_shape) + key = key.view(ori_key_shape) return query, key RotaryEmbedding.forward_oot = rope_forward_oot +DeepseekScalingRotaryEmbedding.forward = rope_deepseek_forward_oot