feat: remove the dependency on FusedMoE (#2153)
This commit is contained in:
@@ -57,12 +57,23 @@ __all__ = [
|
|||||||
"QUANTIZATION_METHODS",
|
"QUANTIZATION_METHODS",
|
||||||
]
|
]
|
||||||
|
|
||||||
"""
|
|
||||||
def fp8_get_quant_method(
|
def fp8_get_quant_method(self, layer, prefix):
|
||||||
self, layer: torch.nn.Module, prefix: str
|
from vllm.model_executor.layers.linear import LinearBase
|
||||||
) -> Optional["QuantizeMethodBase"]:
|
from vllm.model_executor.layers.quantization.fp8 import (
|
||||||
|
Fp8LinearMethod,
|
||||||
|
Fp8MoEMethod,
|
||||||
|
)
|
||||||
|
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||||
|
is_layer_skipped,
|
||||||
|
)
|
||||||
|
|
||||||
|
from sglang.srt.layers.triton_fused_moe.layer import FusedMoE
|
||||||
|
|
||||||
if isinstance(layer, LinearBase):
|
if isinstance(layer, LinearBase):
|
||||||
if is_layer_skipped(prefix, self.ignored_layers):
|
if is_layer_skipped(prefix, self.ignored_layers):
|
||||||
|
from sglang.srt.layers.linear import UnquantizedLinearMethod
|
||||||
|
|
||||||
return UnquantizedLinearMethod()
|
return UnquantizedLinearMethod()
|
||||||
return Fp8LinearMethod(self)
|
return Fp8LinearMethod(self)
|
||||||
elif isinstance(layer, FusedMoE):
|
elif isinstance(layer, FusedMoE):
|
||||||
@@ -71,4 +82,3 @@ def fp8_get_quant_method(
|
|||||||
|
|
||||||
|
|
||||||
setattr(Fp8Config, "get_quant_method", fp8_get_quant_method)
|
setattr(Fp8Config, "get_quant_method", fp8_get_quant_method)
|
||||||
"""
|
|
||||||
|
|||||||
44
python/sglang/srt/layers/triton_fused_moe/__init__.py
Normal file
44
python/sglang/srt/layers/triton_fused_moe/__init__.py
Normal file
@@ -0,0 +1,44 @@
|
|||||||
|
from contextlib import contextmanager
|
||||||
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
|
import sglang.srt.layers.triton_fused_moe.fused_moe # noqa
|
||||||
|
from sglang.srt.layers.triton_fused_moe.fused_moe import (
|
||||||
|
fused_experts,
|
||||||
|
fused_topk,
|
||||||
|
get_config_file_name,
|
||||||
|
grouped_topk,
|
||||||
|
)
|
||||||
|
from sglang.srt.layers.triton_fused_moe.layer import (
|
||||||
|
FusedMoE,
|
||||||
|
FusedMoEMethodBase,
|
||||||
|
FusedMoeWeightScaleSupported,
|
||||||
|
)
|
||||||
|
|
||||||
|
_config: Optional[Dict[str, Any]] = None
|
||||||
|
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def override_config(config):
|
||||||
|
global _config
|
||||||
|
old_config = _config
|
||||||
|
_config = config
|
||||||
|
yield
|
||||||
|
_config = old_config
|
||||||
|
|
||||||
|
|
||||||
|
def get_config() -> Optional[Dict[str, Any]]:
|
||||||
|
return _config
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"FusedMoE",
|
||||||
|
"FusedMoEMethodBase",
|
||||||
|
"FusedMoeWeightScaleSupported",
|
||||||
|
"override_config",
|
||||||
|
"get_config",
|
||||||
|
"fused_moe",
|
||||||
|
"fused_topk",
|
||||||
|
"fused_experts",
|
||||||
|
"get_config_file_name",
|
||||||
|
"grouped_topk",
|
||||||
|
]
|
||||||
10
python/sglang/srt/layers/triton_fused_moe/configs/README
Normal file
10
python/sglang/srt/layers/triton_fused_moe/configs/README
Normal file
@@ -0,0 +1,10 @@
|
|||||||
|
This directory contains tuned configurations for different settings of the fused_moe kernel.
|
||||||
|
For different settings of
|
||||||
|
- E (number of experts)
|
||||||
|
- N (intermediate size)
|
||||||
|
- device_name (torch.cuda.get_device_name())
|
||||||
|
the JSON file contains a mapping from M (batch size) to the chosen configuration.
|
||||||
|
|
||||||
|
The example configurations provided are for the Mixtral model for TP2 on H100
|
||||||
|
and TP4 on A100. Mixtral has intermediate size N = 14336, i.e. for TP2 we have
|
||||||
|
N = 7168 and for TP4 we have N = 3584.
|
||||||
858
python/sglang/srt/layers/triton_fused_moe/fused_moe.py
Normal file
858
python/sglang/srt/layers/triton_fused_moe/fused_moe.py
Normal file
@@ -0,0 +1,858 @@
|
|||||||
|
"""Fused MoE kernel."""
|
||||||
|
|
||||||
|
import functools
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
from typing import Any, Callable, Dict, Optional, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import triton
|
||||||
|
import triton.language as tl
|
||||||
|
from vllm import _custom_ops as ops
|
||||||
|
|
||||||
|
from sglang.srt.utils import direct_register_custom_op, get_device_name
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def fused_moe_kernel(
|
||||||
|
# Pointers to matrices
|
||||||
|
a_ptr,
|
||||||
|
b_ptr,
|
||||||
|
c_ptr,
|
||||||
|
a_scale_ptr,
|
||||||
|
b_scale_ptr,
|
||||||
|
topk_weights_ptr,
|
||||||
|
sorted_token_ids_ptr,
|
||||||
|
expert_ids_ptr,
|
||||||
|
num_tokens_post_padded_ptr,
|
||||||
|
# Matrix dimensions
|
||||||
|
N,
|
||||||
|
K,
|
||||||
|
EM,
|
||||||
|
num_valid_tokens,
|
||||||
|
# The stride variables represent how much to increase the ptr by when
|
||||||
|
# moving by 1 element in a particular dimension. E.g. `stride_am` is
|
||||||
|
# how much to increase `a_ptr` by to get the element one row down
|
||||||
|
# (A has M rows).
|
||||||
|
stride_am,
|
||||||
|
stride_ak,
|
||||||
|
stride_be,
|
||||||
|
stride_bk,
|
||||||
|
stride_bn,
|
||||||
|
stride_cm,
|
||||||
|
stride_cn,
|
||||||
|
stride_bse,
|
||||||
|
stride_bsn,
|
||||||
|
# Meta-parameters
|
||||||
|
BLOCK_SIZE_M: tl.constexpr,
|
||||||
|
BLOCK_SIZE_N: tl.constexpr,
|
||||||
|
BLOCK_SIZE_K: tl.constexpr,
|
||||||
|
GROUP_SIZE_M: tl.constexpr,
|
||||||
|
MUL_ROUTED_WEIGHT: tl.constexpr,
|
||||||
|
top_k: tl.constexpr,
|
||||||
|
compute_type: tl.constexpr,
|
||||||
|
use_fp8_w8a8: tl.constexpr,
|
||||||
|
use_int8_w8a16: tl.constexpr,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Implements the fused computation for a Mixture of Experts (MOE) using
|
||||||
|
token and expert matrices.
|
||||||
|
|
||||||
|
Key Parameters:
|
||||||
|
- A: The input tensor representing tokens with shape (*, K), where '*' can
|
||||||
|
be any shape representing batches and K is the feature dimension of
|
||||||
|
each token.
|
||||||
|
- B: The stacked MOE weight tensor with shape (E, N, K), where E is
|
||||||
|
the number of experts, K is the input feature dimension, and N is
|
||||||
|
the output feature dimension.
|
||||||
|
- C: The output cache tensor with shape (M, topk, N), where M is the
|
||||||
|
total number of tokens post padding, topk is the number of times
|
||||||
|
each token is repeated, and N is the output feature dimension.
|
||||||
|
- sorted_token_ids: A tensor containing the sorted indices of tokens,
|
||||||
|
repeated topk times and arranged by the expert index they are
|
||||||
|
assigned to.
|
||||||
|
- expert_ids: A tensor containing the indices of the expert for each
|
||||||
|
block. It determines which expert matrix from B should be used for
|
||||||
|
each block in A.
|
||||||
|
This kernel performs the multiplication of a token by its corresponding
|
||||||
|
expert matrix as determined by `expert_ids`. The sorting of
|
||||||
|
`sorted_token_ids` by expert index and padding ensures divisibility by
|
||||||
|
BLOCK_SIZE_M, which is necessary to maintain consistency in block matrix
|
||||||
|
multiplication across different blocks processed by the same expert.
|
||||||
|
"""
|
||||||
|
# -----------------------------------------------------------
|
||||||
|
# Map program ids `pid` to the block of C it should compute.
|
||||||
|
# This is done in a grouped ordering to promote L2 data reuse.
|
||||||
|
pid = tl.program_id(axis=0)
|
||||||
|
num_pid_m = tl.cdiv(EM, BLOCK_SIZE_M)
|
||||||
|
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
|
||||||
|
num_pid_in_group = GROUP_SIZE_M * num_pid_n
|
||||||
|
group_id = pid // num_pid_in_group
|
||||||
|
first_pid_m = group_id * GROUP_SIZE_M
|
||||||
|
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
|
||||||
|
pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
|
||||||
|
pid_n = (pid % num_pid_in_group) // group_size_m
|
||||||
|
|
||||||
|
# ----------------------------------------------------------
|
||||||
|
# Create pointers for the first blocks of A and B.
|
||||||
|
# We will advance this pointer as we move in the K direction
|
||||||
|
# and accumulate
|
||||||
|
# `a_ptrs` is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers
|
||||||
|
# `b_ptrs` is a block of [BLOCK_SIZE_K, BLOCK_SIZE_N] pointers
|
||||||
|
num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr)
|
||||||
|
if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded:
|
||||||
|
return
|
||||||
|
offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
||||||
|
offs_token = tl.load(sorted_token_ids_ptr + offs_token_id)
|
||||||
|
token_mask = offs_token < num_valid_tokens
|
||||||
|
|
||||||
|
offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
|
||||||
|
offs_k = tl.arange(0, BLOCK_SIZE_K)
|
||||||
|
a_ptrs = a_ptr + (
|
||||||
|
offs_token[:, None] // top_k * stride_am + offs_k[None, :] * stride_ak
|
||||||
|
)
|
||||||
|
|
||||||
|
off_experts = tl.load(expert_ids_ptr + pid_m)
|
||||||
|
b_ptrs = (
|
||||||
|
b_ptr
|
||||||
|
+ off_experts * stride_be
|
||||||
|
+ (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
|
||||||
|
)
|
||||||
|
if use_int8_w8a16:
|
||||||
|
b_scale_ptrs = (
|
||||||
|
b_scale_ptr + off_experts * stride_bse + offs_bn[None, :] * stride_bsn
|
||||||
|
)
|
||||||
|
b_scale = tl.load(b_scale_ptrs)
|
||||||
|
|
||||||
|
if use_fp8_w8a8:
|
||||||
|
a_scale = tl.load(a_scale_ptr)
|
||||||
|
b_scale = tl.load(b_scale_ptr + off_experts)
|
||||||
|
|
||||||
|
# -----------------------------------------------------------
|
||||||
|
# Iterate to compute a block of the C matrix.
|
||||||
|
# We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block
|
||||||
|
# of fp32 values for higher accuracy.
|
||||||
|
# `accumulator` will be converted back to fp16 after the loop.
|
||||||
|
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
||||||
|
|
||||||
|
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
|
||||||
|
# Load the next block of A and B, generate a mask by checking the
|
||||||
|
# K dimension.
|
||||||
|
a = tl.load(
|
||||||
|
a_ptrs,
|
||||||
|
mask=token_mask[:, None] & (offs_k[None, :] < K - k * BLOCK_SIZE_K),
|
||||||
|
other=0.0,
|
||||||
|
)
|
||||||
|
b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)
|
||||||
|
# We accumulate along the K dimension.
|
||||||
|
if use_int8_w8a16:
|
||||||
|
accumulator = tl.dot(a, b.to(compute_type), acc=accumulator)
|
||||||
|
elif use_fp8_w8a8:
|
||||||
|
accumulator = tl.dot(a, b, acc=accumulator)
|
||||||
|
else:
|
||||||
|
accumulator += tl.dot(a, b)
|
||||||
|
# Advance the ptrs to the next K block.
|
||||||
|
a_ptrs += BLOCK_SIZE_K * stride_ak
|
||||||
|
b_ptrs += BLOCK_SIZE_K * stride_bk
|
||||||
|
|
||||||
|
if MUL_ROUTED_WEIGHT:
|
||||||
|
moe_weight = tl.load(topk_weights_ptr + offs_token, mask=token_mask, other=0)
|
||||||
|
accumulator = accumulator * moe_weight[:, None]
|
||||||
|
if use_int8_w8a16:
|
||||||
|
accumulator = (accumulator * b_scale).to(compute_type)
|
||||||
|
elif use_fp8_w8a8:
|
||||||
|
accumulator = (accumulator * a_scale * b_scale).to(compute_type)
|
||||||
|
else:
|
||||||
|
accumulator = accumulator.to(compute_type)
|
||||||
|
# -----------------------------------------------------------
|
||||||
|
# Write back the block of the output
|
||||||
|
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
||||||
|
c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[None, :]
|
||||||
|
c_mask = token_mask[:, None] & (offs_cn[None, :] < N)
|
||||||
|
tl.store(c_ptrs, accumulator, mask=c_mask)
|
||||||
|
|
||||||
|
|
||||||
|
def moe_align_block_size(
|
||||||
|
topk_ids: torch.Tensor, block_size: int, num_experts: int
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||||
|
"""
|
||||||
|
Aligns the token distribution across experts to be compatible with block
|
||||||
|
size for matrix multiplication.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
- topk_ids: A tensor of shape [total_tokens, top_k] representing the
|
||||||
|
top-k expert indices for each token.
|
||||||
|
- block_size: The block size used in block matrix multiplication.
|
||||||
|
- num_experts: The total number of experts.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
- sorted_token_ids: A tensor containing the sorted token indices according
|
||||||
|
to their allocated expert.
|
||||||
|
- expert_ids: A tensor indicating the assigned expert index for each block.
|
||||||
|
- num_tokens_post_padded: The total number of tokens after padding,
|
||||||
|
ensuring divisibility by block_size.
|
||||||
|
|
||||||
|
This function pads the number of tokens that each expert needs to process
|
||||||
|
so that it is divisible by block_size.
|
||||||
|
Padding ensures that during block matrix multiplication, the dimensions
|
||||||
|
align correctly.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
Given topk_ids = [[2, 3, 4], [1, 2, 4], [1, 3, 4], [1, 2, 3]],
|
||||||
|
block_size = 4, and num_experts = 4:
|
||||||
|
- We initially have 12 tokens (after repeating 'top_k' times) and 4 experts,
|
||||||
|
with each expert needing to process 3 tokens.
|
||||||
|
- As block_size is 4, we pad 1 token for each expert.
|
||||||
|
- First, flatten topk_ids to [2, 3, 4, 1, 2, 4, 1, 3, 4, 1, 2, 3].
|
||||||
|
- Then append padding tokens [12, 12, 12, 12] for each block.
|
||||||
|
- After sorting by expert index, we obtain token_ids
|
||||||
|
[3, 6, 9, 12, 0, 4, 10, 12, 1, 7, 11, 12, 2, 5, 8, 12].
|
||||||
|
Tokens 12 are non-existent (padding) and are ignored in
|
||||||
|
the subsequent matrix multiplication.
|
||||||
|
- The padding ensures that the total number of tokens is now divisible
|
||||||
|
by block_size for proper block matrix operations.
|
||||||
|
"""
|
||||||
|
max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1)
|
||||||
|
sorted_ids = torch.empty(
|
||||||
|
(max_num_tokens_padded,), dtype=torch.int32, device=topk_ids.device
|
||||||
|
)
|
||||||
|
sorted_ids.fill_(topk_ids.numel())
|
||||||
|
max_num_m_blocks = triton.cdiv(max_num_tokens_padded, block_size)
|
||||||
|
expert_ids = torch.empty(
|
||||||
|
(max_num_m_blocks,), dtype=torch.int32, device=topk_ids.device
|
||||||
|
)
|
||||||
|
num_tokens_post_pad = torch.empty((1), dtype=torch.int32, device=topk_ids.device)
|
||||||
|
ops.moe_align_block_size(
|
||||||
|
topk_ids, num_experts, block_size, sorted_ids, expert_ids, num_tokens_post_pad
|
||||||
|
)
|
||||||
|
return sorted_ids, expert_ids, num_tokens_post_pad
|
||||||
|
|
||||||
|
|
||||||
|
def invoke_fused_moe_kernel(
|
||||||
|
A: torch.Tensor,
|
||||||
|
B: torch.Tensor,
|
||||||
|
C: torch.Tensor,
|
||||||
|
A_scale: Optional[torch.Tensor],
|
||||||
|
B_scale: Optional[torch.Tensor],
|
||||||
|
topk_weights: torch.Tensor,
|
||||||
|
topk_ids: torch.Tensor,
|
||||||
|
sorted_token_ids: torch.Tensor,
|
||||||
|
expert_ids: torch.Tensor,
|
||||||
|
num_tokens_post_padded: torch.Tensor,
|
||||||
|
mul_routed_weight: bool,
|
||||||
|
top_k: int,
|
||||||
|
config: Dict[str, Any],
|
||||||
|
compute_type: tl.dtype,
|
||||||
|
use_fp8_w8a8: bool,
|
||||||
|
use_int8_w8a16: bool,
|
||||||
|
) -> None:
|
||||||
|
assert topk_weights.stride(1) == 1
|
||||||
|
assert sorted_token_ids.stride(0) == 1
|
||||||
|
|
||||||
|
if use_fp8_w8a8:
|
||||||
|
A, A_scale = ops.scaled_fp8_quant(A, A_scale)
|
||||||
|
assert B_scale is not None
|
||||||
|
elif use_int8_w8a16:
|
||||||
|
assert B_scale is not None
|
||||||
|
else:
|
||||||
|
assert A_scale is None
|
||||||
|
assert B_scale is None
|
||||||
|
|
||||||
|
grid = lambda META: (
|
||||||
|
triton.cdiv(sorted_token_ids.shape[0], META["BLOCK_SIZE_M"])
|
||||||
|
* triton.cdiv(B.shape[1], META["BLOCK_SIZE_N"]),
|
||||||
|
)
|
||||||
|
|
||||||
|
fused_moe_kernel[grid](
|
||||||
|
A,
|
||||||
|
B,
|
||||||
|
C,
|
||||||
|
A_scale,
|
||||||
|
B_scale,
|
||||||
|
topk_weights,
|
||||||
|
sorted_token_ids,
|
||||||
|
expert_ids,
|
||||||
|
num_tokens_post_padded,
|
||||||
|
B.shape[1],
|
||||||
|
B.shape[2],
|
||||||
|
sorted_token_ids.shape[0],
|
||||||
|
topk_ids.numel(),
|
||||||
|
A.stride(0),
|
||||||
|
A.stride(1),
|
||||||
|
B.stride(0),
|
||||||
|
B.stride(2),
|
||||||
|
B.stride(1),
|
||||||
|
C.stride(1),
|
||||||
|
C.stride(2),
|
||||||
|
B_scale.stride(0) if B_scale is not None and use_int8_w8a16 else 0,
|
||||||
|
B_scale.stride(1) if B_scale is not None and use_int8_w8a16 else 0,
|
||||||
|
MUL_ROUTED_WEIGHT=mul_routed_weight,
|
||||||
|
top_k=top_k,
|
||||||
|
compute_type=compute_type,
|
||||||
|
use_fp8_w8a8=use_fp8_w8a8,
|
||||||
|
use_int8_w8a16=use_int8_w8a16,
|
||||||
|
**config,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_config_file_name(E: int, N: int, dtype: Optional[str]) -> str:
|
||||||
|
device_name = get_device_name().replace(" ", "_")
|
||||||
|
dtype_selector = "" if not dtype else f",dtype={dtype}"
|
||||||
|
return f"E={E},N={N},device_name={device_name}{dtype_selector}.json"
|
||||||
|
|
||||||
|
|
||||||
|
@functools.lru_cache
|
||||||
|
def get_moe_configs(E: int, N: int, dtype: Optional[str]) -> Optional[Dict[int, Any]]:
|
||||||
|
"""
|
||||||
|
Return optimized configurations for the fused MoE kernel.
|
||||||
|
|
||||||
|
The return value will be a dictionary that maps an irregular grid of
|
||||||
|
batch sizes to configurations of the fused_moe kernel. To evaluate the
|
||||||
|
kernel on a given batch size bs, the closest batch size in the grid should
|
||||||
|
be picked and the associated configuration chosen to invoke the kernel.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# First look up if an optimized configuration is available in the configs
|
||||||
|
# directory
|
||||||
|
json_file_name = get_config_file_name(E, N, dtype)
|
||||||
|
|
||||||
|
config_file_path = os.path.join(
|
||||||
|
os.path.dirname(os.path.realpath(__file__)), "configs", json_file_name
|
||||||
|
)
|
||||||
|
if os.path.exists(config_file_path):
|
||||||
|
with open(config_file_path) as f:
|
||||||
|
logger.info("Using configuration from %s for MoE layer.", config_file_path)
|
||||||
|
# If a configuration has been found, return it
|
||||||
|
return {int(key): val for key, val in json.load(f).items()}
|
||||||
|
|
||||||
|
# If no optimized configuration is available, we will use the default
|
||||||
|
# configuration
|
||||||
|
logger.warning(
|
||||||
|
(
|
||||||
|
"Using default MoE config. Performance might be sub-optimal! "
|
||||||
|
"Config file not found at %s"
|
||||||
|
),
|
||||||
|
config_file_path,
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def get_default_config(
|
||||||
|
M: int,
|
||||||
|
E: int,
|
||||||
|
N: int,
|
||||||
|
K: int,
|
||||||
|
topk: int,
|
||||||
|
dtype: Optional[str],
|
||||||
|
is_marlin: bool,
|
||||||
|
) -> Dict[str, int]:
|
||||||
|
config = {
|
||||||
|
"BLOCK_SIZE_M": 64,
|
||||||
|
"BLOCK_SIZE_N": 64,
|
||||||
|
"BLOCK_SIZE_K": 32,
|
||||||
|
"GROUP_SIZE_M": 8,
|
||||||
|
}
|
||||||
|
# A heuristic: fused marlin works faster with this config for small M
|
||||||
|
if M <= E or (is_marlin and M <= 32):
|
||||||
|
config = {
|
||||||
|
"BLOCK_SIZE_M": 16,
|
||||||
|
"BLOCK_SIZE_N": 32,
|
||||||
|
"BLOCK_SIZE_K": 64,
|
||||||
|
"GROUP_SIZE_M": 1,
|
||||||
|
}
|
||||||
|
return config
|
||||||
|
|
||||||
|
|
||||||
|
def try_get_optimal_moe_config(
|
||||||
|
w1_shape: Tuple[int, ...],
|
||||||
|
w2_shape: Tuple[int, ...],
|
||||||
|
top_k: int,
|
||||||
|
dtype: Optional[str],
|
||||||
|
M: int,
|
||||||
|
is_marlin: bool = False,
|
||||||
|
):
|
||||||
|
from sglang.srt.layers.triton_fused_moe import get_config
|
||||||
|
|
||||||
|
override_config = get_config()
|
||||||
|
if override_config:
|
||||||
|
config = override_config
|
||||||
|
else:
|
||||||
|
# First try to load optimal config from the file
|
||||||
|
E, _, N = w2_shape
|
||||||
|
configs = get_moe_configs(E, N, dtype)
|
||||||
|
|
||||||
|
if configs:
|
||||||
|
# If an optimal configuration map has been found, look up the
|
||||||
|
# optimal config
|
||||||
|
config = configs[min(configs.keys(), key=lambda x: abs(x - M))]
|
||||||
|
else:
|
||||||
|
# Else use the default config
|
||||||
|
config = get_default_config(M, E, N, w1_shape[2], top_k, dtype, is_marlin)
|
||||||
|
return config
|
||||||
|
|
||||||
|
|
||||||
|
def fused_topk(
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
gating_output: torch.Tensor,
|
||||||
|
topk: int,
|
||||||
|
renormalize: bool,
|
||||||
|
):
|
||||||
|
assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
|
||||||
|
|
||||||
|
M, _ = hidden_states.shape
|
||||||
|
|
||||||
|
topk_weights = torch.empty(
|
||||||
|
M, topk, dtype=torch.float32, device=hidden_states.device
|
||||||
|
)
|
||||||
|
topk_ids = torch.empty(M, topk, dtype=torch.int32, device=hidden_states.device)
|
||||||
|
token_expert_indicies = torch.empty(
|
||||||
|
M, topk, dtype=torch.int32, device=hidden_states.device
|
||||||
|
)
|
||||||
|
|
||||||
|
ops.topk_softmax(
|
||||||
|
topk_weights,
|
||||||
|
topk_ids,
|
||||||
|
token_expert_indicies,
|
||||||
|
gating_output.float(), # TODO(woosuk): Optimize this.
|
||||||
|
)
|
||||||
|
del token_expert_indicies # Not used. Will be used in the future.
|
||||||
|
|
||||||
|
if renormalize:
|
||||||
|
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
|
||||||
|
|
||||||
|
return topk_weights, topk_ids
|
||||||
|
|
||||||
|
|
||||||
|
# This is used by the Deepseek-V2 model
|
||||||
|
def grouped_topk(
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
gating_output: torch.Tensor,
|
||||||
|
topk: int,
|
||||||
|
renormalize: bool,
|
||||||
|
num_expert_group: int = 0,
|
||||||
|
topk_group: int = 0,
|
||||||
|
):
|
||||||
|
|
||||||
|
assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
|
||||||
|
|
||||||
|
scores = torch.softmax(gating_output, dim=-1)
|
||||||
|
num_token = scores.shape[0]
|
||||||
|
group_scores = (
|
||||||
|
scores.view(num_token, num_expert_group, -1).max(dim=-1).values
|
||||||
|
) # [n, n_group]
|
||||||
|
group_idx = torch.topk(group_scores, k=topk_group, dim=-1, sorted=False)[
|
||||||
|
1
|
||||||
|
] # [n, top_k_group]
|
||||||
|
group_mask = torch.zeros_like(group_scores) # [n, n_group]
|
||||||
|
group_mask.scatter_(1, group_idx, 1) # [n, n_group]
|
||||||
|
score_mask = (
|
||||||
|
group_mask.unsqueeze(-1)
|
||||||
|
.expand(num_token, num_expert_group, scores.shape[-1] // num_expert_group)
|
||||||
|
.reshape(num_token, -1)
|
||||||
|
) # [n, e]
|
||||||
|
tmp_scores = scores.masked_fill(~score_mask.bool(), 0.0) # [n, e]
|
||||||
|
topk_weights, topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False)
|
||||||
|
|
||||||
|
if renormalize:
|
||||||
|
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
|
||||||
|
|
||||||
|
return topk_weights.to(torch.float32), topk_ids.to(torch.int32)
|
||||||
|
|
||||||
|
|
||||||
|
def get_config_dtype_str(
|
||||||
|
dtype: torch.dtype,
|
||||||
|
use_int8_w8a16: Optional[bool] = False,
|
||||||
|
use_fp8_w8a8: Optional[bool] = False,
|
||||||
|
):
|
||||||
|
if use_fp8_w8a8:
|
||||||
|
return "fp8_w8a8"
|
||||||
|
elif use_int8_w8a16:
|
||||||
|
return "int8_w8a16"
|
||||||
|
elif dtype == torch.float:
|
||||||
|
# avoiding cases where kernel fails when float32 MoE
|
||||||
|
# use fp16/bfloat16 configs
|
||||||
|
return "float32"
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def inplace_fused_experts(
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
w1: torch.Tensor,
|
||||||
|
w2: torch.Tensor,
|
||||||
|
topk_weights: torch.Tensor,
|
||||||
|
topk_ids: torch.Tensor,
|
||||||
|
use_fp8_w8a8: bool = False,
|
||||||
|
use_int8_w8a16: bool = False,
|
||||||
|
w1_scale: Optional[torch.Tensor] = None,
|
||||||
|
w2_scale: Optional[torch.Tensor] = None,
|
||||||
|
a1_scale: Optional[torch.Tensor] = None,
|
||||||
|
a2_scale: Optional[torch.Tensor] = None,
|
||||||
|
) -> None:
|
||||||
|
fused_experts_impl(
|
||||||
|
hidden_states,
|
||||||
|
w1,
|
||||||
|
w2,
|
||||||
|
topk_weights,
|
||||||
|
topk_ids,
|
||||||
|
True,
|
||||||
|
use_fp8_w8a8,
|
||||||
|
use_int8_w8a16,
|
||||||
|
w1_scale,
|
||||||
|
w2_scale,
|
||||||
|
a1_scale,
|
||||||
|
a2_scale,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def inplace_fused_experts_fake(
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
w1: torch.Tensor,
|
||||||
|
w2: torch.Tensor,
|
||||||
|
topk_weights: torch.Tensor,
|
||||||
|
topk_ids: torch.Tensor,
|
||||||
|
use_fp8_w8a8: bool = False,
|
||||||
|
use_int8_w8a16: bool = False,
|
||||||
|
w1_scale: Optional[torch.Tensor] = None,
|
||||||
|
w2_scale: Optional[torch.Tensor] = None,
|
||||||
|
a1_scale: Optional[torch.Tensor] = None,
|
||||||
|
a2_scale: Optional[torch.Tensor] = None,
|
||||||
|
) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
direct_register_custom_op(
|
||||||
|
op_name="inplace_fused_experts",
|
||||||
|
op_func=inplace_fused_experts,
|
||||||
|
mutates_args=["hidden_states"],
|
||||||
|
fake_impl=inplace_fused_experts_fake,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def outplace_fused_experts(
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
w1: torch.Tensor,
|
||||||
|
w2: torch.Tensor,
|
||||||
|
topk_weights: torch.Tensor,
|
||||||
|
topk_ids: torch.Tensor,
|
||||||
|
use_fp8_w8a8: bool = False,
|
||||||
|
use_int8_w8a16: bool = False,
|
||||||
|
w1_scale: Optional[torch.Tensor] = None,
|
||||||
|
w2_scale: Optional[torch.Tensor] = None,
|
||||||
|
a1_scale: Optional[torch.Tensor] = None,
|
||||||
|
a2_scale: Optional[torch.Tensor] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
return fused_experts_impl(
|
||||||
|
hidden_states,
|
||||||
|
w1,
|
||||||
|
w2,
|
||||||
|
topk_weights,
|
||||||
|
topk_ids,
|
||||||
|
False,
|
||||||
|
use_fp8_w8a8,
|
||||||
|
use_int8_w8a16,
|
||||||
|
w1_scale,
|
||||||
|
w2_scale,
|
||||||
|
a1_scale,
|
||||||
|
a2_scale,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def outplace_fused_experts_fake(
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
w1: torch.Tensor,
|
||||||
|
w2: torch.Tensor,
|
||||||
|
topk_weights: torch.Tensor,
|
||||||
|
topk_ids: torch.Tensor,
|
||||||
|
use_fp8_w8a8: bool = False,
|
||||||
|
use_int8_w8a16: bool = False,
|
||||||
|
w1_scale: Optional[torch.Tensor] = None,
|
||||||
|
w2_scale: Optional[torch.Tensor] = None,
|
||||||
|
a1_scale: Optional[torch.Tensor] = None,
|
||||||
|
a2_scale: Optional[torch.Tensor] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
return torch.empty_like(hidden_states)
|
||||||
|
|
||||||
|
|
||||||
|
direct_register_custom_op(
|
||||||
|
op_name="outplace_fused_experts",
|
||||||
|
op_func=outplace_fused_experts,
|
||||||
|
mutates_args=[],
|
||||||
|
fake_impl=outplace_fused_experts_fake,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def fused_experts(
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
w1: torch.Tensor,
|
||||||
|
w2: torch.Tensor,
|
||||||
|
topk_weights: torch.Tensor,
|
||||||
|
topk_ids: torch.Tensor,
|
||||||
|
inplace: bool = False,
|
||||||
|
use_fp8_w8a8: bool = False,
|
||||||
|
use_int8_w8a16: bool = False,
|
||||||
|
w1_scale: Optional[torch.Tensor] = None,
|
||||||
|
w2_scale: Optional[torch.Tensor] = None,
|
||||||
|
a1_scale: Optional[torch.Tensor] = None,
|
||||||
|
a2_scale: Optional[torch.Tensor] = None,
|
||||||
|
):
|
||||||
|
if inplace:
|
||||||
|
torch.ops.sglang.inplace_fused_experts(
|
||||||
|
hidden_states,
|
||||||
|
w1,
|
||||||
|
w2,
|
||||||
|
topk_weights,
|
||||||
|
topk_ids,
|
||||||
|
use_fp8_w8a8,
|
||||||
|
use_int8_w8a16,
|
||||||
|
w1_scale,
|
||||||
|
w2_scale,
|
||||||
|
a1_scale,
|
||||||
|
a2_scale,
|
||||||
|
)
|
||||||
|
return hidden_states
|
||||||
|
else:
|
||||||
|
return torch.ops.sglang.outplace_fused_experts(
|
||||||
|
hidden_states,
|
||||||
|
w1,
|
||||||
|
w2,
|
||||||
|
topk_weights,
|
||||||
|
topk_ids,
|
||||||
|
use_fp8_w8a8,
|
||||||
|
use_int8_w8a16,
|
||||||
|
w1_scale,
|
||||||
|
w2_scale,
|
||||||
|
a1_scale,
|
||||||
|
a2_scale,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def fused_experts_impl(
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
w1: torch.Tensor,
|
||||||
|
w2: torch.Tensor,
|
||||||
|
topk_weights: torch.Tensor,
|
||||||
|
topk_ids: torch.Tensor,
|
||||||
|
inplace: bool = False,
|
||||||
|
use_fp8_w8a8: bool = False,
|
||||||
|
use_int8_w8a16: bool = False,
|
||||||
|
w1_scale: Optional[torch.Tensor] = None,
|
||||||
|
w2_scale: Optional[torch.Tensor] = None,
|
||||||
|
a1_scale: Optional[torch.Tensor] = None,
|
||||||
|
a2_scale: Optional[torch.Tensor] = None,
|
||||||
|
):
|
||||||
|
# Check constraints.
|
||||||
|
assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch"
|
||||||
|
assert topk_weights.shape == topk_ids.shape, "topk shape mismatch"
|
||||||
|
assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
|
||||||
|
assert w1.is_contiguous(), "Expert weights1 must be contiguous"
|
||||||
|
assert w2.is_contiguous(), "Expert weights2 must be contiguous"
|
||||||
|
assert hidden_states.dtype in [torch.float32, torch.float16, torch.bfloat16]
|
||||||
|
|
||||||
|
num_tokens, _ = hidden_states.shape
|
||||||
|
E, N, _ = w1.shape
|
||||||
|
# We execute the fused_moe kernel in chunks to circumvent this issue:
|
||||||
|
# https://github.com/vllm-project/vllm/issues/5938
|
||||||
|
CHUNK_SIZE = 64 * 1024
|
||||||
|
M = min(num_tokens, CHUNK_SIZE)
|
||||||
|
config_dtype = get_config_dtype_str(
|
||||||
|
use_fp8_w8a8=use_fp8_w8a8,
|
||||||
|
use_int8_w8a16=use_int8_w8a16,
|
||||||
|
dtype=hidden_states.dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
get_config_func = functools.partial(
|
||||||
|
try_get_optimal_moe_config,
|
||||||
|
w1.shape,
|
||||||
|
w2.shape,
|
||||||
|
topk_ids.shape[1],
|
||||||
|
config_dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
config = get_config_func(M)
|
||||||
|
|
||||||
|
intermediate_cache1 = torch.empty(
|
||||||
|
(M, topk_ids.shape[1], N),
|
||||||
|
device=hidden_states.device,
|
||||||
|
dtype=hidden_states.dtype,
|
||||||
|
)
|
||||||
|
intermediate_cache2 = torch.empty(
|
||||||
|
(M * topk_ids.shape[1], N // 2),
|
||||||
|
device=hidden_states.device,
|
||||||
|
dtype=hidden_states.dtype,
|
||||||
|
)
|
||||||
|
intermediate_cache3 = torch.empty(
|
||||||
|
(M, topk_ids.shape[1], w2.shape[1]),
|
||||||
|
device=hidden_states.device,
|
||||||
|
dtype=hidden_states.dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
compute_type = tl.bfloat16 if hidden_states.dtype == torch.bfloat16 else tl.float16
|
||||||
|
|
||||||
|
if inplace:
|
||||||
|
out_hidden_states = hidden_states
|
||||||
|
else:
|
||||||
|
out_hidden_states = torch.empty_like(hidden_states)
|
||||||
|
|
||||||
|
for chunk in range((num_tokens // CHUNK_SIZE) + 1):
|
||||||
|
begin_chunk_idx, end_chunk_idx = (
|
||||||
|
chunk * CHUNK_SIZE,
|
||||||
|
min((chunk + 1) * CHUNK_SIZE, num_tokens),
|
||||||
|
)
|
||||||
|
curr_hidden_states = hidden_states[begin_chunk_idx:end_chunk_idx]
|
||||||
|
tokens_in_chunk, _ = curr_hidden_states.shape
|
||||||
|
|
||||||
|
if tokens_in_chunk == 0:
|
||||||
|
break
|
||||||
|
|
||||||
|
if tokens_in_chunk < CHUNK_SIZE and chunk > 0:
|
||||||
|
# Adjust the intermediate cache size and config for the last
|
||||||
|
# chunk. Note that in most cases we only have one chunk
|
||||||
|
# so the cache size and config are already set correctly and
|
||||||
|
# do not need to be adjusted.
|
||||||
|
intermediate_cache1 = intermediate_cache1[:tokens_in_chunk]
|
||||||
|
intermediate_cache2 = intermediate_cache2[:tokens_in_chunk]
|
||||||
|
intermediate_cache3 = intermediate_cache3[:tokens_in_chunk]
|
||||||
|
config = get_config_func(tokens_in_chunk)
|
||||||
|
|
||||||
|
curr_topk_ids = topk_ids[begin_chunk_idx:end_chunk_idx]
|
||||||
|
curr_topk_weights = topk_weights[begin_chunk_idx:end_chunk_idx]
|
||||||
|
|
||||||
|
sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size(
|
||||||
|
curr_topk_ids, config["BLOCK_SIZE_M"], E
|
||||||
|
)
|
||||||
|
|
||||||
|
invoke_fused_moe_kernel(
|
||||||
|
curr_hidden_states,
|
||||||
|
w1,
|
||||||
|
intermediate_cache1,
|
||||||
|
a1_scale,
|
||||||
|
w1_scale,
|
||||||
|
curr_topk_weights,
|
||||||
|
curr_topk_ids,
|
||||||
|
sorted_token_ids,
|
||||||
|
expert_ids,
|
||||||
|
num_tokens_post_padded,
|
||||||
|
False,
|
||||||
|
topk_ids.shape[1],
|
||||||
|
config,
|
||||||
|
compute_type=compute_type,
|
||||||
|
use_fp8_w8a8=use_fp8_w8a8,
|
||||||
|
use_int8_w8a16=use_int8_w8a16,
|
||||||
|
)
|
||||||
|
|
||||||
|
ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N))
|
||||||
|
|
||||||
|
invoke_fused_moe_kernel(
|
||||||
|
intermediate_cache2,
|
||||||
|
w2,
|
||||||
|
intermediate_cache3,
|
||||||
|
a2_scale,
|
||||||
|
w2_scale,
|
||||||
|
curr_topk_weights,
|
||||||
|
curr_topk_ids,
|
||||||
|
sorted_token_ids,
|
||||||
|
expert_ids,
|
||||||
|
num_tokens_post_padded,
|
||||||
|
True,
|
||||||
|
1,
|
||||||
|
config,
|
||||||
|
compute_type=compute_type,
|
||||||
|
use_fp8_w8a8=use_fp8_w8a8,
|
||||||
|
use_int8_w8a16=use_int8_w8a16,
|
||||||
|
)
|
||||||
|
|
||||||
|
ops.moe_sum(
|
||||||
|
intermediate_cache3.view(*intermediate_cache3.shape),
|
||||||
|
out_hidden_states[begin_chunk_idx:end_chunk_idx],
|
||||||
|
)
|
||||||
|
return out_hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
def fused_moe(
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
w1: torch.Tensor,
|
||||||
|
w2: torch.Tensor,
|
||||||
|
gating_output: torch.Tensor,
|
||||||
|
topk: int,
|
||||||
|
renormalize: bool,
|
||||||
|
inplace: bool = False,
|
||||||
|
use_grouped_topk: bool = False,
|
||||||
|
num_expert_group: Optional[int] = None,
|
||||||
|
topk_group: Optional[int] = None,
|
||||||
|
custom_routing_function: Optional[Callable] = None,
|
||||||
|
use_fp8_w8a8: bool = False,
|
||||||
|
use_int8_w8a16: bool = False,
|
||||||
|
w1_scale: Optional[torch.Tensor] = None,
|
||||||
|
w2_scale: Optional[torch.Tensor] = None,
|
||||||
|
a1_scale: Optional[torch.Tensor] = None,
|
||||||
|
a2_scale: Optional[torch.Tensor] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
This function computes a Mixture of Experts (MoE) layer using two sets of
|
||||||
|
weights, w1 and w2, and top-k gating mechanism.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
- hidden_states (torch.Tensor): The input tensor to the MoE layer.
|
||||||
|
- w1 (torch.Tensor): The first set of expert weights.
|
||||||
|
- w2 (torch.Tensor): The second set of expert weights.
|
||||||
|
- gating_output (torch.Tensor): The output of the gating operation
|
||||||
|
(before softmax).
|
||||||
|
- topk (int): The number of top-k experts to select.
|
||||||
|
- renormalize (bool): If True, renormalize the top-k weights to sum to 1.
|
||||||
|
- inplace (bool): If True, perform the operation in-place.
|
||||||
|
Defaults to False.
|
||||||
|
- num_expert_group: Optional[int]: additional parameter for grouped_topk
|
||||||
|
- topk_group: Optional[int]: additional parameter for grouped_topk
|
||||||
|
- use_grouped_topk: If True, use grouped_topk instead of fused_topk
|
||||||
|
note: Deepseekv2 model uses grouped_topk
|
||||||
|
- use_fp8_w8a8 (bool): If True, use fp8 arithmetic to compute the inner
|
||||||
|
products for w1 and w2. Defaults to False.
|
||||||
|
- use_int8_w8a16 (bool): If True, use fp8 arithmetic to compute the inner
|
||||||
|
products for w1 and w2. Defaults to False.
|
||||||
|
- w1_scale (Optional[torch.Tensor]): Optional scale to be used for
|
||||||
|
w1.
|
||||||
|
- w2_scale (Optional[torch.Tensor]): Optional scale to be used for
|
||||||
|
w2.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
- torch.Tensor: The output tensor after applying the MoE layer.
|
||||||
|
"""
|
||||||
|
# Check constraints.
|
||||||
|
assert gating_output.shape[1] == w1.shape[0], "Number of experts mismatch"
|
||||||
|
|
||||||
|
if use_grouped_topk:
|
||||||
|
assert num_expert_group is not None and topk_group is not None
|
||||||
|
topk_weights, topk_ids = grouped_topk(
|
||||||
|
hidden_states,
|
||||||
|
gating_output,
|
||||||
|
topk,
|
||||||
|
renormalize,
|
||||||
|
num_expert_group,
|
||||||
|
topk_group,
|
||||||
|
)
|
||||||
|
elif custom_routing_function is None:
|
||||||
|
topk_weights, topk_ids = fused_topk(
|
||||||
|
hidden_states, gating_output, topk, renormalize
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
topk_weights, topk_ids = custom_routing_function(
|
||||||
|
hidden_states, gating_output, topk, renormalize
|
||||||
|
)
|
||||||
|
|
||||||
|
return fused_experts(
|
||||||
|
hidden_states,
|
||||||
|
w1,
|
||||||
|
w2,
|
||||||
|
topk_weights,
|
||||||
|
topk_ids,
|
||||||
|
inplace=inplace,
|
||||||
|
use_fp8_w8a8=use_fp8_w8a8,
|
||||||
|
use_int8_w8a16=use_int8_w8a16,
|
||||||
|
w1_scale=w1_scale,
|
||||||
|
w2_scale=w2_scale,
|
||||||
|
a1_scale=a1_scale,
|
||||||
|
a2_scale=a2_scale,
|
||||||
|
)
|
||||||
631
python/sglang/srt/layers/triton_fused_moe/layer.py
Normal file
631
python/sglang/srt/layers/triton_fused_moe/layer.py
Normal file
@@ -0,0 +1,631 @@
|
|||||||
|
from abc import abstractmethod
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Callable, List, Optional, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from vllm.distributed import (
|
||||||
|
get_tensor_model_parallel_rank,
|
||||||
|
get_tensor_model_parallel_world_size,
|
||||||
|
tensor_model_parallel_all_reduce,
|
||||||
|
)
|
||||||
|
from vllm.model_executor.custom_op import CustomOp
|
||||||
|
|
||||||
|
from sglang.srt.layers.custom_op_util import register_custom_op
|
||||||
|
from sglang.srt.layers.quantization.base_config import (
|
||||||
|
QuantizationConfig,
|
||||||
|
QuantizeMethodBase,
|
||||||
|
)
|
||||||
|
from sglang.srt.utils import set_weight_attrs
|
||||||
|
|
||||||
|
if torch.cuda.is_available() or torch.hip.is_available():
|
||||||
|
from .fused_moe import fused_experts
|
||||||
|
else:
|
||||||
|
fused_experts = None # type: ignore
|
||||||
|
|
||||||
|
import logging
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class FusedMoeWeightScaleSupported(Enum):
|
||||||
|
TENSOR = "tensor"
|
||||||
|
CHANNEL = "channel"
|
||||||
|
GROUP = "group"
|
||||||
|
|
||||||
|
|
||||||
|
class FusedMoEMethodBase(QuantizeMethodBase):
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def create_weights(
|
||||||
|
self,
|
||||||
|
layer: torch.nn.Module,
|
||||||
|
num_experts: int,
|
||||||
|
hidden_size: int,
|
||||||
|
intermediate_size: int,
|
||||||
|
params_dtype: torch.dtype,
|
||||||
|
**extra_weight_attrs,
|
||||||
|
):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def apply(
|
||||||
|
self,
|
||||||
|
layer: torch.nn.Module,
|
||||||
|
x: torch.Tensor,
|
||||||
|
router_logits: torch.Tensor,
|
||||||
|
top_k: int,
|
||||||
|
renormalize: bool,
|
||||||
|
use_grouped_topk: bool,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
@register_custom_op("sglang_unquantized_fused_moe")
|
||||||
|
class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||||
|
"""MoE method without quantization."""
|
||||||
|
|
||||||
|
def create_weights(
|
||||||
|
self,
|
||||||
|
layer: torch.nn.Module,
|
||||||
|
num_experts: int,
|
||||||
|
hidden_size: int,
|
||||||
|
intermediate_size: int,
|
||||||
|
params_dtype: torch.dtype,
|
||||||
|
**extra_weight_attrs,
|
||||||
|
):
|
||||||
|
# Fused gate_up_proj (column parallel)
|
||||||
|
w13_weight = torch.nn.Parameter(
|
||||||
|
torch.empty(
|
||||||
|
num_experts, 2 * intermediate_size, hidden_size, dtype=params_dtype
|
||||||
|
),
|
||||||
|
requires_grad=False,
|
||||||
|
)
|
||||||
|
layer.register_parameter("w13_weight", w13_weight)
|
||||||
|
set_weight_attrs(w13_weight, extra_weight_attrs)
|
||||||
|
|
||||||
|
# down_proj (row parallel)
|
||||||
|
w2_weight = torch.nn.Parameter(
|
||||||
|
torch.empty(
|
||||||
|
num_experts, hidden_size, intermediate_size, dtype=params_dtype
|
||||||
|
),
|
||||||
|
requires_grad=False,
|
||||||
|
)
|
||||||
|
layer.register_parameter("w2_weight", w2_weight)
|
||||||
|
set_weight_attrs(w2_weight, extra_weight_attrs)
|
||||||
|
|
||||||
|
def apply(
|
||||||
|
self,
|
||||||
|
layer: torch.nn.Module,
|
||||||
|
x: torch.Tensor,
|
||||||
|
router_logits: torch.Tensor,
|
||||||
|
top_k: int,
|
||||||
|
renormalize: bool,
|
||||||
|
use_grouped_topk: bool,
|
||||||
|
topk_group: Optional[int] = None,
|
||||||
|
num_expert_group: Optional[int] = None,
|
||||||
|
custom_routing_function: Optional[Callable] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
return self.forward(
|
||||||
|
x=x,
|
||||||
|
layer=layer,
|
||||||
|
router_logits=router_logits,
|
||||||
|
top_k=top_k,
|
||||||
|
renormalize=renormalize,
|
||||||
|
use_grouped_topk=use_grouped_topk,
|
||||||
|
topk_group=topk_group,
|
||||||
|
num_expert_group=num_expert_group,
|
||||||
|
custom_routing_function=custom_routing_function,
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward_cuda(
|
||||||
|
self,
|
||||||
|
layer: torch.nn.Module,
|
||||||
|
x: torch.Tensor,
|
||||||
|
use_grouped_topk: bool,
|
||||||
|
top_k: int,
|
||||||
|
router_logits: torch.Tensor,
|
||||||
|
renormalize: bool,
|
||||||
|
topk_group: Optional[int] = None,
|
||||||
|
num_expert_group: Optional[int] = None,
|
||||||
|
custom_routing_function: Optional[Callable] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
topk_weights, topk_ids = FusedMoE.select_experts(
|
||||||
|
hidden_states=x,
|
||||||
|
router_logits=router_logits,
|
||||||
|
use_grouped_topk=use_grouped_topk,
|
||||||
|
top_k=top_k,
|
||||||
|
renormalize=renormalize,
|
||||||
|
topk_group=topk_group,
|
||||||
|
num_expert_group=num_expert_group,
|
||||||
|
custom_routing_function=custom_routing_function,
|
||||||
|
)
|
||||||
|
|
||||||
|
return fused_experts(
|
||||||
|
hidden_states=x,
|
||||||
|
w1=layer.w13_weight,
|
||||||
|
w2=layer.w2_weight,
|
||||||
|
topk_weights=topk_weights,
|
||||||
|
topk_ids=topk_ids,
|
||||||
|
inplace=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward_cpu(self, *args, **kwargs):
|
||||||
|
raise NotImplementedError("The CPU backend currently does not support MoE.")
|
||||||
|
|
||||||
|
def forward_tpu(self, *args, **kwargs) -> torch.Tensor:
|
||||||
|
raise NotImplementedError("The TPU backend currently does not support MoE.")
|
||||||
|
|
||||||
|
forward_native = forward_cuda
|
||||||
|
|
||||||
|
|
||||||
|
class FusedMoE(torch.nn.Module):
|
||||||
|
"""FusedMoE layer for MoE models.
|
||||||
|
|
||||||
|
This layer contains both MergedColumnParallel weights (gate_up_proj /
|
||||||
|
w13) and RowParallelLinear weights (down_proj/ w2).
|
||||||
|
|
||||||
|
Note: Mixtral uses w1, w2, and w3 for gate, up, and down_proj. We
|
||||||
|
copy that naming convention here and handle any remapping in the
|
||||||
|
load_weights function in each model implementation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
num_experts: Number of experts in the model
|
||||||
|
top_k: Number of experts selected for each token
|
||||||
|
hidden_size: Input hidden state size of the transformer
|
||||||
|
intermediate_size: Intermediate size of the experts
|
||||||
|
params_dtype: Data type for the parameters.
|
||||||
|
reduce_results: Whether to all all_reduce on the output of the layer
|
||||||
|
renomalize: Whether to renormalize the logits in the fused_moe kernel
|
||||||
|
quant_config: Quantization configure.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
num_experts: int,
|
||||||
|
top_k: int,
|
||||||
|
hidden_size: int,
|
||||||
|
intermediate_size: int,
|
||||||
|
params_dtype: Optional[torch.dtype] = None,
|
||||||
|
reduce_results: bool = False,
|
||||||
|
renormalize: bool = True,
|
||||||
|
use_grouped_topk: bool = False,
|
||||||
|
num_expert_group: Optional[int] = None,
|
||||||
|
topk_group: Optional[int] = None,
|
||||||
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
tp_size: Optional[int] = None,
|
||||||
|
prefix: str = "",
|
||||||
|
custom_routing_function: Optional[Callable] = None,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
if params_dtype is None:
|
||||||
|
params_dtype = torch.get_default_dtype()
|
||||||
|
|
||||||
|
self.tp_size = (
|
||||||
|
tp_size if tp_size is not None else get_tensor_model_parallel_world_size()
|
||||||
|
)
|
||||||
|
self.top_k = top_k
|
||||||
|
self.num_experts = num_experts
|
||||||
|
self.intermediate_size_per_partition = intermediate_size // self.tp_size
|
||||||
|
self.reduce_results = reduce_results
|
||||||
|
self.renormalize = renormalize
|
||||||
|
self.use_grouped_topk = use_grouped_topk
|
||||||
|
if self.use_grouped_topk:
|
||||||
|
assert num_expert_group is not None and topk_group is not None
|
||||||
|
self.num_expert_group = num_expert_group
|
||||||
|
self.topk_group = topk_group
|
||||||
|
self.custom_routing_function = custom_routing_function
|
||||||
|
|
||||||
|
if quant_config is None:
|
||||||
|
self.quant_method: Optional[QuantizeMethodBase] = (
|
||||||
|
UnquantizedFusedMoEMethod()
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.quant_method = quant_config.get_quant_method(self, prefix)
|
||||||
|
assert self.quant_method is not None
|
||||||
|
|
||||||
|
self.quant_method.create_weights(
|
||||||
|
layer=self,
|
||||||
|
num_experts=num_experts,
|
||||||
|
hidden_size=hidden_size,
|
||||||
|
intermediate_size=self.intermediate_size_per_partition,
|
||||||
|
params_dtype=params_dtype,
|
||||||
|
weight_loader=self.weight_loader,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _load_per_tensor_weight_scale(
|
||||||
|
self,
|
||||||
|
shard_id: str,
|
||||||
|
param: torch.nn.Parameter,
|
||||||
|
loaded_weight: torch.Tensor,
|
||||||
|
expert_id: int,
|
||||||
|
):
|
||||||
|
param_data = param.data
|
||||||
|
# for per tensor weight quantization
|
||||||
|
if shard_id in ("w1", "w3"):
|
||||||
|
# We have to keep the weight scales of w1 and w3 because
|
||||||
|
# we need to re-quantize w1/w3 weights after weight loading.
|
||||||
|
idx = 0 if shard_id == "w1" else 1
|
||||||
|
param_data[expert_id][idx] = loaded_weight
|
||||||
|
# If we are in the row parallel case (down_proj)
|
||||||
|
elif shard_id == "w2":
|
||||||
|
param_data[expert_id] = loaded_weight
|
||||||
|
|
||||||
|
def _load_model_weight_or_group_weight_scale(
|
||||||
|
self,
|
||||||
|
shard_dim: int,
|
||||||
|
expert_data: torch.Tensor,
|
||||||
|
shard_id: str,
|
||||||
|
loaded_weight: torch.tensor,
|
||||||
|
tp_rank: int,
|
||||||
|
):
|
||||||
|
# Load grouped weight scales for group quantization
|
||||||
|
# or model weights
|
||||||
|
if shard_id == "w2":
|
||||||
|
self._load_w2(
|
||||||
|
shard_id=shard_id,
|
||||||
|
shard_dim=shard_dim,
|
||||||
|
loaded_weight=loaded_weight,
|
||||||
|
expert_data=expert_data,
|
||||||
|
tp_rank=tp_rank,
|
||||||
|
)
|
||||||
|
elif shard_id in ("w1", "w3"):
|
||||||
|
self._load_w13(
|
||||||
|
shard_id=shard_id,
|
||||||
|
shard_dim=shard_dim,
|
||||||
|
loaded_weight=loaded_weight,
|
||||||
|
expert_data=expert_data,
|
||||||
|
tp_rank=tp_rank,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _load_per_channel_weight_scale(
|
||||||
|
self,
|
||||||
|
expert_data: torch.Tensor,
|
||||||
|
shard_dim: int,
|
||||||
|
shard_id: str,
|
||||||
|
loaded_weight: torch.tensor,
|
||||||
|
tp_rank: int,
|
||||||
|
):
|
||||||
|
# for per channel weight quantization
|
||||||
|
if shard_id == "w2":
|
||||||
|
expert_data.copy_(loaded_weight)
|
||||||
|
elif shard_id in ("w1", "w3"):
|
||||||
|
self._load_w13(
|
||||||
|
shard_id=shard_id,
|
||||||
|
shard_dim=shard_dim,
|
||||||
|
loaded_weight=loaded_weight,
|
||||||
|
expert_data=expert_data,
|
||||||
|
tp_rank=tp_rank,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _load_w13(
|
||||||
|
self,
|
||||||
|
expert_data: torch.Tensor,
|
||||||
|
shard_dim: int,
|
||||||
|
shard_id: str,
|
||||||
|
loaded_weight: torch.tensor,
|
||||||
|
tp_rank: int,
|
||||||
|
):
|
||||||
|
|
||||||
|
# Index the loaded weight for tp sharding.
|
||||||
|
# gate_up_proj: "MergedColumnParallel", so tp sharding on output_dim
|
||||||
|
shard_size = expert_data.shape[shard_dim] // 2
|
||||||
|
loaded_weight = loaded_weight.narrow(
|
||||||
|
shard_dim, shard_size * tp_rank, shard_size
|
||||||
|
)
|
||||||
|
# Narrow parameter and load.
|
||||||
|
# w1, gate_proj: Load into first logical weight of w13.
|
||||||
|
if shard_id == "w1":
|
||||||
|
expert_data = expert_data.narrow(shard_dim, 0, shard_size)
|
||||||
|
# w3, up_proj: Load into second logical weight of w13.
|
||||||
|
else:
|
||||||
|
assert shard_id == "w3"
|
||||||
|
expert_data = expert_data.narrow(shard_dim, shard_size, shard_size)
|
||||||
|
expert_data.copy_(loaded_weight)
|
||||||
|
|
||||||
|
def _load_w2(
|
||||||
|
self,
|
||||||
|
expert_data: torch.Tensor,
|
||||||
|
shard_dim: int,
|
||||||
|
shard_id: str,
|
||||||
|
loaded_weight: torch.tensor,
|
||||||
|
tp_rank: int,
|
||||||
|
):
|
||||||
|
|
||||||
|
# Index the loaded weight for tp sharding.
|
||||||
|
# down_proj: "RowParallel" so tp sharding on input_dim
|
||||||
|
# Narrow parameter and load.
|
||||||
|
shard_size = expert_data.shape[shard_dim]
|
||||||
|
loaded_weight = loaded_weight.narrow(
|
||||||
|
shard_dim, shard_size * tp_rank, shard_size
|
||||||
|
)
|
||||||
|
# w2, down_proj: Load into only logical weight of w2.
|
||||||
|
expert_data.copy_(loaded_weight)
|
||||||
|
|
||||||
|
def _load_single_value(
|
||||||
|
self, param: torch.nn.Parameter, loaded_weight: torch.Tensor, expert_id: int
|
||||||
|
):
|
||||||
|
param_data = param.data
|
||||||
|
|
||||||
|
# Input scales can be loaded directly and should be equal.
|
||||||
|
param_data[expert_id] = loaded_weight
|
||||||
|
|
||||||
|
def _load_g_idx(
|
||||||
|
self,
|
||||||
|
shard_id: str,
|
||||||
|
expert_data: torch.Tensor,
|
||||||
|
shard_dim: int,
|
||||||
|
loaded_weight: torch.tensor,
|
||||||
|
tp_rank: int,
|
||||||
|
):
|
||||||
|
|
||||||
|
if shard_id == "w2":
|
||||||
|
self._load_w2(
|
||||||
|
shard_id=shard_id,
|
||||||
|
shard_dim=shard_dim,
|
||||||
|
loaded_weight=loaded_weight,
|
||||||
|
expert_data=expert_data,
|
||||||
|
tp_rank=tp_rank,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
assert shard_id in ("w1", "w3")
|
||||||
|
expert_data.copy_(loaded_weight)
|
||||||
|
|
||||||
|
def weight_loader(
|
||||||
|
self,
|
||||||
|
param: torch.nn.Parameter,
|
||||||
|
loaded_weight: torch.Tensor,
|
||||||
|
weight_name: str,
|
||||||
|
shard_id: str,
|
||||||
|
expert_id: int,
|
||||||
|
) -> None:
|
||||||
|
|
||||||
|
# compressed-tensors checkpoints with packed weights are stored flipped
|
||||||
|
# TODO (mgoin): check self.quant_method.quant_config.quant_format
|
||||||
|
# against known CompressionFormat enum values that have this quality
|
||||||
|
loaded_weight = (
|
||||||
|
loaded_weight.t().contiguous()
|
||||||
|
if (
|
||||||
|
self.quant_method.__class__.__name__
|
||||||
|
== "CompressedTensorsWNA16MoEMethod"
|
||||||
|
)
|
||||||
|
else loaded_weight
|
||||||
|
)
|
||||||
|
|
||||||
|
if shard_id not in ("w1", "w2", "w3"):
|
||||||
|
raise ValueError(
|
||||||
|
f"shard_id must be ['w1','w2','w3'] but " f"got {shard_id}."
|
||||||
|
)
|
||||||
|
|
||||||
|
WEIGHT_SCALE_SUPPORTED = [e.value for e in FusedMoeWeightScaleSupported]
|
||||||
|
# Fetch the dim to shard the parameter/loaded weight
|
||||||
|
# based on the shard id. This will be whatever
|
||||||
|
# dimension intermediate_size is used.
|
||||||
|
SHARD_ID_TO_SHARDED_DIM = {"w1": 0, "w2": 1, "w3": 0}
|
||||||
|
|
||||||
|
expert_data = param.data[expert_id]
|
||||||
|
tp_rank = get_tensor_model_parallel_rank()
|
||||||
|
|
||||||
|
# is_transposed: if the dim to shard the weight
|
||||||
|
# should be flipped. Required by GPTQ, compressed-tensors
|
||||||
|
# should be whatever dimension intermediate_size is
|
||||||
|
is_transposed = getattr(param, "is_transposed", False)
|
||||||
|
shard_dim = SHARD_ID_TO_SHARDED_DIM[shard_id]
|
||||||
|
if is_transposed:
|
||||||
|
shard_dim = ~shard_dim
|
||||||
|
|
||||||
|
# Case input scale: input_scale loading is only supported for fp8
|
||||||
|
if "input_scale" in weight_name:
|
||||||
|
# this is needed for compressed-tensors only
|
||||||
|
loaded_weight = loaded_weight.to(param.data.device)
|
||||||
|
|
||||||
|
if (
|
||||||
|
param.data[expert_id] != 1
|
||||||
|
and (param.data[expert_id] - loaded_weight).abs() > 1e-5
|
||||||
|
):
|
||||||
|
raise ValueError(
|
||||||
|
"input_scales of w1 and w3 of a layer "
|
||||||
|
f"must be equal. But got {param.data[expert_id]} "
|
||||||
|
f"vs. {loaded_weight}"
|
||||||
|
)
|
||||||
|
|
||||||
|
self._load_single_value(
|
||||||
|
param=param, loaded_weight=loaded_weight, expert_id=expert_id
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
# Case g_idx
|
||||||
|
if "g_idx" in weight_name:
|
||||||
|
self._load_g_idx(
|
||||||
|
shard_dim=0,
|
||||||
|
shard_id=shard_id,
|
||||||
|
loaded_weight=loaded_weight,
|
||||||
|
expert_data=expert_data,
|
||||||
|
tp_rank=tp_rank,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
# Case weight scales and zero_points
|
||||||
|
if "scale" in weight_name or "zero" in weight_name:
|
||||||
|
# load the weight scales and zp based on the quantization scheme
|
||||||
|
# supported weight scales/zp can be found in
|
||||||
|
# FusedMoeWeightScaleSupported
|
||||||
|
# TODO @dsikka: once hardened, refactor to use vLLM Parameters
|
||||||
|
# specific to each case
|
||||||
|
quant_method = getattr(param, "quant_method", None)
|
||||||
|
if quant_method == FusedMoeWeightScaleSupported.CHANNEL.value:
|
||||||
|
self._load_per_channel_weight_scale(
|
||||||
|
shard_id=shard_id,
|
||||||
|
shard_dim=shard_dim,
|
||||||
|
loaded_weight=loaded_weight,
|
||||||
|
expert_data=expert_data,
|
||||||
|
tp_rank=tp_rank,
|
||||||
|
)
|
||||||
|
elif quant_method == FusedMoeWeightScaleSupported.GROUP.value:
|
||||||
|
self._load_model_weight_or_group_weight_scale(
|
||||||
|
shard_id=shard_id,
|
||||||
|
shard_dim=shard_dim,
|
||||||
|
loaded_weight=loaded_weight,
|
||||||
|
expert_data=expert_data,
|
||||||
|
tp_rank=tp_rank,
|
||||||
|
)
|
||||||
|
elif quant_method == FusedMoeWeightScaleSupported.TENSOR.value:
|
||||||
|
self._load_per_tensor_weight_scale(
|
||||||
|
shard_id=shard_id,
|
||||||
|
param=param,
|
||||||
|
loaded_weight=loaded_weight,
|
||||||
|
expert_id=expert_id,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"quant method must be one of {WEIGHT_SCALE_SUPPORTED}"
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
# Case weight_shape
|
||||||
|
if "weight_shape" in weight_name:
|
||||||
|
# only required by compressed-tensors
|
||||||
|
self._load_single_value(
|
||||||
|
param=param, loaded_weight=loaded_weight, expert_id=expert_id
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
# Case model weights
|
||||||
|
if "weight" in weight_name:
|
||||||
|
self._load_model_weight_or_group_weight_scale(
|
||||||
|
shard_id=shard_id,
|
||||||
|
shard_dim=shard_dim,
|
||||||
|
loaded_weight=loaded_weight,
|
||||||
|
expert_data=expert_data,
|
||||||
|
tp_rank=tp_rank,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def select_experts(
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
router_logits: torch.Tensor,
|
||||||
|
top_k: int,
|
||||||
|
use_grouped_topk: bool,
|
||||||
|
renormalize: bool,
|
||||||
|
topk_group: Optional[int] = None,
|
||||||
|
num_expert_group: Optional[int] = None,
|
||||||
|
custom_routing_function: Optional[Callable] = None,
|
||||||
|
):
|
||||||
|
from vllm.model_executor.layers.fused_moe.fused_moe import (
|
||||||
|
fused_topk,
|
||||||
|
grouped_topk,
|
||||||
|
)
|
||||||
|
|
||||||
|
# DeekSeekv2 uses grouped_top_k
|
||||||
|
if use_grouped_topk:
|
||||||
|
assert topk_group is not None
|
||||||
|
assert num_expert_group is not None
|
||||||
|
topk_weights, topk_ids = grouped_topk(
|
||||||
|
hidden_states=hidden_states,
|
||||||
|
gating_output=router_logits,
|
||||||
|
topk=top_k,
|
||||||
|
renormalize=renormalize,
|
||||||
|
num_expert_group=num_expert_group,
|
||||||
|
topk_group=topk_group,
|
||||||
|
)
|
||||||
|
elif custom_routing_function is None:
|
||||||
|
topk_weights, topk_ids = fused_topk(
|
||||||
|
hidden_states=hidden_states,
|
||||||
|
gating_output=router_logits,
|
||||||
|
topk=top_k,
|
||||||
|
renormalize=renormalize,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
topk_weights, topk_ids = custom_routing_function(
|
||||||
|
hidden_states=hidden_states,
|
||||||
|
gating_output=router_logits,
|
||||||
|
topk=top_k,
|
||||||
|
renormalize=renormalize,
|
||||||
|
)
|
||||||
|
|
||||||
|
return topk_weights, topk_ids
|
||||||
|
|
||||||
|
def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor):
|
||||||
|
assert self.quant_method is not None
|
||||||
|
|
||||||
|
# Matrix multiply.
|
||||||
|
final_hidden_states = self.quant_method.apply(
|
||||||
|
layer=self,
|
||||||
|
x=hidden_states,
|
||||||
|
router_logits=router_logits,
|
||||||
|
top_k=self.top_k,
|
||||||
|
renormalize=self.renormalize,
|
||||||
|
use_grouped_topk=self.use_grouped_topk,
|
||||||
|
topk_group=self.topk_group,
|
||||||
|
num_expert_group=self.num_expert_group,
|
||||||
|
custom_routing_function=self.custom_routing_function,
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.reduce_results and self.tp_size > 1:
|
||||||
|
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
|
||||||
|
|
||||||
|
return final_hidden_states
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def make_expert_params_mapping(
|
||||||
|
cls,
|
||||||
|
ckpt_gate_proj_name: str,
|
||||||
|
ckpt_down_proj_name: str,
|
||||||
|
ckpt_up_proj_name: str,
|
||||||
|
num_experts: int,
|
||||||
|
) -> List[Tuple[str, str, int, str]]:
|
||||||
|
|
||||||
|
return [
|
||||||
|
# (param_name, weight_name, expert_id, shard_id)
|
||||||
|
(
|
||||||
|
(
|
||||||
|
"experts.w13_"
|
||||||
|
if weight_name in [ckpt_gate_proj_name, ckpt_up_proj_name]
|
||||||
|
else "experts.w2_"
|
||||||
|
),
|
||||||
|
f"experts.{expert_id}.{weight_name}.",
|
||||||
|
expert_id,
|
||||||
|
shard_id,
|
||||||
|
)
|
||||||
|
for expert_id in range(num_experts)
|
||||||
|
for shard_id, weight_name in [
|
||||||
|
("w1", ckpt_gate_proj_name),
|
||||||
|
("w2", ckpt_down_proj_name),
|
||||||
|
("w3", ckpt_up_proj_name),
|
||||||
|
]
|
||||||
|
]
|
||||||
|
|
||||||
|
def _load_fp8_scale(
|
||||||
|
self,
|
||||||
|
param: torch.nn.Parameter,
|
||||||
|
loaded_weight: torch.Tensor,
|
||||||
|
weight_name: str,
|
||||||
|
shard_id: str,
|
||||||
|
expert_id: int,
|
||||||
|
) -> None:
|
||||||
|
param_data = param.data
|
||||||
|
|
||||||
|
# Input scales can be loaded directly and should be equal.
|
||||||
|
if "input_scale" in weight_name:
|
||||||
|
if (
|
||||||
|
param_data[expert_id] != 1
|
||||||
|
and (param_data[expert_id] - loaded_weight).abs() > 1e-5
|
||||||
|
):
|
||||||
|
raise ValueError(
|
||||||
|
"input_scales of w1 and w3 of a layer "
|
||||||
|
f"must be equal. But got {param_data[expert_id]} "
|
||||||
|
f"vs. {loaded_weight}"
|
||||||
|
)
|
||||||
|
param_data[expert_id] = loaded_weight
|
||||||
|
# Weight scales
|
||||||
|
elif "weight_scale" in weight_name:
|
||||||
|
# If we are in merged column case (gate_up_proj)
|
||||||
|
if shard_id in ("w1", "w3"):
|
||||||
|
# We have to keep the weight scales of w1 and w3 because
|
||||||
|
# we need to re-quantize w1/w3 weights after weight loading.
|
||||||
|
idx = 0 if shard_id == "w1" else 1
|
||||||
|
param_data[expert_id][idx] = loaded_weight
|
||||||
|
# If we are in the row parallel case (down_proj)
|
||||||
|
else:
|
||||||
|
param_data[expert_id] = loaded_weight
|
||||||
@@ -27,7 +27,6 @@ from vllm.distributed import (
|
|||||||
get_tp_group,
|
get_tp_group,
|
||||||
tensor_model_parallel_all_reduce,
|
tensor_model_parallel_all_reduce,
|
||||||
)
|
)
|
||||||
from vllm.model_executor.layers.fused_moe import FusedMoE
|
|
||||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||||
|
|
||||||
@@ -42,6 +41,7 @@ from sglang.srt.layers.linear import (
|
|||||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||||
from sglang.srt.layers.radix_attention import RadixAttention
|
from sglang.srt.layers.radix_attention import RadixAttention
|
||||||
|
from sglang.srt.layers.triton_fused_moe import FusedMoE
|
||||||
from sglang.srt.layers.vocab_parallel_embedding import (
|
from sglang.srt.layers.vocab_parallel_embedding import (
|
||||||
ParallelLMHead,
|
ParallelLMHead,
|
||||||
VocabParallelEmbedding,
|
VocabParallelEmbedding,
|
||||||
|
|||||||
@@ -31,7 +31,7 @@ import time
|
|||||||
import warnings
|
import warnings
|
||||||
from importlib.metadata import PackageNotFoundError, version
|
from importlib.metadata import PackageNotFoundError, version
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
from typing import Any, Dict, List, Optional, Protocol, Tuple, Union
|
from typing import Any, Callable, Dict, List, Optional, Protocol, Tuple, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import psutil
|
import psutil
|
||||||
@@ -45,6 +45,7 @@ from packaging import version as pkg_version
|
|||||||
from starlette.routing import Mount
|
from starlette.routing import Mount
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.func import functional_call
|
from torch.func import functional_call
|
||||||
|
from torch.library import Library
|
||||||
from torch.profiler import ProfilerActivity, profile, record_function
|
from torch.profiler import ProfilerActivity, profile, record_function
|
||||||
from triton.runtime.cache import (
|
from triton.runtime.cache import (
|
||||||
FileCacheManager,
|
FileCacheManager,
|
||||||
@@ -930,3 +931,44 @@ def get_nvgpu_memory_capacity():
|
|||||||
def crash_on_warnings():
|
def crash_on_warnings():
|
||||||
# Crash on warning if we are running CI tests
|
# Crash on warning if we are running CI tests
|
||||||
return os.getenv("SGLANG_IS_IN_CI", "false") == "true"
|
return os.getenv("SGLANG_IS_IN_CI", "false") == "true"
|
||||||
|
|
||||||
|
|
||||||
|
def get_device_name(device_id: int = 0) -> str:
|
||||||
|
if hasattr(torch, "cuda") and torch.cuda.is_available():
|
||||||
|
return torch.cuda.get_device_name(device_id)
|
||||||
|
|
||||||
|
if hasattr(torch, "hip") and torch.hip.is_available():
|
||||||
|
return torch.hip.get_device_name(device_id)
|
||||||
|
|
||||||
|
if hasattr(torch, "xpu") and torch.xpu.is_available():
|
||||||
|
return torch.xpu.get_device_name(device_id)
|
||||||
|
|
||||||
|
if hasattr(torch, "hpu") and torch.hpu.is_available():
|
||||||
|
return torch.hpu.get_device_name(device_id)
|
||||||
|
|
||||||
|
|
||||||
|
sglang_lib = Library("sglang", "FRAGMENT") # noqa
|
||||||
|
|
||||||
|
|
||||||
|
def direct_register_custom_op(
|
||||||
|
op_name: str,
|
||||||
|
op_func: Callable,
|
||||||
|
mutates_args: List[str],
|
||||||
|
fake_impl: Optional[Callable] = None,
|
||||||
|
target_lib: Optional[Library] = None,
|
||||||
|
):
|
||||||
|
import torch.library
|
||||||
|
|
||||||
|
if hasattr(torch.library, "infer_schema"):
|
||||||
|
schema_str = torch.library.infer_schema(op_func, mutates_args=mutates_args)
|
||||||
|
else:
|
||||||
|
# for pytorch 2.4
|
||||||
|
import torch._custom_op.impl
|
||||||
|
|
||||||
|
schema_str = torch._custom_op.impl.infer_schema(op_func, mutates_args)
|
||||||
|
|
||||||
|
my_lib = target_lib or sglang_lib
|
||||||
|
my_lib.define(op_name + schema_str)
|
||||||
|
my_lib.impl(op_name, op_func, "CUDA")
|
||||||
|
if fake_impl is not None:
|
||||||
|
my_lib._register_fake(op_name, fake_impl)
|
||||||
|
|||||||
Reference in New Issue
Block a user