Support for Qwen2.5-VL Model in bitsandbytes Format (#5003)
This commit is contained in:
@@ -1071,6 +1071,7 @@ class BitsAndBytesModelLoader(BaseModelLoader):
|
||||
|
||||
param_dict = dict(model.named_parameters())
|
||||
stacked_quant_state_dict: Dict[str, Dict[int, Any]] = {}
|
||||
model_type = model_config.hf_config.model_type
|
||||
for quant_param_name in quant_state_dict:
|
||||
non_stacked_param_name = quant_param_name
|
||||
|
||||
@@ -1079,11 +1080,24 @@ class BitsAndBytesModelLoader(BaseModelLoader):
|
||||
weight_name,
|
||||
index,
|
||||
) in model.bitsandbytes_stacked_params_mapping.items():
|
||||
if (
|
||||
model_type in ["qwen2_vl", "qwen2_5_vl"]
|
||||
and "visual" in quant_param_name
|
||||
):
|
||||
break
|
||||
if shard_name in quant_param_name:
|
||||
shard_index = index
|
||||
quant_param_name = quant_param_name.replace(shard_name, weight_name)
|
||||
break
|
||||
|
||||
if (
|
||||
model_type in ["qwen2_vl", "qwen2_5_vl"]
|
||||
and "visual" in quant_param_name
|
||||
):
|
||||
quant_param_name = quant_param_name.replace(
|
||||
r"attn.qkv.", r"attn.qkv_proj."
|
||||
)
|
||||
|
||||
if quant_param_name not in param_dict:
|
||||
raise ValueError(
|
||||
f"Parameter {quant_param_name} not found in the model."
|
||||
@@ -1111,6 +1125,8 @@ class BitsAndBytesModelLoader(BaseModelLoader):
|
||||
num_elements[seq] = math.prod(quant_state.shape) // pack_ratio
|
||||
|
||||
offsets = np.concatenate(([0], np.cumsum(num_elements)))
|
||||
# Make torch infer_schema happy(Compatible with vLLM)
|
||||
offsets = torch.tensor(offsets).cpu()
|
||||
set_weight_attrs(param, {"bnb_shard_offsets": offsets})
|
||||
|
||||
if load_8bit:
|
||||
|
||||
Reference in New Issue
Block a user