Deprecate --disable-flashinfer and introduce --attention-backend (#1380)

This commit is contained in:
Lianmin Zheng
2024-09-10 17:11:16 -07:00
committed by GitHub
parent 3a6e8b6d78
commit 46094e0c1b
13 changed files with 99 additions and 61 deletions

View File

@@ -17,7 +17,6 @@ limitations under the License.
import argparse
import dataclasses
import json
import logging
import random
from typing import List, Optional, Union
@@ -50,7 +49,6 @@ class ServerArgs:
# Memory and scheduling
mem_fraction_static: Optional[float] = None
max_running_requests: Optional[int] = None
max_num_reqs: Optional[int] = None
max_total_tokens: Optional[int] = None
chunked_prefill_size: int = 8192
max_prefill_tokens: int = 16384
@@ -85,6 +83,9 @@ class ServerArgs:
json_model_override_args: str = "{}"
# Optimization/debug options
attention_backend: str = "flashinfer"
sampling_backend: str = "flashinfer"
disable_flashinfer: bool = False
disable_flashinfer_sampling: bool = False
disable_radix_cache: bool = False
@@ -101,6 +102,7 @@ class ServerArgs:
triton_attention_reduce_in_fp32: bool = False
def __post_init__(self):
# Set missing default values
if self.tokenizer_path is None:
self.tokenizer_path = self.model_path
@@ -111,6 +113,7 @@ class ServerArgs:
# Disable chunked prefill
self.chunked_prefill_size = None
# Mem fraction depends on the tensor parallelism size
if self.mem_fraction_static is None:
if self.tp_size >= 16:
self.mem_fraction_static = 0.79
@@ -131,6 +134,29 @@ class ServerArgs:
if self.random_seed is None:
self.random_seed = random.randint(0, 1 << 30)
# Deprecation warnings
if self.disable_flashinfer:
logger.warning(
"The option '--disable-flashinfer' will be deprecated in the next release. "
"Please use '--attention-backend triton' instead."
)
if self.disable_flashinfer_sampling:
logger.warning(
"The option '--disable-flashinfer-sampling' will be deprecated in the next release. "
"Please use '--sampling-backend pytorch' instead. "
)
# Model-specific patches
if "Alibaba-NLP/gte-Qwen2-1.5B-instruct" == self.model_path:
logger.info(
"Not sure why, the tokenizer will add an additional token at the end of the prompt when trust_remote_mode=True"
)
self.trust_remote_code = False
if "gemma-2" in self.model_path.lower():
logger.info("When using sliding window in gemma-2, turn on flashinfer.")
self.attention_backend = "flashinfer"
@staticmethod
def add_cli_args(parser: argparse.ArgumentParser):
parser.add_argument(
@@ -214,11 +240,6 @@ class ServerArgs:
action="store_true",
help="Whether or not to allow for custom models defined on the Hub in their own modeling files.",
)
parser.add_argument(
"--is-embedding",
action="store_true",
help="Whether to use a CausalLM as an embedding model.",
)
parser.add_argument(
"--context-length",
type=int,
@@ -253,6 +274,11 @@ class ServerArgs:
default=ServerArgs.chat_template,
help="The buliltin chat template name or the path of the chat template file. This is only used for OpenAI-compatible API server.",
)
parser.add_argument(
"--is-embedding",
action="store_true",
help="Whether to use a CausalLM as an embedding model.",
)
parser.add_argument(
"--mem-fraction-static",
type=float,
@@ -265,17 +291,12 @@ class ServerArgs:
default=ServerArgs.max_running_requests,
help="The maximum number of running requests.",
)
parser.add_argument(
"--max-num-reqs",
type=int,
default=ServerArgs.max_num_reqs,
help="The maximum number of requests to serve in the memory pool. If the model have a large context length, you may need to decrease this value to avoid out-of-memory errors.",
)
parser.add_argument(
"--max-total-tokens",
type=int,
default=ServerArgs.max_total_tokens,
help="The maximum number of tokens in the memory pool. If not specified, it will be automatically calculated based on the memory usage fraction. This option is typically used for development and debugging purposes.",
help="The maximum number of tokens in the memory pool. If not specified, it will be automatically calculated based on the memory usage fraction. "
"This option is typically used for development and debugging purposes.",
)
parser.add_argument(
"--chunked-prefill-size",
@@ -395,15 +416,29 @@ class ServerArgs:
)
# Optimization/debug options
parser.add_argument(
"--attention-backend",
type=str,
choices=["flashinfer", "triton"],
default=ServerArgs.attention_backend,
help="Choose the kernels for attention layers.",
)
parser.add_argument(
"--sampling-backend",
type=str,
choices=["flashinfer", "pytorch"],
default=ServerArgs.sampling_backend,
help="Choose the kernels for sampling layers.",
)
parser.add_argument(
"--disable-flashinfer",
action="store_true",
help="Disable flashinfer attention kernels.",
help="Disable flashinfer attention kernels. This option will be deprecated in the next release. Please use '--attention-backend triton' instead.",
)
parser.add_argument(
"--disable-flashinfer-sampling",
action="store_true",
help="Disable flashinfer sampling kernels.",
help="Disable flashinfer sampling kernels. This option will be deprecated in the next release. Please use '--sampling-backend pytorch' instead.",
)
parser.add_argument(
"--disable-radix-cache",
@@ -491,14 +526,6 @@ class ServerArgs:
assert not (
self.dp_size > 1 and self.node_rank is not None
), "multi-node data parallel is not supported"
if "Alibaba-NLP/gte-Qwen2-1.5B-instruct" == self.model_path:
logger.info(
"Not sure why, the tokenizer will add an additional token at the end of the prompt when trust_remote_mode=True"
)
self.trust_remote_code = False
if "gemma-2" in self.model_path.lower():
logger.info("When using sliding window in gemma-2, turn on flashinfer.")
self.disable_flashinfer = False
def prepare_server_args(argv: List[str]) -> ServerArgs: