Compare commits
4 Commits
v0.5.4_dev
...
v0.5.3_dev
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a0fb70e9c1 | ||
|
|
848c5b8290 | ||
|
|
68277eac30 | ||
|
|
8f7453e3af |
@@ -5,6 +5,15 @@ from typing import List, Optional, Tuple
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from sglang.srt.utils import get_bool_env_var, is_hip, is_hpu, is_npu
|
from sglang.srt.utils import get_bool_env_var, is_hip, is_hpu, is_npu
|
||||||
|
try:
|
||||||
|
from lmslim import quant_ops
|
||||||
|
from lmslim import quant_tools
|
||||||
|
except Exception:
|
||||||
|
print("INFO: Please install lmslim if you want to infer gptq or awq or w8a8 model.\n")
|
||||||
|
try:
|
||||||
|
import lightop
|
||||||
|
except Exception:
|
||||||
|
print("INFO: Please install lightop if you want to infer awq of marlin.\n")
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
use_vllm_custom_allreduce = get_bool_env_var(
|
use_vllm_custom_allreduce = get_bool_env_var(
|
||||||
@@ -175,3 +184,25 @@ def mscclpp_allreduce(
|
|||||||
context: int, inp: torch.Tensor, out: torch.Tensor, nthreads: int, nblocks: int
|
context: int, inp: torch.Tensor, out: torch.Tensor, nthreads: int, nblocks: int
|
||||||
) -> None:
|
) -> None:
|
||||||
return sgl_kernel.allreduce.mscclpp_allreduce(context, inp, out, nthreads, nblocks)
|
return sgl_kernel.allreduce.mscclpp_allreduce(context, inp, out, nthreads, nblocks)
|
||||||
|
|
||||||
|
def triton_scaled_mm(a: torch.Tensor,
|
||||||
|
b: torch.Tensor,
|
||||||
|
scale_a: torch.Tensor,
|
||||||
|
scale_b: torch.Tensor,
|
||||||
|
out_dtype: torch.dtype,
|
||||||
|
bias: Optional[torch.Tensor] = None,
|
||||||
|
best_config:Optional[list] = None) -> torch.Tensor:
|
||||||
|
|
||||||
|
return quant_ops.triton_scaled_mm(a, b,scale_a,scale_b,out_dtype,bias,best_config)
|
||||||
|
|
||||||
|
def triton_int8_gemm_helper(m: int,
|
||||||
|
n: int,
|
||||||
|
k: int,
|
||||||
|
per_token_act_quant: bool,
|
||||||
|
per_out_channel_weight_quant: bool,
|
||||||
|
use_bias: bool,
|
||||||
|
out_dtype: type[torch.dtype] = torch.float16,
|
||||||
|
device: str = "cuda:0",
|
||||||
|
best_config:Optional[list] = None,
|
||||||
|
repeat:Optional[int] = 2):
|
||||||
|
return quant_tools.triton_int8_gemm_helper(m,n,k,per_token_act_quant,per_out_channel_weight_quant,use_bias,out_dtype,device,best_config,repeat)
|
||||||
@@ -508,6 +508,7 @@ class ModelConfig:
|
|||||||
"petit_nvfp4",
|
"petit_nvfp4",
|
||||||
"quark",
|
"quark",
|
||||||
"mxfp4",
|
"mxfp4",
|
||||||
|
"slimquant_w4a8_marlin",
|
||||||
]
|
]
|
||||||
optimized_quantization_methods = [
|
optimized_quantization_methods = [
|
||||||
"fp8",
|
"fp8",
|
||||||
@@ -526,6 +527,7 @@ class ModelConfig:
|
|||||||
"qoq",
|
"qoq",
|
||||||
"w4afp8",
|
"w4afp8",
|
||||||
"petit_nvfp4",
|
"petit_nvfp4",
|
||||||
|
"slimquant_w4a8_marlin",
|
||||||
]
|
]
|
||||||
compatible_quantization_methods = {
|
compatible_quantization_methods = {
|
||||||
"modelopt_fp4": ["modelopt"],
|
"modelopt_fp4": ["modelopt"],
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ import time
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from sglang import ServerArgs
|
from sglang.srt.server_args import ServerArgs
|
||||||
from sglang.srt.managers.cache_controller import HiCacheController
|
from sglang.srt.managers.cache_controller import HiCacheController
|
||||||
from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
|
from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
|
||||||
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
|
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
|
||||||
|
|||||||
@@ -127,21 +127,45 @@ class RMSNorm(CustomOp):
|
|||||||
return output, residual_out
|
return output, residual_out
|
||||||
return rms_norm(x, self.weight.data, self.variance_epsilon)
|
return rms_norm(x, self.weight.data, self.variance_epsilon)
|
||||||
|
|
||||||
|
# def forward_hip(
|
||||||
|
# self,
|
||||||
|
# x: torch.Tensor,
|
||||||
|
# residual: Optional[torch.Tensor] = None,
|
||||||
|
# ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
||||||
|
# if not x.is_contiguous():
|
||||||
|
# # NOTE: Remove this if aiter kernel supports discontinuous input
|
||||||
|
# x = x.contiguous()
|
||||||
|
# if residual is not None:
|
||||||
|
# if _vllm_version < Version("0.9"):
|
||||||
|
# fused_add_rms_norm(x, residual, self.weight.data, self.variance_epsilon)
|
||||||
|
# return x, residual
|
||||||
|
# else:
|
||||||
|
# residual_out = torch.empty_like(x)
|
||||||
|
# output = torch.empty_like(x)
|
||||||
|
# fused_add_rms_norm(
|
||||||
|
# output,
|
||||||
|
# x,
|
||||||
|
# residual_out,
|
||||||
|
# residual,
|
||||||
|
# self.weight.data,
|
||||||
|
# self.variance_epsilon,
|
||||||
|
# )
|
||||||
|
# return output, residual_out
|
||||||
|
# out = torch.empty_like(x)
|
||||||
|
# rms_norm(out, x, self.weight.data, self.variance_epsilon)
|
||||||
|
# return out
|
||||||
def forward_hip(
|
def forward_hip(
|
||||||
self,
|
self,
|
||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
residual: Optional[torch.Tensor] = None,
|
residual: Optional[torch.Tensor] = None,
|
||||||
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
):
|
||||||
if not x.is_contiguous():
|
if not x.is_contiguous():
|
||||||
# NOTE: Remove this if aiter kernel supports discontinuous input
|
|
||||||
x = x.contiguous()
|
x = x.contiguous()
|
||||||
|
|
||||||
if residual is not None:
|
if residual is not None:
|
||||||
if _vllm_version < Version("0.9"):
|
try:
|
||||||
fused_add_rms_norm(x, residual, self.weight.data, self.variance_epsilon)
|
|
||||||
return x, residual
|
|
||||||
else:
|
|
||||||
residual_out = torch.empty_like(x)
|
|
||||||
output = torch.empty_like(x)
|
output = torch.empty_like(x)
|
||||||
|
residual_out = torch.empty_like(x)
|
||||||
fused_add_rms_norm(
|
fused_add_rms_norm(
|
||||||
output,
|
output,
|
||||||
x,
|
x,
|
||||||
@@ -151,10 +175,21 @@ class RMSNorm(CustomOp):
|
|||||||
self.variance_epsilon,
|
self.variance_epsilon,
|
||||||
)
|
)
|
||||||
return output, residual_out
|
return output, residual_out
|
||||||
|
except TypeError:
|
||||||
|
fused_add_rms_norm(
|
||||||
|
x,
|
||||||
|
residual,
|
||||||
|
self.weight.data,
|
||||||
|
self.variance_epsilon,
|
||||||
|
)
|
||||||
|
return x, residual
|
||||||
|
|
||||||
out = torch.empty_like(x)
|
out = torch.empty_like(x)
|
||||||
rms_norm(out, x, self.weight.data, self.variance_epsilon)
|
rms_norm(out, x, self.weight.data, self.variance_epsilon)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def forward_native(
|
def forward_native(
|
||||||
self,
|
self,
|
||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ from __future__ import annotations
|
|||||||
import logging
|
import logging
|
||||||
from typing import TYPE_CHECKING, List, Optional, Union
|
from typing import TYPE_CHECKING, List, Optional, Union
|
||||||
|
|
||||||
|
from sglang.srt.layers.quantization.slimquant_w4a8_marlin import SlimQuantW4A8Int8MarlinConfig
|
||||||
import torch
|
import torch
|
||||||
import triton
|
import triton
|
||||||
import triton.language as tl
|
import triton.language as tl
|
||||||
@@ -124,7 +125,6 @@ class EPMoE(FusedMoE):
|
|||||||
)
|
)
|
||||||
|
|
||||||
self.intermediate_size = intermediate_size
|
self.intermediate_size = intermediate_size
|
||||||
|
|
||||||
if isinstance(quant_config, Fp8Config):
|
if isinstance(quant_config, Fp8Config):
|
||||||
self.use_block_quant = getattr(self.quant_method, "block_quant", False)
|
self.use_block_quant = getattr(self.quant_method, "block_quant", False)
|
||||||
self.block_shape = (
|
self.block_shape = (
|
||||||
@@ -135,11 +135,23 @@ class EPMoE(FusedMoE):
|
|||||||
self.use_fp8_w8a8 = True
|
self.use_fp8_w8a8 = True
|
||||||
self.fp8_dtype = torch.float8_e4m3fn
|
self.fp8_dtype = torch.float8_e4m3fn
|
||||||
self.activation_scheme = quant_config.activation_scheme
|
self.activation_scheme = quant_config.activation_scheme
|
||||||
|
self.use_w4a8_marlin = False
|
||||||
|
elif isinstance(quant_config, SlimQuantW4A8Int8MarlinConfig):
|
||||||
|
self.use_block_quant = getattr(self.quant_method, "block_quant", False)
|
||||||
|
self.block_shape = (
|
||||||
|
self.quant_method.quant_config.weight_block_size
|
||||||
|
if self.use_block_quant
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
self.use_fp8_w8a8 = False
|
||||||
|
self.activation_scheme = None
|
||||||
|
self.use_w4a8_marlin = True
|
||||||
else:
|
else:
|
||||||
self.use_fp8_w8a8 = False
|
self.use_fp8_w8a8 = False
|
||||||
self.use_block_quant = False
|
self.use_block_quant = False
|
||||||
self.block_shape = None
|
self.block_shape = None
|
||||||
self.activation_scheme = None
|
self.activation_scheme = None
|
||||||
|
self.use_w4a8_marlin = False
|
||||||
|
|
||||||
def forward(self, hidden_states: torch.Tensor, topk_output: TopKOutput):
|
def forward(self, hidden_states: torch.Tensor, topk_output: TopKOutput):
|
||||||
if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM and self.use_fp8_w8a8:
|
if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM and self.use_fp8_w8a8:
|
||||||
@@ -386,11 +398,11 @@ class DeepEPMoE(EPMoE):
|
|||||||
return_recv_hook=True,
|
return_recv_hook=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.deepep_mode.enable_low_latency() and not _is_npu:
|
# if self.deepep_mode.enable_low_latency() and not _is_npu:
|
||||||
# NPU supports low_latency deepep without deepgemm
|
# # NPU supports low_latency deepep without deepgemm
|
||||||
assert (
|
# assert (
|
||||||
deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM
|
# deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM
|
||||||
), f"DeepEP {self.deepep_mode} mode requires deep_gemm"
|
# ), f"DeepEP {self.deepep_mode} mode requires deep_gemm"
|
||||||
if _use_aiter:
|
if _use_aiter:
|
||||||
# expert_mask is of size (self.num_local_experts + 1),
|
# expert_mask is of size (self.num_local_experts + 1),
|
||||||
# the extra 1 is for invalid rank_id (in original deepep, the invalid rank_id is -1, but aiter does not allow -1, we use a mask to make those ids invalid)
|
# the extra 1 is for invalid rank_id (in original deepep, the invalid rank_id is -1, but aiter does not allow -1, we use a mask to make those ids invalid)
|
||||||
@@ -404,23 +416,23 @@ class DeepEPMoE(EPMoE):
|
|||||||
)
|
)
|
||||||
# the last one is invalid rank_id
|
# the last one is invalid rank_id
|
||||||
self.expert_mask[:-1] = 1
|
self.expert_mask[:-1] = 1
|
||||||
elif not _is_npu:
|
# elif not _is_npu:
|
||||||
self.w13_weight_fp8 = (
|
# self.w13_weight_fp8 = (
|
||||||
self.w13_weight,
|
# self.w13_weight,
|
||||||
(
|
# (
|
||||||
self.w13_weight_scale_inv
|
# self.w13_weight_scale_inv
|
||||||
if self.use_block_quant
|
# if self.use_block_quant
|
||||||
else self.w13_weight_scale
|
# else self.w13_weight_scale
|
||||||
),
|
# ),
|
||||||
)
|
# )
|
||||||
self.w2_weight_fp8 = (
|
# self.w2_weight_fp8 = (
|
||||||
self.w2_weight,
|
# self.w2_weight,
|
||||||
(
|
# (
|
||||||
self.w2_weight_scale_inv
|
# self.w2_weight_scale_inv
|
||||||
if self.use_block_quant
|
# if self.use_block_quant
|
||||||
else self.w2_weight_scale
|
# else self.w2_weight_scale
|
||||||
),
|
# ),
|
||||||
)
|
# )
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@@ -466,8 +478,15 @@ class DeepEPMoE(EPMoE):
|
|||||||
assert DispatchOutputChecker.format_is_deepep(dispatch_output)
|
assert DispatchOutputChecker.format_is_deepep(dispatch_output)
|
||||||
return self.forward_npu(dispatch_output)
|
return self.forward_npu(dispatch_output)
|
||||||
if DispatchOutputChecker.format_is_deepep_normal(dispatch_output):
|
if DispatchOutputChecker.format_is_deepep_normal(dispatch_output):
|
||||||
assert deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM and self.use_fp8_w8a8
|
#assert deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM and self.use_fp8_w8a8
|
||||||
return self.forward_deepgemm_contiguous(dispatch_output)
|
if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM and self.use_fp8_w8a8:
|
||||||
|
return self.forward_deepgemm_contiguous(dispatch_output)
|
||||||
|
elif self.use_w4a8_marlin:
|
||||||
|
return self.forward_deepgemm_w4a8_marlin_contiguous(dispatch_output)
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"Dispatch output is not supported"
|
||||||
|
)
|
||||||
elif DispatchOutputChecker.format_is_deepep_ll(dispatch_output):
|
elif DispatchOutputChecker.format_is_deepep_ll(dispatch_output):
|
||||||
if get_moe_runner_backend().is_flashinfer_cutedsl():
|
if get_moe_runner_backend().is_flashinfer_cutedsl():
|
||||||
return self.forward_flashinfer_cutedsl(dispatch_output)
|
return self.forward_flashinfer_cutedsl(dispatch_output)
|
||||||
@@ -526,6 +545,34 @@ class DeepEPMoE(EPMoE):
|
|||||||
expert_mask=self.expert_mask,
|
expert_mask=self.expert_mask,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def forward_deepgemm_w4a8_marlin_contiguous(
|
||||||
|
self,
|
||||||
|
dispatch_output: DeepEPNormalOutput,
|
||||||
|
):
|
||||||
|
hidden_states_int8, topk_idx, topk_weights, num_recv_tokens_per_expert = (
|
||||||
|
dispatch_output
|
||||||
|
)
|
||||||
|
assert self.quant_method is not None
|
||||||
|
assert self.moe_runner_config.activation == "silu"
|
||||||
|
# if num_recv_tokens_per_expert is None:
|
||||||
|
return hidden_states_int8.bfloat16()
|
||||||
|
# expert_output = self.quant_method.apply_ep(
|
||||||
|
# layer=self,
|
||||||
|
# x=dispatch_output,
|
||||||
|
# topk_weights=topk_weights,
|
||||||
|
# topk_ids=topk_idx,
|
||||||
|
# global_num_experts=self.global_num_experts,
|
||||||
|
# expert_map=self.expert_map,
|
||||||
|
# activation=self.activation,
|
||||||
|
# apply_router_weight_on_input=self.apply_router_weight_on_input,
|
||||||
|
# use_nn_moe=self.use_nn_moe,
|
||||||
|
# num_local_tokens=dispatch_recv_num_token,
|
||||||
|
# config_select_bs=hidden_states.shape[0],
|
||||||
|
# scales=dispatch_scales if self.use_int8_dispatch else None
|
||||||
|
# # routed_scaling_factor=self.routed_scaling_factor,
|
||||||
|
# )
|
||||||
|
# return expert_output
|
||||||
|
|
||||||
def forward_deepgemm_contiguous(
|
def forward_deepgemm_contiguous(
|
||||||
self,
|
self,
|
||||||
dispatch_output: DeepEPNormalOutput,
|
dispatch_output: DeepEPNormalOutput,
|
||||||
|
|||||||
@@ -61,7 +61,7 @@ def inplace_fused_experts(
|
|||||||
topk_ids: torch.Tensor,
|
topk_ids: torch.Tensor,
|
||||||
b1: Optional[torch.Tensor] = None,
|
b1: Optional[torch.Tensor] = None,
|
||||||
b2: Optional[torch.Tensor] = None,
|
b2: Optional[torch.Tensor] = None,
|
||||||
activation: str = "silu",
|
activation: int = 0,#0 silu 1 gelu
|
||||||
apply_router_weight_on_input: bool = False,
|
apply_router_weight_on_input: bool = False,
|
||||||
use_fp8_w8a8: bool = False,
|
use_fp8_w8a8: bool = False,
|
||||||
use_int8_w8a8: bool = False,
|
use_int8_w8a8: bool = False,
|
||||||
@@ -79,6 +79,8 @@ def inplace_fused_experts(
|
|||||||
gemm1_alpha: Optional[float] = None,
|
gemm1_alpha: Optional[float] = None,
|
||||||
gemm1_limit: Optional[float] = None,
|
gemm1_limit: Optional[float] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
if isinstance(activation, int):
|
||||||
|
activation = "silu" if activation == 0 else "gelu"
|
||||||
fused_experts_impl(
|
fused_experts_impl(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
w1,
|
w1,
|
||||||
@@ -117,7 +119,7 @@ def inplace_fused_experts_fake(
|
|||||||
topk_ids: torch.Tensor,
|
topk_ids: torch.Tensor,
|
||||||
b1: Optional[torch.Tensor] = None,
|
b1: Optional[torch.Tensor] = None,
|
||||||
b2: Optional[torch.Tensor] = None,
|
b2: Optional[torch.Tensor] = None,
|
||||||
activation: str = "silu",
|
activation: int = 0,#0 silu 1 gelu
|
||||||
apply_router_weight_on_input: bool = False,
|
apply_router_weight_on_input: bool = False,
|
||||||
use_fp8_w8a8: bool = False,
|
use_fp8_w8a8: bool = False,
|
||||||
use_int8_w8a8: bool = False,
|
use_int8_w8a8: bool = False,
|
||||||
@@ -154,7 +156,7 @@ def outplace_fused_experts(
|
|||||||
topk_ids: torch.Tensor,
|
topk_ids: torch.Tensor,
|
||||||
b1: Optional[torch.Tensor] = None,
|
b1: Optional[torch.Tensor] = None,
|
||||||
b2: Optional[torch.Tensor] = None,
|
b2: Optional[torch.Tensor] = None,
|
||||||
activation: str = "silu",
|
activation: int = 0,#0 silu 1 gelu
|
||||||
apply_router_weight_on_input: bool = False,
|
apply_router_weight_on_input: bool = False,
|
||||||
use_fp8_w8a8: bool = False,
|
use_fp8_w8a8: bool = False,
|
||||||
use_int8_w8a8: bool = False,
|
use_int8_w8a8: bool = False,
|
||||||
@@ -173,6 +175,8 @@ def outplace_fused_experts(
|
|||||||
gemm1_alpha: Optional[float] = None,
|
gemm1_alpha: Optional[float] = None,
|
||||||
gemm1_limit: Optional[float] = None,
|
gemm1_limit: Optional[float] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
|
if isinstance(activation, int):
|
||||||
|
activation = "silu" if activation == 0 else "gelu"
|
||||||
return fused_experts_impl(
|
return fused_experts_impl(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
w1,
|
w1,
|
||||||
@@ -211,7 +215,7 @@ def outplace_fused_experts_fake(
|
|||||||
topk_ids: torch.Tensor,
|
topk_ids: torch.Tensor,
|
||||||
b1: Optional[torch.Tensor] = None,
|
b1: Optional[torch.Tensor] = None,
|
||||||
b2: Optional[torch.Tensor] = None,
|
b2: Optional[torch.Tensor] = None,
|
||||||
activation: str = "silu",
|
activation: int = 0,#0 silu 1 gelu
|
||||||
apply_router_weight_on_input: bool = False,
|
apply_router_weight_on_input: bool = False,
|
||||||
use_fp8_w8a8: bool = False,
|
use_fp8_w8a8: bool = False,
|
||||||
use_int8_w8a8: bool = False,
|
use_int8_w8a8: bool = False,
|
||||||
@@ -263,6 +267,13 @@ def fused_experts(
|
|||||||
block_shape: Optional[List[int]] = None,
|
block_shape: Optional[List[int]] = None,
|
||||||
):
|
):
|
||||||
topk_weights, topk_ids, _ = topk_output
|
topk_weights, topk_ids, _ = topk_output
|
||||||
|
act_id = (
|
||||||
|
0 if (
|
||||||
|
moe_runner_config.activation == 0
|
||||||
|
or (isinstance(moe_runner_config.activation, str)
|
||||||
|
and moe_runner_config.activation.lower() == "silu")
|
||||||
|
) else 1
|
||||||
|
)
|
||||||
if moe_runner_config.inplace:
|
if moe_runner_config.inplace:
|
||||||
assert not moe_runner_config.no_combine, "no combine + inplace makes no sense"
|
assert not moe_runner_config.no_combine, "no combine + inplace makes no sense"
|
||||||
torch.ops.sglang.inplace_fused_experts(
|
torch.ops.sglang.inplace_fused_experts(
|
||||||
@@ -273,7 +284,7 @@ def fused_experts(
|
|||||||
topk_ids,
|
topk_ids,
|
||||||
b1,
|
b1,
|
||||||
b2,
|
b2,
|
||||||
moe_runner_config.activation,
|
act_id,
|
||||||
moe_runner_config.apply_router_weight_on_input,
|
moe_runner_config.apply_router_weight_on_input,
|
||||||
use_fp8_w8a8,
|
use_fp8_w8a8,
|
||||||
use_int8_w8a8,
|
use_int8_w8a8,
|
||||||
@@ -301,7 +312,7 @@ def fused_experts(
|
|||||||
topk_ids,
|
topk_ids,
|
||||||
b1,
|
b1,
|
||||||
b2,
|
b2,
|
||||||
moe_runner_config.activation,
|
act_id,
|
||||||
moe_runner_config.apply_router_weight_on_input,
|
moe_runner_config.apply_router_weight_on_input,
|
||||||
use_fp8_w8a8,
|
use_fp8_w8a8,
|
||||||
use_int8_w8a8,
|
use_int8_w8a8,
|
||||||
@@ -345,7 +356,7 @@ def fused_experts_impl(
|
|||||||
b1: Optional[torch.Tensor] = None,
|
b1: Optional[torch.Tensor] = None,
|
||||||
b2: Optional[torch.Tensor] = None,
|
b2: Optional[torch.Tensor] = None,
|
||||||
inplace: bool = False,
|
inplace: bool = False,
|
||||||
activation: str = "silu",
|
activation: int = 0,#0 silu 1 gelu
|
||||||
apply_router_weight_on_input: bool = False,
|
apply_router_weight_on_input: bool = False,
|
||||||
use_fp8_w8a8: bool = False,
|
use_fp8_w8a8: bool = False,
|
||||||
use_int8_w8a8: bool = False,
|
use_int8_w8a8: bool = False,
|
||||||
@@ -364,6 +375,9 @@ def fused_experts_impl(
|
|||||||
gemm1_alpha: Optional[float] = None,
|
gemm1_alpha: Optional[float] = None,
|
||||||
gemm1_limit: Optional[float] = None,
|
gemm1_limit: Optional[float] = None,
|
||||||
):
|
):
|
||||||
|
if isinstance(activation, int):
|
||||||
|
activation = "silu" if activation == 0 else "gelu"
|
||||||
|
|
||||||
padded_size = padding_size
|
padded_size = padding_size
|
||||||
if not (use_fp8_w8a8 or use_int8_w8a8) or block_shape is not None or _use_aiter:
|
if not (use_fp8_w8a8 or use_int8_w8a8) or block_shape is not None or _use_aiter:
|
||||||
padded_size = 0
|
padded_size = 0
|
||||||
|
|||||||
@@ -431,32 +431,32 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
|
|||||||
deepep_post_reorder_triton_kernel,
|
deepep_post_reorder_triton_kernel,
|
||||||
)
|
)
|
||||||
|
|
||||||
if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM or _use_aiter or _is_npu:
|
#if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM or _use_aiter or _is_npu:
|
||||||
output = hidden_states
|
output = hidden_states
|
||||||
else:
|
# else:
|
||||||
if hidden_states.shape[0] > 0:
|
# if hidden_states.shape[0] > 0:
|
||||||
num_tokens = self.src2dst.shape[0] // self.router_topk
|
# num_tokens = self.src2dst.shape[0] // self.router_topk
|
||||||
output = torch.empty(
|
# output = torch.empty(
|
||||||
(num_tokens, hidden_states.shape[1]),
|
# (num_tokens, hidden_states.shape[1]),
|
||||||
device=hidden_states.device,
|
# device=hidden_states.device,
|
||||||
dtype=hidden_states.dtype,
|
# dtype=hidden_states.dtype,
|
||||||
)
|
# )
|
||||||
deepep_post_reorder_triton_kernel[(num_tokens,)](
|
# deepep_post_reorder_triton_kernel[(num_tokens,)](
|
||||||
hidden_states,
|
# hidden_states,
|
||||||
output,
|
# output,
|
||||||
self.src2dst,
|
# self.src2dst,
|
||||||
topk_idx,
|
# topk_idx,
|
||||||
topk_weights,
|
# topk_weights,
|
||||||
self.router_topk,
|
# self.router_topk,
|
||||||
hidden_states.shape[1],
|
# hidden_states.shape[1],
|
||||||
BLOCK_SIZE=512,
|
# BLOCK_SIZE=512,
|
||||||
)
|
# )
|
||||||
else:
|
# else:
|
||||||
output = torch.zeros(
|
# output = torch.zeros(
|
||||||
(0, hidden_states.shape[1]),
|
# (0, hidden_states.shape[1]),
|
||||||
device=hidden_states.device,
|
# device=hidden_states.device,
|
||||||
dtype=hidden_states.dtype,
|
# dtype=hidden_states.dtype,
|
||||||
)
|
# )
|
||||||
previous_event = Buffer.capture() if self.async_finish else None
|
previous_event = Buffer.capture() if self.async_finish else None
|
||||||
return output, previous_event
|
return output, previous_event
|
||||||
|
|
||||||
|
|||||||
@@ -61,6 +61,7 @@ from sglang.srt.layers.quantization.qoq import QoQConfig
|
|||||||
from sglang.srt.layers.quantization.w4afp8 import W4AFp8Config
|
from sglang.srt.layers.quantization.w4afp8 import W4AFp8Config
|
||||||
from sglang.srt.layers.quantization.w8a8_fp8 import W8A8Fp8Config
|
from sglang.srt.layers.quantization.w8a8_fp8 import W8A8Fp8Config
|
||||||
from sglang.srt.layers.quantization.w8a8_int8 import W8A8Int8Config
|
from sglang.srt.layers.quantization.w8a8_int8 import W8A8Int8Config
|
||||||
|
from sglang.srt.layers.quantization.slimquant_w4a8_marlin import SlimQuantW4A8Int8MarlinConfig
|
||||||
from sglang.srt.utils import is_cuda, is_hip, mxfp_supported
|
from sglang.srt.utils import is_cuda, is_hip, mxfp_supported
|
||||||
|
|
||||||
_is_mxfp_supported = mxfp_supported()
|
_is_mxfp_supported = mxfp_supported()
|
||||||
@@ -86,6 +87,7 @@ BASE_QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
|
|||||||
"w4afp8": W4AFp8Config,
|
"w4afp8": W4AFp8Config,
|
||||||
"petit_nvfp4": PetitNvFp4Config,
|
"petit_nvfp4": PetitNvFp4Config,
|
||||||
"fbgemm_fp8": FBGEMMFp8Config,
|
"fbgemm_fp8": FBGEMMFp8Config,
|
||||||
|
"slimquant_w4a8_marlin":SlimQuantW4A8Int8MarlinConfig,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
415
python/sglang/srt/layers/quantization/slimquant_w4a8.py
Normal file
415
python/sglang/srt/layers/quantization/slimquant_w4a8.py
Normal file
@@ -0,0 +1,415 @@
|
|||||||
|
from typing import Any, Callable, Dict, List, Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from sglang.srt.layers.linear import set_weight_attrs
|
||||||
|
from sglang.srt.distributed import get_tensor_model_parallel_world_size
|
||||||
|
from torch.nn.parameter import Parameter
|
||||||
|
from sglang.srt.layers.linear import LinearBase
|
||||||
|
from sglang.srt.layers.quantization.base_config import LinearMethodBase, QuantizationConfig, QuantizeMethodBase, FusedMoEMethodBase
|
||||||
|
from sglang.srt.layers.parameter import (
|
||||||
|
ChannelQuantScaleParameter,
|
||||||
|
_ColumnvLLMParameter,
|
||||||
|
RowvLLMParameter,
|
||||||
|
)
|
||||||
|
from lmslim.layers.gemm.int8_utils import (
|
||||||
|
per_token_group_quant_int8,
|
||||||
|
per_token_quant_int8)
|
||||||
|
from sglang.srt import _custom_ops as ops
|
||||||
|
from vllm.utils import W8a8GetCacheJSON
|
||||||
|
from sglang.srt.layers.moe import MoeRunner, MoeRunnerBackend, MoeRunnerConfig
|
||||||
|
|
||||||
|
import os
|
||||||
|
|
||||||
|
class ModelWeightParameter(_ColumnvLLMParameter, RowvLLMParameter):
|
||||||
|
"""
|
||||||
|
Parameter class for linear layer weights. Uses both column and
|
||||||
|
row parallelism.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
W8A8_TRITONJSON=W8a8GetCacheJSON()
|
||||||
|
|
||||||
|
def baseline_scaled_mm(a: torch.Tensor,
|
||||||
|
b: torch.Tensor,
|
||||||
|
scale_a: torch.Tensor,
|
||||||
|
scale_b: torch.Tensor,
|
||||||
|
out_dtype: torch.dtype,
|
||||||
|
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||||
|
|
||||||
|
scales= scale_a* scale_b.T
|
||||||
|
gemmout= torch.mm(
|
||||||
|
a.to(dtype=torch.float32), b.to(dtype=torch.float32))
|
||||||
|
output = (scales *gemmout).to(out_dtype)
|
||||||
|
if bias is not None:
|
||||||
|
output = output + bias
|
||||||
|
return output.to(out_dtype)
|
||||||
|
|
||||||
|
|
||||||
|
class SlimQuantW4A8Int8Config(QuantizationConfig):
|
||||||
|
"""Config class for W8A8 Int8 Quantization.
|
||||||
|
|
||||||
|
- Weight: static, per-channel, symmetric
|
||||||
|
- Activation: dynamic, per-token, symmetric
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
|
||||||
|
return [torch.float16, torch.bfloat16]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_min_capability(cls) -> int:
|
||||||
|
return 75
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_name(self) -> str:
|
||||||
|
return "slimquant_w4a8"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_config_filenames(cls) -> List[str]:
|
||||||
|
return []
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_config(cls, config: Dict[str, Any]) -> "SlimQuantW4A8Int8Config":
|
||||||
|
return cls()
|
||||||
|
|
||||||
|
def get_quant_method(
|
||||||
|
self,
|
||||||
|
layer: torch.nn.Module,
|
||||||
|
prefix: str,
|
||||||
|
) -> Optional["QuantizeMethodBase"]:
|
||||||
|
from sglang.srt.layers.moe.fused_moe_triton import (FusedMoE, FusedMoeWeightScaleSupported)
|
||||||
|
|
||||||
|
if isinstance(layer, LinearBase):
|
||||||
|
return SlimQuantW4A8Int8LinearMethod(self)
|
||||||
|
elif isinstance(layer, FusedMoE):
|
||||||
|
return SlimQuantW4A8Int8MoEMethod(self)
|
||||||
|
return None
|
||||||
|
|
||||||
|
def get_scaled_act_names(self) -> List[str]:
|
||||||
|
return []
|
||||||
|
|
||||||
|
|
||||||
|
class SlimQuantW4A8Int8LinearMethod(LinearMethodBase):
|
||||||
|
|
||||||
|
def __init__(self, quantization_config: SlimQuantW4A8Int8Config):
|
||||||
|
self.quantization_config = quantization_config
|
||||||
|
self.tritonsingleton= W8a8GetCacheJSON()
|
||||||
|
self.w8a8_strategy=int(os.getenv('W8A8_SUPPORT_METHODS', '1'))
|
||||||
|
|
||||||
|
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||||
|
n=layer.weight.shape[0]
|
||||||
|
k=layer.weight.shape[1]
|
||||||
|
|
||||||
|
if self.w8a8_strategy==1:
|
||||||
|
if {n,k} not in self.tritonsingleton.weight_shapes:
|
||||||
|
self.tritonsingleton.weight_shapes.append({n,k})
|
||||||
|
json_file=self.tritonsingleton.get_w8a8json_name(n,k)
|
||||||
|
configs_dict=self.tritonsingleton.get_triton_cache(json_file,n,k)
|
||||||
|
|
||||||
|
if configs_dict:
|
||||||
|
self.tritonsingleton.triton_json_dict.update(configs_dict)
|
||||||
|
|
||||||
|
for key, value in configs_dict.items():
|
||||||
|
m=int(key.split('_')[0])
|
||||||
|
ops.triton_int8_gemm_helper(m=m,n=n,k=k,per_token_act_quant=True,per_out_channel_weight_quant=True,use_bias=False,device=layer.weight.device,best_config=value)
|
||||||
|
else:
|
||||||
|
weight_data=layer.weight.data
|
||||||
|
_weight=weight_data.T.contiguous().reshape(n,-1)
|
||||||
|
layer.weight.data=_weight
|
||||||
|
|
||||||
|
layer.weight = Parameter(layer.weight.t(), requires_grad=False)
|
||||||
|
layer.weight_scale = Parameter(layer.weight_scale.data, requires_grad=False)
|
||||||
|
|
||||||
|
def create_weights(
|
||||||
|
self,
|
||||||
|
layer: torch.nn.Module,
|
||||||
|
input_size_per_partition: int,
|
||||||
|
output_partition_sizes: List[int],
|
||||||
|
input_size: int,
|
||||||
|
output_size: int,
|
||||||
|
params_dtype: torch.dtype,
|
||||||
|
**extra_weight_attrs,
|
||||||
|
):
|
||||||
|
|
||||||
|
weight_loader = extra_weight_attrs.get("weight_loader")
|
||||||
|
self.logical_widths = output_partition_sizes
|
||||||
|
|
||||||
|
weight = ModelWeightParameter(
|
||||||
|
data=torch.empty(
|
||||||
|
sum(output_partition_sizes), input_size_per_partition, dtype=torch.int8
|
||||||
|
),
|
||||||
|
input_dim=1,
|
||||||
|
output_dim=0,
|
||||||
|
weight_loader=weight_loader,
|
||||||
|
)
|
||||||
|
layer.register_parameter("weight", weight)
|
||||||
|
|
||||||
|
weight_scale = ChannelQuantScaleParameter(
|
||||||
|
data=torch.empty((sum(output_partition_sizes), 1), dtype=torch.float32),
|
||||||
|
output_dim=0,
|
||||||
|
weight_loader=weight_loader,
|
||||||
|
)
|
||||||
|
layer.register_parameter("weight_scale", weight_scale)
|
||||||
|
|
||||||
|
def apply(
|
||||||
|
self,
|
||||||
|
layer: torch.nn.Module,
|
||||||
|
x: torch.Tensor,
|
||||||
|
bias: Optional[torch.Tensor] = None,
|
||||||
|
input_quant_args: Optional[list[torch.Tensor]] = None,
|
||||||
|
silu_quant_args: Optional[list[torch.Tensor]] = None
|
||||||
|
):
|
||||||
|
# if envs.USE_FUSED_RMS_QUANT and input_quant_args is not None:
|
||||||
|
# assert len(input_quant_args) == 2
|
||||||
|
# x_q, x_scale = input_quant_args
|
||||||
|
# elif envs.USE_FUSED_SILU_MUL_QUANT and silu_quant_args is not None:
|
||||||
|
# x_q, x_scale = silu_quant_args
|
||||||
|
# else:
|
||||||
|
x_q, x_scale = per_token_quant_int8(x)
|
||||||
|
|
||||||
|
if self.w8a8_strategy==1:
|
||||||
|
m=x_q.shape[0]
|
||||||
|
k=x_q.shape[1]
|
||||||
|
n=layer.weight.shape[1]
|
||||||
|
|
||||||
|
if len(W8A8_TRITONJSON.triton_json_dict)==0:
|
||||||
|
best_config=None
|
||||||
|
|
||||||
|
elif f"1_{n}_{k}" in W8A8_TRITONJSON.triton_json_dict:
|
||||||
|
if m<=16:
|
||||||
|
m_=m
|
||||||
|
elif m<=64:
|
||||||
|
m_= (m + 3) & -4 #取值到最近的4的倍数
|
||||||
|
elif m<=160:
|
||||||
|
m_=(m + 7) & -8
|
||||||
|
|
||||||
|
elif m<200: #256
|
||||||
|
m_=160
|
||||||
|
elif m<480: #512
|
||||||
|
m_=256
|
||||||
|
elif m<960: #1024
|
||||||
|
m_=512
|
||||||
|
elif m<2048:
|
||||||
|
m_=1024
|
||||||
|
elif m<4096:
|
||||||
|
m_=2048
|
||||||
|
elif m<6000:
|
||||||
|
m_=4096
|
||||||
|
else:
|
||||||
|
m_=8192
|
||||||
|
|
||||||
|
best_config=W8A8_TRITONJSON.triton_json_dict[f"{m_}_{n}_{k}"]
|
||||||
|
|
||||||
|
else:
|
||||||
|
best_config=None
|
||||||
|
|
||||||
|
#if best_config==None:
|
||||||
|
# print("m:{},n:{},k:{}".format(m,n,k))
|
||||||
|
# print("config not found!")
|
||||||
|
|
||||||
|
return ops.triton_scaled_mm(x_q,
|
||||||
|
layer.weight,
|
||||||
|
scale_a=x_scale,
|
||||||
|
scale_b=layer.weight_scale,
|
||||||
|
out_dtype=x.dtype,
|
||||||
|
bias=bias,best_config=best_config)
|
||||||
|
elif self.w8a8_strategy==2:
|
||||||
|
return ops.cutlass_scaled_mm(x_q,
|
||||||
|
layer.weight,
|
||||||
|
scale_a=x_scale,
|
||||||
|
scale_b=layer.weight_scale,
|
||||||
|
out_dtype=x.dtype,
|
||||||
|
bias=bias)
|
||||||
|
else:
|
||||||
|
return ops.rocblas_scaled_mm(x_q,
|
||||||
|
layer.weight,
|
||||||
|
scale_a=x_scale,
|
||||||
|
scale_b=layer.weight_scale,
|
||||||
|
out_dtype=x.dtype,
|
||||||
|
bias=bias)
|
||||||
|
|
||||||
|
|
||||||
|
class SlimQuantW4A8Int8MoEMethod:
|
||||||
|
"""MoE method for W4A8INT8.
|
||||||
|
Supports loading INT8 checkpoints with static weight scale and
|
||||||
|
dynamic/static activation scale.
|
||||||
|
Also supports loading quantized FP16/BF16 model checkpoints with dynamic
|
||||||
|
activation scaling. The weight scaling factor will be initialized after
|
||||||
|
the model weights are loaded.
|
||||||
|
Args:
|
||||||
|
quant_config: The quantization config.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __new__(cls, *args, **kwargs):
|
||||||
|
from sglang.srt.layers.moe.fused_moe_triton import (FusedMoE, FusedMoeWeightScaleSupported)
|
||||||
|
|
||||||
|
if not hasattr(cls, "_initialized"):
|
||||||
|
original_init = cls.__init__
|
||||||
|
new_cls = type(
|
||||||
|
cls.__name__,
|
||||||
|
(FusedMoEMethodBase,),
|
||||||
|
{
|
||||||
|
"__init__": original_init,
|
||||||
|
**{k: v for k, v in cls.__dict__.items() if k != "__dict__"},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
obj = super(new_cls, new_cls).__new__(new_cls)
|
||||||
|
obj.__init__(*args, **kwargs)
|
||||||
|
return obj
|
||||||
|
return super().__new__(cls)
|
||||||
|
|
||||||
|
def __init__(self, quant_config):
|
||||||
|
self.quant_config = quant_config
|
||||||
|
self.tritonsingleton= W8a8GetCacheJSON()
|
||||||
|
|
||||||
|
def create_weights(
|
||||||
|
self,
|
||||||
|
layer: torch.nn.Module,
|
||||||
|
num_experts: int,
|
||||||
|
hidden_size: int,
|
||||||
|
intermediate_size: int,
|
||||||
|
params_dtype: torch.dtype,
|
||||||
|
**extra_weight_attrs,
|
||||||
|
):
|
||||||
|
from sglang.srt.layers.moe.fused_moe_triton import (FusedMoE, FusedMoeWeightScaleSupported)
|
||||||
|
tp_size = get_tensor_model_parallel_world_size()
|
||||||
|
|
||||||
|
# WEIGHTS
|
||||||
|
w13_weight = torch.nn.Parameter(
|
||||||
|
torch.empty(
|
||||||
|
num_experts, 2 * intermediate_size, hidden_size//2, dtype=torch.int8
|
||||||
|
),
|
||||||
|
requires_grad=False,
|
||||||
|
)
|
||||||
|
layer.register_parameter("w13_weight", w13_weight)
|
||||||
|
set_weight_attrs(w13_weight, extra_weight_attrs)
|
||||||
|
|
||||||
|
w2_weight = torch.nn.Parameter(
|
||||||
|
torch.empty(num_experts, hidden_size, intermediate_size//2, dtype=torch.int8),
|
||||||
|
requires_grad=False,
|
||||||
|
)
|
||||||
|
layer.register_parameter("w2_weight", w2_weight)
|
||||||
|
set_weight_attrs(w2_weight, extra_weight_attrs)
|
||||||
|
|
||||||
|
w13_weight_scale = torch.nn.Parameter(
|
||||||
|
torch.ones(num_experts, 2 * intermediate_size, 1, dtype=torch.float32),
|
||||||
|
requires_grad=False,
|
||||||
|
)
|
||||||
|
w2_weight_scale = torch.nn.Parameter(
|
||||||
|
torch.ones(num_experts, hidden_size, 1, dtype=torch.float32),
|
||||||
|
requires_grad=False,
|
||||||
|
)
|
||||||
|
layer.register_parameter("w13_weight_scale", w13_weight_scale)
|
||||||
|
layer.register_parameter("w2_weight_scale", w2_weight_scale)
|
||||||
|
|
||||||
|
extra_weight_attrs.update(
|
||||||
|
{"quant_method": FusedMoeWeightScaleSupported.CHANNEL.value}
|
||||||
|
)
|
||||||
|
|
||||||
|
set_weight_attrs(w13_weight_scale, extra_weight_attrs)
|
||||||
|
set_weight_attrs(w2_weight_scale, extra_weight_attrs)
|
||||||
|
|
||||||
|
w13_input_scale = None
|
||||||
|
layer.register_parameter("w13_input_scale", w13_input_scale)
|
||||||
|
|
||||||
|
w2_input_scale = None
|
||||||
|
layer.register_parameter("w2_input_scale", w2_input_scale)
|
||||||
|
|
||||||
|
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||||
|
E=layer.w13_weight.shape[0]
|
||||||
|
N1=layer.w13_weight.shape[1]
|
||||||
|
N2=layer.w2_weight.shape[1]
|
||||||
|
K=N1//2
|
||||||
|
if [E,N1,N2,K] not in self.tritonsingleton.moe_weight_shapes:
|
||||||
|
self.tritonsingleton.moe_weight_shapes.append([E,N1,N2,K])
|
||||||
|
|
||||||
|
TOPK= self.tritonsingleton.topk
|
||||||
|
|
||||||
|
json_file=self.tritonsingleton.get_moeint8json_name(E,N1,N2,K,TOPK,use_int4_w4a8=True)
|
||||||
|
configs_dict=self.tritonsingleton.get_moeint8_triton_cache(json_file,E,N1,N2,K,TOPK)
|
||||||
|
|
||||||
|
#warmup
|
||||||
|
if configs_dict:
|
||||||
|
self.tritonsingleton.triton_moejson_dict.update(configs_dict)
|
||||||
|
|
||||||
|
layer.w13_weight = Parameter(layer.w13_weight, requires_grad=False)
|
||||||
|
layer.w2_weight = Parameter(layer.w2_weight, requires_grad=False)
|
||||||
|
layer.w13_weight_scale = Parameter(
|
||||||
|
layer.w13_weight_scale.data, requires_grad=False
|
||||||
|
)
|
||||||
|
layer.w2_weight_scale = Parameter(
|
||||||
|
layer.w2_weight_scale.data, requires_grad=False
|
||||||
|
)
|
||||||
|
|
||||||
|
def create_moe_runner(
|
||||||
|
self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig
|
||||||
|
):
|
||||||
|
self.moe_runner_config = moe_runner_config
|
||||||
|
self.runner = MoeRunner(MoeRunnerBackend.TRITON, moe_runner_config)
|
||||||
|
|
||||||
|
def apply(
|
||||||
|
self,
|
||||||
|
layer: torch.nn.Module,
|
||||||
|
x: torch.Tensor,
|
||||||
|
router_logits: torch.Tensor,
|
||||||
|
top_k: int,
|
||||||
|
renormalize: bool,
|
||||||
|
use_grouped_topk: bool = False,
|
||||||
|
topk_group: Optional[int] = None,
|
||||||
|
num_expert_group: Optional[int] = None,
|
||||||
|
global_num_experts: int = -1,
|
||||||
|
expert_map: Optional[torch.Tensor] = None,
|
||||||
|
custom_routing_function: Optional[Callable] = None,
|
||||||
|
scoring_func: str = "softmax",
|
||||||
|
e_score_correction_bias: Optional[torch.Tensor] = None,
|
||||||
|
apply_router_weight_on_input: bool = False,
|
||||||
|
activation: str = "silu",
|
||||||
|
enable_eplb: bool = False,
|
||||||
|
use_nn_moe: Optional[bool] = False,
|
||||||
|
routed_scaling_factor: Optional[float] = None,
|
||||||
|
use_fused_gate: Optional[bool] = False,
|
||||||
|
**_
|
||||||
|
) -> torch.Tensor:
|
||||||
|
from sglang.srt.layers.moe.fused_moe_triton import (FusedMoE, FusedMoeWeightScaleSupported)
|
||||||
|
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
|
||||||
|
if enable_eplb:
|
||||||
|
raise NotImplementedError(
|
||||||
|
"EPLB not supported for `SlimQuantW4A8Int8MoEMethod` yet.")
|
||||||
|
# Expert selection
|
||||||
|
topk_weights, topk_ids = FusedMoE.select_experts(
|
||||||
|
hidden_states=x,
|
||||||
|
router_logits=router_logits,
|
||||||
|
use_grouped_topk=use_grouped_topk,
|
||||||
|
top_k=top_k,
|
||||||
|
renormalize=renormalize,
|
||||||
|
topk_group=topk_group,
|
||||||
|
num_expert_group=num_expert_group,
|
||||||
|
custom_routing_function=custom_routing_function,
|
||||||
|
scoring_func=scoring_func,
|
||||||
|
e_score_correction_bias=e_score_correction_bias,
|
||||||
|
routed_scaling_factor=routed_scaling_factor,
|
||||||
|
use_fused_gate=use_fused_gate
|
||||||
|
)
|
||||||
|
|
||||||
|
return fused_experts(
|
||||||
|
x,
|
||||||
|
layer.w13_weight,
|
||||||
|
layer.w2_weight,
|
||||||
|
topk_weights=topk_weights,
|
||||||
|
topk_ids=topk_ids,
|
||||||
|
inplace=True,
|
||||||
|
use_int4_w4a8=True,
|
||||||
|
per_channel_quant=True,
|
||||||
|
activation=activation,
|
||||||
|
expert_map=expert_map,
|
||||||
|
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||||
|
global_num_experts=global_num_experts,
|
||||||
|
w1_scale=(layer.w13_weight_scale),
|
||||||
|
w2_scale=(layer.w2_weight_scale),
|
||||||
|
a1_scale=layer.w13_input_scale,
|
||||||
|
a2_scale=layer.w2_input_scale,
|
||||||
|
use_nn_moe=use_nn_moe,
|
||||||
|
)
|
||||||
318
python/sglang/srt/layers/quantization/slimquant_w4a8_marlin.py
Normal file
318
python/sglang/srt/layers/quantization/slimquant_w4a8_marlin.py
Normal file
@@ -0,0 +1,318 @@
|
|||||||
|
from typing import Any, Callable, Dict, List, Optional
|
||||||
|
from sglang.srt.layers.moe.token_dispatcher.base import CombineInput
|
||||||
|
from sglang.srt.layers.moe.token_dispatcher.standard import StandardCombineInput, StandardDispatchOutput
|
||||||
|
import torch
|
||||||
|
from sglang.srt import _custom_ops as ops
|
||||||
|
from sglang.srt.utils import set_weight_attrs
|
||||||
|
from sglang.srt.distributed import get_tensor_model_parallel_world_size
|
||||||
|
from torch.nn.parameter import Parameter
|
||||||
|
from sglang.srt.layers.linear import LinearBase
|
||||||
|
from sglang.srt.layers.quantization import QuantizationConfig
|
||||||
|
from sglang.srt.layers.quantization.w4a8_utils import w4a8_weight_repack_impl
|
||||||
|
from sglang.srt.layers.quantization.base_config import (FusedMoEMethodBase, QuantizeMethodBase)
|
||||||
|
from sglang.srt.layers.quantization.slimquant_w4a8 import SlimQuantW4A8Int8LinearMethod
|
||||||
|
from sglang.srt.layers.moe import MoeRunner, MoeRunnerBackend, MoeRunnerConfig
|
||||||
|
|
||||||
|
try:
|
||||||
|
from lmslim.layers.fused_moe.fuse_moe_w4a8_marlin import fused_experts_impl_w4a8_marlin
|
||||||
|
except Exception:
|
||||||
|
print("INFO: Please install lmslim if you want to infer the quantitative model of moe.\n")
|
||||||
|
|
||||||
|
|
||||||
|
class MarlinMoeWorkspace:
|
||||||
|
"""
|
||||||
|
Singleton manager for device-specific workspace buffers used by w4a8 Marlin-MoE.
|
||||||
|
global_reduce_buffer will take 1.5MB * cus (about 120MB for BW200) memoery in each device
|
||||||
|
"""
|
||||||
|
_instances = {}
|
||||||
|
def __new__(cls, device):
|
||||||
|
if device not in cls._instances:
|
||||||
|
instance = super().__new__(cls)
|
||||||
|
instance._initialized = False
|
||||||
|
cls._instances[device] = instance
|
||||||
|
return cls._instances[device]
|
||||||
|
|
||||||
|
def __init__(self, device):
|
||||||
|
if self._initialized:
|
||||||
|
return
|
||||||
|
sms = torch.cuda.get_device_properties(device).multi_processor_count
|
||||||
|
self.workspace = torch.zeros(
|
||||||
|
500, dtype=torch.int, device=device, requires_grad=False
|
||||||
|
)
|
||||||
|
self.global_reduce_buffer = torch.zeros(
|
||||||
|
sms * 6 * 128 * 512, dtype=torch.int, device=device, requires_grad=False
|
||||||
|
)
|
||||||
|
self._initialized = True
|
||||||
|
|
||||||
|
def get_buffers(self):
|
||||||
|
return self.workspace, self.global_reduce_buffer
|
||||||
|
|
||||||
|
def baseline_scaled_mm(a: torch.Tensor,
|
||||||
|
b: torch.Tensor,
|
||||||
|
scale_a: torch.Tensor,
|
||||||
|
scale_b: torch.Tensor,
|
||||||
|
out_dtype: torch.dtype,
|
||||||
|
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||||
|
|
||||||
|
scales= scale_a* scale_b.T
|
||||||
|
gemmout= torch.mm(
|
||||||
|
a.to(dtype=torch.float32), b.to(dtype=torch.float32))
|
||||||
|
output = (scales *gemmout).to(out_dtype)
|
||||||
|
if bias is not None:
|
||||||
|
output = output + bias
|
||||||
|
return output.to(out_dtype)
|
||||||
|
|
||||||
|
|
||||||
|
class SlimQuantW4A8Int8MarlinConfig(QuantizationConfig):
|
||||||
|
"""Config class for W4A8 Int8 Quantization.
|
||||||
|
- Weight: static, per-channel, symmetric
|
||||||
|
- Activation: dynamic, per-token, symmetric
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
|
||||||
|
return [torch.float16, torch.bfloat16]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_min_capability(cls) -> int:
|
||||||
|
return 75
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_name(self) -> str:
|
||||||
|
return "slimquant_w4a8_marlin"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_config_filenames(cls) -> List[str]:
|
||||||
|
return []
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_config(cls, config: Dict[str, Any]) -> "SlimQuantW4A8Int8MarlinConfig":
|
||||||
|
return cls()
|
||||||
|
@classmethod
|
||||||
|
def override_quantization_method(
|
||||||
|
cls, hf_quant_cfg, user_quant) -> Optional[str]:
|
||||||
|
if hf_quant_cfg.get("quant_method") == "slimquant_w4a8" \
|
||||||
|
and user_quant == "slimquant_w4a8_marlin":
|
||||||
|
return cls.get_name()
|
||||||
|
return None
|
||||||
|
def get_quant_method(
|
||||||
|
self,
|
||||||
|
layer: torch.nn.Module,
|
||||||
|
prefix: str,
|
||||||
|
) -> Optional["QuantizeMethodBase"]:
|
||||||
|
from sglang.srt.layers.moe.fused_moe_triton import (FusedMoE, FusedMoeWeightScaleSupported)
|
||||||
|
|
||||||
|
if isinstance(layer, LinearBase):
|
||||||
|
return SlimQuantW4A8Int8LinearMethod(self)
|
||||||
|
elif isinstance(layer, FusedMoE):
|
||||||
|
return SlimQuantW4A8Int8MarlinMoEMethod(self)
|
||||||
|
return None
|
||||||
|
|
||||||
|
def get_scaled_act_names(self) -> List[str]:
|
||||||
|
return []
|
||||||
|
|
||||||
|
|
||||||
|
class SlimQuantW4A8Int8MarlinMoEMethod:
|
||||||
|
"""MoE method for W4A8INT8 Marlin.
|
||||||
|
Supports loading INT8 checkpoints with static weight scale and
|
||||||
|
dynamic/static activation scale.
|
||||||
|
Args:
|
||||||
|
quant_config: The quantization config.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __new__(cls, *args, **kwargs):
|
||||||
|
from sglang.srt.layers.moe.fused_moe_triton import (FusedMoE, FusedMoeWeightScaleSupported)
|
||||||
|
|
||||||
|
if not hasattr(cls, "_initialized"):
|
||||||
|
original_init = cls.__init__
|
||||||
|
new_cls = type(
|
||||||
|
cls.__name__,
|
||||||
|
(FusedMoEMethodBase,),
|
||||||
|
{
|
||||||
|
"__init__": original_init,
|
||||||
|
**{k: v for k, v in cls.__dict__.items() if k != "__dict__"},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
obj = super(new_cls, new_cls).__new__(new_cls)
|
||||||
|
obj.__init__(*args, **kwargs)
|
||||||
|
return obj
|
||||||
|
return super().__new__(cls)
|
||||||
|
|
||||||
|
def __init__(self, quant_config):
|
||||||
|
self.quant_config = quant_config
|
||||||
|
|
||||||
|
def create_weights(
|
||||||
|
self,
|
||||||
|
layer: torch.nn.Module,
|
||||||
|
num_experts: int,
|
||||||
|
hidden_size: int,
|
||||||
|
intermediate_size_per_partition: int,
|
||||||
|
params_dtype: torch.dtype,
|
||||||
|
**extra_weight_attrs,
|
||||||
|
):
|
||||||
|
from sglang.srt.layers.moe.fused_moe_triton import (FusedMoE, FusedMoeWeightScaleSupported)
|
||||||
|
tp_size = get_tensor_model_parallel_world_size()
|
||||||
|
intermediate_size = intermediate_size_per_partition
|
||||||
|
# WEIGHTS
|
||||||
|
w13_weight = torch.nn.Parameter(
|
||||||
|
torch.empty(
|
||||||
|
num_experts, 2 * intermediate_size, hidden_size//2, dtype=torch.int8
|
||||||
|
),
|
||||||
|
requires_grad=False,
|
||||||
|
)
|
||||||
|
layer.register_parameter("w13_weight", w13_weight)
|
||||||
|
set_weight_attrs(w13_weight, extra_weight_attrs)
|
||||||
|
|
||||||
|
w2_weight = torch.nn.Parameter(
|
||||||
|
torch.empty(num_experts, hidden_size, intermediate_size//2, dtype=torch.int8),
|
||||||
|
requires_grad=False,
|
||||||
|
)
|
||||||
|
layer.register_parameter("w2_weight", w2_weight)
|
||||||
|
set_weight_attrs(w2_weight, extra_weight_attrs)
|
||||||
|
|
||||||
|
w13_weight_scale = torch.nn.Parameter(
|
||||||
|
torch.ones(num_experts, 2 * intermediate_size, 1, dtype=torch.float32),
|
||||||
|
requires_grad=False,
|
||||||
|
)
|
||||||
|
w2_weight_scale = torch.nn.Parameter(
|
||||||
|
torch.ones(num_experts, hidden_size, 1, dtype=torch.float32),
|
||||||
|
requires_grad=False,
|
||||||
|
)
|
||||||
|
layer.register_parameter("w13_weight_scale", w13_weight_scale)
|
||||||
|
layer.register_parameter("w2_weight_scale", w2_weight_scale)
|
||||||
|
|
||||||
|
extra_weight_attrs.update(
|
||||||
|
{"quant_method": FusedMoeWeightScaleSupported.CHANNEL.value}
|
||||||
|
)
|
||||||
|
|
||||||
|
set_weight_attrs(w13_weight_scale, extra_weight_attrs)
|
||||||
|
set_weight_attrs(w2_weight_scale, extra_weight_attrs)
|
||||||
|
|
||||||
|
w13_input_scale = None
|
||||||
|
layer.register_parameter("w13_input_scale", w13_input_scale)
|
||||||
|
|
||||||
|
w2_input_scale = None
|
||||||
|
layer.register_parameter("w2_input_scale", w2_input_scale)
|
||||||
|
|
||||||
|
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||||
|
layer.w13_weight_scale = Parameter(
|
||||||
|
layer.w13_weight_scale.data, requires_grad=False
|
||||||
|
)
|
||||||
|
layer.w2_weight_scale = Parameter(
|
||||||
|
layer.w2_weight_scale.data, requires_grad=False
|
||||||
|
)
|
||||||
|
|
||||||
|
layer.w13_weight = Parameter(w4a8_weight_repack_impl(layer.w13_weight), requires_grad=False)
|
||||||
|
layer.w2_weight = Parameter(w4a8_weight_repack_impl(layer.w2_weight), requires_grad=False)
|
||||||
|
|
||||||
|
def create_moe_runner(
|
||||||
|
self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig
|
||||||
|
):
|
||||||
|
self.moe_runner_config = moe_runner_config
|
||||||
|
self.runner = MoeRunner(MoeRunnerBackend.TRITON, moe_runner_config)
|
||||||
|
|
||||||
|
|
||||||
|
def apply(
|
||||||
|
self,
|
||||||
|
layer: torch.nn.Module,
|
||||||
|
dispatch_output: StandardDispatchOutput,
|
||||||
|
) -> CombineInput:
|
||||||
|
x = dispatch_output.hidden_states
|
||||||
|
topk_output = dispatch_output.topk_output
|
||||||
|
from sglang.srt.layers.moe.topk import apply_topk_weights_cpu
|
||||||
|
|
||||||
|
topk_weights, topk_ids, _ = topk_output
|
||||||
|
x, topk_weights = apply_topk_weights_cpu(
|
||||||
|
self.moe_runner_config.apply_router_weight_on_input, topk_weights, x
|
||||||
|
)
|
||||||
|
workspace, global_reduce_buffer = MarlinMoeWorkspace(x.device).get_buffers()
|
||||||
|
output = fused_experts_impl_w4a8_marlin(
|
||||||
|
x,
|
||||||
|
layer.w13_weight,
|
||||||
|
layer.w2_weight,
|
||||||
|
topk_weights=topk_weights,
|
||||||
|
topk_ids=topk_ids,
|
||||||
|
workspace=workspace,
|
||||||
|
global_reduce_buffer=global_reduce_buffer,
|
||||||
|
inplace=True,
|
||||||
|
use_int4_w4a8=True,
|
||||||
|
per_channel_quant=True,
|
||||||
|
activation=layer.moe_runner_config.activation,
|
||||||
|
expert_map=layer.expert_map_gpu,
|
||||||
|
apply_router_weight_on_input=self.moe_runner_config.apply_router_weight_on_input,
|
||||||
|
global_num_experts=layer.moe_runner_config.num_experts,
|
||||||
|
w1_scale=(layer.w13_weight_scale),
|
||||||
|
w2_scale=(layer.w2_weight_scale),
|
||||||
|
a1_scale=layer.w13_input_scale,
|
||||||
|
a2_scale=layer.w2_input_scale,
|
||||||
|
use_nn_moe=False,
|
||||||
|
)
|
||||||
|
return StandardCombineInput(hidden_states=output)
|
||||||
|
# def _apply(
|
||||||
|
# self,
|
||||||
|
# layer: torch.nn.Module,
|
||||||
|
# x: torch.Tensor,
|
||||||
|
# router_logits: torch.Tensor,
|
||||||
|
# top_k: int,
|
||||||
|
# #renormalize: bool,
|
||||||
|
# #use_grouped_topk: bool = False,
|
||||||
|
# topk_group: Optional[int] = None,
|
||||||
|
# num_expert_group: Optional[int] = None,
|
||||||
|
# global_num_experts: int = -1,
|
||||||
|
# expert_map: Optional[torch.Tensor] = None,
|
||||||
|
# custom_routing_function: Optional[Callable] = None,
|
||||||
|
# scoring_func: str = "softmax",
|
||||||
|
# e_score_correction_bias: Optional[torch.Tensor] = None,
|
||||||
|
# apply_router_weight_on_input: bool = False,
|
||||||
|
# activation: str = "silu",
|
||||||
|
# enable_eplb: bool = False,
|
||||||
|
# use_nn_moe: Optional[bool] = False,
|
||||||
|
# routed_scaling_factor: Optional[float] = None,
|
||||||
|
# use_fused_gate: Optional[bool] = False,
|
||||||
|
# **_
|
||||||
|
# ) -> torch.Tensor:
|
||||||
|
# from sglang.srt.layers.moe.fused_moe_triton import (FusedMoE, FusedMoeWeightScaleSupported)
|
||||||
|
# from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
|
||||||
|
# if enable_eplb:
|
||||||
|
# raise NotImplementedError(
|
||||||
|
# "EPLB not supported for `SlimQuantW4A8Int8MarlinMoEMethod` yet.")
|
||||||
|
# # Expert selection
|
||||||
|
# topk_weights, topk_ids = FusedMoE.select_experts(
|
||||||
|
# hidden_states=x,
|
||||||
|
# router_logits=router_logits,
|
||||||
|
# #use_grouped_topk=use_grouped_topk,
|
||||||
|
# top_k=top_k,
|
||||||
|
# #renormalize=renormalize,
|
||||||
|
# topk_group=topk_group,
|
||||||
|
# num_expert_group=num_expert_group,
|
||||||
|
# custom_routing_function=custom_routing_function,
|
||||||
|
# scoring_func=scoring_func,
|
||||||
|
# e_score_correction_bias=e_score_correction_bias,
|
||||||
|
# routed_scaling_factor=routed_scaling_factor,
|
||||||
|
# use_fused_gate=use_fused_gate
|
||||||
|
# )
|
||||||
|
# workspace, global_reduce_buffer = MarlinMoeWorkspace(x.device).get_buffers()
|
||||||
|
# return fused_experts_impl_w4a8_marlin(
|
||||||
|
# x,
|
||||||
|
# layer.w13_weight,
|
||||||
|
# layer.w2_weight,
|
||||||
|
# topk_weights=topk_weights,
|
||||||
|
# topk_ids=topk_ids,
|
||||||
|
# workspace=workspace,
|
||||||
|
# global_reduce_buffer=global_reduce_buffer,
|
||||||
|
# inplace=True,
|
||||||
|
# use_int4_w4a8=True,
|
||||||
|
# per_channel_quant=True,
|
||||||
|
# activation=activation,
|
||||||
|
# expert_map=expert_map,
|
||||||
|
# apply_router_weight_on_input=apply_router_weight_on_input,
|
||||||
|
# global_num_experts=global_num_experts,
|
||||||
|
# w1_scale=(layer.w13_weight_scale),
|
||||||
|
# w2_scale=(layer.w2_weight_scale),
|
||||||
|
# a1_scale=layer.w13_input_scale,
|
||||||
|
# a2_scale=layer.w2_input_scale,
|
||||||
|
# use_nn_moe=use_nn_moe,
|
||||||
|
# )
|
||||||
92
python/sglang/srt/layers/quantization/w4a8_utils.py
Normal file
92
python/sglang/srt/layers/quantization/w4a8_utils.py
Normal file
@@ -0,0 +1,92 @@
|
|||||||
|
import torch
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
try:
|
||||||
|
from lightop import awq_marlin_repack_w4a8
|
||||||
|
use_lightop = False
|
||||||
|
except Exception:
|
||||||
|
use_lightop = False
|
||||||
|
|
||||||
|
def unpack_int8_to_int4(tensor_int8: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
将[N, K//2]大小的torch.int8 Tensor,转换为[N, K]大小的torch.int32 Tensor。
|
||||||
|
每个int8包含两个int4,分别提取到int32的低4位,其余位为0。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tensor_int8 (torch.Tensor): 输入张量,形状为[N, K//2],类型为torch.int8。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
torch.Tensor: 输出张量,形状为[N, K],类型为torch.int32。
|
||||||
|
"""
|
||||||
|
if tensor_int8.dtype != torch.int8:
|
||||||
|
raise ValueError("Input tensor must be of type torch.int8")
|
||||||
|
|
||||||
|
N, K_half = tensor_int8.shape
|
||||||
|
tensor_uint8 = tensor_int8.to(torch.uint8)
|
||||||
|
high4 = tensor_uint8 & 0x0F
|
||||||
|
low4 = (tensor_uint8 >> 4) & 0x0F
|
||||||
|
unpacked = torch.empty((N, K_half * 2), dtype=torch.int32, device=tensor_int8.device)
|
||||||
|
unpacked[:, 0::2] = low4.to(torch.int32)
|
||||||
|
unpacked[:, 1::2] = high4.to(torch.int32)
|
||||||
|
|
||||||
|
return unpacked
|
||||||
|
|
||||||
|
def get_weight_perms(interleave: bool=True):
|
||||||
|
perm = []
|
||||||
|
for i in range(64):
|
||||||
|
|
||||||
|
for col in range(4):
|
||||||
|
cur_col = (i % 16) * 4 + col
|
||||||
|
for row in range(8):
|
||||||
|
cur_row = (i // 16) * 8 + row
|
||||||
|
cur_idx = cur_row * 64 + cur_col
|
||||||
|
perm.append(cur_idx)
|
||||||
|
|
||||||
|
perm = np.array(perm)
|
||||||
|
if interleave:
|
||||||
|
interleave = np.array([4, 0, 5, 1, 6, 2, 7, 3])
|
||||||
|
perm = perm.reshape((-1, 8))[:, interleave].ravel()
|
||||||
|
|
||||||
|
perm = torch.from_numpy(perm)
|
||||||
|
|
||||||
|
return perm
|
||||||
|
|
||||||
|
def marlin_weights(q_w,weight_perm,k_tile=32,n_tile=64,pack_factor=8):
|
||||||
|
size_k, size_n = q_w.shape
|
||||||
|
q_w = q_w.reshape((size_k // k_tile, k_tile, size_n // n_tile, n_tile))
|
||||||
|
q_w = q_w.permute((0, 2, 1, 3))
|
||||||
|
q_w = q_w.reshape((size_k // k_tile, size_n * k_tile))
|
||||||
|
q_w = q_w.reshape((-1, weight_perm.numel()))[:, weight_perm].reshape(q_w.shape)
|
||||||
|
|
||||||
|
orig_device = q_w.device
|
||||||
|
q_w = q_w.contiguous().to(torch.int32)
|
||||||
|
M, N = q_w.shape
|
||||||
|
assert N % pack_factor == 0, f"size_n ({N}) must be divisible by pack_factor ({pack_factor})"
|
||||||
|
q_packed = torch.zeros((M, N // pack_factor), dtype=torch.int32, device=orig_device)
|
||||||
|
for i in range(pack_factor):
|
||||||
|
q_packed += q_w[:, i::pack_factor] << (4 * i)
|
||||||
|
|
||||||
|
return q_packed
|
||||||
|
|
||||||
|
def w4a8_2_marlin_weight(w4a8_w):
|
||||||
|
full_w4a8_w = unpack_int8_to_int4(w4a8_w)
|
||||||
|
full_w4a8_w = full_w4a8_w.T
|
||||||
|
weight_perm = get_weight_perms()
|
||||||
|
marlin_q_w = marlin_weights(full_w4a8_w, weight_perm, k_tile=32, n_tile=64, pack_factor=8)
|
||||||
|
return marlin_q_w
|
||||||
|
|
||||||
|
def w4a8_weight_repack_impl(input):
|
||||||
|
if use_lightop:
|
||||||
|
size_batch = input.shape[0]
|
||||||
|
size_n = input.shape[1]
|
||||||
|
size_k = input.shape[2] * 2
|
||||||
|
output = torch.zeros((size_batch, size_k // 32, size_n * 4), device=input.device, dtype=torch.int32)
|
||||||
|
awq_marlin_repack_w4a8(input, output, size_batch, size_k, size_n)
|
||||||
|
else:
|
||||||
|
w_marlin_list = []
|
||||||
|
for e in range(input.shape[0]):
|
||||||
|
w_marlin_in = w4a8_2_marlin_weight(input[e])
|
||||||
|
w_marlin_list.append(w_marlin_in)
|
||||||
|
output = torch.stack(w_marlin_list, dim=0)
|
||||||
|
|
||||||
|
return output
|
||||||
@@ -516,7 +516,7 @@ class ModelRunner:
|
|||||||
):
|
):
|
||||||
server_args.attention_backend = "fa3"
|
server_args.attention_backend = "fa3"
|
||||||
elif _is_hip:
|
elif _is_hip:
|
||||||
server_args.attention_backend = "aiter"
|
server_args.attention_backend = "triton"
|
||||||
elif _is_npu:
|
elif _is_npu:
|
||||||
server_args.attention_backend = "ascend"
|
server_args.attention_backend = "ascend"
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -84,6 +84,7 @@ QUANTIZATION_CHOICES = [
|
|||||||
"qoq",
|
"qoq",
|
||||||
"w4afp8",
|
"w4afp8",
|
||||||
"mxfp4",
|
"mxfp4",
|
||||||
|
"slimquant_w4a8_marlin",
|
||||||
]
|
]
|
||||||
|
|
||||||
ATTENTION_BACKEND_CHOICES = [
|
ATTENTION_BACKEND_CHOICES = [
|
||||||
|
|||||||
@@ -165,10 +165,10 @@ DINLINE void start_sync(
|
|||||||
if (threadIdx.x < ngpus) {
|
if (threadIdx.x < ngpus) {
|
||||||
// simultaneously write to the corresponding flag of all ranks.
|
// simultaneously write to the corresponding flag of all ranks.
|
||||||
// Latency = 1 p2p write
|
// Latency = 1 p2p write
|
||||||
__scoped_atomic_store_n(
|
__hip_atomic_store(
|
||||||
&sg.signals[threadIdx.x]->start[blockIdx.x][rank], flag, __ATOMIC_RELAXED, __MEMORY_SCOPE_SYSTEM);
|
&sg.signals[threadIdx.x]->start[blockIdx.x][rank], flag, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_SYSTEM);
|
||||||
// wait until we got true from all ranks
|
// wait until we got true from all ranks
|
||||||
while (__scoped_atomic_load_n(&self_sg->start[blockIdx.x][threadIdx.x], __ATOMIC_RELAXED, __MEMORY_SCOPE_DEVICE) <
|
while (__hip_atomic_load(&self_sg->start[blockIdx.x][threadIdx.x], __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT) <
|
||||||
flag)
|
flag)
|
||||||
;
|
;
|
||||||
}
|
}
|
||||||
@@ -211,16 +211,16 @@ DINLINE void end_sync(
|
|||||||
if (threadIdx.x < ngpus) {
|
if (threadIdx.x < ngpus) {
|
||||||
// simultaneously write to the corresponding flag of all ranks.
|
// simultaneously write to the corresponding flag of all ranks.
|
||||||
// Latency = 1 p2p write
|
// Latency = 1 p2p write
|
||||||
__scoped_atomic_store_n(
|
__hip_atomic_store(
|
||||||
&sg.signals[threadIdx.x]->end[blockIdx.x][rank],
|
&sg.signals[threadIdx.x]->end[blockIdx.x][rank],
|
||||||
flag,
|
flag,
|
||||||
final_sync ? __ATOMIC_RELAXED : __ATOMIC_RELEASE,
|
final_sync ? __ATOMIC_RELAXED : __ATOMIC_RELEASE,
|
||||||
__MEMORY_SCOPE_SYSTEM);
|
__HIP_MEMORY_SCOPE_SYSTEM);
|
||||||
// wait until we got true from all ranks
|
// wait until we got true from all ranks
|
||||||
while (__scoped_atomic_load_n(
|
while (__hip_atomic_load(
|
||||||
&self_sg->end[blockIdx.x][threadIdx.x],
|
&self_sg->end[blockIdx.x][threadIdx.x],
|
||||||
final_sync ? __ATOMIC_RELAXED : __ATOMIC_ACQUIRE,
|
final_sync ? __ATOMIC_RELAXED : __ATOMIC_ACQUIRE,
|
||||||
__MEMORY_SCOPE_DEVICE) < flag)
|
__HIP_MEMORY_SCOPE_AGENT) < flag)
|
||||||
;
|
;
|
||||||
}
|
}
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
|
|||||||
@@ -21,6 +21,7 @@ limitations under the License.
|
|||||||
|
|
||||||
#include "utils.h"
|
#include "utils.h"
|
||||||
|
|
||||||
|
#define WARP_SIZE 64
|
||||||
#define VEC_SIZE 4
|
#define VEC_SIZE 4
|
||||||
using Vec = int4;
|
using Vec = int4;
|
||||||
|
|
||||||
@@ -45,7 +46,7 @@ __device__ __forceinline__ int warp_exclusive_scan(int v, unsigned mask = 0xffff
|
|||||||
int original = v;
|
int original = v;
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int offset = 1; offset < WARP_SIZE; offset <<= 1) {
|
for (int offset = 1; offset < WARP_SIZE; offset <<= 1) {
|
||||||
int n = __shfl_up_sync(mask, v, offset);
|
int n = __shfl_up(v, offset);
|
||||||
if ((threadIdx.x & (WARP_SIZE - 1)) >= offset) v += n;
|
if ((threadIdx.x & (WARP_SIZE - 1)) >= offset) v += n;
|
||||||
}
|
}
|
||||||
return v - original;
|
return v - original;
|
||||||
|
|||||||
@@ -60,7 +60,7 @@ template <typename T>
|
|||||||
__device__ float convert_to_float(T x) {
|
__device__ float convert_to_float(T x) {
|
||||||
if constexpr (std::is_same_v<T, __half>) {
|
if constexpr (std::is_same_v<T, __half>) {
|
||||||
return __half2float(x);
|
return __half2float(x);
|
||||||
} else if constexpr (std::is_same_v<T, __nv_bfloat16>) {
|
} else if constexpr (std::is_same_v<T, __hip_bfloat16>) {
|
||||||
return __bfloat162float(x);
|
return __bfloat162float(x);
|
||||||
} else if constexpr (std::is_same_v<T, float>) {
|
} else if constexpr (std::is_same_v<T, float>) {
|
||||||
return x;
|
return x;
|
||||||
@@ -575,8 +575,8 @@ void topk_softmax(
|
|||||||
renormalize,
|
renormalize,
|
||||||
stream);
|
stream);
|
||||||
} else if (dtype == at::ScalarType::BFloat16) {
|
} else if (dtype == at::ScalarType::BFloat16) {
|
||||||
topkGatingSoftmaxKernelLauncher<__nv_bfloat16>(
|
topkGatingSoftmaxKernelLauncher<__hip_bfloat16>(
|
||||||
reinterpret_cast<const __nv_bfloat16*>(gating_output.data_ptr<at::BFloat16>()),
|
reinterpret_cast<const __hip_bfloat16*>(gating_output.data_ptr<at::BFloat16>()),
|
||||||
topk_weights.data_ptr<float>(),
|
topk_weights.data_ptr<float>(),
|
||||||
topk_indices.data_ptr<int>(),
|
topk_indices.data_ptr<int>(),
|
||||||
softmax_workspace.data_ptr<float>(),
|
softmax_workspace.data_ptr<float>(),
|
||||||
|
|||||||
@@ -358,25 +358,25 @@ __device__ __forceinline__ dstDtype castFromFloat(float val) {
|
|||||||
#endif
|
#endif
|
||||||
|
|
||||||
// add FP8 support
|
// add FP8 support
|
||||||
#ifndef USE_ROCM
|
// #ifndef USE_ROCM
|
||||||
#include <c10/util/Float8_e4m3fn.h>
|
// #include <c10/util/Float8_e4m3fn.h>
|
||||||
using FP8_TYPE = c10::Float8_e4m3fn;
|
// using FP8_TYPE = c10::Float8_e4m3fn;
|
||||||
C10_HOST_DEVICE constexpr auto FP8_E4M3_MAX = std::numeric_limits<FP8_TYPE>::max();
|
// C10_HOST_DEVICE constexpr auto FP8_E4M3_MAX = std::numeric_limits<FP8_TYPE>::max();
|
||||||
#else // USE_ROCM
|
// #else // USE_ROCM
|
||||||
#if HIP_FP8_TYPE_FNUZ
|
// #if HIP_FP8_TYPE_FNUZ
|
||||||
#include <c10/util/Float8_e4m3fnuz.h>
|
// #include <c10/util/Float8_e4m3fnuz.h>
|
||||||
using FP8_TYPE = c10::Float8_e4m3fnuz;
|
// using FP8_TYPE = c10::Float8_e4m3fnuz;
|
||||||
constexpr auto FP8_E4M3_MAX = 224.0f;
|
// constexpr auto FP8_E4M3_MAX = 224.0f;
|
||||||
#else
|
// #else
|
||||||
#if HIP_FP8_TYPE_E4M3
|
// #if HIP_FP8_TYPE_E4M3
|
||||||
#include <c10/util/Float8_e4m3fn.h>
|
// #include <c10/util/Float8_e4m3fn.h>
|
||||||
using FP8_TYPE = c10::Float8_e4m3fn;
|
// using FP8_TYPE = c10::Float8_e4m3fn;
|
||||||
C10_HOST_DEVICE constexpr auto FP8_E4M3_MAX = std::numeric_limits<FP8_TYPE>::max();
|
// C10_HOST_DEVICE constexpr auto FP8_E4M3_MAX = std::numeric_limits<FP8_TYPE>::max();
|
||||||
#else
|
// #else
|
||||||
#error "fp8 is not supported in this processor (arch < gfx942)."
|
// #error "fp8 is not supported in this processor (arch < gfx942)."
|
||||||
#endif // HIP_FP8_TYPE_E4M3
|
// #endif // HIP_FP8_TYPE_E4M3
|
||||||
#endif // HIP_FP8_TYPE_FNUZ
|
// #endif // HIP_FP8_TYPE_FNUZ
|
||||||
#endif // USE_ROCM
|
// #endif // USE_ROCM
|
||||||
|
|
||||||
#define FULL_MASK 0xffffffff
|
#define FULL_MASK 0xffffffff
|
||||||
|
|
||||||
|
|||||||
100
sgl-kernel/setup_hip.py
Normal file
100
sgl-kernel/setup_hip.py
Normal file
@@ -0,0 +1,100 @@
|
|||||||
|
# Copyright 2025 SGLang Team. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
import platform
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from setuptools import find_packages, setup
|
||||||
|
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
|
||||||
|
|
||||||
|
root = Path(__file__).parent.resolve()
|
||||||
|
arch = platform.machine().lower()
|
||||||
|
|
||||||
|
|
||||||
|
def _get_version():
|
||||||
|
with open(root / "pyproject.toml") as f:
|
||||||
|
for line in f:
|
||||||
|
if line.startswith("version"):
|
||||||
|
return line.split("=")[1].strip().strip('"')
|
||||||
|
|
||||||
|
|
||||||
|
operator_namespace = "sgl_kernel"
|
||||||
|
include_dirs = [
|
||||||
|
root / "include",
|
||||||
|
root / "csrc",
|
||||||
|
]
|
||||||
|
|
||||||
|
sources = [
|
||||||
|
"csrc/allreduce/custom_all_reduce.hip",
|
||||||
|
"csrc/allreduce/quick_all_reduce.cu",
|
||||||
|
"csrc/common_extension_rocm.cc",
|
||||||
|
"csrc/elementwise/activation.cu",
|
||||||
|
"csrc/grammar/apply_token_bitmask_inplace_cuda.cu",
|
||||||
|
"csrc/moe/moe_align_kernel.cu",
|
||||||
|
"csrc/moe/moe_topk_softmax_kernels.cu",
|
||||||
|
"csrc/speculative/eagle_utils.cu",
|
||||||
|
"csrc/kvcacheio/transfer.cu",
|
||||||
|
]
|
||||||
|
|
||||||
|
cxx_flags = [
|
||||||
|
"-O3",
|
||||||
|
"-Wno-switch-bool",
|
||||||
|
"-Wno-macro-redefined",
|
||||||
|
"-Wno-deprecated-declarations",
|
||||||
|
"-w",
|
||||||
|
]
|
||||||
|
libraries = ["c10", "torch", "torch_python"]
|
||||||
|
extra_link_args = ["-Wl,-rpath,$ORIGIN/../../torch/lib", f"-L/usr/lib/{arch}-linux-gnu"]
|
||||||
|
|
||||||
|
hipcc_flags = [
|
||||||
|
"-fPIC",
|
||||||
|
"-O3",
|
||||||
|
"-std=c++17",
|
||||||
|
"-D__HIP_PLATFORM_HCC__=1",
|
||||||
|
"--offload-arch=gfx928",
|
||||||
|
"--offload-arch=gfx936",
|
||||||
|
"--gpu-max-threads-per-block=1024",
|
||||||
|
"-Wno-macro-redefined",
|
||||||
|
"-Wno-deprecated-declarations",
|
||||||
|
"-funroll-loops",
|
||||||
|
"-Rpass-analysis=unroll-loops",
|
||||||
|
"-w",
|
||||||
|
]
|
||||||
|
|
||||||
|
ext_modules = [
|
||||||
|
CUDAExtension(
|
||||||
|
name="sgl_kernel.common_ops",
|
||||||
|
sources=sources,
|
||||||
|
include_dirs=include_dirs,
|
||||||
|
extra_compile_args={
|
||||||
|
"nvcc": hipcc_flags,
|
||||||
|
"cxx": cxx_flags,
|
||||||
|
},
|
||||||
|
libraries=libraries,
|
||||||
|
extra_link_args=extra_link_args,
|
||||||
|
py_limited_api=False,
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
setup(
|
||||||
|
name="sgl-kernel",
|
||||||
|
version=_get_version(),
|
||||||
|
packages=find_packages(where="python"),
|
||||||
|
package_dir={"": "python"},
|
||||||
|
ext_modules=ext_modules,
|
||||||
|
cmdclass={"build_ext": BuildExtension.with_options(use_ninja=True)},
|
||||||
|
options={"bdist_wheel": {"py_limited_api": "cp39"}},
|
||||||
|
)
|
||||||
Reference in New Issue
Block a user