[Lint]Style: Convert vllm-ascend/compilation to ruff format (#5912)
### What this PR does / why we need it?
Convert `vllm-ascend/compilation` to ruff format.
### Does this PR introduce _any_ user-facing change?
During this migration, we encountered some **errors** in our CI and
testing environments, such as:
```
vllm_ascend/utils.py:653: in <module>
def register_ascend_customop(vllm_config: VllmConfig | None = None):
^^^^^^^^^^^^^^^^^
E TypeError: unsupported operand type(s) for |: 'NoneType' and 'NoneType'
```
**1. Root Cause Analysis:**
The project uses a common pattern to break circular dependencies:
```python
if TYPE_CHECKING:
from vllm.config import VllmConfig
else:
VllmConfig = None # Placeholder assigned at runtime
```
When Python parses the function definition `def
register_ascend_customop(vllm_config: VllmConfig | None)`, it attempts
to evaluate the expression `VllmConfig | None`.
Since `VllmConfig` is assigned `None` at runtime, the expression
effectively becomes `None | None`. In Python, `None` is an instance of
`NoneType`. While the `|` operator is implemented for Type objects
(classes), it is not supported for `NoneType` instances, leading to the
`TypeError` shown above.
**2. Solution:**
To maintain the modern `|` syntax required by our new linting standards
while preserving our dependency management strategy, I have introduced:
```python
from __future__ import annotations
```
at the top of the affected files. This enables **Postponed Evaluation of
Annotations (PEP 563)**.
**3. Impact and Benefits:**
- By enabling `annotations`, Python no longer executes the `VllmConfig |
None` operation during module load. Instead, it stores the annotation as
a string literal, completely avoiding the `None | None` calculation.
- We can keep the `VllmConfig = None` placeholders. This ensures that
other modules can still import these symbols without triggering an
`ImportError`, maintaining a stable dependency graph.
- IDEs and static type checkers (MyPy/Pyright) continue to resolve the
types correctly. This allows us to use modern syntax without sacrificing
type safety or runtime stability.
- The only side effect is that `__annotations__` will now return strings
instead of type objects. Since this module does not use runtime type
enforcement or reflection, this change has zero negative impact on
existing functionality.
### How was this patch tested?
- vLLM version: v0.13.0
- vLLM main:
11b6af5280
---------
Signed-off-by: MrZ20 <2609716663@qq.com>
This commit is contained in:
@@ -1,21 +1,25 @@
|
||||
import math
|
||||
from contextlib import contextmanager
|
||||
from enum import Enum
|
||||
from typing import Any, Optional
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
from vllm.config import CUDAGraphMode, VllmConfig
|
||||
from vllm.distributed import (get_dp_group, get_ep_group,
|
||||
get_tensor_model_parallel_world_size)
|
||||
from vllm.forward_context import (BatchDescriptor, get_forward_context,
|
||||
set_forward_context)
|
||||
from vllm.distributed import get_dp_group, get_ep_group, get_tensor_model_parallel_world_size
|
||||
from vllm.forward_context import BatchDescriptor, get_forward_context, set_forward_context
|
||||
|
||||
import vllm_ascend.envs as envs_ascend
|
||||
from vllm_ascend.ascend_config import get_ascend_config
|
||||
from vllm_ascend.utils import (AscendDeviceType, enable_sp, flashcomm2_enable,
|
||||
get_ascend_device_type, has_layer_idx,
|
||||
is_drafter_moe_model, is_moe_model,
|
||||
speculative_enable_dispatch_gmm_combine_decode)
|
||||
from vllm_ascend.utils import (
|
||||
AscendDeviceType,
|
||||
enable_sp,
|
||||
flashcomm2_enable,
|
||||
get_ascend_device_type,
|
||||
has_layer_idx,
|
||||
is_drafter_moe_model,
|
||||
is_moe_model,
|
||||
speculative_enable_dispatch_gmm_combine_decode,
|
||||
)
|
||||
|
||||
|
||||
class MoECommType(Enum):
|
||||
@@ -27,36 +31,36 @@ class MoECommType(Enum):
|
||||
|
||||
@contextmanager
|
||||
def set_ascend_forward_context(
|
||||
attn_metadata: Any,
|
||||
vllm_config: VllmConfig,
|
||||
virtual_engine: int = 0,
|
||||
num_tokens: int = 0,
|
||||
num_tokens_across_dp: Optional[torch.Tensor] = None,
|
||||
in_profile_run: bool = False,
|
||||
num_actual_tokens: Optional[int] = None,
|
||||
aclgraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
|
||||
batch_descriptor: Optional[BatchDescriptor] = None,
|
||||
model_instance: torch.nn.Module = None,
|
||||
is_draft_model=False):
|
||||
attn_metadata: Any,
|
||||
vllm_config: VllmConfig,
|
||||
virtual_engine: int = 0,
|
||||
num_tokens: int = 0,
|
||||
num_tokens_across_dp: torch.Tensor | None = None,
|
||||
in_profile_run: bool = False,
|
||||
num_actual_tokens: int | None = None,
|
||||
aclgraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
|
||||
batch_descriptor: BatchDescriptor | None = None,
|
||||
model_instance: torch.nn.Module = None,
|
||||
is_draft_model=False,
|
||||
):
|
||||
"""A context manager that stores the current forward context,
|
||||
can be attention metadata, etc.
|
||||
We add some additional param into forward_context.
|
||||
"""
|
||||
with set_forward_context(
|
||||
attn_metadata,
|
||||
vllm_config,
|
||||
virtual_engine=virtual_engine,
|
||||
num_tokens=num_tokens,
|
||||
num_tokens_across_dp=num_tokens_across_dp,
|
||||
cudagraph_runtime_mode=aclgraph_runtime_mode,
|
||||
batch_descriptor=batch_descriptor,
|
||||
attn_metadata,
|
||||
vllm_config,
|
||||
virtual_engine=virtual_engine,
|
||||
num_tokens=num_tokens,
|
||||
num_tokens_across_dp=num_tokens_across_dp,
|
||||
cudagraph_runtime_mode=aclgraph_runtime_mode,
|
||||
batch_descriptor=batch_descriptor,
|
||||
):
|
||||
forward_context = get_forward_context()
|
||||
|
||||
from vllm_ascend.ops.fused_moe.moe_comm_method import \
|
||||
get_moe_comm_method
|
||||
moe_comm_type = select_moe_comm_method(num_tokens, vllm_config,
|
||||
is_draft_model)
|
||||
from vllm_ascend.ops.fused_moe.moe_comm_method import get_moe_comm_method
|
||||
|
||||
moe_comm_type = select_moe_comm_method(num_tokens, vllm_config, is_draft_model)
|
||||
forward_context.moe_comm_type = moe_comm_type
|
||||
forward_context.moe_comm_method = get_moe_comm_method(moe_comm_type)
|
||||
|
||||
@@ -68,31 +72,28 @@ def set_ascend_forward_context(
|
||||
# due to multiple warmups before actual capturing
|
||||
forward_context.capturing = False
|
||||
|
||||
# set for sequence parallelism, 1000 is the batch size concurrency threshold for enabling the flashcomm_v1 or sequence_parallelism feature.
|
||||
# Currently, it is an empirical value. In normal scenarios, if the concurrency exceeds this threshold,
|
||||
# the performance benefits can be maximized. Conversely, if the concurrency is below the threshold,
|
||||
# set for sequence parallelism, 1000 is the batch size concurrency threshold
|
||||
# for enabling the flashcomm_v1 or sequence_parallelism feature.
|
||||
# Currently, it is an empirical value. In normal scenarios, if the concurrency
|
||||
# exceeds this threshold, the performance benefits can be maximized.
|
||||
# Conversely, if the concurrency is below the threshold,
|
||||
# the performance may degrade due to the switching of communication methods.
|
||||
mmrs_fusion = True
|
||||
# main model and drafter model may have different architecture
|
||||
is_context_moe_model = is_drafter_moe_model(vllm_config) \
|
||||
if is_draft_model else is_moe_model(vllm_config)
|
||||
is_context_moe_model = is_drafter_moe_model(vllm_config) if is_draft_model else is_moe_model(vllm_config)
|
||||
if is_context_moe_model:
|
||||
sp_enabled = enable_sp(vllm_config) and num_tokens is not None
|
||||
mmrs_fusion = False
|
||||
else:
|
||||
sp_enabled = enable_sp(vllm_config) and \
|
||||
num_tokens is not None and num_tokens > 1000
|
||||
sp_enabled = enable_sp(vllm_config) and num_tokens is not None and num_tokens > 1000
|
||||
forward_context.mmrs_fusion = mmrs_fusion
|
||||
forward_context.num_tokens = num_tokens
|
||||
forward_context.sp_enabled = sp_enabled
|
||||
# TODO(Levi-JQ): another PR to normalize the enabling logic for sp/fc2
|
||||
forward_context.flashcomm_v2_enabled = flashcomm2_enable(
|
||||
) and tp_world_size > 1 and num_tokens is not None
|
||||
forward_context.flashcomm_v2_enabled = flashcomm2_enable() and tp_world_size > 1 and num_tokens is not None
|
||||
|
||||
if (forward_context.sp_enabled
|
||||
or forward_context.flashcomm_v2_enabled):
|
||||
pad_size = (tp_world_size -
|
||||
(num_tokens % tp_world_size)) % tp_world_size
|
||||
if forward_context.sp_enabled or forward_context.flashcomm_v2_enabled:
|
||||
pad_size = (tp_world_size - (num_tokens % tp_world_size)) % tp_world_size
|
||||
forward_context.pad_size = pad_size
|
||||
|
||||
# set this for rope forward_oot using
|
||||
@@ -106,9 +107,12 @@ def set_ascend_forward_context(
|
||||
|
||||
# TODO(rjg-lyh): refactor mlp weight prefetch method
|
||||
# set for mlp weight prefetch
|
||||
prefetch_mlp_enabled = envs_ascend.VLLM_ASCEND_ENABLE_PREFETCH_MLP and \
|
||||
forward_context.layer_idx is not None and \
|
||||
num_tokens is not None and num_tokens < 500
|
||||
prefetch_mlp_enabled = (
|
||||
envs_ascend.VLLM_ASCEND_ENABLE_PREFETCH_MLP
|
||||
and forward_context.layer_idx is not None
|
||||
and num_tokens is not None
|
||||
and num_tokens < 500
|
||||
)
|
||||
if prefetch_mlp_enabled:
|
||||
forward_context.prefetch_mlp_gate_up_proj = False
|
||||
forward_context.prefetch_mlp_down_proj = False
|
||||
@@ -121,12 +125,9 @@ def set_ascend_forward_context(
|
||||
|
||||
dp_world_size = get_dp_group().world_size
|
||||
if dp_world_size > 1 and forward_context.dp_metadata is not None:
|
||||
max_tokens_across_dp = \
|
||||
forward_context.dp_metadata.max_tokens_across_dp_cpu.item()
|
||||
if (forward_context.sp_enabled
|
||||
or forward_context.flashcomm_v2_enabled):
|
||||
padded_length = (max_tokens_across_dp + tp_world_size -
|
||||
1) // tp_world_size * tp_world_size
|
||||
max_tokens_across_dp = forward_context.dp_metadata.max_tokens_across_dp_cpu.item()
|
||||
if forward_context.sp_enabled or forward_context.flashcomm_v2_enabled:
|
||||
padded_length = (max_tokens_across_dp + tp_world_size - 1) // tp_world_size * tp_world_size
|
||||
pad_size = padded_length - num_tokens
|
||||
forward_context.padded_length = padded_length
|
||||
forward_context.pad_size = pad_size
|
||||
@@ -139,12 +140,10 @@ def set_ascend_forward_context(
|
||||
if num_actual_tokens is None:
|
||||
num_actual_tokens = num_tokens
|
||||
# NOTE: token num which need to pad to when mc2
|
||||
forward_context.padded_num_tokens = math.ceil(
|
||||
max_tokens_across_dp / tp_world_size) * tp_world_size
|
||||
forward_context.padded_num_tokens = math.ceil(max_tokens_across_dp / tp_world_size) * tp_world_size
|
||||
reserved_mc2_mask = get_mc2_mask()
|
||||
if reserved_mc2_mask is not None:
|
||||
mc2_mask = reserved_mc2_mask[:forward_context.
|
||||
padded_num_tokens]
|
||||
mc2_mask = reserved_mc2_mask[: forward_context.padded_num_tokens]
|
||||
mc2_mask[:num_actual_tokens] = True
|
||||
mc2_mask[num_actual_tokens:] = False
|
||||
forward_context.mc2_mask = mc2_mask
|
||||
@@ -155,14 +154,13 @@ def set_ascend_forward_context(
|
||||
pass
|
||||
|
||||
|
||||
_mc2_tokens_capacity: Optional[int] = None
|
||||
_reserved_mc2_mask: Optional[torch.Tensor] = None
|
||||
_sin: Optional[torch.Tensor] = None
|
||||
_cos: Optional[torch.Tensor] = None
|
||||
_mc2_tokens_capacity: int | None = None
|
||||
_reserved_mc2_mask: torch.Tensor | None = None
|
||||
_sin: torch.Tensor | None = None
|
||||
_cos: torch.Tensor | None = None
|
||||
|
||||
|
||||
def set_mc2_tokens_capacity(vllm_config, max_num_reqs,
|
||||
uniform_decode_query_len):
|
||||
def set_mc2_tokens_capacity(vllm_config, max_num_reqs, uniform_decode_query_len):
|
||||
global _mc2_tokens_capacity
|
||||
if _mc2_tokens_capacity is not None:
|
||||
return
|
||||
@@ -186,9 +184,7 @@ def set_mc2_mask(vllm_config, device):
|
||||
if _reserved_mc2_mask is not None:
|
||||
return
|
||||
if is_moe_model(vllm_config):
|
||||
_reserved_mc2_mask = torch.zeros(get_mc2_tokens_capacity(),
|
||||
dtype=torch.bool,
|
||||
device=device)
|
||||
_reserved_mc2_mask = torch.zeros(get_mc2_tokens_capacity(), dtype=torch.bool, device=device)
|
||||
else:
|
||||
_reserved_mc2_mask = None
|
||||
|
||||
@@ -197,9 +193,7 @@ def get_mc2_mask():
|
||||
return _reserved_mc2_mask
|
||||
|
||||
|
||||
def select_moe_comm_method(num_tokens: int,
|
||||
vllm_config: VllmConfig,
|
||||
is_draft_model=False) -> Optional[MoECommType]:
|
||||
def select_moe_comm_method(num_tokens: int, vllm_config: VllmConfig, is_draft_model=False) -> MoECommType | None:
|
||||
"""Select the MoE communication method according to parallel settings,
|
||||
device generation, token count, and quantization.
|
||||
|
||||
@@ -227,16 +221,19 @@ def select_moe_comm_method(num_tokens: int,
|
||||
mc2_tokens_capacity = get_mc2_tokens_capacity()
|
||||
soc_version = get_ascend_device_type()
|
||||
quant_type = getattr(
|
||||
vllm_config.model_config.hf_text_config, 'moe_quantize',
|
||||
getattr(vllm_config.model_config.hf_text_config, 'quantize', None))
|
||||
vllm_config.model_config.hf_text_config,
|
||||
"moe_quantize",
|
||||
getattr(vllm_config.model_config.hf_text_config, "quantize", None),
|
||||
)
|
||||
|
||||
if not vllm_config.parallel_config.enable_expert_parallel or get_ep_group(
|
||||
).world_size == 1:
|
||||
if not vllm_config.parallel_config.enable_expert_parallel or get_ep_group().world_size == 1:
|
||||
moe_comm_type = MoECommType.ALLGATHER
|
||||
elif soc_version in {AscendDeviceType.A2}:
|
||||
if (num_tokens <= mc2_tokens_capacity
|
||||
and vllm_config.parallel_config.world_size_across_dp /
|
||||
vllm_config.parallel_config.pipeline_parallel_size >= 16):
|
||||
if (
|
||||
num_tokens <= mc2_tokens_capacity
|
||||
and vllm_config.parallel_config.world_size_across_dp / vllm_config.parallel_config.pipeline_parallel_size
|
||||
>= 16
|
||||
):
|
||||
moe_comm_type = MoECommType.MC2
|
||||
else:
|
||||
moe_comm_type = MoECommType.ALLGATHER
|
||||
@@ -246,15 +243,13 @@ def select_moe_comm_method(num_tokens: int,
|
||||
# TODO: drop the EP-size guard when dispatch_ffn_combine supports larger EP sizes
|
||||
# TODO: drop speculative method guard when dispatch_gmm_combine_decode supports w16a16
|
||||
fused_mc2_enable = envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 and quant_type == "w8a8_dynamic"
|
||||
dispatch_ffn_combine_enable = get_ep_group().world_size <= 32 and (
|
||||
not is_draft_model) and (not dynamic_eplb)
|
||||
dispatch_ffn_combine_enable = get_ep_group().world_size <= 32 and (not is_draft_model) and (not dynamic_eplb)
|
||||
if num_tokens <= mc2_tokens_capacity:
|
||||
fused_decode_enable = fused_mc2_enable
|
||||
if envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 == 1:
|
||||
fused_decode_enable = fused_mc2_enable and dispatch_ffn_combine_enable
|
||||
elif envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 == 2:
|
||||
fused_decode_enable = fused_mc2_enable and \
|
||||
speculative_enable_dispatch_gmm_combine_decode(vllm_config)
|
||||
fused_decode_enable = fused_mc2_enable and speculative_enable_dispatch_gmm_combine_decode(vllm_config)
|
||||
moe_comm_type = MoECommType.FUSED_MC2 if fused_decode_enable else MoECommType.MC2
|
||||
else:
|
||||
fused_prefill_enable = fused_mc2_enable
|
||||
|
||||
Reference in New Issue
Block a user