init
This commit is contained in:
Binary file not shown.
Binary file not shown.
Binary file not shown.
82
vllm_vacc/vllm/model_executor/layers/fused_moe/fused_moe.py
Normal file
82
vllm_vacc/vllm/model_executor/layers/fused_moe/fused_moe.py
Normal file
@@ -0,0 +1,82 @@
|
||||
import torch
|
||||
from typing import Optional, Tuple
|
||||
|
||||
def fused_topk(
|
||||
hidden_states: torch.Tensor,
|
||||
gating_output: torch.Tensor,
|
||||
topk: int,
|
||||
renormalize: bool,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
hidden_size = hidden_states.shape[-1]
|
||||
num_tokens = hidden_states.shape[:-1].numel()
|
||||
dtype = hidden_states.dtype
|
||||
|
||||
hidden_states = hidden_states.view(num_tokens, hidden_size)
|
||||
gating_output = gating_output.view(num_tokens, -1)
|
||||
topk_weights = gating_output.softmax(dim=-1, dtype=torch.float)
|
||||
topk_weights, selected_experts = topk_weights.topk(topk, dim=-1)
|
||||
if renormalize:
|
||||
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
|
||||
topk_weights = topk_weights.to(dtype)
|
||||
return topk_weights, selected_experts
|
||||
|
||||
def grouped_topk_with_itype(hidden_states: torch.Tensor,
|
||||
gating_output: torch.Tensor,
|
||||
topk: int,
|
||||
renormalize: bool,
|
||||
num_expert_group: int = 0,
|
||||
topk_group: int = 0,
|
||||
scoring_func: str = "softmax",
|
||||
e_score_correction_bias: Optional[torch.Tensor] = None):
|
||||
assert hidden_states.shape[0] == gating_output.shape[0], (
|
||||
"Number of tokens mismatch")
|
||||
|
||||
|
||||
try:
|
||||
from torch_vacc.vacc.custom_ops import fused_moe_preprocess
|
||||
return fused_moe_preprocess(gating_output, e_score_correction_bias)
|
||||
except Exception as e:
|
||||
print(f"fused group topk run fail, now use unfused group topk: {e}")
|
||||
|
||||
if scoring_func == "softmax":
|
||||
scores = torch.softmax(gating_output, dim=-1)
|
||||
elif scoring_func == "sigmoid":
|
||||
scores = gating_output.sigmoid()
|
||||
else:
|
||||
raise ValueError(f"Unsupported scoring function: {scoring_func}")
|
||||
|
||||
num_token = scores.shape[0]
|
||||
if e_score_correction_bias is not None:
|
||||
# Store original scores before applying correction bias. We use biased
|
||||
# scores for expert selection but original scores for routing weights
|
||||
original_scores = scores
|
||||
scores = scores + e_score_correction_bias.unsqueeze(0)
|
||||
group_scores = (scores.view(num_token, num_expert_group,
|
||||
-1).topk(2, dim=-1)[0].sum(dim=-1))
|
||||
else:
|
||||
group_scores = scores.view(num_token, num_expert_group,
|
||||
-1).max(dim=-1).values # [n, n_group]
|
||||
group_idx = torch.topk(group_scores, k=topk_group, dim=-1,
|
||||
sorted=False)[1] # [n, top_k_group]
|
||||
group_mask = torch.zeros_like(group_scores) # [n, n_group]
|
||||
group_mask.scatter_(1, group_idx, 1) # [n, n_group]
|
||||
score_mask = group_mask.unsqueeze(-1).expand(
|
||||
num_token, num_expert_group,
|
||||
scores.shape[-1] // num_expert_group).reshape(num_token, -1) # [n, e]
|
||||
tmp_scores = scores.masked_fill(~score_mask.bool(),
|
||||
float("-inf")) # [n, e]
|
||||
|
||||
if e_score_correction_bias is not None:
|
||||
topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False)[1]
|
||||
# Use original unbiased scores for the routing weights
|
||||
topk_weights = original_scores.gather(1, topk_ids)
|
||||
else:
|
||||
topk_weights, topk_ids = torch.topk(tmp_scores,
|
||||
k=topk,
|
||||
dim=-1,
|
||||
sorted=False)
|
||||
|
||||
if renormalize:
|
||||
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
|
||||
|
||||
return topk_weights.to(hidden_states.dtype), topk_ids.to(torch.int32)
|
||||
715
vllm_vacc/vllm/model_executor/layers/fused_moe/layer.py
Normal file
715
vllm_vacc/vllm/model_executor/layers/fused_moe/layer.py
Normal file
@@ -0,0 +1,715 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from abc import abstractmethod
|
||||
from enum import Enum
|
||||
from typing import Callable, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch.nn.parameter import UninitializedParameter
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.config import get_current_vllm_config
|
||||
from vllm.distributed import (get_dp_group, get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size,
|
||||
tensor_model_parallel_all_reduce)
|
||||
from vllm.forward_context import ForwardContext, get_forward_context
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.custom_op import CustomOp
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
FusedMoEConfig, FusedMoEParallelConfig)
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoEMethodBase
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig, QuantizeMethodBase)
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.platforms.interface import CpuArchEnum
|
||||
from vllm.utils import direct_register_custom_op
|
||||
|
||||
if current_platform.is_cuda_alike():
|
||||
from .fused_moe import fused_experts
|
||||
else:
|
||||
fused_experts = None # type: ignore
|
||||
if current_platform.is_tpu():
|
||||
# the iterative moe implementation is used until the moe_pallas is fixed
|
||||
from .moe_torch_iterative import fused_moe as fused_moe_pallas
|
||||
else:
|
||||
fused_moe_pallas = None # type: ignore
|
||||
logger = init_logger(__name__)
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig, QuantizeMethodBase)
|
||||
|
||||
class FusedMoeWeightScaleSupported(Enum):
|
||||
TENSOR = "tensor"
|
||||
CHANNEL = "channel"
|
||||
GROUP = "group"
|
||||
BLOCK = "block"
|
||||
|
||||
def FusedMoE_init_(
|
||||
self,
|
||||
num_experts: int, # Global number of experts
|
||||
top_k: int,
|
||||
hidden_size: int,
|
||||
intermediate_size: int,
|
||||
params_dtype: Optional[torch.dtype] = None,
|
||||
reduce_results: bool = False,
|
||||
renormalize: bool = True,
|
||||
use_grouped_topk: bool = False,
|
||||
num_expert_group: Optional[int] = None,
|
||||
topk_group: Optional[int] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
tp_size: Optional[int] = None,
|
||||
ep_size: Optional[int] = None,
|
||||
dp_size: Optional[int] = None,
|
||||
prefix: str = "",
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
scoring_func: str = "softmax",
|
||||
routed_scaling_factor: float = 1.0,
|
||||
e_score_correction_bias: Optional[torch.Tensor] = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
activation: str = "silu",
|
||||
enable_eplb: bool = False,
|
||||
num_redundant_experts: int = 0,
|
||||
has_bias: bool = False,
|
||||
is_sequence_parallel=False,
|
||||
):
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoE
|
||||
super(FusedMoE, self).__init__()
|
||||
|
||||
if params_dtype is None:
|
||||
params_dtype = torch.get_default_dtype()
|
||||
self.params_dtype = params_dtype
|
||||
|
||||
vllm_config = get_current_vllm_config()
|
||||
|
||||
# FIXME (varun): We should have a better way of inferring the activation
|
||||
# datatype. This works for now as the tensor datatype entering the MoE
|
||||
# operation is typically unquantized (i.e. float16/bfloat16).
|
||||
if vllm_config.model_config is not None:
|
||||
moe_in_dtype = vllm_config.model_config.dtype
|
||||
else:
|
||||
# TODO (bnell): This is a hack to get test_mixtral_moe to work
|
||||
# since model_config is not set in the pytest test.
|
||||
moe_in_dtype = params_dtype
|
||||
|
||||
tp_size_ = (tp_size if tp_size is not None else
|
||||
get_tensor_model_parallel_world_size())
|
||||
dp_size_ = (dp_size
|
||||
if dp_size is not None else get_dp_group().world_size)
|
||||
|
||||
|
||||
self.moe_parallel_config: FusedMoEParallelConfig = (
|
||||
FusedMoEParallelConfig.make(
|
||||
tp_size_=tp_size_,
|
||||
dp_size_=dp_size_,
|
||||
vllm_parallel_config=vllm_config.parallel_config))
|
||||
|
||||
self.global_num_experts = num_experts + num_redundant_experts
|
||||
|
||||
# For smuggling this layer into the fused moe custom op
|
||||
compilation_config = vllm_config.compilation_config
|
||||
if prefix in compilation_config.static_forward_context:
|
||||
raise ValueError("Duplicate layer name: {}".format(prefix))
|
||||
compilation_config.static_forward_context[prefix] = self
|
||||
self.layer_name = prefix
|
||||
|
||||
self.enable_eplb = enable_eplb
|
||||
self.expert_load_view: Optional[torch.Tensor] = None
|
||||
self.logical_to_physical_map: Optional[torch.Tensor] = None
|
||||
self.logical_replica_count: Optional[torch.Tensor] = None
|
||||
|
||||
# Determine expert maps
|
||||
if self.use_ep:
|
||||
if self.enable_eplb:
|
||||
assert self.global_num_experts % self.ep_size == 0, \
|
||||
"EPLB currently only supports even distribution of " \
|
||||
"experts across ranks."
|
||||
else:
|
||||
assert num_redundant_experts == 0, \
|
||||
"Redundant experts are only supported with EPLB."
|
||||
self.local_num_experts, self.expert_map = determine_expert_map(
|
||||
ep_size=self.ep_size,
|
||||
ep_rank=self.ep_rank,
|
||||
global_num_experts=self.global_num_experts)
|
||||
else:
|
||||
self.local_num_experts, self.expert_map = (self.global_num_experts,
|
||||
None)
|
||||
|
||||
self.top_k = top_k
|
||||
|
||||
assert intermediate_size % self.tp_size == 0
|
||||
self.hidden_size = hidden_size
|
||||
self.intermediate_size_per_partition = intermediate_size // self.tp_size
|
||||
self.reduce_results = reduce_results
|
||||
self.renormalize = renormalize
|
||||
self.use_grouped_topk = use_grouped_topk
|
||||
if self.use_grouped_topk:
|
||||
assert num_expert_group is not None and topk_group is not None
|
||||
self.num_expert_group = num_expert_group
|
||||
self.topk_group = topk_group
|
||||
self.custom_routing_function = custom_routing_function
|
||||
self.scoring_func = scoring_func
|
||||
self.e_score_correction_bias = e_score_correction_bias
|
||||
self.apply_router_weight_on_input = apply_router_weight_on_input
|
||||
self.activation = activation
|
||||
|
||||
if self.scoring_func != "softmax" and not self.use_grouped_topk:
|
||||
raise ValueError("Only softmax scoring function is supported for "
|
||||
"non-grouped topk.")
|
||||
|
||||
moe = FusedMoEConfig(
|
||||
num_experts=self.global_num_experts,
|
||||
experts_per_token=top_k,
|
||||
hidden_dim=hidden_size,
|
||||
num_local_experts=self.local_num_experts,
|
||||
moe_parallel_config=self.moe_parallel_config,
|
||||
in_dtype=moe_in_dtype,
|
||||
max_num_tokens=envs.VLLM_MOE_DP_CHUNK_SIZE,
|
||||
has_bias=has_bias,
|
||||
)
|
||||
self.moe_config = moe
|
||||
self.quant_config = quant_config
|
||||
|
||||
# Note: get_quant_method will look at the layer's local_num_experts
|
||||
# for heuristic purposes, so it must be initialized first.
|
||||
quant_method: Optional[QuantizeMethodBase] = None
|
||||
quant_method = (UnquantizedFusedMoEMethod(moe) if quant_config is None
|
||||
else quant_config.get_quant_method(self, prefix))
|
||||
|
||||
assert quant_method is not None
|
||||
assert isinstance(quant_method, FusedMoEMethodBase)
|
||||
self.quant_method = quant_method
|
||||
|
||||
if self.enable_eplb:
|
||||
from vllm.model_executor.layers.quantization.fp8 import (
|
||||
Fp8MoEMethod)
|
||||
if not isinstance(quant_method, Fp8MoEMethod):
|
||||
# TODO: Add support for additional quantization methods.
|
||||
# The implementation for other quantization methods does not
|
||||
# contain essential differences, but the current quant API
|
||||
# design causes duplicated work when extending to new
|
||||
# quantization methods, so I'm leaving it for now.
|
||||
# If you plan to add support for more quantization methods,
|
||||
# please refer to the implementation in `Fp8MoEMethod`.
|
||||
raise NotImplementedError("EPLB is only supported for FP8 "
|
||||
"quantization for now.")
|
||||
|
||||
moe_quant_params = {
|
||||
"num_experts": self.local_num_experts,
|
||||
"hidden_size": hidden_size,
|
||||
"intermediate_size_per_partition":
|
||||
self.intermediate_size_per_partition,
|
||||
"params_dtype": params_dtype,
|
||||
"weight_loader": self.weight_loader,
|
||||
}
|
||||
# need full intermediate size pre-sharding for WNA16 act order
|
||||
if (self.quant_method.__class__.__name__
|
||||
in ("GPTQMarlinMoEMethod",
|
||||
"CompressedTensorsWNA16MarlinMoEMethod",
|
||||
"CompressedTensorsWNA16MoEMethod")):
|
||||
moe_quant_params["intermediate_size_full"] = intermediate_size
|
||||
|
||||
self.quant_method.create_weights(layer=self, **moe_quant_params)
|
||||
|
||||
# self.scale_n = self.quant_method.scale_n
|
||||
# self.scale_k = self.quant_method.scale_k
|
||||
self.scale_n = 1
|
||||
self.scale_k = 1
|
||||
self.scale_n_prefill = 1
|
||||
if hasattr(self.quant_method, "scale_n") and hasattr(self.quant_method, "scale_k"):
|
||||
self.scale_n = self.quant_method.scale_n
|
||||
self.scale_k = self.quant_method.scale_k
|
||||
if hasattr(self.quant_method, "scale_n_prefill"):
|
||||
self.scale_n_prefill = self.quant_method.scale_n_prefill
|
||||
|
||||
# Chunked all2all staging tensor
|
||||
self.batched_hidden_states: Optional[torch.Tensor] = None
|
||||
self.batched_router_logits: Optional[torch.Tensor] = None
|
||||
if (self.moe_parallel_config.use_pplx_kernels
|
||||
or self.moe_parallel_config.use_deepep_ll_kernels):
|
||||
self.batched_hidden_states = torch.zeros(
|
||||
(moe.max_num_tokens, self.hidden_size),
|
||||
dtype=moe.in_dtype,
|
||||
device=torch.cuda.current_device())
|
||||
|
||||
# Note here we use `num_experts` which is logical expert count
|
||||
self.batched_router_logits = torch.zeros(
|
||||
(moe.max_num_tokens, num_experts),
|
||||
dtype=moe.in_dtype,
|
||||
device=torch.cuda.current_device())
|
||||
|
||||
class FusedMoE(torch.nn.Module):
|
||||
def _load_w13(self, expert_data: torch.Tensor, shard_dim: int,
|
||||
shard_id: str, loaded_weight: torch.Tensor, tp_rank: int,
|
||||
expert_id=0):
|
||||
#print("w13 shape is:", expert_data.shape, loaded_weight.shape)
|
||||
if self.scale_n > 1 and len(loaded_weight.shape) == 2 and torch.finfo(loaded_weight.dtype).bits > 8:
|
||||
n_v, k_v = loaded_weight.shape
|
||||
loaded_weight = loaded_weight.reshape(1, n_v, 1, k_v) #[1, n, 1, k]
|
||||
if self.scale_n_prefill != self.scale_n and hasattr(self, 'w13_weight_scale_inv_prefill'):
|
||||
loaded_weight0 = loaded_weight.repeat(self.scale_n_prefill,1,self.scale_k,1).permute(1,0,3,2).reshape([n_v * self.scale_n_prefill, k_v * self.scale_k])
|
||||
shard_size = self.w13_weight_scale_inv_prefill.data[expert_id].shape[shard_dim] // 2
|
||||
loaded_weight0 = loaded_weight0.narrow(shard_dim, shard_size * tp_rank, shard_size)
|
||||
|
||||
if shard_id == "w1":
|
||||
self.w13_weight_scale_inv_prefill.data[expert_id, :loaded_weight0.shape[0], :loaded_weight0.shape[1]] = loaded_weight0
|
||||
elif shard_id == "w3":
|
||||
self.w13_weight_scale_inv_prefill.data[expert_id, -loaded_weight0.shape[0]:, -loaded_weight0.shape[1]:] = loaded_weight0
|
||||
else:
|
||||
raise ValueError('error shard_id: ',shard_id)
|
||||
|
||||
loaded_weight = loaded_weight.repeat(self.scale_n,1,self.scale_k,1).permute(1,0,3,2).reshape([n_v * self.scale_n, k_v * self.scale_k])
|
||||
#print("w13 repeat shape is:", expert_data.shape, loaded_weight.shape)
|
||||
|
||||
# Index the loaded weight for tp sharding.
|
||||
# gate_up_proj: "MergedColumnParallel", so tp sharding on output_dim
|
||||
shard_size = expert_data.shape[shard_dim] // 2
|
||||
loaded_weight = loaded_weight.narrow(shard_dim, shard_size * tp_rank,
|
||||
shard_size)
|
||||
# Narrow parameter and load.
|
||||
# w1, gate_proj: Load into first logical weight of w13.
|
||||
if shard_id == "w1":
|
||||
expert_data = expert_data.narrow(shard_dim, 0, shard_size)
|
||||
# w3, up_proj: Load into second logical weight of w13.
|
||||
else:
|
||||
assert shard_id == "w3"
|
||||
expert_data = expert_data.narrow(shard_dim, shard_size, shard_size)
|
||||
|
||||
expert_data.copy_(loaded_weight)
|
||||
|
||||
def _load_w2(self,
|
||||
expert_data: torch.Tensor,
|
||||
shard_dim: int,
|
||||
loaded_weight: torch.Tensor,
|
||||
tp_rank: int,
|
||||
load_full: bool = False,
|
||||
expert_id=0):
|
||||
#print("w2 shape is:", expert_data.shape, loaded_weight.shape)
|
||||
if self.scale_n > 1 and len(loaded_weight.shape) == 2 and torch.finfo(loaded_weight.dtype).bits > 8:
|
||||
n_v, k_v = loaded_weight.shape
|
||||
loaded_weight = loaded_weight.reshape(1, n_v, 1, k_v) #[1, n, 1, k]
|
||||
if self.scale_n_prefill != self.scale_n and hasattr(self, 'w2_weight_scale_inv_prefill'):
|
||||
loaded_weight0 = loaded_weight.repeat(self.scale_k,1,self.scale_n_prefill,1).permute(1,0,3,2).reshape([n_v * self.scale_k, k_v * self.scale_n_prefill])
|
||||
shard_size = self.w2_weight_scale_inv_prefill.data[expert_id].shape[shard_dim]
|
||||
if not load_full:
|
||||
loaded_weight0 = loaded_weight0.narrow(shard_dim,
|
||||
shard_size * tp_rank,
|
||||
shard_size)
|
||||
self.w2_weight_scale_inv_prefill.data[expert_id] = loaded_weight0
|
||||
|
||||
#print("loaded_weight:", loaded_weight.shape)
|
||||
loaded_weight = loaded_weight.repeat(self.scale_k,1,self.scale_n,1).permute(1,0,3,2).reshape([n_v * self.scale_k, k_v * self.scale_n])
|
||||
#print("w2 repeat shape is:", expert_data.shape, loaded_weight.shape)
|
||||
#if self.quant_method.__class__.__name__ in ['Fp8LinearMethod', 'Fp8MoEMethod'] and torch.finfo(loaded_weight.dtype).bits > 8
|
||||
# Index the loaded weight for tp sharding.
|
||||
# down_proj: "RowParallel" so tp sharding on input_dim
|
||||
# Narrow parameter and load.
|
||||
shard_size = expert_data.shape[shard_dim]
|
||||
if not load_full:
|
||||
loaded_weight = loaded_weight.narrow(shard_dim,
|
||||
shard_size * tp_rank,
|
||||
shard_size)
|
||||
# w2, down_proj: Load into only logical weight of w2.
|
||||
expert_data.copy_(loaded_weight)
|
||||
|
||||
|
||||
def weight_loader(self,
|
||||
param: torch.nn.Parameter,
|
||||
loaded_weight: torch.Tensor,
|
||||
weight_name: str,
|
||||
shard_id: str,
|
||||
expert_id: int, return_success=True) -> None:
|
||||
|
||||
if self.quant_config and self.quant_config.get_name() == "mxfp4":
|
||||
# (FIXME) for gpt-oss all experts are combined
|
||||
if "bias" in weight_name:
|
||||
dim1 = loaded_weight.shape[1]
|
||||
param.data[:, :dim1].copy_(loaded_weight)
|
||||
else:
|
||||
dim1 = loaded_weight.shape[1]
|
||||
dim2 = loaded_weight.shape[2]
|
||||
param.data[:, :dim1, :dim2].copy_(loaded_weight)
|
||||
return True if return_success else None
|
||||
|
||||
expert_id = self._map_global_expert_id_to_local_expert_id(expert_id)
|
||||
if expert_id == -1:
|
||||
return False if return_success else None
|
||||
|
||||
quant_method_name = self.quant_method.__class__.__name__
|
||||
# compressed-tensors checkpoints with packed weights are stored flipped
|
||||
# TODO (mgoin): check self.quant_method.quant_config.quant_format
|
||||
# against known CompressionFormat enum values that have this quality
|
||||
if self.quant_method.__class__.__name__ in (
|
||||
"CompressedTensorsWNA16MarlinMoEMethod",
|
||||
"CompressedTensorsWNA16MoEMethod"):
|
||||
loaded_weight = loaded_weight.t().contiguous()
|
||||
|
||||
if shard_id not in ("w1", "w2", "w3"):
|
||||
raise ValueError(f"shard_id must be ['w1','w2','w3'] but "
|
||||
f"got {shard_id}.")
|
||||
|
||||
# WEIGHT_SCALE_SUPPORTED = [
|
||||
# e.value for e in FusedMoeWeightScaleSupported
|
||||
# ]
|
||||
# Fetch the dim to shard the parameter/loaded weight
|
||||
# based on the shard id. This will be whatever
|
||||
# dimension intermediate_size_per_partition is used.
|
||||
SHARD_ID_TO_SHARDED_DIM = {"w1": 0, "w2": 1, "w3": 0}
|
||||
|
||||
is_gguf_weight = getattr(param, "is_gguf_weight", False)
|
||||
is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
|
||||
if is_gguf_weight_type:
|
||||
param.weight_type = loaded_weight.item()
|
||||
param.data.copy_(loaded_weight)
|
||||
return True if return_success else None
|
||||
|
||||
# Case for BitsAndBytes
|
||||
use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
|
||||
if use_bitsandbytes_4bit:
|
||||
shard_dim = 0
|
||||
|
||||
expert_data = param.data[expert_id]
|
||||
if shard_id == "w2":
|
||||
expert_data.copy_(loaded_weight)
|
||||
elif shard_id in ("w1", "w3"):
|
||||
# BNB inflight quantization has already sharded the weights
|
||||
full_load = True
|
||||
self._load_w13(
|
||||
shard_id=shard_id,
|
||||
shard_dim=shard_dim,
|
||||
loaded_weight=loaded_weight,
|
||||
expert_data=expert_data,
|
||||
tp_rank=self.tp_rank,
|
||||
load_full=full_load,
|
||||
)
|
||||
return True if return_success else None
|
||||
|
||||
# is_transposed: if the dim to shard the weight
|
||||
# should be flipped. Required by GPTQ, compressed-tensors
|
||||
# should be whatever dimension intermediate_size_per_partition is
|
||||
is_transposed = getattr(param, "is_transposed", False)
|
||||
shard_dim = SHARD_ID_TO_SHARDED_DIM[shard_id]
|
||||
if is_transposed:
|
||||
shard_dim = int(not shard_dim)
|
||||
|
||||
full_load = len(loaded_weight.shape) == 3
|
||||
if full_load:
|
||||
shard_dim += 1
|
||||
|
||||
# Materialize GGUF UninitializedParameter
|
||||
if is_gguf_weight and isinstance(param, UninitializedParameter):
|
||||
final_shape = list(loaded_weight.shape)
|
||||
if shard_id in ["w1", "w3"]:
|
||||
final_shape[1] *= 2
|
||||
final_shape[shard_dim] = final_shape[shard_dim] // self.tp_size
|
||||
param.materialize(final_shape, dtype=loaded_weight.dtype)
|
||||
|
||||
expert_data = param.data if full_load else param.data[expert_id]
|
||||
# Case input scale: input_scale loading is only supported for fp8
|
||||
if "input_scale" in weight_name:
|
||||
# this is needed for compressed-tensors only
|
||||
loaded_weight = loaded_weight.to(param.data.device)
|
||||
|
||||
if param.data[expert_id] != 1 and (param.data[expert_id] -
|
||||
loaded_weight).abs() > 1e-5:
|
||||
raise ValueError(
|
||||
"input_scales of w1 and w3 of a layer "
|
||||
f"must be equal. But got {param.data[expert_id]} "
|
||||
f"vs. {loaded_weight}")
|
||||
|
||||
self._load_single_value(param=param,
|
||||
loaded_weight=loaded_weight,
|
||||
expert_id=expert_id)
|
||||
return True if return_success else None
|
||||
|
||||
# Case g_idx
|
||||
if "g_idx" in weight_name:
|
||||
self._load_g_idx(shard_dim=0,
|
||||
shard_id=shard_id,
|
||||
loaded_weight=loaded_weight,
|
||||
expert_data=expert_data,
|
||||
tp_rank=self.tp_rank)
|
||||
return True if return_success else None
|
||||
|
||||
# TODO @dsikka: ModelOpt should follow the proper MoE loading pattern
|
||||
if "ModelOpt" in quant_method_name:
|
||||
# Determine per-tensor weight scale patterns based on variant
|
||||
# Use the dedicated method instead of brittle string matching
|
||||
uses_weight_scale_2 = self.quant_method.uses_weight_scale_2_pattern(
|
||||
)
|
||||
|
||||
# Call _load_per_tensor_weight_scale() to load per-tensor (scalar)
|
||||
# weights scales.
|
||||
# Input scales are always per-tensor.
|
||||
# Weight scales: FP4 uses "weight_scale_2" and FP8 uses
|
||||
# "weight_scale" for per-tensor scales.
|
||||
is_per_tensor = ("weight_scale_2" in weight_name
|
||||
if uses_weight_scale_2 else "weight_scale"
|
||||
in weight_name) or "input_scale" in weight_name
|
||||
if is_per_tensor:
|
||||
self._load_per_tensor_weight_scale(
|
||||
shard_id=shard_id,
|
||||
param=param,
|
||||
loaded_weight=loaded_weight,
|
||||
expert_id=expert_id,
|
||||
)
|
||||
return True if return_success else None
|
||||
|
||||
# If the weight is w13_weight_scale and w13_weight_scales are
|
||||
# combined into single loaded_weight, call
|
||||
# _load_combined_w13_weight_scale() to load it.
|
||||
# This is checked by comparing the hidden_out dims of the
|
||||
# loaded_weight and the param.
|
||||
if "w13_weight_scale" in weight_name:
|
||||
loaded_weight_hidden_out = loaded_weight.shape[-2]
|
||||
param_hidden_out = param.data.shape[-2] * self.tp_size
|
||||
if loaded_weight_hidden_out == param_hidden_out:
|
||||
self._load_combined_w13_weight_scale(
|
||||
shard_dim=shard_dim,
|
||||
loaded_weight=loaded_weight,
|
||||
param=param,
|
||||
tp_rank=self.tp_rank,
|
||||
)
|
||||
return True if return_success else None
|
||||
|
||||
# For other weights, call _load_model_weight_or_group_weight_scale()
|
||||
# to load it.
|
||||
if "weight" in weight_name:
|
||||
self._load_model_weight_or_group_weight_scale(
|
||||
shard_id=shard_id,
|
||||
shard_dim=shard_dim,
|
||||
loaded_weight=loaded_weight,
|
||||
expert_data=expert_data,
|
||||
tp_rank=self.tp_rank)
|
||||
return True if return_success else None
|
||||
|
||||
|
||||
# Case weight scales, zero_points and offset
|
||||
if ("scale" in weight_name or "zero" in weight_name
|
||||
or "offset" in weight_name):
|
||||
# load the weight scales and zp based on the quantization scheme
|
||||
# supported weight scales/zp can be found in
|
||||
# FusedMoeWeightScaleSupported
|
||||
# TODO @dsikka: once hardened, refactor to use vLLM Parameters
|
||||
# specific to each case
|
||||
quant_method = getattr(param, "quant_method", None)
|
||||
if quant_method == FusedMoeWeightScaleSupported.CHANNEL.value:
|
||||
self._load_per_channel_weight_scale(
|
||||
shard_id=shard_id,
|
||||
shard_dim=shard_dim,
|
||||
loaded_weight=loaded_weight,
|
||||
expert_data=expert_data,
|
||||
tp_rank=self.tp_rank)
|
||||
elif quant_method in [
|
||||
FusedMoeWeightScaleSupported.GROUP.value,
|
||||
FusedMoeWeightScaleSupported.BLOCK.value,
|
||||
]:
|
||||
self._load_model_weight_or_group_weight_scale(
|
||||
shard_id=shard_id,
|
||||
shard_dim=shard_dim,
|
||||
loaded_weight=loaded_weight,
|
||||
expert_data=expert_data,
|
||||
tp_rank=self.tp_rank,
|
||||
load_full_w2=getattr(param, "load_full_w2", False),
|
||||
expert_id=expert_id)
|
||||
elif quant_method == FusedMoeWeightScaleSupported.TENSOR.value:
|
||||
self._load_per_tensor_weight_scale(shard_id=shard_id,
|
||||
param=param,
|
||||
loaded_weight=loaded_weight,
|
||||
expert_id=expert_id)
|
||||
else:
|
||||
WEIGHT_SCALE_SUPPORTED = [
|
||||
e.value for e in FusedMoeWeightScaleSupported
|
||||
]
|
||||
raise ValueError(
|
||||
f"quant method must be one of {WEIGHT_SCALE_SUPPORTED}")
|
||||
return True if return_success else None
|
||||
|
||||
# Case weight_shape
|
||||
if "weight_shape" in weight_name:
|
||||
# only required by compressed-tensors
|
||||
self._load_single_value(param=param,
|
||||
loaded_weight=loaded_weight,
|
||||
expert_id=expert_id)
|
||||
return True if return_success else None
|
||||
|
||||
# Case model weights
|
||||
if "weight" in weight_name:
|
||||
self._load_model_weight_or_group_weight_scale(
|
||||
shard_id=shard_id,
|
||||
shard_dim=shard_dim,
|
||||
loaded_weight=loaded_weight,
|
||||
expert_data=expert_data,
|
||||
tp_rank=self.tp_rank)
|
||||
return True if return_success else None
|
||||
|
||||
return False if return_success else None
|
||||
|
||||
def _load_model_weight_or_group_weight_scale(self,
|
||||
shard_dim: int,
|
||||
expert_data: torch.Tensor,
|
||||
shard_id: str,
|
||||
loaded_weight: torch.Tensor,
|
||||
tp_rank: int,
|
||||
load_full_w2: bool = False,
|
||||
expert_id: int = 0):
|
||||
"""
|
||||
Load grouped weight scales for group quantization or model weights
|
||||
:param shard_dim: dimension to shard
|
||||
:param expert_data: parameter for a particular expert
|
||||
:param shard_id: either w1, w2, or w3
|
||||
:param loaded_weight: checkpoint weight to load into the param
|
||||
:param tp_rank: tensor parallel rank
|
||||
:param load_full_w2: whether or not the w2 loaded should be sharded.
|
||||
"""
|
||||
if shard_id == "w2":
|
||||
# In the case where we have actorder/g_idx, we do not partition the
|
||||
# w2 scales, as indicated by `load_full` argument, for all tp cases
|
||||
self._load_w2(shard_dim=shard_dim,
|
||||
loaded_weight=loaded_weight,
|
||||
expert_data=expert_data,
|
||||
tp_rank=tp_rank,
|
||||
load_full=load_full_w2,
|
||||
expert_id=expert_id)
|
||||
elif shard_id in ("w1", "w3"):
|
||||
self._load_w13(shard_id=shard_id,
|
||||
shard_dim=shard_dim,
|
||||
loaded_weight=loaded_weight,
|
||||
expert_data=expert_data,
|
||||
tp_rank=tp_rank,
|
||||
expert_id=expert_id)
|
||||
|
||||
class UnquantizedFusedMoEMethod():
|
||||
def forward_vacc(
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
use_grouped_topk: bool,
|
||||
top_k: int,
|
||||
router_logits: torch.Tensor,
|
||||
renormalize: bool,
|
||||
topk_group: Optional[int] = None,
|
||||
num_expert_group: Optional[int] = None,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
scoring_func: str = "softmax",
|
||||
e_score_correction_bias: Optional[torch.Tensor] = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
activation: str = "silu",
|
||||
) -> torch.Tensor:
|
||||
|
||||
hidden_states = x
|
||||
orig_shape = hidden_states.shape
|
||||
hidden_size = hidden_states.shape[-1]
|
||||
num_tokens = hidden_states.shape[:-1].numel()
|
||||
dtype = hidden_states.dtype
|
||||
intermediate_size = layer.w2_weight.shape[-1]
|
||||
gating_output=router_logits
|
||||
|
||||
hidden_states = hidden_states.view(num_tokens, hidden_size)
|
||||
gating_output = gating_output.view(num_tokens, global_num_experts)
|
||||
topk_weights = gating_output.softmax(dim=-1, dtype=torch.float)
|
||||
topk_weights, topk_ids = topk_weights.topk(top_k, dim=-1)
|
||||
if renormalize:
|
||||
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
|
||||
topk_weights = topk_weights.to(dtype)
|
||||
|
||||
if expert_map is not None:
|
||||
topk_ids = expert_map[topk_ids]
|
||||
|
||||
final_hidden_states = torch.zeros_like(hidden_states)
|
||||
sel_experts = topk_ids.shape[1]
|
||||
if hidden_states.shape[0] == 1:
|
||||
for id in range(sel_experts):
|
||||
expert_idx = topk_ids[0][id]
|
||||
expert_w1 = layer.w13_weight[expert_idx].contiguous()
|
||||
expert_w2 = layer.w2_weight[expert_idx].contiguous()
|
||||
|
||||
expert_weights = topk_weights[0][id].to(hidden_states.dtype)
|
||||
|
||||
x = hidden_states
|
||||
x = F.linear(x, expert_w1)
|
||||
gate = F.silu(x[:, :intermediate_size])
|
||||
x = x[:, intermediate_size:] * gate
|
||||
x = F.linear(x, expert_w2)
|
||||
|
||||
current_hidden_states = x * expert_weights
|
||||
current_hidden_states = current_hidden_states.to(x.dtype)
|
||||
final_hidden_states += current_hidden_states
|
||||
else:
|
||||
for expert_idx in range(global_num_experts):
|
||||
# topk_ids [tokens, experts] => sample:[10, 8]
|
||||
# expert_mask [tokens, experts] => sample:[10, 8]
|
||||
expert_mask = topk_ids == expert_idx
|
||||
|
||||
idx = torch.where(expert_mask)[0]
|
||||
if idx.numel() == 0:
|
||||
continue
|
||||
|
||||
expert_w1 = layer.w13_weight[expert_idx].contiguous()
|
||||
expert_w2 = layer.w2_weight[expert_idx].contiguous()
|
||||
|
||||
# [seq, experts]
|
||||
expert_weights = (
|
||||
topk_weights.masked_select(expert_mask)
|
||||
.unsqueeze(1)
|
||||
.to(hidden_states.dtype)
|
||||
)
|
||||
|
||||
x = hidden_states[idx]
|
||||
x = F.linear(x, expert_w1)
|
||||
gate = F.silu(x[:, :intermediate_size])
|
||||
x = x[:, intermediate_size:] * gate
|
||||
x = F.linear(x, expert_w2)
|
||||
|
||||
current_hidden_states = x * expert_weights
|
||||
current_hidden_states = current_hidden_states.to(x.dtype)
|
||||
# final_hidden_states[idx] += current_hidden_states
|
||||
final_hidden_states.index_add_(0, idx, current_hidden_states)
|
||||
return final_hidden_states.view(orig_shape) # type: ignore
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
top_k: int,
|
||||
renormalize: bool,
|
||||
use_grouped_topk: bool = False,
|
||||
topk_group: Optional[int] = None,
|
||||
num_expert_group: Optional[int] = None,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
scoring_func: str = "softmax",
|
||||
e_score_correction_bias: Optional[torch.Tensor] = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
activation: str = "silu",
|
||||
enable_eplb: bool = False,
|
||||
expert_load_view: Optional[torch.Tensor] = None,
|
||||
logical_to_physical_map: Optional[torch.Tensor] = None,
|
||||
logical_replica_count: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
|
||||
custom_forward = self.forward
|
||||
if x.device.type == "vacc":
|
||||
custom_forward = UnquantizedFusedMoEMethod.forward_vacc
|
||||
|
||||
return custom_forward(
|
||||
x=x,
|
||||
layer=layer,
|
||||
router_logits=router_logits,
|
||||
top_k=top_k,
|
||||
renormalize=renormalize,
|
||||
use_grouped_topk=use_grouped_topk,
|
||||
topk_group=topk_group,
|
||||
num_expert_group=num_expert_group,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map,
|
||||
custom_routing_function=custom_routing_function,
|
||||
scoring_func=scoring_func,
|
||||
e_score_correction_bias=e_score_correction_bias,
|
||||
activation=activation,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user