[Lint]Style: Convert vllm-ascend/compilation to ruff format (#5912)

### What this PR does / why we need it?
Convert `vllm-ascend/compilation` to ruff format.

### Does this PR introduce _any_ user-facing change?
During this migration, we encountered some **errors** in our CI and
testing environments, such as:
```
vllm_ascend/utils.py:653: in <module>
    def register_ascend_customop(vllm_config: VllmConfig | None = None):
                                              ^^^^^^^^^^^^^^^^^
E   TypeError: unsupported operand type(s) for |: 'NoneType' and 'NoneType'
```

**1. Root Cause Analysis:**
The project uses a common pattern to break circular dependencies:
```python
if TYPE_CHECKING:
    from vllm.config import VllmConfig
else:
    VllmConfig = None  # Placeholder assigned at runtime
```
When Python parses the function definition `def
register_ascend_customop(vllm_config: VllmConfig | None)`, it attempts
to evaluate the expression `VllmConfig | None`.
Since `VllmConfig` is assigned `None` at runtime, the expression
effectively becomes `None | None`. In Python, `None` is an instance of
`NoneType`. While the `|` operator is implemented for Type objects
(classes), it is not supported for `NoneType` instances, leading to the
`TypeError` shown above.

**2. Solution:**
To maintain the modern `|` syntax required by our new linting standards
while preserving our dependency management strategy, I have introduced:
```python
from __future__ import annotations
```
at the top of the affected files. This enables **Postponed Evaluation of
Annotations (PEP 563)**.

**3. Impact and Benefits:**
- By enabling `annotations`, Python no longer executes the `VllmConfig |
None` operation during module load. Instead, it stores the annotation as
a string literal, completely avoiding the `None | None` calculation.
- We can keep the `VllmConfig = None` placeholders. This ensures that
other modules can still import these symbols without triggering an
`ImportError`, maintaining a stable dependency graph.
- IDEs and static type checkers (MyPy/Pyright) continue to resolve the
types correctly. This allows us to use modern syntax without sacrificing
type safety or runtime stability.
- The only side effect is that `__annotations__` will now return strings
instead of type objects. Since this module does not use runtime type
enforcement or reflection, this change has zero negative impact on
existing functionality.
### How was this patch tested?

- vLLM version: v0.13.0
- vLLM main:
11b6af5280

---------

Signed-off-by: MrZ20 <2609716663@qq.com>
This commit is contained in:
SILONG ZENG
2026-01-16 20:57:46 +08:00
committed by GitHub
parent 3af91e5ac4
commit 52086394ae
16 changed files with 996 additions and 1140 deletions

View File

@@ -2,9 +2,10 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import dataclasses
from collections.abc import Callable
from contextlib import ExitStack
from dataclasses import dataclass
from typing import Any, Callable, Optional
from typing import Any
from unittest.mock import patch
import numpy as np
@@ -27,12 +28,12 @@ from ..utils import weak_ref_tensors
@dataclasses.dataclass
class ACLGraphEntry:
batch_descriptor: BatchDescriptor
aclgraph: Optional[torch.npu.NPUGraph] = None
output: Optional[Any] = None
aclgraph: torch.npu.NPUGraph | None = None
output: Any | None = None
# for aclgraph debugging, track the input addresses
# during capture, and check if they are the same during replay
input_addresses: Optional[list[int]] = None
input_addresses: list[int] | None = None
class ACLGraphWrapper:
@@ -60,11 +61,13 @@ class ACLGraphWrapper:
guaranteed when VLLM_LOGGING_LEVEL == "DEBUG".
"""
def __init__(self,
runnable: Callable,
vllm_config: VllmConfig,
runtime_mode: CUDAGraphMode,
cudagraph_options: Optional[CUDAGraphOptions] = None):
def __init__(
self,
runnable: Callable,
vllm_config: VllmConfig,
runtime_mode: CUDAGraphMode,
cudagraph_options: CUDAGraphOptions | None = None,
):
self.runnable = runnable
self.vllm_config = vllm_config
self.runtime_mode = runtime_mode
@@ -83,15 +86,13 @@ class ACLGraphWrapper:
self.aclgraph_options = cudagraph_options
# the entries for different batch descriptors that we need to capture
# aclgraphs for.
self.concrete_aclgraph_entries: dict[BatchDescriptor, ACLGraphEntry]\
= {}
self.concrete_aclgraph_entries: dict[BatchDescriptor, ACLGraphEntry] = {}
def __getattr__(self, key: str):
# allow accessing the attributes of the runnable.
if hasattr(self.runnable, key):
return getattr(self.runnable, key)
raise AttributeError(f"Attribute {key} not exists in the runnable of "
f"aclgraph wrapper: {self.runnable}")
raise AttributeError(f"Attribute {key} not exists in the runnable of aclgraph wrapper: {self.runnable}")
def unwrap(self) -> Callable:
# in case we need to access the original runnable.
@@ -102,8 +103,7 @@ class ACLGraphWrapper:
batch_descriptor = forward_context.batch_descriptor
aclgraph_runtime_mode = forward_context.cudagraph_runtime_mode
if aclgraph_runtime_mode == CUDAGraphMode.NONE or \
aclgraph_runtime_mode != self.runtime_mode:
if aclgraph_runtime_mode == CUDAGraphMode.NONE or aclgraph_runtime_mode != self.runtime_mode:
# CUDAGraphMode.NONE could mean the profile run, a warmup run, or
# running without aclgraphs.
# We do not trigger capture/replay if the runtime mode is not
@@ -114,8 +114,7 @@ class ACLGraphWrapper:
if batch_descriptor not in self.concrete_aclgraph_entries:
# create a new entry for this batch descriptor
self.concrete_aclgraph_entries[batch_descriptor] = \
ACLGraphEntry(batch_descriptor=batch_descriptor)
self.concrete_aclgraph_entries[batch_descriptor] = ACLGraphEntry(batch_descriptor=batch_descriptor)
entry = self.concrete_aclgraph_entries[batch_descriptor]
@@ -125,14 +124,11 @@ class ACLGraphWrapper:
# capturing is fast, we don't need to log it for every
# shape. E.g. we only log it for the first subgraph in
# piecewise mode.
logger.debug("Capturing a aclgraph on (%s,%s)",
self.runtime_mode.name, entry.batch_descriptor)
logger.debug("Capturing a aclgraph on (%s,%s)", self.runtime_mode.name, entry.batch_descriptor)
# validate that aclgraph capturing is legal at this point.
validate_cudagraph_capturing_enabled()
input_addresses = [
x.data_ptr() for x in args if isinstance(x, torch.Tensor)
]
input_addresses = [x.data_ptr() for x in args if isinstance(x, torch.Tensor)]
entry.input_addresses = input_addresses
aclgraph = torch.npu.NPUGraph()
@@ -145,8 +141,7 @@ class ACLGraphWrapper:
# therefore, we only run gc for the first graph,
# and disable gc for the rest of the graphs.
stack.enter_context(patch("gc.collect", lambda: None))
stack.enter_context(
patch("torch.npu.empty_cache", lambda: None))
stack.enter_context(patch("torch.npu.empty_cache", lambda: None))
# mind-exploding: carefully manage the reference and memory.
forward_context.capturing = True
@@ -183,13 +178,12 @@ class ACLGraphWrapper:
if self.is_debugging_mode:
# check if the input addresses are the same
new_input_addresses = [
x.data_ptr() for x in args if isinstance(x, torch.Tensor)
]
new_input_addresses = [x.data_ptr() for x in args if isinstance(x, torch.Tensor)]
assert new_input_addresses == entry.input_addresses, (
f"Input addresses for aclgraphs are different "
f"during replay. Expected {entry.input_addresses}, "
f"got {new_input_addresses}")
f"got {new_input_addresses}"
)
logger.info_once("Replaying aclgraph")
# In async scheduling or multi-threaded (MT) scenarios when graph mode is FULL, it is possible that
@@ -209,8 +203,7 @@ def weak_ref_workspaces(params):
for num_tokens in params.workspaces:
if params.workspaces[num_tokens] is None:
continue
params.workspaces[num_tokens] = weak_ref_tensors(
params.workspaces[num_tokens])
params.workspaces[num_tokens] = weak_ref_tensors(params.workspaces[num_tokens])
def _update_attn_pa_params(update_stream, forward_context, runtime_shape):
@@ -219,10 +212,10 @@ def _update_attn_pa_params(update_stream, forward_context, runtime_shape):
# for each layer's attention op in the graph.
with torch.npu.stream(update_stream):
for key, param, handle, event in zip(
forward_context.attn_metadata,
graph_params.attn_params[runtime_shape],
graph_params.handles[runtime_shape],
graph_params.events[runtime_shape],
forward_context.attn_metadata,
graph_params.attn_params[runtime_shape],
graph_params.handles[runtime_shape],
graph_params.events[runtime_shape],
):
(
query,
@@ -254,18 +247,21 @@ def _update_attn_pa_params(update_stream, forward_context, runtime_shape):
scale_value=scale,
block_table=block_table,
context_lens=seq_lens,
out=output)
out=output,
)
torch.npu.graph_task_update_begin(update_stream, handle)
torch_npu._npu_paged_attention(query=query,
key_cache=key_cache,
value_cache=value_cache,
num_kv_heads=num_kv_heads,
num_heads=num_heads,
scale_value=scale,
block_table=block_table,
context_lens=seq_lens,
out=output,
workspace=workspace)
torch_npu._npu_paged_attention(
query=query,
key_cache=key_cache,
value_cache=value_cache,
num_kv_heads=num_kv_heads,
num_heads=num_heads,
scale_value=scale,
block_table=block_table,
context_lens=seq_lens,
out=output,
workspace=workspace,
)
torch.npu.graph_task_update_end(update_stream)
event.record(update_stream)
@@ -282,18 +278,29 @@ def _update_attn_fia_params(update_stream, forward_context, runtime_shape):
# filters out the update operations for linear_attn.
with torch.npu.stream(update_stream):
for key, param, handle, event in zip(
forward_context.attn_metadata,
graph_params.attn_params[runtime_shape],
graph_params.handles[runtime_shape],
graph_params.events[runtime_shape],
forward_context.attn_metadata,
graph_params.attn_params[runtime_shape],
graph_params.handles[runtime_shape],
graph_params.events[runtime_shape],
):
(query, key_cache, value, block_tables, attn_mask, block_size,
seq_lens, query_start_loc, num_kv_heads, num_heads, scale,
attn_output, softmax_lse) = param
(
query,
key_cache,
value,
block_tables,
attn_mask,
block_size,
seq_lens,
query_start_loc,
num_kv_heads,
num_heads,
scale,
attn_output,
softmax_lse,
) = param
seq_lens = forward_context.attn_metadata[key].seq_lens_list
actual_seq_lengths_q = forward_context.attn_metadata[
key].actual_seq_lengths_q
actual_seq_lengths_q = forward_context.attn_metadata[key].actual_seq_lengths_q
torch.npu.graph_task_update_begin(update_stream, handle)
torch_npu.npu_fused_infer_attention_score.out(
query=query,
@@ -317,16 +324,14 @@ def _update_attn_fia_params(update_stream, forward_context, runtime_shape):
event.record(update_stream)
def update_attn_params(update_stream, forward_context, runtime_shape,
vllm_config):
def update_attn_params(update_stream, forward_context, runtime_shape, vllm_config):
if using_paged_attention(runtime_shape, vllm_config):
_update_attn_pa_params(update_stream, forward_context, runtime_shape)
else:
_update_attn_fia_params(update_stream, forward_context, runtime_shape)
def update_mla_attn_params(update_stream, forward_context, runtime_shape,
speculative_config):
def update_mla_attn_params(update_stream, forward_context, runtime_shape, speculative_config):
if forward_context.is_draft_model:
graph_params = get_draft_graph_params()
else:
@@ -335,41 +340,44 @@ def update_mla_attn_params(update_stream, forward_context, runtime_shape,
# for each layer's attention op in the graph.
with torch.npu.stream(update_stream):
for key, param, handle, event in zip(
forward_context.attn_metadata,
graph_params.attn_params[runtime_shape],
graph_params.handles[runtime_shape],
graph_params.events[runtime_shape],
forward_context.attn_metadata,
graph_params.attn_params[runtime_shape],
graph_params.handles[runtime_shape],
graph_params.events[runtime_shape],
):
(q_nope, k_nope, q_pe, k_pe, num_heads, num_kv_heads, input_layout,
attn_mask, sparse_mode, scale, block_table, block_size,
seq_lens_list, actual_seq_lengths, attn_output,
softmax_lse) = param
seq_lens_list = forward_context.attn_metadata[
key].decode.seq_lens_list
if speculative_config and speculative_config.method == "mtp" \
and not forward_context.is_draft_model:
actual_seq_lengths = forward_context.attn_metadata[
key].decode.actual_seq_lengths_q
(
q_nope,
k_nope,
q_pe,
k_pe,
num_heads,
num_kv_heads,
input_layout,
attn_mask,
sparse_mode,
scale,
block_table,
block_size,
seq_lens_list,
actual_seq_lengths,
attn_output,
softmax_lse,
) = param
seq_lens_list = forward_context.attn_metadata[key].decode.seq_lens_list
if speculative_config and speculative_config.method == "mtp" and not forward_context.is_draft_model:
actual_seq_lengths = forward_context.attn_metadata[key].decode.actual_seq_lengths_q
spec_multiple = speculative_config.num_speculative_tokens + 1
seq_lens_list = seq_lens_list + [0] * (
runtime_shape // spec_multiple - len(seq_lens_list))
actual_seq_lengths = [
spec_multiple * (i + 1)
for i in range(runtime_shape // spec_multiple)
]
seq_lens_list = seq_lens_list + [0] * (runtime_shape // spec_multiple - len(seq_lens_list))
actual_seq_lengths = [spec_multiple * (i + 1) for i in range(runtime_shape // spec_multiple)]
elif forward_context.is_draft_model:
actual_seq_lengths = forward_context.attn_metadata[
key].decode.actual_seq_lengths_q
block_table = forward_context.attn_metadata[
key].decode.block_table
actual_seq_lengths = forward_context.attn_metadata[key].decode.actual_seq_lengths_q
block_table = forward_context.attn_metadata[key].decode.block_table
# TODO: This is a hack and should be fixed in the future.
if speculative_config.disable_padded_drafter_batch:
block_table = block_table[:len(actual_seq_lengths)]
seq_lens_list = seq_lens_list + [0] * (
len(actual_seq_lengths) - len(seq_lens_list))
block_table = block_table[: len(actual_seq_lengths)]
seq_lens_list = seq_lens_list + [0] * (len(actual_seq_lengths) - len(seq_lens_list))
else:
seq_lens_list = seq_lens_list + [0] * (runtime_shape -
len(seq_lens_list))
seq_lens_list = seq_lens_list + [0] * (runtime_shape - len(seq_lens_list))
torch.npu.graph_task_update_begin(update_stream, handle)
torch_npu.npu_fused_infer_attention_score.out(
@@ -391,7 +399,8 @@ def update_mla_attn_params(update_stream, forward_context, runtime_shape,
actual_seq_lengths_kv=seq_lens_list,
actual_seq_lengths=actual_seq_lengths,
workspace=graph_params.workspaces.get(runtime_shape),
out=[attn_output, softmax_lse])
out=[attn_output, softmax_lse],
)
torch.npu.graph_task_update_end(update_stream)
event.record(update_stream)
@@ -403,34 +412,40 @@ def update_attn_dcp_pcp_params(update_stream, forward_context, runtime_shape):
graph_params = get_graph_params()
with torch.npu.stream(update_stream):
for key, param, handle, event in zip(
forward_context.attn_metadata,
graph_params.attn_params[runtime_shape],
graph_params.handles[runtime_shape],
graph_params.events[runtime_shape],
forward_context.attn_metadata,
graph_params.attn_params[runtime_shape],
graph_params.handles[runtime_shape],
graph_params.events[runtime_shape],
):
(q_nope, k_nope, value, num_heads, num_kv_heads, scale,
block_table, block_size, actual_seq_lengths_kv,
actual_seq_lengths_q, attn_output, softmax_lse, dcp_size,
pcp_rank, dcp_rank) = param
(
q_nope,
k_nope,
value,
num_heads,
num_kv_heads,
scale,
block_table,
block_size,
actual_seq_lengths_kv,
actual_seq_lengths_q,
attn_output,
softmax_lse,
dcp_size,
pcp_rank,
dcp_rank,
) = param
attn_metadata = forward_context.attn_metadata[key]
actual_seq_lengths_kv = attn_metadata.decode_meta.num_computed_tokens_of_pcp_dcp[:,
pcp_rank,
dcp_rank]
actual_seq_lengths_kv = attn_metadata.decode_meta.num_computed_tokens_of_pcp_dcp[:, pcp_rank, dcp_rank]
pad_length = runtime_shape - len(actual_seq_lengths_kv)
if pad_length > 0:
pad_tensor = np.zeros(pad_length,
dtype=actual_seq_lengths_kv.dtype)
actual_seq_lengths_kv = np.concatenate(
[actual_seq_lengths_kv, pad_tensor])
pad_tensor = np.zeros(pad_length, dtype=actual_seq_lengths_kv.dtype)
actual_seq_lengths_kv = np.concatenate([actual_seq_lengths_kv, pad_tensor])
actual_seq_lengths_q = attn_metadata.actual_seq_lengths_q[:
attn_metadata
.
num_decode_tokens]
if (runtime_shape - len(actual_seq_lengths_q)):
actual_seq_lengths_q = actual_seq_lengths_q + [
actual_seq_lengths_q[-1]
] * (runtime_shape - len(actual_seq_lengths_q))
actual_seq_lengths_q = attn_metadata.actual_seq_lengths_q[: attn_metadata.num_decode_tokens]
if runtime_shape - len(actual_seq_lengths_q):
actual_seq_lengths_q = actual_seq_lengths_q + [actual_seq_lengths_q[-1]] * (
runtime_shape - len(actual_seq_lengths_q)
)
if dcp_size > 1:
num_heads = num_heads * dcp_size
@@ -453,14 +468,14 @@ def update_attn_dcp_pcp_params(update_stream, forward_context, runtime_shape):
actual_seq_lengths_kv=actual_seq_lengths_kv,
actual_seq_lengths=actual_seq_lengths_q,
workspace=graph_params.workspaces.get(runtime_shape),
out=[attn_output, softmax_lse])
out=[attn_output, softmax_lse],
)
torch.npu.graph_task_update_end(update_stream)
event.record(update_stream)
def update_mla_attn_dcp_pcp_params(update_stream, forward_context,
runtime_shape):
def update_mla_attn_dcp_pcp_params(update_stream, forward_context, runtime_shape):
if forward_context.is_draft_model:
graph_params = get_draft_graph_params()
else:
@@ -469,13 +484,24 @@ def update_mla_attn_dcp_pcp_params(update_stream, forward_context,
# for each layer's attention op in the graph.
with torch.npu.stream(update_stream):
for key, param, handle, event in zip(
forward_context.attn_metadata,
graph_params.attn_params[runtime_shape],
graph_params.handles[runtime_shape],
graph_params.events[runtime_shape],
forward_context.attn_metadata,
graph_params.attn_params[runtime_shape],
graph_params.handles[runtime_shape],
graph_params.events[runtime_shape],
):
(q_nope, q_pe, k_nope, k_pe, block_table, seq_len, num_heads,
scale, num_kv_heads, attn_output, softmax_lse) = param
(
q_nope,
q_pe,
k_nope,
k_pe,
block_table,
seq_len,
num_heads,
scale,
num_kv_heads,
attn_output,
softmax_lse,
) = param
decode_meta = forward_context.attn_metadata[key].decode
seq_len = decode_meta.cp_seq_len
@@ -484,9 +510,7 @@ def update_mla_attn_dcp_pcp_params(update_stream, forward_context,
# to avoid irregular attn_mask shape,
# so there's no need to divide runtime_shape by spec_multiple
pad_length = runtime_shape - len(seq_len)
pad_tensor = torch.zeros(pad_length,
dtype=seq_len.dtype,
device=seq_len.device)
pad_tensor = torch.zeros(pad_length, dtype=seq_len.dtype, device=seq_len.device)
seq_len = torch.cat([seq_len, pad_tensor], dim=0)
torch.npu.graph_task_update_begin(update_stream, handle)
@@ -505,7 +529,8 @@ def update_mla_attn_dcp_pcp_params(update_stream, forward_context,
calc_type="calc_type_ring",
workspace=graph_params.workspaces.get(runtime_shape),
output=attn_output,
lse=softmax_lse)
lse=softmax_lse,
)
torch.npu.graph_task_update_end(update_stream)
event.record(update_stream)
@@ -519,7 +544,7 @@ class GraphParams:
attn_params: dict[int, list[tuple]]
_graph_params: Optional[GraphParams] = None
_graph_params: GraphParams | None = None
def set_graph_params(aclgraph_capture_sizes: list[int]):
@@ -527,14 +552,10 @@ def set_graph_params(aclgraph_capture_sizes: list[int]):
if _graph_params is not None:
raise ValueError("Graph parameters have already been set!")
_graph_params = GraphParams(
{size: []
for size in aclgraph_capture_sizes},
{size: None
for size in aclgraph_capture_sizes},
{size: []
for size in aclgraph_capture_sizes},
{size: []
for size in aclgraph_capture_sizes},
{size: [] for size in aclgraph_capture_sizes},
{size: None for size in aclgraph_capture_sizes},
{size: [] for size in aclgraph_capture_sizes},
{size: [] for size in aclgraph_capture_sizes},
)
@@ -548,7 +569,7 @@ def get_graph_params():
return _graph_params
_draft_graph_params: Optional[GraphParams] = None
_draft_graph_params: GraphParams | None = None
def set_draft_graph_params(aclgraph_capture_sizes: list[int]):
@@ -556,14 +577,10 @@ def set_draft_graph_params(aclgraph_capture_sizes: list[int]):
if _draft_graph_params is not None:
raise ValueError("DraftGraph parameters have already been set!")
_draft_graph_params = GraphParams(
{size: []
for size in aclgraph_capture_sizes},
{size: None
for size in aclgraph_capture_sizes},
{size: []
for size in aclgraph_capture_sizes},
{size: []
for size in aclgraph_capture_sizes},
{size: [] for size in aclgraph_capture_sizes},
{size: None for size in aclgraph_capture_sizes},
{size: [] for size in aclgraph_capture_sizes},
{size: [] for size in aclgraph_capture_sizes},
)

View File

@@ -16,13 +16,13 @@
# limitations under the License.
#
import functools
from typing import Any, Callable, Optional
from collections.abc import Callable
from typing import Any
import torch
import torch.fx as fx
from torch._dynamo.backends.common import aot_autograd
from torch._inductor.compile_fx import (graph_returns_tuple,
make_graph_return_tuple)
from torch._inductor.compile_fx import graph_returns_tuple, make_graph_return_tuple
from torch._inductor.decomposition import select_decomp_table
from torch.fx import GraphModule
from vllm.compilation.compiler_interface import CompilerInterface
@@ -32,15 +32,11 @@ from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.utils import COMPILATION_PASS_KEY
def compile_fx(graph: GraphModule, example_inputs: list,
inner_compile: Callable, decompositions: dict) -> Callable:
recursive_compile_fx = functools.partial(compile_fx,
inner_compile=inner_compile,
decompositions=decompositions)
def compile_fx(graph: GraphModule, example_inputs: list, inner_compile: Callable, decompositions: dict) -> Callable:
recursive_compile_fx = functools.partial(compile_fx, inner_compile=inner_compile, decompositions=decompositions)
if not graph_returns_tuple(graph):
return make_graph_return_tuple(graph, example_inputs,
recursive_compile_fx)
return make_graph_return_tuple(graph, example_inputs, recursive_compile_fx)
return aot_autograd(fw_compiler=inner_compile)(graph, example_inputs)
@@ -49,9 +45,8 @@ def fusion_pass_compile(
example_inputs: list[Any],
compiler_config: dict[str, Any],
compile_range: Range,
key: Optional[str] = None,
) -> tuple[Optional[Callable], Optional[Any]]:
key: str | None = None,
) -> tuple[Callable | None, Any | None]:
def compile_inner(graph, example_inputs):
current_pass_manager = compiler_config[COMPILATION_PASS_KEY]
graph = current_pass_manager(graph)
@@ -74,8 +69,8 @@ def npugraph_ex_compile(
example_inputs: list[Any],
compiler_config: dict[str, Any],
compile_range: Range,
key: Optional[str] = None,
) -> tuple[Optional[Callable], Optional[Any]]:
key: str | None = None,
) -> tuple[Callable | None, Any | None]:
# When currently using the FULL_DECODE_ONLY mode,
# the piecewise compilation level slicing process
# in vllm is also encountered.
@@ -87,10 +82,8 @@ def npugraph_ex_compile(
output_node = fx_graph.output_node()
with fx_graph.inserting_before(output_node):
return_value = output_node.args[0]
tuple_node = fx_graph.create_node("call_function",
tuple,
args=([return_value], ))
output_node.args = (tuple_node, )
tuple_node = fx_graph.create_node("call_function", tuple, args=([return_value],))
output_node.args = (tuple_node,)
graph.recompile()
import torchair
@@ -119,6 +112,7 @@ class AscendCompiler(CompilerInterface):
This class provides a method to compile a PyTorch FX graph module with
specific configurations for graph fusion and decomposition.
"""
name = "AscendCompiler"
def compile(
@@ -127,13 +121,10 @@ class AscendCompiler(CompilerInterface):
example_inputs: list[Any],
compiler_config: dict[str, Any],
compile_range: Range,
key: Optional[str] = None,
) -> tuple[Optional[Callable], Optional[Any]]:
key: str | None = None,
) -> tuple[Callable | None, Any | None]:
ascend_config = get_ascend_config()
if ascend_config.enable_npugraph_ex:
return npugraph_ex_compile(graph, example_inputs, compiler_config,
compile_range, key)
return npugraph_ex_compile(graph, example_inputs, compiler_config, compile_range, key)
else:
return fusion_pass_compile(graph, example_inputs, compiler_config,
compile_range, key)
return fusion_pass_compile(graph, example_inputs, compiler_config, compile_range, key)

View File

@@ -26,7 +26,7 @@ class GraphFusionPassManager:
"""
A pass manager for graph fusion passes.
It handles the configuration and execution of passes.
The counterpart in vllm is PostGradPassManager. Since torch_npu
The counterpart in vllm is PostGradPassManager. Since torch_npu
does not support triton for now, we define our own pass manager.
"""
@@ -48,13 +48,13 @@ class GraphFusionPassManager:
def configure(self, config: VllmConfig):
# By default, we enable the graph fusion and quantization fusion pass.
self.ascend_compilation_config: dict = config.additional_config.get(
"ascend_compilation_config", {})
self.ascend_compilation_config: dict = config.additional_config.get("ascend_compilation_config", {})
if self.ascend_compilation_config.get("fuse_norm_quant", True):
from .passes.norm_quant_fusion_pass import \
AddRMSNormQuantFusionPass
from .passes.norm_quant_fusion_pass import AddRMSNormQuantFusionPass
self.passes.append(AddRMSNormQuantFusionPass(config))
if self.ascend_compilation_config.get("fuse_qknorm_rope", True):
from .passes.qknorm_rope_fusion_pass import QKNormRopeFusionPass
self.passes.append(QKNormRopeFusionPass(config))

View File

@@ -48,7 +48,8 @@ def _extra_stream_scope_check(match: Match) -> bool:
logger.debug(
f"Cross-stream operation detected in pattern match for AddRMSNormQuant. "
f"Multiple streams found: {non_default_streams}. "
f"Fusion is not supported for cross-stream operations.")
f"Fusion is not supported for cross-stream operations."
)
return False
return True
@@ -57,24 +58,29 @@ def _extra_stream_scope_check(match: Match) -> bool:
@functools.lru_cache(None)
# The replacement registered here will be actually executed after AOT.
def replacement_add_rms_norm_quant(epsilon):
def pattern(rms_norm_input: torch.Tensor, residual: torch.Tensor,
rms_norm_weight: torch.Tensor, scale: torch.Tensor,
offset: torch.Tensor):
def pattern(
rms_norm_input: torch.Tensor,
residual: torch.Tensor,
rms_norm_weight: torch.Tensor,
scale: torch.Tensor,
offset: torch.Tensor,
):
"""
Pattern for AddRMSNormQuant fusion.
"""
output = torch.ops.npu.npu_add_rms_norm(rms_norm_input, residual,
rms_norm_weight, epsilon)
output = torch.ops.npu.npu_add_rms_norm(rms_norm_input, residual, rms_norm_weight, epsilon)
out0 = output[0]
out1 = output[2]
quantized_output = torch.ops.npu.npu_quantize(out0, scale, offset,
torch.qint8, -1, False)
quantized_output = torch.ops.npu.npu_quantize(out0, scale, offset, torch.qint8, -1, False)
return quantized_output, out1
def replacement(rms_norm_input: torch.Tensor, residual: torch.Tensor,
rms_norm_weight: torch.Tensor, scale: torch.Tensor,
offset: torch.Tensor):
def replacement(
rms_norm_input: torch.Tensor,
residual: torch.Tensor,
rms_norm_weight: torch.Tensor,
scale: torch.Tensor,
offset: torch.Tensor,
):
"""
Replacement for the AddRMSNormQuant fusion.
"""
@@ -82,10 +88,12 @@ def replacement_add_rms_norm_quant(epsilon):
rms_norm_input,
residual,
rms_norm_weight,
# The inverse of scale is required by npu_add_rms_norm_quant kernel which is opposite to the npu_quantize kernel.
1. / scale,
# The inverse of scale is required by npu_add_rms_norm_quant kernel
# which is opposite to the npu_quantize kernel.
1.0 / scale,
offset,
epsilon=epsilon)
epsilon=epsilon,
)
quantized_output = output[0]
out1 = output[2]
return quantized_output, out1
@@ -103,33 +111,39 @@ def replacement_add_rms_norm_quant(epsilon):
import torchair
torchair.register_replacement(search_fn=pattern,
replace_fn=replacement,
example_inputs=get_inputs(),
extra_check=_extra_stream_scope_check)
torchair.register_replacement(
search_fn=pattern, replace_fn=replacement, example_inputs=get_inputs(), extra_check=_extra_stream_scope_check
)
# The replacement registered here will be actually executed after AOT.
def replacement_add_rms_norm_quant_with_bias(epsilon):
def pattern(rms_norm_input: torch.Tensor, residual: torch.Tensor,
rms_norm_weight: torch.Tensor, scale: torch.Tensor,
offset: torch.Tensor, bias: torch.Tensor):
def pattern(
rms_norm_input: torch.Tensor,
residual: torch.Tensor,
rms_norm_weight: torch.Tensor,
scale: torch.Tensor,
offset: torch.Tensor,
bias: torch.Tensor,
):
"""
Pattern for AddRMSNormQuantWithBias fusion.
"""
output = torch.ops.npu.npu_add_rms_norm(rms_norm_input, residual,
rms_norm_weight, epsilon)
output = torch.ops.npu.npu_add_rms_norm(rms_norm_input, residual, rms_norm_weight, epsilon)
out0 = output[0]
out1 = output[2]
out0 = out0 + bias
quantized_output = torch.ops.npu.npu_quantize(out0, scale, offset,
torch.qint8, -1, False)
quantized_output = torch.ops.npu.npu_quantize(out0, scale, offset, torch.qint8, -1, False)
return quantized_output, out1
def replacement(rms_norm_input: torch.Tensor, residual: torch.Tensor,
rms_norm_weight: torch.Tensor, scale: torch.Tensor,
offset: torch.Tensor, bias: torch.Tensor):
def replacement(
rms_norm_input: torch.Tensor,
residual: torch.Tensor,
rms_norm_weight: torch.Tensor,
scale: torch.Tensor,
offset: torch.Tensor,
bias: torch.Tensor,
):
"""
Replacement for AddRMSNormQuantWithBias fusion.
"""
@@ -137,11 +151,13 @@ def replacement_add_rms_norm_quant_with_bias(epsilon):
rms_norm_input,
residual,
rms_norm_weight,
# The inverse of scale is required by npu_add_rms_norm_quant kernel which is opposite to the npu_quantize kernel.
1. / scale,
# The inverse of scale is required by npu_add_rms_norm_quant kernel
# which is opposite to the npu_quantize kernel.
1.0 / scale,
offset,
epsilon=epsilon,
beta=bias)
beta=bias,
)
quantized_output = output[0]
out1 = output[2]
return quantized_output, out1
@@ -156,40 +172,41 @@ def replacement_add_rms_norm_quant_with_bias(epsilon):
rmsnorm_bias = torch.randn(4, device="npu")
scale = torch.ones(4, device="npu")
offset = torch.zeros(4, device="npu")
return [
rms_norm_input, residual, rms_norm_weight, scale, offset,
rmsnorm_bias
]
return [rms_norm_input, residual, rms_norm_weight, scale, offset, rmsnorm_bias]
import torchair
torchair.register_replacement(search_fn=pattern,
replace_fn=replacement,
example_inputs=get_inputs(),
extra_check=_extra_stream_scope_check)
torchair.register_replacement(
search_fn=pattern, replace_fn=replacement, example_inputs=get_inputs(), extra_check=_extra_stream_scope_check
)
# The replacement registered here will be actually executed after AOT.
def replacement_add_rms_norm_quant_sp_pattern(epsilon):
def pattern(rms_norm_input: torch.Tensor, residual: torch.Tensor,
rms_norm_weight: torch.Tensor, scale: torch.Tensor,
offset: torch.Tensor):
def pattern(
rms_norm_input: torch.Tensor,
residual: torch.Tensor,
rms_norm_weight: torch.Tensor,
scale: torch.Tensor,
offset: torch.Tensor,
):
"""
Pattern for AddRMSNormQuantSPPattern fusion.
"""
output = torch.ops.npu.npu_add_rms_norm(rms_norm_input, residual,
rms_norm_weight, epsilon)
output = torch.ops.npu.npu_add_rms_norm(rms_norm_input, residual, rms_norm_weight, epsilon)
out0 = output[0]
out1 = output[2]
out0 = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(out0, True)
quantized_output = torch.ops.npu.npu_quantize(out0, scale, offset,
torch.qint8, -1, False)
quantized_output = torch.ops.npu.npu_quantize(out0, scale, offset, torch.qint8, -1, False)
return quantized_output, out1
def replacement(rms_norm_input: torch.Tensor, residual: torch.Tensor,
rms_norm_weight: torch.Tensor, scale: torch.Tensor,
offset: torch.Tensor):
def replacement(
rms_norm_input: torch.Tensor,
residual: torch.Tensor,
rms_norm_weight: torch.Tensor,
scale: torch.Tensor,
offset: torch.Tensor,
):
"""
Replacement for the AddRMSNormQuantSPPattern fusion.
"""
@@ -197,14 +214,15 @@ def replacement_add_rms_norm_quant_sp_pattern(epsilon):
rms_norm_input,
residual,
rms_norm_weight,
# The inverse of scale is required by npu_add_rms_norm_quant kernel which is opposite to the npu_quantize kernel.
1. / scale,
# The inverse of scale is required by npu_add_rms_norm_quant kernel
# which is opposite to the npu_quantize kernel.
1.0 / scale,
offset,
epsilon=epsilon)
epsilon=epsilon,
)
quantized_output = output[0]
out1 = output[2]
quantized_output = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
quantized_output, True)
quantized_output = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(quantized_output, True)
return quantized_output, out1
def get_inputs():
@@ -220,34 +238,40 @@ def replacement_add_rms_norm_quant_sp_pattern(epsilon):
import torchair
torchair.register_replacement(search_fn=pattern,
replace_fn=replacement,
example_inputs=get_inputs(),
extra_check=_extra_stream_scope_check)
torchair.register_replacement(
search_fn=pattern, replace_fn=replacement, example_inputs=get_inputs(), extra_check=_extra_stream_scope_check
)
# The replacement registered here will be actually executed after AOT.
def replacement_add_rms_norm_quant_sp_pattern_with_bias(epsilon):
def pattern(rms_norm_input: torch.Tensor, residual: torch.Tensor,
rms_norm_weight: torch.Tensor, scale: torch.Tensor,
offset: torch.Tensor, bias: torch.Tensor):
def pattern(
rms_norm_input: torch.Tensor,
residual: torch.Tensor,
rms_norm_weight: torch.Tensor,
scale: torch.Tensor,
offset: torch.Tensor,
bias: torch.Tensor,
):
"""
Pattern for AddRMSNormQuantSPPatternWithBias fusion.
"""
output = torch.ops.npu.npu_add_rms_norm(rms_norm_input, residual,
rms_norm_weight, epsilon)
output = torch.ops.npu.npu_add_rms_norm(rms_norm_input, residual, rms_norm_weight, epsilon)
out0 = output[0]
out1 = output[2]
out0 = out0 + bias
out0 = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(out0, True)
quantized_output = torch.ops.npu.npu_quantize(out0, scale, offset,
torch.qint8, -1, False)
quantized_output = torch.ops.npu.npu_quantize(out0, scale, offset, torch.qint8, -1, False)
return quantized_output, out1
def replacement(rms_norm_input: torch.Tensor, residual: torch.Tensor,
rms_norm_weight: torch.Tensor, scale: torch.Tensor,
offset: torch.Tensor, bias: torch.Tensor):
def replacement(
rms_norm_input: torch.Tensor,
residual: torch.Tensor,
rms_norm_weight: torch.Tensor,
scale: torch.Tensor,
offset: torch.Tensor,
bias: torch.Tensor,
):
"""
Replacement for the AddRMSNormQuantSPPatternWithBias fusion.
"""
@@ -255,15 +279,16 @@ def replacement_add_rms_norm_quant_sp_pattern_with_bias(epsilon):
rms_norm_input,
residual,
rms_norm_weight,
# The inverse of scale is required by npu_add_rms_norm_quant kernel which is opposite to the npu_quantize kernel.
1. / scale,
# The inverse of scale is required by npu_add_rms_norm_quant kernel
# which is opposite to the npu_quantize kernel.
1.0 / scale,
offset,
epsilon=epsilon,
beta=bias)
beta=bias,
)
quantized_output = output[0]
out1 = output[2]
quantized_output = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
quantized_output, True)
quantized_output = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(quantized_output, True)
return quantized_output, out1
def get_inputs():
@@ -276,25 +301,19 @@ def replacement_add_rms_norm_quant_sp_pattern_with_bias(epsilon):
rmsnorm_bias = torch.randn(4, device="npu")
scale = torch.ones(4, device="npu")
offset = torch.zeros(4, device="npu")
return [
rms_norm_input, residual, rms_norm_weight, scale, offset,
rmsnorm_bias
]
return [rms_norm_input, residual, rms_norm_weight, scale, offset, rmsnorm_bias]
import torchair
torchair.register_replacement(search_fn=pattern,
replace_fn=replacement,
example_inputs=get_inputs(),
extra_check=_extra_stream_scope_check)
torchair.register_replacement(
search_fn=pattern, replace_fn=replacement, example_inputs=get_inputs(), extra_check=_extra_stream_scope_check
)
# register converter for pass
common_epsilons = [1e-5, 1e-6]
for eps in common_epsilons:
logger.info(
f"Start register fusion pattern for AddRMSNormQuant with epsilons={eps}"
)
logger.info(f"Start register fusion pattern for AddRMSNormQuant with epsilons={eps}")
replacement_add_rms_norm_quant(eps)
replacement_add_rms_norm_quant_with_bias(eps)
replacement_add_rms_norm_quant_sp_pattern(eps)

View File

@@ -25,7 +25,6 @@ from vllm.logger import logger
class AddRMSNormQuantPattern:
def __init__(self, vllm_config: VllmConfig, eps: float = 1e-6):
self.vllm_config = vllm_config
self.dtype = vllm_config.model_config.dtype
@@ -41,50 +40,48 @@ class AddRMSNormQuantPattern:
scale = torch.ones(4, device="npu", dtype=self.dtype)
scale_reciprocal = torch.ones(4, device="npu", dtype=self.dtype)
offset = torch.zeros(4, device="npu", dtype=self.dtype)
return [
rms_norm_input, residual, rms_norm_weight, scale, scale_reciprocal,
offset
]
return [rms_norm_input, residual, rms_norm_weight, scale, scale_reciprocal, offset]
def register(self, pm_pass: PatternMatcherPass):
def pattern(rms_norm_input: torch.Tensor, residual: torch.Tensor,
rms_norm_weight: torch.Tensor, scale: torch.Tensor,
scale_reciprocal: torch.Tensor, offset: torch.Tensor):
def pattern(
rms_norm_input: torch.Tensor,
residual: torch.Tensor,
rms_norm_weight: torch.Tensor,
scale: torch.Tensor,
scale_reciprocal: torch.Tensor,
offset: torch.Tensor,
):
"""
Pattern for AddRMSNormQuant fusion.
"""
output = torch.ops.npu.npu_add_rms_norm(rms_norm_input, residual,
rms_norm_weight, self.eps)
output = torch.ops.npu.npu_add_rms_norm(rms_norm_input, residual, rms_norm_weight, self.eps)
out0 = output[0]
out1 = output[2]
quantized_output = torch.ops.vllm.quantize(out0, scale,
scale_reciprocal,
offset)
quantized_output = torch.ops.vllm.quantize(out0, scale, scale_reciprocal, offset)
return quantized_output, out1
def replacement(rms_norm_input: torch.Tensor, residual: torch.Tensor,
rms_norm_weight: torch.Tensor, scale: torch.Tensor,
scale_reciprocal: torch.Tensor, offset: torch.Tensor):
def replacement(
rms_norm_input: torch.Tensor,
residual: torch.Tensor,
rms_norm_weight: torch.Tensor,
scale: torch.Tensor,
scale_reciprocal: torch.Tensor,
offset: torch.Tensor,
):
"""
Replacement for the AddRMSNormQuant fusion.
"""
output = torch.ops.npu.npu_add_rms_norm_quant(rms_norm_input,
residual,
rms_norm_weight,
scale,
offset,
epsilon=self.eps)
output = torch.ops.npu.npu_add_rms_norm_quant(
rms_norm_input, residual, rms_norm_weight, scale, offset, epsilon=self.eps
)
quantized_output = output[0]
out1 = output[2]
return quantized_output, out1
pm.register_replacement(pattern, replacement, self.get_inputs(),
pm.fwd_only, pm_pass)
pm.register_replacement(pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass)
class AddRMSNormQuantPatternWithBias:
def __init__(self, vllm_config: VllmConfig, eps: float = 1e-6):
self.vllm_config = vllm_config
self.dtype = vllm_config.model_config.dtype
@@ -101,54 +98,51 @@ class AddRMSNormQuantPatternWithBias:
scale = torch.ones(4, device="npu", dtype=self.dtype)
scale_reciprocal = torch.ones(4, device="npu", dtype=self.dtype)
offset = torch.zeros(4, device="npu", dtype=self.dtype)
return [
rms_norm_input, residual, rms_norm_weight, scale, scale_reciprocal,
offset, rmsnorm_bias
]
return [rms_norm_input, residual, rms_norm_weight, scale, scale_reciprocal, offset, rmsnorm_bias]
def register(self, pm_pass: PatternMatcherPass):
def pattern(rms_norm_input: torch.Tensor, residual: torch.Tensor,
rms_norm_weight: torch.Tensor, scale: torch.Tensor,
scale_reciprocal: torch.Tensor, offset: torch.Tensor,
bias: torch.Tensor):
def pattern(
rms_norm_input: torch.Tensor,
residual: torch.Tensor,
rms_norm_weight: torch.Tensor,
scale: torch.Tensor,
scale_reciprocal: torch.Tensor,
offset: torch.Tensor,
bias: torch.Tensor,
):
"""
Pattern for AddRMSNormQuant fusion.
"""
output = torch.ops.npu.npu_add_rms_norm(rms_norm_input, residual,
rms_norm_weight, self.eps)
output = torch.ops.npu.npu_add_rms_norm(rms_norm_input, residual, rms_norm_weight, self.eps)
out0 = output[0]
out1 = output[2]
out0 = out0 + bias
quantized_output = torch.ops.vllm.quantize(out0, scale,
scale_reciprocal,
offset)
quantized_output = torch.ops.vllm.quantize(out0, scale, scale_reciprocal, offset)
return quantized_output, out1
def replacement(rms_norm_input: torch.Tensor, residual: torch.Tensor,
rms_norm_weight: torch.Tensor, scale: torch.Tensor,
scale_reciprocal: torch.Tensor, offset: torch.Tensor,
bias: torch.Tensor):
def replacement(
rms_norm_input: torch.Tensor,
residual: torch.Tensor,
rms_norm_weight: torch.Tensor,
scale: torch.Tensor,
scale_reciprocal: torch.Tensor,
offset: torch.Tensor,
bias: torch.Tensor,
):
"""
Replacement for the AddRMSNormQuant fusion.
"""
output = torch.ops.npu.npu_add_rms_norm_quant(rms_norm_input,
residual,
rms_norm_weight,
scale,
offset,
epsilon=self.eps,
beta=bias)
output = torch.ops.npu.npu_add_rms_norm_quant(
rms_norm_input, residual, rms_norm_weight, scale, offset, epsilon=self.eps, beta=bias
)
quantized_output = output[0]
out1 = output[2]
return quantized_output, out1
pm.register_replacement(pattern, replacement, self.get_inputs(),
pm.fwd_only, pm_pass)
pm.register_replacement(pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass)
class AddRMSNormQuantSPPattern:
def __init__(self, vllm_config: VllmConfig, eps: float = 1e-6):
self.vllm_config = vllm_config
self.dtype = vllm_config.model_config.dtype
@@ -164,53 +158,50 @@ class AddRMSNormQuantSPPattern:
scale = torch.ones(4, device="npu", dtype=self.dtype)
scale_reciprocal = torch.ones(4, device="npu", dtype=self.dtype)
offset = torch.zeros(4, device="npu", dtype=self.dtype)
return [
rms_norm_input, residual, rms_norm_weight, scale, scale_reciprocal,
offset
]
return [rms_norm_input, residual, rms_norm_weight, scale, scale_reciprocal, offset]
def register(self, pm_pass: PatternMatcherPass):
def pattern(rms_norm_input: torch.Tensor, residual: torch.Tensor,
rms_norm_weight: torch.Tensor, scale: torch.Tensor,
scale_reciprocal: torch.Tensor, offset: torch.Tensor):
def pattern(
rms_norm_input: torch.Tensor,
residual: torch.Tensor,
rms_norm_weight: torch.Tensor,
scale: torch.Tensor,
scale_reciprocal: torch.Tensor,
offset: torch.Tensor,
):
"""
Pattern for AddRMSNormQuant fusion.
"""
output = torch.ops.npu.npu_add_rms_norm(rms_norm_input, residual,
rms_norm_weight, self.eps)
output = torch.ops.npu.npu_add_rms_norm(rms_norm_input, residual, rms_norm_weight, self.eps)
out0 = output[0]
out1 = output[2]
out0 = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(out0, True)
quantized_output = torch.ops.vllm.quantize(out0, scale,
scale_reciprocal,
offset)
quantized_output = torch.ops.vllm.quantize(out0, scale, scale_reciprocal, offset)
return quantized_output, out1
def replacement(rms_norm_input: torch.Tensor, residual: torch.Tensor,
rms_norm_weight: torch.Tensor, scale: torch.Tensor,
scale_reciprocal: torch.Tensor, offset: torch.Tensor):
def replacement(
rms_norm_input: torch.Tensor,
residual: torch.Tensor,
rms_norm_weight: torch.Tensor,
scale: torch.Tensor,
scale_reciprocal: torch.Tensor,
offset: torch.Tensor,
):
"""
Replacement for the AddRMSNormQuant fusion.
"""
output = torch.ops.npu.npu_add_rms_norm_quant(rms_norm_input,
residual,
rms_norm_weight,
scale,
offset,
epsilon=self.eps)
output = torch.ops.npu.npu_add_rms_norm_quant(
rms_norm_input, residual, rms_norm_weight, scale, offset, epsilon=self.eps
)
quantized_output = output[0]
out1 = output[2]
quantized_output = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
quantized_output, True)
quantized_output = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(quantized_output, True)
return quantized_output, out1
pm.register_replacement(pattern, replacement, self.get_inputs(),
pm.fwd_only, pm_pass)
pm.register_replacement(pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass)
class AddRMSNormQuantSPPatternWithBias:
def __init__(self, vllm_config: VllmConfig, eps: float = 1e-6):
self.vllm_config = vllm_config
self.dtype = vllm_config.model_config.dtype
@@ -227,53 +218,50 @@ class AddRMSNormQuantSPPatternWithBias:
scale = torch.ones(4, device="npu", dtype=self.dtype)
scale_reciprocal = torch.ones(4, device="npu", dtype=self.dtype)
offset = torch.zeros(4, device="npu", dtype=self.dtype)
return [
rms_norm_input, residual, rms_norm_weight, scale, scale_reciprocal,
offset, rmsnorm_bias
]
return [rms_norm_input, residual, rms_norm_weight, scale, scale_reciprocal, offset, rmsnorm_bias]
def register(self, pm_pass: PatternMatcherPass):
def pattern(rms_norm_input: torch.Tensor, residual: torch.Tensor,
rms_norm_weight: torch.Tensor, scale: torch.Tensor,
scale_reciprocal: torch.Tensor, offset: torch.Tensor,
bias: torch.Tensor):
def pattern(
rms_norm_input: torch.Tensor,
residual: torch.Tensor,
rms_norm_weight: torch.Tensor,
scale: torch.Tensor,
scale_reciprocal: torch.Tensor,
offset: torch.Tensor,
bias: torch.Tensor,
):
"""
Pattern for AddRMSNormQuant fusion.
"""
output = torch.ops.npu.npu_add_rms_norm(rms_norm_input, residual,
rms_norm_weight, self.eps)
output = torch.ops.npu.npu_add_rms_norm(rms_norm_input, residual, rms_norm_weight, self.eps)
out0 = output[0]
out1 = output[2]
out0 = out0 + bias
out0 = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(out0, True)
quantized_output = torch.ops.vllm.quantize(out0, scale,
scale_reciprocal,
offset)
quantized_output = torch.ops.vllm.quantize(out0, scale, scale_reciprocal, offset)
return quantized_output, out1
def replacement(rms_norm_input: torch.Tensor, residual: torch.Tensor,
rms_norm_weight: torch.Tensor, scale: torch.Tensor,
scale_reciprocal: torch.Tensor, offset: torch.Tensor,
bias: torch.Tensor):
def replacement(
rms_norm_input: torch.Tensor,
residual: torch.Tensor,
rms_norm_weight: torch.Tensor,
scale: torch.Tensor,
scale_reciprocal: torch.Tensor,
offset: torch.Tensor,
bias: torch.Tensor,
):
"""
Replacement for the AddRMSNormQuant fusion.
"""
output = torch.ops.npu.npu_add_rms_norm_quant(rms_norm_input,
residual,
rms_norm_weight,
scale,
offset,
epsilon=self.eps,
beta=bias)
output = torch.ops.npu.npu_add_rms_norm_quant(
rms_norm_input, residual, rms_norm_weight, scale, offset, epsilon=self.eps, beta=bias
)
quantized_output = output[0]
out1 = output[2]
quantized_output = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
quantized_output, True)
quantized_output = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(quantized_output, True)
return quantized_output, out1
pm.register_replacement(pattern, replacement, self.get_inputs(),
pm.fwd_only, pm_pass)
pm.register_replacement(pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass)
class AddRMSNormQuantFusionPass(VllmInductorPass):
@@ -283,25 +271,19 @@ class AddRMSNormQuantFusionPass(VllmInductorPass):
def __init__(self, vllm_config: VllmConfig):
super().__init__(vllm_config)
self.pattern_match_passes: PatternMatcherPass = PatternMatcherPass(
pass_name="rmsnorm_quant_fusion_pass")
self.pattern_match_passes: PatternMatcherPass = PatternMatcherPass(pass_name="rmsnorm_quant_fusion_pass")
dtype = vllm_config.model_config.dtype
if dtype not in (torch.bfloat16, torch.float16):
logger.debug("Quant fusion not enabled: unsupported dtype %s",
dtype)
logger.debug("Quant fusion not enabled: unsupported dtype %s", dtype)
return
common_epsilons = [1e-5, 1e-6]
for eps in common_epsilons:
AddRMSNormQuantPattern(vllm_config,
eps=eps).register(self.pattern_match_passes)
AddRMSNormQuantPatternWithBias(vllm_config, eps=eps).register(
self.pattern_match_passes)
AddRMSNormQuantSPPattern(vllm_config, eps=eps).register(
self.pattern_match_passes)
AddRMSNormQuantSPPatternWithBias(vllm_config, eps=eps).register(
self.pattern_match_passes)
AddRMSNormQuantPattern(vllm_config, eps=eps).register(self.pattern_match_passes)
AddRMSNormQuantPatternWithBias(vllm_config, eps=eps).register(self.pattern_match_passes)
AddRMSNormQuantSPPattern(vllm_config, eps=eps).register(self.pattern_match_passes)
AddRMSNormQuantSPPatternWithBias(vllm_config, eps=eps).register(self.pattern_match_passes)
def __call__(self, graph: torch.fx.Graph):
self.begin()

View File

@@ -17,8 +17,7 @@
#
import torch
import torch._inductor.pattern_matcher as pm
from torch._inductor.pattern_matcher import (PatternMatcherPass,
PatternPrettyPrinter)
from torch._inductor.pattern_matcher import PatternMatcherPass, PatternPrettyPrinter
from vllm.attention.layer import Attention
from vllm.compilation.vllm_inductor_pass import VllmInductorPass
from vllm.config import VllmConfig, get_layers_from_vllm_config
@@ -27,13 +26,7 @@ from vllm.logger import logger
class QKNormRopeFusionPattern:
def __init__(self,
vllm_config,
head_dim,
num_heads,
num_kv_heads,
eps=1e-6):
def __init__(self, vllm_config, head_dim, num_heads, num_kv_heads, eps=1e-6):
self.vllm_config = vllm_config
self.head_dim = head_dim
self.num_heads = num_heads
@@ -45,65 +38,38 @@ class QKNormRopeFusionPattern:
def get_inputs(self):
T = 5
qkv = torch.empty(T,
self.q_size + 2 * self.kv_size,
dtype=torch.bfloat16,
device="npu")
q_weight = torch.empty(self.head_dim,
dtype=torch.bfloat16,
device="npu")
k_weight = torch.empty(self.head_dim,
dtype=torch.bfloat16,
device="npu")
cos = torch.empty(1,
T,
1,
self.head_dim,
dtype=torch.bfloat16,
device="npu")
sin = torch.empty(1,
T,
1,
self.head_dim,
dtype=torch.bfloat16,
device="npu")
qkv = torch.empty(T, self.q_size + 2 * self.kv_size, dtype=torch.bfloat16, device="npu")
q_weight = torch.empty(self.head_dim, dtype=torch.bfloat16, device="npu")
k_weight = torch.empty(self.head_dim, dtype=torch.bfloat16, device="npu")
cos = torch.empty(1, T, 1, self.head_dim, dtype=torch.bfloat16, device="npu")
sin = torch.empty(1, T, 1, self.head_dim, dtype=torch.bfloat16, device="npu")
return [qkv, q_weight, k_weight, cos, sin]
def register(self, pm_pass: PatternMatcherPass):
def pattern(
qkv: torch.Tensor, q_weight: torch.Tensor, k_weight: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor
):
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
def pattern(qkv: torch.Tensor, q_weight: torch.Tensor,
k_weight: torch.Tensor, cos: torch.Tensor,
sin: torch.Tensor):
q_by_head = q.view(*q.shape[:-1], q.shape[-1] // self.head_dim, self.head_dim)
q_norm_out, _ = torch.ops.npu.npu_rms_norm(q_by_head, q_weight, self.eps)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size],
dim=-1)
q_by_head = q.view(*q.shape[:-1], q.shape[-1] // self.head_dim,
self.head_dim)
q_norm_out, _ = torch.ops.npu.npu_rms_norm(q_by_head, q_weight,
self.eps)
k_by_head = k.view(*k.shape[:-1], k.shape[-1] // self.head_dim,
self.head_dim)
k_norm_out, _ = torch.ops.npu.npu_rms_norm(k_by_head, k_weight,
self.eps)
k_by_head = k.view(*k.shape[:-1], k.shape[-1] // self.head_dim, self.head_dim)
k_norm_out, _ = torch.ops.npu.npu_rms_norm(k_by_head, k_weight, self.eps)
q_flat = q_norm_out.view(q.shape)
q_reshape = q_flat.contiguous().view(1, q_flat.shape[0], -1,
self.head_dim)
q_reshape = q_flat.contiguous().view(1, q_flat.shape[0], -1, self.head_dim)
k_flat = k_norm_out.view(k.shape)
k_reshape = k_flat.contiguous().view(1, k_flat.shape[0], -1,
self.head_dim)
k_reshape = k_flat.contiguous().view(1, k_flat.shape[0], -1, self.head_dim)
q_rope, k_rope = torch.ops.npu.npu_apply_rotary_pos_emb(
q_reshape, k_reshape, cos, sin)
q_rope, k_rope = torch.ops.npu.npu_apply_rotary_pos_emb(q_reshape, k_reshape, cos, sin)
return q_rope, k_rope, v
def replacement(qkv: torch.Tensor, q_weight: torch.Tensor,
k_weight: torch.Tensor, cos: torch.Tensor,
sin: torch.Tensor):
def replacement(
qkv: torch.Tensor, q_weight: torch.Tensor, k_weight: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor
):
results = torch.ops.vllm.qkv_rmsnorm_rope(
input=qkv,
q_weight=q_weight,
@@ -115,22 +81,16 @@ class QKNormRopeFusionPattern:
q_bias=None,
k_bias=None,
sin=sin,
cos=cos)
cos=cos,
)
return results
pm.register_replacement(pattern, replacement, self.get_inputs(),
pm.fwd_only, pm_pass)
pm.register_replacement(pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass)
class QKNormRopeFusionPatternWithBias:
def __init__(self,
vllm_config,
head_dim,
num_heads,
num_kv_heads,
eps=1e-6):
def __init__(self, vllm_config, head_dim, num_heads, num_kv_heads, eps=1e-6):
self.head_dim = head_dim
self.num_heads = num_heads
self.num_kv_heads = num_kv_heads
@@ -142,71 +102,55 @@ class QKNormRopeFusionPatternWithBias:
def get_inputs(self):
T = 5
qkv = torch.empty(T,
self.q_size + 2 * self.kv_size,
dtype=torch.bfloat16,
device="npu")
q_weight = torch.empty(self.head_dim,
dtype=torch.bfloat16,
device="npu")
k_weight = torch.empty(self.head_dim,
dtype=torch.bfloat16,
device="npu")
qkv = torch.empty(T, self.q_size + 2 * self.kv_size, dtype=torch.bfloat16, device="npu")
q_weight = torch.empty(self.head_dim, dtype=torch.bfloat16, device="npu")
k_weight = torch.empty(self.head_dim, dtype=torch.bfloat16, device="npu")
q_bias = torch.empty(self.head_dim, dtype=torch.bfloat16, device="npu")
k_bias = torch.empty(self.head_dim, dtype=torch.bfloat16, device="npu")
cos = torch.empty(1,
T,
1,
self.head_dim,
dtype=torch.bfloat16,
device="npu")
sin = torch.empty(1,
T,
1,
self.head_dim,
dtype=torch.bfloat16,
device="npu")
cos = torch.empty(1, T, 1, self.head_dim, dtype=torch.bfloat16, device="npu")
sin = torch.empty(1, T, 1, self.head_dim, dtype=torch.bfloat16, device="npu")
return [qkv, q_weight, k_weight, q_bias, k_bias, cos, sin]
def register(self, pm_pass: PatternMatcherPass):
def pattern(
qkv: torch.Tensor,
q_weight: torch.Tensor,
k_weight: torch.Tensor,
q_bias: torch.Tensor,
k_bias: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
):
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
def pattern(qkv: torch.Tensor, q_weight: torch.Tensor,
k_weight: torch.Tensor, q_bias: torch.Tensor,
k_bias: torch.Tensor, cos: torch.Tensor,
sin: torch.Tensor):
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size],
dim=-1)
q_by_head = q.view(*q.shape[:-1], q.shape[-1] // self.head_dim,
self.head_dim)
q_norm_out, _ = torch.ops.npu.npu_rms_norm(q_by_head, q_weight,
self.eps)
q_by_head = q.view(*q.shape[:-1], q.shape[-1] // self.head_dim, self.head_dim)
q_norm_out, _ = torch.ops.npu.npu_rms_norm(q_by_head, q_weight, self.eps)
q_normed = q_norm_out + q_bias
k_by_head = k.view(*k.shape[:-1], k.shape[-1] // self.head_dim,
self.head_dim)
k_norm_out, _ = torch.ops.npu.npu_rms_norm(k_by_head, k_weight,
self.eps)
k_by_head = k.view(*k.shape[:-1], k.shape[-1] // self.head_dim, self.head_dim)
k_norm_out, _ = torch.ops.npu.npu_rms_norm(k_by_head, k_weight, self.eps)
k_normed = k_norm_out + k_bias
q_flat = q_normed.view(q.shape)
q_reshape = q_flat.contiguous().view(1, q_flat.shape[0], -1,
self.head_dim)
q_reshape = q_flat.contiguous().view(1, q_flat.shape[0], -1, self.head_dim)
k_flat = k_normed.view(k.shape)
k_reshape = k_flat.contiguous().view(1, k_flat.shape[0], -1,
self.head_dim)
k_reshape = k_flat.contiguous().view(1, k_flat.shape[0], -1, self.head_dim)
q_rope, k_rope = torch.ops.npu.npu_apply_rotary_pos_emb(
q_reshape, k_reshape, cos, sin)
q_rope, k_rope = torch.ops.npu.npu_apply_rotary_pos_emb(q_reshape, k_reshape, cos, sin)
return q_rope, k_rope, v
def replacement(qkv: torch.Tensor, q_weight: torch.Tensor,
k_weight: torch.Tensor, q_bias: torch.Tensor,
k_bias: torch.Tensor, cos: torch.Tensor,
sin: torch.Tensor):
def replacement(
qkv: torch.Tensor,
q_weight: torch.Tensor,
k_weight: torch.Tensor,
q_bias: torch.Tensor,
k_bias: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
):
results = torch.ops.vllm.qkv_rmsnorm_rope(
input=qkv,
q_weight=q_weight,
@@ -218,11 +162,11 @@ class QKNormRopeFusionPatternWithBias:
q_bias=q_bias,
k_bias=k_bias,
cos=cos,
sin=sin)
sin=sin,
)
return results
pm.register_replacement(pattern, replacement, self.get_inputs(),
pm.fwd_only, pm_pass)
pm.register_replacement(pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass)
class QKNormRopeFusionPass(VllmInductorPass):
@@ -232,44 +176,38 @@ class QKNormRopeFusionPass(VllmInductorPass):
def __init__(self, vllm_config: VllmConfig):
super().__init__(vllm_config)
self.pattern_match_passes: PatternMatcherPass = PatternMatcherPass(
pass_name="qknorm_rope_fusion_pass")
self.pattern_match_passes: PatternMatcherPass = PatternMatcherPass(pass_name="qknorm_rope_fusion_pass")
dtype = vllm_config.model_config.dtype
if dtype not in (torch.bfloat16, torch.float16):
logger.debug(
"QKNorm and Rope fusion not enabled: unsupported dtype %s",
dtype)
logger.debug("QKNorm and Rope fusion not enabled: unsupported dtype %s", dtype)
return
# use one attn layer to get meta (such as head_dim) for QKNormRopeFusionPattern
attn_layers: dict[str, Attention] = get_layers_from_vllm_config(
vllm_config, Attention)
attn_layers: dict[str, Attention] = get_layers_from_vllm_config(vllm_config, Attention)
if len(attn_layers) == 0:
logger.debug(
"QKNorm and Rope fusion enabled, but no Attention layers were discovered."
)
logger.debug("QKNorm and Rope fusion enabled, but no Attention layers were discovered.")
return
layer = next(iter(attn_layers.values()))
for epsilon in [1e-6, 1e-5]:
if layer.head_size != 128:
logger.debug(
"QKNorm and Rope fusion not enabled: head_dim %d is not equal of 128",
layer.head_size)
logger.debug("QKNorm and Rope fusion not enabled: head_dim %d is not equal of 128", layer.head_size)
continue
QKNormRopeFusionPattern(vllm_config=vllm_config,
head_dim=layer.head_size,
num_heads=layer.num_heads,
num_kv_heads=layer.num_kv_heads,
eps=epsilon).register(
self.pattern_match_passes)
QKNormRopeFusionPattern(
vllm_config=vllm_config,
head_dim=layer.head_size,
num_heads=layer.num_heads,
num_kv_heads=layer.num_kv_heads,
eps=epsilon,
).register(self.pattern_match_passes)
QKNormRopeFusionPatternWithBias(vllm_config=vllm_config,
head_dim=layer.head_size,
num_heads=layer.num_heads,
num_kv_heads=layer.num_kv_heads,
eps=epsilon).register(
self.pattern_match_passes)
QKNormRopeFusionPatternWithBias(
vllm_config=vllm_config,
head_dim=layer.head_size,
num_heads=layer.num_heads,
num_kv_heads=layer.num_kv_heads,
eps=epsilon,
).register(self.pattern_match_passes)
def __call__(self, graph: torch.fx.Graph):
self.begin()