Replace sglang.srt.layers.quantization.scalar_types with sgl_kernel.scalar_type (#8951)

This commit is contained in:
Hongbo Xu
2025-08-14 10:41:41 +08:00
committed by GitHub
parent 6b7c24712c
commit a669bc2f74
8 changed files with 44 additions and 362 deletions

View File

@@ -29,9 +29,8 @@ from sglang.srt.layers.quantization.marlin_utils import (
verify_marlin_supported,
verify_marlin_supports_shape,
)
from sglang.srt.layers.quantization.scalar_type import scalar_types
from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod
from sglang.srt.layers.quantization.utils import replace_parameter
from sglang.srt.layers.quantization.utils import get_scalar_types, replace_parameter
if TYPE_CHECKING:
from sglang.srt.layers.moe.topk import TopKOutput
@@ -52,6 +51,7 @@ _is_cuda = is_cuda()
_is_hip = is_hip()
if _is_cuda:
from sgl_kernel import awq_dequantize, fused_marlin_moe
elif _is_hip:
from sglang.srt.layers.quantization.awq_triton import (
awq_dequantize_triton as awq_dequantize,
@@ -64,6 +64,9 @@ else:
logger = logging.getLogger(__name__)
ScalarType, scalar_types = get_scalar_types()
def is_layer_skipped_awq(prefix: str, modules_to_not_convert: List[str]):
return any(module_name in prefix for module_name in modules_to_not_convert)