Support compressed tensors fp8w8a8 (#4743)

This commit is contained in:
Xiaoyu Zhang
2025-03-27 04:21:25 +08:00
committed by GitHub
parent 45fdf1f7f3
commit 04e3ff6975
30 changed files with 2386 additions and 113 deletions

View File

@@ -8,7 +8,6 @@ from typing import Callable, Optional
import torch
from torch.nn import functional as F
from sglang.srt.layers.activation import GeluAndMul, SiluAndMul
from sglang.srt.layers.moe.topk import select_experts
@@ -69,6 +68,8 @@ def moe_forward_native(
activation: str = "silu",
) -> torch.Tensor:
from sglang.srt.layers.activation import GeluAndMul, SiluAndMul
topk_weights, topk_ids = select_experts(
hidden_states=x,
router_logits=router_logits,

View File

@@ -305,6 +305,7 @@ class FusedMoE(torch.nn.Module):
self.use_presharded_weights = use_presharded_weights
self.inplace = inplace
self.no_combine = no_combine
self.local_num_experts = num_experts
if quant_config is None:
self.quant_method: Optional[QuantizeMethodBase] = (
@@ -629,8 +630,6 @@ class FusedMoE(torch.nn.Module):
custom_routing_function=self.custom_routing_function,
correction_bias=self.correction_bias,
activation=self.activation,
inplace=self.inplace,
no_combine=self.no_combine,
)
if self.reduce_results and self.tp_size > 1:

View File

@@ -17,11 +17,12 @@ from typing import Callable, Optional
import torch
import torch.nn.functional as F
from sglang.srt.utils import get_compiler_backend, is_cuda
from sglang.srt.utils import get_compiler_backend, is_cuda, is_hip
_is_cuda = is_cuda()
_is_hip = is_hip()
from sglang.srt.managers.utils import ExpertDistributionRecorder
from sglang.srt.managers.expert_distribution import ExpertDistributionRecorder
expert_distribution_recorder = ExpertDistributionRecorder()
@@ -53,10 +54,10 @@ def fused_topk(
topk: int,
renormalize: bool,
):
if _is_cuda:
if _is_cuda or _is_hip:
from sgl_kernel import topk_softmax
else:
from vllm import _custom_ops as ops
from vllm import _custom_ops as vllm_ops
assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
@@ -70,7 +71,7 @@ def fused_topk(
M, topk, dtype=torch.int32, device=hidden_states.device
)
if _is_cuda:
if _is_cuda or _is_hip:
topk_softmax(
topk_weights,
topk_ids,
@@ -78,7 +79,7 @@ def fused_topk(
gating_output.float(),
)
else:
ops.topk_softmax(
vllm_ops.topk_softmax(
topk_weights,
topk_ids,
token_expert_indicies,