[Bugfix] Fix broken CI (#2825)
### What this PR does / why we need it?
1. Initial support disable tp for integrating with
[vllm-commit](https://github.com/vllm-project/vllm/pull/23024)
2. [vllm@commit](https://github.com/vllm-project/vllm/pull/23673) now
use `bytes` to save the `BlockHash` to reduce GC overhead, this pr add
the integration
- vLLM version: main
- vLLM main:
e40827280b
---------
Signed-off-by: wangli <wangli858794774@gmail.com>
This commit is contained in:
@@ -8,6 +8,7 @@ from vllm.config import (CacheConfig, KVTransferConfig, ModelConfig,
|
||||
SchedulerConfig, SpeculativeConfig, VllmConfig)
|
||||
from vllm.multimodal.inputs import PlaceholderRange
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.utils import sha256
|
||||
from vllm.v1.core.kv_cache_utils import (get_request_block_hasher,
|
||||
init_none_hash)
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
@@ -38,7 +39,7 @@ def create_requests(
|
||||
max_tokens: int = 16,
|
||||
stop_token_ids: Optional[list[int]] = None,
|
||||
block_size: int = 3,
|
||||
hash_fn=hash,
|
||||
hash_fn=sha256,
|
||||
):
|
||||
init_none_hash(hash_fn)
|
||||
prompt_logprobs = PROMPT_LOGPROBS
|
||||
|
||||
@@ -10,6 +10,7 @@ import torch
|
||||
from vllm import SamplingParams
|
||||
from vllm.config import (CacheConfig, DeviceConfig, KVTransferConfig,
|
||||
ModelConfig, SchedulerConfig, VllmConfig)
|
||||
from vllm.utils import sha256
|
||||
from vllm.v1.core.kv_cache_utils import (get_request_block_hasher,
|
||||
init_none_hash)
|
||||
from vllm.v1.core.sched.scheduler import Scheduler
|
||||
@@ -129,10 +130,10 @@ def create_request(
|
||||
"""Make dummy request for testing."""
|
||||
global _none_hash_initialized
|
||||
if not _none_hash_initialized:
|
||||
init_none_hash(hash)
|
||||
init_none_hash(sha256)
|
||||
_none_hash_initialized = True
|
||||
|
||||
block_hasher = get_request_block_hasher(block_size, hash)
|
||||
block_hasher = get_request_block_hasher(block_size, sha256)
|
||||
|
||||
kv_transfer_params: Optional[dict[str, Any]] = None
|
||||
|
||||
|
||||
@@ -62,6 +62,7 @@ class AscendColumnParallelLinear(ColumnParallelLinear):
|
||||
prefix: str = "",
|
||||
*,
|
||||
return_bias: bool = True,
|
||||
disable_tp: bool = False,
|
||||
):
|
||||
self.comm_group = None
|
||||
if prefix.find("gate_up_proj") != -1 and mlp_tp_enable():
|
||||
@@ -88,7 +89,8 @@ class AscendColumnParallelLinear(ColumnParallelLinear):
|
||||
params_dtype,
|
||||
quant_config,
|
||||
prefix,
|
||||
return_bias=return_bias)
|
||||
return_bias=return_bias,
|
||||
disable_tp=disable_tp)
|
||||
|
||||
self.gather_output = gather_output
|
||||
|
||||
@@ -137,6 +139,7 @@ class AscendRowParallelLinear(RowParallelLinear):
|
||||
prefix: str = "",
|
||||
*,
|
||||
return_bias: bool = True,
|
||||
disable_tp: bool = False,
|
||||
):
|
||||
if prefix.find("down_proj") != -1 and mlp_tp_enable():
|
||||
comm_group = get_mlp_tp_group()
|
||||
@@ -156,6 +159,7 @@ class AscendRowParallelLinear(RowParallelLinear):
|
||||
self.forward_type = "normal"
|
||||
self.comm_group = comm_group
|
||||
|
||||
# TODO: check for disable_tp
|
||||
self.tp_size = self.comm_group.world_size
|
||||
self.tp_rank = self.comm_group.rank_in_group
|
||||
|
||||
@@ -171,7 +175,8 @@ class AscendRowParallelLinear(RowParallelLinear):
|
||||
params_dtype,
|
||||
quant_config,
|
||||
prefix,
|
||||
return_bias=return_bias)
|
||||
return_bias=return_bias,
|
||||
disable_tp=disable_tp)
|
||||
|
||||
self.input_is_parallel = input_is_parallel
|
||||
self.reduce_results = reduce_results
|
||||
@@ -392,6 +397,7 @@ class AscendMergedColumnParallelLinear(MergedColumnParallelLinear):
|
||||
prefix: str = "",
|
||||
*,
|
||||
return_bias: bool = True,
|
||||
disable_tp: bool = False,
|
||||
):
|
||||
if prefix.find("gate_up_proj") != -1 and mlp_tp_enable():
|
||||
comm_group = get_mlp_tp_group()
|
||||
@@ -403,6 +409,7 @@ class AscendMergedColumnParallelLinear(MergedColumnParallelLinear):
|
||||
comm_group = get_tp_group()
|
||||
self.forward_type = "normal_tp"
|
||||
self.comm_group = comm_group
|
||||
# TODO: check for disable_tp
|
||||
self.tp_rank = comm_group.rank_in_group
|
||||
self.tp_size = comm_group.world_size
|
||||
|
||||
@@ -418,7 +425,8 @@ class AscendMergedColumnParallelLinear(MergedColumnParallelLinear):
|
||||
params_dtype=params_dtype,
|
||||
quant_config=quant_config,
|
||||
prefix=prefix,
|
||||
return_bias=return_bias)
|
||||
return_bias=return_bias,
|
||||
disable_tp=disable_tp)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -498,6 +506,7 @@ class AscendQKVParallelLinear(QKVParallelLinear):
|
||||
prefix: str = "",
|
||||
*,
|
||||
return_bias: bool = True,
|
||||
disable_tp: bool = False,
|
||||
):
|
||||
if dense_optim_enable():
|
||||
self.forward_type = "dense_optim"
|
||||
@@ -511,6 +520,7 @@ class AscendQKVParallelLinear(QKVParallelLinear):
|
||||
total_num_kv_heads = total_num_heads
|
||||
self.total_num_kv_heads = total_num_kv_heads
|
||||
# Divide the weight matrix along the last dimension.
|
||||
# TODO: check for disable_tp
|
||||
tp_size = self.comm_group.world_size
|
||||
self.num_heads = divide(self.total_num_heads, tp_size)
|
||||
if tp_size >= self.total_num_kv_heads:
|
||||
@@ -537,7 +547,8 @@ class AscendQKVParallelLinear(QKVParallelLinear):
|
||||
params_dtype=params_dtype,
|
||||
quant_config=quant_config,
|
||||
prefix=prefix,
|
||||
return_bias=return_bias)
|
||||
return_bias=return_bias,
|
||||
disable_tp=disable_tp)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -611,4 +622,4 @@ class AscendLinearBase(LinearBase):
|
||||
self.quant_method = quant_config.get_quant_method(self,
|
||||
prefix=prefix)
|
||||
self.return_bias = return_bias
|
||||
self.disable_tp = disable_tp
|
||||
self.disable_tp = disable_tp
|
||||
|
||||
Reference in New Issue
Block a user