[Graph][Fusion]Add new pattern for AddRmsnormQuant with SP. (#5077)
### What this PR does / why we need it?
1. In addition to
[#4168](https://github.com/vllm-project/vllm-ascend/pull/4168),
[#5011](https://github.com/vllm-project/vllm-ascend/pull/5011), this PR
adds two more pattern for AddRmsnormQuant with SP enabled. The key
difference is to insert an additional `maybe_all_gather_and_maybe_unpad`
between `addrmsnorm` and `quantize`.
2. This PR also introduce another api `torch.ops.vllm.quantize`, so that
we pass `input_scale` and `input_scale_reciprocal` at the same time.
This is because `npu_add_rms_norm_quant` and `npu_quantize` requires
different `div_mode`. To avoid introducing additional reciprocal
calculation in runtime, we have to pass both of them to quantize api.
3. Removes redundant `AscendQuantRmsnorm`.
- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c
---------
Signed-off-by: Angazenn <supperccell@163.com>
This commit is contained in:
@@ -282,6 +282,26 @@ def _matmul_and_reduce_impl_fake(input_parallel: torch.Tensor,
|
||||
return output
|
||||
|
||||
|
||||
# TODO(Angazenn): The reason why we use a custom op to encapsulate npu_quantize
|
||||
# is that aclnnAscendQuantV3(npu_quantize) use div_mode=False, while
|
||||
# aclnnAddRmsNormQuantV2(npu_add_rms_norm_quant) use div_moe=True. We have to
|
||||
# pass input_scale and input_scale_reciprocal at the same time to avoid redundant
|
||||
# reciprocal calculation in fussion pass. We shall remove this once
|
||||
# aclnnAddRmsNormQuantV2 supports div_moe=False.
|
||||
def _quantize_impl(in_tensor: torch.Tensor, input_scale: torch.Tensor,
|
||||
input_scale_reciprocal: torch.Tensor,
|
||||
input_offset: torch.Tensor) -> torch.Tensor:
|
||||
return torch_npu.npu_quantize(in_tensor, input_scale_reciprocal,
|
||||
input_offset, torch.qint8, -1, False)
|
||||
|
||||
|
||||
def _quantize_impl_fake(in_tensor: torch.Tensor, input_scale: torch.Tensor,
|
||||
input_scale_reciprocal: torch.Tensor,
|
||||
input_offset: torch.Tensor) -> torch.Tensor:
|
||||
return torch_npu.npu_quantize(in_tensor, input_scale_reciprocal,
|
||||
input_offset, torch.qint8, -1, False)
|
||||
|
||||
|
||||
direct_register_custom_op(op_name="maybe_chunk_residual",
|
||||
op_func=_maybe_chunk_residual_impl,
|
||||
fake_impl=lambda x, residual: x,
|
||||
@@ -341,3 +361,9 @@ direct_register_custom_op(op_name="matmul_and_reduce",
|
||||
fake_impl=_matmul_and_reduce_impl_fake,
|
||||
mutates_args=[],
|
||||
dispatch_key="PrivateUse1")
|
||||
|
||||
direct_register_custom_op(op_name="quantize",
|
||||
op_func=_quantize_impl,
|
||||
fake_impl=_quantize_impl_fake,
|
||||
mutates_args=[],
|
||||
dispatch_key="PrivateUse1")
|
||||
|
||||
Reference in New Issue
Block a user