[8/N] MoE Refactor: deprecate EPMoE (#11211)
This commit is contained in:
@@ -31,8 +31,8 @@ except ImportError:
|
||||
from sglang.srt.distributed import get_tensor_model_parallel_world_size
|
||||
from sglang.srt.layers.amx_utils import _amx_process_weight_after_loading
|
||||
from sglang.srt.layers.moe import MoeRunner, MoeRunnerBackend, MoeRunnerConfig
|
||||
from sglang.srt.layers.moe.moe_runner.deep_gemm import DeepGemmMoeQuantInfo
|
||||
from sglang.srt.layers.moe.moe_runner.triton import TritonMoeQuantInfo
|
||||
from sglang.srt.layers.moe.token_dispatcher.base import DispatchOutputChecker
|
||||
from sglang.srt.layers.parameter import (
|
||||
BlockQuantScaleParameter,
|
||||
ModelWeightParameter,
|
||||
@@ -1006,8 +1006,29 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
def create_moe_runner(
|
||||
self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig
|
||||
):
|
||||
|
||||
from sglang.srt.layers.moe.utils import (
|
||||
get_moe_a2a_backend,
|
||||
get_moe_runner_backend,
|
||||
)
|
||||
from sglang.srt.layers.quantization import deep_gemm_wrapper
|
||||
|
||||
self.moe_runner_config = moe_runner_config
|
||||
self.runner = MoeRunner(MoeRunnerBackend.TRITON, moe_runner_config)
|
||||
moe_runner_backend = get_moe_runner_backend()
|
||||
|
||||
if moe_runner_backend.is_auto():
|
||||
if (
|
||||
deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM
|
||||
and get_moe_a2a_backend().is_deepep()
|
||||
):
|
||||
moe_runner_backend = MoeRunnerBackend.DEEP_GEMM
|
||||
else:
|
||||
moe_runner_backend = MoeRunnerBackend.TRITON
|
||||
if moe_runner_backend.is_deep_gemm() or moe_runner_backend.is_triton():
|
||||
self.runner = MoeRunner(moe_runner_backend, moe_runner_config)
|
||||
else:
|
||||
# TODO(cwan): refactor other backends
|
||||
pass
|
||||
|
||||
def apply(
|
||||
self,
|
||||
@@ -1087,22 +1108,67 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
)
|
||||
return StandardCombineInput(hidden_states=output)
|
||||
|
||||
quant_info = TritonMoeQuantInfo(
|
||||
w13_weight=layer.w13_weight,
|
||||
w2_weight=layer.w2_weight,
|
||||
use_fp8_w8a8=True,
|
||||
w13_scale=(
|
||||
layer.w13_weight_scale_inv
|
||||
if self.block_quant
|
||||
else layer.w13_weight_scale
|
||||
),
|
||||
w2_scale=(
|
||||
layer.w2_weight_scale_inv if self.block_quant else layer.w2_weight_scale
|
||||
),
|
||||
a13_scale=layer.w13_input_scale,
|
||||
a2_scale=layer.w2_input_scale,
|
||||
block_shape=self.quant_config.weight_block_size,
|
||||
)
|
||||
if self.runner.runner_backend.is_deep_gemm():
|
||||
|
||||
w13_weight = layer.w13_weight
|
||||
w2_weight = layer.w2_weight
|
||||
|
||||
if self.block_quant:
|
||||
block_shape = self.quant_config.weight_block_size
|
||||
w13_scale = layer.w13_weight_scale_inv
|
||||
w2_scale = layer.w2_weight_scale_inv
|
||||
else:
|
||||
# Convert per-tensor quant to per-block quant by repeating scales for forward_deepgemm
|
||||
scale_block_size = 128
|
||||
block_shape = [scale_block_size, scale_block_size]
|
||||
w13_scale_n = (w13_weight.shape[1] - 1) // scale_block_size + 1
|
||||
w13_scale_k = (w13_weight.shape[2] - 1) // scale_block_size + 1
|
||||
w13_scale = (
|
||||
layer.w13_weight_scale.unsqueeze(1)
|
||||
.repeat_interleave(w13_scale_n, dim=1)
|
||||
.unsqueeze(2)
|
||||
.repeat_interleave(w13_scale_k, dim=2)
|
||||
)
|
||||
w2_scale_n = (w2_weight.shape[1] - 1) // scale_block_size + 1
|
||||
w2_scale_k = (w2_weight.shape[2] - 1) // scale_block_size + 1
|
||||
w2_scale = (
|
||||
layer.w2_weight_scale.unsqueeze(1)
|
||||
.repeat_interleave(w2_scale_n, dim=1)
|
||||
.unsqueeze(2)
|
||||
.repeat_interleave(w2_scale_k, dim=2)
|
||||
)
|
||||
quant_info = DeepGemmMoeQuantInfo(
|
||||
w13_weight=w13_weight,
|
||||
w2_weight=w2_weight,
|
||||
use_fp8=True,
|
||||
w13_scale=w13_scale,
|
||||
w2_scale=w2_scale,
|
||||
block_shape=block_shape,
|
||||
)
|
||||
elif self.runner.runner_backend.is_triton():
|
||||
quant_info = TritonMoeQuantInfo(
|
||||
w13_weight=layer.w13_weight,
|
||||
w2_weight=layer.w2_weight,
|
||||
use_fp8_w8a8=True,
|
||||
w13_scale=(
|
||||
layer.w13_weight_scale_inv
|
||||
if self.block_quant
|
||||
else layer.w13_weight_scale
|
||||
),
|
||||
w2_scale=(
|
||||
layer.w2_weight_scale_inv
|
||||
if self.block_quant
|
||||
else layer.w2_weight_scale
|
||||
),
|
||||
a13_scale=layer.w13_input_scale,
|
||||
a2_scale=layer.w2_input_scale,
|
||||
block_shape=self.quant_config.weight_block_size,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
"Unsupported runner backend: %s" % self.runner.runner_backend
|
||||
)
|
||||
|
||||
return self.runner.run(dispatch_output, quant_info)
|
||||
|
||||
def apply_with_router_logits(
|
||||
|
||||
@@ -21,7 +21,6 @@ from sglang.srt.utils import is_npu, set_weight_attrs
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.layers.moe import MoeRunnerConfig
|
||||
from sglang.srt.layers.moe.ep_moe.layer import EPMoE
|
||||
from sglang.srt.layers.moe.token_dispatcher import (
|
||||
CombineInput,
|
||||
StandardDispatchOutput,
|
||||
@@ -94,9 +93,7 @@ class W4AFp8Config(QuantizationConfig):
|
||||
self, layer: torch.nn.Module, prefix: str
|
||||
) -> Optional[QuantizeMethodBase]:
|
||||
from sglang.srt.layers.linear import LinearBase
|
||||
from sglang.srt.layers.moe.ep_moe.layer import EPMoE
|
||||
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||
|
||||
if isinstance(layer, LinearBase):
|
||||
if is_layer_skipped(prefix, self.ignored_layers):
|
||||
@@ -133,7 +130,7 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
layer: EPMoE,
|
||||
layer: Module,
|
||||
num_experts: int,
|
||||
hidden_size: int,
|
||||
intermediate_size_per_partition: int,
|
||||
@@ -292,7 +289,7 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: EPMoE,
|
||||
layer: Module,
|
||||
dispatch_output: StandardDispatchOutput,
|
||||
) -> CombineInput:
|
||||
|
||||
@@ -303,18 +300,8 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
|
||||
topk_output = dispatch_output.topk_output
|
||||
|
||||
topk_weights, topk_ids, _ = topk_output
|
||||
local_topk_ids = topk_ids
|
||||
if get_moe_expert_parallel_world_size() > 1:
|
||||
local_topk_ids = torch.where(
|
||||
topk_ids == -1,
|
||||
layer.num_experts,
|
||||
topk_ids,
|
||||
)
|
||||
|
||||
output = cutlass_w4a8_moe(
|
||||
layer.start_expert_id,
|
||||
layer.end_expert_id,
|
||||
layer.num_experts,
|
||||
x,
|
||||
layer.w13_weight,
|
||||
layer.w2_weight,
|
||||
@@ -322,7 +309,6 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
|
||||
layer.w2_weight_scale_inv,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
local_topk_ids,
|
||||
self.a_strides1,
|
||||
self.b_strides1,
|
||||
self.c_strides1,
|
||||
|
||||
Reference in New Issue
Block a user