[3/n] chore: decouple AWQ implementation from vLLM dependency (#8113)
Co-authored-by: AniZpZ <zhuangsen.zp@antgroup.com>
This commit is contained in:
286
python/sglang/test/test_marlin_moe.py
Normal file
286
python/sglang/test/test_marlin_moe.py
Normal file
@@ -0,0 +1,286 @@
|
||||
import types
|
||||
from typing import Optional
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from sgl_kernel import fused_marlin_moe
|
||||
|
||||
from sglang.srt.layers.activation import SiluAndMul
|
||||
from sglang.srt.layers.quantization.scalar_type import ScalarType, scalar_types
|
||||
from sglang.test.test_marlin_utils import awq_marlin_quantize, marlin_quantize
|
||||
|
||||
|
||||
def stack_and_dev(tensors: list[torch.Tensor]):
|
||||
dev = tensors[0].device
|
||||
return torch.stack(tensors, dim=0).to(dev)
|
||||
|
||||
|
||||
def torch_experts(
|
||||
a: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
topk_weight: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
quant_dtype: Optional[torch.dtype] = None,
|
||||
apply_router_weights_on_input: bool = False,
|
||||
) -> torch.Tensor:
|
||||
assert (
|
||||
global_num_experts == -1
|
||||
or (global_num_experts == w1.shape[0] and expert_map is None)
|
||||
or (expert_map is not None and global_num_experts == expert_map.shape[0])
|
||||
)
|
||||
|
||||
M, K = a.shape
|
||||
topk = topk_ids.shape[1]
|
||||
print("quant_dtype", quant_dtype)
|
||||
# exit(0)
|
||||
if apply_router_weights_on_input:
|
||||
assert topk == 1
|
||||
a = a * topk_weight.to(a.dtype)
|
||||
|
||||
a = a.view(M, -1, K).repeat(1, topk, 1).reshape(-1, K)
|
||||
|
||||
out = torch.zeros(M * topk, w2.shape[1], dtype=a.dtype, device=a.device)
|
||||
|
||||
num_experts = w1.shape[0]
|
||||
|
||||
topk_ids = topk_ids.view(-1)
|
||||
if expert_map is not None:
|
||||
topk_ids = expert_map[topk_ids]
|
||||
|
||||
f32 = torch.float32
|
||||
|
||||
for i in range(num_experts):
|
||||
mask = topk_ids == i
|
||||
if mask.sum():
|
||||
if quant_dtype is None:
|
||||
tmp1 = a[mask] @ w1[i].transpose(0, 1)
|
||||
tmp2 = SiluAndMul()(tmp1)
|
||||
out[mask] = tmp2 @ w2[i].transpose(0, 1)
|
||||
|
||||
if apply_router_weights_on_input:
|
||||
return out
|
||||
else:
|
||||
return (
|
||||
(out.view(M, -1, w2.shape[1]).to(f32) * topk_weight.view(M, -1, 1))
|
||||
.sum(dim=1)
|
||||
.to(out.dtype)
|
||||
)
|
||||
|
||||
|
||||
def torch_moe(
|
||||
a: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
score: torch.Tensor,
|
||||
topk: int,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
score = torch.softmax(score, dim=-1, dtype=torch.float32)
|
||||
topk_weight, topk_ids = torch.topk(score, topk)
|
||||
return torch_experts(
|
||||
a, w1, w2, topk_weight, topk_ids, global_num_experts, expert_map
|
||||
)
|
||||
|
||||
|
||||
def marlin_moe_generate_valid_test_cases():
|
||||
import itertools
|
||||
|
||||
m_list = [1, 123, 666]
|
||||
n_list = [128, 1024]
|
||||
k_list = [256, 2048]
|
||||
e_list = [4, 12]
|
||||
topk_list = [2, 3]
|
||||
dtype_list = [torch.half, torch.bfloat16]
|
||||
group_size_list = [128]
|
||||
act_order_list = [True, False]
|
||||
quant_type_list = [
|
||||
scalar_types.uint4,
|
||||
scalar_types.uint4b8,
|
||||
]
|
||||
is_k_full_list = [True, False]
|
||||
|
||||
all_combinations = itertools.product(
|
||||
m_list,
|
||||
n_list,
|
||||
k_list,
|
||||
e_list,
|
||||
topk_list,
|
||||
dtype_list,
|
||||
group_size_list,
|
||||
act_order_list,
|
||||
quant_type_list,
|
||||
is_k_full_list,
|
||||
)
|
||||
|
||||
def is_invalid(
|
||||
m, n, k, e, topk, dtype, group_size, act_order, quant_type, is_k_full
|
||||
):
|
||||
|
||||
# Filter act_order
|
||||
if act_order:
|
||||
if group_size in (-1, k, n):
|
||||
return False
|
||||
if quant_type not in [scalar_types.uint4b8]:
|
||||
return False
|
||||
elif not is_k_full:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
cases = []
|
||||
for case in all_combinations:
|
||||
if is_invalid(*case):
|
||||
cases.append(case)
|
||||
return cases
|
||||
|
||||
|
||||
@pytest.mark.flaky(reruns=2)
|
||||
@pytest.mark.parametrize(
|
||||
("m, n, k, e, topk, dtype, group_size," "act_order, quant_type, is_k_full"),
|
||||
marlin_moe_generate_valid_test_cases(),
|
||||
)
|
||||
def test_fused_marlin_moe(
|
||||
m: int,
|
||||
n: int,
|
||||
k: int,
|
||||
e: int,
|
||||
topk: int,
|
||||
dtype: torch.dtype,
|
||||
group_size: int,
|
||||
act_order: bool,
|
||||
quant_type: ScalarType,
|
||||
is_k_full: bool,
|
||||
):
|
||||
if not torch.cuda.is_available():
|
||||
pytest.skip("CUDA device not available")
|
||||
|
||||
torch.manual_seed(0)
|
||||
|
||||
has_zp = quant_type in [scalar_types.uint4, scalar_types.uint8]
|
||||
|
||||
# Filter act_order
|
||||
if act_order:
|
||||
if group_size == -1:
|
||||
return
|
||||
if group_size in (k, n):
|
||||
return
|
||||
if has_zp:
|
||||
return
|
||||
else:
|
||||
if not is_k_full:
|
||||
return
|
||||
|
||||
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
|
||||
w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 20
|
||||
w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 20
|
||||
|
||||
e_map = None
|
||||
|
||||
w_ref1_l = []
|
||||
qweight1_l = []
|
||||
scales1_l = []
|
||||
zeros1_l = []
|
||||
g_idx1_l = []
|
||||
sort_indices1_l = []
|
||||
|
||||
for i in range(w1.shape[0]):
|
||||
if has_zp:
|
||||
w_ref1, qweight1, scales1, zeros1 = awq_marlin_quantize(
|
||||
w1[i].transpose(1, 0), quant_type, group_size
|
||||
)
|
||||
|
||||
w_ref1_l.append(w_ref1.T)
|
||||
qweight1_l.append(qweight1)
|
||||
scales1_l.append(scales1)
|
||||
zeros1_l.append(zeros1)
|
||||
else:
|
||||
test_perm = torch.randperm(k)
|
||||
w_ref1, qweight1, scales1, g_idx1, sort_indices1, _ = marlin_quantize(
|
||||
w1[i].transpose(1, 0), quant_type, group_size, act_order, test_perm
|
||||
)
|
||||
|
||||
w_ref1_l.append(w_ref1.T)
|
||||
qweight1_l.append(qweight1)
|
||||
scales1_l.append(scales1)
|
||||
g_idx1_l.append(g_idx1)
|
||||
sort_indices1_l.append(sort_indices1)
|
||||
|
||||
w_ref1 = stack_and_dev(w_ref1_l)
|
||||
qweight1 = stack_and_dev(qweight1_l).contiguous()
|
||||
scales1 = stack_and_dev(scales1_l)
|
||||
g_idx1 = stack_and_dev(g_idx1_l) if g_idx1_l else None
|
||||
zeros1 = stack_and_dev(zeros1_l) if zeros1_l else None
|
||||
sort_indices1 = stack_and_dev(sort_indices1_l) if sort_indices1_l else None
|
||||
|
||||
w_ref2_l = []
|
||||
qweight2_l = []
|
||||
scales2_l = []
|
||||
zeros2_l = []
|
||||
g_idx2_l = []
|
||||
sort_indices2_l = []
|
||||
|
||||
for i in range(w2.shape[0]):
|
||||
if has_zp:
|
||||
w_ref2, qweight2, scales2, zeros2 = awq_marlin_quantize(
|
||||
w2[i].transpose(1, 0), quant_type, group_size
|
||||
)
|
||||
|
||||
w_ref2_l.append(w_ref2.T)
|
||||
qweight2_l.append(qweight2)
|
||||
scales2_l.append(scales2)
|
||||
zeros2_l.append(zeros2)
|
||||
else:
|
||||
test_perm = torch.randperm(n)
|
||||
w_ref2, qweight2, scales2, g_idx2, sort_indices2, _ = marlin_quantize(
|
||||
w2[i].transpose(1, 0), quant_type, group_size, act_order, test_perm
|
||||
)
|
||||
|
||||
w_ref2_l.append(w_ref2.T)
|
||||
qweight2_l.append(qweight2)
|
||||
scales2_l.append(scales2)
|
||||
g_idx2_l.append(g_idx2)
|
||||
sort_indices2_l.append(sort_indices2)
|
||||
|
||||
w_ref2 = stack_and_dev(w_ref2_l)
|
||||
qweight2 = stack_and_dev(qweight2_l).contiguous()
|
||||
scales2 = stack_and_dev(scales2_l)
|
||||
g_idx2 = stack_and_dev(g_idx2_l) if g_idx2_l else None
|
||||
zeros2 = stack_and_dev(zeros2_l) if zeros2_l else None
|
||||
sort_indices2 = stack_and_dev(sort_indices2_l) if sort_indices2_l else None
|
||||
|
||||
score = torch.randn((m, e), device="cuda", dtype=dtype)
|
||||
from sglang.srt.layers.moe.topk import fused_topk_torch_native
|
||||
|
||||
topk_weights, topk_ids = fused_topk_torch_native(a, score, topk, False)
|
||||
|
||||
torch_output = torch_moe(a, w_ref1, w_ref2, score, topk, expert_map=e_map)
|
||||
|
||||
marlin_output = fused_marlin_moe(
|
||||
a,
|
||||
qweight1,
|
||||
qweight2,
|
||||
scales1,
|
||||
scales2,
|
||||
score,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
g_idx1=g_idx1,
|
||||
g_idx2=g_idx2,
|
||||
sort_indices1=sort_indices1,
|
||||
sort_indices2=sort_indices2,
|
||||
w1_zeros=zeros1,
|
||||
w2_zeros=zeros2,
|
||||
num_bits=4,
|
||||
is_k_full=is_k_full,
|
||||
)
|
||||
|
||||
torch.testing.assert_close(marlin_output, torch_output, atol=5e-2, rtol=0)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Run the specific test function directly
|
||||
pytest.main([__file__])
|
||||
171
python/sglang/test/test_marlin_utils.py
Normal file
171
python/sglang/test/test_marlin_utils.py
Normal file
@@ -0,0 +1,171 @@
|
||||
"""
|
||||
Adapted from
|
||||
https://github.com/vllm-project/vllm/blob/020f58abcdea65302225663130d08fd8f4dd755a/vllm/model_executor/layers/quantization/utils/marlin_utils_test.py
|
||||
"""
|
||||
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
"""Utility functions used for tests and benchmarks"""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from sglang.srt.layers.quantization.marlin_utils import (
|
||||
GPTQ_MARLIN_TILE,
|
||||
marlin_permute_scales,
|
||||
marlin_zero_points,
|
||||
)
|
||||
from sglang.srt.layers.quantization.scalar_type import ScalarType
|
||||
from sglang.srt.layers.quantization.utils import (
|
||||
get_pack_factor,
|
||||
gptq_quantize_weights,
|
||||
quantize_weights,
|
||||
sort_weights,
|
||||
)
|
||||
|
||||
|
||||
class MarlinWorkspace:
|
||||
|
||||
def __init__(self, out_features, min_thread_n, max_parallel):
|
||||
assert (
|
||||
out_features % min_thread_n == 0
|
||||
), "out_features = {} is undivisible by min_thread_n = {}".format(
|
||||
out_features, min_thread_n
|
||||
)
|
||||
|
||||
max_workspace_size = (out_features // min_thread_n) * max_parallel
|
||||
|
||||
self.scratch = torch.zeros(max_workspace_size, dtype=torch.int, device="cuda")
|
||||
|
||||
|
||||
def marlin_permute_weights(q_w, size_k, size_n, perm, tile=GPTQ_MARLIN_TILE):
|
||||
assert q_w.shape == (size_k, size_n)
|
||||
assert size_k % tile == 0, f"size_k = {size_k}, tile = {tile}"
|
||||
assert size_n % tile == 0, f"size_k = {size_n}, tile = {tile}"
|
||||
|
||||
# Permute weights to 16x64 marlin tiles
|
||||
q_w = q_w.reshape((size_k // tile, tile, size_n // tile, tile))
|
||||
q_w = q_w.permute((0, 2, 1, 3))
|
||||
q_w = q_w.reshape((size_k // tile, size_n * tile))
|
||||
|
||||
q_w = q_w.reshape((-1, perm.numel()))[:, perm].reshape(q_w.shape)
|
||||
|
||||
return q_w
|
||||
|
||||
|
||||
def marlin_weights(q_w, size_k, size_n, num_bits, perm):
|
||||
# Permute
|
||||
q_w = marlin_permute_weights(q_w, size_k, size_n, perm)
|
||||
|
||||
# Pack
|
||||
pack_factor = get_pack_factor(num_bits)
|
||||
orig_device = q_w.device
|
||||
|
||||
q_w = q_w.cpu().numpy().astype(np.uint32)
|
||||
|
||||
q_packed = np.zeros((q_w.shape[0], q_w.shape[1] // pack_factor), dtype=np.uint32)
|
||||
for i in range(pack_factor):
|
||||
q_packed |= q_w[:, i::pack_factor] << num_bits * i
|
||||
|
||||
q_packed = torch.from_numpy(q_packed.astype(np.int32)).to(orig_device)
|
||||
|
||||
return q_packed
|
||||
|
||||
|
||||
def get_weight_perm(num_bits: int):
|
||||
perm_list: list[int] = []
|
||||
for i in range(32):
|
||||
perm1: list[int] = []
|
||||
col = i // 4
|
||||
for block in [0, 1]:
|
||||
for row in [
|
||||
2 * (i % 4),
|
||||
2 * (i % 4) + 1,
|
||||
2 * (i % 4 + 4),
|
||||
2 * (i % 4 + 4) + 1,
|
||||
]:
|
||||
perm1.append(16 * row + col + 8 * block)
|
||||
for j in range(4):
|
||||
perm_list.extend([p + 256 * j for p in perm1])
|
||||
|
||||
perm = np.array(perm_list)
|
||||
|
||||
if num_bits == 4:
|
||||
interleave = np.array([0, 2, 4, 6, 1, 3, 5, 7])
|
||||
elif num_bits == 8:
|
||||
interleave = np.array([0, 2, 1, 3])
|
||||
else:
|
||||
raise Exception("num_bits must be 4 or 8, got {}".format(num_bits))
|
||||
|
||||
perm = perm.reshape((-1, len(interleave)))[:, interleave].ravel()
|
||||
perm = torch.from_numpy(perm)
|
||||
return perm
|
||||
|
||||
|
||||
def marlin_quantize(
|
||||
w: torch.Tensor,
|
||||
quant_type: ScalarType,
|
||||
group_size: int,
|
||||
act_order: bool,
|
||||
test_perm: Optional[torch.Tensor] = None,
|
||||
):
|
||||
size_k, size_n = w.shape
|
||||
num_bits = quant_type.size_bits
|
||||
|
||||
# Normalize group_size
|
||||
if group_size == -1:
|
||||
group_size = size_k
|
||||
assert group_size <= size_k
|
||||
|
||||
# Quantize (and apply act_order if provided)
|
||||
w_ref, q_w, s, g_idx, rand_perm = gptq_quantize_weights(
|
||||
w, quant_type, group_size, act_order, test_perm
|
||||
)
|
||||
|
||||
# For act_order, sort the "weights" and "g_idx" so that group ids are
|
||||
# increasing
|
||||
sort_indices = torch.empty(0, dtype=torch.int, device=w.device)
|
||||
if act_order:
|
||||
q_w, g_idx, sort_indices = sort_weights(q_w, g_idx)
|
||||
|
||||
# Reformat to marlin
|
||||
weight_perm = get_weight_perm(num_bits)
|
||||
marlin_q_w = marlin_weights(q_w, size_k, size_n, num_bits, weight_perm)
|
||||
marlin_s = marlin_permute_scales(s, size_k, size_n, group_size)
|
||||
|
||||
# Create result
|
||||
res_list = [w_ref, marlin_q_w, marlin_s, g_idx, sort_indices, rand_perm]
|
||||
for i in range(len(res_list)):
|
||||
res_list[i] = res_list[i].to(w.device)
|
||||
|
||||
return res_list
|
||||
|
||||
|
||||
def awq_marlin_quantize(w: torch.Tensor, quant_type: ScalarType, group_size: int):
|
||||
size_k, size_n = w.shape
|
||||
|
||||
# Normalize group_size
|
||||
if group_size == -1:
|
||||
group_size = size_k
|
||||
assert group_size <= size_k
|
||||
|
||||
# Detect num groups
|
||||
assert size_k % group_size == 0
|
||||
num_groups = size_k // group_size
|
||||
|
||||
# Quantize with zp
|
||||
w_ref, q_w, s, zp = quantize_weights(w, quant_type, group_size, zero_points=True)
|
||||
|
||||
# Reformat to marlin
|
||||
weight_perm = get_weight_perm(quant_type.size_bits)
|
||||
marlin_q_w = marlin_weights(q_w, size_k, size_n, quant_type.size_bits, weight_perm)
|
||||
marlin_s = marlin_permute_scales(s, size_k, size_n, group_size)
|
||||
marlin_zp = marlin_zero_points(zp, num_groups, size_n, quant_type.size_bits)
|
||||
|
||||
# Create result
|
||||
res_list = [w_ref, marlin_q_w, marlin_s, marlin_zp]
|
||||
for i in range(len(res_list)):
|
||||
res_list[i] = res_list[i].to(w.device)
|
||||
|
||||
return res_list
|
||||
Reference in New Issue
Block a user