Support compressed tensors fp8w8a8 (#4743)
This commit is contained in:
@@ -8,7 +8,6 @@ from typing import Callable, Optional
|
||||
import torch
|
||||
from torch.nn import functional as F
|
||||
|
||||
from sglang.srt.layers.activation import GeluAndMul, SiluAndMul
|
||||
from sglang.srt.layers.moe.topk import select_experts
|
||||
|
||||
|
||||
@@ -69,6 +68,8 @@ def moe_forward_native(
|
||||
activation: str = "silu",
|
||||
) -> torch.Tensor:
|
||||
|
||||
from sglang.srt.layers.activation import GeluAndMul, SiluAndMul
|
||||
|
||||
topk_weights, topk_ids = select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
|
||||
@@ -305,6 +305,7 @@ class FusedMoE(torch.nn.Module):
|
||||
self.use_presharded_weights = use_presharded_weights
|
||||
self.inplace = inplace
|
||||
self.no_combine = no_combine
|
||||
self.local_num_experts = num_experts
|
||||
|
||||
if quant_config is None:
|
||||
self.quant_method: Optional[QuantizeMethodBase] = (
|
||||
@@ -629,8 +630,6 @@ class FusedMoE(torch.nn.Module):
|
||||
custom_routing_function=self.custom_routing_function,
|
||||
correction_bias=self.correction_bias,
|
||||
activation=self.activation,
|
||||
inplace=self.inplace,
|
||||
no_combine=self.no_combine,
|
||||
)
|
||||
|
||||
if self.reduce_results and self.tp_size > 1:
|
||||
|
||||
@@ -17,11 +17,12 @@ from typing import Callable, Optional
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from sglang.srt.utils import get_compiler_backend, is_cuda
|
||||
from sglang.srt.utils import get_compiler_backend, is_cuda, is_hip
|
||||
|
||||
_is_cuda = is_cuda()
|
||||
_is_hip = is_hip()
|
||||
|
||||
from sglang.srt.managers.utils import ExpertDistributionRecorder
|
||||
from sglang.srt.managers.expert_distribution import ExpertDistributionRecorder
|
||||
|
||||
expert_distribution_recorder = ExpertDistributionRecorder()
|
||||
|
||||
@@ -53,10 +54,10 @@ def fused_topk(
|
||||
topk: int,
|
||||
renormalize: bool,
|
||||
):
|
||||
if _is_cuda:
|
||||
if _is_cuda or _is_hip:
|
||||
from sgl_kernel import topk_softmax
|
||||
else:
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm import _custom_ops as vllm_ops
|
||||
|
||||
assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
|
||||
|
||||
@@ -70,7 +71,7 @@ def fused_topk(
|
||||
M, topk, dtype=torch.int32, device=hidden_states.device
|
||||
)
|
||||
|
||||
if _is_cuda:
|
||||
if _is_cuda or _is_hip:
|
||||
topk_softmax(
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
@@ -78,7 +79,7 @@ def fused_topk(
|
||||
gating_output.float(),
|
||||
)
|
||||
else:
|
||||
ops.topk_softmax(
|
||||
vllm_ops.topk_softmax(
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
token_expert_indicies,
|
||||
|
||||
Reference in New Issue
Block a user