Support loading of larger models with on-the-fly quantization (#3061)
This commit is contained in:
@@ -20,6 +20,7 @@ class LoadFormat(str, enum.Enum):
|
||||
GGUF = "gguf"
|
||||
BITSANDBYTES = "bitsandbytes"
|
||||
MISTRAL = "mistral"
|
||||
LAYERED = "layered"
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@@ -5,6 +5,7 @@ Common utilities for torchao.
|
||||
import logging
|
||||
import os
|
||||
import pwd
|
||||
from typing import Callable, Optional
|
||||
|
||||
import torch
|
||||
|
||||
@@ -27,8 +28,18 @@ def save_gemlite_cache(print_error: bool = False) -> bool:
|
||||
return True
|
||||
|
||||
|
||||
def proj_filter(
|
||||
module: torch.nn.Module,
|
||||
fqn: str,
|
||||
):
|
||||
"""Filter function for quantizing projection layers."""
|
||||
return "proj" in fqn
|
||||
|
||||
|
||||
def apply_torchao_config_to_model(
|
||||
model: torch.nn.Module, torchao_config: str, filter_fn=None
|
||||
model: torch.nn.Module,
|
||||
torchao_config: str,
|
||||
filter_fn: Optional[Callable] = proj_filter,
|
||||
):
|
||||
"""Quantize a modelwith torchao quantization specified by torchao_config
|
||||
|
||||
@@ -49,11 +60,6 @@ def apply_torchao_config_to_model(
|
||||
)
|
||||
from torchao.quantization.observer import PerRow, PerTensor
|
||||
|
||||
if filter_fn is None:
|
||||
|
||||
def filter_fn(module, fqn):
|
||||
return "proj" in fqn
|
||||
|
||||
if torchao_config == "" or torchao_config is None:
|
||||
return model
|
||||
elif "int8wo" in torchao_config:
|
||||
|
||||
@@ -185,9 +185,12 @@ class ModelRunner:
|
||||
self.load_model()
|
||||
|
||||
# Apply torchao quantization
|
||||
apply_torchao_config_to_model(
|
||||
self.model, global_server_args_dict["torchao_config"]
|
||||
)
|
||||
torchao_applied = getattr(self.model, "torchao_applied", False)
|
||||
# In layered loading, torchao may have been applied
|
||||
if not torchao_applied:
|
||||
apply_torchao_config_to_model(
|
||||
self.model, global_server_args_dict["torchao_config"]
|
||||
)
|
||||
|
||||
# Apply torch TP if the model supports it
|
||||
supports_torch_tp = getattr(self.model, "supports_torch_tp", False)
|
||||
|
||||
@@ -374,6 +374,78 @@ class DefaultModelLoader(BaseModelLoader):
|
||||
return model.eval()
|
||||
|
||||
|
||||
class LayeredModelLoader(DefaultModelLoader):
|
||||
"""Model loader that loads weights layer by layer so that one can quantize a
|
||||
layer before loading another to make the peak memory envelope smaller."""
|
||||
|
||||
def __init__(self, load_config: LoadConfig):
|
||||
# Back to the default load format
|
||||
load_config.load_format = LoadFormat.AUTO
|
||||
super().__init__(load_config)
|
||||
|
||||
def load_model(
|
||||
self,
|
||||
*,
|
||||
model_config: ModelConfig,
|
||||
device_config: DeviceConfig,
|
||||
) -> nn.Module:
|
||||
from sglang.srt.layers.torchao_utils import apply_torchao_config_to_model
|
||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||
|
||||
torchao_config = global_server_args_dict.get("torchao_config")
|
||||
target_device = torch.device(device_config.device)
|
||||
|
||||
with set_default_torch_dtype(model_config.dtype):
|
||||
# Create model on meta device
|
||||
with torch.device("meta"):
|
||||
model = _initialize_model(
|
||||
model_config,
|
||||
self.load_config,
|
||||
)
|
||||
|
||||
# Check model's layered load support
|
||||
if not hasattr(model, "load_weights_to_module"):
|
||||
raise ValueError(
|
||||
"LayeredModelLoader requires the model to have a "
|
||||
"`load_weights_to_module` method. "
|
||||
f"{model_config.model_path} does not support it."
|
||||
)
|
||||
|
||||
# Get all weights from disk
|
||||
weights = self._get_all_weights(model_config, model)
|
||||
|
||||
# Helper function to recursively fill the weights of a module
|
||||
def fill_module(module, fqn: List[str], weights):
|
||||
"""
|
||||
fqn: list of strings representing the fully qualified name of `module`.
|
||||
"""
|
||||
# Layer by layer
|
||||
for name, submod in module.named_children():
|
||||
fill_module(submod, fqn + [name], weights)
|
||||
|
||||
# First materialize on target device
|
||||
module.to_empty(device=target_device, recurse=False)
|
||||
fqn_path = ".".join(fqn)
|
||||
# Fill weights
|
||||
model.load_weights_to_module(
|
||||
fqn_path,
|
||||
weights,
|
||||
)
|
||||
# Quantize weights if applicable
|
||||
if torchao_config and "proj" in fqn_path:
|
||||
# Note: `None` here is needed to indicate no filter, see
|
||||
# `apply_torchao_config_to_model` for details.
|
||||
apply_torchao_config_to_model(module, torchao_config, None)
|
||||
|
||||
# Start calling on root module
|
||||
fill_module(model, [], weights)
|
||||
|
||||
if torchao_config:
|
||||
model.torchao_applied = True
|
||||
|
||||
return model.eval()
|
||||
|
||||
|
||||
class DummyModelLoader(BaseModelLoader):
|
||||
"""Model loader that will set model weights to random values."""
|
||||
|
||||
@@ -1149,4 +1221,7 @@ def get_model_loader(load_config: LoadConfig) -> BaseModelLoader:
|
||||
if load_config.load_format == LoadFormat.GGUF:
|
||||
return GGUFModelLoader(load_config)
|
||||
|
||||
if load_config.load_format == LoadFormat.LAYERED:
|
||||
return LayeredModelLoader(load_config)
|
||||
|
||||
return DefaultModelLoader(load_config)
|
||||
|
||||
@@ -460,7 +460,12 @@ class TorchNativeLlamaForCausalLM(nn.Module):
|
||||
params_dict = dict(self.named_parameters())
|
||||
return len(params_dict)
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
def load_weights_to_module(
|
||||
self,
|
||||
fqn: str,
|
||||
weights: Iterable[Tuple[str, torch.Tensor]],
|
||||
):
|
||||
"""Load weights onto submodule pointed by path `fqn`."""
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
(".qkv_proj", ".q_proj", "q"),
|
||||
@@ -469,7 +474,8 @@ class TorchNativeLlamaForCausalLM(nn.Module):
|
||||
(".gate_up_proj", ".gate_proj", 0),
|
||||
(".gate_up_proj", ".up_proj", 1),
|
||||
]
|
||||
params_dict = dict(self.named_parameters())
|
||||
module = self.get_submodule(fqn)
|
||||
params_dict = dict(module.named_parameters(prefix=fqn, recurse=False))
|
||||
|
||||
for name, loaded_weight in weights:
|
||||
if "rotary_emb.inv_freq" in name or "projector" in name:
|
||||
@@ -486,7 +492,7 @@ class TorchNativeLlamaForCausalLM(nn.Module):
|
||||
continue
|
||||
name = name.replace(weight_name, param_name)
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
if name.endswith(".bias") or name not in params_dict:
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = param.weight_loader
|
||||
@@ -494,12 +500,19 @@ class TorchNativeLlamaForCausalLM(nn.Module):
|
||||
break
|
||||
else:
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
if name.endswith(".bias") or name not in params_dict:
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
|
||||
def load_weights(
|
||||
self,
|
||||
weights: Iterable[Tuple[str, torch.Tensor]],
|
||||
):
|
||||
"""Load weights onto the full model."""
|
||||
self.load_weights_to_module("", weights)
|
||||
|
||||
|
||||
class TorchNativePhi3ForCausalLM(TorchNativeLlamaForCausalLM):
|
||||
pass
|
||||
|
||||
@@ -317,6 +317,7 @@ class ServerArgs:
|
||||
"dummy",
|
||||
"gguf",
|
||||
"bitsandbytes",
|
||||
"layered",
|
||||
],
|
||||
help="The format of the model weights to load. "
|
||||
'"auto" will try to load the weights in the safetensors format '
|
||||
@@ -330,7 +331,10 @@ class ServerArgs:
|
||||
"which is mainly for profiling."
|
||||
'"gguf" will load the weights in the gguf format. '
|
||||
'"bitsandbytes" will load the weights using bitsandbytes '
|
||||
"quantization.",
|
||||
"quantization."
|
||||
'"layered" loads weights layer by layer so that one can quantize a '
|
||||
"layer before loading another to make the peak memory envelope "
|
||||
"smaller.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--trust-remote-code",
|
||||
|
||||
Reference in New Issue
Block a user