CUDA-graph-compatible releasing and resuming KV cache and model weight memory (#2630)

This commit is contained in:
fzyzcjy
2025-01-14 03:38:51 +08:00
committed by GitHub
parent d08c77c434
commit 923f518337
12 changed files with 406 additions and 60 deletions

View File

@@ -60,6 +60,7 @@ from sglang.srt.utils import (
monkey_patch_vllm_p2p_access_check,
set_cpu_offload_max_bytes,
)
from sglang.torch_memory_saver_adapter import TorchMemorySaverAdapter
logger = logging.getLogger(__name__)
@@ -166,6 +167,10 @@ class ModelRunner:
# Get memory before model loading
min_per_gpu_memory = self.init_torch_distributed()
self.memory_saver_adapter = TorchMemorySaverAdapter.create(
enable=self.server_args.enable_memory_saver
)
# Load the model
self.sampler = Sampler()
self.load_model()
@@ -272,11 +277,12 @@ class ModelRunner:
monkey_patch_vllm_gguf_config()
# Load the model
self.model = get_model(
model_config=self.model_config,
load_config=self.load_config,
device_config=DeviceConfig(self.device),
)
with self.memory_saver_adapter.region():
self.model = get_model(
model_config=self.model_config,
load_config=self.load_config,
device_config=DeviceConfig(self.device),
)
if self.server_args.kv_cache_dtype == "fp8_e4m3":
if self.server_args.quantization_param_path is not None:
@@ -417,7 +423,7 @@ class ModelRunner:
logger.info(
f"init custom process group: master_address={master_address}, master_port={master_port}, "
f"rank_offset={rank_offset}, world_size={world_size}, group_name={group_name}, backend={backend}"
f"rank_offset={rank_offset}, rank={rank}, world_size={world_size}, group_name={group_name}, backend={backend}"
)
try:
@@ -590,6 +596,7 @@ class ModelRunner:
max_context_len=self.model_config.context_len + 4,
device=self.device,
use_records=False,
enable_memory_saver=self.server_args.enable_memory_saver,
)
if (
self.model_config.attention_arch == AttentionArch.MLA
@@ -602,6 +609,7 @@ class ModelRunner:
qk_rope_head_dim=self.model_config.qk_rope_head_dim,
layer_num=self.model_config.num_hidden_layers,
device=self.device,
enable_memory_saver=self.server_args.enable_memory_saver,
)
elif self.server_args.enable_double_sparsity:
self.token_to_kv_pool = DoubleSparseTokenToKVPool(
@@ -612,6 +620,7 @@ class ModelRunner:
layer_num=self.model_config.num_hidden_layers,
device=self.device,
heavy_channel_num=self.server_args.ds_heavy_channel_num,
enable_memory_saver=self.server_args.enable_memory_saver,
)
else:
self.token_to_kv_pool = MHATokenToKVPool(
@@ -621,6 +630,7 @@ class ModelRunner:
head_dim=self.model_config.head_dim,
layer_num=self.model_config.num_hidden_layers,
device=self.device,
enable_memory_saver=self.server_args.enable_memory_saver,
)
logger.info(
f"Memory pool end. "