[Minor] move triton attention kernels into a separate folder (#1379)
This commit is contained in:
@@ -22,13 +22,20 @@ from flashinfer.cascade import merge_state
|
||||
from torch import nn
|
||||
|
||||
from sglang.global_config import global_config
|
||||
from sglang.srt.layers.decode_attention import decode_attention_fwd
|
||||
from sglang.srt.layers.extend_attention import extend_attention_fwd
|
||||
from sglang.srt.layers.triton_attention.decode_attention import decode_attention_fwd
|
||||
from sglang.srt.layers.triton_attention.extend_attention import extend_attention_fwd
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetadata
|
||||
from sglang.srt.model_executor.model_runner import global_server_args_dict
|
||||
|
||||
|
||||
class RadixAttention(nn.Module):
|
||||
"""
|
||||
The attention layer implementation.
|
||||
Now it has two backends: FlashInfer and Triton.
|
||||
FlashInfer is faster and Triton is easier to customize.
|
||||
It supports two operators: extend (i.e. prefill with cached prefix) and decode.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
@@ -49,8 +56,10 @@ class RadixAttention(nn.Module):
|
||||
self.v_head_dim = v_head_dim if v_head_dim != -1 else head_dim
|
||||
self.scaling = scaling
|
||||
self.layer_id = layer_id
|
||||
self.logit_cap = logit_cap if logit_cap is not None and logit_cap > 0 else 0
|
||||
self.sliding_window_size = sliding_window_size if sliding_window_size else -1
|
||||
|
||||
# Choose backend
|
||||
if (
|
||||
not global_server_args_dict.get("disable_flashinfer", False)
|
||||
and self.qk_head_dim == self.v_head_dim
|
||||
@@ -61,8 +70,6 @@ class RadixAttention(nn.Module):
|
||||
self.extend_forward = self.extend_forward_triton
|
||||
self.decode_forward = self.decode_forward_triton
|
||||
|
||||
self.logit_cap = logit_cap if logit_cap is not None and logit_cap > 0 else 0
|
||||
|
||||
def extend_forward_triton(self, q, k, v, input_metadata: InputMetadata):
|
||||
if self.qk_head_dim != self.v_head_dim:
|
||||
o = q.new_empty((q.shape[0], self.tp_q_head_num * self.v_head_dim))
|
||||
|
||||
Reference in New Issue
Block a user