Add a simple torch native attention backend (#2241)

This commit is contained in:
Qun Yang
2024-12-01 19:01:25 +08:00
committed by GitHub
parent fc78640e00
commit 62c516ac45
7 changed files with 388 additions and 26 deletions

View File

@@ -40,6 +40,7 @@ from vllm.model_executor.models import ModelRegistry
from sglang.srt.configs.model_config import AttentionArch, ModelConfig
from sglang.srt.layers.attention.double_sparsity_backend import DoubleSparseAttnBackend
from sglang.srt.layers.attention.flashinfer_backend import FlashInferAttnBackend
from sglang.srt.layers.attention.torch_native_backend import TorchNativeAttnBackend
from sglang.srt.layers.attention.triton_backend import TritonAttnBackend
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.layers.sampler import Sampler
@@ -570,6 +571,8 @@ class ModelRunner:
self.attn_backend = DoubleSparseAttnBackend(self)
else:
self.attn_backend = TritonAttnBackend(self)
elif self.server_args.attention_backend == "torch_native":
self.attn_backend = TorchNativeAttnBackend(self)
else:
raise ValueError(
f"Invalid attention backend: {self.server_args.attention_backend}"