Refactor attention backend (#1381)
This commit is contained in:
383
python/sglang/srt/layers/attention_backend.py
Normal file
383
python/sglang/srt/layers/attention_backend.py
Normal file
@@ -0,0 +1,383 @@
|
||||
from __future__ import annotations
|
||||
|
||||
"""
|
||||
Support different attention backends.
|
||||
Now there are two backends: FlashInfer and Triton.
|
||||
FlashInfer is faster and Triton is easier to customize.
|
||||
Each backend supports two operators: extend (i.e. prefill with cached prefix) and decode.
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from flashinfer import (
|
||||
BatchDecodeWithPagedKVCacheWrapper,
|
||||
BatchPrefillWithPagedKVCacheWrapper,
|
||||
BatchPrefillWithRaggedKVCacheWrapper,
|
||||
)
|
||||
from flashinfer.cascade import merge_state
|
||||
from flashinfer.decode import _grouped_size_compiled_for_decode_kernels
|
||||
|
||||
from sglang.global_config import global_config
|
||||
from sglang.srt.layers.flashinfer_utils import update_flashinfer_indices
|
||||
from sglang.srt.managers.schedule_batch import ScheduleBatch
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetadata
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.model_executor.model_runner import ModelRunner
|
||||
|
||||
|
||||
class AttentionBackend(ABC):
|
||||
"""The base class of attention backends"""
|
||||
|
||||
@abstractmethod
|
||||
def init_forward_metadata(
|
||||
self, batch: ScheduleBatch, input_metadata: InputMetadata
|
||||
):
|
||||
pass
|
||||
|
||||
def forward(self, q, k, v, layer, input_metadata: InputMetadata):
|
||||
if input_metadata.forward_mode.is_decode():
|
||||
return self.forward_decode(q, k, v, layer, input_metadata)
|
||||
else:
|
||||
return self.forward_extend(q, k, v, layer, input_metadata)
|
||||
|
||||
|
||||
class FlashInferAttnBackend(AttentionBackend):
|
||||
"""Flashinfer attention kernels."""
|
||||
|
||||
def __init__(self, model_runner: ModelRunner):
|
||||
super().__init__()
|
||||
self.model_runner = model_runner
|
||||
|
||||
if not _grouped_size_compiled_for_decode_kernels(
|
||||
model_runner.model_config.num_attention_heads // model_runner.tp_size,
|
||||
model_runner.model_config.get_num_kv_heads(model_runner.tp_size),
|
||||
):
|
||||
self.decode_use_tensor_cores = True
|
||||
else:
|
||||
self.decode_use_tensor_cores = False
|
||||
|
||||
self.workspace_buffer = torch.empty(
|
||||
global_config.flashinfer_workspace_size,
|
||||
dtype=torch.uint8,
|
||||
device="cuda",
|
||||
)
|
||||
|
||||
if model_runner.sliding_window_size is None:
|
||||
self.prefill_wrapper_ragged = BatchPrefillWithRaggedKVCacheWrapper(
|
||||
self.workspace_buffer, "NHD"
|
||||
)
|
||||
self.prefill_wrapper_paged = BatchPrefillWithPagedKVCacheWrapper(
|
||||
self.workspace_buffer, "NHD"
|
||||
)
|
||||
self.decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
|
||||
self.workspace_buffer,
|
||||
"NHD",
|
||||
use_tensor_cores=self.decode_use_tensor_cores,
|
||||
)
|
||||
else:
|
||||
# Two wrappers: one for sliding window attention and one for full attention.
|
||||
# Using two wrappers is unnecessary in the current PR, but are prepared for future PRs
|
||||
self.prefill_wrapper_ragged = None
|
||||
self.prefill_wrapper_paged = []
|
||||
self.decode_wrapper = []
|
||||
for _ in range(2):
|
||||
self.prefill_wrapper_paged.append(
|
||||
BatchPrefillWithPagedKVCacheWrapper(self.workspace_buffer, "NHD")
|
||||
)
|
||||
self.decode_wrapper.append(
|
||||
BatchDecodeWithPagedKVCacheWrapper(
|
||||
self.workspace_buffer,
|
||||
"NHD",
|
||||
use_tensor_cores=self.decode_use_tensor_cores,
|
||||
)
|
||||
)
|
||||
|
||||
self.forward_metadata = None
|
||||
self.cuda_graph_metadata = {}
|
||||
|
||||
def init_forward_metadata(
|
||||
self, batch: ScheduleBatch, input_metadata: InputMetadata
|
||||
):
|
||||
if input_metadata.forward_mode.is_decode():
|
||||
prefix_lens = None
|
||||
use_ragged = False
|
||||
total_num_tokens = None
|
||||
else:
|
||||
prefix_lens = input_metadata.extend_prefix_lens
|
||||
|
||||
# Some heuristics to check whether to use ragged forward
|
||||
use_ragged = False
|
||||
if (
|
||||
int(torch.sum(input_metadata.seq_lens)) > 4096
|
||||
and self.model_runner.sliding_window_size is None
|
||||
):
|
||||
use_ragged = True
|
||||
|
||||
total_num_tokens = torch.sum(input_metadata.seq_lens).item()
|
||||
|
||||
update_flashinfer_indices(
|
||||
input_metadata.forward_mode,
|
||||
self.model_runner,
|
||||
input_metadata.req_pool_indices,
|
||||
input_metadata.seq_lens,
|
||||
prefix_lens,
|
||||
use_ragged=use_ragged,
|
||||
)
|
||||
|
||||
self.forward_metadata = (use_ragged, total_num_tokens, self.decode_wrapper)
|
||||
|
||||
def init_cuda_graph_state(self, max_bs: int):
|
||||
self.cuda_graph_kv_indptr = torch.zeros(
|
||||
(max_bs + 1,), dtype=torch.int32, device="cuda"
|
||||
)
|
||||
self.cuda_graph_kv_indices = torch.zeros(
|
||||
(max_bs * self.model_runner.model_config.context_len,),
|
||||
dtype=torch.int32,
|
||||
device="cuda",
|
||||
)
|
||||
self.cuda_graph_kv_last_page_len = torch.ones(
|
||||
(max_bs,), dtype=torch.int32, device="cuda"
|
||||
)
|
||||
|
||||
if self.model_runner.sliding_window_size is not None:
|
||||
self.cuda_graph_kv_indptr = [
|
||||
self.cuda_graph_kv_indptr,
|
||||
self.cuda_graph_kv_indptr.clone(),
|
||||
]
|
||||
self.cuda_graph_kv_indices = [
|
||||
self.cuda_graph_kv_indices,
|
||||
self.cuda_graph_kv_indices.clone(),
|
||||
]
|
||||
|
||||
def capture_cuda_graph_init(self, bs: int, req_pool_indices, seq_lens):
|
||||
if self.model_runner.sliding_window_size is None:
|
||||
decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
|
||||
self.workspace_buffer,
|
||||
"NHD",
|
||||
use_cuda_graph=True,
|
||||
use_tensor_cores=self.decode_use_tensor_cores,
|
||||
paged_kv_indptr_buffer=self.cuda_graph_kv_indptr[: bs + 1],
|
||||
paged_kv_indices_buffer=self.cuda_graph_kv_indices,
|
||||
paged_kv_last_page_len_buffer=self.cuda_graph_kv_last_page_len[:bs],
|
||||
)
|
||||
else:
|
||||
decode_wrapper = []
|
||||
for i in range(2):
|
||||
decode_wrapper.append(
|
||||
BatchDecodeWithPagedKVCacheWrapper(
|
||||
self.workspace_buffer,
|
||||
"NHD",
|
||||
use_cuda_graph=True,
|
||||
use_tensor_cores=self.decode_use_tensor_cores,
|
||||
paged_kv_indptr_buffer=self.cuda_graph_kv_indptr[i][: bs + 1],
|
||||
paged_kv_indices_buffer=self.cuda_graph_kv_indices[i],
|
||||
paged_kv_last_page_len_buffer=self.cuda_graph_kv_last_page_len[
|
||||
:bs
|
||||
],
|
||||
)
|
||||
)
|
||||
|
||||
update_flashinfer_indices(
|
||||
ForwardMode.DECODE,
|
||||
self.model_runner,
|
||||
req_pool_indices,
|
||||
seq_lens,
|
||||
None,
|
||||
decode_wrapper,
|
||||
)
|
||||
|
||||
self.cuda_graph_metadata[bs] = decode_wrapper
|
||||
|
||||
self.forward_metadata = (False, None, decode_wrapper)
|
||||
|
||||
def replay_cuda_graph_init(self, bs: int, req_pool_indices, seq_lens):
|
||||
update_flashinfer_indices(
|
||||
ForwardMode.DECODE,
|
||||
self.model_runner,
|
||||
req_pool_indices[:bs],
|
||||
seq_lens[:bs],
|
||||
None,
|
||||
self.cuda_graph_metadata[bs],
|
||||
)
|
||||
|
||||
def forward_extend(self, q, k, v, layer: nn.Module, input_metadata: InputMetadata):
|
||||
if not isinstance(self.prefill_wrapper_paged, list):
|
||||
prefill_wrapper_paged = self.prefill_wrapper_paged
|
||||
else:
|
||||
if layer.sliding_window_size != -1:
|
||||
prefill_wrapper_paged = self.prefill_wrapper_paged[0]
|
||||
else:
|
||||
prefill_wrapper_paged = self.prefill_wrapper_paged[1]
|
||||
|
||||
use_ragged, total_num_tokens, decode_wrapper = self.forward_metadata
|
||||
|
||||
if not use_ragged:
|
||||
if k is not None:
|
||||
assert v is not None
|
||||
input_metadata.token_to_kv_pool.set_kv_buffer(
|
||||
layer.layer_id, input_metadata.out_cache_loc, k, v
|
||||
)
|
||||
o = prefill_wrapper_paged.forward(
|
||||
q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
|
||||
input_metadata.token_to_kv_pool.get_kv_buffer(layer.layer_id),
|
||||
causal=True,
|
||||
sm_scale=layer.scaling,
|
||||
window_left=layer.sliding_window_size,
|
||||
logits_soft_cap=layer.logit_cap,
|
||||
)
|
||||
else:
|
||||
o1, s1 = self.prefill_wrapper_ragged.forward_return_lse(
|
||||
q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
|
||||
k.contiguous().view(-1, layer.tp_k_head_num, layer.head_dim),
|
||||
v.contiguous().view(-1, layer.tp_v_head_num, layer.head_dim),
|
||||
causal=True,
|
||||
sm_scale=layer.scaling,
|
||||
logits_soft_cap=layer.logit_cap,
|
||||
)
|
||||
|
||||
if input_metadata.extend_no_prefix:
|
||||
o = o1
|
||||
else:
|
||||
o2, s2 = prefill_wrapper_paged.forward_return_lse(
|
||||
q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
|
||||
input_metadata.token_to_kv_pool.get_kv_buffer(layer.layer_id),
|
||||
causal=False,
|
||||
sm_scale=layer.scaling,
|
||||
logits_soft_cap=layer.logit_cap,
|
||||
)
|
||||
|
||||
o, _ = merge_state(o1, s1, o2, s2)
|
||||
|
||||
input_metadata.token_to_kv_pool.set_kv_buffer(
|
||||
layer.layer_id, input_metadata.out_cache_loc, k, v
|
||||
)
|
||||
|
||||
if total_num_tokens >= global_config.layer_sync_threshold:
|
||||
torch.cuda.synchronize()
|
||||
|
||||
return o.view(-1, layer.tp_q_head_num * layer.head_dim)
|
||||
|
||||
def forward_decode(self, q, k, v, layer: nn.Module, input_metadata: InputMetadata):
|
||||
use_ragged, total_num_tokens, decode_wrapper = self.forward_metadata
|
||||
|
||||
if isinstance(decode_wrapper, list):
|
||||
if layer.sliding_window_size != -1:
|
||||
decode_wrapper = decode_wrapper[0]
|
||||
else:
|
||||
decode_wrapper = decode_wrapper[1]
|
||||
|
||||
if k is not None:
|
||||
assert v is not None
|
||||
input_metadata.token_to_kv_pool.set_kv_buffer(
|
||||
layer.layer_id, input_metadata.out_cache_loc, k, v
|
||||
)
|
||||
|
||||
o = decode_wrapper.forward(
|
||||
q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
|
||||
input_metadata.token_to_kv_pool.get_kv_buffer(layer.layer_id),
|
||||
sm_scale=layer.scaling,
|
||||
logits_soft_cap=layer.logit_cap,
|
||||
)
|
||||
|
||||
return o.view(-1, layer.tp_q_head_num * layer.head_dim)
|
||||
|
||||
|
||||
class TritonAttnBackend(AttentionBackend):
|
||||
def __init__(self, model_runner: ModelRunner):
|
||||
# Lazy import to avoid the initialization of cuda context
|
||||
from sglang.srt.layers.triton_attention.decode_attention import (
|
||||
decode_attention_fwd,
|
||||
)
|
||||
from sglang.srt.layers.triton_attention.extend_attention import (
|
||||
extend_attention_fwd,
|
||||
)
|
||||
|
||||
super().__init__()
|
||||
|
||||
self.decode_attention_fwd = decode_attention_fwd
|
||||
self.extend_attention_fwd = extend_attention_fwd
|
||||
|
||||
self.forward_metadata = None
|
||||
|
||||
def init_forward_metadata(
|
||||
self, batch: ScheduleBatch, input_metadata: InputMetadata
|
||||
):
|
||||
"""Init auxiliary variables for triton attention backend."""
|
||||
|
||||
if input_metadata.forward_mode.is_decode():
|
||||
max_seq_len = torch.max(input_metadata.seq_lens).item()
|
||||
start_loc = torch.zeros_like(input_metadata.seq_lens, dtype=torch.int32)
|
||||
start_loc[1:] = torch.cumsum(input_metadata.seq_lens[:-1], dim=0)
|
||||
|
||||
total_num_tokens = torch.sum(input_metadata.seq_lens).item()
|
||||
max_extend_len = None
|
||||
else:
|
||||
start_loc = max_seq_len = total_num_tokens = None
|
||||
prefix_lens = torch.tensor(batch.prefix_lens_cpu, device="cuda")
|
||||
max_extend_len = torch.max(input_metadata.seq_lens - prefix_lens).item()
|
||||
|
||||
self.forward_metadata = start_loc, max_seq_len, max_extend_len, total_num_tokens
|
||||
|
||||
def forward_extend(self, q, k, v, layer: nn.Module, input_metadata: InputMetadata):
|
||||
if layer.qk_head_dim != layer.v_head_dim:
|
||||
o = q.new_empty((q.shape[0], layer.tp_q_head_num * layer.v_head_dim))
|
||||
else:
|
||||
o = torch.empty_like(q)
|
||||
|
||||
input_metadata.token_to_kv_pool.set_kv_buffer(
|
||||
layer.layer_id, input_metadata.out_cache_loc, k, v
|
||||
)
|
||||
|
||||
start_loc, max_seq_len, max_extend_len, total_num_tokens = self.forward_metadata
|
||||
|
||||
self.extend_attention_fwd(
|
||||
q.view(-1, layer.tp_q_head_num, layer.qk_head_dim),
|
||||
k.contiguous(),
|
||||
v.contiguous(),
|
||||
o.view(-1, layer.tp_q_head_num, layer.v_head_dim),
|
||||
input_metadata.token_to_kv_pool.get_key_buffer(layer.layer_id),
|
||||
input_metadata.token_to_kv_pool.get_value_buffer(layer.layer_id),
|
||||
input_metadata.req_to_token_pool.req_to_token,
|
||||
input_metadata.req_pool_indices,
|
||||
input_metadata.seq_lens,
|
||||
input_metadata.extend_seq_lens,
|
||||
input_metadata.extend_start_loc,
|
||||
max_extend_len,
|
||||
layer.scaling,
|
||||
layer.logit_cap,
|
||||
)
|
||||
|
||||
return o
|
||||
|
||||
def forward_decode(self, q, k, v, layer: nn.Module, input_metadata: InputMetadata):
|
||||
if layer.qk_head_dim != layer.v_head_dim:
|
||||
o = q.new_empty((q.shape[0], layer.tp_q_head_num * layer.v_head_dim))
|
||||
else:
|
||||
o = torch.empty_like(q)
|
||||
|
||||
start_loc, max_seq_len, max_extend_len, total_num_tokens = self.forward_metadata
|
||||
|
||||
input_metadata.token_to_kv_pool.set_kv_buffer(
|
||||
layer.layer_id, input_metadata.out_cache_loc, k, v
|
||||
)
|
||||
|
||||
self.decode_attention_fwd(
|
||||
q.view(-1, layer.tp_q_head_num, layer.qk_head_dim),
|
||||
input_metadata.token_to_kv_pool.get_key_buffer(layer.layer_id),
|
||||
input_metadata.token_to_kv_pool.get_value_buffer(layer.layer_id),
|
||||
o.view(-1, layer.tp_q_head_num, layer.v_head_dim),
|
||||
input_metadata.req_to_token_pool.req_to_token,
|
||||
input_metadata.req_pool_indices,
|
||||
start_loc,
|
||||
input_metadata.seq_lens,
|
||||
max_seq_len,
|
||||
total_num_tokens,
|
||||
layer.scaling,
|
||||
layer.logit_cap,
|
||||
)
|
||||
|
||||
return o
|
||||
@@ -10,8 +10,8 @@ def create_flashinfer_kv_indices_triton(
|
||||
page_kernel_lens_ptr,
|
||||
kv_indptr,
|
||||
kv_start_idx,
|
||||
max_context_len,
|
||||
kv_indices_ptr,
|
||||
max_context_len: tl.constexpr,
|
||||
):
|
||||
BLOCK_SIZE: tl.constexpr = 512
|
||||
pid = tl.program_id(axis=0)
|
||||
@@ -47,15 +47,15 @@ class FlashinferUpdater:
|
||||
req_pool_indices,
|
||||
seq_lens,
|
||||
prefix_lens,
|
||||
flashinfer_decode_wrapper=None,
|
||||
flashinfer_use_ragged=False,
|
||||
decode_wrapper=None,
|
||||
use_ragged=False,
|
||||
):
|
||||
self.forward_mode = forward_mode
|
||||
self.model_runner = model_runner
|
||||
self.req_pool_indices = req_pool_indices
|
||||
self.seq_lens = seq_lens
|
||||
self.prefix_lens = prefix_lens
|
||||
self.flashinfer_use_ragged = flashinfer_use_ragged
|
||||
self.use_ragged = use_ragged
|
||||
|
||||
self.num_qo_heads = (
|
||||
model_runner.model_config.num_attention_heads // model_runner.tp_size
|
||||
@@ -71,20 +71,17 @@ class FlashinferUpdater:
|
||||
)
|
||||
|
||||
(
|
||||
self.flashinfer_decode_wrapper,
|
||||
self.flashinfer_prefill_wrapper_ragged,
|
||||
self.flashinfer_prefill_wrapper_paged,
|
||||
self.decode_wrapper,
|
||||
self.prefill_wrapper_ragged,
|
||||
self.prefill_wrapper_paged,
|
||||
) = (
|
||||
flashinfer_decode_wrapper,
|
||||
self.model_runner.flashinfer_prefill_wrapper_ragged,
|
||||
self.model_runner.flashinfer_prefill_wrapper_paged,
|
||||
decode_wrapper or self.model_runner.attn_backend.decode_wrapper,
|
||||
self.model_runner.attn_backend.prefill_wrapper_ragged,
|
||||
self.model_runner.attn_backend.prefill_wrapper_paged,
|
||||
)
|
||||
# CUDA graph uses different flashinfer_decode_wrapper
|
||||
if self.flashinfer_decode_wrapper is None:
|
||||
self.flashinfer_decode_wrapper = self.model_runner.flashinfer_decode_wrapper
|
||||
|
||||
def _init_indices_no_window(self):
|
||||
if self.flashinfer_use_ragged:
|
||||
def _init_indices_no_sliding_window(self):
|
||||
if self.use_ragged:
|
||||
paged_kernel_lens = self.prefix_lens
|
||||
else:
|
||||
paged_kernel_lens = self.seq_lens
|
||||
@@ -103,13 +100,13 @@ class FlashinferUpdater:
|
||||
paged_kernel_lens,
|
||||
self.kv_indptr,
|
||||
None,
|
||||
self.model_runner.req_to_token_pool.req_to_token.size(1),
|
||||
self.kv_indices,
|
||||
self.model_runner.req_to_token_pool.req_to_token.size(1),
|
||||
)
|
||||
|
||||
def _init_indices_window(self, wrapper_id):
|
||||
# window attention use paged only
|
||||
def _init_indices_sliding_window(self, wrapper_id):
|
||||
if wrapper_id == 0:
|
||||
# window attention use paged only
|
||||
if self.forward_mode.is_decode():
|
||||
paged_kernel_lens = torch.minimum(
|
||||
self.seq_lens,
|
||||
@@ -123,6 +120,7 @@ class FlashinferUpdater:
|
||||
- self.prefix_lens,
|
||||
)
|
||||
else:
|
||||
# full attention
|
||||
paged_kernel_lens = self.seq_lens
|
||||
|
||||
kv_start_idx = self.seq_lens - paged_kernel_lens
|
||||
@@ -139,8 +137,8 @@ class FlashinferUpdater:
|
||||
paged_kernel_lens,
|
||||
self.kv_indptr,
|
||||
kv_start_idx,
|
||||
self.model_runner.req_to_token_pool.req_to_token.size(1),
|
||||
self.kv_indices,
|
||||
self.model_runner.req_to_token_pool.req_to_token.size(1),
|
||||
)
|
||||
|
||||
def _update_decode_indices(self, decode_wrapper):
|
||||
@@ -164,7 +162,7 @@ class FlashinferUpdater:
|
||||
)
|
||||
qo_indptr[1:] = torch.cumsum(self.seq_lens - self.prefix_lens, dim=0)
|
||||
|
||||
if self.flashinfer_use_ragged:
|
||||
if self.use_ragged:
|
||||
ragged_wrapper.end_forward()
|
||||
ragged_wrapper.begin_forward(
|
||||
qo_indptr,
|
||||
@@ -187,28 +185,28 @@ class FlashinferUpdater:
|
||||
1,
|
||||
)
|
||||
|
||||
def update_indices_no_window(self):
|
||||
self._init_indices_no_window()
|
||||
def update_indices_no_sliding_window(self):
|
||||
self._init_indices_no_sliding_window()
|
||||
|
||||
if self.forward_mode.is_decode():
|
||||
self._update_decode_indices(self.flashinfer_decode_wrapper)
|
||||
self._update_decode_indices(self.decode_wrapper)
|
||||
else:
|
||||
self._update_extend_indices(
|
||||
self.flashinfer_prefill_wrapper_ragged,
|
||||
self.flashinfer_prefill_wrapper_paged,
|
||||
self.prefill_wrapper_ragged,
|
||||
self.prefill_wrapper_paged,
|
||||
)
|
||||
|
||||
def update_indices_window(self):
|
||||
assert self.flashinfer_use_ragged is False
|
||||
def update_indices_sliding_window(self):
|
||||
assert self.use_ragged is False
|
||||
|
||||
for wrapper_id in range(2):
|
||||
self._init_indices_window(wrapper_id)
|
||||
self._init_indices_sliding_window(wrapper_id)
|
||||
if self.forward_mode.is_decode():
|
||||
self._update_decode_indices(self.flashinfer_decode_wrapper[wrapper_id])
|
||||
self._update_decode_indices(self.decode_wrapper[wrapper_id])
|
||||
else:
|
||||
self._update_extend_indices(
|
||||
None,
|
||||
self.flashinfer_prefill_wrapper_paged[wrapper_id],
|
||||
self.prefill_wrapper_paged[wrapper_id],
|
||||
)
|
||||
|
||||
|
||||
@@ -218,20 +216,20 @@ def update_flashinfer_indices(
|
||||
req_pool_indices,
|
||||
seq_lens,
|
||||
prefix_lens,
|
||||
flashinfer_decode_wrapper=None,
|
||||
flashinfer_use_ragged=False,
|
||||
decode_wrapper=None,
|
||||
use_ragged=False,
|
||||
):
|
||||
flashinfer_updater = FlashinferUpdater(
|
||||
updater = FlashinferUpdater(
|
||||
forward_mode,
|
||||
model_runner,
|
||||
req_pool_indices,
|
||||
seq_lens,
|
||||
prefix_lens,
|
||||
flashinfer_decode_wrapper,
|
||||
flashinfer_use_ragged,
|
||||
decode_wrapper,
|
||||
use_ragged,
|
||||
)
|
||||
|
||||
if model_runner.sliding_window_size is None:
|
||||
flashinfer_updater.update_indices_no_window()
|
||||
updater.update_indices_no_sliding_window()
|
||||
else:
|
||||
flashinfer_updater.update_indices_window()
|
||||
updater.update_indices_sliding_window()
|
||||
|
||||
@@ -15,25 +15,14 @@ limitations under the License.
|
||||
|
||||
"""Radix attention."""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from flashinfer.cascade import merge_state
|
||||
from torch import nn
|
||||
|
||||
from sglang.global_config import global_config
|
||||
from sglang.srt.layers.triton_attention.decode_attention import decode_attention_fwd
|
||||
from sglang.srt.layers.triton_attention.extend_attention import extend_attention_fwd
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetadata
|
||||
from sglang.srt.model_executor.model_runner import global_server_args_dict
|
||||
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
||||
|
||||
|
||||
class RadixAttention(nn.Module):
|
||||
"""
|
||||
The attention layer implementation.
|
||||
Now it has two backends: FlashInfer and Triton.
|
||||
FlashInfer is faster and Triton is easier to customize.
|
||||
It supports two operators: extend (i.e. prefill with cached prefix) and decode.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -43,8 +32,8 @@ class RadixAttention(nn.Module):
|
||||
scaling: float,
|
||||
num_kv_heads: int,
|
||||
layer_id: int,
|
||||
sliding_window_size: Optional[int] = None,
|
||||
logit_cap: int = -1,
|
||||
sliding_window_size: int = -1,
|
||||
logit_cap: float = 0.0,
|
||||
v_head_dim: int = -1,
|
||||
):
|
||||
super().__init__()
|
||||
@@ -56,164 +45,14 @@ class RadixAttention(nn.Module):
|
||||
self.v_head_dim = v_head_dim if v_head_dim != -1 else head_dim
|
||||
self.scaling = scaling
|
||||
self.layer_id = layer_id
|
||||
self.logit_cap = logit_cap if logit_cap is not None and logit_cap > 0 else 0
|
||||
self.sliding_window_size = sliding_window_size if sliding_window_size else -1
|
||||
|
||||
# Choose backend
|
||||
if (
|
||||
global_server_args_dict["attention_backend"] == "flashinfer"
|
||||
and self.qk_head_dim == self.v_head_dim
|
||||
):
|
||||
self.extend_forward = self.extend_forward_flashinfer
|
||||
self.decode_forward = self.decode_forward_flashinfer
|
||||
elif global_server_args_dict["attention_backend"] == "triton":
|
||||
self.extend_forward = self.extend_forward_triton
|
||||
self.decode_forward = self.decode_forward_triton
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Invalid attention backend: {global_server_args_dict['attention_backend']}"
|
||||
)
|
||||
|
||||
def extend_forward_triton(self, q, k, v, input_metadata: InputMetadata):
|
||||
if self.qk_head_dim != self.v_head_dim:
|
||||
o = q.new_empty((q.shape[0], self.tp_q_head_num * self.v_head_dim))
|
||||
else:
|
||||
o = torch.empty_like(q)
|
||||
|
||||
self.store_kv_cache(k, v, input_metadata)
|
||||
extend_attention_fwd(
|
||||
q.view(-1, self.tp_q_head_num, self.qk_head_dim),
|
||||
k.contiguous(),
|
||||
v.contiguous(),
|
||||
o.view(-1, self.tp_q_head_num, self.v_head_dim),
|
||||
input_metadata.token_to_kv_pool.get_key_buffer(self.layer_id),
|
||||
input_metadata.token_to_kv_pool.get_value_buffer(self.layer_id),
|
||||
input_metadata.req_to_token_pool.req_to_token,
|
||||
input_metadata.req_pool_indices,
|
||||
input_metadata.triton_start_loc,
|
||||
input_metadata.seq_lens,
|
||||
input_metadata.triton_prefix_lens,
|
||||
input_metadata.extend_start_loc,
|
||||
input_metadata.extend_seq_lens,
|
||||
input_metadata.triton_max_seq_len,
|
||||
input_metadata.triton_max_extend_len,
|
||||
sm_scale=self.scaling,
|
||||
logit_cap=self.logit_cap,
|
||||
)
|
||||
|
||||
return o
|
||||
|
||||
def decode_forward_triton(self, q, k, v, input_metadata: InputMetadata):
|
||||
if self.qk_head_dim != self.v_head_dim:
|
||||
o = q.new_empty((q.shape[0], self.tp_q_head_num * self.v_head_dim))
|
||||
else:
|
||||
o = torch.empty_like(q)
|
||||
self.store_kv_cache(k, v, input_metadata)
|
||||
|
||||
decode_attention_fwd(
|
||||
q.view(-1, self.tp_q_head_num, self.qk_head_dim),
|
||||
input_metadata.token_to_kv_pool.get_key_buffer(self.layer_id),
|
||||
input_metadata.token_to_kv_pool.get_value_buffer(self.layer_id),
|
||||
o.view(-1, self.tp_q_head_num, self.v_head_dim),
|
||||
input_metadata.req_to_token_pool.req_to_token,
|
||||
input_metadata.req_pool_indices,
|
||||
input_metadata.triton_start_loc,
|
||||
input_metadata.seq_lens,
|
||||
input_metadata.triton_max_seq_len,
|
||||
input_metadata.total_num_tokens,
|
||||
sm_scale=self.scaling,
|
||||
logit_cap=self.logit_cap,
|
||||
)
|
||||
|
||||
return o
|
||||
|
||||
def extend_forward_flashinfer(self, q, k, v, input_metadata: InputMetadata):
|
||||
# using two wrappers is unnecessary in the current PR, but are prepared for future PRs
|
||||
prefill_wrapper_paged = input_metadata.flashinfer_prefill_wrapper_paged
|
||||
if self.sliding_window_size != -1:
|
||||
prefill_wrapper_paged = prefill_wrapper_paged[0]
|
||||
else:
|
||||
if isinstance(prefill_wrapper_paged, list):
|
||||
prefill_wrapper_paged = prefill_wrapper_paged[1]
|
||||
|
||||
if not input_metadata.flashinfer_use_ragged:
|
||||
if k is not None:
|
||||
assert v is not None
|
||||
self.store_kv_cache(k, v, input_metadata)
|
||||
|
||||
o = prefill_wrapper_paged.forward(
|
||||
q.contiguous().view(-1, self.tp_q_head_num, self.head_dim),
|
||||
input_metadata.token_to_kv_pool.get_kv_buffer(self.layer_id),
|
||||
causal=True,
|
||||
sm_scale=self.scaling,
|
||||
window_left=self.sliding_window_size,
|
||||
logits_soft_cap=self.logit_cap,
|
||||
)
|
||||
else:
|
||||
o1, s1 = (
|
||||
input_metadata.flashinfer_prefill_wrapper_ragged.forward_return_lse(
|
||||
q.contiguous().view(-1, self.tp_q_head_num, self.head_dim),
|
||||
k.contiguous().view(-1, self.tp_k_head_num, self.head_dim),
|
||||
v.contiguous().view(-1, self.tp_v_head_num, self.head_dim),
|
||||
causal=True,
|
||||
sm_scale=self.scaling,
|
||||
logits_soft_cap=self.logit_cap,
|
||||
)
|
||||
)
|
||||
|
||||
if input_metadata.extend_no_prefix:
|
||||
o = o1
|
||||
else:
|
||||
o2, s2 = prefill_wrapper_paged.forward_return_lse(
|
||||
q.contiguous().view(-1, self.tp_q_head_num, self.head_dim),
|
||||
input_metadata.token_to_kv_pool.get_kv_buffer(self.layer_id),
|
||||
causal=False,
|
||||
sm_scale=self.scaling,
|
||||
logits_soft_cap=self.logit_cap,
|
||||
)
|
||||
|
||||
o, _ = merge_state(o1, s1, o2, s2)
|
||||
|
||||
self.store_kv_cache(k, v, input_metadata)
|
||||
|
||||
if input_metadata.total_num_tokens >= global_config.layer_sync_threshold:
|
||||
torch.cuda.synchronize()
|
||||
|
||||
return o.view(-1, self.tp_q_head_num * self.head_dim)
|
||||
|
||||
def decode_forward_flashinfer(self, q, k, v, input_metadata: InputMetadata):
|
||||
decode_wrapper = input_metadata.flashinfer_decode_wrapper
|
||||
if self.sliding_window_size != -1:
|
||||
decode_wrapper = decode_wrapper[0]
|
||||
else:
|
||||
if isinstance(decode_wrapper, list):
|
||||
decode_wrapper = decode_wrapper[1]
|
||||
|
||||
if k is not None:
|
||||
assert v is not None
|
||||
self.store_kv_cache(k, v, input_metadata)
|
||||
|
||||
o = decode_wrapper.forward(
|
||||
q.contiguous().view(-1, self.tp_q_head_num, self.head_dim),
|
||||
input_metadata.token_to_kv_pool.get_kv_buffer(self.layer_id),
|
||||
sm_scale=self.scaling,
|
||||
logits_soft_cap=self.logit_cap,
|
||||
)
|
||||
|
||||
return o.view(-1, self.tp_q_head_num * self.head_dim)
|
||||
self.logit_cap = logit_cap
|
||||
self.sliding_window_size = sliding_window_size or -1
|
||||
|
||||
def forward(self, q, k, v, input_metadata: InputMetadata):
|
||||
if k is not None:
|
||||
# For cross-layer sharing, kv can be None
|
||||
assert v is not None
|
||||
k = k.view(-1, self.tp_k_head_num, self.qk_head_dim)
|
||||
v = v.view(-1, self.tp_v_head_num, self.v_head_dim)
|
||||
|
||||
if input_metadata.forward_mode.is_extend():
|
||||
return self.extend_forward(q, k, v, input_metadata)
|
||||
elif input_metadata.forward_mode.is_decode():
|
||||
return self.decode_forward(q, k, v, input_metadata)
|
||||
|
||||
def store_kv_cache(self, cache_k, cache_v, input_metadata: InputMetadata):
|
||||
input_metadata.token_to_kv_pool.set_kv_buffer(
|
||||
self.layer_id, input_metadata.out_cache_loc, cache_k, cache_v
|
||||
)
|
||||
return input_metadata.attn_backend.forward(q, k, v, self, input_metadata)
|
||||
|
||||
@@ -15,6 +15,7 @@ limitations under the License.
|
||||
|
||||
"""
|
||||
Memory-efficient attention for decoding.
|
||||
It supports page size = 1.
|
||||
"""
|
||||
|
||||
# Adapted from
|
||||
@@ -197,7 +198,6 @@ def _decode_att_m_fwd(
|
||||
logit_cap,
|
||||
):
|
||||
BLOCK = 32
|
||||
# shape constraints
|
||||
Lq, Lk = q.shape[-1], k_buffer.shape[-1]
|
||||
|
||||
batch, head_num = B_req_idx.shape[0], q.shape[1]
|
||||
@@ -478,7 +478,6 @@ def _decode_grouped_att_m_fwd(
|
||||
logit_cap,
|
||||
):
|
||||
BLOCK = 32
|
||||
# shape constraints
|
||||
Lq, Lk = q.shape[-1], k_buffer.shape[-1]
|
||||
|
||||
if Lk == 576:
|
||||
@@ -570,9 +569,9 @@ def _decode_grouped_softmax_reducev_fwd(
|
||||
BLOCK_DMODEL=BLOCK_DMODEL,
|
||||
BLOCK_N=BLOCK,
|
||||
BLOCK_H=BLOCK_H,
|
||||
Lv=Lv,
|
||||
num_warps=num_warps,
|
||||
num_stages=1,
|
||||
Lv=Lv,
|
||||
)
|
||||
|
||||
|
||||
@@ -588,7 +587,7 @@ def decode_attention_fwd(
|
||||
max_len_in_batch,
|
||||
total_num_tokens,
|
||||
sm_scale,
|
||||
logit_cap=-1,
|
||||
logit_cap=0.0,
|
||||
att_m=None,
|
||||
):
|
||||
if att_m is None:
|
||||
|
||||
@@ -61,14 +61,14 @@ def _fwd_kernel(
|
||||
stride_buf_vbs,
|
||||
stride_buf_vh,
|
||||
stride_req_to_tokens_b,
|
||||
logit_cap: tl.constexpr,
|
||||
Lq: tl.constexpr,
|
||||
Lv: tl.constexpr,
|
||||
BLOCK_DMODEL: tl.constexpr,
|
||||
BLOCK_DPE: tl.constexpr,
|
||||
BLOCK_DV: tl.constexpr,
|
||||
BLOCK_M: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr,
|
||||
logit_cap: tl.constexpr,
|
||||
Lq: tl.constexpr,
|
||||
Lv: tl.constexpr,
|
||||
):
|
||||
cur_seq = tl.program_id(0)
|
||||
cur_head = tl.program_id(1)
|
||||
@@ -111,7 +111,7 @@ def _fwd_kernel(
|
||||
)
|
||||
qpe = tl.load(Q_Extend + offs_qpe, mask=mask_m[:, None], other=0.0)
|
||||
|
||||
# stage1: compute scores with prefix
|
||||
# stage 1: compute scores with prefix
|
||||
offs_n = tl.arange(0, BLOCK_N)
|
||||
|
||||
acc = tl.zeros([BLOCK_M, BLOCK_DV], dtype=tl.float32)
|
||||
@@ -174,7 +174,7 @@ def _fwd_kernel(
|
||||
|
||||
e_max = n_e_max
|
||||
|
||||
# stage2: compute the trianlge part
|
||||
# stage 2: compute the trianlge part
|
||||
|
||||
cur_block_m_end = tl.minimum(cur_seq_len_extend, (cur_block_m + 1) * BLOCK_M)
|
||||
for start_n in range(0, cur_block_m_end, BLOCK_N):
|
||||
@@ -255,26 +255,22 @@ def extend_attention_fwd(
|
||||
v_buffer,
|
||||
req_to_tokens,
|
||||
b_req_idx,
|
||||
b_start_loc,
|
||||
b_seq_len,
|
||||
b_seq_len_prefix,
|
||||
b_start_loc_extend,
|
||||
b_seq_len_extend,
|
||||
max_len_in_batch,
|
||||
b_start_loc_extend,
|
||||
max_len_extend,
|
||||
sm_scale=None,
|
||||
logit_cap=-1,
|
||||
logit_cap=0.0,
|
||||
):
|
||||
"""
|
||||
q_extend, k_extend, v_extend, o_extend: contiguous tensors
|
||||
|
||||
k_buffer, v_buffer: (prefix + extend) tensors in mem_manager
|
||||
"""
|
||||
Lq, Lk, Lv, Lo = (
|
||||
Lq, Lk, Lv = (
|
||||
q_extend.shape[-1],
|
||||
k_extend.shape[-1],
|
||||
v_extend.shape[-1],
|
||||
o_extend.shape[-1],
|
||||
)
|
||||
|
||||
if Lq == 576:
|
||||
@@ -303,7 +299,7 @@ def extend_attention_fwd(
|
||||
else:
|
||||
BLOCK_M, BLOCK_N = (64, 64) if Lq <= 128 else (32, 32)
|
||||
|
||||
sm_scale = 1.0 / (Lq**0.5) if sm_scale is None else sm_scale
|
||||
sm_scale = sm_scale or 1.0 / (Lq**0.5)
|
||||
batch_size, head_num = b_seq_len.shape[0], q_extend.shape[1]
|
||||
kv_group_num = q_extend.shape[1] // k_extend.shape[1]
|
||||
|
||||
@@ -338,27 +334,24 @@ def extend_attention_fwd(
|
||||
v_buffer.stride(0),
|
||||
v_buffer.stride(1),
|
||||
req_to_tokens.stride(0),
|
||||
logit_cap=logit_cap,
|
||||
BLOCK_DMODEL=BLOCK_DMODEL,
|
||||
BLOCK_DPE=BLOCK_DPE,
|
||||
BLOCK_DV=BLOCK_DV,
|
||||
BLOCK_M=BLOCK_M,
|
||||
BLOCK_N=BLOCK_N,
|
||||
num_warps=num_warps,
|
||||
num_stages=num_stages,
|
||||
logit_cap=logit_cap,
|
||||
Lq=Lq,
|
||||
Lv=Lv,
|
||||
num_warps=num_warps,
|
||||
num_stages=num_stages,
|
||||
)
|
||||
|
||||
|
||||
def redundant_attention(
|
||||
q_extend,
|
||||
k_extend,
|
||||
v_extend,
|
||||
o_extend,
|
||||
k_buffer,
|
||||
v_buffer,
|
||||
req_to_tokens,
|
||||
b_req_idx,
|
||||
b_start_loc,
|
||||
b_seq_len,
|
||||
|
||||
@@ -368,7 +368,7 @@ class ScheduleBatch:
|
||||
)
|
||||
|
||||
def batch_size(self):
|
||||
return len(self.reqs) if self.reqs is not None else 0
|
||||
return len(self.reqs) if self.reqs else 0
|
||||
|
||||
def is_empty(self):
|
||||
return len(self.reqs) == 0
|
||||
|
||||
@@ -13,15 +13,13 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
|
||||
"""Run the model with cuda graph."""
|
||||
"""Run the model with cuda graph and torch.compile."""
|
||||
|
||||
import bisect
|
||||
from contextlib import contextmanager
|
||||
from typing import Callable, List
|
||||
from typing import Callable
|
||||
|
||||
import torch
|
||||
from flashinfer import BatchDecodeWithPagedKVCacheWrapper
|
||||
from flashinfer.decode import _grouped_size_compiled_for_decode_kernels
|
||||
from vllm.distributed.parallel_state import graph_capture
|
||||
from vllm.model_executor.custom_op import CustomOp
|
||||
|
||||
@@ -55,6 +53,7 @@ def _to_torch(model: torch.nn.Module, reverse: bool = False):
|
||||
def patch_model(
|
||||
model: torch.nn.Module, enable_compile: bool, tp_group: "GroupCoordinator"
|
||||
):
|
||||
"""Patch the model to make it compatible with with torch.compile"""
|
||||
backup_ca_comm = None
|
||||
|
||||
try:
|
||||
@@ -86,23 +85,28 @@ def set_torch_compile_config():
|
||||
|
||||
|
||||
class CudaGraphRunner:
|
||||
def __init__(
|
||||
self,
|
||||
model_runner: "ModelRunner",
|
||||
max_batch_size_to_capture: int,
|
||||
use_torch_compile: bool,
|
||||
disable_padding: bool,
|
||||
):
|
||||
"""A CudaGraphRunner runs the forward pass of a model with cuda graph and torch.compile."""
|
||||
|
||||
def __init__(self, model_runner: "ModelRunner"):
|
||||
# Parse args
|
||||
self.model_runner = model_runner
|
||||
self.graphs = {}
|
||||
self.input_buffers = {}
|
||||
self.output_buffers = {}
|
||||
self.flashinfer_handlers = {}
|
||||
self.graph_memory_pool = None
|
||||
self.disable_padding = disable_padding
|
||||
self.use_torch_compile = model_runner.server_args.enable_torch_compile
|
||||
self.disable_padding = model_runner.server_args.disable_cuda_graph_padding
|
||||
|
||||
# Batch sizes to capture
|
||||
if self.model_runner.server_args.disable_cuda_graph_padding:
|
||||
self.capture_bs = list(range(1, 32)) + [64, 128]
|
||||
else:
|
||||
self.capture_bs = [1, 2, 4] + [i * 8 for i in range(1, 21)]
|
||||
self.compile_bs = [1, 2, 4, 8, 16, 24, 32] if self.use_torch_compile else []
|
||||
|
||||
# Common inputs
|
||||
self.max_bs = max_batch_size_to_capture
|
||||
self.max_bs = max(self.capture_bs)
|
||||
self.input_ids = torch.zeros((self.max_bs,), dtype=torch.int32, device="cuda")
|
||||
self.req_pool_indices = torch.zeros(
|
||||
(self.max_bs,), dtype=torch.int32, device="cuda"
|
||||
@@ -115,56 +119,39 @@ class CudaGraphRunner:
|
||||
(self.max_bs,), dtype=torch.int32, device="cuda"
|
||||
)
|
||||
|
||||
# FlashInfer inputs
|
||||
self.flashinfer_kv_indptr = torch.zeros(
|
||||
(self.max_bs + 1,), dtype=torch.int32, device="cuda"
|
||||
)
|
||||
self.flashinfer_kv_indices = torch.zeros(
|
||||
(self.max_bs * model_runner.model_config.context_len,),
|
||||
dtype=torch.int32,
|
||||
device="cuda",
|
||||
)
|
||||
self.flashinfer_kv_last_page_len = torch.ones(
|
||||
(self.max_bs,), dtype=torch.int32, device="cuda"
|
||||
)
|
||||
if model_runner.sliding_window_size is None:
|
||||
self.flashinfer_workspace_buffer = (
|
||||
self.model_runner.flashinfer_workspace_buffer
|
||||
)
|
||||
else:
|
||||
self.flashinfer_workspace_buffer = (
|
||||
self.model_runner.flashinfer_workspace_buffer
|
||||
)
|
||||
# Attention backend
|
||||
self.model_runner.attn_backend.init_cuda_graph_state(self.max_bs)
|
||||
|
||||
self.flashinfer_kv_indptr = [
|
||||
self.flashinfer_kv_indptr,
|
||||
self.flashinfer_kv_indptr.clone(),
|
||||
]
|
||||
self.flashinfer_kv_indices = [
|
||||
self.flashinfer_kv_indices,
|
||||
self.flashinfer_kv_indices.clone(),
|
||||
]
|
||||
|
||||
# Sampling inputs
|
||||
# Sampling info
|
||||
vocab_size = model_runner.model_config.vocab_size
|
||||
self.sampling_info = SamplingBatchInfo.dummy_one(self.max_bs, vocab_size)
|
||||
|
||||
self.compile_bs = [1, 2, 4, 8, 16, 24, 32] if use_torch_compile else []
|
||||
|
||||
if use_torch_compile:
|
||||
if self.use_torch_compile:
|
||||
set_torch_compile_config()
|
||||
|
||||
# Capture
|
||||
try:
|
||||
self.capture()
|
||||
except RuntimeError as e:
|
||||
raise Exception(
|
||||
f"Capture cuda graph failed: {e}\n"
|
||||
"Possible solutions:\n"
|
||||
"1. disable cuda graph by --disable-cuda-graph\n"
|
||||
"2. set --mem-fraction-static to a smaller value\n"
|
||||
"3. disable torch compile by not using --enable-torch-compile\n"
|
||||
"Open an issue on GitHub https://github.com/sgl-project/sglang/issues/new/choose \n"
|
||||
)
|
||||
|
||||
def can_run(self, batch_size: int):
|
||||
if self.disable_padding:
|
||||
return batch_size in self.graphs
|
||||
else:
|
||||
return batch_size <= self.max_bs
|
||||
|
||||
def capture(self, batch_size_list: List[int]):
|
||||
self.batch_size_list = batch_size_list
|
||||
def capture(self):
|
||||
with graph_capture() as graph_capture_context:
|
||||
self.stream = graph_capture_context.stream
|
||||
for bs in batch_size_list:
|
||||
for bs in self.capture_bs:
|
||||
with patch_model(
|
||||
self.model_runner.model,
|
||||
bs in self.compile_bs,
|
||||
@@ -172,14 +159,10 @@ class CudaGraphRunner:
|
||||
) as forward:
|
||||
(
|
||||
graph,
|
||||
input_buffers,
|
||||
output_buffers,
|
||||
flashinfer_handler,
|
||||
) = self.capture_one_batch_size(bs, forward)
|
||||
self.graphs[bs] = graph
|
||||
self.input_buffers[bs] = input_buffers
|
||||
self.output_buffers[bs] = output_buffers
|
||||
self.flashinfer_handlers[bs] = flashinfer_handler
|
||||
|
||||
def capture_one_batch_size(self, bs: int, forward: Callable):
|
||||
graph = torch.cuda.CUDAGraph()
|
||||
@@ -192,48 +175,9 @@ class CudaGraphRunner:
|
||||
position_ids_offsets = self.position_ids_offsets[:bs]
|
||||
out_cache_loc = self.out_cache_loc[:bs]
|
||||
|
||||
# FlashInfer inputs
|
||||
if not _grouped_size_compiled_for_decode_kernels(
|
||||
self.model_runner.model_config.num_attention_heads
|
||||
// self.model_runner.tp_size,
|
||||
self.model_runner.model_config.get_num_kv_heads(self.model_runner.tp_size),
|
||||
):
|
||||
use_tensor_cores = True
|
||||
else:
|
||||
use_tensor_cores = False
|
||||
if self.model_runner.sliding_window_size is None:
|
||||
flashinfer_decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
|
||||
self.flashinfer_workspace_buffer,
|
||||
"NHD",
|
||||
use_cuda_graph=True,
|
||||
use_tensor_cores=use_tensor_cores,
|
||||
paged_kv_indptr_buffer=self.flashinfer_kv_indptr[: bs + 1],
|
||||
paged_kv_indices_buffer=self.flashinfer_kv_indices,
|
||||
paged_kv_last_page_len_buffer=self.flashinfer_kv_last_page_len[:bs],
|
||||
)
|
||||
else:
|
||||
flashinfer_decode_wrapper = []
|
||||
for i in range(2):
|
||||
flashinfer_decode_wrapper.append(
|
||||
BatchDecodeWithPagedKVCacheWrapper(
|
||||
self.flashinfer_workspace_buffer,
|
||||
"NHD",
|
||||
use_cuda_graph=True,
|
||||
use_tensor_cores=use_tensor_cores,
|
||||
paged_kv_indptr_buffer=self.flashinfer_kv_indptr[i][: bs + 1],
|
||||
paged_kv_indices_buffer=self.flashinfer_kv_indices[i],
|
||||
paged_kv_last_page_len_buffer=self.flashinfer_kv_last_page_len[
|
||||
:bs
|
||||
],
|
||||
)
|
||||
)
|
||||
update_flashinfer_indices(
|
||||
ForwardMode.DECODE,
|
||||
self.model_runner,
|
||||
req_pool_indices,
|
||||
seq_lens,
|
||||
None,
|
||||
flashinfer_decode_wrapper,
|
||||
# Attention backend
|
||||
self.model_runner.attn_backend.capture_cuda_graph_init(
|
||||
bs, req_pool_indices, seq_lens
|
||||
)
|
||||
|
||||
# Run and capture
|
||||
@@ -246,13 +190,12 @@ class CudaGraphRunner:
|
||||
seq_lens=seq_lens,
|
||||
req_to_token_pool=self.model_runner.req_to_token_pool,
|
||||
token_to_kv_pool=self.model_runner.token_to_kv_pool,
|
||||
attn_backend=self.model_runner.attn_backend,
|
||||
out_cache_loc=out_cache_loc,
|
||||
return_logprob=False,
|
||||
top_logprobs_nums=0,
|
||||
positions=(seq_lens - 1 + position_ids_offsets).to(torch.int64),
|
||||
flashinfer_decode_wrapper=flashinfer_decode_wrapper,
|
||||
)
|
||||
|
||||
return forward(input_ids, input_metadata.positions, input_metadata)
|
||||
|
||||
for _ in range(2):
|
||||
@@ -274,15 +217,15 @@ class CudaGraphRunner:
|
||||
self.model_runner.tp_group.barrier()
|
||||
|
||||
self.graph_memory_pool = graph.pool()
|
||||
return graph, None, out, flashinfer_decode_wrapper
|
||||
return graph, out
|
||||
|
||||
def replay(self, batch: ScheduleBatch):
|
||||
assert batch.out_cache_loc is not None
|
||||
raw_bs = len(batch.reqs)
|
||||
|
||||
# Pad
|
||||
index = bisect.bisect_left(self.batch_size_list, raw_bs)
|
||||
bs = self.batch_size_list[index]
|
||||
index = bisect.bisect_left(self.capture_bs, raw_bs)
|
||||
bs = self.capture_bs[index]
|
||||
if bs != raw_bs:
|
||||
self.seq_lens.zero_()
|
||||
self.position_ids_offsets.fill_(1)
|
||||
@@ -295,14 +238,9 @@ class CudaGraphRunner:
|
||||
self.position_ids_offsets[:raw_bs] = batch.position_ids_offsets
|
||||
self.out_cache_loc[:raw_bs] = batch.out_cache_loc
|
||||
|
||||
# FlashInfer inputs
|
||||
update_flashinfer_indices(
|
||||
ForwardMode.DECODE,
|
||||
self.model_runner,
|
||||
self.req_pool_indices[:bs],
|
||||
self.seq_lens[:bs],
|
||||
None,
|
||||
self.flashinfer_handlers[bs],
|
||||
# Attention backend
|
||||
self.model_runner.attn_backend.replay_cuda_graph_init(
|
||||
bs, self.req_pool_indices, self.seq_lens
|
||||
)
|
||||
|
||||
# Sampling inputs
|
||||
|
||||
@@ -23,9 +23,8 @@ from typing import TYPE_CHECKING, List
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from sglang.srt.layers.flashinfer_utils import update_flashinfer_indices
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.layers.attention_backend import AttentionBackend
|
||||
from sglang.srt.managers.schedule_batch import ScheduleBatch
|
||||
from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
|
||||
from sglang.srt.model_executor.model_runner import ModelRunner
|
||||
@@ -66,12 +65,11 @@ class InputMetadata:
|
||||
seq_lens: torch.Tensor
|
||||
req_to_token_pool: ReqToTokenPool
|
||||
token_to_kv_pool: BaseTokenToKVPool
|
||||
attn_backend: AttentionBackend
|
||||
|
||||
# Output location of the KV cache
|
||||
out_cache_loc: torch.Tensor
|
||||
|
||||
total_num_tokens: int = None
|
||||
|
||||
# Position information
|
||||
positions: torch.Tensor = None
|
||||
|
||||
@@ -93,18 +91,6 @@ class InputMetadata:
|
||||
image_offsets: List[List[int]] = None
|
||||
modalities: List[List[str]] = None
|
||||
|
||||
# Trition attention backend
|
||||
triton_max_seq_len: int = 0
|
||||
triton_max_extend_len: int = 0
|
||||
triton_start_loc: torch.Tensor = None
|
||||
triton_prefix_lens: torch.Tensor = None
|
||||
|
||||
# FlashInfer attention backend
|
||||
flashinfer_prefill_wrapper_ragged: "BatchPrefillWithRaggedKVCacheWrapper" = None
|
||||
flashinfer_prefill_wrapper_paged: "BatchPrefillWithPagedKVCacheWrapper" = None
|
||||
flashinfer_decode_wrapper: "BatchDecodeWithPagedKVCacheWrapper" = None
|
||||
flashinfer_use_ragged: bool = False
|
||||
|
||||
def init_multimuldal_info(self, batch: ScheduleBatch):
|
||||
reqs = batch.reqs
|
||||
self.pixel_values = [r.pixel_values for r in reqs]
|
||||
@@ -154,32 +140,27 @@ class InputMetadata:
|
||||
self.positions = self.positions.to(torch.int64)
|
||||
|
||||
def compute_extend_infos(self, batch: ScheduleBatch):
|
||||
if self.forward_mode.is_decode():
|
||||
self.extend_seq_lens = self.extend_start_loc = self.extend_no_prefix = None
|
||||
self.extend_seq_lens_cpu = self.logprob_start_lens_cpu = None
|
||||
else:
|
||||
extend_lens_cpu = [
|
||||
len(r.fill_ids) - batch.prefix_lens_cpu[i]
|
||||
for i, r in enumerate(batch.reqs)
|
||||
]
|
||||
self.extend_seq_lens = torch.tensor(extend_lens_cpu, device="cuda")
|
||||
self.extend_prefix_lens = torch.tensor(batch.prefix_lens_cpu, device="cuda")
|
||||
self.extend_start_loc = torch.zeros_like(self.seq_lens)
|
||||
self.extend_start_loc[1:] = torch.cumsum(self.extend_seq_lens[:-1], dim=0)
|
||||
self.extend_no_prefix = all(l == 0 for l in batch.prefix_lens_cpu)
|
||||
extend_lens_cpu = [
|
||||
len(r.fill_ids) - batch.prefix_lens_cpu[i] for i, r in enumerate(batch.reqs)
|
||||
]
|
||||
self.extend_seq_lens = torch.tensor(extend_lens_cpu, device="cuda")
|
||||
self.extend_prefix_lens = torch.tensor(batch.prefix_lens_cpu, device="cuda")
|
||||
self.extend_start_loc = torch.zeros_like(self.seq_lens)
|
||||
self.extend_start_loc[1:] = torch.cumsum(self.extend_seq_lens[:-1], dim=0)
|
||||
self.extend_no_prefix = all(l == 0 for l in batch.prefix_lens_cpu)
|
||||
|
||||
self.extend_seq_lens_cpu = extend_lens_cpu
|
||||
self.logprob_start_lens_cpu = [
|
||||
(
|
||||
min(
|
||||
req.logprob_start_len - batch.prefix_lens_cpu[i],
|
||||
extend_lens_cpu[i] - 1,
|
||||
)
|
||||
if req.logprob_start_len >= batch.prefix_lens_cpu[i]
|
||||
else extend_lens_cpu[i] - 1 # Fake extend, actually decode
|
||||
self.extend_seq_lens_cpu = extend_lens_cpu
|
||||
self.logprob_start_lens_cpu = [
|
||||
(
|
||||
min(
|
||||
req.logprob_start_len - batch.prefix_lens_cpu[i],
|
||||
extend_lens_cpu[i] - 1,
|
||||
)
|
||||
for i, req in enumerate(batch.reqs)
|
||||
]
|
||||
if req.logprob_start_len >= batch.prefix_lens_cpu[i]
|
||||
else extend_lens_cpu[i] - 1 # Fake extend, actually decode
|
||||
)
|
||||
for i, req in enumerate(batch.reqs)
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def from_schedule_batch(
|
||||
@@ -195,6 +176,7 @@ class InputMetadata:
|
||||
seq_lens=batch.seq_lens,
|
||||
req_to_token_pool=model_runner.req_to_token_pool,
|
||||
token_to_kv_pool=model_runner.token_to_kv_pool,
|
||||
attn_backend=model_runner.attn_backend,
|
||||
out_cache_loc=batch.out_cache_loc,
|
||||
return_logprob=batch.return_logprob,
|
||||
top_logprobs_nums=batch.top_logprobs_nums,
|
||||
@@ -202,76 +184,12 @@ class InputMetadata:
|
||||
|
||||
ret.sampling_info.update_penalties()
|
||||
ret.sampling_info.update_regex_vocab_mask(batch)
|
||||
|
||||
ret.compute_positions(batch)
|
||||
|
||||
ret.compute_extend_infos(batch)
|
||||
|
||||
fm = batch.forward_mode
|
||||
if not fm.is_decode() or model_runner.server_args.attention_backend == "triton":
|
||||
ret.total_num_tokens = int(torch.sum(ret.seq_lens))
|
||||
|
||||
if not fm.is_decode():
|
||||
if not batch.forward_mode.is_decode():
|
||||
ret.init_multimuldal_info(batch)
|
||||
ret.compute_extend_infos(batch)
|
||||
|
||||
if model_runner.server_args.attention_backend == "triton":
|
||||
ret.init_triton_args(batch)
|
||||
|
||||
flashinfer_use_ragged = False
|
||||
if model_runner.server_args.attention_backend == "flashinfer":
|
||||
if (
|
||||
not fm.is_decode()
|
||||
and int(torch.sum(ret.seq_lens)) > 4096
|
||||
and model_runner.sliding_window_size is None
|
||||
):
|
||||
flashinfer_use_ragged = True
|
||||
ret.init_flashinfer_handlers(
|
||||
model_runner, batch.prefix_lens_cpu, flashinfer_use_ragged
|
||||
)
|
||||
model_runner.attn_backend.init_forward_metadata(batch, ret)
|
||||
|
||||
return ret
|
||||
|
||||
def init_triton_args(self, batch: ScheduleBatch):
|
||||
"""Init auxiliary variables for triton attention backend."""
|
||||
self.triton_max_seq_len = int(torch.max(self.seq_lens))
|
||||
self.triton_start_loc = torch.zeros_like(self.seq_lens, dtype=torch.int32)
|
||||
self.triton_start_loc[1:] = torch.cumsum(self.seq_lens[:-1], dim=0)
|
||||
|
||||
if self.forward_mode.is_decode():
|
||||
self.triton_max_extend_len = None
|
||||
else:
|
||||
self.triton_prefix_lens = torch.tensor(batch.prefix_lens_cpu, device="cuda")
|
||||
extend_seq_lens = self.seq_lens - self.triton_prefix_lens
|
||||
self.triton_max_extend_len = int(torch.max(extend_seq_lens))
|
||||
|
||||
def init_flashinfer_handlers(
|
||||
self,
|
||||
model_runner,
|
||||
prefix_lens_cpu,
|
||||
flashinfer_use_ragged,
|
||||
):
|
||||
if self.forward_mode.is_decode():
|
||||
prefix_lens = None
|
||||
else:
|
||||
prefix_lens = self.extend_prefix_lens
|
||||
|
||||
update_flashinfer_indices(
|
||||
self.forward_mode,
|
||||
model_runner,
|
||||
self.req_pool_indices,
|
||||
self.seq_lens,
|
||||
prefix_lens,
|
||||
flashinfer_use_ragged=flashinfer_use_ragged,
|
||||
)
|
||||
|
||||
(
|
||||
self.flashinfer_prefill_wrapper_ragged,
|
||||
self.flashinfer_prefill_wrapper_paged,
|
||||
self.flashinfer_decode_wrapper,
|
||||
self.flashinfer_use_ragged,
|
||||
) = (
|
||||
model_runner.flashinfer_prefill_wrapper_ragged,
|
||||
model_runner.flashinfer_prefill_wrapper_paged,
|
||||
model_runner.flashinfer_decode_wrapper,
|
||||
flashinfer_use_ragged,
|
||||
)
|
||||
|
||||
@@ -25,12 +25,6 @@ from typing import Optional, Tuple, Type
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from flashinfer import (
|
||||
BatchDecodeWithPagedKVCacheWrapper,
|
||||
BatchPrefillWithPagedKVCacheWrapper,
|
||||
BatchPrefillWithRaggedKVCacheWrapper,
|
||||
)
|
||||
from flashinfer.decode import _grouped_size_compiled_for_decode_kernels
|
||||
from vllm.config import DeviceConfig, LoadConfig
|
||||
from vllm.config import ModelConfig as VllmModelConfig
|
||||
from vllm.distributed import (
|
||||
@@ -43,8 +37,8 @@ from vllm.distributed.parallel_state import in_the_same_node_as
|
||||
from vllm.model_executor.model_loader import get_model
|
||||
from vllm.model_executor.models import ModelRegistry
|
||||
|
||||
from sglang.global_config import global_config
|
||||
from sglang.srt.configs.model_config import AttentionArch, ModelConfig
|
||||
from sglang.srt.layers.attention_backend import FlashInferAttnBackend, TritonAttnBackend
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
||||
from sglang.srt.layers.sampler import SampleOutput
|
||||
from sglang.srt.managers.schedule_batch import ScheduleBatch, global_server_args_dict
|
||||
@@ -69,6 +63,8 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ModelRunner:
|
||||
"""ModelRunner runs the forward passes of the models."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_config: ModelConfig,
|
||||
@@ -100,6 +96,7 @@ class ModelRunner:
|
||||
}
|
||||
)
|
||||
|
||||
# Model-specific adjustment
|
||||
if self.is_multimodal_model:
|
||||
logger.info(
|
||||
"Automatically turn off --chunked-prefill-size and adjust --mem-fraction-static for multimodal models."
|
||||
@@ -107,6 +104,7 @@ class ModelRunner:
|
||||
server_args.chunked_prefill_size = None
|
||||
server_args.mem_fraction_static *= 0.95
|
||||
|
||||
# Init componnets
|
||||
min_per_gpu_memory = self.init_torch_distributed()
|
||||
self.load_model()
|
||||
self.init_memory_pool(
|
||||
@@ -115,7 +113,7 @@ class ModelRunner:
|
||||
server_args.max_total_tokens,
|
||||
)
|
||||
self.init_cublas()
|
||||
self.init_flashinfer()
|
||||
self.init_attention_backend()
|
||||
self.init_cuda_graphs()
|
||||
|
||||
def init_torch_distributed(self):
|
||||
@@ -397,9 +395,6 @@ class ModelRunner:
|
||||
qk_rope_head_dim=self.model_config.qk_rope_head_dim,
|
||||
layer_num=self.model_config.num_hidden_layers,
|
||||
)
|
||||
logger.info("using MLA Triton implementaion, flashinfer is disabled")
|
||||
# FIXME: temporarily only Triton MLA is supported
|
||||
self.server_args.attention_backend = "triton"
|
||||
else:
|
||||
self.token_to_kv_pool = MHATokenToKVPool(
|
||||
self.max_total_num_tokens,
|
||||
@@ -422,106 +417,42 @@ class ModelRunner:
|
||||
c = a @ b
|
||||
return c
|
||||
|
||||
def init_flashinfer(self):
|
||||
"""Init flashinfer attention kernel wrappers."""
|
||||
if self.server_args.attention_backend != "flashinfer":
|
||||
assert (
|
||||
self.sliding_window_size is None
|
||||
), "turn on flashinfer to support window attention"
|
||||
self.flashinfer_prefill_wrapper_ragged = None
|
||||
self.flashinfer_prefill_wrapper_paged = None
|
||||
self.flashinfer_decode_wrapper = None
|
||||
return
|
||||
|
||||
if not _grouped_size_compiled_for_decode_kernels(
|
||||
self.model_config.num_attention_heads // self.tp_size,
|
||||
self.model_config.get_num_kv_heads(self.tp_size),
|
||||
):
|
||||
use_tensor_cores = True
|
||||
def init_attention_backend(self):
|
||||
"""Init attention kernel backend."""
|
||||
if self.server_args.attention_backend == "flashinfer":
|
||||
self.attn_backend = FlashInferAttnBackend(self)
|
||||
elif self.server_args.attention_backend == "triton":
|
||||
assert self.sliding_window_size is None, (
|
||||
"Window attention is not supported in the triton attention backend. "
|
||||
"Please use `--attention-backend flashinfer`."
|
||||
)
|
||||
self.attn_backend = TritonAttnBackend(self)
|
||||
else:
|
||||
use_tensor_cores = False
|
||||
|
||||
if self.sliding_window_size is None:
|
||||
self.flashinfer_workspace_buffer = torch.empty(
|
||||
global_config.flashinfer_workspace_size,
|
||||
dtype=torch.uint8,
|
||||
device="cuda",
|
||||
raise ValueError(
|
||||
f"Invalid attention backend: {self.server_args.attention_backend}"
|
||||
)
|
||||
self.flashinfer_prefill_wrapper_ragged = (
|
||||
BatchPrefillWithRaggedKVCacheWrapper(
|
||||
self.flashinfer_workspace_buffer, "NHD"
|
||||
)
|
||||
)
|
||||
self.flashinfer_prefill_wrapper_paged = BatchPrefillWithPagedKVCacheWrapper(
|
||||
self.flashinfer_workspace_buffer, "NHD"
|
||||
)
|
||||
self.flashinfer_decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
|
||||
self.flashinfer_workspace_buffer,
|
||||
"NHD",
|
||||
use_tensor_cores=use_tensor_cores,
|
||||
)
|
||||
else:
|
||||
self.flashinfer_workspace_buffer = torch.empty(
|
||||
global_config.flashinfer_workspace_size,
|
||||
dtype=torch.uint8,
|
||||
device="cuda",
|
||||
)
|
||||
self.flashinfer_prefill_wrapper_ragged = None
|
||||
self.flashinfer_prefill_wrapper_paged = []
|
||||
self.flashinfer_decode_wrapper = []
|
||||
for i in range(2):
|
||||
self.flashinfer_prefill_wrapper_paged.append(
|
||||
BatchPrefillWithPagedKVCacheWrapper(
|
||||
self.flashinfer_workspace_buffer, "NHD"
|
||||
)
|
||||
)
|
||||
self.flashinfer_decode_wrapper.append(
|
||||
BatchDecodeWithPagedKVCacheWrapper(
|
||||
self.flashinfer_workspace_buffer,
|
||||
"NHD",
|
||||
use_tensor_cores=use_tensor_cores,
|
||||
)
|
||||
)
|
||||
|
||||
def init_cuda_graphs(self):
|
||||
"""Capture cuda graphs."""
|
||||
from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner
|
||||
|
||||
self.cuda_graph_runner = None
|
||||
|
||||
if not self.is_generation:
|
||||
# TODO: Currently, cuda graph only captures decode steps, which only exists for generation models
|
||||
return
|
||||
|
||||
from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner
|
||||
if self.server_args.disable_cuda_graph:
|
||||
return
|
||||
|
||||
if (
|
||||
self.server_args.disable_cuda_graph
|
||||
or self.server_args.attention_backend != "flashinfer"
|
||||
):
|
||||
self.cuda_graph_runner = None
|
||||
if self.server_args.attention_backend != "flashinfer":
|
||||
logger.warning(
|
||||
f"Cuda graph is not supported for attention backend: {self.server_args.attention_backend}"
|
||||
)
|
||||
return
|
||||
|
||||
logger.info("Capture cuda graph begin. This can take up to several minutes.")
|
||||
|
||||
if self.server_args.disable_cuda_graph_padding:
|
||||
batch_size_list = list(range(1, 32)) + [64, 128]
|
||||
else:
|
||||
batch_size_list = [1, 2, 4] + [i * 8 for i in range(1, 21)]
|
||||
|
||||
self.cuda_graph_runner = CudaGraphRunner(
|
||||
self,
|
||||
max_batch_size_to_capture=max(batch_size_list),
|
||||
use_torch_compile=self.server_args.enable_torch_compile,
|
||||
disable_padding=self.server_args.disable_cuda_graph_padding,
|
||||
)
|
||||
try:
|
||||
self.cuda_graph_runner.capture(batch_size_list)
|
||||
except RuntimeError as e:
|
||||
raise Exception(
|
||||
f"Capture cuda graph failed: {e}\n"
|
||||
"Possible solutions:\n"
|
||||
"1. disable cuda graph by --disable-cuda-graph\n"
|
||||
"2. set --mem-fraction-static to a smaller value\n"
|
||||
"3. disable torch compile by not using --enable-torch-compile\n"
|
||||
"Open an issue on GitHub https://github.com/sgl-project/sglang/issues/new/choose \n"
|
||||
)
|
||||
self.cuda_graph_runner = CudaGraphRunner(self)
|
||||
|
||||
@torch.inference_mode()
|
||||
def forward_decode(self, batch: ScheduleBatch):
|
||||
|
||||
@@ -143,18 +143,16 @@ class SamplingBatchInfo:
|
||||
self.linear_penalties = penalizer.apply(self.linear_penalties)
|
||||
|
||||
def update_regex_vocab_mask(self, batch: ScheduleBatch):
|
||||
bs, reqs = batch.batch_size(), batch.reqs
|
||||
device = "cuda"
|
||||
has_regex = any(req.regex_fsm is not None for req in reqs)
|
||||
has_regex = any(req.regex_fsm is not None for req in batch.reqs)
|
||||
|
||||
# Reset the vocab mask
|
||||
self.vocab_mask = None
|
||||
|
||||
if has_regex:
|
||||
self.vocab_mask = torch.zeros(
|
||||
bs, self.vocab_size, dtype=torch.bool, device=device
|
||||
batch.batch_size(), self.vocab_size, dtype=torch.bool, device="cuda"
|
||||
)
|
||||
for i, req in enumerate(reqs):
|
||||
for i, req in enumerate(batch.reqs):
|
||||
if req.regex_fsm is not None:
|
||||
self.vocab_mask[i].fill_(1)
|
||||
self.vocab_mask[i][
|
||||
|
||||
@@ -335,23 +335,19 @@ def launch_server(
|
||||
return
|
||||
|
||||
# Launch processes
|
||||
tokenizer_manager = TokenizerManager(server_args, port_args)
|
||||
if server_args.chat_template:
|
||||
load_chat_template_for_openai_api(tokenizer_manager, server_args.chat_template)
|
||||
pipe_controller_reader, pipe_controller_writer = mp.Pipe(duplex=False)
|
||||
pipe_detoken_reader, pipe_detoken_writer = mp.Pipe(duplex=False)
|
||||
|
||||
if server_args.dp_size == 1:
|
||||
start_controller_process = start_controller_process_single
|
||||
else:
|
||||
start_controller_process = start_controller_process_multi
|
||||
|
||||
proc_controller = mp.Process(
|
||||
target=start_controller_process,
|
||||
args=(server_args, port_args, pipe_controller_writer),
|
||||
)
|
||||
proc_controller.start()
|
||||
|
||||
pipe_detoken_reader, pipe_detoken_writer = mp.Pipe(duplex=False)
|
||||
proc_detoken = mp.Process(
|
||||
target=start_detokenizer_process,
|
||||
args=(
|
||||
@@ -362,6 +358,10 @@ def launch_server(
|
||||
)
|
||||
proc_detoken.start()
|
||||
|
||||
tokenizer_manager = TokenizerManager(server_args, port_args)
|
||||
if server_args.chat_template:
|
||||
load_chat_template_for_openai_api(tokenizer_manager, server_args.chat_template)
|
||||
|
||||
# Wait for the model to finish loading
|
||||
controller_init_state = pipe_controller_reader.recv()
|
||||
detoken_init_state = pipe_detoken_reader.recv()
|
||||
|
||||
@@ -83,8 +83,8 @@ class ServerArgs:
|
||||
json_model_override_args: str = "{}"
|
||||
|
||||
# Optimization/debug options
|
||||
attention_backend: str = "flashinfer"
|
||||
sampling_backend: str = "flashinfer"
|
||||
attention_backend: Optional[str] = None
|
||||
sampling_backend: Optional[str] = None
|
||||
|
||||
disable_flashinfer: bool = False
|
||||
disable_flashinfer_sampling: bool = False
|
||||
@@ -148,6 +148,17 @@ class ServerArgs:
|
||||
)
|
||||
self.sampling_backend = "pytorch"
|
||||
|
||||
# Default kernel backends
|
||||
if self.enable_mla:
|
||||
logger.info("MLA optimization is tunred on. Use triton backend.")
|
||||
self.attention_backend = "triton"
|
||||
|
||||
if self.attention_backend is None:
|
||||
self.attention_backend = "flashinfer"
|
||||
|
||||
if self.sampling_backend is None:
|
||||
self.sampling_backend = "flashinfer"
|
||||
|
||||
# Model-specific patches
|
||||
if "Alibaba-NLP/gte-Qwen2-1.5B-instruct" == self.model_path:
|
||||
logger.info(
|
||||
|
||||
Reference in New Issue
Block a user