[Fix] Compatibility of window attention and cuda graph (#1090)
This commit is contained in:
@@ -34,6 +34,7 @@ class RadixAttention(nn.Module):
|
|||||||
scaling: float,
|
scaling: float,
|
||||||
num_kv_heads: int,
|
num_kv_heads: int,
|
||||||
layer_id: int,
|
layer_id: int,
|
||||||
|
reuse: bool = False,
|
||||||
sliding_window_size: int = -1,
|
sliding_window_size: int = -1,
|
||||||
logit_cap: int = -1,
|
logit_cap: int = -1,
|
||||||
v_head_dim: int = -1,
|
v_head_dim: int = -1,
|
||||||
@@ -47,6 +48,7 @@ class RadixAttention(nn.Module):
|
|||||||
self.v_head_dim = v_head_dim if v_head_dim != -1 else head_dim
|
self.v_head_dim = v_head_dim if v_head_dim != -1 else head_dim
|
||||||
self.scaling = scaling
|
self.scaling = scaling
|
||||||
self.layer_id = layer_id
|
self.layer_id = layer_id
|
||||||
|
self.reuse = reuse
|
||||||
self.sliding_window_size = sliding_window_size
|
self.sliding_window_size = sliding_window_size
|
||||||
|
|
||||||
if (
|
if (
|
||||||
@@ -127,8 +129,9 @@ class RadixAttention(nn.Module):
|
|||||||
if isinstance(prefill_wrapper_paged, list):
|
if isinstance(prefill_wrapper_paged, list):
|
||||||
prefill_wrapper_paged = prefill_wrapper_paged[1]
|
prefill_wrapper_paged = prefill_wrapper_paged[1]
|
||||||
|
|
||||||
if not input_metadata.flashinfer_use_ragged:
|
if not input_metadata.flashinfer_use_ragged or self.reuse:
|
||||||
self.store_kv_cache(k, v, input_metadata)
|
if not self.reuse:
|
||||||
|
self.store_kv_cache(k, v, input_metadata)
|
||||||
|
|
||||||
o = prefill_wrapper_paged.forward(
|
o = prefill_wrapper_paged.forward(
|
||||||
q.contiguous().view(-1, self.tp_q_head_num, self.head_dim),
|
q.contiguous().view(-1, self.tp_q_head_num, self.head_dim),
|
||||||
@@ -179,7 +182,8 @@ class RadixAttention(nn.Module):
|
|||||||
if isinstance(decode_wrapper, list):
|
if isinstance(decode_wrapper, list):
|
||||||
decode_wrapper = decode_wrapper[1]
|
decode_wrapper = decode_wrapper[1]
|
||||||
|
|
||||||
self.store_kv_cache(k, v, input_metadata)
|
if not self.reuse:
|
||||||
|
self.store_kv_cache(k, v, input_metadata)
|
||||||
|
|
||||||
o = decode_wrapper.forward(
|
o = decode_wrapper.forward(
|
||||||
q.contiguous().view(-1, self.tp_q_head_num, self.head_dim),
|
q.contiguous().view(-1, self.tp_q_head_num, self.head_dim),
|
||||||
@@ -191,8 +195,10 @@ class RadixAttention(nn.Module):
|
|||||||
return o.view(-1, self.tp_q_head_num * self.head_dim)
|
return o.view(-1, self.tp_q_head_num * self.head_dim)
|
||||||
|
|
||||||
def forward(self, q, k, v, input_metadata: InputMetadata):
|
def forward(self, q, k, v, input_metadata: InputMetadata):
|
||||||
k = k.view(-1, self.tp_k_head_num, self.qk_head_dim)
|
if k is not None:
|
||||||
v = v.view(-1, self.tp_v_head_num, self.v_head_dim)
|
assert v is not None
|
||||||
|
k = k.view(-1, self.tp_k_head_num, self.qk_head_dim)
|
||||||
|
v = v.view(-1, self.tp_v_head_num, self.v_head_dim)
|
||||||
|
|
||||||
if input_metadata.forward_mode == ForwardMode.EXTEND:
|
if input_metadata.forward_mode == ForwardMode.EXTEND:
|
||||||
return self.extend_forward(q, k, v, input_metadata)
|
return self.extend_forward(q, k, v, input_metadata)
|
||||||
|
|||||||
@@ -107,9 +107,6 @@ class CudaGraphRunner:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# FlashInfer inputs
|
# FlashInfer inputs
|
||||||
self.flashinfer_workspace_buffer = (
|
|
||||||
self.model_runner.flashinfer_workspace_buffers[0]
|
|
||||||
)
|
|
||||||
self.flashinfer_kv_indptr = torch.zeros(
|
self.flashinfer_kv_indptr = torch.zeros(
|
||||||
(self.max_bs + 1,), dtype=torch.int32, device="cuda"
|
(self.max_bs + 1,), dtype=torch.int32, device="cuda"
|
||||||
)
|
)
|
||||||
@@ -121,6 +118,23 @@ class CudaGraphRunner:
|
|||||||
self.flashinfer_kv_last_page_len = torch.ones(
|
self.flashinfer_kv_last_page_len = torch.ones(
|
||||||
(self.max_bs,), dtype=torch.int32, device="cuda"
|
(self.max_bs,), dtype=torch.int32, device="cuda"
|
||||||
)
|
)
|
||||||
|
if model_runner.sliding_window_size is None:
|
||||||
|
self.flashinfer_workspace_buffer = (
|
||||||
|
self.model_runner.flashinfer_workspace_buffers[0]
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.flashinfer_workspace_buffers = [
|
||||||
|
self.model_runner.flashinfer_workspace_buffers[0],
|
||||||
|
self.model_runner.flashinfer_workspace_buffers[2],
|
||||||
|
]
|
||||||
|
self.flashinfer_kv_indptr = [
|
||||||
|
self.flashinfer_kv_indptr,
|
||||||
|
self.flashinfer_kv_indptr.clone(),
|
||||||
|
]
|
||||||
|
self.flashinfer_kv_indices = [
|
||||||
|
self.flashinfer_kv_indices,
|
||||||
|
self.flashinfer_kv_indices.clone(),
|
||||||
|
]
|
||||||
|
|
||||||
self.compile_bs = [1, 2, 4, 8, 16, 24, 32] if use_torch_compile else []
|
self.compile_bs = [1, 2, 4, 8, 16, 24, 32] if use_torch_compile else []
|
||||||
|
|
||||||
@@ -171,15 +185,32 @@ class CudaGraphRunner:
|
|||||||
use_tensor_cores = True
|
use_tensor_cores = True
|
||||||
else:
|
else:
|
||||||
use_tensor_cores = False
|
use_tensor_cores = False
|
||||||
flashinfer_decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
|
if self.model_runner.sliding_window_size is None:
|
||||||
self.flashinfer_workspace_buffer,
|
flashinfer_decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
|
||||||
"NHD",
|
self.flashinfer_workspace_buffer,
|
||||||
use_cuda_graph=True,
|
"NHD",
|
||||||
use_tensor_cores=use_tensor_cores,
|
use_cuda_graph=True,
|
||||||
paged_kv_indptr_buffer=self.flashinfer_kv_indptr[: bs + 1],
|
use_tensor_cores=use_tensor_cores,
|
||||||
paged_kv_indices_buffer=self.flashinfer_kv_indices,
|
paged_kv_indptr_buffer=self.flashinfer_kv_indptr[: bs + 1],
|
||||||
paged_kv_last_page_len_buffer=self.flashinfer_kv_last_page_len[:bs],
|
paged_kv_indices_buffer=self.flashinfer_kv_indices,
|
||||||
)
|
paged_kv_last_page_len_buffer=self.flashinfer_kv_last_page_len[:bs],
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
flashinfer_decode_wrapper = []
|
||||||
|
for i in range(2):
|
||||||
|
flashinfer_decode_wrapper.append(
|
||||||
|
BatchDecodeWithPagedKVCacheWrapper(
|
||||||
|
self.flashinfer_workspace_buffers[i],
|
||||||
|
"NHD",
|
||||||
|
use_cuda_graph=True,
|
||||||
|
use_tensor_cores=use_tensor_cores,
|
||||||
|
paged_kv_indptr_buffer=self.flashinfer_kv_indptr[i][: bs + 1],
|
||||||
|
paged_kv_indices_buffer=self.flashinfer_kv_indices[i],
|
||||||
|
paged_kv_last_page_len_buffer=self.flashinfer_kv_last_page_len[
|
||||||
|
:bs
|
||||||
|
],
|
||||||
|
)
|
||||||
|
)
|
||||||
update_flashinfer_indices(
|
update_flashinfer_indices(
|
||||||
ForwardMode.DECODE,
|
ForwardMode.DECODE,
|
||||||
self.model_runner,
|
self.model_runner,
|
||||||
|
|||||||
@@ -154,7 +154,6 @@ class InputMetadata:
|
|||||||
model_runner: "ModelRunner",
|
model_runner: "ModelRunner",
|
||||||
batch: ScheduleBatch,
|
batch: ScheduleBatch,
|
||||||
forward_mode: ForwardMode,
|
forward_mode: ForwardMode,
|
||||||
sliding_window_size: Optional[int] = None,
|
|
||||||
):
|
):
|
||||||
ret = cls(
|
ret = cls(
|
||||||
forward_mode=forward_mode,
|
forward_mode=forward_mode,
|
||||||
@@ -198,7 +197,7 @@ class InputMetadata:
|
|||||||
):
|
):
|
||||||
flashinfer_use_ragged = True
|
flashinfer_use_ragged = True
|
||||||
ret.init_flashinfer_handlers(
|
ret.init_flashinfer_handlers(
|
||||||
model_runner, prefix_lens, flashinfer_use_ragged, sliding_window_size
|
model_runner, prefix_lens, flashinfer_use_ragged
|
||||||
)
|
)
|
||||||
|
|
||||||
return ret
|
return ret
|
||||||
@@ -221,7 +220,6 @@ class InputMetadata:
|
|||||||
model_runner,
|
model_runner,
|
||||||
prefix_lens,
|
prefix_lens,
|
||||||
flashinfer_use_ragged,
|
flashinfer_use_ragged,
|
||||||
sliding_window_size=None,
|
|
||||||
):
|
):
|
||||||
update_flashinfer_indices(
|
update_flashinfer_indices(
|
||||||
self.forward_mode,
|
self.forward_mode,
|
||||||
@@ -230,7 +228,6 @@ class InputMetadata:
|
|||||||
self.seq_lens,
|
self.seq_lens,
|
||||||
prefix_lens,
|
prefix_lens,
|
||||||
flashinfer_use_ragged=flashinfer_use_ragged,
|
flashinfer_use_ragged=flashinfer_use_ragged,
|
||||||
sliding_window_size=sliding_window_size,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
(
|
(
|
||||||
@@ -254,7 +251,6 @@ def update_flashinfer_indices(
|
|||||||
prefix_lens,
|
prefix_lens,
|
||||||
flashinfer_decode_wrapper=None,
|
flashinfer_decode_wrapper=None,
|
||||||
flashinfer_use_ragged=False,
|
flashinfer_use_ragged=False,
|
||||||
sliding_window_size=None,
|
|
||||||
):
|
):
|
||||||
"""Init auxiliary variables for FlashInfer attention backend."""
|
"""Init auxiliary variables for FlashInfer attention backend."""
|
||||||
num_qo_heads = model_runner.model_config.num_attention_heads // model_runner.tp_size
|
num_qo_heads = model_runner.model_config.num_attention_heads // model_runner.tp_size
|
||||||
@@ -262,7 +258,7 @@ def update_flashinfer_indices(
|
|||||||
head_dim = model_runner.model_config.head_dim
|
head_dim = model_runner.model_config.head_dim
|
||||||
batch_size = len(req_pool_indices)
|
batch_size = len(req_pool_indices)
|
||||||
|
|
||||||
if sliding_window_size is None:
|
if model_runner.sliding_window_size is None:
|
||||||
if flashinfer_use_ragged:
|
if flashinfer_use_ragged:
|
||||||
paged_kernel_lens = prefix_lens
|
paged_kernel_lens = prefix_lens
|
||||||
else:
|
else:
|
||||||
@@ -335,7 +331,7 @@ def update_flashinfer_indices(
|
|||||||
|
|
||||||
if wrapper_id == 0 and forward_mode == ForwardMode.DECODE:
|
if wrapper_id == 0 and forward_mode == ForwardMode.DECODE:
|
||||||
paged_kernel_lens = torch.minimum(
|
paged_kernel_lens = torch.minimum(
|
||||||
paged_kernel_lens, torch.tensor(sliding_window_size)
|
paged_kernel_lens, torch.tensor(model_runner.sliding_window_size)
|
||||||
)
|
)
|
||||||
kv_start_idx = seq_lens - paged_kernel_lens
|
kv_start_idx = seq_lens - paged_kernel_lens
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -187,6 +187,11 @@ class ModelRunner:
|
|||||||
scheduler_config=None,
|
scheduler_config=None,
|
||||||
cache_config=None,
|
cache_config=None,
|
||||||
)
|
)
|
||||||
|
self.sliding_window_size = (
|
||||||
|
self.model.get_window_size()
|
||||||
|
if hasattr(self.model, "get_window_size")
|
||||||
|
else None
|
||||||
|
)
|
||||||
self.is_generation = is_generation_model(
|
self.is_generation = is_generation_model(
|
||||||
self.model_config.hf_config.architectures
|
self.model_config.hf_config.architectures
|
||||||
)
|
)
|
||||||
@@ -295,12 +300,6 @@ class ModelRunner:
|
|||||||
return c
|
return c
|
||||||
|
|
||||||
def init_flashinfer(self):
|
def init_flashinfer(self):
|
||||||
self.sliding_window_size = (
|
|
||||||
self.model.get_window_size()
|
|
||||||
if hasattr(self.model, "get_window_size")
|
|
||||||
else None
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.server_args.disable_flashinfer:
|
if self.server_args.disable_flashinfer:
|
||||||
assert (
|
assert (
|
||||||
self.sliding_window_size is None
|
self.sliding_window_size is None
|
||||||
@@ -339,7 +338,7 @@ class ModelRunner:
|
|||||||
use_tensor_cores=use_tensor_cores,
|
use_tensor_cores=use_tensor_cores,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
workspace_buffers = torch.empty(
|
self.flashinfer_workspace_buffers = torch.empty(
|
||||||
4,
|
4,
|
||||||
global_config.flashinfer_workspace_size,
|
global_config.flashinfer_workspace_size,
|
||||||
dtype=torch.uint8,
|
dtype=torch.uint8,
|
||||||
@@ -351,17 +350,17 @@ class ModelRunner:
|
|||||||
for i in range(2):
|
for i in range(2):
|
||||||
self.flashinfer_prefill_wrapper_ragged.append(
|
self.flashinfer_prefill_wrapper_ragged.append(
|
||||||
BatchPrefillWithRaggedKVCacheWrapper(
|
BatchPrefillWithRaggedKVCacheWrapper(
|
||||||
workspace_buffers[2 * i + 0], "NHD"
|
self.flashinfer_workspace_buffers[2 * i + 0], "NHD"
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
self.flashinfer_prefill_wrapper_paged.append(
|
self.flashinfer_prefill_wrapper_paged.append(
|
||||||
BatchPrefillWithPagedKVCacheWrapper(
|
BatchPrefillWithPagedKVCacheWrapper(
|
||||||
workspace_buffers[2 * i + 1], "NHD"
|
self.flashinfer_workspace_buffers[2 * i + 1], "NHD"
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
self.flashinfer_decode_wrapper.append(
|
self.flashinfer_decode_wrapper.append(
|
||||||
BatchDecodeWithPagedKVCacheWrapper(
|
BatchDecodeWithPagedKVCacheWrapper(
|
||||||
workspace_buffers[2 * i + 0],
|
self.flashinfer_workspace_buffers[2 * i + 0],
|
||||||
"NHD",
|
"NHD",
|
||||||
use_tensor_cores=use_tensor_cores,
|
use_tensor_cores=use_tensor_cores,
|
||||||
)
|
)
|
||||||
@@ -404,7 +403,6 @@ class ModelRunner:
|
|||||||
self,
|
self,
|
||||||
batch,
|
batch,
|
||||||
ForwardMode.DECODE,
|
ForwardMode.DECODE,
|
||||||
sliding_window_size=self.sliding_window_size,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return self.model.forward(
|
return self.model.forward(
|
||||||
@@ -417,7 +415,6 @@ class ModelRunner:
|
|||||||
self,
|
self,
|
||||||
batch,
|
batch,
|
||||||
forward_mode=ForwardMode.EXTEND,
|
forward_mode=ForwardMode.EXTEND,
|
||||||
sliding_window_size=self.sliding_window_size,
|
|
||||||
)
|
)
|
||||||
return self.model.forward(
|
return self.model.forward(
|
||||||
batch.input_ids, input_metadata.positions, input_metadata
|
batch.input_ids, input_metadata.positions, input_metadata
|
||||||
@@ -429,7 +426,6 @@ class ModelRunner:
|
|||||||
self,
|
self,
|
||||||
batch,
|
batch,
|
||||||
forward_mode=ForwardMode.EXTEND,
|
forward_mode=ForwardMode.EXTEND,
|
||||||
sliding_window_size=self.sliding_window_size,
|
|
||||||
)
|
)
|
||||||
return self.model.forward(
|
return self.model.forward(
|
||||||
batch.input_ids,
|
batch.input_ids,
|
||||||
|
|||||||
@@ -453,10 +453,12 @@ class ServerArgs:
|
|||||||
logger.info(
|
logger.info(
|
||||||
f"When using sliding window in gemma-2, disable radix_cache, regex_jump_forward, and turn on flashinfer."
|
f"When using sliding window in gemma-2, disable radix_cache, regex_jump_forward, and turn on flashinfer."
|
||||||
)
|
)
|
||||||
|
# FIXME: compatibility with radix attention
|
||||||
self.disable_radix_cache = True
|
self.disable_radix_cache = True
|
||||||
|
# FIXME: compatibility with jump forward
|
||||||
self.disable_regex_jump_forward = True
|
self.disable_regex_jump_forward = True
|
||||||
self.disable_flashinfer = False
|
self.disable_flashinfer = False
|
||||||
self.disable_cuda_graph = True
|
# FIXME: compatibility with chunked prefill
|
||||||
self.chunked_prefill_size = None
|
self.chunked_prefill_size = None
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -36,7 +36,7 @@ DEFAULT_PROMPTS = [
|
|||||||
]
|
]
|
||||||
|
|
||||||
dirpath = os.path.dirname(__file__)
|
dirpath = os.path.dirname(__file__)
|
||||||
with open(os.path.join(dirpath, "long_prompt"), "r") as f:
|
with open(os.path.join(dirpath, "long_prompt.txt"), "r") as f:
|
||||||
long_prompt = f.read()
|
long_prompt = f.read()
|
||||||
DEFAULT_PROMPTS.append(long_prompt)
|
DEFAULT_PROMPTS.append(long_prompt)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user