适配w8a8模型

This commit is contained in:
maxiao1
2025-10-29 09:06:22 +08:00
parent a5718531b7
commit 4b9b337b39
4 changed files with 9 additions and 4 deletions

View File

@@ -615,6 +615,7 @@ class ModelConfig:
"quark", "quark",
"mxfp4", "mxfp4",
"slimquant_w4a8_marlin", "slimquant_w4a8_marlin",
"w8a8_int8",
] ]
optimized_quantization_methods = [ optimized_quantization_methods = [
"fp8", "fp8",

View File

@@ -14,9 +14,10 @@ from sglang.srt.layers.quantization.fp8_kernel import (
) )
from sglang.srt.layers.quantization.int8_kernel import ( from sglang.srt.layers.quantization.int8_kernel import (
per_token_group_quant_int8, per_token_group_quant_int8,
per_token_quant_int8, # per_token_quant_int8,
sglang_per_token_group_quant_int8, sglang_per_token_group_quant_int8,
) )
from lmslim.layers.gemm.int8_utils import per_token_quant_int8
from sglang.srt.utils import ( from sglang.srt.utils import (
cpu_has_amx_support, cpu_has_amx_support,
get_bool_env_var, get_bool_env_var,

View File

@@ -22,7 +22,8 @@ from sglang.srt.layers.quantization.base_config import (
QuantizeMethodBase, QuantizeMethodBase,
) )
from sglang.srt.layers.quantization.compressed_tensors.utils import should_ignore_layer from sglang.srt.layers.quantization.compressed_tensors.utils import should_ignore_layer
from sglang.srt.layers.quantization.int8_kernel import per_token_quant_int8 # from sglang.srt.layers.quantization.int8_kernel import per_token_quant_int8
from lmslim.layers.gemm.int8_utils import per_token_quant_int8
from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod
from sglang.srt.utils import ( from sglang.srt.utils import (
apply_module_patch, apply_module_patch,
@@ -39,6 +40,8 @@ if TYPE_CHECKING:
CombineInput, CombineInput,
StandardDispatchOutput, StandardDispatchOutput,
) )
from lmslim import quant_ops
_is_cuda = is_cuda() _is_cuda = is_cuda()
_is_cpu_amx_available = cpu_has_amx_support() _is_cpu_amx_available = cpu_has_amx_support()
@@ -405,7 +408,7 @@ class W8A8Int8LinearMethod(LinearMethodBase):
x_scale_2d = x_scale.view(-1, x_scale.shape[-1]) x_scale_2d = x_scale.view(-1, x_scale.shape[-1])
output_shape = [*x_q.shape[:-1], layer.weight.shape[1]] output_shape = [*x_q.shape[:-1], layer.weight.shape[1]]
output = int8_scaled_mm( output = quant_ops.triton_scaled_mm(
x_q_2d, x_q_2d,
layer.weight, layer.weight,
x_scale_2d, x_scale_2d,

View File

@@ -203,7 +203,7 @@ _is_xpu_xmx_available = xpu_has_xmx_support()
SGLANG_CI_SMALL_KV_SIZE = os.getenv("SGLANG_CI_SMALL_KV_SIZE", None) SGLANG_CI_SMALL_KV_SIZE = os.getenv("SGLANG_CI_SMALL_KV_SIZE", None)
# Detect stragger ranks in model loading # Detect stragger ranks in model loading
UNBALANCED_MODEL_LOADING_TIMEOUT_S = 300 UNBALANCED_MODEL_LOADING_TIMEOUT_S = 36000
# the ratio of mamba cache pool size to max_running_requests, it will be safe when it is larger than 2 (yizhang2077) # the ratio of mamba cache pool size to max_running_requests, it will be safe when it is larger than 2 (yizhang2077)
MAMBA_CACHE_SIZE_MAX_RUNNING_REQUESTS_RATIO = 3 MAMBA_CACHE_SIZE_MAX_RUNNING_REQUESTS_RATIO = 3