Refactor weight offloading logic (#8521)

This commit is contained in:
fzyzcjy
2025-08-21 18:48:13 +08:00
committed by GitHub
parent de4990a5b2
commit 55d336cb08
3 changed files with 141 additions and 74 deletions

View File

@@ -96,6 +96,11 @@ from sglang.srt.model_loader import get_model
from sglang.srt.model_loader.loader import DefaultModelLoader, get_model_loader
from sglang.srt.model_loader.utils import set_default_torch_dtype
from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.offloader import (
create_offloader_from_server_args,
get_offloader,
set_offloader,
)
from sglang.srt.patch_torch import monkey_patch_torch_reductions
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
from sglang.srt.server_args import ServerArgs
@@ -118,7 +123,6 @@ from sglang.srt.utils import (
is_npu,
monkey_patch_p2p_access_check,
monkey_patch_vllm_gguf_config,
set_cpu_offload_max_bytes,
set_cuda_arch,
)
from sglang.srt.weight_sync.tensor_bucket import (
@@ -222,9 +226,6 @@ class ModelRunner:
}
)
# CPU offload
set_cpu_offload_max_bytes(int(server_args.cpu_offload_gb * 1024**3))
# Init OpenMP threads binding for CPU
if self.device == "cpu":
self.init_threads_binding()
@@ -232,6 +233,9 @@ class ModelRunner:
# Get memory before model loading
min_per_gpu_memory = self.init_torch_distributed()
# CPU offload
set_offloader(create_offloader_from_server_args(server_args))
# Update deep gemm configure
if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM:
deep_gemm_wrapper.update_deep_gemm_config(gpu_id, server_args)
@@ -690,6 +694,8 @@ class ModelRunner:
monkey_patch_vllm_parallel_state(reverse=True)
monkey_patch_isinstance_for_vllm_base_layer(reverse=True)
get_offloader().post_init()
if self.server_args.kv_cache_dtype == "fp8_e4m3":
if self.server_args.quantization_param_path is not None:
if callable(getattr(self.model, "load_kv_cache_scales", None)):