### What this PR does / why we need it?
| File Path |
| :--- |
| `tests/e2e/singlecard/compile/backend.py` |
| `tests/e2e/singlecard/compile/test_graphex_norm_quant_fusion.py` |
| `tests/e2e/singlecard/compile/test_graphex_qknorm_rope_fusion.py` |
| `tests/e2e/singlecard/compile/test_norm_quant_fusion.py` |
| `tests/e2e/singlecard/model_runner_v2/test_basic.py` |
| `tests/e2e/singlecard/test_aclgraph_accuracy.py` |
| `tests/e2e/singlecard/test_aclgraph_batch_invariant.py` |
| `tests/e2e/singlecard/test_aclgraph_mem.py` |
| `tests/e2e/singlecard/test_async_scheduling.py` |
| `tests/e2e/singlecard/test_auto_fit_max_mode_len.py` |
| `tests/e2e/singlecard/test_batch_invariant.py` |
| `tests/e2e/singlecard/test_camem.py` |
| `tests/e2e/singlecard/test_completion_with_prompt_embeds.py` |
| `tests/e2e/singlecard/test_cpu_offloading.py` |
| `tests/e2e/singlecard/test_guided_decoding.py` |
| `tests/e2e/singlecard/test_ilama_lora.py` |
| `tests/e2e/singlecard/test_llama32_lora.py` |
| `tests/e2e/singlecard/test_models.py` |
| `tests/e2e/singlecard/test_multistream_overlap_shared_expert.py` |
| `tests/e2e/singlecard/test_quantization.py` |
| `tests/e2e/singlecard/test_qwen3_multi_loras.py` |
| `tests/e2e/singlecard/test_sampler.py` |
| `tests/e2e/singlecard/test_vlm.py` |
| `tests/e2e/singlecard/test_xlite.py` |
| `tests/e2e/singlecard/utils.py` |
### Does this PR introduce _any_ user-facing change?
### How was this patch tested?
- vLLM version: v0.15.0
- vLLM main:
9562912cea
---------
Signed-off-by: MrZ20 <2609716663@qq.com>
This commit is contained in:
@@ -14,8 +14,9 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
from collections.abc import Callable, Sequence
|
||||
from copy import deepcopy
|
||||
from typing import Any, Callable, List, Optional, Sequence
|
||||
from typing import Any
|
||||
|
||||
import torch.fx as fx
|
||||
from torch._inductor.decomposition import select_decomp_table
|
||||
@@ -37,7 +38,7 @@ class TestBackend:
|
||||
records the FX graph before and after the transformation.
|
||||
"""
|
||||
|
||||
def __init__(self, custom_passes: Optional[List[Any]] = None):
|
||||
def __init__(self, custom_passes: list[Any] | None = None):
|
||||
vllm_config = get_current_vllm_config()
|
||||
compile_config = vllm_config.compilation_config
|
||||
self.inductor_config = compile_config.inductor_compile_config
|
||||
@@ -48,9 +49,7 @@ class TestBackend:
|
||||
self.graph_pre_pass = None
|
||||
self.graph_post_pass = None
|
||||
|
||||
def post_pass(self,
|
||||
graph: fx.Graph,
|
||||
runtime_shape: int | None = None) -> fx.Graph:
|
||||
def post_pass(self, graph: fx.Graph, runtime_shape: int | None = None) -> fx.Graph:
|
||||
"""
|
||||
Apply custom graph transformation passes.
|
||||
"""
|
||||
@@ -62,13 +61,13 @@ class TestBackend:
|
||||
return graph
|
||||
|
||||
def compile(
|
||||
self,
|
||||
graph: fx.GraphModule,
|
||||
example_inputs: list[Any],
|
||||
compiler_config: dict[str, Any],
|
||||
runtime_shape: Optional[int] = None,
|
||||
key: Optional[str] = None
|
||||
) -> tuple[Optional[Callable], Optional[Any]]:
|
||||
self,
|
||||
graph: fx.GraphModule,
|
||||
example_inputs: list[Any],
|
||||
compiler_config: dict[str, Any],
|
||||
runtime_shape: int | None = None,
|
||||
key: str | None = None,
|
||||
) -> tuple[Callable | None, Any | None]:
|
||||
"""
|
||||
Compile the FX graph using vLLM's Ascend compiler interface.
|
||||
Wraps the post-pass logic into the inner_compile callback.
|
||||
@@ -87,8 +86,7 @@ class TestBackend:
|
||||
)
|
||||
return compiled_fn, None
|
||||
|
||||
def __call__(self, gm: fx.GraphModule,
|
||||
example_inputs: Optional[List[Any]]):
|
||||
def __call__(self, gm: fx.GraphModule, example_inputs: list[Any] | None):
|
||||
"""
|
||||
Make the backend callable by torch.compile().
|
||||
Returns a compiled executable function.
|
||||
@@ -103,17 +101,11 @@ class TestBackend:
|
||||
)
|
||||
return compiled_fn
|
||||
|
||||
def find_nodes_by_target(self, graph: fx.GraphModule,
|
||||
target: OpOverload) -> List[fx.Node]:
|
||||
def find_nodes_by_target(self, graph: fx.GraphModule, target: OpOverload) -> list[fx.Node]:
|
||||
"""Helper to find all FX nodes that call a specific operator."""
|
||||
return [
|
||||
node for node in graph.graph.nodes
|
||||
if hasattr(node, 'target') and node.target == target
|
||||
]
|
||||
return [node for node in graph.graph.nodes if hasattr(node, "target") and node.target == target]
|
||||
|
||||
def check_before_ops(self,
|
||||
ops: Sequence[OpOverload],
|
||||
fully_replaced: bool = True):
|
||||
def check_before_ops(self, ops: Sequence[OpOverload], fully_replaced: bool = True):
|
||||
"""
|
||||
Verify that the original (unfused) operators exist before the pass
|
||||
and are fully removed afterward (if fully_replaced=True).
|
||||
|
||||
@@ -215,6 +215,7 @@ def register_pattern_safe(pattern_class, vllm_config, eps, pattern_key):
|
||||
try:
|
||||
# Import the required pass class
|
||||
from torch._inductor.pattern_matcher import PatternMatcherPass
|
||||
|
||||
pm_pass = PatternMatcherPass()
|
||||
pattern.register(pm_pass)
|
||||
_registered_patterns.add(pattern_key)
|
||||
@@ -243,7 +244,7 @@ def test_rmsnorm_quant_fusion(
|
||||
sp_enable: bool,
|
||||
):
|
||||
# Check if fusion operator is available
|
||||
if not hasattr(torch.ops.npu, 'npu_add_rms_norm_quant'):
|
||||
if not hasattr(torch.ops.npu, "npu_add_rms_norm_quant"):
|
||||
pytest.skip("Fusion operator npu_add_rms_norm_quant not available, skipping test")
|
||||
|
||||
vllm_config = VllmConfig(model_config=ModelConfig(dtype=dtype))
|
||||
@@ -266,7 +267,7 @@ def test_rmsnorm_quant_fusion(
|
||||
if not enable_custom_op():
|
||||
pytest.skip("Custom ops not available, skipping bias test")
|
||||
# Check if the bias operator exists
|
||||
if not hasattr(torch.ops._C_ascend, 'npu_add_rms_norm_bias'):
|
||||
if not hasattr(torch.ops._C_ascend, "npu_add_rms_norm_bias"):
|
||||
pytest.skip("Operator npu_add_rms_norm_bias not available, skipping bias test")
|
||||
if sp_enable:
|
||||
model = ModelSPWithBias(hidden_size, dtype, eps, device="npu")
|
||||
@@ -281,13 +282,11 @@ def test_rmsnorm_quant_fusion(
|
||||
else:
|
||||
# The non-bias patterns currently use npu_add_rms_norm_bias in their pattern matching
|
||||
# so we need to skip if it's not available
|
||||
if not hasattr(torch.ops._C_ascend, 'npu_add_rms_norm_bias'):
|
||||
if not hasattr(torch.ops._C_ascend, "npu_add_rms_norm_bias"):
|
||||
pytest.skip("Operator npu_add_rms_norm_bias not available, skipping test")
|
||||
if sp_enable:
|
||||
model = ModelSPWithoutBias(hidden_size, dtype, eps, device="npu")
|
||||
register_pattern_safe(
|
||||
AddRMSNormQuantSPPattern, vllm_config, eps, "GraphEXAddRMSNormQuantSPPattern"
|
||||
)
|
||||
register_pattern_safe(AddRMSNormQuantSPPattern, vllm_config, eps, "GraphEXAddRMSNormQuantSPPattern")
|
||||
else:
|
||||
model = ModelWithoutBias(hidden_size, dtype, eps, device="npu")
|
||||
register_pattern_safe(AddRMSNormQuantPattern, vllm_config, eps, "GraphEXAddRMSNormQuantPattern")
|
||||
@@ -302,5 +301,9 @@ def test_rmsnorm_quant_fusion(
|
||||
compiled_out, compiled_res = compiled_model(x)
|
||||
|
||||
# Verify output shapes are correct
|
||||
assert compiled_out.shape == (num_tokens, hidden_size), f"Expected shape {(num_tokens, hidden_size)}, got {compiled_out.shape}"
|
||||
assert compiled_res.shape == (num_tokens, hidden_size), f"Expected shape {(num_tokens, hidden_size)}, got {compiled_res.shape}"
|
||||
assert compiled_out.shape == (num_tokens, hidden_size), (
|
||||
f"Expected shape {(num_tokens, hidden_size)}, got {compiled_out.shape}"
|
||||
)
|
||||
assert compiled_res.shape == (num_tokens, hidden_size), (
|
||||
f"Expected shape {(num_tokens, hidden_size)}, got {compiled_res.shape}"
|
||||
)
|
||||
|
||||
@@ -201,6 +201,7 @@ def test_rmsnorm_quant_fusion(
|
||||
vllm_config=vllm_config, head_dim=head_dim, num_heads=num_heads, num_kv_heads=num_kv_heads, eps=eps
|
||||
)
|
||||
from torch._inductor.pattern_matcher import PatternMatcherPass
|
||||
|
||||
pm_pass = PatternMatcherPass()
|
||||
fusion_pattern.register(pm_pass)
|
||||
model = model.to("npu")
|
||||
|
||||
@@ -14,25 +14,20 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
from typing import List
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch_npu
|
||||
import vllm.config
|
||||
from vllm.config import ModelConfig, VllmConfig
|
||||
from vllm.distributed import (ensure_model_parallel_initialized,
|
||||
init_distributed_environment)
|
||||
from vllm.distributed import ensure_model_parallel_initialized, init_distributed_environment
|
||||
from vllm.utils.system_utils import update_environment_variables
|
||||
|
||||
import vllm_ascend.ops.register_custom_ops # noqa
|
||||
from tests.e2e.singlecard.compile.backend import TestBackend
|
||||
from vllm_ascend.ascend_forward_context import set_ascend_forward_context
|
||||
from vllm_ascend.compilation.passes.norm_quant_fusion_pass import \
|
||||
AddRMSNormQuantFusionPass
|
||||
from vllm_ascend.utils import enable_custom_op
|
||||
from vllm_ascend.utils import vllm_version_is
|
||||
from vllm_ascend.compilation.passes.norm_quant_fusion_pass import AddRMSNormQuantFusionPass
|
||||
from vllm_ascend.utils import enable_custom_op, vllm_version_is
|
||||
|
||||
if vllm_version_is("0.15.0"):
|
||||
from vllm.compilation.fx_utils import OpOverload # type: ignore
|
||||
@@ -48,34 +43,24 @@ def get_or_create_backend(vllm_config):
|
||||
"""Get or create backend with fusion passes (cached to avoid duplicate pattern registration)."""
|
||||
global _backend_cache
|
||||
if _backend_cache is None:
|
||||
_backend_cache = TestBackend(custom_passes=[
|
||||
AddRMSNormQuantFusionPass(vllm_config=vllm_config)
|
||||
])
|
||||
_backend_cache = TestBackend(custom_passes=[AddRMSNormQuantFusionPass(vllm_config=vllm_config)])
|
||||
return _backend_cache
|
||||
|
||||
|
||||
class TestModelWithoutBias(nn.Module):
|
||||
"""
|
||||
A minimal test model that simulates the pattern:
|
||||
AddRMSNorm → Quantization (without bias)
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
hidden_size: int,
|
||||
dtype: torch.dtype,
|
||||
eps: float = 1e-6,
|
||||
device="npu"):
|
||||
def __init__(self, hidden_size: int, dtype: torch.dtype, eps: float = 1e-6, device="npu"):
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
self.eps = eps
|
||||
self.rms_norm_weight = nn.Parameter(
|
||||
torch.randn(hidden_size, device=device))
|
||||
self.rms_norm_weight = nn.Parameter(torch.randn(hidden_size, device=device))
|
||||
self.quant_scale = torch.ones(hidden_size, dtype=dtype, device=device)
|
||||
self.quant_scale_reciprocal = torch.ones(hidden_size,
|
||||
dtype=dtype,
|
||||
device=device)
|
||||
self.quant_offset = torch.zeros(hidden_size,
|
||||
dtype=dtype,
|
||||
device=device)
|
||||
self.quant_scale_reciprocal = torch.ones(hidden_size, dtype=dtype, device=device)
|
||||
self.quant_offset = torch.zeros(hidden_size, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
@@ -87,23 +72,20 @@ class TestModelWithoutBias(nn.Module):
|
||||
residual = torch.zeros_like(x)
|
||||
|
||||
norm_output, _, new_residual = torch.ops._C_ascend.npu_add_rms_norm_bias(
|
||||
x, residual, self.rms_norm_weight, None, self.eps)
|
||||
x, residual, self.rms_norm_weight, None, self.eps
|
||||
)
|
||||
|
||||
quantized_output = torch.ops.vllm.quantize(norm_output,
|
||||
self.quant_scale,
|
||||
self.quant_scale_reciprocal,
|
||||
self.quant_offset)
|
||||
quantized_output = torch.ops.vllm.quantize(
|
||||
norm_output, self.quant_scale, self.quant_scale_reciprocal, self.quant_offset
|
||||
)
|
||||
|
||||
return quantized_output, new_residual
|
||||
|
||||
def ops_in_model_before(self) -> List[OpOverload]:
|
||||
def ops_in_model_before(self) -> list[OpOverload]:
|
||||
"""Return the list of expected operators BEFORE fusion."""
|
||||
return [
|
||||
torch.ops._C_ascend.npu_add_rms_norm_bias.default,
|
||||
torch.ops.vllm.quantize.default
|
||||
]
|
||||
return [torch.ops._C_ascend.npu_add_rms_norm_bias.default, torch.ops.vllm.quantize.default]
|
||||
|
||||
def ops_in_model_after(self) -> List[OpOverload]:
|
||||
def ops_in_model_after(self) -> list[OpOverload]:
|
||||
"""Return the list of expected operators AFTER successful fusion."""
|
||||
return [torch.ops.npu.npu_add_rms_norm_quant.default]
|
||||
|
||||
@@ -114,24 +96,15 @@ class TestModelWithBias(nn.Module):
|
||||
AddRMSNorm → Add Bias → Quantization (with bias)
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
hidden_size: int,
|
||||
dtype: torch.dtype,
|
||||
eps: float = 1e-6,
|
||||
device="npu"):
|
||||
def __init__(self, hidden_size: int, dtype: torch.dtype, eps: float = 1e-6, device="npu"):
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
self.eps = eps
|
||||
self.rms_norm_weight = nn.Parameter(
|
||||
torch.randn(hidden_size, device=device))
|
||||
self.rms_norm_weight = nn.Parameter(torch.randn(hidden_size, device=device))
|
||||
self.bias = nn.Parameter(torch.randn(hidden_size, device=device))
|
||||
self.quant_scale = torch.ones(hidden_size, dtype=dtype, device=device)
|
||||
self.quant_scale_reciprocal = torch.ones(hidden_size,
|
||||
dtype=dtype,
|
||||
device=device)
|
||||
self.quant_offset = torch.zeros(hidden_size,
|
||||
dtype=dtype,
|
||||
device=device)
|
||||
self.quant_scale_reciprocal = torch.ones(hidden_size, dtype=dtype, device=device)
|
||||
self.quant_offset = torch.zeros(hidden_size, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
@@ -144,23 +117,20 @@ class TestModelWithBias(nn.Module):
|
||||
residual = torch.zeros_like(x)
|
||||
|
||||
norm_output_with_bias, _, new_residual = torch.ops._C_ascend.npu_add_rms_norm_bias(
|
||||
x, residual, self.rms_norm_weight, self.bias, self.eps)
|
||||
x, residual, self.rms_norm_weight, self.bias, self.eps
|
||||
)
|
||||
|
||||
quantized_output = torch.ops.vllm.quantize(norm_output_with_bias,
|
||||
self.quant_scale,
|
||||
self.quant_scale_reciprocal,
|
||||
self.quant_offset)
|
||||
quantized_output = torch.ops.vllm.quantize(
|
||||
norm_output_with_bias, self.quant_scale, self.quant_scale_reciprocal, self.quant_offset
|
||||
)
|
||||
|
||||
return quantized_output, new_residual
|
||||
|
||||
def ops_in_model_before(self) -> List[OpOverload]:
|
||||
def ops_in_model_before(self) -> list[OpOverload]:
|
||||
"""Return the list of expected operators BEFORE fusion."""
|
||||
return [
|
||||
torch.ops._C_ascend.npu_add_rms_norm_bias.default,
|
||||
torch.ops.vllm.quantize.default
|
||||
]
|
||||
return [torch.ops._C_ascend.npu_add_rms_norm_bias.default, torch.ops.vllm.quantize.default]
|
||||
|
||||
def ops_in_model_after(self) -> List[OpOverload]:
|
||||
def ops_in_model_after(self) -> list[OpOverload]:
|
||||
"""Return the list of expected operators AFTER successful fusion."""
|
||||
return [torch.ops.npu.npu_add_rms_norm_quant.default]
|
||||
|
||||
@@ -171,23 +141,14 @@ class TestModelSPWithoutBias(nn.Module):
|
||||
AddRMSNorm → maybe_allgather → Quantization (without bias)
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
hidden_size: int,
|
||||
dtype: torch.dtype,
|
||||
eps: float = 1e-6,
|
||||
device="npu"):
|
||||
def __init__(self, hidden_size: int, dtype: torch.dtype, eps: float = 1e-6, device="npu"):
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
self.eps = eps
|
||||
self.rms_norm_weight = nn.Parameter(
|
||||
torch.randn(hidden_size, device=device))
|
||||
self.rms_norm_weight = nn.Parameter(torch.randn(hidden_size, device=device))
|
||||
self.quant_scale = torch.ones(hidden_size, dtype=dtype, device=device)
|
||||
self.quant_scale_reciprocal = torch.ones(hidden_size,
|
||||
dtype=dtype,
|
||||
device=device)
|
||||
self.quant_offset = torch.zeros(hidden_size,
|
||||
dtype=dtype,
|
||||
device=device)
|
||||
self.quant_scale_reciprocal = torch.ones(hidden_size, dtype=dtype, device=device)
|
||||
self.quant_offset = torch.zeros(hidden_size, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
@@ -200,32 +161,28 @@ class TestModelSPWithoutBias(nn.Module):
|
||||
residual = torch.zeros_like(x)
|
||||
|
||||
norm_output, _, new_residual = torch.ops._C_ascend.npu_add_rms_norm_bias(
|
||||
x, residual, self.rms_norm_weight, None, self.eps)
|
||||
x, residual, self.rms_norm_weight, None, self.eps
|
||||
)
|
||||
|
||||
norm_output = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
|
||||
norm_output, True)
|
||||
norm_output = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(norm_output, True)
|
||||
|
||||
quantized_output = torch.ops.vllm.quantize(norm_output,
|
||||
self.quant_scale,
|
||||
self.quant_scale_reciprocal,
|
||||
self.quant_offset)
|
||||
quantized_output = torch.ops.vllm.quantize(
|
||||
norm_output, self.quant_scale, self.quant_scale_reciprocal, self.quant_offset
|
||||
)
|
||||
|
||||
return quantized_output, new_residual
|
||||
|
||||
def ops_in_model_before(self) -> List[OpOverload]:
|
||||
def ops_in_model_before(self) -> list[OpOverload]:
|
||||
"""Return the list of expected operators BEFORE fusion."""
|
||||
return [
|
||||
torch.ops._C_ascend.npu_add_rms_norm_bias.default,
|
||||
torch.ops.vllm.maybe_all_gather_and_maybe_unpad.default,
|
||||
torch.ops.vllm.quantize.default
|
||||
torch.ops.vllm.quantize.default,
|
||||
]
|
||||
|
||||
def ops_in_model_after(self) -> List[OpOverload]:
|
||||
def ops_in_model_after(self) -> list[OpOverload]:
|
||||
"""Return the list of expected operators AFTER successful fusion."""
|
||||
return [
|
||||
torch.ops.npu.npu_add_rms_norm_quant.default,
|
||||
torch.ops.vllm.maybe_all_gather_and_maybe_unpad.default
|
||||
]
|
||||
return [torch.ops.npu.npu_add_rms_norm_quant.default, torch.ops.vllm.maybe_all_gather_and_maybe_unpad.default]
|
||||
|
||||
|
||||
class TestModelSPWithBias(nn.Module):
|
||||
@@ -234,24 +191,15 @@ class TestModelSPWithBias(nn.Module):
|
||||
AddRMSNorm → Add bias → maybe_allgather → Quantization (without bias)
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
hidden_size: int,
|
||||
dtype: torch.dtype,
|
||||
eps: float = 1e-6,
|
||||
device="npu"):
|
||||
def __init__(self, hidden_size: int, dtype: torch.dtype, eps: float = 1e-6, device="npu"):
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
self.eps = eps
|
||||
self.rms_norm_weight = nn.Parameter(
|
||||
torch.randn(hidden_size, device=device))
|
||||
self.rms_norm_weight = nn.Parameter(torch.randn(hidden_size, device=device))
|
||||
self.bias = nn.Parameter(torch.randn(hidden_size, device=device))
|
||||
self.quant_scale = torch.ones(hidden_size, dtype=dtype, device=device)
|
||||
self.quant_scale_reciprocal = torch.ones(hidden_size,
|
||||
dtype=dtype,
|
||||
device=device)
|
||||
self.quant_offset = torch.zeros(hidden_size,
|
||||
dtype=dtype,
|
||||
device=device)
|
||||
self.quant_scale_reciprocal = torch.ones(hidden_size, dtype=dtype, device=device)
|
||||
self.quant_offset = torch.zeros(hidden_size, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
@@ -265,32 +213,28 @@ class TestModelSPWithBias(nn.Module):
|
||||
residual = torch.zeros_like(x)
|
||||
|
||||
norm_output_with_bias, _, new_residual = torch.ops._C_ascend.npu_add_rms_norm_bias(
|
||||
x, residual, self.rms_norm_weight, self.bias, self.eps)
|
||||
x, residual, self.rms_norm_weight, self.bias, self.eps
|
||||
)
|
||||
|
||||
norm_output_with_bias = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
|
||||
norm_output_with_bias, True)
|
||||
norm_output_with_bias = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(norm_output_with_bias, True)
|
||||
|
||||
quantized_output = torch.ops.vllm.quantize(norm_output_with_bias,
|
||||
self.quant_scale,
|
||||
self.quant_scale_reciprocal,
|
||||
self.quant_offset)
|
||||
quantized_output = torch.ops.vllm.quantize(
|
||||
norm_output_with_bias, self.quant_scale, self.quant_scale_reciprocal, self.quant_offset
|
||||
)
|
||||
|
||||
return quantized_output, new_residual
|
||||
|
||||
def ops_in_model_before(self) -> List[OpOverload]:
|
||||
def ops_in_model_before(self) -> list[OpOverload]:
|
||||
"""Return the list of expected operators BEFORE fusion."""
|
||||
return [
|
||||
torch.ops._C_ascend.npu_add_rms_norm_bias.default,
|
||||
torch.ops.vllm.maybe_all_gather_and_maybe_unpad.default,
|
||||
torch.ops.vllm.quantize.default
|
||||
torch.ops.vllm.quantize.default,
|
||||
]
|
||||
|
||||
def ops_in_model_after(self) -> List[OpOverload]:
|
||||
def ops_in_model_after(self) -> list[OpOverload]:
|
||||
"""Return the list of expected operators AFTER successful fusion."""
|
||||
return [
|
||||
torch.ops.npu.npu_add_rms_norm_quant.default,
|
||||
torch.ops.vllm.maybe_all_gather_and_maybe_unpad.default
|
||||
]
|
||||
return [torch.ops.npu.npu_add_rms_norm_quant.default, torch.ops.vllm.maybe_all_gather_and_maybe_unpad.default]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dtype", [torch.bfloat16])
|
||||
@@ -317,58 +261,42 @@ def test_rmsnorm_quant_fusion(
|
||||
vllm_config = VllmConfig(model_config=ModelConfig(dtype=dtype))
|
||||
|
||||
with vllm.config.set_current_vllm_config(vllm_config):
|
||||
update_environment_variables({
|
||||
"RANK": "0",
|
||||
"LOCAL_RANK": "0",
|
||||
"WORLD_SIZE": "1",
|
||||
"MASTER_ADDR": "localhost",
|
||||
"MASTER_PORT": "12345",
|
||||
})
|
||||
update_environment_variables(
|
||||
{
|
||||
"RANK": "0",
|
||||
"LOCAL_RANK": "0",
|
||||
"WORLD_SIZE": "1",
|
||||
"MASTER_ADDR": "localhost",
|
||||
"MASTER_PORT": "12345",
|
||||
}
|
||||
)
|
||||
init_distributed_environment()
|
||||
ensure_model_parallel_initialized(1, 1)
|
||||
|
||||
with vllm.config.set_current_vllm_config(vllm_config):
|
||||
with set_ascend_forward_context(None, vllm_config):
|
||||
backend = get_or_create_backend(vllm_config)
|
||||
if use_bias:
|
||||
if not enable_custom_op():
|
||||
return
|
||||
if sp_enable:
|
||||
model = TestModelSPWithBias(hidden_size,
|
||||
dtype,
|
||||
eps,
|
||||
device="npu")
|
||||
else:
|
||||
model = TestModelWithBias(hidden_size,
|
||||
dtype,
|
||||
eps,
|
||||
device="npu")
|
||||
with vllm.config.set_current_vllm_config(vllm_config), set_ascend_forward_context(None, vllm_config):
|
||||
backend = get_or_create_backend(vllm_config)
|
||||
if use_bias:
|
||||
if not enable_custom_op():
|
||||
return
|
||||
if sp_enable:
|
||||
model = TestModelSPWithBias(hidden_size, dtype, eps, device="npu")
|
||||
else:
|
||||
if sp_enable:
|
||||
model = TestModelSPWithoutBias(hidden_size,
|
||||
dtype,
|
||||
eps,
|
||||
device="npu")
|
||||
else:
|
||||
model = TestModelWithoutBias(hidden_size,
|
||||
dtype,
|
||||
eps,
|
||||
device="npu")
|
||||
model = model.to("npu")
|
||||
model = TestModelWithBias(hidden_size, dtype, eps, device="npu")
|
||||
else:
|
||||
if sp_enable:
|
||||
model = TestModelSPWithoutBias(hidden_size, dtype, eps, device="npu")
|
||||
else:
|
||||
model = TestModelWithoutBias(hidden_size, dtype, eps, device="npu")
|
||||
model = model.to("npu")
|
||||
|
||||
x = torch.rand(num_tokens,
|
||||
hidden_size,
|
||||
device="npu",
|
||||
dtype=dtype,
|
||||
requires_grad=False)
|
||||
x = torch.rand(num_tokens, hidden_size, device="npu", dtype=dtype, requires_grad=False)
|
||||
|
||||
result_unfused = model(x)
|
||||
print("Unfused result:", [t.shape for t in result_unfused])
|
||||
model_fused = torch.compile(model, backend=backend)
|
||||
result_fused = model_fused(x)
|
||||
print("Fused result:", [t.shape for t in result_fused])
|
||||
result_unfused = model(x)
|
||||
print("Unfused result:", [t.shape for t in result_unfused])
|
||||
model_fused = torch.compile(model, backend=backend)
|
||||
result_fused = model_fused(x)
|
||||
print("Fused result:", [t.shape for t in result_fused])
|
||||
|
||||
print("=== Checking operator fusion ===")
|
||||
backend.check_before_ops(model.ops_in_model_before(),
|
||||
fully_replaced=not sp_enable)
|
||||
backend.check_after_ops(model.ops_in_model_after())
|
||||
print("=== Checking operator fusion ===")
|
||||
backend.check_before_ops(model.ops_in_model_before(), fully_replaced=not sp_enable)
|
||||
backend.check_after_ops(model.ops_in_model_after())
|
||||
|
||||
Reference in New Issue
Block a user