Deprecate --disable-flashinfer and --disable-flashinfer-sampling (#2065)

This commit is contained in:
Lianmin Zheng
2024-11-17 16:20:58 -08:00
committed by GitHub
parent 38625e2139
commit 11f881d173
3 changed files with 25 additions and 28 deletions

View File

@@ -116,8 +116,6 @@ class ServerArgs:
grammar_backend: Optional[str] = "outlines" grammar_backend: Optional[str] = "outlines"
# Optimization/debug options # Optimization/debug options
disable_flashinfer: bool = False
disable_flashinfer_sampling: bool = False
disable_radix_cache: bool = False disable_radix_cache: bool = False
disable_jump_forward: bool = False disable_jump_forward: bool = False
disable_cuda_graph: bool = False disable_cuda_graph: bool = False
@@ -179,20 +177,6 @@ class ServerArgs:
self.chunked_prefill_size //= 4 # make it 2048 self.chunked_prefill_size //= 4 # make it 2048
self.cuda_graph_max_bs = 4 self.cuda_graph_max_bs = 4
# Deprecation warnings
if self.disable_flashinfer:
logger.warning(
"The option '--disable-flashinfer' will be deprecated in the next release. "
"Please use '--attention-backend triton' instead."
)
self.attention_backend = "triton"
if self.disable_flashinfer_sampling:
logger.warning(
"The option '--disable-flashinfer-sampling' will be deprecated in the next release. "
"Please use '--sampling-backend pytorch' instead. "
)
self.sampling_backend = "pytorch"
if not is_flashinfer_available(): if not is_flashinfer_available():
self.attention_backend = "triton" self.attention_backend = "triton"
self.sampling_backend = "pytorch" self.sampling_backend = "pytorch"
@@ -615,16 +599,6 @@ class ServerArgs:
) )
# Optimization/debug options # Optimization/debug options
parser.add_argument(
"--disable-flashinfer",
action="store_true",
help="Disable flashinfer attention kernels. This option will be deprecated in the next release. Please use '--attention-backend triton' instead.",
)
parser.add_argument(
"--disable-flashinfer-sampling",
action="store_true",
help="Disable flashinfer sampling kernels. This option will be deprecated in the next release. Please use '--sampling-backend pytorch' instead.",
)
parser.add_argument( parser.add_argument(
"--disable-radix-cache", "--disable-radix-cache",
action="store_true", action="store_true",
@@ -733,6 +707,18 @@ class ServerArgs:
help="Delete the model checkpoint after loading the model.", help="Delete the model checkpoint after loading the model.",
) )
# Deprecated arguments
parser.add_argument(
"--disable-flashinfer",
action=DeprecatedAction,
help="'--disable-flashinfer' is deprecated. Please use '--attention-backend triton' instead.",
)
parser.add_argument(
"--disable-flashinfer-sampling",
action=DeprecatedAction,
help="'--disable-flashinfer-sampling' is deprecated. Please use '--sampling-backend pytroch' instead.",
)
@classmethod @classmethod
def from_cli_args(cls, args: argparse.Namespace): def from_cli_args(cls, args: argparse.Namespace):
args.tp_size = args.tensor_parallel_size args.tp_size = args.tensor_parallel_size
@@ -826,3 +812,13 @@ class LoRAPathAction(argparse.Action):
getattr(namespace, self.dest)[name] = path getattr(namespace, self.dest)[name] = path
else: else:
getattr(namespace, self.dest)[lora_path] = lora_path getattr(namespace, self.dest)[lora_path] = lora_path
class DeprecatedAction(argparse.Action):
def __init__(self, option_strings, dest, nargs=0, **kwargs):
super(DeprecatedAction, self).__init__(
option_strings, dest, nargs=nargs, **kwargs
)
def __call__(self, parser, namespace, values, option_string=None):
raise ValueError(self.help)

View File

@@ -71,6 +71,8 @@ def is_flashinfer_available():
Check whether flashinfer is available. Check whether flashinfer is available.
As of Oct. 6, 2024, it is only available on NVIDIA GPUs. As of Oct. 6, 2024, it is only available on NVIDIA GPUs.
""" """
if os.environ.get("SGLANG_IS_FLASHINFER_AVAILABLE", "true") == "false":
return False
return torch.cuda.is_available() and not is_hip() return torch.cuda.is_available() and not is_hip()

View File

@@ -65,8 +65,7 @@ class TestTorchCompile(unittest.TestCase):
tok = time.time() tok = time.time()
print(f"{res=}") print(f"{res=}")
throughput = max_tokens / (tok - tic) throughput = max_tokens / (tok - tic)
print(f"Throughput: {throughput} tokens/s") self.assertGreaterEqual(throughput, 285)
self.assertGreaterEqual(throughput, 290)
if __name__ == "__main__": if __name__ == "__main__":