CUDA-graph-compatible releasing and resuming KV cache and model weight memory (#2630)
This commit is contained in:
@@ -44,6 +44,7 @@ srt_hpu = ["sglang[runtime_common]"]
|
||||
openai = ["openai>=1.0", "tiktoken"]
|
||||
anthropic = ["anthropic>=0.20.0"]
|
||||
litellm = ["litellm>=1.0.0"]
|
||||
torch_memory_saver = ["torch_memory_saver"]
|
||||
test = [
|
||||
"jsonlines",
|
||||
"matplotlib",
|
||||
|
||||
@@ -19,9 +19,7 @@ processes (TokenizerManager, DetokenizerManager, Controller).
|
||||
import uuid
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
from sglang.srt.managers.schedule_batch import BaseFinishReason
|
||||
from sglang.srt.sampling.sampling_params import SamplingParams
|
||||
@@ -459,6 +457,26 @@ class GetWeightsByNameReqOutput:
|
||||
parameter: list
|
||||
|
||||
|
||||
@dataclass
|
||||
class ReleaseMemoryOccupationReqInput:
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class ReleaseMemoryOccupationReqOutput:
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class ResumeMemoryOccupationReqInput:
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class ResumeMemoryOccupationReqOutput:
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class AbortReq:
|
||||
# The request id
|
||||
|
||||
@@ -47,6 +47,10 @@ from sglang.srt.managers.io_struct import (
|
||||
OpenSessionReqInput,
|
||||
OpenSessionReqOutput,
|
||||
ProfileReq,
|
||||
ReleaseMemoryOccupationReqInput,
|
||||
ReleaseMemoryOccupationReqOutput,
|
||||
ResumeMemoryOccupationReqInput,
|
||||
ResumeMemoryOccupationReqOutput,
|
||||
TokenizedEmbeddingReqInput,
|
||||
TokenizedGenerateReqInput,
|
||||
UpdateWeightFromDiskReqInput,
|
||||
@@ -88,6 +92,7 @@ from sglang.srt.utils import (
|
||||
set_random_seed,
|
||||
suppress_other_loggers,
|
||||
)
|
||||
from sglang.torch_memory_saver_adapter import TorchMemorySaverAdapter
|
||||
from sglang.utils import get_exception_traceback
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -357,6 +362,10 @@ class Scheduler:
|
||||
t.start()
|
||||
self.parent_process = psutil.Process().parent()
|
||||
|
||||
self.memory_saver_adapter = TorchMemorySaverAdapter.create(
|
||||
enable=server_args.enable_memory_saver
|
||||
)
|
||||
|
||||
# Init profiler
|
||||
if os.getenv("SGLANG_TORCH_PROFILER_DIR", "") == "":
|
||||
self.profiler = None
|
||||
@@ -519,6 +528,12 @@ class Scheduler:
|
||||
elif isinstance(recv_req, GetWeightsByNameReqInput):
|
||||
parameter = self.get_weights_by_name(recv_req)
|
||||
self.send_to_tokenizer.send_pyobj(GetWeightsByNameReqOutput(parameter))
|
||||
elif isinstance(recv_req, ReleaseMemoryOccupationReqInput):
|
||||
self.release_memory_occupation()
|
||||
self.send_to_tokenizer.send_pyobj(ReleaseMemoryOccupationReqOutput())
|
||||
elif isinstance(recv_req, ResumeMemoryOccupationReqInput):
|
||||
self.resume_memory_occupation()
|
||||
self.send_to_tokenizer.send_pyobj(ResumeMemoryOccupationReqOutput())
|
||||
elif isinstance(recv_req, ProfileReq):
|
||||
if recv_req == ProfileReq.START_PROFILE:
|
||||
self.start_profile()
|
||||
@@ -1538,6 +1553,20 @@ class Scheduler:
|
||||
parameter = self.tp_worker.get_weights_by_name(recv_req)
|
||||
return parameter
|
||||
|
||||
def release_memory_occupation(self):
|
||||
self.stashed_model_static_state = _export_static_state(
|
||||
self.tp_worker.worker.model_runner.model
|
||||
)
|
||||
self.memory_saver_adapter.pause()
|
||||
self.flush_cache()
|
||||
|
||||
def resume_memory_occupation(self):
|
||||
self.memory_saver_adapter.resume()
|
||||
_import_static_state(
|
||||
self.tp_worker.worker.model_runner.model, self.stashed_model_static_state
|
||||
)
|
||||
del self.stashed_model_static_state
|
||||
|
||||
def start_profile(self) -> None:
|
||||
if self.profiler is None:
|
||||
raise RuntimeError("Profiler is not enabled.")
|
||||
@@ -1576,6 +1605,20 @@ class Scheduler:
|
||||
del self.sessions[session_id]
|
||||
|
||||
|
||||
def _export_static_state(model):
|
||||
return dict(
|
||||
buffers=[
|
||||
(name, buffer.detach().clone()) for name, buffer in model.named_buffers()
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def _import_static_state(model, static_params):
|
||||
self_named_buffers = dict(model.named_buffers())
|
||||
for name, tensor in static_params["buffers"]:
|
||||
self_named_buffers[name][...] = tensor
|
||||
|
||||
|
||||
def run_scheduler_process(
|
||||
server_args: ServerArgs,
|
||||
port_args: PortArgs,
|
||||
|
||||
@@ -53,6 +53,10 @@ from sglang.srt.managers.io_struct import (
|
||||
OpenSessionReqInput,
|
||||
OpenSessionReqOutput,
|
||||
ProfileReq,
|
||||
ReleaseMemoryOccupationReqInput,
|
||||
ReleaseMemoryOccupationReqOutput,
|
||||
ResumeMemoryOccupationReqInput,
|
||||
ResumeMemoryOccupationReqOutput,
|
||||
SessionParams,
|
||||
TokenizedEmbeddingReqInput,
|
||||
TokenizedGenerateReqInput,
|
||||
@@ -188,6 +192,12 @@ class TokenizerManager:
|
||||
self.get_weights_by_name_communicator = _Communicator(
|
||||
self.send_to_scheduler, server_args.dp_size
|
||||
)
|
||||
self.release_memory_occupation_communicator = _Communicator(
|
||||
self.send_to_scheduler, server_args.dp_size
|
||||
)
|
||||
self.resume_memory_occupation_communicator = _Communicator(
|
||||
self.send_to_scheduler, server_args.dp_size
|
||||
)
|
||||
|
||||
# Metrics
|
||||
if self.enable_metrics:
|
||||
@@ -548,6 +558,22 @@ class TokenizerManager:
|
||||
else:
|
||||
return all_parameters
|
||||
|
||||
async def release_memory_occupation(
|
||||
self,
|
||||
obj: ReleaseMemoryOccupationReqInput,
|
||||
request: Optional[fastapi.Request] = None,
|
||||
):
|
||||
self.auto_create_handle_loop()
|
||||
await self.release_memory_occupation_communicator(obj)
|
||||
|
||||
async def resume_memory_occupation(
|
||||
self,
|
||||
obj: ResumeMemoryOccupationReqInput,
|
||||
request: Optional[fastapi.Request] = None,
|
||||
):
|
||||
self.auto_create_handle_loop()
|
||||
await self.resume_memory_occupation_communicator(obj)
|
||||
|
||||
async def open_session(
|
||||
self, obj: OpenSessionReqInput, request: Optional[fastapi.Request] = None
|
||||
):
|
||||
@@ -627,6 +653,8 @@ class TokenizerManager:
|
||||
UpdateWeightsFromDistributedReqOutput,
|
||||
GetWeightsByNameReqOutput,
|
||||
InitWeightsUpdateGroupReqOutput,
|
||||
ReleaseMemoryOccupationReqOutput,
|
||||
ResumeMemoryOccupationReqOutput,
|
||||
] = await self.recv_from_detokenizer.recv_pyobj()
|
||||
|
||||
if isinstance(recv_obj, (BatchStrOut, BatchEmbeddingOut, BatchTokenIDOut)):
|
||||
@@ -709,6 +737,10 @@ class TokenizerManager:
|
||||
self.update_weights_from_tensor_communicator.handle_recv(recv_obj)
|
||||
elif isinstance(recv_obj, GetWeightsByNameReqOutput):
|
||||
self.get_weights_by_name_communicator.handle_recv(recv_obj)
|
||||
elif isinstance(recv_obj, ReleaseMemoryOccupationReqOutput):
|
||||
self.release_memory_occupation_communicator.handle_recv(recv_obj)
|
||||
elif isinstance(recv_obj, ResumeMemoryOccupationReqOutput):
|
||||
self.resume_memory_occupation_communicator.handle_recv(recv_obj)
|
||||
else:
|
||||
raise ValueError(f"Invalid object: {recv_obj=}")
|
||||
|
||||
|
||||
@@ -13,6 +13,8 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
|
||||
from sglang.torch_memory_saver_adapter import TorchMemorySaverAdapter
|
||||
|
||||
"""
|
||||
Memory pool.
|
||||
|
||||
@@ -42,13 +44,25 @@ GB = 1024 * 1024 * 1024
|
||||
class ReqToTokenPool:
|
||||
"""A memory pool that maps a request to its token locations."""
|
||||
|
||||
def __init__(self, size: int, max_context_len: int, device: str, use_records: bool):
|
||||
def __init__(
|
||||
self,
|
||||
size: int,
|
||||
max_context_len: int,
|
||||
device: str,
|
||||
use_records: bool,
|
||||
enable_memory_saver: bool,
|
||||
):
|
||||
memory_saver_adapter = TorchMemorySaverAdapter.create(
|
||||
enable=enable_memory_saver
|
||||
)
|
||||
|
||||
self.size = size
|
||||
self.max_context_len = max_context_len
|
||||
self.device = device
|
||||
self.req_to_token = torch.zeros(
|
||||
(size, max_context_len), dtype=torch.int32, device=device
|
||||
)
|
||||
with memory_saver_adapter.region():
|
||||
self.req_to_token = torch.zeros(
|
||||
(size, max_context_len), dtype=torch.int32, device=device
|
||||
)
|
||||
self.free_slots = list(range(size))
|
||||
self.write_records = []
|
||||
self.use_records = use_records
|
||||
@@ -189,8 +203,14 @@ class MHATokenToKVPool(BaseTokenToKVPool):
|
||||
head_dim: int,
|
||||
layer_num: int,
|
||||
device: str,
|
||||
enable_memory_saver: bool,
|
||||
):
|
||||
super().__init__(size, dtype, device)
|
||||
|
||||
self.memory_saver_adapter = TorchMemorySaverAdapter.create(
|
||||
enable=enable_memory_saver
|
||||
)
|
||||
|
||||
self.head_num = head_num
|
||||
self.head_dim = head_dim
|
||||
self.layer_num = layer_num
|
||||
@@ -202,24 +222,25 @@ class MHATokenToKVPool(BaseTokenToKVPool):
|
||||
)
|
||||
|
||||
def _create_buffers(self):
|
||||
# [size, head_num, head_dim] for each layer
|
||||
# The padded slot 0 is used for writing dummy outputs from padded tokens.
|
||||
self.k_buffer = [
|
||||
torch.empty(
|
||||
(self.size + 1, self.head_num, self.head_dim),
|
||||
dtype=self.store_dtype,
|
||||
device=self.device,
|
||||
)
|
||||
for _ in range(self.layer_num)
|
||||
]
|
||||
self.v_buffer = [
|
||||
torch.empty(
|
||||
(self.size + 1, self.head_num, self.head_dim),
|
||||
dtype=self.store_dtype,
|
||||
device=self.device,
|
||||
)
|
||||
for _ in range(self.layer_num)
|
||||
]
|
||||
with self.memory_saver_adapter.region():
|
||||
# [size, head_num, head_dim] for each layer
|
||||
# The padded slot 0 is used for writing dummy outputs from padded tokens.
|
||||
self.k_buffer = [
|
||||
torch.empty(
|
||||
(self.size + 1, self.head_num, self.head_dim),
|
||||
dtype=self.store_dtype,
|
||||
device=self.device,
|
||||
)
|
||||
for _ in range(self.layer_num)
|
||||
]
|
||||
self.v_buffer = [
|
||||
torch.empty(
|
||||
(self.size + 1, self.head_num, self.head_dim),
|
||||
dtype=self.store_dtype,
|
||||
device=self.device,
|
||||
)
|
||||
for _ in range(self.layer_num)
|
||||
]
|
||||
|
||||
def _clear_buffers(self):
|
||||
del self.k_buffer
|
||||
@@ -307,19 +328,26 @@ class MLATokenToKVPool(BaseTokenToKVPool):
|
||||
qk_rope_head_dim: int,
|
||||
layer_num: int,
|
||||
device: str,
|
||||
enable_memory_saver: bool,
|
||||
):
|
||||
super().__init__(size, dtype, device)
|
||||
|
||||
self.kv_lora_rank = kv_lora_rank
|
||||
# The padded slot 0 is used for writing dummy outputs from padded tokens.
|
||||
self.kv_buffer = [
|
||||
torch.empty(
|
||||
(size + 1, 1, kv_lora_rank + qk_rope_head_dim),
|
||||
dtype=self.store_dtype,
|
||||
device=device,
|
||||
)
|
||||
for _ in range(layer_num)
|
||||
]
|
||||
|
||||
memory_saver_adapter = TorchMemorySaverAdapter.create(
|
||||
enable=enable_memory_saver
|
||||
)
|
||||
|
||||
with memory_saver_adapter.region():
|
||||
# The padded slot 0 is used for writing dummy outputs from padded tokens.
|
||||
self.kv_buffer = [
|
||||
torch.empty(
|
||||
(size + 1, 1, kv_lora_rank + qk_rope_head_dim),
|
||||
dtype=self.store_dtype,
|
||||
device=device,
|
||||
)
|
||||
for _ in range(layer_num)
|
||||
]
|
||||
|
||||
def get_key_buffer(self, layer_id: int):
|
||||
if self.store_dtype != self.dtype:
|
||||
@@ -360,26 +388,32 @@ class DoubleSparseTokenToKVPool(BaseTokenToKVPool):
|
||||
layer_num: int,
|
||||
device: str,
|
||||
heavy_channel_num: int,
|
||||
enable_memory_saver: bool,
|
||||
):
|
||||
super().__init__(size, dtype, device)
|
||||
|
||||
# [size, head_num, head_dim] for each layer
|
||||
self.k_buffer = [
|
||||
torch.empty((size + 1, head_num, head_dim), dtype=dtype, device=device)
|
||||
for _ in range(layer_num)
|
||||
]
|
||||
self.v_buffer = [
|
||||
torch.empty((size + 1, head_num, head_dim), dtype=dtype, device=device)
|
||||
for _ in range(layer_num)
|
||||
]
|
||||
memory_saver_adapter = TorchMemorySaverAdapter.create(
|
||||
enable=enable_memory_saver
|
||||
)
|
||||
|
||||
# [size, head_num, heavy_channel_num] for each layer
|
||||
self.label_buffer = [
|
||||
torch.empty(
|
||||
(size + 1, head_num, heavy_channel_num), dtype=dtype, device=device
|
||||
)
|
||||
for _ in range(layer_num)
|
||||
]
|
||||
with memory_saver_adapter.region():
|
||||
# [size, head_num, head_dim] for each layer
|
||||
self.k_buffer = [
|
||||
torch.empty((size + 1, head_num, head_dim), dtype=dtype, device=device)
|
||||
for _ in range(layer_num)
|
||||
]
|
||||
self.v_buffer = [
|
||||
torch.empty((size + 1, head_num, head_dim), dtype=dtype, device=device)
|
||||
for _ in range(layer_num)
|
||||
]
|
||||
|
||||
# [size, head_num, heavy_channel_num] for each layer
|
||||
self.label_buffer = [
|
||||
torch.empty(
|
||||
(size + 1, head_num, heavy_channel_num), dtype=dtype, device=device
|
||||
)
|
||||
for _ in range(layer_num)
|
||||
]
|
||||
|
||||
def get_key_buffer(self, layer_id: int):
|
||||
return self.k_buffer[layer_id]
|
||||
|
||||
@@ -60,6 +60,7 @@ from sglang.srt.utils import (
|
||||
monkey_patch_vllm_p2p_access_check,
|
||||
set_cpu_offload_max_bytes,
|
||||
)
|
||||
from sglang.torch_memory_saver_adapter import TorchMemorySaverAdapter
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -166,6 +167,10 @@ class ModelRunner:
|
||||
# Get memory before model loading
|
||||
min_per_gpu_memory = self.init_torch_distributed()
|
||||
|
||||
self.memory_saver_adapter = TorchMemorySaverAdapter.create(
|
||||
enable=self.server_args.enable_memory_saver
|
||||
)
|
||||
|
||||
# Load the model
|
||||
self.sampler = Sampler()
|
||||
self.load_model()
|
||||
@@ -272,11 +277,12 @@ class ModelRunner:
|
||||
monkey_patch_vllm_gguf_config()
|
||||
|
||||
# Load the model
|
||||
self.model = get_model(
|
||||
model_config=self.model_config,
|
||||
load_config=self.load_config,
|
||||
device_config=DeviceConfig(self.device),
|
||||
)
|
||||
with self.memory_saver_adapter.region():
|
||||
self.model = get_model(
|
||||
model_config=self.model_config,
|
||||
load_config=self.load_config,
|
||||
device_config=DeviceConfig(self.device),
|
||||
)
|
||||
|
||||
if self.server_args.kv_cache_dtype == "fp8_e4m3":
|
||||
if self.server_args.quantization_param_path is not None:
|
||||
@@ -417,7 +423,7 @@ class ModelRunner:
|
||||
|
||||
logger.info(
|
||||
f"init custom process group: master_address={master_address}, master_port={master_port}, "
|
||||
f"rank_offset={rank_offset}, world_size={world_size}, group_name={group_name}, backend={backend}"
|
||||
f"rank_offset={rank_offset}, rank={rank}, world_size={world_size}, group_name={group_name}, backend={backend}"
|
||||
)
|
||||
|
||||
try:
|
||||
@@ -590,6 +596,7 @@ class ModelRunner:
|
||||
max_context_len=self.model_config.context_len + 4,
|
||||
device=self.device,
|
||||
use_records=False,
|
||||
enable_memory_saver=self.server_args.enable_memory_saver,
|
||||
)
|
||||
if (
|
||||
self.model_config.attention_arch == AttentionArch.MLA
|
||||
@@ -602,6 +609,7 @@ class ModelRunner:
|
||||
qk_rope_head_dim=self.model_config.qk_rope_head_dim,
|
||||
layer_num=self.model_config.num_hidden_layers,
|
||||
device=self.device,
|
||||
enable_memory_saver=self.server_args.enable_memory_saver,
|
||||
)
|
||||
elif self.server_args.enable_double_sparsity:
|
||||
self.token_to_kv_pool = DoubleSparseTokenToKVPool(
|
||||
@@ -612,6 +620,7 @@ class ModelRunner:
|
||||
layer_num=self.model_config.num_hidden_layers,
|
||||
device=self.device,
|
||||
heavy_channel_num=self.server_args.ds_heavy_channel_num,
|
||||
enable_memory_saver=self.server_args.enable_memory_saver,
|
||||
)
|
||||
else:
|
||||
self.token_to_kv_pool = MHATokenToKVPool(
|
||||
@@ -621,6 +630,7 @@ class ModelRunner:
|
||||
head_dim=self.model_config.head_dim,
|
||||
layer_num=self.model_config.num_hidden_layers,
|
||||
device=self.device,
|
||||
enable_memory_saver=self.server_args.enable_memory_saver,
|
||||
)
|
||||
logger.info(
|
||||
f"Memory pool end. "
|
||||
|
||||
@@ -31,6 +31,8 @@ from typing import AsyncIterator, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
from sglang.torch_memory_saver_adapter import TorchMemorySaverAdapter
|
||||
|
||||
# Fix a bug of Python threading
|
||||
setattr(threading, "_register_atexit", lambda *args, **kwargs: None)
|
||||
|
||||
@@ -57,6 +59,8 @@ from sglang.srt.managers.io_struct import (
|
||||
GetWeightsByNameReqInput,
|
||||
InitWeightsUpdateGroupReqInput,
|
||||
OpenSessionReqInput,
|
||||
ReleaseMemoryOccupationReqInput,
|
||||
ResumeMemoryOccupationReqInput,
|
||||
UpdateWeightFromDiskReqInput,
|
||||
UpdateWeightsFromDistributedReqInput,
|
||||
UpdateWeightsFromTensorReqInput,
|
||||
@@ -255,6 +259,28 @@ async def get_weights_by_name(obj: GetWeightsByNameReqInput, request: Request):
|
||||
return _create_error_response(e)
|
||||
|
||||
|
||||
@app.api_route("/release_memory_occupation", methods=["GET", "POST"])
|
||||
async def release_memory_occupation(
|
||||
obj: ReleaseMemoryOccupationReqInput, request: Request
|
||||
):
|
||||
"""Release GPU occupation temporarily"""
|
||||
try:
|
||||
await tokenizer_manager.release_memory_occupation(obj, request)
|
||||
except Exception as e:
|
||||
return _create_error_response(e)
|
||||
|
||||
|
||||
@app.api_route("/resume_memory_occupation", methods=["GET", "POST"])
|
||||
async def resume_memory_occupation(
|
||||
obj: ResumeMemoryOccupationReqInput, request: Request
|
||||
):
|
||||
"""Resume GPU occupation"""
|
||||
try:
|
||||
await tokenizer_manager.resume_memory_occupation(obj, request)
|
||||
except Exception as e:
|
||||
return _create_error_response(e)
|
||||
|
||||
|
||||
@app.api_route("/open_session", methods=["GET", "POST"])
|
||||
async def open_session(obj: OpenSessionReqInput, request: Request):
|
||||
"""Open a session, and return its unique session id."""
|
||||
@@ -438,6 +464,10 @@ def launch_engine(
|
||||
server_args.model_path, server_args.tokenizer_path
|
||||
)
|
||||
|
||||
memory_saver_adapter = TorchMemorySaverAdapter.create(
|
||||
enable=server_args.enable_memory_saver
|
||||
)
|
||||
|
||||
if server_args.dp_size == 1:
|
||||
# Launch tensor parallel scheduler processes
|
||||
scheduler_procs = []
|
||||
@@ -454,7 +484,8 @@ def launch_engine(
|
||||
target=run_scheduler_process,
|
||||
args=(server_args, port_args, gpu_id, tp_rank, None, writer),
|
||||
)
|
||||
proc.start()
|
||||
with memory_saver_adapter.configure_subprocess():
|
||||
proc.start()
|
||||
scheduler_procs.append(proc)
|
||||
scheduler_pipe_readers.append(reader)
|
||||
|
||||
@@ -471,7 +502,8 @@ def launch_engine(
|
||||
target=run_data_parallel_controller_process,
|
||||
args=(server_args, port_args, writer),
|
||||
)
|
||||
proc.start()
|
||||
with memory_saver_adapter.configure_subprocess():
|
||||
proc.start()
|
||||
|
||||
# Launch detokenizer process
|
||||
detoken_proc = mp.Process(
|
||||
@@ -897,6 +929,18 @@ class Engine:
|
||||
loop = asyncio.get_event_loop()
|
||||
return loop.run_until_complete(tokenizer_manager.get_weights_by_name(obj, None))
|
||||
|
||||
def release_memory_occupation(self):
|
||||
"""Release GPU occupation temporarily"""
|
||||
obj = ReleaseMemoryOccupationReqInput()
|
||||
loop = asyncio.get_event_loop()
|
||||
loop.run_until_complete(tokenizer_manager.release_memory_occupation(obj, None))
|
||||
|
||||
def resume_memory_occupation(self):
|
||||
"""Resume GPU occupation"""
|
||||
obj = ResumeMemoryOccupationReqInput()
|
||||
loop = asyncio.get_event_loop()
|
||||
loop.run_until_complete(tokenizer_manager.resume_memory_occupation(obj, None))
|
||||
|
||||
|
||||
class Runtime:
|
||||
"""
|
||||
|
||||
@@ -23,7 +23,6 @@ from typing import List, Optional
|
||||
import torch
|
||||
|
||||
from sglang.srt.hf_transformers_utils import check_gguf_file
|
||||
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
|
||||
from sglang.srt.utils import (
|
||||
get_amdgpu_memory_capacity,
|
||||
get_hpu_memory_capacity,
|
||||
@@ -157,6 +156,7 @@ class ServerArgs:
|
||||
triton_attention_num_kv_splits: int = 8
|
||||
num_continuous_decode_steps: int = 1
|
||||
delete_ckpt_after_loading: bool = False
|
||||
enable_memory_saver: bool = False
|
||||
|
||||
def __post_init__(self):
|
||||
# Set missing default values
|
||||
@@ -854,6 +854,11 @@ class ServerArgs:
|
||||
action="store_true",
|
||||
help="Delete the model checkpoint after loading the model.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--enable-memory-saver",
|
||||
action="store_true",
|
||||
help="Allow saving memory using release_memory_occupation and resume_memory_occupation",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_cli_args(cls, args: argparse.Namespace):
|
||||
|
||||
59
python/sglang/torch_memory_saver_adapter.py
Normal file
59
python/sglang/torch_memory_saver_adapter.py
Normal file
@@ -0,0 +1,59 @@
|
||||
from abc import ABC
|
||||
from contextlib import contextmanager
|
||||
|
||||
try:
|
||||
import torch_memory_saver
|
||||
|
||||
_primary_memory_saver = torch_memory_saver.TorchMemorySaver()
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
|
||||
class TorchMemorySaverAdapter(ABC):
|
||||
@staticmethod
|
||||
def create(enable: bool):
|
||||
return (
|
||||
_TorchMemorySaverAdapterReal() if enable else _TorchMemorySaverAdapterNoop()
|
||||
)
|
||||
|
||||
def configure_subprocess(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def region(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def pause(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def resume(self):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class _TorchMemorySaverAdapterReal(TorchMemorySaverAdapter):
|
||||
def configure_subprocess(self):
|
||||
return torch_memory_saver.configure_subprocess()
|
||||
|
||||
def region(self):
|
||||
return _primary_memory_saver.region()
|
||||
|
||||
def pause(self):
|
||||
return _primary_memory_saver.pause()
|
||||
|
||||
def resume(self):
|
||||
return _primary_memory_saver.resume()
|
||||
|
||||
|
||||
class _TorchMemorySaverAdapterNoop(TorchMemorySaverAdapter):
|
||||
@contextmanager
|
||||
def configure_subprocess(self):
|
||||
yield
|
||||
|
||||
@contextmanager
|
||||
def region(self):
|
||||
yield
|
||||
|
||||
def pause(self):
|
||||
pass
|
||||
|
||||
def resume(self):
|
||||
pass
|
||||
@@ -12,8 +12,9 @@ bash "${SCRIPT_DIR}/killall_sglang.sh"
|
||||
pip install --upgrade pip
|
||||
pip install -e "python[all]" --find-links https://flashinfer.ai/whl/cu124/torch2.4/flashinfer/
|
||||
|
||||
# Force reinstall flashinfer
|
||||
# Force reinstall flashinfer and torch_memory_saver
|
||||
pip install flashinfer==0.1.6 --find-links ${FLASHINFER_REPO} --force-reinstall --no-deps
|
||||
pip install torch_memory_saver --force-reinstall
|
||||
|
||||
pip install transformers==4.45.2 sentence_transformers accelerate peft
|
||||
|
||||
|
||||
@@ -29,6 +29,7 @@ suites = {
|
||||
"test_openai_server.py",
|
||||
"test_pytorch_sampling_backend.py",
|
||||
"test_radix_attention.py",
|
||||
"test_release_memory_occupation.py",
|
||||
"test_retract_decode.py",
|
||||
"test_server_args.py",
|
||||
"test_session_control.py",
|
||||
|
||||
98
test/srt/test_release_memory_occupation.py
Normal file
98
test/srt/test_release_memory_occupation.py
Normal file
@@ -0,0 +1,98 @@
|
||||
import time
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
from transformers import AutoModelForCausalLM
|
||||
|
||||
import sglang as sgl
|
||||
from sglang.test.test_utils import DEFAULT_SMALL_MODEL_NAME_FOR_TEST
|
||||
|
||||
# (temporarily) set to true to observe memory usage in nvidia-smi more clearly
|
||||
_DEBUG_EXTRA = True
|
||||
|
||||
|
||||
class TestReleaseMemoryOccupation(unittest.TestCase):
|
||||
def test_release_and_resume_occupation(self):
|
||||
prompt = "Today is a sunny day and I like"
|
||||
sampling_params = {"temperature": 0, "max_new_tokens": 8}
|
||||
model_name = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
|
||||
expect_output = " to spend it outdoors. I decided to"
|
||||
|
||||
engine = sgl.Engine(
|
||||
model_path=model_name,
|
||||
random_seed=42,
|
||||
enable_memory_saver=True,
|
||||
# disable_cuda_graph=True, # for debugging only
|
||||
)
|
||||
hf_model_new = AutoModelForCausalLM.from_pretrained(
|
||||
model_name, torch_dtype="bfloat16"
|
||||
)
|
||||
|
||||
print("generate (#1)")
|
||||
outputs = engine.generate(prompt, sampling_params)["text"]
|
||||
self.assertEqual(outputs, expect_output)
|
||||
|
||||
if _DEBUG_EXTRA:
|
||||
time.sleep(3)
|
||||
|
||||
self.assertEqual(
|
||||
_try_allocate_big_tensor(),
|
||||
False,
|
||||
"Should not be able to allocate big tensors before releasing",
|
||||
)
|
||||
|
||||
print("release_memory_occupation start")
|
||||
t = time.time()
|
||||
engine.release_memory_occupation()
|
||||
if _DEBUG_EXTRA:
|
||||
print("release_memory_occupation", time.time() - t)
|
||||
|
||||
if _DEBUG_EXTRA:
|
||||
time.sleep(5)
|
||||
|
||||
self.assertEqual(
|
||||
_try_allocate_big_tensor(),
|
||||
True,
|
||||
"Should be able to allocate big tensors aftre releasing",
|
||||
)
|
||||
|
||||
if _DEBUG_EXTRA:
|
||||
time.sleep(5)
|
||||
|
||||
print("resume_memory_occupation start")
|
||||
t = time.time()
|
||||
engine.resume_memory_occupation()
|
||||
if _DEBUG_EXTRA:
|
||||
print("resume_memory_occupation", time.time() - t)
|
||||
|
||||
self.assertEqual(
|
||||
_try_allocate_big_tensor(),
|
||||
False,
|
||||
"Should not be able to allocate big tensors after resuming",
|
||||
)
|
||||
|
||||
print("update_weights_from_tensor")
|
||||
# As if: PPO has updated hf model's weights, and now we sync it to SGLang
|
||||
engine.update_weights_from_tensor(list(hf_model_new.named_parameters()))
|
||||
|
||||
print("generate (#2)")
|
||||
outputs = engine.generate(prompt, sampling_params)["text"]
|
||||
self.assertEqual(outputs, expect_output)
|
||||
|
||||
if _DEBUG_EXTRA:
|
||||
time.sleep(4)
|
||||
|
||||
engine.shutdown()
|
||||
|
||||
|
||||
def _try_allocate_big_tensor(size: int = 20_000_000_000):
|
||||
try:
|
||||
torch.empty((size,), dtype=torch.uint8, device="cuda")
|
||||
torch.cuda.empty_cache()
|
||||
return True
|
||||
except torch.cuda.OutOfMemoryError:
|
||||
return False
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user