Deprecate --disable-flashinfer and --disable-flashinfer-sampling (#2065)
This commit is contained in:
@@ -116,8 +116,6 @@ class ServerArgs:
|
|||||||
grammar_backend: Optional[str] = "outlines"
|
grammar_backend: Optional[str] = "outlines"
|
||||||
|
|
||||||
# Optimization/debug options
|
# Optimization/debug options
|
||||||
disable_flashinfer: bool = False
|
|
||||||
disable_flashinfer_sampling: bool = False
|
|
||||||
disable_radix_cache: bool = False
|
disable_radix_cache: bool = False
|
||||||
disable_jump_forward: bool = False
|
disable_jump_forward: bool = False
|
||||||
disable_cuda_graph: bool = False
|
disable_cuda_graph: bool = False
|
||||||
@@ -179,20 +177,6 @@ class ServerArgs:
|
|||||||
self.chunked_prefill_size //= 4 # make it 2048
|
self.chunked_prefill_size //= 4 # make it 2048
|
||||||
self.cuda_graph_max_bs = 4
|
self.cuda_graph_max_bs = 4
|
||||||
|
|
||||||
# Deprecation warnings
|
|
||||||
if self.disable_flashinfer:
|
|
||||||
logger.warning(
|
|
||||||
"The option '--disable-flashinfer' will be deprecated in the next release. "
|
|
||||||
"Please use '--attention-backend triton' instead."
|
|
||||||
)
|
|
||||||
self.attention_backend = "triton"
|
|
||||||
if self.disable_flashinfer_sampling:
|
|
||||||
logger.warning(
|
|
||||||
"The option '--disable-flashinfer-sampling' will be deprecated in the next release. "
|
|
||||||
"Please use '--sampling-backend pytorch' instead. "
|
|
||||||
)
|
|
||||||
self.sampling_backend = "pytorch"
|
|
||||||
|
|
||||||
if not is_flashinfer_available():
|
if not is_flashinfer_available():
|
||||||
self.attention_backend = "triton"
|
self.attention_backend = "triton"
|
||||||
self.sampling_backend = "pytorch"
|
self.sampling_backend = "pytorch"
|
||||||
@@ -615,16 +599,6 @@ class ServerArgs:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Optimization/debug options
|
# Optimization/debug options
|
||||||
parser.add_argument(
|
|
||||||
"--disable-flashinfer",
|
|
||||||
action="store_true",
|
|
||||||
help="Disable flashinfer attention kernels. This option will be deprecated in the next release. Please use '--attention-backend triton' instead.",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--disable-flashinfer-sampling",
|
|
||||||
action="store_true",
|
|
||||||
help="Disable flashinfer sampling kernels. This option will be deprecated in the next release. Please use '--sampling-backend pytorch' instead.",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--disable-radix-cache",
|
"--disable-radix-cache",
|
||||||
action="store_true",
|
action="store_true",
|
||||||
@@ -733,6 +707,18 @@ class ServerArgs:
|
|||||||
help="Delete the model checkpoint after loading the model.",
|
help="Delete the model checkpoint after loading the model.",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Deprecated arguments
|
||||||
|
parser.add_argument(
|
||||||
|
"--disable-flashinfer",
|
||||||
|
action=DeprecatedAction,
|
||||||
|
help="'--disable-flashinfer' is deprecated. Please use '--attention-backend triton' instead.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--disable-flashinfer-sampling",
|
||||||
|
action=DeprecatedAction,
|
||||||
|
help="'--disable-flashinfer-sampling' is deprecated. Please use '--sampling-backend pytroch' instead.",
|
||||||
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_cli_args(cls, args: argparse.Namespace):
|
def from_cli_args(cls, args: argparse.Namespace):
|
||||||
args.tp_size = args.tensor_parallel_size
|
args.tp_size = args.tensor_parallel_size
|
||||||
@@ -826,3 +812,13 @@ class LoRAPathAction(argparse.Action):
|
|||||||
getattr(namespace, self.dest)[name] = path
|
getattr(namespace, self.dest)[name] = path
|
||||||
else:
|
else:
|
||||||
getattr(namespace, self.dest)[lora_path] = lora_path
|
getattr(namespace, self.dest)[lora_path] = lora_path
|
||||||
|
|
||||||
|
|
||||||
|
class DeprecatedAction(argparse.Action):
|
||||||
|
def __init__(self, option_strings, dest, nargs=0, **kwargs):
|
||||||
|
super(DeprecatedAction, self).__init__(
|
||||||
|
option_strings, dest, nargs=nargs, **kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
def __call__(self, parser, namespace, values, option_string=None):
|
||||||
|
raise ValueError(self.help)
|
||||||
|
|||||||
@@ -71,6 +71,8 @@ def is_flashinfer_available():
|
|||||||
Check whether flashinfer is available.
|
Check whether flashinfer is available.
|
||||||
As of Oct. 6, 2024, it is only available on NVIDIA GPUs.
|
As of Oct. 6, 2024, it is only available on NVIDIA GPUs.
|
||||||
"""
|
"""
|
||||||
|
if os.environ.get("SGLANG_IS_FLASHINFER_AVAILABLE", "true") == "false":
|
||||||
|
return False
|
||||||
return torch.cuda.is_available() and not is_hip()
|
return torch.cuda.is_available() and not is_hip()
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -65,8 +65,7 @@ class TestTorchCompile(unittest.TestCase):
|
|||||||
tok = time.time()
|
tok = time.time()
|
||||||
print(f"{res=}")
|
print(f"{res=}")
|
||||||
throughput = max_tokens / (tok - tic)
|
throughput = max_tokens / (tok - tic)
|
||||||
print(f"Throughput: {throughput} tokens/s")
|
self.assertGreaterEqual(throughput, 285)
|
||||||
self.assertGreaterEqual(throughput, 290)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
Reference in New Issue
Block a user