[Model Support] unsloth/Phi-4-mini bnb model (#4982)
Co-authored-by: yhyang201 <yhyang201@gmail.com> Co-authored-by: Liangsheng Yin <hnyls2002@gmail.com> Co-authored-by: Chayenne <zhaochen20@outlook.com> Co-authored-by: Yineng Zhang <me@zhyncs.com>
This commit is contained in:
@@ -1,5 +1,6 @@
|
|||||||
"""Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/model_executor/layers/linear.py"""
|
"""Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/model_executor/layers/linear.py"""
|
||||||
|
|
||||||
|
import itertools
|
||||||
import logging
|
import logging
|
||||||
from abc import abstractmethod
|
from abc import abstractmethod
|
||||||
from typing import Dict, List, Optional, Tuple
|
from typing import Dict, List, Optional, Tuple
|
||||||
@@ -61,12 +62,12 @@ def adjust_marlin_shard(param, shard_size, shard_offset):
|
|||||||
|
|
||||||
|
|
||||||
def adjust_bitsandbytes_4bit_shard(
|
def adjust_bitsandbytes_4bit_shard(
|
||||||
param: Parameter, qkv_offsets: Dict[str, Tuple[int, int]], loaded_shard_id: str
|
param: Parameter, shard_offsets: Dict[str, Tuple[int, int]], loaded_shard_id: str
|
||||||
) -> Tuple[int, int]:
|
) -> Tuple[int, int]:
|
||||||
"""Adjust the quantization offsets and sizes for BitsAndBytes sharding."""
|
"""Adjust the quantization offsets and sizes for BitsAndBytes sharding."""
|
||||||
|
|
||||||
total, _ = qkv_offsets["total"]
|
total, _ = shard_offsets["total"]
|
||||||
orig_offset, orig_size = qkv_offsets[loaded_shard_id]
|
orig_offset, orig_size = shard_offsets[loaded_shard_id]
|
||||||
|
|
||||||
quantized_total = param.data.shape[0]
|
quantized_total = param.data.shape[0]
|
||||||
quantized_offset = orig_offset * quantized_total // total
|
quantized_offset = orig_offset * quantized_total // total
|
||||||
@@ -573,6 +574,8 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
|
|||||||
shard_offsets.append((i, current_shard_offset, output_size))
|
shard_offsets.append((i, current_shard_offset, output_size))
|
||||||
current_shard_offset += output_size
|
current_shard_offset += output_size
|
||||||
packed_dim = getattr(param, "packed_dim", None)
|
packed_dim = getattr(param, "packed_dim", None)
|
||||||
|
|
||||||
|
use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
|
||||||
for shard_id, shard_offset, shard_size in shard_offsets:
|
for shard_id, shard_offset, shard_size in shard_offsets:
|
||||||
# Special case for Quantization.
|
# Special case for Quantization.
|
||||||
# If quantized, we need to adjust the offset and size to account
|
# If quantized, we need to adjust the offset and size to account
|
||||||
@@ -585,6 +588,17 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
|
|||||||
param, shard_size, shard_offset
|
param, shard_size, shard_offset
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if use_bitsandbytes_4bit:
|
||||||
|
index = list(itertools.accumulate([0] + self.output_sizes))
|
||||||
|
orig_offsets = {
|
||||||
|
str(i): (index[i], size)
|
||||||
|
for i, size in enumerate(self.output_sizes)
|
||||||
|
}
|
||||||
|
orig_offsets["total"] = (self.output_size, 0)
|
||||||
|
shard_size, shard_offset = adjust_bitsandbytes_4bit_shard(
|
||||||
|
param, orig_offsets, str(shard_id)
|
||||||
|
)
|
||||||
|
|
||||||
loaded_weight_shard = loaded_weight.narrow(
|
loaded_weight_shard = loaded_weight.narrow(
|
||||||
output_dim, shard_offset, shard_size
|
output_dim, shard_offset, shard_size
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -362,11 +362,11 @@ class LlamaForCausalLM(nn.Module):
|
|||||||
column_parallel_weights_modules = [".down_proj.", ".o_proj."]
|
column_parallel_weights_modules = [".down_proj.", ".o_proj."]
|
||||||
bitsandbytes_stacked_params_mapping = {
|
bitsandbytes_stacked_params_mapping = {
|
||||||
# shard_name, weight_name, index
|
# shard_name, weight_name, index
|
||||||
"q_proj": ("qkv_proj", 0),
|
".q_proj": (".qkv_proj", 0),
|
||||||
"k_proj": ("qkv_proj", 1),
|
".k_proj": (".qkv_proj", 1),
|
||||||
"v_proj": ("qkv_proj", 2),
|
".v_proj": (".qkv_proj", 2),
|
||||||
"gate_proj": ("gate_up_proj", 0),
|
".gate_proj": (".gate_up_proj", 0),
|
||||||
"up_proj": ("gate_up_proj", 1),
|
".up_proj": (".gate_up_proj", 1),
|
||||||
}
|
}
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
|
|||||||
213
test/srt/models/test_unsloth_models.py
Normal file
213
test/srt/models/test_unsloth_models.py
Normal file
@@ -0,0 +1,213 @@
|
|||||||
|
import unittest
|
||||||
|
from types import SimpleNamespace
|
||||||
|
|
||||||
|
from sglang.srt.utils import kill_process_tree
|
||||||
|
from sglang.test.few_shot_gsm8k import run_eval
|
||||||
|
from sglang.test.test_utils import (
|
||||||
|
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||||
|
DEFAULT_URL_FOR_TEST,
|
||||||
|
CustomTestCase,
|
||||||
|
popen_launch_server,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestUnslothPhi4(CustomTestCase):
|
||||||
|
@classmethod
|
||||||
|
def setUpClass(cls):
|
||||||
|
cls.model = "unsloth/phi-4"
|
||||||
|
cls.base_url = DEFAULT_URL_FOR_TEST
|
||||||
|
cls.process = popen_launch_server(
|
||||||
|
cls.model,
|
||||||
|
cls.base_url,
|
||||||
|
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||||
|
other_args=[],
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def tearDownClass(cls):
|
||||||
|
kill_process_tree(cls.process.pid)
|
||||||
|
|
||||||
|
def test_gsm8k(self):
|
||||||
|
args = SimpleNamespace(
|
||||||
|
num_shots=5,
|
||||||
|
data_path=None,
|
||||||
|
num_questions=200,
|
||||||
|
max_new_tokens=512,
|
||||||
|
parallel=128,
|
||||||
|
host="http://127.0.0.1",
|
||||||
|
port=int(self.base_url.split(":")[-1]),
|
||||||
|
)
|
||||||
|
metrics = run_eval(args)
|
||||||
|
print(f"{metrics=}")
|
||||||
|
self.assertGreater(metrics["accuracy"], 0.78)
|
||||||
|
|
||||||
|
|
||||||
|
class TestUnslothPhi4Bnb4bit(CustomTestCase):
|
||||||
|
@classmethod
|
||||||
|
def setUpClass(cls):
|
||||||
|
cls.model = "unsloth/phi-4-bnb-4bit"
|
||||||
|
cls.base_url = DEFAULT_URL_FOR_TEST
|
||||||
|
cls.process = popen_launch_server(
|
||||||
|
cls.model,
|
||||||
|
cls.base_url,
|
||||||
|
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||||
|
other_args=[
|
||||||
|
"--load-format",
|
||||||
|
"bitsandbytes",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def tearDownClass(cls):
|
||||||
|
kill_process_tree(cls.process.pid)
|
||||||
|
|
||||||
|
def test_gsm8k(self):
|
||||||
|
args = SimpleNamespace(
|
||||||
|
num_shots=5,
|
||||||
|
data_path=None,
|
||||||
|
num_questions=200,
|
||||||
|
max_new_tokens=512,
|
||||||
|
parallel=128,
|
||||||
|
host="http://127.0.0.1",
|
||||||
|
port=int(self.base_url.split(":")[-1]),
|
||||||
|
)
|
||||||
|
metrics = run_eval(args)
|
||||||
|
print(f"{metrics=}")
|
||||||
|
self.assertGreater(metrics["accuracy"], 0.75)
|
||||||
|
|
||||||
|
|
||||||
|
class TestUnslothPhi4UnslothBnb4bit(CustomTestCase):
|
||||||
|
@classmethod
|
||||||
|
def setUpClass(cls):
|
||||||
|
cls.model = "unsloth/phi-4-unsloth-bnb-4bit"
|
||||||
|
cls.base_url = DEFAULT_URL_FOR_TEST
|
||||||
|
cls.process = popen_launch_server(
|
||||||
|
cls.model,
|
||||||
|
cls.base_url,
|
||||||
|
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||||
|
other_args=[
|
||||||
|
"--load-format",
|
||||||
|
"bitsandbytes",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def tearDownClass(cls):
|
||||||
|
kill_process_tree(cls.process.pid)
|
||||||
|
|
||||||
|
def test_gsm8k(self):
|
||||||
|
args = SimpleNamespace(
|
||||||
|
num_shots=5,
|
||||||
|
data_path=None,
|
||||||
|
num_questions=200,
|
||||||
|
max_new_tokens=512,
|
||||||
|
parallel=128,
|
||||||
|
host="http://127.0.0.1",
|
||||||
|
port=int(self.base_url.split(":")[-1]),
|
||||||
|
)
|
||||||
|
metrics = run_eval(args)
|
||||||
|
print(f"{metrics=}")
|
||||||
|
self.assertGreater(metrics["accuracy"], 0.75)
|
||||||
|
|
||||||
|
|
||||||
|
class TestUnslothPhi4MiniInstruct(CustomTestCase):
|
||||||
|
@classmethod
|
||||||
|
def setUpClass(cls):
|
||||||
|
cls.model = "unsloth/Phi-4-mini-instruct"
|
||||||
|
cls.base_url = DEFAULT_URL_FOR_TEST
|
||||||
|
cls.process = popen_launch_server(
|
||||||
|
cls.model,
|
||||||
|
cls.base_url,
|
||||||
|
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||||
|
other_args=[],
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def tearDownClass(cls):
|
||||||
|
kill_process_tree(cls.process.pid)
|
||||||
|
|
||||||
|
def test_gsm8k(self):
|
||||||
|
args = SimpleNamespace(
|
||||||
|
num_shots=5,
|
||||||
|
data_path=None,
|
||||||
|
num_questions=200,
|
||||||
|
max_new_tokens=512,
|
||||||
|
parallel=128,
|
||||||
|
host="http://127.0.0.1",
|
||||||
|
port=int(self.base_url.split(":")[-1]),
|
||||||
|
)
|
||||||
|
metrics = run_eval(args)
|
||||||
|
print(f"{metrics=}")
|
||||||
|
self.assertGreater(metrics["accuracy"], 0.65)
|
||||||
|
|
||||||
|
|
||||||
|
class TestUnslothPhi4MiniBnb4bit(CustomTestCase):
|
||||||
|
@classmethod
|
||||||
|
def setUpClass(cls):
|
||||||
|
cls.model = "unsloth/Phi-4-mini-instruct-bnb-4bit"
|
||||||
|
cls.base_url = DEFAULT_URL_FOR_TEST
|
||||||
|
cls.process = popen_launch_server(
|
||||||
|
cls.model,
|
||||||
|
cls.base_url,
|
||||||
|
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||||
|
other_args=[
|
||||||
|
"--load-format",
|
||||||
|
"bitsandbytes",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def tearDownClass(cls):
|
||||||
|
kill_process_tree(cls.process.pid)
|
||||||
|
|
||||||
|
def test_gsm8k(self):
|
||||||
|
args = SimpleNamespace(
|
||||||
|
num_shots=5,
|
||||||
|
data_path=None,
|
||||||
|
num_questions=200,
|
||||||
|
max_new_tokens=512,
|
||||||
|
parallel=128,
|
||||||
|
host="http://127.0.0.1",
|
||||||
|
port=int(self.base_url.split(":")[-1]),
|
||||||
|
)
|
||||||
|
metrics = run_eval(args)
|
||||||
|
print(f"{metrics=}")
|
||||||
|
self.assertGreater(metrics["accuracy"], 0.6)
|
||||||
|
|
||||||
|
|
||||||
|
class TestUnslothPhi4MiniUnslothBnb4bit(CustomTestCase):
|
||||||
|
@classmethod
|
||||||
|
def setUpClass(cls):
|
||||||
|
cls.model = "unsloth/Phi-4-mini-instruct-unsloth-bnb-4bit"
|
||||||
|
cls.base_url = DEFAULT_URL_FOR_TEST
|
||||||
|
cls.process = popen_launch_server(
|
||||||
|
cls.model,
|
||||||
|
cls.base_url,
|
||||||
|
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||||
|
other_args=[
|
||||||
|
"--load-format",
|
||||||
|
"bitsandbytes",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def tearDownClass(cls):
|
||||||
|
kill_process_tree(cls.process.pid)
|
||||||
|
|
||||||
|
def test_gsm8k(self):
|
||||||
|
args = SimpleNamespace(
|
||||||
|
num_shots=5,
|
||||||
|
data_path=None,
|
||||||
|
num_questions=200,
|
||||||
|
max_new_tokens=512,
|
||||||
|
parallel=128,
|
||||||
|
host="http://127.0.0.1",
|
||||||
|
port=int(self.base_url.split(":")[-1]),
|
||||||
|
)
|
||||||
|
metrics = run_eval(args)
|
||||||
|
print(f"{metrics=}")
|
||||||
|
self.assertGreater(metrics["accuracy"], 0.6)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
||||||
Reference in New Issue
Block a user