feat: remove the dependency on FusedMoE (#2153)
This commit is contained in:
@@ -57,12 +57,23 @@ __all__ = [
|
||||
"QUANTIZATION_METHODS",
|
||||
]
|
||||
|
||||
"""
|
||||
def fp8_get_quant_method(
|
||||
self, layer: torch.nn.Module, prefix: str
|
||||
) -> Optional["QuantizeMethodBase"]:
|
||||
|
||||
def fp8_get_quant_method(self, layer, prefix):
|
||||
from vllm.model_executor.layers.linear import LinearBase
|
||||
from vllm.model_executor.layers.quantization.fp8 import (
|
||||
Fp8LinearMethod,
|
||||
Fp8MoEMethod,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
is_layer_skipped,
|
||||
)
|
||||
|
||||
from sglang.srt.layers.triton_fused_moe.layer import FusedMoE
|
||||
|
||||
if isinstance(layer, LinearBase):
|
||||
if is_layer_skipped(prefix, self.ignored_layers):
|
||||
from sglang.srt.layers.linear import UnquantizedLinearMethod
|
||||
|
||||
return UnquantizedLinearMethod()
|
||||
return Fp8LinearMethod(self)
|
||||
elif isinstance(layer, FusedMoE):
|
||||
@@ -71,4 +82,3 @@ def fp8_get_quant_method(
|
||||
|
||||
|
||||
setattr(Fp8Config, "get_quant_method", fp8_get_quant_method)
|
||||
"""
|
||||
|
||||
44
python/sglang/srt/layers/triton_fused_moe/__init__.py
Normal file
44
python/sglang/srt/layers/triton_fused_moe/__init__.py
Normal file
@@ -0,0 +1,44 @@
|
||||
from contextlib import contextmanager
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import sglang.srt.layers.triton_fused_moe.fused_moe # noqa
|
||||
from sglang.srt.layers.triton_fused_moe.fused_moe import (
|
||||
fused_experts,
|
||||
fused_topk,
|
||||
get_config_file_name,
|
||||
grouped_topk,
|
||||
)
|
||||
from sglang.srt.layers.triton_fused_moe.layer import (
|
||||
FusedMoE,
|
||||
FusedMoEMethodBase,
|
||||
FusedMoeWeightScaleSupported,
|
||||
)
|
||||
|
||||
_config: Optional[Dict[str, Any]] = None
|
||||
|
||||
|
||||
@contextmanager
|
||||
def override_config(config):
|
||||
global _config
|
||||
old_config = _config
|
||||
_config = config
|
||||
yield
|
||||
_config = old_config
|
||||
|
||||
|
||||
def get_config() -> Optional[Dict[str, Any]]:
|
||||
return _config
|
||||
|
||||
|
||||
__all__ = [
|
||||
"FusedMoE",
|
||||
"FusedMoEMethodBase",
|
||||
"FusedMoeWeightScaleSupported",
|
||||
"override_config",
|
||||
"get_config",
|
||||
"fused_moe",
|
||||
"fused_topk",
|
||||
"fused_experts",
|
||||
"get_config_file_name",
|
||||
"grouped_topk",
|
||||
]
|
||||
10
python/sglang/srt/layers/triton_fused_moe/configs/README
Normal file
10
python/sglang/srt/layers/triton_fused_moe/configs/README
Normal file
@@ -0,0 +1,10 @@
|
||||
This directory contains tuned configurations for different settings of the fused_moe kernel.
|
||||
For different settings of
|
||||
- E (number of experts)
|
||||
- N (intermediate size)
|
||||
- device_name (torch.cuda.get_device_name())
|
||||
the JSON file contains a mapping from M (batch size) to the chosen configuration.
|
||||
|
||||
The example configurations provided are for the Mixtral model for TP2 on H100
|
||||
and TP4 on A100. Mixtral has intermediate size N = 14336, i.e. for TP2 we have
|
||||
N = 7168 and for TP4 we have N = 3584.
|
||||
858
python/sglang/srt/layers/triton_fused_moe/fused_moe.py
Normal file
858
python/sglang/srt/layers/triton_fused_moe/fused_moe.py
Normal file
@@ -0,0 +1,858 @@
|
||||
"""Fused MoE kernel."""
|
||||
|
||||
import functools
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from typing import Any, Callable, Dict, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
from vllm import _custom_ops as ops
|
||||
|
||||
from sglang.srt.utils import direct_register_custom_op, get_device_name
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def fused_moe_kernel(
|
||||
# Pointers to matrices
|
||||
a_ptr,
|
||||
b_ptr,
|
||||
c_ptr,
|
||||
a_scale_ptr,
|
||||
b_scale_ptr,
|
||||
topk_weights_ptr,
|
||||
sorted_token_ids_ptr,
|
||||
expert_ids_ptr,
|
||||
num_tokens_post_padded_ptr,
|
||||
# Matrix dimensions
|
||||
N,
|
||||
K,
|
||||
EM,
|
||||
num_valid_tokens,
|
||||
# The stride variables represent how much to increase the ptr by when
|
||||
# moving by 1 element in a particular dimension. E.g. `stride_am` is
|
||||
# how much to increase `a_ptr` by to get the element one row down
|
||||
# (A has M rows).
|
||||
stride_am,
|
||||
stride_ak,
|
||||
stride_be,
|
||||
stride_bk,
|
||||
stride_bn,
|
||||
stride_cm,
|
||||
stride_cn,
|
||||
stride_bse,
|
||||
stride_bsn,
|
||||
# Meta-parameters
|
||||
BLOCK_SIZE_M: tl.constexpr,
|
||||
BLOCK_SIZE_N: tl.constexpr,
|
||||
BLOCK_SIZE_K: tl.constexpr,
|
||||
GROUP_SIZE_M: tl.constexpr,
|
||||
MUL_ROUTED_WEIGHT: tl.constexpr,
|
||||
top_k: tl.constexpr,
|
||||
compute_type: tl.constexpr,
|
||||
use_fp8_w8a8: tl.constexpr,
|
||||
use_int8_w8a16: tl.constexpr,
|
||||
):
|
||||
"""
|
||||
Implements the fused computation for a Mixture of Experts (MOE) using
|
||||
token and expert matrices.
|
||||
|
||||
Key Parameters:
|
||||
- A: The input tensor representing tokens with shape (*, K), where '*' can
|
||||
be any shape representing batches and K is the feature dimension of
|
||||
each token.
|
||||
- B: The stacked MOE weight tensor with shape (E, N, K), where E is
|
||||
the number of experts, K is the input feature dimension, and N is
|
||||
the output feature dimension.
|
||||
- C: The output cache tensor with shape (M, topk, N), where M is the
|
||||
total number of tokens post padding, topk is the number of times
|
||||
each token is repeated, and N is the output feature dimension.
|
||||
- sorted_token_ids: A tensor containing the sorted indices of tokens,
|
||||
repeated topk times and arranged by the expert index they are
|
||||
assigned to.
|
||||
- expert_ids: A tensor containing the indices of the expert for each
|
||||
block. It determines which expert matrix from B should be used for
|
||||
each block in A.
|
||||
This kernel performs the multiplication of a token by its corresponding
|
||||
expert matrix as determined by `expert_ids`. The sorting of
|
||||
`sorted_token_ids` by expert index and padding ensures divisibility by
|
||||
BLOCK_SIZE_M, which is necessary to maintain consistency in block matrix
|
||||
multiplication across different blocks processed by the same expert.
|
||||
"""
|
||||
# -----------------------------------------------------------
|
||||
# Map program ids `pid` to the block of C it should compute.
|
||||
# This is done in a grouped ordering to promote L2 data reuse.
|
||||
pid = tl.program_id(axis=0)
|
||||
num_pid_m = tl.cdiv(EM, BLOCK_SIZE_M)
|
||||
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
|
||||
num_pid_in_group = GROUP_SIZE_M * num_pid_n
|
||||
group_id = pid // num_pid_in_group
|
||||
first_pid_m = group_id * GROUP_SIZE_M
|
||||
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
|
||||
pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
|
||||
pid_n = (pid % num_pid_in_group) // group_size_m
|
||||
|
||||
# ----------------------------------------------------------
|
||||
# Create pointers for the first blocks of A and B.
|
||||
# We will advance this pointer as we move in the K direction
|
||||
# and accumulate
|
||||
# `a_ptrs` is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers
|
||||
# `b_ptrs` is a block of [BLOCK_SIZE_K, BLOCK_SIZE_N] pointers
|
||||
num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr)
|
||||
if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded:
|
||||
return
|
||||
offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
||||
offs_token = tl.load(sorted_token_ids_ptr + offs_token_id)
|
||||
token_mask = offs_token < num_valid_tokens
|
||||
|
||||
offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
|
||||
offs_k = tl.arange(0, BLOCK_SIZE_K)
|
||||
a_ptrs = a_ptr + (
|
||||
offs_token[:, None] // top_k * stride_am + offs_k[None, :] * stride_ak
|
||||
)
|
||||
|
||||
off_experts = tl.load(expert_ids_ptr + pid_m)
|
||||
b_ptrs = (
|
||||
b_ptr
|
||||
+ off_experts * stride_be
|
||||
+ (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
|
||||
)
|
||||
if use_int8_w8a16:
|
||||
b_scale_ptrs = (
|
||||
b_scale_ptr + off_experts * stride_bse + offs_bn[None, :] * stride_bsn
|
||||
)
|
||||
b_scale = tl.load(b_scale_ptrs)
|
||||
|
||||
if use_fp8_w8a8:
|
||||
a_scale = tl.load(a_scale_ptr)
|
||||
b_scale = tl.load(b_scale_ptr + off_experts)
|
||||
|
||||
# -----------------------------------------------------------
|
||||
# Iterate to compute a block of the C matrix.
|
||||
# We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block
|
||||
# of fp32 values for higher accuracy.
|
||||
# `accumulator` will be converted back to fp16 after the loop.
|
||||
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
||||
|
||||
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
|
||||
# Load the next block of A and B, generate a mask by checking the
|
||||
# K dimension.
|
||||
a = tl.load(
|
||||
a_ptrs,
|
||||
mask=token_mask[:, None] & (offs_k[None, :] < K - k * BLOCK_SIZE_K),
|
||||
other=0.0,
|
||||
)
|
||||
b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)
|
||||
# We accumulate along the K dimension.
|
||||
if use_int8_w8a16:
|
||||
accumulator = tl.dot(a, b.to(compute_type), acc=accumulator)
|
||||
elif use_fp8_w8a8:
|
||||
accumulator = tl.dot(a, b, acc=accumulator)
|
||||
else:
|
||||
accumulator += tl.dot(a, b)
|
||||
# Advance the ptrs to the next K block.
|
||||
a_ptrs += BLOCK_SIZE_K * stride_ak
|
||||
b_ptrs += BLOCK_SIZE_K * stride_bk
|
||||
|
||||
if MUL_ROUTED_WEIGHT:
|
||||
moe_weight = tl.load(topk_weights_ptr + offs_token, mask=token_mask, other=0)
|
||||
accumulator = accumulator * moe_weight[:, None]
|
||||
if use_int8_w8a16:
|
||||
accumulator = (accumulator * b_scale).to(compute_type)
|
||||
elif use_fp8_w8a8:
|
||||
accumulator = (accumulator * a_scale * b_scale).to(compute_type)
|
||||
else:
|
||||
accumulator = accumulator.to(compute_type)
|
||||
# -----------------------------------------------------------
|
||||
# Write back the block of the output
|
||||
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
||||
c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[None, :]
|
||||
c_mask = token_mask[:, None] & (offs_cn[None, :] < N)
|
||||
tl.store(c_ptrs, accumulator, mask=c_mask)
|
||||
|
||||
|
||||
def moe_align_block_size(
|
||||
topk_ids: torch.Tensor, block_size: int, num_experts: int
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Aligns the token distribution across experts to be compatible with block
|
||||
size for matrix multiplication.
|
||||
|
||||
Parameters:
|
||||
- topk_ids: A tensor of shape [total_tokens, top_k] representing the
|
||||
top-k expert indices for each token.
|
||||
- block_size: The block size used in block matrix multiplication.
|
||||
- num_experts: The total number of experts.
|
||||
|
||||
Returns:
|
||||
- sorted_token_ids: A tensor containing the sorted token indices according
|
||||
to their allocated expert.
|
||||
- expert_ids: A tensor indicating the assigned expert index for each block.
|
||||
- num_tokens_post_padded: The total number of tokens after padding,
|
||||
ensuring divisibility by block_size.
|
||||
|
||||
This function pads the number of tokens that each expert needs to process
|
||||
so that it is divisible by block_size.
|
||||
Padding ensures that during block matrix multiplication, the dimensions
|
||||
align correctly.
|
||||
|
||||
Example:
|
||||
Given topk_ids = [[2, 3, 4], [1, 2, 4], [1, 3, 4], [1, 2, 3]],
|
||||
block_size = 4, and num_experts = 4:
|
||||
- We initially have 12 tokens (after repeating 'top_k' times) and 4 experts,
|
||||
with each expert needing to process 3 tokens.
|
||||
- As block_size is 4, we pad 1 token for each expert.
|
||||
- First, flatten topk_ids to [2, 3, 4, 1, 2, 4, 1, 3, 4, 1, 2, 3].
|
||||
- Then append padding tokens [12, 12, 12, 12] for each block.
|
||||
- After sorting by expert index, we obtain token_ids
|
||||
[3, 6, 9, 12, 0, 4, 10, 12, 1, 7, 11, 12, 2, 5, 8, 12].
|
||||
Tokens 12 are non-existent (padding) and are ignored in
|
||||
the subsequent matrix multiplication.
|
||||
- The padding ensures that the total number of tokens is now divisible
|
||||
by block_size for proper block matrix operations.
|
||||
"""
|
||||
max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1)
|
||||
sorted_ids = torch.empty(
|
||||
(max_num_tokens_padded,), dtype=torch.int32, device=topk_ids.device
|
||||
)
|
||||
sorted_ids.fill_(topk_ids.numel())
|
||||
max_num_m_blocks = triton.cdiv(max_num_tokens_padded, block_size)
|
||||
expert_ids = torch.empty(
|
||||
(max_num_m_blocks,), dtype=torch.int32, device=topk_ids.device
|
||||
)
|
||||
num_tokens_post_pad = torch.empty((1), dtype=torch.int32, device=topk_ids.device)
|
||||
ops.moe_align_block_size(
|
||||
topk_ids, num_experts, block_size, sorted_ids, expert_ids, num_tokens_post_pad
|
||||
)
|
||||
return sorted_ids, expert_ids, num_tokens_post_pad
|
||||
|
||||
|
||||
def invoke_fused_moe_kernel(
|
||||
A: torch.Tensor,
|
||||
B: torch.Tensor,
|
||||
C: torch.Tensor,
|
||||
A_scale: Optional[torch.Tensor],
|
||||
B_scale: Optional[torch.Tensor],
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
sorted_token_ids: torch.Tensor,
|
||||
expert_ids: torch.Tensor,
|
||||
num_tokens_post_padded: torch.Tensor,
|
||||
mul_routed_weight: bool,
|
||||
top_k: int,
|
||||
config: Dict[str, Any],
|
||||
compute_type: tl.dtype,
|
||||
use_fp8_w8a8: bool,
|
||||
use_int8_w8a16: bool,
|
||||
) -> None:
|
||||
assert topk_weights.stride(1) == 1
|
||||
assert sorted_token_ids.stride(0) == 1
|
||||
|
||||
if use_fp8_w8a8:
|
||||
A, A_scale = ops.scaled_fp8_quant(A, A_scale)
|
||||
assert B_scale is not None
|
||||
elif use_int8_w8a16:
|
||||
assert B_scale is not None
|
||||
else:
|
||||
assert A_scale is None
|
||||
assert B_scale is None
|
||||
|
||||
grid = lambda META: (
|
||||
triton.cdiv(sorted_token_ids.shape[0], META["BLOCK_SIZE_M"])
|
||||
* triton.cdiv(B.shape[1], META["BLOCK_SIZE_N"]),
|
||||
)
|
||||
|
||||
fused_moe_kernel[grid](
|
||||
A,
|
||||
B,
|
||||
C,
|
||||
A_scale,
|
||||
B_scale,
|
||||
topk_weights,
|
||||
sorted_token_ids,
|
||||
expert_ids,
|
||||
num_tokens_post_padded,
|
||||
B.shape[1],
|
||||
B.shape[2],
|
||||
sorted_token_ids.shape[0],
|
||||
topk_ids.numel(),
|
||||
A.stride(0),
|
||||
A.stride(1),
|
||||
B.stride(0),
|
||||
B.stride(2),
|
||||
B.stride(1),
|
||||
C.stride(1),
|
||||
C.stride(2),
|
||||
B_scale.stride(0) if B_scale is not None and use_int8_w8a16 else 0,
|
||||
B_scale.stride(1) if B_scale is not None and use_int8_w8a16 else 0,
|
||||
MUL_ROUTED_WEIGHT=mul_routed_weight,
|
||||
top_k=top_k,
|
||||
compute_type=compute_type,
|
||||
use_fp8_w8a8=use_fp8_w8a8,
|
||||
use_int8_w8a16=use_int8_w8a16,
|
||||
**config,
|
||||
)
|
||||
|
||||
|
||||
def get_config_file_name(E: int, N: int, dtype: Optional[str]) -> str:
|
||||
device_name = get_device_name().replace(" ", "_")
|
||||
dtype_selector = "" if not dtype else f",dtype={dtype}"
|
||||
return f"E={E},N={N},device_name={device_name}{dtype_selector}.json"
|
||||
|
||||
|
||||
@functools.lru_cache
|
||||
def get_moe_configs(E: int, N: int, dtype: Optional[str]) -> Optional[Dict[int, Any]]:
|
||||
"""
|
||||
Return optimized configurations for the fused MoE kernel.
|
||||
|
||||
The return value will be a dictionary that maps an irregular grid of
|
||||
batch sizes to configurations of the fused_moe kernel. To evaluate the
|
||||
kernel on a given batch size bs, the closest batch size in the grid should
|
||||
be picked and the associated configuration chosen to invoke the kernel.
|
||||
"""
|
||||
|
||||
# First look up if an optimized configuration is available in the configs
|
||||
# directory
|
||||
json_file_name = get_config_file_name(E, N, dtype)
|
||||
|
||||
config_file_path = os.path.join(
|
||||
os.path.dirname(os.path.realpath(__file__)), "configs", json_file_name
|
||||
)
|
||||
if os.path.exists(config_file_path):
|
||||
with open(config_file_path) as f:
|
||||
logger.info("Using configuration from %s for MoE layer.", config_file_path)
|
||||
# If a configuration has been found, return it
|
||||
return {int(key): val for key, val in json.load(f).items()}
|
||||
|
||||
# If no optimized configuration is available, we will use the default
|
||||
# configuration
|
||||
logger.warning(
|
||||
(
|
||||
"Using default MoE config. Performance might be sub-optimal! "
|
||||
"Config file not found at %s"
|
||||
),
|
||||
config_file_path,
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
def get_default_config(
|
||||
M: int,
|
||||
E: int,
|
||||
N: int,
|
||||
K: int,
|
||||
topk: int,
|
||||
dtype: Optional[str],
|
||||
is_marlin: bool,
|
||||
) -> Dict[str, int]:
|
||||
config = {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 32,
|
||||
"GROUP_SIZE_M": 8,
|
||||
}
|
||||
# A heuristic: fused marlin works faster with this config for small M
|
||||
if M <= E or (is_marlin and M <= 32):
|
||||
config = {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
}
|
||||
return config
|
||||
|
||||
|
||||
def try_get_optimal_moe_config(
|
||||
w1_shape: Tuple[int, ...],
|
||||
w2_shape: Tuple[int, ...],
|
||||
top_k: int,
|
||||
dtype: Optional[str],
|
||||
M: int,
|
||||
is_marlin: bool = False,
|
||||
):
|
||||
from sglang.srt.layers.triton_fused_moe import get_config
|
||||
|
||||
override_config = get_config()
|
||||
if override_config:
|
||||
config = override_config
|
||||
else:
|
||||
# First try to load optimal config from the file
|
||||
E, _, N = w2_shape
|
||||
configs = get_moe_configs(E, N, dtype)
|
||||
|
||||
if configs:
|
||||
# If an optimal configuration map has been found, look up the
|
||||
# optimal config
|
||||
config = configs[min(configs.keys(), key=lambda x: abs(x - M))]
|
||||
else:
|
||||
# Else use the default config
|
||||
config = get_default_config(M, E, N, w1_shape[2], top_k, dtype, is_marlin)
|
||||
return config
|
||||
|
||||
|
||||
def fused_topk(
|
||||
hidden_states: torch.Tensor,
|
||||
gating_output: torch.Tensor,
|
||||
topk: int,
|
||||
renormalize: bool,
|
||||
):
|
||||
assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
|
||||
|
||||
M, _ = hidden_states.shape
|
||||
|
||||
topk_weights = torch.empty(
|
||||
M, topk, dtype=torch.float32, device=hidden_states.device
|
||||
)
|
||||
topk_ids = torch.empty(M, topk, dtype=torch.int32, device=hidden_states.device)
|
||||
token_expert_indicies = torch.empty(
|
||||
M, topk, dtype=torch.int32, device=hidden_states.device
|
||||
)
|
||||
|
||||
ops.topk_softmax(
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
token_expert_indicies,
|
||||
gating_output.float(), # TODO(woosuk): Optimize this.
|
||||
)
|
||||
del token_expert_indicies # Not used. Will be used in the future.
|
||||
|
||||
if renormalize:
|
||||
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
|
||||
|
||||
return topk_weights, topk_ids
|
||||
|
||||
|
||||
# This is used by the Deepseek-V2 model
|
||||
def grouped_topk(
|
||||
hidden_states: torch.Tensor,
|
||||
gating_output: torch.Tensor,
|
||||
topk: int,
|
||||
renormalize: bool,
|
||||
num_expert_group: int = 0,
|
||||
topk_group: int = 0,
|
||||
):
|
||||
|
||||
assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
|
||||
|
||||
scores = torch.softmax(gating_output, dim=-1)
|
||||
num_token = scores.shape[0]
|
||||
group_scores = (
|
||||
scores.view(num_token, num_expert_group, -1).max(dim=-1).values
|
||||
) # [n, n_group]
|
||||
group_idx = torch.topk(group_scores, k=topk_group, dim=-1, sorted=False)[
|
||||
1
|
||||
] # [n, top_k_group]
|
||||
group_mask = torch.zeros_like(group_scores) # [n, n_group]
|
||||
group_mask.scatter_(1, group_idx, 1) # [n, n_group]
|
||||
score_mask = (
|
||||
group_mask.unsqueeze(-1)
|
||||
.expand(num_token, num_expert_group, scores.shape[-1] // num_expert_group)
|
||||
.reshape(num_token, -1)
|
||||
) # [n, e]
|
||||
tmp_scores = scores.masked_fill(~score_mask.bool(), 0.0) # [n, e]
|
||||
topk_weights, topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False)
|
||||
|
||||
if renormalize:
|
||||
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
|
||||
|
||||
return topk_weights.to(torch.float32), topk_ids.to(torch.int32)
|
||||
|
||||
|
||||
def get_config_dtype_str(
|
||||
dtype: torch.dtype,
|
||||
use_int8_w8a16: Optional[bool] = False,
|
||||
use_fp8_w8a8: Optional[bool] = False,
|
||||
):
|
||||
if use_fp8_w8a8:
|
||||
return "fp8_w8a8"
|
||||
elif use_int8_w8a16:
|
||||
return "int8_w8a16"
|
||||
elif dtype == torch.float:
|
||||
# avoiding cases where kernel fails when float32 MoE
|
||||
# use fp16/bfloat16 configs
|
||||
return "float32"
|
||||
return None
|
||||
|
||||
|
||||
def inplace_fused_experts(
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
use_fp8_w8a8: bool = False,
|
||||
use_int8_w8a16: bool = False,
|
||||
w1_scale: Optional[torch.Tensor] = None,
|
||||
w2_scale: Optional[torch.Tensor] = None,
|
||||
a1_scale: Optional[torch.Tensor] = None,
|
||||
a2_scale: Optional[torch.Tensor] = None,
|
||||
) -> None:
|
||||
fused_experts_impl(
|
||||
hidden_states,
|
||||
w1,
|
||||
w2,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
True,
|
||||
use_fp8_w8a8,
|
||||
use_int8_w8a16,
|
||||
w1_scale,
|
||||
w2_scale,
|
||||
a1_scale,
|
||||
a2_scale,
|
||||
)
|
||||
|
||||
|
||||
def inplace_fused_experts_fake(
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
use_fp8_w8a8: bool = False,
|
||||
use_int8_w8a16: bool = False,
|
||||
w1_scale: Optional[torch.Tensor] = None,
|
||||
w2_scale: Optional[torch.Tensor] = None,
|
||||
a1_scale: Optional[torch.Tensor] = None,
|
||||
a2_scale: Optional[torch.Tensor] = None,
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="inplace_fused_experts",
|
||||
op_func=inplace_fused_experts,
|
||||
mutates_args=["hidden_states"],
|
||||
fake_impl=inplace_fused_experts_fake,
|
||||
)
|
||||
|
||||
|
||||
def outplace_fused_experts(
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
use_fp8_w8a8: bool = False,
|
||||
use_int8_w8a16: bool = False,
|
||||
w1_scale: Optional[torch.Tensor] = None,
|
||||
w2_scale: Optional[torch.Tensor] = None,
|
||||
a1_scale: Optional[torch.Tensor] = None,
|
||||
a2_scale: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
return fused_experts_impl(
|
||||
hidden_states,
|
||||
w1,
|
||||
w2,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
False,
|
||||
use_fp8_w8a8,
|
||||
use_int8_w8a16,
|
||||
w1_scale,
|
||||
w2_scale,
|
||||
a1_scale,
|
||||
a2_scale,
|
||||
)
|
||||
|
||||
|
||||
def outplace_fused_experts_fake(
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
use_fp8_w8a8: bool = False,
|
||||
use_int8_w8a16: bool = False,
|
||||
w1_scale: Optional[torch.Tensor] = None,
|
||||
w2_scale: Optional[torch.Tensor] = None,
|
||||
a1_scale: Optional[torch.Tensor] = None,
|
||||
a2_scale: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
return torch.empty_like(hidden_states)
|
||||
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="outplace_fused_experts",
|
||||
op_func=outplace_fused_experts,
|
||||
mutates_args=[],
|
||||
fake_impl=outplace_fused_experts_fake,
|
||||
)
|
||||
|
||||
|
||||
def fused_experts(
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
inplace: bool = False,
|
||||
use_fp8_w8a8: bool = False,
|
||||
use_int8_w8a16: bool = False,
|
||||
w1_scale: Optional[torch.Tensor] = None,
|
||||
w2_scale: Optional[torch.Tensor] = None,
|
||||
a1_scale: Optional[torch.Tensor] = None,
|
||||
a2_scale: Optional[torch.Tensor] = None,
|
||||
):
|
||||
if inplace:
|
||||
torch.ops.sglang.inplace_fused_experts(
|
||||
hidden_states,
|
||||
w1,
|
||||
w2,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
use_fp8_w8a8,
|
||||
use_int8_w8a16,
|
||||
w1_scale,
|
||||
w2_scale,
|
||||
a1_scale,
|
||||
a2_scale,
|
||||
)
|
||||
return hidden_states
|
||||
else:
|
||||
return torch.ops.sglang.outplace_fused_experts(
|
||||
hidden_states,
|
||||
w1,
|
||||
w2,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
use_fp8_w8a8,
|
||||
use_int8_w8a16,
|
||||
w1_scale,
|
||||
w2_scale,
|
||||
a1_scale,
|
||||
a2_scale,
|
||||
)
|
||||
|
||||
|
||||
def fused_experts_impl(
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
inplace: bool = False,
|
||||
use_fp8_w8a8: bool = False,
|
||||
use_int8_w8a16: bool = False,
|
||||
w1_scale: Optional[torch.Tensor] = None,
|
||||
w2_scale: Optional[torch.Tensor] = None,
|
||||
a1_scale: Optional[torch.Tensor] = None,
|
||||
a2_scale: Optional[torch.Tensor] = None,
|
||||
):
|
||||
# Check constraints.
|
||||
assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch"
|
||||
assert topk_weights.shape == topk_ids.shape, "topk shape mismatch"
|
||||
assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
|
||||
assert w1.is_contiguous(), "Expert weights1 must be contiguous"
|
||||
assert w2.is_contiguous(), "Expert weights2 must be contiguous"
|
||||
assert hidden_states.dtype in [torch.float32, torch.float16, torch.bfloat16]
|
||||
|
||||
num_tokens, _ = hidden_states.shape
|
||||
E, N, _ = w1.shape
|
||||
# We execute the fused_moe kernel in chunks to circumvent this issue:
|
||||
# https://github.com/vllm-project/vllm/issues/5938
|
||||
CHUNK_SIZE = 64 * 1024
|
||||
M = min(num_tokens, CHUNK_SIZE)
|
||||
config_dtype = get_config_dtype_str(
|
||||
use_fp8_w8a8=use_fp8_w8a8,
|
||||
use_int8_w8a16=use_int8_w8a16,
|
||||
dtype=hidden_states.dtype,
|
||||
)
|
||||
|
||||
get_config_func = functools.partial(
|
||||
try_get_optimal_moe_config,
|
||||
w1.shape,
|
||||
w2.shape,
|
||||
topk_ids.shape[1],
|
||||
config_dtype,
|
||||
)
|
||||
|
||||
config = get_config_func(M)
|
||||
|
||||
intermediate_cache1 = torch.empty(
|
||||
(M, topk_ids.shape[1], N),
|
||||
device=hidden_states.device,
|
||||
dtype=hidden_states.dtype,
|
||||
)
|
||||
intermediate_cache2 = torch.empty(
|
||||
(M * topk_ids.shape[1], N // 2),
|
||||
device=hidden_states.device,
|
||||
dtype=hidden_states.dtype,
|
||||
)
|
||||
intermediate_cache3 = torch.empty(
|
||||
(M, topk_ids.shape[1], w2.shape[1]),
|
||||
device=hidden_states.device,
|
||||
dtype=hidden_states.dtype,
|
||||
)
|
||||
|
||||
compute_type = tl.bfloat16 if hidden_states.dtype == torch.bfloat16 else tl.float16
|
||||
|
||||
if inplace:
|
||||
out_hidden_states = hidden_states
|
||||
else:
|
||||
out_hidden_states = torch.empty_like(hidden_states)
|
||||
|
||||
for chunk in range((num_tokens // CHUNK_SIZE) + 1):
|
||||
begin_chunk_idx, end_chunk_idx = (
|
||||
chunk * CHUNK_SIZE,
|
||||
min((chunk + 1) * CHUNK_SIZE, num_tokens),
|
||||
)
|
||||
curr_hidden_states = hidden_states[begin_chunk_idx:end_chunk_idx]
|
||||
tokens_in_chunk, _ = curr_hidden_states.shape
|
||||
|
||||
if tokens_in_chunk == 0:
|
||||
break
|
||||
|
||||
if tokens_in_chunk < CHUNK_SIZE and chunk > 0:
|
||||
# Adjust the intermediate cache size and config for the last
|
||||
# chunk. Note that in most cases we only have one chunk
|
||||
# so the cache size and config are already set correctly and
|
||||
# do not need to be adjusted.
|
||||
intermediate_cache1 = intermediate_cache1[:tokens_in_chunk]
|
||||
intermediate_cache2 = intermediate_cache2[:tokens_in_chunk]
|
||||
intermediate_cache3 = intermediate_cache3[:tokens_in_chunk]
|
||||
config = get_config_func(tokens_in_chunk)
|
||||
|
||||
curr_topk_ids = topk_ids[begin_chunk_idx:end_chunk_idx]
|
||||
curr_topk_weights = topk_weights[begin_chunk_idx:end_chunk_idx]
|
||||
|
||||
sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size(
|
||||
curr_topk_ids, config["BLOCK_SIZE_M"], E
|
||||
)
|
||||
|
||||
invoke_fused_moe_kernel(
|
||||
curr_hidden_states,
|
||||
w1,
|
||||
intermediate_cache1,
|
||||
a1_scale,
|
||||
w1_scale,
|
||||
curr_topk_weights,
|
||||
curr_topk_ids,
|
||||
sorted_token_ids,
|
||||
expert_ids,
|
||||
num_tokens_post_padded,
|
||||
False,
|
||||
topk_ids.shape[1],
|
||||
config,
|
||||
compute_type=compute_type,
|
||||
use_fp8_w8a8=use_fp8_w8a8,
|
||||
use_int8_w8a16=use_int8_w8a16,
|
||||
)
|
||||
|
||||
ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N))
|
||||
|
||||
invoke_fused_moe_kernel(
|
||||
intermediate_cache2,
|
||||
w2,
|
||||
intermediate_cache3,
|
||||
a2_scale,
|
||||
w2_scale,
|
||||
curr_topk_weights,
|
||||
curr_topk_ids,
|
||||
sorted_token_ids,
|
||||
expert_ids,
|
||||
num_tokens_post_padded,
|
||||
True,
|
||||
1,
|
||||
config,
|
||||
compute_type=compute_type,
|
||||
use_fp8_w8a8=use_fp8_w8a8,
|
||||
use_int8_w8a16=use_int8_w8a16,
|
||||
)
|
||||
|
||||
ops.moe_sum(
|
||||
intermediate_cache3.view(*intermediate_cache3.shape),
|
||||
out_hidden_states[begin_chunk_idx:end_chunk_idx],
|
||||
)
|
||||
return out_hidden_states
|
||||
|
||||
|
||||
def fused_moe(
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
gating_output: torch.Tensor,
|
||||
topk: int,
|
||||
renormalize: bool,
|
||||
inplace: bool = False,
|
||||
use_grouped_topk: bool = False,
|
||||
num_expert_group: Optional[int] = None,
|
||||
topk_group: Optional[int] = None,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
use_fp8_w8a8: bool = False,
|
||||
use_int8_w8a16: bool = False,
|
||||
w1_scale: Optional[torch.Tensor] = None,
|
||||
w2_scale: Optional[torch.Tensor] = None,
|
||||
a1_scale: Optional[torch.Tensor] = None,
|
||||
a2_scale: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
This function computes a Mixture of Experts (MoE) layer using two sets of
|
||||
weights, w1 and w2, and top-k gating mechanism.
|
||||
|
||||
Parameters:
|
||||
- hidden_states (torch.Tensor): The input tensor to the MoE layer.
|
||||
- w1 (torch.Tensor): The first set of expert weights.
|
||||
- w2 (torch.Tensor): The second set of expert weights.
|
||||
- gating_output (torch.Tensor): The output of the gating operation
|
||||
(before softmax).
|
||||
- topk (int): The number of top-k experts to select.
|
||||
- renormalize (bool): If True, renormalize the top-k weights to sum to 1.
|
||||
- inplace (bool): If True, perform the operation in-place.
|
||||
Defaults to False.
|
||||
- num_expert_group: Optional[int]: additional parameter for grouped_topk
|
||||
- topk_group: Optional[int]: additional parameter for grouped_topk
|
||||
- use_grouped_topk: If True, use grouped_topk instead of fused_topk
|
||||
note: Deepseekv2 model uses grouped_topk
|
||||
- use_fp8_w8a8 (bool): If True, use fp8 arithmetic to compute the inner
|
||||
products for w1 and w2. Defaults to False.
|
||||
- use_int8_w8a16 (bool): If True, use fp8 arithmetic to compute the inner
|
||||
products for w1 and w2. Defaults to False.
|
||||
- w1_scale (Optional[torch.Tensor]): Optional scale to be used for
|
||||
w1.
|
||||
- w2_scale (Optional[torch.Tensor]): Optional scale to be used for
|
||||
w2.
|
||||
|
||||
Returns:
|
||||
- torch.Tensor: The output tensor after applying the MoE layer.
|
||||
"""
|
||||
# Check constraints.
|
||||
assert gating_output.shape[1] == w1.shape[0], "Number of experts mismatch"
|
||||
|
||||
if use_grouped_topk:
|
||||
assert num_expert_group is not None and topk_group is not None
|
||||
topk_weights, topk_ids = grouped_topk(
|
||||
hidden_states,
|
||||
gating_output,
|
||||
topk,
|
||||
renormalize,
|
||||
num_expert_group,
|
||||
topk_group,
|
||||
)
|
||||
elif custom_routing_function is None:
|
||||
topk_weights, topk_ids = fused_topk(
|
||||
hidden_states, gating_output, topk, renormalize
|
||||
)
|
||||
else:
|
||||
topk_weights, topk_ids = custom_routing_function(
|
||||
hidden_states, gating_output, topk, renormalize
|
||||
)
|
||||
|
||||
return fused_experts(
|
||||
hidden_states,
|
||||
w1,
|
||||
w2,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
inplace=inplace,
|
||||
use_fp8_w8a8=use_fp8_w8a8,
|
||||
use_int8_w8a16=use_int8_w8a16,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
a1_scale=a1_scale,
|
||||
a2_scale=a2_scale,
|
||||
)
|
||||
631
python/sglang/srt/layers/triton_fused_moe/layer.py
Normal file
631
python/sglang/srt/layers/triton_fused_moe/layer.py
Normal file
@@ -0,0 +1,631 @@
|
||||
from abc import abstractmethod
|
||||
from enum import Enum
|
||||
from typing import Callable, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from vllm.distributed import (
|
||||
get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size,
|
||||
tensor_model_parallel_all_reduce,
|
||||
)
|
||||
from vllm.model_executor.custom_op import CustomOp
|
||||
|
||||
from sglang.srt.layers.custom_op_util import register_custom_op
|
||||
from sglang.srt.layers.quantization.base_config import (
|
||||
QuantizationConfig,
|
||||
QuantizeMethodBase,
|
||||
)
|
||||
from sglang.srt.utils import set_weight_attrs
|
||||
|
||||
if torch.cuda.is_available() or torch.hip.is_available():
|
||||
from .fused_moe import fused_experts
|
||||
else:
|
||||
fused_experts = None # type: ignore
|
||||
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class FusedMoeWeightScaleSupported(Enum):
|
||||
TENSOR = "tensor"
|
||||
CHANNEL = "channel"
|
||||
GROUP = "group"
|
||||
|
||||
|
||||
class FusedMoEMethodBase(QuantizeMethodBase):
|
||||
|
||||
@abstractmethod
|
||||
def create_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
num_experts: int,
|
||||
hidden_size: int,
|
||||
intermediate_size: int,
|
||||
params_dtype: torch.dtype,
|
||||
**extra_weight_attrs,
|
||||
):
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
top_k: int,
|
||||
renormalize: bool,
|
||||
use_grouped_topk: bool,
|
||||
) -> torch.Tensor:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@register_custom_op("sglang_unquantized_fused_moe")
|
||||
class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
"""MoE method without quantization."""
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
num_experts: int,
|
||||
hidden_size: int,
|
||||
intermediate_size: int,
|
||||
params_dtype: torch.dtype,
|
||||
**extra_weight_attrs,
|
||||
):
|
||||
# Fused gate_up_proj (column parallel)
|
||||
w13_weight = torch.nn.Parameter(
|
||||
torch.empty(
|
||||
num_experts, 2 * intermediate_size, hidden_size, dtype=params_dtype
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w13_weight", w13_weight)
|
||||
set_weight_attrs(w13_weight, extra_weight_attrs)
|
||||
|
||||
# down_proj (row parallel)
|
||||
w2_weight = torch.nn.Parameter(
|
||||
torch.empty(
|
||||
num_experts, hidden_size, intermediate_size, dtype=params_dtype
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w2_weight", w2_weight)
|
||||
set_weight_attrs(w2_weight, extra_weight_attrs)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
top_k: int,
|
||||
renormalize: bool,
|
||||
use_grouped_topk: bool,
|
||||
topk_group: Optional[int] = None,
|
||||
num_expert_group: Optional[int] = None,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
) -> torch.Tensor:
|
||||
return self.forward(
|
||||
x=x,
|
||||
layer=layer,
|
||||
router_logits=router_logits,
|
||||
top_k=top_k,
|
||||
renormalize=renormalize,
|
||||
use_grouped_topk=use_grouped_topk,
|
||||
topk_group=topk_group,
|
||||
num_expert_group=num_expert_group,
|
||||
custom_routing_function=custom_routing_function,
|
||||
)
|
||||
|
||||
def forward_cuda(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
use_grouped_topk: bool,
|
||||
top_k: int,
|
||||
router_logits: torch.Tensor,
|
||||
renormalize: bool,
|
||||
topk_group: Optional[int] = None,
|
||||
num_expert_group: Optional[int] = None,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
) -> torch.Tensor:
|
||||
topk_weights, topk_ids = FusedMoE.select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
use_grouped_topk=use_grouped_topk,
|
||||
top_k=top_k,
|
||||
renormalize=renormalize,
|
||||
topk_group=topk_group,
|
||||
num_expert_group=num_expert_group,
|
||||
custom_routing_function=custom_routing_function,
|
||||
)
|
||||
|
||||
return fused_experts(
|
||||
hidden_states=x,
|
||||
w1=layer.w13_weight,
|
||||
w2=layer.w2_weight,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
inplace=True,
|
||||
)
|
||||
|
||||
def forward_cpu(self, *args, **kwargs):
|
||||
raise NotImplementedError("The CPU backend currently does not support MoE.")
|
||||
|
||||
def forward_tpu(self, *args, **kwargs) -> torch.Tensor:
|
||||
raise NotImplementedError("The TPU backend currently does not support MoE.")
|
||||
|
||||
forward_native = forward_cuda
|
||||
|
||||
|
||||
class FusedMoE(torch.nn.Module):
|
||||
"""FusedMoE layer for MoE models.
|
||||
|
||||
This layer contains both MergedColumnParallel weights (gate_up_proj /
|
||||
w13) and RowParallelLinear weights (down_proj/ w2).
|
||||
|
||||
Note: Mixtral uses w1, w2, and w3 for gate, up, and down_proj. We
|
||||
copy that naming convention here and handle any remapping in the
|
||||
load_weights function in each model implementation.
|
||||
|
||||
Args:
|
||||
num_experts: Number of experts in the model
|
||||
top_k: Number of experts selected for each token
|
||||
hidden_size: Input hidden state size of the transformer
|
||||
intermediate_size: Intermediate size of the experts
|
||||
params_dtype: Data type for the parameters.
|
||||
reduce_results: Whether to all all_reduce on the output of the layer
|
||||
renomalize: Whether to renormalize the logits in the fused_moe kernel
|
||||
quant_config: Quantization configure.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_experts: int,
|
||||
top_k: int,
|
||||
hidden_size: int,
|
||||
intermediate_size: int,
|
||||
params_dtype: Optional[torch.dtype] = None,
|
||||
reduce_results: bool = False,
|
||||
renormalize: bool = True,
|
||||
use_grouped_topk: bool = False,
|
||||
num_expert_group: Optional[int] = None,
|
||||
topk_group: Optional[int] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
tp_size: Optional[int] = None,
|
||||
prefix: str = "",
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
if params_dtype is None:
|
||||
params_dtype = torch.get_default_dtype()
|
||||
|
||||
self.tp_size = (
|
||||
tp_size if tp_size is not None else get_tensor_model_parallel_world_size()
|
||||
)
|
||||
self.top_k = top_k
|
||||
self.num_experts = num_experts
|
||||
self.intermediate_size_per_partition = intermediate_size // self.tp_size
|
||||
self.reduce_results = reduce_results
|
||||
self.renormalize = renormalize
|
||||
self.use_grouped_topk = use_grouped_topk
|
||||
if self.use_grouped_topk:
|
||||
assert num_expert_group is not None and topk_group is not None
|
||||
self.num_expert_group = num_expert_group
|
||||
self.topk_group = topk_group
|
||||
self.custom_routing_function = custom_routing_function
|
||||
|
||||
if quant_config is None:
|
||||
self.quant_method: Optional[QuantizeMethodBase] = (
|
||||
UnquantizedFusedMoEMethod()
|
||||
)
|
||||
else:
|
||||
self.quant_method = quant_config.get_quant_method(self, prefix)
|
||||
assert self.quant_method is not None
|
||||
|
||||
self.quant_method.create_weights(
|
||||
layer=self,
|
||||
num_experts=num_experts,
|
||||
hidden_size=hidden_size,
|
||||
intermediate_size=self.intermediate_size_per_partition,
|
||||
params_dtype=params_dtype,
|
||||
weight_loader=self.weight_loader,
|
||||
)
|
||||
|
||||
def _load_per_tensor_weight_scale(
|
||||
self,
|
||||
shard_id: str,
|
||||
param: torch.nn.Parameter,
|
||||
loaded_weight: torch.Tensor,
|
||||
expert_id: int,
|
||||
):
|
||||
param_data = param.data
|
||||
# for per tensor weight quantization
|
||||
if shard_id in ("w1", "w3"):
|
||||
# We have to keep the weight scales of w1 and w3 because
|
||||
# we need to re-quantize w1/w3 weights after weight loading.
|
||||
idx = 0 if shard_id == "w1" else 1
|
||||
param_data[expert_id][idx] = loaded_weight
|
||||
# If we are in the row parallel case (down_proj)
|
||||
elif shard_id == "w2":
|
||||
param_data[expert_id] = loaded_weight
|
||||
|
||||
def _load_model_weight_or_group_weight_scale(
|
||||
self,
|
||||
shard_dim: int,
|
||||
expert_data: torch.Tensor,
|
||||
shard_id: str,
|
||||
loaded_weight: torch.tensor,
|
||||
tp_rank: int,
|
||||
):
|
||||
# Load grouped weight scales for group quantization
|
||||
# or model weights
|
||||
if shard_id == "w2":
|
||||
self._load_w2(
|
||||
shard_id=shard_id,
|
||||
shard_dim=shard_dim,
|
||||
loaded_weight=loaded_weight,
|
||||
expert_data=expert_data,
|
||||
tp_rank=tp_rank,
|
||||
)
|
||||
elif shard_id in ("w1", "w3"):
|
||||
self._load_w13(
|
||||
shard_id=shard_id,
|
||||
shard_dim=shard_dim,
|
||||
loaded_weight=loaded_weight,
|
||||
expert_data=expert_data,
|
||||
tp_rank=tp_rank,
|
||||
)
|
||||
|
||||
def _load_per_channel_weight_scale(
|
||||
self,
|
||||
expert_data: torch.Tensor,
|
||||
shard_dim: int,
|
||||
shard_id: str,
|
||||
loaded_weight: torch.tensor,
|
||||
tp_rank: int,
|
||||
):
|
||||
# for per channel weight quantization
|
||||
if shard_id == "w2":
|
||||
expert_data.copy_(loaded_weight)
|
||||
elif shard_id in ("w1", "w3"):
|
||||
self._load_w13(
|
||||
shard_id=shard_id,
|
||||
shard_dim=shard_dim,
|
||||
loaded_weight=loaded_weight,
|
||||
expert_data=expert_data,
|
||||
tp_rank=tp_rank,
|
||||
)
|
||||
|
||||
def _load_w13(
|
||||
self,
|
||||
expert_data: torch.Tensor,
|
||||
shard_dim: int,
|
||||
shard_id: str,
|
||||
loaded_weight: torch.tensor,
|
||||
tp_rank: int,
|
||||
):
|
||||
|
||||
# Index the loaded weight for tp sharding.
|
||||
# gate_up_proj: "MergedColumnParallel", so tp sharding on output_dim
|
||||
shard_size = expert_data.shape[shard_dim] // 2
|
||||
loaded_weight = loaded_weight.narrow(
|
||||
shard_dim, shard_size * tp_rank, shard_size
|
||||
)
|
||||
# Narrow parameter and load.
|
||||
# w1, gate_proj: Load into first logical weight of w13.
|
||||
if shard_id == "w1":
|
||||
expert_data = expert_data.narrow(shard_dim, 0, shard_size)
|
||||
# w3, up_proj: Load into second logical weight of w13.
|
||||
else:
|
||||
assert shard_id == "w3"
|
||||
expert_data = expert_data.narrow(shard_dim, shard_size, shard_size)
|
||||
expert_data.copy_(loaded_weight)
|
||||
|
||||
def _load_w2(
|
||||
self,
|
||||
expert_data: torch.Tensor,
|
||||
shard_dim: int,
|
||||
shard_id: str,
|
||||
loaded_weight: torch.tensor,
|
||||
tp_rank: int,
|
||||
):
|
||||
|
||||
# Index the loaded weight for tp sharding.
|
||||
# down_proj: "RowParallel" so tp sharding on input_dim
|
||||
# Narrow parameter and load.
|
||||
shard_size = expert_data.shape[shard_dim]
|
||||
loaded_weight = loaded_weight.narrow(
|
||||
shard_dim, shard_size * tp_rank, shard_size
|
||||
)
|
||||
# w2, down_proj: Load into only logical weight of w2.
|
||||
expert_data.copy_(loaded_weight)
|
||||
|
||||
def _load_single_value(
|
||||
self, param: torch.nn.Parameter, loaded_weight: torch.Tensor, expert_id: int
|
||||
):
|
||||
param_data = param.data
|
||||
|
||||
# Input scales can be loaded directly and should be equal.
|
||||
param_data[expert_id] = loaded_weight
|
||||
|
||||
def _load_g_idx(
|
||||
self,
|
||||
shard_id: str,
|
||||
expert_data: torch.Tensor,
|
||||
shard_dim: int,
|
||||
loaded_weight: torch.tensor,
|
||||
tp_rank: int,
|
||||
):
|
||||
|
||||
if shard_id == "w2":
|
||||
self._load_w2(
|
||||
shard_id=shard_id,
|
||||
shard_dim=shard_dim,
|
||||
loaded_weight=loaded_weight,
|
||||
expert_data=expert_data,
|
||||
tp_rank=tp_rank,
|
||||
)
|
||||
else:
|
||||
assert shard_id in ("w1", "w3")
|
||||
expert_data.copy_(loaded_weight)
|
||||
|
||||
def weight_loader(
|
||||
self,
|
||||
param: torch.nn.Parameter,
|
||||
loaded_weight: torch.Tensor,
|
||||
weight_name: str,
|
||||
shard_id: str,
|
||||
expert_id: int,
|
||||
) -> None:
|
||||
|
||||
# compressed-tensors checkpoints with packed weights are stored flipped
|
||||
# TODO (mgoin): check self.quant_method.quant_config.quant_format
|
||||
# against known CompressionFormat enum values that have this quality
|
||||
loaded_weight = (
|
||||
loaded_weight.t().contiguous()
|
||||
if (
|
||||
self.quant_method.__class__.__name__
|
||||
== "CompressedTensorsWNA16MoEMethod"
|
||||
)
|
||||
else loaded_weight
|
||||
)
|
||||
|
||||
if shard_id not in ("w1", "w2", "w3"):
|
||||
raise ValueError(
|
||||
f"shard_id must be ['w1','w2','w3'] but " f"got {shard_id}."
|
||||
)
|
||||
|
||||
WEIGHT_SCALE_SUPPORTED = [e.value for e in FusedMoeWeightScaleSupported]
|
||||
# Fetch the dim to shard the parameter/loaded weight
|
||||
# based on the shard id. This will be whatever
|
||||
# dimension intermediate_size is used.
|
||||
SHARD_ID_TO_SHARDED_DIM = {"w1": 0, "w2": 1, "w3": 0}
|
||||
|
||||
expert_data = param.data[expert_id]
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
|
||||
# is_transposed: if the dim to shard the weight
|
||||
# should be flipped. Required by GPTQ, compressed-tensors
|
||||
# should be whatever dimension intermediate_size is
|
||||
is_transposed = getattr(param, "is_transposed", False)
|
||||
shard_dim = SHARD_ID_TO_SHARDED_DIM[shard_id]
|
||||
if is_transposed:
|
||||
shard_dim = ~shard_dim
|
||||
|
||||
# Case input scale: input_scale loading is only supported for fp8
|
||||
if "input_scale" in weight_name:
|
||||
# this is needed for compressed-tensors only
|
||||
loaded_weight = loaded_weight.to(param.data.device)
|
||||
|
||||
if (
|
||||
param.data[expert_id] != 1
|
||||
and (param.data[expert_id] - loaded_weight).abs() > 1e-5
|
||||
):
|
||||
raise ValueError(
|
||||
"input_scales of w1 and w3 of a layer "
|
||||
f"must be equal. But got {param.data[expert_id]} "
|
||||
f"vs. {loaded_weight}"
|
||||
)
|
||||
|
||||
self._load_single_value(
|
||||
param=param, loaded_weight=loaded_weight, expert_id=expert_id
|
||||
)
|
||||
return
|
||||
|
||||
# Case g_idx
|
||||
if "g_idx" in weight_name:
|
||||
self._load_g_idx(
|
||||
shard_dim=0,
|
||||
shard_id=shard_id,
|
||||
loaded_weight=loaded_weight,
|
||||
expert_data=expert_data,
|
||||
tp_rank=tp_rank,
|
||||
)
|
||||
return
|
||||
|
||||
# Case weight scales and zero_points
|
||||
if "scale" in weight_name or "zero" in weight_name:
|
||||
# load the weight scales and zp based on the quantization scheme
|
||||
# supported weight scales/zp can be found in
|
||||
# FusedMoeWeightScaleSupported
|
||||
# TODO @dsikka: once hardened, refactor to use vLLM Parameters
|
||||
# specific to each case
|
||||
quant_method = getattr(param, "quant_method", None)
|
||||
if quant_method == FusedMoeWeightScaleSupported.CHANNEL.value:
|
||||
self._load_per_channel_weight_scale(
|
||||
shard_id=shard_id,
|
||||
shard_dim=shard_dim,
|
||||
loaded_weight=loaded_weight,
|
||||
expert_data=expert_data,
|
||||
tp_rank=tp_rank,
|
||||
)
|
||||
elif quant_method == FusedMoeWeightScaleSupported.GROUP.value:
|
||||
self._load_model_weight_or_group_weight_scale(
|
||||
shard_id=shard_id,
|
||||
shard_dim=shard_dim,
|
||||
loaded_weight=loaded_weight,
|
||||
expert_data=expert_data,
|
||||
tp_rank=tp_rank,
|
||||
)
|
||||
elif quant_method == FusedMoeWeightScaleSupported.TENSOR.value:
|
||||
self._load_per_tensor_weight_scale(
|
||||
shard_id=shard_id,
|
||||
param=param,
|
||||
loaded_weight=loaded_weight,
|
||||
expert_id=expert_id,
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"quant method must be one of {WEIGHT_SCALE_SUPPORTED}"
|
||||
)
|
||||
return
|
||||
|
||||
# Case weight_shape
|
||||
if "weight_shape" in weight_name:
|
||||
# only required by compressed-tensors
|
||||
self._load_single_value(
|
||||
param=param, loaded_weight=loaded_weight, expert_id=expert_id
|
||||
)
|
||||
return
|
||||
|
||||
# Case model weights
|
||||
if "weight" in weight_name:
|
||||
self._load_model_weight_or_group_weight_scale(
|
||||
shard_id=shard_id,
|
||||
shard_dim=shard_dim,
|
||||
loaded_weight=loaded_weight,
|
||||
expert_data=expert_data,
|
||||
tp_rank=tp_rank,
|
||||
)
|
||||
return
|
||||
|
||||
@staticmethod
|
||||
def select_experts(
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
top_k: int,
|
||||
use_grouped_topk: bool,
|
||||
renormalize: bool,
|
||||
topk_group: Optional[int] = None,
|
||||
num_expert_group: Optional[int] = None,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
):
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import (
|
||||
fused_topk,
|
||||
grouped_topk,
|
||||
)
|
||||
|
||||
# DeekSeekv2 uses grouped_top_k
|
||||
if use_grouped_topk:
|
||||
assert topk_group is not None
|
||||
assert num_expert_group is not None
|
||||
topk_weights, topk_ids = grouped_topk(
|
||||
hidden_states=hidden_states,
|
||||
gating_output=router_logits,
|
||||
topk=top_k,
|
||||
renormalize=renormalize,
|
||||
num_expert_group=num_expert_group,
|
||||
topk_group=topk_group,
|
||||
)
|
||||
elif custom_routing_function is None:
|
||||
topk_weights, topk_ids = fused_topk(
|
||||
hidden_states=hidden_states,
|
||||
gating_output=router_logits,
|
||||
topk=top_k,
|
||||
renormalize=renormalize,
|
||||
)
|
||||
else:
|
||||
topk_weights, topk_ids = custom_routing_function(
|
||||
hidden_states=hidden_states,
|
||||
gating_output=router_logits,
|
||||
topk=top_k,
|
||||
renormalize=renormalize,
|
||||
)
|
||||
|
||||
return topk_weights, topk_ids
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor):
|
||||
assert self.quant_method is not None
|
||||
|
||||
# Matrix multiply.
|
||||
final_hidden_states = self.quant_method.apply(
|
||||
layer=self,
|
||||
x=hidden_states,
|
||||
router_logits=router_logits,
|
||||
top_k=self.top_k,
|
||||
renormalize=self.renormalize,
|
||||
use_grouped_topk=self.use_grouped_topk,
|
||||
topk_group=self.topk_group,
|
||||
num_expert_group=self.num_expert_group,
|
||||
custom_routing_function=self.custom_routing_function,
|
||||
)
|
||||
|
||||
if self.reduce_results and self.tp_size > 1:
|
||||
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
|
||||
|
||||
return final_hidden_states
|
||||
|
||||
@classmethod
|
||||
def make_expert_params_mapping(
|
||||
cls,
|
||||
ckpt_gate_proj_name: str,
|
||||
ckpt_down_proj_name: str,
|
||||
ckpt_up_proj_name: str,
|
||||
num_experts: int,
|
||||
) -> List[Tuple[str, str, int, str]]:
|
||||
|
||||
return [
|
||||
# (param_name, weight_name, expert_id, shard_id)
|
||||
(
|
||||
(
|
||||
"experts.w13_"
|
||||
if weight_name in [ckpt_gate_proj_name, ckpt_up_proj_name]
|
||||
else "experts.w2_"
|
||||
),
|
||||
f"experts.{expert_id}.{weight_name}.",
|
||||
expert_id,
|
||||
shard_id,
|
||||
)
|
||||
for expert_id in range(num_experts)
|
||||
for shard_id, weight_name in [
|
||||
("w1", ckpt_gate_proj_name),
|
||||
("w2", ckpt_down_proj_name),
|
||||
("w3", ckpt_up_proj_name),
|
||||
]
|
||||
]
|
||||
|
||||
def _load_fp8_scale(
|
||||
self,
|
||||
param: torch.nn.Parameter,
|
||||
loaded_weight: torch.Tensor,
|
||||
weight_name: str,
|
||||
shard_id: str,
|
||||
expert_id: int,
|
||||
) -> None:
|
||||
param_data = param.data
|
||||
|
||||
# Input scales can be loaded directly and should be equal.
|
||||
if "input_scale" in weight_name:
|
||||
if (
|
||||
param_data[expert_id] != 1
|
||||
and (param_data[expert_id] - loaded_weight).abs() > 1e-5
|
||||
):
|
||||
raise ValueError(
|
||||
"input_scales of w1 and w3 of a layer "
|
||||
f"must be equal. But got {param_data[expert_id]} "
|
||||
f"vs. {loaded_weight}"
|
||||
)
|
||||
param_data[expert_id] = loaded_weight
|
||||
# Weight scales
|
||||
elif "weight_scale" in weight_name:
|
||||
# If we are in merged column case (gate_up_proj)
|
||||
if shard_id in ("w1", "w3"):
|
||||
# We have to keep the weight scales of w1 and w3 because
|
||||
# we need to re-quantize w1/w3 weights after weight loading.
|
||||
idx = 0 if shard_id == "w1" else 1
|
||||
param_data[expert_id][idx] = loaded_weight
|
||||
# If we are in the row parallel case (down_proj)
|
||||
else:
|
||||
param_data[expert_id] = loaded_weight
|
||||
@@ -27,7 +27,6 @@ from vllm.distributed import (
|
||||
get_tp_group,
|
||||
tensor_model_parallel_all_reduce,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoE
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
|
||||
@@ -42,6 +41,7 @@ from sglang.srt.layers.linear import (
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||
from sglang.srt.layers.radix_attention import RadixAttention
|
||||
from sglang.srt.layers.triton_fused_moe import FusedMoE
|
||||
from sglang.srt.layers.vocab_parallel_embedding import (
|
||||
ParallelLMHead,
|
||||
VocabParallelEmbedding,
|
||||
|
||||
@@ -31,7 +31,7 @@ import time
|
||||
import warnings
|
||||
from importlib.metadata import PackageNotFoundError, version
|
||||
from io import BytesIO
|
||||
from typing import Any, Dict, List, Optional, Protocol, Tuple, Union
|
||||
from typing import Any, Callable, Dict, List, Optional, Protocol, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import psutil
|
||||
@@ -45,6 +45,7 @@ from packaging import version as pkg_version
|
||||
from starlette.routing import Mount
|
||||
from torch import nn
|
||||
from torch.func import functional_call
|
||||
from torch.library import Library
|
||||
from torch.profiler import ProfilerActivity, profile, record_function
|
||||
from triton.runtime.cache import (
|
||||
FileCacheManager,
|
||||
@@ -930,3 +931,44 @@ def get_nvgpu_memory_capacity():
|
||||
def crash_on_warnings():
|
||||
# Crash on warning if we are running CI tests
|
||||
return os.getenv("SGLANG_IS_IN_CI", "false") == "true"
|
||||
|
||||
|
||||
def get_device_name(device_id: int = 0) -> str:
|
||||
if hasattr(torch, "cuda") and torch.cuda.is_available():
|
||||
return torch.cuda.get_device_name(device_id)
|
||||
|
||||
if hasattr(torch, "hip") and torch.hip.is_available():
|
||||
return torch.hip.get_device_name(device_id)
|
||||
|
||||
if hasattr(torch, "xpu") and torch.xpu.is_available():
|
||||
return torch.xpu.get_device_name(device_id)
|
||||
|
||||
if hasattr(torch, "hpu") and torch.hpu.is_available():
|
||||
return torch.hpu.get_device_name(device_id)
|
||||
|
||||
|
||||
sglang_lib = Library("sglang", "FRAGMENT") # noqa
|
||||
|
||||
|
||||
def direct_register_custom_op(
|
||||
op_name: str,
|
||||
op_func: Callable,
|
||||
mutates_args: List[str],
|
||||
fake_impl: Optional[Callable] = None,
|
||||
target_lib: Optional[Library] = None,
|
||||
):
|
||||
import torch.library
|
||||
|
||||
if hasattr(torch.library, "infer_schema"):
|
||||
schema_str = torch.library.infer_schema(op_func, mutates_args=mutates_args)
|
||||
else:
|
||||
# for pytorch 2.4
|
||||
import torch._custom_op.impl
|
||||
|
||||
schema_str = torch._custom_op.impl.infer_schema(op_func, mutates_args)
|
||||
|
||||
my_lib = target_lib or sglang_lib
|
||||
my_lib.define(op_name + schema_str)
|
||||
my_lib.impl(op_name, op_func, "CUDA")
|
||||
if fake_impl is not None:
|
||||
my_lib._register_fake(op_name, fake_impl)
|
||||
|
||||
Reference in New Issue
Block a user