From 4057ea82c9a11f4f2379189c390f4a4f88f73854 Mon Sep 17 00:00:00 2001 From: Ying Sheng Date: Thu, 28 Nov 2024 23:36:55 -0800 Subject: [PATCH] Revert "Add simple CPU offloading support" (#2252) We'll re-add the commit to correctly ack Kaichao's authorship --- .../sglang/srt/model_executor/model_runner.py | 3 - python/sglang/srt/models/gemma2.py | 15 +-- python/sglang/srt/models/llama.py | 14 +-- python/sglang/srt/models/olmo.py | 13 +-- python/sglang/srt/models/olmoe.py | 13 +-- python/sglang/srt/models/qwen2.py | 13 +-- python/sglang/srt/server_args.py | 8 -- python/sglang/srt/utils.py | 91 +------------------ test/srt/test_srt_engine.py | 32 +------ 9 files changed, 29 insertions(+), 173 deletions(-) diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 7c1c51a8f..5667084a6 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -61,7 +61,6 @@ from sglang.srt.utils import ( is_hip, monkey_patch_vllm_model_config, monkey_patch_vllm_p2p_access_check, - set_cpu_offload_max_bytes, ) logger = logging.getLogger(__name__) @@ -146,8 +145,6 @@ class ModelRunner: } ) - set_cpu_offload_max_bytes(int(server_args.cpu_offload_gb * 1024**3)) - # Init components min_per_gpu_memory = self.init_torch_distributed() self.sampler = Sampler() diff --git a/python/sglang/srt/models/gemma2.py b/python/sglang/srt/models/gemma2.py index 0fa6a5393..d5972c110 100644 --- a/python/sglang/srt/models/gemma2.py +++ b/python/sglang/srt/models/gemma2.py @@ -38,7 +38,6 @@ from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding from sglang.srt.model_executor.forward_batch_info import ForwardBatch -from sglang.srt.utils import make_layers # Aligned with HF's implementation, using sliding window inclusive with the last token @@ -268,15 +267,11 @@ class Gemma2Model(nn.Module): config.vocab_size, config.hidden_size, ) - self.layers = make_layers( - config.num_hidden_layers, - lambda idx, prefix: Gemma2DecoderLayer( - layer_id=idx, - config=config, - cache_config=cache_config, - quant_config=quant_config, - ), - prefix="", + self.layers = nn.ModuleList( + [ + Gemma2DecoderLayer(layer_id, config, cache_config, quant_config) + for layer_id in range(config.num_hidden_layers) + ] ) self.norm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) diff --git a/python/sglang/srt/models/llama.py b/python/sglang/srt/models/llama.py index 7e9fd0f72..f7dc0f249 100644 --- a/python/sglang/srt/models/llama.py +++ b/python/sglang/srt/models/llama.py @@ -43,7 +43,6 @@ from sglang.srt.layers.vocab_parallel_embedding import ( ) from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.model_executor.forward_batch_info import ForwardBatch -from sglang.srt.utils import make_layers class LlamaMLP(nn.Module): @@ -256,12 +255,13 @@ class LlamaModel(nn.Module): config.vocab_size, config.hidden_size, ) - self.layers = make_layers( - config.num_hidden_layers, - lambda idx, prefix: LlamaDecoderLayer( - config=config, quant_config=quant_config, layer_id=idx, prefix=prefix - ), - prefix="model.layers", + self.layers = nn.ModuleList( + [ + LlamaDecoderLayer( + config, i, quant_config=quant_config, prefix=f"model.layers.{i}" + ) + for i in range(config.num_hidden_layers) + ] ) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) diff --git a/python/sglang/srt/models/olmo.py b/python/sglang/srt/models/olmo.py index 80fd64a53..729039f93 100644 --- a/python/sglang/srt/models/olmo.py +++ b/python/sglang/srt/models/olmo.py @@ -38,7 +38,6 @@ from sglang.srt.layers.vocab_parallel_embedding import ( VocabParallelEmbedding, ) from sglang.srt.model_executor.forward_batch_info import ForwardBatch -from sglang.srt.utils import make_layers class OlmoAttention(nn.Module): @@ -221,13 +220,11 @@ class OlmoModel(nn.Module): self.embed_tokens = VocabParallelEmbedding( config.vocab_size, config.hidden_size ) - self.layers = make_layers( - config.num_hidden_layers, - lambda idx, prefix: OlmoDecoderLayer( - layer_id=idx, - config=config, - quant_config=quant_config, - ), + self.layers = nn.ModuleList( + [ + OlmoDecoderLayer(config, layer_id, quant_config) + for layer_id in range(config.num_hidden_layers) + ] ) self.norm = nn.LayerNorm( config.hidden_size, elementwise_affine=False, bias=False diff --git a/python/sglang/srt/models/olmoe.py b/python/sglang/srt/models/olmoe.py index 407eb98cb..1c8ba52cb 100644 --- a/python/sglang/srt/models/olmoe.py +++ b/python/sglang/srt/models/olmoe.py @@ -48,7 +48,6 @@ from sglang.srt.layers.vocab_parallel_embedding import ( VocabParallelEmbedding, ) from sglang.srt.model_executor.forward_batch_info import ForwardBatch -from sglang.srt.utils import make_layers class OlmoeMoE(nn.Module): @@ -262,13 +261,11 @@ class OlmoeModel(nn.Module): config.vocab_size, config.hidden_size, ) - self.layers = make_layers( - config.num_hidden_layers, - lambda idx, prefix: OlmoeDecoderLayer( - config=config, - quant_config=quant_config, - layer_id=idx, - ), + self.layers = nn.ModuleList( + [ + OlmoeDecoderLayer(config, layer_id, quant_config=quant_config) + for layer_id in range(config.num_hidden_layers) + ] ) self.norm = RMSNorm(config.hidden_size, eps=1e-5) diff --git a/python/sglang/srt/models/qwen2.py b/python/sglang/srt/models/qwen2.py index 634ce1cf1..308e9008a 100644 --- a/python/sglang/srt/models/qwen2.py +++ b/python/sglang/srt/models/qwen2.py @@ -40,7 +40,6 @@ from sglang.srt.layers.vocab_parallel_embedding import ( VocabParallelEmbedding, ) from sglang.srt.model_executor.forward_batch_info import ForwardBatch -from sglang.srt.utils import make_layers Qwen2Config = None @@ -231,13 +230,11 @@ class Qwen2Model(nn.Module): config.vocab_size, config.hidden_size, ) - self.layers = make_layers( - config.num_hidden_layers, - lambda idx, prefix: Qwen2DecoderLayer( - layer_id=idx, - config=config, - quant_config=quant_config, - ), + self.layers = nn.ModuleList( + [ + Qwen2DecoderLayer(config, i, quant_config=quant_config) + for i in range(config.num_hidden_layers) + ] ) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 144ade58e..67c3df71e 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -62,7 +62,6 @@ class ServerArgs: max_prefill_tokens: int = 16384 schedule_policy: str = "lpm" schedule_conservativeness: float = 1.0 - cpu_offload_gb: int = 0 # Other runtime options tp_size: int = 1 @@ -368,13 +367,6 @@ class ServerArgs: help="How conservative the schedule policy is. A larger value means more conservative scheduling. Use a larger value if you see requests being retracted frequently.", ) - parser.add_argument( - "--cpu-offload-gb", - type=int, - default=ServerArgs.cpu_offload_gb, - help="How many GBs of RAM to reserve for CPU offloading", - ) - # Other runtime options parser.add_argument( "--tensor-parallel-size", diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py index 46b4db8e8..7820276cf 100644 --- a/python/sglang/srt/utils.py +++ b/python/sglang/srt/utils.py @@ -32,7 +32,7 @@ import time import warnings from importlib.metadata import PackageNotFoundError, version from io import BytesIO -from typing import Any, Callable, Dict, List, Optional, Protocol, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Union import numpy as np import psutil @@ -45,7 +45,6 @@ from fastapi.responses import ORJSONResponse from packaging import version as pkg_version from starlette.routing import Mount from torch import nn -from torch.func import functional_call from torch.library import Library from torch.profiler import ProfilerActivity, profile, record_function from triton.runtime.cache import ( @@ -193,94 +192,6 @@ def get_available_gpu_memory(device, gpu_id, distributed=False): return free_gpu_memory / (1 << 30) -def is_pin_memory_available() -> bool: - return torch.cuda.is_available() - - -_CPU_OFFLOAD_BYTES = 0 -_CPU_OFFLOAD_MAX_BYTES = 0 - - -def set_cpu_offload_max_bytes(max_bytes: int) -> None: - global _CPU_OFFLOAD_MAX_BYTES, _CPU_OFFLOAD_BYTES - _CPU_OFFLOAD_BYTES = 0 - _CPU_OFFLOAD_MAX_BYTES = max_bytes - - -def maybe_offload_to_cpu(module: torch.nn.Module) -> torch.nn.Module: - device = next(module.parameters()).device - - if device == torch.device("cpu"): - return module - - global _CPU_OFFLOAD_MAX_BYTES, _CPU_OFFLOAD_BYTES - if _CPU_OFFLOAD_BYTES >= _CPU_OFFLOAD_MAX_BYTES: - return module - - pin_memory = is_pin_memory_available() - # offload parameters to CPU - # use pin_memory if possible, which helps cudagraph capture speed - offloaded_parameters = False - for p in module.parameters(): - if _CPU_OFFLOAD_BYTES >= _CPU_OFFLOAD_MAX_BYTES: - # we use per-parameter offloading - # one module might have some parameters offloaded and some not - break - - # `torch.empty_like` does not support `pin_memory` argument - cpu_data = torch.empty_strided( - size=p.data.size(), - stride=p.data.stride(), - dtype=p.data.dtype, - layout=p.data.layout, - device="cpu", - pin_memory=pin_memory, - ) - cpu_data.copy_(p.data) - p.data = cpu_data - _CPU_OFFLOAD_BYTES += p.data.numel() * p.data.element_size() - offloaded_parameters = True - - if offloaded_parameters: - original_forward = module.forward - - def forward(*args, **kwargs): - module.forward = original_forward - device_state = { - # here we blindly call `to(device)` - # if the parameter is already on the device, it will be a no-op - k: v.to(device, non_blocking=True) - for k, v in module.state_dict().items() - } - output = functional_call(module, device_state, args=args, kwargs=kwargs) - module.forward = forward - return output - - module.forward = forward - - return module - - -class LayerFn(Protocol): - - def __call__(self, layer_id: int, prefix: str) -> torch.nn.Module: ... - - -def make_layers( - num_hidden_layers: int, - layer_fn: LayerFn, - prefix: str = "", -) -> Tuple[int, int, torch.nn.ModuleList]: - """Make a list of layers with the given layer function""" - modules = torch.nn.ModuleList( - [ - maybe_offload_to_cpu(layer_fn(idx=idx, prefix=f"{prefix}.{idx}")) - for idx in range(num_hidden_layers) - ] - ) - return modules - - def set_random_seed(seed: int) -> None: """Set the random seed for all libraries.""" random.seed(seed) diff --git a/test/srt/test_srt_engine.py b/test/srt/test_srt_engine.py index a985c8dda..5d7f95440 100644 --- a/test/srt/test_srt_engine.py +++ b/test/srt/test_srt_engine.py @@ -152,37 +152,7 @@ class TestSRTEngine(unittest.TestCase): self.assertTrue(torch.allclose(out1, out2, atol=1e-5, rtol=1e-3)) - def test_7_engine_cpu_offload(self): - prompt = "Today is a sunny day and I like" - model_path = DEFAULT_SMALL_MODEL_NAME_FOR_TEST - - sampling_params = {"temperature": 0, "max_new_tokens": 8} - - engine = sgl.Engine( - model_path=model_path, - random_seed=42, - max_total_tokens=128, - ) - out1 = engine.generate(prompt, sampling_params)["text"] - engine.shutdown() - - engine = sgl.Engine( - model_path=model_path, - random_seed=42, - max_total_tokens=128, - cpu_offload_gb=3, - ) - out2 = engine.generate(prompt, sampling_params)["text"] - engine.shutdown() - - print("==== Answer 1 ====") - print(out1) - - print("==== Answer 2 ====") - print(out2) - self.assertEqual(out1, out2) - - def test_8_engine_offline_throughput(self): + def test_7_engine_offline_throughput(self): server_args = ServerArgs( model_path=DEFAULT_SMALL_MODEL_NAME_FOR_TEST, )