Improve streaming, log_level, memory report, weight loading, and benchmark script (#7632)
Co-authored-by: Kan Wu <wukanustc@gmail.com>
This commit is contained in:
@@ -8,6 +8,7 @@ from sglang.srt.utils import is_hip
|
||||
|
||||
_is_hip = is_hip()
|
||||
|
||||
|
||||
fused_softcap_autotune = triton.autotune(
|
||||
configs=[
|
||||
triton.Config(kwargs={"BLOCK_SIZE": 128}, num_warps=4),
|
||||
@@ -189,21 +190,16 @@ def fused_dual_residual_rmsnorm(x, residual, weight1, weight2, eps, autotune=Fal
|
||||
assert x.shape == residual.shape and x.dtype == residual.dtype
|
||||
output, mid = torch.empty_like(x), torch.empty_like(x)
|
||||
bs, hidden_dim = x.shape
|
||||
|
||||
min_num_warps = 16 if _is_hip else 32
|
||||
|
||||
if autotune:
|
||||
fused_dual_residual_rmsnorm_kernel_autotune[(bs,)](
|
||||
output, mid, x, residual, weight1, weight2, eps=eps, hidden_dim=hidden_dim
|
||||
)
|
||||
else:
|
||||
max_warps = 16 if _is_hip else 32
|
||||
config = {
|
||||
"BLOCK_SIZE": triton.next_power_of_2(hidden_dim),
|
||||
"num_warps": max(
|
||||
min(
|
||||
triton.next_power_of_2(triton.cdiv(hidden_dim, 256)), min_num_warps
|
||||
),
|
||||
4,
|
||||
min(triton.next_power_of_2(triton.cdiv(hidden_dim, 256)), max_warps), 4
|
||||
),
|
||||
}
|
||||
|
||||
@@ -260,13 +256,11 @@ def fused_rmsnorm(x, weight, eps, autotune=False, inplace=False):
|
||||
else:
|
||||
output = torch.empty_like(x)
|
||||
bs, hidden_dim = x.shape
|
||||
|
||||
min_num_warps = 16 if _is_hip else 32
|
||||
|
||||
max_warps = 16 if _is_hip else 32
|
||||
config = {
|
||||
"BLOCK_SIZE": triton.next_power_of_2(hidden_dim),
|
||||
"num_warps": max(
|
||||
min(triton.next_power_of_2(triton.cdiv(hidden_dim, 256)), min_num_warps), 4
|
||||
min(triton.next_power_of_2(triton.cdiv(hidden_dim, 256)), max_warps), 4
|
||||
),
|
||||
}
|
||||
|
||||
@@ -331,6 +325,75 @@ class FusedDualResidualRMSNorm:
|
||||
return self.rmsnorm2.forward_native(residual), residual
|
||||
|
||||
|
||||
@triton.jit
|
||||
def experts_combine_kernel(
|
||||
out_hidden_states,
|
||||
moe_hidden_states,
|
||||
mlp_hidden_states,
|
||||
combine_k: tl.constexpr,
|
||||
hidden_dim: tl.constexpr,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
pid = tl.program_id(0)
|
||||
start_index_mlp = pid * hidden_dim
|
||||
start_index_rmoe = pid * hidden_dim * combine_k
|
||||
offsets = tl.arange(0, BLOCK_SIZE)
|
||||
mask = offsets < hidden_dim
|
||||
combine_k_offsets = tl.arange(0, combine_k)
|
||||
|
||||
moe_x = tl.load(
|
||||
moe_hidden_states
|
||||
+ start_index_rmoe
|
||||
+ combine_k_offsets[:, None] * hidden_dim
|
||||
+ offsets[None, :],
|
||||
mask=mask[None, :],
|
||||
other=0.0,
|
||||
)
|
||||
moe_x = tl.sum(moe_x, axis=0)
|
||||
mlp_x = tl.load(mlp_hidden_states + start_index_mlp + offsets, mask=mask, other=0.0)
|
||||
combined_x = (moe_x + mlp_x) / 1.4142135623730951
|
||||
|
||||
tl.store(out_hidden_states + start_index_mlp + offsets, combined_x, mask=mask)
|
||||
|
||||
|
||||
def experts_combine_triton(moe_hidden_states, mlp_hidden_states, output_buffer=None):
|
||||
assert moe_hidden_states.is_contiguous()
|
||||
assert mlp_hidden_states.is_contiguous()
|
||||
|
||||
if len(moe_hidden_states.shape) == 2:
|
||||
combine_k = 1 # pre-combined
|
||||
else:
|
||||
combine_k = moe_hidden_states.shape[1]
|
||||
|
||||
if output_buffer is None:
|
||||
out_hidden_states = torch.empty_like(mlp_hidden_states)
|
||||
else:
|
||||
flat_output_buffer = output_buffer.view(mlp_hidden_states.dtype).reshape(-1)
|
||||
assert flat_output_buffer.numel() >= mlp_hidden_states.numel()
|
||||
out_hidden_states = flat_output_buffer[: mlp_hidden_states.numel()].reshape(
|
||||
mlp_hidden_states.shape
|
||||
)
|
||||
|
||||
bs, hidden_dim = mlp_hidden_states.shape
|
||||
|
||||
config = {
|
||||
"BLOCK_SIZE": triton.next_power_of_2(hidden_dim),
|
||||
"num_warps": max(
|
||||
min(triton.next_power_of_2(triton.cdiv(hidden_dim, 1024)), 8), 4
|
||||
),
|
||||
}
|
||||
|
||||
experts_combine_kernel[(bs,)](
|
||||
out_hidden_states,
|
||||
moe_hidden_states,
|
||||
mlp_hidden_states,
|
||||
combine_k,
|
||||
hidden_dim,
|
||||
**config,
|
||||
)
|
||||
return out_hidden_states
|
||||
|
||||
|
||||
# gelu on first half of vector
|
||||
@triton.jit
|
||||
def gelu_and_mul_kernel(
|
||||
@@ -400,10 +463,11 @@ def gelu_and_mul_triton(
|
||||
out_scales = scales
|
||||
static_scale = True
|
||||
|
||||
max_warps = 16 if _is_hip else 32
|
||||
config = {
|
||||
# 8 ele per thread (not tuned)
|
||||
"num_warps": max(
|
||||
min(triton.next_power_of_2(triton.cdiv(hidden_dim, 8 * 32)), 32), 4
|
||||
min(triton.next_power_of_2(triton.cdiv(hidden_dim, 8 * 32)), max_warps), 4
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user