[Minor] move triton attention kernels into a separate folder (#1379)

This commit is contained in:
Lianmin Zheng
2024-09-10 15:15:08 -07:00
committed by GitHub
parent fbb4754cb8
commit 3a6e8b6d78
13 changed files with 24 additions and 15 deletions

View File

@@ -22,13 +22,20 @@ from flashinfer.cascade import merge_state
from torch import nn
from sglang.global_config import global_config
from sglang.srt.layers.decode_attention import decode_attention_fwd
from sglang.srt.layers.extend_attention import extend_attention_fwd
from sglang.srt.layers.triton_attention.decode_attention import decode_attention_fwd
from sglang.srt.layers.triton_attention.extend_attention import extend_attention_fwd
from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetadata
from sglang.srt.model_executor.model_runner import global_server_args_dict
class RadixAttention(nn.Module):
"""
The attention layer implementation.
Now it has two backends: FlashInfer and Triton.
FlashInfer is faster and Triton is easier to customize.
It supports two operators: extend (i.e. prefill with cached prefix) and decode.
"""
def __init__(
self,
num_heads: int,
@@ -49,8 +56,10 @@ class RadixAttention(nn.Module):
self.v_head_dim = v_head_dim if v_head_dim != -1 else head_dim
self.scaling = scaling
self.layer_id = layer_id
self.logit_cap = logit_cap if logit_cap is not None and logit_cap > 0 else 0
self.sliding_window_size = sliding_window_size if sliding_window_size else -1
# Choose backend
if (
not global_server_args_dict.get("disable_flashinfer", False)
and self.qk_head_dim == self.v_head_dim
@@ -61,8 +70,6 @@ class RadixAttention(nn.Module):
self.extend_forward = self.extend_forward_triton
self.decode_forward = self.decode_forward_triton
self.logit_cap = logit_cap if logit_cap is not None and logit_cap > 0 else 0
def extend_forward_triton(self, q, k, v, input_metadata: InputMetadata):
if self.qk_head_dim != self.v_head_dim:
o = q.new_empty((q.shape[0], self.tp_q_head_num * self.v_head_dim))