[CI] upgrade to vllm 0.9.0 (#959)

Upgrade to vllm 0.9.0.
0.8.5 will not be supported any more.

Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
This commit is contained in:
wangxiyuan
2025-05-28 21:18:41 +08:00
committed by GitHub
parent e2a0c19cea
commit f6e5decc10
16 changed files with 79 additions and 146 deletions

View File

@@ -26,18 +26,10 @@ from vllm.distributed import (GroupCoordinator,
tensor_model_parallel_all_reduce)
from vllm.distributed.parallel_state import get_dp_group
from vllm.model_executor.layers.fused_moe.layer import (
FusedMoE, UnquantizedFusedMoEMethod, determine_expert_map)
from vllm_ascend.utils import vllm_version_is
if not (vllm_version_is("0.8.5") or vllm_version_is("0.8.5.post1")):
from vllm.model_executor.layers.fused_moe.layer import (
FusedMoEParallelConfig, MoEConfig)
else:
MoEConfig = None
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase)
FusedMoE, FusedMoEParallelConfig, MoEConfig, UnquantizedFusedMoEMethod,
determine_expert_map)
from vllm.model_executor.layers.quantization.base_config import \
QuantizationConfig
import vllm_ascend.envs as envs_ascend
from vllm_ascend.distributed.parallel_state import get_ep_group, get_etp_group
@@ -587,10 +579,8 @@ def select_experts(
class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
def __init__(self, moe: MoEConfig = None):
if vllm_version_is("0.8.5") or vllm_version_is("0.8.5.post1"):
super().__init__()
else:
super().__init__(moe=moe)
super().__init__(moe=moe)
vllm_config = get_current_vllm_config()
ep_group = get_ep_group()
@@ -731,24 +721,17 @@ class AscendFusedMoE(FusedMoE):
params_dtype = torch.get_default_dtype()
vllm_config = get_current_vllm_config()
if vllm_version_is("0.8.5") or vllm_version_is("0.8.5.post1"):
self.ep_size = get_ep_group().world_size
self.tp_size = get_etp_group().world_size
self.dp_size = (dp_size if dp_size is not None else
get_dp_group().world_size)
self.dp_rank = (0 if self.dp_size == 1 else
get_dp_group().rank_in_group)
else:
self.moe_parallel_config: FusedMoEParallelConfig = (
FusedMoEParallelConfig.make(
tp_size_=(tp_size if tp_size is not None else
get_tensor_model_parallel_world_size()),
dp_size_=(dp_size if dp_size is not None else
get_dp_group().world_size),
vllm_parallel_config=vllm_config.parallel_config))
self.moe_parallel_config.ep_size = get_ep_group().world_size
self.moe_parallel_config.tp_size = get_etp_group().world_size
self.moe_parallel_config: FusedMoEParallelConfig = (
FusedMoEParallelConfig.make(
tp_size_=(tp_size if tp_size is not None else
get_tensor_model_parallel_world_size()),
dp_size_=(dp_size if dp_size is not None else
get_dp_group().world_size),
vllm_parallel_config=vllm_config.parallel_config))
self.moe_parallel_config.ep_size = get_ep_group().world_size
self.moe_parallel_config.tp_size = get_etp_group().world_size
self.top_k = top_k
self.num_experts = num_experts
@@ -773,54 +756,39 @@ class AscendFusedMoE(FusedMoE):
self.local_num_experts, self.expert_map = determine_expert_map(
self.ep_size,
get_ep_group().rank_in_group, self.global_num_experts)
if vllm_version_is("0.8.5") or vllm_version_is("0.8.5.post1"):
self.tp_rank = get_etp_group().rank_in_group
self.ep_rank = get_ep_group().rank_in_group
else:
self.moe_parallel_config.tp_rank = get_etp_group(
).rank_in_group
self.moe_parallel_config.ep_rank = get_ep_group().rank_in_group
self.moe_parallel_config.tp_rank = get_etp_group().rank_in_group
self.moe_parallel_config.ep_rank = get_ep_group().rank_in_group
else:
# Adjust TP size for DP attention
# haven't test its functionality yet, may remove in the future
if vllm_version_is("0.8.5") or vllm_version_is("0.8.5.post1"):
self.tp_rank = self.tp_size * self.dp_rank
self.ep_rank = 0
self.tp_size = self.tp_size * self.dp_size
self.ep_size = 1
else:
self.moe_parallel_config.tp_rank = self.tp_size * self.dp_rank
self.moe_parallel_config.ep_rank = 0
self.moe_parallel_config.tp_size = self.tp_size * self.dp_size
self.moe_parallel_config.ep_size = 1
self.moe_parallel_config.tp_rank = self.tp_size * self.dp_rank
self.moe_parallel_config.ep_rank = 0
self.moe_parallel_config.tp_size = self.tp_size * self.dp_size
self.moe_parallel_config.ep_size = 1
self.local_num_experts, self.expert_map = (self.global_num_experts,
None)
if self.scoring_func != "softmax" and not self.use_grouped_topk:
raise ValueError("Only softmax scoring function is supported for "
"non-grouped topk.")
if vllm_version_is("0.8.5") or vllm_version_is("0.8.5.post1"):
if quant_config is None:
self.quant_method: Optional[QuantizeMethodBase] = (
AscendUnquantizedFusedMoEMethod())
else:
self.quant_method = quant_config.get_quant_method(self, prefix)
else:
moe = MoEConfig(
num_experts=self.global_num_experts,
experts_per_token=top_k,
hidden_dim=hidden_size,
num_local_experts=self.local_num_experts,
moe_parallel_config=self.moe_parallel_config,
# TODO (bnell): this needs to be fixed for quantized types.
in_dtype=params_dtype,
)
if quant_config is None:
self.quant_method = AscendUnquantizedFusedMoEMethod(moe)
else:
self.quant_method = quant_config.get_quant_method(self, prefix)
moe = MoEConfig(
num_experts=self.global_num_experts,
experts_per_token=top_k,
hidden_dim=hidden_size,
num_local_experts=self.local_num_experts,
moe_parallel_config=self.moe_parallel_config,
# TODO (bnell): this needs to be fixed for quantized types.
in_dtype=params_dtype,
)
if quant_config is None:
self.quant_method = AscendUnquantizedFusedMoEMethod(moe)
else:
self.quant_method = quant_config.get_quant_method(self, prefix)
assert self.quant_method is not None