diff --git a/python/sglang/srt/model_loader/loader.py b/python/sglang/srt/model_loader/loader.py index e0b03d771..776b69aaf 100644 --- a/python/sglang/srt/model_loader/loader.py +++ b/python/sglang/srt/model_loader/loader.py @@ -770,6 +770,21 @@ class BitsAndBytesModelLoader(BaseModelLoader): quant_state_dict, ) + def _is_8bit_weight_name(self, weight_name: str): + quantized_suffix = {".scb", ".weight_format"} + return any(weight_name.lower().endswith(suffix) for suffix in quantized_suffix) + + def _is_4bit_weight_name(self, weight_name: str): + quantized_suffix = { + "absmax", + "quant_map", + "nested_absmax", + "nested_quant_map", + "bitsandbytes", + } + suffix = weight_name.split(".")[-1] + return any(q_suffix in suffix for q_suffix in quantized_suffix) + def _quantized_8bit_generator( self, hf_weights_files, use_safetensors, quant_state_dict ) -> Generator: @@ -779,21 +794,18 @@ class BitsAndBytesModelLoader(BaseModelLoader): if not weight_name.lower().endswith(".scb"): continue - weight_key = weight_name.lower().replace(".scb", ".qweight") + weight_key = weight_name.lower().replace(".scb", ".weight") quant_state_dict[weight_key] = weight_tensor for weight_name, weight_tensor in self._hf_weight_iter( hf_weights_files, use_safetensors ): - - if not weight_name.endswith((".weight", ".bias")): + if self._is_8bit_weight_name(weight_name): continue - qweight_name = weight_name.replace(".weight", ".qweight") - - if qweight_name in quant_state_dict: + if weight_name in quant_state_dict: set_weight_attrs(weight_tensor, {"load_in_8bit": True}) - yield qweight_name, weight_tensor + yield weight_name, weight_tensor else: yield weight_name, weight_tensor @@ -806,7 +818,7 @@ class BitsAndBytesModelLoader(BaseModelLoader): weight_iterator = self._hf_weight_iter(hf_weights_files, use_safetensors) temp_state_dict = {} for weight_name, weight_tensor in weight_iterator: - if weight_name.endswith((".weight", ".bias")): + if not self._is_4bit_weight_name(weight_name): continue # bitsandbytes library requires # weight.quant_state.bitsandbytes__* in CPU @@ -830,16 +842,15 @@ class BitsAndBytesModelLoader(BaseModelLoader): hf_weights_files, use_safetensors ): - if not weight_name.endswith((".weight", ".bias")): + if self._is_4bit_weight_name(weight_name): continue if (f"{weight_name}.quant_state.bitsandbytes__nf4" in temp_state_dict) or ( f"{weight_name}.quant_state.bitsandbytes__fp4" in temp_state_dict ): quant_state = _parse_quant_state(weight_name, temp_state_dict) - weight_name = weight_name.replace(".weight", ".qweight") quant_state_dict[weight_name] = quant_state - yield weight_name.replace(".weight", ".qweight"), weight_tensor + yield weight_name, weight_tensor else: yield weight_name, weight_tensor diff --git a/python/sglang/srt/models/gemma2.py b/python/sglang/srt/models/gemma2.py index 0c0e6155d..58d9ce02f 100644 --- a/python/sglang/srt/models/gemma2.py +++ b/python/sglang/srt/models/gemma2.py @@ -307,6 +307,25 @@ class Gemma2Model(nn.Module): class Gemma2ForCausalLM(nn.Module): + # BitandBytes specific attributes + default_bitsandbytes_target_modules = [ + ".gate_proj.", + ".down_proj.", + ".up_proj.", + ".q_proj.", + ".k_proj.", + ".v_proj.", + ".o_proj.", + ] + bitsandbytes_stacked_params_mapping = { + # shard_name, weight_name, index + "q_proj": ("qkv_proj", 0), + "k_proj": ("qkv_proj", 1), + "v_proj": ("qkv_proj", 2), + "gate_proj": ("gate_up_proj", 0), + "up_proj": ("gate_up_proj", 1), + } + packed_modules_mapping = { "qkv_proj": [ "q_proj",