[Bugfix][Model] Fix fusedmoe and make modelrunner_v1 compatible with latest vllm (#867)

### What this PR does / why we need it?
this PR fix CI failure broken by vllm.
1. add moe_config for fused_moe
2. adjust the change for kv cache group from vllm. currently vllm-ascend
doesn't support this feature. this is just a quick fix for backward
compatibility

fix: #872

---------

Signed-off-by: MengqingCao <cmq0113@163.com>
This commit is contained in:
Mengqing Cao
2025-05-16 12:14:55 +08:00
committed by GitHub
parent fd515cd60b
commit 7a325b2e2d
4 changed files with 137 additions and 79 deletions

View File

@@ -30,6 +30,7 @@ from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.worker.gpu_input_batch import InputBatch
from vllm_ascend.ops.attention import vanilla_chunked_prefill
from vllm_ascend.utils import vllm_version_is
class AscendAttentionBackend(AttentionBackend):
@@ -140,8 +141,15 @@ class AscendAttentionMetadataBuilder:
def build(self, num_reqs, num_actual_tokens, max_query_len,
common_prefix_len):
block_table = (
self.runner.input_batch.block_table.get_device_tensor()[:num_reqs])
if vllm_version_is("0.8.5") or vllm_version_is("0.8.5.post1"):
block_table = (self.runner.input_batch.block_table.
get_device_tensor()[:num_reqs])
else:
block_table = self.runner.input_batch.block_table[
0].get_device_tensor()
block_table[:num_reqs, :self.runner.max_num_blocks_per_req] = (
block_table[:num_reqs])
query_lens = self.runner.query_lens
seq_lens = self.runner.seq_lens_cpu[:num_reqs]
slot_mapping = self.runner.slot_mapping_cpu[:num_actual_tokens].to(

View File

@@ -16,6 +16,7 @@ from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
from vllm_ascend.attention.attention_v1 import AscendAttentionState
from vllm_ascend.ops.attention import vanilla_chunked_prefill_mla
from vllm_ascend.utils import vllm_version_is
from vllm_ascend.worker.model_runner_v1 import NPUModelRunner
if TYPE_CHECKING:
@@ -238,8 +239,12 @@ class AscendMLAMetadataBuilder:
# function. We should avoid GPU -> CPU sync as much as possible because
# it blocks on all previous kernels.
device = self.runner.device
block_table = (
self.runner.input_batch.block_table.get_device_tensor()[:num_reqs])
if vllm_version_is("0.8.5") or vllm_version_is("0.8.5.post1"):
block_table = (self.runner.input_batch.block_table.
get_device_tensor()[:num_reqs])
else:
block_table = (self.runner.input_batch.block_table[0].
get_device_tensor()[:num_reqs])
slot_mapping = self.runner.slot_mapping_cpu[:num_actual_tokens].to(
device, non_blocking=True)
input_positions = self.runner.positions_cpu[:num_actual_tokens].to(
@@ -795,4 +800,4 @@ class AscendMLAImpl(MLAAttentionImpl):
output[:num_decode_tokens] = self._forward_decode(
decode_ql_nope, decode_q_pe, decode_k_nope, decode_k_pe,
kv_cache, attn_metadata)
return output_padded
return output_padded

View File

@@ -20,12 +20,22 @@ from typing import Callable, Optional
import torch
import torch_npu
from vllm.config import get_current_vllm_config
from vllm.distributed import tensor_model_parallel_all_reduce
from vllm.distributed import (get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce)
from vllm.distributed.parallel_state import get_dp_group
from vllm.model_executor.layers.fused_moe.layer import (
FusedMoE, UnquantizedFusedMoEMethod, determine_expert_map)
from vllm.model_executor.layers.quantization.base_config import \
QuantizeMethodBase
from vllm_ascend.utils import vllm_version_is
if not (vllm_version_is("0.8.5") or vllm_version_is("0.8.5.post1")):
from vllm.model_executor.layers.fused_moe.layer import (
FusedMoEParallelConfig, MoEConfig)
else:
MoEConfig = None
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase)
import vllm_ascend.envs as envs_ascend
from vllm_ascend.distributed.parallel_state import get_ep_group, get_etp_group
@@ -437,8 +447,11 @@ def select_experts(
class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
def __init__(self):
super().__init__()
def __init__(self, moe: MoEConfig = None):
if vllm_version_is("0.8.5") or vllm_version_is("0.8.5.post1"):
super().__init__()
else:
super().__init__(moe=moe)
vllm_config = get_current_vllm_config()
ep_group = get_ep_group()
@@ -535,37 +548,54 @@ class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
class AscendFusedMoE(FusedMoE):
def __init__(self,
num_experts,
top_k,
hidden_size,
intermediate_size,
params_dtype=None,
reduce_results=False,
renormalize=True,
use_grouped_topk=False,
num_expert_group=None,
topk_group=None,
quant_config=None,
tp_size=None,
ep_size=None,
dp_size=None,
prefix="",
custom_routing_function=None,
scoring_func="softmax",
e_score_correction_bias=None,
activation="silu"):
def __init__(
self,
num_experts: int, # Global number of experts
top_k: int,
hidden_size: int,
intermediate_size: int,
params_dtype: Optional[torch.dtype] = None,
reduce_results: bool = False,
renormalize: bool = True,
use_grouped_topk: bool = False,
num_expert_group: Optional[int] = None,
topk_group: Optional[int] = None,
quant_config: Optional[QuantizationConfig] = None,
tp_size: Optional[int] = None,
ep_size: Optional[int] = None,
dp_size: Optional[int] = None,
prefix: str = "",
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None,
activation: str = "silu",
apply_router_weight_on_input: bool = False,
):
# TODO: This could not initialize FusedMoE baseclass,
# fixme and make __init__() of AscendFusedMoE more clear
super(FusedMoE, self).__init__()
if params_dtype is None:
params_dtype = torch.get_default_dtype()
self.ep_size = get_ep_group().world_size
self.tp_size = get_etp_group().world_size
self.dp_size = (dp_size
if dp_size is not None else get_dp_group().world_size)
self.dp_rank = (0
if self.dp_size == 1 else get_dp_group().rank_in_group)
vllm_config = get_current_vllm_config()
if vllm_version_is("0.8.5") or vllm_version_is("0.8.5.post1"):
self.ep_size = get_ep_group().world_size
self.tp_size = get_etp_group().world_size
self.dp_size = (dp_size if dp_size is not None else
get_dp_group().world_size)
self.dp_rank = (0 if self.dp_size == 1 else
get_dp_group().rank_in_group)
else:
self.moe_parallel_config: FusedMoEParallelConfig = (
FusedMoEParallelConfig.make(
tp_size_=(tp_size if tp_size is not None else
get_tensor_model_parallel_world_size()),
dp_size_=(dp_size if dp_size is not None else
get_dp_group().world_size),
vllm_parallel_config=vllm_config.parallel_config))
self.moe_parallel_config.ep_size = get_ep_group().world_size
self.top_k = top_k
self.num_experts = num_experts
@@ -590,27 +620,55 @@ class AscendFusedMoE(FusedMoE):
self.local_num_experts, self.expert_map = determine_expert_map(
self.ep_size,
get_ep_group().rank_in_group, self.global_num_experts)
self.tp_rank = get_etp_group().rank_in_group
self.ep_rank = get_ep_group().rank_in_group
if vllm_version_is("0.8.5") or vllm_version_is("0.8.5.post1"):
self.tp_rank = get_etp_group().rank_in_group
self.ep_rank = get_ep_group().rank_in_group
else:
self.moe_parallel_config.tp_rank = get_etp_group(
).rank_in_group
self.moe_parallel_config.ep_rank = get_ep_group().rank_in_group
else:
# Adjust TP size for DP attention
# haven't test its functionality yet, may remove in the future
self.tp_rank = self.tp_size * self.dp_rank
self.ep_rank = 0
self.tp_size = self.tp_size * self.dp_size
self.ep_size = 1
self.local_num_experts = self.global_num_experts
self.expert_map = None
if vllm_version_is("0.8.5") or vllm_version_is("0.8.5.post1"):
self.tp_rank = self.tp_size * self.dp_rank
self.ep_rank = 0
self.tp_size = self.tp_size * self.dp_size
self.ep_size = 1
else:
self.moe_parallel_config.tp_rank = self.tp_size * self.dp_rank
self.moe_parallel_config.ep_rank = 0
self.moe_parallel_config.tp_size = self.tp_size * self.dp_size
self.moe_parallel_config.ep_size = 1
self.local_num_experts, self.expert_map = (self.global_num_experts,
None)
if self.scoring_func != "softmax" and not self.use_grouped_topk:
raise ValueError("Only softmax scoring function is supported for "
"non-grouped topk.")
if quant_config is None:
self.quant_method: Optional[QuantizeMethodBase] = (
AscendUnquantizedFusedMoEMethod())
if vllm_version_is("0.8.5") or vllm_version_is("0.8.5.post1"):
if quant_config is None:
self.quant_method: Optional[QuantizeMethodBase] = (
AscendUnquantizedFusedMoEMethod())
else:
self.quant_method = quant_config.get_quant_method(self, prefix)
else:
self.quant_method = quant_config.get_quant_method(self, prefix)
moe = MoEConfig(
num_experts=self.global_num_experts,
experts_per_token=top_k,
hidden_dim=hidden_size,
num_local_experts=self.local_num_experts,
moe_parallel_config=self.moe_parallel_config,
# TODO (bnell): this needs to be fixed for quantized types.
in_dtype=params_dtype,
)
if quant_config is None:
self.quant_method = AscendUnquantizedFusedMoEMethod(moe)
else:
self.quant_method = quant_config.get_quant_method(self, prefix)
assert self.quant_method is not None
local_num_experts = torch.sum(self.expert_map != -1) \

View File

@@ -111,8 +111,10 @@ class NPUModelRunner:
self.scheduler_config = vllm_config.scheduler_config
self.chunked_prefill_enabled = vllm_config.scheduler_config.chunked_prefill_enabled
self.device = device
self.is_multimodal_model = self.model_config.is_multimodal_model
self.block_size = vllm_config.cache_config.block_size
self.max_num_blocks_per_req = cdiv(self.model_config.max_model_len,
self.block_size)
self.max_num_tokens = self.scheduler_config.max_num_batched_tokens
@@ -155,24 +157,6 @@ class NPUModelRunner:
raise NotImplementedError(
"Non-Attention backend is not supported by V1 NPUModelRunner.")
self.attn_backend = get_attn_backend(
self.head_size,
self.dtype,
self.kv_cache_dtype,
self.block_size,
self.model_config.is_attention_free,
use_mla=self.model_config.use_mla,
)
if self.attn_backend is None:
error_msg = (
f"Error with get_att_backend: {self.head_size=}, "
f"{self.dtype=}, {self.kv_cache_dtype=}, {self.block_size=}, "
f"{self.model_config.is_attention_free=}, "
f"{self.model_config.use_mla=}")
logger.error(error_msg)
raise NotImplementedError(
"Non-Attention backend is not supported by V1 GPUModelRunner.")
self.attn_metadata_builder = self.attn_backend.get_builder_cls()(
weakref.proxy(self))
@@ -205,17 +189,6 @@ class NPUModelRunner:
pin_memory=True,
vocab_size=self.model_config.get_vocab_size(),
)
else:
self.input_batch = InputBatch(
max_num_reqs=self.max_num_reqs,
max_model_len=self.model_config.max_model_len,
max_num_blocks_per_req=self.max_num_blocks_per_req,
max_num_batched_tokens=self.max_num_tokens,
device=self.device,
pin_memory=True,
vocab_size=self.model_config.get_vocab_size(),
)
self.input_ids = torch.zeros(self.max_num_tokens,
dtype=torch.int32,
device=self.device)
@@ -562,7 +535,10 @@ class NPUModelRunner:
block_table_indices = (req_indices * self.max_num_blocks_per_req +
positions_np // self.block_size)
block_table_cpu = self.input_batch.block_table.get_cpu_tensor()
if vllm_version_is("0.8.5") or vllm_version_is("0.8.5.post1"):
block_table_cpu = self.input_batch.block_table.get_cpu_tensor()
else:
block_table_cpu = self.input_batch.block_table[0].get_cpu_tensor()
block_numbers = block_table_cpu.flatten()[block_table_indices].numpy()
block_offsets = positions_np % self.block_size
np.add(block_numbers * self.block_size,
@@ -976,6 +952,17 @@ class NPUModelRunner:
"""
import torch_npu
kv_caches: Dict[str, torch.Tensor] = {}
if not (vllm_version_is("0.8.5") or vllm_version_is("0.8.5.post1")):
self.input_batch = InputBatch(
max_num_reqs=self.max_num_reqs,
max_model_len=self.model_config.max_model_len,
max_num_batched_tokens=self.max_num_tokens,
device=self.device,
pin_memory=True,
vocab_size=self.model_config.get_vocab_size(),
kv_cache_config=kv_cache_config,
)
for kv_cache_group in kv_cache_config.kv_cache_groups:
kv_cache_spec = kv_cache_group.kv_cache_spec
for layer_name in kv_cache_group.layer_names: