[gpt-oss] Add gpt-oss bf16 support
This commit is contained in:
@@ -275,6 +275,7 @@ def fused_moe_kernel(
|
||||
a_ptr,
|
||||
b_ptr,
|
||||
c_ptr,
|
||||
b_bias_ptr,
|
||||
a_scale_ptr,
|
||||
b_scale_ptr,
|
||||
topk_weights_ptr,
|
||||
@@ -303,6 +304,8 @@ def fused_moe_kernel(
|
||||
stride_bse,
|
||||
stride_bsk,
|
||||
stride_bsn,
|
||||
stride_bbe, # bias expert stride
|
||||
stride_bbn, # bias N stride
|
||||
# Block size for block-wise quantization
|
||||
group_n: tl.constexpr,
|
||||
group_k: tl.constexpr,
|
||||
@@ -321,6 +324,7 @@ def fused_moe_kernel(
|
||||
use_int8_w8a8: tl.constexpr,
|
||||
use_int8_w8a16: tl.constexpr,
|
||||
per_channel_quant: tl.constexpr,
|
||||
HAS_BIAS: tl.constexpr,
|
||||
UPGRADE: tl.constexpr,
|
||||
UPGRADE_A_OFFS: tl.constexpr,
|
||||
UPGRADE_B_OFFS: tl.constexpr,
|
||||
@@ -447,6 +451,10 @@ def fused_moe_kernel(
|
||||
else:
|
||||
a_scale = tl.load(a_scale_ptr)
|
||||
b_scale = tl.load(b_scale_ptr + off_experts)
|
||||
if HAS_BIAS:
|
||||
# bias shape: [num_experts, N]
|
||||
bias_ptrs = b_bias_ptr + off_experts * stride_bbe + offs_bn * stride_bbn
|
||||
bias = tl.load(bias_ptrs, mask=(offs_bn < N), other=0.0)
|
||||
|
||||
# -----------------------------------------------------------
|
||||
# Iterate to compute a block of the C matrix.
|
||||
@@ -494,7 +502,8 @@ def fused_moe_kernel(
|
||||
# Advance the ptrs to the next K block.
|
||||
a_ptrs += BLOCK_SIZE_K * stride_ak * SPLIT_K
|
||||
b_ptrs += BLOCK_SIZE_K * stride_bk * SPLIT_K
|
||||
|
||||
if HAS_BIAS:
|
||||
accumulator = accumulator + bias[None, :]
|
||||
if MUL_ROUTED_WEIGHT:
|
||||
moe_weight = tl.load(topk_weights_ptr + offs_token,
|
||||
mask=token_mask,
|
||||
@@ -548,7 +557,8 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
|
||||
use_int4_w4a16: bool,
|
||||
orig_acc_dtype: torch.dtype,
|
||||
per_channel_quant: bool,
|
||||
block_shape: Optional[list[int]] = None) -> None:
|
||||
block_shape: Optional[list[int]] = None,
|
||||
B_bias: Optional[torch.Tensor] = None) -> None:
|
||||
assert topk_weights is not None or not mul_routed_weight
|
||||
assert topk_weights is None or topk_weights.stride(1) == 1
|
||||
assert sorted_token_ids.stride(0) == 1
|
||||
@@ -580,7 +590,7 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
|
||||
A.shape[0] * top_k * config['BLOCK_SIZE_M'])
|
||||
grid = lambda META: (triton.cdiv(EM, META['BLOCK_SIZE_M']) * triton.cdiv(
|
||||
B.shape[1], META['BLOCK_SIZE_N']), META['SPLIT_K'])
|
||||
|
||||
HAS_BIAS = B_bias is not None
|
||||
if (use_int8_w8a16 or use_int4_w4a16) and \
|
||||
block_shape is not None and block_shape[1] > 0:
|
||||
assert B_scale is not None and B_scale.ndim == 3
|
||||
@@ -592,19 +602,19 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
|
||||
num_experts=B.shape[0],
|
||||
bit=4 if use_int4_w4a16 else 8)
|
||||
# TODO: missing config for BLOCK_SIZE_K
|
||||
# config = config.copy()
|
||||
# config.update(
|
||||
# get_moe_wna16_block_config(config=config,
|
||||
# use_moe_wna16_cuda=use_moe_wna16_cuda,
|
||||
# num_valid_tokens=num_tokens,
|
||||
# size_k=A.shape[1],
|
||||
# size_n=B.shape[1],
|
||||
# num_experts=B.shape[1],
|
||||
# group_size=block_shape[1],
|
||||
# real_top_k=top_k,
|
||||
# block_size_m=config["BLOCK_SIZE_M"]))
|
||||
config = config.copy()
|
||||
config.update(
|
||||
get_moe_wna16_block_config(config=config,
|
||||
use_moe_wna16_cuda=use_moe_wna16_cuda,
|
||||
num_valid_tokens=num_tokens,
|
||||
size_k=A.shape[1],
|
||||
size_n=B.shape[1],
|
||||
num_experts=B.shape[1],
|
||||
group_size=block_shape[1],
|
||||
real_top_k=top_k,
|
||||
block_size_m=config["BLOCK_SIZE_M"]))
|
||||
|
||||
if False and use_moe_wna16_cuda:
|
||||
if use_moe_wna16_cuda:
|
||||
bit = 4 if use_int4_w4a16 else 8
|
||||
ops.moe_wna16_gemm(A, C, B, B_scale, B_zp,
|
||||
topk_weights if mul_routed_weight else None,
|
||||
@@ -661,6 +671,7 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
|
||||
A,
|
||||
B,
|
||||
C,
|
||||
B_bias,
|
||||
A_scale,
|
||||
B_scale,
|
||||
topk_weights,
|
||||
@@ -689,6 +700,8 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
|
||||
if B_scale is not None and B_scale.ndim == 3 else 0,
|
||||
B_scale.stride(1)
|
||||
if B_scale is not None and B_scale.ndim >= 2 else 0,
|
||||
B_bias.stride(0) if B_bias is not None else 0,
|
||||
B_bias.stride(1) if B_bias is not None else 0,
|
||||
0 if block_shape is None else block_shape[0],
|
||||
0 if block_shape is None else block_shape[1],
|
||||
MUL_ROUTED_WEIGHT=mul_routed_weight,
|
||||
@@ -699,6 +712,7 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
|
||||
use_int8_w8a8=use_int8_w8a8,
|
||||
use_int8_w8a16=use_int8_w8a16,
|
||||
per_channel_quant=per_channel_quant,
|
||||
HAS_BIAS=HAS_BIAS,
|
||||
BLOCK_SIZE_K=BLOCK_SIZE_K,
|
||||
FAST_F32_TO_BF16 = True,
|
||||
**config,
|
||||
@@ -1103,13 +1117,15 @@ def inplace_fused_experts(hidden_states: torch.Tensor,
|
||||
w2_zp: Optional[torch.Tensor] = None,
|
||||
a1_scale: Optional[torch.Tensor] = None,
|
||||
a2_scale: Optional[torch.Tensor] = None,
|
||||
block_shape: Optional[List[int]] = None) -> None:
|
||||
block_shape: Optional[List[int]] = None,
|
||||
w1_bias: Optional[torch.Tensor] = None,
|
||||
w2_bias: Optional[torch.Tensor] = None) -> None:
|
||||
fused_experts_impl(hidden_states, w1, w2, topk_weights, topk_ids, True,
|
||||
activation, apply_router_weight_on_input, use_fp8_w8a8,
|
||||
use_int8_w8a8, use_int8_w8a16, use_int4_w4a16,
|
||||
per_channel_quant, global_num_experts, expert_map,
|
||||
w1_scale, w2_scale, w1_zp, w2_zp, a1_scale, a2_scale,
|
||||
block_shape)
|
||||
block_shape, w1_bias, w2_bias)
|
||||
|
||||
|
||||
def inplace_fused_experts_fake(
|
||||
@@ -1133,7 +1149,9 @@ def inplace_fused_experts_fake(
|
||||
w2_zp: Optional[torch.Tensor] = None,
|
||||
a1_scale: Optional[torch.Tensor] = None,
|
||||
a2_scale: Optional[torch.Tensor] = None,
|
||||
block_shape: Optional[List[int]] = None) -> None:
|
||||
block_shape: Optional[List[int]] = None,
|
||||
w1_bias: Optional[torch.Tensor] = None,
|
||||
w2_bias: Optional[torch.Tensor] = None) -> None:
|
||||
pass
|
||||
|
||||
|
||||
@@ -1167,14 +1185,16 @@ def outplace_fused_experts(
|
||||
w2_zp: Optional[torch.Tensor] = None,
|
||||
a1_scale: Optional[torch.Tensor] = None,
|
||||
a2_scale: Optional[torch.Tensor] = None,
|
||||
block_shape: Optional[List[int]] = None) -> torch.Tensor:
|
||||
block_shape: Optional[List[int]] = None,
|
||||
w1_bias: Optional[torch.Tensor] = None,
|
||||
w2_bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
return fused_experts_impl(hidden_states, w1, w2, topk_weights, topk_ids,
|
||||
False, activation, apply_router_weight_on_input,
|
||||
use_fp8_w8a8, use_int8_w8a8, use_int8_w8a16,
|
||||
use_int4_w4a16, per_channel_quant,
|
||||
global_num_experts, expert_map, w1_scale,
|
||||
w2_scale, w1_zp, w2_zp, a1_scale, a2_scale,
|
||||
block_shape)
|
||||
block_shape, w1_bias, w2_bias)
|
||||
|
||||
|
||||
def outplace_fused_experts_fake(
|
||||
@@ -1197,7 +1217,9 @@ def outplace_fused_experts_fake(
|
||||
w2_zp: Optional[torch.Tensor] = None,
|
||||
a1_scale: Optional[torch.Tensor] = None,
|
||||
a2_scale: Optional[torch.Tensor] = None,
|
||||
block_shape: Optional[List[int]] = None) -> torch.Tensor:
|
||||
block_shape: Optional[List[int]] = None,
|
||||
w1_bias: Optional[torch.Tensor] = None,
|
||||
w2_bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
return torch.empty_like(hidden_states)
|
||||
|
||||
|
||||
@@ -1248,7 +1270,9 @@ def fused_experts(hidden_states: torch.Tensor,
|
||||
a1_scale: Optional[torch.Tensor] = None,
|
||||
a2_scale: Optional[torch.Tensor] = None,
|
||||
block_shape: Optional[list[int]] = None,
|
||||
allow_deep_gemm: bool = False) -> torch.Tensor:
|
||||
allow_deep_gemm: bool = False,
|
||||
w1_bias: Optional[torch.Tensor] = None,
|
||||
w2_bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
# For now, disable DeepGemm for small N (<= 512) until better
|
||||
# permute/unpermute ops are available.
|
||||
N = w1.shape[1]
|
||||
@@ -1293,7 +1317,10 @@ def fused_experts(hidden_states: torch.Tensor,
|
||||
w2_zp=w2_zp,
|
||||
a1_scale=a1_scale,
|
||||
a2_scale=a2_scale,
|
||||
block_shape=block_shape)
|
||||
block_shape=block_shape,
|
||||
w1_bias=w1_bias,
|
||||
w2_bias=w2_bias,
|
||||
)
|
||||
|
||||
|
||||
def fused_experts_impl(
|
||||
@@ -1319,6 +1346,8 @@ def fused_experts_impl(
|
||||
a1_scale: Optional[torch.Tensor] = None,
|
||||
a2_scale: Optional[torch.Tensor] = None,
|
||||
block_shape: Optional[list[int]] = None,
|
||||
w1_bias: Optional[torch.Tensor] = None,
|
||||
w2_bias: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
# Check constraints.
|
||||
if use_int4_w4a16:
|
||||
@@ -1498,7 +1527,19 @@ def fused_experts_impl(
|
||||
use_int4_w4a16=use_int4_w4a16,
|
||||
orig_acc_dtype=hidden_states.dtype,
|
||||
per_channel_quant=per_channel_quant,
|
||||
block_shape=block_shape)
|
||||
block_shape=block_shape,
|
||||
B_bias=w1_bias)
|
||||
|
||||
# TODO fused kernel
|
||||
def swiglu_oai(gate_up):
|
||||
alpha = 1.702
|
||||
limit = 7.0
|
||||
gate, up = gate_up[..., ::2], gate_up[..., 1::2]
|
||||
gate = gate.clamp(min=None, max=limit)
|
||||
up = up.clamp(min=-limit, max=limit)
|
||||
glu = gate * torch.sigmoid(gate * alpha)
|
||||
gated_output = (up + 1) * glu
|
||||
return gated_output
|
||||
|
||||
if activation == "silu":
|
||||
torch.ops._C.silu_and_mul(intermediate_cache2,
|
||||
@@ -1506,6 +1547,8 @@ def fused_experts_impl(
|
||||
elif activation == "gelu":
|
||||
torch.ops._C.gelu_and_mul(intermediate_cache2,
|
||||
intermediate_cache1.view(-1, N))
|
||||
elif activation == "swiglu_oai":
|
||||
intermediate_cache2 = swiglu_oai(intermediate_cache1.view(-1, N))
|
||||
else:
|
||||
raise ValueError(f"Unsupported FusedMoe activation: {activation}")
|
||||
|
||||
@@ -1543,7 +1586,8 @@ def fused_experts_impl(
|
||||
use_int4_w4a16=use_int4_w4a16,
|
||||
orig_acc_dtype=hidden_states.dtype,
|
||||
per_channel_quant=per_channel_quant,
|
||||
block_shape=block_shape)
|
||||
block_shape=block_shape,
|
||||
B_bias=w2_bias)
|
||||
|
||||
ops.moe_sum(intermediate_cache3.view(*intermediate_cache3.shape),
|
||||
out_hidden_states[begin_chunk_idx:end_chunk_idx])
|
||||
@@ -1578,6 +1622,8 @@ def fused_moe(
|
||||
a1_scale: Optional[torch.Tensor] = None,
|
||||
a2_scale: Optional[torch.Tensor] = None,
|
||||
block_shape: Optional[list[int]] = None,
|
||||
w1_bias: Optional[torch.Tensor] = None,
|
||||
w2_bias: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
This function computes a Mixture of Experts (MoE) layer using two sets of
|
||||
@@ -1661,7 +1707,9 @@ def fused_moe(
|
||||
w2_zp=w2_zp,
|
||||
a1_scale=a1_scale,
|
||||
a2_scale=a2_scale,
|
||||
block_shape=block_shape)
|
||||
block_shape=block_shape,
|
||||
w1_bias=w1_bias,
|
||||
w2_bias=w2_bias)
|
||||
|
||||
|
||||
class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
@@ -1805,7 +1853,9 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
use_int8_w8a16=self.use_int8_w8a16,
|
||||
use_int4_w4a16=self.use_int4_w4a16,
|
||||
per_channel_quant=self.per_channel_quant,
|
||||
block_shape=self.block_shape)
|
||||
block_shape=self.block_shape,
|
||||
B_bias=None # TODO support B_bias
|
||||
)
|
||||
|
||||
self.activation(activation, intermediate_cache2,
|
||||
intermediate_cache1.view(-1, N))
|
||||
@@ -1835,7 +1885,9 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
use_int8_w8a16=self.use_int8_w8a16,
|
||||
use_int4_w4a16=self.use_int4_w4a16,
|
||||
per_channel_quant=self.per_channel_quant,
|
||||
block_shape=self.block_shape)
|
||||
block_shape=self.block_shape,
|
||||
B_bias=None # TODO support B_bias
|
||||
)
|
||||
|
||||
return intermediate_cache3
|
||||
|
||||
|
||||
@@ -226,6 +226,8 @@ class MoEConfig:
|
||||
|
||||
max_num_tokens: int = MOE_DP_CHUNK_SIZE
|
||||
|
||||
has_bias: bool = False
|
||||
|
||||
@property
|
||||
def tp_size(self):
|
||||
return self.moe_parallel_config.tp_size
|
||||
@@ -443,6 +445,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
self.fused_experts = fused_experts # type: ignore
|
||||
self.topk_indices_dtype = None
|
||||
self.moe = moe
|
||||
self.has_bias = self.moe.has_bias
|
||||
|
||||
self.rocm_aiter_moe_enabled = is_rocm_aiter_moe_enabled()
|
||||
if self.rocm_aiter_moe_enabled:
|
||||
@@ -502,6 +505,14 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
requires_grad=False)
|
||||
layer.register_parameter("w13_weight", w13_weight)
|
||||
set_weight_attrs(w13_weight, extra_weight_attrs)
|
||||
if self.has_bias:
|
||||
w13_bias = torch.nn.Parameter(torch.zeros(
|
||||
num_experts,
|
||||
2 * intermediate_size_per_partition,
|
||||
dtype=params_dtype),
|
||||
requires_grad=False)
|
||||
layer.register_parameter("w13_bias", w13_bias)
|
||||
set_weight_attrs(w13_bias, extra_weight_attrs)
|
||||
|
||||
# down_proj (row parallel)
|
||||
w2_weight = torch.nn.Parameter(torch.empty(
|
||||
@@ -512,6 +523,13 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
requires_grad=False)
|
||||
layer.register_parameter("w2_weight", w2_weight)
|
||||
set_weight_attrs(w2_weight, extra_weight_attrs)
|
||||
if self.has_bias:
|
||||
w2_bias = torch.nn.Parameter(torch.zeros(num_experts,
|
||||
hidden_size,
|
||||
dtype=params_dtype),
|
||||
requires_grad=False)
|
||||
layer.register_parameter("w2_bias", w2_bias)
|
||||
set_weight_attrs(w2_bias, extra_weight_attrs)
|
||||
|
||||
def _maybe_pad_weight(self, weight: torch.Tensor) -> torch.Tensor:
|
||||
# Pad the weight tensor. This is an optimization on ROCm platform, which
|
||||
@@ -634,6 +652,8 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
hidden_states=x,
|
||||
w1=layer.w13_weight,
|
||||
w2=layer.w2_weight,
|
||||
w1_bias=layer.w13_bias if self.has_bias else None,
|
||||
w2_bias=layer.w2_bias if self.has_bias else None,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
inplace=True,
|
||||
@@ -840,6 +860,7 @@ class FusedMoE(torch.nn.Module):
|
||||
e_score_correction_bias: Optional[torch.Tensor] = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
activation: str = "silu",
|
||||
has_bias: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
if params_dtype is None:
|
||||
@@ -920,6 +941,7 @@ class FusedMoE(torch.nn.Module):
|
||||
in_dtype=params_dtype,
|
||||
quant_dtype=quant_dtype,
|
||||
max_num_tokens=MOE_DP_CHUNK_SIZE,
|
||||
has_bias=has_bias,
|
||||
)
|
||||
self.moe_config = moe
|
||||
self.quant_config = quant_config
|
||||
|
||||
Reference in New Issue
Block a user