Support loading of larger models with on-the-fly quantization (#3061)
This commit is contained in:
@@ -20,6 +20,7 @@ class LoadFormat(str, enum.Enum):
|
|||||||
GGUF = "gguf"
|
GGUF = "gguf"
|
||||||
BITSANDBYTES = "bitsandbytes"
|
BITSANDBYTES = "bitsandbytes"
|
||||||
MISTRAL = "mistral"
|
MISTRAL = "mistral"
|
||||||
|
LAYERED = "layered"
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ Common utilities for torchao.
|
|||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import pwd
|
import pwd
|
||||||
|
from typing import Callable, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@@ -27,8 +28,18 @@ def save_gemlite_cache(print_error: bool = False) -> bool:
|
|||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def proj_filter(
|
||||||
|
module: torch.nn.Module,
|
||||||
|
fqn: str,
|
||||||
|
):
|
||||||
|
"""Filter function for quantizing projection layers."""
|
||||||
|
return "proj" in fqn
|
||||||
|
|
||||||
|
|
||||||
def apply_torchao_config_to_model(
|
def apply_torchao_config_to_model(
|
||||||
model: torch.nn.Module, torchao_config: str, filter_fn=None
|
model: torch.nn.Module,
|
||||||
|
torchao_config: str,
|
||||||
|
filter_fn: Optional[Callable] = proj_filter,
|
||||||
):
|
):
|
||||||
"""Quantize a modelwith torchao quantization specified by torchao_config
|
"""Quantize a modelwith torchao quantization specified by torchao_config
|
||||||
|
|
||||||
@@ -49,11 +60,6 @@ def apply_torchao_config_to_model(
|
|||||||
)
|
)
|
||||||
from torchao.quantization.observer import PerRow, PerTensor
|
from torchao.quantization.observer import PerRow, PerTensor
|
||||||
|
|
||||||
if filter_fn is None:
|
|
||||||
|
|
||||||
def filter_fn(module, fqn):
|
|
||||||
return "proj" in fqn
|
|
||||||
|
|
||||||
if torchao_config == "" or torchao_config is None:
|
if torchao_config == "" or torchao_config is None:
|
||||||
return model
|
return model
|
||||||
elif "int8wo" in torchao_config:
|
elif "int8wo" in torchao_config:
|
||||||
|
|||||||
@@ -185,9 +185,12 @@ class ModelRunner:
|
|||||||
self.load_model()
|
self.load_model()
|
||||||
|
|
||||||
# Apply torchao quantization
|
# Apply torchao quantization
|
||||||
apply_torchao_config_to_model(
|
torchao_applied = getattr(self.model, "torchao_applied", False)
|
||||||
self.model, global_server_args_dict["torchao_config"]
|
# In layered loading, torchao may have been applied
|
||||||
)
|
if not torchao_applied:
|
||||||
|
apply_torchao_config_to_model(
|
||||||
|
self.model, global_server_args_dict["torchao_config"]
|
||||||
|
)
|
||||||
|
|
||||||
# Apply torch TP if the model supports it
|
# Apply torch TP if the model supports it
|
||||||
supports_torch_tp = getattr(self.model, "supports_torch_tp", False)
|
supports_torch_tp = getattr(self.model, "supports_torch_tp", False)
|
||||||
|
|||||||
@@ -374,6 +374,78 @@ class DefaultModelLoader(BaseModelLoader):
|
|||||||
return model.eval()
|
return model.eval()
|
||||||
|
|
||||||
|
|
||||||
|
class LayeredModelLoader(DefaultModelLoader):
|
||||||
|
"""Model loader that loads weights layer by layer so that one can quantize a
|
||||||
|
layer before loading another to make the peak memory envelope smaller."""
|
||||||
|
|
||||||
|
def __init__(self, load_config: LoadConfig):
|
||||||
|
# Back to the default load format
|
||||||
|
load_config.load_format = LoadFormat.AUTO
|
||||||
|
super().__init__(load_config)
|
||||||
|
|
||||||
|
def load_model(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
model_config: ModelConfig,
|
||||||
|
device_config: DeviceConfig,
|
||||||
|
) -> nn.Module:
|
||||||
|
from sglang.srt.layers.torchao_utils import apply_torchao_config_to_model
|
||||||
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||||
|
|
||||||
|
torchao_config = global_server_args_dict.get("torchao_config")
|
||||||
|
target_device = torch.device(device_config.device)
|
||||||
|
|
||||||
|
with set_default_torch_dtype(model_config.dtype):
|
||||||
|
# Create model on meta device
|
||||||
|
with torch.device("meta"):
|
||||||
|
model = _initialize_model(
|
||||||
|
model_config,
|
||||||
|
self.load_config,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check model's layered load support
|
||||||
|
if not hasattr(model, "load_weights_to_module"):
|
||||||
|
raise ValueError(
|
||||||
|
"LayeredModelLoader requires the model to have a "
|
||||||
|
"`load_weights_to_module` method. "
|
||||||
|
f"{model_config.model_path} does not support it."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get all weights from disk
|
||||||
|
weights = self._get_all_weights(model_config, model)
|
||||||
|
|
||||||
|
# Helper function to recursively fill the weights of a module
|
||||||
|
def fill_module(module, fqn: List[str], weights):
|
||||||
|
"""
|
||||||
|
fqn: list of strings representing the fully qualified name of `module`.
|
||||||
|
"""
|
||||||
|
# Layer by layer
|
||||||
|
for name, submod in module.named_children():
|
||||||
|
fill_module(submod, fqn + [name], weights)
|
||||||
|
|
||||||
|
# First materialize on target device
|
||||||
|
module.to_empty(device=target_device, recurse=False)
|
||||||
|
fqn_path = ".".join(fqn)
|
||||||
|
# Fill weights
|
||||||
|
model.load_weights_to_module(
|
||||||
|
fqn_path,
|
||||||
|
weights,
|
||||||
|
)
|
||||||
|
# Quantize weights if applicable
|
||||||
|
if torchao_config and "proj" in fqn_path:
|
||||||
|
# Note: `None` here is needed to indicate no filter, see
|
||||||
|
# `apply_torchao_config_to_model` for details.
|
||||||
|
apply_torchao_config_to_model(module, torchao_config, None)
|
||||||
|
|
||||||
|
# Start calling on root module
|
||||||
|
fill_module(model, [], weights)
|
||||||
|
|
||||||
|
if torchao_config:
|
||||||
|
model.torchao_applied = True
|
||||||
|
|
||||||
|
return model.eval()
|
||||||
|
|
||||||
|
|
||||||
class DummyModelLoader(BaseModelLoader):
|
class DummyModelLoader(BaseModelLoader):
|
||||||
"""Model loader that will set model weights to random values."""
|
"""Model loader that will set model weights to random values."""
|
||||||
|
|
||||||
@@ -1149,4 +1221,7 @@ def get_model_loader(load_config: LoadConfig) -> BaseModelLoader:
|
|||||||
if load_config.load_format == LoadFormat.GGUF:
|
if load_config.load_format == LoadFormat.GGUF:
|
||||||
return GGUFModelLoader(load_config)
|
return GGUFModelLoader(load_config)
|
||||||
|
|
||||||
|
if load_config.load_format == LoadFormat.LAYERED:
|
||||||
|
return LayeredModelLoader(load_config)
|
||||||
|
|
||||||
return DefaultModelLoader(load_config)
|
return DefaultModelLoader(load_config)
|
||||||
|
|||||||
@@ -460,7 +460,12 @@ class TorchNativeLlamaForCausalLM(nn.Module):
|
|||||||
params_dict = dict(self.named_parameters())
|
params_dict = dict(self.named_parameters())
|
||||||
return len(params_dict)
|
return len(params_dict)
|
||||||
|
|
||||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
def load_weights_to_module(
|
||||||
|
self,
|
||||||
|
fqn: str,
|
||||||
|
weights: Iterable[Tuple[str, torch.Tensor]],
|
||||||
|
):
|
||||||
|
"""Load weights onto submodule pointed by path `fqn`."""
|
||||||
stacked_params_mapping = [
|
stacked_params_mapping = [
|
||||||
# (param_name, shard_name, shard_id)
|
# (param_name, shard_name, shard_id)
|
||||||
(".qkv_proj", ".q_proj", "q"),
|
(".qkv_proj", ".q_proj", "q"),
|
||||||
@@ -469,7 +474,8 @@ class TorchNativeLlamaForCausalLM(nn.Module):
|
|||||||
(".gate_up_proj", ".gate_proj", 0),
|
(".gate_up_proj", ".gate_proj", 0),
|
||||||
(".gate_up_proj", ".up_proj", 1),
|
(".gate_up_proj", ".up_proj", 1),
|
||||||
]
|
]
|
||||||
params_dict = dict(self.named_parameters())
|
module = self.get_submodule(fqn)
|
||||||
|
params_dict = dict(module.named_parameters(prefix=fqn, recurse=False))
|
||||||
|
|
||||||
for name, loaded_weight in weights:
|
for name, loaded_weight in weights:
|
||||||
if "rotary_emb.inv_freq" in name or "projector" in name:
|
if "rotary_emb.inv_freq" in name or "projector" in name:
|
||||||
@@ -486,7 +492,7 @@ class TorchNativeLlamaForCausalLM(nn.Module):
|
|||||||
continue
|
continue
|
||||||
name = name.replace(weight_name, param_name)
|
name = name.replace(weight_name, param_name)
|
||||||
# Skip loading extra bias for GPTQ models.
|
# Skip loading extra bias for GPTQ models.
|
||||||
if name.endswith(".bias") and name not in params_dict:
|
if name.endswith(".bias") or name not in params_dict:
|
||||||
continue
|
continue
|
||||||
param = params_dict[name]
|
param = params_dict[name]
|
||||||
weight_loader = param.weight_loader
|
weight_loader = param.weight_loader
|
||||||
@@ -494,12 +500,19 @@ class TorchNativeLlamaForCausalLM(nn.Module):
|
|||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
# Skip loading extra bias for GPTQ models.
|
# Skip loading extra bias for GPTQ models.
|
||||||
if name.endswith(".bias") and name not in params_dict:
|
if name.endswith(".bias") or name not in params_dict:
|
||||||
continue
|
continue
|
||||||
param = params_dict[name]
|
param = params_dict[name]
|
||||||
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||||
weight_loader(param, loaded_weight)
|
weight_loader(param, loaded_weight)
|
||||||
|
|
||||||
|
def load_weights(
|
||||||
|
self,
|
||||||
|
weights: Iterable[Tuple[str, torch.Tensor]],
|
||||||
|
):
|
||||||
|
"""Load weights onto the full model."""
|
||||||
|
self.load_weights_to_module("", weights)
|
||||||
|
|
||||||
|
|
||||||
class TorchNativePhi3ForCausalLM(TorchNativeLlamaForCausalLM):
|
class TorchNativePhi3ForCausalLM(TorchNativeLlamaForCausalLM):
|
||||||
pass
|
pass
|
||||||
|
|||||||
@@ -317,6 +317,7 @@ class ServerArgs:
|
|||||||
"dummy",
|
"dummy",
|
||||||
"gguf",
|
"gguf",
|
||||||
"bitsandbytes",
|
"bitsandbytes",
|
||||||
|
"layered",
|
||||||
],
|
],
|
||||||
help="The format of the model weights to load. "
|
help="The format of the model weights to load. "
|
||||||
'"auto" will try to load the weights in the safetensors format '
|
'"auto" will try to load the weights in the safetensors format '
|
||||||
@@ -330,7 +331,10 @@ class ServerArgs:
|
|||||||
"which is mainly for profiling."
|
"which is mainly for profiling."
|
||||||
'"gguf" will load the weights in the gguf format. '
|
'"gguf" will load the weights in the gguf format. '
|
||||||
'"bitsandbytes" will load the weights using bitsandbytes '
|
'"bitsandbytes" will load the weights using bitsandbytes '
|
||||||
"quantization.",
|
"quantization."
|
||||||
|
'"layered" loads weights layer by layer so that one can quantize a '
|
||||||
|
"layer before loading another to make the peak memory envelope "
|
||||||
|
"smaller.",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--trust-remote-code",
|
"--trust-remote-code",
|
||||||
|
|||||||
Reference in New Issue
Block a user