[Feat.][310P] addrmsnorm for 300I DUO (#6704)
### What this PR does / why we need it? This PR integrates the `npu_add_rms_norm` fused kernel for RMSNorm operations with residual connections on 310P devices. This change optimizes the computation by replacing a two-step process (manual residual addition followed by RMSNorm) with a single, more efficient fused operation. This is needed to improve the performance of models utilizing RMSNorm with residual connections on the 310P architecture. Fixes # ### Does this PR introduce _any_ user-facing change? No, this PR introduces an internal optimization and does not change any user-facing APIs or behaviors. ### How was this patch tested? This patch was tested with updated unit tests (`test_RMSNorm_forward_310p`) that mock the `npu_add_rms_norm` operation to verify the correctness of the fused kernel integration. --------- Signed-off-by: Tflowers-0129 <2906339855@qq.com>
This commit is contained in:
@@ -5,7 +5,7 @@ import torch
|
||||
from vllm.config import set_current_vllm_config
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
|
||||
from vllm_ascend.utils import AscendDeviceType, enable_custom_op
|
||||
from vllm_ascend.utils import enable_custom_op
|
||||
from vllm_ascend.utils import is_310p as is_310p_hw
|
||||
|
||||
enable_custom_op()
|
||||
@@ -39,8 +39,8 @@ def default_vllm_config():
|
||||
with set_current_vllm_config(mock_config):
|
||||
yield mock_config
|
||||
|
||||
@pytest.mark.skip(
|
||||
"Skip as register_kernels has NPU SocName checking in CANN 8.5.0.")
|
||||
|
||||
@pytest.mark.skip("Skip as register_kernels has NPU SocName checking in CANN 8.5.0.")
|
||||
@pytest.mark.skipif(is_310p_hw(), reason="non_310P device unittest case.")
|
||||
@pytest.mark.parametrize("residual", [None, torch.randn(4, 8, dtype=torch.float32)])
|
||||
@patch("torch_npu.npu_rms_norm", side_effect=mock_rms_norm)
|
||||
@@ -68,19 +68,18 @@ def test_RMSNorm_forward(
|
||||
@pytest.mark.skipif(not is_310p_hw(), reason="310P device unittest case.")
|
||||
@pytest.mark.parametrize("residual", [None, torch.randn(4, 8, dtype=torch.float16)])
|
||||
@patch("torch_npu.npu_rms_norm", side_effect=mock_rms_norm)
|
||||
def test_RMSNorm_forward_310p(
|
||||
mock_rmsnorm, residual, dummy_tensor, default_vllm_config
|
||||
):
|
||||
@patch("torch_npu.npu_add_rms_norm", side_effect=mock_add_rms_norm)
|
||||
def test_RMSNorm_forward_310p(mock_add_rmsnorm, mock_rmsnorm, residual, dummy_tensor, default_vllm_config):
|
||||
layer = RMSNorm(hidden_size=8, eps=1e-05)
|
||||
if residual is not None:
|
||||
out_x, out_residual = layer.forward_oot(dummy_tensor, residual)
|
||||
expected_out_residual = dummy_tensor + residual
|
||||
expected_out_x = expected_out_residual + 1
|
||||
mock_rmsnorm.assert_called_once()
|
||||
expected_out_x = 2 * dummy_tensor
|
||||
expected_out_residual = 2 * residual
|
||||
mock_add_rmsnorm.assert_called_once()
|
||||
assert torch.allclose(out_x, expected_out_x)
|
||||
assert torch.allclose(out_residual, expected_out_residual)
|
||||
else:
|
||||
out_x = layer.forward_oot(dummy_tensor, residual)
|
||||
expected_out_x = dummy_tensor + 1
|
||||
mock_rmsnorm.assert_called_once()
|
||||
assert torch.allclose(out_x, expected_out_x)
|
||||
assert torch.allclose(out_x, expected_out_x)
|
||||
|
||||
@@ -11,13 +11,9 @@ class AscendRMSNorm310(AscendRMSNorm):
|
||||
residual: torch.Tensor | None = None,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
if residual is not None:
|
||||
if x is None or x.numel() == 0 or x.shape[-1] == 0:
|
||||
x = residual
|
||||
else:
|
||||
x = x + residual
|
||||
|
||||
residual = x
|
||||
x, _ = torch_npu.npu_rms_norm(x, self.weight, self.variance_epsilon)
|
||||
x, _, residual = torch_npu.npu_add_rms_norm(x, residual, self.weight, self.variance_epsilon)
|
||||
if self.bias is not None:
|
||||
x.add_(self.bias)
|
||||
return x, residual
|
||||
|
||||
x, _ = torch_npu.npu_rms_norm(x, self.weight, self.variance_epsilon)
|
||||
|
||||
@@ -1,61 +0,0 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
|
||||
import torch
|
||||
import torch_npu
|
||||
|
||||
from vllm_ascend.ops.mm_encoder_attention import AscendMMEncoderAttention
|
||||
|
||||
|
||||
class AscendMMEncoderAttention310(AscendMMEncoderAttention):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def forward_oot(
|
||||
self,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
cu_seqlens: torch.Tensor | None = None,
|
||||
max_seqlen: int | None = None,
|
||||
**kwargs,
|
||||
):
|
||||
bsz, q_len = query.size()[:2]
|
||||
kv_len = key.size(1)
|
||||
query = query.view(bsz * q_len, self.num_heads, self.head_size)
|
||||
key = key.view(bsz * kv_len, self.num_kv_heads, self.head_size)
|
||||
value = value.view(bsz * kv_len, self.num_kv_heads, self.head_size)
|
||||
|
||||
if cu_seqlens is None:
|
||||
seq_len = torch.tensor([q_len] * bsz, device="cpu", dtype=torch.int32)
|
||||
else:
|
||||
seq_len = torch.diff(cu_seqlens.to("cpu", dtype=torch.int32))
|
||||
|
||||
output = torch.empty_like(query)
|
||||
torch_npu._npu_flash_attention_unpad(
|
||||
query=query,
|
||||
key=key,
|
||||
value=value,
|
||||
seq_len=seq_len,
|
||||
scale_value=self.head_size**-0.5,
|
||||
num_heads=self.num_heads,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
out=output,
|
||||
)
|
||||
|
||||
output = output.view(bsz, -1, self.num_heads, self.head_size)
|
||||
return output
|
||||
@@ -628,13 +628,11 @@ def register_ascend_customop(vllm_config: VllmConfig | None = None):
|
||||
from vllm_ascend._310p.fused_moe.fused_moe import AscendFusedMoE310, AscendSharedFusedMoE310
|
||||
from vllm_ascend._310p.ops.activation import AscendSiluAndMul310
|
||||
from vllm_ascend._310p.ops.layernorm import AscendGemmaRMSNorm310, AscendRMSNorm310
|
||||
from vllm_ascend._310p.ops.mm_encoder_attention import AscendMMEncoderAttention310
|
||||
from vllm_ascend._310p.ops.rotary_embedding import AscendRotaryEmbedding310
|
||||
|
||||
REGISTERED_ASCEND_OPS.update(
|
||||
{
|
||||
"SiluAndMul": AscendSiluAndMul310,
|
||||
"MMEncoderAttention": AscendMMEncoderAttention310,
|
||||
"RotaryEmbedding": AscendRotaryEmbedding310,
|
||||
"RMSNorm": AscendRMSNorm310,
|
||||
"GemmaRMSNorm": AscendGemmaRMSNorm310,
|
||||
|
||||
Reference in New Issue
Block a user