From ef959d7b857d64c3e10aa8827e9af742283c1571 Mon Sep 17 00:00:00 2001 From: Zaili Wang <109502517+ZailiWang@users.noreply.github.com> Date: Thu, 11 Sep 2025 14:52:22 +0800 Subject: [PATCH] [CPU] fix OOM when mem-fraction is not set (#9090) --- docker/Dockerfile.xeon | 5 ++--- docs/platforms/cpu_server.md | 14 ++++++++------ python/pyproject.toml | 2 +- .../sglang/srt/model_executor/model_runner.py | 18 ++++++++++++++---- python/sglang/srt/utils.py | 4 +++- test/srt/test_intel_amx_attention_backend.py | 2 +- 6 files changed, 29 insertions(+), 16 deletions(-) diff --git a/docker/Dockerfile.xeon b/docker/Dockerfile.xeon index 087e12cca..fdc439b30 100644 --- a/docker/Dockerfile.xeon +++ b/docker/Dockerfile.xeon @@ -31,8 +31,7 @@ ENV PIP_ROOT_USER_ACTION=ignore ENV CONDA_PREFIX=/sgl-workspace/miniforge3 RUN pip config set global.index-url https://download.pytorch.org/whl/cpu && \ - pip config set global.extra-index-url https://pypi.org/simple && \ - pip install intel-openmp + pip config set global.extra-index-url https://pypi.org/simple RUN git clone https://github.com/sgl-project/sglang.git && \ cd sglang && \ @@ -41,7 +40,7 @@ RUN git clone https://github.com/sgl-project/sglang.git && \ pip install torch==${VER_TORCH} torchvision==${VER_TORCHVISION} triton==${VER_TRITON} --force-reinstall && \ cd sgl-kernel && \ cp pyproject_cpu.toml pyproject.toml && \ - pip install -v . + pip install . ENV SGLANG_USE_CPU_ENGINE=1 ENV LD_PRELOAD=/sgl-workspace/miniforge3/lib/libiomp5.so:/sgl-workspace/miniforge3/lib/libtcmalloc.so:/sgl-workspace/miniforge3/lib/libtbbmalloc.so.2 diff --git a/docs/platforms/cpu_server.md b/docs/platforms/cpu_server.md index 4e91e7b88..97fad918d 100644 --- a/docs/platforms/cpu_server.md +++ b/docs/platforms/cpu_server.md @@ -84,13 +84,13 @@ git checkout # Install SGLang dependent libs, and build SGLang main package pip install --upgrade pip setuptools conda install -y libsqlite==3.48.0 gperftools tbb libnuma numactl -pip install intel-openmp pip install -e "python[all_cpu]" +pip install torch==2.7.1 torchvision==0.22.1 triton==3.3.1 --force-reinstall # Build the CPU backend kernels cd sgl-kernel cp pyproject_cpu.toml pyproject.toml -pip install -v . +pip install . # Other required environment variables # Recommend to set these in ~/.bashrc in order not to set every time in a new terminal @@ -134,13 +134,17 @@ Notes: export SGLANG_CPU_OMP_THREADS_BIND="0-39|43-82|86-125|128-167|171-210|214-253" ``` + Please beware that with SGLANG_CPU_OMP_THREADS_BIND set, + the available memory amounts of the ranks may not be determined in prior. + You may need to set proper `--max-total-tokens` to avoid the out-of-memory error. + 3. For optimizing decoding with torch.compile, please add the flag `--enable-torch-compile`. To specify the maximum batch size when using torch compile, set the flag `--torch-compile-max-bs`. For example, `--enable-torch-compile --torch-compile-max-bs 4` means using torch compile and setting the maximum batch size to 4. 4. A warmup step is automatically triggered when the service is started. -The server is ready when you see the log `The server is fired up and ready to roll!`. + The server is ready when you see the log `The server is fired up and ready to roll!`. ## Benchmarking with Requests @@ -164,7 +168,7 @@ python -m sglang.bench_serving -h ``` Additionally, the requests can be formed with -[OpenAI Completions API](https://docs.sglang.ai/backend/openai_api_completions.html) +[OpenAI Completions API](https://docs.sglang.ai/basic_usage/openai_api_completions.html) and sent via the command line (e.g. using `curl`) or via your own script. ## Example: Running DeepSeek-R1 @@ -180,7 +184,6 @@ python -m sglang.launch_server \ --quantization w8a8_int8 \ --host 0.0.0.0 \ --mem-fraction-static 0.8 \ - --max-total-token 65536 \ --tp 6 ``` @@ -194,7 +197,6 @@ python -m sglang.launch_server \ --device cpu \ --host 0.0.0.0 \ --mem-fraction-static 0.8 \ - --max-total-token 65536 \ --tp 6 ``` diff --git a/python/pyproject.toml b/python/pyproject.toml index 2327575f4..a51bc915b 100755 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -87,7 +87,7 @@ srt_hip = [ ] # https://docs.sglang.ai/platforms/cpu_server.html -srt_cpu = ["sglang[runtime_common]"] +srt_cpu = ["sglang[runtime_common]", "intel-openmp"] # https://docs.sglang.ai/platforms/ascend_npu.html srt_npu = ["sglang[runtime_common]"] diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 2548ea59e..56cdee7a2 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -1673,10 +1673,9 @@ class ModelRunner: def init_threads_binding(self): omp_cpuids = os.environ.get("SGLANG_CPU_OMP_THREADS_BIND", "all") + cpu_ids_by_node = get_cpu_ids_by_node() + n_numa_node = len(cpu_ids_by_node) if omp_cpuids == "all": - cpu_ids_by_node = get_cpu_ids_by_node() - n_numa_node = len(cpu_ids_by_node) - assert self.tp_size <= n_numa_node, ( f"SGLANG_CPU_OMP_THREADS_BIND is not set, in this case, " f"tp_size {self.tp_size} should be smaller than or equal to number of numa node on the machine {n_numa_node}. " @@ -1693,7 +1692,18 @@ class ModelRunner: ) self.local_omp_cpuid = cpu_ids_by_node[self.tp_rank] else: - self.local_omp_cpuid = omp_cpuids.split("|")[self.tp_rank] + threads_bind_list = omp_cpuids.split("|") + assert self.tp_size == len(threads_bind_list), ( + f"SGLANG_CPU_OMP_THREADS_BIND setting must be aligned with TP size parameter ({self.tp_size}). " + f"Please double check your settings." + ) + self.local_omp_cpuid = threads_bind_list[self.tp_rank] + if self.tp_size > n_numa_node: + logger.warning( + f"TP size ({self.tp_size})is larger than numa node number ({n_numa_node}), " + f"in this case the available memory amount of each rank cannot be determined in prior. " + f"Please set proper `--max-total-tokens` to avoid the out-of-memory error." + ) def apply_torch_tp(self): logger.info(f"Enabling torch tensor parallelism on {self.tp_size} devices.") diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py index 7ea3f36d5..846baeb01 100644 --- a/python/sglang/srt/utils.py +++ b/python/sglang/srt/utils.py @@ -434,7 +434,9 @@ def get_available_gpu_memory( elif device == "cpu": # TODO: rename the variables in the current function to be not GPU specific - free_gpu_memory = psutil.virtual_memory().available + total_free_memory = psutil.virtual_memory().available + n_numa_node: int = len(get_cpu_ids_by_node()) + free_gpu_memory = round(total_free_memory / n_numa_node, 3) elif device == "npu": num_gpus = torch.npu.device_count() assert gpu_id < num_gpus diff --git a/test/srt/test_intel_amx_attention_backend.py b/test/srt/test_intel_amx_attention_backend.py index 22f7057ce..5534c57f9 100644 --- a/test/srt/test_intel_amx_attention_backend.py +++ b/test/srt/test_intel_amx_attention_backend.py @@ -109,7 +109,7 @@ class TestIntelAMXAttnBackend(CustomTestCase): "--attention-backend", "intel_amx", "--mem-fraction-static", - "0.05", + "0.3", "--disable-radix", "--trust-remote-code", "--disable-overlap-schedule",