[Bugfix][Model] Fix fusedmoe and make modelrunner_v1 compatible with latest vllm (#867)
### What this PR does / why we need it? this PR fix CI failure broken by vllm. 1. add moe_config for fused_moe 2. adjust the change for kv cache group from vllm. currently vllm-ascend doesn't support this feature. this is just a quick fix for backward compatibility fix: #872 --------- Signed-off-by: MengqingCao <cmq0113@163.com>
This commit is contained in:
@@ -20,12 +20,22 @@ from typing import Callable, Optional
|
||||
import torch
|
||||
import torch_npu
|
||||
from vllm.config import get_current_vllm_config
|
||||
from vllm.distributed import tensor_model_parallel_all_reduce
|
||||
from vllm.distributed import (get_tensor_model_parallel_world_size,
|
||||
tensor_model_parallel_all_reduce)
|
||||
from vllm.distributed.parallel_state import get_dp_group
|
||||
from vllm.model_executor.layers.fused_moe.layer import (
|
||||
FusedMoE, UnquantizedFusedMoEMethod, determine_expert_map)
|
||||
from vllm.model_executor.layers.quantization.base_config import \
|
||||
QuantizeMethodBase
|
||||
|
||||
from vllm_ascend.utils import vllm_version_is
|
||||
|
||||
if not (vllm_version_is("0.8.5") or vllm_version_is("0.8.5.post1")):
|
||||
from vllm.model_executor.layers.fused_moe.layer import (
|
||||
FusedMoEParallelConfig, MoEConfig)
|
||||
else:
|
||||
MoEConfig = None
|
||||
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig, QuantizeMethodBase)
|
||||
|
||||
import vllm_ascend.envs as envs_ascend
|
||||
from vllm_ascend.distributed.parallel_state import get_ep_group, get_etp_group
|
||||
@@ -437,8 +447,11 @@ def select_experts(
|
||||
|
||||
class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
def __init__(self, moe: MoEConfig = None):
|
||||
if vllm_version_is("0.8.5") or vllm_version_is("0.8.5.post1"):
|
||||
super().__init__()
|
||||
else:
|
||||
super().__init__(moe=moe)
|
||||
vllm_config = get_current_vllm_config()
|
||||
|
||||
ep_group = get_ep_group()
|
||||
@@ -535,37 +548,54 @@ class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
|
||||
|
||||
class AscendFusedMoE(FusedMoE):
|
||||
|
||||
def __init__(self,
|
||||
num_experts,
|
||||
top_k,
|
||||
hidden_size,
|
||||
intermediate_size,
|
||||
params_dtype=None,
|
||||
reduce_results=False,
|
||||
renormalize=True,
|
||||
use_grouped_topk=False,
|
||||
num_expert_group=None,
|
||||
topk_group=None,
|
||||
quant_config=None,
|
||||
tp_size=None,
|
||||
ep_size=None,
|
||||
dp_size=None,
|
||||
prefix="",
|
||||
custom_routing_function=None,
|
||||
scoring_func="softmax",
|
||||
e_score_correction_bias=None,
|
||||
activation="silu"):
|
||||
def __init__(
|
||||
self,
|
||||
num_experts: int, # Global number of experts
|
||||
top_k: int,
|
||||
hidden_size: int,
|
||||
intermediate_size: int,
|
||||
params_dtype: Optional[torch.dtype] = None,
|
||||
reduce_results: bool = False,
|
||||
renormalize: bool = True,
|
||||
use_grouped_topk: bool = False,
|
||||
num_expert_group: Optional[int] = None,
|
||||
topk_group: Optional[int] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
tp_size: Optional[int] = None,
|
||||
ep_size: Optional[int] = None,
|
||||
dp_size: Optional[int] = None,
|
||||
prefix: str = "",
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
scoring_func: str = "softmax",
|
||||
e_score_correction_bias: Optional[torch.Tensor] = None,
|
||||
activation: str = "silu",
|
||||
apply_router_weight_on_input: bool = False,
|
||||
):
|
||||
# TODO: This could not initialize FusedMoE baseclass,
|
||||
# fixme and make __init__() of AscendFusedMoE more clear
|
||||
super(FusedMoE, self).__init__()
|
||||
|
||||
if params_dtype is None:
|
||||
params_dtype = torch.get_default_dtype()
|
||||
|
||||
self.ep_size = get_ep_group().world_size
|
||||
self.tp_size = get_etp_group().world_size
|
||||
self.dp_size = (dp_size
|
||||
if dp_size is not None else get_dp_group().world_size)
|
||||
self.dp_rank = (0
|
||||
if self.dp_size == 1 else get_dp_group().rank_in_group)
|
||||
vllm_config = get_current_vllm_config()
|
||||
if vllm_version_is("0.8.5") or vllm_version_is("0.8.5.post1"):
|
||||
self.ep_size = get_ep_group().world_size
|
||||
self.tp_size = get_etp_group().world_size
|
||||
self.dp_size = (dp_size if dp_size is not None else
|
||||
get_dp_group().world_size)
|
||||
self.dp_rank = (0 if self.dp_size == 1 else
|
||||
get_dp_group().rank_in_group)
|
||||
else:
|
||||
self.moe_parallel_config: FusedMoEParallelConfig = (
|
||||
FusedMoEParallelConfig.make(
|
||||
tp_size_=(tp_size if tp_size is not None else
|
||||
get_tensor_model_parallel_world_size()),
|
||||
dp_size_=(dp_size if dp_size is not None else
|
||||
get_dp_group().world_size),
|
||||
vllm_parallel_config=vllm_config.parallel_config))
|
||||
|
||||
self.moe_parallel_config.ep_size = get_ep_group().world_size
|
||||
|
||||
self.top_k = top_k
|
||||
self.num_experts = num_experts
|
||||
@@ -590,27 +620,55 @@ class AscendFusedMoE(FusedMoE):
|
||||
self.local_num_experts, self.expert_map = determine_expert_map(
|
||||
self.ep_size,
|
||||
get_ep_group().rank_in_group, self.global_num_experts)
|
||||
self.tp_rank = get_etp_group().rank_in_group
|
||||
self.ep_rank = get_ep_group().rank_in_group
|
||||
if vllm_version_is("0.8.5") or vllm_version_is("0.8.5.post1"):
|
||||
self.tp_rank = get_etp_group().rank_in_group
|
||||
self.ep_rank = get_ep_group().rank_in_group
|
||||
else:
|
||||
self.moe_parallel_config.tp_rank = get_etp_group(
|
||||
).rank_in_group
|
||||
self.moe_parallel_config.ep_rank = get_ep_group().rank_in_group
|
||||
|
||||
else:
|
||||
# Adjust TP size for DP attention
|
||||
# haven't test its functionality yet, may remove in the future
|
||||
self.tp_rank = self.tp_size * self.dp_rank
|
||||
self.ep_rank = 0
|
||||
self.tp_size = self.tp_size * self.dp_size
|
||||
self.ep_size = 1
|
||||
self.local_num_experts = self.global_num_experts
|
||||
self.expert_map = None
|
||||
if vllm_version_is("0.8.5") or vllm_version_is("0.8.5.post1"):
|
||||
self.tp_rank = self.tp_size * self.dp_rank
|
||||
self.ep_rank = 0
|
||||
self.tp_size = self.tp_size * self.dp_size
|
||||
self.ep_size = 1
|
||||
else:
|
||||
self.moe_parallel_config.tp_rank = self.tp_size * self.dp_rank
|
||||
self.moe_parallel_config.ep_rank = 0
|
||||
self.moe_parallel_config.tp_size = self.tp_size * self.dp_size
|
||||
self.moe_parallel_config.ep_size = 1
|
||||
|
||||
self.local_num_experts, self.expert_map = (self.global_num_experts,
|
||||
None)
|
||||
if self.scoring_func != "softmax" and not self.use_grouped_topk:
|
||||
raise ValueError("Only softmax scoring function is supported for "
|
||||
"non-grouped topk.")
|
||||
|
||||
if quant_config is None:
|
||||
self.quant_method: Optional[QuantizeMethodBase] = (
|
||||
AscendUnquantizedFusedMoEMethod())
|
||||
if vllm_version_is("0.8.5") or vllm_version_is("0.8.5.post1"):
|
||||
if quant_config is None:
|
||||
self.quant_method: Optional[QuantizeMethodBase] = (
|
||||
AscendUnquantizedFusedMoEMethod())
|
||||
else:
|
||||
self.quant_method = quant_config.get_quant_method(self, prefix)
|
||||
else:
|
||||
self.quant_method = quant_config.get_quant_method(self, prefix)
|
||||
moe = MoEConfig(
|
||||
num_experts=self.global_num_experts,
|
||||
experts_per_token=top_k,
|
||||
hidden_dim=hidden_size,
|
||||
num_local_experts=self.local_num_experts,
|
||||
moe_parallel_config=self.moe_parallel_config,
|
||||
# TODO (bnell): this needs to be fixed for quantized types.
|
||||
in_dtype=params_dtype,
|
||||
)
|
||||
|
||||
if quant_config is None:
|
||||
self.quant_method = AscendUnquantizedFusedMoEMethod(moe)
|
||||
else:
|
||||
self.quant_method = quant_config.get_quant_method(self, prefix)
|
||||
|
||||
assert self.quant_method is not None
|
||||
|
||||
local_num_experts = torch.sum(self.expert_map != -1) \
|
||||
|
||||
Reference in New Issue
Block a user