Warn users when release_memory_occupation is called without memory saver enabled (#4566)
This commit is contained in:
12
.github/workflows/pr-test-amd.yml
vendored
12
.github/workflows/pr-test-amd.yml
vendored
@@ -22,7 +22,7 @@ concurrency:
|
|||||||
jobs:
|
jobs:
|
||||||
accuracy-test-1-gpu-amd:
|
accuracy-test-1-gpu-amd:
|
||||||
if: (github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request') &&
|
if: (github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request') &&
|
||||||
github.event.pull_request.draft == false
|
github.event.pull_request.draft == false
|
||||||
runs-on: linux-mi300-gpu-1
|
runs-on: linux-mi300-gpu-1
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout code
|
- name: Checkout code
|
||||||
@@ -56,13 +56,13 @@ jobs:
|
|||||||
- name: Evaluate Accuracy
|
- name: Evaluate Accuracy
|
||||||
timeout-minutes: 20
|
timeout-minutes: 20
|
||||||
run: |
|
run: |
|
||||||
docker exec -w /sglang-checkout/test/srt ci_sglang python3 test_eval_accuracy_large.py
|
docker exec -w /sglang-checkout/test/srt -e SGLANG_IS_IN_CI=1 ci_sglang python3 test_eval_accuracy_large.py
|
||||||
docker exec -w /sglang-checkout/test/srt ci_sglang python3 test_eval_fp8_accuracy.py
|
docker exec -w /sglang-checkout/test/srt -e SGLANG_IS_IN_CI=1 ci_sglang python3 test_eval_fp8_accuracy.py
|
||||||
docker exec -w /sglang-checkout/test/srt ci_sglang python3 models/test_qwen_models.py
|
docker exec -w /sglang-checkout/test/srt -e SGLANG_IS_IN_CI=1 ci_sglang python3 models/test_qwen_models.py
|
||||||
|
|
||||||
mla-test-1-gpu-amd:
|
mla-test-1-gpu-amd:
|
||||||
if: (github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request') &&
|
if: (github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request') &&
|
||||||
github.event.pull_request.draft == false
|
github.event.pull_request.draft == false
|
||||||
runs-on: linux-mi300-gpu-1
|
runs-on: linux-mi300-gpu-1
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout code
|
- name: Checkout code
|
||||||
@@ -96,7 +96,7 @@ jobs:
|
|||||||
- name: MLA TEST
|
- name: MLA TEST
|
||||||
timeout-minutes: 20
|
timeout-minutes: 20
|
||||||
run: |
|
run: |
|
||||||
docker exec -w /sglang-checkout/test/srt ci_sglang python3 test_mla.py
|
docker exec -w /sglang-checkout/test/srt -e SGLANG_IS_IN_CI=1 ci_sglang python3 test_mla.py
|
||||||
|
|
||||||
finish:
|
finish:
|
||||||
if: always()
|
if: always()
|
||||||
|
|||||||
2
.github/workflows/release-docs.yml
vendored
2
.github/workflows/release-docs.yml
vendored
@@ -33,7 +33,7 @@ jobs:
|
|||||||
pip install -r docs/requirements.txt
|
pip install -r docs/requirements.txt
|
||||||
apt-get update
|
apt-get update
|
||||||
apt-get install -y pandoc
|
apt-get install -y pandoc
|
||||||
apt-get update && apt-get install -y parallel
|
apt-get update && apt-get install -y parallel retry
|
||||||
|
|
||||||
- name: Setup Jupyter Kernel
|
- name: Setup Jupyter Kernel
|
||||||
run: |
|
run: |
|
||||||
|
|||||||
@@ -72,7 +72,7 @@ srt_cpu = ["sglang[runtime_common]", "outlines>=0.0.44,<=0.1.11", "torch"]
|
|||||||
openai = ["openai>=1.0", "tiktoken"]
|
openai = ["openai>=1.0", "tiktoken"]
|
||||||
anthropic = ["anthropic>=0.20.0"]
|
anthropic = ["anthropic>=0.20.0"]
|
||||||
litellm = ["litellm>=1.0.0"]
|
litellm = ["litellm>=1.0.0"]
|
||||||
torch_memory_saver = ["torch_memory_saver"]
|
torch_memory_saver = ["torch_memory_saver>=0.0.3"]
|
||||||
test = [
|
test = [
|
||||||
"jsonlines",
|
"jsonlines",
|
||||||
"matplotlib",
|
"matplotlib",
|
||||||
|
|||||||
@@ -1790,6 +1790,9 @@ class Scheduler(
|
|||||||
return GetWeightsByNameReqOutput(parameter)
|
return GetWeightsByNameReqOutput(parameter)
|
||||||
|
|
||||||
def release_memory_occupation(self, recv_req: ReleaseMemoryOccupationReqInput):
|
def release_memory_occupation(self, recv_req: ReleaseMemoryOccupationReqInput):
|
||||||
|
self.memory_saver_adapter.check_validity(
|
||||||
|
caller_name="release_memory_occupation"
|
||||||
|
)
|
||||||
self.stashed_model_static_state = _export_static_state(
|
self.stashed_model_static_state = _export_static_state(
|
||||||
self.tp_worker.worker.model_runner.model
|
self.tp_worker.worker.model_runner.model
|
||||||
)
|
)
|
||||||
@@ -1798,6 +1801,7 @@ class Scheduler(
|
|||||||
return ReleaseMemoryOccupationReqOutput()
|
return ReleaseMemoryOccupationReqOutput()
|
||||||
|
|
||||||
def resume_memory_occupation(self, recv_req: ResumeMemoryOccupationReqInput):
|
def resume_memory_occupation(self, recv_req: ResumeMemoryOccupationReqInput):
|
||||||
|
self.memory_saver_adapter.check_validity(caller_name="resume_memory_occupation")
|
||||||
self.memory_saver_adapter.resume()
|
self.memory_saver_adapter.resume()
|
||||||
_import_static_state(
|
_import_static_state(
|
||||||
self.tp_worker.worker.model_runner.model, self.stashed_model_static_state
|
self.tp_worker.worker.model_runner.model, self.stashed_model_static_state
|
||||||
|
|||||||
@@ -287,7 +287,14 @@ class ModelRunner:
|
|||||||
def init_torch_distributed(self):
|
def init_torch_distributed(self):
|
||||||
logger.info("Init torch distributed begin.")
|
logger.info("Init torch distributed begin.")
|
||||||
|
|
||||||
torch.get_device_module(self.device).set_device(self.gpu_id)
|
try:
|
||||||
|
torch.get_device_module(self.device).set_device(self.gpu_id)
|
||||||
|
except Exception:
|
||||||
|
logger.warning(
|
||||||
|
f"Context: {self.device=} {self.gpu_id=} {os.environ.get('CUDA_VISIBLE_DEVICES')=} {self.tp_rank=} {self.tp_size=}"
|
||||||
|
)
|
||||||
|
raise
|
||||||
|
|
||||||
if self.device == "cuda":
|
if self.device == "cuda":
|
||||||
backend = "nccl"
|
backend = "nccl"
|
||||||
elif self.device == "xpu":
|
elif self.device == "xpu":
|
||||||
|
|||||||
@@ -1,3 +1,4 @@
|
|||||||
|
import logging
|
||||||
from abc import ABC
|
from abc import ABC
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
|
|
||||||
@@ -8,6 +9,8 @@ try:
|
|||||||
except ImportError:
|
except ImportError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class TorchMemorySaverAdapter(ABC):
|
class TorchMemorySaverAdapter(ABC):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@@ -16,6 +19,13 @@ class TorchMemorySaverAdapter(ABC):
|
|||||||
_TorchMemorySaverAdapterReal() if enable else _TorchMemorySaverAdapterNoop()
|
_TorchMemorySaverAdapterReal() if enable else _TorchMemorySaverAdapterNoop()
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def check_validity(self, caller_name):
|
||||||
|
if not self.enabled:
|
||||||
|
logger.warning(
|
||||||
|
f"`{caller_name}` will not save memory because torch_memory_saver is not enabled. "
|
||||||
|
f"Potential causes: `enable_memory_saver` is false, or torch_memory_saver has installation issues."
|
||||||
|
)
|
||||||
|
|
||||||
def configure_subprocess(self):
|
def configure_subprocess(self):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@@ -28,6 +38,10 @@ class TorchMemorySaverAdapter(ABC):
|
|||||||
def resume(self):
|
def resume(self):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@property
|
||||||
|
def enabled(self):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
class _TorchMemorySaverAdapterReal(TorchMemorySaverAdapter):
|
class _TorchMemorySaverAdapterReal(TorchMemorySaverAdapter):
|
||||||
def configure_subprocess(self):
|
def configure_subprocess(self):
|
||||||
@@ -42,6 +56,10 @@ class _TorchMemorySaverAdapterReal(TorchMemorySaverAdapter):
|
|||||||
def resume(self):
|
def resume(self):
|
||||||
return _primary_memory_saver.resume()
|
return _primary_memory_saver.resume()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def enabled(self):
|
||||||
|
return _primary_memory_saver.enabled
|
||||||
|
|
||||||
|
|
||||||
class _TorchMemorySaverAdapterNoop(TorchMemorySaverAdapter):
|
class _TorchMemorySaverAdapterNoop(TorchMemorySaverAdapter):
|
||||||
@contextmanager
|
@contextmanager
|
||||||
@@ -57,3 +75,7 @@ class _TorchMemorySaverAdapterNoop(TorchMemorySaverAdapter):
|
|||||||
|
|
||||||
def resume(self):
|
def resume(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@property
|
||||||
|
def enabled(self):
|
||||||
|
return False
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ from sglang.srt.layers.attention.flashattention_backend import FlashAttentionBac
|
|||||||
from sglang.srt.layers.radix_attention import RadixAttention
|
from sglang.srt.layers.radix_attention import RadixAttention
|
||||||
from sglang.srt.mem_cache.memory_pool import MHATokenToKVPool
|
from sglang.srt.mem_cache.memory_pool import MHATokenToKVPool
|
||||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
||||||
|
from sglang.test.test_utils import CustomTestCase
|
||||||
|
|
||||||
|
|
||||||
class MockModelRunner:
|
class MockModelRunner:
|
||||||
@@ -39,7 +40,7 @@ class MockReqToTokenPool:
|
|||||||
|
|
||||||
|
|
||||||
@unittest.skipIf(not torch.cuda.is_available(), "Test requires CUDA")
|
@unittest.skipIf(not torch.cuda.is_available(), "Test requires CUDA")
|
||||||
class TestFlashAttentionBackend(unittest.TestCase):
|
class TestFlashAttentionBackend(CustomTestCase):
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
"""Set up test fixtures before each test method."""
|
"""Set up test fixtures before each test method."""
|
||||||
self.model_runner = MockModelRunner()
|
self.model_runner = MockModelRunner()
|
||||||
|
|||||||
@@ -3,6 +3,7 @@
|
|||||||
import argparse
|
import argparse
|
||||||
import asyncio
|
import asyncio
|
||||||
import copy
|
import copy
|
||||||
|
import logging
|
||||||
import os
|
import os
|
||||||
import random
|
import random
|
||||||
import subprocess
|
import subprocess
|
||||||
@@ -922,6 +923,10 @@ def run_mulit_request_test(
|
|||||||
|
|
||||||
|
|
||||||
def write_github_step_summary(content):
|
def write_github_step_summary(content):
|
||||||
|
if not os.environ.get("GITHUB_STEP_SUMMARY"):
|
||||||
|
logging.warning("GITHUB_STEP_SUMMARY environment variable not set")
|
||||||
|
return
|
||||||
|
|
||||||
with open(os.environ["GITHUB_STEP_SUMMARY"], "a") as f:
|
with open(os.environ["GITHUB_STEP_SUMMARY"], "a") as f:
|
||||||
f.write(content)
|
f.write(content)
|
||||||
|
|
||||||
|
|||||||
@@ -46,7 +46,7 @@ class TestMLADeepseekV3ChannelInt8(CustomTestCase):
|
|||||||
metrics = run_eval_few_shot_gsm8k(args)
|
metrics = run_eval_few_shot_gsm8k(args)
|
||||||
print(metrics)
|
print(metrics)
|
||||||
|
|
||||||
self.assertGreater(metrics["accuracy"], 0.62)
|
self.assertGreaterEqual(metrics["accuracy"], 0.61)
|
||||||
|
|
||||||
|
|
||||||
class TestDeepseekV3MTPChannelInt8(CustomTestCase):
|
class TestDeepseekV3MTPChannelInt8(CustomTestCase):
|
||||||
|
|||||||
@@ -624,7 +624,6 @@ class TestMinicpmoServer(TestOpenAIVisionServer):
|
|||||||
"minicpmo",
|
"minicpmo",
|
||||||
"--mem-fraction-static",
|
"--mem-fraction-static",
|
||||||
"0.7",
|
"0.7",
|
||||||
"--tp=2",
|
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
cls.base_url += "/v1"
|
cls.base_url += "/v1"
|
||||||
|
|||||||
Reference in New Issue
Block a user