Deprecate --disable-flashinfer and introduce --attention-backend (#1380)
This commit is contained in:
@@ -17,7 +17,6 @@ limitations under the License.
|
||||
|
||||
import argparse
|
||||
import dataclasses
|
||||
import json
|
||||
import logging
|
||||
import random
|
||||
from typing import List, Optional, Union
|
||||
@@ -50,7 +49,6 @@ class ServerArgs:
|
||||
# Memory and scheduling
|
||||
mem_fraction_static: Optional[float] = None
|
||||
max_running_requests: Optional[int] = None
|
||||
max_num_reqs: Optional[int] = None
|
||||
max_total_tokens: Optional[int] = None
|
||||
chunked_prefill_size: int = 8192
|
||||
max_prefill_tokens: int = 16384
|
||||
@@ -85,6 +83,9 @@ class ServerArgs:
|
||||
json_model_override_args: str = "{}"
|
||||
|
||||
# Optimization/debug options
|
||||
attention_backend: str = "flashinfer"
|
||||
sampling_backend: str = "flashinfer"
|
||||
|
||||
disable_flashinfer: bool = False
|
||||
disable_flashinfer_sampling: bool = False
|
||||
disable_radix_cache: bool = False
|
||||
@@ -101,6 +102,7 @@ class ServerArgs:
|
||||
triton_attention_reduce_in_fp32: bool = False
|
||||
|
||||
def __post_init__(self):
|
||||
# Set missing default values
|
||||
if self.tokenizer_path is None:
|
||||
self.tokenizer_path = self.model_path
|
||||
|
||||
@@ -111,6 +113,7 @@ class ServerArgs:
|
||||
# Disable chunked prefill
|
||||
self.chunked_prefill_size = None
|
||||
|
||||
# Mem fraction depends on the tensor parallelism size
|
||||
if self.mem_fraction_static is None:
|
||||
if self.tp_size >= 16:
|
||||
self.mem_fraction_static = 0.79
|
||||
@@ -131,6 +134,29 @@ class ServerArgs:
|
||||
if self.random_seed is None:
|
||||
self.random_seed = random.randint(0, 1 << 30)
|
||||
|
||||
# Deprecation warnings
|
||||
if self.disable_flashinfer:
|
||||
logger.warning(
|
||||
"The option '--disable-flashinfer' will be deprecated in the next release. "
|
||||
"Please use '--attention-backend triton' instead."
|
||||
)
|
||||
if self.disable_flashinfer_sampling:
|
||||
logger.warning(
|
||||
"The option '--disable-flashinfer-sampling' will be deprecated in the next release. "
|
||||
"Please use '--sampling-backend pytorch' instead. "
|
||||
)
|
||||
|
||||
# Model-specific patches
|
||||
if "Alibaba-NLP/gte-Qwen2-1.5B-instruct" == self.model_path:
|
||||
logger.info(
|
||||
"Not sure why, the tokenizer will add an additional token at the end of the prompt when trust_remote_mode=True"
|
||||
)
|
||||
self.trust_remote_code = False
|
||||
|
||||
if "gemma-2" in self.model_path.lower():
|
||||
logger.info("When using sliding window in gemma-2, turn on flashinfer.")
|
||||
self.attention_backend = "flashinfer"
|
||||
|
||||
@staticmethod
|
||||
def add_cli_args(parser: argparse.ArgumentParser):
|
||||
parser.add_argument(
|
||||
@@ -214,11 +240,6 @@ class ServerArgs:
|
||||
action="store_true",
|
||||
help="Whether or not to allow for custom models defined on the Hub in their own modeling files.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--is-embedding",
|
||||
action="store_true",
|
||||
help="Whether to use a CausalLM as an embedding model.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--context-length",
|
||||
type=int,
|
||||
@@ -253,6 +274,11 @@ class ServerArgs:
|
||||
default=ServerArgs.chat_template,
|
||||
help="The buliltin chat template name or the path of the chat template file. This is only used for OpenAI-compatible API server.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--is-embedding",
|
||||
action="store_true",
|
||||
help="Whether to use a CausalLM as an embedding model.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--mem-fraction-static",
|
||||
type=float,
|
||||
@@ -265,17 +291,12 @@ class ServerArgs:
|
||||
default=ServerArgs.max_running_requests,
|
||||
help="The maximum number of running requests.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-num-reqs",
|
||||
type=int,
|
||||
default=ServerArgs.max_num_reqs,
|
||||
help="The maximum number of requests to serve in the memory pool. If the model have a large context length, you may need to decrease this value to avoid out-of-memory errors.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-total-tokens",
|
||||
type=int,
|
||||
default=ServerArgs.max_total_tokens,
|
||||
help="The maximum number of tokens in the memory pool. If not specified, it will be automatically calculated based on the memory usage fraction. This option is typically used for development and debugging purposes.",
|
||||
help="The maximum number of tokens in the memory pool. If not specified, it will be automatically calculated based on the memory usage fraction. "
|
||||
"This option is typically used for development and debugging purposes.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--chunked-prefill-size",
|
||||
@@ -395,15 +416,29 @@ class ServerArgs:
|
||||
)
|
||||
|
||||
# Optimization/debug options
|
||||
parser.add_argument(
|
||||
"--attention-backend",
|
||||
type=str,
|
||||
choices=["flashinfer", "triton"],
|
||||
default=ServerArgs.attention_backend,
|
||||
help="Choose the kernels for attention layers.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--sampling-backend",
|
||||
type=str,
|
||||
choices=["flashinfer", "pytorch"],
|
||||
default=ServerArgs.sampling_backend,
|
||||
help="Choose the kernels for sampling layers.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--disable-flashinfer",
|
||||
action="store_true",
|
||||
help="Disable flashinfer attention kernels.",
|
||||
help="Disable flashinfer attention kernels. This option will be deprecated in the next release. Please use '--attention-backend triton' instead.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--disable-flashinfer-sampling",
|
||||
action="store_true",
|
||||
help="Disable flashinfer sampling kernels.",
|
||||
help="Disable flashinfer sampling kernels. This option will be deprecated in the next release. Please use '--sampling-backend pytorch' instead.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--disable-radix-cache",
|
||||
@@ -491,14 +526,6 @@ class ServerArgs:
|
||||
assert not (
|
||||
self.dp_size > 1 and self.node_rank is not None
|
||||
), "multi-node data parallel is not supported"
|
||||
if "Alibaba-NLP/gte-Qwen2-1.5B-instruct" == self.model_path:
|
||||
logger.info(
|
||||
"Not sure why, the tokenizer will add an additional token at the end of the prompt when trust_remote_mode=True"
|
||||
)
|
||||
self.trust_remote_code = False
|
||||
if "gemma-2" in self.model_path.lower():
|
||||
logger.info("When using sliding window in gemma-2, turn on flashinfer.")
|
||||
self.disable_flashinfer = False
|
||||
|
||||
|
||||
def prepare_server_args(argv: List[str]) -> ServerArgs:
|
||||
|
||||
Reference in New Issue
Block a user