[Bugfix] Fix broken CI (#2825)

### What this PR does / why we need it?
1. Initial support disable tp for integrating with
[vllm-commit](https://github.com/vllm-project/vllm/pull/23024)
2. [vllm@commit](https://github.com/vllm-project/vllm/pull/23673) now
use `bytes` to save the `BlockHash` to reduce GC overhead, this pr add
the integration

- vLLM version: main
- vLLM main:
e40827280b

---------

Signed-off-by: wangli <wangli858794774@gmail.com>
This commit is contained in:
Li Wang
2025-09-10 13:29:29 +08:00
committed by GitHub
parent aa4d2a91ed
commit 22b425765a
3 changed files with 21 additions and 8 deletions

View File

@@ -8,6 +8,7 @@ from vllm.config import (CacheConfig, KVTransferConfig, ModelConfig,
SchedulerConfig, SpeculativeConfig, VllmConfig)
from vllm.multimodal.inputs import PlaceholderRange
from vllm.sampling_params import SamplingParams
from vllm.utils import sha256
from vllm.v1.core.kv_cache_utils import (get_request_block_hasher,
init_none_hash)
from vllm.v1.core.sched.output import SchedulerOutput
@@ -38,7 +39,7 @@ def create_requests(
max_tokens: int = 16,
stop_token_ids: Optional[list[int]] = None,
block_size: int = 3,
hash_fn=hash,
hash_fn=sha256,
):
init_none_hash(hash_fn)
prompt_logprobs = PROMPT_LOGPROBS

View File

@@ -10,6 +10,7 @@ import torch
from vllm import SamplingParams
from vllm.config import (CacheConfig, DeviceConfig, KVTransferConfig,
ModelConfig, SchedulerConfig, VllmConfig)
from vllm.utils import sha256
from vllm.v1.core.kv_cache_utils import (get_request_block_hasher,
init_none_hash)
from vllm.v1.core.sched.scheduler import Scheduler
@@ -129,10 +130,10 @@ def create_request(
"""Make dummy request for testing."""
global _none_hash_initialized
if not _none_hash_initialized:
init_none_hash(hash)
init_none_hash(sha256)
_none_hash_initialized = True
block_hasher = get_request_block_hasher(block_size, hash)
block_hasher = get_request_block_hasher(block_size, sha256)
kv_transfer_params: Optional[dict[str, Any]] = None

View File

@@ -62,6 +62,7 @@ class AscendColumnParallelLinear(ColumnParallelLinear):
prefix: str = "",
*,
return_bias: bool = True,
disable_tp: bool = False,
):
self.comm_group = None
if prefix.find("gate_up_proj") != -1 and mlp_tp_enable():
@@ -88,7 +89,8 @@ class AscendColumnParallelLinear(ColumnParallelLinear):
params_dtype,
quant_config,
prefix,
return_bias=return_bias)
return_bias=return_bias,
disable_tp=disable_tp)
self.gather_output = gather_output
@@ -137,6 +139,7 @@ class AscendRowParallelLinear(RowParallelLinear):
prefix: str = "",
*,
return_bias: bool = True,
disable_tp: bool = False,
):
if prefix.find("down_proj") != -1 and mlp_tp_enable():
comm_group = get_mlp_tp_group()
@@ -156,6 +159,7 @@ class AscendRowParallelLinear(RowParallelLinear):
self.forward_type = "normal"
self.comm_group = comm_group
# TODO: check for disable_tp
self.tp_size = self.comm_group.world_size
self.tp_rank = self.comm_group.rank_in_group
@@ -171,7 +175,8 @@ class AscendRowParallelLinear(RowParallelLinear):
params_dtype,
quant_config,
prefix,
return_bias=return_bias)
return_bias=return_bias,
disable_tp=disable_tp)
self.input_is_parallel = input_is_parallel
self.reduce_results = reduce_results
@@ -392,6 +397,7 @@ class AscendMergedColumnParallelLinear(MergedColumnParallelLinear):
prefix: str = "",
*,
return_bias: bool = True,
disable_tp: bool = False,
):
if prefix.find("gate_up_proj") != -1 and mlp_tp_enable():
comm_group = get_mlp_tp_group()
@@ -403,6 +409,7 @@ class AscendMergedColumnParallelLinear(MergedColumnParallelLinear):
comm_group = get_tp_group()
self.forward_type = "normal_tp"
self.comm_group = comm_group
# TODO: check for disable_tp
self.tp_rank = comm_group.rank_in_group
self.tp_size = comm_group.world_size
@@ -418,7 +425,8 @@ class AscendMergedColumnParallelLinear(MergedColumnParallelLinear):
params_dtype=params_dtype,
quant_config=quant_config,
prefix=prefix,
return_bias=return_bias)
return_bias=return_bias,
disable_tp=disable_tp)
def forward(
self,
@@ -498,6 +506,7 @@ class AscendQKVParallelLinear(QKVParallelLinear):
prefix: str = "",
*,
return_bias: bool = True,
disable_tp: bool = False,
):
if dense_optim_enable():
self.forward_type = "dense_optim"
@@ -511,6 +520,7 @@ class AscendQKVParallelLinear(QKVParallelLinear):
total_num_kv_heads = total_num_heads
self.total_num_kv_heads = total_num_kv_heads
# Divide the weight matrix along the last dimension.
# TODO: check for disable_tp
tp_size = self.comm_group.world_size
self.num_heads = divide(self.total_num_heads, tp_size)
if tp_size >= self.total_num_kv_heads:
@@ -537,7 +547,8 @@ class AscendQKVParallelLinear(QKVParallelLinear):
params_dtype=params_dtype,
quant_config=quant_config,
prefix=prefix,
return_bias=return_bias)
return_bias=return_bias,
disable_tp=disable_tp)
def forward(
self,
@@ -611,4 +622,4 @@ class AscendLinearBase(LinearBase):
self.quant_method = quant_config.get_quant_method(self,
prefix=prefix)
self.return_bias = return_bias
self.disable_tp = disable_tp
self.disable_tp = disable_tp