[4/N] MoE Refactor: Unified Triton Kernel for FusedMoE and EPMoE (#8515)
This commit is contained in:
@@ -86,79 +86,6 @@ if use_flashinfer_trtllm_moe:
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class GroupedGemmRunner(torch.nn.Module):
|
|
||||||
flashinfer_gemm_warpper = None
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
device,
|
|
||||||
use_flashinfer: bool = False,
|
|
||||||
use_per_token_if_dynamic: bool = True,
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
self.device = device
|
|
||||||
self.use_flashinfer = use_flashinfer
|
|
||||||
self.use_per_token_if_dynamic = use_per_token_if_dynamic
|
|
||||||
if self.use_flashinfer and GroupedGemmRunner.flashinfer_gemm_warpper is None:
|
|
||||||
GroupedGemmRunner._init_flashinfer_wrapper(device)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def _init_flashinfer_wrapper(cls, device):
|
|
||||||
from flashinfer import SegmentGEMMWrapper
|
|
||||||
|
|
||||||
workspace_buffer = torch.empty(
|
|
||||||
128 * 1024 * 1024, dtype=torch.int8, device=device
|
|
||||||
)
|
|
||||||
cls.flashinfer_gemm_warpper = SegmentGEMMWrapper(workspace_buffer)
|
|
||||||
|
|
||||||
# c = a * b
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
a: torch.Tensor,
|
|
||||||
b: torch.Tensor,
|
|
||||||
c: torch.Tensor,
|
|
||||||
batch_size: int,
|
|
||||||
weight_column_major: bool,
|
|
||||||
seg_indptr: Optional[torch.Tensor] = None,
|
|
||||||
weight_indices: Optional[torch.Tensor] = None,
|
|
||||||
use_fp8_w8a8: bool = False,
|
|
||||||
scale_a: torch.Tensor = None,
|
|
||||||
scale_b: torch.Tensor = None,
|
|
||||||
block_shape: Optional[List[int]] = None,
|
|
||||||
c_dtype=None,
|
|
||||||
):
|
|
||||||
if self.use_flashinfer:
|
|
||||||
# TODO: flashinfer
|
|
||||||
assert False
|
|
||||||
assert GroupedGemmRunner.flashinfer_gemm_warpper is not None
|
|
||||||
c = GroupedGemmRunner.flashinfer_gemm_warpper.run(
|
|
||||||
x=a,
|
|
||||||
weights=b,
|
|
||||||
batch_size=batch_size,
|
|
||||||
weight_column_major=weight_column_major,
|
|
||||||
seg_indptr=seg_indptr,
|
|
||||||
weight_indices=weight_indices,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
assert weight_column_major == True
|
|
||||||
c = grouped_gemm_triton(
|
|
||||||
a,
|
|
||||||
b,
|
|
||||||
c,
|
|
||||||
batch_size,
|
|
||||||
weight_column_major,
|
|
||||||
seg_indptr,
|
|
||||||
weight_indices,
|
|
||||||
use_fp8_w8a8,
|
|
||||||
scale_a,
|
|
||||||
scale_b,
|
|
||||||
block_shape=block_shape,
|
|
||||||
c_dtype=c_dtype,
|
|
||||||
use_per_token_if_dynamic=self.use_per_token_if_dynamic,
|
|
||||||
)
|
|
||||||
return c
|
|
||||||
|
|
||||||
|
|
||||||
def _get_tile_tokens_dim(num_tokens, top_k, num_experts):
|
def _get_tile_tokens_dim(num_tokens, top_k, num_experts):
|
||||||
# Guess tokens per expert assuming perfect expert distribution first.
|
# Guess tokens per expert assuming perfect expert distribution first.
|
||||||
num_tokens_per_expert = (num_tokens * top_k) // num_experts
|
num_tokens_per_expert = (num_tokens * top_k) // num_experts
|
||||||
@@ -190,135 +117,50 @@ class EPMoE(FusedMoE):
|
|||||||
prefix: str = "",
|
prefix: str = "",
|
||||||
activation: str = "silu",
|
activation: str = "silu",
|
||||||
routed_scaling_factor: Optional[float] = None,
|
routed_scaling_factor: Optional[float] = None,
|
||||||
use_per_token_if_dynamic: bool = True,
|
|
||||||
):
|
):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
num_experts=num_experts,
|
num_experts=num_experts,
|
||||||
hidden_size=hidden_size,
|
hidden_size=hidden_size,
|
||||||
intermediate_size=intermediate_size,
|
intermediate_size=intermediate_size,
|
||||||
top_k=top_k,
|
|
||||||
num_fused_shared_experts=num_fused_shared_experts,
|
num_fused_shared_experts=num_fused_shared_experts,
|
||||||
layer_id=layer_id,
|
layer_id=layer_id,
|
||||||
|
top_k=top_k,
|
||||||
params_dtype=params_dtype,
|
params_dtype=params_dtype,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
tp_size=tp_size,
|
tp_size=tp_size,
|
||||||
prefix=prefix,
|
prefix=prefix,
|
||||||
activation=activation,
|
activation=activation,
|
||||||
|
# apply_router_weight_on_input=apply_router_weight_on_input,
|
||||||
routed_scaling_factor=routed_scaling_factor,
|
routed_scaling_factor=routed_scaling_factor,
|
||||||
enable_ep_moe=True,
|
enable_ep_moe=True,
|
||||||
skip_quant=True,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if params_dtype is None:
|
|
||||||
params_dtype = torch.get_default_dtype()
|
|
||||||
|
|
||||||
self.layer_id = layer_id
|
|
||||||
self.num_local_experts, self.expert_map = self.determine_expert_map()
|
|
||||||
self.start_expert_id = self.ep_rank * self.num_local_experts
|
self.start_expert_id = self.ep_rank * self.num_local_experts
|
||||||
self.end_expert_id = self.start_expert_id + self.num_local_experts - 1
|
self.end_expert_id = self.start_expert_id + self.num_local_experts - 1
|
||||||
|
|
||||||
self.intermediate_size = intermediate_size
|
self.intermediate_size = intermediate_size
|
||||||
self.use_per_token_if_dynamic = use_per_token_if_dynamic
|
|
||||||
|
|
||||||
# TODO(ch-wan): move quant preparation to FusedMoE
|
if isinstance(quant_config, Fp8Config):
|
||||||
if quant_config is None:
|
|
||||||
self.quant_method: Optional[QuantizeMethodBase] = (
|
|
||||||
UnquantizedFusedMoEMethod()
|
|
||||||
)
|
|
||||||
self.use_fp8_w8a8 = False
|
|
||||||
self.use_block_quant = False
|
|
||||||
self.block_shape = None
|
|
||||||
self.activation_scheme = None
|
|
||||||
self.w13_input_scale = None
|
|
||||||
self.w2_input_scale = None
|
|
||||||
self.w13_weight_scale = None
|
|
||||||
self.w2_weight_scale = None
|
|
||||||
elif isinstance(quant_config, W4AFp8Config):
|
|
||||||
self.quant_method: Optional[QuantizeMethodBase] = W4AFp8MoEMethod(
|
|
||||||
quant_config
|
|
||||||
)
|
|
||||||
self.use_fp8_w8a8 = False
|
|
||||||
self.use_block_quant = False
|
|
||||||
self.fp8_dtype = torch.float8_e4m3fn
|
|
||||||
self.w13_input_scale = None
|
|
||||||
self.w2_input_scale = None
|
|
||||||
self.w13_weight_scale = None
|
|
||||||
self.w2_weight_scale = None
|
|
||||||
self.activation_scheme = quant_config.moe_activation_scheme
|
|
||||||
elif isinstance(quant_config, Fp8Config):
|
|
||||||
self.quant_method: Optional[QuantizeMethodBase] = Fp8MoEMethod(quant_config)
|
|
||||||
self.use_fp8_w8a8 = True
|
|
||||||
self.use_block_quant = getattr(self.quant_method, "block_quant", False)
|
self.use_block_quant = getattr(self.quant_method, "block_quant", False)
|
||||||
self.block_shape = (
|
self.block_shape = (
|
||||||
self.quant_method.quant_config.weight_block_size
|
self.quant_method.quant_config.weight_block_size
|
||||||
if self.use_block_quant
|
if self.use_block_quant
|
||||||
else None
|
else None
|
||||||
)
|
)
|
||||||
|
self.use_fp8_w8a8 = True
|
||||||
self.fp8_dtype = torch.float8_e4m3fn
|
self.fp8_dtype = torch.float8_e4m3fn
|
||||||
self.activation_scheme = quant_config.activation_scheme
|
self.activation_scheme = quant_config.activation_scheme
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported quant_config: {quant_config}")
|
self.use_fp8_w8a8 = False
|
||||||
|
self.use_block_quant = False
|
||||||
self.quant_config = quant_config
|
self.block_shape = None
|
||||||
self.quant_method.create_weights(
|
self.activation_scheme = None
|
||||||
layer=self,
|
|
||||||
num_experts=self.num_local_experts,
|
|
||||||
hidden_size=hidden_size,
|
|
||||||
intermediate_size=self.intermediate_size,
|
|
||||||
params_dtype=params_dtype,
|
|
||||||
weight_loader=self.weight_loader,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.grouped_gemm_runner = None
|
|
||||||
|
|
||||||
# Adapted from https://github.com/vllm-project/vllm/blob/9fb52e523abf7bdaf7e60cf2971edb5a1b13dc08/vllm/model_executor/layers/fused_moe/layer.py#L544C1-L586C43
|
|
||||||
# Modifications: use determine_expert_map as a class internal function, set 'global_num_experts' rather than '-1' for experts not assigned to the current rank.
|
|
||||||
def determine_expert_map(self) -> Tuple[int, Optional[torch.Tensor]]:
|
|
||||||
"""
|
|
||||||
Calculates how many experts should be assigned to each rank for EP and
|
|
||||||
creates a mapping from global to local expert index. Experts are
|
|
||||||
distributed evenly across ranks. Any remaining are assigned to the
|
|
||||||
last rank.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Tuple[int, Optional[torch.Tensor]]: A tuple containing:
|
|
||||||
- local_num_experts (int): The number of experts assigned
|
|
||||||
to the current rank.
|
|
||||||
- expert_map (Optional[torch.Tensor]): A tensor of shape
|
|
||||||
(global_num_experts,) mapping from global to local index.
|
|
||||||
Contains global_num_experts for experts not assigned to the current rank.
|
|
||||||
Returns None if ep_size is 1.
|
|
||||||
"""
|
|
||||||
ep_size = self.ep_size
|
|
||||||
ep_rank = self.ep_rank
|
|
||||||
global_num_experts = self.num_experts
|
|
||||||
|
|
||||||
assert ep_size > 0
|
|
||||||
if ep_size == 1:
|
|
||||||
return (global_num_experts, None)
|
|
||||||
|
|
||||||
local_num_experts = global_num_experts // ep_size
|
|
||||||
|
|
||||||
expert_map = torch.full(
|
|
||||||
(global_num_experts,), global_num_experts, dtype=torch.int32
|
|
||||||
)
|
|
||||||
if ep_rank < (ep_size - 1):
|
|
||||||
expert_map[
|
|
||||||
ep_rank * local_num_experts : (ep_rank + 1) * local_num_experts
|
|
||||||
] = torch.arange(0, local_num_experts, dtype=torch.int32)
|
|
||||||
else:
|
|
||||||
local_num_experts = global_num_experts - ep_rank * local_num_experts
|
|
||||||
|
|
||||||
expert_map[-local_num_experts:] = torch.arange(
|
|
||||||
0, local_num_experts, dtype=torch.int32
|
|
||||||
)
|
|
||||||
return (local_num_experts, expert_map)
|
|
||||||
|
|
||||||
def forward(self, hidden_states: torch.Tensor, topk_output: TopKOutput):
|
def forward(self, hidden_states: torch.Tensor, topk_output: TopKOutput):
|
||||||
if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM and self.use_fp8_w8a8:
|
if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM and self.use_fp8_w8a8:
|
||||||
return self.forward_deepgemm(hidden_states, topk_output)
|
return self.forward_deepgemm(hidden_states, topk_output)
|
||||||
else:
|
else:
|
||||||
return self.forward_normal(hidden_states, topk_output)
|
return super().forward(hidden_states, topk_output)
|
||||||
|
|
||||||
def forward_deepgemm(
|
def forward_deepgemm(
|
||||||
self,
|
self,
|
||||||
@@ -477,303 +319,6 @@ class EPMoE(FusedMoE):
|
|||||||
)
|
)
|
||||||
return output
|
return output
|
||||||
|
|
||||||
def forward_normal(self, hidden_states: torch.Tensor, topk_output: TopKOutput):
|
|
||||||
return self.quant_method.apply(self, hidden_states, topk_output)
|
|
||||||
|
|
||||||
def run_moe(self, hidden_states: torch.Tensor, topk_output: TopKOutput):
|
|
||||||
|
|
||||||
topk_weights, topk_ids, _ = topk_output
|
|
||||||
|
|
||||||
hidden_states_shape = hidden_states.shape
|
|
||||||
hidden_states_dtype = hidden_states.dtype
|
|
||||||
hidden_states_device = hidden_states.device
|
|
||||||
if self.grouped_gemm_runner is None:
|
|
||||||
self.grouped_gemm_runner = GroupedGemmRunner(
|
|
||||||
hidden_states.device,
|
|
||||||
use_flashinfer=False, # TODO: use flashinfer
|
|
||||||
use_per_token_if_dynamic=self.use_per_token_if_dynamic,
|
|
||||||
)
|
|
||||||
|
|
||||||
num_experts = self.num_experts
|
|
||||||
|
|
||||||
reorder_topk_ids, src2dst, seg_indptr = run_moe_ep_preproess(
|
|
||||||
topk_ids,
|
|
||||||
num_experts,
|
|
||||||
)
|
|
||||||
|
|
||||||
gateup_input = torch.empty(
|
|
||||||
(int(hidden_states.shape[0] * self.top_k), hidden_states.shape[1]),
|
|
||||||
device=hidden_states.device,
|
|
||||||
dtype=(
|
|
||||||
self.fp8_dtype
|
|
||||||
if self.use_fp8_w8a8 and not self.use_block_quant
|
|
||||||
else hidden_states.dtype
|
|
||||||
),
|
|
||||||
)
|
|
||||||
if self.activation_scheme == "dynamic" and not self.use_block_quant:
|
|
||||||
if self.use_per_token_if_dynamic:
|
|
||||||
max_value = torch.max(hidden_states, dim=1).values.to(torch.float32)
|
|
||||||
self.w13_input_scale = max_value / torch.finfo(self.fp8_dtype).max
|
|
||||||
else:
|
|
||||||
max_value = (
|
|
||||||
torch.max(hidden_states)
|
|
||||||
.repeat(self.num_local_experts)
|
|
||||||
.to(torch.float32)
|
|
||||||
)
|
|
||||||
self.w13_input_scale = max_value / torch.finfo(self.fp8_dtype).max
|
|
||||||
|
|
||||||
# PreReorder
|
|
||||||
pre_reorder_triton_kernel[(hidden_states.shape[0],)](
|
|
||||||
hidden_states,
|
|
||||||
gateup_input,
|
|
||||||
src2dst,
|
|
||||||
topk_ids,
|
|
||||||
self.w13_input_scale,
|
|
||||||
self.start_expert_id,
|
|
||||||
self.end_expert_id,
|
|
||||||
self.top_k,
|
|
||||||
hidden_states.shape[1],
|
|
||||||
BLOCK_SIZE=512,
|
|
||||||
use_per_token_if_dynamic=self.use_per_token_if_dynamic,
|
|
||||||
)
|
|
||||||
dispose_tensor(hidden_states)
|
|
||||||
|
|
||||||
if (
|
|
||||||
self.activation_scheme == "dynamic"
|
|
||||||
and not self.use_block_quant
|
|
||||||
and self.use_per_token_if_dynamic
|
|
||||||
):
|
|
||||||
scale = torch.empty(
|
|
||||||
hidden_states_shape[0] * self.top_k,
|
|
||||||
device=hidden_states_device,
|
|
||||||
dtype=torch.float32,
|
|
||||||
)
|
|
||||||
scale[src2dst] = (
|
|
||||||
self.w13_input_scale.unsqueeze(1)
|
|
||||||
.expand(hidden_states_shape[0], self.top_k)
|
|
||||||
.reshape(-1)
|
|
||||||
)
|
|
||||||
self.w13_input_scale = scale
|
|
||||||
|
|
||||||
seg_indptr_cur_rank = seg_indptr[self.start_expert_id : self.end_expert_id + 2]
|
|
||||||
weight_indices_cur_rank = torch.arange(
|
|
||||||
0,
|
|
||||||
self.num_local_experts,
|
|
||||||
device=hidden_states_device,
|
|
||||||
dtype=torch.int64,
|
|
||||||
)
|
|
||||||
# GroupGemm-0
|
|
||||||
gateup_output = self.grouped_gemm_runner(
|
|
||||||
a=gateup_input,
|
|
||||||
b=self.w13_weight,
|
|
||||||
c=None,
|
|
||||||
c_dtype=hidden_states_dtype,
|
|
||||||
batch_size=self.num_local_experts,
|
|
||||||
weight_column_major=True,
|
|
||||||
seg_indptr=seg_indptr_cur_rank,
|
|
||||||
weight_indices=weight_indices_cur_rank,
|
|
||||||
use_fp8_w8a8=self.use_fp8_w8a8,
|
|
||||||
scale_a=self.w13_input_scale,
|
|
||||||
scale_b=self.w13_weight_scale,
|
|
||||||
block_shape=self.block_shape,
|
|
||||||
)
|
|
||||||
del gateup_input
|
|
||||||
|
|
||||||
# Act
|
|
||||||
if self.activation_scheme == "dynamic" and not self.use_block_quant:
|
|
||||||
self.w2_input_scale = None
|
|
||||||
down_input = torch.empty(
|
|
||||||
gateup_output.shape[0],
|
|
||||||
gateup_output.shape[1] // 2,
|
|
||||||
device=gateup_output.device,
|
|
||||||
dtype=hidden_states_dtype,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
down_input = torch.empty(
|
|
||||||
gateup_output.shape[0],
|
|
||||||
gateup_output.shape[1] // 2,
|
|
||||||
device=gateup_output.device,
|
|
||||||
dtype=(
|
|
||||||
self.fp8_dtype
|
|
||||||
if (self.use_fp8_w8a8 and not self.use_block_quant)
|
|
||||||
else hidden_states_dtype
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.activation == "silu":
|
|
||||||
silu_and_mul_triton_kernel[(gateup_output.shape[0],)](
|
|
||||||
gateup_output,
|
|
||||||
down_input,
|
|
||||||
gateup_output.shape[1],
|
|
||||||
reorder_topk_ids,
|
|
||||||
self.w2_input_scale,
|
|
||||||
self.start_expert_id,
|
|
||||||
self.end_expert_id,
|
|
||||||
BLOCK_SIZE=512,
|
|
||||||
)
|
|
||||||
elif self.activation == "gelu":
|
|
||||||
gelu_and_mul_triton_kernel[(gateup_output.shape[0],)](
|
|
||||||
gateup_output,
|
|
||||||
down_input,
|
|
||||||
gateup_output.shape[1],
|
|
||||||
reorder_topk_ids,
|
|
||||||
self.w2_input_scale,
|
|
||||||
self.start_expert_id,
|
|
||||||
self.end_expert_id,
|
|
||||||
BLOCK_SIZE=512,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Unsupported activation: {self.activation=}")
|
|
||||||
del gateup_output
|
|
||||||
|
|
||||||
if self.activation_scheme == "dynamic" and not self.use_block_quant:
|
|
||||||
if self.use_per_token_if_dynamic:
|
|
||||||
down_input, self.w2_input_scale = sglang_per_token_quant_fp8(down_input)
|
|
||||||
else:
|
|
||||||
self.w2_input_scale = torch.ones(
|
|
||||||
self.num_local_experts,
|
|
||||||
dtype=torch.float32,
|
|
||||||
device=hidden_states_device,
|
|
||||||
)
|
|
||||||
|
|
||||||
# GroupGemm-1
|
|
||||||
down_output = torch.empty(
|
|
||||||
down_input.shape[0],
|
|
||||||
self.w2_weight.shape[1],
|
|
||||||
device=hidden_states_device,
|
|
||||||
dtype=hidden_states_dtype,
|
|
||||||
)
|
|
||||||
down_output = self.grouped_gemm_runner(
|
|
||||||
a=down_input,
|
|
||||||
b=self.w2_weight,
|
|
||||||
c=down_output,
|
|
||||||
batch_size=self.num_local_experts,
|
|
||||||
weight_column_major=True,
|
|
||||||
seg_indptr=seg_indptr_cur_rank,
|
|
||||||
weight_indices=weight_indices_cur_rank,
|
|
||||||
use_fp8_w8a8=self.use_fp8_w8a8,
|
|
||||||
scale_a=self.w2_input_scale,
|
|
||||||
scale_b=self.w2_weight_scale,
|
|
||||||
block_shape=self.block_shape,
|
|
||||||
)
|
|
||||||
del down_input
|
|
||||||
|
|
||||||
# PostReorder
|
|
||||||
output = torch.empty(
|
|
||||||
hidden_states_shape, dtype=hidden_states_dtype, device=hidden_states_device
|
|
||||||
)
|
|
||||||
post_reorder_triton_kernel[(hidden_states_shape[0],)](
|
|
||||||
down_output,
|
|
||||||
output,
|
|
||||||
src2dst,
|
|
||||||
topk_ids,
|
|
||||||
topk_weights,
|
|
||||||
self.start_expert_id,
|
|
||||||
self.end_expert_id,
|
|
||||||
self.top_k,
|
|
||||||
hidden_states_shape[1],
|
|
||||||
0,
|
|
||||||
BLOCK_SIZE=512,
|
|
||||||
)
|
|
||||||
return output
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def make_expert_params_mapping(
|
|
||||||
cls,
|
|
||||||
ckpt_gate_proj_name: str,
|
|
||||||
ckpt_down_proj_name: str,
|
|
||||||
ckpt_up_proj_name: str,
|
|
||||||
num_experts: int,
|
|
||||||
) -> List[Tuple[str, str, int, str]]:
|
|
||||||
return [
|
|
||||||
# (param_name, weight_name, expert_id, shard_id)
|
|
||||||
(
|
|
||||||
(
|
|
||||||
"experts.w13_"
|
|
||||||
if weight_name in [ckpt_gate_proj_name, ckpt_up_proj_name]
|
|
||||||
else "experts.w2_"
|
|
||||||
),
|
|
||||||
f"experts.{expert_id}.{weight_name}.",
|
|
||||||
expert_id,
|
|
||||||
shard_id,
|
|
||||||
)
|
|
||||||
for expert_id in range(num_experts)
|
|
||||||
for shard_id, weight_name in [
|
|
||||||
("w1", ckpt_gate_proj_name),
|
|
||||||
("w2", ckpt_down_proj_name),
|
|
||||||
("w3", ckpt_up_proj_name),
|
|
||||||
]
|
|
||||||
]
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def make_expert_input_scale_params_mapping(
|
|
||||||
cls,
|
|
||||||
num_experts: int,
|
|
||||||
) -> List[Tuple[str, str, int, str]]:
|
|
||||||
# (param_name, weight_name, expert_id, shard_id)
|
|
||||||
return [
|
|
||||||
(
|
|
||||||
"experts.w13_" if shard_id in ["w1", "w3"] else "experts.w2_",
|
|
||||||
f"experts.{expert_id}.{shard_id}.",
|
|
||||||
expert_id,
|
|
||||||
shard_id,
|
|
||||||
)
|
|
||||||
for expert_id in range(num_experts)
|
|
||||||
for shard_id in ["w1", "w2", "w3"]
|
|
||||||
]
|
|
||||||
|
|
||||||
def weight_loader(
|
|
||||||
self,
|
|
||||||
param: torch.nn.Parameter,
|
|
||||||
loaded_weight: torch.Tensor,
|
|
||||||
weight_name: str,
|
|
||||||
shard_id: str,
|
|
||||||
expert_id: int,
|
|
||||||
) -> None:
|
|
||||||
global_expert_location_metadata = get_global_expert_location_metadata()
|
|
||||||
if global_expert_location_metadata is None:
|
|
||||||
self._weight_loader_impl(
|
|
||||||
param=param,
|
|
||||||
loaded_weight=loaded_weight,
|
|
||||||
weight_name=weight_name,
|
|
||||||
shard_id=shard_id,
|
|
||||||
expert_id=expert_id,
|
|
||||||
)
|
|
||||||
return
|
|
||||||
|
|
||||||
physical_expert_ids = global_expert_location_metadata.logical_to_all_physical(
|
|
||||||
self.layer_id, expert_id
|
|
||||||
)
|
|
||||||
for physical_expert_id in physical_expert_ids:
|
|
||||||
self._weight_loader_physical(
|
|
||||||
param=param,
|
|
||||||
loaded_weight=loaded_weight,
|
|
||||||
weight_name=weight_name,
|
|
||||||
shard_id=shard_id,
|
|
||||||
expert_id=physical_expert_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
def _weight_loader_physical(
|
|
||||||
self,
|
|
||||||
param: torch.nn.Parameter,
|
|
||||||
loaded_weight: torch.Tensor,
|
|
||||||
weight_name: str,
|
|
||||||
shard_id: str,
|
|
||||||
expert_id: int,
|
|
||||||
) -> None:
|
|
||||||
if expert_id < self.start_expert_id or expert_id > self.end_expert_id:
|
|
||||||
return
|
|
||||||
expert_id = expert_id - self.start_expert_id
|
|
||||||
|
|
||||||
self._weight_loader_impl(
|
|
||||||
param=param,
|
|
||||||
loaded_weight=loaded_weight,
|
|
||||||
weight_name=weight_name,
|
|
||||||
shard_id=shard_id,
|
|
||||||
expert_id=expert_id,
|
|
||||||
)
|
|
||||||
return
|
|
||||||
|
|
||||||
|
|
||||||
class DeepEPMoE(EPMoE):
|
class DeepEPMoE(EPMoE):
|
||||||
"""
|
"""
|
||||||
@@ -905,14 +450,15 @@ class DeepEPMoE(EPMoE):
|
|||||||
# in forward_aiter, we skip token permutation and unpermutation, which have been fused inside aiter kernel
|
# in forward_aiter, we skip token permutation and unpermutation, which have been fused inside aiter kernel
|
||||||
return self.forward_aiter(dispatch_output)
|
return self.forward_aiter(dispatch_output)
|
||||||
if dispatch_output.format.is_deepep_normal():
|
if dispatch_output.format.is_deepep_normal():
|
||||||
if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM and self.use_fp8_w8a8:
|
assert deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM and self.use_fp8_w8a8
|
||||||
return self.forward_deepgemm_contiguous(dispatch_output)
|
return self.forward_deepgemm_contiguous(dispatch_output)
|
||||||
else:
|
|
||||||
return self.forward_normal(dispatch_output)
|
|
||||||
elif dispatch_output.format.is_deepep_ll():
|
elif dispatch_output.format.is_deepep_ll():
|
||||||
|
assert deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM and self.use_fp8_w8a8
|
||||||
return self.forward_deepgemm_masked(dispatch_output)
|
return self.forward_deepgemm_masked(dispatch_output)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Invalid deepep_mode: {self.deepep_mode}")
|
raise ValueError(
|
||||||
|
f"Dispatch output format {dispatch_output.format} is not supported"
|
||||||
|
)
|
||||||
|
|
||||||
def combine(
|
def combine(
|
||||||
self,
|
self,
|
||||||
@@ -928,185 +474,6 @@ class DeepEPMoE(EPMoE):
|
|||||||
forward_batch=forward_batch,
|
forward_batch=forward_batch,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _prepare_for_normal(
|
|
||||||
self,
|
|
||||||
hidden_states: torch.Tensor,
|
|
||||||
topk_idx: torch.Tensor,
|
|
||||||
):
|
|
||||||
from sglang.srt.layers.moe.ep_moe.kernels import (
|
|
||||||
deepep_permute_triton_kernel,
|
|
||||||
deepep_run_moe_deep_preprocess,
|
|
||||||
)
|
|
||||||
|
|
||||||
if hidden_states.shape[0] == 0:
|
|
||||||
reorder_topk_ids = torch.empty(
|
|
||||||
(0,), device=hidden_states.device, dtype=torch.int64
|
|
||||||
)
|
|
||||||
seg_indptr = torch.zeros(
|
|
||||||
(self.num_experts + 1,),
|
|
||||||
device=hidden_states.device,
|
|
||||||
dtype=torch.int64,
|
|
||||||
)
|
|
||||||
return reorder_topk_ids, seg_indptr, hidden_states
|
|
||||||
else:
|
|
||||||
if _use_aiter:
|
|
||||||
# skip permutation here as aiter fused_moe has fused inside
|
|
||||||
reorder_topk_ids = torch.empty(
|
|
||||||
(0,), device=hidden_states.device, dtype=torch.int64
|
|
||||||
)
|
|
||||||
seg_indptr = torch.zeros(
|
|
||||||
(self.num_experts + 1,),
|
|
||||||
device=hidden_states.device,
|
|
||||||
dtype=torch.int64,
|
|
||||||
)
|
|
||||||
return reorder_topk_ids, seg_indptr, hidden_states
|
|
||||||
|
|
||||||
reorder_topk_ids, self.src2dst, seg_indptr = deepep_run_moe_deep_preprocess(
|
|
||||||
topk_idx, self.num_experts
|
|
||||||
)
|
|
||||||
num_total_tokens = reorder_topk_ids.numel()
|
|
||||||
gateup_input = torch.empty(
|
|
||||||
(int(num_total_tokens), hidden_states.shape[1]),
|
|
||||||
device=hidden_states.device,
|
|
||||||
dtype=hidden_states.dtype,
|
|
||||||
)
|
|
||||||
# PreReorder
|
|
||||||
deepep_permute_triton_kernel[(hidden_states.shape[0],)](
|
|
||||||
hidden_states,
|
|
||||||
gateup_input,
|
|
||||||
self.src2dst,
|
|
||||||
topk_idx,
|
|
||||||
None,
|
|
||||||
self.router_topk,
|
|
||||||
hidden_states.shape[1],
|
|
||||||
BLOCK_SIZE=512,
|
|
||||||
)
|
|
||||||
return reorder_topk_ids, seg_indptr, gateup_input
|
|
||||||
|
|
||||||
def forward_normal(
|
|
||||||
self,
|
|
||||||
dispatch_output: DeepEPNormalOutput,
|
|
||||||
):
|
|
||||||
hidden_states, topk_idx = (
|
|
||||||
dispatch_output.hidden_states,
|
|
||||||
dispatch_output.topk_idx,
|
|
||||||
)
|
|
||||||
reorder_topk_ids, seg_indptr, hidden_states = self._prepare_for_normal(
|
|
||||||
hidden_states, topk_idx
|
|
||||||
)
|
|
||||||
hidden_states_dtype = hidden_states.dtype
|
|
||||||
hidden_states_device = hidden_states.device
|
|
||||||
|
|
||||||
assert self.quant_method is not None
|
|
||||||
assert self.activation == "silu"
|
|
||||||
if self.grouped_gemm_runner is None:
|
|
||||||
self.grouped_gemm_runner = GroupedGemmRunner(
|
|
||||||
hidden_states.device, use_flashinfer=False # TODO: use flashinfer
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.activation_scheme == "dynamic" and not self.use_block_quant:
|
|
||||||
max_value = (
|
|
||||||
torch.max(hidden_states)
|
|
||||||
.repeat(self.num_local_experts)
|
|
||||||
.to(torch.float32)
|
|
||||||
)
|
|
||||||
self.w13_input_scale = max_value / torch.finfo(self.fp8_dtype).max
|
|
||||||
weight_indices_cur_rank = torch.arange(
|
|
||||||
0,
|
|
||||||
self.num_local_experts,
|
|
||||||
device=hidden_states.device,
|
|
||||||
dtype=torch.int64,
|
|
||||||
)
|
|
||||||
|
|
||||||
# GroupGemm-0
|
|
||||||
if hidden_states.shape[0] > 0:
|
|
||||||
gateup_output = self.grouped_gemm_runner(
|
|
||||||
a=hidden_states,
|
|
||||||
b=self.w13_weight,
|
|
||||||
c=None,
|
|
||||||
c_dtype=hidden_states.dtype,
|
|
||||||
batch_size=self.num_local_experts,
|
|
||||||
weight_column_major=True,
|
|
||||||
seg_indptr=seg_indptr,
|
|
||||||
weight_indices=weight_indices_cur_rank,
|
|
||||||
use_fp8_w8a8=self.use_fp8_w8a8,
|
|
||||||
scale_a=self.w13_input_scale,
|
|
||||||
scale_b=(
|
|
||||||
self.w13_weight_scale_inv
|
|
||||||
if self.use_block_quant
|
|
||||||
else self.w13_weight_scale
|
|
||||||
),
|
|
||||||
block_shape=self.block_shape,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
gateup_output = torch.empty(
|
|
||||||
hidden_states.shape[0],
|
|
||||||
self.w13_weight.shape[1],
|
|
||||||
device=hidden_states.device,
|
|
||||||
dtype=hidden_states.dtype,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Act
|
|
||||||
down_input = torch.empty(
|
|
||||||
gateup_output.shape[0],
|
|
||||||
gateup_output.shape[1] // 2,
|
|
||||||
device=gateup_output.device,
|
|
||||||
dtype=(
|
|
||||||
self.fp8_dtype
|
|
||||||
if (self.use_fp8_w8a8 and not self.use_block_quant)
|
|
||||||
else hidden_states_dtype
|
|
||||||
),
|
|
||||||
)
|
|
||||||
if self.w2_input_scale is None and not self.use_block_quant:
|
|
||||||
self.w2_input_scale = torch.ones(
|
|
||||||
self.num_local_experts,
|
|
||||||
dtype=torch.float32,
|
|
||||||
device=hidden_states_device,
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.activation == "silu":
|
|
||||||
silu_and_mul_triton_kernel[(gateup_output.shape[0],)](
|
|
||||||
gateup_output,
|
|
||||||
down_input,
|
|
||||||
gateup_output.shape[1],
|
|
||||||
reorder_topk_ids,
|
|
||||||
self.w2_input_scale,
|
|
||||||
0,
|
|
||||||
self.num_local_experts - 1,
|
|
||||||
BLOCK_SIZE=512,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Unsupported activation: {self.activation=}")
|
|
||||||
|
|
||||||
del gateup_output
|
|
||||||
|
|
||||||
# GroupGemm-1
|
|
||||||
down_output = torch.empty(
|
|
||||||
down_input.shape[0],
|
|
||||||
self.w2_weight.shape[1],
|
|
||||||
device=hidden_states_device,
|
|
||||||
dtype=hidden_states_dtype,
|
|
||||||
)
|
|
||||||
if down_input.shape[0] > 0:
|
|
||||||
down_output = self.grouped_gemm_runner(
|
|
||||||
a=down_input,
|
|
||||||
b=self.w2_weight,
|
|
||||||
c=down_output,
|
|
||||||
batch_size=self.num_local_experts,
|
|
||||||
weight_column_major=True,
|
|
||||||
seg_indptr=seg_indptr,
|
|
||||||
weight_indices=weight_indices_cur_rank,
|
|
||||||
use_fp8_w8a8=self.use_fp8_w8a8,
|
|
||||||
scale_a=self.w2_input_scale,
|
|
||||||
scale_b=(
|
|
||||||
self.w2_weight_scale_inv
|
|
||||||
if self.use_block_quant
|
|
||||||
else self.w2_weight_scale
|
|
||||||
),
|
|
||||||
block_shape=self.block_shape,
|
|
||||||
)
|
|
||||||
return down_output
|
|
||||||
|
|
||||||
def forward_aiter(
|
def forward_aiter(
|
||||||
self,
|
self,
|
||||||
dispatch_output: DeepEPNormalOutput,
|
dispatch_output: DeepEPNormalOutput,
|
||||||
|
|||||||
@@ -413,18 +413,37 @@ def fused_moe_kernel(
|
|||||||
num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr)
|
num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr)
|
||||||
if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded:
|
if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded:
|
||||||
return
|
return
|
||||||
offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to(tl.int64)
|
||||||
offs_token = tl.load(sorted_token_ids_ptr + offs_token_id)
|
offs_token = tl.load(sorted_token_ids_ptr + offs_token_id)
|
||||||
offs_token = offs_token.to(tl.int64)
|
offs_token = offs_token.to(tl.int64)
|
||||||
token_mask = offs_token < num_valid_tokens
|
token_mask = offs_token < num_valid_tokens
|
||||||
|
|
||||||
offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
|
off_experts = tl.load(expert_ids_ptr + pid_m).to(tl.int64)
|
||||||
|
|
||||||
|
if off_experts == -1:
|
||||||
|
# -----------------------------------------------------------
|
||||||
|
# Write back zeros to the output when the expert is not
|
||||||
|
# in the current expert parallel rank.
|
||||||
|
write_zeros_to_output(
|
||||||
|
c_ptr,
|
||||||
|
stride_cm,
|
||||||
|
stride_cn,
|
||||||
|
pid_n,
|
||||||
|
N,
|
||||||
|
offs_token,
|
||||||
|
token_mask,
|
||||||
|
BLOCK_SIZE_M,
|
||||||
|
BLOCK_SIZE_N,
|
||||||
|
compute_type,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int64)) % N
|
||||||
offs_k = tl.arange(0, BLOCK_SIZE_K)
|
offs_k = tl.arange(0, BLOCK_SIZE_K)
|
||||||
a_ptrs = a_ptr + (
|
a_ptrs = a_ptr + (
|
||||||
offs_token[:, None] // top_k * stride_am + offs_k[None, :] * stride_ak
|
offs_token[:, None] // top_k * stride_am + offs_k[None, :] * stride_ak
|
||||||
)
|
)
|
||||||
|
|
||||||
off_experts = tl.load(expert_ids_ptr + pid_m)
|
|
||||||
b_ptrs = (
|
b_ptrs = (
|
||||||
b_ptr
|
b_ptr
|
||||||
+ off_experts * stride_be
|
+ off_experts * stride_be
|
||||||
@@ -497,7 +516,6 @@ def fused_moe_kernel(
|
|||||||
|
|
||||||
accumulator += tl.dot(a, b) * a_scale[:, None] * b_scale[None, :]
|
accumulator += tl.dot(a, b) * a_scale[:, None] * b_scale[None, :]
|
||||||
else:
|
else:
|
||||||
# fix out of shared memory issue
|
|
||||||
if use_fp8_w8a8:
|
if use_fp8_w8a8:
|
||||||
accumulator = tl.dot(a, b, acc=accumulator)
|
accumulator = tl.dot(a, b, acc=accumulator)
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -12,7 +12,7 @@ from sglang.srt.distributed import (
|
|||||||
tensor_model_parallel_all_reduce,
|
tensor_model_parallel_all_reduce,
|
||||||
)
|
)
|
||||||
from sglang.srt.eplb.expert_location import get_global_expert_location_metadata
|
from sglang.srt.eplb.expert_location import get_global_expert_location_metadata
|
||||||
from sglang.srt.layers.moe.topk import TopKOutput
|
from sglang.srt.layers.moe.topk import StandardTopKOutput
|
||||||
from sglang.srt.layers.quantization.base_config import (
|
from sglang.srt.layers.quantization.base_config import (
|
||||||
QuantizationConfig,
|
QuantizationConfig,
|
||||||
QuantizeMethodBase,
|
QuantizeMethodBase,
|
||||||
@@ -79,7 +79,6 @@ class FusedMoE(torch.nn.Module):
|
|||||||
routed_scaling_factor: Optional[float] = None,
|
routed_scaling_factor: Optional[float] = None,
|
||||||
enable_flashinfer_cutlass_moe: Optional[bool] = False,
|
enable_flashinfer_cutlass_moe: Optional[bool] = False,
|
||||||
enable_ep_moe: Optional[bool] = False,
|
enable_ep_moe: Optional[bool] = False,
|
||||||
skip_quant: Optional[bool] = False,
|
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@@ -95,7 +94,8 @@ class FusedMoE(torch.nn.Module):
|
|||||||
self.tp_rank = get_tensor_model_parallel_rank()
|
self.tp_rank = get_tensor_model_parallel_rank()
|
||||||
self.num_experts = num_experts
|
self.num_experts = num_experts
|
||||||
self.num_fused_shared_experts = num_fused_shared_experts
|
self.num_fused_shared_experts = num_fused_shared_experts
|
||||||
self.expert_map = None
|
self.expert_map_cpu = None
|
||||||
|
self.expert_map_gpu = None
|
||||||
|
|
||||||
if enable_flashinfer_cutlass_moe and quant_config is None:
|
if enable_flashinfer_cutlass_moe and quant_config is None:
|
||||||
logger.warning("Disable flashinfer MoE when quantization config is None.")
|
logger.warning("Disable flashinfer MoE when quantization config is None.")
|
||||||
@@ -104,20 +104,22 @@ class FusedMoE(torch.nn.Module):
|
|||||||
|
|
||||||
self.enable_flashinfer_cutlass_moe = enable_flashinfer_cutlass_moe
|
self.enable_flashinfer_cutlass_moe = enable_flashinfer_cutlass_moe
|
||||||
if enable_ep_moe:
|
if enable_ep_moe:
|
||||||
|
# TODO(ch-wan): support shared experts fusion
|
||||||
self.ep_size = self.tp_size
|
self.ep_size = self.tp_size
|
||||||
self.ep_rank = self.tp_rank
|
self.ep_rank = self.tp_rank
|
||||||
self.tp_size = 1
|
self.tp_size = 1
|
||||||
self.tp_rank = 0
|
self.tp_rank = 0
|
||||||
# Create a tensor of size num_experts filled with -1
|
# Create a tensor of size num_experts filled with -1
|
||||||
self.expert_map = torch.full((self.num_experts,), -1, dtype=torch.int32)
|
self.expert_map_cpu = torch.full((self.num_experts,), -1, dtype=torch.int32)
|
||||||
# Create a expert map for the local experts
|
# Create a expert map for the local experts
|
||||||
assert num_experts % self.ep_size == 0
|
assert num_experts % self.ep_size == 0
|
||||||
self.num_local_experts = num_experts // self.ep_size
|
self.num_local_experts = num_experts // self.ep_size
|
||||||
self.expert_map[
|
self.expert_map_cpu[
|
||||||
self.ep_rank
|
self.ep_rank
|
||||||
* self.num_local_experts : (self.ep_rank + 1)
|
* self.num_local_experts : (self.ep_rank + 1)
|
||||||
* self.num_local_experts
|
* self.num_local_experts
|
||||||
] = torch.arange(0, self.num_local_experts, dtype=torch.int32, device="cpu")
|
] = torch.arange(0, self.num_local_experts, dtype=torch.int32, device="cpu")
|
||||||
|
self.expert_map_gpu = self.expert_map_cpu.to(device="cuda")
|
||||||
else:
|
else:
|
||||||
self.ep_size = 1
|
self.ep_size = 1
|
||||||
self.ep_rank = 0
|
self.ep_rank = 0
|
||||||
@@ -136,9 +138,6 @@ class FusedMoE(torch.nn.Module):
|
|||||||
not _is_cpu and global_server_args_dict["enable_triton_kernel_moe"]
|
not _is_cpu and global_server_args_dict["enable_triton_kernel_moe"]
|
||||||
)
|
)
|
||||||
|
|
||||||
if skip_quant:
|
|
||||||
return
|
|
||||||
|
|
||||||
if quant_config is None:
|
if quant_config is None:
|
||||||
self.quant_method: Optional[QuantizeMethodBase] = UnquantizedFusedMoEMethod(
|
self.quant_method: Optional[QuantizeMethodBase] = UnquantizedFusedMoEMethod(
|
||||||
self.use_triton_kernels
|
self.use_triton_kernels
|
||||||
@@ -367,9 +366,9 @@ class FusedMoE(torch.nn.Module):
|
|||||||
expert_data.copy_(loaded_weight)
|
expert_data.copy_(loaded_weight)
|
||||||
|
|
||||||
def _map_global_expert_id_to_local_expert_id(self, expert_id: int) -> int:
|
def _map_global_expert_id_to_local_expert_id(self, expert_id: int) -> int:
|
||||||
if self.expert_map is None:
|
if self.expert_map_cpu is None:
|
||||||
return expert_id
|
return expert_id
|
||||||
return self.expert_map[expert_id].item()
|
return self.expert_map_cpu[expert_id].item()
|
||||||
|
|
||||||
def weight_loader(
|
def weight_loader(
|
||||||
self,
|
self,
|
||||||
@@ -421,7 +420,6 @@ class FusedMoE(torch.nn.Module):
|
|||||||
expert_id = self._map_global_expert_id_to_local_expert_id(expert_id)
|
expert_id = self._map_global_expert_id_to_local_expert_id(expert_id)
|
||||||
if expert_id == -1:
|
if expert_id == -1:
|
||||||
return
|
return
|
||||||
|
|
||||||
self._weight_loader_impl(
|
self._weight_loader_impl(
|
||||||
param=param,
|
param=param,
|
||||||
loaded_weight=loaded_weight,
|
loaded_weight=loaded_weight,
|
||||||
@@ -614,9 +612,14 @@ class FusedMoE(torch.nn.Module):
|
|||||||
)
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
def forward(self, hidden_states: torch.Tensor, topk_output: TopKOutput):
|
def forward(self, hidden_states: torch.Tensor, topk_output: StandardTopKOutput):
|
||||||
assert self.quant_method is not None
|
assert self.quant_method is not None
|
||||||
|
|
||||||
|
if self.expert_map_gpu is not None:
|
||||||
|
topk_output = topk_output._replace(
|
||||||
|
topk_ids=self.expert_map_gpu[topk_output.topk_ids]
|
||||||
|
)
|
||||||
|
|
||||||
# Matrix multiply.
|
# Matrix multiply.
|
||||||
final_hidden_states = self.quant_method.apply(
|
final_hidden_states = self.quant_method.apply(
|
||||||
layer=self,
|
layer=self,
|
||||||
@@ -670,3 +673,20 @@ class FusedMoE(torch.nn.Module):
|
|||||||
("w3", ckpt_up_proj_name),
|
("w3", ckpt_up_proj_name),
|
||||||
]
|
]
|
||||||
]
|
]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def make_expert_input_scale_params_mapping(
|
||||||
|
cls,
|
||||||
|
num_experts: int,
|
||||||
|
) -> List[Tuple[str, str, int, str]]:
|
||||||
|
# (param_name, weight_name, expert_id, shard_id)
|
||||||
|
return [
|
||||||
|
(
|
||||||
|
"experts.w13_" if shard_id in ["w1", "w3"] else "experts.w2_",
|
||||||
|
f"experts.{expert_id}.{shard_id}.",
|
||||||
|
expert_id,
|
||||||
|
shard_id,
|
||||||
|
)
|
||||||
|
for expert_id in range(num_experts)
|
||||||
|
for shard_id in ["w1", "w2", "w3"]
|
||||||
|
]
|
||||||
|
|||||||
@@ -172,7 +172,6 @@ class Fp8Config(QuantizationConfig):
|
|||||||
self, layer: torch.nn.Module, prefix: str
|
self, layer: torch.nn.Module, prefix: str
|
||||||
) -> Optional[QuantizeMethodBase]:
|
) -> Optional[QuantizeMethodBase]:
|
||||||
from sglang.srt.layers.linear import LinearBase
|
from sglang.srt.layers.linear import LinearBase
|
||||||
from sglang.srt.layers.moe.ep_moe.layer import EPMoE
|
|
||||||
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
||||||
|
|
||||||
if isinstance(layer, LinearBase):
|
if isinstance(layer, LinearBase):
|
||||||
@@ -181,8 +180,6 @@ class Fp8Config(QuantizationConfig):
|
|||||||
return Fp8LinearMethod(self)
|
return Fp8LinearMethod(self)
|
||||||
elif isinstance(layer, FusedMoE):
|
elif isinstance(layer, FusedMoE):
|
||||||
return Fp8MoEMethod(self)
|
return Fp8MoEMethod(self)
|
||||||
elif isinstance(layer, EPMoE):
|
|
||||||
return Fp8EPMoEMethod(self)
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def get_scaled_act_names(self) -> List[str]:
|
def get_scaled_act_names(self) -> List[str]:
|
||||||
@@ -984,23 +981,8 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|||||||
no_combine: bool = False,
|
no_combine: bool = False,
|
||||||
routed_scaling_factor: Optional[float] = None,
|
routed_scaling_factor: Optional[float] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
from sglang.srt.layers.moe.ep_moe.layer import EPMoE
|
|
||||||
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
|
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
|
||||||
|
|
||||||
if isinstance(layer, EPMoE):
|
|
||||||
layer.w13_weight_scale = (
|
|
||||||
layer.w13_weight_scale_inv
|
|
||||||
if self.block_quant
|
|
||||||
else layer.w13_weight_scale
|
|
||||||
)
|
|
||||||
layer.w2_weight_scale = (
|
|
||||||
layer.w2_weight_scale_inv if self.block_quant else layer.w2_weight_scale
|
|
||||||
)
|
|
||||||
return layer.run_moe(
|
|
||||||
hidden_states=x,
|
|
||||||
topk_output=topk_output,
|
|
||||||
)
|
|
||||||
|
|
||||||
if use_intel_amx_backend(layer):
|
if use_intel_amx_backend(layer):
|
||||||
from sglang.srt.layers.moe.topk import apply_topk_weights_cpu
|
from sglang.srt.layers.moe.topk import apply_topk_weights_cpu
|
||||||
|
|
||||||
|
|||||||
@@ -204,14 +204,6 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
|||||||
routed_scaling_factor: Optional[float] = None,
|
routed_scaling_factor: Optional[float] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
|
|
||||||
from sglang.srt.layers.moe.ep_moe.layer import EPMoE
|
|
||||||
|
|
||||||
if isinstance(layer, EPMoE):
|
|
||||||
return layer.run_moe(
|
|
||||||
hidden_states=x,
|
|
||||||
topk_output=topk_output,
|
|
||||||
)
|
|
||||||
|
|
||||||
return self.forward(
|
return self.forward(
|
||||||
x=x,
|
x=x,
|
||||||
layer=layer,
|
layer=layer,
|
||||||
|
|||||||
@@ -276,6 +276,7 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
|
|||||||
layer: EPMoE,
|
layer: EPMoE,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
topk_output: TopKOutput,
|
topk_output: TopKOutput,
|
||||||
|
**kwargs,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
|
|
||||||
# TODO(ch-wan): move it out of this class
|
# TODO(ch-wan): move it out of this class
|
||||||
|
|||||||
Reference in New Issue
Block a user