[Feature] Support gpt-oss and update model list (#71)
* [Docs] Update Support Models * [Feature] Support gpt-oss * [Docs] fix model support list * Fix Moe * Fix * Fix moe_ep * remove gpt oss graph support , not yet --------- Co-authored-by: hanhaowen <hanhaowen@baidu.com>
This commit is contained in:
@@ -418,6 +418,7 @@ class KunlunOps:
|
||||
w2: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
linear_weights: torch.Tensor,
|
||||
ep_rank: int,
|
||||
moe_top_k: int,
|
||||
renormalize: bool,
|
||||
inplace: bool = False,
|
||||
@@ -430,7 +431,7 @@ class KunlunOps:
|
||||
e_score_correction_bias: Optional[torch.Tensor] = None
|
||||
) -> torch.Tensor:
|
||||
"""fused_moe"""
|
||||
global_num_experts = linear_weights.shape[0]
|
||||
global_num_experts, up_gate_size, _ = w1.shape
|
||||
M, N = hidden_states.shape
|
||||
hidden_dim = w2.shape[1]
|
||||
normed_score = torch.empty(M,
|
||||
@@ -445,82 +446,119 @@ class KunlunOps:
|
||||
block_statistic = torch.zeros(
|
||||
num_blocks, global_num_experts, dtype=torch.int32, device=hidden_states.device
|
||||
)
|
||||
|
||||
torch.ops._C.moe_sigmoid_group_topk_norm(
|
||||
router_logits = router_logits.to(torch.float)
|
||||
if scoring_func == "softmax":
|
||||
torch.ops._C.moe_softmax_topk_norm(
|
||||
x=router_logits,
|
||||
normed_score=normed_score,
|
||||
topk_index=topk_ids,
|
||||
norm_score=normed_score,
|
||||
block_static=block_statistic,
|
||||
bias=e_score_correction_bias,
|
||||
scale=1.0,
|
||||
n_group=num_expert_group,
|
||||
topk_group=1,
|
||||
block_statistic=None,
|
||||
stable=True)
|
||||
elif scoring_func == "sigmoid":
|
||||
torch.ops._C.moe_sigmoid_group_topk_norm(
|
||||
x=router_logits,
|
||||
topk_index=topk_ids,
|
||||
norm_score=normed_score,
|
||||
block_static=block_statistic,
|
||||
bias=e_score_correction_bias,
|
||||
scale=1.0,
|
||||
n_group=num_expert_group,
|
||||
topk_group=topk_group,
|
||||
)
|
||||
|
||||
if w1_bias is not None or w2_bias is not None:
|
||||
# Rignt now this branch is for gpt oss
|
||||
# TODO (@xyDong23): faster here using moe_fc kernel
|
||||
normed_score = normed_score.to(hidden_states.dtype)
|
||||
out = torch.zeros(M * moe_top_k, N, dtype=hidden_states.dtype, device=hidden_states.device)
|
||||
repeat_x = hidden_states.repeat_interleave(moe_top_k, dim=0)
|
||||
topk_ids_flat = topk_ids.flatten()
|
||||
for i in range(global_num_experts):
|
||||
experts_id = ep_rank * global_num_experts + i
|
||||
selected_token = topk_ids_flat == experts_id
|
||||
if selected_token.sum():
|
||||
cur_token = repeat_x[selected_token]
|
||||
up_gate = torch.empty(selected_token.sum(), up_gate_size//2,
|
||||
dtype=cur_token.dtype, device=cur_token.device)
|
||||
groupgemm1 = cur_token@ w1[i].T
|
||||
# Add w13 bias
|
||||
if w1_bias is not None:
|
||||
groupgemm1 = groupgemm1 + w1_bias[i]
|
||||
up_gate = torch.ops._C.swigluoai_and_mul(groupgemm1)
|
||||
groupgemm2 = up_gate @ w2[i].T
|
||||
# Add w2 bias
|
||||
if w2_bias is not None:
|
||||
groupgemm2 = groupgemm2 + w2_bias[i]
|
||||
out[selected_token] = groupgemm2
|
||||
ouput = (out.view(M, moe_top_k, N) * normed_score.unsqueeze(2)).sum(dim=1).to(hidden_states.dtype)
|
||||
return ouput
|
||||
else:
|
||||
moe_expand = torch.empty((M * moe_top_k, N), dtype=hidden_states.dtype, device=hidden_states.device) # [M*top_k, N], float
|
||||
expert_m = torch.zeros(global_num_experts, dtype=torch.int32, device=hidden_states.device) # [E]
|
||||
sorted_tokens_num_lod = torch.zeros(global_num_experts + 1, dtype=torch.int32, device=hidden_states.device) # [E+1]
|
||||
sorted_tokens_idx = torch.zeros(M * moe_top_k, dtype=torch.int32, device=hidden_states.device)
|
||||
|
||||
torch.ops._C.gen_block_statistic(topk_ids,block_statistic)
|
||||
|
||||
torch.ops._C.moe_pre_sorted(
|
||||
x=hidden_states,
|
||||
topk_index=topk_ids,
|
||||
block_statistic=block_statistic,
|
||||
moe_expand=moe_expand,
|
||||
moe_index=sorted_tokens_idx,
|
||||
expert_m=expert_m,
|
||||
sorted_tokens_num_lod=sorted_tokens_num_lod)
|
||||
|
||||
y = torch.empty(M,moe_top_k,
|
||||
w1.shape[1],
|
||||
dtype=hidden_states.dtype,
|
||||
device=hidden_states.device)
|
||||
|
||||
moe_expand = moe_expand.view(M * moe_top_k, hidden_dim)
|
||||
|
||||
torch.ops._C.moe_fc(
|
||||
x=moe_expand,
|
||||
weight=w1,
|
||||
sorted_tokens_num_lod=sorted_tokens_num_lod,
|
||||
sorted_tokens_idx=sorted_tokens_idx,
|
||||
moe_topk=moe_top_k,
|
||||
y=y,
|
||||
)
|
||||
|
||||
moe_expand = torch.empty((M * moe_top_k, N), dtype=hidden_states.dtype, device=hidden_states.device) # [M*top_k, N], float
|
||||
expert_m = torch.zeros(global_num_experts, dtype=torch.int32, device=hidden_states.device) # [E]
|
||||
sorted_tokens_num_lod = torch.zeros(global_num_experts + 1, dtype=torch.int32, device=hidden_states.device) # [E+1]
|
||||
sorted_tokens_idx = torch.zeros(M * moe_top_k, dtype=torch.int32, device=hidden_states.device)
|
||||
|
||||
torch.ops._C.gen_block_statistic(topk_ids,block_statistic)
|
||||
d = y.shape[-1] // 2
|
||||
output_shape = (y.shape[:-1] + (d, ))
|
||||
out1 = torch.empty(output_shape, dtype=y.dtype, device=y.device)
|
||||
torch.ops._C.silu_and_mul(out1, y)
|
||||
|
||||
out = torch.empty(M,moe_top_k,
|
||||
w2.shape[1],
|
||||
dtype=hidden_states.dtype,
|
||||
device=hidden_states.device)
|
||||
|
||||
torch.ops._C.moe_pre_sorted(
|
||||
x=hidden_states,
|
||||
topk_index=topk_ids,
|
||||
block_statistic=block_statistic,
|
||||
moe_expand=moe_expand,
|
||||
moe_index=sorted_tokens_idx,
|
||||
expert_m=expert_m,
|
||||
sorted_tokens_num_lod=sorted_tokens_num_lod)
|
||||
out1 = out1.reshape(-1, out1.shape[-1])
|
||||
|
||||
y = torch.empty(M,moe_top_k,
|
||||
w1.shape[1],
|
||||
dtype=hidden_states.dtype,
|
||||
device=hidden_states.device)
|
||||
torch.ops._C.moe_fc(
|
||||
x=out1,
|
||||
weight=w2,
|
||||
sorted_tokens_num_lod=sorted_tokens_num_lod,
|
||||
sorted_tokens_idx=sorted_tokens_idx,
|
||||
moe_topk=moe_top_k,
|
||||
y=out,
|
||||
)
|
||||
|
||||
moe_expand = moe_expand.view(M * moe_top_k, hidden_dim)
|
||||
dequant_scale = torch.ones([M, moe_top_k], dtype = torch.float32, device=out.device)
|
||||
output = torch.empty([M, N], dtype=hidden_states.dtype, device=hidden_states.device)
|
||||
sorted_tokens_idx = sorted_tokens_idx.view(M, moe_top_k)
|
||||
|
||||
torch.ops._C.moe_fc(
|
||||
x=moe_expand,
|
||||
weight=w1,
|
||||
sorted_tokens_num_lod=sorted_tokens_num_lod,
|
||||
sorted_tokens_idx=sorted_tokens_idx,
|
||||
moe_topk=moe_top_k,
|
||||
y=y)
|
||||
|
||||
d = y.shape[-1] // 2
|
||||
output_shape = (y.shape[:-1] + (d, ))
|
||||
out1 = torch.empty(output_shape, dtype=y.dtype, device=y.device)
|
||||
torch.ops._C.silu_and_mul(out1, y)
|
||||
|
||||
out = torch.empty(M,moe_top_k,
|
||||
w2.shape[1],
|
||||
dtype=hidden_states.dtype,
|
||||
device=hidden_states.device)
|
||||
|
||||
out1 = out1.reshape(-1, out1.shape[-1])
|
||||
|
||||
torch.ops._C.moe_fc(
|
||||
x=out1,
|
||||
weight=w2,
|
||||
sorted_tokens_num_lod=sorted_tokens_num_lod,
|
||||
sorted_tokens_idx=sorted_tokens_idx,
|
||||
moe_topk=moe_top_k,
|
||||
y=out)
|
||||
|
||||
dequant_scale = torch.ones([M, moe_top_k], dtype = torch.float32, device=out.device)
|
||||
output = torch.empty([M, N], dtype=hidden_states.dtype, device=hidden_states.device)
|
||||
sorted_tokens_idx = sorted_tokens_idx.view(M, moe_top_k)
|
||||
|
||||
torch.ops._C.moe_post(
|
||||
x=out,
|
||||
moe_index=sorted_tokens_idx,
|
||||
normed_scale=normed_score,
|
||||
dequant_scale=dequant_scale,
|
||||
y=output
|
||||
)
|
||||
|
||||
return output
|
||||
torch.ops._C.moe_post(
|
||||
x=out,
|
||||
moe_index=sorted_tokens_idx,
|
||||
normed_scale=normed_score,
|
||||
dequant_scale=dequant_scale,
|
||||
y=output
|
||||
)
|
||||
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
def fused_moe_ep(
|
||||
|
||||
@@ -108,6 +108,7 @@ class UnquantizedFusedMoEMethod(VllmUnquantizedFusedMoEMethod):
|
||||
layer.w2_weight,
|
||||
router_logits,
|
||||
linear_weights,
|
||||
self.moe.ep_rank,
|
||||
top_k,
|
||||
renormalize=renormalize,
|
||||
inplace=True,
|
||||
@@ -116,6 +117,8 @@ class UnquantizedFusedMoEMethod(VllmUnquantizedFusedMoEMethod):
|
||||
topk_group=topk_group,
|
||||
scoring_func=scoring_func,
|
||||
e_score_correction_bias=e_score_correction_bias,
|
||||
w1_bias = layer.w13_bias,
|
||||
w2_bias = layer.w2_bias,
|
||||
)
|
||||
|
||||
class FusedMoE(VllmFusedMoE):
|
||||
@@ -144,6 +147,7 @@ class FusedMoE(VllmFusedMoE):
|
||||
enable_eplb: bool = False,
|
||||
num_redundant_experts: int = 0,
|
||||
is_sequence_parallel=False,
|
||||
has_bias: bool = False,
|
||||
):
|
||||
super().__init__(
|
||||
num_experts=num_experts, # Global number of experts
|
||||
@@ -186,10 +190,12 @@ class FusedMoE(VllmFusedMoE):
|
||||
moe_parallel_config=self.moe_parallel_config,
|
||||
in_dtype=model_dtype,
|
||||
max_num_tokens=envs.VLLM_MOE_DP_CHUNK_SIZE,
|
||||
has_bias=has_bias,
|
||||
# quant_config=quant_config,
|
||||
)
|
||||
self.moe_config = moe
|
||||
self.quant_config = quant_config
|
||||
self.has_bias=has_bias
|
||||
|
||||
# Note: get_quant_method will look at the layer's local_num_experts
|
||||
# for heuristic purposes, so it must be initialized first.
|
||||
|
||||
@@ -147,7 +147,6 @@ RotaryEmbedding.forward_cuda = vllm_kunlun_forward_cuda
|
||||
RotaryEmbedding.forward = vllm_kunlun_forward_cuda
|
||||
MRotaryEmbedding.forward_cuda = vllm_kunlun_mrope_forward_cuda
|
||||
MRotaryEmbedding.forward = vllm_kunlun_mrope_forward_cuda
|
||||
YaRNScalingRotaryEmbedding._compute_inv_freq = RotaryEmbedding._compute_inv_freq
|
||||
|
||||
|
||||
def Split_Norm_Rope(
|
||||
|
||||
Reference in New Issue
Block a user