Resubmit MoE-EP (#2371)
This commit is contained in:
6
.github/workflows/pr-test.yml
vendored
6
.github/workflows/pr-test.yml
vendored
@@ -105,6 +105,12 @@ jobs:
|
|||||||
cd test/srt
|
cd test/srt
|
||||||
python3 test_update_weights_from_distributed.py
|
python3 test_update_weights_from_distributed.py
|
||||||
|
|
||||||
|
- name: Evaluate MoE EP accuracy (TP=2)
|
||||||
|
timeout-minutes: 10
|
||||||
|
run: |
|
||||||
|
cd test/srt
|
||||||
|
python3 test_moe_ep.py
|
||||||
|
|
||||||
performance-test-1-gpu-part-1:
|
performance-test-1-gpu-part-1:
|
||||||
if: github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request'
|
if: github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request'
|
||||||
runs-on: 1-gpu-runner
|
runs-on: 1-gpu-runner
|
||||||
|
|||||||
0
python/sglang/srt/layers/ep_moe/__init__.py
Normal file
0
python/sglang/srt/layers/ep_moe/__init__.py
Normal file
349
python/sglang/srt/layers/ep_moe/kernels.py
Normal file
349
python/sglang/srt/layers/ep_moe/kernels.py
Normal file
@@ -0,0 +1,349 @@
|
|||||||
|
import logging
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import triton
|
||||||
|
import triton.language as tl
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def compute_seg_indptr_triton_kernel(reorder_topk_ids, seg_indptr, num_toks):
|
||||||
|
expert = tl.program_id(0)
|
||||||
|
low = 0
|
||||||
|
high = num_toks - 1
|
||||||
|
target_location = -1
|
||||||
|
while low <= high:
|
||||||
|
mid = (low + high) // 2
|
||||||
|
|
||||||
|
if tl.load(reorder_topk_ids + mid) > expert:
|
||||||
|
high = mid - 1
|
||||||
|
else:
|
||||||
|
low = mid + 1
|
||||||
|
target_location = mid
|
||||||
|
tl.store(seg_indptr + expert + 1, target_location + 1)
|
||||||
|
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def compute_src2dst_triton_kernel(
|
||||||
|
reorder_ids, src2dst, num_toks, BLOCK_SIZE: tl.constexpr
|
||||||
|
):
|
||||||
|
pid = tl.program_id(axis=0)
|
||||||
|
dst_id = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
||||||
|
mask = dst_id < num_toks
|
||||||
|
src_id = tl.load(reorder_ids + dst_id, mask=mask)
|
||||||
|
tl.store(src2dst + src_id, dst_id, mask=mask)
|
||||||
|
|
||||||
|
|
||||||
|
def run_moe_ep_preproess(topk_ids: torch.Tensor, num_experts: int):
|
||||||
|
reorder_topk_ids, reorder_ids = torch.sort(topk_ids.view(-1), stable=True)
|
||||||
|
seg_indptr = torch.zeros(num_experts + 1, device=topk_ids.device, dtype=torch.int64)
|
||||||
|
src2dst = torch.empty(topk_ids.numel(), device=topk_ids.device, dtype=torch.int32)
|
||||||
|
|
||||||
|
compute_seg_indptr_triton_kernel[(num_experts,)](
|
||||||
|
reorder_topk_ids, seg_indptr, topk_ids.numel()
|
||||||
|
)
|
||||||
|
|
||||||
|
BLOCK_SIZE = 512
|
||||||
|
grid = (triton.cdiv(topk_ids.numel(), BLOCK_SIZE),)
|
||||||
|
compute_src2dst_triton_kernel[grid](
|
||||||
|
reorder_ids, src2dst, topk_ids.numel(), BLOCK_SIZE
|
||||||
|
)
|
||||||
|
return reorder_topk_ids, src2dst, seg_indptr
|
||||||
|
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def pre_reorder_triton_kernel(
|
||||||
|
input_ptr,
|
||||||
|
gateup_input_ptr,
|
||||||
|
src2dst_ptr,
|
||||||
|
topk_ids_ptr,
|
||||||
|
a1_scales_ptr,
|
||||||
|
start_expert_id,
|
||||||
|
end_expert_id,
|
||||||
|
topk,
|
||||||
|
hidden_size,
|
||||||
|
BLOCK_SIZE: tl.constexpr,
|
||||||
|
):
|
||||||
|
OutDtype = gateup_input_ptr.dtype.element_ty
|
||||||
|
|
||||||
|
src_idx = tl.program_id(0)
|
||||||
|
src2dst_ptr = src2dst_ptr + src_idx * topk
|
||||||
|
topk_ids_ptr = topk_ids_ptr + src_idx * topk
|
||||||
|
|
||||||
|
src_ptr = input_ptr + src_idx * hidden_size
|
||||||
|
for idx in range(topk):
|
||||||
|
expert_id = tl.load(topk_ids_ptr + idx)
|
||||||
|
if expert_id >= start_expert_id and expert_id <= end_expert_id:
|
||||||
|
if a1_scales_ptr is not None:
|
||||||
|
scale = 1.0 / tl.load(a1_scales_ptr + expert_id - start_expert_id)
|
||||||
|
else:
|
||||||
|
scale = 1.0
|
||||||
|
|
||||||
|
dst_idx = tl.load(src2dst_ptr + idx)
|
||||||
|
dst_ptr = gateup_input_ptr + dst_idx * hidden_size
|
||||||
|
for start_offset in tl.range(0, hidden_size, BLOCK_SIZE):
|
||||||
|
offset = start_offset + tl.arange(0, BLOCK_SIZE)
|
||||||
|
mask = offset < hidden_size
|
||||||
|
in_data = tl.load(src_ptr + offset, mask=mask).to(tl.float32)
|
||||||
|
out_data = (in_data * scale).to(OutDtype)
|
||||||
|
tl.store(dst_ptr + offset, out_data, mask=mask)
|
||||||
|
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def silu_and_mul_triton_kernel(
|
||||||
|
gateup_output,
|
||||||
|
down_input,
|
||||||
|
hidden_size,
|
||||||
|
reorder_topk_ids,
|
||||||
|
scales,
|
||||||
|
start_expert_id,
|
||||||
|
end_expert_id,
|
||||||
|
BLOCK_SIZE: tl.constexpr,
|
||||||
|
):
|
||||||
|
InDtype = gateup_output.dtype.element_ty
|
||||||
|
OutDtype = down_input.dtype.element_ty
|
||||||
|
|
||||||
|
half_hidden_size = hidden_size // 2
|
||||||
|
|
||||||
|
pid = tl.program_id(0)
|
||||||
|
expert_id = tl.load(reorder_topk_ids + pid)
|
||||||
|
if expert_id >= start_expert_id and expert_id <= end_expert_id:
|
||||||
|
gateup_output_ptr = gateup_output + pid * hidden_size
|
||||||
|
gate_output_ptr = gateup_output_ptr
|
||||||
|
up_output_ptr = gateup_output_ptr + half_hidden_size
|
||||||
|
down_input_ptr = down_input + pid * half_hidden_size
|
||||||
|
|
||||||
|
if scales is not None:
|
||||||
|
scale = tl.load(scales + expert_id - start_expert_id)
|
||||||
|
scale = (1 / scale).to(InDtype)
|
||||||
|
else:
|
||||||
|
scale = 1
|
||||||
|
|
||||||
|
for start_offset in tl.range(0, half_hidden_size, BLOCK_SIZE):
|
||||||
|
offset = start_offset + tl.arange(0, BLOCK_SIZE)
|
||||||
|
mask = offset < half_hidden_size
|
||||||
|
|
||||||
|
gate_output = tl.load(gate_output_ptr + offset, mask=mask).to(tl.float32)
|
||||||
|
up_output = tl.load(up_output_ptr + offset, mask=mask)
|
||||||
|
|
||||||
|
# silu & mul & quantize
|
||||||
|
gate_output = gate_output * tl.sigmoid(gate_output)
|
||||||
|
gate_output = gate_output.to(InDtype)
|
||||||
|
|
||||||
|
silu_mul_output = gate_output * up_output * scale
|
||||||
|
silu_mul_output = silu_mul_output.to(OutDtype)
|
||||||
|
tl.store(down_input_ptr + offset, silu_mul_output, mask=mask)
|
||||||
|
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def post_reorder_triton_kernel(
|
||||||
|
down_output_ptr,
|
||||||
|
output_ptr,
|
||||||
|
src2dst_ptr,
|
||||||
|
topk_ids_ptr,
|
||||||
|
topk_weights_ptr,
|
||||||
|
start_expert_id,
|
||||||
|
end_expert_id,
|
||||||
|
topk,
|
||||||
|
hidden_size,
|
||||||
|
BLOCK_SIZE: tl.constexpr,
|
||||||
|
):
|
||||||
|
InDtype = down_output_ptr.dtype.element_ty
|
||||||
|
|
||||||
|
src_idx = tl.program_id(0)
|
||||||
|
src2dst_ptr = src2dst_ptr + src_idx * topk
|
||||||
|
topk_ids_ptr = topk_ids_ptr + src_idx * topk
|
||||||
|
topk_weights_ptr = topk_weights_ptr + src_idx * topk
|
||||||
|
|
||||||
|
computed = False
|
||||||
|
store_ptr = output_ptr + src_idx * hidden_size
|
||||||
|
for start_offset in tl.range(0, hidden_size, BLOCK_SIZE):
|
||||||
|
offset = start_offset + tl.arange(0, BLOCK_SIZE)
|
||||||
|
mask = offset < hidden_size
|
||||||
|
|
||||||
|
sum_vec = tl.zeros([BLOCK_SIZE], dtype=InDtype)
|
||||||
|
for idx in range(topk):
|
||||||
|
expert_id = tl.load(topk_ids_ptr + idx)
|
||||||
|
if expert_id >= start_expert_id and expert_id <= end_expert_id:
|
||||||
|
computed = True
|
||||||
|
dst_idx = tl.load(src2dst_ptr + idx)
|
||||||
|
weigh_scale = tl.load(topk_weights_ptr + idx).to(InDtype)
|
||||||
|
load_ptr = down_output_ptr + dst_idx * hidden_size
|
||||||
|
in_data = tl.load(load_ptr + offset, mask=mask)
|
||||||
|
sum_vec += in_data * weigh_scale
|
||||||
|
tl.store(store_ptr + offset, sum_vec, mask=mask)
|
||||||
|
|
||||||
|
if computed == False:
|
||||||
|
for start_offset in tl.range(0, hidden_size, BLOCK_SIZE):
|
||||||
|
offset = start_offset + tl.arange(0, BLOCK_SIZE)
|
||||||
|
mask = offset < hidden_size
|
||||||
|
tl.store(
|
||||||
|
store_ptr + offset, tl.zeros([BLOCK_SIZE], dtype=InDtype), mask=mask
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def compute_m_range(
|
||||||
|
pid,
|
||||||
|
batch_size,
|
||||||
|
seg_indptr,
|
||||||
|
weight_indices,
|
||||||
|
m_num_tiles_indptr,
|
||||||
|
BLOCK_SIZE_M: tl.constexpr,
|
||||||
|
):
|
||||||
|
idx = 0
|
||||||
|
for bs in range(batch_size):
|
||||||
|
tiles = tl.load(m_num_tiles_indptr + bs)
|
||||||
|
if pid >= tiles:
|
||||||
|
idx = bs
|
||||||
|
|
||||||
|
idx_start = tl.load(m_num_tiles_indptr + idx)
|
||||||
|
|
||||||
|
m_range_start = tl.load(seg_indptr + idx) + (pid - idx_start) * BLOCK_SIZE_M
|
||||||
|
m_range_end = min(tl.load(seg_indptr + idx + 1), m_range_start + BLOCK_SIZE_M)
|
||||||
|
expert_id = tl.load(weight_indices + idx)
|
||||||
|
return m_range_start, m_range_end, expert_id
|
||||||
|
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def grouped_gemm_triton_kernel(
|
||||||
|
a,
|
||||||
|
b,
|
||||||
|
c,
|
||||||
|
batch_size,
|
||||||
|
N,
|
||||||
|
K,
|
||||||
|
seg_indptr,
|
||||||
|
weight_indices,
|
||||||
|
m_num_tiles_indptr,
|
||||||
|
use_fp8_w8a8,
|
||||||
|
scale_a,
|
||||||
|
scale_b,
|
||||||
|
a_stride_0: tl.constexpr,
|
||||||
|
b_stride_0: tl.constexpr,
|
||||||
|
b_stride_1: tl.constexpr,
|
||||||
|
BLOCK_SIZE_M: tl.constexpr,
|
||||||
|
BLOCK_SIZE_N: tl.constexpr,
|
||||||
|
BLOCK_SIZE_K: tl.constexpr,
|
||||||
|
):
|
||||||
|
c_dtype = c.dtype.element_ty
|
||||||
|
|
||||||
|
pid_m = tl.program_id(0)
|
||||||
|
pid_n = tl.program_id(1)
|
||||||
|
total_m_block = tl.load(m_num_tiles_indptr + batch_size)
|
||||||
|
if pid_m >= total_m_block:
|
||||||
|
return
|
||||||
|
|
||||||
|
m_range_start, m_range_end, expert_id = compute_m_range(
|
||||||
|
pid_m, batch_size, seg_indptr, weight_indices, m_num_tiles_indptr, BLOCK_SIZE_M
|
||||||
|
)
|
||||||
|
if m_range_end - m_range_start == 0:
|
||||||
|
return
|
||||||
|
|
||||||
|
n_range_start = pid_n * BLOCK_SIZE_N
|
||||||
|
n_range_end = min(n_range_start + BLOCK_SIZE_N, N)
|
||||||
|
|
||||||
|
offs_am = tl.arange(0, BLOCK_SIZE_M)
|
||||||
|
offs_bn = tl.arange(0, BLOCK_SIZE_N)
|
||||||
|
|
||||||
|
offs_am = tl.where(offs_am < m_range_end - m_range_start, offs_am, 0)
|
||||||
|
offs_bn = tl.where(offs_bn < n_range_end - n_range_start, offs_bn, 0)
|
||||||
|
offs_am = tl.max_contiguous(tl.multiple_of(offs_am, BLOCK_SIZE_M), BLOCK_SIZE_M)
|
||||||
|
offs_bn = tl.max_contiguous(tl.multiple_of(offs_bn, BLOCK_SIZE_N), BLOCK_SIZE_N)
|
||||||
|
offs_k = tl.arange(0, BLOCK_SIZE_K)
|
||||||
|
|
||||||
|
a_ptr = a + (m_range_start + offs_am[:, None]) * a_stride_0 + offs_k[None, :]
|
||||||
|
b_ptr = b + (
|
||||||
|
(expert_id * b_stride_0)
|
||||||
|
+ (n_range_start + offs_bn[:, None]) * b_stride_1
|
||||||
|
+ offs_k[None, :]
|
||||||
|
)
|
||||||
|
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
||||||
|
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
|
||||||
|
a_tile = tl.load(
|
||||||
|
a_ptr, mask=offs_k[None, :] < (K - k * BLOCK_SIZE_K), other=0.0
|
||||||
|
)
|
||||||
|
b_tile = tl.load(
|
||||||
|
b_ptr, mask=offs_k[None, :] < (K - k * BLOCK_SIZE_K), other=0.0
|
||||||
|
)
|
||||||
|
accumulator = tl.dot(a_tile, b_tile.T, accumulator)
|
||||||
|
a_ptr += BLOCK_SIZE_K
|
||||||
|
b_ptr += BLOCK_SIZE_K
|
||||||
|
|
||||||
|
if use_fp8_w8a8:
|
||||||
|
scale_a_value = tl.load(scale_a + expert_id)
|
||||||
|
scale_b_value = tl.load(scale_b + expert_id)
|
||||||
|
accumulator *= scale_a_value * scale_b_value
|
||||||
|
c_tile = accumulator.to(c_dtype)
|
||||||
|
|
||||||
|
offs_cm = m_range_start + tl.arange(0, BLOCK_SIZE_M)
|
||||||
|
offs_cn = n_range_start + tl.arange(0, BLOCK_SIZE_N)
|
||||||
|
c_ptr = c + offs_cm[:, None] * N + offs_cn[None, :]
|
||||||
|
c_mask = (offs_cm[:, None] < m_range_end) & (offs_cn[None, :] < n_range_end)
|
||||||
|
tl.store(c_ptr, c_tile, mask=c_mask)
|
||||||
|
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def compute_m_num_tiles_indptr(
|
||||||
|
m_num_tiles_indptr, seg_indptr, batch_size: tl.constexpr, BLOCK_SIZE_M: tl.constexpr
|
||||||
|
):
|
||||||
|
for bs in range(batch_size):
|
||||||
|
m = tl.load(seg_indptr + bs + 1) - tl.load(seg_indptr + bs)
|
||||||
|
cur_num_tiles = tl.cdiv(m, BLOCK_SIZE_M)
|
||||||
|
pre_num_tiles = tl.load(m_num_tiles_indptr + bs)
|
||||||
|
tl.store(m_num_tiles_indptr + bs + 1, pre_num_tiles + cur_num_tiles)
|
||||||
|
|
||||||
|
|
||||||
|
def grouped_gemm_triton(
|
||||||
|
a: torch.Tensor,
|
||||||
|
b: torch.Tensor,
|
||||||
|
c: torch.Tensor,
|
||||||
|
batch_size: int,
|
||||||
|
weight_column_major: bool,
|
||||||
|
seg_indptr: Optional[torch.Tensor] = None,
|
||||||
|
weight_indices: Optional[torch.Tensor] = None,
|
||||||
|
use_fp8_w8a8: bool = False,
|
||||||
|
scale_a: torch.Tensor = None,
|
||||||
|
scale_b: torch.Tensor = None,
|
||||||
|
):
|
||||||
|
assert weight_column_major == True # TODO: more
|
||||||
|
if use_fp8_w8a8:
|
||||||
|
assert scale_a is not None and scale_b is not None
|
||||||
|
|
||||||
|
config = {
|
||||||
|
"BLOCK_SIZE_M": 128,
|
||||||
|
"BLOCK_SIZE_N": 128,
|
||||||
|
"BLOCK_SIZE_K": 128,
|
||||||
|
}
|
||||||
|
|
||||||
|
m_num_tiles_indptr = torch.zeros(batch_size + 1, device=a.device, dtype=torch.int64)
|
||||||
|
compute_m_num_tiles_indptr[(1,)](
|
||||||
|
m_num_tiles_indptr, seg_indptr, batch_size, config["BLOCK_SIZE_M"]
|
||||||
|
)
|
||||||
|
|
||||||
|
grid = lambda META: (
|
||||||
|
triton.cdiv(a.size(0), META["BLOCK_SIZE_M"]) + batch_size,
|
||||||
|
triton.cdiv(b.size(1), META["BLOCK_SIZE_N"]),
|
||||||
|
)
|
||||||
|
|
||||||
|
grouped_gemm_triton_kernel[grid](
|
||||||
|
a,
|
||||||
|
b,
|
||||||
|
c,
|
||||||
|
batch_size,
|
||||||
|
b.size(1),
|
||||||
|
b.size(2),
|
||||||
|
seg_indptr,
|
||||||
|
weight_indices,
|
||||||
|
m_num_tiles_indptr,
|
||||||
|
use_fp8_w8a8,
|
||||||
|
scale_a,
|
||||||
|
scale_b,
|
||||||
|
a.stride(0),
|
||||||
|
b.stride(0),
|
||||||
|
b.stride(1),
|
||||||
|
**config,
|
||||||
|
)
|
||||||
|
return c
|
||||||
661
python/sglang/srt/layers/ep_moe/layer.py
Normal file
661
python/sglang/srt/layers/ep_moe/layer.py
Normal file
@@ -0,0 +1,661 @@
|
|||||||
|
import logging
|
||||||
|
from typing import Callable, List, Optional, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch.nn import Module
|
||||||
|
from vllm import _custom_ops as ops
|
||||||
|
from vllm.distributed import (
|
||||||
|
get_tensor_model_parallel_rank,
|
||||||
|
get_tensor_model_parallel_world_size,
|
||||||
|
)
|
||||||
|
from vllm.model_executor.custom_op import CustomOp
|
||||||
|
from vllm.model_executor.layers.quantization.fp8 import Fp8Config, Fp8MoEMethod
|
||||||
|
|
||||||
|
from sglang.srt.layers.custom_op_util import register_custom_op
|
||||||
|
from sglang.srt.layers.ep_moe.kernels import (
|
||||||
|
grouped_gemm_triton,
|
||||||
|
post_reorder_triton_kernel,
|
||||||
|
pre_reorder_triton_kernel,
|
||||||
|
run_moe_ep_preproess,
|
||||||
|
silu_and_mul_triton_kernel,
|
||||||
|
)
|
||||||
|
from sglang.srt.layers.fused_moe_triton.fused_moe import fused_topk, grouped_topk
|
||||||
|
from sglang.srt.layers.fused_moe_triton.layer import FusedMoEMethodBase
|
||||||
|
from sglang.srt.layers.quantization.base_config import (
|
||||||
|
QuantizationConfig,
|
||||||
|
QuantizeMethodBase,
|
||||||
|
)
|
||||||
|
from sglang.srt.utils import is_hip, set_weight_attrs
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class GroupedGemmRunner(torch.nn.Module):
|
||||||
|
flashinfer_gemm_warpper = None
|
||||||
|
|
||||||
|
def __init__(self, device, use_flashinfer: bool = False):
|
||||||
|
super().__init__()
|
||||||
|
self.device = device
|
||||||
|
self.use_flashinfer = use_flashinfer
|
||||||
|
if self.use_flashinfer and GroupedGemmRunner.flashinfer_gemm_warpper is None:
|
||||||
|
GroupedGemmRunner._init_flashinfer_wrapper(device)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _init_flashinfer_wrapper(cls, device):
|
||||||
|
from flashinfer import SegmentGEMMWrapper
|
||||||
|
|
||||||
|
workspace_buffer = torch.empty(
|
||||||
|
128 * 1024 * 1024, dtype=torch.int8, device=device
|
||||||
|
)
|
||||||
|
cls.flashinfer_gemm_warpper = SegmentGEMMWrapper(workspace_buffer)
|
||||||
|
|
||||||
|
# c = a * b
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
a: torch.Tensor,
|
||||||
|
b: torch.Tensor,
|
||||||
|
c: torch.Tensor,
|
||||||
|
batch_size: int,
|
||||||
|
weight_column_major: bool,
|
||||||
|
seg_indptr: Optional[torch.Tensor] = None,
|
||||||
|
weight_indices: Optional[torch.Tensor] = None,
|
||||||
|
use_fp8_w8a8: bool = False,
|
||||||
|
scale_a: torch.Tensor = None,
|
||||||
|
scale_b: torch.Tensor = None,
|
||||||
|
):
|
||||||
|
if self.use_flashinfer:
|
||||||
|
# TODO: flashinfer
|
||||||
|
assert False
|
||||||
|
assert GroupedGemmRunner.flashinfer_gemm_warpper is not None
|
||||||
|
c = GroupedGemmRunner.flashinfer_gemm_warpper.run(
|
||||||
|
x=a,
|
||||||
|
weights=b,
|
||||||
|
batch_size=batch_size,
|
||||||
|
weight_column_major=weight_column_major,
|
||||||
|
seg_indptr=seg_indptr,
|
||||||
|
weight_indices=weight_indices,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
assert weight_column_major == True
|
||||||
|
c = grouped_gemm_triton(
|
||||||
|
a,
|
||||||
|
b,
|
||||||
|
c,
|
||||||
|
batch_size,
|
||||||
|
weight_column_major,
|
||||||
|
seg_indptr,
|
||||||
|
weight_indices,
|
||||||
|
use_fp8_w8a8,
|
||||||
|
scale_a,
|
||||||
|
scale_b,
|
||||||
|
)
|
||||||
|
return c
|
||||||
|
|
||||||
|
|
||||||
|
class EPMoE(torch.nn.Module):
|
||||||
|
"""
|
||||||
|
MoE Expert Parallel Impl
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
num_experts: int,
|
||||||
|
top_k: int,
|
||||||
|
hidden_size: int,
|
||||||
|
intermediate_size: int,
|
||||||
|
params_dtype: Optional[torch.dtype] = None,
|
||||||
|
renormalize: bool = True,
|
||||||
|
use_grouped_topk: bool = False,
|
||||||
|
num_expert_group: Optional[int] = None,
|
||||||
|
topk_group: Optional[int] = None,
|
||||||
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
tp_size: Optional[int] = None,
|
||||||
|
prefix: str = "",
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
if params_dtype is None:
|
||||||
|
params_dtype = torch.get_default_dtype()
|
||||||
|
|
||||||
|
self.tp_size = (
|
||||||
|
tp_size if tp_size is not None else get_tensor_model_parallel_world_size()
|
||||||
|
)
|
||||||
|
self.tp_rank = get_tensor_model_parallel_rank()
|
||||||
|
|
||||||
|
self.num_experts = num_experts
|
||||||
|
assert self.num_experts % self.tp_size == 0
|
||||||
|
self.num_experts_per_partition = self.num_experts // self.tp_size
|
||||||
|
self.start_expert_id = self.tp_rank * self.num_experts_per_partition
|
||||||
|
self.end_expert_id = self.start_expert_id + self.num_experts_per_partition - 1
|
||||||
|
|
||||||
|
self.top_k = top_k
|
||||||
|
self.intermediate_size = intermediate_size
|
||||||
|
self.renormalize = renormalize
|
||||||
|
self.use_grouped_topk = use_grouped_topk
|
||||||
|
if self.use_grouped_topk:
|
||||||
|
assert num_expert_group is not None and topk_group is not None
|
||||||
|
self.num_expert_group = num_expert_group
|
||||||
|
self.topk_group = topk_group
|
||||||
|
|
||||||
|
if quant_config is None:
|
||||||
|
self.quant_method: Optional[QuantizeMethodBase] = UnquantizedEPMoEMethod()
|
||||||
|
self.use_fp8_w8a8 = False
|
||||||
|
self.activation_scheme = None
|
||||||
|
else:
|
||||||
|
self.quant_method: Optional[QuantizeMethodBase] = Fp8EPMoEMethod(
|
||||||
|
quant_config
|
||||||
|
)
|
||||||
|
self.use_fp8_w8a8 = True
|
||||||
|
self.fp8_dtype = torch.float8_e4m3fn
|
||||||
|
self.activation_scheme = quant_config.activation_scheme
|
||||||
|
|
||||||
|
self.quant_method.create_weights(
|
||||||
|
layer=self,
|
||||||
|
num_experts_per_partition=self.num_experts_per_partition,
|
||||||
|
hidden_size=hidden_size,
|
||||||
|
intermediate_size=self.intermediate_size,
|
||||||
|
params_dtype=params_dtype,
|
||||||
|
weight_loader=self.weight_loader,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.grouped_gemm_runner = None
|
||||||
|
|
||||||
|
def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor):
|
||||||
|
assert self.quant_method is not None
|
||||||
|
|
||||||
|
if self.grouped_gemm_runner is None:
|
||||||
|
self.grouped_gemm_runner = GroupedGemmRunner(
|
||||||
|
hidden_states.device, use_flashinfer=False # TODO: use flashinfer
|
||||||
|
)
|
||||||
|
|
||||||
|
topk_weights, topk_ids = self.select_experts(
|
||||||
|
hidden_states,
|
||||||
|
router_logits,
|
||||||
|
self.top_k,
|
||||||
|
self.renormalize,
|
||||||
|
self.topk_group,
|
||||||
|
self.num_expert_group,
|
||||||
|
)
|
||||||
|
|
||||||
|
reorder_topk_ids, src2dst, seg_indptr = run_moe_ep_preproess(
|
||||||
|
topk_ids, self.num_experts
|
||||||
|
)
|
||||||
|
|
||||||
|
gateup_input = torch.empty(
|
||||||
|
(int(hidden_states.shape[0] * self.top_k), hidden_states.shape[1]),
|
||||||
|
device=hidden_states.device,
|
||||||
|
dtype=self.fp8_dtype if self.use_fp8_w8a8 else hidden_states.dtype,
|
||||||
|
)
|
||||||
|
if self.activation_scheme == "dynamic":
|
||||||
|
max_value = (
|
||||||
|
torch.max(hidden_states)
|
||||||
|
.repeat(self.num_experts_per_partition)
|
||||||
|
.to(torch.float32)
|
||||||
|
)
|
||||||
|
self.w13_input_scale = max_value / torch.finfo(self.fp8_dtype).max
|
||||||
|
|
||||||
|
# PreReorder
|
||||||
|
pre_reorder_triton_kernel[(hidden_states.shape[0],)](
|
||||||
|
hidden_states,
|
||||||
|
gateup_input,
|
||||||
|
src2dst,
|
||||||
|
topk_ids,
|
||||||
|
self.w13_input_scale,
|
||||||
|
self.start_expert_id,
|
||||||
|
self.end_expert_id,
|
||||||
|
self.top_k,
|
||||||
|
hidden_states.shape[1],
|
||||||
|
BLOCK_SIZE=512,
|
||||||
|
)
|
||||||
|
|
||||||
|
seg_indptr_cur_rank = seg_indptr[self.start_expert_id : self.end_expert_id + 2]
|
||||||
|
weight_indices_cur_rank = torch.arange(
|
||||||
|
0,
|
||||||
|
self.num_experts_per_partition,
|
||||||
|
device=hidden_states.device,
|
||||||
|
dtype=torch.int64,
|
||||||
|
)
|
||||||
|
# GroupGemm-0
|
||||||
|
gateup_output = torch.empty(
|
||||||
|
gateup_input.shape[0],
|
||||||
|
self.w13_weight.shape[1],
|
||||||
|
device=hidden_states.device,
|
||||||
|
dtype=hidden_states.dtype,
|
||||||
|
)
|
||||||
|
gateup_output = self.grouped_gemm_runner(
|
||||||
|
a=gateup_input,
|
||||||
|
b=self.w13_weight,
|
||||||
|
c=gateup_output,
|
||||||
|
batch_size=self.num_experts_per_partition,
|
||||||
|
weight_column_major=True,
|
||||||
|
seg_indptr=seg_indptr_cur_rank,
|
||||||
|
weight_indices=weight_indices_cur_rank,
|
||||||
|
use_fp8_w8a8=self.use_fp8_w8a8,
|
||||||
|
scale_a=self.w13_input_scale,
|
||||||
|
scale_b=self.w13_weight_scale,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Act
|
||||||
|
down_input = torch.empty(
|
||||||
|
gateup_output.shape[0],
|
||||||
|
gateup_output.shape[1] // 2,
|
||||||
|
device=gateup_output.device,
|
||||||
|
dtype=self.fp8_dtype if self.use_fp8_w8a8 else hidden_states.dtype,
|
||||||
|
)
|
||||||
|
if self.w2_input_scale is None:
|
||||||
|
self.w2_input_scale = torch.ones(
|
||||||
|
self.num_experts_per_partition,
|
||||||
|
dtype=torch.float32,
|
||||||
|
device=hidden_states.device,
|
||||||
|
)
|
||||||
|
silu_and_mul_triton_kernel[(gateup_output.shape[0],)](
|
||||||
|
gateup_output,
|
||||||
|
down_input,
|
||||||
|
gateup_output.shape[1],
|
||||||
|
reorder_topk_ids,
|
||||||
|
self.w2_input_scale,
|
||||||
|
self.start_expert_id,
|
||||||
|
self.end_expert_id,
|
||||||
|
BLOCK_SIZE=512,
|
||||||
|
)
|
||||||
|
|
||||||
|
# GroupGemm-1
|
||||||
|
down_output = torch.empty(
|
||||||
|
down_input.shape[0],
|
||||||
|
self.w2_weight.shape[1],
|
||||||
|
device=hidden_states.device,
|
||||||
|
dtype=hidden_states.dtype,
|
||||||
|
)
|
||||||
|
down_output = self.grouped_gemm_runner(
|
||||||
|
a=down_input,
|
||||||
|
b=self.w2_weight,
|
||||||
|
c=down_output,
|
||||||
|
batch_size=self.num_experts_per_partition,
|
||||||
|
weight_column_major=True,
|
||||||
|
seg_indptr=seg_indptr_cur_rank,
|
||||||
|
weight_indices=weight_indices_cur_rank,
|
||||||
|
use_fp8_w8a8=self.use_fp8_w8a8,
|
||||||
|
scale_a=self.w2_input_scale,
|
||||||
|
scale_b=self.w2_weight_scale,
|
||||||
|
)
|
||||||
|
|
||||||
|
# PostReorder
|
||||||
|
output = torch.empty_like(hidden_states)
|
||||||
|
post_reorder_triton_kernel[(hidden_states.size(0),)](
|
||||||
|
down_output,
|
||||||
|
output,
|
||||||
|
src2dst,
|
||||||
|
topk_ids,
|
||||||
|
topk_weights,
|
||||||
|
self.start_expert_id,
|
||||||
|
self.end_expert_id,
|
||||||
|
self.top_k,
|
||||||
|
hidden_states.size(1),
|
||||||
|
BLOCK_SIZE=512,
|
||||||
|
)
|
||||||
|
return output
|
||||||
|
|
||||||
|
def select_experts(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
router_logits: torch.Tensor,
|
||||||
|
top_k: int,
|
||||||
|
renormalize: bool,
|
||||||
|
topk_group: Optional[int] = None,
|
||||||
|
num_expert_group: Optional[int] = None,
|
||||||
|
):
|
||||||
|
if self.use_grouped_topk:
|
||||||
|
assert topk_group is not None
|
||||||
|
assert num_expert_group is not None
|
||||||
|
topk_weights, topk_ids = grouped_topk(
|
||||||
|
hidden_states=hidden_states,
|
||||||
|
gating_output=router_logits,
|
||||||
|
topk=top_k,
|
||||||
|
renormalize=renormalize,
|
||||||
|
num_expert_group=num_expert_group,
|
||||||
|
topk_group=topk_group,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
topk_weights, topk_ids = fused_topk(
|
||||||
|
hidden_states=hidden_states,
|
||||||
|
gating_output=router_logits,
|
||||||
|
topk=top_k,
|
||||||
|
renormalize=renormalize,
|
||||||
|
)
|
||||||
|
return topk_weights, topk_ids.to(torch.int32)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def make_expert_params_mapping(
|
||||||
|
cls,
|
||||||
|
ckpt_gate_proj_name: str,
|
||||||
|
ckpt_down_proj_name: str,
|
||||||
|
ckpt_up_proj_name: str,
|
||||||
|
num_experts: int,
|
||||||
|
) -> List[Tuple[str, str, int, str]]:
|
||||||
|
|
||||||
|
return [
|
||||||
|
# (param_name, weight_name, expert_id, shard_id)
|
||||||
|
(
|
||||||
|
(
|
||||||
|
"experts.w13_"
|
||||||
|
if weight_name in [ckpt_gate_proj_name, ckpt_up_proj_name]
|
||||||
|
else "experts.w2_"
|
||||||
|
),
|
||||||
|
f"experts.{expert_id}.{weight_name}.",
|
||||||
|
expert_id,
|
||||||
|
shard_id,
|
||||||
|
)
|
||||||
|
for expert_id in range(num_experts)
|
||||||
|
for shard_id, weight_name in [
|
||||||
|
("w1", ckpt_gate_proj_name),
|
||||||
|
("w2", ckpt_down_proj_name),
|
||||||
|
("w3", ckpt_up_proj_name),
|
||||||
|
]
|
||||||
|
]
|
||||||
|
|
||||||
|
def weight_loader(
|
||||||
|
self,
|
||||||
|
param: torch.nn.Parameter,
|
||||||
|
loaded_weight: torch.Tensor,
|
||||||
|
weight_name: str,
|
||||||
|
shard_id: str,
|
||||||
|
expert_id: int,
|
||||||
|
) -> None:
|
||||||
|
if expert_id < self.start_expert_id or expert_id > self.end_expert_id:
|
||||||
|
return
|
||||||
|
expert_id = expert_id - self.start_expert_id
|
||||||
|
|
||||||
|
if shard_id not in ("w1", "w2", "w3"):
|
||||||
|
raise ValueError(
|
||||||
|
f"shard_id must be ['w1','w2','w3'] but " f"got {shard_id}."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Special case for fp8 scales.
|
||||||
|
if "scale" in weight_name:
|
||||||
|
self._load_fp8_scale(
|
||||||
|
param.data, loaded_weight, weight_name, shard_id, expert_id
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
expert_data = param.data[expert_id]
|
||||||
|
if shard_id == "w2":
|
||||||
|
param.data[expert_id] = loaded_weight
|
||||||
|
elif shard_id == "w1":
|
||||||
|
param.data[expert_id][: self.intermediate_size, :] = loaded_weight
|
||||||
|
elif shard_id == "w3":
|
||||||
|
param.data[expert_id][self.intermediate_size :, :] = loaded_weight
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Expected shard_id w1,w2 or w3 but got {shard_id}")
|
||||||
|
|
||||||
|
def _load_fp8_scale(
|
||||||
|
self,
|
||||||
|
param: torch.nn.Parameter,
|
||||||
|
loaded_weight: torch.Tensor,
|
||||||
|
weight_name: str,
|
||||||
|
shard_id: str,
|
||||||
|
expert_id: int,
|
||||||
|
) -> None:
|
||||||
|
param_data = param.data
|
||||||
|
|
||||||
|
# Input scales can be loaded directly and should be equal.
|
||||||
|
if "input_scale" in weight_name:
|
||||||
|
if (
|
||||||
|
param_data[expert_id] != 1
|
||||||
|
and (param_data[expert_id] - loaded_weight).abs() > 1e-5
|
||||||
|
):
|
||||||
|
raise ValueError(
|
||||||
|
"input_scales of w1 and w3 of a layer "
|
||||||
|
f"must be equal. But got {param_data[expert_id]} "
|
||||||
|
f"vs. {loaded_weight}"
|
||||||
|
)
|
||||||
|
param_data[expert_id] = loaded_weight
|
||||||
|
# Weight scales
|
||||||
|
elif "weight_scale" in weight_name:
|
||||||
|
# If we are in merged column case (gate_up_proj)
|
||||||
|
if shard_id in ("w1", "w3"):
|
||||||
|
# We have to keep the weight scales of w1 and w3 because
|
||||||
|
# we need to re-quantize w1/w3 weights after weight loading.
|
||||||
|
idx = 0 if shard_id == "w1" else 1
|
||||||
|
param_data[expert_id][idx] = loaded_weight
|
||||||
|
# If we are in the row parallel case (down_proj)
|
||||||
|
else:
|
||||||
|
param_data[expert_id] = loaded_weight
|
||||||
|
|
||||||
|
|
||||||
|
@register_custom_op("sglang_unquantized_ep_moe")
|
||||||
|
class UnquantizedEPMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||||
|
def create_weights(
|
||||||
|
self,
|
||||||
|
layer: torch.nn.Module,
|
||||||
|
num_experts_per_partition: int,
|
||||||
|
hidden_size: int,
|
||||||
|
intermediate_size: int,
|
||||||
|
params_dtype: torch.dtype,
|
||||||
|
**extra_weight_attrs,
|
||||||
|
):
|
||||||
|
# Fused gate_up_proj (column parallel)
|
||||||
|
w13_weight = torch.nn.Parameter(
|
||||||
|
torch.empty(
|
||||||
|
num_experts_per_partition,
|
||||||
|
2 * intermediate_size,
|
||||||
|
hidden_size,
|
||||||
|
dtype=params_dtype,
|
||||||
|
),
|
||||||
|
requires_grad=False,
|
||||||
|
)
|
||||||
|
layer.register_parameter("w13_weight", w13_weight)
|
||||||
|
set_weight_attrs(w13_weight, extra_weight_attrs)
|
||||||
|
|
||||||
|
# down_proj (row parallel)
|
||||||
|
w2_weight = torch.nn.Parameter(
|
||||||
|
torch.empty(
|
||||||
|
num_experts_per_partition,
|
||||||
|
hidden_size,
|
||||||
|
intermediate_size,
|
||||||
|
dtype=params_dtype,
|
||||||
|
),
|
||||||
|
requires_grad=False,
|
||||||
|
)
|
||||||
|
layer.register_parameter("w2_weight", w2_weight)
|
||||||
|
set_weight_attrs(w2_weight, extra_weight_attrs)
|
||||||
|
|
||||||
|
# scale
|
||||||
|
ones_tensor = torch.ones(num_experts_per_partition, dtype=torch.float32)
|
||||||
|
w13_input_scale = torch.nn.Parameter(
|
||||||
|
ones_tensor,
|
||||||
|
requires_grad=False,
|
||||||
|
)
|
||||||
|
layer.register_parameter("w13_input_scale", w13_input_scale)
|
||||||
|
set_weight_attrs(w13_input_scale, extra_weight_attrs)
|
||||||
|
|
||||||
|
w2_input_scale = torch.nn.Parameter(
|
||||||
|
ones_tensor,
|
||||||
|
requires_grad=False,
|
||||||
|
)
|
||||||
|
layer.register_parameter("w2_input_scale", w2_input_scale)
|
||||||
|
set_weight_attrs(w2_input_scale, extra_weight_attrs)
|
||||||
|
|
||||||
|
w13_weight_scale = torch.nn.Parameter(
|
||||||
|
ones_tensor,
|
||||||
|
requires_grad=False,
|
||||||
|
)
|
||||||
|
layer.register_parameter("w13_weight_scale", w13_weight_scale)
|
||||||
|
set_weight_attrs(w13_weight_scale, extra_weight_attrs)
|
||||||
|
|
||||||
|
w2_weight_scale = torch.nn.Parameter(
|
||||||
|
ones_tensor,
|
||||||
|
requires_grad=False,
|
||||||
|
)
|
||||||
|
layer.register_parameter("w2_weight_scale", w2_weight_scale)
|
||||||
|
set_weight_attrs(w2_weight_scale, extra_weight_attrs)
|
||||||
|
|
||||||
|
def apply(
|
||||||
|
self,
|
||||||
|
layer: torch.nn.Module,
|
||||||
|
x: torch.Tensor,
|
||||||
|
router_logits: torch.Tensor,
|
||||||
|
top_k: int,
|
||||||
|
renormalize: bool,
|
||||||
|
use_grouped_topk: bool,
|
||||||
|
topk_group: Optional[int] = None,
|
||||||
|
num_expert_group: Optional[int] = None,
|
||||||
|
custom_routing_function: Optional[Callable] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
class Fp8EPMoEMethod(Fp8MoEMethod):
|
||||||
|
"""MoE method for FP8.
|
||||||
|
Supports loading FP8 checkpoints with static weight scale and
|
||||||
|
dynamic/static activation scale.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
quant_config: The quantization config.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, quant_config: Fp8Config):
|
||||||
|
self.quant_config = quant_config
|
||||||
|
|
||||||
|
def create_weights(
|
||||||
|
self,
|
||||||
|
layer: Module,
|
||||||
|
num_experts_per_partition: int,
|
||||||
|
hidden_size: int,
|
||||||
|
intermediate_size: int,
|
||||||
|
params_dtype: torch.dtype,
|
||||||
|
**extra_weight_attrs,
|
||||||
|
):
|
||||||
|
|
||||||
|
if self.quant_config.is_checkpoint_fp8_serialized:
|
||||||
|
params_dtype = torch.float8_e4m3fn
|
||||||
|
|
||||||
|
# WEIGHTS
|
||||||
|
w13_weight = torch.nn.Parameter(
|
||||||
|
torch.empty(
|
||||||
|
num_experts_per_partition,
|
||||||
|
2 * intermediate_size,
|
||||||
|
hidden_size,
|
||||||
|
dtype=params_dtype,
|
||||||
|
),
|
||||||
|
requires_grad=False,
|
||||||
|
)
|
||||||
|
layer.register_parameter("w13_weight", w13_weight)
|
||||||
|
set_weight_attrs(w13_weight, extra_weight_attrs)
|
||||||
|
|
||||||
|
w2_weight = torch.nn.Parameter(
|
||||||
|
torch.empty(
|
||||||
|
num_experts_per_partition,
|
||||||
|
hidden_size,
|
||||||
|
intermediate_size,
|
||||||
|
dtype=params_dtype,
|
||||||
|
),
|
||||||
|
requires_grad=False,
|
||||||
|
)
|
||||||
|
layer.register_parameter("w2_weight", w2_weight)
|
||||||
|
set_weight_attrs(w2_weight, extra_weight_attrs)
|
||||||
|
|
||||||
|
# WEIGHT_SCALES
|
||||||
|
# Allocate 2 scales for w1 and w3 respectively.
|
||||||
|
w13_weight_scale = torch.nn.Parameter(
|
||||||
|
torch.ones(num_experts_per_partition, 2, dtype=torch.float32),
|
||||||
|
requires_grad=False,
|
||||||
|
)
|
||||||
|
layer.register_parameter("w13_weight_scale", w13_weight_scale)
|
||||||
|
|
||||||
|
w2_weight_scale = torch.nn.Parameter(
|
||||||
|
torch.ones(num_experts_per_partition, dtype=torch.float32),
|
||||||
|
requires_grad=False,
|
||||||
|
)
|
||||||
|
layer.register_parameter("w2_weight_scale", w2_weight_scale)
|
||||||
|
# Add the quantization method used (per tensor/grouped/channel)
|
||||||
|
# to ensure the weight scales are loaded in properly
|
||||||
|
extra_weight_attrs.update({"quant_method": "tensor"})
|
||||||
|
# If loading fp8 checkpoint, pass the weight loaders.
|
||||||
|
# If loading an fp16 checkpoint, do not (we will quantize in
|
||||||
|
# process_weights_after_loading()
|
||||||
|
if self.quant_config.is_checkpoint_fp8_serialized:
|
||||||
|
set_weight_attrs(w13_weight_scale, extra_weight_attrs)
|
||||||
|
set_weight_attrs(w2_weight_scale, extra_weight_attrs)
|
||||||
|
|
||||||
|
# INPUT_SCALES
|
||||||
|
if self.quant_config.activation_scheme == "static":
|
||||||
|
if not self.quant_config.is_checkpoint_fp8_serialized:
|
||||||
|
raise ValueError(
|
||||||
|
"Found static activation scheme for checkpoint that "
|
||||||
|
"was not serialized fp8."
|
||||||
|
)
|
||||||
|
|
||||||
|
w13_input_scale = torch.nn.Parameter(
|
||||||
|
torch.ones(num_experts_per_partition, dtype=torch.float32),
|
||||||
|
requires_grad=False,
|
||||||
|
)
|
||||||
|
layer.register_parameter("w13_input_scale", w13_input_scale)
|
||||||
|
set_weight_attrs(w13_input_scale, extra_weight_attrs)
|
||||||
|
|
||||||
|
w2_input_scale = torch.nn.Parameter(
|
||||||
|
torch.ones(num_experts_per_partition, dtype=torch.float32),
|
||||||
|
requires_grad=False,
|
||||||
|
)
|
||||||
|
layer.register_parameter("w2_input_scale", w2_input_scale)
|
||||||
|
set_weight_attrs(w2_input_scale, extra_weight_attrs)
|
||||||
|
|
||||||
|
else:
|
||||||
|
layer.w13_input_scale = None
|
||||||
|
layer.w2_input_scale = None
|
||||||
|
|
||||||
|
def process_weights_after_loading(self, layer: Module) -> None:
|
||||||
|
|
||||||
|
# If checkpoint is fp16, quantize in place.
|
||||||
|
if not self.quant_config.is_checkpoint_fp8_serialized:
|
||||||
|
# If rocm, use float8_e4m3fnuz as dtype
|
||||||
|
fp8_dtype = torch.float8_e4m3fnuz if is_hip() else torch.float8_e4m3fn
|
||||||
|
w13_weight = torch.empty_like(layer.w13_weight.data, dtype=fp8_dtype)
|
||||||
|
w2_weight = torch.empty_like(layer.w2_weight.data, dtype=fp8_dtype)
|
||||||
|
|
||||||
|
layer.w13_weight_scale = torch.nn.Parameter(
|
||||||
|
torch.ones(
|
||||||
|
layer.num_experts_per_partition,
|
||||||
|
dtype=torch.float32,
|
||||||
|
device=w13_weight.device,
|
||||||
|
),
|
||||||
|
requires_grad=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
for expert in range(layer.num_experts_per_partition):
|
||||||
|
w13_weight[expert, :, :], layer.w13_weight_scale[expert] = (
|
||||||
|
ops.scaled_fp8_quant(layer.w13_weight.data[expert, :, :])
|
||||||
|
)
|
||||||
|
w2_weight[expert, :, :], layer.w2_weight_scale[expert] = (
|
||||||
|
ops.scaled_fp8_quant(layer.w2_weight.data[expert, :, :])
|
||||||
|
)
|
||||||
|
layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False)
|
||||||
|
layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False)
|
||||||
|
return
|
||||||
|
|
||||||
|
# If checkpoint is fp8, we need to handle that the
|
||||||
|
# MoE kernels require single activation scale and single weight
|
||||||
|
# scale for w13 per expert.
|
||||||
|
else:
|
||||||
|
if self.quant_config.activation_scheme == "static":
|
||||||
|
if layer.w13_input_scale is None or layer.w2_input_scale is None:
|
||||||
|
raise ValueError(
|
||||||
|
"QuantConfig has static quantization, but found "
|
||||||
|
"activation scales are None."
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
def apply(
|
||||||
|
self,
|
||||||
|
layer: torch.nn.Module,
|
||||||
|
x: torch.Tensor,
|
||||||
|
router_logits: torch.Tensor,
|
||||||
|
top_k: int,
|
||||||
|
renormalize: bool,
|
||||||
|
use_grouped_topk: bool,
|
||||||
|
topk_group: Optional[int] = None,
|
||||||
|
num_expert_group: Optional[int] = None,
|
||||||
|
custom_routing_function: Optional[Callable] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
raise NotImplementedError
|
||||||
@@ -58,6 +58,7 @@ global_server_args_dict = {
|
|||||||
"torchao_config": ServerArgs.torchao_config,
|
"torchao_config": ServerArgs.torchao_config,
|
||||||
"enable_nan_detection": ServerArgs.enable_nan_detection,
|
"enable_nan_detection": ServerArgs.enable_nan_detection,
|
||||||
"enable_dp_attention": ServerArgs.enable_dp_attention,
|
"enable_dp_attention": ServerArgs.enable_dp_attention,
|
||||||
|
"enable_ep_moe": ServerArgs.enable_ep_moe,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -141,6 +141,7 @@ class ModelRunner:
|
|||||||
"torchao_config": server_args.torchao_config,
|
"torchao_config": server_args.torchao_config,
|
||||||
"enable_nan_detection": server_args.enable_nan_detection,
|
"enable_nan_detection": server_args.enable_nan_detection,
|
||||||
"enable_dp_attention": server_args.enable_dp_attention,
|
"enable_dp_attention": server_args.enable_dp_attention,
|
||||||
|
"enable_ep_moe": server_args.enable_ep_moe,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -31,6 +31,7 @@ from vllm.distributed import (
|
|||||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||||
|
|
||||||
from sglang.srt.layers.activation import SiluAndMul
|
from sglang.srt.layers.activation import SiluAndMul
|
||||||
|
from sglang.srt.layers.ep_moe.layer import EPMoE
|
||||||
from sglang.srt.layers.fused_moe_triton import FusedMoE
|
from sglang.srt.layers.fused_moe_triton import FusedMoE
|
||||||
from sglang.srt.layers.layernorm import RMSNorm
|
from sglang.srt.layers.layernorm import RMSNorm
|
||||||
from sglang.srt.layers.linear import (
|
from sglang.srt.layers.linear import (
|
||||||
@@ -113,12 +114,12 @@ class DeepseekV2MoE(nn.Module):
|
|||||||
"Only silu is supported for now."
|
"Only silu is supported for now."
|
||||||
)
|
)
|
||||||
|
|
||||||
self.experts = FusedMoE(
|
MoEImpl = EPMoE if global_server_args_dict["enable_ep_moe"] else FusedMoE
|
||||||
|
self.experts = MoEImpl(
|
||||||
num_experts=config.n_routed_experts,
|
num_experts=config.n_routed_experts,
|
||||||
top_k=config.num_experts_per_tok,
|
top_k=config.num_experts_per_tok,
|
||||||
hidden_size=config.hidden_size,
|
hidden_size=config.hidden_size,
|
||||||
intermediate_size=config.moe_intermediate_size,
|
intermediate_size=config.moe_intermediate_size,
|
||||||
reduce_results=False,
|
|
||||||
renormalize=config.norm_topk_prob,
|
renormalize=config.norm_topk_prob,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
use_grouped_topk=True,
|
use_grouped_topk=True,
|
||||||
@@ -834,7 +835,8 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|||||||
|
|
||||||
# Params for weights, fp8 weight scales, fp8 activation scales
|
# Params for weights, fp8 weight scales, fp8 activation scales
|
||||||
# (param_name, weight_name, expert_id, shard_id)
|
# (param_name, weight_name, expert_id, shard_id)
|
||||||
expert_params_mapping = FusedMoE.make_expert_params_mapping(
|
MoEImpl = EPMoE if global_server_args_dict["enable_ep_moe"] else FusedMoE
|
||||||
|
expert_params_mapping = MoEImpl.make_expert_params_mapping(
|
||||||
ckpt_gate_proj_name="gate_proj",
|
ckpt_gate_proj_name="gate_proj",
|
||||||
ckpt_down_proj_name="down_proj",
|
ckpt_down_proj_name="down_proj",
|
||||||
ckpt_up_proj_name="up_proj",
|
ckpt_up_proj_name="up_proj",
|
||||||
|
|||||||
@@ -21,9 +21,13 @@ from typing import Iterable, Optional, Tuple
|
|||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from transformers import MixtralConfig
|
from transformers import MixtralConfig
|
||||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
from vllm.distributed import (
|
||||||
|
get_tensor_model_parallel_world_size,
|
||||||
|
tensor_model_parallel_all_reduce,
|
||||||
|
)
|
||||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||||
|
|
||||||
|
from sglang.srt.layers.ep_moe.layer import EPMoE
|
||||||
from sglang.srt.layers.fused_moe_triton import FusedMoE
|
from sglang.srt.layers.fused_moe_triton import FusedMoE
|
||||||
from sglang.srt.layers.layernorm import RMSNorm
|
from sglang.srt.layers.layernorm import RMSNorm
|
||||||
from sglang.srt.layers.linear import (
|
from sglang.srt.layers.linear import (
|
||||||
@@ -38,6 +42,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
|
|||||||
ParallelLMHead,
|
ParallelLMHead,
|
||||||
VocabParallelEmbedding,
|
VocabParallelEmbedding,
|
||||||
)
|
)
|
||||||
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||||
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
||||||
|
|
||||||
@@ -63,6 +68,7 @@ class MixtralMoE(nn.Module):
|
|||||||
prefix: str = "",
|
prefix: str = "",
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
self.tp_size = get_tensor_model_parallel_world_size()
|
||||||
self.hidden_size = hidden_size
|
self.hidden_size = hidden_size
|
||||||
|
|
||||||
# Gate always runs at half / full precision for now.
|
# Gate always runs at half / full precision for now.
|
||||||
@@ -74,14 +80,13 @@ class MixtralMoE(nn.Module):
|
|||||||
quant_config=None,
|
quant_config=None,
|
||||||
prefix=f"{prefix}.gate",
|
prefix=f"{prefix}.gate",
|
||||||
)
|
)
|
||||||
|
MoEImpl = EPMoE if global_server_args_dict["enable_ep_moe"] else FusedMoE
|
||||||
self.experts = FusedMoE(
|
self.experts = MoEImpl(
|
||||||
num_experts=num_experts,
|
num_experts=num_experts,
|
||||||
top_k=top_k,
|
top_k=top_k,
|
||||||
hidden_size=hidden_size,
|
hidden_size=hidden_size,
|
||||||
intermediate_size=intermediate_size,
|
intermediate_size=intermediate_size,
|
||||||
params_dtype=params_dtype,
|
params_dtype=params_dtype,
|
||||||
reduce_results=True,
|
|
||||||
renormalize=True,
|
renormalize=True,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
tp_size=tp_size,
|
tp_size=tp_size,
|
||||||
@@ -95,6 +100,8 @@ class MixtralMoE(nn.Module):
|
|||||||
# router_logits: (num_tokens, n_experts)
|
# router_logits: (num_tokens, n_experts)
|
||||||
router_logits, _ = self.gate(hidden_states)
|
router_logits, _ = self.gate(hidden_states)
|
||||||
final_hidden_states = self.experts(hidden_states, router_logits)
|
final_hidden_states = self.experts(hidden_states, router_logits)
|
||||||
|
if self.tp_size > 1:
|
||||||
|
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
|
||||||
return final_hidden_states.view(orig_shape)
|
return final_hidden_states.view(orig_shape)
|
||||||
|
|
||||||
|
|
||||||
@@ -319,7 +326,8 @@ class MixtralForCausalLM(nn.Module):
|
|||||||
|
|
||||||
# Params for weights, fp8 weight scales, fp8 activation scales
|
# Params for weights, fp8 weight scales, fp8 activation scales
|
||||||
# (param_name, weight_name, expert_id, shard_id)
|
# (param_name, weight_name, expert_id, shard_id)
|
||||||
expert_params_mapping = FusedMoE.make_expert_params_mapping(
|
MoEImpl = EPMoE if global_server_args_dict["enable_ep_moe"] else FusedMoE
|
||||||
|
expert_params_mapping = MoEImpl.make_expert_params_mapping(
|
||||||
ckpt_gate_proj_name="w1",
|
ckpt_gate_proj_name="w1",
|
||||||
ckpt_down_proj_name="w2",
|
ckpt_down_proj_name="w2",
|
||||||
ckpt_up_proj_name="w3",
|
ckpt_up_proj_name="w3",
|
||||||
|
|||||||
@@ -93,6 +93,8 @@ class ServerArgs:
|
|||||||
# Data parallelism
|
# Data parallelism
|
||||||
dp_size: int = 1
|
dp_size: int = 1
|
||||||
load_balance_method: str = "round_robin"
|
load_balance_method: str = "round_robin"
|
||||||
|
# Expert parallelism
|
||||||
|
ep_size: int = 1
|
||||||
|
|
||||||
# Multi-node distributed serving
|
# Multi-node distributed serving
|
||||||
dist_init_addr: Optional[str] = None
|
dist_init_addr: Optional[str] = None
|
||||||
@@ -130,6 +132,7 @@ class ServerArgs:
|
|||||||
disable_overlap_schedule: bool = False
|
disable_overlap_schedule: bool = False
|
||||||
enable_mixed_chunk: bool = False
|
enable_mixed_chunk: bool = False
|
||||||
enable_dp_attention: bool = False
|
enable_dp_attention: bool = False
|
||||||
|
enable_ep_moe: bool = False
|
||||||
enable_torch_compile: bool = False
|
enable_torch_compile: bool = False
|
||||||
torch_compile_max_bs: int = 32
|
torch_compile_max_bs: int = 32
|
||||||
cuda_graph_max_bs: Optional[int] = None
|
cuda_graph_max_bs: Optional[int] = None
|
||||||
@@ -216,6 +219,12 @@ class ServerArgs:
|
|||||||
"Data parallel size is adjusted to be the same as tensor parallel size. "
|
"Data parallel size is adjusted to be the same as tensor parallel size. "
|
||||||
"Overlap scheduler is disabled."
|
"Overlap scheduler is disabled."
|
||||||
)
|
)
|
||||||
|
# Expert parallelism
|
||||||
|
if self.enable_ep_moe:
|
||||||
|
self.ep_size = self.tp_size
|
||||||
|
logger.info(
|
||||||
|
f"EP MoE is enabled. The expert parallel size is adjusted to be the same as the tensor parallel size[{self.tp_size}]."
|
||||||
|
)
|
||||||
|
|
||||||
# GGUF
|
# GGUF
|
||||||
if (
|
if (
|
||||||
@@ -526,6 +535,14 @@ class ServerArgs:
|
|||||||
"shortest_queue",
|
"shortest_queue",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
# Expert parallelism
|
||||||
|
parser.add_argument(
|
||||||
|
"--expert-parallel-size",
|
||||||
|
"--ep-size",
|
||||||
|
type=int,
|
||||||
|
default=ServerArgs.ep_size,
|
||||||
|
help="The expert parallelism size.",
|
||||||
|
)
|
||||||
|
|
||||||
# Multi-node distributed serving
|
# Multi-node distributed serving
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@@ -681,6 +698,11 @@ class ServerArgs:
|
|||||||
action="store_true",
|
action="store_true",
|
||||||
help="Enabling data parallelism for attention and tensor parallelism for FFN. The dp size should be equal to the tp size. Currently only DeepSeek-V2 is supported.",
|
help="Enabling data parallelism for attention and tensor parallelism for FFN. The dp size should be equal to the tp size. Currently only DeepSeek-V2 is supported.",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--enable-ep-moe",
|
||||||
|
action="store_true",
|
||||||
|
help="Enabling expert parallelism for moe. The ep size is equal to the tp size.",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--enable-torch-compile",
|
"--enable-torch-compile",
|
||||||
action="store_true",
|
action="store_true",
|
||||||
@@ -760,6 +782,7 @@ class ServerArgs:
|
|||||||
def from_cli_args(cls, args: argparse.Namespace):
|
def from_cli_args(cls, args: argparse.Namespace):
|
||||||
args.tp_size = args.tensor_parallel_size
|
args.tp_size = args.tensor_parallel_size
|
||||||
args.dp_size = args.data_parallel_size
|
args.dp_size = args.data_parallel_size
|
||||||
|
args.ep_size = args.expert_parallel_size
|
||||||
attrs = [attr.name for attr in dataclasses.fields(cls)]
|
attrs = [attr.name for attr in dataclasses.fields(cls)]
|
||||||
return cls(**{attr: getattr(args, attr) for attr in attrs})
|
return cls(**{attr: getattr(args, attr) for attr in attrs})
|
||||||
|
|
||||||
|
|||||||
113
test/srt/test_moe_ep.py
Normal file
113
test/srt/test_moe_ep.py
Normal file
@@ -0,0 +1,113 @@
|
|||||||
|
import unittest
|
||||||
|
from types import SimpleNamespace
|
||||||
|
|
||||||
|
from sglang.srt.utils import kill_process_tree
|
||||||
|
from sglang.test.run_eval import run_eval
|
||||||
|
from sglang.test.test_utils import (
|
||||||
|
DEFAULT_MLA_MODEL_NAME_FOR_TEST,
|
||||||
|
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||||
|
DEFAULT_URL_FOR_TEST,
|
||||||
|
popen_launch_server,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestEpMoE(unittest.TestCase):
|
||||||
|
@classmethod
|
||||||
|
def setUpClass(cls):
|
||||||
|
cls.model = DEFAULT_MLA_MODEL_NAME_FOR_TEST
|
||||||
|
cls.base_url = DEFAULT_URL_FOR_TEST
|
||||||
|
cls.process = popen_launch_server(
|
||||||
|
cls.model,
|
||||||
|
cls.base_url,
|
||||||
|
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||||
|
other_args=[
|
||||||
|
"--trust-remote-code",
|
||||||
|
"--tp",
|
||||||
|
"2",
|
||||||
|
"--ep-size",
|
||||||
|
"2",
|
||||||
|
"--enable-ep-moe",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def tearDownClass(cls):
|
||||||
|
kill_process_tree(cls.process.pid)
|
||||||
|
|
||||||
|
def test_mmlu(self):
|
||||||
|
args = SimpleNamespace(
|
||||||
|
base_url=self.base_url,
|
||||||
|
model=self.model,
|
||||||
|
eval_name="mmlu",
|
||||||
|
num_examples=64,
|
||||||
|
num_threads=32,
|
||||||
|
)
|
||||||
|
|
||||||
|
metrics = run_eval(args)
|
||||||
|
assert metrics["score"] >= 0.5
|
||||||
|
|
||||||
|
def test_mgsm_en(self):
|
||||||
|
args = SimpleNamespace(
|
||||||
|
base_url=self.base_url,
|
||||||
|
model=self.model,
|
||||||
|
eval_name="mgsm_en",
|
||||||
|
num_examples=None,
|
||||||
|
num_threads=1024,
|
||||||
|
)
|
||||||
|
|
||||||
|
metrics = run_eval(args)
|
||||||
|
assert metrics["score"] >= 0.8
|
||||||
|
|
||||||
|
|
||||||
|
class TestEpMoEFP8(unittest.TestCase):
|
||||||
|
@classmethod
|
||||||
|
def setUpClass(cls):
|
||||||
|
cls.model = DEFAULT_MLA_MODEL_NAME_FOR_TEST
|
||||||
|
cls.base_url = DEFAULT_URL_FOR_TEST
|
||||||
|
cls.process = popen_launch_server(
|
||||||
|
cls.model,
|
||||||
|
cls.base_url,
|
||||||
|
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||||
|
other_args=[
|
||||||
|
"--trust-remote-code",
|
||||||
|
"--tp",
|
||||||
|
"2",
|
||||||
|
"--ep-size",
|
||||||
|
"2",
|
||||||
|
"--enable-ep-moe",
|
||||||
|
"--quantization",
|
||||||
|
"fp8",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def tearDownClass(cls):
|
||||||
|
kill_process_tree(cls.process.pid)
|
||||||
|
|
||||||
|
def test_mmlu(self):
|
||||||
|
args = SimpleNamespace(
|
||||||
|
base_url=self.base_url,
|
||||||
|
model=self.model,
|
||||||
|
eval_name="mmlu",
|
||||||
|
num_examples=64,
|
||||||
|
num_threads=32,
|
||||||
|
)
|
||||||
|
|
||||||
|
metrics = run_eval(args)
|
||||||
|
assert metrics["score"] >= 0.5
|
||||||
|
|
||||||
|
def test_mgsm_en(self):
|
||||||
|
args = SimpleNamespace(
|
||||||
|
base_url=self.base_url,
|
||||||
|
model=self.model,
|
||||||
|
eval_name="mgsm_en",
|
||||||
|
num_examples=None,
|
||||||
|
num_threads=1024,
|
||||||
|
)
|
||||||
|
|
||||||
|
metrics = run_eval(args)
|
||||||
|
assert metrics["score"] >= 0.8
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
||||||
Reference in New Issue
Block a user