Add fp4 quantize before all-gather for Flashinfer cutlass MoE DP (max throughput) (#7667)
This commit is contained in:
@@ -148,7 +148,11 @@ class PyNcclCommunicator:
|
||||
)
|
||||
|
||||
def all_gather(
|
||||
self, output_tensor: torch.Tensor, input_tensor: torch.Tensor, stream=None
|
||||
self,
|
||||
output_tensor: torch.Tensor,
|
||||
input_tensor: torch.Tensor,
|
||||
stream=None,
|
||||
sizes: Optional[list[int]] = None,
|
||||
):
|
||||
if self.disabled:
|
||||
return
|
||||
@@ -161,14 +165,33 @@ class PyNcclCommunicator:
|
||||
)
|
||||
if stream is None:
|
||||
stream = self.stream
|
||||
self.nccl.ncclAllGather(
|
||||
buffer_type(input_tensor.data_ptr()),
|
||||
buffer_type(output_tensor.data_ptr()),
|
||||
input_tensor.numel(),
|
||||
ncclDataTypeEnum.from_torch(input_tensor.dtype),
|
||||
self.comm,
|
||||
cudaStream_t(stream.cuda_stream),
|
||||
)
|
||||
|
||||
if sizes is not None:
|
||||
split_offset = 0
|
||||
|
||||
self.nccl.ncclGroupStart()
|
||||
for root, split_size in enumerate(sizes):
|
||||
dst_slice = output_tensor[split_offset : split_offset + split_size]
|
||||
self.nccl.ncclBroadcast(
|
||||
buffer_type(input_tensor.data_ptr()),
|
||||
buffer_type(dst_slice.data_ptr()),
|
||||
dst_slice.numel(),
|
||||
ncclDataTypeEnum.from_torch(input_tensor.dtype),
|
||||
root,
|
||||
self.comm,
|
||||
cudaStream_t(stream.cuda_stream),
|
||||
)
|
||||
split_offset += split_size
|
||||
self.nccl.ncclGroupEnd()
|
||||
else:
|
||||
self.nccl.ncclAllGather(
|
||||
buffer_type(input_tensor.data_ptr()),
|
||||
buffer_type(output_tensor.data_ptr()),
|
||||
input_tensor.numel(),
|
||||
ncclDataTypeEnum.from_torch(input_tensor.dtype),
|
||||
self.comm,
|
||||
cudaStream_t(stream.cuda_stream),
|
||||
)
|
||||
|
||||
def reduce_scatter(
|
||||
self,
|
||||
@@ -176,6 +199,7 @@ class PyNcclCommunicator:
|
||||
input_tensor: torch.Tensor,
|
||||
op: ReduceOp = ReduceOp.SUM,
|
||||
stream=None,
|
||||
sizes: Optional[list[int]] = None,
|
||||
):
|
||||
if self.disabled:
|
||||
return
|
||||
@@ -188,15 +212,35 @@ class PyNcclCommunicator:
|
||||
)
|
||||
if stream is None:
|
||||
stream = self.stream
|
||||
self.nccl.ncclReduceScatter(
|
||||
buffer_type(input_tensor.data_ptr()),
|
||||
buffer_type(output_tensor.data_ptr()),
|
||||
output_tensor.numel(),
|
||||
ncclDataTypeEnum.from_torch(input_tensor.dtype),
|
||||
ncclRedOpTypeEnum.from_torch(op),
|
||||
self.comm,
|
||||
cudaStream_t(stream.cuda_stream),
|
||||
)
|
||||
|
||||
if sizes is not None:
|
||||
split_offset = 0
|
||||
self.nccl.ncclGroupStart()
|
||||
for root, split_size in enumerate(sizes):
|
||||
chunk = input_tensor[split_offset : split_offset + split_size, ...]
|
||||
|
||||
self.nccl.ncclReduce(
|
||||
buffer_type(chunk.data_ptr()),
|
||||
buffer_type(output_tensor.data_ptr()),
|
||||
chunk.numel(),
|
||||
ncclDataTypeEnum.from_torch(input_tensor.dtype),
|
||||
ncclRedOpTypeEnum.from_torch(op),
|
||||
root,
|
||||
self.comm,
|
||||
cudaStream_t(stream.cuda_stream),
|
||||
)
|
||||
split_offset += split_size
|
||||
self.nccl.ncclGroupEnd()
|
||||
else:
|
||||
self.nccl.ncclReduceScatter(
|
||||
buffer_type(input_tensor.data_ptr()),
|
||||
buffer_type(output_tensor.data_ptr()),
|
||||
output_tensor.numel(),
|
||||
ncclDataTypeEnum.from_torch(input_tensor.dtype),
|
||||
ncclRedOpTypeEnum.from_torch(op),
|
||||
self.comm,
|
||||
cudaStream_t(stream.cuda_stream),
|
||||
)
|
||||
|
||||
def send(self, tensor: torch.Tensor, dst: int, stream=None):
|
||||
if self.disabled:
|
||||
@@ -266,6 +310,12 @@ class PyNcclCommunicator:
|
||||
def deregister_comm_window(self, window):
|
||||
return self.nccl.ncclCommWindowDeregister(self.comm, window)
|
||||
|
||||
def group_start(self):
|
||||
self.nccl.ncclGroupStart()
|
||||
|
||||
def group_end(self):
|
||||
self.nccl.ncclGroupEnd()
|
||||
|
||||
@contextmanager
|
||||
def change_state(
|
||||
self, enable: Optional[bool] = None, stream: Optional[torch.cuda.Stream] = None
|
||||
|
||||
@@ -206,6 +206,26 @@ class NCCLLibrary:
|
||||
cudaStream_t,
|
||||
],
|
||||
),
|
||||
# ncclResult_t ncclReduce(
|
||||
# const void* sendbuff, void* recvbuff, size_t count,
|
||||
# ncclDataType_t datatype, ncclRedOp_t op, int root,
|
||||
# ncclComm_t comm, cudaStream_t stream);
|
||||
# note that cudaStream_t is a pointer type, so the last argument
|
||||
# is a pointer
|
||||
Function(
|
||||
"ncclReduce",
|
||||
ncclResult_t,
|
||||
[
|
||||
buffer_type,
|
||||
buffer_type,
|
||||
ctypes.c_size_t,
|
||||
ncclDataType_t,
|
||||
ncclRedOp_t,
|
||||
ctypes.c_int,
|
||||
ncclComm_t,
|
||||
cudaStream_t,
|
||||
],
|
||||
),
|
||||
# ncclResult_t ncclReduceScatter(
|
||||
# const void* sendbuff, void* recvbuff, size_t count,
|
||||
# ncclDataType_t datatype, ncclRedOp_t op, ncclComm_t comm,
|
||||
@@ -278,6 +298,10 @@ class NCCLLibrary:
|
||||
# it is better not to call it at all.
|
||||
# ncclResult_t ncclCommDestroy(ncclComm_t comm);
|
||||
Function("ncclCommDestroy", ncclResult_t, [ncclComm_t]),
|
||||
# ncclResult_t ncclGroupStart();
|
||||
Function("ncclGroupStart", ncclResult_t, []),
|
||||
# ncclResult_t ncclGroupEnd();
|
||||
Function("ncclGroupEnd", ncclResult_t, []),
|
||||
]
|
||||
|
||||
exported_functions_symm_mem = [
|
||||
@@ -400,6 +424,28 @@ class NCCLLibrary:
|
||||
)
|
||||
)
|
||||
|
||||
def ncclReduce(
|
||||
self,
|
||||
sendbuff: buffer_type,
|
||||
recvbuff: buffer_type,
|
||||
count: int,
|
||||
datatype: int,
|
||||
op: int,
|
||||
root: int,
|
||||
comm: ncclComm_t,
|
||||
stream: cudaStream_t,
|
||||
) -> None:
|
||||
# `datatype` actually should be `ncclDataType_t`
|
||||
# and `op` should be `ncclRedOp_t`
|
||||
# both are aliases of `ctypes.c_int`
|
||||
# when we pass int to a function, it will be converted to `ctypes.c_int`
|
||||
# by ctypes automatically
|
||||
self.NCCL_CHECK(
|
||||
self._funcs["ncclReduce"](
|
||||
sendbuff, recvbuff, count, datatype, op, root, comm, stream
|
||||
)
|
||||
)
|
||||
|
||||
def ncclReduceScatter(
|
||||
self,
|
||||
sendbuff: buffer_type,
|
||||
@@ -499,6 +545,12 @@ class NCCLLibrary:
|
||||
def ncclCommWindowDeregister(self, comm: ncclComm_t, window: ncclWindow_t) -> None:
|
||||
self.NCCL_CHECK(self._funcs["ncclCommWindowDeregister"](comm, window))
|
||||
|
||||
def ncclGroupStart(self) -> None:
|
||||
self.NCCL_CHECK(self._funcs["ncclGroupStart"]())
|
||||
|
||||
def ncclGroupEnd(self) -> None:
|
||||
self.NCCL_CHECK(self._funcs["ncclGroupEnd"]())
|
||||
|
||||
|
||||
__all__ = [
|
||||
"NCCLLibrary",
|
||||
|
||||
@@ -583,6 +583,39 @@ class GroupCoordinator:
|
||||
torch.distributed.reduce_scatter(output, input_list, group=self.device_group)
|
||||
return output
|
||||
|
||||
def reduce_scatterv(
|
||||
self,
|
||||
input_: torch.Tensor,
|
||||
output: Optional[torch.Tensor] = None,
|
||||
sizes: Optional[List[int]] = None,
|
||||
) -> torch.Tensor:
|
||||
world_size = self.world_size
|
||||
pynccl_comm = self.pynccl_comm
|
||||
|
||||
with pynccl_comm.change_state(enable=True, stream=torch.cuda.current_stream()):
|
||||
assert (
|
||||
pynccl_comm is not None and not pynccl_comm.disabled
|
||||
), "pynccl is required for reduce_scatterv"
|
||||
|
||||
if sizes is not None:
|
||||
assert len(sizes) == world_size
|
||||
assert input_.shape[0] == sum(sizes)
|
||||
chunk_size = sizes[self.rank_in_group]
|
||||
else:
|
||||
assert input_.shape[0] % world_size == 0
|
||||
chunk_size = input_.shape[0] // world_size
|
||||
output_shape = (chunk_size,) + input_.shape[1:]
|
||||
|
||||
if output is None:
|
||||
output = torch.empty(
|
||||
output_shape, dtype=input_.dtype, device=input_.device
|
||||
)
|
||||
else:
|
||||
assert output.shape == output_shape
|
||||
|
||||
pynccl_comm.reduce_scatter(output, input_, sizes=sizes)
|
||||
return output
|
||||
|
||||
def _all_gather_into_tensor(self, output: torch.Tensor, input: torch.Tensor):
|
||||
pynccl_comm = self.pynccl_comm
|
||||
if pynccl_comm is not None and not pynccl_comm.disabled:
|
||||
@@ -673,6 +706,54 @@ class GroupCoordinator:
|
||||
)
|
||||
return output_tensor
|
||||
|
||||
def all_gatherv(
|
||||
self,
|
||||
input_: Union[torch.Tensor, List[torch.Tensor]],
|
||||
sizes: Optional[List[int]] = None,
|
||||
) -> Union[torch.Tensor, List[torch.Tensor]]:
|
||||
"""
|
||||
Supports varying sizes per rank and input tensor list.
|
||||
`sizes`: a list of len(world_size) with the number of items per rank to gather.
|
||||
"""
|
||||
world_size = self.world_size
|
||||
pynccl_comm = self.pynccl_comm
|
||||
|
||||
with pynccl_comm.change_state(enable=True, stream=torch.cuda.current_stream()):
|
||||
assert (
|
||||
pynccl_comm is not None and not pynccl_comm.disabled
|
||||
), "pynccl is required for all_gatherv"
|
||||
|
||||
def _all_gather_single(
|
||||
input_: torch.Tensor, sizes: Optional[List[int]] = None
|
||||
):
|
||||
input_size = input_.size()
|
||||
if sizes is not None:
|
||||
assert len(sizes) == world_size
|
||||
assert input_.shape[0] == sizes[self.rank_in_group]
|
||||
output_size = (sum(sizes),) + input_size[1:]
|
||||
# 'sizes' is not needed if all inputs in the same group have the same shape
|
||||
if all(s == sizes[0] for s in sizes):
|
||||
sizes = None
|
||||
else:
|
||||
output_size = (input_size[0] * world_size,) + input_size[1:]
|
||||
# Allocate output tensor.
|
||||
output_tensor = torch.empty(
|
||||
output_size, dtype=input_.dtype, device=input_.device
|
||||
)
|
||||
pynccl_comm.all_gather(output_tensor, input_, sizes=sizes)
|
||||
return output_tensor
|
||||
|
||||
if isinstance(input_, torch.Tensor):
|
||||
return _all_gather_single(input_, sizes)
|
||||
|
||||
output_list = []
|
||||
pynccl_comm.group_start()
|
||||
for inp in input_:
|
||||
output_list.append(_all_gather_single(inp, sizes=sizes))
|
||||
pynccl_comm.group_end()
|
||||
|
||||
return output_list
|
||||
|
||||
def gather(
|
||||
self, input_: torch.Tensor, dst: int = 0, dim: int = -1
|
||||
) -> Optional[torch.Tensor]:
|
||||
|
||||
@@ -35,7 +35,10 @@ from sglang.srt.layers.dp_attention import (
|
||||
get_global_dp_buffer,
|
||||
get_local_dp_buffer,
|
||||
)
|
||||
from sglang.srt.layers.moe import get_moe_a2a_backend
|
||||
from sglang.srt.layers.moe import (
|
||||
get_moe_a2a_backend,
|
||||
should_use_flashinfer_cutlass_moe_fp4_allgather,
|
||||
)
|
||||
from sglang.srt.layers.utils import is_sm100_supported
|
||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
@@ -112,7 +115,11 @@ class LayerScatterModes:
|
||||
if context.is_layer_sparse:
|
||||
return (
|
||||
ScatterMode.SCATTERED
|
||||
if not get_moe_a2a_backend().is_none()
|
||||
if (
|
||||
# Token dispatch/combine will be handled outside of LayerCommunicator for these modes.
|
||||
not get_moe_a2a_backend().is_none()
|
||||
or should_use_flashinfer_cutlass_moe_fp4_allgather()
|
||||
)
|
||||
else ScatterMode.FULL
|
||||
)
|
||||
else:
|
||||
|
||||
@@ -72,6 +72,7 @@ class _DpGatheredBufferWrapper:
|
||||
_device: torch.device
|
||||
_global_dp_buffer_len: int
|
||||
_local_dp_buffer_len: int
|
||||
_global_num_tokens: Optional[List[int]]
|
||||
|
||||
@classmethod
|
||||
def set_metadata(cls, hidden_size: int, dtype: torch.dtype, device: torch.device):
|
||||
@@ -80,9 +81,15 @@ class _DpGatheredBufferWrapper:
|
||||
cls._device = device
|
||||
|
||||
@classmethod
|
||||
def set_dp_buffer_len(cls, global_dp_buffer_len: int, local_dp_buffer_len: int):
|
||||
def set_dp_buffer_len(
|
||||
cls,
|
||||
global_dp_buffer_len: int,
|
||||
local_dp_buffer_len: int,
|
||||
global_num_tokens: Optional[List[int]] = None,
|
||||
):
|
||||
cls._global_dp_buffer_len = global_dp_buffer_len
|
||||
cls._local_dp_buffer_len = local_dp_buffer_len
|
||||
cls._global_num_tokens = global_num_tokens
|
||||
|
||||
@classmethod
|
||||
def get_global_dp_buffer(cls) -> torch.Tensor:
|
||||
@@ -108,10 +115,18 @@ class _DpGatheredBufferWrapper:
|
||||
def get_local_dp_buffer_len(cls) -> int:
|
||||
return cls._local_dp_buffer_len
|
||||
|
||||
@classmethod
|
||||
def get_dp_global_num_tokens(cls) -> List[int]:
|
||||
return cls._global_num_tokens
|
||||
|
||||
def set_dp_buffer_len(global_dp_buffer_len: int, local_dp_buffer_len: int):
|
||||
|
||||
def set_dp_buffer_len(
|
||||
global_dp_buffer_len: int,
|
||||
local_dp_buffer_len: int,
|
||||
global_num_tokens: Optional[List[int]] = None,
|
||||
):
|
||||
_DpGatheredBufferWrapper.set_dp_buffer_len(
|
||||
global_dp_buffer_len, local_dp_buffer_len
|
||||
global_dp_buffer_len, local_dp_buffer_len, global_num_tokens
|
||||
)
|
||||
|
||||
|
||||
@@ -131,6 +146,10 @@ def get_local_dp_buffer_len() -> int:
|
||||
return _DpGatheredBufferWrapper.get_local_dp_buffer_len()
|
||||
|
||||
|
||||
def get_dp_global_num_tokens() -> List[int]:
|
||||
return _DpGatheredBufferWrapper.get_dp_global_num_tokens()
|
||||
|
||||
|
||||
def compute_dp_attention_world_info(enable_dp_attention, tp_rank, tp_size, dp_size):
|
||||
if not enable_dp_attention:
|
||||
return tp_rank, tp_size, 0
|
||||
|
||||
@@ -191,7 +191,11 @@ class LogitsMetadata:
|
||||
else:
|
||||
self.global_dp_buffer_len = self.global_dp_buffer_len
|
||||
|
||||
set_dp_buffer_len(self.global_dp_buffer_len, self.dp_local_num_tokens)
|
||||
set_dp_buffer_len(
|
||||
self.global_dp_buffer_len,
|
||||
self.dp_local_num_tokens,
|
||||
self.global_num_tokens_for_logprob_cpu,
|
||||
)
|
||||
|
||||
|
||||
class LogitsProcessor(nn.Module):
|
||||
|
||||
@@ -10,6 +10,7 @@ from sglang.srt.layers.moe.utils import (
|
||||
get_tbo_token_distribution_threshold,
|
||||
initialize_moe_config,
|
||||
is_tbo_enabled,
|
||||
should_use_flashinfer_cutlass_moe_fp4_allgather,
|
||||
should_use_flashinfer_trtllm_moe,
|
||||
)
|
||||
|
||||
@@ -23,6 +24,7 @@ __all__ = [
|
||||
"get_moe_runner_backend",
|
||||
"get_deepep_mode",
|
||||
"should_use_flashinfer_trtllm_moe",
|
||||
"should_use_flashinfer_cutlass_moe_fp4_allgather",
|
||||
"is_tbo_enabled",
|
||||
"get_tbo_token_distribution_threshold",
|
||||
"get_deepep_config",
|
||||
|
||||
@@ -28,6 +28,7 @@ from sglang.srt.layers.quantization.base_config import (
|
||||
QuantizationConfig,
|
||||
QuantizeMethodBase,
|
||||
)
|
||||
from sglang.srt.layers.quantization.modelopt_quant import ModelOptNvFp4FusedMoEMethod
|
||||
from sglang.srt.layers.quantization.unquant import UnquantizedFusedMoEMethod
|
||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||
from sglang.srt.model_loader.weight_utils import narrow_padded_param_and_loaded_weight
|
||||
@@ -621,9 +622,7 @@ class FusedMoE(torch.nn.Module):
|
||||
|
||||
if "ModelOpt" in self.quant_method.__class__.__name__:
|
||||
# Determine per-tensor weight scale patterns based on variant
|
||||
is_fp4_variant = (
|
||||
"ModelOptNvFp4FusedMoEMethod" in self.quant_method.__class__.__name__
|
||||
)
|
||||
is_fp4_variant = isinstance(self.quant_method, ModelOptNvFp4FusedMoEMethod)
|
||||
|
||||
# FP4 uses "weight_scale_2" for per-tensor, FP8 uses "weight_scale" for per-tensor
|
||||
per_tensor_conditions = (
|
||||
|
||||
@@ -327,6 +327,13 @@ class TopK(CustomOp):
|
||||
expert_location_dispatch_info=expert_location_dispatch_info,
|
||||
)
|
||||
|
||||
def empty_topk_output(self, device: torch.device) -> TopKOutput:
|
||||
topk = self.topk_config.top_k - self.topk_config.num_fused_shared_experts
|
||||
topk_weights = torch.empty((0, topk), dtype=torch.float32, device=device)
|
||||
topk_idx = torch.full((0, topk), -1, dtype=torch.int32, device=device)
|
||||
router_logits = torch.empty((0, topk), dtype=torch.float32, device=device)
|
||||
return StandardTopKOutput(topk_weights, topk_idx, router_logits)
|
||||
|
||||
|
||||
# ------------------------------- TopK implementation -------------------------------------
|
||||
|
||||
|
||||
@@ -7,6 +7,11 @@ from typing import TYPE_CHECKING, Optional
|
||||
|
||||
from packaging import version as pkg_version
|
||||
|
||||
from sglang.srt.distributed.parallel_state import get_moe_expert_parallel_world_size
|
||||
from sglang.srt.layers.dp_attention import (
|
||||
get_attention_dp_size,
|
||||
is_dp_attention_enabled,
|
||||
)
|
||||
from sglang.srt.utils import logger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -99,6 +104,7 @@ DEEPEP_MODE: Optional[DeepEPMode] = None
|
||||
IS_TBO_ENABLED: Optional[bool] = None
|
||||
TBO_TOKEN_DISTRIBUTION_THRESHOLD: Optional[float] = None
|
||||
DEEPEP_CONFIG: Optional[str] = None
|
||||
DISABLE_FLASHINFER_CUTLASS_MOE_FP4_ALLGATHER: Optional[bool] = None
|
||||
|
||||
|
||||
def initialize_moe_config(server_args: ServerArgs):
|
||||
@@ -108,6 +114,7 @@ def initialize_moe_config(server_args: ServerArgs):
|
||||
global DEEPEP_CONFIG
|
||||
global IS_TBO_ENABLED
|
||||
global TBO_TOKEN_DISTRIBUTION_THRESHOLD
|
||||
global DISABLE_FLASHINFER_CUTLASS_MOE_FP4_ALLGATHER
|
||||
|
||||
MOE_A2A_BACKEND = MoeA2ABackend(server_args.moe_a2a_backend)
|
||||
MOE_RUNNER_BACKEND = MoeRunnerBackend(server_args.moe_runner_backend)
|
||||
@@ -115,6 +122,9 @@ def initialize_moe_config(server_args: ServerArgs):
|
||||
DEEPEP_CONFIG = server_args.deepep_config or ""
|
||||
IS_TBO_ENABLED = server_args.enable_two_batch_overlap
|
||||
TBO_TOKEN_DISTRIBUTION_THRESHOLD = server_args.tbo_token_distribution_threshold
|
||||
DISABLE_FLASHINFER_CUTLASS_MOE_FP4_ALLGATHER = (
|
||||
server_args.disable_flashinfer_cutlass_moe_fp4_allgather
|
||||
)
|
||||
|
||||
|
||||
def get_moe_a2a_backend() -> MoeA2ABackend:
|
||||
@@ -175,3 +185,16 @@ def should_use_flashinfer_trtllm_moe():
|
||||
>= pkg_version.parse("0.2.9rc1")
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def should_use_flashinfer_cutlass_moe_fp4_allgather():
|
||||
"""
|
||||
Perform FP4 quantize before all-gather for flashinfer cutlass moe to reduce communication cost for high-throughput serving.
|
||||
"""
|
||||
return (
|
||||
not DISABLE_FLASHINFER_CUTLASS_MOE_FP4_ALLGATHER
|
||||
and get_moe_runner_backend().is_flashinfer_cutlass()
|
||||
and is_dp_attention_enabled()
|
||||
and get_moe_expert_parallel_world_size() == get_attention_dp_size()
|
||||
)
|
||||
|
||||
@@ -7,7 +7,12 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional
|
||||
import torch
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from sglang.srt.layers.moe import should_use_flashinfer_trtllm_moe
|
||||
from sglang.srt.distributed import get_tp_group
|
||||
from sglang.srt.layers.dp_attention import get_dp_global_num_tokens, get_local_dp_buffer
|
||||
from sglang.srt.layers.moe import (
|
||||
should_use_flashinfer_cutlass_moe_fp4_allgather,
|
||||
should_use_flashinfer_trtllm_moe,
|
||||
)
|
||||
from sglang.srt.layers.moe.cutlass_moe_params import CutlassMoEParams, CutlassMoEType
|
||||
from sglang.srt.layers.parameter import ModelWeightParameter, PerTensorScaleParameter
|
||||
from sglang.srt.layers.quantization.base_config import (
|
||||
@@ -1176,16 +1181,37 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
|
||||
), "apply_router_weight_on_input is not supported for Flashinfer"
|
||||
# TRTLLM Cutlass moe takes in activations in BF16/Half/nvfp4 precision
|
||||
# and fp4 quantized weights loaded from the checkpoint
|
||||
|
||||
topk_weights, topk_ids = topk_output.topk_weights, topk_output.topk_ids
|
||||
|
||||
output_dtype = x.dtype
|
||||
x_sf = None
|
||||
if should_use_flashinfer_cutlass_moe_fp4_allgather():
|
||||
from flashinfer import fp4_quantize, nvfp4_block_scale_interleave
|
||||
|
||||
# Quantize before comm, swizzle after.
|
||||
if x.shape[0] > 0:
|
||||
x, x_sf = fp4_quantize(
|
||||
x, layer.w13_input_scale_quant, is_sf_swizzled_layout=False
|
||||
)
|
||||
else:
|
||||
x_col = x.shape[1]
|
||||
x = torch.zeros(0, x_col // 2, dtype=torch.uint8, device=x.device)
|
||||
x_sf = torch.zeros(
|
||||
0, x_col // 16, dtype=torch.uint8, device=x.device
|
||||
)
|
||||
topk_weights, topk_ids, x, x_sf = get_tp_group().all_gatherv(
|
||||
[topk_weights, topk_ids, x, x_sf], sizes=get_dp_global_num_tokens()
|
||||
)
|
||||
x_sf = nvfp4_block_scale_interleave(x_sf)
|
||||
|
||||
output = flashinfer_cutlass_fused_moe(
|
||||
x,
|
||||
topk_ids.to(torch.int),
|
||||
topk_weights,
|
||||
layer.w13_weight.view(torch.long),
|
||||
layer.w2_weight.view(torch.long),
|
||||
x.dtype,
|
||||
input=x,
|
||||
token_selected_experts=topk_ids.to(torch.int),
|
||||
token_final_scales=topk_weights,
|
||||
fc1_expert_weights=layer.w13_weight.view(torch.long),
|
||||
fc2_expert_weights=layer.w2_weight.view(torch.long),
|
||||
output_dtype=output_dtype,
|
||||
input_sf=x_sf,
|
||||
quant_scales=[
|
||||
layer.w13_input_scale_quant,
|
||||
layer.w13_blockscale_swizzled.view(torch.int32),
|
||||
@@ -1202,6 +1228,11 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
|
||||
)[0]
|
||||
if moe_runner_config.routed_scaling_factor is not None:
|
||||
output *= moe_runner_config.routed_scaling_factor
|
||||
if should_use_flashinfer_cutlass_moe_fp4_allgather():
|
||||
output, global_output = get_local_dp_buffer(), output
|
||||
get_tp_group().reduce_scatterv(
|
||||
global_output, output=output, sizes=get_dp_global_num_tokens()
|
||||
)
|
||||
return output
|
||||
|
||||
from sglang.srt.layers.moe.cutlass_moe import cutlass_moe_fp4
|
||||
|
||||
@@ -84,6 +84,7 @@ GLOBAL_SERVER_ARGS_KEYS = [
|
||||
"chunked_prefill_size",
|
||||
"device",
|
||||
"disable_chunked_prefix_cache",
|
||||
"disable_flashinfer_cutlass_moe_fp4_allgather",
|
||||
"disable_radix_cache",
|
||||
"enable_dp_lm_head",
|
||||
"enable_flashinfer_allreduce_fusion",
|
||||
|
||||
@@ -649,7 +649,7 @@ class ForwardBatch:
|
||||
num_tokens = global_num_tokens[0]
|
||||
|
||||
self.global_dp_buffer_len = buffer_len
|
||||
set_dp_buffer_len(buffer_len, num_tokens)
|
||||
set_dp_buffer_len(buffer_len, num_tokens, global_num_tokens)
|
||||
|
||||
bs = self.batch_size
|
||||
|
||||
|
||||
@@ -60,7 +60,11 @@ from sglang.srt.layers.linear import (
|
||||
RowParallelLinear,
|
||||
)
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||
from sglang.srt.layers.moe import get_deepep_mode, get_moe_a2a_backend
|
||||
from sglang.srt.layers.moe import (
|
||||
get_deepep_mode,
|
||||
get_moe_a2a_backend,
|
||||
should_use_flashinfer_cutlass_moe_fp4_allgather,
|
||||
)
|
||||
from sglang.srt.layers.moe.ep_moe.layer import DeepEPMoE, get_moe_impl_class
|
||||
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
|
||||
from sglang.srt.layers.moe.topk import TopK
|
||||
@@ -343,7 +347,7 @@ class DeepseekV2MoE(nn.Module):
|
||||
self.shared_experts_weight_block_size = None
|
||||
if config.n_shared_experts is not None and self.num_fused_shared_experts == 0:
|
||||
intermediate_size = config.moe_intermediate_size * config.n_shared_experts
|
||||
# disable tp for shared experts when enable deepep moe
|
||||
# disable tp for shared experts when enable deepep moe, or with fp4 allgather
|
||||
self.shared_experts = DeepseekV2MLP(
|
||||
hidden_size=config.hidden_size,
|
||||
intermediate_size=intermediate_size,
|
||||
@@ -354,6 +358,7 @@ class DeepseekV2MoE(nn.Module):
|
||||
**(
|
||||
dict(tp_rank=0, tp_size=1)
|
||||
if get_moe_a2a_backend().is_deepep()
|
||||
or should_use_flashinfer_cutlass_moe_fp4_allgather()
|
||||
else {}
|
||||
),
|
||||
)
|
||||
@@ -433,14 +438,19 @@ class DeepseekV2MoE(nn.Module):
|
||||
if (
|
||||
self.alt_stream is not None
|
||||
and self.num_fused_shared_experts == 0
|
||||
and hidden_states.shape[0] > 0
|
||||
and hidden_states.shape[0] <= DUAL_STREAM_TOKEN_THRESHOLD
|
||||
):
|
||||
return self.forward_normal_dual_stream(
|
||||
hidden_states, should_allreduce_fusion, use_reduce_scatter
|
||||
hidden_states,
|
||||
should_allreduce_fusion,
|
||||
use_reduce_scatter,
|
||||
)
|
||||
else:
|
||||
return self.forward_normal(
|
||||
hidden_states, should_allreduce_fusion, use_reduce_scatter
|
||||
hidden_states,
|
||||
should_allreduce_fusion,
|
||||
use_reduce_scatter,
|
||||
)
|
||||
else:
|
||||
return self.forward_deepep(hidden_states, forward_batch)
|
||||
@@ -471,7 +481,12 @@ class DeepseekV2MoE(nn.Module):
|
||||
torch.add(final_hidden_states, shared_output, out=final_hidden_states_out)
|
||||
final_hidden_states = final_hidden_states_out
|
||||
sm.tag(final_hidden_states)
|
||||
if self.tp_size > 1 and not should_allreduce_fusion and not use_reduce_scatter:
|
||||
if (
|
||||
self.tp_size > 1
|
||||
and not should_allreduce_fusion
|
||||
and not use_reduce_scatter
|
||||
and not should_use_flashinfer_cutlass_moe_fp4_allgather()
|
||||
):
|
||||
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
|
||||
return final_hidden_states
|
||||
|
||||
@@ -486,10 +501,14 @@ class DeepseekV2MoE(nn.Module):
|
||||
):
|
||||
return self.forward_cpu(hidden_states, should_allreduce_fusion)
|
||||
|
||||
shared_output = self._forward_shared_experts(hidden_states)
|
||||
# router_logits: (num_tokens, n_experts)
|
||||
router_logits = self.gate(hidden_states)
|
||||
topk_output = self.topk(hidden_states, router_logits)
|
||||
if hidden_states.shape[0] > 0:
|
||||
shared_output = self._forward_shared_experts(hidden_states)
|
||||
# router_logits: (num_tokens, n_experts)
|
||||
router_logits = self.gate(hidden_states)
|
||||
topk_output = self.topk(hidden_states, router_logits)
|
||||
else:
|
||||
shared_output = None
|
||||
topk_output = self.topk.empty_topk_output(hidden_states.device)
|
||||
|
||||
final_hidden_states = self.experts(hidden_states, topk_output)
|
||||
if not _is_cuda and not _use_aiter:
|
||||
@@ -501,7 +520,12 @@ class DeepseekV2MoE(nn.Module):
|
||||
torch.add(final_hidden_states, shared_output, out=final_hidden_states_out)
|
||||
final_hidden_states = final_hidden_states_out
|
||||
sm.tag(final_hidden_states)
|
||||
if self.tp_size > 1 and not should_allreduce_fusion and not use_reduce_scatter:
|
||||
if (
|
||||
self.tp_size > 1
|
||||
and not should_allreduce_fusion
|
||||
and not use_reduce_scatter
|
||||
and not should_use_flashinfer_cutlass_moe_fp4_allgather()
|
||||
):
|
||||
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
|
||||
return final_hidden_states
|
||||
|
||||
@@ -580,11 +604,8 @@ class DeepseekV2MoE(nn.Module):
|
||||
),
|
||||
)
|
||||
else:
|
||||
topk_idx = torch.full(
|
||||
(0, self.top_k), -1, dtype=torch.int, device=hidden_states.device
|
||||
)
|
||||
topk_weights = torch.empty(
|
||||
(0, self.top_k), dtype=torch.float32, device=hidden_states.device
|
||||
topk_weights, topk_idx, _ = self.topk.empty_topk_output(
|
||||
hidden_states.device
|
||||
)
|
||||
|
||||
final_hidden_states = self.experts(
|
||||
|
||||
@@ -84,6 +84,7 @@ class _StageExecutor:
|
||||
forward_batch: ForwardBatch = inputs["forward_batch"]
|
||||
self._global_dp_buffer_len = forward_batch.global_dp_buffer_len
|
||||
self._local_dp_buffer_len = forward_batch.input_ids.shape[0]
|
||||
self._global_num_tokens = forward_batch.global_num_tokens_cpu
|
||||
|
||||
def next(self):
|
||||
assert not self.done
|
||||
@@ -91,7 +92,11 @@ class _StageExecutor:
|
||||
stage = self._stages[self._index]
|
||||
|
||||
if self._global_dp_buffer_len is not None:
|
||||
set_dp_buffer_len(self._global_dp_buffer_len, self._local_dp_buffer_len)
|
||||
set_dp_buffer_len(
|
||||
self._global_dp_buffer_len,
|
||||
self._local_dp_buffer_len,
|
||||
self._global_num_tokens,
|
||||
)
|
||||
|
||||
with _annotate_region(debug_name=f"{self._debug_name}{self._index}"):
|
||||
for op in stage:
|
||||
|
||||
@@ -230,6 +230,7 @@ class ServerArgs:
|
||||
enable_cudagraph_gc: bool = False
|
||||
enable_nccl_nvls: bool = False
|
||||
enable_symm_mem: bool = False
|
||||
disable_flashinfer_cutlass_moe_fp4_allgather: bool = False
|
||||
enable_tokenizer_batch_encode: bool = False
|
||||
disable_outlines_disk_cache: bool = False
|
||||
disable_custom_all_reduce: bool = False
|
||||
@@ -1714,6 +1715,11 @@ class ServerArgs:
|
||||
action="store_true",
|
||||
help="Enable NCCL symmetric memory for fast collectives.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--disable-flashinfer-cutlass-moe-fp4-allgather",
|
||||
action="store_true",
|
||||
help="Disables quantize before all-gather for flashinfer cutlass moe.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--enable-tokenizer-batch-encode",
|
||||
action="store_true",
|
||||
|
||||
Reference in New Issue
Block a user