[feature] add_rms_norm support bias (#5790)
### What this PR does / why we need it?
This PR is to replace addRmsNorm and Add With addRmsNormBias. This way
can lead to a more effecient result.
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
Full Test Pass
- vLLM version: v0.13.0
- vLLM main:
2f4e6548ef
Signed-off-by: Chen_HaoWen <chenhaowen12@huawei.com>
Co-authored-by: Chen_HaoWen <chenhaowen12@huawei.com>
This commit is contained in:
@@ -6,6 +6,8 @@ from vllm.config import set_current_vllm_config
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
|
||||
from vllm_ascend.utils import AscendDeviceType
|
||||
from vllm_ascend.utils import enable_custom_op
|
||||
enable_custom_op()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -20,6 +22,13 @@ def mock_rms_norm(x, weight, eps):
|
||||
def mock_add_rms_norm(x, residual, weight, eps):
|
||||
return 2 * x, None, 2 * residual
|
||||
|
||||
def mock_add_rms_norm_bias(x, residual, weight, bias, eps):
|
||||
if bias is None:
|
||||
return 2 * x, None, 2 * residual
|
||||
else:
|
||||
return 2 * x + bias, None, 2 * residual
|
||||
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def default_vllm_config():
|
||||
@@ -35,7 +44,8 @@ def default_vllm_config():
|
||||
[None, torch.randn(4, 8, dtype=torch.float32)])
|
||||
@patch("torch_npu.npu_rms_norm", side_effect=mock_rms_norm)
|
||||
@patch("torch_npu.npu_add_rms_norm", side_effect=mock_add_rms_norm)
|
||||
def test_RMSNorm_forward(mock_add_rmsnorm, mock_rmsnorm, is_310p, residual,
|
||||
@patch("torch.ops._C_ascend.npu_add_rms_norm_bias", side_effect=mock_add_rms_norm_bias)
|
||||
def test_RMSNorm_forward(mock_add_rms_norm_bias, mock_add_rmsnorm, mock_rmsnorm, is_310p, residual,
|
||||
dummy_tensor, default_vllm_config):
|
||||
|
||||
with patch("vllm_ascend.utils.get_ascend_device_type",
|
||||
@@ -56,7 +66,7 @@ def test_RMSNorm_forward(mock_add_rmsnorm, mock_rmsnorm, is_310p, residual,
|
||||
else:
|
||||
expected_out_x = 2 * dummy_tensor
|
||||
expected_out_residual = 2 * residual
|
||||
mock_add_rmsnorm.assert_called_once()
|
||||
mock_add_rms_norm_bias.assert_called_once()
|
||||
assert torch.allclose(out_x, expected_out_x)
|
||||
assert torch.allclose(out_residual, expected_out_residual)
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user