[Inductor]change pass to adapt to new addrmsnormBias operator (#6094)
### What this PR does / why we need it? #5790 changes default addrmsnormBias operator if custom ops is enabled. This PR modifies AddRmsNormQuant pass to align with addrmsnormBias. --------- Signed-off-by: Angazenn <supperccell@163.com>
This commit is contained in:
@@ -32,6 +32,7 @@ from tests.e2e.singlecard.compile.backend import TestBackend
|
|||||||
from vllm_ascend.ascend_forward_context import set_ascend_forward_context
|
from vllm_ascend.ascend_forward_context import set_ascend_forward_context
|
||||||
from vllm_ascend.compilation.passes.norm_quant_fusion_pass import \
|
from vllm_ascend.compilation.passes.norm_quant_fusion_pass import \
|
||||||
AddRMSNormQuantFusionPass
|
AddRMSNormQuantFusionPass
|
||||||
|
from vllm_ascend.utils import enable_custom_op
|
||||||
|
|
||||||
|
|
||||||
class TestModelWithoutBias(nn.Module):
|
class TestModelWithoutBias(nn.Module):
|
||||||
@@ -124,11 +125,8 @@ class TestModelWithBias(nn.Module):
|
|||||||
"""
|
"""
|
||||||
residual = torch.zeros_like(x)
|
residual = torch.zeros_like(x)
|
||||||
|
|
||||||
norm_output, _, new_residual = torch_npu.npu_add_rms_norm(
|
norm_output_with_bias, _, new_residual = torch.ops._C_ascend.npu_add_rms_norm_bias(
|
||||||
x, residual, self.rms_norm_weight, self.eps)
|
x, residual, self.rms_norm_weight, self.bias, self.eps)
|
||||||
|
|
||||||
# Add bias
|
|
||||||
norm_output_with_bias = norm_output + self.bias
|
|
||||||
|
|
||||||
quantized_output = torch.ops.vllm.quantize(norm_output_with_bias,
|
quantized_output = torch.ops.vllm.quantize(norm_output_with_bias,
|
||||||
self.quant_scale,
|
self.quant_scale,
|
||||||
@@ -140,8 +138,7 @@ class TestModelWithBias(nn.Module):
|
|||||||
def ops_in_model_before(self) -> List[OpOverload]:
|
def ops_in_model_before(self) -> List[OpOverload]:
|
||||||
"""Return the list of expected operators BEFORE fusion."""
|
"""Return the list of expected operators BEFORE fusion."""
|
||||||
return [
|
return [
|
||||||
torch.ops.npu.npu_add_rms_norm.default,
|
torch.ops._C_ascend.npu_add_rms_norm_bias.default,
|
||||||
torch.ops.aten.add.Tensor, # Add bias operation
|
|
||||||
torch.ops.vllm.quantize.default
|
torch.ops.vllm.quantize.default
|
||||||
]
|
]
|
||||||
|
|
||||||
@@ -249,11 +246,8 @@ class TestModelSPWithBias(nn.Module):
|
|||||||
"""
|
"""
|
||||||
residual = torch.zeros_like(x)
|
residual = torch.zeros_like(x)
|
||||||
|
|
||||||
norm_output, _, new_residual = torch_npu.npu_add_rms_norm(
|
norm_output_with_bias, _, new_residual = torch.ops._C_ascend.npu_add_rms_norm_bias(
|
||||||
x, residual, self.rms_norm_weight, self.eps)
|
x, residual, self.rms_norm_weight, self.bias, self.eps)
|
||||||
|
|
||||||
# Add bias
|
|
||||||
norm_output_with_bias = norm_output + self.bias
|
|
||||||
|
|
||||||
norm_output_with_bias = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
|
norm_output_with_bias = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
|
||||||
norm_output_with_bias, True)
|
norm_output_with_bias, True)
|
||||||
@@ -268,8 +262,7 @@ class TestModelSPWithBias(nn.Module):
|
|||||||
def ops_in_model_before(self) -> List[OpOverload]:
|
def ops_in_model_before(self) -> List[OpOverload]:
|
||||||
"""Return the list of expected operators BEFORE fusion."""
|
"""Return the list of expected operators BEFORE fusion."""
|
||||||
return [
|
return [
|
||||||
torch.ops.npu.npu_add_rms_norm.default,
|
torch.ops._C_ascend.npu_add_rms_norm_bias.default,
|
||||||
torch.ops.aten.add.Tensor, # Add bias operation
|
|
||||||
torch.ops.vllm.maybe_all_gather_and_maybe_unpad.default,
|
torch.ops.vllm.maybe_all_gather_and_maybe_unpad.default,
|
||||||
torch.ops.vllm.quantize.default
|
torch.ops.vllm.quantize.default
|
||||||
]
|
]
|
||||||
@@ -322,6 +315,8 @@ def test_rmsnorm_quant_fusion(
|
|||||||
AddRMSNormQuantFusionPass(vllm_config=vllm_config)
|
AddRMSNormQuantFusionPass(vllm_config=vllm_config)
|
||||||
])
|
])
|
||||||
if use_bias:
|
if use_bias:
|
||||||
|
if not enable_custom_op():
|
||||||
|
return
|
||||||
if sp_enable:
|
if sp_enable:
|
||||||
model = TestModelSPWithBias(hidden_size,
|
model = TestModelSPWithBias(hidden_size,
|
||||||
dtype,
|
dtype,
|
||||||
|
|||||||
@@ -23,6 +23,8 @@ from vllm.config import VllmConfig
|
|||||||
from vllm.config.compilation import Range
|
from vllm.config.compilation import Range
|
||||||
from vllm.logger import logger
|
from vllm.logger import logger
|
||||||
|
|
||||||
|
from vllm_ascend.utils import enable_custom_op
|
||||||
|
|
||||||
|
|
||||||
class AddRMSNormQuantPattern:
|
class AddRMSNormQuantPattern:
|
||||||
def __init__(self, vllm_config: VllmConfig, eps: float = 1e-6):
|
def __init__(self, vllm_config: VllmConfig, eps: float = 1e-6):
|
||||||
@@ -113,10 +115,11 @@ class AddRMSNormQuantPatternWithBias:
|
|||||||
"""
|
"""
|
||||||
Pattern for AddRMSNormQuant fusion.
|
Pattern for AddRMSNormQuant fusion.
|
||||||
"""
|
"""
|
||||||
output = torch.ops.npu.npu_add_rms_norm(rms_norm_input, residual, rms_norm_weight, self.eps)
|
output = torch.ops._C_ascend.npu_add_rms_norm_bias(
|
||||||
|
rms_norm_input, residual, rms_norm_weight, bias, self.eps
|
||||||
|
)
|
||||||
out0 = output[0]
|
out0 = output[0]
|
||||||
out1 = output[2]
|
out1 = output[2]
|
||||||
out0 = out0 + bias
|
|
||||||
quantized_output = torch.ops.vllm.quantize(out0, scale, scale_reciprocal, offset)
|
quantized_output = torch.ops.vllm.quantize(out0, scale, scale_reciprocal, offset)
|
||||||
return quantized_output, out1
|
return quantized_output, out1
|
||||||
|
|
||||||
@@ -233,10 +236,11 @@ class AddRMSNormQuantSPPatternWithBias:
|
|||||||
"""
|
"""
|
||||||
Pattern for AddRMSNormQuant fusion.
|
Pattern for AddRMSNormQuant fusion.
|
||||||
"""
|
"""
|
||||||
output = torch.ops.npu.npu_add_rms_norm(rms_norm_input, residual, rms_norm_weight, self.eps)
|
output = torch.ops._C_ascend.npu_add_rms_norm_bias(
|
||||||
|
rms_norm_input, residual, rms_norm_weight, bias, self.eps
|
||||||
|
)
|
||||||
out0 = output[0]
|
out0 = output[0]
|
||||||
out1 = output[2]
|
out1 = output[2]
|
||||||
out0 = out0 + bias
|
|
||||||
out0 = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(out0, True)
|
out0 = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(out0, True)
|
||||||
quantized_output = torch.ops.vllm.quantize(out0, scale, scale_reciprocal, offset)
|
quantized_output = torch.ops.vllm.quantize(out0, scale, scale_reciprocal, offset)
|
||||||
return quantized_output, out1
|
return quantized_output, out1
|
||||||
@@ -281,9 +285,10 @@ class AddRMSNormQuantFusionPass(VllmInductorPass):
|
|||||||
common_epsilons = [1e-5, 1e-6]
|
common_epsilons = [1e-5, 1e-6]
|
||||||
for eps in common_epsilons:
|
for eps in common_epsilons:
|
||||||
AddRMSNormQuantPattern(vllm_config, eps=eps).register(self.pattern_match_passes)
|
AddRMSNormQuantPattern(vllm_config, eps=eps).register(self.pattern_match_passes)
|
||||||
AddRMSNormQuantPatternWithBias(vllm_config, eps=eps).register(self.pattern_match_passes)
|
|
||||||
AddRMSNormQuantSPPattern(vllm_config, eps=eps).register(self.pattern_match_passes)
|
AddRMSNormQuantSPPattern(vllm_config, eps=eps).register(self.pattern_match_passes)
|
||||||
AddRMSNormQuantSPPatternWithBias(vllm_config, eps=eps).register(self.pattern_match_passes)
|
if enable_custom_op():
|
||||||
|
AddRMSNormQuantPatternWithBias(vllm_config, eps=eps).register(self.pattern_match_passes)
|
||||||
|
AddRMSNormQuantSPPatternWithBias(vllm_config, eps=eps).register(self.pattern_match_passes)
|
||||||
|
|
||||||
def __call__(self, graph: torch.fx.Graph):
|
def __call__(self, graph: torch.fx.Graph):
|
||||||
self.begin()
|
self.begin()
|
||||||
|
|||||||
Reference in New Issue
Block a user