Refactor weight offloading logic (#8521)
This commit is contained in:
@@ -96,6 +96,11 @@ from sglang.srt.model_loader import get_model
|
||||
from sglang.srt.model_loader.loader import DefaultModelLoader, get_model_loader
|
||||
from sglang.srt.model_loader.utils import set_default_torch_dtype
|
||||
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
||||
from sglang.srt.offloader import (
|
||||
create_offloader_from_server_args,
|
||||
get_offloader,
|
||||
set_offloader,
|
||||
)
|
||||
from sglang.srt.patch_torch import monkey_patch_torch_reductions
|
||||
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
|
||||
from sglang.srt.server_args import ServerArgs
|
||||
@@ -118,7 +123,6 @@ from sglang.srt.utils import (
|
||||
is_npu,
|
||||
monkey_patch_p2p_access_check,
|
||||
monkey_patch_vllm_gguf_config,
|
||||
set_cpu_offload_max_bytes,
|
||||
set_cuda_arch,
|
||||
)
|
||||
from sglang.srt.weight_sync.tensor_bucket import (
|
||||
@@ -222,9 +226,6 @@ class ModelRunner:
|
||||
}
|
||||
)
|
||||
|
||||
# CPU offload
|
||||
set_cpu_offload_max_bytes(int(server_args.cpu_offload_gb * 1024**3))
|
||||
|
||||
# Init OpenMP threads binding for CPU
|
||||
if self.device == "cpu":
|
||||
self.init_threads_binding()
|
||||
@@ -232,6 +233,9 @@ class ModelRunner:
|
||||
# Get memory before model loading
|
||||
min_per_gpu_memory = self.init_torch_distributed()
|
||||
|
||||
# CPU offload
|
||||
set_offloader(create_offloader_from_server_args(server_args))
|
||||
|
||||
# Update deep gemm configure
|
||||
if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM:
|
||||
deep_gemm_wrapper.update_deep_gemm_config(gpu_id, server_args)
|
||||
@@ -690,6 +694,8 @@ class ModelRunner:
|
||||
monkey_patch_vllm_parallel_state(reverse=True)
|
||||
monkey_patch_isinstance_for_vllm_base_layer(reverse=True)
|
||||
|
||||
get_offloader().post_init()
|
||||
|
||||
if self.server_args.kv_cache_dtype == "fp8_e4m3":
|
||||
if self.server_args.quantization_param_path is not None:
|
||||
if callable(getattr(self.model, "load_kv_cache_scales", None)):
|
||||
|
||||
Reference in New Issue
Block a user