Deprecate --disable-flashinfer and introduce --attention-backend (#1380)
This commit is contained in:
@@ -61,14 +61,18 @@ class RadixAttention(nn.Module):
|
||||
|
||||
# Choose backend
|
||||
if (
|
||||
not global_server_args_dict.get("disable_flashinfer", False)
|
||||
global_server_args_dict["attention_backend"] == "flashinfer"
|
||||
and self.qk_head_dim == self.v_head_dim
|
||||
):
|
||||
self.extend_forward = self.extend_forward_flashinfer
|
||||
self.decode_forward = self.decode_forward_flashinfer
|
||||
else:
|
||||
elif global_server_args_dict["attention_backend"] == "triton":
|
||||
self.extend_forward = self.extend_forward_triton
|
||||
self.decode_forward = self.decode_forward_triton
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Invalid attention backend: {global_server_args_dict['attention_backend']}"
|
||||
)
|
||||
|
||||
def extend_forward_triton(self, q, k, v, input_metadata: InputMetadata):
|
||||
if self.qk_head_dim != self.v_head_dim:
|
||||
|
||||
@@ -78,7 +78,7 @@ class Sampler(CustomOp):
|
||||
|
||||
probs = self._get_probs(logits, sampling_info)
|
||||
|
||||
if not global_server_args_dict["disable_flashinfer_sampling"]:
|
||||
if global_server_args_dict["sampling_backend"] == "flashinfer":
|
||||
max_top_k_round, batch_size = 32, probs.shape[0]
|
||||
uniform_samples = torch.rand(
|
||||
(max_top_k_round, batch_size), device=probs.device
|
||||
@@ -93,11 +93,15 @@ class Sampler(CustomOp):
|
||||
batch_next_token_ids, success = flashinfer_top_k_top_p(
|
||||
probs, uniform_samples, sampling_info.top_ks, sampling_info.top_ps
|
||||
)
|
||||
else:
|
||||
elif global_server_args_dict["sampling_backend"] == "pytorch":
|
||||
# Here we provide a slower fallback implementation.
|
||||
batch_next_token_ids, success = top_k_top_p_min_p_sampling_from_probs_torch(
|
||||
probs, sampling_info.top_ks, sampling_info.top_ps, sampling_info.min_ps
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Invalid sampling backend: {global_server_args_dict['sampling_backend']}"
|
||||
)
|
||||
|
||||
return SampleOutput(success, probs, batch_next_token_ids)
|
||||
|
||||
|
||||
@@ -31,6 +31,7 @@ from sglang.srt.mem_cache.chunk_cache import ChunkCache
|
||||
from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardMode
|
||||
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
|
||||
from sglang.srt.server_args import ServerArgs
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.layers.sampler import SampleOutput
|
||||
@@ -40,10 +41,11 @@ INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5
|
||||
|
||||
# Put some global args for easy access
|
||||
global_server_args_dict = {
|
||||
"disable_flashinfer": False,
|
||||
"disable_flashinfer_sampling": False,
|
||||
"triton_attention_reduce_in_fp32": False,
|
||||
"enable_mla": False,
|
||||
"attention_backend": ServerArgs.attention_backend,
|
||||
"sampling_backend": ServerArgs.sampling_backend,
|
||||
"triton_attention_reduce_in_fp32": ServerArgs.triton_attention_reduce_in_fp32,
|
||||
"enable_mla": ServerArgs.enable_mla,
|
||||
"torchao_config": ServerArgs.torchao_config,
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -128,7 +128,7 @@ class ModelTpServer:
|
||||
if server_args.max_running_requests is None
|
||||
else server_args.max_running_requests
|
||||
),
|
||||
self.model_runner.req_to_token_pool.size - 1,
|
||||
self.model_runner.req_to_token_pool.size,
|
||||
)
|
||||
self.max_req_input_len = min(
|
||||
self.model_config.context_len - 1,
|
||||
|
||||
@@ -203,17 +203,17 @@ class InputMetadata:
|
||||
ret.compute_extend_infos(batch)
|
||||
|
||||
fm = batch.forward_mode
|
||||
if not fm.is_decode() or model_runner.server_args.disable_flashinfer:
|
||||
if not fm.is_decode() or model_runner.server_args.attention_backend == "triton":
|
||||
ret.total_num_tokens = int(torch.sum(ret.seq_lens))
|
||||
|
||||
if not fm.is_decode():
|
||||
ret.init_multimuldal_info(batch)
|
||||
|
||||
if model_runner.server_args.disable_flashinfer:
|
||||
if model_runner.server_args.attention_backend == "triton":
|
||||
ret.init_triton_args(batch)
|
||||
|
||||
flashinfer_use_ragged = False
|
||||
if not model_runner.server_args.disable_flashinfer:
|
||||
if model_runner.server_args.attention_backend == "flashinfer":
|
||||
if (
|
||||
not fm.is_decode()
|
||||
and int(torch.sum(ret.seq_lens)) > 4096
|
||||
|
||||
@@ -53,7 +53,7 @@ from sglang.srt.mem_cache.memory_pool import (
|
||||
MLATokenToKVPool,
|
||||
ReqToTokenPool,
|
||||
)
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetadata
|
||||
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
||||
from sglang.srt.server_args import ServerArgs
|
||||
from sglang.srt.utils import (
|
||||
get_available_gpu_memory,
|
||||
@@ -92,8 +92,8 @@ class ModelRunner:
|
||||
)
|
||||
global_server_args_dict.update(
|
||||
{
|
||||
"disable_flashinfer": server_args.disable_flashinfer,
|
||||
"disable_flashinfer_sampling": server_args.disable_flashinfer_sampling,
|
||||
"attention_backend": server_args.attention_backend,
|
||||
"sampling_backend": server_args.sampling_backend,
|
||||
"triton_attention_reduce_in_fp32": server_args.triton_attention_reduce_in_fp32,
|
||||
"enable_mla": server_args.enable_mla,
|
||||
"torchao_config": server_args.torchao_config,
|
||||
@@ -111,7 +111,7 @@ class ModelRunner:
|
||||
self.load_model()
|
||||
self.init_memory_pool(
|
||||
min_per_gpu_memory,
|
||||
server_args.max_num_reqs,
|
||||
server_args.max_running_requests,
|
||||
server_args.max_total_tokens,
|
||||
)
|
||||
self.init_cublas()
|
||||
@@ -344,8 +344,8 @@ class ModelRunner:
|
||||
def init_memory_pool(
|
||||
self,
|
||||
total_gpu_memory: int,
|
||||
max_num_reqs: int = None,
|
||||
max_total_tokens: int = None,
|
||||
max_num_reqs: Optional[int] = None,
|
||||
max_total_tokens: Optional[int] = None,
|
||||
):
|
||||
if self.server_args.kv_cache_dtype == "auto":
|
||||
self.kv_cache_dtype = self.dtype
|
||||
@@ -379,7 +379,7 @@ class ModelRunner:
|
||||
),
|
||||
2048,
|
||||
),
|
||||
5120,
|
||||
4096,
|
||||
)
|
||||
|
||||
self.req_to_token_pool = ReqToTokenPool(
|
||||
@@ -399,7 +399,7 @@ class ModelRunner:
|
||||
)
|
||||
logger.info("using MLA Triton implementaion, flashinfer is disabled")
|
||||
# FIXME: temporarily only Triton MLA is supported
|
||||
self.server_args.disable_flashinfer = True
|
||||
self.server_args.attention_backend = "triton"
|
||||
else:
|
||||
self.token_to_kv_pool = MHATokenToKVPool(
|
||||
self.max_total_num_tokens,
|
||||
@@ -424,7 +424,7 @@ class ModelRunner:
|
||||
|
||||
def init_flashinfer(self):
|
||||
"""Init flashinfer attention kernel wrappers."""
|
||||
if self.server_args.disable_flashinfer:
|
||||
if self.server_args.attention_backend != "flashinfer":
|
||||
assert (
|
||||
self.sliding_window_size is None
|
||||
), "turn on flashinfer to support window attention"
|
||||
@@ -491,7 +491,10 @@ class ModelRunner:
|
||||
|
||||
from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner
|
||||
|
||||
if self.server_args.disable_cuda_graph or self.server_args.disable_flashinfer:
|
||||
if (
|
||||
self.server_args.disable_cuda_graph
|
||||
or self.server_args.attention_backend != "flashinfer"
|
||||
):
|
||||
self.cuda_graph_runner = None
|
||||
return
|
||||
|
||||
|
||||
@@ -425,7 +425,7 @@ def _set_envs_and_config(server_args: ServerArgs):
|
||||
maybe_set_triton_cache_manager()
|
||||
|
||||
# Check flashinfer version
|
||||
if not server_args.disable_flashinfer:
|
||||
if server_args.attention_backend == "flashinfer":
|
||||
assert_pkg_version(
|
||||
"flashinfer",
|
||||
"0.1.6",
|
||||
|
||||
@@ -17,7 +17,6 @@ limitations under the License.
|
||||
|
||||
import argparse
|
||||
import dataclasses
|
||||
import json
|
||||
import logging
|
||||
import random
|
||||
from typing import List, Optional, Union
|
||||
@@ -50,7 +49,6 @@ class ServerArgs:
|
||||
# Memory and scheduling
|
||||
mem_fraction_static: Optional[float] = None
|
||||
max_running_requests: Optional[int] = None
|
||||
max_num_reqs: Optional[int] = None
|
||||
max_total_tokens: Optional[int] = None
|
||||
chunked_prefill_size: int = 8192
|
||||
max_prefill_tokens: int = 16384
|
||||
@@ -85,6 +83,9 @@ class ServerArgs:
|
||||
json_model_override_args: str = "{}"
|
||||
|
||||
# Optimization/debug options
|
||||
attention_backend: str = "flashinfer"
|
||||
sampling_backend: str = "flashinfer"
|
||||
|
||||
disable_flashinfer: bool = False
|
||||
disable_flashinfer_sampling: bool = False
|
||||
disable_radix_cache: bool = False
|
||||
@@ -101,6 +102,7 @@ class ServerArgs:
|
||||
triton_attention_reduce_in_fp32: bool = False
|
||||
|
||||
def __post_init__(self):
|
||||
# Set missing default values
|
||||
if self.tokenizer_path is None:
|
||||
self.tokenizer_path = self.model_path
|
||||
|
||||
@@ -111,6 +113,7 @@ class ServerArgs:
|
||||
# Disable chunked prefill
|
||||
self.chunked_prefill_size = None
|
||||
|
||||
# Mem fraction depends on the tensor parallelism size
|
||||
if self.mem_fraction_static is None:
|
||||
if self.tp_size >= 16:
|
||||
self.mem_fraction_static = 0.79
|
||||
@@ -131,6 +134,29 @@ class ServerArgs:
|
||||
if self.random_seed is None:
|
||||
self.random_seed = random.randint(0, 1 << 30)
|
||||
|
||||
# Deprecation warnings
|
||||
if self.disable_flashinfer:
|
||||
logger.warning(
|
||||
"The option '--disable-flashinfer' will be deprecated in the next release. "
|
||||
"Please use '--attention-backend triton' instead."
|
||||
)
|
||||
if self.disable_flashinfer_sampling:
|
||||
logger.warning(
|
||||
"The option '--disable-flashinfer-sampling' will be deprecated in the next release. "
|
||||
"Please use '--sampling-backend pytorch' instead. "
|
||||
)
|
||||
|
||||
# Model-specific patches
|
||||
if "Alibaba-NLP/gte-Qwen2-1.5B-instruct" == self.model_path:
|
||||
logger.info(
|
||||
"Not sure why, the tokenizer will add an additional token at the end of the prompt when trust_remote_mode=True"
|
||||
)
|
||||
self.trust_remote_code = False
|
||||
|
||||
if "gemma-2" in self.model_path.lower():
|
||||
logger.info("When using sliding window in gemma-2, turn on flashinfer.")
|
||||
self.attention_backend = "flashinfer"
|
||||
|
||||
@staticmethod
|
||||
def add_cli_args(parser: argparse.ArgumentParser):
|
||||
parser.add_argument(
|
||||
@@ -214,11 +240,6 @@ class ServerArgs:
|
||||
action="store_true",
|
||||
help="Whether or not to allow for custom models defined on the Hub in their own modeling files.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--is-embedding",
|
||||
action="store_true",
|
||||
help="Whether to use a CausalLM as an embedding model.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--context-length",
|
||||
type=int,
|
||||
@@ -253,6 +274,11 @@ class ServerArgs:
|
||||
default=ServerArgs.chat_template,
|
||||
help="The buliltin chat template name or the path of the chat template file. This is only used for OpenAI-compatible API server.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--is-embedding",
|
||||
action="store_true",
|
||||
help="Whether to use a CausalLM as an embedding model.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--mem-fraction-static",
|
||||
type=float,
|
||||
@@ -265,17 +291,12 @@ class ServerArgs:
|
||||
default=ServerArgs.max_running_requests,
|
||||
help="The maximum number of running requests.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-num-reqs",
|
||||
type=int,
|
||||
default=ServerArgs.max_num_reqs,
|
||||
help="The maximum number of requests to serve in the memory pool. If the model have a large context length, you may need to decrease this value to avoid out-of-memory errors.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-total-tokens",
|
||||
type=int,
|
||||
default=ServerArgs.max_total_tokens,
|
||||
help="The maximum number of tokens in the memory pool. If not specified, it will be automatically calculated based on the memory usage fraction. This option is typically used for development and debugging purposes.",
|
||||
help="The maximum number of tokens in the memory pool. If not specified, it will be automatically calculated based on the memory usage fraction. "
|
||||
"This option is typically used for development and debugging purposes.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--chunked-prefill-size",
|
||||
@@ -395,15 +416,29 @@ class ServerArgs:
|
||||
)
|
||||
|
||||
# Optimization/debug options
|
||||
parser.add_argument(
|
||||
"--attention-backend",
|
||||
type=str,
|
||||
choices=["flashinfer", "triton"],
|
||||
default=ServerArgs.attention_backend,
|
||||
help="Choose the kernels for attention layers.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--sampling-backend",
|
||||
type=str,
|
||||
choices=["flashinfer", "pytorch"],
|
||||
default=ServerArgs.sampling_backend,
|
||||
help="Choose the kernels for sampling layers.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--disable-flashinfer",
|
||||
action="store_true",
|
||||
help="Disable flashinfer attention kernels.",
|
||||
help="Disable flashinfer attention kernels. This option will be deprecated in the next release. Please use '--attention-backend triton' instead.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--disable-flashinfer-sampling",
|
||||
action="store_true",
|
||||
help="Disable flashinfer sampling kernels.",
|
||||
help="Disable flashinfer sampling kernels. This option will be deprecated in the next release. Please use '--sampling-backend pytorch' instead.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--disable-radix-cache",
|
||||
@@ -491,14 +526,6 @@ class ServerArgs:
|
||||
assert not (
|
||||
self.dp_size > 1 and self.node_rank is not None
|
||||
), "multi-node data parallel is not supported"
|
||||
if "Alibaba-NLP/gte-Qwen2-1.5B-instruct" == self.model_path:
|
||||
logger.info(
|
||||
"Not sure why, the tokenizer will add an additional token at the end of the prompt when trust_remote_mode=True"
|
||||
)
|
||||
self.trust_remote_code = False
|
||||
if "gemma-2" in self.model_path.lower():
|
||||
logger.info("When using sliding window in gemma-2, turn on flashinfer.")
|
||||
self.disable_flashinfer = False
|
||||
|
||||
|
||||
def prepare_server_args(argv: List[str]) -> ServerArgs:
|
||||
|
||||
Reference in New Issue
Block a user