### What this PR does / why we need it?
Use base test to avoid patch everwhere.
Followup here: https://github.com/vllm-project/vllm-ascend/pull/1566
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
ut ci passed
- vLLM version: v0.9.2
- vLLM main:
8d0a01a5f2
Signed-off-by: Yikun Jiang <yikunkero@gmail.com>
135 lines
4.6 KiB
Python
135 lines
4.6 KiB
Python
from unittest.mock import patch
|
|
|
|
import torch
|
|
|
|
from tests.ut.base import TestBase
|
|
from vllm_ascend.quantization.func_wrapper import (wrapper_rmsnorm_forward_oot,
|
|
wrapper_rmsnorm_init)
|
|
|
|
|
|
class MockRMSNorm:
|
|
|
|
def __init__(self, hidden_size: int, **extra_args):
|
|
self.hidden_size = hidden_size
|
|
self.weight = torch.ones(hidden_size)
|
|
self.input_scale = 1.0
|
|
self.input_offset = 0.0
|
|
self.variance_epsilon = 1e-6
|
|
self.bias = torch.nn.Parameter(torch.zeros(hidden_size),
|
|
requires_grad=False)
|
|
self.ignore_anti = extra_args.get('ignore_anti', True)
|
|
|
|
|
|
class TestFuncWrapper(TestBase):
|
|
|
|
def test_wrapper_rmsnorm_init(self):
|
|
|
|
@wrapper_rmsnorm_init
|
|
def init(self, hidden_size: int, **extra_args) -> None:
|
|
self.hidden_size = hidden_size
|
|
|
|
hidden_size = 128
|
|
extra_args = {'arg1': 'value1'}
|
|
|
|
rms_norm = MockRMSNorm(hidden_size, **extra_args)
|
|
init(rms_norm, hidden_size, **extra_args)
|
|
|
|
self.assertTrue(hasattr(rms_norm, 'ignore_anti'))
|
|
self.assertTrue(rms_norm.ignore_anti)
|
|
|
|
self.assertTrue(hasattr(rms_norm, 'bias'))
|
|
self.assertIsInstance(rms_norm.bias, torch.nn.Parameter)
|
|
self.assertEqual(rms_norm.bias.shape, torch.Size([hidden_size]))
|
|
self.assertFalse(rms_norm.bias.requires_grad)
|
|
|
|
@patch('torch_npu._npu_quant_rms_norm')
|
|
def test_wrapper_rmsnorm_forward_oot_with_residual(
|
|
self, mock_npu_quant_rms_norm):
|
|
hidden_size = 128
|
|
x = torch.randn(hidden_size)
|
|
residual = torch.randn(hidden_size)
|
|
expected_out = torch.randn(hidden_size)
|
|
|
|
mock_npu_quant_rms_norm.return_value = (expected_out, residual)
|
|
|
|
@wrapper_rmsnorm_forward_oot
|
|
def forward_oot(self, x: torch.Tensor, residual: torch.Tensor = None):
|
|
return x, residual
|
|
|
|
rms_norm = MockRMSNorm(hidden_size)
|
|
rms_norm.ignore_anti = False
|
|
|
|
output, res = forward_oot(rms_norm, x, residual)
|
|
|
|
mock_npu_quant_rms_norm.assert_called_once()
|
|
|
|
args, kwargs = mock_npu_quant_rms_norm.call_args
|
|
self.assertTrue(torch.equal(args[1], rms_norm.weight))
|
|
self.assertTrue(torch.equal(args[2], rms_norm.bias))
|
|
self.assertEqual(args[3], rms_norm.input_scale)
|
|
self.assertEqual(args[4], rms_norm.input_offset)
|
|
self.assertEqual(args[5], rms_norm.variance_epsilon)
|
|
self.assertTrue(torch.equal(res, residual))
|
|
|
|
@patch('torch_npu._npu_quant_rms_norm')
|
|
def test_wrapper_rmsnorm_forward_oot_without_residual(
|
|
self, mock_npu_quant_rms_norm):
|
|
hidden_size = 128
|
|
x = torch.randn(hidden_size)
|
|
expected_out = torch.randn(hidden_size)
|
|
|
|
mock_npu_quant_rms_norm.return_value = expected_out
|
|
|
|
@wrapper_rmsnorm_forward_oot
|
|
def forward_oot(self, x: torch.Tensor, residual: torch.Tensor = None):
|
|
return x
|
|
|
|
rms_norm = MockRMSNorm(hidden_size)
|
|
rms_norm.ignore_anti = False
|
|
|
|
output = forward_oot(rms_norm, x)
|
|
|
|
mock_npu_quant_rms_norm.assert_called_once()
|
|
|
|
args, kwargs = mock_npu_quant_rms_norm.call_args
|
|
self.assertTrue(torch.equal(args[0], x))
|
|
self.assertTrue(torch.equal(args[1], rms_norm.weight))
|
|
self.assertTrue(torch.equal(args[2], rms_norm.bias))
|
|
self.assertEqual(args[3], rms_norm.input_scale)
|
|
self.assertEqual(args[4], rms_norm.input_offset)
|
|
self.assertEqual(args[5], rms_norm.variance_epsilon)
|
|
|
|
self.assertTrue(torch.equal(output, expected_out))
|
|
|
|
def test_wrapper_rmsnorm_forward_oot_ignore_anti_with_residual(self):
|
|
hidden_size = 128
|
|
x = torch.randn(hidden_size)
|
|
residual = torch.randn(hidden_size)
|
|
|
|
@wrapper_rmsnorm_forward_oot
|
|
def forward_oot(self, x: torch.Tensor, residual: torch.Tensor = None):
|
|
return x, residual
|
|
|
|
rms_norm = MockRMSNorm(hidden_size)
|
|
rms_norm.ignore_anti = True
|
|
|
|
output, res = forward_oot(rms_norm, x, residual)
|
|
|
|
self.assertTrue(torch.equal(output, x.add_(rms_norm.bias)))
|
|
self.assertTrue(torch.equal(res, residual))
|
|
|
|
def test_wrapper_rmsnorm_forward_oot_ignore_anti_no_residual(self):
|
|
hidden_size = 128
|
|
x = torch.randn(hidden_size)
|
|
|
|
@wrapper_rmsnorm_forward_oot
|
|
def forward_oot(self, x: torch.Tensor, residual: torch.Tensor = None):
|
|
return x
|
|
|
|
rms_norm = MockRMSNorm(hidden_size)
|
|
rms_norm.ignore_anti = True
|
|
|
|
output = forward_oot(rms_norm, x)
|
|
|
|
self.assertTrue(torch.equal(output, x.add_(rms_norm.bias)))
|