From 33ef5dc8130fb7132b23c14d13a5c9361118d912 Mon Sep 17 00:00:00 2001 From: xudongLi-cmss <41658239+xudongLi-cmss@users.noreply.github.com> Date: Fri, 18 Jul 2025 11:05:17 +0800 Subject: [PATCH] add unit test for func wrapper (#1863) ### What this PR does / why we need it? test func wrapper file ### Does this PR introduce _any_ user-facing change? N/A ### How was this patch tested? CI passed with new added test. - vLLM version: v0.9.2 - vLLM main: https://github.com/vllm-project/vllm/commit/8dfb45ca3379b3a789ec529af4bf725daa07f10d Signed-off-by: lixudong --- tests/ut/quantization/test_func_wrapper.py | 134 +++++++++++++++++++++ 1 file changed, 134 insertions(+) create mode 100644 tests/ut/quantization/test_func_wrapper.py diff --git a/tests/ut/quantization/test_func_wrapper.py b/tests/ut/quantization/test_func_wrapper.py new file mode 100644 index 0000000..2f9bd89 --- /dev/null +++ b/tests/ut/quantization/test_func_wrapper.py @@ -0,0 +1,134 @@ +import unittest +from unittest.mock import patch + +import torch + +from vllm_ascend.quantization.func_wrapper import (wrapper_rmsnorm_forward_oot, + wrapper_rmsnorm_init) + + +class MockRMSNorm: + + def __init__(self, hidden_size: int, **extra_args): + self.hidden_size = hidden_size + self.weight = torch.ones(hidden_size) + self.input_scale = 1.0 + self.input_offset = 0.0 + self.variance_epsilon = 1e-6 + self.bias = torch.nn.Parameter(torch.zeros(hidden_size), + requires_grad=False) + self.ignore_anti = extra_args.get('ignore_anti', True) + + +class TestFuncWrapper(unittest.TestCase): + + def test_wrapper_rmsnorm_init(self): + + @wrapper_rmsnorm_init + def init(self, hidden_size: int, **extra_args) -> None: + self.hidden_size = hidden_size + + hidden_size = 128 + extra_args = {'arg1': 'value1'} + + rms_norm = MockRMSNorm(hidden_size, **extra_args) + init(rms_norm, hidden_size, **extra_args) + + self.assertTrue(hasattr(rms_norm, 'ignore_anti')) + self.assertTrue(rms_norm.ignore_anti) + + self.assertTrue(hasattr(rms_norm, 'bias')) + self.assertIsInstance(rms_norm.bias, torch.nn.Parameter) + self.assertEqual(rms_norm.bias.shape, torch.Size([hidden_size])) + self.assertFalse(rms_norm.bias.requires_grad) + + @patch('torch_npu._npu_quant_rms_norm') + def test_wrapper_rmsnorm_forward_oot_with_residual( + self, mock_npu_quant_rms_norm): + hidden_size = 128 + x = torch.randn(hidden_size) + residual = torch.randn(hidden_size) + expected_out = torch.randn(hidden_size) + + mock_npu_quant_rms_norm.return_value = (expected_out, residual) + + @wrapper_rmsnorm_forward_oot + def forward_oot(self, x: torch.Tensor, residual: torch.Tensor = None): + return x, residual + + rms_norm = MockRMSNorm(hidden_size) + rms_norm.ignore_anti = False + + output, res = forward_oot(rms_norm, x, residual) + + mock_npu_quant_rms_norm.assert_called_once() + + args, kwargs = mock_npu_quant_rms_norm.call_args + self.assertTrue(torch.equal(args[1], rms_norm.weight)) + self.assertTrue(torch.equal(args[2], rms_norm.bias)) + self.assertEqual(args[3], rms_norm.input_scale) + self.assertEqual(args[4], rms_norm.input_offset) + self.assertEqual(args[5], rms_norm.variance_epsilon) + self.assertTrue(torch.equal(res, residual)) + + @patch('torch_npu._npu_quant_rms_norm') + def test_wrapper_rmsnorm_forward_oot_without_residual( + self, mock_npu_quant_rms_norm): + hidden_size = 128 + x = torch.randn(hidden_size) + expected_out = torch.randn(hidden_size) + + mock_npu_quant_rms_norm.return_value = expected_out + + @wrapper_rmsnorm_forward_oot + def forward_oot(self, x: torch.Tensor, residual: torch.Tensor = None): + return x + + rms_norm = MockRMSNorm(hidden_size) + rms_norm.ignore_anti = False + + output = forward_oot(rms_norm, x) + + mock_npu_quant_rms_norm.assert_called_once() + + args, kwargs = mock_npu_quant_rms_norm.call_args + self.assertTrue(torch.equal(args[0], x)) + self.assertTrue(torch.equal(args[1], rms_norm.weight)) + self.assertTrue(torch.equal(args[2], rms_norm.bias)) + self.assertEqual(args[3], rms_norm.input_scale) + self.assertEqual(args[4], rms_norm.input_offset) + self.assertEqual(args[5], rms_norm.variance_epsilon) + + self.assertTrue(torch.equal(output, expected_out)) + + def test_wrapper_rmsnorm_forward_oot_ignore_anti_with_residual(self): + hidden_size = 128 + x = torch.randn(hidden_size) + residual = torch.randn(hidden_size) + + @wrapper_rmsnorm_forward_oot + def forward_oot(self, x: torch.Tensor, residual: torch.Tensor = None): + return x, residual + + rms_norm = MockRMSNorm(hidden_size) + rms_norm.ignore_anti = True + + output, res = forward_oot(rms_norm, x, residual) + + self.assertTrue(torch.equal(output, x.add_(rms_norm.bias))) + self.assertTrue(torch.equal(res, residual)) + + def test_wrapper_rmsnorm_forward_oot_ignore_anti_no_residual(self): + hidden_size = 128 + x = torch.randn(hidden_size) + + @wrapper_rmsnorm_forward_oot + def forward_oot(self, x: torch.Tensor, residual: torch.Tensor = None): + return x + + rms_norm = MockRMSNorm(hidden_size) + rms_norm.ignore_anti = True + + output = forward_oot(rms_norm, x) + + self.assertTrue(torch.equal(output, x.add_(rms_norm.bias)))