Support sliding window in triton backend (#6509)
This commit is contained in:
@@ -72,6 +72,65 @@ def get_num_kv_splits_triton(
|
|||||||
tl.store(num_kv_splits_ptr + i + offs_token, num_kv_splits, mask=mask_token)
|
tl.store(num_kv_splits_ptr + i + offs_token, num_kv_splits, mask=mask_token)
|
||||||
|
|
||||||
|
|
||||||
|
def update_sliding_window_buffer(
|
||||||
|
window_kv_indptr,
|
||||||
|
req_to_token,
|
||||||
|
sliding_window_size,
|
||||||
|
seq_lens,
|
||||||
|
req_pool_indices,
|
||||||
|
bs,
|
||||||
|
device,
|
||||||
|
):
|
||||||
|
window_kv_lens = torch.minimum(
|
||||||
|
seq_lens,
|
||||||
|
torch.tensor(sliding_window_size + 1),
|
||||||
|
)
|
||||||
|
window_kv_indptr[1 : bs + 1] = torch.cumsum(window_kv_lens, dim=0)
|
||||||
|
window_kv_indptr = window_kv_indptr[: bs + 1]
|
||||||
|
window_kv_indices = torch.empty(
|
||||||
|
window_kv_indptr[-1], dtype=torch.int32, device=device
|
||||||
|
)
|
||||||
|
window_kv_start_idx = seq_lens - window_kv_lens
|
||||||
|
create_flashinfer_kv_indices_triton[(bs,)](
|
||||||
|
req_to_token,
|
||||||
|
req_pool_indices,
|
||||||
|
window_kv_lens,
|
||||||
|
window_kv_indptr,
|
||||||
|
window_kv_start_idx,
|
||||||
|
window_kv_indices,
|
||||||
|
req_to_token.stride(0),
|
||||||
|
)
|
||||||
|
return window_kv_indptr, window_kv_indices, window_kv_lens
|
||||||
|
|
||||||
|
|
||||||
|
def update_sliding_window_buffer_cuda_graph(
|
||||||
|
window_kv_indptr,
|
||||||
|
window_kv_indices,
|
||||||
|
req_to_token,
|
||||||
|
sliding_window_size,
|
||||||
|
seq_lens,
|
||||||
|
req_pool_indices,
|
||||||
|
bs,
|
||||||
|
):
|
||||||
|
window_kv_lens = torch.minimum(
|
||||||
|
seq_lens,
|
||||||
|
torch.tensor(sliding_window_size + 1),
|
||||||
|
)
|
||||||
|
window_kv_indptr[1 : bs + 1] = torch.cumsum(window_kv_lens, dim=0)
|
||||||
|
window_kv_indptr = window_kv_indptr[: bs + 1]
|
||||||
|
window_kv_start_idx = seq_lens - window_kv_lens
|
||||||
|
create_flashinfer_kv_indices_triton[(bs,)](
|
||||||
|
req_to_token,
|
||||||
|
req_pool_indices,
|
||||||
|
window_kv_lens,
|
||||||
|
window_kv_indptr,
|
||||||
|
window_kv_start_idx,
|
||||||
|
window_kv_indices,
|
||||||
|
req_to_token.stride(0),
|
||||||
|
)
|
||||||
|
return window_kv_indptr, window_kv_lens
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ForwardMetadata:
|
class ForwardMetadata:
|
||||||
attn_logits: torch.Tensor
|
attn_logits: torch.Tensor
|
||||||
@@ -83,6 +142,10 @@ class ForwardMetadata:
|
|||||||
qo_indptr: torch.Tensor
|
qo_indptr: torch.Tensor
|
||||||
custom_mask: torch.Tensor
|
custom_mask: torch.Tensor
|
||||||
mask_indptr: torch.Tensor
|
mask_indptr: torch.Tensor
|
||||||
|
# Sliding window
|
||||||
|
window_kv_indptr: torch.Tensor
|
||||||
|
window_kv_indices: torch.Tensor
|
||||||
|
window_num_kv_splits: torch.Tensor
|
||||||
|
|
||||||
|
|
||||||
class TritonAttnBackend(AttentionBackend):
|
class TritonAttnBackend(AttentionBackend):
|
||||||
@@ -109,6 +172,13 @@ class TritonAttnBackend(AttentionBackend):
|
|||||||
|
|
||||||
max_bs = model_runner.req_to_token_pool.size
|
max_bs = model_runner.req_to_token_pool.size
|
||||||
|
|
||||||
|
assert not (
|
||||||
|
model_runner.sliding_window_size is not None
|
||||||
|
and model_runner.model_config.is_encoder_decoder
|
||||||
|
), "Sliding window and cross attention are not supported together"
|
||||||
|
self.sliding_window_size = model_runner.sliding_window_size
|
||||||
|
|
||||||
|
# TODO(Jianan Ji): Make sure it behaves as expected when kv_indptr_buf is provided and sliding window is enabled
|
||||||
if kv_indptr_buf is None:
|
if kv_indptr_buf is None:
|
||||||
self.kv_indptr = torch.zeros(
|
self.kv_indptr = torch.zeros(
|
||||||
(max_bs + 1,), dtype=torch.int32, device=model_runner.device
|
(max_bs + 1,), dtype=torch.int32, device=model_runner.device
|
||||||
@@ -116,6 +186,18 @@ class TritonAttnBackend(AttentionBackend):
|
|||||||
else:
|
else:
|
||||||
self.kv_indptr = kv_indptr_buf
|
self.kv_indptr = kv_indptr_buf
|
||||||
|
|
||||||
|
# If sliding window is enabled, we might need two sets of buffers
|
||||||
|
# because of interleaved attention types (e.g. for Gemma3)
|
||||||
|
self.window_kv_indptr = None
|
||||||
|
if self.sliding_window_size is not None and self.sliding_window_size > 0:
|
||||||
|
if kv_indptr_buf is None:
|
||||||
|
self.window_kv_indptr = torch.zeros(
|
||||||
|
(max_bs + 1,), dtype=torch.int32, device=model_runner.device
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# When provided a buffer, create a clone for the second buffer
|
||||||
|
self.window_kv_indptr = torch.zeros_like(kv_indptr_buf)
|
||||||
|
|
||||||
self.req_to_token = model_runner.req_to_token_pool.req_to_token
|
self.req_to_token = model_runner.req_to_token_pool.req_to_token
|
||||||
|
|
||||||
if not self.skip_prefill:
|
if not self.skip_prefill:
|
||||||
@@ -191,6 +273,9 @@ class TritonAttnBackend(AttentionBackend):
|
|||||||
|
|
||||||
bs = forward_batch.batch_size
|
bs = forward_batch.batch_size
|
||||||
kv_indptr = self.kv_indptr
|
kv_indptr = self.kv_indptr
|
||||||
|
window_kv_indptr = self.window_kv_indptr
|
||||||
|
window_kv_indices = None
|
||||||
|
window_num_kv_splits = None
|
||||||
spec_info = forward_batch.spec_info
|
spec_info = forward_batch.spec_info
|
||||||
|
|
||||||
if forward_batch.forward_mode.is_decode_or_idle():
|
if forward_batch.forward_mode.is_decode_or_idle():
|
||||||
@@ -209,6 +294,26 @@ class TritonAttnBackend(AttentionBackend):
|
|||||||
kv_indices,
|
kv_indices,
|
||||||
self.req_to_token.stride(0),
|
self.req_to_token.stride(0),
|
||||||
)
|
)
|
||||||
|
# Sliding window
|
||||||
|
if (
|
||||||
|
self.sliding_window_size is not None
|
||||||
|
and self.sliding_window_size > 0
|
||||||
|
):
|
||||||
|
window_kv_indptr, window_kv_indices, window_kv_lens = (
|
||||||
|
update_sliding_window_buffer(
|
||||||
|
self.window_kv_indptr,
|
||||||
|
self.req_to_token,
|
||||||
|
self.sliding_window_size,
|
||||||
|
forward_batch.seq_lens,
|
||||||
|
forward_batch.req_pool_indices,
|
||||||
|
bs,
|
||||||
|
self.device,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
window_num_kv_splits = torch.empty(
|
||||||
|
(bs,), dtype=torch.int32, device=self.device
|
||||||
|
)
|
||||||
|
self.get_num_kv_splits(window_num_kv_splits, window_kv_lens)
|
||||||
else:
|
else:
|
||||||
kv_indptr, kv_indices = spec_info.kv_indptr, spec_info.kv_indices
|
kv_indptr, kv_indices = spec_info.kv_indptr, spec_info.kv_indices
|
||||||
bs = kv_indptr.shape[0] - 1
|
bs = kv_indptr.shape[0] - 1
|
||||||
@@ -224,7 +329,6 @@ class TritonAttnBackend(AttentionBackend):
|
|||||||
device=self.device,
|
device=self.device,
|
||||||
)
|
)
|
||||||
num_kv_splits = torch.empty((bs,), dtype=torch.int32, device=self.device)
|
num_kv_splits = torch.empty((bs,), dtype=torch.int32, device=self.device)
|
||||||
|
|
||||||
self.get_num_kv_splits(num_kv_splits, forward_batch.seq_lens)
|
self.get_num_kv_splits(num_kv_splits, forward_batch.seq_lens)
|
||||||
|
|
||||||
qo_indptr = None
|
qo_indptr = None
|
||||||
@@ -232,6 +336,7 @@ class TritonAttnBackend(AttentionBackend):
|
|||||||
mask_indptr = None
|
mask_indptr = None
|
||||||
max_extend_len = None
|
max_extend_len = None
|
||||||
elif forward_batch.forward_mode.is_target_verify():
|
elif forward_batch.forward_mode.is_target_verify():
|
||||||
|
# TODO: Support sliding window in spec inference
|
||||||
bs = len(forward_batch.req_pool_indices)
|
bs = len(forward_batch.req_pool_indices)
|
||||||
qo_indptr = torch.arange(
|
qo_indptr = torch.arange(
|
||||||
0,
|
0,
|
||||||
@@ -303,6 +408,17 @@ class TritonAttnBackend(AttentionBackend):
|
|||||||
kv_indices,
|
kv_indices,
|
||||||
self.req_to_token.stride(0),
|
self.req_to_token.stride(0),
|
||||||
)
|
)
|
||||||
|
# Sliding window
|
||||||
|
if self.sliding_window_size is not None and self.sliding_window_size > 0:
|
||||||
|
window_kv_indptr, window_kv_indices, _ = update_sliding_window_buffer(
|
||||||
|
self.window_kv_indptr,
|
||||||
|
self.req_to_token,
|
||||||
|
self.sliding_window_size,
|
||||||
|
forward_batch.extend_prefix_lens,
|
||||||
|
forward_batch.req_pool_indices,
|
||||||
|
bs,
|
||||||
|
self.device,
|
||||||
|
)
|
||||||
|
|
||||||
qo_indptr = self.qo_indptr
|
qo_indptr = self.qo_indptr
|
||||||
qo_indptr[1 : bs + 1] = torch.cumsum(forward_batch.extend_seq_lens, dim=0)
|
qo_indptr[1 : bs + 1] = torch.cumsum(forward_batch.extend_seq_lens, dim=0)
|
||||||
@@ -324,6 +440,9 @@ class TritonAttnBackend(AttentionBackend):
|
|||||||
qo_indptr,
|
qo_indptr,
|
||||||
custom_mask,
|
custom_mask,
|
||||||
mask_indptr,
|
mask_indptr,
|
||||||
|
window_kv_indptr,
|
||||||
|
window_kv_indices,
|
||||||
|
window_num_kv_splits,
|
||||||
)
|
)
|
||||||
|
|
||||||
def init_cuda_graph_state(
|
def init_cuda_graph_state(
|
||||||
@@ -358,6 +477,20 @@ class TritonAttnBackend(AttentionBackend):
|
|||||||
device=self.device,
|
device=self.device,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if self.sliding_window_size is not None and self.sliding_window_size > 0:
|
||||||
|
if kv_indices_buf is None:
|
||||||
|
self.cuda_graph_window_kv_indices = torch.zeros(
|
||||||
|
(max_bs * self.sliding_window_size),
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=self.device,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.cuda_graph_window_kv_indices = torch.zeros_like(kv_indices_buf)
|
||||||
|
|
||||||
|
self.cuda_graph_window_num_kv_splits = torch.full(
|
||||||
|
(max_bs,), self.max_kv_splits, dtype=torch.int32, device=self.device
|
||||||
|
)
|
||||||
|
|
||||||
def init_forward_metadata_capture_cuda_graph(
|
def init_forward_metadata_capture_cuda_graph(
|
||||||
self,
|
self,
|
||||||
bs: int,
|
bs: int,
|
||||||
@@ -369,6 +502,9 @@ class TritonAttnBackend(AttentionBackend):
|
|||||||
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
|
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
|
||||||
):
|
):
|
||||||
assert encoder_lens is None, "Not supported"
|
assert encoder_lens is None, "Not supported"
|
||||||
|
window_kv_indptr = self.window_kv_indptr
|
||||||
|
window_kv_indices = None
|
||||||
|
window_num_kv_splits = None
|
||||||
|
|
||||||
if forward_mode.is_decode_or_idle():
|
if forward_mode.is_decode_or_idle():
|
||||||
if spec_info is None:
|
if spec_info is None:
|
||||||
@@ -385,6 +521,21 @@ class TritonAttnBackend(AttentionBackend):
|
|||||||
kv_indices,
|
kv_indices,
|
||||||
self.req_to_token.stride(0),
|
self.req_to_token.stride(0),
|
||||||
)
|
)
|
||||||
|
if (
|
||||||
|
self.sliding_window_size is not None
|
||||||
|
and self.sliding_window_size > 0
|
||||||
|
):
|
||||||
|
window_kv_indices = self.cuda_graph_window_kv_indices
|
||||||
|
window_num_kv_splits = self.cuda_graph_window_num_kv_splits
|
||||||
|
window_kv_indptr, _ = update_sliding_window_buffer_cuda_graph(
|
||||||
|
self.window_kv_indptr,
|
||||||
|
window_kv_indices,
|
||||||
|
self.req_to_token,
|
||||||
|
self.sliding_window_size,
|
||||||
|
seq_lens[:bs],
|
||||||
|
req_pool_indices,
|
||||||
|
bs,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
kv_indptr, kv_indices = spec_info.kv_indptr, spec_info.kv_indices
|
kv_indptr, kv_indices = spec_info.kv_indptr, spec_info.kv_indices
|
||||||
|
|
||||||
@@ -468,6 +619,9 @@ class TritonAttnBackend(AttentionBackend):
|
|||||||
qo_indptr,
|
qo_indptr,
|
||||||
custom_mask,
|
custom_mask,
|
||||||
mask_indptr,
|
mask_indptr,
|
||||||
|
window_kv_indptr,
|
||||||
|
window_kv_indices,
|
||||||
|
window_num_kv_splits,
|
||||||
)
|
)
|
||||||
|
|
||||||
def init_forward_metadata_replay_cuda_graph(
|
def init_forward_metadata_replay_cuda_graph(
|
||||||
@@ -500,11 +654,31 @@ class TritonAttnBackend(AttentionBackend):
|
|||||||
self.req_to_token.stride(0),
|
self.req_to_token.stride(0),
|
||||||
)
|
)
|
||||||
num_token = bs
|
num_token = bs
|
||||||
|
if (
|
||||||
|
self.sliding_window_size is not None
|
||||||
|
and self.sliding_window_size > 0
|
||||||
|
):
|
||||||
|
window_num_kv_splits = self.cuda_graph_window_num_kv_splits
|
||||||
|
window_kv_indices = self.cuda_graph_window_kv_indices
|
||||||
|
_, window_kv_lens = update_sliding_window_buffer_cuda_graph(
|
||||||
|
self.window_kv_indptr,
|
||||||
|
window_kv_indices,
|
||||||
|
self.req_to_token,
|
||||||
|
self.sliding_window_size,
|
||||||
|
seq_lens[:bs],
|
||||||
|
req_pool_indices[:bs],
|
||||||
|
bs,
|
||||||
|
)
|
||||||
|
self.get_num_kv_splits(
|
||||||
|
window_num_kv_splits[:num_token], window_kv_lens[:bs]
|
||||||
|
)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
kv_indptr[: spec_info.kv_indptr.shape[0]] = spec_info.kv_indptr
|
kv_indptr[: spec_info.kv_indptr.shape[0]] = spec_info.kv_indptr
|
||||||
kv_indices[: spec_info.kv_indices.shape[0]] = spec_info.kv_indices
|
kv_indices[: spec_info.kv_indices.shape[0]] = spec_info.kv_indices
|
||||||
num_token = spec_info.kv_indptr.shape[0] - 1
|
num_token = spec_info.kv_indptr.shape[0] - 1
|
||||||
self.get_num_kv_splits(num_kv_splits[:num_token], seq_lens[:bs])
|
self.get_num_kv_splits(num_kv_splits[:num_token], seq_lens[:bs])
|
||||||
|
|
||||||
elif forward_mode.is_target_verify():
|
elif forward_mode.is_target_verify():
|
||||||
# Update qo_indptr, kv_indptr, kv_indices, custom_mask, mask_indptr
|
# Update qo_indptr, kv_indptr, kv_indices, custom_mask, mask_indptr
|
||||||
bs = len(req_pool_indices)
|
bs = len(req_pool_indices)
|
||||||
@@ -582,6 +756,17 @@ class TritonAttnBackend(AttentionBackend):
|
|||||||
if layer.attn_type == AttentionType.ENCODER_ONLY:
|
if layer.attn_type == AttentionType.ENCODER_ONLY:
|
||||||
causal = False
|
causal = False
|
||||||
|
|
||||||
|
if layer.sliding_window_size is not None and layer.sliding_window_size > -1:
|
||||||
|
sliding_window_size = (
|
||||||
|
layer.sliding_window_size
|
||||||
|
) # Needed for sliding window mask
|
||||||
|
kv_indptr = self.forward_metadata.window_kv_indptr
|
||||||
|
kv_indices = self.forward_metadata.window_kv_indices
|
||||||
|
else:
|
||||||
|
sliding_window_size = -1
|
||||||
|
kv_indptr = self.forward_metadata.kv_indptr
|
||||||
|
kv_indices = self.forward_metadata.kv_indices
|
||||||
|
|
||||||
self.extend_attention_fwd(
|
self.extend_attention_fwd(
|
||||||
q.view(-1, layer.tp_q_head_num, layer.qk_head_dim),
|
q.view(-1, layer.tp_q_head_num, layer.qk_head_dim),
|
||||||
k.contiguous(),
|
k.contiguous(),
|
||||||
@@ -590,14 +775,15 @@ class TritonAttnBackend(AttentionBackend):
|
|||||||
forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id),
|
forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id),
|
||||||
forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id),
|
forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id),
|
||||||
self.forward_metadata.qo_indptr,
|
self.forward_metadata.qo_indptr,
|
||||||
self.forward_metadata.kv_indptr,
|
kv_indptr,
|
||||||
self.forward_metadata.kv_indices,
|
kv_indices,
|
||||||
self.forward_metadata.custom_mask,
|
self.forward_metadata.custom_mask,
|
||||||
causal,
|
causal,
|
||||||
self.forward_metadata.mask_indptr,
|
self.forward_metadata.mask_indptr,
|
||||||
self.forward_metadata.max_extend_len,
|
self.forward_metadata.max_extend_len,
|
||||||
layer.scaling,
|
layer.scaling,
|
||||||
layer.logit_cap,
|
layer.logit_cap,
|
||||||
|
sliding_window_size,
|
||||||
)
|
)
|
||||||
return o
|
return o
|
||||||
|
|
||||||
@@ -625,13 +811,20 @@ class TritonAttnBackend(AttentionBackend):
|
|||||||
layer, forward_batch.out_cache_loc, k, v
|
layer, forward_batch.out_cache_loc, k, v
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if layer.sliding_window_size is not None and layer.sliding_window_size > -1:
|
||||||
|
kv_indptr = self.forward_metadata.window_kv_indptr
|
||||||
|
kv_indices = self.forward_metadata.window_kv_indices
|
||||||
|
else:
|
||||||
|
kv_indptr = self.forward_metadata.kv_indptr
|
||||||
|
kv_indices = self.forward_metadata.kv_indices
|
||||||
|
|
||||||
self.decode_attention_fwd(
|
self.decode_attention_fwd(
|
||||||
q.view(-1, layer.tp_q_head_num, layer.qk_head_dim),
|
q.view(-1, layer.tp_q_head_num, layer.qk_head_dim),
|
||||||
forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id),
|
forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id),
|
||||||
forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id),
|
forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id),
|
||||||
o.view(-1, layer.tp_q_head_num, layer.v_head_dim),
|
o.view(-1, layer.tp_q_head_num, layer.v_head_dim),
|
||||||
self.forward_metadata.kv_indptr,
|
kv_indptr,
|
||||||
self.forward_metadata.kv_indices,
|
kv_indices,
|
||||||
self.forward_metadata.attn_logits,
|
self.forward_metadata.attn_logits,
|
||||||
self.forward_metadata.attn_lse,
|
self.forward_metadata.attn_lse,
|
||||||
self.forward_metadata.num_kv_splits,
|
self.forward_metadata.num_kv_splits,
|
||||||
|
|||||||
@@ -65,6 +65,7 @@ def _fwd_kernel(
|
|||||||
stride_buf_kh,
|
stride_buf_kh,
|
||||||
stride_buf_vbs,
|
stride_buf_vbs,
|
||||||
stride_buf_vh,
|
stride_buf_vh,
|
||||||
|
SLIDING_WINDOW_SIZE: tl.constexpr,
|
||||||
logit_cap: tl.constexpr,
|
logit_cap: tl.constexpr,
|
||||||
Lq: tl.constexpr,
|
Lq: tl.constexpr,
|
||||||
Lv: tl.constexpr,
|
Lv: tl.constexpr,
|
||||||
@@ -163,6 +164,7 @@ def _fwd_kernel(
|
|||||||
if logit_cap > 0:
|
if logit_cap > 0:
|
||||||
qk = logit_cap * tanh(qk / logit_cap)
|
qk = logit_cap * tanh(qk / logit_cap)
|
||||||
|
|
||||||
|
final_mask = mask_m[:, None] & mask_n[None, :]
|
||||||
if USE_CUSTOM_MASK and not SKIP_PREFIX_CUSTOM_MASK:
|
if USE_CUSTOM_MASK and not SKIP_PREFIX_CUSTOM_MASK:
|
||||||
custom_mask = tl.load(
|
custom_mask = tl.load(
|
||||||
mask_ptr
|
mask_ptr
|
||||||
@@ -173,10 +175,14 @@ def _fwd_kernel(
|
|||||||
mask=(mask_m[:, None] & mask_n[None, :]),
|
mask=(mask_m[:, None] & mask_n[None, :]),
|
||||||
other=0,
|
other=0,
|
||||||
)
|
)
|
||||||
custom_mask &= mask_m[:, None] & mask_n[None, :]
|
final_mask &= custom_mask
|
||||||
qk = tl.where(custom_mask, qk, float("-inf"))
|
if SLIDING_WINDOW_SIZE > 0:
|
||||||
else:
|
# Add mask where q_id <= kv_id + sliding_window_size
|
||||||
qk = tl.where(mask_m[:, None] & mask_n[None, :], qk, float("-inf"))
|
window_mask = (cur_block_m * BLOCK_M + offs_m[:, None]) <= (
|
||||||
|
start_n + offs_n[None, :] + SLIDING_WINDOW_SIZE
|
||||||
|
)
|
||||||
|
final_mask &= window_mask
|
||||||
|
qk = tl.where(final_mask, qk, float("-inf"))
|
||||||
|
|
||||||
n_e_max = tl.maximum(tl.max(qk, 1), e_max)
|
n_e_max = tl.maximum(tl.max(qk, 1), e_max)
|
||||||
re_scale = tl.exp(e_max - n_e_max)
|
re_scale = tl.exp(e_max - n_e_max)
|
||||||
@@ -314,6 +320,7 @@ def extend_attention_fwd(
|
|||||||
sm_scale=None,
|
sm_scale=None,
|
||||||
logit_cap=0.0,
|
logit_cap=0.0,
|
||||||
skip_prefix_custom_mask=True,
|
skip_prefix_custom_mask=True,
|
||||||
|
sliding_window_size=-1,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
q_extend, k_extend, v_extend, o_extend: contiguous tensors
|
q_extend, k_extend, v_extend, o_extend: contiguous tensors
|
||||||
@@ -412,6 +419,7 @@ def extend_attention_fwd(
|
|||||||
k_buffer.stride(1),
|
k_buffer.stride(1),
|
||||||
v_buffer.stride(0),
|
v_buffer.stride(0),
|
||||||
v_buffer.stride(1),
|
v_buffer.stride(1),
|
||||||
|
SLIDING_WINDOW_SIZE=sliding_window_size,
|
||||||
logit_cap=logit_cap,
|
logit_cap=logit_cap,
|
||||||
BLOCK_DMODEL=BLOCK_DMODEL,
|
BLOCK_DMODEL=BLOCK_DMODEL,
|
||||||
BLOCK_DPE=BLOCK_DPE,
|
BLOCK_DPE=BLOCK_DPE,
|
||||||
|
|||||||
@@ -1025,10 +1025,6 @@ class ModelRunner:
|
|||||||
|
|
||||||
return AiterAttnBackend(self)
|
return AiterAttnBackend(self)
|
||||||
elif self.server_args.attention_backend == "triton":
|
elif self.server_args.attention_backend == "triton":
|
||||||
assert self.sliding_window_size is None, (
|
|
||||||
"Window attention is not supported in the triton attention backend. "
|
|
||||||
"Please use `--attention-backend flashinfer`."
|
|
||||||
)
|
|
||||||
assert not self.model_config.is_encoder_decoder, (
|
assert not self.model_config.is_encoder_decoder, (
|
||||||
"Cross attention is not supported in the triton attention backend. "
|
"Cross attention is not supported in the triton attention backend. "
|
||||||
"Please use `--attention-backend flashinfer`."
|
"Please use `--attention-backend flashinfer`."
|
||||||
|
|||||||
@@ -277,6 +277,13 @@ class Gemma3Attention(nn.Module):
|
|||||||
k = k.permute(0, 2, 1, 3)
|
k = k.permute(0, 2, 1, 3)
|
||||||
|
|
||||||
attn_output = self.attn(q, k, v, forward_batch=forward_batch)
|
attn_output = self.attn(q, k, v, forward_batch=forward_batch)
|
||||||
|
|
||||||
|
# Compatible with triton backend which returns [1, s, h, head_dim]
|
||||||
|
if attn_output.dim() == 4 and attn_output.shape[0] == 1:
|
||||||
|
attn_output = attn_output.squeeze(0)
|
||||||
|
attn_output = attn_output.flatten(-2, -1)
|
||||||
|
# [s, h * head_dim]
|
||||||
|
|
||||||
output, _ = self.o_proj(attn_output)
|
output, _ = self.o_proj(attn_output)
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
|||||||
@@ -78,6 +78,7 @@ suites = {
|
|||||||
TestFile("test_triton_attention_kernels.py", 4),
|
TestFile("test_triton_attention_kernels.py", 4),
|
||||||
TestFile("test_triton_attention_backend.py", 134),
|
TestFile("test_triton_attention_backend.py", 134),
|
||||||
TestFile("test_triton_moe_channel_fp8_kernel.py", 25),
|
TestFile("test_triton_moe_channel_fp8_kernel.py", 25),
|
||||||
|
TestFile("test_triton_sliding_window.py", 250),
|
||||||
TestFile("test_update_weights_from_disk.py", 114),
|
TestFile("test_update_weights_from_disk.py", 114),
|
||||||
TestFile("test_update_weights_from_tensor.py", 48),
|
TestFile("test_update_weights_from_tensor.py", 48),
|
||||||
TestFile("test_vertex_endpoint.py", 31),
|
TestFile("test_vertex_endpoint.py", 31),
|
||||||
|
|||||||
132
test/srt/test_triton_sliding_window.py
Normal file
132
test/srt/test_triton_sliding_window.py
Normal file
@@ -0,0 +1,132 @@
|
|||||||
|
import time
|
||||||
|
import unittest
|
||||||
|
from types import SimpleNamespace
|
||||||
|
|
||||||
|
import requests
|
||||||
|
|
||||||
|
from sglang.srt.utils import kill_process_tree
|
||||||
|
from sglang.test.run_eval import run_eval
|
||||||
|
from sglang.test.test_utils import (
|
||||||
|
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||||
|
DEFAULT_URL_FOR_TEST,
|
||||||
|
CustomTestCase,
|
||||||
|
popen_launch_server,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestSlidingWindowAttentionTriton(CustomTestCase):
|
||||||
|
"""Test sliding window attention functionality with triton backend."""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def setUpClass(cls):
|
||||||
|
"""Set up the test server with Gemma3 model and triton backend."""
|
||||||
|
# Gemma3 model supports sliding window attention
|
||||||
|
cls.model = "google/gemma-3-4b-it"
|
||||||
|
cls.base_url = DEFAULT_URL_FOR_TEST
|
||||||
|
|
||||||
|
cls.common_args = [
|
||||||
|
"--trust-remote-code",
|
||||||
|
"--attention-backend",
|
||||||
|
"triton",
|
||||||
|
"--context-length",
|
||||||
|
"8192",
|
||||||
|
"--random-seed",
|
||||||
|
"42",
|
||||||
|
]
|
||||||
|
|
||||||
|
cls.short_context_prompt = "The capital of France is"
|
||||||
|
|
||||||
|
# Test prompt longer than window size
|
||||||
|
cls.long_context_prompt = (
|
||||||
|
"""
|
||||||
|
Once upon a time, there was a mountain. In the mountain, there was a temple. In the temple, there was an old monk telling a story. The story was:
|
||||||
|
"""
|
||||||
|
* 100
|
||||||
|
)
|
||||||
|
cls.long_context_prompt += "\nNow, summarize the story in one sentence:"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def tearDownClass(cls):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def _test_mmlu(self):
|
||||||
|
args = SimpleNamespace(
|
||||||
|
base_url=self.base_url,
|
||||||
|
model=self.model,
|
||||||
|
eval_name="mmlu",
|
||||||
|
num_examples=64,
|
||||||
|
num_threads=32,
|
||||||
|
)
|
||||||
|
|
||||||
|
metrics = run_eval(args)
|
||||||
|
print(f"MMLU metrics with sliding window: {metrics}")
|
||||||
|
|
||||||
|
self.assertGreaterEqual(metrics["score"], 0.64)
|
||||||
|
|
||||||
|
def _test_short_context_generation(self):
|
||||||
|
response = requests.post(
|
||||||
|
self.base_url + "/generate",
|
||||||
|
json={
|
||||||
|
"text": self.short_context_prompt,
|
||||||
|
"sampling_params": {
|
||||||
|
"temperature": 0,
|
||||||
|
"max_new_tokens": 256,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertEqual(response.status_code, 200)
|
||||||
|
result = response.json()
|
||||||
|
self.assertIn("paris", result["text"].lower())
|
||||||
|
print(f"Short context generation result: {result['text']}")
|
||||||
|
|
||||||
|
def _test_long_context_generation(self):
|
||||||
|
response = requests.post(
|
||||||
|
self.base_url + "/generate",
|
||||||
|
json={
|
||||||
|
"text": self.long_context_prompt,
|
||||||
|
"sampling_params": {
|
||||||
|
"temperature": 0,
|
||||||
|
"max_new_tokens": 256,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertEqual(response.status_code, 200)
|
||||||
|
result = response.json()
|
||||||
|
self.assertGreater(len(result["text"].strip()), 0)
|
||||||
|
print(f"Long context generation result: {result['text'][:100]}...")
|
||||||
|
|
||||||
|
def test_no_cuda_graph(self):
|
||||||
|
self.no_cuda_graph_process = popen_launch_server(
|
||||||
|
self.model,
|
||||||
|
self.base_url,
|
||||||
|
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||||
|
other_args=self.common_args + ["--disable-cuda-graph"],
|
||||||
|
)
|
||||||
|
|
||||||
|
self._test_short_context_generation()
|
||||||
|
self._test_long_context_generation()
|
||||||
|
self._test_mmlu()
|
||||||
|
|
||||||
|
kill_process_tree(self.no_cuda_graph_process.pid)
|
||||||
|
time.sleep(5)
|
||||||
|
|
||||||
|
def test_cuda_graph(self):
|
||||||
|
self.cuda_graph_process = popen_launch_server(
|
||||||
|
self.model,
|
||||||
|
self.base_url,
|
||||||
|
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||||
|
other_args=self.common_args,
|
||||||
|
)
|
||||||
|
|
||||||
|
self._test_short_context_generation()
|
||||||
|
self._test_long_context_generation()
|
||||||
|
self._test_mmlu()
|
||||||
|
|
||||||
|
kill_process_tree(self.cuda_graph_process.pid)
|
||||||
|
time.sleep(5)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
||||||
Reference in New Issue
Block a user