Add a simple torch native attention backend (#2241)
This commit is contained in:
285
python/sglang/srt/layers/attention/torch_native_backend.py
Normal file
285
python/sglang/srt/layers/attention/torch_native_backend.py
Normal file
@@ -0,0 +1,285 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
import torch
|
||||
from torch.nn.functional import scaled_dot_product_attention
|
||||
|
||||
from sglang.srt.layers.attention import AttentionBackend
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.layers.radix_attention import RadixAttention
|
||||
from sglang.srt.model_executor.model_runner import ModelRunner
|
||||
|
||||
|
||||
class TorchNativeAttnBackend(AttentionBackend):
|
||||
def __init__(self, model_runner: ModelRunner):
|
||||
super().__init__()
|
||||
self.forward_metadata = None
|
||||
self.device = model_runner.device
|
||||
|
||||
def init_forward_metadata(self, forward_batch: ForwardBatch):
|
||||
"""Init the metadata for a forward pass."""
|
||||
pass
|
||||
|
||||
def init_cuda_graph_state(self, max_bs: int):
|
||||
# TODO: Support CUDA graph
|
||||
raise ValueError(
|
||||
"Torch native attention does not support CUDA graph for now. Please --disable-cuda-graph"
|
||||
)
|
||||
|
||||
def init_forward_metadata_capture_cuda_graph(
|
||||
self,
|
||||
bs: int,
|
||||
req_pool_indices: torch.Tensor,
|
||||
seq_lens: torch.Tensor,
|
||||
encoder_lens: Optional[torch.Tensor] = None,
|
||||
):
|
||||
# TODO: Support CUDA graph
|
||||
raise ValueError(
|
||||
"Torch native attention does not support CUDA graph for now. Please --disable-cuda-graph"
|
||||
)
|
||||
|
||||
def init_forward_metadata_replay_cuda_graph(
|
||||
self,
|
||||
bs: int,
|
||||
req_pool_indices: torch.Tensor,
|
||||
seq_lens: torch.Tensor,
|
||||
seq_lens_sum: int,
|
||||
encoder_lens: Optional[torch.Tensor] = None,
|
||||
):
|
||||
# TODO: Support CUDA graph
|
||||
raise ValueError(
|
||||
"Torch native attention does not support CUDA graph for now. Please --disable-cuda-graph"
|
||||
)
|
||||
|
||||
def get_cuda_graph_seq_len_fill_value(self):
|
||||
# TODO: Support CUDA graph
|
||||
raise ValueError(
|
||||
"Torch native attention does not support CUDA graph for now. Please --disable-cuda-graph"
|
||||
)
|
||||
|
||||
def _run_sdpa_forward_extend(
|
||||
self,
|
||||
query: torch.Tensor,
|
||||
output: torch.Tensor,
|
||||
k_cache: torch.Tensor,
|
||||
v_cache: torch.Tensor,
|
||||
req_to_token: torch.Tensor,
|
||||
req_pool_indices: torch.Tensor,
|
||||
seq_lens: torch.Tensor,
|
||||
extend_prefix_lens: torch.Tensor,
|
||||
extend_seq_lens: torch.Tensor,
|
||||
scaling=None,
|
||||
enable_gqa=False,
|
||||
causal=False,
|
||||
):
|
||||
"""Run the extend forward by using torch native sdpa op.
|
||||
|
||||
Args:
|
||||
query: [num_tokens, num_heads, head_size]
|
||||
output: [num_tokens, num_heads, head_size]
|
||||
k_cache: [max_total_num_tokens, num_heads, head_size]
|
||||
v_cache: [max_total_num_tokens, num_heads, head_size]
|
||||
req_to_token: [max_num_reqs, max_context_len]
|
||||
req_pool_indices: [num_seqs]
|
||||
seq_lens: [num_seqs]
|
||||
extend_prefix_lens: [num_seqs]
|
||||
extend_seq_lens: [num_seqs]
|
||||
scaling: float or None
|
||||
enable_gqa: bool
|
||||
causal: bool
|
||||
|
||||
Returns:
|
||||
output: [num_tokens, num_heads, head_size]
|
||||
"""
|
||||
|
||||
assert seq_lens.shape[0] == extend_prefix_lens.shape[0]
|
||||
assert seq_lens.shape[0] == extend_seq_lens.shape[0]
|
||||
|
||||
# [num_tokens, num_heads, head_size] -> [num_heads, num_tokens, head_size]
|
||||
query = query.movedim(0, query.dim() - 2)
|
||||
|
||||
start_q, start_kv = 0, 0
|
||||
for seq_idx in range(seq_lens.shape[0]):
|
||||
# TODO: this loop process a sequence per iter, this is inefficient.
|
||||
# Need optimize the performance later.
|
||||
|
||||
extend_seq_len_q = extend_seq_lens[seq_idx]
|
||||
prefill_seq_len_q = extend_prefix_lens[seq_idx]
|
||||
|
||||
seq_len_kv = seq_lens[seq_idx]
|
||||
end_q = start_q + extend_seq_len_q
|
||||
end_kv = start_kv + seq_len_kv
|
||||
|
||||
per_req_query = query[:, start_q:end_q, :]
|
||||
per_req_query_redudant = torch.empty(
|
||||
(per_req_query.shape[0], seq_len_kv, per_req_query.shape[2]),
|
||||
dtype=per_req_query.dtype,
|
||||
device=per_req_query.device,
|
||||
)
|
||||
|
||||
per_req_query_redudant[:, prefill_seq_len_q:, :] = per_req_query
|
||||
|
||||
# get key and value from cache. per_req_tokens contains the kv cache
|
||||
# index for each token in the sequence.
|
||||
req_pool_idx = req_pool_indices[seq_idx]
|
||||
per_req_tokens = req_to_token[req_pool_idx, :seq_len_kv]
|
||||
per_req_key = k_cache[per_req_tokens].movedim(0, query.dim() - 2)
|
||||
per_req_value = v_cache[per_req_tokens].movedim(0, query.dim() - 2)
|
||||
|
||||
per_req_out_redudant = (
|
||||
scaled_dot_product_attention(
|
||||
per_req_query_redudant.unsqueeze(0),
|
||||
per_req_key.unsqueeze(0),
|
||||
per_req_value.unsqueeze(0),
|
||||
enable_gqa=enable_gqa,
|
||||
scale=scaling,
|
||||
is_causal=causal,
|
||||
)
|
||||
.squeeze(0)
|
||||
.movedim(query.dim() - 2, 0)
|
||||
)
|
||||
output[start_q:end_q, :, :] = per_req_out_redudant[prefill_seq_len_q:, :, :]
|
||||
start_q, start_kv = end_q, end_kv
|
||||
return output
|
||||
|
||||
def _run_sdpa_forward_decode(
|
||||
self,
|
||||
query: torch.Tensor,
|
||||
output: torch.Tensor,
|
||||
k_cache: torch.Tensor,
|
||||
v_cache: torch.Tensor,
|
||||
req_to_token: torch.Tensor,
|
||||
req_pool_indices: torch.Tensor,
|
||||
seq_lens: torch.Tensor,
|
||||
scaling=None,
|
||||
enable_gqa=False,
|
||||
causal=False,
|
||||
):
|
||||
"""Run the decode forward by using torch native sdpa op.
|
||||
|
||||
Args:
|
||||
query: [num_tokens, num_heads, head_size]
|
||||
output: [num_tokens, num_heads, head_size]
|
||||
k_cache: [max_total_num_tokens, num_heads, head_size]
|
||||
v_cache: [max_total_num_tokens, num_heads, head_size]
|
||||
req_to_token: [max_num_reqs, max_context_len]
|
||||
req_pool_indices: [num_seqs]
|
||||
seq_lens: [num_seqs]
|
||||
scaling: float or None
|
||||
enable_gqa: bool
|
||||
causal: bool
|
||||
|
||||
Returns:
|
||||
output: [num_tokens, num_heads, head_size]
|
||||
"""
|
||||
|
||||
# [num_tokens, num_heads, head_size] -> [num_heads, num_tokens, head_size]
|
||||
query = query.movedim(0, query.dim() - 2)
|
||||
|
||||
start_q, start_kv = 0, 0
|
||||
for seq_idx in range(seq_lens.shape[0]):
|
||||
# TODO: this loop process a sequence per iter, this is inefficient.
|
||||
# Need optimize the performance later.
|
||||
|
||||
seq_len_q = 1
|
||||
seq_len_kv = seq_lens[seq_idx]
|
||||
end_q = start_q + seq_len_q
|
||||
end_kv = start_kv + seq_len_kv
|
||||
|
||||
per_req_query = query[:, start_q:end_q, :]
|
||||
|
||||
# get key and value from cache. per_req_tokens contains the kv cache
|
||||
# index for each token in the sequence.
|
||||
req_pool_idx = req_pool_indices[seq_idx]
|
||||
per_req_tokens = req_to_token[req_pool_idx, :seq_len_kv]
|
||||
per_req_key = k_cache[per_req_tokens].movedim(0, query.dim() - 2)
|
||||
per_req_value = v_cache[per_req_tokens].movedim(0, query.dim() - 2)
|
||||
|
||||
per_req_out = (
|
||||
scaled_dot_product_attention(
|
||||
per_req_query.unsqueeze(0),
|
||||
per_req_key.unsqueeze(0),
|
||||
per_req_value.unsqueeze(0),
|
||||
enable_gqa=enable_gqa,
|
||||
scale=scaling,
|
||||
is_causal=causal,
|
||||
)
|
||||
.squeeze(0)
|
||||
.movedim(query.dim() - 2, 0)
|
||||
)
|
||||
output[start_q:end_q, :, :] = per_req_out
|
||||
start_q, start_kv = end_q, end_kv
|
||||
|
||||
return output
|
||||
|
||||
def forward_extend(
|
||||
self, q, k, v, layer: RadixAttention, forward_batch: ForwardBatch
|
||||
):
|
||||
if layer.qk_head_dim != layer.v_head_dim:
|
||||
o = q.new_empty((q.shape[0], layer.tp_q_head_num * layer.v_head_dim))
|
||||
else:
|
||||
o = torch.empty_like(q)
|
||||
|
||||
forward_batch.token_to_kv_pool.set_kv_buffer(
|
||||
layer, forward_batch.out_cache_loc, k, v
|
||||
)
|
||||
|
||||
use_gqa = layer.tp_q_head_num != layer.tp_k_head_num
|
||||
|
||||
q_ = q.view(-1, layer.tp_q_head_num, layer.qk_head_dim)
|
||||
o_ = o.view(-1, layer.tp_q_head_num, layer.v_head_dim)
|
||||
|
||||
self._run_sdpa_forward_extend(
|
||||
q_,
|
||||
o_,
|
||||
forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id),
|
||||
forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id),
|
||||
forward_batch.req_to_token_pool.req_to_token,
|
||||
forward_batch.req_pool_indices,
|
||||
forward_batch.seq_lens,
|
||||
forward_batch.extend_prefix_lens,
|
||||
forward_batch.extend_seq_lens,
|
||||
scaling=layer.scaling,
|
||||
enable_gqa=use_gqa,
|
||||
causal=not layer.is_cross_attention,
|
||||
)
|
||||
return o
|
||||
|
||||
def forward_decode(
|
||||
self, q, k, v, layer: RadixAttention, forward_batch: ForwardBatch
|
||||
):
|
||||
# During torch.compile, there is a bug in rotary_emb that causes the
|
||||
# output value to have a 3D tensor shape. This reshapes the output correctly.
|
||||
q = q.reshape(-1, layer.tp_q_head_num * layer.qk_head_dim)
|
||||
|
||||
if layer.qk_head_dim != layer.v_head_dim:
|
||||
o = q.new_empty((q.shape[0], layer.tp_q_head_num * layer.v_head_dim))
|
||||
else:
|
||||
o = torch.empty_like(q)
|
||||
|
||||
forward_batch.token_to_kv_pool.set_kv_buffer(
|
||||
layer, forward_batch.out_cache_loc, k, v
|
||||
)
|
||||
|
||||
use_gqa = layer.tp_q_head_num != layer.tp_k_head_num
|
||||
|
||||
q_ = q.view(-1, layer.tp_q_head_num, layer.qk_head_dim)
|
||||
o_ = o.view(-1, layer.tp_q_head_num, layer.v_head_dim)
|
||||
|
||||
self._run_sdpa_forward_decode(
|
||||
q_,
|
||||
o_,
|
||||
forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id),
|
||||
forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id),
|
||||
forward_batch.req_to_token_pool.req_to_token,
|
||||
forward_batch.req_pool_indices,
|
||||
forward_batch.seq_lens,
|
||||
scaling=layer.scaling,
|
||||
enable_gqa=use_gqa,
|
||||
causal=False,
|
||||
)
|
||||
|
||||
return o
|
||||
@@ -743,20 +743,24 @@ class ScheduleBatch:
|
||||
extend_lens = torch.tensor(self.extend_lens, dtype=torch.int32).to(
|
||||
self.device, non_blocking=True
|
||||
)
|
||||
write_req_to_token_pool_triton[(bs,)](
|
||||
self.req_to_token_pool.req_to_token,
|
||||
self.req_pool_indices,
|
||||
pre_lens,
|
||||
self.seq_lens,
|
||||
extend_lens,
|
||||
self.out_cache_loc,
|
||||
self.req_to_token_pool.req_to_token.shape[1],
|
||||
)
|
||||
# The triton kernel is equivalent to the following python code.
|
||||
# self.req_to_token_pool.write(
|
||||
# (req.req_pool_idx, slice(pre_len, seq_len)),
|
||||
# out_cache_loc[pt : pt + req.extend_input_len],
|
||||
# )
|
||||
if global_server_args_dict["attention_backend"] != "torch_native":
|
||||
write_req_to_token_pool_triton[(bs,)](
|
||||
self.req_to_token_pool.req_to_token,
|
||||
self.req_pool_indices,
|
||||
pre_lens,
|
||||
self.seq_lens,
|
||||
extend_lens,
|
||||
self.out_cache_loc,
|
||||
self.req_to_token_pool.req_to_token.shape[1],
|
||||
)
|
||||
else:
|
||||
pt = 0
|
||||
for i in range(bs):
|
||||
self.req_to_token_pool.write(
|
||||
(self.req_pool_indices[i], slice(pre_lens[i], self.seq_lens[i])),
|
||||
self.out_cache_loc[pt : pt + self.extend_lens[i]],
|
||||
)
|
||||
pt += self.extend_lens[i]
|
||||
# TODO: some tensors can be reused for ForwardBatchInfo (e.g., extend_lens, cumsum_start)
|
||||
|
||||
if self.model_config.is_encoder_decoder:
|
||||
|
||||
@@ -256,10 +256,15 @@ class ForwardBatch:
|
||||
ret.extend_prefix_lens = torch.tensor(
|
||||
batch.extend_prefix_lens, dtype=torch.int32
|
||||
).to(device, non_blocking=True)
|
||||
ret.extend_num_tokens = batch.extend_num_tokens
|
||||
ret.positions, ret.extend_start_loc = compute_position_triton(
|
||||
ret.extend_prefix_lens, ret.extend_seq_lens, ret.extend_num_tokens
|
||||
)
|
||||
if model_runner.server_args.attention_backend != "torch_native":
|
||||
ret.extend_num_tokens = batch.extend_num_tokens
|
||||
ret.positions, ret.extend_start_loc = compute_position_triton(
|
||||
ret.extend_prefix_lens, ret.extend_seq_lens, ret.extend_num_tokens
|
||||
)
|
||||
else:
|
||||
ret.positions, ret.extend_start_loc = compute_position_torch(
|
||||
ret.extend_prefix_lens, ret.extend_seq_lens
|
||||
)
|
||||
ret.extend_prefix_lens_cpu = batch.extend_prefix_lens
|
||||
ret.extend_seq_lens_cpu = batch.extend_seq_lens
|
||||
ret.extend_logprob_start_lens_cpu = batch.extend_logprob_start_lens
|
||||
|
||||
@@ -40,6 +40,7 @@ from vllm.model_executor.models import ModelRegistry
|
||||
from sglang.srt.configs.model_config import AttentionArch, ModelConfig
|
||||
from sglang.srt.layers.attention.double_sparsity_backend import DoubleSparseAttnBackend
|
||||
from sglang.srt.layers.attention.flashinfer_backend import FlashInferAttnBackend
|
||||
from sglang.srt.layers.attention.torch_native_backend import TorchNativeAttnBackend
|
||||
from sglang.srt.layers.attention.triton_backend import TritonAttnBackend
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
||||
from sglang.srt.layers.sampler import Sampler
|
||||
@@ -570,6 +571,8 @@ class ModelRunner:
|
||||
self.attn_backend = DoubleSparseAttnBackend(self)
|
||||
else:
|
||||
self.attn_backend = TritonAttnBackend(self)
|
||||
elif self.server_args.attention_backend == "torch_native":
|
||||
self.attn_backend = TorchNativeAttnBackend(self)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Invalid attention backend: {self.server_args.attention_backend}"
|
||||
|
||||
@@ -180,15 +180,21 @@ class ServerArgs:
|
||||
else:
|
||||
self.cuda_graph_max_bs = 160
|
||||
|
||||
# Set kernel backends
|
||||
if not is_flashinfer_available():
|
||||
self.attention_backend = "triton"
|
||||
self.sampling_backend = "pytorch"
|
||||
|
||||
# Choose kernel backends
|
||||
if self.attention_backend is None:
|
||||
self.attention_backend = "flashinfer"
|
||||
self.attention_backend = (
|
||||
"flashinfer" if is_flashinfer_available() else "triton"
|
||||
)
|
||||
if self.sampling_backend is None:
|
||||
self.sampling_backend = "flashinfer"
|
||||
self.sampling_backend = (
|
||||
"flashinfer" if is_flashinfer_available() else "pytorch"
|
||||
)
|
||||
|
||||
if self.attention_backend == "torch_native":
|
||||
logger.info(
|
||||
"Cuda graph is disabled because of using torch native attention backend"
|
||||
)
|
||||
self.disable_cuda_graph = True
|
||||
|
||||
# Others
|
||||
if self.enable_dp_attention:
|
||||
@@ -586,7 +592,7 @@ class ServerArgs:
|
||||
parser.add_argument(
|
||||
"--attention-backend",
|
||||
type=str,
|
||||
choices=["flashinfer", "triton"],
|
||||
choices=["flashinfer", "triton", "torch_native"],
|
||||
default=ServerArgs.attention_backend,
|
||||
help="Choose the kernels for attention layers.",
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user