Clean up imports (#5467)
This commit is contained in:
@@ -13,7 +13,6 @@
|
||||
# ==============================================================================
|
||||
|
||||
import math
|
||||
import os
|
||||
from typing import Callable, Optional
|
||||
|
||||
import torch
|
||||
@@ -29,6 +28,10 @@ _is_hip = is_hip()
|
||||
if _is_cuda:
|
||||
from sgl_kernel import moe_fused_gate
|
||||
|
||||
if _is_cuda or _is_hip:
|
||||
from sgl_kernel import topk_softmax
|
||||
|
||||
|
||||
expert_distribution_recorder = ExpertDistributionRecorder()
|
||||
|
||||
|
||||
@@ -59,11 +62,6 @@ def fused_topk(
|
||||
topk: int,
|
||||
renormalize: bool,
|
||||
):
|
||||
if _is_cuda or _is_hip:
|
||||
from sgl_kernel import topk_softmax
|
||||
else:
|
||||
from vllm import _custom_ops as vllm_ops
|
||||
|
||||
assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
|
||||
|
||||
M, _ = hidden_states.shape
|
||||
@@ -76,20 +74,12 @@ def fused_topk(
|
||||
M, topk, dtype=torch.int32, device=hidden_states.device
|
||||
)
|
||||
|
||||
if _is_cuda or _is_hip:
|
||||
topk_softmax(
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
token_expert_indicies,
|
||||
gating_output.float(),
|
||||
)
|
||||
else:
|
||||
vllm_ops.topk_softmax(
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
token_expert_indicies,
|
||||
gating_output.float(),
|
||||
)
|
||||
topk_softmax(
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
token_expert_indicies,
|
||||
gating_output.float(),
|
||||
)
|
||||
del token_expert_indicies
|
||||
|
||||
if renormalize:
|
||||
|
||||
Reference in New Issue
Block a user