Reduce MoE memory usage (#6147)
This commit is contained in:
@@ -49,7 +49,7 @@ from sglang.srt.layers.quantization.base_config import (
|
||||
from sglang.srt.layers.quantization.fp8 import Fp8Config, Fp8MoEMethod
|
||||
from sglang.srt.layers.quantization.fp8_kernel import scaled_fp8_quant
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardMode
|
||||
from sglang.srt.utils import DeepEPMode, is_hip, set_weight_attrs
|
||||
from sglang.srt.utils import DeepEPMode, dispose_tensor, is_hip, set_weight_attrs
|
||||
|
||||
_is_hip = is_hip()
|
||||
|
||||
@@ -92,6 +92,7 @@ class GroupedGemmRunner(torch.nn.Module):
|
||||
scale_a: torch.Tensor = None,
|
||||
scale_b: torch.Tensor = None,
|
||||
block_shape: Optional[List[int]] = None,
|
||||
c_dtype=None,
|
||||
):
|
||||
if self.use_flashinfer:
|
||||
# TODO: flashinfer
|
||||
@@ -119,6 +120,7 @@ class GroupedGemmRunner(torch.nn.Module):
|
||||
scale_a,
|
||||
scale_b,
|
||||
block_shape=block_shape,
|
||||
c_dtype=c_dtype,
|
||||
)
|
||||
return c
|
||||
|
||||
@@ -210,6 +212,10 @@ class EPMoE(torch.nn.Module):
|
||||
self.grouped_gemm_runner = None
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor):
|
||||
hidden_states_shape = hidden_states.shape
|
||||
hidden_states_dtype = hidden_states.dtype
|
||||
hidden_states_device = hidden_states.device
|
||||
|
||||
assert self.quant_method is not None
|
||||
|
||||
if self.grouped_gemm_runner is None:
|
||||
@@ -265,25 +271,21 @@ class EPMoE(torch.nn.Module):
|
||||
hidden_states.shape[1],
|
||||
BLOCK_SIZE=512,
|
||||
)
|
||||
dispose_tensor(hidden_states)
|
||||
|
||||
seg_indptr_cur_rank = seg_indptr[self.start_expert_id : self.end_expert_id + 2]
|
||||
weight_indices_cur_rank = torch.arange(
|
||||
0,
|
||||
self.num_experts_per_partition,
|
||||
device=hidden_states.device,
|
||||
device=hidden_states_device,
|
||||
dtype=torch.int64,
|
||||
)
|
||||
# GroupGemm-0
|
||||
gateup_output = torch.empty(
|
||||
gateup_input.shape[0],
|
||||
self.w13_weight.shape[1],
|
||||
device=hidden_states.device,
|
||||
dtype=hidden_states.dtype,
|
||||
)
|
||||
gateup_output = self.grouped_gemm_runner(
|
||||
a=gateup_input,
|
||||
b=self.w13_weight,
|
||||
c=gateup_output,
|
||||
c=None,
|
||||
c_dtype=hidden_states_dtype,
|
||||
batch_size=self.num_experts_per_partition,
|
||||
weight_column_major=True,
|
||||
seg_indptr=seg_indptr_cur_rank,
|
||||
@@ -297,6 +299,7 @@ class EPMoE(torch.nn.Module):
|
||||
),
|
||||
block_shape=self.block_shape,
|
||||
)
|
||||
del gateup_input
|
||||
|
||||
# Act
|
||||
down_input = torch.empty(
|
||||
@@ -306,14 +309,14 @@ class EPMoE(torch.nn.Module):
|
||||
dtype=(
|
||||
self.fp8_dtype
|
||||
if (self.use_fp8_w8a8 and not self.use_block_quant)
|
||||
else hidden_states.dtype
|
||||
else hidden_states_dtype
|
||||
),
|
||||
)
|
||||
if self.w2_input_scale is None and not self.use_block_quant:
|
||||
self.w2_input_scale = torch.ones(
|
||||
self.num_experts_per_partition,
|
||||
dtype=torch.float32,
|
||||
device=hidden_states.device,
|
||||
device=hidden_states_device,
|
||||
)
|
||||
|
||||
if self.activation == "silu":
|
||||
@@ -340,13 +343,14 @@ class EPMoE(torch.nn.Module):
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported activation: {self.activation=}")
|
||||
del gateup_output
|
||||
|
||||
# GroupGemm-1
|
||||
down_output = torch.empty(
|
||||
down_input.shape[0],
|
||||
self.w2_weight.shape[1],
|
||||
device=hidden_states.device,
|
||||
dtype=hidden_states.dtype,
|
||||
device=hidden_states_device,
|
||||
dtype=hidden_states_dtype,
|
||||
)
|
||||
down_output = self.grouped_gemm_runner(
|
||||
a=down_input,
|
||||
@@ -365,10 +369,13 @@ class EPMoE(torch.nn.Module):
|
||||
),
|
||||
block_shape=self.block_shape,
|
||||
)
|
||||
del down_input
|
||||
|
||||
# PostReorder
|
||||
output = torch.empty_like(hidden_states)
|
||||
post_reorder_triton_kernel[(hidden_states.size(0),)](
|
||||
output = torch.empty(
|
||||
hidden_states_shape, dtype=hidden_states_dtype, device=hidden_states_device
|
||||
)
|
||||
post_reorder_triton_kernel[(hidden_states_shape[0],)](
|
||||
down_output,
|
||||
output,
|
||||
src2dst,
|
||||
@@ -377,7 +384,7 @@ class EPMoE(torch.nn.Module):
|
||||
self.start_expert_id,
|
||||
self.end_expert_id,
|
||||
self.top_k,
|
||||
hidden_states.size(1),
|
||||
hidden_states_shape[1],
|
||||
BLOCK_SIZE=512,
|
||||
)
|
||||
return output
|
||||
@@ -881,6 +888,9 @@ class DeepEPMoE(EPMoE):
|
||||
reorder_topk_ids: torch.Tensor,
|
||||
seg_indptr: torch.Tensor,
|
||||
):
|
||||
hidden_states_dtype = hidden_states.dtype
|
||||
hidden_states_device = hidden_states.device
|
||||
|
||||
assert self.quant_method is not None
|
||||
assert self.activation == "silu"
|
||||
if self.grouped_gemm_runner is None:
|
||||
@@ -903,18 +913,12 @@ class DeepEPMoE(EPMoE):
|
||||
)
|
||||
|
||||
# GroupGemm-0
|
||||
gateup_output = torch.empty(
|
||||
hidden_states.shape[0],
|
||||
self.w13_weight.shape[1],
|
||||
device=hidden_states.device,
|
||||
dtype=hidden_states.dtype,
|
||||
)
|
||||
|
||||
if hidden_states.shape[0] > 0:
|
||||
gateup_output = self.grouped_gemm_runner(
|
||||
a=hidden_states,
|
||||
b=self.w13_weight,
|
||||
c=gateup_output,
|
||||
c=None,
|
||||
c_dtype=hidden_states.dtype,
|
||||
batch_size=self.num_experts_per_partition,
|
||||
weight_column_major=True,
|
||||
seg_indptr=seg_indptr,
|
||||
@@ -928,6 +932,13 @@ class DeepEPMoE(EPMoE):
|
||||
),
|
||||
block_shape=self.block_shape,
|
||||
)
|
||||
else:
|
||||
gateup_output = torch.empty(
|
||||
hidden_states.shape[0],
|
||||
self.w13_weight.shape[1],
|
||||
device=hidden_states.device,
|
||||
dtype=hidden_states.dtype,
|
||||
)
|
||||
|
||||
# Act
|
||||
down_input = torch.empty(
|
||||
@@ -937,14 +948,14 @@ class DeepEPMoE(EPMoE):
|
||||
dtype=(
|
||||
self.fp8_dtype
|
||||
if (self.use_fp8_w8a8 and not self.use_block_quant)
|
||||
else hidden_states.dtype
|
||||
else hidden_states_dtype
|
||||
),
|
||||
)
|
||||
if self.w2_input_scale is None and not self.use_block_quant:
|
||||
self.w2_input_scale = torch.ones(
|
||||
self.num_experts_per_partition,
|
||||
dtype=torch.float32,
|
||||
device=hidden_states.device,
|
||||
device=hidden_states_device,
|
||||
)
|
||||
|
||||
if self.activation == "silu":
|
||||
@@ -961,12 +972,14 @@ class DeepEPMoE(EPMoE):
|
||||
else:
|
||||
raise ValueError(f"Unsupported activation: {self.activation=}")
|
||||
|
||||
del gateup_output
|
||||
|
||||
# GroupGemm-1
|
||||
down_output = torch.empty(
|
||||
down_input.shape[0],
|
||||
self.w2_weight.shape[1],
|
||||
device=hidden_states.device,
|
||||
dtype=hidden_states.dtype,
|
||||
device=hidden_states_device,
|
||||
dtype=hidden_states_dtype,
|
||||
)
|
||||
if down_input.shape[0] > 0:
|
||||
down_output = self.grouped_gemm_runner(
|
||||
@@ -1007,11 +1020,9 @@ class DeepEPMoE(EPMoE):
|
||||
N = self.w13_weight.size(1)
|
||||
scale_block_size = 128
|
||||
|
||||
gather_out = torch.empty_like(
|
||||
hidden_states_fp8,
|
||||
device=hidden_states_fp8.device,
|
||||
dtype=torch.bfloat16,
|
||||
)
|
||||
hidden_states_fp8_shape = hidden_states_fp8.shape
|
||||
hidden_states_fp8_device = hidden_states_fp8.device
|
||||
hidden_states_fp8_dtype = hidden_states_fp8.dtype
|
||||
|
||||
input_tensor = [
|
||||
torch.empty(
|
||||
@@ -1049,16 +1060,18 @@ class DeepEPMoE(EPMoE):
|
||||
m_indices,
|
||||
output_index,
|
||||
)
|
||||
dispose_tensor(hidden_states_fp8)
|
||||
|
||||
gateup_output = torch.empty(
|
||||
(all_tokens, N),
|
||||
device=hidden_states_fp8.device,
|
||||
device=hidden_states_fp8_device,
|
||||
dtype=torch.bfloat16,
|
||||
)
|
||||
input_tensor[1] = tma_align_input_scale(input_tensor[1])
|
||||
m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(
|
||||
input_tensor, self.w13_weight_fp8, gateup_output, m_indices
|
||||
)
|
||||
del input_tensor
|
||||
down_input = torch.empty(
|
||||
(
|
||||
all_tokens,
|
||||
@@ -1068,14 +1081,16 @@ class DeepEPMoE(EPMoE):
|
||||
dtype=torch.bfloat16,
|
||||
)
|
||||
silu_and_mul(gateup_output.view(-1, N), down_input)
|
||||
del gateup_output
|
||||
down_output = torch.empty(
|
||||
(all_tokens, K),
|
||||
device=hidden_states_fp8.device,
|
||||
device=hidden_states_fp8_device,
|
||||
dtype=torch.bfloat16,
|
||||
)
|
||||
down_input_fp8, down_input_scale = sglang_per_token_group_quant_fp8(
|
||||
down_input, scale_block_size
|
||||
)
|
||||
del down_input
|
||||
down_input_scale = tma_align_input_scale(down_input_scale)
|
||||
m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(
|
||||
(down_input_fp8, down_input_scale),
|
||||
@@ -1083,7 +1098,13 @@ class DeepEPMoE(EPMoE):
|
||||
down_output,
|
||||
m_indices,
|
||||
)
|
||||
del down_input_fp8, down_input_scale
|
||||
|
||||
gather_out = torch.empty(
|
||||
hidden_states_fp8_shape,
|
||||
device=hidden_states_fp8_device,
|
||||
dtype=torch.bfloat16,
|
||||
)
|
||||
ep_gather(down_output, topk_idx, topk_weights, output_index, gather_out)
|
||||
|
||||
return gather_out
|
||||
@@ -1107,6 +1128,7 @@ class DeepEPMoE(EPMoE):
|
||||
m_grouped_gemm_fp8_fp8_bf16_nt_masked(
|
||||
hidden_states_fp8, self.w13_weight_fp8, gateup_output, masked_m, expected_m
|
||||
)
|
||||
dispose_tensor(hidden_states_fp8[0])
|
||||
|
||||
# Act
|
||||
down_input = torch.empty(
|
||||
@@ -1135,6 +1157,7 @@ class DeepEPMoE(EPMoE):
|
||||
scale_block_size,
|
||||
masked_m,
|
||||
)
|
||||
del gateup_output
|
||||
|
||||
# GroupGemm-1
|
||||
n = self.w2_weight.size(1)
|
||||
|
||||
Reference in New Issue
Block a user