[6/N] MoE Refactor: Cleanup MoE-related configs (#8849)
This commit is contained in:
@@ -11,6 +11,7 @@ import triton
|
|||||||
from ray.experimental.tqdm_ray import tqdm
|
from ray.experimental.tqdm_ray import tqdm
|
||||||
from transformers import AutoConfig
|
from transformers import AutoConfig
|
||||||
|
|
||||||
|
from sglang.srt.layers.moe.fused_moe_triton import override_config
|
||||||
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import (
|
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import (
|
||||||
fused_moe,
|
fused_moe,
|
||||||
get_config_dtype_str,
|
get_config_dtype_str,
|
||||||
@@ -18,7 +19,8 @@ from sglang.srt.layers.moe.fused_moe_triton.fused_moe import (
|
|||||||
get_default_config,
|
get_default_config,
|
||||||
get_moe_configs,
|
get_moe_configs,
|
||||||
)
|
)
|
||||||
from sglang.srt.layers.moe.topk import select_experts
|
from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
|
||||||
|
from sglang.srt.layers.moe.topk import TopKConfig, select_experts
|
||||||
from sglang.srt.utils import is_hip
|
from sglang.srt.utils import is_hip
|
||||||
|
|
||||||
_is_hip = is_hip()
|
_is_hip = is_hip()
|
||||||
@@ -117,17 +119,23 @@ def benchmark_config(
|
|||||||
w2 = w2.to(torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn)
|
w2 = w2.to(torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn)
|
||||||
|
|
||||||
input_gating = torch.randn(num_tokens, num_experts, dtype=torch.float32)
|
input_gating = torch.randn(num_tokens, num_experts, dtype=torch.float32)
|
||||||
topk_output = select_experts(x, input_gating, topk, renormalize=True)
|
topk_config = TopKConfig(
|
||||||
|
top_k=topk,
|
||||||
|
renormalize=True,
|
||||||
|
)
|
||||||
|
topk_output = select_experts(x, input_gating, topk_config)
|
||||||
|
|
||||||
def prepare(i: int):
|
def prepare(i: int):
|
||||||
input_gating = gating_output[i]
|
input_gating = gating_output[i]
|
||||||
new_topk_output = select_experts(x, input_gating, topk, renormalize=True)
|
new_topk_output = select_experts(x, input_gating, topk_config)
|
||||||
topk_output.topk_weights.copy_(new_topk_output.topk_weights)
|
topk_output.topk_weights.copy_(new_topk_output.topk_weights)
|
||||||
topk_output.topk_ids.copy_(new_topk_output.topk_ids)
|
topk_output.topk_ids.copy_(new_topk_output.topk_ids)
|
||||||
topk_output.router_logits.copy_(new_topk_output.router_logits)
|
topk_output.router_logits.copy_(new_topk_output.router_logits)
|
||||||
|
|
||||||
def run():
|
def run():
|
||||||
from sglang.srt.layers.moe.fused_moe_triton import override_config
|
moe_runner_config = MoeRunnerConfig(
|
||||||
|
inplace=True,
|
||||||
|
)
|
||||||
|
|
||||||
with override_config(config):
|
with override_config(config):
|
||||||
fused_moe(
|
fused_moe(
|
||||||
@@ -135,7 +143,7 @@ def benchmark_config(
|
|||||||
w1,
|
w1,
|
||||||
w2,
|
w2,
|
||||||
topk_output,
|
topk_output,
|
||||||
inplace=True,
|
moe_runner_config=moe_runner_config,
|
||||||
use_fp8_w8a8=use_fp8_w8a8,
|
use_fp8_w8a8=use_fp8_w8a8,
|
||||||
use_int8_w8a8=use_int8_w8a8,
|
use_int8_w8a8=use_int8_w8a8,
|
||||||
use_int8_w8a16=use_int8_w8a16,
|
use_int8_w8a16=use_int8_w8a16,
|
||||||
|
|||||||
@@ -213,12 +213,11 @@ Please consult the documentation below and [server_args.py](https://github.com/s
|
|||||||
| Arguments | Description | Defaults |
|
| Arguments | Description | Defaults |
|
||||||
|-----------|-------------|----------|
|
|-----------|-------------|----------|
|
||||||
| `--ep-size` | The expert parallelism size. | 1 |
|
| `--ep-size` | The expert parallelism size. | 1 |
|
||||||
| `--moe-a2a-backend` | Select the backend for all-to-all communication for expert parallelism. | None |
|
| `--moe-a2a-backend` | Select the backend for all-to-all communication for expert parallelism. | none |
|
||||||
| `--enable-flashinfer-cutlass-moe` | Enabling Flashinfer Cutlass MoE implementation for high throughput. | False |
|
| `--moe-runner-backend` | Select the runner backend for MoE. | 'triton' |
|
||||||
| `--enable-flashinfer-trtllm-moe` | Enabling Flashinfer Trtllm MoE implementation for low latency. | False |
|
|
||||||
| `--deepep-mode` | Select the mode when enable DeepEP MoE, could be `normal`, `low_latency` or `auto`. Default is `auto`, which means `low_latency` for decode batch and `normal` for prefill batch. | auto |
|
| `--deepep-mode` | Select the mode when enable DeepEP MoE, could be `normal`, `low_latency` or `auto`. Default is `auto`, which means `low_latency` for decode batch and `normal` for prefill batch. | auto |
|
||||||
| `--ep-num-redundant-experts` | Allocate this number of redundant experts in expert parallel. | 0 |
|
| `--ep-num-redundant-experts` | Allocate this number of redundant experts in expert parallel. | 0 |
|
||||||
| `--ep-dispatch-algorithm` | The algorithm to choose ranks for redundant experts in expert parallel. | None |
|
| `--ep-dispatch-algorithm` | The algorithm to choose ranks for redundant experts in EPLB. | None |
|
||||||
| `--init-expert-location` | Initial location of EP experts. | trivial |
|
| `--init-expert-location` | Initial location of EP experts. | trivial |
|
||||||
| `--enable-eplb` | Enable EPLB algorithm. | False |
|
| `--enable-eplb` | Enable EPLB algorithm. | False |
|
||||||
| `--eplb-algorithm` | Chosen EPLB algorithm. | auto |
|
| `--eplb-algorithm` | Chosen EPLB algorithm. | auto |
|
||||||
@@ -280,7 +279,6 @@ Please consult the documentation below and [server_args.py](https://github.com/s
|
|||||||
| `--disable-chunked-prefix-cache` | Disable chunked prefix cache. | False |
|
| `--disable-chunked-prefix-cache` | Disable chunked prefix cache. | False |
|
||||||
| `--disable-fast-image-processor` | Disable fast image processor. | False |
|
| `--disable-fast-image-processor` | Disable fast image processor. | False |
|
||||||
| `--enable-return-hidden-states` | Enable returning hidden states. | False |
|
| `--enable-return-hidden-states` | Enable returning hidden states. | False |
|
||||||
| `--enable-triton-kernel-moe` | Enable Triton kernel for MoE. | False |
|
|
||||||
|
|
||||||
## Debug tensor dumps
|
## Debug tensor dumps
|
||||||
|
|
||||||
|
|||||||
@@ -61,7 +61,6 @@ from sglang.srt.configs.model_config import ModelConfig
|
|||||||
from sglang.srt.distributed.parallel_state import destroy_distributed_environment
|
from sglang.srt.distributed.parallel_state import destroy_distributed_environment
|
||||||
from sglang.srt.entrypoints.engine import _set_envs_and_config
|
from sglang.srt.entrypoints.engine import _set_envs_and_config
|
||||||
from sglang.srt.hf_transformers_utils import get_tokenizer
|
from sglang.srt.hf_transformers_utils import get_tokenizer
|
||||||
from sglang.srt.layers.moe.utils import DeepEPMode, MoeA2ABackend
|
|
||||||
from sglang.srt.managers.schedule_batch import Req, ScheduleBatch
|
from sglang.srt.managers.schedule_batch import Req, ScheduleBatch
|
||||||
from sglang.srt.managers.scheduler import Scheduler
|
from sglang.srt.managers.scheduler import Scheduler
|
||||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||||
@@ -300,11 +299,6 @@ def _maybe_prepare_mlp_sync_batch(batch: ScheduleBatch, model_runner):
|
|||||||
disable_cuda_graph=model_runner.server_args.disable_cuda_graph,
|
disable_cuda_graph=model_runner.server_args.disable_cuda_graph,
|
||||||
spec_algorithm=SpeculativeAlgorithm.NONE,
|
spec_algorithm=SpeculativeAlgorithm.NONE,
|
||||||
speculative_num_draft_tokens=None,
|
speculative_num_draft_tokens=None,
|
||||||
enable_two_batch_overlap=model_runner.server_args.enable_two_batch_overlap,
|
|
||||||
enable_deepep_moe=MoeA2ABackend(
|
|
||||||
model_runner.server_args.moe_a2a_backend
|
|
||||||
).is_deepep(),
|
|
||||||
deepep_mode=DeepEPMode(model_runner.server_args.deepep_mode),
|
|
||||||
require_mlp_tp_gather=require_mlp_tp_gather(model_runner.server_args),
|
require_mlp_tp_gather=require_mlp_tp_gather(model_runner.server_args),
|
||||||
disable_overlap_schedule=model_runner.server_args.disable_overlap_schedule,
|
disable_overlap_schedule=model_runner.server_args.disable_overlap_schedule,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -25,7 +25,6 @@ import torch
|
|||||||
import torch.distributed
|
import torch.distributed
|
||||||
|
|
||||||
from sglang.srt.eplb.expert_location import ExpertLocationMetadata
|
from sglang.srt.eplb.expert_location import ExpertLocationMetadata
|
||||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
|
||||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||||
from sglang.srt.server_args import ServerArgs
|
from sglang.srt.server_args import ServerArgs
|
||||||
from sglang.srt.utils import Withable, get_bool_env_var
|
from sglang.srt.utils import Withable, get_bool_env_var
|
||||||
@@ -288,14 +287,14 @@ class _SinglePassGatherer(ABC):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if server_args.expert_distribution_recorder_mode == "stat_approx":
|
if server_args.expert_distribution_recorder_mode == "stat_approx":
|
||||||
if server_args.moe_a2a_backend is not None and (
|
if server_args.moe_a2a_backend != "none" and (
|
||||||
server_args.deepep_mode == "normal"
|
server_args.deepep_mode == "normal"
|
||||||
):
|
):
|
||||||
return _DeepepNormalSinglePassGatherer(expert_location_metadata, rank)
|
return _DeepepNormalSinglePassGatherer(expert_location_metadata, rank)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
if server_args.moe_a2a_backend is not None:
|
if server_args.moe_a2a_backend != "none":
|
||||||
if server_args.deepep_mode == "normal":
|
if server_args.deepep_mode == "normal":
|
||||||
return _SelectExpertsSinglePassGatherer(expert_location_metadata, rank)
|
return _SelectExpertsSinglePassGatherer(expert_location_metadata, rank)
|
||||||
elif server_args.deepep_mode == "low_latency":
|
elif server_args.deepep_mode == "low_latency":
|
||||||
|
|||||||
@@ -17,7 +17,7 @@ from enum import Enum, auto
|
|||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import Dict, Optional
|
from typing import Dict, Optional
|
||||||
|
|
||||||
import torch.distributed
|
import torch
|
||||||
|
|
||||||
from sglang.srt.distributed import (
|
from sglang.srt.distributed import (
|
||||||
get_tensor_model_parallel_world_size,
|
get_tensor_model_parallel_world_size,
|
||||||
@@ -35,6 +35,7 @@ from sglang.srt.layers.dp_attention import (
|
|||||||
get_global_dp_buffer,
|
get_global_dp_buffer,
|
||||||
get_local_dp_buffer,
|
get_local_dp_buffer,
|
||||||
)
|
)
|
||||||
|
from sglang.srt.layers.moe import get_moe_a2a_backend
|
||||||
from sglang.srt.layers.utils import is_sm100_supported
|
from sglang.srt.layers.utils import is_sm100_supported
|
||||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||||
@@ -111,7 +112,7 @@ class LayerScatterModes:
|
|||||||
if context.is_layer_sparse:
|
if context.is_layer_sparse:
|
||||||
return (
|
return (
|
||||||
ScatterMode.SCATTERED
|
ScatterMode.SCATTERED
|
||||||
if not global_server_args_dict["moe_a2a_backend"].is_standard()
|
if not get_moe_a2a_backend().is_none()
|
||||||
else ScatterMode.FULL
|
else ScatterMode.FULL
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
|||||||
29
python/sglang/srt/layers/moe/__init__.py
Normal file
29
python/sglang/srt/layers/moe/__init__.py
Normal file
@@ -0,0 +1,29 @@
|
|||||||
|
from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
|
||||||
|
from sglang.srt.layers.moe.utils import (
|
||||||
|
DeepEPMode,
|
||||||
|
MoeA2ABackend,
|
||||||
|
MoeRunnerBackend,
|
||||||
|
get_deepep_config,
|
||||||
|
get_deepep_mode,
|
||||||
|
get_moe_a2a_backend,
|
||||||
|
get_moe_runner_backend,
|
||||||
|
get_tbo_token_distribution_threshold,
|
||||||
|
initialize_moe_config,
|
||||||
|
is_tbo_enabled,
|
||||||
|
should_use_flashinfer_trtllm_moe,
|
||||||
|
)
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"DeepEPMode",
|
||||||
|
"MoeA2ABackend",
|
||||||
|
"MoeRunnerConfig",
|
||||||
|
"MoeRunnerBackend",
|
||||||
|
"initialize_moe_config",
|
||||||
|
"get_moe_a2a_backend",
|
||||||
|
"get_moe_runner_backend",
|
||||||
|
"get_deepep_mode",
|
||||||
|
"should_use_flashinfer_trtllm_moe",
|
||||||
|
"is_tbo_enabled",
|
||||||
|
"get_tbo_token_distribution_threshold",
|
||||||
|
"get_deepep_config",
|
||||||
|
]
|
||||||
@@ -1,11 +1,17 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from typing import TYPE_CHECKING, Optional
|
from typing import TYPE_CHECKING, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from sglang.srt.distributed.parallel_state import get_moe_expert_parallel_world_size
|
from sglang.srt.distributed.parallel_state import get_moe_expert_parallel_world_size
|
||||||
|
from sglang.srt.layers.moe import (
|
||||||
|
get_deepep_mode,
|
||||||
|
get_moe_a2a_backend,
|
||||||
|
get_moe_runner_backend,
|
||||||
|
should_use_flashinfer_trtllm_moe,
|
||||||
|
)
|
||||||
from sglang.srt.layers.moe.ep_moe.kernels import (
|
from sglang.srt.layers.moe.ep_moe.kernels import (
|
||||||
ep_gather,
|
ep_gather,
|
||||||
ep_scatter,
|
ep_scatter,
|
||||||
@@ -16,14 +22,9 @@ from sglang.srt.layers.moe.ep_moe.kernels import (
|
|||||||
)
|
)
|
||||||
from sglang.srt.layers.moe.fused_moe_triton.layer import FlashInferFusedMoE, FusedMoE
|
from sglang.srt.layers.moe.fused_moe_triton.layer import FlashInferFusedMoE, FusedMoE
|
||||||
from sglang.srt.layers.moe.topk import TopKOutput
|
from sglang.srt.layers.moe.topk import TopKOutput
|
||||||
from sglang.srt.layers.moe.utils import DeepEPMode, should_use_flashinfer_trtllm_moe
|
|
||||||
from sglang.srt.layers.quantization import deep_gemm_wrapper
|
from sglang.srt.layers.quantization import deep_gemm_wrapper
|
||||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||||
from sglang.srt.layers.quantization.fp8 import (
|
from sglang.srt.layers.quantization.fp8 import Fp8Config
|
||||||
Fp8Config,
|
|
||||||
Fp8MoEMethod,
|
|
||||||
get_tile_tokens_dim,
|
|
||||||
)
|
|
||||||
from sglang.srt.layers.quantization.fp8_kernel import (
|
from sglang.srt.layers.quantization.fp8_kernel import (
|
||||||
is_fp8_fnuz,
|
is_fp8_fnuz,
|
||||||
sglang_per_token_group_quant_fp8,
|
sglang_per_token_group_quant_fp8,
|
||||||
@@ -89,12 +90,11 @@ class EPMoE(FusedMoE):
|
|||||||
num_fused_shared_experts: int = 0,
|
num_fused_shared_experts: int = 0,
|
||||||
params_dtype: Optional[torch.dtype] = None,
|
params_dtype: Optional[torch.dtype] = None,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
tp_size: Optional[int] = None,
|
|
||||||
prefix: str = "",
|
prefix: str = "",
|
||||||
activation: str = "silu",
|
activation: str = "silu",
|
||||||
routed_scaling_factor: Optional[float] = None,
|
routed_scaling_factor: Optional[float] = None,
|
||||||
activation_alpha: Optional[float] = None,
|
gemm1_alpha: Optional[float] = None,
|
||||||
swiglu_limit: Optional[float] = None,
|
gemm1_clamp_limit: Optional[float] = None,
|
||||||
with_bias: bool = False,
|
with_bias: bool = False,
|
||||||
):
|
):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
@@ -106,13 +106,12 @@ class EPMoE(FusedMoE):
|
|||||||
top_k=top_k,
|
top_k=top_k,
|
||||||
params_dtype=params_dtype,
|
params_dtype=params_dtype,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
tp_size=tp_size,
|
|
||||||
prefix=prefix,
|
prefix=prefix,
|
||||||
activation=activation,
|
activation=activation,
|
||||||
# apply_router_weight_on_input=apply_router_weight_on_input,
|
# apply_router_weight_on_input=apply_router_weight_on_input,
|
||||||
routed_scaling_factor=routed_scaling_factor,
|
routed_scaling_factor=routed_scaling_factor,
|
||||||
activation_alpha=activation_alpha,
|
gemm1_alpha=gemm1_alpha,
|
||||||
swiglu_limit=swiglu_limit,
|
gemm1_clamp_limit=gemm1_clamp_limit,
|
||||||
with_bias=with_bias,
|
with_bias=with_bias,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -163,7 +162,8 @@ class EPMoE(FusedMoE):
|
|||||||
)
|
)
|
||||||
|
|
||||||
assert self.quant_method is not None
|
assert self.quant_method is not None
|
||||||
assert self.activation == "silu"
|
assert self.moe_runner_config.activation == "silu"
|
||||||
|
|
||||||
hidden_states_shape = hidden_states.shape
|
hidden_states_shape = hidden_states.shape
|
||||||
hidden_states_dtype = hidden_states.dtype
|
hidden_states_dtype = hidden_states.dtype
|
||||||
hidden_states_device = hidden_states.device
|
hidden_states_device = hidden_states.device
|
||||||
@@ -327,8 +327,8 @@ class EPMoE(FusedMoE):
|
|||||||
m_max * self.start_expert_id,
|
m_max * self.start_expert_id,
|
||||||
BLOCK_SIZE=512,
|
BLOCK_SIZE=512,
|
||||||
)
|
)
|
||||||
if self.routed_scaling_factor is not None:
|
if self.moe_runner_config.routed_scaling_factor is not None:
|
||||||
output *= self.routed_scaling_factor
|
output *= self.moe_runner_config.routed_scaling_factor
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
@@ -349,11 +349,9 @@ class DeepEPMoE(EPMoE):
|
|||||||
num_fused_shared_experts: int = 0,
|
num_fused_shared_experts: int = 0,
|
||||||
params_dtype: Optional[torch.dtype] = None,
|
params_dtype: Optional[torch.dtype] = None,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
tp_size: Optional[int] = None,
|
|
||||||
prefix: str = "",
|
prefix: str = "",
|
||||||
activation: str = "silu",
|
activation: str = "silu",
|
||||||
routed_scaling_factor: Optional[float] = None,
|
routed_scaling_factor: Optional[float] = None,
|
||||||
deepep_mode: DeepEPMode = DeepEPMode.AUTO,
|
|
||||||
):
|
):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
num_experts=num_experts,
|
num_experts=num_experts,
|
||||||
@@ -364,12 +362,11 @@ class DeepEPMoE(EPMoE):
|
|||||||
num_fused_shared_experts=num_fused_shared_experts,
|
num_fused_shared_experts=num_fused_shared_experts,
|
||||||
params_dtype=params_dtype,
|
params_dtype=params_dtype,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
tp_size=tp_size,
|
|
||||||
prefix=prefix,
|
prefix=prefix,
|
||||||
activation=activation,
|
activation=activation,
|
||||||
routed_scaling_factor=routed_scaling_factor,
|
routed_scaling_factor=routed_scaling_factor,
|
||||||
)
|
)
|
||||||
self.deepep_mode = deepep_mode
|
self.deepep_mode = get_deepep_mode()
|
||||||
|
|
||||||
# TODO: move to the beginning of the file
|
# TODO: move to the beginning of the file
|
||||||
from sglang.srt.distributed.parallel_state import get_tp_group
|
from sglang.srt.distributed.parallel_state import get_tp_group
|
||||||
@@ -383,7 +380,7 @@ class DeepEPMoE(EPMoE):
|
|||||||
num_local_experts=self.num_local_experts,
|
num_local_experts=self.num_local_experts,
|
||||||
hidden_size=hidden_size,
|
hidden_size=hidden_size,
|
||||||
params_dtype=params_dtype,
|
params_dtype=params_dtype,
|
||||||
deepep_mode=deepep_mode,
|
deepep_mode=self.deepep_mode,
|
||||||
async_finish=True, # TODO
|
async_finish=True, # TODO
|
||||||
return_recv_hook=True,
|
return_recv_hook=True,
|
||||||
)
|
)
|
||||||
@@ -458,15 +455,19 @@ class DeepEPMoE(EPMoE):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def moe_impl(self, dispatch_output: DispatchOutput):
|
def moe_impl(self, dispatch_output: DispatchOutput):
|
||||||
|
from sglang.srt.layers.moe.token_dispatcher import DispatchOutputChecker
|
||||||
|
|
||||||
if _use_aiter:
|
if _use_aiter:
|
||||||
|
assert DispatchOutputChecker.format_is_deepep(dispatch_output)
|
||||||
# in forward_aiter, we skip token permutation and unpermutation, which have been fused inside aiter kernel
|
# in forward_aiter, we skip token permutation and unpermutation, which have been fused inside aiter kernel
|
||||||
return self.forward_aiter(dispatch_output)
|
return self.forward_aiter(dispatch_output)
|
||||||
if _is_npu:
|
if _is_npu:
|
||||||
|
assert DispatchOutputChecker.format_is_ascent_ll(dispatch_output)
|
||||||
return self.forward_npu(dispatch_output)
|
return self.forward_npu(dispatch_output)
|
||||||
if dispatch_output.format.is_deepep_normal():
|
if DispatchOutputChecker.format_is_deepep_normal(dispatch_output):
|
||||||
assert deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM and self.use_fp8_w8a8
|
assert deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM and self.use_fp8_w8a8
|
||||||
return self.forward_deepgemm_contiguous(dispatch_output)
|
return self.forward_deepgemm_contiguous(dispatch_output)
|
||||||
elif dispatch_output.format.is_deepep_ll():
|
elif DispatchOutputChecker.format_is_deepep_ll(dispatch_output):
|
||||||
assert deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM and self.use_fp8_w8a8
|
assert deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM and self.use_fp8_w8a8
|
||||||
return self.forward_deepgemm_masked(dispatch_output)
|
return self.forward_deepgemm_masked(dispatch_output)
|
||||||
else:
|
else:
|
||||||
@@ -490,7 +491,7 @@ class DeepEPMoE(EPMoE):
|
|||||||
|
|
||||||
def forward_aiter(
|
def forward_aiter(
|
||||||
self,
|
self,
|
||||||
dispatch_output: DeepEPNormalOutput,
|
dispatch_output: Union[DeepEPNormalOutput, DeepEPLLOutput],
|
||||||
):
|
):
|
||||||
hidden_states, topk_idx, topk_weights = (
|
hidden_states, topk_idx, topk_weights = (
|
||||||
dispatch_output.hidden_states,
|
dispatch_output.hidden_states,
|
||||||
@@ -516,7 +517,7 @@ class DeepEPMoE(EPMoE):
|
|||||||
quant_type=QuantType.per_128x128,
|
quant_type=QuantType.per_128x128,
|
||||||
activation=(
|
activation=(
|
||||||
ActivationType.Silu
|
ActivationType.Silu
|
||||||
if self.activation == "silu"
|
if self.moe_runner_config.activation == "silu"
|
||||||
else ActivationType.Gelu
|
else ActivationType.Gelu
|
||||||
),
|
),
|
||||||
expert_mask=self.expert_mask,
|
expert_mask=self.expert_mask,
|
||||||
@@ -531,7 +532,7 @@ class DeepEPMoE(EPMoE):
|
|||||||
)
|
)
|
||||||
hidden_states_fp8, hidden_states_scale = hidden_states_fp8
|
hidden_states_fp8, hidden_states_scale = hidden_states_fp8
|
||||||
assert self.quant_method is not None
|
assert self.quant_method is not None
|
||||||
assert self.activation == "silu"
|
assert self.moe_runner_config.activation == "silu"
|
||||||
if num_recv_tokens_per_expert is None:
|
if num_recv_tokens_per_expert is None:
|
||||||
return hidden_states_fp8.bfloat16()
|
return hidden_states_fp8.bfloat16()
|
||||||
all_tokens = sum(num_recv_tokens_per_expert)
|
all_tokens = sum(num_recv_tokens_per_expert)
|
||||||
@@ -652,7 +653,7 @@ class DeepEPMoE(EPMoE):
|
|||||||
):
|
):
|
||||||
hidden_states_fp8, _, _, masked_m, expected_m = dispatch_output
|
hidden_states_fp8, _, _, masked_m, expected_m = dispatch_output
|
||||||
assert self.quant_method is not None
|
assert self.quant_method is not None
|
||||||
assert self.activation == "silu"
|
assert self.moe_runner_config.activation == "silu"
|
||||||
|
|
||||||
# GroupGemm-0
|
# GroupGemm-0
|
||||||
num_groups, m, k = hidden_states_fp8[0].size()
|
num_groups, m, k = hidden_states_fp8[0].size()
|
||||||
@@ -783,12 +784,12 @@ class DeepEPMoE(EPMoE):
|
|||||||
|
|
||||||
|
|
||||||
def get_moe_impl_class():
|
def get_moe_impl_class():
|
||||||
if global_server_args_dict["moe_a2a_backend"].is_deepep():
|
if get_moe_a2a_backend().is_deepep():
|
||||||
return DeepEPMoE
|
return DeepEPMoE
|
||||||
|
|
||||||
# NEW: Direct FP4 detection (bypasses EP requirements)
|
# NEW: Direct FP4 detection (bypasses EP requirements)
|
||||||
# Check for FP4 quantization with TRTLLM flag, regardless of EP
|
# Check for FP4 quantization with TRTLLM flag, regardless of EP
|
||||||
if global_server_args_dict.get("enable_flashinfer_trtllm_moe", False):
|
if get_moe_runner_backend().is_flashinfer_trtllm():
|
||||||
try:
|
try:
|
||||||
# Check the quantization argument directly
|
# Check the quantization argument directly
|
||||||
quantization = global_server_args_dict.get("quantization")
|
quantization = global_server_args_dict.get("quantization")
|
||||||
@@ -803,7 +804,7 @@ def get_moe_impl_class():
|
|||||||
|
|
||||||
if should_use_flashinfer_trtllm_moe():
|
if should_use_flashinfer_trtllm_moe():
|
||||||
return FlashInferFusedMoE
|
return FlashInferFusedMoE
|
||||||
if global_server_args_dict["enable_flashinfer_cutlass_moe"]:
|
if get_moe_runner_backend().is_flashinfer_cutlass():
|
||||||
return FusedMoE
|
return FusedMoE
|
||||||
if get_moe_expert_parallel_world_size() > 1:
|
if get_moe_expert_parallel_world_size() > 1:
|
||||||
return EPMoE
|
return EPMoE
|
||||||
|
|||||||
@@ -3,28 +3,22 @@ Torch-native implementation for FusedMoE. This is used for torch.compile.
|
|||||||
It is based on https://github.com/pytorch-labs/gpt-fast/blob/32971d3129541c5bfb4f715abc33d1c5f408d204/mixtral-moe/model.py#L204
|
It is based on https://github.com/pytorch-labs/gpt-fast/blob/32971d3129541c5bfb4f715abc33d1c5f408d204/mixtral-moe/model.py#L204
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import Callable, Optional
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch.nn import functional as F
|
from torch.nn import functional as F
|
||||||
|
|
||||||
from sglang.srt.layers.activation import GeluAndMul, SiluAndMul
|
from sglang.srt.layers.activation import GeluAndMul, SiluAndMul
|
||||||
from sglang.srt.layers.moe.topk import TopKOutput
|
from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
|
||||||
|
from sglang.srt.layers.moe.topk import StandardTopKOutput
|
||||||
|
|
||||||
|
|
||||||
def fused_moe_forward_native(
|
def fused_moe_forward_native(
|
||||||
layer: torch.nn.Module,
|
layer: torch.nn.Module,
|
||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
topk_output: TopKOutput,
|
topk_output: StandardTopKOutput,
|
||||||
*,
|
moe_runner_config: MoeRunnerConfig,
|
||||||
activation: str = "silu",
|
|
||||||
apply_router_weight_on_input: bool = False,
|
|
||||||
inplace: bool = True,
|
|
||||||
no_combine: bool = False,
|
|
||||||
routed_scaling_factor: Optional[float] = None,
|
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
|
|
||||||
if apply_router_weight_on_input:
|
if moe_runner_config.apply_router_weight_on_input:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
topk_weights, topk_ids, _ = topk_output
|
topk_weights, topk_ids, _ = topk_output
|
||||||
@@ -33,12 +27,12 @@ def fused_moe_forward_native(
|
|||||||
w1_weights, w3_weights = torch.chunk(w13_weights, 2, dim=2)
|
w1_weights, w3_weights = torch.chunk(w13_weights, 2, dim=2)
|
||||||
w2_weights = layer.w2_weight[topk_ids]
|
w2_weights = layer.w2_weight[topk_ids]
|
||||||
x1 = torch.einsum("ti,taoi -> tao", x, w1_weights)
|
x1 = torch.einsum("ti,taoi -> tao", x, w1_weights)
|
||||||
if activation == "silu":
|
if moe_runner_config.activation == "silu":
|
||||||
x1 = F.silu(x1)
|
x1 = F.silu(x1)
|
||||||
elif activation == "gelu":
|
elif moe_runner_config.activation == "gelu":
|
||||||
x1 = F.gelu(x1)
|
x1 = F.gelu(x1)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported activation: {activation=}")
|
raise ValueError(f"Unsupported activation: {moe_runner_config.activation=}")
|
||||||
x3 = torch.einsum("ti, taoi -> tao", x, w3_weights)
|
x3 = torch.einsum("ti, taoi -> tao", x, w3_weights)
|
||||||
expert_outs = torch.einsum("tao, taio -> tai", (x1 * x3), w2_weights)
|
expert_outs = torch.einsum("tao, taio -> tai", (x1 * x3), w2_weights)
|
||||||
return torch.einsum("tai,ta -> ti", expert_outs, topk_weights.to(expert_outs.dtype))
|
return torch.einsum("tai,ta -> ti", expert_outs, topk_weights.to(expert_outs.dtype))
|
||||||
@@ -47,16 +41,11 @@ def fused_moe_forward_native(
|
|||||||
def moe_forward_native(
|
def moe_forward_native(
|
||||||
layer: torch.nn.Module,
|
layer: torch.nn.Module,
|
||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
topk_output: TopKOutput,
|
topk_output: StandardTopKOutput,
|
||||||
*,
|
moe_runner_config: MoeRunnerConfig,
|
||||||
activation: str = "silu",
|
|
||||||
apply_router_weight_on_input: bool = False,
|
|
||||||
inplace: bool = True,
|
|
||||||
no_combine: bool = False,
|
|
||||||
routed_scaling_factor: Optional[float] = None,
|
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
|
|
||||||
if apply_router_weight_on_input:
|
if moe_runner_config.apply_router_weight_on_input:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
topk_weights, topk_ids, _ = topk_output
|
topk_weights, topk_ids, _ = topk_output
|
||||||
@@ -72,12 +61,12 @@ def moe_forward_native(
|
|||||||
sorted_tokens = x[idxs // topk_ids.shape[1]]
|
sorted_tokens = x[idxs // topk_ids.shape[1]]
|
||||||
tokens_per_expert = tokens_per_expert.cpu().numpy()
|
tokens_per_expert = tokens_per_expert.cpu().numpy()
|
||||||
|
|
||||||
if activation == "silu":
|
if moe_runner_config.activation == "silu":
|
||||||
act = SiluAndMul()
|
act = SiluAndMul()
|
||||||
elif activation == "gelu":
|
elif moe_runner_config.activation == "gelu":
|
||||||
act = GeluAndMul()
|
act = GeluAndMul()
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported activation: {activation=}")
|
raise ValueError(f"Unsupported activation: {moe_runner_config.activation=}")
|
||||||
|
|
||||||
outputs = []
|
outputs = []
|
||||||
start_idx = 0
|
start_idx = 0
|
||||||
|
|||||||
@@ -2,17 +2,20 @@
|
|||||||
|
|
||||||
"""Fused MoE kernel."""
|
"""Fused MoE kernel."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import functools
|
import functools
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
from typing import Any, Dict, List, Optional, Tuple
|
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import triton
|
import triton
|
||||||
import triton.language as tl
|
import triton.language as tl
|
||||||
|
|
||||||
from sglang.srt.layers.moe.topk import TopKOutput
|
from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
|
||||||
|
from sglang.srt.layers.moe.topk import StandardTopKOutput
|
||||||
from sglang.srt.layers.quantization.fp8_kernel import (
|
from sglang.srt.layers.quantization.fp8_kernel import (
|
||||||
per_token_group_quant_fp8,
|
per_token_group_quant_fp8,
|
||||||
scaled_fp8_quant,
|
scaled_fp8_quant,
|
||||||
@@ -1025,8 +1028,8 @@ def inplace_fused_experts(
|
|||||||
a2_scale: Optional[torch.Tensor] = None,
|
a2_scale: Optional[torch.Tensor] = None,
|
||||||
block_shape: Optional[List[int]] = None,
|
block_shape: Optional[List[int]] = None,
|
||||||
routed_scaling_factor: Optional[float] = None,
|
routed_scaling_factor: Optional[float] = None,
|
||||||
activation_alpha: Optional[float] = None,
|
gemm1_alpha: Optional[float] = None,
|
||||||
swiglu_limit: Optional[float] = None,
|
gemm1_limit: Optional[float] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
fused_experts_impl(
|
fused_experts_impl(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
@@ -1053,8 +1056,8 @@ def inplace_fused_experts(
|
|||||||
block_shape,
|
block_shape,
|
||||||
False,
|
False,
|
||||||
routed_scaling_factor,
|
routed_scaling_factor,
|
||||||
activation_alpha,
|
gemm1_alpha,
|
||||||
swiglu_limit,
|
gemm1_limit,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -1081,8 +1084,8 @@ def inplace_fused_experts_fake(
|
|||||||
a2_scale: Optional[torch.Tensor] = None,
|
a2_scale: Optional[torch.Tensor] = None,
|
||||||
block_shape: Optional[List[int]] = None,
|
block_shape: Optional[List[int]] = None,
|
||||||
routed_scaling_factor: Optional[float] = None,
|
routed_scaling_factor: Optional[float] = None,
|
||||||
activation_alpha: Optional[float] = None,
|
gemm1_alpha: Optional[float] = None,
|
||||||
swiglu_limit: Optional[float] = None,
|
gemm1_limit: Optional[float] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@@ -1119,8 +1122,8 @@ def outplace_fused_experts(
|
|||||||
block_shape: Optional[List[int]] = None,
|
block_shape: Optional[List[int]] = None,
|
||||||
no_combine: bool = False,
|
no_combine: bool = False,
|
||||||
routed_scaling_factor: Optional[float] = None,
|
routed_scaling_factor: Optional[float] = None,
|
||||||
activation_alpha: Optional[float] = None,
|
gemm1_alpha: Optional[float] = None,
|
||||||
swiglu_limit: Optional[float] = None,
|
gemm1_limit: Optional[float] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
return fused_experts_impl(
|
return fused_experts_impl(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
@@ -1147,8 +1150,8 @@ def outplace_fused_experts(
|
|||||||
block_shape,
|
block_shape,
|
||||||
no_combine=no_combine,
|
no_combine=no_combine,
|
||||||
routed_scaling_factor=routed_scaling_factor,
|
routed_scaling_factor=routed_scaling_factor,
|
||||||
activation_alpha=activation_alpha,
|
gemm1_alpha=gemm1_alpha,
|
||||||
swiglu_limit=swiglu_limit,
|
gemm1_limit=gemm1_limit,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -1176,8 +1179,8 @@ def outplace_fused_experts_fake(
|
|||||||
block_shape: Optional[List[int]] = None,
|
block_shape: Optional[List[int]] = None,
|
||||||
no_combine: bool = False,
|
no_combine: bool = False,
|
||||||
routed_scaling_factor: Optional[float] = None,
|
routed_scaling_factor: Optional[float] = None,
|
||||||
activation_alpha: Optional[float] = None,
|
gemm1_alpha: Optional[float] = None,
|
||||||
swiglu_limit: Optional[float] = None,
|
gemm1_limit: Optional[float] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
return torch.empty_like(hidden_states)
|
return torch.empty_like(hidden_states)
|
||||||
|
|
||||||
@@ -1194,12 +1197,10 @@ def fused_experts(
|
|||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
w1: torch.Tensor,
|
w1: torch.Tensor,
|
||||||
w2: torch.Tensor,
|
w2: torch.Tensor,
|
||||||
topk_output: TopKOutput,
|
topk_output: StandardTopKOutput,
|
||||||
|
moe_runner_config: MoeRunnerConfig,
|
||||||
b1: Optional[torch.Tensor] = None,
|
b1: Optional[torch.Tensor] = None,
|
||||||
b2: Optional[torch.Tensor] = None,
|
b2: Optional[torch.Tensor] = None,
|
||||||
inplace: bool = False,
|
|
||||||
activation: str = "silu",
|
|
||||||
apply_router_weight_on_input: bool = False,
|
|
||||||
use_fp8_w8a8: bool = False,
|
use_fp8_w8a8: bool = False,
|
||||||
use_int8_w8a8: bool = False,
|
use_int8_w8a8: bool = False,
|
||||||
use_int8_w8a16: bool = False,
|
use_int8_w8a16: bool = False,
|
||||||
@@ -1212,14 +1213,10 @@ def fused_experts(
|
|||||||
a1_scale: Optional[torch.Tensor] = None,
|
a1_scale: Optional[torch.Tensor] = None,
|
||||||
a2_scale: Optional[torch.Tensor] = None,
|
a2_scale: Optional[torch.Tensor] = None,
|
||||||
block_shape: Optional[List[int]] = None,
|
block_shape: Optional[List[int]] = None,
|
||||||
no_combine: bool = False,
|
|
||||||
routed_scaling_factor: Optional[float] = None,
|
|
||||||
activation_alpha: Optional[float] = None,
|
|
||||||
swiglu_limit: Optional[float] = None,
|
|
||||||
):
|
):
|
||||||
topk_weights, topk_ids, _ = topk_output
|
topk_weights, topk_ids, _ = topk_output
|
||||||
if inplace:
|
if moe_runner_config.inplace:
|
||||||
assert not no_combine, "no combine + inplace makes no sense"
|
assert not moe_runner_config.no_combine, "no combine + inplace makes no sense"
|
||||||
torch.ops.sglang.inplace_fused_experts(
|
torch.ops.sglang.inplace_fused_experts(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
w1,
|
w1,
|
||||||
@@ -1228,8 +1225,8 @@ def fused_experts(
|
|||||||
topk_ids,
|
topk_ids,
|
||||||
b1,
|
b1,
|
||||||
b2,
|
b2,
|
||||||
activation,
|
moe_runner_config.activation,
|
||||||
apply_router_weight_on_input,
|
moe_runner_config.apply_router_weight_on_input,
|
||||||
use_fp8_w8a8,
|
use_fp8_w8a8,
|
||||||
use_int8_w8a8,
|
use_int8_w8a8,
|
||||||
use_int8_w8a16,
|
use_int8_w8a16,
|
||||||
@@ -1242,9 +1239,9 @@ def fused_experts(
|
|||||||
a1_scale,
|
a1_scale,
|
||||||
a2_scale,
|
a2_scale,
|
||||||
block_shape,
|
block_shape,
|
||||||
routed_scaling_factor,
|
moe_runner_config.routed_scaling_factor,
|
||||||
activation_alpha,
|
moe_runner_config.gemm1_alpha,
|
||||||
swiglu_limit,
|
moe_runner_config.gemm1_clamp_limit,
|
||||||
)
|
)
|
||||||
return hidden_states
|
return hidden_states
|
||||||
else:
|
else:
|
||||||
@@ -1256,8 +1253,8 @@ def fused_experts(
|
|||||||
topk_ids,
|
topk_ids,
|
||||||
b1,
|
b1,
|
||||||
b2,
|
b2,
|
||||||
activation,
|
moe_runner_config.activation,
|
||||||
apply_router_weight_on_input,
|
moe_runner_config.apply_router_weight_on_input,
|
||||||
use_fp8_w8a8,
|
use_fp8_w8a8,
|
||||||
use_int8_w8a8,
|
use_int8_w8a8,
|
||||||
use_int8_w8a16,
|
use_int8_w8a16,
|
||||||
@@ -1270,10 +1267,10 @@ def fused_experts(
|
|||||||
a1_scale,
|
a1_scale,
|
||||||
a2_scale,
|
a2_scale,
|
||||||
block_shape,
|
block_shape,
|
||||||
no_combine=no_combine,
|
no_combine=moe_runner_config.no_combine,
|
||||||
routed_scaling_factor=routed_scaling_factor,
|
routed_scaling_factor=moe_runner_config.routed_scaling_factor,
|
||||||
activation_alpha=activation_alpha,
|
gemm1_alpha=moe_runner_config.gemm1_alpha,
|
||||||
swiglu_limit=swiglu_limit,
|
gemm1_limit=moe_runner_config.gemm1_clamp_limit,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -1370,11 +1367,11 @@ def moe_sum_reduce_torch_compile(x, out, routed_scaling_factor):
|
|||||||
|
|
||||||
|
|
||||||
@torch.compile
|
@torch.compile
|
||||||
def swiglu_with_alpha_and_limit(x, alpha, limit):
|
def swiglu_with_alpha_and_limit(x, gemm1_alpha, gemm1_limit):
|
||||||
gate, up = x[..., ::2], x[..., 1::2]
|
gate, up = x[..., ::2], x[..., 1::2]
|
||||||
gate = gate.clamp(min=None, max=limit)
|
gate = gate.clamp(min=None, max=gemm1_limit)
|
||||||
up = up.clamp(min=-limit, max=limit)
|
up = up.clamp(min=-gemm1_limit, max=gemm1_limit)
|
||||||
return gate * torch.sigmoid(gate * alpha) * (up + 1)
|
return gate * torch.sigmoid(gate * gemm1_alpha) * (up + 1)
|
||||||
|
|
||||||
|
|
||||||
def fused_experts_impl(
|
def fused_experts_impl(
|
||||||
@@ -1402,8 +1399,8 @@ def fused_experts_impl(
|
|||||||
block_shape: Optional[List[int]] = None,
|
block_shape: Optional[List[int]] = None,
|
||||||
no_combine: bool = False,
|
no_combine: bool = False,
|
||||||
routed_scaling_factor: Optional[float] = None,
|
routed_scaling_factor: Optional[float] = None,
|
||||||
activation_alpha: Optional[float] = None,
|
gemm1_alpha: Optional[float] = None,
|
||||||
swiglu_limit: Optional[float] = None,
|
gemm1_limit: Optional[float] = None,
|
||||||
):
|
):
|
||||||
padded_size = padding_size
|
padded_size = padding_size
|
||||||
if not (use_fp8_w8a8 or use_int8_w8a8) or block_shape is not None or _use_aiter:
|
if not (use_fp8_w8a8 or use_int8_w8a8) or block_shape is not None or _use_aiter:
|
||||||
@@ -1533,12 +1530,12 @@ def fused_experts_impl(
|
|||||||
block_shape=block_shape,
|
block_shape=block_shape,
|
||||||
)
|
)
|
||||||
if activation == "silu":
|
if activation == "silu":
|
||||||
if activation_alpha is not None:
|
if gemm1_alpha is not None:
|
||||||
assert swiglu_limit is not None
|
assert gemm1_limit is not None
|
||||||
intermediate_cache2 = swiglu_with_alpha_and_limit(
|
intermediate_cache2 = swiglu_with_alpha_and_limit(
|
||||||
intermediate_cache1.view(-1, N),
|
intermediate_cache1.view(-1, N),
|
||||||
activation_alpha,
|
gemm1_alpha,
|
||||||
swiglu_limit,
|
gemm1_limit,
|
||||||
)
|
)
|
||||||
elif _is_cuda:
|
elif _is_cuda:
|
||||||
silu_and_mul(intermediate_cache1.view(-1, N), intermediate_cache2)
|
silu_and_mul(intermediate_cache1.view(-1, N), intermediate_cache2)
|
||||||
@@ -1547,10 +1544,8 @@ def fused_experts_impl(
|
|||||||
intermediate_cache2, intermediate_cache1.view(-1, N)
|
intermediate_cache2, intermediate_cache1.view(-1, N)
|
||||||
)
|
)
|
||||||
elif activation == "gelu":
|
elif activation == "gelu":
|
||||||
assert (
|
assert gemm1_alpha is None, "gemm1_alpha is not supported for gelu"
|
||||||
activation_alpha is None
|
assert gemm1_limit is None, "gemm1_limit is not supported for gelu"
|
||||||
), "activation_alpha is not supported for gelu"
|
|
||||||
assert swiglu_limit is None, "swiglu_limit is not supported for gelu"
|
|
||||||
if _is_cuda:
|
if _is_cuda:
|
||||||
gelu_and_mul(intermediate_cache1.view(-1, N), intermediate_cache2)
|
gelu_and_mul(intermediate_cache1.view(-1, N), intermediate_cache2)
|
||||||
else:
|
else:
|
||||||
@@ -1641,12 +1636,10 @@ def fused_moe(
|
|||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
w1: torch.Tensor,
|
w1: torch.Tensor,
|
||||||
w2: torch.Tensor,
|
w2: torch.Tensor,
|
||||||
topk_output: TopKOutput,
|
topk_output: StandardTopKOutput,
|
||||||
|
moe_runner_config: MoeRunnerConfig = MoeRunnerConfig(),
|
||||||
b1: Optional[torch.Tensor] = None,
|
b1: Optional[torch.Tensor] = None,
|
||||||
b2: Optional[torch.Tensor] = None,
|
b2: Optional[torch.Tensor] = None,
|
||||||
inplace: bool = False,
|
|
||||||
activation: str = "silu",
|
|
||||||
apply_router_weight_on_input: bool = False,
|
|
||||||
use_fp8_w8a8: bool = False,
|
use_fp8_w8a8: bool = False,
|
||||||
use_int8_w8a8: bool = False,
|
use_int8_w8a8: bool = False,
|
||||||
use_int8_w8a16: bool = False,
|
use_int8_w8a16: bool = False,
|
||||||
@@ -1659,10 +1652,6 @@ def fused_moe(
|
|||||||
a1_scale: Optional[torch.Tensor] = None,
|
a1_scale: Optional[torch.Tensor] = None,
|
||||||
a2_scale: Optional[torch.Tensor] = None,
|
a2_scale: Optional[torch.Tensor] = None,
|
||||||
block_shape: Optional[List[int]] = None,
|
block_shape: Optional[List[int]] = None,
|
||||||
no_combine: bool = False,
|
|
||||||
routed_scaling_factor: Optional[float] = None,
|
|
||||||
activation_alpha: Optional[float] = None,
|
|
||||||
swiglu_limit: Optional[float] = None,
|
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
This function computes a Mixture of Experts (MoE) layer using two sets of
|
This function computes a Mixture of Experts (MoE) layer using two sets of
|
||||||
@@ -1672,11 +1661,10 @@ def fused_moe(
|
|||||||
- hidden_states (torch.Tensor): The input tensor to the MoE layer.
|
- hidden_states (torch.Tensor): The input tensor to the MoE layer.
|
||||||
- w1 (torch.Tensor): The first set of expert weights.
|
- w1 (torch.Tensor): The first set of expert weights.
|
||||||
- w2 (torch.Tensor): The second set of expert weights.
|
- w2 (torch.Tensor): The second set of expert weights.
|
||||||
- topk_output (TopKOutput): The top-k output of the experts.
|
- topk_output (StandardTopKOutput): The top-k output of the experts.
|
||||||
|
- moe_runner_config (MoeRunnerConfig): The configuration for the MoE runner.
|
||||||
- b1 (Optional[torch.Tensor]): Optional bias for w1.
|
- b1 (Optional[torch.Tensor]): Optional bias for w1.
|
||||||
- b2 (Optional[torch.Tensor]): Optional bias for w2.
|
- b2 (Optional[torch.Tensor]): Optional bias for w2.
|
||||||
- inplace (bool): If True, perform the operation in-place.
|
|
||||||
Defaults to False.
|
|
||||||
- use_fp8_w8a8 (bool): If True, use fp8 arithmetic to compute the inner
|
- use_fp8_w8a8 (bool): If True, use fp8 arithmetic to compute the inner
|
||||||
products for w1 and w2. Defaults to False.
|
products for w1 and w2. Defaults to False.
|
||||||
- use_int8_w8a8 (bool): If True, use int8 arithmetic to compute the inner
|
- use_int8_w8a8 (bool): If True, use int8 arithmetic to compute the inner
|
||||||
@@ -1696,9 +1684,9 @@ def fused_moe(
|
|||||||
a2.
|
a2.
|
||||||
- block_shape: (Optional[List[int]]): Optional block size for block-wise
|
- block_shape: (Optional[List[int]]): Optional block size for block-wise
|
||||||
quantization.
|
quantization.
|
||||||
- activation_alpha (Optional[float]): Optional alpha for the activation
|
- gemm1_alpha (Optional[float]): Optional gemm1_alpha for the activation
|
||||||
function.
|
function.
|
||||||
- swiglu_limit (Optional[float]): Optional limit for the swiglu activation
|
- gemm1_limit (Optional[float]): Optional gemm1_limit for the swiglu activation
|
||||||
function.
|
function.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
@@ -1710,11 +1698,9 @@ def fused_moe(
|
|||||||
w1,
|
w1,
|
||||||
w2,
|
w2,
|
||||||
topk_output,
|
topk_output,
|
||||||
|
moe_runner_config=moe_runner_config,
|
||||||
b1=b1,
|
b1=b1,
|
||||||
b2=b2,
|
b2=b2,
|
||||||
inplace=inplace,
|
|
||||||
activation=activation,
|
|
||||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
|
||||||
use_fp8_w8a8=use_fp8_w8a8,
|
use_fp8_w8a8=use_fp8_w8a8,
|
||||||
use_int8_w8a8=use_int8_w8a8,
|
use_int8_w8a8=use_int8_w8a8,
|
||||||
use_int8_w8a16=use_int8_w8a16,
|
use_int8_w8a16=use_int8_w8a16,
|
||||||
@@ -1727,8 +1713,4 @@ def fused_moe(
|
|||||||
a1_scale=a1_scale,
|
a1_scale=a1_scale,
|
||||||
a2_scale=a2_scale,
|
a2_scale=a2_scale,
|
||||||
block_shape=block_shape,
|
block_shape=block_shape,
|
||||||
no_combine=no_combine,
|
|
||||||
routed_scaling_factor=routed_scaling_factor,
|
|
||||||
activation_alpha=activation_alpha,
|
|
||||||
swiglu_limit=swiglu_limit,
|
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -1,10 +1,6 @@
|
|||||||
# Adapted from https://github.com/vllm-project/vllm/blob/a6221a144af772fd1a68fe7e627935dc53e81738/vllm/model_executor/layers/fused_moe/layer.py
|
# Adapted from https://github.com/vllm-project/vllm/blob/a6221a144af772fd1a68fe7e627935dc53e81738/vllm/model_executor/layers/fused_moe/layer.py
|
||||||
|
|
||||||
import datetime
|
|
||||||
import glob
|
|
||||||
import logging
|
import logging
|
||||||
import os
|
|
||||||
import sys
|
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import List, Optional, Tuple
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
@@ -22,8 +18,12 @@ from sglang.srt.distributed.device_communicators.pynccl_allocator import (
|
|||||||
use_symmetric_memory,
|
use_symmetric_memory,
|
||||||
)
|
)
|
||||||
from sglang.srt.eplb.expert_location import get_global_expert_location_metadata
|
from sglang.srt.eplb.expert_location import get_global_expert_location_metadata
|
||||||
from sglang.srt.layers.moe.topk import StandardTopKOutput
|
from sglang.srt.layers.moe import (
|
||||||
from sglang.srt.layers.moe.utils import should_use_flashinfer_trtllm_moe
|
MoeRunnerConfig,
|
||||||
|
get_moe_runner_backend,
|
||||||
|
should_use_flashinfer_trtllm_moe,
|
||||||
|
)
|
||||||
|
from sglang.srt.layers.moe.topk import TopKOutput, TopKOutputChecker
|
||||||
from sglang.srt.layers.quantization.base_config import (
|
from sglang.srt.layers.quantization.base_config import (
|
||||||
QuantizationConfig,
|
QuantizationConfig,
|
||||||
QuantizeMethodBase,
|
QuantizeMethodBase,
|
||||||
@@ -126,7 +126,6 @@ class FusedMoE(torch.nn.Module):
|
|||||||
params_dtype: Optional[torch.dtype] = None,
|
params_dtype: Optional[torch.dtype] = None,
|
||||||
reduce_results: bool = False,
|
reduce_results: bool = False,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
tp_size: Optional[int] = None,
|
|
||||||
prefix: str = "",
|
prefix: str = "",
|
||||||
activation: str = "silu",
|
activation: str = "silu",
|
||||||
apply_router_weight_on_input: bool = False,
|
apply_router_weight_on_input: bool = False,
|
||||||
@@ -134,9 +133,8 @@ class FusedMoE(torch.nn.Module):
|
|||||||
inplace: bool = True,
|
inplace: bool = True,
|
||||||
no_combine: bool = False,
|
no_combine: bool = False,
|
||||||
routed_scaling_factor: Optional[float] = None,
|
routed_scaling_factor: Optional[float] = None,
|
||||||
enable_flashinfer_cutlass_moe: Optional[bool] = False,
|
gemm1_alpha: Optional[float] = None,
|
||||||
activation_alpha: Optional[float] = None,
|
gemm1_clamp_limit: Optional[float] = None,
|
||||||
swiglu_limit: Optional[float] = None,
|
|
||||||
use_weight_loader_fused: bool = False,
|
use_weight_loader_fused: bool = False,
|
||||||
with_bias=False,
|
with_bias=False,
|
||||||
):
|
):
|
||||||
@@ -153,9 +151,17 @@ class FusedMoE(torch.nn.Module):
|
|||||||
self.expert_map_cpu = None
|
self.expert_map_cpu = None
|
||||||
self.expert_map_gpu = None
|
self.expert_map_gpu = None
|
||||||
|
|
||||||
# For activation
|
self.moe_runner_config = MoeRunnerConfig(
|
||||||
self.activation_alpha = activation_alpha
|
activation=activation,
|
||||||
self.swiglu_limit = swiglu_limit
|
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||||
|
inplace=inplace,
|
||||||
|
no_combine=no_combine,
|
||||||
|
routed_scaling_factor=routed_scaling_factor,
|
||||||
|
gemm1_alpha=gemm1_alpha,
|
||||||
|
gemm1_clamp_limit=gemm1_clamp_limit,
|
||||||
|
)
|
||||||
|
|
||||||
|
enable_flashinfer_cutlass_moe = get_moe_runner_backend().is_flashinfer_cutlass()
|
||||||
|
|
||||||
if enable_flashinfer_cutlass_moe and quant_config is None:
|
if enable_flashinfer_cutlass_moe and quant_config is None:
|
||||||
logger.warning("Disable flashinfer MoE when quantization config is None.")
|
logger.warning("Disable flashinfer MoE when quantization config is None.")
|
||||||
@@ -184,20 +190,12 @@ class FusedMoE(torch.nn.Module):
|
|||||||
* self.num_local_experts
|
* self.num_local_experts
|
||||||
] = torch.arange(0, self.num_local_experts, dtype=torch.int32, device="cpu")
|
] = torch.arange(0, self.num_local_experts, dtype=torch.int32, device="cpu")
|
||||||
|
|
||||||
self.routed_scaling_factor = routed_scaling_factor
|
|
||||||
assert intermediate_size % self.moe_tp_size == 0
|
assert intermediate_size % self.moe_tp_size == 0
|
||||||
self.intermediate_size_per_partition = intermediate_size // self.moe_tp_size
|
self.intermediate_size_per_partition = intermediate_size // self.moe_tp_size
|
||||||
self.reduce_results = reduce_results
|
self.reduce_results = reduce_results
|
||||||
self.activation = activation
|
|
||||||
self.apply_router_weight_on_input = apply_router_weight_on_input
|
|
||||||
self.use_presharded_weights = use_presharded_weights
|
self.use_presharded_weights = use_presharded_weights
|
||||||
self.inplace = inplace
|
|
||||||
self.no_combine = no_combine
|
|
||||||
|
|
||||||
self.use_triton_kernels = (
|
|
||||||
not _is_cpu and global_server_args_dict["enable_triton_kernel_moe"]
|
|
||||||
)
|
|
||||||
|
|
||||||
|
self.use_triton_kernels = get_moe_runner_backend().is_triton_kernel()
|
||||||
if quant_config is None:
|
if quant_config is None:
|
||||||
self.quant_method: Optional[QuantizeMethodBase] = UnquantizedFusedMoEMethod(
|
self.quant_method: Optional[QuantizeMethodBase] = UnquantizedFusedMoEMethod(
|
||||||
self.use_triton_kernels
|
self.use_triton_kernels
|
||||||
@@ -207,14 +205,12 @@ class FusedMoE(torch.nn.Module):
|
|||||||
assert self.quant_method is not None
|
assert self.quant_method is not None
|
||||||
|
|
||||||
self.quant_config = quant_config
|
self.quant_config = quant_config
|
||||||
self.use_enable_flashinfer_mxfp4_moe = global_server_args_dict.get(
|
self.use_flashinfer_mxfp4_moe = get_moe_runner_backend().is_flashinfer_mxfp4()
|
||||||
"enable_flashinfer_mxfp4_moe", False
|
|
||||||
)
|
|
||||||
# TODO maybe we should remove this `if`, since `Mxfp4MoEMethod` does another round-up logic
|
# TODO maybe we should remove this `if`, since `Mxfp4MoEMethod` does another round-up logic
|
||||||
if (
|
if (
|
||||||
self.quant_config is not None
|
self.quant_config is not None
|
||||||
and self.quant_config.get_name() == "mxfp4"
|
and self.quant_config.get_name() == "mxfp4"
|
||||||
and self.use_enable_flashinfer_mxfp4_moe
|
and self.use_flashinfer_mxfp4_moe
|
||||||
):
|
):
|
||||||
hidden_size = round_up(hidden_size, 256)
|
hidden_size = round_up(hidden_size, 256)
|
||||||
self.quant_method.create_weights(
|
self.quant_method.create_weights(
|
||||||
@@ -794,7 +790,7 @@ class FusedMoE(torch.nn.Module):
|
|||||||
f"Unsupported weight_name {weight_name} for FusedMoE weight_loader_fused. Nothing is loaded."
|
f"Unsupported weight_name {weight_name} for FusedMoE weight_loader_fused. Nothing is loaded."
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, hidden_states: torch.Tensor, topk_output: StandardTopKOutput):
|
def forward(self, hidden_states: torch.Tensor, topk_output: TopKOutput):
|
||||||
origin_hidden_states_dim = hidden_states.shape[-1]
|
origin_hidden_states_dim = hidden_states.shape[-1]
|
||||||
assert self.quant_method is not None
|
assert self.quant_method is not None
|
||||||
|
|
||||||
@@ -803,40 +799,22 @@ class FusedMoE(torch.nn.Module):
|
|||||||
# If we are in EP mode, we need to move the expert map to GPU.
|
# If we are in EP mode, we need to move the expert map to GPU.
|
||||||
self.expert_map_gpu = self.expert_map_cpu.to(device="cuda")
|
self.expert_map_gpu = self.expert_map_cpu.to(device="cuda")
|
||||||
|
|
||||||
if self.expert_map_gpu is not None and isinstance(
|
if self.expert_map_gpu is not None:
|
||||||
topk_output, StandardTopKOutput
|
if TopKOutputChecker.format_is_standard(topk_output):
|
||||||
):
|
|
||||||
topk_output = topk_output._replace(
|
topk_output = topk_output._replace(
|
||||||
topk_ids=self.expert_map_gpu[topk_output.topk_ids]
|
topk_ids=self.expert_map_gpu[topk_output.topk_ids]
|
||||||
)
|
)
|
||||||
|
elif TopKOutputChecker.format_is_triton_kernel(topk_output):
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
# Matrix multiply.
|
# Matrix multiply.
|
||||||
with use_symmetric_memory(get_tp_group()) as sm:
|
with use_symmetric_memory(get_tp_group()) as sm:
|
||||||
kwargs = {}
|
|
||||||
if self.activation_alpha is not None:
|
|
||||||
kwargs["activation_alpha"] = self.activation_alpha
|
|
||||||
if self.swiglu_limit is not None:
|
|
||||||
kwargs["swiglu_limit"] = self.swiglu_limit
|
|
||||||
|
|
||||||
final_hidden_states = self.quant_method.apply(
|
final_hidden_states = self.quant_method.apply(
|
||||||
layer=self,
|
layer=self,
|
||||||
x=hidden_states,
|
x=hidden_states,
|
||||||
topk_output=topk_output,
|
topk_output=topk_output,
|
||||||
activation=self.activation,
|
moe_runner_config=self.moe_runner_config,
|
||||||
apply_router_weight_on_input=self.apply_router_weight_on_input,
|
|
||||||
routed_scaling_factor=self.routed_scaling_factor,
|
|
||||||
**(
|
|
||||||
dict(
|
|
||||||
tp_rank=self.moe_tp_rank,
|
|
||||||
tp_size=self.moe_tp_size,
|
|
||||||
ep_rank=self.moe_ep_rank,
|
|
||||||
ep_size=self.moe_ep_size,
|
|
||||||
)
|
|
||||||
if self.quant_method.__class__.__name__
|
|
||||||
== "ModelOptNvFp4FusedMoEMethod"
|
|
||||||
else {}
|
|
||||||
),
|
|
||||||
**kwargs,
|
|
||||||
)
|
)
|
||||||
sm.tag(final_hidden_states)
|
sm.tag(final_hidden_states)
|
||||||
|
|
||||||
@@ -944,24 +922,10 @@ class FusedMoE(torch.nn.Module):
|
|||||||
|
|
||||||
class FlashInferFusedMoE(FusedMoE):
|
class FlashInferFusedMoE(FusedMoE):
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
renormalize = kwargs.pop("renormalize", True)
|
|
||||||
num_fused_shared_experts = kwargs.pop("num_fused_shared_experts", 0)
|
|
||||||
use_grouped_topk = kwargs.pop("use_grouped_topk", False)
|
|
||||||
num_expert_group = kwargs.pop("num_expert_group", None)
|
|
||||||
topk_group = kwargs.pop("topk_group", None)
|
|
||||||
correction_bias = kwargs.pop("correction_bias", None)
|
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
self.renormalize = renormalize
|
|
||||||
self.num_fused_shared_experts = num_fused_shared_experts
|
|
||||||
self.use_grouped_topk = use_grouped_topk
|
|
||||||
if self.use_grouped_topk:
|
|
||||||
assert num_expert_group is not None and topk_group is not None
|
|
||||||
self.num_expert_group = num_expert_group
|
|
||||||
self.topk_group = topk_group
|
|
||||||
self.correction_bias = correction_bias
|
|
||||||
self.use_flashinfer_trtllm_moe = should_use_flashinfer_trtllm_moe()
|
self.use_flashinfer_trtllm_moe = should_use_flashinfer_trtllm_moe()
|
||||||
|
|
||||||
def forward(self, hidden_states: torch.Tensor, topk_output: tuple):
|
def forward(self, hidden_states: torch.Tensor, topk_output: TopKOutput):
|
||||||
assert self.use_flashinfer_trtllm_moe
|
assert self.use_flashinfer_trtllm_moe
|
||||||
assert (
|
assert (
|
||||||
self.activation == "silu"
|
self.activation == "silu"
|
||||||
@@ -974,20 +938,14 @@ class FlashInferFusedMoE(FusedMoE):
|
|||||||
self.num_fused_shared_experts == 0
|
self.num_fused_shared_experts == 0
|
||||||
), "Fused shared experts are not supported for flashinfer blockscale fp8 moe"
|
), "Fused shared experts are not supported for flashinfer blockscale fp8 moe"
|
||||||
|
|
||||||
# TRTLLM mode expects (TopK_config, router_logits) tuple
|
assert TopKOutputChecker.format_is_bypassed(topk_output)
|
||||||
if not isinstance(topk_output, tuple) or len(topk_output) != 2:
|
|
||||||
raise ValueError(
|
|
||||||
f"FlashInferFusedMoE expects (TopK_config, router_logits) tuple, got {type(topk_output)}"
|
|
||||||
)
|
|
||||||
_, router_logits = topk_output
|
|
||||||
|
|
||||||
# Matrix multiply.
|
# Matrix multiply.
|
||||||
final_hidden_states = self.quant_method.apply_with_router_logits(
|
final_hidden_states = self.quant_method.apply_with_router_logits(
|
||||||
layer=self,
|
layer=self,
|
||||||
x=hidden_states,
|
x=hidden_states,
|
||||||
router_logits=router_logits,
|
topk_output=topk_output,
|
||||||
activation=self.activation,
|
moe_runner_config=self.moe_runner_config,
|
||||||
routed_scaling_factor=self.routed_scaling_factor,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.reduce_results and (self.moe_tp_size > 1 or self.moe_ep_size > 1):
|
if self.reduce_results and (self.moe_tp_size > 1 or self.moe_ep_size > 1):
|
||||||
@@ -1000,28 +958,8 @@ class FlashInferFP4MoE(FusedMoE):
|
|||||||
"""FP4 TRTLLM MoE implementation using FlashInfer."""
|
"""FP4 TRTLLM MoE implementation using FlashInfer."""
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
# Extract DeepSeek-specific parameters
|
|
||||||
renormalize = kwargs.pop("renormalize", True)
|
|
||||||
num_fused_shared_experts = kwargs.pop("num_fused_shared_experts", 0)
|
|
||||||
use_grouped_topk = kwargs.pop("use_grouped_topk", False)
|
|
||||||
num_expert_group = kwargs.pop("num_expert_group", None)
|
|
||||||
topk_group = kwargs.pop("topk_group", None)
|
|
||||||
correction_bias = kwargs.pop("correction_bias", None)
|
|
||||||
|
|
||||||
# Extract additional TopK parameters that were previously extracted in forward
|
|
||||||
routed_scaling_factor = kwargs.pop("routed_scaling_factor", None)
|
|
||||||
|
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
# Store DeepSeek parameters
|
|
||||||
self.renormalize = renormalize
|
|
||||||
self.num_fused_shared_experts = num_fused_shared_experts
|
|
||||||
self.use_grouped_topk = use_grouped_topk
|
|
||||||
self.num_expert_group = num_expert_group
|
|
||||||
self.topk_group = topk_group
|
|
||||||
self.correction_bias = correction_bias
|
|
||||||
self.routed_scaling_factor = routed_scaling_factor
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------
|
# ---------------------------------------------------------------------
|
||||||
# Helper: quantize hidden states to FP4 each forward pass
|
# Helper: quantize hidden states to FP4 each forward pass
|
||||||
# ---------------------------------------------------------------------
|
# ---------------------------------------------------------------------
|
||||||
@@ -1052,21 +990,17 @@ class FlashInferFP4MoE(FusedMoE):
|
|||||||
|
|
||||||
return hs_fp4, hs_sf
|
return hs_fp4, hs_sf
|
||||||
|
|
||||||
def forward(self, hidden_states: torch.Tensor, topk_output):
|
def forward(self, hidden_states: torch.Tensor, topk_output: TopKOutput):
|
||||||
"""Forward pass using FP4 TRTLLM kernel.
|
"""Forward pass using FP4 TRTLLM kernel.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
hidden_states: Input tensor
|
hidden_states: Input tensor
|
||||||
topk_output: Should be tuple of (TopK_config, router_logits) for TRTLLM mode
|
topk_output: TopKOutput object with Bypassed format
|
||||||
"""
|
"""
|
||||||
|
assert TopKOutputChecker.format_is_bypassed(topk_output)
|
||||||
|
|
||||||
# TRTLLM mode expects (TopK_config, router_logits) tuple
|
router_logits = topk_output.router_logits
|
||||||
if not isinstance(topk_output, tuple) or len(topk_output) != 2:
|
topk_config = topk_output.topk_config
|
||||||
raise ValueError(
|
|
||||||
f"FlashInferFP4MoE expects (TopK_config, router_logits) tuple, got {type(topk_output)}"
|
|
||||||
)
|
|
||||||
|
|
||||||
_, router_logits = topk_output
|
|
||||||
|
|
||||||
hs_fp4, hs_scale_linear = self._quantize_hidden_states_fp4(hidden_states)
|
hs_fp4, hs_scale_linear = self._quantize_hidden_states_fp4(hidden_states)
|
||||||
|
|
||||||
@@ -1074,7 +1008,7 @@ class FlashInferFP4MoE(FusedMoE):
|
|||||||
|
|
||||||
result = trtllm_fp4_block_scale_moe(
|
result = trtllm_fp4_block_scale_moe(
|
||||||
routing_logits=router_logits,
|
routing_logits=router_logits,
|
||||||
routing_bias=self.correction_bias.to(hidden_states.dtype),
|
routing_bias=topk_config.correction_bias.to(hidden_states.dtype),
|
||||||
hidden_states=hs_fp4,
|
hidden_states=hs_fp4,
|
||||||
hidden_states_scale=hs_scale_linear.view(torch.float8_e4m3fn).flatten(),
|
hidden_states_scale=hs_scale_linear.view(torch.float8_e4m3fn).flatten(),
|
||||||
gemm1_weights=self.gemm1_weights_fp4_shuffled.data,
|
gemm1_weights=self.gemm1_weights_fp4_shuffled.data,
|
||||||
@@ -1094,15 +1028,15 @@ class FlashInferFP4MoE(FusedMoE):
|
|||||||
output1_scale_gate_scalar=self.g1_alphas.data,
|
output1_scale_gate_scalar=self.g1_alphas.data,
|
||||||
output2_scale_scalar=self.g2_alphas.data,
|
output2_scale_scalar=self.g2_alphas.data,
|
||||||
num_experts=self.num_experts,
|
num_experts=self.num_experts,
|
||||||
top_k=self.top_k,
|
top_k=topk_config.top_k,
|
||||||
n_group=self.num_expert_group,
|
n_group=topk_config.num_expert_group,
|
||||||
topk_group=self.topk_group,
|
topk_group=topk_config.topk_group,
|
||||||
intermediate_size=self.intermediate_size_per_partition,
|
intermediate_size=self.intermediate_size_per_partition,
|
||||||
local_expert_offset=self.moe_ep_rank * self.num_local_experts,
|
local_expert_offset=self.moe_ep_rank * self.num_local_experts,
|
||||||
local_num_experts=self.num_local_experts,
|
local_num_experts=self.num_local_experts,
|
||||||
routed_scaling_factor=self.routed_scaling_factor,
|
routed_scaling_factor=self.moe_runner_config.routed_scaling_factor,
|
||||||
tile_tokens_dim=_get_tile_tokens_dim(
|
tile_tokens_dim=_get_tile_tokens_dim(
|
||||||
hidden_states.shape[0], self.top_k, self.num_local_experts
|
hidden_states.shape[0], topk_config.top_k, self.num_local_experts
|
||||||
),
|
),
|
||||||
routing_method_type=RoutingMethodType.DeepSeekV3,
|
routing_method_type=RoutingMethodType.DeepSeekV3,
|
||||||
do_finalize=True,
|
do_finalize=True,
|
||||||
|
|||||||
@@ -18,6 +18,7 @@ from triton_kernels.routing import GatherIndx, RoutingData, ScatterIndx
|
|||||||
from triton_kernels.swiglu import swiglu_fn
|
from triton_kernels.swiglu import swiglu_fn
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
|
||||||
from sglang.srt.layers.moe.topk import TopKOutput
|
from sglang.srt.layers.moe.topk import TopKOutput
|
||||||
|
|
||||||
|
|
||||||
@@ -55,8 +56,7 @@ def triton_kernel_moe_forward(
|
|||||||
w1: torch.Tensor,
|
w1: torch.Tensor,
|
||||||
w2: torch.Tensor,
|
w2: torch.Tensor,
|
||||||
topk_output: TopKOutput,
|
topk_output: TopKOutput,
|
||||||
inplace: bool = False,
|
moe_runner_config: MoeRunnerConfig,
|
||||||
activation: str = "silu",
|
|
||||||
apply_router_weight_on_input: bool = False,
|
apply_router_weight_on_input: bool = False,
|
||||||
use_fp8_w8a8: bool = False,
|
use_fp8_w8a8: bool = False,
|
||||||
per_channel_quant: bool = False,
|
per_channel_quant: bool = False,
|
||||||
@@ -69,7 +69,10 @@ def triton_kernel_moe_forward(
|
|||||||
block_shape: Optional[list[int]] = None,
|
block_shape: Optional[list[int]] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
|
|
||||||
assert topk_output.format.is_triton_kernel()
|
from sglang.srt.layers.moe.topk import TopKOutputChecker
|
||||||
|
|
||||||
|
assert TopKOutputChecker.format_is_triton_kernel(topk_output)
|
||||||
|
|
||||||
routing_data, gather_idx, scatter_idx = topk_output
|
routing_data, gather_idx, scatter_idx = topk_output
|
||||||
|
|
||||||
return triton_kernel_fused_experts(
|
return triton_kernel_fused_experts(
|
||||||
@@ -79,8 +82,8 @@ def triton_kernel_moe_forward(
|
|||||||
routing_data,
|
routing_data,
|
||||||
gather_idx,
|
gather_idx,
|
||||||
scatter_idx,
|
scatter_idx,
|
||||||
inplace=inplace,
|
inplace=False, # triton kernel doesn't support inplace
|
||||||
activation=activation,
|
activation=moe_runner_config.activation,
|
||||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||||
use_fp8_w8a8=use_fp8_w8a8,
|
use_fp8_w8a8=use_fp8_w8a8,
|
||||||
per_channel_quant=per_channel_quant,
|
per_channel_quant=per_channel_quant,
|
||||||
@@ -192,8 +195,7 @@ def triton_kernel_moe_with_bias_forward(
|
|||||||
w2_pcg,
|
w2_pcg,
|
||||||
b2: torch.Tensor,
|
b2: torch.Tensor,
|
||||||
topk_output: TopKOutput,
|
topk_output: TopKOutput,
|
||||||
inplace: bool = False,
|
moe_runner_config: MoeRunnerConfig,
|
||||||
activation: str = "silu",
|
|
||||||
use_fp8_w8a8: bool = False,
|
use_fp8_w8a8: bool = False,
|
||||||
per_channel_quant: bool = False,
|
per_channel_quant: bool = False,
|
||||||
global_num_experts: int = -1,
|
global_num_experts: int = -1,
|
||||||
@@ -203,10 +205,11 @@ def triton_kernel_moe_with_bias_forward(
|
|||||||
a1_scale: Optional[torch.Tensor] = None,
|
a1_scale: Optional[torch.Tensor] = None,
|
||||||
a2_scale: Optional[torch.Tensor] = None,
|
a2_scale: Optional[torch.Tensor] = None,
|
||||||
block_shape: Optional[list[int]] = None,
|
block_shape: Optional[list[int]] = None,
|
||||||
activation_alpha: Optional[float] = None,
|
|
||||||
swiglu_limit: Optional[int] = None,
|
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
assert topk_output.format.is_triton_kernel()
|
from sglang.srt.layers.moe.topk import TopKOutputChecker
|
||||||
|
|
||||||
|
assert TopKOutputChecker.format_is_triton_kernel(topk_output)
|
||||||
|
|
||||||
routing_data, gather_idx, scatter_idx = topk_output
|
routing_data, gather_idx, scatter_idx = topk_output
|
||||||
|
|
||||||
return triton_kernel_fused_experts_with_bias(
|
return triton_kernel_fused_experts_with_bias(
|
||||||
@@ -220,8 +223,8 @@ def triton_kernel_moe_with_bias_forward(
|
|||||||
routing_data=routing_data,
|
routing_data=routing_data,
|
||||||
gather_indx=gather_idx,
|
gather_indx=gather_idx,
|
||||||
scatter_indx=scatter_idx,
|
scatter_indx=scatter_idx,
|
||||||
inplace=inplace,
|
inplace=False, # triton kernel doesn't support inplace
|
||||||
activation=activation,
|
activation=moe_runner_config.activation,
|
||||||
use_fp8_w8a8=use_fp8_w8a8,
|
use_fp8_w8a8=use_fp8_w8a8,
|
||||||
per_channel_quant=per_channel_quant,
|
per_channel_quant=per_channel_quant,
|
||||||
global_num_experts=global_num_experts,
|
global_num_experts=global_num_experts,
|
||||||
@@ -231,8 +234,8 @@ def triton_kernel_moe_with_bias_forward(
|
|||||||
a1_scale=a1_scale,
|
a1_scale=a1_scale,
|
||||||
a2_scale=a2_scale,
|
a2_scale=a2_scale,
|
||||||
block_shape=block_shape,
|
block_shape=block_shape,
|
||||||
activation_alpha=activation_alpha,
|
gemm1_alpha=moe_runner_config.gemm1_alpha,
|
||||||
swiglu_limit=swiglu_limit,
|
gemm1_clamp_limit=moe_runner_config.gemm1_clamp_limit,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -258,10 +261,9 @@ def triton_kernel_fused_experts_with_bias(
|
|||||||
a1_scale: Optional[torch.Tensor] = None,
|
a1_scale: Optional[torch.Tensor] = None,
|
||||||
a2_scale: Optional[torch.Tensor] = None,
|
a2_scale: Optional[torch.Tensor] = None,
|
||||||
block_shape: Optional[list[int]] = None,
|
block_shape: Optional[list[int]] = None,
|
||||||
activation_alpha: Optional[float] = None,
|
gemm1_alpha: Optional[float] = None,
|
||||||
swiglu_limit: Optional[int] = None,
|
gemm1_clamp_limit: Optional[float] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
# print(f"here in triton moe with bias", b1.shape, b1.dtype, b2.shape, b2.dtype)
|
|
||||||
assert use_fp8_w8a8 == False, "use_fp8_w8a8 is not supported"
|
assert use_fp8_w8a8 == False, "use_fp8_w8a8 is not supported"
|
||||||
assert per_channel_quant == False, "per_channel_quant is not supported"
|
assert per_channel_quant == False, "per_channel_quant is not supported"
|
||||||
assert expert_map == None, "expert_map is not supported"
|
assert expert_map == None, "expert_map is not supported"
|
||||||
@@ -307,7 +309,7 @@ def triton_kernel_fused_experts_with_bias(
|
|||||||
|
|
||||||
act = FusedActivation(
|
act = FusedActivation(
|
||||||
FnSpecs("swiglu", swiglu_fn, ("alpha", "limit")),
|
FnSpecs("swiglu", swiglu_fn, ("alpha", "limit")),
|
||||||
(activation_alpha, swiglu_limit),
|
(gemm1_alpha, gemm1_clamp_limit),
|
||||||
2,
|
2,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
3
python/sglang/srt/layers/moe/moe_runner/__init__.py
Normal file
3
python/sglang/srt/layers/moe/moe_runner/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
from sglang.srt.layers.moe.moe_runner.base import MoeRunnerConfig
|
||||||
|
|
||||||
|
__all__ = ["MoeRunnerConfig"]
|
||||||
13
python/sglang/srt/layers/moe/moe_runner/base.py
Normal file
13
python/sglang/srt/layers/moe/moe_runner/base.py
Normal file
@@ -0,0 +1,13 @@
|
|||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class MoeRunnerConfig:
|
||||||
|
activation: str = "silu"
|
||||||
|
apply_router_weight_on_input: bool = False
|
||||||
|
inplace: bool = True
|
||||||
|
no_combine: bool = False
|
||||||
|
routed_scaling_factor: Optional[float] = None
|
||||||
|
gemm1_alpha: Optional[float] = None
|
||||||
|
gemm1_clamp_limit: Optional[float] = None
|
||||||
@@ -2,20 +2,26 @@ from sglang.srt.layers.moe.token_dispatcher.base_dispatcher import (
|
|||||||
BaseDispatcher,
|
BaseDispatcher,
|
||||||
BaseDispatcherConfig,
|
BaseDispatcherConfig,
|
||||||
DispatchOutput,
|
DispatchOutput,
|
||||||
|
DispatchOutputChecker,
|
||||||
DispatchOutputFormat,
|
DispatchOutputFormat,
|
||||||
)
|
)
|
||||||
from sglang.srt.layers.moe.token_dispatcher.deepep import (
|
from sglang.srt.layers.moe.token_dispatcher.deepep import (
|
||||||
|
AscendDeepEPLLOutput,
|
||||||
DeepEPConfig,
|
DeepEPConfig,
|
||||||
DeepEPDispatcher,
|
DeepEPDispatcher,
|
||||||
DeepEPLLOutput,
|
DeepEPLLOutput,
|
||||||
DeepEPNormalOutput,
|
DeepEPNormalOutput,
|
||||||
)
|
)
|
||||||
|
from sglang.srt.layers.moe.token_dispatcher.standard import StandardDispatchOutput
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
|
"AscendDeepEPLLOutput",
|
||||||
"BaseDispatcher",
|
"BaseDispatcher",
|
||||||
"BaseDispatcherConfig",
|
"BaseDispatcherConfig",
|
||||||
"DispatchOutput",
|
"DispatchOutput",
|
||||||
"DispatchOutputFormat",
|
"DispatchOutputFormat",
|
||||||
|
"DispatchOutputChecker",
|
||||||
|
"StandardDispatchOutput",
|
||||||
"DeepEPConfig",
|
"DeepEPConfig",
|
||||||
"DeepEPDispatcher",
|
"DeepEPDispatcher",
|
||||||
"DeepEPNormalOutput",
|
"DeepEPNormalOutput",
|
||||||
|
|||||||
@@ -2,35 +2,76 @@ from __future__ import annotations
|
|||||||
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from enum import Enum, auto
|
from enum import Enum, auto
|
||||||
from typing import Protocol, runtime_checkable
|
from typing import TYPE_CHECKING, Protocol, TypeGuard, Union, runtime_checkable
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from sglang.srt.layers.moe.token_dispatcher import (
|
||||||
|
AscendDeepEPLLOutput,
|
||||||
|
DeepEPLLOutput,
|
||||||
|
DeepEPNormalOutput,
|
||||||
|
StandardDispatchOutput,
|
||||||
|
)
|
||||||
|
|
||||||
class MoEA2ABackend(Enum):
|
|
||||||
none = "none"
|
|
||||||
deepep = "deepep"
|
|
||||||
|
|
||||||
def is_none(self):
|
class DispatchOutputChecker:
|
||||||
return self == MoEA2ABackend.none
|
|
||||||
|
|
||||||
def is_deepep(self):
|
@staticmethod
|
||||||
return self == MoEA2ABackend.deepep
|
def format_is_standard(
|
||||||
|
dispatch_output: DispatchOutput,
|
||||||
|
) -> TypeGuard[StandardDispatchOutput]:
|
||||||
|
return dispatch_output.format.is_standard()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def format_is_deepep_normal(
|
||||||
|
dispatch_output: DispatchOutput,
|
||||||
|
) -> TypeGuard[DeepEPNormalOutput]:
|
||||||
|
return dispatch_output.format.is_deepep_normal()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def format_is_deepep_ll(
|
||||||
|
dispatch_output: DispatchOutput,
|
||||||
|
) -> TypeGuard[DeepEPLLOutput]:
|
||||||
|
return dispatch_output.format.is_deepep_ll()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def format_is_deepep(
|
||||||
|
dispatch_output: DispatchOutput,
|
||||||
|
) -> TypeGuard[Union[DeepEPNormalOutput, DeepEPLLOutput]]:
|
||||||
|
return dispatch_output.format.is_deepep()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def format_is_ascent_ll(
|
||||||
|
dispatch_output: DispatchOutput,
|
||||||
|
) -> TypeGuard[AscendDeepEPLLOutput]:
|
||||||
|
return dispatch_output.format.is_ascent_ll()
|
||||||
|
|
||||||
|
|
||||||
class DispatchOutputFormat(Enum):
|
class DispatchOutputFormat(Enum):
|
||||||
standard = auto()
|
|
||||||
deepep_normal = auto()
|
STANDARD = auto()
|
||||||
deepep_ll = auto()
|
DEEPEP_NORMAL = auto()
|
||||||
|
DEEPEP_LL = auto()
|
||||||
|
ASCENT_LL = auto()
|
||||||
|
|
||||||
def is_standard(self) -> bool:
|
def is_standard(self) -> bool:
|
||||||
return self == DispatchOutputFormat.standard
|
return self == DispatchOutputFormat.STANDARD
|
||||||
|
|
||||||
def is_deepep_normal(self) -> bool:
|
def is_deepep_normal(self) -> bool:
|
||||||
return self == DispatchOutputFormat.deepep_normal
|
return self == DispatchOutputFormat.DEEPEP_NORMAL
|
||||||
|
|
||||||
def is_deepep_ll(self) -> bool:
|
def is_deepep_ll(self) -> bool:
|
||||||
return self == DispatchOutputFormat.deepep_ll
|
return self == DispatchOutputFormat.DEEPEP_LL
|
||||||
|
|
||||||
|
def is_deepep(self) -> bool:
|
||||||
|
return self in [
|
||||||
|
DispatchOutputFormat.DEEPEP_NORMAL,
|
||||||
|
DispatchOutputFormat.DEEPEP_LL,
|
||||||
|
]
|
||||||
|
|
||||||
|
def is_ascent_ll(self) -> bool:
|
||||||
|
return self == DispatchOutputFormat.ASCENT_LL
|
||||||
|
|
||||||
|
|
||||||
@runtime_checkable
|
@runtime_checkable
|
||||||
|
|||||||
@@ -2,27 +2,17 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import logging
|
import logging
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import (
|
from typing import TYPE_CHECKING, List, NamedTuple, Optional, Tuple, Union
|
||||||
TYPE_CHECKING,
|
|
||||||
List,
|
|
||||||
NamedTuple,
|
|
||||||
Optional,
|
|
||||||
Protocol,
|
|
||||||
Tuple,
|
|
||||||
Union,
|
|
||||||
runtime_checkable,
|
|
||||||
)
|
|
||||||
|
|
||||||
from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder
|
from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder
|
||||||
|
from sglang.srt.layers.moe import DeepEPMode, get_deepep_config, is_tbo_enabled
|
||||||
from sglang.srt.layers.moe.token_dispatcher.base_dispatcher import (
|
from sglang.srt.layers.moe.token_dispatcher.base_dispatcher import (
|
||||||
BaseDispatcher,
|
BaseDispatcher,
|
||||||
BaseDispatcherConfig,
|
BaseDispatcherConfig,
|
||||||
DispatchOutput,
|
DispatchOutput,
|
||||||
DispatchOutputFormat,
|
DispatchOutputFormat,
|
||||||
)
|
)
|
||||||
from sglang.srt.layers.moe.utils import DeepEPMode
|
|
||||||
from sglang.srt.layers.quantization import deep_gemm_wrapper
|
from sglang.srt.layers.quantization import deep_gemm_wrapper
|
||||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
|
||||||
from sglang.srt.utils import (
|
from sglang.srt.utils import (
|
||||||
get_bool_env_var,
|
get_bool_env_var,
|
||||||
get_int_env_var,
|
get_int_env_var,
|
||||||
@@ -72,7 +62,7 @@ class DeepEPNormalOutput(NamedTuple):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def format(self) -> DispatchOutputFormat:
|
def format(self) -> DispatchOutputFormat:
|
||||||
return DispatchOutputFormat.deepep_normal
|
return DispatchOutputFormat.DEEPEP_NORMAL
|
||||||
|
|
||||||
|
|
||||||
class DeepEPLLOutput(NamedTuple):
|
class DeepEPLLOutput(NamedTuple):
|
||||||
@@ -86,7 +76,7 @@ class DeepEPLLOutput(NamedTuple):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def format(self) -> DispatchOutputFormat:
|
def format(self) -> DispatchOutputFormat:
|
||||||
return DispatchOutputFormat.deepep_ll
|
return DispatchOutputFormat.DEEPEP_LL
|
||||||
|
|
||||||
|
|
||||||
class AscendDeepEPLLOutput(NamedTuple):
|
class AscendDeepEPLLOutput(NamedTuple):
|
||||||
@@ -101,7 +91,7 @@ class AscendDeepEPLLOutput(NamedTuple):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def format(self) -> DispatchOutputFormat:
|
def format(self) -> DispatchOutputFormat:
|
||||||
return DispatchOutputFormat.deepep_ll
|
return DispatchOutputFormat.ASCENT_LL
|
||||||
|
|
||||||
|
|
||||||
assert isinstance(DeepEPNormalOutput, DispatchOutput)
|
assert isinstance(DeepEPNormalOutput, DispatchOutput)
|
||||||
@@ -128,8 +118,8 @@ class DeepEPBuffer:
|
|||||||
hidden_size: int,
|
hidden_size: int,
|
||||||
param_bytes: int,
|
param_bytes: int,
|
||||||
deepep_mode: DeepEPMode,
|
deepep_mode: DeepEPMode,
|
||||||
num_max_dispatch_tokens_per_rank: int = None,
|
num_max_dispatch_tokens_per_rank: int = -1,
|
||||||
num_experts: int = None,
|
num_experts: int = -1,
|
||||||
):
|
):
|
||||||
if cls._buffer is not None:
|
if cls._buffer is not None:
|
||||||
return cls._buffer
|
return cls._buffer
|
||||||
@@ -156,8 +146,8 @@ class DeepEPBuffer:
|
|||||||
num_rdma_bytes,
|
num_rdma_bytes,
|
||||||
)
|
)
|
||||||
if deepep_mode.enable_low_latency():
|
if deepep_mode.enable_low_latency():
|
||||||
assert num_max_dispatch_tokens_per_rank is not None
|
assert num_max_dispatch_tokens_per_rank != -1
|
||||||
assert num_experts is not None and num_experts % group.size() == 0
|
assert num_experts != -1 and num_experts % group.size() == 0
|
||||||
num_rdma_bytes = max(
|
num_rdma_bytes = max(
|
||||||
Buffer.get_low_latency_rdma_size_hint(
|
Buffer.get_low_latency_rdma_size_hint(
|
||||||
num_max_dispatch_tokens_per_rank,
|
num_max_dispatch_tokens_per_rank,
|
||||||
@@ -181,7 +171,7 @@ class DeepEPBuffer:
|
|||||||
).multi_processor_count
|
).multi_processor_count
|
||||||
if (
|
if (
|
||||||
(deepep_mode != DeepEPMode.LOW_LATENCY)
|
(deepep_mode != DeepEPMode.LOW_LATENCY)
|
||||||
and not global_server_args_dict["enable_two_batch_overlap"]
|
and not is_tbo_enabled()
|
||||||
and (DeepEPConfig.get_instance().num_sms < total_num_sms // 2)
|
and (DeepEPConfig.get_instance().num_sms < total_num_sms // 2)
|
||||||
):
|
):
|
||||||
logger.warning(
|
logger.warning(
|
||||||
@@ -226,7 +216,7 @@ class DeepEPConfig(BaseDispatcherConfig):
|
|||||||
_instance = None
|
_instance = None
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
config_str = global_server_args_dict["deepep_config"]
|
config_str = get_deepep_config()
|
||||||
if config_str:
|
if config_str:
|
||||||
config_parsed = load_json_config(config_str)
|
config_parsed = load_json_config(config_str)
|
||||||
if torch.distributed.get_rank() == 0:
|
if torch.distributed.get_rank() == 0:
|
||||||
|
|||||||
@@ -13,7 +13,7 @@ class StandardDispatchOutput(NamedTuple):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def format(self) -> DispatchOutputFormat:
|
def format(self) -> DispatchOutputFormat:
|
||||||
return DispatchOutputFormat.standard
|
return DispatchOutputFormat.STANDARD
|
||||||
|
|
||||||
|
|
||||||
assert isinstance(StandardDispatchOutput, DispatchOutput)
|
assert isinstance(StandardDispatchOutput, DispatchOutput)
|
||||||
|
|||||||
@@ -14,9 +14,18 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
import math
|
import math
|
||||||
|
from dataclasses import dataclass
|
||||||
from enum import Enum, auto
|
from enum import Enum, auto
|
||||||
from typing import Callable, NamedTuple, Optional, Protocol, runtime_checkable
|
from typing import (
|
||||||
|
Callable,
|
||||||
|
NamedTuple,
|
||||||
|
Optional,
|
||||||
|
Protocol,
|
||||||
|
TypeGuard,
|
||||||
|
runtime_checkable,
|
||||||
|
)
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
@@ -28,7 +37,10 @@ from sglang.srt.eplb.expert_location_dispatch import (
|
|||||||
ExpertLocationDispatchInfo,
|
ExpertLocationDispatchInfo,
|
||||||
topk_ids_logical_to_physical,
|
topk_ids_logical_to_physical,
|
||||||
)
|
)
|
||||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
from sglang.srt.layers.moe import (
|
||||||
|
get_moe_runner_backend,
|
||||||
|
should_use_flashinfer_trtllm_moe,
|
||||||
|
)
|
||||||
from sglang.srt.utils import (
|
from sglang.srt.utils import (
|
||||||
cpu_has_amx_support,
|
cpu_has_amx_support,
|
||||||
get_bool_env_var,
|
get_bool_env_var,
|
||||||
@@ -43,6 +55,7 @@ try:
|
|||||||
from triton_kernels.routing import GatherIndx, RoutingData, ScatterIndx, routing
|
from triton_kernels.routing import GatherIndx, RoutingData, ScatterIndx, routing
|
||||||
except ImportError:
|
except ImportError:
|
||||||
pass
|
pass
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
_is_cuda = is_cuda()
|
_is_cuda = is_cuda()
|
||||||
@@ -65,13 +78,48 @@ if _use_aiter:
|
|||||||
if _is_npu:
|
if _is_npu:
|
||||||
import torch_npu
|
import torch_npu
|
||||||
|
|
||||||
|
# -------------------------------- TopKConfig ---------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class TopKConfig:
|
||||||
|
top_k: int
|
||||||
|
use_grouped_topk: bool = False
|
||||||
|
topk_group: int = 0
|
||||||
|
num_expert_group: int = 0
|
||||||
|
renormalize: bool = True
|
||||||
|
num_fused_shared_experts: int = 0
|
||||||
|
custom_routing_function: Optional[Callable] = None
|
||||||
|
correction_bias: Optional[torch.Tensor] = None
|
||||||
|
torch_native: bool = False
|
||||||
|
routed_scaling_factor: Optional[float] = None
|
||||||
|
apply_routed_scaling_factor_on_output: bool = False
|
||||||
|
|
||||||
|
|
||||||
# -------------------------------- TopKOutput ---------------------------------------
|
# -------------------------------- TopKOutput ---------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TopKOutputChecker:
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def format_is_standard(topk_output: TopKOutput) -> TypeGuard[StandardTopKOutput]:
|
||||||
|
return topk_output.format.is_standard()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def format_is_triton_kernel(
|
||||||
|
topk_output: TopKOutput,
|
||||||
|
) -> TypeGuard[TritonKernelTopKOutput]:
|
||||||
|
return topk_output.format.is_triton_kernel()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def format_is_bypassed(topk_output: TopKOutput) -> TypeGuard[BypassedTopKOutput]:
|
||||||
|
return topk_output.format.is_bypassed()
|
||||||
|
|
||||||
|
|
||||||
class TopKOutputFormat(Enum):
|
class TopKOutputFormat(Enum):
|
||||||
STANDARD = auto()
|
STANDARD = auto()
|
||||||
TRITON_KERNEL = auto()
|
TRITON_KERNEL = auto()
|
||||||
|
BYPASSED = auto()
|
||||||
|
|
||||||
def is_standard(self) -> bool:
|
def is_standard(self) -> bool:
|
||||||
return self == TopKOutputFormat.STANDARD
|
return self == TopKOutputFormat.STANDARD
|
||||||
@@ -79,6 +127,9 @@ class TopKOutputFormat(Enum):
|
|||||||
def is_triton_kernel(self) -> bool:
|
def is_triton_kernel(self) -> bool:
|
||||||
return self == TopKOutputFormat.TRITON_KERNEL
|
return self == TopKOutputFormat.TRITON_KERNEL
|
||||||
|
|
||||||
|
def is_bypassed(self) -> bool:
|
||||||
|
return self == TopKOutputFormat.BYPASSED
|
||||||
|
|
||||||
|
|
||||||
@runtime_checkable
|
@runtime_checkable
|
||||||
class TopKOutput(Protocol):
|
class TopKOutput(Protocol):
|
||||||
@@ -114,6 +165,20 @@ class TritonKernelTopKOutput(NamedTuple):
|
|||||||
return TopKOutputFormat.TRITON_KERNEL
|
return TopKOutputFormat.TRITON_KERNEL
|
||||||
|
|
||||||
|
|
||||||
|
class BypassedTopKOutput(NamedTuple):
|
||||||
|
"""Bypassed top-k output format."""
|
||||||
|
|
||||||
|
hidden_states: torch.Tensor
|
||||||
|
router_logits: torch.Tensor
|
||||||
|
topk_config: TopKConfig
|
||||||
|
num_token_non_padded: Optional[torch.Tensor] = None
|
||||||
|
expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def format(self) -> TopKOutputFormat:
|
||||||
|
return TopKOutputFormat.BYPASSED
|
||||||
|
|
||||||
|
|
||||||
# -------------------------------- TopK ---------------------------------------
|
# -------------------------------- TopK ---------------------------------------
|
||||||
|
|
||||||
|
|
||||||
@@ -124,8 +189,8 @@ class TopK(CustomOp):
|
|||||||
top_k: int,
|
top_k: int,
|
||||||
*,
|
*,
|
||||||
use_grouped_topk: bool = False,
|
use_grouped_topk: bool = False,
|
||||||
topk_group: Optional[int] = None,
|
topk_group: int = 0,
|
||||||
num_expert_group: Optional[int] = None,
|
num_expert_group: int = 0,
|
||||||
renormalize: bool = True,
|
renormalize: bool = True,
|
||||||
num_fused_shared_experts: int = 0,
|
num_fused_shared_experts: int = 0,
|
||||||
custom_routing_function: Optional[Callable] = None,
|
custom_routing_function: Optional[Callable] = None,
|
||||||
@@ -136,19 +201,23 @@ class TopK(CustomOp):
|
|||||||
# NOTE: scoring_func is not used for now, but we keep it for future use
|
# NOTE: scoring_func is not used for now, but we keep it for future use
|
||||||
# see https://github.com/sgl-project/sglang/pull/4505 for more details
|
# see https://github.com/sgl-project/sglang/pull/4505 for more details
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
if use_grouped_topk:
|
if use_grouped_topk:
|
||||||
assert num_expert_group is not None and topk_group is not None
|
assert num_expert_group is not None and topk_group is not None
|
||||||
self.top_k = top_k
|
|
||||||
self.use_grouped_topk = use_grouped_topk
|
|
||||||
self.renormalize = renormalize
|
|
||||||
self.topk_group = topk_group
|
|
||||||
self.num_expert_group = num_expert_group
|
|
||||||
self.num_fused_shared_experts = num_fused_shared_experts
|
|
||||||
self.custom_routing_function = custom_routing_function
|
|
||||||
self.correction_bias = correction_bias
|
|
||||||
self.routed_scaling_factor = routed_scaling_factor
|
|
||||||
|
|
||||||
self.use_triton_kernels = global_server_args_dict["enable_triton_kernel_moe"]
|
self.topk_config = TopKConfig(
|
||||||
|
top_k=top_k,
|
||||||
|
use_grouped_topk=use_grouped_topk,
|
||||||
|
renormalize=renormalize,
|
||||||
|
topk_group=topk_group,
|
||||||
|
num_expert_group=num_expert_group,
|
||||||
|
num_fused_shared_experts=num_fused_shared_experts,
|
||||||
|
custom_routing_function=custom_routing_function,
|
||||||
|
correction_bias=correction_bias,
|
||||||
|
routed_scaling_factor=routed_scaling_factor,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.use_triton_kernels = get_moe_runner_backend().is_triton_kernel()
|
||||||
|
|
||||||
def forward_native(
|
def forward_native(
|
||||||
self,
|
self,
|
||||||
@@ -158,20 +227,11 @@ class TopK(CustomOp):
|
|||||||
num_token_non_padded: Optional[torch.Tensor] = None,
|
num_token_non_padded: Optional[torch.Tensor] = None,
|
||||||
expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None,
|
expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None,
|
||||||
) -> TopKOutput:
|
) -> TopKOutput:
|
||||||
torch_native = True
|
self.topk_config.torch_native = True
|
||||||
return select_experts(
|
return select_experts(
|
||||||
hidden_states=hidden_states,
|
hidden_states=hidden_states,
|
||||||
router_logits=router_logits,
|
router_logits=router_logits,
|
||||||
top_k=self.top_k,
|
topk_config=self.topk_config,
|
||||||
use_grouped_topk=self.use_grouped_topk,
|
|
||||||
renormalize=self.renormalize,
|
|
||||||
topk_group=self.topk_group,
|
|
||||||
num_expert_group=self.num_expert_group,
|
|
||||||
num_fused_shared_experts=self.num_fused_shared_experts,
|
|
||||||
custom_routing_function=self.custom_routing_function,
|
|
||||||
correction_bias=self.correction_bias,
|
|
||||||
torch_native=torch_native,
|
|
||||||
routed_scaling_factor=self.routed_scaling_factor,
|
|
||||||
num_token_non_padded=num_token_non_padded,
|
num_token_non_padded=num_token_non_padded,
|
||||||
expert_location_dispatch_info=expert_location_dispatch_info,
|
expert_location_dispatch_info=expert_location_dispatch_info,
|
||||||
)
|
)
|
||||||
@@ -187,24 +247,28 @@ class TopK(CustomOp):
|
|||||||
if self.use_triton_kernels:
|
if self.use_triton_kernels:
|
||||||
# renormalize=True is equivalent to sm_first=False
|
# renormalize=True is equivalent to sm_first=False
|
||||||
routing_data, gather_idx, scatter_idx = routing(
|
routing_data, gather_idx, scatter_idx = routing(
|
||||||
router_logits, self.top_k, sm_first=not self.renormalize
|
router_logits,
|
||||||
|
self.topk_config.top_k,
|
||||||
|
sm_first=not self.topk_config.renormalize,
|
||||||
)
|
)
|
||||||
return TritonKernelTopKOutput(routing_data, gather_idx, scatter_idx)
|
return TritonKernelTopKOutput(routing_data, gather_idx, scatter_idx)
|
||||||
|
elif (
|
||||||
|
should_use_flashinfer_trtllm_moe()
|
||||||
|
or get_moe_runner_backend().is_flashinfer_mxfp4()
|
||||||
|
):
|
||||||
|
return BypassedTopKOutput(
|
||||||
|
hidden_states=hidden_states,
|
||||||
|
router_logits=router_logits,
|
||||||
|
topk_config=self.topk_config,
|
||||||
|
num_token_non_padded=num_token_non_padded,
|
||||||
|
expert_location_dispatch_info=expert_location_dispatch_info,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
torch_native = False
|
self.topk_config.torch_native = False
|
||||||
return select_experts(
|
return select_experts(
|
||||||
hidden_states=hidden_states,
|
hidden_states=hidden_states,
|
||||||
router_logits=router_logits,
|
router_logits=router_logits,
|
||||||
top_k=self.top_k,
|
topk_config=self.topk_config,
|
||||||
use_grouped_topk=self.use_grouped_topk,
|
|
||||||
renormalize=self.renormalize,
|
|
||||||
topk_group=self.topk_group,
|
|
||||||
num_expert_group=self.num_expert_group,
|
|
||||||
num_fused_shared_experts=self.num_fused_shared_experts,
|
|
||||||
custom_routing_function=self.custom_routing_function,
|
|
||||||
correction_bias=self.correction_bias,
|
|
||||||
torch_native=torch_native,
|
|
||||||
routed_scaling_factor=self.routed_scaling_factor,
|
|
||||||
num_token_non_padded=num_token_non_padded,
|
num_token_non_padded=num_token_non_padded,
|
||||||
expert_location_dispatch_info=expert_location_dispatch_info,
|
expert_location_dispatch_info=expert_location_dispatch_info,
|
||||||
)
|
)
|
||||||
@@ -220,15 +284,7 @@ class TopK(CustomOp):
|
|||||||
return select_experts(
|
return select_experts(
|
||||||
hidden_states=hidden_states,
|
hidden_states=hidden_states,
|
||||||
router_logits=router_logits,
|
router_logits=router_logits,
|
||||||
top_k=self.top_k,
|
topk_config=self.topk_config,
|
||||||
use_grouped_topk=self.use_grouped_topk,
|
|
||||||
renormalize=self.renormalize,
|
|
||||||
topk_group=self.topk_group,
|
|
||||||
num_expert_group=self.num_expert_group,
|
|
||||||
num_fused_shared_experts=self.num_fused_shared_experts,
|
|
||||||
custom_routing_function=self.custom_routing_function,
|
|
||||||
correction_bias=self.correction_bias,
|
|
||||||
routed_scaling_factor=self.routed_scaling_factor,
|
|
||||||
num_token_non_padded=num_token_non_padded,
|
num_token_non_padded=num_token_non_padded,
|
||||||
expert_location_dispatch_info=expert_location_dispatch_info,
|
expert_location_dispatch_info=expert_location_dispatch_info,
|
||||||
)
|
)
|
||||||
@@ -244,35 +300,29 @@ class TopK(CustomOp):
|
|||||||
global_num_experts = router_logits.shape[-1]
|
global_num_experts = router_logits.shape[-1]
|
||||||
|
|
||||||
# NOTE: now npu_moe_gating_top_k can only support `group_count=256` pattern
|
# NOTE: now npu_moe_gating_top_k can only support `group_count=256` pattern
|
||||||
if global_num_experts == 256:
|
if global_num_experts == 256 and self.topk_config.renormalize is False:
|
||||||
|
|
||||||
|
routed_scaling_factor = self.topk_config.routed_scaling_factor or 1
|
||||||
router_logits = router_logits.to(torch.float32)
|
router_logits = router_logits.to(torch.float32)
|
||||||
|
|
||||||
return torch_npu.npu_moe_gating_top_k(
|
return torch_npu.npu_moe_gating_top_k(
|
||||||
router_logits,
|
router_logits,
|
||||||
k=self.top_k,
|
k=self.topk_config.top_k,
|
||||||
bias=self.correction_bias.to(torch.float32),
|
bias=self.topk_config.correction_bias.to(torch.float32),
|
||||||
k_group=self.topk_group,
|
k_group=self.topk_config.topk_group,
|
||||||
group_count=self.num_expert_group,
|
group_count=self.topk_config.num_expert_group,
|
||||||
group_select_mode=1,
|
group_select_mode=1,
|
||||||
renorm=0,
|
renorm=0,
|
||||||
norm_type=1,
|
norm_type=1,
|
||||||
routed_scaling_factor=1,
|
routed_scaling_factor=routed_scaling_factor,
|
||||||
eps=float(1e-20),
|
eps=float(1e-20),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
torch_native = True
|
self.topk_config.torch_native = True
|
||||||
return select_experts(
|
return select_experts(
|
||||||
hidden_states=hidden_states,
|
hidden_states=hidden_states,
|
||||||
router_logits=router_logits,
|
router_logits=router_logits,
|
||||||
top_k=self.top_k,
|
topk_config=self.topk_config,
|
||||||
use_grouped_topk=self.use_grouped_topk,
|
|
||||||
renormalize=self.renormalize,
|
|
||||||
topk_group=self.topk_group,
|
|
||||||
num_expert_group=self.num_expert_group,
|
|
||||||
num_fused_shared_experts=self.num_fused_shared_experts,
|
|
||||||
custom_routing_function=self.custom_routing_function,
|
|
||||||
correction_bias=self.correction_bias,
|
|
||||||
torch_native=torch_native,
|
|
||||||
routed_scaling_factor=self.routed_scaling_factor,
|
|
||||||
num_token_non_padded=num_token_non_padded,
|
num_token_non_padded=num_token_non_padded,
|
||||||
expert_location_dispatch_info=expert_location_dispatch_info,
|
expert_location_dispatch_info=expert_location_dispatch_info,
|
||||||
)
|
)
|
||||||
@@ -670,20 +720,23 @@ else:
|
|||||||
def select_experts(
|
def select_experts(
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
router_logits: torch.Tensor,
|
router_logits: torch.Tensor,
|
||||||
top_k: int,
|
topk_config: TopKConfig,
|
||||||
*,
|
*,
|
||||||
use_grouped_topk: bool = False,
|
|
||||||
renormalize: bool = False,
|
|
||||||
topk_group: Optional[int] = None,
|
|
||||||
num_expert_group: Optional[int] = None,
|
|
||||||
num_fused_shared_experts: int = 0,
|
|
||||||
custom_routing_function: Optional[Callable] = None,
|
|
||||||
correction_bias: Optional[torch.Tensor] = None,
|
|
||||||
torch_native: bool = False,
|
|
||||||
routed_scaling_factor: Optional[float] = None,
|
|
||||||
num_token_non_padded: Optional[torch.Tensor] = None,
|
num_token_non_padded: Optional[torch.Tensor] = None,
|
||||||
expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None,
|
expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None,
|
||||||
) -> TopKOutput:
|
) -> StandardTopKOutput:
|
||||||
|
|
||||||
|
top_k = topk_config.top_k
|
||||||
|
use_grouped_topk = topk_config.use_grouped_topk
|
||||||
|
topk_group = topk_config.topk_group
|
||||||
|
num_expert_group = topk_config.num_expert_group
|
||||||
|
renormalize = topk_config.renormalize
|
||||||
|
num_fused_shared_experts = topk_config.num_fused_shared_experts
|
||||||
|
custom_routing_function = topk_config.custom_routing_function
|
||||||
|
correction_bias = topk_config.correction_bias
|
||||||
|
torch_native = topk_config.torch_native
|
||||||
|
routed_scaling_factor = topk_config.routed_scaling_factor
|
||||||
|
|
||||||
router_logits, correction_bias = (
|
router_logits, correction_bias = (
|
||||||
expert_location_dispatch.transform_select_experts_inputs(
|
expert_location_dispatch.transform_select_experts_inputs(
|
||||||
router_logits=router_logits,
|
router_logits=router_logits,
|
||||||
|
|||||||
@@ -1,55 +1,80 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import importlib.util
|
import importlib.util
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
|
from typing import TYPE_CHECKING, Optional
|
||||||
|
|
||||||
from packaging import version as pkg_version
|
from packaging import version as pkg_version
|
||||||
|
|
||||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
from sglang.srt.utils import logger
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
@lru_cache(maxsize=1)
|
from sglang.srt.server_args import ServerArgs
|
||||||
def should_use_flashinfer_trtllm_moe():
|
|
||||||
result = global_server_args_dict["enable_flashinfer_trtllm_moe"] and (
|
|
||||||
not importlib.util.find_spec("flashinfer")
|
|
||||||
or pkg_version.parse(__import__("flashinfer").__version__)
|
|
||||||
>= pkg_version.parse("0.2.9rc1")
|
|
||||||
)
|
|
||||||
return result
|
|
||||||
|
|
||||||
|
|
||||||
class MoeA2ABackend(Enum):
|
class MoeA2ABackend(Enum):
|
||||||
|
|
||||||
STANDARD = ("standard", "none")
|
NONE = "none"
|
||||||
DEEPEP = "deepep"
|
DEEPEP = "deepep"
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _missing_(cls, value):
|
def _missing_(cls, value):
|
||||||
if value is None:
|
if value is None:
|
||||||
return cls.STANDARD
|
return cls.NONE
|
||||||
for member in cls:
|
for member in cls:
|
||||||
if value in member.value:
|
if value == member.value:
|
||||||
return member
|
return member
|
||||||
raise ValueError(f"No {cls.__name__} member for value {value}")
|
raise ValueError(f"No {cls.__name__} member for value {value}")
|
||||||
|
|
||||||
|
def is_none(self):
|
||||||
|
return self == MoeA2ABackend.NONE
|
||||||
|
|
||||||
def is_deepep(self):
|
def is_deepep(self):
|
||||||
return self == MoeA2ABackend.DEEPEP
|
return self == MoeA2ABackend.DEEPEP
|
||||||
|
|
||||||
def is_standard(self):
|
|
||||||
return self == MoeA2ABackend.STANDARD
|
class MoeRunnerBackend(Enum):
|
||||||
|
|
||||||
|
AUTO = "auto"
|
||||||
|
TRITON = "triton"
|
||||||
|
TRITON_KERNEL = "triton_kernel"
|
||||||
|
FLASHINFER = "flashinfer_trtllm"
|
||||||
|
FLASHINFER_CUTLASS = "flashinfer_cutlass"
|
||||||
|
FLASHINFER_MXFP4 = "flashinfer_mxfp4"
|
||||||
|
|
||||||
|
def is_auto(self):
|
||||||
|
return self == MoeRunnerBackend.AUTO
|
||||||
|
|
||||||
|
def is_triton(self):
|
||||||
|
return self == MoeRunnerBackend.TRITON
|
||||||
|
|
||||||
|
def is_triton_kernel(self):
|
||||||
|
return self == MoeRunnerBackend.TRITON_KERNEL
|
||||||
|
|
||||||
|
def is_flashinfer_trtllm(self):
|
||||||
|
return self == MoeRunnerBackend.FLASHINFER
|
||||||
|
|
||||||
|
def is_flashinfer_cutlass(self):
|
||||||
|
return self == MoeRunnerBackend.FLASHINFER_CUTLASS
|
||||||
|
|
||||||
|
def is_flashinfer_mxfp4(self):
|
||||||
|
return self == MoeRunnerBackend.FLASHINFER_MXFP4
|
||||||
|
|
||||||
|
|
||||||
class DeepEPMode(Enum):
|
class DeepEPMode(Enum):
|
||||||
|
|
||||||
NORMAL = "normal"
|
NORMAL = "normal"
|
||||||
LOW_LATENCY = "low_latency"
|
LOW_LATENCY = "low_latency"
|
||||||
AUTO = "auto"
|
AUTO = "auto"
|
||||||
|
|
||||||
def enable_normal(self):
|
def enable_normal(self) -> bool:
|
||||||
return self in [DeepEPMode.NORMAL, DeepEPMode.AUTO]
|
return self in [DeepEPMode.NORMAL, DeepEPMode.AUTO]
|
||||||
|
|
||||||
def enable_low_latency(self):
|
def enable_low_latency(self) -> bool:
|
||||||
return self in [DeepEPMode.LOW_LATENCY, DeepEPMode.AUTO]
|
return self in [DeepEPMode.LOW_LATENCY, DeepEPMode.AUTO]
|
||||||
|
|
||||||
def resolve(self, is_extend_in_batch: bool):
|
def resolve(self, is_extend_in_batch: bool) -> DeepEPMode:
|
||||||
if self != DeepEPMode.AUTO:
|
if self != DeepEPMode.AUTO:
|
||||||
return self
|
return self
|
||||||
|
|
||||||
@@ -57,3 +82,96 @@ class DeepEPMode(Enum):
|
|||||||
return DeepEPMode.NORMAL
|
return DeepEPMode.NORMAL
|
||||||
else:
|
else:
|
||||||
return DeepEPMode.LOW_LATENCY
|
return DeepEPMode.LOW_LATENCY
|
||||||
|
|
||||||
|
def is_normal(self) -> bool:
|
||||||
|
return self == DeepEPMode.NORMAL
|
||||||
|
|
||||||
|
def is_low_latency(self) -> bool:
|
||||||
|
return self == DeepEPMode.LOW_LATENCY
|
||||||
|
|
||||||
|
def is_auto(self) -> bool:
|
||||||
|
return self == DeepEPMode.AUTO
|
||||||
|
|
||||||
|
|
||||||
|
MOE_A2A_BACKEND: Optional[MoeA2ABackend] = None
|
||||||
|
MOE_RUNNER_BACKEND: Optional[MoeRunnerBackend] = None
|
||||||
|
DEEPEP_MODE: Optional[DeepEPMode] = None
|
||||||
|
IS_TBO_ENABLED: Optional[bool] = None
|
||||||
|
TBO_TOKEN_DISTRIBUTION_THRESHOLD: Optional[float] = None
|
||||||
|
DEEPEP_CONFIG: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
def initialize_moe_config(server_args: ServerArgs):
|
||||||
|
global MOE_A2A_BACKEND
|
||||||
|
global MOE_RUNNER_BACKEND
|
||||||
|
global DEEPEP_MODE
|
||||||
|
global DEEPEP_CONFIG
|
||||||
|
global IS_TBO_ENABLED
|
||||||
|
global TBO_TOKEN_DISTRIBUTION_THRESHOLD
|
||||||
|
|
||||||
|
MOE_A2A_BACKEND = MoeA2ABackend(server_args.moe_a2a_backend)
|
||||||
|
MOE_RUNNER_BACKEND = MoeRunnerBackend(server_args.moe_runner_backend)
|
||||||
|
DEEPEP_MODE = DeepEPMode(server_args.deepep_mode)
|
||||||
|
DEEPEP_CONFIG = server_args.deepep_config or ""
|
||||||
|
IS_TBO_ENABLED = server_args.enable_two_batch_overlap
|
||||||
|
TBO_TOKEN_DISTRIBUTION_THRESHOLD = server_args.tbo_token_distribution_threshold
|
||||||
|
|
||||||
|
|
||||||
|
def get_moe_a2a_backend() -> MoeA2ABackend:
|
||||||
|
global MOE_A2A_BACKEND
|
||||||
|
if MOE_A2A_BACKEND is None:
|
||||||
|
logger.warning("MOE_A2A_BACKEND is not initialized, using default backend")
|
||||||
|
MOE_A2A_BACKEND = MoeA2ABackend(None)
|
||||||
|
return MOE_A2A_BACKEND
|
||||||
|
|
||||||
|
|
||||||
|
def get_moe_runner_backend() -> MoeRunnerBackend:
|
||||||
|
global MOE_RUNNER_BACKEND
|
||||||
|
if MOE_RUNNER_BACKEND is None:
|
||||||
|
logger.warning("MOE_RUNNER_BACKEND is not initialized, using triton backend")
|
||||||
|
MOE_RUNNER_BACKEND = MoeRunnerBackend("triton")
|
||||||
|
return MOE_RUNNER_BACKEND
|
||||||
|
|
||||||
|
|
||||||
|
def get_deepep_mode() -> DeepEPMode:
|
||||||
|
global DEEPEP_MODE
|
||||||
|
if DEEPEP_MODE is None:
|
||||||
|
logger.warning("DEEPEP_MODE is not initialized, using auto mode")
|
||||||
|
DEEPEP_MODE = DeepEPMode("auto")
|
||||||
|
return DEEPEP_MODE
|
||||||
|
|
||||||
|
|
||||||
|
def get_deepep_config() -> str:
|
||||||
|
global DEEPEP_CONFIG
|
||||||
|
if DEEPEP_CONFIG is None:
|
||||||
|
logger.warning("DEEPEP_CONFIG is not initialized, using default config")
|
||||||
|
DEEPEP_CONFIG = ""
|
||||||
|
return DEEPEP_CONFIG
|
||||||
|
|
||||||
|
|
||||||
|
def is_tbo_enabled() -> bool:
|
||||||
|
global IS_TBO_ENABLED
|
||||||
|
if IS_TBO_ENABLED is None:
|
||||||
|
logger.warning("IS_TBO_ENABLED is not initialized, using False")
|
||||||
|
IS_TBO_ENABLED = False
|
||||||
|
return IS_TBO_ENABLED
|
||||||
|
|
||||||
|
|
||||||
|
def get_tbo_token_distribution_threshold() -> float:
|
||||||
|
global TBO_TOKEN_DISTRIBUTION_THRESHOLD
|
||||||
|
if TBO_TOKEN_DISTRIBUTION_THRESHOLD is None:
|
||||||
|
logger.warning(
|
||||||
|
"TBO_TOKEN_DISTRIBUTION_THRESHOLD is not initialized, using 0.48"
|
||||||
|
)
|
||||||
|
TBO_TOKEN_DISTRIBUTION_THRESHOLD = 0.48
|
||||||
|
return TBO_TOKEN_DISTRIBUTION_THRESHOLD
|
||||||
|
|
||||||
|
|
||||||
|
@lru_cache(maxsize=1)
|
||||||
|
def should_use_flashinfer_trtllm_moe():
|
||||||
|
result = get_moe_runner_backend().is_flashinfer_trtllm() and (
|
||||||
|
not importlib.util.find_spec("flashinfer")
|
||||||
|
or pkg_version.parse(__import__("flashinfer").__version__)
|
||||||
|
>= pkg_version.parse("0.2.9rc1")
|
||||||
|
)
|
||||||
|
return result
|
||||||
|
|||||||
@@ -33,7 +33,8 @@ from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod
|
|||||||
from sglang.srt.layers.quantization.utils import get_scalar_types, replace_parameter
|
from sglang.srt.layers.quantization.utils import get_scalar_types, replace_parameter
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from sglang.srt.layers.moe.topk import TopKOutput
|
from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
|
||||||
|
from sglang.srt.layers.moe.topk import StandardTopKOutput
|
||||||
|
|
||||||
from sglang.srt.utils import is_cuda, is_hip
|
from sglang.srt.utils import is_cuda, is_hip
|
||||||
|
|
||||||
@@ -739,13 +740,12 @@ class AWQMoEMethod(FusedMoEMethodBase):
|
|||||||
self,
|
self,
|
||||||
layer: torch.nn.Module,
|
layer: torch.nn.Module,
|
||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
topk_output: TopKOutput,
|
topk_output: StandardTopKOutput,
|
||||||
*,
|
moe_runner_config: MoeRunnerConfig,
|
||||||
activation: str = "silu",
|
|
||||||
**kwargs,
|
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
|
assert (
|
||||||
assert activation == "silu", "Only SiLU activation is supported."
|
moe_runner_config.activation == "silu"
|
||||||
|
), "Only SiLU activation is supported."
|
||||||
|
|
||||||
# The input must currently be float16
|
# The input must currently be float16
|
||||||
orig_dtype = x.dtype
|
orig_dtype = x.dtype
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ import torch
|
|||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
|
||||||
from sglang.srt.layers.moe.topk import TopKOutput
|
from sglang.srt.layers.moe.topk import TopKOutput
|
||||||
|
|
||||||
|
|
||||||
@@ -100,12 +101,7 @@ class FusedMoEMethodBase(QuantizeMethodBase):
|
|||||||
layer: torch.nn.Module,
|
layer: torch.nn.Module,
|
||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
topk_output: TopKOutput,
|
topk_output: TopKOutput,
|
||||||
*,
|
moe_runner_config: MoeRunnerConfig,
|
||||||
activation: str = "silu",
|
|
||||||
apply_router_weight_on_input: bool = False,
|
|
||||||
inplace: bool = True,
|
|
||||||
no_combine: bool = False,
|
|
||||||
routed_scaling_factor: Optional[float] = None,
|
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|||||||
@@ -3,7 +3,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional
|
from typing import TYPE_CHECKING, Any, Dict, List, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch.nn import Module
|
from torch.nn import Module
|
||||||
@@ -22,6 +22,7 @@ from sglang.srt.layers.quantization.utils import is_layer_skipped
|
|||||||
from sglang.srt.utils import set_weight_attrs
|
from sglang.srt.utils import set_weight_attrs
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
|
||||||
from sglang.srt.layers.moe.topk import TopKOutput
|
from sglang.srt.layers.moe.topk import TopKOutput
|
||||||
|
|
||||||
ACTIVATION_SCHEMES = ["static", "dynamic"]
|
ACTIVATION_SCHEMES = ["static", "dynamic"]
|
||||||
@@ -348,12 +349,7 @@ class BlockInt8MoEMethod(FusedMoEMethodBase):
|
|||||||
layer: torch.nn.Module,
|
layer: torch.nn.Module,
|
||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
topk_output: TopKOutput,
|
topk_output: TopKOutput,
|
||||||
*,
|
moe_runner_config: MoeRunnerConfig,
|
||||||
activation: str = "silu",
|
|
||||||
apply_router_weight_on_input: bool = False,
|
|
||||||
inplace: bool = True,
|
|
||||||
no_combine: bool = False,
|
|
||||||
routed_scaling_factor: Optional[float] = None,
|
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
|
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
|
||||||
|
|
||||||
@@ -363,15 +359,11 @@ class BlockInt8MoEMethod(FusedMoEMethodBase):
|
|||||||
layer.w13_weight,
|
layer.w13_weight,
|
||||||
layer.w2_weight,
|
layer.w2_weight,
|
||||||
topk_output=topk_output,
|
topk_output=topk_output,
|
||||||
inplace=inplace,
|
moe_runner_config=moe_runner_config,
|
||||||
activation=activation,
|
|
||||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
|
||||||
use_int8_w8a8=True,
|
use_int8_w8a8=True,
|
||||||
w1_scale=(layer.w13_weight_scale_inv),
|
w1_scale=(layer.w13_weight_scale_inv),
|
||||||
w2_scale=(layer.w2_weight_scale_inv),
|
w2_scale=(layer.w2_weight_scale_inv),
|
||||||
a1_scale=layer.w13_input_scale,
|
a1_scale=layer.w13_input_scale,
|
||||||
a2_scale=layer.w2_input_scale,
|
a2_scale=layer.w2_input_scale,
|
||||||
block_shape=self.quant_config.weight_block_size,
|
block_shape=self.quant_config.weight_block_size,
|
||||||
no_combine=no_combine,
|
|
||||||
routed_scaling_factor=routed_scaling_factor,
|
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -23,6 +23,7 @@ from sglang.srt.utils import is_cpu, is_cuda, is_hip, is_npu, set_weight_attrs
|
|||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
||||||
|
from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
|
||||||
from sglang.srt.layers.moe.topk import TopKOutput
|
from sglang.srt.layers.moe.topk import TopKOutput
|
||||||
from sglang.srt.layers.quantization.compressed_tensors.compressed_tensors import (
|
from sglang.srt.layers.quantization.compressed_tensors.compressed_tensors import (
|
||||||
CompressedTensorsConfig,
|
CompressedTensorsConfig,
|
||||||
@@ -269,12 +270,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
|||||||
layer: torch.nn.Module,
|
layer: torch.nn.Module,
|
||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
topk_output: TopKOutput,
|
topk_output: TopKOutput,
|
||||||
*,
|
moe_runner_config: MoeRunnerConfig,
|
||||||
activation: str = "silu",
|
|
||||||
apply_router_weight_on_input: bool = False,
|
|
||||||
inplace: bool = True,
|
|
||||||
no_combine: bool = False,
|
|
||||||
routed_scaling_factor: Optional[float] = None,
|
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
from sglang.srt.layers.moe.fused_moe_triton import fused_experts
|
from sglang.srt.layers.moe.fused_moe_triton import fused_experts
|
||||||
|
|
||||||
@@ -283,8 +279,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
|||||||
layer.w13_weight,
|
layer.w13_weight,
|
||||||
layer.w2_weight,
|
layer.w2_weight,
|
||||||
topk_output=topk_output,
|
topk_output=topk_output,
|
||||||
inplace=inplace,
|
moe_runner_config=moe_runner_config,
|
||||||
activation=activation,
|
|
||||||
use_fp8_w8a8=True,
|
use_fp8_w8a8=True,
|
||||||
per_channel_quant=self.weight_quant.strategy
|
per_channel_quant=self.weight_quant.strategy
|
||||||
== QuantizationStrategy.CHANNEL,
|
== QuantizationStrategy.CHANNEL,
|
||||||
@@ -292,8 +287,6 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
|||||||
w2_scale=layer.w2_weight_scale,
|
w2_scale=layer.w2_weight_scale,
|
||||||
a1_scale=layer.w13_input_scale,
|
a1_scale=layer.w13_input_scale,
|
||||||
a2_scale=layer.w2_input_scale,
|
a2_scale=layer.w2_input_scale,
|
||||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
|
||||||
routed_scaling_factor=routed_scaling_factor,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -601,12 +594,12 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
|
|||||||
layer: torch.nn.Module,
|
layer: torch.nn.Module,
|
||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
topk_output: TopKOutput,
|
topk_output: TopKOutput,
|
||||||
*,
|
moe_runner_config: MoeRunnerConfig,
|
||||||
activation: str = "silu",
|
|
||||||
**kwargs,
|
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
|
|
||||||
assert activation == "silu", "Only SiLU activation is supported."
|
assert (
|
||||||
|
moe_runner_config.activation == "silu"
|
||||||
|
), "Only SiLU activation is supported."
|
||||||
|
|
||||||
topk_weights, topk_ids, router_logits = topk_output
|
topk_weights, topk_ids, router_logits = topk_output
|
||||||
|
|
||||||
|
|||||||
@@ -41,6 +41,7 @@ from sglang.srt.utils import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
|
||||||
from sglang.srt.layers.moe.topk import TopKOutput
|
from sglang.srt.layers.moe.topk import TopKOutput
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -220,22 +221,10 @@ class MxFp4LinearMethod(LinearMethodBase):
|
|||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
class MxFp4MoEMethod:
|
class MxFp4MoEMethod(FusedMoEMethodBase):
|
||||||
def __new__(cls, *args, **kwargs):
|
|
||||||
if not hasattr(cls, "_initialized"):
|
def __init__(self, quant_config: Mxfp4Config):
|
||||||
original_init = cls.__init__
|
self.quant_config = quant_config
|
||||||
new_cls = type(
|
|
||||||
cls.__name__,
|
|
||||||
(FusedMoEMethodBase,),
|
|
||||||
{
|
|
||||||
"__init__": original_init,
|
|
||||||
**{k: v for k, v in cls.__dict__.items() if k != "__dict__"},
|
|
||||||
},
|
|
||||||
)
|
|
||||||
obj = super(new_cls, new_cls).__new__(new_cls)
|
|
||||||
obj.__init__(*args, **kwargs)
|
|
||||||
return obj
|
|
||||||
return super().__new__(cls)
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_moe_method(
|
def get_moe_method(
|
||||||
@@ -364,12 +353,7 @@ class W4A4MXFp4MoEDynamicMethod(MxFp4MoEMethod):
|
|||||||
layer: torch.nn.Module,
|
layer: torch.nn.Module,
|
||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
topk_output: TopKOutput,
|
topk_output: TopKOutput,
|
||||||
*,
|
moe_runner_config: MoeRunnerConfig,
|
||||||
activation: str = "silu",
|
|
||||||
apply_router_weight_on_input: bool = False,
|
|
||||||
inplace: bool = True,
|
|
||||||
no_combine: bool = False,
|
|
||||||
routed_scaling_factor: Optional[float] = None,
|
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
topk_weights, topk_ids, _ = topk_output
|
topk_weights, topk_ids, _ = topk_output
|
||||||
|
|
||||||
@@ -383,7 +367,9 @@ class W4A4MXFp4MoEDynamicMethod(MxFp4MoEMethod):
|
|||||||
w1_scale=layer.w13_weight_scale,
|
w1_scale=layer.w13_weight_scale,
|
||||||
w2_scale=layer.w2_weight_scale,
|
w2_scale=layer.w2_weight_scale,
|
||||||
activation=(
|
activation=(
|
||||||
ActivationType.Silu if activation == "silu" else ActivationType.Gelu
|
ActivationType.Silu
|
||||||
|
if moe_runner_config.activation == "silu"
|
||||||
|
else ActivationType.Gelu
|
||||||
),
|
),
|
||||||
doweight_stage1=False,
|
doweight_stage1=False,
|
||||||
)
|
)
|
||||||
@@ -497,12 +483,7 @@ class W4A4MXFp4MoEStaticMethod(MxFp4MoEMethod):
|
|||||||
layer: torch.nn.Module,
|
layer: torch.nn.Module,
|
||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
topk_output: TopKOutput,
|
topk_output: TopKOutput,
|
||||||
*,
|
moe_runner_config: MoeRunnerConfig,
|
||||||
activation: str = "silu",
|
|
||||||
apply_router_weight_on_input: bool = False,
|
|
||||||
inplace: bool = True,
|
|
||||||
no_combine: bool = False,
|
|
||||||
routed_scaling_factor: Optional[float] = None,
|
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
topk_weights, topk_ids, _ = topk_output
|
topk_weights, topk_ids, _ = topk_output
|
||||||
|
|
||||||
@@ -516,7 +497,9 @@ class W4A4MXFp4MoEStaticMethod(MxFp4MoEMethod):
|
|||||||
w1_scale=layer.w13_weight_scale,
|
w1_scale=layer.w13_weight_scale,
|
||||||
w2_scale=layer.w2_weight_scale,
|
w2_scale=layer.w2_weight_scale,
|
||||||
activation=(
|
activation=(
|
||||||
ActivationType.Silu if activation == "silu" else ActivationType.Gelu
|
ActivationType.Silu
|
||||||
|
if moe_runner_config.activation == "silu"
|
||||||
|
else ActivationType.Gelu
|
||||||
),
|
),
|
||||||
doweight_stage1=False,
|
doweight_stage1=False,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -79,6 +79,7 @@ from sglang.srt.utils import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
|
||||||
from sglang.srt.layers.moe.topk import TopKOutput
|
from sglang.srt.layers.moe.topk import TopKOutput
|
||||||
from sglang.srt.layers.quantization.w4afp8 import W4AFp8Config
|
from sglang.srt.layers.quantization.w4afp8 import W4AFp8Config
|
||||||
|
|
||||||
@@ -982,12 +983,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|||||||
layer: torch.nn.Module,
|
layer: torch.nn.Module,
|
||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
topk_output: TopKOutput,
|
topk_output: TopKOutput,
|
||||||
*,
|
moe_runner_config: MoeRunnerConfig,
|
||||||
activation: str = "silu",
|
|
||||||
apply_router_weight_on_input: bool = False,
|
|
||||||
inplace: bool = True,
|
|
||||||
no_combine: bool = False,
|
|
||||||
routed_scaling_factor: Optional[float] = None,
|
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
|
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
|
||||||
|
|
||||||
@@ -996,7 +992,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|||||||
|
|
||||||
topk_weights, topk_ids, _ = topk_output
|
topk_weights, topk_ids, _ = topk_output
|
||||||
x, topk_weights = apply_topk_weights_cpu(
|
x, topk_weights = apply_topk_weights_cpu(
|
||||||
apply_router_weight_on_input, topk_weights, x
|
moe_runner_config.apply_router_weight_on_input, topk_weights, x
|
||||||
)
|
)
|
||||||
|
|
||||||
return torch.ops.sgl_kernel.fused_experts_cpu(
|
return torch.ops.sgl_kernel.fused_experts_cpu(
|
||||||
@@ -1021,8 +1017,8 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|||||||
layer,
|
layer,
|
||||||
x,
|
x,
|
||||||
topk_output,
|
topk_output,
|
||||||
activation,
|
moe_runner_config.activation,
|
||||||
no_combine,
|
moe_runner_config.no_combine,
|
||||||
)
|
)
|
||||||
if ret is not None:
|
if ret is not None:
|
||||||
return ret
|
return ret
|
||||||
@@ -1060,8 +1056,8 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|||||||
use_fp8_blockscale=True,
|
use_fp8_blockscale=True,
|
||||||
)
|
)
|
||||||
# TODO: Fuse into select_experts
|
# TODO: Fuse into select_experts
|
||||||
if routed_scaling_factor is not None:
|
if moe_runner_config.routed_scaling_factor is not None:
|
||||||
output *= routed_scaling_factor
|
output *= moe_runner_config.routed_scaling_factor
|
||||||
return output
|
return output
|
||||||
# Expert fusion with FP8 quantization
|
# Expert fusion with FP8 quantization
|
||||||
return fused_experts(
|
return fused_experts(
|
||||||
@@ -1069,9 +1065,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|||||||
layer.w13_weight,
|
layer.w13_weight,
|
||||||
layer.w2_weight,
|
layer.w2_weight,
|
||||||
topk_output=topk_output,
|
topk_output=topk_output,
|
||||||
inplace=inplace and not no_combine,
|
moe_runner_config=moe_runner_config,
|
||||||
activation=activation,
|
|
||||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
|
||||||
use_fp8_w8a8=True,
|
use_fp8_w8a8=True,
|
||||||
w1_scale=(
|
w1_scale=(
|
||||||
layer.w13_weight_scale_inv
|
layer.w13_weight_scale_inv
|
||||||
@@ -1084,26 +1078,32 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|||||||
a1_scale=layer.w13_input_scale,
|
a1_scale=layer.w13_input_scale,
|
||||||
a2_scale=layer.w2_input_scale,
|
a2_scale=layer.w2_input_scale,
|
||||||
block_shape=self.quant_config.weight_block_size,
|
block_shape=self.quant_config.weight_block_size,
|
||||||
no_combine=no_combine,
|
|
||||||
routed_scaling_factor=routed_scaling_factor,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def apply_with_router_logits(
|
def apply_with_router_logits(
|
||||||
self,
|
self,
|
||||||
layer: torch.nn.Module,
|
layer: torch.nn.Module,
|
||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
router_logits: torch.Tensor,
|
topk_output: TopKOutput,
|
||||||
*,
|
moe_runner_config: MoeRunnerConfig,
|
||||||
activation: str = "silu",
|
|
||||||
routed_scaling_factor: Optional[float] = None,
|
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
|
|
||||||
|
activation = moe_runner_config.activation
|
||||||
|
routed_scaling_factor = moe_runner_config.routed_scaling_factor
|
||||||
|
|
||||||
|
from flashinfer.fused_moe import trtllm_fp8_block_scale_moe
|
||||||
|
|
||||||
|
from sglang.srt.layers.moe.topk import TopKOutputChecker
|
||||||
|
|
||||||
|
assert TopKOutputChecker.format_is_bypassed(topk_output)
|
||||||
|
router_logits = topk_output.router_logits
|
||||||
|
topk_config = topk_output.topk_config
|
||||||
assert (
|
assert (
|
||||||
activation == "silu"
|
activation == "silu"
|
||||||
), "Only silu is supported for flashinfer blockscale fp8 moe"
|
), "Only silu is supported for flashinfer blockscale fp8 moe"
|
||||||
a_q, a_sf = per_token_group_quant_fp8(x, self.quant_config.weight_block_size[1])
|
a_q, a_sf = per_token_group_quant_fp8(x, self.quant_config.weight_block_size[1])
|
||||||
# NOTE: scales of hidden states have to be transposed!
|
# NOTE: scales of hidden states have to be transposed!
|
||||||
a_sf_t = a_sf.t().contiguous()
|
a_sf_t = a_sf.t().contiguous()
|
||||||
from flashinfer.fused_moe import trtllm_fp8_block_scale_moe
|
|
||||||
|
|
||||||
return trtllm_fp8_block_scale_moe(
|
return trtllm_fp8_block_scale_moe(
|
||||||
routing_logits=router_logits.to(torch.float32),
|
routing_logits=router_logits.to(torch.float32),
|
||||||
@@ -1115,9 +1115,9 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|||||||
gemm2_weights=layer.w2_weight,
|
gemm2_weights=layer.w2_weight,
|
||||||
gemm2_weights_scale=layer.w2_weight_scale_inv,
|
gemm2_weights_scale=layer.w2_weight_scale_inv,
|
||||||
num_experts=layer.num_experts,
|
num_experts=layer.num_experts,
|
||||||
top_k=layer.top_k,
|
top_k=topk_config.top_k,
|
||||||
n_group=layer.num_expert_group,
|
n_group=topk_config.num_expert_group,
|
||||||
topk_group=layer.topk_group,
|
topk_group=topk_config.topk_group,
|
||||||
intermediate_size=layer.w2_weight.shape[2],
|
intermediate_size=layer.w2_weight.shape[2],
|
||||||
local_expert_offset=layer.moe_ep_rank * layer.num_local_experts,
|
local_expert_offset=layer.moe_ep_rank * layer.num_local_experts,
|
||||||
local_num_experts=layer.num_local_experts,
|
local_num_experts=layer.num_local_experts,
|
||||||
|
|||||||
@@ -113,6 +113,7 @@ def normalize_e4m3fn_to_e4m3fnuz(
|
|||||||
return weight, weight_scale, input_scale
|
return weight, weight_scale, input_scale
|
||||||
|
|
||||||
|
|
||||||
|
# TODO(ch-wan): define these backends in --moe-runner-backend
|
||||||
def cutlass_block_fp8_supported() -> bool:
|
def cutlass_block_fp8_supported() -> bool:
|
||||||
if not get_bool_env_var("SGLANG_SUPPORT_CUTLASS_BLOCK_FP8"):
|
if not get_bool_env_var("SGLANG_SUPPORT_CUTLASS_BLOCK_FP8"):
|
||||||
return False
|
return False
|
||||||
|
|||||||
@@ -44,6 +44,7 @@ from sglang.srt.layers.quantization.utils import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
|
||||||
from sglang.srt.layers.moe.topk import TopKOutput
|
from sglang.srt.layers.moe.topk import TopKOutput
|
||||||
|
|
||||||
from sglang.srt.utils import is_cuda
|
from sglang.srt.utils import is_cuda
|
||||||
@@ -1056,13 +1057,13 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
|
|||||||
layer: torch.nn.Module,
|
layer: torch.nn.Module,
|
||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
topk_output: TopKOutput,
|
topk_output: TopKOutput,
|
||||||
*,
|
moe_runner_config: MoeRunnerConfig,
|
||||||
activation: str = "silu",
|
|
||||||
**kwargs,
|
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
# Delay the import to avoid circular dependency
|
# Delay the import to avoid circular dependency
|
||||||
|
|
||||||
assert activation == "silu", "Only SiLU activation is supported."
|
assert (
|
||||||
|
moe_runner_config.activation == "silu"
|
||||||
|
), "Only SiLU activation is supported."
|
||||||
|
|
||||||
# The input must currently be float16
|
# The input must currently be float16
|
||||||
orig_dtype = x.dtype
|
orig_dtype = x.dtype
|
||||||
|
|||||||
@@ -28,6 +28,7 @@ from sglang.srt.utils import get_device_capability, is_cuda
|
|||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from sglang.srt.layers.linear import LinearBase
|
from sglang.srt.layers.linear import LinearBase
|
||||||
|
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from vllm import _custom_ops as ops
|
from vllm import _custom_ops as ops
|
||||||
@@ -216,13 +217,13 @@ def check_marlin_supports_layer(layer: LinearBase, group_size: int) -> bool:
|
|||||||
)[0]
|
)[0]
|
||||||
|
|
||||||
|
|
||||||
def check_moe_marlin_supports_layer(layer: LinearBase, group_size: int) -> bool:
|
def check_moe_marlin_supports_layer(layer: FusedMoE, group_size: int) -> bool:
|
||||||
hidden_size = layer.hidden_size
|
hidden_size = layer.hidden_size
|
||||||
intermediate_size_per_partition = layer.intermediate_size_per_partition
|
intermediate_size_per_partition = layer.intermediate_size_per_partition
|
||||||
# apply_router_weight_on_input is not supported for moe marlin
|
# apply_router_weight_on_input is not supported for moe marlin
|
||||||
supports_router_weight = not layer.apply_router_weight_on_input
|
supports_router_weight = not layer.moe_runner_config.apply_router_weight_on_input
|
||||||
# moe marlin requires the activation to be silu
|
# moe marlin requires the activation to be silu
|
||||||
supports_activation = layer.activation == "silu"
|
supports_activation = layer.moe_runner_config.activation == "silu"
|
||||||
|
|
||||||
# gate-up: (n, k) = (intermediate_size_per_partition * 2, hidden_size)
|
# gate-up: (n, k) = (intermediate_size_per_partition * 2, hidden_size)
|
||||||
# down: (n, k) = (hidden_size, intermediate_size_per_partition)
|
# down: (n, k) = (hidden_size, intermediate_size_per_partition)
|
||||||
|
|||||||
@@ -7,8 +7,8 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional
|
|||||||
import torch
|
import torch
|
||||||
from torch.nn.parameter import Parameter
|
from torch.nn.parameter import Parameter
|
||||||
|
|
||||||
|
from sglang.srt.layers.moe import should_use_flashinfer_trtllm_moe
|
||||||
from sglang.srt.layers.moe.cutlass_moe_params import CutlassMoEParams, CutlassMoEType
|
from sglang.srt.layers.moe.cutlass_moe_params import CutlassMoEParams, CutlassMoEType
|
||||||
from sglang.srt.layers.moe.utils import should_use_flashinfer_trtllm_moe
|
|
||||||
from sglang.srt.layers.parameter import ModelWeightParameter, PerTensorScaleParameter
|
from sglang.srt.layers.parameter import ModelWeightParameter, PerTensorScaleParameter
|
||||||
from sglang.srt.layers.quantization.base_config import (
|
from sglang.srt.layers.quantization.base_config import (
|
||||||
FusedMoEMethodBase,
|
FusedMoEMethodBase,
|
||||||
@@ -30,10 +30,11 @@ from sglang.srt.layers.quantization.utils import (
|
|||||||
requantize_with_max_scale,
|
requantize_with_max_scale,
|
||||||
)
|
)
|
||||||
from sglang.srt.layers.radix_attention import RadixAttention
|
from sglang.srt.layers.radix_attention import RadixAttention
|
||||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
|
||||||
from sglang.srt.utils import is_cuda, next_power_of_2
|
from sglang.srt.utils import is_cuda, next_power_of_2
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
|
||||||
|
from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
|
||||||
from sglang.srt.layers.moe.topk import TopKOutput
|
from sglang.srt.layers.moe.topk import TopKOutput
|
||||||
|
|
||||||
if is_cuda():
|
if is_cuda():
|
||||||
@@ -422,12 +423,7 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
|
|||||||
layer: torch.nn.Module,
|
layer: torch.nn.Module,
|
||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
topk_output: TopKOutput,
|
topk_output: TopKOutput,
|
||||||
*,
|
moe_runner_config: MoeRunnerConfig,
|
||||||
activation: str = "silu",
|
|
||||||
apply_router_weight_on_input: bool = False,
|
|
||||||
inplace: bool = True,
|
|
||||||
no_combine: bool = False,
|
|
||||||
routed_scaling_factor: Optional[float] = None,
|
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
|
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
|
||||||
|
|
||||||
@@ -436,15 +432,13 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
|
|||||||
layer.w13_weight,
|
layer.w13_weight,
|
||||||
layer.w2_weight,
|
layer.w2_weight,
|
||||||
topk_output=topk_output,
|
topk_output=topk_output,
|
||||||
inplace=inplace,
|
moe_runner_config=moe_runner_config,
|
||||||
activation=activation,
|
|
||||||
use_fp8_w8a8=True,
|
use_fp8_w8a8=True,
|
||||||
per_channel_quant=False, # ModelOpt uses per-tensor quantization
|
per_channel_quant=False, # ModelOpt uses per-tensor quantization
|
||||||
w1_scale=layer.w13_weight_scale,
|
w1_scale=layer.w13_weight_scale,
|
||||||
w2_scale=layer.w2_weight_scale,
|
w2_scale=layer.w2_weight_scale,
|
||||||
a1_scale=layer.w13_input_scale,
|
a1_scale=layer.w13_input_scale,
|
||||||
a2_scale=layer.w2_input_scale,
|
a2_scale=layer.w2_input_scale,
|
||||||
no_combine=no_combine,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -741,8 +735,10 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def enable_flashinfer_cutlass_moe(self) -> bool:
|
def enable_flashinfer_cutlass_moe(self) -> bool:
|
||||||
|
from sglang.srt.layers.moe import get_moe_runner_backend
|
||||||
|
|
||||||
"""Access the global enable_flashinfer_cutlass_moe setting."""
|
"""Access the global enable_flashinfer_cutlass_moe setting."""
|
||||||
return global_server_args_dict.get("enable_flashinfer_cutlass_moe", False)
|
return get_moe_runner_backend().is_flashinfer_cutlass()
|
||||||
|
|
||||||
def create_weights(
|
def create_weights(
|
||||||
self,
|
self,
|
||||||
@@ -1160,21 +1156,14 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
|
|||||||
|
|
||||||
def apply(
|
def apply(
|
||||||
self,
|
self,
|
||||||
layer: torch.nn.Module,
|
layer: FusedMoE,
|
||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
topk_output: TopKOutput,
|
topk_output: TopKOutput,
|
||||||
*,
|
moe_runner_config: MoeRunnerConfig,
|
||||||
activation: str = "silu",
|
|
||||||
apply_router_weight_on_input: bool = False,
|
|
||||||
inplace: bool = True,
|
|
||||||
no_combine: bool = False,
|
|
||||||
routed_scaling_factor: Optional[float] = None,
|
|
||||||
ep_rank: Optional[int] = None,
|
|
||||||
ep_size: Optional[int] = None,
|
|
||||||
tp_rank: Optional[int] = None,
|
|
||||||
tp_size: Optional[int] = None,
|
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
assert activation == "silu", "Only SiLU activation is supported."
|
assert (
|
||||||
|
moe_runner_config.activation == "silu"
|
||||||
|
), "Only SiLU activation is supported."
|
||||||
|
|
||||||
# Check if this is a FlashInferFP4MoE layer that should handle its own forward
|
# Check if this is a FlashInferFP4MoE layer that should handle its own forward
|
||||||
if hasattr(layer, "gemm1_weights_fp4_shuffled"):
|
if hasattr(layer, "gemm1_weights_fp4_shuffled"):
|
||||||
@@ -1183,7 +1172,7 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
|
|||||||
|
|
||||||
if self.enable_flashinfer_cutlass_moe:
|
if self.enable_flashinfer_cutlass_moe:
|
||||||
assert (
|
assert (
|
||||||
not apply_router_weight_on_input
|
not moe_runner_config.apply_router_weight_on_input
|
||||||
), "apply_router_weight_on_input is not supported for Flashinfer"
|
), "apply_router_weight_on_input is not supported for Flashinfer"
|
||||||
# TRTLLM Cutlass moe takes in activations in BF16/Half/nvfp4 precision
|
# TRTLLM Cutlass moe takes in activations in BF16/Half/nvfp4 precision
|
||||||
# and fp4 quantized weights loaded from the checkpoint
|
# and fp4 quantized weights loaded from the checkpoint
|
||||||
@@ -1205,14 +1194,14 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
|
|||||||
layer.w2_blockscale_swizzled.view(torch.int32),
|
layer.w2_blockscale_swizzled.view(torch.int32),
|
||||||
layer.g2_alphas,
|
layer.g2_alphas,
|
||||||
],
|
],
|
||||||
ep_size=ep_size,
|
ep_size=layer.moe_ep_size,
|
||||||
ep_rank=ep_rank,
|
ep_rank=layer.moe_ep_rank,
|
||||||
tp_size=tp_size,
|
tp_size=layer.moe_tp_size,
|
||||||
tp_rank=tp_rank,
|
tp_rank=layer.moe_tp_rank,
|
||||||
tune_max_num_tokens=next_power_of_2(x.shape[0]),
|
tune_max_num_tokens=next_power_of_2(x.shape[0]),
|
||||||
)[0]
|
)[0]
|
||||||
if routed_scaling_factor is not None:
|
if moe_runner_config.routed_scaling_factor is not None:
|
||||||
output *= routed_scaling_factor
|
output *= moe_runner_config.routed_scaling_factor
|
||||||
return output
|
return output
|
||||||
|
|
||||||
from sglang.srt.layers.moe.cutlass_moe import cutlass_moe_fp4
|
from sglang.srt.layers.moe.cutlass_moe import cutlass_moe_fp4
|
||||||
@@ -1231,8 +1220,8 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
|
|||||||
topk_weights=topk_weights,
|
topk_weights=topk_weights,
|
||||||
topk_ids=topk_ids,
|
topk_ids=topk_ids,
|
||||||
params=layer.cutlass_moe_params,
|
params=layer.cutlass_moe_params,
|
||||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
apply_router_weight_on_input=moe_runner_config.apply_router_weight_on_input,
|
||||||
).to(x.dtype)
|
).to(x.dtype)
|
||||||
if routed_scaling_factor is not None:
|
if moe_runner_config.routed_scaling_factor is not None:
|
||||||
output *= routed_scaling_factor
|
output *= moe_runner_config.routed_scaling_factor
|
||||||
return output
|
return output
|
||||||
|
|||||||
@@ -22,6 +22,7 @@ from sglang.srt.utils import get_device_capability, set_weight_attrs
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
|
||||||
from sglang.srt.layers.moe.topk import TopKOutput
|
from sglang.srt.layers.moe.topk import TopKOutput
|
||||||
|
|
||||||
|
|
||||||
@@ -353,17 +354,14 @@ class MoeWNA16Method(FusedMoEMethodBase):
|
|||||||
layer: torch.nn.Module,
|
layer: torch.nn.Module,
|
||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
topk_output: TopKOutput,
|
topk_output: TopKOutput,
|
||||||
*,
|
moe_runner_config: MoeRunnerConfig,
|
||||||
activation: str = "silu",
|
|
||||||
apply_router_weight_on_input: bool = False,
|
|
||||||
inplace: bool = True,
|
|
||||||
no_combine: bool = False,
|
|
||||||
routed_scaling_factor: Optional[float] = None,
|
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
# avoid circular import
|
# avoid circular import
|
||||||
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
|
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
|
||||||
|
|
||||||
assert activation == "silu", "Only SiLU activation is supported."
|
assert (
|
||||||
|
moe_runner_config.activation == "silu"
|
||||||
|
), "Only SiLU activation is supported."
|
||||||
|
|
||||||
weight_bits = self.quant_config.weight_bits
|
weight_bits = self.quant_config.weight_bits
|
||||||
has_zp = self.quant_config.has_zp
|
has_zp = self.quant_config.has_zp
|
||||||
@@ -373,8 +371,7 @@ class MoeWNA16Method(FusedMoEMethodBase):
|
|||||||
layer.w13_qweight,
|
layer.w13_qweight,
|
||||||
layer.w2_qweight,
|
layer.w2_qweight,
|
||||||
topk_output=topk_output,
|
topk_output=topk_output,
|
||||||
inplace=inplace,
|
moe_runner_config=moe_runner_config,
|
||||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
|
||||||
use_int4_w4a16=weight_bits == 4,
|
use_int4_w4a16=weight_bits == 4,
|
||||||
use_int8_w8a16=weight_bits == 8,
|
use_int8_w8a16=weight_bits == 8,
|
||||||
w1_scale=layer.w13_scales,
|
w1_scale=layer.w13_scales,
|
||||||
@@ -382,8 +379,6 @@ class MoeWNA16Method(FusedMoEMethodBase):
|
|||||||
w1_zp=layer.w13_qzeros if has_zp else None,
|
w1_zp=layer.w13_qzeros if has_zp else None,
|
||||||
w2_zp=layer.w2_qzeros if has_zp else None,
|
w2_zp=layer.w2_qzeros if has_zp else None,
|
||||||
block_shape=[0, layer.group_size],
|
block_shape=[0, layer.group_size],
|
||||||
no_combine=no_combine,
|
|
||||||
routed_scaling_factor=routed_scaling_factor,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@@ -486,16 +481,16 @@ class MoeWNA16Method(FusedMoEMethodBase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if "w13_qzeros" in weight_name:
|
if "w13_qzeros" in weight_name:
|
||||||
tensor = loaded_weight.view(layer.tp_size, -1, loaded_weight.size(1))[
|
tensor = loaded_weight.view(
|
||||||
tp_rank
|
layer.moe_tp_size, -1, loaded_weight.size(1)
|
||||||
]
|
)[tp_rank]
|
||||||
if shard_id == "w1":
|
if shard_id == "w1":
|
||||||
param.data[expert_id, : shard_size // 2] = tensor
|
param.data[expert_id, : shard_size // 2] = tensor
|
||||||
else:
|
else:
|
||||||
param.data[expert_id, shard_size // 2 :] = tensor
|
param.data[expert_id, shard_size // 2 :] = tensor
|
||||||
elif "w2_qzeros" in weight_name:
|
elif "w2_qzeros" in weight_name:
|
||||||
param.data[expert_id] = loaded_weight.view(
|
param.data[expert_id] = loaded_weight.view(
|
||||||
loaded_weight.size(0), layer.tp_size, -1
|
loaded_weight.size(0), layer.moe_tp_size, -1
|
||||||
)[:, tp_rank]
|
)[:, tp_rank]
|
||||||
else:
|
else:
|
||||||
weight_loader(param, loaded_weight, weight_name, shard_id, expert_id)
|
weight_loader(param, loaded_weight, weight_name, shard_id, expert_id)
|
||||||
|
|||||||
@@ -16,14 +16,13 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import importlib.util
|
|
||||||
import logging
|
import logging
|
||||||
from typing import TYPE_CHECKING, List, Optional
|
from typing import TYPE_CHECKING, List, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import triton.language as tl
|
|
||||||
from torch.nn.parameter import Parameter
|
from torch.nn.parameter import Parameter
|
||||||
|
|
||||||
|
from sglang.srt.layers.moe.utils import get_moe_runner_backend
|
||||||
from sglang.srt.layers.quantization.base_config import (
|
from sglang.srt.layers.quantization.base_config import (
|
||||||
FusedMoEMethodBase,
|
FusedMoEMethodBase,
|
||||||
QuantizationConfig,
|
QuantizationConfig,
|
||||||
@@ -31,7 +30,6 @@ from sglang.srt.layers.quantization.base_config import (
|
|||||||
)
|
)
|
||||||
from sglang.srt.layers.quantization.utils import is_layer_skipped
|
from sglang.srt.layers.quantization.utils import is_layer_skipped
|
||||||
from sglang.srt.layers.utils import is_sm100_supported
|
from sglang.srt.layers.utils import is_sm100_supported
|
||||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
|
||||||
from sglang.srt.utils import (
|
from sglang.srt.utils import (
|
||||||
direct_register_custom_op,
|
direct_register_custom_op,
|
||||||
get_bool_env_var,
|
get_bool_env_var,
|
||||||
@@ -60,6 +58,7 @@ if is_flashinfer_available():
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
|
||||||
from sglang.srt.layers.moe.topk import TopKOutput
|
from sglang.srt.layers.moe.topk import TopKOutput
|
||||||
|
|
||||||
OCP_MX_BLOCK_SIZE = 32
|
OCP_MX_BLOCK_SIZE = 32
|
||||||
@@ -218,15 +217,13 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
|||||||
self,
|
self,
|
||||||
prefix: str,
|
prefix: str,
|
||||||
):
|
):
|
||||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
|
||||||
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.prefix = prefix
|
self.prefix = prefix
|
||||||
self.topk_indices_dtype = None
|
self.topk_indices_dtype = None
|
||||||
self.use_triton_kernels = global_server_args_dict["enable_triton_kernel_moe"]
|
self.use_triton_kernels = get_moe_runner_backend().is_triton_kernel()
|
||||||
self.with_bias = False
|
self.with_bias = False
|
||||||
self.use_flashinfer = global_server_args_dict["enable_flashinfer_mxfp4_moe"]
|
self.use_flashinfer = get_moe_runner_backend().is_flashinfer_mxfp4()
|
||||||
|
|
||||||
self.triton_kernel_moe_forward = None
|
self.triton_kernel_moe_forward = None
|
||||||
self.triton_kernel_moe_with_bias_forward = None
|
self.triton_kernel_moe_with_bias_forward = None
|
||||||
@@ -348,6 +345,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
|||||||
logger,
|
logger,
|
||||||
f"Shuffling MoE weights for FlashInfer MXFP4 moe kernel (layer: {self.prefix}), it might take a while...",
|
f"Shuffling MoE weights for FlashInfer MXFP4 moe kernel (layer: {self.prefix}), it might take a while...",
|
||||||
)
|
)
|
||||||
|
# TODO: these values are hardcoded for now, we need to get them from the model
|
||||||
layer.gemm1_alpha = Parameter(
|
layer.gemm1_alpha = Parameter(
|
||||||
torch.tensor([1.702] * self.num_experts, dtype=torch.float32).cuda(),
|
torch.tensor([1.702] * self.num_experts, dtype=torch.float32).cuda(),
|
||||||
requires_grad=False,
|
requires_grad=False,
|
||||||
@@ -573,14 +571,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
|||||||
layer: torch.nn.Module,
|
layer: torch.nn.Module,
|
||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
topk_output: TopKOutput,
|
topk_output: TopKOutput,
|
||||||
*,
|
moe_runner_config: MoeRunnerConfig,
|
||||||
activation: str = "silu",
|
|
||||||
apply_router_weight_on_input: bool = False,
|
|
||||||
inplace: bool = True,
|
|
||||||
no_combine: bool = False,
|
|
||||||
routed_scaling_factor: Optional[float] = None,
|
|
||||||
activation_alpha: Optional[float] = None,
|
|
||||||
swiglu_limit: Optional[float] = None,
|
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
if self.use_flashinfer:
|
if self.use_flashinfer:
|
||||||
# Based on profiling results, we need to quantize x to mxfp8 here to achieve better performance
|
# Based on profiling results, we need to quantize x to mxfp8 here to achieve better performance
|
||||||
@@ -637,9 +628,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
|||||||
b1=layer.w13_weight_bias,
|
b1=layer.w13_weight_bias,
|
||||||
b2=layer.w2_weight_bias,
|
b2=layer.w2_weight_bias,
|
||||||
topk_output=topk_output,
|
topk_output=topk_output,
|
||||||
activation=activation,
|
moe_runner_config=moe_runner_config,
|
||||||
activation_alpha=activation_alpha,
|
|
||||||
swiglu_limit=swiglu_limit,
|
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
return self.triton_kernel_moe_forward(
|
return self.triton_kernel_moe_forward(
|
||||||
@@ -647,6 +636,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
|||||||
w1=layer.w13_weight,
|
w1=layer.w13_weight,
|
||||||
w2=layer.w2_weight,
|
w2=layer.w2_weight,
|
||||||
topk_output=topk_output,
|
topk_output=topk_output,
|
||||||
|
moe_runner_config=moe_runner_config,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
|
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
|
||||||
@@ -656,13 +646,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
|||||||
w1=layer.w13_weight,
|
w1=layer.w13_weight,
|
||||||
w2=layer.w2_weight,
|
w2=layer.w2_weight,
|
||||||
topk_output=topk_output,
|
topk_output=topk_output,
|
||||||
|
moe_runner_config=moe_runner_config,
|
||||||
b1=layer.w13_weight_bias,
|
b1=layer.w13_weight_bias,
|
||||||
b2=layer.w2_weight_bias,
|
b2=layer.w2_weight_bias,
|
||||||
inplace=inplace,
|
|
||||||
activation=activation,
|
|
||||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
|
||||||
no_combine=no_combine,
|
|
||||||
routed_scaling_factor=routed_scaling_factor,
|
|
||||||
activation_alpha=activation_alpha,
|
|
||||||
swiglu_limit=swiglu_limit,
|
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import importlib
|
import importlib.util
|
||||||
from typing import TYPE_CHECKING, Callable, List, Optional
|
from typing import TYPE_CHECKING, List, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
@@ -24,7 +24,7 @@ from sglang.srt.utils import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from sglang.srt.layers.moe.ep_moe.layer import EPMoE
|
from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
|
||||||
from sglang.srt.layers.moe.topk import TopKOutput
|
from sglang.srt.layers.moe.topk import TopKOutput
|
||||||
|
|
||||||
has_triton_kernels = importlib.util.find_spec("triton_kernels") is not None
|
has_triton_kernels = importlib.util.find_spec("triton_kernels") is not None
|
||||||
@@ -221,31 +221,14 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
|||||||
layer: torch.nn.Module,
|
layer: torch.nn.Module,
|
||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
topk_output: TopKOutput,
|
topk_output: TopKOutput,
|
||||||
*,
|
moe_runner_config: MoeRunnerConfig,
|
||||||
activation: str = "silu",
|
|
||||||
apply_router_weight_on_input: bool = False,
|
|
||||||
inplace: bool = True,
|
|
||||||
no_combine: bool = False,
|
|
||||||
routed_scaling_factor: Optional[float] = None,
|
|
||||||
activation_alpha: Optional[float] = None,
|
|
||||||
swiglu_limit: Optional[float] = None,
|
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
kwargs = {}
|
|
||||||
if activation_alpha is not None:
|
|
||||||
kwargs["activation_alpha"] = activation_alpha
|
|
||||||
if swiglu_limit is not None:
|
|
||||||
kwargs["swiglu_limit"] = swiglu_limit
|
|
||||||
|
|
||||||
return self.forward(
|
return self.forward(
|
||||||
x=x,
|
x=x,
|
||||||
layer=layer,
|
layer=layer,
|
||||||
topk_output=topk_output,
|
topk_output=topk_output,
|
||||||
activation=activation,
|
moe_runner_config=moe_runner_config,
|
||||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
|
||||||
inplace=inplace,
|
|
||||||
no_combine=no_combine,
|
|
||||||
routed_scaling_factor=routed_scaling_factor,
|
|
||||||
**kwargs,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward_cuda(
|
def forward_cuda(
|
||||||
@@ -253,18 +236,12 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
|||||||
layer: torch.nn.Module,
|
layer: torch.nn.Module,
|
||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
topk_output: TopKOutput,
|
topk_output: TopKOutput,
|
||||||
*,
|
moe_runner_config: MoeRunnerConfig,
|
||||||
activation: str = "silu",
|
|
||||||
apply_router_weight_on_input: bool = False,
|
|
||||||
inplace: bool = True,
|
|
||||||
no_combine: bool = False,
|
|
||||||
routed_scaling_factor: Optional[float] = None,
|
|
||||||
activation_alpha: Optional[float] = None,
|
|
||||||
swiglu_limit: Optional[float] = None,
|
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
|
|
||||||
if self.use_triton_kernels:
|
if self.use_triton_kernels:
|
||||||
if self.with_bias:
|
if self.with_bias:
|
||||||
|
assert self.triton_kernel_moe_with_bias_forward is not None
|
||||||
return self.triton_kernel_moe_with_bias_forward(
|
return self.triton_kernel_moe_with_bias_forward(
|
||||||
hidden_states=x,
|
hidden_states=x,
|
||||||
w1=layer.w13_weight,
|
w1=layer.w13_weight,
|
||||||
@@ -272,24 +249,24 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
|||||||
b1=layer.w13_weight_bias,
|
b1=layer.w13_weight_bias,
|
||||||
b2=layer.w2_weight_bias,
|
b2=layer.w2_weight_bias,
|
||||||
topk_output=topk_output,
|
topk_output=topk_output,
|
||||||
activation=activation,
|
moe_runner_config=moe_runner_config,
|
||||||
activation_alpha=activation_alpha,
|
|
||||||
swiglu_limit=swiglu_limit,
|
|
||||||
w1_pcg=None,
|
w1_pcg=None,
|
||||||
w2_pcg=None,
|
w2_pcg=None,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
assert self.triton_kernel_moe_forward is not None
|
||||||
return self.triton_kernel_moe_forward(
|
return self.triton_kernel_moe_forward(
|
||||||
hidden_states=x,
|
hidden_states=x,
|
||||||
w1=layer.w13_weight,
|
w1=layer.w13_weight,
|
||||||
w2=layer.w2_weight,
|
w2=layer.w2_weight,
|
||||||
topk_output=topk_output,
|
topk_output=topk_output,
|
||||||
|
moe_runner_config=moe_runner_config,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
if _use_aiter:
|
if _use_aiter:
|
||||||
assert not no_combine, "unsupported"
|
assert not moe_runner_config.no_combine, "unsupported"
|
||||||
topk_weights, topk_ids, _ = topk_output
|
topk_weights, topk_ids, _ = topk_output
|
||||||
if apply_router_weight_on_input:
|
if moe_runner_config.apply_router_weight_on_input:
|
||||||
assert (
|
assert (
|
||||||
topk_weights.dim() == 2
|
topk_weights.dim() == 2
|
||||||
), "`topk_weights` should be in shape (num_tokens, topk)"
|
), "`topk_weights` should be in shape (num_tokens, topk)"
|
||||||
@@ -309,7 +286,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
|||||||
topk_ids,
|
topk_ids,
|
||||||
activation=(
|
activation=(
|
||||||
ActivationType.Silu
|
ActivationType.Silu
|
||||||
if activation == "silu"
|
if moe_runner_config.activation == "silu"
|
||||||
else ActivationType.Gelu
|
else ActivationType.Gelu
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
@@ -325,13 +302,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
|||||||
b1=getattr(layer, "w13_weight_bias", None),
|
b1=getattr(layer, "w13_weight_bias", None),
|
||||||
b2=getattr(layer, "w2_weight_bias", None),
|
b2=getattr(layer, "w2_weight_bias", None),
|
||||||
topk_output=topk_output,
|
topk_output=topk_output,
|
||||||
inplace=inplace and not no_combine,
|
moe_runner_config=moe_runner_config,
|
||||||
activation=activation,
|
|
||||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
|
||||||
no_combine=no_combine,
|
|
||||||
routed_scaling_factor=routed_scaling_factor,
|
|
||||||
activation_alpha=activation_alpha,
|
|
||||||
swiglu_limit=swiglu_limit,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward_cpu(
|
def forward_cpu(
|
||||||
@@ -339,21 +310,21 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
|||||||
layer: torch.nn.Module,
|
layer: torch.nn.Module,
|
||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
topk_output: TopKOutput,
|
topk_output: TopKOutput,
|
||||||
*,
|
moe_runner_config: MoeRunnerConfig,
|
||||||
activation: str = "silu",
|
|
||||||
apply_router_weight_on_input: bool = False,
|
|
||||||
inplace: bool = True,
|
|
||||||
no_combine: bool = False,
|
|
||||||
routed_scaling_factor: Optional[float] = None,
|
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
assert activation == "silu", f"activation = {activation} is not supported."
|
assert (
|
||||||
|
moe_runner_config.activation == "silu"
|
||||||
|
), f"activation = {moe_runner_config.activation} is not supported."
|
||||||
|
|
||||||
if use_intel_amx_backend(layer) and not apply_router_weight_on_input:
|
if (
|
||||||
|
use_intel_amx_backend(layer)
|
||||||
|
and not moe_runner_config.apply_router_weight_on_input
|
||||||
|
):
|
||||||
from sglang.srt.layers.moe.topk import apply_topk_weights_cpu
|
from sglang.srt.layers.moe.topk import apply_topk_weights_cpu
|
||||||
|
|
||||||
topk_weights, topk_ids, _ = topk_output
|
topk_weights, topk_ids, _ = topk_output
|
||||||
x, topk_weights = apply_topk_weights_cpu(
|
x, topk_weights = apply_topk_weights_cpu(
|
||||||
apply_router_weight_on_input, topk_weights, x
|
moe_runner_config.apply_router_weight_on_input, topk_weights, x
|
||||||
)
|
)
|
||||||
return torch.ops.sgl_kernel.fused_experts_cpu(
|
return torch.ops.sgl_kernel.fused_experts_cpu(
|
||||||
x,
|
x,
|
||||||
@@ -378,11 +349,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
|||||||
layer,
|
layer,
|
||||||
x,
|
x,
|
||||||
topk_output,
|
topk_output,
|
||||||
activation=activation,
|
moe_runner_config,
|
||||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
|
||||||
inplace=inplace,
|
|
||||||
no_combine=no_combine,
|
|
||||||
routed_scaling_factor=routed_scaling_factor,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward_npu(
|
def forward_npu(
|
||||||
@@ -390,12 +357,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
|||||||
layer: torch.nn.Module,
|
layer: torch.nn.Module,
|
||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
topk_output: TopKOutput,
|
topk_output: TopKOutput,
|
||||||
*,
|
moe_runner_config: MoeRunnerConfig,
|
||||||
activation: str = "silu",
|
|
||||||
apply_router_weight_on_input: bool = False,
|
|
||||||
inplace: bool = True,
|
|
||||||
no_combine: bool = False,
|
|
||||||
routed_scaling_factor: Optional[float] = None,
|
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
from sglang.srt.layers.moe.fused_moe_native import moe_forward_native
|
from sglang.srt.layers.moe.fused_moe_native import moe_forward_native
|
||||||
|
|
||||||
@@ -403,11 +365,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
|||||||
layer,
|
layer,
|
||||||
x,
|
x,
|
||||||
topk_output,
|
topk_output,
|
||||||
activation=activation,
|
moe_runner_config,
|
||||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
|
||||||
inplace=inplace,
|
|
||||||
no_combine=no_combine,
|
|
||||||
routed_scaling_factor=routed_scaling_factor,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward_tpu(self, *args, **kwargs) -> torch.Tensor:
|
def forward_tpu(self, *args, **kwargs) -> torch.Tensor:
|
||||||
|
|||||||
@@ -18,7 +18,9 @@ from sglang.srt.layers.quantization.utils import is_layer_skipped
|
|||||||
from sglang.srt.utils import set_weight_attrs
|
from sglang.srt.utils import set_weight_attrs
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from sglang.srt.layers.moe.ep_moe.layer import EPMoE, TopKOutput
|
from sglang.srt.layers.moe import MoeRunnerConfig
|
||||||
|
from sglang.srt.layers.moe.ep_moe.layer import EPMoE
|
||||||
|
from sglang.srt.layers.moe.topk import StandardTopKOutput
|
||||||
|
|
||||||
ACTIVATION_SCHEMES = ["static", "dynamic"]
|
ACTIVATION_SCHEMES = ["static", "dynamic"]
|
||||||
|
|
||||||
@@ -280,11 +282,8 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
|
|||||||
self,
|
self,
|
||||||
layer: EPMoE,
|
layer: EPMoE,
|
||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
topk_output: TopKOutput,
|
topk_output: StandardTopKOutput,
|
||||||
activation: str = "silu",
|
moe_runner_config: MoeRunnerConfig,
|
||||||
apply_router_weight_on_input: bool = False,
|
|
||||||
routed_scaling_factor: Optional[float] = None,
|
|
||||||
**kwargs,
|
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
|
|
||||||
# TODO(ch-wan): move it out of this class
|
# TODO(ch-wan): move it out of this class
|
||||||
@@ -324,6 +323,6 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
|
|||||||
layer.w13_input_scale,
|
layer.w13_input_scale,
|
||||||
layer.w2_input_scale,
|
layer.w2_input_scale,
|
||||||
)
|
)
|
||||||
if routed_scaling_factor is not None:
|
if moe_runner_config.routed_scaling_factor is not None:
|
||||||
output *= routed_scaling_factor
|
output *= moe_runner_config.routed_scaling_factor
|
||||||
return output
|
return output
|
||||||
|
|||||||
@@ -26,7 +26,8 @@ from sglang.srt.layers.quantization.fp8_utils import (
|
|||||||
from sglang.srt.utils import set_weight_attrs
|
from sglang.srt.utils import set_weight_attrs
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from sglang.srt.layers.moe.topk import TopKOutput
|
from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
|
||||||
|
from sglang.srt.layers.moe.topk import StandardTopKOutput
|
||||||
|
|
||||||
_is_fp8_fnuz = is_fp8_fnuz()
|
_is_fp8_fnuz = is_fp8_fnuz()
|
||||||
|
|
||||||
@@ -269,13 +270,8 @@ class W8A8FP8MoEMethod(FusedMoEMethodBase):
|
|||||||
self,
|
self,
|
||||||
layer: torch.nn.Module,
|
layer: torch.nn.Module,
|
||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
topk_output: TopKOutput,
|
topk_output: StandardTopKOutput,
|
||||||
*,
|
moe_runner_config: MoeRunnerConfig,
|
||||||
activation: str = "silu",
|
|
||||||
apply_router_weight_on_input: bool = False,
|
|
||||||
inplace: bool = True,
|
|
||||||
no_combine: bool = False,
|
|
||||||
routed_scaling_factor: Optional[float] = None,
|
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
|
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
|
||||||
|
|
||||||
@@ -284,15 +280,11 @@ class W8A8FP8MoEMethod(FusedMoEMethodBase):
|
|||||||
layer.w13_weight,
|
layer.w13_weight,
|
||||||
layer.w2_weight,
|
layer.w2_weight,
|
||||||
topk_output=topk_output,
|
topk_output=topk_output,
|
||||||
inplace=inplace,
|
moe_runner_config=moe_runner_config,
|
||||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
|
||||||
activation=activation,
|
|
||||||
use_fp8_w8a8=True,
|
use_fp8_w8a8=True,
|
||||||
per_channel_quant=True,
|
per_channel_quant=True,
|
||||||
w1_scale=(layer.w13_weight_scale),
|
w1_scale=(layer.w13_weight_scale),
|
||||||
w2_scale=(layer.w2_weight_scale),
|
w2_scale=(layer.w2_weight_scale),
|
||||||
a1_scale=layer.w13_input_scale,
|
a1_scale=layer.w13_input_scale,
|
||||||
a2_scale=layer.w2_input_scale,
|
a2_scale=layer.w2_input_scale,
|
||||||
no_combine=no_combine,
|
|
||||||
routed_scaling_factor=routed_scaling_factor,
|
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -49,6 +49,7 @@ from sglang.srt.utils import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
|
||||||
from sglang.srt.layers.moe.topk import TopKOutput
|
from sglang.srt.layers.moe.topk import TopKOutput
|
||||||
|
|
||||||
_is_cuda = is_cuda()
|
_is_cuda = is_cuda()
|
||||||
@@ -487,12 +488,7 @@ class W8A8Int8MoEMethod(FusedMoEMethodBase):
|
|||||||
layer: torch.nn.Module,
|
layer: torch.nn.Module,
|
||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
topk_output: TopKOutput,
|
topk_output: TopKOutput,
|
||||||
*,
|
moe_runner_config: MoeRunnerConfig,
|
||||||
activation: str = "silu",
|
|
||||||
apply_router_weight_on_input: bool = False,
|
|
||||||
inplace: bool = True,
|
|
||||||
no_combine: bool = False,
|
|
||||||
routed_scaling_factor: Optional[float] = None,
|
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
|
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
|
||||||
|
|
||||||
@@ -501,7 +497,7 @@ class W8A8Int8MoEMethod(FusedMoEMethodBase):
|
|||||||
|
|
||||||
topk_weights, topk_ids, _ = topk_output
|
topk_weights, topk_ids, _ = topk_output
|
||||||
x, topk_weights = apply_topk_weights_cpu(
|
x, topk_weights = apply_topk_weights_cpu(
|
||||||
apply_router_weight_on_input, topk_weights, x
|
moe_runner_config.apply_router_weight_on_input, topk_weights, x
|
||||||
)
|
)
|
||||||
return torch.ops.sgl_kernel.fused_experts_cpu(
|
return torch.ops.sgl_kernel.fused_experts_cpu(
|
||||||
x,
|
x,
|
||||||
@@ -525,17 +521,13 @@ class W8A8Int8MoEMethod(FusedMoEMethodBase):
|
|||||||
layer.w13_weight,
|
layer.w13_weight,
|
||||||
layer.w2_weight,
|
layer.w2_weight,
|
||||||
topk_output=topk_output,
|
topk_output=topk_output,
|
||||||
inplace=inplace,
|
moe_runner_config=moe_runner_config,
|
||||||
activation=activation,
|
|
||||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
|
||||||
use_int8_w8a8=True,
|
use_int8_w8a8=True,
|
||||||
per_channel_quant=True,
|
per_channel_quant=True,
|
||||||
w1_scale=(layer.w13_weight_scale),
|
w1_scale=(layer.w13_weight_scale),
|
||||||
w2_scale=(layer.w2_weight_scale),
|
w2_scale=(layer.w2_weight_scale),
|
||||||
a1_scale=layer.w13_input_scale,
|
a1_scale=layer.w13_input_scale,
|
||||||
a2_scale=layer.w2_input_scale,
|
a2_scale=layer.w2_input_scale,
|
||||||
no_combine=no_combine,
|
|
||||||
routed_scaling_factor=routed_scaling_factor,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -982,7 +974,7 @@ class NPU_W8A8MoEMethod(FusedMoEMethodBase):
|
|||||||
layer,
|
layer,
|
||||||
x,
|
x,
|
||||||
topk_output: TopKOutput,
|
topk_output: TopKOutput,
|
||||||
**kwargs,
|
moe_runner_config: MoeRunnerConfig,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
|
|
||||||
topk_weights, topk_ids, _ = topk_output
|
topk_weights, topk_ids, _ = topk_output
|
||||||
|
|||||||
@@ -52,6 +52,7 @@ from sglang.srt.disaggregation.decode_schedule_batch_mixin import (
|
|||||||
ScheduleBatchDisaggregationDecodeMixin,
|
ScheduleBatchDisaggregationDecodeMixin,
|
||||||
)
|
)
|
||||||
from sglang.srt.distributed.parallel_state import get_tensor_model_parallel_rank
|
from sglang.srt.distributed.parallel_state import get_tensor_model_parallel_rank
|
||||||
|
from sglang.srt.layers.moe import is_tbo_enabled
|
||||||
from sglang.srt.mem_cache.allocator import (
|
from sglang.srt.mem_cache.allocator import (
|
||||||
BaseTokenToKVPoolAllocator,
|
BaseTokenToKVPoolAllocator,
|
||||||
SWATokenToKVPoolAllocator,
|
SWATokenToKVPoolAllocator,
|
||||||
@@ -84,17 +85,10 @@ GLOBAL_SERVER_ARGS_KEYS = [
|
|||||||
"device",
|
"device",
|
||||||
"disable_chunked_prefix_cache",
|
"disable_chunked_prefix_cache",
|
||||||
"disable_radix_cache",
|
"disable_radix_cache",
|
||||||
"enable_two_batch_overlap",
|
|
||||||
"tbo_token_distribution_threshold",
|
|
||||||
"enable_dp_lm_head",
|
"enable_dp_lm_head",
|
||||||
"moe_a2a_backend",
|
|
||||||
"deepep_mode",
|
|
||||||
"enable_flashinfer_cutlass_moe",
|
|
||||||
"enable_flashinfer_trtllm_moe",
|
|
||||||
"enable_flashinfer_allreduce_fusion",
|
"enable_flashinfer_allreduce_fusion",
|
||||||
"moe_dense_tp_size",
|
"moe_dense_tp_size",
|
||||||
"ep_dispatch_algorithm",
|
"ep_dispatch_algorithm",
|
||||||
"deepep_config",
|
|
||||||
"ep_num_redundant_experts",
|
"ep_num_redundant_experts",
|
||||||
"enable_nan_detection",
|
"enable_nan_detection",
|
||||||
"flashinfer_mla_disable_ragged",
|
"flashinfer_mla_disable_ragged",
|
||||||
@@ -107,8 +101,6 @@ GLOBAL_SERVER_ARGS_KEYS = [
|
|||||||
"triton_attention_reduce_in_fp32",
|
"triton_attention_reduce_in_fp32",
|
||||||
"num_reserved_decode_tokens",
|
"num_reserved_decode_tokens",
|
||||||
"weight_loader_disable_mmap",
|
"weight_loader_disable_mmap",
|
||||||
"enable_triton_kernel_moe",
|
|
||||||
"enable_flashinfer_mxfp4_moe",
|
|
||||||
"enable_multimodal",
|
"enable_multimodal",
|
||||||
"enable_symm_mem",
|
"enable_symm_mem",
|
||||||
"quantization",
|
"quantization",
|
||||||
|
|||||||
@@ -64,7 +64,7 @@ from sglang.srt.hf_transformers_utils import (
|
|||||||
)
|
)
|
||||||
from sglang.srt.layers.dp_attention import compute_dp_attention_world_info
|
from sglang.srt.layers.dp_attention import compute_dp_attention_world_info
|
||||||
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
||||||
from sglang.srt.layers.moe.utils import DeepEPMode, MoeA2ABackend
|
from sglang.srt.layers.moe import initialize_moe_config
|
||||||
from sglang.srt.managers.io_struct import (
|
from sglang.srt.managers.io_struct import (
|
||||||
AbortReq,
|
AbortReq,
|
||||||
CloseSessionReqInput,
|
CloseSessionReqInput,
|
||||||
@@ -245,6 +245,9 @@ class Scheduler(
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Init model config
|
||||||
|
self.model_config = ModelConfig.from_server_args(server_args)
|
||||||
|
|
||||||
# Init inter-process communication
|
# Init inter-process communication
|
||||||
context = zmq.Context(2)
|
context = zmq.Context(2)
|
||||||
self.idle_sleeper = None
|
self.idle_sleeper = None
|
||||||
@@ -292,6 +295,9 @@ class Scheduler(
|
|||||||
# Init tokenizer
|
# Init tokenizer
|
||||||
self.init_tokenizer()
|
self.init_tokenizer()
|
||||||
|
|
||||||
|
# Init moe config
|
||||||
|
self.init_moe_config()
|
||||||
|
|
||||||
# Set reasoning_parser and think_end_id if --reasoning_parser is enabled
|
# Set reasoning_parser and think_end_id if --reasoning_parser is enabled
|
||||||
if self.server_args.reasoning_parser and self.tokenizer:
|
if self.server_args.reasoning_parser and self.tokenizer:
|
||||||
reasoning_parser = ReasoningParser(
|
reasoning_parser = ReasoningParser(
|
||||||
@@ -538,8 +544,6 @@ class Scheduler(
|
|||||||
|
|
||||||
def init_tokenizer(self):
|
def init_tokenizer(self):
|
||||||
server_args = self.server_args
|
server_args = self.server_args
|
||||||
|
|
||||||
self.model_config = ModelConfig.from_server_args(server_args)
|
|
||||||
self.is_generation = self.model_config.is_generation
|
self.is_generation = self.model_config.is_generation
|
||||||
|
|
||||||
if server_args.skip_tokenizer_init:
|
if server_args.skip_tokenizer_init:
|
||||||
@@ -761,6 +765,10 @@ class Scheduler(
|
|||||||
# The prefill requests that are in the middle of kv sending
|
# The prefill requests that are in the middle of kv sending
|
||||||
self.disagg_prefill_inflight_queue: List[Req] = []
|
self.disagg_prefill_inflight_queue: List[Req] = []
|
||||||
|
|
||||||
|
def init_moe_config(self):
|
||||||
|
if hasattr(self.model_config.hf_config, "num_experts_per_tok"):
|
||||||
|
initialize_moe_config(self.server_args)
|
||||||
|
|
||||||
@DynamicGradMode()
|
@DynamicGradMode()
|
||||||
def event_loop_normal(self):
|
def event_loop_normal(self):
|
||||||
"""A normal scheduler loop."""
|
"""A normal scheduler loop."""
|
||||||
@@ -1823,11 +1831,6 @@ class Scheduler(
|
|||||||
disable_cuda_graph=self.server_args.disable_cuda_graph,
|
disable_cuda_graph=self.server_args.disable_cuda_graph,
|
||||||
spec_algorithm=self.spec_algorithm,
|
spec_algorithm=self.spec_algorithm,
|
||||||
speculative_num_draft_tokens=self.server_args.speculative_num_draft_tokens,
|
speculative_num_draft_tokens=self.server_args.speculative_num_draft_tokens,
|
||||||
enable_two_batch_overlap=self.server_args.enable_two_batch_overlap,
|
|
||||||
enable_deepep_moe=MoeA2ABackend(
|
|
||||||
self.server_args.moe_a2a_backend
|
|
||||||
).is_deepep(),
|
|
||||||
deepep_mode=DeepEPMode(self.server_args.deepep_mode),
|
|
||||||
require_mlp_tp_gather=require_mlp_tp_gather(self.server_args),
|
require_mlp_tp_gather=require_mlp_tp_gather(self.server_args),
|
||||||
disable_overlap_schedule=self.server_args.disable_overlap_schedule,
|
disable_overlap_schedule=self.server_args.disable_overlap_schedule,
|
||||||
)
|
)
|
||||||
@@ -1922,9 +1925,6 @@ class Scheduler(
|
|||||||
disable_cuda_graph: bool,
|
disable_cuda_graph: bool,
|
||||||
spec_algorithm,
|
spec_algorithm,
|
||||||
speculative_num_draft_tokens,
|
speculative_num_draft_tokens,
|
||||||
enable_two_batch_overlap: bool,
|
|
||||||
enable_deepep_moe: bool,
|
|
||||||
deepep_mode: DeepEPMode,
|
|
||||||
require_mlp_tp_gather: bool,
|
require_mlp_tp_gather: bool,
|
||||||
disable_overlap_schedule: bool,
|
disable_overlap_schedule: bool,
|
||||||
):
|
):
|
||||||
@@ -1972,9 +1972,6 @@ class Scheduler(
|
|||||||
is_extend_in_batch,
|
is_extend_in_batch,
|
||||||
*tbo_preparer.prepare_all_gather(
|
*tbo_preparer.prepare_all_gather(
|
||||||
local_batch,
|
local_batch,
|
||||||
deepep_mode,
|
|
||||||
enable_deepep_moe,
|
|
||||||
enable_two_batch_overlap,
|
|
||||||
),
|
),
|
||||||
],
|
],
|
||||||
dtype=torch.int64,
|
dtype=torch.int64,
|
||||||
|
|||||||
@@ -60,7 +60,6 @@ from sglang.srt.layers.dp_attention import (
|
|||||||
initialize_dp_attention,
|
initialize_dp_attention,
|
||||||
)
|
)
|
||||||
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
||||||
from sglang.srt.layers.moe.utils import DeepEPMode, MoeA2ABackend
|
|
||||||
from sglang.srt.layers.quantization import (
|
from sglang.srt.layers.quantization import (
|
||||||
deep_gemm_wrapper,
|
deep_gemm_wrapper,
|
||||||
monkey_patch_isinstance_for_vllm_base_layer,
|
monkey_patch_isinstance_for_vllm_base_layer,
|
||||||
@@ -219,8 +218,6 @@ class ModelRunner:
|
|||||||
# TODO it is indeed not a "server args"
|
# TODO it is indeed not a "server args"
|
||||||
"use_mla_backend": self.use_mla_backend,
|
"use_mla_backend": self.use_mla_backend,
|
||||||
"speculative_algorithm": self.spec_algorithm,
|
"speculative_algorithm": self.spec_algorithm,
|
||||||
"moe_a2a_backend": MoeA2ABackend(server_args.moe_a2a_backend),
|
|
||||||
"deepep_mode": DeepEPMode(server_args.deepep_mode),
|
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -32,7 +32,9 @@ from sglang.srt.layers.linear import (
|
|||||||
RowParallelLinear,
|
RowParallelLinear,
|
||||||
)
|
)
|
||||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||||
from sglang.srt.layers.moe.fused_moe_triton import fused_moe
|
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_moe
|
||||||
|
from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
|
||||||
|
from sglang.srt.layers.moe.topk import TopK
|
||||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||||
from sglang.srt.layers.radix_attention import RadixAttention
|
from sglang.srt.layers.radix_attention import RadixAttention
|
||||||
from sglang.srt.layers.rotary_embedding import get_rope
|
from sglang.srt.layers.rotary_embedding import get_rope
|
||||||
@@ -104,6 +106,11 @@ class DbrxExperts(nn.Module):
|
|||||||
self.params_dtype = params_dtype
|
self.params_dtype = params_dtype
|
||||||
|
|
||||||
self.router = DbrxRouter(config, self.params_dtype)
|
self.router = DbrxRouter(config, self.params_dtype)
|
||||||
|
self.topk = TopK(
|
||||||
|
self.top_k,
|
||||||
|
renormalize=True,
|
||||||
|
)
|
||||||
|
self.moe_runner_config = MoeRunnerConfig(inplace=True)
|
||||||
self.ws = nn.Parameter(
|
self.ws = nn.Parameter(
|
||||||
torch.empty(
|
torch.empty(
|
||||||
self.num_total_experts,
|
self.num_total_experts,
|
||||||
@@ -169,14 +176,13 @@ class DbrxExperts(nn.Module):
|
|||||||
hidden_states = hidden_states.view(-1, self.d_model)
|
hidden_states = hidden_states.view(-1, self.d_model)
|
||||||
# router_logits: (num_tokens, n_experts)
|
# router_logits: (num_tokens, n_experts)
|
||||||
router_logits = self.router(hidden_states)
|
router_logits = self.router(hidden_states)
|
||||||
|
topk_output = self.topk(hidden_states, router_logits)
|
||||||
final_hidden_states = fused_moe(
|
final_hidden_states = fused_moe(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
self.ws,
|
self.ws,
|
||||||
self.w2s,
|
self.w2s,
|
||||||
router_logits,
|
topk_output,
|
||||||
self.top_k,
|
self.moe_runner_config,
|
||||||
renormalize=True,
|
|
||||||
inplace=True,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.tp_size > 1:
|
if self.tp_size > 1:
|
||||||
@@ -293,7 +299,7 @@ class DbrxFusedNormAttention(nn.Module):
|
|||||||
position_ids: torch.Tensor,
|
position_ids: torch.Tensor,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
forward_batch: ForwardBatch,
|
forward_batch: ForwardBatch,
|
||||||
) -> torch.Tensor:
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
residual = hidden_states
|
residual = hidden_states
|
||||||
hidden_states = self.norm_1(hidden_states)
|
hidden_states = self.norm_1(hidden_states)
|
||||||
x = self.attn(
|
x = self.attn(
|
||||||
|
|||||||
@@ -37,6 +37,7 @@ from sglang.srt.layers.linear import (
|
|||||||
)
|
)
|
||||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||||
from sglang.srt.layers.moe.fused_moe_triton import fused_moe
|
from sglang.srt.layers.moe.fused_moe_triton import fused_moe
|
||||||
|
from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
|
||||||
from sglang.srt.layers.moe.topk import TopK
|
from sglang.srt.layers.moe.topk import TopK
|
||||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||||
from sglang.srt.layers.radix_attention import RadixAttention
|
from sglang.srt.layers.radix_attention import RadixAttention
|
||||||
@@ -180,7 +181,7 @@ class DeepseekMoE(nn.Module):
|
|||||||
w1=self.w1,
|
w1=self.w1,
|
||||||
w2=self.w2,
|
w2=self.w2,
|
||||||
topk_output=topk_output,
|
topk_output=topk_output,
|
||||||
inplace=True,
|
moe_runner_config=MoeRunnerConfig(inplace=True),
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.config.n_shared_experts is not None:
|
if self.config.n_shared_experts is not None:
|
||||||
|
|||||||
@@ -50,7 +50,6 @@ from sglang.srt.layers.communicator import (
|
|||||||
from sglang.srt.layers.dp_attention import (
|
from sglang.srt.layers.dp_attention import (
|
||||||
get_attention_tp_rank,
|
get_attention_tp_rank,
|
||||||
get_attention_tp_size,
|
get_attention_tp_size,
|
||||||
get_local_attention_dp_size,
|
|
||||||
is_dp_attention_enabled,
|
is_dp_attention_enabled,
|
||||||
)
|
)
|
||||||
from sglang.srt.layers.layernorm import RMSNorm
|
from sglang.srt.layers.layernorm import RMSNorm
|
||||||
@@ -61,9 +60,10 @@ from sglang.srt.layers.linear import (
|
|||||||
RowParallelLinear,
|
RowParallelLinear,
|
||||||
)
|
)
|
||||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||||
|
from sglang.srt.layers.moe import get_deepep_mode, get_moe_a2a_backend
|
||||||
from sglang.srt.layers.moe.ep_moe.layer import DeepEPMoE, get_moe_impl_class
|
from sglang.srt.layers.moe.ep_moe.layer import DeepEPMoE, get_moe_impl_class
|
||||||
|
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
|
||||||
from sglang.srt.layers.moe.topk import TopK
|
from sglang.srt.layers.moe.topk import TopK
|
||||||
from sglang.srt.layers.moe.utils import should_use_flashinfer_trtllm_moe
|
|
||||||
from sglang.srt.layers.quantization import deep_gemm_wrapper
|
from sglang.srt.layers.quantization import deep_gemm_wrapper
|
||||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||||
from sglang.srt.layers.quantization.fp8_kernel import (
|
from sglang.srt.layers.quantization.fp8_kernel import (
|
||||||
@@ -336,30 +336,6 @@ class DeepseekV2MoE(nn.Module):
|
|||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
routed_scaling_factor=self.routed_scaling_factor,
|
routed_scaling_factor=self.routed_scaling_factor,
|
||||||
prefix=add_prefix("experts", prefix),
|
prefix=add_prefix("experts", prefix),
|
||||||
**(
|
|
||||||
dict(deepep_mode=global_server_args_dict["deepep_mode"])
|
|
||||||
if global_server_args_dict["moe_a2a_backend"].is_deepep()
|
|
||||||
else {}
|
|
||||||
),
|
|
||||||
# Additional args for FusedMoE
|
|
||||||
**(
|
|
||||||
dict(
|
|
||||||
enable_flashinfer_cutlass_moe=True,
|
|
||||||
)
|
|
||||||
if global_server_args_dict["enable_flashinfer_cutlass_moe"]
|
|
||||||
else {}
|
|
||||||
),
|
|
||||||
**(
|
|
||||||
dict(
|
|
||||||
renormalize=config.norm_topk_prob,
|
|
||||||
use_grouped_topk=True,
|
|
||||||
num_expert_group=config.n_group,
|
|
||||||
topk_group=config.topk_group,
|
|
||||||
correction_bias=self.gate.e_score_correction_bias,
|
|
||||||
)
|
|
||||||
if should_use_flashinfer_trtllm_moe()
|
|
||||||
else {}
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
self.shared_experts_is_int8 = False
|
self.shared_experts_is_int8 = False
|
||||||
@@ -377,7 +353,7 @@ class DeepseekV2MoE(nn.Module):
|
|||||||
prefix=add_prefix("shared_experts", prefix),
|
prefix=add_prefix("shared_experts", prefix),
|
||||||
**(
|
**(
|
||||||
dict(tp_rank=0, tp_size=1)
|
dict(tp_rank=0, tp_size=1)
|
||||||
if global_server_args_dict["moe_a2a_backend"].is_deepep()
|
if get_moe_a2a_backend().is_deepep()
|
||||||
else {}
|
else {}
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
@@ -407,7 +383,7 @@ class DeepseekV2MoE(nn.Module):
|
|||||||
|
|
||||||
self.top_k = config.num_experts_per_tok
|
self.top_k = config.num_experts_per_tok
|
||||||
|
|
||||||
if global_server_args_dict["moe_a2a_backend"].is_deepep():
|
if get_moe_a2a_backend().is_deepep():
|
||||||
# TODO: we will support tp < ep in the future
|
# TODO: we will support tp < ep in the future
|
||||||
self.ep_size = get_moe_expert_parallel_world_size()
|
self.ep_size = get_moe_expert_parallel_world_size()
|
||||||
self.num_experts = (
|
self.num_experts = (
|
||||||
@@ -431,12 +407,12 @@ class DeepseekV2MoE(nn.Module):
|
|||||||
num_local_experts=config.n_routed_experts // self.tp_size,
|
num_local_experts=config.n_routed_experts // self.tp_size,
|
||||||
hidden_size=config.hidden_size,
|
hidden_size=config.hidden_size,
|
||||||
params_dtype=config.torch_dtype,
|
params_dtype=config.torch_dtype,
|
||||||
deepep_mode=global_server_args_dict["deepep_mode"],
|
deepep_mode=get_deepep_mode(),
|
||||||
async_finish=True,
|
async_finish=True,
|
||||||
return_recv_hook=True,
|
return_recv_hook=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
self._enable_deepep_moe = global_server_args_dict["moe_a2a_backend"].is_deepep()
|
self._enable_deepep_moe = get_moe_a2a_backend().is_deepep()
|
||||||
|
|
||||||
def get_moe_weights(self):
|
def get_moe_weights(self):
|
||||||
return [
|
return [
|
||||||
@@ -484,12 +460,6 @@ class DeepseekV2MoE(nn.Module):
|
|||||||
# router_logits: (num_tokens, n_experts)
|
# router_logits: (num_tokens, n_experts)
|
||||||
router_logits = self.gate(hidden_states)
|
router_logits = self.gate(hidden_states)
|
||||||
kwargs = {"hidden_states": hidden_states}
|
kwargs = {"hidden_states": hidden_states}
|
||||||
|
|
||||||
# FlashInferFP4MoE (TRTLLM path) expects (TopK, router_logits) tuple
|
|
||||||
# Regular FusedMoE (CUTLASS path) expects StandardTopKOutput
|
|
||||||
if should_use_flashinfer_trtllm_moe():
|
|
||||||
kwargs["topk_output"] = (self.topk, router_logits)
|
|
||||||
else:
|
|
||||||
kwargs["topk_output"] = self.topk(hidden_states, router_logits)
|
kwargs["topk_output"] = self.topk(hidden_states, router_logits)
|
||||||
|
|
||||||
final_hidden_states = self.experts(**kwargs)
|
final_hidden_states = self.experts(**kwargs)
|
||||||
@@ -520,12 +490,6 @@ class DeepseekV2MoE(nn.Module):
|
|||||||
# router_logits: (num_tokens, n_experts)
|
# router_logits: (num_tokens, n_experts)
|
||||||
router_logits = self.gate(hidden_states)
|
router_logits = self.gate(hidden_states)
|
||||||
kwargs = {"hidden_states": hidden_states}
|
kwargs = {"hidden_states": hidden_states}
|
||||||
|
|
||||||
# FlashInferFP4MoE (TRTLLM path) expects (TopK, router_logits) tuple
|
|
||||||
# Regular FusedMoE (CUTLASS path) expects StandardTopKOutput
|
|
||||||
if should_use_flashinfer_trtllm_moe():
|
|
||||||
kwargs["topk_output"] = (self.topk, router_logits)
|
|
||||||
else:
|
|
||||||
kwargs["topk_output"] = self.topk(hidden_states, router_logits)
|
kwargs["topk_output"] = self.topk(hidden_states, router_logits)
|
||||||
|
|
||||||
final_hidden_states = self.experts(**kwargs)
|
final_hidden_states = self.experts(**kwargs)
|
||||||
@@ -2478,18 +2442,16 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|||||||
|
|
||||||
# Params for weights, fp8 weight scales, fp8 activation scales
|
# Params for weights, fp8 weight scales, fp8 activation scales
|
||||||
# (param_name, weight_name, expert_id, shard_id)
|
# (param_name, weight_name, expert_id, shard_id)
|
||||||
expert_params_mapping = get_moe_impl_class().make_expert_params_mapping(
|
expert_params_mapping = FusedMoE.make_expert_params_mapping(
|
||||||
ckpt_gate_proj_name="gate_proj",
|
ckpt_gate_proj_name="gate_proj",
|
||||||
ckpt_down_proj_name="down_proj",
|
ckpt_down_proj_name="down_proj",
|
||||||
ckpt_up_proj_name="up_proj",
|
ckpt_up_proj_name="up_proj",
|
||||||
num_experts=self.config.n_routed_experts + self.num_fused_shared_experts,
|
num_experts=self.config.n_routed_experts + self.num_fused_shared_experts,
|
||||||
)
|
)
|
||||||
if self.quant_config and self.quant_config.get_name() == "w4afp8":
|
if self.quant_config and self.quant_config.get_name() == "w4afp8":
|
||||||
expert_params_mapping += (
|
expert_params_mapping += FusedMoE.make_expert_input_scale_params_mapping(
|
||||||
get_moe_impl_class().make_expert_input_scale_params_mapping(
|
|
||||||
num_experts=self.config.n_routed_experts
|
num_experts=self.config.n_routed_experts
|
||||||
)
|
)
|
||||||
)
|
|
||||||
|
|
||||||
# Fuse q_a_proj and kv_a_proj_with_mqa along output dimension when q_lora_rank is not None
|
# Fuse q_a_proj and kv_a_proj_with_mqa along output dimension when q_lora_rank is not None
|
||||||
fuse_qkv_a_proj = hasattr(self.config, "q_lora_rank") and (
|
fuse_qkv_a_proj = hasattr(self.config, "q_lora_rank") and (
|
||||||
|
|||||||
@@ -31,13 +31,13 @@ from sglang.srt.layers.communicator import enable_moe_dense_fully_dp
|
|||||||
from sglang.srt.layers.layernorm import RMSNorm
|
from sglang.srt.layers.layernorm import RMSNorm
|
||||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||||
from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class
|
from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class
|
||||||
|
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
|
||||||
from sglang.srt.layers.moe.topk import TopK
|
from sglang.srt.layers.moe.topk import TopK
|
||||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||||
from sglang.srt.layers.vocab_parallel_embedding import (
|
from sglang.srt.layers.vocab_parallel_embedding import (
|
||||||
ParallelLMHead,
|
ParallelLMHead,
|
||||||
VocabParallelEmbedding,
|
VocabParallelEmbedding,
|
||||||
)
|
)
|
||||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
|
||||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||||
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
||||||
from sglang.srt.models.deepseek_v2 import DeepseekV2MLP as Ernie4MLP
|
from sglang.srt.models.deepseek_v2 import DeepseekV2MLP as Ernie4MLP
|
||||||
@@ -361,7 +361,7 @@ class Ernie4_5_ForCausalLM(nn.Module):
|
|||||||
|
|
||||||
class Ernie4_5_MoeForCausalLM(Ernie4_5_ForCausalLM):
|
class Ernie4_5_MoeForCausalLM(Ernie4_5_ForCausalLM):
|
||||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||||
expert_params_mapping = get_moe_impl_class().make_expert_params_mapping(
|
expert_params_mapping = FusedMoE.make_expert_params_mapping(
|
||||||
ckpt_gate_proj_name="gate_proj",
|
ckpt_gate_proj_name="gate_proj",
|
||||||
ckpt_down_proj_name="down_proj",
|
ckpt_down_proj_name="down_proj",
|
||||||
ckpt_up_proj_name="up_proj",
|
ckpt_up_proj_name="up_proj",
|
||||||
|
|||||||
@@ -39,7 +39,6 @@ from sglang.srt.layers.communicator import (
|
|||||||
from sglang.srt.layers.dp_attention import (
|
from sglang.srt.layers.dp_attention import (
|
||||||
get_attention_tp_rank,
|
get_attention_tp_rank,
|
||||||
get_attention_tp_size,
|
get_attention_tp_size,
|
||||||
get_local_attention_dp_size,
|
|
||||||
is_dp_attention_enabled,
|
is_dp_attention_enabled,
|
||||||
)
|
)
|
||||||
from sglang.srt.layers.layernorm import RMSNorm
|
from sglang.srt.layers.layernorm import RMSNorm
|
||||||
@@ -51,9 +50,10 @@ from sglang.srt.layers.linear import (
|
|||||||
RowParallelLinear,
|
RowParallelLinear,
|
||||||
)
|
)
|
||||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||||
|
from sglang.srt.layers.moe import get_deepep_mode, get_moe_a2a_backend
|
||||||
from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class
|
from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class
|
||||||
|
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
|
||||||
from sglang.srt.layers.moe.topk import TopK
|
from sglang.srt.layers.moe.topk import TopK
|
||||||
from sglang.srt.layers.moe.utils import should_use_flashinfer_trtllm_moe
|
|
||||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||||
from sglang.srt.layers.quantization.fp8_kernel import (
|
from sglang.srt.layers.quantization.fp8_kernel import (
|
||||||
is_fp8_fnuz,
|
is_fp8_fnuz,
|
||||||
@@ -76,10 +76,7 @@ from sglang.srt.models.deepseek_v2 import (
|
|||||||
DeepseekV2Model,
|
DeepseekV2Model,
|
||||||
DeepseekV2MoE,
|
DeepseekV2MoE,
|
||||||
)
|
)
|
||||||
from sglang.srt.two_batch_overlap import (
|
from sglang.srt.two_batch_overlap import MaybeTboDeepEPDispatcher
|
||||||
MaybeTboDeepEPDispatcher,
|
|
||||||
model_forward_maybe_tbo,
|
|
||||||
)
|
|
||||||
from sglang.srt.utils import (
|
from sglang.srt.utils import (
|
||||||
BumpAllocator,
|
BumpAllocator,
|
||||||
LazyValue,
|
LazyValue,
|
||||||
@@ -414,8 +411,7 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
|
|||||||
config=config, prefix=add_prefix("gate", prefix), is_nextn=is_nextn
|
config=config, prefix=add_prefix("gate", prefix), is_nextn=is_nextn
|
||||||
)
|
)
|
||||||
|
|
||||||
self.topk = (
|
self.topk = TopK(
|
||||||
TopK(
|
|
||||||
top_k=config.num_experts_per_tok + self.num_fused_shared_experts,
|
top_k=config.num_experts_per_tok + self.num_fused_shared_experts,
|
||||||
renormalize=config.norm_topk_prob,
|
renormalize=config.norm_topk_prob,
|
||||||
use_grouped_topk=True,
|
use_grouped_topk=True,
|
||||||
@@ -425,9 +421,6 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
|
|||||||
correction_bias=self.gate.e_score_correction_bias,
|
correction_bias=self.gate.e_score_correction_bias,
|
||||||
routed_scaling_factor=self.routed_scaling_factor,
|
routed_scaling_factor=self.routed_scaling_factor,
|
||||||
)
|
)
|
||||||
if not should_use_flashinfer_trtllm_moe()
|
|
||||||
else None
|
|
||||||
)
|
|
||||||
|
|
||||||
self.experts = get_moe_impl_class()(
|
self.experts = get_moe_impl_class()(
|
||||||
num_experts=config.n_routed_experts
|
num_experts=config.n_routed_experts
|
||||||
@@ -441,31 +434,6 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
|
|||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
routed_scaling_factor=self.routed_scaling_factor,
|
routed_scaling_factor=self.routed_scaling_factor,
|
||||||
prefix=add_prefix("experts", prefix),
|
prefix=add_prefix("experts", prefix),
|
||||||
**(
|
|
||||||
dict(deepep_mode=global_server_args_dict["deepep_mode"])
|
|
||||||
if global_server_args_dict["moe_a2a_backend"].is_deepep()
|
|
||||||
else {}
|
|
||||||
),
|
|
||||||
# Additional args for FusedMoE
|
|
||||||
**(
|
|
||||||
dict(
|
|
||||||
enable_flashinfer_cutlass_moe=True,
|
|
||||||
)
|
|
||||||
if global_server_args_dict["enable_flashinfer_cutlass_moe"]
|
|
||||||
else {}
|
|
||||||
),
|
|
||||||
**(
|
|
||||||
dict(
|
|
||||||
renormalize=config.norm_topk_prob,
|
|
||||||
use_grouped_topk=True,
|
|
||||||
num_expert_group=config.n_group,
|
|
||||||
num_fused_shared_experts=self.num_fused_shared_experts,
|
|
||||||
topk_group=config.topk_group,
|
|
||||||
correction_bias=self.gate.e_score_correction_bias,
|
|
||||||
)
|
|
||||||
if should_use_flashinfer_trtllm_moe()
|
|
||||||
else {}
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
self.shared_experts_is_int8 = False
|
self.shared_experts_is_int8 = False
|
||||||
@@ -496,7 +464,7 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
|
|||||||
|
|
||||||
self.top_k = config.num_experts_per_tok
|
self.top_k = config.num_experts_per_tok
|
||||||
|
|
||||||
if global_server_args_dict["moe_a2a_backend"].is_deepep():
|
if get_moe_a2a_backend().is_deepep():
|
||||||
# TODO: we will support tp < ep in the future
|
# TODO: we will support tp < ep in the future
|
||||||
self.ep_size = get_moe_expert_parallel_world_size()
|
self.ep_size = get_moe_expert_parallel_world_size()
|
||||||
self.num_experts = (
|
self.num_experts = (
|
||||||
@@ -520,12 +488,12 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
|
|||||||
num_local_experts=config.n_routed_experts // self.tp_size,
|
num_local_experts=config.n_routed_experts // self.tp_size,
|
||||||
hidden_size=config.hidden_size,
|
hidden_size=config.hidden_size,
|
||||||
params_dtype=config.torch_dtype,
|
params_dtype=config.torch_dtype,
|
||||||
deepep_mode=global_server_args_dict["deepep_mode"],
|
deepep_mode=get_deepep_mode(),
|
||||||
async_finish=True,
|
async_finish=True,
|
||||||
return_recv_hook=True,
|
return_recv_hook=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
self._enable_deepep_moe = global_server_args_dict["moe_a2a_backend"].is_deepep()
|
self._enable_deepep_moe = get_moe_a2a_backend().is_deepep()
|
||||||
|
|
||||||
def forward_normal_dual_stream(
|
def forward_normal_dual_stream(
|
||||||
self,
|
self,
|
||||||
@@ -542,10 +510,7 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
|
|||||||
# router_logits: (num_tokens, n_experts)
|
# router_logits: (num_tokens, n_experts)
|
||||||
router_logits = self.gate(hidden_states)
|
router_logits = self.gate(hidden_states)
|
||||||
kwargs = {"hidden_states": hidden_states}
|
kwargs = {"hidden_states": hidden_states}
|
||||||
if self.topk is not None:
|
|
||||||
kwargs["topk_output"] = self.topk(hidden_states, router_logits)
|
kwargs["topk_output"] = self.topk(hidden_states, router_logits)
|
||||||
else:
|
|
||||||
kwargs["router_logits"] = router_logits
|
|
||||||
final_hidden_states = self.experts(**kwargs)
|
final_hidden_states = self.experts(**kwargs)
|
||||||
if not _is_cuda:
|
if not _is_cuda:
|
||||||
final_hidden_states *= self.routed_scaling_factor
|
final_hidden_states *= self.routed_scaling_factor
|
||||||
@@ -588,10 +553,7 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
|
|||||||
# router_logits: (num_tokens, n_experts)
|
# router_logits: (num_tokens, n_experts)
|
||||||
router_logits = self.gate(hidden_states)
|
router_logits = self.gate(hidden_states)
|
||||||
kwargs = {"hidden_states": hidden_states}
|
kwargs = {"hidden_states": hidden_states}
|
||||||
if self.topk is not None:
|
|
||||||
kwargs["topk_output"] = self.topk(hidden_states, router_logits)
|
kwargs["topk_output"] = self.topk(hidden_states, router_logits)
|
||||||
else:
|
|
||||||
kwargs["router_logits"] = router_logits
|
|
||||||
final_hidden_states = self.experts(**kwargs)
|
final_hidden_states = self.experts(**kwargs)
|
||||||
if not _is_cuda and not _use_aiter:
|
if not _is_cuda and not _use_aiter:
|
||||||
# fused in biased_grouped_topk so we can skip here
|
# fused in biased_grouped_topk so we can skip here
|
||||||
@@ -761,8 +723,6 @@ class Glm4MoeModel(DeepseekV2Model):
|
|||||||
)
|
)
|
||||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||||
|
|
||||||
self.dp_size = get_local_attention_dp_size()
|
|
||||||
|
|
||||||
|
|
||||||
class Glm4MoeForCausalLM(DeepseekV2ForCausalLM):
|
class Glm4MoeForCausalLM(DeepseekV2ForCausalLM):
|
||||||
|
|
||||||
@@ -789,7 +749,6 @@ class Glm4MoeForCausalLM(DeepseekV2ForCausalLM):
|
|||||||
use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"],
|
use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"],
|
||||||
)
|
)
|
||||||
self.logits_processor = LogitsProcessor(config)
|
self.logits_processor = LogitsProcessor(config)
|
||||||
self.dp_size = get_local_attention_dp_size()
|
|
||||||
|
|
||||||
self._routed_experts_weights_of_layer = LazyValue(
|
self._routed_experts_weights_of_layer = LazyValue(
|
||||||
lambda: {
|
lambda: {
|
||||||
@@ -953,7 +912,7 @@ class Glm4MoeForCausalLM(DeepseekV2ForCausalLM):
|
|||||||
|
|
||||||
# Params for weights, fp8 weight scales, fp8 activation scales
|
# Params for weights, fp8 weight scales, fp8 activation scales
|
||||||
# (param_name, weight_name, expert_id, shard_id)
|
# (param_name, weight_name, expert_id, shard_id)
|
||||||
expert_params_mapping = get_moe_impl_class().make_expert_params_mapping(
|
expert_params_mapping = FusedMoE.make_expert_params_mapping(
|
||||||
ckpt_gate_proj_name="gate_proj",
|
ckpt_gate_proj_name="gate_proj",
|
||||||
ckpt_down_proj_name="down_proj",
|
ckpt_down_proj_name="down_proj",
|
||||||
ckpt_up_proj_name="up_proj",
|
ckpt_up_proj_name="up_proj",
|
||||||
|
|||||||
@@ -8,19 +8,11 @@ from transformers.models.glm4v_moe.configuration_glm4v_moe import Glm4vMoeConfig
|
|||||||
|
|
||||||
from sglang.srt.distributed import (
|
from sglang.srt.distributed import (
|
||||||
get_moe_expert_parallel_world_size,
|
get_moe_expert_parallel_world_size,
|
||||||
get_tensor_model_parallel_rank,
|
|
||||||
get_tensor_model_parallel_world_size,
|
get_tensor_model_parallel_world_size,
|
||||||
parallel_state,
|
|
||||||
tensor_model_parallel_all_reduce,
|
|
||||||
)
|
)
|
||||||
from sglang.srt.hf_transformers_utils import get_processor
|
from sglang.srt.hf_transformers_utils import get_processor
|
||||||
from sglang.srt.layers.dp_attention import (
|
|
||||||
get_attention_tp_rank,
|
|
||||||
get_attention_tp_size,
|
|
||||||
get_local_attention_dp_size,
|
|
||||||
)
|
|
||||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||||
from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class
|
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
||||||
from sglang.srt.layers.pooler import Pooler, PoolingType
|
from sglang.srt.layers.pooler import Pooler, PoolingType
|
||||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||||
from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
|
from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
|
||||||
@@ -49,7 +41,6 @@ class Glm4vMoeForConditionalGeneration(Glm4vForConditionalGeneration):
|
|||||||
config.moe_layer_freq = 1
|
config.moe_layer_freq = 1
|
||||||
self.config = config
|
self.config = config
|
||||||
self.tp_size = get_tensor_model_parallel_world_size()
|
self.tp_size = get_tensor_model_parallel_world_size()
|
||||||
self.dp_size = get_local_attention_dp_size()
|
|
||||||
self.quant_config = quant_config
|
self.quant_config = quant_config
|
||||||
self.determine_num_fused_shared_experts("Glm4MoeForCausalLM")
|
self.determine_num_fused_shared_experts("Glm4MoeForCausalLM")
|
||||||
self.num_fused_shared_experts = (
|
self.num_fused_shared_experts = (
|
||||||
@@ -232,7 +223,7 @@ class Glm4vMoeForConditionalGeneration(Glm4vForConditionalGeneration):
|
|||||||
|
|
||||||
# Params for weights, fp8 weight scales, fp8 activation scales
|
# Params for weights, fp8 weight scales, fp8 activation scales
|
||||||
# (param_name, weight_name, expert_id, shard_id)
|
# (param_name, weight_name, expert_id, shard_id)
|
||||||
expert_params_mapping = get_moe_impl_class().make_expert_params_mapping(
|
expert_params_mapping = FusedMoE.make_expert_params_mapping(
|
||||||
ckpt_gate_proj_name="gate_proj",
|
ckpt_gate_proj_name="gate_proj",
|
||||||
ckpt_down_proj_name="down_proj",
|
ckpt_down_proj_name="down_proj",
|
||||||
ckpt_up_proj_name="up_proj",
|
ckpt_up_proj_name="up_proj",
|
||||||
|
|||||||
@@ -40,7 +40,6 @@ from sglang.srt.layers.communicator import LayerCommunicator, LayerScatterModes
|
|||||||
from sglang.srt.layers.dp_attention import (
|
from sglang.srt.layers.dp_attention import (
|
||||||
get_attention_tp_rank,
|
get_attention_tp_rank,
|
||||||
get_attention_tp_size,
|
get_attention_tp_size,
|
||||||
get_local_attention_dp_size,
|
|
||||||
is_dp_attention_enabled,
|
is_dp_attention_enabled,
|
||||||
)
|
)
|
||||||
from sglang.srt.layers.layernorm import RMSNorm
|
from sglang.srt.layers.layernorm import RMSNorm
|
||||||
@@ -50,9 +49,10 @@ from sglang.srt.layers.linear import (
|
|||||||
RowParallelLinear,
|
RowParallelLinear,
|
||||||
)
|
)
|
||||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||||
|
from sglang.srt.layers.moe import get_moe_a2a_backend
|
||||||
from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class
|
from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class
|
||||||
|
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
|
||||||
from sglang.srt.layers.moe.topk import TopK
|
from sglang.srt.layers.moe.topk import TopK
|
||||||
from sglang.srt.layers.moe.utils import DeepEPMode
|
|
||||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||||
from sglang.srt.layers.quantization.fp8_utils import dequant_mxfp4
|
from sglang.srt.layers.quantization.fp8_utils import dequant_mxfp4
|
||||||
from sglang.srt.layers.radix_attention import RadixAttention
|
from sglang.srt.layers.radix_attention import RadixAttention
|
||||||
@@ -110,12 +110,9 @@ class GptOssSparseMoeBlock(nn.Module):
|
|||||||
self.tp_size = get_tensor_model_parallel_world_size()
|
self.tp_size = get_tensor_model_parallel_world_size()
|
||||||
self.layer_id = layer_id
|
self.layer_id = layer_id
|
||||||
self.activation = config.hidden_act
|
self.activation = config.hidden_act
|
||||||
self.activation_alpha = getattr(config, "hidden_act_alpha", 1.702)
|
self.gemm1_alpha = getattr(config, "hidden_act_alpha", 1.702)
|
||||||
self.swiglu_limit = config.swiglu_limit
|
self.gemm1_clamp_limit = config.swiglu_limit
|
||||||
|
|
||||||
if global_server_args_dict["enable_flashinfer_mxfp4_moe"]:
|
|
||||||
self.topk = None
|
|
||||||
else:
|
|
||||||
self.topk = TopK(
|
self.topk = TopK(
|
||||||
top_k=config.num_experts_per_tok,
|
top_k=config.num_experts_per_tok,
|
||||||
renormalize=True,
|
renormalize=True,
|
||||||
@@ -129,11 +126,9 @@ class GptOssSparseMoeBlock(nn.Module):
|
|||||||
quant_config.get_name() if quant_config is not None else None
|
quant_config.get_name() if quant_config is not None else None
|
||||||
)
|
)
|
||||||
extra_kwargs = {
|
extra_kwargs = {
|
||||||
"enable_flashinfer_cutlass_moe": global_server_args_dict[
|
|
||||||
"enable_flashinfer_cutlass_moe"
|
|
||||||
],
|
|
||||||
# for moe gate_up_proj and down_proj and their bias loading
|
# for moe gate_up_proj and down_proj and their bias loading
|
||||||
"use_weight_loader_fused": quant_config_name != "mxfp4",
|
"use_weight_loader_fused": quant_config_name
|
||||||
|
!= "mxfp4"
|
||||||
}
|
}
|
||||||
self.experts = experts_type(
|
self.experts = experts_type(
|
||||||
num_experts=config.num_local_experts
|
num_experts=config.num_local_experts
|
||||||
@@ -144,15 +139,10 @@ class GptOssSparseMoeBlock(nn.Module):
|
|||||||
intermediate_size=config.intermediate_size,
|
intermediate_size=config.intermediate_size,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
activation=self.activation,
|
activation=self.activation,
|
||||||
activation_alpha=self.activation_alpha,
|
gemm1_alpha=self.gemm1_alpha,
|
||||||
swiglu_limit=self.swiglu_limit,
|
gemm1_clamp_limit=self.gemm1_clamp_limit,
|
||||||
with_bias=True,
|
with_bias=True,
|
||||||
prefix=add_prefix("experts", prefix),
|
prefix=add_prefix("experts", prefix),
|
||||||
**(
|
|
||||||
dict(deepep_mode=DeepEPMode[global_server_args_dict["deepep_mode"]])
|
|
||||||
if global_server_args_dict["moe_a2a_backend"].is_deepep()
|
|
||||||
else {}
|
|
||||||
),
|
|
||||||
**extra_kwargs,
|
**extra_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -171,7 +161,7 @@ class GptOssSparseMoeBlock(nn.Module):
|
|||||||
forward_batch: Optional[ForwardBatch] = None,
|
forward_batch: Optional[ForwardBatch] = None,
|
||||||
should_allreduce_fusion: bool = False,
|
should_allreduce_fusion: bool = False,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
if not global_server_args_dict["moe_a2a_backend"].is_deepep():
|
if not get_moe_a2a_backend().is_deepep():
|
||||||
return self.forward_normal(hidden_states, should_allreduce_fusion)
|
return self.forward_normal(hidden_states, should_allreduce_fusion)
|
||||||
else:
|
else:
|
||||||
raise Exception("forward_deepep branch not implemented yet")
|
raise Exception("forward_deepep branch not implemented yet")
|
||||||
@@ -189,17 +179,10 @@ class GptOssSparseMoeBlock(nn.Module):
|
|||||||
should_allreduce_fusion: bool = False,
|
should_allreduce_fusion: bool = False,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
num_tokens, hidden_dim = hidden_states.shape
|
num_tokens, hidden_dim = hidden_states.shape
|
||||||
hidden_states = hidden_states.view(-1, hidden_dim)
|
|
||||||
|
|
||||||
# router_logits: (num_tokens, n_experts)
|
|
||||||
router_logits, _ = self.router(hidden_states)
|
router_logits, _ = self.router(hidden_states)
|
||||||
|
topk_output = self.topk(hidden_states, router_logits)
|
||||||
kwargs = {"hidden_states": hidden_states}
|
final_hidden_states = self.experts(hidden_states, topk_output)
|
||||||
if self.topk is not None:
|
|
||||||
kwargs["topk_output"] = self.topk(hidden_states, router_logits)
|
|
||||||
else:
|
|
||||||
kwargs["topk_output"] = (self.top_k, router_logits)
|
|
||||||
final_hidden_states = self.experts(**kwargs)
|
|
||||||
|
|
||||||
if self.tp_size > 1 and not should_allreduce_fusion:
|
if self.tp_size > 1 and not should_allreduce_fusion:
|
||||||
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
|
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
|
||||||
@@ -436,7 +419,6 @@ class GptOssDecoderLayer(nn.Module):
|
|||||||
|
|
||||||
self.attn_tp_size = get_attention_tp_size()
|
self.attn_tp_size = get_attention_tp_size()
|
||||||
self.attn_tp_rank = get_attention_tp_rank()
|
self.attn_tp_rank = get_attention_tp_rank()
|
||||||
self.local_dp_size = get_local_attention_dp_size()
|
|
||||||
|
|
||||||
# GptOss all layers are sparse and have no nextn now
|
# GptOss all layers are sparse and have no nextn now
|
||||||
self.is_layer_sparse = True
|
self.is_layer_sparse = True
|
||||||
@@ -1060,7 +1042,7 @@ class GptOssForCausalLM(nn.Module):
|
|||||||
("qkv_proj", "k_proj", "k"),
|
("qkv_proj", "k_proj", "k"),
|
||||||
("qkv_proj", "v_proj", "v"),
|
("qkv_proj", "v_proj", "v"),
|
||||||
]
|
]
|
||||||
expert_params_mapping = get_moe_impl_class().make_expert_params_mapping_fused(
|
expert_params_mapping = FusedMoE.make_expert_params_mapping_fused(
|
||||||
ckpt_gate_up_proj_name="gate_up_proj",
|
ckpt_gate_up_proj_name="gate_up_proj",
|
||||||
ckpt_down_proj_name="down_proj",
|
ckpt_down_proj_name="down_proj",
|
||||||
ckpt_gate_up_proj_bias_name="gate_up_proj_bias",
|
ckpt_gate_up_proj_bias_name="gate_up_proj_bias",
|
||||||
|
|||||||
@@ -76,7 +76,6 @@ class GraniteMoeMoE(nn.Module):
|
|||||||
params_dtype=params_dtype,
|
params_dtype=params_dtype,
|
||||||
reduce_results=True,
|
reduce_results=True,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
tp_size=tp_size,
|
|
||||||
prefix=f"{prefix}.experts",
|
prefix=f"{prefix}.experts",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -135,7 +135,6 @@ class Grok1MoE(nn.Module):
|
|||||||
intermediate_size=intermediate_size,
|
intermediate_size=intermediate_size,
|
||||||
params_dtype=params_dtype,
|
params_dtype=params_dtype,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
tp_size=tp_size,
|
|
||||||
activation="gelu",
|
activation="gelu",
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ from transformers import PretrainedConfig
|
|||||||
|
|
||||||
from sglang.srt.distributed import parallel_state
|
from sglang.srt.distributed import parallel_state
|
||||||
from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class
|
from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class
|
||||||
|
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
|
||||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||||
from sglang.srt.managers.mm_utils import (
|
from sglang.srt.managers.mm_utils import (
|
||||||
MultiModalityDataPaddingPatternTokenPairs,
|
MultiModalityDataPaddingPatternTokenPairs,
|
||||||
@@ -254,7 +255,7 @@ class InternS1ForConditionalGeneration(nn.Module):
|
|||||||
]
|
]
|
||||||
expert_params_mapping = []
|
expert_params_mapping = []
|
||||||
if "Qwen3MoeForCausalLM" in self.config.text_config.architectures:
|
if "Qwen3MoeForCausalLM" in self.config.text_config.architectures:
|
||||||
expert_params_mapping = get_moe_impl_class().make_expert_params_mapping(
|
expert_params_mapping = FusedMoE.make_expert_params_mapping(
|
||||||
ckpt_gate_proj_name="gate_proj",
|
ckpt_gate_proj_name="gate_proj",
|
||||||
ckpt_down_proj_name="down_proj",
|
ckpt_down_proj_name="down_proj",
|
||||||
ckpt_up_proj_name="up_proj",
|
ckpt_up_proj_name="up_proj",
|
||||||
|
|||||||
@@ -12,7 +12,7 @@ from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPo
|
|||||||
|
|
||||||
from sglang.srt.distributed import parallel_state
|
from sglang.srt.distributed import parallel_state
|
||||||
from sglang.srt.layers.attention.vision import SingletonCache, VisionAttention
|
from sglang.srt.layers.attention.vision import SingletonCache, VisionAttention
|
||||||
from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class
|
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
|
||||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||||
from sglang.srt.managers.mm_utils import (
|
from sglang.srt.managers.mm_utils import (
|
||||||
MultiModalityDataPaddingPatternTokenPairs,
|
MultiModalityDataPaddingPatternTokenPairs,
|
||||||
@@ -616,7 +616,7 @@ class InternVLChatModel(nn.Module):
|
|||||||
("gate_up_proj", "up_proj", 1),
|
("gate_up_proj", "up_proj", 1),
|
||||||
]
|
]
|
||||||
|
|
||||||
expert_params_mapping = get_moe_impl_class().make_expert_params_mapping(
|
expert_params_mapping = FusedMoE.make_expert_params_mapping(
|
||||||
ckpt_gate_proj_name="gate_proj",
|
ckpt_gate_proj_name="gate_proj",
|
||||||
ckpt_down_proj_name="down_proj",
|
ckpt_down_proj_name="down_proj",
|
||||||
ckpt_up_proj_name="up_proj",
|
ckpt_up_proj_name="up_proj",
|
||||||
|
|||||||
@@ -31,7 +31,6 @@ from sglang.srt.layers.communicator import LayerCommunicator, LayerScatterModes
|
|||||||
from sglang.srt.layers.dp_attention import (
|
from sglang.srt.layers.dp_attention import (
|
||||||
get_attention_tp_rank,
|
get_attention_tp_rank,
|
||||||
get_attention_tp_size,
|
get_attention_tp_size,
|
||||||
get_local_attention_dp_size,
|
|
||||||
is_dp_attention_enabled,
|
is_dp_attention_enabled,
|
||||||
)
|
)
|
||||||
from sglang.srt.layers.layernorm import RMSNorm
|
from sglang.srt.layers.layernorm import RMSNorm
|
||||||
@@ -364,7 +363,6 @@ class Llama4DecoderLayer(nn.Module):
|
|||||||
rope_theta = config.rope_theta
|
rope_theta = config.rope_theta
|
||||||
rope_scaling = config.rope_scaling
|
rope_scaling = config.rope_scaling
|
||||||
max_position_embeddings = config.max_position_embeddings
|
max_position_embeddings = config.max_position_embeddings
|
||||||
self.local_dp_size = get_local_attention_dp_size()
|
|
||||||
self.attn_tp_size = get_attention_tp_size()
|
self.attn_tp_size = get_attention_tp_size()
|
||||||
self.attn_tp_rank = get_attention_tp_rank()
|
self.attn_tp_rank = get_attention_tp_rank()
|
||||||
|
|
||||||
|
|||||||
@@ -37,7 +37,6 @@ from sglang.srt.layers.vocab_parallel_embedding import (
|
|||||||
ParallelLMHead,
|
ParallelLMHead,
|
||||||
VocabParallelEmbedding,
|
VocabParallelEmbedding,
|
||||||
)
|
)
|
||||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
|
||||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||||
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
||||||
from sglang.srt.utils import add_prefix, is_cuda
|
from sglang.srt.utils import add_prefix, is_cuda
|
||||||
|
|||||||
@@ -47,7 +47,6 @@ from sglang.srt.layers.vocab_parallel_embedding import (
|
|||||||
ParallelLMHead,
|
ParallelLMHead,
|
||||||
VocabParallelEmbedding,
|
VocabParallelEmbedding,
|
||||||
)
|
)
|
||||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
|
||||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
|
||||||
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
||||||
from sglang.srt.utils import add_prefix, make_layers
|
from sglang.srt.utils import add_prefix, make_layers
|
||||||
@@ -104,7 +103,6 @@ class MixtralMoE(nn.Module):
|
|||||||
intermediate_size=intermediate_size,
|
intermediate_size=intermediate_size,
|
||||||
params_dtype=params_dtype,
|
params_dtype=params_dtype,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
tp_size=tp_size,
|
|
||||||
prefix=add_prefix("experts", prefix),
|
prefix=add_prefix("experts", prefix),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -89,7 +89,6 @@ class OlmoeMoE(nn.Module):
|
|||||||
intermediate_size=intermediate_size,
|
intermediate_size=intermediate_size,
|
||||||
reduce_results=True,
|
reduce_results=True,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
tp_size=tp_size,
|
|
||||||
layer_id=layer_id,
|
layer_id=layer_id,
|
||||||
prefix=add_prefix("experts", prefix),
|
prefix=add_prefix("experts", prefix),
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -17,8 +17,6 @@
|
|||||||
"""Inference-only Qwen2MoE model compatible with HuggingFace weights."""
|
"""Inference-only Qwen2MoE model compatible with HuggingFace weights."""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from dataclasses import dataclass
|
|
||||||
from enum import Enum, auto
|
|
||||||
from typing import Any, Dict, Iterable, Optional, Tuple, Union
|
from typing import Any, Dict, Iterable, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@@ -31,10 +29,7 @@ from sglang.srt.distributed import (
|
|||||||
get_tensor_model_parallel_world_size,
|
get_tensor_model_parallel_world_size,
|
||||||
tensor_model_parallel_all_reduce,
|
tensor_model_parallel_all_reduce,
|
||||||
)
|
)
|
||||||
from sglang.srt.eplb.expert_distribution import (
|
from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder
|
||||||
ExpertDistributionRecorder,
|
|
||||||
get_global_expert_distribution_recorder,
|
|
||||||
)
|
|
||||||
from sglang.srt.eplb.expert_location import ModelConfigForExpertLocation
|
from sglang.srt.eplb.expert_location import ModelConfigForExpertLocation
|
||||||
from sglang.srt.layers.activation import SiluAndMul
|
from sglang.srt.layers.activation import SiluAndMul
|
||||||
from sglang.srt.layers.communicator import (
|
from sglang.srt.layers.communicator import (
|
||||||
@@ -45,7 +40,6 @@ from sglang.srt.layers.communicator import (
|
|||||||
from sglang.srt.layers.dp_attention import (
|
from sglang.srt.layers.dp_attention import (
|
||||||
get_attention_tp_rank,
|
get_attention_tp_rank,
|
||||||
get_attention_tp_size,
|
get_attention_tp_size,
|
||||||
get_local_attention_dp_size,
|
|
||||||
is_dp_attention_enabled,
|
is_dp_attention_enabled,
|
||||||
)
|
)
|
||||||
from sglang.srt.layers.layernorm import RMSNorm
|
from sglang.srt.layers.layernorm import RMSNorm
|
||||||
@@ -55,8 +49,8 @@ from sglang.srt.layers.linear import (
|
|||||||
ReplicatedLinear,
|
ReplicatedLinear,
|
||||||
RowParallelLinear,
|
RowParallelLinear,
|
||||||
)
|
)
|
||||||
from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||||
from sglang.srt.layers.moe.ep_moe.layer import EPMoE, get_moe_impl_class
|
from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class
|
||||||
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
||||||
from sglang.srt.layers.moe.topk import TopK
|
from sglang.srt.layers.moe.topk import TopK
|
||||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||||
@@ -149,14 +143,6 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
|
|||||||
intermediate_size=config.moe_intermediate_size,
|
intermediate_size=config.moe_intermediate_size,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
prefix=add_prefix("experts", prefix),
|
prefix=add_prefix("experts", prefix),
|
||||||
# Additional args for FusedMoE
|
|
||||||
**(
|
|
||||||
dict(
|
|
||||||
enable_flashinfer_cutlass_moe=True,
|
|
||||||
)
|
|
||||||
if global_server_args_dict["enable_flashinfer_cutlass_moe"]
|
|
||||||
else {}
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
self.gate = ReplicatedLinear(
|
self.gate = ReplicatedLinear(
|
||||||
@@ -340,7 +326,6 @@ class Qwen2MoeDecoderLayer(nn.Module):
|
|||||||
|
|
||||||
self.attn_tp_size = get_attention_tp_size()
|
self.attn_tp_size = get_attention_tp_size()
|
||||||
self.attn_tp_rank = get_attention_tp_rank()
|
self.attn_tp_rank = get_attention_tp_rank()
|
||||||
self.local_dp_size = get_local_attention_dp_size()
|
|
||||||
|
|
||||||
# Qwen2MoE all layers are sparse and have no nextn now
|
# Qwen2MoE all layers are sparse and have no nextn now
|
||||||
self.is_layer_sparse = True
|
self.is_layer_sparse = True
|
||||||
|
|||||||
@@ -28,50 +28,35 @@ from sglang.srt.distributed import (
|
|||||||
get_pp_group,
|
get_pp_group,
|
||||||
get_tensor_model_parallel_rank,
|
get_tensor_model_parallel_rank,
|
||||||
get_tensor_model_parallel_world_size,
|
get_tensor_model_parallel_world_size,
|
||||||
parallel_state,
|
|
||||||
split_tensor_along_last_dim,
|
|
||||||
tensor_model_parallel_all_gather,
|
|
||||||
tensor_model_parallel_all_reduce,
|
tensor_model_parallel_all_reduce,
|
||||||
)
|
)
|
||||||
from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder
|
from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder
|
||||||
from sglang.srt.eplb.expert_location import ModelConfigForExpertLocation
|
from sglang.srt.eplb.expert_location import ModelConfigForExpertLocation
|
||||||
from sglang.srt.eplb.expert_location_dispatch import ExpertLocationDispatchInfo
|
from sglang.srt.eplb.expert_location_dispatch import ExpertLocationDispatchInfo
|
||||||
from sglang.srt.layers.activation import SiluAndMul
|
|
||||||
from sglang.srt.layers.communicator import LayerCommunicator, LayerScatterModes
|
from sglang.srt.layers.communicator import LayerCommunicator, LayerScatterModes
|
||||||
from sglang.srt.layers.dp_attention import (
|
from sglang.srt.layers.dp_attention import get_attention_tp_rank, get_attention_tp_size
|
||||||
get_attention_tp_rank,
|
|
||||||
get_attention_tp_size,
|
|
||||||
get_local_attention_dp_size,
|
|
||||||
)
|
|
||||||
from sglang.srt.layers.layernorm import RMSNorm
|
from sglang.srt.layers.layernorm import RMSNorm
|
||||||
from sglang.srt.layers.linear import (
|
from sglang.srt.layers.linear import (
|
||||||
MergedColumnParallelLinear,
|
|
||||||
QKVParallelLinear,
|
QKVParallelLinear,
|
||||||
ReplicatedLinear,
|
ReplicatedLinear,
|
||||||
RowParallelLinear,
|
RowParallelLinear,
|
||||||
)
|
)
|
||||||
from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||||
|
from sglang.srt.layers.moe import get_moe_a2a_backend
|
||||||
from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class
|
from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class
|
||||||
|
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
|
||||||
from sglang.srt.layers.moe.topk import TopK
|
from sglang.srt.layers.moe.topk import TopK
|
||||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||||
from sglang.srt.layers.radix_attention import RadixAttention
|
from sglang.srt.layers.radix_attention import RadixAttention
|
||||||
from sglang.srt.layers.rotary_embedding import get_rope
|
from sglang.srt.layers.rotary_embedding import get_rope
|
||||||
from sglang.srt.layers.utils import get_layer_id
|
from sglang.srt.layers.utils import get_layer_id
|
||||||
from sglang.srt.layers.vocab_parallel_embedding import (
|
from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
|
||||||
ParallelLMHead,
|
|
||||||
VocabParallelEmbedding,
|
|
||||||
)
|
|
||||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||||
from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode
|
from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode
|
||||||
from sglang.srt.model_executor.forward_batch_info import (
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
|
||||||
ForwardBatch,
|
|
||||||
ForwardMode,
|
|
||||||
PPProxyTensors,
|
|
||||||
)
|
|
||||||
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
||||||
from sglang.srt.models.qwen2_moe import Qwen2MoeMLP as Qwen3MoeMLP
|
from sglang.srt.models.qwen2_moe import Qwen2MoeMLP as Qwen3MoeMLP
|
||||||
from sglang.srt.models.qwen2_moe import Qwen2MoeModel
|
from sglang.srt.models.qwen2_moe import Qwen2MoeModel
|
||||||
from sglang.srt.two_batch_overlap import MaybeTboDeepEPDispatcher
|
|
||||||
from sglang.srt.utils import add_prefix, is_cuda, is_non_idle_and_non_empty
|
from sglang.srt.utils import add_prefix, is_cuda, is_non_idle_and_non_empty
|
||||||
|
|
||||||
Qwen3MoeConfig = None
|
Qwen3MoeConfig = None
|
||||||
@@ -112,19 +97,6 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
|
|||||||
intermediate_size=config.moe_intermediate_size,
|
intermediate_size=config.moe_intermediate_size,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
prefix=add_prefix("experts", prefix),
|
prefix=add_prefix("experts", prefix),
|
||||||
**(
|
|
||||||
dict(deepep_mode=global_server_args_dict["deepep_mode"])
|
|
||||||
if global_server_args_dict["moe_a2a_backend"].is_deepep()
|
|
||||||
else {}
|
|
||||||
),
|
|
||||||
# Additional args for FusedMoE
|
|
||||||
**(
|
|
||||||
dict(
|
|
||||||
enable_flashinfer_cutlass_moe=True,
|
|
||||||
)
|
|
||||||
if global_server_args_dict["enable_flashinfer_cutlass_moe"]
|
|
||||||
else {}
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
self.gate = ReplicatedLinear(
|
self.gate = ReplicatedLinear(
|
||||||
@@ -135,7 +107,7 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
|
|||||||
prefix=add_prefix("gate", prefix),
|
prefix=add_prefix("gate", prefix),
|
||||||
)
|
)
|
||||||
|
|
||||||
if global_server_args_dict["moe_a2a_backend"].is_deepep():
|
if get_moe_a2a_backend().is_deepep():
|
||||||
# TODO: we will support tp < ep in the future
|
# TODO: we will support tp < ep in the future
|
||||||
self.ep_size = get_moe_expert_parallel_world_size()
|
self.ep_size = get_moe_expert_parallel_world_size()
|
||||||
self.num_experts = (
|
self.num_experts = (
|
||||||
@@ -150,7 +122,7 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
|
|||||||
use_reduce_scatter: bool = False,
|
use_reduce_scatter: bool = False,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
|
|
||||||
if not global_server_args_dict["moe_a2a_backend"].is_deepep():
|
if not get_moe_a2a_backend().is_deepep():
|
||||||
return self.forward_normal(hidden_states, use_reduce_scatter)
|
return self.forward_normal(hidden_states, use_reduce_scatter)
|
||||||
else:
|
else:
|
||||||
return self.forward_deepep(hidden_states, forward_batch)
|
return self.forward_deepep(hidden_states, forward_batch)
|
||||||
@@ -491,7 +463,6 @@ class Qwen3MoeDecoderLayer(nn.Module):
|
|||||||
|
|
||||||
self.attn_tp_size = get_attention_tp_size()
|
self.attn_tp_size = get_attention_tp_size()
|
||||||
self.attn_tp_rank = get_attention_tp_rank()
|
self.attn_tp_rank = get_attention_tp_rank()
|
||||||
self.local_dp_size = get_local_attention_dp_size()
|
|
||||||
|
|
||||||
# Qwen3MoE all layers are sparse and have no nextn now
|
# Qwen3MoE all layers are sparse and have no nextn now
|
||||||
self.is_layer_sparse = True
|
self.is_layer_sparse = True
|
||||||
@@ -778,7 +749,7 @@ class Qwen3MoeForCausalLM(nn.Module):
|
|||||||
("gate_up_proj", "up_proj", 1),
|
("gate_up_proj", "up_proj", 1),
|
||||||
]
|
]
|
||||||
|
|
||||||
expert_params_mapping = get_moe_impl_class().make_expert_params_mapping(
|
expert_params_mapping = FusedMoE.make_expert_params_mapping(
|
||||||
ckpt_gate_proj_name="gate_proj",
|
ckpt_gate_proj_name="gate_proj",
|
||||||
ckpt_down_proj_name="down_proj",
|
ckpt_down_proj_name="down_proj",
|
||||||
ckpt_up_proj_name="up_proj",
|
ckpt_up_proj_name="up_proj",
|
||||||
|
|||||||
@@ -38,6 +38,7 @@ from sglang.srt.layers.linear import (
|
|||||||
RowParallelLinear,
|
RowParallelLinear,
|
||||||
)
|
)
|
||||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||||
|
from sglang.srt.layers.moe import get_moe_a2a_backend
|
||||||
from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class
|
from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class
|
||||||
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
||||||
from sglang.srt.layers.moe.topk import TopK
|
from sglang.srt.layers.moe.topk import TopK
|
||||||
@@ -150,7 +151,7 @@ class Step3TextMoEMLP(nn.Module):
|
|||||||
prefix=add_prefix("gate", prefix),
|
prefix=add_prefix("gate", prefix),
|
||||||
)
|
)
|
||||||
|
|
||||||
if global_server_args_dict["moe_a2a_backend"].is_deepep():
|
if get_moe_a2a_backend().is_deepep():
|
||||||
raise NotImplementedError("DeepEP MoE is not supported yet in Step3 model.")
|
raise NotImplementedError("DeepEP MoE is not supported yet in Step3 model.")
|
||||||
|
|
||||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||||
|
|||||||
@@ -33,7 +33,9 @@ from sglang.srt.layers.linear import (
|
|||||||
RowParallelLinear,
|
RowParallelLinear,
|
||||||
)
|
)
|
||||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||||
from sglang.srt.layers.moe.fused_moe_triton import fused_moe
|
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_moe
|
||||||
|
from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
|
||||||
|
from sglang.srt.layers.moe.topk import TopK
|
||||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||||
from sglang.srt.layers.radix_attention import RadixAttention
|
from sglang.srt.layers.radix_attention import RadixAttention
|
||||||
from sglang.srt.layers.rotary_embedding import get_rope
|
from sglang.srt.layers.rotary_embedding import get_rope
|
||||||
@@ -121,6 +123,7 @@ class XverseMoE(nn.Module):
|
|||||||
]
|
]
|
||||||
)
|
)
|
||||||
self.pack_params()
|
self.pack_params()
|
||||||
|
self.moe_runner_config = MoeRunnerConfig(inplace=True)
|
||||||
|
|
||||||
self.router = ReplicatedLinear(
|
self.router = ReplicatedLinear(
|
||||||
config.hidden_size,
|
config.hidden_size,
|
||||||
@@ -129,6 +132,10 @@ class XverseMoE(nn.Module):
|
|||||||
quant_config=None,
|
quant_config=None,
|
||||||
prefix=add_prefix("router", prefix),
|
prefix=add_prefix("router", prefix),
|
||||||
)
|
)
|
||||||
|
self.topk = TopK(
|
||||||
|
top_k=self.top_k,
|
||||||
|
renormalize=getattr(self.config, "norm_topk_prob", False),
|
||||||
|
)
|
||||||
|
|
||||||
if config.num_shared_experts is not None:
|
if config.num_shared_experts is not None:
|
||||||
intermediate_size = config.intermediate_size * config.num_shared_experts
|
intermediate_size = config.intermediate_size * config.num_shared_experts
|
||||||
@@ -167,14 +174,13 @@ class XverseMoE(nn.Module):
|
|||||||
shared_output = self.shared_experts(hidden_states)
|
shared_output = self.shared_experts(hidden_states)
|
||||||
# router_logits: (num_tokens, n_experts)
|
# router_logits: (num_tokens, n_experts)
|
||||||
router_logits, _ = self.router(hidden_states)
|
router_logits, _ = self.router(hidden_states)
|
||||||
|
topk_output = self.topk(hidden_states, router_logits)
|
||||||
final_hidden_states = fused_moe(
|
final_hidden_states = fused_moe(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
self.w1,
|
self.w1,
|
||||||
self.w2,
|
self.w2,
|
||||||
router_logits,
|
topk_output,
|
||||||
self.top_k,
|
self.moe_runner_config,
|
||||||
renormalize=getattr(self.config, "norm_topk_prob", False),
|
|
||||||
inplace=True,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.config.num_shared_experts is not None:
|
if self.config.num_shared_experts is not None:
|
||||||
|
|||||||
@@ -37,6 +37,7 @@ from sglang.srt.utils import (
|
|||||||
is_hip,
|
is_hip,
|
||||||
is_port_available,
|
is_port_available,
|
||||||
is_remote_url,
|
is_remote_url,
|
||||||
|
is_triton_kernels_available,
|
||||||
is_valid_ipv6_address,
|
is_valid_ipv6_address,
|
||||||
nullable_str,
|
nullable_str,
|
||||||
)
|
)
|
||||||
@@ -175,9 +176,15 @@ class ServerArgs:
|
|||||||
|
|
||||||
# Expert parallelism
|
# Expert parallelism
|
||||||
ep_size: int = 1
|
ep_size: int = 1
|
||||||
moe_a2a_backend: Optional[Literal["deepep"]] = None
|
moe_a2a_backend: Literal["none", "deepep"] = "none"
|
||||||
enable_flashinfer_cutlass_moe: bool = False
|
moe_runner_backend: Literal[
|
||||||
enable_flashinfer_trtllm_moe: bool = False
|
"auto",
|
||||||
|
"triton",
|
||||||
|
"triton_kernel",
|
||||||
|
"flashinfer_trtllm",
|
||||||
|
"flashinfer_cutlass",
|
||||||
|
"flashinfer_mxfp4",
|
||||||
|
] = "auto"
|
||||||
enable_flashinfer_allreduce_fusion: bool = False
|
enable_flashinfer_allreduce_fusion: bool = False
|
||||||
deepep_mode: Literal["auto", "normal", "low_latency"] = "auto"
|
deepep_mode: Literal["auto", "normal", "low_latency"] = "auto"
|
||||||
ep_num_redundant_experts: int = 0
|
ep_num_redundant_experts: int = 0
|
||||||
@@ -250,8 +257,6 @@ class ServerArgs:
|
|||||||
disable_chunked_prefix_cache: bool = False
|
disable_chunked_prefix_cache: bool = False
|
||||||
disable_fast_image_processor: bool = False
|
disable_fast_image_processor: bool = False
|
||||||
enable_return_hidden_states: bool = False
|
enable_return_hidden_states: bool = False
|
||||||
enable_triton_kernel_moe: bool = False
|
|
||||||
enable_flashinfer_mxfp4_moe: bool = False
|
|
||||||
scheduler_recv_interval: int = 1
|
scheduler_recv_interval: int = 1
|
||||||
|
|
||||||
# Debug tensor dumps
|
# Debug tensor dumps
|
||||||
@@ -282,6 +287,9 @@ class ServerArgs:
|
|||||||
# Deprecated arguments
|
# Deprecated arguments
|
||||||
enable_ep_moe: bool = False
|
enable_ep_moe: bool = False
|
||||||
enable_deepep_moe: bool = False
|
enable_deepep_moe: bool = False
|
||||||
|
enable_flashinfer_cutlass_moe: bool = False
|
||||||
|
enable_flashinfer_trtllm_moe: bool = False
|
||||||
|
enable_triton_kernel_moe: bool = False
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
# Check deprecated arguments
|
# Check deprecated arguments
|
||||||
@@ -298,6 +306,21 @@ class ServerArgs:
|
|||||||
print_deprecated_warning(
|
print_deprecated_warning(
|
||||||
"NOTE: --enable-deepep-moe is deprecated. Please set `--moe-a2a-backend` to 'deepep' instead."
|
"NOTE: --enable-deepep-moe is deprecated. Please set `--moe-a2a-backend` to 'deepep' instead."
|
||||||
)
|
)
|
||||||
|
if self.enable_triton_kernel_moe:
|
||||||
|
self.moe_runner_backend = "triton_kernel"
|
||||||
|
print_deprecated_warning(
|
||||||
|
"NOTE: --enable-triton-kernel-moe is deprecated. Please set `--moe-runner-backend` to 'triton_kernel' instead."
|
||||||
|
)
|
||||||
|
if self.enable_flashinfer_cutlass_moe:
|
||||||
|
self.moe_runner_backend = "flashinfer_cutlass"
|
||||||
|
print_deprecated_warning(
|
||||||
|
"NOTE: --enable-flashinfer-cutlass-moe is deprecated. Please set `--moe-runner-backend` to 'flashinfer_cutlass' instead."
|
||||||
|
)
|
||||||
|
if self.enable_flashinfer_trtllm_moe:
|
||||||
|
self.moe_runner_backend = "flashinfer_trtllm"
|
||||||
|
print_deprecated_warning(
|
||||||
|
"NOTE: --enable-flashinfer-trtllm-moe is deprecated. Please set `--moe-runner-backend` to 'flashinfer_trtllm' instead."
|
||||||
|
)
|
||||||
|
|
||||||
# Set missing default values
|
# Set missing default values
|
||||||
if self.tokenizer_path is None:
|
if self.tokenizer_path is None:
|
||||||
@@ -517,7 +540,7 @@ class ServerArgs:
|
|||||||
), "Please enable dp attention when setting enable_dp_lm_head. "
|
), "Please enable dp attention when setting enable_dp_lm_head. "
|
||||||
|
|
||||||
# MoE kernel
|
# MoE kernel
|
||||||
if self.enable_flashinfer_cutlass_moe:
|
if self.moe_runner_backend == "flashinfer_cutlass":
|
||||||
assert (
|
assert (
|
||||||
self.quantization == "modelopt_fp4"
|
self.quantization == "modelopt_fp4"
|
||||||
), "modelopt_fp4 quantization is required for Flashinfer MOE"
|
), "modelopt_fp4 quantization is required for Flashinfer MOE"
|
||||||
@@ -527,7 +550,7 @@ class ServerArgs:
|
|||||||
self.tp_size,
|
self.tp_size,
|
||||||
], "The expert parallel size must be 1 or the same as the tensor parallel size"
|
], "The expert parallel size must be 1 or the same as the tensor parallel size"
|
||||||
|
|
||||||
if self.enable_flashinfer_trtllm_moe:
|
if self.moe_runner_backend == "flashinfer_trtllm":
|
||||||
if not self.disable_shared_experts_fusion:
|
if not self.disable_shared_experts_fusion:
|
||||||
self.disable_shared_experts_fusion = True
|
self.disable_shared_experts_fusion = True
|
||||||
logger.warning(
|
logger.warning(
|
||||||
@@ -556,7 +579,7 @@ class ServerArgs:
|
|||||||
self.ep_dispatch_algorithm = "static"
|
self.ep_dispatch_algorithm = "static"
|
||||||
|
|
||||||
if self.enable_eplb:
|
if self.enable_eplb:
|
||||||
assert self.ep_size > 1 or self.moe_a2a_backend is not None
|
assert self.ep_size > 1
|
||||||
|
|
||||||
if self.enable_expert_distribution_metrics and (
|
if self.enable_expert_distribution_metrics and (
|
||||||
self.expert_distribution_recorder_mode is None
|
self.expert_distribution_recorder_mode is None
|
||||||
@@ -1446,19 +1469,22 @@ class ServerArgs:
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--moe-a2a-backend",
|
"--moe-a2a-backend",
|
||||||
type=str,
|
type=str,
|
||||||
choices=["deepep"],
|
choices=["none", "deepep"],
|
||||||
default=ServerArgs.moe_a2a_backend,
|
default=ServerArgs.moe_a2a_backend,
|
||||||
help="Choose the backend for MoE A2A.",
|
help="Choose the backend for MoE A2A.",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--enable-flashinfer-cutlass-moe",
|
"--moe-runner-backend",
|
||||||
action="store_true",
|
type=str,
|
||||||
help="Enable FlashInfer CUTLASS MoE backend for modelopt_fp4 quant on Blackwell. Supports MoE-EP",
|
choices=[
|
||||||
)
|
"auto",
|
||||||
parser.add_argument(
|
"triton",
|
||||||
"--enable-flashinfer-trtllm-moe",
|
"triton_kernel",
|
||||||
action="store_true",
|
"flashinfer_trtllm",
|
||||||
help="Enable FlashInfer TRTLLM MoE backend on Blackwell. Supports BlockScale FP8 MoE-EP",
|
"flashinfer_cutlass",
|
||||||
|
],
|
||||||
|
default=ServerArgs.moe_runner_backend,
|
||||||
|
help="Choose the runner backend for MoE.",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--enable-flashinfer-allreduce-fusion",
|
"--enable-flashinfer-allreduce-fusion",
|
||||||
@@ -1825,11 +1851,6 @@ class ServerArgs:
|
|||||||
action="store_true",
|
action="store_true",
|
||||||
help="Enable returning hidden states with responses.",
|
help="Enable returning hidden states with responses.",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
|
||||||
"--enable-triton-kernel-moe",
|
|
||||||
action="store_true",
|
|
||||||
help="Use triton moe grouped gemm kernel.",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--enable-flashinfer-mxfp4-moe",
|
"--enable-flashinfer-mxfp4-moe",
|
||||||
action="store_true",
|
action="store_true",
|
||||||
@@ -1965,6 +1986,21 @@ class ServerArgs:
|
|||||||
action="store_true",
|
action="store_true",
|
||||||
help="(Deprecated) Enabling DeepEP MoE implementation for EP MoE.",
|
help="(Deprecated) Enabling DeepEP MoE implementation for EP MoE.",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--enable-flashinfer-cutlass-moe",
|
||||||
|
action="store_true",
|
||||||
|
help="(Deprecated) Enable FlashInfer CUTLASS MoE backend for modelopt_fp4 quant on Blackwell. Supports MoE-EP",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--enable-flashinfer-trtllm-moe",
|
||||||
|
action="store_true",
|
||||||
|
help="(Deprecated) Enable FlashInfer TRTLLM MoE backend on Blackwell. Supports BlockScale FP8 MoE-EP",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--enable-triton-kernel-moe",
|
||||||
|
action="store_true",
|
||||||
|
help="(Deprecated) Use triton moe grouped gemm kernel.",
|
||||||
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_cli_args(cls, args: argparse.Namespace):
|
def from_cli_args(cls, args: argparse.Namespace):
|
||||||
@@ -2143,18 +2179,21 @@ class ServerArgs:
|
|||||||
)
|
)
|
||||||
|
|
||||||
if is_sm100_supported() and is_mxfp4_quant_format:
|
if is_sm100_supported() and is_mxfp4_quant_format:
|
||||||
self.enable_flashinfer_mxfp4_moe = True
|
self.moe_runner_backend = "flashinfer_mxfp4"
|
||||||
self.enable_triton_kernel_moe = False
|
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"Detected SM100 and MXFP4 quantization format for GPT-OSS model, enabling FlashInfer MXFP4 MOE kernel."
|
"Detected SM100 and MXFP4 quantization format for GPT-OSS model, enabling FlashInfer MXFP4 MOE kernel."
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
if self.enable_triton_kernel_moe:
|
if self.moe_runner_backend == "triton_kernel":
|
||||||
assert (
|
assert (
|
||||||
self.ep_size == 1
|
self.ep_size == 1
|
||||||
), "Triton kernel MoE is only supported when ep_size == 1"
|
), "Triton kernel MoE is only supported when ep_size == 1"
|
||||||
if not self.enable_triton_kernel_moe and self.ep_size == 1:
|
if (
|
||||||
self.enable_triton_kernel_moe = True
|
self.moe_runner_backend == "auto"
|
||||||
|
and self.ep_size == 1
|
||||||
|
and is_triton_kernels_available()
|
||||||
|
):
|
||||||
|
self.moe_runner_backend = "triton_kernel"
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"Detected GPT-OSS model, enabling triton_kernels MOE kernel."
|
"Detected GPT-OSS model, enabling triton_kernels MOE kernel."
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -14,8 +14,13 @@ from sglang.srt.layers.communicator import (
|
|||||||
CommunicateSummableTensorPairFn,
|
CommunicateSummableTensorPairFn,
|
||||||
ScatterMode,
|
ScatterMode,
|
||||||
)
|
)
|
||||||
|
from sglang.srt.layers.moe import (
|
||||||
|
get_deepep_mode,
|
||||||
|
get_moe_a2a_backend,
|
||||||
|
get_tbo_token_distribution_threshold,
|
||||||
|
is_tbo_enabled,
|
||||||
|
)
|
||||||
from sglang.srt.layers.moe.token_dispatcher import DeepEPDispatcher
|
from sglang.srt.layers.moe.token_dispatcher import DeepEPDispatcher
|
||||||
from sglang.srt.layers.moe.utils import DeepEPMode
|
|
||||||
from sglang.srt.layers.quantization import deep_gemm_wrapper
|
from sglang.srt.layers.quantization import deep_gemm_wrapper
|
||||||
from sglang.srt.managers.schedule_batch import ScheduleBatch, global_server_args_dict
|
from sglang.srt.managers.schedule_batch import ScheduleBatch, global_server_args_dict
|
||||||
from sglang.srt.model_executor.forward_batch_info import (
|
from sglang.srt.model_executor.forward_batch_info import (
|
||||||
@@ -83,7 +88,7 @@ def _is_two_chunk_split_enabled(extend_lens: Sequence[int]) -> bool:
|
|||||||
vanilla_split_seq_index = _split_array_by_balanced_sum(extend_lens)
|
vanilla_split_seq_index = _split_array_by_balanced_sum(extend_lens)
|
||||||
left_sum = sum(extend_lens[:vanilla_split_seq_index])
|
left_sum = sum(extend_lens[:vanilla_split_seq_index])
|
||||||
overall_sum = sum(extend_lens)
|
overall_sum = sum(extend_lens)
|
||||||
threshold = global_server_args_dict["tbo_token_distribution_threshold"]
|
threshold = get_tbo_token_distribution_threshold()
|
||||||
assert threshold <= 0.5, f"{threshold=}"
|
assert threshold <= 0.5, f"{threshold=}"
|
||||||
return left_sum < overall_sum * threshold or left_sum > overall_sum * (
|
return left_sum < overall_sum * threshold or left_sum > overall_sum * (
|
||||||
1 - threshold
|
1 - threshold
|
||||||
@@ -299,7 +304,7 @@ class TboCudaGraphRunnerPlugin:
|
|||||||
self._tbo_children_num_token_non_padded = torch.zeros((2,), dtype=torch.int32)
|
self._tbo_children_num_token_non_padded = torch.zeros((2,), dtype=torch.int32)
|
||||||
|
|
||||||
def capture_one_batch_size(self, batch: ForwardBatch, num_tokens: int):
|
def capture_one_batch_size(self, batch: ForwardBatch, num_tokens: int):
|
||||||
if not global_server_args_dict["enable_two_batch_overlap"]:
|
if not is_tbo_enabled():
|
||||||
return
|
return
|
||||||
token_num_per_seq = get_token_num_per_seq(
|
token_num_per_seq = get_token_num_per_seq(
|
||||||
forward_mode=batch.forward_mode, spec_info=batch.spec_info
|
forward_mode=batch.forward_mode, spec_info=batch.spec_info
|
||||||
@@ -353,10 +358,12 @@ class TboDPAttentionPreparer:
|
|||||||
def prepare_all_gather(
|
def prepare_all_gather(
|
||||||
self,
|
self,
|
||||||
local_batch: ScheduleBatch,
|
local_batch: ScheduleBatch,
|
||||||
deepep_mode: DeepEPMode,
|
|
||||||
enable_deepep_moe: bool,
|
|
||||||
enable_two_batch_overlap: bool,
|
|
||||||
):
|
):
|
||||||
|
|
||||||
|
deepep_mode = get_deepep_mode()
|
||||||
|
enable_deepep_moe = get_moe_a2a_backend().is_deepep()
|
||||||
|
enable_two_batch_overlap = is_tbo_enabled()
|
||||||
|
|
||||||
self.enable_two_batch_overlap = enable_two_batch_overlap
|
self.enable_two_batch_overlap = enable_two_batch_overlap
|
||||||
|
|
||||||
if local_batch is not None:
|
if local_batch is not None:
|
||||||
@@ -384,7 +391,7 @@ class TboDPAttentionPreparer:
|
|||||||
and not local_batch.forward_mode.is_target_verify()
|
and not local_batch.forward_mode.is_target_verify()
|
||||||
)
|
)
|
||||||
and enable_deepep_moe
|
and enable_deepep_moe
|
||||||
and (resolved_deepep_mode == DeepEPMode.LOW_LATENCY)
|
and (resolved_deepep_mode.is_low_latency())
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
self.local_tbo_split_seq_index = 0
|
self.local_tbo_split_seq_index = 0
|
||||||
@@ -657,6 +664,7 @@ class TboForwardBatchPreparer:
|
|||||||
"req_to_token_pool",
|
"req_to_token_pool",
|
||||||
"token_to_kv_pool",
|
"token_to_kv_pool",
|
||||||
"can_run_dp_cuda_graph",
|
"can_run_dp_cuda_graph",
|
||||||
|
"dp_padding_mode",
|
||||||
"global_forward_mode",
|
"global_forward_mode",
|
||||||
"spec_algorithm",
|
"spec_algorithm",
|
||||||
"capture_hidden_mode",
|
"capture_hidden_mode",
|
||||||
@@ -701,7 +709,6 @@ class TboForwardBatchPreparer:
|
|||||||
tbo_children=None,
|
tbo_children=None,
|
||||||
global_num_tokens_gpu=None,
|
global_num_tokens_gpu=None,
|
||||||
global_num_tokens_cpu=None,
|
global_num_tokens_cpu=None,
|
||||||
dp_padding_mode=None,
|
|
||||||
global_dp_buffer_len=global_dp_buffer_len,
|
global_dp_buffer_len=global_dp_buffer_len,
|
||||||
global_num_tokens_for_logprob_gpu=None,
|
global_num_tokens_for_logprob_gpu=None,
|
||||||
global_num_tokens_for_logprob_cpu=None,
|
global_num_tokens_for_logprob_cpu=None,
|
||||||
@@ -955,9 +962,7 @@ def _model_forward_tbo_merge_outputs(output_a, output_b):
|
|||||||
|
|
||||||
class MaybeTboDeepEPDispatcher:
|
class MaybeTboDeepEPDispatcher:
|
||||||
def __init__(self, **kwargs):
|
def __init__(self, **kwargs):
|
||||||
num_inner_dispatchers = (
|
num_inner_dispatchers = 2 if is_tbo_enabled() else 1
|
||||||
2 if global_server_args_dict["enable_two_batch_overlap"] else 1
|
|
||||||
)
|
|
||||||
self._inners = [
|
self._inners = [
|
||||||
DeepEPDispatcher(**kwargs) for _ in range(num_inner_dispatchers)
|
DeepEPDispatcher(**kwargs) for _ in range(num_inner_dispatchers)
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -2413,7 +2413,7 @@ def require_mlp_tp_gather(server_args):
|
|||||||
return True
|
return True
|
||||||
elif not server_args.enable_dp_lm_head:
|
elif not server_args.enable_dp_lm_head:
|
||||||
return True
|
return True
|
||||||
elif server_args.moe_a2a_backend is None:
|
elif server_args.moe_a2a_backend == "none":
|
||||||
return True
|
return True
|
||||||
else:
|
else:
|
||||||
return (
|
return (
|
||||||
@@ -2429,7 +2429,7 @@ def require_attn_tp_gather(server_args):
|
|||||||
Check if the input of attention is scattered.
|
Check if the input of attention is scattered.
|
||||||
"""
|
"""
|
||||||
assert server_args.moe_dense_tp_size in [1, None]
|
assert server_args.moe_dense_tp_size in [1, None]
|
||||||
if server_args.moe_a2a_backend is not None or server_args.moe_dense_tp_size == 1:
|
if server_args.moe_a2a_backend != "none" or server_args.moe_dense_tp_size == 1:
|
||||||
if server_args.enable_dp_attention:
|
if server_args.enable_dp_attention:
|
||||||
return server_args.dp_size < server_args.tp_size
|
return server_args.dp_size < server_args.tp_size
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ import torch
|
|||||||
|
|
||||||
from sglang.srt.layers.activation import SiluAndMul
|
from sglang.srt.layers.activation import SiluAndMul
|
||||||
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_moe
|
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_moe
|
||||||
from sglang.srt.layers.moe.topk import select_experts
|
from sglang.srt.layers.moe.topk import TopKConfig, select_experts
|
||||||
from sglang.srt.layers.quantization.fp8_kernel import (
|
from sglang.srt.layers.quantization.fp8_kernel import (
|
||||||
per_tensor_quant_mla_fp8,
|
per_tensor_quant_mla_fp8,
|
||||||
per_token_group_quant_fp8,
|
per_token_group_quant_fp8,
|
||||||
@@ -498,11 +498,13 @@ class TestW8A8BlockFP8FusedMoE(CustomTestCase):
|
|||||||
score = torch.randn((M, E), dtype=dtype)
|
score = torch.randn((M, E), dtype=dtype)
|
||||||
|
|
||||||
with torch.inference_mode():
|
with torch.inference_mode():
|
||||||
|
ref_out = torch_w8a8_block_fp8_moe(
|
||||||
|
a, w1, w2, w1_s, w2_s, score, topk, block_size
|
||||||
|
)
|
||||||
topk_output = select_experts(
|
topk_output = select_experts(
|
||||||
hidden_states=a,
|
hidden_states=a,
|
||||||
router_logits=score,
|
router_logits=score,
|
||||||
top_k=topk,
|
topk_config=TopKConfig(top_k=topk, renormalize=False),
|
||||||
renormalize=False,
|
|
||||||
)
|
)
|
||||||
out = fused_moe(
|
out = fused_moe(
|
||||||
a,
|
a,
|
||||||
@@ -514,9 +516,6 @@ class TestW8A8BlockFP8FusedMoE(CustomTestCase):
|
|||||||
w2_scale=w2_s,
|
w2_scale=w2_s,
|
||||||
block_shape=block_size,
|
block_shape=block_size,
|
||||||
)
|
)
|
||||||
ref_out = torch_w8a8_block_fp8_moe(
|
|
||||||
a, w1, w2, w1_s, w2_s, score, topk, block_size
|
|
||||||
)
|
|
||||||
|
|
||||||
self.assertTrue(
|
self.assertTrue(
|
||||||
torch.mean(torch.abs(out.to(torch.float32) - ref_out.to(torch.float32)))
|
torch.mean(torch.abs(out.to(torch.float32) - ref_out.to(torch.float32)))
|
||||||
|
|||||||
@@ -12,7 +12,7 @@ from sglang.srt.layers.moe.ep_moe.kernels import (
|
|||||||
run_moe_ep_preproess,
|
run_moe_ep_preproess,
|
||||||
silu_and_mul_triton_kernel,
|
silu_and_mul_triton_kernel,
|
||||||
)
|
)
|
||||||
from sglang.srt.layers.moe.topk import select_experts
|
from sglang.srt.layers.moe.topk import TopKConfig, select_experts
|
||||||
from sglang.test.test_utils import CustomTestCase
|
from sglang.test.test_utils import CustomTestCase
|
||||||
|
|
||||||
|
|
||||||
@@ -22,35 +22,26 @@ def ep_moe(
|
|||||||
w1: torch.Tensor,
|
w1: torch.Tensor,
|
||||||
w2: torch.Tensor,
|
w2: torch.Tensor,
|
||||||
router_logits: torch.Tensor,
|
router_logits: torch.Tensor,
|
||||||
top_k: int,
|
topk_config: TopKConfig,
|
||||||
renormalize: bool,
|
|
||||||
# ep config
|
# ep config
|
||||||
num_experts: int = 256,
|
num_experts: int = 256,
|
||||||
fp8_dtype: torch.types = torch.float8_e4m3fn,
|
fp8_dtype: torch.types = torch.float8_e4m3fn,
|
||||||
num_experts_per_partition: int = 128,
|
num_experts_per_partition: int = 128,
|
||||||
start_expert_id: int = 0,
|
start_expert_id: int = 0,
|
||||||
end_expert_id: int = 127,
|
end_expert_id: int = 127,
|
||||||
use_grouped_topk: bool = False,
|
|
||||||
num_expert_group: Optional[int] = None,
|
|
||||||
topk_group: Optional[int] = None,
|
|
||||||
custom_routing_function: Optional[Callable] = None,
|
|
||||||
use_fp8_w8a8: bool = False,
|
use_fp8_w8a8: bool = False,
|
||||||
w1_scale_inv: Optional[torch.Tensor] = None,
|
w1_scale_inv: Optional[torch.Tensor] = None,
|
||||||
w2_scale_inv: Optional[torch.Tensor] = None,
|
w2_scale_inv: Optional[torch.Tensor] = None,
|
||||||
block_shape: Optional[List[int]] = None,
|
block_shape: Optional[List[int]] = None,
|
||||||
):
|
):
|
||||||
use_blockwise_fp8 = block_shape is not None
|
use_blockwise_fp8 = block_shape is not None
|
||||||
topk_weights, topk_ids, _ = select_experts(
|
top_k = topk_config.top_k
|
||||||
|
topk_output = select_experts(
|
||||||
hidden_states=hidden_states,
|
hidden_states=hidden_states,
|
||||||
router_logits=router_logits,
|
router_logits=router_logits,
|
||||||
top_k=top_k,
|
topk_config=topk_config,
|
||||||
use_grouped_topk=use_grouped_topk,
|
|
||||||
renormalize=renormalize,
|
|
||||||
topk_group=topk_group,
|
|
||||||
num_expert_group=num_expert_group,
|
|
||||||
# correction_bias=correction_bias, #skip this in test
|
|
||||||
custom_routing_function=custom_routing_function,
|
|
||||||
)
|
)
|
||||||
|
topk_weights, topk_ids, _ = topk_output
|
||||||
|
|
||||||
reorder_topk_ids, src2dst, seg_indptr = run_moe_ep_preproess(topk_ids, num_experts)
|
reorder_topk_ids, src2dst, seg_indptr = run_moe_ep_preproess(topk_ids, num_experts)
|
||||||
|
|
||||||
@@ -294,14 +285,18 @@ class TestW8A8BlockFP8EPMoE(CustomTestCase):
|
|||||||
start_id = cur_rank * num_experts_per_partition
|
start_id = cur_rank * num_experts_per_partition
|
||||||
end_id = start_id + num_experts_per_partition - 1
|
end_id = start_id + num_experts_per_partition - 1
|
||||||
|
|
||||||
|
topk_config = TopKConfig(
|
||||||
|
top_k=topk,
|
||||||
|
renormalize=False,
|
||||||
|
)
|
||||||
|
|
||||||
with torch.inference_mode():
|
with torch.inference_mode():
|
||||||
out = ep_moe(
|
out = ep_moe(
|
||||||
hidden_states=a,
|
hidden_states=a,
|
||||||
w1=w1,
|
w1=w1,
|
||||||
w2=w2,
|
w2=w2,
|
||||||
router_logits=score,
|
router_logits=score,
|
||||||
top_k=topk,
|
topk_config=topk_config,
|
||||||
renormalize=False,
|
|
||||||
use_fp8_w8a8=True,
|
use_fp8_w8a8=True,
|
||||||
w1_scale_inv=w1_s,
|
w1_scale_inv=w1_s,
|
||||||
w2_scale_inv=w2_s,
|
w2_scale_inv=w2_s,
|
||||||
@@ -316,8 +311,7 @@ class TestW8A8BlockFP8EPMoE(CustomTestCase):
|
|||||||
w1=w1_ref,
|
w1=w1_ref,
|
||||||
w2=w2_ref,
|
w2=w2_ref,
|
||||||
router_logits=score,
|
router_logits=score,
|
||||||
top_k=topk,
|
topk_config=topk_config,
|
||||||
renormalize=False,
|
|
||||||
use_fp8_w8a8=False,
|
use_fp8_w8a8=False,
|
||||||
w1_scale_inv=None,
|
w1_scale_inv=None,
|
||||||
w2_scale_inv=None,
|
w2_scale_inv=None,
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ import pytest
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from sglang.srt.layers.moe.cutlass_w4a8_moe import cutlass_w4a8_moe
|
from sglang.srt.layers.moe.cutlass_w4a8_moe import cutlass_w4a8_moe
|
||||||
from sglang.srt.layers.moe.topk import select_experts
|
from sglang.srt.layers.moe.topk import TopKConfig, select_experts
|
||||||
|
|
||||||
|
|
||||||
def pack_int4_values_to_int8(int4_values_interleaved: torch.Tensor) -> torch.Tensor:
|
def pack_int4_values_to_int8(int4_values_interleaved: torch.Tensor) -> torch.Tensor:
|
||||||
@@ -100,11 +100,12 @@ def test_cutlass_w4a8_moe(M, N, K, E, ep_size, topk, group_size, dtype):
|
|||||||
s_strides2 = c_strides2
|
s_strides2 = c_strides2
|
||||||
|
|
||||||
score = torch.randn((M, E), dtype=dtype, device=device)
|
score = torch.randn((M, E), dtype=dtype, device=device)
|
||||||
topk_weights, topk_ids, _ = select_experts(
|
topk_output = select_experts(
|
||||||
hidden_states=a,
|
hidden_states=a,
|
||||||
router_logits=score,
|
router_logits=score,
|
||||||
top_k=topk,
|
topk_config=TopKConfig(top_k=topk, renormalize=False),
|
||||||
)
|
)
|
||||||
|
topk_weights, topk_ids, _ = topk_output
|
||||||
expert_map = torch.arange(E, dtype=torch.int32, device=device)
|
expert_map = torch.arange(E, dtype=torch.int32, device=device)
|
||||||
expert_map[local_e:] = E
|
expert_map[local_e:] = E
|
||||||
|
|
||||||
|
|||||||
@@ -9,7 +9,7 @@ from sgl_kernel import scaled_fp4_quant
|
|||||||
from sglang.srt.layers.activation import SiluAndMul
|
from sglang.srt.layers.activation import SiluAndMul
|
||||||
from sglang.srt.layers.moe.cutlass_moe import cutlass_moe_fp4
|
from sglang.srt.layers.moe.cutlass_moe import cutlass_moe_fp4
|
||||||
from sglang.srt.layers.moe.cutlass_moe_params import CutlassMoEParams, CutlassMoEType
|
from sglang.srt.layers.moe.cutlass_moe_params import CutlassMoEParams, CutlassMoEType
|
||||||
from sglang.srt.layers.moe.topk import select_experts
|
from sglang.srt.layers.moe.topk import TopKConfig, select_experts
|
||||||
|
|
||||||
if torch.cuda.get_device_capability() < (10, 0):
|
if torch.cuda.get_device_capability() < (10, 0):
|
||||||
pytest.skip(
|
pytest.skip(
|
||||||
@@ -163,11 +163,12 @@ def check_moe(
|
|||||||
|
|
||||||
score = torch.randn((m, e), device="cuda", dtype=dtype)
|
score = torch.randn((m, e), device="cuda", dtype=dtype)
|
||||||
|
|
||||||
topk_weights, topk_ids, _ = select_experts(
|
topk_output = select_experts(
|
||||||
hidden_states=a,
|
hidden_states=a,
|
||||||
router_logits=score,
|
router_logits=score,
|
||||||
top_k=topk,
|
topk_config=TopKConfig(top_k=topk, renormalize=False),
|
||||||
)
|
)
|
||||||
|
topk_weights, topk_ids, _ = topk_output
|
||||||
|
|
||||||
a1_gs = torch.ones((e,), device="cuda", dtype=torch.float32)
|
a1_gs = torch.ones((e,), device="cuda", dtype=torch.float32)
|
||||||
a2_gs = torch.ones((e,), device="cuda", dtype=torch.float32)
|
a2_gs = torch.ones((e,), device="cuda", dtype=torch.float32)
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ import torch
|
|||||||
|
|
||||||
from sglang.srt.layers.activation import SiluAndMul
|
from sglang.srt.layers.activation import SiluAndMul
|
||||||
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_moe
|
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_moe
|
||||||
from sglang.srt.layers.moe.topk import select_experts
|
from sglang.srt.layers.moe.topk import TopKConfig, select_experts
|
||||||
from sglang.test.test_utils import CustomTestCase
|
from sglang.test.test_utils import CustomTestCase
|
||||||
|
|
||||||
|
|
||||||
@@ -175,10 +175,13 @@ class TestW8A8BlockINT8FusedMoE(CustomTestCase):
|
|||||||
topk_output = select_experts(
|
topk_output = select_experts(
|
||||||
hidden_states=a,
|
hidden_states=a,
|
||||||
router_logits=score,
|
router_logits=score,
|
||||||
top_k=topk,
|
topk_config=TopKConfig(top_k=topk, renormalize=False),
|
||||||
)
|
)
|
||||||
|
|
||||||
with torch.inference_mode():
|
with torch.inference_mode():
|
||||||
|
ref_out = torch_w8a8_block_int8_moe(
|
||||||
|
a, w1, w2, w1_s, w2_s, score, topk, block_size
|
||||||
|
)
|
||||||
out = fused_moe(
|
out = fused_moe(
|
||||||
a,
|
a,
|
||||||
w1,
|
w1,
|
||||||
@@ -189,9 +192,6 @@ class TestW8A8BlockINT8FusedMoE(CustomTestCase):
|
|||||||
w2_scale=w2_s,
|
w2_scale=w2_s,
|
||||||
block_shape=block_size,
|
block_shape=block_size,
|
||||||
)
|
)
|
||||||
ref_out = torch_w8a8_block_int8_moe(
|
|
||||||
a, w1, w2, w1_s, w2_s, score, topk, block_size
|
|
||||||
)
|
|
||||||
|
|
||||||
self.assertTrue(
|
self.assertTrue(
|
||||||
torch.mean(torch.abs(out.to(torch.float32) - ref_out.to(torch.float32)))
|
torch.mean(torch.abs(out.to(torch.float32) - ref_out.to(torch.float32)))
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ import torch
|
|||||||
|
|
||||||
from sglang.srt.layers.activation import SiluAndMul
|
from sglang.srt.layers.activation import SiluAndMul
|
||||||
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_moe
|
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_moe
|
||||||
from sglang.srt.layers.moe.topk import select_experts
|
from sglang.srt.layers.moe.topk import TopKConfig, select_experts
|
||||||
from sglang.srt.layers.quantization.int8_kernel import per_token_quant_int8
|
from sglang.srt.layers.quantization.int8_kernel import per_token_quant_int8
|
||||||
from sglang.test.test_utils import CustomTestCase
|
from sglang.test.test_utils import CustomTestCase
|
||||||
|
|
||||||
@@ -118,7 +118,7 @@ class TestW8A8Int8FusedMoE(CustomTestCase):
|
|||||||
topk_output = select_experts(
|
topk_output = select_experts(
|
||||||
hidden_states=a,
|
hidden_states=a,
|
||||||
router_logits=score,
|
router_logits=score,
|
||||||
top_k=topk,
|
topk_config=TopKConfig(top_k=topk, renormalize=False),
|
||||||
)
|
)
|
||||||
out = fused_moe(
|
out = fused_moe(
|
||||||
a,
|
a,
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ from tqdm import tqdm
|
|||||||
|
|
||||||
from sglang.srt.layers.activation import SiluAndMul
|
from sglang.srt.layers.activation import SiluAndMul
|
||||||
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_moe
|
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_moe
|
||||||
from sglang.srt.layers.moe.topk import select_experts
|
from sglang.srt.layers.moe.topk import TopKConfig, select_experts
|
||||||
from sglang.srt.layers.quantization.fp8_kernel import is_fp8_fnuz
|
from sglang.srt.layers.quantization.fp8_kernel import is_fp8_fnuz
|
||||||
from sglang.srt.layers.quantization.fp8_utils import normalize_e4m3fn_to_e4m3fnuz
|
from sglang.srt.layers.quantization.fp8_utils import normalize_e4m3fn_to_e4m3fnuz
|
||||||
from sglang.srt.utils import is_hip
|
from sglang.srt.utils import is_hip
|
||||||
@@ -136,19 +136,7 @@ class TestFusedMOE(CustomTestCase):
|
|||||||
topk_output = select_experts(
|
topk_output = select_experts(
|
||||||
hidden_states=a,
|
hidden_states=a,
|
||||||
router_logits=score,
|
router_logits=score,
|
||||||
top_k=topk,
|
topk_config=TopKConfig(top_k=topk, renormalize=False),
|
||||||
)
|
|
||||||
|
|
||||||
sglang_output = fused_moe(
|
|
||||||
a,
|
|
||||||
w1,
|
|
||||||
w2,
|
|
||||||
topk_output,
|
|
||||||
use_fp8_w8a8=True,
|
|
||||||
w1_scale=w1_scale,
|
|
||||||
w2_scale=w2_scale,
|
|
||||||
a1_scale=a1_scale,
|
|
||||||
a2_scale=a2_scale,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
torch_output = self.torch_naive_moe(
|
torch_output = self.torch_naive_moe(
|
||||||
@@ -162,6 +150,18 @@ class TestFusedMOE(CustomTestCase):
|
|||||||
a1_scale,
|
a1_scale,
|
||||||
a2_scale,
|
a2_scale,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
sglang_output = fused_moe(
|
||||||
|
a,
|
||||||
|
w1,
|
||||||
|
w2,
|
||||||
|
topk_output,
|
||||||
|
use_fp8_w8a8=True,
|
||||||
|
w1_scale=w1_scale,
|
||||||
|
w2_scale=w2_scale,
|
||||||
|
a1_scale=a1_scale,
|
||||||
|
a2_scale=a2_scale,
|
||||||
|
)
|
||||||
torch.testing.assert_close(
|
torch.testing.assert_close(
|
||||||
sglang_output, torch_output, rtol=rtol, atol=atol
|
sglang_output, torch_output, rtol=rtol, atol=atol
|
||||||
)
|
)
|
||||||
@@ -174,7 +174,7 @@ class TestFusedMOE(CustomTestCase):
|
|||||||
topk_output = select_experts(
|
topk_output = select_experts(
|
||||||
hidden_states=a,
|
hidden_states=a,
|
||||||
router_logits=score,
|
router_logits=score,
|
||||||
top_k=topk,
|
topk_config=TopKConfig(top_k=topk, renormalize=False),
|
||||||
)
|
)
|
||||||
|
|
||||||
triton_output = fused_moe(a, w1, w2, topk_output)
|
triton_output = fused_moe(a, w1, w2, topk_output)
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ import torch
|
|||||||
|
|
||||||
from sglang.srt.layers.activation import SiluAndMul
|
from sglang.srt.layers.activation import SiluAndMul
|
||||||
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_moe
|
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_moe
|
||||||
from sglang.srt.layers.moe.topk import select_experts
|
from sglang.srt.layers.moe.topk import TopKConfig, select_experts
|
||||||
from sglang.srt.layers.quantization.fp8_kernel import scaled_fp8_quant
|
from sglang.srt.layers.quantization.fp8_kernel import scaled_fp8_quant
|
||||||
from sglang.test.test_utils import CustomTestCase
|
from sglang.test.test_utils import CustomTestCase
|
||||||
|
|
||||||
@@ -130,7 +130,7 @@ class TestW8A8FP8FusedMoE(CustomTestCase):
|
|||||||
topk_output = select_experts(
|
topk_output = select_experts(
|
||||||
hidden_states=a,
|
hidden_states=a,
|
||||||
router_logits=score,
|
router_logits=score,
|
||||||
top_k=topk,
|
topk_config=TopKConfig(top_k=topk, renormalize=False),
|
||||||
)
|
)
|
||||||
out = fused_moe(
|
out = fused_moe(
|
||||||
a,
|
a,
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ import torch
|
|||||||
|
|
||||||
from sglang.srt.layers.activation import SiluAndMul
|
from sglang.srt.layers.activation import SiluAndMul
|
||||||
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_moe
|
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_moe
|
||||||
from sglang.srt.layers.moe.topk import select_experts
|
from sglang.srt.layers.moe.topk import TopKConfig, select_experts
|
||||||
|
|
||||||
NUM_EXPERTS = [8, 64]
|
NUM_EXPERTS = [8, 64]
|
||||||
TOP_KS = [2, 6]
|
TOP_KS = [2, 6]
|
||||||
@@ -223,7 +223,7 @@ def test_fused_moe_wn16(
|
|||||||
topk_output = select_experts(
|
topk_output = select_experts(
|
||||||
hidden_states=a,
|
hidden_states=a,
|
||||||
router_logits=score,
|
router_logits=score,
|
||||||
top_k=topk,
|
topk_config=TopKConfig(top_k=topk),
|
||||||
)
|
)
|
||||||
|
|
||||||
triton_output = fused_moe(
|
triton_output = fused_moe(
|
||||||
|
|||||||
Reference in New Issue
Block a user