[Lint]Style: Convert vllm-ascend/compilation to ruff format (#5912)
### What this PR does / why we need it?
Convert `vllm-ascend/compilation` to ruff format.
### Does this PR introduce _any_ user-facing change?
During this migration, we encountered some **errors** in our CI and
testing environments, such as:
```
vllm_ascend/utils.py:653: in <module>
def register_ascend_customop(vllm_config: VllmConfig | None = None):
^^^^^^^^^^^^^^^^^
E TypeError: unsupported operand type(s) for |: 'NoneType' and 'NoneType'
```
**1. Root Cause Analysis:**
The project uses a common pattern to break circular dependencies:
```python
if TYPE_CHECKING:
from vllm.config import VllmConfig
else:
VllmConfig = None # Placeholder assigned at runtime
```
When Python parses the function definition `def
register_ascend_customop(vllm_config: VllmConfig | None)`, it attempts
to evaluate the expression `VllmConfig | None`.
Since `VllmConfig` is assigned `None` at runtime, the expression
effectively becomes `None | None`. In Python, `None` is an instance of
`NoneType`. While the `|` operator is implemented for Type objects
(classes), it is not supported for `NoneType` instances, leading to the
`TypeError` shown above.
**2. Solution:**
To maintain the modern `|` syntax required by our new linting standards
while preserving our dependency management strategy, I have introduced:
```python
from __future__ import annotations
```
at the top of the affected files. This enables **Postponed Evaluation of
Annotations (PEP 563)**.
**3. Impact and Benefits:**
- By enabling `annotations`, Python no longer executes the `VllmConfig |
None` operation during module load. Instead, it stores the annotation as
a string literal, completely avoiding the `None | None` calculation.
- We can keep the `VllmConfig = None` placeholders. This ensures that
other modules can still import these symbols without triggering an
`ImportError`, maintaining a stable dependency graph.
- IDEs and static type checkers (MyPy/Pyright) continue to resolve the
types correctly. This allows us to use modern syntax without sacrificing
type safety or runtime stability.
- The only side effect is that `__annotations__` will now return strings
instead of type objects. Since this module does not use runtime type
enforcement or reflection, this change has zero negative impact on
existing functionality.
### How was this patch tested?
- vLLM version: v0.13.0
- vLLM main:
11b6af5280
---------
Signed-off-by: MrZ20 <2609716663@qq.com>
This commit is contained in:
@@ -49,7 +49,25 @@ line-length = 120
|
||||
# Folder to be modified
|
||||
exclude = [
|
||||
"tests/**",
|
||||
"vllm_ascend/**",
|
||||
"vllm_ascend/_cann_ops_custom",
|
||||
"vllm_ascend/attention",
|
||||
"vllm_ascend/core",
|
||||
"vllm_ascend/device",
|
||||
"vllm_ascend/device_allocator",
|
||||
"vllm_ascend/distributed",
|
||||
"vllm_ascend/eplb",
|
||||
"vllm_ascend/kv_offload",
|
||||
"vllm_ascend/lora",
|
||||
"vllm_ascend/model_loader",
|
||||
"vllm_ascend/ops",
|
||||
"vllm_ascend/patch",
|
||||
"vllm_ascend/quantization",
|
||||
"vllm_ascend/sample",
|
||||
"vllm_ascend/spec_decode",
|
||||
"vllm_ascend/worker",
|
||||
"vllm_ascend/xlite",
|
||||
"vllm_ascend/envs.py",
|
||||
"vllm_ascend/batch_invariant.py",
|
||||
]
|
||||
|
||||
[tool.ruff.lint]
|
||||
|
||||
@@ -24,14 +24,17 @@ def register():
|
||||
|
||||
def register_connector():
|
||||
from vllm_ascend.distributed.kv_transfer import register_connector
|
||||
|
||||
register_connector()
|
||||
|
||||
|
||||
def register_model_loader():
|
||||
from .model_loader.netloader import register_netloader
|
||||
|
||||
register_netloader()
|
||||
|
||||
|
||||
def register_service_profiling():
|
||||
from .profiling_config import generate_service_profiling_config
|
||||
|
||||
generate_service_profiling_config()
|
||||
|
||||
@@ -14,7 +14,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import os
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from vllm.logger import logger
|
||||
from vllm.triton_utils import HAS_TRITON
|
||||
@@ -32,18 +32,13 @@ class AscendConfig:
|
||||
additional_config = vllm_config.additional_config if vllm_config.additional_config is not None else {}
|
||||
|
||||
xlite_graph_config = additional_config.get("xlite_graph_config", {})
|
||||
self.xlite_graph_config = XliteGraphConfig(xlite_graph_config,
|
||||
vllm_config)
|
||||
self.xlite_graph_config = XliteGraphConfig(xlite_graph_config, vllm_config)
|
||||
|
||||
ascend_compilation_config = additional_config.get(
|
||||
"ascend_compilation_config", {})
|
||||
self.ascend_compilation_config = AscendCompilationConfig(
|
||||
**ascend_compilation_config)
|
||||
ascend_compilation_config = additional_config.get("ascend_compilation_config", {})
|
||||
self.ascend_compilation_config = AscendCompilationConfig(**ascend_compilation_config)
|
||||
|
||||
finegrained_tp_config = additional_config.get("finegrained_tp_config",
|
||||
{})
|
||||
self.finegrained_tp_config = FinegrainedTPConfig(
|
||||
finegrained_tp_config, vllm_config)
|
||||
finegrained_tp_config = additional_config.get("finegrained_tp_config", {})
|
||||
self.finegrained_tp_config = FinegrainedTPConfig(finegrained_tp_config, vllm_config)
|
||||
|
||||
eplb_config = additional_config.get("eplb_config", {})
|
||||
self.eplb_config = EplbConfig(eplb_config)
|
||||
@@ -51,10 +46,8 @@ class AscendConfig:
|
||||
# Dump / PrecisionDebugger configuration
|
||||
self.dump_config_path = additional_config.get("dump_config_path", None)
|
||||
|
||||
weight_prefetch_config = additional_config.get(
|
||||
"weight_prefetch_config", {})
|
||||
self.weight_prefetch_config = WeightPrefetchConfig(
|
||||
weight_prefetch_config)
|
||||
weight_prefetch_config = additional_config.get("weight_prefetch_config", {})
|
||||
self.weight_prefetch_config = WeightPrefetchConfig(weight_prefetch_config)
|
||||
self.layer_sharding = additional_config.get("layer_sharding", None)
|
||||
logger.info_once(
|
||||
f"Linear layer sharding enabled with config: {self.layer_sharding}. "
|
||||
@@ -62,30 +55,25 @@ class AscendConfig:
|
||||
"using it without these features may result in significant performance degradation."
|
||||
)
|
||||
|
||||
self.enable_shared_expert_dp = additional_config.get(
|
||||
"enable_shared_expert_dp",
|
||||
False) and vllm_config.parallel_config.enable_expert_parallel
|
||||
self.enable_shared_expert_dp = (
|
||||
additional_config.get("enable_shared_expert_dp", False)
|
||||
and vllm_config.parallel_config.enable_expert_parallel
|
||||
)
|
||||
if self.enable_shared_expert_dp:
|
||||
from vllm_ascend.utils import enable_sp
|
||||
assert enable_sp(vllm_config=vllm_config,
|
||||
enable_shared_expert_dp=True)
|
||||
self.multistream_overlap_shared_expert = additional_config.get(
|
||||
"multistream_overlap_shared_expert", False)
|
||||
self.multistream_overlap_gate = additional_config.get(
|
||||
"multistream_overlap_gate", False)
|
||||
self.recompute_scheduler_enable = additional_config.get(
|
||||
"recompute_scheduler_enable", False)
|
||||
self.enable_cpu_binding = additional_config.get(
|
||||
"enable_cpu_binding", False)
|
||||
|
||||
assert enable_sp(vllm_config=vllm_config, enable_shared_expert_dp=True)
|
||||
self.multistream_overlap_shared_expert = additional_config.get("multistream_overlap_shared_expert", False)
|
||||
self.multistream_overlap_gate = additional_config.get("multistream_overlap_gate", False)
|
||||
self.recompute_scheduler_enable = additional_config.get("recompute_scheduler_enable", False)
|
||||
self.enable_cpu_binding = additional_config.get("enable_cpu_binding", False)
|
||||
|
||||
self.pd_tp_ratio = 1
|
||||
self.pd_head_ratio = 1
|
||||
self.num_head_replica = 1
|
||||
if vllm_config.kv_transfer_config is not None and not vllm_config.model_config.is_deepseek_mla:
|
||||
prefill_tp_size = vllm_config.kv_transfer_config.get_from_extra_config(
|
||||
"prefill", {"tp_size": 1})["tp_size"]
|
||||
decode_tp_size = vllm_config.kv_transfer_config.get_from_extra_config(
|
||||
"decode", {"tp_size": 1})["tp_size"]
|
||||
prefill_tp_size = vllm_config.kv_transfer_config.get_from_extra_config("prefill", {"tp_size": 1})["tp_size"]
|
||||
decode_tp_size = vllm_config.kv_transfer_config.get_from_extra_config("decode", {"tp_size": 1})["tp_size"]
|
||||
assert prefill_tp_size % decode_tp_size == 0, "Prefill TP size must be divisible by Decode TP size."
|
||||
self.pd_tp_ratio = prefill_tp_size // decode_tp_size
|
||||
if self.pd_tp_ratio > 1:
|
||||
@@ -106,36 +94,29 @@ class AscendConfig:
|
||||
)
|
||||
|
||||
if self.pd_tp_ratio == 0:
|
||||
raise AssertionError(
|
||||
"Only support P node tp size lagger then D node tp size")
|
||||
self.SLO_limits_for_dynamic_batch = additional_config.get(
|
||||
"SLO_limits_for_dynamic_batch", -1)
|
||||
raise AssertionError("Only support P node tp size lagger then D node tp size")
|
||||
self.SLO_limits_for_dynamic_batch = additional_config.get("SLO_limits_for_dynamic_batch", -1)
|
||||
from vllm_ascend.utils import get_flashcomm2_config_and_validate
|
||||
self.flashcomm2_oproj_tensor_parallel_size = get_flashcomm2_config_and_validate(
|
||||
self, vllm_config)
|
||||
self.enable_npugraph_ex = additional_config.get(
|
||||
"enable_npugraph_ex", False)
|
||||
|
||||
self.flashcomm2_oproj_tensor_parallel_size = get_flashcomm2_config_and_validate(self, vllm_config)
|
||||
self.enable_npugraph_ex = additional_config.get("enable_npugraph_ex", False)
|
||||
# We find that _npu_paged_attention still performs better than
|
||||
# npu_fused_infer_attention_score in some cases. We allow to execute
|
||||
# _npu_paged_attention in this cases. This should be removed once
|
||||
# npu_fused_infer_attention_score performs better on all scenarios.
|
||||
self.pa_shape_list = additional_config.get("pa_shape_list", [])
|
||||
|
||||
self.enable_async_exponential = bool(
|
||||
additional_config.get("enable_async_exponential", False))
|
||||
self.enable_async_exponential = bool(additional_config.get("enable_async_exponential", False))
|
||||
|
||||
self.enable_kv_nz = additional_config.get("enable_kv_nz", False)
|
||||
if self.enable_kv_nz:
|
||||
use_sparse = hasattr(vllm_config.model_config.hf_text_config,
|
||||
"index_topk")
|
||||
use_sparse = hasattr(vllm_config.model_config.hf_text_config, "index_topk")
|
||||
if not vllm_config.model_config.is_deepseek_mla or use_sparse:
|
||||
raise RuntimeError(
|
||||
"enable_kv_nz is only supported for mla currently.")
|
||||
if vllm_config.kv_transfer_config is None \
|
||||
or not vllm_config.kv_transfer_config.is_kv_consumer:
|
||||
raise RuntimeError("enable_kv_nz is only supported for mla currently.")
|
||||
if vllm_config.kv_transfer_config is None or not vllm_config.kv_transfer_config.is_kv_consumer:
|
||||
raise NotImplementedError(
|
||||
"enable_kv_nz is only supported in pd scenario and can "
|
||||
"only be used in D node.")
|
||||
"enable_kv_nz is only supported in pd scenario and can only be used in D node."
|
||||
)
|
||||
|
||||
|
||||
class FinegrainedTPConfig:
|
||||
@@ -144,40 +125,28 @@ class FinegrainedTPConfig:
|
||||
"""
|
||||
|
||||
def __init__(self, finegrained_tp_config: dict, vllm_config):
|
||||
self.oproj_tensor_parallel_size = finegrained_tp_config.get(
|
||||
"oproj_tensor_parallel_size", 0)
|
||||
self.lmhead_tensor_parallel_size = finegrained_tp_config.get(
|
||||
"lmhead_tensor_parallel_size", 0)
|
||||
self.embedding_tensor_parallel_size = finegrained_tp_config.get(
|
||||
"embedding_tensor_parallel_size", 0)
|
||||
self.mlp_tensor_parallel_size = finegrained_tp_config.get(
|
||||
"mlp_tensor_parallel_size", 0)
|
||||
self.oproj_tensor_parallel_size = finegrained_tp_config.get("oproj_tensor_parallel_size", 0)
|
||||
self.lmhead_tensor_parallel_size = finegrained_tp_config.get("lmhead_tensor_parallel_size", 0)
|
||||
self.embedding_tensor_parallel_size = finegrained_tp_config.get("embedding_tensor_parallel_size", 0)
|
||||
self.mlp_tensor_parallel_size = finegrained_tp_config.get("mlp_tensor_parallel_size", 0)
|
||||
|
||||
enabled_configs = []
|
||||
if self.oproj_tensor_parallel_size > 0:
|
||||
enabled_configs.append(
|
||||
f"oproj_tensor_parallel_size={self.oproj_tensor_parallel_size}"
|
||||
)
|
||||
# dummy_run does not run the entire attention module in eager mode,, so the o_proj tp split can only be used in graph mode.
|
||||
enabled_configs.append(f"oproj_tensor_parallel_size={self.oproj_tensor_parallel_size}")
|
||||
# dummy_run does not run the entire attention module in eager mode,
|
||||
# so the o_proj tp split can only be used in graph mode.
|
||||
if vllm_config.model_config.enforce_eager is True:
|
||||
raise AssertionError(
|
||||
"oproj_tensor_parallel_size is only supported in graph mode"
|
||||
)
|
||||
raise AssertionError("oproj_tensor_parallel_size is only supported in graph mode")
|
||||
if vllm_config.kv_transfer_config is None or not vllm_config.kv_transfer_config.is_kv_consumer:
|
||||
raise AssertionError(
|
||||
"oproj_tensor_parallel_size is only supported in pd scenario and can only be used in D node."
|
||||
)
|
||||
if self.lmhead_tensor_parallel_size > 0:
|
||||
enabled_configs.append(
|
||||
f"lmhead_tensor_parallel_size={self.lmhead_tensor_parallel_size}"
|
||||
)
|
||||
enabled_configs.append(f"lmhead_tensor_parallel_size={self.lmhead_tensor_parallel_size}")
|
||||
if self.embedding_tensor_parallel_size > 0:
|
||||
enabled_configs.append(
|
||||
f"embedding_tensor_parallel_size={self.embedding_tensor_parallel_size}"
|
||||
)
|
||||
enabled_configs.append(f"embedding_tensor_parallel_size={self.embedding_tensor_parallel_size}")
|
||||
if self.mlp_tensor_parallel_size > 0:
|
||||
enabled_configs.append(
|
||||
f"mlp_tensor_parallel_size={self.mlp_tensor_parallel_size}")
|
||||
enabled_configs.append(f"mlp_tensor_parallel_size={self.mlp_tensor_parallel_size}")
|
||||
module_tp_sizes = [
|
||||
self.oproj_tensor_parallel_size,
|
||||
self.lmhead_tensor_parallel_size,
|
||||
@@ -186,11 +155,9 @@ class FinegrainedTPConfig:
|
||||
]
|
||||
for module_tp_size in module_tp_sizes:
|
||||
if module_tp_size > 0 and vllm_config.parallel_config.data_parallel_size % module_tp_size != 0:
|
||||
raise AssertionError(
|
||||
"module tp sizes must divide data_parallel_size")
|
||||
raise AssertionError("module tp sizes must divide data_parallel_size")
|
||||
if any(size > 0 for size in module_tp_sizes) and enabled_configs:
|
||||
logger.info(
|
||||
f"finegrained_tp_config enabled: {', '.join(enabled_configs)}")
|
||||
logger.info(f"finegrained_tp_config enabled: {', '.join(enabled_configs)}")
|
||||
|
||||
|
||||
class AscendCompilationConfig:
|
||||
@@ -202,10 +169,7 @@ class AscendCompilationConfig:
|
||||
deployed on Ascend platforms.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
fuse_norm_quant: bool = True,
|
||||
fuse_qknorm_rope: bool = False,
|
||||
**kwargs):
|
||||
def __init__(self, fuse_norm_quant: bool = True, fuse_qknorm_rope: bool = False, **kwargs):
|
||||
"""
|
||||
Initialize the configuration.
|
||||
|
||||
@@ -236,7 +200,8 @@ class XliteGraphConfig:
|
||||
)
|
||||
if vllm_config.parallel_config.pipeline_parallel_size > 1:
|
||||
raise RuntimeError(
|
||||
"Xlite graph mode is not compatible with pipeline parallelism. Please set pipeline_parallel_size to 1."
|
||||
"Xlite graph mode is not compatible with pipeline parallelism. "
|
||||
"Please set pipeline_parallel_size to 1."
|
||||
)
|
||||
if vllm_config.cache_config.block_size != 128:
|
||||
raise RuntimeError(
|
||||
@@ -254,21 +219,19 @@ class WeightPrefetchConfig:
|
||||
"qkv": 1.0,
|
||||
"o": 1.0,
|
||||
},
|
||||
"moe": {
|
||||
"gate_up": 0.8
|
||||
}
|
||||
"moe": {"gate_up": 0.8},
|
||||
}
|
||||
|
||||
def __init__(self, weight_prefetch_config: dict):
|
||||
self.enabled = weight_prefetch_config.get("enabled", False)
|
||||
self.prefetch_ratio = weight_prefetch_config.get(
|
||||
"prefetch_ratio", self.prefetch_ratio)
|
||||
self.prefetch_ratio = weight_prefetch_config.get("prefetch_ratio", self.prefetch_ratio)
|
||||
|
||||
|
||||
class EplbConfig:
|
||||
"""
|
||||
Configuration Object for xlite_graph_config from additional_config
|
||||
"""
|
||||
|
||||
_defaults = {
|
||||
"dynamic_eplb": False,
|
||||
"expert_map_path": None,
|
||||
@@ -276,10 +239,12 @@ class EplbConfig:
|
||||
"algorithm_execution_interval": 30,
|
||||
"expert_map_record_path": None,
|
||||
"num_redundant_experts": 0,
|
||||
"eplb_policy_type": 1
|
||||
"eplb_policy_type": 1,
|
||||
}
|
||||
|
||||
def __init__(self, user_config: dict = {}):
|
||||
def __init__(self, user_config: dict | None = None):
|
||||
if user_config is None:
|
||||
user_config = {}
|
||||
self.config = self._defaults.copy()
|
||||
if user_config and isinstance(user_config, dict):
|
||||
for key, value in user_config.items():
|
||||
@@ -307,27 +272,21 @@ class EplbConfig:
|
||||
raise TypeError("The expert_map_record_path is not json.")
|
||||
dirname = os.path.dirname(self.expert_map_record_path)
|
||||
os.makedirs(dirname, exist_ok=True)
|
||||
for key in [
|
||||
"expert_heat_collection_interval",
|
||||
"algorithm_execution_interval", "num_redundant_experts"
|
||||
]:
|
||||
for key in ["expert_heat_collection_interval", "algorithm_execution_interval", "num_redundant_experts"]:
|
||||
if not isinstance(self.config[key], int):
|
||||
raise TypeError(f"{key} must be an integer")
|
||||
if self.config[key] < 0: # type: ignore
|
||||
raise ValueError(
|
||||
f"{key} must greater than 0; got {self.config[key]} instead"
|
||||
)
|
||||
raise ValueError(f"{key} must greater than 0; got {self.config[key]} instead")
|
||||
if self.eplb_policy_type not in [0, 1, 2, 3]:
|
||||
raise ValueError("eplb_policy_type must in [0, 1, 2, 3]")
|
||||
|
||||
|
||||
_ASCEND_CONFIG: Optional[AscendConfig] = None
|
||||
_ASCEND_CONFIG: AscendConfig | None = None
|
||||
|
||||
|
||||
def init_ascend_config(vllm_config):
|
||||
additional_config = vllm_config.additional_config if vllm_config.additional_config is not None else {}
|
||||
refresh = additional_config.get("refresh",
|
||||
False) if additional_config else False
|
||||
refresh = additional_config.get("refresh", False) if additional_config else False
|
||||
global _ASCEND_CONFIG
|
||||
if _ASCEND_CONFIG is not None and not refresh:
|
||||
return _ASCEND_CONFIG
|
||||
@@ -343,7 +302,5 @@ def clear_ascend_config():
|
||||
def get_ascend_config():
|
||||
global _ASCEND_CONFIG
|
||||
if _ASCEND_CONFIG is None:
|
||||
raise RuntimeError(
|
||||
"Ascend config is not initialized. Please call init_ascend_config first."
|
||||
)
|
||||
raise RuntimeError("Ascend config is not initialized. Please call init_ascend_config first.")
|
||||
return _ASCEND_CONFIG
|
||||
|
||||
@@ -1,21 +1,25 @@
|
||||
import math
|
||||
from contextlib import contextmanager
|
||||
from enum import Enum
|
||||
from typing import Any, Optional
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
from vllm.config import CUDAGraphMode, VllmConfig
|
||||
from vllm.distributed import (get_dp_group, get_ep_group,
|
||||
get_tensor_model_parallel_world_size)
|
||||
from vllm.forward_context import (BatchDescriptor, get_forward_context,
|
||||
set_forward_context)
|
||||
from vllm.distributed import get_dp_group, get_ep_group, get_tensor_model_parallel_world_size
|
||||
from vllm.forward_context import BatchDescriptor, get_forward_context, set_forward_context
|
||||
|
||||
import vllm_ascend.envs as envs_ascend
|
||||
from vllm_ascend.ascend_config import get_ascend_config
|
||||
from vllm_ascend.utils import (AscendDeviceType, enable_sp, flashcomm2_enable,
|
||||
get_ascend_device_type, has_layer_idx,
|
||||
is_drafter_moe_model, is_moe_model,
|
||||
speculative_enable_dispatch_gmm_combine_decode)
|
||||
from vllm_ascend.utils import (
|
||||
AscendDeviceType,
|
||||
enable_sp,
|
||||
flashcomm2_enable,
|
||||
get_ascend_device_type,
|
||||
has_layer_idx,
|
||||
is_drafter_moe_model,
|
||||
is_moe_model,
|
||||
speculative_enable_dispatch_gmm_combine_decode,
|
||||
)
|
||||
|
||||
|
||||
class MoECommType(Enum):
|
||||
@@ -31,13 +35,14 @@ def set_ascend_forward_context(
|
||||
vllm_config: VllmConfig,
|
||||
virtual_engine: int = 0,
|
||||
num_tokens: int = 0,
|
||||
num_tokens_across_dp: Optional[torch.Tensor] = None,
|
||||
num_tokens_across_dp: torch.Tensor | None = None,
|
||||
in_profile_run: bool = False,
|
||||
num_actual_tokens: Optional[int] = None,
|
||||
num_actual_tokens: int | None = None,
|
||||
aclgraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
|
||||
batch_descriptor: Optional[BatchDescriptor] = None,
|
||||
batch_descriptor: BatchDescriptor | None = None,
|
||||
model_instance: torch.nn.Module = None,
|
||||
is_draft_model=False):
|
||||
is_draft_model=False,
|
||||
):
|
||||
"""A context manager that stores the current forward context,
|
||||
can be attention metadata, etc.
|
||||
We add some additional param into forward_context.
|
||||
@@ -53,10 +58,9 @@ def set_ascend_forward_context(
|
||||
):
|
||||
forward_context = get_forward_context()
|
||||
|
||||
from vllm_ascend.ops.fused_moe.moe_comm_method import \
|
||||
get_moe_comm_method
|
||||
moe_comm_type = select_moe_comm_method(num_tokens, vllm_config,
|
||||
is_draft_model)
|
||||
from vllm_ascend.ops.fused_moe.moe_comm_method import get_moe_comm_method
|
||||
|
||||
moe_comm_type = select_moe_comm_method(num_tokens, vllm_config, is_draft_model)
|
||||
forward_context.moe_comm_type = moe_comm_type
|
||||
forward_context.moe_comm_method = get_moe_comm_method(moe_comm_type)
|
||||
|
||||
@@ -68,31 +72,28 @@ def set_ascend_forward_context(
|
||||
# due to multiple warmups before actual capturing
|
||||
forward_context.capturing = False
|
||||
|
||||
# set for sequence parallelism, 1000 is the batch size concurrency threshold for enabling the flashcomm_v1 or sequence_parallelism feature.
|
||||
# Currently, it is an empirical value. In normal scenarios, if the concurrency exceeds this threshold,
|
||||
# the performance benefits can be maximized. Conversely, if the concurrency is below the threshold,
|
||||
# set for sequence parallelism, 1000 is the batch size concurrency threshold
|
||||
# for enabling the flashcomm_v1 or sequence_parallelism feature.
|
||||
# Currently, it is an empirical value. In normal scenarios, if the concurrency
|
||||
# exceeds this threshold, the performance benefits can be maximized.
|
||||
# Conversely, if the concurrency is below the threshold,
|
||||
# the performance may degrade due to the switching of communication methods.
|
||||
mmrs_fusion = True
|
||||
# main model and drafter model may have different architecture
|
||||
is_context_moe_model = is_drafter_moe_model(vllm_config) \
|
||||
if is_draft_model else is_moe_model(vllm_config)
|
||||
is_context_moe_model = is_drafter_moe_model(vllm_config) if is_draft_model else is_moe_model(vllm_config)
|
||||
if is_context_moe_model:
|
||||
sp_enabled = enable_sp(vllm_config) and num_tokens is not None
|
||||
mmrs_fusion = False
|
||||
else:
|
||||
sp_enabled = enable_sp(vllm_config) and \
|
||||
num_tokens is not None and num_tokens > 1000
|
||||
sp_enabled = enable_sp(vllm_config) and num_tokens is not None and num_tokens > 1000
|
||||
forward_context.mmrs_fusion = mmrs_fusion
|
||||
forward_context.num_tokens = num_tokens
|
||||
forward_context.sp_enabled = sp_enabled
|
||||
# TODO(Levi-JQ): another PR to normalize the enabling logic for sp/fc2
|
||||
forward_context.flashcomm_v2_enabled = flashcomm2_enable(
|
||||
) and tp_world_size > 1 and num_tokens is not None
|
||||
forward_context.flashcomm_v2_enabled = flashcomm2_enable() and tp_world_size > 1 and num_tokens is not None
|
||||
|
||||
if (forward_context.sp_enabled
|
||||
or forward_context.flashcomm_v2_enabled):
|
||||
pad_size = (tp_world_size -
|
||||
(num_tokens % tp_world_size)) % tp_world_size
|
||||
if forward_context.sp_enabled or forward_context.flashcomm_v2_enabled:
|
||||
pad_size = (tp_world_size - (num_tokens % tp_world_size)) % tp_world_size
|
||||
forward_context.pad_size = pad_size
|
||||
|
||||
# set this for rope forward_oot using
|
||||
@@ -106,9 +107,12 @@ def set_ascend_forward_context(
|
||||
|
||||
# TODO(rjg-lyh): refactor mlp weight prefetch method
|
||||
# set for mlp weight prefetch
|
||||
prefetch_mlp_enabled = envs_ascend.VLLM_ASCEND_ENABLE_PREFETCH_MLP and \
|
||||
forward_context.layer_idx is not None and \
|
||||
num_tokens is not None and num_tokens < 500
|
||||
prefetch_mlp_enabled = (
|
||||
envs_ascend.VLLM_ASCEND_ENABLE_PREFETCH_MLP
|
||||
and forward_context.layer_idx is not None
|
||||
and num_tokens is not None
|
||||
and num_tokens < 500
|
||||
)
|
||||
if prefetch_mlp_enabled:
|
||||
forward_context.prefetch_mlp_gate_up_proj = False
|
||||
forward_context.prefetch_mlp_down_proj = False
|
||||
@@ -121,12 +125,9 @@ def set_ascend_forward_context(
|
||||
|
||||
dp_world_size = get_dp_group().world_size
|
||||
if dp_world_size > 1 and forward_context.dp_metadata is not None:
|
||||
max_tokens_across_dp = \
|
||||
forward_context.dp_metadata.max_tokens_across_dp_cpu.item()
|
||||
if (forward_context.sp_enabled
|
||||
or forward_context.flashcomm_v2_enabled):
|
||||
padded_length = (max_tokens_across_dp + tp_world_size -
|
||||
1) // tp_world_size * tp_world_size
|
||||
max_tokens_across_dp = forward_context.dp_metadata.max_tokens_across_dp_cpu.item()
|
||||
if forward_context.sp_enabled or forward_context.flashcomm_v2_enabled:
|
||||
padded_length = (max_tokens_across_dp + tp_world_size - 1) // tp_world_size * tp_world_size
|
||||
pad_size = padded_length - num_tokens
|
||||
forward_context.padded_length = padded_length
|
||||
forward_context.pad_size = pad_size
|
||||
@@ -139,12 +140,10 @@ def set_ascend_forward_context(
|
||||
if num_actual_tokens is None:
|
||||
num_actual_tokens = num_tokens
|
||||
# NOTE: token num which need to pad to when mc2
|
||||
forward_context.padded_num_tokens = math.ceil(
|
||||
max_tokens_across_dp / tp_world_size) * tp_world_size
|
||||
forward_context.padded_num_tokens = math.ceil(max_tokens_across_dp / tp_world_size) * tp_world_size
|
||||
reserved_mc2_mask = get_mc2_mask()
|
||||
if reserved_mc2_mask is not None:
|
||||
mc2_mask = reserved_mc2_mask[:forward_context.
|
||||
padded_num_tokens]
|
||||
mc2_mask = reserved_mc2_mask[: forward_context.padded_num_tokens]
|
||||
mc2_mask[:num_actual_tokens] = True
|
||||
mc2_mask[num_actual_tokens:] = False
|
||||
forward_context.mc2_mask = mc2_mask
|
||||
@@ -155,14 +154,13 @@ def set_ascend_forward_context(
|
||||
pass
|
||||
|
||||
|
||||
_mc2_tokens_capacity: Optional[int] = None
|
||||
_reserved_mc2_mask: Optional[torch.Tensor] = None
|
||||
_sin: Optional[torch.Tensor] = None
|
||||
_cos: Optional[torch.Tensor] = None
|
||||
_mc2_tokens_capacity: int | None = None
|
||||
_reserved_mc2_mask: torch.Tensor | None = None
|
||||
_sin: torch.Tensor | None = None
|
||||
_cos: torch.Tensor | None = None
|
||||
|
||||
|
||||
def set_mc2_tokens_capacity(vllm_config, max_num_reqs,
|
||||
uniform_decode_query_len):
|
||||
def set_mc2_tokens_capacity(vllm_config, max_num_reqs, uniform_decode_query_len):
|
||||
global _mc2_tokens_capacity
|
||||
if _mc2_tokens_capacity is not None:
|
||||
return
|
||||
@@ -186,9 +184,7 @@ def set_mc2_mask(vllm_config, device):
|
||||
if _reserved_mc2_mask is not None:
|
||||
return
|
||||
if is_moe_model(vllm_config):
|
||||
_reserved_mc2_mask = torch.zeros(get_mc2_tokens_capacity(),
|
||||
dtype=torch.bool,
|
||||
device=device)
|
||||
_reserved_mc2_mask = torch.zeros(get_mc2_tokens_capacity(), dtype=torch.bool, device=device)
|
||||
else:
|
||||
_reserved_mc2_mask = None
|
||||
|
||||
@@ -197,9 +193,7 @@ def get_mc2_mask():
|
||||
return _reserved_mc2_mask
|
||||
|
||||
|
||||
def select_moe_comm_method(num_tokens: int,
|
||||
vllm_config: VllmConfig,
|
||||
is_draft_model=False) -> Optional[MoECommType]:
|
||||
def select_moe_comm_method(num_tokens: int, vllm_config: VllmConfig, is_draft_model=False) -> MoECommType | None:
|
||||
"""Select the MoE communication method according to parallel settings,
|
||||
device generation, token count, and quantization.
|
||||
|
||||
@@ -227,16 +221,19 @@ def select_moe_comm_method(num_tokens: int,
|
||||
mc2_tokens_capacity = get_mc2_tokens_capacity()
|
||||
soc_version = get_ascend_device_type()
|
||||
quant_type = getattr(
|
||||
vllm_config.model_config.hf_text_config, 'moe_quantize',
|
||||
getattr(vllm_config.model_config.hf_text_config, 'quantize', None))
|
||||
vllm_config.model_config.hf_text_config,
|
||||
"moe_quantize",
|
||||
getattr(vllm_config.model_config.hf_text_config, "quantize", None),
|
||||
)
|
||||
|
||||
if not vllm_config.parallel_config.enable_expert_parallel or get_ep_group(
|
||||
).world_size == 1:
|
||||
if not vllm_config.parallel_config.enable_expert_parallel or get_ep_group().world_size == 1:
|
||||
moe_comm_type = MoECommType.ALLGATHER
|
||||
elif soc_version in {AscendDeviceType.A2}:
|
||||
if (num_tokens <= mc2_tokens_capacity
|
||||
and vllm_config.parallel_config.world_size_across_dp /
|
||||
vllm_config.parallel_config.pipeline_parallel_size >= 16):
|
||||
if (
|
||||
num_tokens <= mc2_tokens_capacity
|
||||
and vllm_config.parallel_config.world_size_across_dp / vllm_config.parallel_config.pipeline_parallel_size
|
||||
>= 16
|
||||
):
|
||||
moe_comm_type = MoECommType.MC2
|
||||
else:
|
||||
moe_comm_type = MoECommType.ALLGATHER
|
||||
@@ -246,15 +243,13 @@ def select_moe_comm_method(num_tokens: int,
|
||||
# TODO: drop the EP-size guard when dispatch_ffn_combine supports larger EP sizes
|
||||
# TODO: drop speculative method guard when dispatch_gmm_combine_decode supports w16a16
|
||||
fused_mc2_enable = envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 and quant_type == "w8a8_dynamic"
|
||||
dispatch_ffn_combine_enable = get_ep_group().world_size <= 32 and (
|
||||
not is_draft_model) and (not dynamic_eplb)
|
||||
dispatch_ffn_combine_enable = get_ep_group().world_size <= 32 and (not is_draft_model) and (not dynamic_eplb)
|
||||
if num_tokens <= mc2_tokens_capacity:
|
||||
fused_decode_enable = fused_mc2_enable
|
||||
if envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 == 1:
|
||||
fused_decode_enable = fused_mc2_enable and dispatch_ffn_combine_enable
|
||||
elif envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 == 2:
|
||||
fused_decode_enable = fused_mc2_enable and \
|
||||
speculative_enable_dispatch_gmm_combine_decode(vllm_config)
|
||||
fused_decode_enable = fused_mc2_enable and speculative_enable_dispatch_gmm_combine_decode(vllm_config)
|
||||
moe_comm_type = MoECommType.FUSED_MC2 if fused_decode_enable else MoECommType.MC2
|
||||
else:
|
||||
fused_prefill_enable = fused_mc2_enable
|
||||
|
||||
@@ -2,9 +2,10 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import dataclasses
|
||||
from collections.abc import Callable
|
||||
from contextlib import ExitStack
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Callable, Optional
|
||||
from typing import Any
|
||||
from unittest.mock import patch
|
||||
|
||||
import numpy as np
|
||||
@@ -27,12 +28,12 @@ from ..utils import weak_ref_tensors
|
||||
@dataclasses.dataclass
|
||||
class ACLGraphEntry:
|
||||
batch_descriptor: BatchDescriptor
|
||||
aclgraph: Optional[torch.npu.NPUGraph] = None
|
||||
output: Optional[Any] = None
|
||||
aclgraph: torch.npu.NPUGraph | None = None
|
||||
output: Any | None = None
|
||||
|
||||
# for aclgraph debugging, track the input addresses
|
||||
# during capture, and check if they are the same during replay
|
||||
input_addresses: Optional[list[int]] = None
|
||||
input_addresses: list[int] | None = None
|
||||
|
||||
|
||||
class ACLGraphWrapper:
|
||||
@@ -60,11 +61,13 @@ class ACLGraphWrapper:
|
||||
guaranteed when VLLM_LOGGING_LEVEL == "DEBUG".
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
def __init__(
|
||||
self,
|
||||
runnable: Callable,
|
||||
vllm_config: VllmConfig,
|
||||
runtime_mode: CUDAGraphMode,
|
||||
cudagraph_options: Optional[CUDAGraphOptions] = None):
|
||||
cudagraph_options: CUDAGraphOptions | None = None,
|
||||
):
|
||||
self.runnable = runnable
|
||||
self.vllm_config = vllm_config
|
||||
self.runtime_mode = runtime_mode
|
||||
@@ -83,15 +86,13 @@ class ACLGraphWrapper:
|
||||
self.aclgraph_options = cudagraph_options
|
||||
# the entries for different batch descriptors that we need to capture
|
||||
# aclgraphs for.
|
||||
self.concrete_aclgraph_entries: dict[BatchDescriptor, ACLGraphEntry]\
|
||||
= {}
|
||||
self.concrete_aclgraph_entries: dict[BatchDescriptor, ACLGraphEntry] = {}
|
||||
|
||||
def __getattr__(self, key: str):
|
||||
# allow accessing the attributes of the runnable.
|
||||
if hasattr(self.runnable, key):
|
||||
return getattr(self.runnable, key)
|
||||
raise AttributeError(f"Attribute {key} not exists in the runnable of "
|
||||
f"aclgraph wrapper: {self.runnable}")
|
||||
raise AttributeError(f"Attribute {key} not exists in the runnable of aclgraph wrapper: {self.runnable}")
|
||||
|
||||
def unwrap(self) -> Callable:
|
||||
# in case we need to access the original runnable.
|
||||
@@ -102,8 +103,7 @@ class ACLGraphWrapper:
|
||||
batch_descriptor = forward_context.batch_descriptor
|
||||
aclgraph_runtime_mode = forward_context.cudagraph_runtime_mode
|
||||
|
||||
if aclgraph_runtime_mode == CUDAGraphMode.NONE or \
|
||||
aclgraph_runtime_mode != self.runtime_mode:
|
||||
if aclgraph_runtime_mode == CUDAGraphMode.NONE or aclgraph_runtime_mode != self.runtime_mode:
|
||||
# CUDAGraphMode.NONE could mean the profile run, a warmup run, or
|
||||
# running without aclgraphs.
|
||||
# We do not trigger capture/replay if the runtime mode is not
|
||||
@@ -114,8 +114,7 @@ class ACLGraphWrapper:
|
||||
|
||||
if batch_descriptor not in self.concrete_aclgraph_entries:
|
||||
# create a new entry for this batch descriptor
|
||||
self.concrete_aclgraph_entries[batch_descriptor] = \
|
||||
ACLGraphEntry(batch_descriptor=batch_descriptor)
|
||||
self.concrete_aclgraph_entries[batch_descriptor] = ACLGraphEntry(batch_descriptor=batch_descriptor)
|
||||
|
||||
entry = self.concrete_aclgraph_entries[batch_descriptor]
|
||||
|
||||
@@ -125,14 +124,11 @@ class ACLGraphWrapper:
|
||||
# capturing is fast, we don't need to log it for every
|
||||
# shape. E.g. we only log it for the first subgraph in
|
||||
# piecewise mode.
|
||||
logger.debug("Capturing a aclgraph on (%s,%s)",
|
||||
self.runtime_mode.name, entry.batch_descriptor)
|
||||
logger.debug("Capturing a aclgraph on (%s,%s)", self.runtime_mode.name, entry.batch_descriptor)
|
||||
# validate that aclgraph capturing is legal at this point.
|
||||
validate_cudagraph_capturing_enabled()
|
||||
|
||||
input_addresses = [
|
||||
x.data_ptr() for x in args if isinstance(x, torch.Tensor)
|
||||
]
|
||||
input_addresses = [x.data_ptr() for x in args if isinstance(x, torch.Tensor)]
|
||||
entry.input_addresses = input_addresses
|
||||
aclgraph = torch.npu.NPUGraph()
|
||||
|
||||
@@ -145,8 +141,7 @@ class ACLGraphWrapper:
|
||||
# therefore, we only run gc for the first graph,
|
||||
# and disable gc for the rest of the graphs.
|
||||
stack.enter_context(patch("gc.collect", lambda: None))
|
||||
stack.enter_context(
|
||||
patch("torch.npu.empty_cache", lambda: None))
|
||||
stack.enter_context(patch("torch.npu.empty_cache", lambda: None))
|
||||
|
||||
# mind-exploding: carefully manage the reference and memory.
|
||||
forward_context.capturing = True
|
||||
@@ -183,13 +178,12 @@ class ACLGraphWrapper:
|
||||
|
||||
if self.is_debugging_mode:
|
||||
# check if the input addresses are the same
|
||||
new_input_addresses = [
|
||||
x.data_ptr() for x in args if isinstance(x, torch.Tensor)
|
||||
]
|
||||
new_input_addresses = [x.data_ptr() for x in args if isinstance(x, torch.Tensor)]
|
||||
assert new_input_addresses == entry.input_addresses, (
|
||||
f"Input addresses for aclgraphs are different "
|
||||
f"during replay. Expected {entry.input_addresses}, "
|
||||
f"got {new_input_addresses}")
|
||||
f"got {new_input_addresses}"
|
||||
)
|
||||
|
||||
logger.info_once("Replaying aclgraph")
|
||||
# In async scheduling or multi-threaded (MT) scenarios when graph mode is FULL, it is possible that
|
||||
@@ -209,8 +203,7 @@ def weak_ref_workspaces(params):
|
||||
for num_tokens in params.workspaces:
|
||||
if params.workspaces[num_tokens] is None:
|
||||
continue
|
||||
params.workspaces[num_tokens] = weak_ref_tensors(
|
||||
params.workspaces[num_tokens])
|
||||
params.workspaces[num_tokens] = weak_ref_tensors(params.workspaces[num_tokens])
|
||||
|
||||
|
||||
def _update_attn_pa_params(update_stream, forward_context, runtime_shape):
|
||||
@@ -254,9 +247,11 @@ def _update_attn_pa_params(update_stream, forward_context, runtime_shape):
|
||||
scale_value=scale,
|
||||
block_table=block_table,
|
||||
context_lens=seq_lens,
|
||||
out=output)
|
||||
out=output,
|
||||
)
|
||||
torch.npu.graph_task_update_begin(update_stream, handle)
|
||||
torch_npu._npu_paged_attention(query=query,
|
||||
torch_npu._npu_paged_attention(
|
||||
query=query,
|
||||
key_cache=key_cache,
|
||||
value_cache=value_cache,
|
||||
num_kv_heads=num_kv_heads,
|
||||
@@ -265,7 +260,8 @@ def _update_attn_pa_params(update_stream, forward_context, runtime_shape):
|
||||
block_table=block_table,
|
||||
context_lens=seq_lens,
|
||||
out=output,
|
||||
workspace=workspace)
|
||||
workspace=workspace,
|
||||
)
|
||||
torch.npu.graph_task_update_end(update_stream)
|
||||
|
||||
event.record(update_stream)
|
||||
@@ -287,13 +283,24 @@ def _update_attn_fia_params(update_stream, forward_context, runtime_shape):
|
||||
graph_params.handles[runtime_shape],
|
||||
graph_params.events[runtime_shape],
|
||||
):
|
||||
(query, key_cache, value, block_tables, attn_mask, block_size,
|
||||
seq_lens, query_start_loc, num_kv_heads, num_heads, scale,
|
||||
attn_output, softmax_lse) = param
|
||||
(
|
||||
query,
|
||||
key_cache,
|
||||
value,
|
||||
block_tables,
|
||||
attn_mask,
|
||||
block_size,
|
||||
seq_lens,
|
||||
query_start_loc,
|
||||
num_kv_heads,
|
||||
num_heads,
|
||||
scale,
|
||||
attn_output,
|
||||
softmax_lse,
|
||||
) = param
|
||||
|
||||
seq_lens = forward_context.attn_metadata[key].seq_lens_list
|
||||
actual_seq_lengths_q = forward_context.attn_metadata[
|
||||
key].actual_seq_lengths_q
|
||||
actual_seq_lengths_q = forward_context.attn_metadata[key].actual_seq_lengths_q
|
||||
torch.npu.graph_task_update_begin(update_stream, handle)
|
||||
torch_npu.npu_fused_infer_attention_score.out(
|
||||
query=query,
|
||||
@@ -317,16 +324,14 @@ def _update_attn_fia_params(update_stream, forward_context, runtime_shape):
|
||||
event.record(update_stream)
|
||||
|
||||
|
||||
def update_attn_params(update_stream, forward_context, runtime_shape,
|
||||
vllm_config):
|
||||
def update_attn_params(update_stream, forward_context, runtime_shape, vllm_config):
|
||||
if using_paged_attention(runtime_shape, vllm_config):
|
||||
_update_attn_pa_params(update_stream, forward_context, runtime_shape)
|
||||
else:
|
||||
_update_attn_fia_params(update_stream, forward_context, runtime_shape)
|
||||
|
||||
|
||||
def update_mla_attn_params(update_stream, forward_context, runtime_shape,
|
||||
speculative_config):
|
||||
def update_mla_attn_params(update_stream, forward_context, runtime_shape, speculative_config):
|
||||
if forward_context.is_draft_model:
|
||||
graph_params = get_draft_graph_params()
|
||||
else:
|
||||
@@ -340,36 +345,39 @@ def update_mla_attn_params(update_stream, forward_context, runtime_shape,
|
||||
graph_params.handles[runtime_shape],
|
||||
graph_params.events[runtime_shape],
|
||||
):
|
||||
(q_nope, k_nope, q_pe, k_pe, num_heads, num_kv_heads, input_layout,
|
||||
attn_mask, sparse_mode, scale, block_table, block_size,
|
||||
seq_lens_list, actual_seq_lengths, attn_output,
|
||||
softmax_lse) = param
|
||||
seq_lens_list = forward_context.attn_metadata[
|
||||
key].decode.seq_lens_list
|
||||
if speculative_config and speculative_config.method == "mtp" \
|
||||
and not forward_context.is_draft_model:
|
||||
actual_seq_lengths = forward_context.attn_metadata[
|
||||
key].decode.actual_seq_lengths_q
|
||||
(
|
||||
q_nope,
|
||||
k_nope,
|
||||
q_pe,
|
||||
k_pe,
|
||||
num_heads,
|
||||
num_kv_heads,
|
||||
input_layout,
|
||||
attn_mask,
|
||||
sparse_mode,
|
||||
scale,
|
||||
block_table,
|
||||
block_size,
|
||||
seq_lens_list,
|
||||
actual_seq_lengths,
|
||||
attn_output,
|
||||
softmax_lse,
|
||||
) = param
|
||||
seq_lens_list = forward_context.attn_metadata[key].decode.seq_lens_list
|
||||
if speculative_config and speculative_config.method == "mtp" and not forward_context.is_draft_model:
|
||||
actual_seq_lengths = forward_context.attn_metadata[key].decode.actual_seq_lengths_q
|
||||
spec_multiple = speculative_config.num_speculative_tokens + 1
|
||||
seq_lens_list = seq_lens_list + [0] * (
|
||||
runtime_shape // spec_multiple - len(seq_lens_list))
|
||||
actual_seq_lengths = [
|
||||
spec_multiple * (i + 1)
|
||||
for i in range(runtime_shape // spec_multiple)
|
||||
]
|
||||
seq_lens_list = seq_lens_list + [0] * (runtime_shape // spec_multiple - len(seq_lens_list))
|
||||
actual_seq_lengths = [spec_multiple * (i + 1) for i in range(runtime_shape // spec_multiple)]
|
||||
elif forward_context.is_draft_model:
|
||||
actual_seq_lengths = forward_context.attn_metadata[
|
||||
key].decode.actual_seq_lengths_q
|
||||
block_table = forward_context.attn_metadata[
|
||||
key].decode.block_table
|
||||
actual_seq_lengths = forward_context.attn_metadata[key].decode.actual_seq_lengths_q
|
||||
block_table = forward_context.attn_metadata[key].decode.block_table
|
||||
# TODO: This is a hack and should be fixed in the future.
|
||||
if speculative_config.disable_padded_drafter_batch:
|
||||
block_table = block_table[: len(actual_seq_lengths)]
|
||||
seq_lens_list = seq_lens_list + [0] * (
|
||||
len(actual_seq_lengths) - len(seq_lens_list))
|
||||
seq_lens_list = seq_lens_list + [0] * (len(actual_seq_lengths) - len(seq_lens_list))
|
||||
else:
|
||||
seq_lens_list = seq_lens_list + [0] * (runtime_shape -
|
||||
len(seq_lens_list))
|
||||
seq_lens_list = seq_lens_list + [0] * (runtime_shape - len(seq_lens_list))
|
||||
torch.npu.graph_task_update_begin(update_stream, handle)
|
||||
|
||||
torch_npu.npu_fused_infer_attention_score.out(
|
||||
@@ -391,7 +399,8 @@ def update_mla_attn_params(update_stream, forward_context, runtime_shape,
|
||||
actual_seq_lengths_kv=seq_lens_list,
|
||||
actual_seq_lengths=actual_seq_lengths,
|
||||
workspace=graph_params.workspaces.get(runtime_shape),
|
||||
out=[attn_output, softmax_lse])
|
||||
out=[attn_output, softmax_lse],
|
||||
)
|
||||
torch.npu.graph_task_update_end(update_stream)
|
||||
|
||||
event.record(update_stream)
|
||||
@@ -408,29 +417,35 @@ def update_attn_dcp_pcp_params(update_stream, forward_context, runtime_shape):
|
||||
graph_params.handles[runtime_shape],
|
||||
graph_params.events[runtime_shape],
|
||||
):
|
||||
(q_nope, k_nope, value, num_heads, num_kv_heads, scale,
|
||||
block_table, block_size, actual_seq_lengths_kv,
|
||||
actual_seq_lengths_q, attn_output, softmax_lse, dcp_size,
|
||||
pcp_rank, dcp_rank) = param
|
||||
attn_metadata = forward_context.attn_metadata[key]
|
||||
actual_seq_lengths_kv = attn_metadata.decode_meta.num_computed_tokens_of_pcp_dcp[:,
|
||||
(
|
||||
q_nope,
|
||||
k_nope,
|
||||
value,
|
||||
num_heads,
|
||||
num_kv_heads,
|
||||
scale,
|
||||
block_table,
|
||||
block_size,
|
||||
actual_seq_lengths_kv,
|
||||
actual_seq_lengths_q,
|
||||
attn_output,
|
||||
softmax_lse,
|
||||
dcp_size,
|
||||
pcp_rank,
|
||||
dcp_rank]
|
||||
dcp_rank,
|
||||
) = param
|
||||
attn_metadata = forward_context.attn_metadata[key]
|
||||
actual_seq_lengths_kv = attn_metadata.decode_meta.num_computed_tokens_of_pcp_dcp[:, pcp_rank, dcp_rank]
|
||||
pad_length = runtime_shape - len(actual_seq_lengths_kv)
|
||||
if pad_length > 0:
|
||||
pad_tensor = np.zeros(pad_length,
|
||||
dtype=actual_seq_lengths_kv.dtype)
|
||||
actual_seq_lengths_kv = np.concatenate(
|
||||
[actual_seq_lengths_kv, pad_tensor])
|
||||
pad_tensor = np.zeros(pad_length, dtype=actual_seq_lengths_kv.dtype)
|
||||
actual_seq_lengths_kv = np.concatenate([actual_seq_lengths_kv, pad_tensor])
|
||||
|
||||
actual_seq_lengths_q = attn_metadata.actual_seq_lengths_q[:
|
||||
attn_metadata
|
||||
.
|
||||
num_decode_tokens]
|
||||
if (runtime_shape - len(actual_seq_lengths_q)):
|
||||
actual_seq_lengths_q = actual_seq_lengths_q + [
|
||||
actual_seq_lengths_q[-1]
|
||||
] * (runtime_shape - len(actual_seq_lengths_q))
|
||||
actual_seq_lengths_q = attn_metadata.actual_seq_lengths_q[: attn_metadata.num_decode_tokens]
|
||||
if runtime_shape - len(actual_seq_lengths_q):
|
||||
actual_seq_lengths_q = actual_seq_lengths_q + [actual_seq_lengths_q[-1]] * (
|
||||
runtime_shape - len(actual_seq_lengths_q)
|
||||
)
|
||||
if dcp_size > 1:
|
||||
num_heads = num_heads * dcp_size
|
||||
|
||||
@@ -453,14 +468,14 @@ def update_attn_dcp_pcp_params(update_stream, forward_context, runtime_shape):
|
||||
actual_seq_lengths_kv=actual_seq_lengths_kv,
|
||||
actual_seq_lengths=actual_seq_lengths_q,
|
||||
workspace=graph_params.workspaces.get(runtime_shape),
|
||||
out=[attn_output, softmax_lse])
|
||||
out=[attn_output, softmax_lse],
|
||||
)
|
||||
torch.npu.graph_task_update_end(update_stream)
|
||||
|
||||
event.record(update_stream)
|
||||
|
||||
|
||||
def update_mla_attn_dcp_pcp_params(update_stream, forward_context,
|
||||
runtime_shape):
|
||||
def update_mla_attn_dcp_pcp_params(update_stream, forward_context, runtime_shape):
|
||||
if forward_context.is_draft_model:
|
||||
graph_params = get_draft_graph_params()
|
||||
else:
|
||||
@@ -474,8 +489,19 @@ def update_mla_attn_dcp_pcp_params(update_stream, forward_context,
|
||||
graph_params.handles[runtime_shape],
|
||||
graph_params.events[runtime_shape],
|
||||
):
|
||||
(q_nope, q_pe, k_nope, k_pe, block_table, seq_len, num_heads,
|
||||
scale, num_kv_heads, attn_output, softmax_lse) = param
|
||||
(
|
||||
q_nope,
|
||||
q_pe,
|
||||
k_nope,
|
||||
k_pe,
|
||||
block_table,
|
||||
seq_len,
|
||||
num_heads,
|
||||
scale,
|
||||
num_kv_heads,
|
||||
attn_output,
|
||||
softmax_lse,
|
||||
) = param
|
||||
|
||||
decode_meta = forward_context.attn_metadata[key].decode
|
||||
seq_len = decode_meta.cp_seq_len
|
||||
@@ -484,9 +510,7 @@ def update_mla_attn_dcp_pcp_params(update_stream, forward_context,
|
||||
# to avoid irregular attn_mask shape,
|
||||
# so there's no need to divide runtime_shape by spec_multiple
|
||||
pad_length = runtime_shape - len(seq_len)
|
||||
pad_tensor = torch.zeros(pad_length,
|
||||
dtype=seq_len.dtype,
|
||||
device=seq_len.device)
|
||||
pad_tensor = torch.zeros(pad_length, dtype=seq_len.dtype, device=seq_len.device)
|
||||
seq_len = torch.cat([seq_len, pad_tensor], dim=0)
|
||||
|
||||
torch.npu.graph_task_update_begin(update_stream, handle)
|
||||
@@ -505,7 +529,8 @@ def update_mla_attn_dcp_pcp_params(update_stream, forward_context,
|
||||
calc_type="calc_type_ring",
|
||||
workspace=graph_params.workspaces.get(runtime_shape),
|
||||
output=attn_output,
|
||||
lse=softmax_lse)
|
||||
lse=softmax_lse,
|
||||
)
|
||||
torch.npu.graph_task_update_end(update_stream)
|
||||
|
||||
event.record(update_stream)
|
||||
@@ -519,7 +544,7 @@ class GraphParams:
|
||||
attn_params: dict[int, list[tuple]]
|
||||
|
||||
|
||||
_graph_params: Optional[GraphParams] = None
|
||||
_graph_params: GraphParams | None = None
|
||||
|
||||
|
||||
def set_graph_params(aclgraph_capture_sizes: list[int]):
|
||||
@@ -527,14 +552,10 @@ def set_graph_params(aclgraph_capture_sizes: list[int]):
|
||||
if _graph_params is not None:
|
||||
raise ValueError("Graph parameters have already been set!")
|
||||
_graph_params = GraphParams(
|
||||
{size: []
|
||||
for size in aclgraph_capture_sizes},
|
||||
{size: None
|
||||
for size in aclgraph_capture_sizes},
|
||||
{size: []
|
||||
for size in aclgraph_capture_sizes},
|
||||
{size: []
|
||||
for size in aclgraph_capture_sizes},
|
||||
{size: [] for size in aclgraph_capture_sizes},
|
||||
{size: None for size in aclgraph_capture_sizes},
|
||||
{size: [] for size in aclgraph_capture_sizes},
|
||||
{size: [] for size in aclgraph_capture_sizes},
|
||||
)
|
||||
|
||||
|
||||
@@ -548,7 +569,7 @@ def get_graph_params():
|
||||
return _graph_params
|
||||
|
||||
|
||||
_draft_graph_params: Optional[GraphParams] = None
|
||||
_draft_graph_params: GraphParams | None = None
|
||||
|
||||
|
||||
def set_draft_graph_params(aclgraph_capture_sizes: list[int]):
|
||||
@@ -556,14 +577,10 @@ def set_draft_graph_params(aclgraph_capture_sizes: list[int]):
|
||||
if _draft_graph_params is not None:
|
||||
raise ValueError("DraftGraph parameters have already been set!")
|
||||
_draft_graph_params = GraphParams(
|
||||
{size: []
|
||||
for size in aclgraph_capture_sizes},
|
||||
{size: None
|
||||
for size in aclgraph_capture_sizes},
|
||||
{size: []
|
||||
for size in aclgraph_capture_sizes},
|
||||
{size: []
|
||||
for size in aclgraph_capture_sizes},
|
||||
{size: [] for size in aclgraph_capture_sizes},
|
||||
{size: None for size in aclgraph_capture_sizes},
|
||||
{size: [] for size in aclgraph_capture_sizes},
|
||||
{size: [] for size in aclgraph_capture_sizes},
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -16,13 +16,13 @@
|
||||
# limitations under the License.
|
||||
#
|
||||
import functools
|
||||
from typing import Any, Callable, Optional
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
import torch.fx as fx
|
||||
from torch._dynamo.backends.common import aot_autograd
|
||||
from torch._inductor.compile_fx import (graph_returns_tuple,
|
||||
make_graph_return_tuple)
|
||||
from torch._inductor.compile_fx import graph_returns_tuple, make_graph_return_tuple
|
||||
from torch._inductor.decomposition import select_decomp_table
|
||||
from torch.fx import GraphModule
|
||||
from vllm.compilation.compiler_interface import CompilerInterface
|
||||
@@ -32,15 +32,11 @@ from vllm_ascend.ascend_config import get_ascend_config
|
||||
from vllm_ascend.utils import COMPILATION_PASS_KEY
|
||||
|
||||
|
||||
def compile_fx(graph: GraphModule, example_inputs: list,
|
||||
inner_compile: Callable, decompositions: dict) -> Callable:
|
||||
recursive_compile_fx = functools.partial(compile_fx,
|
||||
inner_compile=inner_compile,
|
||||
decompositions=decompositions)
|
||||
def compile_fx(graph: GraphModule, example_inputs: list, inner_compile: Callable, decompositions: dict) -> Callable:
|
||||
recursive_compile_fx = functools.partial(compile_fx, inner_compile=inner_compile, decompositions=decompositions)
|
||||
|
||||
if not graph_returns_tuple(graph):
|
||||
return make_graph_return_tuple(graph, example_inputs,
|
||||
recursive_compile_fx)
|
||||
return make_graph_return_tuple(graph, example_inputs, recursive_compile_fx)
|
||||
return aot_autograd(fw_compiler=inner_compile)(graph, example_inputs)
|
||||
|
||||
|
||||
@@ -49,9 +45,8 @@ def fusion_pass_compile(
|
||||
example_inputs: list[Any],
|
||||
compiler_config: dict[str, Any],
|
||||
compile_range: Range,
|
||||
key: Optional[str] = None,
|
||||
) -> tuple[Optional[Callable], Optional[Any]]:
|
||||
|
||||
key: str | None = None,
|
||||
) -> tuple[Callable | None, Any | None]:
|
||||
def compile_inner(graph, example_inputs):
|
||||
current_pass_manager = compiler_config[COMPILATION_PASS_KEY]
|
||||
graph = current_pass_manager(graph)
|
||||
@@ -74,8 +69,8 @@ def npugraph_ex_compile(
|
||||
example_inputs: list[Any],
|
||||
compiler_config: dict[str, Any],
|
||||
compile_range: Range,
|
||||
key: Optional[str] = None,
|
||||
) -> tuple[Optional[Callable], Optional[Any]]:
|
||||
key: str | None = None,
|
||||
) -> tuple[Callable | None, Any | None]:
|
||||
# When currently using the FULL_DECODE_ONLY mode,
|
||||
# the piecewise compilation level slicing process
|
||||
# in vllm is also encountered.
|
||||
@@ -87,9 +82,7 @@ def npugraph_ex_compile(
|
||||
output_node = fx_graph.output_node()
|
||||
with fx_graph.inserting_before(output_node):
|
||||
return_value = output_node.args[0]
|
||||
tuple_node = fx_graph.create_node("call_function",
|
||||
tuple,
|
||||
args=([return_value], ))
|
||||
tuple_node = fx_graph.create_node("call_function", tuple, args=([return_value],))
|
||||
output_node.args = (tuple_node,)
|
||||
graph.recompile()
|
||||
|
||||
@@ -119,6 +112,7 @@ class AscendCompiler(CompilerInterface):
|
||||
This class provides a method to compile a PyTorch FX graph module with
|
||||
specific configurations for graph fusion and decomposition.
|
||||
"""
|
||||
|
||||
name = "AscendCompiler"
|
||||
|
||||
def compile(
|
||||
@@ -127,13 +121,10 @@ class AscendCompiler(CompilerInterface):
|
||||
example_inputs: list[Any],
|
||||
compiler_config: dict[str, Any],
|
||||
compile_range: Range,
|
||||
key: Optional[str] = None,
|
||||
) -> tuple[Optional[Callable], Optional[Any]]:
|
||||
|
||||
key: str | None = None,
|
||||
) -> tuple[Callable | None, Any | None]:
|
||||
ascend_config = get_ascend_config()
|
||||
if ascend_config.enable_npugraph_ex:
|
||||
return npugraph_ex_compile(graph, example_inputs, compiler_config,
|
||||
compile_range, key)
|
||||
return npugraph_ex_compile(graph, example_inputs, compiler_config, compile_range, key)
|
||||
else:
|
||||
return fusion_pass_compile(graph, example_inputs, compiler_config,
|
||||
compile_range, key)
|
||||
return fusion_pass_compile(graph, example_inputs, compiler_config, compile_range, key)
|
||||
|
||||
@@ -48,13 +48,13 @@ class GraphFusionPassManager:
|
||||
|
||||
def configure(self, config: VllmConfig):
|
||||
# By default, we enable the graph fusion and quantization fusion pass.
|
||||
self.ascend_compilation_config: dict = config.additional_config.get(
|
||||
"ascend_compilation_config", {})
|
||||
self.ascend_compilation_config: dict = config.additional_config.get("ascend_compilation_config", {})
|
||||
if self.ascend_compilation_config.get("fuse_norm_quant", True):
|
||||
from .passes.norm_quant_fusion_pass import \
|
||||
AddRMSNormQuantFusionPass
|
||||
from .passes.norm_quant_fusion_pass import AddRMSNormQuantFusionPass
|
||||
|
||||
self.passes.append(AddRMSNormQuantFusionPass(config))
|
||||
|
||||
if self.ascend_compilation_config.get("fuse_qknorm_rope", True):
|
||||
from .passes.qknorm_rope_fusion_pass import QKNormRopeFusionPass
|
||||
|
||||
self.passes.append(QKNormRopeFusionPass(config))
|
||||
|
||||
@@ -48,7 +48,8 @@ def _extra_stream_scope_check(match: Match) -> bool:
|
||||
logger.debug(
|
||||
f"Cross-stream operation detected in pattern match for AddRMSNormQuant. "
|
||||
f"Multiple streams found: {non_default_streams}. "
|
||||
f"Fusion is not supported for cross-stream operations.")
|
||||
f"Fusion is not supported for cross-stream operations."
|
||||
)
|
||||
return False
|
||||
|
||||
return True
|
||||
@@ -57,24 +58,29 @@ def _extra_stream_scope_check(match: Match) -> bool:
|
||||
@functools.lru_cache(None)
|
||||
# The replacement registered here will be actually executed after AOT.
|
||||
def replacement_add_rms_norm_quant(epsilon):
|
||||
|
||||
def pattern(rms_norm_input: torch.Tensor, residual: torch.Tensor,
|
||||
rms_norm_weight: torch.Tensor, scale: torch.Tensor,
|
||||
offset: torch.Tensor):
|
||||
def pattern(
|
||||
rms_norm_input: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
rms_norm_weight: torch.Tensor,
|
||||
scale: torch.Tensor,
|
||||
offset: torch.Tensor,
|
||||
):
|
||||
"""
|
||||
Pattern for AddRMSNormQuant fusion.
|
||||
"""
|
||||
output = torch.ops.npu.npu_add_rms_norm(rms_norm_input, residual,
|
||||
rms_norm_weight, epsilon)
|
||||
output = torch.ops.npu.npu_add_rms_norm(rms_norm_input, residual, rms_norm_weight, epsilon)
|
||||
out0 = output[0]
|
||||
out1 = output[2]
|
||||
quantized_output = torch.ops.npu.npu_quantize(out0, scale, offset,
|
||||
torch.qint8, -1, False)
|
||||
quantized_output = torch.ops.npu.npu_quantize(out0, scale, offset, torch.qint8, -1, False)
|
||||
return quantized_output, out1
|
||||
|
||||
def replacement(rms_norm_input: torch.Tensor, residual: torch.Tensor,
|
||||
rms_norm_weight: torch.Tensor, scale: torch.Tensor,
|
||||
offset: torch.Tensor):
|
||||
def replacement(
|
||||
rms_norm_input: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
rms_norm_weight: torch.Tensor,
|
||||
scale: torch.Tensor,
|
||||
offset: torch.Tensor,
|
||||
):
|
||||
"""
|
||||
Replacement for the AddRMSNormQuant fusion.
|
||||
"""
|
||||
@@ -82,10 +88,12 @@ def replacement_add_rms_norm_quant(epsilon):
|
||||
rms_norm_input,
|
||||
residual,
|
||||
rms_norm_weight,
|
||||
# The inverse of scale is required by npu_add_rms_norm_quant kernel which is opposite to the npu_quantize kernel.
|
||||
1. / scale,
|
||||
# The inverse of scale is required by npu_add_rms_norm_quant kernel
|
||||
# which is opposite to the npu_quantize kernel.
|
||||
1.0 / scale,
|
||||
offset,
|
||||
epsilon=epsilon)
|
||||
epsilon=epsilon,
|
||||
)
|
||||
quantized_output = output[0]
|
||||
out1 = output[2]
|
||||
return quantized_output, out1
|
||||
@@ -103,33 +111,39 @@ def replacement_add_rms_norm_quant(epsilon):
|
||||
|
||||
import torchair
|
||||
|
||||
torchair.register_replacement(search_fn=pattern,
|
||||
replace_fn=replacement,
|
||||
example_inputs=get_inputs(),
|
||||
extra_check=_extra_stream_scope_check)
|
||||
torchair.register_replacement(
|
||||
search_fn=pattern, replace_fn=replacement, example_inputs=get_inputs(), extra_check=_extra_stream_scope_check
|
||||
)
|
||||
|
||||
|
||||
# The replacement registered here will be actually executed after AOT.
|
||||
def replacement_add_rms_norm_quant_with_bias(epsilon):
|
||||
|
||||
def pattern(rms_norm_input: torch.Tensor, residual: torch.Tensor,
|
||||
rms_norm_weight: torch.Tensor, scale: torch.Tensor,
|
||||
offset: torch.Tensor, bias: torch.Tensor):
|
||||
def pattern(
|
||||
rms_norm_input: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
rms_norm_weight: torch.Tensor,
|
||||
scale: torch.Tensor,
|
||||
offset: torch.Tensor,
|
||||
bias: torch.Tensor,
|
||||
):
|
||||
"""
|
||||
Pattern for AddRMSNormQuantWithBias fusion.
|
||||
"""
|
||||
output = torch.ops.npu.npu_add_rms_norm(rms_norm_input, residual,
|
||||
rms_norm_weight, epsilon)
|
||||
output = torch.ops.npu.npu_add_rms_norm(rms_norm_input, residual, rms_norm_weight, epsilon)
|
||||
out0 = output[0]
|
||||
out1 = output[2]
|
||||
out0 = out0 + bias
|
||||
quantized_output = torch.ops.npu.npu_quantize(out0, scale, offset,
|
||||
torch.qint8, -1, False)
|
||||
quantized_output = torch.ops.npu.npu_quantize(out0, scale, offset, torch.qint8, -1, False)
|
||||
return quantized_output, out1
|
||||
|
||||
def replacement(rms_norm_input: torch.Tensor, residual: torch.Tensor,
|
||||
rms_norm_weight: torch.Tensor, scale: torch.Tensor,
|
||||
offset: torch.Tensor, bias: torch.Tensor):
|
||||
def replacement(
|
||||
rms_norm_input: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
rms_norm_weight: torch.Tensor,
|
||||
scale: torch.Tensor,
|
||||
offset: torch.Tensor,
|
||||
bias: torch.Tensor,
|
||||
):
|
||||
"""
|
||||
Replacement for AddRMSNormQuantWithBias fusion.
|
||||
"""
|
||||
@@ -137,11 +151,13 @@ def replacement_add_rms_norm_quant_with_bias(epsilon):
|
||||
rms_norm_input,
|
||||
residual,
|
||||
rms_norm_weight,
|
||||
# The inverse of scale is required by npu_add_rms_norm_quant kernel which is opposite to the npu_quantize kernel.
|
||||
1. / scale,
|
||||
# The inverse of scale is required by npu_add_rms_norm_quant kernel
|
||||
# which is opposite to the npu_quantize kernel.
|
||||
1.0 / scale,
|
||||
offset,
|
||||
epsilon=epsilon,
|
||||
beta=bias)
|
||||
beta=bias,
|
||||
)
|
||||
quantized_output = output[0]
|
||||
out1 = output[2]
|
||||
return quantized_output, out1
|
||||
@@ -156,40 +172,41 @@ def replacement_add_rms_norm_quant_with_bias(epsilon):
|
||||
rmsnorm_bias = torch.randn(4, device="npu")
|
||||
scale = torch.ones(4, device="npu")
|
||||
offset = torch.zeros(4, device="npu")
|
||||
return [
|
||||
rms_norm_input, residual, rms_norm_weight, scale, offset,
|
||||
rmsnorm_bias
|
||||
]
|
||||
return [rms_norm_input, residual, rms_norm_weight, scale, offset, rmsnorm_bias]
|
||||
|
||||
import torchair
|
||||
|
||||
torchair.register_replacement(search_fn=pattern,
|
||||
replace_fn=replacement,
|
||||
example_inputs=get_inputs(),
|
||||
extra_check=_extra_stream_scope_check)
|
||||
torchair.register_replacement(
|
||||
search_fn=pattern, replace_fn=replacement, example_inputs=get_inputs(), extra_check=_extra_stream_scope_check
|
||||
)
|
||||
|
||||
|
||||
# The replacement registered here will be actually executed after AOT.
|
||||
def replacement_add_rms_norm_quant_sp_pattern(epsilon):
|
||||
|
||||
def pattern(rms_norm_input: torch.Tensor, residual: torch.Tensor,
|
||||
rms_norm_weight: torch.Tensor, scale: torch.Tensor,
|
||||
offset: torch.Tensor):
|
||||
def pattern(
|
||||
rms_norm_input: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
rms_norm_weight: torch.Tensor,
|
||||
scale: torch.Tensor,
|
||||
offset: torch.Tensor,
|
||||
):
|
||||
"""
|
||||
Pattern for AddRMSNormQuantSPPattern fusion.
|
||||
"""
|
||||
output = torch.ops.npu.npu_add_rms_norm(rms_norm_input, residual,
|
||||
rms_norm_weight, epsilon)
|
||||
output = torch.ops.npu.npu_add_rms_norm(rms_norm_input, residual, rms_norm_weight, epsilon)
|
||||
out0 = output[0]
|
||||
out1 = output[2]
|
||||
out0 = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(out0, True)
|
||||
quantized_output = torch.ops.npu.npu_quantize(out0, scale, offset,
|
||||
torch.qint8, -1, False)
|
||||
quantized_output = torch.ops.npu.npu_quantize(out0, scale, offset, torch.qint8, -1, False)
|
||||
return quantized_output, out1
|
||||
|
||||
def replacement(rms_norm_input: torch.Tensor, residual: torch.Tensor,
|
||||
rms_norm_weight: torch.Tensor, scale: torch.Tensor,
|
||||
offset: torch.Tensor):
|
||||
def replacement(
|
||||
rms_norm_input: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
rms_norm_weight: torch.Tensor,
|
||||
scale: torch.Tensor,
|
||||
offset: torch.Tensor,
|
||||
):
|
||||
"""
|
||||
Replacement for the AddRMSNormQuantSPPattern fusion.
|
||||
"""
|
||||
@@ -197,14 +214,15 @@ def replacement_add_rms_norm_quant_sp_pattern(epsilon):
|
||||
rms_norm_input,
|
||||
residual,
|
||||
rms_norm_weight,
|
||||
# The inverse of scale is required by npu_add_rms_norm_quant kernel which is opposite to the npu_quantize kernel.
|
||||
1. / scale,
|
||||
# The inverse of scale is required by npu_add_rms_norm_quant kernel
|
||||
# which is opposite to the npu_quantize kernel.
|
||||
1.0 / scale,
|
||||
offset,
|
||||
epsilon=epsilon)
|
||||
epsilon=epsilon,
|
||||
)
|
||||
quantized_output = output[0]
|
||||
out1 = output[2]
|
||||
quantized_output = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
|
||||
quantized_output, True)
|
||||
quantized_output = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(quantized_output, True)
|
||||
return quantized_output, out1
|
||||
|
||||
def get_inputs():
|
||||
@@ -220,34 +238,40 @@ def replacement_add_rms_norm_quant_sp_pattern(epsilon):
|
||||
|
||||
import torchair
|
||||
|
||||
torchair.register_replacement(search_fn=pattern,
|
||||
replace_fn=replacement,
|
||||
example_inputs=get_inputs(),
|
||||
extra_check=_extra_stream_scope_check)
|
||||
torchair.register_replacement(
|
||||
search_fn=pattern, replace_fn=replacement, example_inputs=get_inputs(), extra_check=_extra_stream_scope_check
|
||||
)
|
||||
|
||||
|
||||
# The replacement registered here will be actually executed after AOT.
|
||||
def replacement_add_rms_norm_quant_sp_pattern_with_bias(epsilon):
|
||||
|
||||
def pattern(rms_norm_input: torch.Tensor, residual: torch.Tensor,
|
||||
rms_norm_weight: torch.Tensor, scale: torch.Tensor,
|
||||
offset: torch.Tensor, bias: torch.Tensor):
|
||||
def pattern(
|
||||
rms_norm_input: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
rms_norm_weight: torch.Tensor,
|
||||
scale: torch.Tensor,
|
||||
offset: torch.Tensor,
|
||||
bias: torch.Tensor,
|
||||
):
|
||||
"""
|
||||
Pattern for AddRMSNormQuantSPPatternWithBias fusion.
|
||||
"""
|
||||
output = torch.ops.npu.npu_add_rms_norm(rms_norm_input, residual,
|
||||
rms_norm_weight, epsilon)
|
||||
output = torch.ops.npu.npu_add_rms_norm(rms_norm_input, residual, rms_norm_weight, epsilon)
|
||||
out0 = output[0]
|
||||
out1 = output[2]
|
||||
out0 = out0 + bias
|
||||
out0 = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(out0, True)
|
||||
quantized_output = torch.ops.npu.npu_quantize(out0, scale, offset,
|
||||
torch.qint8, -1, False)
|
||||
quantized_output = torch.ops.npu.npu_quantize(out0, scale, offset, torch.qint8, -1, False)
|
||||
return quantized_output, out1
|
||||
|
||||
def replacement(rms_norm_input: torch.Tensor, residual: torch.Tensor,
|
||||
rms_norm_weight: torch.Tensor, scale: torch.Tensor,
|
||||
offset: torch.Tensor, bias: torch.Tensor):
|
||||
def replacement(
|
||||
rms_norm_input: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
rms_norm_weight: torch.Tensor,
|
||||
scale: torch.Tensor,
|
||||
offset: torch.Tensor,
|
||||
bias: torch.Tensor,
|
||||
):
|
||||
"""
|
||||
Replacement for the AddRMSNormQuantSPPatternWithBias fusion.
|
||||
"""
|
||||
@@ -255,15 +279,16 @@ def replacement_add_rms_norm_quant_sp_pattern_with_bias(epsilon):
|
||||
rms_norm_input,
|
||||
residual,
|
||||
rms_norm_weight,
|
||||
# The inverse of scale is required by npu_add_rms_norm_quant kernel which is opposite to the npu_quantize kernel.
|
||||
1. / scale,
|
||||
# The inverse of scale is required by npu_add_rms_norm_quant kernel
|
||||
# which is opposite to the npu_quantize kernel.
|
||||
1.0 / scale,
|
||||
offset,
|
||||
epsilon=epsilon,
|
||||
beta=bias)
|
||||
beta=bias,
|
||||
)
|
||||
quantized_output = output[0]
|
||||
out1 = output[2]
|
||||
quantized_output = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
|
||||
quantized_output, True)
|
||||
quantized_output = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(quantized_output, True)
|
||||
return quantized_output, out1
|
||||
|
||||
def get_inputs():
|
||||
@@ -276,25 +301,19 @@ def replacement_add_rms_norm_quant_sp_pattern_with_bias(epsilon):
|
||||
rmsnorm_bias = torch.randn(4, device="npu")
|
||||
scale = torch.ones(4, device="npu")
|
||||
offset = torch.zeros(4, device="npu")
|
||||
return [
|
||||
rms_norm_input, residual, rms_norm_weight, scale, offset,
|
||||
rmsnorm_bias
|
||||
]
|
||||
return [rms_norm_input, residual, rms_norm_weight, scale, offset, rmsnorm_bias]
|
||||
|
||||
import torchair
|
||||
|
||||
torchair.register_replacement(search_fn=pattern,
|
||||
replace_fn=replacement,
|
||||
example_inputs=get_inputs(),
|
||||
extra_check=_extra_stream_scope_check)
|
||||
torchair.register_replacement(
|
||||
search_fn=pattern, replace_fn=replacement, example_inputs=get_inputs(), extra_check=_extra_stream_scope_check
|
||||
)
|
||||
|
||||
|
||||
# register converter for pass
|
||||
common_epsilons = [1e-5, 1e-6]
|
||||
for eps in common_epsilons:
|
||||
logger.info(
|
||||
f"Start register fusion pattern for AddRMSNormQuant with epsilons={eps}"
|
||||
)
|
||||
logger.info(f"Start register fusion pattern for AddRMSNormQuant with epsilons={eps}")
|
||||
replacement_add_rms_norm_quant(eps)
|
||||
replacement_add_rms_norm_quant_with_bias(eps)
|
||||
replacement_add_rms_norm_quant_sp_pattern(eps)
|
||||
|
||||
@@ -25,7 +25,6 @@ from vllm.logger import logger
|
||||
|
||||
|
||||
class AddRMSNormQuantPattern:
|
||||
|
||||
def __init__(self, vllm_config: VllmConfig, eps: float = 1e-6):
|
||||
self.vllm_config = vllm_config
|
||||
self.dtype = vllm_config.model_config.dtype
|
||||
@@ -41,50 +40,48 @@ class AddRMSNormQuantPattern:
|
||||
scale = torch.ones(4, device="npu", dtype=self.dtype)
|
||||
scale_reciprocal = torch.ones(4, device="npu", dtype=self.dtype)
|
||||
offset = torch.zeros(4, device="npu", dtype=self.dtype)
|
||||
return [
|
||||
rms_norm_input, residual, rms_norm_weight, scale, scale_reciprocal,
|
||||
offset
|
||||
]
|
||||
return [rms_norm_input, residual, rms_norm_weight, scale, scale_reciprocal, offset]
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass):
|
||||
|
||||
def pattern(rms_norm_input: torch.Tensor, residual: torch.Tensor,
|
||||
rms_norm_weight: torch.Tensor, scale: torch.Tensor,
|
||||
scale_reciprocal: torch.Tensor, offset: torch.Tensor):
|
||||
def pattern(
|
||||
rms_norm_input: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
rms_norm_weight: torch.Tensor,
|
||||
scale: torch.Tensor,
|
||||
scale_reciprocal: torch.Tensor,
|
||||
offset: torch.Tensor,
|
||||
):
|
||||
"""
|
||||
Pattern for AddRMSNormQuant fusion.
|
||||
"""
|
||||
output = torch.ops.npu.npu_add_rms_norm(rms_norm_input, residual,
|
||||
rms_norm_weight, self.eps)
|
||||
output = torch.ops.npu.npu_add_rms_norm(rms_norm_input, residual, rms_norm_weight, self.eps)
|
||||
out0 = output[0]
|
||||
out1 = output[2]
|
||||
quantized_output = torch.ops.vllm.quantize(out0, scale,
|
||||
scale_reciprocal,
|
||||
offset)
|
||||
quantized_output = torch.ops.vllm.quantize(out0, scale, scale_reciprocal, offset)
|
||||
return quantized_output, out1
|
||||
|
||||
def replacement(rms_norm_input: torch.Tensor, residual: torch.Tensor,
|
||||
rms_norm_weight: torch.Tensor, scale: torch.Tensor,
|
||||
scale_reciprocal: torch.Tensor, offset: torch.Tensor):
|
||||
def replacement(
|
||||
rms_norm_input: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
rms_norm_weight: torch.Tensor,
|
||||
scale: torch.Tensor,
|
||||
scale_reciprocal: torch.Tensor,
|
||||
offset: torch.Tensor,
|
||||
):
|
||||
"""
|
||||
Replacement for the AddRMSNormQuant fusion.
|
||||
"""
|
||||
output = torch.ops.npu.npu_add_rms_norm_quant(rms_norm_input,
|
||||
residual,
|
||||
rms_norm_weight,
|
||||
scale,
|
||||
offset,
|
||||
epsilon=self.eps)
|
||||
output = torch.ops.npu.npu_add_rms_norm_quant(
|
||||
rms_norm_input, residual, rms_norm_weight, scale, offset, epsilon=self.eps
|
||||
)
|
||||
quantized_output = output[0]
|
||||
out1 = output[2]
|
||||
return quantized_output, out1
|
||||
|
||||
pm.register_replacement(pattern, replacement, self.get_inputs(),
|
||||
pm.fwd_only, pm_pass)
|
||||
pm.register_replacement(pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass)
|
||||
|
||||
|
||||
class AddRMSNormQuantPatternWithBias:
|
||||
|
||||
def __init__(self, vllm_config: VllmConfig, eps: float = 1e-6):
|
||||
self.vllm_config = vllm_config
|
||||
self.dtype = vllm_config.model_config.dtype
|
||||
@@ -101,54 +98,51 @@ class AddRMSNormQuantPatternWithBias:
|
||||
scale = torch.ones(4, device="npu", dtype=self.dtype)
|
||||
scale_reciprocal = torch.ones(4, device="npu", dtype=self.dtype)
|
||||
offset = torch.zeros(4, device="npu", dtype=self.dtype)
|
||||
return [
|
||||
rms_norm_input, residual, rms_norm_weight, scale, scale_reciprocal,
|
||||
offset, rmsnorm_bias
|
||||
]
|
||||
return [rms_norm_input, residual, rms_norm_weight, scale, scale_reciprocal, offset, rmsnorm_bias]
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass):
|
||||
|
||||
def pattern(rms_norm_input: torch.Tensor, residual: torch.Tensor,
|
||||
rms_norm_weight: torch.Tensor, scale: torch.Tensor,
|
||||
scale_reciprocal: torch.Tensor, offset: torch.Tensor,
|
||||
bias: torch.Tensor):
|
||||
def pattern(
|
||||
rms_norm_input: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
rms_norm_weight: torch.Tensor,
|
||||
scale: torch.Tensor,
|
||||
scale_reciprocal: torch.Tensor,
|
||||
offset: torch.Tensor,
|
||||
bias: torch.Tensor,
|
||||
):
|
||||
"""
|
||||
Pattern for AddRMSNormQuant fusion.
|
||||
"""
|
||||
output = torch.ops.npu.npu_add_rms_norm(rms_norm_input, residual,
|
||||
rms_norm_weight, self.eps)
|
||||
output = torch.ops.npu.npu_add_rms_norm(rms_norm_input, residual, rms_norm_weight, self.eps)
|
||||
out0 = output[0]
|
||||
out1 = output[2]
|
||||
out0 = out0 + bias
|
||||
quantized_output = torch.ops.vllm.quantize(out0, scale,
|
||||
scale_reciprocal,
|
||||
offset)
|
||||
quantized_output = torch.ops.vllm.quantize(out0, scale, scale_reciprocal, offset)
|
||||
return quantized_output, out1
|
||||
|
||||
def replacement(rms_norm_input: torch.Tensor, residual: torch.Tensor,
|
||||
rms_norm_weight: torch.Tensor, scale: torch.Tensor,
|
||||
scale_reciprocal: torch.Tensor, offset: torch.Tensor,
|
||||
bias: torch.Tensor):
|
||||
def replacement(
|
||||
rms_norm_input: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
rms_norm_weight: torch.Tensor,
|
||||
scale: torch.Tensor,
|
||||
scale_reciprocal: torch.Tensor,
|
||||
offset: torch.Tensor,
|
||||
bias: torch.Tensor,
|
||||
):
|
||||
"""
|
||||
Replacement for the AddRMSNormQuant fusion.
|
||||
"""
|
||||
output = torch.ops.npu.npu_add_rms_norm_quant(rms_norm_input,
|
||||
residual,
|
||||
rms_norm_weight,
|
||||
scale,
|
||||
offset,
|
||||
epsilon=self.eps,
|
||||
beta=bias)
|
||||
output = torch.ops.npu.npu_add_rms_norm_quant(
|
||||
rms_norm_input, residual, rms_norm_weight, scale, offset, epsilon=self.eps, beta=bias
|
||||
)
|
||||
quantized_output = output[0]
|
||||
out1 = output[2]
|
||||
return quantized_output, out1
|
||||
|
||||
pm.register_replacement(pattern, replacement, self.get_inputs(),
|
||||
pm.fwd_only, pm_pass)
|
||||
pm.register_replacement(pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass)
|
||||
|
||||
|
||||
class AddRMSNormQuantSPPattern:
|
||||
|
||||
def __init__(self, vllm_config: VllmConfig, eps: float = 1e-6):
|
||||
self.vllm_config = vllm_config
|
||||
self.dtype = vllm_config.model_config.dtype
|
||||
@@ -164,53 +158,50 @@ class AddRMSNormQuantSPPattern:
|
||||
scale = torch.ones(4, device="npu", dtype=self.dtype)
|
||||
scale_reciprocal = torch.ones(4, device="npu", dtype=self.dtype)
|
||||
offset = torch.zeros(4, device="npu", dtype=self.dtype)
|
||||
return [
|
||||
rms_norm_input, residual, rms_norm_weight, scale, scale_reciprocal,
|
||||
offset
|
||||
]
|
||||
return [rms_norm_input, residual, rms_norm_weight, scale, scale_reciprocal, offset]
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass):
|
||||
|
||||
def pattern(rms_norm_input: torch.Tensor, residual: torch.Tensor,
|
||||
rms_norm_weight: torch.Tensor, scale: torch.Tensor,
|
||||
scale_reciprocal: torch.Tensor, offset: torch.Tensor):
|
||||
def pattern(
|
||||
rms_norm_input: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
rms_norm_weight: torch.Tensor,
|
||||
scale: torch.Tensor,
|
||||
scale_reciprocal: torch.Tensor,
|
||||
offset: torch.Tensor,
|
||||
):
|
||||
"""
|
||||
Pattern for AddRMSNormQuant fusion.
|
||||
"""
|
||||
output = torch.ops.npu.npu_add_rms_norm(rms_norm_input, residual,
|
||||
rms_norm_weight, self.eps)
|
||||
output = torch.ops.npu.npu_add_rms_norm(rms_norm_input, residual, rms_norm_weight, self.eps)
|
||||
out0 = output[0]
|
||||
out1 = output[2]
|
||||
out0 = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(out0, True)
|
||||
quantized_output = torch.ops.vllm.quantize(out0, scale,
|
||||
scale_reciprocal,
|
||||
offset)
|
||||
quantized_output = torch.ops.vllm.quantize(out0, scale, scale_reciprocal, offset)
|
||||
return quantized_output, out1
|
||||
|
||||
def replacement(rms_norm_input: torch.Tensor, residual: torch.Tensor,
|
||||
rms_norm_weight: torch.Tensor, scale: torch.Tensor,
|
||||
scale_reciprocal: torch.Tensor, offset: torch.Tensor):
|
||||
def replacement(
|
||||
rms_norm_input: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
rms_norm_weight: torch.Tensor,
|
||||
scale: torch.Tensor,
|
||||
scale_reciprocal: torch.Tensor,
|
||||
offset: torch.Tensor,
|
||||
):
|
||||
"""
|
||||
Replacement for the AddRMSNormQuant fusion.
|
||||
"""
|
||||
output = torch.ops.npu.npu_add_rms_norm_quant(rms_norm_input,
|
||||
residual,
|
||||
rms_norm_weight,
|
||||
scale,
|
||||
offset,
|
||||
epsilon=self.eps)
|
||||
output = torch.ops.npu.npu_add_rms_norm_quant(
|
||||
rms_norm_input, residual, rms_norm_weight, scale, offset, epsilon=self.eps
|
||||
)
|
||||
quantized_output = output[0]
|
||||
out1 = output[2]
|
||||
quantized_output = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
|
||||
quantized_output, True)
|
||||
quantized_output = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(quantized_output, True)
|
||||
return quantized_output, out1
|
||||
|
||||
pm.register_replacement(pattern, replacement, self.get_inputs(),
|
||||
pm.fwd_only, pm_pass)
|
||||
pm.register_replacement(pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass)
|
||||
|
||||
|
||||
class AddRMSNormQuantSPPatternWithBias:
|
||||
|
||||
def __init__(self, vllm_config: VllmConfig, eps: float = 1e-6):
|
||||
self.vllm_config = vllm_config
|
||||
self.dtype = vllm_config.model_config.dtype
|
||||
@@ -227,53 +218,50 @@ class AddRMSNormQuantSPPatternWithBias:
|
||||
scale = torch.ones(4, device="npu", dtype=self.dtype)
|
||||
scale_reciprocal = torch.ones(4, device="npu", dtype=self.dtype)
|
||||
offset = torch.zeros(4, device="npu", dtype=self.dtype)
|
||||
return [
|
||||
rms_norm_input, residual, rms_norm_weight, scale, scale_reciprocal,
|
||||
offset, rmsnorm_bias
|
||||
]
|
||||
return [rms_norm_input, residual, rms_norm_weight, scale, scale_reciprocal, offset, rmsnorm_bias]
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass):
|
||||
|
||||
def pattern(rms_norm_input: torch.Tensor, residual: torch.Tensor,
|
||||
rms_norm_weight: torch.Tensor, scale: torch.Tensor,
|
||||
scale_reciprocal: torch.Tensor, offset: torch.Tensor,
|
||||
bias: torch.Tensor):
|
||||
def pattern(
|
||||
rms_norm_input: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
rms_norm_weight: torch.Tensor,
|
||||
scale: torch.Tensor,
|
||||
scale_reciprocal: torch.Tensor,
|
||||
offset: torch.Tensor,
|
||||
bias: torch.Tensor,
|
||||
):
|
||||
"""
|
||||
Pattern for AddRMSNormQuant fusion.
|
||||
"""
|
||||
output = torch.ops.npu.npu_add_rms_norm(rms_norm_input, residual,
|
||||
rms_norm_weight, self.eps)
|
||||
output = torch.ops.npu.npu_add_rms_norm(rms_norm_input, residual, rms_norm_weight, self.eps)
|
||||
out0 = output[0]
|
||||
out1 = output[2]
|
||||
out0 = out0 + bias
|
||||
out0 = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(out0, True)
|
||||
quantized_output = torch.ops.vllm.quantize(out0, scale,
|
||||
scale_reciprocal,
|
||||
offset)
|
||||
quantized_output = torch.ops.vllm.quantize(out0, scale, scale_reciprocal, offset)
|
||||
return quantized_output, out1
|
||||
|
||||
def replacement(rms_norm_input: torch.Tensor, residual: torch.Tensor,
|
||||
rms_norm_weight: torch.Tensor, scale: torch.Tensor,
|
||||
scale_reciprocal: torch.Tensor, offset: torch.Tensor,
|
||||
bias: torch.Tensor):
|
||||
def replacement(
|
||||
rms_norm_input: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
rms_norm_weight: torch.Tensor,
|
||||
scale: torch.Tensor,
|
||||
scale_reciprocal: torch.Tensor,
|
||||
offset: torch.Tensor,
|
||||
bias: torch.Tensor,
|
||||
):
|
||||
"""
|
||||
Replacement for the AddRMSNormQuant fusion.
|
||||
"""
|
||||
output = torch.ops.npu.npu_add_rms_norm_quant(rms_norm_input,
|
||||
residual,
|
||||
rms_norm_weight,
|
||||
scale,
|
||||
offset,
|
||||
epsilon=self.eps,
|
||||
beta=bias)
|
||||
output = torch.ops.npu.npu_add_rms_norm_quant(
|
||||
rms_norm_input, residual, rms_norm_weight, scale, offset, epsilon=self.eps, beta=bias
|
||||
)
|
||||
quantized_output = output[0]
|
||||
out1 = output[2]
|
||||
quantized_output = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
|
||||
quantized_output, True)
|
||||
quantized_output = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(quantized_output, True)
|
||||
return quantized_output, out1
|
||||
|
||||
pm.register_replacement(pattern, replacement, self.get_inputs(),
|
||||
pm.fwd_only, pm_pass)
|
||||
pm.register_replacement(pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass)
|
||||
|
||||
|
||||
class AddRMSNormQuantFusionPass(VllmInductorPass):
|
||||
@@ -283,25 +271,19 @@ class AddRMSNormQuantFusionPass(VllmInductorPass):
|
||||
|
||||
def __init__(self, vllm_config: VllmConfig):
|
||||
super().__init__(vllm_config)
|
||||
self.pattern_match_passes: PatternMatcherPass = PatternMatcherPass(
|
||||
pass_name="rmsnorm_quant_fusion_pass")
|
||||
self.pattern_match_passes: PatternMatcherPass = PatternMatcherPass(pass_name="rmsnorm_quant_fusion_pass")
|
||||
|
||||
dtype = vllm_config.model_config.dtype
|
||||
if dtype not in (torch.bfloat16, torch.float16):
|
||||
logger.debug("Quant fusion not enabled: unsupported dtype %s",
|
||||
dtype)
|
||||
logger.debug("Quant fusion not enabled: unsupported dtype %s", dtype)
|
||||
return
|
||||
|
||||
common_epsilons = [1e-5, 1e-6]
|
||||
for eps in common_epsilons:
|
||||
AddRMSNormQuantPattern(vllm_config,
|
||||
eps=eps).register(self.pattern_match_passes)
|
||||
AddRMSNormQuantPatternWithBias(vllm_config, eps=eps).register(
|
||||
self.pattern_match_passes)
|
||||
AddRMSNormQuantSPPattern(vllm_config, eps=eps).register(
|
||||
self.pattern_match_passes)
|
||||
AddRMSNormQuantSPPatternWithBias(vllm_config, eps=eps).register(
|
||||
self.pattern_match_passes)
|
||||
AddRMSNormQuantPattern(vllm_config, eps=eps).register(self.pattern_match_passes)
|
||||
AddRMSNormQuantPatternWithBias(vllm_config, eps=eps).register(self.pattern_match_passes)
|
||||
AddRMSNormQuantSPPattern(vllm_config, eps=eps).register(self.pattern_match_passes)
|
||||
AddRMSNormQuantSPPatternWithBias(vllm_config, eps=eps).register(self.pattern_match_passes)
|
||||
|
||||
def __call__(self, graph: torch.fx.Graph):
|
||||
self.begin()
|
||||
|
||||
@@ -17,8 +17,7 @@
|
||||
#
|
||||
import torch
|
||||
import torch._inductor.pattern_matcher as pm
|
||||
from torch._inductor.pattern_matcher import (PatternMatcherPass,
|
||||
PatternPrettyPrinter)
|
||||
from torch._inductor.pattern_matcher import PatternMatcherPass, PatternPrettyPrinter
|
||||
from vllm.attention.layer import Attention
|
||||
from vllm.compilation.vllm_inductor_pass import VllmInductorPass
|
||||
from vllm.config import VllmConfig, get_layers_from_vllm_config
|
||||
@@ -27,13 +26,7 @@ from vllm.logger import logger
|
||||
|
||||
|
||||
class QKNormRopeFusionPattern:
|
||||
|
||||
def __init__(self,
|
||||
vllm_config,
|
||||
head_dim,
|
||||
num_heads,
|
||||
num_kv_heads,
|
||||
eps=1e-6):
|
||||
def __init__(self, vllm_config, head_dim, num_heads, num_kv_heads, eps=1e-6):
|
||||
self.vllm_config = vllm_config
|
||||
self.head_dim = head_dim
|
||||
self.num_heads = num_heads
|
||||
@@ -45,65 +38,38 @@ class QKNormRopeFusionPattern:
|
||||
|
||||
def get_inputs(self):
|
||||
T = 5
|
||||
qkv = torch.empty(T,
|
||||
self.q_size + 2 * self.kv_size,
|
||||
dtype=torch.bfloat16,
|
||||
device="npu")
|
||||
q_weight = torch.empty(self.head_dim,
|
||||
dtype=torch.bfloat16,
|
||||
device="npu")
|
||||
k_weight = torch.empty(self.head_dim,
|
||||
dtype=torch.bfloat16,
|
||||
device="npu")
|
||||
cos = torch.empty(1,
|
||||
T,
|
||||
1,
|
||||
self.head_dim,
|
||||
dtype=torch.bfloat16,
|
||||
device="npu")
|
||||
sin = torch.empty(1,
|
||||
T,
|
||||
1,
|
||||
self.head_dim,
|
||||
dtype=torch.bfloat16,
|
||||
device="npu")
|
||||
qkv = torch.empty(T, self.q_size + 2 * self.kv_size, dtype=torch.bfloat16, device="npu")
|
||||
q_weight = torch.empty(self.head_dim, dtype=torch.bfloat16, device="npu")
|
||||
k_weight = torch.empty(self.head_dim, dtype=torch.bfloat16, device="npu")
|
||||
cos = torch.empty(1, T, 1, self.head_dim, dtype=torch.bfloat16, device="npu")
|
||||
sin = torch.empty(1, T, 1, self.head_dim, dtype=torch.bfloat16, device="npu")
|
||||
return [qkv, q_weight, k_weight, cos, sin]
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass):
|
||||
def pattern(
|
||||
qkv: torch.Tensor, q_weight: torch.Tensor, k_weight: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor
|
||||
):
|
||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||
|
||||
def pattern(qkv: torch.Tensor, q_weight: torch.Tensor,
|
||||
k_weight: torch.Tensor, cos: torch.Tensor,
|
||||
sin: torch.Tensor):
|
||||
q_by_head = q.view(*q.shape[:-1], q.shape[-1] // self.head_dim, self.head_dim)
|
||||
q_norm_out, _ = torch.ops.npu.npu_rms_norm(q_by_head, q_weight, self.eps)
|
||||
|
||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size],
|
||||
dim=-1)
|
||||
|
||||
q_by_head = q.view(*q.shape[:-1], q.shape[-1] // self.head_dim,
|
||||
self.head_dim)
|
||||
q_norm_out, _ = torch.ops.npu.npu_rms_norm(q_by_head, q_weight,
|
||||
self.eps)
|
||||
|
||||
k_by_head = k.view(*k.shape[:-1], k.shape[-1] // self.head_dim,
|
||||
self.head_dim)
|
||||
k_norm_out, _ = torch.ops.npu.npu_rms_norm(k_by_head, k_weight,
|
||||
self.eps)
|
||||
k_by_head = k.view(*k.shape[:-1], k.shape[-1] // self.head_dim, self.head_dim)
|
||||
k_norm_out, _ = torch.ops.npu.npu_rms_norm(k_by_head, k_weight, self.eps)
|
||||
|
||||
q_flat = q_norm_out.view(q.shape)
|
||||
q_reshape = q_flat.contiguous().view(1, q_flat.shape[0], -1,
|
||||
self.head_dim)
|
||||
q_reshape = q_flat.contiguous().view(1, q_flat.shape[0], -1, self.head_dim)
|
||||
|
||||
k_flat = k_norm_out.view(k.shape)
|
||||
k_reshape = k_flat.contiguous().view(1, k_flat.shape[0], -1,
|
||||
self.head_dim)
|
||||
k_reshape = k_flat.contiguous().view(1, k_flat.shape[0], -1, self.head_dim)
|
||||
|
||||
q_rope, k_rope = torch.ops.npu.npu_apply_rotary_pos_emb(
|
||||
q_reshape, k_reshape, cos, sin)
|
||||
q_rope, k_rope = torch.ops.npu.npu_apply_rotary_pos_emb(q_reshape, k_reshape, cos, sin)
|
||||
|
||||
return q_rope, k_rope, v
|
||||
|
||||
def replacement(qkv: torch.Tensor, q_weight: torch.Tensor,
|
||||
k_weight: torch.Tensor, cos: torch.Tensor,
|
||||
sin: torch.Tensor):
|
||||
def replacement(
|
||||
qkv: torch.Tensor, q_weight: torch.Tensor, k_weight: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor
|
||||
):
|
||||
results = torch.ops.vllm.qkv_rmsnorm_rope(
|
||||
input=qkv,
|
||||
q_weight=q_weight,
|
||||
@@ -115,22 +81,16 @@ class QKNormRopeFusionPattern:
|
||||
q_bias=None,
|
||||
k_bias=None,
|
||||
sin=sin,
|
||||
cos=cos)
|
||||
cos=cos,
|
||||
)
|
||||
|
||||
return results
|
||||
|
||||
pm.register_replacement(pattern, replacement, self.get_inputs(),
|
||||
pm.fwd_only, pm_pass)
|
||||
pm.register_replacement(pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass)
|
||||
|
||||
|
||||
class QKNormRopeFusionPatternWithBias:
|
||||
|
||||
def __init__(self,
|
||||
vllm_config,
|
||||
head_dim,
|
||||
num_heads,
|
||||
num_kv_heads,
|
||||
eps=1e-6):
|
||||
def __init__(self, vllm_config, head_dim, num_heads, num_kv_heads, eps=1e-6):
|
||||
self.head_dim = head_dim
|
||||
self.num_heads = num_heads
|
||||
self.num_kv_heads = num_kv_heads
|
||||
@@ -142,71 +102,55 @@ class QKNormRopeFusionPatternWithBias:
|
||||
|
||||
def get_inputs(self):
|
||||
T = 5
|
||||
qkv = torch.empty(T,
|
||||
self.q_size + 2 * self.kv_size,
|
||||
dtype=torch.bfloat16,
|
||||
device="npu")
|
||||
q_weight = torch.empty(self.head_dim,
|
||||
dtype=torch.bfloat16,
|
||||
device="npu")
|
||||
k_weight = torch.empty(self.head_dim,
|
||||
dtype=torch.bfloat16,
|
||||
device="npu")
|
||||
qkv = torch.empty(T, self.q_size + 2 * self.kv_size, dtype=torch.bfloat16, device="npu")
|
||||
q_weight = torch.empty(self.head_dim, dtype=torch.bfloat16, device="npu")
|
||||
k_weight = torch.empty(self.head_dim, dtype=torch.bfloat16, device="npu")
|
||||
q_bias = torch.empty(self.head_dim, dtype=torch.bfloat16, device="npu")
|
||||
k_bias = torch.empty(self.head_dim, dtype=torch.bfloat16, device="npu")
|
||||
cos = torch.empty(1,
|
||||
T,
|
||||
1,
|
||||
self.head_dim,
|
||||
dtype=torch.bfloat16,
|
||||
device="npu")
|
||||
sin = torch.empty(1,
|
||||
T,
|
||||
1,
|
||||
self.head_dim,
|
||||
dtype=torch.bfloat16,
|
||||
device="npu")
|
||||
cos = torch.empty(1, T, 1, self.head_dim, dtype=torch.bfloat16, device="npu")
|
||||
sin = torch.empty(1, T, 1, self.head_dim, dtype=torch.bfloat16, device="npu")
|
||||
|
||||
return [qkv, q_weight, k_weight, q_bias, k_bias, cos, sin]
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass):
|
||||
def pattern(
|
||||
qkv: torch.Tensor,
|
||||
q_weight: torch.Tensor,
|
||||
k_weight: torch.Tensor,
|
||||
q_bias: torch.Tensor,
|
||||
k_bias: torch.Tensor,
|
||||
cos: torch.Tensor,
|
||||
sin: torch.Tensor,
|
||||
):
|
||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||
|
||||
def pattern(qkv: torch.Tensor, q_weight: torch.Tensor,
|
||||
k_weight: torch.Tensor, q_bias: torch.Tensor,
|
||||
k_bias: torch.Tensor, cos: torch.Tensor,
|
||||
sin: torch.Tensor):
|
||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size],
|
||||
dim=-1)
|
||||
|
||||
q_by_head = q.view(*q.shape[:-1], q.shape[-1] // self.head_dim,
|
||||
self.head_dim)
|
||||
q_norm_out, _ = torch.ops.npu.npu_rms_norm(q_by_head, q_weight,
|
||||
self.eps)
|
||||
q_by_head = q.view(*q.shape[:-1], q.shape[-1] // self.head_dim, self.head_dim)
|
||||
q_norm_out, _ = torch.ops.npu.npu_rms_norm(q_by_head, q_weight, self.eps)
|
||||
q_normed = q_norm_out + q_bias
|
||||
|
||||
k_by_head = k.view(*k.shape[:-1], k.shape[-1] // self.head_dim,
|
||||
self.head_dim)
|
||||
k_norm_out, _ = torch.ops.npu.npu_rms_norm(k_by_head, k_weight,
|
||||
self.eps)
|
||||
k_by_head = k.view(*k.shape[:-1], k.shape[-1] // self.head_dim, self.head_dim)
|
||||
k_norm_out, _ = torch.ops.npu.npu_rms_norm(k_by_head, k_weight, self.eps)
|
||||
k_normed = k_norm_out + k_bias
|
||||
|
||||
q_flat = q_normed.view(q.shape)
|
||||
q_reshape = q_flat.contiguous().view(1, q_flat.shape[0], -1,
|
||||
self.head_dim)
|
||||
q_reshape = q_flat.contiguous().view(1, q_flat.shape[0], -1, self.head_dim)
|
||||
|
||||
k_flat = k_normed.view(k.shape)
|
||||
k_reshape = k_flat.contiguous().view(1, k_flat.shape[0], -1,
|
||||
self.head_dim)
|
||||
k_reshape = k_flat.contiguous().view(1, k_flat.shape[0], -1, self.head_dim)
|
||||
|
||||
q_rope, k_rope = torch.ops.npu.npu_apply_rotary_pos_emb(
|
||||
q_reshape, k_reshape, cos, sin)
|
||||
q_rope, k_rope = torch.ops.npu.npu_apply_rotary_pos_emb(q_reshape, k_reshape, cos, sin)
|
||||
|
||||
return q_rope, k_rope, v
|
||||
|
||||
def replacement(qkv: torch.Tensor, q_weight: torch.Tensor,
|
||||
k_weight: torch.Tensor, q_bias: torch.Tensor,
|
||||
k_bias: torch.Tensor, cos: torch.Tensor,
|
||||
sin: torch.Tensor):
|
||||
def replacement(
|
||||
qkv: torch.Tensor,
|
||||
q_weight: torch.Tensor,
|
||||
k_weight: torch.Tensor,
|
||||
q_bias: torch.Tensor,
|
||||
k_bias: torch.Tensor,
|
||||
cos: torch.Tensor,
|
||||
sin: torch.Tensor,
|
||||
):
|
||||
results = torch.ops.vllm.qkv_rmsnorm_rope(
|
||||
input=qkv,
|
||||
q_weight=q_weight,
|
||||
@@ -218,11 +162,11 @@ class QKNormRopeFusionPatternWithBias:
|
||||
q_bias=q_bias,
|
||||
k_bias=k_bias,
|
||||
cos=cos,
|
||||
sin=sin)
|
||||
sin=sin,
|
||||
)
|
||||
return results
|
||||
|
||||
pm.register_replacement(pattern, replacement, self.get_inputs(),
|
||||
pm.fwd_only, pm_pass)
|
||||
pm.register_replacement(pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass)
|
||||
|
||||
|
||||
class QKNormRopeFusionPass(VllmInductorPass):
|
||||
@@ -232,44 +176,38 @@ class QKNormRopeFusionPass(VllmInductorPass):
|
||||
|
||||
def __init__(self, vllm_config: VllmConfig):
|
||||
super().__init__(vllm_config)
|
||||
self.pattern_match_passes: PatternMatcherPass = PatternMatcherPass(
|
||||
pass_name="qknorm_rope_fusion_pass")
|
||||
self.pattern_match_passes: PatternMatcherPass = PatternMatcherPass(pass_name="qknorm_rope_fusion_pass")
|
||||
|
||||
dtype = vllm_config.model_config.dtype
|
||||
if dtype not in (torch.bfloat16, torch.float16):
|
||||
logger.debug(
|
||||
"QKNorm and Rope fusion not enabled: unsupported dtype %s",
|
||||
dtype)
|
||||
logger.debug("QKNorm and Rope fusion not enabled: unsupported dtype %s", dtype)
|
||||
return
|
||||
|
||||
# use one attn layer to get meta (such as head_dim) for QKNormRopeFusionPattern
|
||||
attn_layers: dict[str, Attention] = get_layers_from_vllm_config(
|
||||
vllm_config, Attention)
|
||||
attn_layers: dict[str, Attention] = get_layers_from_vllm_config(vllm_config, Attention)
|
||||
if len(attn_layers) == 0:
|
||||
logger.debug(
|
||||
"QKNorm and Rope fusion enabled, but no Attention layers were discovered."
|
||||
)
|
||||
logger.debug("QKNorm and Rope fusion enabled, but no Attention layers were discovered.")
|
||||
return
|
||||
layer = next(iter(attn_layers.values()))
|
||||
for epsilon in [1e-6, 1e-5]:
|
||||
if layer.head_size != 128:
|
||||
logger.debug(
|
||||
"QKNorm and Rope fusion not enabled: head_dim %d is not equal of 128",
|
||||
layer.head_size)
|
||||
logger.debug("QKNorm and Rope fusion not enabled: head_dim %d is not equal of 128", layer.head_size)
|
||||
continue
|
||||
QKNormRopeFusionPattern(vllm_config=vllm_config,
|
||||
QKNormRopeFusionPattern(
|
||||
vllm_config=vllm_config,
|
||||
head_dim=layer.head_size,
|
||||
num_heads=layer.num_heads,
|
||||
num_kv_heads=layer.num_kv_heads,
|
||||
eps=epsilon).register(
|
||||
self.pattern_match_passes)
|
||||
eps=epsilon,
|
||||
).register(self.pattern_match_passes)
|
||||
|
||||
QKNormRopeFusionPatternWithBias(vllm_config=vllm_config,
|
||||
QKNormRopeFusionPatternWithBias(
|
||||
vllm_config=vllm_config,
|
||||
head_dim=layer.head_size,
|
||||
num_heads=layer.num_heads,
|
||||
num_kv_heads=layer.num_kv_heads,
|
||||
eps=epsilon).register(
|
||||
self.pattern_match_passes)
|
||||
eps=epsilon,
|
||||
).register(self.pattern_match_passes)
|
||||
|
||||
def __call__(self, graph: torch.fx.Graph):
|
||||
self.begin()
|
||||
|
||||
@@ -1,10 +1,8 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
import os
|
||||
import subprocess
|
||||
from collections import defaultdict
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
import psutil
|
||||
from vllm.logger import logger
|
||||
@@ -13,26 +11,22 @@ ALLOWED_CPUS_PATH = "/proc/self/status"
|
||||
ASCEND_RT_VISIBLE_DEVICES = os.getenv("ASCEND_RT_VISIBLE_DEVICES")
|
||||
|
||||
|
||||
def execute_command(cmd: List[str]) -> Tuple[str, int]:
|
||||
with subprocess.Popen(cmd,
|
||||
shell=False,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE) as p:
|
||||
def execute_command(cmd: list[str]) -> tuple[str, int]:
|
||||
with subprocess.Popen(cmd, shell=False, stdout=subprocess.PIPE, stderr=subprocess.PIPE) as p:
|
||||
out, _ = p.communicate(timeout=1000)
|
||||
return out.decode(), p.returncode
|
||||
|
||||
|
||||
class DeviceInfo:
|
||||
|
||||
def __init__(self):
|
||||
self.npu_map_info: Dict[str, Dict[str, str]] = self.get_npu_map_info()
|
||||
self.allowed_cpus: List[int] = self.parse_allowed_cpus()
|
||||
self.running_npu_list: List[int] = self.get_running_npus()
|
||||
self.npu_affinity: Dict[int, List[int]] = self.parse_topo_affinity()
|
||||
self.npu_map_info: dict[str, dict[str, str]] = self.get_npu_map_info()
|
||||
self.allowed_cpus: list[int] = self.parse_allowed_cpus()
|
||||
self.running_npu_list: list[int] = self.get_running_npus()
|
||||
self.npu_affinity: dict[int, list[int]] = self.parse_topo_affinity()
|
||||
|
||||
@staticmethod
|
||||
def expand_cpu_list(allowed_list_str: str) -> List[int]:
|
||||
allowed_cpus_list: List[int] = []
|
||||
def expand_cpu_list(allowed_list_str: str) -> list[int]:
|
||||
allowed_cpus_list: list[int] = []
|
||||
for per_range in allowed_list_str.split(","):
|
||||
if "-" in per_range:
|
||||
start_cpu, end_cpu = map(int, per_range.split("-"))
|
||||
@@ -42,8 +36,8 @@ class DeviceInfo:
|
||||
return allowed_cpus_list
|
||||
|
||||
@staticmethod
|
||||
def get_npu_map_info() -> Dict[str, Dict[str, str]]:
|
||||
npu_map_info: Dict[str, Dict[str, str]] = {}
|
||||
def get_npu_map_info() -> dict[str, dict[str, str]]:
|
||||
npu_map_info: dict[str, dict[str, str]] = {}
|
||||
npu_info, _ = execute_command(["npu-smi", "info", "-m"])
|
||||
npu_map = npu_info.strip().split("\n")[1:]
|
||||
for line in npu_map:
|
||||
@@ -55,7 +49,7 @@ class DeviceInfo:
|
||||
npu_map_info[npu_id][chip_id] = chip_logic_id
|
||||
return npu_map_info
|
||||
|
||||
def get_running_npus(self) -> List[int]:
|
||||
def get_running_npus(self) -> list[int]:
|
||||
npu_message, _ = execute_command(["npu-smi", "info"])
|
||||
in_proc_section = False
|
||||
running_npu_set = set()
|
||||
@@ -76,36 +70,29 @@ class DeviceInfo:
|
||||
continue
|
||||
chip_logic_id = self.npu_map_info.get(npu_id, {}).get(chip_id)
|
||||
if not chip_logic_id or not chip_logic_id.isdigit():
|
||||
raise RuntimeError(
|
||||
"Failed to get correct chip_logic_id from command 'npu-smi info -m'."
|
||||
)
|
||||
raise RuntimeError("Failed to get correct chip_logic_id from command 'npu-smi info -m'.")
|
||||
running_npu_set.add(int(chip_logic_id))
|
||||
if ASCEND_RT_VISIBLE_DEVICES:
|
||||
devices_str = ASCEND_RT_VISIBLE_DEVICES
|
||||
devices_list = [int(x) for x in devices_str.split(",")]
|
||||
running_npu_set = set(devices_list) & running_npu_set
|
||||
if not running_npu_set:
|
||||
raise RuntimeError(
|
||||
"Can not get running npu info, you can use BIND_CPU=0 to skip."
|
||||
)
|
||||
raise RuntimeError("Can not get running npu info, you can use BIND_CPU=0 to skip.")
|
||||
return sorted(running_npu_set)
|
||||
|
||||
def parse_allowed_cpus(self) -> List[int]:
|
||||
def parse_allowed_cpus(self) -> list[int]:
|
||||
if not os.path.exists(ALLOWED_CPUS_PATH):
|
||||
return []
|
||||
with open(ALLOWED_CPUS_PATH) as f:
|
||||
for line in f:
|
||||
if line.startswith("Cpus_allowed_list"):
|
||||
return self.expand_cpu_list(line.split()[1])
|
||||
raise RuntimeError(
|
||||
"Can not found specific 'Cpus_allowed_list' in the '/proc/self/status' file."
|
||||
)
|
||||
raise RuntimeError("Can not found specific 'Cpus_allowed_list' in the '/proc/self/status' file.")
|
||||
|
||||
def parse_topo_affinity(self) -> Dict[int, List[int]]:
|
||||
def parse_topo_affinity(self) -> dict[int, list[int]]:
|
||||
chip_logic_id = 0
|
||||
affinity: Dict[int, List[int]] = {}
|
||||
affinity_message, _ = execute_command(
|
||||
["npu-smi", "info", "-t", "topo"])
|
||||
affinity: dict[int, list[int]] = {}
|
||||
affinity_message, _ = execute_command(["npu-smi", "info", "-t", "topo"])
|
||||
for line in affinity_message.splitlines():
|
||||
if line.startswith("NPU"):
|
||||
parts = line.split()
|
||||
@@ -117,21 +104,19 @@ class DeviceInfo:
|
||||
|
||||
|
||||
class CpuAlloc:
|
||||
|
||||
def __init__(self, rank_id: int):
|
||||
self.rank_id = rank_id
|
||||
self.device_info: DeviceInfo = DeviceInfo()
|
||||
self.cpu_node: Dict[int, int] = {}
|
||||
self.numa_to_cpu_map: Dict[int, List[int]] = defaultdict(list)
|
||||
self.npu_cpu_pool: Dict[int, List[int]] = {}
|
||||
self.assign_main: Dict[int, List[int]] = {}
|
||||
self.assign_acl: Dict[int, List[int]] = {}
|
||||
self.assign_rel: Dict[int, List[int]] = {}
|
||||
self.cpu_node: dict[int, int] = {}
|
||||
self.numa_to_cpu_map: dict[int, list[int]] = defaultdict(list)
|
||||
self.npu_cpu_pool: dict[int, list[int]] = {}
|
||||
self.assign_main: dict[int, list[int]] = {}
|
||||
self.assign_acl: dict[int, list[int]] = {}
|
||||
self.assign_rel: dict[int, list[int]] = {}
|
||||
|
||||
@staticmethod
|
||||
def get_threads_map(
|
||||
thread_message: str) -> Dict[str, Dict[str, List[str]]]:
|
||||
threads_map: Dict[str, Dict[str, List[str]]] = {}
|
||||
def get_threads_map(thread_message: str) -> dict[str, dict[str, list[str]]]:
|
||||
threads_map: dict[str, dict[str, list[str]]] = {}
|
||||
for line in thread_message.splitlines():
|
||||
parts = line.split()
|
||||
if len(parts) < 2:
|
||||
@@ -144,40 +129,33 @@ class CpuAlloc:
|
||||
else:
|
||||
continue
|
||||
if main_pid not in threads_map:
|
||||
threads_map[main_pid] = {
|
||||
"acl_thread": [],
|
||||
"release_thread": []
|
||||
}
|
||||
threads_map[main_pid] = {"acl_thread": [], "release_thread": []}
|
||||
threads_map[main_pid][key].append(sub_pid)
|
||||
return threads_map
|
||||
|
||||
@staticmethod
|
||||
def bind(pid: str, cpus: List[int], bind_sub_thread: bool) -> None:
|
||||
def bind(pid: str, cpus: list[int], bind_sub_thread: bool) -> None:
|
||||
if cpus:
|
||||
cpu_list = ",".join(map(str, cpus))
|
||||
if bind_sub_thread:
|
||||
bind_result, return_code = execute_command(
|
||||
["taskset", "-acp", cpu_list, pid])
|
||||
bind_result, return_code = execute_command(["taskset", "-acp", cpu_list, pid])
|
||||
else:
|
||||
bind_result, return_code = execute_command(
|
||||
["taskset", "-cp", cpu_list, pid])
|
||||
bind_result, return_code = execute_command(["taskset", "-cp", cpu_list, pid])
|
||||
if return_code != 0:
|
||||
raise RuntimeError(f"Failed to bind {pid} to CPU {cpu_list}.")
|
||||
|
||||
def average_distribute(
|
||||
self, groups: Dict[str, List[int]]) -> Dict[int, List[int]]:
|
||||
result: Dict[int, List[int]] = {}
|
||||
def average_distribute(self, groups: dict[str, list[int]]) -> dict[int, list[int]]:
|
||||
result: dict[int, list[int]] = {}
|
||||
for key, npu_list in groups.items():
|
||||
cpu_list = sorted(self.npu_cpu_pool[npu_list[0]])
|
||||
cpu_num_per_npu = len(cpu_list) // len(npu_list)
|
||||
for i, npu in enumerate(npu_list):
|
||||
start_index = i * cpu_num_per_npu
|
||||
end_index = (i + 1) * cpu_num_per_npu if i < len(
|
||||
npu_list) - 1 else len(cpu_list)
|
||||
end_index = (i + 1) * cpu_num_per_npu if i < len(npu_list) - 1 else len(cpu_list)
|
||||
result[npu] = cpu_list[start_index:end_index]
|
||||
return result
|
||||
|
||||
def extend_numa(self, cpu_list: List[int]) -> List[int]:
|
||||
def extend_numa(self, cpu_list: list[int]) -> list[int]:
|
||||
if not cpu_list:
|
||||
return []
|
||||
nodes = {self.cpu_node[c] for c in cpu_list}
|
||||
@@ -203,9 +181,7 @@ class CpuAlloc:
|
||||
self.cpu_node[cpu] = node
|
||||
self.numa_to_cpu_map[node].append(cpu)
|
||||
if len(self.numa_to_cpu_map) == 0:
|
||||
raise RuntimeError(
|
||||
"lscpu command output error, no NUMA node available. Please check!"
|
||||
)
|
||||
raise RuntimeError("lscpu command output error, no NUMA node available. Please check!")
|
||||
|
||||
def handle_no_affinity(self) -> None:
|
||||
num_running_npu = len(self.device_info.running_npu_list)
|
||||
@@ -219,10 +195,7 @@ class CpuAlloc:
|
||||
index = 0
|
||||
for node in sorted(self.numa_to_cpu_map):
|
||||
# Available CPUs on this NUMA (constrained by allowed_cpus)
|
||||
cpus = [
|
||||
c for c in self.numa_to_cpu_map[node]
|
||||
if c in self.device_info.allowed_cpus
|
||||
]
|
||||
cpus = [c for c in self.numa_to_cpu_map[node] if c in self.device_info.allowed_cpus]
|
||||
if not cpus:
|
||||
continue
|
||||
# The actual number of NPUs to be allocated on this NUMA.
|
||||
@@ -251,19 +224,16 @@ class CpuAlloc:
|
||||
return
|
||||
for npu in self.device_info.running_npu_list:
|
||||
base_cpu_list = [
|
||||
cpu for cpu in self.device_info.npu_affinity.get(npu, [])
|
||||
if cpu in self.device_info.allowed_cpus
|
||||
cpu for cpu in self.device_info.npu_affinity.get(npu, []) if cpu in self.device_info.allowed_cpus
|
||||
]
|
||||
if not base_cpu_list:
|
||||
raise RuntimeError(
|
||||
"CPUs available in 'Cpus_allowed_list' conflict with NUMA affinity."
|
||||
)
|
||||
raise RuntimeError("CPUs available in 'Cpus_allowed_list' conflict with NUMA affinity.")
|
||||
extra_cpu_list = self.extend_numa(base_cpu_list)
|
||||
self.npu_cpu_pool[npu] = extra_cpu_list
|
||||
groups = defaultdict(list)
|
||||
for npu, cpus in self.npu_cpu_pool.items():
|
||||
groups[str(cpus)].append(npu)
|
||||
final: Dict[int, List[int]] = {}
|
||||
final: dict[int, list[int]] = {}
|
||||
for key, npu_list in groups.items():
|
||||
if len(npu_list) == 1:
|
||||
final[npu_list[0]] = self.npu_cpu_pool[npu_list[0]]
|
||||
@@ -279,8 +249,8 @@ class CpuAlloc:
|
||||
rel = [pool[-1]]
|
||||
else:
|
||||
raise RuntimeError(
|
||||
"The number of CPUs is insufficient to bind to the NPUs. "
|
||||
"Each NPU requires at least 3 CPUs.")
|
||||
"The number of CPUs is insufficient to bind to the NPUs. Each NPU requires at least 3 CPUs."
|
||||
)
|
||||
self.assign_main[npu] = main
|
||||
self.assign_acl[npu] = acl
|
||||
self.assign_rel[npu] = rel
|
||||
@@ -290,10 +260,8 @@ class CpuAlloc:
|
||||
current_npu = self.device_info.running_npu_list[self.rank_id]
|
||||
main = " ".join(map(str, self.assign_main[current_npu]))
|
||||
acl = " ".join(map(str, self.assign_acl[current_npu]))
|
||||
rel = str(self.assign_rel[current_npu]
|
||||
) if self.assign_rel[current_npu] else ""
|
||||
logger.info(
|
||||
f"NPU{current_npu}: main=[{main}] acl=[{acl}] release=[{rel}]")
|
||||
rel = str(self.assign_rel[current_npu]) if self.assign_rel[current_npu] else ""
|
||||
logger.info(f"NPU{current_npu}: main=[{main}] acl=[{acl}] release=[{rel}]")
|
||||
|
||||
def bind_threads(self) -> None:
|
||||
thread_message, _ = execute_command(["ps", "-Te"])
|
||||
@@ -303,8 +271,7 @@ class CpuAlloc:
|
||||
self.bind(main_pid, self.assign_main[current_npu], True)
|
||||
for acl_thread in threads_map.get(main_pid, {}).get("acl_thread", []):
|
||||
self.bind(acl_thread, self.assign_acl[current_npu], False)
|
||||
for release_thread in threads_map.get(main_pid,
|
||||
{}).get("release_thread", []):
|
||||
for release_thread in threads_map.get(main_pid, {}).get("release_thread", []):
|
||||
self.bind(release_thread, self.assign_rel[current_npu], False)
|
||||
|
||||
def run_all(self) -> None:
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from vllm.model_executor.layers.linear import LinearBase
|
||||
@@ -7,26 +6,26 @@ from vllm.model_executor.layers.linear import LinearBase
|
||||
|
||||
@dataclass
|
||||
class FlashCommon3Context:
|
||||
gate: Optional[LinearBase] = None
|
||||
topk_weights: Optional[torch.Tensor] = None
|
||||
topk_ids: Optional[torch.Tensor] = None
|
||||
row_idx: Optional[torch.Tensor] = None
|
||||
shared_experts: Optional[torch.nn.Module] = None
|
||||
shared_out: Optional[torch.Tensor] = None
|
||||
gate: LinearBase | None = None
|
||||
topk_weights: torch.Tensor | None = None
|
||||
topk_ids: torch.Tensor | None = None
|
||||
row_idx: torch.Tensor | None = None
|
||||
shared_experts: torch.nn.Module | None = None
|
||||
shared_out: torch.Tensor | None = None
|
||||
|
||||
|
||||
_flash_common3_context: Optional[FlashCommon3Context] = None
|
||||
_flash_common3_context: FlashCommon3Context | None = None
|
||||
|
||||
|
||||
def get_flash_common3_context() -> Optional[FlashCommon3Context]:
|
||||
def get_flash_common3_context() -> FlashCommon3Context | None:
|
||||
return _flash_common3_context
|
||||
|
||||
|
||||
def set_flash_common3_context(
|
||||
topk_weights: Optional[torch.Tensor] = None,
|
||||
topk_ids: Optional[torch.Tensor] = None,
|
||||
shared_experts: Optional[torch.nn.Module] = None,
|
||||
shared_out: Optional[torch.Tensor] = None,
|
||||
topk_weights: torch.Tensor | None = None,
|
||||
topk_ids: torch.Tensor | None = None,
|
||||
shared_experts: torch.nn.Module | None = None,
|
||||
shared_out: torch.Tensor | None = None,
|
||||
):
|
||||
global _flash_common3_context
|
||||
if _flash_common3_context is None:
|
||||
|
||||
@@ -46,17 +46,20 @@ def register_meta_if_necessary(ns: str, op_name: str, fn, overload: str = ""):
|
||||
if overload != "":
|
||||
op_name = op_name + "." + overload
|
||||
schema_to_find = ns + "::" + op_name
|
||||
meta_impl_list = torch._C._dispatch_get_registrations_for_dispatch_key(
|
||||
"Meta")
|
||||
meta_impl_list = torch._C._dispatch_get_registrations_for_dispatch_key("Meta")
|
||||
if schema_to_find in meta_impl_list:
|
||||
return
|
||||
lib.impl(op_name, fn, "Meta")
|
||||
|
||||
|
||||
def rotary_embedding_meta(positions: torch.Tensor, query: torch.Tensor,
|
||||
key: torch.Tensor, head_size: int,
|
||||
cos_sin_cache: torch.Tensor, is_neox: bool):
|
||||
|
||||
def rotary_embedding_meta(
|
||||
positions: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
head_size: int,
|
||||
cos_sin_cache: torch.Tensor,
|
||||
is_neox: bool,
|
||||
):
|
||||
num_tokens = positions.numel()
|
||||
query_hidden_size = query.numel() // num_tokens
|
||||
key_hidden_size = key.numel() // num_tokens
|
||||
@@ -68,38 +71,41 @@ def rotary_embedding_meta(positions: torch.Tensor, query: torch.Tensor,
|
||||
return query_dst, key_dst
|
||||
|
||||
|
||||
def get_masked_input_and_mask_meta(input: torch.Tensor,
|
||||
def get_masked_input_and_mask_meta(
|
||||
input: torch.Tensor,
|
||||
org_vocab_start_index: int,
|
||||
org_vocab_end_index: int,
|
||||
num_org_vocab_padding: int,
|
||||
added_vocab_start_index: int,
|
||||
added_vocab_end_index: int):
|
||||
|
||||
added_vocab_end_index: int,
|
||||
):
|
||||
masked_input = torch.empty_like(input)
|
||||
mask = torch.empty_like(input).to(torch.bool)
|
||||
|
||||
return masked_input, mask
|
||||
|
||||
|
||||
def bgmv_expand_meta(x: torch.Tensor, weight: torch.Tensor,
|
||||
indices: torch.Tensor, y: torch.Tensor, slice_offset: int,
|
||||
slice_size: int):
|
||||
|
||||
def bgmv_expand_meta(
|
||||
x: torch.Tensor, weight: torch.Tensor, indices: torch.Tensor, y: torch.Tensor, slice_offset: int, slice_size: int
|
||||
):
|
||||
y_out = torch.empty_like(y)
|
||||
return y_out
|
||||
|
||||
|
||||
def sgmv_expand_meta(x: torch.Tensor, weight: torch.Tensor,
|
||||
lora_indices: torch.Tensor, seq_len: torch.Tensor,
|
||||
y: torch.Tensor, slice_offset: int, slice_size: int):
|
||||
|
||||
def sgmv_expand_meta(
|
||||
x: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
lora_indices: torch.Tensor,
|
||||
seq_len: torch.Tensor,
|
||||
y: torch.Tensor,
|
||||
slice_offset: int,
|
||||
slice_size: int,
|
||||
):
|
||||
y_out = torch.empty_like(y)
|
||||
return y_out
|
||||
|
||||
|
||||
register_meta_if_necessary("_C_ascend", "rotary_embedding",
|
||||
rotary_embedding_meta)
|
||||
register_meta_if_necessary("_C_ascend", "get_masked_input_and_mask",
|
||||
get_masked_input_and_mask_meta)
|
||||
register_meta_if_necessary("_C_ascend", "rotary_embedding", rotary_embedding_meta)
|
||||
register_meta_if_necessary("_C_ascend", "get_masked_input_and_mask", get_masked_input_and_mask_meta)
|
||||
register_meta_if_necessary("_C_ascend", "bgmv_expand", bgmv_expand_meta)
|
||||
register_meta_if_necessary("_C_ascend", "sgmv_expand", sgmv_expand_meta)
|
||||
|
||||
@@ -15,9 +15,11 @@
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
import os
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
from typing import TYPE_CHECKING, Any
|
||||
from uuid import uuid4
|
||||
|
||||
import torch
|
||||
@@ -32,11 +34,21 @@ from vllm_ascend.ascend_config import init_ascend_config
|
||||
|
||||
# isort: off
|
||||
from vllm_ascend.utils import (
|
||||
ASCEND_QUANTIZATION_METHOD, COMPILATION_PASS_KEY,
|
||||
COMPRESSED_TENSORS_METHOD, AscendDeviceType, check_kv_extra_config,
|
||||
enable_sp, flashcomm2_enable, get_ascend_device_type, is_moe_model,
|
||||
is_vl_model, refresh_block_size, update_aclgraph_sizes,
|
||||
update_cudagraph_capture_sizes, update_default_aclgraph_sizes)
|
||||
ASCEND_QUANTIZATION_METHOD,
|
||||
COMPILATION_PASS_KEY,
|
||||
COMPRESSED_TENSORS_METHOD,
|
||||
AscendDeviceType,
|
||||
check_kv_extra_config,
|
||||
enable_sp,
|
||||
flashcomm2_enable,
|
||||
get_ascend_device_type,
|
||||
is_moe_model,
|
||||
is_vl_model,
|
||||
refresh_block_size,
|
||||
update_aclgraph_sizes,
|
||||
update_cudagraph_capture_sizes,
|
||||
update_default_aclgraph_sizes,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.config import ModelConfig, VllmConfig
|
||||
@@ -80,7 +92,6 @@ def config_deprecated_logging():
|
||||
|
||||
|
||||
class NPUPlatform(Platform):
|
||||
|
||||
_enum = PlatformEnum.OOT
|
||||
device_name: str = "npu"
|
||||
device_type: str = "npu"
|
||||
@@ -89,9 +100,7 @@ class NPUPlatform(Platform):
|
||||
device_control_env_var: str = "ASCEND_RT_VISIBLE_DEVICES"
|
||||
dispatch_key: str = "PrivateUse1"
|
||||
|
||||
supported_quantization: list[str] = [
|
||||
ASCEND_QUANTIZATION_METHOD, COMPRESSED_TENSORS_METHOD
|
||||
]
|
||||
supported_quantization: list[str] = [ASCEND_QUANTIZATION_METHOD, COMPRESSED_TENSORS_METHOD]
|
||||
|
||||
def is_sleep_mode_available(self) -> bool:
|
||||
return True
|
||||
@@ -122,27 +131,23 @@ class NPUPlatform(Platform):
|
||||
return "vllm_ascend.compilation.compiler_interface.AscendCompiler"
|
||||
|
||||
@classmethod
|
||||
def pre_register_and_update(cls,
|
||||
parser: Optional[FlexibleArgumentParser] = None
|
||||
) -> None:
|
||||
def pre_register_and_update(cls, parser: FlexibleArgumentParser | None = None) -> None:
|
||||
# Adapt the global patch here.
|
||||
from vllm_ascend.utils import adapt_patch
|
||||
|
||||
adapt_patch(is_global_patch=True)
|
||||
|
||||
# For online serving, "ascend" quantization method is not a choice natively,
|
||||
# so we need to add "ascend" quantization method to quantization methods list
|
||||
# and the user can enable quantization using "vllm serve --quantization ascend".
|
||||
if parser is not None:
|
||||
quant_action = parser._option_string_actions.get('--quantization')
|
||||
if quant_action and hasattr(quant_action,
|
||||
'choices') and quant_action.choices:
|
||||
quant_action = parser._option_string_actions.get("--quantization")
|
||||
if quant_action and hasattr(quant_action, "choices") and quant_action.choices:
|
||||
if ASCEND_QUANTIZATION_METHOD not in quant_action.choices:
|
||||
quant_action.choices.append(ASCEND_QUANTIZATION_METHOD)
|
||||
|
||||
from vllm_ascend.quantization.compressed_tensors.compressed_tensors import \
|
||||
AscendCompressedTensorsConfig # noqa: F401
|
||||
from vllm_ascend.quantization.quant_config import \
|
||||
AscendQuantConfig # noqa: F401
|
||||
from vllm_ascend.quantization.compressed_tensors.compressed_tensors import AscendCompressedTensorsConfig # noqa: F401
|
||||
from vllm_ascend.quantization.quant_config import AscendQuantConfig # noqa: F401
|
||||
|
||||
config_deprecated_logging()
|
||||
|
||||
@@ -169,8 +174,7 @@ class NPUPlatform(Platform):
|
||||
|
||||
if vllm_config.kv_transfer_config is not None:
|
||||
check_kv_extra_config(vllm_config)
|
||||
if not getattr(vllm_config.kv_transfer_config,
|
||||
"_engine_id_patched", False):
|
||||
if not getattr(vllm_config.kv_transfer_config, "_engine_id_patched", False):
|
||||
vllm_config.kv_transfer_config.engine_id = f"{vllm_config.kv_transfer_config.engine_id}-{uuid4().hex}"
|
||||
vllm_config.kv_transfer_config._engine_id_patched = True
|
||||
from vllm.config import CompilationMode # noqa: E402
|
||||
@@ -181,24 +185,22 @@ class NPUPlatform(Platform):
|
||||
cache_config = vllm_config.cache_config
|
||||
ascend_compilation_config = ascend_config.ascend_compilation_config
|
||||
if ascend_compilation_config:
|
||||
vllm_config.additional_config.setdefault(
|
||||
"ascend_compilation_config", {}).update(
|
||||
vars(ascend_compilation_config
|
||||
) if not isinstance(ascend_compilation_config, dict)
|
||||
else ascend_compilation_config)
|
||||
vllm_config.additional_config.setdefault("ascend_compilation_config", {}).update(
|
||||
vars(ascend_compilation_config)
|
||||
if not isinstance(ascend_compilation_config, dict)
|
||||
else ascend_compilation_config
|
||||
)
|
||||
|
||||
elif model_config and hasattr(model_config.hf_text_config,
|
||||
"index_topk"):
|
||||
vllm_config.cache_config.cache_dtype = str(
|
||||
model_config.dtype).replace("torch.", "")
|
||||
elif model_config and hasattr(model_config.hf_text_config, "index_topk"):
|
||||
vllm_config.cache_config.cache_dtype = str(model_config.dtype).replace("torch.", "")
|
||||
if model_config is None:
|
||||
logger.warning("Model config is missing. This may indicate "
|
||||
"that we are running a test case")
|
||||
logger.warning("Model config is missing. This may indicate that we are running a test case")
|
||||
enforce_eager = False
|
||||
else:
|
||||
enforce_eager = getattr(model_config, "enforce_eager", False)
|
||||
|
||||
from vllm.config.compilation import CUDAGraphMode
|
||||
|
||||
if enforce_eager:
|
||||
logger.info("Compilation disabled, using eager mode by default")
|
||||
compilation_config.mode = CompilationMode.NONE
|
||||
@@ -207,12 +209,10 @@ class NPUPlatform(Platform):
|
||||
|
||||
compilation_config.cudagraph_num_of_warmups = 1
|
||||
|
||||
if compilation_config.mode not in [
|
||||
CompilationMode.NONE, CompilationMode.VLLM_COMPILE
|
||||
]:
|
||||
if compilation_config.mode not in [CompilationMode.NONE, CompilationMode.VLLM_COMPILE]:
|
||||
logger.warning(
|
||||
"NPU does not support %s compilation mode. Setting CUDAGraphMode to NONE",
|
||||
compilation_config.mode)
|
||||
"NPU does not support %s compilation mode. Setting CUDAGraphMode to NONE", compilation_config.mode
|
||||
)
|
||||
compilation_config.cudagraph_mode = CUDAGraphMode.NONE
|
||||
|
||||
# set cudaprah sizes before extending `compilation_config.splitting_ops`
|
||||
@@ -223,15 +223,18 @@ class NPUPlatform(Platform):
|
||||
update_default_aclgraph_sizes(vllm_config)
|
||||
# TODO delete graph size update here when compilation_config.pass_config.enable_sp
|
||||
# is supported by vllm-ascend.
|
||||
if vllm_config.parallel_config.tensor_parallel_size > 1 and not vllm_config.model_config.enforce_eager and \
|
||||
enable_sp(vllm_config):
|
||||
if (
|
||||
vllm_config.parallel_config.tensor_parallel_size > 1
|
||||
and not vllm_config.model_config.enforce_eager
|
||||
and enable_sp(vllm_config)
|
||||
):
|
||||
original_sizes = compilation_config.cudagraph_capture_sizes
|
||||
sp_aclgraph_sizes = \
|
||||
vllm_config.update_sizes_for_sequence_parallelism(original_sizes)
|
||||
sp_aclgraph_sizes = vllm_config.update_sizes_for_sequence_parallelism(original_sizes)
|
||||
assert sp_aclgraph_sizes, (
|
||||
f"cudagraph_capture_sizes {original_sizes} does not contain"
|
||||
f"values that are multiples of tp_size "
|
||||
f"{vllm_config.parallel_config.tensor_parallel_size}")
|
||||
f"{vllm_config.parallel_config.tensor_parallel_size}"
|
||||
)
|
||||
if len(sp_aclgraph_sizes) != len(original_sizes):
|
||||
compilation_config.cudagraph_capture_sizes = sp_aclgraph_sizes
|
||||
update_cudagraph_capture_sizes(vllm_config, sp_aclgraph_sizes)
|
||||
@@ -243,9 +246,7 @@ class NPUPlatform(Platform):
|
||||
# encoder-decoder models currently only support piecewise mode
|
||||
if model_config and model_config.is_encoder_decoder is True:
|
||||
if compilation_config.cudagraph_mode == CUDAGraphMode.FULL_DECODE_ONLY:
|
||||
logger.warning(
|
||||
"encoder-decoder model doesn't support FULL_DECODE_ONLY, fallback to PIECEWISE "
|
||||
)
|
||||
logger.warning("encoder-decoder model doesn't support FULL_DECODE_ONLY, fallback to PIECEWISE ")
|
||||
compilation_config.cudagraph_mode = CUDAGraphMode.PIECEWISE
|
||||
|
||||
# get custom compile backend for graph fusion
|
||||
@@ -255,15 +256,14 @@ class NPUPlatform(Platform):
|
||||
compilation_config.mode = CompilationMode.NONE
|
||||
ascend_config.enable_npugraph_ex = False
|
||||
elif compilation_config.cudagraph_mode == CUDAGraphMode.PIECEWISE:
|
||||
logger.info(
|
||||
"PIECEWISE compilation enabled on NPU. use_inductor not supported - "
|
||||
"using only ACL Graph mode")
|
||||
assert compilation_config.mode == CompilationMode.VLLM_COMPILE, \
|
||||
"When enabling VLLM_COMPILE aclgraph, please make sure compilation_config.mode == CompilationMode.VLLM_COMPILE and compilation_config.cudagraph_mode == CUDAGraphMode.VLLM_COMPILE"
|
||||
logger.info("PIECEWISE compilation enabled on NPU. use_inductor not supported - using only ACL Graph mode")
|
||||
assert compilation_config.mode == CompilationMode.VLLM_COMPILE, (
|
||||
"When enabling VLLM_COMPILE aclgraph, please make sure compilation_config.mode == "
|
||||
"CompilationMode.VLLM_COMPILE and compilation_config.cudagraph_mode == CUDAGraphMode.VLLM_COMPILE"
|
||||
)
|
||||
compilation_config.set_splitting_ops_for_v1(
|
||||
all2all_backend=vllm_config.parallel_config.all2all_backend,
|
||||
data_parallel_size=vllm_config.parallel_config.
|
||||
data_parallel_size,
|
||||
data_parallel_size=vllm_config.parallel_config.data_parallel_size,
|
||||
)
|
||||
compilation_config.use_inductor = False
|
||||
# NOTE: Theoretically, we should also add vllm::mla_forward in the attention ops.
|
||||
@@ -275,11 +275,13 @@ class NPUPlatform(Platform):
|
||||
compilation_config.splitting_ops.extend(["vllm::mla_forward"])
|
||||
update_aclgraph_sizes(vllm_config)
|
||||
ascend_config.enable_npugraph_ex = False
|
||||
elif compilation_config.cudagraph_mode == CUDAGraphMode.FULL_DECODE_ONLY or\
|
||||
compilation_config.cudagraph_mode == CUDAGraphMode.FULL:
|
||||
elif (
|
||||
compilation_config.cudagraph_mode == CUDAGraphMode.FULL_DECODE_ONLY
|
||||
or compilation_config.cudagraph_mode == CUDAGraphMode.FULL
|
||||
):
|
||||
logger.info(
|
||||
"FULL_DECODE_ONLY compilation enabled on NPU. use_inductor not supported - "
|
||||
"using only ACL Graph mode")
|
||||
"FULL_DECODE_ONLY compilation enabled on NPU. use_inductor not supported - using only ACL Graph mode"
|
||||
)
|
||||
compilation_config.use_inductor = False
|
||||
compilation_config.splitting_ops = []
|
||||
warning_message = """\033[91m
|
||||
@@ -297,30 +299,31 @@ class NPUPlatform(Platform):
|
||||
logger.warning(warning_message)
|
||||
else:
|
||||
logger.info(
|
||||
"%s cudagraph_mode is not support on NPU. falling back to NONE",
|
||||
compilation_config.cudagraph_mode)
|
||||
"%s cudagraph_mode is not support on NPU. falling back to NONE", compilation_config.cudagraph_mode
|
||||
)
|
||||
compilation_config.cudagraph_mode = CUDAGraphMode.NONE
|
||||
compilation_config.mode = CompilationMode.NONE
|
||||
ascend_config.enable_npugraph_ex = False
|
||||
|
||||
# TODO: Remove this check when ACL Graph supports ASCEND_LAUNCH_BLOCKING=1
|
||||
# Then, we will have to discuss the error handling strategy and user experience
|
||||
if compilation_config.cudagraph_mode != CUDAGraphMode.NONE and \
|
||||
os.environ.get("ASCEND_LAUNCH_BLOCKING", "0") == "1":
|
||||
if (
|
||||
compilation_config.cudagraph_mode != CUDAGraphMode.NONE
|
||||
and os.environ.get("ASCEND_LAUNCH_BLOCKING", "0") == "1"
|
||||
):
|
||||
raise ValueError(
|
||||
"ACL graph is incompatible with ASCEND_LAUNCH_BLOCKING=1. "
|
||||
"Please unset ASCEND_LAUNCH_BLOCKING or set it to 0. If you "
|
||||
"need ASCEND_LAUNCH_BLOCKING for debugging, consider other methods — "
|
||||
"for example, check the plog files (default: $HOME/ascend/log/debug) "
|
||||
"for more information about runtime errors.")
|
||||
"for more information about runtime errors."
|
||||
)
|
||||
|
||||
if parallel_config and parallel_config.worker_cls == "auto":
|
||||
# TODO: this is a tricky way to disable `use_sequence_parallel_moe` in vllm.
|
||||
parallel_config.all2all_backend = "flashinfer_all2allv"
|
||||
if ascend_config.xlite_graph_config.enabled:
|
||||
logger.info(
|
||||
"openEuler Xlite enabled. See: https://atomgit.com/openeuler/GVirt/tree/master/xlite"
|
||||
)
|
||||
logger.info("openEuler Xlite enabled. See: https://atomgit.com/openeuler/GVirt/tree/master/xlite")
|
||||
parallel_config.worker_cls = "vllm_ascend.xlite.xlite_worker.XliteWorker"
|
||||
else:
|
||||
parallel_config.worker_cls = "vllm_ascend.worker.worker.NPUWorker"
|
||||
@@ -332,10 +335,9 @@ class NPUPlatform(Platform):
|
||||
compilation_config.custom_ops = ["all"]
|
||||
|
||||
if ascend_config.recompute_scheduler_enable:
|
||||
from vllm_ascend.core.recompute_scheduler import \
|
||||
RecomputeSchedulerConfig
|
||||
recompute_scheduler_config = RecomputeSchedulerConfig.initialize_from_config(
|
||||
vllm_config)
|
||||
from vllm_ascend.core.recompute_scheduler import RecomputeSchedulerConfig
|
||||
|
||||
recompute_scheduler_config = RecomputeSchedulerConfig.initialize_from_config(vllm_config)
|
||||
vllm_config.scheduler_config = recompute_scheduler_config
|
||||
|
||||
# Extend original scheduler_config to use SchedulerDynamicBatch.
|
||||
@@ -346,9 +348,11 @@ class NPUPlatform(Platform):
|
||||
vllm_config.scheduler_config.enable_chunked_prefill = True
|
||||
vllm_config.scheduler_config.SLO_limits_for_dynamic_batch = ascend_config.SLO_limits_for_dynamic_batch
|
||||
|
||||
if vllm_config.kv_transfer_config is not None and \
|
||||
cache_config.block_size != parallel_config.cp_kv_cache_interleave_size and \
|
||||
parallel_config.decode_context_parallel_size * parallel_config.prefill_context_parallel_size > 1:
|
||||
if (
|
||||
vllm_config.kv_transfer_config is not None
|
||||
and cache_config.block_size != parallel_config.cp_kv_cache_interleave_size
|
||||
and parallel_config.decode_context_parallel_size * parallel_config.prefill_context_parallel_size > 1
|
||||
):
|
||||
raise AssertionError(
|
||||
f"cp_kv_cache_interleave_size({parallel_config.cp_kv_cache_interleave_size}) "
|
||||
f"and block_size({cache_config.block_size}) "
|
||||
@@ -356,12 +360,14 @@ class NPUPlatform(Platform):
|
||||
)
|
||||
|
||||
if is_vl_model(vllm_config):
|
||||
if bool(int(os.getenv("VLLM_ASCEND_ENABLE_FLASHCOMM", '0'))) or \
|
||||
bool(int(os.getenv("VLLM_ASCEND_ENABLE_FLASHCOMM1", '0'))):
|
||||
if bool(int(os.getenv("VLLM_ASCEND_ENABLE_FLASHCOMM", "0"))) or bool(
|
||||
int(os.getenv("VLLM_ASCEND_ENABLE_FLASHCOMM1", "0"))
|
||||
):
|
||||
raise ValueError(
|
||||
"Currently, VL models doesn't support "
|
||||
"FLASHCOMM in vllm-ascend. We will fix this in the future. "
|
||||
"Please set VLLM_ASCEND_ENABLE_FLASHCOMM1=0.")
|
||||
"Please set VLLM_ASCEND_ENABLE_FLASHCOMM1=0."
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def import_kernels(cls) -> None:
|
||||
@@ -377,14 +383,11 @@ class NPUPlatform(Platform):
|
||||
if _CUSTOM_OP_REGISTERED:
|
||||
return
|
||||
CUR_DIR = os.path.dirname(os.path.realpath(__file__))
|
||||
CUSTOM_OPP_PATH = os.path.join(CUR_DIR, "_cann_ops_custom", "vendors",
|
||||
"vllm-ascend")
|
||||
CUSTOM_OPP_PATH = os.path.join(CUR_DIR, "_cann_ops_custom", "vendors", "vllm-ascend")
|
||||
if os.path.exists(CUSTOM_OPP_PATH):
|
||||
current_cust_opp_path = os.environ.get("ASCEND_CUSTOM_OPP_PATH",
|
||||
"")
|
||||
current_cust_opp_path = os.environ.get("ASCEND_CUSTOM_OPP_PATH", "")
|
||||
if current_cust_opp_path:
|
||||
os.environ[
|
||||
"ASCEND_CUSTOM_OPP_PATH"] = f"{CUSTOM_OPP_PATH}:{current_cust_opp_path}"
|
||||
os.environ["ASCEND_CUSTOM_OPP_PATH"] = f"{CUSTOM_OPP_PATH}:{current_cust_opp_path}"
|
||||
else:
|
||||
os.environ["ASCEND_CUSTOM_OPP_PATH"] = CUSTOM_OPP_PATH
|
||||
_CUSTOM_OP_REGISTERED = True
|
||||
@@ -393,22 +396,18 @@ class NPUPlatform(Platform):
|
||||
def get_attn_backend_cls(cls, selected_backend, attn_selector_config):
|
||||
backend_map = {
|
||||
(True, False): "vllm_ascend.attention.mla_v1.AscendMLABackend",
|
||||
(False, False):
|
||||
"vllm_ascend.attention.attention_v1.AscendAttentionBackend",
|
||||
(False, False): "vllm_ascend.attention.attention_v1.AscendAttentionBackend",
|
||||
(True, True): "vllm_ascend.attention.sfa_v1.AscendSFABackend",
|
||||
}
|
||||
|
||||
return backend_map[(attn_selector_config.use_mla,
|
||||
attn_selector_config.use_sparse)]
|
||||
return backend_map[(attn_selector_config.use_mla, attn_selector_config.use_sparse)]
|
||||
|
||||
@classmethod
|
||||
def get_punica_wrapper(cls) -> str:
|
||||
return "vllm_ascend.lora.punica_npu.PunicaWrapperNPU"
|
||||
|
||||
@classmethod
|
||||
def get_current_memory_usage(cls,
|
||||
device: Optional[torch.types.Device] = None
|
||||
) -> float:
|
||||
def get_current_memory_usage(cls, device: torch.types.Device | None = None) -> float:
|
||||
torch.npu.reset_peak_memory_stats(device)
|
||||
return torch.npu.max_memory_allocated(device)
|
||||
|
||||
@@ -474,15 +473,16 @@ class NPUPlatform(Platform):
|
||||
dict[str, Any]: _description_
|
||||
"""
|
||||
# NOTE(Ronald1995): avoid circular import.
|
||||
from vllm_ascend.ascend_forward_context import (get_mc2_mask,
|
||||
select_moe_comm_method)
|
||||
from vllm_ascend.ascend_forward_context import get_mc2_mask, select_moe_comm_method
|
||||
from vllm_ascend.ops.fused_moe.moe_comm_method import get_moe_comm_method
|
||||
from vllm.distributed import get_dp_group, get_tensor_model_parallel_world_size
|
||||
|
||||
# NOTE(Ronald1995): avoid circular import, cudagraph_runtime_mode is
|
||||
# CUDAGraphMode.NONE in vllm, but we can't set CUDAGraphMode.NONE in
|
||||
# argument default value, so we set it to None first, then set it to
|
||||
# CUDAGraphMode.NONE here.
|
||||
from vllm.config import CUDAGraphMode
|
||||
|
||||
if cudagraph_runtime_mode is None:
|
||||
cudagraph_runtime_mode = CUDAGraphMode.NONE
|
||||
# TODO(Ronald1995): model runner v1 still use ascend_forward_context,
|
||||
@@ -526,31 +526,26 @@ class NPUPlatform(Platform):
|
||||
sp_enabled = enable_sp(vllm_config) and num_tokens is not None
|
||||
mmrs_fusion = False
|
||||
else:
|
||||
sp_enabled = enable_sp(vllm_config) and \
|
||||
num_tokens is not None and num_tokens > 1000
|
||||
sp_enabled = enable_sp(vllm_config) and num_tokens is not None and num_tokens > 1000
|
||||
|
||||
# TODO(Levi-JQ): another PR to normalize the enabling logic for sp/fc2
|
||||
flashcomm_v2_enabled = flashcomm2_enable(
|
||||
) and tp_world_size > 1 and num_tokens is not None
|
||||
flashcomm_v2_enabled = flashcomm2_enable() and tp_world_size > 1 and num_tokens is not None
|
||||
pad_size = 0
|
||||
if (sp_enabled or flashcomm_v2_enabled):
|
||||
pad_size = (tp_world_size -
|
||||
(num_tokens % tp_world_size)) % tp_world_size
|
||||
if sp_enabled or flashcomm_v2_enabled:
|
||||
pad_size = (tp_world_size - (num_tokens % tp_world_size)) % tp_world_size
|
||||
|
||||
dp_world_size = get_dp_group().world_size
|
||||
if dp_world_size > 1 and dp_metadata is not None:
|
||||
max_tokens_across_dp = dp_metadata.max_tokens_across_dp_cpu.item()
|
||||
if (sp_enabled or flashcomm_v2_enabled):
|
||||
padded_length = (max_tokens_across_dp + tp_world_size -
|
||||
1) // tp_world_size * tp_world_size
|
||||
if sp_enabled or flashcomm_v2_enabled:
|
||||
padded_length = (max_tokens_across_dp + tp_world_size - 1) // tp_world_size * tp_world_size
|
||||
pad_size = padded_length - num_tokens
|
||||
else:
|
||||
max_tokens_across_dp = num_tokens
|
||||
|
||||
if num_tokens is not None:
|
||||
# NOTE: token num which need to pad to when mc2
|
||||
padded_num_tokens = math.ceil(
|
||||
max_tokens_across_dp / tp_world_size) * tp_world_size
|
||||
padded_num_tokens = math.ceil(max_tokens_across_dp / tp_world_size) * tp_world_size
|
||||
reserved_mc2_mask = get_mc2_mask()
|
||||
if reserved_mc2_mask is not None:
|
||||
mc2_mask = reserved_mc2_mask[:padded_num_tokens]
|
||||
|
||||
@@ -21,9 +21,9 @@ This module generates the service_profiling_symbols.yaml configuration file
|
||||
to ~/.config/vllm_ascend/ directory.
|
||||
"""
|
||||
|
||||
import contextlib
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import vllm
|
||||
from vllm.logger import logger
|
||||
@@ -129,7 +129,7 @@ def get_config_dir() -> Path:
|
||||
return config_dir
|
||||
|
||||
|
||||
def _cleanup_temp_file(tmp_path: Optional[Path]) -> None:
|
||||
def _cleanup_temp_file(tmp_path: Path | None) -> None:
|
||||
"""
|
||||
Clean up a temporary file if it exists.
|
||||
|
||||
@@ -137,13 +137,11 @@ def _cleanup_temp_file(tmp_path: Optional[Path]) -> None:
|
||||
tmp_path: Path to the temporary file to clean up.
|
||||
"""
|
||||
if tmp_path is not None and tmp_path.exists():
|
||||
try:
|
||||
with contextlib.suppress(OSError):
|
||||
tmp_path.unlink()
|
||||
except OSError:
|
||||
pass # Ignore cleanup errors
|
||||
|
||||
|
||||
def generate_service_profiling_config() -> Optional[Path]:
|
||||
def generate_service_profiling_config() -> Path | None:
|
||||
"""
|
||||
Generate the service_profiling_symbols.yaml configuration file
|
||||
to ~/.config/vllm_ascend/ directory.
|
||||
@@ -170,9 +168,7 @@ def generate_service_profiling_config() -> Optional[Path]:
|
||||
try:
|
||||
config_dir.mkdir(parents=True, exist_ok=True)
|
||||
except (OSError, PermissionError) as e:
|
||||
logger.error(
|
||||
f"Failed to create configuration directory {config_dir}: {e}",
|
||||
exc_info=True)
|
||||
logger.error(f"Failed to create configuration directory {config_dir}: {e}", exc_info=True)
|
||||
return None
|
||||
|
||||
# Write the configuration file atomically using a temporary file
|
||||
@@ -180,13 +176,9 @@ def generate_service_profiling_config() -> Optional[Path]:
|
||||
tmp_path = None
|
||||
try:
|
||||
# Create a temporary file in the same directory for atomic write
|
||||
with tempfile.NamedTemporaryFile(mode='w',
|
||||
encoding='utf-8',
|
||||
dir=config_dir,
|
||||
delete=False,
|
||||
suffix='.tmp',
|
||||
prefix=CONFIG_FILENAME +
|
||||
'.') as tmp_file:
|
||||
with tempfile.NamedTemporaryFile(
|
||||
mode="w", encoding="utf-8", dir=config_dir, delete=False, suffix=".tmp", prefix=CONFIG_FILENAME + "."
|
||||
) as tmp_file:
|
||||
tmp_file.write(SERVICE_PROFILING_SYMBOLS_YAML)
|
||||
tmp_path = Path(tmp_file.name)
|
||||
|
||||
@@ -194,8 +186,7 @@ def generate_service_profiling_config() -> Optional[Path]:
|
||||
tmp_path.replace(config_file)
|
||||
return config_file
|
||||
except (OSError, PermissionError) as e:
|
||||
logger.error(f"Failed to write configuration file {config_file}: {e}",
|
||||
exc_info=True)
|
||||
logger.error(f"Failed to write configuration file {config_file}: {e}", exc_info=True)
|
||||
return None
|
||||
finally:
|
||||
# Clean up the temporary file if it wasn't successfully replaced
|
||||
|
||||
@@ -17,6 +17,8 @@
|
||||
# Adapted from vllm-project/vllm/vllm/worker/worker.py
|
||||
#
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import atexit
|
||||
import functools
|
||||
import math
|
||||
@@ -25,7 +27,7 @@ from contextlib import contextmanager, nullcontext
|
||||
from enum import Enum
|
||||
from functools import lru_cache
|
||||
from threading import Lock
|
||||
from typing import TYPE_CHECKING, Any, List, Optional, Tuple, Union
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import torch
|
||||
import torch_npu # noqa: F401
|
||||
@@ -88,6 +90,7 @@ def acl_graph_print(*args):
|
||||
resolving unexpected hangs. Usage:
|
||||
```python
|
||||
from vllm_ascend.utils import acl_graph_print
|
||||
|
||||
...
|
||||
acl_graph_print("Debug info")
|
||||
```
|
||||
@@ -113,8 +116,7 @@ def acl_graph_print(*args):
|
||||
torch_npu.npu._subscribe_report(current_compute_stream)
|
||||
_SUBSCRIBED_COMPUTE_STREAMS.add(current_compute_stream)
|
||||
|
||||
torch_npu.npu._launch_host_func(current_compute_stream,
|
||||
_print_callback_on_stream, args)
|
||||
torch_npu.npu._launch_host_func(current_compute_stream, _print_callback_on_stream, args)
|
||||
|
||||
|
||||
def _unregister_print_streams_on_exit():
|
||||
@@ -196,9 +198,7 @@ def nd_to_nz_2d(in_tensor: torch.Tensor) -> torch.Tensor:
|
||||
# after: pad_dims: [0, 2, 0, 3]
|
||||
|
||||
# return: (1, 2, 16, 16)
|
||||
return _custom_transpose(
|
||||
_custom_reshape(_custom_pad(in_tensor, pad_dims), aux_dims), 1,
|
||||
2).contiguous()
|
||||
return _custom_transpose(_custom_reshape(_custom_pad(in_tensor, pad_dims), aux_dims), 1, 2).contiguous()
|
||||
|
||||
|
||||
def nd_to_nz_spec(mask_tensor: torch.Tensor) -> torch.Tensor:
|
||||
@@ -208,11 +208,9 @@ def nd_to_nz_spec(mask_tensor: torch.Tensor) -> torch.Tensor:
|
||||
tokens_pad = (num_tokens + 15) // 16 * 16
|
||||
max_seq_len_pad = (max_seq_len + 15) // 16 * 16
|
||||
|
||||
mask_tensor_pad = \
|
||||
torch.zeros((1, tokens_pad, max_seq_len_pad), dtype=mask_tensor.dtype, device=mask_tensor.device)
|
||||
mask_tensor_pad = torch.zeros((1, tokens_pad, max_seq_len_pad), dtype=mask_tensor.dtype, device=mask_tensor.device)
|
||||
mask_tensor_pad[0][:num_tokens, :max_seq_len] = mask_tensor
|
||||
mask = mask_tensor_pad.reshape(
|
||||
(1, tokens_pad, max_seq_len_pad // 16, 16)).permute(0, 2, 1, 3)
|
||||
mask = mask_tensor_pad.reshape((1, tokens_pad, max_seq_len_pad // 16, 16)).permute(0, 2, 1, 3)
|
||||
return mask
|
||||
|
||||
|
||||
@@ -230,10 +228,7 @@ def aligned_16(tensor: torch.Tensor):
|
||||
return tensor
|
||||
|
||||
# Create a new tensor with shape (n_aligned, H, W) and fill it with zeros
|
||||
new_tensor = torch.zeros(n_aligned,
|
||||
*tensor.shape[1:],
|
||||
dtype=tensor.dtype,
|
||||
device=tensor.device)
|
||||
new_tensor = torch.zeros(n_aligned, *tensor.shape[1:], dtype=tensor.dtype, device=tensor.device)
|
||||
|
||||
# Copy the original tensor to the first N positions of the new tensor
|
||||
new_tensor[:n] = tensor
|
||||
@@ -254,15 +249,15 @@ def enable_custom_op():
|
||||
# isort: off
|
||||
# register custom ops into torch_library here
|
||||
import vllm_ascend.vllm_ascend_C # type: ignore # noqa: F401
|
||||
|
||||
# register the meta implementation for custom kernel if necessary
|
||||
import vllm_ascend.meta_registration # type: ignore # noqa: F401
|
||||
|
||||
# isort: on
|
||||
_CUSTOM_OP_ENABLED = True
|
||||
except ImportError:
|
||||
_CUSTOM_OP_ENABLED = False
|
||||
logger.warning(
|
||||
"Warning: Failed to register custom ops, all custom ops will be disabled"
|
||||
)
|
||||
logger.warning("Warning: Failed to register custom ops, all custom ops will be disabled")
|
||||
return _CUSTOM_OP_ENABLED
|
||||
|
||||
|
||||
@@ -277,8 +272,7 @@ def find_hccl_library() -> str:
|
||||
|
||||
# manually load the hccl library
|
||||
if so_file:
|
||||
logger.info("Found hccl from environment variable HCCL_SO_PATH=%s",
|
||||
so_file)
|
||||
logger.info("Found hccl from environment variable HCCL_SO_PATH=%s", so_file)
|
||||
else:
|
||||
if torch.version.cann is not None:
|
||||
so_file = "libhccl.so"
|
||||
@@ -318,6 +312,7 @@ def set_weight_prefetch_method(weight_prefetch_config: WeightPrefetchConfig):
|
||||
global _WEIGHT_PREFETCH_METHOD
|
||||
if _WEIGHT_PREFETCH_METHOD is None:
|
||||
from vllm_ascend.ops.weight_prefetch import WeightPrefetchMethod
|
||||
|
||||
_WEIGHT_PREFETCH_METHOD = WeightPrefetchMethod(weight_prefetch_config)
|
||||
return _WEIGHT_PREFETCH_METHOD
|
||||
|
||||
@@ -364,6 +359,7 @@ def vllm_version_is(target_vllm_version: str):
|
||||
vllm_version = envs_ascend.VLLM_VERSION
|
||||
else:
|
||||
import vllm
|
||||
|
||||
vllm_version = vllm.__version__
|
||||
try:
|
||||
return Version(vllm_version) == Version(target_vllm_version)
|
||||
@@ -372,7 +368,8 @@ def vllm_version_is(target_vllm_version: str):
|
||||
f"Invalid vllm version {vllm_version} found. A dev version of vllm "
|
||||
"is installed probably. Set the environment variable VLLM_VERSION "
|
||||
"to control it by hand. And please make sure the value follows the "
|
||||
"format of x.y.z.")
|
||||
"format of x.y.z."
|
||||
)
|
||||
|
||||
|
||||
def get_max_hidden_layers(hf_config) -> int:
|
||||
@@ -394,20 +391,19 @@ def get_max_hidden_layers(hf_config) -> int:
|
||||
|
||||
|
||||
# Update cudagraph capture sizes for vllm config
|
||||
def update_cudagraph_capture_sizes(vllm_config: VllmConfig,
|
||||
cudagraph_capture_sizes: List[int]):
|
||||
|
||||
valid_max_size = (cudagraph_capture_sizes[-1]
|
||||
if cudagraph_capture_sizes else 0)
|
||||
if (vllm_config.compilation_config.max_cudagraph_capture_size is not None
|
||||
and vllm_config.compilation_config.max_cudagraph_capture_size
|
||||
!= valid_max_size):
|
||||
def update_cudagraph_capture_sizes(vllm_config: VllmConfig, cudagraph_capture_sizes: list[int]):
|
||||
valid_max_size = cudagraph_capture_sizes[-1] if cudagraph_capture_sizes else 0
|
||||
if (
|
||||
vllm_config.compilation_config.max_cudagraph_capture_size is not None
|
||||
and vllm_config.compilation_config.max_cudagraph_capture_size != valid_max_size
|
||||
):
|
||||
if vllm_config.compilation_config.cudagraph_capture_sizes is not None:
|
||||
raise ValueError(
|
||||
"customized max_cudagraph_capture_size"
|
||||
f"(={vllm_config.compilation_config.max_cudagraph_capture_size}) "
|
||||
"should be consistent with the max value of "
|
||||
f"cudagraph_capture_sizes(={valid_max_size})")
|
||||
f"cudagraph_capture_sizes(={valid_max_size})"
|
||||
)
|
||||
logger.warning(
|
||||
"Truncating max_cudagraph_capture_size to %d",
|
||||
valid_max_size,
|
||||
@@ -415,12 +411,11 @@ def update_cudagraph_capture_sizes(vllm_config: VllmConfig,
|
||||
|
||||
vllm_config.compilation_config.max_cudagraph_capture_size = valid_max_size
|
||||
|
||||
if vllm_config.compilation_config.cudagraph_capture_sizes is not None and len(
|
||||
cudagraph_capture_sizes) < len(
|
||||
vllm_config.compilation_config.cudagraph_capture_sizes):
|
||||
if vllm_config.compilation_config.cudagraph_capture_sizes is not None and len(cudagraph_capture_sizes) < len(
|
||||
vllm_config.compilation_config.cudagraph_capture_sizes
|
||||
):
|
||||
logger.warning(
|
||||
("cudagraph_capture_sizes specified in compilation_config"
|
||||
" %s is overridden by config %s"),
|
||||
("cudagraph_capture_sizes specified in compilation_config %s is overridden by config %s"),
|
||||
vllm_config.compilation_config.cudagraph_capture_sizes,
|
||||
cudagraph_capture_sizes,
|
||||
)
|
||||
@@ -433,26 +428,17 @@ def _is_default_capture_sizes(vllm_config: VllmConfig) -> bool:
|
||||
Check whether it is vLLM default capture sizes.
|
||||
"""
|
||||
|
||||
max_cudagraph_capture_size = \
|
||||
vllm_config.compilation_config.max_cudagraph_capture_size
|
||||
cudagraph_capture_sizes = [
|
||||
i for i in [1, 2, 4] if i <= max_cudagraph_capture_size
|
||||
]
|
||||
max_cudagraph_capture_size = vllm_config.compilation_config.max_cudagraph_capture_size
|
||||
cudagraph_capture_sizes = [i for i in [1, 2, 4] if i <= max_cudagraph_capture_size]
|
||||
if max_cudagraph_capture_size >= 8:
|
||||
# Step size 8 for small batch sizes, up to 256(not included)
|
||||
cudagraph_capture_sizes += list(
|
||||
range(8, min(max_cudagraph_capture_size + 1, 256), 8))
|
||||
cudagraph_capture_sizes += list(range(8, min(max_cudagraph_capture_size + 1, 256), 8))
|
||||
if max_cudagraph_capture_size >= 256:
|
||||
# Step size 16 for larger batch sizes
|
||||
cudagraph_capture_sizes += list(
|
||||
range(256, max_cudagraph_capture_size + 1, 16))
|
||||
cudagraph_capture_sizes += list(range(256, max_cudagraph_capture_size + 1, 16))
|
||||
# in newer version, vLLM use ascending order of cudagraph_capture_sizes.
|
||||
target_cudagraph_capture_sizes = sorted(cudagraph_capture_sizes)
|
||||
if target_cudagraph_capture_sizes == \
|
||||
vllm_config.compilation_config.cudagraph_capture_sizes:
|
||||
return True
|
||||
|
||||
return False
|
||||
return target_cudagraph_capture_sizes == vllm_config.compilation_config.cudagraph_capture_sizes
|
||||
|
||||
|
||||
def update_default_aclgraph_sizes(vllm_config: VllmConfig) -> None:
|
||||
@@ -461,9 +447,11 @@ def update_default_aclgraph_sizes(vllm_config: VllmConfig) -> None:
|
||||
are more friendly to ascend ops && hardware.
|
||||
"""
|
||||
|
||||
if vllm_config.model_config is None or \
|
||||
vllm_config.model_config.enforce_eager or \
|
||||
not _is_default_capture_sizes(vllm_config):
|
||||
if (
|
||||
vllm_config.model_config is None
|
||||
or vllm_config.model_config.enforce_eager
|
||||
or not _is_default_capture_sizes(vllm_config)
|
||||
):
|
||||
return
|
||||
|
||||
# modify the default capture_sizes for Qwen3-MoE models on dp settings.
|
||||
@@ -471,16 +459,15 @@ def update_default_aclgraph_sizes(vllm_config: VllmConfig) -> None:
|
||||
# on special shapes.
|
||||
# TODO(Angazenn): we will remove this once _npu_paged_attention is fully
|
||||
# replaced by npu_fused_infer_attention_score which does not contain such bugs.
|
||||
if vllm_config.model_config and vllm_config.model_config.hf_text_config.model_type == "qwen3_moe" \
|
||||
and vllm_config.parallel_config.tensor_parallel_size == 1 \
|
||||
and vllm_config.parallel_config.data_parallel_size > 1 :
|
||||
|
||||
if (
|
||||
vllm_config.model_config
|
||||
and vllm_config.model_config.hf_text_config.model_type == "qwen3_moe"
|
||||
and vllm_config.parallel_config.tensor_parallel_size == 1
|
||||
and vllm_config.parallel_config.data_parallel_size > 1
|
||||
):
|
||||
max_capture_size = vllm_config.compilation_config.max_cudagraph_capture_size
|
||||
new_cudagraph_capture_sizes = [1, 2, 5, 10, 15, 20] + [
|
||||
i for i in range(24, max_capture_size + 1, 8)
|
||||
]
|
||||
update_cudagraph_capture_sizes(vllm_config,
|
||||
new_cudagraph_capture_sizes)
|
||||
new_cudagraph_capture_sizes = [1, 2, 5, 10, 15, 20] + [i for i in range(24, max_capture_size + 1, 8)]
|
||||
update_cudagraph_capture_sizes(vllm_config, new_cudagraph_capture_sizes)
|
||||
|
||||
|
||||
def update_aclgraph_sizes(vllm_config: VllmConfig) -> None:
|
||||
@@ -498,8 +485,7 @@ def update_aclgraph_sizes(vllm_config: VllmConfig) -> None:
|
||||
|
||||
# Store original configuration and temporarily clear it
|
||||
compilation_config = vllm_config.compilation_config
|
||||
original_sizes, compilation_config.cudagraph_capture_sizes = \
|
||||
compilation_config.cudagraph_capture_sizes, None
|
||||
original_sizes, compilation_config.cudagraph_capture_sizes = compilation_config.cudagraph_capture_sizes, None
|
||||
|
||||
# Calculate parallel configuration factor
|
||||
if not vllm_config.model_config:
|
||||
@@ -510,7 +496,7 @@ def update_aclgraph_sizes(vllm_config: VllmConfig) -> None:
|
||||
|
||||
return
|
||||
hf_config = vllm_config.model_config.hf_text_config
|
||||
if hasattr(hf_config, 'num_hidden_layers'):
|
||||
if hasattr(hf_config, "num_hidden_layers"):
|
||||
num_hidden_layers = hf_config.num_hidden_layers
|
||||
else:
|
||||
num_hidden_layers = get_max_hidden_layers(hf_config)
|
||||
@@ -519,24 +505,28 @@ def update_aclgraph_sizes(vllm_config: VllmConfig) -> None:
|
||||
# Calculate maximum supported batch sizes considering model architecture
|
||||
resources_per_graph = num_hidden_layers + 1
|
||||
# For suffix decoding, use the suffix path when no draft_model_config is provided.
|
||||
if (spec := vllm_config.speculative_config) and \
|
||||
(draft := spec.draft_model_config):
|
||||
if (spec := vllm_config.speculative_config) and (draft := spec.draft_model_config):
|
||||
resources_per_graph += draft.hf_config.num_hidden_layers + 1
|
||||
|
||||
# TODO: Find out whether we need to take into account the pp_size
|
||||
num_comm_groups = sum(size > 1 for size in [
|
||||
num_comm_groups = sum(
|
||||
size > 1
|
||||
for size in [
|
||||
parallel_config.data_parallel_size,
|
||||
parallel_config.tensor_parallel_size,
|
||||
])
|
||||
]
|
||||
)
|
||||
|
||||
if os.getenv("HCCL_OP_EXPANSION_MODE") == 'AIV':
|
||||
if os.getenv("HCCL_OP_EXPANSION_MODE") == "AIV":
|
||||
# TODO: Find out whether we need to take into account the pp_size
|
||||
parallel_factor = 1 + num_comm_groups + int(
|
||||
parallel_config.enable_expert_parallel) + int(
|
||||
vllm_config.additional_config.get(
|
||||
"multistream_overlap_shared_expert", False))
|
||||
parallel_factor = (
|
||||
1
|
||||
+ num_comm_groups
|
||||
+ int(parallel_config.enable_expert_parallel)
|
||||
+ int(vllm_config.additional_config.get("multistream_overlap_shared_expert", False))
|
||||
)
|
||||
if is_moe_model(vllm_config):
|
||||
parallel_factor += (parallel_config.data_parallel_size > 1)
|
||||
parallel_factor += parallel_config.data_parallel_size > 1
|
||||
else:
|
||||
# When AIV mode is enabled, the allreduce operator of the dense
|
||||
# layer model will occupy additional streams, which are buffered here.
|
||||
@@ -546,11 +536,8 @@ def update_aclgraph_sizes(vllm_config: VllmConfig) -> None:
|
||||
# Assume the following case:
|
||||
# MAX_CAPTURE_SIZE = 1920, num_hidden_layers = 48, data_parallel_size is 1, tensor_parallel_size is 4,
|
||||
# According to the formula, max_num_batch_sizes = math.floor(1920 / (48 + 1) / 2) = 19
|
||||
max_num_batch_sizes = math.floor(MAX_CAPTURE_SIZE /
|
||||
resources_per_graph / parallel_factor)
|
||||
logger.info(
|
||||
"Calculated maximum supported batch sizes for ACL graph: %s",
|
||||
max_num_batch_sizes)
|
||||
max_num_batch_sizes = math.floor(MAX_CAPTURE_SIZE / resources_per_graph / parallel_factor)
|
||||
logger.info("Calculated maximum supported batch sizes for ACL graph: %s", max_num_batch_sizes)
|
||||
else:
|
||||
# enable pcp or dcp will add new communication and consume additional approximately less than 100 streams
|
||||
if parallel_config.prefill_context_parallel_size > 1:
|
||||
@@ -562,18 +549,18 @@ def update_aclgraph_sizes(vllm_config: VllmConfig) -> None:
|
||||
# Under this configuration, HCCL employs the FFTS+ method for execution unfolding,
|
||||
# which adds only 1 concurrent stream without consuming collective communication execution unfolding streams.
|
||||
# On A3 hardware, HCCL defaults to the AICPU method.
|
||||
# This approach may additionally allocate up to rank_size (max 16) - 1 streams per collective communication domain on the device (worst case).
|
||||
# Using the default collective communication unfolding method on A3 will lead to a significant reduction in the maximum supported sizes.
|
||||
# This approach may additionally allocate up to rank_size (max 16) - 1 streams per collective communication
|
||||
# domain on the device (worst case).
|
||||
# Using the default collective communication unfolding method on A3 will lead to a significant reduction
|
||||
# in the maximum supported sizes.
|
||||
# Therefore, the calculation formula has been modified as follows:
|
||||
# Assume the following case:
|
||||
# MAX_CAPTURE_SIZE = 1920, num_hidden_layers = 48, data_parallel_size is 1, tensor_parallel_size is 4,
|
||||
# According to the formula, max_num_batch_sizes = math.floor((1920 - 1 * 40) / (48 + 1) / (1 + 1 * 2)) = 12
|
||||
max_num_batch_sizes = math.floor(
|
||||
(MAX_CAPTURE_SIZE - num_comm_groups * 40) / resources_per_graph /
|
||||
(1 + num_comm_groups * 2))
|
||||
logger.info(
|
||||
"Calculated maximum supported batch sizes for ACL graph: %s",
|
||||
max_num_batch_sizes)
|
||||
(MAX_CAPTURE_SIZE - num_comm_groups * 40) / resources_per_graph / (1 + num_comm_groups * 2)
|
||||
)
|
||||
logger.info("Calculated maximum supported batch sizes for ACL graph: %s", max_num_batch_sizes)
|
||||
logger.warning(
|
||||
"Currently, communication is performed using FFTS+ method, which reduces "
|
||||
"the number of available streams and, as a result, limits the range of runtime "
|
||||
@@ -598,16 +585,19 @@ def update_aclgraph_sizes(vllm_config: VllmConfig) -> None:
|
||||
vllm_config.model_config.architectures[0],
|
||||
num_hidden_layers,
|
||||
len(original_sizes),
|
||||
len(compilation_config.
|
||||
cudagraph_capture_sizes # type: ignore[arg-type]
|
||||
))
|
||||
len(
|
||||
compilation_config.cudagraph_capture_sizes # type: ignore[arg-type]
|
||||
),
|
||||
)
|
||||
else:
|
||||
# No adjustment needed
|
||||
compilation_config.cudagraph_capture_sizes = original_sizes
|
||||
logger.info(
|
||||
"No adjustment needed for ACL graph batch sizes: %s model (layers: %d) with %d sizes",
|
||||
vllm_config.model_config.architectures[0], num_hidden_layers,
|
||||
len(original_sizes))
|
||||
vllm_config.model_config.architectures[0],
|
||||
num_hidden_layers,
|
||||
len(original_sizes),
|
||||
)
|
||||
|
||||
|
||||
# TODO(wxy): Move to ops module
|
||||
@@ -617,7 +607,7 @@ def dispose_tensor(x: torch.Tensor):
|
||||
|
||||
class ProfileExecuteDuration:
|
||||
_instance = None
|
||||
_observations: List[Tuple[str, Event, Event]] = []
|
||||
_observations: list[tuple[str, Event, Event]] = []
|
||||
_lock = Lock()
|
||||
|
||||
def __new__(cls):
|
||||
@@ -645,8 +635,7 @@ class ProfileExecuteDuration:
|
||||
observe_end = Event(enable_timing=True)
|
||||
observe_end.record()
|
||||
with self._lock:
|
||||
self._observations.append(
|
||||
(duration_tag, observe_start, observe_end))
|
||||
self._observations.append((duration_tag, observe_start, observe_end))
|
||||
|
||||
def pop_captured_sync(self) -> dict:
|
||||
"""Pop and synchronize all events in the observation list"""
|
||||
@@ -663,7 +652,7 @@ class ProfileExecuteDuration:
|
||||
return durations
|
||||
|
||||
|
||||
def register_ascend_customop(vllm_config: Optional[VllmConfig] = None):
|
||||
def register_ascend_customop(vllm_config: VllmConfig | None = None):
|
||||
"""Register Ascend CustomOP
|
||||
|
||||
NOTE: if the register branch requires model type, please use `vllm.config.get_current_vllm_config`,
|
||||
@@ -675,23 +664,29 @@ def register_ascend_customop(vllm_config: Optional[VllmConfig] = None):
|
||||
from vllm.model_executor.custom_op import CustomOp
|
||||
|
||||
from vllm_ascend.ops.activation import AscendQuickGELU, AscendSiluAndMul
|
||||
from vllm_ascend.ops.fused_moe.fused_moe import (AscendFusedMoE,
|
||||
AscendSharedFusedMoE)
|
||||
from vllm_ascend.ops.fused_moe.fused_moe import AscendFusedMoE, AscendSharedFusedMoE
|
||||
from vllm_ascend.ops.layernorm import AscendGemmaRMSNorm, AscendRMSNorm
|
||||
from vllm_ascend.ops.linear import (AscendColumnParallelLinear,
|
||||
from vllm_ascend.ops.linear import (
|
||||
AscendColumnParallelLinear,
|
||||
AscendMergedColumnParallelLinear,
|
||||
AscendQKVParallelLinear,
|
||||
AscendReplicatedLinear,
|
||||
AscendRowParallelLinear)
|
||||
AscendRowParallelLinear,
|
||||
)
|
||||
from vllm_ascend.ops.mla import AscendMultiHeadLatentAttention
|
||||
from vllm_ascend.ops.mm_encoder_attention import AscendMMEncoderAttention
|
||||
from vllm_ascend.ops.rotary_embedding import (
|
||||
AscendApplyRotaryEmb, AscendDeepseekScalingRotaryEmbedding,
|
||||
AscendMRotaryEmbedding, AscendRotaryEmbedding,
|
||||
AscendYaRNRotaryEmbedding)
|
||||
AscendApplyRotaryEmb,
|
||||
AscendDeepseekScalingRotaryEmbedding,
|
||||
AscendMRotaryEmbedding,
|
||||
AscendRotaryEmbedding,
|
||||
AscendYaRNRotaryEmbedding,
|
||||
)
|
||||
from vllm_ascend.ops.vocab_parallel_embedding import (
|
||||
AscendLogitsProcessor, AscendParallelLMHead,
|
||||
AscendVocabParallelEmbedding)
|
||||
AscendLogitsProcessor,
|
||||
AscendParallelLMHead,
|
||||
AscendVocabParallelEmbedding,
|
||||
)
|
||||
|
||||
global REGISTERED_ASCEND_OPS
|
||||
REGISTERED_ASCEND_OPS = {
|
||||
@@ -738,6 +733,7 @@ _ascend_device_type = None
|
||||
def _init_ascend_device_type():
|
||||
global _ascend_device_type
|
||||
from vllm_ascend import _build_info # type: ignore
|
||||
|
||||
_ascend_device_type = AscendDeviceType[_build_info.__device_type__]
|
||||
|
||||
|
||||
@@ -758,7 +754,10 @@ def check_ascend_device_type():
|
||||
else:
|
||||
raise RuntimeError(f"Can not support soc_version: {soc_version}.")
|
||||
|
||||
assert _ascend_device_type == cur_device_type, f"Current device type: {cur_device_type} does not match the installed version's device type: {_ascend_device_type}, please check your installation package."
|
||||
assert _ascend_device_type == cur_device_type, (
|
||||
f"Current device type: {cur_device_type} does not match the installed version's device type: "
|
||||
f"{_ascend_device_type}, please check your installation package."
|
||||
)
|
||||
|
||||
|
||||
def get_ascend_device_type():
|
||||
@@ -769,23 +768,19 @@ def get_ascend_device_type():
|
||||
|
||||
|
||||
def lmhead_tp_enable() -> bool:
|
||||
return get_ascend_config(
|
||||
).finegrained_tp_config.lmhead_tensor_parallel_size > 0
|
||||
return get_ascend_config().finegrained_tp_config.lmhead_tensor_parallel_size > 0
|
||||
|
||||
|
||||
def embedding_tp_enable() -> bool:
|
||||
return get_ascend_config(
|
||||
).finegrained_tp_config.embedding_tensor_parallel_size > 0
|
||||
return get_ascend_config().finegrained_tp_config.embedding_tensor_parallel_size > 0
|
||||
|
||||
|
||||
def oproj_tp_enable() -> bool:
|
||||
return get_ascend_config(
|
||||
).finegrained_tp_config.oproj_tensor_parallel_size > 0
|
||||
return get_ascend_config().finegrained_tp_config.oproj_tensor_parallel_size > 0
|
||||
|
||||
|
||||
def mlp_tp_enable() -> bool:
|
||||
return get_ascend_config(
|
||||
).finegrained_tp_config.mlp_tensor_parallel_size > 0
|
||||
return get_ascend_config().finegrained_tp_config.mlp_tensor_parallel_size > 0
|
||||
|
||||
|
||||
def matmul_allreduce_enable() -> bool:
|
||||
@@ -797,30 +792,30 @@ def enable_sp(vllm_config=None, enable_shared_expert_dp: bool = False) -> bool:
|
||||
if _ENABLE_SP is None:
|
||||
if vllm_config is None:
|
||||
from vllm.config import get_current_vllm_config
|
||||
|
||||
vllm_config = get_current_vllm_config()
|
||||
_ENABLE_SP = (
|
||||
vllm_config.compilation_config.pass_config.enable_sp
|
||||
or envs_ascend.VLLM_ASCEND_ENABLE_FLASHCOMM1
|
||||
# Flash comm 1 should be enabled by env VLLM_ASCEND_ENABLE_FLASHCOMM1
|
||||
# We retain the env VLLM_ASCEND_ENABLE_FLASHCOMM here for backward compatibility.
|
||||
or bool(int(os.getenv("VLLM_ASCEND_ENABLE_FLASHCOMM", '0'))))
|
||||
or bool(int(os.getenv("VLLM_ASCEND_ENABLE_FLASHCOMM", "0")))
|
||||
)
|
||||
|
||||
if not _ENABLE_SP and enable_shared_expert_dp:
|
||||
_ENABLE_SP = True
|
||||
logger.info(
|
||||
"shared_expert_dp requires enable_sp = True. has set enable_sp to True"
|
||||
)
|
||||
logger.info("shared_expert_dp requires enable_sp = True. has set enable_sp to True")
|
||||
|
||||
if not _ENABLE_SP:
|
||||
return _ENABLE_SP
|
||||
|
||||
assert vllm_config.parallel_config.tensor_parallel_size > 1, \
|
||||
assert vllm_config.parallel_config.tensor_parallel_size > 1, (
|
||||
"Flash Comm v1 (Sequence Parallelism) is only supported when tp_size > 1."
|
||||
)
|
||||
|
||||
assert (
|
||||
not is_moe_model(vllm_config)
|
||||
or vllm_config.parallel_config.enable_expert_parallel
|
||||
), "Flash Comm v1 (Sequence Parallelism) requires enable_expert_parallel=True for MoE models."
|
||||
assert not is_moe_model(vllm_config) or vllm_config.parallel_config.enable_expert_parallel, (
|
||||
"Flash Comm v1 (Sequence Parallelism) requires enable_expert_parallel=True for MoE models."
|
||||
)
|
||||
|
||||
return _ENABLE_SP
|
||||
|
||||
@@ -847,25 +842,21 @@ def is_drafter_moe_model(vllm_config: VllmConfig):
|
||||
"""Checks if the drafter model is a MoE model by config"""
|
||||
global _IS_DRAFTER_MOE_MODEL
|
||||
if _IS_DRAFTER_MOE_MODEL is None:
|
||||
model_configs = vllm_config.speculative_config.draft_model_config.hf_text_config \
|
||||
.to_dict()
|
||||
model_configs = vllm_config.speculative_config.draft_model_config.hf_text_config.to_dict()
|
||||
_IS_DRAFTER_MOE_MODEL = _is_contain_expert(model_configs)
|
||||
return _IS_DRAFTER_MOE_MODEL
|
||||
|
||||
|
||||
def speculative_enable_dispatch_gmm_combine_decode(
|
||||
vllm_config: VllmConfig) -> bool:
|
||||
def speculative_enable_dispatch_gmm_combine_decode(vllm_config: VllmConfig) -> bool:
|
||||
if vllm_config.speculative_config is None:
|
||||
return True
|
||||
speculative_method = getattr(vllm_config.speculative_config, "method",
|
||||
None)
|
||||
speculative_method = getattr(vllm_config.speculative_config, "method", None)
|
||||
if speculative_method in [None, "ngram", "suffix"]:
|
||||
return True
|
||||
if speculative_method in ["eagle", "eagle3"]:
|
||||
return False
|
||||
if speculative_method == "mtp":
|
||||
mtp_quant_type = getattr(vllm_config.model_config.hf_text_config,
|
||||
"mtp_quantize", None)
|
||||
mtp_quant_type = getattr(vllm_config.model_config.hf_text_config, "mtp_quantize", None)
|
||||
return mtp_quant_type == "w8a8_dynamic"
|
||||
return False
|
||||
|
||||
@@ -915,8 +906,8 @@ def weak_ref_tensor(tensor: Any) -> Any:
|
||||
|
||||
|
||||
def weak_ref_tensors(
|
||||
tensors: Union[torch.Tensor, list[torch.Tensor], tuple[torch.Tensor]]
|
||||
) -> Union[torch.Tensor, list[Any], tuple[Any], Any]:
|
||||
tensors: torch.Tensor | list[torch.Tensor] | tuple[torch.Tensor],
|
||||
) -> torch.Tensor | list[Any] | tuple[Any] | Any:
|
||||
"""
|
||||
Convenience function to create weak references to tensors,
|
||||
for single tensor, list of tensors or tuple of tensors.
|
||||
@@ -936,17 +927,12 @@ def weak_ref_tensors(
|
||||
return tuple(weak_ref_tensor(t) for t in tensors)
|
||||
# For IntermediateTensors used in pipeline parallelism
|
||||
if isinstance(tensors, IntermediateTensors):
|
||||
ret = IntermediateTensors({
|
||||
key: weak_ref_tensor(val)
|
||||
for key, val in tensors.tensors.items()
|
||||
})
|
||||
ret = IntermediateTensors({key: weak_ref_tensor(val) for key, val in tensors.tensors.items()})
|
||||
return ret
|
||||
raise ValueError("Invalid type for tensors")
|
||||
|
||||
|
||||
def npu_stream_switch(target_stream: torch.npu.Stream,
|
||||
*,
|
||||
enabled: bool = True):
|
||||
def npu_stream_switch(target_stream: torch.npu.Stream, *, enabled: bool = True):
|
||||
"""
|
||||
Switch to the target stream if enabled is True.
|
||||
Otherwise, do nothing.
|
||||
@@ -965,7 +951,7 @@ def create_hccl_pg_options(group_name: str):
|
||||
return options
|
||||
|
||||
|
||||
def get_hccl_config_for_pg_options(group_name: str) -> Optional[dict]:
|
||||
def get_hccl_config_for_pg_options(group_name: str) -> dict | None:
|
||||
"""
|
||||
Get HCCL process group options for the given communication group name.
|
||||
|
||||
@@ -981,9 +967,7 @@ def get_hccl_config_for_pg_options(group_name: str) -> Optional[dict]:
|
||||
if group_name and "mc2" in group_name:
|
||||
return None
|
||||
hccl_config_map = {
|
||||
"dp": {
|
||||
"hccl_buffer_size": calculate_dp_buffer_size()
|
||||
},
|
||||
"dp": {"hccl_buffer_size": calculate_dp_buffer_size()},
|
||||
}
|
||||
return hccl_config_map.get(group_name, get_default_buffer_config())
|
||||
|
||||
@@ -998,6 +982,7 @@ def calculate_dp_buffer_size() -> int:
|
||||
dp_size + 1 (flags: with_prefill)
|
||||
"""
|
||||
from vllm.config import get_current_vllm_config
|
||||
|
||||
vllm_config = get_current_vllm_config()
|
||||
dp_size = vllm_config.parallel_config.data_parallel_size
|
||||
int32_size = torch.iinfo(torch.int32).bits // 8
|
||||
@@ -1009,8 +994,7 @@ def calculate_dp_buffer_size() -> int:
|
||||
# and HCCL_INTRA_ROCE_ENABLE=0 can reduce cross-machine communication traffic and
|
||||
# significantly improve communication performance of MC2 ops dispatch/combine.
|
||||
def is_hierarchical_communication_enabled():
|
||||
return (os.getenv("HCCL_INTRA_ROCE_ENABLE", "") == "0"
|
||||
and os.getenv("HCCL_INTRA_PCIE_ENABLE", "") == "1")
|
||||
return os.getenv("HCCL_INTRA_ROCE_ENABLE", "") == "0" and os.getenv("HCCL_INTRA_PCIE_ENABLE", "") == "1"
|
||||
|
||||
|
||||
def has_layer_idx(model_instance: torch.nn.Module) -> bool:
|
||||
@@ -1019,8 +1003,7 @@ def has_layer_idx(model_instance: torch.nn.Module) -> bool:
|
||||
|
||||
global _HAS_LAYER_IDX
|
||||
if _HAS_LAYER_IDX is None:
|
||||
_HAS_LAYER_IDX = hasattr(model_instance, "model") and \
|
||||
hasattr(model_instance.model, "start_layer")
|
||||
_HAS_LAYER_IDX = hasattr(model_instance, "model") and hasattr(model_instance.model, "start_layer")
|
||||
return _HAS_LAYER_IDX
|
||||
|
||||
|
||||
@@ -1042,20 +1025,17 @@ def get_flashcomm2_config_and_validate(ascend_config, vllm_config):
|
||||
if not flashcomm2_enable():
|
||||
return 0
|
||||
|
||||
logger.info(
|
||||
f"Enable FLASHCOMM2 with flashcomm2_oproj_tensor_parallel_size = {flashcomm2_oproj_tp_size}"
|
||||
)
|
||||
logger.info(f"Enable FLASHCOMM2 with flashcomm2_oproj_tensor_parallel_size = {flashcomm2_oproj_tp_size}")
|
||||
|
||||
layer_sharding = ascend_config.layer_sharding or []
|
||||
if layer_sharding:
|
||||
if layer_sharding == ["o_proj"]:
|
||||
logger.info_once(
|
||||
"Enable FLASHCOMM2 with o_proj layer sharding for reduced memory consumption."
|
||||
)
|
||||
logger.info_once("Enable FLASHCOMM2 with o_proj layer sharding for reduced memory consumption.")
|
||||
else:
|
||||
raise ValueError(
|
||||
"FLASHCOMM2 only supports 'o_proj' as the sole layer sharding configuration! "
|
||||
f"Found invalid layer_sharding: {layer_sharding}")
|
||||
f"Found invalid layer_sharding: {layer_sharding}"
|
||||
)
|
||||
if not envs_ascend.VLLM_ASCEND_ENABLE_FLASHCOMM1:
|
||||
logger.warning_once(
|
||||
"It is recommended to enable FLASHCOMM1 simultaneously when starting FLASHCOMM2 for optimal performance."
|
||||
@@ -1066,32 +1046,35 @@ def get_flashcomm2_config_and_validate(ascend_config, vllm_config):
|
||||
)
|
||||
if global_tp_size <= flashcomm2_oproj_tp_size:
|
||||
raise AssertionError(
|
||||
f"flashcomm2_oproj_tensor_parallel_size ({flashcomm2_oproj_tp_size}) cannot exceed global tensor parallel size ({global_tp_size})"
|
||||
f"flashcomm2_oproj_tensor_parallel_size ({flashcomm2_oproj_tp_size}) cannot exceed "
|
||||
f"global tensor parallel size ({global_tp_size})"
|
||||
)
|
||||
if global_tp_size % flashcomm2_oproj_tp_size != 0:
|
||||
raise AssertionError(
|
||||
f"Global tensor parallel size ({global_tp_size}) must be divisible by flashcomm2_oproj_tensor_parallel_size ({flashcomm2_oproj_tp_size})"
|
||||
f"Global tensor parallel size ({global_tp_size}) must be divisible by "
|
||||
f"flashcomm2_oproj_tensor_parallel_size ({flashcomm2_oproj_tp_size})"
|
||||
)
|
||||
if vllm_config.kv_transfer_config is None:
|
||||
logger.warning_once(
|
||||
"It is recommended to enable FLASHCOMM2 in P-scenario deployments, enable it in hybrid deployment may lead to decode performance degradation."
|
||||
"It is recommended to enable FLASHCOMM2 in P-scenario deployments, enable it in hybrid deployment "
|
||||
"may lead to decode performance degradation."
|
||||
)
|
||||
if vllm_config.kv_transfer_config is not None and vllm_config.kv_transfer_config.is_kv_consumer:
|
||||
raise AssertionError(
|
||||
"FLASHCOMM2 primarily targets P-scenario deployments, with additional support for hybrid deployment scenarios. It is not applicable in D-scenario environments."
|
||||
"FLASHCOMM2 primarily targets P-scenario deployments, with additional support "
|
||||
"for hybrid deployment scenarios. It is not applicable in D-scenario environments."
|
||||
)
|
||||
|
||||
return flashcomm2_oproj_tp_size
|
||||
|
||||
|
||||
def get_flashcomm2_reorgnized_batch_ids(global_tp_size) -> list[list[int]]:
|
||||
# Reorganize batch_ids so that, after the all2all and reduce-scatter operation, each batch_id corresponds to the rank_id within the DP domain.
|
||||
# Reorganize batch_ids so that, after the all2all and reduce-scatter operation,
|
||||
# each batch_id corresponds to the rank_id within the DP domain.
|
||||
# For example, when DP = [0, 1, 2, ..., 15] and flashcomm2_oproj_tensor_parallel_size = 2,
|
||||
# the reorganized batch_ids will be [[batch0, batch8], [batch1, batch9], ..., [batch7, batch15]].
|
||||
flashcomm2_otp_size = get_ascend_config(
|
||||
).flashcomm2_oproj_tensor_parallel_size
|
||||
num_oproj_tensor_parallel_groups: int = (global_tp_size //
|
||||
flashcomm2_otp_size)
|
||||
flashcomm2_otp_size = get_ascend_config().flashcomm2_oproj_tensor_parallel_size
|
||||
num_oproj_tensor_parallel_groups: int = global_tp_size // flashcomm2_otp_size
|
||||
|
||||
reorgnized_batch_ids = []
|
||||
for i in range(num_oproj_tensor_parallel_groups):
|
||||
@@ -1122,11 +1105,9 @@ def refresh_block_size(vllm_config):
|
||||
return
|
||||
|
||||
# TODO(MengqingCao): Remove the model_type check, after resolving the hidden error in get_kv_cache_groups.
|
||||
if not model_config.hf_text_config.model_type == "qwen3_next" and cache_config.block_size != 128:
|
||||
if model_config.hf_text_config.model_type != "qwen3_next" and cache_config.block_size != 128:
|
||||
if cache_config.enable_prefix_caching or scheduler_config.enable_chunked_prefill:
|
||||
logger.info(
|
||||
"Block size is set to 128 if prefix cache or chunked prefill is enabled."
|
||||
)
|
||||
logger.info("Block size is set to 128 if prefix cache or chunked prefill is enabled.")
|
||||
cache_config.block_size = 128
|
||||
|
||||
|
||||
@@ -1138,7 +1119,6 @@ def dispose_layer(layer: Any):
|
||||
|
||||
|
||||
def check_kv_extra_config(vllm_config):
|
||||
|
||||
def _check(name: str, config: dict):
|
||||
tp_key = "tp_size"
|
||||
dp_key = "dp_size"
|
||||
@@ -1148,24 +1128,21 @@ def check_kv_extra_config(vllm_config):
|
||||
if config_tp != vllm_tp:
|
||||
raise ValueError(
|
||||
f"KV transfer '{name}' config has a conflicting tensor parallel size. "
|
||||
f"Expected {vllm_tp}, but got {config_tp}.")
|
||||
f"Expected {vllm_tp}, but got {config_tp}."
|
||||
)
|
||||
if dp_key in config:
|
||||
config_dp = config[dp_key]
|
||||
vllm_dp = vllm_config.parallel_config.data_parallel_size
|
||||
if config_dp != vllm_dp:
|
||||
raise ValueError(
|
||||
f"KV transfer '{name}' config has a conflicting data parallel size. "
|
||||
f"Expected {vllm_dp}, but got {config_dp}.")
|
||||
f"Expected {vllm_dp}, but got {config_dp}."
|
||||
)
|
||||
|
||||
if vllm_config.kv_transfer_config.is_kv_producer:
|
||||
_check(
|
||||
"prefill",
|
||||
vllm_config.kv_transfer_config.get_from_extra_config(
|
||||
"prefill", {}))
|
||||
_check("prefill", vllm_config.kv_transfer_config.get_from_extra_config("prefill", {}))
|
||||
if vllm_config.kv_transfer_config.is_kv_consumer:
|
||||
_check(
|
||||
"decode",
|
||||
vllm_config.kv_transfer_config.get_from_extra_config("decode", {}))
|
||||
_check("decode", vllm_config.kv_transfer_config.get_from_extra_config("decode", {}))
|
||||
|
||||
|
||||
def singleton(cls):
|
||||
@@ -1179,17 +1156,17 @@ def singleton(cls):
|
||||
return get_instance
|
||||
|
||||
|
||||
#TODO: Temporarily use enable_sp to enable the dsa_cp feature of ds32. and subsequent updates will introduce new interfaces. --zzhx1
|
||||
# TODO: Temporarily use enable_sp to enable the dsa_cp feature of ds32.
|
||||
# and subsequent updates will introduce new interfaces. --zzhx1
|
||||
@lru_cache(maxsize=1)
|
||||
def enable_dsa_cp() -> bool:
|
||||
from vllm.config import get_current_vllm_config
|
||||
|
||||
vllm_config = get_current_vllm_config()
|
||||
is_ds_v32 = hasattr(
|
||||
vllm_config.model_config, "hf_text_config") and hasattr(
|
||||
vllm_config.model_config.hf_text_config, "index_topk")
|
||||
if is_ds_v32 and enable_sp():
|
||||
return True
|
||||
return False
|
||||
is_ds_v32 = hasattr(vllm_config.model_config, "hf_text_config") and hasattr(
|
||||
vllm_config.model_config.hf_text_config, "index_topk"
|
||||
)
|
||||
return bool(is_ds_v32 and enable_sp())
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
@@ -1197,6 +1174,7 @@ def enable_dsa_cp_with_layer_shard() -> bool:
|
||||
if not enable_dsa_cp():
|
||||
return False
|
||||
from vllm.config import get_current_vllm_config
|
||||
|
||||
vllm_config = get_current_vllm_config()
|
||||
is_prefill_instance = vllm_config.kv_transfer_config is not None and vllm_config.kv_transfer_config.is_kv_producer
|
||||
return is_prefill_instance
|
||||
|
||||
Reference in New Issue
Block a user