Improve streaming, log_level, memory report, weight loading, and benchmark script (#7632)

Co-authored-by: Kan Wu <wukanustc@gmail.com>
This commit is contained in:
Lianmin Zheng
2025-06-29 23:16:19 -07:00
committed by GitHub
parent c5131f7a2f
commit 22352d47a9
24 changed files with 626 additions and 160 deletions

View File

@@ -8,6 +8,7 @@ from sglang.srt.utils import is_hip
_is_hip = is_hip()
fused_softcap_autotune = triton.autotune(
configs=[
triton.Config(kwargs={"BLOCK_SIZE": 128}, num_warps=4),
@@ -189,21 +190,16 @@ def fused_dual_residual_rmsnorm(x, residual, weight1, weight2, eps, autotune=Fal
assert x.shape == residual.shape and x.dtype == residual.dtype
output, mid = torch.empty_like(x), torch.empty_like(x)
bs, hidden_dim = x.shape
min_num_warps = 16 if _is_hip else 32
if autotune:
fused_dual_residual_rmsnorm_kernel_autotune[(bs,)](
output, mid, x, residual, weight1, weight2, eps=eps, hidden_dim=hidden_dim
)
else:
max_warps = 16 if _is_hip else 32
config = {
"BLOCK_SIZE": triton.next_power_of_2(hidden_dim),
"num_warps": max(
min(
triton.next_power_of_2(triton.cdiv(hidden_dim, 256)), min_num_warps
),
4,
min(triton.next_power_of_2(triton.cdiv(hidden_dim, 256)), max_warps), 4
),
}
@@ -260,13 +256,11 @@ def fused_rmsnorm(x, weight, eps, autotune=False, inplace=False):
else:
output = torch.empty_like(x)
bs, hidden_dim = x.shape
min_num_warps = 16 if _is_hip else 32
max_warps = 16 if _is_hip else 32
config = {
"BLOCK_SIZE": triton.next_power_of_2(hidden_dim),
"num_warps": max(
min(triton.next_power_of_2(triton.cdiv(hidden_dim, 256)), min_num_warps), 4
min(triton.next_power_of_2(triton.cdiv(hidden_dim, 256)), max_warps), 4
),
}
@@ -331,6 +325,75 @@ class FusedDualResidualRMSNorm:
return self.rmsnorm2.forward_native(residual), residual
@triton.jit
def experts_combine_kernel(
out_hidden_states,
moe_hidden_states,
mlp_hidden_states,
combine_k: tl.constexpr,
hidden_dim: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
):
pid = tl.program_id(0)
start_index_mlp = pid * hidden_dim
start_index_rmoe = pid * hidden_dim * combine_k
offsets = tl.arange(0, BLOCK_SIZE)
mask = offsets < hidden_dim
combine_k_offsets = tl.arange(0, combine_k)
moe_x = tl.load(
moe_hidden_states
+ start_index_rmoe
+ combine_k_offsets[:, None] * hidden_dim
+ offsets[None, :],
mask=mask[None, :],
other=0.0,
)
moe_x = tl.sum(moe_x, axis=0)
mlp_x = tl.load(mlp_hidden_states + start_index_mlp + offsets, mask=mask, other=0.0)
combined_x = (moe_x + mlp_x) / 1.4142135623730951
tl.store(out_hidden_states + start_index_mlp + offsets, combined_x, mask=mask)
def experts_combine_triton(moe_hidden_states, mlp_hidden_states, output_buffer=None):
assert moe_hidden_states.is_contiguous()
assert mlp_hidden_states.is_contiguous()
if len(moe_hidden_states.shape) == 2:
combine_k = 1 # pre-combined
else:
combine_k = moe_hidden_states.shape[1]
if output_buffer is None:
out_hidden_states = torch.empty_like(mlp_hidden_states)
else:
flat_output_buffer = output_buffer.view(mlp_hidden_states.dtype).reshape(-1)
assert flat_output_buffer.numel() >= mlp_hidden_states.numel()
out_hidden_states = flat_output_buffer[: mlp_hidden_states.numel()].reshape(
mlp_hidden_states.shape
)
bs, hidden_dim = mlp_hidden_states.shape
config = {
"BLOCK_SIZE": triton.next_power_of_2(hidden_dim),
"num_warps": max(
min(triton.next_power_of_2(triton.cdiv(hidden_dim, 1024)), 8), 4
),
}
experts_combine_kernel[(bs,)](
out_hidden_states,
moe_hidden_states,
mlp_hidden_states,
combine_k,
hidden_dim,
**config,
)
return out_hidden_states
# gelu on first half of vector
@triton.jit
def gelu_and_mul_kernel(
@@ -400,10 +463,11 @@ def gelu_and_mul_triton(
out_scales = scales
static_scale = True
max_warps = 16 if _is_hip else 32
config = {
# 8 ele per thread (not tuned)
"num_warps": max(
min(triton.next_power_of_2(triton.cdiv(hidden_dim, 8 * 32)), 32), 4
min(triton.next_power_of_2(triton.cdiv(hidden_dim, 8 * 32)), max_warps), 4
),
}