Bump Flashinfer to 0.2.5 (#5870)
Co-authored-by: Yuhao Chen <yxckeis8@gmail.com>
This commit is contained in:
2
.github/workflows/pr-test.yml
vendored
2
.github/workflows/pr-test.yml
vendored
@@ -96,8 +96,6 @@ jobs:
|
|||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
env:
|
|
||||||
FLASHINFER_REPO: ${{ inputs.version == 'nightly' && 'https://flashinfer.ai/whl/nightly/cu124/torch2.5/flashinfer-python' || 'https://flashinfer.ai/whl/cu124/torch2.5/flashinfer-python' }}
|
|
||||||
run: |
|
run: |
|
||||||
bash scripts/ci_install_dependency.sh
|
bash scripts/ci_install_dependency.sh
|
||||||
|
|
||||||
|
|||||||
@@ -164,4 +164,4 @@ sky status --endpoint 30000 sglang
|
|||||||
- [FlashInfer](https://github.com/flashinfer-ai/flashinfer) is the default attention kernel backend. It only supports sm75 and above. If you encounter any FlashInfer-related issues on sm75+ devices (e.g., T4, A10, A100, L4, L40S, H100), please switch to other kernels by adding `--attention-backend triton --sampling-backend pytorch` and open an issue on GitHub.
|
- [FlashInfer](https://github.com/flashinfer-ai/flashinfer) is the default attention kernel backend. It only supports sm75 and above. If you encounter any FlashInfer-related issues on sm75+ devices (e.g., T4, A10, A100, L4, L40S, H100), please switch to other kernels by adding `--attention-backend triton --sampling-backend pytorch` and open an issue on GitHub.
|
||||||
- If you only need to use OpenAI models with the frontend language, you can avoid installing other dependencies by using `pip install "sglang[openai]"`.
|
- If you only need to use OpenAI models with the frontend language, you can avoid installing other dependencies by using `pip install "sglang[openai]"`.
|
||||||
- The language frontend operates independently of the backend runtime. You can install the frontend locally without needing a GPU, while the backend can be set up on a GPU-enabled machine. To install the frontend, run `pip install sglang`, and for the backend, use `pip install sglang[srt]`. `srt` is the abbreviation of SGLang runtime.
|
- The language frontend operates independently of the backend runtime. You can install the frontend locally without needing a GPU, while the backend can be set up on a GPU-enabled machine. To install the frontend, run `pip install sglang`, and for the backend, use `pip install sglang[srt]`. `srt` is the abbreviation of SGLang runtime.
|
||||||
- To reinstall flashinfer locally, use the following command: `pip install "flashinfer-python==0.2.3" -i https://flashinfer.ai/whl/cu124/torch2.6 --force-reinstall --no-deps` and then delete the cache with `rm -rf ~/.cache/flashinfer`.
|
- To reinstall flashinfer locally, use the following command: `pip install "flashinfer-python==0.2.5" -i https://flashinfer.ai/whl/cu124/torch2.6 --force-reinstall --no-deps` and then delete the cache with `rm -rf ~/.cache/flashinfer`.
|
||||||
|
|||||||
@@ -37,7 +37,7 @@ runtime_common = [
|
|||||||
"python-multipart",
|
"python-multipart",
|
||||||
"pyzmq>=25.1.2",
|
"pyzmq>=25.1.2",
|
||||||
"soundfile==0.13.1",
|
"soundfile==0.13.1",
|
||||||
"torchao>=0.7.0",
|
"torchao>=0.9.0",
|
||||||
"transformers==4.51.1",
|
"transformers==4.51.1",
|
||||||
"uvicorn",
|
"uvicorn",
|
||||||
"uvloop",
|
"uvloop",
|
||||||
@@ -47,7 +47,7 @@ runtime_common = [
|
|||||||
srt = [
|
srt = [
|
||||||
"sglang[runtime_common]",
|
"sglang[runtime_common]",
|
||||||
"sgl-kernel==0.1.0",
|
"sgl-kernel==0.1.0",
|
||||||
"flashinfer_python==0.2.3",
|
"flashinfer_python==0.2.5",
|
||||||
"torch==2.6.0",
|
"torch==2.6.0",
|
||||||
"torchvision==0.21.0",
|
"torchvision==0.21.0",
|
||||||
"cuda-python",
|
"cuda-python",
|
||||||
|
|||||||
@@ -453,7 +453,7 @@ def _set_envs_and_config(server_args: ServerArgs):
|
|||||||
if server_args.attention_backend == "flashinfer":
|
if server_args.attention_backend == "flashinfer":
|
||||||
assert_pkg_version(
|
assert_pkg_version(
|
||||||
"flashinfer_python",
|
"flashinfer_python",
|
||||||
"0.2.3",
|
"0.2.5",
|
||||||
"Please uninstall the old version and "
|
"Please uninstall the old version and "
|
||||||
"reinstall the latest version by following the instructions "
|
"reinstall the latest version by following the instructions "
|
||||||
"at https://docs.flashinfer.ai/installation.html.",
|
"at https://docs.flashinfer.ai/installation.html.",
|
||||||
|
|||||||
@@ -15,6 +15,11 @@ from typing import TYPE_CHECKING, Callable, List, Optional, Union
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
if os.environ["SGLANG_ENABLE_TORCH_COMPILE"] == "1":
|
||||||
|
import torch._dynamo
|
||||||
|
|
||||||
|
torch._dynamo.config.suppress_errors = True
|
||||||
|
|
||||||
from sglang.global_config import global_config
|
from sglang.global_config import global_config
|
||||||
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
|
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
|
||||||
from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton
|
from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton
|
||||||
@@ -82,8 +87,6 @@ class FlashInferAttnBackend(AttentionBackend):
|
|||||||
self.max_context_len = model_runner.model_config.context_len
|
self.max_context_len = model_runner.model_config.context_len
|
||||||
self.skip_prefill = skip_prefill
|
self.skip_prefill = skip_prefill
|
||||||
self.is_multimodal = model_runner.model_config.is_multimodal
|
self.is_multimodal = model_runner.model_config.is_multimodal
|
||||||
self.kv_cache_dtype = model_runner.kv_cache_dtype
|
|
||||||
self.kv_cache_dtype_str = model_runner.server_args.kv_cache_dtype
|
|
||||||
|
|
||||||
assert not (
|
assert not (
|
||||||
model_runner.sliding_window_size is not None
|
model_runner.sliding_window_size is not None
|
||||||
@@ -268,6 +271,12 @@ class FlashInferAttnBackend(AttentionBackend):
|
|||||||
cuda_graph_kv_indices.clone() for _ in range(self.num_wrappers - 1)
|
cuda_graph_kv_indices.clone() for _ in range(self.num_wrappers - 1)
|
||||||
]
|
]
|
||||||
|
|
||||||
|
# Ensure tensors are properly allocated
|
||||||
|
for i in range(self.num_wrappers):
|
||||||
|
# Force allocation by performing a small operation
|
||||||
|
if len(self.cuda_graph_kv_indices[i]) > 0:
|
||||||
|
self.cuda_graph_kv_indices[i][0] = 0
|
||||||
|
|
||||||
if not self.skip_prefill:
|
if not self.skip_prefill:
|
||||||
self.cuda_graph_custom_mask = torch.zeros(
|
self.cuda_graph_custom_mask = torch.zeros(
|
||||||
(max_bs * self.max_context_len),
|
(max_bs * self.max_context_len),
|
||||||
@@ -396,8 +405,6 @@ class FlashInferAttnBackend(AttentionBackend):
|
|||||||
forward_batch: ForwardBatch,
|
forward_batch: ForwardBatch,
|
||||||
save_kv_cache=True,
|
save_kv_cache=True,
|
||||||
):
|
):
|
||||||
k_scale = layer.k_scale_float if self.kv_cache_dtype_str != "auto" else None
|
|
||||||
v_scale = layer.v_scale_float if self.kv_cache_dtype_str != "auto" else None
|
|
||||||
prefill_wrapper_paged = self.forward_metadata.prefill_wrappers[
|
prefill_wrapper_paged = self.forward_metadata.prefill_wrappers[
|
||||||
self._get_wrapper_idx(layer)
|
self._get_wrapper_idx(layer)
|
||||||
]
|
]
|
||||||
@@ -414,7 +421,7 @@ class FlashInferAttnBackend(AttentionBackend):
|
|||||||
assert v is not None
|
assert v is not None
|
||||||
if save_kv_cache:
|
if save_kv_cache:
|
||||||
forward_batch.token_to_kv_pool.set_kv_buffer(
|
forward_batch.token_to_kv_pool.set_kv_buffer(
|
||||||
layer, cache_loc, k, v, k_scale, v_scale
|
layer, cache_loc, k, v, layer.k_scale, layer.v_scale
|
||||||
)
|
)
|
||||||
|
|
||||||
o = prefill_wrapper_paged.forward(
|
o = prefill_wrapper_paged.forward(
|
||||||
@@ -424,8 +431,8 @@ class FlashInferAttnBackend(AttentionBackend):
|
|||||||
sm_scale=layer.scaling,
|
sm_scale=layer.scaling,
|
||||||
window_left=layer.sliding_window_size,
|
window_left=layer.sliding_window_size,
|
||||||
logits_soft_cap=logits_soft_cap,
|
logits_soft_cap=logits_soft_cap,
|
||||||
k_scale=k_scale,
|
k_scale=layer.k_scale,
|
||||||
v_scale=v_scale,
|
v_scale=layer.v_scale,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
o1, s1 = self.prefill_wrapper_ragged.forward_return_lse(
|
o1, s1 = self.prefill_wrapper_ragged.forward_return_lse(
|
||||||
@@ -452,7 +459,7 @@ class FlashInferAttnBackend(AttentionBackend):
|
|||||||
|
|
||||||
if save_kv_cache:
|
if save_kv_cache:
|
||||||
forward_batch.token_to_kv_pool.set_kv_buffer(
|
forward_batch.token_to_kv_pool.set_kv_buffer(
|
||||||
layer, cache_loc, k, v, k_scale, v_scale
|
layer, cache_loc, k, v, layer.k_scale, layer.v_scale
|
||||||
)
|
)
|
||||||
|
|
||||||
return o.view(-1, layer.tp_q_head_num * layer.head_dim)
|
return o.view(-1, layer.tp_q_head_num * layer.head_dim)
|
||||||
@@ -466,8 +473,6 @@ class FlashInferAttnBackend(AttentionBackend):
|
|||||||
forward_batch: ForwardBatch,
|
forward_batch: ForwardBatch,
|
||||||
save_kv_cache=True,
|
save_kv_cache=True,
|
||||||
):
|
):
|
||||||
k_scale = layer.k_scale_float if self.kv_cache_dtype_str != "auto" else None
|
|
||||||
v_scale = layer.v_scale_float if self.kv_cache_dtype_str != "auto" else None
|
|
||||||
decode_wrapper = self.forward_metadata.decode_wrappers[
|
decode_wrapper = self.forward_metadata.decode_wrappers[
|
||||||
self._get_wrapper_idx(layer)
|
self._get_wrapper_idx(layer)
|
||||||
]
|
]
|
||||||
@@ -481,16 +486,17 @@ class FlashInferAttnBackend(AttentionBackend):
|
|||||||
assert v is not None
|
assert v is not None
|
||||||
if save_kv_cache:
|
if save_kv_cache:
|
||||||
forward_batch.token_to_kv_pool.set_kv_buffer(
|
forward_batch.token_to_kv_pool.set_kv_buffer(
|
||||||
layer, cache_loc, k, v, k_scale, v_scale
|
layer, cache_loc, k, v, layer.k_scale, layer.v_scale
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Call the wrapped function
|
||||||
o = decode_wrapper.forward(
|
o = decode_wrapper.forward(
|
||||||
q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
|
q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
|
||||||
forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id),
|
forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id),
|
||||||
sm_scale=layer.scaling,
|
sm_scale=layer.scaling,
|
||||||
logits_soft_cap=layer.logit_cap,
|
logits_soft_cap=layer.logit_cap,
|
||||||
k_scale=k_scale,
|
k_scale=layer.k_scale,
|
||||||
v_scale=v_scale,
|
v_scale=layer.v_scale,
|
||||||
)
|
)
|
||||||
|
|
||||||
return o.view(-1, layer.tp_q_head_num * layer.head_dim)
|
return o.view(-1, layer.tp_q_head_num * layer.head_dim)
|
||||||
@@ -1146,8 +1152,9 @@ def fast_decode_plan(
|
|||||||
pos_encoding_mode: str = "NONE",
|
pos_encoding_mode: str = "NONE",
|
||||||
window_left: int = -1,
|
window_left: int = -1,
|
||||||
logits_soft_cap: Optional[float] = None,
|
logits_soft_cap: Optional[float] = None,
|
||||||
data_type: Union[str, torch.dtype] = "float16",
|
|
||||||
q_data_type: Optional[Union[str, torch.dtype]] = None,
|
q_data_type: Optional[Union[str, torch.dtype]] = None,
|
||||||
|
kv_data_type: Optional[Union[str, torch.dtype]] = None,
|
||||||
|
data_type: Optional[Union[str, torch.dtype]] = None,
|
||||||
sm_scale: Optional[float] = None,
|
sm_scale: Optional[float] = None,
|
||||||
rope_scale: Optional[float] = None,
|
rope_scale: Optional[float] = None,
|
||||||
rope_theta: Optional[float] = None,
|
rope_theta: Optional[float] = None,
|
||||||
@@ -1163,6 +1170,18 @@ def fast_decode_plan(
|
|||||||
if logits_soft_cap is None:
|
if logits_soft_cap is None:
|
||||||
logits_soft_cap = 0.0
|
logits_soft_cap = 0.0
|
||||||
|
|
||||||
|
# Handle data types consistently
|
||||||
|
if data_type is not None:
|
||||||
|
if q_data_type is None:
|
||||||
|
q_data_type = data_type
|
||||||
|
if kv_data_type is None:
|
||||||
|
kv_data_type = data_type
|
||||||
|
elif q_data_type is None:
|
||||||
|
q_data_type = "float16"
|
||||||
|
|
||||||
|
if kv_data_type is None:
|
||||||
|
kv_data_type = q_data_type
|
||||||
|
|
||||||
if self.use_tensor_cores:
|
if self.use_tensor_cores:
|
||||||
qo_indptr_host = _get_range_buf(batch_size + 1, "cpu")
|
qo_indptr_host = _get_range_buf(batch_size + 1, "cpu")
|
||||||
|
|
||||||
@@ -1178,36 +1197,33 @@ def fast_decode_plan(
|
|||||||
raise ValueError(
|
raise ValueError(
|
||||||
"The size of indices should be less than or equal to the allocated buffer"
|
"The size of indices should be less than or equal to the allocated buffer"
|
||||||
)
|
)
|
||||||
# Skip these copies because we directly write to them during prepartion
|
|
||||||
# self._paged_kv_indptr_buf.copy_(indptr)
|
|
||||||
# self._paged_kv_indices_buf[: len(indices)] = indices
|
|
||||||
# self._paged_kv_last_page_len_buf.copy_(last_page_len)
|
|
||||||
else:
|
else:
|
||||||
self._paged_kv_indptr_buf = indptr
|
self._paged_kv_indptr_buf = indptr
|
||||||
self._paged_kv_indices_buf = indices
|
self._paged_kv_indices_buf = indices
|
||||||
self._paged_kv_last_page_len_buf = last_page_len
|
self._paged_kv_last_page_len_buf = last_page_len
|
||||||
self._qo_indptr_buf = qo_indptr_host.to(self.device, non_blocking=non_blocking)
|
if self.use_tensor_cores:
|
||||||
|
self._qo_indptr_buf = qo_indptr_host.to(
|
||||||
|
self.device, non_blocking=non_blocking
|
||||||
|
)
|
||||||
|
|
||||||
# NOTE(Zihao): the following tensors acts as placeholder to pass dtype info
|
# Create empty tensors for dtype info if needed
|
||||||
if not q_data_type:
|
empty_q_data = torch.empty(
|
||||||
q_data_type = data_type
|
0,
|
||||||
|
dtype=(
|
||||||
|
getattr(torch, q_data_type) if isinstance(q_data_type, str) else q_data_type
|
||||||
|
),
|
||||||
|
device=self.device,
|
||||||
|
)
|
||||||
|
|
||||||
if not hasattr(self, "empty_q_data"):
|
empty_kv_cache = torch.empty(
|
||||||
self.empty_q_data = torch.empty(
|
0,
|
||||||
0,
|
dtype=(
|
||||||
dtype=(
|
getattr(torch, kv_data_type)
|
||||||
getattr(torch, q_data_type)
|
if isinstance(kv_data_type, str)
|
||||||
if isinstance(q_data_type, str)
|
else kv_data_type
|
||||||
else q_data_type
|
),
|
||||||
),
|
device=self.device,
|
||||||
)
|
)
|
||||||
self.empty_kv_cache = torch.empty(
|
|
||||||
0,
|
|
||||||
dtype=(
|
|
||||||
getattr(torch, data_type) if isinstance(data_type, str) else data_type
|
|
||||||
),
|
|
||||||
)
|
|
||||||
self.last_page_len = torch.ones(32768, dtype=torch.int32)
|
|
||||||
|
|
||||||
indptr_host = (
|
indptr_host = (
|
||||||
global_override_indptr_cpu
|
global_override_indptr_cpu
|
||||||
@@ -1215,48 +1231,57 @@ def fast_decode_plan(
|
|||||||
else indptr.cpu()
|
else indptr.cpu()
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.use_tensor_cores:
|
with torch.cuda.device(self.device):
|
||||||
kv_lens_arr_host = get_seq_lens(
|
|
||||||
indptr_host, self.last_page_len[:batch_size], page_size
|
|
||||||
)
|
|
||||||
|
|
||||||
self._plan_info = self._cached_module.plan(
|
if self.use_tensor_cores:
|
||||||
self._float_workspace_buffer,
|
# ALSO convert last_page_len to CPU
|
||||||
self._int_workspace_buffer,
|
last_page_len_host = last_page_len.cpu()
|
||||||
self._pin_memory_int_workspace_buffer,
|
|
||||||
qo_indptr_host,
|
kv_lens_arr_host = get_seq_lens(indptr_host, last_page_len_host, page_size)
|
||||||
indptr_host,
|
|
||||||
kv_lens_arr_host,
|
try:
|
||||||
batch_size, # total_num_rows
|
# Make sure we pass exactly 15 arguments for tensor core version
|
||||||
batch_size,
|
self._plan_info = self._cached_module.plan(
|
||||||
num_qo_heads,
|
self._float_workspace_buffer,
|
||||||
num_kv_heads,
|
self._int_workspace_buffer,
|
||||||
page_size,
|
self._pin_memory_int_workspace_buffer,
|
||||||
self.is_cuda_graph_enabled,
|
qo_indptr_host,
|
||||||
head_dim,
|
indptr_host,
|
||||||
head_dim,
|
kv_lens_arr_host,
|
||||||
False, # causal
|
batch_size, # total_num_rows
|
||||||
torch.cuda.current_stream().cuda_stream,
|
batch_size,
|
||||||
)
|
num_qo_heads,
|
||||||
else:
|
num_kv_heads,
|
||||||
self._plan_info = self._cached_module.plan(
|
page_size,
|
||||||
self._float_workspace_buffer,
|
self.is_cuda_graph_enabled,
|
||||||
self._int_workspace_buffer,
|
head_dim,
|
||||||
self._pin_memory_int_workspace_buffer,
|
head_dim,
|
||||||
indptr_host,
|
False, # causal
|
||||||
batch_size,
|
)
|
||||||
num_qo_heads,
|
except Exception as e:
|
||||||
num_kv_heads,
|
raise RuntimeError(f"Error in standard plan: {e}")
|
||||||
page_size,
|
else:
|
||||||
self.is_cuda_graph_enabled,
|
try:
|
||||||
window_left,
|
# Make sure we pass exactly 15 arguments for standard version
|
||||||
logits_soft_cap,
|
self._plan_info = self._cached_module.plan(
|
||||||
head_dim,
|
self._float_workspace_buffer,
|
||||||
head_dim,
|
self._int_workspace_buffer,
|
||||||
self.empty_q_data,
|
self._pin_memory_int_workspace_buffer,
|
||||||
self.empty_kv_cache,
|
indptr_host,
|
||||||
torch.cuda.current_stream().cuda_stream,
|
batch_size,
|
||||||
)
|
num_qo_heads,
|
||||||
|
num_kv_heads,
|
||||||
|
page_size,
|
||||||
|
self.is_cuda_graph_enabled,
|
||||||
|
window_left,
|
||||||
|
logits_soft_cap,
|
||||||
|
head_dim,
|
||||||
|
head_dim,
|
||||||
|
empty_q_data,
|
||||||
|
empty_kv_cache,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
raise RuntimeError(f"Error in standard plan: {e}")
|
||||||
|
|
||||||
self._pos_encoding_mode = pos_encoding_mode
|
self._pos_encoding_mode = pos_encoding_mode
|
||||||
self._window_left = window_left
|
self._window_left = window_left
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ and uses BatchMLAPaged wrapper for decoding.
|
|||||||
More details can be found in https://docs.flashinfer.ai/api/mla.html
|
More details can be found in https://docs.flashinfer.ai/api/mla.html
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import TYPE_CHECKING, Callable, Optional, Union
|
from typing import TYPE_CHECKING, Callable, Optional, Union
|
||||||
@@ -16,6 +17,11 @@ from typing import TYPE_CHECKING, Callable, Optional, Union
|
|||||||
import torch
|
import torch
|
||||||
import triton
|
import triton
|
||||||
|
|
||||||
|
if os.environ["SGLANG_ENABLE_TORCH_COMPILE"] == "1":
|
||||||
|
import torch._dynamo
|
||||||
|
|
||||||
|
torch._dynamo.config.suppress_errors = True
|
||||||
|
|
||||||
from sglang.global_config import global_config
|
from sglang.global_config import global_config
|
||||||
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
|
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
|
||||||
from sglang.srt.layers.attention.flashinfer_backend import (
|
from sglang.srt.layers.attention.flashinfer_backend import (
|
||||||
@@ -388,14 +394,17 @@ class FlashInferMLAAttnBackend(AttentionBackend):
|
|||||||
k,
|
k,
|
||||||
v,
|
v,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Reshape inputs
|
||||||
reshaped_q = q.view(-1, layer.tp_q_head_num, layer.head_dim)
|
reshaped_q = q.view(-1, layer.tp_q_head_num, layer.head_dim)
|
||||||
k_buffer = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
|
k_buffer = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
|
||||||
reshaped_k = k_buffer.view(-1, 1, layer.head_dim)
|
|
||||||
|
# Direct call to run without the wrapper
|
||||||
o = decode_wrapper.run(
|
o = decode_wrapper.run(
|
||||||
reshaped_q[:, :, : layer.v_head_dim],
|
reshaped_q[:, :, : layer.v_head_dim],
|
||||||
reshaped_q[:, :, layer.v_head_dim :],
|
reshaped_q[:, :, layer.v_head_dim :],
|
||||||
reshaped_k[:, :, : layer.v_head_dim],
|
k_buffer[:, :, : layer.v_head_dim],
|
||||||
reshaped_k[:, :, layer.v_head_dim :],
|
k_buffer[:, :, layer.v_head_dim :],
|
||||||
)
|
)
|
||||||
|
|
||||||
return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
|
return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
|
||||||
@@ -825,16 +834,18 @@ def fast_mla_decode_plan(
|
|||||||
self._sm_scale = sm_scale
|
self._sm_scale = sm_scale
|
||||||
|
|
||||||
with self.device as device:
|
with self.device as device:
|
||||||
stream = torch.cuda.current_stream(device).cuda_stream
|
try:
|
||||||
self._cached_module.plan(
|
# Standard version with just the required arguments (no use_profiler)
|
||||||
self._float_workspace_buffer,
|
self._cached_module.plan.default(
|
||||||
self._int_workspace_buffer,
|
self._float_workspace_buffer,
|
||||||
self._pin_memory_int_workspace_buffer,
|
self._int_workspace_buffer,
|
||||||
qo_indptr_cpu,
|
self._pin_memory_int_workspace_buffer,
|
||||||
kv_indptr_cpu,
|
qo_indptr_cpu,
|
||||||
kv_len_arr_cpu,
|
kv_indptr_cpu,
|
||||||
num_heads,
|
kv_len_arr_cpu,
|
||||||
head_dim_ckv,
|
num_heads,
|
||||||
causal,
|
head_dim_ckv,
|
||||||
stream,
|
causal,
|
||||||
)
|
)
|
||||||
|
except Exception as e:
|
||||||
|
raise RuntimeError(f"Error in alternate MLA plan: {e}")
|
||||||
|
|||||||
Reference in New Issue
Block a user