[feature] add_rms_norm support bias (#5790)

### What this PR does / why we need it?
This PR is to replace addRmsNorm and Add With addRmsNormBias. This way
can lead to a more effecient result.

### Does this PR introduce _any_ user-facing change?
No

### How was this patch tested?
Full Test Pass

- vLLM version: v0.13.0
- vLLM main:
2f4e6548ef

Signed-off-by: Chen_HaoWen <chenhaowen12@huawei.com>
Co-authored-by: Chen_HaoWen <chenhaowen12@huawei.com>
This commit is contained in:
yjmyl
2026-01-23 21:09:54 +08:00
committed by GitHub
parent 6c73b88dd6
commit e90b14140b
24 changed files with 3537 additions and 13 deletions

View File

@@ -6,6 +6,8 @@ from vllm.config import set_current_vllm_config
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm_ascend.utils import AscendDeviceType
from vllm_ascend.utils import enable_custom_op
enable_custom_op()
@pytest.fixture
@@ -20,6 +22,13 @@ def mock_rms_norm(x, weight, eps):
def mock_add_rms_norm(x, residual, weight, eps):
return 2 * x, None, 2 * residual
def mock_add_rms_norm_bias(x, residual, weight, bias, eps):
if bias is None:
return 2 * x, None, 2 * residual
else:
return 2 * x + bias, None, 2 * residual
@pytest.fixture(autouse=True)
def default_vllm_config():
@@ -35,7 +44,8 @@ def default_vllm_config():
[None, torch.randn(4, 8, dtype=torch.float32)])
@patch("torch_npu.npu_rms_norm", side_effect=mock_rms_norm)
@patch("torch_npu.npu_add_rms_norm", side_effect=mock_add_rms_norm)
def test_RMSNorm_forward(mock_add_rmsnorm, mock_rmsnorm, is_310p, residual,
@patch("torch.ops._C_ascend.npu_add_rms_norm_bias", side_effect=mock_add_rms_norm_bias)
def test_RMSNorm_forward(mock_add_rms_norm_bias, mock_add_rmsnorm, mock_rmsnorm, is_310p, residual,
dummy_tensor, default_vllm_config):
with patch("vllm_ascend.utils.get_ascend_device_type",
@@ -56,7 +66,7 @@ def test_RMSNorm_forward(mock_add_rmsnorm, mock_rmsnorm, is_310p, residual,
else:
expected_out_x = 2 * dummy_tensor
expected_out_residual = 2 * residual
mock_add_rmsnorm.assert_called_once()
mock_add_rms_norm_bias.assert_called_once()
assert torch.allclose(out_x, expected_out_x)
assert torch.allclose(out_residual, expected_out_residual)
else: