[Bugfix] Fix broken CI (#2825)

### What this PR does / why we need it?
1. Initial support disable tp for integrating with
[vllm-commit](https://github.com/vllm-project/vllm/pull/23024)
2. [vllm@commit](https://github.com/vllm-project/vllm/pull/23673) now
use `bytes` to save the `BlockHash` to reduce GC overhead, this pr add
the integration

- vLLM version: main
- vLLM main:
e40827280b

---------

Signed-off-by: wangli <wangli858794774@gmail.com>
This commit is contained in:
Li Wang
2025-09-10 13:29:29 +08:00
committed by GitHub
parent aa4d2a91ed
commit 22b425765a
3 changed files with 21 additions and 8 deletions

View File

@@ -8,6 +8,7 @@ from vllm.config import (CacheConfig, KVTransferConfig, ModelConfig,
SchedulerConfig, SpeculativeConfig, VllmConfig) SchedulerConfig, SpeculativeConfig, VllmConfig)
from vllm.multimodal.inputs import PlaceholderRange from vllm.multimodal.inputs import PlaceholderRange
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.utils import sha256
from vllm.v1.core.kv_cache_utils import (get_request_block_hasher, from vllm.v1.core.kv_cache_utils import (get_request_block_hasher,
init_none_hash) init_none_hash)
from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.core.sched.output import SchedulerOutput
@@ -38,7 +39,7 @@ def create_requests(
max_tokens: int = 16, max_tokens: int = 16,
stop_token_ids: Optional[list[int]] = None, stop_token_ids: Optional[list[int]] = None,
block_size: int = 3, block_size: int = 3,
hash_fn=hash, hash_fn=sha256,
): ):
init_none_hash(hash_fn) init_none_hash(hash_fn)
prompt_logprobs = PROMPT_LOGPROBS prompt_logprobs = PROMPT_LOGPROBS

View File

@@ -10,6 +10,7 @@ import torch
from vllm import SamplingParams from vllm import SamplingParams
from vllm.config import (CacheConfig, DeviceConfig, KVTransferConfig, from vllm.config import (CacheConfig, DeviceConfig, KVTransferConfig,
ModelConfig, SchedulerConfig, VllmConfig) ModelConfig, SchedulerConfig, VllmConfig)
from vllm.utils import sha256
from vllm.v1.core.kv_cache_utils import (get_request_block_hasher, from vllm.v1.core.kv_cache_utils import (get_request_block_hasher,
init_none_hash) init_none_hash)
from vllm.v1.core.sched.scheduler import Scheduler from vllm.v1.core.sched.scheduler import Scheduler
@@ -129,10 +130,10 @@ def create_request(
"""Make dummy request for testing.""" """Make dummy request for testing."""
global _none_hash_initialized global _none_hash_initialized
if not _none_hash_initialized: if not _none_hash_initialized:
init_none_hash(hash) init_none_hash(sha256)
_none_hash_initialized = True _none_hash_initialized = True
block_hasher = get_request_block_hasher(block_size, hash) block_hasher = get_request_block_hasher(block_size, sha256)
kv_transfer_params: Optional[dict[str, Any]] = None kv_transfer_params: Optional[dict[str, Any]] = None

View File

@@ -62,6 +62,7 @@ class AscendColumnParallelLinear(ColumnParallelLinear):
prefix: str = "", prefix: str = "",
*, *,
return_bias: bool = True, return_bias: bool = True,
disable_tp: bool = False,
): ):
self.comm_group = None self.comm_group = None
if prefix.find("gate_up_proj") != -1 and mlp_tp_enable(): if prefix.find("gate_up_proj") != -1 and mlp_tp_enable():
@@ -88,7 +89,8 @@ class AscendColumnParallelLinear(ColumnParallelLinear):
params_dtype, params_dtype,
quant_config, quant_config,
prefix, prefix,
return_bias=return_bias) return_bias=return_bias,
disable_tp=disable_tp)
self.gather_output = gather_output self.gather_output = gather_output
@@ -137,6 +139,7 @@ class AscendRowParallelLinear(RowParallelLinear):
prefix: str = "", prefix: str = "",
*, *,
return_bias: bool = True, return_bias: bool = True,
disable_tp: bool = False,
): ):
if prefix.find("down_proj") != -1 and mlp_tp_enable(): if prefix.find("down_proj") != -1 and mlp_tp_enable():
comm_group = get_mlp_tp_group() comm_group = get_mlp_tp_group()
@@ -156,6 +159,7 @@ class AscendRowParallelLinear(RowParallelLinear):
self.forward_type = "normal" self.forward_type = "normal"
self.comm_group = comm_group self.comm_group = comm_group
# TODO: check for disable_tp
self.tp_size = self.comm_group.world_size self.tp_size = self.comm_group.world_size
self.tp_rank = self.comm_group.rank_in_group self.tp_rank = self.comm_group.rank_in_group
@@ -171,7 +175,8 @@ class AscendRowParallelLinear(RowParallelLinear):
params_dtype, params_dtype,
quant_config, quant_config,
prefix, prefix,
return_bias=return_bias) return_bias=return_bias,
disable_tp=disable_tp)
self.input_is_parallel = input_is_parallel self.input_is_parallel = input_is_parallel
self.reduce_results = reduce_results self.reduce_results = reduce_results
@@ -392,6 +397,7 @@ class AscendMergedColumnParallelLinear(MergedColumnParallelLinear):
prefix: str = "", prefix: str = "",
*, *,
return_bias: bool = True, return_bias: bool = True,
disable_tp: bool = False,
): ):
if prefix.find("gate_up_proj") != -1 and mlp_tp_enable(): if prefix.find("gate_up_proj") != -1 and mlp_tp_enable():
comm_group = get_mlp_tp_group() comm_group = get_mlp_tp_group()
@@ -403,6 +409,7 @@ class AscendMergedColumnParallelLinear(MergedColumnParallelLinear):
comm_group = get_tp_group() comm_group = get_tp_group()
self.forward_type = "normal_tp" self.forward_type = "normal_tp"
self.comm_group = comm_group self.comm_group = comm_group
# TODO: check for disable_tp
self.tp_rank = comm_group.rank_in_group self.tp_rank = comm_group.rank_in_group
self.tp_size = comm_group.world_size self.tp_size = comm_group.world_size
@@ -418,7 +425,8 @@ class AscendMergedColumnParallelLinear(MergedColumnParallelLinear):
params_dtype=params_dtype, params_dtype=params_dtype,
quant_config=quant_config, quant_config=quant_config,
prefix=prefix, prefix=prefix,
return_bias=return_bias) return_bias=return_bias,
disable_tp=disable_tp)
def forward( def forward(
self, self,
@@ -498,6 +506,7 @@ class AscendQKVParallelLinear(QKVParallelLinear):
prefix: str = "", prefix: str = "",
*, *,
return_bias: bool = True, return_bias: bool = True,
disable_tp: bool = False,
): ):
if dense_optim_enable(): if dense_optim_enable():
self.forward_type = "dense_optim" self.forward_type = "dense_optim"
@@ -511,6 +520,7 @@ class AscendQKVParallelLinear(QKVParallelLinear):
total_num_kv_heads = total_num_heads total_num_kv_heads = total_num_heads
self.total_num_kv_heads = total_num_kv_heads self.total_num_kv_heads = total_num_kv_heads
# Divide the weight matrix along the last dimension. # Divide the weight matrix along the last dimension.
# TODO: check for disable_tp
tp_size = self.comm_group.world_size tp_size = self.comm_group.world_size
self.num_heads = divide(self.total_num_heads, tp_size) self.num_heads = divide(self.total_num_heads, tp_size)
if tp_size >= self.total_num_kv_heads: if tp_size >= self.total_num_kv_heads:
@@ -537,7 +547,8 @@ class AscendQKVParallelLinear(QKVParallelLinear):
params_dtype=params_dtype, params_dtype=params_dtype,
quant_config=quant_config, quant_config=quant_config,
prefix=prefix, prefix=prefix,
return_bias=return_bias) return_bias=return_bias,
disable_tp=disable_tp)
def forward( def forward(
self, self,
@@ -611,4 +622,4 @@ class AscendLinearBase(LinearBase):
self.quant_method = quant_config.get_quant_method(self, self.quant_method = quant_config.get_quant_method(self,
prefix=prefix) prefix=prefix)
self.return_bias = return_bias self.return_bias = return_bias
self.disable_tp = disable_tp self.disable_tp = disable_tp