Add a simple torch native attention backend (#2241)
This commit is contained in:
@@ -180,15 +180,21 @@ class ServerArgs:
|
||||
else:
|
||||
self.cuda_graph_max_bs = 160
|
||||
|
||||
# Set kernel backends
|
||||
if not is_flashinfer_available():
|
||||
self.attention_backend = "triton"
|
||||
self.sampling_backend = "pytorch"
|
||||
|
||||
# Choose kernel backends
|
||||
if self.attention_backend is None:
|
||||
self.attention_backend = "flashinfer"
|
||||
self.attention_backend = (
|
||||
"flashinfer" if is_flashinfer_available() else "triton"
|
||||
)
|
||||
if self.sampling_backend is None:
|
||||
self.sampling_backend = "flashinfer"
|
||||
self.sampling_backend = (
|
||||
"flashinfer" if is_flashinfer_available() else "pytorch"
|
||||
)
|
||||
|
||||
if self.attention_backend == "torch_native":
|
||||
logger.info(
|
||||
"Cuda graph is disabled because of using torch native attention backend"
|
||||
)
|
||||
self.disable_cuda_graph = True
|
||||
|
||||
# Others
|
||||
if self.enable_dp_attention:
|
||||
@@ -586,7 +592,7 @@ class ServerArgs:
|
||||
parser.add_argument(
|
||||
"--attention-backend",
|
||||
type=str,
|
||||
choices=["flashinfer", "triton"],
|
||||
choices=["flashinfer", "triton", "torch_native"],
|
||||
default=ServerArgs.attention_backend,
|
||||
help="Choose the kernels for attention layers.",
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user