[Fix] Compatibility of window attention and cuda graph (#1090)
This commit is contained in:
@@ -34,6 +34,7 @@ class RadixAttention(nn.Module):
|
||||
scaling: float,
|
||||
num_kv_heads: int,
|
||||
layer_id: int,
|
||||
reuse: bool = False,
|
||||
sliding_window_size: int = -1,
|
||||
logit_cap: int = -1,
|
||||
v_head_dim: int = -1,
|
||||
@@ -47,6 +48,7 @@ class RadixAttention(nn.Module):
|
||||
self.v_head_dim = v_head_dim if v_head_dim != -1 else head_dim
|
||||
self.scaling = scaling
|
||||
self.layer_id = layer_id
|
||||
self.reuse = reuse
|
||||
self.sliding_window_size = sliding_window_size
|
||||
|
||||
if (
|
||||
@@ -127,8 +129,9 @@ class RadixAttention(nn.Module):
|
||||
if isinstance(prefill_wrapper_paged, list):
|
||||
prefill_wrapper_paged = prefill_wrapper_paged[1]
|
||||
|
||||
if not input_metadata.flashinfer_use_ragged:
|
||||
self.store_kv_cache(k, v, input_metadata)
|
||||
if not input_metadata.flashinfer_use_ragged or self.reuse:
|
||||
if not self.reuse:
|
||||
self.store_kv_cache(k, v, input_metadata)
|
||||
|
||||
o = prefill_wrapper_paged.forward(
|
||||
q.contiguous().view(-1, self.tp_q_head_num, self.head_dim),
|
||||
@@ -179,7 +182,8 @@ class RadixAttention(nn.Module):
|
||||
if isinstance(decode_wrapper, list):
|
||||
decode_wrapper = decode_wrapper[1]
|
||||
|
||||
self.store_kv_cache(k, v, input_metadata)
|
||||
if not self.reuse:
|
||||
self.store_kv_cache(k, v, input_metadata)
|
||||
|
||||
o = decode_wrapper.forward(
|
||||
q.contiguous().view(-1, self.tp_q_head_num, self.head_dim),
|
||||
@@ -191,8 +195,10 @@ class RadixAttention(nn.Module):
|
||||
return o.view(-1, self.tp_q_head_num * self.head_dim)
|
||||
|
||||
def forward(self, q, k, v, input_metadata: InputMetadata):
|
||||
k = k.view(-1, self.tp_k_head_num, self.qk_head_dim)
|
||||
v = v.view(-1, self.tp_v_head_num, self.v_head_dim)
|
||||
if k is not None:
|
||||
assert v is not None
|
||||
k = k.view(-1, self.tp_k_head_num, self.qk_head_dim)
|
||||
v = v.view(-1, self.tp_v_head_num, self.v_head_dim)
|
||||
|
||||
if input_metadata.forward_mode == ForwardMode.EXTEND:
|
||||
return self.extend_forward(q, k, v, input_metadata)
|
||||
|
||||
@@ -107,9 +107,6 @@ class CudaGraphRunner:
|
||||
)
|
||||
|
||||
# FlashInfer inputs
|
||||
self.flashinfer_workspace_buffer = (
|
||||
self.model_runner.flashinfer_workspace_buffers[0]
|
||||
)
|
||||
self.flashinfer_kv_indptr = torch.zeros(
|
||||
(self.max_bs + 1,), dtype=torch.int32, device="cuda"
|
||||
)
|
||||
@@ -121,6 +118,23 @@ class CudaGraphRunner:
|
||||
self.flashinfer_kv_last_page_len = torch.ones(
|
||||
(self.max_bs,), dtype=torch.int32, device="cuda"
|
||||
)
|
||||
if model_runner.sliding_window_size is None:
|
||||
self.flashinfer_workspace_buffer = (
|
||||
self.model_runner.flashinfer_workspace_buffers[0]
|
||||
)
|
||||
else:
|
||||
self.flashinfer_workspace_buffers = [
|
||||
self.model_runner.flashinfer_workspace_buffers[0],
|
||||
self.model_runner.flashinfer_workspace_buffers[2],
|
||||
]
|
||||
self.flashinfer_kv_indptr = [
|
||||
self.flashinfer_kv_indptr,
|
||||
self.flashinfer_kv_indptr.clone(),
|
||||
]
|
||||
self.flashinfer_kv_indices = [
|
||||
self.flashinfer_kv_indices,
|
||||
self.flashinfer_kv_indices.clone(),
|
||||
]
|
||||
|
||||
self.compile_bs = [1, 2, 4, 8, 16, 24, 32] if use_torch_compile else []
|
||||
|
||||
@@ -171,15 +185,32 @@ class CudaGraphRunner:
|
||||
use_tensor_cores = True
|
||||
else:
|
||||
use_tensor_cores = False
|
||||
flashinfer_decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
|
||||
self.flashinfer_workspace_buffer,
|
||||
"NHD",
|
||||
use_cuda_graph=True,
|
||||
use_tensor_cores=use_tensor_cores,
|
||||
paged_kv_indptr_buffer=self.flashinfer_kv_indptr[: bs + 1],
|
||||
paged_kv_indices_buffer=self.flashinfer_kv_indices,
|
||||
paged_kv_last_page_len_buffer=self.flashinfer_kv_last_page_len[:bs],
|
||||
)
|
||||
if self.model_runner.sliding_window_size is None:
|
||||
flashinfer_decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
|
||||
self.flashinfer_workspace_buffer,
|
||||
"NHD",
|
||||
use_cuda_graph=True,
|
||||
use_tensor_cores=use_tensor_cores,
|
||||
paged_kv_indptr_buffer=self.flashinfer_kv_indptr[: bs + 1],
|
||||
paged_kv_indices_buffer=self.flashinfer_kv_indices,
|
||||
paged_kv_last_page_len_buffer=self.flashinfer_kv_last_page_len[:bs],
|
||||
)
|
||||
else:
|
||||
flashinfer_decode_wrapper = []
|
||||
for i in range(2):
|
||||
flashinfer_decode_wrapper.append(
|
||||
BatchDecodeWithPagedKVCacheWrapper(
|
||||
self.flashinfer_workspace_buffers[i],
|
||||
"NHD",
|
||||
use_cuda_graph=True,
|
||||
use_tensor_cores=use_tensor_cores,
|
||||
paged_kv_indptr_buffer=self.flashinfer_kv_indptr[i][: bs + 1],
|
||||
paged_kv_indices_buffer=self.flashinfer_kv_indices[i],
|
||||
paged_kv_last_page_len_buffer=self.flashinfer_kv_last_page_len[
|
||||
:bs
|
||||
],
|
||||
)
|
||||
)
|
||||
update_flashinfer_indices(
|
||||
ForwardMode.DECODE,
|
||||
self.model_runner,
|
||||
|
||||
@@ -154,7 +154,6 @@ class InputMetadata:
|
||||
model_runner: "ModelRunner",
|
||||
batch: ScheduleBatch,
|
||||
forward_mode: ForwardMode,
|
||||
sliding_window_size: Optional[int] = None,
|
||||
):
|
||||
ret = cls(
|
||||
forward_mode=forward_mode,
|
||||
@@ -198,7 +197,7 @@ class InputMetadata:
|
||||
):
|
||||
flashinfer_use_ragged = True
|
||||
ret.init_flashinfer_handlers(
|
||||
model_runner, prefix_lens, flashinfer_use_ragged, sliding_window_size
|
||||
model_runner, prefix_lens, flashinfer_use_ragged
|
||||
)
|
||||
|
||||
return ret
|
||||
@@ -221,7 +220,6 @@ class InputMetadata:
|
||||
model_runner,
|
||||
prefix_lens,
|
||||
flashinfer_use_ragged,
|
||||
sliding_window_size=None,
|
||||
):
|
||||
update_flashinfer_indices(
|
||||
self.forward_mode,
|
||||
@@ -230,7 +228,6 @@ class InputMetadata:
|
||||
self.seq_lens,
|
||||
prefix_lens,
|
||||
flashinfer_use_ragged=flashinfer_use_ragged,
|
||||
sliding_window_size=sliding_window_size,
|
||||
)
|
||||
|
||||
(
|
||||
@@ -254,7 +251,6 @@ def update_flashinfer_indices(
|
||||
prefix_lens,
|
||||
flashinfer_decode_wrapper=None,
|
||||
flashinfer_use_ragged=False,
|
||||
sliding_window_size=None,
|
||||
):
|
||||
"""Init auxiliary variables for FlashInfer attention backend."""
|
||||
num_qo_heads = model_runner.model_config.num_attention_heads // model_runner.tp_size
|
||||
@@ -262,7 +258,7 @@ def update_flashinfer_indices(
|
||||
head_dim = model_runner.model_config.head_dim
|
||||
batch_size = len(req_pool_indices)
|
||||
|
||||
if sliding_window_size is None:
|
||||
if model_runner.sliding_window_size is None:
|
||||
if flashinfer_use_ragged:
|
||||
paged_kernel_lens = prefix_lens
|
||||
else:
|
||||
@@ -335,7 +331,7 @@ def update_flashinfer_indices(
|
||||
|
||||
if wrapper_id == 0 and forward_mode == ForwardMode.DECODE:
|
||||
paged_kernel_lens = torch.minimum(
|
||||
paged_kernel_lens, torch.tensor(sliding_window_size)
|
||||
paged_kernel_lens, torch.tensor(model_runner.sliding_window_size)
|
||||
)
|
||||
kv_start_idx = seq_lens - paged_kernel_lens
|
||||
else:
|
||||
|
||||
@@ -187,6 +187,11 @@ class ModelRunner:
|
||||
scheduler_config=None,
|
||||
cache_config=None,
|
||||
)
|
||||
self.sliding_window_size = (
|
||||
self.model.get_window_size()
|
||||
if hasattr(self.model, "get_window_size")
|
||||
else None
|
||||
)
|
||||
self.is_generation = is_generation_model(
|
||||
self.model_config.hf_config.architectures
|
||||
)
|
||||
@@ -295,12 +300,6 @@ class ModelRunner:
|
||||
return c
|
||||
|
||||
def init_flashinfer(self):
|
||||
self.sliding_window_size = (
|
||||
self.model.get_window_size()
|
||||
if hasattr(self.model, "get_window_size")
|
||||
else None
|
||||
)
|
||||
|
||||
if self.server_args.disable_flashinfer:
|
||||
assert (
|
||||
self.sliding_window_size is None
|
||||
@@ -339,7 +338,7 @@ class ModelRunner:
|
||||
use_tensor_cores=use_tensor_cores,
|
||||
)
|
||||
else:
|
||||
workspace_buffers = torch.empty(
|
||||
self.flashinfer_workspace_buffers = torch.empty(
|
||||
4,
|
||||
global_config.flashinfer_workspace_size,
|
||||
dtype=torch.uint8,
|
||||
@@ -351,17 +350,17 @@ class ModelRunner:
|
||||
for i in range(2):
|
||||
self.flashinfer_prefill_wrapper_ragged.append(
|
||||
BatchPrefillWithRaggedKVCacheWrapper(
|
||||
workspace_buffers[2 * i + 0], "NHD"
|
||||
self.flashinfer_workspace_buffers[2 * i + 0], "NHD"
|
||||
)
|
||||
)
|
||||
self.flashinfer_prefill_wrapper_paged.append(
|
||||
BatchPrefillWithPagedKVCacheWrapper(
|
||||
workspace_buffers[2 * i + 1], "NHD"
|
||||
self.flashinfer_workspace_buffers[2 * i + 1], "NHD"
|
||||
)
|
||||
)
|
||||
self.flashinfer_decode_wrapper.append(
|
||||
BatchDecodeWithPagedKVCacheWrapper(
|
||||
workspace_buffers[2 * i + 0],
|
||||
self.flashinfer_workspace_buffers[2 * i + 0],
|
||||
"NHD",
|
||||
use_tensor_cores=use_tensor_cores,
|
||||
)
|
||||
@@ -404,7 +403,6 @@ class ModelRunner:
|
||||
self,
|
||||
batch,
|
||||
ForwardMode.DECODE,
|
||||
sliding_window_size=self.sliding_window_size,
|
||||
)
|
||||
|
||||
return self.model.forward(
|
||||
@@ -417,7 +415,6 @@ class ModelRunner:
|
||||
self,
|
||||
batch,
|
||||
forward_mode=ForwardMode.EXTEND,
|
||||
sliding_window_size=self.sliding_window_size,
|
||||
)
|
||||
return self.model.forward(
|
||||
batch.input_ids, input_metadata.positions, input_metadata
|
||||
@@ -429,7 +426,6 @@ class ModelRunner:
|
||||
self,
|
||||
batch,
|
||||
forward_mode=ForwardMode.EXTEND,
|
||||
sliding_window_size=self.sliding_window_size,
|
||||
)
|
||||
return self.model.forward(
|
||||
batch.input_ids,
|
||||
|
||||
@@ -453,10 +453,12 @@ class ServerArgs:
|
||||
logger.info(
|
||||
f"When using sliding window in gemma-2, disable radix_cache, regex_jump_forward, and turn on flashinfer."
|
||||
)
|
||||
# FIXME: compatibility with radix attention
|
||||
self.disable_radix_cache = True
|
||||
# FIXME: compatibility with jump forward
|
||||
self.disable_regex_jump_forward = True
|
||||
self.disable_flashinfer = False
|
||||
self.disable_cuda_graph = True
|
||||
# FIXME: compatibility with chunked prefill
|
||||
self.chunked_prefill_size = None
|
||||
|
||||
|
||||
|
||||
@@ -36,7 +36,7 @@ DEFAULT_PROMPTS = [
|
||||
]
|
||||
|
||||
dirpath = os.path.dirname(__file__)
|
||||
with open(os.path.join(dirpath, "long_prompt"), "r") as f:
|
||||
with open(os.path.join(dirpath, "long_prompt.txt"), "r") as f:
|
||||
long_prompt = f.read()
|
||||
DEFAULT_PROMPTS.append(long_prompt)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user