[gpt-oss] Add gpt-oss bf16 support

This commit is contained in:
2025-08-13 21:25:57 +08:00
parent 5d2e7edf78
commit 8ba49a7723
9 changed files with 777 additions and 36 deletions

View File

@@ -275,6 +275,7 @@ def fused_moe_kernel(
a_ptr,
b_ptr,
c_ptr,
b_bias_ptr,
a_scale_ptr,
b_scale_ptr,
topk_weights_ptr,
@@ -303,6 +304,8 @@ def fused_moe_kernel(
stride_bse,
stride_bsk,
stride_bsn,
stride_bbe, # bias expert stride
stride_bbn, # bias N stride
# Block size for block-wise quantization
group_n: tl.constexpr,
group_k: tl.constexpr,
@@ -321,6 +324,7 @@ def fused_moe_kernel(
use_int8_w8a8: tl.constexpr,
use_int8_w8a16: tl.constexpr,
per_channel_quant: tl.constexpr,
HAS_BIAS: tl.constexpr,
UPGRADE: tl.constexpr,
UPGRADE_A_OFFS: tl.constexpr,
UPGRADE_B_OFFS: tl.constexpr,
@@ -447,6 +451,10 @@ def fused_moe_kernel(
else:
a_scale = tl.load(a_scale_ptr)
b_scale = tl.load(b_scale_ptr + off_experts)
if HAS_BIAS:
# bias shape: [num_experts, N]
bias_ptrs = b_bias_ptr + off_experts * stride_bbe + offs_bn * stride_bbn
bias = tl.load(bias_ptrs, mask=(offs_bn < N), other=0.0)
# -----------------------------------------------------------
# Iterate to compute a block of the C matrix.
@@ -494,7 +502,8 @@ def fused_moe_kernel(
# Advance the ptrs to the next K block.
a_ptrs += BLOCK_SIZE_K * stride_ak * SPLIT_K
b_ptrs += BLOCK_SIZE_K * stride_bk * SPLIT_K
if HAS_BIAS:
accumulator = accumulator + bias[None, :]
if MUL_ROUTED_WEIGHT:
moe_weight = tl.load(topk_weights_ptr + offs_token,
mask=token_mask,
@@ -548,7 +557,8 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
use_int4_w4a16: bool,
orig_acc_dtype: torch.dtype,
per_channel_quant: bool,
block_shape: Optional[list[int]] = None) -> None:
block_shape: Optional[list[int]] = None,
B_bias: Optional[torch.Tensor] = None) -> None:
assert topk_weights is not None or not mul_routed_weight
assert topk_weights is None or topk_weights.stride(1) == 1
assert sorted_token_ids.stride(0) == 1
@@ -580,7 +590,7 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
A.shape[0] * top_k * config['BLOCK_SIZE_M'])
grid = lambda META: (triton.cdiv(EM, META['BLOCK_SIZE_M']) * triton.cdiv(
B.shape[1], META['BLOCK_SIZE_N']), META['SPLIT_K'])
HAS_BIAS = B_bias is not None
if (use_int8_w8a16 or use_int4_w4a16) and \
block_shape is not None and block_shape[1] > 0:
assert B_scale is not None and B_scale.ndim == 3
@@ -592,19 +602,19 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
num_experts=B.shape[0],
bit=4 if use_int4_w4a16 else 8)
# TODO: missing config for BLOCK_SIZE_K
# config = config.copy()
# config.update(
# get_moe_wna16_block_config(config=config,
# use_moe_wna16_cuda=use_moe_wna16_cuda,
# num_valid_tokens=num_tokens,
# size_k=A.shape[1],
# size_n=B.shape[1],
# num_experts=B.shape[1],
# group_size=block_shape[1],
# real_top_k=top_k,
# block_size_m=config["BLOCK_SIZE_M"]))
config = config.copy()
config.update(
get_moe_wna16_block_config(config=config,
use_moe_wna16_cuda=use_moe_wna16_cuda,
num_valid_tokens=num_tokens,
size_k=A.shape[1],
size_n=B.shape[1],
num_experts=B.shape[1],
group_size=block_shape[1],
real_top_k=top_k,
block_size_m=config["BLOCK_SIZE_M"]))
if False and use_moe_wna16_cuda:
if use_moe_wna16_cuda:
bit = 4 if use_int4_w4a16 else 8
ops.moe_wna16_gemm(A, C, B, B_scale, B_zp,
topk_weights if mul_routed_weight else None,
@@ -661,6 +671,7 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
A,
B,
C,
B_bias,
A_scale,
B_scale,
topk_weights,
@@ -689,6 +700,8 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
if B_scale is not None and B_scale.ndim == 3 else 0,
B_scale.stride(1)
if B_scale is not None and B_scale.ndim >= 2 else 0,
B_bias.stride(0) if B_bias is not None else 0,
B_bias.stride(1) if B_bias is not None else 0,
0 if block_shape is None else block_shape[0],
0 if block_shape is None else block_shape[1],
MUL_ROUTED_WEIGHT=mul_routed_weight,
@@ -699,6 +712,7 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
use_int8_w8a8=use_int8_w8a8,
use_int8_w8a16=use_int8_w8a16,
per_channel_quant=per_channel_quant,
HAS_BIAS=HAS_BIAS,
BLOCK_SIZE_K=BLOCK_SIZE_K,
FAST_F32_TO_BF16 = True,
**config,
@@ -1103,13 +1117,15 @@ def inplace_fused_experts(hidden_states: torch.Tensor,
w2_zp: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[List[int]] = None) -> None:
block_shape: Optional[List[int]] = None,
w1_bias: Optional[torch.Tensor] = None,
w2_bias: Optional[torch.Tensor] = None) -> None:
fused_experts_impl(hidden_states, w1, w2, topk_weights, topk_ids, True,
activation, apply_router_weight_on_input, use_fp8_w8a8,
use_int8_w8a8, use_int8_w8a16, use_int4_w4a16,
per_channel_quant, global_num_experts, expert_map,
w1_scale, w2_scale, w1_zp, w2_zp, a1_scale, a2_scale,
block_shape)
block_shape, w1_bias, w2_bias)
def inplace_fused_experts_fake(
@@ -1133,7 +1149,9 @@ def inplace_fused_experts_fake(
w2_zp: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[List[int]] = None) -> None:
block_shape: Optional[List[int]] = None,
w1_bias: Optional[torch.Tensor] = None,
w2_bias: Optional[torch.Tensor] = None) -> None:
pass
@@ -1167,14 +1185,16 @@ def outplace_fused_experts(
w2_zp: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[List[int]] = None) -> torch.Tensor:
block_shape: Optional[List[int]] = None,
w1_bias: Optional[torch.Tensor] = None,
w2_bias: Optional[torch.Tensor] = None) -> torch.Tensor:
return fused_experts_impl(hidden_states, w1, w2, topk_weights, topk_ids,
False, activation, apply_router_weight_on_input,
use_fp8_w8a8, use_int8_w8a8, use_int8_w8a16,
use_int4_w4a16, per_channel_quant,
global_num_experts, expert_map, w1_scale,
w2_scale, w1_zp, w2_zp, a1_scale, a2_scale,
block_shape)
block_shape, w1_bias, w2_bias)
def outplace_fused_experts_fake(
@@ -1197,7 +1217,9 @@ def outplace_fused_experts_fake(
w2_zp: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[List[int]] = None) -> torch.Tensor:
block_shape: Optional[List[int]] = None,
w1_bias: Optional[torch.Tensor] = None,
w2_bias: Optional[torch.Tensor] = None) -> torch.Tensor:
return torch.empty_like(hidden_states)
@@ -1248,7 +1270,9 @@ def fused_experts(hidden_states: torch.Tensor,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[list[int]] = None,
allow_deep_gemm: bool = False) -> torch.Tensor:
allow_deep_gemm: bool = False,
w1_bias: Optional[torch.Tensor] = None,
w2_bias: Optional[torch.Tensor] = None) -> torch.Tensor:
# For now, disable DeepGemm for small N (<= 512) until better
# permute/unpermute ops are available.
N = w1.shape[1]
@@ -1293,7 +1317,10 @@ def fused_experts(hidden_states: torch.Tensor,
w2_zp=w2_zp,
a1_scale=a1_scale,
a2_scale=a2_scale,
block_shape=block_shape)
block_shape=block_shape,
w1_bias=w1_bias,
w2_bias=w2_bias,
)
def fused_experts_impl(
@@ -1319,6 +1346,8 @@ def fused_experts_impl(
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[list[int]] = None,
w1_bias: Optional[torch.Tensor] = None,
w2_bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
# Check constraints.
if use_int4_w4a16:
@@ -1498,7 +1527,19 @@ def fused_experts_impl(
use_int4_w4a16=use_int4_w4a16,
orig_acc_dtype=hidden_states.dtype,
per_channel_quant=per_channel_quant,
block_shape=block_shape)
block_shape=block_shape,
B_bias=w1_bias)
# TODO fused kernel
def swiglu_oai(gate_up):
alpha = 1.702
limit = 7.0
gate, up = gate_up[..., ::2], gate_up[..., 1::2]
gate = gate.clamp(min=None, max=limit)
up = up.clamp(min=-limit, max=limit)
glu = gate * torch.sigmoid(gate * alpha)
gated_output = (up + 1) * glu
return gated_output
if activation == "silu":
torch.ops._C.silu_and_mul(intermediate_cache2,
@@ -1506,6 +1547,8 @@ def fused_experts_impl(
elif activation == "gelu":
torch.ops._C.gelu_and_mul(intermediate_cache2,
intermediate_cache1.view(-1, N))
elif activation == "swiglu_oai":
intermediate_cache2 = swiglu_oai(intermediate_cache1.view(-1, N))
else:
raise ValueError(f"Unsupported FusedMoe activation: {activation}")
@@ -1543,7 +1586,8 @@ def fused_experts_impl(
use_int4_w4a16=use_int4_w4a16,
orig_acc_dtype=hidden_states.dtype,
per_channel_quant=per_channel_quant,
block_shape=block_shape)
block_shape=block_shape,
B_bias=w2_bias)
ops.moe_sum(intermediate_cache3.view(*intermediate_cache3.shape),
out_hidden_states[begin_chunk_idx:end_chunk_idx])
@@ -1578,6 +1622,8 @@ def fused_moe(
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[list[int]] = None,
w1_bias: Optional[torch.Tensor] = None,
w2_bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
This function computes a Mixture of Experts (MoE) layer using two sets of
@@ -1661,7 +1707,9 @@ def fused_moe(
w2_zp=w2_zp,
a1_scale=a1_scale,
a2_scale=a2_scale,
block_shape=block_shape)
block_shape=block_shape,
w1_bias=w1_bias,
w2_bias=w2_bias)
class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
@@ -1805,7 +1853,9 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
use_int8_w8a16=self.use_int8_w8a16,
use_int4_w4a16=self.use_int4_w4a16,
per_channel_quant=self.per_channel_quant,
block_shape=self.block_shape)
block_shape=self.block_shape,
B_bias=None # TODO support B_bias
)
self.activation(activation, intermediate_cache2,
intermediate_cache1.view(-1, N))
@@ -1835,7 +1885,9 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
use_int8_w8a16=self.use_int8_w8a16,
use_int4_w4a16=self.use_int4_w4a16,
per_channel_quant=self.per_channel_quant,
block_shape=self.block_shape)
block_shape=self.block_shape,
B_bias=None # TODO support B_bias
)
return intermediate_cache3

View File

@@ -226,6 +226,8 @@ class MoEConfig:
max_num_tokens: int = MOE_DP_CHUNK_SIZE
has_bias: bool = False
@property
def tp_size(self):
return self.moe_parallel_config.tp_size
@@ -443,6 +445,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
self.fused_experts = fused_experts # type: ignore
self.topk_indices_dtype = None
self.moe = moe
self.has_bias = self.moe.has_bias
self.rocm_aiter_moe_enabled = is_rocm_aiter_moe_enabled()
if self.rocm_aiter_moe_enabled:
@@ -502,6 +505,14 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
requires_grad=False)
layer.register_parameter("w13_weight", w13_weight)
set_weight_attrs(w13_weight, extra_weight_attrs)
if self.has_bias:
w13_bias = torch.nn.Parameter(torch.zeros(
num_experts,
2 * intermediate_size_per_partition,
dtype=params_dtype),
requires_grad=False)
layer.register_parameter("w13_bias", w13_bias)
set_weight_attrs(w13_bias, extra_weight_attrs)
# down_proj (row parallel)
w2_weight = torch.nn.Parameter(torch.empty(
@@ -512,6 +523,13 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
requires_grad=False)
layer.register_parameter("w2_weight", w2_weight)
set_weight_attrs(w2_weight, extra_weight_attrs)
if self.has_bias:
w2_bias = torch.nn.Parameter(torch.zeros(num_experts,
hidden_size,
dtype=params_dtype),
requires_grad=False)
layer.register_parameter("w2_bias", w2_bias)
set_weight_attrs(w2_bias, extra_weight_attrs)
def _maybe_pad_weight(self, weight: torch.Tensor) -> torch.Tensor:
# Pad the weight tensor. This is an optimization on ROCm platform, which
@@ -634,6 +652,8 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
w1_bias=layer.w13_bias if self.has_bias else None,
w2_bias=layer.w2_bias if self.has_bias else None,
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=True,
@@ -840,6 +860,7 @@ class FusedMoE(torch.nn.Module):
e_score_correction_bias: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
has_bias: bool = False,
):
super().__init__()
if params_dtype is None:
@@ -920,6 +941,7 @@ class FusedMoE(torch.nn.Module):
in_dtype=params_dtype,
quant_dtype=quant_dtype,
max_num_tokens=MOE_DP_CHUNK_SIZE,
has_bias=has_bias,
)
self.moe_config = moe
self.quant_config = quant_config