CUDA-graph-compatible releasing and resuming KV cache and model weight memory (#2630)
This commit is contained in:
@@ -60,6 +60,7 @@ from sglang.srt.utils import (
|
||||
monkey_patch_vllm_p2p_access_check,
|
||||
set_cpu_offload_max_bytes,
|
||||
)
|
||||
from sglang.torch_memory_saver_adapter import TorchMemorySaverAdapter
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -166,6 +167,10 @@ class ModelRunner:
|
||||
# Get memory before model loading
|
||||
min_per_gpu_memory = self.init_torch_distributed()
|
||||
|
||||
self.memory_saver_adapter = TorchMemorySaverAdapter.create(
|
||||
enable=self.server_args.enable_memory_saver
|
||||
)
|
||||
|
||||
# Load the model
|
||||
self.sampler = Sampler()
|
||||
self.load_model()
|
||||
@@ -272,11 +277,12 @@ class ModelRunner:
|
||||
monkey_patch_vllm_gguf_config()
|
||||
|
||||
# Load the model
|
||||
self.model = get_model(
|
||||
model_config=self.model_config,
|
||||
load_config=self.load_config,
|
||||
device_config=DeviceConfig(self.device),
|
||||
)
|
||||
with self.memory_saver_adapter.region():
|
||||
self.model = get_model(
|
||||
model_config=self.model_config,
|
||||
load_config=self.load_config,
|
||||
device_config=DeviceConfig(self.device),
|
||||
)
|
||||
|
||||
if self.server_args.kv_cache_dtype == "fp8_e4m3":
|
||||
if self.server_args.quantization_param_path is not None:
|
||||
@@ -417,7 +423,7 @@ class ModelRunner:
|
||||
|
||||
logger.info(
|
||||
f"init custom process group: master_address={master_address}, master_port={master_port}, "
|
||||
f"rank_offset={rank_offset}, world_size={world_size}, group_name={group_name}, backend={backend}"
|
||||
f"rank_offset={rank_offset}, rank={rank}, world_size={world_size}, group_name={group_name}, backend={backend}"
|
||||
)
|
||||
|
||||
try:
|
||||
@@ -590,6 +596,7 @@ class ModelRunner:
|
||||
max_context_len=self.model_config.context_len + 4,
|
||||
device=self.device,
|
||||
use_records=False,
|
||||
enable_memory_saver=self.server_args.enable_memory_saver,
|
||||
)
|
||||
if (
|
||||
self.model_config.attention_arch == AttentionArch.MLA
|
||||
@@ -602,6 +609,7 @@ class ModelRunner:
|
||||
qk_rope_head_dim=self.model_config.qk_rope_head_dim,
|
||||
layer_num=self.model_config.num_hidden_layers,
|
||||
device=self.device,
|
||||
enable_memory_saver=self.server_args.enable_memory_saver,
|
||||
)
|
||||
elif self.server_args.enable_double_sparsity:
|
||||
self.token_to_kv_pool = DoubleSparseTokenToKVPool(
|
||||
@@ -612,6 +620,7 @@ class ModelRunner:
|
||||
layer_num=self.model_config.num_hidden_layers,
|
||||
device=self.device,
|
||||
heavy_channel_num=self.server_args.ds_heavy_channel_num,
|
||||
enable_memory_saver=self.server_args.enable_memory_saver,
|
||||
)
|
||||
else:
|
||||
self.token_to_kv_pool = MHATokenToKVPool(
|
||||
@@ -621,6 +630,7 @@ class ModelRunner:
|
||||
head_dim=self.model_config.head_dim,
|
||||
layer_num=self.model_config.num_hidden_layers,
|
||||
device=self.device,
|
||||
enable_memory_saver=self.server_args.enable_memory_saver,
|
||||
)
|
||||
logger.info(
|
||||
f"Memory pool end. "
|
||||
|
||||
Reference in New Issue
Block a user