[Feature] Support gpt-oss and update model list (#71)

* [Docs] Update Support Models

* [Feature] Support gpt-oss

* [Docs] fix model support list

* Fix Moe

* Fix

* Fix moe_ep

* remove gpt oss graph support , not yet

---------

Co-authored-by: hanhaowen <hanhaowen@baidu.com>
This commit is contained in:
Xinyu Dong
2026-01-04 21:19:49 +08:00
committed by GitHub
parent ded24f5026
commit fe666fb24f
6 changed files with 537 additions and 340 deletions

View File

@@ -418,6 +418,7 @@ class KunlunOps:
w2: torch.Tensor,
router_logits: torch.Tensor,
linear_weights: torch.Tensor,
ep_rank: int,
moe_top_k: int,
renormalize: bool,
inplace: bool = False,
@@ -430,7 +431,7 @@ class KunlunOps:
e_score_correction_bias: Optional[torch.Tensor] = None
) -> torch.Tensor:
"""fused_moe"""
global_num_experts = linear_weights.shape[0]
global_num_experts, up_gate_size, _ = w1.shape
M, N = hidden_states.shape
hidden_dim = w2.shape[1]
normed_score = torch.empty(M,
@@ -445,82 +446,119 @@ class KunlunOps:
block_statistic = torch.zeros(
num_blocks, global_num_experts, dtype=torch.int32, device=hidden_states.device
)
torch.ops._C.moe_sigmoid_group_topk_norm(
router_logits = router_logits.to(torch.float)
if scoring_func == "softmax":
torch.ops._C.moe_softmax_topk_norm(
x=router_logits,
normed_score=normed_score,
topk_index=topk_ids,
norm_score=normed_score,
block_static=block_statistic,
bias=e_score_correction_bias,
scale=1.0,
n_group=num_expert_group,
topk_group=1,
block_statistic=None,
stable=True)
elif scoring_func == "sigmoid":
torch.ops._C.moe_sigmoid_group_topk_norm(
x=router_logits,
topk_index=topk_ids,
norm_score=normed_score,
block_static=block_statistic,
bias=e_score_correction_bias,
scale=1.0,
n_group=num_expert_group,
topk_group=topk_group,
)
if w1_bias is not None or w2_bias is not None:
# Rignt now this branch is for gpt oss
# TODO (@xyDong23): faster here using moe_fc kernel
normed_score = normed_score.to(hidden_states.dtype)
out = torch.zeros(M * moe_top_k, N, dtype=hidden_states.dtype, device=hidden_states.device)
repeat_x = hidden_states.repeat_interleave(moe_top_k, dim=0)
topk_ids_flat = topk_ids.flatten()
for i in range(global_num_experts):
experts_id = ep_rank * global_num_experts + i
selected_token = topk_ids_flat == experts_id
if selected_token.sum():
cur_token = repeat_x[selected_token]
up_gate = torch.empty(selected_token.sum(), up_gate_size//2,
dtype=cur_token.dtype, device=cur_token.device)
groupgemm1 = cur_token@ w1[i].T
# Add w13 bias
if w1_bias is not None:
groupgemm1 = groupgemm1 + w1_bias[i]
up_gate = torch.ops._C.swigluoai_and_mul(groupgemm1)
groupgemm2 = up_gate @ w2[i].T
# Add w2 bias
if w2_bias is not None:
groupgemm2 = groupgemm2 + w2_bias[i]
out[selected_token] = groupgemm2
ouput = (out.view(M, moe_top_k, N) * normed_score.unsqueeze(2)).sum(dim=1).to(hidden_states.dtype)
return ouput
else:
moe_expand = torch.empty((M * moe_top_k, N), dtype=hidden_states.dtype, device=hidden_states.device) # [M*top_k, N], float
expert_m = torch.zeros(global_num_experts, dtype=torch.int32, device=hidden_states.device) # [E]
sorted_tokens_num_lod = torch.zeros(global_num_experts + 1, dtype=torch.int32, device=hidden_states.device) # [E+1]
sorted_tokens_idx = torch.zeros(M * moe_top_k, dtype=torch.int32, device=hidden_states.device)
torch.ops._C.gen_block_statistic(topk_ids,block_statistic)
torch.ops._C.moe_pre_sorted(
x=hidden_states,
topk_index=topk_ids,
block_statistic=block_statistic,
moe_expand=moe_expand,
moe_index=sorted_tokens_idx,
expert_m=expert_m,
sorted_tokens_num_lod=sorted_tokens_num_lod)
y = torch.empty(M,moe_top_k,
w1.shape[1],
dtype=hidden_states.dtype,
device=hidden_states.device)
moe_expand = moe_expand.view(M * moe_top_k, hidden_dim)
torch.ops._C.moe_fc(
x=moe_expand,
weight=w1,
sorted_tokens_num_lod=sorted_tokens_num_lod,
sorted_tokens_idx=sorted_tokens_idx,
moe_topk=moe_top_k,
y=y,
)
moe_expand = torch.empty((M * moe_top_k, N), dtype=hidden_states.dtype, device=hidden_states.device) # [M*top_k, N], float
expert_m = torch.zeros(global_num_experts, dtype=torch.int32, device=hidden_states.device) # [E]
sorted_tokens_num_lod = torch.zeros(global_num_experts + 1, dtype=torch.int32, device=hidden_states.device) # [E+1]
sorted_tokens_idx = torch.zeros(M * moe_top_k, dtype=torch.int32, device=hidden_states.device)
torch.ops._C.gen_block_statistic(topk_ids,block_statistic)
d = y.shape[-1] // 2
output_shape = (y.shape[:-1] + (d, ))
out1 = torch.empty(output_shape, dtype=y.dtype, device=y.device)
torch.ops._C.silu_and_mul(out1, y)
out = torch.empty(M,moe_top_k,
w2.shape[1],
dtype=hidden_states.dtype,
device=hidden_states.device)
torch.ops._C.moe_pre_sorted(
x=hidden_states,
topk_index=topk_ids,
block_statistic=block_statistic,
moe_expand=moe_expand,
moe_index=sorted_tokens_idx,
expert_m=expert_m,
sorted_tokens_num_lod=sorted_tokens_num_lod)
out1 = out1.reshape(-1, out1.shape[-1])
y = torch.empty(M,moe_top_k,
w1.shape[1],
dtype=hidden_states.dtype,
device=hidden_states.device)
torch.ops._C.moe_fc(
x=out1,
weight=w2,
sorted_tokens_num_lod=sorted_tokens_num_lod,
sorted_tokens_idx=sorted_tokens_idx,
moe_topk=moe_top_k,
y=out,
)
moe_expand = moe_expand.view(M * moe_top_k, hidden_dim)
dequant_scale = torch.ones([M, moe_top_k], dtype = torch.float32, device=out.device)
output = torch.empty([M, N], dtype=hidden_states.dtype, device=hidden_states.device)
sorted_tokens_idx = sorted_tokens_idx.view(M, moe_top_k)
torch.ops._C.moe_fc(
x=moe_expand,
weight=w1,
sorted_tokens_num_lod=sorted_tokens_num_lod,
sorted_tokens_idx=sorted_tokens_idx,
moe_topk=moe_top_k,
y=y)
d = y.shape[-1] // 2
output_shape = (y.shape[:-1] + (d, ))
out1 = torch.empty(output_shape, dtype=y.dtype, device=y.device)
torch.ops._C.silu_and_mul(out1, y)
out = torch.empty(M,moe_top_k,
w2.shape[1],
dtype=hidden_states.dtype,
device=hidden_states.device)
out1 = out1.reshape(-1, out1.shape[-1])
torch.ops._C.moe_fc(
x=out1,
weight=w2,
sorted_tokens_num_lod=sorted_tokens_num_lod,
sorted_tokens_idx=sorted_tokens_idx,
moe_topk=moe_top_k,
y=out)
dequant_scale = torch.ones([M, moe_top_k], dtype = torch.float32, device=out.device)
output = torch.empty([M, N], dtype=hidden_states.dtype, device=hidden_states.device)
sorted_tokens_idx = sorted_tokens_idx.view(M, moe_top_k)
torch.ops._C.moe_post(
x=out,
moe_index=sorted_tokens_idx,
normed_scale=normed_score,
dequant_scale=dequant_scale,
y=output
)
return output
torch.ops._C.moe_post(
x=out,
moe_index=sorted_tokens_idx,
normed_scale=normed_score,
dequant_scale=dequant_scale,
y=output
)
return output
@staticmethod
def fused_moe_ep(

View File

@@ -108,6 +108,7 @@ class UnquantizedFusedMoEMethod(VllmUnquantizedFusedMoEMethod):
layer.w2_weight,
router_logits,
linear_weights,
self.moe.ep_rank,
top_k,
renormalize=renormalize,
inplace=True,
@@ -116,6 +117,8 @@ class UnquantizedFusedMoEMethod(VllmUnquantizedFusedMoEMethod):
topk_group=topk_group,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias,
w1_bias = layer.w13_bias,
w2_bias = layer.w2_bias,
)
class FusedMoE(VllmFusedMoE):
@@ -144,6 +147,7 @@ class FusedMoE(VllmFusedMoE):
enable_eplb: bool = False,
num_redundant_experts: int = 0,
is_sequence_parallel=False,
has_bias: bool = False,
):
super().__init__(
num_experts=num_experts, # Global number of experts
@@ -186,10 +190,12 @@ class FusedMoE(VllmFusedMoE):
moe_parallel_config=self.moe_parallel_config,
in_dtype=model_dtype,
max_num_tokens=envs.VLLM_MOE_DP_CHUNK_SIZE,
has_bias=has_bias,
# quant_config=quant_config,
)
self.moe_config = moe
self.quant_config = quant_config
self.has_bias=has_bias
# Note: get_quant_method will look at the layer's local_num_experts
# for heuristic purposes, so it must be initialized first.

View File

@@ -147,7 +147,6 @@ RotaryEmbedding.forward_cuda = vllm_kunlun_forward_cuda
RotaryEmbedding.forward = vllm_kunlun_forward_cuda
MRotaryEmbedding.forward_cuda = vllm_kunlun_mrope_forward_cuda
MRotaryEmbedding.forward = vllm_kunlun_mrope_forward_cuda
YaRNScalingRotaryEmbedding._compute_inv_freq = RotaryEmbedding._compute_inv_freq
def Split_Norm_Rope(