Initial commit for vLLM-Kunlun Plugin
This commit is contained in:
310
vllm_kunlun/ops/fused_moe/layer.py
Normal file
310
vllm_kunlun/ops/fused_moe/layer.py
Normal file
@@ -0,0 +1,310 @@
|
||||
#
|
||||
# Copyright (c) 2025 Baidu, Inc. All Rights Reserved.
|
||||
# Copyright 2023 The vLLM team.
|
||||
# Author: Dong Xinyu, Chen Zhennan, Bao Qian, Yuan Jizhong
|
||||
# Email: dongxinyu03@baidu.com
|
||||
# This file is a part of the vllm-kunlun project.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""layer.py"""
|
||||
import torch
|
||||
from typing import Callable, Optional
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.config import get_current_vllm_config
|
||||
from vllm.forward_context import ForwardContext, get_forward_context
|
||||
from vllm.distributed import get_ep_group
|
||||
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoE as VllmFusedMoE
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoEMethodBase as VllmFusedMoEMethodBase
|
||||
from vllm.model_executor.layers.fused_moe.layer import (
|
||||
UnquantizedFusedMoEMethod as VllmUnquantizedFusedMoEMethod)
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig, QuantizeMethodBase)
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
FusedMoEConfig, FusedMoEParallelConfig)
|
||||
|
||||
from vllm.model_executor.custom_op import CustomOp
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from vllm_kunlun.ops.quantization.compressed_tensors_moe import CompressedTensorsW8A8Int8MoEMethod
|
||||
|
||||
|
||||
class FusedMoEMethodBase(VllmFusedMoEMethodBase):
|
||||
"""FusedMoEMethodBase"""
|
||||
moe: FusedMoEConfig
|
||||
|
||||
@CustomOp.register("vllm_kunlun_unquantized_fused_moe")
|
||||
class UnquantizedFusedMoEMethod(VllmUnquantizedFusedMoEMethod):
|
||||
"""UnquantizedFusedMoEMethod"""
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
top_k: int,
|
||||
renormalize: bool,
|
||||
use_grouped_topk: bool = False,
|
||||
topk_group: Optional[int] = None,
|
||||
num_expert_group: Optional[int] = None,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
scoring_func: str = "softmax",
|
||||
e_score_correction_bias: Optional[torch.Tensor] = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
activation: str = "silu",
|
||||
enable_eplb: bool = False,
|
||||
expert_load_view: Optional[torch.Tensor] = None,
|
||||
logical_to_physical_map: Optional[torch.Tensor] = None,
|
||||
logical_replica_count: Optional[torch.Tensor] = None,
|
||||
linear_weights: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
"""apply"""
|
||||
if enable_eplb:
|
||||
raise NotImplementedError(
|
||||
"EPLB not supported for `UnquantizedFusedMoEMethod` yet.")
|
||||
|
||||
return self.forward_kunlun(x=x,
|
||||
layer=layer,
|
||||
router_logits=router_logits,
|
||||
top_k=top_k,
|
||||
renormalize=renormalize,
|
||||
use_grouped_topk=use_grouped_topk,
|
||||
topk_group=topk_group,
|
||||
num_expert_group=num_expert_group,
|
||||
custom_routing_function=custom_routing_function,
|
||||
linear_weights=linear_weights)
|
||||
|
||||
def forward_kunlun(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
use_grouped_topk: bool,
|
||||
top_k: int,
|
||||
router_logits: torch.Tensor,
|
||||
linear_weights: torch.Tensor,
|
||||
renormalize: bool,
|
||||
topk_group: Optional[int] = None,
|
||||
num_expert_group: Optional[int] = None,
|
||||
custom_routing_function: Optional[Callable] = None
|
||||
) -> torch.Tensor:
|
||||
"""forward_kunlun"""
|
||||
from vllm_kunlun.ops._kunlun_ops import KunlunOps as ops
|
||||
|
||||
if self.moe.use_ep:
|
||||
return ops.fused_moe_ep(x,
|
||||
layer.w13_weight,
|
||||
layer.w2_weight,
|
||||
router_logits,
|
||||
linear_weights,
|
||||
self.moe.ep_rank,
|
||||
top_k,
|
||||
renormalize=renormalize,
|
||||
inplace=True,
|
||||
use_grouped_topk=use_grouped_topk,
|
||||
num_expert_group=num_expert_group,
|
||||
topk_group=topk_group
|
||||
)
|
||||
else:
|
||||
return ops.fused_moe(x,
|
||||
layer.w13_weight,
|
||||
layer.w2_weight,
|
||||
router_logits,
|
||||
linear_weights,
|
||||
top_k,
|
||||
renormalize=renormalize,
|
||||
inplace=True,
|
||||
use_grouped_topk=use_grouped_topk,
|
||||
num_expert_group=num_expert_group,
|
||||
topk_group=topk_group
|
||||
)
|
||||
|
||||
class FusedMoE(VllmFusedMoE):
|
||||
"""FusedMoE"""
|
||||
def __init__(self,
|
||||
num_experts: int, # Global number of experts
|
||||
top_k: int,
|
||||
hidden_size: int,
|
||||
intermediate_size: int,
|
||||
params_dtype: Optional[torch.dtype] = None,
|
||||
reduce_results: bool = False,
|
||||
renormalize: bool = True,
|
||||
use_grouped_topk: bool = False,
|
||||
num_expert_group: Optional[int] = 0,
|
||||
topk_group: Optional[int] = 0,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
tp_size: Optional[int] = None,
|
||||
ep_size: Optional[int] = None,
|
||||
dp_size: Optional[int] = None,
|
||||
prefix: str = "",
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
scoring_func: str = "softmax",
|
||||
e_score_correction_bias: Optional[torch.Tensor] = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
activation: str = "silu",
|
||||
enable_eplb: bool = False,
|
||||
num_redundant_experts: int = 0,
|
||||
):
|
||||
super().__init__(
|
||||
num_experts=num_experts, # Global number of experts
|
||||
top_k=top_k,
|
||||
hidden_size=hidden_size,
|
||||
intermediate_size=intermediate_size,
|
||||
params_dtype=params_dtype,
|
||||
reduce_results=reduce_results,
|
||||
renormalize=renormalize,
|
||||
use_grouped_topk=use_grouped_topk,
|
||||
num_expert_group=num_expert_group,
|
||||
topk_group=topk_group,
|
||||
quant_config=quant_config,
|
||||
tp_size=tp_size,
|
||||
ep_size=ep_size,
|
||||
dp_size=dp_size,
|
||||
prefix=prefix,
|
||||
custom_routing_function=custom_routing_function,
|
||||
scoring_func=scoring_func,
|
||||
e_score_correction_bias=e_score_correction_bias,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
activation=activation,
|
||||
enable_eplb=enable_eplb,
|
||||
num_redundant_experts=num_redundant_experts,
|
||||
)
|
||||
|
||||
vllm_config = get_current_vllm_config()
|
||||
if vllm_config.model_config is not None:
|
||||
model_dtype = vllm_config.model_config.dtype
|
||||
else:
|
||||
# TODO (bnell): This is a hack to get test_mixtral_moe to work
|
||||
# since model_config is not set in the pytest test.
|
||||
model_dtype = params_dtype
|
||||
|
||||
moe = FusedMoEConfig.make(
|
||||
num_experts=self.global_num_experts,
|
||||
experts_per_token=top_k,
|
||||
hidden_dim=hidden_size,
|
||||
num_local_experts=self.local_num_experts,
|
||||
moe_parallel_config=self.moe_parallel_config,
|
||||
in_dtype=model_dtype,
|
||||
max_num_tokens=envs.VLLM_MOE_DP_CHUNK_SIZE,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
self.moe_config = moe
|
||||
self.quant_config = quant_config
|
||||
|
||||
# Note: get_quant_method will look at the layer's local_num_experts
|
||||
# for heuristic purposes, so it must be initialized first.
|
||||
quant_method: Optional[QuantizeMethodBase] = None
|
||||
quant_method = (UnquantizedFusedMoEMethod(moe) if quant_config is None
|
||||
else quant_config.get_quant_method(self, prefix))
|
||||
|
||||
assert quant_method is not None
|
||||
# assert isinstance(quant_method, FusedMoEMethodBase)
|
||||
self.quant_method = quant_method
|
||||
|
||||
if self.enable_eplb:
|
||||
from vllm_kunlun.ops.quantization.fp8 import (
|
||||
Fp8MoEMethod)
|
||||
if not isinstance(quant_method, Fp8MoEMethod):
|
||||
# TODO: Add support for additional quantization methods.
|
||||
# The implementation for other quantization methods does not
|
||||
# contain essential differences, but the current quant API
|
||||
# design causes duplicated work when extending to new
|
||||
# quantization methods, so I'm leaving it for now.
|
||||
# If you plan to add support for more quantization methods,
|
||||
# please refer to the implementation in `Fp8MoEMethod`.
|
||||
raise NotImplementedError("EPLB is only supported for FP8 "
|
||||
"quantization for now.")
|
||||
|
||||
moe_quant_params = {
|
||||
"num_experts": self.local_num_experts,
|
||||
"hidden_size": hidden_size,
|
||||
"intermediate_size_per_partition":
|
||||
self.intermediate_size_per_partition,
|
||||
"params_dtype": params_dtype,
|
||||
"weight_loader": self.weight_loader,
|
||||
}
|
||||
# need full intermediate size pre-sharding for WNA16 act order
|
||||
if (self.quant_method.__class__.__name__
|
||||
in ("GPTQMarlinMoEMethod",
|
||||
"CompressedTensorsWNA16MarlinMoEMethod",
|
||||
"CompressedTensorsWNA16MoEMethod")):
|
||||
moe_quant_params["intermediate_size_full"] = intermediate_size
|
||||
|
||||
self.quant_method.create_weights(layer=self, **moe_quant_params)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor = None,
|
||||
linear_weights: torch.Tensor = None):
|
||||
"""forward"""
|
||||
# TODO: Once the OOM issue for the TPU backend is resolved, we will
|
||||
# switch to using the moe_forward custom op.
|
||||
if current_platform.is_tpu():
|
||||
return self.forward_impl(hidden_states, router_logits)
|
||||
else:
|
||||
forward_context: ForwardContext = get_forward_context()
|
||||
self = forward_context.no_compile_layers[self.layer_name]
|
||||
assert self.quant_method is not None
|
||||
return self.forward_impl(hidden_states, router_logits, linear_weights)
|
||||
# return torch.ops.vllm.moe_forward(hidden_states, router_logits,
|
||||
# self.layer_name)
|
||||
|
||||
def forward_impl(self, hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
linear_weights: torch.Tensor = None):
|
||||
"""forward_impl"""
|
||||
assert self.quant_method is not None
|
||||
if (self.moe_parallel_config.use_pplx_kernels
|
||||
or self.moe_parallel_config.use_deepep_ll_kernels):
|
||||
return self.forward_impl_chunked(hidden_states, router_logits)
|
||||
|
||||
do_naive_dispatch_combine: bool = (
|
||||
self.dp_size > 1
|
||||
and not self.moe_parallel_config.use_deepep_ht_kernels)
|
||||
if do_naive_dispatch_combine:
|
||||
hidden_states, router_logits = get_ep_group().dispatch(
|
||||
hidden_states, router_logits)
|
||||
|
||||
# Matrix multiply.
|
||||
final_hidden_states = self.quant_method.apply(
|
||||
layer=self,
|
||||
x=hidden_states,
|
||||
router_logits=router_logits,
|
||||
top_k=self.top_k,
|
||||
renormalize=self.renormalize,
|
||||
use_grouped_topk=self.use_grouped_topk,
|
||||
global_num_experts=self.global_num_experts,
|
||||
expert_map=self.expert_map,
|
||||
topk_group=self.topk_group,
|
||||
num_expert_group=self.num_expert_group,
|
||||
custom_routing_function=self.custom_routing_function,
|
||||
scoring_func=self.scoring_func,
|
||||
e_score_correction_bias=self.e_score_correction_bias,
|
||||
activation=self.activation,
|
||||
apply_router_weight_on_input=self.apply_router_weight_on_input,
|
||||
enable_eplb=self.enable_eplb,
|
||||
expert_load_view=self.expert_load_view,
|
||||
logical_to_physical_map=self.logical_to_physical_map,
|
||||
logical_replica_count=self.logical_replica_count,
|
||||
linear_weights=linear_weights
|
||||
)
|
||||
|
||||
if do_naive_dispatch_combine:
|
||||
final_hidden_states = get_ep_group().combine(final_hidden_states)
|
||||
|
||||
if self.reduce_results and (self.tp_size > 1 or self.ep_size > 1):
|
||||
# Default set to False. (May have to add shared expert outputs.
|
||||
final_hidden_states = self.maybe_all_reduce_tensor_model_parallel(
|
||||
final_hidden_states)
|
||||
|
||||
return final_hidden_states
|
||||
Reference in New Issue
Block a user