feat: support DeepSeek-R1-W4AFP8 model with ep-moe mode (#7762)
Signed-off-by: yangsijia.614 <yangsijia.614@bytedance.com>
This commit is contained in:
@@ -359,7 +359,17 @@ class ModelConfig:
|
|||||||
if hf_api.file_exists(self.model_path, "hf_quant_config.json"):
|
if hf_api.file_exists(self.model_path, "hf_quant_config.json"):
|
||||||
quant_cfg = modelopt_quant_config
|
quant_cfg = modelopt_quant_config
|
||||||
elif os.path.exists(os.path.join(self.model_path, "hf_quant_config.json")):
|
elif os.path.exists(os.path.join(self.model_path, "hf_quant_config.json")):
|
||||||
quant_cfg = modelopt_quant_config
|
quant_config_file = os.path.join(
|
||||||
|
self.model_path, "hf_quant_config.json"
|
||||||
|
)
|
||||||
|
with open(quant_config_file) as f:
|
||||||
|
quant_config_dict = json.load(f)
|
||||||
|
json_quant_configs = quant_config_dict["quantization"]
|
||||||
|
quant_algo = json_quant_configs.get("quant_algo", None)
|
||||||
|
if quant_algo == "MIXED_PRECISION":
|
||||||
|
quant_cfg = {"quant_method": "w4afp8"}
|
||||||
|
else:
|
||||||
|
quant_cfg = modelopt_quant_config
|
||||||
return quant_cfg
|
return quant_cfg
|
||||||
|
|
||||||
# adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/config.py
|
# adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/config.py
|
||||||
@@ -389,6 +399,7 @@ class ModelConfig:
|
|||||||
"w8a8_fp8",
|
"w8a8_fp8",
|
||||||
"moe_wna16",
|
"moe_wna16",
|
||||||
"qoq",
|
"qoq",
|
||||||
|
"w4afp8",
|
||||||
]
|
]
|
||||||
compatible_quantization_methods = {
|
compatible_quantization_methods = {
|
||||||
"modelopt_fp4": ["modelopt"],
|
"modelopt_fp4": ["modelopt"],
|
||||||
|
|||||||
215
python/sglang/srt/layers/moe/cutlass_w4a8_moe.py
Normal file
215
python/sglang/srt/layers/moe/cutlass_w4a8_moe.py
Normal file
@@ -0,0 +1,215 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
"""Cutlass W4A8 MoE kernel."""
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from sgl_kernel import (
|
||||||
|
cutlass_w4a8_moe_mm,
|
||||||
|
get_cutlass_w4a8_moe_mm_data,
|
||||||
|
sgl_per_tensor_quant_fp8,
|
||||||
|
silu_and_mul,
|
||||||
|
)
|
||||||
|
|
||||||
|
from sglang.srt.layers.moe.ep_moe.kernels import (
|
||||||
|
post_reorder_triton_kernel,
|
||||||
|
pre_reorder_triton_kernel_for_cutlass_moe,
|
||||||
|
run_cutlass_moe_ep_preproess,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def cutlass_w4a8_moe(
|
||||||
|
start_expert_id: int,
|
||||||
|
end_expert_id: int,
|
||||||
|
total_num_experts: int,
|
||||||
|
a: torch.Tensor,
|
||||||
|
w1_q: torch.Tensor,
|
||||||
|
w2_q: torch.Tensor,
|
||||||
|
w1_scale: torch.Tensor,
|
||||||
|
w2_scale: torch.Tensor,
|
||||||
|
topk_weights: torch.Tensor,
|
||||||
|
topk_ids_: torch.Tensor,
|
||||||
|
local_topk_ids: torch.Tensor,
|
||||||
|
a_strides1: torch.Tensor,
|
||||||
|
b_strides1: torch.Tensor,
|
||||||
|
c_strides1: torch.Tensor,
|
||||||
|
a_strides2: torch.Tensor,
|
||||||
|
b_strides2: torch.Tensor,
|
||||||
|
c_strides2: torch.Tensor,
|
||||||
|
s_strides13: torch.Tensor,
|
||||||
|
s_strides2: torch.Tensor,
|
||||||
|
expert_offsets: torch.Tensor,
|
||||||
|
problem_sizes1: torch.Tensor,
|
||||||
|
problem_sizes2: torch.Tensor,
|
||||||
|
a1_scale: Optional[torch.Tensor] = None,
|
||||||
|
a2_scale: Optional[torch.Tensor] = None,
|
||||||
|
apply_router_weight_on_input: bool = False,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
This function computes a w4a8-quantized Mixture of Experts (MoE) layer
|
||||||
|
using two sets of quantized weights, w1_q and w2_q, and top-k gating
|
||||||
|
mechanism. The matrix multiplications are implemented with CUTLASS
|
||||||
|
grouped gemm.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
- a (torch.Tensor): The input tensor to the MoE layer.
|
||||||
|
Shape: [M, K]
|
||||||
|
- w1_q (torch.Tensor): The first set of int4-quantized expert weights.
|
||||||
|
Shape: [num_experts, N * 2, K // 2]
|
||||||
|
(the weights are passed transposed and int4-packed)
|
||||||
|
- w2_q (torch.Tensor): The second set of int4-quantized expert weights.
|
||||||
|
Shape: [num_experts, K, N // 2]
|
||||||
|
(the weights are passed transposed and int4-packed)
|
||||||
|
- w1_scale (torch.Tensor): The fp32 scale to dequantize w1_q.
|
||||||
|
Shape: [num_experts, K // 512, N * 8]
|
||||||
|
- w2_scale (torch.Tensor): The fp32 scale to dequantize w2_q.
|
||||||
|
Shape: [num_experts, N // 512, K * 4]
|
||||||
|
- topk_weights (torch.Tensor): The weights of each token->expert mapping.
|
||||||
|
- a_strides1 (torch.Tensor): The input strides of the first grouped gemm.
|
||||||
|
- b_strides1 (torch.Tensor): The weights strides of the first grouped gemm.
|
||||||
|
- c_strides1 (torch.Tensor): The output strides of the first grouped gemm.
|
||||||
|
- a_strides2 (torch.Tensor): The input strides of the second grouped gemm.
|
||||||
|
- b_strides2 (torch.Tensor): The weights strides of the second grouped gemm.
|
||||||
|
- c_strides2 (torch.Tensor): The output strides of the second grouped gemm.
|
||||||
|
- s_strides13 (torch.Tensor): The input and scale strides of the first grouped gemm.
|
||||||
|
- s_strides2 (torch.Tensor): The scale strides of the second grouped gemm.
|
||||||
|
- a1_scale (Optional[torch.Tensor]): The optional fp32 scale to quantize a.
|
||||||
|
Shape: scalar or [1, K]
|
||||||
|
- a2_scale (Optional[torch.Tensor]): The optional fp32 scale to
|
||||||
|
quantize the intermediate result between the gemms.
|
||||||
|
Shape: scalar or [1, N]
|
||||||
|
- apply_router_weight_on_input (bool): When true, the topk weights are
|
||||||
|
applied directly on the inputs. This is only applicable when topk is 1.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
- torch.Tensor: The fp8 output tensor after applying the MoE layer.
|
||||||
|
"""
|
||||||
|
assert topk_weights.shape == topk_ids_.shape, "topk shape mismatch"
|
||||||
|
assert w1_q.dtype == torch.int8
|
||||||
|
assert w2_q.dtype == torch.int8
|
||||||
|
assert a.shape[1] // 2 == w1_q.shape[2], "Hidden size mismatch w1"
|
||||||
|
assert w1_q.shape[2] * 2 == w2_q.shape[1], "Hidden size mismatch w2"
|
||||||
|
assert w1_q.shape[0] == w2_q.shape[0], "Expert number mismatch"
|
||||||
|
assert w1_q.shape[0] == w1_scale.shape[0], "w1 scales expert number mismatch"
|
||||||
|
assert w1_q.shape[0] == w2_scale.shape[0], "w2 scales expert number mismatch"
|
||||||
|
assert (
|
||||||
|
w1_scale.shape[1] == w1_q.shape[2] * 2 / 512
|
||||||
|
and w1_scale.shape[2] == w1_q.shape[1] * 4
|
||||||
|
), "W1 scale shape mismatch"
|
||||||
|
assert (
|
||||||
|
w2_scale.shape[1] == w2_q.shape[2] * 2 / 512
|
||||||
|
and w2_scale.shape[2] == w2_q.shape[1] * 4
|
||||||
|
), "W2 scale shape mismatch"
|
||||||
|
|
||||||
|
assert a_strides1.shape[0] == w1_q.shape[0], "A Strides 1 expert number mismatch"
|
||||||
|
assert b_strides1.shape[0] == w1_q.shape[0], "B Strides 1 expert number mismatch"
|
||||||
|
assert a_strides2.shape[0] == w2_q.shape[0], "A Strides 2 expert number mismatch"
|
||||||
|
assert b_strides2.shape[0] == w2_q.shape[0], "B Strides 2 expert number mismatch"
|
||||||
|
num_experts = w1_q.size(0)
|
||||||
|
m = a.size(0)
|
||||||
|
k = w1_q.size(2) * 2 # w1_q is transposed and packed
|
||||||
|
n = w2_q.size(2) * 2 # w2_q is transposed and packed
|
||||||
|
topk = topk_ids_.size(1)
|
||||||
|
|
||||||
|
if apply_router_weight_on_input:
|
||||||
|
assert topk == 1, "apply_router_weight_on_input is only implemented for topk=1"
|
||||||
|
|
||||||
|
device = a.device
|
||||||
|
|
||||||
|
_, src2dst, _ = run_cutlass_moe_ep_preproess(
|
||||||
|
local_topk_ids,
|
||||||
|
num_experts,
|
||||||
|
)
|
||||||
|
|
||||||
|
gateup_input = torch.empty(
|
||||||
|
(m * topk, k),
|
||||||
|
device=device,
|
||||||
|
dtype=torch.float8_e4m3fn,
|
||||||
|
)
|
||||||
|
|
||||||
|
pre_reorder_triton_kernel_for_cutlass_moe[(m,)](
|
||||||
|
a,
|
||||||
|
gateup_input,
|
||||||
|
src2dst,
|
||||||
|
local_topk_ids,
|
||||||
|
a1_scale,
|
||||||
|
total_num_experts,
|
||||||
|
topk,
|
||||||
|
k,
|
||||||
|
BLOCK_SIZE=512,
|
||||||
|
)
|
||||||
|
|
||||||
|
# NOTE: a_map and c_map are not used in the get_cutlass_w4a8_moe_mm_data kernel,
|
||||||
|
# they are kept to allow for a quick switch of the permutation logic
|
||||||
|
# from the current triton kernel implementation to the cutlass-based one if needed.
|
||||||
|
a_map = torch.empty((local_topk_ids.numel()), dtype=torch.int32, device=device)
|
||||||
|
c_map = torch.empty((local_topk_ids.numel()), dtype=torch.int32, device=device)
|
||||||
|
get_cutlass_w4a8_moe_mm_data(
|
||||||
|
local_topk_ids,
|
||||||
|
expert_offsets,
|
||||||
|
problem_sizes1,
|
||||||
|
problem_sizes2,
|
||||||
|
a_map,
|
||||||
|
c_map,
|
||||||
|
num_experts,
|
||||||
|
n,
|
||||||
|
k,
|
||||||
|
)
|
||||||
|
|
||||||
|
c1 = torch.empty((m * topk, n * 2), device=device, dtype=torch.half)
|
||||||
|
c2 = torch.zeros((m * topk, k), device=device, dtype=torch.half)
|
||||||
|
|
||||||
|
cutlass_w4a8_moe_mm(
|
||||||
|
c1,
|
||||||
|
gateup_input,
|
||||||
|
w1_q,
|
||||||
|
a1_scale.float(),
|
||||||
|
w1_scale,
|
||||||
|
expert_offsets[:-1],
|
||||||
|
problem_sizes1,
|
||||||
|
a_strides1,
|
||||||
|
b_strides1,
|
||||||
|
c_strides1,
|
||||||
|
s_strides13,
|
||||||
|
128,
|
||||||
|
topk,
|
||||||
|
)
|
||||||
|
|
||||||
|
intermediate = torch.empty((m * topk, n), device=device, dtype=torch.half)
|
||||||
|
silu_and_mul(c1, intermediate)
|
||||||
|
|
||||||
|
intermediate_q = torch.empty(
|
||||||
|
intermediate.shape, dtype=torch.float8_e4m3fn, device=device
|
||||||
|
)
|
||||||
|
sgl_per_tensor_quant_fp8(intermediate, intermediate_q, a2_scale.float(), True)
|
||||||
|
|
||||||
|
cutlass_w4a8_moe_mm(
|
||||||
|
c2,
|
||||||
|
intermediate_q,
|
||||||
|
w2_q,
|
||||||
|
a2_scale.float(),
|
||||||
|
w2_scale,
|
||||||
|
expert_offsets[:-1],
|
||||||
|
problem_sizes2,
|
||||||
|
a_strides2,
|
||||||
|
b_strides2,
|
||||||
|
c_strides2,
|
||||||
|
s_strides2,
|
||||||
|
128,
|
||||||
|
topk,
|
||||||
|
)
|
||||||
|
|
||||||
|
output = torch.empty_like(a)
|
||||||
|
post_reorder_triton_kernel[(m,)](
|
||||||
|
c2,
|
||||||
|
output,
|
||||||
|
src2dst,
|
||||||
|
topk_ids_,
|
||||||
|
topk_weights,
|
||||||
|
start_expert_id,
|
||||||
|
end_expert_id,
|
||||||
|
topk,
|
||||||
|
k,
|
||||||
|
0,
|
||||||
|
BLOCK_SIZE=512,
|
||||||
|
)
|
||||||
|
return output
|
||||||
@@ -146,6 +146,7 @@ def compute_seg_indptr_triton_kernel(reorder_topk_ids, seg_indptr, num_toks):
|
|||||||
|
|
||||||
def run_moe_ep_preproess(topk_ids: torch.Tensor, num_experts: int):
|
def run_moe_ep_preproess(topk_ids: torch.Tensor, num_experts: int):
|
||||||
reorder_topk_ids, reorder_ids = torch.sort(topk_ids.view(-1), stable=True)
|
reorder_topk_ids, reorder_ids = torch.sort(topk_ids.view(-1), stable=True)
|
||||||
|
|
||||||
seg_indptr = torch.zeros(num_experts + 1, device=topk_ids.device, dtype=torch.int64)
|
seg_indptr = torch.zeros(num_experts + 1, device=topk_ids.device, dtype=torch.int64)
|
||||||
src2dst = torch.empty(topk_ids.numel(), device=topk_ids.device, dtype=torch.int32)
|
src2dst = torch.empty(topk_ids.numel(), device=topk_ids.device, dtype=torch.int32)
|
||||||
|
|
||||||
@@ -158,9 +159,66 @@ def run_moe_ep_preproess(topk_ids: torch.Tensor, num_experts: int):
|
|||||||
compute_src2dst_triton_kernel[grid](
|
compute_src2dst_triton_kernel[grid](
|
||||||
reorder_ids, src2dst, topk_ids.numel(), BLOCK_SIZE
|
reorder_ids, src2dst, topk_ids.numel(), BLOCK_SIZE
|
||||||
)
|
)
|
||||||
|
|
||||||
return reorder_topk_ids, src2dst, seg_indptr
|
return reorder_topk_ids, src2dst, seg_indptr
|
||||||
|
|
||||||
|
|
||||||
|
def run_cutlass_moe_ep_preproess(local_topk_ids: torch.Tensor, local_num_experts: int):
|
||||||
|
reorder_topk_ids, reorder_ids = torch.sort(local_topk_ids.view(-1), stable=True)
|
||||||
|
|
||||||
|
seg_indptr = torch.zeros(
|
||||||
|
local_num_experts + 1, device=local_topk_ids.device, dtype=torch.int64
|
||||||
|
)
|
||||||
|
src2dst = torch.empty(
|
||||||
|
local_topk_ids.numel(), device=local_topk_ids.device, dtype=torch.int32
|
||||||
|
)
|
||||||
|
|
||||||
|
BLOCK_SIZE = 512
|
||||||
|
grid = (triton.cdiv(local_topk_ids.numel(), BLOCK_SIZE),)
|
||||||
|
compute_src2dst_triton_kernel[grid](
|
||||||
|
reorder_ids, src2dst, local_topk_ids.numel(), BLOCK_SIZE
|
||||||
|
)
|
||||||
|
|
||||||
|
return reorder_topk_ids, src2dst, seg_indptr
|
||||||
|
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def pre_reorder_triton_kernel_for_cutlass_moe(
|
||||||
|
input_ptr,
|
||||||
|
gateup_input_ptr,
|
||||||
|
src2dst_ptr,
|
||||||
|
topk_ids_ptr,
|
||||||
|
a1_scales_ptr,
|
||||||
|
num_experts,
|
||||||
|
topk,
|
||||||
|
hidden_size,
|
||||||
|
BLOCK_SIZE: tl.constexpr,
|
||||||
|
):
|
||||||
|
OutDtype = gateup_input_ptr.dtype.element_ty
|
||||||
|
|
||||||
|
src_idx = tl.program_id(0)
|
||||||
|
src2dst_ptr = src2dst_ptr + src_idx * topk
|
||||||
|
topk_ids_ptr = topk_ids_ptr + src_idx * topk
|
||||||
|
|
||||||
|
src_ptr = input_ptr + src_idx * hidden_size
|
||||||
|
for idx in range(topk):
|
||||||
|
expert_id = tl.load(topk_ids_ptr + idx)
|
||||||
|
if expert_id != num_experts:
|
||||||
|
if a1_scales_ptr is not None:
|
||||||
|
scale = 1.0 / tl.load(a1_scales_ptr)
|
||||||
|
else:
|
||||||
|
scale = 1.0
|
||||||
|
|
||||||
|
dst_idx = tl.load(src2dst_ptr + idx)
|
||||||
|
dst_ptr = gateup_input_ptr + dst_idx * hidden_size
|
||||||
|
for start_offset in tl.range(0, hidden_size, BLOCK_SIZE):
|
||||||
|
offset = start_offset + tl.arange(0, BLOCK_SIZE)
|
||||||
|
mask = offset < hidden_size
|
||||||
|
in_data = tl.load(src_ptr + offset, mask=mask).to(tl.float32)
|
||||||
|
out_data = (in_data * scale).to(OutDtype)
|
||||||
|
tl.store(dst_ptr + offset, out_data, mask=mask)
|
||||||
|
|
||||||
|
|
||||||
@triton.jit
|
@triton.jit
|
||||||
def pre_reorder_triton_kernel(
|
def pre_reorder_triton_kernel(
|
||||||
input_ptr,
|
input_ptr,
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ from sglang.srt.distributed import (
|
|||||||
)
|
)
|
||||||
from sglang.srt.eplb.expert_location import get_global_expert_location_metadata
|
from sglang.srt.eplb.expert_location import get_global_expert_location_metadata
|
||||||
from sglang.srt.eplb.expert_location_dispatch import ExpertLocationDispatchInfo
|
from sglang.srt.eplb.expert_location_dispatch import ExpertLocationDispatchInfo
|
||||||
|
from sglang.srt.layers.moe.cutlass_w4a8_moe import cutlass_w4a8_moe
|
||||||
from sglang.srt.layers.moe.ep_moe.kernels import (
|
from sglang.srt.layers.moe.ep_moe.kernels import (
|
||||||
ep_gather,
|
ep_gather,
|
||||||
ep_scatter,
|
ep_scatter,
|
||||||
@@ -20,6 +21,8 @@ from sglang.srt.layers.moe.ep_moe.kernels import (
|
|||||||
moe_ep_deepgemm_preprocess,
|
moe_ep_deepgemm_preprocess,
|
||||||
post_reorder_triton_kernel,
|
post_reorder_triton_kernel,
|
||||||
pre_reorder_triton_kernel,
|
pre_reorder_triton_kernel,
|
||||||
|
pre_reorder_triton_kernel_for_cutlass_moe,
|
||||||
|
run_cutlass_moe_ep_preproess,
|
||||||
run_moe_ep_preproess,
|
run_moe_ep_preproess,
|
||||||
silu_and_mul_masked_post_quant_fwd,
|
silu_and_mul_masked_post_quant_fwd,
|
||||||
silu_and_mul_triton_kernel,
|
silu_and_mul_triton_kernel,
|
||||||
@@ -41,6 +44,7 @@ from sglang.srt.layers.quantization.fp8_kernel import (
|
|||||||
sglang_per_token_quant_fp8,
|
sglang_per_token_quant_fp8,
|
||||||
)
|
)
|
||||||
from sglang.srt.layers.quantization.fp8_utils import normalize_e4m3fn_to_e4m3fnuz
|
from sglang.srt.layers.quantization.fp8_utils import normalize_e4m3fn_to_e4m3fnuz
|
||||||
|
from sglang.srt.layers.quantization.w4afp8 import W4AFp8Config, W4AFp8MoEMethod
|
||||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||||
from sglang.srt.utils import (
|
from sglang.srt.utils import (
|
||||||
@@ -191,7 +195,7 @@ class EPMoE(torch.nn.Module):
|
|||||||
num_fused_shared_experts == 0
|
num_fused_shared_experts == 0
|
||||||
), "num_fused_shared_experts is not supported in EP"
|
), "num_fused_shared_experts is not supported in EP"
|
||||||
self.num_fused_shared_experts = num_fused_shared_experts
|
self.num_fused_shared_experts = num_fused_shared_experts
|
||||||
self.num_experts_per_partition = self.num_experts // self.tp_size
|
self.num_experts_per_partition, self.expert_map = self.determine_expert_map()
|
||||||
self.start_expert_id = self.tp_rank * self.num_experts_per_partition
|
self.start_expert_id = self.tp_rank * self.num_experts_per_partition
|
||||||
self.end_expert_id = self.start_expert_id + self.num_experts_per_partition - 1
|
self.end_expert_id = self.start_expert_id + self.num_experts_per_partition - 1
|
||||||
|
|
||||||
@@ -215,6 +219,18 @@ class EPMoE(torch.nn.Module):
|
|||||||
self.use_block_quant = False
|
self.use_block_quant = False
|
||||||
self.block_shape = None
|
self.block_shape = None
|
||||||
self.activation_scheme = None
|
self.activation_scheme = None
|
||||||
|
self.use_w4afp8 = False
|
||||||
|
elif isinstance(quant_config, W4AFp8Config):
|
||||||
|
self.quant_method: Optional[QuantizeMethodBase] = W4AFp8MoEMethod(
|
||||||
|
quant_config
|
||||||
|
)
|
||||||
|
self.use_w4afp8 = True
|
||||||
|
self.use_fp8_w8a8 = False
|
||||||
|
self.use_block_quant = False
|
||||||
|
self.fp8_dtype = torch.float8_e4m3fn
|
||||||
|
self.w13_weight_scale = None
|
||||||
|
self.w2_weight_scale = None
|
||||||
|
self.activation_scheme = quant_config.moe_activation_scheme
|
||||||
else:
|
else:
|
||||||
self.quant_method: Optional[QuantizeMethodBase] = Fp8EPMoEMethod(
|
self.quant_method: Optional[QuantizeMethodBase] = Fp8EPMoEMethod(
|
||||||
quant_config
|
quant_config
|
||||||
@@ -228,6 +244,7 @@ class EPMoE(torch.nn.Module):
|
|||||||
)
|
)
|
||||||
self.fp8_dtype = torch.float8_e4m3fn
|
self.fp8_dtype = torch.float8_e4m3fn
|
||||||
self.activation_scheme = quant_config.activation_scheme
|
self.activation_scheme = quant_config.activation_scheme
|
||||||
|
self.use_w4afp8 = False
|
||||||
|
|
||||||
self.quant_method.create_weights(
|
self.quant_method.create_weights(
|
||||||
layer=self,
|
layer=self,
|
||||||
@@ -253,6 +270,49 @@ class EPMoE(torch.nn.Module):
|
|||||||
self.w2_weight_scale_inv if self.use_block_quant else self.w2_weight_scale,
|
self.w2_weight_scale_inv if self.use_block_quant else self.w2_weight_scale,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Adapted from https://github.com/vllm-project/vllm/blob/9fb52e523abf7bdaf7e60cf2971edb5a1b13dc08/vllm/model_executor/layers/fused_moe/layer.py#L544C1-L586C43
|
||||||
|
# Modifications: use determine_expert_map as a class internal function, set 'global_num_experts' rather than '-1' for experts not assigned to the current rank.
|
||||||
|
def determine_expert_map(self) -> Tuple[int, Optional[torch.Tensor]]:
|
||||||
|
"""
|
||||||
|
Calculates how many experts should be assigned to each rank for EP and
|
||||||
|
creates a mapping from global to local expert index. Experts are
|
||||||
|
distributed evenly across ranks. Any remaining are assigned to the
|
||||||
|
last rank.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple[int, Optional[torch.Tensor]]: A tuple containing:
|
||||||
|
- local_num_experts (int): The number of experts assigned
|
||||||
|
to the current rank.
|
||||||
|
- expert_map (Optional[torch.Tensor]): A tensor of shape
|
||||||
|
(global_num_experts,) mapping from global to local index.
|
||||||
|
Contains global_num_experts for experts not assigned to the current rank.
|
||||||
|
Returns None if ep_size is 1.
|
||||||
|
"""
|
||||||
|
ep_size = self.tp_size
|
||||||
|
ep_rank = self.tp_rank
|
||||||
|
global_num_experts = self.num_experts
|
||||||
|
|
||||||
|
assert ep_size > 0
|
||||||
|
if ep_size == 1:
|
||||||
|
return (global_num_experts, None)
|
||||||
|
|
||||||
|
local_num_experts = global_num_experts // ep_size
|
||||||
|
|
||||||
|
expert_map = torch.full(
|
||||||
|
(global_num_experts,), self.num_experts, dtype=torch.int32
|
||||||
|
)
|
||||||
|
if ep_rank < (ep_size - 1):
|
||||||
|
expert_map[
|
||||||
|
ep_rank * local_num_experts : (ep_rank + 1) * local_num_experts
|
||||||
|
] = torch.arange(0, local_num_experts, dtype=torch.int32)
|
||||||
|
else:
|
||||||
|
local_num_experts = global_num_experts - ep_rank * local_num_experts
|
||||||
|
|
||||||
|
expert_map[-local_num_experts:] = torch.arange(
|
||||||
|
0, local_num_experts, dtype=torch.int32
|
||||||
|
)
|
||||||
|
return (local_num_experts, expert_map)
|
||||||
|
|
||||||
def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor):
|
def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor):
|
||||||
if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM and self.use_fp8_w8a8:
|
if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM and self.use_fp8_w8a8:
|
||||||
return self.forward_deepgemm(hidden_states, router_logits)
|
return self.forward_deepgemm(hidden_states, router_logits)
|
||||||
@@ -440,6 +500,51 @@ class EPMoE(torch.nn.Module):
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if self.use_w4afp8:
|
||||||
|
local_topk_ids = topk_ids
|
||||||
|
if self.expert_map is not None:
|
||||||
|
"Translate info from expert_map to topk_ids"
|
||||||
|
local_topk_ids = torch.where(
|
||||||
|
self.expert_map[topk_ids] != self.num_experts,
|
||||||
|
self.expert_map[topk_ids],
|
||||||
|
self.num_experts,
|
||||||
|
)
|
||||||
|
|
||||||
|
output = cutlass_w4a8_moe(
|
||||||
|
self.start_expert_id,
|
||||||
|
self.end_expert_id,
|
||||||
|
self.num_experts,
|
||||||
|
hidden_states,
|
||||||
|
self.w13_weight,
|
||||||
|
self.w2_weight,
|
||||||
|
self.w13_weight_scale_inv,
|
||||||
|
self.w2_weight_scale_inv,
|
||||||
|
topk_weights,
|
||||||
|
topk_ids,
|
||||||
|
local_topk_ids,
|
||||||
|
self.quant_method.a_strides1,
|
||||||
|
self.quant_method.b_strides1,
|
||||||
|
self.quant_method.c_strides1,
|
||||||
|
self.quant_method.a_strides2,
|
||||||
|
self.quant_method.b_strides2,
|
||||||
|
self.quant_method.c_strides2,
|
||||||
|
self.quant_method.s_strides13,
|
||||||
|
self.quant_method.s_strides2,
|
||||||
|
self.quant_method.expert_offsets,
|
||||||
|
self.quant_method.problem_sizes1,
|
||||||
|
self.quant_method.problem_sizes2,
|
||||||
|
self.w13_input_scale,
|
||||||
|
self.w2_input_scale,
|
||||||
|
)
|
||||||
|
return output
|
||||||
|
|
||||||
|
if self.grouped_gemm_runner is None:
|
||||||
|
self.grouped_gemm_runner = GroupedGemmRunner(
|
||||||
|
hidden_states.device,
|
||||||
|
use_flashinfer=False, # TODO: use flashinfer
|
||||||
|
use_per_token_if_dynamic=self.use_per_token_if_dynamic,
|
||||||
|
)
|
||||||
|
|
||||||
reorder_topk_ids, src2dst, seg_indptr = run_moe_ep_preproess(
|
reorder_topk_ids, src2dst, seg_indptr = run_moe_ep_preproess(
|
||||||
topk_ids, self.num_experts
|
topk_ids, self.num_experts
|
||||||
)
|
)
|
||||||
@@ -449,7 +554,7 @@ class EPMoE(torch.nn.Module):
|
|||||||
device=hidden_states.device,
|
device=hidden_states.device,
|
||||||
dtype=(
|
dtype=(
|
||||||
self.fp8_dtype
|
self.fp8_dtype
|
||||||
if (self.use_fp8_w8a8 and not self.use_block_quant)
|
if ((self.use_fp8_w8a8 or self.use_w4afp8) and not self.use_block_quant)
|
||||||
else hidden_states.dtype
|
else hidden_states.dtype
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
@@ -656,6 +761,23 @@ class EPMoE(torch.nn.Module):
|
|||||||
]
|
]
|
||||||
]
|
]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def make_expert_input_scale_params_mapping(
|
||||||
|
cls,
|
||||||
|
num_experts: int,
|
||||||
|
) -> List[Tuple[str, str, int, str]]:
|
||||||
|
# (param_name, weight_name, expert_id, shard_id)
|
||||||
|
return [
|
||||||
|
(
|
||||||
|
"experts.w13_" if shard_id in ["w1", "w3"] else "experts.w2_",
|
||||||
|
f"experts.{expert_id}.{shard_id}.",
|
||||||
|
expert_id,
|
||||||
|
shard_id,
|
||||||
|
)
|
||||||
|
for expert_id in range(num_experts)
|
||||||
|
for shard_id in ["w1", "w2", "w3"]
|
||||||
|
]
|
||||||
|
|
||||||
def weight_loader(
|
def weight_loader(
|
||||||
self,
|
self,
|
||||||
param: torch.nn.Parameter,
|
param: torch.nn.Parameter,
|
||||||
@@ -727,6 +849,15 @@ class EPMoE(torch.nn.Module):
|
|||||||
|
|
||||||
# Input scales can be loaded directly and should be equal.
|
# Input scales can be loaded directly and should be equal.
|
||||||
if "input_scale" in weight_name:
|
if "input_scale" in weight_name:
|
||||||
|
if self.use_w4afp8:
|
||||||
|
if shard_id == "w1":
|
||||||
|
param_data[expert_id][0] = loaded_weight
|
||||||
|
elif shard_id == "w3":
|
||||||
|
param_data[expert_id][1] = loaded_weight
|
||||||
|
else:
|
||||||
|
param_data[expert_id] = loaded_weight
|
||||||
|
return
|
||||||
|
|
||||||
if (
|
if (
|
||||||
(shard_id == "w1" or shard_id == "w3")
|
(shard_id == "w1" or shard_id == "w3")
|
||||||
and param_data[expert_id] != 1
|
and param_data[expert_id] != 1
|
||||||
@@ -752,6 +883,13 @@ class EPMoE(torch.nn.Module):
|
|||||||
] = loaded_weight
|
] = loaded_weight
|
||||||
else: # w2
|
else: # w2
|
||||||
param_data[expert_id] = loaded_weight
|
param_data[expert_id] = loaded_weight
|
||||||
|
elif self.use_w4afp8:
|
||||||
|
if shard_id == "w1":
|
||||||
|
param_data[expert_id][: self.intermediate_size, :] = loaded_weight
|
||||||
|
elif shard_id == "w3":
|
||||||
|
param_data[expert_id][self.intermediate_size :, :] = loaded_weight
|
||||||
|
else:
|
||||||
|
param_data[expert_id] = loaded_weight
|
||||||
# If we are in merged column case (gate_up_proj)
|
# If we are in merged column case (gate_up_proj)
|
||||||
else:
|
else:
|
||||||
if shard_id in ("w1", "w3"):
|
if shard_id in ("w1", "w3"):
|
||||||
|
|||||||
@@ -68,6 +68,7 @@ from sglang.srt.layers.quantization.modelopt_quant import (
|
|||||||
)
|
)
|
||||||
from sglang.srt.layers.quantization.moe_wna16 import MoeWNA16Config
|
from sglang.srt.layers.quantization.moe_wna16 import MoeWNA16Config
|
||||||
from sglang.srt.layers.quantization.qoq import QoQConfig
|
from sglang.srt.layers.quantization.qoq import QoQConfig
|
||||||
|
from sglang.srt.layers.quantization.w4afp8 import W4AFp8Config
|
||||||
from sglang.srt.layers.quantization.w8a8_fp8 import W8A8Fp8Config
|
from sglang.srt.layers.quantization.w8a8_fp8 import W8A8Fp8Config
|
||||||
from sglang.srt.layers.quantization.w8a8_int8 import W8A8Int8Config
|
from sglang.srt.layers.quantization.w8a8_int8 import W8A8Int8Config
|
||||||
|
|
||||||
@@ -82,6 +83,7 @@ BASE_QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
|
|||||||
"moe_wna16": MoeWNA16Config,
|
"moe_wna16": MoeWNA16Config,
|
||||||
"compressed-tensors": CompressedTensorsConfig,
|
"compressed-tensors": CompressedTensorsConfig,
|
||||||
"qoq": QoQConfig,
|
"qoq": QoQConfig,
|
||||||
|
"w4afp8": W4AFp8Config,
|
||||||
}
|
}
|
||||||
|
|
||||||
# VLLM-dependent quantization methods
|
# VLLM-dependent quantization methods
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/model_executor/layers/quantization/fp8.py
|
# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/model_executor/layers/quantization/fp8.py
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from typing import Any, Callable, Dict, List, Optional
|
from typing import Any, Callable, Dict, List, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
@@ -200,7 +200,7 @@ class Fp8LinearMethod(LinearMethodBase):
|
|||||||
quant_config: The quantization config.
|
quant_config: The quantization config.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, quant_config: Fp8Config):
|
def __init__(self, quant_config: Union["Fp8Config", "W4AFp8Config"]):
|
||||||
self.quant_config = quant_config
|
self.quant_config = quant_config
|
||||||
self.cutlass_fp8_supported = cutlass_fp8_supported()
|
self.cutlass_fp8_supported = cutlass_fp8_supported()
|
||||||
|
|
||||||
@@ -286,7 +286,10 @@ class Fp8LinearMethod(LinearMethodBase):
|
|||||||
if self.quant_config.is_checkpoint_fp8_serialized:
|
if self.quant_config.is_checkpoint_fp8_serialized:
|
||||||
# WEIGHT SCALE
|
# WEIGHT SCALE
|
||||||
if self.block_quant:
|
if self.block_quant:
|
||||||
assert self.quant_config.activation_scheme == "dynamic"
|
if hasattr(self.quant_config, "activation_scheme"):
|
||||||
|
assert self.quant_config.activation_scheme == "dynamic"
|
||||||
|
elif hasattr(self.quant_config, "linear_activation_scheme"):
|
||||||
|
assert self.quant_config.linear_activation_scheme == "dynamic"
|
||||||
scale = BlockQuantScaleParameter(
|
scale = BlockQuantScaleParameter(
|
||||||
data=torch.empty(
|
data=torch.empty(
|
||||||
(output_size_per_partition + block_n - 1) // block_n,
|
(output_size_per_partition + block_n - 1) // block_n,
|
||||||
@@ -308,7 +311,13 @@ class Fp8LinearMethod(LinearMethodBase):
|
|||||||
layer.register_parameter("weight_scale", scale)
|
layer.register_parameter("weight_scale", scale)
|
||||||
|
|
||||||
# INPUT ACTIVATION SCALE
|
# INPUT ACTIVATION SCALE
|
||||||
if self.quant_config.activation_scheme == "static":
|
if (
|
||||||
|
hasattr(self.quant_config, "activation_scheme")
|
||||||
|
and self.quant_config.activation_scheme == "static"
|
||||||
|
) or (
|
||||||
|
hasattr(self.quant_config, "linear_activation_scheme")
|
||||||
|
and self.quant_config.linear_activation_scheme == "static"
|
||||||
|
):
|
||||||
scale = PerTensorScaleParameter(
|
scale = PerTensorScaleParameter(
|
||||||
data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
|
data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
|
||||||
weight_loader=weight_loader,
|
weight_loader=weight_loader,
|
||||||
@@ -371,7 +380,13 @@ class Fp8LinearMethod(LinearMethodBase):
|
|||||||
layer.weight_scale = torch.nn.Parameter(
|
layer.weight_scale = torch.nn.Parameter(
|
||||||
layer.weight_scale.data, requires_grad=False
|
layer.weight_scale.data, requires_grad=False
|
||||||
)
|
)
|
||||||
if self.quant_config.activation_scheme == "static":
|
if (
|
||||||
|
hasattr(self.quant_config, "activation_scheme")
|
||||||
|
and self.quant_config.activation_scheme == "static"
|
||||||
|
) or (
|
||||||
|
hasattr(self.quant_config, "linear_activation_scheme")
|
||||||
|
and self.quant_config.linear_activation_scheme == "static"
|
||||||
|
):
|
||||||
layer.input_scale = torch.nn.Parameter(
|
layer.input_scale = torch.nn.Parameter(
|
||||||
layer.input_scale.data, requires_grad=False
|
layer.input_scale.data, requires_grad=False
|
||||||
)
|
)
|
||||||
@@ -405,7 +420,13 @@ class Fp8LinearMethod(LinearMethodBase):
|
|||||||
# Update layer with new values.
|
# Update layer with new values.
|
||||||
layer.weight = Parameter(weight.t(), requires_grad=False)
|
layer.weight = Parameter(weight.t(), requires_grad=False)
|
||||||
layer.weight_scale = Parameter(weight_scale, requires_grad=False)
|
layer.weight_scale = Parameter(weight_scale, requires_grad=False)
|
||||||
if self.quant_config.activation_scheme == "static":
|
if (
|
||||||
|
hasattr(self.quant_config, "activation_scheme")
|
||||||
|
and self.quant_config.activation_scheme == "static"
|
||||||
|
) or (
|
||||||
|
hasattr(self.quant_config, "linear_activation_scheme")
|
||||||
|
and self.quant_config.linear_activation_scheme == "static"
|
||||||
|
):
|
||||||
layer.input_scale = Parameter(
|
layer.input_scale = Parameter(
|
||||||
layer.input_scale.max(), requires_grad=False
|
layer.input_scale.max(), requires_grad=False
|
||||||
)
|
)
|
||||||
|
|||||||
264
python/sglang/srt/layers/quantization/w4afp8.py
Normal file
264
python/sglang/srt/layers/quantization/w4afp8.py
Normal file
@@ -0,0 +1,264 @@
|
|||||||
|
import logging
|
||||||
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch.nn import Module
|
||||||
|
from torch.nn.parameter import Parameter
|
||||||
|
|
||||||
|
from sglang.srt.layers.linear import LinearBase, UnquantizedLinearMethod
|
||||||
|
from sglang.srt.layers.quantization.base_config import (
|
||||||
|
QuantizationConfig,
|
||||||
|
QuantizeMethodBase,
|
||||||
|
)
|
||||||
|
from sglang.srt.layers.quantization.fp8 import Fp8LinearMethod
|
||||||
|
from sglang.srt.layers.quantization.utils import is_layer_skipped
|
||||||
|
from sglang.srt.utils import set_weight_attrs
|
||||||
|
|
||||||
|
ACTIVATION_SCHEMES = ["static", "dynamic"]
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class W4AFp8Config(QuantizationConfig):
|
||||||
|
"""Config class for MIXED_PRECISION W4AFp8."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
is_checkpoint_fp8_serialized: bool = True,
|
||||||
|
is_checkpoint_w4afp8_serialized: bool = True,
|
||||||
|
linear_activation_scheme: str = "dynamic",
|
||||||
|
moe_activation_scheme: str = "static",
|
||||||
|
ignored_layers: Optional[List[str]] = None,
|
||||||
|
weight_block_size: Optional[List[int]] = None,
|
||||||
|
group_size: int = 128,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized
|
||||||
|
self.is_checkpoint_w4afp8_serialized = is_checkpoint_w4afp8_serialized
|
||||||
|
if is_checkpoint_w4afp8_serialized:
|
||||||
|
logger.warning("Detected w4afp8 checkpoint. Please note that")
|
||||||
|
if moe_activation_scheme not in ACTIVATION_SCHEMES:
|
||||||
|
raise ValueError(f"Unsupported activation scheme {moe_activation_scheme}")
|
||||||
|
self.linear_activation_scheme = linear_activation_scheme
|
||||||
|
self.moe_activation_scheme = moe_activation_scheme
|
||||||
|
self.ignored_layers = ignored_layers or []
|
||||||
|
self.weight_block_size = [128, 128]
|
||||||
|
self.group_size = group_size
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_name(cls) -> str:
|
||||||
|
return "w4afp8"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
|
||||||
|
return [torch.bfloat16, torch.float8_e4m3fn]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_min_capability(cls) -> int:
|
||||||
|
return 90
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_config_filenames(cls) -> List[str]:
|
||||||
|
return []
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_config(cls, config: Dict[str, Any]) -> "W4AFp8Config":
|
||||||
|
quant_method = cls.get_from_keys(config, ["quant_method"])
|
||||||
|
is_checkpoint_fp8_serialized = "fp8" in quant_method
|
||||||
|
is_checkpoint_w4afp8_serialized = "w4afp8" in quant_method
|
||||||
|
linear_activation_scheme = "dynamic"
|
||||||
|
moe_activation_scheme = "static"
|
||||||
|
weight_block_size = [128, 128]
|
||||||
|
return cls(
|
||||||
|
is_checkpoint_fp8_serialized=is_checkpoint_fp8_serialized,
|
||||||
|
is_checkpoint_w4afp8_serialized=is_checkpoint_w4afp8_serialized,
|
||||||
|
linear_activation_scheme=linear_activation_scheme,
|
||||||
|
moe_activation_scheme=moe_activation_scheme,
|
||||||
|
weight_block_size=weight_block_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_quant_method(
|
||||||
|
self, layer: torch.nn.Module, prefix: str
|
||||||
|
) -> Optional["QuantizeMethodBase"]:
|
||||||
|
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
||||||
|
|
||||||
|
if isinstance(layer, LinearBase):
|
||||||
|
if is_layer_skipped(prefix, self.ignored_layers):
|
||||||
|
return UnquantizedLinearMethod()
|
||||||
|
return Fp8LinearMethod(self)
|
||||||
|
elif isinstance(layer, FusedMoE):
|
||||||
|
return W4AFp8MoEMethod(self)
|
||||||
|
return None
|
||||||
|
|
||||||
|
def get_scaled_act_names(self) -> List[str]:
|
||||||
|
return []
|
||||||
|
|
||||||
|
|
||||||
|
class W4AFp8MoEMethod:
|
||||||
|
|
||||||
|
def __init__(self, quant_config: W4AFp8Config):
|
||||||
|
self.quant_config = quant_config
|
||||||
|
|
||||||
|
def create_weights(
|
||||||
|
self,
|
||||||
|
layer: Module,
|
||||||
|
num_experts_per_partition: int,
|
||||||
|
hidden_size: int,
|
||||||
|
intermediate_size: int,
|
||||||
|
params_dtype: torch.dtype,
|
||||||
|
**extra_weight_attrs,
|
||||||
|
):
|
||||||
|
assert "weight_loader" in extra_weight_attrs
|
||||||
|
|
||||||
|
# Fused gate_up_proj (column parallel)
|
||||||
|
w13_weight = torch.nn.Parameter(
|
||||||
|
torch.empty(
|
||||||
|
num_experts_per_partition,
|
||||||
|
intermediate_size * 2,
|
||||||
|
hidden_size // 2,
|
||||||
|
dtype=torch.int8,
|
||||||
|
),
|
||||||
|
requires_grad=False,
|
||||||
|
)
|
||||||
|
layer.register_parameter("w13_weight", w13_weight)
|
||||||
|
set_weight_attrs(w13_weight, extra_weight_attrs)
|
||||||
|
|
||||||
|
# down_proj (row parallel)
|
||||||
|
w2_weight = torch.nn.Parameter(
|
||||||
|
torch.empty(
|
||||||
|
num_experts_per_partition,
|
||||||
|
hidden_size,
|
||||||
|
intermediate_size // 2,
|
||||||
|
dtype=torch.int8,
|
||||||
|
),
|
||||||
|
requires_grad=False,
|
||||||
|
)
|
||||||
|
layer.register_parameter("w2_weight", w2_weight)
|
||||||
|
set_weight_attrs(w2_weight, extra_weight_attrs)
|
||||||
|
|
||||||
|
w13_weight_scale = torch.nn.Parameter(
|
||||||
|
torch.zeros(
|
||||||
|
num_experts_per_partition,
|
||||||
|
2 * intermediate_size,
|
||||||
|
hidden_size // self.quant_config.group_size,
|
||||||
|
dtype=torch.float32,
|
||||||
|
),
|
||||||
|
requires_grad=False,
|
||||||
|
)
|
||||||
|
layer.register_parameter("w13_weight_scale_inv", w13_weight_scale)
|
||||||
|
set_weight_attrs(w13_weight_scale, extra_weight_attrs)
|
||||||
|
|
||||||
|
w2_weight_scale = torch.nn.Parameter(
|
||||||
|
torch.zeros(
|
||||||
|
num_experts_per_partition,
|
||||||
|
hidden_size,
|
||||||
|
intermediate_size // self.quant_config.group_size,
|
||||||
|
dtype=torch.float32,
|
||||||
|
),
|
||||||
|
requires_grad=False,
|
||||||
|
)
|
||||||
|
layer.register_parameter("w2_weight_scale_inv", w2_weight_scale)
|
||||||
|
set_weight_attrs(w2_weight_scale, extra_weight_attrs)
|
||||||
|
|
||||||
|
# Input scales
|
||||||
|
w13_input_scale = torch.nn.Parameter(
|
||||||
|
torch.ones((num_experts_per_partition, 2), dtype=torch.bfloat16),
|
||||||
|
requires_grad=False,
|
||||||
|
)
|
||||||
|
layer.register_parameter("w13_input_scale", w13_input_scale)
|
||||||
|
set_weight_attrs(w13_input_scale, extra_weight_attrs)
|
||||||
|
|
||||||
|
w2_input_scale = torch.nn.Parameter(
|
||||||
|
torch.ones(num_experts_per_partition, dtype=torch.bfloat16),
|
||||||
|
requires_grad=False,
|
||||||
|
)
|
||||||
|
layer.register_parameter("w2_input_scale", w2_input_scale)
|
||||||
|
set_weight_attrs(w2_input_scale, extra_weight_attrs)
|
||||||
|
|
||||||
|
# Pre-populate the strides
|
||||||
|
device = layer.w13_weight.device
|
||||||
|
|
||||||
|
self.a_strides1 = torch.full(
|
||||||
|
(num_experts_per_partition, 3),
|
||||||
|
hidden_size,
|
||||||
|
device=device,
|
||||||
|
dtype=torch.int64,
|
||||||
|
)
|
||||||
|
self.c_strides1 = torch.full(
|
||||||
|
(num_experts_per_partition, 3),
|
||||||
|
2 * intermediate_size,
|
||||||
|
device=device,
|
||||||
|
dtype=torch.int64,
|
||||||
|
)
|
||||||
|
self.a_strides2 = torch.full(
|
||||||
|
(num_experts_per_partition, 3),
|
||||||
|
intermediate_size,
|
||||||
|
device=device,
|
||||||
|
dtype=torch.int64,
|
||||||
|
)
|
||||||
|
self.c_strides2 = torch.full(
|
||||||
|
(num_experts_per_partition, 3),
|
||||||
|
hidden_size,
|
||||||
|
device=device,
|
||||||
|
dtype=torch.int64,
|
||||||
|
)
|
||||||
|
self.b_strides1 = self.a_strides1
|
||||||
|
self.s_strides13 = self.c_strides1
|
||||||
|
self.b_strides2 = self.a_strides2
|
||||||
|
self.s_strides2 = self.c_strides2
|
||||||
|
|
||||||
|
self.expert_offsets = torch.empty(
|
||||||
|
(num_experts_per_partition + 1), dtype=torch.int32, device=device
|
||||||
|
)
|
||||||
|
self.problem_sizes1 = torch.empty(
|
||||||
|
(num_experts_per_partition, 3), dtype=torch.int32, device=device
|
||||||
|
)
|
||||||
|
self.problem_sizes2 = torch.empty(
|
||||||
|
(num_experts_per_partition, 3), dtype=torch.int32, device=device
|
||||||
|
)
|
||||||
|
|
||||||
|
return
|
||||||
|
|
||||||
|
def _interleave_scales(self, scales: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""Interleave scales in groups of 4 similar to TRT-LLM implementation."""
|
||||||
|
s_shape = scales.shape
|
||||||
|
# Reshape to separate groups of 4
|
||||||
|
scales_interleaved = scales.reshape(
|
||||||
|
s_shape[0], s_shape[1], (s_shape[2] // 4), 4
|
||||||
|
)
|
||||||
|
# Permute dimensions to interleave
|
||||||
|
scales_interleaved = scales_interleaved.permute(0, 2, 1, 3)
|
||||||
|
# Reshape back to original dimensions but with interleaved values
|
||||||
|
scales_interleaved = scales_interleaved.reshape(
|
||||||
|
s_shape[0], s_shape[2] // 4, s_shape[1] * 4
|
||||||
|
)
|
||||||
|
return scales_interleaved.contiguous()
|
||||||
|
|
||||||
|
def process_weights_after_loading(self, layer: Module) -> None:
|
||||||
|
dtype = torch.bfloat16
|
||||||
|
device = layer.w2_weight.device
|
||||||
|
|
||||||
|
# Interleave w13_weight_scale (gate_up_proj)
|
||||||
|
w13_weight_scale = layer.w13_weight_scale_inv.to(dtype)
|
||||||
|
w13_weight_scale = self._interleave_scales(w13_weight_scale)
|
||||||
|
layer.w13_weight_scale_inv = Parameter(w13_weight_scale, requires_grad=False)
|
||||||
|
|
||||||
|
# Interleave w2_weight_scale (down_proj)
|
||||||
|
w2_weight_scale = layer.w2_weight_scale_inv.to(dtype)
|
||||||
|
w2_weight_scale = self._interleave_scales(w2_weight_scale)
|
||||||
|
layer.w2_weight_scale_inv = Parameter(w2_weight_scale, requires_grad=False)
|
||||||
|
|
||||||
|
# Process input scales
|
||||||
|
w13_input_scale_max = layer.w13_input_scale.max().to(dtype).item()
|
||||||
|
new_w13_input_scale = torch.tensor(
|
||||||
|
[w13_input_scale_max],
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
layer.w13_input_scale = Parameter(new_w13_input_scale, requires_grad=False)
|
||||||
|
|
||||||
|
w2_input_scale_max = layer.w2_input_scale.max().to(dtype).item()
|
||||||
|
new_w2_input_scale = torch.tensor(
|
||||||
|
[w2_input_scale_max], dtype=dtype, device=device
|
||||||
|
)
|
||||||
|
layer.w2_input_scale = Parameter(new_w2_input_scale, requires_grad=False)
|
||||||
@@ -2363,6 +2363,12 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|||||||
ckpt_up_proj_name="up_proj",
|
ckpt_up_proj_name="up_proj",
|
||||||
num_experts=self.config.n_routed_experts + self.num_fused_shared_experts,
|
num_experts=self.config.n_routed_experts + self.num_fused_shared_experts,
|
||||||
)
|
)
|
||||||
|
if self.quant_config and self.quant_config.get_name() == "w4afp8":
|
||||||
|
expert_params_mapping += (
|
||||||
|
get_moe_impl_class().make_expert_input_scale_params_mapping(
|
||||||
|
num_experts=self.config.n_routed_experts
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
# Fuse q_a_proj and kv_a_proj_with_mqa along output dimension when q_lora_rank is not None
|
# Fuse q_a_proj and kv_a_proj_with_mqa along output dimension when q_lora_rank is not None
|
||||||
fuse_qkv_a_proj = hasattr(self.config, "q_lora_rank") and (
|
fuse_qkv_a_proj = hasattr(self.config, "q_lora_rank") and (
|
||||||
|
|||||||
@@ -708,6 +708,7 @@ class ServerArgs:
|
|||||||
"w8a8_fp8",
|
"w8a8_fp8",
|
||||||
"moe_wna16",
|
"moe_wna16",
|
||||||
"qoq",
|
"qoq",
|
||||||
|
"w4afp8",
|
||||||
],
|
],
|
||||||
help="The quantization method.",
|
help="The quantization method.",
|
||||||
)
|
)
|
||||||
|
|||||||
281
python/sglang/test/test_cutlass_w4a8_moe.py
Normal file
281
python/sglang/test/test_cutlass_w4a8_moe.py
Normal file
@@ -0,0 +1,281 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from sglang.srt.layers.moe.cutlass_w4a8_moe import cutlass_w4a8_moe
|
||||||
|
from sglang.srt.layers.moe.topk import select_experts
|
||||||
|
|
||||||
|
|
||||||
|
def pack_int4_values_to_int8(int4_values_interleaved: torch.Tensor) -> torch.Tensor:
|
||||||
|
if int4_values_interleaved.shape[-1] % 2 != 0:
|
||||||
|
raise ValueError(
|
||||||
|
"the last dim size of int4_values_interleaved tensor must be even."
|
||||||
|
)
|
||||||
|
|
||||||
|
input_tensor_int8 = int4_values_interleaved.to(torch.int8)
|
||||||
|
|
||||||
|
low_nibbles = input_tensor_int8[..., 0::2]
|
||||||
|
high_nibbles = input_tensor_int8[..., 1::2]
|
||||||
|
|
||||||
|
packed_tensor = (high_nibbles << 4) | (low_nibbles & 0x0F)
|
||||||
|
|
||||||
|
return packed_tensor.to(torch.int8)
|
||||||
|
|
||||||
|
|
||||||
|
def pack_interleave(num_experts, ref_weight, ref_scale):
|
||||||
|
n, k = ref_weight.shape[1], ref_weight.shape[2]
|
||||||
|
|
||||||
|
weight = pack_int4_values_to_int8(ref_weight.cpu()).cuda()
|
||||||
|
w_q = weight.view((num_experts, n, k // 2)).view(torch.int8)
|
||||||
|
w_q = w_q.contiguous()
|
||||||
|
|
||||||
|
scale_interleaved = ref_scale.reshape(
|
||||||
|
ref_scale.shape[0], ref_scale.shape[1], (ref_scale.shape[2] // 4), 4
|
||||||
|
) # [E, N, K/4, 4]
|
||||||
|
scale_interleaved = scale_interleaved.permute(0, 2, 1, 3) # [E, K/4, N, 4]
|
||||||
|
scale_interleaved = scale_interleaved.reshape(
|
||||||
|
ref_scale.shape[0], ref_scale.shape[2] // 4, ref_scale.shape[1] * 4
|
||||||
|
) # [E, K/4, N*4]
|
||||||
|
w_scale = scale_interleaved.contiguous()
|
||||||
|
|
||||||
|
return w_q, w_scale
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("M", [1, 2, 4, 8, 16])
|
||||||
|
@pytest.mark.parametrize("N", [2048])
|
||||||
|
@pytest.mark.parametrize("K", [7168])
|
||||||
|
@pytest.mark.parametrize("E", [256])
|
||||||
|
@pytest.mark.parametrize("ep_size", [8])
|
||||||
|
@pytest.mark.parametrize("topk", [8])
|
||||||
|
@pytest.mark.parametrize("group_size", [128])
|
||||||
|
@pytest.mark.parametrize("dtype", [torch.bfloat16])
|
||||||
|
def test_cutlass_w4a8_moe(M, N, K, E, ep_size, topk, group_size, dtype):
|
||||||
|
local_e = E // ep_size
|
||||||
|
|
||||||
|
debug = False
|
||||||
|
if debug:
|
||||||
|
a = torch.ones((M, K), dtype=dtype, device="cuda") * 0.001
|
||||||
|
ref_weight_1 = torch.ones((local_e, N * 2, K), dtype=torch.int8, device="cuda")
|
||||||
|
ref_weight_2 = torch.ones((local_e, K, N), dtype=torch.int8, device="cuda")
|
||||||
|
a1_scale = torch.ones(1, dtype=torch.float32, device="cuda")
|
||||||
|
a2_scale = torch.ones(1, dtype=torch.float32, device="cuda")
|
||||||
|
scale_1 = torch.ones(
|
||||||
|
(local_e, N * 2, K // group_size), dtype=dtype, device="cuda"
|
||||||
|
)
|
||||||
|
scale_2 = torch.ones((local_e, K, N // group_size), dtype=dtype, device="cuda")
|
||||||
|
else:
|
||||||
|
a = torch.randn(M, K, dtype=dtype, device="cuda")
|
||||||
|
ref_weight_1 = torch.randint(
|
||||||
|
-8, 8, (local_e, N * 2, K), dtype=torch.int8, device="cuda"
|
||||||
|
)
|
||||||
|
ref_weight_2 = torch.randint(
|
||||||
|
-8, 8, (local_e, K, N), dtype=torch.int8, device="cuda"
|
||||||
|
)
|
||||||
|
affine_coeff = 0.005
|
||||||
|
a1_scale = torch.randn(1, dtype=torch.float32, device="cuda")
|
||||||
|
a2_scale = torch.randn(1, dtype=torch.float32, device="cuda")
|
||||||
|
scale_1 = (
|
||||||
|
torch.randn(local_e, N * 2, K // group_size, dtype=dtype, device="cuda")
|
||||||
|
* affine_coeff
|
||||||
|
)
|
||||||
|
scale_2 = (
|
||||||
|
torch.randn(local_e, K, N // group_size, dtype=dtype, device="cuda")
|
||||||
|
* affine_coeff
|
||||||
|
)
|
||||||
|
|
||||||
|
w1_q, w1_scale = pack_interleave(local_e, ref_weight_1, scale_1)
|
||||||
|
w2_q, w2_scale = pack_interleave(local_e, ref_weight_2, scale_2)
|
||||||
|
|
||||||
|
device = "cuda"
|
||||||
|
a_strides1 = torch.full((local_e, 3), K, device=device, dtype=torch.int64)
|
||||||
|
c_strides1 = torch.full((local_e, 3), 2 * N, device=device, dtype=torch.int64)
|
||||||
|
a_strides2 = torch.full((local_e, 3), N, device=device, dtype=torch.int64)
|
||||||
|
c_strides2 = torch.full((local_e, 3), K, device=device, dtype=torch.int64)
|
||||||
|
b_strides1 = a_strides1
|
||||||
|
s_strides13 = c_strides1
|
||||||
|
b_strides2 = a_strides2
|
||||||
|
s_strides2 = c_strides2
|
||||||
|
|
||||||
|
score = torch.randn((M, E), dtype=dtype, device=device)
|
||||||
|
topk_weights, topk_ids = select_experts(
|
||||||
|
hidden_states=a,
|
||||||
|
router_logits=score,
|
||||||
|
top_k=topk,
|
||||||
|
use_grouped_topk=False,
|
||||||
|
renormalize=False,
|
||||||
|
)
|
||||||
|
expert_map = torch.arange(E, dtype=torch.int32, device=device)
|
||||||
|
expert_map[local_e:] = E
|
||||||
|
|
||||||
|
output = cutlass_moe(
|
||||||
|
a,
|
||||||
|
w1_q,
|
||||||
|
w2_q,
|
||||||
|
w1_scale,
|
||||||
|
w2_scale,
|
||||||
|
topk_weights,
|
||||||
|
topk_ids,
|
||||||
|
a_strides1,
|
||||||
|
b_strides1,
|
||||||
|
c_strides1,
|
||||||
|
a_strides2,
|
||||||
|
b_strides2,
|
||||||
|
c_strides2,
|
||||||
|
s_strides13,
|
||||||
|
s_strides2,
|
||||||
|
0,
|
||||||
|
local_e - 1,
|
||||||
|
E,
|
||||||
|
a1_scale,
|
||||||
|
a2_scale,
|
||||||
|
expert_map,
|
||||||
|
)
|
||||||
|
|
||||||
|
ref_output = ref(
|
||||||
|
a,
|
||||||
|
local_e,
|
||||||
|
topk_weights,
|
||||||
|
topk_ids,
|
||||||
|
ref_weight_1,
|
||||||
|
ref_weight_2,
|
||||||
|
scale_1,
|
||||||
|
scale_2,
|
||||||
|
has_pre_quant=True,
|
||||||
|
has_alpha=True,
|
||||||
|
pre_quant_scale_1=a1_scale,
|
||||||
|
pre_quant_scale_2=a2_scale,
|
||||||
|
alpha_1=a1_scale,
|
||||||
|
alpha_2=a2_scale,
|
||||||
|
)
|
||||||
|
|
||||||
|
# compare
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
|
# compare final output
|
||||||
|
torch.testing.assert_close(output, ref_output, rtol=1e-2, atol=0.1)
|
||||||
|
print("SUCCESS: Final output tensors are close.")
|
||||||
|
|
||||||
|
|
||||||
|
def cutlass_moe(
|
||||||
|
a: torch.Tensor,
|
||||||
|
w1_q: torch.Tensor,
|
||||||
|
w2_q: torch.Tensor,
|
||||||
|
w1_scale: torch.Tensor,
|
||||||
|
w2_scale: torch.Tensor,
|
||||||
|
topk_weights: torch.Tensor,
|
||||||
|
topk_ids_: torch.Tensor,
|
||||||
|
a_strides1: torch.Tensor,
|
||||||
|
b_strides1: torch.Tensor,
|
||||||
|
c_strides1: torch.Tensor,
|
||||||
|
a_strides2: torch.Tensor,
|
||||||
|
b_strides2: torch.Tensor,
|
||||||
|
c_strides2: torch.Tensor,
|
||||||
|
s_strides13: torch.Tensor,
|
||||||
|
s_strides2: torch.Tensor,
|
||||||
|
start_expert_id: int,
|
||||||
|
end_expert_id: int,
|
||||||
|
E: int,
|
||||||
|
a1_scale: Optional[torch.Tensor] = None,
|
||||||
|
a2_scale: Optional[torch.Tensor] = None,
|
||||||
|
expert_map: Optional[torch.Tensor] = None,
|
||||||
|
apply_router_weight_on_input: bool = False,
|
||||||
|
):
|
||||||
|
local_topk_ids = topk_ids_
|
||||||
|
local_topk_ids = torch.where(expert_map[topk_ids_] != E, expert_map[topk_ids_], E)
|
||||||
|
device = a.device
|
||||||
|
|
||||||
|
local_num_experts = end_expert_id - start_expert_id + 1
|
||||||
|
expert_offsets = torch.empty(
|
||||||
|
(local_num_experts + 1), dtype=torch.int32, device=device
|
||||||
|
)
|
||||||
|
problem_sizes1 = torch.empty(
|
||||||
|
(local_num_experts, 3), dtype=torch.int32, device=device
|
||||||
|
)
|
||||||
|
problem_sizes2 = torch.empty(
|
||||||
|
(local_num_experts, 3), dtype=torch.int32, device=device
|
||||||
|
)
|
||||||
|
return cutlass_w4a8_moe(
|
||||||
|
start_expert_id,
|
||||||
|
end_expert_id,
|
||||||
|
E,
|
||||||
|
a,
|
||||||
|
w1_q,
|
||||||
|
w2_q,
|
||||||
|
w1_scale,
|
||||||
|
w2_scale,
|
||||||
|
topk_weights,
|
||||||
|
topk_ids_,
|
||||||
|
local_topk_ids,
|
||||||
|
a_strides1,
|
||||||
|
b_strides1,
|
||||||
|
c_strides1,
|
||||||
|
a_strides2,
|
||||||
|
b_strides2,
|
||||||
|
c_strides2,
|
||||||
|
s_strides13,
|
||||||
|
s_strides2,
|
||||||
|
expert_offsets,
|
||||||
|
problem_sizes1,
|
||||||
|
problem_sizes2,
|
||||||
|
a1_scale,
|
||||||
|
a2_scale,
|
||||||
|
apply_router_weight_on_input,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def ref(
|
||||||
|
x: torch.Tensor,
|
||||||
|
num_experts: int,
|
||||||
|
topk_weights: torch.Tensor,
|
||||||
|
topk_ids: torch.Tensor,
|
||||||
|
ref_weight_1: torch.Tensor,
|
||||||
|
ref_weight_2: torch.Tensor,
|
||||||
|
ref_weight_scale_1: torch.Tensor,
|
||||||
|
ref_weight_scale_2: torch.Tensor,
|
||||||
|
has_pre_quant: bool = False,
|
||||||
|
has_alpha: bool = False,
|
||||||
|
pre_quant_scale_1: Optional[torch.Tensor] = None,
|
||||||
|
pre_quant_scale_2: Optional[torch.Tensor] = None,
|
||||||
|
alpha_1: Optional[torch.Tensor] = None,
|
||||||
|
alpha_2: Optional[torch.Tensor] = None,
|
||||||
|
):
|
||||||
|
results = torch.zeros_like(x)
|
||||||
|
dtype = x.dtype
|
||||||
|
for e_idx in range(num_experts):
|
||||||
|
mask = topk_ids == e_idx
|
||||||
|
activated_tokens = mask.sum(1).bool()
|
||||||
|
act = x[activated_tokens, :]
|
||||||
|
if act.shape[0] == 0:
|
||||||
|
continue
|
||||||
|
final_scale = (topk_weights * mask).sum(1)[activated_tokens].unsqueeze(1)
|
||||||
|
|
||||||
|
act = (
|
||||||
|
torch.clamp((act / pre_quant_scale_1.float()), -448.0, 448.0)
|
||||||
|
.to(torch.float8_e4m3fn)
|
||||||
|
.to(dtype)
|
||||||
|
)
|
||||||
|
w3_w1 = ref_weight_1[e_idx]
|
||||||
|
ref_w_scale_repeat = (
|
||||||
|
ref_weight_scale_1[e_idx].repeat_interleave(128, dim=1).to(float)
|
||||||
|
)
|
||||||
|
w3_w1 = (w3_w1.to(float) * ref_w_scale_repeat).to(dtype)
|
||||||
|
fc1 = ((torch.matmul(act, w3_w1.T)) * alpha_1).to(torch.float16)
|
||||||
|
|
||||||
|
gate, fc1 = fc1.chunk(2, dim=-1)
|
||||||
|
fc1 = fc1 * torch.nn.functional.silu(gate)
|
||||||
|
act = (fc1 / pre_quant_scale_2.float()).to(torch.float8_e4m3fn)
|
||||||
|
act = act.to(dtype)
|
||||||
|
|
||||||
|
w2 = ref_weight_2[e_idx]
|
||||||
|
ref_w_scale_repeat = (
|
||||||
|
ref_weight_scale_2[e_idx].repeat_interleave(128, dim=1).to(float)
|
||||||
|
)
|
||||||
|
w2 = (w2.to(float) * ref_w_scale_repeat).to(dtype)
|
||||||
|
fc2 = (torch.matmul(act, w2.T) * alpha_2).to(torch.float16)
|
||||||
|
|
||||||
|
results[activated_tokens, :] += (fc2 * final_scale).to(results.dtype)
|
||||||
|
|
||||||
|
return results
|
||||||
Reference in New Issue
Block a user