This commit is contained in:
2026-04-02 04:53:13 +00:00
parent 80932c96e5
commit 24df76db9d
1987 changed files with 447445 additions and 0 deletions

View File

@@ -0,0 +1,82 @@
import torch
from typing import Optional, Tuple
def fused_topk(
hidden_states: torch.Tensor,
gating_output: torch.Tensor,
topk: int,
renormalize: bool,
) -> Tuple[torch.Tensor, torch.Tensor]:
hidden_size = hidden_states.shape[-1]
num_tokens = hidden_states.shape[:-1].numel()
dtype = hidden_states.dtype
hidden_states = hidden_states.view(num_tokens, hidden_size)
gating_output = gating_output.view(num_tokens, -1)
topk_weights = gating_output.softmax(dim=-1, dtype=torch.float)
topk_weights, selected_experts = topk_weights.topk(topk, dim=-1)
if renormalize:
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
topk_weights = topk_weights.to(dtype)
return topk_weights, selected_experts
def grouped_topk_with_itype(hidden_states: torch.Tensor,
gating_output: torch.Tensor,
topk: int,
renormalize: bool,
num_expert_group: int = 0,
topk_group: int = 0,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None):
assert hidden_states.shape[0] == gating_output.shape[0], (
"Number of tokens mismatch")
try:
from torch_vacc.vacc.custom_ops import fused_moe_preprocess
return fused_moe_preprocess(gating_output, e_score_correction_bias)
except Exception as e:
print(f"fused group topk run fail, now use unfused group topk: {e}")
if scoring_func == "softmax":
scores = torch.softmax(gating_output, dim=-1)
elif scoring_func == "sigmoid":
scores = gating_output.sigmoid()
else:
raise ValueError(f"Unsupported scoring function: {scoring_func}")
num_token = scores.shape[0]
if e_score_correction_bias is not None:
# Store original scores before applying correction bias. We use biased
# scores for expert selection but original scores for routing weights
original_scores = scores
scores = scores + e_score_correction_bias.unsqueeze(0)
group_scores = (scores.view(num_token, num_expert_group,
-1).topk(2, dim=-1)[0].sum(dim=-1))
else:
group_scores = scores.view(num_token, num_expert_group,
-1).max(dim=-1).values # [n, n_group]
group_idx = torch.topk(group_scores, k=topk_group, dim=-1,
sorted=False)[1] # [n, top_k_group]
group_mask = torch.zeros_like(group_scores) # [n, n_group]
group_mask.scatter_(1, group_idx, 1) # [n, n_group]
score_mask = group_mask.unsqueeze(-1).expand(
num_token, num_expert_group,
scores.shape[-1] // num_expert_group).reshape(num_token, -1) # [n, e]
tmp_scores = scores.masked_fill(~score_mask.bool(),
float("-inf")) # [n, e]
if e_score_correction_bias is not None:
topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False)[1]
# Use original unbiased scores for the routing weights
topk_weights = original_scores.gather(1, topk_ids)
else:
topk_weights, topk_ids = torch.topk(tmp_scores,
k=topk,
dim=-1,
sorted=False)
if renormalize:
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
return topk_weights.to(hidden_states.dtype), topk_ids.to(torch.int32)

View File

@@ -0,0 +1,715 @@
# SPDX-License-Identifier: Apache-2.0
from abc import abstractmethod
from enum import Enum
from typing import Callable, List, Optional, Tuple
import torch
import torch.nn.functional as F
from torch.nn.parameter import UninitializedParameter
import vllm.envs as envs
from vllm.config import get_current_vllm_config
from vllm.distributed import (get_dp_group, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce)
from vllm.forward_context import ForwardContext, get_forward_context
from vllm.logger import init_logger
from vllm.model_executor.custom_op import CustomOp
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig, FusedMoEParallelConfig)
from vllm.model_executor.layers.fused_moe import FusedMoEMethodBase
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform
from vllm.platforms.interface import CpuArchEnum
from vllm.utils import direct_register_custom_op
if current_platform.is_cuda_alike():
from .fused_moe import fused_experts
else:
fused_experts = None # type: ignore
if current_platform.is_tpu():
# the iterative moe implementation is used until the moe_pallas is fixed
from .moe_torch_iterative import fused_moe as fused_moe_pallas
else:
fused_moe_pallas = None # type: ignore
logger = init_logger(__name__)
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase)
class FusedMoeWeightScaleSupported(Enum):
TENSOR = "tensor"
CHANNEL = "channel"
GROUP = "group"
BLOCK = "block"
def FusedMoE_init_(
self,
num_experts: int, # Global number of experts
top_k: int,
hidden_size: int,
intermediate_size: int,
params_dtype: Optional[torch.dtype] = None,
reduce_results: bool = False,
renormalize: bool = True,
use_grouped_topk: bool = False,
num_expert_group: Optional[int] = None,
topk_group: Optional[int] = None,
quant_config: Optional[QuantizationConfig] = None,
tp_size: Optional[int] = None,
ep_size: Optional[int] = None,
dp_size: Optional[int] = None,
prefix: str = "",
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0,
e_score_correction_bias: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
enable_eplb: bool = False,
num_redundant_experts: int = 0,
has_bias: bool = False,
is_sequence_parallel=False,
):
from vllm.model_executor.layers.fused_moe import FusedMoE
super(FusedMoE, self).__init__()
if params_dtype is None:
params_dtype = torch.get_default_dtype()
self.params_dtype = params_dtype
vllm_config = get_current_vllm_config()
# FIXME (varun): We should have a better way of inferring the activation
# datatype. This works for now as the tensor datatype entering the MoE
# operation is typically unquantized (i.e. float16/bfloat16).
if vllm_config.model_config is not None:
moe_in_dtype = vllm_config.model_config.dtype
else:
# TODO (bnell): This is a hack to get test_mixtral_moe to work
# since model_config is not set in the pytest test.
moe_in_dtype = params_dtype
tp_size_ = (tp_size if tp_size is not None else
get_tensor_model_parallel_world_size())
dp_size_ = (dp_size
if dp_size is not None else get_dp_group().world_size)
self.moe_parallel_config: FusedMoEParallelConfig = (
FusedMoEParallelConfig.make(
tp_size_=tp_size_,
dp_size_=dp_size_,
vllm_parallel_config=vllm_config.parallel_config))
self.global_num_experts = num_experts + num_redundant_experts
# For smuggling this layer into the fused moe custom op
compilation_config = vllm_config.compilation_config
if prefix in compilation_config.static_forward_context:
raise ValueError("Duplicate layer name: {}".format(prefix))
compilation_config.static_forward_context[prefix] = self
self.layer_name = prefix
self.enable_eplb = enable_eplb
self.expert_load_view: Optional[torch.Tensor] = None
self.logical_to_physical_map: Optional[torch.Tensor] = None
self.logical_replica_count: Optional[torch.Tensor] = None
# Determine expert maps
if self.use_ep:
if self.enable_eplb:
assert self.global_num_experts % self.ep_size == 0, \
"EPLB currently only supports even distribution of " \
"experts across ranks."
else:
assert num_redundant_experts == 0, \
"Redundant experts are only supported with EPLB."
self.local_num_experts, self.expert_map = determine_expert_map(
ep_size=self.ep_size,
ep_rank=self.ep_rank,
global_num_experts=self.global_num_experts)
else:
self.local_num_experts, self.expert_map = (self.global_num_experts,
None)
self.top_k = top_k
assert intermediate_size % self.tp_size == 0
self.hidden_size = hidden_size
self.intermediate_size_per_partition = intermediate_size // self.tp_size
self.reduce_results = reduce_results
self.renormalize = renormalize
self.use_grouped_topk = use_grouped_topk
if self.use_grouped_topk:
assert num_expert_group is not None and topk_group is not None
self.num_expert_group = num_expert_group
self.topk_group = topk_group
self.custom_routing_function = custom_routing_function
self.scoring_func = scoring_func
self.e_score_correction_bias = e_score_correction_bias
self.apply_router_weight_on_input = apply_router_weight_on_input
self.activation = activation
if self.scoring_func != "softmax" and not self.use_grouped_topk:
raise ValueError("Only softmax scoring function is supported for "
"non-grouped topk.")
moe = FusedMoEConfig(
num_experts=self.global_num_experts,
experts_per_token=top_k,
hidden_dim=hidden_size,
num_local_experts=self.local_num_experts,
moe_parallel_config=self.moe_parallel_config,
in_dtype=moe_in_dtype,
max_num_tokens=envs.VLLM_MOE_DP_CHUNK_SIZE,
has_bias=has_bias,
)
self.moe_config = moe
self.quant_config = quant_config
# Note: get_quant_method will look at the layer's local_num_experts
# for heuristic purposes, so it must be initialized first.
quant_method: Optional[QuantizeMethodBase] = None
quant_method = (UnquantizedFusedMoEMethod(moe) if quant_config is None
else quant_config.get_quant_method(self, prefix))
assert quant_method is not None
assert isinstance(quant_method, FusedMoEMethodBase)
self.quant_method = quant_method
if self.enable_eplb:
from vllm.model_executor.layers.quantization.fp8 import (
Fp8MoEMethod)
if not isinstance(quant_method, Fp8MoEMethod):
# TODO: Add support for additional quantization methods.
# The implementation for other quantization methods does not
# contain essential differences, but the current quant API
# design causes duplicated work when extending to new
# quantization methods, so I'm leaving it for now.
# If you plan to add support for more quantization methods,
# please refer to the implementation in `Fp8MoEMethod`.
raise NotImplementedError("EPLB is only supported for FP8 "
"quantization for now.")
moe_quant_params = {
"num_experts": self.local_num_experts,
"hidden_size": hidden_size,
"intermediate_size_per_partition":
self.intermediate_size_per_partition,
"params_dtype": params_dtype,
"weight_loader": self.weight_loader,
}
# need full intermediate size pre-sharding for WNA16 act order
if (self.quant_method.__class__.__name__
in ("GPTQMarlinMoEMethod",
"CompressedTensorsWNA16MarlinMoEMethod",
"CompressedTensorsWNA16MoEMethod")):
moe_quant_params["intermediate_size_full"] = intermediate_size
self.quant_method.create_weights(layer=self, **moe_quant_params)
# self.scale_n = self.quant_method.scale_n
# self.scale_k = self.quant_method.scale_k
self.scale_n = 1
self.scale_k = 1
self.scale_n_prefill = 1
if hasattr(self.quant_method, "scale_n") and hasattr(self.quant_method, "scale_k"):
self.scale_n = self.quant_method.scale_n
self.scale_k = self.quant_method.scale_k
if hasattr(self.quant_method, "scale_n_prefill"):
self.scale_n_prefill = self.quant_method.scale_n_prefill
# Chunked all2all staging tensor
self.batched_hidden_states: Optional[torch.Tensor] = None
self.batched_router_logits: Optional[torch.Tensor] = None
if (self.moe_parallel_config.use_pplx_kernels
or self.moe_parallel_config.use_deepep_ll_kernels):
self.batched_hidden_states = torch.zeros(
(moe.max_num_tokens, self.hidden_size),
dtype=moe.in_dtype,
device=torch.cuda.current_device())
# Note here we use `num_experts` which is logical expert count
self.batched_router_logits = torch.zeros(
(moe.max_num_tokens, num_experts),
dtype=moe.in_dtype,
device=torch.cuda.current_device())
class FusedMoE(torch.nn.Module):
def _load_w13(self, expert_data: torch.Tensor, shard_dim: int,
shard_id: str, loaded_weight: torch.Tensor, tp_rank: int,
expert_id=0):
#print("w13 shape is:", expert_data.shape, loaded_weight.shape)
if self.scale_n > 1 and len(loaded_weight.shape) == 2 and torch.finfo(loaded_weight.dtype).bits > 8:
n_v, k_v = loaded_weight.shape
loaded_weight = loaded_weight.reshape(1, n_v, 1, k_v) #[1, n, 1, k]
if self.scale_n_prefill != self.scale_n and hasattr(self, 'w13_weight_scale_inv_prefill'):
loaded_weight0 = loaded_weight.repeat(self.scale_n_prefill,1,self.scale_k,1).permute(1,0,3,2).reshape([n_v * self.scale_n_prefill, k_v * self.scale_k])
shard_size = self.w13_weight_scale_inv_prefill.data[expert_id].shape[shard_dim] // 2
loaded_weight0 = loaded_weight0.narrow(shard_dim, shard_size * tp_rank, shard_size)
if shard_id == "w1":
self.w13_weight_scale_inv_prefill.data[expert_id, :loaded_weight0.shape[0], :loaded_weight0.shape[1]] = loaded_weight0
elif shard_id == "w3":
self.w13_weight_scale_inv_prefill.data[expert_id, -loaded_weight0.shape[0]:, -loaded_weight0.shape[1]:] = loaded_weight0
else:
raise ValueError('error shard_id: ',shard_id)
loaded_weight = loaded_weight.repeat(self.scale_n,1,self.scale_k,1).permute(1,0,3,2).reshape([n_v * self.scale_n, k_v * self.scale_k])
#print("w13 repeat shape is:", expert_data.shape, loaded_weight.shape)
# Index the loaded weight for tp sharding.
# gate_up_proj: "MergedColumnParallel", so tp sharding on output_dim
shard_size = expert_data.shape[shard_dim] // 2
loaded_weight = loaded_weight.narrow(shard_dim, shard_size * tp_rank,
shard_size)
# Narrow parameter and load.
# w1, gate_proj: Load into first logical weight of w13.
if shard_id == "w1":
expert_data = expert_data.narrow(shard_dim, 0, shard_size)
# w3, up_proj: Load into second logical weight of w13.
else:
assert shard_id == "w3"
expert_data = expert_data.narrow(shard_dim, shard_size, shard_size)
expert_data.copy_(loaded_weight)
def _load_w2(self,
expert_data: torch.Tensor,
shard_dim: int,
loaded_weight: torch.Tensor,
tp_rank: int,
load_full: bool = False,
expert_id=0):
#print("w2 shape is:", expert_data.shape, loaded_weight.shape)
if self.scale_n > 1 and len(loaded_weight.shape) == 2 and torch.finfo(loaded_weight.dtype).bits > 8:
n_v, k_v = loaded_weight.shape
loaded_weight = loaded_weight.reshape(1, n_v, 1, k_v) #[1, n, 1, k]
if self.scale_n_prefill != self.scale_n and hasattr(self, 'w2_weight_scale_inv_prefill'):
loaded_weight0 = loaded_weight.repeat(self.scale_k,1,self.scale_n_prefill,1).permute(1,0,3,2).reshape([n_v * self.scale_k, k_v * self.scale_n_prefill])
shard_size = self.w2_weight_scale_inv_prefill.data[expert_id].shape[shard_dim]
if not load_full:
loaded_weight0 = loaded_weight0.narrow(shard_dim,
shard_size * tp_rank,
shard_size)
self.w2_weight_scale_inv_prefill.data[expert_id] = loaded_weight0
#print("loaded_weight:", loaded_weight.shape)
loaded_weight = loaded_weight.repeat(self.scale_k,1,self.scale_n,1).permute(1,0,3,2).reshape([n_v * self.scale_k, k_v * self.scale_n])
#print("w2 repeat shape is:", expert_data.shape, loaded_weight.shape)
#if self.quant_method.__class__.__name__ in ['Fp8LinearMethod', 'Fp8MoEMethod'] and torch.finfo(loaded_weight.dtype).bits > 8
# Index the loaded weight for tp sharding.
# down_proj: "RowParallel" so tp sharding on input_dim
# Narrow parameter and load.
shard_size = expert_data.shape[shard_dim]
if not load_full:
loaded_weight = loaded_weight.narrow(shard_dim,
shard_size * tp_rank,
shard_size)
# w2, down_proj: Load into only logical weight of w2.
expert_data.copy_(loaded_weight)
def weight_loader(self,
param: torch.nn.Parameter,
loaded_weight: torch.Tensor,
weight_name: str,
shard_id: str,
expert_id: int, return_success=True) -> None:
if self.quant_config and self.quant_config.get_name() == "mxfp4":
# (FIXME) for gpt-oss all experts are combined
if "bias" in weight_name:
dim1 = loaded_weight.shape[1]
param.data[:, :dim1].copy_(loaded_weight)
else:
dim1 = loaded_weight.shape[1]
dim2 = loaded_weight.shape[2]
param.data[:, :dim1, :dim2].copy_(loaded_weight)
return True if return_success else None
expert_id = self._map_global_expert_id_to_local_expert_id(expert_id)
if expert_id == -1:
return False if return_success else None
quant_method_name = self.quant_method.__class__.__name__
# compressed-tensors checkpoints with packed weights are stored flipped
# TODO (mgoin): check self.quant_method.quant_config.quant_format
# against known CompressionFormat enum values that have this quality
if self.quant_method.__class__.__name__ in (
"CompressedTensorsWNA16MarlinMoEMethod",
"CompressedTensorsWNA16MoEMethod"):
loaded_weight = loaded_weight.t().contiguous()
if shard_id not in ("w1", "w2", "w3"):
raise ValueError(f"shard_id must be ['w1','w2','w3'] but "
f"got {shard_id}.")
# WEIGHT_SCALE_SUPPORTED = [
# e.value for e in FusedMoeWeightScaleSupported
# ]
# Fetch the dim to shard the parameter/loaded weight
# based on the shard id. This will be whatever
# dimension intermediate_size_per_partition is used.
SHARD_ID_TO_SHARDED_DIM = {"w1": 0, "w2": 1, "w3": 0}
is_gguf_weight = getattr(param, "is_gguf_weight", False)
is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
if is_gguf_weight_type:
param.weight_type = loaded_weight.item()
param.data.copy_(loaded_weight)
return True if return_success else None
# Case for BitsAndBytes
use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
if use_bitsandbytes_4bit:
shard_dim = 0
expert_data = param.data[expert_id]
if shard_id == "w2":
expert_data.copy_(loaded_weight)
elif shard_id in ("w1", "w3"):
# BNB inflight quantization has already sharded the weights
full_load = True
self._load_w13(
shard_id=shard_id,
shard_dim=shard_dim,
loaded_weight=loaded_weight,
expert_data=expert_data,
tp_rank=self.tp_rank,
load_full=full_load,
)
return True if return_success else None
# is_transposed: if the dim to shard the weight
# should be flipped. Required by GPTQ, compressed-tensors
# should be whatever dimension intermediate_size_per_partition is
is_transposed = getattr(param, "is_transposed", False)
shard_dim = SHARD_ID_TO_SHARDED_DIM[shard_id]
if is_transposed:
shard_dim = int(not shard_dim)
full_load = len(loaded_weight.shape) == 3
if full_load:
shard_dim += 1
# Materialize GGUF UninitializedParameter
if is_gguf_weight and isinstance(param, UninitializedParameter):
final_shape = list(loaded_weight.shape)
if shard_id in ["w1", "w3"]:
final_shape[1] *= 2
final_shape[shard_dim] = final_shape[shard_dim] // self.tp_size
param.materialize(final_shape, dtype=loaded_weight.dtype)
expert_data = param.data if full_load else param.data[expert_id]
# Case input scale: input_scale loading is only supported for fp8
if "input_scale" in weight_name:
# this is needed for compressed-tensors only
loaded_weight = loaded_weight.to(param.data.device)
if param.data[expert_id] != 1 and (param.data[expert_id] -
loaded_weight).abs() > 1e-5:
raise ValueError(
"input_scales of w1 and w3 of a layer "
f"must be equal. But got {param.data[expert_id]} "
f"vs. {loaded_weight}")
self._load_single_value(param=param,
loaded_weight=loaded_weight,
expert_id=expert_id)
return True if return_success else None
# Case g_idx
if "g_idx" in weight_name:
self._load_g_idx(shard_dim=0,
shard_id=shard_id,
loaded_weight=loaded_weight,
expert_data=expert_data,
tp_rank=self.tp_rank)
return True if return_success else None
# TODO @dsikka: ModelOpt should follow the proper MoE loading pattern
if "ModelOpt" in quant_method_name:
# Determine per-tensor weight scale patterns based on variant
# Use the dedicated method instead of brittle string matching
uses_weight_scale_2 = self.quant_method.uses_weight_scale_2_pattern(
)
# Call _load_per_tensor_weight_scale() to load per-tensor (scalar)
# weights scales.
# Input scales are always per-tensor.
# Weight scales: FP4 uses "weight_scale_2" and FP8 uses
# "weight_scale" for per-tensor scales.
is_per_tensor = ("weight_scale_2" in weight_name
if uses_weight_scale_2 else "weight_scale"
in weight_name) or "input_scale" in weight_name
if is_per_tensor:
self._load_per_tensor_weight_scale(
shard_id=shard_id,
param=param,
loaded_weight=loaded_weight,
expert_id=expert_id,
)
return True if return_success else None
# If the weight is w13_weight_scale and w13_weight_scales are
# combined into single loaded_weight, call
# _load_combined_w13_weight_scale() to load it.
# This is checked by comparing the hidden_out dims of the
# loaded_weight and the param.
if "w13_weight_scale" in weight_name:
loaded_weight_hidden_out = loaded_weight.shape[-2]
param_hidden_out = param.data.shape[-2] * self.tp_size
if loaded_weight_hidden_out == param_hidden_out:
self._load_combined_w13_weight_scale(
shard_dim=shard_dim,
loaded_weight=loaded_weight,
param=param,
tp_rank=self.tp_rank,
)
return True if return_success else None
# For other weights, call _load_model_weight_or_group_weight_scale()
# to load it.
if "weight" in weight_name:
self._load_model_weight_or_group_weight_scale(
shard_id=shard_id,
shard_dim=shard_dim,
loaded_weight=loaded_weight,
expert_data=expert_data,
tp_rank=self.tp_rank)
return True if return_success else None
# Case weight scales, zero_points and offset
if ("scale" in weight_name or "zero" in weight_name
or "offset" in weight_name):
# load the weight scales and zp based on the quantization scheme
# supported weight scales/zp can be found in
# FusedMoeWeightScaleSupported
# TODO @dsikka: once hardened, refactor to use vLLM Parameters
# specific to each case
quant_method = getattr(param, "quant_method", None)
if quant_method == FusedMoeWeightScaleSupported.CHANNEL.value:
self._load_per_channel_weight_scale(
shard_id=shard_id,
shard_dim=shard_dim,
loaded_weight=loaded_weight,
expert_data=expert_data,
tp_rank=self.tp_rank)
elif quant_method in [
FusedMoeWeightScaleSupported.GROUP.value,
FusedMoeWeightScaleSupported.BLOCK.value,
]:
self._load_model_weight_or_group_weight_scale(
shard_id=shard_id,
shard_dim=shard_dim,
loaded_weight=loaded_weight,
expert_data=expert_data,
tp_rank=self.tp_rank,
load_full_w2=getattr(param, "load_full_w2", False),
expert_id=expert_id)
elif quant_method == FusedMoeWeightScaleSupported.TENSOR.value:
self._load_per_tensor_weight_scale(shard_id=shard_id,
param=param,
loaded_weight=loaded_weight,
expert_id=expert_id)
else:
WEIGHT_SCALE_SUPPORTED = [
e.value for e in FusedMoeWeightScaleSupported
]
raise ValueError(
f"quant method must be one of {WEIGHT_SCALE_SUPPORTED}")
return True if return_success else None
# Case weight_shape
if "weight_shape" in weight_name:
# only required by compressed-tensors
self._load_single_value(param=param,
loaded_weight=loaded_weight,
expert_id=expert_id)
return True if return_success else None
# Case model weights
if "weight" in weight_name:
self._load_model_weight_or_group_weight_scale(
shard_id=shard_id,
shard_dim=shard_dim,
loaded_weight=loaded_weight,
expert_data=expert_data,
tp_rank=self.tp_rank)
return True if return_success else None
return False if return_success else None
def _load_model_weight_or_group_weight_scale(self,
shard_dim: int,
expert_data: torch.Tensor,
shard_id: str,
loaded_weight: torch.Tensor,
tp_rank: int,
load_full_w2: bool = False,
expert_id: int = 0):
"""
Load grouped weight scales for group quantization or model weights
:param shard_dim: dimension to shard
:param expert_data: parameter for a particular expert
:param shard_id: either w1, w2, or w3
:param loaded_weight: checkpoint weight to load into the param
:param tp_rank: tensor parallel rank
:param load_full_w2: whether or not the w2 loaded should be sharded.
"""
if shard_id == "w2":
# In the case where we have actorder/g_idx, we do not partition the
# w2 scales, as indicated by `load_full` argument, for all tp cases
self._load_w2(shard_dim=shard_dim,
loaded_weight=loaded_weight,
expert_data=expert_data,
tp_rank=tp_rank,
load_full=load_full_w2,
expert_id=expert_id)
elif shard_id in ("w1", "w3"):
self._load_w13(shard_id=shard_id,
shard_dim=shard_dim,
loaded_weight=loaded_weight,
expert_data=expert_data,
tp_rank=tp_rank,
expert_id=expert_id)
class UnquantizedFusedMoEMethod():
def forward_vacc(
layer: torch.nn.Module,
x: torch.Tensor,
use_grouped_topk: bool,
top_k: int,
router_logits: torch.Tensor,
renormalize: bool,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
) -> torch.Tensor:
hidden_states = x
orig_shape = hidden_states.shape
hidden_size = hidden_states.shape[-1]
num_tokens = hidden_states.shape[:-1].numel()
dtype = hidden_states.dtype
intermediate_size = layer.w2_weight.shape[-1]
gating_output=router_logits
hidden_states = hidden_states.view(num_tokens, hidden_size)
gating_output = gating_output.view(num_tokens, global_num_experts)
topk_weights = gating_output.softmax(dim=-1, dtype=torch.float)
topk_weights, topk_ids = topk_weights.topk(top_k, dim=-1)
if renormalize:
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
topk_weights = topk_weights.to(dtype)
if expert_map is not None:
topk_ids = expert_map[topk_ids]
final_hidden_states = torch.zeros_like(hidden_states)
sel_experts = topk_ids.shape[1]
if hidden_states.shape[0] == 1:
for id in range(sel_experts):
expert_idx = topk_ids[0][id]
expert_w1 = layer.w13_weight[expert_idx].contiguous()
expert_w2 = layer.w2_weight[expert_idx].contiguous()
expert_weights = topk_weights[0][id].to(hidden_states.dtype)
x = hidden_states
x = F.linear(x, expert_w1)
gate = F.silu(x[:, :intermediate_size])
x = x[:, intermediate_size:] * gate
x = F.linear(x, expert_w2)
current_hidden_states = x * expert_weights
current_hidden_states = current_hidden_states.to(x.dtype)
final_hidden_states += current_hidden_states
else:
for expert_idx in range(global_num_experts):
# topk_ids [tokens, experts] => sample:[10, 8]
# expert_mask [tokens, experts] => sample:[10, 8]
expert_mask = topk_ids == expert_idx
idx = torch.where(expert_mask)[0]
if idx.numel() == 0:
continue
expert_w1 = layer.w13_weight[expert_idx].contiguous()
expert_w2 = layer.w2_weight[expert_idx].contiguous()
# [seq, experts]
expert_weights = (
topk_weights.masked_select(expert_mask)
.unsqueeze(1)
.to(hidden_states.dtype)
)
x = hidden_states[idx]
x = F.linear(x, expert_w1)
gate = F.silu(x[:, :intermediate_size])
x = x[:, intermediate_size:] * gate
x = F.linear(x, expert_w2)
current_hidden_states = x * expert_weights
current_hidden_states = current_hidden_states.to(x.dtype)
# final_hidden_states[idx] += current_hidden_states
final_hidden_states.index_add_(0, idx, current_hidden_states)
return final_hidden_states.view(orig_shape) # type: ignore
def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
renormalize: bool,
use_grouped_topk: bool = False,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
enable_eplb: bool = False,
expert_load_view: Optional[torch.Tensor] = None,
logical_to_physical_map: Optional[torch.Tensor] = None,
logical_replica_count: Optional[torch.Tensor] = None,
) -> torch.Tensor:
custom_forward = self.forward
if x.device.type == "vacc":
custom_forward = UnquantizedFusedMoEMethod.forward_vacc
return custom_forward(
x=x,
layer=layer,
router_logits=router_logits,
top_k=top_k,
renormalize=renormalize,
use_grouped_topk=use_grouped_topk,
topk_group=topk_group,
num_expert_group=num_expert_group,
global_num_experts=global_num_experts,
expert_map=expert_map,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias,
activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input)