Update fused_moe (#553)
This commit is contained in:
@@ -9,9 +9,9 @@ from typing import Any, Dict, Optional, Tuple
|
|||||||
import torch
|
import torch
|
||||||
import triton
|
import triton
|
||||||
import triton.language as tl
|
import triton.language as tl
|
||||||
|
|
||||||
from vllm import _custom_ops as ops
|
from vllm import _custom_ops as ops
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.utils import is_hip
|
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
@@ -108,16 +108,12 @@ def fused_moe_kernel(
|
|||||||
|
|
||||||
offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
|
offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
|
||||||
offs_k = tl.arange(0, BLOCK_SIZE_K)
|
offs_k = tl.arange(0, BLOCK_SIZE_K)
|
||||||
a_ptrs = a_ptr + (
|
a_ptrs = a_ptr + (offs_token[:, None] // top_k * stride_am +
|
||||||
offs_token[:, None] // top_k * stride_am + offs_k[None, :] * stride_ak
|
offs_k[None, :] * stride_ak)
|
||||||
)
|
|
||||||
|
|
||||||
off_experts = tl.load(expert_ids_ptr + pid_m)
|
off_experts = tl.load(expert_ids_ptr + pid_m)
|
||||||
b_ptrs = (
|
b_ptrs = b_ptr + off_experts * stride_be + (offs_k[:, None] * stride_bk +
|
||||||
b_ptr
|
offs_bn[None, :] * stride_bn)
|
||||||
+ off_experts * stride_be
|
|
||||||
+ (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
|
|
||||||
)
|
|
||||||
|
|
||||||
if use_fp8:
|
if use_fp8:
|
||||||
a_scale = tl.load(a_scale_ptr)
|
a_scale = tl.load(a_scale_ptr)
|
||||||
@@ -133,12 +129,13 @@ def fused_moe_kernel(
|
|||||||
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
|
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
|
||||||
# Load the next block of A and B, generate a mask by checking the
|
# Load the next block of A and B, generate a mask by checking the
|
||||||
# K dimension.
|
# K dimension.
|
||||||
a = tl.load(
|
a = tl.load(a_ptrs,
|
||||||
a_ptrs,
|
mask=token_mask[:, None] &
|
||||||
mask=token_mask[:, None] & (offs_k[None, :] < K - k * BLOCK_SIZE_K),
|
(offs_k[None, :] < K - k * BLOCK_SIZE_K),
|
||||||
other=0.0,
|
other=0.0)
|
||||||
)
|
b = tl.load(b_ptrs,
|
||||||
b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)
|
mask=offs_k[:, None] < K - k * BLOCK_SIZE_K,
|
||||||
|
other=0.0)
|
||||||
# We accumulate along the K dimension.
|
# We accumulate along the K dimension.
|
||||||
if use_fp8:
|
if use_fp8:
|
||||||
accumulator = tl.dot(a, b, acc=accumulator)
|
accumulator = tl.dot(a, b, acc=accumulator)
|
||||||
@@ -149,7 +146,9 @@ def fused_moe_kernel(
|
|||||||
b_ptrs += BLOCK_SIZE_K * stride_bk
|
b_ptrs += BLOCK_SIZE_K * stride_bk
|
||||||
|
|
||||||
if MUL_ROUTED_WEIGHT:
|
if MUL_ROUTED_WEIGHT:
|
||||||
moe_weight = tl.load(topk_weights_ptr + offs_token, mask=token_mask, other=0)
|
moe_weight = tl.load(topk_weights_ptr + offs_token,
|
||||||
|
mask=token_mask,
|
||||||
|
other=0)
|
||||||
accumulator = accumulator * moe_weight[:, None]
|
accumulator = accumulator * moe_weight[:, None]
|
||||||
|
|
||||||
if use_fp8:
|
if use_fp8:
|
||||||
@@ -159,14 +158,15 @@ def fused_moe_kernel(
|
|||||||
# -----------------------------------------------------------
|
# -----------------------------------------------------------
|
||||||
# Write back the block of the output
|
# Write back the block of the output
|
||||||
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
||||||
c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[None, :]
|
c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[
|
||||||
|
None, :]
|
||||||
c_mask = token_mask[:, None] & (offs_cn[None, :] < N)
|
c_mask = token_mask[:, None] & (offs_cn[None, :] < N)
|
||||||
tl.store(c_ptrs, accumulator, mask=c_mask)
|
tl.store(c_ptrs, accumulator, mask=c_mask)
|
||||||
|
|
||||||
|
|
||||||
def moe_align_block_size(
|
def moe_align_block_size(
|
||||||
topk_ids: torch.Tensor, block_size: int, num_experts: int
|
topk_ids: torch.Tensor, block_size: int,
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
num_experts: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||||
"""
|
"""
|
||||||
Aligns the token distribution across experts to be compatible with block
|
Aligns the token distribution across experts to be compatible with block
|
||||||
size for matrix multiplication.
|
size for matrix multiplication.
|
||||||
@@ -205,38 +205,32 @@ def moe_align_block_size(
|
|||||||
by block_size for proper block matrix operations.
|
by block_size for proper block matrix operations.
|
||||||
"""
|
"""
|
||||||
max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1)
|
max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1)
|
||||||
sorted_ids = torch.empty(
|
sorted_ids = torch.empty((max_num_tokens_padded, ),
|
||||||
(max_num_tokens_padded,), dtype=torch.int32, device=topk_ids.device
|
dtype=torch.int32,
|
||||||
)
|
device=topk_ids.device)
|
||||||
sorted_ids.fill_(topk_ids.numel())
|
sorted_ids.fill_(topk_ids.numel())
|
||||||
max_num_m_blocks = triton.cdiv(max_num_tokens_padded, block_size)
|
max_num_m_blocks = triton.cdiv(max_num_tokens_padded, block_size)
|
||||||
expert_ids = torch.empty(
|
expert_ids = torch.empty((max_num_m_blocks, ),
|
||||||
(max_num_m_blocks,), dtype=torch.int32, device=topk_ids.device
|
dtype=torch.int32,
|
||||||
)
|
device=topk_ids.device)
|
||||||
num_tokens_post_pad = torch.empty((1), dtype=torch.int32, device=topk_ids.device)
|
num_tokens_post_pad = torch.empty((1),
|
||||||
ops.moe_align_block_size(
|
dtype=torch.int32,
|
||||||
topk_ids, num_experts, block_size, sorted_ids, expert_ids, num_tokens_post_pad
|
device=topk_ids.device)
|
||||||
)
|
ops.moe_align_block_size(topk_ids, num_experts, block_size, sorted_ids,
|
||||||
|
expert_ids, num_tokens_post_pad)
|
||||||
return sorted_ids, expert_ids, num_tokens_post_pad
|
return sorted_ids, expert_ids, num_tokens_post_pad
|
||||||
|
|
||||||
|
|
||||||
def invoke_fused_moe_kernel(
|
def invoke_fused_moe_kernel(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor,
|
||||||
A: torch.Tensor,
|
A_scale: Optional[torch.Tensor],
|
||||||
B: torch.Tensor,
|
B_scale: Optional[torch.Tensor],
|
||||||
C: torch.Tensor,
|
topk_weights: torch.Tensor, topk_ids: torch.Tensor,
|
||||||
A_scale: Optional[torch.Tensor],
|
sorted_token_ids: torch.Tensor,
|
||||||
B_scale: Optional[torch.Tensor],
|
expert_ids: torch.Tensor,
|
||||||
topk_weights: torch.Tensor,
|
num_tokens_post_padded: torch.Tensor,
|
||||||
topk_ids: torch.Tensor,
|
mul_routed_weight: bool, top_k: int,
|
||||||
sorted_token_ids: torch.Tensor,
|
config: Dict[str, Any], compute_type: tl.dtype,
|
||||||
expert_ids: torch.Tensor,
|
use_fp8: bool) -> None:
|
||||||
num_tokens_post_padded: torch.Tensor,
|
|
||||||
mul_routed_weight: bool,
|
|
||||||
top_k: int,
|
|
||||||
config: Dict[str, Any],
|
|
||||||
compute_type: tl.dtype,
|
|
||||||
use_fp8: bool,
|
|
||||||
) -> None:
|
|
||||||
assert topk_weights.stride(1) == 1
|
assert topk_weights.stride(1) == 1
|
||||||
assert sorted_token_ids.stride(0) == 1
|
assert sorted_token_ids.stride(0) == 1
|
||||||
|
|
||||||
@@ -247,10 +241,8 @@ def invoke_fused_moe_kernel(
|
|||||||
A, A_scale = ops.scaled_fp8_quant(A, A_scale)
|
A, A_scale = ops.scaled_fp8_quant(A, A_scale)
|
||||||
assert B_scale is not None
|
assert B_scale is not None
|
||||||
|
|
||||||
grid = lambda META: (
|
grid = lambda META: (triton.cdiv(sorted_token_ids.shape[0], META[
|
||||||
triton.cdiv(sorted_token_ids.shape[0], META["BLOCK_SIZE_M"])
|
'BLOCK_SIZE_M']) * triton.cdiv(B.shape[1], META['BLOCK_SIZE_N']), )
|
||||||
* triton.cdiv(B.shape[1], META["BLOCK_SIZE_N"]),
|
|
||||||
)
|
|
||||||
|
|
||||||
fused_moe_kernel[grid](
|
fused_moe_kernel[grid](
|
||||||
A,
|
A,
|
||||||
@@ -288,7 +280,8 @@ def get_config_file_name(E: int, N: int, dtype: Optional[str]) -> str:
|
|||||||
|
|
||||||
|
|
||||||
@functools.lru_cache
|
@functools.lru_cache
|
||||||
def get_moe_configs(E: int, N: int, dtype: Optional[str]) -> Optional[Dict[int, Any]]:
|
def get_moe_configs(E: int, N: int,
|
||||||
|
dtype: Optional[str]) -> Optional[Dict[int, Any]]:
|
||||||
"""
|
"""
|
||||||
Return optimized configurations for the fused MoE kernel.
|
Return optimized configurations for the fused MoE kernel.
|
||||||
|
|
||||||
@@ -303,11 +296,11 @@ def get_moe_configs(E: int, N: int, dtype: Optional[str]) -> Optional[Dict[int,
|
|||||||
json_file_name = get_config_file_name(E, N, dtype)
|
json_file_name = get_config_file_name(E, N, dtype)
|
||||||
|
|
||||||
config_file_path = os.path.join(
|
config_file_path = os.path.join(
|
||||||
os.path.dirname(os.path.realpath(__file__)), "configs", json_file_name
|
os.path.dirname(os.path.realpath(__file__)), "configs", json_file_name)
|
||||||
)
|
|
||||||
if os.path.exists(config_file_path):
|
if os.path.exists(config_file_path):
|
||||||
with open(config_file_path) as f:
|
with open(config_file_path) as f:
|
||||||
logger.info("Using configuration from %s for MoE layer.", config_file_path)
|
logger.info("Using configuration from %s for MoE layer.",
|
||||||
|
config_file_path)
|
||||||
# If a configuration has been found, return it
|
# If a configuration has been found, return it
|
||||||
return {int(key): val for key, val in json.load(f).items()}
|
return {int(key): val for key, val in json.load(f).items()}
|
||||||
|
|
||||||
@@ -316,6 +309,165 @@ def get_moe_configs(E: int, N: int, dtype: Optional[str]) -> Optional[Dict[int,
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def get_default_config(
|
||||||
|
M: int,
|
||||||
|
E: int,
|
||||||
|
N: int,
|
||||||
|
K: int,
|
||||||
|
topk: int,
|
||||||
|
dtype: Optional[str],
|
||||||
|
) -> Dict[str, int]:
|
||||||
|
config = {
|
||||||
|
'BLOCK_SIZE_M': 64,
|
||||||
|
'BLOCK_SIZE_N': 64,
|
||||||
|
'BLOCK_SIZE_K': 32,
|
||||||
|
'GROUP_SIZE_M': 8
|
||||||
|
}
|
||||||
|
if M <= E:
|
||||||
|
config = {
|
||||||
|
'BLOCK_SIZE_M': 16,
|
||||||
|
'BLOCK_SIZE_N': 32,
|
||||||
|
'BLOCK_SIZE_K': 64,
|
||||||
|
'GROUP_SIZE_M': 1
|
||||||
|
}
|
||||||
|
return config
|
||||||
|
|
||||||
|
|
||||||
|
def fused_topk(
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
gating_output: torch.Tensor,
|
||||||
|
topk: int,
|
||||||
|
renormalize: bool,
|
||||||
|
):
|
||||||
|
assert hidden_states.shape[0] == gating_output.shape[0], (
|
||||||
|
"Number of tokens mismatch")
|
||||||
|
|
||||||
|
M, _ = hidden_states.shape
|
||||||
|
|
||||||
|
topk_weights = torch.empty(M,
|
||||||
|
topk,
|
||||||
|
dtype=torch.float32,
|
||||||
|
device=hidden_states.device)
|
||||||
|
topk_ids = torch.empty(M,
|
||||||
|
topk,
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=hidden_states.device)
|
||||||
|
token_expert_indicies = torch.empty(M,
|
||||||
|
topk,
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=hidden_states.device)
|
||||||
|
ops.topk_softmax(
|
||||||
|
topk_weights,
|
||||||
|
topk_ids,
|
||||||
|
token_expert_indicies,
|
||||||
|
gating_output.float(), # TODO(woosuk): Optimize this.
|
||||||
|
)
|
||||||
|
del token_expert_indicies # Not used. Will be used in the future.
|
||||||
|
|
||||||
|
if renormalize:
|
||||||
|
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
|
||||||
|
return topk_weights, topk_ids
|
||||||
|
|
||||||
|
|
||||||
|
def fused_experts(hidden_states: torch.Tensor,
|
||||||
|
w1: torch.Tensor,
|
||||||
|
w2: torch.Tensor,
|
||||||
|
topk_weights: torch.Tensor,
|
||||||
|
topk_ids: torch.Tensor,
|
||||||
|
inplace: bool = False,
|
||||||
|
override_config: Optional[Dict[str, Any]] = None,
|
||||||
|
use_fp8: bool = False,
|
||||||
|
w1_scale: Optional[torch.Tensor] = None,
|
||||||
|
w2_scale: Optional[torch.Tensor] = None,
|
||||||
|
a1_scale: Optional[torch.Tensor] = None,
|
||||||
|
a2_scale: Optional[torch.Tensor] = None):
|
||||||
|
# Check constraints.
|
||||||
|
assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch"
|
||||||
|
assert topk_weights.shape == topk_ids.shape, "topk shape mismatch"
|
||||||
|
assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
|
||||||
|
assert w1.is_contiguous(), "Expert weights1 must be contiguous"
|
||||||
|
assert w2.is_contiguous(), "Expert weights2 must be contiguous"
|
||||||
|
assert hidden_states.dtype in [
|
||||||
|
torch.float32, torch.float16, torch.bfloat16
|
||||||
|
]
|
||||||
|
|
||||||
|
M, _ = hidden_states.shape
|
||||||
|
E, N, _ = w1.shape
|
||||||
|
|
||||||
|
if override_config:
|
||||||
|
config = override_config
|
||||||
|
else:
|
||||||
|
# First try to load optimal config from the file
|
||||||
|
configs = get_moe_configs(E, w2.shape[2],
|
||||||
|
"float8" if use_fp8 else None)
|
||||||
|
|
||||||
|
if configs:
|
||||||
|
# If an optimal configuration map has been found, look up the
|
||||||
|
# optimal config
|
||||||
|
config = configs[min(configs.keys(), key=lambda x: abs(x - M))]
|
||||||
|
else:
|
||||||
|
# Else use the default config
|
||||||
|
config = get_default_config(M, E, N, w1.shape[2],
|
||||||
|
topk_ids.shape[1],
|
||||||
|
"float8" if use_fp8 else None)
|
||||||
|
|
||||||
|
intermediate_cache1 = torch.empty((M, topk_ids.shape[1], N),
|
||||||
|
device=hidden_states.device,
|
||||||
|
dtype=hidden_states.dtype)
|
||||||
|
intermediate_cache2 = torch.empty((M * topk_ids.shape[1], N // 2),
|
||||||
|
device=hidden_states.device,
|
||||||
|
dtype=hidden_states.dtype)
|
||||||
|
intermediate_cache3 = torch.empty((M, topk_ids.shape[1], w2.shape[1]),
|
||||||
|
device=hidden_states.device,
|
||||||
|
dtype=hidden_states.dtype)
|
||||||
|
|
||||||
|
sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size(
|
||||||
|
topk_ids, config['BLOCK_SIZE_M'], E)
|
||||||
|
compute_type = (tl.bfloat16
|
||||||
|
if hidden_states.dtype == torch.bfloat16 else tl.float16)
|
||||||
|
|
||||||
|
invoke_fused_moe_kernel(hidden_states,
|
||||||
|
w1,
|
||||||
|
intermediate_cache1,
|
||||||
|
a1_scale,
|
||||||
|
w1_scale,
|
||||||
|
topk_weights,
|
||||||
|
topk_ids,
|
||||||
|
sorted_token_ids,
|
||||||
|
expert_ids,
|
||||||
|
num_tokens_post_padded,
|
||||||
|
False,
|
||||||
|
topk_ids.shape[1],
|
||||||
|
config,
|
||||||
|
compute_type=compute_type,
|
||||||
|
use_fp8=use_fp8)
|
||||||
|
|
||||||
|
ops.gelu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N))
|
||||||
|
|
||||||
|
invoke_fused_moe_kernel(intermediate_cache2,
|
||||||
|
w2,
|
||||||
|
intermediate_cache3,
|
||||||
|
a2_scale,
|
||||||
|
w2_scale,
|
||||||
|
topk_weights,
|
||||||
|
topk_ids,
|
||||||
|
sorted_token_ids,
|
||||||
|
expert_ids,
|
||||||
|
num_tokens_post_padded,
|
||||||
|
True,
|
||||||
|
1,
|
||||||
|
config,
|
||||||
|
compute_type=compute_type,
|
||||||
|
use_fp8=use_fp8)
|
||||||
|
|
||||||
|
if inplace:
|
||||||
|
return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape),
|
||||||
|
dim=1,
|
||||||
|
out=hidden_states)
|
||||||
|
return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape),
|
||||||
|
dim=1)
|
||||||
|
|
||||||
|
|
||||||
def fused_moe(
|
def fused_moe(
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
w1: torch.Tensor,
|
w1: torch.Tensor,
|
||||||
@@ -358,134 +510,19 @@ def fused_moe(
|
|||||||
- torch.Tensor: The output tensor after applying the MoE layer.
|
- torch.Tensor: The output tensor after applying the MoE layer.
|
||||||
"""
|
"""
|
||||||
# Check constraints.
|
# Check constraints.
|
||||||
assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
|
|
||||||
assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch"
|
|
||||||
assert gating_output.shape[1] == w1.shape[0], "Number of experts mismatch"
|
assert gating_output.shape[1] == w1.shape[0], "Number of experts mismatch"
|
||||||
assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
|
|
||||||
assert w1.is_contiguous(), "Expert weights1 must be contiguous"
|
|
||||||
assert w2.is_contiguous(), "Expert weights2 must be contiguous"
|
|
||||||
assert hidden_states.dtype in [torch.float32, torch.float16, torch.bfloat16]
|
|
||||||
M, _ = hidden_states.shape
|
|
||||||
E, N, _ = w1.shape
|
|
||||||
|
|
||||||
if is_hip():
|
topk_weights, topk_ids = fused_topk(hidden_states, gating_output, topk,
|
||||||
# The MoE kernels are not yet supported on ROCm.
|
renormalize)
|
||||||
routing_weights = torch.softmax(gating_output, dim=-1, dtype=torch.float32)
|
return fused_experts(hidden_states,
|
||||||
topk_weights, topk_ids = torch.topk(routing_weights, topk, dim=-1)
|
w1,
|
||||||
else:
|
w2,
|
||||||
import vllm._moe_C as moe_kernels
|
topk_weights,
|
||||||
|
topk_ids,
|
||||||
topk_weights = torch.empty(
|
inplace=inplace,
|
||||||
M, topk, dtype=torch.float32, device=hidden_states.device
|
override_config=override_config,
|
||||||
)
|
use_fp8=use_fp8,
|
||||||
topk_ids = torch.empty(M, topk, dtype=torch.int32, device=hidden_states.device)
|
w1_scale=w1_scale,
|
||||||
token_expert_indicies = torch.empty(
|
w2_scale=w2_scale,
|
||||||
M, topk, dtype=torch.int32, device=hidden_states.device
|
a1_scale=a1_scale,
|
||||||
)
|
a2_scale=a2_scale)
|
||||||
moe_kernels.topk_softmax(
|
|
||||||
topk_weights,
|
|
||||||
topk_ids,
|
|
||||||
token_expert_indicies,
|
|
||||||
gating_output.float(), # TODO(woosuk): Optimize this.
|
|
||||||
)
|
|
||||||
del token_expert_indicies # Not used. Will be used in the future.
|
|
||||||
if renormalize:
|
|
||||||
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
|
|
||||||
|
|
||||||
if override_config:
|
|
||||||
config = override_config
|
|
||||||
else:
|
|
||||||
# First try to load optimal config from the file
|
|
||||||
configs = get_moe_configs(E, w2.shape[2], "float8" if use_fp8 else None)
|
|
||||||
|
|
||||||
if configs:
|
|
||||||
# If an optimal configuration map has been found, look up the
|
|
||||||
# optimal config
|
|
||||||
config = configs[min(configs.keys(), key=lambda x: abs(x - M))]
|
|
||||||
else:
|
|
||||||
# Else use the default config
|
|
||||||
config = {
|
|
||||||
"BLOCK_SIZE_M": 128,
|
|
||||||
"BLOCK_SIZE_N": 64,
|
|
||||||
"BLOCK_SIZE_K": 128,
|
|
||||||
"GROUP_SIZE_M": 1,
|
|
||||||
"num_warps": 4,
|
|
||||||
"num_stages": 4,
|
|
||||||
}
|
|
||||||
|
|
||||||
if M <= E:
|
|
||||||
config = {
|
|
||||||
"BLOCK_SIZE_M": 128,
|
|
||||||
"BLOCK_SIZE_N": 256,
|
|
||||||
"BLOCK_SIZE_K": 128,
|
|
||||||
"GROUP_SIZE_M": 16,
|
|
||||||
"num_warps": 8,
|
|
||||||
"num_stages": 4,
|
|
||||||
}
|
|
||||||
|
|
||||||
intermediate_cache1 = torch.empty(
|
|
||||||
(M, topk_ids.shape[1], N),
|
|
||||||
device=hidden_states.device,
|
|
||||||
dtype=hidden_states.dtype,
|
|
||||||
)
|
|
||||||
intermediate_cache2 = torch.empty(
|
|
||||||
(M * topk_ids.shape[1], N // 2),
|
|
||||||
device=hidden_states.device,
|
|
||||||
dtype=hidden_states.dtype,
|
|
||||||
)
|
|
||||||
intermediate_cache3 = torch.empty(
|
|
||||||
(M, topk_ids.shape[1], w2.shape[1]),
|
|
||||||
device=hidden_states.device,
|
|
||||||
dtype=hidden_states.dtype,
|
|
||||||
)
|
|
||||||
|
|
||||||
sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size(
|
|
||||||
topk_ids, config["BLOCK_SIZE_M"], E
|
|
||||||
)
|
|
||||||
compute_type = tl.bfloat16 if hidden_states.dtype == torch.bfloat16 else tl.float16
|
|
||||||
|
|
||||||
invoke_fused_moe_kernel(
|
|
||||||
hidden_states,
|
|
||||||
w1,
|
|
||||||
intermediate_cache1,
|
|
||||||
a1_scale,
|
|
||||||
w1_scale,
|
|
||||||
topk_weights,
|
|
||||||
topk_ids,
|
|
||||||
sorted_token_ids,
|
|
||||||
expert_ids,
|
|
||||||
num_tokens_post_padded,
|
|
||||||
False,
|
|
||||||
topk_ids.shape[1],
|
|
||||||
config,
|
|
||||||
compute_type=compute_type,
|
|
||||||
use_fp8=use_fp8,
|
|
||||||
)
|
|
||||||
|
|
||||||
ops.gelu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N))
|
|
||||||
|
|
||||||
invoke_fused_moe_kernel(
|
|
||||||
intermediate_cache2,
|
|
||||||
w2,
|
|
||||||
intermediate_cache3,
|
|
||||||
a2_scale,
|
|
||||||
w2_scale,
|
|
||||||
topk_weights,
|
|
||||||
topk_ids,
|
|
||||||
sorted_token_ids,
|
|
||||||
expert_ids,
|
|
||||||
num_tokens_post_padded,
|
|
||||||
True,
|
|
||||||
1,
|
|
||||||
config,
|
|
||||||
compute_type=compute_type,
|
|
||||||
use_fp8=use_fp8,
|
|
||||||
)
|
|
||||||
|
|
||||||
if inplace:
|
|
||||||
return torch.sum(
|
|
||||||
intermediate_cache3.view(*intermediate_cache3.shape),
|
|
||||||
dim=1,
|
|
||||||
out=hidden_states,
|
|
||||||
)
|
|
||||||
return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape), dim=1)
|
|
||||||
Reference in New Issue
Block a user