diff --git a/README.md b/README.md index 8af73c49d..7ebada73d 100644 --- a/README.md +++ b/README.md @@ -139,7 +139,7 @@ sky status --endpoint 30000 sglang ### Common Notes -- [FlashInfer](https://github.com/flashinfer-ai/flashinfer) is currently one of the dependencies that must be installed for SGLang. It only supports sm75 and above. If you encounter any FlashInfer-related issues on sm75+ devices (e.g., T4, A10, A100, L4, L40S, H100), consider using Triton's kernel by `--disable-flashinfer --disable-flashinfer-sampling` and raise an issue. +- [FlashInfer](https://github.com/flashinfer-ai/flashinfer) is the default attention kernel backend. It only supports sm75 and above. If you encounter any FlashInfer-related issues on sm75+ devices (e.g., T4, A10, A100, L4, L40S, H100), please disable it by adding `--disable-flashinfer --disable-flashinfer-sampling` and open an issue on GitHub. - If you only need to use the OpenAI backend, you can avoid installing other dependencies by using `pip install "sglang[openai]"`. ## Backend: SGLang Runtime (SRT) diff --git a/docs/en/install.md b/docs/en/install.md index 656bc6840..60645ce84 100644 --- a/docs/en/install.md +++ b/docs/en/install.md @@ -92,5 +92,5 @@ sky status --endpoint 30000 sglang ### Common Notes -- [FlashInfer](https://github.com/flashinfer-ai/flashinfer) is currently one of the dependencies that must be installed for SGLang. It only supports sm75 and above. If you encounter any FlashInfer-related issues on sm75+ devices (e.g., T4, A10, A100, L4, L40S, H100), consider using Triton's kernel by `--disable-flashinfer --disable-flashinfer-sampling` and raise an issue. +- [FlashInfer](https://github.com/flashinfer-ai/flashinfer) is the default attention kernel backend. It only supports sm75 and above. If you encounter any FlashInfer-related issues on sm75+ devices (e.g., T4, A10, A100, L4, L40S, H100), please disable it by adding `--disable-flashinfer --disable-flashinfer-sampling` and open an issue on GitHub. - If you only need to use the OpenAI backend, you can avoid installing other dependencies by using `pip install "sglang[openai]"`. diff --git a/python/sglang/srt/layers/radix_attention.py b/python/sglang/srt/layers/radix_attention.py index adada7cda..48567e43d 100644 --- a/python/sglang/srt/layers/radix_attention.py +++ b/python/sglang/srt/layers/radix_attention.py @@ -61,14 +61,18 @@ class RadixAttention(nn.Module): # Choose backend if ( - not global_server_args_dict.get("disable_flashinfer", False) + global_server_args_dict["attention_backend"] == "flashinfer" and self.qk_head_dim == self.v_head_dim ): self.extend_forward = self.extend_forward_flashinfer self.decode_forward = self.decode_forward_flashinfer - else: + elif global_server_args_dict["attention_backend"] == "triton": self.extend_forward = self.extend_forward_triton self.decode_forward = self.decode_forward_triton + else: + raise ValueError( + f"Invalid attention backend: {global_server_args_dict['attention_backend']}" + ) def extend_forward_triton(self, q, k, v, input_metadata: InputMetadata): if self.qk_head_dim != self.v_head_dim: diff --git a/python/sglang/srt/layers/sampler.py b/python/sglang/srt/layers/sampler.py index 6cb7d5b55..16b6b80e9 100644 --- a/python/sglang/srt/layers/sampler.py +++ b/python/sglang/srt/layers/sampler.py @@ -78,7 +78,7 @@ class Sampler(CustomOp): probs = self._get_probs(logits, sampling_info) - if not global_server_args_dict["disable_flashinfer_sampling"]: + if global_server_args_dict["sampling_backend"] == "flashinfer": max_top_k_round, batch_size = 32, probs.shape[0] uniform_samples = torch.rand( (max_top_k_round, batch_size), device=probs.device @@ -93,11 +93,15 @@ class Sampler(CustomOp): batch_next_token_ids, success = flashinfer_top_k_top_p( probs, uniform_samples, sampling_info.top_ks, sampling_info.top_ps ) - else: + elif global_server_args_dict["sampling_backend"] == "pytorch": # Here we provide a slower fallback implementation. batch_next_token_ids, success = top_k_top_p_min_p_sampling_from_probs_torch( probs, sampling_info.top_ks, sampling_info.top_ps, sampling_info.min_ps ) + else: + raise ValueError( + f"Invalid sampling backend: {global_server_args_dict['sampling_backend']}" + ) return SampleOutput(success, probs, batch_next_token_ids) diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index 2e2489cd2..e46177bdb 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -31,6 +31,7 @@ from sglang.srt.mem_cache.chunk_cache import ChunkCache from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool from sglang.srt.model_executor.forward_batch_info import ForwardMode from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo +from sglang.srt.server_args import ServerArgs if TYPE_CHECKING: from sglang.srt.layers.sampler import SampleOutput @@ -40,10 +41,11 @@ INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5 # Put some global args for easy access global_server_args_dict = { - "disable_flashinfer": False, - "disable_flashinfer_sampling": False, - "triton_attention_reduce_in_fp32": False, - "enable_mla": False, + "attention_backend": ServerArgs.attention_backend, + "sampling_backend": ServerArgs.sampling_backend, + "triton_attention_reduce_in_fp32": ServerArgs.triton_attention_reduce_in_fp32, + "enable_mla": ServerArgs.enable_mla, + "torchao_config": ServerArgs.torchao_config, } diff --git a/python/sglang/srt/managers/tp_worker.py b/python/sglang/srt/managers/tp_worker.py index fe7c4bcab..b1131b011 100644 --- a/python/sglang/srt/managers/tp_worker.py +++ b/python/sglang/srt/managers/tp_worker.py @@ -128,7 +128,7 @@ class ModelTpServer: if server_args.max_running_requests is None else server_args.max_running_requests ), - self.model_runner.req_to_token_pool.size - 1, + self.model_runner.req_to_token_pool.size, ) self.max_req_input_len = min( self.model_config.context_len - 1, diff --git a/python/sglang/srt/model_executor/forward_batch_info.py b/python/sglang/srt/model_executor/forward_batch_info.py index 867bd95a1..c158b3ce2 100644 --- a/python/sglang/srt/model_executor/forward_batch_info.py +++ b/python/sglang/srt/model_executor/forward_batch_info.py @@ -203,17 +203,17 @@ class InputMetadata: ret.compute_extend_infos(batch) fm = batch.forward_mode - if not fm.is_decode() or model_runner.server_args.disable_flashinfer: + if not fm.is_decode() or model_runner.server_args.attention_backend == "triton": ret.total_num_tokens = int(torch.sum(ret.seq_lens)) if not fm.is_decode(): ret.init_multimuldal_info(batch) - if model_runner.server_args.disable_flashinfer: + if model_runner.server_args.attention_backend == "triton": ret.init_triton_args(batch) flashinfer_use_ragged = False - if not model_runner.server_args.disable_flashinfer: + if model_runner.server_args.attention_backend == "flashinfer": if ( not fm.is_decode() and int(torch.sum(ret.seq_lens)) > 4096 diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 3033a7ce4..b04b0d7c0 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -53,7 +53,7 @@ from sglang.srt.mem_cache.memory_pool import ( MLATokenToKVPool, ReqToTokenPool, ) -from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetadata +from sglang.srt.model_executor.forward_batch_info import InputMetadata from sglang.srt.server_args import ServerArgs from sglang.srt.utils import ( get_available_gpu_memory, @@ -92,8 +92,8 @@ class ModelRunner: ) global_server_args_dict.update( { - "disable_flashinfer": server_args.disable_flashinfer, - "disable_flashinfer_sampling": server_args.disable_flashinfer_sampling, + "attention_backend": server_args.attention_backend, + "sampling_backend": server_args.sampling_backend, "triton_attention_reduce_in_fp32": server_args.triton_attention_reduce_in_fp32, "enable_mla": server_args.enable_mla, "torchao_config": server_args.torchao_config, @@ -111,7 +111,7 @@ class ModelRunner: self.load_model() self.init_memory_pool( min_per_gpu_memory, - server_args.max_num_reqs, + server_args.max_running_requests, server_args.max_total_tokens, ) self.init_cublas() @@ -344,8 +344,8 @@ class ModelRunner: def init_memory_pool( self, total_gpu_memory: int, - max_num_reqs: int = None, - max_total_tokens: int = None, + max_num_reqs: Optional[int] = None, + max_total_tokens: Optional[int] = None, ): if self.server_args.kv_cache_dtype == "auto": self.kv_cache_dtype = self.dtype @@ -379,7 +379,7 @@ class ModelRunner: ), 2048, ), - 5120, + 4096, ) self.req_to_token_pool = ReqToTokenPool( @@ -399,7 +399,7 @@ class ModelRunner: ) logger.info("using MLA Triton implementaion, flashinfer is disabled") # FIXME: temporarily only Triton MLA is supported - self.server_args.disable_flashinfer = True + self.server_args.attention_backend = "triton" else: self.token_to_kv_pool = MHATokenToKVPool( self.max_total_num_tokens, @@ -424,7 +424,7 @@ class ModelRunner: def init_flashinfer(self): """Init flashinfer attention kernel wrappers.""" - if self.server_args.disable_flashinfer: + if self.server_args.attention_backend != "flashinfer": assert ( self.sliding_window_size is None ), "turn on flashinfer to support window attention" @@ -491,7 +491,10 @@ class ModelRunner: from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner - if self.server_args.disable_cuda_graph or self.server_args.disable_flashinfer: + if ( + self.server_args.disable_cuda_graph + or self.server_args.attention_backend != "flashinfer" + ): self.cuda_graph_runner = None return diff --git a/python/sglang/srt/server.py b/python/sglang/srt/server.py index b73a01265..4aaf018a1 100644 --- a/python/sglang/srt/server.py +++ b/python/sglang/srt/server.py @@ -425,7 +425,7 @@ def _set_envs_and_config(server_args: ServerArgs): maybe_set_triton_cache_manager() # Check flashinfer version - if not server_args.disable_flashinfer: + if server_args.attention_backend == "flashinfer": assert_pkg_version( "flashinfer", "0.1.6", diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 3dfb1dc41..0881344c0 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -17,7 +17,6 @@ limitations under the License. import argparse import dataclasses -import json import logging import random from typing import List, Optional, Union @@ -50,7 +49,6 @@ class ServerArgs: # Memory and scheduling mem_fraction_static: Optional[float] = None max_running_requests: Optional[int] = None - max_num_reqs: Optional[int] = None max_total_tokens: Optional[int] = None chunked_prefill_size: int = 8192 max_prefill_tokens: int = 16384 @@ -85,6 +83,9 @@ class ServerArgs: json_model_override_args: str = "{}" # Optimization/debug options + attention_backend: str = "flashinfer" + sampling_backend: str = "flashinfer" + disable_flashinfer: bool = False disable_flashinfer_sampling: bool = False disable_radix_cache: bool = False @@ -101,6 +102,7 @@ class ServerArgs: triton_attention_reduce_in_fp32: bool = False def __post_init__(self): + # Set missing default values if self.tokenizer_path is None: self.tokenizer_path = self.model_path @@ -111,6 +113,7 @@ class ServerArgs: # Disable chunked prefill self.chunked_prefill_size = None + # Mem fraction depends on the tensor parallelism size if self.mem_fraction_static is None: if self.tp_size >= 16: self.mem_fraction_static = 0.79 @@ -131,6 +134,29 @@ class ServerArgs: if self.random_seed is None: self.random_seed = random.randint(0, 1 << 30) + # Deprecation warnings + if self.disable_flashinfer: + logger.warning( + "The option '--disable-flashinfer' will be deprecated in the next release. " + "Please use '--attention-backend triton' instead." + ) + if self.disable_flashinfer_sampling: + logger.warning( + "The option '--disable-flashinfer-sampling' will be deprecated in the next release. " + "Please use '--sampling-backend pytorch' instead. " + ) + + # Model-specific patches + if "Alibaba-NLP/gte-Qwen2-1.5B-instruct" == self.model_path: + logger.info( + "Not sure why, the tokenizer will add an additional token at the end of the prompt when trust_remote_mode=True" + ) + self.trust_remote_code = False + + if "gemma-2" in self.model_path.lower(): + logger.info("When using sliding window in gemma-2, turn on flashinfer.") + self.attention_backend = "flashinfer" + @staticmethod def add_cli_args(parser: argparse.ArgumentParser): parser.add_argument( @@ -214,11 +240,6 @@ class ServerArgs: action="store_true", help="Whether or not to allow for custom models defined on the Hub in their own modeling files.", ) - parser.add_argument( - "--is-embedding", - action="store_true", - help="Whether to use a CausalLM as an embedding model.", - ) parser.add_argument( "--context-length", type=int, @@ -253,6 +274,11 @@ class ServerArgs: default=ServerArgs.chat_template, help="The buliltin chat template name or the path of the chat template file. This is only used for OpenAI-compatible API server.", ) + parser.add_argument( + "--is-embedding", + action="store_true", + help="Whether to use a CausalLM as an embedding model.", + ) parser.add_argument( "--mem-fraction-static", type=float, @@ -265,17 +291,12 @@ class ServerArgs: default=ServerArgs.max_running_requests, help="The maximum number of running requests.", ) - parser.add_argument( - "--max-num-reqs", - type=int, - default=ServerArgs.max_num_reqs, - help="The maximum number of requests to serve in the memory pool. If the model have a large context length, you may need to decrease this value to avoid out-of-memory errors.", - ) parser.add_argument( "--max-total-tokens", type=int, default=ServerArgs.max_total_tokens, - help="The maximum number of tokens in the memory pool. If not specified, it will be automatically calculated based on the memory usage fraction. This option is typically used for development and debugging purposes.", + help="The maximum number of tokens in the memory pool. If not specified, it will be automatically calculated based on the memory usage fraction. " + "This option is typically used for development and debugging purposes.", ) parser.add_argument( "--chunked-prefill-size", @@ -395,15 +416,29 @@ class ServerArgs: ) # Optimization/debug options + parser.add_argument( + "--attention-backend", + type=str, + choices=["flashinfer", "triton"], + default=ServerArgs.attention_backend, + help="Choose the kernels for attention layers.", + ) + parser.add_argument( + "--sampling-backend", + type=str, + choices=["flashinfer", "pytorch"], + default=ServerArgs.sampling_backend, + help="Choose the kernels for sampling layers.", + ) parser.add_argument( "--disable-flashinfer", action="store_true", - help="Disable flashinfer attention kernels.", + help="Disable flashinfer attention kernels. This option will be deprecated in the next release. Please use '--attention-backend triton' instead.", ) parser.add_argument( "--disable-flashinfer-sampling", action="store_true", - help="Disable flashinfer sampling kernels.", + help="Disable flashinfer sampling kernels. This option will be deprecated in the next release. Please use '--sampling-backend pytorch' instead.", ) parser.add_argument( "--disable-radix-cache", @@ -491,14 +526,6 @@ class ServerArgs: assert not ( self.dp_size > 1 and self.node_rank is not None ), "multi-node data parallel is not supported" - if "Alibaba-NLP/gte-Qwen2-1.5B-instruct" == self.model_path: - logger.info( - "Not sure why, the tokenizer will add an additional token at the end of the prompt when trust_remote_mode=True" - ) - self.trust_remote_code = False - if "gemma-2" in self.model_path.lower(): - logger.info("When using sliding window in gemma-2, turn on flashinfer.") - self.disable_flashinfer = False def prepare_server_args(argv: List[str]) -> ServerArgs: diff --git a/test/srt/test_moe_serving_throughput.py b/test/srt/test_moe_serving_throughput.py index 2acf626c1..e0c851d4e 100644 --- a/test/srt/test_moe_serving_throughput.py +++ b/test/srt/test_moe_serving_throughput.py @@ -14,13 +14,12 @@ from sglang.test.test_utils import ( class TestServingThroughput(unittest.TestCase): - def run_test(self, disable_radix_cache, disable_flashinfer, chunked_prefill_size): + def run_test(self, disable_radix_cache, attention_backend, chunked_prefill_size): # Launch the server other_args = [] if disable_radix_cache: other_args.append("--disable-radix-cache") - if disable_flashinfer: - other_args.append("--disable-flashinfer") + other_args.extend(["--attention-backend", attention_backend]) other_args.extend(["--chunked-prefill-size", str(chunked_prefill_size)]) other_args.extend(["--tensor-parallel-size", "2"]) @@ -70,7 +69,7 @@ class TestServingThroughput(unittest.TestCase): def test_default(self): res = self.run_test( disable_radix_cache=ServerArgs.disable_radix_cache, - disable_flashinfer=ServerArgs.disable_flashinfer, + attention_backend=ServerArgs.attention_backend, chunked_prefill_size=ServerArgs.chunked_prefill_size, ) @@ -80,7 +79,7 @@ class TestServingThroughput(unittest.TestCase): def test_default_without_radix_cache(self): res = self.run_test( disable_radix_cache=True, - disable_flashinfer=ServerArgs.disable_flashinfer, + attention_backend=ServerArgs.attention_backend, chunked_prefill_size=ServerArgs.chunked_prefill_size, ) diff --git a/test/srt/test_serving_throughput.py b/test/srt/test_serving_throughput.py index d4ed12612..1b458e9e6 100644 --- a/test/srt/test_serving_throughput.py +++ b/test/srt/test_serving_throughput.py @@ -14,13 +14,12 @@ from sglang.test.test_utils import ( class TestServingThroughput(unittest.TestCase): - def run_test(self, disable_radix_cache, disable_flashinfer, chunked_prefill_size): + def run_test(self, disable_radix_cache, attention_backend, chunked_prefill_size): # Launch the server other_args = [] if disable_radix_cache: other_args.append("--disable-radix-cache") - if disable_flashinfer: - other_args.append("--disable-flashinfer") + other_args.extend(["--attention-backend", attention_backend]) other_args.extend(["--chunked-prefill-size", str(chunked_prefill_size)]) model = DEFAULT_MODEL_NAME_FOR_TEST @@ -69,7 +68,7 @@ class TestServingThroughput(unittest.TestCase): def test_default(self): res = self.run_test( disable_radix_cache=ServerArgs.disable_radix_cache, - disable_flashinfer=ServerArgs.disable_flashinfer, + attention_backend=ServerArgs.attention_backend, chunked_prefill_size=ServerArgs.chunked_prefill_size, ) @@ -79,7 +78,7 @@ class TestServingThroughput(unittest.TestCase): def test_default_without_radix_cache(self): res = self.run_test( disable_radix_cache=True, - disable_flashinfer=ServerArgs.disable_flashinfer, + attention_backend=ServerArgs.attention_backend, chunked_prefill_size=ServerArgs.chunked_prefill_size, ) @@ -89,7 +88,7 @@ class TestServingThroughput(unittest.TestCase): def test_default_without_chunked_prefill(self): res = self.run_test( disable_radix_cache=ServerArgs.disable_radix_cache, - disable_flashinfer=ServerArgs.disable_flashinfer, + attention_backend=ServerArgs.attention_backend, chunked_prefill_size=-1, ) diff --git a/test/srt/test_triton_attn_backend.py b/test/srt/test_triton_attn_backend.py index b3f65ac13..9c6519d91 100644 --- a/test/srt/test_triton_attn_backend.py +++ b/test/srt/test_triton_attn_backend.py @@ -20,7 +20,7 @@ class TestTritonAttnBackend(unittest.TestCase): cls.model, cls.base_url, timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, - other_args=["--disable-flashinfer"], + other_args=["--attention-backend", "triton"], ) @classmethod