Reduce MoE memory usage (#6147)
This commit is contained in:
@@ -3,10 +3,9 @@ from typing import List, Optional
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import triton
|
import triton
|
||||||
import triton.language as tl
|
|
||||||
|
|
||||||
from sglang.srt.layers.quantization.fp8_kernel import per_token_group_quant_fp8
|
from sglang.srt.layers.quantization.fp8_kernel import per_token_group_quant_fp8
|
||||||
from sglang.srt.utils import is_cuda
|
from sglang.srt.utils import dispose_tensor, is_cuda
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -653,12 +652,15 @@ def grouped_gemm_triton(
|
|||||||
scale_a: torch.Tensor = None,
|
scale_a: torch.Tensor = None,
|
||||||
scale_b: torch.Tensor = None,
|
scale_b: torch.Tensor = None,
|
||||||
block_shape: Optional[List[int]] = None,
|
block_shape: Optional[List[int]] = None,
|
||||||
|
c_dtype=None,
|
||||||
):
|
):
|
||||||
assert weight_column_major == True # TODO: more
|
assert weight_column_major == True # TODO: more
|
||||||
if use_fp8_w8a8 and block_shape is None:
|
if use_fp8_w8a8 and block_shape is None:
|
||||||
assert scale_a is not None and scale_b is not None
|
assert scale_a is not None and scale_b is not None
|
||||||
|
|
||||||
if block_shape is not None:
|
if block_shape is not None:
|
||||||
|
a_original = a
|
||||||
|
|
||||||
assert len(block_shape) == 2
|
assert len(block_shape) == 2
|
||||||
block_n, block_k = block_shape[0], block_shape[1]
|
block_n, block_k = block_shape[0], block_shape[1]
|
||||||
a, scale_a = per_token_group_quant_fp8(a, block_k)
|
a, scale_a = per_token_group_quant_fp8(a, block_k)
|
||||||
@@ -667,6 +669,8 @@ def grouped_gemm_triton(
|
|||||||
assert triton.cdiv(b.shape[-2], block_n) == scale_b.shape[-2]
|
assert triton.cdiv(b.shape[-2], block_n) == scale_b.shape[-2]
|
||||||
assert triton.cdiv(b.shape[-1], block_k) == scale_b.shape[-1]
|
assert triton.cdiv(b.shape[-1], block_k) == scale_b.shape[-1]
|
||||||
|
|
||||||
|
dispose_tensor(a_original)
|
||||||
|
|
||||||
# TODO: adjust config or tune kernel
|
# TODO: adjust config or tune kernel
|
||||||
# Reduce block size to prevent L40 shared memory overflow.
|
# Reduce block size to prevent L40 shared memory overflow.
|
||||||
config = {
|
config = {
|
||||||
@@ -680,6 +684,10 @@ def grouped_gemm_triton(
|
|||||||
m_num_tiles_indptr, seg_indptr, batch_size, config["BLOCK_SIZE_M"]
|
m_num_tiles_indptr, seg_indptr, batch_size, config["BLOCK_SIZE_M"]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if c is None:
|
||||||
|
assert c_dtype is not None
|
||||||
|
c = torch.empty(a.shape[0], b.shape[1], device=a.device, dtype=c_dtype)
|
||||||
|
|
||||||
grid = lambda META: (
|
grid = lambda META: (
|
||||||
triton.cdiv(a.size(0), META["BLOCK_SIZE_M"]) + batch_size,
|
triton.cdiv(a.size(0), META["BLOCK_SIZE_M"]) + batch_size,
|
||||||
triton.cdiv(b.size(1), META["BLOCK_SIZE_N"]),
|
triton.cdiv(b.size(1), META["BLOCK_SIZE_N"]),
|
||||||
|
|||||||
@@ -49,7 +49,7 @@ from sglang.srt.layers.quantization.base_config import (
|
|||||||
from sglang.srt.layers.quantization.fp8 import Fp8Config, Fp8MoEMethod
|
from sglang.srt.layers.quantization.fp8 import Fp8Config, Fp8MoEMethod
|
||||||
from sglang.srt.layers.quantization.fp8_kernel import scaled_fp8_quant
|
from sglang.srt.layers.quantization.fp8_kernel import scaled_fp8_quant
|
||||||
from sglang.srt.model_executor.forward_batch_info import ForwardMode
|
from sglang.srt.model_executor.forward_batch_info import ForwardMode
|
||||||
from sglang.srt.utils import DeepEPMode, is_hip, set_weight_attrs
|
from sglang.srt.utils import DeepEPMode, dispose_tensor, is_hip, set_weight_attrs
|
||||||
|
|
||||||
_is_hip = is_hip()
|
_is_hip = is_hip()
|
||||||
|
|
||||||
@@ -92,6 +92,7 @@ class GroupedGemmRunner(torch.nn.Module):
|
|||||||
scale_a: torch.Tensor = None,
|
scale_a: torch.Tensor = None,
|
||||||
scale_b: torch.Tensor = None,
|
scale_b: torch.Tensor = None,
|
||||||
block_shape: Optional[List[int]] = None,
|
block_shape: Optional[List[int]] = None,
|
||||||
|
c_dtype=None,
|
||||||
):
|
):
|
||||||
if self.use_flashinfer:
|
if self.use_flashinfer:
|
||||||
# TODO: flashinfer
|
# TODO: flashinfer
|
||||||
@@ -119,6 +120,7 @@ class GroupedGemmRunner(torch.nn.Module):
|
|||||||
scale_a,
|
scale_a,
|
||||||
scale_b,
|
scale_b,
|
||||||
block_shape=block_shape,
|
block_shape=block_shape,
|
||||||
|
c_dtype=c_dtype,
|
||||||
)
|
)
|
||||||
return c
|
return c
|
||||||
|
|
||||||
@@ -210,6 +212,10 @@ class EPMoE(torch.nn.Module):
|
|||||||
self.grouped_gemm_runner = None
|
self.grouped_gemm_runner = None
|
||||||
|
|
||||||
def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor):
|
def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor):
|
||||||
|
hidden_states_shape = hidden_states.shape
|
||||||
|
hidden_states_dtype = hidden_states.dtype
|
||||||
|
hidden_states_device = hidden_states.device
|
||||||
|
|
||||||
assert self.quant_method is not None
|
assert self.quant_method is not None
|
||||||
|
|
||||||
if self.grouped_gemm_runner is None:
|
if self.grouped_gemm_runner is None:
|
||||||
@@ -265,25 +271,21 @@ class EPMoE(torch.nn.Module):
|
|||||||
hidden_states.shape[1],
|
hidden_states.shape[1],
|
||||||
BLOCK_SIZE=512,
|
BLOCK_SIZE=512,
|
||||||
)
|
)
|
||||||
|
dispose_tensor(hidden_states)
|
||||||
|
|
||||||
seg_indptr_cur_rank = seg_indptr[self.start_expert_id : self.end_expert_id + 2]
|
seg_indptr_cur_rank = seg_indptr[self.start_expert_id : self.end_expert_id + 2]
|
||||||
weight_indices_cur_rank = torch.arange(
|
weight_indices_cur_rank = torch.arange(
|
||||||
0,
|
0,
|
||||||
self.num_experts_per_partition,
|
self.num_experts_per_partition,
|
||||||
device=hidden_states.device,
|
device=hidden_states_device,
|
||||||
dtype=torch.int64,
|
dtype=torch.int64,
|
||||||
)
|
)
|
||||||
# GroupGemm-0
|
# GroupGemm-0
|
||||||
gateup_output = torch.empty(
|
|
||||||
gateup_input.shape[0],
|
|
||||||
self.w13_weight.shape[1],
|
|
||||||
device=hidden_states.device,
|
|
||||||
dtype=hidden_states.dtype,
|
|
||||||
)
|
|
||||||
gateup_output = self.grouped_gemm_runner(
|
gateup_output = self.grouped_gemm_runner(
|
||||||
a=gateup_input,
|
a=gateup_input,
|
||||||
b=self.w13_weight,
|
b=self.w13_weight,
|
||||||
c=gateup_output,
|
c=None,
|
||||||
|
c_dtype=hidden_states_dtype,
|
||||||
batch_size=self.num_experts_per_partition,
|
batch_size=self.num_experts_per_partition,
|
||||||
weight_column_major=True,
|
weight_column_major=True,
|
||||||
seg_indptr=seg_indptr_cur_rank,
|
seg_indptr=seg_indptr_cur_rank,
|
||||||
@@ -297,6 +299,7 @@ class EPMoE(torch.nn.Module):
|
|||||||
),
|
),
|
||||||
block_shape=self.block_shape,
|
block_shape=self.block_shape,
|
||||||
)
|
)
|
||||||
|
del gateup_input
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
down_input = torch.empty(
|
down_input = torch.empty(
|
||||||
@@ -306,14 +309,14 @@ class EPMoE(torch.nn.Module):
|
|||||||
dtype=(
|
dtype=(
|
||||||
self.fp8_dtype
|
self.fp8_dtype
|
||||||
if (self.use_fp8_w8a8 and not self.use_block_quant)
|
if (self.use_fp8_w8a8 and not self.use_block_quant)
|
||||||
else hidden_states.dtype
|
else hidden_states_dtype
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
if self.w2_input_scale is None and not self.use_block_quant:
|
if self.w2_input_scale is None and not self.use_block_quant:
|
||||||
self.w2_input_scale = torch.ones(
|
self.w2_input_scale = torch.ones(
|
||||||
self.num_experts_per_partition,
|
self.num_experts_per_partition,
|
||||||
dtype=torch.float32,
|
dtype=torch.float32,
|
||||||
device=hidden_states.device,
|
device=hidden_states_device,
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.activation == "silu":
|
if self.activation == "silu":
|
||||||
@@ -340,13 +343,14 @@ class EPMoE(torch.nn.Module):
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported activation: {self.activation=}")
|
raise ValueError(f"Unsupported activation: {self.activation=}")
|
||||||
|
del gateup_output
|
||||||
|
|
||||||
# GroupGemm-1
|
# GroupGemm-1
|
||||||
down_output = torch.empty(
|
down_output = torch.empty(
|
||||||
down_input.shape[0],
|
down_input.shape[0],
|
||||||
self.w2_weight.shape[1],
|
self.w2_weight.shape[1],
|
||||||
device=hidden_states.device,
|
device=hidden_states_device,
|
||||||
dtype=hidden_states.dtype,
|
dtype=hidden_states_dtype,
|
||||||
)
|
)
|
||||||
down_output = self.grouped_gemm_runner(
|
down_output = self.grouped_gemm_runner(
|
||||||
a=down_input,
|
a=down_input,
|
||||||
@@ -365,10 +369,13 @@ class EPMoE(torch.nn.Module):
|
|||||||
),
|
),
|
||||||
block_shape=self.block_shape,
|
block_shape=self.block_shape,
|
||||||
)
|
)
|
||||||
|
del down_input
|
||||||
|
|
||||||
# PostReorder
|
# PostReorder
|
||||||
output = torch.empty_like(hidden_states)
|
output = torch.empty(
|
||||||
post_reorder_triton_kernel[(hidden_states.size(0),)](
|
hidden_states_shape, dtype=hidden_states_dtype, device=hidden_states_device
|
||||||
|
)
|
||||||
|
post_reorder_triton_kernel[(hidden_states_shape[0],)](
|
||||||
down_output,
|
down_output,
|
||||||
output,
|
output,
|
||||||
src2dst,
|
src2dst,
|
||||||
@@ -377,7 +384,7 @@ class EPMoE(torch.nn.Module):
|
|||||||
self.start_expert_id,
|
self.start_expert_id,
|
||||||
self.end_expert_id,
|
self.end_expert_id,
|
||||||
self.top_k,
|
self.top_k,
|
||||||
hidden_states.size(1),
|
hidden_states_shape[1],
|
||||||
BLOCK_SIZE=512,
|
BLOCK_SIZE=512,
|
||||||
)
|
)
|
||||||
return output
|
return output
|
||||||
@@ -881,6 +888,9 @@ class DeepEPMoE(EPMoE):
|
|||||||
reorder_topk_ids: torch.Tensor,
|
reorder_topk_ids: torch.Tensor,
|
||||||
seg_indptr: torch.Tensor,
|
seg_indptr: torch.Tensor,
|
||||||
):
|
):
|
||||||
|
hidden_states_dtype = hidden_states.dtype
|
||||||
|
hidden_states_device = hidden_states.device
|
||||||
|
|
||||||
assert self.quant_method is not None
|
assert self.quant_method is not None
|
||||||
assert self.activation == "silu"
|
assert self.activation == "silu"
|
||||||
if self.grouped_gemm_runner is None:
|
if self.grouped_gemm_runner is None:
|
||||||
@@ -903,18 +913,12 @@ class DeepEPMoE(EPMoE):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# GroupGemm-0
|
# GroupGemm-0
|
||||||
gateup_output = torch.empty(
|
|
||||||
hidden_states.shape[0],
|
|
||||||
self.w13_weight.shape[1],
|
|
||||||
device=hidden_states.device,
|
|
||||||
dtype=hidden_states.dtype,
|
|
||||||
)
|
|
||||||
|
|
||||||
if hidden_states.shape[0] > 0:
|
if hidden_states.shape[0] > 0:
|
||||||
gateup_output = self.grouped_gemm_runner(
|
gateup_output = self.grouped_gemm_runner(
|
||||||
a=hidden_states,
|
a=hidden_states,
|
||||||
b=self.w13_weight,
|
b=self.w13_weight,
|
||||||
c=gateup_output,
|
c=None,
|
||||||
|
c_dtype=hidden_states.dtype,
|
||||||
batch_size=self.num_experts_per_partition,
|
batch_size=self.num_experts_per_partition,
|
||||||
weight_column_major=True,
|
weight_column_major=True,
|
||||||
seg_indptr=seg_indptr,
|
seg_indptr=seg_indptr,
|
||||||
@@ -928,6 +932,13 @@ class DeepEPMoE(EPMoE):
|
|||||||
),
|
),
|
||||||
block_shape=self.block_shape,
|
block_shape=self.block_shape,
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
gateup_output = torch.empty(
|
||||||
|
hidden_states.shape[0],
|
||||||
|
self.w13_weight.shape[1],
|
||||||
|
device=hidden_states.device,
|
||||||
|
dtype=hidden_states.dtype,
|
||||||
|
)
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
down_input = torch.empty(
|
down_input = torch.empty(
|
||||||
@@ -937,14 +948,14 @@ class DeepEPMoE(EPMoE):
|
|||||||
dtype=(
|
dtype=(
|
||||||
self.fp8_dtype
|
self.fp8_dtype
|
||||||
if (self.use_fp8_w8a8 and not self.use_block_quant)
|
if (self.use_fp8_w8a8 and not self.use_block_quant)
|
||||||
else hidden_states.dtype
|
else hidden_states_dtype
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
if self.w2_input_scale is None and not self.use_block_quant:
|
if self.w2_input_scale is None and not self.use_block_quant:
|
||||||
self.w2_input_scale = torch.ones(
|
self.w2_input_scale = torch.ones(
|
||||||
self.num_experts_per_partition,
|
self.num_experts_per_partition,
|
||||||
dtype=torch.float32,
|
dtype=torch.float32,
|
||||||
device=hidden_states.device,
|
device=hidden_states_device,
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.activation == "silu":
|
if self.activation == "silu":
|
||||||
@@ -961,12 +972,14 @@ class DeepEPMoE(EPMoE):
|
|||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported activation: {self.activation=}")
|
raise ValueError(f"Unsupported activation: {self.activation=}")
|
||||||
|
|
||||||
|
del gateup_output
|
||||||
|
|
||||||
# GroupGemm-1
|
# GroupGemm-1
|
||||||
down_output = torch.empty(
|
down_output = torch.empty(
|
||||||
down_input.shape[0],
|
down_input.shape[0],
|
||||||
self.w2_weight.shape[1],
|
self.w2_weight.shape[1],
|
||||||
device=hidden_states.device,
|
device=hidden_states_device,
|
||||||
dtype=hidden_states.dtype,
|
dtype=hidden_states_dtype,
|
||||||
)
|
)
|
||||||
if down_input.shape[0] > 0:
|
if down_input.shape[0] > 0:
|
||||||
down_output = self.grouped_gemm_runner(
|
down_output = self.grouped_gemm_runner(
|
||||||
@@ -1007,11 +1020,9 @@ class DeepEPMoE(EPMoE):
|
|||||||
N = self.w13_weight.size(1)
|
N = self.w13_weight.size(1)
|
||||||
scale_block_size = 128
|
scale_block_size = 128
|
||||||
|
|
||||||
gather_out = torch.empty_like(
|
hidden_states_fp8_shape = hidden_states_fp8.shape
|
||||||
hidden_states_fp8,
|
hidden_states_fp8_device = hidden_states_fp8.device
|
||||||
device=hidden_states_fp8.device,
|
hidden_states_fp8_dtype = hidden_states_fp8.dtype
|
||||||
dtype=torch.bfloat16,
|
|
||||||
)
|
|
||||||
|
|
||||||
input_tensor = [
|
input_tensor = [
|
||||||
torch.empty(
|
torch.empty(
|
||||||
@@ -1049,16 +1060,18 @@ class DeepEPMoE(EPMoE):
|
|||||||
m_indices,
|
m_indices,
|
||||||
output_index,
|
output_index,
|
||||||
)
|
)
|
||||||
|
dispose_tensor(hidden_states_fp8)
|
||||||
|
|
||||||
gateup_output = torch.empty(
|
gateup_output = torch.empty(
|
||||||
(all_tokens, N),
|
(all_tokens, N),
|
||||||
device=hidden_states_fp8.device,
|
device=hidden_states_fp8_device,
|
||||||
dtype=torch.bfloat16,
|
dtype=torch.bfloat16,
|
||||||
)
|
)
|
||||||
input_tensor[1] = tma_align_input_scale(input_tensor[1])
|
input_tensor[1] = tma_align_input_scale(input_tensor[1])
|
||||||
m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(
|
m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(
|
||||||
input_tensor, self.w13_weight_fp8, gateup_output, m_indices
|
input_tensor, self.w13_weight_fp8, gateup_output, m_indices
|
||||||
)
|
)
|
||||||
|
del input_tensor
|
||||||
down_input = torch.empty(
|
down_input = torch.empty(
|
||||||
(
|
(
|
||||||
all_tokens,
|
all_tokens,
|
||||||
@@ -1068,14 +1081,16 @@ class DeepEPMoE(EPMoE):
|
|||||||
dtype=torch.bfloat16,
|
dtype=torch.bfloat16,
|
||||||
)
|
)
|
||||||
silu_and_mul(gateup_output.view(-1, N), down_input)
|
silu_and_mul(gateup_output.view(-1, N), down_input)
|
||||||
|
del gateup_output
|
||||||
down_output = torch.empty(
|
down_output = torch.empty(
|
||||||
(all_tokens, K),
|
(all_tokens, K),
|
||||||
device=hidden_states_fp8.device,
|
device=hidden_states_fp8_device,
|
||||||
dtype=torch.bfloat16,
|
dtype=torch.bfloat16,
|
||||||
)
|
)
|
||||||
down_input_fp8, down_input_scale = sglang_per_token_group_quant_fp8(
|
down_input_fp8, down_input_scale = sglang_per_token_group_quant_fp8(
|
||||||
down_input, scale_block_size
|
down_input, scale_block_size
|
||||||
)
|
)
|
||||||
|
del down_input
|
||||||
down_input_scale = tma_align_input_scale(down_input_scale)
|
down_input_scale = tma_align_input_scale(down_input_scale)
|
||||||
m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(
|
m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(
|
||||||
(down_input_fp8, down_input_scale),
|
(down_input_fp8, down_input_scale),
|
||||||
@@ -1083,7 +1098,13 @@ class DeepEPMoE(EPMoE):
|
|||||||
down_output,
|
down_output,
|
||||||
m_indices,
|
m_indices,
|
||||||
)
|
)
|
||||||
|
del down_input_fp8, down_input_scale
|
||||||
|
|
||||||
|
gather_out = torch.empty(
|
||||||
|
hidden_states_fp8_shape,
|
||||||
|
device=hidden_states_fp8_device,
|
||||||
|
dtype=torch.bfloat16,
|
||||||
|
)
|
||||||
ep_gather(down_output, topk_idx, topk_weights, output_index, gather_out)
|
ep_gather(down_output, topk_idx, topk_weights, output_index, gather_out)
|
||||||
|
|
||||||
return gather_out
|
return gather_out
|
||||||
@@ -1107,6 +1128,7 @@ class DeepEPMoE(EPMoE):
|
|||||||
m_grouped_gemm_fp8_fp8_bf16_nt_masked(
|
m_grouped_gemm_fp8_fp8_bf16_nt_masked(
|
||||||
hidden_states_fp8, self.w13_weight_fp8, gateup_output, masked_m, expected_m
|
hidden_states_fp8, self.w13_weight_fp8, gateup_output, masked_m, expected_m
|
||||||
)
|
)
|
||||||
|
dispose_tensor(hidden_states_fp8[0])
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
down_input = torch.empty(
|
down_input = torch.empty(
|
||||||
@@ -1135,6 +1157,7 @@ class DeepEPMoE(EPMoE):
|
|||||||
scale_block_size,
|
scale_block_size,
|
||||||
masked_m,
|
masked_m,
|
||||||
)
|
)
|
||||||
|
del gateup_output
|
||||||
|
|
||||||
# GroupGemm-1
|
# GroupGemm-1
|
||||||
n = self.w2_weight.size(1)
|
n = self.w2_weight.size(1)
|
||||||
|
|||||||
@@ -311,10 +311,10 @@ class DeepseekV2MoE(nn.Module):
|
|||||||
shared_output = self._forward_shared_experts(hidden_states)
|
shared_output = self._forward_shared_experts(hidden_states)
|
||||||
# router_logits: (num_tokens, n_experts)
|
# router_logits: (num_tokens, n_experts)
|
||||||
router_logits = self.gate(hidden_states)
|
router_logits = self.gate(hidden_states)
|
||||||
final_hidden_states = (
|
final_hidden_states = self.experts(
|
||||||
self.experts(hidden_states=hidden_states, router_logits=router_logits)
|
hidden_states=hidden_states, router_logits=router_logits
|
||||||
* self.routed_scaling_factor
|
|
||||||
)
|
)
|
||||||
|
final_hidden_states *= self.routed_scaling_factor
|
||||||
if shared_output is not None:
|
if shared_output is not None:
|
||||||
final_hidden_states = final_hidden_states + shared_output
|
final_hidden_states = final_hidden_states + shared_output
|
||||||
if self.tp_size > 1:
|
if self.tp_size > 1:
|
||||||
|
|||||||
@@ -2100,3 +2100,7 @@ def log_info_on_rank0(logger, msg):
|
|||||||
|
|
||||||
if get_tensor_model_parallel_rank() == 0:
|
if get_tensor_model_parallel_rank() == 0:
|
||||||
logger.info(msg)
|
logger.info(msg)
|
||||||
|
|
||||||
|
|
||||||
|
def dispose_tensor(x: torch.Tensor):
|
||||||
|
x.set_(torch.empty((0,), device=x.device, dtype=x.dtype))
|
||||||
|
|||||||
Reference in New Issue
Block a user