[Lint]Style: Convert vllm-ascend/compilation to ruff format (#5912)
### What this PR does / why we need it?
Convert `vllm-ascend/compilation` to ruff format.
### Does this PR introduce _any_ user-facing change?
During this migration, we encountered some **errors** in our CI and
testing environments, such as:
```
vllm_ascend/utils.py:653: in <module>
def register_ascend_customop(vllm_config: VllmConfig | None = None):
^^^^^^^^^^^^^^^^^
E TypeError: unsupported operand type(s) for |: 'NoneType' and 'NoneType'
```
**1. Root Cause Analysis:**
The project uses a common pattern to break circular dependencies:
```python
if TYPE_CHECKING:
from vllm.config import VllmConfig
else:
VllmConfig = None # Placeholder assigned at runtime
```
When Python parses the function definition `def
register_ascend_customop(vllm_config: VllmConfig | None)`, it attempts
to evaluate the expression `VllmConfig | None`.
Since `VllmConfig` is assigned `None` at runtime, the expression
effectively becomes `None | None`. In Python, `None` is an instance of
`NoneType`. While the `|` operator is implemented for Type objects
(classes), it is not supported for `NoneType` instances, leading to the
`TypeError` shown above.
**2. Solution:**
To maintain the modern `|` syntax required by our new linting standards
while preserving our dependency management strategy, I have introduced:
```python
from __future__ import annotations
```
at the top of the affected files. This enables **Postponed Evaluation of
Annotations (PEP 563)**.
**3. Impact and Benefits:**
- By enabling `annotations`, Python no longer executes the `VllmConfig |
None` operation during module load. Instead, it stores the annotation as
a string literal, completely avoiding the `None | None` calculation.
- We can keep the `VllmConfig = None` placeholders. This ensures that
other modules can still import these symbols without triggering an
`ImportError`, maintaining a stable dependency graph.
- IDEs and static type checkers (MyPy/Pyright) continue to resolve the
types correctly. This allows us to use modern syntax without sacrificing
type safety or runtime stability.
- The only side effect is that `__annotations__` will now return strings
instead of type objects. Since this module does not use runtime type
enforcement or reflection, this change has zero negative impact on
existing functionality.
### How was this patch tested?
- vLLM version: v0.13.0
- vLLM main:
11b6af5280
---------
Signed-off-by: MrZ20 <2609716663@qq.com>
This commit is contained in:
@@ -14,7 +14,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import os
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from vllm.logger import logger
|
||||
from vllm.triton_utils import HAS_TRITON
|
||||
@@ -32,18 +32,13 @@ class AscendConfig:
|
||||
additional_config = vllm_config.additional_config if vllm_config.additional_config is not None else {}
|
||||
|
||||
xlite_graph_config = additional_config.get("xlite_graph_config", {})
|
||||
self.xlite_graph_config = XliteGraphConfig(xlite_graph_config,
|
||||
vllm_config)
|
||||
self.xlite_graph_config = XliteGraphConfig(xlite_graph_config, vllm_config)
|
||||
|
||||
ascend_compilation_config = additional_config.get(
|
||||
"ascend_compilation_config", {})
|
||||
self.ascend_compilation_config = AscendCompilationConfig(
|
||||
**ascend_compilation_config)
|
||||
ascend_compilation_config = additional_config.get("ascend_compilation_config", {})
|
||||
self.ascend_compilation_config = AscendCompilationConfig(**ascend_compilation_config)
|
||||
|
||||
finegrained_tp_config = additional_config.get("finegrained_tp_config",
|
||||
{})
|
||||
self.finegrained_tp_config = FinegrainedTPConfig(
|
||||
finegrained_tp_config, vllm_config)
|
||||
finegrained_tp_config = additional_config.get("finegrained_tp_config", {})
|
||||
self.finegrained_tp_config = FinegrainedTPConfig(finegrained_tp_config, vllm_config)
|
||||
|
||||
eplb_config = additional_config.get("eplb_config", {})
|
||||
self.eplb_config = EplbConfig(eplb_config)
|
||||
@@ -51,10 +46,8 @@ class AscendConfig:
|
||||
# Dump / PrecisionDebugger configuration
|
||||
self.dump_config_path = additional_config.get("dump_config_path", None)
|
||||
|
||||
weight_prefetch_config = additional_config.get(
|
||||
"weight_prefetch_config", {})
|
||||
self.weight_prefetch_config = WeightPrefetchConfig(
|
||||
weight_prefetch_config)
|
||||
weight_prefetch_config = additional_config.get("weight_prefetch_config", {})
|
||||
self.weight_prefetch_config = WeightPrefetchConfig(weight_prefetch_config)
|
||||
self.layer_sharding = additional_config.get("layer_sharding", None)
|
||||
logger.info_once(
|
||||
f"Linear layer sharding enabled with config: {self.layer_sharding}. "
|
||||
@@ -62,30 +55,25 @@ class AscendConfig:
|
||||
"using it without these features may result in significant performance degradation."
|
||||
)
|
||||
|
||||
self.enable_shared_expert_dp = additional_config.get(
|
||||
"enable_shared_expert_dp",
|
||||
False) and vllm_config.parallel_config.enable_expert_parallel
|
||||
self.enable_shared_expert_dp = (
|
||||
additional_config.get("enable_shared_expert_dp", False)
|
||||
and vllm_config.parallel_config.enable_expert_parallel
|
||||
)
|
||||
if self.enable_shared_expert_dp:
|
||||
from vllm_ascend.utils import enable_sp
|
||||
assert enable_sp(vllm_config=vllm_config,
|
||||
enable_shared_expert_dp=True)
|
||||
self.multistream_overlap_shared_expert = additional_config.get(
|
||||
"multistream_overlap_shared_expert", False)
|
||||
self.multistream_overlap_gate = additional_config.get(
|
||||
"multistream_overlap_gate", False)
|
||||
self.recompute_scheduler_enable = additional_config.get(
|
||||
"recompute_scheduler_enable", False)
|
||||
self.enable_cpu_binding = additional_config.get(
|
||||
"enable_cpu_binding", False)
|
||||
|
||||
assert enable_sp(vllm_config=vllm_config, enable_shared_expert_dp=True)
|
||||
self.multistream_overlap_shared_expert = additional_config.get("multistream_overlap_shared_expert", False)
|
||||
self.multistream_overlap_gate = additional_config.get("multistream_overlap_gate", False)
|
||||
self.recompute_scheduler_enable = additional_config.get("recompute_scheduler_enable", False)
|
||||
self.enable_cpu_binding = additional_config.get("enable_cpu_binding", False)
|
||||
|
||||
self.pd_tp_ratio = 1
|
||||
self.pd_head_ratio = 1
|
||||
self.num_head_replica = 1
|
||||
if vllm_config.kv_transfer_config is not None and not vllm_config.model_config.is_deepseek_mla:
|
||||
prefill_tp_size = vllm_config.kv_transfer_config.get_from_extra_config(
|
||||
"prefill", {"tp_size": 1})["tp_size"]
|
||||
decode_tp_size = vllm_config.kv_transfer_config.get_from_extra_config(
|
||||
"decode", {"tp_size": 1})["tp_size"]
|
||||
prefill_tp_size = vllm_config.kv_transfer_config.get_from_extra_config("prefill", {"tp_size": 1})["tp_size"]
|
||||
decode_tp_size = vllm_config.kv_transfer_config.get_from_extra_config("decode", {"tp_size": 1})["tp_size"]
|
||||
assert prefill_tp_size % decode_tp_size == 0, "Prefill TP size must be divisible by Decode TP size."
|
||||
self.pd_tp_ratio = prefill_tp_size // decode_tp_size
|
||||
if self.pd_tp_ratio > 1:
|
||||
@@ -106,36 +94,29 @@ class AscendConfig:
|
||||
)
|
||||
|
||||
if self.pd_tp_ratio == 0:
|
||||
raise AssertionError(
|
||||
"Only support P node tp size lagger then D node tp size")
|
||||
self.SLO_limits_for_dynamic_batch = additional_config.get(
|
||||
"SLO_limits_for_dynamic_batch", -1)
|
||||
raise AssertionError("Only support P node tp size lagger then D node tp size")
|
||||
self.SLO_limits_for_dynamic_batch = additional_config.get("SLO_limits_for_dynamic_batch", -1)
|
||||
from vllm_ascend.utils import get_flashcomm2_config_and_validate
|
||||
self.flashcomm2_oproj_tensor_parallel_size = get_flashcomm2_config_and_validate(
|
||||
self, vllm_config)
|
||||
self.enable_npugraph_ex = additional_config.get(
|
||||
"enable_npugraph_ex", False)
|
||||
|
||||
self.flashcomm2_oproj_tensor_parallel_size = get_flashcomm2_config_and_validate(self, vllm_config)
|
||||
self.enable_npugraph_ex = additional_config.get("enable_npugraph_ex", False)
|
||||
# We find that _npu_paged_attention still performs better than
|
||||
# npu_fused_infer_attention_score in some cases. We allow to execute
|
||||
# _npu_paged_attention in this cases. This should be removed once
|
||||
# npu_fused_infer_attention_score performs better on all scenarios.
|
||||
self.pa_shape_list = additional_config.get("pa_shape_list", [])
|
||||
|
||||
self.enable_async_exponential = bool(
|
||||
additional_config.get("enable_async_exponential", False))
|
||||
self.enable_async_exponential = bool(additional_config.get("enable_async_exponential", False))
|
||||
|
||||
self.enable_kv_nz = additional_config.get("enable_kv_nz", False)
|
||||
if self.enable_kv_nz:
|
||||
use_sparse = hasattr(vllm_config.model_config.hf_text_config,
|
||||
"index_topk")
|
||||
use_sparse = hasattr(vllm_config.model_config.hf_text_config, "index_topk")
|
||||
if not vllm_config.model_config.is_deepseek_mla or use_sparse:
|
||||
raise RuntimeError(
|
||||
"enable_kv_nz is only supported for mla currently.")
|
||||
if vllm_config.kv_transfer_config is None \
|
||||
or not vllm_config.kv_transfer_config.is_kv_consumer:
|
||||
raise RuntimeError("enable_kv_nz is only supported for mla currently.")
|
||||
if vllm_config.kv_transfer_config is None or not vllm_config.kv_transfer_config.is_kv_consumer:
|
||||
raise NotImplementedError(
|
||||
"enable_kv_nz is only supported in pd scenario and can "
|
||||
"only be used in D node.")
|
||||
"enable_kv_nz is only supported in pd scenario and can only be used in D node."
|
||||
)
|
||||
|
||||
|
||||
class FinegrainedTPConfig:
|
||||
@@ -144,40 +125,28 @@ class FinegrainedTPConfig:
|
||||
"""
|
||||
|
||||
def __init__(self, finegrained_tp_config: dict, vllm_config):
|
||||
self.oproj_tensor_parallel_size = finegrained_tp_config.get(
|
||||
"oproj_tensor_parallel_size", 0)
|
||||
self.lmhead_tensor_parallel_size = finegrained_tp_config.get(
|
||||
"lmhead_tensor_parallel_size", 0)
|
||||
self.embedding_tensor_parallel_size = finegrained_tp_config.get(
|
||||
"embedding_tensor_parallel_size", 0)
|
||||
self.mlp_tensor_parallel_size = finegrained_tp_config.get(
|
||||
"mlp_tensor_parallel_size", 0)
|
||||
self.oproj_tensor_parallel_size = finegrained_tp_config.get("oproj_tensor_parallel_size", 0)
|
||||
self.lmhead_tensor_parallel_size = finegrained_tp_config.get("lmhead_tensor_parallel_size", 0)
|
||||
self.embedding_tensor_parallel_size = finegrained_tp_config.get("embedding_tensor_parallel_size", 0)
|
||||
self.mlp_tensor_parallel_size = finegrained_tp_config.get("mlp_tensor_parallel_size", 0)
|
||||
|
||||
enabled_configs = []
|
||||
if self.oproj_tensor_parallel_size > 0:
|
||||
enabled_configs.append(
|
||||
f"oproj_tensor_parallel_size={self.oproj_tensor_parallel_size}"
|
||||
)
|
||||
# dummy_run does not run the entire attention module in eager mode,, so the o_proj tp split can only be used in graph mode.
|
||||
enabled_configs.append(f"oproj_tensor_parallel_size={self.oproj_tensor_parallel_size}")
|
||||
# dummy_run does not run the entire attention module in eager mode,
|
||||
# so the o_proj tp split can only be used in graph mode.
|
||||
if vllm_config.model_config.enforce_eager is True:
|
||||
raise AssertionError(
|
||||
"oproj_tensor_parallel_size is only supported in graph mode"
|
||||
)
|
||||
raise AssertionError("oproj_tensor_parallel_size is only supported in graph mode")
|
||||
if vllm_config.kv_transfer_config is None or not vllm_config.kv_transfer_config.is_kv_consumer:
|
||||
raise AssertionError(
|
||||
"oproj_tensor_parallel_size is only supported in pd scenario and can only be used in D node."
|
||||
)
|
||||
if self.lmhead_tensor_parallel_size > 0:
|
||||
enabled_configs.append(
|
||||
f"lmhead_tensor_parallel_size={self.lmhead_tensor_parallel_size}"
|
||||
)
|
||||
enabled_configs.append(f"lmhead_tensor_parallel_size={self.lmhead_tensor_parallel_size}")
|
||||
if self.embedding_tensor_parallel_size > 0:
|
||||
enabled_configs.append(
|
||||
f"embedding_tensor_parallel_size={self.embedding_tensor_parallel_size}"
|
||||
)
|
||||
enabled_configs.append(f"embedding_tensor_parallel_size={self.embedding_tensor_parallel_size}")
|
||||
if self.mlp_tensor_parallel_size > 0:
|
||||
enabled_configs.append(
|
||||
f"mlp_tensor_parallel_size={self.mlp_tensor_parallel_size}")
|
||||
enabled_configs.append(f"mlp_tensor_parallel_size={self.mlp_tensor_parallel_size}")
|
||||
module_tp_sizes = [
|
||||
self.oproj_tensor_parallel_size,
|
||||
self.lmhead_tensor_parallel_size,
|
||||
@@ -186,11 +155,9 @@ class FinegrainedTPConfig:
|
||||
]
|
||||
for module_tp_size in module_tp_sizes:
|
||||
if module_tp_size > 0 and vllm_config.parallel_config.data_parallel_size % module_tp_size != 0:
|
||||
raise AssertionError(
|
||||
"module tp sizes must divide data_parallel_size")
|
||||
raise AssertionError("module tp sizes must divide data_parallel_size")
|
||||
if any(size > 0 for size in module_tp_sizes) and enabled_configs:
|
||||
logger.info(
|
||||
f"finegrained_tp_config enabled: {', '.join(enabled_configs)}")
|
||||
logger.info(f"finegrained_tp_config enabled: {', '.join(enabled_configs)}")
|
||||
|
||||
|
||||
class AscendCompilationConfig:
|
||||
@@ -202,13 +169,10 @@ class AscendCompilationConfig:
|
||||
deployed on Ascend platforms.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
fuse_norm_quant: bool = True,
|
||||
fuse_qknorm_rope: bool = False,
|
||||
**kwargs):
|
||||
def __init__(self, fuse_norm_quant: bool = True, fuse_qknorm_rope: bool = False, **kwargs):
|
||||
"""
|
||||
Initialize the configuration.
|
||||
|
||||
|
||||
Args:
|
||||
fuse_norm_quant (bool): Whether to enable norm and quant fusion optimization.
|
||||
When set to True, the system will optimize norm and quant operations.
|
||||
@@ -236,7 +200,8 @@ class XliteGraphConfig:
|
||||
)
|
||||
if vllm_config.parallel_config.pipeline_parallel_size > 1:
|
||||
raise RuntimeError(
|
||||
"Xlite graph mode is not compatible with pipeline parallelism. Please set pipeline_parallel_size to 1."
|
||||
"Xlite graph mode is not compatible with pipeline parallelism. "
|
||||
"Please set pipeline_parallel_size to 1."
|
||||
)
|
||||
if vllm_config.cache_config.block_size != 128:
|
||||
raise RuntimeError(
|
||||
@@ -254,21 +219,19 @@ class WeightPrefetchConfig:
|
||||
"qkv": 1.0,
|
||||
"o": 1.0,
|
||||
},
|
||||
"moe": {
|
||||
"gate_up": 0.8
|
||||
}
|
||||
"moe": {"gate_up": 0.8},
|
||||
}
|
||||
|
||||
def __init__(self, weight_prefetch_config: dict):
|
||||
self.enabled = weight_prefetch_config.get("enabled", False)
|
||||
self.prefetch_ratio = weight_prefetch_config.get(
|
||||
"prefetch_ratio", self.prefetch_ratio)
|
||||
self.prefetch_ratio = weight_prefetch_config.get("prefetch_ratio", self.prefetch_ratio)
|
||||
|
||||
|
||||
class EplbConfig:
|
||||
"""
|
||||
Configuration Object for xlite_graph_config from additional_config
|
||||
"""
|
||||
|
||||
_defaults = {
|
||||
"dynamic_eplb": False,
|
||||
"expert_map_path": None,
|
||||
@@ -276,10 +239,12 @@ class EplbConfig:
|
||||
"algorithm_execution_interval": 30,
|
||||
"expert_map_record_path": None,
|
||||
"num_redundant_experts": 0,
|
||||
"eplb_policy_type": 1
|
||||
"eplb_policy_type": 1,
|
||||
}
|
||||
|
||||
def __init__(self, user_config: dict = {}):
|
||||
def __init__(self, user_config: dict | None = None):
|
||||
if user_config is None:
|
||||
user_config = {}
|
||||
self.config = self._defaults.copy()
|
||||
if user_config and isinstance(user_config, dict):
|
||||
for key, value in user_config.items():
|
||||
@@ -307,27 +272,21 @@ class EplbConfig:
|
||||
raise TypeError("The expert_map_record_path is not json.")
|
||||
dirname = os.path.dirname(self.expert_map_record_path)
|
||||
os.makedirs(dirname, exist_ok=True)
|
||||
for key in [
|
||||
"expert_heat_collection_interval",
|
||||
"algorithm_execution_interval", "num_redundant_experts"
|
||||
]:
|
||||
for key in ["expert_heat_collection_interval", "algorithm_execution_interval", "num_redundant_experts"]:
|
||||
if not isinstance(self.config[key], int):
|
||||
raise TypeError(f"{key} must be an integer")
|
||||
if self.config[key] < 0: # type: ignore
|
||||
raise ValueError(
|
||||
f"{key} must greater than 0; got {self.config[key]} instead"
|
||||
)
|
||||
raise ValueError(f"{key} must greater than 0; got {self.config[key]} instead")
|
||||
if self.eplb_policy_type not in [0, 1, 2, 3]:
|
||||
raise ValueError("eplb_policy_type must in [0, 1, 2, 3]")
|
||||
|
||||
|
||||
_ASCEND_CONFIG: Optional[AscendConfig] = None
|
||||
_ASCEND_CONFIG: AscendConfig | None = None
|
||||
|
||||
|
||||
def init_ascend_config(vllm_config):
|
||||
additional_config = vllm_config.additional_config if vllm_config.additional_config is not None else {}
|
||||
refresh = additional_config.get("refresh",
|
||||
False) if additional_config else False
|
||||
refresh = additional_config.get("refresh", False) if additional_config else False
|
||||
global _ASCEND_CONFIG
|
||||
if _ASCEND_CONFIG is not None and not refresh:
|
||||
return _ASCEND_CONFIG
|
||||
@@ -343,7 +302,5 @@ def clear_ascend_config():
|
||||
def get_ascend_config():
|
||||
global _ASCEND_CONFIG
|
||||
if _ASCEND_CONFIG is None:
|
||||
raise RuntimeError(
|
||||
"Ascend config is not initialized. Please call init_ascend_config first."
|
||||
)
|
||||
raise RuntimeError("Ascend config is not initialized. Please call init_ascend_config first.")
|
||||
return _ASCEND_CONFIG
|
||||
|
||||
Reference in New Issue
Block a user