[Model Support] unsloth/Phi-4-mini bnb model (#4982)

Co-authored-by: yhyang201 <yhyang201@gmail.com>
Co-authored-by: Liangsheng Yin <hnyls2002@gmail.com>
Co-authored-by: Chayenne <zhaochen20@outlook.com>
Co-authored-by: Yineng Zhang <me@zhyncs.com>
This commit is contained in:
eigen
2025-04-16 22:58:20 -04:00
committed by GitHub
parent 90faf9018e
commit 8f783c1943
3 changed files with 235 additions and 8 deletions

View File

@@ -1,5 +1,6 @@
"""Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/model_executor/layers/linear.py"""
import itertools
import logging
from abc import abstractmethod
from typing import Dict, List, Optional, Tuple
@@ -61,12 +62,12 @@ def adjust_marlin_shard(param, shard_size, shard_offset):
def adjust_bitsandbytes_4bit_shard(
param: Parameter, qkv_offsets: Dict[str, Tuple[int, int]], loaded_shard_id: str
param: Parameter, shard_offsets: Dict[str, Tuple[int, int]], loaded_shard_id: str
) -> Tuple[int, int]:
"""Adjust the quantization offsets and sizes for BitsAndBytes sharding."""
total, _ = qkv_offsets["total"]
orig_offset, orig_size = qkv_offsets[loaded_shard_id]
total, _ = shard_offsets["total"]
orig_offset, orig_size = shard_offsets[loaded_shard_id]
quantized_total = param.data.shape[0]
quantized_offset = orig_offset * quantized_total // total
@@ -573,6 +574,8 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
shard_offsets.append((i, current_shard_offset, output_size))
current_shard_offset += output_size
packed_dim = getattr(param, "packed_dim", None)
use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
for shard_id, shard_offset, shard_size in shard_offsets:
# Special case for Quantization.
# If quantized, we need to adjust the offset and size to account
@@ -585,6 +588,17 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
param, shard_size, shard_offset
)
if use_bitsandbytes_4bit:
index = list(itertools.accumulate([0] + self.output_sizes))
orig_offsets = {
str(i): (index[i], size)
for i, size in enumerate(self.output_sizes)
}
orig_offsets["total"] = (self.output_size, 0)
shard_size, shard_offset = adjust_bitsandbytes_4bit_shard(
param, orig_offsets, str(shard_id)
)
loaded_weight_shard = loaded_weight.narrow(
output_dim, shard_offset, shard_size
)

View File

@@ -362,11 +362,11 @@ class LlamaForCausalLM(nn.Module):
column_parallel_weights_modules = [".down_proj.", ".o_proj."]
bitsandbytes_stacked_params_mapping = {
# shard_name, weight_name, index
"q_proj": ("qkv_proj", 0),
"k_proj": ("qkv_proj", 1),
"v_proj": ("qkv_proj", 2),
"gate_proj": ("gate_up_proj", 0),
"up_proj": ("gate_up_proj", 1),
".q_proj": (".qkv_proj", 0),
".k_proj": (".qkv_proj", 1),
".v_proj": (".qkv_proj", 2),
".gate_proj": (".gate_up_proj", 0),
".up_proj": (".gate_up_proj", 1),
}
def __init__(