Cleanup attention backend: flashinfer and triton (#611)
This commit is contained in:
@@ -1,6 +1,5 @@
|
|||||||
"""Radix attention."""
|
"""Radix attention."""
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
import torch
|
||||||
from flashinfer.cascade import merge_state
|
from flashinfer.cascade import merge_state
|
||||||
from torch import nn
|
from torch import nn
|
||||||
@@ -51,13 +50,13 @@ class RadixAttention(nn.Module):
|
|||||||
input_metadata.token_to_kv_pool.get_value_buffer(self.layer_id),
|
input_metadata.token_to_kv_pool.get_value_buffer(self.layer_id),
|
||||||
input_metadata.req_to_token_pool.req_to_token,
|
input_metadata.req_to_token_pool.req_to_token,
|
||||||
input_metadata.req_pool_indices,
|
input_metadata.req_pool_indices,
|
||||||
input_metadata.start_loc,
|
input_metadata.triton_start_loc,
|
||||||
input_metadata.seq_lens,
|
input_metadata.seq_lens,
|
||||||
input_metadata.prefix_lens,
|
input_metadata.triton_prefix_lens,
|
||||||
input_metadata.extend_start_loc,
|
input_metadata.extend_start_loc,
|
||||||
input_metadata.extend_seq_lens,
|
input_metadata.extend_seq_lens,
|
||||||
input_metadata.max_seq_len,
|
input_metadata.triton_max_seq_len,
|
||||||
input_metadata.max_extend_len,
|
input_metadata.triton_max_extend_len,
|
||||||
sm_scale=self.scaling,
|
sm_scale=self.scaling,
|
||||||
logit_cap=self.logit_cap,
|
logit_cap=self.logit_cap,
|
||||||
)
|
)
|
||||||
@@ -75,9 +74,9 @@ class RadixAttention(nn.Module):
|
|||||||
o.view(-1, self.tp_q_head_num, self.head_dim),
|
o.view(-1, self.tp_q_head_num, self.head_dim),
|
||||||
input_metadata.req_to_token_pool.req_to_token,
|
input_metadata.req_to_token_pool.req_to_token,
|
||||||
input_metadata.req_pool_indices,
|
input_metadata.req_pool_indices,
|
||||||
input_metadata.start_loc,
|
input_metadata.triton_start_loc,
|
||||||
input_metadata.seq_lens,
|
input_metadata.seq_lens,
|
||||||
input_metadata.max_seq_len,
|
input_metadata.triton_max_seq_len,
|
||||||
input_metadata.total_num_tokens,
|
input_metadata.total_num_tokens,
|
||||||
sm_scale=self.scaling,
|
sm_scale=self.scaling,
|
||||||
logit_cap=self.logit_cap,
|
logit_cap=self.logit_cap,
|
||||||
@@ -95,7 +94,7 @@ class RadixAttention(nn.Module):
|
|||||||
logits_soft_cap=self.logit_cap,
|
logits_soft_cap=self.logit_cap,
|
||||||
)
|
)
|
||||||
|
|
||||||
if input_metadata.no_prefix:
|
if input_metadata.extend_no_prefix:
|
||||||
o = o1
|
o = o1
|
||||||
else:
|
else:
|
||||||
o2, s2 = input_metadata.flashinfer_prefill_wrapper_paged.forward_return_lse(
|
o2, s2 = input_metadata.flashinfer_prefill_wrapper_paged.forward_return_lse(
|
||||||
|
|||||||
@@ -312,7 +312,7 @@ def token_attention_fwd(
|
|||||||
b_seq_len,
|
b_seq_len,
|
||||||
max_len_in_batch,
|
max_len_in_batch,
|
||||||
total_num_tokens,
|
total_num_tokens,
|
||||||
sm_scale=None,
|
sm_scale,
|
||||||
logit_cap=-1,
|
logit_cap=-1,
|
||||||
att_m=None,
|
att_m=None,
|
||||||
):
|
):
|
||||||
@@ -320,7 +320,6 @@ def token_attention_fwd(
|
|||||||
att_m = torch.empty(
|
att_m = torch.empty(
|
||||||
(q.shape[-2], total_num_tokens), dtype=REDUCE_TORCH_TYPE, device="cuda"
|
(q.shape[-2], total_num_tokens), dtype=REDUCE_TORCH_TYPE, device="cuda"
|
||||||
)
|
)
|
||||||
sm_scale = 1.0 / (Lq**0.5) if sm_scale is None else sm_scale
|
|
||||||
|
|
||||||
_token_att_m_fwd(
|
_token_att_m_fwd(
|
||||||
q,
|
q,
|
||||||
|
|||||||
@@ -75,6 +75,7 @@ class Req:
|
|||||||
"""Store all inforamtion of a request."""
|
"""Store all inforamtion of a request."""
|
||||||
|
|
||||||
def __init__(self, rid, origin_input_text, origin_input_ids):
|
def __init__(self, rid, origin_input_text, origin_input_ids):
|
||||||
|
# Input and output info
|
||||||
self.rid = rid
|
self.rid = rid
|
||||||
self.origin_input_text = origin_input_text
|
self.origin_input_text = origin_input_text
|
||||||
self.origin_input_ids_unpadded = origin_input_ids # Before image padding
|
self.origin_input_ids_unpadded = origin_input_ids # Before image padding
|
||||||
@@ -97,6 +98,11 @@ class Req:
|
|||||||
self.image_offset = 0
|
self.image_offset = 0
|
||||||
self.pad_value = None
|
self.pad_value = None
|
||||||
|
|
||||||
|
# Prefix info
|
||||||
|
self.extend_input_len = 0
|
||||||
|
self.prefix_indices = []
|
||||||
|
self.last_node = None
|
||||||
|
|
||||||
# Sampling parameters
|
# Sampling parameters
|
||||||
self.sampling_params = None
|
self.sampling_params = None
|
||||||
self.stream = False
|
self.stream = False
|
||||||
@@ -105,11 +111,6 @@ class Req:
|
|||||||
self.tokenizer = None
|
self.tokenizer = None
|
||||||
self.finished_reason = None
|
self.finished_reason = None
|
||||||
|
|
||||||
# Prefix info
|
|
||||||
self.extend_input_len = 0
|
|
||||||
self.prefix_indices = []
|
|
||||||
self.last_node = None
|
|
||||||
|
|
||||||
# Logprobs
|
# Logprobs
|
||||||
self.return_logprob = False
|
self.return_logprob = False
|
||||||
self.logprob_start_len = 0
|
self.logprob_start_len = 0
|
||||||
@@ -261,35 +262,36 @@ class Req:
|
|||||||
class Batch:
|
class Batch:
|
||||||
"""Store all inforamtion of a batch."""
|
"""Store all inforamtion of a batch."""
|
||||||
|
|
||||||
|
# Request, memory pool, and cache
|
||||||
reqs: List[Req]
|
reqs: List[Req]
|
||||||
req_to_token_pool: ReqToTokenPool
|
req_to_token_pool: ReqToTokenPool
|
||||||
token_to_kv_pool: TokenToKVPool
|
token_to_kv_pool: TokenToKVPool
|
||||||
tree_cache: RadixCache
|
tree_cache: RadixCache
|
||||||
|
|
||||||
# batched arguments to model runner
|
# Batched arguments to model runner
|
||||||
input_ids: torch.Tensor = None
|
input_ids: torch.Tensor = None
|
||||||
req_pool_indices: torch.Tensor = None
|
req_pool_indices: torch.Tensor = None
|
||||||
seq_lens: torch.Tensor = None
|
seq_lens: torch.Tensor = None
|
||||||
prefix_lens: torch.Tensor = None
|
prefix_lens: torch.Tensor = None
|
||||||
position_ids_offsets: torch.Tensor = None
|
position_ids_offsets: torch.Tensor = None
|
||||||
out_cache_loc: torch.Tensor = None
|
out_cache_loc: torch.Tensor = None
|
||||||
out_cache_cont_start: torch.Tensor = None
|
out_cache_cont_start: int = None
|
||||||
out_cache_cont_end: torch.Tensor = None
|
out_cache_cont_end: int = None
|
||||||
|
|
||||||
# for processing logprobs
|
# For processing logprobs
|
||||||
return_logprob: bool = False
|
return_logprob: bool = False
|
||||||
top_logprobs_nums: List[int] = None
|
top_logprobs_nums: List[int] = None
|
||||||
|
|
||||||
# for multimodal
|
# For multimodal
|
||||||
pixel_values: List[torch.Tensor] = None
|
pixel_values: List[torch.Tensor] = None
|
||||||
image_sizes: List[List[int]] = None
|
image_sizes: List[List[int]] = None
|
||||||
image_offsets: List[int] = None
|
image_offsets: List[int] = None
|
||||||
|
|
||||||
# other arguments for control
|
# Other arguments for control
|
||||||
output_ids: torch.Tensor = None
|
output_ids: torch.Tensor = None
|
||||||
extend_num_tokens: int = None
|
extend_num_tokens: int = None
|
||||||
|
|
||||||
# batched sampling params
|
# Batched sampling params
|
||||||
temperatures: torch.Tensor = None
|
temperatures: torch.Tensor = None
|
||||||
top_ps: torch.Tensor = None
|
top_ps: torch.Tensor = None
|
||||||
top_ks: torch.Tensor = None
|
top_ks: torch.Tensor = None
|
||||||
@@ -312,8 +314,8 @@ class Batch:
|
|||||||
def is_empty(self):
|
def is_empty(self):
|
||||||
return len(self.reqs) == 0
|
return len(self.reqs) == 0
|
||||||
|
|
||||||
# whether batch has at least 1 streaming request
|
|
||||||
def has_stream(self) -> bool:
|
def has_stream(self) -> bool:
|
||||||
|
# Return whether batch has at least 1 streaming request
|
||||||
return any(r.stream for r in self.reqs)
|
return any(r.stream for r in self.reqs)
|
||||||
|
|
||||||
def prepare_for_extend(self, vocab_size: int, int_token_logit_bias: torch.Tensor):
|
def prepare_for_extend(self, vocab_size: int, int_token_logit_bias: torch.Tensor):
|
||||||
@@ -347,7 +349,7 @@ class Batch:
|
|||||||
|
|
||||||
position_ids_offsets = torch.zeros((bs,), dtype=torch.int32, device=device)
|
position_ids_offsets = torch.zeros((bs,), dtype=torch.int32, device=device)
|
||||||
|
|
||||||
# Alloc mem
|
# Allocate memory
|
||||||
seq_lens, prefix_lens = np.array(seq_lens), np.array(prefix_lens)
|
seq_lens, prefix_lens = np.array(seq_lens), np.array(prefix_lens)
|
||||||
extend_num_tokens = seq_lens.sum() - prefix_lens.sum()
|
extend_num_tokens = seq_lens.sum() - prefix_lens.sum()
|
||||||
out_cache_loc = self.token_to_kv_pool.alloc(extend_num_tokens)
|
out_cache_loc = self.token_to_kv_pool.alloc(extend_num_tokens)
|
||||||
@@ -703,7 +705,6 @@ def _top_p_top_k(probs: torch.Tensor, top_ps: torch.Tensor, top_ks: torch.Tensor
|
|||||||
return probs_sort, probs_idx
|
return probs_sort, probs_idx
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class InputMetadata:
|
class InputMetadata:
|
||||||
"""Store all inforamtion of a forward pass."""
|
"""Store all inforamtion of a forward pass."""
|
||||||
@@ -711,110 +712,37 @@ class InputMetadata:
|
|||||||
forward_mode: ForwardMode
|
forward_mode: ForwardMode
|
||||||
batch_size: int
|
batch_size: int
|
||||||
total_num_tokens: int
|
total_num_tokens: int
|
||||||
max_seq_len: int
|
|
||||||
req_pool_indices: torch.Tensor
|
req_pool_indices: torch.Tensor
|
||||||
start_loc: torch.Tensor
|
|
||||||
seq_lens: torch.Tensor
|
seq_lens: torch.Tensor
|
||||||
prefix_lens: torch.Tensor
|
|
||||||
positions: torch.Tensor
|
positions: torch.Tensor
|
||||||
req_to_token_pool: ReqToTokenPool
|
req_to_token_pool: ReqToTokenPool
|
||||||
token_to_kv_pool: TokenToKVPool
|
token_to_kv_pool: TokenToKVPool
|
||||||
|
|
||||||
# for extend
|
# For extend
|
||||||
extend_seq_lens: torch.Tensor = None
|
extend_seq_lens: torch.Tensor
|
||||||
extend_start_loc: torch.Tensor = None
|
extend_start_loc: torch.Tensor
|
||||||
max_extend_len: int = 0
|
extend_no_prefix: bool
|
||||||
|
|
||||||
|
# Output location of the KV cache
|
||||||
out_cache_loc: torch.Tensor = None
|
out_cache_loc: torch.Tensor = None
|
||||||
out_cache_cont_start: torch.Tensor = None
|
out_cache_cont_start: int = None
|
||||||
out_cache_cont_end: torch.Tensor = None
|
out_cache_cont_end: int = None
|
||||||
|
|
||||||
|
# Output options
|
||||||
return_logprob: bool = False
|
return_logprob: bool = False
|
||||||
top_logprobs_nums: List[int] = None
|
top_logprobs_nums: List[int] = None
|
||||||
|
|
||||||
# for flashinfer
|
# Trition attention backend
|
||||||
qo_indptr: torch.Tensor = None
|
triton_max_seq_len: int = 0
|
||||||
kv_indptr: torch.Tensor = None
|
triton_max_extend_len: int = 0
|
||||||
kv_indices: torch.Tensor = None
|
triton_start_loc: torch.Tensor = None
|
||||||
kv_last_page_len: torch.Tensor = None
|
triton_prefix_lens: torch.Tensor = None
|
||||||
|
|
||||||
|
# FlashInfer attention backend
|
||||||
flashinfer_prefill_wrapper_ragged: "BatchPrefillWithRaggedKVCacheWrapper" = None
|
flashinfer_prefill_wrapper_ragged: "BatchPrefillWithRaggedKVCacheWrapper" = None
|
||||||
flashinfer_prefill_wrapper_paged: "BatchPrefillWithPagedKVCacheWrapper" = None
|
flashinfer_prefill_wrapper_paged: "BatchPrefillWithPagedKVCacheWrapper" = None
|
||||||
flashinfer_decode_wrapper: "BatchDecodeWithPagedKVCacheWrapper" = None
|
flashinfer_decode_wrapper: "BatchDecodeWithPagedKVCacheWrapper" = None
|
||||||
|
|
||||||
def init_flashinfer_args(self, num_qo_heads, num_kv_heads, head_dim):
|
|
||||||
if self.forward_mode == ForwardMode.DECODE:
|
|
||||||
paged_kernel_lens = self.seq_lens
|
|
||||||
else:
|
|
||||||
paged_kernel_lens = self.prefix_lens
|
|
||||||
self.no_prefix = torch.all(self.prefix_lens == 0)
|
|
||||||
|
|
||||||
kv_indptr = torch.zeros(
|
|
||||||
(self.batch_size + 1,), dtype=torch.int32, device="cuda"
|
|
||||||
)
|
|
||||||
kv_indptr[1:] = torch.cumsum(paged_kernel_lens, dim=0)
|
|
||||||
req_pool_indices_cpu = self.req_pool_indices.cpu().numpy()
|
|
||||||
paged_kernel_lens_cpu = paged_kernel_lens.cpu().numpy()
|
|
||||||
kv_indices = torch.cat(
|
|
||||||
[
|
|
||||||
self.req_to_token_pool.req_to_token[
|
|
||||||
req_pool_indices_cpu[i], : paged_kernel_lens_cpu[i]
|
|
||||||
]
|
|
||||||
for i in range(self.batch_size)
|
|
||||||
],
|
|
||||||
dim=0,
|
|
||||||
).contiguous()
|
|
||||||
kv_last_page_len = torch.ones(
|
|
||||||
(self.batch_size,), dtype=torch.int32, device="cuda"
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.forward_mode == ForwardMode.DECODE:
|
|
||||||
self.flashinfer_decode_wrapper.end_forward()
|
|
||||||
self.flashinfer_decode_wrapper.begin_forward(
|
|
||||||
kv_indptr,
|
|
||||||
kv_indices,
|
|
||||||
kv_last_page_len,
|
|
||||||
num_qo_heads,
|
|
||||||
num_kv_heads,
|
|
||||||
head_dim,
|
|
||||||
1,
|
|
||||||
pos_encoding_mode="NONE",
|
|
||||||
data_type=self.token_to_kv_pool.kv_data[0].dtype,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# extend part
|
|
||||||
qo_indptr = torch.zeros(
|
|
||||||
(self.batch_size + 1,), dtype=torch.int32, device="cuda"
|
|
||||||
)
|
|
||||||
qo_indptr[1:] = torch.cumsum(self.extend_seq_lens, dim=0)
|
|
||||||
|
|
||||||
self.flashinfer_prefill_wrapper_ragged.end_forward()
|
|
||||||
self.flashinfer_prefill_wrapper_ragged.begin_forward(
|
|
||||||
qo_indptr,
|
|
||||||
qo_indptr,
|
|
||||||
num_qo_heads,
|
|
||||||
num_kv_heads,
|
|
||||||
head_dim,
|
|
||||||
)
|
|
||||||
|
|
||||||
# cached part
|
|
||||||
self.flashinfer_prefill_wrapper_paged.end_forward()
|
|
||||||
self.flashinfer_prefill_wrapper_paged.begin_forward(
|
|
||||||
qo_indptr,
|
|
||||||
kv_indptr,
|
|
||||||
kv_indices,
|
|
||||||
kv_last_page_len,
|
|
||||||
num_qo_heads,
|
|
||||||
num_kv_heads,
|
|
||||||
head_dim,
|
|
||||||
1,
|
|
||||||
)
|
|
||||||
|
|
||||||
def init_extend_args(self):
|
|
||||||
self.extend_seq_lens = self.seq_lens - self.prefix_lens
|
|
||||||
self.extend_start_loc = torch.zeros_like(self.seq_lens)
|
|
||||||
self.extend_start_loc[1:] = torch.cumsum(self.extend_seq_lens[:-1], dim=0)
|
|
||||||
self.max_extend_len = int(torch.max(self.extend_seq_lens))
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def create(
|
def create(
|
||||||
cls,
|
cls,
|
||||||
@@ -830,14 +758,20 @@ class InputMetadata:
|
|||||||
top_logprobs_nums=None,
|
top_logprobs_nums=None,
|
||||||
return_logprob=False,
|
return_logprob=False,
|
||||||
):
|
):
|
||||||
|
if not model_runner.server_args.disable_flashinfer:
|
||||||
|
init_flashinfer_args(forward_mode, model_runner, req_pool_indices, seq_lens, prefix_lens)
|
||||||
|
|
||||||
batch_size = len(req_pool_indices)
|
batch_size = len(req_pool_indices)
|
||||||
start_loc = torch.zeros((batch_size,), dtype=torch.int32, device="cuda")
|
|
||||||
start_loc[1:] = torch.cumsum(seq_lens[:-1], dim=0)
|
|
||||||
total_num_tokens = int(torch.sum(seq_lens))
|
|
||||||
max_seq_len = int(torch.max(seq_lens))
|
|
||||||
|
|
||||||
if forward_mode == ForwardMode.DECODE:
|
if forward_mode == ForwardMode.DECODE:
|
||||||
positions = ((seq_lens - 1) + position_ids_offsets).to(torch.int64)
|
positions = ((seq_lens - 1) + position_ids_offsets).to(torch.int64)
|
||||||
|
extend_seq_lens = extend_start_loc = extend_no_prefix = None
|
||||||
|
if not model_runner.server_args.disable_flashinfer:
|
||||||
|
# This variable is not needed in this case,
|
||||||
|
# we do not compute it to make it compatbile with cuda graph.
|
||||||
|
total_num_tokens = None
|
||||||
|
else:
|
||||||
|
total_num_tokens = int(torch.sum(seq_lens))
|
||||||
else:
|
else:
|
||||||
seq_lens_cpu = seq_lens.cpu().numpy()
|
seq_lens_cpu = seq_lens.cpu().numpy()
|
||||||
prefix_lens_cpu = prefix_lens.cpu().numpy()
|
prefix_lens_cpu = prefix_lens.cpu().numpy()
|
||||||
@@ -855,22 +789,27 @@ class InputMetadata:
|
|||||||
),
|
),
|
||||||
device="cuda",
|
device="cuda",
|
||||||
)
|
)
|
||||||
|
extend_seq_lens = seq_lens - prefix_lens
|
||||||
|
extend_start_loc = torch.zeros_like(seq_lens)
|
||||||
|
extend_start_loc[1:] = torch.cumsum(extend_seq_lens[:-1], dim=0)
|
||||||
|
extend_no_prefix = torch.all(prefix_lens == 0)
|
||||||
|
total_num_tokens = int(torch.sum(seq_lens))
|
||||||
|
|
||||||
ret = cls(
|
ret = cls(
|
||||||
forward_mode=forward_mode,
|
forward_mode=forward_mode,
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
total_num_tokens=total_num_tokens,
|
total_num_tokens=total_num_tokens,
|
||||||
max_seq_len=max_seq_len,
|
|
||||||
req_pool_indices=req_pool_indices,
|
req_pool_indices=req_pool_indices,
|
||||||
start_loc=start_loc,
|
|
||||||
seq_lens=seq_lens,
|
seq_lens=seq_lens,
|
||||||
prefix_lens=prefix_lens,
|
|
||||||
positions=positions,
|
positions=positions,
|
||||||
req_to_token_pool=model_runner.req_to_token_pool,
|
req_to_token_pool=model_runner.req_to_token_pool,
|
||||||
token_to_kv_pool=model_runner.token_to_kv_pool,
|
token_to_kv_pool=model_runner.token_to_kv_pool,
|
||||||
out_cache_loc=out_cache_loc,
|
out_cache_loc=out_cache_loc,
|
||||||
out_cache_cont_start=out_cache_cont_start,
|
out_cache_cont_start=out_cache_cont_start,
|
||||||
out_cache_cont_end=out_cache_cont_end,
|
out_cache_cont_end=out_cache_cont_end,
|
||||||
|
extend_seq_lens=extend_seq_lens,
|
||||||
|
extend_start_loc=extend_start_loc,
|
||||||
|
extend_no_prefix=extend_no_prefix,
|
||||||
return_logprob=return_logprob,
|
return_logprob=return_logprob,
|
||||||
top_logprobs_nums=top_logprobs_nums,
|
top_logprobs_nums=top_logprobs_nums,
|
||||||
flashinfer_prefill_wrapper_ragged=model_runner.flashinfer_prefill_wrapper_ragged,
|
flashinfer_prefill_wrapper_ragged=model_runner.flashinfer_prefill_wrapper_ragged,
|
||||||
@@ -878,14 +817,96 @@ class InputMetadata:
|
|||||||
flashinfer_decode_wrapper=model_runner.flashinfer_decode_wrapper,
|
flashinfer_decode_wrapper=model_runner.flashinfer_decode_wrapper,
|
||||||
)
|
)
|
||||||
|
|
||||||
if forward_mode == ForwardMode.EXTEND:
|
if model_runner.server_args.disable_flashinfer:
|
||||||
ret.init_extend_args()
|
(ret.triton_max_seq_len,
|
||||||
|
ret.triton_max_extend_len,
|
||||||
if not global_server_args_dict.get("disable_flashinfer", False):
|
ret.triton_start_loc,
|
||||||
ret.init_flashinfer_args(
|
ret.triton_prefix_lens) = init_triton_args(forward_mode, seq_lens, prefix_lens)
|
||||||
model_runner.model_config.num_attention_heads // model_runner.tp_size,
|
|
||||||
model_runner.model_config.get_num_kv_heads(model_runner.tp_size),
|
|
||||||
model_runner.model_config.head_dim,
|
|
||||||
)
|
|
||||||
|
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
|
|
||||||
|
def init_flashinfer_args(forward_mode, model_runner, req_pool_indices, seq_lens, prefix_lens):
|
||||||
|
num_qo_heads = model_runner.model_config.num_attention_heads // model_runner.tp_size
|
||||||
|
num_kv_heads = model_runner.model_config.get_num_kv_heads(model_runner.tp_size)
|
||||||
|
head_dim = model_runner.model_config.head_dim
|
||||||
|
batch_size = len(req_pool_indices)
|
||||||
|
|
||||||
|
if forward_mode == ForwardMode.DECODE:
|
||||||
|
paged_kernel_lens = seq_lens
|
||||||
|
else:
|
||||||
|
paged_kernel_lens = prefix_lens
|
||||||
|
|
||||||
|
kv_indptr = torch.zeros(
|
||||||
|
(batch_size + 1,), dtype=torch.int32, device="cuda"
|
||||||
|
)
|
||||||
|
kv_indptr[1:] = torch.cumsum(paged_kernel_lens, dim=0)
|
||||||
|
req_pool_indices_cpu = req_pool_indices.cpu().numpy()
|
||||||
|
paged_kernel_lens_cpu = paged_kernel_lens.cpu().numpy()
|
||||||
|
kv_indices = torch.cat(
|
||||||
|
[
|
||||||
|
model_runner.req_to_token_pool.req_to_token[
|
||||||
|
req_pool_indices_cpu[i], : paged_kernel_lens_cpu[i]
|
||||||
|
]
|
||||||
|
for i in range(batch_size)
|
||||||
|
],
|
||||||
|
dim=0,
|
||||||
|
).contiguous()
|
||||||
|
kv_last_page_len = torch.ones(
|
||||||
|
(batch_size,), dtype=torch.int32, device="cuda"
|
||||||
|
)
|
||||||
|
|
||||||
|
if forward_mode == ForwardMode.DECODE:
|
||||||
|
model_runner.flashinfer_decode_wrapper.end_forward()
|
||||||
|
model_runner.flashinfer_decode_wrapper.begin_forward(
|
||||||
|
kv_indptr,
|
||||||
|
kv_indices,
|
||||||
|
kv_last_page_len,
|
||||||
|
num_qo_heads,
|
||||||
|
num_kv_heads,
|
||||||
|
head_dim,
|
||||||
|
1,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# extend part
|
||||||
|
qo_indptr = torch.zeros(
|
||||||
|
(batch_size + 1,), dtype=torch.int32, device="cuda"
|
||||||
|
)
|
||||||
|
qo_indptr[1:] = torch.cumsum(seq_lens - prefix_lens, dim=0)
|
||||||
|
|
||||||
|
model_runner.flashinfer_prefill_wrapper_ragged.end_forward()
|
||||||
|
model_runner.flashinfer_prefill_wrapper_ragged.begin_forward(
|
||||||
|
qo_indptr,
|
||||||
|
qo_indptr,
|
||||||
|
num_qo_heads,
|
||||||
|
num_kv_heads,
|
||||||
|
head_dim,
|
||||||
|
)
|
||||||
|
|
||||||
|
# cached part
|
||||||
|
model_runner.flashinfer_prefill_wrapper_paged.end_forward()
|
||||||
|
model_runner.flashinfer_prefill_wrapper_paged.begin_forward(
|
||||||
|
qo_indptr,
|
||||||
|
kv_indptr,
|
||||||
|
kv_indices,
|
||||||
|
kv_last_page_len,
|
||||||
|
num_qo_heads,
|
||||||
|
num_kv_heads,
|
||||||
|
head_dim,
|
||||||
|
1,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def init_triton_args(forward_mode, seq_lens, prefix_lens):
|
||||||
|
batch_size = len(seq_lens)
|
||||||
|
max_seq_len = int(torch.max(seq_lens))
|
||||||
|
start_loc = torch.zeros((batch_size,), dtype=torch.int32, device="cuda")
|
||||||
|
start_loc[1:] = torch.cumsum(seq_lens[:-1], dim=0)
|
||||||
|
|
||||||
|
if forward_mode == ForwardMode.DECODE:
|
||||||
|
max_extend_len = None
|
||||||
|
else:
|
||||||
|
extend_seq_lens = seq_lens - prefix_lens
|
||||||
|
max_extend_len = int(torch.max(extend_seq_lens))
|
||||||
|
|
||||||
|
return max_seq_len, max_extend_len, start_loc, prefix_lens
|
||||||
|
|||||||
@@ -182,39 +182,39 @@ class ModelRunner:
|
|||||||
return c
|
return c
|
||||||
|
|
||||||
def init_flash_infer(self):
|
def init_flash_infer(self):
|
||||||
if not global_server_args_dict.get("disable_flashinfer", False):
|
if self.server_args.disable_flashinfer:
|
||||||
from flashinfer import (
|
self.flashinfer_prefill_wrapper_ragged = None
|
||||||
BatchDecodeWithPagedKVCacheWrapper,
|
self.flashinfer_prefill_wrapper_paged = None
|
||||||
BatchPrefillWithPagedKVCacheWrapper,
|
|
||||||
BatchPrefillWithRaggedKVCacheWrapper,
|
|
||||||
)
|
|
||||||
from flashinfer.decode import _grouped_size_compiled_for_decode_kernels
|
|
||||||
|
|
||||||
if not _grouped_size_compiled_for_decode_kernels(
|
|
||||||
self.model_config.num_attention_heads // self.tp_size,
|
|
||||||
self.model_config.get_num_kv_heads(self.tp_size),
|
|
||||||
):
|
|
||||||
use_tensor_cores = True
|
|
||||||
else:
|
|
||||||
use_tensor_cores = False
|
|
||||||
|
|
||||||
workspace_buffers = torch.empty(
|
|
||||||
2, 96 * 1024 * 1024, dtype=torch.uint8, device="cuda"
|
|
||||||
)
|
|
||||||
self.flashinfer_prefill_wrapper_ragged = (
|
|
||||||
BatchPrefillWithRaggedKVCacheWrapper(workspace_buffers[0], "NHD")
|
|
||||||
)
|
|
||||||
self.flashinfer_prefill_wrapper_paged = BatchPrefillWithPagedKVCacheWrapper(
|
|
||||||
workspace_buffers[1], "NHD"
|
|
||||||
)
|
|
||||||
self.flashinfer_decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
|
|
||||||
workspace_buffers[0], "NHD", use_tensor_cores=use_tensor_cores
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
self.flashinfer_prefill_wrapper_ragged = (
|
|
||||||
self.flashinfer_prefill_wrapper_paged
|
|
||||||
) = None
|
|
||||||
self.flashinfer_decode_wrapper = None
|
self.flashinfer_decode_wrapper = None
|
||||||
|
return
|
||||||
|
|
||||||
|
from flashinfer import (
|
||||||
|
BatchDecodeWithPagedKVCacheWrapper,
|
||||||
|
BatchPrefillWithPagedKVCacheWrapper,
|
||||||
|
BatchPrefillWithRaggedKVCacheWrapper,
|
||||||
|
)
|
||||||
|
from flashinfer.decode import _grouped_size_compiled_for_decode_kernels
|
||||||
|
|
||||||
|
if not _grouped_size_compiled_for_decode_kernels(
|
||||||
|
self.model_config.num_attention_heads // self.tp_size,
|
||||||
|
self.model_config.get_num_kv_heads(self.tp_size),
|
||||||
|
):
|
||||||
|
use_tensor_cores = True
|
||||||
|
else:
|
||||||
|
use_tensor_cores = False
|
||||||
|
|
||||||
|
workspace_buffers = torch.empty(
|
||||||
|
3, 96 * 1024 * 1024, dtype=torch.uint8, device="cuda"
|
||||||
|
)
|
||||||
|
self.flashinfer_prefill_wrapper_ragged = BatchPrefillWithRaggedKVCacheWrapper(
|
||||||
|
workspace_buffers[0], "NHD"
|
||||||
|
)
|
||||||
|
self.flashinfer_prefill_wrapper_paged = BatchPrefillWithPagedKVCacheWrapper(
|
||||||
|
workspace_buffers[1], "NHD"
|
||||||
|
)
|
||||||
|
self.flashinfer_decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
|
||||||
|
workspace_buffers[2], "NHD", use_tensor_cores=use_tensor_cores
|
||||||
|
)
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def forward_extend(self, batch: Batch):
|
def forward_extend(self, batch: Batch):
|
||||||
|
|||||||
Reference in New Issue
Block a user