diff --git a/vllm_ascend/compilation/passes/norm_quant_fusion_pass.py b/vllm_ascend/compilation/passes/norm_quant_fusion_pass.py index 9bea5ca5..5dcb98d1 100644 --- a/vllm_ascend/compilation/passes/norm_quant_fusion_pass.py +++ b/vllm_ascend/compilation/passes/norm_quant_fusion_pass.py @@ -268,6 +268,200 @@ class AddRMSNormQuantSPPatternWithBias: pm.register_replacement(pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass) +class AddRMSNormDynamicQuantPattern: + def __init__(self, vllm_config: VllmConfig, eps: float = 1e-6): + self.vllm_config = vllm_config + self.dtype = vllm_config.model_config.dtype + self.eps = eps + + def get_inputs(self): + """ + Generate example inputs for the AddRMSNormQuant fusion pattern. + """ + rms_norm_input = torch.randn(2, 4, device="npu", dtype=self.dtype) + residual = torch.randn(2, 4, device="npu", dtype=self.dtype) + rms_norm_weight = torch.randn(4, device="npu", dtype=self.dtype) + return [rms_norm_input, residual, rms_norm_weight] + + def register(self, pm_pass: PatternMatcherPass): + def pattern(rms_norm_input: torch.Tensor, residual: torch.Tensor, rms_norm_weight: torch.Tensor): + """ + Pattern for AddRMSNormQuant fusion. + """ + output = torch.ops.npu.npu_add_rms_norm(rms_norm_input, residual, rms_norm_weight, self.eps) + out0 = output[0] + out1 = output[2] + quantized_output = torch.ops.npu.npu_dynamic_quant(out0) + return quantized_output[0], quantized_output[1], out1 + + def replacement(rms_norm_input: torch.Tensor, residual: torch.Tensor, rms_norm_weight: torch.Tensor): + """ + Replacement for the AddRMSNormQuant fusion. + """ + output = torch.ops.npu.npu_add_rms_norm_dynamic_quant( + rms_norm_input, residual, rms_norm_weight, epsilon=self.eps, output_mask=[True, True] + ) + return ( + output[0], + output[3], + output[2], + ) + + pm.register_replacement(pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass) + + +class AddRMSNormDynamicQuantPatternWithBias: + def __init__(self, vllm_config: VllmConfig, eps: float = 1e-6): + self.vllm_config = vllm_config + self.dtype = vllm_config.model_config.dtype + self.eps = eps + + def get_inputs(self): + """ + Generate example inputs for the AddRMSNormQuant fusion pattern. + """ + rms_norm_input = torch.randn(2, 4, device="npu", dtype=self.dtype) + residual = torch.randn(2, 4, device="npu", dtype=self.dtype) + rms_norm_weight = torch.randn(4, device="npu", dtype=self.dtype) + rmsnorm_bias = torch.randn(4, device="npu", dtype=self.dtype) + return [rms_norm_input, residual, rms_norm_weight, rmsnorm_bias] + + def register(self, pm_pass: PatternMatcherPass): + def pattern( + rms_norm_input: torch.Tensor, + residual: torch.Tensor, + rms_norm_weight: torch.Tensor, + bias: torch.Tensor, + ): + """ + Pattern for AddRMSNormQuant fusion. + """ + output = torch.ops._C_ascend.npu_add_rms_norm_bias( + rms_norm_input, residual, rms_norm_weight, bias, self.eps + ) + out0 = output[0] + out1 = output[2] + quantized_output = torch.ops.npu.npu_dynamic_quant(out0) + return quantized_output[0], quantized_output[1], out1 + + def replacement( + rms_norm_input: torch.Tensor, + residual: torch.Tensor, + rms_norm_weight: torch.Tensor, + bias: torch.Tensor, + ): + """ + Replacement for the AddRMSNormQuant fusion. + """ + output = torch.ops.npu.npu_add_rms_norm_dynamic_quant( + rms_norm_input, residual, rms_norm_weight, epsilon=self.eps, output_mask=[True, True], beta=bias + ) + return ( + output[0], + output[3], + output[2], + ) + + pm.register_replacement(pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass) + + +class AddRMSNormDynamicQuantSPPattern: + def __init__(self, vllm_config: VllmConfig, eps: float = 1e-6): + self.vllm_config = vllm_config + self.dtype = vllm_config.model_config.dtype + self.eps = eps + + def get_inputs(self): + """ + Generate example inputs for the AddRMSNormQuant fusion pattern. + """ + rms_norm_input = torch.randn(2, 4, device="npu", dtype=self.dtype) + residual = torch.randn(2, 4, device="npu", dtype=self.dtype) + rms_norm_weight = torch.randn(4, device="npu", dtype=self.dtype) + return [rms_norm_input, residual, rms_norm_weight] + + def register(self, pm_pass: PatternMatcherPass): + def pattern(rms_norm_input: torch.Tensor, residual: torch.Tensor, rms_norm_weight: torch.Tensor): + """ + Pattern for AddRMSNormQuant fusion. + """ + output = torch.ops.npu.npu_add_rms_norm(rms_norm_input, residual, rms_norm_weight, self.eps) + out0 = output[0] + out1 = output[2] + out0 = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(out0, True) + quantized_output = torch.ops.npu.npu_dynamic_quant(out0) + return quantized_output[0], quantized_output[1], out1 + + def replacement(rms_norm_input: torch.Tensor, residual: torch.Tensor, rms_norm_weight: torch.Tensor): + """ + Replacement for the AddRMSNormQuant fusion. + """ + output = torch.ops.npu.npu_add_rms_norm_dynamic_quant( + rms_norm_input, residual, rms_norm_weight, epsilon=self.eps, output_mask=[True, True] + ) + out3 = output[3] + quantized_output = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(output[0], True) + out3 = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(out3, True) + return quantized_output, out3, output[2] + + pm.register_replacement(pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass) + + +class AddRMSNormDynamicQuantSPPatternWithBias: + def __init__(self, vllm_config: VllmConfig, eps: float = 1e-6): + self.vllm_config = vllm_config + self.dtype = vllm_config.model_config.dtype + self.eps = eps + + def get_inputs(self): + """ + Generate example inputs for the AddRMSNormQuant fusion pattern. + """ + rms_norm_input = torch.randn(2, 4, device="npu", dtype=self.dtype) + residual = torch.randn(2, 4, device="npu", dtype=self.dtype) + rms_norm_weight = torch.randn(4, device="npu", dtype=self.dtype) + rmsnorm_bias = torch.randn(4, device="npu", dtype=self.dtype) + return [rms_norm_input, residual, rms_norm_weight, rmsnorm_bias] + + def register(self, pm_pass: PatternMatcherPass): + def pattern( + rms_norm_input: torch.Tensor, + residual: torch.Tensor, + rms_norm_weight: torch.Tensor, + bias: torch.Tensor, + ): + """ + Pattern for AddRMSNormQuant fusion. + """ + output = torch.ops._C_ascend.npu_add_rms_norm_bias( + rms_norm_input, residual, rms_norm_weight, bias, self.eps + ) + out0 = output[0] + out1 = output[2] + out0 = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(out0, True) + quantized_output = torch.ops.npu.npu_dynamic_quant(out0) + return quantized_output[0], quantized_output[1], out1 + + def replacement( + rms_norm_input: torch.Tensor, + residual: torch.Tensor, + rms_norm_weight: torch.Tensor, + bias: torch.Tensor, + ): + """ + Replacement for the AddRMSNormQuant fusion. + """ + output = torch.ops.npu.npu_add_rms_norm_dynamic_quant( + rms_norm_input, residual, rms_norm_weight, epsilon=self.eps, output_mask=[True, True], beta=bias + ) + out3 = output[3] + quantized_output = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(output[0], True) + out3 = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(out3, True) + return quantized_output, out3, output[2] + + pm.register_replacement(pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass) + + class AddRMSNormQuantFusionPass(VllmInductorPass): """ A pass for fusing AddRMSNorm and W8A8 quantization operations on Ascend. @@ -286,9 +480,13 @@ class AddRMSNormQuantFusionPass(VllmInductorPass): for eps in common_epsilons: AddRMSNormQuantPattern(vllm_config, eps=eps).register(self.pattern_match_passes) AddRMSNormQuantSPPattern(vllm_config, eps=eps).register(self.pattern_match_passes) + AddRMSNormDynamicQuantPattern(vllm_config, eps=eps).register(self.pattern_match_passes) + AddRMSNormDynamicQuantSPPattern(vllm_config, eps=eps).register(self.pattern_match_passes) if enable_custom_op(): AddRMSNormQuantPatternWithBias(vllm_config, eps=eps).register(self.pattern_match_passes) AddRMSNormQuantSPPatternWithBias(vllm_config, eps=eps).register(self.pattern_match_passes) + AddRMSNormDynamicQuantPatternWithBias(vllm_config, eps=eps).register(self.pattern_match_passes) + AddRMSNormDynamicQuantSPPatternWithBias(vllm_config, eps=eps).register(self.pattern_match_passes) def __call__(self, graph: torch.fx.Graph): self.begin()