Support loading of larger models with on-the-fly quantization (#3061)

This commit is contained in:
Ke Wen
2025-01-22 21:33:17 -08:00
committed by GitHub
parent 8b84e69f25
commit 862bcff833
6 changed files with 116 additions and 14 deletions

View File

@@ -374,6 +374,78 @@ class DefaultModelLoader(BaseModelLoader):
return model.eval()
class LayeredModelLoader(DefaultModelLoader):
"""Model loader that loads weights layer by layer so that one can quantize a
layer before loading another to make the peak memory envelope smaller."""
def __init__(self, load_config: LoadConfig):
# Back to the default load format
load_config.load_format = LoadFormat.AUTO
super().__init__(load_config)
def load_model(
self,
*,
model_config: ModelConfig,
device_config: DeviceConfig,
) -> nn.Module:
from sglang.srt.layers.torchao_utils import apply_torchao_config_to_model
from sglang.srt.managers.schedule_batch import global_server_args_dict
torchao_config = global_server_args_dict.get("torchao_config")
target_device = torch.device(device_config.device)
with set_default_torch_dtype(model_config.dtype):
# Create model on meta device
with torch.device("meta"):
model = _initialize_model(
model_config,
self.load_config,
)
# Check model's layered load support
if not hasattr(model, "load_weights_to_module"):
raise ValueError(
"LayeredModelLoader requires the model to have a "
"`load_weights_to_module` method. "
f"{model_config.model_path} does not support it."
)
# Get all weights from disk
weights = self._get_all_weights(model_config, model)
# Helper function to recursively fill the weights of a module
def fill_module(module, fqn: List[str], weights):
"""
fqn: list of strings representing the fully qualified name of `module`.
"""
# Layer by layer
for name, submod in module.named_children():
fill_module(submod, fqn + [name], weights)
# First materialize on target device
module.to_empty(device=target_device, recurse=False)
fqn_path = ".".join(fqn)
# Fill weights
model.load_weights_to_module(
fqn_path,
weights,
)
# Quantize weights if applicable
if torchao_config and "proj" in fqn_path:
# Note: `None` here is needed to indicate no filter, see
# `apply_torchao_config_to_model` for details.
apply_torchao_config_to_model(module, torchao_config, None)
# Start calling on root module
fill_module(model, [], weights)
if torchao_config:
model.torchao_applied = True
return model.eval()
class DummyModelLoader(BaseModelLoader):
"""Model loader that will set model weights to random values."""
@@ -1149,4 +1221,7 @@ def get_model_loader(load_config: LoadConfig) -> BaseModelLoader:
if load_config.load_format == LoadFormat.GGUF:
return GGUFModelLoader(load_config)
if load_config.load_format == LoadFormat.LAYERED:
return LayeredModelLoader(load_config)
return DefaultModelLoader(load_config)