[Lint]Style: Convert vllm-ascend/compilation to ruff format (#5912)
### What this PR does / why we need it?
Convert `vllm-ascend/compilation` to ruff format.
### Does this PR introduce _any_ user-facing change?
During this migration, we encountered some **errors** in our CI and
testing environments, such as:
```
vllm_ascend/utils.py:653: in <module>
def register_ascend_customop(vllm_config: VllmConfig | None = None):
^^^^^^^^^^^^^^^^^
E TypeError: unsupported operand type(s) for |: 'NoneType' and 'NoneType'
```
**1. Root Cause Analysis:**
The project uses a common pattern to break circular dependencies:
```python
if TYPE_CHECKING:
from vllm.config import VllmConfig
else:
VllmConfig = None # Placeholder assigned at runtime
```
When Python parses the function definition `def
register_ascend_customop(vllm_config: VllmConfig | None)`, it attempts
to evaluate the expression `VllmConfig | None`.
Since `VllmConfig` is assigned `None` at runtime, the expression
effectively becomes `None | None`. In Python, `None` is an instance of
`NoneType`. While the `|` operator is implemented for Type objects
(classes), it is not supported for `NoneType` instances, leading to the
`TypeError` shown above.
**2. Solution:**
To maintain the modern `|` syntax required by our new linting standards
while preserving our dependency management strategy, I have introduced:
```python
from __future__ import annotations
```
at the top of the affected files. This enables **Postponed Evaluation of
Annotations (PEP 563)**.
**3. Impact and Benefits:**
- By enabling `annotations`, Python no longer executes the `VllmConfig |
None` operation during module load. Instead, it stores the annotation as
a string literal, completely avoiding the `None | None` calculation.
- We can keep the `VllmConfig = None` placeholders. This ensures that
other modules can still import these symbols without triggering an
`ImportError`, maintaining a stable dependency graph.
- IDEs and static type checkers (MyPy/Pyright) continue to resolve the
types correctly. This allows us to use modern syntax without sacrificing
type safety or runtime stability.
- The only side effect is that `__annotations__` will now return strings
instead of type objects. Since this module does not use runtime type
enforcement or reflection, this change has zero negative impact on
existing functionality.
### How was this patch tested?
- vLLM version: v0.13.0
- vLLM main:
11b6af5280
---------
Signed-off-by: MrZ20 <2609716663@qq.com>
This commit is contained in:
@@ -15,9 +15,11 @@
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
import os
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
from typing import TYPE_CHECKING, Any
|
||||
from uuid import uuid4
|
||||
|
||||
import torch
|
||||
@@ -32,11 +34,21 @@ from vllm_ascend.ascend_config import init_ascend_config
|
||||
|
||||
# isort: off
|
||||
from vllm_ascend.utils import (
|
||||
ASCEND_QUANTIZATION_METHOD, COMPILATION_PASS_KEY,
|
||||
COMPRESSED_TENSORS_METHOD, AscendDeviceType, check_kv_extra_config,
|
||||
enable_sp, flashcomm2_enable, get_ascend_device_type, is_moe_model,
|
||||
is_vl_model, refresh_block_size, update_aclgraph_sizes,
|
||||
update_cudagraph_capture_sizes, update_default_aclgraph_sizes)
|
||||
ASCEND_QUANTIZATION_METHOD,
|
||||
COMPILATION_PASS_KEY,
|
||||
COMPRESSED_TENSORS_METHOD,
|
||||
AscendDeviceType,
|
||||
check_kv_extra_config,
|
||||
enable_sp,
|
||||
flashcomm2_enable,
|
||||
get_ascend_device_type,
|
||||
is_moe_model,
|
||||
is_vl_model,
|
||||
refresh_block_size,
|
||||
update_aclgraph_sizes,
|
||||
update_cudagraph_capture_sizes,
|
||||
update_default_aclgraph_sizes,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.config import ModelConfig, VllmConfig
|
||||
@@ -80,7 +92,6 @@ def config_deprecated_logging():
|
||||
|
||||
|
||||
class NPUPlatform(Platform):
|
||||
|
||||
_enum = PlatformEnum.OOT
|
||||
device_name: str = "npu"
|
||||
device_type: str = "npu"
|
||||
@@ -89,9 +100,7 @@ class NPUPlatform(Platform):
|
||||
device_control_env_var: str = "ASCEND_RT_VISIBLE_DEVICES"
|
||||
dispatch_key: str = "PrivateUse1"
|
||||
|
||||
supported_quantization: list[str] = [
|
||||
ASCEND_QUANTIZATION_METHOD, COMPRESSED_TENSORS_METHOD
|
||||
]
|
||||
supported_quantization: list[str] = [ASCEND_QUANTIZATION_METHOD, COMPRESSED_TENSORS_METHOD]
|
||||
|
||||
def is_sleep_mode_available(self) -> bool:
|
||||
return True
|
||||
@@ -116,33 +125,29 @@ class NPUPlatform(Platform):
|
||||
@classmethod
|
||||
def get_compile_backend(self) -> str:
|
||||
"""
|
||||
Get the custom compile backend. Previously, we used EagerAdaptor by default.
|
||||
Get the custom compile backend. Previously, we used EagerAdaptor by default.
|
||||
To use graph fusion operations, we defined our own backend compiler.
|
||||
"""
|
||||
return "vllm_ascend.compilation.compiler_interface.AscendCompiler"
|
||||
|
||||
@classmethod
|
||||
def pre_register_and_update(cls,
|
||||
parser: Optional[FlexibleArgumentParser] = None
|
||||
) -> None:
|
||||
def pre_register_and_update(cls, parser: FlexibleArgumentParser | None = None) -> None:
|
||||
# Adapt the global patch here.
|
||||
from vllm_ascend.utils import adapt_patch
|
||||
|
||||
adapt_patch(is_global_patch=True)
|
||||
|
||||
# For online serving, "ascend" quantization method is not a choice natively,
|
||||
# so we need to add "ascend" quantization method to quantization methods list
|
||||
# and the user can enable quantization using "vllm serve --quantization ascend".
|
||||
if parser is not None:
|
||||
quant_action = parser._option_string_actions.get('--quantization')
|
||||
if quant_action and hasattr(quant_action,
|
||||
'choices') and quant_action.choices:
|
||||
quant_action = parser._option_string_actions.get("--quantization")
|
||||
if quant_action and hasattr(quant_action, "choices") and quant_action.choices:
|
||||
if ASCEND_QUANTIZATION_METHOD not in quant_action.choices:
|
||||
quant_action.choices.append(ASCEND_QUANTIZATION_METHOD)
|
||||
|
||||
from vllm_ascend.quantization.compressed_tensors.compressed_tensors import \
|
||||
AscendCompressedTensorsConfig # noqa: F401
|
||||
from vllm_ascend.quantization.quant_config import \
|
||||
AscendQuantConfig # noqa: F401
|
||||
from vllm_ascend.quantization.compressed_tensors.compressed_tensors import AscendCompressedTensorsConfig # noqa: F401
|
||||
from vllm_ascend.quantization.quant_config import AscendQuantConfig # noqa: F401
|
||||
|
||||
config_deprecated_logging()
|
||||
|
||||
@@ -169,8 +174,7 @@ class NPUPlatform(Platform):
|
||||
|
||||
if vllm_config.kv_transfer_config is not None:
|
||||
check_kv_extra_config(vllm_config)
|
||||
if not getattr(vllm_config.kv_transfer_config,
|
||||
"_engine_id_patched", False):
|
||||
if not getattr(vllm_config.kv_transfer_config, "_engine_id_patched", False):
|
||||
vllm_config.kv_transfer_config.engine_id = f"{vllm_config.kv_transfer_config.engine_id}-{uuid4().hex}"
|
||||
vllm_config.kv_transfer_config._engine_id_patched = True
|
||||
from vllm.config import CompilationMode # noqa: E402
|
||||
@@ -181,24 +185,22 @@ class NPUPlatform(Platform):
|
||||
cache_config = vllm_config.cache_config
|
||||
ascend_compilation_config = ascend_config.ascend_compilation_config
|
||||
if ascend_compilation_config:
|
||||
vllm_config.additional_config.setdefault(
|
||||
"ascend_compilation_config", {}).update(
|
||||
vars(ascend_compilation_config
|
||||
) if not isinstance(ascend_compilation_config, dict)
|
||||
else ascend_compilation_config)
|
||||
vllm_config.additional_config.setdefault("ascend_compilation_config", {}).update(
|
||||
vars(ascend_compilation_config)
|
||||
if not isinstance(ascend_compilation_config, dict)
|
||||
else ascend_compilation_config
|
||||
)
|
||||
|
||||
elif model_config and hasattr(model_config.hf_text_config,
|
||||
"index_topk"):
|
||||
vllm_config.cache_config.cache_dtype = str(
|
||||
model_config.dtype).replace("torch.", "")
|
||||
elif model_config and hasattr(model_config.hf_text_config, "index_topk"):
|
||||
vllm_config.cache_config.cache_dtype = str(model_config.dtype).replace("torch.", "")
|
||||
if model_config is None:
|
||||
logger.warning("Model config is missing. This may indicate "
|
||||
"that we are running a test case")
|
||||
logger.warning("Model config is missing. This may indicate that we are running a test case")
|
||||
enforce_eager = False
|
||||
else:
|
||||
enforce_eager = getattr(model_config, "enforce_eager", False)
|
||||
|
||||
from vllm.config.compilation import CUDAGraphMode
|
||||
|
||||
if enforce_eager:
|
||||
logger.info("Compilation disabled, using eager mode by default")
|
||||
compilation_config.mode = CompilationMode.NONE
|
||||
@@ -207,12 +209,10 @@ class NPUPlatform(Platform):
|
||||
|
||||
compilation_config.cudagraph_num_of_warmups = 1
|
||||
|
||||
if compilation_config.mode not in [
|
||||
CompilationMode.NONE, CompilationMode.VLLM_COMPILE
|
||||
]:
|
||||
if compilation_config.mode not in [CompilationMode.NONE, CompilationMode.VLLM_COMPILE]:
|
||||
logger.warning(
|
||||
"NPU does not support %s compilation mode. Setting CUDAGraphMode to NONE",
|
||||
compilation_config.mode)
|
||||
"NPU does not support %s compilation mode. Setting CUDAGraphMode to NONE", compilation_config.mode
|
||||
)
|
||||
compilation_config.cudagraph_mode = CUDAGraphMode.NONE
|
||||
|
||||
# set cudaprah sizes before extending `compilation_config.splitting_ops`
|
||||
@@ -223,15 +223,18 @@ class NPUPlatform(Platform):
|
||||
update_default_aclgraph_sizes(vllm_config)
|
||||
# TODO delete graph size update here when compilation_config.pass_config.enable_sp
|
||||
# is supported by vllm-ascend.
|
||||
if vllm_config.parallel_config.tensor_parallel_size > 1 and not vllm_config.model_config.enforce_eager and \
|
||||
enable_sp(vllm_config):
|
||||
if (
|
||||
vllm_config.parallel_config.tensor_parallel_size > 1
|
||||
and not vllm_config.model_config.enforce_eager
|
||||
and enable_sp(vllm_config)
|
||||
):
|
||||
original_sizes = compilation_config.cudagraph_capture_sizes
|
||||
sp_aclgraph_sizes = \
|
||||
vllm_config.update_sizes_for_sequence_parallelism(original_sizes)
|
||||
sp_aclgraph_sizes = vllm_config.update_sizes_for_sequence_parallelism(original_sizes)
|
||||
assert sp_aclgraph_sizes, (
|
||||
f"cudagraph_capture_sizes {original_sizes} does not contain"
|
||||
f"values that are multiples of tp_size "
|
||||
f"{vllm_config.parallel_config.tensor_parallel_size}")
|
||||
f"{vllm_config.parallel_config.tensor_parallel_size}"
|
||||
)
|
||||
if len(sp_aclgraph_sizes) != len(original_sizes):
|
||||
compilation_config.cudagraph_capture_sizes = sp_aclgraph_sizes
|
||||
update_cudagraph_capture_sizes(vllm_config, sp_aclgraph_sizes)
|
||||
@@ -243,9 +246,7 @@ class NPUPlatform(Platform):
|
||||
# encoder-decoder models currently only support piecewise mode
|
||||
if model_config and model_config.is_encoder_decoder is True:
|
||||
if compilation_config.cudagraph_mode == CUDAGraphMode.FULL_DECODE_ONLY:
|
||||
logger.warning(
|
||||
"encoder-decoder model doesn't support FULL_DECODE_ONLY, fallback to PIECEWISE "
|
||||
)
|
||||
logger.warning("encoder-decoder model doesn't support FULL_DECODE_ONLY, fallback to PIECEWISE ")
|
||||
compilation_config.cudagraph_mode = CUDAGraphMode.PIECEWISE
|
||||
|
||||
# get custom compile backend for graph fusion
|
||||
@@ -255,15 +256,14 @@ class NPUPlatform(Platform):
|
||||
compilation_config.mode = CompilationMode.NONE
|
||||
ascend_config.enable_npugraph_ex = False
|
||||
elif compilation_config.cudagraph_mode == CUDAGraphMode.PIECEWISE:
|
||||
logger.info(
|
||||
"PIECEWISE compilation enabled on NPU. use_inductor not supported - "
|
||||
"using only ACL Graph mode")
|
||||
assert compilation_config.mode == CompilationMode.VLLM_COMPILE, \
|
||||
"When enabling VLLM_COMPILE aclgraph, please make sure compilation_config.mode == CompilationMode.VLLM_COMPILE and compilation_config.cudagraph_mode == CUDAGraphMode.VLLM_COMPILE"
|
||||
logger.info("PIECEWISE compilation enabled on NPU. use_inductor not supported - using only ACL Graph mode")
|
||||
assert compilation_config.mode == CompilationMode.VLLM_COMPILE, (
|
||||
"When enabling VLLM_COMPILE aclgraph, please make sure compilation_config.mode == "
|
||||
"CompilationMode.VLLM_COMPILE and compilation_config.cudagraph_mode == CUDAGraphMode.VLLM_COMPILE"
|
||||
)
|
||||
compilation_config.set_splitting_ops_for_v1(
|
||||
all2all_backend=vllm_config.parallel_config.all2all_backend,
|
||||
data_parallel_size=vllm_config.parallel_config.
|
||||
data_parallel_size,
|
||||
data_parallel_size=vllm_config.parallel_config.data_parallel_size,
|
||||
)
|
||||
compilation_config.use_inductor = False
|
||||
# NOTE: Theoretically, we should also add vllm::mla_forward in the attention ops.
|
||||
@@ -275,11 +275,13 @@ class NPUPlatform(Platform):
|
||||
compilation_config.splitting_ops.extend(["vllm::mla_forward"])
|
||||
update_aclgraph_sizes(vllm_config)
|
||||
ascend_config.enable_npugraph_ex = False
|
||||
elif compilation_config.cudagraph_mode == CUDAGraphMode.FULL_DECODE_ONLY or\
|
||||
compilation_config.cudagraph_mode == CUDAGraphMode.FULL:
|
||||
elif (
|
||||
compilation_config.cudagraph_mode == CUDAGraphMode.FULL_DECODE_ONLY
|
||||
or compilation_config.cudagraph_mode == CUDAGraphMode.FULL
|
||||
):
|
||||
logger.info(
|
||||
"FULL_DECODE_ONLY compilation enabled on NPU. use_inductor not supported - "
|
||||
"using only ACL Graph mode")
|
||||
"FULL_DECODE_ONLY compilation enabled on NPU. use_inductor not supported - using only ACL Graph mode"
|
||||
)
|
||||
compilation_config.use_inductor = False
|
||||
compilation_config.splitting_ops = []
|
||||
warning_message = """\033[91m
|
||||
@@ -297,30 +299,31 @@ class NPUPlatform(Platform):
|
||||
logger.warning(warning_message)
|
||||
else:
|
||||
logger.info(
|
||||
"%s cudagraph_mode is not support on NPU. falling back to NONE",
|
||||
compilation_config.cudagraph_mode)
|
||||
"%s cudagraph_mode is not support on NPU. falling back to NONE", compilation_config.cudagraph_mode
|
||||
)
|
||||
compilation_config.cudagraph_mode = CUDAGraphMode.NONE
|
||||
compilation_config.mode = CompilationMode.NONE
|
||||
ascend_config.enable_npugraph_ex = False
|
||||
|
||||
# TODO: Remove this check when ACL Graph supports ASCEND_LAUNCH_BLOCKING=1
|
||||
# Then, we will have to discuss the error handling strategy and user experience
|
||||
if compilation_config.cudagraph_mode != CUDAGraphMode.NONE and \
|
||||
os.environ.get("ASCEND_LAUNCH_BLOCKING", "0") == "1":
|
||||
if (
|
||||
compilation_config.cudagraph_mode != CUDAGraphMode.NONE
|
||||
and os.environ.get("ASCEND_LAUNCH_BLOCKING", "0") == "1"
|
||||
):
|
||||
raise ValueError(
|
||||
"ACL graph is incompatible with ASCEND_LAUNCH_BLOCKING=1. "
|
||||
"Please unset ASCEND_LAUNCH_BLOCKING or set it to 0. If you "
|
||||
"need ASCEND_LAUNCH_BLOCKING for debugging, consider other methods — "
|
||||
"for example, check the plog files (default: $HOME/ascend/log/debug) "
|
||||
"for more information about runtime errors.")
|
||||
"for more information about runtime errors."
|
||||
)
|
||||
|
||||
if parallel_config and parallel_config.worker_cls == "auto":
|
||||
# TODO: this is a tricky way to disable `use_sequence_parallel_moe` in vllm.
|
||||
parallel_config.all2all_backend = "flashinfer_all2allv"
|
||||
if ascend_config.xlite_graph_config.enabled:
|
||||
logger.info(
|
||||
"openEuler Xlite enabled. See: https://atomgit.com/openeuler/GVirt/tree/master/xlite"
|
||||
)
|
||||
logger.info("openEuler Xlite enabled. See: https://atomgit.com/openeuler/GVirt/tree/master/xlite")
|
||||
parallel_config.worker_cls = "vllm_ascend.xlite.xlite_worker.XliteWorker"
|
||||
else:
|
||||
parallel_config.worker_cls = "vllm_ascend.worker.worker.NPUWorker"
|
||||
@@ -332,10 +335,9 @@ class NPUPlatform(Platform):
|
||||
compilation_config.custom_ops = ["all"]
|
||||
|
||||
if ascend_config.recompute_scheduler_enable:
|
||||
from vllm_ascend.core.recompute_scheduler import \
|
||||
RecomputeSchedulerConfig
|
||||
recompute_scheduler_config = RecomputeSchedulerConfig.initialize_from_config(
|
||||
vllm_config)
|
||||
from vllm_ascend.core.recompute_scheduler import RecomputeSchedulerConfig
|
||||
|
||||
recompute_scheduler_config = RecomputeSchedulerConfig.initialize_from_config(vllm_config)
|
||||
vllm_config.scheduler_config = recompute_scheduler_config
|
||||
|
||||
# Extend original scheduler_config to use SchedulerDynamicBatch.
|
||||
@@ -346,9 +348,11 @@ class NPUPlatform(Platform):
|
||||
vllm_config.scheduler_config.enable_chunked_prefill = True
|
||||
vllm_config.scheduler_config.SLO_limits_for_dynamic_batch = ascend_config.SLO_limits_for_dynamic_batch
|
||||
|
||||
if vllm_config.kv_transfer_config is not None and \
|
||||
cache_config.block_size != parallel_config.cp_kv_cache_interleave_size and \
|
||||
parallel_config.decode_context_parallel_size * parallel_config.prefill_context_parallel_size > 1:
|
||||
if (
|
||||
vllm_config.kv_transfer_config is not None
|
||||
and cache_config.block_size != parallel_config.cp_kv_cache_interleave_size
|
||||
and parallel_config.decode_context_parallel_size * parallel_config.prefill_context_parallel_size > 1
|
||||
):
|
||||
raise AssertionError(
|
||||
f"cp_kv_cache_interleave_size({parallel_config.cp_kv_cache_interleave_size}) "
|
||||
f"and block_size({cache_config.block_size}) "
|
||||
@@ -356,12 +360,14 @@ class NPUPlatform(Platform):
|
||||
)
|
||||
|
||||
if is_vl_model(vllm_config):
|
||||
if bool(int(os.getenv("VLLM_ASCEND_ENABLE_FLASHCOMM", '0'))) or \
|
||||
bool(int(os.getenv("VLLM_ASCEND_ENABLE_FLASHCOMM1", '0'))):
|
||||
if bool(int(os.getenv("VLLM_ASCEND_ENABLE_FLASHCOMM", "0"))) or bool(
|
||||
int(os.getenv("VLLM_ASCEND_ENABLE_FLASHCOMM1", "0"))
|
||||
):
|
||||
raise ValueError(
|
||||
"Currently, VL models doesn't support "
|
||||
"FLASHCOMM in vllm-ascend. We will fix this in the future. "
|
||||
"Please set VLLM_ASCEND_ENABLE_FLASHCOMM1=0.")
|
||||
"Please set VLLM_ASCEND_ENABLE_FLASHCOMM1=0."
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def import_kernels(cls) -> None:
|
||||
@@ -377,14 +383,11 @@ class NPUPlatform(Platform):
|
||||
if _CUSTOM_OP_REGISTERED:
|
||||
return
|
||||
CUR_DIR = os.path.dirname(os.path.realpath(__file__))
|
||||
CUSTOM_OPP_PATH = os.path.join(CUR_DIR, "_cann_ops_custom", "vendors",
|
||||
"vllm-ascend")
|
||||
CUSTOM_OPP_PATH = os.path.join(CUR_DIR, "_cann_ops_custom", "vendors", "vllm-ascend")
|
||||
if os.path.exists(CUSTOM_OPP_PATH):
|
||||
current_cust_opp_path = os.environ.get("ASCEND_CUSTOM_OPP_PATH",
|
||||
"")
|
||||
current_cust_opp_path = os.environ.get("ASCEND_CUSTOM_OPP_PATH", "")
|
||||
if current_cust_opp_path:
|
||||
os.environ[
|
||||
"ASCEND_CUSTOM_OPP_PATH"] = f"{CUSTOM_OPP_PATH}:{current_cust_opp_path}"
|
||||
os.environ["ASCEND_CUSTOM_OPP_PATH"] = f"{CUSTOM_OPP_PATH}:{current_cust_opp_path}"
|
||||
else:
|
||||
os.environ["ASCEND_CUSTOM_OPP_PATH"] = CUSTOM_OPP_PATH
|
||||
_CUSTOM_OP_REGISTERED = True
|
||||
@@ -393,22 +396,18 @@ class NPUPlatform(Platform):
|
||||
def get_attn_backend_cls(cls, selected_backend, attn_selector_config):
|
||||
backend_map = {
|
||||
(True, False): "vllm_ascend.attention.mla_v1.AscendMLABackend",
|
||||
(False, False):
|
||||
"vllm_ascend.attention.attention_v1.AscendAttentionBackend",
|
||||
(False, False): "vllm_ascend.attention.attention_v1.AscendAttentionBackend",
|
||||
(True, True): "vllm_ascend.attention.sfa_v1.AscendSFABackend",
|
||||
}
|
||||
|
||||
return backend_map[(attn_selector_config.use_mla,
|
||||
attn_selector_config.use_sparse)]
|
||||
return backend_map[(attn_selector_config.use_mla, attn_selector_config.use_sparse)]
|
||||
|
||||
@classmethod
|
||||
def get_punica_wrapper(cls) -> str:
|
||||
return "vllm_ascend.lora.punica_npu.PunicaWrapperNPU"
|
||||
|
||||
@classmethod
|
||||
def get_current_memory_usage(cls,
|
||||
device: Optional[torch.types.Device] = None
|
||||
) -> float:
|
||||
def get_current_memory_usage(cls, device: torch.types.Device | None = None) -> float:
|
||||
torch.npu.reset_peak_memory_stats(device)
|
||||
return torch.npu.max_memory_allocated(device)
|
||||
|
||||
@@ -457,32 +456,33 @@ class NPUPlatform(Platform):
|
||||
Args:
|
||||
attn_metadata (dict[str, Any]): attention metadata for all layers.
|
||||
vllm_config (VllmConfig): configuration of vllm.
|
||||
dp_metadata (DpMetada): metadata for data parallelism.
|
||||
dp_metadata (DpMetada): metadata for data parallelism.
|
||||
lack of typehint because of circular import.
|
||||
virtual_engine (int, optional): index of virtual engine. Defaults to 0.
|
||||
num_tokens (int | None, optional): number of tokens. Defaults to None.
|
||||
num_tokens_across_dp (torch.Tensor | None, optional): number of tokens
|
||||
num_tokens_across_dp (torch.Tensor | None, optional): number of tokens
|
||||
across data parallelism.Defaults to None.
|
||||
cudagraph_runtime_mode (CUDAGraphMode, optional): mode of cudagraph runtime.
|
||||
Defaults to None.lack of typehint because of circular import.
|
||||
batch_descriptor (BatchDescriptor, optional): descriptor of batch.
|
||||
Defaults to None.
|
||||
ubatch_slices (UBatchSlices, optional): slice info for dual batch.
|
||||
ubatch_slices (UBatchSlices, optional): slice info for dual batch.
|
||||
Defaults to None. lack of typehint because of circular import
|
||||
|
||||
Returns:
|
||||
dict[str, Any]: _description_
|
||||
"""
|
||||
# NOTE(Ronald1995): avoid circular import.
|
||||
from vllm_ascend.ascend_forward_context import (get_mc2_mask,
|
||||
select_moe_comm_method)
|
||||
from vllm_ascend.ascend_forward_context import get_mc2_mask, select_moe_comm_method
|
||||
from vllm_ascend.ops.fused_moe.moe_comm_method import get_moe_comm_method
|
||||
from vllm.distributed import get_dp_group, get_tensor_model_parallel_world_size
|
||||
|
||||
# NOTE(Ronald1995): avoid circular import, cudagraph_runtime_mode is
|
||||
# CUDAGraphMode.NONE in vllm, but we can't set CUDAGraphMode.NONE in
|
||||
# argument default value, so we set it to None first, then set it to
|
||||
# CUDAGraphMode.NONE here.
|
||||
from vllm.config import CUDAGraphMode
|
||||
|
||||
if cudagraph_runtime_mode is None:
|
||||
cudagraph_runtime_mode = CUDAGraphMode.NONE
|
||||
# TODO(Ronald1995): model runner v1 still use ascend_forward_context,
|
||||
@@ -526,31 +526,26 @@ class NPUPlatform(Platform):
|
||||
sp_enabled = enable_sp(vllm_config) and num_tokens is not None
|
||||
mmrs_fusion = False
|
||||
else:
|
||||
sp_enabled = enable_sp(vllm_config) and \
|
||||
num_tokens is not None and num_tokens > 1000
|
||||
sp_enabled = enable_sp(vllm_config) and num_tokens is not None and num_tokens > 1000
|
||||
|
||||
# TODO(Levi-JQ): another PR to normalize the enabling logic for sp/fc2
|
||||
flashcomm_v2_enabled = flashcomm2_enable(
|
||||
) and tp_world_size > 1 and num_tokens is not None
|
||||
flashcomm_v2_enabled = flashcomm2_enable() and tp_world_size > 1 and num_tokens is not None
|
||||
pad_size = 0
|
||||
if (sp_enabled or flashcomm_v2_enabled):
|
||||
pad_size = (tp_world_size -
|
||||
(num_tokens % tp_world_size)) % tp_world_size
|
||||
if sp_enabled or flashcomm_v2_enabled:
|
||||
pad_size = (tp_world_size - (num_tokens % tp_world_size)) % tp_world_size
|
||||
|
||||
dp_world_size = get_dp_group().world_size
|
||||
if dp_world_size > 1 and dp_metadata is not None:
|
||||
max_tokens_across_dp = dp_metadata.max_tokens_across_dp_cpu.item()
|
||||
if (sp_enabled or flashcomm_v2_enabled):
|
||||
padded_length = (max_tokens_across_dp + tp_world_size -
|
||||
1) // tp_world_size * tp_world_size
|
||||
if sp_enabled or flashcomm_v2_enabled:
|
||||
padded_length = (max_tokens_across_dp + tp_world_size - 1) // tp_world_size * tp_world_size
|
||||
pad_size = padded_length - num_tokens
|
||||
else:
|
||||
max_tokens_across_dp = num_tokens
|
||||
|
||||
if num_tokens is not None:
|
||||
# NOTE: token num which need to pad to when mc2
|
||||
padded_num_tokens = math.ceil(
|
||||
max_tokens_across_dp / tp_world_size) * tp_world_size
|
||||
padded_num_tokens = math.ceil(max_tokens_across_dp / tp_world_size) * tp_world_size
|
||||
reserved_mc2_mask = get_mc2_mask()
|
||||
if reserved_mc2_mask is not None:
|
||||
mc2_mask = reserved_mc2_mask[:padded_num_tokens]
|
||||
|
||||
Reference in New Issue
Block a user