Replace sglang.srt.layers.quantization.scalar_types with sgl_kernel.scalar_type (#8951)
This commit is contained in:
@@ -29,9 +29,8 @@ from sglang.srt.layers.quantization.marlin_utils import (
|
||||
verify_marlin_supported,
|
||||
verify_marlin_supports_shape,
|
||||
)
|
||||
from sglang.srt.layers.quantization.scalar_type import scalar_types
|
||||
from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod
|
||||
from sglang.srt.layers.quantization.utils import replace_parameter
|
||||
from sglang.srt.layers.quantization.utils import get_scalar_types, replace_parameter
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.layers.moe.topk import TopKOutput
|
||||
@@ -52,6 +51,7 @@ _is_cuda = is_cuda()
|
||||
_is_hip = is_hip()
|
||||
if _is_cuda:
|
||||
from sgl_kernel import awq_dequantize, fused_marlin_moe
|
||||
|
||||
elif _is_hip:
|
||||
from sglang.srt.layers.quantization.awq_triton import (
|
||||
awq_dequantize_triton as awq_dequantize,
|
||||
@@ -64,6 +64,9 @@ else:
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
ScalarType, scalar_types = get_scalar_types()
|
||||
|
||||
|
||||
def is_layer_skipped_awq(prefix: str, modules_to_not_convert: List[str]):
|
||||
return any(module_name in prefix for module_name in modules_to_not_convert)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user