[AMD] Support a new flag to disable quant on parallelLinear layer if required (#11811)
This commit is contained in:
@@ -158,6 +158,7 @@ class Envs:
|
|||||||
# AMD & ROCm
|
# AMD & ROCm
|
||||||
SGLANG_USE_AITER = EnvBool(False)
|
SGLANG_USE_AITER = EnvBool(False)
|
||||||
SGLANG_ROCM_FUSED_DECODE_MLA = EnvBool(False)
|
SGLANG_ROCM_FUSED_DECODE_MLA = EnvBool(False)
|
||||||
|
SGLANG_ROCM_DISABLE_LINEARQUANT = EnvBool(False)
|
||||||
|
|
||||||
# Quantization
|
# Quantization
|
||||||
SGLANG_INT4_WEIGHT = EnvBool(False)
|
SGLANG_INT4_WEIGHT = EnvBool(False)
|
||||||
|
|||||||
@@ -32,7 +32,7 @@ from sglang.srt.layers.parameter import (
|
|||||||
)
|
)
|
||||||
from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod
|
from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod
|
||||||
from sglang.srt.layers.utils import pad_or_narrow_weight
|
from sglang.srt.layers.utils import pad_or_narrow_weight
|
||||||
from sglang.srt.utils import is_cpu, is_npu, set_weight_attrs
|
from sglang.srt.utils import get_bool_env_var, is_cpu, is_hip, is_npu, set_weight_attrs
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from sglang.srt.layers.quantization.base_config import (
|
from sglang.srt.layers.quantization.base_config import (
|
||||||
@@ -40,6 +40,11 @@ if TYPE_CHECKING:
|
|||||||
QuantizeMethodBase,
|
QuantizeMethodBase,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
_is_hip = is_hip()
|
||||||
|
_disable_hip_linear_quant = _is_hip and get_bool_env_var(
|
||||||
|
"SGLANG_ROCM_DISABLE_LINEARQUANT"
|
||||||
|
)
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
WEIGHT_LOADER_V2_SUPPORTED = [
|
WEIGHT_LOADER_V2_SUPPORTED = [
|
||||||
@@ -824,6 +829,7 @@ class QKVParallelLinear(ColumnParallelLinear):
|
|||||||
self.num_kv_heads * self.head_size * tp_size, # v_proj
|
self.num_kv_heads * self.head_size * tp_size, # v_proj
|
||||||
]
|
]
|
||||||
self.use_presharded_weights = load_presharded_attn
|
self.use_presharded_weights = load_presharded_attn
|
||||||
|
quant_config = None if _disable_hip_linear_quant else quant_config
|
||||||
|
|
||||||
super().__init__(
|
super().__init__(
|
||||||
input_size=input_size,
|
input_size=input_size,
|
||||||
@@ -1225,6 +1231,7 @@ class RowParallelLinear(LinearBase):
|
|||||||
tp_size: Optional[int] = None,
|
tp_size: Optional[int] = None,
|
||||||
use_presharded_weights: bool = False,
|
use_presharded_weights: bool = False,
|
||||||
):
|
):
|
||||||
|
quant_config = None if _disable_hip_linear_quant else quant_config
|
||||||
super().__init__(
|
super().__init__(
|
||||||
input_size, output_size, skip_bias_add, params_dtype, quant_config, prefix
|
input_size, output_size, skip_bias_add, params_dtype, quant_config, prefix
|
||||||
)
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user