[Fusion] Add rmsnorm dynamic quant fusion pass (#6274)

### What this PR does / why we need it?

This PR introduces four new patterns to support the fusion of RMSNorm
and DynamicQuant operators. After replacing the fusion operators, the
execution time has been reduced from 22.8us to 16.9us.

### Does this PR introduce _any_ user-facing change?

N/A

### How was this patch tested?


- vLLM version: v0.14.1
- vLLM main:
d7de043d55

Signed-off-by: Bryan <250470359+Zhang-Bryan@users.noreply.github.com>
This commit is contained in:
Zhang-Bryan
2026-02-04 15:53:53 +08:00
committed by GitHub
parent e7a13beedb
commit 804a9ec4e6

View File

@@ -268,6 +268,200 @@ class AddRMSNormQuantSPPatternWithBias:
pm.register_replacement(pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass)
class AddRMSNormDynamicQuantPattern:
def __init__(self, vllm_config: VllmConfig, eps: float = 1e-6):
self.vllm_config = vllm_config
self.dtype = vllm_config.model_config.dtype
self.eps = eps
def get_inputs(self):
"""
Generate example inputs for the AddRMSNormQuant fusion pattern.
"""
rms_norm_input = torch.randn(2, 4, device="npu", dtype=self.dtype)
residual = torch.randn(2, 4, device="npu", dtype=self.dtype)
rms_norm_weight = torch.randn(4, device="npu", dtype=self.dtype)
return [rms_norm_input, residual, rms_norm_weight]
def register(self, pm_pass: PatternMatcherPass):
def pattern(rms_norm_input: torch.Tensor, residual: torch.Tensor, rms_norm_weight: torch.Tensor):
"""
Pattern for AddRMSNormQuant fusion.
"""
output = torch.ops.npu.npu_add_rms_norm(rms_norm_input, residual, rms_norm_weight, self.eps)
out0 = output[0]
out1 = output[2]
quantized_output = torch.ops.npu.npu_dynamic_quant(out0)
return quantized_output[0], quantized_output[1], out1
def replacement(rms_norm_input: torch.Tensor, residual: torch.Tensor, rms_norm_weight: torch.Tensor):
"""
Replacement for the AddRMSNormQuant fusion.
"""
output = torch.ops.npu.npu_add_rms_norm_dynamic_quant(
rms_norm_input, residual, rms_norm_weight, epsilon=self.eps, output_mask=[True, True]
)
return (
output[0],
output[3],
output[2],
)
pm.register_replacement(pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass)
class AddRMSNormDynamicQuantPatternWithBias:
def __init__(self, vllm_config: VllmConfig, eps: float = 1e-6):
self.vllm_config = vllm_config
self.dtype = vllm_config.model_config.dtype
self.eps = eps
def get_inputs(self):
"""
Generate example inputs for the AddRMSNormQuant fusion pattern.
"""
rms_norm_input = torch.randn(2, 4, device="npu", dtype=self.dtype)
residual = torch.randn(2, 4, device="npu", dtype=self.dtype)
rms_norm_weight = torch.randn(4, device="npu", dtype=self.dtype)
rmsnorm_bias = torch.randn(4, device="npu", dtype=self.dtype)
return [rms_norm_input, residual, rms_norm_weight, rmsnorm_bias]
def register(self, pm_pass: PatternMatcherPass):
def pattern(
rms_norm_input: torch.Tensor,
residual: torch.Tensor,
rms_norm_weight: torch.Tensor,
bias: torch.Tensor,
):
"""
Pattern for AddRMSNormQuant fusion.
"""
output = torch.ops._C_ascend.npu_add_rms_norm_bias(
rms_norm_input, residual, rms_norm_weight, bias, self.eps
)
out0 = output[0]
out1 = output[2]
quantized_output = torch.ops.npu.npu_dynamic_quant(out0)
return quantized_output[0], quantized_output[1], out1
def replacement(
rms_norm_input: torch.Tensor,
residual: torch.Tensor,
rms_norm_weight: torch.Tensor,
bias: torch.Tensor,
):
"""
Replacement for the AddRMSNormQuant fusion.
"""
output = torch.ops.npu.npu_add_rms_norm_dynamic_quant(
rms_norm_input, residual, rms_norm_weight, epsilon=self.eps, output_mask=[True, True], beta=bias
)
return (
output[0],
output[3],
output[2],
)
pm.register_replacement(pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass)
class AddRMSNormDynamicQuantSPPattern:
def __init__(self, vllm_config: VllmConfig, eps: float = 1e-6):
self.vllm_config = vllm_config
self.dtype = vllm_config.model_config.dtype
self.eps = eps
def get_inputs(self):
"""
Generate example inputs for the AddRMSNormQuant fusion pattern.
"""
rms_norm_input = torch.randn(2, 4, device="npu", dtype=self.dtype)
residual = torch.randn(2, 4, device="npu", dtype=self.dtype)
rms_norm_weight = torch.randn(4, device="npu", dtype=self.dtype)
return [rms_norm_input, residual, rms_norm_weight]
def register(self, pm_pass: PatternMatcherPass):
def pattern(rms_norm_input: torch.Tensor, residual: torch.Tensor, rms_norm_weight: torch.Tensor):
"""
Pattern for AddRMSNormQuant fusion.
"""
output = torch.ops.npu.npu_add_rms_norm(rms_norm_input, residual, rms_norm_weight, self.eps)
out0 = output[0]
out1 = output[2]
out0 = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(out0, True)
quantized_output = torch.ops.npu.npu_dynamic_quant(out0)
return quantized_output[0], quantized_output[1], out1
def replacement(rms_norm_input: torch.Tensor, residual: torch.Tensor, rms_norm_weight: torch.Tensor):
"""
Replacement for the AddRMSNormQuant fusion.
"""
output = torch.ops.npu.npu_add_rms_norm_dynamic_quant(
rms_norm_input, residual, rms_norm_weight, epsilon=self.eps, output_mask=[True, True]
)
out3 = output[3]
quantized_output = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(output[0], True)
out3 = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(out3, True)
return quantized_output, out3, output[2]
pm.register_replacement(pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass)
class AddRMSNormDynamicQuantSPPatternWithBias:
def __init__(self, vllm_config: VllmConfig, eps: float = 1e-6):
self.vllm_config = vllm_config
self.dtype = vllm_config.model_config.dtype
self.eps = eps
def get_inputs(self):
"""
Generate example inputs for the AddRMSNormQuant fusion pattern.
"""
rms_norm_input = torch.randn(2, 4, device="npu", dtype=self.dtype)
residual = torch.randn(2, 4, device="npu", dtype=self.dtype)
rms_norm_weight = torch.randn(4, device="npu", dtype=self.dtype)
rmsnorm_bias = torch.randn(4, device="npu", dtype=self.dtype)
return [rms_norm_input, residual, rms_norm_weight, rmsnorm_bias]
def register(self, pm_pass: PatternMatcherPass):
def pattern(
rms_norm_input: torch.Tensor,
residual: torch.Tensor,
rms_norm_weight: torch.Tensor,
bias: torch.Tensor,
):
"""
Pattern for AddRMSNormQuant fusion.
"""
output = torch.ops._C_ascend.npu_add_rms_norm_bias(
rms_norm_input, residual, rms_norm_weight, bias, self.eps
)
out0 = output[0]
out1 = output[2]
out0 = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(out0, True)
quantized_output = torch.ops.npu.npu_dynamic_quant(out0)
return quantized_output[0], quantized_output[1], out1
def replacement(
rms_norm_input: torch.Tensor,
residual: torch.Tensor,
rms_norm_weight: torch.Tensor,
bias: torch.Tensor,
):
"""
Replacement for the AddRMSNormQuant fusion.
"""
output = torch.ops.npu.npu_add_rms_norm_dynamic_quant(
rms_norm_input, residual, rms_norm_weight, epsilon=self.eps, output_mask=[True, True], beta=bias
)
out3 = output[3]
quantized_output = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(output[0], True)
out3 = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(out3, True)
return quantized_output, out3, output[2]
pm.register_replacement(pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass)
class AddRMSNormQuantFusionPass(VllmInductorPass):
"""
A pass for fusing AddRMSNorm and W8A8 quantization operations on Ascend.
@@ -286,9 +480,13 @@ class AddRMSNormQuantFusionPass(VllmInductorPass):
for eps in common_epsilons:
AddRMSNormQuantPattern(vllm_config, eps=eps).register(self.pattern_match_passes)
AddRMSNormQuantSPPattern(vllm_config, eps=eps).register(self.pattern_match_passes)
AddRMSNormDynamicQuantPattern(vllm_config, eps=eps).register(self.pattern_match_passes)
AddRMSNormDynamicQuantSPPattern(vllm_config, eps=eps).register(self.pattern_match_passes)
if enable_custom_op():
AddRMSNormQuantPatternWithBias(vllm_config, eps=eps).register(self.pattern_match_passes)
AddRMSNormQuantSPPatternWithBias(vllm_config, eps=eps).register(self.pattern_match_passes)
AddRMSNormDynamicQuantPatternWithBias(vllm_config, eps=eps).register(self.pattern_match_passes)
AddRMSNormDynamicQuantSPPatternWithBias(vllm_config, eps=eps).register(self.pattern_match_passes)
def __call__(self, graph: torch.fx.Graph):
self.begin()