[2/N] MoE Refactor: Unify weight loader and quant methods (#8397)
This commit is contained in:
@@ -30,13 +30,13 @@ from sglang.srt.layers.quantization.base_config import (
|
|||||||
QuantizationConfig,
|
QuantizationConfig,
|
||||||
QuantizeMethodBase,
|
QuantizeMethodBase,
|
||||||
)
|
)
|
||||||
from sglang.srt.layers.quantization.fp8 import Fp8EPMoEMethod
|
from sglang.srt.layers.quantization.fp8 import Fp8Config, Fp8MoEMethod
|
||||||
from sglang.srt.layers.quantization.fp8_kernel import (
|
from sglang.srt.layers.quantization.fp8_kernel import (
|
||||||
is_fp8_fnuz,
|
is_fp8_fnuz,
|
||||||
sglang_per_token_group_quant_fp8,
|
sglang_per_token_group_quant_fp8,
|
||||||
sglang_per_token_quant_fp8,
|
sglang_per_token_quant_fp8,
|
||||||
)
|
)
|
||||||
from sglang.srt.layers.quantization.unquant import UnquantizedEPMoEMethod
|
from sglang.srt.layers.quantization.unquant import UnquantizedFusedMoEMethod
|
||||||
from sglang.srt.layers.quantization.w4afp8 import W4AFp8Config, W4AFp8MoEMethod
|
from sglang.srt.layers.quantization.w4afp8 import W4AFp8Config, W4AFp8MoEMethod
|
||||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||||
@@ -62,8 +62,6 @@ use_flashinfer_trtllm_moe = (
|
|||||||
if not (_is_npu or _is_hip):
|
if not (_is_npu or _is_hip):
|
||||||
from sgl_kernel import silu_and_mul
|
from sgl_kernel import silu_and_mul
|
||||||
|
|
||||||
from sglang.srt.layers.moe.cutlass_w4a8_moe import cutlass_w4a8_moe
|
|
||||||
|
|
||||||
if _use_aiter:
|
if _use_aiter:
|
||||||
from aiter import ActivationType, QuantType
|
from aiter import ActivationType, QuantType
|
||||||
from aiter.fused_moe import fused_moe
|
from aiter.fused_moe import fused_moe
|
||||||
@@ -162,7 +160,7 @@ def _get_tile_tokens_dim(num_tokens, top_k, num_experts):
|
|||||||
return tile_tokens_dim
|
return tile_tokens_dim
|
||||||
|
|
||||||
|
|
||||||
class EPMoE(torch.nn.Module):
|
class EPMoE(FusedMoE):
|
||||||
"""
|
"""
|
||||||
MoE Expert Parallel Impl
|
MoE Expert Parallel Impl
|
||||||
|
|
||||||
@@ -184,51 +182,60 @@ class EPMoE(torch.nn.Module):
|
|||||||
routed_scaling_factor: Optional[float] = None,
|
routed_scaling_factor: Optional[float] = None,
|
||||||
use_per_token_if_dynamic: bool = True,
|
use_per_token_if_dynamic: bool = True,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__(
|
||||||
|
num_experts=num_experts,
|
||||||
|
hidden_size=hidden_size,
|
||||||
|
intermediate_size=intermediate_size,
|
||||||
|
top_k=top_k,
|
||||||
|
layer_id=layer_id,
|
||||||
|
params_dtype=params_dtype,
|
||||||
|
quant_config=quant_config,
|
||||||
|
tp_size=tp_size,
|
||||||
|
prefix=prefix,
|
||||||
|
activation=activation,
|
||||||
|
routed_scaling_factor=routed_scaling_factor,
|
||||||
|
enable_ep_moe=True,
|
||||||
|
skip_quant=True,
|
||||||
|
)
|
||||||
|
|
||||||
if params_dtype is None:
|
if params_dtype is None:
|
||||||
params_dtype = torch.get_default_dtype()
|
params_dtype = torch.get_default_dtype()
|
||||||
|
|
||||||
self.tp_size = (
|
|
||||||
tp_size if tp_size is not None else get_tensor_model_parallel_world_size()
|
|
||||||
)
|
|
||||||
self.tp_rank = get_tensor_model_parallel_rank()
|
|
||||||
|
|
||||||
self.layer_id = layer_id
|
self.layer_id = layer_id
|
||||||
self.num_experts = num_experts
|
self.num_local_experts, self.expert_map = self.determine_expert_map()
|
||||||
assert self.num_experts % self.tp_size == 0
|
self.start_expert_id = self.ep_rank * self.num_local_experts
|
||||||
self.num_experts_per_partition, self.expert_map = self.determine_expert_map()
|
self.end_expert_id = self.start_expert_id + self.num_local_experts - 1
|
||||||
self.start_expert_id = self.tp_rank * self.num_experts_per_partition
|
|
||||||
self.end_expert_id = self.start_expert_id + self.num_experts_per_partition - 1
|
|
||||||
|
|
||||||
self.top_k = top_k
|
|
||||||
self.intermediate_size = intermediate_size
|
self.intermediate_size = intermediate_size
|
||||||
self.activation = activation
|
|
||||||
self.routed_scaling_factor = routed_scaling_factor
|
|
||||||
self.use_per_token_if_dynamic = use_per_token_if_dynamic
|
self.use_per_token_if_dynamic = use_per_token_if_dynamic
|
||||||
|
|
||||||
|
# TODO(ch-wan): move quant preparation to FusedMoE
|
||||||
if quant_config is None:
|
if quant_config is None:
|
||||||
self.quant_method: Optional[QuantizeMethodBase] = UnquantizedEPMoEMethod()
|
self.quant_method: Optional[QuantizeMethodBase] = (
|
||||||
|
UnquantizedFusedMoEMethod()
|
||||||
|
)
|
||||||
self.use_fp8_w8a8 = False
|
self.use_fp8_w8a8 = False
|
||||||
self.use_block_quant = False
|
self.use_block_quant = False
|
||||||
self.block_shape = None
|
self.block_shape = None
|
||||||
self.activation_scheme = None
|
self.activation_scheme = None
|
||||||
self.use_w4afp8 = False
|
self.w13_input_scale = None
|
||||||
|
self.w2_input_scale = None
|
||||||
|
self.w13_weight_scale = None
|
||||||
|
self.w2_weight_scale = None
|
||||||
elif isinstance(quant_config, W4AFp8Config):
|
elif isinstance(quant_config, W4AFp8Config):
|
||||||
self.quant_method: Optional[QuantizeMethodBase] = W4AFp8MoEMethod(
|
self.quant_method: Optional[QuantizeMethodBase] = W4AFp8MoEMethod(
|
||||||
quant_config
|
quant_config
|
||||||
)
|
)
|
||||||
self.use_w4afp8 = True
|
|
||||||
self.use_fp8_w8a8 = False
|
self.use_fp8_w8a8 = False
|
||||||
self.use_block_quant = False
|
self.use_block_quant = False
|
||||||
self.fp8_dtype = torch.float8_e4m3fn
|
self.fp8_dtype = torch.float8_e4m3fn
|
||||||
|
self.w13_input_scale = None
|
||||||
|
self.w2_input_scale = None
|
||||||
self.w13_weight_scale = None
|
self.w13_weight_scale = None
|
||||||
self.w2_weight_scale = None
|
self.w2_weight_scale = None
|
||||||
self.activation_scheme = quant_config.moe_activation_scheme
|
self.activation_scheme = quant_config.moe_activation_scheme
|
||||||
else:
|
elif isinstance(quant_config, Fp8Config):
|
||||||
self.quant_method: Optional[QuantizeMethodBase] = Fp8EPMoEMethod(
|
self.quant_method: Optional[QuantizeMethodBase] = Fp8MoEMethod(quant_config)
|
||||||
quant_config
|
|
||||||
)
|
|
||||||
self.use_fp8_w8a8 = True
|
self.use_fp8_w8a8 = True
|
||||||
self.use_block_quant = getattr(self.quant_method, "block_quant", False)
|
self.use_block_quant = getattr(self.quant_method, "block_quant", False)
|
||||||
self.block_shape = (
|
self.block_shape = (
|
||||||
@@ -238,11 +245,13 @@ class EPMoE(torch.nn.Module):
|
|||||||
)
|
)
|
||||||
self.fp8_dtype = torch.float8_e4m3fn
|
self.fp8_dtype = torch.float8_e4m3fn
|
||||||
self.activation_scheme = quant_config.activation_scheme
|
self.activation_scheme = quant_config.activation_scheme
|
||||||
self.use_w4afp8 = False
|
else:
|
||||||
|
raise ValueError(f"Unsupported quant_config: {quant_config}")
|
||||||
|
|
||||||
|
self.quant_config = quant_config
|
||||||
self.quant_method.create_weights(
|
self.quant_method.create_weights(
|
||||||
layer=self,
|
layer=self,
|
||||||
num_experts_per_partition=self.num_experts_per_partition,
|
num_experts=self.num_local_experts,
|
||||||
hidden_size=hidden_size,
|
hidden_size=hidden_size,
|
||||||
intermediate_size=self.intermediate_size,
|
intermediate_size=self.intermediate_size,
|
||||||
params_dtype=params_dtype,
|
params_dtype=params_dtype,
|
||||||
@@ -251,19 +260,6 @@ class EPMoE(torch.nn.Module):
|
|||||||
|
|
||||||
self.grouped_gemm_runner = None
|
self.grouped_gemm_runner = None
|
||||||
|
|
||||||
self.w13_weight_fp8 = (
|
|
||||||
self.w13_weight,
|
|
||||||
(
|
|
||||||
self.w13_weight_scale_inv
|
|
||||||
if self.use_block_quant
|
|
||||||
else self.w13_weight_scale
|
|
||||||
),
|
|
||||||
)
|
|
||||||
self.w2_weight_fp8 = (
|
|
||||||
self.w2_weight,
|
|
||||||
self.w2_weight_scale_inv if self.use_block_quant else self.w2_weight_scale,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Adapted from https://github.com/vllm-project/vllm/blob/9fb52e523abf7bdaf7e60cf2971edb5a1b13dc08/vllm/model_executor/layers/fused_moe/layer.py#L544C1-L586C43
|
# Adapted from https://github.com/vllm-project/vllm/blob/9fb52e523abf7bdaf7e60cf2971edb5a1b13dc08/vllm/model_executor/layers/fused_moe/layer.py#L544C1-L586C43
|
||||||
# Modifications: use determine_expert_map as a class internal function, set 'global_num_experts' rather than '-1' for experts not assigned to the current rank.
|
# Modifications: use determine_expert_map as a class internal function, set 'global_num_experts' rather than '-1' for experts not assigned to the current rank.
|
||||||
def determine_expert_map(self) -> Tuple[int, Optional[torch.Tensor]]:
|
def determine_expert_map(self) -> Tuple[int, Optional[torch.Tensor]]:
|
||||||
@@ -282,8 +278,8 @@ class EPMoE(torch.nn.Module):
|
|||||||
Contains global_num_experts for experts not assigned to the current rank.
|
Contains global_num_experts for experts not assigned to the current rank.
|
||||||
Returns None if ep_size is 1.
|
Returns None if ep_size is 1.
|
||||||
"""
|
"""
|
||||||
ep_size = self.tp_size
|
ep_size = self.ep_size
|
||||||
ep_rank = self.tp_rank
|
ep_rank = self.ep_rank
|
||||||
global_num_experts = self.num_experts
|
global_num_experts = self.num_experts
|
||||||
|
|
||||||
assert ep_size > 0
|
assert ep_size > 0
|
||||||
@@ -293,7 +289,7 @@ class EPMoE(torch.nn.Module):
|
|||||||
local_num_experts = global_num_experts // ep_size
|
local_num_experts = global_num_experts // ep_size
|
||||||
|
|
||||||
expert_map = torch.full(
|
expert_map = torch.full(
|
||||||
(global_num_experts,), self.num_experts, dtype=torch.int32
|
(global_num_experts,), global_num_experts, dtype=torch.int32
|
||||||
)
|
)
|
||||||
if ep_rank < (ep_size - 1):
|
if ep_rank < (ep_size - 1):
|
||||||
expert_map[
|
expert_map[
|
||||||
@@ -318,6 +314,20 @@ class EPMoE(torch.nn.Module):
|
|||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
topk_output: TopKOutput,
|
topk_output: TopKOutput,
|
||||||
):
|
):
|
||||||
|
|
||||||
|
self.w13_weight_fp8 = (
|
||||||
|
self.w13_weight,
|
||||||
|
(
|
||||||
|
self.w13_weight_scale_inv
|
||||||
|
if self.use_block_quant
|
||||||
|
else self.w13_weight_scale
|
||||||
|
),
|
||||||
|
)
|
||||||
|
self.w2_weight_fp8 = (
|
||||||
|
self.w2_weight,
|
||||||
|
self.w2_weight_scale_inv if self.use_block_quant else self.w2_weight_scale,
|
||||||
|
)
|
||||||
|
|
||||||
assert self.quant_method is not None
|
assert self.quant_method is not None
|
||||||
assert self.activation == "silu"
|
assert self.activation == "silu"
|
||||||
hidden_states_shape = hidden_states.shape
|
hidden_states_shape = hidden_states.shape
|
||||||
@@ -457,7 +467,10 @@ class EPMoE(torch.nn.Module):
|
|||||||
return output
|
return output
|
||||||
|
|
||||||
def forward_normal(self, hidden_states: torch.Tensor, topk_output: TopKOutput):
|
def forward_normal(self, hidden_states: torch.Tensor, topk_output: TopKOutput):
|
||||||
assert self.quant_method is not None
|
return self.quant_method.apply(self, hidden_states, topk_output)
|
||||||
|
|
||||||
|
def run_moe(self, hidden_states: torch.Tensor, topk_output: TopKOutput):
|
||||||
|
|
||||||
topk_weights, topk_ids, _ = topk_output
|
topk_weights, topk_ids, _ = topk_output
|
||||||
|
|
||||||
hidden_states_shape = hidden_states.shape
|
hidden_states_shape = hidden_states.shape
|
||||||
@@ -470,53 +483,11 @@ class EPMoE(torch.nn.Module):
|
|||||||
use_per_token_if_dynamic=self.use_per_token_if_dynamic,
|
use_per_token_if_dynamic=self.use_per_token_if_dynamic,
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.use_w4afp8:
|
num_experts = self.num_experts
|
||||||
local_topk_ids = topk_ids
|
|
||||||
if self.expert_map is not None:
|
|
||||||
"Translate info from expert_map to topk_ids"
|
|
||||||
local_topk_ids = torch.where(
|
|
||||||
self.expert_map[topk_ids] != self.num_experts,
|
|
||||||
self.expert_map[topk_ids],
|
|
||||||
self.num_experts,
|
|
||||||
)
|
|
||||||
|
|
||||||
output = cutlass_w4a8_moe(
|
|
||||||
self.start_expert_id,
|
|
||||||
self.end_expert_id,
|
|
||||||
self.num_experts,
|
|
||||||
hidden_states,
|
|
||||||
self.w13_weight,
|
|
||||||
self.w2_weight,
|
|
||||||
self.w13_weight_scale_inv,
|
|
||||||
self.w2_weight_scale_inv,
|
|
||||||
topk_weights,
|
|
||||||
topk_ids,
|
|
||||||
local_topk_ids,
|
|
||||||
self.quant_method.a_strides1,
|
|
||||||
self.quant_method.b_strides1,
|
|
||||||
self.quant_method.c_strides1,
|
|
||||||
self.quant_method.a_strides2,
|
|
||||||
self.quant_method.b_strides2,
|
|
||||||
self.quant_method.c_strides2,
|
|
||||||
self.quant_method.s_strides13,
|
|
||||||
self.quant_method.s_strides2,
|
|
||||||
self.quant_method.expert_offsets,
|
|
||||||
self.quant_method.problem_sizes1,
|
|
||||||
self.quant_method.problem_sizes2,
|
|
||||||
self.w13_input_scale,
|
|
||||||
self.w2_input_scale,
|
|
||||||
)
|
|
||||||
return output
|
|
||||||
|
|
||||||
if self.grouped_gemm_runner is None:
|
|
||||||
self.grouped_gemm_runner = GroupedGemmRunner(
|
|
||||||
hidden_states.device,
|
|
||||||
use_flashinfer=False, # TODO: use flashinfer
|
|
||||||
use_per_token_if_dynamic=self.use_per_token_if_dynamic,
|
|
||||||
)
|
|
||||||
|
|
||||||
reorder_topk_ids, src2dst, seg_indptr = run_moe_ep_preproess(
|
reorder_topk_ids, src2dst, seg_indptr = run_moe_ep_preproess(
|
||||||
topk_ids, self.num_experts
|
topk_ids,
|
||||||
|
num_experts,
|
||||||
)
|
)
|
||||||
|
|
||||||
gateup_input = torch.empty(
|
gateup_input = torch.empty(
|
||||||
@@ -524,7 +495,7 @@ class EPMoE(torch.nn.Module):
|
|||||||
device=hidden_states.device,
|
device=hidden_states.device,
|
||||||
dtype=(
|
dtype=(
|
||||||
self.fp8_dtype
|
self.fp8_dtype
|
||||||
if ((self.use_fp8_w8a8 or self.use_w4afp8) and not self.use_block_quant)
|
if self.use_fp8_w8a8 and not self.use_block_quant
|
||||||
else hidden_states.dtype
|
else hidden_states.dtype
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
@@ -535,7 +506,7 @@ class EPMoE(torch.nn.Module):
|
|||||||
else:
|
else:
|
||||||
max_value = (
|
max_value = (
|
||||||
torch.max(hidden_states)
|
torch.max(hidden_states)
|
||||||
.repeat(self.num_experts_per_partition)
|
.repeat(self.num_local_experts)
|
||||||
.to(torch.float32)
|
.to(torch.float32)
|
||||||
)
|
)
|
||||||
self.w13_input_scale = max_value / torch.finfo(self.fp8_dtype).max
|
self.w13_input_scale = max_value / torch.finfo(self.fp8_dtype).max
|
||||||
@@ -576,7 +547,7 @@ class EPMoE(torch.nn.Module):
|
|||||||
seg_indptr_cur_rank = seg_indptr[self.start_expert_id : self.end_expert_id + 2]
|
seg_indptr_cur_rank = seg_indptr[self.start_expert_id : self.end_expert_id + 2]
|
||||||
weight_indices_cur_rank = torch.arange(
|
weight_indices_cur_rank = torch.arange(
|
||||||
0,
|
0,
|
||||||
self.num_experts_per_partition,
|
self.num_local_experts,
|
||||||
device=hidden_states_device,
|
device=hidden_states_device,
|
||||||
dtype=torch.int64,
|
dtype=torch.int64,
|
||||||
)
|
)
|
||||||
@@ -586,17 +557,13 @@ class EPMoE(torch.nn.Module):
|
|||||||
b=self.w13_weight,
|
b=self.w13_weight,
|
||||||
c=None,
|
c=None,
|
||||||
c_dtype=hidden_states_dtype,
|
c_dtype=hidden_states_dtype,
|
||||||
batch_size=self.num_experts_per_partition,
|
batch_size=self.num_local_experts,
|
||||||
weight_column_major=True,
|
weight_column_major=True,
|
||||||
seg_indptr=seg_indptr_cur_rank,
|
seg_indptr=seg_indptr_cur_rank,
|
||||||
weight_indices=weight_indices_cur_rank,
|
weight_indices=weight_indices_cur_rank,
|
||||||
use_fp8_w8a8=self.use_fp8_w8a8,
|
use_fp8_w8a8=self.use_fp8_w8a8,
|
||||||
scale_a=self.w13_input_scale,
|
scale_a=self.w13_input_scale,
|
||||||
scale_b=(
|
scale_b=self.w13_weight_scale,
|
||||||
self.w13_weight_scale_inv
|
|
||||||
if self.use_block_quant
|
|
||||||
else self.w13_weight_scale
|
|
||||||
),
|
|
||||||
block_shape=self.block_shape,
|
block_shape=self.block_shape,
|
||||||
)
|
)
|
||||||
del gateup_input
|
del gateup_input
|
||||||
@@ -653,7 +620,7 @@ class EPMoE(torch.nn.Module):
|
|||||||
down_input, self.w2_input_scale = sglang_per_token_quant_fp8(down_input)
|
down_input, self.w2_input_scale = sglang_per_token_quant_fp8(down_input)
|
||||||
else:
|
else:
|
||||||
self.w2_input_scale = torch.ones(
|
self.w2_input_scale = torch.ones(
|
||||||
self.num_experts_per_partition,
|
self.num_local_experts,
|
||||||
dtype=torch.float32,
|
dtype=torch.float32,
|
||||||
device=hidden_states_device,
|
device=hidden_states_device,
|
||||||
)
|
)
|
||||||
@@ -669,17 +636,13 @@ class EPMoE(torch.nn.Module):
|
|||||||
a=down_input,
|
a=down_input,
|
||||||
b=self.w2_weight,
|
b=self.w2_weight,
|
||||||
c=down_output,
|
c=down_output,
|
||||||
batch_size=self.num_experts_per_partition,
|
batch_size=self.num_local_experts,
|
||||||
weight_column_major=True,
|
weight_column_major=True,
|
||||||
seg_indptr=seg_indptr_cur_rank,
|
seg_indptr=seg_indptr_cur_rank,
|
||||||
weight_indices=weight_indices_cur_rank,
|
weight_indices=weight_indices_cur_rank,
|
||||||
use_fp8_w8a8=self.use_fp8_w8a8,
|
use_fp8_w8a8=self.use_fp8_w8a8,
|
||||||
scale_a=self.w2_input_scale,
|
scale_a=self.w2_input_scale,
|
||||||
scale_b=(
|
scale_b=self.w2_weight_scale,
|
||||||
self.w2_weight_scale_inv
|
|
||||||
if self.use_block_quant
|
|
||||||
else self.w2_weight_scale
|
|
||||||
),
|
|
||||||
block_shape=self.block_shape,
|
block_shape=self.block_shape,
|
||||||
)
|
)
|
||||||
del down_input
|
del down_input
|
||||||
@@ -782,108 +745,15 @@ class EPMoE(torch.nn.Module):
|
|||||||
return
|
return
|
||||||
expert_id = expert_id - self.start_expert_id
|
expert_id = expert_id - self.start_expert_id
|
||||||
|
|
||||||
if shard_id not in ("w1", "w2", "w3"):
|
self._weight_loader_impl(
|
||||||
raise ValueError(
|
param=param,
|
||||||
f"shard_id must be ['w1','w2','w3'] but " f"got {shard_id}."
|
loaded_weight=loaded_weight,
|
||||||
)
|
weight_name=weight_name,
|
||||||
|
shard_id=shard_id,
|
||||||
# Special case for fp8 scales.
|
expert_id=expert_id,
|
||||||
if "scale" in weight_name:
|
|
||||||
self._load_fp8_scale(
|
|
||||||
param.data,
|
|
||||||
loaded_weight,
|
|
||||||
weight_name,
|
|
||||||
shard_id,
|
|
||||||
expert_id,
|
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
# Flashinfer assumes w31 format for w13_weight. Same for the scales.
|
|
||||||
if use_flashinfer_trtllm_moe:
|
|
||||||
actual_shard_id = {"w1": "w3", "w3": "w1", "w2": "w2"}[shard_id]
|
|
||||||
else:
|
|
||||||
actual_shard_id = shard_id
|
|
||||||
|
|
||||||
if actual_shard_id == "w2":
|
|
||||||
param.data[expert_id] = loaded_weight
|
|
||||||
elif actual_shard_id == "w1":
|
|
||||||
param.data[expert_id][: self.intermediate_size, :] = loaded_weight
|
|
||||||
elif actual_shard_id == "w3":
|
|
||||||
param.data[expert_id][self.intermediate_size :, :] = loaded_weight
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Expected shard_id w1,w2 or w3 but got {actual_shard_id}")
|
|
||||||
|
|
||||||
def _load_fp8_scale(
|
|
||||||
self,
|
|
||||||
param: torch.nn.Parameter,
|
|
||||||
loaded_weight: torch.Tensor,
|
|
||||||
weight_name: str,
|
|
||||||
shard_id: str,
|
|
||||||
expert_id: int,
|
|
||||||
) -> None:
|
|
||||||
param_data = param.data
|
|
||||||
|
|
||||||
# Input scales can be loaded directly and should be equal.
|
|
||||||
if "input_scale" in weight_name:
|
|
||||||
if self.use_w4afp8:
|
|
||||||
if shard_id == "w1":
|
|
||||||
param_data[expert_id][0] = loaded_weight
|
|
||||||
elif shard_id == "w3":
|
|
||||||
param_data[expert_id][1] = loaded_weight
|
|
||||||
else:
|
|
||||||
param_data[expert_id] = loaded_weight
|
|
||||||
return
|
|
||||||
|
|
||||||
if (
|
|
||||||
(shard_id == "w1" or shard_id == "w3")
|
|
||||||
and param_data[expert_id] != 1
|
|
||||||
and (param_data[expert_id] - loaded_weight).abs() > 1e-5
|
|
||||||
):
|
|
||||||
raise ValueError(
|
|
||||||
"input_scales of w1 and w3 of a layer "
|
|
||||||
f"must be equal. But got {param_data[expert_id]} "
|
|
||||||
f"vs. {loaded_weight}"
|
|
||||||
)
|
|
||||||
param_data[expert_id] = loaded_weight
|
|
||||||
# Weight scales
|
|
||||||
elif "weight_scale" in weight_name:
|
|
||||||
if self.use_block_quant:
|
|
||||||
if use_flashinfer_trtllm_moe:
|
|
||||||
actual_shard_id = {"w1": "w3", "w3": "w1", "w2": "w2"}[shard_id]
|
|
||||||
else:
|
|
||||||
actual_shard_id = shard_id
|
|
||||||
|
|
||||||
block_n, block_k = self.block_shape[0], self.block_shape[1]
|
|
||||||
|
|
||||||
if actual_shard_id == "w1":
|
|
||||||
param_data[expert_id][
|
|
||||||
: (self.intermediate_size + block_n - 1) // block_n, :
|
|
||||||
] = loaded_weight
|
|
||||||
elif actual_shard_id == "w3":
|
|
||||||
param_data[expert_id][
|
|
||||||
(self.intermediate_size + block_n - 1) // block_n :, :
|
|
||||||
] = loaded_weight
|
|
||||||
else: # w2
|
|
||||||
param_data[expert_id] = loaded_weight
|
|
||||||
elif self.use_w4afp8:
|
|
||||||
if shard_id == "w1":
|
|
||||||
param_data[expert_id][: self.intermediate_size, :] = loaded_weight
|
|
||||||
elif shard_id == "w3":
|
|
||||||
param_data[expert_id][self.intermediate_size :, :] = loaded_weight
|
|
||||||
else:
|
|
||||||
param_data[expert_id] = loaded_weight
|
|
||||||
# If we are in merged column case (gate_up_proj)
|
|
||||||
else:
|
|
||||||
if shard_id in ("w1", "w3"):
|
|
||||||
# We have to keep the weight scales of w1 and w3 because
|
|
||||||
# we need to re-quantize w1/w3 weights after weight loading.
|
|
||||||
idx = 0 if shard_id == "w1" else 1
|
|
||||||
param_data[expert_id][idx] = loaded_weight
|
|
||||||
|
|
||||||
# If we are in the row parallel case (down_proj)
|
|
||||||
else:
|
|
||||||
param_data[expert_id] = loaded_weight
|
|
||||||
|
|
||||||
|
|
||||||
class DeepEPMoE(EPMoE):
|
class DeepEPMoE(EPMoE):
|
||||||
"""
|
"""
|
||||||
@@ -932,13 +802,13 @@ class DeepEPMoE(EPMoE):
|
|||||||
deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM
|
deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM
|
||||||
), f"DeepEP {self.deepep_mode} mode requires deep_gemm"
|
), f"DeepEP {self.deepep_mode} mode requires deep_gemm"
|
||||||
if _use_aiter:
|
if _use_aiter:
|
||||||
# expert_mask is of size (self.num_experts_per_partition + 1),
|
# expert_mask is of size (self.num_local_experts + 1),
|
||||||
# the extra 1 is for invalid rank_id (in original deepep, the invalid rank_id is -1, but aiter does not allow -1, we use a mask to make those ids invalid)
|
# the extra 1 is for invalid rank_id (in original deepep, the invalid rank_id is -1, but aiter does not allow -1, we use a mask to make those ids invalid)
|
||||||
# for instance, if we have 4 experts on this rank, we would have a expert_mask like:
|
# for instance, if we have 4 experts on this rank, we would have a expert_mask like:
|
||||||
# self.expert_mask = [1, 1, 1, 1, 0]
|
# self.expert_mask = [1, 1, 1, 1, 0]
|
||||||
# idx from 0-3 is valid and will be processed, while idx == 4 will be masked out
|
# idx from 0-3 is valid and will be processed, while idx == 4 will be masked out
|
||||||
self.expert_mask = torch.zeros(
|
self.expert_mask = torch.zeros(
|
||||||
(self.num_experts_per_partition + 1),
|
(self.num_local_experts + 1),
|
||||||
device=torch.cuda.current_device(),
|
device=torch.cuda.current_device(),
|
||||||
dtype=torch.int,
|
dtype=torch.int,
|
||||||
)
|
)
|
||||||
@@ -1011,13 +881,13 @@ class DeepEPMoE(EPMoE):
|
|||||||
if self.activation_scheme == "dynamic" and not self.use_block_quant:
|
if self.activation_scheme == "dynamic" and not self.use_block_quant:
|
||||||
max_value = (
|
max_value = (
|
||||||
torch.max(hidden_states)
|
torch.max(hidden_states)
|
||||||
.repeat(self.num_experts_per_partition)
|
.repeat(self.num_local_experts)
|
||||||
.to(torch.float32)
|
.to(torch.float32)
|
||||||
)
|
)
|
||||||
self.w13_input_scale = max_value / torch.finfo(self.fp8_dtype).max
|
self.w13_input_scale = max_value / torch.finfo(self.fp8_dtype).max
|
||||||
weight_indices_cur_rank = torch.arange(
|
weight_indices_cur_rank = torch.arange(
|
||||||
0,
|
0,
|
||||||
self.num_experts_per_partition,
|
self.num_local_experts,
|
||||||
device=hidden_states.device,
|
device=hidden_states.device,
|
||||||
dtype=torch.int64,
|
dtype=torch.int64,
|
||||||
)
|
)
|
||||||
@@ -1029,7 +899,7 @@ class DeepEPMoE(EPMoE):
|
|||||||
b=self.w13_weight,
|
b=self.w13_weight,
|
||||||
c=None,
|
c=None,
|
||||||
c_dtype=hidden_states.dtype,
|
c_dtype=hidden_states.dtype,
|
||||||
batch_size=self.num_experts_per_partition,
|
batch_size=self.num_local_experts,
|
||||||
weight_column_major=True,
|
weight_column_major=True,
|
||||||
seg_indptr=seg_indptr,
|
seg_indptr=seg_indptr,
|
||||||
weight_indices=weight_indices_cur_rank,
|
weight_indices=weight_indices_cur_rank,
|
||||||
@@ -1063,7 +933,7 @@ class DeepEPMoE(EPMoE):
|
|||||||
)
|
)
|
||||||
if self.w2_input_scale is None and not self.use_block_quant:
|
if self.w2_input_scale is None and not self.use_block_quant:
|
||||||
self.w2_input_scale = torch.ones(
|
self.w2_input_scale = torch.ones(
|
||||||
self.num_experts_per_partition,
|
self.num_local_experts,
|
||||||
dtype=torch.float32,
|
dtype=torch.float32,
|
||||||
device=hidden_states_device,
|
device=hidden_states_device,
|
||||||
)
|
)
|
||||||
@@ -1076,7 +946,7 @@ class DeepEPMoE(EPMoE):
|
|||||||
reorder_topk_ids,
|
reorder_topk_ids,
|
||||||
self.w2_input_scale,
|
self.w2_input_scale,
|
||||||
0,
|
0,
|
||||||
self.num_experts_per_partition - 1,
|
self.num_local_experts - 1,
|
||||||
BLOCK_SIZE=512,
|
BLOCK_SIZE=512,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
@@ -1096,7 +966,7 @@ class DeepEPMoE(EPMoE):
|
|||||||
a=down_input,
|
a=down_input,
|
||||||
b=self.w2_weight,
|
b=self.w2_weight,
|
||||||
c=down_output,
|
c=down_output,
|
||||||
batch_size=self.num_experts_per_partition,
|
batch_size=self.num_local_experts,
|
||||||
weight_column_major=True,
|
weight_column_major=True,
|
||||||
seg_indptr=seg_indptr,
|
seg_indptr=seg_indptr,
|
||||||
weight_indices=weight_indices_cur_rank,
|
weight_indices=weight_indices_cur_rank,
|
||||||
@@ -1121,9 +991,9 @@ class DeepEPMoE(EPMoE):
|
|||||||
return hidden_states
|
return hidden_states
|
||||||
# in original deepep, idx == -1 meaning invalid and will not be processed.
|
# in original deepep, idx == -1 meaning invalid and will not be processed.
|
||||||
# aiter does not accept -1, we use a expert mask to make these idx invalid
|
# aiter does not accept -1, we use a expert mask to make these idx invalid
|
||||||
# (idx == num_experts_per_partition) meaning not used in aiter fused_moe
|
# (idx == num_local_experts) meaning not used in aiter fused_moe
|
||||||
topk_idx_copy = topk_idx.to(torch.int32)
|
topk_idx_copy = topk_idx.to(torch.int32)
|
||||||
topk_idx_copy[topk_idx_copy == -1] = self.num_experts_per_partition
|
topk_idx_copy[topk_idx_copy == -1] = self.num_local_experts
|
||||||
|
|
||||||
return fused_moe(
|
return fused_moe(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
|
|||||||
@@ -77,6 +77,7 @@ class FusedMoE(torch.nn.Module):
|
|||||||
routed_scaling_factor: Optional[float] = None,
|
routed_scaling_factor: Optional[float] = None,
|
||||||
enable_flashinfer_cutlass_moe: Optional[bool] = False,
|
enable_flashinfer_cutlass_moe: Optional[bool] = False,
|
||||||
enable_ep_moe: Optional[bool] = False,
|
enable_ep_moe: Optional[bool] = False,
|
||||||
|
skip_quant: Optional[bool] = False,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@@ -99,9 +100,6 @@ class FusedMoE(torch.nn.Module):
|
|||||||
|
|
||||||
self.enable_flashinfer_cutlass_moe = enable_flashinfer_cutlass_moe
|
self.enable_flashinfer_cutlass_moe = enable_flashinfer_cutlass_moe
|
||||||
if enable_ep_moe:
|
if enable_ep_moe:
|
||||||
assert (
|
|
||||||
self.enable_flashinfer_cutlass_moe
|
|
||||||
), "FusedMoE only supports EP with --enable-flashinfer-cutlass-moe"
|
|
||||||
self.ep_size = self.tp_size
|
self.ep_size = self.tp_size
|
||||||
self.ep_rank = self.tp_rank
|
self.ep_rank = self.tp_rank
|
||||||
self.tp_size = 1
|
self.tp_size = 1
|
||||||
@@ -110,16 +108,16 @@ class FusedMoE(torch.nn.Module):
|
|||||||
self.expert_map = torch.full((self.num_experts,), -1, dtype=torch.int32)
|
self.expert_map = torch.full((self.num_experts,), -1, dtype=torch.int32)
|
||||||
# Create a expert map for the local experts
|
# Create a expert map for the local experts
|
||||||
assert num_experts % self.ep_size == 0
|
assert num_experts % self.ep_size == 0
|
||||||
self.local_num_experts = num_experts // self.ep_size
|
self.num_local_experts = num_experts // self.ep_size
|
||||||
self.expert_map[
|
self.expert_map[
|
||||||
self.ep_rank
|
self.ep_rank
|
||||||
* self.local_num_experts : (self.ep_rank + 1)
|
* self.num_local_experts : (self.ep_rank + 1)
|
||||||
* self.local_num_experts
|
* self.num_local_experts
|
||||||
] = torch.arange(0, self.local_num_experts, dtype=torch.int32, device="cpu")
|
] = torch.arange(0, self.num_local_experts, dtype=torch.int32, device="cpu")
|
||||||
else:
|
else:
|
||||||
self.ep_size = 1
|
self.ep_size = 1
|
||||||
self.ep_rank = 0
|
self.ep_rank = 0
|
||||||
self.local_num_experts = num_experts
|
self.num_local_experts = num_experts
|
||||||
self.routed_scaling_factor = routed_scaling_factor
|
self.routed_scaling_factor = routed_scaling_factor
|
||||||
assert intermediate_size % self.tp_size == 0
|
assert intermediate_size % self.tp_size == 0
|
||||||
self.intermediate_size_per_partition = intermediate_size // self.tp_size
|
self.intermediate_size_per_partition = intermediate_size // self.tp_size
|
||||||
@@ -134,6 +132,9 @@ class FusedMoE(torch.nn.Module):
|
|||||||
not _is_cpu and global_server_args_dict["enable_triton_kernel_moe"]
|
not _is_cpu and global_server_args_dict["enable_triton_kernel_moe"]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if skip_quant:
|
||||||
|
return
|
||||||
|
|
||||||
if quant_config is None:
|
if quant_config is None:
|
||||||
self.quant_method: Optional[QuantizeMethodBase] = UnquantizedFusedMoEMethod(
|
self.quant_method: Optional[QuantizeMethodBase] = UnquantizedFusedMoEMethod(
|
||||||
self.use_triton_kernels
|
self.use_triton_kernels
|
||||||
@@ -149,7 +150,7 @@ class FusedMoE(torch.nn.Module):
|
|||||||
self.quant_config = quant_config
|
self.quant_config = quant_config
|
||||||
self.quant_method.create_weights(
|
self.quant_method.create_weights(
|
||||||
layer=self,
|
layer=self,
|
||||||
num_experts=self.local_num_experts,
|
num_experts=self.num_local_experts,
|
||||||
hidden_size=hidden_size,
|
hidden_size=hidden_size,
|
||||||
# FIXME: figure out which intermediate_size to use
|
# FIXME: figure out which intermediate_size to use
|
||||||
intermediate_size=self.intermediate_size_per_partition,
|
intermediate_size=self.intermediate_size_per_partition,
|
||||||
@@ -378,6 +379,23 @@ class FusedMoE(torch.nn.Module):
|
|||||||
if expert_id == -1:
|
if expert_id == -1:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
self._weight_loader_impl(
|
||||||
|
param=param,
|
||||||
|
loaded_weight=loaded_weight,
|
||||||
|
weight_name=weight_name,
|
||||||
|
shard_id=shard_id,
|
||||||
|
expert_id=expert_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _weight_loader_impl(
|
||||||
|
self,
|
||||||
|
param: torch.nn.Parameter,
|
||||||
|
loaded_weight: torch.Tensor,
|
||||||
|
weight_name: str,
|
||||||
|
shard_id: str,
|
||||||
|
expert_id: int,
|
||||||
|
) -> None:
|
||||||
|
|
||||||
# TP rank is set to 0 if EP is enabled
|
# TP rank is set to 0 if EP is enabled
|
||||||
tp_rank = 0 if self.ep_size > 1 else get_tensor_model_parallel_rank()
|
tp_rank = 0 if self.ep_size > 1 else get_tensor_model_parallel_rank()
|
||||||
|
|
||||||
@@ -398,6 +416,10 @@ class FusedMoE(torch.nn.Module):
|
|||||||
f"shard_id must be ['w1','w2','w3'] but " f"got {shard_id}."
|
f"shard_id must be ['w1','w2','w3'] but " f"got {shard_id}."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Flashinfer assumes w31 format for w13_weight. Same for the scales.
|
||||||
|
if getattr(self, "use_flashinfer_trtllm_moe", False):
|
||||||
|
shard_id = {"w1": "w3", "w3": "w1", "w2": "w2"}[shard_id]
|
||||||
|
|
||||||
WEIGHT_SCALE_SUPPORTED = [e.value for e in FusedMoeWeightScaleSupported]
|
WEIGHT_SCALE_SUPPORTED = [e.value for e in FusedMoeWeightScaleSupported]
|
||||||
# Fetch the dim to shard the parameter/loaded weight
|
# Fetch the dim to shard the parameter/loaded weight
|
||||||
# based on the shard id. This will be whatever
|
# based on the shard id. This will be whatever
|
||||||
@@ -605,37 +627,3 @@ class FusedMoE(torch.nn.Module):
|
|||||||
("w3", ckpt_up_proj_name),
|
("w3", ckpt_up_proj_name),
|
||||||
]
|
]
|
||||||
]
|
]
|
||||||
|
|
||||||
def _load_fp8_scale(
|
|
||||||
self,
|
|
||||||
param: torch.nn.Parameter,
|
|
||||||
loaded_weight: torch.Tensor,
|
|
||||||
weight_name: str,
|
|
||||||
shard_id: str,
|
|
||||||
expert_id: int,
|
|
||||||
) -> None:
|
|
||||||
param_data = param.data
|
|
||||||
|
|
||||||
# Input scales can be loaded directly and should be equal.
|
|
||||||
if "input_scale" in weight_name:
|
|
||||||
if (
|
|
||||||
param_data[expert_id] != 1
|
|
||||||
and (param_data[expert_id] - loaded_weight).abs() > 1e-5
|
|
||||||
):
|
|
||||||
raise ValueError(
|
|
||||||
"input_scales of w1 and w3 of a layer "
|
|
||||||
f"must be equal. But got {param_data[expert_id]} "
|
|
||||||
f"vs. {loaded_weight}"
|
|
||||||
)
|
|
||||||
param_data[expert_id] = loaded_weight
|
|
||||||
# Weight scales
|
|
||||||
elif "weight_scale" in weight_name:
|
|
||||||
# If we are in merged column case (gate_up_proj)
|
|
||||||
if shard_id in ("w1", "w3"):
|
|
||||||
# We have to keep the weight scales of w1 and w3 because
|
|
||||||
# we need to re-quantize w1/w3 weights after weight loading.
|
|
||||||
idx = 0 if shard_id == "w1" else 1
|
|
||||||
param_data[expert_id][idx] = loaded_weight
|
|
||||||
# If we are in the row parallel case (down_proj)
|
|
||||||
else:
|
|
||||||
param_data[expert_id] = loaded_weight
|
|
||||||
|
|||||||
@@ -172,6 +172,7 @@ class Fp8Config(QuantizationConfig):
|
|||||||
self, layer: torch.nn.Module, prefix: str
|
self, layer: torch.nn.Module, prefix: str
|
||||||
) -> Optional[QuantizeMethodBase]:
|
) -> Optional[QuantizeMethodBase]:
|
||||||
from sglang.srt.layers.linear import LinearBase
|
from sglang.srt.layers.linear import LinearBase
|
||||||
|
from sglang.srt.layers.moe.ep_moe.layer import EPMoE
|
||||||
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
||||||
|
|
||||||
if isinstance(layer, LinearBase):
|
if isinstance(layer, LinearBase):
|
||||||
@@ -180,6 +181,8 @@ class Fp8Config(QuantizationConfig):
|
|||||||
return Fp8LinearMethod(self)
|
return Fp8LinearMethod(self)
|
||||||
elif isinstance(layer, FusedMoE):
|
elif isinstance(layer, FusedMoE):
|
||||||
return Fp8MoEMethod(self)
|
return Fp8MoEMethod(self)
|
||||||
|
elif isinstance(layer, EPMoE):
|
||||||
|
return Fp8EPMoEMethod(self)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def get_scaled_act_names(self) -> List[str]:
|
def get_scaled_act_names(self) -> List[str]:
|
||||||
@@ -791,11 +794,13 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|||||||
# merged w13 weights and generate a single scaling factor.
|
# merged w13 weights and generate a single scaling factor.
|
||||||
layer.w13_weight_scale = torch.nn.Parameter(
|
layer.w13_weight_scale = torch.nn.Parameter(
|
||||||
torch.ones(
|
torch.ones(
|
||||||
layer.num_experts, dtype=torch.float32, device=w13_weight.device
|
layer.num_local_experts,
|
||||||
|
dtype=torch.float32,
|
||||||
|
device=w13_weight.device,
|
||||||
),
|
),
|
||||||
requires_grad=False,
|
requires_grad=False,
|
||||||
)
|
)
|
||||||
for expert in range(layer.num_experts):
|
for expert in range(layer.num_local_experts):
|
||||||
w13_weight[expert, :, :], layer.w13_weight_scale[expert] = (
|
w13_weight[expert, :, :], layer.w13_weight_scale[expert] = (
|
||||||
scaled_fp8_quant(layer.w13_weight.data[expert, :, :])
|
scaled_fp8_quant(layer.w13_weight.data[expert, :, :])
|
||||||
)
|
)
|
||||||
@@ -871,7 +876,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|||||||
assert layer.w13_weight_scale is not None
|
assert layer.w13_weight_scale is not None
|
||||||
shard_size = layer.intermediate_size_per_partition
|
shard_size = layer.intermediate_size_per_partition
|
||||||
max_w13_scales = layer.w13_weight_scale.max(dim=1).values
|
max_w13_scales = layer.w13_weight_scale.max(dim=1).values
|
||||||
for expert_id in range(layer.num_experts):
|
for expert_id in range(layer.num_local_experts):
|
||||||
start = 0
|
start = 0
|
||||||
for shard_id in range(2):
|
for shard_id in range(2):
|
||||||
dq_weight = per_tensor_dequantize(
|
dq_weight = per_tensor_dequantize(
|
||||||
@@ -914,7 +919,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|||||||
assert layer.w13_weight_scale is not None
|
assert layer.w13_weight_scale is not None
|
||||||
shard_size = layer.intermediate_size_per_partition
|
shard_size = layer.intermediate_size_per_partition
|
||||||
max_w13_scales = layer.w13_weight_scale.max(dim=1).values
|
max_w13_scales = layer.w13_weight_scale.max(dim=1).values
|
||||||
for expert_id in range(layer.num_experts):
|
for expert_id in range(layer.num_local_experts):
|
||||||
start = 0
|
start = 0
|
||||||
max_w13_scale_fp8 = max_w13_scales[expert_id]
|
max_w13_scale_fp8 = max_w13_scales[expert_id]
|
||||||
for shard_id in range(2):
|
for shard_id in range(2):
|
||||||
@@ -931,7 +936,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|||||||
|
|
||||||
# special hack to asm_moe, which takes (weight_scale1 * weight_scale) as post GEMM scaling
|
# special hack to asm_moe, which takes (weight_scale1 * weight_scale) as post GEMM scaling
|
||||||
# optimal design - shall apply per-column weight_scale1 before GEMM, and weight_scale post
|
# optimal design - shall apply per-column weight_scale1 before GEMM, and weight_scale post
|
||||||
for expert_id in range(layer.num_experts):
|
for expert_id in range(layer.num_local_experts):
|
||||||
layer.w13_weight_scale1[expert_id] *= max_w13_scales[expert_id]
|
layer.w13_weight_scale1[expert_id] *= max_w13_scales[expert_id]
|
||||||
layer.w2_weight_scale1[expert_id] *= layer.w2_weight_scale[expert_id]
|
layer.w2_weight_scale1[expert_id] *= layer.w2_weight_scale[expert_id]
|
||||||
|
|
||||||
@@ -979,8 +984,23 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|||||||
no_combine: bool = False,
|
no_combine: bool = False,
|
||||||
routed_scaling_factor: Optional[float] = None,
|
routed_scaling_factor: Optional[float] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
|
from sglang.srt.layers.moe.ep_moe.layer import EPMoE
|
||||||
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
|
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
|
||||||
|
|
||||||
|
if isinstance(layer, EPMoE):
|
||||||
|
layer.w13_weight_scale = (
|
||||||
|
layer.w13_weight_scale_inv
|
||||||
|
if self.block_quant
|
||||||
|
else layer.w13_weight_scale
|
||||||
|
)
|
||||||
|
layer.w2_weight_scale = (
|
||||||
|
layer.w2_weight_scale_inv if self.block_quant else layer.w2_weight_scale
|
||||||
|
)
|
||||||
|
return layer.run_moe(
|
||||||
|
hidden_states=x,
|
||||||
|
topk_output=topk_output,
|
||||||
|
)
|
||||||
|
|
||||||
if use_intel_amx_backend(layer):
|
if use_intel_amx_backend(layer):
|
||||||
from sglang.srt.layers.moe.topk import apply_topk_weights_cpu
|
from sglang.srt.layers.moe.topk import apply_topk_weights_cpu
|
||||||
|
|
||||||
@@ -1138,248 +1158,6 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
class Fp8EPMoEMethod(Fp8MoEMethod):
|
|
||||||
"""MoE method for FP8.
|
|
||||||
Supports loading FP8 checkpoints with static weight scale and
|
|
||||||
dynamic/static activation scale.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
quant_config: The quantization config.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, quant_config: Fp8Config):
|
|
||||||
self.quant_config = quant_config
|
|
||||||
self.block_quant = self.quant_config.weight_block_size is not None
|
|
||||||
|
|
||||||
def create_weights(
|
|
||||||
self,
|
|
||||||
layer: Module,
|
|
||||||
num_experts_per_partition: int,
|
|
||||||
hidden_size: int,
|
|
||||||
intermediate_size: int,
|
|
||||||
params_dtype: torch.dtype,
|
|
||||||
**extra_weight_attrs,
|
|
||||||
):
|
|
||||||
from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported
|
|
||||||
|
|
||||||
if self.quant_config.is_checkpoint_fp8_serialized:
|
|
||||||
params_dtype = torch.float8_e4m3fn
|
|
||||||
|
|
||||||
tp_size = get_tensor_model_parallel_world_size()
|
|
||||||
if self.block_quant:
|
|
||||||
block_n, block_k = (
|
|
||||||
self.quant_config.weight_block_size[0],
|
|
||||||
self.quant_config.weight_block_size[1],
|
|
||||||
)
|
|
||||||
# NOTE(HandH1998): To ensure proper alignment of the block-wise quantization scales, the output_size of the weights for both the gate and up layers must be divisible by block_n.
|
|
||||||
# Required by column parallel or enabling merged weights
|
|
||||||
if intermediate_size % block_n != 0:
|
|
||||||
raise ValueError(
|
|
||||||
f"The output_size of gate's and up's weight = "
|
|
||||||
f"{intermediate_size} is not divisible by "
|
|
||||||
f"weight quantization block_n = {block_n}."
|
|
||||||
)
|
|
||||||
if tp_size > 1:
|
|
||||||
# Required by row parallel
|
|
||||||
if intermediate_size % block_k != 0:
|
|
||||||
raise ValueError(
|
|
||||||
f"The input_size of down's weight = "
|
|
||||||
f"{intermediate_size} is not divisible by "
|
|
||||||
f"weight quantization block_k = {block_k}."
|
|
||||||
)
|
|
||||||
|
|
||||||
# WEIGHTS
|
|
||||||
w13_weight = torch.nn.Parameter(
|
|
||||||
torch.empty(
|
|
||||||
num_experts_per_partition,
|
|
||||||
2 * intermediate_size,
|
|
||||||
hidden_size,
|
|
||||||
dtype=params_dtype,
|
|
||||||
),
|
|
||||||
requires_grad=False,
|
|
||||||
)
|
|
||||||
layer.register_parameter("w13_weight", w13_weight)
|
|
||||||
set_weight_attrs(w13_weight, extra_weight_attrs)
|
|
||||||
|
|
||||||
w2_weight = torch.nn.Parameter(
|
|
||||||
torch.empty(
|
|
||||||
num_experts_per_partition,
|
|
||||||
hidden_size,
|
|
||||||
intermediate_size,
|
|
||||||
dtype=params_dtype,
|
|
||||||
),
|
|
||||||
requires_grad=False,
|
|
||||||
)
|
|
||||||
layer.register_parameter("w2_weight", w2_weight)
|
|
||||||
set_weight_attrs(w2_weight, extra_weight_attrs)
|
|
||||||
|
|
||||||
# WEIGHT_SCALES
|
|
||||||
if self.block_quant:
|
|
||||||
w13_weight_scale = torch.nn.Parameter(
|
|
||||||
torch.ones(
|
|
||||||
num_experts_per_partition,
|
|
||||||
2 * ((intermediate_size + block_n - 1) // block_n),
|
|
||||||
(hidden_size + block_k - 1) // block_k,
|
|
||||||
dtype=torch.float32,
|
|
||||||
),
|
|
||||||
requires_grad=False,
|
|
||||||
)
|
|
||||||
w2_weight_scale = torch.nn.Parameter(
|
|
||||||
torch.ones(
|
|
||||||
num_experts_per_partition,
|
|
||||||
(hidden_size + block_n - 1) // block_n,
|
|
||||||
(intermediate_size + block_k - 1) // block_k,
|
|
||||||
dtype=torch.float32,
|
|
||||||
),
|
|
||||||
requires_grad=False,
|
|
||||||
)
|
|
||||||
layer.register_parameter("w13_weight_scale_inv", w13_weight_scale)
|
|
||||||
layer.register_parameter("w2_weight_scale_inv", w2_weight_scale)
|
|
||||||
assert self.quant_config.activation_scheme == "dynamic"
|
|
||||||
else:
|
|
||||||
# WEIGHT_SCALES
|
|
||||||
# Allocate 2 scales for w1 and w3 respectively.
|
|
||||||
w13_weight_scale = torch.nn.Parameter(
|
|
||||||
torch.ones(num_experts_per_partition, 2, dtype=torch.float32),
|
|
||||||
requires_grad=False,
|
|
||||||
)
|
|
||||||
layer.register_parameter("w13_weight_scale", w13_weight_scale)
|
|
||||||
|
|
||||||
w2_weight_scale = torch.nn.Parameter(
|
|
||||||
torch.ones(num_experts_per_partition, dtype=torch.float32),
|
|
||||||
requires_grad=False,
|
|
||||||
)
|
|
||||||
layer.register_parameter("w2_weight_scale", w2_weight_scale)
|
|
||||||
# Add the quantization method used (per tensor/grouped/channel)
|
|
||||||
# to ensure the weight scales are loaded in properly
|
|
||||||
extra_weight_attrs.update(
|
|
||||||
{"quant_method": FusedMoeWeightScaleSupported.BLOCK.value}
|
|
||||||
if self.block_quant
|
|
||||||
else {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value}
|
|
||||||
)
|
|
||||||
# If loading fp8 checkpoint, pass the weight loaders.
|
|
||||||
# If loading an fp16 checkpoint, do not (we will quantize in
|
|
||||||
# process_weights_after_loading()
|
|
||||||
if self.quant_config.is_checkpoint_fp8_serialized:
|
|
||||||
set_weight_attrs(w13_weight_scale, extra_weight_attrs)
|
|
||||||
set_weight_attrs(w2_weight_scale, extra_weight_attrs)
|
|
||||||
|
|
||||||
# INPUT_SCALES
|
|
||||||
if self.quant_config.activation_scheme == "static":
|
|
||||||
if not self.quant_config.is_checkpoint_fp8_serialized:
|
|
||||||
raise ValueError(
|
|
||||||
"Found static activation scheme for checkpoint that "
|
|
||||||
"was not serialized fp8."
|
|
||||||
)
|
|
||||||
|
|
||||||
w13_input_scale = torch.nn.Parameter(
|
|
||||||
torch.ones(num_experts_per_partition, dtype=torch.float32),
|
|
||||||
requires_grad=False,
|
|
||||||
)
|
|
||||||
layer.register_parameter("w13_input_scale", w13_input_scale)
|
|
||||||
set_weight_attrs(w13_input_scale, extra_weight_attrs)
|
|
||||||
|
|
||||||
w2_input_scale = torch.nn.Parameter(
|
|
||||||
torch.ones(num_experts_per_partition, dtype=torch.float32),
|
|
||||||
requires_grad=False,
|
|
||||||
)
|
|
||||||
layer.register_parameter("w2_input_scale", w2_input_scale)
|
|
||||||
set_weight_attrs(w2_input_scale, extra_weight_attrs)
|
|
||||||
|
|
||||||
else:
|
|
||||||
layer.w13_input_scale = None
|
|
||||||
layer.w2_input_scale = None
|
|
||||||
|
|
||||||
def process_weights_after_loading(self, layer: Module) -> None:
|
|
||||||
|
|
||||||
# If checkpoint is fp16, quantize in place.
|
|
||||||
if not self.quant_config.is_checkpoint_fp8_serialized:
|
|
||||||
# If rocm, use float8_e4m3fnuz as dtype
|
|
||||||
fp8_dtype = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn
|
|
||||||
w13_weight = torch.empty_like(layer.w13_weight.data, dtype=fp8_dtype)
|
|
||||||
w2_weight = torch.empty_like(layer.w2_weight.data, dtype=fp8_dtype)
|
|
||||||
|
|
||||||
layer.w13_weight_scale = torch.nn.Parameter(
|
|
||||||
torch.ones(
|
|
||||||
layer.num_experts_per_partition,
|
|
||||||
dtype=torch.float32,
|
|
||||||
device=w13_weight.device,
|
|
||||||
),
|
|
||||||
requires_grad=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
for expert in range(layer.num_experts_per_partition):
|
|
||||||
w13_weight[expert, :, :], layer.w13_weight_scale[expert] = (
|
|
||||||
scaled_fp8_quant(layer.w13_weight.data[expert, :, :])
|
|
||||||
)
|
|
||||||
w2_weight[expert, :, :], layer.w2_weight_scale[expert] = (
|
|
||||||
scaled_fp8_quant(layer.w2_weight.data[expert, :, :])
|
|
||||||
)
|
|
||||||
layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False)
|
|
||||||
layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False)
|
|
||||||
return
|
|
||||||
|
|
||||||
# If checkpoint is fp8, we need to handle that the
|
|
||||||
# MoE kernels require single activation scale and single weight
|
|
||||||
# scale for w13 per expert.
|
|
||||||
else:
|
|
||||||
if self.quant_config.activation_scheme == "static":
|
|
||||||
if layer.w13_input_scale is None or layer.w2_input_scale is None:
|
|
||||||
raise ValueError(
|
|
||||||
"QuantConfig has static quantization, but found "
|
|
||||||
"activation scales are None."
|
|
||||||
)
|
|
||||||
layer.w13_weight_scale = torch.nn.Parameter(
|
|
||||||
torch.max(layer.w13_weight_scale, dim=1).values,
|
|
||||||
requires_grad=False,
|
|
||||||
)
|
|
||||||
if self.block_quant:
|
|
||||||
# If ROCm, normalize the weights and scales to e4m3fnuz
|
|
||||||
if _is_fp8_fnuz:
|
|
||||||
# activation_scheme: dynamic
|
|
||||||
w13_weight, w13_weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
|
|
||||||
weight=layer.w13_weight,
|
|
||||||
weight_scale=layer.w13_weight_scale_inv,
|
|
||||||
input_scale=None,
|
|
||||||
)
|
|
||||||
w2_weight, w2_weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
|
|
||||||
weight=layer.w2_weight,
|
|
||||||
weight_scale=layer.w2_weight_scale_inv,
|
|
||||||
input_scale=None,
|
|
||||||
)
|
|
||||||
# Reset the parameter
|
|
||||||
layer.w13_weight = torch.nn.Parameter(
|
|
||||||
w13_weight, requires_grad=False
|
|
||||||
)
|
|
||||||
layer.w13_weight_scale_inv = torch.nn.Parameter(
|
|
||||||
w13_weight_scale, requires_grad=False
|
|
||||||
)
|
|
||||||
layer.w13_input_scale = None
|
|
||||||
layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False)
|
|
||||||
layer.w2_weight_scale_inv = torch.nn.Parameter(
|
|
||||||
w2_weight_scale, requires_grad=False
|
|
||||||
)
|
|
||||||
layer.w2_input_scale = None
|
|
||||||
if _use_aiter:
|
|
||||||
layer.w13_weight = torch.nn.Parameter(
|
|
||||||
shuffle_weight(layer.w13_weight.data, (16, 16)),
|
|
||||||
requires_grad=False,
|
|
||||||
)
|
|
||||||
layer.w2_weight = torch.nn.Parameter(
|
|
||||||
shuffle_weight(layer.w2_weight.data, (16, 16)),
|
|
||||||
requires_grad=False,
|
|
||||||
)
|
|
||||||
return
|
|
||||||
|
|
||||||
def apply(
|
|
||||||
self,
|
|
||||||
layer: torch.nn.Module,
|
|
||||||
hidden_states: torch.Tensor,
|
|
||||||
topk_output: TopKOutput,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
|
|
||||||
class Fp8KVCacheMethod(BaseKVCacheMethod):
|
class Fp8KVCacheMethod(BaseKVCacheMethod):
|
||||||
"""
|
"""
|
||||||
Supports loading kv-cache scaling factors from FP8 checkpoints.
|
Supports loading kv-cache scaling factors from FP8 checkpoints.
|
||||||
|
|||||||
@@ -24,6 +24,7 @@ from sglang.srt.utils import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
from sglang.srt.layers.moe.ep_moe.layer import EPMoE
|
||||||
from sglang.srt.layers.moe.topk import TopKOutput
|
from sglang.srt.layers.moe.topk import TopKOutput
|
||||||
|
|
||||||
has_triton_kernels = importlib.util.find_spec("triton_kernels") is not None
|
has_triton_kernels = importlib.util.find_spec("triton_kernels") is not None
|
||||||
@@ -194,6 +195,15 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
|||||||
no_combine: bool = False,
|
no_combine: bool = False,
|
||||||
routed_scaling_factor: Optional[float] = None,
|
routed_scaling_factor: Optional[float] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
|
|
||||||
|
from sglang.srt.layers.moe.ep_moe.layer import EPMoE
|
||||||
|
|
||||||
|
if isinstance(layer, EPMoE):
|
||||||
|
return layer.run_moe(
|
||||||
|
hidden_states=x,
|
||||||
|
topk_output=topk_output,
|
||||||
|
)
|
||||||
|
|
||||||
return self.forward(
|
return self.forward(
|
||||||
x=x,
|
x=x,
|
||||||
layer=layer,
|
layer=layer,
|
||||||
@@ -354,69 +364,3 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
|||||||
raise NotImplementedError("The TPU backend currently does not support MoE.")
|
raise NotImplementedError("The TPU backend currently does not support MoE.")
|
||||||
|
|
||||||
forward_native = forward_cpu
|
forward_native = forward_cpu
|
||||||
|
|
||||||
|
|
||||||
class UnquantizedEPMoEMethod(FusedMoEMethodBase, CustomOp):
|
|
||||||
|
|
||||||
def create_weights(
|
|
||||||
self,
|
|
||||||
layer: torch.nn.Module,
|
|
||||||
num_experts_per_partition: int,
|
|
||||||
hidden_size: int,
|
|
||||||
intermediate_size: int,
|
|
||||||
params_dtype: torch.dtype,
|
|
||||||
**extra_weight_attrs,
|
|
||||||
):
|
|
||||||
# Fused gate_up_proj (column parallel)
|
|
||||||
w13_weight = torch.nn.Parameter(
|
|
||||||
torch.empty(
|
|
||||||
num_experts_per_partition,
|
|
||||||
2 * intermediate_size,
|
|
||||||
hidden_size,
|
|
||||||
dtype=params_dtype,
|
|
||||||
),
|
|
||||||
requires_grad=False,
|
|
||||||
)
|
|
||||||
layer.register_parameter("w13_weight", w13_weight)
|
|
||||||
set_weight_attrs(w13_weight, extra_weight_attrs)
|
|
||||||
|
|
||||||
# down_proj (row parallel)
|
|
||||||
w2_weight = torch.nn.Parameter(
|
|
||||||
torch.empty(
|
|
||||||
num_experts_per_partition,
|
|
||||||
hidden_size,
|
|
||||||
intermediate_size,
|
|
||||||
dtype=params_dtype,
|
|
||||||
),
|
|
||||||
requires_grad=False,
|
|
||||||
)
|
|
||||||
layer.register_parameter("w2_weight", w2_weight)
|
|
||||||
set_weight_attrs(w2_weight, extra_weight_attrs)
|
|
||||||
|
|
||||||
# scale
|
|
||||||
layer.register_parameter("w13_input_scale", None)
|
|
||||||
layer.register_parameter("w13_weight_scale", None)
|
|
||||||
|
|
||||||
ones_tensor = torch.ones(num_experts_per_partition, dtype=torch.float32)
|
|
||||||
|
|
||||||
w2_input_scale = torch.nn.Parameter(
|
|
||||||
ones_tensor,
|
|
||||||
requires_grad=False,
|
|
||||||
)
|
|
||||||
layer.register_parameter("w2_input_scale", w2_input_scale)
|
|
||||||
set_weight_attrs(w2_input_scale, extra_weight_attrs)
|
|
||||||
|
|
||||||
w2_weight_scale = torch.nn.Parameter(
|
|
||||||
ones_tensor,
|
|
||||||
requires_grad=False,
|
|
||||||
)
|
|
||||||
layer.register_parameter("w2_weight_scale", w2_weight_scale)
|
|
||||||
set_weight_attrs(w2_weight_scale, extra_weight_attrs)
|
|
||||||
|
|
||||||
def apply(
|
|
||||||
self,
|
|
||||||
layer: torch.nn.Module,
|
|
||||||
hidden_states: torch.Tensor,
|
|
||||||
topk_output: TopKOutput,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import TYPE_CHECKING, Any, Dict, List, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch.nn import Module
|
from torch.nn import Module
|
||||||
@@ -17,6 +17,9 @@ from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod
|
|||||||
from sglang.srt.layers.quantization.utils import is_layer_skipped
|
from sglang.srt.layers.quantization.utils import is_layer_skipped
|
||||||
from sglang.srt.utils import set_weight_attrs
|
from sglang.srt.utils import set_weight_attrs
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from sglang.srt.layers.moe.ep_moe.layer import EPMoE, TopKOutput
|
||||||
|
|
||||||
ACTIVATION_SCHEMES = ["static", "dynamic"]
|
ACTIVATION_SCHEMES = ["static", "dynamic"]
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -84,13 +87,14 @@ class W4AFp8Config(QuantizationConfig):
|
|||||||
self, layer: torch.nn.Module, prefix: str
|
self, layer: torch.nn.Module, prefix: str
|
||||||
) -> Optional[QuantizeMethodBase]:
|
) -> Optional[QuantizeMethodBase]:
|
||||||
from sglang.srt.layers.linear import LinearBase
|
from sglang.srt.layers.linear import LinearBase
|
||||||
|
from sglang.srt.layers.moe.ep_moe.layer import EPMoE
|
||||||
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
||||||
|
|
||||||
if isinstance(layer, LinearBase):
|
if isinstance(layer, LinearBase):
|
||||||
if is_layer_skipped(prefix, self.ignored_layers):
|
if is_layer_skipped(prefix, self.ignored_layers):
|
||||||
return UnquantizedLinearMethod()
|
return UnquantizedLinearMethod()
|
||||||
return Fp8LinearMethod(self)
|
return Fp8LinearMethod(self)
|
||||||
elif isinstance(layer, FusedMoE):
|
elif isinstance(layer, EPMoE):
|
||||||
return W4AFp8MoEMethod(self)
|
return W4AFp8MoEMethod(self)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@@ -105,8 +109,8 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
|
|||||||
|
|
||||||
def create_weights(
|
def create_weights(
|
||||||
self,
|
self,
|
||||||
layer: Module,
|
layer: EPMoE,
|
||||||
num_experts_per_partition: int,
|
num_experts: int,
|
||||||
hidden_size: int,
|
hidden_size: int,
|
||||||
intermediate_size: int,
|
intermediate_size: int,
|
||||||
params_dtype: torch.dtype,
|
params_dtype: torch.dtype,
|
||||||
@@ -117,7 +121,7 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
|
|||||||
# Fused gate_up_proj (column parallel)
|
# Fused gate_up_proj (column parallel)
|
||||||
w13_weight = torch.nn.Parameter(
|
w13_weight = torch.nn.Parameter(
|
||||||
torch.empty(
|
torch.empty(
|
||||||
num_experts_per_partition,
|
num_experts,
|
||||||
intermediate_size * 2,
|
intermediate_size * 2,
|
||||||
hidden_size // 2,
|
hidden_size // 2,
|
||||||
dtype=torch.int8,
|
dtype=torch.int8,
|
||||||
@@ -130,7 +134,7 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
|
|||||||
# down_proj (row parallel)
|
# down_proj (row parallel)
|
||||||
w2_weight = torch.nn.Parameter(
|
w2_weight = torch.nn.Parameter(
|
||||||
torch.empty(
|
torch.empty(
|
||||||
num_experts_per_partition,
|
num_experts,
|
||||||
hidden_size,
|
hidden_size,
|
||||||
intermediate_size // 2,
|
intermediate_size // 2,
|
||||||
dtype=torch.int8,
|
dtype=torch.int8,
|
||||||
@@ -142,7 +146,7 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
|
|||||||
|
|
||||||
w13_weight_scale = torch.nn.Parameter(
|
w13_weight_scale = torch.nn.Parameter(
|
||||||
torch.zeros(
|
torch.zeros(
|
||||||
num_experts_per_partition,
|
num_experts,
|
||||||
2 * intermediate_size,
|
2 * intermediate_size,
|
||||||
hidden_size // self.quant_config.group_size,
|
hidden_size // self.quant_config.group_size,
|
||||||
dtype=torch.float32,
|
dtype=torch.float32,
|
||||||
@@ -154,7 +158,7 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
|
|||||||
|
|
||||||
w2_weight_scale = torch.nn.Parameter(
|
w2_weight_scale = torch.nn.Parameter(
|
||||||
torch.zeros(
|
torch.zeros(
|
||||||
num_experts_per_partition,
|
num_experts,
|
||||||
hidden_size,
|
hidden_size,
|
||||||
intermediate_size // self.quant_config.group_size,
|
intermediate_size // self.quant_config.group_size,
|
||||||
dtype=torch.float32,
|
dtype=torch.float32,
|
||||||
@@ -166,14 +170,14 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
|
|||||||
|
|
||||||
# Input scales
|
# Input scales
|
||||||
w13_input_scale = torch.nn.Parameter(
|
w13_input_scale = torch.nn.Parameter(
|
||||||
torch.ones((num_experts_per_partition, 2), dtype=torch.bfloat16),
|
torch.ones((num_experts, 2), dtype=torch.bfloat16),
|
||||||
requires_grad=False,
|
requires_grad=False,
|
||||||
)
|
)
|
||||||
layer.register_parameter("w13_input_scale", w13_input_scale)
|
layer.register_parameter("w13_input_scale", w13_input_scale)
|
||||||
set_weight_attrs(w13_input_scale, extra_weight_attrs)
|
set_weight_attrs(w13_input_scale, extra_weight_attrs)
|
||||||
|
|
||||||
w2_input_scale = torch.nn.Parameter(
|
w2_input_scale = torch.nn.Parameter(
|
||||||
torch.ones(num_experts_per_partition, dtype=torch.bfloat16),
|
torch.ones(num_experts, dtype=torch.bfloat16),
|
||||||
requires_grad=False,
|
requires_grad=False,
|
||||||
)
|
)
|
||||||
layer.register_parameter("w2_input_scale", w2_input_scale)
|
layer.register_parameter("w2_input_scale", w2_input_scale)
|
||||||
@@ -183,25 +187,25 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
|
|||||||
device = layer.w13_weight.device
|
device = layer.w13_weight.device
|
||||||
|
|
||||||
self.a_strides1 = torch.full(
|
self.a_strides1 = torch.full(
|
||||||
(num_experts_per_partition, 3),
|
(num_experts, 3),
|
||||||
hidden_size,
|
hidden_size,
|
||||||
device=device,
|
device=device,
|
||||||
dtype=torch.int64,
|
dtype=torch.int64,
|
||||||
)
|
)
|
||||||
self.c_strides1 = torch.full(
|
self.c_strides1 = torch.full(
|
||||||
(num_experts_per_partition, 3),
|
(num_experts, 3),
|
||||||
2 * intermediate_size,
|
2 * intermediate_size,
|
||||||
device=device,
|
device=device,
|
||||||
dtype=torch.int64,
|
dtype=torch.int64,
|
||||||
)
|
)
|
||||||
self.a_strides2 = torch.full(
|
self.a_strides2 = torch.full(
|
||||||
(num_experts_per_partition, 3),
|
(num_experts, 3),
|
||||||
intermediate_size,
|
intermediate_size,
|
||||||
device=device,
|
device=device,
|
||||||
dtype=torch.int64,
|
dtype=torch.int64,
|
||||||
)
|
)
|
||||||
self.c_strides2 = torch.full(
|
self.c_strides2 = torch.full(
|
||||||
(num_experts_per_partition, 3),
|
(num_experts, 3),
|
||||||
hidden_size,
|
hidden_size,
|
||||||
device=device,
|
device=device,
|
||||||
dtype=torch.int64,
|
dtype=torch.int64,
|
||||||
@@ -212,13 +216,13 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
|
|||||||
self.s_strides2 = self.c_strides2
|
self.s_strides2 = self.c_strides2
|
||||||
|
|
||||||
self.expert_offsets = torch.empty(
|
self.expert_offsets = torch.empty(
|
||||||
(num_experts_per_partition + 1), dtype=torch.int32, device=device
|
(num_experts + 1), dtype=torch.int32, device=device
|
||||||
)
|
)
|
||||||
self.problem_sizes1 = torch.empty(
|
self.problem_sizes1 = torch.empty(
|
||||||
(num_experts_per_partition, 3), dtype=torch.int32, device=device
|
(num_experts, 3), dtype=torch.int32, device=device
|
||||||
)
|
)
|
||||||
self.problem_sizes2 = torch.empty(
|
self.problem_sizes2 = torch.empty(
|
||||||
(num_experts_per_partition, 3), dtype=torch.int32, device=device
|
(num_experts, 3), dtype=torch.int32, device=device
|
||||||
)
|
)
|
||||||
|
|
||||||
return
|
return
|
||||||
@@ -266,3 +270,50 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
|
|||||||
[w2_input_scale_max], dtype=dtype, device=device
|
[w2_input_scale_max], dtype=dtype, device=device
|
||||||
)
|
)
|
||||||
layer.w2_input_scale = Parameter(new_w2_input_scale, requires_grad=False)
|
layer.w2_input_scale = Parameter(new_w2_input_scale, requires_grad=False)
|
||||||
|
|
||||||
|
def apply(
|
||||||
|
self,
|
||||||
|
layer: EPMoE,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
topk_output: TopKOutput,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
|
||||||
|
# TODO(ch-wan): move it out of this class
|
||||||
|
from sglang.srt.layers.moe.cutlass_w4a8_moe import cutlass_w4a8_moe
|
||||||
|
|
||||||
|
topk_ids, topk_weights, _ = topk_output
|
||||||
|
local_topk_ids = topk_ids
|
||||||
|
if layer.expert_map is not None:
|
||||||
|
"Translate info from expert_map to topk_ids"
|
||||||
|
local_topk_ids = torch.where(
|
||||||
|
layer.expert_map[topk_ids] != layer.num_experts,
|
||||||
|
layer.expert_map[topk_ids],
|
||||||
|
layer.num_experts,
|
||||||
|
)
|
||||||
|
|
||||||
|
return cutlass_w4a8_moe(
|
||||||
|
layer.start_expert_id,
|
||||||
|
layer.end_expert_id,
|
||||||
|
layer.num_experts,
|
||||||
|
hidden_states,
|
||||||
|
layer.w13_weight,
|
||||||
|
layer.w2_weight,
|
||||||
|
layer.w13_weight_scale_inv,
|
||||||
|
layer.w2_weight_scale_inv,
|
||||||
|
topk_weights,
|
||||||
|
topk_ids,
|
||||||
|
local_topk_ids,
|
||||||
|
self.a_strides1,
|
||||||
|
self.b_strides1,
|
||||||
|
self.c_strides1,
|
||||||
|
self.a_strides2,
|
||||||
|
self.b_strides2,
|
||||||
|
self.c_strides2,
|
||||||
|
self.s_strides13,
|
||||||
|
self.s_strides2,
|
||||||
|
self.expert_offsets,
|
||||||
|
self.problem_sizes1,
|
||||||
|
self.problem_sizes2,
|
||||||
|
layer.w13_input_scale,
|
||||||
|
layer.w2_input_scale,
|
||||||
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user