[Lint]Style: Convert vllm-ascend/compilation to ruff format (#5912)

### What this PR does / why we need it?
Convert `vllm-ascend/compilation` to ruff format.

### Does this PR introduce _any_ user-facing change?
During this migration, we encountered some **errors** in our CI and
testing environments, such as:
```
vllm_ascend/utils.py:653: in <module>
    def register_ascend_customop(vllm_config: VllmConfig | None = None):
                                              ^^^^^^^^^^^^^^^^^
E   TypeError: unsupported operand type(s) for |: 'NoneType' and 'NoneType'
```

**1. Root Cause Analysis:**
The project uses a common pattern to break circular dependencies:
```python
if TYPE_CHECKING:
    from vllm.config import VllmConfig
else:
    VllmConfig = None  # Placeholder assigned at runtime
```
When Python parses the function definition `def
register_ascend_customop(vllm_config: VllmConfig | None)`, it attempts
to evaluate the expression `VllmConfig | None`.
Since `VllmConfig` is assigned `None` at runtime, the expression
effectively becomes `None | None`. In Python, `None` is an instance of
`NoneType`. While the `|` operator is implemented for Type objects
(classes), it is not supported for `NoneType` instances, leading to the
`TypeError` shown above.

**2. Solution:**
To maintain the modern `|` syntax required by our new linting standards
while preserving our dependency management strategy, I have introduced:
```python
from __future__ import annotations
```
at the top of the affected files. This enables **Postponed Evaluation of
Annotations (PEP 563)**.

**3. Impact and Benefits:**
- By enabling `annotations`, Python no longer executes the `VllmConfig |
None` operation during module load. Instead, it stores the annotation as
a string literal, completely avoiding the `None | None` calculation.
- We can keep the `VllmConfig = None` placeholders. This ensures that
other modules can still import these symbols without triggering an
`ImportError`, maintaining a stable dependency graph.
- IDEs and static type checkers (MyPy/Pyright) continue to resolve the
types correctly. This allows us to use modern syntax without sacrificing
type safety or runtime stability.
- The only side effect is that `__annotations__` will now return strings
instead of type objects. Since this module does not use runtime type
enforcement or reflection, this change has zero negative impact on
existing functionality.
### How was this patch tested?

- vLLM version: v0.13.0
- vLLM main:
11b6af5280

---------

Signed-off-by: MrZ20 <2609716663@qq.com>
This commit is contained in:
SILONG ZENG
2026-01-16 20:57:46 +08:00
committed by GitHub
parent 3af91e5ac4
commit 52086394ae
16 changed files with 996 additions and 1140 deletions

View File

@@ -49,7 +49,25 @@ line-length = 120
# Folder to be modified
exclude = [
"tests/**",
"vllm_ascend/**",
"vllm_ascend/_cann_ops_custom",
"vllm_ascend/attention",
"vllm_ascend/core",
"vllm_ascend/device",
"vllm_ascend/device_allocator",
"vllm_ascend/distributed",
"vllm_ascend/eplb",
"vllm_ascend/kv_offload",
"vllm_ascend/lora",
"vllm_ascend/model_loader",
"vllm_ascend/ops",
"vllm_ascend/patch",
"vllm_ascend/quantization",
"vllm_ascend/sample",
"vllm_ascend/spec_decode",
"vllm_ascend/worker",
"vllm_ascend/xlite",
"vllm_ascend/envs.py",
"vllm_ascend/batch_invariant.py",
]
[tool.ruff.lint]

View File

@@ -24,14 +24,17 @@ def register():
def register_connector():
from vllm_ascend.distributed.kv_transfer import register_connector
register_connector()
def register_model_loader():
from .model_loader.netloader import register_netloader
register_netloader()
def register_service_profiling():
from .profiling_config import generate_service_profiling_config
generate_service_profiling_config()

View File

@@ -14,7 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from typing import TYPE_CHECKING, Optional
from typing import TYPE_CHECKING
from vllm.logger import logger
from vllm.triton_utils import HAS_TRITON
@@ -32,18 +32,13 @@ class AscendConfig:
additional_config = vllm_config.additional_config if vllm_config.additional_config is not None else {}
xlite_graph_config = additional_config.get("xlite_graph_config", {})
self.xlite_graph_config = XliteGraphConfig(xlite_graph_config,
vllm_config)
self.xlite_graph_config = XliteGraphConfig(xlite_graph_config, vllm_config)
ascend_compilation_config = additional_config.get(
"ascend_compilation_config", {})
self.ascend_compilation_config = AscendCompilationConfig(
**ascend_compilation_config)
ascend_compilation_config = additional_config.get("ascend_compilation_config", {})
self.ascend_compilation_config = AscendCompilationConfig(**ascend_compilation_config)
finegrained_tp_config = additional_config.get("finegrained_tp_config",
{})
self.finegrained_tp_config = FinegrainedTPConfig(
finegrained_tp_config, vllm_config)
finegrained_tp_config = additional_config.get("finegrained_tp_config", {})
self.finegrained_tp_config = FinegrainedTPConfig(finegrained_tp_config, vllm_config)
eplb_config = additional_config.get("eplb_config", {})
self.eplb_config = EplbConfig(eplb_config)
@@ -51,10 +46,8 @@ class AscendConfig:
# Dump / PrecisionDebugger configuration
self.dump_config_path = additional_config.get("dump_config_path", None)
weight_prefetch_config = additional_config.get(
"weight_prefetch_config", {})
self.weight_prefetch_config = WeightPrefetchConfig(
weight_prefetch_config)
weight_prefetch_config = additional_config.get("weight_prefetch_config", {})
self.weight_prefetch_config = WeightPrefetchConfig(weight_prefetch_config)
self.layer_sharding = additional_config.get("layer_sharding", None)
logger.info_once(
f"Linear layer sharding enabled with config: {self.layer_sharding}. "
@@ -62,30 +55,25 @@ class AscendConfig:
"using it without these features may result in significant performance degradation."
)
self.enable_shared_expert_dp = additional_config.get(
"enable_shared_expert_dp",
False) and vllm_config.parallel_config.enable_expert_parallel
self.enable_shared_expert_dp = (
additional_config.get("enable_shared_expert_dp", False)
and vllm_config.parallel_config.enable_expert_parallel
)
if self.enable_shared_expert_dp:
from vllm_ascend.utils import enable_sp
assert enable_sp(vllm_config=vllm_config,
enable_shared_expert_dp=True)
self.multistream_overlap_shared_expert = additional_config.get(
"multistream_overlap_shared_expert", False)
self.multistream_overlap_gate = additional_config.get(
"multistream_overlap_gate", False)
self.recompute_scheduler_enable = additional_config.get(
"recompute_scheduler_enable", False)
self.enable_cpu_binding = additional_config.get(
"enable_cpu_binding", False)
assert enable_sp(vllm_config=vllm_config, enable_shared_expert_dp=True)
self.multistream_overlap_shared_expert = additional_config.get("multistream_overlap_shared_expert", False)
self.multistream_overlap_gate = additional_config.get("multistream_overlap_gate", False)
self.recompute_scheduler_enable = additional_config.get("recompute_scheduler_enable", False)
self.enable_cpu_binding = additional_config.get("enable_cpu_binding", False)
self.pd_tp_ratio = 1
self.pd_head_ratio = 1
self.num_head_replica = 1
if vllm_config.kv_transfer_config is not None and not vllm_config.model_config.is_deepseek_mla:
prefill_tp_size = vllm_config.kv_transfer_config.get_from_extra_config(
"prefill", {"tp_size": 1})["tp_size"]
decode_tp_size = vllm_config.kv_transfer_config.get_from_extra_config(
"decode", {"tp_size": 1})["tp_size"]
prefill_tp_size = vllm_config.kv_transfer_config.get_from_extra_config("prefill", {"tp_size": 1})["tp_size"]
decode_tp_size = vllm_config.kv_transfer_config.get_from_extra_config("decode", {"tp_size": 1})["tp_size"]
assert prefill_tp_size % decode_tp_size == 0, "Prefill TP size must be divisible by Decode TP size."
self.pd_tp_ratio = prefill_tp_size // decode_tp_size
if self.pd_tp_ratio > 1:
@@ -106,36 +94,29 @@ class AscendConfig:
)
if self.pd_tp_ratio == 0:
raise AssertionError(
"Only support P node tp size lagger then D node tp size")
self.SLO_limits_for_dynamic_batch = additional_config.get(
"SLO_limits_for_dynamic_batch", -1)
raise AssertionError("Only support P node tp size lagger then D node tp size")
self.SLO_limits_for_dynamic_batch = additional_config.get("SLO_limits_for_dynamic_batch", -1)
from vllm_ascend.utils import get_flashcomm2_config_and_validate
self.flashcomm2_oproj_tensor_parallel_size = get_flashcomm2_config_and_validate(
self, vllm_config)
self.enable_npugraph_ex = additional_config.get(
"enable_npugraph_ex", False)
self.flashcomm2_oproj_tensor_parallel_size = get_flashcomm2_config_and_validate(self, vllm_config)
self.enable_npugraph_ex = additional_config.get("enable_npugraph_ex", False)
# We find that _npu_paged_attention still performs better than
# npu_fused_infer_attention_score in some cases. We allow to execute
# _npu_paged_attention in this cases. This should be removed once
# npu_fused_infer_attention_score performs better on all scenarios.
self.pa_shape_list = additional_config.get("pa_shape_list", [])
self.enable_async_exponential = bool(
additional_config.get("enable_async_exponential", False))
self.enable_async_exponential = bool(additional_config.get("enable_async_exponential", False))
self.enable_kv_nz = additional_config.get("enable_kv_nz", False)
if self.enable_kv_nz:
use_sparse = hasattr(vllm_config.model_config.hf_text_config,
"index_topk")
use_sparse = hasattr(vllm_config.model_config.hf_text_config, "index_topk")
if not vllm_config.model_config.is_deepseek_mla or use_sparse:
raise RuntimeError(
"enable_kv_nz is only supported for mla currently.")
if vllm_config.kv_transfer_config is None \
or not vllm_config.kv_transfer_config.is_kv_consumer:
raise RuntimeError("enable_kv_nz is only supported for mla currently.")
if vllm_config.kv_transfer_config is None or not vllm_config.kv_transfer_config.is_kv_consumer:
raise NotImplementedError(
"enable_kv_nz is only supported in pd scenario and can "
"only be used in D node.")
"enable_kv_nz is only supported in pd scenario and can only be used in D node."
)
class FinegrainedTPConfig:
@@ -144,40 +125,28 @@ class FinegrainedTPConfig:
"""
def __init__(self, finegrained_tp_config: dict, vllm_config):
self.oproj_tensor_parallel_size = finegrained_tp_config.get(
"oproj_tensor_parallel_size", 0)
self.lmhead_tensor_parallel_size = finegrained_tp_config.get(
"lmhead_tensor_parallel_size", 0)
self.embedding_tensor_parallel_size = finegrained_tp_config.get(
"embedding_tensor_parallel_size", 0)
self.mlp_tensor_parallel_size = finegrained_tp_config.get(
"mlp_tensor_parallel_size", 0)
self.oproj_tensor_parallel_size = finegrained_tp_config.get("oproj_tensor_parallel_size", 0)
self.lmhead_tensor_parallel_size = finegrained_tp_config.get("lmhead_tensor_parallel_size", 0)
self.embedding_tensor_parallel_size = finegrained_tp_config.get("embedding_tensor_parallel_size", 0)
self.mlp_tensor_parallel_size = finegrained_tp_config.get("mlp_tensor_parallel_size", 0)
enabled_configs = []
if self.oproj_tensor_parallel_size > 0:
enabled_configs.append(
f"oproj_tensor_parallel_size={self.oproj_tensor_parallel_size}"
)
# dummy_run does not run the entire attention module in eager mode,, so the o_proj tp split can only be used in graph mode.
enabled_configs.append(f"oproj_tensor_parallel_size={self.oproj_tensor_parallel_size}")
# dummy_run does not run the entire attention module in eager mode,
# so the o_proj tp split can only be used in graph mode.
if vllm_config.model_config.enforce_eager is True:
raise AssertionError(
"oproj_tensor_parallel_size is only supported in graph mode"
)
raise AssertionError("oproj_tensor_parallel_size is only supported in graph mode")
if vllm_config.kv_transfer_config is None or not vllm_config.kv_transfer_config.is_kv_consumer:
raise AssertionError(
"oproj_tensor_parallel_size is only supported in pd scenario and can only be used in D node."
)
if self.lmhead_tensor_parallel_size > 0:
enabled_configs.append(
f"lmhead_tensor_parallel_size={self.lmhead_tensor_parallel_size}"
)
enabled_configs.append(f"lmhead_tensor_parallel_size={self.lmhead_tensor_parallel_size}")
if self.embedding_tensor_parallel_size > 0:
enabled_configs.append(
f"embedding_tensor_parallel_size={self.embedding_tensor_parallel_size}"
)
enabled_configs.append(f"embedding_tensor_parallel_size={self.embedding_tensor_parallel_size}")
if self.mlp_tensor_parallel_size > 0:
enabled_configs.append(
f"mlp_tensor_parallel_size={self.mlp_tensor_parallel_size}")
enabled_configs.append(f"mlp_tensor_parallel_size={self.mlp_tensor_parallel_size}")
module_tp_sizes = [
self.oproj_tensor_parallel_size,
self.lmhead_tensor_parallel_size,
@@ -186,11 +155,9 @@ class FinegrainedTPConfig:
]
for module_tp_size in module_tp_sizes:
if module_tp_size > 0 and vllm_config.parallel_config.data_parallel_size % module_tp_size != 0:
raise AssertionError(
"module tp sizes must divide data_parallel_size")
raise AssertionError("module tp sizes must divide data_parallel_size")
if any(size > 0 for size in module_tp_sizes) and enabled_configs:
logger.info(
f"finegrained_tp_config enabled: {', '.join(enabled_configs)}")
logger.info(f"finegrained_tp_config enabled: {', '.join(enabled_configs)}")
class AscendCompilationConfig:
@@ -202,10 +169,7 @@ class AscendCompilationConfig:
deployed on Ascend platforms.
"""
def __init__(self,
fuse_norm_quant: bool = True,
fuse_qknorm_rope: bool = False,
**kwargs):
def __init__(self, fuse_norm_quant: bool = True, fuse_qknorm_rope: bool = False, **kwargs):
"""
Initialize the configuration.
@@ -236,7 +200,8 @@ class XliteGraphConfig:
)
if vllm_config.parallel_config.pipeline_parallel_size > 1:
raise RuntimeError(
"Xlite graph mode is not compatible with pipeline parallelism. Please set pipeline_parallel_size to 1."
"Xlite graph mode is not compatible with pipeline parallelism. "
"Please set pipeline_parallel_size to 1."
)
if vllm_config.cache_config.block_size != 128:
raise RuntimeError(
@@ -254,21 +219,19 @@ class WeightPrefetchConfig:
"qkv": 1.0,
"o": 1.0,
},
"moe": {
"gate_up": 0.8
}
"moe": {"gate_up": 0.8},
}
def __init__(self, weight_prefetch_config: dict):
self.enabled = weight_prefetch_config.get("enabled", False)
self.prefetch_ratio = weight_prefetch_config.get(
"prefetch_ratio", self.prefetch_ratio)
self.prefetch_ratio = weight_prefetch_config.get("prefetch_ratio", self.prefetch_ratio)
class EplbConfig:
"""
Configuration Object for xlite_graph_config from additional_config
"""
_defaults = {
"dynamic_eplb": False,
"expert_map_path": None,
@@ -276,10 +239,12 @@ class EplbConfig:
"algorithm_execution_interval": 30,
"expert_map_record_path": None,
"num_redundant_experts": 0,
"eplb_policy_type": 1
"eplb_policy_type": 1,
}
def __init__(self, user_config: dict = {}):
def __init__(self, user_config: dict | None = None):
if user_config is None:
user_config = {}
self.config = self._defaults.copy()
if user_config and isinstance(user_config, dict):
for key, value in user_config.items():
@@ -307,27 +272,21 @@ class EplbConfig:
raise TypeError("The expert_map_record_path is not json.")
dirname = os.path.dirname(self.expert_map_record_path)
os.makedirs(dirname, exist_ok=True)
for key in [
"expert_heat_collection_interval",
"algorithm_execution_interval", "num_redundant_experts"
]:
for key in ["expert_heat_collection_interval", "algorithm_execution_interval", "num_redundant_experts"]:
if not isinstance(self.config[key], int):
raise TypeError(f"{key} must be an integer")
if self.config[key] < 0: # type: ignore
raise ValueError(
f"{key} must greater than 0; got {self.config[key]} instead"
)
raise ValueError(f"{key} must greater than 0; got {self.config[key]} instead")
if self.eplb_policy_type not in [0, 1, 2, 3]:
raise ValueError("eplb_policy_type must in [0, 1, 2, 3]")
_ASCEND_CONFIG: Optional[AscendConfig] = None
_ASCEND_CONFIG: AscendConfig | None = None
def init_ascend_config(vllm_config):
additional_config = vllm_config.additional_config if vllm_config.additional_config is not None else {}
refresh = additional_config.get("refresh",
False) if additional_config else False
refresh = additional_config.get("refresh", False) if additional_config else False
global _ASCEND_CONFIG
if _ASCEND_CONFIG is not None and not refresh:
return _ASCEND_CONFIG
@@ -343,7 +302,5 @@ def clear_ascend_config():
def get_ascend_config():
global _ASCEND_CONFIG
if _ASCEND_CONFIG is None:
raise RuntimeError(
"Ascend config is not initialized. Please call init_ascend_config first."
)
raise RuntimeError("Ascend config is not initialized. Please call init_ascend_config first.")
return _ASCEND_CONFIG

View File

@@ -1,21 +1,25 @@
import math
from contextlib import contextmanager
from enum import Enum
from typing import Any, Optional
from typing import Any
import torch
from vllm.config import CUDAGraphMode, VllmConfig
from vllm.distributed import (get_dp_group, get_ep_group,
get_tensor_model_parallel_world_size)
from vllm.forward_context import (BatchDescriptor, get_forward_context,
set_forward_context)
from vllm.distributed import get_dp_group, get_ep_group, get_tensor_model_parallel_world_size
from vllm.forward_context import BatchDescriptor, get_forward_context, set_forward_context
import vllm_ascend.envs as envs_ascend
from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.utils import (AscendDeviceType, enable_sp, flashcomm2_enable,
get_ascend_device_type, has_layer_idx,
is_drafter_moe_model, is_moe_model,
speculative_enable_dispatch_gmm_combine_decode)
from vllm_ascend.utils import (
AscendDeviceType,
enable_sp,
flashcomm2_enable,
get_ascend_device_type,
has_layer_idx,
is_drafter_moe_model,
is_moe_model,
speculative_enable_dispatch_gmm_combine_decode,
)
class MoECommType(Enum):
@@ -31,13 +35,14 @@ def set_ascend_forward_context(
vllm_config: VllmConfig,
virtual_engine: int = 0,
num_tokens: int = 0,
num_tokens_across_dp: Optional[torch.Tensor] = None,
num_tokens_across_dp: torch.Tensor | None = None,
in_profile_run: bool = False,
num_actual_tokens: Optional[int] = None,
num_actual_tokens: int | None = None,
aclgraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
batch_descriptor: Optional[BatchDescriptor] = None,
batch_descriptor: BatchDescriptor | None = None,
model_instance: torch.nn.Module = None,
is_draft_model=False):
is_draft_model=False,
):
"""A context manager that stores the current forward context,
can be attention metadata, etc.
We add some additional param into forward_context.
@@ -53,10 +58,9 @@ def set_ascend_forward_context(
):
forward_context = get_forward_context()
from vllm_ascend.ops.fused_moe.moe_comm_method import \
get_moe_comm_method
moe_comm_type = select_moe_comm_method(num_tokens, vllm_config,
is_draft_model)
from vllm_ascend.ops.fused_moe.moe_comm_method import get_moe_comm_method
moe_comm_type = select_moe_comm_method(num_tokens, vllm_config, is_draft_model)
forward_context.moe_comm_type = moe_comm_type
forward_context.moe_comm_method = get_moe_comm_method(moe_comm_type)
@@ -68,31 +72,28 @@ def set_ascend_forward_context(
# due to multiple warmups before actual capturing
forward_context.capturing = False
# set for sequence parallelism, 1000 is the batch size concurrency threshold for enabling the flashcomm_v1 or sequence_parallelism feature.
# Currently, it is an empirical value. In normal scenarios, if the concurrency exceeds this threshold,
# the performance benefits can be maximized. Conversely, if the concurrency is below the threshold,
# set for sequence parallelism, 1000 is the batch size concurrency threshold
# for enabling the flashcomm_v1 or sequence_parallelism feature.
# Currently, it is an empirical value. In normal scenarios, if the concurrency
# exceeds this threshold, the performance benefits can be maximized.
# Conversely, if the concurrency is below the threshold,
# the performance may degrade due to the switching of communication methods.
mmrs_fusion = True
# main model and drafter model may have different architecture
is_context_moe_model = is_drafter_moe_model(vllm_config) \
if is_draft_model else is_moe_model(vllm_config)
is_context_moe_model = is_drafter_moe_model(vllm_config) if is_draft_model else is_moe_model(vllm_config)
if is_context_moe_model:
sp_enabled = enable_sp(vllm_config) and num_tokens is not None
mmrs_fusion = False
else:
sp_enabled = enable_sp(vllm_config) and \
num_tokens is not None and num_tokens > 1000
sp_enabled = enable_sp(vllm_config) and num_tokens is not None and num_tokens > 1000
forward_context.mmrs_fusion = mmrs_fusion
forward_context.num_tokens = num_tokens
forward_context.sp_enabled = sp_enabled
# TODO(Levi-JQ): another PR to normalize the enabling logic for sp/fc2
forward_context.flashcomm_v2_enabled = flashcomm2_enable(
) and tp_world_size > 1 and num_tokens is not None
forward_context.flashcomm_v2_enabled = flashcomm2_enable() and tp_world_size > 1 and num_tokens is not None
if (forward_context.sp_enabled
or forward_context.flashcomm_v2_enabled):
pad_size = (tp_world_size -
(num_tokens % tp_world_size)) % tp_world_size
if forward_context.sp_enabled or forward_context.flashcomm_v2_enabled:
pad_size = (tp_world_size - (num_tokens % tp_world_size)) % tp_world_size
forward_context.pad_size = pad_size
# set this for rope forward_oot using
@@ -106,9 +107,12 @@ def set_ascend_forward_context(
# TODO(rjg-lyh): refactor mlp weight prefetch method
# set for mlp weight prefetch
prefetch_mlp_enabled = envs_ascend.VLLM_ASCEND_ENABLE_PREFETCH_MLP and \
forward_context.layer_idx is not None and \
num_tokens is not None and num_tokens < 500
prefetch_mlp_enabled = (
envs_ascend.VLLM_ASCEND_ENABLE_PREFETCH_MLP
and forward_context.layer_idx is not None
and num_tokens is not None
and num_tokens < 500
)
if prefetch_mlp_enabled:
forward_context.prefetch_mlp_gate_up_proj = False
forward_context.prefetch_mlp_down_proj = False
@@ -121,12 +125,9 @@ def set_ascend_forward_context(
dp_world_size = get_dp_group().world_size
if dp_world_size > 1 and forward_context.dp_metadata is not None:
max_tokens_across_dp = \
forward_context.dp_metadata.max_tokens_across_dp_cpu.item()
if (forward_context.sp_enabled
or forward_context.flashcomm_v2_enabled):
padded_length = (max_tokens_across_dp + tp_world_size -
1) // tp_world_size * tp_world_size
max_tokens_across_dp = forward_context.dp_metadata.max_tokens_across_dp_cpu.item()
if forward_context.sp_enabled or forward_context.flashcomm_v2_enabled:
padded_length = (max_tokens_across_dp + tp_world_size - 1) // tp_world_size * tp_world_size
pad_size = padded_length - num_tokens
forward_context.padded_length = padded_length
forward_context.pad_size = pad_size
@@ -139,12 +140,10 @@ def set_ascend_forward_context(
if num_actual_tokens is None:
num_actual_tokens = num_tokens
# NOTE: token num which need to pad to when mc2
forward_context.padded_num_tokens = math.ceil(
max_tokens_across_dp / tp_world_size) * tp_world_size
forward_context.padded_num_tokens = math.ceil(max_tokens_across_dp / tp_world_size) * tp_world_size
reserved_mc2_mask = get_mc2_mask()
if reserved_mc2_mask is not None:
mc2_mask = reserved_mc2_mask[:forward_context.
padded_num_tokens]
mc2_mask = reserved_mc2_mask[: forward_context.padded_num_tokens]
mc2_mask[:num_actual_tokens] = True
mc2_mask[num_actual_tokens:] = False
forward_context.mc2_mask = mc2_mask
@@ -155,14 +154,13 @@ def set_ascend_forward_context(
pass
_mc2_tokens_capacity: Optional[int] = None
_reserved_mc2_mask: Optional[torch.Tensor] = None
_sin: Optional[torch.Tensor] = None
_cos: Optional[torch.Tensor] = None
_mc2_tokens_capacity: int | None = None
_reserved_mc2_mask: torch.Tensor | None = None
_sin: torch.Tensor | None = None
_cos: torch.Tensor | None = None
def set_mc2_tokens_capacity(vllm_config, max_num_reqs,
uniform_decode_query_len):
def set_mc2_tokens_capacity(vllm_config, max_num_reqs, uniform_decode_query_len):
global _mc2_tokens_capacity
if _mc2_tokens_capacity is not None:
return
@@ -186,9 +184,7 @@ def set_mc2_mask(vllm_config, device):
if _reserved_mc2_mask is not None:
return
if is_moe_model(vllm_config):
_reserved_mc2_mask = torch.zeros(get_mc2_tokens_capacity(),
dtype=torch.bool,
device=device)
_reserved_mc2_mask = torch.zeros(get_mc2_tokens_capacity(), dtype=torch.bool, device=device)
else:
_reserved_mc2_mask = None
@@ -197,9 +193,7 @@ def get_mc2_mask():
return _reserved_mc2_mask
def select_moe_comm_method(num_tokens: int,
vllm_config: VllmConfig,
is_draft_model=False) -> Optional[MoECommType]:
def select_moe_comm_method(num_tokens: int, vllm_config: VllmConfig, is_draft_model=False) -> MoECommType | None:
"""Select the MoE communication method according to parallel settings,
device generation, token count, and quantization.
@@ -227,16 +221,19 @@ def select_moe_comm_method(num_tokens: int,
mc2_tokens_capacity = get_mc2_tokens_capacity()
soc_version = get_ascend_device_type()
quant_type = getattr(
vllm_config.model_config.hf_text_config, 'moe_quantize',
getattr(vllm_config.model_config.hf_text_config, 'quantize', None))
vllm_config.model_config.hf_text_config,
"moe_quantize",
getattr(vllm_config.model_config.hf_text_config, "quantize", None),
)
if not vllm_config.parallel_config.enable_expert_parallel or get_ep_group(
).world_size == 1:
if not vllm_config.parallel_config.enable_expert_parallel or get_ep_group().world_size == 1:
moe_comm_type = MoECommType.ALLGATHER
elif soc_version in {AscendDeviceType.A2}:
if (num_tokens <= mc2_tokens_capacity
and vllm_config.parallel_config.world_size_across_dp /
vllm_config.parallel_config.pipeline_parallel_size >= 16):
if (
num_tokens <= mc2_tokens_capacity
and vllm_config.parallel_config.world_size_across_dp / vllm_config.parallel_config.pipeline_parallel_size
>= 16
):
moe_comm_type = MoECommType.MC2
else:
moe_comm_type = MoECommType.ALLGATHER
@@ -246,15 +243,13 @@ def select_moe_comm_method(num_tokens: int,
# TODO: drop the EP-size guard when dispatch_ffn_combine supports larger EP sizes
# TODO: drop speculative method guard when dispatch_gmm_combine_decode supports w16a16
fused_mc2_enable = envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 and quant_type == "w8a8_dynamic"
dispatch_ffn_combine_enable = get_ep_group().world_size <= 32 and (
not is_draft_model) and (not dynamic_eplb)
dispatch_ffn_combine_enable = get_ep_group().world_size <= 32 and (not is_draft_model) and (not dynamic_eplb)
if num_tokens <= mc2_tokens_capacity:
fused_decode_enable = fused_mc2_enable
if envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 == 1:
fused_decode_enable = fused_mc2_enable and dispatch_ffn_combine_enable
elif envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 == 2:
fused_decode_enable = fused_mc2_enable and \
speculative_enable_dispatch_gmm_combine_decode(vllm_config)
fused_decode_enable = fused_mc2_enable and speculative_enable_dispatch_gmm_combine_decode(vllm_config)
moe_comm_type = MoECommType.FUSED_MC2 if fused_decode_enable else MoECommType.MC2
else:
fused_prefill_enable = fused_mc2_enable

View File

@@ -2,9 +2,10 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import dataclasses
from collections.abc import Callable
from contextlib import ExitStack
from dataclasses import dataclass
from typing import Any, Callable, Optional
from typing import Any
from unittest.mock import patch
import numpy as np
@@ -27,12 +28,12 @@ from ..utils import weak_ref_tensors
@dataclasses.dataclass
class ACLGraphEntry:
batch_descriptor: BatchDescriptor
aclgraph: Optional[torch.npu.NPUGraph] = None
output: Optional[Any] = None
aclgraph: torch.npu.NPUGraph | None = None
output: Any | None = None
# for aclgraph debugging, track the input addresses
# during capture, and check if they are the same during replay
input_addresses: Optional[list[int]] = None
input_addresses: list[int] | None = None
class ACLGraphWrapper:
@@ -60,11 +61,13 @@ class ACLGraphWrapper:
guaranteed when VLLM_LOGGING_LEVEL == "DEBUG".
"""
def __init__(self,
def __init__(
self,
runnable: Callable,
vllm_config: VllmConfig,
runtime_mode: CUDAGraphMode,
cudagraph_options: Optional[CUDAGraphOptions] = None):
cudagraph_options: CUDAGraphOptions | None = None,
):
self.runnable = runnable
self.vllm_config = vllm_config
self.runtime_mode = runtime_mode
@@ -83,15 +86,13 @@ class ACLGraphWrapper:
self.aclgraph_options = cudagraph_options
# the entries for different batch descriptors that we need to capture
# aclgraphs for.
self.concrete_aclgraph_entries: dict[BatchDescriptor, ACLGraphEntry]\
= {}
self.concrete_aclgraph_entries: dict[BatchDescriptor, ACLGraphEntry] = {}
def __getattr__(self, key: str):
# allow accessing the attributes of the runnable.
if hasattr(self.runnable, key):
return getattr(self.runnable, key)
raise AttributeError(f"Attribute {key} not exists in the runnable of "
f"aclgraph wrapper: {self.runnable}")
raise AttributeError(f"Attribute {key} not exists in the runnable of aclgraph wrapper: {self.runnable}")
def unwrap(self) -> Callable:
# in case we need to access the original runnable.
@@ -102,8 +103,7 @@ class ACLGraphWrapper:
batch_descriptor = forward_context.batch_descriptor
aclgraph_runtime_mode = forward_context.cudagraph_runtime_mode
if aclgraph_runtime_mode == CUDAGraphMode.NONE or \
aclgraph_runtime_mode != self.runtime_mode:
if aclgraph_runtime_mode == CUDAGraphMode.NONE or aclgraph_runtime_mode != self.runtime_mode:
# CUDAGraphMode.NONE could mean the profile run, a warmup run, or
# running without aclgraphs.
# We do not trigger capture/replay if the runtime mode is not
@@ -114,8 +114,7 @@ class ACLGraphWrapper:
if batch_descriptor not in self.concrete_aclgraph_entries:
# create a new entry for this batch descriptor
self.concrete_aclgraph_entries[batch_descriptor] = \
ACLGraphEntry(batch_descriptor=batch_descriptor)
self.concrete_aclgraph_entries[batch_descriptor] = ACLGraphEntry(batch_descriptor=batch_descriptor)
entry = self.concrete_aclgraph_entries[batch_descriptor]
@@ -125,14 +124,11 @@ class ACLGraphWrapper:
# capturing is fast, we don't need to log it for every
# shape. E.g. we only log it for the first subgraph in
# piecewise mode.
logger.debug("Capturing a aclgraph on (%s,%s)",
self.runtime_mode.name, entry.batch_descriptor)
logger.debug("Capturing a aclgraph on (%s,%s)", self.runtime_mode.name, entry.batch_descriptor)
# validate that aclgraph capturing is legal at this point.
validate_cudagraph_capturing_enabled()
input_addresses = [
x.data_ptr() for x in args if isinstance(x, torch.Tensor)
]
input_addresses = [x.data_ptr() for x in args if isinstance(x, torch.Tensor)]
entry.input_addresses = input_addresses
aclgraph = torch.npu.NPUGraph()
@@ -145,8 +141,7 @@ class ACLGraphWrapper:
# therefore, we only run gc for the first graph,
# and disable gc for the rest of the graphs.
stack.enter_context(patch("gc.collect", lambda: None))
stack.enter_context(
patch("torch.npu.empty_cache", lambda: None))
stack.enter_context(patch("torch.npu.empty_cache", lambda: None))
# mind-exploding: carefully manage the reference and memory.
forward_context.capturing = True
@@ -183,13 +178,12 @@ class ACLGraphWrapper:
if self.is_debugging_mode:
# check if the input addresses are the same
new_input_addresses = [
x.data_ptr() for x in args if isinstance(x, torch.Tensor)
]
new_input_addresses = [x.data_ptr() for x in args if isinstance(x, torch.Tensor)]
assert new_input_addresses == entry.input_addresses, (
f"Input addresses for aclgraphs are different "
f"during replay. Expected {entry.input_addresses}, "
f"got {new_input_addresses}")
f"got {new_input_addresses}"
)
logger.info_once("Replaying aclgraph")
# In async scheduling or multi-threaded (MT) scenarios when graph mode is FULL, it is possible that
@@ -209,8 +203,7 @@ def weak_ref_workspaces(params):
for num_tokens in params.workspaces:
if params.workspaces[num_tokens] is None:
continue
params.workspaces[num_tokens] = weak_ref_tensors(
params.workspaces[num_tokens])
params.workspaces[num_tokens] = weak_ref_tensors(params.workspaces[num_tokens])
def _update_attn_pa_params(update_stream, forward_context, runtime_shape):
@@ -254,9 +247,11 @@ def _update_attn_pa_params(update_stream, forward_context, runtime_shape):
scale_value=scale,
block_table=block_table,
context_lens=seq_lens,
out=output)
out=output,
)
torch.npu.graph_task_update_begin(update_stream, handle)
torch_npu._npu_paged_attention(query=query,
torch_npu._npu_paged_attention(
query=query,
key_cache=key_cache,
value_cache=value_cache,
num_kv_heads=num_kv_heads,
@@ -265,7 +260,8 @@ def _update_attn_pa_params(update_stream, forward_context, runtime_shape):
block_table=block_table,
context_lens=seq_lens,
out=output,
workspace=workspace)
workspace=workspace,
)
torch.npu.graph_task_update_end(update_stream)
event.record(update_stream)
@@ -287,13 +283,24 @@ def _update_attn_fia_params(update_stream, forward_context, runtime_shape):
graph_params.handles[runtime_shape],
graph_params.events[runtime_shape],
):
(query, key_cache, value, block_tables, attn_mask, block_size,
seq_lens, query_start_loc, num_kv_heads, num_heads, scale,
attn_output, softmax_lse) = param
(
query,
key_cache,
value,
block_tables,
attn_mask,
block_size,
seq_lens,
query_start_loc,
num_kv_heads,
num_heads,
scale,
attn_output,
softmax_lse,
) = param
seq_lens = forward_context.attn_metadata[key].seq_lens_list
actual_seq_lengths_q = forward_context.attn_metadata[
key].actual_seq_lengths_q
actual_seq_lengths_q = forward_context.attn_metadata[key].actual_seq_lengths_q
torch.npu.graph_task_update_begin(update_stream, handle)
torch_npu.npu_fused_infer_attention_score.out(
query=query,
@@ -317,16 +324,14 @@ def _update_attn_fia_params(update_stream, forward_context, runtime_shape):
event.record(update_stream)
def update_attn_params(update_stream, forward_context, runtime_shape,
vllm_config):
def update_attn_params(update_stream, forward_context, runtime_shape, vllm_config):
if using_paged_attention(runtime_shape, vllm_config):
_update_attn_pa_params(update_stream, forward_context, runtime_shape)
else:
_update_attn_fia_params(update_stream, forward_context, runtime_shape)
def update_mla_attn_params(update_stream, forward_context, runtime_shape,
speculative_config):
def update_mla_attn_params(update_stream, forward_context, runtime_shape, speculative_config):
if forward_context.is_draft_model:
graph_params = get_draft_graph_params()
else:
@@ -340,36 +345,39 @@ def update_mla_attn_params(update_stream, forward_context, runtime_shape,
graph_params.handles[runtime_shape],
graph_params.events[runtime_shape],
):
(q_nope, k_nope, q_pe, k_pe, num_heads, num_kv_heads, input_layout,
attn_mask, sparse_mode, scale, block_table, block_size,
seq_lens_list, actual_seq_lengths, attn_output,
softmax_lse) = param
seq_lens_list = forward_context.attn_metadata[
key].decode.seq_lens_list
if speculative_config and speculative_config.method == "mtp" \
and not forward_context.is_draft_model:
actual_seq_lengths = forward_context.attn_metadata[
key].decode.actual_seq_lengths_q
(
q_nope,
k_nope,
q_pe,
k_pe,
num_heads,
num_kv_heads,
input_layout,
attn_mask,
sparse_mode,
scale,
block_table,
block_size,
seq_lens_list,
actual_seq_lengths,
attn_output,
softmax_lse,
) = param
seq_lens_list = forward_context.attn_metadata[key].decode.seq_lens_list
if speculative_config and speculative_config.method == "mtp" and not forward_context.is_draft_model:
actual_seq_lengths = forward_context.attn_metadata[key].decode.actual_seq_lengths_q
spec_multiple = speculative_config.num_speculative_tokens + 1
seq_lens_list = seq_lens_list + [0] * (
runtime_shape // spec_multiple - len(seq_lens_list))
actual_seq_lengths = [
spec_multiple * (i + 1)
for i in range(runtime_shape // spec_multiple)
]
seq_lens_list = seq_lens_list + [0] * (runtime_shape // spec_multiple - len(seq_lens_list))
actual_seq_lengths = [spec_multiple * (i + 1) for i in range(runtime_shape // spec_multiple)]
elif forward_context.is_draft_model:
actual_seq_lengths = forward_context.attn_metadata[
key].decode.actual_seq_lengths_q
block_table = forward_context.attn_metadata[
key].decode.block_table
actual_seq_lengths = forward_context.attn_metadata[key].decode.actual_seq_lengths_q
block_table = forward_context.attn_metadata[key].decode.block_table
# TODO: This is a hack and should be fixed in the future.
if speculative_config.disable_padded_drafter_batch:
block_table = block_table[: len(actual_seq_lengths)]
seq_lens_list = seq_lens_list + [0] * (
len(actual_seq_lengths) - len(seq_lens_list))
seq_lens_list = seq_lens_list + [0] * (len(actual_seq_lengths) - len(seq_lens_list))
else:
seq_lens_list = seq_lens_list + [0] * (runtime_shape -
len(seq_lens_list))
seq_lens_list = seq_lens_list + [0] * (runtime_shape - len(seq_lens_list))
torch.npu.graph_task_update_begin(update_stream, handle)
torch_npu.npu_fused_infer_attention_score.out(
@@ -391,7 +399,8 @@ def update_mla_attn_params(update_stream, forward_context, runtime_shape,
actual_seq_lengths_kv=seq_lens_list,
actual_seq_lengths=actual_seq_lengths,
workspace=graph_params.workspaces.get(runtime_shape),
out=[attn_output, softmax_lse])
out=[attn_output, softmax_lse],
)
torch.npu.graph_task_update_end(update_stream)
event.record(update_stream)
@@ -408,29 +417,35 @@ def update_attn_dcp_pcp_params(update_stream, forward_context, runtime_shape):
graph_params.handles[runtime_shape],
graph_params.events[runtime_shape],
):
(q_nope, k_nope, value, num_heads, num_kv_heads, scale,
block_table, block_size, actual_seq_lengths_kv,
actual_seq_lengths_q, attn_output, softmax_lse, dcp_size,
pcp_rank, dcp_rank) = param
attn_metadata = forward_context.attn_metadata[key]
actual_seq_lengths_kv = attn_metadata.decode_meta.num_computed_tokens_of_pcp_dcp[:,
(
q_nope,
k_nope,
value,
num_heads,
num_kv_heads,
scale,
block_table,
block_size,
actual_seq_lengths_kv,
actual_seq_lengths_q,
attn_output,
softmax_lse,
dcp_size,
pcp_rank,
dcp_rank]
dcp_rank,
) = param
attn_metadata = forward_context.attn_metadata[key]
actual_seq_lengths_kv = attn_metadata.decode_meta.num_computed_tokens_of_pcp_dcp[:, pcp_rank, dcp_rank]
pad_length = runtime_shape - len(actual_seq_lengths_kv)
if pad_length > 0:
pad_tensor = np.zeros(pad_length,
dtype=actual_seq_lengths_kv.dtype)
actual_seq_lengths_kv = np.concatenate(
[actual_seq_lengths_kv, pad_tensor])
pad_tensor = np.zeros(pad_length, dtype=actual_seq_lengths_kv.dtype)
actual_seq_lengths_kv = np.concatenate([actual_seq_lengths_kv, pad_tensor])
actual_seq_lengths_q = attn_metadata.actual_seq_lengths_q[:
attn_metadata
.
num_decode_tokens]
if (runtime_shape - len(actual_seq_lengths_q)):
actual_seq_lengths_q = actual_seq_lengths_q + [
actual_seq_lengths_q[-1]
] * (runtime_shape - len(actual_seq_lengths_q))
actual_seq_lengths_q = attn_metadata.actual_seq_lengths_q[: attn_metadata.num_decode_tokens]
if runtime_shape - len(actual_seq_lengths_q):
actual_seq_lengths_q = actual_seq_lengths_q + [actual_seq_lengths_q[-1]] * (
runtime_shape - len(actual_seq_lengths_q)
)
if dcp_size > 1:
num_heads = num_heads * dcp_size
@@ -453,14 +468,14 @@ def update_attn_dcp_pcp_params(update_stream, forward_context, runtime_shape):
actual_seq_lengths_kv=actual_seq_lengths_kv,
actual_seq_lengths=actual_seq_lengths_q,
workspace=graph_params.workspaces.get(runtime_shape),
out=[attn_output, softmax_lse])
out=[attn_output, softmax_lse],
)
torch.npu.graph_task_update_end(update_stream)
event.record(update_stream)
def update_mla_attn_dcp_pcp_params(update_stream, forward_context,
runtime_shape):
def update_mla_attn_dcp_pcp_params(update_stream, forward_context, runtime_shape):
if forward_context.is_draft_model:
graph_params = get_draft_graph_params()
else:
@@ -474,8 +489,19 @@ def update_mla_attn_dcp_pcp_params(update_stream, forward_context,
graph_params.handles[runtime_shape],
graph_params.events[runtime_shape],
):
(q_nope, q_pe, k_nope, k_pe, block_table, seq_len, num_heads,
scale, num_kv_heads, attn_output, softmax_lse) = param
(
q_nope,
q_pe,
k_nope,
k_pe,
block_table,
seq_len,
num_heads,
scale,
num_kv_heads,
attn_output,
softmax_lse,
) = param
decode_meta = forward_context.attn_metadata[key].decode
seq_len = decode_meta.cp_seq_len
@@ -484,9 +510,7 @@ def update_mla_attn_dcp_pcp_params(update_stream, forward_context,
# to avoid irregular attn_mask shape,
# so there's no need to divide runtime_shape by spec_multiple
pad_length = runtime_shape - len(seq_len)
pad_tensor = torch.zeros(pad_length,
dtype=seq_len.dtype,
device=seq_len.device)
pad_tensor = torch.zeros(pad_length, dtype=seq_len.dtype, device=seq_len.device)
seq_len = torch.cat([seq_len, pad_tensor], dim=0)
torch.npu.graph_task_update_begin(update_stream, handle)
@@ -505,7 +529,8 @@ def update_mla_attn_dcp_pcp_params(update_stream, forward_context,
calc_type="calc_type_ring",
workspace=graph_params.workspaces.get(runtime_shape),
output=attn_output,
lse=softmax_lse)
lse=softmax_lse,
)
torch.npu.graph_task_update_end(update_stream)
event.record(update_stream)
@@ -519,7 +544,7 @@ class GraphParams:
attn_params: dict[int, list[tuple]]
_graph_params: Optional[GraphParams] = None
_graph_params: GraphParams | None = None
def set_graph_params(aclgraph_capture_sizes: list[int]):
@@ -527,14 +552,10 @@ def set_graph_params(aclgraph_capture_sizes: list[int]):
if _graph_params is not None:
raise ValueError("Graph parameters have already been set!")
_graph_params = GraphParams(
{size: []
for size in aclgraph_capture_sizes},
{size: None
for size in aclgraph_capture_sizes},
{size: []
for size in aclgraph_capture_sizes},
{size: []
for size in aclgraph_capture_sizes},
{size: [] for size in aclgraph_capture_sizes},
{size: None for size in aclgraph_capture_sizes},
{size: [] for size in aclgraph_capture_sizes},
{size: [] for size in aclgraph_capture_sizes},
)
@@ -548,7 +569,7 @@ def get_graph_params():
return _graph_params
_draft_graph_params: Optional[GraphParams] = None
_draft_graph_params: GraphParams | None = None
def set_draft_graph_params(aclgraph_capture_sizes: list[int]):
@@ -556,14 +577,10 @@ def set_draft_graph_params(aclgraph_capture_sizes: list[int]):
if _draft_graph_params is not None:
raise ValueError("DraftGraph parameters have already been set!")
_draft_graph_params = GraphParams(
{size: []
for size in aclgraph_capture_sizes},
{size: None
for size in aclgraph_capture_sizes},
{size: []
for size in aclgraph_capture_sizes},
{size: []
for size in aclgraph_capture_sizes},
{size: [] for size in aclgraph_capture_sizes},
{size: None for size in aclgraph_capture_sizes},
{size: [] for size in aclgraph_capture_sizes},
{size: [] for size in aclgraph_capture_sizes},
)

View File

@@ -16,13 +16,13 @@
# limitations under the License.
#
import functools
from typing import Any, Callable, Optional
from collections.abc import Callable
from typing import Any
import torch
import torch.fx as fx
from torch._dynamo.backends.common import aot_autograd
from torch._inductor.compile_fx import (graph_returns_tuple,
make_graph_return_tuple)
from torch._inductor.compile_fx import graph_returns_tuple, make_graph_return_tuple
from torch._inductor.decomposition import select_decomp_table
from torch.fx import GraphModule
from vllm.compilation.compiler_interface import CompilerInterface
@@ -32,15 +32,11 @@ from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.utils import COMPILATION_PASS_KEY
def compile_fx(graph: GraphModule, example_inputs: list,
inner_compile: Callable, decompositions: dict) -> Callable:
recursive_compile_fx = functools.partial(compile_fx,
inner_compile=inner_compile,
decompositions=decompositions)
def compile_fx(graph: GraphModule, example_inputs: list, inner_compile: Callable, decompositions: dict) -> Callable:
recursive_compile_fx = functools.partial(compile_fx, inner_compile=inner_compile, decompositions=decompositions)
if not graph_returns_tuple(graph):
return make_graph_return_tuple(graph, example_inputs,
recursive_compile_fx)
return make_graph_return_tuple(graph, example_inputs, recursive_compile_fx)
return aot_autograd(fw_compiler=inner_compile)(graph, example_inputs)
@@ -49,9 +45,8 @@ def fusion_pass_compile(
example_inputs: list[Any],
compiler_config: dict[str, Any],
compile_range: Range,
key: Optional[str] = None,
) -> tuple[Optional[Callable], Optional[Any]]:
key: str | None = None,
) -> tuple[Callable | None, Any | None]:
def compile_inner(graph, example_inputs):
current_pass_manager = compiler_config[COMPILATION_PASS_KEY]
graph = current_pass_manager(graph)
@@ -74,8 +69,8 @@ def npugraph_ex_compile(
example_inputs: list[Any],
compiler_config: dict[str, Any],
compile_range: Range,
key: Optional[str] = None,
) -> tuple[Optional[Callable], Optional[Any]]:
key: str | None = None,
) -> tuple[Callable | None, Any | None]:
# When currently using the FULL_DECODE_ONLY mode,
# the piecewise compilation level slicing process
# in vllm is also encountered.
@@ -87,9 +82,7 @@ def npugraph_ex_compile(
output_node = fx_graph.output_node()
with fx_graph.inserting_before(output_node):
return_value = output_node.args[0]
tuple_node = fx_graph.create_node("call_function",
tuple,
args=([return_value], ))
tuple_node = fx_graph.create_node("call_function", tuple, args=([return_value],))
output_node.args = (tuple_node,)
graph.recompile()
@@ -119,6 +112,7 @@ class AscendCompiler(CompilerInterface):
This class provides a method to compile a PyTorch FX graph module with
specific configurations for graph fusion and decomposition.
"""
name = "AscendCompiler"
def compile(
@@ -127,13 +121,10 @@ class AscendCompiler(CompilerInterface):
example_inputs: list[Any],
compiler_config: dict[str, Any],
compile_range: Range,
key: Optional[str] = None,
) -> tuple[Optional[Callable], Optional[Any]]:
key: str | None = None,
) -> tuple[Callable | None, Any | None]:
ascend_config = get_ascend_config()
if ascend_config.enable_npugraph_ex:
return npugraph_ex_compile(graph, example_inputs, compiler_config,
compile_range, key)
return npugraph_ex_compile(graph, example_inputs, compiler_config, compile_range, key)
else:
return fusion_pass_compile(graph, example_inputs, compiler_config,
compile_range, key)
return fusion_pass_compile(graph, example_inputs, compiler_config, compile_range, key)

View File

@@ -48,13 +48,13 @@ class GraphFusionPassManager:
def configure(self, config: VllmConfig):
# By default, we enable the graph fusion and quantization fusion pass.
self.ascend_compilation_config: dict = config.additional_config.get(
"ascend_compilation_config", {})
self.ascend_compilation_config: dict = config.additional_config.get("ascend_compilation_config", {})
if self.ascend_compilation_config.get("fuse_norm_quant", True):
from .passes.norm_quant_fusion_pass import \
AddRMSNormQuantFusionPass
from .passes.norm_quant_fusion_pass import AddRMSNormQuantFusionPass
self.passes.append(AddRMSNormQuantFusionPass(config))
if self.ascend_compilation_config.get("fuse_qknorm_rope", True):
from .passes.qknorm_rope_fusion_pass import QKNormRopeFusionPass
self.passes.append(QKNormRopeFusionPass(config))

View File

@@ -48,7 +48,8 @@ def _extra_stream_scope_check(match: Match) -> bool:
logger.debug(
f"Cross-stream operation detected in pattern match for AddRMSNormQuant. "
f"Multiple streams found: {non_default_streams}. "
f"Fusion is not supported for cross-stream operations.")
f"Fusion is not supported for cross-stream operations."
)
return False
return True
@@ -57,24 +58,29 @@ def _extra_stream_scope_check(match: Match) -> bool:
@functools.lru_cache(None)
# The replacement registered here will be actually executed after AOT.
def replacement_add_rms_norm_quant(epsilon):
def pattern(rms_norm_input: torch.Tensor, residual: torch.Tensor,
rms_norm_weight: torch.Tensor, scale: torch.Tensor,
offset: torch.Tensor):
def pattern(
rms_norm_input: torch.Tensor,
residual: torch.Tensor,
rms_norm_weight: torch.Tensor,
scale: torch.Tensor,
offset: torch.Tensor,
):
"""
Pattern for AddRMSNormQuant fusion.
"""
output = torch.ops.npu.npu_add_rms_norm(rms_norm_input, residual,
rms_norm_weight, epsilon)
output = torch.ops.npu.npu_add_rms_norm(rms_norm_input, residual, rms_norm_weight, epsilon)
out0 = output[0]
out1 = output[2]
quantized_output = torch.ops.npu.npu_quantize(out0, scale, offset,
torch.qint8, -1, False)
quantized_output = torch.ops.npu.npu_quantize(out0, scale, offset, torch.qint8, -1, False)
return quantized_output, out1
def replacement(rms_norm_input: torch.Tensor, residual: torch.Tensor,
rms_norm_weight: torch.Tensor, scale: torch.Tensor,
offset: torch.Tensor):
def replacement(
rms_norm_input: torch.Tensor,
residual: torch.Tensor,
rms_norm_weight: torch.Tensor,
scale: torch.Tensor,
offset: torch.Tensor,
):
"""
Replacement for the AddRMSNormQuant fusion.
"""
@@ -82,10 +88,12 @@ def replacement_add_rms_norm_quant(epsilon):
rms_norm_input,
residual,
rms_norm_weight,
# The inverse of scale is required by npu_add_rms_norm_quant kernel which is opposite to the npu_quantize kernel.
1. / scale,
# The inverse of scale is required by npu_add_rms_norm_quant kernel
# which is opposite to the npu_quantize kernel.
1.0 / scale,
offset,
epsilon=epsilon)
epsilon=epsilon,
)
quantized_output = output[0]
out1 = output[2]
return quantized_output, out1
@@ -103,33 +111,39 @@ def replacement_add_rms_norm_quant(epsilon):
import torchair
torchair.register_replacement(search_fn=pattern,
replace_fn=replacement,
example_inputs=get_inputs(),
extra_check=_extra_stream_scope_check)
torchair.register_replacement(
search_fn=pattern, replace_fn=replacement, example_inputs=get_inputs(), extra_check=_extra_stream_scope_check
)
# The replacement registered here will be actually executed after AOT.
def replacement_add_rms_norm_quant_with_bias(epsilon):
def pattern(rms_norm_input: torch.Tensor, residual: torch.Tensor,
rms_norm_weight: torch.Tensor, scale: torch.Tensor,
offset: torch.Tensor, bias: torch.Tensor):
def pattern(
rms_norm_input: torch.Tensor,
residual: torch.Tensor,
rms_norm_weight: torch.Tensor,
scale: torch.Tensor,
offset: torch.Tensor,
bias: torch.Tensor,
):
"""
Pattern for AddRMSNormQuantWithBias fusion.
"""
output = torch.ops.npu.npu_add_rms_norm(rms_norm_input, residual,
rms_norm_weight, epsilon)
output = torch.ops.npu.npu_add_rms_norm(rms_norm_input, residual, rms_norm_weight, epsilon)
out0 = output[0]
out1 = output[2]
out0 = out0 + bias
quantized_output = torch.ops.npu.npu_quantize(out0, scale, offset,
torch.qint8, -1, False)
quantized_output = torch.ops.npu.npu_quantize(out0, scale, offset, torch.qint8, -1, False)
return quantized_output, out1
def replacement(rms_norm_input: torch.Tensor, residual: torch.Tensor,
rms_norm_weight: torch.Tensor, scale: torch.Tensor,
offset: torch.Tensor, bias: torch.Tensor):
def replacement(
rms_norm_input: torch.Tensor,
residual: torch.Tensor,
rms_norm_weight: torch.Tensor,
scale: torch.Tensor,
offset: torch.Tensor,
bias: torch.Tensor,
):
"""
Replacement for AddRMSNormQuantWithBias fusion.
"""
@@ -137,11 +151,13 @@ def replacement_add_rms_norm_quant_with_bias(epsilon):
rms_norm_input,
residual,
rms_norm_weight,
# The inverse of scale is required by npu_add_rms_norm_quant kernel which is opposite to the npu_quantize kernel.
1. / scale,
# The inverse of scale is required by npu_add_rms_norm_quant kernel
# which is opposite to the npu_quantize kernel.
1.0 / scale,
offset,
epsilon=epsilon,
beta=bias)
beta=bias,
)
quantized_output = output[0]
out1 = output[2]
return quantized_output, out1
@@ -156,40 +172,41 @@ def replacement_add_rms_norm_quant_with_bias(epsilon):
rmsnorm_bias = torch.randn(4, device="npu")
scale = torch.ones(4, device="npu")
offset = torch.zeros(4, device="npu")
return [
rms_norm_input, residual, rms_norm_weight, scale, offset,
rmsnorm_bias
]
return [rms_norm_input, residual, rms_norm_weight, scale, offset, rmsnorm_bias]
import torchair
torchair.register_replacement(search_fn=pattern,
replace_fn=replacement,
example_inputs=get_inputs(),
extra_check=_extra_stream_scope_check)
torchair.register_replacement(
search_fn=pattern, replace_fn=replacement, example_inputs=get_inputs(), extra_check=_extra_stream_scope_check
)
# The replacement registered here will be actually executed after AOT.
def replacement_add_rms_norm_quant_sp_pattern(epsilon):
def pattern(rms_norm_input: torch.Tensor, residual: torch.Tensor,
rms_norm_weight: torch.Tensor, scale: torch.Tensor,
offset: torch.Tensor):
def pattern(
rms_norm_input: torch.Tensor,
residual: torch.Tensor,
rms_norm_weight: torch.Tensor,
scale: torch.Tensor,
offset: torch.Tensor,
):
"""
Pattern for AddRMSNormQuantSPPattern fusion.
"""
output = torch.ops.npu.npu_add_rms_norm(rms_norm_input, residual,
rms_norm_weight, epsilon)
output = torch.ops.npu.npu_add_rms_norm(rms_norm_input, residual, rms_norm_weight, epsilon)
out0 = output[0]
out1 = output[2]
out0 = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(out0, True)
quantized_output = torch.ops.npu.npu_quantize(out0, scale, offset,
torch.qint8, -1, False)
quantized_output = torch.ops.npu.npu_quantize(out0, scale, offset, torch.qint8, -1, False)
return quantized_output, out1
def replacement(rms_norm_input: torch.Tensor, residual: torch.Tensor,
rms_norm_weight: torch.Tensor, scale: torch.Tensor,
offset: torch.Tensor):
def replacement(
rms_norm_input: torch.Tensor,
residual: torch.Tensor,
rms_norm_weight: torch.Tensor,
scale: torch.Tensor,
offset: torch.Tensor,
):
"""
Replacement for the AddRMSNormQuantSPPattern fusion.
"""
@@ -197,14 +214,15 @@ def replacement_add_rms_norm_quant_sp_pattern(epsilon):
rms_norm_input,
residual,
rms_norm_weight,
# The inverse of scale is required by npu_add_rms_norm_quant kernel which is opposite to the npu_quantize kernel.
1. / scale,
# The inverse of scale is required by npu_add_rms_norm_quant kernel
# which is opposite to the npu_quantize kernel.
1.0 / scale,
offset,
epsilon=epsilon)
epsilon=epsilon,
)
quantized_output = output[0]
out1 = output[2]
quantized_output = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
quantized_output, True)
quantized_output = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(quantized_output, True)
return quantized_output, out1
def get_inputs():
@@ -220,34 +238,40 @@ def replacement_add_rms_norm_quant_sp_pattern(epsilon):
import torchair
torchair.register_replacement(search_fn=pattern,
replace_fn=replacement,
example_inputs=get_inputs(),
extra_check=_extra_stream_scope_check)
torchair.register_replacement(
search_fn=pattern, replace_fn=replacement, example_inputs=get_inputs(), extra_check=_extra_stream_scope_check
)
# The replacement registered here will be actually executed after AOT.
def replacement_add_rms_norm_quant_sp_pattern_with_bias(epsilon):
def pattern(rms_norm_input: torch.Tensor, residual: torch.Tensor,
rms_norm_weight: torch.Tensor, scale: torch.Tensor,
offset: torch.Tensor, bias: torch.Tensor):
def pattern(
rms_norm_input: torch.Tensor,
residual: torch.Tensor,
rms_norm_weight: torch.Tensor,
scale: torch.Tensor,
offset: torch.Tensor,
bias: torch.Tensor,
):
"""
Pattern for AddRMSNormQuantSPPatternWithBias fusion.
"""
output = torch.ops.npu.npu_add_rms_norm(rms_norm_input, residual,
rms_norm_weight, epsilon)
output = torch.ops.npu.npu_add_rms_norm(rms_norm_input, residual, rms_norm_weight, epsilon)
out0 = output[0]
out1 = output[2]
out0 = out0 + bias
out0 = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(out0, True)
quantized_output = torch.ops.npu.npu_quantize(out0, scale, offset,
torch.qint8, -1, False)
quantized_output = torch.ops.npu.npu_quantize(out0, scale, offset, torch.qint8, -1, False)
return quantized_output, out1
def replacement(rms_norm_input: torch.Tensor, residual: torch.Tensor,
rms_norm_weight: torch.Tensor, scale: torch.Tensor,
offset: torch.Tensor, bias: torch.Tensor):
def replacement(
rms_norm_input: torch.Tensor,
residual: torch.Tensor,
rms_norm_weight: torch.Tensor,
scale: torch.Tensor,
offset: torch.Tensor,
bias: torch.Tensor,
):
"""
Replacement for the AddRMSNormQuantSPPatternWithBias fusion.
"""
@@ -255,15 +279,16 @@ def replacement_add_rms_norm_quant_sp_pattern_with_bias(epsilon):
rms_norm_input,
residual,
rms_norm_weight,
# The inverse of scale is required by npu_add_rms_norm_quant kernel which is opposite to the npu_quantize kernel.
1. / scale,
# The inverse of scale is required by npu_add_rms_norm_quant kernel
# which is opposite to the npu_quantize kernel.
1.0 / scale,
offset,
epsilon=epsilon,
beta=bias)
beta=bias,
)
quantized_output = output[0]
out1 = output[2]
quantized_output = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
quantized_output, True)
quantized_output = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(quantized_output, True)
return quantized_output, out1
def get_inputs():
@@ -276,25 +301,19 @@ def replacement_add_rms_norm_quant_sp_pattern_with_bias(epsilon):
rmsnorm_bias = torch.randn(4, device="npu")
scale = torch.ones(4, device="npu")
offset = torch.zeros(4, device="npu")
return [
rms_norm_input, residual, rms_norm_weight, scale, offset,
rmsnorm_bias
]
return [rms_norm_input, residual, rms_norm_weight, scale, offset, rmsnorm_bias]
import torchair
torchair.register_replacement(search_fn=pattern,
replace_fn=replacement,
example_inputs=get_inputs(),
extra_check=_extra_stream_scope_check)
torchair.register_replacement(
search_fn=pattern, replace_fn=replacement, example_inputs=get_inputs(), extra_check=_extra_stream_scope_check
)
# register converter for pass
common_epsilons = [1e-5, 1e-6]
for eps in common_epsilons:
logger.info(
f"Start register fusion pattern for AddRMSNormQuant with epsilons={eps}"
)
logger.info(f"Start register fusion pattern for AddRMSNormQuant with epsilons={eps}")
replacement_add_rms_norm_quant(eps)
replacement_add_rms_norm_quant_with_bias(eps)
replacement_add_rms_norm_quant_sp_pattern(eps)

View File

@@ -25,7 +25,6 @@ from vllm.logger import logger
class AddRMSNormQuantPattern:
def __init__(self, vllm_config: VllmConfig, eps: float = 1e-6):
self.vllm_config = vllm_config
self.dtype = vllm_config.model_config.dtype
@@ -41,50 +40,48 @@ class AddRMSNormQuantPattern:
scale = torch.ones(4, device="npu", dtype=self.dtype)
scale_reciprocal = torch.ones(4, device="npu", dtype=self.dtype)
offset = torch.zeros(4, device="npu", dtype=self.dtype)
return [
rms_norm_input, residual, rms_norm_weight, scale, scale_reciprocal,
offset
]
return [rms_norm_input, residual, rms_norm_weight, scale, scale_reciprocal, offset]
def register(self, pm_pass: PatternMatcherPass):
def pattern(rms_norm_input: torch.Tensor, residual: torch.Tensor,
rms_norm_weight: torch.Tensor, scale: torch.Tensor,
scale_reciprocal: torch.Tensor, offset: torch.Tensor):
def pattern(
rms_norm_input: torch.Tensor,
residual: torch.Tensor,
rms_norm_weight: torch.Tensor,
scale: torch.Tensor,
scale_reciprocal: torch.Tensor,
offset: torch.Tensor,
):
"""
Pattern for AddRMSNormQuant fusion.
"""
output = torch.ops.npu.npu_add_rms_norm(rms_norm_input, residual,
rms_norm_weight, self.eps)
output = torch.ops.npu.npu_add_rms_norm(rms_norm_input, residual, rms_norm_weight, self.eps)
out0 = output[0]
out1 = output[2]
quantized_output = torch.ops.vllm.quantize(out0, scale,
scale_reciprocal,
offset)
quantized_output = torch.ops.vllm.quantize(out0, scale, scale_reciprocal, offset)
return quantized_output, out1
def replacement(rms_norm_input: torch.Tensor, residual: torch.Tensor,
rms_norm_weight: torch.Tensor, scale: torch.Tensor,
scale_reciprocal: torch.Tensor, offset: torch.Tensor):
def replacement(
rms_norm_input: torch.Tensor,
residual: torch.Tensor,
rms_norm_weight: torch.Tensor,
scale: torch.Tensor,
scale_reciprocal: torch.Tensor,
offset: torch.Tensor,
):
"""
Replacement for the AddRMSNormQuant fusion.
"""
output = torch.ops.npu.npu_add_rms_norm_quant(rms_norm_input,
residual,
rms_norm_weight,
scale,
offset,
epsilon=self.eps)
output = torch.ops.npu.npu_add_rms_norm_quant(
rms_norm_input, residual, rms_norm_weight, scale, offset, epsilon=self.eps
)
quantized_output = output[0]
out1 = output[2]
return quantized_output, out1
pm.register_replacement(pattern, replacement, self.get_inputs(),
pm.fwd_only, pm_pass)
pm.register_replacement(pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass)
class AddRMSNormQuantPatternWithBias:
def __init__(self, vllm_config: VllmConfig, eps: float = 1e-6):
self.vllm_config = vllm_config
self.dtype = vllm_config.model_config.dtype
@@ -101,54 +98,51 @@ class AddRMSNormQuantPatternWithBias:
scale = torch.ones(4, device="npu", dtype=self.dtype)
scale_reciprocal = torch.ones(4, device="npu", dtype=self.dtype)
offset = torch.zeros(4, device="npu", dtype=self.dtype)
return [
rms_norm_input, residual, rms_norm_weight, scale, scale_reciprocal,
offset, rmsnorm_bias
]
return [rms_norm_input, residual, rms_norm_weight, scale, scale_reciprocal, offset, rmsnorm_bias]
def register(self, pm_pass: PatternMatcherPass):
def pattern(rms_norm_input: torch.Tensor, residual: torch.Tensor,
rms_norm_weight: torch.Tensor, scale: torch.Tensor,
scale_reciprocal: torch.Tensor, offset: torch.Tensor,
bias: torch.Tensor):
def pattern(
rms_norm_input: torch.Tensor,
residual: torch.Tensor,
rms_norm_weight: torch.Tensor,
scale: torch.Tensor,
scale_reciprocal: torch.Tensor,
offset: torch.Tensor,
bias: torch.Tensor,
):
"""
Pattern for AddRMSNormQuant fusion.
"""
output = torch.ops.npu.npu_add_rms_norm(rms_norm_input, residual,
rms_norm_weight, self.eps)
output = torch.ops.npu.npu_add_rms_norm(rms_norm_input, residual, rms_norm_weight, self.eps)
out0 = output[0]
out1 = output[2]
out0 = out0 + bias
quantized_output = torch.ops.vllm.quantize(out0, scale,
scale_reciprocal,
offset)
quantized_output = torch.ops.vllm.quantize(out0, scale, scale_reciprocal, offset)
return quantized_output, out1
def replacement(rms_norm_input: torch.Tensor, residual: torch.Tensor,
rms_norm_weight: torch.Tensor, scale: torch.Tensor,
scale_reciprocal: torch.Tensor, offset: torch.Tensor,
bias: torch.Tensor):
def replacement(
rms_norm_input: torch.Tensor,
residual: torch.Tensor,
rms_norm_weight: torch.Tensor,
scale: torch.Tensor,
scale_reciprocal: torch.Tensor,
offset: torch.Tensor,
bias: torch.Tensor,
):
"""
Replacement for the AddRMSNormQuant fusion.
"""
output = torch.ops.npu.npu_add_rms_norm_quant(rms_norm_input,
residual,
rms_norm_weight,
scale,
offset,
epsilon=self.eps,
beta=bias)
output = torch.ops.npu.npu_add_rms_norm_quant(
rms_norm_input, residual, rms_norm_weight, scale, offset, epsilon=self.eps, beta=bias
)
quantized_output = output[0]
out1 = output[2]
return quantized_output, out1
pm.register_replacement(pattern, replacement, self.get_inputs(),
pm.fwd_only, pm_pass)
pm.register_replacement(pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass)
class AddRMSNormQuantSPPattern:
def __init__(self, vllm_config: VllmConfig, eps: float = 1e-6):
self.vllm_config = vllm_config
self.dtype = vllm_config.model_config.dtype
@@ -164,53 +158,50 @@ class AddRMSNormQuantSPPattern:
scale = torch.ones(4, device="npu", dtype=self.dtype)
scale_reciprocal = torch.ones(4, device="npu", dtype=self.dtype)
offset = torch.zeros(4, device="npu", dtype=self.dtype)
return [
rms_norm_input, residual, rms_norm_weight, scale, scale_reciprocal,
offset
]
return [rms_norm_input, residual, rms_norm_weight, scale, scale_reciprocal, offset]
def register(self, pm_pass: PatternMatcherPass):
def pattern(rms_norm_input: torch.Tensor, residual: torch.Tensor,
rms_norm_weight: torch.Tensor, scale: torch.Tensor,
scale_reciprocal: torch.Tensor, offset: torch.Tensor):
def pattern(
rms_norm_input: torch.Tensor,
residual: torch.Tensor,
rms_norm_weight: torch.Tensor,
scale: torch.Tensor,
scale_reciprocal: torch.Tensor,
offset: torch.Tensor,
):
"""
Pattern for AddRMSNormQuant fusion.
"""
output = torch.ops.npu.npu_add_rms_norm(rms_norm_input, residual,
rms_norm_weight, self.eps)
output = torch.ops.npu.npu_add_rms_norm(rms_norm_input, residual, rms_norm_weight, self.eps)
out0 = output[0]
out1 = output[2]
out0 = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(out0, True)
quantized_output = torch.ops.vllm.quantize(out0, scale,
scale_reciprocal,
offset)
quantized_output = torch.ops.vllm.quantize(out0, scale, scale_reciprocal, offset)
return quantized_output, out1
def replacement(rms_norm_input: torch.Tensor, residual: torch.Tensor,
rms_norm_weight: torch.Tensor, scale: torch.Tensor,
scale_reciprocal: torch.Tensor, offset: torch.Tensor):
def replacement(
rms_norm_input: torch.Tensor,
residual: torch.Tensor,
rms_norm_weight: torch.Tensor,
scale: torch.Tensor,
scale_reciprocal: torch.Tensor,
offset: torch.Tensor,
):
"""
Replacement for the AddRMSNormQuant fusion.
"""
output = torch.ops.npu.npu_add_rms_norm_quant(rms_norm_input,
residual,
rms_norm_weight,
scale,
offset,
epsilon=self.eps)
output = torch.ops.npu.npu_add_rms_norm_quant(
rms_norm_input, residual, rms_norm_weight, scale, offset, epsilon=self.eps
)
quantized_output = output[0]
out1 = output[2]
quantized_output = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
quantized_output, True)
quantized_output = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(quantized_output, True)
return quantized_output, out1
pm.register_replacement(pattern, replacement, self.get_inputs(),
pm.fwd_only, pm_pass)
pm.register_replacement(pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass)
class AddRMSNormQuantSPPatternWithBias:
def __init__(self, vllm_config: VllmConfig, eps: float = 1e-6):
self.vllm_config = vllm_config
self.dtype = vllm_config.model_config.dtype
@@ -227,53 +218,50 @@ class AddRMSNormQuantSPPatternWithBias:
scale = torch.ones(4, device="npu", dtype=self.dtype)
scale_reciprocal = torch.ones(4, device="npu", dtype=self.dtype)
offset = torch.zeros(4, device="npu", dtype=self.dtype)
return [
rms_norm_input, residual, rms_norm_weight, scale, scale_reciprocal,
offset, rmsnorm_bias
]
return [rms_norm_input, residual, rms_norm_weight, scale, scale_reciprocal, offset, rmsnorm_bias]
def register(self, pm_pass: PatternMatcherPass):
def pattern(rms_norm_input: torch.Tensor, residual: torch.Tensor,
rms_norm_weight: torch.Tensor, scale: torch.Tensor,
scale_reciprocal: torch.Tensor, offset: torch.Tensor,
bias: torch.Tensor):
def pattern(
rms_norm_input: torch.Tensor,
residual: torch.Tensor,
rms_norm_weight: torch.Tensor,
scale: torch.Tensor,
scale_reciprocal: torch.Tensor,
offset: torch.Tensor,
bias: torch.Tensor,
):
"""
Pattern for AddRMSNormQuant fusion.
"""
output = torch.ops.npu.npu_add_rms_norm(rms_norm_input, residual,
rms_norm_weight, self.eps)
output = torch.ops.npu.npu_add_rms_norm(rms_norm_input, residual, rms_norm_weight, self.eps)
out0 = output[0]
out1 = output[2]
out0 = out0 + bias
out0 = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(out0, True)
quantized_output = torch.ops.vllm.quantize(out0, scale,
scale_reciprocal,
offset)
quantized_output = torch.ops.vllm.quantize(out0, scale, scale_reciprocal, offset)
return quantized_output, out1
def replacement(rms_norm_input: torch.Tensor, residual: torch.Tensor,
rms_norm_weight: torch.Tensor, scale: torch.Tensor,
scale_reciprocal: torch.Tensor, offset: torch.Tensor,
bias: torch.Tensor):
def replacement(
rms_norm_input: torch.Tensor,
residual: torch.Tensor,
rms_norm_weight: torch.Tensor,
scale: torch.Tensor,
scale_reciprocal: torch.Tensor,
offset: torch.Tensor,
bias: torch.Tensor,
):
"""
Replacement for the AddRMSNormQuant fusion.
"""
output = torch.ops.npu.npu_add_rms_norm_quant(rms_norm_input,
residual,
rms_norm_weight,
scale,
offset,
epsilon=self.eps,
beta=bias)
output = torch.ops.npu.npu_add_rms_norm_quant(
rms_norm_input, residual, rms_norm_weight, scale, offset, epsilon=self.eps, beta=bias
)
quantized_output = output[0]
out1 = output[2]
quantized_output = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
quantized_output, True)
quantized_output = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(quantized_output, True)
return quantized_output, out1
pm.register_replacement(pattern, replacement, self.get_inputs(),
pm.fwd_only, pm_pass)
pm.register_replacement(pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass)
class AddRMSNormQuantFusionPass(VllmInductorPass):
@@ -283,25 +271,19 @@ class AddRMSNormQuantFusionPass(VllmInductorPass):
def __init__(self, vllm_config: VllmConfig):
super().__init__(vllm_config)
self.pattern_match_passes: PatternMatcherPass = PatternMatcherPass(
pass_name="rmsnorm_quant_fusion_pass")
self.pattern_match_passes: PatternMatcherPass = PatternMatcherPass(pass_name="rmsnorm_quant_fusion_pass")
dtype = vllm_config.model_config.dtype
if dtype not in (torch.bfloat16, torch.float16):
logger.debug("Quant fusion not enabled: unsupported dtype %s",
dtype)
logger.debug("Quant fusion not enabled: unsupported dtype %s", dtype)
return
common_epsilons = [1e-5, 1e-6]
for eps in common_epsilons:
AddRMSNormQuantPattern(vllm_config,
eps=eps).register(self.pattern_match_passes)
AddRMSNormQuantPatternWithBias(vllm_config, eps=eps).register(
self.pattern_match_passes)
AddRMSNormQuantSPPattern(vllm_config, eps=eps).register(
self.pattern_match_passes)
AddRMSNormQuantSPPatternWithBias(vllm_config, eps=eps).register(
self.pattern_match_passes)
AddRMSNormQuantPattern(vllm_config, eps=eps).register(self.pattern_match_passes)
AddRMSNormQuantPatternWithBias(vllm_config, eps=eps).register(self.pattern_match_passes)
AddRMSNormQuantSPPattern(vllm_config, eps=eps).register(self.pattern_match_passes)
AddRMSNormQuantSPPatternWithBias(vllm_config, eps=eps).register(self.pattern_match_passes)
def __call__(self, graph: torch.fx.Graph):
self.begin()

View File

@@ -17,8 +17,7 @@
#
import torch
import torch._inductor.pattern_matcher as pm
from torch._inductor.pattern_matcher import (PatternMatcherPass,
PatternPrettyPrinter)
from torch._inductor.pattern_matcher import PatternMatcherPass, PatternPrettyPrinter
from vllm.attention.layer import Attention
from vllm.compilation.vllm_inductor_pass import VllmInductorPass
from vllm.config import VllmConfig, get_layers_from_vllm_config
@@ -27,13 +26,7 @@ from vllm.logger import logger
class QKNormRopeFusionPattern:
def __init__(self,
vllm_config,
head_dim,
num_heads,
num_kv_heads,
eps=1e-6):
def __init__(self, vllm_config, head_dim, num_heads, num_kv_heads, eps=1e-6):
self.vllm_config = vllm_config
self.head_dim = head_dim
self.num_heads = num_heads
@@ -45,65 +38,38 @@ class QKNormRopeFusionPattern:
def get_inputs(self):
T = 5
qkv = torch.empty(T,
self.q_size + 2 * self.kv_size,
dtype=torch.bfloat16,
device="npu")
q_weight = torch.empty(self.head_dim,
dtype=torch.bfloat16,
device="npu")
k_weight = torch.empty(self.head_dim,
dtype=torch.bfloat16,
device="npu")
cos = torch.empty(1,
T,
1,
self.head_dim,
dtype=torch.bfloat16,
device="npu")
sin = torch.empty(1,
T,
1,
self.head_dim,
dtype=torch.bfloat16,
device="npu")
qkv = torch.empty(T, self.q_size + 2 * self.kv_size, dtype=torch.bfloat16, device="npu")
q_weight = torch.empty(self.head_dim, dtype=torch.bfloat16, device="npu")
k_weight = torch.empty(self.head_dim, dtype=torch.bfloat16, device="npu")
cos = torch.empty(1, T, 1, self.head_dim, dtype=torch.bfloat16, device="npu")
sin = torch.empty(1, T, 1, self.head_dim, dtype=torch.bfloat16, device="npu")
return [qkv, q_weight, k_weight, cos, sin]
def register(self, pm_pass: PatternMatcherPass):
def pattern(
qkv: torch.Tensor, q_weight: torch.Tensor, k_weight: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor
):
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
def pattern(qkv: torch.Tensor, q_weight: torch.Tensor,
k_weight: torch.Tensor, cos: torch.Tensor,
sin: torch.Tensor):
q_by_head = q.view(*q.shape[:-1], q.shape[-1] // self.head_dim, self.head_dim)
q_norm_out, _ = torch.ops.npu.npu_rms_norm(q_by_head, q_weight, self.eps)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size],
dim=-1)
q_by_head = q.view(*q.shape[:-1], q.shape[-1] // self.head_dim,
self.head_dim)
q_norm_out, _ = torch.ops.npu.npu_rms_norm(q_by_head, q_weight,
self.eps)
k_by_head = k.view(*k.shape[:-1], k.shape[-1] // self.head_dim,
self.head_dim)
k_norm_out, _ = torch.ops.npu.npu_rms_norm(k_by_head, k_weight,
self.eps)
k_by_head = k.view(*k.shape[:-1], k.shape[-1] // self.head_dim, self.head_dim)
k_norm_out, _ = torch.ops.npu.npu_rms_norm(k_by_head, k_weight, self.eps)
q_flat = q_norm_out.view(q.shape)
q_reshape = q_flat.contiguous().view(1, q_flat.shape[0], -1,
self.head_dim)
q_reshape = q_flat.contiguous().view(1, q_flat.shape[0], -1, self.head_dim)
k_flat = k_norm_out.view(k.shape)
k_reshape = k_flat.contiguous().view(1, k_flat.shape[0], -1,
self.head_dim)
k_reshape = k_flat.contiguous().view(1, k_flat.shape[0], -1, self.head_dim)
q_rope, k_rope = torch.ops.npu.npu_apply_rotary_pos_emb(
q_reshape, k_reshape, cos, sin)
q_rope, k_rope = torch.ops.npu.npu_apply_rotary_pos_emb(q_reshape, k_reshape, cos, sin)
return q_rope, k_rope, v
def replacement(qkv: torch.Tensor, q_weight: torch.Tensor,
k_weight: torch.Tensor, cos: torch.Tensor,
sin: torch.Tensor):
def replacement(
qkv: torch.Tensor, q_weight: torch.Tensor, k_weight: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor
):
results = torch.ops.vllm.qkv_rmsnorm_rope(
input=qkv,
q_weight=q_weight,
@@ -115,22 +81,16 @@ class QKNormRopeFusionPattern:
q_bias=None,
k_bias=None,
sin=sin,
cos=cos)
cos=cos,
)
return results
pm.register_replacement(pattern, replacement, self.get_inputs(),
pm.fwd_only, pm_pass)
pm.register_replacement(pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass)
class QKNormRopeFusionPatternWithBias:
def __init__(self,
vllm_config,
head_dim,
num_heads,
num_kv_heads,
eps=1e-6):
def __init__(self, vllm_config, head_dim, num_heads, num_kv_heads, eps=1e-6):
self.head_dim = head_dim
self.num_heads = num_heads
self.num_kv_heads = num_kv_heads
@@ -142,71 +102,55 @@ class QKNormRopeFusionPatternWithBias:
def get_inputs(self):
T = 5
qkv = torch.empty(T,
self.q_size + 2 * self.kv_size,
dtype=torch.bfloat16,
device="npu")
q_weight = torch.empty(self.head_dim,
dtype=torch.bfloat16,
device="npu")
k_weight = torch.empty(self.head_dim,
dtype=torch.bfloat16,
device="npu")
qkv = torch.empty(T, self.q_size + 2 * self.kv_size, dtype=torch.bfloat16, device="npu")
q_weight = torch.empty(self.head_dim, dtype=torch.bfloat16, device="npu")
k_weight = torch.empty(self.head_dim, dtype=torch.bfloat16, device="npu")
q_bias = torch.empty(self.head_dim, dtype=torch.bfloat16, device="npu")
k_bias = torch.empty(self.head_dim, dtype=torch.bfloat16, device="npu")
cos = torch.empty(1,
T,
1,
self.head_dim,
dtype=torch.bfloat16,
device="npu")
sin = torch.empty(1,
T,
1,
self.head_dim,
dtype=torch.bfloat16,
device="npu")
cos = torch.empty(1, T, 1, self.head_dim, dtype=torch.bfloat16, device="npu")
sin = torch.empty(1, T, 1, self.head_dim, dtype=torch.bfloat16, device="npu")
return [qkv, q_weight, k_weight, q_bias, k_bias, cos, sin]
def register(self, pm_pass: PatternMatcherPass):
def pattern(
qkv: torch.Tensor,
q_weight: torch.Tensor,
k_weight: torch.Tensor,
q_bias: torch.Tensor,
k_bias: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
):
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
def pattern(qkv: torch.Tensor, q_weight: torch.Tensor,
k_weight: torch.Tensor, q_bias: torch.Tensor,
k_bias: torch.Tensor, cos: torch.Tensor,
sin: torch.Tensor):
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size],
dim=-1)
q_by_head = q.view(*q.shape[:-1], q.shape[-1] // self.head_dim,
self.head_dim)
q_norm_out, _ = torch.ops.npu.npu_rms_norm(q_by_head, q_weight,
self.eps)
q_by_head = q.view(*q.shape[:-1], q.shape[-1] // self.head_dim, self.head_dim)
q_norm_out, _ = torch.ops.npu.npu_rms_norm(q_by_head, q_weight, self.eps)
q_normed = q_norm_out + q_bias
k_by_head = k.view(*k.shape[:-1], k.shape[-1] // self.head_dim,
self.head_dim)
k_norm_out, _ = torch.ops.npu.npu_rms_norm(k_by_head, k_weight,
self.eps)
k_by_head = k.view(*k.shape[:-1], k.shape[-1] // self.head_dim, self.head_dim)
k_norm_out, _ = torch.ops.npu.npu_rms_norm(k_by_head, k_weight, self.eps)
k_normed = k_norm_out + k_bias
q_flat = q_normed.view(q.shape)
q_reshape = q_flat.contiguous().view(1, q_flat.shape[0], -1,
self.head_dim)
q_reshape = q_flat.contiguous().view(1, q_flat.shape[0], -1, self.head_dim)
k_flat = k_normed.view(k.shape)
k_reshape = k_flat.contiguous().view(1, k_flat.shape[0], -1,
self.head_dim)
k_reshape = k_flat.contiguous().view(1, k_flat.shape[0], -1, self.head_dim)
q_rope, k_rope = torch.ops.npu.npu_apply_rotary_pos_emb(
q_reshape, k_reshape, cos, sin)
q_rope, k_rope = torch.ops.npu.npu_apply_rotary_pos_emb(q_reshape, k_reshape, cos, sin)
return q_rope, k_rope, v
def replacement(qkv: torch.Tensor, q_weight: torch.Tensor,
k_weight: torch.Tensor, q_bias: torch.Tensor,
k_bias: torch.Tensor, cos: torch.Tensor,
sin: torch.Tensor):
def replacement(
qkv: torch.Tensor,
q_weight: torch.Tensor,
k_weight: torch.Tensor,
q_bias: torch.Tensor,
k_bias: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
):
results = torch.ops.vllm.qkv_rmsnorm_rope(
input=qkv,
q_weight=q_weight,
@@ -218,11 +162,11 @@ class QKNormRopeFusionPatternWithBias:
q_bias=q_bias,
k_bias=k_bias,
cos=cos,
sin=sin)
sin=sin,
)
return results
pm.register_replacement(pattern, replacement, self.get_inputs(),
pm.fwd_only, pm_pass)
pm.register_replacement(pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass)
class QKNormRopeFusionPass(VllmInductorPass):
@@ -232,44 +176,38 @@ class QKNormRopeFusionPass(VllmInductorPass):
def __init__(self, vllm_config: VllmConfig):
super().__init__(vllm_config)
self.pattern_match_passes: PatternMatcherPass = PatternMatcherPass(
pass_name="qknorm_rope_fusion_pass")
self.pattern_match_passes: PatternMatcherPass = PatternMatcherPass(pass_name="qknorm_rope_fusion_pass")
dtype = vllm_config.model_config.dtype
if dtype not in (torch.bfloat16, torch.float16):
logger.debug(
"QKNorm and Rope fusion not enabled: unsupported dtype %s",
dtype)
logger.debug("QKNorm and Rope fusion not enabled: unsupported dtype %s", dtype)
return
# use one attn layer to get meta (such as head_dim) for QKNormRopeFusionPattern
attn_layers: dict[str, Attention] = get_layers_from_vllm_config(
vllm_config, Attention)
attn_layers: dict[str, Attention] = get_layers_from_vllm_config(vllm_config, Attention)
if len(attn_layers) == 0:
logger.debug(
"QKNorm and Rope fusion enabled, but no Attention layers were discovered."
)
logger.debug("QKNorm and Rope fusion enabled, but no Attention layers were discovered.")
return
layer = next(iter(attn_layers.values()))
for epsilon in [1e-6, 1e-5]:
if layer.head_size != 128:
logger.debug(
"QKNorm and Rope fusion not enabled: head_dim %d is not equal of 128",
layer.head_size)
logger.debug("QKNorm and Rope fusion not enabled: head_dim %d is not equal of 128", layer.head_size)
continue
QKNormRopeFusionPattern(vllm_config=vllm_config,
QKNormRopeFusionPattern(
vllm_config=vllm_config,
head_dim=layer.head_size,
num_heads=layer.num_heads,
num_kv_heads=layer.num_kv_heads,
eps=epsilon).register(
self.pattern_match_passes)
eps=epsilon,
).register(self.pattern_match_passes)
QKNormRopeFusionPatternWithBias(vllm_config=vllm_config,
QKNormRopeFusionPatternWithBias(
vllm_config=vllm_config,
head_dim=layer.head_size,
num_heads=layer.num_heads,
num_kv_heads=layer.num_kv_heads,
eps=epsilon).register(
self.pattern_match_passes)
eps=epsilon,
).register(self.pattern_match_passes)
def __call__(self, graph: torch.fx.Graph):
self.begin()

View File

@@ -1,10 +1,8 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import os
import subprocess
from collections import defaultdict
from typing import Dict, List, Tuple
import psutil
from vllm.logger import logger
@@ -13,26 +11,22 @@ ALLOWED_CPUS_PATH = "/proc/self/status"
ASCEND_RT_VISIBLE_DEVICES = os.getenv("ASCEND_RT_VISIBLE_DEVICES")
def execute_command(cmd: List[str]) -> Tuple[str, int]:
with subprocess.Popen(cmd,
shell=False,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE) as p:
def execute_command(cmd: list[str]) -> tuple[str, int]:
with subprocess.Popen(cmd, shell=False, stdout=subprocess.PIPE, stderr=subprocess.PIPE) as p:
out, _ = p.communicate(timeout=1000)
return out.decode(), p.returncode
class DeviceInfo:
def __init__(self):
self.npu_map_info: Dict[str, Dict[str, str]] = self.get_npu_map_info()
self.allowed_cpus: List[int] = self.parse_allowed_cpus()
self.running_npu_list: List[int] = self.get_running_npus()
self.npu_affinity: Dict[int, List[int]] = self.parse_topo_affinity()
self.npu_map_info: dict[str, dict[str, str]] = self.get_npu_map_info()
self.allowed_cpus: list[int] = self.parse_allowed_cpus()
self.running_npu_list: list[int] = self.get_running_npus()
self.npu_affinity: dict[int, list[int]] = self.parse_topo_affinity()
@staticmethod
def expand_cpu_list(allowed_list_str: str) -> List[int]:
allowed_cpus_list: List[int] = []
def expand_cpu_list(allowed_list_str: str) -> list[int]:
allowed_cpus_list: list[int] = []
for per_range in allowed_list_str.split(","):
if "-" in per_range:
start_cpu, end_cpu = map(int, per_range.split("-"))
@@ -42,8 +36,8 @@ class DeviceInfo:
return allowed_cpus_list
@staticmethod
def get_npu_map_info() -> Dict[str, Dict[str, str]]:
npu_map_info: Dict[str, Dict[str, str]] = {}
def get_npu_map_info() -> dict[str, dict[str, str]]:
npu_map_info: dict[str, dict[str, str]] = {}
npu_info, _ = execute_command(["npu-smi", "info", "-m"])
npu_map = npu_info.strip().split("\n")[1:]
for line in npu_map:
@@ -55,7 +49,7 @@ class DeviceInfo:
npu_map_info[npu_id][chip_id] = chip_logic_id
return npu_map_info
def get_running_npus(self) -> List[int]:
def get_running_npus(self) -> list[int]:
npu_message, _ = execute_command(["npu-smi", "info"])
in_proc_section = False
running_npu_set = set()
@@ -76,36 +70,29 @@ class DeviceInfo:
continue
chip_logic_id = self.npu_map_info.get(npu_id, {}).get(chip_id)
if not chip_logic_id or not chip_logic_id.isdigit():
raise RuntimeError(
"Failed to get correct chip_logic_id from command 'npu-smi info -m'."
)
raise RuntimeError("Failed to get correct chip_logic_id from command 'npu-smi info -m'.")
running_npu_set.add(int(chip_logic_id))
if ASCEND_RT_VISIBLE_DEVICES:
devices_str = ASCEND_RT_VISIBLE_DEVICES
devices_list = [int(x) for x in devices_str.split(",")]
running_npu_set = set(devices_list) & running_npu_set
if not running_npu_set:
raise RuntimeError(
"Can not get running npu info, you can use BIND_CPU=0 to skip."
)
raise RuntimeError("Can not get running npu info, you can use BIND_CPU=0 to skip.")
return sorted(running_npu_set)
def parse_allowed_cpus(self) -> List[int]:
def parse_allowed_cpus(self) -> list[int]:
if not os.path.exists(ALLOWED_CPUS_PATH):
return []
with open(ALLOWED_CPUS_PATH) as f:
for line in f:
if line.startswith("Cpus_allowed_list"):
return self.expand_cpu_list(line.split()[1])
raise RuntimeError(
"Can not found specific 'Cpus_allowed_list' in the '/proc/self/status' file."
)
raise RuntimeError("Can not found specific 'Cpus_allowed_list' in the '/proc/self/status' file.")
def parse_topo_affinity(self) -> Dict[int, List[int]]:
def parse_topo_affinity(self) -> dict[int, list[int]]:
chip_logic_id = 0
affinity: Dict[int, List[int]] = {}
affinity_message, _ = execute_command(
["npu-smi", "info", "-t", "topo"])
affinity: dict[int, list[int]] = {}
affinity_message, _ = execute_command(["npu-smi", "info", "-t", "topo"])
for line in affinity_message.splitlines():
if line.startswith("NPU"):
parts = line.split()
@@ -117,21 +104,19 @@ class DeviceInfo:
class CpuAlloc:
def __init__(self, rank_id: int):
self.rank_id = rank_id
self.device_info: DeviceInfo = DeviceInfo()
self.cpu_node: Dict[int, int] = {}
self.numa_to_cpu_map: Dict[int, List[int]] = defaultdict(list)
self.npu_cpu_pool: Dict[int, List[int]] = {}
self.assign_main: Dict[int, List[int]] = {}
self.assign_acl: Dict[int, List[int]] = {}
self.assign_rel: Dict[int, List[int]] = {}
self.cpu_node: dict[int, int] = {}
self.numa_to_cpu_map: dict[int, list[int]] = defaultdict(list)
self.npu_cpu_pool: dict[int, list[int]] = {}
self.assign_main: dict[int, list[int]] = {}
self.assign_acl: dict[int, list[int]] = {}
self.assign_rel: dict[int, list[int]] = {}
@staticmethod
def get_threads_map(
thread_message: str) -> Dict[str, Dict[str, List[str]]]:
threads_map: Dict[str, Dict[str, List[str]]] = {}
def get_threads_map(thread_message: str) -> dict[str, dict[str, list[str]]]:
threads_map: dict[str, dict[str, list[str]]] = {}
for line in thread_message.splitlines():
parts = line.split()
if len(parts) < 2:
@@ -144,40 +129,33 @@ class CpuAlloc:
else:
continue
if main_pid not in threads_map:
threads_map[main_pid] = {
"acl_thread": [],
"release_thread": []
}
threads_map[main_pid] = {"acl_thread": [], "release_thread": []}
threads_map[main_pid][key].append(sub_pid)
return threads_map
@staticmethod
def bind(pid: str, cpus: List[int], bind_sub_thread: bool) -> None:
def bind(pid: str, cpus: list[int], bind_sub_thread: bool) -> None:
if cpus:
cpu_list = ",".join(map(str, cpus))
if bind_sub_thread:
bind_result, return_code = execute_command(
["taskset", "-acp", cpu_list, pid])
bind_result, return_code = execute_command(["taskset", "-acp", cpu_list, pid])
else:
bind_result, return_code = execute_command(
["taskset", "-cp", cpu_list, pid])
bind_result, return_code = execute_command(["taskset", "-cp", cpu_list, pid])
if return_code != 0:
raise RuntimeError(f"Failed to bind {pid} to CPU {cpu_list}.")
def average_distribute(
self, groups: Dict[str, List[int]]) -> Dict[int, List[int]]:
result: Dict[int, List[int]] = {}
def average_distribute(self, groups: dict[str, list[int]]) -> dict[int, list[int]]:
result: dict[int, list[int]] = {}
for key, npu_list in groups.items():
cpu_list = sorted(self.npu_cpu_pool[npu_list[0]])
cpu_num_per_npu = len(cpu_list) // len(npu_list)
for i, npu in enumerate(npu_list):
start_index = i * cpu_num_per_npu
end_index = (i + 1) * cpu_num_per_npu if i < len(
npu_list) - 1 else len(cpu_list)
end_index = (i + 1) * cpu_num_per_npu if i < len(npu_list) - 1 else len(cpu_list)
result[npu] = cpu_list[start_index:end_index]
return result
def extend_numa(self, cpu_list: List[int]) -> List[int]:
def extend_numa(self, cpu_list: list[int]) -> list[int]:
if not cpu_list:
return []
nodes = {self.cpu_node[c] for c in cpu_list}
@@ -203,9 +181,7 @@ class CpuAlloc:
self.cpu_node[cpu] = node
self.numa_to_cpu_map[node].append(cpu)
if len(self.numa_to_cpu_map) == 0:
raise RuntimeError(
"lscpu command output error, no NUMA node available. Please check!"
)
raise RuntimeError("lscpu command output error, no NUMA node available. Please check!")
def handle_no_affinity(self) -> None:
num_running_npu = len(self.device_info.running_npu_list)
@@ -219,10 +195,7 @@ class CpuAlloc:
index = 0
for node in sorted(self.numa_to_cpu_map):
# Available CPUs on this NUMA (constrained by allowed_cpus)
cpus = [
c for c in self.numa_to_cpu_map[node]
if c in self.device_info.allowed_cpus
]
cpus = [c for c in self.numa_to_cpu_map[node] if c in self.device_info.allowed_cpus]
if not cpus:
continue
# The actual number of NPUs to be allocated on this NUMA.
@@ -251,19 +224,16 @@ class CpuAlloc:
return
for npu in self.device_info.running_npu_list:
base_cpu_list = [
cpu for cpu in self.device_info.npu_affinity.get(npu, [])
if cpu in self.device_info.allowed_cpus
cpu for cpu in self.device_info.npu_affinity.get(npu, []) if cpu in self.device_info.allowed_cpus
]
if not base_cpu_list:
raise RuntimeError(
"CPUs available in 'Cpus_allowed_list' conflict with NUMA affinity."
)
raise RuntimeError("CPUs available in 'Cpus_allowed_list' conflict with NUMA affinity.")
extra_cpu_list = self.extend_numa(base_cpu_list)
self.npu_cpu_pool[npu] = extra_cpu_list
groups = defaultdict(list)
for npu, cpus in self.npu_cpu_pool.items():
groups[str(cpus)].append(npu)
final: Dict[int, List[int]] = {}
final: dict[int, list[int]] = {}
for key, npu_list in groups.items():
if len(npu_list) == 1:
final[npu_list[0]] = self.npu_cpu_pool[npu_list[0]]
@@ -279,8 +249,8 @@ class CpuAlloc:
rel = [pool[-1]]
else:
raise RuntimeError(
"The number of CPUs is insufficient to bind to the NPUs. "
"Each NPU requires at least 3 CPUs.")
"The number of CPUs is insufficient to bind to the NPUs. Each NPU requires at least 3 CPUs."
)
self.assign_main[npu] = main
self.assign_acl[npu] = acl
self.assign_rel[npu] = rel
@@ -290,10 +260,8 @@ class CpuAlloc:
current_npu = self.device_info.running_npu_list[self.rank_id]
main = " ".join(map(str, self.assign_main[current_npu]))
acl = " ".join(map(str, self.assign_acl[current_npu]))
rel = str(self.assign_rel[current_npu]
) if self.assign_rel[current_npu] else ""
logger.info(
f"NPU{current_npu}: main=[{main}] acl=[{acl}] release=[{rel}]")
rel = str(self.assign_rel[current_npu]) if self.assign_rel[current_npu] else ""
logger.info(f"NPU{current_npu}: main=[{main}] acl=[{acl}] release=[{rel}]")
def bind_threads(self) -> None:
thread_message, _ = execute_command(["ps", "-Te"])
@@ -303,8 +271,7 @@ class CpuAlloc:
self.bind(main_pid, self.assign_main[current_npu], True)
for acl_thread in threads_map.get(main_pid, {}).get("acl_thread", []):
self.bind(acl_thread, self.assign_acl[current_npu], False)
for release_thread in threads_map.get(main_pid,
{}).get("release_thread", []):
for release_thread in threads_map.get(main_pid, {}).get("release_thread", []):
self.bind(release_thread, self.assign_rel[current_npu], False)
def run_all(self) -> None:

View File

@@ -1,5 +1,4 @@
from dataclasses import dataclass
from typing import Optional
import torch
from vllm.model_executor.layers.linear import LinearBase
@@ -7,26 +6,26 @@ from vllm.model_executor.layers.linear import LinearBase
@dataclass
class FlashCommon3Context:
gate: Optional[LinearBase] = None
topk_weights: Optional[torch.Tensor] = None
topk_ids: Optional[torch.Tensor] = None
row_idx: Optional[torch.Tensor] = None
shared_experts: Optional[torch.nn.Module] = None
shared_out: Optional[torch.Tensor] = None
gate: LinearBase | None = None
topk_weights: torch.Tensor | None = None
topk_ids: torch.Tensor | None = None
row_idx: torch.Tensor | None = None
shared_experts: torch.nn.Module | None = None
shared_out: torch.Tensor | None = None
_flash_common3_context: Optional[FlashCommon3Context] = None
_flash_common3_context: FlashCommon3Context | None = None
def get_flash_common3_context() -> Optional[FlashCommon3Context]:
def get_flash_common3_context() -> FlashCommon3Context | None:
return _flash_common3_context
def set_flash_common3_context(
topk_weights: Optional[torch.Tensor] = None,
topk_ids: Optional[torch.Tensor] = None,
shared_experts: Optional[torch.nn.Module] = None,
shared_out: Optional[torch.Tensor] = None,
topk_weights: torch.Tensor | None = None,
topk_ids: torch.Tensor | None = None,
shared_experts: torch.nn.Module | None = None,
shared_out: torch.Tensor | None = None,
):
global _flash_common3_context
if _flash_common3_context is None:

View File

@@ -46,17 +46,20 @@ def register_meta_if_necessary(ns: str, op_name: str, fn, overload: str = ""):
if overload != "":
op_name = op_name + "." + overload
schema_to_find = ns + "::" + op_name
meta_impl_list = torch._C._dispatch_get_registrations_for_dispatch_key(
"Meta")
meta_impl_list = torch._C._dispatch_get_registrations_for_dispatch_key("Meta")
if schema_to_find in meta_impl_list:
return
lib.impl(op_name, fn, "Meta")
def rotary_embedding_meta(positions: torch.Tensor, query: torch.Tensor,
key: torch.Tensor, head_size: int,
cos_sin_cache: torch.Tensor, is_neox: bool):
def rotary_embedding_meta(
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
head_size: int,
cos_sin_cache: torch.Tensor,
is_neox: bool,
):
num_tokens = positions.numel()
query_hidden_size = query.numel() // num_tokens
key_hidden_size = key.numel() // num_tokens
@@ -68,38 +71,41 @@ def rotary_embedding_meta(positions: torch.Tensor, query: torch.Tensor,
return query_dst, key_dst
def get_masked_input_and_mask_meta(input: torch.Tensor,
def get_masked_input_and_mask_meta(
input: torch.Tensor,
org_vocab_start_index: int,
org_vocab_end_index: int,
num_org_vocab_padding: int,
added_vocab_start_index: int,
added_vocab_end_index: int):
added_vocab_end_index: int,
):
masked_input = torch.empty_like(input)
mask = torch.empty_like(input).to(torch.bool)
return masked_input, mask
def bgmv_expand_meta(x: torch.Tensor, weight: torch.Tensor,
indices: torch.Tensor, y: torch.Tensor, slice_offset: int,
slice_size: int):
def bgmv_expand_meta(
x: torch.Tensor, weight: torch.Tensor, indices: torch.Tensor, y: torch.Tensor, slice_offset: int, slice_size: int
):
y_out = torch.empty_like(y)
return y_out
def sgmv_expand_meta(x: torch.Tensor, weight: torch.Tensor,
lora_indices: torch.Tensor, seq_len: torch.Tensor,
y: torch.Tensor, slice_offset: int, slice_size: int):
def sgmv_expand_meta(
x: torch.Tensor,
weight: torch.Tensor,
lora_indices: torch.Tensor,
seq_len: torch.Tensor,
y: torch.Tensor,
slice_offset: int,
slice_size: int,
):
y_out = torch.empty_like(y)
return y_out
register_meta_if_necessary("_C_ascend", "rotary_embedding",
rotary_embedding_meta)
register_meta_if_necessary("_C_ascend", "get_masked_input_and_mask",
get_masked_input_and_mask_meta)
register_meta_if_necessary("_C_ascend", "rotary_embedding", rotary_embedding_meta)
register_meta_if_necessary("_C_ascend", "get_masked_input_and_mask", get_masked_input_and_mask_meta)
register_meta_if_necessary("_C_ascend", "bgmv_expand", bgmv_expand_meta)
register_meta_if_necessary("_C_ascend", "sgmv_expand", sgmv_expand_meta)

View File

@@ -15,9 +15,11 @@
# This file is a part of the vllm-ascend project.
#
from __future__ import annotations
import math
import os
from typing import TYPE_CHECKING, Any, Optional
from typing import TYPE_CHECKING, Any
from uuid import uuid4
import torch
@@ -32,11 +34,21 @@ from vllm_ascend.ascend_config import init_ascend_config
# isort: off
from vllm_ascend.utils import (
ASCEND_QUANTIZATION_METHOD, COMPILATION_PASS_KEY,
COMPRESSED_TENSORS_METHOD, AscendDeviceType, check_kv_extra_config,
enable_sp, flashcomm2_enable, get_ascend_device_type, is_moe_model,
is_vl_model, refresh_block_size, update_aclgraph_sizes,
update_cudagraph_capture_sizes, update_default_aclgraph_sizes)
ASCEND_QUANTIZATION_METHOD,
COMPILATION_PASS_KEY,
COMPRESSED_TENSORS_METHOD,
AscendDeviceType,
check_kv_extra_config,
enable_sp,
flashcomm2_enable,
get_ascend_device_type,
is_moe_model,
is_vl_model,
refresh_block_size,
update_aclgraph_sizes,
update_cudagraph_capture_sizes,
update_default_aclgraph_sizes,
)
if TYPE_CHECKING:
from vllm.config import ModelConfig, VllmConfig
@@ -80,7 +92,6 @@ def config_deprecated_logging():
class NPUPlatform(Platform):
_enum = PlatformEnum.OOT
device_name: str = "npu"
device_type: str = "npu"
@@ -89,9 +100,7 @@ class NPUPlatform(Platform):
device_control_env_var: str = "ASCEND_RT_VISIBLE_DEVICES"
dispatch_key: str = "PrivateUse1"
supported_quantization: list[str] = [
ASCEND_QUANTIZATION_METHOD, COMPRESSED_TENSORS_METHOD
]
supported_quantization: list[str] = [ASCEND_QUANTIZATION_METHOD, COMPRESSED_TENSORS_METHOD]
def is_sleep_mode_available(self) -> bool:
return True
@@ -122,27 +131,23 @@ class NPUPlatform(Platform):
return "vllm_ascend.compilation.compiler_interface.AscendCompiler"
@classmethod
def pre_register_and_update(cls,
parser: Optional[FlexibleArgumentParser] = None
) -> None:
def pre_register_and_update(cls, parser: FlexibleArgumentParser | None = None) -> None:
# Adapt the global patch here.
from vllm_ascend.utils import adapt_patch
adapt_patch(is_global_patch=True)
# For online serving, "ascend" quantization method is not a choice natively,
# so we need to add "ascend" quantization method to quantization methods list
# and the user can enable quantization using "vllm serve --quantization ascend".
if parser is not None:
quant_action = parser._option_string_actions.get('--quantization')
if quant_action and hasattr(quant_action,
'choices') and quant_action.choices:
quant_action = parser._option_string_actions.get("--quantization")
if quant_action and hasattr(quant_action, "choices") and quant_action.choices:
if ASCEND_QUANTIZATION_METHOD not in quant_action.choices:
quant_action.choices.append(ASCEND_QUANTIZATION_METHOD)
from vllm_ascend.quantization.compressed_tensors.compressed_tensors import \
AscendCompressedTensorsConfig # noqa: F401
from vllm_ascend.quantization.quant_config import \
AscendQuantConfig # noqa: F401
from vllm_ascend.quantization.compressed_tensors.compressed_tensors import AscendCompressedTensorsConfig # noqa: F401
from vllm_ascend.quantization.quant_config import AscendQuantConfig # noqa: F401
config_deprecated_logging()
@@ -169,8 +174,7 @@ class NPUPlatform(Platform):
if vllm_config.kv_transfer_config is not None:
check_kv_extra_config(vllm_config)
if not getattr(vllm_config.kv_transfer_config,
"_engine_id_patched", False):
if not getattr(vllm_config.kv_transfer_config, "_engine_id_patched", False):
vllm_config.kv_transfer_config.engine_id = f"{vllm_config.kv_transfer_config.engine_id}-{uuid4().hex}"
vllm_config.kv_transfer_config._engine_id_patched = True
from vllm.config import CompilationMode # noqa: E402
@@ -181,24 +185,22 @@ class NPUPlatform(Platform):
cache_config = vllm_config.cache_config
ascend_compilation_config = ascend_config.ascend_compilation_config
if ascend_compilation_config:
vllm_config.additional_config.setdefault(
"ascend_compilation_config", {}).update(
vars(ascend_compilation_config
) if not isinstance(ascend_compilation_config, dict)
else ascend_compilation_config)
vllm_config.additional_config.setdefault("ascend_compilation_config", {}).update(
vars(ascend_compilation_config)
if not isinstance(ascend_compilation_config, dict)
else ascend_compilation_config
)
elif model_config and hasattr(model_config.hf_text_config,
"index_topk"):
vllm_config.cache_config.cache_dtype = str(
model_config.dtype).replace("torch.", "")
elif model_config and hasattr(model_config.hf_text_config, "index_topk"):
vllm_config.cache_config.cache_dtype = str(model_config.dtype).replace("torch.", "")
if model_config is None:
logger.warning("Model config is missing. This may indicate "
"that we are running a test case")
logger.warning("Model config is missing. This may indicate that we are running a test case")
enforce_eager = False
else:
enforce_eager = getattr(model_config, "enforce_eager", False)
from vllm.config.compilation import CUDAGraphMode
if enforce_eager:
logger.info("Compilation disabled, using eager mode by default")
compilation_config.mode = CompilationMode.NONE
@@ -207,12 +209,10 @@ class NPUPlatform(Platform):
compilation_config.cudagraph_num_of_warmups = 1
if compilation_config.mode not in [
CompilationMode.NONE, CompilationMode.VLLM_COMPILE
]:
if compilation_config.mode not in [CompilationMode.NONE, CompilationMode.VLLM_COMPILE]:
logger.warning(
"NPU does not support %s compilation mode. Setting CUDAGraphMode to NONE",
compilation_config.mode)
"NPU does not support %s compilation mode. Setting CUDAGraphMode to NONE", compilation_config.mode
)
compilation_config.cudagraph_mode = CUDAGraphMode.NONE
# set cudaprah sizes before extending `compilation_config.splitting_ops`
@@ -223,15 +223,18 @@ class NPUPlatform(Platform):
update_default_aclgraph_sizes(vllm_config)
# TODO delete graph size update here when compilation_config.pass_config.enable_sp
# is supported by vllm-ascend.
if vllm_config.parallel_config.tensor_parallel_size > 1 and not vllm_config.model_config.enforce_eager and \
enable_sp(vllm_config):
if (
vllm_config.parallel_config.tensor_parallel_size > 1
and not vllm_config.model_config.enforce_eager
and enable_sp(vllm_config)
):
original_sizes = compilation_config.cudagraph_capture_sizes
sp_aclgraph_sizes = \
vllm_config.update_sizes_for_sequence_parallelism(original_sizes)
sp_aclgraph_sizes = vllm_config.update_sizes_for_sequence_parallelism(original_sizes)
assert sp_aclgraph_sizes, (
f"cudagraph_capture_sizes {original_sizes} does not contain"
f"values that are multiples of tp_size "
f"{vllm_config.parallel_config.tensor_parallel_size}")
f"{vllm_config.parallel_config.tensor_parallel_size}"
)
if len(sp_aclgraph_sizes) != len(original_sizes):
compilation_config.cudagraph_capture_sizes = sp_aclgraph_sizes
update_cudagraph_capture_sizes(vllm_config, sp_aclgraph_sizes)
@@ -243,9 +246,7 @@ class NPUPlatform(Platform):
# encoder-decoder models currently only support piecewise mode
if model_config and model_config.is_encoder_decoder is True:
if compilation_config.cudagraph_mode == CUDAGraphMode.FULL_DECODE_ONLY:
logger.warning(
"encoder-decoder model doesn't support FULL_DECODE_ONLY, fallback to PIECEWISE "
)
logger.warning("encoder-decoder model doesn't support FULL_DECODE_ONLY, fallback to PIECEWISE ")
compilation_config.cudagraph_mode = CUDAGraphMode.PIECEWISE
# get custom compile backend for graph fusion
@@ -255,15 +256,14 @@ class NPUPlatform(Platform):
compilation_config.mode = CompilationMode.NONE
ascend_config.enable_npugraph_ex = False
elif compilation_config.cudagraph_mode == CUDAGraphMode.PIECEWISE:
logger.info(
"PIECEWISE compilation enabled on NPU. use_inductor not supported - "
"using only ACL Graph mode")
assert compilation_config.mode == CompilationMode.VLLM_COMPILE, \
"When enabling VLLM_COMPILE aclgraph, please make sure compilation_config.mode == CompilationMode.VLLM_COMPILE and compilation_config.cudagraph_mode == CUDAGraphMode.VLLM_COMPILE"
logger.info("PIECEWISE compilation enabled on NPU. use_inductor not supported - using only ACL Graph mode")
assert compilation_config.mode == CompilationMode.VLLM_COMPILE, (
"When enabling VLLM_COMPILE aclgraph, please make sure compilation_config.mode == "
"CompilationMode.VLLM_COMPILE and compilation_config.cudagraph_mode == CUDAGraphMode.VLLM_COMPILE"
)
compilation_config.set_splitting_ops_for_v1(
all2all_backend=vllm_config.parallel_config.all2all_backend,
data_parallel_size=vllm_config.parallel_config.
data_parallel_size,
data_parallel_size=vllm_config.parallel_config.data_parallel_size,
)
compilation_config.use_inductor = False
# NOTE: Theoretically, we should also add vllm::mla_forward in the attention ops.
@@ -275,11 +275,13 @@ class NPUPlatform(Platform):
compilation_config.splitting_ops.extend(["vllm::mla_forward"])
update_aclgraph_sizes(vllm_config)
ascend_config.enable_npugraph_ex = False
elif compilation_config.cudagraph_mode == CUDAGraphMode.FULL_DECODE_ONLY or\
compilation_config.cudagraph_mode == CUDAGraphMode.FULL:
elif (
compilation_config.cudagraph_mode == CUDAGraphMode.FULL_DECODE_ONLY
or compilation_config.cudagraph_mode == CUDAGraphMode.FULL
):
logger.info(
"FULL_DECODE_ONLY compilation enabled on NPU. use_inductor not supported - "
"using only ACL Graph mode")
"FULL_DECODE_ONLY compilation enabled on NPU. use_inductor not supported - using only ACL Graph mode"
)
compilation_config.use_inductor = False
compilation_config.splitting_ops = []
warning_message = """\033[91m
@@ -297,30 +299,31 @@ class NPUPlatform(Platform):
logger.warning(warning_message)
else:
logger.info(
"%s cudagraph_mode is not support on NPU. falling back to NONE",
compilation_config.cudagraph_mode)
"%s cudagraph_mode is not support on NPU. falling back to NONE", compilation_config.cudagraph_mode
)
compilation_config.cudagraph_mode = CUDAGraphMode.NONE
compilation_config.mode = CompilationMode.NONE
ascend_config.enable_npugraph_ex = False
# TODO: Remove this check when ACL Graph supports ASCEND_LAUNCH_BLOCKING=1
# Then, we will have to discuss the error handling strategy and user experience
if compilation_config.cudagraph_mode != CUDAGraphMode.NONE and \
os.environ.get("ASCEND_LAUNCH_BLOCKING", "0") == "1":
if (
compilation_config.cudagraph_mode != CUDAGraphMode.NONE
and os.environ.get("ASCEND_LAUNCH_BLOCKING", "0") == "1"
):
raise ValueError(
"ACL graph is incompatible with ASCEND_LAUNCH_BLOCKING=1. "
"Please unset ASCEND_LAUNCH_BLOCKING or set it to 0. If you "
"need ASCEND_LAUNCH_BLOCKING for debugging, consider other methods — "
"for example, check the plog files (default: $HOME/ascend/log/debug) "
"for more information about runtime errors.")
"for more information about runtime errors."
)
if parallel_config and parallel_config.worker_cls == "auto":
# TODO: this is a tricky way to disable `use_sequence_parallel_moe` in vllm.
parallel_config.all2all_backend = "flashinfer_all2allv"
if ascend_config.xlite_graph_config.enabled:
logger.info(
"openEuler Xlite enabled. See: https://atomgit.com/openeuler/GVirt/tree/master/xlite"
)
logger.info("openEuler Xlite enabled. See: https://atomgit.com/openeuler/GVirt/tree/master/xlite")
parallel_config.worker_cls = "vllm_ascend.xlite.xlite_worker.XliteWorker"
else:
parallel_config.worker_cls = "vllm_ascend.worker.worker.NPUWorker"
@@ -332,10 +335,9 @@ class NPUPlatform(Platform):
compilation_config.custom_ops = ["all"]
if ascend_config.recompute_scheduler_enable:
from vllm_ascend.core.recompute_scheduler import \
RecomputeSchedulerConfig
recompute_scheduler_config = RecomputeSchedulerConfig.initialize_from_config(
vllm_config)
from vllm_ascend.core.recompute_scheduler import RecomputeSchedulerConfig
recompute_scheduler_config = RecomputeSchedulerConfig.initialize_from_config(vllm_config)
vllm_config.scheduler_config = recompute_scheduler_config
# Extend original scheduler_config to use SchedulerDynamicBatch.
@@ -346,9 +348,11 @@ class NPUPlatform(Platform):
vllm_config.scheduler_config.enable_chunked_prefill = True
vllm_config.scheduler_config.SLO_limits_for_dynamic_batch = ascend_config.SLO_limits_for_dynamic_batch
if vllm_config.kv_transfer_config is not None and \
cache_config.block_size != parallel_config.cp_kv_cache_interleave_size and \
parallel_config.decode_context_parallel_size * parallel_config.prefill_context_parallel_size > 1:
if (
vllm_config.kv_transfer_config is not None
and cache_config.block_size != parallel_config.cp_kv_cache_interleave_size
and parallel_config.decode_context_parallel_size * parallel_config.prefill_context_parallel_size > 1
):
raise AssertionError(
f"cp_kv_cache_interleave_size({parallel_config.cp_kv_cache_interleave_size}) "
f"and block_size({cache_config.block_size}) "
@@ -356,12 +360,14 @@ class NPUPlatform(Platform):
)
if is_vl_model(vllm_config):
if bool(int(os.getenv("VLLM_ASCEND_ENABLE_FLASHCOMM", '0'))) or \
bool(int(os.getenv("VLLM_ASCEND_ENABLE_FLASHCOMM1", '0'))):
if bool(int(os.getenv("VLLM_ASCEND_ENABLE_FLASHCOMM", "0"))) or bool(
int(os.getenv("VLLM_ASCEND_ENABLE_FLASHCOMM1", "0"))
):
raise ValueError(
"Currently, VL models doesn't support "
"FLASHCOMM in vllm-ascend. We will fix this in the future. "
"Please set VLLM_ASCEND_ENABLE_FLASHCOMM1=0.")
"Please set VLLM_ASCEND_ENABLE_FLASHCOMM1=0."
)
@classmethod
def import_kernels(cls) -> None:
@@ -377,14 +383,11 @@ class NPUPlatform(Platform):
if _CUSTOM_OP_REGISTERED:
return
CUR_DIR = os.path.dirname(os.path.realpath(__file__))
CUSTOM_OPP_PATH = os.path.join(CUR_DIR, "_cann_ops_custom", "vendors",
"vllm-ascend")
CUSTOM_OPP_PATH = os.path.join(CUR_DIR, "_cann_ops_custom", "vendors", "vllm-ascend")
if os.path.exists(CUSTOM_OPP_PATH):
current_cust_opp_path = os.environ.get("ASCEND_CUSTOM_OPP_PATH",
"")
current_cust_opp_path = os.environ.get("ASCEND_CUSTOM_OPP_PATH", "")
if current_cust_opp_path:
os.environ[
"ASCEND_CUSTOM_OPP_PATH"] = f"{CUSTOM_OPP_PATH}:{current_cust_opp_path}"
os.environ["ASCEND_CUSTOM_OPP_PATH"] = f"{CUSTOM_OPP_PATH}:{current_cust_opp_path}"
else:
os.environ["ASCEND_CUSTOM_OPP_PATH"] = CUSTOM_OPP_PATH
_CUSTOM_OP_REGISTERED = True
@@ -393,22 +396,18 @@ class NPUPlatform(Platform):
def get_attn_backend_cls(cls, selected_backend, attn_selector_config):
backend_map = {
(True, False): "vllm_ascend.attention.mla_v1.AscendMLABackend",
(False, False):
"vllm_ascend.attention.attention_v1.AscendAttentionBackend",
(False, False): "vllm_ascend.attention.attention_v1.AscendAttentionBackend",
(True, True): "vllm_ascend.attention.sfa_v1.AscendSFABackend",
}
return backend_map[(attn_selector_config.use_mla,
attn_selector_config.use_sparse)]
return backend_map[(attn_selector_config.use_mla, attn_selector_config.use_sparse)]
@classmethod
def get_punica_wrapper(cls) -> str:
return "vllm_ascend.lora.punica_npu.PunicaWrapperNPU"
@classmethod
def get_current_memory_usage(cls,
device: Optional[torch.types.Device] = None
) -> float:
def get_current_memory_usage(cls, device: torch.types.Device | None = None) -> float:
torch.npu.reset_peak_memory_stats(device)
return torch.npu.max_memory_allocated(device)
@@ -474,15 +473,16 @@ class NPUPlatform(Platform):
dict[str, Any]: _description_
"""
# NOTE(Ronald1995): avoid circular import.
from vllm_ascend.ascend_forward_context import (get_mc2_mask,
select_moe_comm_method)
from vllm_ascend.ascend_forward_context import get_mc2_mask, select_moe_comm_method
from vllm_ascend.ops.fused_moe.moe_comm_method import get_moe_comm_method
from vllm.distributed import get_dp_group, get_tensor_model_parallel_world_size
# NOTE(Ronald1995): avoid circular import, cudagraph_runtime_mode is
# CUDAGraphMode.NONE in vllm, but we can't set CUDAGraphMode.NONE in
# argument default value, so we set it to None first, then set it to
# CUDAGraphMode.NONE here.
from vllm.config import CUDAGraphMode
if cudagraph_runtime_mode is None:
cudagraph_runtime_mode = CUDAGraphMode.NONE
# TODO(Ronald1995): model runner v1 still use ascend_forward_context,
@@ -526,31 +526,26 @@ class NPUPlatform(Platform):
sp_enabled = enable_sp(vllm_config) and num_tokens is not None
mmrs_fusion = False
else:
sp_enabled = enable_sp(vllm_config) and \
num_tokens is not None and num_tokens > 1000
sp_enabled = enable_sp(vllm_config) and num_tokens is not None and num_tokens > 1000
# TODO(Levi-JQ): another PR to normalize the enabling logic for sp/fc2
flashcomm_v2_enabled = flashcomm2_enable(
) and tp_world_size > 1 and num_tokens is not None
flashcomm_v2_enabled = flashcomm2_enable() and tp_world_size > 1 and num_tokens is not None
pad_size = 0
if (sp_enabled or flashcomm_v2_enabled):
pad_size = (tp_world_size -
(num_tokens % tp_world_size)) % tp_world_size
if sp_enabled or flashcomm_v2_enabled:
pad_size = (tp_world_size - (num_tokens % tp_world_size)) % tp_world_size
dp_world_size = get_dp_group().world_size
if dp_world_size > 1 and dp_metadata is not None:
max_tokens_across_dp = dp_metadata.max_tokens_across_dp_cpu.item()
if (sp_enabled or flashcomm_v2_enabled):
padded_length = (max_tokens_across_dp + tp_world_size -
1) // tp_world_size * tp_world_size
if sp_enabled or flashcomm_v2_enabled:
padded_length = (max_tokens_across_dp + tp_world_size - 1) // tp_world_size * tp_world_size
pad_size = padded_length - num_tokens
else:
max_tokens_across_dp = num_tokens
if num_tokens is not None:
# NOTE: token num which need to pad to when mc2
padded_num_tokens = math.ceil(
max_tokens_across_dp / tp_world_size) * tp_world_size
padded_num_tokens = math.ceil(max_tokens_across_dp / tp_world_size) * tp_world_size
reserved_mc2_mask = get_mc2_mask()
if reserved_mc2_mask is not None:
mc2_mask = reserved_mc2_mask[:padded_num_tokens]

View File

@@ -21,9 +21,9 @@ This module generates the service_profiling_symbols.yaml configuration file
to ~/.config/vllm_ascend/ directory.
"""
import contextlib
import tempfile
from pathlib import Path
from typing import Optional
import vllm
from vllm.logger import logger
@@ -129,7 +129,7 @@ def get_config_dir() -> Path:
return config_dir
def _cleanup_temp_file(tmp_path: Optional[Path]) -> None:
def _cleanup_temp_file(tmp_path: Path | None) -> None:
"""
Clean up a temporary file if it exists.
@@ -137,13 +137,11 @@ def _cleanup_temp_file(tmp_path: Optional[Path]) -> None:
tmp_path: Path to the temporary file to clean up.
"""
if tmp_path is not None and tmp_path.exists():
try:
with contextlib.suppress(OSError):
tmp_path.unlink()
except OSError:
pass # Ignore cleanup errors
def generate_service_profiling_config() -> Optional[Path]:
def generate_service_profiling_config() -> Path | None:
"""
Generate the service_profiling_symbols.yaml configuration file
to ~/.config/vllm_ascend/ directory.
@@ -170,9 +168,7 @@ def generate_service_profiling_config() -> Optional[Path]:
try:
config_dir.mkdir(parents=True, exist_ok=True)
except (OSError, PermissionError) as e:
logger.error(
f"Failed to create configuration directory {config_dir}: {e}",
exc_info=True)
logger.error(f"Failed to create configuration directory {config_dir}: {e}", exc_info=True)
return None
# Write the configuration file atomically using a temporary file
@@ -180,13 +176,9 @@ def generate_service_profiling_config() -> Optional[Path]:
tmp_path = None
try:
# Create a temporary file in the same directory for atomic write
with tempfile.NamedTemporaryFile(mode='w',
encoding='utf-8',
dir=config_dir,
delete=False,
suffix='.tmp',
prefix=CONFIG_FILENAME +
'.') as tmp_file:
with tempfile.NamedTemporaryFile(
mode="w", encoding="utf-8", dir=config_dir, delete=False, suffix=".tmp", prefix=CONFIG_FILENAME + "."
) as tmp_file:
tmp_file.write(SERVICE_PROFILING_SYMBOLS_YAML)
tmp_path = Path(tmp_file.name)
@@ -194,8 +186,7 @@ def generate_service_profiling_config() -> Optional[Path]:
tmp_path.replace(config_file)
return config_file
except (OSError, PermissionError) as e:
logger.error(f"Failed to write configuration file {config_file}: {e}",
exc_info=True)
logger.error(f"Failed to write configuration file {config_file}: {e}", exc_info=True)
return None
finally:
# Clean up the temporary file if it wasn't successfully replaced

View File

@@ -17,6 +17,8 @@
# Adapted from vllm-project/vllm/vllm/worker/worker.py
#
from __future__ import annotations
import atexit
import functools
import math
@@ -25,7 +27,7 @@ from contextlib import contextmanager, nullcontext
from enum import Enum
from functools import lru_cache
from threading import Lock
from typing import TYPE_CHECKING, Any, List, Optional, Tuple, Union
from typing import TYPE_CHECKING, Any
import torch
import torch_npu # noqa: F401
@@ -88,6 +90,7 @@ def acl_graph_print(*args):
resolving unexpected hangs. Usage:
```python
from vllm_ascend.utils import acl_graph_print
...
acl_graph_print("Debug info")
```
@@ -113,8 +116,7 @@ def acl_graph_print(*args):
torch_npu.npu._subscribe_report(current_compute_stream)
_SUBSCRIBED_COMPUTE_STREAMS.add(current_compute_stream)
torch_npu.npu._launch_host_func(current_compute_stream,
_print_callback_on_stream, args)
torch_npu.npu._launch_host_func(current_compute_stream, _print_callback_on_stream, args)
def _unregister_print_streams_on_exit():
@@ -196,9 +198,7 @@ def nd_to_nz_2d(in_tensor: torch.Tensor) -> torch.Tensor:
# after: pad_dims: [0, 2, 0, 3]
# return: (1, 2, 16, 16)
return _custom_transpose(
_custom_reshape(_custom_pad(in_tensor, pad_dims), aux_dims), 1,
2).contiguous()
return _custom_transpose(_custom_reshape(_custom_pad(in_tensor, pad_dims), aux_dims), 1, 2).contiguous()
def nd_to_nz_spec(mask_tensor: torch.Tensor) -> torch.Tensor:
@@ -208,11 +208,9 @@ def nd_to_nz_spec(mask_tensor: torch.Tensor) -> torch.Tensor:
tokens_pad = (num_tokens + 15) // 16 * 16
max_seq_len_pad = (max_seq_len + 15) // 16 * 16
mask_tensor_pad = \
torch.zeros((1, tokens_pad, max_seq_len_pad), dtype=mask_tensor.dtype, device=mask_tensor.device)
mask_tensor_pad = torch.zeros((1, tokens_pad, max_seq_len_pad), dtype=mask_tensor.dtype, device=mask_tensor.device)
mask_tensor_pad[0][:num_tokens, :max_seq_len] = mask_tensor
mask = mask_tensor_pad.reshape(
(1, tokens_pad, max_seq_len_pad // 16, 16)).permute(0, 2, 1, 3)
mask = mask_tensor_pad.reshape((1, tokens_pad, max_seq_len_pad // 16, 16)).permute(0, 2, 1, 3)
return mask
@@ -230,10 +228,7 @@ def aligned_16(tensor: torch.Tensor):
return tensor
# Create a new tensor with shape (n_aligned, H, W) and fill it with zeros
new_tensor = torch.zeros(n_aligned,
*tensor.shape[1:],
dtype=tensor.dtype,
device=tensor.device)
new_tensor = torch.zeros(n_aligned, *tensor.shape[1:], dtype=tensor.dtype, device=tensor.device)
# Copy the original tensor to the first N positions of the new tensor
new_tensor[:n] = tensor
@@ -254,15 +249,15 @@ def enable_custom_op():
# isort: off
# register custom ops into torch_library here
import vllm_ascend.vllm_ascend_C # type: ignore # noqa: F401
# register the meta implementation for custom kernel if necessary
import vllm_ascend.meta_registration # type: ignore # noqa: F401
# isort: on
_CUSTOM_OP_ENABLED = True
except ImportError:
_CUSTOM_OP_ENABLED = False
logger.warning(
"Warning: Failed to register custom ops, all custom ops will be disabled"
)
logger.warning("Warning: Failed to register custom ops, all custom ops will be disabled")
return _CUSTOM_OP_ENABLED
@@ -277,8 +272,7 @@ def find_hccl_library() -> str:
# manually load the hccl library
if so_file:
logger.info("Found hccl from environment variable HCCL_SO_PATH=%s",
so_file)
logger.info("Found hccl from environment variable HCCL_SO_PATH=%s", so_file)
else:
if torch.version.cann is not None:
so_file = "libhccl.so"
@@ -318,6 +312,7 @@ def set_weight_prefetch_method(weight_prefetch_config: WeightPrefetchConfig):
global _WEIGHT_PREFETCH_METHOD
if _WEIGHT_PREFETCH_METHOD is None:
from vllm_ascend.ops.weight_prefetch import WeightPrefetchMethod
_WEIGHT_PREFETCH_METHOD = WeightPrefetchMethod(weight_prefetch_config)
return _WEIGHT_PREFETCH_METHOD
@@ -364,6 +359,7 @@ def vllm_version_is(target_vllm_version: str):
vllm_version = envs_ascend.VLLM_VERSION
else:
import vllm
vllm_version = vllm.__version__
try:
return Version(vllm_version) == Version(target_vllm_version)
@@ -372,7 +368,8 @@ def vllm_version_is(target_vllm_version: str):
f"Invalid vllm version {vllm_version} found. A dev version of vllm "
"is installed probably. Set the environment variable VLLM_VERSION "
"to control it by hand. And please make sure the value follows the "
"format of x.y.z.")
"format of x.y.z."
)
def get_max_hidden_layers(hf_config) -> int:
@@ -394,20 +391,19 @@ def get_max_hidden_layers(hf_config) -> int:
# Update cudagraph capture sizes for vllm config
def update_cudagraph_capture_sizes(vllm_config: VllmConfig,
cudagraph_capture_sizes: List[int]):
valid_max_size = (cudagraph_capture_sizes[-1]
if cudagraph_capture_sizes else 0)
if (vllm_config.compilation_config.max_cudagraph_capture_size is not None
and vllm_config.compilation_config.max_cudagraph_capture_size
!= valid_max_size):
def update_cudagraph_capture_sizes(vllm_config: VllmConfig, cudagraph_capture_sizes: list[int]):
valid_max_size = cudagraph_capture_sizes[-1] if cudagraph_capture_sizes else 0
if (
vllm_config.compilation_config.max_cudagraph_capture_size is not None
and vllm_config.compilation_config.max_cudagraph_capture_size != valid_max_size
):
if vllm_config.compilation_config.cudagraph_capture_sizes is not None:
raise ValueError(
"customized max_cudagraph_capture_size"
f"(={vllm_config.compilation_config.max_cudagraph_capture_size}) "
"should be consistent with the max value of "
f"cudagraph_capture_sizes(={valid_max_size})")
f"cudagraph_capture_sizes(={valid_max_size})"
)
logger.warning(
"Truncating max_cudagraph_capture_size to %d",
valid_max_size,
@@ -415,12 +411,11 @@ def update_cudagraph_capture_sizes(vllm_config: VllmConfig,
vllm_config.compilation_config.max_cudagraph_capture_size = valid_max_size
if vllm_config.compilation_config.cudagraph_capture_sizes is not None and len(
cudagraph_capture_sizes) < len(
vllm_config.compilation_config.cudagraph_capture_sizes):
if vllm_config.compilation_config.cudagraph_capture_sizes is not None and len(cudagraph_capture_sizes) < len(
vllm_config.compilation_config.cudagraph_capture_sizes
):
logger.warning(
("cudagraph_capture_sizes specified in compilation_config"
" %s is overridden by config %s"),
("cudagraph_capture_sizes specified in compilation_config %s is overridden by config %s"),
vllm_config.compilation_config.cudagraph_capture_sizes,
cudagraph_capture_sizes,
)
@@ -433,26 +428,17 @@ def _is_default_capture_sizes(vllm_config: VllmConfig) -> bool:
Check whether it is vLLM default capture sizes.
"""
max_cudagraph_capture_size = \
vllm_config.compilation_config.max_cudagraph_capture_size
cudagraph_capture_sizes = [
i for i in [1, 2, 4] if i <= max_cudagraph_capture_size
]
max_cudagraph_capture_size = vllm_config.compilation_config.max_cudagraph_capture_size
cudagraph_capture_sizes = [i for i in [1, 2, 4] if i <= max_cudagraph_capture_size]
if max_cudagraph_capture_size >= 8:
# Step size 8 for small batch sizes, up to 256(not included)
cudagraph_capture_sizes += list(
range(8, min(max_cudagraph_capture_size + 1, 256), 8))
cudagraph_capture_sizes += list(range(8, min(max_cudagraph_capture_size + 1, 256), 8))
if max_cudagraph_capture_size >= 256:
# Step size 16 for larger batch sizes
cudagraph_capture_sizes += list(
range(256, max_cudagraph_capture_size + 1, 16))
cudagraph_capture_sizes += list(range(256, max_cudagraph_capture_size + 1, 16))
# in newer version, vLLM use ascending order of cudagraph_capture_sizes.
target_cudagraph_capture_sizes = sorted(cudagraph_capture_sizes)
if target_cudagraph_capture_sizes == \
vllm_config.compilation_config.cudagraph_capture_sizes:
return True
return False
return target_cudagraph_capture_sizes == vllm_config.compilation_config.cudagraph_capture_sizes
def update_default_aclgraph_sizes(vllm_config: VllmConfig) -> None:
@@ -461,9 +447,11 @@ def update_default_aclgraph_sizes(vllm_config: VllmConfig) -> None:
are more friendly to ascend ops && hardware.
"""
if vllm_config.model_config is None or \
vllm_config.model_config.enforce_eager or \
not _is_default_capture_sizes(vllm_config):
if (
vllm_config.model_config is None
or vllm_config.model_config.enforce_eager
or not _is_default_capture_sizes(vllm_config)
):
return
# modify the default capture_sizes for Qwen3-MoE models on dp settings.
@@ -471,16 +459,15 @@ def update_default_aclgraph_sizes(vllm_config: VllmConfig) -> None:
# on special shapes.
# TODO(Angazenn): we will remove this once _npu_paged_attention is fully
# replaced by npu_fused_infer_attention_score which does not contain such bugs.
if vllm_config.model_config and vllm_config.model_config.hf_text_config.model_type == "qwen3_moe" \
and vllm_config.parallel_config.tensor_parallel_size == 1 \
and vllm_config.parallel_config.data_parallel_size > 1 :
if (
vllm_config.model_config
and vllm_config.model_config.hf_text_config.model_type == "qwen3_moe"
and vllm_config.parallel_config.tensor_parallel_size == 1
and vllm_config.parallel_config.data_parallel_size > 1
):
max_capture_size = vllm_config.compilation_config.max_cudagraph_capture_size
new_cudagraph_capture_sizes = [1, 2, 5, 10, 15, 20] + [
i for i in range(24, max_capture_size + 1, 8)
]
update_cudagraph_capture_sizes(vllm_config,
new_cudagraph_capture_sizes)
new_cudagraph_capture_sizes = [1, 2, 5, 10, 15, 20] + [i for i in range(24, max_capture_size + 1, 8)]
update_cudagraph_capture_sizes(vllm_config, new_cudagraph_capture_sizes)
def update_aclgraph_sizes(vllm_config: VllmConfig) -> None:
@@ -498,8 +485,7 @@ def update_aclgraph_sizes(vllm_config: VllmConfig) -> None:
# Store original configuration and temporarily clear it
compilation_config = vllm_config.compilation_config
original_sizes, compilation_config.cudagraph_capture_sizes = \
compilation_config.cudagraph_capture_sizes, None
original_sizes, compilation_config.cudagraph_capture_sizes = compilation_config.cudagraph_capture_sizes, None
# Calculate parallel configuration factor
if not vllm_config.model_config:
@@ -510,7 +496,7 @@ def update_aclgraph_sizes(vllm_config: VllmConfig) -> None:
return
hf_config = vllm_config.model_config.hf_text_config
if hasattr(hf_config, 'num_hidden_layers'):
if hasattr(hf_config, "num_hidden_layers"):
num_hidden_layers = hf_config.num_hidden_layers
else:
num_hidden_layers = get_max_hidden_layers(hf_config)
@@ -519,24 +505,28 @@ def update_aclgraph_sizes(vllm_config: VllmConfig) -> None:
# Calculate maximum supported batch sizes considering model architecture
resources_per_graph = num_hidden_layers + 1
# For suffix decoding, use the suffix path when no draft_model_config is provided.
if (spec := vllm_config.speculative_config) and \
(draft := spec.draft_model_config):
if (spec := vllm_config.speculative_config) and (draft := spec.draft_model_config):
resources_per_graph += draft.hf_config.num_hidden_layers + 1
# TODO: Find out whether we need to take into account the pp_size
num_comm_groups = sum(size > 1 for size in [
num_comm_groups = sum(
size > 1
for size in [
parallel_config.data_parallel_size,
parallel_config.tensor_parallel_size,
])
]
)
if os.getenv("HCCL_OP_EXPANSION_MODE") == 'AIV':
if os.getenv("HCCL_OP_EXPANSION_MODE") == "AIV":
# TODO: Find out whether we need to take into account the pp_size
parallel_factor = 1 + num_comm_groups + int(
parallel_config.enable_expert_parallel) + int(
vllm_config.additional_config.get(
"multistream_overlap_shared_expert", False))
parallel_factor = (
1
+ num_comm_groups
+ int(parallel_config.enable_expert_parallel)
+ int(vllm_config.additional_config.get("multistream_overlap_shared_expert", False))
)
if is_moe_model(vllm_config):
parallel_factor += (parallel_config.data_parallel_size > 1)
parallel_factor += parallel_config.data_parallel_size > 1
else:
# When AIV mode is enabled, the allreduce operator of the dense
# layer model will occupy additional streams, which are buffered here.
@@ -546,11 +536,8 @@ def update_aclgraph_sizes(vllm_config: VllmConfig) -> None:
# Assume the following case:
# MAX_CAPTURE_SIZE = 1920, num_hidden_layers = 48, data_parallel_size is 1, tensor_parallel_size is 4,
# According to the formula, max_num_batch_sizes = math.floor(1920 / (48 + 1) / 2) = 19
max_num_batch_sizes = math.floor(MAX_CAPTURE_SIZE /
resources_per_graph / parallel_factor)
logger.info(
"Calculated maximum supported batch sizes for ACL graph: %s",
max_num_batch_sizes)
max_num_batch_sizes = math.floor(MAX_CAPTURE_SIZE / resources_per_graph / parallel_factor)
logger.info("Calculated maximum supported batch sizes for ACL graph: %s", max_num_batch_sizes)
else:
# enable pcp or dcp will add new communication and consume additional approximately less than 100 streams
if parallel_config.prefill_context_parallel_size > 1:
@@ -562,18 +549,18 @@ def update_aclgraph_sizes(vllm_config: VllmConfig) -> None:
# Under this configuration, HCCL employs the FFTS+ method for execution unfolding,
# which adds only 1 concurrent stream without consuming collective communication execution unfolding streams.
# On A3 hardware, HCCL defaults to the AICPU method.
# This approach may additionally allocate up to rank_size (max 16) - 1 streams per collective communication domain on the device (worst case).
# Using the default collective communication unfolding method on A3 will lead to a significant reduction in the maximum supported sizes.
# This approach may additionally allocate up to rank_size (max 16) - 1 streams per collective communication
# domain on the device (worst case).
# Using the default collective communication unfolding method on A3 will lead to a significant reduction
# in the maximum supported sizes.
# Therefore, the calculation formula has been modified as follows:
# Assume the following case:
# MAX_CAPTURE_SIZE = 1920, num_hidden_layers = 48, data_parallel_size is 1, tensor_parallel_size is 4,
# According to the formula, max_num_batch_sizes = math.floor((1920 - 1 * 40) / (48 + 1) / (1 + 1 * 2)) = 12
max_num_batch_sizes = math.floor(
(MAX_CAPTURE_SIZE - num_comm_groups * 40) / resources_per_graph /
(1 + num_comm_groups * 2))
logger.info(
"Calculated maximum supported batch sizes for ACL graph: %s",
max_num_batch_sizes)
(MAX_CAPTURE_SIZE - num_comm_groups * 40) / resources_per_graph / (1 + num_comm_groups * 2)
)
logger.info("Calculated maximum supported batch sizes for ACL graph: %s", max_num_batch_sizes)
logger.warning(
"Currently, communication is performed using FFTS+ method, which reduces "
"the number of available streams and, as a result, limits the range of runtime "
@@ -598,16 +585,19 @@ def update_aclgraph_sizes(vllm_config: VllmConfig) -> None:
vllm_config.model_config.architectures[0],
num_hidden_layers,
len(original_sizes),
len(compilation_config.
cudagraph_capture_sizes # type: ignore[arg-type]
))
len(
compilation_config.cudagraph_capture_sizes # type: ignore[arg-type]
),
)
else:
# No adjustment needed
compilation_config.cudagraph_capture_sizes = original_sizes
logger.info(
"No adjustment needed for ACL graph batch sizes: %s model (layers: %d) with %d sizes",
vllm_config.model_config.architectures[0], num_hidden_layers,
len(original_sizes))
vllm_config.model_config.architectures[0],
num_hidden_layers,
len(original_sizes),
)
# TODO(wxy): Move to ops module
@@ -617,7 +607,7 @@ def dispose_tensor(x: torch.Tensor):
class ProfileExecuteDuration:
_instance = None
_observations: List[Tuple[str, Event, Event]] = []
_observations: list[tuple[str, Event, Event]] = []
_lock = Lock()
def __new__(cls):
@@ -645,8 +635,7 @@ class ProfileExecuteDuration:
observe_end = Event(enable_timing=True)
observe_end.record()
with self._lock:
self._observations.append(
(duration_tag, observe_start, observe_end))
self._observations.append((duration_tag, observe_start, observe_end))
def pop_captured_sync(self) -> dict:
"""Pop and synchronize all events in the observation list"""
@@ -663,7 +652,7 @@ class ProfileExecuteDuration:
return durations
def register_ascend_customop(vllm_config: Optional[VllmConfig] = None):
def register_ascend_customop(vllm_config: VllmConfig | None = None):
"""Register Ascend CustomOP
NOTE: if the register branch requires model type, please use `vllm.config.get_current_vllm_config`,
@@ -675,23 +664,29 @@ def register_ascend_customop(vllm_config: Optional[VllmConfig] = None):
from vllm.model_executor.custom_op import CustomOp
from vllm_ascend.ops.activation import AscendQuickGELU, AscendSiluAndMul
from vllm_ascend.ops.fused_moe.fused_moe import (AscendFusedMoE,
AscendSharedFusedMoE)
from vllm_ascend.ops.fused_moe.fused_moe import AscendFusedMoE, AscendSharedFusedMoE
from vllm_ascend.ops.layernorm import AscendGemmaRMSNorm, AscendRMSNorm
from vllm_ascend.ops.linear import (AscendColumnParallelLinear,
from vllm_ascend.ops.linear import (
AscendColumnParallelLinear,
AscendMergedColumnParallelLinear,
AscendQKVParallelLinear,
AscendReplicatedLinear,
AscendRowParallelLinear)
AscendRowParallelLinear,
)
from vllm_ascend.ops.mla import AscendMultiHeadLatentAttention
from vllm_ascend.ops.mm_encoder_attention import AscendMMEncoderAttention
from vllm_ascend.ops.rotary_embedding import (
AscendApplyRotaryEmb, AscendDeepseekScalingRotaryEmbedding,
AscendMRotaryEmbedding, AscendRotaryEmbedding,
AscendYaRNRotaryEmbedding)
AscendApplyRotaryEmb,
AscendDeepseekScalingRotaryEmbedding,
AscendMRotaryEmbedding,
AscendRotaryEmbedding,
AscendYaRNRotaryEmbedding,
)
from vllm_ascend.ops.vocab_parallel_embedding import (
AscendLogitsProcessor, AscendParallelLMHead,
AscendVocabParallelEmbedding)
AscendLogitsProcessor,
AscendParallelLMHead,
AscendVocabParallelEmbedding,
)
global REGISTERED_ASCEND_OPS
REGISTERED_ASCEND_OPS = {
@@ -738,6 +733,7 @@ _ascend_device_type = None
def _init_ascend_device_type():
global _ascend_device_type
from vllm_ascend import _build_info # type: ignore
_ascend_device_type = AscendDeviceType[_build_info.__device_type__]
@@ -758,7 +754,10 @@ def check_ascend_device_type():
else:
raise RuntimeError(f"Can not support soc_version: {soc_version}.")
assert _ascend_device_type == cur_device_type, f"Current device type: {cur_device_type} does not match the installed version's device type: {_ascend_device_type}, please check your installation package."
assert _ascend_device_type == cur_device_type, (
f"Current device type: {cur_device_type} does not match the installed version's device type: "
f"{_ascend_device_type}, please check your installation package."
)
def get_ascend_device_type():
@@ -769,23 +768,19 @@ def get_ascend_device_type():
def lmhead_tp_enable() -> bool:
return get_ascend_config(
).finegrained_tp_config.lmhead_tensor_parallel_size > 0
return get_ascend_config().finegrained_tp_config.lmhead_tensor_parallel_size > 0
def embedding_tp_enable() -> bool:
return get_ascend_config(
).finegrained_tp_config.embedding_tensor_parallel_size > 0
return get_ascend_config().finegrained_tp_config.embedding_tensor_parallel_size > 0
def oproj_tp_enable() -> bool:
return get_ascend_config(
).finegrained_tp_config.oproj_tensor_parallel_size > 0
return get_ascend_config().finegrained_tp_config.oproj_tensor_parallel_size > 0
def mlp_tp_enable() -> bool:
return get_ascend_config(
).finegrained_tp_config.mlp_tensor_parallel_size > 0
return get_ascend_config().finegrained_tp_config.mlp_tensor_parallel_size > 0
def matmul_allreduce_enable() -> bool:
@@ -797,30 +792,30 @@ def enable_sp(vllm_config=None, enable_shared_expert_dp: bool = False) -> bool:
if _ENABLE_SP is None:
if vllm_config is None:
from vllm.config import get_current_vllm_config
vllm_config = get_current_vllm_config()
_ENABLE_SP = (
vllm_config.compilation_config.pass_config.enable_sp
or envs_ascend.VLLM_ASCEND_ENABLE_FLASHCOMM1
# Flash comm 1 should be enabled by env VLLM_ASCEND_ENABLE_FLASHCOMM1
# We retain the env VLLM_ASCEND_ENABLE_FLASHCOMM here for backward compatibility.
or bool(int(os.getenv("VLLM_ASCEND_ENABLE_FLASHCOMM", '0'))))
or bool(int(os.getenv("VLLM_ASCEND_ENABLE_FLASHCOMM", "0")))
)
if not _ENABLE_SP and enable_shared_expert_dp:
_ENABLE_SP = True
logger.info(
"shared_expert_dp requires enable_sp = True. has set enable_sp to True"
)
logger.info("shared_expert_dp requires enable_sp = True. has set enable_sp to True")
if not _ENABLE_SP:
return _ENABLE_SP
assert vllm_config.parallel_config.tensor_parallel_size > 1, \
assert vllm_config.parallel_config.tensor_parallel_size > 1, (
"Flash Comm v1 (Sequence Parallelism) is only supported when tp_size > 1."
)
assert (
not is_moe_model(vllm_config)
or vllm_config.parallel_config.enable_expert_parallel
), "Flash Comm v1 (Sequence Parallelism) requires enable_expert_parallel=True for MoE models."
assert not is_moe_model(vllm_config) or vllm_config.parallel_config.enable_expert_parallel, (
"Flash Comm v1 (Sequence Parallelism) requires enable_expert_parallel=True for MoE models."
)
return _ENABLE_SP
@@ -847,25 +842,21 @@ def is_drafter_moe_model(vllm_config: VllmConfig):
"""Checks if the drafter model is a MoE model by config"""
global _IS_DRAFTER_MOE_MODEL
if _IS_DRAFTER_MOE_MODEL is None:
model_configs = vllm_config.speculative_config.draft_model_config.hf_text_config \
.to_dict()
model_configs = vllm_config.speculative_config.draft_model_config.hf_text_config.to_dict()
_IS_DRAFTER_MOE_MODEL = _is_contain_expert(model_configs)
return _IS_DRAFTER_MOE_MODEL
def speculative_enable_dispatch_gmm_combine_decode(
vllm_config: VllmConfig) -> bool:
def speculative_enable_dispatch_gmm_combine_decode(vllm_config: VllmConfig) -> bool:
if vllm_config.speculative_config is None:
return True
speculative_method = getattr(vllm_config.speculative_config, "method",
None)
speculative_method = getattr(vllm_config.speculative_config, "method", None)
if speculative_method in [None, "ngram", "suffix"]:
return True
if speculative_method in ["eagle", "eagle3"]:
return False
if speculative_method == "mtp":
mtp_quant_type = getattr(vllm_config.model_config.hf_text_config,
"mtp_quantize", None)
mtp_quant_type = getattr(vllm_config.model_config.hf_text_config, "mtp_quantize", None)
return mtp_quant_type == "w8a8_dynamic"
return False
@@ -915,8 +906,8 @@ def weak_ref_tensor(tensor: Any) -> Any:
def weak_ref_tensors(
tensors: Union[torch.Tensor, list[torch.Tensor], tuple[torch.Tensor]]
) -> Union[torch.Tensor, list[Any], tuple[Any], Any]:
tensors: torch.Tensor | list[torch.Tensor] | tuple[torch.Tensor],
) -> torch.Tensor | list[Any] | tuple[Any] | Any:
"""
Convenience function to create weak references to tensors,
for single tensor, list of tensors or tuple of tensors.
@@ -936,17 +927,12 @@ def weak_ref_tensors(
return tuple(weak_ref_tensor(t) for t in tensors)
# For IntermediateTensors used in pipeline parallelism
if isinstance(tensors, IntermediateTensors):
ret = IntermediateTensors({
key: weak_ref_tensor(val)
for key, val in tensors.tensors.items()
})
ret = IntermediateTensors({key: weak_ref_tensor(val) for key, val in tensors.tensors.items()})
return ret
raise ValueError("Invalid type for tensors")
def npu_stream_switch(target_stream: torch.npu.Stream,
*,
enabled: bool = True):
def npu_stream_switch(target_stream: torch.npu.Stream, *, enabled: bool = True):
"""
Switch to the target stream if enabled is True.
Otherwise, do nothing.
@@ -965,7 +951,7 @@ def create_hccl_pg_options(group_name: str):
return options
def get_hccl_config_for_pg_options(group_name: str) -> Optional[dict]:
def get_hccl_config_for_pg_options(group_name: str) -> dict | None:
"""
Get HCCL process group options for the given communication group name.
@@ -981,9 +967,7 @@ def get_hccl_config_for_pg_options(group_name: str) -> Optional[dict]:
if group_name and "mc2" in group_name:
return None
hccl_config_map = {
"dp": {
"hccl_buffer_size": calculate_dp_buffer_size()
},
"dp": {"hccl_buffer_size": calculate_dp_buffer_size()},
}
return hccl_config_map.get(group_name, get_default_buffer_config())
@@ -998,6 +982,7 @@ def calculate_dp_buffer_size() -> int:
dp_size + 1 (flags: with_prefill)
"""
from vllm.config import get_current_vllm_config
vllm_config = get_current_vllm_config()
dp_size = vllm_config.parallel_config.data_parallel_size
int32_size = torch.iinfo(torch.int32).bits // 8
@@ -1009,8 +994,7 @@ def calculate_dp_buffer_size() -> int:
# and HCCL_INTRA_ROCE_ENABLE=0 can reduce cross-machine communication traffic and
# significantly improve communication performance of MC2 ops dispatch/combine.
def is_hierarchical_communication_enabled():
return (os.getenv("HCCL_INTRA_ROCE_ENABLE", "") == "0"
and os.getenv("HCCL_INTRA_PCIE_ENABLE", "") == "1")
return os.getenv("HCCL_INTRA_ROCE_ENABLE", "") == "0" and os.getenv("HCCL_INTRA_PCIE_ENABLE", "") == "1"
def has_layer_idx(model_instance: torch.nn.Module) -> bool:
@@ -1019,8 +1003,7 @@ def has_layer_idx(model_instance: torch.nn.Module) -> bool:
global _HAS_LAYER_IDX
if _HAS_LAYER_IDX is None:
_HAS_LAYER_IDX = hasattr(model_instance, "model") and \
hasattr(model_instance.model, "start_layer")
_HAS_LAYER_IDX = hasattr(model_instance, "model") and hasattr(model_instance.model, "start_layer")
return _HAS_LAYER_IDX
@@ -1042,20 +1025,17 @@ def get_flashcomm2_config_and_validate(ascend_config, vllm_config):
if not flashcomm2_enable():
return 0
logger.info(
f"Enable FLASHCOMM2 with flashcomm2_oproj_tensor_parallel_size = {flashcomm2_oproj_tp_size}"
)
logger.info(f"Enable FLASHCOMM2 with flashcomm2_oproj_tensor_parallel_size = {flashcomm2_oproj_tp_size}")
layer_sharding = ascend_config.layer_sharding or []
if layer_sharding:
if layer_sharding == ["o_proj"]:
logger.info_once(
"Enable FLASHCOMM2 with o_proj layer sharding for reduced memory consumption."
)
logger.info_once("Enable FLASHCOMM2 with o_proj layer sharding for reduced memory consumption.")
else:
raise ValueError(
"FLASHCOMM2 only supports 'o_proj' as the sole layer sharding configuration! "
f"Found invalid layer_sharding: {layer_sharding}")
f"Found invalid layer_sharding: {layer_sharding}"
)
if not envs_ascend.VLLM_ASCEND_ENABLE_FLASHCOMM1:
logger.warning_once(
"It is recommended to enable FLASHCOMM1 simultaneously when starting FLASHCOMM2 for optimal performance."
@@ -1066,32 +1046,35 @@ def get_flashcomm2_config_and_validate(ascend_config, vllm_config):
)
if global_tp_size <= flashcomm2_oproj_tp_size:
raise AssertionError(
f"flashcomm2_oproj_tensor_parallel_size ({flashcomm2_oproj_tp_size}) cannot exceed global tensor parallel size ({global_tp_size})"
f"flashcomm2_oproj_tensor_parallel_size ({flashcomm2_oproj_tp_size}) cannot exceed "
f"global tensor parallel size ({global_tp_size})"
)
if global_tp_size % flashcomm2_oproj_tp_size != 0:
raise AssertionError(
f"Global tensor parallel size ({global_tp_size}) must be divisible by flashcomm2_oproj_tensor_parallel_size ({flashcomm2_oproj_tp_size})"
f"Global tensor parallel size ({global_tp_size}) must be divisible by "
f"flashcomm2_oproj_tensor_parallel_size ({flashcomm2_oproj_tp_size})"
)
if vllm_config.kv_transfer_config is None:
logger.warning_once(
"It is recommended to enable FLASHCOMM2 in P-scenario deployments, enable it in hybrid deployment may lead to decode performance degradation."
"It is recommended to enable FLASHCOMM2 in P-scenario deployments, enable it in hybrid deployment "
"may lead to decode performance degradation."
)
if vllm_config.kv_transfer_config is not None and vllm_config.kv_transfer_config.is_kv_consumer:
raise AssertionError(
"FLASHCOMM2 primarily targets P-scenario deployments, with additional support for hybrid deployment scenarios. It is not applicable in D-scenario environments."
"FLASHCOMM2 primarily targets P-scenario deployments, with additional support "
"for hybrid deployment scenarios. It is not applicable in D-scenario environments."
)
return flashcomm2_oproj_tp_size
def get_flashcomm2_reorgnized_batch_ids(global_tp_size) -> list[list[int]]:
# Reorganize batch_ids so that, after the all2all and reduce-scatter operation, each batch_id corresponds to the rank_id within the DP domain.
# Reorganize batch_ids so that, after the all2all and reduce-scatter operation,
# each batch_id corresponds to the rank_id within the DP domain.
# For example, when DP = [0, 1, 2, ..., 15] and flashcomm2_oproj_tensor_parallel_size = 2,
# the reorganized batch_ids will be [[batch0, batch8], [batch1, batch9], ..., [batch7, batch15]].
flashcomm2_otp_size = get_ascend_config(
).flashcomm2_oproj_tensor_parallel_size
num_oproj_tensor_parallel_groups: int = (global_tp_size //
flashcomm2_otp_size)
flashcomm2_otp_size = get_ascend_config().flashcomm2_oproj_tensor_parallel_size
num_oproj_tensor_parallel_groups: int = global_tp_size // flashcomm2_otp_size
reorgnized_batch_ids = []
for i in range(num_oproj_tensor_parallel_groups):
@@ -1122,11 +1105,9 @@ def refresh_block_size(vllm_config):
return
# TODO(MengqingCao): Remove the model_type check, after resolving the hidden error in get_kv_cache_groups.
if not model_config.hf_text_config.model_type == "qwen3_next" and cache_config.block_size != 128:
if model_config.hf_text_config.model_type != "qwen3_next" and cache_config.block_size != 128:
if cache_config.enable_prefix_caching or scheduler_config.enable_chunked_prefill:
logger.info(
"Block size is set to 128 if prefix cache or chunked prefill is enabled."
)
logger.info("Block size is set to 128 if prefix cache or chunked prefill is enabled.")
cache_config.block_size = 128
@@ -1138,7 +1119,6 @@ def dispose_layer(layer: Any):
def check_kv_extra_config(vllm_config):
def _check(name: str, config: dict):
tp_key = "tp_size"
dp_key = "dp_size"
@@ -1148,24 +1128,21 @@ def check_kv_extra_config(vllm_config):
if config_tp != vllm_tp:
raise ValueError(
f"KV transfer '{name}' config has a conflicting tensor parallel size. "
f"Expected {vllm_tp}, but got {config_tp}.")
f"Expected {vllm_tp}, but got {config_tp}."
)
if dp_key in config:
config_dp = config[dp_key]
vllm_dp = vllm_config.parallel_config.data_parallel_size
if config_dp != vllm_dp:
raise ValueError(
f"KV transfer '{name}' config has a conflicting data parallel size. "
f"Expected {vllm_dp}, but got {config_dp}.")
f"Expected {vllm_dp}, but got {config_dp}."
)
if vllm_config.kv_transfer_config.is_kv_producer:
_check(
"prefill",
vllm_config.kv_transfer_config.get_from_extra_config(
"prefill", {}))
_check("prefill", vllm_config.kv_transfer_config.get_from_extra_config("prefill", {}))
if vllm_config.kv_transfer_config.is_kv_consumer:
_check(
"decode",
vllm_config.kv_transfer_config.get_from_extra_config("decode", {}))
_check("decode", vllm_config.kv_transfer_config.get_from_extra_config("decode", {}))
def singleton(cls):
@@ -1179,17 +1156,17 @@ def singleton(cls):
return get_instance
#TODO: Temporarily use enable_sp to enable the dsa_cp feature of ds32. and subsequent updates will introduce new interfaces. --zzhx1
# TODO: Temporarily use enable_sp to enable the dsa_cp feature of ds32.
# and subsequent updates will introduce new interfaces. --zzhx1
@lru_cache(maxsize=1)
def enable_dsa_cp() -> bool:
from vllm.config import get_current_vllm_config
vllm_config = get_current_vllm_config()
is_ds_v32 = hasattr(
vllm_config.model_config, "hf_text_config") and hasattr(
vllm_config.model_config.hf_text_config, "index_topk")
if is_ds_v32 and enable_sp():
return True
return False
is_ds_v32 = hasattr(vllm_config.model_config, "hf_text_config") and hasattr(
vllm_config.model_config.hf_text_config, "index_topk"
)
return bool(is_ds_v32 and enable_sp())
@lru_cache(maxsize=1)
@@ -1197,6 +1174,7 @@ def enable_dsa_cp_with_layer_shard() -> bool:
if not enable_dsa_cp():
return False
from vllm.config import get_current_vllm_config
vllm_config = get_current_vllm_config()
is_prefill_instance = vllm_config.kv_transfer_config is not None and vllm_config.kv_transfer_config.is_kv_producer
return is_prefill_instance