[Model Support] unsloth/Phi-4-mini bnb model (#4982)
Co-authored-by: yhyang201 <yhyang201@gmail.com> Co-authored-by: Liangsheng Yin <hnyls2002@gmail.com> Co-authored-by: Chayenne <zhaochen20@outlook.com> Co-authored-by: Yineng Zhang <me@zhyncs.com>
This commit is contained in:
@@ -1,5 +1,6 @@
|
||||
"""Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/model_executor/layers/linear.py"""
|
||||
|
||||
import itertools
|
||||
import logging
|
||||
from abc import abstractmethod
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
@@ -61,12 +62,12 @@ def adjust_marlin_shard(param, shard_size, shard_offset):
|
||||
|
||||
|
||||
def adjust_bitsandbytes_4bit_shard(
|
||||
param: Parameter, qkv_offsets: Dict[str, Tuple[int, int]], loaded_shard_id: str
|
||||
param: Parameter, shard_offsets: Dict[str, Tuple[int, int]], loaded_shard_id: str
|
||||
) -> Tuple[int, int]:
|
||||
"""Adjust the quantization offsets and sizes for BitsAndBytes sharding."""
|
||||
|
||||
total, _ = qkv_offsets["total"]
|
||||
orig_offset, orig_size = qkv_offsets[loaded_shard_id]
|
||||
total, _ = shard_offsets["total"]
|
||||
orig_offset, orig_size = shard_offsets[loaded_shard_id]
|
||||
|
||||
quantized_total = param.data.shape[0]
|
||||
quantized_offset = orig_offset * quantized_total // total
|
||||
@@ -573,6 +574,8 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
|
||||
shard_offsets.append((i, current_shard_offset, output_size))
|
||||
current_shard_offset += output_size
|
||||
packed_dim = getattr(param, "packed_dim", None)
|
||||
|
||||
use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
|
||||
for shard_id, shard_offset, shard_size in shard_offsets:
|
||||
# Special case for Quantization.
|
||||
# If quantized, we need to adjust the offset and size to account
|
||||
@@ -585,6 +588,17 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
|
||||
param, shard_size, shard_offset
|
||||
)
|
||||
|
||||
if use_bitsandbytes_4bit:
|
||||
index = list(itertools.accumulate([0] + self.output_sizes))
|
||||
orig_offsets = {
|
||||
str(i): (index[i], size)
|
||||
for i, size in enumerate(self.output_sizes)
|
||||
}
|
||||
orig_offsets["total"] = (self.output_size, 0)
|
||||
shard_size, shard_offset = adjust_bitsandbytes_4bit_shard(
|
||||
param, orig_offsets, str(shard_id)
|
||||
)
|
||||
|
||||
loaded_weight_shard = loaded_weight.narrow(
|
||||
output_dim, shard_offset, shard_size
|
||||
)
|
||||
|
||||
@@ -362,11 +362,11 @@ class LlamaForCausalLM(nn.Module):
|
||||
column_parallel_weights_modules = [".down_proj.", ".o_proj."]
|
||||
bitsandbytes_stacked_params_mapping = {
|
||||
# shard_name, weight_name, index
|
||||
"q_proj": ("qkv_proj", 0),
|
||||
"k_proj": ("qkv_proj", 1),
|
||||
"v_proj": ("qkv_proj", 2),
|
||||
"gate_proj": ("gate_up_proj", 0),
|
||||
"up_proj": ("gate_up_proj", 1),
|
||||
".q_proj": (".qkv_proj", 0),
|
||||
".k_proj": (".qkv_proj", 1),
|
||||
".v_proj": (".qkv_proj", 2),
|
||||
".gate_proj": (".gate_up_proj", 0),
|
||||
".up_proj": (".gate_up_proj", 1),
|
||||
}
|
||||
|
||||
def __init__(
|
||||
|
||||
Reference in New Issue
Block a user