[Model] Support DeepSeek-V4
This commit is contained in:
3
vllm_mlu/model_executor/layers/fused_moe/__init__.py
Normal file
3
vllm_mlu/model_executor/layers/fused_moe/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
|
||||
|
||||
935
vllm_mlu/model_executor/layers/fused_moe/fused_moe.py
Normal file
935
vllm_mlu/model_executor/layers/fused_moe/fused_moe.py
Normal file
@@ -0,0 +1,935 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
|
||||
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
"""Fused MoE kernel."""
|
||||
import functools
|
||||
import json
|
||||
import os
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
_get_config_dtype_str,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import (
|
||||
fused_moe_kernel_gptq_awq,
|
||||
write_zeros_to_output,
|
||||
get_default_config,
|
||||
try_get_optimal_moe_config,
|
||||
_get_config_quant_dtype,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.utils import (
|
||||
activation_without_mul,
|
||||
disable_inplace,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.mxfp4_utils import dequant_mxfp4
|
||||
from vllm.model_executor.layers.quantization.utils.mxfp6_utils import dequant_mxfp6
|
||||
from vllm.model_executor.layers.quantization.utils.ocp_mx_utils import OCP_MX_Scheme
|
||||
from vllm.utils.torch_utils import direct_register_custom_op, is_torch_equal_or_newer
|
||||
|
||||
from vllm_mlu.model_executor.layers.fused_moe.moe_align_block_size import (
|
||||
moe_align_block_size)
|
||||
from vllm_mlu.model_executor.layers.fused_moe.utils import _fp8_quantize
|
||||
import vllm_mlu._mlu_ops as mlu_ops
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@triton.jit
|
||||
def fused_moe_kernel(
|
||||
# Pointers to matrices
|
||||
a_ptr,
|
||||
b_ptr,
|
||||
c_ptr,
|
||||
b_bias_ptr,
|
||||
a_scale_ptr,
|
||||
b_scale_ptr,
|
||||
topk_weights_ptr,
|
||||
sorted_token_ids_ptr,
|
||||
expert_ids_ptr,
|
||||
num_tokens_post_padded_ptr,
|
||||
# Matrix dimensions
|
||||
N,
|
||||
K,
|
||||
EM,
|
||||
num_valid_tokens,
|
||||
# The stride variables represent how much to increase the ptr by when
|
||||
# moving by 1 element in a particular dimension. E.g. `stride_am` is
|
||||
# how much to increase `a_ptr` by to get the element one row down
|
||||
# (A has M rows).
|
||||
stride_am,
|
||||
stride_ak,
|
||||
stride_be,
|
||||
stride_bk,
|
||||
stride_bn,
|
||||
stride_cm,
|
||||
stride_cn,
|
||||
stride_asm,
|
||||
stride_ask,
|
||||
stride_bse,
|
||||
stride_bsk,
|
||||
stride_bsn,
|
||||
stride_bbe, # bias expert stride
|
||||
stride_bbn, # bias N stride
|
||||
# Block size for block-wise quantization
|
||||
group_n: tl.constexpr,
|
||||
group_k: tl.constexpr,
|
||||
# Meta-parameters
|
||||
BLOCK_SIZE_M: tl.constexpr,
|
||||
BLOCK_SIZE_N: tl.constexpr,
|
||||
BLOCK_SIZE_K: tl.constexpr,
|
||||
GROUP_SIZE_M: tl.constexpr,
|
||||
SPLIT_K: tl.constexpr,
|
||||
MUL_ROUTED_WEIGHT: tl.constexpr,
|
||||
top_k: tl.constexpr,
|
||||
compute_type: tl.constexpr,
|
||||
use_fp8_w8a8: tl.constexpr,
|
||||
use_int8_w8a8: tl.constexpr,
|
||||
use_int8_w8a16: tl.constexpr,
|
||||
per_channel_quant: tl.constexpr,
|
||||
HAS_BIAS: tl.constexpr,
|
||||
):
|
||||
"""
|
||||
Implements the fused computation for a Mixture of Experts (MOE) using
|
||||
token and expert matrices.
|
||||
|
||||
Key Parameters:
|
||||
- A: The input tensor representing tokens with shape (*, K), where '*' can
|
||||
be any shape representing batches and K is the feature dimension of
|
||||
each token.
|
||||
- B: The stacked MOE weight tensor with shape (E, N, K), where E is
|
||||
the number of experts, K is the input feature dimension, and N is
|
||||
the output feature dimension.
|
||||
- C: The output cache tensor with shape (M, topk, N), where M is the
|
||||
total number of tokens post padding, topk is the number of times
|
||||
each token is repeated, and N is the output feature dimension.
|
||||
- sorted_token_ids: A tensor containing the sorted indices of tokens,
|
||||
repeated topk times and arranged by the expert index they are
|
||||
assigned to.
|
||||
- expert_ids: A tensor containing the indices of the expert for each
|
||||
block. It determines which expert matrix from B should be used for
|
||||
each block in A.
|
||||
This kernel performs the multiplication of a token by its corresponding
|
||||
expert matrix as determined by `expert_ids`. The sorting of
|
||||
`sorted_token_ids` by expert index and padding ensures divisibility by
|
||||
BLOCK_SIZE_M, which is necessary to maintain consistency in block matrix
|
||||
multiplication across different blocks processed by the same expert.
|
||||
"""
|
||||
# -----------------------------------------------------------
|
||||
# Map program ids `pid` to the block of C it should compute.
|
||||
# This is done in a grouped ordering to promote L2 data reuse.
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: Split the program ID into two dimensions (pid_0 and pid_1)
|
||||
'''
|
||||
pid_0 = tl.program_id(axis=0)
|
||||
pid_1 = tl.program_id(axis=1)
|
||||
pid = pid_1 * tl.num_programs(axis=0) + pid_0
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
num_pid_m = tl.cdiv(EM, BLOCK_SIZE_M)
|
||||
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
|
||||
num_pid_in_group = GROUP_SIZE_M * num_pid_n
|
||||
group_id = pid // num_pid_in_group
|
||||
first_pid_m = group_id * GROUP_SIZE_M
|
||||
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
|
||||
pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
|
||||
pid_n = (pid % num_pid_in_group) // group_size_m
|
||||
|
||||
# ----------------------------------------------------------
|
||||
# Create pointers for the first blocks of A and B.
|
||||
# We will advance this pointer as we move in the K direction
|
||||
# and accumulate
|
||||
# `a_ptrs` is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers
|
||||
# `b_ptrs` is a block of [BLOCK_SIZE_K, BLOCK_SIZE_N] pointers
|
||||
num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr)
|
||||
if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded:
|
||||
return
|
||||
offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to(tl.int64)
|
||||
offs_token = tl.load(sorted_token_ids_ptr + offs_token_id)
|
||||
token_mask = offs_token < num_valid_tokens
|
||||
|
||||
off_experts = tl.load(expert_ids_ptr + pid_m).to(tl.int64)
|
||||
if off_experts == -1:
|
||||
# -----------------------------------------------------------
|
||||
# Write back zeros to the output when the expert is not
|
||||
# in the current expert parallel rank.
|
||||
write_zeros_to_output(
|
||||
c_ptr,
|
||||
stride_cm,
|
||||
stride_cn,
|
||||
pid_n,
|
||||
N,
|
||||
offs_token,
|
||||
token_mask,
|
||||
BLOCK_SIZE_M,
|
||||
BLOCK_SIZE_N,
|
||||
compute_type,
|
||||
)
|
||||
return
|
||||
|
||||
offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int64)) % N
|
||||
offs_k = tl.arange(0, BLOCK_SIZE_K)
|
||||
a_ptrs = a_ptr + (
|
||||
offs_token[:, None] // top_k * stride_am + offs_k[None, :] * stride_ak
|
||||
)
|
||||
|
||||
b_ptrs = (
|
||||
b_ptr
|
||||
+ off_experts * stride_be
|
||||
+ (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
|
||||
)
|
||||
if use_int8_w8a16:
|
||||
b_scale_ptrs = (
|
||||
b_scale_ptr + off_experts * stride_bse + offs_bn[None, :] * stride_bsn
|
||||
)
|
||||
b_scale = tl.load(b_scale_ptrs)
|
||||
|
||||
if use_fp8_w8a8 or use_int8_w8a8:
|
||||
# block-wise
|
||||
if group_k > 0 and group_n > 0:
|
||||
a_scale_ptrs = a_scale_ptr + (offs_token // top_k) * stride_asm
|
||||
offs_bsn = offs_bn // group_n
|
||||
b_scale_ptrs = (
|
||||
b_scale_ptr + off_experts * stride_bse + offs_bsn * stride_bsn
|
||||
)
|
||||
# channel-wise
|
||||
elif per_channel_quant:
|
||||
b_scale_ptrs = (
|
||||
b_scale_ptr + off_experts * stride_bse + offs_bn[None, :] * stride_bsn
|
||||
)
|
||||
b_scale = tl.load(b_scale_ptrs)
|
||||
# Load per-token scale for activations
|
||||
a_scale_ptrs = a_scale_ptr + (offs_token // top_k) * stride_asm
|
||||
a_scale = tl.load(a_scale_ptrs, mask=token_mask, other=0.0)[:, None]
|
||||
# tensor-wise
|
||||
else:
|
||||
a_scale = tl.load(a_scale_ptr)
|
||||
b_scale = tl.load(b_scale_ptr + off_experts)
|
||||
if HAS_BIAS:
|
||||
# bias shape: [num_experts, N]
|
||||
bias_ptrs = b_bias_ptr + off_experts * stride_bbe + offs_bn * stride_bbn
|
||||
bias = tl.load(bias_ptrs, mask=(offs_bn < N), other=0.0)
|
||||
# -----------------------------------------------------------
|
||||
# Iterate to compute a block of the C matrix.
|
||||
# We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block
|
||||
# of fp32 values for higher accuracy.
|
||||
# `accumulator` will be converted back to fp16 after the loop.
|
||||
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
||||
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
|
||||
# Load the next block of A and B, generate a mask by checking the
|
||||
# K dimension.
|
||||
a = tl.load(
|
||||
a_ptrs,
|
||||
mask=token_mask[:, None] & (offs_k[None, :] < K - k * BLOCK_SIZE_K),
|
||||
other=0.0,
|
||||
)
|
||||
b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)
|
||||
# We accumulate along the K dimension.
|
||||
if use_int8_w8a16:
|
||||
accumulator = tl.dot(a, b.to(compute_type), acc=accumulator)
|
||||
elif use_fp8_w8a8 or use_int8_w8a8:
|
||||
if group_k > 0 and group_n > 0:
|
||||
k_start = k * BLOCK_SIZE_K
|
||||
offs_ks = k_start // group_k
|
||||
a_scale = tl.load(
|
||||
a_scale_ptrs + offs_ks * stride_ask, mask=token_mask, other=0.0
|
||||
)
|
||||
b_scale = tl.load(b_scale_ptrs + offs_ks * stride_bsk)
|
||||
|
||||
accumulator += tl.dot(a, b) * a_scale[:, None] * b_scale[None, :]
|
||||
else:
|
||||
if use_fp8_w8a8:
|
||||
# acc used to enable fp8_fast_accum
|
||||
accumulator = tl.dot(a, b, acc=accumulator)
|
||||
else:
|
||||
accumulator += tl.dot(a, b)
|
||||
else:
|
||||
accumulator += tl.dot(a, b)
|
||||
# Advance the ptrs to the next K block.
|
||||
a_ptrs += BLOCK_SIZE_K * stride_ak
|
||||
b_ptrs += BLOCK_SIZE_K * stride_bk
|
||||
if HAS_BIAS:
|
||||
accumulator = accumulator + bias[None, :]
|
||||
if MUL_ROUTED_WEIGHT:
|
||||
moe_weight = tl.load(topk_weights_ptr + offs_token, mask=token_mask, other=0)
|
||||
accumulator = accumulator * moe_weight[:, None]
|
||||
if use_int8_w8a16:
|
||||
accumulator = (accumulator * b_scale).to(compute_type)
|
||||
elif use_fp8_w8a8 or use_int8_w8a8:
|
||||
if group_k > 0 and group_n > 0:
|
||||
accumulator = accumulator.to(compute_type)
|
||||
else:
|
||||
accumulator = (accumulator * a_scale * b_scale).to(compute_type)
|
||||
else:
|
||||
accumulator = accumulator.to(compute_type)
|
||||
|
||||
# -----------------------------------------------------------
|
||||
# Write back the block of the output
|
||||
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
||||
c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[None, :]
|
||||
c_mask = token_mask[:, None] & (offs_cn[None, :] < N)
|
||||
tl.store(c_ptrs, accumulator, mask=c_mask)
|
||||
|
||||
|
||||
def invoke_fused_moe_kernel(
|
||||
A: torch.Tensor,
|
||||
B: torch.Tensor,
|
||||
C: torch.Tensor,
|
||||
A_scale: torch.Tensor | None,
|
||||
B_scale: torch.Tensor | None,
|
||||
B_zp: torch.Tensor | None,
|
||||
topk_weights: torch.Tensor | None,
|
||||
sorted_token_ids: torch.Tensor,
|
||||
expert_ids: torch.Tensor,
|
||||
num_tokens_post_padded: torch.Tensor,
|
||||
mul_routed_weight: bool,
|
||||
top_k: int,
|
||||
config: dict[str, Any],
|
||||
compute_type: tl.dtype,
|
||||
use_fp8_w8a8: bool,
|
||||
use_int8_w8a8: bool,
|
||||
use_int8_w8a16: bool,
|
||||
use_int4_w4a16: bool,
|
||||
per_channel_quant: bool,
|
||||
block_shape: list[int] | None = None,
|
||||
B_bias: torch.Tensor | None = None,
|
||||
) -> None:
|
||||
assert topk_weights is not None or not mul_routed_weight
|
||||
assert topk_weights is None or topk_weights.stride(1) == 1
|
||||
assert sorted_token_ids.stride(0) == 1
|
||||
|
||||
if use_fp8_w8a8 or use_int8_w8a8:
|
||||
assert B_scale is not None
|
||||
assert block_shape is None or triton.cdiv(
|
||||
B.size(-2), block_shape[0]
|
||||
) == B_scale.size(-2)
|
||||
assert block_shape is None or triton.cdiv(
|
||||
B.size(-1), block_shape[1]
|
||||
) == B_scale.size(-1)
|
||||
|
||||
elif use_int8_w8a16 or use_int4_w4a16:
|
||||
assert B_scale is not None
|
||||
assert block_shape is None or block_shape[0] == 0
|
||||
else:
|
||||
assert A_scale is None
|
||||
assert B_scale is None
|
||||
|
||||
M = A.size(0)
|
||||
num_tokens = M * top_k
|
||||
|
||||
EM = sorted_token_ids.size(0)
|
||||
if A.size(0) < config["BLOCK_SIZE_M"]:
|
||||
# optimize for small batch_size.
|
||||
# We assume that top_ids of each token is unique,
|
||||
# so num_valid_experts <= batch_size <= BLOCK_SIZE_M,
|
||||
# and we can skip some invalid blocks.
|
||||
EM = min(sorted_token_ids.size(0), A.size(0) * top_k * config["BLOCK_SIZE_M"])
|
||||
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: Split the program ID into two dimensions (pid_0, pid_1)
|
||||
'''
|
||||
grid = lambda META: (triton.cdiv(EM, META['BLOCK_SIZE_M']), triton.cdiv(
|
||||
B.shape[1], META['BLOCK_SIZE_N']), )
|
||||
|
||||
assert not (use_int8_w8a16 or use_int4_w4a16)
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
HAS_BIAS = B_bias is not None
|
||||
if (
|
||||
(use_int8_w8a16 or use_int4_w4a16)
|
||||
and block_shape is not None
|
||||
and block_shape[1] > 0
|
||||
):
|
||||
assert B_scale is not None and B_scale.ndim == 3
|
||||
assert B_zp is None or B_zp.ndim == 3
|
||||
|
||||
use_moe_wna16_cuda = should_moe_wna16_use_cuda(
|
||||
num_valid_tokens=num_tokens,
|
||||
group_size=block_shape[1],
|
||||
num_experts=B.size(0),
|
||||
bit=4 if use_int4_w4a16 else 8,
|
||||
)
|
||||
config = config.copy()
|
||||
config.update(
|
||||
get_moe_wna16_block_config(
|
||||
config=config,
|
||||
use_moe_wna16_cuda=use_moe_wna16_cuda,
|
||||
num_valid_tokens=num_tokens,
|
||||
size_k=A.size(1),
|
||||
size_n=B.size(1),
|
||||
num_experts=B.size(1),
|
||||
group_size=block_shape[1],
|
||||
real_top_k=top_k,
|
||||
block_size_m=config["BLOCK_SIZE_M"],
|
||||
)
|
||||
)
|
||||
|
||||
if use_moe_wna16_cuda:
|
||||
bit = 4 if use_int4_w4a16 else 8
|
||||
ops.moe_wna16_gemm(
|
||||
A,
|
||||
C,
|
||||
B,
|
||||
B_scale,
|
||||
B_zp,
|
||||
topk_weights if mul_routed_weight else None,
|
||||
sorted_token_ids,
|
||||
expert_ids,
|
||||
num_tokens_post_padded,
|
||||
top_k,
|
||||
config["BLOCK_SIZE_M"],
|
||||
config["BLOCK_SIZE_N"],
|
||||
config["BLOCK_SIZE_K"],
|
||||
bit,
|
||||
)
|
||||
return
|
||||
fused_moe_kernel_gptq_awq[grid](
|
||||
A,
|
||||
B,
|
||||
C,
|
||||
B_scale,
|
||||
B_zp,
|
||||
topk_weights,
|
||||
sorted_token_ids,
|
||||
expert_ids,
|
||||
num_tokens_post_padded,
|
||||
B.size(1),
|
||||
A.size(1),
|
||||
EM,
|
||||
num_tokens,
|
||||
A.stride(0),
|
||||
A.stride(1),
|
||||
B.stride(0),
|
||||
B.stride(2),
|
||||
B.stride(1),
|
||||
C.stride(1),
|
||||
C.stride(2),
|
||||
B_scale.stride(0),
|
||||
B_scale.stride(2),
|
||||
B_scale.stride(1),
|
||||
B_zp.stride(0) if B_zp is not None else 0,
|
||||
B_zp.stride(2) if B_zp is not None else 0,
|
||||
B_zp.stride(1) if B_zp is not None else 0,
|
||||
block_k_diviable=A.size(1) % config["BLOCK_SIZE_K"] == 0,
|
||||
group_size=block_shape[1],
|
||||
MUL_ROUTED_WEIGHT=mul_routed_weight,
|
||||
top_k=top_k,
|
||||
compute_type=compute_type,
|
||||
has_zp=B_zp is not None,
|
||||
use_int4_w4a16=use_int4_w4a16,
|
||||
use_int8_w8a16=use_int8_w8a16,
|
||||
**config,
|
||||
)
|
||||
else:
|
||||
config = config.copy()
|
||||
config["SPLIT_K"] = 1
|
||||
BLOCK_SIZE_K = config.pop("BLOCK_SIZE_K")
|
||||
if block_shape is not None:
|
||||
BLOCK_SIZE_K = min(BLOCK_SIZE_K, min(block_shape[0], block_shape[1]))
|
||||
fused_moe_kernel[grid](
|
||||
A,
|
||||
B,
|
||||
C,
|
||||
B_bias,
|
||||
A_scale,
|
||||
B_scale,
|
||||
topk_weights,
|
||||
sorted_token_ids,
|
||||
expert_ids,
|
||||
num_tokens_post_padded,
|
||||
B.size(1),
|
||||
B.size(2),
|
||||
EM,
|
||||
num_tokens,
|
||||
A.stride(0),
|
||||
A.stride(1),
|
||||
B.stride(0),
|
||||
B.stride(2),
|
||||
B.stride(1),
|
||||
C.stride(1),
|
||||
C.stride(2),
|
||||
A_scale.stride(0) if A_scale is not None and A_scale.ndim == 2 else 0,
|
||||
A_scale.stride(1) if A_scale is not None and A_scale.ndim == 2 else 0,
|
||||
B_scale.stride(0) if B_scale is not None and B_scale.ndim >= 2 else 0,
|
||||
B_scale.stride(2) if B_scale is not None and B_scale.ndim == 3 else 0,
|
||||
B_scale.stride(1) if B_scale is not None and B_scale.ndim >= 2 else 0,
|
||||
B_bias.stride(0) if B_bias is not None else 0,
|
||||
B_bias.stride(1) if B_bias is not None else 0,
|
||||
0 if block_shape is None else block_shape[0],
|
||||
0 if block_shape is None else block_shape[1],
|
||||
MUL_ROUTED_WEIGHT=mul_routed_weight,
|
||||
top_k=top_k,
|
||||
compute_type=compute_type,
|
||||
use_fp8_w8a8=use_fp8_w8a8,
|
||||
use_int8_w8a8=use_int8_w8a8,
|
||||
use_int8_w8a16=use_int8_w8a16,
|
||||
per_channel_quant=per_channel_quant,
|
||||
HAS_BIAS=HAS_BIAS,
|
||||
BLOCK_SIZE_K=BLOCK_SIZE_K,
|
||||
**config,
|
||||
)
|
||||
|
||||
|
||||
def outplace_fused_experts(
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
activation: str = "silu",
|
||||
apply_router_weight_on_input: bool = False,
|
||||
use_fp8_w8a8: bool = False,
|
||||
use_int8_w8a8: bool = False,
|
||||
use_int8_w8a16: bool = False,
|
||||
use_int4_w4a16: bool = False,
|
||||
ocp_mx_scheme: str | None = None,
|
||||
per_channel_quant: bool = False,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: torch.Tensor | None = None,
|
||||
w1_scale: torch.Tensor | None = None,
|
||||
w2_scale: torch.Tensor | None = None,
|
||||
w1_zp: torch.Tensor | None = None,
|
||||
w2_zp: torch.Tensor | None = None,
|
||||
a1_scale: torch.Tensor | None = None,
|
||||
a2_scale: torch.Tensor | None = None,
|
||||
block_shape: list[int] | None = None,
|
||||
w1_bias: torch.Tensor | None = None,
|
||||
w2_bias: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
return fused_experts_impl(
|
||||
hidden_states,
|
||||
w1,
|
||||
w2,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
True,
|
||||
activation,
|
||||
apply_router_weight_on_input,
|
||||
use_fp8_w8a8,
|
||||
use_int8_w8a8,
|
||||
use_int8_w8a16,
|
||||
use_int4_w4a16,
|
||||
ocp_mx_scheme,
|
||||
per_channel_quant,
|
||||
global_num_experts,
|
||||
expert_map,
|
||||
w1_scale,
|
||||
w2_scale,
|
||||
w1_zp,
|
||||
w2_zp,
|
||||
a1_scale,
|
||||
a2_scale,
|
||||
block_shape,
|
||||
w1_bias,
|
||||
w2_bias,
|
||||
)
|
||||
|
||||
|
||||
def outplace_fused_experts_fake(
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
activation: str = "silu",
|
||||
apply_router_weight_on_input: bool = False,
|
||||
use_fp8_w8a8: bool = False,
|
||||
use_int8_w8a8: bool = False,
|
||||
use_int8_w8a16: bool = False,
|
||||
use_int4_w4a16: bool = False,
|
||||
ocp_mx_scheme: str | None = None,
|
||||
per_channel_quant: bool = False,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: torch.Tensor | None = None,
|
||||
w1_scale: torch.Tensor | None = None,
|
||||
w2_scale: torch.Tensor | None = None,
|
||||
w1_zp: torch.Tensor | None = None,
|
||||
w2_zp: torch.Tensor | None = None,
|
||||
a1_scale: torch.Tensor | None = None,
|
||||
a2_scale: torch.Tensor | None = None,
|
||||
block_shape: list[int] | None = None,
|
||||
w1_bias: torch.Tensor | None = None,
|
||||
w2_bias: torch.Tensor | None = None,
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="outplace_fused_experts_mlu",
|
||||
op_func=outplace_fused_experts,
|
||||
mutates_args=["hidden_states"],
|
||||
fake_impl=outplace_fused_experts_fake,
|
||||
dispatch_key="PrivateUse1",
|
||||
tags=(
|
||||
()
|
||||
if is_torch_equal_or_newer("2.7.0")
|
||||
else (torch.Tag.needs_fixed_stride_order,)
|
||||
),
|
||||
)
|
||||
|
||||
def fused_experts(hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
inplace: bool = False,
|
||||
activation: str = "silu",
|
||||
apply_router_weight_on_input: bool = False,
|
||||
use_fp8_w8a8: bool = False,
|
||||
use_int8_w8a16: bool = False,
|
||||
use_int4_w4a16: bool = False,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
w1_scale: Optional[torch.Tensor] = None,
|
||||
w2_scale: Optional[torch.Tensor] = None,
|
||||
w1_zp: Optional[torch.Tensor] = None,
|
||||
w2_zp: Optional[torch.Tensor] = None,
|
||||
a1_scale: Optional[torch.Tensor] = None,
|
||||
a2_scale: Optional[torch.Tensor] = None,
|
||||
block_shape: Optional[List[int]] = None,
|
||||
allow_deep_gemm: bool = False) -> torch.Tensor:
|
||||
return torch.ops.vllm.outplace_fused_experts_mlu(
|
||||
hidden_states=hidden_states,
|
||||
w1=w1,
|
||||
w2=w2,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
activation=activation,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
use_fp8_w8a8=use_fp8_w8a8,
|
||||
use_int8_w8a16=use_int8_w8a16,
|
||||
use_int4_w4a16=use_int4_w4a16,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
w1_zp=w1_zp,
|
||||
w2_zp=w2_zp,
|
||||
a1_scale=a1_scale,
|
||||
a2_scale=a2_scale,
|
||||
block_shape=block_shape)
|
||||
|
||||
|
||||
SILU_NO_MUL: str = activation_without_mul("silu")
|
||||
GELU_NO_MUL: str = activation_without_mul("gelu")
|
||||
RELU2_NO_MUL: str = activation_without_mul("relu2")
|
||||
|
||||
|
||||
def fused_experts_impl(
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
inplace: bool = False,
|
||||
activation: str = "silu",
|
||||
apply_router_weight_on_input: bool = False,
|
||||
use_fp8_w8a8: bool = False,
|
||||
use_int8_w8a8: bool = False,
|
||||
use_int8_w8a16: bool = False,
|
||||
use_int4_w4a16: bool = False,
|
||||
ocp_mx_scheme: str | None = None,
|
||||
per_channel_quant: bool = False,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: torch.Tensor | None = None,
|
||||
w1_scale: torch.Tensor | None = None,
|
||||
w2_scale: torch.Tensor | None = None,
|
||||
w1_zp: torch.Tensor | None = None,
|
||||
w2_zp: torch.Tensor | None = None,
|
||||
a1_scale: torch.Tensor | None = None,
|
||||
a2_scale: torch.Tensor | None = None,
|
||||
block_shape: list[int] | None = None,
|
||||
w1_bias: torch.Tensor | None = None,
|
||||
w2_bias: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
# Check constraints.
|
||||
if use_int4_w4a16:
|
||||
assert hidden_states.size(1) // 2 == w1.size(2), "Hidden size mismatch"
|
||||
elif ocp_mx_scheme is not None:
|
||||
if ocp_mx_scheme in {
|
||||
"w_mxfp4_a_mxfp4",
|
||||
"w_mxfp4_a_mxfp6_e3m2",
|
||||
"w_mxfp4_a_mxfp6_e2m3",
|
||||
}:
|
||||
# 16bit activation and fp4x2 packed weight
|
||||
assert hidden_states.size(1) == w1.size(2) * 2, "hidden size mismatch"
|
||||
elif ocp_mx_scheme in {
|
||||
"w_mxfp6_e3m2_a_mxfp6_e3m2",
|
||||
"w_mxfp6_e2m3_a_mxfp6_e2m3",
|
||||
}:
|
||||
assert hidden_states.size(1) == (w1.size(2) * 4) // 3, (
|
||||
"hidden size mismatch"
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(f"Unsupported ocp_mx_scheme={ocp_mx_scheme}")
|
||||
else:
|
||||
assert hidden_states.size(1) == w1.size(2), (
|
||||
f"Hidden size mismatch {hidden_states.size(1)} != {w1.size(2)}"
|
||||
)
|
||||
|
||||
assert topk_weights.size() == topk_ids.size(), "topk shape mismatch"
|
||||
assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
|
||||
assert w1.stride(-1) == 1, "Stride of last dimension must be 1"
|
||||
assert w2.stride(-1) == 1, "Stride of last dimension must be 1"
|
||||
assert hidden_states.dtype in [torch.float32, torch.float16, torch.bfloat16]
|
||||
|
||||
num_tokens = hidden_states.size(0)
|
||||
E, N, _ = w1.size()
|
||||
K = w2.size(1)
|
||||
if global_num_experts == -1:
|
||||
global_num_experts = E
|
||||
top_k_num = topk_ids.size(1)
|
||||
# We execute the fused_moe kernel in chunks to circumvent this issue:
|
||||
# https://github.com/vllm-project/vllm/issues/5938
|
||||
CHUNK_SIZE = envs.VLLM_FUSED_MOE_CHUNK_SIZE
|
||||
M = min(num_tokens, CHUNK_SIZE)
|
||||
|
||||
config_dtype = _get_config_dtype_str(
|
||||
use_fp8_w8a8=use_fp8_w8a8,
|
||||
use_int8_w8a16=use_int8_w8a16,
|
||||
use_int4_w4a16=use_int4_w4a16,
|
||||
ocp_mx_scheme=ocp_mx_scheme,
|
||||
dtype=hidden_states.dtype,
|
||||
)
|
||||
|
||||
# Note: for use_int8_w8a16 or use_int4_w4a16, the activations are
|
||||
# quantized prior to calling fused_experts.
|
||||
quant_dtype = _get_config_quant_dtype(
|
||||
use_fp8_w8a8=use_fp8_w8a8,
|
||||
use_int8_w8a8=use_int8_w8a8,
|
||||
ocp_mx_scheme=ocp_mx_scheme,
|
||||
)
|
||||
|
||||
get_config_func = functools.partial(
|
||||
try_get_optimal_moe_config,
|
||||
w1.size(),
|
||||
w2.size(),
|
||||
top_k_num,
|
||||
config_dtype,
|
||||
block_shape=block_shape,
|
||||
)
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: Only use the default config
|
||||
'''
|
||||
config = get_default_config(M, E, N, w1.shape[2], topk_ids.shape[1],
|
||||
hidden_states.dtype, block_shape)
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
|
||||
# We can reuse the memory between these because by the time we need
|
||||
# cache3, we're done with cache1
|
||||
cache13 = torch.empty(
|
||||
M * top_k_num * max(N, K),
|
||||
device=hidden_states.device,
|
||||
dtype=hidden_states.dtype,
|
||||
)
|
||||
intermediate_cache1 = cache13[: M * top_k_num * N].view(M, top_k_num, N)
|
||||
intermediate_cache3 = cache13[: M * top_k_num * K].view(M, top_k_num, K)
|
||||
|
||||
# This needs separate memory since it's used concurrently with cache1
|
||||
intermediate_cache2 = torch.empty(
|
||||
(M * top_k_num, N // 2), device=hidden_states.device, dtype=hidden_states.dtype
|
||||
)
|
||||
|
||||
if hidden_states.dtype == torch.bfloat16:
|
||||
compute_type = tl.bfloat16
|
||||
elif hidden_states.dtype == torch.float16:
|
||||
compute_type = tl.float16
|
||||
elif hidden_states.dtype == torch.float32:
|
||||
compute_type = tl.float32
|
||||
else:
|
||||
raise ValueError(f"Unsupported compute_type: {hidden_states.dtype}")
|
||||
|
||||
if inplace and not disable_inplace():
|
||||
out_hidden_states = hidden_states
|
||||
else:
|
||||
out_hidden_states = torch.empty_like(hidden_states)
|
||||
|
||||
if ocp_mx_scheme is not None:
|
||||
# TODO: On platforms for which `current_platform.supports_mx()` is True
|
||||
# and for which we have a native OCP mx fused MOE kernel,
|
||||
# this dequantization step should not be done.
|
||||
if ocp_mx_scheme in {
|
||||
OCP_MX_Scheme.w_mxfp4_a_mxfp4,
|
||||
OCP_MX_Scheme.w_mxfp4_a_mxfp6_e3m2,
|
||||
OCP_MX_Scheme.w_mxfp4_a_mxfp6_e2m3,
|
||||
}:
|
||||
# Weight has to be dequantized for mxfp4 emulation.
|
||||
w1 = dequant_mxfp4(w1, w1_scale, hidden_states.dtype)
|
||||
w1_scale = None
|
||||
w2 = dequant_mxfp4(w2, w2_scale, hidden_states.dtype)
|
||||
w2_scale = None
|
||||
elif ocp_mx_scheme == OCP_MX_Scheme.w_mxfp6_e3m2_a_mxfp6_e3m2:
|
||||
w1 = dequant_mxfp6(
|
||||
w1, w1_scale, quant_dtype="fp6_e3m2", float_dtype=hidden_states.dtype
|
||||
)
|
||||
w1_scale = None
|
||||
w2 = dequant_mxfp6(
|
||||
w2, w2_scale, quant_dtype="fp6_e3m2", float_dtype=hidden_states.dtype
|
||||
)
|
||||
w2_scale = None
|
||||
elif ocp_mx_scheme == OCP_MX_Scheme.w_mxfp6_e2m3_a_mxfp6_e2m3:
|
||||
w1 = dequant_mxfp6(
|
||||
w1, w1_scale, quant_dtype="fp6_e2m3", float_dtype=hidden_states.dtype
|
||||
)
|
||||
w1_scale = None
|
||||
w2 = dequant_mxfp6(
|
||||
w2, w2_scale, quant_dtype="fp6_e2m3", float_dtype=hidden_states.dtype
|
||||
)
|
||||
w2_scale = None
|
||||
else:
|
||||
raise NotImplementedError(f"Unsupported ocp_mx_scheme={ocp_mx_scheme}")
|
||||
|
||||
for chunk in range((num_tokens // CHUNK_SIZE) + 1):
|
||||
begin_chunk_idx, end_chunk_idx = (
|
||||
chunk * CHUNK_SIZE,
|
||||
min((chunk + 1) * CHUNK_SIZE, num_tokens),
|
||||
)
|
||||
curr_hidden_states = hidden_states[begin_chunk_idx:end_chunk_idx]
|
||||
tokens_in_chunk, _ = curr_hidden_states.size()
|
||||
|
||||
if tokens_in_chunk == 0:
|
||||
break
|
||||
|
||||
if tokens_in_chunk < CHUNK_SIZE and chunk > 0:
|
||||
# Adjust the intermediate cache size and config for the last
|
||||
# chunk. Note that in most cases we only have one chunk
|
||||
# so the cache size and config are already set correctly and
|
||||
# do not need to be adjusted.
|
||||
intermediate_cache1 = intermediate_cache1[:tokens_in_chunk]
|
||||
intermediate_cache3 = intermediate_cache3[:tokens_in_chunk]
|
||||
|
||||
curr_topk_ids = topk_ids[begin_chunk_idx:end_chunk_idx]
|
||||
curr_topk_weights = topk_weights[begin_chunk_idx:end_chunk_idx]
|
||||
|
||||
a1q_scale: Optional[torch.Tensor] = None
|
||||
|
||||
if use_fp8_w8a8:
|
||||
qcurr_hidden_states, a1q_scale = _fp8_quantize(
|
||||
curr_hidden_states, a1_scale, block_shape)
|
||||
else:
|
||||
qcurr_hidden_states = curr_hidden_states
|
||||
sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size(
|
||||
curr_topk_ids, config["BLOCK_SIZE_M"], global_num_experts, expert_map
|
||||
)
|
||||
|
||||
invoke_fused_moe_kernel(
|
||||
qcurr_hidden_states,
|
||||
w1,
|
||||
intermediate_cache1,
|
||||
a1q_scale,
|
||||
w1_scale,
|
||||
w1_zp,
|
||||
curr_topk_weights,
|
||||
sorted_token_ids,
|
||||
expert_ids,
|
||||
num_tokens_post_padded,
|
||||
apply_router_weight_on_input,
|
||||
top_k_num,
|
||||
config,
|
||||
compute_type=compute_type,
|
||||
use_fp8_w8a8=use_fp8_w8a8,
|
||||
use_int8_w8a8=use_int8_w8a8,
|
||||
use_int8_w8a16=use_int8_w8a16,
|
||||
use_int4_w4a16=use_int4_w4a16,
|
||||
per_channel_quant=per_channel_quant,
|
||||
block_shape=block_shape,
|
||||
B_bias=w1_bias,
|
||||
)
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: Activate by mlu_ops
|
||||
'''
|
||||
intermediate_cache2 = mlu_ops.active(intermediate_cache1.view(-1, N),
|
||||
act_mode=activation,
|
||||
is_gated=True)
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
a2q_scale: Optional[torch.Tensor] = None
|
||||
|
||||
if use_fp8_w8a8:
|
||||
qintermediate_cache2, a2q_scale = _fp8_quantize(
|
||||
intermediate_cache2, a2_scale, block_shape)
|
||||
else:
|
||||
qintermediate_cache2 = intermediate_cache2
|
||||
invoke_fused_moe_kernel(
|
||||
qintermediate_cache2,
|
||||
w2,
|
||||
intermediate_cache3,
|
||||
a2q_scale,
|
||||
w2_scale,
|
||||
w2_zp,
|
||||
curr_topk_weights,
|
||||
sorted_token_ids,
|
||||
expert_ids,
|
||||
num_tokens_post_padded,
|
||||
not apply_router_weight_on_input,
|
||||
1,
|
||||
config,
|
||||
compute_type=compute_type,
|
||||
use_fp8_w8a8=use_fp8_w8a8,
|
||||
use_int8_w8a8=use_int8_w8a8,
|
||||
use_int8_w8a16=use_int8_w8a16,
|
||||
use_int4_w4a16=use_int4_w4a16,
|
||||
per_channel_quant=per_channel_quant,
|
||||
block_shape=block_shape,
|
||||
B_bias=w2_bias,
|
||||
)
|
||||
|
||||
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: replace moe_sum with torch.sum
|
||||
Reference Links: https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py#L1513
|
||||
'''
|
||||
if topk_ids.shape[1] == 2:
|
||||
torch.add(
|
||||
intermediate_cache3[:, 0],
|
||||
intermediate_cache3[:, 1],
|
||||
out=out_hidden_states[begin_chunk_idx:end_chunk_idx],
|
||||
).squeeze(dim=1)
|
||||
elif topk_ids.shape[1] > 2:
|
||||
torch.sum(
|
||||
intermediate_cache3.view(*intermediate_cache3.shape),
|
||||
dim=1,
|
||||
out=out_hidden_states[begin_chunk_idx:end_chunk_idx],
|
||||
)
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
return out_hidden_states
|
||||
|
||||
106
vllm_mlu/model_executor/layers/fused_moe/layer.py
Normal file
106
vllm_mlu/model_executor/layers/fused_moe/layer.py
Normal file
@@ -0,0 +1,106 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
|
||||
|
||||
from typing import Optional, Callable
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.fused_moe.layer import UnquantizedFusedMoEMethod
|
||||
|
||||
from vllm_mlu import _mlu_ops as mlu_ops
|
||||
from vllm_mlu.mlu_hijack_utils import MluHijackObject
|
||||
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoE
|
||||
from vllm_mlu.model_executor.layers.fused_moe.fused_moe import fused_experts
|
||||
|
||||
|
||||
def vllm__model_executor__layers__fused_moe__layer__UnquantizedFusedMoEMethod__forward_oot(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
top_k: int,
|
||||
renormalize: bool,
|
||||
use_grouped_topk: bool = False,
|
||||
topk_group: Optional[int] = None,
|
||||
num_expert_group: Optional[int] = None,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
scoring_func: str = "softmax",
|
||||
routed_scaling_factor: float = 1.0,
|
||||
e_score_correction_bias: Optional[torch.Tensor] = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
activation: str = "silu",
|
||||
enable_eplb: bool = False,
|
||||
expert_load_view: Optional[torch.Tensor] = None,
|
||||
logical_to_physical_map: Optional[torch.Tensor] = None,
|
||||
logical_replica_count: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
#TODO: support `routed_scaling_factor`
|
||||
assert routed_scaling_factor == 1.0, (
|
||||
f"routed_scaling_factor {routed_scaling_factor} is not supported for MLU."
|
||||
)
|
||||
use_fused_kernel = topk_group is None
|
||||
if use_fused_kernel:
|
||||
assert not enable_eplb, f"MLU not support eplb in fused_moe kernel."
|
||||
assert use_grouped_topk is False and num_expert_group is None and topk_group is None, \
|
||||
f"Following params: use_grouped_topk, num_expert_group, topk_group are not support yet."
|
||||
return mlu_ops.fused_moe(
|
||||
x,
|
||||
router_logits,
|
||||
layer.w13_weight, layer.w2_weight,
|
||||
None, None, # bias1, bias2
|
||||
None, # residual
|
||||
None, # input_smooth
|
||||
None, # act_smooth
|
||||
None, None, # w1_scale, w2_scale
|
||||
top_k,
|
||||
renormalize,
|
||||
True, # gated
|
||||
activation
|
||||
)
|
||||
else:
|
||||
topk_weights, topk_ids = FusedMoE.select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
use_grouped_topk=use_grouped_topk,
|
||||
top_k=top_k,
|
||||
renormalize=renormalize,
|
||||
topk_group=topk_group,
|
||||
num_expert_group=num_expert_group,
|
||||
custom_routing_function=custom_routing_function,
|
||||
scoring_func=scoring_func,
|
||||
e_score_correction_bias=e_score_correction_bias,
|
||||
indices_type=self.topk_indices_dtype)
|
||||
|
||||
if self.rocm_aiter_moe_enabled:
|
||||
assert expert_map is None
|
||||
return self.rocm_aiter_fused_experts(
|
||||
hidden_states=x,
|
||||
w1=layer.w13_weight,
|
||||
w2=layer.w2_weight,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
activation=activation,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input)
|
||||
else:
|
||||
return fused_experts(
|
||||
hidden_states=x,
|
||||
w1=layer.w13_weight,
|
||||
w2=layer.w2_weight,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
inplace=True,
|
||||
activation=activation,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map,
|
||||
)
|
||||
|
||||
|
||||
MluHijackObject.apply_hijack(
|
||||
UnquantizedFusedMoEMethod,
|
||||
UnquantizedFusedMoEMethod.forward_oot,
|
||||
vllm__model_executor__layers__fused_moe__layer__UnquantizedFusedMoEMethod__forward_oot
|
||||
)
|
||||
248
vllm_mlu/model_executor/layers/fused_moe/moe_align_block_size.py
Normal file
248
vllm_mlu/model_executor/layers/fused_moe/moe_align_block_size.py
Normal file
@@ -0,0 +1,248 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
|
||||
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
import torch
|
||||
|
||||
from vllm.triton_utils import tl, triton
|
||||
from vllm.utils.math_utils import cdiv, round_up
|
||||
|
||||
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: Implementation of moe_align_block_size_triton.
|
||||
Note: the implemtentation has been removed from vllm since the
|
||||
cuda implementation is more efficient.
|
||||
'''
|
||||
@triton.jit
|
||||
def moe_align_block_size_stage1(
|
||||
topk_ids_ptr,
|
||||
tokens_cnts_ptr,
|
||||
num_experts: tl.constexpr,
|
||||
numel: tl.constexpr,
|
||||
tokens_per_thread: tl.constexpr,
|
||||
):
|
||||
pid = tl.program_id(0)
|
||||
|
||||
start_idx = pid * tokens_per_thread
|
||||
|
||||
off_c = (pid + 1) * num_experts
|
||||
|
||||
for i in range(tokens_per_thread):
|
||||
if start_idx + i < numel:
|
||||
idx = tl.load(topk_ids_ptr + start_idx + i)
|
||||
token_cnt = tl.load(tokens_cnts_ptr + off_c + idx)
|
||||
tl.store(tokens_cnts_ptr + off_c + idx, token_cnt + 1)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def moe_align_block_size_stage2(
|
||||
tokens_cnts_ptr,
|
||||
num_experts: tl.constexpr,
|
||||
):
|
||||
pid = tl.program_id(0)
|
||||
|
||||
last_cnt = 0
|
||||
for i in range(1, num_experts + 1):
|
||||
token_cnt = tl.load(tokens_cnts_ptr + i * num_experts + pid)
|
||||
last_cnt = last_cnt + token_cnt
|
||||
tl.store(tokens_cnts_ptr + i * num_experts + pid, last_cnt)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def moe_align_block_size_stage3(
|
||||
total_tokens_post_pad_ptr,
|
||||
tokens_cnts_ptr,
|
||||
cumsum_ptr,
|
||||
num_experts: tl.constexpr,
|
||||
block_size: tl.constexpr,
|
||||
):
|
||||
last_cumsum = 0
|
||||
off_cnt = num_experts * num_experts
|
||||
for i in range(1, num_experts + 1):
|
||||
token_cnt = tl.load(tokens_cnts_ptr + off_cnt + i - 1)
|
||||
last_cumsum = last_cumsum + tl.cdiv(token_cnt, block_size) * block_size
|
||||
tl.store(cumsum_ptr + i, last_cumsum)
|
||||
tl.store(total_tokens_post_pad_ptr, last_cumsum)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def moe_align_block_size_stage4(
|
||||
topk_ids_ptr,
|
||||
sorted_token_ids_ptr,
|
||||
expert_ids_ptr,
|
||||
tokens_cnts_ptr,
|
||||
cumsum_ptr,
|
||||
num_experts: tl.constexpr,
|
||||
block_size: tl.constexpr,
|
||||
numel: tl.constexpr,
|
||||
tokens_per_thread: tl.constexpr,
|
||||
):
|
||||
pid = tl.program_id(0)
|
||||
start_idx = tl.load(cumsum_ptr + pid)
|
||||
end_idx = tl.load(cumsum_ptr + pid + 1)
|
||||
|
||||
for i in range(start_idx, end_idx, block_size):
|
||||
tl.store(expert_ids_ptr + i // block_size, pid)
|
||||
|
||||
start_idx = pid * tokens_per_thread
|
||||
off_t = pid * num_experts
|
||||
|
||||
for i in range(start_idx, tl.minimum(start_idx + tokens_per_thread,
|
||||
numel)):
|
||||
expert_id = tl.load(topk_ids_ptr + i)
|
||||
token_cnt = tl.load(tokens_cnts_ptr + off_t + expert_id)
|
||||
rank_post_pad = token_cnt + tl.load(cumsum_ptr + expert_id)
|
||||
tl.store(sorted_token_ids_ptr + rank_post_pad, i)
|
||||
tl.store(tokens_cnts_ptr + off_t + expert_id, token_cnt + 1)
|
||||
|
||||
|
||||
# Triton implementation based on:
|
||||
# https://github.com/sgl-project/sglang/commit/ba5112ff691d791a9e38c6c71f59324a5fcb49d0
|
||||
def moe_align_block_size_triton(
|
||||
topk_ids: torch.Tensor,
|
||||
num_experts: int,
|
||||
block_size: int,
|
||||
sorted_token_ids: torch.Tensor,
|
||||
expert_ids: torch.Tensor,
|
||||
num_tokens_post_pad: torch.Tensor,
|
||||
) -> None:
|
||||
numel = topk_ids.numel()
|
||||
grid = (num_experts, )
|
||||
tokens_cnts = torch.zeros((num_experts + 1, num_experts),
|
||||
dtype=torch.int32,
|
||||
device=topk_ids.device)
|
||||
cumsum = torch.zeros((num_experts + 1, ),
|
||||
dtype=torch.int32,
|
||||
device=topk_ids.device)
|
||||
tokens_per_thread = cdiv(numel, num_experts)
|
||||
sorted_token_ids.fill_(numel)
|
||||
expert_ids.zero_()
|
||||
|
||||
moe_align_block_size_stage1[grid](
|
||||
topk_ids,
|
||||
tokens_cnts,
|
||||
num_experts,
|
||||
numel,
|
||||
tokens_per_thread,
|
||||
)
|
||||
moe_align_block_size_stage2[grid](
|
||||
tokens_cnts,
|
||||
num_experts,
|
||||
)
|
||||
moe_align_block_size_stage3[(1, )](
|
||||
num_tokens_post_pad,
|
||||
tokens_cnts,
|
||||
cumsum,
|
||||
num_experts,
|
||||
block_size,
|
||||
)
|
||||
moe_align_block_size_stage4[grid](
|
||||
topk_ids,
|
||||
sorted_token_ids,
|
||||
expert_ids,
|
||||
tokens_cnts,
|
||||
cumsum,
|
||||
num_experts,
|
||||
block_size,
|
||||
numel,
|
||||
tokens_per_thread,
|
||||
)
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
|
||||
|
||||
def moe_align_block_size(
|
||||
topk_ids: torch.Tensor,
|
||||
block_size: int,
|
||||
num_experts: int,
|
||||
expert_map: torch.Tensor | None = None,
|
||||
pad_sorted_ids: bool = False
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Aligns the token distribution across experts to be compatible with block
|
||||
size for matrix multiplication.
|
||||
|
||||
Parameters:
|
||||
- topk_ids: A tensor of shape [total_tokens, top_k] representing the
|
||||
top-k expert indices for each token.
|
||||
- block_size: The block size used in block matrix multiplication.
|
||||
- num_experts: The total number of experts.
|
||||
- expert_map: A tensor of shape [num_experts] that maps the expert index
|
||||
from the global space to the local index space of the current
|
||||
expert parallel shard. If the expert is not in the current expert
|
||||
parallel shard, the mapping is set to -1.
|
||||
- pad_sorted_ids: A flag indicating whether the sorted_token_ids length
|
||||
should be padded to a multiple of block_size,
|
||||
|
||||
Returns:
|
||||
- sorted_token_ids: A tensor containing the sorted token indices according
|
||||
to their allocated expert.
|
||||
- expert_ids: A tensor indicating the assigned expert index for each block.
|
||||
- num_tokens_post_padded: The total number of tokens after padding,
|
||||
ensuring divisibility by block_size.
|
||||
|
||||
This function pads the number of tokens that each expert needs to process
|
||||
so that it is divisible by block_size.
|
||||
Padding ensures that during block matrix multiplication, the dimensions
|
||||
align correctly.
|
||||
|
||||
Example:
|
||||
Given topk_ids = [[2, 3, 4], [1, 2, 4], [1, 3, 4], [1, 2, 3]],
|
||||
block_size = 4, and num_experts = 4:
|
||||
- We initially have 12 tokens (after repeating 'top_k' times) and 4 experts,
|
||||
with each expert needing to process 3 tokens.
|
||||
- As block_size is 4, we pad 1 token for each expert.
|
||||
- First, flatten topk_ids to [2, 3, 4, 1, 2, 4, 1, 3, 4, 1, 2, 3].
|
||||
- Then append padding tokens [12, 12, 12, 12] for each block.
|
||||
- After sorting by expert index, we obtain token_ids
|
||||
[3, 6, 9, 12, 0, 4, 10, 12, 1, 7, 11, 12, 2, 5, 8, 12].
|
||||
Tokens 12 are non-existent (padding) and are ignored in
|
||||
the subsequent matrix multiplication.
|
||||
- The padding ensures that the total number of tokens is now divisible
|
||||
by block_size for proper block matrix operations.
|
||||
"""
|
||||
max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1)
|
||||
if pad_sorted_ids:
|
||||
max_num_tokens_padded = round_up(max_num_tokens_padded, block_size)
|
||||
sorted_ids = torch.empty((max_num_tokens_padded, ),
|
||||
dtype=torch.int32,
|
||||
device=topk_ids.device)
|
||||
sorted_ids.fill_(topk_ids.numel())
|
||||
max_num_m_blocks = triton.cdiv(max_num_tokens_padded, block_size)
|
||||
# Expert ids must be zeroed out to prevent index out of bounds error while
|
||||
# mapping global expert ids to local expert ids in expert parallelism.
|
||||
expert_ids = torch.zeros((max_num_m_blocks, ),
|
||||
dtype=torch.int32,
|
||||
device=topk_ids.device)
|
||||
num_tokens_post_pad = torch.empty((1),
|
||||
dtype=torch.int32,
|
||||
device=topk_ids.device)
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: Only use triton to implement moe_align_block_size
|
||||
'''
|
||||
moe_align_block_size_triton(
|
||||
topk_ids,
|
||||
num_experts,
|
||||
block_size,
|
||||
sorted_ids,
|
||||
expert_ids,
|
||||
num_tokens_post_pad,
|
||||
)
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
if expert_map is not None:
|
||||
expert_ids = expert_map[expert_ids]
|
||||
|
||||
return sorted_ids, expert_ids, num_tokens_post_pad
|
||||
31
vllm_mlu/model_executor/layers/fused_moe/utils.py
Normal file
31
vllm_mlu/model_executor/layers/fused_moe/utils.py
Normal file
@@ -0,0 +1,31 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
|
||||
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
from math import prod
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.utils.math_utils import cdiv
|
||||
|
||||
|
||||
def _fp8_quantize(
|
||||
A: torch.Tensor,
|
||||
A_scale: Optional[torch.Tensor],
|
||||
block_shape: List[int],
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Perform fp8 quantization on the inputs. If a block_shape
|
||||
is provided, the output will be blocked.
|
||||
"""
|
||||
from vllm_mlu.model_executor.layers.quantization.utils.fp8_utils import (
|
||||
per_token_group_quant_fp8)
|
||||
assert block_shape is not None
|
||||
assert len(block_shape) == 2
|
||||
_, block_k = block_shape[0], block_shape[1]
|
||||
A, A_scale = per_token_group_quant_fp8(A, block_k)
|
||||
assert cdiv(A.shape[-1], block_k) == A_scale.shape[-1]
|
||||
return A, A_scale
|
||||
|
||||
|
||||
Reference in New Issue
Block a user