Revert "feat: update grouped_topk to support softmax and sigmoid" (#4505)
This commit is contained in:
@@ -88,6 +88,7 @@ def fused_topk(
|
|||||||
return topk_weights, topk_ids
|
return topk_weights, topk_ids
|
||||||
|
|
||||||
|
|
||||||
|
# This is used by the Deepseek V2/V3/R1 series models
|
||||||
@torch.compile(dynamic=True, backend=get_compiler_backend())
|
@torch.compile(dynamic=True, backend=get_compiler_backend())
|
||||||
def grouped_topk(
|
def grouped_topk(
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
@@ -96,17 +97,10 @@ def grouped_topk(
|
|||||||
renormalize: bool,
|
renormalize: bool,
|
||||||
num_expert_group: int = 0,
|
num_expert_group: int = 0,
|
||||||
topk_group: int = 0,
|
topk_group: int = 0,
|
||||||
scoring_func: str = "softmax",
|
|
||||||
):
|
):
|
||||||
assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
|
assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
|
||||||
|
|
||||||
if scoring_func == "softmax":
|
scores = torch.softmax(gating_output, dim=-1)
|
||||||
scores = torch.softmax(gating_output, dim=-1)
|
|
||||||
elif scoring_func == "sigmoid":
|
|
||||||
scores = gating_output.sigmoid()
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Scoring function '{scoring_func}' is not supported.")
|
|
||||||
|
|
||||||
num_token = scores.shape[0]
|
num_token = scores.shape[0]
|
||||||
group_scores = (
|
group_scores = (
|
||||||
scores.view(num_token, num_expert_group, -1).max(dim=-1).values
|
scores.view(num_token, num_expert_group, -1).max(dim=-1).values
|
||||||
@@ -130,7 +124,6 @@ def grouped_topk(
|
|||||||
return topk_weights.to(torch.float32), topk_ids.to(torch.int32)
|
return topk_weights.to(torch.float32), topk_ids.to(torch.int32)
|
||||||
|
|
||||||
|
|
||||||
# DeepSeek V2/V3/R1 uses biased_grouped_top
|
|
||||||
@torch.compile(dynamic=True, backend=get_compiler_backend())
|
@torch.compile(dynamic=True, backend=get_compiler_backend())
|
||||||
def biased_grouped_topk(
|
def biased_grouped_topk(
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
@@ -185,7 +178,7 @@ def select_experts(
|
|||||||
correction_bias: Optional[torch.Tensor] = None,
|
correction_bias: Optional[torch.Tensor] = None,
|
||||||
torch_native: bool = False,
|
torch_native: bool = False,
|
||||||
):
|
):
|
||||||
# DeepSeek V2/V3/R1 uses biased_grouped_top
|
# DeekSeekv2 uses grouped_top_k
|
||||||
if use_grouped_topk:
|
if use_grouped_topk:
|
||||||
assert topk_group is not None
|
assert topk_group is not None
|
||||||
assert num_expert_group is not None
|
assert num_expert_group is not None
|
||||||
|
|||||||
Reference in New Issue
Block a user