Resubmit MoE-EP (#2371)
This commit is contained in:
0
python/sglang/srt/layers/ep_moe/__init__.py
Normal file
0
python/sglang/srt/layers/ep_moe/__init__.py
Normal file
349
python/sglang/srt/layers/ep_moe/kernels.py
Normal file
349
python/sglang/srt/layers/ep_moe/kernels.py
Normal file
@@ -0,0 +1,349 @@
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def compute_seg_indptr_triton_kernel(reorder_topk_ids, seg_indptr, num_toks):
|
||||
expert = tl.program_id(0)
|
||||
low = 0
|
||||
high = num_toks - 1
|
||||
target_location = -1
|
||||
while low <= high:
|
||||
mid = (low + high) // 2
|
||||
|
||||
if tl.load(reorder_topk_ids + mid) > expert:
|
||||
high = mid - 1
|
||||
else:
|
||||
low = mid + 1
|
||||
target_location = mid
|
||||
tl.store(seg_indptr + expert + 1, target_location + 1)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def compute_src2dst_triton_kernel(
|
||||
reorder_ids, src2dst, num_toks, BLOCK_SIZE: tl.constexpr
|
||||
):
|
||||
pid = tl.program_id(axis=0)
|
||||
dst_id = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
||||
mask = dst_id < num_toks
|
||||
src_id = tl.load(reorder_ids + dst_id, mask=mask)
|
||||
tl.store(src2dst + src_id, dst_id, mask=mask)
|
||||
|
||||
|
||||
def run_moe_ep_preproess(topk_ids: torch.Tensor, num_experts: int):
|
||||
reorder_topk_ids, reorder_ids = torch.sort(topk_ids.view(-1), stable=True)
|
||||
seg_indptr = torch.zeros(num_experts + 1, device=topk_ids.device, dtype=torch.int64)
|
||||
src2dst = torch.empty(topk_ids.numel(), device=topk_ids.device, dtype=torch.int32)
|
||||
|
||||
compute_seg_indptr_triton_kernel[(num_experts,)](
|
||||
reorder_topk_ids, seg_indptr, topk_ids.numel()
|
||||
)
|
||||
|
||||
BLOCK_SIZE = 512
|
||||
grid = (triton.cdiv(topk_ids.numel(), BLOCK_SIZE),)
|
||||
compute_src2dst_triton_kernel[grid](
|
||||
reorder_ids, src2dst, topk_ids.numel(), BLOCK_SIZE
|
||||
)
|
||||
return reorder_topk_ids, src2dst, seg_indptr
|
||||
|
||||
|
||||
@triton.jit
|
||||
def pre_reorder_triton_kernel(
|
||||
input_ptr,
|
||||
gateup_input_ptr,
|
||||
src2dst_ptr,
|
||||
topk_ids_ptr,
|
||||
a1_scales_ptr,
|
||||
start_expert_id,
|
||||
end_expert_id,
|
||||
topk,
|
||||
hidden_size,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
OutDtype = gateup_input_ptr.dtype.element_ty
|
||||
|
||||
src_idx = tl.program_id(0)
|
||||
src2dst_ptr = src2dst_ptr + src_idx * topk
|
||||
topk_ids_ptr = topk_ids_ptr + src_idx * topk
|
||||
|
||||
src_ptr = input_ptr + src_idx * hidden_size
|
||||
for idx in range(topk):
|
||||
expert_id = tl.load(topk_ids_ptr + idx)
|
||||
if expert_id >= start_expert_id and expert_id <= end_expert_id:
|
||||
if a1_scales_ptr is not None:
|
||||
scale = 1.0 / tl.load(a1_scales_ptr + expert_id - start_expert_id)
|
||||
else:
|
||||
scale = 1.0
|
||||
|
||||
dst_idx = tl.load(src2dst_ptr + idx)
|
||||
dst_ptr = gateup_input_ptr + dst_idx * hidden_size
|
||||
for start_offset in tl.range(0, hidden_size, BLOCK_SIZE):
|
||||
offset = start_offset + tl.arange(0, BLOCK_SIZE)
|
||||
mask = offset < hidden_size
|
||||
in_data = tl.load(src_ptr + offset, mask=mask).to(tl.float32)
|
||||
out_data = (in_data * scale).to(OutDtype)
|
||||
tl.store(dst_ptr + offset, out_data, mask=mask)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def silu_and_mul_triton_kernel(
|
||||
gateup_output,
|
||||
down_input,
|
||||
hidden_size,
|
||||
reorder_topk_ids,
|
||||
scales,
|
||||
start_expert_id,
|
||||
end_expert_id,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
InDtype = gateup_output.dtype.element_ty
|
||||
OutDtype = down_input.dtype.element_ty
|
||||
|
||||
half_hidden_size = hidden_size // 2
|
||||
|
||||
pid = tl.program_id(0)
|
||||
expert_id = tl.load(reorder_topk_ids + pid)
|
||||
if expert_id >= start_expert_id and expert_id <= end_expert_id:
|
||||
gateup_output_ptr = gateup_output + pid * hidden_size
|
||||
gate_output_ptr = gateup_output_ptr
|
||||
up_output_ptr = gateup_output_ptr + half_hidden_size
|
||||
down_input_ptr = down_input + pid * half_hidden_size
|
||||
|
||||
if scales is not None:
|
||||
scale = tl.load(scales + expert_id - start_expert_id)
|
||||
scale = (1 / scale).to(InDtype)
|
||||
else:
|
||||
scale = 1
|
||||
|
||||
for start_offset in tl.range(0, half_hidden_size, BLOCK_SIZE):
|
||||
offset = start_offset + tl.arange(0, BLOCK_SIZE)
|
||||
mask = offset < half_hidden_size
|
||||
|
||||
gate_output = tl.load(gate_output_ptr + offset, mask=mask).to(tl.float32)
|
||||
up_output = tl.load(up_output_ptr + offset, mask=mask)
|
||||
|
||||
# silu & mul & quantize
|
||||
gate_output = gate_output * tl.sigmoid(gate_output)
|
||||
gate_output = gate_output.to(InDtype)
|
||||
|
||||
silu_mul_output = gate_output * up_output * scale
|
||||
silu_mul_output = silu_mul_output.to(OutDtype)
|
||||
tl.store(down_input_ptr + offset, silu_mul_output, mask=mask)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def post_reorder_triton_kernel(
|
||||
down_output_ptr,
|
||||
output_ptr,
|
||||
src2dst_ptr,
|
||||
topk_ids_ptr,
|
||||
topk_weights_ptr,
|
||||
start_expert_id,
|
||||
end_expert_id,
|
||||
topk,
|
||||
hidden_size,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
InDtype = down_output_ptr.dtype.element_ty
|
||||
|
||||
src_idx = tl.program_id(0)
|
||||
src2dst_ptr = src2dst_ptr + src_idx * topk
|
||||
topk_ids_ptr = topk_ids_ptr + src_idx * topk
|
||||
topk_weights_ptr = topk_weights_ptr + src_idx * topk
|
||||
|
||||
computed = False
|
||||
store_ptr = output_ptr + src_idx * hidden_size
|
||||
for start_offset in tl.range(0, hidden_size, BLOCK_SIZE):
|
||||
offset = start_offset + tl.arange(0, BLOCK_SIZE)
|
||||
mask = offset < hidden_size
|
||||
|
||||
sum_vec = tl.zeros([BLOCK_SIZE], dtype=InDtype)
|
||||
for idx in range(topk):
|
||||
expert_id = tl.load(topk_ids_ptr + idx)
|
||||
if expert_id >= start_expert_id and expert_id <= end_expert_id:
|
||||
computed = True
|
||||
dst_idx = tl.load(src2dst_ptr + idx)
|
||||
weigh_scale = tl.load(topk_weights_ptr + idx).to(InDtype)
|
||||
load_ptr = down_output_ptr + dst_idx * hidden_size
|
||||
in_data = tl.load(load_ptr + offset, mask=mask)
|
||||
sum_vec += in_data * weigh_scale
|
||||
tl.store(store_ptr + offset, sum_vec, mask=mask)
|
||||
|
||||
if computed == False:
|
||||
for start_offset in tl.range(0, hidden_size, BLOCK_SIZE):
|
||||
offset = start_offset + tl.arange(0, BLOCK_SIZE)
|
||||
mask = offset < hidden_size
|
||||
tl.store(
|
||||
store_ptr + offset, tl.zeros([BLOCK_SIZE], dtype=InDtype), mask=mask
|
||||
)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def compute_m_range(
|
||||
pid,
|
||||
batch_size,
|
||||
seg_indptr,
|
||||
weight_indices,
|
||||
m_num_tiles_indptr,
|
||||
BLOCK_SIZE_M: tl.constexpr,
|
||||
):
|
||||
idx = 0
|
||||
for bs in range(batch_size):
|
||||
tiles = tl.load(m_num_tiles_indptr + bs)
|
||||
if pid >= tiles:
|
||||
idx = bs
|
||||
|
||||
idx_start = tl.load(m_num_tiles_indptr + idx)
|
||||
|
||||
m_range_start = tl.load(seg_indptr + idx) + (pid - idx_start) * BLOCK_SIZE_M
|
||||
m_range_end = min(tl.load(seg_indptr + idx + 1), m_range_start + BLOCK_SIZE_M)
|
||||
expert_id = tl.load(weight_indices + idx)
|
||||
return m_range_start, m_range_end, expert_id
|
||||
|
||||
|
||||
@triton.jit
|
||||
def grouped_gemm_triton_kernel(
|
||||
a,
|
||||
b,
|
||||
c,
|
||||
batch_size,
|
||||
N,
|
||||
K,
|
||||
seg_indptr,
|
||||
weight_indices,
|
||||
m_num_tiles_indptr,
|
||||
use_fp8_w8a8,
|
||||
scale_a,
|
||||
scale_b,
|
||||
a_stride_0: tl.constexpr,
|
||||
b_stride_0: tl.constexpr,
|
||||
b_stride_1: tl.constexpr,
|
||||
BLOCK_SIZE_M: tl.constexpr,
|
||||
BLOCK_SIZE_N: tl.constexpr,
|
||||
BLOCK_SIZE_K: tl.constexpr,
|
||||
):
|
||||
c_dtype = c.dtype.element_ty
|
||||
|
||||
pid_m = tl.program_id(0)
|
||||
pid_n = tl.program_id(1)
|
||||
total_m_block = tl.load(m_num_tiles_indptr + batch_size)
|
||||
if pid_m >= total_m_block:
|
||||
return
|
||||
|
||||
m_range_start, m_range_end, expert_id = compute_m_range(
|
||||
pid_m, batch_size, seg_indptr, weight_indices, m_num_tiles_indptr, BLOCK_SIZE_M
|
||||
)
|
||||
if m_range_end - m_range_start == 0:
|
||||
return
|
||||
|
||||
n_range_start = pid_n * BLOCK_SIZE_N
|
||||
n_range_end = min(n_range_start + BLOCK_SIZE_N, N)
|
||||
|
||||
offs_am = tl.arange(0, BLOCK_SIZE_M)
|
||||
offs_bn = tl.arange(0, BLOCK_SIZE_N)
|
||||
|
||||
offs_am = tl.where(offs_am < m_range_end - m_range_start, offs_am, 0)
|
||||
offs_bn = tl.where(offs_bn < n_range_end - n_range_start, offs_bn, 0)
|
||||
offs_am = tl.max_contiguous(tl.multiple_of(offs_am, BLOCK_SIZE_M), BLOCK_SIZE_M)
|
||||
offs_bn = tl.max_contiguous(tl.multiple_of(offs_bn, BLOCK_SIZE_N), BLOCK_SIZE_N)
|
||||
offs_k = tl.arange(0, BLOCK_SIZE_K)
|
||||
|
||||
a_ptr = a + (m_range_start + offs_am[:, None]) * a_stride_0 + offs_k[None, :]
|
||||
b_ptr = b + (
|
||||
(expert_id * b_stride_0)
|
||||
+ (n_range_start + offs_bn[:, None]) * b_stride_1
|
||||
+ offs_k[None, :]
|
||||
)
|
||||
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
||||
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
|
||||
a_tile = tl.load(
|
||||
a_ptr, mask=offs_k[None, :] < (K - k * BLOCK_SIZE_K), other=0.0
|
||||
)
|
||||
b_tile = tl.load(
|
||||
b_ptr, mask=offs_k[None, :] < (K - k * BLOCK_SIZE_K), other=0.0
|
||||
)
|
||||
accumulator = tl.dot(a_tile, b_tile.T, accumulator)
|
||||
a_ptr += BLOCK_SIZE_K
|
||||
b_ptr += BLOCK_SIZE_K
|
||||
|
||||
if use_fp8_w8a8:
|
||||
scale_a_value = tl.load(scale_a + expert_id)
|
||||
scale_b_value = tl.load(scale_b + expert_id)
|
||||
accumulator *= scale_a_value * scale_b_value
|
||||
c_tile = accumulator.to(c_dtype)
|
||||
|
||||
offs_cm = m_range_start + tl.arange(0, BLOCK_SIZE_M)
|
||||
offs_cn = n_range_start + tl.arange(0, BLOCK_SIZE_N)
|
||||
c_ptr = c + offs_cm[:, None] * N + offs_cn[None, :]
|
||||
c_mask = (offs_cm[:, None] < m_range_end) & (offs_cn[None, :] < n_range_end)
|
||||
tl.store(c_ptr, c_tile, mask=c_mask)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def compute_m_num_tiles_indptr(
|
||||
m_num_tiles_indptr, seg_indptr, batch_size: tl.constexpr, BLOCK_SIZE_M: tl.constexpr
|
||||
):
|
||||
for bs in range(batch_size):
|
||||
m = tl.load(seg_indptr + bs + 1) - tl.load(seg_indptr + bs)
|
||||
cur_num_tiles = tl.cdiv(m, BLOCK_SIZE_M)
|
||||
pre_num_tiles = tl.load(m_num_tiles_indptr + bs)
|
||||
tl.store(m_num_tiles_indptr + bs + 1, pre_num_tiles + cur_num_tiles)
|
||||
|
||||
|
||||
def grouped_gemm_triton(
|
||||
a: torch.Tensor,
|
||||
b: torch.Tensor,
|
||||
c: torch.Tensor,
|
||||
batch_size: int,
|
||||
weight_column_major: bool,
|
||||
seg_indptr: Optional[torch.Tensor] = None,
|
||||
weight_indices: Optional[torch.Tensor] = None,
|
||||
use_fp8_w8a8: bool = False,
|
||||
scale_a: torch.Tensor = None,
|
||||
scale_b: torch.Tensor = None,
|
||||
):
|
||||
assert weight_column_major == True # TODO: more
|
||||
if use_fp8_w8a8:
|
||||
assert scale_a is not None and scale_b is not None
|
||||
|
||||
config = {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
}
|
||||
|
||||
m_num_tiles_indptr = torch.zeros(batch_size + 1, device=a.device, dtype=torch.int64)
|
||||
compute_m_num_tiles_indptr[(1,)](
|
||||
m_num_tiles_indptr, seg_indptr, batch_size, config["BLOCK_SIZE_M"]
|
||||
)
|
||||
|
||||
grid = lambda META: (
|
||||
triton.cdiv(a.size(0), META["BLOCK_SIZE_M"]) + batch_size,
|
||||
triton.cdiv(b.size(1), META["BLOCK_SIZE_N"]),
|
||||
)
|
||||
|
||||
grouped_gemm_triton_kernel[grid](
|
||||
a,
|
||||
b,
|
||||
c,
|
||||
batch_size,
|
||||
b.size(1),
|
||||
b.size(2),
|
||||
seg_indptr,
|
||||
weight_indices,
|
||||
m_num_tiles_indptr,
|
||||
use_fp8_w8a8,
|
||||
scale_a,
|
||||
scale_b,
|
||||
a.stride(0),
|
||||
b.stride(0),
|
||||
b.stride(1),
|
||||
**config,
|
||||
)
|
||||
return c
|
||||
661
python/sglang/srt/layers/ep_moe/layer.py
Normal file
661
python/sglang/srt/layers/ep_moe/layer.py
Normal file
@@ -0,0 +1,661 @@
|
||||
import logging
|
||||
from typing import Callable, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from torch.nn import Module
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.distributed import (
|
||||
get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size,
|
||||
)
|
||||
from vllm.model_executor.custom_op import CustomOp
|
||||
from vllm.model_executor.layers.quantization.fp8 import Fp8Config, Fp8MoEMethod
|
||||
|
||||
from sglang.srt.layers.custom_op_util import register_custom_op
|
||||
from sglang.srt.layers.ep_moe.kernels import (
|
||||
grouped_gemm_triton,
|
||||
post_reorder_triton_kernel,
|
||||
pre_reorder_triton_kernel,
|
||||
run_moe_ep_preproess,
|
||||
silu_and_mul_triton_kernel,
|
||||
)
|
||||
from sglang.srt.layers.fused_moe_triton.fused_moe import fused_topk, grouped_topk
|
||||
from sglang.srt.layers.fused_moe_triton.layer import FusedMoEMethodBase
|
||||
from sglang.srt.layers.quantization.base_config import (
|
||||
QuantizationConfig,
|
||||
QuantizeMethodBase,
|
||||
)
|
||||
from sglang.srt.utils import is_hip, set_weight_attrs
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class GroupedGemmRunner(torch.nn.Module):
|
||||
flashinfer_gemm_warpper = None
|
||||
|
||||
def __init__(self, device, use_flashinfer: bool = False):
|
||||
super().__init__()
|
||||
self.device = device
|
||||
self.use_flashinfer = use_flashinfer
|
||||
if self.use_flashinfer and GroupedGemmRunner.flashinfer_gemm_warpper is None:
|
||||
GroupedGemmRunner._init_flashinfer_wrapper(device)
|
||||
|
||||
@classmethod
|
||||
def _init_flashinfer_wrapper(cls, device):
|
||||
from flashinfer import SegmentGEMMWrapper
|
||||
|
||||
workspace_buffer = torch.empty(
|
||||
128 * 1024 * 1024, dtype=torch.int8, device=device
|
||||
)
|
||||
cls.flashinfer_gemm_warpper = SegmentGEMMWrapper(workspace_buffer)
|
||||
|
||||
# c = a * b
|
||||
def forward(
|
||||
self,
|
||||
a: torch.Tensor,
|
||||
b: torch.Tensor,
|
||||
c: torch.Tensor,
|
||||
batch_size: int,
|
||||
weight_column_major: bool,
|
||||
seg_indptr: Optional[torch.Tensor] = None,
|
||||
weight_indices: Optional[torch.Tensor] = None,
|
||||
use_fp8_w8a8: bool = False,
|
||||
scale_a: torch.Tensor = None,
|
||||
scale_b: torch.Tensor = None,
|
||||
):
|
||||
if self.use_flashinfer:
|
||||
# TODO: flashinfer
|
||||
assert False
|
||||
assert GroupedGemmRunner.flashinfer_gemm_warpper is not None
|
||||
c = GroupedGemmRunner.flashinfer_gemm_warpper.run(
|
||||
x=a,
|
||||
weights=b,
|
||||
batch_size=batch_size,
|
||||
weight_column_major=weight_column_major,
|
||||
seg_indptr=seg_indptr,
|
||||
weight_indices=weight_indices,
|
||||
)
|
||||
else:
|
||||
assert weight_column_major == True
|
||||
c = grouped_gemm_triton(
|
||||
a,
|
||||
b,
|
||||
c,
|
||||
batch_size,
|
||||
weight_column_major,
|
||||
seg_indptr,
|
||||
weight_indices,
|
||||
use_fp8_w8a8,
|
||||
scale_a,
|
||||
scale_b,
|
||||
)
|
||||
return c
|
||||
|
||||
|
||||
class EPMoE(torch.nn.Module):
|
||||
"""
|
||||
MoE Expert Parallel Impl
|
||||
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_experts: int,
|
||||
top_k: int,
|
||||
hidden_size: int,
|
||||
intermediate_size: int,
|
||||
params_dtype: Optional[torch.dtype] = None,
|
||||
renormalize: bool = True,
|
||||
use_grouped_topk: bool = False,
|
||||
num_expert_group: Optional[int] = None,
|
||||
topk_group: Optional[int] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
tp_size: Optional[int] = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
if params_dtype is None:
|
||||
params_dtype = torch.get_default_dtype()
|
||||
|
||||
self.tp_size = (
|
||||
tp_size if tp_size is not None else get_tensor_model_parallel_world_size()
|
||||
)
|
||||
self.tp_rank = get_tensor_model_parallel_rank()
|
||||
|
||||
self.num_experts = num_experts
|
||||
assert self.num_experts % self.tp_size == 0
|
||||
self.num_experts_per_partition = self.num_experts // self.tp_size
|
||||
self.start_expert_id = self.tp_rank * self.num_experts_per_partition
|
||||
self.end_expert_id = self.start_expert_id + self.num_experts_per_partition - 1
|
||||
|
||||
self.top_k = top_k
|
||||
self.intermediate_size = intermediate_size
|
||||
self.renormalize = renormalize
|
||||
self.use_grouped_topk = use_grouped_topk
|
||||
if self.use_grouped_topk:
|
||||
assert num_expert_group is not None and topk_group is not None
|
||||
self.num_expert_group = num_expert_group
|
||||
self.topk_group = topk_group
|
||||
|
||||
if quant_config is None:
|
||||
self.quant_method: Optional[QuantizeMethodBase] = UnquantizedEPMoEMethod()
|
||||
self.use_fp8_w8a8 = False
|
||||
self.activation_scheme = None
|
||||
else:
|
||||
self.quant_method: Optional[QuantizeMethodBase] = Fp8EPMoEMethod(
|
||||
quant_config
|
||||
)
|
||||
self.use_fp8_w8a8 = True
|
||||
self.fp8_dtype = torch.float8_e4m3fn
|
||||
self.activation_scheme = quant_config.activation_scheme
|
||||
|
||||
self.quant_method.create_weights(
|
||||
layer=self,
|
||||
num_experts_per_partition=self.num_experts_per_partition,
|
||||
hidden_size=hidden_size,
|
||||
intermediate_size=self.intermediate_size,
|
||||
params_dtype=params_dtype,
|
||||
weight_loader=self.weight_loader,
|
||||
)
|
||||
|
||||
self.grouped_gemm_runner = None
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor):
|
||||
assert self.quant_method is not None
|
||||
|
||||
if self.grouped_gemm_runner is None:
|
||||
self.grouped_gemm_runner = GroupedGemmRunner(
|
||||
hidden_states.device, use_flashinfer=False # TODO: use flashinfer
|
||||
)
|
||||
|
||||
topk_weights, topk_ids = self.select_experts(
|
||||
hidden_states,
|
||||
router_logits,
|
||||
self.top_k,
|
||||
self.renormalize,
|
||||
self.topk_group,
|
||||
self.num_expert_group,
|
||||
)
|
||||
|
||||
reorder_topk_ids, src2dst, seg_indptr = run_moe_ep_preproess(
|
||||
topk_ids, self.num_experts
|
||||
)
|
||||
|
||||
gateup_input = torch.empty(
|
||||
(int(hidden_states.shape[0] * self.top_k), hidden_states.shape[1]),
|
||||
device=hidden_states.device,
|
||||
dtype=self.fp8_dtype if self.use_fp8_w8a8 else hidden_states.dtype,
|
||||
)
|
||||
if self.activation_scheme == "dynamic":
|
||||
max_value = (
|
||||
torch.max(hidden_states)
|
||||
.repeat(self.num_experts_per_partition)
|
||||
.to(torch.float32)
|
||||
)
|
||||
self.w13_input_scale = max_value / torch.finfo(self.fp8_dtype).max
|
||||
|
||||
# PreReorder
|
||||
pre_reorder_triton_kernel[(hidden_states.shape[0],)](
|
||||
hidden_states,
|
||||
gateup_input,
|
||||
src2dst,
|
||||
topk_ids,
|
||||
self.w13_input_scale,
|
||||
self.start_expert_id,
|
||||
self.end_expert_id,
|
||||
self.top_k,
|
||||
hidden_states.shape[1],
|
||||
BLOCK_SIZE=512,
|
||||
)
|
||||
|
||||
seg_indptr_cur_rank = seg_indptr[self.start_expert_id : self.end_expert_id + 2]
|
||||
weight_indices_cur_rank = torch.arange(
|
||||
0,
|
||||
self.num_experts_per_partition,
|
||||
device=hidden_states.device,
|
||||
dtype=torch.int64,
|
||||
)
|
||||
# GroupGemm-0
|
||||
gateup_output = torch.empty(
|
||||
gateup_input.shape[0],
|
||||
self.w13_weight.shape[1],
|
||||
device=hidden_states.device,
|
||||
dtype=hidden_states.dtype,
|
||||
)
|
||||
gateup_output = self.grouped_gemm_runner(
|
||||
a=gateup_input,
|
||||
b=self.w13_weight,
|
||||
c=gateup_output,
|
||||
batch_size=self.num_experts_per_partition,
|
||||
weight_column_major=True,
|
||||
seg_indptr=seg_indptr_cur_rank,
|
||||
weight_indices=weight_indices_cur_rank,
|
||||
use_fp8_w8a8=self.use_fp8_w8a8,
|
||||
scale_a=self.w13_input_scale,
|
||||
scale_b=self.w13_weight_scale,
|
||||
)
|
||||
|
||||
# Act
|
||||
down_input = torch.empty(
|
||||
gateup_output.shape[0],
|
||||
gateup_output.shape[1] // 2,
|
||||
device=gateup_output.device,
|
||||
dtype=self.fp8_dtype if self.use_fp8_w8a8 else hidden_states.dtype,
|
||||
)
|
||||
if self.w2_input_scale is None:
|
||||
self.w2_input_scale = torch.ones(
|
||||
self.num_experts_per_partition,
|
||||
dtype=torch.float32,
|
||||
device=hidden_states.device,
|
||||
)
|
||||
silu_and_mul_triton_kernel[(gateup_output.shape[0],)](
|
||||
gateup_output,
|
||||
down_input,
|
||||
gateup_output.shape[1],
|
||||
reorder_topk_ids,
|
||||
self.w2_input_scale,
|
||||
self.start_expert_id,
|
||||
self.end_expert_id,
|
||||
BLOCK_SIZE=512,
|
||||
)
|
||||
|
||||
# GroupGemm-1
|
||||
down_output = torch.empty(
|
||||
down_input.shape[0],
|
||||
self.w2_weight.shape[1],
|
||||
device=hidden_states.device,
|
||||
dtype=hidden_states.dtype,
|
||||
)
|
||||
down_output = self.grouped_gemm_runner(
|
||||
a=down_input,
|
||||
b=self.w2_weight,
|
||||
c=down_output,
|
||||
batch_size=self.num_experts_per_partition,
|
||||
weight_column_major=True,
|
||||
seg_indptr=seg_indptr_cur_rank,
|
||||
weight_indices=weight_indices_cur_rank,
|
||||
use_fp8_w8a8=self.use_fp8_w8a8,
|
||||
scale_a=self.w2_input_scale,
|
||||
scale_b=self.w2_weight_scale,
|
||||
)
|
||||
|
||||
# PostReorder
|
||||
output = torch.empty_like(hidden_states)
|
||||
post_reorder_triton_kernel[(hidden_states.size(0),)](
|
||||
down_output,
|
||||
output,
|
||||
src2dst,
|
||||
topk_ids,
|
||||
topk_weights,
|
||||
self.start_expert_id,
|
||||
self.end_expert_id,
|
||||
self.top_k,
|
||||
hidden_states.size(1),
|
||||
BLOCK_SIZE=512,
|
||||
)
|
||||
return output
|
||||
|
||||
def select_experts(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
top_k: int,
|
||||
renormalize: bool,
|
||||
topk_group: Optional[int] = None,
|
||||
num_expert_group: Optional[int] = None,
|
||||
):
|
||||
if self.use_grouped_topk:
|
||||
assert topk_group is not None
|
||||
assert num_expert_group is not None
|
||||
topk_weights, topk_ids = grouped_topk(
|
||||
hidden_states=hidden_states,
|
||||
gating_output=router_logits,
|
||||
topk=top_k,
|
||||
renormalize=renormalize,
|
||||
num_expert_group=num_expert_group,
|
||||
topk_group=topk_group,
|
||||
)
|
||||
else:
|
||||
topk_weights, topk_ids = fused_topk(
|
||||
hidden_states=hidden_states,
|
||||
gating_output=router_logits,
|
||||
topk=top_k,
|
||||
renormalize=renormalize,
|
||||
)
|
||||
return topk_weights, topk_ids.to(torch.int32)
|
||||
|
||||
@classmethod
|
||||
def make_expert_params_mapping(
|
||||
cls,
|
||||
ckpt_gate_proj_name: str,
|
||||
ckpt_down_proj_name: str,
|
||||
ckpt_up_proj_name: str,
|
||||
num_experts: int,
|
||||
) -> List[Tuple[str, str, int, str]]:
|
||||
|
||||
return [
|
||||
# (param_name, weight_name, expert_id, shard_id)
|
||||
(
|
||||
(
|
||||
"experts.w13_"
|
||||
if weight_name in [ckpt_gate_proj_name, ckpt_up_proj_name]
|
||||
else "experts.w2_"
|
||||
),
|
||||
f"experts.{expert_id}.{weight_name}.",
|
||||
expert_id,
|
||||
shard_id,
|
||||
)
|
||||
for expert_id in range(num_experts)
|
||||
for shard_id, weight_name in [
|
||||
("w1", ckpt_gate_proj_name),
|
||||
("w2", ckpt_down_proj_name),
|
||||
("w3", ckpt_up_proj_name),
|
||||
]
|
||||
]
|
||||
|
||||
def weight_loader(
|
||||
self,
|
||||
param: torch.nn.Parameter,
|
||||
loaded_weight: torch.Tensor,
|
||||
weight_name: str,
|
||||
shard_id: str,
|
||||
expert_id: int,
|
||||
) -> None:
|
||||
if expert_id < self.start_expert_id or expert_id > self.end_expert_id:
|
||||
return
|
||||
expert_id = expert_id - self.start_expert_id
|
||||
|
||||
if shard_id not in ("w1", "w2", "w3"):
|
||||
raise ValueError(
|
||||
f"shard_id must be ['w1','w2','w3'] but " f"got {shard_id}."
|
||||
)
|
||||
|
||||
# Special case for fp8 scales.
|
||||
if "scale" in weight_name:
|
||||
self._load_fp8_scale(
|
||||
param.data, loaded_weight, weight_name, shard_id, expert_id
|
||||
)
|
||||
return
|
||||
|
||||
expert_data = param.data[expert_id]
|
||||
if shard_id == "w2":
|
||||
param.data[expert_id] = loaded_weight
|
||||
elif shard_id == "w1":
|
||||
param.data[expert_id][: self.intermediate_size, :] = loaded_weight
|
||||
elif shard_id == "w3":
|
||||
param.data[expert_id][self.intermediate_size :, :] = loaded_weight
|
||||
else:
|
||||
raise ValueError(f"Expected shard_id w1,w2 or w3 but got {shard_id}")
|
||||
|
||||
def _load_fp8_scale(
|
||||
self,
|
||||
param: torch.nn.Parameter,
|
||||
loaded_weight: torch.Tensor,
|
||||
weight_name: str,
|
||||
shard_id: str,
|
||||
expert_id: int,
|
||||
) -> None:
|
||||
param_data = param.data
|
||||
|
||||
# Input scales can be loaded directly and should be equal.
|
||||
if "input_scale" in weight_name:
|
||||
if (
|
||||
param_data[expert_id] != 1
|
||||
and (param_data[expert_id] - loaded_weight).abs() > 1e-5
|
||||
):
|
||||
raise ValueError(
|
||||
"input_scales of w1 and w3 of a layer "
|
||||
f"must be equal. But got {param_data[expert_id]} "
|
||||
f"vs. {loaded_weight}"
|
||||
)
|
||||
param_data[expert_id] = loaded_weight
|
||||
# Weight scales
|
||||
elif "weight_scale" in weight_name:
|
||||
# If we are in merged column case (gate_up_proj)
|
||||
if shard_id in ("w1", "w3"):
|
||||
# We have to keep the weight scales of w1 and w3 because
|
||||
# we need to re-quantize w1/w3 weights after weight loading.
|
||||
idx = 0 if shard_id == "w1" else 1
|
||||
param_data[expert_id][idx] = loaded_weight
|
||||
# If we are in the row parallel case (down_proj)
|
||||
else:
|
||||
param_data[expert_id] = loaded_weight
|
||||
|
||||
|
||||
@register_custom_op("sglang_unquantized_ep_moe")
|
||||
class UnquantizedEPMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
def create_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
num_experts_per_partition: int,
|
||||
hidden_size: int,
|
||||
intermediate_size: int,
|
||||
params_dtype: torch.dtype,
|
||||
**extra_weight_attrs,
|
||||
):
|
||||
# Fused gate_up_proj (column parallel)
|
||||
w13_weight = torch.nn.Parameter(
|
||||
torch.empty(
|
||||
num_experts_per_partition,
|
||||
2 * intermediate_size,
|
||||
hidden_size,
|
||||
dtype=params_dtype,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w13_weight", w13_weight)
|
||||
set_weight_attrs(w13_weight, extra_weight_attrs)
|
||||
|
||||
# down_proj (row parallel)
|
||||
w2_weight = torch.nn.Parameter(
|
||||
torch.empty(
|
||||
num_experts_per_partition,
|
||||
hidden_size,
|
||||
intermediate_size,
|
||||
dtype=params_dtype,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w2_weight", w2_weight)
|
||||
set_weight_attrs(w2_weight, extra_weight_attrs)
|
||||
|
||||
# scale
|
||||
ones_tensor = torch.ones(num_experts_per_partition, dtype=torch.float32)
|
||||
w13_input_scale = torch.nn.Parameter(
|
||||
ones_tensor,
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w13_input_scale", w13_input_scale)
|
||||
set_weight_attrs(w13_input_scale, extra_weight_attrs)
|
||||
|
||||
w2_input_scale = torch.nn.Parameter(
|
||||
ones_tensor,
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w2_input_scale", w2_input_scale)
|
||||
set_weight_attrs(w2_input_scale, extra_weight_attrs)
|
||||
|
||||
w13_weight_scale = torch.nn.Parameter(
|
||||
ones_tensor,
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w13_weight_scale", w13_weight_scale)
|
||||
set_weight_attrs(w13_weight_scale, extra_weight_attrs)
|
||||
|
||||
w2_weight_scale = torch.nn.Parameter(
|
||||
ones_tensor,
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w2_weight_scale", w2_weight_scale)
|
||||
set_weight_attrs(w2_weight_scale, extra_weight_attrs)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
top_k: int,
|
||||
renormalize: bool,
|
||||
use_grouped_topk: bool,
|
||||
topk_group: Optional[int] = None,
|
||||
num_expert_group: Optional[int] = None,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
) -> torch.Tensor:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class Fp8EPMoEMethod(Fp8MoEMethod):
|
||||
"""MoE method for FP8.
|
||||
Supports loading FP8 checkpoints with static weight scale and
|
||||
dynamic/static activation scale.
|
||||
|
||||
Args:
|
||||
quant_config: The quantization config.
|
||||
"""
|
||||
|
||||
def __init__(self, quant_config: Fp8Config):
|
||||
self.quant_config = quant_config
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
layer: Module,
|
||||
num_experts_per_partition: int,
|
||||
hidden_size: int,
|
||||
intermediate_size: int,
|
||||
params_dtype: torch.dtype,
|
||||
**extra_weight_attrs,
|
||||
):
|
||||
|
||||
if self.quant_config.is_checkpoint_fp8_serialized:
|
||||
params_dtype = torch.float8_e4m3fn
|
||||
|
||||
# WEIGHTS
|
||||
w13_weight = torch.nn.Parameter(
|
||||
torch.empty(
|
||||
num_experts_per_partition,
|
||||
2 * intermediate_size,
|
||||
hidden_size,
|
||||
dtype=params_dtype,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w13_weight", w13_weight)
|
||||
set_weight_attrs(w13_weight, extra_weight_attrs)
|
||||
|
||||
w2_weight = torch.nn.Parameter(
|
||||
torch.empty(
|
||||
num_experts_per_partition,
|
||||
hidden_size,
|
||||
intermediate_size,
|
||||
dtype=params_dtype,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w2_weight", w2_weight)
|
||||
set_weight_attrs(w2_weight, extra_weight_attrs)
|
||||
|
||||
# WEIGHT_SCALES
|
||||
# Allocate 2 scales for w1 and w3 respectively.
|
||||
w13_weight_scale = torch.nn.Parameter(
|
||||
torch.ones(num_experts_per_partition, 2, dtype=torch.float32),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w13_weight_scale", w13_weight_scale)
|
||||
|
||||
w2_weight_scale = torch.nn.Parameter(
|
||||
torch.ones(num_experts_per_partition, dtype=torch.float32),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w2_weight_scale", w2_weight_scale)
|
||||
# Add the quantization method used (per tensor/grouped/channel)
|
||||
# to ensure the weight scales are loaded in properly
|
||||
extra_weight_attrs.update({"quant_method": "tensor"})
|
||||
# If loading fp8 checkpoint, pass the weight loaders.
|
||||
# If loading an fp16 checkpoint, do not (we will quantize in
|
||||
# process_weights_after_loading()
|
||||
if self.quant_config.is_checkpoint_fp8_serialized:
|
||||
set_weight_attrs(w13_weight_scale, extra_weight_attrs)
|
||||
set_weight_attrs(w2_weight_scale, extra_weight_attrs)
|
||||
|
||||
# INPUT_SCALES
|
||||
if self.quant_config.activation_scheme == "static":
|
||||
if not self.quant_config.is_checkpoint_fp8_serialized:
|
||||
raise ValueError(
|
||||
"Found static activation scheme for checkpoint that "
|
||||
"was not serialized fp8."
|
||||
)
|
||||
|
||||
w13_input_scale = torch.nn.Parameter(
|
||||
torch.ones(num_experts_per_partition, dtype=torch.float32),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w13_input_scale", w13_input_scale)
|
||||
set_weight_attrs(w13_input_scale, extra_weight_attrs)
|
||||
|
||||
w2_input_scale = torch.nn.Parameter(
|
||||
torch.ones(num_experts_per_partition, dtype=torch.float32),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w2_input_scale", w2_input_scale)
|
||||
set_weight_attrs(w2_input_scale, extra_weight_attrs)
|
||||
|
||||
else:
|
||||
layer.w13_input_scale = None
|
||||
layer.w2_input_scale = None
|
||||
|
||||
def process_weights_after_loading(self, layer: Module) -> None:
|
||||
|
||||
# If checkpoint is fp16, quantize in place.
|
||||
if not self.quant_config.is_checkpoint_fp8_serialized:
|
||||
# If rocm, use float8_e4m3fnuz as dtype
|
||||
fp8_dtype = torch.float8_e4m3fnuz if is_hip() else torch.float8_e4m3fn
|
||||
w13_weight = torch.empty_like(layer.w13_weight.data, dtype=fp8_dtype)
|
||||
w2_weight = torch.empty_like(layer.w2_weight.data, dtype=fp8_dtype)
|
||||
|
||||
layer.w13_weight_scale = torch.nn.Parameter(
|
||||
torch.ones(
|
||||
layer.num_experts_per_partition,
|
||||
dtype=torch.float32,
|
||||
device=w13_weight.device,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
|
||||
for expert in range(layer.num_experts_per_partition):
|
||||
w13_weight[expert, :, :], layer.w13_weight_scale[expert] = (
|
||||
ops.scaled_fp8_quant(layer.w13_weight.data[expert, :, :])
|
||||
)
|
||||
w2_weight[expert, :, :], layer.w2_weight_scale[expert] = (
|
||||
ops.scaled_fp8_quant(layer.w2_weight.data[expert, :, :])
|
||||
)
|
||||
layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False)
|
||||
layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False)
|
||||
return
|
||||
|
||||
# If checkpoint is fp8, we need to handle that the
|
||||
# MoE kernels require single activation scale and single weight
|
||||
# scale for w13 per expert.
|
||||
else:
|
||||
if self.quant_config.activation_scheme == "static":
|
||||
if layer.w13_input_scale is None or layer.w2_input_scale is None:
|
||||
raise ValueError(
|
||||
"QuantConfig has static quantization, but found "
|
||||
"activation scales are None."
|
||||
)
|
||||
return
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
top_k: int,
|
||||
renormalize: bool,
|
||||
use_grouped_topk: bool,
|
||||
topk_group: Optional[int] = None,
|
||||
num_expert_group: Optional[int] = None,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
) -> torch.Tensor:
|
||||
raise NotImplementedError
|
||||
@@ -58,6 +58,7 @@ global_server_args_dict = {
|
||||
"torchao_config": ServerArgs.torchao_config,
|
||||
"enable_nan_detection": ServerArgs.enable_nan_detection,
|
||||
"enable_dp_attention": ServerArgs.enable_dp_attention,
|
||||
"enable_ep_moe": ServerArgs.enable_ep_moe,
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -141,6 +141,7 @@ class ModelRunner:
|
||||
"torchao_config": server_args.torchao_config,
|
||||
"enable_nan_detection": server_args.enable_nan_detection,
|
||||
"enable_dp_attention": server_args.enable_dp_attention,
|
||||
"enable_ep_moe": server_args.enable_ep_moe,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@@ -31,6 +31,7 @@ from vllm.distributed import (
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
|
||||
from sglang.srt.layers.activation import SiluAndMul
|
||||
from sglang.srt.layers.ep_moe.layer import EPMoE
|
||||
from sglang.srt.layers.fused_moe_triton import FusedMoE
|
||||
from sglang.srt.layers.layernorm import RMSNorm
|
||||
from sglang.srt.layers.linear import (
|
||||
@@ -113,12 +114,12 @@ class DeepseekV2MoE(nn.Module):
|
||||
"Only silu is supported for now."
|
||||
)
|
||||
|
||||
self.experts = FusedMoE(
|
||||
MoEImpl = EPMoE if global_server_args_dict["enable_ep_moe"] else FusedMoE
|
||||
self.experts = MoEImpl(
|
||||
num_experts=config.n_routed_experts,
|
||||
top_k=config.num_experts_per_tok,
|
||||
hidden_size=config.hidden_size,
|
||||
intermediate_size=config.moe_intermediate_size,
|
||||
reduce_results=False,
|
||||
renormalize=config.norm_topk_prob,
|
||||
quant_config=quant_config,
|
||||
use_grouped_topk=True,
|
||||
@@ -834,7 +835,8 @@ class DeepseekV2ForCausalLM(nn.Module):
|
||||
|
||||
# Params for weights, fp8 weight scales, fp8 activation scales
|
||||
# (param_name, weight_name, expert_id, shard_id)
|
||||
expert_params_mapping = FusedMoE.make_expert_params_mapping(
|
||||
MoEImpl = EPMoE if global_server_args_dict["enable_ep_moe"] else FusedMoE
|
||||
expert_params_mapping = MoEImpl.make_expert_params_mapping(
|
||||
ckpt_gate_proj_name="gate_proj",
|
||||
ckpt_down_proj_name="down_proj",
|
||||
ckpt_up_proj_name="up_proj",
|
||||
|
||||
@@ -21,9 +21,13 @@ from typing import Iterable, Optional, Tuple
|
||||
import torch
|
||||
from torch import nn
|
||||
from transformers import MixtralConfig
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.distributed import (
|
||||
get_tensor_model_parallel_world_size,
|
||||
tensor_model_parallel_all_reduce,
|
||||
)
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
|
||||
from sglang.srt.layers.ep_moe.layer import EPMoE
|
||||
from sglang.srt.layers.fused_moe_triton import FusedMoE
|
||||
from sglang.srt.layers.layernorm import RMSNorm
|
||||
from sglang.srt.layers.linear import (
|
||||
@@ -38,6 +42,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
|
||||
ParallelLMHead,
|
||||
VocabParallelEmbedding,
|
||||
)
|
||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
||||
|
||||
@@ -63,6 +68,7 @@ class MixtralMoE(nn.Module):
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
self.hidden_size = hidden_size
|
||||
|
||||
# Gate always runs at half / full precision for now.
|
||||
@@ -74,14 +80,13 @@ class MixtralMoE(nn.Module):
|
||||
quant_config=None,
|
||||
prefix=f"{prefix}.gate",
|
||||
)
|
||||
|
||||
self.experts = FusedMoE(
|
||||
MoEImpl = EPMoE if global_server_args_dict["enable_ep_moe"] else FusedMoE
|
||||
self.experts = MoEImpl(
|
||||
num_experts=num_experts,
|
||||
top_k=top_k,
|
||||
hidden_size=hidden_size,
|
||||
intermediate_size=intermediate_size,
|
||||
params_dtype=params_dtype,
|
||||
reduce_results=True,
|
||||
renormalize=True,
|
||||
quant_config=quant_config,
|
||||
tp_size=tp_size,
|
||||
@@ -95,6 +100,8 @@ class MixtralMoE(nn.Module):
|
||||
# router_logits: (num_tokens, n_experts)
|
||||
router_logits, _ = self.gate(hidden_states)
|
||||
final_hidden_states = self.experts(hidden_states, router_logits)
|
||||
if self.tp_size > 1:
|
||||
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
|
||||
return final_hidden_states.view(orig_shape)
|
||||
|
||||
|
||||
@@ -319,7 +326,8 @@ class MixtralForCausalLM(nn.Module):
|
||||
|
||||
# Params for weights, fp8 weight scales, fp8 activation scales
|
||||
# (param_name, weight_name, expert_id, shard_id)
|
||||
expert_params_mapping = FusedMoE.make_expert_params_mapping(
|
||||
MoEImpl = EPMoE if global_server_args_dict["enable_ep_moe"] else FusedMoE
|
||||
expert_params_mapping = MoEImpl.make_expert_params_mapping(
|
||||
ckpt_gate_proj_name="w1",
|
||||
ckpt_down_proj_name="w2",
|
||||
ckpt_up_proj_name="w3",
|
||||
|
||||
@@ -93,6 +93,8 @@ class ServerArgs:
|
||||
# Data parallelism
|
||||
dp_size: int = 1
|
||||
load_balance_method: str = "round_robin"
|
||||
# Expert parallelism
|
||||
ep_size: int = 1
|
||||
|
||||
# Multi-node distributed serving
|
||||
dist_init_addr: Optional[str] = None
|
||||
@@ -130,6 +132,7 @@ class ServerArgs:
|
||||
disable_overlap_schedule: bool = False
|
||||
enable_mixed_chunk: bool = False
|
||||
enable_dp_attention: bool = False
|
||||
enable_ep_moe: bool = False
|
||||
enable_torch_compile: bool = False
|
||||
torch_compile_max_bs: int = 32
|
||||
cuda_graph_max_bs: Optional[int] = None
|
||||
@@ -216,6 +219,12 @@ class ServerArgs:
|
||||
"Data parallel size is adjusted to be the same as tensor parallel size. "
|
||||
"Overlap scheduler is disabled."
|
||||
)
|
||||
# Expert parallelism
|
||||
if self.enable_ep_moe:
|
||||
self.ep_size = self.tp_size
|
||||
logger.info(
|
||||
f"EP MoE is enabled. The expert parallel size is adjusted to be the same as the tensor parallel size[{self.tp_size}]."
|
||||
)
|
||||
|
||||
# GGUF
|
||||
if (
|
||||
@@ -526,6 +535,14 @@ class ServerArgs:
|
||||
"shortest_queue",
|
||||
],
|
||||
)
|
||||
# Expert parallelism
|
||||
parser.add_argument(
|
||||
"--expert-parallel-size",
|
||||
"--ep-size",
|
||||
type=int,
|
||||
default=ServerArgs.ep_size,
|
||||
help="The expert parallelism size.",
|
||||
)
|
||||
|
||||
# Multi-node distributed serving
|
||||
parser.add_argument(
|
||||
@@ -681,6 +698,11 @@ class ServerArgs:
|
||||
action="store_true",
|
||||
help="Enabling data parallelism for attention and tensor parallelism for FFN. The dp size should be equal to the tp size. Currently only DeepSeek-V2 is supported.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--enable-ep-moe",
|
||||
action="store_true",
|
||||
help="Enabling expert parallelism for moe. The ep size is equal to the tp size.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--enable-torch-compile",
|
||||
action="store_true",
|
||||
@@ -760,6 +782,7 @@ class ServerArgs:
|
||||
def from_cli_args(cls, args: argparse.Namespace):
|
||||
args.tp_size = args.tensor_parallel_size
|
||||
args.dp_size = args.data_parallel_size
|
||||
args.ep_size = args.expert_parallel_size
|
||||
attrs = [attr.name for attr in dataclasses.fields(cls)]
|
||||
return cls(**{attr: getattr(args, attr) for attr in attrs})
|
||||
|
||||
|
||||
Reference in New Issue
Block a user