diff --git a/README.md b/README.md index e6f518d5c..bc2976d23 100644 --- a/README.md +++ b/README.md @@ -32,6 +32,10 @@ pip install --upgrade pip pip install -e "python[all]" ``` +### Notes +- If you are using older GPUs (NVIDIA T4, V100), please use `pip install "triton>=2.2.0"` to avoid some bugs in the triton compiler +- If you only need to use the OpenAI backend, you can avoid installing other dependencies by using `pip install sglang[openai]` + ## Quick Start The example below shows how to use sglang to answer a mulit-turn question. @@ -197,7 +201,7 @@ for out in state.text_iter(): ## Backend: SGLang Runtime (SRT) The SGLang Runtime (SRT) is designed to work best with the SGLang frontend. However, it can also be used as a standalone API server. -In this case, the [RadixAttention](https://arxiv.org/abs/2312.07104) can still greatly accelerate many use cases. +In this case, the [RadixAttention](https://arxiv.org/abs/2312.07104) can still greatly accelerate many use cases with automatic KV cache reuse. ### Usage Launch a server @@ -221,6 +225,10 @@ curl http://localhost:30000/v1/completions \ ``` python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000 --tp 2 ``` +- If you see out-of-memory errors during serving, please try to reduce the memory usage of the KV cache pool by setting a smaller value of `--mem-fraction-static`. The default value is `0.9` +``` +python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000 --mem-fraction-static 0.7 +``` ### Supported Models - Llama diff --git a/python/pyproject.toml b/python/pyproject.toml index ba91943f5..1d4677b8d 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "sglang" -version = "0.1.3" +version = "0.1.4" description = "A structured generation langauge for LLMs." readme = "README.md" requires-python = ">=3.8" diff --git a/python/sglang/__init__.py b/python/sglang/__init__.py index ed3fa966e..957532f3e 100644 --- a/python/sglang/__init__.py +++ b/python/sglang/__init__.py @@ -1,4 +1,4 @@ -__version__ = "0.1.3" +__version__ = "0.1.4" from sglang.api import * from sglang.global_config import global_config diff --git a/python/sglang/srt/layers/context_flashattention_nopad.py b/python/sglang/srt/layers/context_flashattention_nopad.py index 6159e9a51..657cf9f9b 100644 --- a/python/sglang/srt/layers/context_flashattention_nopad.py +++ b/python/sglang/srt/layers/context_flashattention_nopad.py @@ -6,6 +6,9 @@ import triton.language as tl from sglang.srt.utils import wrap_kernel_launcher +CUDA_CAPABILITY = torch.cuda.get_device_capability() + + @triton.jit def _fwd_kernel( Q, @@ -120,7 +123,11 @@ cached_kernel = None def context_attention_fwd(q, k, v, o, b_start_loc, b_seq_len, max_input_len): - BLOCK = 128 + if CUDA_CAPABILITY[0] >= 8: + BLOCK = 128 + else: + BLOCK = 64 + Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1] assert Lq == Lk and Lk == Lv assert Lk in {16, 32, 64, 128} diff --git a/python/sglang/srt/layers/extend_attention.py b/python/sglang/srt/layers/extend_attention.py index 18f403ae6..d1269e726 100644 --- a/python/sglang/srt/layers/extend_attention.py +++ b/python/sglang/srt/layers/extend_attention.py @@ -2,6 +2,10 @@ import torch import triton import triton.language as tl from sglang.srt.layers.context_flashattention_nopad import context_attention_fwd +from sglang.srt.utils import wrap_kernel_launcher + + +CUDA_CAPABILITY = torch.cuda.get_device_capability() @triton.jit @@ -153,6 +157,9 @@ def _fwd_kernel( tl.store(O_Extend + offs_o, acc / deno[:, None], mask=mask_m[:, None]) +cached_kernel = None + + def extend_attention_fwd( q_extend, k_extend, @@ -175,7 +182,11 @@ def extend_attention_fwd( k_buffer, v_buffer: (prefix + extend) tensors in mem_manager """ - BLOCK_M, BLOCK_N = 128, 128 + if CUDA_CAPABILITY[0] >= 8: + BLOCK_M, BLOCK_N = 128, 128 + else: + BLOCK_M, BLOCK_N = 64, 64 + Lq, Lk, Lv, Lo = ( q_extend.shape[-1], k_extend.shape[-1], @@ -193,6 +204,40 @@ def extend_attention_fwd( num_warps = 4 if Lk <= 64 else 8 num_stages = 1 + global cached_kernel + if cached_kernel: + cached_kernel( + grid, + num_warps, + q_extend, + k_extend, + v_extend, + o_extend, + k_buffer, + v_buffer, + req_to_tokens, + b_req_idx, + b_seq_len, + b_start_loc_extend, + b_seq_len_extend, + sm_scale, + kv_group_num, + q_extend.stride(0), + q_extend.stride(1), + k_extend.stride(0), + k_extend.stride(1), + v_extend.stride(0), + v_extend.stride(1), + o_extend.stride(0), + o_extend.stride(1), + k_buffer.stride(0), + k_buffer.stride(1), + v_buffer.stride(0), + v_buffer.stride(1), + req_to_tokens.stride(0), + ) + return + _fwd_kernel[grid]( q_extend, k_extend, @@ -226,6 +271,7 @@ def extend_attention_fwd( num_warps=num_warps, num_stages=num_stages, ) + cached_kernel = wrap_kernel_launcher(_fwd_kernel) def redundant_attention( diff --git a/python/sglang/srt/managers/router/model_rpc.py b/python/sglang/srt/managers/router/model_rpc.py index 5aec63311..bedd60914 100644 --- a/python/sglang/srt/managers/router/model_rpc.py +++ b/python/sglang/srt/managers/router/model_rpc.py @@ -5,6 +5,7 @@ import time from concurrent.futures import ThreadPoolExecutor from enum import Enum, auto from typing import Dict, List, Optional, Tuple, Union +import warnings import numpy as np import rpyc @@ -164,7 +165,7 @@ class ModelRpcServer(rpyc.Service): + self.tree_cache.evictable_size() ) if available_size != self.max_total_num_token: - logger.warning( + warnings.warn( "Warning: " f"available_size={available_size}, max_total_num_token={self.max_total_num_token}\n" "KV cache pool leak detected!"