Refactor attention backend (#1381)
This commit is contained in:
383
python/sglang/srt/layers/attention_backend.py
Normal file
383
python/sglang/srt/layers/attention_backend.py
Normal file
@@ -0,0 +1,383 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
"""
|
||||||
|
Support different attention backends.
|
||||||
|
Now there are two backends: FlashInfer and Triton.
|
||||||
|
FlashInfer is faster and Triton is easier to customize.
|
||||||
|
Each backend supports two operators: extend (i.e. prefill with cached prefix) and decode.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from flashinfer import (
|
||||||
|
BatchDecodeWithPagedKVCacheWrapper,
|
||||||
|
BatchPrefillWithPagedKVCacheWrapper,
|
||||||
|
BatchPrefillWithRaggedKVCacheWrapper,
|
||||||
|
)
|
||||||
|
from flashinfer.cascade import merge_state
|
||||||
|
from flashinfer.decode import _grouped_size_compiled_for_decode_kernels
|
||||||
|
|
||||||
|
from sglang.global_config import global_config
|
||||||
|
from sglang.srt.layers.flashinfer_utils import update_flashinfer_indices
|
||||||
|
from sglang.srt.managers.schedule_batch import ScheduleBatch
|
||||||
|
from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetadata
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from sglang.srt.model_executor.model_runner import ModelRunner
|
||||||
|
|
||||||
|
|
||||||
|
class AttentionBackend(ABC):
|
||||||
|
"""The base class of attention backends"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def init_forward_metadata(
|
||||||
|
self, batch: ScheduleBatch, input_metadata: InputMetadata
|
||||||
|
):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def forward(self, q, k, v, layer, input_metadata: InputMetadata):
|
||||||
|
if input_metadata.forward_mode.is_decode():
|
||||||
|
return self.forward_decode(q, k, v, layer, input_metadata)
|
||||||
|
else:
|
||||||
|
return self.forward_extend(q, k, v, layer, input_metadata)
|
||||||
|
|
||||||
|
|
||||||
|
class FlashInferAttnBackend(AttentionBackend):
|
||||||
|
"""Flashinfer attention kernels."""
|
||||||
|
|
||||||
|
def __init__(self, model_runner: ModelRunner):
|
||||||
|
super().__init__()
|
||||||
|
self.model_runner = model_runner
|
||||||
|
|
||||||
|
if not _grouped_size_compiled_for_decode_kernels(
|
||||||
|
model_runner.model_config.num_attention_heads // model_runner.tp_size,
|
||||||
|
model_runner.model_config.get_num_kv_heads(model_runner.tp_size),
|
||||||
|
):
|
||||||
|
self.decode_use_tensor_cores = True
|
||||||
|
else:
|
||||||
|
self.decode_use_tensor_cores = False
|
||||||
|
|
||||||
|
self.workspace_buffer = torch.empty(
|
||||||
|
global_config.flashinfer_workspace_size,
|
||||||
|
dtype=torch.uint8,
|
||||||
|
device="cuda",
|
||||||
|
)
|
||||||
|
|
||||||
|
if model_runner.sliding_window_size is None:
|
||||||
|
self.prefill_wrapper_ragged = BatchPrefillWithRaggedKVCacheWrapper(
|
||||||
|
self.workspace_buffer, "NHD"
|
||||||
|
)
|
||||||
|
self.prefill_wrapper_paged = BatchPrefillWithPagedKVCacheWrapper(
|
||||||
|
self.workspace_buffer, "NHD"
|
||||||
|
)
|
||||||
|
self.decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
|
||||||
|
self.workspace_buffer,
|
||||||
|
"NHD",
|
||||||
|
use_tensor_cores=self.decode_use_tensor_cores,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Two wrappers: one for sliding window attention and one for full attention.
|
||||||
|
# Using two wrappers is unnecessary in the current PR, but are prepared for future PRs
|
||||||
|
self.prefill_wrapper_ragged = None
|
||||||
|
self.prefill_wrapper_paged = []
|
||||||
|
self.decode_wrapper = []
|
||||||
|
for _ in range(2):
|
||||||
|
self.prefill_wrapper_paged.append(
|
||||||
|
BatchPrefillWithPagedKVCacheWrapper(self.workspace_buffer, "NHD")
|
||||||
|
)
|
||||||
|
self.decode_wrapper.append(
|
||||||
|
BatchDecodeWithPagedKVCacheWrapper(
|
||||||
|
self.workspace_buffer,
|
||||||
|
"NHD",
|
||||||
|
use_tensor_cores=self.decode_use_tensor_cores,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
self.forward_metadata = None
|
||||||
|
self.cuda_graph_metadata = {}
|
||||||
|
|
||||||
|
def init_forward_metadata(
|
||||||
|
self, batch: ScheduleBatch, input_metadata: InputMetadata
|
||||||
|
):
|
||||||
|
if input_metadata.forward_mode.is_decode():
|
||||||
|
prefix_lens = None
|
||||||
|
use_ragged = False
|
||||||
|
total_num_tokens = None
|
||||||
|
else:
|
||||||
|
prefix_lens = input_metadata.extend_prefix_lens
|
||||||
|
|
||||||
|
# Some heuristics to check whether to use ragged forward
|
||||||
|
use_ragged = False
|
||||||
|
if (
|
||||||
|
int(torch.sum(input_metadata.seq_lens)) > 4096
|
||||||
|
and self.model_runner.sliding_window_size is None
|
||||||
|
):
|
||||||
|
use_ragged = True
|
||||||
|
|
||||||
|
total_num_tokens = torch.sum(input_metadata.seq_lens).item()
|
||||||
|
|
||||||
|
update_flashinfer_indices(
|
||||||
|
input_metadata.forward_mode,
|
||||||
|
self.model_runner,
|
||||||
|
input_metadata.req_pool_indices,
|
||||||
|
input_metadata.seq_lens,
|
||||||
|
prefix_lens,
|
||||||
|
use_ragged=use_ragged,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.forward_metadata = (use_ragged, total_num_tokens, self.decode_wrapper)
|
||||||
|
|
||||||
|
def init_cuda_graph_state(self, max_bs: int):
|
||||||
|
self.cuda_graph_kv_indptr = torch.zeros(
|
||||||
|
(max_bs + 1,), dtype=torch.int32, device="cuda"
|
||||||
|
)
|
||||||
|
self.cuda_graph_kv_indices = torch.zeros(
|
||||||
|
(max_bs * self.model_runner.model_config.context_len,),
|
||||||
|
dtype=torch.int32,
|
||||||
|
device="cuda",
|
||||||
|
)
|
||||||
|
self.cuda_graph_kv_last_page_len = torch.ones(
|
||||||
|
(max_bs,), dtype=torch.int32, device="cuda"
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.model_runner.sliding_window_size is not None:
|
||||||
|
self.cuda_graph_kv_indptr = [
|
||||||
|
self.cuda_graph_kv_indptr,
|
||||||
|
self.cuda_graph_kv_indptr.clone(),
|
||||||
|
]
|
||||||
|
self.cuda_graph_kv_indices = [
|
||||||
|
self.cuda_graph_kv_indices,
|
||||||
|
self.cuda_graph_kv_indices.clone(),
|
||||||
|
]
|
||||||
|
|
||||||
|
def capture_cuda_graph_init(self, bs: int, req_pool_indices, seq_lens):
|
||||||
|
if self.model_runner.sliding_window_size is None:
|
||||||
|
decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
|
||||||
|
self.workspace_buffer,
|
||||||
|
"NHD",
|
||||||
|
use_cuda_graph=True,
|
||||||
|
use_tensor_cores=self.decode_use_tensor_cores,
|
||||||
|
paged_kv_indptr_buffer=self.cuda_graph_kv_indptr[: bs + 1],
|
||||||
|
paged_kv_indices_buffer=self.cuda_graph_kv_indices,
|
||||||
|
paged_kv_last_page_len_buffer=self.cuda_graph_kv_last_page_len[:bs],
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
decode_wrapper = []
|
||||||
|
for i in range(2):
|
||||||
|
decode_wrapper.append(
|
||||||
|
BatchDecodeWithPagedKVCacheWrapper(
|
||||||
|
self.workspace_buffer,
|
||||||
|
"NHD",
|
||||||
|
use_cuda_graph=True,
|
||||||
|
use_tensor_cores=self.decode_use_tensor_cores,
|
||||||
|
paged_kv_indptr_buffer=self.cuda_graph_kv_indptr[i][: bs + 1],
|
||||||
|
paged_kv_indices_buffer=self.cuda_graph_kv_indices[i],
|
||||||
|
paged_kv_last_page_len_buffer=self.cuda_graph_kv_last_page_len[
|
||||||
|
:bs
|
||||||
|
],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
update_flashinfer_indices(
|
||||||
|
ForwardMode.DECODE,
|
||||||
|
self.model_runner,
|
||||||
|
req_pool_indices,
|
||||||
|
seq_lens,
|
||||||
|
None,
|
||||||
|
decode_wrapper,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.cuda_graph_metadata[bs] = decode_wrapper
|
||||||
|
|
||||||
|
self.forward_metadata = (False, None, decode_wrapper)
|
||||||
|
|
||||||
|
def replay_cuda_graph_init(self, bs: int, req_pool_indices, seq_lens):
|
||||||
|
update_flashinfer_indices(
|
||||||
|
ForwardMode.DECODE,
|
||||||
|
self.model_runner,
|
||||||
|
req_pool_indices[:bs],
|
||||||
|
seq_lens[:bs],
|
||||||
|
None,
|
||||||
|
self.cuda_graph_metadata[bs],
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward_extend(self, q, k, v, layer: nn.Module, input_metadata: InputMetadata):
|
||||||
|
if not isinstance(self.prefill_wrapper_paged, list):
|
||||||
|
prefill_wrapper_paged = self.prefill_wrapper_paged
|
||||||
|
else:
|
||||||
|
if layer.sliding_window_size != -1:
|
||||||
|
prefill_wrapper_paged = self.prefill_wrapper_paged[0]
|
||||||
|
else:
|
||||||
|
prefill_wrapper_paged = self.prefill_wrapper_paged[1]
|
||||||
|
|
||||||
|
use_ragged, total_num_tokens, decode_wrapper = self.forward_metadata
|
||||||
|
|
||||||
|
if not use_ragged:
|
||||||
|
if k is not None:
|
||||||
|
assert v is not None
|
||||||
|
input_metadata.token_to_kv_pool.set_kv_buffer(
|
||||||
|
layer.layer_id, input_metadata.out_cache_loc, k, v
|
||||||
|
)
|
||||||
|
o = prefill_wrapper_paged.forward(
|
||||||
|
q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
|
||||||
|
input_metadata.token_to_kv_pool.get_kv_buffer(layer.layer_id),
|
||||||
|
causal=True,
|
||||||
|
sm_scale=layer.scaling,
|
||||||
|
window_left=layer.sliding_window_size,
|
||||||
|
logits_soft_cap=layer.logit_cap,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
o1, s1 = self.prefill_wrapper_ragged.forward_return_lse(
|
||||||
|
q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
|
||||||
|
k.contiguous().view(-1, layer.tp_k_head_num, layer.head_dim),
|
||||||
|
v.contiguous().view(-1, layer.tp_v_head_num, layer.head_dim),
|
||||||
|
causal=True,
|
||||||
|
sm_scale=layer.scaling,
|
||||||
|
logits_soft_cap=layer.logit_cap,
|
||||||
|
)
|
||||||
|
|
||||||
|
if input_metadata.extend_no_prefix:
|
||||||
|
o = o1
|
||||||
|
else:
|
||||||
|
o2, s2 = prefill_wrapper_paged.forward_return_lse(
|
||||||
|
q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
|
||||||
|
input_metadata.token_to_kv_pool.get_kv_buffer(layer.layer_id),
|
||||||
|
causal=False,
|
||||||
|
sm_scale=layer.scaling,
|
||||||
|
logits_soft_cap=layer.logit_cap,
|
||||||
|
)
|
||||||
|
|
||||||
|
o, _ = merge_state(o1, s1, o2, s2)
|
||||||
|
|
||||||
|
input_metadata.token_to_kv_pool.set_kv_buffer(
|
||||||
|
layer.layer_id, input_metadata.out_cache_loc, k, v
|
||||||
|
)
|
||||||
|
|
||||||
|
if total_num_tokens >= global_config.layer_sync_threshold:
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
|
return o.view(-1, layer.tp_q_head_num * layer.head_dim)
|
||||||
|
|
||||||
|
def forward_decode(self, q, k, v, layer: nn.Module, input_metadata: InputMetadata):
|
||||||
|
use_ragged, total_num_tokens, decode_wrapper = self.forward_metadata
|
||||||
|
|
||||||
|
if isinstance(decode_wrapper, list):
|
||||||
|
if layer.sliding_window_size != -1:
|
||||||
|
decode_wrapper = decode_wrapper[0]
|
||||||
|
else:
|
||||||
|
decode_wrapper = decode_wrapper[1]
|
||||||
|
|
||||||
|
if k is not None:
|
||||||
|
assert v is not None
|
||||||
|
input_metadata.token_to_kv_pool.set_kv_buffer(
|
||||||
|
layer.layer_id, input_metadata.out_cache_loc, k, v
|
||||||
|
)
|
||||||
|
|
||||||
|
o = decode_wrapper.forward(
|
||||||
|
q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
|
||||||
|
input_metadata.token_to_kv_pool.get_kv_buffer(layer.layer_id),
|
||||||
|
sm_scale=layer.scaling,
|
||||||
|
logits_soft_cap=layer.logit_cap,
|
||||||
|
)
|
||||||
|
|
||||||
|
return o.view(-1, layer.tp_q_head_num * layer.head_dim)
|
||||||
|
|
||||||
|
|
||||||
|
class TritonAttnBackend(AttentionBackend):
|
||||||
|
def __init__(self, model_runner: ModelRunner):
|
||||||
|
# Lazy import to avoid the initialization of cuda context
|
||||||
|
from sglang.srt.layers.triton_attention.decode_attention import (
|
||||||
|
decode_attention_fwd,
|
||||||
|
)
|
||||||
|
from sglang.srt.layers.triton_attention.extend_attention import (
|
||||||
|
extend_attention_fwd,
|
||||||
|
)
|
||||||
|
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.decode_attention_fwd = decode_attention_fwd
|
||||||
|
self.extend_attention_fwd = extend_attention_fwd
|
||||||
|
|
||||||
|
self.forward_metadata = None
|
||||||
|
|
||||||
|
def init_forward_metadata(
|
||||||
|
self, batch: ScheduleBatch, input_metadata: InputMetadata
|
||||||
|
):
|
||||||
|
"""Init auxiliary variables for triton attention backend."""
|
||||||
|
|
||||||
|
if input_metadata.forward_mode.is_decode():
|
||||||
|
max_seq_len = torch.max(input_metadata.seq_lens).item()
|
||||||
|
start_loc = torch.zeros_like(input_metadata.seq_lens, dtype=torch.int32)
|
||||||
|
start_loc[1:] = torch.cumsum(input_metadata.seq_lens[:-1], dim=0)
|
||||||
|
|
||||||
|
total_num_tokens = torch.sum(input_metadata.seq_lens).item()
|
||||||
|
max_extend_len = None
|
||||||
|
else:
|
||||||
|
start_loc = max_seq_len = total_num_tokens = None
|
||||||
|
prefix_lens = torch.tensor(batch.prefix_lens_cpu, device="cuda")
|
||||||
|
max_extend_len = torch.max(input_metadata.seq_lens - prefix_lens).item()
|
||||||
|
|
||||||
|
self.forward_metadata = start_loc, max_seq_len, max_extend_len, total_num_tokens
|
||||||
|
|
||||||
|
def forward_extend(self, q, k, v, layer: nn.Module, input_metadata: InputMetadata):
|
||||||
|
if layer.qk_head_dim != layer.v_head_dim:
|
||||||
|
o = q.new_empty((q.shape[0], layer.tp_q_head_num * layer.v_head_dim))
|
||||||
|
else:
|
||||||
|
o = torch.empty_like(q)
|
||||||
|
|
||||||
|
input_metadata.token_to_kv_pool.set_kv_buffer(
|
||||||
|
layer.layer_id, input_metadata.out_cache_loc, k, v
|
||||||
|
)
|
||||||
|
|
||||||
|
start_loc, max_seq_len, max_extend_len, total_num_tokens = self.forward_metadata
|
||||||
|
|
||||||
|
self.extend_attention_fwd(
|
||||||
|
q.view(-1, layer.tp_q_head_num, layer.qk_head_dim),
|
||||||
|
k.contiguous(),
|
||||||
|
v.contiguous(),
|
||||||
|
o.view(-1, layer.tp_q_head_num, layer.v_head_dim),
|
||||||
|
input_metadata.token_to_kv_pool.get_key_buffer(layer.layer_id),
|
||||||
|
input_metadata.token_to_kv_pool.get_value_buffer(layer.layer_id),
|
||||||
|
input_metadata.req_to_token_pool.req_to_token,
|
||||||
|
input_metadata.req_pool_indices,
|
||||||
|
input_metadata.seq_lens,
|
||||||
|
input_metadata.extend_seq_lens,
|
||||||
|
input_metadata.extend_start_loc,
|
||||||
|
max_extend_len,
|
||||||
|
layer.scaling,
|
||||||
|
layer.logit_cap,
|
||||||
|
)
|
||||||
|
|
||||||
|
return o
|
||||||
|
|
||||||
|
def forward_decode(self, q, k, v, layer: nn.Module, input_metadata: InputMetadata):
|
||||||
|
if layer.qk_head_dim != layer.v_head_dim:
|
||||||
|
o = q.new_empty((q.shape[0], layer.tp_q_head_num * layer.v_head_dim))
|
||||||
|
else:
|
||||||
|
o = torch.empty_like(q)
|
||||||
|
|
||||||
|
start_loc, max_seq_len, max_extend_len, total_num_tokens = self.forward_metadata
|
||||||
|
|
||||||
|
input_metadata.token_to_kv_pool.set_kv_buffer(
|
||||||
|
layer.layer_id, input_metadata.out_cache_loc, k, v
|
||||||
|
)
|
||||||
|
|
||||||
|
self.decode_attention_fwd(
|
||||||
|
q.view(-1, layer.tp_q_head_num, layer.qk_head_dim),
|
||||||
|
input_metadata.token_to_kv_pool.get_key_buffer(layer.layer_id),
|
||||||
|
input_metadata.token_to_kv_pool.get_value_buffer(layer.layer_id),
|
||||||
|
o.view(-1, layer.tp_q_head_num, layer.v_head_dim),
|
||||||
|
input_metadata.req_to_token_pool.req_to_token,
|
||||||
|
input_metadata.req_pool_indices,
|
||||||
|
start_loc,
|
||||||
|
input_metadata.seq_lens,
|
||||||
|
max_seq_len,
|
||||||
|
total_num_tokens,
|
||||||
|
layer.scaling,
|
||||||
|
layer.logit_cap,
|
||||||
|
)
|
||||||
|
|
||||||
|
return o
|
||||||
@@ -10,8 +10,8 @@ def create_flashinfer_kv_indices_triton(
|
|||||||
page_kernel_lens_ptr,
|
page_kernel_lens_ptr,
|
||||||
kv_indptr,
|
kv_indptr,
|
||||||
kv_start_idx,
|
kv_start_idx,
|
||||||
max_context_len,
|
|
||||||
kv_indices_ptr,
|
kv_indices_ptr,
|
||||||
|
max_context_len: tl.constexpr,
|
||||||
):
|
):
|
||||||
BLOCK_SIZE: tl.constexpr = 512
|
BLOCK_SIZE: tl.constexpr = 512
|
||||||
pid = tl.program_id(axis=0)
|
pid = tl.program_id(axis=0)
|
||||||
@@ -47,15 +47,15 @@ class FlashinferUpdater:
|
|||||||
req_pool_indices,
|
req_pool_indices,
|
||||||
seq_lens,
|
seq_lens,
|
||||||
prefix_lens,
|
prefix_lens,
|
||||||
flashinfer_decode_wrapper=None,
|
decode_wrapper=None,
|
||||||
flashinfer_use_ragged=False,
|
use_ragged=False,
|
||||||
):
|
):
|
||||||
self.forward_mode = forward_mode
|
self.forward_mode = forward_mode
|
||||||
self.model_runner = model_runner
|
self.model_runner = model_runner
|
||||||
self.req_pool_indices = req_pool_indices
|
self.req_pool_indices = req_pool_indices
|
||||||
self.seq_lens = seq_lens
|
self.seq_lens = seq_lens
|
||||||
self.prefix_lens = prefix_lens
|
self.prefix_lens = prefix_lens
|
||||||
self.flashinfer_use_ragged = flashinfer_use_ragged
|
self.use_ragged = use_ragged
|
||||||
|
|
||||||
self.num_qo_heads = (
|
self.num_qo_heads = (
|
||||||
model_runner.model_config.num_attention_heads // model_runner.tp_size
|
model_runner.model_config.num_attention_heads // model_runner.tp_size
|
||||||
@@ -71,20 +71,17 @@ class FlashinferUpdater:
|
|||||||
)
|
)
|
||||||
|
|
||||||
(
|
(
|
||||||
self.flashinfer_decode_wrapper,
|
self.decode_wrapper,
|
||||||
self.flashinfer_prefill_wrapper_ragged,
|
self.prefill_wrapper_ragged,
|
||||||
self.flashinfer_prefill_wrapper_paged,
|
self.prefill_wrapper_paged,
|
||||||
) = (
|
) = (
|
||||||
flashinfer_decode_wrapper,
|
decode_wrapper or self.model_runner.attn_backend.decode_wrapper,
|
||||||
self.model_runner.flashinfer_prefill_wrapper_ragged,
|
self.model_runner.attn_backend.prefill_wrapper_ragged,
|
||||||
self.model_runner.flashinfer_prefill_wrapper_paged,
|
self.model_runner.attn_backend.prefill_wrapper_paged,
|
||||||
)
|
)
|
||||||
# CUDA graph uses different flashinfer_decode_wrapper
|
|
||||||
if self.flashinfer_decode_wrapper is None:
|
|
||||||
self.flashinfer_decode_wrapper = self.model_runner.flashinfer_decode_wrapper
|
|
||||||
|
|
||||||
def _init_indices_no_window(self):
|
def _init_indices_no_sliding_window(self):
|
||||||
if self.flashinfer_use_ragged:
|
if self.use_ragged:
|
||||||
paged_kernel_lens = self.prefix_lens
|
paged_kernel_lens = self.prefix_lens
|
||||||
else:
|
else:
|
||||||
paged_kernel_lens = self.seq_lens
|
paged_kernel_lens = self.seq_lens
|
||||||
@@ -103,13 +100,13 @@ class FlashinferUpdater:
|
|||||||
paged_kernel_lens,
|
paged_kernel_lens,
|
||||||
self.kv_indptr,
|
self.kv_indptr,
|
||||||
None,
|
None,
|
||||||
self.model_runner.req_to_token_pool.req_to_token.size(1),
|
|
||||||
self.kv_indices,
|
self.kv_indices,
|
||||||
|
self.model_runner.req_to_token_pool.req_to_token.size(1),
|
||||||
)
|
)
|
||||||
|
|
||||||
def _init_indices_window(self, wrapper_id):
|
def _init_indices_sliding_window(self, wrapper_id):
|
||||||
# window attention use paged only
|
|
||||||
if wrapper_id == 0:
|
if wrapper_id == 0:
|
||||||
|
# window attention use paged only
|
||||||
if self.forward_mode.is_decode():
|
if self.forward_mode.is_decode():
|
||||||
paged_kernel_lens = torch.minimum(
|
paged_kernel_lens = torch.minimum(
|
||||||
self.seq_lens,
|
self.seq_lens,
|
||||||
@@ -123,6 +120,7 @@ class FlashinferUpdater:
|
|||||||
- self.prefix_lens,
|
- self.prefix_lens,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
# full attention
|
||||||
paged_kernel_lens = self.seq_lens
|
paged_kernel_lens = self.seq_lens
|
||||||
|
|
||||||
kv_start_idx = self.seq_lens - paged_kernel_lens
|
kv_start_idx = self.seq_lens - paged_kernel_lens
|
||||||
@@ -139,8 +137,8 @@ class FlashinferUpdater:
|
|||||||
paged_kernel_lens,
|
paged_kernel_lens,
|
||||||
self.kv_indptr,
|
self.kv_indptr,
|
||||||
kv_start_idx,
|
kv_start_idx,
|
||||||
self.model_runner.req_to_token_pool.req_to_token.size(1),
|
|
||||||
self.kv_indices,
|
self.kv_indices,
|
||||||
|
self.model_runner.req_to_token_pool.req_to_token.size(1),
|
||||||
)
|
)
|
||||||
|
|
||||||
def _update_decode_indices(self, decode_wrapper):
|
def _update_decode_indices(self, decode_wrapper):
|
||||||
@@ -164,7 +162,7 @@ class FlashinferUpdater:
|
|||||||
)
|
)
|
||||||
qo_indptr[1:] = torch.cumsum(self.seq_lens - self.prefix_lens, dim=0)
|
qo_indptr[1:] = torch.cumsum(self.seq_lens - self.prefix_lens, dim=0)
|
||||||
|
|
||||||
if self.flashinfer_use_ragged:
|
if self.use_ragged:
|
||||||
ragged_wrapper.end_forward()
|
ragged_wrapper.end_forward()
|
||||||
ragged_wrapper.begin_forward(
|
ragged_wrapper.begin_forward(
|
||||||
qo_indptr,
|
qo_indptr,
|
||||||
@@ -187,28 +185,28 @@ class FlashinferUpdater:
|
|||||||
1,
|
1,
|
||||||
)
|
)
|
||||||
|
|
||||||
def update_indices_no_window(self):
|
def update_indices_no_sliding_window(self):
|
||||||
self._init_indices_no_window()
|
self._init_indices_no_sliding_window()
|
||||||
|
|
||||||
if self.forward_mode.is_decode():
|
if self.forward_mode.is_decode():
|
||||||
self._update_decode_indices(self.flashinfer_decode_wrapper)
|
self._update_decode_indices(self.decode_wrapper)
|
||||||
else:
|
else:
|
||||||
self._update_extend_indices(
|
self._update_extend_indices(
|
||||||
self.flashinfer_prefill_wrapper_ragged,
|
self.prefill_wrapper_ragged,
|
||||||
self.flashinfer_prefill_wrapper_paged,
|
self.prefill_wrapper_paged,
|
||||||
)
|
)
|
||||||
|
|
||||||
def update_indices_window(self):
|
def update_indices_sliding_window(self):
|
||||||
assert self.flashinfer_use_ragged is False
|
assert self.use_ragged is False
|
||||||
|
|
||||||
for wrapper_id in range(2):
|
for wrapper_id in range(2):
|
||||||
self._init_indices_window(wrapper_id)
|
self._init_indices_sliding_window(wrapper_id)
|
||||||
if self.forward_mode.is_decode():
|
if self.forward_mode.is_decode():
|
||||||
self._update_decode_indices(self.flashinfer_decode_wrapper[wrapper_id])
|
self._update_decode_indices(self.decode_wrapper[wrapper_id])
|
||||||
else:
|
else:
|
||||||
self._update_extend_indices(
|
self._update_extend_indices(
|
||||||
None,
|
None,
|
||||||
self.flashinfer_prefill_wrapper_paged[wrapper_id],
|
self.prefill_wrapper_paged[wrapper_id],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -218,20 +216,20 @@ def update_flashinfer_indices(
|
|||||||
req_pool_indices,
|
req_pool_indices,
|
||||||
seq_lens,
|
seq_lens,
|
||||||
prefix_lens,
|
prefix_lens,
|
||||||
flashinfer_decode_wrapper=None,
|
decode_wrapper=None,
|
||||||
flashinfer_use_ragged=False,
|
use_ragged=False,
|
||||||
):
|
):
|
||||||
flashinfer_updater = FlashinferUpdater(
|
updater = FlashinferUpdater(
|
||||||
forward_mode,
|
forward_mode,
|
||||||
model_runner,
|
model_runner,
|
||||||
req_pool_indices,
|
req_pool_indices,
|
||||||
seq_lens,
|
seq_lens,
|
||||||
prefix_lens,
|
prefix_lens,
|
||||||
flashinfer_decode_wrapper,
|
decode_wrapper,
|
||||||
flashinfer_use_ragged,
|
use_ragged,
|
||||||
)
|
)
|
||||||
|
|
||||||
if model_runner.sliding_window_size is None:
|
if model_runner.sliding_window_size is None:
|
||||||
flashinfer_updater.update_indices_no_window()
|
updater.update_indices_no_sliding_window()
|
||||||
else:
|
else:
|
||||||
flashinfer_updater.update_indices_window()
|
updater.update_indices_sliding_window()
|
||||||
|
|||||||
@@ -15,25 +15,14 @@ limitations under the License.
|
|||||||
|
|
||||||
"""Radix attention."""
|
"""Radix attention."""
|
||||||
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from flashinfer.cascade import merge_state
|
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from sglang.global_config import global_config
|
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
||||||
from sglang.srt.layers.triton_attention.decode_attention import decode_attention_fwd
|
|
||||||
from sglang.srt.layers.triton_attention.extend_attention import extend_attention_fwd
|
|
||||||
from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetadata
|
|
||||||
from sglang.srt.model_executor.model_runner import global_server_args_dict
|
|
||||||
|
|
||||||
|
|
||||||
class RadixAttention(nn.Module):
|
class RadixAttention(nn.Module):
|
||||||
"""
|
"""
|
||||||
The attention layer implementation.
|
The attention layer implementation.
|
||||||
Now it has two backends: FlashInfer and Triton.
|
|
||||||
FlashInfer is faster and Triton is easier to customize.
|
|
||||||
It supports two operators: extend (i.e. prefill with cached prefix) and decode.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@@ -43,8 +32,8 @@ class RadixAttention(nn.Module):
|
|||||||
scaling: float,
|
scaling: float,
|
||||||
num_kv_heads: int,
|
num_kv_heads: int,
|
||||||
layer_id: int,
|
layer_id: int,
|
||||||
sliding_window_size: Optional[int] = None,
|
sliding_window_size: int = -1,
|
||||||
logit_cap: int = -1,
|
logit_cap: float = 0.0,
|
||||||
v_head_dim: int = -1,
|
v_head_dim: int = -1,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@@ -56,164 +45,14 @@ class RadixAttention(nn.Module):
|
|||||||
self.v_head_dim = v_head_dim if v_head_dim != -1 else head_dim
|
self.v_head_dim = v_head_dim if v_head_dim != -1 else head_dim
|
||||||
self.scaling = scaling
|
self.scaling = scaling
|
||||||
self.layer_id = layer_id
|
self.layer_id = layer_id
|
||||||
self.logit_cap = logit_cap if logit_cap is not None and logit_cap > 0 else 0
|
self.logit_cap = logit_cap
|
||||||
self.sliding_window_size = sliding_window_size if sliding_window_size else -1
|
self.sliding_window_size = sliding_window_size or -1
|
||||||
|
|
||||||
# Choose backend
|
|
||||||
if (
|
|
||||||
global_server_args_dict["attention_backend"] == "flashinfer"
|
|
||||||
and self.qk_head_dim == self.v_head_dim
|
|
||||||
):
|
|
||||||
self.extend_forward = self.extend_forward_flashinfer
|
|
||||||
self.decode_forward = self.decode_forward_flashinfer
|
|
||||||
elif global_server_args_dict["attention_backend"] == "triton":
|
|
||||||
self.extend_forward = self.extend_forward_triton
|
|
||||||
self.decode_forward = self.decode_forward_triton
|
|
||||||
else:
|
|
||||||
raise ValueError(
|
|
||||||
f"Invalid attention backend: {global_server_args_dict['attention_backend']}"
|
|
||||||
)
|
|
||||||
|
|
||||||
def extend_forward_triton(self, q, k, v, input_metadata: InputMetadata):
|
|
||||||
if self.qk_head_dim != self.v_head_dim:
|
|
||||||
o = q.new_empty((q.shape[0], self.tp_q_head_num * self.v_head_dim))
|
|
||||||
else:
|
|
||||||
o = torch.empty_like(q)
|
|
||||||
|
|
||||||
self.store_kv_cache(k, v, input_metadata)
|
|
||||||
extend_attention_fwd(
|
|
||||||
q.view(-1, self.tp_q_head_num, self.qk_head_dim),
|
|
||||||
k.contiguous(),
|
|
||||||
v.contiguous(),
|
|
||||||
o.view(-1, self.tp_q_head_num, self.v_head_dim),
|
|
||||||
input_metadata.token_to_kv_pool.get_key_buffer(self.layer_id),
|
|
||||||
input_metadata.token_to_kv_pool.get_value_buffer(self.layer_id),
|
|
||||||
input_metadata.req_to_token_pool.req_to_token,
|
|
||||||
input_metadata.req_pool_indices,
|
|
||||||
input_metadata.triton_start_loc,
|
|
||||||
input_metadata.seq_lens,
|
|
||||||
input_metadata.triton_prefix_lens,
|
|
||||||
input_metadata.extend_start_loc,
|
|
||||||
input_metadata.extend_seq_lens,
|
|
||||||
input_metadata.triton_max_seq_len,
|
|
||||||
input_metadata.triton_max_extend_len,
|
|
||||||
sm_scale=self.scaling,
|
|
||||||
logit_cap=self.logit_cap,
|
|
||||||
)
|
|
||||||
|
|
||||||
return o
|
|
||||||
|
|
||||||
def decode_forward_triton(self, q, k, v, input_metadata: InputMetadata):
|
|
||||||
if self.qk_head_dim != self.v_head_dim:
|
|
||||||
o = q.new_empty((q.shape[0], self.tp_q_head_num * self.v_head_dim))
|
|
||||||
else:
|
|
||||||
o = torch.empty_like(q)
|
|
||||||
self.store_kv_cache(k, v, input_metadata)
|
|
||||||
|
|
||||||
decode_attention_fwd(
|
|
||||||
q.view(-1, self.tp_q_head_num, self.qk_head_dim),
|
|
||||||
input_metadata.token_to_kv_pool.get_key_buffer(self.layer_id),
|
|
||||||
input_metadata.token_to_kv_pool.get_value_buffer(self.layer_id),
|
|
||||||
o.view(-1, self.tp_q_head_num, self.v_head_dim),
|
|
||||||
input_metadata.req_to_token_pool.req_to_token,
|
|
||||||
input_metadata.req_pool_indices,
|
|
||||||
input_metadata.triton_start_loc,
|
|
||||||
input_metadata.seq_lens,
|
|
||||||
input_metadata.triton_max_seq_len,
|
|
||||||
input_metadata.total_num_tokens,
|
|
||||||
sm_scale=self.scaling,
|
|
||||||
logit_cap=self.logit_cap,
|
|
||||||
)
|
|
||||||
|
|
||||||
return o
|
|
||||||
|
|
||||||
def extend_forward_flashinfer(self, q, k, v, input_metadata: InputMetadata):
|
|
||||||
# using two wrappers is unnecessary in the current PR, but are prepared for future PRs
|
|
||||||
prefill_wrapper_paged = input_metadata.flashinfer_prefill_wrapper_paged
|
|
||||||
if self.sliding_window_size != -1:
|
|
||||||
prefill_wrapper_paged = prefill_wrapper_paged[0]
|
|
||||||
else:
|
|
||||||
if isinstance(prefill_wrapper_paged, list):
|
|
||||||
prefill_wrapper_paged = prefill_wrapper_paged[1]
|
|
||||||
|
|
||||||
if not input_metadata.flashinfer_use_ragged:
|
|
||||||
if k is not None:
|
|
||||||
assert v is not None
|
|
||||||
self.store_kv_cache(k, v, input_metadata)
|
|
||||||
|
|
||||||
o = prefill_wrapper_paged.forward(
|
|
||||||
q.contiguous().view(-1, self.tp_q_head_num, self.head_dim),
|
|
||||||
input_metadata.token_to_kv_pool.get_kv_buffer(self.layer_id),
|
|
||||||
causal=True,
|
|
||||||
sm_scale=self.scaling,
|
|
||||||
window_left=self.sliding_window_size,
|
|
||||||
logits_soft_cap=self.logit_cap,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
o1, s1 = (
|
|
||||||
input_metadata.flashinfer_prefill_wrapper_ragged.forward_return_lse(
|
|
||||||
q.contiguous().view(-1, self.tp_q_head_num, self.head_dim),
|
|
||||||
k.contiguous().view(-1, self.tp_k_head_num, self.head_dim),
|
|
||||||
v.contiguous().view(-1, self.tp_v_head_num, self.head_dim),
|
|
||||||
causal=True,
|
|
||||||
sm_scale=self.scaling,
|
|
||||||
logits_soft_cap=self.logit_cap,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
if input_metadata.extend_no_prefix:
|
|
||||||
o = o1
|
|
||||||
else:
|
|
||||||
o2, s2 = prefill_wrapper_paged.forward_return_lse(
|
|
||||||
q.contiguous().view(-1, self.tp_q_head_num, self.head_dim),
|
|
||||||
input_metadata.token_to_kv_pool.get_kv_buffer(self.layer_id),
|
|
||||||
causal=False,
|
|
||||||
sm_scale=self.scaling,
|
|
||||||
logits_soft_cap=self.logit_cap,
|
|
||||||
)
|
|
||||||
|
|
||||||
o, _ = merge_state(o1, s1, o2, s2)
|
|
||||||
|
|
||||||
self.store_kv_cache(k, v, input_metadata)
|
|
||||||
|
|
||||||
if input_metadata.total_num_tokens >= global_config.layer_sync_threshold:
|
|
||||||
torch.cuda.synchronize()
|
|
||||||
|
|
||||||
return o.view(-1, self.tp_q_head_num * self.head_dim)
|
|
||||||
|
|
||||||
def decode_forward_flashinfer(self, q, k, v, input_metadata: InputMetadata):
|
|
||||||
decode_wrapper = input_metadata.flashinfer_decode_wrapper
|
|
||||||
if self.sliding_window_size != -1:
|
|
||||||
decode_wrapper = decode_wrapper[0]
|
|
||||||
else:
|
|
||||||
if isinstance(decode_wrapper, list):
|
|
||||||
decode_wrapper = decode_wrapper[1]
|
|
||||||
|
|
||||||
if k is not None:
|
|
||||||
assert v is not None
|
|
||||||
self.store_kv_cache(k, v, input_metadata)
|
|
||||||
|
|
||||||
o = decode_wrapper.forward(
|
|
||||||
q.contiguous().view(-1, self.tp_q_head_num, self.head_dim),
|
|
||||||
input_metadata.token_to_kv_pool.get_kv_buffer(self.layer_id),
|
|
||||||
sm_scale=self.scaling,
|
|
||||||
logits_soft_cap=self.logit_cap,
|
|
||||||
)
|
|
||||||
|
|
||||||
return o.view(-1, self.tp_q_head_num * self.head_dim)
|
|
||||||
|
|
||||||
def forward(self, q, k, v, input_metadata: InputMetadata):
|
def forward(self, q, k, v, input_metadata: InputMetadata):
|
||||||
if k is not None:
|
if k is not None:
|
||||||
|
# For cross-layer sharing, kv can be None
|
||||||
assert v is not None
|
assert v is not None
|
||||||
k = k.view(-1, self.tp_k_head_num, self.qk_head_dim)
|
k = k.view(-1, self.tp_k_head_num, self.qk_head_dim)
|
||||||
v = v.view(-1, self.tp_v_head_num, self.v_head_dim)
|
v = v.view(-1, self.tp_v_head_num, self.v_head_dim)
|
||||||
|
|
||||||
if input_metadata.forward_mode.is_extend():
|
return input_metadata.attn_backend.forward(q, k, v, self, input_metadata)
|
||||||
return self.extend_forward(q, k, v, input_metadata)
|
|
||||||
elif input_metadata.forward_mode.is_decode():
|
|
||||||
return self.decode_forward(q, k, v, input_metadata)
|
|
||||||
|
|
||||||
def store_kv_cache(self, cache_k, cache_v, input_metadata: InputMetadata):
|
|
||||||
input_metadata.token_to_kv_pool.set_kv_buffer(
|
|
||||||
self.layer_id, input_metadata.out_cache_loc, cache_k, cache_v
|
|
||||||
)
|
|
||||||
|
|||||||
@@ -15,6 +15,7 @@ limitations under the License.
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
Memory-efficient attention for decoding.
|
Memory-efficient attention for decoding.
|
||||||
|
It supports page size = 1.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Adapted from
|
# Adapted from
|
||||||
@@ -197,7 +198,6 @@ def _decode_att_m_fwd(
|
|||||||
logit_cap,
|
logit_cap,
|
||||||
):
|
):
|
||||||
BLOCK = 32
|
BLOCK = 32
|
||||||
# shape constraints
|
|
||||||
Lq, Lk = q.shape[-1], k_buffer.shape[-1]
|
Lq, Lk = q.shape[-1], k_buffer.shape[-1]
|
||||||
|
|
||||||
batch, head_num = B_req_idx.shape[0], q.shape[1]
|
batch, head_num = B_req_idx.shape[0], q.shape[1]
|
||||||
@@ -478,7 +478,6 @@ def _decode_grouped_att_m_fwd(
|
|||||||
logit_cap,
|
logit_cap,
|
||||||
):
|
):
|
||||||
BLOCK = 32
|
BLOCK = 32
|
||||||
# shape constraints
|
|
||||||
Lq, Lk = q.shape[-1], k_buffer.shape[-1]
|
Lq, Lk = q.shape[-1], k_buffer.shape[-1]
|
||||||
|
|
||||||
if Lk == 576:
|
if Lk == 576:
|
||||||
@@ -570,9 +569,9 @@ def _decode_grouped_softmax_reducev_fwd(
|
|||||||
BLOCK_DMODEL=BLOCK_DMODEL,
|
BLOCK_DMODEL=BLOCK_DMODEL,
|
||||||
BLOCK_N=BLOCK,
|
BLOCK_N=BLOCK,
|
||||||
BLOCK_H=BLOCK_H,
|
BLOCK_H=BLOCK_H,
|
||||||
|
Lv=Lv,
|
||||||
num_warps=num_warps,
|
num_warps=num_warps,
|
||||||
num_stages=1,
|
num_stages=1,
|
||||||
Lv=Lv,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -588,7 +587,7 @@ def decode_attention_fwd(
|
|||||||
max_len_in_batch,
|
max_len_in_batch,
|
||||||
total_num_tokens,
|
total_num_tokens,
|
||||||
sm_scale,
|
sm_scale,
|
||||||
logit_cap=-1,
|
logit_cap=0.0,
|
||||||
att_m=None,
|
att_m=None,
|
||||||
):
|
):
|
||||||
if att_m is None:
|
if att_m is None:
|
||||||
|
|||||||
@@ -61,14 +61,14 @@ def _fwd_kernel(
|
|||||||
stride_buf_vbs,
|
stride_buf_vbs,
|
||||||
stride_buf_vh,
|
stride_buf_vh,
|
||||||
stride_req_to_tokens_b,
|
stride_req_to_tokens_b,
|
||||||
|
logit_cap: tl.constexpr,
|
||||||
|
Lq: tl.constexpr,
|
||||||
|
Lv: tl.constexpr,
|
||||||
BLOCK_DMODEL: tl.constexpr,
|
BLOCK_DMODEL: tl.constexpr,
|
||||||
BLOCK_DPE: tl.constexpr,
|
BLOCK_DPE: tl.constexpr,
|
||||||
BLOCK_DV: tl.constexpr,
|
BLOCK_DV: tl.constexpr,
|
||||||
BLOCK_M: tl.constexpr,
|
BLOCK_M: tl.constexpr,
|
||||||
BLOCK_N: tl.constexpr,
|
BLOCK_N: tl.constexpr,
|
||||||
logit_cap: tl.constexpr,
|
|
||||||
Lq: tl.constexpr,
|
|
||||||
Lv: tl.constexpr,
|
|
||||||
):
|
):
|
||||||
cur_seq = tl.program_id(0)
|
cur_seq = tl.program_id(0)
|
||||||
cur_head = tl.program_id(1)
|
cur_head = tl.program_id(1)
|
||||||
@@ -111,7 +111,7 @@ def _fwd_kernel(
|
|||||||
)
|
)
|
||||||
qpe = tl.load(Q_Extend + offs_qpe, mask=mask_m[:, None], other=0.0)
|
qpe = tl.load(Q_Extend + offs_qpe, mask=mask_m[:, None], other=0.0)
|
||||||
|
|
||||||
# stage1: compute scores with prefix
|
# stage 1: compute scores with prefix
|
||||||
offs_n = tl.arange(0, BLOCK_N)
|
offs_n = tl.arange(0, BLOCK_N)
|
||||||
|
|
||||||
acc = tl.zeros([BLOCK_M, BLOCK_DV], dtype=tl.float32)
|
acc = tl.zeros([BLOCK_M, BLOCK_DV], dtype=tl.float32)
|
||||||
@@ -174,7 +174,7 @@ def _fwd_kernel(
|
|||||||
|
|
||||||
e_max = n_e_max
|
e_max = n_e_max
|
||||||
|
|
||||||
# stage2: compute the trianlge part
|
# stage 2: compute the trianlge part
|
||||||
|
|
||||||
cur_block_m_end = tl.minimum(cur_seq_len_extend, (cur_block_m + 1) * BLOCK_M)
|
cur_block_m_end = tl.minimum(cur_seq_len_extend, (cur_block_m + 1) * BLOCK_M)
|
||||||
for start_n in range(0, cur_block_m_end, BLOCK_N):
|
for start_n in range(0, cur_block_m_end, BLOCK_N):
|
||||||
@@ -255,26 +255,22 @@ def extend_attention_fwd(
|
|||||||
v_buffer,
|
v_buffer,
|
||||||
req_to_tokens,
|
req_to_tokens,
|
||||||
b_req_idx,
|
b_req_idx,
|
||||||
b_start_loc,
|
|
||||||
b_seq_len,
|
b_seq_len,
|
||||||
b_seq_len_prefix,
|
|
||||||
b_start_loc_extend,
|
|
||||||
b_seq_len_extend,
|
b_seq_len_extend,
|
||||||
max_len_in_batch,
|
b_start_loc_extend,
|
||||||
max_len_extend,
|
max_len_extend,
|
||||||
sm_scale=None,
|
sm_scale=None,
|
||||||
logit_cap=-1,
|
logit_cap=0.0,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
q_extend, k_extend, v_extend, o_extend: contiguous tensors
|
q_extend, k_extend, v_extend, o_extend: contiguous tensors
|
||||||
|
|
||||||
k_buffer, v_buffer: (prefix + extend) tensors in mem_manager
|
k_buffer, v_buffer: (prefix + extend) tensors in mem_manager
|
||||||
"""
|
"""
|
||||||
Lq, Lk, Lv, Lo = (
|
Lq, Lk, Lv = (
|
||||||
q_extend.shape[-1],
|
q_extend.shape[-1],
|
||||||
k_extend.shape[-1],
|
k_extend.shape[-1],
|
||||||
v_extend.shape[-1],
|
v_extend.shape[-1],
|
||||||
o_extend.shape[-1],
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if Lq == 576:
|
if Lq == 576:
|
||||||
@@ -303,7 +299,7 @@ def extend_attention_fwd(
|
|||||||
else:
|
else:
|
||||||
BLOCK_M, BLOCK_N = (64, 64) if Lq <= 128 else (32, 32)
|
BLOCK_M, BLOCK_N = (64, 64) if Lq <= 128 else (32, 32)
|
||||||
|
|
||||||
sm_scale = 1.0 / (Lq**0.5) if sm_scale is None else sm_scale
|
sm_scale = sm_scale or 1.0 / (Lq**0.5)
|
||||||
batch_size, head_num = b_seq_len.shape[0], q_extend.shape[1]
|
batch_size, head_num = b_seq_len.shape[0], q_extend.shape[1]
|
||||||
kv_group_num = q_extend.shape[1] // k_extend.shape[1]
|
kv_group_num = q_extend.shape[1] // k_extend.shape[1]
|
||||||
|
|
||||||
@@ -338,27 +334,24 @@ def extend_attention_fwd(
|
|||||||
v_buffer.stride(0),
|
v_buffer.stride(0),
|
||||||
v_buffer.stride(1),
|
v_buffer.stride(1),
|
||||||
req_to_tokens.stride(0),
|
req_to_tokens.stride(0),
|
||||||
|
logit_cap=logit_cap,
|
||||||
BLOCK_DMODEL=BLOCK_DMODEL,
|
BLOCK_DMODEL=BLOCK_DMODEL,
|
||||||
BLOCK_DPE=BLOCK_DPE,
|
BLOCK_DPE=BLOCK_DPE,
|
||||||
BLOCK_DV=BLOCK_DV,
|
BLOCK_DV=BLOCK_DV,
|
||||||
BLOCK_M=BLOCK_M,
|
BLOCK_M=BLOCK_M,
|
||||||
BLOCK_N=BLOCK_N,
|
BLOCK_N=BLOCK_N,
|
||||||
num_warps=num_warps,
|
|
||||||
num_stages=num_stages,
|
|
||||||
logit_cap=logit_cap,
|
|
||||||
Lq=Lq,
|
Lq=Lq,
|
||||||
Lv=Lv,
|
Lv=Lv,
|
||||||
|
num_warps=num_warps,
|
||||||
|
num_stages=num_stages,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def redundant_attention(
|
def redundant_attention(
|
||||||
q_extend,
|
q_extend,
|
||||||
k_extend,
|
|
||||||
v_extend,
|
|
||||||
o_extend,
|
o_extend,
|
||||||
k_buffer,
|
k_buffer,
|
||||||
v_buffer,
|
v_buffer,
|
||||||
req_to_tokens,
|
|
||||||
b_req_idx,
|
b_req_idx,
|
||||||
b_start_loc,
|
b_start_loc,
|
||||||
b_seq_len,
|
b_seq_len,
|
||||||
|
|||||||
@@ -368,7 +368,7 @@ class ScheduleBatch:
|
|||||||
)
|
)
|
||||||
|
|
||||||
def batch_size(self):
|
def batch_size(self):
|
||||||
return len(self.reqs) if self.reqs is not None else 0
|
return len(self.reqs) if self.reqs else 0
|
||||||
|
|
||||||
def is_empty(self):
|
def is_empty(self):
|
||||||
return len(self.reqs) == 0
|
return len(self.reqs) == 0
|
||||||
|
|||||||
@@ -13,15 +13,13 @@ See the License for the specific language governing permissions and
|
|||||||
limitations under the License.
|
limitations under the License.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
"""Run the model with cuda graph."""
|
"""Run the model with cuda graph and torch.compile."""
|
||||||
|
|
||||||
import bisect
|
import bisect
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from typing import Callable, List
|
from typing import Callable
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from flashinfer import BatchDecodeWithPagedKVCacheWrapper
|
|
||||||
from flashinfer.decode import _grouped_size_compiled_for_decode_kernels
|
|
||||||
from vllm.distributed.parallel_state import graph_capture
|
from vllm.distributed.parallel_state import graph_capture
|
||||||
from vllm.model_executor.custom_op import CustomOp
|
from vllm.model_executor.custom_op import CustomOp
|
||||||
|
|
||||||
@@ -55,6 +53,7 @@ def _to_torch(model: torch.nn.Module, reverse: bool = False):
|
|||||||
def patch_model(
|
def patch_model(
|
||||||
model: torch.nn.Module, enable_compile: bool, tp_group: "GroupCoordinator"
|
model: torch.nn.Module, enable_compile: bool, tp_group: "GroupCoordinator"
|
||||||
):
|
):
|
||||||
|
"""Patch the model to make it compatible with with torch.compile"""
|
||||||
backup_ca_comm = None
|
backup_ca_comm = None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -86,23 +85,28 @@ def set_torch_compile_config():
|
|||||||
|
|
||||||
|
|
||||||
class CudaGraphRunner:
|
class CudaGraphRunner:
|
||||||
def __init__(
|
"""A CudaGraphRunner runs the forward pass of a model with cuda graph and torch.compile."""
|
||||||
self,
|
|
||||||
model_runner: "ModelRunner",
|
def __init__(self, model_runner: "ModelRunner"):
|
||||||
max_batch_size_to_capture: int,
|
# Parse args
|
||||||
use_torch_compile: bool,
|
|
||||||
disable_padding: bool,
|
|
||||||
):
|
|
||||||
self.model_runner = model_runner
|
self.model_runner = model_runner
|
||||||
self.graphs = {}
|
self.graphs = {}
|
||||||
self.input_buffers = {}
|
self.input_buffers = {}
|
||||||
self.output_buffers = {}
|
self.output_buffers = {}
|
||||||
self.flashinfer_handlers = {}
|
self.flashinfer_handlers = {}
|
||||||
self.graph_memory_pool = None
|
self.graph_memory_pool = None
|
||||||
self.disable_padding = disable_padding
|
self.use_torch_compile = model_runner.server_args.enable_torch_compile
|
||||||
|
self.disable_padding = model_runner.server_args.disable_cuda_graph_padding
|
||||||
|
|
||||||
|
# Batch sizes to capture
|
||||||
|
if self.model_runner.server_args.disable_cuda_graph_padding:
|
||||||
|
self.capture_bs = list(range(1, 32)) + [64, 128]
|
||||||
|
else:
|
||||||
|
self.capture_bs = [1, 2, 4] + [i * 8 for i in range(1, 21)]
|
||||||
|
self.compile_bs = [1, 2, 4, 8, 16, 24, 32] if self.use_torch_compile else []
|
||||||
|
|
||||||
# Common inputs
|
# Common inputs
|
||||||
self.max_bs = max_batch_size_to_capture
|
self.max_bs = max(self.capture_bs)
|
||||||
self.input_ids = torch.zeros((self.max_bs,), dtype=torch.int32, device="cuda")
|
self.input_ids = torch.zeros((self.max_bs,), dtype=torch.int32, device="cuda")
|
||||||
self.req_pool_indices = torch.zeros(
|
self.req_pool_indices = torch.zeros(
|
||||||
(self.max_bs,), dtype=torch.int32, device="cuda"
|
(self.max_bs,), dtype=torch.int32, device="cuda"
|
||||||
@@ -115,56 +119,39 @@ class CudaGraphRunner:
|
|||||||
(self.max_bs,), dtype=torch.int32, device="cuda"
|
(self.max_bs,), dtype=torch.int32, device="cuda"
|
||||||
)
|
)
|
||||||
|
|
||||||
# FlashInfer inputs
|
# Attention backend
|
||||||
self.flashinfer_kv_indptr = torch.zeros(
|
self.model_runner.attn_backend.init_cuda_graph_state(self.max_bs)
|
||||||
(self.max_bs + 1,), dtype=torch.int32, device="cuda"
|
|
||||||
)
|
|
||||||
self.flashinfer_kv_indices = torch.zeros(
|
|
||||||
(self.max_bs * model_runner.model_config.context_len,),
|
|
||||||
dtype=torch.int32,
|
|
||||||
device="cuda",
|
|
||||||
)
|
|
||||||
self.flashinfer_kv_last_page_len = torch.ones(
|
|
||||||
(self.max_bs,), dtype=torch.int32, device="cuda"
|
|
||||||
)
|
|
||||||
if model_runner.sliding_window_size is None:
|
|
||||||
self.flashinfer_workspace_buffer = (
|
|
||||||
self.model_runner.flashinfer_workspace_buffer
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
self.flashinfer_workspace_buffer = (
|
|
||||||
self.model_runner.flashinfer_workspace_buffer
|
|
||||||
)
|
|
||||||
|
|
||||||
self.flashinfer_kv_indptr = [
|
# Sampling info
|
||||||
self.flashinfer_kv_indptr,
|
|
||||||
self.flashinfer_kv_indptr.clone(),
|
|
||||||
]
|
|
||||||
self.flashinfer_kv_indices = [
|
|
||||||
self.flashinfer_kv_indices,
|
|
||||||
self.flashinfer_kv_indices.clone(),
|
|
||||||
]
|
|
||||||
|
|
||||||
# Sampling inputs
|
|
||||||
vocab_size = model_runner.model_config.vocab_size
|
vocab_size = model_runner.model_config.vocab_size
|
||||||
self.sampling_info = SamplingBatchInfo.dummy_one(self.max_bs, vocab_size)
|
self.sampling_info = SamplingBatchInfo.dummy_one(self.max_bs, vocab_size)
|
||||||
|
|
||||||
self.compile_bs = [1, 2, 4, 8, 16, 24, 32] if use_torch_compile else []
|
if self.use_torch_compile:
|
||||||
|
|
||||||
if use_torch_compile:
|
|
||||||
set_torch_compile_config()
|
set_torch_compile_config()
|
||||||
|
|
||||||
|
# Capture
|
||||||
|
try:
|
||||||
|
self.capture()
|
||||||
|
except RuntimeError as e:
|
||||||
|
raise Exception(
|
||||||
|
f"Capture cuda graph failed: {e}\n"
|
||||||
|
"Possible solutions:\n"
|
||||||
|
"1. disable cuda graph by --disable-cuda-graph\n"
|
||||||
|
"2. set --mem-fraction-static to a smaller value\n"
|
||||||
|
"3. disable torch compile by not using --enable-torch-compile\n"
|
||||||
|
"Open an issue on GitHub https://github.com/sgl-project/sglang/issues/new/choose \n"
|
||||||
|
)
|
||||||
|
|
||||||
def can_run(self, batch_size: int):
|
def can_run(self, batch_size: int):
|
||||||
if self.disable_padding:
|
if self.disable_padding:
|
||||||
return batch_size in self.graphs
|
return batch_size in self.graphs
|
||||||
else:
|
else:
|
||||||
return batch_size <= self.max_bs
|
return batch_size <= self.max_bs
|
||||||
|
|
||||||
def capture(self, batch_size_list: List[int]):
|
def capture(self):
|
||||||
self.batch_size_list = batch_size_list
|
|
||||||
with graph_capture() as graph_capture_context:
|
with graph_capture() as graph_capture_context:
|
||||||
self.stream = graph_capture_context.stream
|
self.stream = graph_capture_context.stream
|
||||||
for bs in batch_size_list:
|
for bs in self.capture_bs:
|
||||||
with patch_model(
|
with patch_model(
|
||||||
self.model_runner.model,
|
self.model_runner.model,
|
||||||
bs in self.compile_bs,
|
bs in self.compile_bs,
|
||||||
@@ -172,14 +159,10 @@ class CudaGraphRunner:
|
|||||||
) as forward:
|
) as forward:
|
||||||
(
|
(
|
||||||
graph,
|
graph,
|
||||||
input_buffers,
|
|
||||||
output_buffers,
|
output_buffers,
|
||||||
flashinfer_handler,
|
|
||||||
) = self.capture_one_batch_size(bs, forward)
|
) = self.capture_one_batch_size(bs, forward)
|
||||||
self.graphs[bs] = graph
|
self.graphs[bs] = graph
|
||||||
self.input_buffers[bs] = input_buffers
|
|
||||||
self.output_buffers[bs] = output_buffers
|
self.output_buffers[bs] = output_buffers
|
||||||
self.flashinfer_handlers[bs] = flashinfer_handler
|
|
||||||
|
|
||||||
def capture_one_batch_size(self, bs: int, forward: Callable):
|
def capture_one_batch_size(self, bs: int, forward: Callable):
|
||||||
graph = torch.cuda.CUDAGraph()
|
graph = torch.cuda.CUDAGraph()
|
||||||
@@ -192,48 +175,9 @@ class CudaGraphRunner:
|
|||||||
position_ids_offsets = self.position_ids_offsets[:bs]
|
position_ids_offsets = self.position_ids_offsets[:bs]
|
||||||
out_cache_loc = self.out_cache_loc[:bs]
|
out_cache_loc = self.out_cache_loc[:bs]
|
||||||
|
|
||||||
# FlashInfer inputs
|
# Attention backend
|
||||||
if not _grouped_size_compiled_for_decode_kernels(
|
self.model_runner.attn_backend.capture_cuda_graph_init(
|
||||||
self.model_runner.model_config.num_attention_heads
|
bs, req_pool_indices, seq_lens
|
||||||
// self.model_runner.tp_size,
|
|
||||||
self.model_runner.model_config.get_num_kv_heads(self.model_runner.tp_size),
|
|
||||||
):
|
|
||||||
use_tensor_cores = True
|
|
||||||
else:
|
|
||||||
use_tensor_cores = False
|
|
||||||
if self.model_runner.sliding_window_size is None:
|
|
||||||
flashinfer_decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
|
|
||||||
self.flashinfer_workspace_buffer,
|
|
||||||
"NHD",
|
|
||||||
use_cuda_graph=True,
|
|
||||||
use_tensor_cores=use_tensor_cores,
|
|
||||||
paged_kv_indptr_buffer=self.flashinfer_kv_indptr[: bs + 1],
|
|
||||||
paged_kv_indices_buffer=self.flashinfer_kv_indices,
|
|
||||||
paged_kv_last_page_len_buffer=self.flashinfer_kv_last_page_len[:bs],
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
flashinfer_decode_wrapper = []
|
|
||||||
for i in range(2):
|
|
||||||
flashinfer_decode_wrapper.append(
|
|
||||||
BatchDecodeWithPagedKVCacheWrapper(
|
|
||||||
self.flashinfer_workspace_buffer,
|
|
||||||
"NHD",
|
|
||||||
use_cuda_graph=True,
|
|
||||||
use_tensor_cores=use_tensor_cores,
|
|
||||||
paged_kv_indptr_buffer=self.flashinfer_kv_indptr[i][: bs + 1],
|
|
||||||
paged_kv_indices_buffer=self.flashinfer_kv_indices[i],
|
|
||||||
paged_kv_last_page_len_buffer=self.flashinfer_kv_last_page_len[
|
|
||||||
:bs
|
|
||||||
],
|
|
||||||
)
|
|
||||||
)
|
|
||||||
update_flashinfer_indices(
|
|
||||||
ForwardMode.DECODE,
|
|
||||||
self.model_runner,
|
|
||||||
req_pool_indices,
|
|
||||||
seq_lens,
|
|
||||||
None,
|
|
||||||
flashinfer_decode_wrapper,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Run and capture
|
# Run and capture
|
||||||
@@ -246,13 +190,12 @@ class CudaGraphRunner:
|
|||||||
seq_lens=seq_lens,
|
seq_lens=seq_lens,
|
||||||
req_to_token_pool=self.model_runner.req_to_token_pool,
|
req_to_token_pool=self.model_runner.req_to_token_pool,
|
||||||
token_to_kv_pool=self.model_runner.token_to_kv_pool,
|
token_to_kv_pool=self.model_runner.token_to_kv_pool,
|
||||||
|
attn_backend=self.model_runner.attn_backend,
|
||||||
out_cache_loc=out_cache_loc,
|
out_cache_loc=out_cache_loc,
|
||||||
return_logprob=False,
|
return_logprob=False,
|
||||||
top_logprobs_nums=0,
|
top_logprobs_nums=0,
|
||||||
positions=(seq_lens - 1 + position_ids_offsets).to(torch.int64),
|
positions=(seq_lens - 1 + position_ids_offsets).to(torch.int64),
|
||||||
flashinfer_decode_wrapper=flashinfer_decode_wrapper,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return forward(input_ids, input_metadata.positions, input_metadata)
|
return forward(input_ids, input_metadata.positions, input_metadata)
|
||||||
|
|
||||||
for _ in range(2):
|
for _ in range(2):
|
||||||
@@ -274,15 +217,15 @@ class CudaGraphRunner:
|
|||||||
self.model_runner.tp_group.barrier()
|
self.model_runner.tp_group.barrier()
|
||||||
|
|
||||||
self.graph_memory_pool = graph.pool()
|
self.graph_memory_pool = graph.pool()
|
||||||
return graph, None, out, flashinfer_decode_wrapper
|
return graph, out
|
||||||
|
|
||||||
def replay(self, batch: ScheduleBatch):
|
def replay(self, batch: ScheduleBatch):
|
||||||
assert batch.out_cache_loc is not None
|
assert batch.out_cache_loc is not None
|
||||||
raw_bs = len(batch.reqs)
|
raw_bs = len(batch.reqs)
|
||||||
|
|
||||||
# Pad
|
# Pad
|
||||||
index = bisect.bisect_left(self.batch_size_list, raw_bs)
|
index = bisect.bisect_left(self.capture_bs, raw_bs)
|
||||||
bs = self.batch_size_list[index]
|
bs = self.capture_bs[index]
|
||||||
if bs != raw_bs:
|
if bs != raw_bs:
|
||||||
self.seq_lens.zero_()
|
self.seq_lens.zero_()
|
||||||
self.position_ids_offsets.fill_(1)
|
self.position_ids_offsets.fill_(1)
|
||||||
@@ -295,14 +238,9 @@ class CudaGraphRunner:
|
|||||||
self.position_ids_offsets[:raw_bs] = batch.position_ids_offsets
|
self.position_ids_offsets[:raw_bs] = batch.position_ids_offsets
|
||||||
self.out_cache_loc[:raw_bs] = batch.out_cache_loc
|
self.out_cache_loc[:raw_bs] = batch.out_cache_loc
|
||||||
|
|
||||||
# FlashInfer inputs
|
# Attention backend
|
||||||
update_flashinfer_indices(
|
self.model_runner.attn_backend.replay_cuda_graph_init(
|
||||||
ForwardMode.DECODE,
|
bs, self.req_pool_indices, self.seq_lens
|
||||||
self.model_runner,
|
|
||||||
self.req_pool_indices[:bs],
|
|
||||||
self.seq_lens[:bs],
|
|
||||||
None,
|
|
||||||
self.flashinfer_handlers[bs],
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Sampling inputs
|
# Sampling inputs
|
||||||
|
|||||||
@@ -23,9 +23,8 @@ from typing import TYPE_CHECKING, List
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from sglang.srt.layers.flashinfer_utils import update_flashinfer_indices
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
from sglang.srt.layers.attention_backend import AttentionBackend
|
||||||
from sglang.srt.managers.schedule_batch import ScheduleBatch
|
from sglang.srt.managers.schedule_batch import ScheduleBatch
|
||||||
from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
|
from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
|
||||||
from sglang.srt.model_executor.model_runner import ModelRunner
|
from sglang.srt.model_executor.model_runner import ModelRunner
|
||||||
@@ -66,12 +65,11 @@ class InputMetadata:
|
|||||||
seq_lens: torch.Tensor
|
seq_lens: torch.Tensor
|
||||||
req_to_token_pool: ReqToTokenPool
|
req_to_token_pool: ReqToTokenPool
|
||||||
token_to_kv_pool: BaseTokenToKVPool
|
token_to_kv_pool: BaseTokenToKVPool
|
||||||
|
attn_backend: AttentionBackend
|
||||||
|
|
||||||
# Output location of the KV cache
|
# Output location of the KV cache
|
||||||
out_cache_loc: torch.Tensor
|
out_cache_loc: torch.Tensor
|
||||||
|
|
||||||
total_num_tokens: int = None
|
|
||||||
|
|
||||||
# Position information
|
# Position information
|
||||||
positions: torch.Tensor = None
|
positions: torch.Tensor = None
|
||||||
|
|
||||||
@@ -93,18 +91,6 @@ class InputMetadata:
|
|||||||
image_offsets: List[List[int]] = None
|
image_offsets: List[List[int]] = None
|
||||||
modalities: List[List[str]] = None
|
modalities: List[List[str]] = None
|
||||||
|
|
||||||
# Trition attention backend
|
|
||||||
triton_max_seq_len: int = 0
|
|
||||||
triton_max_extend_len: int = 0
|
|
||||||
triton_start_loc: torch.Tensor = None
|
|
||||||
triton_prefix_lens: torch.Tensor = None
|
|
||||||
|
|
||||||
# FlashInfer attention backend
|
|
||||||
flashinfer_prefill_wrapper_ragged: "BatchPrefillWithRaggedKVCacheWrapper" = None
|
|
||||||
flashinfer_prefill_wrapper_paged: "BatchPrefillWithPagedKVCacheWrapper" = None
|
|
||||||
flashinfer_decode_wrapper: "BatchDecodeWithPagedKVCacheWrapper" = None
|
|
||||||
flashinfer_use_ragged: bool = False
|
|
||||||
|
|
||||||
def init_multimuldal_info(self, batch: ScheduleBatch):
|
def init_multimuldal_info(self, batch: ScheduleBatch):
|
||||||
reqs = batch.reqs
|
reqs = batch.reqs
|
||||||
self.pixel_values = [r.pixel_values for r in reqs]
|
self.pixel_values = [r.pixel_values for r in reqs]
|
||||||
@@ -154,32 +140,27 @@ class InputMetadata:
|
|||||||
self.positions = self.positions.to(torch.int64)
|
self.positions = self.positions.to(torch.int64)
|
||||||
|
|
||||||
def compute_extend_infos(self, batch: ScheduleBatch):
|
def compute_extend_infos(self, batch: ScheduleBatch):
|
||||||
if self.forward_mode.is_decode():
|
extend_lens_cpu = [
|
||||||
self.extend_seq_lens = self.extend_start_loc = self.extend_no_prefix = None
|
len(r.fill_ids) - batch.prefix_lens_cpu[i] for i, r in enumerate(batch.reqs)
|
||||||
self.extend_seq_lens_cpu = self.logprob_start_lens_cpu = None
|
]
|
||||||
else:
|
self.extend_seq_lens = torch.tensor(extend_lens_cpu, device="cuda")
|
||||||
extend_lens_cpu = [
|
self.extend_prefix_lens = torch.tensor(batch.prefix_lens_cpu, device="cuda")
|
||||||
len(r.fill_ids) - batch.prefix_lens_cpu[i]
|
self.extend_start_loc = torch.zeros_like(self.seq_lens)
|
||||||
for i, r in enumerate(batch.reqs)
|
self.extend_start_loc[1:] = torch.cumsum(self.extend_seq_lens[:-1], dim=0)
|
||||||
]
|
self.extend_no_prefix = all(l == 0 for l in batch.prefix_lens_cpu)
|
||||||
self.extend_seq_lens = torch.tensor(extend_lens_cpu, device="cuda")
|
|
||||||
self.extend_prefix_lens = torch.tensor(batch.prefix_lens_cpu, device="cuda")
|
|
||||||
self.extend_start_loc = torch.zeros_like(self.seq_lens)
|
|
||||||
self.extend_start_loc[1:] = torch.cumsum(self.extend_seq_lens[:-1], dim=0)
|
|
||||||
self.extend_no_prefix = all(l == 0 for l in batch.prefix_lens_cpu)
|
|
||||||
|
|
||||||
self.extend_seq_lens_cpu = extend_lens_cpu
|
self.extend_seq_lens_cpu = extend_lens_cpu
|
||||||
self.logprob_start_lens_cpu = [
|
self.logprob_start_lens_cpu = [
|
||||||
(
|
(
|
||||||
min(
|
min(
|
||||||
req.logprob_start_len - batch.prefix_lens_cpu[i],
|
req.logprob_start_len - batch.prefix_lens_cpu[i],
|
||||||
extend_lens_cpu[i] - 1,
|
extend_lens_cpu[i] - 1,
|
||||||
)
|
|
||||||
if req.logprob_start_len >= batch.prefix_lens_cpu[i]
|
|
||||||
else extend_lens_cpu[i] - 1 # Fake extend, actually decode
|
|
||||||
)
|
)
|
||||||
for i, req in enumerate(batch.reqs)
|
if req.logprob_start_len >= batch.prefix_lens_cpu[i]
|
||||||
]
|
else extend_lens_cpu[i] - 1 # Fake extend, actually decode
|
||||||
|
)
|
||||||
|
for i, req in enumerate(batch.reqs)
|
||||||
|
]
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_schedule_batch(
|
def from_schedule_batch(
|
||||||
@@ -195,6 +176,7 @@ class InputMetadata:
|
|||||||
seq_lens=batch.seq_lens,
|
seq_lens=batch.seq_lens,
|
||||||
req_to_token_pool=model_runner.req_to_token_pool,
|
req_to_token_pool=model_runner.req_to_token_pool,
|
||||||
token_to_kv_pool=model_runner.token_to_kv_pool,
|
token_to_kv_pool=model_runner.token_to_kv_pool,
|
||||||
|
attn_backend=model_runner.attn_backend,
|
||||||
out_cache_loc=batch.out_cache_loc,
|
out_cache_loc=batch.out_cache_loc,
|
||||||
return_logprob=batch.return_logprob,
|
return_logprob=batch.return_logprob,
|
||||||
top_logprobs_nums=batch.top_logprobs_nums,
|
top_logprobs_nums=batch.top_logprobs_nums,
|
||||||
@@ -202,76 +184,12 @@ class InputMetadata:
|
|||||||
|
|
||||||
ret.sampling_info.update_penalties()
|
ret.sampling_info.update_penalties()
|
||||||
ret.sampling_info.update_regex_vocab_mask(batch)
|
ret.sampling_info.update_regex_vocab_mask(batch)
|
||||||
|
|
||||||
ret.compute_positions(batch)
|
ret.compute_positions(batch)
|
||||||
|
|
||||||
ret.compute_extend_infos(batch)
|
if not batch.forward_mode.is_decode():
|
||||||
|
|
||||||
fm = batch.forward_mode
|
|
||||||
if not fm.is_decode() or model_runner.server_args.attention_backend == "triton":
|
|
||||||
ret.total_num_tokens = int(torch.sum(ret.seq_lens))
|
|
||||||
|
|
||||||
if not fm.is_decode():
|
|
||||||
ret.init_multimuldal_info(batch)
|
ret.init_multimuldal_info(batch)
|
||||||
|
ret.compute_extend_infos(batch)
|
||||||
|
|
||||||
if model_runner.server_args.attention_backend == "triton":
|
model_runner.attn_backend.init_forward_metadata(batch, ret)
|
||||||
ret.init_triton_args(batch)
|
|
||||||
|
|
||||||
flashinfer_use_ragged = False
|
|
||||||
if model_runner.server_args.attention_backend == "flashinfer":
|
|
||||||
if (
|
|
||||||
not fm.is_decode()
|
|
||||||
and int(torch.sum(ret.seq_lens)) > 4096
|
|
||||||
and model_runner.sliding_window_size is None
|
|
||||||
):
|
|
||||||
flashinfer_use_ragged = True
|
|
||||||
ret.init_flashinfer_handlers(
|
|
||||||
model_runner, batch.prefix_lens_cpu, flashinfer_use_ragged
|
|
||||||
)
|
|
||||||
|
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
def init_triton_args(self, batch: ScheduleBatch):
|
|
||||||
"""Init auxiliary variables for triton attention backend."""
|
|
||||||
self.triton_max_seq_len = int(torch.max(self.seq_lens))
|
|
||||||
self.triton_start_loc = torch.zeros_like(self.seq_lens, dtype=torch.int32)
|
|
||||||
self.triton_start_loc[1:] = torch.cumsum(self.seq_lens[:-1], dim=0)
|
|
||||||
|
|
||||||
if self.forward_mode.is_decode():
|
|
||||||
self.triton_max_extend_len = None
|
|
||||||
else:
|
|
||||||
self.triton_prefix_lens = torch.tensor(batch.prefix_lens_cpu, device="cuda")
|
|
||||||
extend_seq_lens = self.seq_lens - self.triton_prefix_lens
|
|
||||||
self.triton_max_extend_len = int(torch.max(extend_seq_lens))
|
|
||||||
|
|
||||||
def init_flashinfer_handlers(
|
|
||||||
self,
|
|
||||||
model_runner,
|
|
||||||
prefix_lens_cpu,
|
|
||||||
flashinfer_use_ragged,
|
|
||||||
):
|
|
||||||
if self.forward_mode.is_decode():
|
|
||||||
prefix_lens = None
|
|
||||||
else:
|
|
||||||
prefix_lens = self.extend_prefix_lens
|
|
||||||
|
|
||||||
update_flashinfer_indices(
|
|
||||||
self.forward_mode,
|
|
||||||
model_runner,
|
|
||||||
self.req_pool_indices,
|
|
||||||
self.seq_lens,
|
|
||||||
prefix_lens,
|
|
||||||
flashinfer_use_ragged=flashinfer_use_ragged,
|
|
||||||
)
|
|
||||||
|
|
||||||
(
|
|
||||||
self.flashinfer_prefill_wrapper_ragged,
|
|
||||||
self.flashinfer_prefill_wrapper_paged,
|
|
||||||
self.flashinfer_decode_wrapper,
|
|
||||||
self.flashinfer_use_ragged,
|
|
||||||
) = (
|
|
||||||
model_runner.flashinfer_prefill_wrapper_ragged,
|
|
||||||
model_runner.flashinfer_prefill_wrapper_paged,
|
|
||||||
model_runner.flashinfer_decode_wrapper,
|
|
||||||
flashinfer_use_ragged,
|
|
||||||
)
|
|
||||||
|
|||||||
@@ -25,12 +25,6 @@ from typing import Optional, Tuple, Type
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from flashinfer import (
|
|
||||||
BatchDecodeWithPagedKVCacheWrapper,
|
|
||||||
BatchPrefillWithPagedKVCacheWrapper,
|
|
||||||
BatchPrefillWithRaggedKVCacheWrapper,
|
|
||||||
)
|
|
||||||
from flashinfer.decode import _grouped_size_compiled_for_decode_kernels
|
|
||||||
from vllm.config import DeviceConfig, LoadConfig
|
from vllm.config import DeviceConfig, LoadConfig
|
||||||
from vllm.config import ModelConfig as VllmModelConfig
|
from vllm.config import ModelConfig as VllmModelConfig
|
||||||
from vllm.distributed import (
|
from vllm.distributed import (
|
||||||
@@ -43,8 +37,8 @@ from vllm.distributed.parallel_state import in_the_same_node_as
|
|||||||
from vllm.model_executor.model_loader import get_model
|
from vllm.model_executor.model_loader import get_model
|
||||||
from vllm.model_executor.models import ModelRegistry
|
from vllm.model_executor.models import ModelRegistry
|
||||||
|
|
||||||
from sglang.global_config import global_config
|
|
||||||
from sglang.srt.configs.model_config import AttentionArch, ModelConfig
|
from sglang.srt.configs.model_config import AttentionArch, ModelConfig
|
||||||
|
from sglang.srt.layers.attention_backend import FlashInferAttnBackend, TritonAttnBackend
|
||||||
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
||||||
from sglang.srt.layers.sampler import SampleOutput
|
from sglang.srt.layers.sampler import SampleOutput
|
||||||
from sglang.srt.managers.schedule_batch import ScheduleBatch, global_server_args_dict
|
from sglang.srt.managers.schedule_batch import ScheduleBatch, global_server_args_dict
|
||||||
@@ -69,6 +63,8 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
|
|
||||||
class ModelRunner:
|
class ModelRunner:
|
||||||
|
"""ModelRunner runs the forward passes of the models."""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
model_config: ModelConfig,
|
model_config: ModelConfig,
|
||||||
@@ -100,6 +96,7 @@ class ModelRunner:
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Model-specific adjustment
|
||||||
if self.is_multimodal_model:
|
if self.is_multimodal_model:
|
||||||
logger.info(
|
logger.info(
|
||||||
"Automatically turn off --chunked-prefill-size and adjust --mem-fraction-static for multimodal models."
|
"Automatically turn off --chunked-prefill-size and adjust --mem-fraction-static for multimodal models."
|
||||||
@@ -107,6 +104,7 @@ class ModelRunner:
|
|||||||
server_args.chunked_prefill_size = None
|
server_args.chunked_prefill_size = None
|
||||||
server_args.mem_fraction_static *= 0.95
|
server_args.mem_fraction_static *= 0.95
|
||||||
|
|
||||||
|
# Init componnets
|
||||||
min_per_gpu_memory = self.init_torch_distributed()
|
min_per_gpu_memory = self.init_torch_distributed()
|
||||||
self.load_model()
|
self.load_model()
|
||||||
self.init_memory_pool(
|
self.init_memory_pool(
|
||||||
@@ -115,7 +113,7 @@ class ModelRunner:
|
|||||||
server_args.max_total_tokens,
|
server_args.max_total_tokens,
|
||||||
)
|
)
|
||||||
self.init_cublas()
|
self.init_cublas()
|
||||||
self.init_flashinfer()
|
self.init_attention_backend()
|
||||||
self.init_cuda_graphs()
|
self.init_cuda_graphs()
|
||||||
|
|
||||||
def init_torch_distributed(self):
|
def init_torch_distributed(self):
|
||||||
@@ -397,9 +395,6 @@ class ModelRunner:
|
|||||||
qk_rope_head_dim=self.model_config.qk_rope_head_dim,
|
qk_rope_head_dim=self.model_config.qk_rope_head_dim,
|
||||||
layer_num=self.model_config.num_hidden_layers,
|
layer_num=self.model_config.num_hidden_layers,
|
||||||
)
|
)
|
||||||
logger.info("using MLA Triton implementaion, flashinfer is disabled")
|
|
||||||
# FIXME: temporarily only Triton MLA is supported
|
|
||||||
self.server_args.attention_backend = "triton"
|
|
||||||
else:
|
else:
|
||||||
self.token_to_kv_pool = MHATokenToKVPool(
|
self.token_to_kv_pool = MHATokenToKVPool(
|
||||||
self.max_total_num_tokens,
|
self.max_total_num_tokens,
|
||||||
@@ -422,106 +417,42 @@ class ModelRunner:
|
|||||||
c = a @ b
|
c = a @ b
|
||||||
return c
|
return c
|
||||||
|
|
||||||
def init_flashinfer(self):
|
def init_attention_backend(self):
|
||||||
"""Init flashinfer attention kernel wrappers."""
|
"""Init attention kernel backend."""
|
||||||
if self.server_args.attention_backend != "flashinfer":
|
if self.server_args.attention_backend == "flashinfer":
|
||||||
assert (
|
self.attn_backend = FlashInferAttnBackend(self)
|
||||||
self.sliding_window_size is None
|
elif self.server_args.attention_backend == "triton":
|
||||||
), "turn on flashinfer to support window attention"
|
assert self.sliding_window_size is None, (
|
||||||
self.flashinfer_prefill_wrapper_ragged = None
|
"Window attention is not supported in the triton attention backend. "
|
||||||
self.flashinfer_prefill_wrapper_paged = None
|
"Please use `--attention-backend flashinfer`."
|
||||||
self.flashinfer_decode_wrapper = None
|
)
|
||||||
return
|
self.attn_backend = TritonAttnBackend(self)
|
||||||
|
|
||||||
if not _grouped_size_compiled_for_decode_kernels(
|
|
||||||
self.model_config.num_attention_heads // self.tp_size,
|
|
||||||
self.model_config.get_num_kv_heads(self.tp_size),
|
|
||||||
):
|
|
||||||
use_tensor_cores = True
|
|
||||||
else:
|
else:
|
||||||
use_tensor_cores = False
|
raise ValueError(
|
||||||
|
f"Invalid attention backend: {self.server_args.attention_backend}"
|
||||||
if self.sliding_window_size is None:
|
|
||||||
self.flashinfer_workspace_buffer = torch.empty(
|
|
||||||
global_config.flashinfer_workspace_size,
|
|
||||||
dtype=torch.uint8,
|
|
||||||
device="cuda",
|
|
||||||
)
|
)
|
||||||
self.flashinfer_prefill_wrapper_ragged = (
|
|
||||||
BatchPrefillWithRaggedKVCacheWrapper(
|
|
||||||
self.flashinfer_workspace_buffer, "NHD"
|
|
||||||
)
|
|
||||||
)
|
|
||||||
self.flashinfer_prefill_wrapper_paged = BatchPrefillWithPagedKVCacheWrapper(
|
|
||||||
self.flashinfer_workspace_buffer, "NHD"
|
|
||||||
)
|
|
||||||
self.flashinfer_decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
|
|
||||||
self.flashinfer_workspace_buffer,
|
|
||||||
"NHD",
|
|
||||||
use_tensor_cores=use_tensor_cores,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
self.flashinfer_workspace_buffer = torch.empty(
|
|
||||||
global_config.flashinfer_workspace_size,
|
|
||||||
dtype=torch.uint8,
|
|
||||||
device="cuda",
|
|
||||||
)
|
|
||||||
self.flashinfer_prefill_wrapper_ragged = None
|
|
||||||
self.flashinfer_prefill_wrapper_paged = []
|
|
||||||
self.flashinfer_decode_wrapper = []
|
|
||||||
for i in range(2):
|
|
||||||
self.flashinfer_prefill_wrapper_paged.append(
|
|
||||||
BatchPrefillWithPagedKVCacheWrapper(
|
|
||||||
self.flashinfer_workspace_buffer, "NHD"
|
|
||||||
)
|
|
||||||
)
|
|
||||||
self.flashinfer_decode_wrapper.append(
|
|
||||||
BatchDecodeWithPagedKVCacheWrapper(
|
|
||||||
self.flashinfer_workspace_buffer,
|
|
||||||
"NHD",
|
|
||||||
use_tensor_cores=use_tensor_cores,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
def init_cuda_graphs(self):
|
def init_cuda_graphs(self):
|
||||||
"""Capture cuda graphs."""
|
"""Capture cuda graphs."""
|
||||||
|
from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner
|
||||||
|
|
||||||
|
self.cuda_graph_runner = None
|
||||||
|
|
||||||
if not self.is_generation:
|
if not self.is_generation:
|
||||||
# TODO: Currently, cuda graph only captures decode steps, which only exists for generation models
|
# TODO: Currently, cuda graph only captures decode steps, which only exists for generation models
|
||||||
return
|
return
|
||||||
|
|
||||||
from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner
|
if self.server_args.disable_cuda_graph:
|
||||||
|
return
|
||||||
|
|
||||||
if (
|
if self.server_args.attention_backend != "flashinfer":
|
||||||
self.server_args.disable_cuda_graph
|
logger.warning(
|
||||||
or self.server_args.attention_backend != "flashinfer"
|
f"Cuda graph is not supported for attention backend: {self.server_args.attention_backend}"
|
||||||
):
|
)
|
||||||
self.cuda_graph_runner = None
|
|
||||||
return
|
return
|
||||||
|
|
||||||
logger.info("Capture cuda graph begin. This can take up to several minutes.")
|
logger.info("Capture cuda graph begin. This can take up to several minutes.")
|
||||||
|
self.cuda_graph_runner = CudaGraphRunner(self)
|
||||||
if self.server_args.disable_cuda_graph_padding:
|
|
||||||
batch_size_list = list(range(1, 32)) + [64, 128]
|
|
||||||
else:
|
|
||||||
batch_size_list = [1, 2, 4] + [i * 8 for i in range(1, 21)]
|
|
||||||
|
|
||||||
self.cuda_graph_runner = CudaGraphRunner(
|
|
||||||
self,
|
|
||||||
max_batch_size_to_capture=max(batch_size_list),
|
|
||||||
use_torch_compile=self.server_args.enable_torch_compile,
|
|
||||||
disable_padding=self.server_args.disable_cuda_graph_padding,
|
|
||||||
)
|
|
||||||
try:
|
|
||||||
self.cuda_graph_runner.capture(batch_size_list)
|
|
||||||
except RuntimeError as e:
|
|
||||||
raise Exception(
|
|
||||||
f"Capture cuda graph failed: {e}\n"
|
|
||||||
"Possible solutions:\n"
|
|
||||||
"1. disable cuda graph by --disable-cuda-graph\n"
|
|
||||||
"2. set --mem-fraction-static to a smaller value\n"
|
|
||||||
"3. disable torch compile by not using --enable-torch-compile\n"
|
|
||||||
"Open an issue on GitHub https://github.com/sgl-project/sglang/issues/new/choose \n"
|
|
||||||
)
|
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def forward_decode(self, batch: ScheduleBatch):
|
def forward_decode(self, batch: ScheduleBatch):
|
||||||
|
|||||||
@@ -143,18 +143,16 @@ class SamplingBatchInfo:
|
|||||||
self.linear_penalties = penalizer.apply(self.linear_penalties)
|
self.linear_penalties = penalizer.apply(self.linear_penalties)
|
||||||
|
|
||||||
def update_regex_vocab_mask(self, batch: ScheduleBatch):
|
def update_regex_vocab_mask(self, batch: ScheduleBatch):
|
||||||
bs, reqs = batch.batch_size(), batch.reqs
|
has_regex = any(req.regex_fsm is not None for req in batch.reqs)
|
||||||
device = "cuda"
|
|
||||||
has_regex = any(req.regex_fsm is not None for req in reqs)
|
|
||||||
|
|
||||||
# Reset the vocab mask
|
# Reset the vocab mask
|
||||||
self.vocab_mask = None
|
self.vocab_mask = None
|
||||||
|
|
||||||
if has_regex:
|
if has_regex:
|
||||||
self.vocab_mask = torch.zeros(
|
self.vocab_mask = torch.zeros(
|
||||||
bs, self.vocab_size, dtype=torch.bool, device=device
|
batch.batch_size(), self.vocab_size, dtype=torch.bool, device="cuda"
|
||||||
)
|
)
|
||||||
for i, req in enumerate(reqs):
|
for i, req in enumerate(batch.reqs):
|
||||||
if req.regex_fsm is not None:
|
if req.regex_fsm is not None:
|
||||||
self.vocab_mask[i].fill_(1)
|
self.vocab_mask[i].fill_(1)
|
||||||
self.vocab_mask[i][
|
self.vocab_mask[i][
|
||||||
|
|||||||
@@ -335,23 +335,19 @@ def launch_server(
|
|||||||
return
|
return
|
||||||
|
|
||||||
# Launch processes
|
# Launch processes
|
||||||
tokenizer_manager = TokenizerManager(server_args, port_args)
|
|
||||||
if server_args.chat_template:
|
|
||||||
load_chat_template_for_openai_api(tokenizer_manager, server_args.chat_template)
|
|
||||||
pipe_controller_reader, pipe_controller_writer = mp.Pipe(duplex=False)
|
pipe_controller_reader, pipe_controller_writer = mp.Pipe(duplex=False)
|
||||||
pipe_detoken_reader, pipe_detoken_writer = mp.Pipe(duplex=False)
|
|
||||||
|
|
||||||
if server_args.dp_size == 1:
|
if server_args.dp_size == 1:
|
||||||
start_controller_process = start_controller_process_single
|
start_controller_process = start_controller_process_single
|
||||||
else:
|
else:
|
||||||
start_controller_process = start_controller_process_multi
|
start_controller_process = start_controller_process_multi
|
||||||
|
|
||||||
proc_controller = mp.Process(
|
proc_controller = mp.Process(
|
||||||
target=start_controller_process,
|
target=start_controller_process,
|
||||||
args=(server_args, port_args, pipe_controller_writer),
|
args=(server_args, port_args, pipe_controller_writer),
|
||||||
)
|
)
|
||||||
proc_controller.start()
|
proc_controller.start()
|
||||||
|
|
||||||
|
pipe_detoken_reader, pipe_detoken_writer = mp.Pipe(duplex=False)
|
||||||
proc_detoken = mp.Process(
|
proc_detoken = mp.Process(
|
||||||
target=start_detokenizer_process,
|
target=start_detokenizer_process,
|
||||||
args=(
|
args=(
|
||||||
@@ -362,6 +358,10 @@ def launch_server(
|
|||||||
)
|
)
|
||||||
proc_detoken.start()
|
proc_detoken.start()
|
||||||
|
|
||||||
|
tokenizer_manager = TokenizerManager(server_args, port_args)
|
||||||
|
if server_args.chat_template:
|
||||||
|
load_chat_template_for_openai_api(tokenizer_manager, server_args.chat_template)
|
||||||
|
|
||||||
# Wait for the model to finish loading
|
# Wait for the model to finish loading
|
||||||
controller_init_state = pipe_controller_reader.recv()
|
controller_init_state = pipe_controller_reader.recv()
|
||||||
detoken_init_state = pipe_detoken_reader.recv()
|
detoken_init_state = pipe_detoken_reader.recv()
|
||||||
|
|||||||
@@ -83,8 +83,8 @@ class ServerArgs:
|
|||||||
json_model_override_args: str = "{}"
|
json_model_override_args: str = "{}"
|
||||||
|
|
||||||
# Optimization/debug options
|
# Optimization/debug options
|
||||||
attention_backend: str = "flashinfer"
|
attention_backend: Optional[str] = None
|
||||||
sampling_backend: str = "flashinfer"
|
sampling_backend: Optional[str] = None
|
||||||
|
|
||||||
disable_flashinfer: bool = False
|
disable_flashinfer: bool = False
|
||||||
disable_flashinfer_sampling: bool = False
|
disable_flashinfer_sampling: bool = False
|
||||||
@@ -148,6 +148,17 @@ class ServerArgs:
|
|||||||
)
|
)
|
||||||
self.sampling_backend = "pytorch"
|
self.sampling_backend = "pytorch"
|
||||||
|
|
||||||
|
# Default kernel backends
|
||||||
|
if self.enable_mla:
|
||||||
|
logger.info("MLA optimization is tunred on. Use triton backend.")
|
||||||
|
self.attention_backend = "triton"
|
||||||
|
|
||||||
|
if self.attention_backend is None:
|
||||||
|
self.attention_backend = "flashinfer"
|
||||||
|
|
||||||
|
if self.sampling_backend is None:
|
||||||
|
self.sampling_backend = "flashinfer"
|
||||||
|
|
||||||
# Model-specific patches
|
# Model-specific patches
|
||||||
if "Alibaba-NLP/gte-Qwen2-1.5B-instruct" == self.model_path:
|
if "Alibaba-NLP/gte-Qwen2-1.5B-instruct" == self.model_path:
|
||||||
logger.info(
|
logger.info(
|
||||||
|
|||||||
@@ -55,8 +55,8 @@ class TestCreateKvIndices(unittest.TestCase):
|
|||||||
paged_kernel_lens,
|
paged_kernel_lens,
|
||||||
kv_indptr,
|
kv_indptr,
|
||||||
None,
|
None,
|
||||||
req_to_token.size(1),
|
|
||||||
kv_indices_triton,
|
kv_indices_triton,
|
||||||
|
req_to_token.size(1),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Check
|
# Check
|
||||||
|
|||||||
@@ -19,7 +19,8 @@ class TestServingThroughput(unittest.TestCase):
|
|||||||
other_args = []
|
other_args = []
|
||||||
if disable_radix_cache:
|
if disable_radix_cache:
|
||||||
other_args.append("--disable-radix-cache")
|
other_args.append("--disable-radix-cache")
|
||||||
other_args.extend(["--attention-backend", attention_backend])
|
if attention_backend:
|
||||||
|
other_args.extend(["--attention-backend", attention_backend])
|
||||||
other_args.extend(["--chunked-prefill-size", str(chunked_prefill_size)])
|
other_args.extend(["--chunked-prefill-size", str(chunked_prefill_size)])
|
||||||
other_args.extend(["--tensor-parallel-size", "2"])
|
other_args.extend(["--tensor-parallel-size", "2"])
|
||||||
|
|
||||||
|
|||||||
@@ -19,7 +19,8 @@ class TestServingThroughput(unittest.TestCase):
|
|||||||
other_args = []
|
other_args = []
|
||||||
if disable_radix_cache:
|
if disable_radix_cache:
|
||||||
other_args.append("--disable-radix-cache")
|
other_args.append("--disable-radix-cache")
|
||||||
other_args.extend(["--attention-backend", attention_backend])
|
if attention_backend:
|
||||||
|
other_args.extend(["--attention-backend", attention_backend])
|
||||||
other_args.extend(["--chunked-prefill-size", str(chunked_prefill_size)])
|
other_args.extend(["--chunked-prefill-size", str(chunked_prefill_size)])
|
||||||
|
|
||||||
model = DEFAULT_MODEL_NAME_FOR_TEST
|
model = DEFAULT_MODEL_NAME_FOR_TEST
|
||||||
|
|||||||
@@ -96,23 +96,17 @@ class TestExtendAttention(unittest.TestCase):
|
|||||||
v_buffer,
|
v_buffer,
|
||||||
req_to_tokens,
|
req_to_tokens,
|
||||||
b_req_idx,
|
b_req_idx,
|
||||||
b_start_loc,
|
|
||||||
b_seq_len,
|
b_seq_len,
|
||||||
b_seq_len_prefix,
|
|
||||||
b_start_loc_extend,
|
|
||||||
b_seq_len_extend,
|
b_seq_len_extend,
|
||||||
max_len_in_batch,
|
b_start_loc_extend,
|
||||||
max_len_extend,
|
max_len_extend,
|
||||||
)
|
)
|
||||||
|
|
||||||
redundant_attention(
|
redundant_attention(
|
||||||
q_extend,
|
q_extend,
|
||||||
k_extend,
|
|
||||||
v_extend,
|
|
||||||
o_redundant,
|
o_redundant,
|
||||||
k_buffer,
|
k_buffer,
|
||||||
v_buffer,
|
v_buffer,
|
||||||
req_to_tokens,
|
|
||||||
b_req_idx,
|
b_req_idx,
|
||||||
b_start_loc,
|
b_start_loc,
|
||||||
b_seq_len,
|
b_seq_len,
|
||||||
|
|||||||
Reference in New Issue
Block a user