From 22b425765a36ab507ca23625409c767dabb4371b Mon Sep 17 00:00:00 2001 From: Li Wang Date: Wed, 10 Sep 2025 13:29:29 +0800 Subject: [PATCH] [Bugfix] Fix broken CI (#2825) ### What this PR does / why we need it? 1. Initial support disable tp for integrating with [vllm-commit](https://github.com/vllm-project/vllm/pull/23024) 2. [vllm@commit](https://github.com/vllm-project/vllm/pull/23673) now use `bytes` to save the `BlockHash` to reduce GC overhead, this pr add the integration - vLLM version: main - vLLM main: https://github.com/vllm-project/vllm/commit/e40827280b225bf0a0797cc9842fc3cdfea8ebdf --------- Signed-off-by: wangli --- tests/ut/core/test_scheduler.py | 3 ++- tests/ut/kv_connector/utils.py | 5 +++-- vllm_ascend/ops/linear.py | 21 ++++++++++++++++----- 3 files changed, 21 insertions(+), 8 deletions(-) diff --git a/tests/ut/core/test_scheduler.py b/tests/ut/core/test_scheduler.py index af78531..c2e21c0 100644 --- a/tests/ut/core/test_scheduler.py +++ b/tests/ut/core/test_scheduler.py @@ -8,6 +8,7 @@ from vllm.config import (CacheConfig, KVTransferConfig, ModelConfig, SchedulerConfig, SpeculativeConfig, VllmConfig) from vllm.multimodal.inputs import PlaceholderRange from vllm.sampling_params import SamplingParams +from vllm.utils import sha256 from vllm.v1.core.kv_cache_utils import (get_request_block_hasher, init_none_hash) from vllm.v1.core.sched.output import SchedulerOutput @@ -38,7 +39,7 @@ def create_requests( max_tokens: int = 16, stop_token_ids: Optional[list[int]] = None, block_size: int = 3, - hash_fn=hash, + hash_fn=sha256, ): init_none_hash(hash_fn) prompt_logprobs = PROMPT_LOGPROBS diff --git a/tests/ut/kv_connector/utils.py b/tests/ut/kv_connector/utils.py index 3676e87..d1bf01f 100644 --- a/tests/ut/kv_connector/utils.py +++ b/tests/ut/kv_connector/utils.py @@ -10,6 +10,7 @@ import torch from vllm import SamplingParams from vllm.config import (CacheConfig, DeviceConfig, KVTransferConfig, ModelConfig, SchedulerConfig, VllmConfig) +from vllm.utils import sha256 from vllm.v1.core.kv_cache_utils import (get_request_block_hasher, init_none_hash) from vllm.v1.core.sched.scheduler import Scheduler @@ -129,10 +130,10 @@ def create_request( """Make dummy request for testing.""" global _none_hash_initialized if not _none_hash_initialized: - init_none_hash(hash) + init_none_hash(sha256) _none_hash_initialized = True - block_hasher = get_request_block_hasher(block_size, hash) + block_hasher = get_request_block_hasher(block_size, sha256) kv_transfer_params: Optional[dict[str, Any]] = None diff --git a/vllm_ascend/ops/linear.py b/vllm_ascend/ops/linear.py index 8bb7b85..6bf9676 100644 --- a/vllm_ascend/ops/linear.py +++ b/vllm_ascend/ops/linear.py @@ -62,6 +62,7 @@ class AscendColumnParallelLinear(ColumnParallelLinear): prefix: str = "", *, return_bias: bool = True, + disable_tp: bool = False, ): self.comm_group = None if prefix.find("gate_up_proj") != -1 and mlp_tp_enable(): @@ -88,7 +89,8 @@ class AscendColumnParallelLinear(ColumnParallelLinear): params_dtype, quant_config, prefix, - return_bias=return_bias) + return_bias=return_bias, + disable_tp=disable_tp) self.gather_output = gather_output @@ -137,6 +139,7 @@ class AscendRowParallelLinear(RowParallelLinear): prefix: str = "", *, return_bias: bool = True, + disable_tp: bool = False, ): if prefix.find("down_proj") != -1 and mlp_tp_enable(): comm_group = get_mlp_tp_group() @@ -156,6 +159,7 @@ class AscendRowParallelLinear(RowParallelLinear): self.forward_type = "normal" self.comm_group = comm_group + # TODO: check for disable_tp self.tp_size = self.comm_group.world_size self.tp_rank = self.comm_group.rank_in_group @@ -171,7 +175,8 @@ class AscendRowParallelLinear(RowParallelLinear): params_dtype, quant_config, prefix, - return_bias=return_bias) + return_bias=return_bias, + disable_tp=disable_tp) self.input_is_parallel = input_is_parallel self.reduce_results = reduce_results @@ -392,6 +397,7 @@ class AscendMergedColumnParallelLinear(MergedColumnParallelLinear): prefix: str = "", *, return_bias: bool = True, + disable_tp: bool = False, ): if prefix.find("gate_up_proj") != -1 and mlp_tp_enable(): comm_group = get_mlp_tp_group() @@ -403,6 +409,7 @@ class AscendMergedColumnParallelLinear(MergedColumnParallelLinear): comm_group = get_tp_group() self.forward_type = "normal_tp" self.comm_group = comm_group + # TODO: check for disable_tp self.tp_rank = comm_group.rank_in_group self.tp_size = comm_group.world_size @@ -418,7 +425,8 @@ class AscendMergedColumnParallelLinear(MergedColumnParallelLinear): params_dtype=params_dtype, quant_config=quant_config, prefix=prefix, - return_bias=return_bias) + return_bias=return_bias, + disable_tp=disable_tp) def forward( self, @@ -498,6 +506,7 @@ class AscendQKVParallelLinear(QKVParallelLinear): prefix: str = "", *, return_bias: bool = True, + disable_tp: bool = False, ): if dense_optim_enable(): self.forward_type = "dense_optim" @@ -511,6 +520,7 @@ class AscendQKVParallelLinear(QKVParallelLinear): total_num_kv_heads = total_num_heads self.total_num_kv_heads = total_num_kv_heads # Divide the weight matrix along the last dimension. + # TODO: check for disable_tp tp_size = self.comm_group.world_size self.num_heads = divide(self.total_num_heads, tp_size) if tp_size >= self.total_num_kv_heads: @@ -537,7 +547,8 @@ class AscendQKVParallelLinear(QKVParallelLinear): params_dtype=params_dtype, quant_config=quant_config, prefix=prefix, - return_bias=return_bias) + return_bias=return_bias, + disable_tp=disable_tp) def forward( self, @@ -611,4 +622,4 @@ class AscendLinearBase(LinearBase): self.quant_method = quant_config.get_quant_method(self, prefix=prefix) self.return_bias = return_bias - self.disable_tp = disable_tp \ No newline at end of file + self.disable_tp = disable_tp