[Graph][Fusion]Add new pattern for AddRmsnormQuant with SP. (#5077)
### What this PR does / why we need it?
1. In addition to
[#4168](https://github.com/vllm-project/vllm-ascend/pull/4168),
[#5011](https://github.com/vllm-project/vllm-ascend/pull/5011), this PR
adds two more pattern for AddRmsnormQuant with SP enabled. The key
difference is to insert an additional `maybe_all_gather_and_maybe_unpad`
between `addrmsnorm` and `quantize`.
2. This PR also introduce another api `torch.ops.vllm.quantize`, so that
we pass `input_scale` and `input_scale_reciprocal` at the same time.
This is because `npu_add_rms_norm_quant` and `npu_quantize` requires
different `div_mode`. To avoid introducing additional reciprocal
calculation in runtime, we have to pass both of them to quantize api.
3. Removes redundant `AscendQuantRmsnorm`.
- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c
---------
Signed-off-by: Angazenn <supperccell@163.com>
This commit is contained in:
@@ -23,8 +23,13 @@ import torch_npu
|
||||
import vllm.config
|
||||
from vllm.compilation.fx_utils import OpOverload
|
||||
from vllm.config import ModelConfig, VllmConfig
|
||||
from vllm.distributed import (ensure_model_parallel_initialized,
|
||||
init_distributed_environment)
|
||||
from vllm.utils.system_utils import update_environment_variables
|
||||
|
||||
import vllm_ascend.ops.register_custom_ops # noqa
|
||||
from tests.e2e.singlecard.compile.backend import TestBackend
|
||||
from vllm_ascend.ascend_forward_context import set_ascend_forward_context
|
||||
from vllm_ascend.compilation.passes.norm_quant_fusion_pass import \
|
||||
AddRMSNormQuantFusionPass
|
||||
|
||||
@@ -35,14 +40,23 @@ class TestModelWithoutBias(nn.Module):
|
||||
AddRMSNorm → Quantization (without bias)
|
||||
"""
|
||||
|
||||
def __init__(self, hidden_size: int, eps: float = 1e-6, device="npu"):
|
||||
def __init__(self,
|
||||
hidden_size: int,
|
||||
dtype: torch.dtype,
|
||||
eps: float = 1e-6,
|
||||
device="npu"):
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
self.eps = eps
|
||||
self.rms_norm_weight = nn.Parameter(
|
||||
torch.randn(hidden_size, device=device))
|
||||
self.quant_scale = torch.tensor([1.0], device=device)
|
||||
self.quant_offset = torch.tensor([0.0], device=device)
|
||||
self.quant_scale = torch.ones(hidden_size, dtype=dtype, device=device)
|
||||
self.quant_scale_reciprocal = torch.ones(hidden_size,
|
||||
dtype=dtype,
|
||||
device=device)
|
||||
self.quant_offset = torch.zeros(hidden_size,
|
||||
dtype=dtype,
|
||||
device=device)
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
@@ -56,10 +70,10 @@ class TestModelWithoutBias(nn.Module):
|
||||
norm_output, _, new_residual = torch_npu.npu_add_rms_norm(
|
||||
x, residual, self.rms_norm_weight, self.eps)
|
||||
|
||||
quantized_output = torch_npu.npu_quantize(norm_output,
|
||||
self.quant_scale,
|
||||
self.quant_offset,
|
||||
torch.qint8, -1, False)
|
||||
quantized_output = torch.ops.vllm.quantize(norm_output,
|
||||
self.quant_scale,
|
||||
self.quant_scale_reciprocal,
|
||||
self.quant_offset)
|
||||
|
||||
return quantized_output, new_residual
|
||||
|
||||
@@ -67,7 +81,7 @@ class TestModelWithoutBias(nn.Module):
|
||||
"""Return the list of expected operators BEFORE fusion."""
|
||||
return [
|
||||
torch.ops.npu.npu_add_rms_norm.default,
|
||||
torch.ops.npu.npu_quantize.default
|
||||
torch.ops.vllm.quantize.default
|
||||
]
|
||||
|
||||
def ops_in_model_after(self) -> List[OpOverload]:
|
||||
@@ -81,15 +95,24 @@ class TestModelWithBias(nn.Module):
|
||||
AddRMSNorm → Add Bias → Quantization (with bias)
|
||||
"""
|
||||
|
||||
def __init__(self, hidden_size: int, eps: float = 1e-6, device="npu"):
|
||||
def __init__(self,
|
||||
hidden_size: int,
|
||||
dtype: torch.dtype,
|
||||
eps: float = 1e-6,
|
||||
device="npu"):
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
self.eps = eps
|
||||
self.rms_norm_weight = nn.Parameter(
|
||||
torch.randn(hidden_size, device=device))
|
||||
self.bias = nn.Parameter(torch.randn(hidden_size, device=device))
|
||||
self.quant_scale = torch.tensor([1.0], device=device)
|
||||
self.quant_offset = torch.tensor([0.0], device=device)
|
||||
self.quant_scale = torch.ones(hidden_size, dtype=dtype, device=device)
|
||||
self.quant_scale_reciprocal = torch.ones(hidden_size,
|
||||
dtype=dtype,
|
||||
device=device)
|
||||
self.quant_offset = torch.zeros(hidden_size,
|
||||
dtype=dtype,
|
||||
device=device)
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
@@ -107,10 +130,10 @@ class TestModelWithBias(nn.Module):
|
||||
# Add bias
|
||||
norm_output_with_bias = norm_output + self.bias
|
||||
|
||||
quantized_output = torch_npu.npu_quantize(norm_output_with_bias,
|
||||
self.quant_scale,
|
||||
self.quant_offset,
|
||||
torch.qint8, -1, False)
|
||||
quantized_output = torch.ops.vllm.quantize(norm_output_with_bias,
|
||||
self.quant_scale,
|
||||
self.quant_scale_reciprocal,
|
||||
self.quant_offset)
|
||||
|
||||
return quantized_output, new_residual
|
||||
|
||||
@@ -119,7 +142,7 @@ class TestModelWithBias(nn.Module):
|
||||
return [
|
||||
torch.ops.npu.npu_add_rms_norm.default,
|
||||
torch.ops.aten.add.Tensor, # Add bias operation
|
||||
torch.ops.npu.npu_quantize.default
|
||||
torch.ops.vllm.quantize.default
|
||||
]
|
||||
|
||||
def ops_in_model_after(self) -> List[OpOverload]:
|
||||
@@ -127,13 +150,152 @@ class TestModelWithBias(nn.Module):
|
||||
return [torch.ops.npu.npu_add_rms_norm_quant.default]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
|
||||
class TestModelSPWithoutBias(nn.Module):
|
||||
"""
|
||||
A minimal test model that simulates the pattern:
|
||||
AddRMSNorm → maybe_allgather → Quantization (without bias)
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
hidden_size: int,
|
||||
dtype: torch.dtype,
|
||||
eps: float = 1e-6,
|
||||
device="npu"):
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
self.eps = eps
|
||||
self.rms_norm_weight = nn.Parameter(
|
||||
torch.randn(hidden_size, device=device))
|
||||
self.quant_scale = torch.ones(hidden_size, dtype=dtype, device=device)
|
||||
self.quant_scale_reciprocal = torch.ones(hidden_size,
|
||||
dtype=dtype,
|
||||
device=device)
|
||||
self.quant_offset = torch.zeros(hidden_size,
|
||||
dtype=dtype,
|
||||
device=device)
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
Forward pass:
|
||||
1. Perform npu_add_rms_norm
|
||||
2. Perform a fake maybe_all_gather_and_maybe_unpad
|
||||
3. Quantize the normalized output to int8
|
||||
Returns both quantized output and updated residual.
|
||||
"""
|
||||
residual = torch.zeros_like(x)
|
||||
|
||||
norm_output, _, new_residual = torch_npu.npu_add_rms_norm(
|
||||
x, residual, self.rms_norm_weight, self.eps)
|
||||
|
||||
norm_output = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
|
||||
norm_output, True)
|
||||
|
||||
quantized_output = torch.ops.vllm.quantize(norm_output,
|
||||
self.quant_scale,
|
||||
self.quant_scale_reciprocal,
|
||||
self.quant_offset)
|
||||
|
||||
return quantized_output, new_residual
|
||||
|
||||
def ops_in_model_before(self) -> List[OpOverload]:
|
||||
"""Return the list of expected operators BEFORE fusion."""
|
||||
return [
|
||||
torch.ops.npu.npu_add_rms_norm.default,
|
||||
torch.ops.vllm.maybe_all_gather_and_maybe_unpad.default,
|
||||
torch.ops.vllm.quantize.default
|
||||
]
|
||||
|
||||
def ops_in_model_after(self) -> List[OpOverload]:
|
||||
"""Return the list of expected operators AFTER successful fusion."""
|
||||
return [
|
||||
torch.ops.npu.npu_add_rms_norm_quant.default,
|
||||
torch.ops.vllm.maybe_all_gather_and_maybe_unpad.default
|
||||
]
|
||||
|
||||
|
||||
class TestModelSPWithBias(nn.Module):
|
||||
"""
|
||||
A minimal test model that simulates the pattern:
|
||||
AddRMSNorm → Add bias → maybe_allgather → Quantization (without bias)
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
hidden_size: int,
|
||||
dtype: torch.dtype,
|
||||
eps: float = 1e-6,
|
||||
device="npu"):
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
self.eps = eps
|
||||
self.rms_norm_weight = nn.Parameter(
|
||||
torch.randn(hidden_size, device=device))
|
||||
self.bias = nn.Parameter(torch.randn(hidden_size, device=device))
|
||||
self.quant_scale = torch.ones(hidden_size, dtype=dtype, device=device)
|
||||
self.quant_scale_reciprocal = torch.ones(hidden_size,
|
||||
dtype=dtype,
|
||||
device=device)
|
||||
self.quant_offset = torch.zeros(hidden_size,
|
||||
dtype=dtype,
|
||||
device=device)
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
Forward pass:
|
||||
1. Perform npu_add_rms_norm
|
||||
2. Add bias
|
||||
3. Perform a fake maybe_all_gather_and_maybe_unpad
|
||||
4. Quantize the normalized output to int8
|
||||
Returns both quantized output and updated residual.
|
||||
"""
|
||||
residual = torch.zeros_like(x)
|
||||
|
||||
norm_output, _, new_residual = torch_npu.npu_add_rms_norm(
|
||||
x, residual, self.rms_norm_weight, self.eps)
|
||||
|
||||
# Add bias
|
||||
norm_output_with_bias = norm_output + self.bias
|
||||
|
||||
norm_output_with_bias = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
|
||||
norm_output_with_bias, True)
|
||||
|
||||
quantized_output = torch.ops.vllm.quantize(norm_output_with_bias,
|
||||
self.quant_scale,
|
||||
self.quant_scale_reciprocal,
|
||||
self.quant_offset)
|
||||
|
||||
return quantized_output, new_residual
|
||||
|
||||
def ops_in_model_before(self) -> List[OpOverload]:
|
||||
"""Return the list of expected operators BEFORE fusion."""
|
||||
return [
|
||||
torch.ops.npu.npu_add_rms_norm.default,
|
||||
torch.ops.aten.add.Tensor, # Add bias operation
|
||||
torch.ops.vllm.maybe_all_gather_and_maybe_unpad.default,
|
||||
torch.ops.vllm.quantize.default
|
||||
]
|
||||
|
||||
def ops_in_model_after(self) -> List[OpOverload]:
|
||||
"""Return the list of expected operators AFTER successful fusion."""
|
||||
return [
|
||||
torch.ops.npu.npu_add_rms_norm_quant.default,
|
||||
torch.ops.vllm.maybe_all_gather_and_maybe_unpad.default
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dtype", [torch.bfloat16])
|
||||
@pytest.mark.parametrize("hidden_size", [64])
|
||||
@pytest.mark.parametrize("num_tokens", [257])
|
||||
@pytest.mark.parametrize("eps", [1e-5, 1e-6])
|
||||
@pytest.mark.parametrize("use_bias", [False, True])
|
||||
def test_rmsnorm_quant_fusion(dtype: torch.dtype, hidden_size: int,
|
||||
num_tokens: int, eps: float, use_bias: bool):
|
||||
@pytest.mark.parametrize("sp_enable", [False, True])
|
||||
def test_rmsnorm_quant_fusion(
|
||||
dtype: torch.dtype,
|
||||
hidden_size: int,
|
||||
num_tokens: int,
|
||||
eps: float,
|
||||
use_bias: bool,
|
||||
sp_enable: bool,
|
||||
):
|
||||
"""
|
||||
End-to-end test for AddRMSNorm+Quantize fusion.
|
||||
Compares: Operator presence/absence before and after graph transformation
|
||||
@@ -143,27 +305,58 @@ def test_rmsnorm_quant_fusion(dtype: torch.dtype, hidden_size: int,
|
||||
|
||||
vllm_config = VllmConfig(model_config=ModelConfig(dtype=dtype))
|
||||
|
||||
update_environment_variables({
|
||||
"RANK": "0",
|
||||
"LOCAL_RANK": "0",
|
||||
"WORLD_SIZE": "1",
|
||||
"MASTER_ADDR": "localhost",
|
||||
"MASTER_PORT": "12345",
|
||||
})
|
||||
init_distributed_environment()
|
||||
ensure_model_parallel_initialized(1, 1)
|
||||
|
||||
with vllm.config.set_current_vllm_config(vllm_config):
|
||||
backend = TestBackend(
|
||||
custom_passes=[AddRMSNormQuantFusionPass(vllm_config=vllm_config)])
|
||||
if use_bias:
|
||||
model = TestModelWithBias(hidden_size, eps, device="npu")
|
||||
else:
|
||||
model = TestModelWithoutBias(hidden_size, eps, device="npu")
|
||||
model = model.to("npu")
|
||||
with set_ascend_forward_context(None, vllm_config):
|
||||
backend = TestBackend(custom_passes=[
|
||||
AddRMSNormQuantFusionPass(vllm_config=vllm_config)
|
||||
])
|
||||
if use_bias:
|
||||
if sp_enable:
|
||||
model = TestModelSPWithBias(hidden_size,
|
||||
dtype,
|
||||
eps,
|
||||
device="npu")
|
||||
else:
|
||||
model = TestModelWithBias(hidden_size,
|
||||
dtype,
|
||||
eps,
|
||||
device="npu")
|
||||
else:
|
||||
if sp_enable:
|
||||
model = TestModelSPWithoutBias(hidden_size,
|
||||
dtype,
|
||||
eps,
|
||||
device="npu")
|
||||
else:
|
||||
model = TestModelWithoutBias(hidden_size,
|
||||
dtype,
|
||||
eps,
|
||||
device="npu")
|
||||
model = model.to("npu")
|
||||
|
||||
x = torch.rand(num_tokens,
|
||||
hidden_size,
|
||||
device="npu",
|
||||
dtype=dtype,
|
||||
requires_grad=False)
|
||||
x = torch.rand(num_tokens,
|
||||
hidden_size,
|
||||
device="npu",
|
||||
dtype=dtype,
|
||||
requires_grad=False)
|
||||
|
||||
result_unfused = model(x)
|
||||
print("Unfused result:", [t.shape for t in result_unfused])
|
||||
model_fused = torch.compile(model, backend=backend)
|
||||
result_fused = model_fused(x)
|
||||
print("Fused result:", [t.shape for t in result_fused])
|
||||
result_unfused = model(x)
|
||||
print("Unfused result:", [t.shape for t in result_unfused])
|
||||
model_fused = torch.compile(model, backend=backend)
|
||||
result_fused = model_fused(x)
|
||||
print("Fused result:", [t.shape for t in result_fused])
|
||||
|
||||
print("=== Checking operator fusion ===")
|
||||
backend.check_before_ops(model.ops_in_model_before())
|
||||
backend.check_after_ops(model.ops_in_model_after())
|
||||
print("=== Checking operator fusion ===")
|
||||
backend.check_before_ops(model.ops_in_model_before(),
|
||||
fully_replaced=not sp_enable)
|
||||
backend.check_after_ops(model.ops_in_model_after())
|
||||
|
||||
Reference in New Issue
Block a user