From 1df6eabd5d362b4e3ac53b8d195e0f8c14d22b54 Mon Sep 17 00:00:00 2001 From: Andrew Smith <101757907+andjsmi@users.noreply.github.com> Date: Fri, 21 Feb 2025 22:31:09 +1100 Subject: [PATCH] feat: Add SageMaker support (#3740) --- docker/Dockerfile.sagemaker | 78 ++++++++ docker/serve | 31 ++++ python/sglang/srt/entrypoints/http_server.py | 12 ++ test/srt/test_sagemaker_server.py | 178 +++++++++++++++++++ 4 files changed, 299 insertions(+) create mode 100644 docker/Dockerfile.sagemaker create mode 100755 docker/serve create mode 100644 test/srt/test_sagemaker_server.py diff --git a/docker/Dockerfile.sagemaker b/docker/Dockerfile.sagemaker new file mode 100644 index 000000000..fde8d556e --- /dev/null +++ b/docker/Dockerfile.sagemaker @@ -0,0 +1,78 @@ +ARG CUDA_VERSION=12.5.1 + +FROM nvcr.io/nvidia/tritonserver:24.04-py3-min + +ARG BUILD_TYPE=all +ENV DEBIAN_FRONTEND=noninteractive + +RUN echo 'tzdata tzdata/Areas select America' | debconf-set-selections \ + && echo 'tzdata tzdata/Zones/America select Los_Angeles' | debconf-set-selections \ + && apt update -y \ + && apt install software-properties-common -y \ + && add-apt-repository ppa:deadsnakes/ppa -y && apt update \ + && apt install python3.10 python3.10-dev -y \ + && update-alternatives --install /usr/bin/python3 python3 /usr/bin/python3.10 1 \ + && update-alternatives --set python3 /usr/bin/python3.10 && apt install python3.10-distutils -y \ + && apt install curl git sudo libibverbs-dev -y \ + && apt install -y rdma-core infiniband-diags openssh-server perftest ibverbs-providers libibumad3 libibverbs1 libnl-3-200 libnl-route-3-200 librdmacm1 \ + && curl https://bootstrap.pypa.io/get-pip.py -o get-pip.py && python3 get-pip.py \ + && python3 --version \ + && python3 -m pip --version \ + && rm -rf /var/lib/apt/lists/* \ + && apt clean + +# For openbmb/MiniCPM models +RUN pip3 install datamodel_code_generator + +WORKDIR /sgl-workspace + +ARG CUDA_VERSION +RUN python3 -m pip install --upgrade pip setuptools wheel html5lib six \ + && git clone --depth=1 https://github.com/sgl-project/sglang.git \ + && if [ "$CUDA_VERSION" = "12.1.1" ]; then \ + python3 -m pip install torch --index-url https://download.pytorch.org/whl/cu121; \ + elif [ "$CUDA_VERSION" = "12.4.1" ]; then \ + python3 -m pip install torch --index-url https://download.pytorch.org/whl/cu124; \ + elif [ "$CUDA_VERSION" = "12.5.1" ]; then \ + python3 -m pip install torch --index-url https://download.pytorch.org/whl/cu124; \ + elif [ "$CUDA_VERSION" = "11.8.0" ]; then \ + python3 -m pip install torch --index-url https://download.pytorch.org/whl/cu118; \ + python3 -m pip install sgl-kernel -i https://docs.sglang.ai/whl/cu118; \ + else \ + echo "Unsupported CUDA version: $CUDA_VERSION" && exit 1; \ + fi \ + && cd sglang \ + && if [ "$BUILD_TYPE" = "srt" ]; then \ + if [ "$CUDA_VERSION" = "12.1.1" ]; then \ + python3 -m pip --no-cache-dir install -e "python[srt]" --find-links https://flashinfer.ai/whl/cu121/torch2.5/flashinfer-python; \ + elif [ "$CUDA_VERSION" = "12.4.1" ]; then \ + python3 -m pip --no-cache-dir install -e "python[srt]" --find-links https://flashinfer.ai/whl/cu124/torch2.5/flashinfer-python; \ + elif [ "$CUDA_VERSION" = "12.5.1" ]; then \ + python3 -m pip --no-cache-dir install -e "python[srt]" --find-links https://flashinfer.ai/whl/cu124/torch2.5/flashinfer-python; \ + elif [ "$CUDA_VERSION" = "11.8.0" ]; then \ + python3 -m pip --no-cache-dir install -e "python[srt]" --find-links https://flashinfer.ai/whl/cu118/torch2.5/flashinfer-python; \ + python3 -m pip install sgl-kernel -i https://docs.sglang.ai/whl/cu118; \ + else \ + echo "Unsupported CUDA version: $CUDA_VERSION" && exit 1; \ + fi; \ + else \ + if [ "$CUDA_VERSION" = "12.1.1" ]; then \ + python3 -m pip --no-cache-dir install -e "python[all]" --find-links https://flashinfer.ai/whl/cu121/torch2.5/flashinfer-python; \ + elif [ "$CUDA_VERSION" = "12.4.1" ]; then \ + python3 -m pip --no-cache-dir install -e "python[all]" --find-links https://flashinfer.ai/whl/cu124/torch2.5/flashinfer-python; \ + elif [ "$CUDA_VERSION" = "12.5.1" ]; then \ + python3 -m pip --no-cache-dir install -e "python[all]" --find-links https://flashinfer.ai/whl/cu124/torch2.5/flashinfer-python; \ + elif [ "$CUDA_VERSION" = "11.8.0" ]; then \ + python3 -m pip --no-cache-dir install -e "python[all]" --find-links https://flashinfer.ai/whl/cu118/torch2.5/flashinfer-python; \ + python3 -m pip install sgl-kernel -i https://docs.sglang.ai/whl/cu118; \ + else \ + echo "Unsupported CUDA version: $CUDA_VERSION" && exit 1; \ + fi; \ + fi + +ENV DEBIAN_FRONTEND=interactive + +COPY serve /usr/bin/serve +RUN chmod 777 /usr/bin/serve + +ENTRYPOINT [ "/usr/bin/serve" ] diff --git a/docker/serve b/docker/serve new file mode 100755 index 000000000..493ecbd23 --- /dev/null +++ b/docker/serve @@ -0,0 +1,31 @@ +#!/bin/bash + +echo "Starting server" + +SERVER_ARGS="--host 0.0.0.0 --port 8080" + +if [ -n "$TENSOR_PARALLEL_DEGREE" ]; then + SERVER_ARGS="${SERVER_ARGS} --tp-size ${TENSOR_PARALLEL_DEGREE}" +fi + +if [ -n "$DATA_PARALLEL_DEGREE" ]; then + SERVER_ARGS="${SERVER_ARGS} --dp-size ${DATA_PARALLEL_DEGREE}" +fi + +if [ -n "$EXPERT_PARALLEL_DEGREE" ]; then + SERVER_ARGS="${SERVER_ARGS} --ep-size ${EXPERT_PARALLEL_DEGREE}" +fi + +if [ -n "$MEM_FRACTION_STATIC" ]; then + SERVER_ARGS="${SERVER_ARGS} --mem-fraction-static ${MEM_FRACTION_STATIC}" +fi + +if [ -n "$QUANTIZATION" ]; then + SERVER_ARGS="${SERVER_ARGS} --quantization ${QUANTIZATION}" +fi + +if [ -n "$CHUNKED_PREFILL_SIZE" ]; then + SERVER_ARGS="${SERVER_ARGS} --chunked-prefill-size ${CHUNKED_PREFILL_SIZE}" +fi + +python3 -m sglang.launch_server --model-path /opt/ml/model $SERVER_ARGS diff --git a/python/sglang/srt/entrypoints/http_server.py b/python/sglang/srt/entrypoints/http_server.py index 1759cd2bb..2b2421a37 100644 --- a/python/sglang/srt/entrypoints/http_server.py +++ b/python/sglang/srt/entrypoints/http_server.py @@ -463,6 +463,18 @@ async def retrieve_file_content(file_id: str): return await v1_retrieve_file_content(file_id) +## SageMaker API +@app.get("/ping") +async def sagemaker_health() -> Response: + """Check the health of the http server.""" + return Response(status_code=200) + + +@app.post("/invocations") +async def sagemaker_chat_completions(raw_request: Request): + return await v1_chat_completions(_global_state.tokenizer_manager, raw_request) + + def _create_error_response(e): return ORJSONResponse( {"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST diff --git a/test/srt/test_sagemaker_server.py b/test/srt/test_sagemaker_server.py new file mode 100644 index 000000000..fab7ca4dc --- /dev/null +++ b/test/srt/test_sagemaker_server.py @@ -0,0 +1,178 @@ +""" +python3 -m unittest test_sagemaker_server.TestSageMakerServer.test_chat_completion +""" + +import json +import unittest + +import requests + +from sglang.srt.hf_transformers_utils import get_tokenizer +from sglang.srt.utils import kill_process_tree +from sglang.test.test_utils import ( + DEFAULT_SMALL_MODEL_NAME_FOR_TEST, + DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + DEFAULT_URL_FOR_TEST, + popen_launch_server, +) + + +class TestSageMakerServer(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST + cls.base_url = DEFAULT_URL_FOR_TEST + cls.api_key = "sk-123456" + cls.process = popen_launch_server( + cls.model, + cls.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + api_key=cls.api_key, + ) + cls.tokenizer = get_tokenizer(DEFAULT_SMALL_MODEL_NAME_FOR_TEST) + + @classmethod + def tearDownClass(cls): + kill_process_tree(cls.process.pid) + + def run_chat_completion(self, logprobs, parallel_sample_num): + data = { + "model": self.model, + "messages": [ + {"role": "system", "content": "You are a helpful AI assistant"}, + { + "role": "user", + "content": "What is the capital of France? Answer in a few words.", + }, + ], + "temperature": 0, + "logprobs": logprobs is not None and logprobs > 0, + "top_logprobs": logprobs, + "n": parallel_sample_num, + } + + headers = {"Authorization": f"Bearer {self.api_key}"} + + response = requests.post( + f"{self.base_url}/invocations", json=data, headers=headers + ).json() + + if logprobs: + assert isinstance( + response["choices"][0]["logprobs"]["content"][0]["top_logprobs"][0][ + "token" + ], + str, + ) + + ret_num_top_logprobs = len( + response["choices"][0]["logprobs"]["content"][0]["top_logprobs"] + ) + assert ( + ret_num_top_logprobs == logprobs + ), f"{ret_num_top_logprobs} vs {logprobs}" + + assert len(response["choices"]) == parallel_sample_num + assert response["choices"][0]["message"]["role"] == "assistant" + assert isinstance(response["choices"][0]["message"]["content"], str) + assert response["id"] + assert response["created"] + assert response["usage"]["prompt_tokens"] > 0 + assert response["usage"]["completion_tokens"] > 0 + assert response["usage"]["total_tokens"] > 0 + + def run_chat_completion_stream(self, logprobs, parallel_sample_num=1): + data = { + "model": self.model, + "messages": [ + {"role": "system", "content": "You are a helpful AI assistant"}, + { + "role": "user", + "content": "What is the capital of France? Answer in a few words.", + }, + ], + "temperature": 0, + "logprobs": logprobs is not None and logprobs > 0, + "top_logprobs": logprobs, + "stream": True, + "stream_options": {"include_usage": True}, + "n": parallel_sample_num, + } + + headers = {"Authorization": f"Bearer {self.api_key}"} + + response = requests.post( + f"{self.base_url}/invocations", json=data, stream=True, headers=headers + ) + + is_firsts = {} + for line in response.iter_lines(): + line = line.decode("utf-8").replace("data: ", "") + if len(line) < 1 or line == "[DONE]": + continue + print(f"value: {line}") + line = json.loads(line) + usage = line.get("usage") + if usage is not None: + assert usage["prompt_tokens"] > 0 + assert usage["completion_tokens"] > 0 + assert usage["total_tokens"] > 0 + continue + + index = line.get("choices")[0].get("index") + data = line.get("choices")[0].get("delta") + + if is_firsts.get(index, True): + assert data["role"] == "assistant" + is_firsts[index] = False + continue + + if logprobs: + assert line.get("choices")[0].get("logprobs") + assert isinstance( + line.get("choices")[0] + .get("logprobs") + .get("content")[0] + .get("top_logprobs")[0] + .get("token"), + str, + ) + assert isinstance( + line.get("choices")[0] + .get("logprobs") + .get("content")[0] + .get("top_logprobs"), + list, + ) + ret_num_top_logprobs = len( + line.get("choices")[0] + .get("logprobs") + .get("content")[0] + .get("top_logprobs") + ) + assert ( + ret_num_top_logprobs == logprobs + ), f"{ret_num_top_logprobs} vs {logprobs}" + + assert isinstance(data["content"], str) + assert line["id"] + assert line["created"] + + for index in [i for i in range(parallel_sample_num)]: + assert not is_firsts.get( + index, True + ), f"index {index} is not found in the response" + + def test_chat_completion(self): + for logprobs in [None, 5]: + for parallel_sample_num in [1, 2]: + self.run_chat_completion(logprobs, parallel_sample_num) + + def test_chat_completion_stream(self): + for logprobs in [None, 5]: + for parallel_sample_num in [1, 2]: + self.run_chat_completion_stream(logprobs, parallel_sample_num) + + +if __name__ == "__main__": + unittest.main()