Simplify flashinfer utilities (#1704)
This commit is contained in:
@@ -1,5 +1,6 @@
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
|
|
||||||
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||||
@@ -18,13 +19,13 @@ class AttentionBackend(ABC):
|
|||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
def init_forward_metadata_capture_cuda_graph(
|
def init_forward_metadata_capture_cuda_graph(
|
||||||
self, bs: int, req_pool_indices, seq_lens
|
self, bs: int, req_pool_indices: torch.Tensor, seq_lens: torch.Tensor
|
||||||
):
|
):
|
||||||
"""Init the metadata for a forward pass for capturing a cuda graph."""
|
"""Init the metadata for a forward pass for capturing a cuda graph."""
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
def init_forward_metadata_replay_cuda_graph(
|
def init_forward_metadata_replay_cuda_graph(
|
||||||
self, bs: int, req_pool_indices, seq_lens
|
self, bs: int, req_pool_indices: torch.Tensor, seq_lens: torch.Tensor
|
||||||
):
|
):
|
||||||
"""Init the metadata for a forward pass for replying a cuda graph."""
|
"""Init the metadata for a forward pass for replying a cuda graph."""
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
@@ -33,17 +34,38 @@ class AttentionBackend(ABC):
|
|||||||
"""Get the fill value for padded seq lens. Typically, it is 0 or 1."""
|
"""Get the fill value for padded seq lens. Typically, it is 0 or 1."""
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
def forward(self, q, k, v, layer: nn.Module, forward_batch: ForwardBatch):
|
def forward(
|
||||||
|
self,
|
||||||
|
q: torch.Tensor,
|
||||||
|
k: torch.Tensor,
|
||||||
|
v: torch.Tensor,
|
||||||
|
layer: nn.Module,
|
||||||
|
forward_batch: ForwardBatch,
|
||||||
|
):
|
||||||
"""Run forward on an attention layer."""
|
"""Run forward on an attention layer."""
|
||||||
if forward_batch.forward_mode.is_decode():
|
if forward_batch.forward_mode.is_decode():
|
||||||
return self.forward_decode(q, k, v, layer, forward_batch)
|
return self.forward_decode(q, k, v, layer, forward_batch)
|
||||||
else:
|
else:
|
||||||
return self.forward_extend(q, k, v, layer, forward_batch)
|
return self.forward_extend(q, k, v, layer, forward_batch)
|
||||||
|
|
||||||
def forward_decode(self, q, k, v, layer: nn.Module, forward_batch: ForwardBatch):
|
def forward_decode(
|
||||||
|
self,
|
||||||
|
q: torch.Tensor,
|
||||||
|
k: torch.Tensor,
|
||||||
|
v: torch.Tensor,
|
||||||
|
layer: nn.Module,
|
||||||
|
forward_batch: ForwardBatch,
|
||||||
|
):
|
||||||
"""Run a forward for decode."""
|
"""Run a forward for decode."""
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
def forward_extend(self, q, k, v, layer: nn.Module, forward_batch: ForwardBatch):
|
def forward_extend(
|
||||||
|
self,
|
||||||
|
q: torch.Tensor,
|
||||||
|
k: torch.Tensor,
|
||||||
|
v: torch.Tensor,
|
||||||
|
layer: nn.Module,
|
||||||
|
forward_batch: ForwardBatch,
|
||||||
|
):
|
||||||
"""Run a forward for extend."""
|
"""Run a forward for extend."""
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|||||||
@@ -134,7 +134,7 @@ class DoubleSparseAttnBackend(AttentionBackend):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def init_forward_metadata_capture_cuda_graph(
|
def init_forward_metadata_capture_cuda_graph(
|
||||||
self, bs: int, req_pool_indices, seq_lens
|
self, bs: int, req_pool_indices: torch.Tensor, seq_lens: torch.Tensor
|
||||||
):
|
):
|
||||||
self.forward_metadata = (
|
self.forward_metadata = (
|
||||||
self.cuda_graph_start_loc,
|
self.cuda_graph_start_loc,
|
||||||
@@ -144,7 +144,7 @@ class DoubleSparseAttnBackend(AttentionBackend):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def init_forward_metadata_replay_cuda_graph(
|
def init_forward_metadata_replay_cuda_graph(
|
||||||
self, bs: int, req_pool_indices, seq_lens
|
self, bs: int, req_pool_indices: torch.Tensor, seq_lens: torch.Tensor
|
||||||
):
|
):
|
||||||
self.cuda_graph_start_loc.zero_()
|
self.cuda_graph_start_loc.zero_()
|
||||||
self.cuda_graph_start_loc[1:bs] = torch.cumsum(seq_lens[: bs - 1], dim=0)
|
self.cuda_graph_start_loc[1:bs] = torch.cumsum(seq_lens[: bs - 1], dim=0)
|
||||||
|
|||||||
@@ -7,18 +7,17 @@ FlashInfer is faster and Triton is easier to customize.
|
|||||||
Each backend supports two operators: extend (i.e. prefill with cached prefix) and decode.
|
Each backend supports two operators: extend (i.e. prefill with cached prefix) and decode.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
from enum import Enum, auto
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
import triton
|
||||||
|
import triton.language as tl
|
||||||
|
|
||||||
from sglang.global_config import global_config
|
from sglang.global_config import global_config
|
||||||
from sglang.srt.layers.attention import AttentionBackend
|
from sglang.srt.layers.attention import AttentionBackend
|
||||||
from sglang.srt.layers.attention.flashinfer_utils import (
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||||
WrapperDispatch,
|
|
||||||
update_flashinfer_indices,
|
|
||||||
)
|
|
||||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
|
||||||
from sglang.srt.utils import is_flashinfer_available
|
from sglang.srt.utils import is_flashinfer_available
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@@ -34,13 +33,18 @@ if is_flashinfer_available():
|
|||||||
from flashinfer.decode import _grouped_size_compiled_for_decode_kernels
|
from flashinfer.decode import _grouped_size_compiled_for_decode_kernels
|
||||||
|
|
||||||
|
|
||||||
|
class WrapperDispatch(Enum):
|
||||||
|
SLIDING_WINDOW = auto()
|
||||||
|
CROSS_ATTENTION = auto()
|
||||||
|
|
||||||
|
|
||||||
class FlashInferAttnBackend(AttentionBackend):
|
class FlashInferAttnBackend(AttentionBackend):
|
||||||
"""Flashinfer attention kernels."""
|
"""Flashinfer attention kernels."""
|
||||||
|
|
||||||
def __init__(self, model_runner: ModelRunner):
|
def __init__(self, model_runner: ModelRunner):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.model_runner = model_runner
|
|
||||||
|
|
||||||
|
# Parse constants
|
||||||
if not _grouped_size_compiled_for_decode_kernels(
|
if not _grouped_size_compiled_for_decode_kernels(
|
||||||
model_runner.model_config.num_attention_heads // model_runner.tp_size,
|
model_runner.model_config.num_attention_heads // model_runner.tp_size,
|
||||||
model_runner.model_config.get_num_kv_heads(model_runner.tp_size),
|
model_runner.model_config.get_num_kv_heads(model_runner.tp_size),
|
||||||
@@ -48,27 +52,43 @@ class FlashInferAttnBackend(AttentionBackend):
|
|||||||
self.decode_use_tensor_cores = True
|
self.decode_use_tensor_cores = True
|
||||||
else:
|
else:
|
||||||
self.decode_use_tensor_cores = False
|
self.decode_use_tensor_cores = False
|
||||||
|
self.max_context_len = model_runner.model_config.context_len
|
||||||
self.workspace_buffer = torch.empty(
|
|
||||||
global_config.flashinfer_workspace_size,
|
|
||||||
dtype=torch.uint8,
|
|
||||||
device="cuda",
|
|
||||||
)
|
|
||||||
|
|
||||||
assert not (
|
assert not (
|
||||||
model_runner.sliding_window_size is not None
|
model_runner.sliding_window_size is not None
|
||||||
and model_runner.has_cross_attention
|
and model_runner.has_cross_attention
|
||||||
), "Sliding window and cross attention are not supported together"
|
), "Sliding window and cross attention are not supported together"
|
||||||
|
|
||||||
self.num_wrappers = 1
|
|
||||||
self.dispatch_reason = None
|
|
||||||
if model_runner.sliding_window_size is not None:
|
if model_runner.sliding_window_size is not None:
|
||||||
self.num_wrappers = 2
|
self.num_wrappers = 2
|
||||||
self.dispatch_reason = WrapperDispatch.SLIDING_WINDOW
|
self.dispatch_reason = WrapperDispatch.SLIDING_WINDOW
|
||||||
elif model_runner.has_cross_attention:
|
elif model_runner.has_cross_attention:
|
||||||
self.num_wrappers = 2
|
self.num_wrappers = 2
|
||||||
self.dispatch_reason = WrapperDispatch.CROSS_ATTENTION
|
self.dispatch_reason = WrapperDispatch.CROSS_ATTENTION
|
||||||
|
else:
|
||||||
|
self.num_wrappers = 1
|
||||||
|
self.dispatch_reason = None
|
||||||
|
|
||||||
|
# Allocate buffers
|
||||||
|
self.workspace_buffer = torch.empty(
|
||||||
|
global_config.flashinfer_workspace_size,
|
||||||
|
dtype=torch.uint8,
|
||||||
|
device=model_runner.device,
|
||||||
|
)
|
||||||
|
max_bs = model_runner.req_to_token_pool.size
|
||||||
|
self.kv_indptr = [
|
||||||
|
torch.zeros((max_bs + 1,), dtype=torch.int32, device=model_runner.device)
|
||||||
|
for _ in range(self.num_wrappers)
|
||||||
|
]
|
||||||
|
self.kv_last_page_len = torch.ones(
|
||||||
|
(max_bs,), dtype=torch.int32, device=model_runner.device
|
||||||
|
)
|
||||||
|
self.qo_indptr = [
|
||||||
|
torch.zeros((max_bs + 1,), dtype=torch.int32, device=model_runner.device)
|
||||||
|
for _ in range(self.num_wrappers)
|
||||||
|
]
|
||||||
|
|
||||||
|
# Create wrappers
|
||||||
# NOTE: we do not use ragged attention when there are multiple wrappers
|
# NOTE: we do not use ragged attention when there are multiple wrappers
|
||||||
self.prefill_wrapper_ragged = (
|
self.prefill_wrapper_ragged = (
|
||||||
BatchPrefillWithRaggedKVCacheWrapper(self.workspace_buffer, "NHD")
|
BatchPrefillWithRaggedKVCacheWrapper(self.workspace_buffer, "NHD")
|
||||||
@@ -92,26 +112,23 @@ class FlashInferAttnBackend(AttentionBackend):
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Create indices updater
|
||||||
|
self.indices_updater_decode = FlashInferIndicesUpdaterDecode(model_runner, self)
|
||||||
|
self.indices_updater_prefill = FlashInferIndicesUpdaterPrefill(
|
||||||
|
model_runner, self
|
||||||
|
)
|
||||||
|
|
||||||
|
# Other metadata
|
||||||
self.forward_metadata = None
|
self.forward_metadata = None
|
||||||
self.cuda_graph_metadata = {}
|
self.cuda_graph_metadata = {}
|
||||||
|
|
||||||
def _get_wrapper_idx(self, layer: nn.Module):
|
|
||||||
if self.num_wrappers == 1:
|
|
||||||
return 0
|
|
||||||
|
|
||||||
if self.dispatch_reason == WrapperDispatch.SLIDING_WINDOW:
|
|
||||||
return layer.sliding_window_size == -1
|
|
||||||
if self.dispatch_reason == WrapperDispatch.CROSS_ATTENTION:
|
|
||||||
return layer.is_cross_attention
|
|
||||||
|
|
||||||
raise ValueError(f"Unknown dispatch reason: {self.dispatch_reason}")
|
|
||||||
|
|
||||||
def init_forward_metadata(self, forward_batch: ForwardBatch):
|
def init_forward_metadata(self, forward_batch: ForwardBatch):
|
||||||
if forward_batch.forward_mode.is_decode():
|
if forward_batch.forward_mode.is_decode():
|
||||||
prefix_lens = None
|
self.indices_updater_decode.update(
|
||||||
use_ragged = False
|
forward_batch.req_pool_indices,
|
||||||
extend_no_prefix = False
|
forward_batch.seq_lens,
|
||||||
total_num_tokens = None
|
)
|
||||||
|
self.forward_metadata = (self.decode_wrappers,)
|
||||||
else:
|
else:
|
||||||
prefix_lens = forward_batch.extend_prefix_lens
|
prefix_lens = forward_batch.extend_prefix_lens
|
||||||
|
|
||||||
@@ -123,48 +140,32 @@ class FlashInferAttnBackend(AttentionBackend):
|
|||||||
):
|
):
|
||||||
use_ragged = True
|
use_ragged = True
|
||||||
|
|
||||||
total_num_tokens = torch.sum(forward_batch.seq_lens).item()
|
|
||||||
extend_no_prefix = not torch.any(forward_batch.extend_prefix_lens).item()
|
extend_no_prefix = not torch.any(forward_batch.extend_prefix_lens).item()
|
||||||
|
|
||||||
update_flashinfer_indices(
|
self.indices_updater_prefill.update(
|
||||||
forward_batch.forward_mode,
|
forward_batch.req_pool_indices,
|
||||||
self.model_runner,
|
forward_batch.seq_lens,
|
||||||
forward_batch.req_pool_indices,
|
prefix_lens,
|
||||||
forward_batch.seq_lens,
|
use_ragged,
|
||||||
prefix_lens,
|
)
|
||||||
use_ragged=use_ragged,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.forward_metadata = (
|
self.forward_metadata = (
|
||||||
use_ragged,
|
use_ragged,
|
||||||
extend_no_prefix,
|
extend_no_prefix,
|
||||||
total_num_tokens,
|
)
|
||||||
self.decode_wrappers,
|
|
||||||
)
|
|
||||||
|
|
||||||
def init_cuda_graph_state(self, max_bs: int):
|
def init_cuda_graph_state(self, max_bs: int):
|
||||||
self.cuda_graph_kv_indptr = torch.zeros(
|
cuda_graph_kv_indices = torch.zeros(
|
||||||
(max_bs + 1,), dtype=torch.int32, device="cuda"
|
(max_bs * self.max_context_len,),
|
||||||
)
|
|
||||||
self.cuda_graph_kv_indices = torch.zeros(
|
|
||||||
(max_bs * self.model_runner.model_config.context_len,),
|
|
||||||
dtype=torch.int32,
|
dtype=torch.int32,
|
||||||
device="cuda",
|
device="cuda",
|
||||||
)
|
)
|
||||||
self.cuda_graph_kv_last_page_len = torch.ones(
|
self.cuda_graph_kv_indices = [cuda_graph_kv_indices] + [
|
||||||
(max_bs,), dtype=torch.int32, device="cuda"
|
cuda_graph_kv_indices.clone() for _ in range(self.num_wrappers - 1)
|
||||||
)
|
|
||||||
|
|
||||||
# NOTE: the buffers are always in the form of list
|
|
||||||
self.cuda_graph_kv_indptr = [self.cuda_graph_kv_indptr] + [
|
|
||||||
self.cuda_graph_kv_indptr.clone() for _ in range(self.num_wrappers - 1)
|
|
||||||
]
|
|
||||||
self.cuda_graph_kv_indices = [self.cuda_graph_kv_indices] + [
|
|
||||||
self.cuda_graph_kv_indices.clone() for _ in range(self.num_wrappers - 1)
|
|
||||||
]
|
]
|
||||||
|
|
||||||
def init_forward_metadata_capture_cuda_graph(
|
def init_forward_metadata_capture_cuda_graph(
|
||||||
self, bs: int, req_pool_indices, seq_lens
|
self, bs: int, req_pool_indices: torch.Tensor, seq_lens: torch.Tensor
|
||||||
):
|
):
|
||||||
decode_wrappers = []
|
decode_wrappers = []
|
||||||
for i in range(self.num_wrappers):
|
for i in range(self.num_wrappers):
|
||||||
@@ -174,35 +175,21 @@ class FlashInferAttnBackend(AttentionBackend):
|
|||||||
"NHD",
|
"NHD",
|
||||||
use_cuda_graph=True,
|
use_cuda_graph=True,
|
||||||
use_tensor_cores=self.decode_use_tensor_cores,
|
use_tensor_cores=self.decode_use_tensor_cores,
|
||||||
paged_kv_indptr_buffer=self.cuda_graph_kv_indptr[i][: bs + 1],
|
paged_kv_indptr_buffer=self.kv_indptr[i][: bs + 1],
|
||||||
paged_kv_indices_buffer=self.cuda_graph_kv_indices[i],
|
paged_kv_indices_buffer=self.cuda_graph_kv_indices[i],
|
||||||
paged_kv_last_page_len_buffer=self.cuda_graph_kv_last_page_len[:bs],
|
paged_kv_last_page_len_buffer=self.kv_last_page_len[:bs],
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
update_flashinfer_indices(
|
self.indices_updater_decode.update(req_pool_indices, seq_lens, decode_wrappers)
|
||||||
ForwardMode.DECODE,
|
|
||||||
self.model_runner,
|
|
||||||
req_pool_indices,
|
|
||||||
seq_lens,
|
|
||||||
None,
|
|
||||||
decode_wrappers,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.cuda_graph_metadata[bs] = decode_wrappers
|
self.cuda_graph_metadata[bs] = decode_wrappers
|
||||||
|
self.forward_metadata = (decode_wrappers,)
|
||||||
self.forward_metadata = (False, False, None, decode_wrappers)
|
|
||||||
|
|
||||||
def init_forward_metadata_replay_cuda_graph(
|
def init_forward_metadata_replay_cuda_graph(
|
||||||
self, bs: int, req_pool_indices, seq_lens
|
self, bs: int, req_pool_indices: torch.Tensor, seq_lens: torch.Tensor
|
||||||
):
|
):
|
||||||
update_flashinfer_indices(
|
self.indices_updater_decode.update(
|
||||||
ForwardMode.DECODE,
|
req_pool_indices[:bs], seq_lens[:bs], self.cuda_graph_metadata[bs]
|
||||||
self.model_runner,
|
|
||||||
req_pool_indices[:bs],
|
|
||||||
seq_lens[:bs],
|
|
||||||
None,
|
|
||||||
self.cuda_graph_metadata[bs],
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_cuda_graph_seq_len_fill_value(self):
|
def get_cuda_graph_seq_len_fill_value(self):
|
||||||
@@ -213,7 +200,7 @@ class FlashInferAttnBackend(AttentionBackend):
|
|||||||
self._get_wrapper_idx(layer)
|
self._get_wrapper_idx(layer)
|
||||||
]
|
]
|
||||||
|
|
||||||
use_ragged, extend_no_prefix, _, _ = self.forward_metadata
|
use_ragged, extend_no_prefix = self.forward_metadata
|
||||||
|
|
||||||
if not use_ragged:
|
if not use_ragged:
|
||||||
if k is not None:
|
if k is not None:
|
||||||
@@ -259,7 +246,7 @@ class FlashInferAttnBackend(AttentionBackend):
|
|||||||
return o.view(-1, layer.tp_q_head_num * layer.head_dim)
|
return o.view(-1, layer.tp_q_head_num * layer.head_dim)
|
||||||
|
|
||||||
def forward_decode(self, q, k, v, layer: nn.Module, forward_batch: ForwardBatch):
|
def forward_decode(self, q, k, v, layer: nn.Module, forward_batch: ForwardBatch):
|
||||||
decode_wrapper = self.forward_metadata[-1][self._get_wrapper_idx(layer)]
|
decode_wrapper = self.forward_metadata[0][self._get_wrapper_idx(layer)]
|
||||||
|
|
||||||
if k is not None:
|
if k is not None:
|
||||||
assert v is not None
|
assert v is not None
|
||||||
@@ -275,3 +262,285 @@ class FlashInferAttnBackend(AttentionBackend):
|
|||||||
)
|
)
|
||||||
|
|
||||||
return o.view(-1, layer.tp_q_head_num * layer.head_dim)
|
return o.view(-1, layer.tp_q_head_num * layer.head_dim)
|
||||||
|
|
||||||
|
def _get_wrapper_idx(self, layer: nn.Module):
|
||||||
|
if self.num_wrappers == 1:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
if self.dispatch_reason == WrapperDispatch.SLIDING_WINDOW:
|
||||||
|
return layer.sliding_window_size == -1
|
||||||
|
if self.dispatch_reason == WrapperDispatch.CROSS_ATTENTION:
|
||||||
|
return layer.is_cross_attention
|
||||||
|
|
||||||
|
raise ValueError(f"Unknown dispatch reason: {self.dispatch_reason}")
|
||||||
|
|
||||||
|
|
||||||
|
class FlashInferIndicesUpdaterDecode:
|
||||||
|
def __init__(self, model_runner: ModelRunner, attn_backend: AttentionBackend):
|
||||||
|
# Constants
|
||||||
|
self.num_qo_heads = (
|
||||||
|
model_runner.model_config.num_attention_heads // model_runner.tp_size
|
||||||
|
)
|
||||||
|
self.num_kv_heads = model_runner.model_config.get_num_kv_heads(
|
||||||
|
model_runner.tp_size
|
||||||
|
)
|
||||||
|
self.head_dim = model_runner.model_config.head_dim
|
||||||
|
self.data_type = model_runner.kv_cache_dtype
|
||||||
|
self.q_data_type = model_runner.dtype
|
||||||
|
self.max_context_len = model_runner.req_to_token_pool.req_to_token.size(1)
|
||||||
|
self.sliding_window_size = model_runner.sliding_window_size
|
||||||
|
|
||||||
|
# Buffers and wrappers
|
||||||
|
self.kv_indptr = attn_backend.kv_indptr
|
||||||
|
self.kv_last_page_len = attn_backend.kv_last_page_len
|
||||||
|
self.req_to_token = model_runner.req_to_token_pool.req_to_token
|
||||||
|
self.decode_wrappers = attn_backend.decode_wrappers
|
||||||
|
|
||||||
|
# Dispatch
|
||||||
|
if attn_backend.dispatch_reason == WrapperDispatch.SLIDING_WINDOW:
|
||||||
|
self.update = self.update_sliding_window
|
||||||
|
elif attn_backend.dispatch_reason == WrapperDispatch.CROSS_ATTENTION:
|
||||||
|
self.update = self.update_cross_attention
|
||||||
|
else:
|
||||||
|
assert attn_backend.num_wrappers == 1
|
||||||
|
self.update = self.update_single_wrapper
|
||||||
|
|
||||||
|
def update_single_wrapper(self, req_pool_indices, seq_lens, decode_wrappers=None):
|
||||||
|
decode_wrappers = decode_wrappers or self.decode_wrappers
|
||||||
|
self.call_begin_forward(
|
||||||
|
decode_wrappers[0], req_pool_indices, seq_lens, self.kv_indptr[0], None
|
||||||
|
)
|
||||||
|
|
||||||
|
def update_sliding_window(self, req_pool_indices, seq_lens, decode_wrappers=None):
|
||||||
|
decode_wrappers = decode_wrappers or self.decode_wrappers
|
||||||
|
|
||||||
|
for wrapper_id in range(2):
|
||||||
|
if wrapper_id == 0:
|
||||||
|
# Sliding window attention
|
||||||
|
paged_kernel_lens = torch.minimum( # TODO: replace this with clamp
|
||||||
|
seq_lens,
|
||||||
|
torch.tensor(self.sliding_window_size + 1),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Full attention
|
||||||
|
paged_kernel_lens = seq_lens
|
||||||
|
|
||||||
|
kv_start_idx = seq_lens - paged_kernel_lens
|
||||||
|
|
||||||
|
self.call_begin_forward(
|
||||||
|
decode_wrappers[wrapper_id],
|
||||||
|
req_pool_indices,
|
||||||
|
paged_kernel_lens,
|
||||||
|
self.kv_indptr[wrapper_id],
|
||||||
|
kv_start_idx,
|
||||||
|
)
|
||||||
|
|
||||||
|
def update_cross_attention(self):
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
def call_begin_forward(
|
||||||
|
self, wrapper, req_pool_indices, paged_kernel_lens, kv_indptr, kv_start_idx
|
||||||
|
):
|
||||||
|
bs = len(req_pool_indices)
|
||||||
|
kv_indptr = kv_indptr[: bs + 1]
|
||||||
|
# TODO: optimize the blocking call on kv_indptr[-1]
|
||||||
|
kv_indptr[1:] = torch.cumsum(paged_kernel_lens, dim=0)
|
||||||
|
kv_indices = torch.empty(kv_indptr[-1], dtype=torch.int32, device="cuda")
|
||||||
|
|
||||||
|
create_flashinfer_kv_indices_triton[(bs,)](
|
||||||
|
self.req_to_token,
|
||||||
|
req_pool_indices,
|
||||||
|
paged_kernel_lens,
|
||||||
|
kv_indptr,
|
||||||
|
kv_start_idx,
|
||||||
|
kv_indices,
|
||||||
|
self.max_context_len,
|
||||||
|
)
|
||||||
|
|
||||||
|
wrapper.end_forward()
|
||||||
|
wrapper.begin_forward(
|
||||||
|
kv_indptr,
|
||||||
|
kv_indices,
|
||||||
|
self.kv_last_page_len[:bs],
|
||||||
|
self.num_qo_heads,
|
||||||
|
self.num_kv_heads,
|
||||||
|
self.head_dim,
|
||||||
|
1,
|
||||||
|
data_type=self.data_type,
|
||||||
|
q_data_type=self.q_data_type,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class FlashInferIndicesUpdaterPrefill:
|
||||||
|
def __init__(self, model_runner: ModelRunner, attn_backend: AttentionBackend):
|
||||||
|
# Constants
|
||||||
|
self.num_qo_heads = (
|
||||||
|
model_runner.model_config.num_attention_heads // model_runner.tp_size
|
||||||
|
)
|
||||||
|
self.num_kv_heads = model_runner.model_config.get_num_kv_heads(
|
||||||
|
model_runner.tp_size
|
||||||
|
)
|
||||||
|
self.head_dim = model_runner.model_config.head_dim
|
||||||
|
self.data_type = model_runner.kv_cache_dtype
|
||||||
|
self.q_data_type = model_runner.dtype
|
||||||
|
self.max_context_len = model_runner.req_to_token_pool.req_to_token.size(1)
|
||||||
|
self.sliding_window_size = model_runner.sliding_window_size
|
||||||
|
|
||||||
|
# Buffers and wrappers
|
||||||
|
self.kv_indptr = attn_backend.kv_indptr
|
||||||
|
self.kv_last_page_len = attn_backend.kv_last_page_len
|
||||||
|
self.qo_indptr = attn_backend.qo_indptr
|
||||||
|
self.req_to_token = model_runner.req_to_token_pool.req_to_token
|
||||||
|
self.wrapper_ragged = attn_backend.prefill_wrapper_ragged
|
||||||
|
self.wrappers_paged = attn_backend.prefill_wrappers_paged
|
||||||
|
|
||||||
|
# Dispatch
|
||||||
|
if attn_backend.dispatch_reason == WrapperDispatch.SLIDING_WINDOW:
|
||||||
|
self.update = self.update_sliding_window
|
||||||
|
elif attn_backend.dispatch_reason == WrapperDispatch.CROSS_ATTENTION:
|
||||||
|
self.update = self.update_cross_attention
|
||||||
|
else:
|
||||||
|
assert attn_backend.num_wrappers == 1
|
||||||
|
self.update = self.update_single_wrapper
|
||||||
|
|
||||||
|
def update_single_wrapper(
|
||||||
|
self, req_pool_indices, seq_lens, prefix_lens, use_ragged
|
||||||
|
):
|
||||||
|
if use_ragged:
|
||||||
|
paged_kernel_lens = prefix_lens
|
||||||
|
else:
|
||||||
|
paged_kernel_lens = seq_lens
|
||||||
|
|
||||||
|
self.call_begin_forward(
|
||||||
|
self.wrapper_ragged,
|
||||||
|
self.wrappers_paged[0],
|
||||||
|
req_pool_indices,
|
||||||
|
paged_kernel_lens,
|
||||||
|
seq_lens,
|
||||||
|
prefix_lens,
|
||||||
|
None,
|
||||||
|
self.kv_indptr[0],
|
||||||
|
self.qo_indptr[0],
|
||||||
|
use_ragged,
|
||||||
|
)
|
||||||
|
|
||||||
|
def update_sliding_window(
|
||||||
|
self, req_pool_indices, seq_lens, prefix_lens, use_ragged
|
||||||
|
):
|
||||||
|
for wrapper_id in range(2):
|
||||||
|
if wrapper_id == 0:
|
||||||
|
# window attention use paged only
|
||||||
|
paged_kernel_lens = torch.minimum(
|
||||||
|
seq_lens,
|
||||||
|
torch.tensor(self.sliding_window_size) + seq_lens - prefix_lens,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# full attention
|
||||||
|
paged_kernel_lens = seq_lens
|
||||||
|
kv_start_idx = seq_lens - paged_kernel_lens
|
||||||
|
|
||||||
|
self.call_begin_forward(
|
||||||
|
self.wrapper_ragged,
|
||||||
|
self.wrappers_paged[wrapper_id],
|
||||||
|
req_pool_indices,
|
||||||
|
paged_kernel_lens,
|
||||||
|
seq_lens,
|
||||||
|
prefix_lens,
|
||||||
|
kv_start_idx,
|
||||||
|
self.kv_indptr[wrapper_id],
|
||||||
|
self.qo_indptr[wrapper_id],
|
||||||
|
use_ragged,
|
||||||
|
)
|
||||||
|
|
||||||
|
def update_cross_attention(self):
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
def call_begin_forward(
|
||||||
|
self,
|
||||||
|
wrapper_ragged,
|
||||||
|
wrapper_paged,
|
||||||
|
req_pool_indices,
|
||||||
|
paged_kernel_lens,
|
||||||
|
seq_lens,
|
||||||
|
prefix_lens,
|
||||||
|
kv_start_idx,
|
||||||
|
kv_indptr,
|
||||||
|
qo_indptr,
|
||||||
|
use_ragged,
|
||||||
|
):
|
||||||
|
bs = len(req_pool_indices)
|
||||||
|
kv_indptr = kv_indptr[: bs + 1]
|
||||||
|
kv_indptr[1:] = torch.cumsum(paged_kernel_lens, dim=0)
|
||||||
|
kv_indices = torch.empty(kv_indptr[-1], dtype=torch.int32, device="cuda")
|
||||||
|
create_flashinfer_kv_indices_triton[(bs,)](
|
||||||
|
self.req_to_token,
|
||||||
|
req_pool_indices,
|
||||||
|
paged_kernel_lens,
|
||||||
|
kv_indptr,
|
||||||
|
kv_start_idx,
|
||||||
|
kv_indices,
|
||||||
|
self.max_context_len,
|
||||||
|
)
|
||||||
|
|
||||||
|
qo_indptr = qo_indptr[: bs + 1]
|
||||||
|
qo_indptr[1:] = torch.cumsum(seq_lens - prefix_lens, dim=0)
|
||||||
|
|
||||||
|
# extend part
|
||||||
|
if use_ragged:
|
||||||
|
wrapper_ragged.end_forward()
|
||||||
|
wrapper_ragged.begin_forward(
|
||||||
|
qo_indptr,
|
||||||
|
qo_indptr,
|
||||||
|
self.num_qo_heads,
|
||||||
|
self.num_kv_heads,
|
||||||
|
self.head_dim,
|
||||||
|
)
|
||||||
|
|
||||||
|
# cached part
|
||||||
|
wrapper_paged.end_forward()
|
||||||
|
wrapper_paged.begin_forward(
|
||||||
|
qo_indptr,
|
||||||
|
kv_indptr,
|
||||||
|
kv_indices,
|
||||||
|
self.kv_last_page_len[:bs],
|
||||||
|
self.num_qo_heads,
|
||||||
|
self.num_kv_heads,
|
||||||
|
self.head_dim,
|
||||||
|
1,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def create_flashinfer_kv_indices_triton(
|
||||||
|
req_to_token_ptr, # [max_batch, max_context_len]
|
||||||
|
req_pool_indices_ptr,
|
||||||
|
page_kernel_lens_ptr,
|
||||||
|
kv_indptr,
|
||||||
|
kv_start_idx,
|
||||||
|
kv_indices_ptr,
|
||||||
|
max_context_len: tl.constexpr,
|
||||||
|
):
|
||||||
|
BLOCK_SIZE: tl.constexpr = 512
|
||||||
|
pid = tl.program_id(axis=0)
|
||||||
|
req_pool_index = tl.load(req_pool_indices_ptr + pid)
|
||||||
|
kv_indices_offset = tl.load(kv_indptr + pid)
|
||||||
|
|
||||||
|
kv_start = 0
|
||||||
|
kv_end = 0
|
||||||
|
if kv_start_idx:
|
||||||
|
kv_start = tl.load(kv_start_idx + pid).to(tl.int32)
|
||||||
|
kv_end = kv_start
|
||||||
|
kv_end += tl.load(page_kernel_lens_ptr + pid).to(tl.int32)
|
||||||
|
|
||||||
|
req_to_token_ptr += req_pool_index * max_context_len
|
||||||
|
kv_indices_ptr += kv_indices_offset
|
||||||
|
|
||||||
|
ld_offset = kv_start + tl.arange(0, BLOCK_SIZE)
|
||||||
|
st_offset = tl.arange(0, BLOCK_SIZE)
|
||||||
|
num_loop = tl.cdiv(kv_end - kv_start, BLOCK_SIZE)
|
||||||
|
for _ in range(num_loop):
|
||||||
|
mask = ld_offset < kv_end
|
||||||
|
data = tl.load(req_to_token_ptr + ld_offset, mask=mask)
|
||||||
|
tl.store(kv_indices_ptr + st_offset, data, mask=mask)
|
||||||
|
ld_offset += BLOCK_SIZE
|
||||||
|
st_offset += BLOCK_SIZE
|
||||||
|
|||||||
@@ -1,237 +0,0 @@
|
|||||||
from enum import Enum, auto
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import triton
|
|
||||||
import triton.language as tl
|
|
||||||
|
|
||||||
|
|
||||||
class WrapperDispatch(Enum):
|
|
||||||
SLIDING_WINDOW = auto()
|
|
||||||
CROSS_ATTENTION = auto()
|
|
||||||
|
|
||||||
|
|
||||||
@triton.jit
|
|
||||||
def create_flashinfer_kv_indices_triton(
|
|
||||||
req_to_token_ptr, # [max_batch, max_context_len]
|
|
||||||
req_pool_indices_ptr,
|
|
||||||
page_kernel_lens_ptr,
|
|
||||||
kv_indptr,
|
|
||||||
kv_start_idx,
|
|
||||||
kv_indices_ptr,
|
|
||||||
max_context_len: tl.constexpr,
|
|
||||||
):
|
|
||||||
BLOCK_SIZE: tl.constexpr = 512
|
|
||||||
pid = tl.program_id(axis=0)
|
|
||||||
req_pool_index = tl.load(req_pool_indices_ptr + pid)
|
|
||||||
kv_indices_offset = tl.load(kv_indptr + pid)
|
|
||||||
|
|
||||||
kv_start = 0
|
|
||||||
kv_end = 0
|
|
||||||
if kv_start_idx:
|
|
||||||
kv_start = tl.load(kv_start_idx + pid).to(tl.int32)
|
|
||||||
kv_end = kv_start
|
|
||||||
kv_end += tl.load(page_kernel_lens_ptr + pid).to(tl.int32)
|
|
||||||
|
|
||||||
req_to_token_ptr += req_pool_index * max_context_len
|
|
||||||
kv_indices_ptr += kv_indices_offset
|
|
||||||
|
|
||||||
ld_offset = kv_start + tl.arange(0, BLOCK_SIZE)
|
|
||||||
st_offset = tl.arange(0, BLOCK_SIZE)
|
|
||||||
num_loop = tl.cdiv(kv_end - kv_start, BLOCK_SIZE)
|
|
||||||
for _ in range(num_loop):
|
|
||||||
mask = ld_offset < kv_end
|
|
||||||
data = tl.load(req_to_token_ptr + ld_offset, mask=mask)
|
|
||||||
tl.store(kv_indices_ptr + st_offset, data, mask=mask)
|
|
||||||
ld_offset += BLOCK_SIZE
|
|
||||||
st_offset += BLOCK_SIZE
|
|
||||||
|
|
||||||
|
|
||||||
class FlashinferUpdater:
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
forward_mode,
|
|
||||||
model_runner,
|
|
||||||
req_pool_indices,
|
|
||||||
seq_lens,
|
|
||||||
prefix_lens,
|
|
||||||
decode_wrappers=None,
|
|
||||||
use_ragged=False,
|
|
||||||
):
|
|
||||||
self.forward_mode = forward_mode
|
|
||||||
self.model_runner = model_runner
|
|
||||||
self.req_pool_indices = req_pool_indices
|
|
||||||
self.seq_lens = seq_lens
|
|
||||||
self.prefix_lens = prefix_lens
|
|
||||||
self.use_ragged = use_ragged
|
|
||||||
|
|
||||||
self.num_qo_heads = (
|
|
||||||
model_runner.model_config.num_attention_heads // model_runner.tp_size
|
|
||||||
)
|
|
||||||
self.num_kv_heads = model_runner.model_config.get_num_kv_heads(
|
|
||||||
model_runner.tp_size
|
|
||||||
)
|
|
||||||
self.head_dim = model_runner.model_config.head_dim
|
|
||||||
self.batch_size = len(req_pool_indices)
|
|
||||||
|
|
||||||
self.decode_wrappers = (
|
|
||||||
decode_wrappers or self.model_runner.attn_backend.decode_wrappers
|
|
||||||
)
|
|
||||||
self.prefill_wrapper_ragged = (
|
|
||||||
self.model_runner.attn_backend.prefill_wrapper_ragged
|
|
||||||
)
|
|
||||||
self.prefill_wrappers_paged = (
|
|
||||||
self.model_runner.attn_backend.prefill_wrappers_paged
|
|
||||||
)
|
|
||||||
|
|
||||||
self.kv_last_page_len = torch.ones(
|
|
||||||
(self.batch_size,), dtype=torch.int32, device="cuda"
|
|
||||||
)
|
|
||||||
|
|
||||||
def _update_decode_indices(self, decode_wrapper):
|
|
||||||
assert not isinstance(decode_wrapper, list)
|
|
||||||
decode_wrapper.end_forward()
|
|
||||||
decode_wrapper.begin_forward(
|
|
||||||
self.kv_indptr,
|
|
||||||
self.kv_indices,
|
|
||||||
self.kv_last_page_len,
|
|
||||||
self.num_qo_heads,
|
|
||||||
self.num_kv_heads,
|
|
||||||
self.head_dim,
|
|
||||||
1,
|
|
||||||
data_type=self.model_runner.kv_cache_dtype,
|
|
||||||
q_data_type=self.model_runner.dtype,
|
|
||||||
)
|
|
||||||
|
|
||||||
def _update_extend_indices(self, ragged_wrapper, paged_wrapper):
|
|
||||||
assert not isinstance(paged_wrapper, list)
|
|
||||||
assert not isinstance(ragged_wrapper, list)
|
|
||||||
|
|
||||||
# extend part
|
|
||||||
qo_indptr = torch.zeros(
|
|
||||||
(self.batch_size + 1,), dtype=torch.int32, device="cuda"
|
|
||||||
)
|
|
||||||
qo_indptr[1:] = torch.cumsum(self.seq_lens - self.prefix_lens, dim=0)
|
|
||||||
|
|
||||||
if self.use_ragged:
|
|
||||||
ragged_wrapper.end_forward()
|
|
||||||
ragged_wrapper.begin_forward(
|
|
||||||
qo_indptr,
|
|
||||||
qo_indptr,
|
|
||||||
self.num_qo_heads,
|
|
||||||
self.num_kv_heads,
|
|
||||||
self.head_dim,
|
|
||||||
)
|
|
||||||
|
|
||||||
# cached part
|
|
||||||
paged_wrapper.end_forward()
|
|
||||||
paged_wrapper.begin_forward(
|
|
||||||
qo_indptr,
|
|
||||||
self.kv_indptr,
|
|
||||||
self.kv_indices,
|
|
||||||
self.kv_last_page_len,
|
|
||||||
self.num_qo_heads,
|
|
||||||
self.num_kv_heads,
|
|
||||||
self.head_dim,
|
|
||||||
1,
|
|
||||||
)
|
|
||||||
|
|
||||||
def _get_indices(self, dispatch_reason: WrapperDispatch = None, wrapper_id=0):
|
|
||||||
if dispatch_reason is None:
|
|
||||||
if self.use_ragged:
|
|
||||||
paged_kernel_lens = self.prefix_lens
|
|
||||||
else:
|
|
||||||
paged_kernel_lens = self.seq_lens
|
|
||||||
self.kv_start_idx = None
|
|
||||||
elif dispatch_reason == WrapperDispatch.SLIDING_WINDOW:
|
|
||||||
if wrapper_id == 0:
|
|
||||||
# window attention use paged only
|
|
||||||
if self.forward_mode.is_decode():
|
|
||||||
paged_kernel_lens = torch.minimum(
|
|
||||||
self.seq_lens,
|
|
||||||
torch.tensor(self.model_runner.sliding_window_size + 1),
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
paged_kernel_lens = torch.minimum(
|
|
||||||
self.seq_lens,
|
|
||||||
torch.tensor(self.model_runner.sliding_window_size)
|
|
||||||
+ self.seq_lens
|
|
||||||
- self.prefix_lens,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# full attention
|
|
||||||
paged_kernel_lens = self.seq_lens
|
|
||||||
self.kv_start_idx = self.seq_lens - paged_kernel_lens
|
|
||||||
|
|
||||||
self.kv_indptr = torch.zeros(
|
|
||||||
(self.batch_size + 1,), dtype=torch.int32, device="cuda"
|
|
||||||
)
|
|
||||||
self.kv_indptr[1:] = torch.cumsum(paged_kernel_lens, dim=0)
|
|
||||||
self.kv_indices = torch.empty(
|
|
||||||
self.kv_indptr[-1], dtype=torch.int32, device="cuda"
|
|
||||||
)
|
|
||||||
|
|
||||||
create_flashinfer_kv_indices_triton[(self.batch_size,)](
|
|
||||||
self.model_runner.req_to_token_pool.req_to_token,
|
|
||||||
self.req_pool_indices,
|
|
||||||
paged_kernel_lens,
|
|
||||||
self.kv_indptr,
|
|
||||||
self.kv_start_idx,
|
|
||||||
self.kv_indices,
|
|
||||||
self.model_runner.req_to_token_pool.req_to_token.size(1),
|
|
||||||
)
|
|
||||||
|
|
||||||
def _update_indicess_single_wrapper(self):
|
|
||||||
self._get_indices()
|
|
||||||
|
|
||||||
if self.forward_mode.is_decode():
|
|
||||||
self._update_decode_indices(self.decode_wrappers[0])
|
|
||||||
else:
|
|
||||||
self._update_extend_indices(
|
|
||||||
self.prefill_wrapper_ragged,
|
|
||||||
self.prefill_wrappers_paged[0],
|
|
||||||
)
|
|
||||||
|
|
||||||
def _update_indices_cross_attention(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def _update_indices_sliding_window(self):
|
|
||||||
assert self.use_ragged is False
|
|
||||||
for wrapper_id in range(2):
|
|
||||||
self._get_indices(WrapperDispatch.SLIDING_WINDOW, wrapper_id)
|
|
||||||
if self.forward_mode.is_decode():
|
|
||||||
self._update_decode_indices(self.decode_wrappers[wrapper_id])
|
|
||||||
else:
|
|
||||||
self._update_extend_indices(
|
|
||||||
None,
|
|
||||||
self.prefill_wrappers_paged[wrapper_id],
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def update_flashinfer_indices(
|
|
||||||
forward_mode,
|
|
||||||
model_runner,
|
|
||||||
req_pool_indices,
|
|
||||||
seq_lens,
|
|
||||||
prefix_lens,
|
|
||||||
decode_wrappers=None,
|
|
||||||
use_ragged=False,
|
|
||||||
):
|
|
||||||
updater = FlashinferUpdater(
|
|
||||||
forward_mode,
|
|
||||||
model_runner,
|
|
||||||
req_pool_indices,
|
|
||||||
seq_lens,
|
|
||||||
prefix_lens,
|
|
||||||
decode_wrappers,
|
|
||||||
use_ragged,
|
|
||||||
)
|
|
||||||
|
|
||||||
dispatch_reason = model_runner.attn_backend.dispatch_reason
|
|
||||||
|
|
||||||
if dispatch_reason == WrapperDispatch.SLIDING_WINDOW:
|
|
||||||
updater._update_indices_sliding_window()
|
|
||||||
elif dispatch_reason == WrapperDispatch.CROSS_ATTENTION:
|
|
||||||
updater._update_indices_cross_attention()
|
|
||||||
else:
|
|
||||||
assert model_runner.attn_backend.num_wrappers == 1
|
|
||||||
updater._update_indicess_single_wrapper()
|
|
||||||
@@ -81,7 +81,7 @@ class TritonAttnBackend(AttentionBackend):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def init_forward_metadata_capture_cuda_graph(
|
def init_forward_metadata_capture_cuda_graph(
|
||||||
self, bs: int, req_pool_indices, seq_lens
|
self, bs: int, req_pool_indices: torch.Tensor, seq_lens: torch.Tensor
|
||||||
):
|
):
|
||||||
self.forward_metadata = (
|
self.forward_metadata = (
|
||||||
self.cuda_graph_start_loc,
|
self.cuda_graph_start_loc,
|
||||||
@@ -91,7 +91,7 @@ class TritonAttnBackend(AttentionBackend):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def init_forward_metadata_replay_cuda_graph(
|
def init_forward_metadata_replay_cuda_graph(
|
||||||
self, bs: int, req_pool_indices, seq_lens
|
self, bs: int, req_pool_indices: torch.Tensor, seq_lens: torch.Tensor
|
||||||
):
|
):
|
||||||
self.cuda_graph_start_loc.zero_()
|
self.cuda_graph_start_loc.zero_()
|
||||||
self.cuda_graph_start_loc[1:bs] = torch.cumsum(seq_lens[: bs - 1], dim=0)
|
self.cuda_graph_start_loc[1:bs] = torch.cumsum(seq_lens[: bs - 1], dim=0)
|
||||||
|
|||||||
@@ -744,7 +744,6 @@ class ScheduleBatch:
|
|||||||
self.forward_mode = ForwardMode.DECODE
|
self.forward_mode = ForwardMode.DECODE
|
||||||
|
|
||||||
self.input_ids = self.output_ids
|
self.input_ids = self.output_ids
|
||||||
self.seq_lens.add_(1)
|
|
||||||
self.output_ids = None
|
self.output_ids = None
|
||||||
if self.sampling_info.penalizer_orchestrator:
|
if self.sampling_info.penalizer_orchestrator:
|
||||||
self.sampling_info.penalizer_orchestrator.cumulate_output_tokens(
|
self.sampling_info.penalizer_orchestrator.cumulate_output_tokens(
|
||||||
@@ -755,9 +754,10 @@ class ScheduleBatch:
|
|||||||
bs = len(self.reqs)
|
bs = len(self.reqs)
|
||||||
self.out_cache_loc = self.alloc_token_slots(bs)
|
self.out_cache_loc = self.alloc_token_slots(bs)
|
||||||
|
|
||||||
self.req_to_token_pool.req_to_token[
|
self.req_to_token_pool.req_to_token[self.req_pool_indices, self.seq_lens] = (
|
||||||
self.req_pool_indices, self.seq_lens - 1
|
self.out_cache_loc
|
||||||
] = self.out_cache_loc
|
)
|
||||||
|
self.seq_lens.add_(1)
|
||||||
|
|
||||||
def filter_batch(
|
def filter_batch(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@@ -134,9 +134,7 @@ class ForwardBatch:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Init position information
|
# Init position information
|
||||||
if ret.forward_mode.is_decode():
|
if not ret.forward_mode.is_decode():
|
||||||
ret.positions = (ret.seq_lens - 1).to(torch.int64)
|
|
||||||
else:
|
|
||||||
ret.positions = torch.tensor(
|
ret.positions = torch.tensor(
|
||||||
np.concatenate(
|
np.concatenate(
|
||||||
[
|
[
|
||||||
@@ -164,7 +162,6 @@ class ForwardBatch:
|
|||||||
ret.req_to_token_pool = model_runner.req_to_token_pool
|
ret.req_to_token_pool = model_runner.req_to_token_pool
|
||||||
ret.token_to_kv_pool = model_runner.token_to_kv_pool
|
ret.token_to_kv_pool = model_runner.token_to_kv_pool
|
||||||
ret.attn_backend = model_runner.attn_backend
|
ret.attn_backend = model_runner.attn_backend
|
||||||
model_runner.attn_backend.init_forward_metadata(ret)
|
|
||||||
|
|
||||||
# Init lora information
|
# Init lora information
|
||||||
if model_runner.server_args.lora_paths is not None:
|
if model_runner.server_args.lora_paths is not None:
|
||||||
|
|||||||
@@ -551,11 +551,14 @@ class ModelRunner:
|
|||||||
):
|
):
|
||||||
return self.cuda_graph_runner.replay(forward_batch)
|
return self.cuda_graph_runner.replay(forward_batch)
|
||||||
|
|
||||||
|
forward_batch.positions = (forward_batch.seq_lens - 1).to(torch.int64)
|
||||||
|
self.attn_backend.init_forward_metadata(forward_batch)
|
||||||
return self.model.forward(
|
return self.model.forward(
|
||||||
forward_batch.input_ids, forward_batch.positions, forward_batch
|
forward_batch.input_ids, forward_batch.positions, forward_batch
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward_extend(self, forward_batch: ForwardBatch):
|
def forward_extend(self, forward_batch: ForwardBatch):
|
||||||
|
self.attn_backend.init_forward_metadata(forward_batch)
|
||||||
if self.is_generation:
|
if self.is_generation:
|
||||||
return self.model.forward(
|
return self.model.forward(
|
||||||
forward_batch.input_ids, forward_batch.positions, forward_batch
|
forward_batch.input_ids, forward_batch.positions, forward_batch
|
||||||
|
|||||||
Reference in New Issue
Block a user