Error occurs when loading the gemma model in bitsandbytes format. (#2557)
This commit is contained in:
committed by
GitHub
parent
60bd32723a
commit
08effbff35
@@ -770,6 +770,21 @@ class BitsAndBytesModelLoader(BaseModelLoader):
|
|||||||
quant_state_dict,
|
quant_state_dict,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _is_8bit_weight_name(self, weight_name: str):
|
||||||
|
quantized_suffix = {".scb", ".weight_format"}
|
||||||
|
return any(weight_name.lower().endswith(suffix) for suffix in quantized_suffix)
|
||||||
|
|
||||||
|
def _is_4bit_weight_name(self, weight_name: str):
|
||||||
|
quantized_suffix = {
|
||||||
|
"absmax",
|
||||||
|
"quant_map",
|
||||||
|
"nested_absmax",
|
||||||
|
"nested_quant_map",
|
||||||
|
"bitsandbytes",
|
||||||
|
}
|
||||||
|
suffix = weight_name.split(".")[-1]
|
||||||
|
return any(q_suffix in suffix for q_suffix in quantized_suffix)
|
||||||
|
|
||||||
def _quantized_8bit_generator(
|
def _quantized_8bit_generator(
|
||||||
self, hf_weights_files, use_safetensors, quant_state_dict
|
self, hf_weights_files, use_safetensors, quant_state_dict
|
||||||
) -> Generator:
|
) -> Generator:
|
||||||
@@ -779,21 +794,18 @@ class BitsAndBytesModelLoader(BaseModelLoader):
|
|||||||
if not weight_name.lower().endswith(".scb"):
|
if not weight_name.lower().endswith(".scb"):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
weight_key = weight_name.lower().replace(".scb", ".qweight")
|
weight_key = weight_name.lower().replace(".scb", ".weight")
|
||||||
quant_state_dict[weight_key] = weight_tensor
|
quant_state_dict[weight_key] = weight_tensor
|
||||||
|
|
||||||
for weight_name, weight_tensor in self._hf_weight_iter(
|
for weight_name, weight_tensor in self._hf_weight_iter(
|
||||||
hf_weights_files, use_safetensors
|
hf_weights_files, use_safetensors
|
||||||
):
|
):
|
||||||
|
if self._is_8bit_weight_name(weight_name):
|
||||||
if not weight_name.endswith((".weight", ".bias")):
|
|
||||||
continue
|
continue
|
||||||
|
|
||||||
qweight_name = weight_name.replace(".weight", ".qweight")
|
if weight_name in quant_state_dict:
|
||||||
|
|
||||||
if qweight_name in quant_state_dict:
|
|
||||||
set_weight_attrs(weight_tensor, {"load_in_8bit": True})
|
set_weight_attrs(weight_tensor, {"load_in_8bit": True})
|
||||||
yield qweight_name, weight_tensor
|
yield weight_name, weight_tensor
|
||||||
else:
|
else:
|
||||||
yield weight_name, weight_tensor
|
yield weight_name, weight_tensor
|
||||||
|
|
||||||
@@ -806,7 +818,7 @@ class BitsAndBytesModelLoader(BaseModelLoader):
|
|||||||
weight_iterator = self._hf_weight_iter(hf_weights_files, use_safetensors)
|
weight_iterator = self._hf_weight_iter(hf_weights_files, use_safetensors)
|
||||||
temp_state_dict = {}
|
temp_state_dict = {}
|
||||||
for weight_name, weight_tensor in weight_iterator:
|
for weight_name, weight_tensor in weight_iterator:
|
||||||
if weight_name.endswith((".weight", ".bias")):
|
if not self._is_4bit_weight_name(weight_name):
|
||||||
continue
|
continue
|
||||||
# bitsandbytes library requires
|
# bitsandbytes library requires
|
||||||
# weight.quant_state.bitsandbytes__* in CPU
|
# weight.quant_state.bitsandbytes__* in CPU
|
||||||
@@ -830,16 +842,15 @@ class BitsAndBytesModelLoader(BaseModelLoader):
|
|||||||
hf_weights_files, use_safetensors
|
hf_weights_files, use_safetensors
|
||||||
):
|
):
|
||||||
|
|
||||||
if not weight_name.endswith((".weight", ".bias")):
|
if self._is_4bit_weight_name(weight_name):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if (f"{weight_name}.quant_state.bitsandbytes__nf4" in temp_state_dict) or (
|
if (f"{weight_name}.quant_state.bitsandbytes__nf4" in temp_state_dict) or (
|
||||||
f"{weight_name}.quant_state.bitsandbytes__fp4" in temp_state_dict
|
f"{weight_name}.quant_state.bitsandbytes__fp4" in temp_state_dict
|
||||||
):
|
):
|
||||||
quant_state = _parse_quant_state(weight_name, temp_state_dict)
|
quant_state = _parse_quant_state(weight_name, temp_state_dict)
|
||||||
weight_name = weight_name.replace(".weight", ".qweight")
|
|
||||||
quant_state_dict[weight_name] = quant_state
|
quant_state_dict[weight_name] = quant_state
|
||||||
yield weight_name.replace(".weight", ".qweight"), weight_tensor
|
yield weight_name, weight_tensor
|
||||||
else:
|
else:
|
||||||
yield weight_name, weight_tensor
|
yield weight_name, weight_tensor
|
||||||
|
|
||||||
|
|||||||
@@ -307,6 +307,25 @@ class Gemma2Model(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class Gemma2ForCausalLM(nn.Module):
|
class Gemma2ForCausalLM(nn.Module):
|
||||||
|
# BitandBytes specific attributes
|
||||||
|
default_bitsandbytes_target_modules = [
|
||||||
|
".gate_proj.",
|
||||||
|
".down_proj.",
|
||||||
|
".up_proj.",
|
||||||
|
".q_proj.",
|
||||||
|
".k_proj.",
|
||||||
|
".v_proj.",
|
||||||
|
".o_proj.",
|
||||||
|
]
|
||||||
|
bitsandbytes_stacked_params_mapping = {
|
||||||
|
# shard_name, weight_name, index
|
||||||
|
"q_proj": ("qkv_proj", 0),
|
||||||
|
"k_proj": ("qkv_proj", 1),
|
||||||
|
"v_proj": ("qkv_proj", 2),
|
||||||
|
"gate_proj": ("gate_up_proj", 0),
|
||||||
|
"up_proj": ("gate_up_proj", 1),
|
||||||
|
}
|
||||||
|
|
||||||
packed_modules_mapping = {
|
packed_modules_mapping = {
|
||||||
"qkv_proj": [
|
"qkv_proj": [
|
||||||
"q_proj",
|
"q_proj",
|
||||||
|
|||||||
Reference in New Issue
Block a user