[Lint]Style: Convert vllm-ascend/compilation to ruff format (#5912)

### What this PR does / why we need it?
Convert `vllm-ascend/compilation` to ruff format.

### Does this PR introduce _any_ user-facing change?
During this migration, we encountered some **errors** in our CI and
testing environments, such as:
```
vllm_ascend/utils.py:653: in <module>
    def register_ascend_customop(vllm_config: VllmConfig | None = None):
                                              ^^^^^^^^^^^^^^^^^
E   TypeError: unsupported operand type(s) for |: 'NoneType' and 'NoneType'
```

**1. Root Cause Analysis:**
The project uses a common pattern to break circular dependencies:
```python
if TYPE_CHECKING:
    from vllm.config import VllmConfig
else:
    VllmConfig = None  # Placeholder assigned at runtime
```
When Python parses the function definition `def
register_ascend_customop(vllm_config: VllmConfig | None)`, it attempts
to evaluate the expression `VllmConfig | None`.
Since `VllmConfig` is assigned `None` at runtime, the expression
effectively becomes `None | None`. In Python, `None` is an instance of
`NoneType`. While the `|` operator is implemented for Type objects
(classes), it is not supported for `NoneType` instances, leading to the
`TypeError` shown above.

**2. Solution:**
To maintain the modern `|` syntax required by our new linting standards
while preserving our dependency management strategy, I have introduced:
```python
from __future__ import annotations
```
at the top of the affected files. This enables **Postponed Evaluation of
Annotations (PEP 563)**.

**3. Impact and Benefits:**
- By enabling `annotations`, Python no longer executes the `VllmConfig |
None` operation during module load. Instead, it stores the annotation as
a string literal, completely avoiding the `None | None` calculation.
- We can keep the `VllmConfig = None` placeholders. This ensures that
other modules can still import these symbols without triggering an
`ImportError`, maintaining a stable dependency graph.
- IDEs and static type checkers (MyPy/Pyright) continue to resolve the
types correctly. This allows us to use modern syntax without sacrificing
type safety or runtime stability.
- The only side effect is that `__annotations__` will now return strings
instead of type objects. Since this module does not use runtime type
enforcement or reflection, this change has zero negative impact on
existing functionality.
### How was this patch tested?

- vLLM version: v0.13.0
- vLLM main:
11b6af5280

---------

Signed-off-by: MrZ20 <2609716663@qq.com>
This commit is contained in:
SILONG ZENG
2026-01-16 20:57:46 +08:00
committed by GitHub
parent 3af91e5ac4
commit 52086394ae
16 changed files with 996 additions and 1140 deletions

View File

@@ -49,7 +49,25 @@ line-length = 120
# Folder to be modified # Folder to be modified
exclude = [ exclude = [
"tests/**", "tests/**",
"vllm_ascend/**", "vllm_ascend/_cann_ops_custom",
"vllm_ascend/attention",
"vllm_ascend/core",
"vllm_ascend/device",
"vllm_ascend/device_allocator",
"vllm_ascend/distributed",
"vllm_ascend/eplb",
"vllm_ascend/kv_offload",
"vllm_ascend/lora",
"vllm_ascend/model_loader",
"vllm_ascend/ops",
"vllm_ascend/patch",
"vllm_ascend/quantization",
"vllm_ascend/sample",
"vllm_ascend/spec_decode",
"vllm_ascend/worker",
"vllm_ascend/xlite",
"vllm_ascend/envs.py",
"vllm_ascend/batch_invariant.py",
] ]
[tool.ruff.lint] [tool.ruff.lint]

View File

@@ -24,14 +24,17 @@ def register():
def register_connector(): def register_connector():
from vllm_ascend.distributed.kv_transfer import register_connector from vllm_ascend.distributed.kv_transfer import register_connector
register_connector() register_connector()
def register_model_loader(): def register_model_loader():
from .model_loader.netloader import register_netloader from .model_loader.netloader import register_netloader
register_netloader() register_netloader()
def register_service_profiling(): def register_service_profiling():
from .profiling_config import generate_service_profiling_config from .profiling_config import generate_service_profiling_config
generate_service_profiling_config() generate_service_profiling_config()

View File

@@ -14,7 +14,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import os import os
from typing import TYPE_CHECKING, Optional from typing import TYPE_CHECKING
from vllm.logger import logger from vllm.logger import logger
from vllm.triton_utils import HAS_TRITON from vllm.triton_utils import HAS_TRITON
@@ -32,18 +32,13 @@ class AscendConfig:
additional_config = vllm_config.additional_config if vllm_config.additional_config is not None else {} additional_config = vllm_config.additional_config if vllm_config.additional_config is not None else {}
xlite_graph_config = additional_config.get("xlite_graph_config", {}) xlite_graph_config = additional_config.get("xlite_graph_config", {})
self.xlite_graph_config = XliteGraphConfig(xlite_graph_config, self.xlite_graph_config = XliteGraphConfig(xlite_graph_config, vllm_config)
vllm_config)
ascend_compilation_config = additional_config.get( ascend_compilation_config = additional_config.get("ascend_compilation_config", {})
"ascend_compilation_config", {}) self.ascend_compilation_config = AscendCompilationConfig(**ascend_compilation_config)
self.ascend_compilation_config = AscendCompilationConfig(
**ascend_compilation_config)
finegrained_tp_config = additional_config.get("finegrained_tp_config", finegrained_tp_config = additional_config.get("finegrained_tp_config", {})
{}) self.finegrained_tp_config = FinegrainedTPConfig(finegrained_tp_config, vllm_config)
self.finegrained_tp_config = FinegrainedTPConfig(
finegrained_tp_config, vllm_config)
eplb_config = additional_config.get("eplb_config", {}) eplb_config = additional_config.get("eplb_config", {})
self.eplb_config = EplbConfig(eplb_config) self.eplb_config = EplbConfig(eplb_config)
@@ -51,10 +46,8 @@ class AscendConfig:
# Dump / PrecisionDebugger configuration # Dump / PrecisionDebugger configuration
self.dump_config_path = additional_config.get("dump_config_path", None) self.dump_config_path = additional_config.get("dump_config_path", None)
weight_prefetch_config = additional_config.get( weight_prefetch_config = additional_config.get("weight_prefetch_config", {})
"weight_prefetch_config", {}) self.weight_prefetch_config = WeightPrefetchConfig(weight_prefetch_config)
self.weight_prefetch_config = WeightPrefetchConfig(
weight_prefetch_config)
self.layer_sharding = additional_config.get("layer_sharding", None) self.layer_sharding = additional_config.get("layer_sharding", None)
logger.info_once( logger.info_once(
f"Linear layer sharding enabled with config: {self.layer_sharding}. " f"Linear layer sharding enabled with config: {self.layer_sharding}. "
@@ -62,30 +55,25 @@ class AscendConfig:
"using it without these features may result in significant performance degradation." "using it without these features may result in significant performance degradation."
) )
self.enable_shared_expert_dp = additional_config.get( self.enable_shared_expert_dp = (
"enable_shared_expert_dp", additional_config.get("enable_shared_expert_dp", False)
False) and vllm_config.parallel_config.enable_expert_parallel and vllm_config.parallel_config.enable_expert_parallel
)
if self.enable_shared_expert_dp: if self.enable_shared_expert_dp:
from vllm_ascend.utils import enable_sp from vllm_ascend.utils import enable_sp
assert enable_sp(vllm_config=vllm_config,
enable_shared_expert_dp=True) assert enable_sp(vllm_config=vllm_config, enable_shared_expert_dp=True)
self.multistream_overlap_shared_expert = additional_config.get( self.multistream_overlap_shared_expert = additional_config.get("multistream_overlap_shared_expert", False)
"multistream_overlap_shared_expert", False) self.multistream_overlap_gate = additional_config.get("multistream_overlap_gate", False)
self.multistream_overlap_gate = additional_config.get( self.recompute_scheduler_enable = additional_config.get("recompute_scheduler_enable", False)
"multistream_overlap_gate", False) self.enable_cpu_binding = additional_config.get("enable_cpu_binding", False)
self.recompute_scheduler_enable = additional_config.get(
"recompute_scheduler_enable", False)
self.enable_cpu_binding = additional_config.get(
"enable_cpu_binding", False)
self.pd_tp_ratio = 1 self.pd_tp_ratio = 1
self.pd_head_ratio = 1 self.pd_head_ratio = 1
self.num_head_replica = 1 self.num_head_replica = 1
if vllm_config.kv_transfer_config is not None and not vllm_config.model_config.is_deepseek_mla: if vllm_config.kv_transfer_config is not None and not vllm_config.model_config.is_deepseek_mla:
prefill_tp_size = vllm_config.kv_transfer_config.get_from_extra_config( prefill_tp_size = vllm_config.kv_transfer_config.get_from_extra_config("prefill", {"tp_size": 1})["tp_size"]
"prefill", {"tp_size": 1})["tp_size"] decode_tp_size = vllm_config.kv_transfer_config.get_from_extra_config("decode", {"tp_size": 1})["tp_size"]
decode_tp_size = vllm_config.kv_transfer_config.get_from_extra_config(
"decode", {"tp_size": 1})["tp_size"]
assert prefill_tp_size % decode_tp_size == 0, "Prefill TP size must be divisible by Decode TP size." assert prefill_tp_size % decode_tp_size == 0, "Prefill TP size must be divisible by Decode TP size."
self.pd_tp_ratio = prefill_tp_size // decode_tp_size self.pd_tp_ratio = prefill_tp_size // decode_tp_size
if self.pd_tp_ratio > 1: if self.pd_tp_ratio > 1:
@@ -106,36 +94,29 @@ class AscendConfig:
) )
if self.pd_tp_ratio == 0: if self.pd_tp_ratio == 0:
raise AssertionError( raise AssertionError("Only support P node tp size lagger then D node tp size")
"Only support P node tp size lagger then D node tp size") self.SLO_limits_for_dynamic_batch = additional_config.get("SLO_limits_for_dynamic_batch", -1)
self.SLO_limits_for_dynamic_batch = additional_config.get(
"SLO_limits_for_dynamic_batch", -1)
from vllm_ascend.utils import get_flashcomm2_config_and_validate from vllm_ascend.utils import get_flashcomm2_config_and_validate
self.flashcomm2_oproj_tensor_parallel_size = get_flashcomm2_config_and_validate(
self, vllm_config) self.flashcomm2_oproj_tensor_parallel_size = get_flashcomm2_config_and_validate(self, vllm_config)
self.enable_npugraph_ex = additional_config.get( self.enable_npugraph_ex = additional_config.get("enable_npugraph_ex", False)
"enable_npugraph_ex", False)
# We find that _npu_paged_attention still performs better than # We find that _npu_paged_attention still performs better than
# npu_fused_infer_attention_score in some cases. We allow to execute # npu_fused_infer_attention_score in some cases. We allow to execute
# _npu_paged_attention in this cases. This should be removed once # _npu_paged_attention in this cases. This should be removed once
# npu_fused_infer_attention_score performs better on all scenarios. # npu_fused_infer_attention_score performs better on all scenarios.
self.pa_shape_list = additional_config.get("pa_shape_list", []) self.pa_shape_list = additional_config.get("pa_shape_list", [])
self.enable_async_exponential = bool( self.enable_async_exponential = bool(additional_config.get("enable_async_exponential", False))
additional_config.get("enable_async_exponential", False))
self.enable_kv_nz = additional_config.get("enable_kv_nz", False) self.enable_kv_nz = additional_config.get("enable_kv_nz", False)
if self.enable_kv_nz: if self.enable_kv_nz:
use_sparse = hasattr(vllm_config.model_config.hf_text_config, use_sparse = hasattr(vllm_config.model_config.hf_text_config, "index_topk")
"index_topk")
if not vllm_config.model_config.is_deepseek_mla or use_sparse: if not vllm_config.model_config.is_deepseek_mla or use_sparse:
raise RuntimeError( raise RuntimeError("enable_kv_nz is only supported for mla currently.")
"enable_kv_nz is only supported for mla currently.") if vllm_config.kv_transfer_config is None or not vllm_config.kv_transfer_config.is_kv_consumer:
if vllm_config.kv_transfer_config is None \
or not vllm_config.kv_transfer_config.is_kv_consumer:
raise NotImplementedError( raise NotImplementedError(
"enable_kv_nz is only supported in pd scenario and can " "enable_kv_nz is only supported in pd scenario and can only be used in D node."
"only be used in D node.") )
class FinegrainedTPConfig: class FinegrainedTPConfig:
@@ -144,40 +125,28 @@ class FinegrainedTPConfig:
""" """
def __init__(self, finegrained_tp_config: dict, vllm_config): def __init__(self, finegrained_tp_config: dict, vllm_config):
self.oproj_tensor_parallel_size = finegrained_tp_config.get( self.oproj_tensor_parallel_size = finegrained_tp_config.get("oproj_tensor_parallel_size", 0)
"oproj_tensor_parallel_size", 0) self.lmhead_tensor_parallel_size = finegrained_tp_config.get("lmhead_tensor_parallel_size", 0)
self.lmhead_tensor_parallel_size = finegrained_tp_config.get( self.embedding_tensor_parallel_size = finegrained_tp_config.get("embedding_tensor_parallel_size", 0)
"lmhead_tensor_parallel_size", 0) self.mlp_tensor_parallel_size = finegrained_tp_config.get("mlp_tensor_parallel_size", 0)
self.embedding_tensor_parallel_size = finegrained_tp_config.get(
"embedding_tensor_parallel_size", 0)
self.mlp_tensor_parallel_size = finegrained_tp_config.get(
"mlp_tensor_parallel_size", 0)
enabled_configs = [] enabled_configs = []
if self.oproj_tensor_parallel_size > 0: if self.oproj_tensor_parallel_size > 0:
enabled_configs.append( enabled_configs.append(f"oproj_tensor_parallel_size={self.oproj_tensor_parallel_size}")
f"oproj_tensor_parallel_size={self.oproj_tensor_parallel_size}" # dummy_run does not run the entire attention module in eager mode,
) # so the o_proj tp split can only be used in graph mode.
# dummy_run does not run the entire attention module in eager mode,, so the o_proj tp split can only be used in graph mode.
if vllm_config.model_config.enforce_eager is True: if vllm_config.model_config.enforce_eager is True:
raise AssertionError( raise AssertionError("oproj_tensor_parallel_size is only supported in graph mode")
"oproj_tensor_parallel_size is only supported in graph mode"
)
if vllm_config.kv_transfer_config is None or not vllm_config.kv_transfer_config.is_kv_consumer: if vllm_config.kv_transfer_config is None or not vllm_config.kv_transfer_config.is_kv_consumer:
raise AssertionError( raise AssertionError(
"oproj_tensor_parallel_size is only supported in pd scenario and can only be used in D node." "oproj_tensor_parallel_size is only supported in pd scenario and can only be used in D node."
) )
if self.lmhead_tensor_parallel_size > 0: if self.lmhead_tensor_parallel_size > 0:
enabled_configs.append( enabled_configs.append(f"lmhead_tensor_parallel_size={self.lmhead_tensor_parallel_size}")
f"lmhead_tensor_parallel_size={self.lmhead_tensor_parallel_size}"
)
if self.embedding_tensor_parallel_size > 0: if self.embedding_tensor_parallel_size > 0:
enabled_configs.append( enabled_configs.append(f"embedding_tensor_parallel_size={self.embedding_tensor_parallel_size}")
f"embedding_tensor_parallel_size={self.embedding_tensor_parallel_size}"
)
if self.mlp_tensor_parallel_size > 0: if self.mlp_tensor_parallel_size > 0:
enabled_configs.append( enabled_configs.append(f"mlp_tensor_parallel_size={self.mlp_tensor_parallel_size}")
f"mlp_tensor_parallel_size={self.mlp_tensor_parallel_size}")
module_tp_sizes = [ module_tp_sizes = [
self.oproj_tensor_parallel_size, self.oproj_tensor_parallel_size,
self.lmhead_tensor_parallel_size, self.lmhead_tensor_parallel_size,
@@ -186,11 +155,9 @@ class FinegrainedTPConfig:
] ]
for module_tp_size in module_tp_sizes: for module_tp_size in module_tp_sizes:
if module_tp_size > 0 and vllm_config.parallel_config.data_parallel_size % module_tp_size != 0: if module_tp_size > 0 and vllm_config.parallel_config.data_parallel_size % module_tp_size != 0:
raise AssertionError( raise AssertionError("module tp sizes must divide data_parallel_size")
"module tp sizes must divide data_parallel_size")
if any(size > 0 for size in module_tp_sizes) and enabled_configs: if any(size > 0 for size in module_tp_sizes) and enabled_configs:
logger.info( logger.info(f"finegrained_tp_config enabled: {', '.join(enabled_configs)}")
f"finegrained_tp_config enabled: {', '.join(enabled_configs)}")
class AscendCompilationConfig: class AscendCompilationConfig:
@@ -202,13 +169,10 @@ class AscendCompilationConfig:
deployed on Ascend platforms. deployed on Ascend platforms.
""" """
def __init__(self, def __init__(self, fuse_norm_quant: bool = True, fuse_qknorm_rope: bool = False, **kwargs):
fuse_norm_quant: bool = True,
fuse_qknorm_rope: bool = False,
**kwargs):
""" """
Initialize the configuration. Initialize the configuration.
Args: Args:
fuse_norm_quant (bool): Whether to enable norm and quant fusion optimization. fuse_norm_quant (bool): Whether to enable norm and quant fusion optimization.
When set to True, the system will optimize norm and quant operations. When set to True, the system will optimize norm and quant operations.
@@ -236,7 +200,8 @@ class XliteGraphConfig:
) )
if vllm_config.parallel_config.pipeline_parallel_size > 1: if vllm_config.parallel_config.pipeline_parallel_size > 1:
raise RuntimeError( raise RuntimeError(
"Xlite graph mode is not compatible with pipeline parallelism. Please set pipeline_parallel_size to 1." "Xlite graph mode is not compatible with pipeline parallelism. "
"Please set pipeline_parallel_size to 1."
) )
if vllm_config.cache_config.block_size != 128: if vllm_config.cache_config.block_size != 128:
raise RuntimeError( raise RuntimeError(
@@ -254,21 +219,19 @@ class WeightPrefetchConfig:
"qkv": 1.0, "qkv": 1.0,
"o": 1.0, "o": 1.0,
}, },
"moe": { "moe": {"gate_up": 0.8},
"gate_up": 0.8
}
} }
def __init__(self, weight_prefetch_config: dict): def __init__(self, weight_prefetch_config: dict):
self.enabled = weight_prefetch_config.get("enabled", False) self.enabled = weight_prefetch_config.get("enabled", False)
self.prefetch_ratio = weight_prefetch_config.get( self.prefetch_ratio = weight_prefetch_config.get("prefetch_ratio", self.prefetch_ratio)
"prefetch_ratio", self.prefetch_ratio)
class EplbConfig: class EplbConfig:
""" """
Configuration Object for xlite_graph_config from additional_config Configuration Object for xlite_graph_config from additional_config
""" """
_defaults = { _defaults = {
"dynamic_eplb": False, "dynamic_eplb": False,
"expert_map_path": None, "expert_map_path": None,
@@ -276,10 +239,12 @@ class EplbConfig:
"algorithm_execution_interval": 30, "algorithm_execution_interval": 30,
"expert_map_record_path": None, "expert_map_record_path": None,
"num_redundant_experts": 0, "num_redundant_experts": 0,
"eplb_policy_type": 1 "eplb_policy_type": 1,
} }
def __init__(self, user_config: dict = {}): def __init__(self, user_config: dict | None = None):
if user_config is None:
user_config = {}
self.config = self._defaults.copy() self.config = self._defaults.copy()
if user_config and isinstance(user_config, dict): if user_config and isinstance(user_config, dict):
for key, value in user_config.items(): for key, value in user_config.items():
@@ -307,27 +272,21 @@ class EplbConfig:
raise TypeError("The expert_map_record_path is not json.") raise TypeError("The expert_map_record_path is not json.")
dirname = os.path.dirname(self.expert_map_record_path) dirname = os.path.dirname(self.expert_map_record_path)
os.makedirs(dirname, exist_ok=True) os.makedirs(dirname, exist_ok=True)
for key in [ for key in ["expert_heat_collection_interval", "algorithm_execution_interval", "num_redundant_experts"]:
"expert_heat_collection_interval",
"algorithm_execution_interval", "num_redundant_experts"
]:
if not isinstance(self.config[key], int): if not isinstance(self.config[key], int):
raise TypeError(f"{key} must be an integer") raise TypeError(f"{key} must be an integer")
if self.config[key] < 0: # type: ignore if self.config[key] < 0: # type: ignore
raise ValueError( raise ValueError(f"{key} must greater than 0; got {self.config[key]} instead")
f"{key} must greater than 0; got {self.config[key]} instead"
)
if self.eplb_policy_type not in [0, 1, 2, 3]: if self.eplb_policy_type not in [0, 1, 2, 3]:
raise ValueError("eplb_policy_type must in [0, 1, 2, 3]") raise ValueError("eplb_policy_type must in [0, 1, 2, 3]")
_ASCEND_CONFIG: Optional[AscendConfig] = None _ASCEND_CONFIG: AscendConfig | None = None
def init_ascend_config(vllm_config): def init_ascend_config(vllm_config):
additional_config = vllm_config.additional_config if vllm_config.additional_config is not None else {} additional_config = vllm_config.additional_config if vllm_config.additional_config is not None else {}
refresh = additional_config.get("refresh", refresh = additional_config.get("refresh", False) if additional_config else False
False) if additional_config else False
global _ASCEND_CONFIG global _ASCEND_CONFIG
if _ASCEND_CONFIG is not None and not refresh: if _ASCEND_CONFIG is not None and not refresh:
return _ASCEND_CONFIG return _ASCEND_CONFIG
@@ -343,7 +302,5 @@ def clear_ascend_config():
def get_ascend_config(): def get_ascend_config():
global _ASCEND_CONFIG global _ASCEND_CONFIG
if _ASCEND_CONFIG is None: if _ASCEND_CONFIG is None:
raise RuntimeError( raise RuntimeError("Ascend config is not initialized. Please call init_ascend_config first.")
"Ascend config is not initialized. Please call init_ascend_config first."
)
return _ASCEND_CONFIG return _ASCEND_CONFIG

View File

@@ -1,21 +1,25 @@
import math import math
from contextlib import contextmanager from contextlib import contextmanager
from enum import Enum from enum import Enum
from typing import Any, Optional from typing import Any
import torch import torch
from vllm.config import CUDAGraphMode, VllmConfig from vllm.config import CUDAGraphMode, VllmConfig
from vllm.distributed import (get_dp_group, get_ep_group, from vllm.distributed import get_dp_group, get_ep_group, get_tensor_model_parallel_world_size
get_tensor_model_parallel_world_size) from vllm.forward_context import BatchDescriptor, get_forward_context, set_forward_context
from vllm.forward_context import (BatchDescriptor, get_forward_context,
set_forward_context)
import vllm_ascend.envs as envs_ascend import vllm_ascend.envs as envs_ascend
from vllm_ascend.ascend_config import get_ascend_config from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.utils import (AscendDeviceType, enable_sp, flashcomm2_enable, from vllm_ascend.utils import (
get_ascend_device_type, has_layer_idx, AscendDeviceType,
is_drafter_moe_model, is_moe_model, enable_sp,
speculative_enable_dispatch_gmm_combine_decode) flashcomm2_enable,
get_ascend_device_type,
has_layer_idx,
is_drafter_moe_model,
is_moe_model,
speculative_enable_dispatch_gmm_combine_decode,
)
class MoECommType(Enum): class MoECommType(Enum):
@@ -27,36 +31,36 @@ class MoECommType(Enum):
@contextmanager @contextmanager
def set_ascend_forward_context( def set_ascend_forward_context(
attn_metadata: Any, attn_metadata: Any,
vllm_config: VllmConfig, vllm_config: VllmConfig,
virtual_engine: int = 0, virtual_engine: int = 0,
num_tokens: int = 0, num_tokens: int = 0,
num_tokens_across_dp: Optional[torch.Tensor] = None, num_tokens_across_dp: torch.Tensor | None = None,
in_profile_run: bool = False, in_profile_run: bool = False,
num_actual_tokens: Optional[int] = None, num_actual_tokens: int | None = None,
aclgraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE, aclgraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
batch_descriptor: Optional[BatchDescriptor] = None, batch_descriptor: BatchDescriptor | None = None,
model_instance: torch.nn.Module = None, model_instance: torch.nn.Module = None,
is_draft_model=False): is_draft_model=False,
):
"""A context manager that stores the current forward context, """A context manager that stores the current forward context,
can be attention metadata, etc. can be attention metadata, etc.
We add some additional param into forward_context. We add some additional param into forward_context.
""" """
with set_forward_context( with set_forward_context(
attn_metadata, attn_metadata,
vllm_config, vllm_config,
virtual_engine=virtual_engine, virtual_engine=virtual_engine,
num_tokens=num_tokens, num_tokens=num_tokens,
num_tokens_across_dp=num_tokens_across_dp, num_tokens_across_dp=num_tokens_across_dp,
cudagraph_runtime_mode=aclgraph_runtime_mode, cudagraph_runtime_mode=aclgraph_runtime_mode,
batch_descriptor=batch_descriptor, batch_descriptor=batch_descriptor,
): ):
forward_context = get_forward_context() forward_context = get_forward_context()
from vllm_ascend.ops.fused_moe.moe_comm_method import \ from vllm_ascend.ops.fused_moe.moe_comm_method import get_moe_comm_method
get_moe_comm_method
moe_comm_type = select_moe_comm_method(num_tokens, vllm_config, moe_comm_type = select_moe_comm_method(num_tokens, vllm_config, is_draft_model)
is_draft_model)
forward_context.moe_comm_type = moe_comm_type forward_context.moe_comm_type = moe_comm_type
forward_context.moe_comm_method = get_moe_comm_method(moe_comm_type) forward_context.moe_comm_method = get_moe_comm_method(moe_comm_type)
@@ -68,31 +72,28 @@ def set_ascend_forward_context(
# due to multiple warmups before actual capturing # due to multiple warmups before actual capturing
forward_context.capturing = False forward_context.capturing = False
# set for sequence parallelism, 1000 is the batch size concurrency threshold for enabling the flashcomm_v1 or sequence_parallelism feature. # set for sequence parallelism, 1000 is the batch size concurrency threshold
# Currently, it is an empirical value. In normal scenarios, if the concurrency exceeds this threshold, # for enabling the flashcomm_v1 or sequence_parallelism feature.
# the performance benefits can be maximized. Conversely, if the concurrency is below the threshold, # Currently, it is an empirical value. In normal scenarios, if the concurrency
# exceeds this threshold, the performance benefits can be maximized.
# Conversely, if the concurrency is below the threshold,
# the performance may degrade due to the switching of communication methods. # the performance may degrade due to the switching of communication methods.
mmrs_fusion = True mmrs_fusion = True
# main model and drafter model may have different architecture # main model and drafter model may have different architecture
is_context_moe_model = is_drafter_moe_model(vllm_config) \ is_context_moe_model = is_drafter_moe_model(vllm_config) if is_draft_model else is_moe_model(vllm_config)
if is_draft_model else is_moe_model(vllm_config)
if is_context_moe_model: if is_context_moe_model:
sp_enabled = enable_sp(vllm_config) and num_tokens is not None sp_enabled = enable_sp(vllm_config) and num_tokens is not None
mmrs_fusion = False mmrs_fusion = False
else: else:
sp_enabled = enable_sp(vllm_config) and \ sp_enabled = enable_sp(vllm_config) and num_tokens is not None and num_tokens > 1000
num_tokens is not None and num_tokens > 1000
forward_context.mmrs_fusion = mmrs_fusion forward_context.mmrs_fusion = mmrs_fusion
forward_context.num_tokens = num_tokens forward_context.num_tokens = num_tokens
forward_context.sp_enabled = sp_enabled forward_context.sp_enabled = sp_enabled
# TODO(Levi-JQ): another PR to normalize the enabling logic for sp/fc2 # TODO(Levi-JQ): another PR to normalize the enabling logic for sp/fc2
forward_context.flashcomm_v2_enabled = flashcomm2_enable( forward_context.flashcomm_v2_enabled = flashcomm2_enable() and tp_world_size > 1 and num_tokens is not None
) and tp_world_size > 1 and num_tokens is not None
if (forward_context.sp_enabled if forward_context.sp_enabled or forward_context.flashcomm_v2_enabled:
or forward_context.flashcomm_v2_enabled): pad_size = (tp_world_size - (num_tokens % tp_world_size)) % tp_world_size
pad_size = (tp_world_size -
(num_tokens % tp_world_size)) % tp_world_size
forward_context.pad_size = pad_size forward_context.pad_size = pad_size
# set this for rope forward_oot using # set this for rope forward_oot using
@@ -106,9 +107,12 @@ def set_ascend_forward_context(
# TODO(rjg-lyh): refactor mlp weight prefetch method # TODO(rjg-lyh): refactor mlp weight prefetch method
# set for mlp weight prefetch # set for mlp weight prefetch
prefetch_mlp_enabled = envs_ascend.VLLM_ASCEND_ENABLE_PREFETCH_MLP and \ prefetch_mlp_enabled = (
forward_context.layer_idx is not None and \ envs_ascend.VLLM_ASCEND_ENABLE_PREFETCH_MLP
num_tokens is not None and num_tokens < 500 and forward_context.layer_idx is not None
and num_tokens is not None
and num_tokens < 500
)
if prefetch_mlp_enabled: if prefetch_mlp_enabled:
forward_context.prefetch_mlp_gate_up_proj = False forward_context.prefetch_mlp_gate_up_proj = False
forward_context.prefetch_mlp_down_proj = False forward_context.prefetch_mlp_down_proj = False
@@ -121,12 +125,9 @@ def set_ascend_forward_context(
dp_world_size = get_dp_group().world_size dp_world_size = get_dp_group().world_size
if dp_world_size > 1 and forward_context.dp_metadata is not None: if dp_world_size > 1 and forward_context.dp_metadata is not None:
max_tokens_across_dp = \ max_tokens_across_dp = forward_context.dp_metadata.max_tokens_across_dp_cpu.item()
forward_context.dp_metadata.max_tokens_across_dp_cpu.item() if forward_context.sp_enabled or forward_context.flashcomm_v2_enabled:
if (forward_context.sp_enabled padded_length = (max_tokens_across_dp + tp_world_size - 1) // tp_world_size * tp_world_size
or forward_context.flashcomm_v2_enabled):
padded_length = (max_tokens_across_dp + tp_world_size -
1) // tp_world_size * tp_world_size
pad_size = padded_length - num_tokens pad_size = padded_length - num_tokens
forward_context.padded_length = padded_length forward_context.padded_length = padded_length
forward_context.pad_size = pad_size forward_context.pad_size = pad_size
@@ -139,12 +140,10 @@ def set_ascend_forward_context(
if num_actual_tokens is None: if num_actual_tokens is None:
num_actual_tokens = num_tokens num_actual_tokens = num_tokens
# NOTE: token num which need to pad to when mc2 # NOTE: token num which need to pad to when mc2
forward_context.padded_num_tokens = math.ceil( forward_context.padded_num_tokens = math.ceil(max_tokens_across_dp / tp_world_size) * tp_world_size
max_tokens_across_dp / tp_world_size) * tp_world_size
reserved_mc2_mask = get_mc2_mask() reserved_mc2_mask = get_mc2_mask()
if reserved_mc2_mask is not None: if reserved_mc2_mask is not None:
mc2_mask = reserved_mc2_mask[:forward_context. mc2_mask = reserved_mc2_mask[: forward_context.padded_num_tokens]
padded_num_tokens]
mc2_mask[:num_actual_tokens] = True mc2_mask[:num_actual_tokens] = True
mc2_mask[num_actual_tokens:] = False mc2_mask[num_actual_tokens:] = False
forward_context.mc2_mask = mc2_mask forward_context.mc2_mask = mc2_mask
@@ -155,14 +154,13 @@ def set_ascend_forward_context(
pass pass
_mc2_tokens_capacity: Optional[int] = None _mc2_tokens_capacity: int | None = None
_reserved_mc2_mask: Optional[torch.Tensor] = None _reserved_mc2_mask: torch.Tensor | None = None
_sin: Optional[torch.Tensor] = None _sin: torch.Tensor | None = None
_cos: Optional[torch.Tensor] = None _cos: torch.Tensor | None = None
def set_mc2_tokens_capacity(vllm_config, max_num_reqs, def set_mc2_tokens_capacity(vllm_config, max_num_reqs, uniform_decode_query_len):
uniform_decode_query_len):
global _mc2_tokens_capacity global _mc2_tokens_capacity
if _mc2_tokens_capacity is not None: if _mc2_tokens_capacity is not None:
return return
@@ -186,9 +184,7 @@ def set_mc2_mask(vllm_config, device):
if _reserved_mc2_mask is not None: if _reserved_mc2_mask is not None:
return return
if is_moe_model(vllm_config): if is_moe_model(vllm_config):
_reserved_mc2_mask = torch.zeros(get_mc2_tokens_capacity(), _reserved_mc2_mask = torch.zeros(get_mc2_tokens_capacity(), dtype=torch.bool, device=device)
dtype=torch.bool,
device=device)
else: else:
_reserved_mc2_mask = None _reserved_mc2_mask = None
@@ -197,9 +193,7 @@ def get_mc2_mask():
return _reserved_mc2_mask return _reserved_mc2_mask
def select_moe_comm_method(num_tokens: int, def select_moe_comm_method(num_tokens: int, vllm_config: VllmConfig, is_draft_model=False) -> MoECommType | None:
vllm_config: VllmConfig,
is_draft_model=False) -> Optional[MoECommType]:
"""Select the MoE communication method according to parallel settings, """Select the MoE communication method according to parallel settings,
device generation, token count, and quantization. device generation, token count, and quantization.
@@ -227,16 +221,19 @@ def select_moe_comm_method(num_tokens: int,
mc2_tokens_capacity = get_mc2_tokens_capacity() mc2_tokens_capacity = get_mc2_tokens_capacity()
soc_version = get_ascend_device_type() soc_version = get_ascend_device_type()
quant_type = getattr( quant_type = getattr(
vllm_config.model_config.hf_text_config, 'moe_quantize', vllm_config.model_config.hf_text_config,
getattr(vllm_config.model_config.hf_text_config, 'quantize', None)) "moe_quantize",
getattr(vllm_config.model_config.hf_text_config, "quantize", None),
)
if not vllm_config.parallel_config.enable_expert_parallel or get_ep_group( if not vllm_config.parallel_config.enable_expert_parallel or get_ep_group().world_size == 1:
).world_size == 1:
moe_comm_type = MoECommType.ALLGATHER moe_comm_type = MoECommType.ALLGATHER
elif soc_version in {AscendDeviceType.A2}: elif soc_version in {AscendDeviceType.A2}:
if (num_tokens <= mc2_tokens_capacity if (
and vllm_config.parallel_config.world_size_across_dp / num_tokens <= mc2_tokens_capacity
vllm_config.parallel_config.pipeline_parallel_size >= 16): and vllm_config.parallel_config.world_size_across_dp / vllm_config.parallel_config.pipeline_parallel_size
>= 16
):
moe_comm_type = MoECommType.MC2 moe_comm_type = MoECommType.MC2
else: else:
moe_comm_type = MoECommType.ALLGATHER moe_comm_type = MoECommType.ALLGATHER
@@ -246,15 +243,13 @@ def select_moe_comm_method(num_tokens: int,
# TODO: drop the EP-size guard when dispatch_ffn_combine supports larger EP sizes # TODO: drop the EP-size guard when dispatch_ffn_combine supports larger EP sizes
# TODO: drop speculative method guard when dispatch_gmm_combine_decode supports w16a16 # TODO: drop speculative method guard when dispatch_gmm_combine_decode supports w16a16
fused_mc2_enable = envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 and quant_type == "w8a8_dynamic" fused_mc2_enable = envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 and quant_type == "w8a8_dynamic"
dispatch_ffn_combine_enable = get_ep_group().world_size <= 32 and ( dispatch_ffn_combine_enable = get_ep_group().world_size <= 32 and (not is_draft_model) and (not dynamic_eplb)
not is_draft_model) and (not dynamic_eplb)
if num_tokens <= mc2_tokens_capacity: if num_tokens <= mc2_tokens_capacity:
fused_decode_enable = fused_mc2_enable fused_decode_enable = fused_mc2_enable
if envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 == 1: if envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 == 1:
fused_decode_enable = fused_mc2_enable and dispatch_ffn_combine_enable fused_decode_enable = fused_mc2_enable and dispatch_ffn_combine_enable
elif envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 == 2: elif envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 == 2:
fused_decode_enable = fused_mc2_enable and \ fused_decode_enable = fused_mc2_enable and speculative_enable_dispatch_gmm_combine_decode(vllm_config)
speculative_enable_dispatch_gmm_combine_decode(vllm_config)
moe_comm_type = MoECommType.FUSED_MC2 if fused_decode_enable else MoECommType.MC2 moe_comm_type = MoECommType.FUSED_MC2 if fused_decode_enable else MoECommType.MC2
else: else:
fused_prefill_enable = fused_mc2_enable fused_prefill_enable = fused_mc2_enable

View File

@@ -2,9 +2,10 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import dataclasses import dataclasses
from collections.abc import Callable
from contextlib import ExitStack from contextlib import ExitStack
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Callable, Optional from typing import Any
from unittest.mock import patch from unittest.mock import patch
import numpy as np import numpy as np
@@ -27,12 +28,12 @@ from ..utils import weak_ref_tensors
@dataclasses.dataclass @dataclasses.dataclass
class ACLGraphEntry: class ACLGraphEntry:
batch_descriptor: BatchDescriptor batch_descriptor: BatchDescriptor
aclgraph: Optional[torch.npu.NPUGraph] = None aclgraph: torch.npu.NPUGraph | None = None
output: Optional[Any] = None output: Any | None = None
# for aclgraph debugging, track the input addresses # for aclgraph debugging, track the input addresses
# during capture, and check if they are the same during replay # during capture, and check if they are the same during replay
input_addresses: Optional[list[int]] = None input_addresses: list[int] | None = None
class ACLGraphWrapper: class ACLGraphWrapper:
@@ -60,11 +61,13 @@ class ACLGraphWrapper:
guaranteed when VLLM_LOGGING_LEVEL == "DEBUG". guaranteed when VLLM_LOGGING_LEVEL == "DEBUG".
""" """
def __init__(self, def __init__(
runnable: Callable, self,
vllm_config: VllmConfig, runnable: Callable,
runtime_mode: CUDAGraphMode, vllm_config: VllmConfig,
cudagraph_options: Optional[CUDAGraphOptions] = None): runtime_mode: CUDAGraphMode,
cudagraph_options: CUDAGraphOptions | None = None,
):
self.runnable = runnable self.runnable = runnable
self.vllm_config = vllm_config self.vllm_config = vllm_config
self.runtime_mode = runtime_mode self.runtime_mode = runtime_mode
@@ -83,15 +86,13 @@ class ACLGraphWrapper:
self.aclgraph_options = cudagraph_options self.aclgraph_options = cudagraph_options
# the entries for different batch descriptors that we need to capture # the entries for different batch descriptors that we need to capture
# aclgraphs for. # aclgraphs for.
self.concrete_aclgraph_entries: dict[BatchDescriptor, ACLGraphEntry]\ self.concrete_aclgraph_entries: dict[BatchDescriptor, ACLGraphEntry] = {}
= {}
def __getattr__(self, key: str): def __getattr__(self, key: str):
# allow accessing the attributes of the runnable. # allow accessing the attributes of the runnable.
if hasattr(self.runnable, key): if hasattr(self.runnable, key):
return getattr(self.runnable, key) return getattr(self.runnable, key)
raise AttributeError(f"Attribute {key} not exists in the runnable of " raise AttributeError(f"Attribute {key} not exists in the runnable of aclgraph wrapper: {self.runnable}")
f"aclgraph wrapper: {self.runnable}")
def unwrap(self) -> Callable: def unwrap(self) -> Callable:
# in case we need to access the original runnable. # in case we need to access the original runnable.
@@ -102,8 +103,7 @@ class ACLGraphWrapper:
batch_descriptor = forward_context.batch_descriptor batch_descriptor = forward_context.batch_descriptor
aclgraph_runtime_mode = forward_context.cudagraph_runtime_mode aclgraph_runtime_mode = forward_context.cudagraph_runtime_mode
if aclgraph_runtime_mode == CUDAGraphMode.NONE or \ if aclgraph_runtime_mode == CUDAGraphMode.NONE or aclgraph_runtime_mode != self.runtime_mode:
aclgraph_runtime_mode != self.runtime_mode:
# CUDAGraphMode.NONE could mean the profile run, a warmup run, or # CUDAGraphMode.NONE could mean the profile run, a warmup run, or
# running without aclgraphs. # running without aclgraphs.
# We do not trigger capture/replay if the runtime mode is not # We do not trigger capture/replay if the runtime mode is not
@@ -114,8 +114,7 @@ class ACLGraphWrapper:
if batch_descriptor not in self.concrete_aclgraph_entries: if batch_descriptor not in self.concrete_aclgraph_entries:
# create a new entry for this batch descriptor # create a new entry for this batch descriptor
self.concrete_aclgraph_entries[batch_descriptor] = \ self.concrete_aclgraph_entries[batch_descriptor] = ACLGraphEntry(batch_descriptor=batch_descriptor)
ACLGraphEntry(batch_descriptor=batch_descriptor)
entry = self.concrete_aclgraph_entries[batch_descriptor] entry = self.concrete_aclgraph_entries[batch_descriptor]
@@ -125,14 +124,11 @@ class ACLGraphWrapper:
# capturing is fast, we don't need to log it for every # capturing is fast, we don't need to log it for every
# shape. E.g. we only log it for the first subgraph in # shape. E.g. we only log it for the first subgraph in
# piecewise mode. # piecewise mode.
logger.debug("Capturing a aclgraph on (%s,%s)", logger.debug("Capturing a aclgraph on (%s,%s)", self.runtime_mode.name, entry.batch_descriptor)
self.runtime_mode.name, entry.batch_descriptor)
# validate that aclgraph capturing is legal at this point. # validate that aclgraph capturing is legal at this point.
validate_cudagraph_capturing_enabled() validate_cudagraph_capturing_enabled()
input_addresses = [ input_addresses = [x.data_ptr() for x in args if isinstance(x, torch.Tensor)]
x.data_ptr() for x in args if isinstance(x, torch.Tensor)
]
entry.input_addresses = input_addresses entry.input_addresses = input_addresses
aclgraph = torch.npu.NPUGraph() aclgraph = torch.npu.NPUGraph()
@@ -145,8 +141,7 @@ class ACLGraphWrapper:
# therefore, we only run gc for the first graph, # therefore, we only run gc for the first graph,
# and disable gc for the rest of the graphs. # and disable gc for the rest of the graphs.
stack.enter_context(patch("gc.collect", lambda: None)) stack.enter_context(patch("gc.collect", lambda: None))
stack.enter_context( stack.enter_context(patch("torch.npu.empty_cache", lambda: None))
patch("torch.npu.empty_cache", lambda: None))
# mind-exploding: carefully manage the reference and memory. # mind-exploding: carefully manage the reference and memory.
forward_context.capturing = True forward_context.capturing = True
@@ -183,13 +178,12 @@ class ACLGraphWrapper:
if self.is_debugging_mode: if self.is_debugging_mode:
# check if the input addresses are the same # check if the input addresses are the same
new_input_addresses = [ new_input_addresses = [x.data_ptr() for x in args if isinstance(x, torch.Tensor)]
x.data_ptr() for x in args if isinstance(x, torch.Tensor)
]
assert new_input_addresses == entry.input_addresses, ( assert new_input_addresses == entry.input_addresses, (
f"Input addresses for aclgraphs are different " f"Input addresses for aclgraphs are different "
f"during replay. Expected {entry.input_addresses}, " f"during replay. Expected {entry.input_addresses}, "
f"got {new_input_addresses}") f"got {new_input_addresses}"
)
logger.info_once("Replaying aclgraph") logger.info_once("Replaying aclgraph")
# In async scheduling or multi-threaded (MT) scenarios when graph mode is FULL, it is possible that # In async scheduling or multi-threaded (MT) scenarios when graph mode is FULL, it is possible that
@@ -209,8 +203,7 @@ def weak_ref_workspaces(params):
for num_tokens in params.workspaces: for num_tokens in params.workspaces:
if params.workspaces[num_tokens] is None: if params.workspaces[num_tokens] is None:
continue continue
params.workspaces[num_tokens] = weak_ref_tensors( params.workspaces[num_tokens] = weak_ref_tensors(params.workspaces[num_tokens])
params.workspaces[num_tokens])
def _update_attn_pa_params(update_stream, forward_context, runtime_shape): def _update_attn_pa_params(update_stream, forward_context, runtime_shape):
@@ -219,10 +212,10 @@ def _update_attn_pa_params(update_stream, forward_context, runtime_shape):
# for each layer's attention op in the graph. # for each layer's attention op in the graph.
with torch.npu.stream(update_stream): with torch.npu.stream(update_stream):
for key, param, handle, event in zip( for key, param, handle, event in zip(
forward_context.attn_metadata, forward_context.attn_metadata,
graph_params.attn_params[runtime_shape], graph_params.attn_params[runtime_shape],
graph_params.handles[runtime_shape], graph_params.handles[runtime_shape],
graph_params.events[runtime_shape], graph_params.events[runtime_shape],
): ):
( (
query, query,
@@ -254,18 +247,21 @@ def _update_attn_pa_params(update_stream, forward_context, runtime_shape):
scale_value=scale, scale_value=scale,
block_table=block_table, block_table=block_table,
context_lens=seq_lens, context_lens=seq_lens,
out=output) out=output,
)
torch.npu.graph_task_update_begin(update_stream, handle) torch.npu.graph_task_update_begin(update_stream, handle)
torch_npu._npu_paged_attention(query=query, torch_npu._npu_paged_attention(
key_cache=key_cache, query=query,
value_cache=value_cache, key_cache=key_cache,
num_kv_heads=num_kv_heads, value_cache=value_cache,
num_heads=num_heads, num_kv_heads=num_kv_heads,
scale_value=scale, num_heads=num_heads,
block_table=block_table, scale_value=scale,
context_lens=seq_lens, block_table=block_table,
out=output, context_lens=seq_lens,
workspace=workspace) out=output,
workspace=workspace,
)
torch.npu.graph_task_update_end(update_stream) torch.npu.graph_task_update_end(update_stream)
event.record(update_stream) event.record(update_stream)
@@ -282,18 +278,29 @@ def _update_attn_fia_params(update_stream, forward_context, runtime_shape):
# filters out the update operations for linear_attn. # filters out the update operations for linear_attn.
with torch.npu.stream(update_stream): with torch.npu.stream(update_stream):
for key, param, handle, event in zip( for key, param, handle, event in zip(
forward_context.attn_metadata, forward_context.attn_metadata,
graph_params.attn_params[runtime_shape], graph_params.attn_params[runtime_shape],
graph_params.handles[runtime_shape], graph_params.handles[runtime_shape],
graph_params.events[runtime_shape], graph_params.events[runtime_shape],
): ):
(query, key_cache, value, block_tables, attn_mask, block_size, (
seq_lens, query_start_loc, num_kv_heads, num_heads, scale, query,
attn_output, softmax_lse) = param key_cache,
value,
block_tables,
attn_mask,
block_size,
seq_lens,
query_start_loc,
num_kv_heads,
num_heads,
scale,
attn_output,
softmax_lse,
) = param
seq_lens = forward_context.attn_metadata[key].seq_lens_list seq_lens = forward_context.attn_metadata[key].seq_lens_list
actual_seq_lengths_q = forward_context.attn_metadata[ actual_seq_lengths_q = forward_context.attn_metadata[key].actual_seq_lengths_q
key].actual_seq_lengths_q
torch.npu.graph_task_update_begin(update_stream, handle) torch.npu.graph_task_update_begin(update_stream, handle)
torch_npu.npu_fused_infer_attention_score.out( torch_npu.npu_fused_infer_attention_score.out(
query=query, query=query,
@@ -317,16 +324,14 @@ def _update_attn_fia_params(update_stream, forward_context, runtime_shape):
event.record(update_stream) event.record(update_stream)
def update_attn_params(update_stream, forward_context, runtime_shape, def update_attn_params(update_stream, forward_context, runtime_shape, vllm_config):
vllm_config):
if using_paged_attention(runtime_shape, vllm_config): if using_paged_attention(runtime_shape, vllm_config):
_update_attn_pa_params(update_stream, forward_context, runtime_shape) _update_attn_pa_params(update_stream, forward_context, runtime_shape)
else: else:
_update_attn_fia_params(update_stream, forward_context, runtime_shape) _update_attn_fia_params(update_stream, forward_context, runtime_shape)
def update_mla_attn_params(update_stream, forward_context, runtime_shape, def update_mla_attn_params(update_stream, forward_context, runtime_shape, speculative_config):
speculative_config):
if forward_context.is_draft_model: if forward_context.is_draft_model:
graph_params = get_draft_graph_params() graph_params = get_draft_graph_params()
else: else:
@@ -335,41 +340,44 @@ def update_mla_attn_params(update_stream, forward_context, runtime_shape,
# for each layer's attention op in the graph. # for each layer's attention op in the graph.
with torch.npu.stream(update_stream): with torch.npu.stream(update_stream):
for key, param, handle, event in zip( for key, param, handle, event in zip(
forward_context.attn_metadata, forward_context.attn_metadata,
graph_params.attn_params[runtime_shape], graph_params.attn_params[runtime_shape],
graph_params.handles[runtime_shape], graph_params.handles[runtime_shape],
graph_params.events[runtime_shape], graph_params.events[runtime_shape],
): ):
(q_nope, k_nope, q_pe, k_pe, num_heads, num_kv_heads, input_layout, (
attn_mask, sparse_mode, scale, block_table, block_size, q_nope,
seq_lens_list, actual_seq_lengths, attn_output, k_nope,
softmax_lse) = param q_pe,
seq_lens_list = forward_context.attn_metadata[ k_pe,
key].decode.seq_lens_list num_heads,
if speculative_config and speculative_config.method == "mtp" \ num_kv_heads,
and not forward_context.is_draft_model: input_layout,
actual_seq_lengths = forward_context.attn_metadata[ attn_mask,
key].decode.actual_seq_lengths_q sparse_mode,
scale,
block_table,
block_size,
seq_lens_list,
actual_seq_lengths,
attn_output,
softmax_lse,
) = param
seq_lens_list = forward_context.attn_metadata[key].decode.seq_lens_list
if speculative_config and speculative_config.method == "mtp" and not forward_context.is_draft_model:
actual_seq_lengths = forward_context.attn_metadata[key].decode.actual_seq_lengths_q
spec_multiple = speculative_config.num_speculative_tokens + 1 spec_multiple = speculative_config.num_speculative_tokens + 1
seq_lens_list = seq_lens_list + [0] * ( seq_lens_list = seq_lens_list + [0] * (runtime_shape // spec_multiple - len(seq_lens_list))
runtime_shape // spec_multiple - len(seq_lens_list)) actual_seq_lengths = [spec_multiple * (i + 1) for i in range(runtime_shape // spec_multiple)]
actual_seq_lengths = [
spec_multiple * (i + 1)
for i in range(runtime_shape // spec_multiple)
]
elif forward_context.is_draft_model: elif forward_context.is_draft_model:
actual_seq_lengths = forward_context.attn_metadata[ actual_seq_lengths = forward_context.attn_metadata[key].decode.actual_seq_lengths_q
key].decode.actual_seq_lengths_q block_table = forward_context.attn_metadata[key].decode.block_table
block_table = forward_context.attn_metadata[
key].decode.block_table
# TODO: This is a hack and should be fixed in the future. # TODO: This is a hack and should be fixed in the future.
if speculative_config.disable_padded_drafter_batch: if speculative_config.disable_padded_drafter_batch:
block_table = block_table[:len(actual_seq_lengths)] block_table = block_table[: len(actual_seq_lengths)]
seq_lens_list = seq_lens_list + [0] * ( seq_lens_list = seq_lens_list + [0] * (len(actual_seq_lengths) - len(seq_lens_list))
len(actual_seq_lengths) - len(seq_lens_list))
else: else:
seq_lens_list = seq_lens_list + [0] * (runtime_shape - seq_lens_list = seq_lens_list + [0] * (runtime_shape - len(seq_lens_list))
len(seq_lens_list))
torch.npu.graph_task_update_begin(update_stream, handle) torch.npu.graph_task_update_begin(update_stream, handle)
torch_npu.npu_fused_infer_attention_score.out( torch_npu.npu_fused_infer_attention_score.out(
@@ -391,7 +399,8 @@ def update_mla_attn_params(update_stream, forward_context, runtime_shape,
actual_seq_lengths_kv=seq_lens_list, actual_seq_lengths_kv=seq_lens_list,
actual_seq_lengths=actual_seq_lengths, actual_seq_lengths=actual_seq_lengths,
workspace=graph_params.workspaces.get(runtime_shape), workspace=graph_params.workspaces.get(runtime_shape),
out=[attn_output, softmax_lse]) out=[attn_output, softmax_lse],
)
torch.npu.graph_task_update_end(update_stream) torch.npu.graph_task_update_end(update_stream)
event.record(update_stream) event.record(update_stream)
@@ -403,34 +412,40 @@ def update_attn_dcp_pcp_params(update_stream, forward_context, runtime_shape):
graph_params = get_graph_params() graph_params = get_graph_params()
with torch.npu.stream(update_stream): with torch.npu.stream(update_stream):
for key, param, handle, event in zip( for key, param, handle, event in zip(
forward_context.attn_metadata, forward_context.attn_metadata,
graph_params.attn_params[runtime_shape], graph_params.attn_params[runtime_shape],
graph_params.handles[runtime_shape], graph_params.handles[runtime_shape],
graph_params.events[runtime_shape], graph_params.events[runtime_shape],
): ):
(q_nope, k_nope, value, num_heads, num_kv_heads, scale, (
block_table, block_size, actual_seq_lengths_kv, q_nope,
actual_seq_lengths_q, attn_output, softmax_lse, dcp_size, k_nope,
pcp_rank, dcp_rank) = param value,
num_heads,
num_kv_heads,
scale,
block_table,
block_size,
actual_seq_lengths_kv,
actual_seq_lengths_q,
attn_output,
softmax_lse,
dcp_size,
pcp_rank,
dcp_rank,
) = param
attn_metadata = forward_context.attn_metadata[key] attn_metadata = forward_context.attn_metadata[key]
actual_seq_lengths_kv = attn_metadata.decode_meta.num_computed_tokens_of_pcp_dcp[:, actual_seq_lengths_kv = attn_metadata.decode_meta.num_computed_tokens_of_pcp_dcp[:, pcp_rank, dcp_rank]
pcp_rank,
dcp_rank]
pad_length = runtime_shape - len(actual_seq_lengths_kv) pad_length = runtime_shape - len(actual_seq_lengths_kv)
if pad_length > 0: if pad_length > 0:
pad_tensor = np.zeros(pad_length, pad_tensor = np.zeros(pad_length, dtype=actual_seq_lengths_kv.dtype)
dtype=actual_seq_lengths_kv.dtype) actual_seq_lengths_kv = np.concatenate([actual_seq_lengths_kv, pad_tensor])
actual_seq_lengths_kv = np.concatenate(
[actual_seq_lengths_kv, pad_tensor])
actual_seq_lengths_q = attn_metadata.actual_seq_lengths_q[: actual_seq_lengths_q = attn_metadata.actual_seq_lengths_q[: attn_metadata.num_decode_tokens]
attn_metadata if runtime_shape - len(actual_seq_lengths_q):
. actual_seq_lengths_q = actual_seq_lengths_q + [actual_seq_lengths_q[-1]] * (
num_decode_tokens] runtime_shape - len(actual_seq_lengths_q)
if (runtime_shape - len(actual_seq_lengths_q)): )
actual_seq_lengths_q = actual_seq_lengths_q + [
actual_seq_lengths_q[-1]
] * (runtime_shape - len(actual_seq_lengths_q))
if dcp_size > 1: if dcp_size > 1:
num_heads = num_heads * dcp_size num_heads = num_heads * dcp_size
@@ -453,14 +468,14 @@ def update_attn_dcp_pcp_params(update_stream, forward_context, runtime_shape):
actual_seq_lengths_kv=actual_seq_lengths_kv, actual_seq_lengths_kv=actual_seq_lengths_kv,
actual_seq_lengths=actual_seq_lengths_q, actual_seq_lengths=actual_seq_lengths_q,
workspace=graph_params.workspaces.get(runtime_shape), workspace=graph_params.workspaces.get(runtime_shape),
out=[attn_output, softmax_lse]) out=[attn_output, softmax_lse],
)
torch.npu.graph_task_update_end(update_stream) torch.npu.graph_task_update_end(update_stream)
event.record(update_stream) event.record(update_stream)
def update_mla_attn_dcp_pcp_params(update_stream, forward_context, def update_mla_attn_dcp_pcp_params(update_stream, forward_context, runtime_shape):
runtime_shape):
if forward_context.is_draft_model: if forward_context.is_draft_model:
graph_params = get_draft_graph_params() graph_params = get_draft_graph_params()
else: else:
@@ -469,13 +484,24 @@ def update_mla_attn_dcp_pcp_params(update_stream, forward_context,
# for each layer's attention op in the graph. # for each layer's attention op in the graph.
with torch.npu.stream(update_stream): with torch.npu.stream(update_stream):
for key, param, handle, event in zip( for key, param, handle, event in zip(
forward_context.attn_metadata, forward_context.attn_metadata,
graph_params.attn_params[runtime_shape], graph_params.attn_params[runtime_shape],
graph_params.handles[runtime_shape], graph_params.handles[runtime_shape],
graph_params.events[runtime_shape], graph_params.events[runtime_shape],
): ):
(q_nope, q_pe, k_nope, k_pe, block_table, seq_len, num_heads, (
scale, num_kv_heads, attn_output, softmax_lse) = param q_nope,
q_pe,
k_nope,
k_pe,
block_table,
seq_len,
num_heads,
scale,
num_kv_heads,
attn_output,
softmax_lse,
) = param
decode_meta = forward_context.attn_metadata[key].decode decode_meta = forward_context.attn_metadata[key].decode
seq_len = decode_meta.cp_seq_len seq_len = decode_meta.cp_seq_len
@@ -484,9 +510,7 @@ def update_mla_attn_dcp_pcp_params(update_stream, forward_context,
# to avoid irregular attn_mask shape, # to avoid irregular attn_mask shape,
# so there's no need to divide runtime_shape by spec_multiple # so there's no need to divide runtime_shape by spec_multiple
pad_length = runtime_shape - len(seq_len) pad_length = runtime_shape - len(seq_len)
pad_tensor = torch.zeros(pad_length, pad_tensor = torch.zeros(pad_length, dtype=seq_len.dtype, device=seq_len.device)
dtype=seq_len.dtype,
device=seq_len.device)
seq_len = torch.cat([seq_len, pad_tensor], dim=0) seq_len = torch.cat([seq_len, pad_tensor], dim=0)
torch.npu.graph_task_update_begin(update_stream, handle) torch.npu.graph_task_update_begin(update_stream, handle)
@@ -505,7 +529,8 @@ def update_mla_attn_dcp_pcp_params(update_stream, forward_context,
calc_type="calc_type_ring", calc_type="calc_type_ring",
workspace=graph_params.workspaces.get(runtime_shape), workspace=graph_params.workspaces.get(runtime_shape),
output=attn_output, output=attn_output,
lse=softmax_lse) lse=softmax_lse,
)
torch.npu.graph_task_update_end(update_stream) torch.npu.graph_task_update_end(update_stream)
event.record(update_stream) event.record(update_stream)
@@ -519,7 +544,7 @@ class GraphParams:
attn_params: dict[int, list[tuple]] attn_params: dict[int, list[tuple]]
_graph_params: Optional[GraphParams] = None _graph_params: GraphParams | None = None
def set_graph_params(aclgraph_capture_sizes: list[int]): def set_graph_params(aclgraph_capture_sizes: list[int]):
@@ -527,14 +552,10 @@ def set_graph_params(aclgraph_capture_sizes: list[int]):
if _graph_params is not None: if _graph_params is not None:
raise ValueError("Graph parameters have already been set!") raise ValueError("Graph parameters have already been set!")
_graph_params = GraphParams( _graph_params = GraphParams(
{size: [] {size: [] for size in aclgraph_capture_sizes},
for size in aclgraph_capture_sizes}, {size: None for size in aclgraph_capture_sizes},
{size: None {size: [] for size in aclgraph_capture_sizes},
for size in aclgraph_capture_sizes}, {size: [] for size in aclgraph_capture_sizes},
{size: []
for size in aclgraph_capture_sizes},
{size: []
for size in aclgraph_capture_sizes},
) )
@@ -548,7 +569,7 @@ def get_graph_params():
return _graph_params return _graph_params
_draft_graph_params: Optional[GraphParams] = None _draft_graph_params: GraphParams | None = None
def set_draft_graph_params(aclgraph_capture_sizes: list[int]): def set_draft_graph_params(aclgraph_capture_sizes: list[int]):
@@ -556,14 +577,10 @@ def set_draft_graph_params(aclgraph_capture_sizes: list[int]):
if _draft_graph_params is not None: if _draft_graph_params is not None:
raise ValueError("DraftGraph parameters have already been set!") raise ValueError("DraftGraph parameters have already been set!")
_draft_graph_params = GraphParams( _draft_graph_params = GraphParams(
{size: [] {size: [] for size in aclgraph_capture_sizes},
for size in aclgraph_capture_sizes}, {size: None for size in aclgraph_capture_sizes},
{size: None {size: [] for size in aclgraph_capture_sizes},
for size in aclgraph_capture_sizes}, {size: [] for size in aclgraph_capture_sizes},
{size: []
for size in aclgraph_capture_sizes},
{size: []
for size in aclgraph_capture_sizes},
) )

View File

@@ -16,13 +16,13 @@
# limitations under the License. # limitations under the License.
# #
import functools import functools
from typing import Any, Callable, Optional from collections.abc import Callable
from typing import Any
import torch import torch
import torch.fx as fx import torch.fx as fx
from torch._dynamo.backends.common import aot_autograd from torch._dynamo.backends.common import aot_autograd
from torch._inductor.compile_fx import (graph_returns_tuple, from torch._inductor.compile_fx import graph_returns_tuple, make_graph_return_tuple
make_graph_return_tuple)
from torch._inductor.decomposition import select_decomp_table from torch._inductor.decomposition import select_decomp_table
from torch.fx import GraphModule from torch.fx import GraphModule
from vllm.compilation.compiler_interface import CompilerInterface from vllm.compilation.compiler_interface import CompilerInterface
@@ -32,15 +32,11 @@ from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.utils import COMPILATION_PASS_KEY from vllm_ascend.utils import COMPILATION_PASS_KEY
def compile_fx(graph: GraphModule, example_inputs: list, def compile_fx(graph: GraphModule, example_inputs: list, inner_compile: Callable, decompositions: dict) -> Callable:
inner_compile: Callable, decompositions: dict) -> Callable: recursive_compile_fx = functools.partial(compile_fx, inner_compile=inner_compile, decompositions=decompositions)
recursive_compile_fx = functools.partial(compile_fx,
inner_compile=inner_compile,
decompositions=decompositions)
if not graph_returns_tuple(graph): if not graph_returns_tuple(graph):
return make_graph_return_tuple(graph, example_inputs, return make_graph_return_tuple(graph, example_inputs, recursive_compile_fx)
recursive_compile_fx)
return aot_autograd(fw_compiler=inner_compile)(graph, example_inputs) return aot_autograd(fw_compiler=inner_compile)(graph, example_inputs)
@@ -49,9 +45,8 @@ def fusion_pass_compile(
example_inputs: list[Any], example_inputs: list[Any],
compiler_config: dict[str, Any], compiler_config: dict[str, Any],
compile_range: Range, compile_range: Range,
key: Optional[str] = None, key: str | None = None,
) -> tuple[Optional[Callable], Optional[Any]]: ) -> tuple[Callable | None, Any | None]:
def compile_inner(graph, example_inputs): def compile_inner(graph, example_inputs):
current_pass_manager = compiler_config[COMPILATION_PASS_KEY] current_pass_manager = compiler_config[COMPILATION_PASS_KEY]
graph = current_pass_manager(graph) graph = current_pass_manager(graph)
@@ -74,8 +69,8 @@ def npugraph_ex_compile(
example_inputs: list[Any], example_inputs: list[Any],
compiler_config: dict[str, Any], compiler_config: dict[str, Any],
compile_range: Range, compile_range: Range,
key: Optional[str] = None, key: str | None = None,
) -> tuple[Optional[Callable], Optional[Any]]: ) -> tuple[Callable | None, Any | None]:
# When currently using the FULL_DECODE_ONLY mode, # When currently using the FULL_DECODE_ONLY mode,
# the piecewise compilation level slicing process # the piecewise compilation level slicing process
# in vllm is also encountered. # in vllm is also encountered.
@@ -87,10 +82,8 @@ def npugraph_ex_compile(
output_node = fx_graph.output_node() output_node = fx_graph.output_node()
with fx_graph.inserting_before(output_node): with fx_graph.inserting_before(output_node):
return_value = output_node.args[0] return_value = output_node.args[0]
tuple_node = fx_graph.create_node("call_function", tuple_node = fx_graph.create_node("call_function", tuple, args=([return_value],))
tuple, output_node.args = (tuple_node,)
args=([return_value], ))
output_node.args = (tuple_node, )
graph.recompile() graph.recompile()
import torchair import torchair
@@ -119,6 +112,7 @@ class AscendCompiler(CompilerInterface):
This class provides a method to compile a PyTorch FX graph module with This class provides a method to compile a PyTorch FX graph module with
specific configurations for graph fusion and decomposition. specific configurations for graph fusion and decomposition.
""" """
name = "AscendCompiler" name = "AscendCompiler"
def compile( def compile(
@@ -127,13 +121,10 @@ class AscendCompiler(CompilerInterface):
example_inputs: list[Any], example_inputs: list[Any],
compiler_config: dict[str, Any], compiler_config: dict[str, Any],
compile_range: Range, compile_range: Range,
key: Optional[str] = None, key: str | None = None,
) -> tuple[Optional[Callable], Optional[Any]]: ) -> tuple[Callable | None, Any | None]:
ascend_config = get_ascend_config() ascend_config = get_ascend_config()
if ascend_config.enable_npugraph_ex: if ascend_config.enable_npugraph_ex:
return npugraph_ex_compile(graph, example_inputs, compiler_config, return npugraph_ex_compile(graph, example_inputs, compiler_config, compile_range, key)
compile_range, key)
else: else:
return fusion_pass_compile(graph, example_inputs, compiler_config, return fusion_pass_compile(graph, example_inputs, compiler_config, compile_range, key)
compile_range, key)

View File

@@ -26,7 +26,7 @@ class GraphFusionPassManager:
""" """
A pass manager for graph fusion passes. A pass manager for graph fusion passes.
It handles the configuration and execution of passes. It handles the configuration and execution of passes.
The counterpart in vllm is PostGradPassManager. Since torch_npu The counterpart in vllm is PostGradPassManager. Since torch_npu
does not support triton for now, we define our own pass manager. does not support triton for now, we define our own pass manager.
""" """
@@ -48,13 +48,13 @@ class GraphFusionPassManager:
def configure(self, config: VllmConfig): def configure(self, config: VllmConfig):
# By default, we enable the graph fusion and quantization fusion pass. # By default, we enable the graph fusion and quantization fusion pass.
self.ascend_compilation_config: dict = config.additional_config.get( self.ascend_compilation_config: dict = config.additional_config.get("ascend_compilation_config", {})
"ascend_compilation_config", {})
if self.ascend_compilation_config.get("fuse_norm_quant", True): if self.ascend_compilation_config.get("fuse_norm_quant", True):
from .passes.norm_quant_fusion_pass import \ from .passes.norm_quant_fusion_pass import AddRMSNormQuantFusionPass
AddRMSNormQuantFusionPass
self.passes.append(AddRMSNormQuantFusionPass(config)) self.passes.append(AddRMSNormQuantFusionPass(config))
if self.ascend_compilation_config.get("fuse_qknorm_rope", True): if self.ascend_compilation_config.get("fuse_qknorm_rope", True):
from .passes.qknorm_rope_fusion_pass import QKNormRopeFusionPass from .passes.qknorm_rope_fusion_pass import QKNormRopeFusionPass
self.passes.append(QKNormRopeFusionPass(config)) self.passes.append(QKNormRopeFusionPass(config))

View File

@@ -48,7 +48,8 @@ def _extra_stream_scope_check(match: Match) -> bool:
logger.debug( logger.debug(
f"Cross-stream operation detected in pattern match for AddRMSNormQuant. " f"Cross-stream operation detected in pattern match for AddRMSNormQuant. "
f"Multiple streams found: {non_default_streams}. " f"Multiple streams found: {non_default_streams}. "
f"Fusion is not supported for cross-stream operations.") f"Fusion is not supported for cross-stream operations."
)
return False return False
return True return True
@@ -57,24 +58,29 @@ def _extra_stream_scope_check(match: Match) -> bool:
@functools.lru_cache(None) @functools.lru_cache(None)
# The replacement registered here will be actually executed after AOT. # The replacement registered here will be actually executed after AOT.
def replacement_add_rms_norm_quant(epsilon): def replacement_add_rms_norm_quant(epsilon):
def pattern(
def pattern(rms_norm_input: torch.Tensor, residual: torch.Tensor, rms_norm_input: torch.Tensor,
rms_norm_weight: torch.Tensor, scale: torch.Tensor, residual: torch.Tensor,
offset: torch.Tensor): rms_norm_weight: torch.Tensor,
scale: torch.Tensor,
offset: torch.Tensor,
):
""" """
Pattern for AddRMSNormQuant fusion. Pattern for AddRMSNormQuant fusion.
""" """
output = torch.ops.npu.npu_add_rms_norm(rms_norm_input, residual, output = torch.ops.npu.npu_add_rms_norm(rms_norm_input, residual, rms_norm_weight, epsilon)
rms_norm_weight, epsilon)
out0 = output[0] out0 = output[0]
out1 = output[2] out1 = output[2]
quantized_output = torch.ops.npu.npu_quantize(out0, scale, offset, quantized_output = torch.ops.npu.npu_quantize(out0, scale, offset, torch.qint8, -1, False)
torch.qint8, -1, False)
return quantized_output, out1 return quantized_output, out1
def replacement(rms_norm_input: torch.Tensor, residual: torch.Tensor, def replacement(
rms_norm_weight: torch.Tensor, scale: torch.Tensor, rms_norm_input: torch.Tensor,
offset: torch.Tensor): residual: torch.Tensor,
rms_norm_weight: torch.Tensor,
scale: torch.Tensor,
offset: torch.Tensor,
):
""" """
Replacement for the AddRMSNormQuant fusion. Replacement for the AddRMSNormQuant fusion.
""" """
@@ -82,10 +88,12 @@ def replacement_add_rms_norm_quant(epsilon):
rms_norm_input, rms_norm_input,
residual, residual,
rms_norm_weight, rms_norm_weight,
# The inverse of scale is required by npu_add_rms_norm_quant kernel which is opposite to the npu_quantize kernel. # The inverse of scale is required by npu_add_rms_norm_quant kernel
1. / scale, # which is opposite to the npu_quantize kernel.
1.0 / scale,
offset, offset,
epsilon=epsilon) epsilon=epsilon,
)
quantized_output = output[0] quantized_output = output[0]
out1 = output[2] out1 = output[2]
return quantized_output, out1 return quantized_output, out1
@@ -103,33 +111,39 @@ def replacement_add_rms_norm_quant(epsilon):
import torchair import torchair
torchair.register_replacement(search_fn=pattern, torchair.register_replacement(
replace_fn=replacement, search_fn=pattern, replace_fn=replacement, example_inputs=get_inputs(), extra_check=_extra_stream_scope_check
example_inputs=get_inputs(), )
extra_check=_extra_stream_scope_check)
# The replacement registered here will be actually executed after AOT. # The replacement registered here will be actually executed after AOT.
def replacement_add_rms_norm_quant_with_bias(epsilon): def replacement_add_rms_norm_quant_with_bias(epsilon):
def pattern(
def pattern(rms_norm_input: torch.Tensor, residual: torch.Tensor, rms_norm_input: torch.Tensor,
rms_norm_weight: torch.Tensor, scale: torch.Tensor, residual: torch.Tensor,
offset: torch.Tensor, bias: torch.Tensor): rms_norm_weight: torch.Tensor,
scale: torch.Tensor,
offset: torch.Tensor,
bias: torch.Tensor,
):
""" """
Pattern for AddRMSNormQuantWithBias fusion. Pattern for AddRMSNormQuantWithBias fusion.
""" """
output = torch.ops.npu.npu_add_rms_norm(rms_norm_input, residual, output = torch.ops.npu.npu_add_rms_norm(rms_norm_input, residual, rms_norm_weight, epsilon)
rms_norm_weight, epsilon)
out0 = output[0] out0 = output[0]
out1 = output[2] out1 = output[2]
out0 = out0 + bias out0 = out0 + bias
quantized_output = torch.ops.npu.npu_quantize(out0, scale, offset, quantized_output = torch.ops.npu.npu_quantize(out0, scale, offset, torch.qint8, -1, False)
torch.qint8, -1, False)
return quantized_output, out1 return quantized_output, out1
def replacement(rms_norm_input: torch.Tensor, residual: torch.Tensor, def replacement(
rms_norm_weight: torch.Tensor, scale: torch.Tensor, rms_norm_input: torch.Tensor,
offset: torch.Tensor, bias: torch.Tensor): residual: torch.Tensor,
rms_norm_weight: torch.Tensor,
scale: torch.Tensor,
offset: torch.Tensor,
bias: torch.Tensor,
):
""" """
Replacement for AddRMSNormQuantWithBias fusion. Replacement for AddRMSNormQuantWithBias fusion.
""" """
@@ -137,11 +151,13 @@ def replacement_add_rms_norm_quant_with_bias(epsilon):
rms_norm_input, rms_norm_input,
residual, residual,
rms_norm_weight, rms_norm_weight,
# The inverse of scale is required by npu_add_rms_norm_quant kernel which is opposite to the npu_quantize kernel. # The inverse of scale is required by npu_add_rms_norm_quant kernel
1. / scale, # which is opposite to the npu_quantize kernel.
1.0 / scale,
offset, offset,
epsilon=epsilon, epsilon=epsilon,
beta=bias) beta=bias,
)
quantized_output = output[0] quantized_output = output[0]
out1 = output[2] out1 = output[2]
return quantized_output, out1 return quantized_output, out1
@@ -156,40 +172,41 @@ def replacement_add_rms_norm_quant_with_bias(epsilon):
rmsnorm_bias = torch.randn(4, device="npu") rmsnorm_bias = torch.randn(4, device="npu")
scale = torch.ones(4, device="npu") scale = torch.ones(4, device="npu")
offset = torch.zeros(4, device="npu") offset = torch.zeros(4, device="npu")
return [ return [rms_norm_input, residual, rms_norm_weight, scale, offset, rmsnorm_bias]
rms_norm_input, residual, rms_norm_weight, scale, offset,
rmsnorm_bias
]
import torchair import torchair
torchair.register_replacement(search_fn=pattern, torchair.register_replacement(
replace_fn=replacement, search_fn=pattern, replace_fn=replacement, example_inputs=get_inputs(), extra_check=_extra_stream_scope_check
example_inputs=get_inputs(), )
extra_check=_extra_stream_scope_check)
# The replacement registered here will be actually executed after AOT. # The replacement registered here will be actually executed after AOT.
def replacement_add_rms_norm_quant_sp_pattern(epsilon): def replacement_add_rms_norm_quant_sp_pattern(epsilon):
def pattern(
def pattern(rms_norm_input: torch.Tensor, residual: torch.Tensor, rms_norm_input: torch.Tensor,
rms_norm_weight: torch.Tensor, scale: torch.Tensor, residual: torch.Tensor,
offset: torch.Tensor): rms_norm_weight: torch.Tensor,
scale: torch.Tensor,
offset: torch.Tensor,
):
""" """
Pattern for AddRMSNormQuantSPPattern fusion. Pattern for AddRMSNormQuantSPPattern fusion.
""" """
output = torch.ops.npu.npu_add_rms_norm(rms_norm_input, residual, output = torch.ops.npu.npu_add_rms_norm(rms_norm_input, residual, rms_norm_weight, epsilon)
rms_norm_weight, epsilon)
out0 = output[0] out0 = output[0]
out1 = output[2] out1 = output[2]
out0 = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(out0, True) out0 = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(out0, True)
quantized_output = torch.ops.npu.npu_quantize(out0, scale, offset, quantized_output = torch.ops.npu.npu_quantize(out0, scale, offset, torch.qint8, -1, False)
torch.qint8, -1, False)
return quantized_output, out1 return quantized_output, out1
def replacement(rms_norm_input: torch.Tensor, residual: torch.Tensor, def replacement(
rms_norm_weight: torch.Tensor, scale: torch.Tensor, rms_norm_input: torch.Tensor,
offset: torch.Tensor): residual: torch.Tensor,
rms_norm_weight: torch.Tensor,
scale: torch.Tensor,
offset: torch.Tensor,
):
""" """
Replacement for the AddRMSNormQuantSPPattern fusion. Replacement for the AddRMSNormQuantSPPattern fusion.
""" """
@@ -197,14 +214,15 @@ def replacement_add_rms_norm_quant_sp_pattern(epsilon):
rms_norm_input, rms_norm_input,
residual, residual,
rms_norm_weight, rms_norm_weight,
# The inverse of scale is required by npu_add_rms_norm_quant kernel which is opposite to the npu_quantize kernel. # The inverse of scale is required by npu_add_rms_norm_quant kernel
1. / scale, # which is opposite to the npu_quantize kernel.
1.0 / scale,
offset, offset,
epsilon=epsilon) epsilon=epsilon,
)
quantized_output = output[0] quantized_output = output[0]
out1 = output[2] out1 = output[2]
quantized_output = torch.ops.vllm.maybe_all_gather_and_maybe_unpad( quantized_output = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(quantized_output, True)
quantized_output, True)
return quantized_output, out1 return quantized_output, out1
def get_inputs(): def get_inputs():
@@ -220,34 +238,40 @@ def replacement_add_rms_norm_quant_sp_pattern(epsilon):
import torchair import torchair
torchair.register_replacement(search_fn=pattern, torchair.register_replacement(
replace_fn=replacement, search_fn=pattern, replace_fn=replacement, example_inputs=get_inputs(), extra_check=_extra_stream_scope_check
example_inputs=get_inputs(), )
extra_check=_extra_stream_scope_check)
# The replacement registered here will be actually executed after AOT. # The replacement registered here will be actually executed after AOT.
def replacement_add_rms_norm_quant_sp_pattern_with_bias(epsilon): def replacement_add_rms_norm_quant_sp_pattern_with_bias(epsilon):
def pattern(
def pattern(rms_norm_input: torch.Tensor, residual: torch.Tensor, rms_norm_input: torch.Tensor,
rms_norm_weight: torch.Tensor, scale: torch.Tensor, residual: torch.Tensor,
offset: torch.Tensor, bias: torch.Tensor): rms_norm_weight: torch.Tensor,
scale: torch.Tensor,
offset: torch.Tensor,
bias: torch.Tensor,
):
""" """
Pattern for AddRMSNormQuantSPPatternWithBias fusion. Pattern for AddRMSNormQuantSPPatternWithBias fusion.
""" """
output = torch.ops.npu.npu_add_rms_norm(rms_norm_input, residual, output = torch.ops.npu.npu_add_rms_norm(rms_norm_input, residual, rms_norm_weight, epsilon)
rms_norm_weight, epsilon)
out0 = output[0] out0 = output[0]
out1 = output[2] out1 = output[2]
out0 = out0 + bias out0 = out0 + bias
out0 = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(out0, True) out0 = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(out0, True)
quantized_output = torch.ops.npu.npu_quantize(out0, scale, offset, quantized_output = torch.ops.npu.npu_quantize(out0, scale, offset, torch.qint8, -1, False)
torch.qint8, -1, False)
return quantized_output, out1 return quantized_output, out1
def replacement(rms_norm_input: torch.Tensor, residual: torch.Tensor, def replacement(
rms_norm_weight: torch.Tensor, scale: torch.Tensor, rms_norm_input: torch.Tensor,
offset: torch.Tensor, bias: torch.Tensor): residual: torch.Tensor,
rms_norm_weight: torch.Tensor,
scale: torch.Tensor,
offset: torch.Tensor,
bias: torch.Tensor,
):
""" """
Replacement for the AddRMSNormQuantSPPatternWithBias fusion. Replacement for the AddRMSNormQuantSPPatternWithBias fusion.
""" """
@@ -255,15 +279,16 @@ def replacement_add_rms_norm_quant_sp_pattern_with_bias(epsilon):
rms_norm_input, rms_norm_input,
residual, residual,
rms_norm_weight, rms_norm_weight,
# The inverse of scale is required by npu_add_rms_norm_quant kernel which is opposite to the npu_quantize kernel. # The inverse of scale is required by npu_add_rms_norm_quant kernel
1. / scale, # which is opposite to the npu_quantize kernel.
1.0 / scale,
offset, offset,
epsilon=epsilon, epsilon=epsilon,
beta=bias) beta=bias,
)
quantized_output = output[0] quantized_output = output[0]
out1 = output[2] out1 = output[2]
quantized_output = torch.ops.vllm.maybe_all_gather_and_maybe_unpad( quantized_output = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(quantized_output, True)
quantized_output, True)
return quantized_output, out1 return quantized_output, out1
def get_inputs(): def get_inputs():
@@ -276,25 +301,19 @@ def replacement_add_rms_norm_quant_sp_pattern_with_bias(epsilon):
rmsnorm_bias = torch.randn(4, device="npu") rmsnorm_bias = torch.randn(4, device="npu")
scale = torch.ones(4, device="npu") scale = torch.ones(4, device="npu")
offset = torch.zeros(4, device="npu") offset = torch.zeros(4, device="npu")
return [ return [rms_norm_input, residual, rms_norm_weight, scale, offset, rmsnorm_bias]
rms_norm_input, residual, rms_norm_weight, scale, offset,
rmsnorm_bias
]
import torchair import torchair
torchair.register_replacement(search_fn=pattern, torchair.register_replacement(
replace_fn=replacement, search_fn=pattern, replace_fn=replacement, example_inputs=get_inputs(), extra_check=_extra_stream_scope_check
example_inputs=get_inputs(), )
extra_check=_extra_stream_scope_check)
# register converter for pass # register converter for pass
common_epsilons = [1e-5, 1e-6] common_epsilons = [1e-5, 1e-6]
for eps in common_epsilons: for eps in common_epsilons:
logger.info( logger.info(f"Start register fusion pattern for AddRMSNormQuant with epsilons={eps}")
f"Start register fusion pattern for AddRMSNormQuant with epsilons={eps}"
)
replacement_add_rms_norm_quant(eps) replacement_add_rms_norm_quant(eps)
replacement_add_rms_norm_quant_with_bias(eps) replacement_add_rms_norm_quant_with_bias(eps)
replacement_add_rms_norm_quant_sp_pattern(eps) replacement_add_rms_norm_quant_sp_pattern(eps)

View File

@@ -25,7 +25,6 @@ from vllm.logger import logger
class AddRMSNormQuantPattern: class AddRMSNormQuantPattern:
def __init__(self, vllm_config: VllmConfig, eps: float = 1e-6): def __init__(self, vllm_config: VllmConfig, eps: float = 1e-6):
self.vllm_config = vllm_config self.vllm_config = vllm_config
self.dtype = vllm_config.model_config.dtype self.dtype = vllm_config.model_config.dtype
@@ -41,50 +40,48 @@ class AddRMSNormQuantPattern:
scale = torch.ones(4, device="npu", dtype=self.dtype) scale = torch.ones(4, device="npu", dtype=self.dtype)
scale_reciprocal = torch.ones(4, device="npu", dtype=self.dtype) scale_reciprocal = torch.ones(4, device="npu", dtype=self.dtype)
offset = torch.zeros(4, device="npu", dtype=self.dtype) offset = torch.zeros(4, device="npu", dtype=self.dtype)
return [ return [rms_norm_input, residual, rms_norm_weight, scale, scale_reciprocal, offset]
rms_norm_input, residual, rms_norm_weight, scale, scale_reciprocal,
offset
]
def register(self, pm_pass: PatternMatcherPass): def register(self, pm_pass: PatternMatcherPass):
def pattern(
def pattern(rms_norm_input: torch.Tensor, residual: torch.Tensor, rms_norm_input: torch.Tensor,
rms_norm_weight: torch.Tensor, scale: torch.Tensor, residual: torch.Tensor,
scale_reciprocal: torch.Tensor, offset: torch.Tensor): rms_norm_weight: torch.Tensor,
scale: torch.Tensor,
scale_reciprocal: torch.Tensor,
offset: torch.Tensor,
):
""" """
Pattern for AddRMSNormQuant fusion. Pattern for AddRMSNormQuant fusion.
""" """
output = torch.ops.npu.npu_add_rms_norm(rms_norm_input, residual, output = torch.ops.npu.npu_add_rms_norm(rms_norm_input, residual, rms_norm_weight, self.eps)
rms_norm_weight, self.eps)
out0 = output[0] out0 = output[0]
out1 = output[2] out1 = output[2]
quantized_output = torch.ops.vllm.quantize(out0, scale, quantized_output = torch.ops.vllm.quantize(out0, scale, scale_reciprocal, offset)
scale_reciprocal,
offset)
return quantized_output, out1 return quantized_output, out1
def replacement(rms_norm_input: torch.Tensor, residual: torch.Tensor, def replacement(
rms_norm_weight: torch.Tensor, scale: torch.Tensor, rms_norm_input: torch.Tensor,
scale_reciprocal: torch.Tensor, offset: torch.Tensor): residual: torch.Tensor,
rms_norm_weight: torch.Tensor,
scale: torch.Tensor,
scale_reciprocal: torch.Tensor,
offset: torch.Tensor,
):
""" """
Replacement for the AddRMSNormQuant fusion. Replacement for the AddRMSNormQuant fusion.
""" """
output = torch.ops.npu.npu_add_rms_norm_quant(rms_norm_input, output = torch.ops.npu.npu_add_rms_norm_quant(
residual, rms_norm_input, residual, rms_norm_weight, scale, offset, epsilon=self.eps
rms_norm_weight, )
scale,
offset,
epsilon=self.eps)
quantized_output = output[0] quantized_output = output[0]
out1 = output[2] out1 = output[2]
return quantized_output, out1 return quantized_output, out1
pm.register_replacement(pattern, replacement, self.get_inputs(), pm.register_replacement(pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass)
pm.fwd_only, pm_pass)
class AddRMSNormQuantPatternWithBias: class AddRMSNormQuantPatternWithBias:
def __init__(self, vllm_config: VllmConfig, eps: float = 1e-6): def __init__(self, vllm_config: VllmConfig, eps: float = 1e-6):
self.vllm_config = vllm_config self.vllm_config = vllm_config
self.dtype = vllm_config.model_config.dtype self.dtype = vllm_config.model_config.dtype
@@ -101,54 +98,51 @@ class AddRMSNormQuantPatternWithBias:
scale = torch.ones(4, device="npu", dtype=self.dtype) scale = torch.ones(4, device="npu", dtype=self.dtype)
scale_reciprocal = torch.ones(4, device="npu", dtype=self.dtype) scale_reciprocal = torch.ones(4, device="npu", dtype=self.dtype)
offset = torch.zeros(4, device="npu", dtype=self.dtype) offset = torch.zeros(4, device="npu", dtype=self.dtype)
return [ return [rms_norm_input, residual, rms_norm_weight, scale, scale_reciprocal, offset, rmsnorm_bias]
rms_norm_input, residual, rms_norm_weight, scale, scale_reciprocal,
offset, rmsnorm_bias
]
def register(self, pm_pass: PatternMatcherPass): def register(self, pm_pass: PatternMatcherPass):
def pattern(
def pattern(rms_norm_input: torch.Tensor, residual: torch.Tensor, rms_norm_input: torch.Tensor,
rms_norm_weight: torch.Tensor, scale: torch.Tensor, residual: torch.Tensor,
scale_reciprocal: torch.Tensor, offset: torch.Tensor, rms_norm_weight: torch.Tensor,
bias: torch.Tensor): scale: torch.Tensor,
scale_reciprocal: torch.Tensor,
offset: torch.Tensor,
bias: torch.Tensor,
):
""" """
Pattern for AddRMSNormQuant fusion. Pattern for AddRMSNormQuant fusion.
""" """
output = torch.ops.npu.npu_add_rms_norm(rms_norm_input, residual, output = torch.ops.npu.npu_add_rms_norm(rms_norm_input, residual, rms_norm_weight, self.eps)
rms_norm_weight, self.eps)
out0 = output[0] out0 = output[0]
out1 = output[2] out1 = output[2]
out0 = out0 + bias out0 = out0 + bias
quantized_output = torch.ops.vllm.quantize(out0, scale, quantized_output = torch.ops.vllm.quantize(out0, scale, scale_reciprocal, offset)
scale_reciprocal,
offset)
return quantized_output, out1 return quantized_output, out1
def replacement(rms_norm_input: torch.Tensor, residual: torch.Tensor, def replacement(
rms_norm_weight: torch.Tensor, scale: torch.Tensor, rms_norm_input: torch.Tensor,
scale_reciprocal: torch.Tensor, offset: torch.Tensor, residual: torch.Tensor,
bias: torch.Tensor): rms_norm_weight: torch.Tensor,
scale: torch.Tensor,
scale_reciprocal: torch.Tensor,
offset: torch.Tensor,
bias: torch.Tensor,
):
""" """
Replacement for the AddRMSNormQuant fusion. Replacement for the AddRMSNormQuant fusion.
""" """
output = torch.ops.npu.npu_add_rms_norm_quant(rms_norm_input, output = torch.ops.npu.npu_add_rms_norm_quant(
residual, rms_norm_input, residual, rms_norm_weight, scale, offset, epsilon=self.eps, beta=bias
rms_norm_weight, )
scale,
offset,
epsilon=self.eps,
beta=bias)
quantized_output = output[0] quantized_output = output[0]
out1 = output[2] out1 = output[2]
return quantized_output, out1 return quantized_output, out1
pm.register_replacement(pattern, replacement, self.get_inputs(), pm.register_replacement(pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass)
pm.fwd_only, pm_pass)
class AddRMSNormQuantSPPattern: class AddRMSNormQuantSPPattern:
def __init__(self, vllm_config: VllmConfig, eps: float = 1e-6): def __init__(self, vllm_config: VllmConfig, eps: float = 1e-6):
self.vllm_config = vllm_config self.vllm_config = vllm_config
self.dtype = vllm_config.model_config.dtype self.dtype = vllm_config.model_config.dtype
@@ -164,53 +158,50 @@ class AddRMSNormQuantSPPattern:
scale = torch.ones(4, device="npu", dtype=self.dtype) scale = torch.ones(4, device="npu", dtype=self.dtype)
scale_reciprocal = torch.ones(4, device="npu", dtype=self.dtype) scale_reciprocal = torch.ones(4, device="npu", dtype=self.dtype)
offset = torch.zeros(4, device="npu", dtype=self.dtype) offset = torch.zeros(4, device="npu", dtype=self.dtype)
return [ return [rms_norm_input, residual, rms_norm_weight, scale, scale_reciprocal, offset]
rms_norm_input, residual, rms_norm_weight, scale, scale_reciprocal,
offset
]
def register(self, pm_pass: PatternMatcherPass): def register(self, pm_pass: PatternMatcherPass):
def pattern(
def pattern(rms_norm_input: torch.Tensor, residual: torch.Tensor, rms_norm_input: torch.Tensor,
rms_norm_weight: torch.Tensor, scale: torch.Tensor, residual: torch.Tensor,
scale_reciprocal: torch.Tensor, offset: torch.Tensor): rms_norm_weight: torch.Tensor,
scale: torch.Tensor,
scale_reciprocal: torch.Tensor,
offset: torch.Tensor,
):
""" """
Pattern for AddRMSNormQuant fusion. Pattern for AddRMSNormQuant fusion.
""" """
output = torch.ops.npu.npu_add_rms_norm(rms_norm_input, residual, output = torch.ops.npu.npu_add_rms_norm(rms_norm_input, residual, rms_norm_weight, self.eps)
rms_norm_weight, self.eps)
out0 = output[0] out0 = output[0]
out1 = output[2] out1 = output[2]
out0 = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(out0, True) out0 = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(out0, True)
quantized_output = torch.ops.vllm.quantize(out0, scale, quantized_output = torch.ops.vllm.quantize(out0, scale, scale_reciprocal, offset)
scale_reciprocal,
offset)
return quantized_output, out1 return quantized_output, out1
def replacement(rms_norm_input: torch.Tensor, residual: torch.Tensor, def replacement(
rms_norm_weight: torch.Tensor, scale: torch.Tensor, rms_norm_input: torch.Tensor,
scale_reciprocal: torch.Tensor, offset: torch.Tensor): residual: torch.Tensor,
rms_norm_weight: torch.Tensor,
scale: torch.Tensor,
scale_reciprocal: torch.Tensor,
offset: torch.Tensor,
):
""" """
Replacement for the AddRMSNormQuant fusion. Replacement for the AddRMSNormQuant fusion.
""" """
output = torch.ops.npu.npu_add_rms_norm_quant(rms_norm_input, output = torch.ops.npu.npu_add_rms_norm_quant(
residual, rms_norm_input, residual, rms_norm_weight, scale, offset, epsilon=self.eps
rms_norm_weight, )
scale,
offset,
epsilon=self.eps)
quantized_output = output[0] quantized_output = output[0]
out1 = output[2] out1 = output[2]
quantized_output = torch.ops.vllm.maybe_all_gather_and_maybe_unpad( quantized_output = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(quantized_output, True)
quantized_output, True)
return quantized_output, out1 return quantized_output, out1
pm.register_replacement(pattern, replacement, self.get_inputs(), pm.register_replacement(pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass)
pm.fwd_only, pm_pass)
class AddRMSNormQuantSPPatternWithBias: class AddRMSNormQuantSPPatternWithBias:
def __init__(self, vllm_config: VllmConfig, eps: float = 1e-6): def __init__(self, vllm_config: VllmConfig, eps: float = 1e-6):
self.vllm_config = vllm_config self.vllm_config = vllm_config
self.dtype = vllm_config.model_config.dtype self.dtype = vllm_config.model_config.dtype
@@ -227,53 +218,50 @@ class AddRMSNormQuantSPPatternWithBias:
scale = torch.ones(4, device="npu", dtype=self.dtype) scale = torch.ones(4, device="npu", dtype=self.dtype)
scale_reciprocal = torch.ones(4, device="npu", dtype=self.dtype) scale_reciprocal = torch.ones(4, device="npu", dtype=self.dtype)
offset = torch.zeros(4, device="npu", dtype=self.dtype) offset = torch.zeros(4, device="npu", dtype=self.dtype)
return [ return [rms_norm_input, residual, rms_norm_weight, scale, scale_reciprocal, offset, rmsnorm_bias]
rms_norm_input, residual, rms_norm_weight, scale, scale_reciprocal,
offset, rmsnorm_bias
]
def register(self, pm_pass: PatternMatcherPass): def register(self, pm_pass: PatternMatcherPass):
def pattern(
def pattern(rms_norm_input: torch.Tensor, residual: torch.Tensor, rms_norm_input: torch.Tensor,
rms_norm_weight: torch.Tensor, scale: torch.Tensor, residual: torch.Tensor,
scale_reciprocal: torch.Tensor, offset: torch.Tensor, rms_norm_weight: torch.Tensor,
bias: torch.Tensor): scale: torch.Tensor,
scale_reciprocal: torch.Tensor,
offset: torch.Tensor,
bias: torch.Tensor,
):
""" """
Pattern for AddRMSNormQuant fusion. Pattern for AddRMSNormQuant fusion.
""" """
output = torch.ops.npu.npu_add_rms_norm(rms_norm_input, residual, output = torch.ops.npu.npu_add_rms_norm(rms_norm_input, residual, rms_norm_weight, self.eps)
rms_norm_weight, self.eps)
out0 = output[0] out0 = output[0]
out1 = output[2] out1 = output[2]
out0 = out0 + bias out0 = out0 + bias
out0 = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(out0, True) out0 = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(out0, True)
quantized_output = torch.ops.vllm.quantize(out0, scale, quantized_output = torch.ops.vllm.quantize(out0, scale, scale_reciprocal, offset)
scale_reciprocal,
offset)
return quantized_output, out1 return quantized_output, out1
def replacement(rms_norm_input: torch.Tensor, residual: torch.Tensor, def replacement(
rms_norm_weight: torch.Tensor, scale: torch.Tensor, rms_norm_input: torch.Tensor,
scale_reciprocal: torch.Tensor, offset: torch.Tensor, residual: torch.Tensor,
bias: torch.Tensor): rms_norm_weight: torch.Tensor,
scale: torch.Tensor,
scale_reciprocal: torch.Tensor,
offset: torch.Tensor,
bias: torch.Tensor,
):
""" """
Replacement for the AddRMSNormQuant fusion. Replacement for the AddRMSNormQuant fusion.
""" """
output = torch.ops.npu.npu_add_rms_norm_quant(rms_norm_input, output = torch.ops.npu.npu_add_rms_norm_quant(
residual, rms_norm_input, residual, rms_norm_weight, scale, offset, epsilon=self.eps, beta=bias
rms_norm_weight, )
scale,
offset,
epsilon=self.eps,
beta=bias)
quantized_output = output[0] quantized_output = output[0]
out1 = output[2] out1 = output[2]
quantized_output = torch.ops.vllm.maybe_all_gather_and_maybe_unpad( quantized_output = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(quantized_output, True)
quantized_output, True)
return quantized_output, out1 return quantized_output, out1
pm.register_replacement(pattern, replacement, self.get_inputs(), pm.register_replacement(pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass)
pm.fwd_only, pm_pass)
class AddRMSNormQuantFusionPass(VllmInductorPass): class AddRMSNormQuantFusionPass(VllmInductorPass):
@@ -283,25 +271,19 @@ class AddRMSNormQuantFusionPass(VllmInductorPass):
def __init__(self, vllm_config: VllmConfig): def __init__(self, vllm_config: VllmConfig):
super().__init__(vllm_config) super().__init__(vllm_config)
self.pattern_match_passes: PatternMatcherPass = PatternMatcherPass( self.pattern_match_passes: PatternMatcherPass = PatternMatcherPass(pass_name="rmsnorm_quant_fusion_pass")
pass_name="rmsnorm_quant_fusion_pass")
dtype = vllm_config.model_config.dtype dtype = vllm_config.model_config.dtype
if dtype not in (torch.bfloat16, torch.float16): if dtype not in (torch.bfloat16, torch.float16):
logger.debug("Quant fusion not enabled: unsupported dtype %s", logger.debug("Quant fusion not enabled: unsupported dtype %s", dtype)
dtype)
return return
common_epsilons = [1e-5, 1e-6] common_epsilons = [1e-5, 1e-6]
for eps in common_epsilons: for eps in common_epsilons:
AddRMSNormQuantPattern(vllm_config, AddRMSNormQuantPattern(vllm_config, eps=eps).register(self.pattern_match_passes)
eps=eps).register(self.pattern_match_passes) AddRMSNormQuantPatternWithBias(vllm_config, eps=eps).register(self.pattern_match_passes)
AddRMSNormQuantPatternWithBias(vllm_config, eps=eps).register( AddRMSNormQuantSPPattern(vllm_config, eps=eps).register(self.pattern_match_passes)
self.pattern_match_passes) AddRMSNormQuantSPPatternWithBias(vllm_config, eps=eps).register(self.pattern_match_passes)
AddRMSNormQuantSPPattern(vllm_config, eps=eps).register(
self.pattern_match_passes)
AddRMSNormQuantSPPatternWithBias(vllm_config, eps=eps).register(
self.pattern_match_passes)
def __call__(self, graph: torch.fx.Graph): def __call__(self, graph: torch.fx.Graph):
self.begin() self.begin()

View File

@@ -17,8 +17,7 @@
# #
import torch import torch
import torch._inductor.pattern_matcher as pm import torch._inductor.pattern_matcher as pm
from torch._inductor.pattern_matcher import (PatternMatcherPass, from torch._inductor.pattern_matcher import PatternMatcherPass, PatternPrettyPrinter
PatternPrettyPrinter)
from vllm.attention.layer import Attention from vllm.attention.layer import Attention
from vllm.compilation.vllm_inductor_pass import VllmInductorPass from vllm.compilation.vllm_inductor_pass import VllmInductorPass
from vllm.config import VllmConfig, get_layers_from_vllm_config from vllm.config import VllmConfig, get_layers_from_vllm_config
@@ -27,13 +26,7 @@ from vllm.logger import logger
class QKNormRopeFusionPattern: class QKNormRopeFusionPattern:
def __init__(self, vllm_config, head_dim, num_heads, num_kv_heads, eps=1e-6):
def __init__(self,
vllm_config,
head_dim,
num_heads,
num_kv_heads,
eps=1e-6):
self.vllm_config = vllm_config self.vllm_config = vllm_config
self.head_dim = head_dim self.head_dim = head_dim
self.num_heads = num_heads self.num_heads = num_heads
@@ -45,65 +38,38 @@ class QKNormRopeFusionPattern:
def get_inputs(self): def get_inputs(self):
T = 5 T = 5
qkv = torch.empty(T, qkv = torch.empty(T, self.q_size + 2 * self.kv_size, dtype=torch.bfloat16, device="npu")
self.q_size + 2 * self.kv_size, q_weight = torch.empty(self.head_dim, dtype=torch.bfloat16, device="npu")
dtype=torch.bfloat16, k_weight = torch.empty(self.head_dim, dtype=torch.bfloat16, device="npu")
device="npu") cos = torch.empty(1, T, 1, self.head_dim, dtype=torch.bfloat16, device="npu")
q_weight = torch.empty(self.head_dim, sin = torch.empty(1, T, 1, self.head_dim, dtype=torch.bfloat16, device="npu")
dtype=torch.bfloat16,
device="npu")
k_weight = torch.empty(self.head_dim,
dtype=torch.bfloat16,
device="npu")
cos = torch.empty(1,
T,
1,
self.head_dim,
dtype=torch.bfloat16,
device="npu")
sin = torch.empty(1,
T,
1,
self.head_dim,
dtype=torch.bfloat16,
device="npu")
return [qkv, q_weight, k_weight, cos, sin] return [qkv, q_weight, k_weight, cos, sin]
def register(self, pm_pass: PatternMatcherPass): def register(self, pm_pass: PatternMatcherPass):
def pattern(
qkv: torch.Tensor, q_weight: torch.Tensor, k_weight: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor
):
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
def pattern(qkv: torch.Tensor, q_weight: torch.Tensor, q_by_head = q.view(*q.shape[:-1], q.shape[-1] // self.head_dim, self.head_dim)
k_weight: torch.Tensor, cos: torch.Tensor, q_norm_out, _ = torch.ops.npu.npu_rms_norm(q_by_head, q_weight, self.eps)
sin: torch.Tensor):
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], k_by_head = k.view(*k.shape[:-1], k.shape[-1] // self.head_dim, self.head_dim)
dim=-1) k_norm_out, _ = torch.ops.npu.npu_rms_norm(k_by_head, k_weight, self.eps)
q_by_head = q.view(*q.shape[:-1], q.shape[-1] // self.head_dim,
self.head_dim)
q_norm_out, _ = torch.ops.npu.npu_rms_norm(q_by_head, q_weight,
self.eps)
k_by_head = k.view(*k.shape[:-1], k.shape[-1] // self.head_dim,
self.head_dim)
k_norm_out, _ = torch.ops.npu.npu_rms_norm(k_by_head, k_weight,
self.eps)
q_flat = q_norm_out.view(q.shape) q_flat = q_norm_out.view(q.shape)
q_reshape = q_flat.contiguous().view(1, q_flat.shape[0], -1, q_reshape = q_flat.contiguous().view(1, q_flat.shape[0], -1, self.head_dim)
self.head_dim)
k_flat = k_norm_out.view(k.shape) k_flat = k_norm_out.view(k.shape)
k_reshape = k_flat.contiguous().view(1, k_flat.shape[0], -1, k_reshape = k_flat.contiguous().view(1, k_flat.shape[0], -1, self.head_dim)
self.head_dim)
q_rope, k_rope = torch.ops.npu.npu_apply_rotary_pos_emb( q_rope, k_rope = torch.ops.npu.npu_apply_rotary_pos_emb(q_reshape, k_reshape, cos, sin)
q_reshape, k_reshape, cos, sin)
return q_rope, k_rope, v return q_rope, k_rope, v
def replacement(qkv: torch.Tensor, q_weight: torch.Tensor, def replacement(
k_weight: torch.Tensor, cos: torch.Tensor, qkv: torch.Tensor, q_weight: torch.Tensor, k_weight: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor
sin: torch.Tensor): ):
results = torch.ops.vllm.qkv_rmsnorm_rope( results = torch.ops.vllm.qkv_rmsnorm_rope(
input=qkv, input=qkv,
q_weight=q_weight, q_weight=q_weight,
@@ -115,22 +81,16 @@ class QKNormRopeFusionPattern:
q_bias=None, q_bias=None,
k_bias=None, k_bias=None,
sin=sin, sin=sin,
cos=cos) cos=cos,
)
return results return results
pm.register_replacement(pattern, replacement, self.get_inputs(), pm.register_replacement(pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass)
pm.fwd_only, pm_pass)
class QKNormRopeFusionPatternWithBias: class QKNormRopeFusionPatternWithBias:
def __init__(self, vllm_config, head_dim, num_heads, num_kv_heads, eps=1e-6):
def __init__(self,
vllm_config,
head_dim,
num_heads,
num_kv_heads,
eps=1e-6):
self.head_dim = head_dim self.head_dim = head_dim
self.num_heads = num_heads self.num_heads = num_heads
self.num_kv_heads = num_kv_heads self.num_kv_heads = num_kv_heads
@@ -142,71 +102,55 @@ class QKNormRopeFusionPatternWithBias:
def get_inputs(self): def get_inputs(self):
T = 5 T = 5
qkv = torch.empty(T, qkv = torch.empty(T, self.q_size + 2 * self.kv_size, dtype=torch.bfloat16, device="npu")
self.q_size + 2 * self.kv_size, q_weight = torch.empty(self.head_dim, dtype=torch.bfloat16, device="npu")
dtype=torch.bfloat16, k_weight = torch.empty(self.head_dim, dtype=torch.bfloat16, device="npu")
device="npu")
q_weight = torch.empty(self.head_dim,
dtype=torch.bfloat16,
device="npu")
k_weight = torch.empty(self.head_dim,
dtype=torch.bfloat16,
device="npu")
q_bias = torch.empty(self.head_dim, dtype=torch.bfloat16, device="npu") q_bias = torch.empty(self.head_dim, dtype=torch.bfloat16, device="npu")
k_bias = torch.empty(self.head_dim, dtype=torch.bfloat16, device="npu") k_bias = torch.empty(self.head_dim, dtype=torch.bfloat16, device="npu")
cos = torch.empty(1, cos = torch.empty(1, T, 1, self.head_dim, dtype=torch.bfloat16, device="npu")
T, sin = torch.empty(1, T, 1, self.head_dim, dtype=torch.bfloat16, device="npu")
1,
self.head_dim,
dtype=torch.bfloat16,
device="npu")
sin = torch.empty(1,
T,
1,
self.head_dim,
dtype=torch.bfloat16,
device="npu")
return [qkv, q_weight, k_weight, q_bias, k_bias, cos, sin] return [qkv, q_weight, k_weight, q_bias, k_bias, cos, sin]
def register(self, pm_pass: PatternMatcherPass): def register(self, pm_pass: PatternMatcherPass):
def pattern(
qkv: torch.Tensor,
q_weight: torch.Tensor,
k_weight: torch.Tensor,
q_bias: torch.Tensor,
k_bias: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
):
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
def pattern(qkv: torch.Tensor, q_weight: torch.Tensor, q_by_head = q.view(*q.shape[:-1], q.shape[-1] // self.head_dim, self.head_dim)
k_weight: torch.Tensor, q_bias: torch.Tensor, q_norm_out, _ = torch.ops.npu.npu_rms_norm(q_by_head, q_weight, self.eps)
k_bias: torch.Tensor, cos: torch.Tensor,
sin: torch.Tensor):
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size],
dim=-1)
q_by_head = q.view(*q.shape[:-1], q.shape[-1] // self.head_dim,
self.head_dim)
q_norm_out, _ = torch.ops.npu.npu_rms_norm(q_by_head, q_weight,
self.eps)
q_normed = q_norm_out + q_bias q_normed = q_norm_out + q_bias
k_by_head = k.view(*k.shape[:-1], k.shape[-1] // self.head_dim, k_by_head = k.view(*k.shape[:-1], k.shape[-1] // self.head_dim, self.head_dim)
self.head_dim) k_norm_out, _ = torch.ops.npu.npu_rms_norm(k_by_head, k_weight, self.eps)
k_norm_out, _ = torch.ops.npu.npu_rms_norm(k_by_head, k_weight,
self.eps)
k_normed = k_norm_out + k_bias k_normed = k_norm_out + k_bias
q_flat = q_normed.view(q.shape) q_flat = q_normed.view(q.shape)
q_reshape = q_flat.contiguous().view(1, q_flat.shape[0], -1, q_reshape = q_flat.contiguous().view(1, q_flat.shape[0], -1, self.head_dim)
self.head_dim)
k_flat = k_normed.view(k.shape) k_flat = k_normed.view(k.shape)
k_reshape = k_flat.contiguous().view(1, k_flat.shape[0], -1, k_reshape = k_flat.contiguous().view(1, k_flat.shape[0], -1, self.head_dim)
self.head_dim)
q_rope, k_rope = torch.ops.npu.npu_apply_rotary_pos_emb( q_rope, k_rope = torch.ops.npu.npu_apply_rotary_pos_emb(q_reshape, k_reshape, cos, sin)
q_reshape, k_reshape, cos, sin)
return q_rope, k_rope, v return q_rope, k_rope, v
def replacement(qkv: torch.Tensor, q_weight: torch.Tensor, def replacement(
k_weight: torch.Tensor, q_bias: torch.Tensor, qkv: torch.Tensor,
k_bias: torch.Tensor, cos: torch.Tensor, q_weight: torch.Tensor,
sin: torch.Tensor): k_weight: torch.Tensor,
q_bias: torch.Tensor,
k_bias: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
):
results = torch.ops.vllm.qkv_rmsnorm_rope( results = torch.ops.vllm.qkv_rmsnorm_rope(
input=qkv, input=qkv,
q_weight=q_weight, q_weight=q_weight,
@@ -218,11 +162,11 @@ class QKNormRopeFusionPatternWithBias:
q_bias=q_bias, q_bias=q_bias,
k_bias=k_bias, k_bias=k_bias,
cos=cos, cos=cos,
sin=sin) sin=sin,
)
return results return results
pm.register_replacement(pattern, replacement, self.get_inputs(), pm.register_replacement(pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass)
pm.fwd_only, pm_pass)
class QKNormRopeFusionPass(VllmInductorPass): class QKNormRopeFusionPass(VllmInductorPass):
@@ -232,44 +176,38 @@ class QKNormRopeFusionPass(VllmInductorPass):
def __init__(self, vllm_config: VllmConfig): def __init__(self, vllm_config: VllmConfig):
super().__init__(vllm_config) super().__init__(vllm_config)
self.pattern_match_passes: PatternMatcherPass = PatternMatcherPass( self.pattern_match_passes: PatternMatcherPass = PatternMatcherPass(pass_name="qknorm_rope_fusion_pass")
pass_name="qknorm_rope_fusion_pass")
dtype = vllm_config.model_config.dtype dtype = vllm_config.model_config.dtype
if dtype not in (torch.bfloat16, torch.float16): if dtype not in (torch.bfloat16, torch.float16):
logger.debug( logger.debug("QKNorm and Rope fusion not enabled: unsupported dtype %s", dtype)
"QKNorm and Rope fusion not enabled: unsupported dtype %s",
dtype)
return return
# use one attn layer to get meta (such as head_dim) for QKNormRopeFusionPattern # use one attn layer to get meta (such as head_dim) for QKNormRopeFusionPattern
attn_layers: dict[str, Attention] = get_layers_from_vllm_config( attn_layers: dict[str, Attention] = get_layers_from_vllm_config(vllm_config, Attention)
vllm_config, Attention)
if len(attn_layers) == 0: if len(attn_layers) == 0:
logger.debug( logger.debug("QKNorm and Rope fusion enabled, but no Attention layers were discovered.")
"QKNorm and Rope fusion enabled, but no Attention layers were discovered."
)
return return
layer = next(iter(attn_layers.values())) layer = next(iter(attn_layers.values()))
for epsilon in [1e-6, 1e-5]: for epsilon in [1e-6, 1e-5]:
if layer.head_size != 128: if layer.head_size != 128:
logger.debug( logger.debug("QKNorm and Rope fusion not enabled: head_dim %d is not equal of 128", layer.head_size)
"QKNorm and Rope fusion not enabled: head_dim %d is not equal of 128",
layer.head_size)
continue continue
QKNormRopeFusionPattern(vllm_config=vllm_config, QKNormRopeFusionPattern(
head_dim=layer.head_size, vllm_config=vllm_config,
num_heads=layer.num_heads, head_dim=layer.head_size,
num_kv_heads=layer.num_kv_heads, num_heads=layer.num_heads,
eps=epsilon).register( num_kv_heads=layer.num_kv_heads,
self.pattern_match_passes) eps=epsilon,
).register(self.pattern_match_passes)
QKNormRopeFusionPatternWithBias(vllm_config=vllm_config, QKNormRopeFusionPatternWithBias(
head_dim=layer.head_size, vllm_config=vllm_config,
num_heads=layer.num_heads, head_dim=layer.head_size,
num_kv_heads=layer.num_kv_heads, num_heads=layer.num_heads,
eps=epsilon).register( num_kv_heads=layer.num_kv_heads,
self.pattern_match_passes) eps=epsilon,
).register(self.pattern_match_passes)
def __call__(self, graph: torch.fx.Graph): def __call__(self, graph: torch.fx.Graph):
self.begin() self.begin()

View File

@@ -1,10 +1,8 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
# -*- coding: utf-8 -*-
import os import os
import subprocess import subprocess
from collections import defaultdict from collections import defaultdict
from typing import Dict, List, Tuple
import psutil import psutil
from vllm.logger import logger from vllm.logger import logger
@@ -13,26 +11,22 @@ ALLOWED_CPUS_PATH = "/proc/self/status"
ASCEND_RT_VISIBLE_DEVICES = os.getenv("ASCEND_RT_VISIBLE_DEVICES") ASCEND_RT_VISIBLE_DEVICES = os.getenv("ASCEND_RT_VISIBLE_DEVICES")
def execute_command(cmd: List[str]) -> Tuple[str, int]: def execute_command(cmd: list[str]) -> tuple[str, int]:
with subprocess.Popen(cmd, with subprocess.Popen(cmd, shell=False, stdout=subprocess.PIPE, stderr=subprocess.PIPE) as p:
shell=False,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE) as p:
out, _ = p.communicate(timeout=1000) out, _ = p.communicate(timeout=1000)
return out.decode(), p.returncode return out.decode(), p.returncode
class DeviceInfo: class DeviceInfo:
def __init__(self): def __init__(self):
self.npu_map_info: Dict[str, Dict[str, str]] = self.get_npu_map_info() self.npu_map_info: dict[str, dict[str, str]] = self.get_npu_map_info()
self.allowed_cpus: List[int] = self.parse_allowed_cpus() self.allowed_cpus: list[int] = self.parse_allowed_cpus()
self.running_npu_list: List[int] = self.get_running_npus() self.running_npu_list: list[int] = self.get_running_npus()
self.npu_affinity: Dict[int, List[int]] = self.parse_topo_affinity() self.npu_affinity: dict[int, list[int]] = self.parse_topo_affinity()
@staticmethod @staticmethod
def expand_cpu_list(allowed_list_str: str) -> List[int]: def expand_cpu_list(allowed_list_str: str) -> list[int]:
allowed_cpus_list: List[int] = [] allowed_cpus_list: list[int] = []
for per_range in allowed_list_str.split(","): for per_range in allowed_list_str.split(","):
if "-" in per_range: if "-" in per_range:
start_cpu, end_cpu = map(int, per_range.split("-")) start_cpu, end_cpu = map(int, per_range.split("-"))
@@ -42,8 +36,8 @@ class DeviceInfo:
return allowed_cpus_list return allowed_cpus_list
@staticmethod @staticmethod
def get_npu_map_info() -> Dict[str, Dict[str, str]]: def get_npu_map_info() -> dict[str, dict[str, str]]:
npu_map_info: Dict[str, Dict[str, str]] = {} npu_map_info: dict[str, dict[str, str]] = {}
npu_info, _ = execute_command(["npu-smi", "info", "-m"]) npu_info, _ = execute_command(["npu-smi", "info", "-m"])
npu_map = npu_info.strip().split("\n")[1:] npu_map = npu_info.strip().split("\n")[1:]
for line in npu_map: for line in npu_map:
@@ -55,7 +49,7 @@ class DeviceInfo:
npu_map_info[npu_id][chip_id] = chip_logic_id npu_map_info[npu_id][chip_id] = chip_logic_id
return npu_map_info return npu_map_info
def get_running_npus(self) -> List[int]: def get_running_npus(self) -> list[int]:
npu_message, _ = execute_command(["npu-smi", "info"]) npu_message, _ = execute_command(["npu-smi", "info"])
in_proc_section = False in_proc_section = False
running_npu_set = set() running_npu_set = set()
@@ -76,36 +70,29 @@ class DeviceInfo:
continue continue
chip_logic_id = self.npu_map_info.get(npu_id, {}).get(chip_id) chip_logic_id = self.npu_map_info.get(npu_id, {}).get(chip_id)
if not chip_logic_id or not chip_logic_id.isdigit(): if not chip_logic_id or not chip_logic_id.isdigit():
raise RuntimeError( raise RuntimeError("Failed to get correct chip_logic_id from command 'npu-smi info -m'.")
"Failed to get correct chip_logic_id from command 'npu-smi info -m'."
)
running_npu_set.add(int(chip_logic_id)) running_npu_set.add(int(chip_logic_id))
if ASCEND_RT_VISIBLE_DEVICES: if ASCEND_RT_VISIBLE_DEVICES:
devices_str = ASCEND_RT_VISIBLE_DEVICES devices_str = ASCEND_RT_VISIBLE_DEVICES
devices_list = [int(x) for x in devices_str.split(",")] devices_list = [int(x) for x in devices_str.split(",")]
running_npu_set = set(devices_list) & running_npu_set running_npu_set = set(devices_list) & running_npu_set
if not running_npu_set: if not running_npu_set:
raise RuntimeError( raise RuntimeError("Can not get running npu info, you can use BIND_CPU=0 to skip.")
"Can not get running npu info, you can use BIND_CPU=0 to skip."
)
return sorted(running_npu_set) return sorted(running_npu_set)
def parse_allowed_cpus(self) -> List[int]: def parse_allowed_cpus(self) -> list[int]:
if not os.path.exists(ALLOWED_CPUS_PATH): if not os.path.exists(ALLOWED_CPUS_PATH):
return [] return []
with open(ALLOWED_CPUS_PATH) as f: with open(ALLOWED_CPUS_PATH) as f:
for line in f: for line in f:
if line.startswith("Cpus_allowed_list"): if line.startswith("Cpus_allowed_list"):
return self.expand_cpu_list(line.split()[1]) return self.expand_cpu_list(line.split()[1])
raise RuntimeError( raise RuntimeError("Can not found specific 'Cpus_allowed_list' in the '/proc/self/status' file.")
"Can not found specific 'Cpus_allowed_list' in the '/proc/self/status' file."
)
def parse_topo_affinity(self) -> Dict[int, List[int]]: def parse_topo_affinity(self) -> dict[int, list[int]]:
chip_logic_id = 0 chip_logic_id = 0
affinity: Dict[int, List[int]] = {} affinity: dict[int, list[int]] = {}
affinity_message, _ = execute_command( affinity_message, _ = execute_command(["npu-smi", "info", "-t", "topo"])
["npu-smi", "info", "-t", "topo"])
for line in affinity_message.splitlines(): for line in affinity_message.splitlines():
if line.startswith("NPU"): if line.startswith("NPU"):
parts = line.split() parts = line.split()
@@ -117,21 +104,19 @@ class DeviceInfo:
class CpuAlloc: class CpuAlloc:
def __init__(self, rank_id: int): def __init__(self, rank_id: int):
self.rank_id = rank_id self.rank_id = rank_id
self.device_info: DeviceInfo = DeviceInfo() self.device_info: DeviceInfo = DeviceInfo()
self.cpu_node: Dict[int, int] = {} self.cpu_node: dict[int, int] = {}
self.numa_to_cpu_map: Dict[int, List[int]] = defaultdict(list) self.numa_to_cpu_map: dict[int, list[int]] = defaultdict(list)
self.npu_cpu_pool: Dict[int, List[int]] = {} self.npu_cpu_pool: dict[int, list[int]] = {}
self.assign_main: Dict[int, List[int]] = {} self.assign_main: dict[int, list[int]] = {}
self.assign_acl: Dict[int, List[int]] = {} self.assign_acl: dict[int, list[int]] = {}
self.assign_rel: Dict[int, List[int]] = {} self.assign_rel: dict[int, list[int]] = {}
@staticmethod @staticmethod
def get_threads_map( def get_threads_map(thread_message: str) -> dict[str, dict[str, list[str]]]:
thread_message: str) -> Dict[str, Dict[str, List[str]]]: threads_map: dict[str, dict[str, list[str]]] = {}
threads_map: Dict[str, Dict[str, List[str]]] = {}
for line in thread_message.splitlines(): for line in thread_message.splitlines():
parts = line.split() parts = line.split()
if len(parts) < 2: if len(parts) < 2:
@@ -144,40 +129,33 @@ class CpuAlloc:
else: else:
continue continue
if main_pid not in threads_map: if main_pid not in threads_map:
threads_map[main_pid] = { threads_map[main_pid] = {"acl_thread": [], "release_thread": []}
"acl_thread": [],
"release_thread": []
}
threads_map[main_pid][key].append(sub_pid) threads_map[main_pid][key].append(sub_pid)
return threads_map return threads_map
@staticmethod @staticmethod
def bind(pid: str, cpus: List[int], bind_sub_thread: bool) -> None: def bind(pid: str, cpus: list[int], bind_sub_thread: bool) -> None:
if cpus: if cpus:
cpu_list = ",".join(map(str, cpus)) cpu_list = ",".join(map(str, cpus))
if bind_sub_thread: if bind_sub_thread:
bind_result, return_code = execute_command( bind_result, return_code = execute_command(["taskset", "-acp", cpu_list, pid])
["taskset", "-acp", cpu_list, pid])
else: else:
bind_result, return_code = execute_command( bind_result, return_code = execute_command(["taskset", "-cp", cpu_list, pid])
["taskset", "-cp", cpu_list, pid])
if return_code != 0: if return_code != 0:
raise RuntimeError(f"Failed to bind {pid} to CPU {cpu_list}.") raise RuntimeError(f"Failed to bind {pid} to CPU {cpu_list}.")
def average_distribute( def average_distribute(self, groups: dict[str, list[int]]) -> dict[int, list[int]]:
self, groups: Dict[str, List[int]]) -> Dict[int, List[int]]: result: dict[int, list[int]] = {}
result: Dict[int, List[int]] = {}
for key, npu_list in groups.items(): for key, npu_list in groups.items():
cpu_list = sorted(self.npu_cpu_pool[npu_list[0]]) cpu_list = sorted(self.npu_cpu_pool[npu_list[0]])
cpu_num_per_npu = len(cpu_list) // len(npu_list) cpu_num_per_npu = len(cpu_list) // len(npu_list)
for i, npu in enumerate(npu_list): for i, npu in enumerate(npu_list):
start_index = i * cpu_num_per_npu start_index = i * cpu_num_per_npu
end_index = (i + 1) * cpu_num_per_npu if i < len( end_index = (i + 1) * cpu_num_per_npu if i < len(npu_list) - 1 else len(cpu_list)
npu_list) - 1 else len(cpu_list)
result[npu] = cpu_list[start_index:end_index] result[npu] = cpu_list[start_index:end_index]
return result return result
def extend_numa(self, cpu_list: List[int]) -> List[int]: def extend_numa(self, cpu_list: list[int]) -> list[int]:
if not cpu_list: if not cpu_list:
return [] return []
nodes = {self.cpu_node[c] for c in cpu_list} nodes = {self.cpu_node[c] for c in cpu_list}
@@ -203,9 +181,7 @@ class CpuAlloc:
self.cpu_node[cpu] = node self.cpu_node[cpu] = node
self.numa_to_cpu_map[node].append(cpu) self.numa_to_cpu_map[node].append(cpu)
if len(self.numa_to_cpu_map) == 0: if len(self.numa_to_cpu_map) == 0:
raise RuntimeError( raise RuntimeError("lscpu command output error, no NUMA node available. Please check!")
"lscpu command output error, no NUMA node available. Please check!"
)
def handle_no_affinity(self) -> None: def handle_no_affinity(self) -> None:
num_running_npu = len(self.device_info.running_npu_list) num_running_npu = len(self.device_info.running_npu_list)
@@ -219,10 +195,7 @@ class CpuAlloc:
index = 0 index = 0
for node in sorted(self.numa_to_cpu_map): for node in sorted(self.numa_to_cpu_map):
# Available CPUs on this NUMA (constrained by allowed_cpus) # Available CPUs on this NUMA (constrained by allowed_cpus)
cpus = [ cpus = [c for c in self.numa_to_cpu_map[node] if c in self.device_info.allowed_cpus]
c for c in self.numa_to_cpu_map[node]
if c in self.device_info.allowed_cpus
]
if not cpus: if not cpus:
continue continue
# The actual number of NPUs to be allocated on this NUMA. # The actual number of NPUs to be allocated on this NUMA.
@@ -251,19 +224,16 @@ class CpuAlloc:
return return
for npu in self.device_info.running_npu_list: for npu in self.device_info.running_npu_list:
base_cpu_list = [ base_cpu_list = [
cpu for cpu in self.device_info.npu_affinity.get(npu, []) cpu for cpu in self.device_info.npu_affinity.get(npu, []) if cpu in self.device_info.allowed_cpus
if cpu in self.device_info.allowed_cpus
] ]
if not base_cpu_list: if not base_cpu_list:
raise RuntimeError( raise RuntimeError("CPUs available in 'Cpus_allowed_list' conflict with NUMA affinity.")
"CPUs available in 'Cpus_allowed_list' conflict with NUMA affinity."
)
extra_cpu_list = self.extend_numa(base_cpu_list) extra_cpu_list = self.extend_numa(base_cpu_list)
self.npu_cpu_pool[npu] = extra_cpu_list self.npu_cpu_pool[npu] = extra_cpu_list
groups = defaultdict(list) groups = defaultdict(list)
for npu, cpus in self.npu_cpu_pool.items(): for npu, cpus in self.npu_cpu_pool.items():
groups[str(cpus)].append(npu) groups[str(cpus)].append(npu)
final: Dict[int, List[int]] = {} final: dict[int, list[int]] = {}
for key, npu_list in groups.items(): for key, npu_list in groups.items():
if len(npu_list) == 1: if len(npu_list) == 1:
final[npu_list[0]] = self.npu_cpu_pool[npu_list[0]] final[npu_list[0]] = self.npu_cpu_pool[npu_list[0]]
@@ -279,8 +249,8 @@ class CpuAlloc:
rel = [pool[-1]] rel = [pool[-1]]
else: else:
raise RuntimeError( raise RuntimeError(
"The number of CPUs is insufficient to bind to the NPUs. " "The number of CPUs is insufficient to bind to the NPUs. Each NPU requires at least 3 CPUs."
"Each NPU requires at least 3 CPUs.") )
self.assign_main[npu] = main self.assign_main[npu] = main
self.assign_acl[npu] = acl self.assign_acl[npu] = acl
self.assign_rel[npu] = rel self.assign_rel[npu] = rel
@@ -290,10 +260,8 @@ class CpuAlloc:
current_npu = self.device_info.running_npu_list[self.rank_id] current_npu = self.device_info.running_npu_list[self.rank_id]
main = " ".join(map(str, self.assign_main[current_npu])) main = " ".join(map(str, self.assign_main[current_npu]))
acl = " ".join(map(str, self.assign_acl[current_npu])) acl = " ".join(map(str, self.assign_acl[current_npu]))
rel = str(self.assign_rel[current_npu] rel = str(self.assign_rel[current_npu]) if self.assign_rel[current_npu] else ""
) if self.assign_rel[current_npu] else "" logger.info(f"NPU{current_npu}: main=[{main}] acl=[{acl}] release=[{rel}]")
logger.info(
f"NPU{current_npu}: main=[{main}] acl=[{acl}] release=[{rel}]")
def bind_threads(self) -> None: def bind_threads(self) -> None:
thread_message, _ = execute_command(["ps", "-Te"]) thread_message, _ = execute_command(["ps", "-Te"])
@@ -303,8 +271,7 @@ class CpuAlloc:
self.bind(main_pid, self.assign_main[current_npu], True) self.bind(main_pid, self.assign_main[current_npu], True)
for acl_thread in threads_map.get(main_pid, {}).get("acl_thread", []): for acl_thread in threads_map.get(main_pid, {}).get("acl_thread", []):
self.bind(acl_thread, self.assign_acl[current_npu], False) self.bind(acl_thread, self.assign_acl[current_npu], False)
for release_thread in threads_map.get(main_pid, for release_thread in threads_map.get(main_pid, {}).get("release_thread", []):
{}).get("release_thread", []):
self.bind(release_thread, self.assign_rel[current_npu], False) self.bind(release_thread, self.assign_rel[current_npu], False)
def run_all(self) -> None: def run_all(self) -> None:

View File

@@ -1,5 +1,4 @@
from dataclasses import dataclass from dataclasses import dataclass
from typing import Optional
import torch import torch
from vllm.model_executor.layers.linear import LinearBase from vllm.model_executor.layers.linear import LinearBase
@@ -7,26 +6,26 @@ from vllm.model_executor.layers.linear import LinearBase
@dataclass @dataclass
class FlashCommon3Context: class FlashCommon3Context:
gate: Optional[LinearBase] = None gate: LinearBase | None = None
topk_weights: Optional[torch.Tensor] = None topk_weights: torch.Tensor | None = None
topk_ids: Optional[torch.Tensor] = None topk_ids: torch.Tensor | None = None
row_idx: Optional[torch.Tensor] = None row_idx: torch.Tensor | None = None
shared_experts: Optional[torch.nn.Module] = None shared_experts: torch.nn.Module | None = None
shared_out: Optional[torch.Tensor] = None shared_out: torch.Tensor | None = None
_flash_common3_context: Optional[FlashCommon3Context] = None _flash_common3_context: FlashCommon3Context | None = None
def get_flash_common3_context() -> Optional[FlashCommon3Context]: def get_flash_common3_context() -> FlashCommon3Context | None:
return _flash_common3_context return _flash_common3_context
def set_flash_common3_context( def set_flash_common3_context(
topk_weights: Optional[torch.Tensor] = None, topk_weights: torch.Tensor | None = None,
topk_ids: Optional[torch.Tensor] = None, topk_ids: torch.Tensor | None = None,
shared_experts: Optional[torch.nn.Module] = None, shared_experts: torch.nn.Module | None = None,
shared_out: Optional[torch.Tensor] = None, shared_out: torch.Tensor | None = None,
): ):
global _flash_common3_context global _flash_common3_context
if _flash_common3_context is None: if _flash_common3_context is None:

View File

@@ -46,17 +46,20 @@ def register_meta_if_necessary(ns: str, op_name: str, fn, overload: str = ""):
if overload != "": if overload != "":
op_name = op_name + "." + overload op_name = op_name + "." + overload
schema_to_find = ns + "::" + op_name schema_to_find = ns + "::" + op_name
meta_impl_list = torch._C._dispatch_get_registrations_for_dispatch_key( meta_impl_list = torch._C._dispatch_get_registrations_for_dispatch_key("Meta")
"Meta")
if schema_to_find in meta_impl_list: if schema_to_find in meta_impl_list:
return return
lib.impl(op_name, fn, "Meta") lib.impl(op_name, fn, "Meta")
def rotary_embedding_meta(positions: torch.Tensor, query: torch.Tensor, def rotary_embedding_meta(
key: torch.Tensor, head_size: int, positions: torch.Tensor,
cos_sin_cache: torch.Tensor, is_neox: bool): query: torch.Tensor,
key: torch.Tensor,
head_size: int,
cos_sin_cache: torch.Tensor,
is_neox: bool,
):
num_tokens = positions.numel() num_tokens = positions.numel()
query_hidden_size = query.numel() // num_tokens query_hidden_size = query.numel() // num_tokens
key_hidden_size = key.numel() // num_tokens key_hidden_size = key.numel() // num_tokens
@@ -68,38 +71,41 @@ def rotary_embedding_meta(positions: torch.Tensor, query: torch.Tensor,
return query_dst, key_dst return query_dst, key_dst
def get_masked_input_and_mask_meta(input: torch.Tensor, def get_masked_input_and_mask_meta(
org_vocab_start_index: int, input: torch.Tensor,
org_vocab_end_index: int, org_vocab_start_index: int,
num_org_vocab_padding: int, org_vocab_end_index: int,
added_vocab_start_index: int, num_org_vocab_padding: int,
added_vocab_end_index: int): added_vocab_start_index: int,
added_vocab_end_index: int,
):
masked_input = torch.empty_like(input) masked_input = torch.empty_like(input)
mask = torch.empty_like(input).to(torch.bool) mask = torch.empty_like(input).to(torch.bool)
return masked_input, mask return masked_input, mask
def bgmv_expand_meta(x: torch.Tensor, weight: torch.Tensor, def bgmv_expand_meta(
indices: torch.Tensor, y: torch.Tensor, slice_offset: int, x: torch.Tensor, weight: torch.Tensor, indices: torch.Tensor, y: torch.Tensor, slice_offset: int, slice_size: int
slice_size: int): ):
y_out = torch.empty_like(y) y_out = torch.empty_like(y)
return y_out return y_out
def sgmv_expand_meta(x: torch.Tensor, weight: torch.Tensor, def sgmv_expand_meta(
lora_indices: torch.Tensor, seq_len: torch.Tensor, x: torch.Tensor,
y: torch.Tensor, slice_offset: int, slice_size: int): weight: torch.Tensor,
lora_indices: torch.Tensor,
seq_len: torch.Tensor,
y: torch.Tensor,
slice_offset: int,
slice_size: int,
):
y_out = torch.empty_like(y) y_out = torch.empty_like(y)
return y_out return y_out
register_meta_if_necessary("_C_ascend", "rotary_embedding", register_meta_if_necessary("_C_ascend", "rotary_embedding", rotary_embedding_meta)
rotary_embedding_meta) register_meta_if_necessary("_C_ascend", "get_masked_input_and_mask", get_masked_input_and_mask_meta)
register_meta_if_necessary("_C_ascend", "get_masked_input_and_mask",
get_masked_input_and_mask_meta)
register_meta_if_necessary("_C_ascend", "bgmv_expand", bgmv_expand_meta) register_meta_if_necessary("_C_ascend", "bgmv_expand", bgmv_expand_meta)
register_meta_if_necessary("_C_ascend", "sgmv_expand", sgmv_expand_meta) register_meta_if_necessary("_C_ascend", "sgmv_expand", sgmv_expand_meta)

View File

@@ -15,9 +15,11 @@
# This file is a part of the vllm-ascend project. # This file is a part of the vllm-ascend project.
# #
from __future__ import annotations
import math import math
import os import os
from typing import TYPE_CHECKING, Any, Optional from typing import TYPE_CHECKING, Any
from uuid import uuid4 from uuid import uuid4
import torch import torch
@@ -32,11 +34,21 @@ from vllm_ascend.ascend_config import init_ascend_config
# isort: off # isort: off
from vllm_ascend.utils import ( from vllm_ascend.utils import (
ASCEND_QUANTIZATION_METHOD, COMPILATION_PASS_KEY, ASCEND_QUANTIZATION_METHOD,
COMPRESSED_TENSORS_METHOD, AscendDeviceType, check_kv_extra_config, COMPILATION_PASS_KEY,
enable_sp, flashcomm2_enable, get_ascend_device_type, is_moe_model, COMPRESSED_TENSORS_METHOD,
is_vl_model, refresh_block_size, update_aclgraph_sizes, AscendDeviceType,
update_cudagraph_capture_sizes, update_default_aclgraph_sizes) check_kv_extra_config,
enable_sp,
flashcomm2_enable,
get_ascend_device_type,
is_moe_model,
is_vl_model,
refresh_block_size,
update_aclgraph_sizes,
update_cudagraph_capture_sizes,
update_default_aclgraph_sizes,
)
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.config import ModelConfig, VllmConfig from vllm.config import ModelConfig, VllmConfig
@@ -80,7 +92,6 @@ def config_deprecated_logging():
class NPUPlatform(Platform): class NPUPlatform(Platform):
_enum = PlatformEnum.OOT _enum = PlatformEnum.OOT
device_name: str = "npu" device_name: str = "npu"
device_type: str = "npu" device_type: str = "npu"
@@ -89,9 +100,7 @@ class NPUPlatform(Platform):
device_control_env_var: str = "ASCEND_RT_VISIBLE_DEVICES" device_control_env_var: str = "ASCEND_RT_VISIBLE_DEVICES"
dispatch_key: str = "PrivateUse1" dispatch_key: str = "PrivateUse1"
supported_quantization: list[str] = [ supported_quantization: list[str] = [ASCEND_QUANTIZATION_METHOD, COMPRESSED_TENSORS_METHOD]
ASCEND_QUANTIZATION_METHOD, COMPRESSED_TENSORS_METHOD
]
def is_sleep_mode_available(self) -> bool: def is_sleep_mode_available(self) -> bool:
return True return True
@@ -116,33 +125,29 @@ class NPUPlatform(Platform):
@classmethod @classmethod
def get_compile_backend(self) -> str: def get_compile_backend(self) -> str:
""" """
Get the custom compile backend. Previously, we used EagerAdaptor by default. Get the custom compile backend. Previously, we used EagerAdaptor by default.
To use graph fusion operations, we defined our own backend compiler. To use graph fusion operations, we defined our own backend compiler.
""" """
return "vllm_ascend.compilation.compiler_interface.AscendCompiler" return "vllm_ascend.compilation.compiler_interface.AscendCompiler"
@classmethod @classmethod
def pre_register_and_update(cls, def pre_register_and_update(cls, parser: FlexibleArgumentParser | None = None) -> None:
parser: Optional[FlexibleArgumentParser] = None
) -> None:
# Adapt the global patch here. # Adapt the global patch here.
from vllm_ascend.utils import adapt_patch from vllm_ascend.utils import adapt_patch
adapt_patch(is_global_patch=True) adapt_patch(is_global_patch=True)
# For online serving, "ascend" quantization method is not a choice natively, # For online serving, "ascend" quantization method is not a choice natively,
# so we need to add "ascend" quantization method to quantization methods list # so we need to add "ascend" quantization method to quantization methods list
# and the user can enable quantization using "vllm serve --quantization ascend". # and the user can enable quantization using "vllm serve --quantization ascend".
if parser is not None: if parser is not None:
quant_action = parser._option_string_actions.get('--quantization') quant_action = parser._option_string_actions.get("--quantization")
if quant_action and hasattr(quant_action, if quant_action and hasattr(quant_action, "choices") and quant_action.choices:
'choices') and quant_action.choices:
if ASCEND_QUANTIZATION_METHOD not in quant_action.choices: if ASCEND_QUANTIZATION_METHOD not in quant_action.choices:
quant_action.choices.append(ASCEND_QUANTIZATION_METHOD) quant_action.choices.append(ASCEND_QUANTIZATION_METHOD)
from vllm_ascend.quantization.compressed_tensors.compressed_tensors import \ from vllm_ascend.quantization.compressed_tensors.compressed_tensors import AscendCompressedTensorsConfig # noqa: F401
AscendCompressedTensorsConfig # noqa: F401 from vllm_ascend.quantization.quant_config import AscendQuantConfig # noqa: F401
from vllm_ascend.quantization.quant_config import \
AscendQuantConfig # noqa: F401
config_deprecated_logging() config_deprecated_logging()
@@ -169,8 +174,7 @@ class NPUPlatform(Platform):
if vllm_config.kv_transfer_config is not None: if vllm_config.kv_transfer_config is not None:
check_kv_extra_config(vllm_config) check_kv_extra_config(vllm_config)
if not getattr(vllm_config.kv_transfer_config, if not getattr(vllm_config.kv_transfer_config, "_engine_id_patched", False):
"_engine_id_patched", False):
vllm_config.kv_transfer_config.engine_id = f"{vllm_config.kv_transfer_config.engine_id}-{uuid4().hex}" vllm_config.kv_transfer_config.engine_id = f"{vllm_config.kv_transfer_config.engine_id}-{uuid4().hex}"
vllm_config.kv_transfer_config._engine_id_patched = True vllm_config.kv_transfer_config._engine_id_patched = True
from vllm.config import CompilationMode # noqa: E402 from vllm.config import CompilationMode # noqa: E402
@@ -181,24 +185,22 @@ class NPUPlatform(Platform):
cache_config = vllm_config.cache_config cache_config = vllm_config.cache_config
ascend_compilation_config = ascend_config.ascend_compilation_config ascend_compilation_config = ascend_config.ascend_compilation_config
if ascend_compilation_config: if ascend_compilation_config:
vllm_config.additional_config.setdefault( vllm_config.additional_config.setdefault("ascend_compilation_config", {}).update(
"ascend_compilation_config", {}).update( vars(ascend_compilation_config)
vars(ascend_compilation_config if not isinstance(ascend_compilation_config, dict)
) if not isinstance(ascend_compilation_config, dict) else ascend_compilation_config
else ascend_compilation_config) )
elif model_config and hasattr(model_config.hf_text_config, elif model_config and hasattr(model_config.hf_text_config, "index_topk"):
"index_topk"): vllm_config.cache_config.cache_dtype = str(model_config.dtype).replace("torch.", "")
vllm_config.cache_config.cache_dtype = str(
model_config.dtype).replace("torch.", "")
if model_config is None: if model_config is None:
logger.warning("Model config is missing. This may indicate " logger.warning("Model config is missing. This may indicate that we are running a test case")
"that we are running a test case")
enforce_eager = False enforce_eager = False
else: else:
enforce_eager = getattr(model_config, "enforce_eager", False) enforce_eager = getattr(model_config, "enforce_eager", False)
from vllm.config.compilation import CUDAGraphMode from vllm.config.compilation import CUDAGraphMode
if enforce_eager: if enforce_eager:
logger.info("Compilation disabled, using eager mode by default") logger.info("Compilation disabled, using eager mode by default")
compilation_config.mode = CompilationMode.NONE compilation_config.mode = CompilationMode.NONE
@@ -207,12 +209,10 @@ class NPUPlatform(Platform):
compilation_config.cudagraph_num_of_warmups = 1 compilation_config.cudagraph_num_of_warmups = 1
if compilation_config.mode not in [ if compilation_config.mode not in [CompilationMode.NONE, CompilationMode.VLLM_COMPILE]:
CompilationMode.NONE, CompilationMode.VLLM_COMPILE
]:
logger.warning( logger.warning(
"NPU does not support %s compilation mode. Setting CUDAGraphMode to NONE", "NPU does not support %s compilation mode. Setting CUDAGraphMode to NONE", compilation_config.mode
compilation_config.mode) )
compilation_config.cudagraph_mode = CUDAGraphMode.NONE compilation_config.cudagraph_mode = CUDAGraphMode.NONE
# set cudaprah sizes before extending `compilation_config.splitting_ops` # set cudaprah sizes before extending `compilation_config.splitting_ops`
@@ -223,15 +223,18 @@ class NPUPlatform(Platform):
update_default_aclgraph_sizes(vllm_config) update_default_aclgraph_sizes(vllm_config)
# TODO delete graph size update here when compilation_config.pass_config.enable_sp # TODO delete graph size update here when compilation_config.pass_config.enable_sp
# is supported by vllm-ascend. # is supported by vllm-ascend.
if vllm_config.parallel_config.tensor_parallel_size > 1 and not vllm_config.model_config.enforce_eager and \ if (
enable_sp(vllm_config): vllm_config.parallel_config.tensor_parallel_size > 1
and not vllm_config.model_config.enforce_eager
and enable_sp(vllm_config)
):
original_sizes = compilation_config.cudagraph_capture_sizes original_sizes = compilation_config.cudagraph_capture_sizes
sp_aclgraph_sizes = \ sp_aclgraph_sizes = vllm_config.update_sizes_for_sequence_parallelism(original_sizes)
vllm_config.update_sizes_for_sequence_parallelism(original_sizes)
assert sp_aclgraph_sizes, ( assert sp_aclgraph_sizes, (
f"cudagraph_capture_sizes {original_sizes} does not contain" f"cudagraph_capture_sizes {original_sizes} does not contain"
f"values that are multiples of tp_size " f"values that are multiples of tp_size "
f"{vllm_config.parallel_config.tensor_parallel_size}") f"{vllm_config.parallel_config.tensor_parallel_size}"
)
if len(sp_aclgraph_sizes) != len(original_sizes): if len(sp_aclgraph_sizes) != len(original_sizes):
compilation_config.cudagraph_capture_sizes = sp_aclgraph_sizes compilation_config.cudagraph_capture_sizes = sp_aclgraph_sizes
update_cudagraph_capture_sizes(vllm_config, sp_aclgraph_sizes) update_cudagraph_capture_sizes(vllm_config, sp_aclgraph_sizes)
@@ -243,9 +246,7 @@ class NPUPlatform(Platform):
# encoder-decoder models currently only support piecewise mode # encoder-decoder models currently only support piecewise mode
if model_config and model_config.is_encoder_decoder is True: if model_config and model_config.is_encoder_decoder is True:
if compilation_config.cudagraph_mode == CUDAGraphMode.FULL_DECODE_ONLY: if compilation_config.cudagraph_mode == CUDAGraphMode.FULL_DECODE_ONLY:
logger.warning( logger.warning("encoder-decoder model doesn't support FULL_DECODE_ONLY, fallback to PIECEWISE ")
"encoder-decoder model doesn't support FULL_DECODE_ONLY, fallback to PIECEWISE "
)
compilation_config.cudagraph_mode = CUDAGraphMode.PIECEWISE compilation_config.cudagraph_mode = CUDAGraphMode.PIECEWISE
# get custom compile backend for graph fusion # get custom compile backend for graph fusion
@@ -255,15 +256,14 @@ class NPUPlatform(Platform):
compilation_config.mode = CompilationMode.NONE compilation_config.mode = CompilationMode.NONE
ascend_config.enable_npugraph_ex = False ascend_config.enable_npugraph_ex = False
elif compilation_config.cudagraph_mode == CUDAGraphMode.PIECEWISE: elif compilation_config.cudagraph_mode == CUDAGraphMode.PIECEWISE:
logger.info( logger.info("PIECEWISE compilation enabled on NPU. use_inductor not supported - using only ACL Graph mode")
"PIECEWISE compilation enabled on NPU. use_inductor not supported - " assert compilation_config.mode == CompilationMode.VLLM_COMPILE, (
"using only ACL Graph mode") "When enabling VLLM_COMPILE aclgraph, please make sure compilation_config.mode == "
assert compilation_config.mode == CompilationMode.VLLM_COMPILE, \ "CompilationMode.VLLM_COMPILE and compilation_config.cudagraph_mode == CUDAGraphMode.VLLM_COMPILE"
"When enabling VLLM_COMPILE aclgraph, please make sure compilation_config.mode == CompilationMode.VLLM_COMPILE and compilation_config.cudagraph_mode == CUDAGraphMode.VLLM_COMPILE" )
compilation_config.set_splitting_ops_for_v1( compilation_config.set_splitting_ops_for_v1(
all2all_backend=vllm_config.parallel_config.all2all_backend, all2all_backend=vllm_config.parallel_config.all2all_backend,
data_parallel_size=vllm_config.parallel_config. data_parallel_size=vllm_config.parallel_config.data_parallel_size,
data_parallel_size,
) )
compilation_config.use_inductor = False compilation_config.use_inductor = False
# NOTE: Theoretically, we should also add vllm::mla_forward in the attention ops. # NOTE: Theoretically, we should also add vllm::mla_forward in the attention ops.
@@ -275,11 +275,13 @@ class NPUPlatform(Platform):
compilation_config.splitting_ops.extend(["vllm::mla_forward"]) compilation_config.splitting_ops.extend(["vllm::mla_forward"])
update_aclgraph_sizes(vllm_config) update_aclgraph_sizes(vllm_config)
ascend_config.enable_npugraph_ex = False ascend_config.enable_npugraph_ex = False
elif compilation_config.cudagraph_mode == CUDAGraphMode.FULL_DECODE_ONLY or\ elif (
compilation_config.cudagraph_mode == CUDAGraphMode.FULL: compilation_config.cudagraph_mode == CUDAGraphMode.FULL_DECODE_ONLY
or compilation_config.cudagraph_mode == CUDAGraphMode.FULL
):
logger.info( logger.info(
"FULL_DECODE_ONLY compilation enabled on NPU. use_inductor not supported - " "FULL_DECODE_ONLY compilation enabled on NPU. use_inductor not supported - using only ACL Graph mode"
"using only ACL Graph mode") )
compilation_config.use_inductor = False compilation_config.use_inductor = False
compilation_config.splitting_ops = [] compilation_config.splitting_ops = []
warning_message = """\033[91m warning_message = """\033[91m
@@ -297,30 +299,31 @@ class NPUPlatform(Platform):
logger.warning(warning_message) logger.warning(warning_message)
else: else:
logger.info( logger.info(
"%s cudagraph_mode is not support on NPU. falling back to NONE", "%s cudagraph_mode is not support on NPU. falling back to NONE", compilation_config.cudagraph_mode
compilation_config.cudagraph_mode) )
compilation_config.cudagraph_mode = CUDAGraphMode.NONE compilation_config.cudagraph_mode = CUDAGraphMode.NONE
compilation_config.mode = CompilationMode.NONE compilation_config.mode = CompilationMode.NONE
ascend_config.enable_npugraph_ex = False ascend_config.enable_npugraph_ex = False
# TODO: Remove this check when ACL Graph supports ASCEND_LAUNCH_BLOCKING=1 # TODO: Remove this check when ACL Graph supports ASCEND_LAUNCH_BLOCKING=1
# Then, we will have to discuss the error handling strategy and user experience # Then, we will have to discuss the error handling strategy and user experience
if compilation_config.cudagraph_mode != CUDAGraphMode.NONE and \ if (
os.environ.get("ASCEND_LAUNCH_BLOCKING", "0") == "1": compilation_config.cudagraph_mode != CUDAGraphMode.NONE
and os.environ.get("ASCEND_LAUNCH_BLOCKING", "0") == "1"
):
raise ValueError( raise ValueError(
"ACL graph is incompatible with ASCEND_LAUNCH_BLOCKING=1. " "ACL graph is incompatible with ASCEND_LAUNCH_BLOCKING=1. "
"Please unset ASCEND_LAUNCH_BLOCKING or set it to 0. If you " "Please unset ASCEND_LAUNCH_BLOCKING or set it to 0. If you "
"need ASCEND_LAUNCH_BLOCKING for debugging, consider other methods — " "need ASCEND_LAUNCH_BLOCKING for debugging, consider other methods — "
"for example, check the plog files (default: $HOME/ascend/log/debug) " "for example, check the plog files (default: $HOME/ascend/log/debug) "
"for more information about runtime errors.") "for more information about runtime errors."
)
if parallel_config and parallel_config.worker_cls == "auto": if parallel_config and parallel_config.worker_cls == "auto":
# TODO: this is a tricky way to disable `use_sequence_parallel_moe` in vllm. # TODO: this is a tricky way to disable `use_sequence_parallel_moe` in vllm.
parallel_config.all2all_backend = "flashinfer_all2allv" parallel_config.all2all_backend = "flashinfer_all2allv"
if ascend_config.xlite_graph_config.enabled: if ascend_config.xlite_graph_config.enabled:
logger.info( logger.info("openEuler Xlite enabled. See: https://atomgit.com/openeuler/GVirt/tree/master/xlite")
"openEuler Xlite enabled. See: https://atomgit.com/openeuler/GVirt/tree/master/xlite"
)
parallel_config.worker_cls = "vllm_ascend.xlite.xlite_worker.XliteWorker" parallel_config.worker_cls = "vllm_ascend.xlite.xlite_worker.XliteWorker"
else: else:
parallel_config.worker_cls = "vllm_ascend.worker.worker.NPUWorker" parallel_config.worker_cls = "vllm_ascend.worker.worker.NPUWorker"
@@ -332,10 +335,9 @@ class NPUPlatform(Platform):
compilation_config.custom_ops = ["all"] compilation_config.custom_ops = ["all"]
if ascend_config.recompute_scheduler_enable: if ascend_config.recompute_scheduler_enable:
from vllm_ascend.core.recompute_scheduler import \ from vllm_ascend.core.recompute_scheduler import RecomputeSchedulerConfig
RecomputeSchedulerConfig
recompute_scheduler_config = RecomputeSchedulerConfig.initialize_from_config( recompute_scheduler_config = RecomputeSchedulerConfig.initialize_from_config(vllm_config)
vllm_config)
vllm_config.scheduler_config = recompute_scheduler_config vllm_config.scheduler_config = recompute_scheduler_config
# Extend original scheduler_config to use SchedulerDynamicBatch. # Extend original scheduler_config to use SchedulerDynamicBatch.
@@ -346,9 +348,11 @@ class NPUPlatform(Platform):
vllm_config.scheduler_config.enable_chunked_prefill = True vllm_config.scheduler_config.enable_chunked_prefill = True
vllm_config.scheduler_config.SLO_limits_for_dynamic_batch = ascend_config.SLO_limits_for_dynamic_batch vllm_config.scheduler_config.SLO_limits_for_dynamic_batch = ascend_config.SLO_limits_for_dynamic_batch
if vllm_config.kv_transfer_config is not None and \ if (
cache_config.block_size != parallel_config.cp_kv_cache_interleave_size and \ vllm_config.kv_transfer_config is not None
parallel_config.decode_context_parallel_size * parallel_config.prefill_context_parallel_size > 1: and cache_config.block_size != parallel_config.cp_kv_cache_interleave_size
and parallel_config.decode_context_parallel_size * parallel_config.prefill_context_parallel_size > 1
):
raise AssertionError( raise AssertionError(
f"cp_kv_cache_interleave_size({parallel_config.cp_kv_cache_interleave_size}) " f"cp_kv_cache_interleave_size({parallel_config.cp_kv_cache_interleave_size}) "
f"and block_size({cache_config.block_size}) " f"and block_size({cache_config.block_size}) "
@@ -356,12 +360,14 @@ class NPUPlatform(Platform):
) )
if is_vl_model(vllm_config): if is_vl_model(vllm_config):
if bool(int(os.getenv("VLLM_ASCEND_ENABLE_FLASHCOMM", '0'))) or \ if bool(int(os.getenv("VLLM_ASCEND_ENABLE_FLASHCOMM", "0"))) or bool(
bool(int(os.getenv("VLLM_ASCEND_ENABLE_FLASHCOMM1", '0'))): int(os.getenv("VLLM_ASCEND_ENABLE_FLASHCOMM1", "0"))
):
raise ValueError( raise ValueError(
"Currently, VL models doesn't support " "Currently, VL models doesn't support "
"FLASHCOMM in vllm-ascend. We will fix this in the future. " "FLASHCOMM in vllm-ascend. We will fix this in the future. "
"Please set VLLM_ASCEND_ENABLE_FLASHCOMM1=0.") "Please set VLLM_ASCEND_ENABLE_FLASHCOMM1=0."
)
@classmethod @classmethod
def import_kernels(cls) -> None: def import_kernels(cls) -> None:
@@ -377,14 +383,11 @@ class NPUPlatform(Platform):
if _CUSTOM_OP_REGISTERED: if _CUSTOM_OP_REGISTERED:
return return
CUR_DIR = os.path.dirname(os.path.realpath(__file__)) CUR_DIR = os.path.dirname(os.path.realpath(__file__))
CUSTOM_OPP_PATH = os.path.join(CUR_DIR, "_cann_ops_custom", "vendors", CUSTOM_OPP_PATH = os.path.join(CUR_DIR, "_cann_ops_custom", "vendors", "vllm-ascend")
"vllm-ascend")
if os.path.exists(CUSTOM_OPP_PATH): if os.path.exists(CUSTOM_OPP_PATH):
current_cust_opp_path = os.environ.get("ASCEND_CUSTOM_OPP_PATH", current_cust_opp_path = os.environ.get("ASCEND_CUSTOM_OPP_PATH", "")
"")
if current_cust_opp_path: if current_cust_opp_path:
os.environ[ os.environ["ASCEND_CUSTOM_OPP_PATH"] = f"{CUSTOM_OPP_PATH}:{current_cust_opp_path}"
"ASCEND_CUSTOM_OPP_PATH"] = f"{CUSTOM_OPP_PATH}:{current_cust_opp_path}"
else: else:
os.environ["ASCEND_CUSTOM_OPP_PATH"] = CUSTOM_OPP_PATH os.environ["ASCEND_CUSTOM_OPP_PATH"] = CUSTOM_OPP_PATH
_CUSTOM_OP_REGISTERED = True _CUSTOM_OP_REGISTERED = True
@@ -393,22 +396,18 @@ class NPUPlatform(Platform):
def get_attn_backend_cls(cls, selected_backend, attn_selector_config): def get_attn_backend_cls(cls, selected_backend, attn_selector_config):
backend_map = { backend_map = {
(True, False): "vllm_ascend.attention.mla_v1.AscendMLABackend", (True, False): "vllm_ascend.attention.mla_v1.AscendMLABackend",
(False, False): (False, False): "vllm_ascend.attention.attention_v1.AscendAttentionBackend",
"vllm_ascend.attention.attention_v1.AscendAttentionBackend",
(True, True): "vllm_ascend.attention.sfa_v1.AscendSFABackend", (True, True): "vllm_ascend.attention.sfa_v1.AscendSFABackend",
} }
return backend_map[(attn_selector_config.use_mla, return backend_map[(attn_selector_config.use_mla, attn_selector_config.use_sparse)]
attn_selector_config.use_sparse)]
@classmethod @classmethod
def get_punica_wrapper(cls) -> str: def get_punica_wrapper(cls) -> str:
return "vllm_ascend.lora.punica_npu.PunicaWrapperNPU" return "vllm_ascend.lora.punica_npu.PunicaWrapperNPU"
@classmethod @classmethod
def get_current_memory_usage(cls, def get_current_memory_usage(cls, device: torch.types.Device | None = None) -> float:
device: Optional[torch.types.Device] = None
) -> float:
torch.npu.reset_peak_memory_stats(device) torch.npu.reset_peak_memory_stats(device)
return torch.npu.max_memory_allocated(device) return torch.npu.max_memory_allocated(device)
@@ -457,32 +456,33 @@ class NPUPlatform(Platform):
Args: Args:
attn_metadata (dict[str, Any]): attention metadata for all layers. attn_metadata (dict[str, Any]): attention metadata for all layers.
vllm_config (VllmConfig): configuration of vllm. vllm_config (VllmConfig): configuration of vllm.
dp_metadata (DpMetada): metadata for data parallelism. dp_metadata (DpMetada): metadata for data parallelism.
lack of typehint because of circular import. lack of typehint because of circular import.
virtual_engine (int, optional): index of virtual engine. Defaults to 0. virtual_engine (int, optional): index of virtual engine. Defaults to 0.
num_tokens (int | None, optional): number of tokens. Defaults to None. num_tokens (int | None, optional): number of tokens. Defaults to None.
num_tokens_across_dp (torch.Tensor | None, optional): number of tokens num_tokens_across_dp (torch.Tensor | None, optional): number of tokens
across data parallelism.Defaults to None. across data parallelism.Defaults to None.
cudagraph_runtime_mode (CUDAGraphMode, optional): mode of cudagraph runtime. cudagraph_runtime_mode (CUDAGraphMode, optional): mode of cudagraph runtime.
Defaults to None.lack of typehint because of circular import. Defaults to None.lack of typehint because of circular import.
batch_descriptor (BatchDescriptor, optional): descriptor of batch. batch_descriptor (BatchDescriptor, optional): descriptor of batch.
Defaults to None. Defaults to None.
ubatch_slices (UBatchSlices, optional): slice info for dual batch. ubatch_slices (UBatchSlices, optional): slice info for dual batch.
Defaults to None. lack of typehint because of circular import Defaults to None. lack of typehint because of circular import
Returns: Returns:
dict[str, Any]: _description_ dict[str, Any]: _description_
""" """
# NOTE(Ronald1995): avoid circular import. # NOTE(Ronald1995): avoid circular import.
from vllm_ascend.ascend_forward_context import (get_mc2_mask, from vllm_ascend.ascend_forward_context import get_mc2_mask, select_moe_comm_method
select_moe_comm_method)
from vllm_ascend.ops.fused_moe.moe_comm_method import get_moe_comm_method from vllm_ascend.ops.fused_moe.moe_comm_method import get_moe_comm_method
from vllm.distributed import get_dp_group, get_tensor_model_parallel_world_size from vllm.distributed import get_dp_group, get_tensor_model_parallel_world_size
# NOTE(Ronald1995): avoid circular import, cudagraph_runtime_mode is # NOTE(Ronald1995): avoid circular import, cudagraph_runtime_mode is
# CUDAGraphMode.NONE in vllm, but we can't set CUDAGraphMode.NONE in # CUDAGraphMode.NONE in vllm, but we can't set CUDAGraphMode.NONE in
# argument default value, so we set it to None first, then set it to # argument default value, so we set it to None first, then set it to
# CUDAGraphMode.NONE here. # CUDAGraphMode.NONE here.
from vllm.config import CUDAGraphMode from vllm.config import CUDAGraphMode
if cudagraph_runtime_mode is None: if cudagraph_runtime_mode is None:
cudagraph_runtime_mode = CUDAGraphMode.NONE cudagraph_runtime_mode = CUDAGraphMode.NONE
# TODO(Ronald1995): model runner v1 still use ascend_forward_context, # TODO(Ronald1995): model runner v1 still use ascend_forward_context,
@@ -526,31 +526,26 @@ class NPUPlatform(Platform):
sp_enabled = enable_sp(vllm_config) and num_tokens is not None sp_enabled = enable_sp(vllm_config) and num_tokens is not None
mmrs_fusion = False mmrs_fusion = False
else: else:
sp_enabled = enable_sp(vllm_config) and \ sp_enabled = enable_sp(vllm_config) and num_tokens is not None and num_tokens > 1000
num_tokens is not None and num_tokens > 1000
# TODO(Levi-JQ): another PR to normalize the enabling logic for sp/fc2 # TODO(Levi-JQ): another PR to normalize the enabling logic for sp/fc2
flashcomm_v2_enabled = flashcomm2_enable( flashcomm_v2_enabled = flashcomm2_enable() and tp_world_size > 1 and num_tokens is not None
) and tp_world_size > 1 and num_tokens is not None
pad_size = 0 pad_size = 0
if (sp_enabled or flashcomm_v2_enabled): if sp_enabled or flashcomm_v2_enabled:
pad_size = (tp_world_size - pad_size = (tp_world_size - (num_tokens % tp_world_size)) % tp_world_size
(num_tokens % tp_world_size)) % tp_world_size
dp_world_size = get_dp_group().world_size dp_world_size = get_dp_group().world_size
if dp_world_size > 1 and dp_metadata is not None: if dp_world_size > 1 and dp_metadata is not None:
max_tokens_across_dp = dp_metadata.max_tokens_across_dp_cpu.item() max_tokens_across_dp = dp_metadata.max_tokens_across_dp_cpu.item()
if (sp_enabled or flashcomm_v2_enabled): if sp_enabled or flashcomm_v2_enabled:
padded_length = (max_tokens_across_dp + tp_world_size - padded_length = (max_tokens_across_dp + tp_world_size - 1) // tp_world_size * tp_world_size
1) // tp_world_size * tp_world_size
pad_size = padded_length - num_tokens pad_size = padded_length - num_tokens
else: else:
max_tokens_across_dp = num_tokens max_tokens_across_dp = num_tokens
if num_tokens is not None: if num_tokens is not None:
# NOTE: token num which need to pad to when mc2 # NOTE: token num which need to pad to when mc2
padded_num_tokens = math.ceil( padded_num_tokens = math.ceil(max_tokens_across_dp / tp_world_size) * tp_world_size
max_tokens_across_dp / tp_world_size) * tp_world_size
reserved_mc2_mask = get_mc2_mask() reserved_mc2_mask = get_mc2_mask()
if reserved_mc2_mask is not None: if reserved_mc2_mask is not None:
mc2_mask = reserved_mc2_mask[:padded_num_tokens] mc2_mask = reserved_mc2_mask[:padded_num_tokens]

View File

@@ -21,9 +21,9 @@ This module generates the service_profiling_symbols.yaml configuration file
to ~/.config/vllm_ascend/ directory. to ~/.config/vllm_ascend/ directory.
""" """
import contextlib
import tempfile import tempfile
from pathlib import Path from pathlib import Path
from typing import Optional
import vllm import vllm
from vllm.logger import logger from vllm.logger import logger
@@ -120,7 +120,7 @@ SERVICE_PROFILING_SYMBOLS_YAML = """
def get_config_dir() -> Path: def get_config_dir() -> Path:
""" """
Get the vllm_ascend configuration directory path. Get the vllm_ascend configuration directory path.
Returns: Returns:
Path: The path to ~/.config/vllm_ascend/ directory. Path: The path to ~/.config/vllm_ascend/ directory.
""" """
@@ -129,32 +129,30 @@ def get_config_dir() -> Path:
return config_dir return config_dir
def _cleanup_temp_file(tmp_path: Optional[Path]) -> None: def _cleanup_temp_file(tmp_path: Path | None) -> None:
""" """
Clean up a temporary file if it exists. Clean up a temporary file if it exists.
Args: Args:
tmp_path: Path to the temporary file to clean up. tmp_path: Path to the temporary file to clean up.
""" """
if tmp_path is not None and tmp_path.exists(): if tmp_path is not None and tmp_path.exists():
try: with contextlib.suppress(OSError):
tmp_path.unlink() tmp_path.unlink()
except OSError:
pass # Ignore cleanup errors
def generate_service_profiling_config() -> Optional[Path]: def generate_service_profiling_config() -> Path | None:
""" """
Generate the service_profiling_symbols.yaml configuration file Generate the service_profiling_symbols.yaml configuration file
to ~/.config/vllm_ascend/ directory. to ~/.config/vllm_ascend/ directory.
If the configuration file already exists, this function will skip If the configuration file already exists, this function will skip
creating it and return the existing file path. creating it and return the existing file path.
If any error occurs during file creation, it will be logged but If any error occurs during file creation, it will be logged but
will not interrupt the execution. The function will return None will not interrupt the execution. The function will return None
to indicate that the file could not be created. to indicate that the file could not be created.
Returns: Returns:
Optional[Path]: The path to the generated (or existing) configuration file. Optional[Path]: The path to the generated (or existing) configuration file.
Returns None if file creation failed. Returns None if file creation failed.
@@ -170,9 +168,7 @@ def generate_service_profiling_config() -> Optional[Path]:
try: try:
config_dir.mkdir(parents=True, exist_ok=True) config_dir.mkdir(parents=True, exist_ok=True)
except (OSError, PermissionError) as e: except (OSError, PermissionError) as e:
logger.error( logger.error(f"Failed to create configuration directory {config_dir}: {e}", exc_info=True)
f"Failed to create configuration directory {config_dir}: {e}",
exc_info=True)
return None return None
# Write the configuration file atomically using a temporary file # Write the configuration file atomically using a temporary file
@@ -180,13 +176,9 @@ def generate_service_profiling_config() -> Optional[Path]:
tmp_path = None tmp_path = None
try: try:
# Create a temporary file in the same directory for atomic write # Create a temporary file in the same directory for atomic write
with tempfile.NamedTemporaryFile(mode='w', with tempfile.NamedTemporaryFile(
encoding='utf-8', mode="w", encoding="utf-8", dir=config_dir, delete=False, suffix=".tmp", prefix=CONFIG_FILENAME + "."
dir=config_dir, ) as tmp_file:
delete=False,
suffix='.tmp',
prefix=CONFIG_FILENAME +
'.') as tmp_file:
tmp_file.write(SERVICE_PROFILING_SYMBOLS_YAML) tmp_file.write(SERVICE_PROFILING_SYMBOLS_YAML)
tmp_path = Path(tmp_file.name) tmp_path = Path(tmp_file.name)
@@ -194,8 +186,7 @@ def generate_service_profiling_config() -> Optional[Path]:
tmp_path.replace(config_file) tmp_path.replace(config_file)
return config_file return config_file
except (OSError, PermissionError) as e: except (OSError, PermissionError) as e:
logger.error(f"Failed to write configuration file {config_file}: {e}", logger.error(f"Failed to write configuration file {config_file}: {e}", exc_info=True)
exc_info=True)
return None return None
finally: finally:
# Clean up the temporary file if it wasn't successfully replaced # Clean up the temporary file if it wasn't successfully replaced

View File

@@ -17,6 +17,8 @@
# Adapted from vllm-project/vllm/vllm/worker/worker.py # Adapted from vllm-project/vllm/vllm/worker/worker.py
# #
from __future__ import annotations
import atexit import atexit
import functools import functools
import math import math
@@ -25,7 +27,7 @@ from contextlib import contextmanager, nullcontext
from enum import Enum from enum import Enum
from functools import lru_cache from functools import lru_cache
from threading import Lock from threading import Lock
from typing import TYPE_CHECKING, Any, List, Optional, Tuple, Union from typing import TYPE_CHECKING, Any
import torch import torch
import torch_npu # noqa: F401 import torch_npu # noqa: F401
@@ -88,6 +90,7 @@ def acl_graph_print(*args):
resolving unexpected hangs. Usage: resolving unexpected hangs. Usage:
```python ```python
from vllm_ascend.utils import acl_graph_print from vllm_ascend.utils import acl_graph_print
... ...
acl_graph_print("Debug info") acl_graph_print("Debug info")
``` ```
@@ -113,8 +116,7 @@ def acl_graph_print(*args):
torch_npu.npu._subscribe_report(current_compute_stream) torch_npu.npu._subscribe_report(current_compute_stream)
_SUBSCRIBED_COMPUTE_STREAMS.add(current_compute_stream) _SUBSCRIBED_COMPUTE_STREAMS.add(current_compute_stream)
torch_npu.npu._launch_host_func(current_compute_stream, torch_npu.npu._launch_host_func(current_compute_stream, _print_callback_on_stream, args)
_print_callback_on_stream, args)
def _unregister_print_streams_on_exit(): def _unregister_print_streams_on_exit():
@@ -196,9 +198,7 @@ def nd_to_nz_2d(in_tensor: torch.Tensor) -> torch.Tensor:
# after: pad_dims: [0, 2, 0, 3] # after: pad_dims: [0, 2, 0, 3]
# return: (1, 2, 16, 16) # return: (1, 2, 16, 16)
return _custom_transpose( return _custom_transpose(_custom_reshape(_custom_pad(in_tensor, pad_dims), aux_dims), 1, 2).contiguous()
_custom_reshape(_custom_pad(in_tensor, pad_dims), aux_dims), 1,
2).contiguous()
def nd_to_nz_spec(mask_tensor: torch.Tensor) -> torch.Tensor: def nd_to_nz_spec(mask_tensor: torch.Tensor) -> torch.Tensor:
@@ -208,11 +208,9 @@ def nd_to_nz_spec(mask_tensor: torch.Tensor) -> torch.Tensor:
tokens_pad = (num_tokens + 15) // 16 * 16 tokens_pad = (num_tokens + 15) // 16 * 16
max_seq_len_pad = (max_seq_len + 15) // 16 * 16 max_seq_len_pad = (max_seq_len + 15) // 16 * 16
mask_tensor_pad = \ mask_tensor_pad = torch.zeros((1, tokens_pad, max_seq_len_pad), dtype=mask_tensor.dtype, device=mask_tensor.device)
torch.zeros((1, tokens_pad, max_seq_len_pad), dtype=mask_tensor.dtype, device=mask_tensor.device)
mask_tensor_pad[0][:num_tokens, :max_seq_len] = mask_tensor mask_tensor_pad[0][:num_tokens, :max_seq_len] = mask_tensor
mask = mask_tensor_pad.reshape( mask = mask_tensor_pad.reshape((1, tokens_pad, max_seq_len_pad // 16, 16)).permute(0, 2, 1, 3)
(1, tokens_pad, max_seq_len_pad // 16, 16)).permute(0, 2, 1, 3)
return mask return mask
@@ -230,10 +228,7 @@ def aligned_16(tensor: torch.Tensor):
return tensor return tensor
# Create a new tensor with shape (n_aligned, H, W) and fill it with zeros # Create a new tensor with shape (n_aligned, H, W) and fill it with zeros
new_tensor = torch.zeros(n_aligned, new_tensor = torch.zeros(n_aligned, *tensor.shape[1:], dtype=tensor.dtype, device=tensor.device)
*tensor.shape[1:],
dtype=tensor.dtype,
device=tensor.device)
# Copy the original tensor to the first N positions of the new tensor # Copy the original tensor to the first N positions of the new tensor
new_tensor[:n] = tensor new_tensor[:n] = tensor
@@ -254,15 +249,15 @@ def enable_custom_op():
# isort: off # isort: off
# register custom ops into torch_library here # register custom ops into torch_library here
import vllm_ascend.vllm_ascend_C # type: ignore # noqa: F401 import vllm_ascend.vllm_ascend_C # type: ignore # noqa: F401
# register the meta implementation for custom kernel if necessary # register the meta implementation for custom kernel if necessary
import vllm_ascend.meta_registration # type: ignore # noqa: F401 import vllm_ascend.meta_registration # type: ignore # noqa: F401
# isort: on # isort: on
_CUSTOM_OP_ENABLED = True _CUSTOM_OP_ENABLED = True
except ImportError: except ImportError:
_CUSTOM_OP_ENABLED = False _CUSTOM_OP_ENABLED = False
logger.warning( logger.warning("Warning: Failed to register custom ops, all custom ops will be disabled")
"Warning: Failed to register custom ops, all custom ops will be disabled"
)
return _CUSTOM_OP_ENABLED return _CUSTOM_OP_ENABLED
@@ -277,8 +272,7 @@ def find_hccl_library() -> str:
# manually load the hccl library # manually load the hccl library
if so_file: if so_file:
logger.info("Found hccl from environment variable HCCL_SO_PATH=%s", logger.info("Found hccl from environment variable HCCL_SO_PATH=%s", so_file)
so_file)
else: else:
if torch.version.cann is not None: if torch.version.cann is not None:
so_file = "libhccl.so" so_file = "libhccl.so"
@@ -318,6 +312,7 @@ def set_weight_prefetch_method(weight_prefetch_config: WeightPrefetchConfig):
global _WEIGHT_PREFETCH_METHOD global _WEIGHT_PREFETCH_METHOD
if _WEIGHT_PREFETCH_METHOD is None: if _WEIGHT_PREFETCH_METHOD is None:
from vllm_ascend.ops.weight_prefetch import WeightPrefetchMethod from vllm_ascend.ops.weight_prefetch import WeightPrefetchMethod
_WEIGHT_PREFETCH_METHOD = WeightPrefetchMethod(weight_prefetch_config) _WEIGHT_PREFETCH_METHOD = WeightPrefetchMethod(weight_prefetch_config)
return _WEIGHT_PREFETCH_METHOD return _WEIGHT_PREFETCH_METHOD
@@ -364,6 +359,7 @@ def vllm_version_is(target_vllm_version: str):
vllm_version = envs_ascend.VLLM_VERSION vllm_version = envs_ascend.VLLM_VERSION
else: else:
import vllm import vllm
vllm_version = vllm.__version__ vllm_version = vllm.__version__
try: try:
return Version(vllm_version) == Version(target_vllm_version) return Version(vllm_version) == Version(target_vllm_version)
@@ -372,7 +368,8 @@ def vllm_version_is(target_vllm_version: str):
f"Invalid vllm version {vllm_version} found. A dev version of vllm " f"Invalid vllm version {vllm_version} found. A dev version of vllm "
"is installed probably. Set the environment variable VLLM_VERSION " "is installed probably. Set the environment variable VLLM_VERSION "
"to control it by hand. And please make sure the value follows the " "to control it by hand. And please make sure the value follows the "
"format of x.y.z.") "format of x.y.z."
)
def get_max_hidden_layers(hf_config) -> int: def get_max_hidden_layers(hf_config) -> int:
@@ -394,20 +391,19 @@ def get_max_hidden_layers(hf_config) -> int:
# Update cudagraph capture sizes for vllm config # Update cudagraph capture sizes for vllm config
def update_cudagraph_capture_sizes(vllm_config: VllmConfig, def update_cudagraph_capture_sizes(vllm_config: VllmConfig, cudagraph_capture_sizes: list[int]):
cudagraph_capture_sizes: List[int]): valid_max_size = cudagraph_capture_sizes[-1] if cudagraph_capture_sizes else 0
if (
valid_max_size = (cudagraph_capture_sizes[-1] vllm_config.compilation_config.max_cudagraph_capture_size is not None
if cudagraph_capture_sizes else 0) and vllm_config.compilation_config.max_cudagraph_capture_size != valid_max_size
if (vllm_config.compilation_config.max_cudagraph_capture_size is not None ):
and vllm_config.compilation_config.max_cudagraph_capture_size
!= valid_max_size):
if vllm_config.compilation_config.cudagraph_capture_sizes is not None: if vllm_config.compilation_config.cudagraph_capture_sizes is not None:
raise ValueError( raise ValueError(
"customized max_cudagraph_capture_size" "customized max_cudagraph_capture_size"
f"(={vllm_config.compilation_config.max_cudagraph_capture_size}) " f"(={vllm_config.compilation_config.max_cudagraph_capture_size}) "
"should be consistent with the max value of " "should be consistent with the max value of "
f"cudagraph_capture_sizes(={valid_max_size})") f"cudagraph_capture_sizes(={valid_max_size})"
)
logger.warning( logger.warning(
"Truncating max_cudagraph_capture_size to %d", "Truncating max_cudagraph_capture_size to %d",
valid_max_size, valid_max_size,
@@ -415,12 +411,11 @@ def update_cudagraph_capture_sizes(vllm_config: VllmConfig,
vllm_config.compilation_config.max_cudagraph_capture_size = valid_max_size vllm_config.compilation_config.max_cudagraph_capture_size = valid_max_size
if vllm_config.compilation_config.cudagraph_capture_sizes is not None and len( if vllm_config.compilation_config.cudagraph_capture_sizes is not None and len(cudagraph_capture_sizes) < len(
cudagraph_capture_sizes) < len( vllm_config.compilation_config.cudagraph_capture_sizes
vllm_config.compilation_config.cudagraph_capture_sizes): ):
logger.warning( logger.warning(
("cudagraph_capture_sizes specified in compilation_config" ("cudagraph_capture_sizes specified in compilation_config %s is overridden by config %s"),
" %s is overridden by config %s"),
vllm_config.compilation_config.cudagraph_capture_sizes, vllm_config.compilation_config.cudagraph_capture_sizes,
cudagraph_capture_sizes, cudagraph_capture_sizes,
) )
@@ -433,26 +428,17 @@ def _is_default_capture_sizes(vllm_config: VllmConfig) -> bool:
Check whether it is vLLM default capture sizes. Check whether it is vLLM default capture sizes.
""" """
max_cudagraph_capture_size = \ max_cudagraph_capture_size = vllm_config.compilation_config.max_cudagraph_capture_size
vllm_config.compilation_config.max_cudagraph_capture_size cudagraph_capture_sizes = [i for i in [1, 2, 4] if i <= max_cudagraph_capture_size]
cudagraph_capture_sizes = [
i for i in [1, 2, 4] if i <= max_cudagraph_capture_size
]
if max_cudagraph_capture_size >= 8: if max_cudagraph_capture_size >= 8:
# Step size 8 for small batch sizes, up to 256(not included) # Step size 8 for small batch sizes, up to 256(not included)
cudagraph_capture_sizes += list( cudagraph_capture_sizes += list(range(8, min(max_cudagraph_capture_size + 1, 256), 8))
range(8, min(max_cudagraph_capture_size + 1, 256), 8))
if max_cudagraph_capture_size >= 256: if max_cudagraph_capture_size >= 256:
# Step size 16 for larger batch sizes # Step size 16 for larger batch sizes
cudagraph_capture_sizes += list( cudagraph_capture_sizes += list(range(256, max_cudagraph_capture_size + 1, 16))
range(256, max_cudagraph_capture_size + 1, 16))
# in newer version, vLLM use ascending order of cudagraph_capture_sizes. # in newer version, vLLM use ascending order of cudagraph_capture_sizes.
target_cudagraph_capture_sizes = sorted(cudagraph_capture_sizes) target_cudagraph_capture_sizes = sorted(cudagraph_capture_sizes)
if target_cudagraph_capture_sizes == \ return target_cudagraph_capture_sizes == vllm_config.compilation_config.cudagraph_capture_sizes
vllm_config.compilation_config.cudagraph_capture_sizes:
return True
return False
def update_default_aclgraph_sizes(vllm_config: VllmConfig) -> None: def update_default_aclgraph_sizes(vllm_config: VllmConfig) -> None:
@@ -461,9 +447,11 @@ def update_default_aclgraph_sizes(vllm_config: VllmConfig) -> None:
are more friendly to ascend ops && hardware. are more friendly to ascend ops && hardware.
""" """
if vllm_config.model_config is None or \ if (
vllm_config.model_config.enforce_eager or \ vllm_config.model_config is None
not _is_default_capture_sizes(vllm_config): or vllm_config.model_config.enforce_eager
or not _is_default_capture_sizes(vllm_config)
):
return return
# modify the default capture_sizes for Qwen3-MoE models on dp settings. # modify the default capture_sizes for Qwen3-MoE models on dp settings.
@@ -471,16 +459,15 @@ def update_default_aclgraph_sizes(vllm_config: VllmConfig) -> None:
# on special shapes. # on special shapes.
# TODO(Angazenn): we will remove this once _npu_paged_attention is fully # TODO(Angazenn): we will remove this once _npu_paged_attention is fully
# replaced by npu_fused_infer_attention_score which does not contain such bugs. # replaced by npu_fused_infer_attention_score which does not contain such bugs.
if vllm_config.model_config and vllm_config.model_config.hf_text_config.model_type == "qwen3_moe" \ if (
and vllm_config.parallel_config.tensor_parallel_size == 1 \ vllm_config.model_config
and vllm_config.parallel_config.data_parallel_size > 1 : and vllm_config.model_config.hf_text_config.model_type == "qwen3_moe"
and vllm_config.parallel_config.tensor_parallel_size == 1
and vllm_config.parallel_config.data_parallel_size > 1
):
max_capture_size = vllm_config.compilation_config.max_cudagraph_capture_size max_capture_size = vllm_config.compilation_config.max_cudagraph_capture_size
new_cudagraph_capture_sizes = [1, 2, 5, 10, 15, 20] + [ new_cudagraph_capture_sizes = [1, 2, 5, 10, 15, 20] + [i for i in range(24, max_capture_size + 1, 8)]
i for i in range(24, max_capture_size + 1, 8) update_cudagraph_capture_sizes(vllm_config, new_cudagraph_capture_sizes)
]
update_cudagraph_capture_sizes(vllm_config,
new_cudagraph_capture_sizes)
def update_aclgraph_sizes(vllm_config: VllmConfig) -> None: def update_aclgraph_sizes(vllm_config: VllmConfig) -> None:
@@ -498,8 +485,7 @@ def update_aclgraph_sizes(vllm_config: VllmConfig) -> None:
# Store original configuration and temporarily clear it # Store original configuration and temporarily clear it
compilation_config = vllm_config.compilation_config compilation_config = vllm_config.compilation_config
original_sizes, compilation_config.cudagraph_capture_sizes = \ original_sizes, compilation_config.cudagraph_capture_sizes = compilation_config.cudagraph_capture_sizes, None
compilation_config.cudagraph_capture_sizes, None
# Calculate parallel configuration factor # Calculate parallel configuration factor
if not vllm_config.model_config: if not vllm_config.model_config:
@@ -510,7 +496,7 @@ def update_aclgraph_sizes(vllm_config: VllmConfig) -> None:
return return
hf_config = vllm_config.model_config.hf_text_config hf_config = vllm_config.model_config.hf_text_config
if hasattr(hf_config, 'num_hidden_layers'): if hasattr(hf_config, "num_hidden_layers"):
num_hidden_layers = hf_config.num_hidden_layers num_hidden_layers = hf_config.num_hidden_layers
else: else:
num_hidden_layers = get_max_hidden_layers(hf_config) num_hidden_layers = get_max_hidden_layers(hf_config)
@@ -519,24 +505,28 @@ def update_aclgraph_sizes(vllm_config: VllmConfig) -> None:
# Calculate maximum supported batch sizes considering model architecture # Calculate maximum supported batch sizes considering model architecture
resources_per_graph = num_hidden_layers + 1 resources_per_graph = num_hidden_layers + 1
# For suffix decoding, use the suffix path when no draft_model_config is provided. # For suffix decoding, use the suffix path when no draft_model_config is provided.
if (spec := vllm_config.speculative_config) and \ if (spec := vllm_config.speculative_config) and (draft := spec.draft_model_config):
(draft := spec.draft_model_config):
resources_per_graph += draft.hf_config.num_hidden_layers + 1 resources_per_graph += draft.hf_config.num_hidden_layers + 1
# TODO: Find out whether we need to take into account the pp_size # TODO: Find out whether we need to take into account the pp_size
num_comm_groups = sum(size > 1 for size in [ num_comm_groups = sum(
parallel_config.data_parallel_size, size > 1
parallel_config.tensor_parallel_size, for size in [
]) parallel_config.data_parallel_size,
parallel_config.tensor_parallel_size,
]
)
if os.getenv("HCCL_OP_EXPANSION_MODE") == 'AIV': if os.getenv("HCCL_OP_EXPANSION_MODE") == "AIV":
# TODO: Find out whether we need to take into account the pp_size # TODO: Find out whether we need to take into account the pp_size
parallel_factor = 1 + num_comm_groups + int( parallel_factor = (
parallel_config.enable_expert_parallel) + int( 1
vllm_config.additional_config.get( + num_comm_groups
"multistream_overlap_shared_expert", False)) + int(parallel_config.enable_expert_parallel)
+ int(vllm_config.additional_config.get("multistream_overlap_shared_expert", False))
)
if is_moe_model(vllm_config): if is_moe_model(vllm_config):
parallel_factor += (parallel_config.data_parallel_size > 1) parallel_factor += parallel_config.data_parallel_size > 1
else: else:
# When AIV mode is enabled, the allreduce operator of the dense # When AIV mode is enabled, the allreduce operator of the dense
# layer model will occupy additional streams, which are buffered here. # layer model will occupy additional streams, which are buffered here.
@@ -546,11 +536,8 @@ def update_aclgraph_sizes(vllm_config: VllmConfig) -> None:
# Assume the following case: # Assume the following case:
# MAX_CAPTURE_SIZE = 1920, num_hidden_layers = 48, data_parallel_size is 1, tensor_parallel_size is 4, # MAX_CAPTURE_SIZE = 1920, num_hidden_layers = 48, data_parallel_size is 1, tensor_parallel_size is 4,
# According to the formula, max_num_batch_sizes = math.floor(1920 / (48 + 1) / 2) = 19 # According to the formula, max_num_batch_sizes = math.floor(1920 / (48 + 1) / 2) = 19
max_num_batch_sizes = math.floor(MAX_CAPTURE_SIZE / max_num_batch_sizes = math.floor(MAX_CAPTURE_SIZE / resources_per_graph / parallel_factor)
resources_per_graph / parallel_factor) logger.info("Calculated maximum supported batch sizes for ACL graph: %s", max_num_batch_sizes)
logger.info(
"Calculated maximum supported batch sizes for ACL graph: %s",
max_num_batch_sizes)
else: else:
# enable pcp or dcp will add new communication and consume additional approximately less than 100 streams # enable pcp or dcp will add new communication and consume additional approximately less than 100 streams
if parallel_config.prefill_context_parallel_size > 1: if parallel_config.prefill_context_parallel_size > 1:
@@ -562,18 +549,18 @@ def update_aclgraph_sizes(vllm_config: VllmConfig) -> None:
# Under this configuration, HCCL employs the FFTS+ method for execution unfolding, # Under this configuration, HCCL employs the FFTS+ method for execution unfolding,
# which adds only 1 concurrent stream without consuming collective communication execution unfolding streams. # which adds only 1 concurrent stream without consuming collective communication execution unfolding streams.
# On A3 hardware, HCCL defaults to the AICPU method. # On A3 hardware, HCCL defaults to the AICPU method.
# This approach may additionally allocate up to rank_size (max 16) - 1 streams per collective communication domain on the device (worst case). # This approach may additionally allocate up to rank_size (max 16) - 1 streams per collective communication
# Using the default collective communication unfolding method on A3 will lead to a significant reduction in the maximum supported sizes. # domain on the device (worst case).
# Using the default collective communication unfolding method on A3 will lead to a significant reduction
# in the maximum supported sizes.
# Therefore, the calculation formula has been modified as follows: # Therefore, the calculation formula has been modified as follows:
# Assume the following case: # Assume the following case:
# MAX_CAPTURE_SIZE = 1920, num_hidden_layers = 48, data_parallel_size is 1, tensor_parallel_size is 4, # MAX_CAPTURE_SIZE = 1920, num_hidden_layers = 48, data_parallel_size is 1, tensor_parallel_size is 4,
# According to the formula, max_num_batch_sizes = math.floor((1920 - 1 * 40) / (48 + 1) / (1 + 1 * 2)) = 12 # According to the formula, max_num_batch_sizes = math.floor((1920 - 1 * 40) / (48 + 1) / (1 + 1 * 2)) = 12
max_num_batch_sizes = math.floor( max_num_batch_sizes = math.floor(
(MAX_CAPTURE_SIZE - num_comm_groups * 40) / resources_per_graph / (MAX_CAPTURE_SIZE - num_comm_groups * 40) / resources_per_graph / (1 + num_comm_groups * 2)
(1 + num_comm_groups * 2)) )
logger.info( logger.info("Calculated maximum supported batch sizes for ACL graph: %s", max_num_batch_sizes)
"Calculated maximum supported batch sizes for ACL graph: %s",
max_num_batch_sizes)
logger.warning( logger.warning(
"Currently, communication is performed using FFTS+ method, which reduces " "Currently, communication is performed using FFTS+ method, which reduces "
"the number of available streams and, as a result, limits the range of runtime " "the number of available streams and, as a result, limits the range of runtime "
@@ -598,26 +585,29 @@ def update_aclgraph_sizes(vllm_config: VllmConfig) -> None:
vllm_config.model_config.architectures[0], vllm_config.model_config.architectures[0],
num_hidden_layers, num_hidden_layers,
len(original_sizes), len(original_sizes),
len(compilation_config. len(
cudagraph_capture_sizes # type: ignore[arg-type] compilation_config.cudagraph_capture_sizes # type: ignore[arg-type]
)) ),
)
else: else:
# No adjustment needed # No adjustment needed
compilation_config.cudagraph_capture_sizes = original_sizes compilation_config.cudagraph_capture_sizes = original_sizes
logger.info( logger.info(
"No adjustment needed for ACL graph batch sizes: %s model (layers: %d) with %d sizes", "No adjustment needed for ACL graph batch sizes: %s model (layers: %d) with %d sizes",
vllm_config.model_config.architectures[0], num_hidden_layers, vllm_config.model_config.architectures[0],
len(original_sizes)) num_hidden_layers,
len(original_sizes),
)
# TODO(wxy): Move to ops module # TODO(wxy): Move to ops module
def dispose_tensor(x: torch.Tensor): def dispose_tensor(x: torch.Tensor):
x.set_(torch.empty((0, ), device=x.device, dtype=x.dtype)) x.set_(torch.empty((0,), device=x.device, dtype=x.dtype))
class ProfileExecuteDuration: class ProfileExecuteDuration:
_instance = None _instance = None
_observations: List[Tuple[str, Event, Event]] = [] _observations: list[tuple[str, Event, Event]] = []
_lock = Lock() _lock = Lock()
def __new__(cls): def __new__(cls):
@@ -645,8 +635,7 @@ class ProfileExecuteDuration:
observe_end = Event(enable_timing=True) observe_end = Event(enable_timing=True)
observe_end.record() observe_end.record()
with self._lock: with self._lock:
self._observations.append( self._observations.append((duration_tag, observe_start, observe_end))
(duration_tag, observe_start, observe_end))
def pop_captured_sync(self) -> dict: def pop_captured_sync(self) -> dict:
"""Pop and synchronize all events in the observation list""" """Pop and synchronize all events in the observation list"""
@@ -663,7 +652,7 @@ class ProfileExecuteDuration:
return durations return durations
def register_ascend_customop(vllm_config: Optional[VllmConfig] = None): def register_ascend_customop(vllm_config: VllmConfig | None = None):
"""Register Ascend CustomOP """Register Ascend CustomOP
NOTE: if the register branch requires model type, please use `vllm.config.get_current_vllm_config`, NOTE: if the register branch requires model type, please use `vllm.config.get_current_vllm_config`,
@@ -675,23 +664,29 @@ def register_ascend_customop(vllm_config: Optional[VllmConfig] = None):
from vllm.model_executor.custom_op import CustomOp from vllm.model_executor.custom_op import CustomOp
from vllm_ascend.ops.activation import AscendQuickGELU, AscendSiluAndMul from vllm_ascend.ops.activation import AscendQuickGELU, AscendSiluAndMul
from vllm_ascend.ops.fused_moe.fused_moe import (AscendFusedMoE, from vllm_ascend.ops.fused_moe.fused_moe import AscendFusedMoE, AscendSharedFusedMoE
AscendSharedFusedMoE)
from vllm_ascend.ops.layernorm import AscendGemmaRMSNorm, AscendRMSNorm from vllm_ascend.ops.layernorm import AscendGemmaRMSNorm, AscendRMSNorm
from vllm_ascend.ops.linear import (AscendColumnParallelLinear, from vllm_ascend.ops.linear import (
AscendMergedColumnParallelLinear, AscendColumnParallelLinear,
AscendQKVParallelLinear, AscendMergedColumnParallelLinear,
AscendReplicatedLinear, AscendQKVParallelLinear,
AscendRowParallelLinear) AscendReplicatedLinear,
AscendRowParallelLinear,
)
from vllm_ascend.ops.mla import AscendMultiHeadLatentAttention from vllm_ascend.ops.mla import AscendMultiHeadLatentAttention
from vllm_ascend.ops.mm_encoder_attention import AscendMMEncoderAttention from vllm_ascend.ops.mm_encoder_attention import AscendMMEncoderAttention
from vllm_ascend.ops.rotary_embedding import ( from vllm_ascend.ops.rotary_embedding import (
AscendApplyRotaryEmb, AscendDeepseekScalingRotaryEmbedding, AscendApplyRotaryEmb,
AscendMRotaryEmbedding, AscendRotaryEmbedding, AscendDeepseekScalingRotaryEmbedding,
AscendYaRNRotaryEmbedding) AscendMRotaryEmbedding,
AscendRotaryEmbedding,
AscendYaRNRotaryEmbedding,
)
from vllm_ascend.ops.vocab_parallel_embedding import ( from vllm_ascend.ops.vocab_parallel_embedding import (
AscendLogitsProcessor, AscendParallelLMHead, AscendLogitsProcessor,
AscendVocabParallelEmbedding) AscendParallelLMHead,
AscendVocabParallelEmbedding,
)
global REGISTERED_ASCEND_OPS global REGISTERED_ASCEND_OPS
REGISTERED_ASCEND_OPS = { REGISTERED_ASCEND_OPS = {
@@ -738,6 +733,7 @@ _ascend_device_type = None
def _init_ascend_device_type(): def _init_ascend_device_type():
global _ascend_device_type global _ascend_device_type
from vllm_ascend import _build_info # type: ignore from vllm_ascend import _build_info # type: ignore
_ascend_device_type = AscendDeviceType[_build_info.__device_type__] _ascend_device_type = AscendDeviceType[_build_info.__device_type__]
@@ -758,7 +754,10 @@ def check_ascend_device_type():
else: else:
raise RuntimeError(f"Can not support soc_version: {soc_version}.") raise RuntimeError(f"Can not support soc_version: {soc_version}.")
assert _ascend_device_type == cur_device_type, f"Current device type: {cur_device_type} does not match the installed version's device type: {_ascend_device_type}, please check your installation package." assert _ascend_device_type == cur_device_type, (
f"Current device type: {cur_device_type} does not match the installed version's device type: "
f"{_ascend_device_type}, please check your installation package."
)
def get_ascend_device_type(): def get_ascend_device_type():
@@ -769,23 +768,19 @@ def get_ascend_device_type():
def lmhead_tp_enable() -> bool: def lmhead_tp_enable() -> bool:
return get_ascend_config( return get_ascend_config().finegrained_tp_config.lmhead_tensor_parallel_size > 0
).finegrained_tp_config.lmhead_tensor_parallel_size > 0
def embedding_tp_enable() -> bool: def embedding_tp_enable() -> bool:
return get_ascend_config( return get_ascend_config().finegrained_tp_config.embedding_tensor_parallel_size > 0
).finegrained_tp_config.embedding_tensor_parallel_size > 0
def oproj_tp_enable() -> bool: def oproj_tp_enable() -> bool:
return get_ascend_config( return get_ascend_config().finegrained_tp_config.oproj_tensor_parallel_size > 0
).finegrained_tp_config.oproj_tensor_parallel_size > 0
def mlp_tp_enable() -> bool: def mlp_tp_enable() -> bool:
return get_ascend_config( return get_ascend_config().finegrained_tp_config.mlp_tensor_parallel_size > 0
).finegrained_tp_config.mlp_tensor_parallel_size > 0
def matmul_allreduce_enable() -> bool: def matmul_allreduce_enable() -> bool:
@@ -797,30 +792,30 @@ def enable_sp(vllm_config=None, enable_shared_expert_dp: bool = False) -> bool:
if _ENABLE_SP is None: if _ENABLE_SP is None:
if vllm_config is None: if vllm_config is None:
from vllm.config import get_current_vllm_config from vllm.config import get_current_vllm_config
vllm_config = get_current_vllm_config() vllm_config = get_current_vllm_config()
_ENABLE_SP = ( _ENABLE_SP = (
vllm_config.compilation_config.pass_config.enable_sp vllm_config.compilation_config.pass_config.enable_sp
or envs_ascend.VLLM_ASCEND_ENABLE_FLASHCOMM1 or envs_ascend.VLLM_ASCEND_ENABLE_FLASHCOMM1
# Flash comm 1 should be enabled by env VLLM_ASCEND_ENABLE_FLASHCOMM1 # Flash comm 1 should be enabled by env VLLM_ASCEND_ENABLE_FLASHCOMM1
# We retain the env VLLM_ASCEND_ENABLE_FLASHCOMM here for backward compatibility. # We retain the env VLLM_ASCEND_ENABLE_FLASHCOMM here for backward compatibility.
or bool(int(os.getenv("VLLM_ASCEND_ENABLE_FLASHCOMM", '0')))) or bool(int(os.getenv("VLLM_ASCEND_ENABLE_FLASHCOMM", "0")))
)
if not _ENABLE_SP and enable_shared_expert_dp: if not _ENABLE_SP and enable_shared_expert_dp:
_ENABLE_SP = True _ENABLE_SP = True
logger.info( logger.info("shared_expert_dp requires enable_sp = True. has set enable_sp to True")
"shared_expert_dp requires enable_sp = True. has set enable_sp to True"
)
if not _ENABLE_SP: if not _ENABLE_SP:
return _ENABLE_SP return _ENABLE_SP
assert vllm_config.parallel_config.tensor_parallel_size > 1, \ assert vllm_config.parallel_config.tensor_parallel_size > 1, (
"Flash Comm v1 (Sequence Parallelism) is only supported when tp_size > 1." "Flash Comm v1 (Sequence Parallelism) is only supported when tp_size > 1."
)
assert ( assert not is_moe_model(vllm_config) or vllm_config.parallel_config.enable_expert_parallel, (
not is_moe_model(vllm_config) "Flash Comm v1 (Sequence Parallelism) requires enable_expert_parallel=True for MoE models."
or vllm_config.parallel_config.enable_expert_parallel )
), "Flash Comm v1 (Sequence Parallelism) requires enable_expert_parallel=True for MoE models."
return _ENABLE_SP return _ENABLE_SP
@@ -847,25 +842,21 @@ def is_drafter_moe_model(vllm_config: VllmConfig):
"""Checks if the drafter model is a MoE model by config""" """Checks if the drafter model is a MoE model by config"""
global _IS_DRAFTER_MOE_MODEL global _IS_DRAFTER_MOE_MODEL
if _IS_DRAFTER_MOE_MODEL is None: if _IS_DRAFTER_MOE_MODEL is None:
model_configs = vllm_config.speculative_config.draft_model_config.hf_text_config \ model_configs = vllm_config.speculative_config.draft_model_config.hf_text_config.to_dict()
.to_dict()
_IS_DRAFTER_MOE_MODEL = _is_contain_expert(model_configs) _IS_DRAFTER_MOE_MODEL = _is_contain_expert(model_configs)
return _IS_DRAFTER_MOE_MODEL return _IS_DRAFTER_MOE_MODEL
def speculative_enable_dispatch_gmm_combine_decode( def speculative_enable_dispatch_gmm_combine_decode(vllm_config: VllmConfig) -> bool:
vllm_config: VllmConfig) -> bool:
if vllm_config.speculative_config is None: if vllm_config.speculative_config is None:
return True return True
speculative_method = getattr(vllm_config.speculative_config, "method", speculative_method = getattr(vllm_config.speculative_config, "method", None)
None)
if speculative_method in [None, "ngram", "suffix"]: if speculative_method in [None, "ngram", "suffix"]:
return True return True
if speculative_method in ["eagle", "eagle3"]: if speculative_method in ["eagle", "eagle3"]:
return False return False
if speculative_method == "mtp": if speculative_method == "mtp":
mtp_quant_type = getattr(vllm_config.model_config.hf_text_config, mtp_quant_type = getattr(vllm_config.model_config.hf_text_config, "mtp_quantize", None)
"mtp_quantize", None)
return mtp_quant_type == "w8a8_dynamic" return mtp_quant_type == "w8a8_dynamic"
return False return False
@@ -915,8 +906,8 @@ def weak_ref_tensor(tensor: Any) -> Any:
def weak_ref_tensors( def weak_ref_tensors(
tensors: Union[torch.Tensor, list[torch.Tensor], tuple[torch.Tensor]] tensors: torch.Tensor | list[torch.Tensor] | tuple[torch.Tensor],
) -> Union[torch.Tensor, list[Any], tuple[Any], Any]: ) -> torch.Tensor | list[Any] | tuple[Any] | Any:
""" """
Convenience function to create weak references to tensors, Convenience function to create weak references to tensors,
for single tensor, list of tensors or tuple of tensors. for single tensor, list of tensors or tuple of tensors.
@@ -936,17 +927,12 @@ def weak_ref_tensors(
return tuple(weak_ref_tensor(t) for t in tensors) return tuple(weak_ref_tensor(t) for t in tensors)
# For IntermediateTensors used in pipeline parallelism # For IntermediateTensors used in pipeline parallelism
if isinstance(tensors, IntermediateTensors): if isinstance(tensors, IntermediateTensors):
ret = IntermediateTensors({ ret = IntermediateTensors({key: weak_ref_tensor(val) for key, val in tensors.tensors.items()})
key: weak_ref_tensor(val)
for key, val in tensors.tensors.items()
})
return ret return ret
raise ValueError("Invalid type for tensors") raise ValueError("Invalid type for tensors")
def npu_stream_switch(target_stream: torch.npu.Stream, def npu_stream_switch(target_stream: torch.npu.Stream, *, enabled: bool = True):
*,
enabled: bool = True):
""" """
Switch to the target stream if enabled is True. Switch to the target stream if enabled is True.
Otherwise, do nothing. Otherwise, do nothing.
@@ -965,7 +951,7 @@ def create_hccl_pg_options(group_name: str):
return options return options
def get_hccl_config_for_pg_options(group_name: str) -> Optional[dict]: def get_hccl_config_for_pg_options(group_name: str) -> dict | None:
""" """
Get HCCL process group options for the given communication group name. Get HCCL process group options for the given communication group name.
@@ -981,9 +967,7 @@ def get_hccl_config_for_pg_options(group_name: str) -> Optional[dict]:
if group_name and "mc2" in group_name: if group_name and "mc2" in group_name:
return None return None
hccl_config_map = { hccl_config_map = {
"dp": { "dp": {"hccl_buffer_size": calculate_dp_buffer_size()},
"hccl_buffer_size": calculate_dp_buffer_size()
},
} }
return hccl_config_map.get(group_name, get_default_buffer_config()) return hccl_config_map.get(group_name, get_default_buffer_config())
@@ -998,6 +982,7 @@ def calculate_dp_buffer_size() -> int:
dp_size + 1 (flags: with_prefill) dp_size + 1 (flags: with_prefill)
""" """
from vllm.config import get_current_vllm_config from vllm.config import get_current_vllm_config
vllm_config = get_current_vllm_config() vllm_config = get_current_vllm_config()
dp_size = vllm_config.parallel_config.data_parallel_size dp_size = vllm_config.parallel_config.data_parallel_size
int32_size = torch.iinfo(torch.int32).bits // 8 int32_size = torch.iinfo(torch.int32).bits // 8
@@ -1009,8 +994,7 @@ def calculate_dp_buffer_size() -> int:
# and HCCL_INTRA_ROCE_ENABLE=0 can reduce cross-machine communication traffic and # and HCCL_INTRA_ROCE_ENABLE=0 can reduce cross-machine communication traffic and
# significantly improve communication performance of MC2 ops dispatch/combine. # significantly improve communication performance of MC2 ops dispatch/combine.
def is_hierarchical_communication_enabled(): def is_hierarchical_communication_enabled():
return (os.getenv("HCCL_INTRA_ROCE_ENABLE", "") == "0" return os.getenv("HCCL_INTRA_ROCE_ENABLE", "") == "0" and os.getenv("HCCL_INTRA_PCIE_ENABLE", "") == "1"
and os.getenv("HCCL_INTRA_PCIE_ENABLE", "") == "1")
def has_layer_idx(model_instance: torch.nn.Module) -> bool: def has_layer_idx(model_instance: torch.nn.Module) -> bool:
@@ -1019,8 +1003,7 @@ def has_layer_idx(model_instance: torch.nn.Module) -> bool:
global _HAS_LAYER_IDX global _HAS_LAYER_IDX
if _HAS_LAYER_IDX is None: if _HAS_LAYER_IDX is None:
_HAS_LAYER_IDX = hasattr(model_instance, "model") and \ _HAS_LAYER_IDX = hasattr(model_instance, "model") and hasattr(model_instance.model, "start_layer")
hasattr(model_instance.model, "start_layer")
return _HAS_LAYER_IDX return _HAS_LAYER_IDX
@@ -1042,20 +1025,17 @@ def get_flashcomm2_config_and_validate(ascend_config, vllm_config):
if not flashcomm2_enable(): if not flashcomm2_enable():
return 0 return 0
logger.info( logger.info(f"Enable FLASHCOMM2 with flashcomm2_oproj_tensor_parallel_size = {flashcomm2_oproj_tp_size}")
f"Enable FLASHCOMM2 with flashcomm2_oproj_tensor_parallel_size = {flashcomm2_oproj_tp_size}"
)
layer_sharding = ascend_config.layer_sharding or [] layer_sharding = ascend_config.layer_sharding or []
if layer_sharding: if layer_sharding:
if layer_sharding == ["o_proj"]: if layer_sharding == ["o_proj"]:
logger.info_once( logger.info_once("Enable FLASHCOMM2 with o_proj layer sharding for reduced memory consumption.")
"Enable FLASHCOMM2 with o_proj layer sharding for reduced memory consumption."
)
else: else:
raise ValueError( raise ValueError(
"FLASHCOMM2 only supports 'o_proj' as the sole layer sharding configuration! " "FLASHCOMM2 only supports 'o_proj' as the sole layer sharding configuration! "
f"Found invalid layer_sharding: {layer_sharding}") f"Found invalid layer_sharding: {layer_sharding}"
)
if not envs_ascend.VLLM_ASCEND_ENABLE_FLASHCOMM1: if not envs_ascend.VLLM_ASCEND_ENABLE_FLASHCOMM1:
logger.warning_once( logger.warning_once(
"It is recommended to enable FLASHCOMM1 simultaneously when starting FLASHCOMM2 for optimal performance." "It is recommended to enable FLASHCOMM1 simultaneously when starting FLASHCOMM2 for optimal performance."
@@ -1066,32 +1046,35 @@ def get_flashcomm2_config_and_validate(ascend_config, vllm_config):
) )
if global_tp_size <= flashcomm2_oproj_tp_size: if global_tp_size <= flashcomm2_oproj_tp_size:
raise AssertionError( raise AssertionError(
f"flashcomm2_oproj_tensor_parallel_size ({flashcomm2_oproj_tp_size}) cannot exceed global tensor parallel size ({global_tp_size})" f"flashcomm2_oproj_tensor_parallel_size ({flashcomm2_oproj_tp_size}) cannot exceed "
f"global tensor parallel size ({global_tp_size})"
) )
if global_tp_size % flashcomm2_oproj_tp_size != 0: if global_tp_size % flashcomm2_oproj_tp_size != 0:
raise AssertionError( raise AssertionError(
f"Global tensor parallel size ({global_tp_size}) must be divisible by flashcomm2_oproj_tensor_parallel_size ({flashcomm2_oproj_tp_size})" f"Global tensor parallel size ({global_tp_size}) must be divisible by "
f"flashcomm2_oproj_tensor_parallel_size ({flashcomm2_oproj_tp_size})"
) )
if vllm_config.kv_transfer_config is None: if vllm_config.kv_transfer_config is None:
logger.warning_once( logger.warning_once(
"It is recommended to enable FLASHCOMM2 in P-scenario deployments, enable it in hybrid deployment may lead to decode performance degradation." "It is recommended to enable FLASHCOMM2 in P-scenario deployments, enable it in hybrid deployment "
"may lead to decode performance degradation."
) )
if vllm_config.kv_transfer_config is not None and vllm_config.kv_transfer_config.is_kv_consumer: if vllm_config.kv_transfer_config is not None and vllm_config.kv_transfer_config.is_kv_consumer:
raise AssertionError( raise AssertionError(
"FLASHCOMM2 primarily targets P-scenario deployments, with additional support for hybrid deployment scenarios. It is not applicable in D-scenario environments." "FLASHCOMM2 primarily targets P-scenario deployments, with additional support "
"for hybrid deployment scenarios. It is not applicable in D-scenario environments."
) )
return flashcomm2_oproj_tp_size return flashcomm2_oproj_tp_size
def get_flashcomm2_reorgnized_batch_ids(global_tp_size) -> list[list[int]]: def get_flashcomm2_reorgnized_batch_ids(global_tp_size) -> list[list[int]]:
# Reorganize batch_ids so that, after the all2all and reduce-scatter operation, each batch_id corresponds to the rank_id within the DP domain. # Reorganize batch_ids so that, after the all2all and reduce-scatter operation,
# each batch_id corresponds to the rank_id within the DP domain.
# For example, when DP = [0, 1, 2, ..., 15] and flashcomm2_oproj_tensor_parallel_size = 2, # For example, when DP = [0, 1, 2, ..., 15] and flashcomm2_oproj_tensor_parallel_size = 2,
# the reorganized batch_ids will be [[batch0, batch8], [batch1, batch9], ..., [batch7, batch15]]. # the reorganized batch_ids will be [[batch0, batch8], [batch1, batch9], ..., [batch7, batch15]].
flashcomm2_otp_size = get_ascend_config( flashcomm2_otp_size = get_ascend_config().flashcomm2_oproj_tensor_parallel_size
).flashcomm2_oproj_tensor_parallel_size num_oproj_tensor_parallel_groups: int = global_tp_size // flashcomm2_otp_size
num_oproj_tensor_parallel_groups: int = (global_tp_size //
flashcomm2_otp_size)
reorgnized_batch_ids = [] reorgnized_batch_ids = []
for i in range(num_oproj_tensor_parallel_groups): for i in range(num_oproj_tensor_parallel_groups):
@@ -1122,11 +1105,9 @@ def refresh_block_size(vllm_config):
return return
# TODO(MengqingCao): Remove the model_type check, after resolving the hidden error in get_kv_cache_groups. # TODO(MengqingCao): Remove the model_type check, after resolving the hidden error in get_kv_cache_groups.
if not model_config.hf_text_config.model_type == "qwen3_next" and cache_config.block_size != 128: if model_config.hf_text_config.model_type != "qwen3_next" and cache_config.block_size != 128:
if cache_config.enable_prefix_caching or scheduler_config.enable_chunked_prefill: if cache_config.enable_prefix_caching or scheduler_config.enable_chunked_prefill:
logger.info( logger.info("Block size is set to 128 if prefix cache or chunked prefill is enabled.")
"Block size is set to 128 if prefix cache or chunked prefill is enabled."
)
cache_config.block_size = 128 cache_config.block_size = 128
@@ -1138,7 +1119,6 @@ def dispose_layer(layer: Any):
def check_kv_extra_config(vllm_config): def check_kv_extra_config(vllm_config):
def _check(name: str, config: dict): def _check(name: str, config: dict):
tp_key = "tp_size" tp_key = "tp_size"
dp_key = "dp_size" dp_key = "dp_size"
@@ -1148,24 +1128,21 @@ def check_kv_extra_config(vllm_config):
if config_tp != vllm_tp: if config_tp != vllm_tp:
raise ValueError( raise ValueError(
f"KV transfer '{name}' config has a conflicting tensor parallel size. " f"KV transfer '{name}' config has a conflicting tensor parallel size. "
f"Expected {vllm_tp}, but got {config_tp}.") f"Expected {vllm_tp}, but got {config_tp}."
)
if dp_key in config: if dp_key in config:
config_dp = config[dp_key] config_dp = config[dp_key]
vllm_dp = vllm_config.parallel_config.data_parallel_size vllm_dp = vllm_config.parallel_config.data_parallel_size
if config_dp != vllm_dp: if config_dp != vllm_dp:
raise ValueError( raise ValueError(
f"KV transfer '{name}' config has a conflicting data parallel size. " f"KV transfer '{name}' config has a conflicting data parallel size. "
f"Expected {vllm_dp}, but got {config_dp}.") f"Expected {vllm_dp}, but got {config_dp}."
)
if vllm_config.kv_transfer_config.is_kv_producer: if vllm_config.kv_transfer_config.is_kv_producer:
_check( _check("prefill", vllm_config.kv_transfer_config.get_from_extra_config("prefill", {}))
"prefill",
vllm_config.kv_transfer_config.get_from_extra_config(
"prefill", {}))
if vllm_config.kv_transfer_config.is_kv_consumer: if vllm_config.kv_transfer_config.is_kv_consumer:
_check( _check("decode", vllm_config.kv_transfer_config.get_from_extra_config("decode", {}))
"decode",
vllm_config.kv_transfer_config.get_from_extra_config("decode", {}))
def singleton(cls): def singleton(cls):
@@ -1179,17 +1156,17 @@ def singleton(cls):
return get_instance return get_instance
#TODO: Temporarily use enable_sp to enable the dsa_cp feature of ds32. and subsequent updates will introduce new interfaces. --zzhx1 # TODO: Temporarily use enable_sp to enable the dsa_cp feature of ds32.
# and subsequent updates will introduce new interfaces. --zzhx1
@lru_cache(maxsize=1) @lru_cache(maxsize=1)
def enable_dsa_cp() -> bool: def enable_dsa_cp() -> bool:
from vllm.config import get_current_vllm_config from vllm.config import get_current_vllm_config
vllm_config = get_current_vllm_config() vllm_config = get_current_vllm_config()
is_ds_v32 = hasattr( is_ds_v32 = hasattr(vllm_config.model_config, "hf_text_config") and hasattr(
vllm_config.model_config, "hf_text_config") and hasattr( vllm_config.model_config.hf_text_config, "index_topk"
vllm_config.model_config.hf_text_config, "index_topk") )
if is_ds_v32 and enable_sp(): return bool(is_ds_v32 and enable_sp())
return True
return False
@lru_cache(maxsize=1) @lru_cache(maxsize=1)
@@ -1197,6 +1174,7 @@ def enable_dsa_cp_with_layer_shard() -> bool:
if not enable_dsa_cp(): if not enable_dsa_cp():
return False return False
from vllm.config import get_current_vllm_config from vllm.config import get_current_vllm_config
vllm_config = get_current_vllm_config() vllm_config = get_current_vllm_config()
is_prefill_instance = vllm_config.kv_transfer_config is not None and vllm_config.kv_transfer_config.is_kv_producer is_prefill_instance = vllm_config.kv_transfer_config is not None and vllm_config.kv_transfer_config.is_kv_producer
return is_prefill_instance return is_prefill_instance