Support loading of larger models with on-the-fly quantization (#3061)
This commit is contained in:
@@ -374,6 +374,78 @@ class DefaultModelLoader(BaseModelLoader):
|
||||
return model.eval()
|
||||
|
||||
|
||||
class LayeredModelLoader(DefaultModelLoader):
|
||||
"""Model loader that loads weights layer by layer so that one can quantize a
|
||||
layer before loading another to make the peak memory envelope smaller."""
|
||||
|
||||
def __init__(self, load_config: LoadConfig):
|
||||
# Back to the default load format
|
||||
load_config.load_format = LoadFormat.AUTO
|
||||
super().__init__(load_config)
|
||||
|
||||
def load_model(
|
||||
self,
|
||||
*,
|
||||
model_config: ModelConfig,
|
||||
device_config: DeviceConfig,
|
||||
) -> nn.Module:
|
||||
from sglang.srt.layers.torchao_utils import apply_torchao_config_to_model
|
||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||
|
||||
torchao_config = global_server_args_dict.get("torchao_config")
|
||||
target_device = torch.device(device_config.device)
|
||||
|
||||
with set_default_torch_dtype(model_config.dtype):
|
||||
# Create model on meta device
|
||||
with torch.device("meta"):
|
||||
model = _initialize_model(
|
||||
model_config,
|
||||
self.load_config,
|
||||
)
|
||||
|
||||
# Check model's layered load support
|
||||
if not hasattr(model, "load_weights_to_module"):
|
||||
raise ValueError(
|
||||
"LayeredModelLoader requires the model to have a "
|
||||
"`load_weights_to_module` method. "
|
||||
f"{model_config.model_path} does not support it."
|
||||
)
|
||||
|
||||
# Get all weights from disk
|
||||
weights = self._get_all_weights(model_config, model)
|
||||
|
||||
# Helper function to recursively fill the weights of a module
|
||||
def fill_module(module, fqn: List[str], weights):
|
||||
"""
|
||||
fqn: list of strings representing the fully qualified name of `module`.
|
||||
"""
|
||||
# Layer by layer
|
||||
for name, submod in module.named_children():
|
||||
fill_module(submod, fqn + [name], weights)
|
||||
|
||||
# First materialize on target device
|
||||
module.to_empty(device=target_device, recurse=False)
|
||||
fqn_path = ".".join(fqn)
|
||||
# Fill weights
|
||||
model.load_weights_to_module(
|
||||
fqn_path,
|
||||
weights,
|
||||
)
|
||||
# Quantize weights if applicable
|
||||
if torchao_config and "proj" in fqn_path:
|
||||
# Note: `None` here is needed to indicate no filter, see
|
||||
# `apply_torchao_config_to_model` for details.
|
||||
apply_torchao_config_to_model(module, torchao_config, None)
|
||||
|
||||
# Start calling on root module
|
||||
fill_module(model, [], weights)
|
||||
|
||||
if torchao_config:
|
||||
model.torchao_applied = True
|
||||
|
||||
return model.eval()
|
||||
|
||||
|
||||
class DummyModelLoader(BaseModelLoader):
|
||||
"""Model loader that will set model weights to random values."""
|
||||
|
||||
@@ -1149,4 +1221,7 @@ def get_model_loader(load_config: LoadConfig) -> BaseModelLoader:
|
||||
if load_config.load_format == LoadFormat.GGUF:
|
||||
return GGUFModelLoader(load_config)
|
||||
|
||||
if load_config.load_format == LoadFormat.LAYERED:
|
||||
return LayeredModelLoader(load_config)
|
||||
|
||||
return DefaultModelLoader(load_config)
|
||||
|
||||
Reference in New Issue
Block a user