[Minor] Improve code style (#2422)

This commit is contained in:
Lianmin Zheng
2024-12-09 06:30:35 -08:00
committed by GitHub
parent 0ce091a82d
commit 641b7d0ae0
15 changed files with 33 additions and 21 deletions

View File

@@ -242,20 +242,22 @@ class ModelRunner:
if torch.cuda.get_device_capability()[1] < 5:
raise RuntimeError("SGLang only supports sm75 and above.")
# Prepare the vllm model config
# Prepare the model config
self.load_config = LoadConfig(
load_format=self.server_args.load_format,
download_dir=self.server_args.download_dir,
)
if self.server_args.load_format == "gguf":
monkey_patch_vllm_gguf_config()
# Load the model
self.model = get_model(
model_config=self.model_config,
load_config=self.load_config,
device_config=DeviceConfig(self.device),
)
# Parse other args
self.sliding_window_size = (
self.model.get_attention_sliding_window_size()
if hasattr(self.model, "get_attention_sliding_window_size")
@@ -270,8 +272,10 @@ class ModelRunner:
f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
)
def update_weights_from_disk(self, model_path: str, load_format: str):
"""Update engine weights online from disk."""
def update_weights_from_disk(
self, model_path: str, load_format: str
) -> tuple[bool, str]:
"""Update engine weights in-place from the disk."""
from sglang.srt.model_loader.loader import (
DefaultModelLoader,
device_loading_context,