[Fusion] Add rmsnorm dynamic quant fusion pass (#6274)
### What this PR does / why we need it?
This PR introduces four new patterns to support the fusion of RMSNorm
and DynamicQuant operators. After replacing the fusion operators, the
execution time has been reduced from 22.8us to 16.9us.
### Does this PR introduce _any_ user-facing change?
N/A
### How was this patch tested?
- vLLM version: v0.14.1
- vLLM main:
d7de043d55
Signed-off-by: Bryan <250470359+Zhang-Bryan@users.noreply.github.com>
This commit is contained in:
@@ -268,6 +268,200 @@ class AddRMSNormQuantSPPatternWithBias:
|
||||
pm.register_replacement(pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass)
|
||||
|
||||
|
||||
class AddRMSNormDynamicQuantPattern:
|
||||
def __init__(self, vllm_config: VllmConfig, eps: float = 1e-6):
|
||||
self.vllm_config = vllm_config
|
||||
self.dtype = vllm_config.model_config.dtype
|
||||
self.eps = eps
|
||||
|
||||
def get_inputs(self):
|
||||
"""
|
||||
Generate example inputs for the AddRMSNormQuant fusion pattern.
|
||||
"""
|
||||
rms_norm_input = torch.randn(2, 4, device="npu", dtype=self.dtype)
|
||||
residual = torch.randn(2, 4, device="npu", dtype=self.dtype)
|
||||
rms_norm_weight = torch.randn(4, device="npu", dtype=self.dtype)
|
||||
return [rms_norm_input, residual, rms_norm_weight]
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass):
|
||||
def pattern(rms_norm_input: torch.Tensor, residual: torch.Tensor, rms_norm_weight: torch.Tensor):
|
||||
"""
|
||||
Pattern for AddRMSNormQuant fusion.
|
||||
"""
|
||||
output = torch.ops.npu.npu_add_rms_norm(rms_norm_input, residual, rms_norm_weight, self.eps)
|
||||
out0 = output[0]
|
||||
out1 = output[2]
|
||||
quantized_output = torch.ops.npu.npu_dynamic_quant(out0)
|
||||
return quantized_output[0], quantized_output[1], out1
|
||||
|
||||
def replacement(rms_norm_input: torch.Tensor, residual: torch.Tensor, rms_norm_weight: torch.Tensor):
|
||||
"""
|
||||
Replacement for the AddRMSNormQuant fusion.
|
||||
"""
|
||||
output = torch.ops.npu.npu_add_rms_norm_dynamic_quant(
|
||||
rms_norm_input, residual, rms_norm_weight, epsilon=self.eps, output_mask=[True, True]
|
||||
)
|
||||
return (
|
||||
output[0],
|
||||
output[3],
|
||||
output[2],
|
||||
)
|
||||
|
||||
pm.register_replacement(pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass)
|
||||
|
||||
|
||||
class AddRMSNormDynamicQuantPatternWithBias:
|
||||
def __init__(self, vllm_config: VllmConfig, eps: float = 1e-6):
|
||||
self.vllm_config = vllm_config
|
||||
self.dtype = vllm_config.model_config.dtype
|
||||
self.eps = eps
|
||||
|
||||
def get_inputs(self):
|
||||
"""
|
||||
Generate example inputs for the AddRMSNormQuant fusion pattern.
|
||||
"""
|
||||
rms_norm_input = torch.randn(2, 4, device="npu", dtype=self.dtype)
|
||||
residual = torch.randn(2, 4, device="npu", dtype=self.dtype)
|
||||
rms_norm_weight = torch.randn(4, device="npu", dtype=self.dtype)
|
||||
rmsnorm_bias = torch.randn(4, device="npu", dtype=self.dtype)
|
||||
return [rms_norm_input, residual, rms_norm_weight, rmsnorm_bias]
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass):
|
||||
def pattern(
|
||||
rms_norm_input: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
rms_norm_weight: torch.Tensor,
|
||||
bias: torch.Tensor,
|
||||
):
|
||||
"""
|
||||
Pattern for AddRMSNormQuant fusion.
|
||||
"""
|
||||
output = torch.ops._C_ascend.npu_add_rms_norm_bias(
|
||||
rms_norm_input, residual, rms_norm_weight, bias, self.eps
|
||||
)
|
||||
out0 = output[0]
|
||||
out1 = output[2]
|
||||
quantized_output = torch.ops.npu.npu_dynamic_quant(out0)
|
||||
return quantized_output[0], quantized_output[1], out1
|
||||
|
||||
def replacement(
|
||||
rms_norm_input: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
rms_norm_weight: torch.Tensor,
|
||||
bias: torch.Tensor,
|
||||
):
|
||||
"""
|
||||
Replacement for the AddRMSNormQuant fusion.
|
||||
"""
|
||||
output = torch.ops.npu.npu_add_rms_norm_dynamic_quant(
|
||||
rms_norm_input, residual, rms_norm_weight, epsilon=self.eps, output_mask=[True, True], beta=bias
|
||||
)
|
||||
return (
|
||||
output[0],
|
||||
output[3],
|
||||
output[2],
|
||||
)
|
||||
|
||||
pm.register_replacement(pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass)
|
||||
|
||||
|
||||
class AddRMSNormDynamicQuantSPPattern:
|
||||
def __init__(self, vllm_config: VllmConfig, eps: float = 1e-6):
|
||||
self.vllm_config = vllm_config
|
||||
self.dtype = vllm_config.model_config.dtype
|
||||
self.eps = eps
|
||||
|
||||
def get_inputs(self):
|
||||
"""
|
||||
Generate example inputs for the AddRMSNormQuant fusion pattern.
|
||||
"""
|
||||
rms_norm_input = torch.randn(2, 4, device="npu", dtype=self.dtype)
|
||||
residual = torch.randn(2, 4, device="npu", dtype=self.dtype)
|
||||
rms_norm_weight = torch.randn(4, device="npu", dtype=self.dtype)
|
||||
return [rms_norm_input, residual, rms_norm_weight]
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass):
|
||||
def pattern(rms_norm_input: torch.Tensor, residual: torch.Tensor, rms_norm_weight: torch.Tensor):
|
||||
"""
|
||||
Pattern for AddRMSNormQuant fusion.
|
||||
"""
|
||||
output = torch.ops.npu.npu_add_rms_norm(rms_norm_input, residual, rms_norm_weight, self.eps)
|
||||
out0 = output[0]
|
||||
out1 = output[2]
|
||||
out0 = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(out0, True)
|
||||
quantized_output = torch.ops.npu.npu_dynamic_quant(out0)
|
||||
return quantized_output[0], quantized_output[1], out1
|
||||
|
||||
def replacement(rms_norm_input: torch.Tensor, residual: torch.Tensor, rms_norm_weight: torch.Tensor):
|
||||
"""
|
||||
Replacement for the AddRMSNormQuant fusion.
|
||||
"""
|
||||
output = torch.ops.npu.npu_add_rms_norm_dynamic_quant(
|
||||
rms_norm_input, residual, rms_norm_weight, epsilon=self.eps, output_mask=[True, True]
|
||||
)
|
||||
out3 = output[3]
|
||||
quantized_output = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(output[0], True)
|
||||
out3 = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(out3, True)
|
||||
return quantized_output, out3, output[2]
|
||||
|
||||
pm.register_replacement(pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass)
|
||||
|
||||
|
||||
class AddRMSNormDynamicQuantSPPatternWithBias:
|
||||
def __init__(self, vllm_config: VllmConfig, eps: float = 1e-6):
|
||||
self.vllm_config = vllm_config
|
||||
self.dtype = vllm_config.model_config.dtype
|
||||
self.eps = eps
|
||||
|
||||
def get_inputs(self):
|
||||
"""
|
||||
Generate example inputs for the AddRMSNormQuant fusion pattern.
|
||||
"""
|
||||
rms_norm_input = torch.randn(2, 4, device="npu", dtype=self.dtype)
|
||||
residual = torch.randn(2, 4, device="npu", dtype=self.dtype)
|
||||
rms_norm_weight = torch.randn(4, device="npu", dtype=self.dtype)
|
||||
rmsnorm_bias = torch.randn(4, device="npu", dtype=self.dtype)
|
||||
return [rms_norm_input, residual, rms_norm_weight, rmsnorm_bias]
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass):
|
||||
def pattern(
|
||||
rms_norm_input: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
rms_norm_weight: torch.Tensor,
|
||||
bias: torch.Tensor,
|
||||
):
|
||||
"""
|
||||
Pattern for AddRMSNormQuant fusion.
|
||||
"""
|
||||
output = torch.ops._C_ascend.npu_add_rms_norm_bias(
|
||||
rms_norm_input, residual, rms_norm_weight, bias, self.eps
|
||||
)
|
||||
out0 = output[0]
|
||||
out1 = output[2]
|
||||
out0 = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(out0, True)
|
||||
quantized_output = torch.ops.npu.npu_dynamic_quant(out0)
|
||||
return quantized_output[0], quantized_output[1], out1
|
||||
|
||||
def replacement(
|
||||
rms_norm_input: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
rms_norm_weight: torch.Tensor,
|
||||
bias: torch.Tensor,
|
||||
):
|
||||
"""
|
||||
Replacement for the AddRMSNormQuant fusion.
|
||||
"""
|
||||
output = torch.ops.npu.npu_add_rms_norm_dynamic_quant(
|
||||
rms_norm_input, residual, rms_norm_weight, epsilon=self.eps, output_mask=[True, True], beta=bias
|
||||
)
|
||||
out3 = output[3]
|
||||
quantized_output = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(output[0], True)
|
||||
out3 = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(out3, True)
|
||||
return quantized_output, out3, output[2]
|
||||
|
||||
pm.register_replacement(pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass)
|
||||
|
||||
|
||||
class AddRMSNormQuantFusionPass(VllmInductorPass):
|
||||
"""
|
||||
A pass for fusing AddRMSNorm and W8A8 quantization operations on Ascend.
|
||||
@@ -286,9 +480,13 @@ class AddRMSNormQuantFusionPass(VllmInductorPass):
|
||||
for eps in common_epsilons:
|
||||
AddRMSNormQuantPattern(vllm_config, eps=eps).register(self.pattern_match_passes)
|
||||
AddRMSNormQuantSPPattern(vllm_config, eps=eps).register(self.pattern_match_passes)
|
||||
AddRMSNormDynamicQuantPattern(vllm_config, eps=eps).register(self.pattern_match_passes)
|
||||
AddRMSNormDynamicQuantSPPattern(vllm_config, eps=eps).register(self.pattern_match_passes)
|
||||
if enable_custom_op():
|
||||
AddRMSNormQuantPatternWithBias(vllm_config, eps=eps).register(self.pattern_match_passes)
|
||||
AddRMSNormQuantSPPatternWithBias(vllm_config, eps=eps).register(self.pattern_match_passes)
|
||||
AddRMSNormDynamicQuantPatternWithBias(vllm_config, eps=eps).register(self.pattern_match_passes)
|
||||
AddRMSNormDynamicQuantSPPatternWithBias(vllm_config, eps=eps).register(self.pattern_match_passes)
|
||||
|
||||
def __call__(self, graph: torch.fx.Graph):
|
||||
self.begin()
|
||||
|
||||
Reference in New Issue
Block a user