Fix dockerfile and triton cache manager (#720)
This commit is contained in:
@@ -23,18 +23,10 @@ RUN echo 'tzdata tzdata/Areas select America' | debconf-set-selections \
|
|||||||
RUN apt-get update -y \
|
RUN apt-get update -y \
|
||||||
&& apt-get install -y python3-pip git curl sudo
|
&& apt-get install -y python3-pip git curl sudo
|
||||||
|
|
||||||
# Workaround for https://github.com/openai/triton/issues/2507 and
|
|
||||||
# https://github.com/pytorch/pytorch/issues/107960 -- hopefully
|
|
||||||
# this won't be needed for future versions of this docker image
|
|
||||||
# or future versions of triton.
|
|
||||||
RUN ldconfig /usr/local/cuda-$(echo $CUDA_VERSION | cut -d. -f1,2)/compat/
|
|
||||||
|
|
||||||
WORKDIR /sgl-workspace
|
WORKDIR /sgl-workspace
|
||||||
|
|
||||||
RUN pip3 --no-cache-dir install --upgrade pip \
|
RUN pip3 --no-cache-dir install --upgrade pip \
|
||||||
&& pip3 --no-cache-dir install "sglang[all]" \
|
&& pip3 --no-cache-dir install "sglang[all]" \
|
||||||
&& pip3 --no-cache-dir uninstall -y triton triton-nightly \
|
|
||||||
&& pip3 --no-cache-dir install --no-deps --index-url https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/Triton-Nightly/pypi/simple/ triton-nightly \
|
|
||||||
&& pip3 --no-cache-dir install flashinfer -i https://flashinfer.ai/whl/cu121/torch2.3/
|
&& pip3 --no-cache-dir install flashinfer -i https://flashinfer.ai/whl/cu121/torch2.3/
|
||||||
|
|
||||||
ENV DEBIAN_FRONTEND=interactive
|
ENV DEBIAN_FRONTEND=interactive
|
||||||
|
|||||||
@@ -52,6 +52,7 @@ from sglang.srt.utils import (
|
|||||||
allocate_init_ports,
|
allocate_init_ports,
|
||||||
assert_pkg_version,
|
assert_pkg_version,
|
||||||
enable_show_time_cost,
|
enable_show_time_cost,
|
||||||
|
maybe_set_triton_cache_manager,
|
||||||
set_ulimit,
|
set_ulimit,
|
||||||
)
|
)
|
||||||
from sglang.utils import get_exception_traceback
|
from sglang.utils import get_exception_traceback
|
||||||
@@ -201,6 +202,11 @@ def launch_server(
|
|||||||
"reinstall the latest version by following the instructions "
|
"reinstall the latest version by following the instructions "
|
||||||
"at https://docs.flashinfer.ai/installation.html.",
|
"at https://docs.flashinfer.ai/installation.html.",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if server_args.tp_size // server_args.dp_size > 1:
|
||||||
|
# FIXME: remove this after https://github.com/triton-lang/triton/pull/4295 is used as a dependency.
|
||||||
|
maybe_set_triton_cache_manager()
|
||||||
|
|
||||||
if server_args.chat_template:
|
if server_args.chat_template:
|
||||||
# TODO: replace this with huggingface transformers template
|
# TODO: replace this with huggingface transformers template
|
||||||
load_chat_template_for_openai_api(server_args.chat_template)
|
load_chat_template_for_openai_api(server_args.chat_template)
|
||||||
|
|||||||
@@ -18,10 +18,15 @@ import psutil
|
|||||||
import requests
|
import requests
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
import triton
|
|
||||||
from fastapi.responses import JSONResponse
|
from fastapi.responses import JSONResponse
|
||||||
from packaging import version as pkg_version
|
from packaging import version as pkg_version
|
||||||
from starlette.middleware.base import BaseHTTPMiddleware
|
from starlette.middleware.base import BaseHTTPMiddleware
|
||||||
|
from triton.runtime.cache import (
|
||||||
|
FileCacheManager,
|
||||||
|
default_cache_dir,
|
||||||
|
default_dump_dir,
|
||||||
|
default_override_dir,
|
||||||
|
)
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -460,6 +465,44 @@ def monkey_patch_vllm_all_gather(reverse: bool = False):
|
|||||||
setattr(GroupCoordinator, "all_gather", all_gather)
|
setattr(GroupCoordinator, "all_gather", all_gather)
|
||||||
|
|
||||||
|
|
||||||
|
def maybe_set_triton_cache_manager() -> None:
|
||||||
|
"""Set environment variable to tell Triton to use a
|
||||||
|
custom cache manager"""
|
||||||
|
cache_manger = os.environ.get("TRITON_CACHE_MANAGER", None)
|
||||||
|
if cache_manger is None:
|
||||||
|
manager = "sglang.srt.utils:CustomCacheManager"
|
||||||
|
logger.info("Setting Triton cache manager to: %s", manager)
|
||||||
|
os.environ["TRITON_CACHE_MANAGER"] = manager
|
||||||
|
|
||||||
|
|
||||||
|
class CustomCacheManager(FileCacheManager):
|
||||||
|
# Adapted from: https://github.com/tdoublep/vllm/blob/3307522289fdfefe323b6c00d0db696651989a2f/vllm/triton_utils/custom_cache_manager.py
|
||||||
|
def __init__(self, key, override=False, dump=False):
|
||||||
|
|
||||||
|
self.key = key
|
||||||
|
self.lock_path = None
|
||||||
|
if dump:
|
||||||
|
self.cache_dir = default_dump_dir()
|
||||||
|
self.cache_dir = os.path.join(self.cache_dir, self.key)
|
||||||
|
self.lock_path = os.path.join(self.cache_dir, "lock")
|
||||||
|
os.makedirs(self.cache_dir, exist_ok=True)
|
||||||
|
elif override:
|
||||||
|
self.cache_dir = default_override_dir()
|
||||||
|
self.cache_dir = os.path.join(self.cache_dir, self.key)
|
||||||
|
else:
|
||||||
|
# create cache directory if it doesn't exist
|
||||||
|
self.cache_dir = (
|
||||||
|
os.getenv("TRITON_CACHE_DIR", "").strip() or default_cache_dir()
|
||||||
|
)
|
||||||
|
if self.cache_dir:
|
||||||
|
self.cache_dir = f"{self.cache_dir}_{os.getpid()}"
|
||||||
|
self.cache_dir = os.path.join(self.cache_dir, self.key)
|
||||||
|
self.lock_path = os.path.join(self.cache_dir, "lock")
|
||||||
|
os.makedirs(self.cache_dir, exist_ok=True)
|
||||||
|
else:
|
||||||
|
raise RuntimeError("Could not create or locate cache dir")
|
||||||
|
|
||||||
|
|
||||||
API_KEY_HEADER_NAME = "X-API-Key"
|
API_KEY_HEADER_NAME = "X-API-Key"
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user