Iluvatar-mrv100 SDK 4.3.0
This commit is contained in:
967
vllm/model_executor/layers/fused_moe/layer.py
Normal file
967
vllm/model_executor/layers/fused_moe/layer.py
Normal file
@@ -0,0 +1,967 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from abc import abstractmethod
|
||||
from enum import Enum
|
||||
from typing import Callable, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch.nn.parameter import UninitializedParameter
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.config import get_current_vllm_config
|
||||
from vllm.distributed import (get_dp_group, get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size,
|
||||
tensor_model_parallel_all_reduce)
|
||||
from vllm.forward_context import ForwardContext, get_forward_context
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.custom_op import CustomOp
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig, QuantizeMethodBase)
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.platforms.interface import CpuArchEnum
|
||||
from vllm.utils import direct_register_custom_op
|
||||
|
||||
if current_platform.is_cuda_alike():
|
||||
from .fused_moe import fused_experts
|
||||
else:
|
||||
fused_experts = None # type: ignore
|
||||
if current_platform.is_tpu():
|
||||
# the iterative moe implementation is used until the moe_pallas is fixed
|
||||
from .moe_torch_iterative import fused_moe as fused_moe_pallas
|
||||
else:
|
||||
fused_moe_pallas = None # type: ignore
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class FusedMoeWeightScaleSupported(Enum):
|
||||
TENSOR = "tensor"
|
||||
CHANNEL = "channel"
|
||||
GROUP = "group"
|
||||
BLOCK = "block"
|
||||
|
||||
|
||||
class FusedMoEMethodBase(QuantizeMethodBase):
|
||||
|
||||
@abstractmethod
|
||||
def create_weights(self, layer: torch.nn.Module, num_experts: int,
|
||||
hidden_size: int, intermediate_size_per_partition: int,
|
||||
params_dtype: torch.dtype, **extra_weight_attrs):
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
top_k: int,
|
||||
renormalize: bool,
|
||||
use_grouped_topk: bool = False,
|
||||
topk_group: Optional[int] = None,
|
||||
num_expert_group: Optional[int] = None,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
scoring_func: str = "softmax",
|
||||
e_score_correction_bias: Optional[torch.Tensor] = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
activation: str = "silu",
|
||||
) -> torch.Tensor:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@CustomOp.register("unquantized_fused_moe")
|
||||
class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
"""MoE method without quantization."""
|
||||
|
||||
def create_weights(self, layer: torch.nn.Module, num_experts: int,
|
||||
hidden_size: int, intermediate_size_per_partition: int,
|
||||
params_dtype: torch.dtype, **extra_weight_attrs):
|
||||
# Fused gate_up_proj (column parallel)
|
||||
w13_weight = torch.nn.Parameter(torch.empty(
|
||||
num_experts,
|
||||
2 * intermediate_size_per_partition,
|
||||
hidden_size,
|
||||
dtype=params_dtype),
|
||||
requires_grad=False)
|
||||
layer.register_parameter("w13_weight", w13_weight)
|
||||
set_weight_attrs(w13_weight, extra_weight_attrs)
|
||||
|
||||
# down_proj (row parallel)
|
||||
w2_weight = torch.nn.Parameter(torch.empty(
|
||||
num_experts,
|
||||
hidden_size,
|
||||
intermediate_size_per_partition,
|
||||
dtype=params_dtype),
|
||||
requires_grad=False)
|
||||
layer.register_parameter("w2_weight", w2_weight)
|
||||
set_weight_attrs(w2_weight, extra_weight_attrs)
|
||||
|
||||
def _maybe_pad_weight(self, weight: torch.Tensor) -> torch.Tensor:
|
||||
# Pad the weight tensor. This is an optimization on ROCm platform, which
|
||||
# can benefit from tensors located far enough from one another in memory
|
||||
if (envs.VLLM_ROCM_MOE_PADDING and current_platform.is_rocm()
|
||||
and weight.stride(-1) == 1
|
||||
and (weight.stride(-2) * weight.element_size()) % 512 == 0):
|
||||
num_pad = 256 // weight.element_size()
|
||||
weight = F.pad(weight, (0, num_pad), "constant", 0)[..., :-num_pad]
|
||||
torch.cuda.empty_cache()
|
||||
return weight
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
super().process_weights_after_loading(layer)
|
||||
|
||||
layer.w13_weight = torch.nn.Parameter(self._maybe_pad_weight(
|
||||
layer.w13_weight.data),
|
||||
requires_grad=False)
|
||||
layer.w2_weight = torch.nn.Parameter(self._maybe_pad_weight(
|
||||
layer.w2_weight.data),
|
||||
requires_grad=False)
|
||||
# Lazy import to avoid importing triton.
|
||||
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
|
||||
is_rocm_aiter_moe_enabled, shuffle_weights)
|
||||
if is_rocm_aiter_moe_enabled():
|
||||
# reshaping weights is required for aiter moe kernel.
|
||||
shuffled_w13, shuffled_w2 = shuffle_weights(
|
||||
layer.w13_weight.data, layer.w2_weight.data)
|
||||
|
||||
layer.w13_weight = torch.nn.Parameter(shuffled_w13,
|
||||
requires_grad=False)
|
||||
layer.w2_weight = torch.nn.Parameter(shuffled_w2,
|
||||
requires_grad=False)
|
||||
|
||||
if current_platform.is_cpu():
|
||||
if current_platform.get_cpu_architecture() == CpuArchEnum.X86:
|
||||
import intel_extension_for_pytorch as ipex
|
||||
layer.ipex_fusion = ipex.llm.modules.GatedMLPMOE(
|
||||
layer.w13_weight,
|
||||
layer.w2_weight,
|
||||
use_prepack=envs.VLLM_CPU_MOE_PREPACK,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError("CPU MOE only supports x86 arch.")
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
top_k: int,
|
||||
renormalize: bool,
|
||||
use_grouped_topk: bool = False,
|
||||
topk_group: Optional[int] = None,
|
||||
num_expert_group: Optional[int] = None,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
scoring_func: str = "softmax",
|
||||
e_score_correction_bias: Optional[torch.Tensor] = None,
|
||||
extra_residual: torch.Tensor = None,
|
||||
routed_scaling_factor: float = 1.0,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
activation: str = "silu",
|
||||
) -> torch.Tensor:
|
||||
out = self.forward(
|
||||
x=x,
|
||||
layer=layer,
|
||||
router_logits=router_logits,
|
||||
top_k=top_k,
|
||||
renormalize=renormalize,
|
||||
use_grouped_topk=use_grouped_topk,
|
||||
topk_group=topk_group,
|
||||
num_expert_group=num_expert_group,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map,
|
||||
custom_routing_function=custom_routing_function,
|
||||
scoring_func=scoring_func,
|
||||
e_score_correction_bias=e_score_correction_bias,
|
||||
activation=activation,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input) * routed_scaling_factor
|
||||
if extra_residual is not None:
|
||||
out += extra_residual
|
||||
return out
|
||||
|
||||
def forward_cuda(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
use_grouped_topk: bool,
|
||||
top_k: int,
|
||||
router_logits: torch.Tensor,
|
||||
renormalize: bool,
|
||||
topk_group: Optional[int] = None,
|
||||
num_expert_group: Optional[int] = None,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
scoring_func: str = "softmax",
|
||||
e_score_correction_bias: Optional[torch.Tensor] = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
activation: str = "silu",
|
||||
) -> torch.Tensor:
|
||||
topk_weights, topk_ids = FusedMoE.select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
use_grouped_topk=use_grouped_topk,
|
||||
top_k=top_k,
|
||||
renormalize=renormalize,
|
||||
topk_group=topk_group,
|
||||
num_expert_group=num_expert_group,
|
||||
custom_routing_function=custom_routing_function,
|
||||
scoring_func=scoring_func,
|
||||
e_score_correction_bias=e_score_correction_bias)
|
||||
|
||||
return fused_experts(
|
||||
hidden_states=x,
|
||||
w1=layer.w13_weight,
|
||||
w2=layer.w2_weight,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
inplace=True,
|
||||
activation=activation,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map)
|
||||
|
||||
def forward_cpu(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
use_grouped_topk: bool,
|
||||
top_k: int,
|
||||
router_logits: torch.Tensor,
|
||||
renormalize: bool,
|
||||
topk_group: Optional[int] = None,
|
||||
num_expert_group: Optional[int] = None,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
scoring_func: str = "softmax",
|
||||
e_score_correction_bias: Optional[torch.Tensor] = None,
|
||||
activation: str = "silu",
|
||||
apply_router_weight_on_input: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
assert activation == "silu", f"{activation} is not supported."
|
||||
assert apply_router_weight_on_input is False
|
||||
return layer.ipex_fusion(
|
||||
x,
|
||||
use_grouped_topk,
|
||||
top_k,
|
||||
router_logits,
|
||||
renormalize,
|
||||
topk_group,
|
||||
num_expert_group,
|
||||
custom_routing_function,
|
||||
scoring_func,
|
||||
e_score_correction_bias,
|
||||
)
|
||||
|
||||
def forward_hpu(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
use_grouped_topk: bool,
|
||||
top_k: int,
|
||||
router_logits: torch.Tensor,
|
||||
renormalize: bool,
|
||||
topk_group: Optional[int] = None,
|
||||
num_expert_group: Optional[int] = None,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
scoring_func: str = "softmax",
|
||||
e_score_correction_bias: Optional[torch.Tensor] = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
activation: str = "silu",
|
||||
) -> torch.Tensor:
|
||||
assert not use_grouped_topk
|
||||
assert num_expert_group is None
|
||||
assert topk_group is None
|
||||
assert custom_routing_function is None
|
||||
assert layer is not None
|
||||
assert apply_router_weight_on_input is False
|
||||
if scoring_func != "softmax":
|
||||
raise NotImplementedError(
|
||||
"Only softmax scoring function is supported for HPU.")
|
||||
if e_score_correction_bias is not None:
|
||||
raise NotImplementedError(
|
||||
"Expert score correction bias is not supported for HPU.")
|
||||
return layer.hpu_fused_moe(x, layer.w13_weight, layer.w2_weight,
|
||||
router_logits, top_k)
|
||||
|
||||
def forward_tpu(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
use_grouped_topk: bool,
|
||||
top_k: int,
|
||||
router_logits: torch.Tensor,
|
||||
renormalize: bool,
|
||||
topk_group: Optional[int] = None,
|
||||
num_expert_group: Optional[int] = None,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
scoring_func: str = "softmax",
|
||||
e_score_correction_bias: Optional[torch.Tensor] = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
activation: str = "silu",
|
||||
) -> torch.Tensor:
|
||||
assert not use_grouped_topk
|
||||
assert num_expert_group is None
|
||||
assert topk_group is None
|
||||
assert custom_routing_function is None
|
||||
assert apply_router_weight_on_input is False
|
||||
if scoring_func != "softmax":
|
||||
raise NotImplementedError(
|
||||
"Only softmax scoring function is supported for TPU.")
|
||||
if e_score_correction_bias is not None:
|
||||
raise NotImplementedError(
|
||||
"Expert score correction bias is not supported for TPU.")
|
||||
assert activation == "silu", f"{activation} is not supported for TPU."
|
||||
return fused_moe_pallas(hidden_states=x,
|
||||
w1=layer.w13_weight,
|
||||
w2=layer.w2_weight,
|
||||
topk=top_k,
|
||||
gating_output=router_logits,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map,
|
||||
renormalize=renormalize)
|
||||
|
||||
forward_native = forward_tpu if current_platform.is_tpu() else forward_cuda
|
||||
|
||||
|
||||
def determine_expert_map(
|
||||
ep_size: int, ep_rank: int,
|
||||
global_num_experts: int) -> Tuple[int, Optional[torch.Tensor]]:
|
||||
"""
|
||||
Calculates how many experts should be assigned to each rank for EP and
|
||||
creates a mapping from global to local expert index. Experts are
|
||||
distributed evenly across ranks. Any remaining are assigned to the
|
||||
last rank.
|
||||
|
||||
Args:
|
||||
ep_size (int): The size of the expert parallel group
|
||||
global_num_experts (int): The total number of experts in the model.
|
||||
|
||||
Returns:
|
||||
Tuple[int, Optional[torch.Tensor]]: A tuple containing:
|
||||
- local_num_experts (int): The number of experts assigned
|
||||
to the current rank.
|
||||
- expert_map (Optional[torch.Tensor]): A tensor of shape
|
||||
(global_num_experts,) mapping from global to local index.
|
||||
Contains -1 for experts not assigned to the current rank.
|
||||
Returns None if ep_size is 1.
|
||||
"""
|
||||
assert ep_size > 0
|
||||
if ep_size == 1:
|
||||
return (global_num_experts, None)
|
||||
|
||||
local_num_experts = global_num_experts // ep_size
|
||||
|
||||
# Create a tensor of size num_experts filled with -1
|
||||
expert_map = torch.full((global_num_experts, ), -1, dtype=torch.int32)
|
||||
# Create a expert map for the local experts
|
||||
if ep_rank < (ep_size - 1):
|
||||
# Each non-last rank gets local_num_experts experts.
|
||||
expert_map[ep_rank * local_num_experts:
|
||||
(ep_rank + 1) * local_num_experts] = \
|
||||
torch.arange(0, local_num_experts, dtype=torch.int32)
|
||||
else:
|
||||
# All remaining experts are assigned to the last rank.
|
||||
local_num_experts = (global_num_experts - ep_rank * local_num_experts)
|
||||
|
||||
expert_map[-local_num_experts:] = \
|
||||
torch.arange(0, local_num_experts, dtype=torch.int32)
|
||||
return (local_num_experts, expert_map)
|
||||
|
||||
|
||||
class FusedMoE(torch.nn.Module):
|
||||
"""FusedMoE layer for MoE models.
|
||||
|
||||
This layer contains both MergedColumnParallel weights (gate_up_proj /
|
||||
w13) and RowParallelLinear weights (down_proj/ w2).
|
||||
|
||||
Note: Mixtral uses w1, w2, and w3 for gate, up, and down_proj. We
|
||||
copy that naming convention here and handle any remapping in the
|
||||
load_weights function in each model implementation.
|
||||
|
||||
Args:
|
||||
num_experts: Number of experts in the model
|
||||
top_k: Number of experts selected for each token
|
||||
hidden_size: Input hidden state size of the transformer
|
||||
intermediate_size: Intermediate size of the experts
|
||||
params_dtype: Data type for the parameters.
|
||||
reduce_results: Whether to all all_reduce on the output of the layer
|
||||
renomalize: Whether to renormalize the logits in the fused_moe kernel
|
||||
quant_config: Quantization configure.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_experts: int, # Global number of experts
|
||||
top_k: int,
|
||||
hidden_size: int,
|
||||
intermediate_size: int,
|
||||
params_dtype: Optional[torch.dtype] = None,
|
||||
reduce_results: bool = False,
|
||||
renormalize: bool = True,
|
||||
use_grouped_topk: bool = False,
|
||||
num_expert_group: Optional[int] = None,
|
||||
topk_group: Optional[int] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
tp_size: Optional[int] = None,
|
||||
ep_size: Optional[int] = None,
|
||||
dp_size: Optional[int] = None,
|
||||
prefix: str = "",
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
scoring_func: str = "softmax",
|
||||
e_score_correction_bias: Optional[torch.Tensor] = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
activation: str = "silu",
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
if params_dtype is None:
|
||||
params_dtype = torch.get_default_dtype()
|
||||
|
||||
# Note: here we guard against accessing the TP and DP groups when
|
||||
# uninitialized (this happens when testing)
|
||||
self.tp_size = (tp_size if tp_size is not None else
|
||||
get_tensor_model_parallel_world_size())
|
||||
tp_rank = 0 if self.tp_size == 1 else get_tensor_model_parallel_rank()
|
||||
self.dp_size = (dp_size
|
||||
if dp_size is not None else get_dp_group().world_size)
|
||||
self.dp_rank = (0
|
||||
if self.dp_size == 1 else get_dp_group().rank_in_group)
|
||||
self.global_num_experts = num_experts
|
||||
|
||||
# Use expert parallelism instead of tensor parallelism?
|
||||
vllm_config = get_current_vllm_config()
|
||||
use_ep = (vllm_config.parallel_config.enable_expert_parallel
|
||||
and self.tp_size > 1)
|
||||
|
||||
# For smuggling this layer into the fused moe custom op
|
||||
# self.use_direct_call = self.dp_size == 1
|
||||
self.use_direct_call = True
|
||||
if not self.use_direct_call:
|
||||
compilation_config = vllm_config.compilation_config
|
||||
if prefix in compilation_config.static_forward_context:
|
||||
raise ValueError("Duplicate layer name: {}".format(prefix))
|
||||
compilation_config.static_forward_context[prefix] = self
|
||||
self.layer_name = prefix
|
||||
|
||||
if use_ep:
|
||||
# Set TP size to 1 to adjust for EP and adjust EP size and rank
|
||||
# for DP attention.
|
||||
self.ep_rank = tp_rank + self.tp_size * self.dp_rank
|
||||
self.tp_rank = 0
|
||||
self.ep_size = self.tp_size * self.dp_size
|
||||
self.tp_size = 1
|
||||
|
||||
self.local_num_experts, self.expert_map = determine_expert_map(
|
||||
ep_size=self.ep_size,
|
||||
ep_rank=self.ep_rank,
|
||||
global_num_experts=self.global_num_experts)
|
||||
else:
|
||||
# Adjust TP size for DP attention
|
||||
self.tp_rank = tp_rank + self.tp_size * self.dp_rank
|
||||
self.ep_rank = 0
|
||||
self.tp_size = self.tp_size * self.dp_size
|
||||
self.ep_size = 1
|
||||
self.local_num_experts = self.global_num_experts
|
||||
self.expert_map = None
|
||||
self.top_k = top_k
|
||||
self.global_num_experts = num_experts
|
||||
|
||||
self.hidden_size = hidden_size
|
||||
self.num_experts = num_experts
|
||||
assert intermediate_size % self.tp_size == 0
|
||||
self.intermediate_size_per_partition = intermediate_size // self.tp_size
|
||||
self.reduce_results = reduce_results
|
||||
self.renormalize = renormalize
|
||||
self.use_grouped_topk = use_grouped_topk
|
||||
if self.use_grouped_topk:
|
||||
assert num_expert_group is not None and topk_group is not None
|
||||
self.num_expert_group = num_expert_group
|
||||
self.topk_group = topk_group
|
||||
self.custom_routing_function = custom_routing_function
|
||||
self.scoring_func = scoring_func
|
||||
self.e_score_correction_bias = e_score_correction_bias
|
||||
self.activation = activation
|
||||
|
||||
if self.scoring_func != "softmax" and not self.use_grouped_topk:
|
||||
raise ValueError("Only softmax scoring function is supported for "
|
||||
"non-grouped topk.")
|
||||
if current_platform.is_hpu():
|
||||
from vllm_hpu_extension.ops import DynamicFusedMOE
|
||||
self.hpu_fused_moe = DynamicFusedMOE(self.global_num_experts)
|
||||
|
||||
# Note: get_quant_method will look at the layer's local_num_experts
|
||||
# for heuristic purposes, so it must be initialized first.
|
||||
if quant_config is None:
|
||||
self.quant_method: Optional[QuantizeMethodBase] = (
|
||||
UnquantizedFusedMoEMethod())
|
||||
else:
|
||||
self.quant_method = quant_config.get_quant_method(self, prefix)
|
||||
assert self.quant_method is not None
|
||||
|
||||
self.apply_router_weight_on_input = apply_router_weight_on_input
|
||||
moe_quant_params = {
|
||||
"num_experts": self.local_num_experts,
|
||||
"hidden_size": hidden_size,
|
||||
"intermediate_size_per_partition":
|
||||
self.intermediate_size_per_partition,
|
||||
"params_dtype": params_dtype,
|
||||
"weight_loader": self.weight_loader,
|
||||
}
|
||||
# need full intermediate size pre-sharding for WNA16 act order
|
||||
if (self.quant_method.__class__.__name__
|
||||
in ("GPTQMarlinMoEMethod", "CompressedTensorsWNA16MoEMethod")):
|
||||
moe_quant_params["intermediate_size_full"] = intermediate_size
|
||||
|
||||
self.quant_method.create_weights(layer=self, **moe_quant_params)
|
||||
|
||||
def _load_per_tensor_weight_scale(self, shard_id: str,
|
||||
param: torch.nn.Parameter,
|
||||
loaded_weight: torch.Tensor,
|
||||
expert_id: int):
|
||||
param_data = param.data
|
||||
# for per tensor weight quantization
|
||||
if shard_id in ("w1", "w3"):
|
||||
# We have to keep the weight scales of w1 and w3 because
|
||||
# we need to re-quantize w1/w3 weights after weight loading.
|
||||
idx = 0 if shard_id == "w1" else 1
|
||||
param_data[expert_id][idx] = loaded_weight
|
||||
# If we are in the row parallel case (down_proj)
|
||||
elif shard_id == "w2":
|
||||
param_data[expert_id] = loaded_weight
|
||||
|
||||
def _load_model_weight_or_group_weight_scale(self,
|
||||
shard_dim: int,
|
||||
expert_data: torch.Tensor,
|
||||
shard_id: str,
|
||||
loaded_weight: torch.Tensor,
|
||||
tp_rank: int,
|
||||
load_full_w2: bool = False):
|
||||
"""
|
||||
Load grouped weight scales for group quantization or model weights
|
||||
:param shard_dim: dimension to shard
|
||||
:param expert_data: parameter for a particular expert
|
||||
:param shard_id: either w1, w2, or w3
|
||||
:param loaded_weight: checkpoint weight to load into the param
|
||||
:param tp_rank: tensor parallel rank
|
||||
:param load_full_w2: whether or not the w2 loaded should be sharded.
|
||||
"""
|
||||
if shard_id == "w2":
|
||||
# In the case where we have actorder/g_idx, we do not partition the
|
||||
# w2 scales, as indicated by `load_full` argument, for all tp cases
|
||||
self._load_w2(shard_dim=shard_dim,
|
||||
loaded_weight=loaded_weight,
|
||||
expert_data=expert_data,
|
||||
tp_rank=tp_rank,
|
||||
load_full=load_full_w2)
|
||||
elif shard_id in ("w1", "w3"):
|
||||
self._load_w13(shard_id=shard_id,
|
||||
shard_dim=shard_dim,
|
||||
loaded_weight=loaded_weight,
|
||||
expert_data=expert_data,
|
||||
tp_rank=tp_rank)
|
||||
|
||||
def _load_per_channel_weight_scale(self, expert_data: torch.Tensor,
|
||||
shard_dim: int, shard_id: str,
|
||||
loaded_weight: torch.Tensor,
|
||||
tp_rank: int):
|
||||
# for per channel weight quantization
|
||||
if shard_id == "w2":
|
||||
expert_data.copy_(loaded_weight)
|
||||
elif shard_id in ("w1", "w3"):
|
||||
self._load_w13(shard_id=shard_id,
|
||||
shard_dim=shard_dim,
|
||||
loaded_weight=loaded_weight,
|
||||
expert_data=expert_data,
|
||||
tp_rank=tp_rank)
|
||||
|
||||
def _load_w13(self, expert_data: torch.Tensor, shard_dim: int,
|
||||
shard_id: str, loaded_weight: torch.Tensor, tp_rank: int):
|
||||
|
||||
# Index the loaded weight for tp sharding.
|
||||
# gate_up_proj: "MergedColumnParallel", so tp sharding on output_dim
|
||||
shard_size = expert_data.shape[shard_dim] // 2
|
||||
loaded_weight = loaded_weight.narrow(shard_dim, shard_size * tp_rank,
|
||||
shard_size)
|
||||
# Narrow parameter and load.
|
||||
# w1, gate_proj: Load into first logical weight of w13.
|
||||
if shard_id == "w1":
|
||||
expert_data = expert_data.narrow(shard_dim, 0, shard_size)
|
||||
# w3, up_proj: Load into second logical weight of w13.
|
||||
else:
|
||||
assert shard_id == "w3"
|
||||
expert_data = expert_data.narrow(shard_dim, shard_size, shard_size)
|
||||
expert_data.copy_(loaded_weight)
|
||||
|
||||
def _load_w2(self,
|
||||
expert_data: torch.Tensor,
|
||||
shard_dim: int,
|
||||
loaded_weight: torch.Tensor,
|
||||
tp_rank: int,
|
||||
load_full: bool = False):
|
||||
|
||||
# Index the loaded weight for tp sharding.
|
||||
# down_proj: "RowParallel" so tp sharding on input_dim
|
||||
# Narrow parameter and load.
|
||||
shard_size = expert_data.shape[shard_dim]
|
||||
if not load_full:
|
||||
loaded_weight = loaded_weight.narrow(shard_dim,
|
||||
shard_size * tp_rank,
|
||||
shard_size)
|
||||
# w2, down_proj: Load into only logical weight of w2.
|
||||
expert_data.copy_(loaded_weight)
|
||||
|
||||
def _load_single_value(self, param: torch.nn.Parameter,
|
||||
loaded_weight: torch.Tensor, expert_id: int):
|
||||
param_data = param.data
|
||||
|
||||
# Input scales can be loaded directly and should be equal.
|
||||
param_data[expert_id] = loaded_weight
|
||||
|
||||
def _load_g_idx(self, shard_id: str, expert_data: torch.Tensor,
|
||||
shard_dim: int, loaded_weight: torch.Tensor, tp_rank: int):
|
||||
|
||||
if shard_id == "w2":
|
||||
self._load_w2(shard_dim=shard_dim,
|
||||
loaded_weight=loaded_weight,
|
||||
expert_data=expert_data,
|
||||
tp_rank=tp_rank)
|
||||
else:
|
||||
assert shard_id in ("w1", "w3")
|
||||
expert_data.copy_(loaded_weight)
|
||||
|
||||
def _map_global_expert_id_to_local_expert_id(self, expert_id: int) -> int:
|
||||
if self.expert_map is None:
|
||||
return expert_id
|
||||
return self.expert_map[expert_id].item()
|
||||
|
||||
def weight_loader(self, param: torch.nn.Parameter,
|
||||
loaded_weight: torch.Tensor, weight_name: str,
|
||||
shard_id: str, expert_id: int) -> None:
|
||||
|
||||
expert_id = self._map_global_expert_id_to_local_expert_id(expert_id)
|
||||
if expert_id == -1:
|
||||
return
|
||||
|
||||
# compressed-tensors checkpoints with packed weights are stored flipped
|
||||
# TODO (mgoin): check self.quant_method.quant_config.quant_format
|
||||
# against known CompressionFormat enum values that have this quality
|
||||
loaded_weight = loaded_weight.t().contiguous() if (
|
||||
self.quant_method.__class__.__name__
|
||||
== "CompressedTensorsWNA16MoEMethod") else loaded_weight
|
||||
|
||||
if shard_id not in ("w1", "w2", "w3"):
|
||||
raise ValueError(f"shard_id must be ['w1','w2','w3'] but "
|
||||
f"got {shard_id}.")
|
||||
|
||||
WEIGHT_SCALE_SUPPORTED = [
|
||||
e.value for e in FusedMoeWeightScaleSupported
|
||||
]
|
||||
# Fetch the dim to shard the parameter/loaded weight
|
||||
# based on the shard id. This will be whatever
|
||||
# dimension intermediate_size_per_partition is used.
|
||||
SHARD_ID_TO_SHARDED_DIM = {"w1": 0, "w2": 1, "w3": 0}
|
||||
|
||||
is_gguf_weight = getattr(param, "is_gguf_weight", False)
|
||||
is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
|
||||
if is_gguf_weight_type:
|
||||
param.weight_type = loaded_weight.item()
|
||||
param.data.copy_(loaded_weight)
|
||||
return
|
||||
|
||||
# is_transposed: if the dim to shard the weight
|
||||
# should be flipped. Required by GPTQ, compressed-tensors
|
||||
# should be whatever dimension intermediate_size_per_partition is
|
||||
is_transposed = getattr(param, "is_transposed", False)
|
||||
shard_dim = SHARD_ID_TO_SHARDED_DIM[shard_id]
|
||||
if is_transposed:
|
||||
shard_dim = int(not shard_dim)
|
||||
shard_dim_force = getattr(param, "shard_dim", None)
|
||||
shard_dim = shard_dim_force if shard_dim_force is not None else shard_dim
|
||||
|
||||
full_load = len(loaded_weight.shape) == 3
|
||||
if full_load:
|
||||
shard_dim += 1
|
||||
|
||||
# Materialize GGUF UninitializedParameter
|
||||
if is_gguf_weight and isinstance(param, UninitializedParameter):
|
||||
final_shape = list(loaded_weight.shape)
|
||||
if shard_id in ["w1", "w3"]:
|
||||
final_shape[1] *= 2
|
||||
final_shape[shard_dim] = final_shape[shard_dim] // self.tp_size
|
||||
param.materialize(final_shape, dtype=loaded_weight.dtype)
|
||||
|
||||
expert_data = param.data if full_load else param.data[expert_id]
|
||||
# Case input scale: input_scale loading is only supported for fp8
|
||||
if "input_scale" in weight_name:
|
||||
# this is needed for compressed-tensors only
|
||||
loaded_weight = loaded_weight.to(param.data.device)
|
||||
|
||||
if param.data[expert_id] != 1 and (param.data[expert_id] -
|
||||
loaded_weight).abs() > 1e-5:
|
||||
raise ValueError(
|
||||
"input_scales of w1 and w3 of a layer "
|
||||
f"must be equal. But got {param.data[expert_id]} "
|
||||
f"vs. {loaded_weight}")
|
||||
|
||||
self._load_single_value(param=param,
|
||||
loaded_weight=loaded_weight,
|
||||
expert_id=expert_id)
|
||||
return
|
||||
|
||||
# Case g_idx
|
||||
if "g_idx" in weight_name:
|
||||
self._load_g_idx(shard_dim=0,
|
||||
shard_id=shard_id,
|
||||
loaded_weight=loaded_weight,
|
||||
expert_data=expert_data,
|
||||
tp_rank=self.tp_rank)
|
||||
return
|
||||
|
||||
# Case weight scales, zero_points and offset
|
||||
if ("scale" in weight_name or "zero" in weight_name
|
||||
or "offset" in weight_name):
|
||||
# load the weight scales and zp based on the quantization scheme
|
||||
# supported weight scales/zp can be found in
|
||||
# FusedMoeWeightScaleSupported
|
||||
# TODO @dsikka: once hardened, refactor to use vLLM Parameters
|
||||
# specific to each case
|
||||
quant_method = getattr(param, "quant_method", None)
|
||||
if quant_method == FusedMoeWeightScaleSupported.CHANNEL.value:
|
||||
self._load_per_channel_weight_scale(
|
||||
shard_id=shard_id,
|
||||
shard_dim=shard_dim,
|
||||
loaded_weight=loaded_weight,
|
||||
expert_data=expert_data,
|
||||
tp_rank=self.tp_rank)
|
||||
elif quant_method in [
|
||||
FusedMoeWeightScaleSupported.GROUP.value,
|
||||
FusedMoeWeightScaleSupported.BLOCK.value,
|
||||
]:
|
||||
self._load_model_weight_or_group_weight_scale(
|
||||
shard_id=shard_id,
|
||||
shard_dim=shard_dim,
|
||||
loaded_weight=loaded_weight,
|
||||
expert_data=expert_data,
|
||||
tp_rank=self.tp_rank,
|
||||
load_full_w2=getattr(param, "load_full_w2", False))
|
||||
elif quant_method == FusedMoeWeightScaleSupported.TENSOR.value:
|
||||
self._load_per_tensor_weight_scale(shard_id=shard_id,
|
||||
param=param,
|
||||
loaded_weight=loaded_weight,
|
||||
expert_id=expert_id)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"quant method must be one of {WEIGHT_SCALE_SUPPORTED}")
|
||||
return
|
||||
|
||||
# Case weight_shape
|
||||
if "weight_shape" in weight_name:
|
||||
# only required by compressed-tensors
|
||||
self._load_single_value(param=param,
|
||||
loaded_weight=loaded_weight,
|
||||
expert_id=expert_id)
|
||||
return
|
||||
|
||||
# Case model weights
|
||||
if "weight" in weight_name:
|
||||
self._load_model_weight_or_group_weight_scale(
|
||||
shard_id=shard_id,
|
||||
shard_dim=shard_dim,
|
||||
loaded_weight=loaded_weight,
|
||||
expert_data=expert_data,
|
||||
tp_rank=self.tp_rank)
|
||||
return
|
||||
|
||||
@staticmethod
|
||||
def select_experts(hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
top_k: int,
|
||||
use_grouped_topk: bool,
|
||||
renormalize: bool,
|
||||
topk_group: Optional[int] = None,
|
||||
num_expert_group: Optional[int] = None,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
scoring_func: str = "softmax",
|
||||
e_score_correction_bias: Optional[torch.Tensor] = None,
|
||||
topk_ids: Optional[torch.Tensor] = None,):
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
|
||||
from ixformer.inference.functions import moe_grouped_topk as grouped_topk
|
||||
|
||||
# DeekSeekv2 uses grouped_top_k
|
||||
if use_grouped_topk:
|
||||
assert topk_group is not None
|
||||
assert num_expert_group is not None
|
||||
topk_weights, topk_ids = grouped_topk(
|
||||
gating_output=router_logits,
|
||||
topk=top_k,
|
||||
renormalize=renormalize,
|
||||
num_expert_group=num_expert_group,
|
||||
topk_group=topk_group,
|
||||
scoring_func=scoring_func,
|
||||
e_score_correction_bias=e_score_correction_bias,
|
||||
topk_ids=topk_ids,
|
||||
)
|
||||
elif custom_routing_function is None:
|
||||
topk_weights, topk_ids = fused_topk(hidden_states=hidden_states,
|
||||
gating_output=router_logits,
|
||||
topk=top_k,
|
||||
renormalize=renormalize)
|
||||
else:
|
||||
topk_weights, topk_ids = custom_routing_function(
|
||||
hidden_states=hidden_states,
|
||||
gating_output=router_logits,
|
||||
topk=top_k,
|
||||
renormalize=renormalize)
|
||||
|
||||
return topk_weights, topk_ids
|
||||
|
||||
def naive_multicast(self, x: torch.Tensor,
|
||||
cu_tokens_across_dp_cpu: torch.Tensor):
|
||||
assert (len(x.shape) == 2)
|
||||
buffer = torch.empty((cu_tokens_across_dp_cpu[-1], x.size(1)),
|
||||
device=x.device,
|
||||
dtype=x.dtype)
|
||||
|
||||
start = 0 if self.dp_rank == 0 else cu_tokens_across_dp_cpu[
|
||||
self.dp_rank - 1]
|
||||
end = cu_tokens_across_dp_cpu[self.dp_rank]
|
||||
buffer[start:end, :].copy_(x)
|
||||
for idx in range(get_dp_group().world_size):
|
||||
start = 0 if idx == 0 else cu_tokens_across_dp_cpu[idx - 1]
|
||||
end = cu_tokens_across_dp_cpu[idx]
|
||||
get_dp_group().broadcast(buffer[start:end, :], idx)
|
||||
|
||||
return buffer
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
extra_residual: torch.Tensor = None,
|
||||
routed_scaling_factor: float = 1.0):
|
||||
if self.use_direct_call:
|
||||
return self.forward_impl(hidden_states, router_logits, extra_residual, routed_scaling_factor)
|
||||
else:
|
||||
return torch.ops.vllm.moe_forward(hidden_states, router_logits,
|
||||
self.layer_name)
|
||||
|
||||
def forward_impl(self, hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
extra_residual: torch.Tensor = None,
|
||||
routed_scaling_factor: float = 1.0):
|
||||
assert self.quant_method is not None
|
||||
|
||||
if self.dp_size > 1:
|
||||
cu_tokens_across_dp_cpu = get_forward_context(
|
||||
).dp_metadata.cu_tokens_across_dp_cpu
|
||||
|
||||
hidden_states = self.naive_multicast(hidden_states,
|
||||
cu_tokens_across_dp_cpu)
|
||||
router_logits = self.naive_multicast(router_logits,
|
||||
cu_tokens_across_dp_cpu)
|
||||
|
||||
# Matrix multiply.
|
||||
final_hidden_states = self.quant_method.apply(
|
||||
layer=self,
|
||||
x=hidden_states,
|
||||
router_logits=router_logits,
|
||||
top_k=self.top_k,
|
||||
renormalize=self.renormalize,
|
||||
use_grouped_topk=self.use_grouped_topk,
|
||||
global_num_experts=self.global_num_experts,
|
||||
expert_map=self.expert_map,
|
||||
topk_group=self.topk_group,
|
||||
num_expert_group=self.num_expert_group,
|
||||
custom_routing_function=self.custom_routing_function,
|
||||
scoring_func=self.scoring_func,
|
||||
e_score_correction_bias=self.e_score_correction_bias,
|
||||
activation=self.activation,
|
||||
apply_router_weight_on_input=self.apply_router_weight_on_input,
|
||||
extra_residual=extra_residual,
|
||||
routed_scaling_factor=routed_scaling_factor,
|
||||
)
|
||||
|
||||
if self.dp_size > 1:
|
||||
start = 0 if self.dp_rank == 0 else cu_tokens_across_dp_cpu[
|
||||
self.dp_rank - 1]
|
||||
end = cu_tokens_across_dp_cpu[self.dp_rank]
|
||||
|
||||
all_hidden_states = get_dp_group().all_reduce(final_hidden_states)
|
||||
final_hidden_states = all_hidden_states[start:end, :]
|
||||
|
||||
if self.reduce_results and (self.tp_size > 1 or self.ep_size > 1):
|
||||
# Default set to False. (May have to add shared expert outputs.)
|
||||
final_hidden_states = tensor_model_parallel_all_reduce(
|
||||
final_hidden_states)
|
||||
|
||||
return final_hidden_states
|
||||
|
||||
@classmethod
|
||||
def make_expert_params_mapping(
|
||||
cls, ckpt_gate_proj_name: str, ckpt_down_proj_name: str,
|
||||
ckpt_up_proj_name: str,
|
||||
num_experts: int) -> List[Tuple[str, str, int, str]]:
|
||||
|
||||
return [
|
||||
# (param_name, weight_name, expert_id, shard_id)
|
||||
("experts.w13_" if weight_name
|
||||
in [ckpt_gate_proj_name, ckpt_up_proj_name] else "experts.w2_",
|
||||
f"experts.{expert_id}.{weight_name}.", expert_id, shard_id)
|
||||
for expert_id in range(num_experts) for shard_id, weight_name in [
|
||||
("w1", ckpt_gate_proj_name),
|
||||
("w2", ckpt_down_proj_name),
|
||||
("w3", ckpt_up_proj_name),
|
||||
]
|
||||
]
|
||||
|
||||
def extra_repr(self) -> str:
|
||||
|
||||
s = (
|
||||
f"global_num_experts={self.global_num_experts}, "
|
||||
f"local_num_experts={self.local_num_experts}, "
|
||||
f"top_k={self.top_k}, "
|
||||
f"intermediate_size_per_partition={self.intermediate_size_per_partition}, " # noqa: E501
|
||||
f"tp_size={self.tp_size},\n"
|
||||
f"ep_size={self.ep_size}, "
|
||||
f"reduce_results={self.reduce_results}, "
|
||||
f"renormalize={self.renormalize}, "
|
||||
f"use_grouped_topk={self.use_grouped_topk}")
|
||||
|
||||
if self.use_grouped_topk:
|
||||
s += f", num_expert_group={self.num_expert_group}, topk_group={self.topk_group}" # noqa: E501
|
||||
|
||||
s += f", scoring_func='{self.scoring_func}', activation='{self.activation}'" # noqa: E501
|
||||
|
||||
return s
|
||||
|
||||
|
||||
def moe_forward(hidden_states: torch.Tensor, router_logits: torch.Tensor,
|
||||
layer_name: str) -> torch.Tensor:
|
||||
forward_context: ForwardContext = get_forward_context()
|
||||
self = forward_context.no_compile_layers[layer_name]
|
||||
assert self.quant_method is not None
|
||||
|
||||
return self.forward_impl(hidden_states, router_logits)
|
||||
|
||||
|
||||
def moe_forward_fake(hidden_states: torch.Tensor, router_logits: torch.Tensor,
|
||||
layer_name: str) -> torch.Tensor:
|
||||
return torch.empty_like(hidden_states)
|
||||
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="moe_forward",
|
||||
op_func=moe_forward,
|
||||
mutates_args=[],
|
||||
fake_impl=moe_forward_fake,
|
||||
dispatch_key=current_platform.dispatch_key,
|
||||
)
|
||||
Reference in New Issue
Block a user