diff --git a/python/sglang/srt/configs/load_config.py b/python/sglang/srt/configs/load_config.py index 2b2b341fa..6cb35ab47 100644 --- a/python/sglang/srt/configs/load_config.py +++ b/python/sglang/srt/configs/load_config.py @@ -20,6 +20,7 @@ class LoadFormat(str, enum.Enum): GGUF = "gguf" BITSANDBYTES = "bitsandbytes" MISTRAL = "mistral" + LAYERED = "layered" @dataclass diff --git a/python/sglang/srt/layers/torchao_utils.py b/python/sglang/srt/layers/torchao_utils.py index c5bca25df..e08abd5ae 100644 --- a/python/sglang/srt/layers/torchao_utils.py +++ b/python/sglang/srt/layers/torchao_utils.py @@ -5,6 +5,7 @@ Common utilities for torchao. import logging import os import pwd +from typing import Callable, Optional import torch @@ -27,8 +28,18 @@ def save_gemlite_cache(print_error: bool = False) -> bool: return True +def proj_filter( + module: torch.nn.Module, + fqn: str, +): + """Filter function for quantizing projection layers.""" + return "proj" in fqn + + def apply_torchao_config_to_model( - model: torch.nn.Module, torchao_config: str, filter_fn=None + model: torch.nn.Module, + torchao_config: str, + filter_fn: Optional[Callable] = proj_filter, ): """Quantize a modelwith torchao quantization specified by torchao_config @@ -49,11 +60,6 @@ def apply_torchao_config_to_model( ) from torchao.quantization.observer import PerRow, PerTensor - if filter_fn is None: - - def filter_fn(module, fqn): - return "proj" in fqn - if torchao_config == "" or torchao_config is None: return model elif "int8wo" in torchao_config: diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index d5cdcf2be..e7dc6bd66 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -185,9 +185,12 @@ class ModelRunner: self.load_model() # Apply torchao quantization - apply_torchao_config_to_model( - self.model, global_server_args_dict["torchao_config"] - ) + torchao_applied = getattr(self.model, "torchao_applied", False) + # In layered loading, torchao may have been applied + if not torchao_applied: + apply_torchao_config_to_model( + self.model, global_server_args_dict["torchao_config"] + ) # Apply torch TP if the model supports it supports_torch_tp = getattr(self.model, "supports_torch_tp", False) diff --git a/python/sglang/srt/model_loader/loader.py b/python/sglang/srt/model_loader/loader.py index 677d716d4..9e6b09488 100644 --- a/python/sglang/srt/model_loader/loader.py +++ b/python/sglang/srt/model_loader/loader.py @@ -374,6 +374,78 @@ class DefaultModelLoader(BaseModelLoader): return model.eval() +class LayeredModelLoader(DefaultModelLoader): + """Model loader that loads weights layer by layer so that one can quantize a + layer before loading another to make the peak memory envelope smaller.""" + + def __init__(self, load_config: LoadConfig): + # Back to the default load format + load_config.load_format = LoadFormat.AUTO + super().__init__(load_config) + + def load_model( + self, + *, + model_config: ModelConfig, + device_config: DeviceConfig, + ) -> nn.Module: + from sglang.srt.layers.torchao_utils import apply_torchao_config_to_model + from sglang.srt.managers.schedule_batch import global_server_args_dict + + torchao_config = global_server_args_dict.get("torchao_config") + target_device = torch.device(device_config.device) + + with set_default_torch_dtype(model_config.dtype): + # Create model on meta device + with torch.device("meta"): + model = _initialize_model( + model_config, + self.load_config, + ) + + # Check model's layered load support + if not hasattr(model, "load_weights_to_module"): + raise ValueError( + "LayeredModelLoader requires the model to have a " + "`load_weights_to_module` method. " + f"{model_config.model_path} does not support it." + ) + + # Get all weights from disk + weights = self._get_all_weights(model_config, model) + + # Helper function to recursively fill the weights of a module + def fill_module(module, fqn: List[str], weights): + """ + fqn: list of strings representing the fully qualified name of `module`. + """ + # Layer by layer + for name, submod in module.named_children(): + fill_module(submod, fqn + [name], weights) + + # First materialize on target device + module.to_empty(device=target_device, recurse=False) + fqn_path = ".".join(fqn) + # Fill weights + model.load_weights_to_module( + fqn_path, + weights, + ) + # Quantize weights if applicable + if torchao_config and "proj" in fqn_path: + # Note: `None` here is needed to indicate no filter, see + # `apply_torchao_config_to_model` for details. + apply_torchao_config_to_model(module, torchao_config, None) + + # Start calling on root module + fill_module(model, [], weights) + + if torchao_config: + model.torchao_applied = True + + return model.eval() + + class DummyModelLoader(BaseModelLoader): """Model loader that will set model weights to random values.""" @@ -1149,4 +1221,7 @@ def get_model_loader(load_config: LoadConfig) -> BaseModelLoader: if load_config.load_format == LoadFormat.GGUF: return GGUFModelLoader(load_config) + if load_config.load_format == LoadFormat.LAYERED: + return LayeredModelLoader(load_config) + return DefaultModelLoader(load_config) diff --git a/python/sglang/srt/models/torch_native_llama.py b/python/sglang/srt/models/torch_native_llama.py index 024a6f317..7b3e5bc5d 100644 --- a/python/sglang/srt/models/torch_native_llama.py +++ b/python/sglang/srt/models/torch_native_llama.py @@ -460,7 +460,12 @@ class TorchNativeLlamaForCausalLM(nn.Module): params_dict = dict(self.named_parameters()) return len(params_dict) - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + def load_weights_to_module( + self, + fqn: str, + weights: Iterable[Tuple[str, torch.Tensor]], + ): + """Load weights onto submodule pointed by path `fqn`.""" stacked_params_mapping = [ # (param_name, shard_name, shard_id) (".qkv_proj", ".q_proj", "q"), @@ -469,7 +474,8 @@ class TorchNativeLlamaForCausalLM(nn.Module): (".gate_up_proj", ".gate_proj", 0), (".gate_up_proj", ".up_proj", 1), ] - params_dict = dict(self.named_parameters()) + module = self.get_submodule(fqn) + params_dict = dict(module.named_parameters(prefix=fqn, recurse=False)) for name, loaded_weight in weights: if "rotary_emb.inv_freq" in name or "projector" in name: @@ -486,7 +492,7 @@ class TorchNativeLlamaForCausalLM(nn.Module): continue name = name.replace(weight_name, param_name) # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: + if name.endswith(".bias") or name not in params_dict: continue param = params_dict[name] weight_loader = param.weight_loader @@ -494,12 +500,19 @@ class TorchNativeLlamaForCausalLM(nn.Module): break else: # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: + if name.endswith(".bias") or name not in params_dict: continue param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) + def load_weights( + self, + weights: Iterable[Tuple[str, torch.Tensor]], + ): + """Load weights onto the full model.""" + self.load_weights_to_module("", weights) + class TorchNativePhi3ForCausalLM(TorchNativeLlamaForCausalLM): pass diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 4a7a28751..330c38132 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -317,6 +317,7 @@ class ServerArgs: "dummy", "gguf", "bitsandbytes", + "layered", ], help="The format of the model weights to load. " '"auto" will try to load the weights in the safetensors format ' @@ -330,7 +331,10 @@ class ServerArgs: "which is mainly for profiling." '"gguf" will load the weights in the gguf format. ' '"bitsandbytes" will load the weights using bitsandbytes ' - "quantization.", + "quantization." + '"layered" loads weights layer by layer so that one can quantize a ' + "layer before loading another to make the peak memory envelope " + "smaller.", ) parser.add_argument( "--trust-remote-code",