适配w8a8模型
This commit is contained in:
@@ -615,6 +615,7 @@ class ModelConfig:
|
|||||||
"quark",
|
"quark",
|
||||||
"mxfp4",
|
"mxfp4",
|
||||||
"slimquant_w4a8_marlin",
|
"slimquant_w4a8_marlin",
|
||||||
|
"w8a8_int8",
|
||||||
]
|
]
|
||||||
optimized_quantization_methods = [
|
optimized_quantization_methods = [
|
||||||
"fp8",
|
"fp8",
|
||||||
|
|||||||
@@ -14,9 +14,10 @@ from sglang.srt.layers.quantization.fp8_kernel import (
|
|||||||
)
|
)
|
||||||
from sglang.srt.layers.quantization.int8_kernel import (
|
from sglang.srt.layers.quantization.int8_kernel import (
|
||||||
per_token_group_quant_int8,
|
per_token_group_quant_int8,
|
||||||
per_token_quant_int8,
|
# per_token_quant_int8,
|
||||||
sglang_per_token_group_quant_int8,
|
sglang_per_token_group_quant_int8,
|
||||||
)
|
)
|
||||||
|
from lmslim.layers.gemm.int8_utils import per_token_quant_int8
|
||||||
from sglang.srt.utils import (
|
from sglang.srt.utils import (
|
||||||
cpu_has_amx_support,
|
cpu_has_amx_support,
|
||||||
get_bool_env_var,
|
get_bool_env_var,
|
||||||
|
|||||||
@@ -22,7 +22,8 @@ from sglang.srt.layers.quantization.base_config import (
|
|||||||
QuantizeMethodBase,
|
QuantizeMethodBase,
|
||||||
)
|
)
|
||||||
from sglang.srt.layers.quantization.compressed_tensors.utils import should_ignore_layer
|
from sglang.srt.layers.quantization.compressed_tensors.utils import should_ignore_layer
|
||||||
from sglang.srt.layers.quantization.int8_kernel import per_token_quant_int8
|
# from sglang.srt.layers.quantization.int8_kernel import per_token_quant_int8
|
||||||
|
from lmslim.layers.gemm.int8_utils import per_token_quant_int8
|
||||||
from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod
|
from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod
|
||||||
from sglang.srt.utils import (
|
from sglang.srt.utils import (
|
||||||
apply_module_patch,
|
apply_module_patch,
|
||||||
@@ -39,6 +40,8 @@ if TYPE_CHECKING:
|
|||||||
CombineInput,
|
CombineInput,
|
||||||
StandardDispatchOutput,
|
StandardDispatchOutput,
|
||||||
)
|
)
|
||||||
|
from lmslim import quant_ops
|
||||||
|
|
||||||
|
|
||||||
_is_cuda = is_cuda()
|
_is_cuda = is_cuda()
|
||||||
_is_cpu_amx_available = cpu_has_amx_support()
|
_is_cpu_amx_available = cpu_has_amx_support()
|
||||||
@@ -405,7 +408,7 @@ class W8A8Int8LinearMethod(LinearMethodBase):
|
|||||||
x_scale_2d = x_scale.view(-1, x_scale.shape[-1])
|
x_scale_2d = x_scale.view(-1, x_scale.shape[-1])
|
||||||
output_shape = [*x_q.shape[:-1], layer.weight.shape[1]]
|
output_shape = [*x_q.shape[:-1], layer.weight.shape[1]]
|
||||||
|
|
||||||
output = int8_scaled_mm(
|
output = quant_ops.triton_scaled_mm(
|
||||||
x_q_2d,
|
x_q_2d,
|
||||||
layer.weight,
|
layer.weight,
|
||||||
x_scale_2d,
|
x_scale_2d,
|
||||||
|
|||||||
@@ -203,7 +203,7 @@ _is_xpu_xmx_available = xpu_has_xmx_support()
|
|||||||
SGLANG_CI_SMALL_KV_SIZE = os.getenv("SGLANG_CI_SMALL_KV_SIZE", None)
|
SGLANG_CI_SMALL_KV_SIZE = os.getenv("SGLANG_CI_SMALL_KV_SIZE", None)
|
||||||
|
|
||||||
# Detect stragger ranks in model loading
|
# Detect stragger ranks in model loading
|
||||||
UNBALANCED_MODEL_LOADING_TIMEOUT_S = 300
|
UNBALANCED_MODEL_LOADING_TIMEOUT_S = 36000
|
||||||
|
|
||||||
# the ratio of mamba cache pool size to max_running_requests, it will be safe when it is larger than 2 (yizhang2077)
|
# the ratio of mamba cache pool size to max_running_requests, it will be safe when it is larger than 2 (yizhang2077)
|
||||||
MAMBA_CACHE_SIZE_MAX_RUNNING_REQUESTS_RATIO = 3
|
MAMBA_CACHE_SIZE_MAX_RUNNING_REQUESTS_RATIO = 3
|
||||||
|
|||||||
Reference in New Issue
Block a user