[2/N] MoE Refactor: Unify weight loader and quant methods (#8397)
This commit is contained in:
@@ -30,13 +30,13 @@ from sglang.srt.layers.quantization.base_config import (
|
||||
QuantizationConfig,
|
||||
QuantizeMethodBase,
|
||||
)
|
||||
from sglang.srt.layers.quantization.fp8 import Fp8EPMoEMethod
|
||||
from sglang.srt.layers.quantization.fp8 import Fp8Config, Fp8MoEMethod
|
||||
from sglang.srt.layers.quantization.fp8_kernel import (
|
||||
is_fp8_fnuz,
|
||||
sglang_per_token_group_quant_fp8,
|
||||
sglang_per_token_quant_fp8,
|
||||
)
|
||||
from sglang.srt.layers.quantization.unquant import UnquantizedEPMoEMethod
|
||||
from sglang.srt.layers.quantization.unquant import UnquantizedFusedMoEMethod
|
||||
from sglang.srt.layers.quantization.w4afp8 import W4AFp8Config, W4AFp8MoEMethod
|
||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
@@ -62,8 +62,6 @@ use_flashinfer_trtllm_moe = (
|
||||
if not (_is_npu or _is_hip):
|
||||
from sgl_kernel import silu_and_mul
|
||||
|
||||
from sglang.srt.layers.moe.cutlass_w4a8_moe import cutlass_w4a8_moe
|
||||
|
||||
if _use_aiter:
|
||||
from aiter import ActivationType, QuantType
|
||||
from aiter.fused_moe import fused_moe
|
||||
@@ -162,7 +160,7 @@ def _get_tile_tokens_dim(num_tokens, top_k, num_experts):
|
||||
return tile_tokens_dim
|
||||
|
||||
|
||||
class EPMoE(torch.nn.Module):
|
||||
class EPMoE(FusedMoE):
|
||||
"""
|
||||
MoE Expert Parallel Impl
|
||||
|
||||
@@ -184,51 +182,60 @@ class EPMoE(torch.nn.Module):
|
||||
routed_scaling_factor: Optional[float] = None,
|
||||
use_per_token_if_dynamic: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
super().__init__(
|
||||
num_experts=num_experts,
|
||||
hidden_size=hidden_size,
|
||||
intermediate_size=intermediate_size,
|
||||
top_k=top_k,
|
||||
layer_id=layer_id,
|
||||
params_dtype=params_dtype,
|
||||
quant_config=quant_config,
|
||||
tp_size=tp_size,
|
||||
prefix=prefix,
|
||||
activation=activation,
|
||||
routed_scaling_factor=routed_scaling_factor,
|
||||
enable_ep_moe=True,
|
||||
skip_quant=True,
|
||||
)
|
||||
|
||||
if params_dtype is None:
|
||||
params_dtype = torch.get_default_dtype()
|
||||
|
||||
self.tp_size = (
|
||||
tp_size if tp_size is not None else get_tensor_model_parallel_world_size()
|
||||
)
|
||||
self.tp_rank = get_tensor_model_parallel_rank()
|
||||
|
||||
self.layer_id = layer_id
|
||||
self.num_experts = num_experts
|
||||
assert self.num_experts % self.tp_size == 0
|
||||
self.num_experts_per_partition, self.expert_map = self.determine_expert_map()
|
||||
self.start_expert_id = self.tp_rank * self.num_experts_per_partition
|
||||
self.end_expert_id = self.start_expert_id + self.num_experts_per_partition - 1
|
||||
self.num_local_experts, self.expert_map = self.determine_expert_map()
|
||||
self.start_expert_id = self.ep_rank * self.num_local_experts
|
||||
self.end_expert_id = self.start_expert_id + self.num_local_experts - 1
|
||||
|
||||
self.top_k = top_k
|
||||
self.intermediate_size = intermediate_size
|
||||
self.activation = activation
|
||||
self.routed_scaling_factor = routed_scaling_factor
|
||||
self.use_per_token_if_dynamic = use_per_token_if_dynamic
|
||||
|
||||
# TODO(ch-wan): move quant preparation to FusedMoE
|
||||
if quant_config is None:
|
||||
self.quant_method: Optional[QuantizeMethodBase] = UnquantizedEPMoEMethod()
|
||||
self.quant_method: Optional[QuantizeMethodBase] = (
|
||||
UnquantizedFusedMoEMethod()
|
||||
)
|
||||
self.use_fp8_w8a8 = False
|
||||
self.use_block_quant = False
|
||||
self.block_shape = None
|
||||
self.activation_scheme = None
|
||||
self.use_w4afp8 = False
|
||||
self.w13_input_scale = None
|
||||
self.w2_input_scale = None
|
||||
self.w13_weight_scale = None
|
||||
self.w2_weight_scale = None
|
||||
elif isinstance(quant_config, W4AFp8Config):
|
||||
self.quant_method: Optional[QuantizeMethodBase] = W4AFp8MoEMethod(
|
||||
quant_config
|
||||
)
|
||||
self.use_w4afp8 = True
|
||||
self.use_fp8_w8a8 = False
|
||||
self.use_block_quant = False
|
||||
self.fp8_dtype = torch.float8_e4m3fn
|
||||
self.w13_input_scale = None
|
||||
self.w2_input_scale = None
|
||||
self.w13_weight_scale = None
|
||||
self.w2_weight_scale = None
|
||||
self.activation_scheme = quant_config.moe_activation_scheme
|
||||
else:
|
||||
self.quant_method: Optional[QuantizeMethodBase] = Fp8EPMoEMethod(
|
||||
quant_config
|
||||
)
|
||||
elif isinstance(quant_config, Fp8Config):
|
||||
self.quant_method: Optional[QuantizeMethodBase] = Fp8MoEMethod(quant_config)
|
||||
self.use_fp8_w8a8 = True
|
||||
self.use_block_quant = getattr(self.quant_method, "block_quant", False)
|
||||
self.block_shape = (
|
||||
@@ -238,11 +245,13 @@ class EPMoE(torch.nn.Module):
|
||||
)
|
||||
self.fp8_dtype = torch.float8_e4m3fn
|
||||
self.activation_scheme = quant_config.activation_scheme
|
||||
self.use_w4afp8 = False
|
||||
else:
|
||||
raise ValueError(f"Unsupported quant_config: {quant_config}")
|
||||
|
||||
self.quant_config = quant_config
|
||||
self.quant_method.create_weights(
|
||||
layer=self,
|
||||
num_experts_per_partition=self.num_experts_per_partition,
|
||||
num_experts=self.num_local_experts,
|
||||
hidden_size=hidden_size,
|
||||
intermediate_size=self.intermediate_size,
|
||||
params_dtype=params_dtype,
|
||||
@@ -251,19 +260,6 @@ class EPMoE(torch.nn.Module):
|
||||
|
||||
self.grouped_gemm_runner = None
|
||||
|
||||
self.w13_weight_fp8 = (
|
||||
self.w13_weight,
|
||||
(
|
||||
self.w13_weight_scale_inv
|
||||
if self.use_block_quant
|
||||
else self.w13_weight_scale
|
||||
),
|
||||
)
|
||||
self.w2_weight_fp8 = (
|
||||
self.w2_weight,
|
||||
self.w2_weight_scale_inv if self.use_block_quant else self.w2_weight_scale,
|
||||
)
|
||||
|
||||
# Adapted from https://github.com/vllm-project/vllm/blob/9fb52e523abf7bdaf7e60cf2971edb5a1b13dc08/vllm/model_executor/layers/fused_moe/layer.py#L544C1-L586C43
|
||||
# Modifications: use determine_expert_map as a class internal function, set 'global_num_experts' rather than '-1' for experts not assigned to the current rank.
|
||||
def determine_expert_map(self) -> Tuple[int, Optional[torch.Tensor]]:
|
||||
@@ -282,8 +278,8 @@ class EPMoE(torch.nn.Module):
|
||||
Contains global_num_experts for experts not assigned to the current rank.
|
||||
Returns None if ep_size is 1.
|
||||
"""
|
||||
ep_size = self.tp_size
|
||||
ep_rank = self.tp_rank
|
||||
ep_size = self.ep_size
|
||||
ep_rank = self.ep_rank
|
||||
global_num_experts = self.num_experts
|
||||
|
||||
assert ep_size > 0
|
||||
@@ -293,7 +289,7 @@ class EPMoE(torch.nn.Module):
|
||||
local_num_experts = global_num_experts // ep_size
|
||||
|
||||
expert_map = torch.full(
|
||||
(global_num_experts,), self.num_experts, dtype=torch.int32
|
||||
(global_num_experts,), global_num_experts, dtype=torch.int32
|
||||
)
|
||||
if ep_rank < (ep_size - 1):
|
||||
expert_map[
|
||||
@@ -318,6 +314,20 @@ class EPMoE(torch.nn.Module):
|
||||
hidden_states: torch.Tensor,
|
||||
topk_output: TopKOutput,
|
||||
):
|
||||
|
||||
self.w13_weight_fp8 = (
|
||||
self.w13_weight,
|
||||
(
|
||||
self.w13_weight_scale_inv
|
||||
if self.use_block_quant
|
||||
else self.w13_weight_scale
|
||||
),
|
||||
)
|
||||
self.w2_weight_fp8 = (
|
||||
self.w2_weight,
|
||||
self.w2_weight_scale_inv if self.use_block_quant else self.w2_weight_scale,
|
||||
)
|
||||
|
||||
assert self.quant_method is not None
|
||||
assert self.activation == "silu"
|
||||
hidden_states_shape = hidden_states.shape
|
||||
@@ -457,7 +467,10 @@ class EPMoE(torch.nn.Module):
|
||||
return output
|
||||
|
||||
def forward_normal(self, hidden_states: torch.Tensor, topk_output: TopKOutput):
|
||||
assert self.quant_method is not None
|
||||
return self.quant_method.apply(self, hidden_states, topk_output)
|
||||
|
||||
def run_moe(self, hidden_states: torch.Tensor, topk_output: TopKOutput):
|
||||
|
||||
topk_weights, topk_ids, _ = topk_output
|
||||
|
||||
hidden_states_shape = hidden_states.shape
|
||||
@@ -470,53 +483,11 @@ class EPMoE(torch.nn.Module):
|
||||
use_per_token_if_dynamic=self.use_per_token_if_dynamic,
|
||||
)
|
||||
|
||||
if self.use_w4afp8:
|
||||
local_topk_ids = topk_ids
|
||||
if self.expert_map is not None:
|
||||
"Translate info from expert_map to topk_ids"
|
||||
local_topk_ids = torch.where(
|
||||
self.expert_map[topk_ids] != self.num_experts,
|
||||
self.expert_map[topk_ids],
|
||||
self.num_experts,
|
||||
)
|
||||
|
||||
output = cutlass_w4a8_moe(
|
||||
self.start_expert_id,
|
||||
self.end_expert_id,
|
||||
self.num_experts,
|
||||
hidden_states,
|
||||
self.w13_weight,
|
||||
self.w2_weight,
|
||||
self.w13_weight_scale_inv,
|
||||
self.w2_weight_scale_inv,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
local_topk_ids,
|
||||
self.quant_method.a_strides1,
|
||||
self.quant_method.b_strides1,
|
||||
self.quant_method.c_strides1,
|
||||
self.quant_method.a_strides2,
|
||||
self.quant_method.b_strides2,
|
||||
self.quant_method.c_strides2,
|
||||
self.quant_method.s_strides13,
|
||||
self.quant_method.s_strides2,
|
||||
self.quant_method.expert_offsets,
|
||||
self.quant_method.problem_sizes1,
|
||||
self.quant_method.problem_sizes2,
|
||||
self.w13_input_scale,
|
||||
self.w2_input_scale,
|
||||
)
|
||||
return output
|
||||
|
||||
if self.grouped_gemm_runner is None:
|
||||
self.grouped_gemm_runner = GroupedGemmRunner(
|
||||
hidden_states.device,
|
||||
use_flashinfer=False, # TODO: use flashinfer
|
||||
use_per_token_if_dynamic=self.use_per_token_if_dynamic,
|
||||
)
|
||||
num_experts = self.num_experts
|
||||
|
||||
reorder_topk_ids, src2dst, seg_indptr = run_moe_ep_preproess(
|
||||
topk_ids, self.num_experts
|
||||
topk_ids,
|
||||
num_experts,
|
||||
)
|
||||
|
||||
gateup_input = torch.empty(
|
||||
@@ -524,7 +495,7 @@ class EPMoE(torch.nn.Module):
|
||||
device=hidden_states.device,
|
||||
dtype=(
|
||||
self.fp8_dtype
|
||||
if ((self.use_fp8_w8a8 or self.use_w4afp8) and not self.use_block_quant)
|
||||
if self.use_fp8_w8a8 and not self.use_block_quant
|
||||
else hidden_states.dtype
|
||||
),
|
||||
)
|
||||
@@ -535,7 +506,7 @@ class EPMoE(torch.nn.Module):
|
||||
else:
|
||||
max_value = (
|
||||
torch.max(hidden_states)
|
||||
.repeat(self.num_experts_per_partition)
|
||||
.repeat(self.num_local_experts)
|
||||
.to(torch.float32)
|
||||
)
|
||||
self.w13_input_scale = max_value / torch.finfo(self.fp8_dtype).max
|
||||
@@ -576,7 +547,7 @@ class EPMoE(torch.nn.Module):
|
||||
seg_indptr_cur_rank = seg_indptr[self.start_expert_id : self.end_expert_id + 2]
|
||||
weight_indices_cur_rank = torch.arange(
|
||||
0,
|
||||
self.num_experts_per_partition,
|
||||
self.num_local_experts,
|
||||
device=hidden_states_device,
|
||||
dtype=torch.int64,
|
||||
)
|
||||
@@ -586,17 +557,13 @@ class EPMoE(torch.nn.Module):
|
||||
b=self.w13_weight,
|
||||
c=None,
|
||||
c_dtype=hidden_states_dtype,
|
||||
batch_size=self.num_experts_per_partition,
|
||||
batch_size=self.num_local_experts,
|
||||
weight_column_major=True,
|
||||
seg_indptr=seg_indptr_cur_rank,
|
||||
weight_indices=weight_indices_cur_rank,
|
||||
use_fp8_w8a8=self.use_fp8_w8a8,
|
||||
scale_a=self.w13_input_scale,
|
||||
scale_b=(
|
||||
self.w13_weight_scale_inv
|
||||
if self.use_block_quant
|
||||
else self.w13_weight_scale
|
||||
),
|
||||
scale_b=self.w13_weight_scale,
|
||||
block_shape=self.block_shape,
|
||||
)
|
||||
del gateup_input
|
||||
@@ -653,7 +620,7 @@ class EPMoE(torch.nn.Module):
|
||||
down_input, self.w2_input_scale = sglang_per_token_quant_fp8(down_input)
|
||||
else:
|
||||
self.w2_input_scale = torch.ones(
|
||||
self.num_experts_per_partition,
|
||||
self.num_local_experts,
|
||||
dtype=torch.float32,
|
||||
device=hidden_states_device,
|
||||
)
|
||||
@@ -669,17 +636,13 @@ class EPMoE(torch.nn.Module):
|
||||
a=down_input,
|
||||
b=self.w2_weight,
|
||||
c=down_output,
|
||||
batch_size=self.num_experts_per_partition,
|
||||
batch_size=self.num_local_experts,
|
||||
weight_column_major=True,
|
||||
seg_indptr=seg_indptr_cur_rank,
|
||||
weight_indices=weight_indices_cur_rank,
|
||||
use_fp8_w8a8=self.use_fp8_w8a8,
|
||||
scale_a=self.w2_input_scale,
|
||||
scale_b=(
|
||||
self.w2_weight_scale_inv
|
||||
if self.use_block_quant
|
||||
else self.w2_weight_scale
|
||||
),
|
||||
scale_b=self.w2_weight_scale,
|
||||
block_shape=self.block_shape,
|
||||
)
|
||||
del down_input
|
||||
@@ -782,107 +745,14 @@ class EPMoE(torch.nn.Module):
|
||||
return
|
||||
expert_id = expert_id - self.start_expert_id
|
||||
|
||||
if shard_id not in ("w1", "w2", "w3"):
|
||||
raise ValueError(
|
||||
f"shard_id must be ['w1','w2','w3'] but " f"got {shard_id}."
|
||||
)
|
||||
|
||||
# Special case for fp8 scales.
|
||||
if "scale" in weight_name:
|
||||
self._load_fp8_scale(
|
||||
param.data,
|
||||
loaded_weight,
|
||||
weight_name,
|
||||
shard_id,
|
||||
expert_id,
|
||||
)
|
||||
return
|
||||
|
||||
# Flashinfer assumes w31 format for w13_weight. Same for the scales.
|
||||
if use_flashinfer_trtllm_moe:
|
||||
actual_shard_id = {"w1": "w3", "w3": "w1", "w2": "w2"}[shard_id]
|
||||
else:
|
||||
actual_shard_id = shard_id
|
||||
|
||||
if actual_shard_id == "w2":
|
||||
param.data[expert_id] = loaded_weight
|
||||
elif actual_shard_id == "w1":
|
||||
param.data[expert_id][: self.intermediate_size, :] = loaded_weight
|
||||
elif actual_shard_id == "w3":
|
||||
param.data[expert_id][self.intermediate_size :, :] = loaded_weight
|
||||
else:
|
||||
raise ValueError(f"Expected shard_id w1,w2 or w3 but got {actual_shard_id}")
|
||||
|
||||
def _load_fp8_scale(
|
||||
self,
|
||||
param: torch.nn.Parameter,
|
||||
loaded_weight: torch.Tensor,
|
||||
weight_name: str,
|
||||
shard_id: str,
|
||||
expert_id: int,
|
||||
) -> None:
|
||||
param_data = param.data
|
||||
|
||||
# Input scales can be loaded directly and should be equal.
|
||||
if "input_scale" in weight_name:
|
||||
if self.use_w4afp8:
|
||||
if shard_id == "w1":
|
||||
param_data[expert_id][0] = loaded_weight
|
||||
elif shard_id == "w3":
|
||||
param_data[expert_id][1] = loaded_weight
|
||||
else:
|
||||
param_data[expert_id] = loaded_weight
|
||||
return
|
||||
|
||||
if (
|
||||
(shard_id == "w1" or shard_id == "w3")
|
||||
and param_data[expert_id] != 1
|
||||
and (param_data[expert_id] - loaded_weight).abs() > 1e-5
|
||||
):
|
||||
raise ValueError(
|
||||
"input_scales of w1 and w3 of a layer "
|
||||
f"must be equal. But got {param_data[expert_id]} "
|
||||
f"vs. {loaded_weight}"
|
||||
)
|
||||
param_data[expert_id] = loaded_weight
|
||||
# Weight scales
|
||||
elif "weight_scale" in weight_name:
|
||||
if self.use_block_quant:
|
||||
if use_flashinfer_trtllm_moe:
|
||||
actual_shard_id = {"w1": "w3", "w3": "w1", "w2": "w2"}[shard_id]
|
||||
else:
|
||||
actual_shard_id = shard_id
|
||||
|
||||
block_n, block_k = self.block_shape[0], self.block_shape[1]
|
||||
|
||||
if actual_shard_id == "w1":
|
||||
param_data[expert_id][
|
||||
: (self.intermediate_size + block_n - 1) // block_n, :
|
||||
] = loaded_weight
|
||||
elif actual_shard_id == "w3":
|
||||
param_data[expert_id][
|
||||
(self.intermediate_size + block_n - 1) // block_n :, :
|
||||
] = loaded_weight
|
||||
else: # w2
|
||||
param_data[expert_id] = loaded_weight
|
||||
elif self.use_w4afp8:
|
||||
if shard_id == "w1":
|
||||
param_data[expert_id][: self.intermediate_size, :] = loaded_weight
|
||||
elif shard_id == "w3":
|
||||
param_data[expert_id][self.intermediate_size :, :] = loaded_weight
|
||||
else:
|
||||
param_data[expert_id] = loaded_weight
|
||||
# If we are in merged column case (gate_up_proj)
|
||||
else:
|
||||
if shard_id in ("w1", "w3"):
|
||||
# We have to keep the weight scales of w1 and w3 because
|
||||
# we need to re-quantize w1/w3 weights after weight loading.
|
||||
idx = 0 if shard_id == "w1" else 1
|
||||
param_data[expert_id][idx] = loaded_weight
|
||||
|
||||
# If we are in the row parallel case (down_proj)
|
||||
else:
|
||||
param_data[expert_id] = loaded_weight
|
||||
self._weight_loader_impl(
|
||||
param=param,
|
||||
loaded_weight=loaded_weight,
|
||||
weight_name=weight_name,
|
||||
shard_id=shard_id,
|
||||
expert_id=expert_id,
|
||||
)
|
||||
return
|
||||
|
||||
|
||||
class DeepEPMoE(EPMoE):
|
||||
@@ -932,13 +802,13 @@ class DeepEPMoE(EPMoE):
|
||||
deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM
|
||||
), f"DeepEP {self.deepep_mode} mode requires deep_gemm"
|
||||
if _use_aiter:
|
||||
# expert_mask is of size (self.num_experts_per_partition + 1),
|
||||
# expert_mask is of size (self.num_local_experts + 1),
|
||||
# the extra 1 is for invalid rank_id (in original deepep, the invalid rank_id is -1, but aiter does not allow -1, we use a mask to make those ids invalid)
|
||||
# for instance, if we have 4 experts on this rank, we would have a expert_mask like:
|
||||
# self.expert_mask = [1, 1, 1, 1, 0]
|
||||
# idx from 0-3 is valid and will be processed, while idx == 4 will be masked out
|
||||
self.expert_mask = torch.zeros(
|
||||
(self.num_experts_per_partition + 1),
|
||||
(self.num_local_experts + 1),
|
||||
device=torch.cuda.current_device(),
|
||||
dtype=torch.int,
|
||||
)
|
||||
@@ -1011,13 +881,13 @@ class DeepEPMoE(EPMoE):
|
||||
if self.activation_scheme == "dynamic" and not self.use_block_quant:
|
||||
max_value = (
|
||||
torch.max(hidden_states)
|
||||
.repeat(self.num_experts_per_partition)
|
||||
.repeat(self.num_local_experts)
|
||||
.to(torch.float32)
|
||||
)
|
||||
self.w13_input_scale = max_value / torch.finfo(self.fp8_dtype).max
|
||||
weight_indices_cur_rank = torch.arange(
|
||||
0,
|
||||
self.num_experts_per_partition,
|
||||
self.num_local_experts,
|
||||
device=hidden_states.device,
|
||||
dtype=torch.int64,
|
||||
)
|
||||
@@ -1029,7 +899,7 @@ class DeepEPMoE(EPMoE):
|
||||
b=self.w13_weight,
|
||||
c=None,
|
||||
c_dtype=hidden_states.dtype,
|
||||
batch_size=self.num_experts_per_partition,
|
||||
batch_size=self.num_local_experts,
|
||||
weight_column_major=True,
|
||||
seg_indptr=seg_indptr,
|
||||
weight_indices=weight_indices_cur_rank,
|
||||
@@ -1063,7 +933,7 @@ class DeepEPMoE(EPMoE):
|
||||
)
|
||||
if self.w2_input_scale is None and not self.use_block_quant:
|
||||
self.w2_input_scale = torch.ones(
|
||||
self.num_experts_per_partition,
|
||||
self.num_local_experts,
|
||||
dtype=torch.float32,
|
||||
device=hidden_states_device,
|
||||
)
|
||||
@@ -1076,7 +946,7 @@ class DeepEPMoE(EPMoE):
|
||||
reorder_topk_ids,
|
||||
self.w2_input_scale,
|
||||
0,
|
||||
self.num_experts_per_partition - 1,
|
||||
self.num_local_experts - 1,
|
||||
BLOCK_SIZE=512,
|
||||
)
|
||||
else:
|
||||
@@ -1096,7 +966,7 @@ class DeepEPMoE(EPMoE):
|
||||
a=down_input,
|
||||
b=self.w2_weight,
|
||||
c=down_output,
|
||||
batch_size=self.num_experts_per_partition,
|
||||
batch_size=self.num_local_experts,
|
||||
weight_column_major=True,
|
||||
seg_indptr=seg_indptr,
|
||||
weight_indices=weight_indices_cur_rank,
|
||||
@@ -1121,9 +991,9 @@ class DeepEPMoE(EPMoE):
|
||||
return hidden_states
|
||||
# in original deepep, idx == -1 meaning invalid and will not be processed.
|
||||
# aiter does not accept -1, we use a expert mask to make these idx invalid
|
||||
# (idx == num_experts_per_partition) meaning not used in aiter fused_moe
|
||||
# (idx == num_local_experts) meaning not used in aiter fused_moe
|
||||
topk_idx_copy = topk_idx.to(torch.int32)
|
||||
topk_idx_copy[topk_idx_copy == -1] = self.num_experts_per_partition
|
||||
topk_idx_copy[topk_idx_copy == -1] = self.num_local_experts
|
||||
|
||||
return fused_moe(
|
||||
hidden_states,
|
||||
|
||||
@@ -77,6 +77,7 @@ class FusedMoE(torch.nn.Module):
|
||||
routed_scaling_factor: Optional[float] = None,
|
||||
enable_flashinfer_cutlass_moe: Optional[bool] = False,
|
||||
enable_ep_moe: Optional[bool] = False,
|
||||
skip_quant: Optional[bool] = False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@@ -99,9 +100,6 @@ class FusedMoE(torch.nn.Module):
|
||||
|
||||
self.enable_flashinfer_cutlass_moe = enable_flashinfer_cutlass_moe
|
||||
if enable_ep_moe:
|
||||
assert (
|
||||
self.enable_flashinfer_cutlass_moe
|
||||
), "FusedMoE only supports EP with --enable-flashinfer-cutlass-moe"
|
||||
self.ep_size = self.tp_size
|
||||
self.ep_rank = self.tp_rank
|
||||
self.tp_size = 1
|
||||
@@ -110,16 +108,16 @@ class FusedMoE(torch.nn.Module):
|
||||
self.expert_map = torch.full((self.num_experts,), -1, dtype=torch.int32)
|
||||
# Create a expert map for the local experts
|
||||
assert num_experts % self.ep_size == 0
|
||||
self.local_num_experts = num_experts // self.ep_size
|
||||
self.num_local_experts = num_experts // self.ep_size
|
||||
self.expert_map[
|
||||
self.ep_rank
|
||||
* self.local_num_experts : (self.ep_rank + 1)
|
||||
* self.local_num_experts
|
||||
] = torch.arange(0, self.local_num_experts, dtype=torch.int32, device="cpu")
|
||||
* self.num_local_experts : (self.ep_rank + 1)
|
||||
* self.num_local_experts
|
||||
] = torch.arange(0, self.num_local_experts, dtype=torch.int32, device="cpu")
|
||||
else:
|
||||
self.ep_size = 1
|
||||
self.ep_rank = 0
|
||||
self.local_num_experts = num_experts
|
||||
self.num_local_experts = num_experts
|
||||
self.routed_scaling_factor = routed_scaling_factor
|
||||
assert intermediate_size % self.tp_size == 0
|
||||
self.intermediate_size_per_partition = intermediate_size // self.tp_size
|
||||
@@ -134,6 +132,9 @@ class FusedMoE(torch.nn.Module):
|
||||
not _is_cpu and global_server_args_dict["enable_triton_kernel_moe"]
|
||||
)
|
||||
|
||||
if skip_quant:
|
||||
return
|
||||
|
||||
if quant_config is None:
|
||||
self.quant_method: Optional[QuantizeMethodBase] = UnquantizedFusedMoEMethod(
|
||||
self.use_triton_kernels
|
||||
@@ -149,7 +150,7 @@ class FusedMoE(torch.nn.Module):
|
||||
self.quant_config = quant_config
|
||||
self.quant_method.create_weights(
|
||||
layer=self,
|
||||
num_experts=self.local_num_experts,
|
||||
num_experts=self.num_local_experts,
|
||||
hidden_size=hidden_size,
|
||||
# FIXME: figure out which intermediate_size to use
|
||||
intermediate_size=self.intermediate_size_per_partition,
|
||||
@@ -378,6 +379,23 @@ class FusedMoE(torch.nn.Module):
|
||||
if expert_id == -1:
|
||||
return
|
||||
|
||||
self._weight_loader_impl(
|
||||
param=param,
|
||||
loaded_weight=loaded_weight,
|
||||
weight_name=weight_name,
|
||||
shard_id=shard_id,
|
||||
expert_id=expert_id,
|
||||
)
|
||||
|
||||
def _weight_loader_impl(
|
||||
self,
|
||||
param: torch.nn.Parameter,
|
||||
loaded_weight: torch.Tensor,
|
||||
weight_name: str,
|
||||
shard_id: str,
|
||||
expert_id: int,
|
||||
) -> None:
|
||||
|
||||
# TP rank is set to 0 if EP is enabled
|
||||
tp_rank = 0 if self.ep_size > 1 else get_tensor_model_parallel_rank()
|
||||
|
||||
@@ -398,6 +416,10 @@ class FusedMoE(torch.nn.Module):
|
||||
f"shard_id must be ['w1','w2','w3'] but " f"got {shard_id}."
|
||||
)
|
||||
|
||||
# Flashinfer assumes w31 format for w13_weight. Same for the scales.
|
||||
if getattr(self, "use_flashinfer_trtllm_moe", False):
|
||||
shard_id = {"w1": "w3", "w3": "w1", "w2": "w2"}[shard_id]
|
||||
|
||||
WEIGHT_SCALE_SUPPORTED = [e.value for e in FusedMoeWeightScaleSupported]
|
||||
# Fetch the dim to shard the parameter/loaded weight
|
||||
# based on the shard id. This will be whatever
|
||||
@@ -605,37 +627,3 @@ class FusedMoE(torch.nn.Module):
|
||||
("w3", ckpt_up_proj_name),
|
||||
]
|
||||
]
|
||||
|
||||
def _load_fp8_scale(
|
||||
self,
|
||||
param: torch.nn.Parameter,
|
||||
loaded_weight: torch.Tensor,
|
||||
weight_name: str,
|
||||
shard_id: str,
|
||||
expert_id: int,
|
||||
) -> None:
|
||||
param_data = param.data
|
||||
|
||||
# Input scales can be loaded directly and should be equal.
|
||||
if "input_scale" in weight_name:
|
||||
if (
|
||||
param_data[expert_id] != 1
|
||||
and (param_data[expert_id] - loaded_weight).abs() > 1e-5
|
||||
):
|
||||
raise ValueError(
|
||||
"input_scales of w1 and w3 of a layer "
|
||||
f"must be equal. But got {param_data[expert_id]} "
|
||||
f"vs. {loaded_weight}"
|
||||
)
|
||||
param_data[expert_id] = loaded_weight
|
||||
# Weight scales
|
||||
elif "weight_scale" in weight_name:
|
||||
# If we are in merged column case (gate_up_proj)
|
||||
if shard_id in ("w1", "w3"):
|
||||
# We have to keep the weight scales of w1 and w3 because
|
||||
# we need to re-quantize w1/w3 weights after weight loading.
|
||||
idx = 0 if shard_id == "w1" else 1
|
||||
param_data[expert_id][idx] = loaded_weight
|
||||
# If we are in the row parallel case (down_proj)
|
||||
else:
|
||||
param_data[expert_id] = loaded_weight
|
||||
|
||||
@@ -172,6 +172,7 @@ class Fp8Config(QuantizationConfig):
|
||||
self, layer: torch.nn.Module, prefix: str
|
||||
) -> Optional[QuantizeMethodBase]:
|
||||
from sglang.srt.layers.linear import LinearBase
|
||||
from sglang.srt.layers.moe.ep_moe.layer import EPMoE
|
||||
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
||||
|
||||
if isinstance(layer, LinearBase):
|
||||
@@ -180,6 +181,8 @@ class Fp8Config(QuantizationConfig):
|
||||
return Fp8LinearMethod(self)
|
||||
elif isinstance(layer, FusedMoE):
|
||||
return Fp8MoEMethod(self)
|
||||
elif isinstance(layer, EPMoE):
|
||||
return Fp8EPMoEMethod(self)
|
||||
return None
|
||||
|
||||
def get_scaled_act_names(self) -> List[str]:
|
||||
@@ -791,11 +794,13 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
# merged w13 weights and generate a single scaling factor.
|
||||
layer.w13_weight_scale = torch.nn.Parameter(
|
||||
torch.ones(
|
||||
layer.num_experts, dtype=torch.float32, device=w13_weight.device
|
||||
layer.num_local_experts,
|
||||
dtype=torch.float32,
|
||||
device=w13_weight.device,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
for expert in range(layer.num_experts):
|
||||
for expert in range(layer.num_local_experts):
|
||||
w13_weight[expert, :, :], layer.w13_weight_scale[expert] = (
|
||||
scaled_fp8_quant(layer.w13_weight.data[expert, :, :])
|
||||
)
|
||||
@@ -871,7 +876,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
assert layer.w13_weight_scale is not None
|
||||
shard_size = layer.intermediate_size_per_partition
|
||||
max_w13_scales = layer.w13_weight_scale.max(dim=1).values
|
||||
for expert_id in range(layer.num_experts):
|
||||
for expert_id in range(layer.num_local_experts):
|
||||
start = 0
|
||||
for shard_id in range(2):
|
||||
dq_weight = per_tensor_dequantize(
|
||||
@@ -914,7 +919,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
assert layer.w13_weight_scale is not None
|
||||
shard_size = layer.intermediate_size_per_partition
|
||||
max_w13_scales = layer.w13_weight_scale.max(dim=1).values
|
||||
for expert_id in range(layer.num_experts):
|
||||
for expert_id in range(layer.num_local_experts):
|
||||
start = 0
|
||||
max_w13_scale_fp8 = max_w13_scales[expert_id]
|
||||
for shard_id in range(2):
|
||||
@@ -931,7 +936,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
|
||||
# special hack to asm_moe, which takes (weight_scale1 * weight_scale) as post GEMM scaling
|
||||
# optimal design - shall apply per-column weight_scale1 before GEMM, and weight_scale post
|
||||
for expert_id in range(layer.num_experts):
|
||||
for expert_id in range(layer.num_local_experts):
|
||||
layer.w13_weight_scale1[expert_id] *= max_w13_scales[expert_id]
|
||||
layer.w2_weight_scale1[expert_id] *= layer.w2_weight_scale[expert_id]
|
||||
|
||||
@@ -979,8 +984,23 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
no_combine: bool = False,
|
||||
routed_scaling_factor: Optional[float] = None,
|
||||
) -> torch.Tensor:
|
||||
from sglang.srt.layers.moe.ep_moe.layer import EPMoE
|
||||
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
|
||||
|
||||
if isinstance(layer, EPMoE):
|
||||
layer.w13_weight_scale = (
|
||||
layer.w13_weight_scale_inv
|
||||
if self.block_quant
|
||||
else layer.w13_weight_scale
|
||||
)
|
||||
layer.w2_weight_scale = (
|
||||
layer.w2_weight_scale_inv if self.block_quant else layer.w2_weight_scale
|
||||
)
|
||||
return layer.run_moe(
|
||||
hidden_states=x,
|
||||
topk_output=topk_output,
|
||||
)
|
||||
|
||||
if use_intel_amx_backend(layer):
|
||||
from sglang.srt.layers.moe.topk import apply_topk_weights_cpu
|
||||
|
||||
@@ -1138,248 +1158,6 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
return None
|
||||
|
||||
|
||||
class Fp8EPMoEMethod(Fp8MoEMethod):
|
||||
"""MoE method for FP8.
|
||||
Supports loading FP8 checkpoints with static weight scale and
|
||||
dynamic/static activation scale.
|
||||
|
||||
Args:
|
||||
quant_config: The quantization config.
|
||||
"""
|
||||
|
||||
def __init__(self, quant_config: Fp8Config):
|
||||
self.quant_config = quant_config
|
||||
self.block_quant = self.quant_config.weight_block_size is not None
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
layer: Module,
|
||||
num_experts_per_partition: int,
|
||||
hidden_size: int,
|
||||
intermediate_size: int,
|
||||
params_dtype: torch.dtype,
|
||||
**extra_weight_attrs,
|
||||
):
|
||||
from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported
|
||||
|
||||
if self.quant_config.is_checkpoint_fp8_serialized:
|
||||
params_dtype = torch.float8_e4m3fn
|
||||
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
if self.block_quant:
|
||||
block_n, block_k = (
|
||||
self.quant_config.weight_block_size[0],
|
||||
self.quant_config.weight_block_size[1],
|
||||
)
|
||||
# NOTE(HandH1998): To ensure proper alignment of the block-wise quantization scales, the output_size of the weights for both the gate and up layers must be divisible by block_n.
|
||||
# Required by column parallel or enabling merged weights
|
||||
if intermediate_size % block_n != 0:
|
||||
raise ValueError(
|
||||
f"The output_size of gate's and up's weight = "
|
||||
f"{intermediate_size} is not divisible by "
|
||||
f"weight quantization block_n = {block_n}."
|
||||
)
|
||||
if tp_size > 1:
|
||||
# Required by row parallel
|
||||
if intermediate_size % block_k != 0:
|
||||
raise ValueError(
|
||||
f"The input_size of down's weight = "
|
||||
f"{intermediate_size} is not divisible by "
|
||||
f"weight quantization block_k = {block_k}."
|
||||
)
|
||||
|
||||
# WEIGHTS
|
||||
w13_weight = torch.nn.Parameter(
|
||||
torch.empty(
|
||||
num_experts_per_partition,
|
||||
2 * intermediate_size,
|
||||
hidden_size,
|
||||
dtype=params_dtype,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w13_weight", w13_weight)
|
||||
set_weight_attrs(w13_weight, extra_weight_attrs)
|
||||
|
||||
w2_weight = torch.nn.Parameter(
|
||||
torch.empty(
|
||||
num_experts_per_partition,
|
||||
hidden_size,
|
||||
intermediate_size,
|
||||
dtype=params_dtype,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w2_weight", w2_weight)
|
||||
set_weight_attrs(w2_weight, extra_weight_attrs)
|
||||
|
||||
# WEIGHT_SCALES
|
||||
if self.block_quant:
|
||||
w13_weight_scale = torch.nn.Parameter(
|
||||
torch.ones(
|
||||
num_experts_per_partition,
|
||||
2 * ((intermediate_size + block_n - 1) // block_n),
|
||||
(hidden_size + block_k - 1) // block_k,
|
||||
dtype=torch.float32,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
w2_weight_scale = torch.nn.Parameter(
|
||||
torch.ones(
|
||||
num_experts_per_partition,
|
||||
(hidden_size + block_n - 1) // block_n,
|
||||
(intermediate_size + block_k - 1) // block_k,
|
||||
dtype=torch.float32,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w13_weight_scale_inv", w13_weight_scale)
|
||||
layer.register_parameter("w2_weight_scale_inv", w2_weight_scale)
|
||||
assert self.quant_config.activation_scheme == "dynamic"
|
||||
else:
|
||||
# WEIGHT_SCALES
|
||||
# Allocate 2 scales for w1 and w3 respectively.
|
||||
w13_weight_scale = torch.nn.Parameter(
|
||||
torch.ones(num_experts_per_partition, 2, dtype=torch.float32),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w13_weight_scale", w13_weight_scale)
|
||||
|
||||
w2_weight_scale = torch.nn.Parameter(
|
||||
torch.ones(num_experts_per_partition, dtype=torch.float32),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w2_weight_scale", w2_weight_scale)
|
||||
# Add the quantization method used (per tensor/grouped/channel)
|
||||
# to ensure the weight scales are loaded in properly
|
||||
extra_weight_attrs.update(
|
||||
{"quant_method": FusedMoeWeightScaleSupported.BLOCK.value}
|
||||
if self.block_quant
|
||||
else {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value}
|
||||
)
|
||||
# If loading fp8 checkpoint, pass the weight loaders.
|
||||
# If loading an fp16 checkpoint, do not (we will quantize in
|
||||
# process_weights_after_loading()
|
||||
if self.quant_config.is_checkpoint_fp8_serialized:
|
||||
set_weight_attrs(w13_weight_scale, extra_weight_attrs)
|
||||
set_weight_attrs(w2_weight_scale, extra_weight_attrs)
|
||||
|
||||
# INPUT_SCALES
|
||||
if self.quant_config.activation_scheme == "static":
|
||||
if not self.quant_config.is_checkpoint_fp8_serialized:
|
||||
raise ValueError(
|
||||
"Found static activation scheme for checkpoint that "
|
||||
"was not serialized fp8."
|
||||
)
|
||||
|
||||
w13_input_scale = torch.nn.Parameter(
|
||||
torch.ones(num_experts_per_partition, dtype=torch.float32),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w13_input_scale", w13_input_scale)
|
||||
set_weight_attrs(w13_input_scale, extra_weight_attrs)
|
||||
|
||||
w2_input_scale = torch.nn.Parameter(
|
||||
torch.ones(num_experts_per_partition, dtype=torch.float32),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w2_input_scale", w2_input_scale)
|
||||
set_weight_attrs(w2_input_scale, extra_weight_attrs)
|
||||
|
||||
else:
|
||||
layer.w13_input_scale = None
|
||||
layer.w2_input_scale = None
|
||||
|
||||
def process_weights_after_loading(self, layer: Module) -> None:
|
||||
|
||||
# If checkpoint is fp16, quantize in place.
|
||||
if not self.quant_config.is_checkpoint_fp8_serialized:
|
||||
# If rocm, use float8_e4m3fnuz as dtype
|
||||
fp8_dtype = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn
|
||||
w13_weight = torch.empty_like(layer.w13_weight.data, dtype=fp8_dtype)
|
||||
w2_weight = torch.empty_like(layer.w2_weight.data, dtype=fp8_dtype)
|
||||
|
||||
layer.w13_weight_scale = torch.nn.Parameter(
|
||||
torch.ones(
|
||||
layer.num_experts_per_partition,
|
||||
dtype=torch.float32,
|
||||
device=w13_weight.device,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
|
||||
for expert in range(layer.num_experts_per_partition):
|
||||
w13_weight[expert, :, :], layer.w13_weight_scale[expert] = (
|
||||
scaled_fp8_quant(layer.w13_weight.data[expert, :, :])
|
||||
)
|
||||
w2_weight[expert, :, :], layer.w2_weight_scale[expert] = (
|
||||
scaled_fp8_quant(layer.w2_weight.data[expert, :, :])
|
||||
)
|
||||
layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False)
|
||||
layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False)
|
||||
return
|
||||
|
||||
# If checkpoint is fp8, we need to handle that the
|
||||
# MoE kernels require single activation scale and single weight
|
||||
# scale for w13 per expert.
|
||||
else:
|
||||
if self.quant_config.activation_scheme == "static":
|
||||
if layer.w13_input_scale is None or layer.w2_input_scale is None:
|
||||
raise ValueError(
|
||||
"QuantConfig has static quantization, but found "
|
||||
"activation scales are None."
|
||||
)
|
||||
layer.w13_weight_scale = torch.nn.Parameter(
|
||||
torch.max(layer.w13_weight_scale, dim=1).values,
|
||||
requires_grad=False,
|
||||
)
|
||||
if self.block_quant:
|
||||
# If ROCm, normalize the weights and scales to e4m3fnuz
|
||||
if _is_fp8_fnuz:
|
||||
# activation_scheme: dynamic
|
||||
w13_weight, w13_weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
|
||||
weight=layer.w13_weight,
|
||||
weight_scale=layer.w13_weight_scale_inv,
|
||||
input_scale=None,
|
||||
)
|
||||
w2_weight, w2_weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
|
||||
weight=layer.w2_weight,
|
||||
weight_scale=layer.w2_weight_scale_inv,
|
||||
input_scale=None,
|
||||
)
|
||||
# Reset the parameter
|
||||
layer.w13_weight = torch.nn.Parameter(
|
||||
w13_weight, requires_grad=False
|
||||
)
|
||||
layer.w13_weight_scale_inv = torch.nn.Parameter(
|
||||
w13_weight_scale, requires_grad=False
|
||||
)
|
||||
layer.w13_input_scale = None
|
||||
layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False)
|
||||
layer.w2_weight_scale_inv = torch.nn.Parameter(
|
||||
w2_weight_scale, requires_grad=False
|
||||
)
|
||||
layer.w2_input_scale = None
|
||||
if _use_aiter:
|
||||
layer.w13_weight = torch.nn.Parameter(
|
||||
shuffle_weight(layer.w13_weight.data, (16, 16)),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.w2_weight = torch.nn.Parameter(
|
||||
shuffle_weight(layer.w2_weight.data, (16, 16)),
|
||||
requires_grad=False,
|
||||
)
|
||||
return
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
hidden_states: torch.Tensor,
|
||||
topk_output: TopKOutput,
|
||||
) -> torch.Tensor:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class Fp8KVCacheMethod(BaseKVCacheMethod):
|
||||
"""
|
||||
Supports loading kv-cache scaling factors from FP8 checkpoints.
|
||||
|
||||
@@ -24,6 +24,7 @@ from sglang.srt.utils import (
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.layers.moe.ep_moe.layer import EPMoE
|
||||
from sglang.srt.layers.moe.topk import TopKOutput
|
||||
|
||||
has_triton_kernels = importlib.util.find_spec("triton_kernels") is not None
|
||||
@@ -194,6 +195,15 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
no_combine: bool = False,
|
||||
routed_scaling_factor: Optional[float] = None,
|
||||
) -> torch.Tensor:
|
||||
|
||||
from sglang.srt.layers.moe.ep_moe.layer import EPMoE
|
||||
|
||||
if isinstance(layer, EPMoE):
|
||||
return layer.run_moe(
|
||||
hidden_states=x,
|
||||
topk_output=topk_output,
|
||||
)
|
||||
|
||||
return self.forward(
|
||||
x=x,
|
||||
layer=layer,
|
||||
@@ -354,69 +364,3 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
raise NotImplementedError("The TPU backend currently does not support MoE.")
|
||||
|
||||
forward_native = forward_cpu
|
||||
|
||||
|
||||
class UnquantizedEPMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
num_experts_per_partition: int,
|
||||
hidden_size: int,
|
||||
intermediate_size: int,
|
||||
params_dtype: torch.dtype,
|
||||
**extra_weight_attrs,
|
||||
):
|
||||
# Fused gate_up_proj (column parallel)
|
||||
w13_weight = torch.nn.Parameter(
|
||||
torch.empty(
|
||||
num_experts_per_partition,
|
||||
2 * intermediate_size,
|
||||
hidden_size,
|
||||
dtype=params_dtype,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w13_weight", w13_weight)
|
||||
set_weight_attrs(w13_weight, extra_weight_attrs)
|
||||
|
||||
# down_proj (row parallel)
|
||||
w2_weight = torch.nn.Parameter(
|
||||
torch.empty(
|
||||
num_experts_per_partition,
|
||||
hidden_size,
|
||||
intermediate_size,
|
||||
dtype=params_dtype,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w2_weight", w2_weight)
|
||||
set_weight_attrs(w2_weight, extra_weight_attrs)
|
||||
|
||||
# scale
|
||||
layer.register_parameter("w13_input_scale", None)
|
||||
layer.register_parameter("w13_weight_scale", None)
|
||||
|
||||
ones_tensor = torch.ones(num_experts_per_partition, dtype=torch.float32)
|
||||
|
||||
w2_input_scale = torch.nn.Parameter(
|
||||
ones_tensor,
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w2_input_scale", w2_input_scale)
|
||||
set_weight_attrs(w2_input_scale, extra_weight_attrs)
|
||||
|
||||
w2_weight_scale = torch.nn.Parameter(
|
||||
ones_tensor,
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w2_weight_scale", w2_weight_scale)
|
||||
set_weight_attrs(w2_weight_scale, extra_weight_attrs)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
hidden_states: torch.Tensor,
|
||||
topk_output: TopKOutput,
|
||||
) -> torch.Tensor:
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional
|
||||
|
||||
import torch
|
||||
from torch.nn import Module
|
||||
@@ -17,6 +17,9 @@ from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod
|
||||
from sglang.srt.layers.quantization.utils import is_layer_skipped
|
||||
from sglang.srt.utils import set_weight_attrs
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.layers.moe.ep_moe.layer import EPMoE, TopKOutput
|
||||
|
||||
ACTIVATION_SCHEMES = ["static", "dynamic"]
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -84,13 +87,14 @@ class W4AFp8Config(QuantizationConfig):
|
||||
self, layer: torch.nn.Module, prefix: str
|
||||
) -> Optional[QuantizeMethodBase]:
|
||||
from sglang.srt.layers.linear import LinearBase
|
||||
from sglang.srt.layers.moe.ep_moe.layer import EPMoE
|
||||
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
||||
|
||||
if isinstance(layer, LinearBase):
|
||||
if is_layer_skipped(prefix, self.ignored_layers):
|
||||
return UnquantizedLinearMethod()
|
||||
return Fp8LinearMethod(self)
|
||||
elif isinstance(layer, FusedMoE):
|
||||
elif isinstance(layer, EPMoE):
|
||||
return W4AFp8MoEMethod(self)
|
||||
return None
|
||||
|
||||
@@ -105,8 +109,8 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
layer: Module,
|
||||
num_experts_per_partition: int,
|
||||
layer: EPMoE,
|
||||
num_experts: int,
|
||||
hidden_size: int,
|
||||
intermediate_size: int,
|
||||
params_dtype: torch.dtype,
|
||||
@@ -117,7 +121,7 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
|
||||
# Fused gate_up_proj (column parallel)
|
||||
w13_weight = torch.nn.Parameter(
|
||||
torch.empty(
|
||||
num_experts_per_partition,
|
||||
num_experts,
|
||||
intermediate_size * 2,
|
||||
hidden_size // 2,
|
||||
dtype=torch.int8,
|
||||
@@ -130,7 +134,7 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
|
||||
# down_proj (row parallel)
|
||||
w2_weight = torch.nn.Parameter(
|
||||
torch.empty(
|
||||
num_experts_per_partition,
|
||||
num_experts,
|
||||
hidden_size,
|
||||
intermediate_size // 2,
|
||||
dtype=torch.int8,
|
||||
@@ -142,7 +146,7 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
|
||||
|
||||
w13_weight_scale = torch.nn.Parameter(
|
||||
torch.zeros(
|
||||
num_experts_per_partition,
|
||||
num_experts,
|
||||
2 * intermediate_size,
|
||||
hidden_size // self.quant_config.group_size,
|
||||
dtype=torch.float32,
|
||||
@@ -154,7 +158,7 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
|
||||
|
||||
w2_weight_scale = torch.nn.Parameter(
|
||||
torch.zeros(
|
||||
num_experts_per_partition,
|
||||
num_experts,
|
||||
hidden_size,
|
||||
intermediate_size // self.quant_config.group_size,
|
||||
dtype=torch.float32,
|
||||
@@ -166,14 +170,14 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
|
||||
|
||||
# Input scales
|
||||
w13_input_scale = torch.nn.Parameter(
|
||||
torch.ones((num_experts_per_partition, 2), dtype=torch.bfloat16),
|
||||
torch.ones((num_experts, 2), dtype=torch.bfloat16),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w13_input_scale", w13_input_scale)
|
||||
set_weight_attrs(w13_input_scale, extra_weight_attrs)
|
||||
|
||||
w2_input_scale = torch.nn.Parameter(
|
||||
torch.ones(num_experts_per_partition, dtype=torch.bfloat16),
|
||||
torch.ones(num_experts, dtype=torch.bfloat16),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w2_input_scale", w2_input_scale)
|
||||
@@ -183,25 +187,25 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
|
||||
device = layer.w13_weight.device
|
||||
|
||||
self.a_strides1 = torch.full(
|
||||
(num_experts_per_partition, 3),
|
||||
(num_experts, 3),
|
||||
hidden_size,
|
||||
device=device,
|
||||
dtype=torch.int64,
|
||||
)
|
||||
self.c_strides1 = torch.full(
|
||||
(num_experts_per_partition, 3),
|
||||
(num_experts, 3),
|
||||
2 * intermediate_size,
|
||||
device=device,
|
||||
dtype=torch.int64,
|
||||
)
|
||||
self.a_strides2 = torch.full(
|
||||
(num_experts_per_partition, 3),
|
||||
(num_experts, 3),
|
||||
intermediate_size,
|
||||
device=device,
|
||||
dtype=torch.int64,
|
||||
)
|
||||
self.c_strides2 = torch.full(
|
||||
(num_experts_per_partition, 3),
|
||||
(num_experts, 3),
|
||||
hidden_size,
|
||||
device=device,
|
||||
dtype=torch.int64,
|
||||
@@ -212,13 +216,13 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
|
||||
self.s_strides2 = self.c_strides2
|
||||
|
||||
self.expert_offsets = torch.empty(
|
||||
(num_experts_per_partition + 1), dtype=torch.int32, device=device
|
||||
(num_experts + 1), dtype=torch.int32, device=device
|
||||
)
|
||||
self.problem_sizes1 = torch.empty(
|
||||
(num_experts_per_partition, 3), dtype=torch.int32, device=device
|
||||
(num_experts, 3), dtype=torch.int32, device=device
|
||||
)
|
||||
self.problem_sizes2 = torch.empty(
|
||||
(num_experts_per_partition, 3), dtype=torch.int32, device=device
|
||||
(num_experts, 3), dtype=torch.int32, device=device
|
||||
)
|
||||
|
||||
return
|
||||
@@ -266,3 +270,50 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
|
||||
[w2_input_scale_max], dtype=dtype, device=device
|
||||
)
|
||||
layer.w2_input_scale = Parameter(new_w2_input_scale, requires_grad=False)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: EPMoE,
|
||||
hidden_states: torch.Tensor,
|
||||
topk_output: TopKOutput,
|
||||
) -> torch.Tensor:
|
||||
|
||||
# TODO(ch-wan): move it out of this class
|
||||
from sglang.srt.layers.moe.cutlass_w4a8_moe import cutlass_w4a8_moe
|
||||
|
||||
topk_ids, topk_weights, _ = topk_output
|
||||
local_topk_ids = topk_ids
|
||||
if layer.expert_map is not None:
|
||||
"Translate info from expert_map to topk_ids"
|
||||
local_topk_ids = torch.where(
|
||||
layer.expert_map[topk_ids] != layer.num_experts,
|
||||
layer.expert_map[topk_ids],
|
||||
layer.num_experts,
|
||||
)
|
||||
|
||||
return cutlass_w4a8_moe(
|
||||
layer.start_expert_id,
|
||||
layer.end_expert_id,
|
||||
layer.num_experts,
|
||||
hidden_states,
|
||||
layer.w13_weight,
|
||||
layer.w2_weight,
|
||||
layer.w13_weight_scale_inv,
|
||||
layer.w2_weight_scale_inv,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
local_topk_ids,
|
||||
self.a_strides1,
|
||||
self.b_strides1,
|
||||
self.c_strides1,
|
||||
self.a_strides2,
|
||||
self.b_strides2,
|
||||
self.c_strides2,
|
||||
self.s_strides13,
|
||||
self.s_strides2,
|
||||
self.expert_offsets,
|
||||
self.problem_sizes1,
|
||||
self.problem_sizes2,
|
||||
layer.w13_input_scale,
|
||||
layer.w2_input_scale,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user