[Lint]Style: Convert vllm-ascend/compilation to ruff format (#5912)
### What this PR does / why we need it?
Convert `vllm-ascend/compilation` to ruff format.
### Does this PR introduce _any_ user-facing change?
During this migration, we encountered some **errors** in our CI and
testing environments, such as:
```
vllm_ascend/utils.py:653: in <module>
def register_ascend_customop(vllm_config: VllmConfig | None = None):
^^^^^^^^^^^^^^^^^
E TypeError: unsupported operand type(s) for |: 'NoneType' and 'NoneType'
```
**1. Root Cause Analysis:**
The project uses a common pattern to break circular dependencies:
```python
if TYPE_CHECKING:
from vllm.config import VllmConfig
else:
VllmConfig = None # Placeholder assigned at runtime
```
When Python parses the function definition `def
register_ascend_customop(vllm_config: VllmConfig | None)`, it attempts
to evaluate the expression `VllmConfig | None`.
Since `VllmConfig` is assigned `None` at runtime, the expression
effectively becomes `None | None`. In Python, `None` is an instance of
`NoneType`. While the `|` operator is implemented for Type objects
(classes), it is not supported for `NoneType` instances, leading to the
`TypeError` shown above.
**2. Solution:**
To maintain the modern `|` syntax required by our new linting standards
while preserving our dependency management strategy, I have introduced:
```python
from __future__ import annotations
```
at the top of the affected files. This enables **Postponed Evaluation of
Annotations (PEP 563)**.
**3. Impact and Benefits:**
- By enabling `annotations`, Python no longer executes the `VllmConfig |
None` operation during module load. Instead, it stores the annotation as
a string literal, completely avoiding the `None | None` calculation.
- We can keep the `VllmConfig = None` placeholders. This ensures that
other modules can still import these symbols without triggering an
`ImportError`, maintaining a stable dependency graph.
- IDEs and static type checkers (MyPy/Pyright) continue to resolve the
types correctly. This allows us to use modern syntax without sacrificing
type safety or runtime stability.
- The only side effect is that `__annotations__` will now return strings
instead of type objects. Since this module does not use runtime type
enforcement or reflection, this change has zero negative impact on
existing functionality.
### How was this patch tested?
- vLLM version: v0.13.0
- vLLM main:
11b6af5280
---------
Signed-off-by: MrZ20 <2609716663@qq.com>
This commit is contained in:
@@ -2,9 +2,10 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import dataclasses
|
||||
from collections.abc import Callable
|
||||
from contextlib import ExitStack
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Callable, Optional
|
||||
from typing import Any
|
||||
from unittest.mock import patch
|
||||
|
||||
import numpy as np
|
||||
@@ -27,12 +28,12 @@ from ..utils import weak_ref_tensors
|
||||
@dataclasses.dataclass
|
||||
class ACLGraphEntry:
|
||||
batch_descriptor: BatchDescriptor
|
||||
aclgraph: Optional[torch.npu.NPUGraph] = None
|
||||
output: Optional[Any] = None
|
||||
aclgraph: torch.npu.NPUGraph | None = None
|
||||
output: Any | None = None
|
||||
|
||||
# for aclgraph debugging, track the input addresses
|
||||
# during capture, and check if they are the same during replay
|
||||
input_addresses: Optional[list[int]] = None
|
||||
input_addresses: list[int] | None = None
|
||||
|
||||
|
||||
class ACLGraphWrapper:
|
||||
@@ -60,11 +61,13 @@ class ACLGraphWrapper:
|
||||
guaranteed when VLLM_LOGGING_LEVEL == "DEBUG".
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
runnable: Callable,
|
||||
vllm_config: VllmConfig,
|
||||
runtime_mode: CUDAGraphMode,
|
||||
cudagraph_options: Optional[CUDAGraphOptions] = None):
|
||||
def __init__(
|
||||
self,
|
||||
runnable: Callable,
|
||||
vllm_config: VllmConfig,
|
||||
runtime_mode: CUDAGraphMode,
|
||||
cudagraph_options: CUDAGraphOptions | None = None,
|
||||
):
|
||||
self.runnable = runnable
|
||||
self.vllm_config = vllm_config
|
||||
self.runtime_mode = runtime_mode
|
||||
@@ -83,15 +86,13 @@ class ACLGraphWrapper:
|
||||
self.aclgraph_options = cudagraph_options
|
||||
# the entries for different batch descriptors that we need to capture
|
||||
# aclgraphs for.
|
||||
self.concrete_aclgraph_entries: dict[BatchDescriptor, ACLGraphEntry]\
|
||||
= {}
|
||||
self.concrete_aclgraph_entries: dict[BatchDescriptor, ACLGraphEntry] = {}
|
||||
|
||||
def __getattr__(self, key: str):
|
||||
# allow accessing the attributes of the runnable.
|
||||
if hasattr(self.runnable, key):
|
||||
return getattr(self.runnable, key)
|
||||
raise AttributeError(f"Attribute {key} not exists in the runnable of "
|
||||
f"aclgraph wrapper: {self.runnable}")
|
||||
raise AttributeError(f"Attribute {key} not exists in the runnable of aclgraph wrapper: {self.runnable}")
|
||||
|
||||
def unwrap(self) -> Callable:
|
||||
# in case we need to access the original runnable.
|
||||
@@ -102,8 +103,7 @@ class ACLGraphWrapper:
|
||||
batch_descriptor = forward_context.batch_descriptor
|
||||
aclgraph_runtime_mode = forward_context.cudagraph_runtime_mode
|
||||
|
||||
if aclgraph_runtime_mode == CUDAGraphMode.NONE or \
|
||||
aclgraph_runtime_mode != self.runtime_mode:
|
||||
if aclgraph_runtime_mode == CUDAGraphMode.NONE or aclgraph_runtime_mode != self.runtime_mode:
|
||||
# CUDAGraphMode.NONE could mean the profile run, a warmup run, or
|
||||
# running without aclgraphs.
|
||||
# We do not trigger capture/replay if the runtime mode is not
|
||||
@@ -114,8 +114,7 @@ class ACLGraphWrapper:
|
||||
|
||||
if batch_descriptor not in self.concrete_aclgraph_entries:
|
||||
# create a new entry for this batch descriptor
|
||||
self.concrete_aclgraph_entries[batch_descriptor] = \
|
||||
ACLGraphEntry(batch_descriptor=batch_descriptor)
|
||||
self.concrete_aclgraph_entries[batch_descriptor] = ACLGraphEntry(batch_descriptor=batch_descriptor)
|
||||
|
||||
entry = self.concrete_aclgraph_entries[batch_descriptor]
|
||||
|
||||
@@ -125,14 +124,11 @@ class ACLGraphWrapper:
|
||||
# capturing is fast, we don't need to log it for every
|
||||
# shape. E.g. we only log it for the first subgraph in
|
||||
# piecewise mode.
|
||||
logger.debug("Capturing a aclgraph on (%s,%s)",
|
||||
self.runtime_mode.name, entry.batch_descriptor)
|
||||
logger.debug("Capturing a aclgraph on (%s,%s)", self.runtime_mode.name, entry.batch_descriptor)
|
||||
# validate that aclgraph capturing is legal at this point.
|
||||
validate_cudagraph_capturing_enabled()
|
||||
|
||||
input_addresses = [
|
||||
x.data_ptr() for x in args if isinstance(x, torch.Tensor)
|
||||
]
|
||||
input_addresses = [x.data_ptr() for x in args if isinstance(x, torch.Tensor)]
|
||||
entry.input_addresses = input_addresses
|
||||
aclgraph = torch.npu.NPUGraph()
|
||||
|
||||
@@ -145,8 +141,7 @@ class ACLGraphWrapper:
|
||||
# therefore, we only run gc for the first graph,
|
||||
# and disable gc for the rest of the graphs.
|
||||
stack.enter_context(patch("gc.collect", lambda: None))
|
||||
stack.enter_context(
|
||||
patch("torch.npu.empty_cache", lambda: None))
|
||||
stack.enter_context(patch("torch.npu.empty_cache", lambda: None))
|
||||
|
||||
# mind-exploding: carefully manage the reference and memory.
|
||||
forward_context.capturing = True
|
||||
@@ -183,13 +178,12 @@ class ACLGraphWrapper:
|
||||
|
||||
if self.is_debugging_mode:
|
||||
# check if the input addresses are the same
|
||||
new_input_addresses = [
|
||||
x.data_ptr() for x in args if isinstance(x, torch.Tensor)
|
||||
]
|
||||
new_input_addresses = [x.data_ptr() for x in args if isinstance(x, torch.Tensor)]
|
||||
assert new_input_addresses == entry.input_addresses, (
|
||||
f"Input addresses for aclgraphs are different "
|
||||
f"during replay. Expected {entry.input_addresses}, "
|
||||
f"got {new_input_addresses}")
|
||||
f"got {new_input_addresses}"
|
||||
)
|
||||
|
||||
logger.info_once("Replaying aclgraph")
|
||||
# In async scheduling or multi-threaded (MT) scenarios when graph mode is FULL, it is possible that
|
||||
@@ -209,8 +203,7 @@ def weak_ref_workspaces(params):
|
||||
for num_tokens in params.workspaces:
|
||||
if params.workspaces[num_tokens] is None:
|
||||
continue
|
||||
params.workspaces[num_tokens] = weak_ref_tensors(
|
||||
params.workspaces[num_tokens])
|
||||
params.workspaces[num_tokens] = weak_ref_tensors(params.workspaces[num_tokens])
|
||||
|
||||
|
||||
def _update_attn_pa_params(update_stream, forward_context, runtime_shape):
|
||||
@@ -219,10 +212,10 @@ def _update_attn_pa_params(update_stream, forward_context, runtime_shape):
|
||||
# for each layer's attention op in the graph.
|
||||
with torch.npu.stream(update_stream):
|
||||
for key, param, handle, event in zip(
|
||||
forward_context.attn_metadata,
|
||||
graph_params.attn_params[runtime_shape],
|
||||
graph_params.handles[runtime_shape],
|
||||
graph_params.events[runtime_shape],
|
||||
forward_context.attn_metadata,
|
||||
graph_params.attn_params[runtime_shape],
|
||||
graph_params.handles[runtime_shape],
|
||||
graph_params.events[runtime_shape],
|
||||
):
|
||||
(
|
||||
query,
|
||||
@@ -254,18 +247,21 @@ def _update_attn_pa_params(update_stream, forward_context, runtime_shape):
|
||||
scale_value=scale,
|
||||
block_table=block_table,
|
||||
context_lens=seq_lens,
|
||||
out=output)
|
||||
out=output,
|
||||
)
|
||||
torch.npu.graph_task_update_begin(update_stream, handle)
|
||||
torch_npu._npu_paged_attention(query=query,
|
||||
key_cache=key_cache,
|
||||
value_cache=value_cache,
|
||||
num_kv_heads=num_kv_heads,
|
||||
num_heads=num_heads,
|
||||
scale_value=scale,
|
||||
block_table=block_table,
|
||||
context_lens=seq_lens,
|
||||
out=output,
|
||||
workspace=workspace)
|
||||
torch_npu._npu_paged_attention(
|
||||
query=query,
|
||||
key_cache=key_cache,
|
||||
value_cache=value_cache,
|
||||
num_kv_heads=num_kv_heads,
|
||||
num_heads=num_heads,
|
||||
scale_value=scale,
|
||||
block_table=block_table,
|
||||
context_lens=seq_lens,
|
||||
out=output,
|
||||
workspace=workspace,
|
||||
)
|
||||
torch.npu.graph_task_update_end(update_stream)
|
||||
|
||||
event.record(update_stream)
|
||||
@@ -282,18 +278,29 @@ def _update_attn_fia_params(update_stream, forward_context, runtime_shape):
|
||||
# filters out the update operations for linear_attn.
|
||||
with torch.npu.stream(update_stream):
|
||||
for key, param, handle, event in zip(
|
||||
forward_context.attn_metadata,
|
||||
graph_params.attn_params[runtime_shape],
|
||||
graph_params.handles[runtime_shape],
|
||||
graph_params.events[runtime_shape],
|
||||
forward_context.attn_metadata,
|
||||
graph_params.attn_params[runtime_shape],
|
||||
graph_params.handles[runtime_shape],
|
||||
graph_params.events[runtime_shape],
|
||||
):
|
||||
(query, key_cache, value, block_tables, attn_mask, block_size,
|
||||
seq_lens, query_start_loc, num_kv_heads, num_heads, scale,
|
||||
attn_output, softmax_lse) = param
|
||||
(
|
||||
query,
|
||||
key_cache,
|
||||
value,
|
||||
block_tables,
|
||||
attn_mask,
|
||||
block_size,
|
||||
seq_lens,
|
||||
query_start_loc,
|
||||
num_kv_heads,
|
||||
num_heads,
|
||||
scale,
|
||||
attn_output,
|
||||
softmax_lse,
|
||||
) = param
|
||||
|
||||
seq_lens = forward_context.attn_metadata[key].seq_lens_list
|
||||
actual_seq_lengths_q = forward_context.attn_metadata[
|
||||
key].actual_seq_lengths_q
|
||||
actual_seq_lengths_q = forward_context.attn_metadata[key].actual_seq_lengths_q
|
||||
torch.npu.graph_task_update_begin(update_stream, handle)
|
||||
torch_npu.npu_fused_infer_attention_score.out(
|
||||
query=query,
|
||||
@@ -317,16 +324,14 @@ def _update_attn_fia_params(update_stream, forward_context, runtime_shape):
|
||||
event.record(update_stream)
|
||||
|
||||
|
||||
def update_attn_params(update_stream, forward_context, runtime_shape,
|
||||
vllm_config):
|
||||
def update_attn_params(update_stream, forward_context, runtime_shape, vllm_config):
|
||||
if using_paged_attention(runtime_shape, vllm_config):
|
||||
_update_attn_pa_params(update_stream, forward_context, runtime_shape)
|
||||
else:
|
||||
_update_attn_fia_params(update_stream, forward_context, runtime_shape)
|
||||
|
||||
|
||||
def update_mla_attn_params(update_stream, forward_context, runtime_shape,
|
||||
speculative_config):
|
||||
def update_mla_attn_params(update_stream, forward_context, runtime_shape, speculative_config):
|
||||
if forward_context.is_draft_model:
|
||||
graph_params = get_draft_graph_params()
|
||||
else:
|
||||
@@ -335,41 +340,44 @@ def update_mla_attn_params(update_stream, forward_context, runtime_shape,
|
||||
# for each layer's attention op in the graph.
|
||||
with torch.npu.stream(update_stream):
|
||||
for key, param, handle, event in zip(
|
||||
forward_context.attn_metadata,
|
||||
graph_params.attn_params[runtime_shape],
|
||||
graph_params.handles[runtime_shape],
|
||||
graph_params.events[runtime_shape],
|
||||
forward_context.attn_metadata,
|
||||
graph_params.attn_params[runtime_shape],
|
||||
graph_params.handles[runtime_shape],
|
||||
graph_params.events[runtime_shape],
|
||||
):
|
||||
(q_nope, k_nope, q_pe, k_pe, num_heads, num_kv_heads, input_layout,
|
||||
attn_mask, sparse_mode, scale, block_table, block_size,
|
||||
seq_lens_list, actual_seq_lengths, attn_output,
|
||||
softmax_lse) = param
|
||||
seq_lens_list = forward_context.attn_metadata[
|
||||
key].decode.seq_lens_list
|
||||
if speculative_config and speculative_config.method == "mtp" \
|
||||
and not forward_context.is_draft_model:
|
||||
actual_seq_lengths = forward_context.attn_metadata[
|
||||
key].decode.actual_seq_lengths_q
|
||||
(
|
||||
q_nope,
|
||||
k_nope,
|
||||
q_pe,
|
||||
k_pe,
|
||||
num_heads,
|
||||
num_kv_heads,
|
||||
input_layout,
|
||||
attn_mask,
|
||||
sparse_mode,
|
||||
scale,
|
||||
block_table,
|
||||
block_size,
|
||||
seq_lens_list,
|
||||
actual_seq_lengths,
|
||||
attn_output,
|
||||
softmax_lse,
|
||||
) = param
|
||||
seq_lens_list = forward_context.attn_metadata[key].decode.seq_lens_list
|
||||
if speculative_config and speculative_config.method == "mtp" and not forward_context.is_draft_model:
|
||||
actual_seq_lengths = forward_context.attn_metadata[key].decode.actual_seq_lengths_q
|
||||
spec_multiple = speculative_config.num_speculative_tokens + 1
|
||||
seq_lens_list = seq_lens_list + [0] * (
|
||||
runtime_shape // spec_multiple - len(seq_lens_list))
|
||||
actual_seq_lengths = [
|
||||
spec_multiple * (i + 1)
|
||||
for i in range(runtime_shape // spec_multiple)
|
||||
]
|
||||
seq_lens_list = seq_lens_list + [0] * (runtime_shape // spec_multiple - len(seq_lens_list))
|
||||
actual_seq_lengths = [spec_multiple * (i + 1) for i in range(runtime_shape // spec_multiple)]
|
||||
elif forward_context.is_draft_model:
|
||||
actual_seq_lengths = forward_context.attn_metadata[
|
||||
key].decode.actual_seq_lengths_q
|
||||
block_table = forward_context.attn_metadata[
|
||||
key].decode.block_table
|
||||
actual_seq_lengths = forward_context.attn_metadata[key].decode.actual_seq_lengths_q
|
||||
block_table = forward_context.attn_metadata[key].decode.block_table
|
||||
# TODO: This is a hack and should be fixed in the future.
|
||||
if speculative_config.disable_padded_drafter_batch:
|
||||
block_table = block_table[:len(actual_seq_lengths)]
|
||||
seq_lens_list = seq_lens_list + [0] * (
|
||||
len(actual_seq_lengths) - len(seq_lens_list))
|
||||
block_table = block_table[: len(actual_seq_lengths)]
|
||||
seq_lens_list = seq_lens_list + [0] * (len(actual_seq_lengths) - len(seq_lens_list))
|
||||
else:
|
||||
seq_lens_list = seq_lens_list + [0] * (runtime_shape -
|
||||
len(seq_lens_list))
|
||||
seq_lens_list = seq_lens_list + [0] * (runtime_shape - len(seq_lens_list))
|
||||
torch.npu.graph_task_update_begin(update_stream, handle)
|
||||
|
||||
torch_npu.npu_fused_infer_attention_score.out(
|
||||
@@ -391,7 +399,8 @@ def update_mla_attn_params(update_stream, forward_context, runtime_shape,
|
||||
actual_seq_lengths_kv=seq_lens_list,
|
||||
actual_seq_lengths=actual_seq_lengths,
|
||||
workspace=graph_params.workspaces.get(runtime_shape),
|
||||
out=[attn_output, softmax_lse])
|
||||
out=[attn_output, softmax_lse],
|
||||
)
|
||||
torch.npu.graph_task_update_end(update_stream)
|
||||
|
||||
event.record(update_stream)
|
||||
@@ -403,34 +412,40 @@ def update_attn_dcp_pcp_params(update_stream, forward_context, runtime_shape):
|
||||
graph_params = get_graph_params()
|
||||
with torch.npu.stream(update_stream):
|
||||
for key, param, handle, event in zip(
|
||||
forward_context.attn_metadata,
|
||||
graph_params.attn_params[runtime_shape],
|
||||
graph_params.handles[runtime_shape],
|
||||
graph_params.events[runtime_shape],
|
||||
forward_context.attn_metadata,
|
||||
graph_params.attn_params[runtime_shape],
|
||||
graph_params.handles[runtime_shape],
|
||||
graph_params.events[runtime_shape],
|
||||
):
|
||||
(q_nope, k_nope, value, num_heads, num_kv_heads, scale,
|
||||
block_table, block_size, actual_seq_lengths_kv,
|
||||
actual_seq_lengths_q, attn_output, softmax_lse, dcp_size,
|
||||
pcp_rank, dcp_rank) = param
|
||||
(
|
||||
q_nope,
|
||||
k_nope,
|
||||
value,
|
||||
num_heads,
|
||||
num_kv_heads,
|
||||
scale,
|
||||
block_table,
|
||||
block_size,
|
||||
actual_seq_lengths_kv,
|
||||
actual_seq_lengths_q,
|
||||
attn_output,
|
||||
softmax_lse,
|
||||
dcp_size,
|
||||
pcp_rank,
|
||||
dcp_rank,
|
||||
) = param
|
||||
attn_metadata = forward_context.attn_metadata[key]
|
||||
actual_seq_lengths_kv = attn_metadata.decode_meta.num_computed_tokens_of_pcp_dcp[:,
|
||||
pcp_rank,
|
||||
dcp_rank]
|
||||
actual_seq_lengths_kv = attn_metadata.decode_meta.num_computed_tokens_of_pcp_dcp[:, pcp_rank, dcp_rank]
|
||||
pad_length = runtime_shape - len(actual_seq_lengths_kv)
|
||||
if pad_length > 0:
|
||||
pad_tensor = np.zeros(pad_length,
|
||||
dtype=actual_seq_lengths_kv.dtype)
|
||||
actual_seq_lengths_kv = np.concatenate(
|
||||
[actual_seq_lengths_kv, pad_tensor])
|
||||
pad_tensor = np.zeros(pad_length, dtype=actual_seq_lengths_kv.dtype)
|
||||
actual_seq_lengths_kv = np.concatenate([actual_seq_lengths_kv, pad_tensor])
|
||||
|
||||
actual_seq_lengths_q = attn_metadata.actual_seq_lengths_q[:
|
||||
attn_metadata
|
||||
.
|
||||
num_decode_tokens]
|
||||
if (runtime_shape - len(actual_seq_lengths_q)):
|
||||
actual_seq_lengths_q = actual_seq_lengths_q + [
|
||||
actual_seq_lengths_q[-1]
|
||||
] * (runtime_shape - len(actual_seq_lengths_q))
|
||||
actual_seq_lengths_q = attn_metadata.actual_seq_lengths_q[: attn_metadata.num_decode_tokens]
|
||||
if runtime_shape - len(actual_seq_lengths_q):
|
||||
actual_seq_lengths_q = actual_seq_lengths_q + [actual_seq_lengths_q[-1]] * (
|
||||
runtime_shape - len(actual_seq_lengths_q)
|
||||
)
|
||||
if dcp_size > 1:
|
||||
num_heads = num_heads * dcp_size
|
||||
|
||||
@@ -453,14 +468,14 @@ def update_attn_dcp_pcp_params(update_stream, forward_context, runtime_shape):
|
||||
actual_seq_lengths_kv=actual_seq_lengths_kv,
|
||||
actual_seq_lengths=actual_seq_lengths_q,
|
||||
workspace=graph_params.workspaces.get(runtime_shape),
|
||||
out=[attn_output, softmax_lse])
|
||||
out=[attn_output, softmax_lse],
|
||||
)
|
||||
torch.npu.graph_task_update_end(update_stream)
|
||||
|
||||
event.record(update_stream)
|
||||
|
||||
|
||||
def update_mla_attn_dcp_pcp_params(update_stream, forward_context,
|
||||
runtime_shape):
|
||||
def update_mla_attn_dcp_pcp_params(update_stream, forward_context, runtime_shape):
|
||||
if forward_context.is_draft_model:
|
||||
graph_params = get_draft_graph_params()
|
||||
else:
|
||||
@@ -469,13 +484,24 @@ def update_mla_attn_dcp_pcp_params(update_stream, forward_context,
|
||||
# for each layer's attention op in the graph.
|
||||
with torch.npu.stream(update_stream):
|
||||
for key, param, handle, event in zip(
|
||||
forward_context.attn_metadata,
|
||||
graph_params.attn_params[runtime_shape],
|
||||
graph_params.handles[runtime_shape],
|
||||
graph_params.events[runtime_shape],
|
||||
forward_context.attn_metadata,
|
||||
graph_params.attn_params[runtime_shape],
|
||||
graph_params.handles[runtime_shape],
|
||||
graph_params.events[runtime_shape],
|
||||
):
|
||||
(q_nope, q_pe, k_nope, k_pe, block_table, seq_len, num_heads,
|
||||
scale, num_kv_heads, attn_output, softmax_lse) = param
|
||||
(
|
||||
q_nope,
|
||||
q_pe,
|
||||
k_nope,
|
||||
k_pe,
|
||||
block_table,
|
||||
seq_len,
|
||||
num_heads,
|
||||
scale,
|
||||
num_kv_heads,
|
||||
attn_output,
|
||||
softmax_lse,
|
||||
) = param
|
||||
|
||||
decode_meta = forward_context.attn_metadata[key].decode
|
||||
seq_len = decode_meta.cp_seq_len
|
||||
@@ -484,9 +510,7 @@ def update_mla_attn_dcp_pcp_params(update_stream, forward_context,
|
||||
# to avoid irregular attn_mask shape,
|
||||
# so there's no need to divide runtime_shape by spec_multiple
|
||||
pad_length = runtime_shape - len(seq_len)
|
||||
pad_tensor = torch.zeros(pad_length,
|
||||
dtype=seq_len.dtype,
|
||||
device=seq_len.device)
|
||||
pad_tensor = torch.zeros(pad_length, dtype=seq_len.dtype, device=seq_len.device)
|
||||
seq_len = torch.cat([seq_len, pad_tensor], dim=0)
|
||||
|
||||
torch.npu.graph_task_update_begin(update_stream, handle)
|
||||
@@ -505,7 +529,8 @@ def update_mla_attn_dcp_pcp_params(update_stream, forward_context,
|
||||
calc_type="calc_type_ring",
|
||||
workspace=graph_params.workspaces.get(runtime_shape),
|
||||
output=attn_output,
|
||||
lse=softmax_lse)
|
||||
lse=softmax_lse,
|
||||
)
|
||||
torch.npu.graph_task_update_end(update_stream)
|
||||
|
||||
event.record(update_stream)
|
||||
@@ -519,7 +544,7 @@ class GraphParams:
|
||||
attn_params: dict[int, list[tuple]]
|
||||
|
||||
|
||||
_graph_params: Optional[GraphParams] = None
|
||||
_graph_params: GraphParams | None = None
|
||||
|
||||
|
||||
def set_graph_params(aclgraph_capture_sizes: list[int]):
|
||||
@@ -527,14 +552,10 @@ def set_graph_params(aclgraph_capture_sizes: list[int]):
|
||||
if _graph_params is not None:
|
||||
raise ValueError("Graph parameters have already been set!")
|
||||
_graph_params = GraphParams(
|
||||
{size: []
|
||||
for size in aclgraph_capture_sizes},
|
||||
{size: None
|
||||
for size in aclgraph_capture_sizes},
|
||||
{size: []
|
||||
for size in aclgraph_capture_sizes},
|
||||
{size: []
|
||||
for size in aclgraph_capture_sizes},
|
||||
{size: [] for size in aclgraph_capture_sizes},
|
||||
{size: None for size in aclgraph_capture_sizes},
|
||||
{size: [] for size in aclgraph_capture_sizes},
|
||||
{size: [] for size in aclgraph_capture_sizes},
|
||||
)
|
||||
|
||||
|
||||
@@ -548,7 +569,7 @@ def get_graph_params():
|
||||
return _graph_params
|
||||
|
||||
|
||||
_draft_graph_params: Optional[GraphParams] = None
|
||||
_draft_graph_params: GraphParams | None = None
|
||||
|
||||
|
||||
def set_draft_graph_params(aclgraph_capture_sizes: list[int]):
|
||||
@@ -556,14 +577,10 @@ def set_draft_graph_params(aclgraph_capture_sizes: list[int]):
|
||||
if _draft_graph_params is not None:
|
||||
raise ValueError("DraftGraph parameters have already been set!")
|
||||
_draft_graph_params = GraphParams(
|
||||
{size: []
|
||||
for size in aclgraph_capture_sizes},
|
||||
{size: None
|
||||
for size in aclgraph_capture_sizes},
|
||||
{size: []
|
||||
for size in aclgraph_capture_sizes},
|
||||
{size: []
|
||||
for size in aclgraph_capture_sizes},
|
||||
{size: [] for size in aclgraph_capture_sizes},
|
||||
{size: None for size in aclgraph_capture_sizes},
|
||||
{size: [] for size in aclgraph_capture_sizes},
|
||||
{size: [] for size in aclgraph_capture_sizes},
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -16,13 +16,13 @@
|
||||
# limitations under the License.
|
||||
#
|
||||
import functools
|
||||
from typing import Any, Callable, Optional
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
import torch.fx as fx
|
||||
from torch._dynamo.backends.common import aot_autograd
|
||||
from torch._inductor.compile_fx import (graph_returns_tuple,
|
||||
make_graph_return_tuple)
|
||||
from torch._inductor.compile_fx import graph_returns_tuple, make_graph_return_tuple
|
||||
from torch._inductor.decomposition import select_decomp_table
|
||||
from torch.fx import GraphModule
|
||||
from vllm.compilation.compiler_interface import CompilerInterface
|
||||
@@ -32,15 +32,11 @@ from vllm_ascend.ascend_config import get_ascend_config
|
||||
from vllm_ascend.utils import COMPILATION_PASS_KEY
|
||||
|
||||
|
||||
def compile_fx(graph: GraphModule, example_inputs: list,
|
||||
inner_compile: Callable, decompositions: dict) -> Callable:
|
||||
recursive_compile_fx = functools.partial(compile_fx,
|
||||
inner_compile=inner_compile,
|
||||
decompositions=decompositions)
|
||||
def compile_fx(graph: GraphModule, example_inputs: list, inner_compile: Callable, decompositions: dict) -> Callable:
|
||||
recursive_compile_fx = functools.partial(compile_fx, inner_compile=inner_compile, decompositions=decompositions)
|
||||
|
||||
if not graph_returns_tuple(graph):
|
||||
return make_graph_return_tuple(graph, example_inputs,
|
||||
recursive_compile_fx)
|
||||
return make_graph_return_tuple(graph, example_inputs, recursive_compile_fx)
|
||||
return aot_autograd(fw_compiler=inner_compile)(graph, example_inputs)
|
||||
|
||||
|
||||
@@ -49,9 +45,8 @@ def fusion_pass_compile(
|
||||
example_inputs: list[Any],
|
||||
compiler_config: dict[str, Any],
|
||||
compile_range: Range,
|
||||
key: Optional[str] = None,
|
||||
) -> tuple[Optional[Callable], Optional[Any]]:
|
||||
|
||||
key: str | None = None,
|
||||
) -> tuple[Callable | None, Any | None]:
|
||||
def compile_inner(graph, example_inputs):
|
||||
current_pass_manager = compiler_config[COMPILATION_PASS_KEY]
|
||||
graph = current_pass_manager(graph)
|
||||
@@ -74,8 +69,8 @@ def npugraph_ex_compile(
|
||||
example_inputs: list[Any],
|
||||
compiler_config: dict[str, Any],
|
||||
compile_range: Range,
|
||||
key: Optional[str] = None,
|
||||
) -> tuple[Optional[Callable], Optional[Any]]:
|
||||
key: str | None = None,
|
||||
) -> tuple[Callable | None, Any | None]:
|
||||
# When currently using the FULL_DECODE_ONLY mode,
|
||||
# the piecewise compilation level slicing process
|
||||
# in vllm is also encountered.
|
||||
@@ -87,10 +82,8 @@ def npugraph_ex_compile(
|
||||
output_node = fx_graph.output_node()
|
||||
with fx_graph.inserting_before(output_node):
|
||||
return_value = output_node.args[0]
|
||||
tuple_node = fx_graph.create_node("call_function",
|
||||
tuple,
|
||||
args=([return_value], ))
|
||||
output_node.args = (tuple_node, )
|
||||
tuple_node = fx_graph.create_node("call_function", tuple, args=([return_value],))
|
||||
output_node.args = (tuple_node,)
|
||||
graph.recompile()
|
||||
|
||||
import torchair
|
||||
@@ -119,6 +112,7 @@ class AscendCompiler(CompilerInterface):
|
||||
This class provides a method to compile a PyTorch FX graph module with
|
||||
specific configurations for graph fusion and decomposition.
|
||||
"""
|
||||
|
||||
name = "AscendCompiler"
|
||||
|
||||
def compile(
|
||||
@@ -127,13 +121,10 @@ class AscendCompiler(CompilerInterface):
|
||||
example_inputs: list[Any],
|
||||
compiler_config: dict[str, Any],
|
||||
compile_range: Range,
|
||||
key: Optional[str] = None,
|
||||
) -> tuple[Optional[Callable], Optional[Any]]:
|
||||
|
||||
key: str | None = None,
|
||||
) -> tuple[Callable | None, Any | None]:
|
||||
ascend_config = get_ascend_config()
|
||||
if ascend_config.enable_npugraph_ex:
|
||||
return npugraph_ex_compile(graph, example_inputs, compiler_config,
|
||||
compile_range, key)
|
||||
return npugraph_ex_compile(graph, example_inputs, compiler_config, compile_range, key)
|
||||
else:
|
||||
return fusion_pass_compile(graph, example_inputs, compiler_config,
|
||||
compile_range, key)
|
||||
return fusion_pass_compile(graph, example_inputs, compiler_config, compile_range, key)
|
||||
|
||||
@@ -26,7 +26,7 @@ class GraphFusionPassManager:
|
||||
"""
|
||||
A pass manager for graph fusion passes.
|
||||
It handles the configuration and execution of passes.
|
||||
The counterpart in vllm is PostGradPassManager. Since torch_npu
|
||||
The counterpart in vllm is PostGradPassManager. Since torch_npu
|
||||
does not support triton for now, we define our own pass manager.
|
||||
"""
|
||||
|
||||
@@ -48,13 +48,13 @@ class GraphFusionPassManager:
|
||||
|
||||
def configure(self, config: VllmConfig):
|
||||
# By default, we enable the graph fusion and quantization fusion pass.
|
||||
self.ascend_compilation_config: dict = config.additional_config.get(
|
||||
"ascend_compilation_config", {})
|
||||
self.ascend_compilation_config: dict = config.additional_config.get("ascend_compilation_config", {})
|
||||
if self.ascend_compilation_config.get("fuse_norm_quant", True):
|
||||
from .passes.norm_quant_fusion_pass import \
|
||||
AddRMSNormQuantFusionPass
|
||||
from .passes.norm_quant_fusion_pass import AddRMSNormQuantFusionPass
|
||||
|
||||
self.passes.append(AddRMSNormQuantFusionPass(config))
|
||||
|
||||
if self.ascend_compilation_config.get("fuse_qknorm_rope", True):
|
||||
from .passes.qknorm_rope_fusion_pass import QKNormRopeFusionPass
|
||||
|
||||
self.passes.append(QKNormRopeFusionPass(config))
|
||||
|
||||
@@ -48,7 +48,8 @@ def _extra_stream_scope_check(match: Match) -> bool:
|
||||
logger.debug(
|
||||
f"Cross-stream operation detected in pattern match for AddRMSNormQuant. "
|
||||
f"Multiple streams found: {non_default_streams}. "
|
||||
f"Fusion is not supported for cross-stream operations.")
|
||||
f"Fusion is not supported for cross-stream operations."
|
||||
)
|
||||
return False
|
||||
|
||||
return True
|
||||
@@ -57,24 +58,29 @@ def _extra_stream_scope_check(match: Match) -> bool:
|
||||
@functools.lru_cache(None)
|
||||
# The replacement registered here will be actually executed after AOT.
|
||||
def replacement_add_rms_norm_quant(epsilon):
|
||||
|
||||
def pattern(rms_norm_input: torch.Tensor, residual: torch.Tensor,
|
||||
rms_norm_weight: torch.Tensor, scale: torch.Tensor,
|
||||
offset: torch.Tensor):
|
||||
def pattern(
|
||||
rms_norm_input: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
rms_norm_weight: torch.Tensor,
|
||||
scale: torch.Tensor,
|
||||
offset: torch.Tensor,
|
||||
):
|
||||
"""
|
||||
Pattern for AddRMSNormQuant fusion.
|
||||
"""
|
||||
output = torch.ops.npu.npu_add_rms_norm(rms_norm_input, residual,
|
||||
rms_norm_weight, epsilon)
|
||||
output = torch.ops.npu.npu_add_rms_norm(rms_norm_input, residual, rms_norm_weight, epsilon)
|
||||
out0 = output[0]
|
||||
out1 = output[2]
|
||||
quantized_output = torch.ops.npu.npu_quantize(out0, scale, offset,
|
||||
torch.qint8, -1, False)
|
||||
quantized_output = torch.ops.npu.npu_quantize(out0, scale, offset, torch.qint8, -1, False)
|
||||
return quantized_output, out1
|
||||
|
||||
def replacement(rms_norm_input: torch.Tensor, residual: torch.Tensor,
|
||||
rms_norm_weight: torch.Tensor, scale: torch.Tensor,
|
||||
offset: torch.Tensor):
|
||||
def replacement(
|
||||
rms_norm_input: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
rms_norm_weight: torch.Tensor,
|
||||
scale: torch.Tensor,
|
||||
offset: torch.Tensor,
|
||||
):
|
||||
"""
|
||||
Replacement for the AddRMSNormQuant fusion.
|
||||
"""
|
||||
@@ -82,10 +88,12 @@ def replacement_add_rms_norm_quant(epsilon):
|
||||
rms_norm_input,
|
||||
residual,
|
||||
rms_norm_weight,
|
||||
# The inverse of scale is required by npu_add_rms_norm_quant kernel which is opposite to the npu_quantize kernel.
|
||||
1. / scale,
|
||||
# The inverse of scale is required by npu_add_rms_norm_quant kernel
|
||||
# which is opposite to the npu_quantize kernel.
|
||||
1.0 / scale,
|
||||
offset,
|
||||
epsilon=epsilon)
|
||||
epsilon=epsilon,
|
||||
)
|
||||
quantized_output = output[0]
|
||||
out1 = output[2]
|
||||
return quantized_output, out1
|
||||
@@ -103,33 +111,39 @@ def replacement_add_rms_norm_quant(epsilon):
|
||||
|
||||
import torchair
|
||||
|
||||
torchair.register_replacement(search_fn=pattern,
|
||||
replace_fn=replacement,
|
||||
example_inputs=get_inputs(),
|
||||
extra_check=_extra_stream_scope_check)
|
||||
torchair.register_replacement(
|
||||
search_fn=pattern, replace_fn=replacement, example_inputs=get_inputs(), extra_check=_extra_stream_scope_check
|
||||
)
|
||||
|
||||
|
||||
# The replacement registered here will be actually executed after AOT.
|
||||
def replacement_add_rms_norm_quant_with_bias(epsilon):
|
||||
|
||||
def pattern(rms_norm_input: torch.Tensor, residual: torch.Tensor,
|
||||
rms_norm_weight: torch.Tensor, scale: torch.Tensor,
|
||||
offset: torch.Tensor, bias: torch.Tensor):
|
||||
def pattern(
|
||||
rms_norm_input: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
rms_norm_weight: torch.Tensor,
|
||||
scale: torch.Tensor,
|
||||
offset: torch.Tensor,
|
||||
bias: torch.Tensor,
|
||||
):
|
||||
"""
|
||||
Pattern for AddRMSNormQuantWithBias fusion.
|
||||
"""
|
||||
output = torch.ops.npu.npu_add_rms_norm(rms_norm_input, residual,
|
||||
rms_norm_weight, epsilon)
|
||||
output = torch.ops.npu.npu_add_rms_norm(rms_norm_input, residual, rms_norm_weight, epsilon)
|
||||
out0 = output[0]
|
||||
out1 = output[2]
|
||||
out0 = out0 + bias
|
||||
quantized_output = torch.ops.npu.npu_quantize(out0, scale, offset,
|
||||
torch.qint8, -1, False)
|
||||
quantized_output = torch.ops.npu.npu_quantize(out0, scale, offset, torch.qint8, -1, False)
|
||||
return quantized_output, out1
|
||||
|
||||
def replacement(rms_norm_input: torch.Tensor, residual: torch.Tensor,
|
||||
rms_norm_weight: torch.Tensor, scale: torch.Tensor,
|
||||
offset: torch.Tensor, bias: torch.Tensor):
|
||||
def replacement(
|
||||
rms_norm_input: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
rms_norm_weight: torch.Tensor,
|
||||
scale: torch.Tensor,
|
||||
offset: torch.Tensor,
|
||||
bias: torch.Tensor,
|
||||
):
|
||||
"""
|
||||
Replacement for AddRMSNormQuantWithBias fusion.
|
||||
"""
|
||||
@@ -137,11 +151,13 @@ def replacement_add_rms_norm_quant_with_bias(epsilon):
|
||||
rms_norm_input,
|
||||
residual,
|
||||
rms_norm_weight,
|
||||
# The inverse of scale is required by npu_add_rms_norm_quant kernel which is opposite to the npu_quantize kernel.
|
||||
1. / scale,
|
||||
# The inverse of scale is required by npu_add_rms_norm_quant kernel
|
||||
# which is opposite to the npu_quantize kernel.
|
||||
1.0 / scale,
|
||||
offset,
|
||||
epsilon=epsilon,
|
||||
beta=bias)
|
||||
beta=bias,
|
||||
)
|
||||
quantized_output = output[0]
|
||||
out1 = output[2]
|
||||
return quantized_output, out1
|
||||
@@ -156,40 +172,41 @@ def replacement_add_rms_norm_quant_with_bias(epsilon):
|
||||
rmsnorm_bias = torch.randn(4, device="npu")
|
||||
scale = torch.ones(4, device="npu")
|
||||
offset = torch.zeros(4, device="npu")
|
||||
return [
|
||||
rms_norm_input, residual, rms_norm_weight, scale, offset,
|
||||
rmsnorm_bias
|
||||
]
|
||||
return [rms_norm_input, residual, rms_norm_weight, scale, offset, rmsnorm_bias]
|
||||
|
||||
import torchair
|
||||
|
||||
torchair.register_replacement(search_fn=pattern,
|
||||
replace_fn=replacement,
|
||||
example_inputs=get_inputs(),
|
||||
extra_check=_extra_stream_scope_check)
|
||||
torchair.register_replacement(
|
||||
search_fn=pattern, replace_fn=replacement, example_inputs=get_inputs(), extra_check=_extra_stream_scope_check
|
||||
)
|
||||
|
||||
|
||||
# The replacement registered here will be actually executed after AOT.
|
||||
def replacement_add_rms_norm_quant_sp_pattern(epsilon):
|
||||
|
||||
def pattern(rms_norm_input: torch.Tensor, residual: torch.Tensor,
|
||||
rms_norm_weight: torch.Tensor, scale: torch.Tensor,
|
||||
offset: torch.Tensor):
|
||||
def pattern(
|
||||
rms_norm_input: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
rms_norm_weight: torch.Tensor,
|
||||
scale: torch.Tensor,
|
||||
offset: torch.Tensor,
|
||||
):
|
||||
"""
|
||||
Pattern for AddRMSNormQuantSPPattern fusion.
|
||||
"""
|
||||
output = torch.ops.npu.npu_add_rms_norm(rms_norm_input, residual,
|
||||
rms_norm_weight, epsilon)
|
||||
output = torch.ops.npu.npu_add_rms_norm(rms_norm_input, residual, rms_norm_weight, epsilon)
|
||||
out0 = output[0]
|
||||
out1 = output[2]
|
||||
out0 = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(out0, True)
|
||||
quantized_output = torch.ops.npu.npu_quantize(out0, scale, offset,
|
||||
torch.qint8, -1, False)
|
||||
quantized_output = torch.ops.npu.npu_quantize(out0, scale, offset, torch.qint8, -1, False)
|
||||
return quantized_output, out1
|
||||
|
||||
def replacement(rms_norm_input: torch.Tensor, residual: torch.Tensor,
|
||||
rms_norm_weight: torch.Tensor, scale: torch.Tensor,
|
||||
offset: torch.Tensor):
|
||||
def replacement(
|
||||
rms_norm_input: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
rms_norm_weight: torch.Tensor,
|
||||
scale: torch.Tensor,
|
||||
offset: torch.Tensor,
|
||||
):
|
||||
"""
|
||||
Replacement for the AddRMSNormQuantSPPattern fusion.
|
||||
"""
|
||||
@@ -197,14 +214,15 @@ def replacement_add_rms_norm_quant_sp_pattern(epsilon):
|
||||
rms_norm_input,
|
||||
residual,
|
||||
rms_norm_weight,
|
||||
# The inverse of scale is required by npu_add_rms_norm_quant kernel which is opposite to the npu_quantize kernel.
|
||||
1. / scale,
|
||||
# The inverse of scale is required by npu_add_rms_norm_quant kernel
|
||||
# which is opposite to the npu_quantize kernel.
|
||||
1.0 / scale,
|
||||
offset,
|
||||
epsilon=epsilon)
|
||||
epsilon=epsilon,
|
||||
)
|
||||
quantized_output = output[0]
|
||||
out1 = output[2]
|
||||
quantized_output = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
|
||||
quantized_output, True)
|
||||
quantized_output = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(quantized_output, True)
|
||||
return quantized_output, out1
|
||||
|
||||
def get_inputs():
|
||||
@@ -220,34 +238,40 @@ def replacement_add_rms_norm_quant_sp_pattern(epsilon):
|
||||
|
||||
import torchair
|
||||
|
||||
torchair.register_replacement(search_fn=pattern,
|
||||
replace_fn=replacement,
|
||||
example_inputs=get_inputs(),
|
||||
extra_check=_extra_stream_scope_check)
|
||||
torchair.register_replacement(
|
||||
search_fn=pattern, replace_fn=replacement, example_inputs=get_inputs(), extra_check=_extra_stream_scope_check
|
||||
)
|
||||
|
||||
|
||||
# The replacement registered here will be actually executed after AOT.
|
||||
def replacement_add_rms_norm_quant_sp_pattern_with_bias(epsilon):
|
||||
|
||||
def pattern(rms_norm_input: torch.Tensor, residual: torch.Tensor,
|
||||
rms_norm_weight: torch.Tensor, scale: torch.Tensor,
|
||||
offset: torch.Tensor, bias: torch.Tensor):
|
||||
def pattern(
|
||||
rms_norm_input: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
rms_norm_weight: torch.Tensor,
|
||||
scale: torch.Tensor,
|
||||
offset: torch.Tensor,
|
||||
bias: torch.Tensor,
|
||||
):
|
||||
"""
|
||||
Pattern for AddRMSNormQuantSPPatternWithBias fusion.
|
||||
"""
|
||||
output = torch.ops.npu.npu_add_rms_norm(rms_norm_input, residual,
|
||||
rms_norm_weight, epsilon)
|
||||
output = torch.ops.npu.npu_add_rms_norm(rms_norm_input, residual, rms_norm_weight, epsilon)
|
||||
out0 = output[0]
|
||||
out1 = output[2]
|
||||
out0 = out0 + bias
|
||||
out0 = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(out0, True)
|
||||
quantized_output = torch.ops.npu.npu_quantize(out0, scale, offset,
|
||||
torch.qint8, -1, False)
|
||||
quantized_output = torch.ops.npu.npu_quantize(out0, scale, offset, torch.qint8, -1, False)
|
||||
return quantized_output, out1
|
||||
|
||||
def replacement(rms_norm_input: torch.Tensor, residual: torch.Tensor,
|
||||
rms_norm_weight: torch.Tensor, scale: torch.Tensor,
|
||||
offset: torch.Tensor, bias: torch.Tensor):
|
||||
def replacement(
|
||||
rms_norm_input: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
rms_norm_weight: torch.Tensor,
|
||||
scale: torch.Tensor,
|
||||
offset: torch.Tensor,
|
||||
bias: torch.Tensor,
|
||||
):
|
||||
"""
|
||||
Replacement for the AddRMSNormQuantSPPatternWithBias fusion.
|
||||
"""
|
||||
@@ -255,15 +279,16 @@ def replacement_add_rms_norm_quant_sp_pattern_with_bias(epsilon):
|
||||
rms_norm_input,
|
||||
residual,
|
||||
rms_norm_weight,
|
||||
# The inverse of scale is required by npu_add_rms_norm_quant kernel which is opposite to the npu_quantize kernel.
|
||||
1. / scale,
|
||||
# The inverse of scale is required by npu_add_rms_norm_quant kernel
|
||||
# which is opposite to the npu_quantize kernel.
|
||||
1.0 / scale,
|
||||
offset,
|
||||
epsilon=epsilon,
|
||||
beta=bias)
|
||||
beta=bias,
|
||||
)
|
||||
quantized_output = output[0]
|
||||
out1 = output[2]
|
||||
quantized_output = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
|
||||
quantized_output, True)
|
||||
quantized_output = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(quantized_output, True)
|
||||
return quantized_output, out1
|
||||
|
||||
def get_inputs():
|
||||
@@ -276,25 +301,19 @@ def replacement_add_rms_norm_quant_sp_pattern_with_bias(epsilon):
|
||||
rmsnorm_bias = torch.randn(4, device="npu")
|
||||
scale = torch.ones(4, device="npu")
|
||||
offset = torch.zeros(4, device="npu")
|
||||
return [
|
||||
rms_norm_input, residual, rms_norm_weight, scale, offset,
|
||||
rmsnorm_bias
|
||||
]
|
||||
return [rms_norm_input, residual, rms_norm_weight, scale, offset, rmsnorm_bias]
|
||||
|
||||
import torchair
|
||||
|
||||
torchair.register_replacement(search_fn=pattern,
|
||||
replace_fn=replacement,
|
||||
example_inputs=get_inputs(),
|
||||
extra_check=_extra_stream_scope_check)
|
||||
torchair.register_replacement(
|
||||
search_fn=pattern, replace_fn=replacement, example_inputs=get_inputs(), extra_check=_extra_stream_scope_check
|
||||
)
|
||||
|
||||
|
||||
# register converter for pass
|
||||
common_epsilons = [1e-5, 1e-6]
|
||||
for eps in common_epsilons:
|
||||
logger.info(
|
||||
f"Start register fusion pattern for AddRMSNormQuant with epsilons={eps}"
|
||||
)
|
||||
logger.info(f"Start register fusion pattern for AddRMSNormQuant with epsilons={eps}")
|
||||
replacement_add_rms_norm_quant(eps)
|
||||
replacement_add_rms_norm_quant_with_bias(eps)
|
||||
replacement_add_rms_norm_quant_sp_pattern(eps)
|
||||
|
||||
@@ -25,7 +25,6 @@ from vllm.logger import logger
|
||||
|
||||
|
||||
class AddRMSNormQuantPattern:
|
||||
|
||||
def __init__(self, vllm_config: VllmConfig, eps: float = 1e-6):
|
||||
self.vllm_config = vllm_config
|
||||
self.dtype = vllm_config.model_config.dtype
|
||||
@@ -41,50 +40,48 @@ class AddRMSNormQuantPattern:
|
||||
scale = torch.ones(4, device="npu", dtype=self.dtype)
|
||||
scale_reciprocal = torch.ones(4, device="npu", dtype=self.dtype)
|
||||
offset = torch.zeros(4, device="npu", dtype=self.dtype)
|
||||
return [
|
||||
rms_norm_input, residual, rms_norm_weight, scale, scale_reciprocal,
|
||||
offset
|
||||
]
|
||||
return [rms_norm_input, residual, rms_norm_weight, scale, scale_reciprocal, offset]
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass):
|
||||
|
||||
def pattern(rms_norm_input: torch.Tensor, residual: torch.Tensor,
|
||||
rms_norm_weight: torch.Tensor, scale: torch.Tensor,
|
||||
scale_reciprocal: torch.Tensor, offset: torch.Tensor):
|
||||
def pattern(
|
||||
rms_norm_input: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
rms_norm_weight: torch.Tensor,
|
||||
scale: torch.Tensor,
|
||||
scale_reciprocal: torch.Tensor,
|
||||
offset: torch.Tensor,
|
||||
):
|
||||
"""
|
||||
Pattern for AddRMSNormQuant fusion.
|
||||
"""
|
||||
output = torch.ops.npu.npu_add_rms_norm(rms_norm_input, residual,
|
||||
rms_norm_weight, self.eps)
|
||||
output = torch.ops.npu.npu_add_rms_norm(rms_norm_input, residual, rms_norm_weight, self.eps)
|
||||
out0 = output[0]
|
||||
out1 = output[2]
|
||||
quantized_output = torch.ops.vllm.quantize(out0, scale,
|
||||
scale_reciprocal,
|
||||
offset)
|
||||
quantized_output = torch.ops.vllm.quantize(out0, scale, scale_reciprocal, offset)
|
||||
return quantized_output, out1
|
||||
|
||||
def replacement(rms_norm_input: torch.Tensor, residual: torch.Tensor,
|
||||
rms_norm_weight: torch.Tensor, scale: torch.Tensor,
|
||||
scale_reciprocal: torch.Tensor, offset: torch.Tensor):
|
||||
def replacement(
|
||||
rms_norm_input: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
rms_norm_weight: torch.Tensor,
|
||||
scale: torch.Tensor,
|
||||
scale_reciprocal: torch.Tensor,
|
||||
offset: torch.Tensor,
|
||||
):
|
||||
"""
|
||||
Replacement for the AddRMSNormQuant fusion.
|
||||
"""
|
||||
output = torch.ops.npu.npu_add_rms_norm_quant(rms_norm_input,
|
||||
residual,
|
||||
rms_norm_weight,
|
||||
scale,
|
||||
offset,
|
||||
epsilon=self.eps)
|
||||
output = torch.ops.npu.npu_add_rms_norm_quant(
|
||||
rms_norm_input, residual, rms_norm_weight, scale, offset, epsilon=self.eps
|
||||
)
|
||||
quantized_output = output[0]
|
||||
out1 = output[2]
|
||||
return quantized_output, out1
|
||||
|
||||
pm.register_replacement(pattern, replacement, self.get_inputs(),
|
||||
pm.fwd_only, pm_pass)
|
||||
pm.register_replacement(pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass)
|
||||
|
||||
|
||||
class AddRMSNormQuantPatternWithBias:
|
||||
|
||||
def __init__(self, vllm_config: VllmConfig, eps: float = 1e-6):
|
||||
self.vllm_config = vllm_config
|
||||
self.dtype = vllm_config.model_config.dtype
|
||||
@@ -101,54 +98,51 @@ class AddRMSNormQuantPatternWithBias:
|
||||
scale = torch.ones(4, device="npu", dtype=self.dtype)
|
||||
scale_reciprocal = torch.ones(4, device="npu", dtype=self.dtype)
|
||||
offset = torch.zeros(4, device="npu", dtype=self.dtype)
|
||||
return [
|
||||
rms_norm_input, residual, rms_norm_weight, scale, scale_reciprocal,
|
||||
offset, rmsnorm_bias
|
||||
]
|
||||
return [rms_norm_input, residual, rms_norm_weight, scale, scale_reciprocal, offset, rmsnorm_bias]
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass):
|
||||
|
||||
def pattern(rms_norm_input: torch.Tensor, residual: torch.Tensor,
|
||||
rms_norm_weight: torch.Tensor, scale: torch.Tensor,
|
||||
scale_reciprocal: torch.Tensor, offset: torch.Tensor,
|
||||
bias: torch.Tensor):
|
||||
def pattern(
|
||||
rms_norm_input: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
rms_norm_weight: torch.Tensor,
|
||||
scale: torch.Tensor,
|
||||
scale_reciprocal: torch.Tensor,
|
||||
offset: torch.Tensor,
|
||||
bias: torch.Tensor,
|
||||
):
|
||||
"""
|
||||
Pattern for AddRMSNormQuant fusion.
|
||||
"""
|
||||
output = torch.ops.npu.npu_add_rms_norm(rms_norm_input, residual,
|
||||
rms_norm_weight, self.eps)
|
||||
output = torch.ops.npu.npu_add_rms_norm(rms_norm_input, residual, rms_norm_weight, self.eps)
|
||||
out0 = output[0]
|
||||
out1 = output[2]
|
||||
out0 = out0 + bias
|
||||
quantized_output = torch.ops.vllm.quantize(out0, scale,
|
||||
scale_reciprocal,
|
||||
offset)
|
||||
quantized_output = torch.ops.vllm.quantize(out0, scale, scale_reciprocal, offset)
|
||||
return quantized_output, out1
|
||||
|
||||
def replacement(rms_norm_input: torch.Tensor, residual: torch.Tensor,
|
||||
rms_norm_weight: torch.Tensor, scale: torch.Tensor,
|
||||
scale_reciprocal: torch.Tensor, offset: torch.Tensor,
|
||||
bias: torch.Tensor):
|
||||
def replacement(
|
||||
rms_norm_input: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
rms_norm_weight: torch.Tensor,
|
||||
scale: torch.Tensor,
|
||||
scale_reciprocal: torch.Tensor,
|
||||
offset: torch.Tensor,
|
||||
bias: torch.Tensor,
|
||||
):
|
||||
"""
|
||||
Replacement for the AddRMSNormQuant fusion.
|
||||
"""
|
||||
output = torch.ops.npu.npu_add_rms_norm_quant(rms_norm_input,
|
||||
residual,
|
||||
rms_norm_weight,
|
||||
scale,
|
||||
offset,
|
||||
epsilon=self.eps,
|
||||
beta=bias)
|
||||
output = torch.ops.npu.npu_add_rms_norm_quant(
|
||||
rms_norm_input, residual, rms_norm_weight, scale, offset, epsilon=self.eps, beta=bias
|
||||
)
|
||||
quantized_output = output[0]
|
||||
out1 = output[2]
|
||||
return quantized_output, out1
|
||||
|
||||
pm.register_replacement(pattern, replacement, self.get_inputs(),
|
||||
pm.fwd_only, pm_pass)
|
||||
pm.register_replacement(pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass)
|
||||
|
||||
|
||||
class AddRMSNormQuantSPPattern:
|
||||
|
||||
def __init__(self, vllm_config: VllmConfig, eps: float = 1e-6):
|
||||
self.vllm_config = vllm_config
|
||||
self.dtype = vllm_config.model_config.dtype
|
||||
@@ -164,53 +158,50 @@ class AddRMSNormQuantSPPattern:
|
||||
scale = torch.ones(4, device="npu", dtype=self.dtype)
|
||||
scale_reciprocal = torch.ones(4, device="npu", dtype=self.dtype)
|
||||
offset = torch.zeros(4, device="npu", dtype=self.dtype)
|
||||
return [
|
||||
rms_norm_input, residual, rms_norm_weight, scale, scale_reciprocal,
|
||||
offset
|
||||
]
|
||||
return [rms_norm_input, residual, rms_norm_weight, scale, scale_reciprocal, offset]
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass):
|
||||
|
||||
def pattern(rms_norm_input: torch.Tensor, residual: torch.Tensor,
|
||||
rms_norm_weight: torch.Tensor, scale: torch.Tensor,
|
||||
scale_reciprocal: torch.Tensor, offset: torch.Tensor):
|
||||
def pattern(
|
||||
rms_norm_input: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
rms_norm_weight: torch.Tensor,
|
||||
scale: torch.Tensor,
|
||||
scale_reciprocal: torch.Tensor,
|
||||
offset: torch.Tensor,
|
||||
):
|
||||
"""
|
||||
Pattern for AddRMSNormQuant fusion.
|
||||
"""
|
||||
output = torch.ops.npu.npu_add_rms_norm(rms_norm_input, residual,
|
||||
rms_norm_weight, self.eps)
|
||||
output = torch.ops.npu.npu_add_rms_norm(rms_norm_input, residual, rms_norm_weight, self.eps)
|
||||
out0 = output[0]
|
||||
out1 = output[2]
|
||||
out0 = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(out0, True)
|
||||
quantized_output = torch.ops.vllm.quantize(out0, scale,
|
||||
scale_reciprocal,
|
||||
offset)
|
||||
quantized_output = torch.ops.vllm.quantize(out0, scale, scale_reciprocal, offset)
|
||||
return quantized_output, out1
|
||||
|
||||
def replacement(rms_norm_input: torch.Tensor, residual: torch.Tensor,
|
||||
rms_norm_weight: torch.Tensor, scale: torch.Tensor,
|
||||
scale_reciprocal: torch.Tensor, offset: torch.Tensor):
|
||||
def replacement(
|
||||
rms_norm_input: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
rms_norm_weight: torch.Tensor,
|
||||
scale: torch.Tensor,
|
||||
scale_reciprocal: torch.Tensor,
|
||||
offset: torch.Tensor,
|
||||
):
|
||||
"""
|
||||
Replacement for the AddRMSNormQuant fusion.
|
||||
"""
|
||||
output = torch.ops.npu.npu_add_rms_norm_quant(rms_norm_input,
|
||||
residual,
|
||||
rms_norm_weight,
|
||||
scale,
|
||||
offset,
|
||||
epsilon=self.eps)
|
||||
output = torch.ops.npu.npu_add_rms_norm_quant(
|
||||
rms_norm_input, residual, rms_norm_weight, scale, offset, epsilon=self.eps
|
||||
)
|
||||
quantized_output = output[0]
|
||||
out1 = output[2]
|
||||
quantized_output = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
|
||||
quantized_output, True)
|
||||
quantized_output = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(quantized_output, True)
|
||||
return quantized_output, out1
|
||||
|
||||
pm.register_replacement(pattern, replacement, self.get_inputs(),
|
||||
pm.fwd_only, pm_pass)
|
||||
pm.register_replacement(pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass)
|
||||
|
||||
|
||||
class AddRMSNormQuantSPPatternWithBias:
|
||||
|
||||
def __init__(self, vllm_config: VllmConfig, eps: float = 1e-6):
|
||||
self.vllm_config = vllm_config
|
||||
self.dtype = vllm_config.model_config.dtype
|
||||
@@ -227,53 +218,50 @@ class AddRMSNormQuantSPPatternWithBias:
|
||||
scale = torch.ones(4, device="npu", dtype=self.dtype)
|
||||
scale_reciprocal = torch.ones(4, device="npu", dtype=self.dtype)
|
||||
offset = torch.zeros(4, device="npu", dtype=self.dtype)
|
||||
return [
|
||||
rms_norm_input, residual, rms_norm_weight, scale, scale_reciprocal,
|
||||
offset, rmsnorm_bias
|
||||
]
|
||||
return [rms_norm_input, residual, rms_norm_weight, scale, scale_reciprocal, offset, rmsnorm_bias]
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass):
|
||||
|
||||
def pattern(rms_norm_input: torch.Tensor, residual: torch.Tensor,
|
||||
rms_norm_weight: torch.Tensor, scale: torch.Tensor,
|
||||
scale_reciprocal: torch.Tensor, offset: torch.Tensor,
|
||||
bias: torch.Tensor):
|
||||
def pattern(
|
||||
rms_norm_input: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
rms_norm_weight: torch.Tensor,
|
||||
scale: torch.Tensor,
|
||||
scale_reciprocal: torch.Tensor,
|
||||
offset: torch.Tensor,
|
||||
bias: torch.Tensor,
|
||||
):
|
||||
"""
|
||||
Pattern for AddRMSNormQuant fusion.
|
||||
"""
|
||||
output = torch.ops.npu.npu_add_rms_norm(rms_norm_input, residual,
|
||||
rms_norm_weight, self.eps)
|
||||
output = torch.ops.npu.npu_add_rms_norm(rms_norm_input, residual, rms_norm_weight, self.eps)
|
||||
out0 = output[0]
|
||||
out1 = output[2]
|
||||
out0 = out0 + bias
|
||||
out0 = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(out0, True)
|
||||
quantized_output = torch.ops.vllm.quantize(out0, scale,
|
||||
scale_reciprocal,
|
||||
offset)
|
||||
quantized_output = torch.ops.vllm.quantize(out0, scale, scale_reciprocal, offset)
|
||||
return quantized_output, out1
|
||||
|
||||
def replacement(rms_norm_input: torch.Tensor, residual: torch.Tensor,
|
||||
rms_norm_weight: torch.Tensor, scale: torch.Tensor,
|
||||
scale_reciprocal: torch.Tensor, offset: torch.Tensor,
|
||||
bias: torch.Tensor):
|
||||
def replacement(
|
||||
rms_norm_input: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
rms_norm_weight: torch.Tensor,
|
||||
scale: torch.Tensor,
|
||||
scale_reciprocal: torch.Tensor,
|
||||
offset: torch.Tensor,
|
||||
bias: torch.Tensor,
|
||||
):
|
||||
"""
|
||||
Replacement for the AddRMSNormQuant fusion.
|
||||
"""
|
||||
output = torch.ops.npu.npu_add_rms_norm_quant(rms_norm_input,
|
||||
residual,
|
||||
rms_norm_weight,
|
||||
scale,
|
||||
offset,
|
||||
epsilon=self.eps,
|
||||
beta=bias)
|
||||
output = torch.ops.npu.npu_add_rms_norm_quant(
|
||||
rms_norm_input, residual, rms_norm_weight, scale, offset, epsilon=self.eps, beta=bias
|
||||
)
|
||||
quantized_output = output[0]
|
||||
out1 = output[2]
|
||||
quantized_output = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
|
||||
quantized_output, True)
|
||||
quantized_output = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(quantized_output, True)
|
||||
return quantized_output, out1
|
||||
|
||||
pm.register_replacement(pattern, replacement, self.get_inputs(),
|
||||
pm.fwd_only, pm_pass)
|
||||
pm.register_replacement(pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass)
|
||||
|
||||
|
||||
class AddRMSNormQuantFusionPass(VllmInductorPass):
|
||||
@@ -283,25 +271,19 @@ class AddRMSNormQuantFusionPass(VllmInductorPass):
|
||||
|
||||
def __init__(self, vllm_config: VllmConfig):
|
||||
super().__init__(vllm_config)
|
||||
self.pattern_match_passes: PatternMatcherPass = PatternMatcherPass(
|
||||
pass_name="rmsnorm_quant_fusion_pass")
|
||||
self.pattern_match_passes: PatternMatcherPass = PatternMatcherPass(pass_name="rmsnorm_quant_fusion_pass")
|
||||
|
||||
dtype = vllm_config.model_config.dtype
|
||||
if dtype not in (torch.bfloat16, torch.float16):
|
||||
logger.debug("Quant fusion not enabled: unsupported dtype %s",
|
||||
dtype)
|
||||
logger.debug("Quant fusion not enabled: unsupported dtype %s", dtype)
|
||||
return
|
||||
|
||||
common_epsilons = [1e-5, 1e-6]
|
||||
for eps in common_epsilons:
|
||||
AddRMSNormQuantPattern(vllm_config,
|
||||
eps=eps).register(self.pattern_match_passes)
|
||||
AddRMSNormQuantPatternWithBias(vllm_config, eps=eps).register(
|
||||
self.pattern_match_passes)
|
||||
AddRMSNormQuantSPPattern(vllm_config, eps=eps).register(
|
||||
self.pattern_match_passes)
|
||||
AddRMSNormQuantSPPatternWithBias(vllm_config, eps=eps).register(
|
||||
self.pattern_match_passes)
|
||||
AddRMSNormQuantPattern(vllm_config, eps=eps).register(self.pattern_match_passes)
|
||||
AddRMSNormQuantPatternWithBias(vllm_config, eps=eps).register(self.pattern_match_passes)
|
||||
AddRMSNormQuantSPPattern(vllm_config, eps=eps).register(self.pattern_match_passes)
|
||||
AddRMSNormQuantSPPatternWithBias(vllm_config, eps=eps).register(self.pattern_match_passes)
|
||||
|
||||
def __call__(self, graph: torch.fx.Graph):
|
||||
self.begin()
|
||||
|
||||
@@ -17,8 +17,7 @@
|
||||
#
|
||||
import torch
|
||||
import torch._inductor.pattern_matcher as pm
|
||||
from torch._inductor.pattern_matcher import (PatternMatcherPass,
|
||||
PatternPrettyPrinter)
|
||||
from torch._inductor.pattern_matcher import PatternMatcherPass, PatternPrettyPrinter
|
||||
from vllm.attention.layer import Attention
|
||||
from vllm.compilation.vllm_inductor_pass import VllmInductorPass
|
||||
from vllm.config import VllmConfig, get_layers_from_vllm_config
|
||||
@@ -27,13 +26,7 @@ from vllm.logger import logger
|
||||
|
||||
|
||||
class QKNormRopeFusionPattern:
|
||||
|
||||
def __init__(self,
|
||||
vllm_config,
|
||||
head_dim,
|
||||
num_heads,
|
||||
num_kv_heads,
|
||||
eps=1e-6):
|
||||
def __init__(self, vllm_config, head_dim, num_heads, num_kv_heads, eps=1e-6):
|
||||
self.vllm_config = vllm_config
|
||||
self.head_dim = head_dim
|
||||
self.num_heads = num_heads
|
||||
@@ -45,65 +38,38 @@ class QKNormRopeFusionPattern:
|
||||
|
||||
def get_inputs(self):
|
||||
T = 5
|
||||
qkv = torch.empty(T,
|
||||
self.q_size + 2 * self.kv_size,
|
||||
dtype=torch.bfloat16,
|
||||
device="npu")
|
||||
q_weight = torch.empty(self.head_dim,
|
||||
dtype=torch.bfloat16,
|
||||
device="npu")
|
||||
k_weight = torch.empty(self.head_dim,
|
||||
dtype=torch.bfloat16,
|
||||
device="npu")
|
||||
cos = torch.empty(1,
|
||||
T,
|
||||
1,
|
||||
self.head_dim,
|
||||
dtype=torch.bfloat16,
|
||||
device="npu")
|
||||
sin = torch.empty(1,
|
||||
T,
|
||||
1,
|
||||
self.head_dim,
|
||||
dtype=torch.bfloat16,
|
||||
device="npu")
|
||||
qkv = torch.empty(T, self.q_size + 2 * self.kv_size, dtype=torch.bfloat16, device="npu")
|
||||
q_weight = torch.empty(self.head_dim, dtype=torch.bfloat16, device="npu")
|
||||
k_weight = torch.empty(self.head_dim, dtype=torch.bfloat16, device="npu")
|
||||
cos = torch.empty(1, T, 1, self.head_dim, dtype=torch.bfloat16, device="npu")
|
||||
sin = torch.empty(1, T, 1, self.head_dim, dtype=torch.bfloat16, device="npu")
|
||||
return [qkv, q_weight, k_weight, cos, sin]
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass):
|
||||
def pattern(
|
||||
qkv: torch.Tensor, q_weight: torch.Tensor, k_weight: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor
|
||||
):
|
||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||
|
||||
def pattern(qkv: torch.Tensor, q_weight: torch.Tensor,
|
||||
k_weight: torch.Tensor, cos: torch.Tensor,
|
||||
sin: torch.Tensor):
|
||||
q_by_head = q.view(*q.shape[:-1], q.shape[-1] // self.head_dim, self.head_dim)
|
||||
q_norm_out, _ = torch.ops.npu.npu_rms_norm(q_by_head, q_weight, self.eps)
|
||||
|
||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size],
|
||||
dim=-1)
|
||||
|
||||
q_by_head = q.view(*q.shape[:-1], q.shape[-1] // self.head_dim,
|
||||
self.head_dim)
|
||||
q_norm_out, _ = torch.ops.npu.npu_rms_norm(q_by_head, q_weight,
|
||||
self.eps)
|
||||
|
||||
k_by_head = k.view(*k.shape[:-1], k.shape[-1] // self.head_dim,
|
||||
self.head_dim)
|
||||
k_norm_out, _ = torch.ops.npu.npu_rms_norm(k_by_head, k_weight,
|
||||
self.eps)
|
||||
k_by_head = k.view(*k.shape[:-1], k.shape[-1] // self.head_dim, self.head_dim)
|
||||
k_norm_out, _ = torch.ops.npu.npu_rms_norm(k_by_head, k_weight, self.eps)
|
||||
|
||||
q_flat = q_norm_out.view(q.shape)
|
||||
q_reshape = q_flat.contiguous().view(1, q_flat.shape[0], -1,
|
||||
self.head_dim)
|
||||
q_reshape = q_flat.contiguous().view(1, q_flat.shape[0], -1, self.head_dim)
|
||||
|
||||
k_flat = k_norm_out.view(k.shape)
|
||||
k_reshape = k_flat.contiguous().view(1, k_flat.shape[0], -1,
|
||||
self.head_dim)
|
||||
k_reshape = k_flat.contiguous().view(1, k_flat.shape[0], -1, self.head_dim)
|
||||
|
||||
q_rope, k_rope = torch.ops.npu.npu_apply_rotary_pos_emb(
|
||||
q_reshape, k_reshape, cos, sin)
|
||||
q_rope, k_rope = torch.ops.npu.npu_apply_rotary_pos_emb(q_reshape, k_reshape, cos, sin)
|
||||
|
||||
return q_rope, k_rope, v
|
||||
|
||||
def replacement(qkv: torch.Tensor, q_weight: torch.Tensor,
|
||||
k_weight: torch.Tensor, cos: torch.Tensor,
|
||||
sin: torch.Tensor):
|
||||
def replacement(
|
||||
qkv: torch.Tensor, q_weight: torch.Tensor, k_weight: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor
|
||||
):
|
||||
results = torch.ops.vllm.qkv_rmsnorm_rope(
|
||||
input=qkv,
|
||||
q_weight=q_weight,
|
||||
@@ -115,22 +81,16 @@ class QKNormRopeFusionPattern:
|
||||
q_bias=None,
|
||||
k_bias=None,
|
||||
sin=sin,
|
||||
cos=cos)
|
||||
cos=cos,
|
||||
)
|
||||
|
||||
return results
|
||||
|
||||
pm.register_replacement(pattern, replacement, self.get_inputs(),
|
||||
pm.fwd_only, pm_pass)
|
||||
pm.register_replacement(pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass)
|
||||
|
||||
|
||||
class QKNormRopeFusionPatternWithBias:
|
||||
|
||||
def __init__(self,
|
||||
vllm_config,
|
||||
head_dim,
|
||||
num_heads,
|
||||
num_kv_heads,
|
||||
eps=1e-6):
|
||||
def __init__(self, vllm_config, head_dim, num_heads, num_kv_heads, eps=1e-6):
|
||||
self.head_dim = head_dim
|
||||
self.num_heads = num_heads
|
||||
self.num_kv_heads = num_kv_heads
|
||||
@@ -142,71 +102,55 @@ class QKNormRopeFusionPatternWithBias:
|
||||
|
||||
def get_inputs(self):
|
||||
T = 5
|
||||
qkv = torch.empty(T,
|
||||
self.q_size + 2 * self.kv_size,
|
||||
dtype=torch.bfloat16,
|
||||
device="npu")
|
||||
q_weight = torch.empty(self.head_dim,
|
||||
dtype=torch.bfloat16,
|
||||
device="npu")
|
||||
k_weight = torch.empty(self.head_dim,
|
||||
dtype=torch.bfloat16,
|
||||
device="npu")
|
||||
qkv = torch.empty(T, self.q_size + 2 * self.kv_size, dtype=torch.bfloat16, device="npu")
|
||||
q_weight = torch.empty(self.head_dim, dtype=torch.bfloat16, device="npu")
|
||||
k_weight = torch.empty(self.head_dim, dtype=torch.bfloat16, device="npu")
|
||||
q_bias = torch.empty(self.head_dim, dtype=torch.bfloat16, device="npu")
|
||||
k_bias = torch.empty(self.head_dim, dtype=torch.bfloat16, device="npu")
|
||||
cos = torch.empty(1,
|
||||
T,
|
||||
1,
|
||||
self.head_dim,
|
||||
dtype=torch.bfloat16,
|
||||
device="npu")
|
||||
sin = torch.empty(1,
|
||||
T,
|
||||
1,
|
||||
self.head_dim,
|
||||
dtype=torch.bfloat16,
|
||||
device="npu")
|
||||
cos = torch.empty(1, T, 1, self.head_dim, dtype=torch.bfloat16, device="npu")
|
||||
sin = torch.empty(1, T, 1, self.head_dim, dtype=torch.bfloat16, device="npu")
|
||||
|
||||
return [qkv, q_weight, k_weight, q_bias, k_bias, cos, sin]
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass):
|
||||
def pattern(
|
||||
qkv: torch.Tensor,
|
||||
q_weight: torch.Tensor,
|
||||
k_weight: torch.Tensor,
|
||||
q_bias: torch.Tensor,
|
||||
k_bias: torch.Tensor,
|
||||
cos: torch.Tensor,
|
||||
sin: torch.Tensor,
|
||||
):
|
||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||
|
||||
def pattern(qkv: torch.Tensor, q_weight: torch.Tensor,
|
||||
k_weight: torch.Tensor, q_bias: torch.Tensor,
|
||||
k_bias: torch.Tensor, cos: torch.Tensor,
|
||||
sin: torch.Tensor):
|
||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size],
|
||||
dim=-1)
|
||||
|
||||
q_by_head = q.view(*q.shape[:-1], q.shape[-1] // self.head_dim,
|
||||
self.head_dim)
|
||||
q_norm_out, _ = torch.ops.npu.npu_rms_norm(q_by_head, q_weight,
|
||||
self.eps)
|
||||
q_by_head = q.view(*q.shape[:-1], q.shape[-1] // self.head_dim, self.head_dim)
|
||||
q_norm_out, _ = torch.ops.npu.npu_rms_norm(q_by_head, q_weight, self.eps)
|
||||
q_normed = q_norm_out + q_bias
|
||||
|
||||
k_by_head = k.view(*k.shape[:-1], k.shape[-1] // self.head_dim,
|
||||
self.head_dim)
|
||||
k_norm_out, _ = torch.ops.npu.npu_rms_norm(k_by_head, k_weight,
|
||||
self.eps)
|
||||
k_by_head = k.view(*k.shape[:-1], k.shape[-1] // self.head_dim, self.head_dim)
|
||||
k_norm_out, _ = torch.ops.npu.npu_rms_norm(k_by_head, k_weight, self.eps)
|
||||
k_normed = k_norm_out + k_bias
|
||||
|
||||
q_flat = q_normed.view(q.shape)
|
||||
q_reshape = q_flat.contiguous().view(1, q_flat.shape[0], -1,
|
||||
self.head_dim)
|
||||
q_reshape = q_flat.contiguous().view(1, q_flat.shape[0], -1, self.head_dim)
|
||||
|
||||
k_flat = k_normed.view(k.shape)
|
||||
k_reshape = k_flat.contiguous().view(1, k_flat.shape[0], -1,
|
||||
self.head_dim)
|
||||
k_reshape = k_flat.contiguous().view(1, k_flat.shape[0], -1, self.head_dim)
|
||||
|
||||
q_rope, k_rope = torch.ops.npu.npu_apply_rotary_pos_emb(
|
||||
q_reshape, k_reshape, cos, sin)
|
||||
q_rope, k_rope = torch.ops.npu.npu_apply_rotary_pos_emb(q_reshape, k_reshape, cos, sin)
|
||||
|
||||
return q_rope, k_rope, v
|
||||
|
||||
def replacement(qkv: torch.Tensor, q_weight: torch.Tensor,
|
||||
k_weight: torch.Tensor, q_bias: torch.Tensor,
|
||||
k_bias: torch.Tensor, cos: torch.Tensor,
|
||||
sin: torch.Tensor):
|
||||
def replacement(
|
||||
qkv: torch.Tensor,
|
||||
q_weight: torch.Tensor,
|
||||
k_weight: torch.Tensor,
|
||||
q_bias: torch.Tensor,
|
||||
k_bias: torch.Tensor,
|
||||
cos: torch.Tensor,
|
||||
sin: torch.Tensor,
|
||||
):
|
||||
results = torch.ops.vllm.qkv_rmsnorm_rope(
|
||||
input=qkv,
|
||||
q_weight=q_weight,
|
||||
@@ -218,11 +162,11 @@ class QKNormRopeFusionPatternWithBias:
|
||||
q_bias=q_bias,
|
||||
k_bias=k_bias,
|
||||
cos=cos,
|
||||
sin=sin)
|
||||
sin=sin,
|
||||
)
|
||||
return results
|
||||
|
||||
pm.register_replacement(pattern, replacement, self.get_inputs(),
|
||||
pm.fwd_only, pm_pass)
|
||||
pm.register_replacement(pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass)
|
||||
|
||||
|
||||
class QKNormRopeFusionPass(VllmInductorPass):
|
||||
@@ -232,44 +176,38 @@ class QKNormRopeFusionPass(VllmInductorPass):
|
||||
|
||||
def __init__(self, vllm_config: VllmConfig):
|
||||
super().__init__(vllm_config)
|
||||
self.pattern_match_passes: PatternMatcherPass = PatternMatcherPass(
|
||||
pass_name="qknorm_rope_fusion_pass")
|
||||
self.pattern_match_passes: PatternMatcherPass = PatternMatcherPass(pass_name="qknorm_rope_fusion_pass")
|
||||
|
||||
dtype = vllm_config.model_config.dtype
|
||||
if dtype not in (torch.bfloat16, torch.float16):
|
||||
logger.debug(
|
||||
"QKNorm and Rope fusion not enabled: unsupported dtype %s",
|
||||
dtype)
|
||||
logger.debug("QKNorm and Rope fusion not enabled: unsupported dtype %s", dtype)
|
||||
return
|
||||
|
||||
# use one attn layer to get meta (such as head_dim) for QKNormRopeFusionPattern
|
||||
attn_layers: dict[str, Attention] = get_layers_from_vllm_config(
|
||||
vllm_config, Attention)
|
||||
attn_layers: dict[str, Attention] = get_layers_from_vllm_config(vllm_config, Attention)
|
||||
if len(attn_layers) == 0:
|
||||
logger.debug(
|
||||
"QKNorm and Rope fusion enabled, but no Attention layers were discovered."
|
||||
)
|
||||
logger.debug("QKNorm and Rope fusion enabled, but no Attention layers were discovered.")
|
||||
return
|
||||
layer = next(iter(attn_layers.values()))
|
||||
for epsilon in [1e-6, 1e-5]:
|
||||
if layer.head_size != 128:
|
||||
logger.debug(
|
||||
"QKNorm and Rope fusion not enabled: head_dim %d is not equal of 128",
|
||||
layer.head_size)
|
||||
logger.debug("QKNorm and Rope fusion not enabled: head_dim %d is not equal of 128", layer.head_size)
|
||||
continue
|
||||
QKNormRopeFusionPattern(vllm_config=vllm_config,
|
||||
head_dim=layer.head_size,
|
||||
num_heads=layer.num_heads,
|
||||
num_kv_heads=layer.num_kv_heads,
|
||||
eps=epsilon).register(
|
||||
self.pattern_match_passes)
|
||||
QKNormRopeFusionPattern(
|
||||
vllm_config=vllm_config,
|
||||
head_dim=layer.head_size,
|
||||
num_heads=layer.num_heads,
|
||||
num_kv_heads=layer.num_kv_heads,
|
||||
eps=epsilon,
|
||||
).register(self.pattern_match_passes)
|
||||
|
||||
QKNormRopeFusionPatternWithBias(vllm_config=vllm_config,
|
||||
head_dim=layer.head_size,
|
||||
num_heads=layer.num_heads,
|
||||
num_kv_heads=layer.num_kv_heads,
|
||||
eps=epsilon).register(
|
||||
self.pattern_match_passes)
|
||||
QKNormRopeFusionPatternWithBias(
|
||||
vllm_config=vllm_config,
|
||||
head_dim=layer.head_size,
|
||||
num_heads=layer.num_heads,
|
||||
num_kv_heads=layer.num_kv_heads,
|
||||
eps=epsilon,
|
||||
).register(self.pattern_match_passes)
|
||||
|
||||
def __call__(self, graph: torch.fx.Graph):
|
||||
self.begin()
|
||||
|
||||
Reference in New Issue
Block a user