diff --git a/python/sglang/srt/environ.py b/python/sglang/srt/environ.py index 0f76743b3..ce78f87b6 100644 --- a/python/sglang/srt/environ.py +++ b/python/sglang/srt/environ.py @@ -158,6 +158,7 @@ class Envs: # AMD & ROCm SGLANG_USE_AITER = EnvBool(False) SGLANG_ROCM_FUSED_DECODE_MLA = EnvBool(False) + SGLANG_ROCM_DISABLE_LINEARQUANT = EnvBool(False) # Quantization SGLANG_INT4_WEIGHT = EnvBool(False) diff --git a/python/sglang/srt/layers/linear.py b/python/sglang/srt/layers/linear.py index 2b34a2965..3aaf301bb 100644 --- a/python/sglang/srt/layers/linear.py +++ b/python/sglang/srt/layers/linear.py @@ -32,7 +32,7 @@ from sglang.srt.layers.parameter import ( ) from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod from sglang.srt.layers.utils import pad_or_narrow_weight -from sglang.srt.utils import is_cpu, is_npu, set_weight_attrs +from sglang.srt.utils import get_bool_env_var, is_cpu, is_hip, is_npu, set_weight_attrs if TYPE_CHECKING: from sglang.srt.layers.quantization.base_config import ( @@ -40,6 +40,11 @@ if TYPE_CHECKING: QuantizeMethodBase, ) +_is_hip = is_hip() +_disable_hip_linear_quant = _is_hip and get_bool_env_var( + "SGLANG_ROCM_DISABLE_LINEARQUANT" +) + logger = logging.getLogger(__name__) WEIGHT_LOADER_V2_SUPPORTED = [ @@ -824,6 +829,7 @@ class QKVParallelLinear(ColumnParallelLinear): self.num_kv_heads * self.head_size * tp_size, # v_proj ] self.use_presharded_weights = load_presharded_attn + quant_config = None if _disable_hip_linear_quant else quant_config super().__init__( input_size=input_size, @@ -1225,6 +1231,7 @@ class RowParallelLinear(LinearBase): tp_size: Optional[int] = None, use_presharded_weights: bool = False, ): + quant_config = None if _disable_hip_linear_quant else quant_config super().__init__( input_size, output_size, skip_bias_add, params_dtype, quant_config, prefix )