[Graph][Fusion]Add new pattern for AddRmsnormQuant with SP. (#5077)
### What this PR does / why we need it?
1. In addition to
[#4168](https://github.com/vllm-project/vllm-ascend/pull/4168),
[#5011](https://github.com/vllm-project/vllm-ascend/pull/5011), this PR
adds two more pattern for AddRmsnormQuant with SP enabled. The key
difference is to insert an additional `maybe_all_gather_and_maybe_unpad`
between `addrmsnorm` and `quantize`.
2. This PR also introduce another api `torch.ops.vllm.quantize`, so that
we pass `input_scale` and `input_scale_reciprocal` at the same time.
This is because `npu_add_rms_norm_quant` and `npu_quantize` requires
different `div_mode`. To avoid introducing additional reciprocal
calculation in runtime, we have to pass both of them to quantize api.
3. Removes redundant `AscendQuantRmsnorm`.
- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c
---------
Signed-off-by: Angazenn <supperccell@163.com>
This commit is contained in:
@@ -70,10 +70,9 @@ class TestAscendW8A8LinearMethod(TestBase):
|
||||
self.assertEqual(params['weight_offset'].shape, (10, 1))
|
||||
|
||||
@patch("vllm_ascend.quantization.w8a8.get_forward_context")
|
||||
@patch("vllm_ascend.quantization.w8a8.quant_per_tensor")
|
||||
@patch("torch.ops.vllm.quantize")
|
||||
@patch("torch_npu.npu_quant_matmul")
|
||||
def test_apply_with_x_not_int8(self, mock_npu_quant_matmul,
|
||||
mock_quant_per_tensor,
|
||||
def test_apply_with_x_not_int8(self, mock_npu_quant_matmul, mock_quantize,
|
||||
mock_get_forward_context):
|
||||
layer = MagicMock()
|
||||
layer.aclnn_input_scale = 0.1
|
||||
@@ -88,10 +87,10 @@ class TestAscendW8A8LinearMethod(TestBase):
|
||||
|
||||
x = torch.randn(32, 128)
|
||||
bias = torch.randn(256)
|
||||
mock_quant_per_tensor.return_value = torch.randint(-128,
|
||||
127,
|
||||
x.shape,
|
||||
dtype=torch.int8)
|
||||
mock_quantize.return_value = torch.randint(-128,
|
||||
127,
|
||||
x.shape,
|
||||
dtype=torch.int8)
|
||||
|
||||
expected_y_output = torch.randn(32, 256)
|
||||
mock_npu_quant_matmul.return_value = expected_y_output
|
||||
|
||||
Reference in New Issue
Block a user