Error occurs when loading the gemma model in bitsandbytes format. (#2557)
This commit is contained in:
committed by
GitHub
parent
60bd32723a
commit
08effbff35
@@ -770,6 +770,21 @@ class BitsAndBytesModelLoader(BaseModelLoader):
|
||||
quant_state_dict,
|
||||
)
|
||||
|
||||
def _is_8bit_weight_name(self, weight_name: str):
|
||||
quantized_suffix = {".scb", ".weight_format"}
|
||||
return any(weight_name.lower().endswith(suffix) for suffix in quantized_suffix)
|
||||
|
||||
def _is_4bit_weight_name(self, weight_name: str):
|
||||
quantized_suffix = {
|
||||
"absmax",
|
||||
"quant_map",
|
||||
"nested_absmax",
|
||||
"nested_quant_map",
|
||||
"bitsandbytes",
|
||||
}
|
||||
suffix = weight_name.split(".")[-1]
|
||||
return any(q_suffix in suffix for q_suffix in quantized_suffix)
|
||||
|
||||
def _quantized_8bit_generator(
|
||||
self, hf_weights_files, use_safetensors, quant_state_dict
|
||||
) -> Generator:
|
||||
@@ -779,21 +794,18 @@ class BitsAndBytesModelLoader(BaseModelLoader):
|
||||
if not weight_name.lower().endswith(".scb"):
|
||||
continue
|
||||
|
||||
weight_key = weight_name.lower().replace(".scb", ".qweight")
|
||||
weight_key = weight_name.lower().replace(".scb", ".weight")
|
||||
quant_state_dict[weight_key] = weight_tensor
|
||||
|
||||
for weight_name, weight_tensor in self._hf_weight_iter(
|
||||
hf_weights_files, use_safetensors
|
||||
):
|
||||
|
||||
if not weight_name.endswith((".weight", ".bias")):
|
||||
if self._is_8bit_weight_name(weight_name):
|
||||
continue
|
||||
|
||||
qweight_name = weight_name.replace(".weight", ".qweight")
|
||||
|
||||
if qweight_name in quant_state_dict:
|
||||
if weight_name in quant_state_dict:
|
||||
set_weight_attrs(weight_tensor, {"load_in_8bit": True})
|
||||
yield qweight_name, weight_tensor
|
||||
yield weight_name, weight_tensor
|
||||
else:
|
||||
yield weight_name, weight_tensor
|
||||
|
||||
@@ -806,7 +818,7 @@ class BitsAndBytesModelLoader(BaseModelLoader):
|
||||
weight_iterator = self._hf_weight_iter(hf_weights_files, use_safetensors)
|
||||
temp_state_dict = {}
|
||||
for weight_name, weight_tensor in weight_iterator:
|
||||
if weight_name.endswith((".weight", ".bias")):
|
||||
if not self._is_4bit_weight_name(weight_name):
|
||||
continue
|
||||
# bitsandbytes library requires
|
||||
# weight.quant_state.bitsandbytes__* in CPU
|
||||
@@ -830,16 +842,15 @@ class BitsAndBytesModelLoader(BaseModelLoader):
|
||||
hf_weights_files, use_safetensors
|
||||
):
|
||||
|
||||
if not weight_name.endswith((".weight", ".bias")):
|
||||
if self._is_4bit_weight_name(weight_name):
|
||||
continue
|
||||
|
||||
if (f"{weight_name}.quant_state.bitsandbytes__nf4" in temp_state_dict) or (
|
||||
f"{weight_name}.quant_state.bitsandbytes__fp4" in temp_state_dict
|
||||
):
|
||||
quant_state = _parse_quant_state(weight_name, temp_state_dict)
|
||||
weight_name = weight_name.replace(".weight", ".qweight")
|
||||
quant_state_dict[weight_name] = quant_state
|
||||
yield weight_name.replace(".weight", ".qweight"), weight_tensor
|
||||
yield weight_name, weight_tensor
|
||||
else:
|
||||
yield weight_name, weight_tensor
|
||||
|
||||
|
||||
@@ -307,6 +307,25 @@ class Gemma2Model(nn.Module):
|
||||
|
||||
|
||||
class Gemma2ForCausalLM(nn.Module):
|
||||
# BitandBytes specific attributes
|
||||
default_bitsandbytes_target_modules = [
|
||||
".gate_proj.",
|
||||
".down_proj.",
|
||||
".up_proj.",
|
||||
".q_proj.",
|
||||
".k_proj.",
|
||||
".v_proj.",
|
||||
".o_proj.",
|
||||
]
|
||||
bitsandbytes_stacked_params_mapping = {
|
||||
# shard_name, weight_name, index
|
||||
"q_proj": ("qkv_proj", 0),
|
||||
"k_proj": ("qkv_proj", 1),
|
||||
"v_proj": ("qkv_proj", 2),
|
||||
"gate_proj": ("gate_up_proj", 0),
|
||||
"up_proj": ("gate_up_proj", 1),
|
||||
}
|
||||
|
||||
packed_modules_mapping = {
|
||||
"qkv_proj": [
|
||||
"q_proj",
|
||||
|
||||
Reference in New Issue
Block a user