Revert "Revert "Add simple CPU offloading support"" (#2253)
Co-authored-by: Jani Monoses <jani.monoses@gmail.com> Co-authored-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
@@ -61,6 +61,7 @@ from sglang.srt.utils import (
|
||||
is_hip,
|
||||
monkey_patch_vllm_model_config,
|
||||
monkey_patch_vllm_p2p_access_check,
|
||||
set_cpu_offload_max_bytes,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -145,6 +146,8 @@ class ModelRunner:
|
||||
}
|
||||
)
|
||||
|
||||
set_cpu_offload_max_bytes(int(server_args.cpu_offload_gb * 1024**3))
|
||||
|
||||
# Init components
|
||||
min_per_gpu_memory = self.init_torch_distributed()
|
||||
self.sampler = Sampler()
|
||||
|
||||
@@ -38,6 +38,7 @@ from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||
from sglang.srt.layers.radix_attention import RadixAttention
|
||||
from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.utils import make_layers
|
||||
|
||||
|
||||
# Aligned with HF's implementation, using sliding window inclusive with the last token
|
||||
@@ -267,11 +268,15 @@ class Gemma2Model(nn.Module):
|
||||
config.vocab_size,
|
||||
config.hidden_size,
|
||||
)
|
||||
self.layers = nn.ModuleList(
|
||||
[
|
||||
Gemma2DecoderLayer(layer_id, config, cache_config, quant_config)
|
||||
for layer_id in range(config.num_hidden_layers)
|
||||
]
|
||||
self.layers = make_layers(
|
||||
config.num_hidden_layers,
|
||||
lambda idx, prefix: Gemma2DecoderLayer(
|
||||
layer_id=idx,
|
||||
config=config,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
),
|
||||
prefix="",
|
||||
)
|
||||
self.norm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
|
||||
|
||||
@@ -43,6 +43,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
|
||||
)
|
||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.utils import make_layers
|
||||
|
||||
|
||||
class LlamaMLP(nn.Module):
|
||||
@@ -255,13 +256,12 @@ class LlamaModel(nn.Module):
|
||||
config.vocab_size,
|
||||
config.hidden_size,
|
||||
)
|
||||
self.layers = nn.ModuleList(
|
||||
[
|
||||
LlamaDecoderLayer(
|
||||
config, i, quant_config=quant_config, prefix=f"model.layers.{i}"
|
||||
)
|
||||
for i in range(config.num_hidden_layers)
|
||||
]
|
||||
self.layers = make_layers(
|
||||
config.num_hidden_layers,
|
||||
lambda idx, prefix: LlamaDecoderLayer(
|
||||
config=config, quant_config=quant_config, layer_id=idx, prefix=prefix
|
||||
),
|
||||
prefix="model.layers",
|
||||
)
|
||||
|
||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
|
||||
@@ -38,6 +38,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
|
||||
VocabParallelEmbedding,
|
||||
)
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.utils import make_layers
|
||||
|
||||
|
||||
class OlmoAttention(nn.Module):
|
||||
@@ -220,11 +221,13 @@ class OlmoModel(nn.Module):
|
||||
self.embed_tokens = VocabParallelEmbedding(
|
||||
config.vocab_size, config.hidden_size
|
||||
)
|
||||
self.layers = nn.ModuleList(
|
||||
[
|
||||
OlmoDecoderLayer(config, layer_id, quant_config)
|
||||
for layer_id in range(config.num_hidden_layers)
|
||||
]
|
||||
self.layers = make_layers(
|
||||
config.num_hidden_layers,
|
||||
lambda idx, prefix: OlmoDecoderLayer(
|
||||
layer_id=idx,
|
||||
config=config,
|
||||
quant_config=quant_config,
|
||||
),
|
||||
)
|
||||
self.norm = nn.LayerNorm(
|
||||
config.hidden_size, elementwise_affine=False, bias=False
|
||||
|
||||
@@ -48,6 +48,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
|
||||
VocabParallelEmbedding,
|
||||
)
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.utils import make_layers
|
||||
|
||||
|
||||
class OlmoeMoE(nn.Module):
|
||||
@@ -261,11 +262,13 @@ class OlmoeModel(nn.Module):
|
||||
config.vocab_size,
|
||||
config.hidden_size,
|
||||
)
|
||||
self.layers = nn.ModuleList(
|
||||
[
|
||||
OlmoeDecoderLayer(config, layer_id, quant_config=quant_config)
|
||||
for layer_id in range(config.num_hidden_layers)
|
||||
]
|
||||
self.layers = make_layers(
|
||||
config.num_hidden_layers,
|
||||
lambda idx, prefix: OlmoeDecoderLayer(
|
||||
config=config,
|
||||
quant_config=quant_config,
|
||||
layer_id=idx,
|
||||
),
|
||||
)
|
||||
self.norm = RMSNorm(config.hidden_size, eps=1e-5)
|
||||
|
||||
|
||||
@@ -40,6 +40,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
|
||||
VocabParallelEmbedding,
|
||||
)
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.utils import make_layers
|
||||
|
||||
Qwen2Config = None
|
||||
|
||||
@@ -230,11 +231,13 @@ class Qwen2Model(nn.Module):
|
||||
config.vocab_size,
|
||||
config.hidden_size,
|
||||
)
|
||||
self.layers = nn.ModuleList(
|
||||
[
|
||||
Qwen2DecoderLayer(config, i, quant_config=quant_config)
|
||||
for i in range(config.num_hidden_layers)
|
||||
]
|
||||
self.layers = make_layers(
|
||||
config.num_hidden_layers,
|
||||
lambda idx, prefix: Qwen2DecoderLayer(
|
||||
layer_id=idx,
|
||||
config=config,
|
||||
quant_config=quant_config,
|
||||
),
|
||||
)
|
||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
|
||||
|
||||
@@ -62,6 +62,7 @@ class ServerArgs:
|
||||
max_prefill_tokens: int = 16384
|
||||
schedule_policy: str = "lpm"
|
||||
schedule_conservativeness: float = 1.0
|
||||
cpu_offload_gb: int = 0
|
||||
|
||||
# Other runtime options
|
||||
tp_size: int = 1
|
||||
@@ -367,6 +368,13 @@ class ServerArgs:
|
||||
help="How conservative the schedule policy is. A larger value means more conservative scheduling. Use a larger value if you see requests being retracted frequently.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--cpu-offload-gb",
|
||||
type=int,
|
||||
default=ServerArgs.cpu_offload_gb,
|
||||
help="How many GBs of RAM to reserve for CPU offloading",
|
||||
)
|
||||
|
||||
# Other runtime options
|
||||
parser.add_argument(
|
||||
"--tensor-parallel-size",
|
||||
|
||||
@@ -32,7 +32,7 @@ import time
|
||||
import warnings
|
||||
from importlib.metadata import PackageNotFoundError, version
|
||||
from io import BytesIO
|
||||
from typing import Any, Callable, Dict, List, Optional, Union
|
||||
from typing import Any, Callable, Dict, List, Optional, Protocol, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import psutil
|
||||
@@ -45,6 +45,7 @@ from fastapi.responses import ORJSONResponse
|
||||
from packaging import version as pkg_version
|
||||
from starlette.routing import Mount
|
||||
from torch import nn
|
||||
from torch.func import functional_call
|
||||
from torch.library import Library
|
||||
from torch.profiler import ProfilerActivity, profile, record_function
|
||||
from triton.runtime.cache import (
|
||||
@@ -192,6 +193,94 @@ def get_available_gpu_memory(device, gpu_id, distributed=False):
|
||||
return free_gpu_memory / (1 << 30)
|
||||
|
||||
|
||||
def is_pin_memory_available() -> bool:
|
||||
return torch.cuda.is_available()
|
||||
|
||||
|
||||
_CPU_OFFLOAD_BYTES = 0
|
||||
_CPU_OFFLOAD_MAX_BYTES = 0
|
||||
|
||||
|
||||
def set_cpu_offload_max_bytes(max_bytes: int) -> None:
|
||||
global _CPU_OFFLOAD_MAX_BYTES, _CPU_OFFLOAD_BYTES
|
||||
_CPU_OFFLOAD_BYTES = 0
|
||||
_CPU_OFFLOAD_MAX_BYTES = max_bytes
|
||||
|
||||
|
||||
def maybe_offload_to_cpu(module: torch.nn.Module) -> torch.nn.Module:
|
||||
device = next(module.parameters()).device
|
||||
|
||||
if device == torch.device("cpu"):
|
||||
return module
|
||||
|
||||
global _CPU_OFFLOAD_MAX_BYTES, _CPU_OFFLOAD_BYTES
|
||||
if _CPU_OFFLOAD_BYTES >= _CPU_OFFLOAD_MAX_BYTES:
|
||||
return module
|
||||
|
||||
pin_memory = is_pin_memory_available()
|
||||
# offload parameters to CPU
|
||||
# use pin_memory if possible, which helps cudagraph capture speed
|
||||
offloaded_parameters = False
|
||||
for p in module.parameters():
|
||||
if _CPU_OFFLOAD_BYTES >= _CPU_OFFLOAD_MAX_BYTES:
|
||||
# we use per-parameter offloading
|
||||
# one module might have some parameters offloaded and some not
|
||||
break
|
||||
|
||||
# `torch.empty_like` does not support `pin_memory` argument
|
||||
cpu_data = torch.empty_strided(
|
||||
size=p.data.size(),
|
||||
stride=p.data.stride(),
|
||||
dtype=p.data.dtype,
|
||||
layout=p.data.layout,
|
||||
device="cpu",
|
||||
pin_memory=pin_memory,
|
||||
)
|
||||
cpu_data.copy_(p.data)
|
||||
p.data = cpu_data
|
||||
_CPU_OFFLOAD_BYTES += p.data.numel() * p.data.element_size()
|
||||
offloaded_parameters = True
|
||||
|
||||
if offloaded_parameters:
|
||||
original_forward = module.forward
|
||||
|
||||
def forward(*args, **kwargs):
|
||||
module.forward = original_forward
|
||||
device_state = {
|
||||
# here we blindly call `to(device)`
|
||||
# if the parameter is already on the device, it will be a no-op
|
||||
k: v.to(device, non_blocking=True)
|
||||
for k, v in module.state_dict().items()
|
||||
}
|
||||
output = functional_call(module, device_state, args=args, kwargs=kwargs)
|
||||
module.forward = forward
|
||||
return output
|
||||
|
||||
module.forward = forward
|
||||
|
||||
return module
|
||||
|
||||
|
||||
class LayerFn(Protocol):
|
||||
|
||||
def __call__(self, layer_id: int, prefix: str) -> torch.nn.Module: ...
|
||||
|
||||
|
||||
def make_layers(
|
||||
num_hidden_layers: int,
|
||||
layer_fn: LayerFn,
|
||||
prefix: str = "",
|
||||
) -> Tuple[int, int, torch.nn.ModuleList]:
|
||||
"""Make a list of layers with the given layer function"""
|
||||
modules = torch.nn.ModuleList(
|
||||
[
|
||||
maybe_offload_to_cpu(layer_fn(idx=idx, prefix=f"{prefix}.{idx}"))
|
||||
for idx in range(num_hidden_layers)
|
||||
]
|
||||
)
|
||||
return modules
|
||||
|
||||
|
||||
def set_random_seed(seed: int) -> None:
|
||||
"""Set the random seed for all libraries."""
|
||||
random.seed(seed)
|
||||
|
||||
@@ -152,7 +152,37 @@ class TestSRTEngine(unittest.TestCase):
|
||||
|
||||
self.assertTrue(torch.allclose(out1, out2, atol=1e-5, rtol=1e-3))
|
||||
|
||||
def test_7_engine_offline_throughput(self):
|
||||
def test_7_engine_cpu_offload(self):
|
||||
prompt = "Today is a sunny day and I like"
|
||||
model_path = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
|
||||
|
||||
sampling_params = {"temperature": 0, "max_new_tokens": 8}
|
||||
|
||||
engine = sgl.Engine(
|
||||
model_path=model_path,
|
||||
random_seed=42,
|
||||
max_total_tokens=128,
|
||||
)
|
||||
out1 = engine.generate(prompt, sampling_params)["text"]
|
||||
engine.shutdown()
|
||||
|
||||
engine = sgl.Engine(
|
||||
model_path=model_path,
|
||||
random_seed=42,
|
||||
max_total_tokens=128,
|
||||
cpu_offload_gb=3,
|
||||
)
|
||||
out2 = engine.generate(prompt, sampling_params)["text"]
|
||||
engine.shutdown()
|
||||
|
||||
print("==== Answer 1 ====")
|
||||
print(out1)
|
||||
|
||||
print("==== Answer 2 ====")
|
||||
print(out2)
|
||||
self.assertEqual(out1, out2)
|
||||
|
||||
def test_8_engine_offline_throughput(self):
|
||||
server_args = ServerArgs(
|
||||
model_path=DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user