Reduce MoE memory usage (#6147)

This commit is contained in:
fzyzcjy
2025-05-16 00:38:28 +08:00
committed by GitHub
parent cfc9f9ab8d
commit f194e14fb7
4 changed files with 75 additions and 40 deletions

View File

@@ -49,7 +49,7 @@ from sglang.srt.layers.quantization.base_config import (
from sglang.srt.layers.quantization.fp8 import Fp8Config, Fp8MoEMethod
from sglang.srt.layers.quantization.fp8_kernel import scaled_fp8_quant
from sglang.srt.model_executor.forward_batch_info import ForwardMode
from sglang.srt.utils import DeepEPMode, is_hip, set_weight_attrs
from sglang.srt.utils import DeepEPMode, dispose_tensor, is_hip, set_weight_attrs
_is_hip = is_hip()
@@ -92,6 +92,7 @@ class GroupedGemmRunner(torch.nn.Module):
scale_a: torch.Tensor = None,
scale_b: torch.Tensor = None,
block_shape: Optional[List[int]] = None,
c_dtype=None,
):
if self.use_flashinfer:
# TODO: flashinfer
@@ -119,6 +120,7 @@ class GroupedGemmRunner(torch.nn.Module):
scale_a,
scale_b,
block_shape=block_shape,
c_dtype=c_dtype,
)
return c
@@ -210,6 +212,10 @@ class EPMoE(torch.nn.Module):
self.grouped_gemm_runner = None
def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor):
hidden_states_shape = hidden_states.shape
hidden_states_dtype = hidden_states.dtype
hidden_states_device = hidden_states.device
assert self.quant_method is not None
if self.grouped_gemm_runner is None:
@@ -265,25 +271,21 @@ class EPMoE(torch.nn.Module):
hidden_states.shape[1],
BLOCK_SIZE=512,
)
dispose_tensor(hidden_states)
seg_indptr_cur_rank = seg_indptr[self.start_expert_id : self.end_expert_id + 2]
weight_indices_cur_rank = torch.arange(
0,
self.num_experts_per_partition,
device=hidden_states.device,
device=hidden_states_device,
dtype=torch.int64,
)
# GroupGemm-0
gateup_output = torch.empty(
gateup_input.shape[0],
self.w13_weight.shape[1],
device=hidden_states.device,
dtype=hidden_states.dtype,
)
gateup_output = self.grouped_gemm_runner(
a=gateup_input,
b=self.w13_weight,
c=gateup_output,
c=None,
c_dtype=hidden_states_dtype,
batch_size=self.num_experts_per_partition,
weight_column_major=True,
seg_indptr=seg_indptr_cur_rank,
@@ -297,6 +299,7 @@ class EPMoE(torch.nn.Module):
),
block_shape=self.block_shape,
)
del gateup_input
# Act
down_input = torch.empty(
@@ -306,14 +309,14 @@ class EPMoE(torch.nn.Module):
dtype=(
self.fp8_dtype
if (self.use_fp8_w8a8 and not self.use_block_quant)
else hidden_states.dtype
else hidden_states_dtype
),
)
if self.w2_input_scale is None and not self.use_block_quant:
self.w2_input_scale = torch.ones(
self.num_experts_per_partition,
dtype=torch.float32,
device=hidden_states.device,
device=hidden_states_device,
)
if self.activation == "silu":
@@ -340,13 +343,14 @@ class EPMoE(torch.nn.Module):
)
else:
raise ValueError(f"Unsupported activation: {self.activation=}")
del gateup_output
# GroupGemm-1
down_output = torch.empty(
down_input.shape[0],
self.w2_weight.shape[1],
device=hidden_states.device,
dtype=hidden_states.dtype,
device=hidden_states_device,
dtype=hidden_states_dtype,
)
down_output = self.grouped_gemm_runner(
a=down_input,
@@ -365,10 +369,13 @@ class EPMoE(torch.nn.Module):
),
block_shape=self.block_shape,
)
del down_input
# PostReorder
output = torch.empty_like(hidden_states)
post_reorder_triton_kernel[(hidden_states.size(0),)](
output = torch.empty(
hidden_states_shape, dtype=hidden_states_dtype, device=hidden_states_device
)
post_reorder_triton_kernel[(hidden_states_shape[0],)](
down_output,
output,
src2dst,
@@ -377,7 +384,7 @@ class EPMoE(torch.nn.Module):
self.start_expert_id,
self.end_expert_id,
self.top_k,
hidden_states.size(1),
hidden_states_shape[1],
BLOCK_SIZE=512,
)
return output
@@ -881,6 +888,9 @@ class DeepEPMoE(EPMoE):
reorder_topk_ids: torch.Tensor,
seg_indptr: torch.Tensor,
):
hidden_states_dtype = hidden_states.dtype
hidden_states_device = hidden_states.device
assert self.quant_method is not None
assert self.activation == "silu"
if self.grouped_gemm_runner is None:
@@ -903,18 +913,12 @@ class DeepEPMoE(EPMoE):
)
# GroupGemm-0
gateup_output = torch.empty(
hidden_states.shape[0],
self.w13_weight.shape[1],
device=hidden_states.device,
dtype=hidden_states.dtype,
)
if hidden_states.shape[0] > 0:
gateup_output = self.grouped_gemm_runner(
a=hidden_states,
b=self.w13_weight,
c=gateup_output,
c=None,
c_dtype=hidden_states.dtype,
batch_size=self.num_experts_per_partition,
weight_column_major=True,
seg_indptr=seg_indptr,
@@ -928,6 +932,13 @@ class DeepEPMoE(EPMoE):
),
block_shape=self.block_shape,
)
else:
gateup_output = torch.empty(
hidden_states.shape[0],
self.w13_weight.shape[1],
device=hidden_states.device,
dtype=hidden_states.dtype,
)
# Act
down_input = torch.empty(
@@ -937,14 +948,14 @@ class DeepEPMoE(EPMoE):
dtype=(
self.fp8_dtype
if (self.use_fp8_w8a8 and not self.use_block_quant)
else hidden_states.dtype
else hidden_states_dtype
),
)
if self.w2_input_scale is None and not self.use_block_quant:
self.w2_input_scale = torch.ones(
self.num_experts_per_partition,
dtype=torch.float32,
device=hidden_states.device,
device=hidden_states_device,
)
if self.activation == "silu":
@@ -961,12 +972,14 @@ class DeepEPMoE(EPMoE):
else:
raise ValueError(f"Unsupported activation: {self.activation=}")
del gateup_output
# GroupGemm-1
down_output = torch.empty(
down_input.shape[0],
self.w2_weight.shape[1],
device=hidden_states.device,
dtype=hidden_states.dtype,
device=hidden_states_device,
dtype=hidden_states_dtype,
)
if down_input.shape[0] > 0:
down_output = self.grouped_gemm_runner(
@@ -1007,11 +1020,9 @@ class DeepEPMoE(EPMoE):
N = self.w13_weight.size(1)
scale_block_size = 128
gather_out = torch.empty_like(
hidden_states_fp8,
device=hidden_states_fp8.device,
dtype=torch.bfloat16,
)
hidden_states_fp8_shape = hidden_states_fp8.shape
hidden_states_fp8_device = hidden_states_fp8.device
hidden_states_fp8_dtype = hidden_states_fp8.dtype
input_tensor = [
torch.empty(
@@ -1049,16 +1060,18 @@ class DeepEPMoE(EPMoE):
m_indices,
output_index,
)
dispose_tensor(hidden_states_fp8)
gateup_output = torch.empty(
(all_tokens, N),
device=hidden_states_fp8.device,
device=hidden_states_fp8_device,
dtype=torch.bfloat16,
)
input_tensor[1] = tma_align_input_scale(input_tensor[1])
m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(
input_tensor, self.w13_weight_fp8, gateup_output, m_indices
)
del input_tensor
down_input = torch.empty(
(
all_tokens,
@@ -1068,14 +1081,16 @@ class DeepEPMoE(EPMoE):
dtype=torch.bfloat16,
)
silu_and_mul(gateup_output.view(-1, N), down_input)
del gateup_output
down_output = torch.empty(
(all_tokens, K),
device=hidden_states_fp8.device,
device=hidden_states_fp8_device,
dtype=torch.bfloat16,
)
down_input_fp8, down_input_scale = sglang_per_token_group_quant_fp8(
down_input, scale_block_size
)
del down_input
down_input_scale = tma_align_input_scale(down_input_scale)
m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(
(down_input_fp8, down_input_scale),
@@ -1083,7 +1098,13 @@ class DeepEPMoE(EPMoE):
down_output,
m_indices,
)
del down_input_fp8, down_input_scale
gather_out = torch.empty(
hidden_states_fp8_shape,
device=hidden_states_fp8_device,
dtype=torch.bfloat16,
)
ep_gather(down_output, topk_idx, topk_weights, output_index, gather_out)
return gather_out
@@ -1107,6 +1128,7 @@ class DeepEPMoE(EPMoE):
m_grouped_gemm_fp8_fp8_bf16_nt_masked(
hidden_states_fp8, self.w13_weight_fp8, gateup_output, masked_m, expected_m
)
dispose_tensor(hidden_states_fp8[0])
# Act
down_input = torch.empty(
@@ -1135,6 +1157,7 @@ class DeepEPMoE(EPMoE):
scale_block_size,
masked_m,
)
del gateup_output
# GroupGemm-1
n = self.w2_weight.size(1)