feat: support DeepSeek-R1-W4AFP8 model with ep-moe mode (#7762)
Signed-off-by: yangsijia.614 <yangsijia.614@bytedance.com>
This commit is contained in:
@@ -359,7 +359,17 @@ class ModelConfig:
|
||||
if hf_api.file_exists(self.model_path, "hf_quant_config.json"):
|
||||
quant_cfg = modelopt_quant_config
|
||||
elif os.path.exists(os.path.join(self.model_path, "hf_quant_config.json")):
|
||||
quant_cfg = modelopt_quant_config
|
||||
quant_config_file = os.path.join(
|
||||
self.model_path, "hf_quant_config.json"
|
||||
)
|
||||
with open(quant_config_file) as f:
|
||||
quant_config_dict = json.load(f)
|
||||
json_quant_configs = quant_config_dict["quantization"]
|
||||
quant_algo = json_quant_configs.get("quant_algo", None)
|
||||
if quant_algo == "MIXED_PRECISION":
|
||||
quant_cfg = {"quant_method": "w4afp8"}
|
||||
else:
|
||||
quant_cfg = modelopt_quant_config
|
||||
return quant_cfg
|
||||
|
||||
# adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/config.py
|
||||
@@ -389,6 +399,7 @@ class ModelConfig:
|
||||
"w8a8_fp8",
|
||||
"moe_wna16",
|
||||
"qoq",
|
||||
"w4afp8",
|
||||
]
|
||||
compatible_quantization_methods = {
|
||||
"modelopt_fp4": ["modelopt"],
|
||||
|
||||
215
python/sglang/srt/layers/moe/cutlass_w4a8_moe.py
Normal file
215
python/sglang/srt/layers/moe/cutlass_w4a8_moe.py
Normal file
@@ -0,0 +1,215 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
"""Cutlass W4A8 MoE kernel."""
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from sgl_kernel import (
|
||||
cutlass_w4a8_moe_mm,
|
||||
get_cutlass_w4a8_moe_mm_data,
|
||||
sgl_per_tensor_quant_fp8,
|
||||
silu_and_mul,
|
||||
)
|
||||
|
||||
from sglang.srt.layers.moe.ep_moe.kernels import (
|
||||
post_reorder_triton_kernel,
|
||||
pre_reorder_triton_kernel_for_cutlass_moe,
|
||||
run_cutlass_moe_ep_preproess,
|
||||
)
|
||||
|
||||
|
||||
def cutlass_w4a8_moe(
|
||||
start_expert_id: int,
|
||||
end_expert_id: int,
|
||||
total_num_experts: int,
|
||||
a: torch.Tensor,
|
||||
w1_q: torch.Tensor,
|
||||
w2_q: torch.Tensor,
|
||||
w1_scale: torch.Tensor,
|
||||
w2_scale: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids_: torch.Tensor,
|
||||
local_topk_ids: torch.Tensor,
|
||||
a_strides1: torch.Tensor,
|
||||
b_strides1: torch.Tensor,
|
||||
c_strides1: torch.Tensor,
|
||||
a_strides2: torch.Tensor,
|
||||
b_strides2: torch.Tensor,
|
||||
c_strides2: torch.Tensor,
|
||||
s_strides13: torch.Tensor,
|
||||
s_strides2: torch.Tensor,
|
||||
expert_offsets: torch.Tensor,
|
||||
problem_sizes1: torch.Tensor,
|
||||
problem_sizes2: torch.Tensor,
|
||||
a1_scale: Optional[torch.Tensor] = None,
|
||||
a2_scale: Optional[torch.Tensor] = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
This function computes a w4a8-quantized Mixture of Experts (MoE) layer
|
||||
using two sets of quantized weights, w1_q and w2_q, and top-k gating
|
||||
mechanism. The matrix multiplications are implemented with CUTLASS
|
||||
grouped gemm.
|
||||
|
||||
Parameters:
|
||||
- a (torch.Tensor): The input tensor to the MoE layer.
|
||||
Shape: [M, K]
|
||||
- w1_q (torch.Tensor): The first set of int4-quantized expert weights.
|
||||
Shape: [num_experts, N * 2, K // 2]
|
||||
(the weights are passed transposed and int4-packed)
|
||||
- w2_q (torch.Tensor): The second set of int4-quantized expert weights.
|
||||
Shape: [num_experts, K, N // 2]
|
||||
(the weights are passed transposed and int4-packed)
|
||||
- w1_scale (torch.Tensor): The fp32 scale to dequantize w1_q.
|
||||
Shape: [num_experts, K // 512, N * 8]
|
||||
- w2_scale (torch.Tensor): The fp32 scale to dequantize w2_q.
|
||||
Shape: [num_experts, N // 512, K * 4]
|
||||
- topk_weights (torch.Tensor): The weights of each token->expert mapping.
|
||||
- a_strides1 (torch.Tensor): The input strides of the first grouped gemm.
|
||||
- b_strides1 (torch.Tensor): The weights strides of the first grouped gemm.
|
||||
- c_strides1 (torch.Tensor): The output strides of the first grouped gemm.
|
||||
- a_strides2 (torch.Tensor): The input strides of the second grouped gemm.
|
||||
- b_strides2 (torch.Tensor): The weights strides of the second grouped gemm.
|
||||
- c_strides2 (torch.Tensor): The output strides of the second grouped gemm.
|
||||
- s_strides13 (torch.Tensor): The input and scale strides of the first grouped gemm.
|
||||
- s_strides2 (torch.Tensor): The scale strides of the second grouped gemm.
|
||||
- a1_scale (Optional[torch.Tensor]): The optional fp32 scale to quantize a.
|
||||
Shape: scalar or [1, K]
|
||||
- a2_scale (Optional[torch.Tensor]): The optional fp32 scale to
|
||||
quantize the intermediate result between the gemms.
|
||||
Shape: scalar or [1, N]
|
||||
- apply_router_weight_on_input (bool): When true, the topk weights are
|
||||
applied directly on the inputs. This is only applicable when topk is 1.
|
||||
|
||||
Returns:
|
||||
- torch.Tensor: The fp8 output tensor after applying the MoE layer.
|
||||
"""
|
||||
assert topk_weights.shape == topk_ids_.shape, "topk shape mismatch"
|
||||
assert w1_q.dtype == torch.int8
|
||||
assert w2_q.dtype == torch.int8
|
||||
assert a.shape[1] // 2 == w1_q.shape[2], "Hidden size mismatch w1"
|
||||
assert w1_q.shape[2] * 2 == w2_q.shape[1], "Hidden size mismatch w2"
|
||||
assert w1_q.shape[0] == w2_q.shape[0], "Expert number mismatch"
|
||||
assert w1_q.shape[0] == w1_scale.shape[0], "w1 scales expert number mismatch"
|
||||
assert w1_q.shape[0] == w2_scale.shape[0], "w2 scales expert number mismatch"
|
||||
assert (
|
||||
w1_scale.shape[1] == w1_q.shape[2] * 2 / 512
|
||||
and w1_scale.shape[2] == w1_q.shape[1] * 4
|
||||
), "W1 scale shape mismatch"
|
||||
assert (
|
||||
w2_scale.shape[1] == w2_q.shape[2] * 2 / 512
|
||||
and w2_scale.shape[2] == w2_q.shape[1] * 4
|
||||
), "W2 scale shape mismatch"
|
||||
|
||||
assert a_strides1.shape[0] == w1_q.shape[0], "A Strides 1 expert number mismatch"
|
||||
assert b_strides1.shape[0] == w1_q.shape[0], "B Strides 1 expert number mismatch"
|
||||
assert a_strides2.shape[0] == w2_q.shape[0], "A Strides 2 expert number mismatch"
|
||||
assert b_strides2.shape[0] == w2_q.shape[0], "B Strides 2 expert number mismatch"
|
||||
num_experts = w1_q.size(0)
|
||||
m = a.size(0)
|
||||
k = w1_q.size(2) * 2 # w1_q is transposed and packed
|
||||
n = w2_q.size(2) * 2 # w2_q is transposed and packed
|
||||
topk = topk_ids_.size(1)
|
||||
|
||||
if apply_router_weight_on_input:
|
||||
assert topk == 1, "apply_router_weight_on_input is only implemented for topk=1"
|
||||
|
||||
device = a.device
|
||||
|
||||
_, src2dst, _ = run_cutlass_moe_ep_preproess(
|
||||
local_topk_ids,
|
||||
num_experts,
|
||||
)
|
||||
|
||||
gateup_input = torch.empty(
|
||||
(m * topk, k),
|
||||
device=device,
|
||||
dtype=torch.float8_e4m3fn,
|
||||
)
|
||||
|
||||
pre_reorder_triton_kernel_for_cutlass_moe[(m,)](
|
||||
a,
|
||||
gateup_input,
|
||||
src2dst,
|
||||
local_topk_ids,
|
||||
a1_scale,
|
||||
total_num_experts,
|
||||
topk,
|
||||
k,
|
||||
BLOCK_SIZE=512,
|
||||
)
|
||||
|
||||
# NOTE: a_map and c_map are not used in the get_cutlass_w4a8_moe_mm_data kernel,
|
||||
# they are kept to allow for a quick switch of the permutation logic
|
||||
# from the current triton kernel implementation to the cutlass-based one if needed.
|
||||
a_map = torch.empty((local_topk_ids.numel()), dtype=torch.int32, device=device)
|
||||
c_map = torch.empty((local_topk_ids.numel()), dtype=torch.int32, device=device)
|
||||
get_cutlass_w4a8_moe_mm_data(
|
||||
local_topk_ids,
|
||||
expert_offsets,
|
||||
problem_sizes1,
|
||||
problem_sizes2,
|
||||
a_map,
|
||||
c_map,
|
||||
num_experts,
|
||||
n,
|
||||
k,
|
||||
)
|
||||
|
||||
c1 = torch.empty((m * topk, n * 2), device=device, dtype=torch.half)
|
||||
c2 = torch.zeros((m * topk, k), device=device, dtype=torch.half)
|
||||
|
||||
cutlass_w4a8_moe_mm(
|
||||
c1,
|
||||
gateup_input,
|
||||
w1_q,
|
||||
a1_scale.float(),
|
||||
w1_scale,
|
||||
expert_offsets[:-1],
|
||||
problem_sizes1,
|
||||
a_strides1,
|
||||
b_strides1,
|
||||
c_strides1,
|
||||
s_strides13,
|
||||
128,
|
||||
topk,
|
||||
)
|
||||
|
||||
intermediate = torch.empty((m * topk, n), device=device, dtype=torch.half)
|
||||
silu_and_mul(c1, intermediate)
|
||||
|
||||
intermediate_q = torch.empty(
|
||||
intermediate.shape, dtype=torch.float8_e4m3fn, device=device
|
||||
)
|
||||
sgl_per_tensor_quant_fp8(intermediate, intermediate_q, a2_scale.float(), True)
|
||||
|
||||
cutlass_w4a8_moe_mm(
|
||||
c2,
|
||||
intermediate_q,
|
||||
w2_q,
|
||||
a2_scale.float(),
|
||||
w2_scale,
|
||||
expert_offsets[:-1],
|
||||
problem_sizes2,
|
||||
a_strides2,
|
||||
b_strides2,
|
||||
c_strides2,
|
||||
s_strides2,
|
||||
128,
|
||||
topk,
|
||||
)
|
||||
|
||||
output = torch.empty_like(a)
|
||||
post_reorder_triton_kernel[(m,)](
|
||||
c2,
|
||||
output,
|
||||
src2dst,
|
||||
topk_ids_,
|
||||
topk_weights,
|
||||
start_expert_id,
|
||||
end_expert_id,
|
||||
topk,
|
||||
k,
|
||||
0,
|
||||
BLOCK_SIZE=512,
|
||||
)
|
||||
return output
|
||||
@@ -146,6 +146,7 @@ def compute_seg_indptr_triton_kernel(reorder_topk_ids, seg_indptr, num_toks):
|
||||
|
||||
def run_moe_ep_preproess(topk_ids: torch.Tensor, num_experts: int):
|
||||
reorder_topk_ids, reorder_ids = torch.sort(topk_ids.view(-1), stable=True)
|
||||
|
||||
seg_indptr = torch.zeros(num_experts + 1, device=topk_ids.device, dtype=torch.int64)
|
||||
src2dst = torch.empty(topk_ids.numel(), device=topk_ids.device, dtype=torch.int32)
|
||||
|
||||
@@ -158,9 +159,66 @@ def run_moe_ep_preproess(topk_ids: torch.Tensor, num_experts: int):
|
||||
compute_src2dst_triton_kernel[grid](
|
||||
reorder_ids, src2dst, topk_ids.numel(), BLOCK_SIZE
|
||||
)
|
||||
|
||||
return reorder_topk_ids, src2dst, seg_indptr
|
||||
|
||||
|
||||
def run_cutlass_moe_ep_preproess(local_topk_ids: torch.Tensor, local_num_experts: int):
|
||||
reorder_topk_ids, reorder_ids = torch.sort(local_topk_ids.view(-1), stable=True)
|
||||
|
||||
seg_indptr = torch.zeros(
|
||||
local_num_experts + 1, device=local_topk_ids.device, dtype=torch.int64
|
||||
)
|
||||
src2dst = torch.empty(
|
||||
local_topk_ids.numel(), device=local_topk_ids.device, dtype=torch.int32
|
||||
)
|
||||
|
||||
BLOCK_SIZE = 512
|
||||
grid = (triton.cdiv(local_topk_ids.numel(), BLOCK_SIZE),)
|
||||
compute_src2dst_triton_kernel[grid](
|
||||
reorder_ids, src2dst, local_topk_ids.numel(), BLOCK_SIZE
|
||||
)
|
||||
|
||||
return reorder_topk_ids, src2dst, seg_indptr
|
||||
|
||||
|
||||
@triton.jit
|
||||
def pre_reorder_triton_kernel_for_cutlass_moe(
|
||||
input_ptr,
|
||||
gateup_input_ptr,
|
||||
src2dst_ptr,
|
||||
topk_ids_ptr,
|
||||
a1_scales_ptr,
|
||||
num_experts,
|
||||
topk,
|
||||
hidden_size,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
OutDtype = gateup_input_ptr.dtype.element_ty
|
||||
|
||||
src_idx = tl.program_id(0)
|
||||
src2dst_ptr = src2dst_ptr + src_idx * topk
|
||||
topk_ids_ptr = topk_ids_ptr + src_idx * topk
|
||||
|
||||
src_ptr = input_ptr + src_idx * hidden_size
|
||||
for idx in range(topk):
|
||||
expert_id = tl.load(topk_ids_ptr + idx)
|
||||
if expert_id != num_experts:
|
||||
if a1_scales_ptr is not None:
|
||||
scale = 1.0 / tl.load(a1_scales_ptr)
|
||||
else:
|
||||
scale = 1.0
|
||||
|
||||
dst_idx = tl.load(src2dst_ptr + idx)
|
||||
dst_ptr = gateup_input_ptr + dst_idx * hidden_size
|
||||
for start_offset in tl.range(0, hidden_size, BLOCK_SIZE):
|
||||
offset = start_offset + tl.arange(0, BLOCK_SIZE)
|
||||
mask = offset < hidden_size
|
||||
in_data = tl.load(src_ptr + offset, mask=mask).to(tl.float32)
|
||||
out_data = (in_data * scale).to(OutDtype)
|
||||
tl.store(dst_ptr + offset, out_data, mask=mask)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def pre_reorder_triton_kernel(
|
||||
input_ptr,
|
||||
|
||||
@@ -12,6 +12,7 @@ from sglang.srt.distributed import (
|
||||
)
|
||||
from sglang.srt.eplb.expert_location import get_global_expert_location_metadata
|
||||
from sglang.srt.eplb.expert_location_dispatch import ExpertLocationDispatchInfo
|
||||
from sglang.srt.layers.moe.cutlass_w4a8_moe import cutlass_w4a8_moe
|
||||
from sglang.srt.layers.moe.ep_moe.kernels import (
|
||||
ep_gather,
|
||||
ep_scatter,
|
||||
@@ -20,6 +21,8 @@ from sglang.srt.layers.moe.ep_moe.kernels import (
|
||||
moe_ep_deepgemm_preprocess,
|
||||
post_reorder_triton_kernel,
|
||||
pre_reorder_triton_kernel,
|
||||
pre_reorder_triton_kernel_for_cutlass_moe,
|
||||
run_cutlass_moe_ep_preproess,
|
||||
run_moe_ep_preproess,
|
||||
silu_and_mul_masked_post_quant_fwd,
|
||||
silu_and_mul_triton_kernel,
|
||||
@@ -41,6 +44,7 @@ from sglang.srt.layers.quantization.fp8_kernel import (
|
||||
sglang_per_token_quant_fp8,
|
||||
)
|
||||
from sglang.srt.layers.quantization.fp8_utils import normalize_e4m3fn_to_e4m3fnuz
|
||||
from sglang.srt.layers.quantization.w4afp8 import W4AFp8Config, W4AFp8MoEMethod
|
||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.utils import (
|
||||
@@ -191,7 +195,7 @@ class EPMoE(torch.nn.Module):
|
||||
num_fused_shared_experts == 0
|
||||
), "num_fused_shared_experts is not supported in EP"
|
||||
self.num_fused_shared_experts = num_fused_shared_experts
|
||||
self.num_experts_per_partition = self.num_experts // self.tp_size
|
||||
self.num_experts_per_partition, self.expert_map = self.determine_expert_map()
|
||||
self.start_expert_id = self.tp_rank * self.num_experts_per_partition
|
||||
self.end_expert_id = self.start_expert_id + self.num_experts_per_partition - 1
|
||||
|
||||
@@ -215,6 +219,18 @@ class EPMoE(torch.nn.Module):
|
||||
self.use_block_quant = False
|
||||
self.block_shape = None
|
||||
self.activation_scheme = None
|
||||
self.use_w4afp8 = False
|
||||
elif isinstance(quant_config, W4AFp8Config):
|
||||
self.quant_method: Optional[QuantizeMethodBase] = W4AFp8MoEMethod(
|
||||
quant_config
|
||||
)
|
||||
self.use_w4afp8 = True
|
||||
self.use_fp8_w8a8 = False
|
||||
self.use_block_quant = False
|
||||
self.fp8_dtype = torch.float8_e4m3fn
|
||||
self.w13_weight_scale = None
|
||||
self.w2_weight_scale = None
|
||||
self.activation_scheme = quant_config.moe_activation_scheme
|
||||
else:
|
||||
self.quant_method: Optional[QuantizeMethodBase] = Fp8EPMoEMethod(
|
||||
quant_config
|
||||
@@ -228,6 +244,7 @@ class EPMoE(torch.nn.Module):
|
||||
)
|
||||
self.fp8_dtype = torch.float8_e4m3fn
|
||||
self.activation_scheme = quant_config.activation_scheme
|
||||
self.use_w4afp8 = False
|
||||
|
||||
self.quant_method.create_weights(
|
||||
layer=self,
|
||||
@@ -253,6 +270,49 @@ class EPMoE(torch.nn.Module):
|
||||
self.w2_weight_scale_inv if self.use_block_quant else self.w2_weight_scale,
|
||||
)
|
||||
|
||||
# Adapted from https://github.com/vllm-project/vllm/blob/9fb52e523abf7bdaf7e60cf2971edb5a1b13dc08/vllm/model_executor/layers/fused_moe/layer.py#L544C1-L586C43
|
||||
# Modifications: use determine_expert_map as a class internal function, set 'global_num_experts' rather than '-1' for experts not assigned to the current rank.
|
||||
def determine_expert_map(self) -> Tuple[int, Optional[torch.Tensor]]:
|
||||
"""
|
||||
Calculates how many experts should be assigned to each rank for EP and
|
||||
creates a mapping from global to local expert index. Experts are
|
||||
distributed evenly across ranks. Any remaining are assigned to the
|
||||
last rank.
|
||||
|
||||
Returns:
|
||||
Tuple[int, Optional[torch.Tensor]]: A tuple containing:
|
||||
- local_num_experts (int): The number of experts assigned
|
||||
to the current rank.
|
||||
- expert_map (Optional[torch.Tensor]): A tensor of shape
|
||||
(global_num_experts,) mapping from global to local index.
|
||||
Contains global_num_experts for experts not assigned to the current rank.
|
||||
Returns None if ep_size is 1.
|
||||
"""
|
||||
ep_size = self.tp_size
|
||||
ep_rank = self.tp_rank
|
||||
global_num_experts = self.num_experts
|
||||
|
||||
assert ep_size > 0
|
||||
if ep_size == 1:
|
||||
return (global_num_experts, None)
|
||||
|
||||
local_num_experts = global_num_experts // ep_size
|
||||
|
||||
expert_map = torch.full(
|
||||
(global_num_experts,), self.num_experts, dtype=torch.int32
|
||||
)
|
||||
if ep_rank < (ep_size - 1):
|
||||
expert_map[
|
||||
ep_rank * local_num_experts : (ep_rank + 1) * local_num_experts
|
||||
] = torch.arange(0, local_num_experts, dtype=torch.int32)
|
||||
else:
|
||||
local_num_experts = global_num_experts - ep_rank * local_num_experts
|
||||
|
||||
expert_map[-local_num_experts:] = torch.arange(
|
||||
0, local_num_experts, dtype=torch.int32
|
||||
)
|
||||
return (local_num_experts, expert_map)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor):
|
||||
if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM and self.use_fp8_w8a8:
|
||||
return self.forward_deepgemm(hidden_states, router_logits)
|
||||
@@ -440,6 +500,51 @@ class EPMoE(torch.nn.Module):
|
||||
),
|
||||
)
|
||||
|
||||
if self.use_w4afp8:
|
||||
local_topk_ids = topk_ids
|
||||
if self.expert_map is not None:
|
||||
"Translate info from expert_map to topk_ids"
|
||||
local_topk_ids = torch.where(
|
||||
self.expert_map[topk_ids] != self.num_experts,
|
||||
self.expert_map[topk_ids],
|
||||
self.num_experts,
|
||||
)
|
||||
|
||||
output = cutlass_w4a8_moe(
|
||||
self.start_expert_id,
|
||||
self.end_expert_id,
|
||||
self.num_experts,
|
||||
hidden_states,
|
||||
self.w13_weight,
|
||||
self.w2_weight,
|
||||
self.w13_weight_scale_inv,
|
||||
self.w2_weight_scale_inv,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
local_topk_ids,
|
||||
self.quant_method.a_strides1,
|
||||
self.quant_method.b_strides1,
|
||||
self.quant_method.c_strides1,
|
||||
self.quant_method.a_strides2,
|
||||
self.quant_method.b_strides2,
|
||||
self.quant_method.c_strides2,
|
||||
self.quant_method.s_strides13,
|
||||
self.quant_method.s_strides2,
|
||||
self.quant_method.expert_offsets,
|
||||
self.quant_method.problem_sizes1,
|
||||
self.quant_method.problem_sizes2,
|
||||
self.w13_input_scale,
|
||||
self.w2_input_scale,
|
||||
)
|
||||
return output
|
||||
|
||||
if self.grouped_gemm_runner is None:
|
||||
self.grouped_gemm_runner = GroupedGemmRunner(
|
||||
hidden_states.device,
|
||||
use_flashinfer=False, # TODO: use flashinfer
|
||||
use_per_token_if_dynamic=self.use_per_token_if_dynamic,
|
||||
)
|
||||
|
||||
reorder_topk_ids, src2dst, seg_indptr = run_moe_ep_preproess(
|
||||
topk_ids, self.num_experts
|
||||
)
|
||||
@@ -449,7 +554,7 @@ class EPMoE(torch.nn.Module):
|
||||
device=hidden_states.device,
|
||||
dtype=(
|
||||
self.fp8_dtype
|
||||
if (self.use_fp8_w8a8 and not self.use_block_quant)
|
||||
if ((self.use_fp8_w8a8 or self.use_w4afp8) and not self.use_block_quant)
|
||||
else hidden_states.dtype
|
||||
),
|
||||
)
|
||||
@@ -656,6 +761,23 @@ class EPMoE(torch.nn.Module):
|
||||
]
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def make_expert_input_scale_params_mapping(
|
||||
cls,
|
||||
num_experts: int,
|
||||
) -> List[Tuple[str, str, int, str]]:
|
||||
# (param_name, weight_name, expert_id, shard_id)
|
||||
return [
|
||||
(
|
||||
"experts.w13_" if shard_id in ["w1", "w3"] else "experts.w2_",
|
||||
f"experts.{expert_id}.{shard_id}.",
|
||||
expert_id,
|
||||
shard_id,
|
||||
)
|
||||
for expert_id in range(num_experts)
|
||||
for shard_id in ["w1", "w2", "w3"]
|
||||
]
|
||||
|
||||
def weight_loader(
|
||||
self,
|
||||
param: torch.nn.Parameter,
|
||||
@@ -727,6 +849,15 @@ class EPMoE(torch.nn.Module):
|
||||
|
||||
# Input scales can be loaded directly and should be equal.
|
||||
if "input_scale" in weight_name:
|
||||
if self.use_w4afp8:
|
||||
if shard_id == "w1":
|
||||
param_data[expert_id][0] = loaded_weight
|
||||
elif shard_id == "w3":
|
||||
param_data[expert_id][1] = loaded_weight
|
||||
else:
|
||||
param_data[expert_id] = loaded_weight
|
||||
return
|
||||
|
||||
if (
|
||||
(shard_id == "w1" or shard_id == "w3")
|
||||
and param_data[expert_id] != 1
|
||||
@@ -752,6 +883,13 @@ class EPMoE(torch.nn.Module):
|
||||
] = loaded_weight
|
||||
else: # w2
|
||||
param_data[expert_id] = loaded_weight
|
||||
elif self.use_w4afp8:
|
||||
if shard_id == "w1":
|
||||
param_data[expert_id][: self.intermediate_size, :] = loaded_weight
|
||||
elif shard_id == "w3":
|
||||
param_data[expert_id][self.intermediate_size :, :] = loaded_weight
|
||||
else:
|
||||
param_data[expert_id] = loaded_weight
|
||||
# If we are in merged column case (gate_up_proj)
|
||||
else:
|
||||
if shard_id in ("w1", "w3"):
|
||||
|
||||
@@ -68,6 +68,7 @@ from sglang.srt.layers.quantization.modelopt_quant import (
|
||||
)
|
||||
from sglang.srt.layers.quantization.moe_wna16 import MoeWNA16Config
|
||||
from sglang.srt.layers.quantization.qoq import QoQConfig
|
||||
from sglang.srt.layers.quantization.w4afp8 import W4AFp8Config
|
||||
from sglang.srt.layers.quantization.w8a8_fp8 import W8A8Fp8Config
|
||||
from sglang.srt.layers.quantization.w8a8_int8 import W8A8Int8Config
|
||||
|
||||
@@ -82,6 +83,7 @@ BASE_QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
|
||||
"moe_wna16": MoeWNA16Config,
|
||||
"compressed-tensors": CompressedTensorsConfig,
|
||||
"qoq": QoQConfig,
|
||||
"w4afp8": W4AFp8Config,
|
||||
}
|
||||
|
||||
# VLLM-dependent quantization methods
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/model_executor/layers/quantization/fp8.py
|
||||
|
||||
import logging
|
||||
from typing import Any, Callable, Dict, List, Optional
|
||||
from typing import Any, Callable, Dict, List, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
@@ -200,7 +200,7 @@ class Fp8LinearMethod(LinearMethodBase):
|
||||
quant_config: The quantization config.
|
||||
"""
|
||||
|
||||
def __init__(self, quant_config: Fp8Config):
|
||||
def __init__(self, quant_config: Union["Fp8Config", "W4AFp8Config"]):
|
||||
self.quant_config = quant_config
|
||||
self.cutlass_fp8_supported = cutlass_fp8_supported()
|
||||
|
||||
@@ -286,7 +286,10 @@ class Fp8LinearMethod(LinearMethodBase):
|
||||
if self.quant_config.is_checkpoint_fp8_serialized:
|
||||
# WEIGHT SCALE
|
||||
if self.block_quant:
|
||||
assert self.quant_config.activation_scheme == "dynamic"
|
||||
if hasattr(self.quant_config, "activation_scheme"):
|
||||
assert self.quant_config.activation_scheme == "dynamic"
|
||||
elif hasattr(self.quant_config, "linear_activation_scheme"):
|
||||
assert self.quant_config.linear_activation_scheme == "dynamic"
|
||||
scale = BlockQuantScaleParameter(
|
||||
data=torch.empty(
|
||||
(output_size_per_partition + block_n - 1) // block_n,
|
||||
@@ -308,7 +311,13 @@ class Fp8LinearMethod(LinearMethodBase):
|
||||
layer.register_parameter("weight_scale", scale)
|
||||
|
||||
# INPUT ACTIVATION SCALE
|
||||
if self.quant_config.activation_scheme == "static":
|
||||
if (
|
||||
hasattr(self.quant_config, "activation_scheme")
|
||||
and self.quant_config.activation_scheme == "static"
|
||||
) or (
|
||||
hasattr(self.quant_config, "linear_activation_scheme")
|
||||
and self.quant_config.linear_activation_scheme == "static"
|
||||
):
|
||||
scale = PerTensorScaleParameter(
|
||||
data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
|
||||
weight_loader=weight_loader,
|
||||
@@ -371,7 +380,13 @@ class Fp8LinearMethod(LinearMethodBase):
|
||||
layer.weight_scale = torch.nn.Parameter(
|
||||
layer.weight_scale.data, requires_grad=False
|
||||
)
|
||||
if self.quant_config.activation_scheme == "static":
|
||||
if (
|
||||
hasattr(self.quant_config, "activation_scheme")
|
||||
and self.quant_config.activation_scheme == "static"
|
||||
) or (
|
||||
hasattr(self.quant_config, "linear_activation_scheme")
|
||||
and self.quant_config.linear_activation_scheme == "static"
|
||||
):
|
||||
layer.input_scale = torch.nn.Parameter(
|
||||
layer.input_scale.data, requires_grad=False
|
||||
)
|
||||
@@ -405,7 +420,13 @@ class Fp8LinearMethod(LinearMethodBase):
|
||||
# Update layer with new values.
|
||||
layer.weight = Parameter(weight.t(), requires_grad=False)
|
||||
layer.weight_scale = Parameter(weight_scale, requires_grad=False)
|
||||
if self.quant_config.activation_scheme == "static":
|
||||
if (
|
||||
hasattr(self.quant_config, "activation_scheme")
|
||||
and self.quant_config.activation_scheme == "static"
|
||||
) or (
|
||||
hasattr(self.quant_config, "linear_activation_scheme")
|
||||
and self.quant_config.linear_activation_scheme == "static"
|
||||
):
|
||||
layer.input_scale = Parameter(
|
||||
layer.input_scale.max(), requires_grad=False
|
||||
)
|
||||
|
||||
264
python/sglang/srt/layers/quantization/w4afp8.py
Normal file
264
python/sglang/srt/layers/quantization/w4afp8.py
Normal file
@@ -0,0 +1,264 @@
|
||||
import logging
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import torch
|
||||
from torch.nn import Module
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from sglang.srt.layers.linear import LinearBase, UnquantizedLinearMethod
|
||||
from sglang.srt.layers.quantization.base_config import (
|
||||
QuantizationConfig,
|
||||
QuantizeMethodBase,
|
||||
)
|
||||
from sglang.srt.layers.quantization.fp8 import Fp8LinearMethod
|
||||
from sglang.srt.layers.quantization.utils import is_layer_skipped
|
||||
from sglang.srt.utils import set_weight_attrs
|
||||
|
||||
ACTIVATION_SCHEMES = ["static", "dynamic"]
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class W4AFp8Config(QuantizationConfig):
|
||||
"""Config class for MIXED_PRECISION W4AFp8."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
is_checkpoint_fp8_serialized: bool = True,
|
||||
is_checkpoint_w4afp8_serialized: bool = True,
|
||||
linear_activation_scheme: str = "dynamic",
|
||||
moe_activation_scheme: str = "static",
|
||||
ignored_layers: Optional[List[str]] = None,
|
||||
weight_block_size: Optional[List[int]] = None,
|
||||
group_size: int = 128,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized
|
||||
self.is_checkpoint_w4afp8_serialized = is_checkpoint_w4afp8_serialized
|
||||
if is_checkpoint_w4afp8_serialized:
|
||||
logger.warning("Detected w4afp8 checkpoint. Please note that")
|
||||
if moe_activation_scheme not in ACTIVATION_SCHEMES:
|
||||
raise ValueError(f"Unsupported activation scheme {moe_activation_scheme}")
|
||||
self.linear_activation_scheme = linear_activation_scheme
|
||||
self.moe_activation_scheme = moe_activation_scheme
|
||||
self.ignored_layers = ignored_layers or []
|
||||
self.weight_block_size = [128, 128]
|
||||
self.group_size = group_size
|
||||
|
||||
@classmethod
|
||||
def get_name(cls) -> str:
|
||||
return "w4afp8"
|
||||
|
||||
@classmethod
|
||||
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
|
||||
return [torch.bfloat16, torch.float8_e4m3fn]
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
return 90
|
||||
|
||||
@classmethod
|
||||
def get_config_filenames(cls) -> List[str]:
|
||||
return []
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict[str, Any]) -> "W4AFp8Config":
|
||||
quant_method = cls.get_from_keys(config, ["quant_method"])
|
||||
is_checkpoint_fp8_serialized = "fp8" in quant_method
|
||||
is_checkpoint_w4afp8_serialized = "w4afp8" in quant_method
|
||||
linear_activation_scheme = "dynamic"
|
||||
moe_activation_scheme = "static"
|
||||
weight_block_size = [128, 128]
|
||||
return cls(
|
||||
is_checkpoint_fp8_serialized=is_checkpoint_fp8_serialized,
|
||||
is_checkpoint_w4afp8_serialized=is_checkpoint_w4afp8_serialized,
|
||||
linear_activation_scheme=linear_activation_scheme,
|
||||
moe_activation_scheme=moe_activation_scheme,
|
||||
weight_block_size=weight_block_size,
|
||||
)
|
||||
|
||||
def get_quant_method(
|
||||
self, layer: torch.nn.Module, prefix: str
|
||||
) -> Optional["QuantizeMethodBase"]:
|
||||
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
||||
|
||||
if isinstance(layer, LinearBase):
|
||||
if is_layer_skipped(prefix, self.ignored_layers):
|
||||
return UnquantizedLinearMethod()
|
||||
return Fp8LinearMethod(self)
|
||||
elif isinstance(layer, FusedMoE):
|
||||
return W4AFp8MoEMethod(self)
|
||||
return None
|
||||
|
||||
def get_scaled_act_names(self) -> List[str]:
|
||||
return []
|
||||
|
||||
|
||||
class W4AFp8MoEMethod:
|
||||
|
||||
def __init__(self, quant_config: W4AFp8Config):
|
||||
self.quant_config = quant_config
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
layer: Module,
|
||||
num_experts_per_partition: int,
|
||||
hidden_size: int,
|
||||
intermediate_size: int,
|
||||
params_dtype: torch.dtype,
|
||||
**extra_weight_attrs,
|
||||
):
|
||||
assert "weight_loader" in extra_weight_attrs
|
||||
|
||||
# Fused gate_up_proj (column parallel)
|
||||
w13_weight = torch.nn.Parameter(
|
||||
torch.empty(
|
||||
num_experts_per_partition,
|
||||
intermediate_size * 2,
|
||||
hidden_size // 2,
|
||||
dtype=torch.int8,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w13_weight", w13_weight)
|
||||
set_weight_attrs(w13_weight, extra_weight_attrs)
|
||||
|
||||
# down_proj (row parallel)
|
||||
w2_weight = torch.nn.Parameter(
|
||||
torch.empty(
|
||||
num_experts_per_partition,
|
||||
hidden_size,
|
||||
intermediate_size // 2,
|
||||
dtype=torch.int8,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w2_weight", w2_weight)
|
||||
set_weight_attrs(w2_weight, extra_weight_attrs)
|
||||
|
||||
w13_weight_scale = torch.nn.Parameter(
|
||||
torch.zeros(
|
||||
num_experts_per_partition,
|
||||
2 * intermediate_size,
|
||||
hidden_size // self.quant_config.group_size,
|
||||
dtype=torch.float32,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w13_weight_scale_inv", w13_weight_scale)
|
||||
set_weight_attrs(w13_weight_scale, extra_weight_attrs)
|
||||
|
||||
w2_weight_scale = torch.nn.Parameter(
|
||||
torch.zeros(
|
||||
num_experts_per_partition,
|
||||
hidden_size,
|
||||
intermediate_size // self.quant_config.group_size,
|
||||
dtype=torch.float32,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w2_weight_scale_inv", w2_weight_scale)
|
||||
set_weight_attrs(w2_weight_scale, extra_weight_attrs)
|
||||
|
||||
# Input scales
|
||||
w13_input_scale = torch.nn.Parameter(
|
||||
torch.ones((num_experts_per_partition, 2), dtype=torch.bfloat16),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w13_input_scale", w13_input_scale)
|
||||
set_weight_attrs(w13_input_scale, extra_weight_attrs)
|
||||
|
||||
w2_input_scale = torch.nn.Parameter(
|
||||
torch.ones(num_experts_per_partition, dtype=torch.bfloat16),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w2_input_scale", w2_input_scale)
|
||||
set_weight_attrs(w2_input_scale, extra_weight_attrs)
|
||||
|
||||
# Pre-populate the strides
|
||||
device = layer.w13_weight.device
|
||||
|
||||
self.a_strides1 = torch.full(
|
||||
(num_experts_per_partition, 3),
|
||||
hidden_size,
|
||||
device=device,
|
||||
dtype=torch.int64,
|
||||
)
|
||||
self.c_strides1 = torch.full(
|
||||
(num_experts_per_partition, 3),
|
||||
2 * intermediate_size,
|
||||
device=device,
|
||||
dtype=torch.int64,
|
||||
)
|
||||
self.a_strides2 = torch.full(
|
||||
(num_experts_per_partition, 3),
|
||||
intermediate_size,
|
||||
device=device,
|
||||
dtype=torch.int64,
|
||||
)
|
||||
self.c_strides2 = torch.full(
|
||||
(num_experts_per_partition, 3),
|
||||
hidden_size,
|
||||
device=device,
|
||||
dtype=torch.int64,
|
||||
)
|
||||
self.b_strides1 = self.a_strides1
|
||||
self.s_strides13 = self.c_strides1
|
||||
self.b_strides2 = self.a_strides2
|
||||
self.s_strides2 = self.c_strides2
|
||||
|
||||
self.expert_offsets = torch.empty(
|
||||
(num_experts_per_partition + 1), dtype=torch.int32, device=device
|
||||
)
|
||||
self.problem_sizes1 = torch.empty(
|
||||
(num_experts_per_partition, 3), dtype=torch.int32, device=device
|
||||
)
|
||||
self.problem_sizes2 = torch.empty(
|
||||
(num_experts_per_partition, 3), dtype=torch.int32, device=device
|
||||
)
|
||||
|
||||
return
|
||||
|
||||
def _interleave_scales(self, scales: torch.Tensor) -> torch.Tensor:
|
||||
"""Interleave scales in groups of 4 similar to TRT-LLM implementation."""
|
||||
s_shape = scales.shape
|
||||
# Reshape to separate groups of 4
|
||||
scales_interleaved = scales.reshape(
|
||||
s_shape[0], s_shape[1], (s_shape[2] // 4), 4
|
||||
)
|
||||
# Permute dimensions to interleave
|
||||
scales_interleaved = scales_interleaved.permute(0, 2, 1, 3)
|
||||
# Reshape back to original dimensions but with interleaved values
|
||||
scales_interleaved = scales_interleaved.reshape(
|
||||
s_shape[0], s_shape[2] // 4, s_shape[1] * 4
|
||||
)
|
||||
return scales_interleaved.contiguous()
|
||||
|
||||
def process_weights_after_loading(self, layer: Module) -> None:
|
||||
dtype = torch.bfloat16
|
||||
device = layer.w2_weight.device
|
||||
|
||||
# Interleave w13_weight_scale (gate_up_proj)
|
||||
w13_weight_scale = layer.w13_weight_scale_inv.to(dtype)
|
||||
w13_weight_scale = self._interleave_scales(w13_weight_scale)
|
||||
layer.w13_weight_scale_inv = Parameter(w13_weight_scale, requires_grad=False)
|
||||
|
||||
# Interleave w2_weight_scale (down_proj)
|
||||
w2_weight_scale = layer.w2_weight_scale_inv.to(dtype)
|
||||
w2_weight_scale = self._interleave_scales(w2_weight_scale)
|
||||
layer.w2_weight_scale_inv = Parameter(w2_weight_scale, requires_grad=False)
|
||||
|
||||
# Process input scales
|
||||
w13_input_scale_max = layer.w13_input_scale.max().to(dtype).item()
|
||||
new_w13_input_scale = torch.tensor(
|
||||
[w13_input_scale_max],
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
)
|
||||
layer.w13_input_scale = Parameter(new_w13_input_scale, requires_grad=False)
|
||||
|
||||
w2_input_scale_max = layer.w2_input_scale.max().to(dtype).item()
|
||||
new_w2_input_scale = torch.tensor(
|
||||
[w2_input_scale_max], dtype=dtype, device=device
|
||||
)
|
||||
layer.w2_input_scale = Parameter(new_w2_input_scale, requires_grad=False)
|
||||
@@ -2363,6 +2363,12 @@ class DeepseekV2ForCausalLM(nn.Module):
|
||||
ckpt_up_proj_name="up_proj",
|
||||
num_experts=self.config.n_routed_experts + self.num_fused_shared_experts,
|
||||
)
|
||||
if self.quant_config and self.quant_config.get_name() == "w4afp8":
|
||||
expert_params_mapping += (
|
||||
get_moe_impl_class().make_expert_input_scale_params_mapping(
|
||||
num_experts=self.config.n_routed_experts
|
||||
)
|
||||
)
|
||||
|
||||
# Fuse q_a_proj and kv_a_proj_with_mqa along output dimension when q_lora_rank is not None
|
||||
fuse_qkv_a_proj = hasattr(self.config, "q_lora_rank") and (
|
||||
|
||||
@@ -708,6 +708,7 @@ class ServerArgs:
|
||||
"w8a8_fp8",
|
||||
"moe_wna16",
|
||||
"qoq",
|
||||
"w4afp8",
|
||||
],
|
||||
help="The quantization method.",
|
||||
)
|
||||
|
||||
281
python/sglang/test/test_cutlass_w4a8_moe.py
Normal file
281
python/sglang/test/test_cutlass_w4a8_moe.py
Normal file
@@ -0,0 +1,281 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from sglang.srt.layers.moe.cutlass_w4a8_moe import cutlass_w4a8_moe
|
||||
from sglang.srt.layers.moe.topk import select_experts
|
||||
|
||||
|
||||
def pack_int4_values_to_int8(int4_values_interleaved: torch.Tensor) -> torch.Tensor:
|
||||
if int4_values_interleaved.shape[-1] % 2 != 0:
|
||||
raise ValueError(
|
||||
"the last dim size of int4_values_interleaved tensor must be even."
|
||||
)
|
||||
|
||||
input_tensor_int8 = int4_values_interleaved.to(torch.int8)
|
||||
|
||||
low_nibbles = input_tensor_int8[..., 0::2]
|
||||
high_nibbles = input_tensor_int8[..., 1::2]
|
||||
|
||||
packed_tensor = (high_nibbles << 4) | (low_nibbles & 0x0F)
|
||||
|
||||
return packed_tensor.to(torch.int8)
|
||||
|
||||
|
||||
def pack_interleave(num_experts, ref_weight, ref_scale):
|
||||
n, k = ref_weight.shape[1], ref_weight.shape[2]
|
||||
|
||||
weight = pack_int4_values_to_int8(ref_weight.cpu()).cuda()
|
||||
w_q = weight.view((num_experts, n, k // 2)).view(torch.int8)
|
||||
w_q = w_q.contiguous()
|
||||
|
||||
scale_interleaved = ref_scale.reshape(
|
||||
ref_scale.shape[0], ref_scale.shape[1], (ref_scale.shape[2] // 4), 4
|
||||
) # [E, N, K/4, 4]
|
||||
scale_interleaved = scale_interleaved.permute(0, 2, 1, 3) # [E, K/4, N, 4]
|
||||
scale_interleaved = scale_interleaved.reshape(
|
||||
ref_scale.shape[0], ref_scale.shape[2] // 4, ref_scale.shape[1] * 4
|
||||
) # [E, K/4, N*4]
|
||||
w_scale = scale_interleaved.contiguous()
|
||||
|
||||
return w_q, w_scale
|
||||
|
||||
|
||||
@pytest.mark.parametrize("M", [1, 2, 4, 8, 16])
|
||||
@pytest.mark.parametrize("N", [2048])
|
||||
@pytest.mark.parametrize("K", [7168])
|
||||
@pytest.mark.parametrize("E", [256])
|
||||
@pytest.mark.parametrize("ep_size", [8])
|
||||
@pytest.mark.parametrize("topk", [8])
|
||||
@pytest.mark.parametrize("group_size", [128])
|
||||
@pytest.mark.parametrize("dtype", [torch.bfloat16])
|
||||
def test_cutlass_w4a8_moe(M, N, K, E, ep_size, topk, group_size, dtype):
|
||||
local_e = E // ep_size
|
||||
|
||||
debug = False
|
||||
if debug:
|
||||
a = torch.ones((M, K), dtype=dtype, device="cuda") * 0.001
|
||||
ref_weight_1 = torch.ones((local_e, N * 2, K), dtype=torch.int8, device="cuda")
|
||||
ref_weight_2 = torch.ones((local_e, K, N), dtype=torch.int8, device="cuda")
|
||||
a1_scale = torch.ones(1, dtype=torch.float32, device="cuda")
|
||||
a2_scale = torch.ones(1, dtype=torch.float32, device="cuda")
|
||||
scale_1 = torch.ones(
|
||||
(local_e, N * 2, K // group_size), dtype=dtype, device="cuda"
|
||||
)
|
||||
scale_2 = torch.ones((local_e, K, N // group_size), dtype=dtype, device="cuda")
|
||||
else:
|
||||
a = torch.randn(M, K, dtype=dtype, device="cuda")
|
||||
ref_weight_1 = torch.randint(
|
||||
-8, 8, (local_e, N * 2, K), dtype=torch.int8, device="cuda"
|
||||
)
|
||||
ref_weight_2 = torch.randint(
|
||||
-8, 8, (local_e, K, N), dtype=torch.int8, device="cuda"
|
||||
)
|
||||
affine_coeff = 0.005
|
||||
a1_scale = torch.randn(1, dtype=torch.float32, device="cuda")
|
||||
a2_scale = torch.randn(1, dtype=torch.float32, device="cuda")
|
||||
scale_1 = (
|
||||
torch.randn(local_e, N * 2, K // group_size, dtype=dtype, device="cuda")
|
||||
* affine_coeff
|
||||
)
|
||||
scale_2 = (
|
||||
torch.randn(local_e, K, N // group_size, dtype=dtype, device="cuda")
|
||||
* affine_coeff
|
||||
)
|
||||
|
||||
w1_q, w1_scale = pack_interleave(local_e, ref_weight_1, scale_1)
|
||||
w2_q, w2_scale = pack_interleave(local_e, ref_weight_2, scale_2)
|
||||
|
||||
device = "cuda"
|
||||
a_strides1 = torch.full((local_e, 3), K, device=device, dtype=torch.int64)
|
||||
c_strides1 = torch.full((local_e, 3), 2 * N, device=device, dtype=torch.int64)
|
||||
a_strides2 = torch.full((local_e, 3), N, device=device, dtype=torch.int64)
|
||||
c_strides2 = torch.full((local_e, 3), K, device=device, dtype=torch.int64)
|
||||
b_strides1 = a_strides1
|
||||
s_strides13 = c_strides1
|
||||
b_strides2 = a_strides2
|
||||
s_strides2 = c_strides2
|
||||
|
||||
score = torch.randn((M, E), dtype=dtype, device=device)
|
||||
topk_weights, topk_ids = select_experts(
|
||||
hidden_states=a,
|
||||
router_logits=score,
|
||||
top_k=topk,
|
||||
use_grouped_topk=False,
|
||||
renormalize=False,
|
||||
)
|
||||
expert_map = torch.arange(E, dtype=torch.int32, device=device)
|
||||
expert_map[local_e:] = E
|
||||
|
||||
output = cutlass_moe(
|
||||
a,
|
||||
w1_q,
|
||||
w2_q,
|
||||
w1_scale,
|
||||
w2_scale,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
a_strides1,
|
||||
b_strides1,
|
||||
c_strides1,
|
||||
a_strides2,
|
||||
b_strides2,
|
||||
c_strides2,
|
||||
s_strides13,
|
||||
s_strides2,
|
||||
0,
|
||||
local_e - 1,
|
||||
E,
|
||||
a1_scale,
|
||||
a2_scale,
|
||||
expert_map,
|
||||
)
|
||||
|
||||
ref_output = ref(
|
||||
a,
|
||||
local_e,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
ref_weight_1,
|
||||
ref_weight_2,
|
||||
scale_1,
|
||||
scale_2,
|
||||
has_pre_quant=True,
|
||||
has_alpha=True,
|
||||
pre_quant_scale_1=a1_scale,
|
||||
pre_quant_scale_2=a2_scale,
|
||||
alpha_1=a1_scale,
|
||||
alpha_2=a2_scale,
|
||||
)
|
||||
|
||||
# compare
|
||||
torch.cuda.synchronize()
|
||||
|
||||
# compare final output
|
||||
torch.testing.assert_close(output, ref_output, rtol=1e-2, atol=0.1)
|
||||
print("SUCCESS: Final output tensors are close.")
|
||||
|
||||
|
||||
def cutlass_moe(
|
||||
a: torch.Tensor,
|
||||
w1_q: torch.Tensor,
|
||||
w2_q: torch.Tensor,
|
||||
w1_scale: torch.Tensor,
|
||||
w2_scale: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids_: torch.Tensor,
|
||||
a_strides1: torch.Tensor,
|
||||
b_strides1: torch.Tensor,
|
||||
c_strides1: torch.Tensor,
|
||||
a_strides2: torch.Tensor,
|
||||
b_strides2: torch.Tensor,
|
||||
c_strides2: torch.Tensor,
|
||||
s_strides13: torch.Tensor,
|
||||
s_strides2: torch.Tensor,
|
||||
start_expert_id: int,
|
||||
end_expert_id: int,
|
||||
E: int,
|
||||
a1_scale: Optional[torch.Tensor] = None,
|
||||
a2_scale: Optional[torch.Tensor] = None,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
):
|
||||
local_topk_ids = topk_ids_
|
||||
local_topk_ids = torch.where(expert_map[topk_ids_] != E, expert_map[topk_ids_], E)
|
||||
device = a.device
|
||||
|
||||
local_num_experts = end_expert_id - start_expert_id + 1
|
||||
expert_offsets = torch.empty(
|
||||
(local_num_experts + 1), dtype=torch.int32, device=device
|
||||
)
|
||||
problem_sizes1 = torch.empty(
|
||||
(local_num_experts, 3), dtype=torch.int32, device=device
|
||||
)
|
||||
problem_sizes2 = torch.empty(
|
||||
(local_num_experts, 3), dtype=torch.int32, device=device
|
||||
)
|
||||
return cutlass_w4a8_moe(
|
||||
start_expert_id,
|
||||
end_expert_id,
|
||||
E,
|
||||
a,
|
||||
w1_q,
|
||||
w2_q,
|
||||
w1_scale,
|
||||
w2_scale,
|
||||
topk_weights,
|
||||
topk_ids_,
|
||||
local_topk_ids,
|
||||
a_strides1,
|
||||
b_strides1,
|
||||
c_strides1,
|
||||
a_strides2,
|
||||
b_strides2,
|
||||
c_strides2,
|
||||
s_strides13,
|
||||
s_strides2,
|
||||
expert_offsets,
|
||||
problem_sizes1,
|
||||
problem_sizes2,
|
||||
a1_scale,
|
||||
a2_scale,
|
||||
apply_router_weight_on_input,
|
||||
)
|
||||
|
||||
|
||||
def ref(
|
||||
x: torch.Tensor,
|
||||
num_experts: int,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
ref_weight_1: torch.Tensor,
|
||||
ref_weight_2: torch.Tensor,
|
||||
ref_weight_scale_1: torch.Tensor,
|
||||
ref_weight_scale_2: torch.Tensor,
|
||||
has_pre_quant: bool = False,
|
||||
has_alpha: bool = False,
|
||||
pre_quant_scale_1: Optional[torch.Tensor] = None,
|
||||
pre_quant_scale_2: Optional[torch.Tensor] = None,
|
||||
alpha_1: Optional[torch.Tensor] = None,
|
||||
alpha_2: Optional[torch.Tensor] = None,
|
||||
):
|
||||
results = torch.zeros_like(x)
|
||||
dtype = x.dtype
|
||||
for e_idx in range(num_experts):
|
||||
mask = topk_ids == e_idx
|
||||
activated_tokens = mask.sum(1).bool()
|
||||
act = x[activated_tokens, :]
|
||||
if act.shape[0] == 0:
|
||||
continue
|
||||
final_scale = (topk_weights * mask).sum(1)[activated_tokens].unsqueeze(1)
|
||||
|
||||
act = (
|
||||
torch.clamp((act / pre_quant_scale_1.float()), -448.0, 448.0)
|
||||
.to(torch.float8_e4m3fn)
|
||||
.to(dtype)
|
||||
)
|
||||
w3_w1 = ref_weight_1[e_idx]
|
||||
ref_w_scale_repeat = (
|
||||
ref_weight_scale_1[e_idx].repeat_interleave(128, dim=1).to(float)
|
||||
)
|
||||
w3_w1 = (w3_w1.to(float) * ref_w_scale_repeat).to(dtype)
|
||||
fc1 = ((torch.matmul(act, w3_w1.T)) * alpha_1).to(torch.float16)
|
||||
|
||||
gate, fc1 = fc1.chunk(2, dim=-1)
|
||||
fc1 = fc1 * torch.nn.functional.silu(gate)
|
||||
act = (fc1 / pre_quant_scale_2.float()).to(torch.float8_e4m3fn)
|
||||
act = act.to(dtype)
|
||||
|
||||
w2 = ref_weight_2[e_idx]
|
||||
ref_w_scale_repeat = (
|
||||
ref_weight_scale_2[e_idx].repeat_interleave(128, dim=1).to(float)
|
||||
)
|
||||
w2 = (w2.to(float) * ref_w_scale_repeat).to(dtype)
|
||||
fc2 = (torch.matmul(act, w2.T) * alpha_2).to(torch.float16)
|
||||
|
||||
results[activated_tokens, :] += (fc2 * final_scale).to(results.dtype)
|
||||
|
||||
return results
|
||||
Reference in New Issue
Block a user