Fp8 MoE optimizations on AMD (#2388)
This commit is contained in:
@@ -16,6 +16,7 @@ from vllm import _custom_ops as ops
|
|||||||
from sglang.srt.utils import direct_register_custom_op, get_device_name
|
from sglang.srt.utils import direct_register_custom_op, get_device_name
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
padding_size = 128 if bool(int(os.getenv("MOE_PADDING", "0"))) else 0
|
||||||
|
|
||||||
|
|
||||||
@triton.jit
|
@triton.jit
|
||||||
@@ -58,6 +59,7 @@ def fused_moe_kernel(
|
|||||||
compute_type: tl.constexpr,
|
compute_type: tl.constexpr,
|
||||||
use_fp8_w8a8: tl.constexpr,
|
use_fp8_w8a8: tl.constexpr,
|
||||||
use_int8_w8a16: tl.constexpr,
|
use_int8_w8a16: tl.constexpr,
|
||||||
|
even_Ks: tl.constexpr,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Implements the fused computation for a Mixture of Experts (MOE) using
|
Implements the fused computation for a Mixture of Experts (MOE) using
|
||||||
@@ -143,12 +145,21 @@ def fused_moe_kernel(
|
|||||||
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
|
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
|
||||||
# Load the next block of A and B, generate a mask by checking the
|
# Load the next block of A and B, generate a mask by checking the
|
||||||
# K dimension.
|
# K dimension.
|
||||||
a = tl.load(
|
if even_Ks:
|
||||||
a_ptrs,
|
a = tl.load(
|
||||||
mask=token_mask[:, None] & (offs_k[None, :] < K - k * BLOCK_SIZE_K),
|
a_ptrs,
|
||||||
other=0.0,
|
mask=token_mask[:, None],
|
||||||
)
|
other=0.0,
|
||||||
b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)
|
)
|
||||||
|
b = tl.load(b_ptrs)
|
||||||
|
else:
|
||||||
|
a = tl.load(
|
||||||
|
a_ptrs,
|
||||||
|
mask=token_mask[:, None] & (offs_k[None, :] < K - k * BLOCK_SIZE_K),
|
||||||
|
other=0.0,
|
||||||
|
)
|
||||||
|
b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)
|
||||||
|
|
||||||
# We accumulate along the K dimension.
|
# We accumulate along the K dimension.
|
||||||
if use_int8_w8a16:
|
if use_int8_w8a16:
|
||||||
accumulator = tl.dot(a, b.to(compute_type), acc=accumulator)
|
accumulator = tl.dot(a, b.to(compute_type), acc=accumulator)
|
||||||
@@ -254,7 +265,9 @@ def invoke_fused_moe_kernel(
|
|||||||
assert topk_weights.stride(1) == 1
|
assert topk_weights.stride(1) == 1
|
||||||
assert sorted_token_ids.stride(0) == 1
|
assert sorted_token_ids.stride(0) == 1
|
||||||
|
|
||||||
|
padded_size = 0
|
||||||
if use_fp8_w8a8:
|
if use_fp8_w8a8:
|
||||||
|
padded_size = padding_size
|
||||||
A, A_scale = ops.scaled_fp8_quant(A, A_scale)
|
A, A_scale = ops.scaled_fp8_quant(A, A_scale)
|
||||||
assert B_scale is not None
|
assert B_scale is not None
|
||||||
elif use_int8_w8a16:
|
elif use_int8_w8a16:
|
||||||
@@ -268,6 +281,12 @@ def invoke_fused_moe_kernel(
|
|||||||
* triton.cdiv(B.shape[1], META["BLOCK_SIZE_N"]),
|
* triton.cdiv(B.shape[1], META["BLOCK_SIZE_N"]),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
K = B.shape[2] - padded_size
|
||||||
|
if K % config["BLOCK_SIZE_K"] == 0:
|
||||||
|
even_Ks = True
|
||||||
|
else:
|
||||||
|
even_Ks = False
|
||||||
|
|
||||||
fused_moe_kernel[grid](
|
fused_moe_kernel[grid](
|
||||||
A,
|
A,
|
||||||
B,
|
B,
|
||||||
@@ -279,7 +298,7 @@ def invoke_fused_moe_kernel(
|
|||||||
expert_ids,
|
expert_ids,
|
||||||
num_tokens_post_padded,
|
num_tokens_post_padded,
|
||||||
B.shape[1],
|
B.shape[1],
|
||||||
B.shape[2],
|
B.shape[2] - padded_size,
|
||||||
sorted_token_ids.shape[0],
|
sorted_token_ids.shape[0],
|
||||||
topk_ids.numel(),
|
topk_ids.numel(),
|
||||||
A.stride(0),
|
A.stride(0),
|
||||||
@@ -296,6 +315,7 @@ def invoke_fused_moe_kernel(
|
|||||||
compute_type=compute_type,
|
compute_type=compute_type,
|
||||||
use_fp8_w8a8=use_fp8_w8a8,
|
use_fp8_w8a8=use_fp8_w8a8,
|
||||||
use_int8_w8a16=use_int8_w8a16,
|
use_int8_w8a16=use_int8_w8a16,
|
||||||
|
even_Ks=even_Ks,
|
||||||
**config,
|
**config,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -351,20 +371,39 @@ def get_default_config(
|
|||||||
dtype: Optional[str],
|
dtype: Optional[str],
|
||||||
is_marlin: bool,
|
is_marlin: bool,
|
||||||
) -> Dict[str, int]:
|
) -> Dict[str, int]:
|
||||||
config = {
|
if dtype == "fp8_w8a8":
|
||||||
"BLOCK_SIZE_M": 64,
|
|
||||||
"BLOCK_SIZE_N": 64,
|
|
||||||
"BLOCK_SIZE_K": 32,
|
|
||||||
"GROUP_SIZE_M": 8,
|
|
||||||
}
|
|
||||||
# A heuristic: fused marlin works faster with this config for small M
|
|
||||||
if M <= E or (is_marlin and M <= 32):
|
|
||||||
config = {
|
config = {
|
||||||
"BLOCK_SIZE_M": 16,
|
"BLOCK_SIZE_M": 128,
|
||||||
"BLOCK_SIZE_N": 32,
|
"BLOCK_SIZE_N": 256,
|
||||||
"BLOCK_SIZE_K": 64,
|
"BLOCK_SIZE_K": 128,
|
||||||
"GROUP_SIZE_M": 1,
|
"GROUP_SIZE_M": 32,
|
||||||
|
"num_warps": 8,
|
||||||
|
"num_stages": 4,
|
||||||
}
|
}
|
||||||
|
if M <= E:
|
||||||
|
config = {
|
||||||
|
"BLOCK_SIZE_M": 64,
|
||||||
|
"BLOCK_SIZE_N": 128,
|
||||||
|
"BLOCK_SIZE_K": 128,
|
||||||
|
"GROUP_SIZE_M": 1,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 4,
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
config = {
|
||||||
|
"BLOCK_SIZE_M": 64,
|
||||||
|
"BLOCK_SIZE_N": 64,
|
||||||
|
"BLOCK_SIZE_K": 32,
|
||||||
|
"GROUP_SIZE_M": 8,
|
||||||
|
}
|
||||||
|
# A heuristic: fused marlin works faster with this config for small M
|
||||||
|
if M <= E or (is_marlin and M <= 32):
|
||||||
|
config = {
|
||||||
|
"BLOCK_SIZE_M": 16,
|
||||||
|
"BLOCK_SIZE_N": 32,
|
||||||
|
"BLOCK_SIZE_K": 64,
|
||||||
|
"GROUP_SIZE_M": 1,
|
||||||
|
}
|
||||||
return config
|
return config
|
||||||
|
|
||||||
|
|
||||||
@@ -645,8 +684,12 @@ def fused_experts_impl(
|
|||||||
a1_scale: Optional[torch.Tensor] = None,
|
a1_scale: Optional[torch.Tensor] = None,
|
||||||
a2_scale: Optional[torch.Tensor] = None,
|
a2_scale: Optional[torch.Tensor] = None,
|
||||||
):
|
):
|
||||||
|
padded_size = padding_size
|
||||||
|
if not use_fp8_w8a8:
|
||||||
|
padded_size = 0
|
||||||
|
|
||||||
# Check constraints.
|
# Check constraints.
|
||||||
assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch"
|
assert hidden_states.shape[1] == w1.shape[2] - padded_size, "Hidden size mismatch"
|
||||||
assert topk_weights.shape == topk_ids.shape, "topk shape mismatch"
|
assert topk_weights.shape == topk_ids.shape, "topk shape mismatch"
|
||||||
assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
|
assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
|
||||||
assert w1.is_contiguous(), "Expert weights1 must be contiguous"
|
assert w1.is_contiguous(), "Expert weights1 must be contiguous"
|
||||||
@@ -668,7 +711,7 @@ def fused_experts_impl(
|
|||||||
get_config_func = functools.partial(
|
get_config_func = functools.partial(
|
||||||
try_get_optimal_moe_config,
|
try_get_optimal_moe_config,
|
||||||
w1.shape,
|
w1.shape,
|
||||||
w2.shape,
|
(w2.shape[0], w2.shape[1], w2.shape[2] - padded_size),
|
||||||
topk_ids.shape[1],
|
topk_ids.shape[1],
|
||||||
config_dtype,
|
config_dtype,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -1,9 +1,11 @@
|
|||||||
# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/model_executor/layers/quantization/fp8.py
|
# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/model_executor/layers/quantization/fp8.py
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
import os
|
||||||
from typing import Any, Callable, Dict, List, Optional
|
from typing import Any, Callable, Dict, List, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
from torch.nn import Module
|
from torch.nn import Module
|
||||||
from torch.nn.parameter import Parameter
|
from torch.nn.parameter import Parameter
|
||||||
from vllm import _custom_ops as ops
|
from vllm import _custom_ops as ops
|
||||||
@@ -24,6 +26,7 @@ from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
|||||||
)
|
)
|
||||||
from vllm.model_executor.parameter import ModelWeightParameter, PerTensorScaleParameter
|
from vllm.model_executor.parameter import ModelWeightParameter, PerTensorScaleParameter
|
||||||
|
|
||||||
|
from sglang.srt.layers.fused_moe_triton.fused_moe import padding_size
|
||||||
from sglang.srt.layers.linear import LinearMethodBase, UnquantizedLinearMethod
|
from sglang.srt.layers.linear import LinearMethodBase, UnquantizedLinearMethod
|
||||||
from sglang.srt.layers.quantization.base_config import (
|
from sglang.srt.layers.quantization.base_config import (
|
||||||
QuantizationConfig,
|
QuantizationConfig,
|
||||||
@@ -420,7 +423,7 @@ class Fp8MoEMethod:
|
|||||||
|
|
||||||
def process_weights_after_loading(self, layer: Module) -> None:
|
def process_weights_after_loading(self, layer: Module) -> None:
|
||||||
|
|
||||||
# If checkpoint is fp16, quantize in place.
|
# If checkpoint is fp16 or bfloat16, quantize in place.
|
||||||
if not self.quant_config.is_checkpoint_fp8_serialized:
|
if not self.quant_config.is_checkpoint_fp8_serialized:
|
||||||
# If ROCm, use float8_e4m3fnuz instead (MI300x HW)
|
# If ROCm, use float8_e4m3fnuz instead (MI300x HW)
|
||||||
fp8_dtype = torch.float8_e4m3fnuz if is_hip() else torch.float8_e4m3fn
|
fp8_dtype = torch.float8_e4m3fnuz if is_hip() else torch.float8_e4m3fn
|
||||||
@@ -444,6 +447,19 @@ class Fp8MoEMethod:
|
|||||||
)
|
)
|
||||||
layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False)
|
layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False)
|
||||||
layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False)
|
layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False)
|
||||||
|
|
||||||
|
# If ROCm, apply weight padding (min. Mem channel contention) only if set
|
||||||
|
if is_hip() and bool(int(os.getenv("MOE_PADDING", "0"))):
|
||||||
|
layer.w13_weight = torch.nn.Parameter(
|
||||||
|
F.pad(layer.w13_weight.data, (0, padding_size), "constant", 0),
|
||||||
|
requires_grad=False,
|
||||||
|
)
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
layer.w2_weight = torch.nn.Parameter(
|
||||||
|
F.pad(layer.w2_weight.data, (0, padding_size), "constant", 0),
|
||||||
|
requires_grad=False,
|
||||||
|
)
|
||||||
|
torch.cuda.empty_cache()
|
||||||
return
|
return
|
||||||
|
|
||||||
# If checkpoint is fp8, we need to handle that the
|
# If checkpoint is fp8, we need to handle that the
|
||||||
@@ -472,6 +488,7 @@ class Fp8MoEMethod:
|
|||||||
layer.w2_input_scale = torch.nn.Parameter(
|
layer.w2_input_scale = torch.nn.Parameter(
|
||||||
layer.w2_input_scale.max(), requires_grad=False
|
layer.w2_input_scale.max(), requires_grad=False
|
||||||
)
|
)
|
||||||
|
|
||||||
# If ROCm, normalize the weights and scales to e4m3fnuz
|
# If ROCm, normalize the weights and scales to e4m3fnuz
|
||||||
if is_hip():
|
if is_hip():
|
||||||
# Normalize the weights and scales
|
# Normalize the weights and scales
|
||||||
@@ -523,6 +540,19 @@ class Fp8MoEMethod:
|
|||||||
layer.w13_weight_scale = torch.nn.Parameter(
|
layer.w13_weight_scale = torch.nn.Parameter(
|
||||||
max_w13_scales, requires_grad=False
|
max_w13_scales, requires_grad=False
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# If ROCm, apply weight padding (min. Mem channel contention) only if set
|
||||||
|
if is_hip() and bool(int(os.getenv("MOE_PADDING", "0"))):
|
||||||
|
layer.w13_weight = torch.nn.Parameter(
|
||||||
|
F.pad(layer.w13_weight.data, (0, padding_size), "constant", 0),
|
||||||
|
requires_grad=False,
|
||||||
|
)
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
layer.w2_weight = torch.nn.Parameter(
|
||||||
|
F.pad(layer.w2_weight.data, (0, padding_size), "constant", 0),
|
||||||
|
requires_grad=False,
|
||||||
|
)
|
||||||
|
torch.cuda.empty_cache()
|
||||||
return
|
return
|
||||||
|
|
||||||
def apply(
|
def apply(
|
||||||
@@ -540,6 +570,7 @@ class Fp8MoEMethod:
|
|||||||
from sglang.srt.layers.fused_moe_triton import FusedMoE
|
from sglang.srt.layers.fused_moe_triton import FusedMoE
|
||||||
from sglang.srt.layers.fused_moe_triton.fused_moe import fused_experts
|
from sglang.srt.layers.fused_moe_triton.fused_moe import fused_experts
|
||||||
|
|
||||||
|
# Expert selection
|
||||||
topk_weights, topk_ids = FusedMoE.select_experts(
|
topk_weights, topk_ids = FusedMoE.select_experts(
|
||||||
hidden_states=x,
|
hidden_states=x,
|
||||||
router_logits=router_logits,
|
router_logits=router_logits,
|
||||||
@@ -551,6 +582,7 @@ class Fp8MoEMethod:
|
|||||||
custom_routing_function=custom_routing_function,
|
custom_routing_function=custom_routing_function,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Expert fusion with FP8 quantization
|
||||||
return fused_experts(
|
return fused_experts(
|
||||||
x,
|
x,
|
||||||
layer.w13_weight,
|
layer.w13_weight,
|
||||||
|
|||||||
Reference in New Issue
Block a user