feat: add fast_decode_plan from flashinfer, flashinfer to 0.4.0rc3 (#10760)
Co-authored-by: Zihao Ye <yezihhhao@gmail.com> Co-authored-by: Sleepcoo <Sleepcoo@gmail.com>
This commit is contained in:
@@ -62,7 +62,7 @@ dependencies = [
|
||||
"torchaudio==2.8.0",
|
||||
"torchvision",
|
||||
"cuda-python",
|
||||
"flashinfer_python==0.4.0rc1",
|
||||
"flashinfer_python==0.4.0rc3",
|
||||
"openai==1.99.1",
|
||||
"tiktoken",
|
||||
"anthropic>=0.20.0",
|
||||
|
||||
@@ -70,7 +70,7 @@ srt = [
|
||||
"torchaudio==2.8.0",
|
||||
"torchvision",
|
||||
"cuda-python",
|
||||
"flashinfer_python==0.4.0rc1",
|
||||
"flashinfer_python==0.4.0rc3",
|
||||
]
|
||||
|
||||
blackwell = [
|
||||
@@ -80,8 +80,8 @@ blackwell = [
|
||||
"torchaudio==2.8.0",
|
||||
"torchvision",
|
||||
"cuda-python",
|
||||
"flashinfer_python==0.4.0rc1",
|
||||
"nvidia-cutlass-dsl==4.2.1",
|
||||
"flashinfer_python==0.4.0rc3",
|
||||
"nvidia-cutlass-dsl==4.2.0",
|
||||
]
|
||||
|
||||
# HIP (Heterogeneous-computing Interface for Portability) for AMD
|
||||
|
||||
@@ -703,7 +703,7 @@ def _set_envs_and_config(server_args: ServerArgs):
|
||||
if server_args.attention_backend == "flashinfer":
|
||||
assert_pkg_version(
|
||||
"flashinfer_python",
|
||||
"0.4.0rc1",
|
||||
"0.4.0rc3",
|
||||
"Please uninstall the old version and "
|
||||
"reinstall the latest version by following the instructions "
|
||||
"at https://docs.flashinfer.ai/installation.html.",
|
||||
|
||||
@@ -47,6 +47,7 @@ if is_flashinfer_available():
|
||||
BatchDecodeWithPagedKVCacheWrapper,
|
||||
BatchPrefillWithPagedKVCacheWrapper,
|
||||
BatchPrefillWithRaggedKVCacheWrapper,
|
||||
fast_decode_plan,
|
||||
)
|
||||
from flashinfer.cascade import merge_state
|
||||
from flashinfer.decode import _get_range_buf, get_seq_lens
|
||||
@@ -842,23 +843,51 @@ class FlashInferIndicesUpdaterDecode:
|
||||
global_override_indptr_cpu[0] = 0
|
||||
global_override_indptr_cpu[1 : bs + 1] = torch.cumsum(seq_lens_cpu, dim=0)
|
||||
|
||||
wrapper.begin_forward(
|
||||
kv_indptr,
|
||||
kv_indices,
|
||||
self.kv_last_page_len[:bs],
|
||||
self.num_qo_heads,
|
||||
self.num_kv_heads,
|
||||
self.head_dim,
|
||||
1,
|
||||
data_type=self.data_type,
|
||||
q_data_type=self.q_data_type,
|
||||
non_blocking=True,
|
||||
fixed_split_size=fixed_split_size,
|
||||
disable_split_kv=(
|
||||
disable_split_kv if disable_split_kv is not None else False
|
||||
),
|
||||
# Check if this specific wrapper's begin_forward has been replaced with fast_decode_plan
|
||||
# by checking if it's a partial function with fast_decode_plan as the func
|
||||
wrapper_uses_fast_decode_plan = (
|
||||
hasattr(wrapper.begin_forward, "func")
|
||||
and wrapper.begin_forward.func == fast_decode_plan
|
||||
)
|
||||
|
||||
if wrapper_uses_fast_decode_plan:
|
||||
# When begin_forward is replaced with fast_decode_plan, pass global_override_indptr_cpu
|
||||
wrapper.begin_forward(
|
||||
kv_indptr,
|
||||
kv_indices,
|
||||
self.kv_last_page_len[:bs],
|
||||
self.num_qo_heads,
|
||||
self.num_kv_heads,
|
||||
self.head_dim,
|
||||
1,
|
||||
data_type=self.data_type,
|
||||
q_data_type=self.q_data_type,
|
||||
non_blocking=True,
|
||||
fixed_split_size=fixed_split_size,
|
||||
disable_split_kv=(
|
||||
disable_split_kv if disable_split_kv is not None else False
|
||||
),
|
||||
global_override_indptr_cpu=global_override_indptr_cpu,
|
||||
)
|
||||
else:
|
||||
# When using original begin_forward, don't pass global_override_indptr_cpu
|
||||
wrapper.begin_forward(
|
||||
kv_indptr,
|
||||
kv_indices,
|
||||
self.kv_last_page_len[:bs],
|
||||
self.num_qo_heads,
|
||||
self.num_kv_heads,
|
||||
self.head_dim,
|
||||
1,
|
||||
data_type=self.data_type,
|
||||
q_data_type=self.q_data_type,
|
||||
non_blocking=True,
|
||||
fixed_split_size=fixed_split_size,
|
||||
disable_split_kv=(
|
||||
disable_split_kv if disable_split_kv is not None else False
|
||||
),
|
||||
)
|
||||
|
||||
if locally_override:
|
||||
global_override_indptr_cpu = None
|
||||
|
||||
@@ -1328,174 +1357,3 @@ def should_use_tensor_core(
|
||||
return gqa_group_size >= 4
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
# Use as a fast path to override the indptr in flashinfer's plan function
|
||||
# This is used to remove some host-to-device copy overhead.
|
||||
global_override_indptr_cpu = None
|
||||
|
||||
|
||||
def fast_decode_plan(
|
||||
self,
|
||||
indptr: torch.Tensor,
|
||||
indices: torch.Tensor,
|
||||
last_page_len: torch.Tensor,
|
||||
num_qo_heads: int,
|
||||
num_kv_heads: int,
|
||||
head_dim: int,
|
||||
page_size: int,
|
||||
pos_encoding_mode: str = "NONE",
|
||||
window_left: int = -1,
|
||||
logits_soft_cap: Optional[float] = None,
|
||||
q_data_type: Optional[Union[str, torch.dtype]] = None,
|
||||
kv_data_type: Optional[Union[str, torch.dtype]] = None,
|
||||
data_type: Optional[Union[str, torch.dtype]] = None,
|
||||
sm_scale: Optional[float] = None,
|
||||
rope_scale: Optional[float] = None,
|
||||
rope_theta: Optional[float] = None,
|
||||
non_blocking: bool = True,
|
||||
fixed_split_size: Optional[int] = None,
|
||||
disable_split_kv: bool = False,
|
||||
) -> None:
|
||||
"""
|
||||
A faster version of BatchDecodeWithPagedKVCacheWrapper::plan used for FlashInferMultiStepDraftBackend.
|
||||
Modifications:
|
||||
- Remove unnecessary device-to-device copy for the cuda graph buffers.
|
||||
- Remove unnecessary host-to-device copy for the metadata buffers.
|
||||
"""
|
||||
batch_size = len(last_page_len)
|
||||
if logits_soft_cap is None:
|
||||
logits_soft_cap = 0.0
|
||||
|
||||
# Handle data types consistently
|
||||
if data_type is not None:
|
||||
if q_data_type is None:
|
||||
q_data_type = data_type
|
||||
if kv_data_type is None:
|
||||
kv_data_type = data_type
|
||||
elif q_data_type is None:
|
||||
q_data_type = "float16"
|
||||
|
||||
if kv_data_type is None:
|
||||
kv_data_type = q_data_type
|
||||
|
||||
if self.use_tensor_cores:
|
||||
qo_indptr_host = _get_range_buf(batch_size + 1, "cpu")
|
||||
# Here we set fixed_split_size to -1 to avoid the assertion error in flashinfer's plan function
|
||||
if fixed_split_size is None:
|
||||
fixed_split_size = -1
|
||||
|
||||
if self.is_cuda_graph_enabled:
|
||||
if batch_size != self._fixed_batch_size:
|
||||
raise ValueError(
|
||||
"The batch size should be fixed in cudagraph mode, the runtime batch size {} "
|
||||
" mismatches the batch size set during initialization {}".format(
|
||||
batch_size, self._fixed_batch_size
|
||||
)
|
||||
)
|
||||
if len(indices) > len(self._paged_kv_indices_buf):
|
||||
raise ValueError(
|
||||
"The size of indices should be less than or equal to the allocated buffer"
|
||||
)
|
||||
else:
|
||||
self._paged_kv_indptr_buf = indptr
|
||||
self._paged_kv_indices_buf = indices
|
||||
self._paged_kv_last_page_len_buf = last_page_len
|
||||
if self.use_tensor_cores:
|
||||
self._qo_indptr_buf = qo_indptr_host.to(
|
||||
self.device, non_blocking=non_blocking
|
||||
)
|
||||
|
||||
# Create empty tensors for dtype info if needed
|
||||
empty_q_data = torch.empty(
|
||||
0,
|
||||
dtype=(
|
||||
getattr(torch, q_data_type) if isinstance(q_data_type, str) else q_data_type
|
||||
),
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
empty_kv_cache = torch.empty(
|
||||
0,
|
||||
dtype=(
|
||||
getattr(torch, kv_data_type)
|
||||
if isinstance(kv_data_type, str)
|
||||
else kv_data_type
|
||||
),
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
indptr_host = (
|
||||
global_override_indptr_cpu
|
||||
if global_override_indptr_cpu is not None
|
||||
else indptr.cpu()
|
||||
)
|
||||
|
||||
with torch.cuda.device(self.device):
|
||||
|
||||
if self.use_tensor_cores:
|
||||
# ALSO convert last_page_len to CPU
|
||||
if page_size == 1:
|
||||
# When page size is 1, last_page_len is always 1.
|
||||
# Directly construct the host tensor rather than executing a device-to-host copy.
|
||||
last_page_len_host = torch.ones(
|
||||
(batch_size,), dtype=torch.int32, device="cpu"
|
||||
)
|
||||
else:
|
||||
last_page_len_host = last_page_len.cpu()
|
||||
|
||||
kv_lens_arr_host = get_seq_lens(indptr_host, last_page_len_host, page_size)
|
||||
|
||||
try:
|
||||
# Make sure we pass exactly 15 arguments for tensor core version
|
||||
self._plan_info = self._cached_module.plan(
|
||||
self._float_workspace_buffer,
|
||||
self._int_workspace_buffer,
|
||||
self._pin_memory_int_workspace_buffer,
|
||||
qo_indptr_host,
|
||||
indptr_host,
|
||||
kv_lens_arr_host,
|
||||
batch_size, # total_num_rows
|
||||
batch_size,
|
||||
num_qo_heads,
|
||||
num_kv_heads,
|
||||
page_size,
|
||||
self.is_cuda_graph_enabled,
|
||||
head_dim,
|
||||
head_dim,
|
||||
False, # causal
|
||||
window_left,
|
||||
fixed_split_size,
|
||||
disable_split_kv,
|
||||
)
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Error in standard plan: {e}")
|
||||
else:
|
||||
try:
|
||||
# Make sure we pass exactly 15 arguments for standard version
|
||||
self._plan_info = self._cached_module.plan(
|
||||
self._float_workspace_buffer,
|
||||
self._int_workspace_buffer,
|
||||
self._pin_memory_int_workspace_buffer,
|
||||
indptr_host,
|
||||
batch_size,
|
||||
num_qo_heads,
|
||||
num_kv_heads,
|
||||
page_size,
|
||||
self.is_cuda_graph_enabled,
|
||||
window_left,
|
||||
logits_soft_cap,
|
||||
head_dim,
|
||||
head_dim,
|
||||
empty_q_data,
|
||||
empty_kv_cache,
|
||||
)
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Error in standard plan: {e}")
|
||||
|
||||
self._pos_encoding_mode = pos_encoding_mode
|
||||
self._window_left = window_left
|
||||
self._logits_soft_cap = logits_soft_cap
|
||||
self._sm_scale = sm_scale
|
||||
self._rope_scale = rope_scale
|
||||
self._rope_theta = rope_theta
|
||||
|
||||
Reference in New Issue
Block a user