First commit

This commit is contained in:
2025-08-05 19:02:46 +08:00
parent 9efe891f99
commit 99fb9f5cb0
1412 changed files with 203615 additions and 0 deletions

View File

@@ -0,0 +1,26 @@
from vllm.model_executor.layers.fused_moe.layer import (
FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported)
from vllm.triton_utils import HAS_TRITON
__all__ = [
"FusedMoE",
"FusedMoEMethodBase",
"FusedMoeWeightScaleSupported",
]
# if HAS_TRITON:
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import (
fused_marlin_moe, single_marlin_moe)
from vllm.model_executor.layers.fused_moe.fused_moe import (
fused_experts, fused_moe, fused_topk, get_config_file_name,
grouped_topk)
__all__ += [
"fused_marlin_moe",
"single_marlin_moe",
"fused_moe",
"fused_topk",
"fused_experts",
"get_config_file_name",
"grouped_topk",
]

View File

@@ -0,0 +1,313 @@
"""Fused MoE utilities for GPTQ."""
import functools
from typing import Any, Dict, Optional
import torch
from vllm import _custom_ops as ops
from vllm.model_executor.layers.fused_moe.fused_moe import (
fused_topk, moe_align_block_size, try_get_optimal_moe_config)
from vllm.scalar_type import scalar_types
def get_scalar_type(num_bits: int, has_zp: bool):
if has_zp:
assert num_bits == 4
return scalar_types.uint4
else:
return scalar_types.uint4b8 if num_bits == 4 else scalar_types.uint8b128
def single_marlin_moe(
hidden_states: torch.Tensor,
w: torch.Tensor,
scales: torch.Tensor,
gating_output: torch.Tensor,
topk: int,
renormalize: bool,
g_idx: Optional[torch.Tensor] = None,
sort_indices: Optional[torch.Tensor] = None,
w_zeros: Optional[torch.Tensor] = None,
override_config: Optional[Dict[str, Any]] = None,
num_bits: int = 8,
is_k_full: bool = True,
) -> torch.Tensor:
"""
This function computes the multiplication of hidden_states with expert
weights used in Marlin MoE, using weights w and top-k gating mechanism.
Its purpose is testing and debugging the fused MoE kernel.
Parameters:
- hidden_states (torch.Tensor): The input tensor to the Marlin Mul.
- w (torch.Tensor): The set of expert weights.
- scales (torch.Tensor): The quantization scales.
- gating_output (torch.Tensor): The output of the gating operation
(before softmax).
- g_idx (Optional[torch.Tensor]): Optional act_order indices.
- sort_indices (Optional[torch.Tensor]): Optional act_order input
permutation.
- topk (int): The number of top-k experts to select.
- renormalize (bool): If True, renormalize the top-k weights to sum to 1.
- w_zeros (Optional[torch.Tensor]): Optional zero points to be used for w.
- override_config (Optional[Dict[str, Any]]): Optional override
for the kernel configuration.
- num_bits (bool): The number of bits in expert weights quantization.
Returns:
- torch.Tensor: The output tensor after applying the MoE layer.
"""
# Check constraints.
assert hidden_states.shape[0] == gating_output.shape[0], (
"Number of tokens mismatch")
assert hidden_states.shape[1] == w.shape[1] * 16, "Hidden size mismatch"
assert gating_output.shape[1] == w.shape[0], "Number of experts mismatch"
assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
assert w.is_contiguous(), "Expert weights must be contiguous"
assert hidden_states.dtype == torch.float16
assert num_bits in [4, 8]
M, K = hidden_states.shape
E = w.shape[0]
N = w.shape[2] // (num_bits // 2)
topk_weights, topk_ids = fused_topk(hidden_states, gating_output, topk,
renormalize)
# This might not be an optimal config for a single MMM
get_config_func = functools.partial(try_get_optimal_moe_config,
w.shape,
w.shape,
topk_ids.shape[1],
None,
override_config=override_config,
is_marlin=True)
config = get_config_func(M)
block_size_m = config['BLOCK_SIZE_M']
sorted_token_ids, _, _ = moe_align_block_size(topk_ids, block_size_m, E)
max_workspace_size = (N // 64) * 16
workspace = torch.zeros(max_workspace_size,
dtype=torch.int,
device=hidden_states.device,
requires_grad=False)
has_zero_point = w_zeros is not None
if w_zeros is None:
w_zeros = torch.empty((0, 0),
dtype=hidden_states.dtype,
device=hidden_states.device,
requires_grad=False)
if g_idx is None:
g_idx = torch.empty((0, 0),
dtype=torch.int32,
device=hidden_states.device,
requires_grad=False)
if sort_indices is None:
sort_indices = torch.empty((0),
dtype=torch.int32,
device=hidden_states.device,
requires_grad=False)
scalar_type = get_scalar_type(num_bits, has_zero_point)
intermediate_cache = torch.ops._moe_C.marlin_gemm_moe(
hidden_states, w, sorted_token_ids, topk_weights, topk_ids, scales,
w_zeros, g_idx, sort_indices, workspace, scalar_type, M, N, K,
is_k_full, E, topk, block_size_m, True, False)
return torch.sum(intermediate_cache.view(*intermediate_cache.shape), dim=1)
def fused_marlin_moe(
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
w1_scale: torch.Tensor,
w2_scale: torch.Tensor,
gating_output: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
g_idx1: Optional[torch.Tensor] = None,
g_idx2: Optional[torch.Tensor] = None,
sort_indices1: Optional[torch.Tensor] = None,
sort_indices2: Optional[torch.Tensor] = None,
w1_zeros: Optional[torch.Tensor] = None,
w2_zeros: Optional[torch.Tensor] = None,
override_config: Optional[Dict[str, Any]] = None,
num_bits: int = 8,
is_k_full: bool = True,
) -> torch.Tensor:
"""
This function computes a Mixture of Experts (MoE) layer using two sets of
weights, w1 and w2, and top-k gating mechanism.
Parameters:
- hidden_states (torch.Tensor): The input tensor to the MoE layer.
- w1 (torch.Tensor): The first set of expert weights.
- w2 (torch.Tensor): The second set of expert weights.
- w1_scale (torch.Tensor): Scale to be used for w1.
- w2_scale (torch.Tensor): Scale to be used for w2.
- gating_output (torch.Tensor): The output of the gating operation
(before softmax).
- g_idx1 (Optional[torch.Tensor]): The first set of act_order indices.
- g_idx2 (Optional[torch.Tensor]): The second set of act_order indices.
- sort_indices1 (Optional[torch.Tensor]): The first act_order input
permutation.
- sort_indices2 (Optional[torch.Tensor]): The second act_order input
permutation.
- topk_weights (torch.Tensor): Top-k weights.
- topk_ids (torch.Tensor): Indices of topk-k elements.
- override_config (Optional[Dict[str, Any]]): Optional override
for the kernel configuration.
- w1_zeros (Optional[torch.Tensor]): Optional zero points to be used for w1.
- w2_zeros (Optional[torch.Tensor]): Optional zero points to be used for w2.
- num_bits (bool): The number of bits in expert weights quantization.
Returns:
- torch.Tensor: The output tensor after applying the MoE layer.
"""
# Check constraints.
assert hidden_states.shape[0] == gating_output.shape[
0], "Number of tokens mismatch"
assert hidden_states.shape[
1] == w1.shape[1] * 16, "Hidden size mismatch w1"
assert hidden_states.shape[1] == w2.shape[2] // (
num_bits // 2), "Hidden size mismatch w2"
assert gating_output.shape[1] == w1.shape[0], "Number of experts mismatch"
assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
assert w1.is_contiguous(), "Expert weights1 must be contiguous"
assert w2.is_contiguous(), "Expert weights2 must be contiguous"
assert hidden_states.dtype == torch.float16
assert num_bits in [4, 8]
has_no_act_order = (g_idx1 is None and g_idx2 is None
and sort_indices1 is None and sort_indices2 is None)
has_all_act_order = (g_idx1 is not None and g_idx2 is not None
and sort_indices1 is not None
and sort_indices2 is not None)
assert has_no_act_order or has_all_act_order, (
"g_idx and sorted_indices "
"must be all not None or must be all None")
has_no_zp = w1_zeros is None and w2_zeros is None
has_all_zp = w1_zeros is not None and w2_zeros is not None
assert has_no_zp or has_all_zp, ("zero points must be both not None or "
"must be both None")
M, K = hidden_states.shape
E = w1.shape[0]
N = w2.shape[1] * 16
topk = topk_ids.shape[1]
get_config_func = functools.partial(
try_get_optimal_moe_config,
w1.shape,
w2.shape,
topk_ids.shape[1],
None,
override_config=override_config,
is_marlin=True,
)
config = get_config_func(M)
block_size_m = config["BLOCK_SIZE_M"]
sorted_token_ids, _, _ = moe_align_block_size(topk_ids, block_size_m, E)
max_workspace_size = (max(2 * N, K) // 64) * 16
workspace = torch.zeros(max_workspace_size,
dtype=torch.int,
device="cuda",
requires_grad=False)
if has_no_zp:
w1_zeros = torch.empty((0, 0),
dtype=hidden_states.dtype,
device=hidden_states.device,
requires_grad=False)
w2_zeros = torch.empty((0, 0),
dtype=hidden_states.dtype,
device=hidden_states.device,
requires_grad=False)
if has_no_act_order:
g_idx1 = torch.empty((0, 0),
dtype=torch.int32,
device=hidden_states.device,
requires_grad=False)
g_idx2 = torch.empty((0, 0),
dtype=torch.int32,
device=hidden_states.device,
requires_grad=False)
sort_indices1 = torch.empty((0),
dtype=torch.int32,
device=hidden_states.device,
requires_grad=False)
sort_indices2 = torch.empty((0, 0),
dtype=torch.int32,
device=hidden_states.device,
requires_grad=False)
scalar_type1 = get_scalar_type(num_bits, has_all_zp)
scalar_type2 = get_scalar_type(num_bits, has_all_zp)
intermediate_cache2 = torch.empty(
(M * topk_ids.shape[1], N),
device=hidden_states.device,
dtype=hidden_states.dtype,
)
intermediate_cache1 = torch.ops._moe_C.marlin_gemm_moe(
hidden_states,
w1,
sorted_token_ids,
topk_weights,
topk_ids,
w1_scale,
w1_zeros,
g_idx1,
sort_indices1,
workspace,
scalar_type1,
M,
2 * N,
K,
is_k_full,
E,
topk,
block_size_m,
True,
False,
)
ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, 2 * N))
intermediate_cache3 = torch.ops._moe_C.marlin_gemm_moe(
intermediate_cache2,
w2,
sorted_token_ids,
topk_weights,
topk_ids,
w2_scale,
w2_zeros,
g_idx2,
sort_indices2,
workspace,
scalar_type2,
M,
K,
N,
is_k_full,
E,
topk,
block_size_m,
False,
True,
)
return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape),
dim=1)

View File

@@ -0,0 +1,693 @@
"""Fused MoE kernel."""
import functools
import json
import os
from typing import Any, Callable, Dict, Optional, Tuple
import torch
import triton
import triton.language as tl
import vllm.envs as envs
from vllm import _custom_ops as ops
from vllm.logger import init_logger
from vllm.platforms import current_platform
logger = init_logger(__name__)
@triton.jit
def fused_moe_kernel(
# Pointers to matrices
a_ptr,
b_ptr,
c_ptr,
a_scale_ptr,
b_scale_ptr,
topk_weights_ptr,
sorted_token_ids_ptr,
expert_ids_ptr,
num_tokens_post_padded_ptr,
# Matrix dimensions
N,
K,
EM,
num_valid_tokens,
# The stride variables represent how much to increase the ptr by when
# moving by 1 element in a particular dimension. E.g. `stride_am` is
# how much to increase `a_ptr` by to get the element one row down
# (A has M rows).
stride_am,
stride_ak,
stride_be,
stride_bk,
stride_bn,
stride_cm,
stride_cn,
stride_bse,
stride_bsn,
# Meta-parameters
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr,
GROUP_SIZE_M: tl.constexpr,
MUL_ROUTED_WEIGHT: tl.constexpr,
top_k: tl.constexpr,
compute_type: tl.constexpr,
use_fp8_w8a8: tl.constexpr,
use_int8_w8a16: tl.constexpr):
"""
Implements the fused computation for a Mixture of Experts (MOE) using
token and expert matrices.
Key Parameters:
- A: The input tensor representing tokens with shape (*, K), where '*' can
be any shape representing batches and K is the feature dimension of
each token.
- B: The stacked MOE weight tensor with shape (E, N, K), where E is
the number of experts, K is the input feature dimension, and N is
the output feature dimension.
- C: The output cache tensor with shape (M, topk, N), where M is the
total number of tokens post padding, topk is the number of times
each token is repeated, and N is the output feature dimension.
- sorted_token_ids: A tensor containing the sorted indices of tokens,
repeated topk times and arranged by the expert index they are
assigned to.
- expert_ids: A tensor containing the indices of the expert for each
block. It determines which expert matrix from B should be used for
each block in A.
This kernel performs the multiplication of a token by its corresponding
expert matrix as determined by `expert_ids`. The sorting of
`sorted_token_ids` by expert index and padding ensures divisibility by
BLOCK_SIZE_M, which is necessary to maintain consistency in block matrix
multiplication across different blocks processed by the same expert.
"""
# -----------------------------------------------------------
# Map program ids `pid` to the block of C it should compute.
# This is done in a grouped ordering to promote L2 data reuse.
pid = tl.program_id(axis=0)
num_pid_m = tl.cdiv(EM, BLOCK_SIZE_M)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
num_pid_in_group = GROUP_SIZE_M * num_pid_n
group_id = pid // num_pid_in_group
first_pid_m = group_id * GROUP_SIZE_M
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
pid_n = (pid % num_pid_in_group) // group_size_m
# ----------------------------------------------------------
# Create pointers for the first blocks of A and B.
# We will advance this pointer as we move in the K direction
# and accumulate
# `a_ptrs` is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers
# `b_ptrs` is a block of [BLOCK_SIZE_K, BLOCK_SIZE_N] pointers
num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr)
if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded:
return
offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_token = tl.load(sorted_token_ids_ptr + offs_token_id)
token_mask = offs_token < num_valid_tokens
offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
offs_k = tl.arange(0, BLOCK_SIZE_K)
a_ptrs = a_ptr + (offs_token[:, None] // top_k * stride_am +
offs_k[None, :] * stride_ak)
off_experts = tl.load(expert_ids_ptr + pid_m)
b_ptrs = b_ptr + off_experts * stride_be + (offs_k[:, None] * stride_bk +
offs_bn[None, :] * stride_bn)
if use_int8_w8a16:
b_scale_ptrs = b_scale_ptr + off_experts * stride_bse + offs_bn[
None, :] * stride_bsn
b_scale = tl.load(b_scale_ptrs)
if use_fp8_w8a8:
a_scale = tl.load(a_scale_ptr)
b_scale = tl.load(b_scale_ptr + off_experts)
# -----------------------------------------------------------
# Iterate to compute a block of the C matrix.
# We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block
# of fp32 values for higher accuracy.
# `accumulator` will be converted back to fp16 after the loop.
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
# Load the next block of A and B, generate a mask by checking the
# K dimension.
a = tl.load(a_ptrs,
mask=token_mask[:, None] &
(offs_k[None, :] < K - k * BLOCK_SIZE_K),
other=0.0)
b = tl.load(b_ptrs,
mask=offs_k[:, None] < K - k * BLOCK_SIZE_K,
other=0.0)
# We accumulate along the K dimension.
if use_int8_w8a16:
accumulator = tl.dot(a, b.to(compute_type), acc=accumulator)
elif use_fp8_w8a8:
accumulator = tl.dot(a, b, acc=accumulator)
else:
accumulator += tl.dot(a, b)
# Advance the ptrs to the next K block.
a_ptrs += BLOCK_SIZE_K * stride_ak
b_ptrs += BLOCK_SIZE_K * stride_bk
if MUL_ROUTED_WEIGHT:
moe_weight = tl.load(topk_weights_ptr + offs_token,
mask=token_mask,
other=0)
accumulator = accumulator * moe_weight[:, None]
if use_int8_w8a16:
accumulator = (accumulator * b_scale).to(compute_type)
elif use_fp8_w8a8:
accumulator = (accumulator * a_scale * b_scale).to(compute_type)
else:
accumulator = accumulator.to(compute_type)
# -----------------------------------------------------------
# Write back the block of the output
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[
None, :]
c_mask = token_mask[:, None] & (offs_cn[None, :] < N)
tl.store(c_ptrs, accumulator, mask=c_mask)
def moe_align_block_size(
topk_ids: torch.Tensor, block_size: int,
num_experts: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Aligns the token distribution across experts to be compatible with block
size for matrix multiplication.
Parameters:
- topk_ids: A tensor of shape [total_tokens, top_k] representing the
top-k expert indices for each token.
- block_size: The block size used in block matrix multiplication.
- num_experts: The total number of experts.
Returns:
- sorted_token_ids: A tensor containing the sorted token indices according
to their allocated expert.
- expert_ids: A tensor indicating the assigned expert index for each block.
- num_tokens_post_padded: The total number of tokens after padding,
ensuring divisibility by block_size.
This function pads the number of tokens that each expert needs to process
so that it is divisible by block_size.
Padding ensures that during block matrix multiplication, the dimensions
align correctly.
Example:
Given topk_ids = [[2, 3, 4], [1, 2, 4], [1, 3, 4], [1, 2, 3]],
block_size = 4, and num_experts = 4:
- We initially have 12 tokens (after repeating 'top_k' times) and 4 experts,
with each expert needing to process 3 tokens.
- As block_size is 4, we pad 1 token for each expert.
- First, flatten topk_ids to [2, 3, 4, 1, 2, 4, 1, 3, 4, 1, 2, 3].
- Then append padding tokens [12, 12, 12, 12] for each block.
- After sorting by expert index, we obtain token_ids
[3, 6, 9, 12, 0, 4, 10, 12, 1, 7, 11, 12, 2, 5, 8, 12].
Tokens 12 are non-existent (padding) and are ignored in
the subsequent matrix multiplication.
- The padding ensures that the total number of tokens is now divisible
by block_size for proper block matrix operations.
"""
max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1)
sorted_ids = torch.empty((max_num_tokens_padded, ),
dtype=torch.int32,
device=topk_ids.device)
sorted_ids.fill_(topk_ids.numel())
# max_num_m_blocks = triton.cdiv(max_num_tokens_padded, block_size)
max_num_m_blocks = topk_ids.numel() + num_experts
expert_ids = torch.empty((max_num_m_blocks, ),
dtype=torch.int32,
device=topk_ids.device)
num_tokens_post_pad = torch.empty((1),
dtype=torch.int32,
device=topk_ids.device)
ops.moe_align_block_size(topk_ids, num_experts, block_size, sorted_ids,
expert_ids, num_tokens_post_pad)
return sorted_ids, expert_ids, num_tokens_post_pad
def invoke_fused_moe_kernel(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor,
A_scale: Optional[torch.Tensor],
B_scale: Optional[torch.Tensor],
topk_weights: torch.Tensor, topk_ids: torch.Tensor,
sorted_token_ids: torch.Tensor,
expert_ids: torch.Tensor,
num_tokens_post_padded: torch.Tensor,
mul_routed_weight: bool, top_k: int,
config: Dict[str, Any], compute_type: tl.dtype,
use_fp8_w8a8: bool, use_int8_w8a16: bool) -> None:
ops.invoke_fused_moe_kernel(A,B,C,A_scale,B_scale,topk_weights,topk_ids,sorted_token_ids,expert_ids,num_tokens_post_padded,mul_routed_weight,top_k,config,compute_type,use_fp8_w8a8,use_int8_w8a16)
return
assert topk_weights.stride(1) == 1
assert sorted_token_ids.stride(0) == 1
if use_fp8_w8a8:
A, A_scale = ops.scaled_fp8_quant(A, A_scale)
assert B_scale is not None
elif use_int8_w8a16:
assert B_scale is not None
else:
assert A_scale is None
assert B_scale is None
grid = lambda META: (triton.cdiv(sorted_token_ids.shape[0], META[
'BLOCK_SIZE_M']) * triton.cdiv(B.shape[1], META['BLOCK_SIZE_N']), )
fused_moe_kernel[grid](
A,
B,
C,
A_scale,
B_scale,
topk_weights,
sorted_token_ids,
expert_ids,
num_tokens_post_padded,
B.shape[1],
B.shape[2],
sorted_token_ids.shape[0],
topk_ids.numel(),
A.stride(0),
A.stride(1),
B.stride(0),
B.stride(2),
B.stride(1),
C.stride(1),
C.stride(2),
B_scale.stride(0) if B_scale is not None and use_int8_w8a16 else 0,
B_scale.stride(1) if B_scale is not None and use_int8_w8a16 else 0,
MUL_ROUTED_WEIGHT=mul_routed_weight,
top_k=top_k,
compute_type=compute_type,
use_fp8_w8a8=use_fp8_w8a8,
use_int8_w8a16=use_int8_w8a16,
**config,
)
def get_config_file_name(E: int, N: int, dtype: Optional[str]) -> str:
device_name = current_platform.get_device_name().replace(" ", "_")
dtype_selector = "" if not dtype else f",dtype={dtype}"
return f"E={E},N={N},device_name={device_name}{dtype_selector}.json"
@functools.lru_cache
def get_moe_configs(E: int, N: int,
dtype: Optional[str]) -> Optional[Dict[int, Any]]:
"""
Return optimized configurations for the fused MoE kernel.
The return value will be a dictionary that maps an irregular grid of
batch sizes to configurations of the fused_moe kernel. To evaluate the
kernel on a given batch size bs, the closest batch size in the grid should
be picked and the associated configuration chosen to invoke the kernel.
"""
# First look up if an optimized configuration is available in the configs
# directory
json_file_name = get_config_file_name(E, N, dtype)
config_file_path = os.path.join(
os.path.dirname(os.path.realpath(__file__)), "configs", json_file_name)
if os.path.exists(config_file_path):
with open(config_file_path) as f:
logger.info("Using configuration from %s for MoE layer.",
config_file_path)
# If a configuration has been found, return it
return {int(key): val for key, val in json.load(f).items()}
# If no optimized configuration is available, we will use the default
# configuration
logger.warning(
("Using default MoE config. Performance might be sub-optimal! "
"Config file not found at %s"), config_file_path)
return None
def get_default_config(
M: int,
E: int,
N: int,
K: int,
topk: int,
dtype: Optional[str],
is_marlin: bool,
) -> Dict[str, int]:
config = {
'BLOCK_SIZE_M': 64,
'BLOCK_SIZE_N': 64,
'BLOCK_SIZE_K': 32,
'GROUP_SIZE_M': 8
}
# A heuristic: fused marlin works faster with this config for small M
if M <= E or (is_marlin and M <= 32):
config = {
'BLOCK_SIZE_M': 16,
'BLOCK_SIZE_N': 32,
'BLOCK_SIZE_K': 64,
'GROUP_SIZE_M': 1
}
numel = M * topk
if numel <= 64:
config['BLOCK_SIZE_M'] = 32
elif numel <= 1024:
config['BLOCK_SIZE_M'] = 64
else:
config['BLOCK_SIZE_M'] = 256
return config
def try_get_optimal_moe_config(
w1_shape: Tuple[int, ...],
w2_shape: Tuple[int, ...],
top_k: int,
dtype: Optional[str],
M: int,
override_config: Optional[Dict[str, Any]] = None,
is_marlin: bool = False,
):
if override_config:
config = override_config
else:
# First try to load optimal config from the file
E, _, N = w2_shape
# configs = get_moe_configs(E, N, dtype)
configs = None
if configs:
# If an optimal configuration map has been found, look up the
# optimal config
config = configs[min(configs.keys(), key=lambda x: abs(x - M))]
else:
# Else use the default config
config = get_default_config(M, E, N, w1_shape[2], top_k, dtype,
is_marlin)
return config
def fused_topk(
hidden_states: torch.Tensor,
gating_output: torch.Tensor,
topk: int,
renormalize: bool,
):
assert hidden_states.shape[0] == gating_output.shape[0], (
"Number of tokens mismatch")
M, _ = hidden_states.shape
topk_weights = torch.empty(M,
topk,
dtype=torch.float32,
device=hidden_states.device)
topk_ids = torch.empty(M,
topk,
dtype=torch.int32,
device=hidden_states.device)
token_expert_indicies = torch.empty(M,
topk,
dtype=torch.int32,
device=hidden_states.device)
ops.topk_softmax(
topk_weights,
topk_ids,
token_expert_indicies,
gating_output.float(), # TODO(woosuk): Optimize this.
)
del token_expert_indicies # Not used. Will be used in the future.
if renormalize:
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
return topk_weights, topk_ids
# This is used by the Deepseek-V2 model
def grouped_topk(hidden_states: torch.Tensor,
gating_output: torch.Tensor,
topk: int,
renormalize: bool,
num_expert_group: int = 0,
topk_group: int = 0):
assert hidden_states.shape[0] == gating_output.shape[0], (
"Number of tokens mismatch")
scores = torch.softmax(gating_output, dim=-1)
num_token = scores.shape[0]
group_scores = scores.view(num_token, num_expert_group,
-1).max(dim=-1).values # [n, n_group]
group_idx = torch.topk(group_scores, k=topk_group, dim=-1,
sorted=False)[1] # [n, top_k_group]
group_mask = torch.zeros_like(group_scores) # [n, n_group]
group_mask.scatter_(1, group_idx, 1) # [n, n_group]
score_mask = group_mask.unsqueeze(-1).expand(
num_token, num_expert_group,
scores.shape[-1] // num_expert_group).reshape(num_token, -1) # [n, e]
tmp_scores = scores.masked_fill(~score_mask.bool(), 0.0) # [n, e]
topk_weights, topk_ids = torch.topk(tmp_scores,
k=topk,
dim=-1,
sorted=False)
if renormalize:
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
return topk_weights.to(torch.float32), topk_ids.to(torch.int32)
def get_config_dtype_str(dtype: torch.dtype,
use_int8_w8a16: Optional[bool] = False,
use_fp8_w8a8: Optional[bool] = False):
if use_fp8_w8a8:
return "fp8_w8a8"
elif use_int8_w8a16:
return "int8_w8a16"
elif dtype == torch.float:
# avoiding cases where kernel fails when float32 MoE
# use fp16/bfloat16 configs
return "float32"
return None
def fused_experts(hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
inplace: bool = False,
override_config: Optional[Dict[str, Any]] = None,
use_fp8_w8a8: bool = False,
use_int8_w8a16: bool = False,
w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None):
if use_fp8_w8a8 or use_int8_w8a16:
raise NotImplementedError("fused_experts has not implemented fp8_w8a8 and int8_w8a16 yet.")
if a1_scale is not None or a2_scale is not None:
raise NotImplementedError("fused_experts has not implemented static_w8a8 yet.")
# Check constraints.
assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch"
assert topk_weights.shape == topk_ids.shape, "topk shape mismatch"
assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
assert w1.is_contiguous(), "Expert weights1 must be contiguous"
assert w2.is_contiguous(), "Expert weights2 must be contiguous"
assert hidden_states.dtype in [
torch.float32, torch.float16, torch.bfloat16
]
num_tokens, _ = hidden_states.shape
E, N, _ = w1.shape
# We execute the fused_moe kernel in chunks to circumvent this issue:
# https://github.com/vllm-project/vllm/issues/5938
CHUNK_SIZE = envs.VLLM_FUSED_MOE_CHUNK_SIZE
M = min(num_tokens, CHUNK_SIZE)
config_dtype = get_config_dtype_str(use_fp8_w8a8=use_fp8_w8a8,
use_int8_w8a16=use_int8_w8a16,
dtype=hidden_states.dtype)
get_config_func = functools.partial(
try_get_optimal_moe_config,
w1.shape,
w2.shape,
topk_ids.shape[1],
config_dtype,
override_config=override_config,
)
config = get_config_func(M)
intermediate_cache1 = torch.empty((M, topk_ids.shape[1], N),
device=hidden_states.device,
dtype=hidden_states.dtype)
intermediate_cache2 = torch.empty((M * topk_ids.shape[1], N // 2),
device=hidden_states.device,
dtype=hidden_states.dtype)
intermediate_cache3 = torch.empty((M, topk_ids.shape[1], w2.shape[1]),
device=hidden_states.device,
dtype=hidden_states.dtype)
compute_type = (tl.bfloat16
if hidden_states.dtype == torch.bfloat16 else tl.float16)
if inplace:
out_hidden_states = hidden_states
else:
out_hidden_states = torch.empty_like(hidden_states)
for chunk in range((num_tokens // CHUNK_SIZE) + 1):
begin_chunk_idx, end_chunk_idx = (chunk * CHUNK_SIZE,
min((chunk + 1) * CHUNK_SIZE,
num_tokens))
curr_hidden_states = hidden_states[begin_chunk_idx:end_chunk_idx]
tokens_in_chunk, _ = curr_hidden_states.shape
if tokens_in_chunk == 0:
break
if tokens_in_chunk < CHUNK_SIZE and chunk > 0:
# Adjust the intermediate cache size and config for the last
# chunk. Note that in most cases we only have one chunk
# so the cache size and config are already set correctly and
# do not need to be adjusted.
intermediate_cache1 = intermediate_cache1[:tokens_in_chunk]
intermediate_cache2 = intermediate_cache2[:tokens_in_chunk]
intermediate_cache3 = intermediate_cache3[:tokens_in_chunk]
config = get_config_func(tokens_in_chunk)
curr_topk_ids = topk_ids[begin_chunk_idx:end_chunk_idx]
curr_topk_weights = topk_weights[begin_chunk_idx:end_chunk_idx]
sorted_token_ids, expert_ids, num_tokens_post_padded = (
moe_align_block_size(curr_topk_ids, config['BLOCK_SIZE_M'], E))
invoke_fused_moe_kernel(curr_hidden_states,
w1,
intermediate_cache1,
a1_scale,
w1_scale,
curr_topk_weights,
curr_topk_ids,
sorted_token_ids,
expert_ids,
num_tokens_post_padded,
False,
topk_ids.shape[1],
config,
compute_type=compute_type,
use_fp8_w8a8=use_fp8_w8a8,
use_int8_w8a16=use_int8_w8a16)
ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N))
invoke_fused_moe_kernel(intermediate_cache2,
w2,
intermediate_cache3,
a2_scale,
w2_scale,
curr_topk_weights,
curr_topk_ids,
sorted_token_ids,
expert_ids,
num_tokens_post_padded,
True,
1,
config,
compute_type=compute_type,
use_fp8_w8a8=use_fp8_w8a8,
use_int8_w8a16=use_int8_w8a16)
torch.sum(intermediate_cache3.view(*intermediate_cache3.shape),
dim=1,
out=out_hidden_states[begin_chunk_idx:end_chunk_idx])
return out_hidden_states
def fused_moe(
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
gating_output: torch.Tensor,
topk: int,
renormalize: bool,
inplace: bool = False,
override_config: Optional[Dict[str, Any]] = None,
use_grouped_topk: bool = False,
num_expert_group: Optional[int] = None,
topk_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None,
use_fp8_w8a8: bool = False,
use_int8_w8a16: bool = False,
w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
This function computes a Mixture of Experts (MoE) layer using two sets of
weights, w1 and w2, and top-k gating mechanism.
Parameters:
- hidden_states (torch.Tensor): The input tensor to the MoE layer.
- w1 (torch.Tensor): The first set of expert weights.
- w2 (torch.Tensor): The second set of expert weights.
- gating_output (torch.Tensor): The output of the gating operation
(before softmax).
- topk (int): The number of top-k experts to select.
- renormalize (bool): If True, renormalize the top-k weights to sum to 1.
- inplace (bool): If True, perform the operation in-place.
Defaults to False.
- override_config (Optional[Dict[str, Any]]): Optional override
for the kernel configuration.
- num_expert_group: Optional[int]: additional parameter for grouped_topk
- topk_group: Optional[int]: additional parameter for grouped_topk
- use_grouped_topk: If True, use grouped_topk instead of fused_topk
note: Deepseekv2 model uses grouped_topk
- use_fp8_w8a8 (bool): If True, use fp8 arithmetic to compute the inner
products for w1 and w2. Defaults to False.
- use_int8_w8a16 (bool): If True, use fp8 arithmetic to compute the inner
products for w1 and w2. Defaults to False.
- w1_scale (Optional[torch.Tensor]): Optional scale to be used for
w1.
- w2_scale (Optional[torch.Tensor]): Optional scale to be used for
w2.
Returns:
- torch.Tensor: The output tensor after applying the MoE layer.
"""
# Check constraints.
assert gating_output.shape[1] == w1.shape[0], "Number of experts mismatch"
if use_grouped_topk:
assert num_expert_group is not None and topk_group is not None
topk_weights, topk_ids = grouped_topk(hidden_states, gating_output,
topk, renormalize,
num_expert_group, topk_group)
elif custom_routing_function is None:
topk_weights, topk_ids = fused_topk(hidden_states, gating_output, topk,
renormalize)
else:
topk_weights, topk_ids = custom_routing_function(
hidden_states, gating_output, topk, renormalize)
return fused_experts(hidden_states,
w1,
w2,
topk_weights,
topk_ids,
inplace=inplace,
override_config=override_config,
use_fp8_w8a8=use_fp8_w8a8,
use_int8_w8a16=use_int8_w8a16,
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a1_scale,
a2_scale=a2_scale)

View File

@@ -0,0 +1,717 @@
from abc import abstractmethod
from enum import Enum
from typing import Callable, List, Optional, Tuple
import torch
from vllm.distributed import (get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce)
from vllm.logger import init_logger
from vllm.model_executor.custom_op import CustomOp
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase)
# from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import CompressedTensorsConfig
from vllm.model_executor.layers.quantization.compressed_tensors.utils import CompressionFormat, QuantizationStrategy
# from vllm.model_executor.layers.quantization.utils.w8a8_utils import create_per_channel_scale_param
from vllm.model_executor.utils import set_weight_attrs
logger = init_logger(__name__)
class FusedMoeWeightScaleSupported(Enum):
TENSOR = "tensor"
CHANNEL = "channel"
GROUP = "group"
class FusedMoEMethodBase(QuantizeMethodBase):
@abstractmethod
def create_weights(self, layer: torch.nn.Module, num_experts: int,
hidden_size: int, intermediate_size: int,
params_dtype: torch.dtype, **extra_weight_attrs):
raise NotImplementedError
@abstractmethod
def apply(self, layer: torch.nn.Module, x: torch.Tensor,
router_logits: torch.Tensor, top_k: int, renormalize: bool,
use_grouped_topk: bool) -> torch.Tensor:
raise NotImplementedError
class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
"""MoE method without quantization."""
def create_weights(self, layer: torch.nn.Module, num_experts: int,
hidden_size: int, intermediate_size: int,
params_dtype: torch.dtype, **extra_weight_attrs):
# Fused gate_up_proj (column parallel)
w13_weight = torch.nn.Parameter(torch.empty(num_experts,
2 * intermediate_size,
hidden_size,
dtype=params_dtype),
requires_grad=False)
layer.register_parameter("w13_weight", w13_weight)
set_weight_attrs(w13_weight, extra_weight_attrs)
# down_proj (row parallel)
w2_weight = torch.nn.Parameter(torch.empty(num_experts,
hidden_size,
intermediate_size,
dtype=params_dtype),
requires_grad=False)
layer.register_parameter("w2_weight", w2_weight)
set_weight_attrs(w2_weight, extra_weight_attrs)
def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
renormalize: bool,
use_grouped_topk: bool,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None
) -> torch.Tensor:
return self.forward(x=x,
layer=layer,
router_logits=router_logits,
top_k=top_k,
renormalize=renormalize,
use_grouped_topk=use_grouped_topk,
topk_group=topk_group,
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function)
def forward_cuda(
self,
layer: torch.nn.Module,
x: torch.Tensor,
use_grouped_topk: bool,
top_k: int,
router_logits: torch.Tensor,
renormalize: bool,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None
) -> torch.Tensor:
from vllm.model_executor.layers.fused_moe.fused_moe import (
fused_experts)
topk_weights, topk_ids = FusedMoE.select_experts(
hidden_states=x,
router_logits=router_logits,
use_grouped_topk=use_grouped_topk,
top_k=top_k,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function)
return fused_experts(hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=True)
def forward_cpu(self, *args, **kwargs):
raise NotImplementedError(
"The CPU backend currently does not support MoE.")
def forward_tpu(
self,
layer: torch.nn.Module,
x: torch.Tensor,
use_grouped_topk: bool,
top_k: int,
router_logits: torch.Tensor,
renormalize: bool,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None
) -> torch.Tensor:
from vllm.model_executor.layers.fused_moe.moe_pallas import fused_moe
assert not use_grouped_topk
assert num_expert_group is None
assert topk_group is None
assert custom_routing_function is None
return fused_moe(hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
topk=top_k,
gating_output=router_logits,
renormalize=renormalize)
class W8A8QuantizedFusedMoEMethod(FusedMoEMethodBase):
"""MoE method W8A8 quantization. This class is for compressed-tensors format loading"""
def create_weights(self, layer: torch.nn.Module, num_experts: int,
hidden_size: int, intermediate_size: int,
params_dtype: torch.dtype, **extra_weight_attrs):
self.strategy = extra_weight_attrs['quant_config'].target_scheme_map['Linear']['weights'].strategy
self.is_static_input_scheme = not extra_weight_attrs['quant_config'].target_scheme_map['Linear']['input_activations'].dynamic
# assert self.is_static_input_scheme, "W8A8 int quantization only support static input activation for now"
self.quant_config = extra_weight_attrs["quant_config"]
self.weight_loader = extra_weight_attrs["weight_loader"]
self.logical_widths_13 = [intermediate_size * 2]
self.logical_widths_2 = [intermediate_size * 2]
# Fused gate_up_proj (column parallel)
w13_weight = torch.nn.Parameter(torch.empty(num_experts,
2 * intermediate_size,
hidden_size,
dtype=torch.int8),
requires_grad=False)
layer.register_parameter("w13_weight", w13_weight)
set_weight_attrs(w13_weight, {
"input_dim": 1,
"output_dim": 0,
"weight_loader": self.weight_loader,
})
# WEIGHT SCALE
layer_kwargs = {"weight_loader": self.weight_loader}
if self.strategy == QuantizationStrategy.CHANNEL:
scale = torch.nn.Parameter(torch.empty((num_experts, intermediate_size * 2, 1),
dtype=torch.float32),
requires_grad=False)
scale[:] = torch.finfo(torch.float32).min
set_weight_attrs(scale, {"input_dim": 1, "output_dim": 0, **layer_kwargs})
else:
assert self.strategy == QuantizationStrategy.TENSOR
scale = torch.nn.Parameter(torch.empty((num_experts, 2), dtype=torch.float32),
requires_grad=False)
scale[:] = torch.finfo(torch.float32).min
set_weight_attrs(scale, {
"needs_scalar_to_array": True,
**layer_kwargs
})
set_weight_attrs(scale, {"is_int8_weight_scale": True})
layer.register_parameter("w13_weight_scale", scale)
# INPUT SCALE
if self.is_static_input_scheme:
scale = torch.nn.Parameter(torch.ones(num_experts,
dtype=torch.float32),
requires_grad=False)
set_weight_attrs(scale, {
"needs_scalar_to_array": True,
**layer_kwargs
})
set_weight_attrs(scale, {"is_int8_input_scale": True})
layer.register_parameter("w13_input_scale", scale)
# down_proj (row parallel)
w2_weight = torch.nn.Parameter(torch.empty(num_experts,
hidden_size,
intermediate_size,
dtype=torch.int8),
requires_grad=False)
layer.register_parameter("w2_weight", w2_weight)
set_weight_attrs(w2_weight, {
"input_dim": 1,
"output_dim": 0,
"weight_loader": self.weight_loader,
})
# WEIGHT SCALE
if self.strategy == QuantizationStrategy.CHANNEL:
scale = torch.nn.Parameter(torch.empty((num_experts, hidden_size, 1),
dtype=torch.float32),
requires_grad=False)
scale[:] = torch.finfo(torch.float32).min
set_weight_attrs(scale, {"input_dim": 0, "output_dim": 1, **layer_kwargs})
else:
assert self.strategy == QuantizationStrategy.TENSOR
scale = torch.nn.Parameter(torch.ones(num_experts,
dtype=torch.float32),
requires_grad=False)
set_weight_attrs(scale, {
"needs_scalar_to_array": True,
**layer_kwargs
})
set_weight_attrs(scale, {"is_int8_weight_scale": True})
layer.register_parameter("w2_weight_scale", scale)
# INPUT SCALE
if self.is_static_input_scheme:
scale = torch.nn.Parameter(torch.ones(num_experts,
dtype=torch.float32),
requires_grad=False)
set_weight_attrs(scale, {
"needs_scalar_to_array": True,
**layer_kwargs
})
set_weight_attrs(scale, {"is_int8_input_scale": True})
layer.register_parameter("w2_input_scale", scale)
def apply(self,
layer: torch.nn.Module,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
renormalize: bool = True,
use_grouped_topk: bool = False,
num_expert_group: Optional[int] = None,
topk_group: Optional[int] = None) -> torch.Tensor:
assert False, "fused_moe w8a8 use ixformer.contrib.vllm.layers.mixtral_decoder_layer_forward"
from vllm.model_executor.layers.fused_moe.fused_moe import fused_moe
if self.is_static_input_scheme:
return fused_moe(x,
layer.w13_weight,
layer.w2_weight,
router_logits,
top_k,
renormalize=renormalize,
inplace=True,
use_grouped_topk=use_grouped_topk,
num_expert_group=num_expert_group,
topk_group=topk_group,
use_int8_w8a8=True,
w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale)
else:
return fused_moe(x,
layer.w13_weight,
layer.w2_weight,
router_logits,
top_k,
renormalize=renormalize,
inplace=True,
use_grouped_topk=use_grouped_topk,
num_expert_group=num_expert_group,
topk_group=topk_group,
use_int8_w8a8=True,
w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
a1_scale=None,
a2_scale=None)
class FusedMoE(torch.nn.Module):
"""FusedMoE layer for MoE models.
This layer contains both MergedColumnParallel weights (gate_up_proj /
w13) and RowParallelLinear weights (down_proj/ w2).
Note: Mixtral uses w1, w2, and w3 for gate, up, and down_proj. We
copy that naming convention here and handle any remapping in the
load_weights function in each model implementation.
Args:
num_experts: Number of experts in the model
top_k: Number of experts selected for each token
hidden_size: Input hidden state size of the transformer
intermediate_size: Intermediate size of the experts
params_dtype: Data type for the parameters.
reduce_results: Whether to all all_reduce on the output of the layer
renomalize: Whether to renormalize the logits in the fused_moe kernel
quant_config: Quantization configure.
"""
def __init__(
self,
num_experts: int,
top_k: int,
hidden_size: int,
intermediate_size: int,
params_dtype: Optional[torch.dtype] = None,
reduce_results: bool = False,
renormalize: bool = True,
use_grouped_topk: bool = False,
num_expert_group: Optional[int] = None,
topk_group: Optional[int] = None,
quant_config: Optional[QuantizationConfig] = None,
tp_size: Optional[int] = None,
prefix: str = "",
custom_routing_function: Optional[Callable] = None,
):
super().__init__()
if params_dtype is None:
params_dtype = torch.get_default_dtype()
self.tp_size = (tp_size if tp_size is not None else
get_tensor_model_parallel_world_size())
self.top_k = top_k
self.num_experts = num_experts
self.intermediate_size_per_partition = intermediate_size // self.tp_size
self.reduce_results = reduce_results
self.renormalize = renormalize
self.use_grouped_topk = use_grouped_topk
if self.use_grouped_topk:
assert num_expert_group is not None and topk_group is not None
self.num_expert_group = num_expert_group
self.topk_group = topk_group
self.custom_routing_function = custom_routing_function
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import CompressedTensorsConfig
if quant_config is None:
self.quant_method: Optional[QuantizeMethodBase] = (
UnquantizedFusedMoEMethod())
elif (isinstance(quant_config, CompressedTensorsConfig)
and quant_config.quant_format == CompressionFormat.int_quantized.value
and quant_config.target_scheme_map['Linear']['input_activations'].num_bits == 8
and quant_config.target_scheme_map['Linear']['weights'].num_bits == 8):
self.quant_method: Optional[QuantizeMethodBase] = (
W8A8QuantizedFusedMoEMethod())
else:
self.quant_method = quant_config.get_quant_method(self, prefix)
assert self.quant_method is not None
self.quant_method.create_weights(
layer=self,
num_experts=num_experts,
hidden_size=hidden_size,
intermediate_size=self.intermediate_size_per_partition,
params_dtype=params_dtype,
weight_loader=self.weight_loader,
quant_config=quant_config)
def _load_per_tensor_weight_scale(self, shard_id: str,
param: torch.nn.Parameter,
loaded_weight: torch.Tensor,
expert_id: int):
param_data = param.data
# for per tensor weight quantization
if shard_id in ("w1", "w3"):
# We have to keep the weight scales of w1 and w3 because
# we need to re-quantize w1/w3 weights after weight loading.
idx = 0 if shard_id == "w1" else 1
param_data[expert_id][idx] = loaded_weight
# If we are in the row parallel case (down_proj)
elif shard_id == "w2":
param_data[expert_id] = loaded_weight
def _load_model_weight_or_group_weight_scale(self, shard_dim: int,
expert_data: torch.Tensor,
shard_id: str,
loaded_weight: torch.tensor,
tp_rank: int):
# Load grouped weight scales for group quantization
# or model weights
if shard_id == "w2":
self._load_w2(shard_id=shard_id,
shard_dim=shard_dim,
loaded_weight=loaded_weight,
expert_data=expert_data,
tp_rank=tp_rank)
elif shard_id in ("w1", "w3"):
self._load_w13(shard_id=shard_id,
shard_dim=shard_dim,
loaded_weight=loaded_weight,
expert_data=expert_data,
tp_rank=tp_rank)
def _load_per_channel_weight_scale(self, expert_data: torch.Tensor,
shard_dim: int, shard_id: str,
loaded_weight: torch.tensor,
tp_rank: int):
# for per channel weight quantization
if shard_id == "w2":
expert_data.copy_(loaded_weight)
elif shard_id in ("w1", "w3"):
self._load_w13(shard_id=shard_id,
shard_dim=shard_dim,
loaded_weight=loaded_weight,
expert_data=expert_data,
tp_rank=tp_rank)
def _load_w13(self, expert_data: torch.Tensor, shard_dim: int,
shard_id: str, loaded_weight: torch.tensor, tp_rank: int):
# Index the loaded weight for tp sharding.
# gate_up_proj: "MergedColumnParallel", so tp sharding on output_dim
shard_size = expert_data.shape[shard_dim] // 2
loaded_weight = loaded_weight.narrow(shard_dim, shard_size * tp_rank,
shard_size)
# Narrow parameter and load.
# w1, gate_proj: Load into first logical weight of w13.
if shard_id == "w1":
expert_data = expert_data.narrow(shard_dim, 0, shard_size)
# w3, up_proj: Load into second logical weight of w13.
else:
assert shard_id == "w3"
expert_data = expert_data.narrow(shard_dim, shard_size, shard_size)
expert_data.copy_(loaded_weight)
def _load_w2(self, expert_data: torch.Tensor, shard_dim: int,
shard_id: str, loaded_weight: torch.tensor, tp_rank: int):
# Index the loaded weight for tp sharding.
# down_proj: "RowParallel" so tp sharding on input_dim
# Narrow parameter and load.
shard_size = expert_data.shape[shard_dim]
loaded_weight = loaded_weight.narrow(shard_dim, shard_size * tp_rank,
shard_size)
# w2, down_proj: Load into only logical weight of w2.
expert_data.copy_(loaded_weight)
def _load_single_value(self, param: torch.nn.Parameter,
loaded_weight: torch.Tensor, expert_id: int):
param_data = param.data
# Input scales can be loaded directly and should be equal.
param_data[expert_id] = loaded_weight
def _load_g_idx(self, shard_id: str, expert_data: torch.Tensor,
shard_dim: int, loaded_weight: torch.tensor, tp_rank: int):
if shard_id == "w2":
self._load_w2(shard_id=shard_id,
shard_dim=shard_dim,
loaded_weight=loaded_weight,
expert_data=expert_data,
tp_rank=tp_rank)
else:
assert shard_id in ("w1", "w3")
expert_data.copy_(loaded_weight)
def weight_loader(self, param: torch.nn.Parameter,
loaded_weight: torch.Tensor, weight_name: str,
shard_id: str, expert_id: int) -> None:
# compressed-tensors checkpoints with packed weights are stored flipped
# TODO (mgoin): check self.quant_method.quant_config.quant_format
# against known CompressionFormat enum values that have this quality
loaded_weight = loaded_weight.t().contiguous() if (
self.quant_method.__class__.__name__
== "CompressedTensorsWNA16MoEMethod") else loaded_weight
if shard_id not in ("w1", "w2", "w3"):
raise ValueError(f"shard_id must be ['w1','w2','w3'] but "
f"got {shard_id}.")
# Special case for fp8 scales.
if getattr(param, "is_fp8_scale", False):
self._load_fp8_scale(param.data, loaded_weight, weight_name,
shard_id, expert_id)
return
elif getattr(param, "is_int8_input_scale", False):
self._load_int8_input_scale(param.data, loaded_weight,
shard_id, expert_id)
return
WEIGHT_SCALE_SUPPORTED = [
e.value for e in FusedMoeWeightScaleSupported
]
# Fetch the dim to shard the parameter/loaded weight
# based on the shard id. This will be whatever
# dimension intermediate_size is used.
SHARD_ID_TO_SHARDED_DIM = {"w1": 0, "w2": 1, "w3": 0}
expert_data = param.data[expert_id]
tp_rank = get_tensor_model_parallel_rank()
# is_transposed: if the dim to shard the weight
# should be flipped. Required by GPTQ, compressed-tensors
# should be whatever dimension intermediate_size is
is_transposed = getattr(param, "is_transposed", False)
shard_dim = SHARD_ID_TO_SHARDED_DIM[shard_id]
if is_transposed:
shard_dim = ~shard_dim
# Case input scale: input_scale loading is only supported for fp8
if "input_scale" in weight_name:
# this is needed for compressed-tensors only
loaded_weight = loaded_weight.to(param.data.device)
if param.data[expert_id] != 1 and (param.data[expert_id] -
loaded_weight).abs() > 1e-5:
raise ValueError(
"input_scales of w1 and w3 of a layer "
f"must be equal. But got {param.data[expert_id]} "
f"vs. {loaded_weight}")
self._load_single_value(param=param,
loaded_weight=loaded_weight,
expert_id=expert_id)
return
# Case g_idx
if "g_idx" in weight_name:
self._load_g_idx(shard_dim=0,
shard_id=shard_id,
loaded_weight=loaded_weight,
expert_data=expert_data,
tp_rank=tp_rank)
return
# Case weight scales and zero_points
if ("scale" in weight_name or "zero" in weight_name):
# load the weight scales and zp based on the quantization scheme
# supported weight scales/zp can be found in
# FusedMoeWeightScaleSupported
# TODO @dsikka: once hardened, refactor to use vLLM Parameters
# specific to each case
quant_method = getattr(param, "quant_method", FusedMoeWeightScaleSupported.CHANNEL.value)
if quant_method == FusedMoeWeightScaleSupported.CHANNEL.value:
self._load_per_channel_weight_scale(
shard_id=shard_id,
shard_dim=shard_dim,
loaded_weight=loaded_weight,
expert_data=expert_data,
tp_rank=tp_rank)
elif quant_method == FusedMoeWeightScaleSupported.GROUP.value:
self._load_model_weight_or_group_weight_scale(
shard_id=shard_id,
shard_dim=shard_dim,
loaded_weight=loaded_weight,
expert_data=expert_data,
tp_rank=tp_rank)
elif quant_method == FusedMoeWeightScaleSupported.TENSOR.value:
self._load_per_tensor_weight_scale(shard_id=shard_id,
param=param,
loaded_weight=loaded_weight,
expert_id=expert_id)
else:
raise ValueError(
f"quant method must be one of {WEIGHT_SCALE_SUPPORTED}")
return
# Case weight_shape
if "weight_shape" in weight_name:
# only required by compressed-tensors
self._load_single_value(param=param,
loaded_weight=loaded_weight,
expert_id=expert_id)
return
# Case model weights
if "weight" in weight_name:
self._load_model_weight_or_group_weight_scale(
shard_id=shard_id,
shard_dim=shard_dim,
loaded_weight=loaded_weight,
expert_data=expert_data,
tp_rank=tp_rank)
return
@staticmethod
def select_experts(hidden_states: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
use_grouped_topk: bool,
renormalize: bool,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None):
from vllm.model_executor.layers.fused_moe.fused_moe import (
fused_topk, grouped_topk)
# DeekSeekv2 uses grouped_top_k
if use_grouped_topk:
assert topk_group is not None
assert num_expert_group is not None
topk_weights, topk_ids = grouped_topk(
hidden_states=hidden_states,
gating_output=router_logits,
topk=top_k,
renormalize=renormalize,
num_expert_group=num_expert_group,
topk_group=topk_group)
elif custom_routing_function is None:
topk_weights, topk_ids = fused_topk(hidden_states=hidden_states,
gating_output=router_logits,
topk=top_k,
renormalize=renormalize)
else:
topk_weights, topk_ids = custom_routing_function(
hidden_states=hidden_states,
gating_output=router_logits,
topk=top_k,
renormalize=renormalize)
return topk_weights, topk_ids
def forward(self, hidden_states: torch.Tensor,
router_logits: torch.Tensor):
assert self.quant_method is not None
# Matrix multiply.
final_hidden_states = self.quant_method.apply(
layer=self,
x=hidden_states,
router_logits=router_logits,
top_k=self.top_k,
renormalize=self.renormalize,
use_grouped_topk=self.use_grouped_topk,
topk_group=self.topk_group,
num_expert_group=self.num_expert_group,
custom_routing_function=self.custom_routing_function)
if self.reduce_results and self.tp_size > 1:
final_hidden_states = tensor_model_parallel_all_reduce(
final_hidden_states)
return final_hidden_states
@classmethod
def make_expert_params_mapping(
cls, ckpt_gate_proj_name: str, ckpt_down_proj_name: str,
ckpt_up_proj_name: str,
num_experts: int) -> List[Tuple[str, str, int, str]]:
return [
# (param_name, weight_name, expert_id, shard_id)
("experts.w13_" if weight_name
in [ckpt_gate_proj_name, ckpt_up_proj_name] else "experts.w2_",
f"experts.{expert_id}.{weight_name}.", expert_id, shard_id)
for expert_id in range(num_experts) for shard_id, weight_name in [
("w1", ckpt_gate_proj_name),
("w2", ckpt_down_proj_name),
("w3", ckpt_up_proj_name),
]
]
def _load_fp8_scale(self, param: torch.nn.Parameter,
loaded_weight: torch.Tensor, weight_name: str,
shard_id: str, expert_id: int) -> None:
param_data = param.data
# Input scales can be loaded directly and should be equal.
if "input_scale" in weight_name:
if param_data[expert_id] != 1 and (param_data[expert_id] -
loaded_weight).abs() > 1e-5:
raise ValueError(
"input_scales of w1 and w3 of a layer "
f"must be equal. But got {param_data[expert_id]} "
f"vs. {loaded_weight}")
param_data[expert_id] = loaded_weight
# Weight scales
elif "weight_scale" in weight_name:
# If we are in merged column case (gate_up_proj)
if shard_id in ("w1", "w3"):
# We have to keep the weight scales of w1 and w3 because
# we need to re-quantize w1/w3 weights after weight loading.
idx = 0 if shard_id == "w1" else 1
param_data[expert_id][idx] = loaded_weight
# If we are in the row parallel case (down_proj)
else:
param_data[expert_id] = loaded_weight
def _load_int8_input_scale(self, param: torch.nn.Parameter,
loaded_weight: torch.Tensor,
shard_id: str, expert_id: int) -> None:
param_data = param.data
# Input scales can be loaded directly and should be equal.
if param_data[expert_id] != 1 and (param_data[expert_id].to(loaded_weight.device) -
loaded_weight).abs() > 1e-5:
raise ValueError(
"input_scales of w1 and w3 of a layer "
f"must be equal. But got {param_data[expert_id]} "
f"vs. {loaded_weight}")
param_data[expert_id] = loaded_weight

View File

@@ -0,0 +1,62 @@
import torch
import torch.nn.functional as F
from torch_xla.experimental.custom_kernel import _histogram
def fused_moe(
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
gating_output: torch.Tensor,
topk: int,
renormalize: bool,
) -> torch.Tensor:
"""
Args:
hidden_states: [*, hidden_size]
w1: [num_experts, intermediate_size * 2, hidden_size]
w2: [num_experts, hidden_size, intermediate_size]
gating_output: [*, num_experts]
"""
orig_shape = hidden_states.shape
hidden_size = hidden_states.shape[-1]
num_tokens = hidden_states.shape[:-1].numel()
num_experts = w1.shape[0]
intermediate_size = w2.shape[-1]
device = hidden_states.device
dtype = hidden_states.dtype
assert (num_tokens * topk) % 16 == 0, (
"The Pallas GMM kernel requires num_tokens * topk to be a multiple of "
f"16 but got {num_tokens * topk}")
hidden_states = hidden_states.view(num_tokens, hidden_size)
gating_output = gating_output.view(num_tokens, num_experts)
topk_weights = gating_output.softmax(dim=-1, dtype=torch.float)
topk_weights, topk_indices = topk_weights.topk(topk, dim=-1)
if renormalize:
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
topk_weights = topk_weights.to(dtype)
topk_indices = topk_indices.flatten()
topk_argsort_indices = topk_indices.argsort()
topk_argsort_revert_indices = topk_argsort_indices.argsort()
token_indices = torch.arange(num_tokens,
device=device).repeat_interleave(topk)
token_indices = token_indices[topk_argsort_indices]
group_sizes = _histogram(topk_indices.to(torch.int32), 0, num_experts - 1)
# NOTE(woosuk): The GMM Pallas kernel requires a different weight layout
# from HF Transformers.
w1 = w1.transpose(1, 2)
w2 = w2.transpose(1, 2)
x = hidden_states[token_indices]
x = torch.ops.xla.gmm(x, w1, group_sizes)
x = F.silu(x[..., :intermediate_size]) * x[..., intermediate_size:]
x = torch.ops.xla.gmm(x, w2, group_sizes)
x = x[topk_argsort_revert_indices].reshape(-1, topk, hidden_size)
x = x * topk_weights.unsqueeze_(dim=-1)
x = x.sum(dim=-2)
x = x.reshape(orig_shape)
return x