Clean up imports (#5467)

This commit is contained in:
Lianmin Zheng
2025-04-16 15:26:49 -07:00
committed by GitHub
parent d7bc19a46a
commit 177320a582
51 changed files with 376 additions and 573 deletions

View File

@@ -13,7 +13,6 @@
# ==============================================================================
import math
import os
from typing import Callable, Optional
import torch
@@ -29,6 +28,10 @@ _is_hip = is_hip()
if _is_cuda:
from sgl_kernel import moe_fused_gate
if _is_cuda or _is_hip:
from sgl_kernel import topk_softmax
expert_distribution_recorder = ExpertDistributionRecorder()
@@ -59,11 +62,6 @@ def fused_topk(
topk: int,
renormalize: bool,
):
if _is_cuda or _is_hip:
from sgl_kernel import topk_softmax
else:
from vllm import _custom_ops as vllm_ops
assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
M, _ = hidden_states.shape
@@ -76,20 +74,12 @@ def fused_topk(
M, topk, dtype=torch.int32, device=hidden_states.device
)
if _is_cuda or _is_hip:
topk_softmax(
topk_weights,
topk_ids,
token_expert_indicies,
gating_output.float(),
)
else:
vllm_ops.topk_softmax(
topk_weights,
topk_ids,
token_expert_indicies,
gating_output.float(),
)
topk_softmax(
topk_weights,
topk_ids,
token_expert_indicies,
gating_output.float(),
)
del token_expert_indicies
if renormalize: