[Lint]Style: Convert vllm-ascend/compilation to ruff format (#5912)
### What this PR does / why we need it?
Convert `vllm-ascend/compilation` to ruff format.
### Does this PR introduce _any_ user-facing change?
During this migration, we encountered some **errors** in our CI and
testing environments, such as:
```
vllm_ascend/utils.py:653: in <module>
def register_ascend_customop(vllm_config: VllmConfig | None = None):
^^^^^^^^^^^^^^^^^
E TypeError: unsupported operand type(s) for |: 'NoneType' and 'NoneType'
```
**1. Root Cause Analysis:**
The project uses a common pattern to break circular dependencies:
```python
if TYPE_CHECKING:
from vllm.config import VllmConfig
else:
VllmConfig = None # Placeholder assigned at runtime
```
When Python parses the function definition `def
register_ascend_customop(vllm_config: VllmConfig | None)`, it attempts
to evaluate the expression `VllmConfig | None`.
Since `VllmConfig` is assigned `None` at runtime, the expression
effectively becomes `None | None`. In Python, `None` is an instance of
`NoneType`. While the `|` operator is implemented for Type objects
(classes), it is not supported for `NoneType` instances, leading to the
`TypeError` shown above.
**2. Solution:**
To maintain the modern `|` syntax required by our new linting standards
while preserving our dependency management strategy, I have introduced:
```python
from __future__ import annotations
```
at the top of the affected files. This enables **Postponed Evaluation of
Annotations (PEP 563)**.
**3. Impact and Benefits:**
- By enabling `annotations`, Python no longer executes the `VllmConfig |
None` operation during module load. Instead, it stores the annotation as
a string literal, completely avoiding the `None | None` calculation.
- We can keep the `VllmConfig = None` placeholders. This ensures that
other modules can still import these symbols without triggering an
`ImportError`, maintaining a stable dependency graph.
- IDEs and static type checkers (MyPy/Pyright) continue to resolve the
types correctly. This allows us to use modern syntax without sacrificing
type safety or runtime stability.
- The only side effect is that `__annotations__` will now return strings
instead of type objects. Since this module does not use runtime type
enforcement or reflection, this change has zero negative impact on
existing functionality.
### How was this patch tested?
- vLLM version: v0.13.0
- vLLM main:
11b6af5280
---------
Signed-off-by: MrZ20 <2609716663@qq.com>
This commit is contained in:
@@ -17,6 +17,8 @@
|
||||
# Adapted from vllm-project/vllm/vllm/worker/worker.py
|
||||
#
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import atexit
|
||||
import functools
|
||||
import math
|
||||
@@ -25,7 +27,7 @@ from contextlib import contextmanager, nullcontext
|
||||
from enum import Enum
|
||||
from functools import lru_cache
|
||||
from threading import Lock
|
||||
from typing import TYPE_CHECKING, Any, List, Optional, Tuple, Union
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import torch
|
||||
import torch_npu # noqa: F401
|
||||
@@ -88,6 +90,7 @@ def acl_graph_print(*args):
|
||||
resolving unexpected hangs. Usage:
|
||||
```python
|
||||
from vllm_ascend.utils import acl_graph_print
|
||||
|
||||
...
|
||||
acl_graph_print("Debug info")
|
||||
```
|
||||
@@ -113,8 +116,7 @@ def acl_graph_print(*args):
|
||||
torch_npu.npu._subscribe_report(current_compute_stream)
|
||||
_SUBSCRIBED_COMPUTE_STREAMS.add(current_compute_stream)
|
||||
|
||||
torch_npu.npu._launch_host_func(current_compute_stream,
|
||||
_print_callback_on_stream, args)
|
||||
torch_npu.npu._launch_host_func(current_compute_stream, _print_callback_on_stream, args)
|
||||
|
||||
|
||||
def _unregister_print_streams_on_exit():
|
||||
@@ -196,9 +198,7 @@ def nd_to_nz_2d(in_tensor: torch.Tensor) -> torch.Tensor:
|
||||
# after: pad_dims: [0, 2, 0, 3]
|
||||
|
||||
# return: (1, 2, 16, 16)
|
||||
return _custom_transpose(
|
||||
_custom_reshape(_custom_pad(in_tensor, pad_dims), aux_dims), 1,
|
||||
2).contiguous()
|
||||
return _custom_transpose(_custom_reshape(_custom_pad(in_tensor, pad_dims), aux_dims), 1, 2).contiguous()
|
||||
|
||||
|
||||
def nd_to_nz_spec(mask_tensor: torch.Tensor) -> torch.Tensor:
|
||||
@@ -208,11 +208,9 @@ def nd_to_nz_spec(mask_tensor: torch.Tensor) -> torch.Tensor:
|
||||
tokens_pad = (num_tokens + 15) // 16 * 16
|
||||
max_seq_len_pad = (max_seq_len + 15) // 16 * 16
|
||||
|
||||
mask_tensor_pad = \
|
||||
torch.zeros((1, tokens_pad, max_seq_len_pad), dtype=mask_tensor.dtype, device=mask_tensor.device)
|
||||
mask_tensor_pad = torch.zeros((1, tokens_pad, max_seq_len_pad), dtype=mask_tensor.dtype, device=mask_tensor.device)
|
||||
mask_tensor_pad[0][:num_tokens, :max_seq_len] = mask_tensor
|
||||
mask = mask_tensor_pad.reshape(
|
||||
(1, tokens_pad, max_seq_len_pad // 16, 16)).permute(0, 2, 1, 3)
|
||||
mask = mask_tensor_pad.reshape((1, tokens_pad, max_seq_len_pad // 16, 16)).permute(0, 2, 1, 3)
|
||||
return mask
|
||||
|
||||
|
||||
@@ -230,10 +228,7 @@ def aligned_16(tensor: torch.Tensor):
|
||||
return tensor
|
||||
|
||||
# Create a new tensor with shape (n_aligned, H, W) and fill it with zeros
|
||||
new_tensor = torch.zeros(n_aligned,
|
||||
*tensor.shape[1:],
|
||||
dtype=tensor.dtype,
|
||||
device=tensor.device)
|
||||
new_tensor = torch.zeros(n_aligned, *tensor.shape[1:], dtype=tensor.dtype, device=tensor.device)
|
||||
|
||||
# Copy the original tensor to the first N positions of the new tensor
|
||||
new_tensor[:n] = tensor
|
||||
@@ -254,15 +249,15 @@ def enable_custom_op():
|
||||
# isort: off
|
||||
# register custom ops into torch_library here
|
||||
import vllm_ascend.vllm_ascend_C # type: ignore # noqa: F401
|
||||
|
||||
# register the meta implementation for custom kernel if necessary
|
||||
import vllm_ascend.meta_registration # type: ignore # noqa: F401
|
||||
|
||||
# isort: on
|
||||
_CUSTOM_OP_ENABLED = True
|
||||
except ImportError:
|
||||
_CUSTOM_OP_ENABLED = False
|
||||
logger.warning(
|
||||
"Warning: Failed to register custom ops, all custom ops will be disabled"
|
||||
)
|
||||
logger.warning("Warning: Failed to register custom ops, all custom ops will be disabled")
|
||||
return _CUSTOM_OP_ENABLED
|
||||
|
||||
|
||||
@@ -277,8 +272,7 @@ def find_hccl_library() -> str:
|
||||
|
||||
# manually load the hccl library
|
||||
if so_file:
|
||||
logger.info("Found hccl from environment variable HCCL_SO_PATH=%s",
|
||||
so_file)
|
||||
logger.info("Found hccl from environment variable HCCL_SO_PATH=%s", so_file)
|
||||
else:
|
||||
if torch.version.cann is not None:
|
||||
so_file = "libhccl.so"
|
||||
@@ -318,6 +312,7 @@ def set_weight_prefetch_method(weight_prefetch_config: WeightPrefetchConfig):
|
||||
global _WEIGHT_PREFETCH_METHOD
|
||||
if _WEIGHT_PREFETCH_METHOD is None:
|
||||
from vllm_ascend.ops.weight_prefetch import WeightPrefetchMethod
|
||||
|
||||
_WEIGHT_PREFETCH_METHOD = WeightPrefetchMethod(weight_prefetch_config)
|
||||
return _WEIGHT_PREFETCH_METHOD
|
||||
|
||||
@@ -364,6 +359,7 @@ def vllm_version_is(target_vllm_version: str):
|
||||
vllm_version = envs_ascend.VLLM_VERSION
|
||||
else:
|
||||
import vllm
|
||||
|
||||
vllm_version = vllm.__version__
|
||||
try:
|
||||
return Version(vllm_version) == Version(target_vllm_version)
|
||||
@@ -372,7 +368,8 @@ def vllm_version_is(target_vllm_version: str):
|
||||
f"Invalid vllm version {vllm_version} found. A dev version of vllm "
|
||||
"is installed probably. Set the environment variable VLLM_VERSION "
|
||||
"to control it by hand. And please make sure the value follows the "
|
||||
"format of x.y.z.")
|
||||
"format of x.y.z."
|
||||
)
|
||||
|
||||
|
||||
def get_max_hidden_layers(hf_config) -> int:
|
||||
@@ -394,20 +391,19 @@ def get_max_hidden_layers(hf_config) -> int:
|
||||
|
||||
|
||||
# Update cudagraph capture sizes for vllm config
|
||||
def update_cudagraph_capture_sizes(vllm_config: VllmConfig,
|
||||
cudagraph_capture_sizes: List[int]):
|
||||
|
||||
valid_max_size = (cudagraph_capture_sizes[-1]
|
||||
if cudagraph_capture_sizes else 0)
|
||||
if (vllm_config.compilation_config.max_cudagraph_capture_size is not None
|
||||
and vllm_config.compilation_config.max_cudagraph_capture_size
|
||||
!= valid_max_size):
|
||||
def update_cudagraph_capture_sizes(vllm_config: VllmConfig, cudagraph_capture_sizes: list[int]):
|
||||
valid_max_size = cudagraph_capture_sizes[-1] if cudagraph_capture_sizes else 0
|
||||
if (
|
||||
vllm_config.compilation_config.max_cudagraph_capture_size is not None
|
||||
and vllm_config.compilation_config.max_cudagraph_capture_size != valid_max_size
|
||||
):
|
||||
if vllm_config.compilation_config.cudagraph_capture_sizes is not None:
|
||||
raise ValueError(
|
||||
"customized max_cudagraph_capture_size"
|
||||
f"(={vllm_config.compilation_config.max_cudagraph_capture_size}) "
|
||||
"should be consistent with the max value of "
|
||||
f"cudagraph_capture_sizes(={valid_max_size})")
|
||||
f"cudagraph_capture_sizes(={valid_max_size})"
|
||||
)
|
||||
logger.warning(
|
||||
"Truncating max_cudagraph_capture_size to %d",
|
||||
valid_max_size,
|
||||
@@ -415,12 +411,11 @@ def update_cudagraph_capture_sizes(vllm_config: VllmConfig,
|
||||
|
||||
vllm_config.compilation_config.max_cudagraph_capture_size = valid_max_size
|
||||
|
||||
if vllm_config.compilation_config.cudagraph_capture_sizes is not None and len(
|
||||
cudagraph_capture_sizes) < len(
|
||||
vllm_config.compilation_config.cudagraph_capture_sizes):
|
||||
if vllm_config.compilation_config.cudagraph_capture_sizes is not None and len(cudagraph_capture_sizes) < len(
|
||||
vllm_config.compilation_config.cudagraph_capture_sizes
|
||||
):
|
||||
logger.warning(
|
||||
("cudagraph_capture_sizes specified in compilation_config"
|
||||
" %s is overridden by config %s"),
|
||||
("cudagraph_capture_sizes specified in compilation_config %s is overridden by config %s"),
|
||||
vllm_config.compilation_config.cudagraph_capture_sizes,
|
||||
cudagraph_capture_sizes,
|
||||
)
|
||||
@@ -433,26 +428,17 @@ def _is_default_capture_sizes(vllm_config: VllmConfig) -> bool:
|
||||
Check whether it is vLLM default capture sizes.
|
||||
"""
|
||||
|
||||
max_cudagraph_capture_size = \
|
||||
vllm_config.compilation_config.max_cudagraph_capture_size
|
||||
cudagraph_capture_sizes = [
|
||||
i for i in [1, 2, 4] if i <= max_cudagraph_capture_size
|
||||
]
|
||||
max_cudagraph_capture_size = vllm_config.compilation_config.max_cudagraph_capture_size
|
||||
cudagraph_capture_sizes = [i for i in [1, 2, 4] if i <= max_cudagraph_capture_size]
|
||||
if max_cudagraph_capture_size >= 8:
|
||||
# Step size 8 for small batch sizes, up to 256(not included)
|
||||
cudagraph_capture_sizes += list(
|
||||
range(8, min(max_cudagraph_capture_size + 1, 256), 8))
|
||||
cudagraph_capture_sizes += list(range(8, min(max_cudagraph_capture_size + 1, 256), 8))
|
||||
if max_cudagraph_capture_size >= 256:
|
||||
# Step size 16 for larger batch sizes
|
||||
cudagraph_capture_sizes += list(
|
||||
range(256, max_cudagraph_capture_size + 1, 16))
|
||||
cudagraph_capture_sizes += list(range(256, max_cudagraph_capture_size + 1, 16))
|
||||
# in newer version, vLLM use ascending order of cudagraph_capture_sizes.
|
||||
target_cudagraph_capture_sizes = sorted(cudagraph_capture_sizes)
|
||||
if target_cudagraph_capture_sizes == \
|
||||
vllm_config.compilation_config.cudagraph_capture_sizes:
|
||||
return True
|
||||
|
||||
return False
|
||||
return target_cudagraph_capture_sizes == vllm_config.compilation_config.cudagraph_capture_sizes
|
||||
|
||||
|
||||
def update_default_aclgraph_sizes(vllm_config: VllmConfig) -> None:
|
||||
@@ -461,9 +447,11 @@ def update_default_aclgraph_sizes(vllm_config: VllmConfig) -> None:
|
||||
are more friendly to ascend ops && hardware.
|
||||
"""
|
||||
|
||||
if vllm_config.model_config is None or \
|
||||
vllm_config.model_config.enforce_eager or \
|
||||
not _is_default_capture_sizes(vllm_config):
|
||||
if (
|
||||
vllm_config.model_config is None
|
||||
or vllm_config.model_config.enforce_eager
|
||||
or not _is_default_capture_sizes(vllm_config)
|
||||
):
|
||||
return
|
||||
|
||||
# modify the default capture_sizes for Qwen3-MoE models on dp settings.
|
||||
@@ -471,16 +459,15 @@ def update_default_aclgraph_sizes(vllm_config: VllmConfig) -> None:
|
||||
# on special shapes.
|
||||
# TODO(Angazenn): we will remove this once _npu_paged_attention is fully
|
||||
# replaced by npu_fused_infer_attention_score which does not contain such bugs.
|
||||
if vllm_config.model_config and vllm_config.model_config.hf_text_config.model_type == "qwen3_moe" \
|
||||
and vllm_config.parallel_config.tensor_parallel_size == 1 \
|
||||
and vllm_config.parallel_config.data_parallel_size > 1 :
|
||||
|
||||
if (
|
||||
vllm_config.model_config
|
||||
and vllm_config.model_config.hf_text_config.model_type == "qwen3_moe"
|
||||
and vllm_config.parallel_config.tensor_parallel_size == 1
|
||||
and vllm_config.parallel_config.data_parallel_size > 1
|
||||
):
|
||||
max_capture_size = vllm_config.compilation_config.max_cudagraph_capture_size
|
||||
new_cudagraph_capture_sizes = [1, 2, 5, 10, 15, 20] + [
|
||||
i for i in range(24, max_capture_size + 1, 8)
|
||||
]
|
||||
update_cudagraph_capture_sizes(vllm_config,
|
||||
new_cudagraph_capture_sizes)
|
||||
new_cudagraph_capture_sizes = [1, 2, 5, 10, 15, 20] + [i for i in range(24, max_capture_size + 1, 8)]
|
||||
update_cudagraph_capture_sizes(vllm_config, new_cudagraph_capture_sizes)
|
||||
|
||||
|
||||
def update_aclgraph_sizes(vllm_config: VllmConfig) -> None:
|
||||
@@ -498,8 +485,7 @@ def update_aclgraph_sizes(vllm_config: VllmConfig) -> None:
|
||||
|
||||
# Store original configuration and temporarily clear it
|
||||
compilation_config = vllm_config.compilation_config
|
||||
original_sizes, compilation_config.cudagraph_capture_sizes = \
|
||||
compilation_config.cudagraph_capture_sizes, None
|
||||
original_sizes, compilation_config.cudagraph_capture_sizes = compilation_config.cudagraph_capture_sizes, None
|
||||
|
||||
# Calculate parallel configuration factor
|
||||
if not vllm_config.model_config:
|
||||
@@ -510,7 +496,7 @@ def update_aclgraph_sizes(vllm_config: VllmConfig) -> None:
|
||||
|
||||
return
|
||||
hf_config = vllm_config.model_config.hf_text_config
|
||||
if hasattr(hf_config, 'num_hidden_layers'):
|
||||
if hasattr(hf_config, "num_hidden_layers"):
|
||||
num_hidden_layers = hf_config.num_hidden_layers
|
||||
else:
|
||||
num_hidden_layers = get_max_hidden_layers(hf_config)
|
||||
@@ -519,24 +505,28 @@ def update_aclgraph_sizes(vllm_config: VllmConfig) -> None:
|
||||
# Calculate maximum supported batch sizes considering model architecture
|
||||
resources_per_graph = num_hidden_layers + 1
|
||||
# For suffix decoding, use the suffix path when no draft_model_config is provided.
|
||||
if (spec := vllm_config.speculative_config) and \
|
||||
(draft := spec.draft_model_config):
|
||||
if (spec := vllm_config.speculative_config) and (draft := spec.draft_model_config):
|
||||
resources_per_graph += draft.hf_config.num_hidden_layers + 1
|
||||
|
||||
# TODO: Find out whether we need to take into account the pp_size
|
||||
num_comm_groups = sum(size > 1 for size in [
|
||||
parallel_config.data_parallel_size,
|
||||
parallel_config.tensor_parallel_size,
|
||||
])
|
||||
num_comm_groups = sum(
|
||||
size > 1
|
||||
for size in [
|
||||
parallel_config.data_parallel_size,
|
||||
parallel_config.tensor_parallel_size,
|
||||
]
|
||||
)
|
||||
|
||||
if os.getenv("HCCL_OP_EXPANSION_MODE") == 'AIV':
|
||||
if os.getenv("HCCL_OP_EXPANSION_MODE") == "AIV":
|
||||
# TODO: Find out whether we need to take into account the pp_size
|
||||
parallel_factor = 1 + num_comm_groups + int(
|
||||
parallel_config.enable_expert_parallel) + int(
|
||||
vllm_config.additional_config.get(
|
||||
"multistream_overlap_shared_expert", False))
|
||||
parallel_factor = (
|
||||
1
|
||||
+ num_comm_groups
|
||||
+ int(parallel_config.enable_expert_parallel)
|
||||
+ int(vllm_config.additional_config.get("multistream_overlap_shared_expert", False))
|
||||
)
|
||||
if is_moe_model(vllm_config):
|
||||
parallel_factor += (parallel_config.data_parallel_size > 1)
|
||||
parallel_factor += parallel_config.data_parallel_size > 1
|
||||
else:
|
||||
# When AIV mode is enabled, the allreduce operator of the dense
|
||||
# layer model will occupy additional streams, which are buffered here.
|
||||
@@ -546,11 +536,8 @@ def update_aclgraph_sizes(vllm_config: VllmConfig) -> None:
|
||||
# Assume the following case:
|
||||
# MAX_CAPTURE_SIZE = 1920, num_hidden_layers = 48, data_parallel_size is 1, tensor_parallel_size is 4,
|
||||
# According to the formula, max_num_batch_sizes = math.floor(1920 / (48 + 1) / 2) = 19
|
||||
max_num_batch_sizes = math.floor(MAX_CAPTURE_SIZE /
|
||||
resources_per_graph / parallel_factor)
|
||||
logger.info(
|
||||
"Calculated maximum supported batch sizes for ACL graph: %s",
|
||||
max_num_batch_sizes)
|
||||
max_num_batch_sizes = math.floor(MAX_CAPTURE_SIZE / resources_per_graph / parallel_factor)
|
||||
logger.info("Calculated maximum supported batch sizes for ACL graph: %s", max_num_batch_sizes)
|
||||
else:
|
||||
# enable pcp or dcp will add new communication and consume additional approximately less than 100 streams
|
||||
if parallel_config.prefill_context_parallel_size > 1:
|
||||
@@ -562,18 +549,18 @@ def update_aclgraph_sizes(vllm_config: VllmConfig) -> None:
|
||||
# Under this configuration, HCCL employs the FFTS+ method for execution unfolding,
|
||||
# which adds only 1 concurrent stream without consuming collective communication execution unfolding streams.
|
||||
# On A3 hardware, HCCL defaults to the AICPU method.
|
||||
# This approach may additionally allocate up to rank_size (max 16) - 1 streams per collective communication domain on the device (worst case).
|
||||
# Using the default collective communication unfolding method on A3 will lead to a significant reduction in the maximum supported sizes.
|
||||
# This approach may additionally allocate up to rank_size (max 16) - 1 streams per collective communication
|
||||
# domain on the device (worst case).
|
||||
# Using the default collective communication unfolding method on A3 will lead to a significant reduction
|
||||
# in the maximum supported sizes.
|
||||
# Therefore, the calculation formula has been modified as follows:
|
||||
# Assume the following case:
|
||||
# MAX_CAPTURE_SIZE = 1920, num_hidden_layers = 48, data_parallel_size is 1, tensor_parallel_size is 4,
|
||||
# According to the formula, max_num_batch_sizes = math.floor((1920 - 1 * 40) / (48 + 1) / (1 + 1 * 2)) = 12
|
||||
max_num_batch_sizes = math.floor(
|
||||
(MAX_CAPTURE_SIZE - num_comm_groups * 40) / resources_per_graph /
|
||||
(1 + num_comm_groups * 2))
|
||||
logger.info(
|
||||
"Calculated maximum supported batch sizes for ACL graph: %s",
|
||||
max_num_batch_sizes)
|
||||
(MAX_CAPTURE_SIZE - num_comm_groups * 40) / resources_per_graph / (1 + num_comm_groups * 2)
|
||||
)
|
||||
logger.info("Calculated maximum supported batch sizes for ACL graph: %s", max_num_batch_sizes)
|
||||
logger.warning(
|
||||
"Currently, communication is performed using FFTS+ method, which reduces "
|
||||
"the number of available streams and, as a result, limits the range of runtime "
|
||||
@@ -598,26 +585,29 @@ def update_aclgraph_sizes(vllm_config: VllmConfig) -> None:
|
||||
vllm_config.model_config.architectures[0],
|
||||
num_hidden_layers,
|
||||
len(original_sizes),
|
||||
len(compilation_config.
|
||||
cudagraph_capture_sizes # type: ignore[arg-type]
|
||||
))
|
||||
len(
|
||||
compilation_config.cudagraph_capture_sizes # type: ignore[arg-type]
|
||||
),
|
||||
)
|
||||
else:
|
||||
# No adjustment needed
|
||||
compilation_config.cudagraph_capture_sizes = original_sizes
|
||||
logger.info(
|
||||
"No adjustment needed for ACL graph batch sizes: %s model (layers: %d) with %d sizes",
|
||||
vllm_config.model_config.architectures[0], num_hidden_layers,
|
||||
len(original_sizes))
|
||||
vllm_config.model_config.architectures[0],
|
||||
num_hidden_layers,
|
||||
len(original_sizes),
|
||||
)
|
||||
|
||||
|
||||
# TODO(wxy): Move to ops module
|
||||
def dispose_tensor(x: torch.Tensor):
|
||||
x.set_(torch.empty((0, ), device=x.device, dtype=x.dtype))
|
||||
x.set_(torch.empty((0,), device=x.device, dtype=x.dtype))
|
||||
|
||||
|
||||
class ProfileExecuteDuration:
|
||||
_instance = None
|
||||
_observations: List[Tuple[str, Event, Event]] = []
|
||||
_observations: list[tuple[str, Event, Event]] = []
|
||||
_lock = Lock()
|
||||
|
||||
def __new__(cls):
|
||||
@@ -645,8 +635,7 @@ class ProfileExecuteDuration:
|
||||
observe_end = Event(enable_timing=True)
|
||||
observe_end.record()
|
||||
with self._lock:
|
||||
self._observations.append(
|
||||
(duration_tag, observe_start, observe_end))
|
||||
self._observations.append((duration_tag, observe_start, observe_end))
|
||||
|
||||
def pop_captured_sync(self) -> dict:
|
||||
"""Pop and synchronize all events in the observation list"""
|
||||
@@ -663,7 +652,7 @@ class ProfileExecuteDuration:
|
||||
return durations
|
||||
|
||||
|
||||
def register_ascend_customop(vllm_config: Optional[VllmConfig] = None):
|
||||
def register_ascend_customop(vllm_config: VllmConfig | None = None):
|
||||
"""Register Ascend CustomOP
|
||||
|
||||
NOTE: if the register branch requires model type, please use `vllm.config.get_current_vllm_config`,
|
||||
@@ -675,23 +664,29 @@ def register_ascend_customop(vllm_config: Optional[VllmConfig] = None):
|
||||
from vllm.model_executor.custom_op import CustomOp
|
||||
|
||||
from vllm_ascend.ops.activation import AscendQuickGELU, AscendSiluAndMul
|
||||
from vllm_ascend.ops.fused_moe.fused_moe import (AscendFusedMoE,
|
||||
AscendSharedFusedMoE)
|
||||
from vllm_ascend.ops.fused_moe.fused_moe import AscendFusedMoE, AscendSharedFusedMoE
|
||||
from vllm_ascend.ops.layernorm import AscendGemmaRMSNorm, AscendRMSNorm
|
||||
from vllm_ascend.ops.linear import (AscendColumnParallelLinear,
|
||||
AscendMergedColumnParallelLinear,
|
||||
AscendQKVParallelLinear,
|
||||
AscendReplicatedLinear,
|
||||
AscendRowParallelLinear)
|
||||
from vllm_ascend.ops.linear import (
|
||||
AscendColumnParallelLinear,
|
||||
AscendMergedColumnParallelLinear,
|
||||
AscendQKVParallelLinear,
|
||||
AscendReplicatedLinear,
|
||||
AscendRowParallelLinear,
|
||||
)
|
||||
from vllm_ascend.ops.mla import AscendMultiHeadLatentAttention
|
||||
from vllm_ascend.ops.mm_encoder_attention import AscendMMEncoderAttention
|
||||
from vllm_ascend.ops.rotary_embedding import (
|
||||
AscendApplyRotaryEmb, AscendDeepseekScalingRotaryEmbedding,
|
||||
AscendMRotaryEmbedding, AscendRotaryEmbedding,
|
||||
AscendYaRNRotaryEmbedding)
|
||||
AscendApplyRotaryEmb,
|
||||
AscendDeepseekScalingRotaryEmbedding,
|
||||
AscendMRotaryEmbedding,
|
||||
AscendRotaryEmbedding,
|
||||
AscendYaRNRotaryEmbedding,
|
||||
)
|
||||
from vllm_ascend.ops.vocab_parallel_embedding import (
|
||||
AscendLogitsProcessor, AscendParallelLMHead,
|
||||
AscendVocabParallelEmbedding)
|
||||
AscendLogitsProcessor,
|
||||
AscendParallelLMHead,
|
||||
AscendVocabParallelEmbedding,
|
||||
)
|
||||
|
||||
global REGISTERED_ASCEND_OPS
|
||||
REGISTERED_ASCEND_OPS = {
|
||||
@@ -738,6 +733,7 @@ _ascend_device_type = None
|
||||
def _init_ascend_device_type():
|
||||
global _ascend_device_type
|
||||
from vllm_ascend import _build_info # type: ignore
|
||||
|
||||
_ascend_device_type = AscendDeviceType[_build_info.__device_type__]
|
||||
|
||||
|
||||
@@ -758,7 +754,10 @@ def check_ascend_device_type():
|
||||
else:
|
||||
raise RuntimeError(f"Can not support soc_version: {soc_version}.")
|
||||
|
||||
assert _ascend_device_type == cur_device_type, f"Current device type: {cur_device_type} does not match the installed version's device type: {_ascend_device_type}, please check your installation package."
|
||||
assert _ascend_device_type == cur_device_type, (
|
||||
f"Current device type: {cur_device_type} does not match the installed version's device type: "
|
||||
f"{_ascend_device_type}, please check your installation package."
|
||||
)
|
||||
|
||||
|
||||
def get_ascend_device_type():
|
||||
@@ -769,23 +768,19 @@ def get_ascend_device_type():
|
||||
|
||||
|
||||
def lmhead_tp_enable() -> bool:
|
||||
return get_ascend_config(
|
||||
).finegrained_tp_config.lmhead_tensor_parallel_size > 0
|
||||
return get_ascend_config().finegrained_tp_config.lmhead_tensor_parallel_size > 0
|
||||
|
||||
|
||||
def embedding_tp_enable() -> bool:
|
||||
return get_ascend_config(
|
||||
).finegrained_tp_config.embedding_tensor_parallel_size > 0
|
||||
return get_ascend_config().finegrained_tp_config.embedding_tensor_parallel_size > 0
|
||||
|
||||
|
||||
def oproj_tp_enable() -> bool:
|
||||
return get_ascend_config(
|
||||
).finegrained_tp_config.oproj_tensor_parallel_size > 0
|
||||
return get_ascend_config().finegrained_tp_config.oproj_tensor_parallel_size > 0
|
||||
|
||||
|
||||
def mlp_tp_enable() -> bool:
|
||||
return get_ascend_config(
|
||||
).finegrained_tp_config.mlp_tensor_parallel_size > 0
|
||||
return get_ascend_config().finegrained_tp_config.mlp_tensor_parallel_size > 0
|
||||
|
||||
|
||||
def matmul_allreduce_enable() -> bool:
|
||||
@@ -797,30 +792,30 @@ def enable_sp(vllm_config=None, enable_shared_expert_dp: bool = False) -> bool:
|
||||
if _ENABLE_SP is None:
|
||||
if vllm_config is None:
|
||||
from vllm.config import get_current_vllm_config
|
||||
|
||||
vllm_config = get_current_vllm_config()
|
||||
_ENABLE_SP = (
|
||||
vllm_config.compilation_config.pass_config.enable_sp
|
||||
or envs_ascend.VLLM_ASCEND_ENABLE_FLASHCOMM1
|
||||
# Flash comm 1 should be enabled by env VLLM_ASCEND_ENABLE_FLASHCOMM1
|
||||
# We retain the env VLLM_ASCEND_ENABLE_FLASHCOMM here for backward compatibility.
|
||||
or bool(int(os.getenv("VLLM_ASCEND_ENABLE_FLASHCOMM", '0'))))
|
||||
or bool(int(os.getenv("VLLM_ASCEND_ENABLE_FLASHCOMM", "0")))
|
||||
)
|
||||
|
||||
if not _ENABLE_SP and enable_shared_expert_dp:
|
||||
_ENABLE_SP = True
|
||||
logger.info(
|
||||
"shared_expert_dp requires enable_sp = True. has set enable_sp to True"
|
||||
)
|
||||
logger.info("shared_expert_dp requires enable_sp = True. has set enable_sp to True")
|
||||
|
||||
if not _ENABLE_SP:
|
||||
return _ENABLE_SP
|
||||
|
||||
assert vllm_config.parallel_config.tensor_parallel_size > 1, \
|
||||
assert vllm_config.parallel_config.tensor_parallel_size > 1, (
|
||||
"Flash Comm v1 (Sequence Parallelism) is only supported when tp_size > 1."
|
||||
)
|
||||
|
||||
assert (
|
||||
not is_moe_model(vllm_config)
|
||||
or vllm_config.parallel_config.enable_expert_parallel
|
||||
), "Flash Comm v1 (Sequence Parallelism) requires enable_expert_parallel=True for MoE models."
|
||||
assert not is_moe_model(vllm_config) or vllm_config.parallel_config.enable_expert_parallel, (
|
||||
"Flash Comm v1 (Sequence Parallelism) requires enable_expert_parallel=True for MoE models."
|
||||
)
|
||||
|
||||
return _ENABLE_SP
|
||||
|
||||
@@ -847,25 +842,21 @@ def is_drafter_moe_model(vllm_config: VllmConfig):
|
||||
"""Checks if the drafter model is a MoE model by config"""
|
||||
global _IS_DRAFTER_MOE_MODEL
|
||||
if _IS_DRAFTER_MOE_MODEL is None:
|
||||
model_configs = vllm_config.speculative_config.draft_model_config.hf_text_config \
|
||||
.to_dict()
|
||||
model_configs = vllm_config.speculative_config.draft_model_config.hf_text_config.to_dict()
|
||||
_IS_DRAFTER_MOE_MODEL = _is_contain_expert(model_configs)
|
||||
return _IS_DRAFTER_MOE_MODEL
|
||||
|
||||
|
||||
def speculative_enable_dispatch_gmm_combine_decode(
|
||||
vllm_config: VllmConfig) -> bool:
|
||||
def speculative_enable_dispatch_gmm_combine_decode(vllm_config: VllmConfig) -> bool:
|
||||
if vllm_config.speculative_config is None:
|
||||
return True
|
||||
speculative_method = getattr(vllm_config.speculative_config, "method",
|
||||
None)
|
||||
speculative_method = getattr(vllm_config.speculative_config, "method", None)
|
||||
if speculative_method in [None, "ngram", "suffix"]:
|
||||
return True
|
||||
if speculative_method in ["eagle", "eagle3"]:
|
||||
return False
|
||||
if speculative_method == "mtp":
|
||||
mtp_quant_type = getattr(vllm_config.model_config.hf_text_config,
|
||||
"mtp_quantize", None)
|
||||
mtp_quant_type = getattr(vllm_config.model_config.hf_text_config, "mtp_quantize", None)
|
||||
return mtp_quant_type == "w8a8_dynamic"
|
||||
return False
|
||||
|
||||
@@ -915,8 +906,8 @@ def weak_ref_tensor(tensor: Any) -> Any:
|
||||
|
||||
|
||||
def weak_ref_tensors(
|
||||
tensors: Union[torch.Tensor, list[torch.Tensor], tuple[torch.Tensor]]
|
||||
) -> Union[torch.Tensor, list[Any], tuple[Any], Any]:
|
||||
tensors: torch.Tensor | list[torch.Tensor] | tuple[torch.Tensor],
|
||||
) -> torch.Tensor | list[Any] | tuple[Any] | Any:
|
||||
"""
|
||||
Convenience function to create weak references to tensors,
|
||||
for single tensor, list of tensors or tuple of tensors.
|
||||
@@ -936,17 +927,12 @@ def weak_ref_tensors(
|
||||
return tuple(weak_ref_tensor(t) for t in tensors)
|
||||
# For IntermediateTensors used in pipeline parallelism
|
||||
if isinstance(tensors, IntermediateTensors):
|
||||
ret = IntermediateTensors({
|
||||
key: weak_ref_tensor(val)
|
||||
for key, val in tensors.tensors.items()
|
||||
})
|
||||
ret = IntermediateTensors({key: weak_ref_tensor(val) for key, val in tensors.tensors.items()})
|
||||
return ret
|
||||
raise ValueError("Invalid type for tensors")
|
||||
|
||||
|
||||
def npu_stream_switch(target_stream: torch.npu.Stream,
|
||||
*,
|
||||
enabled: bool = True):
|
||||
def npu_stream_switch(target_stream: torch.npu.Stream, *, enabled: bool = True):
|
||||
"""
|
||||
Switch to the target stream if enabled is True.
|
||||
Otherwise, do nothing.
|
||||
@@ -965,7 +951,7 @@ def create_hccl_pg_options(group_name: str):
|
||||
return options
|
||||
|
||||
|
||||
def get_hccl_config_for_pg_options(group_name: str) -> Optional[dict]:
|
||||
def get_hccl_config_for_pg_options(group_name: str) -> dict | None:
|
||||
"""
|
||||
Get HCCL process group options for the given communication group name.
|
||||
|
||||
@@ -981,9 +967,7 @@ def get_hccl_config_for_pg_options(group_name: str) -> Optional[dict]:
|
||||
if group_name and "mc2" in group_name:
|
||||
return None
|
||||
hccl_config_map = {
|
||||
"dp": {
|
||||
"hccl_buffer_size": calculate_dp_buffer_size()
|
||||
},
|
||||
"dp": {"hccl_buffer_size": calculate_dp_buffer_size()},
|
||||
}
|
||||
return hccl_config_map.get(group_name, get_default_buffer_config())
|
||||
|
||||
@@ -998,6 +982,7 @@ def calculate_dp_buffer_size() -> int:
|
||||
dp_size + 1 (flags: with_prefill)
|
||||
"""
|
||||
from vllm.config import get_current_vllm_config
|
||||
|
||||
vllm_config = get_current_vllm_config()
|
||||
dp_size = vllm_config.parallel_config.data_parallel_size
|
||||
int32_size = torch.iinfo(torch.int32).bits // 8
|
||||
@@ -1009,8 +994,7 @@ def calculate_dp_buffer_size() -> int:
|
||||
# and HCCL_INTRA_ROCE_ENABLE=0 can reduce cross-machine communication traffic and
|
||||
# significantly improve communication performance of MC2 ops dispatch/combine.
|
||||
def is_hierarchical_communication_enabled():
|
||||
return (os.getenv("HCCL_INTRA_ROCE_ENABLE", "") == "0"
|
||||
and os.getenv("HCCL_INTRA_PCIE_ENABLE", "") == "1")
|
||||
return os.getenv("HCCL_INTRA_ROCE_ENABLE", "") == "0" and os.getenv("HCCL_INTRA_PCIE_ENABLE", "") == "1"
|
||||
|
||||
|
||||
def has_layer_idx(model_instance: torch.nn.Module) -> bool:
|
||||
@@ -1019,8 +1003,7 @@ def has_layer_idx(model_instance: torch.nn.Module) -> bool:
|
||||
|
||||
global _HAS_LAYER_IDX
|
||||
if _HAS_LAYER_IDX is None:
|
||||
_HAS_LAYER_IDX = hasattr(model_instance, "model") and \
|
||||
hasattr(model_instance.model, "start_layer")
|
||||
_HAS_LAYER_IDX = hasattr(model_instance, "model") and hasattr(model_instance.model, "start_layer")
|
||||
return _HAS_LAYER_IDX
|
||||
|
||||
|
||||
@@ -1042,20 +1025,17 @@ def get_flashcomm2_config_and_validate(ascend_config, vllm_config):
|
||||
if not flashcomm2_enable():
|
||||
return 0
|
||||
|
||||
logger.info(
|
||||
f"Enable FLASHCOMM2 with flashcomm2_oproj_tensor_parallel_size = {flashcomm2_oproj_tp_size}"
|
||||
)
|
||||
logger.info(f"Enable FLASHCOMM2 with flashcomm2_oproj_tensor_parallel_size = {flashcomm2_oproj_tp_size}")
|
||||
|
||||
layer_sharding = ascend_config.layer_sharding or []
|
||||
if layer_sharding:
|
||||
if layer_sharding == ["o_proj"]:
|
||||
logger.info_once(
|
||||
"Enable FLASHCOMM2 with o_proj layer sharding for reduced memory consumption."
|
||||
)
|
||||
logger.info_once("Enable FLASHCOMM2 with o_proj layer sharding for reduced memory consumption.")
|
||||
else:
|
||||
raise ValueError(
|
||||
"FLASHCOMM2 only supports 'o_proj' as the sole layer sharding configuration! "
|
||||
f"Found invalid layer_sharding: {layer_sharding}")
|
||||
f"Found invalid layer_sharding: {layer_sharding}"
|
||||
)
|
||||
if not envs_ascend.VLLM_ASCEND_ENABLE_FLASHCOMM1:
|
||||
logger.warning_once(
|
||||
"It is recommended to enable FLASHCOMM1 simultaneously when starting FLASHCOMM2 for optimal performance."
|
||||
@@ -1066,32 +1046,35 @@ def get_flashcomm2_config_and_validate(ascend_config, vllm_config):
|
||||
)
|
||||
if global_tp_size <= flashcomm2_oproj_tp_size:
|
||||
raise AssertionError(
|
||||
f"flashcomm2_oproj_tensor_parallel_size ({flashcomm2_oproj_tp_size}) cannot exceed global tensor parallel size ({global_tp_size})"
|
||||
f"flashcomm2_oproj_tensor_parallel_size ({flashcomm2_oproj_tp_size}) cannot exceed "
|
||||
f"global tensor parallel size ({global_tp_size})"
|
||||
)
|
||||
if global_tp_size % flashcomm2_oproj_tp_size != 0:
|
||||
raise AssertionError(
|
||||
f"Global tensor parallel size ({global_tp_size}) must be divisible by flashcomm2_oproj_tensor_parallel_size ({flashcomm2_oproj_tp_size})"
|
||||
f"Global tensor parallel size ({global_tp_size}) must be divisible by "
|
||||
f"flashcomm2_oproj_tensor_parallel_size ({flashcomm2_oproj_tp_size})"
|
||||
)
|
||||
if vllm_config.kv_transfer_config is None:
|
||||
logger.warning_once(
|
||||
"It is recommended to enable FLASHCOMM2 in P-scenario deployments, enable it in hybrid deployment may lead to decode performance degradation."
|
||||
"It is recommended to enable FLASHCOMM2 in P-scenario deployments, enable it in hybrid deployment "
|
||||
"may lead to decode performance degradation."
|
||||
)
|
||||
if vllm_config.kv_transfer_config is not None and vllm_config.kv_transfer_config.is_kv_consumer:
|
||||
raise AssertionError(
|
||||
"FLASHCOMM2 primarily targets P-scenario deployments, with additional support for hybrid deployment scenarios. It is not applicable in D-scenario environments."
|
||||
"FLASHCOMM2 primarily targets P-scenario deployments, with additional support "
|
||||
"for hybrid deployment scenarios. It is not applicable in D-scenario environments."
|
||||
)
|
||||
|
||||
return flashcomm2_oproj_tp_size
|
||||
|
||||
|
||||
def get_flashcomm2_reorgnized_batch_ids(global_tp_size) -> list[list[int]]:
|
||||
# Reorganize batch_ids so that, after the all2all and reduce-scatter operation, each batch_id corresponds to the rank_id within the DP domain.
|
||||
# Reorganize batch_ids so that, after the all2all and reduce-scatter operation,
|
||||
# each batch_id corresponds to the rank_id within the DP domain.
|
||||
# For example, when DP = [0, 1, 2, ..., 15] and flashcomm2_oproj_tensor_parallel_size = 2,
|
||||
# the reorganized batch_ids will be [[batch0, batch8], [batch1, batch9], ..., [batch7, batch15]].
|
||||
flashcomm2_otp_size = get_ascend_config(
|
||||
).flashcomm2_oproj_tensor_parallel_size
|
||||
num_oproj_tensor_parallel_groups: int = (global_tp_size //
|
||||
flashcomm2_otp_size)
|
||||
flashcomm2_otp_size = get_ascend_config().flashcomm2_oproj_tensor_parallel_size
|
||||
num_oproj_tensor_parallel_groups: int = global_tp_size // flashcomm2_otp_size
|
||||
|
||||
reorgnized_batch_ids = []
|
||||
for i in range(num_oproj_tensor_parallel_groups):
|
||||
@@ -1122,11 +1105,9 @@ def refresh_block_size(vllm_config):
|
||||
return
|
||||
|
||||
# TODO(MengqingCao): Remove the model_type check, after resolving the hidden error in get_kv_cache_groups.
|
||||
if not model_config.hf_text_config.model_type == "qwen3_next" and cache_config.block_size != 128:
|
||||
if model_config.hf_text_config.model_type != "qwen3_next" and cache_config.block_size != 128:
|
||||
if cache_config.enable_prefix_caching or scheduler_config.enable_chunked_prefill:
|
||||
logger.info(
|
||||
"Block size is set to 128 if prefix cache or chunked prefill is enabled."
|
||||
)
|
||||
logger.info("Block size is set to 128 if prefix cache or chunked prefill is enabled.")
|
||||
cache_config.block_size = 128
|
||||
|
||||
|
||||
@@ -1138,7 +1119,6 @@ def dispose_layer(layer: Any):
|
||||
|
||||
|
||||
def check_kv_extra_config(vllm_config):
|
||||
|
||||
def _check(name: str, config: dict):
|
||||
tp_key = "tp_size"
|
||||
dp_key = "dp_size"
|
||||
@@ -1148,24 +1128,21 @@ def check_kv_extra_config(vllm_config):
|
||||
if config_tp != vllm_tp:
|
||||
raise ValueError(
|
||||
f"KV transfer '{name}' config has a conflicting tensor parallel size. "
|
||||
f"Expected {vllm_tp}, but got {config_tp}.")
|
||||
f"Expected {vllm_tp}, but got {config_tp}."
|
||||
)
|
||||
if dp_key in config:
|
||||
config_dp = config[dp_key]
|
||||
vllm_dp = vllm_config.parallel_config.data_parallel_size
|
||||
if config_dp != vllm_dp:
|
||||
raise ValueError(
|
||||
f"KV transfer '{name}' config has a conflicting data parallel size. "
|
||||
f"Expected {vllm_dp}, but got {config_dp}.")
|
||||
f"Expected {vllm_dp}, but got {config_dp}."
|
||||
)
|
||||
|
||||
if vllm_config.kv_transfer_config.is_kv_producer:
|
||||
_check(
|
||||
"prefill",
|
||||
vllm_config.kv_transfer_config.get_from_extra_config(
|
||||
"prefill", {}))
|
||||
_check("prefill", vllm_config.kv_transfer_config.get_from_extra_config("prefill", {}))
|
||||
if vllm_config.kv_transfer_config.is_kv_consumer:
|
||||
_check(
|
||||
"decode",
|
||||
vllm_config.kv_transfer_config.get_from_extra_config("decode", {}))
|
||||
_check("decode", vllm_config.kv_transfer_config.get_from_extra_config("decode", {}))
|
||||
|
||||
|
||||
def singleton(cls):
|
||||
@@ -1179,17 +1156,17 @@ def singleton(cls):
|
||||
return get_instance
|
||||
|
||||
|
||||
#TODO: Temporarily use enable_sp to enable the dsa_cp feature of ds32. and subsequent updates will introduce new interfaces. --zzhx1
|
||||
# TODO: Temporarily use enable_sp to enable the dsa_cp feature of ds32.
|
||||
# and subsequent updates will introduce new interfaces. --zzhx1
|
||||
@lru_cache(maxsize=1)
|
||||
def enable_dsa_cp() -> bool:
|
||||
from vllm.config import get_current_vllm_config
|
||||
|
||||
vllm_config = get_current_vllm_config()
|
||||
is_ds_v32 = hasattr(
|
||||
vllm_config.model_config, "hf_text_config") and hasattr(
|
||||
vllm_config.model_config.hf_text_config, "index_topk")
|
||||
if is_ds_v32 and enable_sp():
|
||||
return True
|
||||
return False
|
||||
is_ds_v32 = hasattr(vllm_config.model_config, "hf_text_config") and hasattr(
|
||||
vllm_config.model_config.hf_text_config, "index_topk"
|
||||
)
|
||||
return bool(is_ds_v32 and enable_sp())
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
@@ -1197,6 +1174,7 @@ def enable_dsa_cp_with_layer_shard() -> bool:
|
||||
if not enable_dsa_cp():
|
||||
return False
|
||||
from vllm.config import get_current_vllm_config
|
||||
|
||||
vllm_config = get_current_vllm_config()
|
||||
is_prefill_instance = vllm_config.kv_transfer_config is not None and vllm_config.kv_transfer_config.is_kv_producer
|
||||
return is_prefill_instance
|
||||
|
||||
Reference in New Issue
Block a user