CUDA-graph-compatible releasing and resuming KV cache and model weight memory (#2630)

This commit is contained in:
fzyzcjy
2025-01-14 03:38:51 +08:00
committed by GitHub
parent d08c77c434
commit 923f518337
12 changed files with 406 additions and 60 deletions

View File

@@ -31,6 +31,8 @@ from typing import AsyncIterator, Dict, List, Optional, Tuple, Union
import torch
from sglang.torch_memory_saver_adapter import TorchMemorySaverAdapter
# Fix a bug of Python threading
setattr(threading, "_register_atexit", lambda *args, **kwargs: None)
@@ -57,6 +59,8 @@ from sglang.srt.managers.io_struct import (
GetWeightsByNameReqInput,
InitWeightsUpdateGroupReqInput,
OpenSessionReqInput,
ReleaseMemoryOccupationReqInput,
ResumeMemoryOccupationReqInput,
UpdateWeightFromDiskReqInput,
UpdateWeightsFromDistributedReqInput,
UpdateWeightsFromTensorReqInput,
@@ -255,6 +259,28 @@ async def get_weights_by_name(obj: GetWeightsByNameReqInput, request: Request):
return _create_error_response(e)
@app.api_route("/release_memory_occupation", methods=["GET", "POST"])
async def release_memory_occupation(
obj: ReleaseMemoryOccupationReqInput, request: Request
):
"""Release GPU occupation temporarily"""
try:
await tokenizer_manager.release_memory_occupation(obj, request)
except Exception as e:
return _create_error_response(e)
@app.api_route("/resume_memory_occupation", methods=["GET", "POST"])
async def resume_memory_occupation(
obj: ResumeMemoryOccupationReqInput, request: Request
):
"""Resume GPU occupation"""
try:
await tokenizer_manager.resume_memory_occupation(obj, request)
except Exception as e:
return _create_error_response(e)
@app.api_route("/open_session", methods=["GET", "POST"])
async def open_session(obj: OpenSessionReqInput, request: Request):
"""Open a session, and return its unique session id."""
@@ -438,6 +464,10 @@ def launch_engine(
server_args.model_path, server_args.tokenizer_path
)
memory_saver_adapter = TorchMemorySaverAdapter.create(
enable=server_args.enable_memory_saver
)
if server_args.dp_size == 1:
# Launch tensor parallel scheduler processes
scheduler_procs = []
@@ -454,7 +484,8 @@ def launch_engine(
target=run_scheduler_process,
args=(server_args, port_args, gpu_id, tp_rank, None, writer),
)
proc.start()
with memory_saver_adapter.configure_subprocess():
proc.start()
scheduler_procs.append(proc)
scheduler_pipe_readers.append(reader)
@@ -471,7 +502,8 @@ def launch_engine(
target=run_data_parallel_controller_process,
args=(server_args, port_args, writer),
)
proc.start()
with memory_saver_adapter.configure_subprocess():
proc.start()
# Launch detokenizer process
detoken_proc = mp.Process(
@@ -897,6 +929,18 @@ class Engine:
loop = asyncio.get_event_loop()
return loop.run_until_complete(tokenizer_manager.get_weights_by_name(obj, None))
def release_memory_occupation(self):
"""Release GPU occupation temporarily"""
obj = ReleaseMemoryOccupationReqInput()
loop = asyncio.get_event_loop()
loop.run_until_complete(tokenizer_manager.release_memory_occupation(obj, None))
def resume_memory_occupation(self):
"""Resume GPU occupation"""
obj = ResumeMemoryOccupationReqInput()
loop = asyncio.get_event_loop()
loop.run_until_complete(tokenizer_manager.resume_memory_occupation(obj, None))
class Runtime:
"""