Support Eagle2 for Triton backend (#3466)
This commit is contained in:
@@ -3,6 +3,7 @@ from __future__ import annotations
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
import torch
|
||||
import triton
|
||||
|
||||
from sglang.srt.layers.attention import AttentionBackend
|
||||
from sglang.srt.layers.attention.flashinfer_backend import (
|
||||
@@ -18,7 +19,12 @@ if TYPE_CHECKING:
|
||||
|
||||
|
||||
class TritonAttnBackend(AttentionBackend):
|
||||
def __init__(self, model_runner: ModelRunner):
|
||||
def __init__(
|
||||
self,
|
||||
model_runner: ModelRunner,
|
||||
skip_prefill: bool = False,
|
||||
kv_indptr_buf: Optional[torch.Tensor] = None,
|
||||
):
|
||||
# Lazy import to avoid the initialization of cuda context
|
||||
from sglang.srt.layers.attention.triton_ops.decode_attention import (
|
||||
decode_attention_fwd,
|
||||
@@ -33,14 +39,25 @@ class TritonAttnBackend(AttentionBackend):
|
||||
self.extend_attention_fwd = extend_attention_fwd
|
||||
|
||||
max_bs = model_runner.req_to_token_pool.size
|
||||
self.kv_indptr = torch.zeros(
|
||||
(max_bs + 1,), dtype=torch.int32, device=model_runner.device
|
||||
)
|
||||
|
||||
if kv_indptr_buf is None:
|
||||
self.kv_indptr = torch.zeros(
|
||||
(max_bs + 1,), dtype=torch.int32, device=model_runner.device
|
||||
)
|
||||
else:
|
||||
self.kv_indptr = kv_indptr_buf
|
||||
|
||||
self.req_to_token = model_runner.req_to_token_pool.req_to_token
|
||||
self.qo_indptr = torch.zeros(
|
||||
(max_bs + 1,), dtype=torch.int32, device=model_runner.device
|
||||
)
|
||||
|
||||
self.mask_indptr = torch.zeros(
|
||||
(max_bs + 1,), dtype=torch.int64, device=model_runner.device
|
||||
)
|
||||
|
||||
self.num_draft_tokens = model_runner.server_args.speculative_num_draft_tokens
|
||||
|
||||
self.num_head = (
|
||||
model_runner.model_config.num_attention_heads // get_attention_tp_size()
|
||||
)
|
||||
@@ -50,7 +67,7 @@ class TritonAttnBackend(AttentionBackend):
|
||||
|
||||
self.forward_metadata = None
|
||||
|
||||
self.cuda_graph_max_seq_len = model_runner.model_config.context_len
|
||||
self.max_context_len = model_runner.model_config.context_len
|
||||
|
||||
self.device = model_runner.device
|
||||
|
||||
@@ -59,11 +76,31 @@ class TritonAttnBackend(AttentionBackend):
|
||||
|
||||
bs = forward_batch.batch_size
|
||||
kv_indptr = self.kv_indptr
|
||||
spec_info = forward_batch.spec_info
|
||||
|
||||
if forward_batch.forward_mode.is_decode():
|
||||
attn_logits = torch.empty(
|
||||
if forward_batch.forward_mode.is_decode_or_idle():
|
||||
if spec_info is None:
|
||||
kv_indptr[1 : bs + 1] = torch.cumsum(forward_batch.seq_lens, dim=0)
|
||||
kv_indptr = kv_indptr[: bs + 1]
|
||||
kv_indices = torch.zeros(
|
||||
forward_batch.seq_lens_sum, dtype=torch.int32, device=self.device
|
||||
)
|
||||
create_flashinfer_kv_indices_triton[(bs,)](
|
||||
self.req_to_token,
|
||||
forward_batch.req_pool_indices,
|
||||
forward_batch.seq_lens,
|
||||
kv_indptr,
|
||||
None,
|
||||
kv_indices,
|
||||
self.req_to_token.stride(0),
|
||||
)
|
||||
else:
|
||||
kv_indptr, kv_indices = spec_info.kv_indptr, spec_info.kv_indices
|
||||
bs = kv_indptr.shape[0] - 1
|
||||
|
||||
attn_logits = torch.zeros(
|
||||
(
|
||||
forward_batch.batch_size,
|
||||
bs,
|
||||
self.num_head,
|
||||
self.num_kv_splits,
|
||||
self.v_head_dim + 1,
|
||||
@@ -72,12 +109,24 @@ class TritonAttnBackend(AttentionBackend):
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
qo_indptr = None
|
||||
custom_mask = None
|
||||
mask_indptr = None
|
||||
max_extend_len = None
|
||||
|
||||
elif forward_batch.forward_mode.is_target_verify():
|
||||
bs = len(forward_batch.req_pool_indices)
|
||||
qo_indptr = torch.arange(
|
||||
0,
|
||||
(1 + bs) * self.num_draft_tokens,
|
||||
step=self.num_draft_tokens,
|
||||
dtype=torch.int32,
|
||||
device=self.device,
|
||||
)
|
||||
# Different with flashinfer kv_indptr and kv_indices construction
|
||||
kv_indptr[1 : bs + 1] = torch.cumsum(forward_batch.seq_lens, dim=0)
|
||||
kv_indptr = kv_indptr[: bs + 1]
|
||||
kv_indices = torch.empty(
|
||||
forward_batch.seq_lens_sum, dtype=torch.int32, device=self.device
|
||||
kv_indices = torch.zeros(
|
||||
kv_indptr[-1], dtype=torch.int32, device=self.device
|
||||
)
|
||||
create_flashinfer_kv_indices_triton[(bs,)](
|
||||
self.req_to_token,
|
||||
@@ -89,15 +138,32 @@ class TritonAttnBackend(AttentionBackend):
|
||||
self.req_to_token.stride(0),
|
||||
)
|
||||
|
||||
qo_indptr = None
|
||||
custom_mask = None
|
||||
mask_offsets = None
|
||||
custom_mask = spec_info.custom_mask
|
||||
seq_mask_len = self.num_draft_tokens * (
|
||||
forward_batch.seq_lens + self.num_draft_tokens
|
||||
)
|
||||
mask_indptr = self.mask_indptr
|
||||
mask_indptr[1 : bs + 1] = torch.cumsum(seq_mask_len[:bs], dim=0)
|
||||
mask_indptr = mask_indptr[: bs + 1]
|
||||
max_extend_len = self.num_draft_tokens
|
||||
attn_logits = None
|
||||
elif forward_batch.forward_mode.is_draft_extend():
|
||||
kv_indices, kv_indptr, qo_indptr, custom_mask = (
|
||||
spec_info.generate_attn_arg_prefill(
|
||||
forward_batch.req_pool_indices,
|
||||
forward_batch.seq_lens,
|
||||
self.req_to_token,
|
||||
)
|
||||
)
|
||||
mask_indptr = None
|
||||
max_extend_len = torch.max(spec_info.accept_length).item()
|
||||
attn_logits = None
|
||||
else:
|
||||
kv_indptr[1 : bs + 1] = torch.cumsum(
|
||||
forward_batch.extend_prefix_lens, dim=0
|
||||
)
|
||||
kv_indptr = kv_indptr[: bs + 1]
|
||||
kv_indices = torch.empty(
|
||||
kv_indices = torch.zeros(
|
||||
forward_batch.extend_prefix_lens.sum().item(),
|
||||
dtype=torch.int32,
|
||||
device=self.device,
|
||||
@@ -116,8 +182,7 @@ class TritonAttnBackend(AttentionBackend):
|
||||
qo_indptr[1 : bs + 1] = torch.cumsum(forward_batch.extend_seq_lens, dim=0)
|
||||
qo_indptr = qo_indptr[: bs + 1]
|
||||
custom_mask = None
|
||||
mask_offsets = None
|
||||
|
||||
mask_indptr = None
|
||||
attn_logits = None
|
||||
max_extend_len = torch.max(forward_batch.extend_seq_lens).item()
|
||||
|
||||
@@ -128,22 +193,22 @@ class TritonAttnBackend(AttentionBackend):
|
||||
kv_indices,
|
||||
qo_indptr,
|
||||
custom_mask,
|
||||
mask_offsets,
|
||||
mask_indptr,
|
||||
)
|
||||
|
||||
def init_cuda_graph_state(self, max_bs: int):
|
||||
self.cuda_graph_max_total_num_tokens = max_bs * self.cuda_graph_max_seq_len
|
||||
self.cuda_graph_max_total_num_tokens = max_bs * self.max_context_len
|
||||
|
||||
self.cuda_graph_start_loc = torch.zeros(
|
||||
(max_bs,), dtype=torch.int32, device=self.device
|
||||
)
|
||||
self.cuda_graph_attn_logits = torch.empty(
|
||||
self.cuda_graph_attn_logits = torch.zeros(
|
||||
(max_bs, self.num_head, self.num_kv_splits, self.v_head_dim + 1),
|
||||
dtype=torch.float32,
|
||||
device=self.device,
|
||||
)
|
||||
self.cuda_graph_kv_indices = torch.zeros(
|
||||
(max_bs * self.cuda_graph_max_seq_len),
|
||||
(max_bs * self.max_context_len),
|
||||
dtype=torch.int32,
|
||||
device=self.device,
|
||||
)
|
||||
@@ -244,8 +309,9 @@ class TritonAttnBackend(AttentionBackend):
|
||||
kv_indices,
|
||||
qo_indptr,
|
||||
custom_mask,
|
||||
mask_offsets,
|
||||
mask_indptr,
|
||||
) = self.forward_metadata
|
||||
|
||||
self.extend_attention_fwd(
|
||||
q.view(-1, layer.tp_q_head_num, layer.qk_head_dim),
|
||||
k.contiguous(),
|
||||
@@ -257,7 +323,7 @@ class TritonAttnBackend(AttentionBackend):
|
||||
kv_indptr,
|
||||
kv_indices,
|
||||
custom_mask,
|
||||
mask_offsets,
|
||||
mask_indptr,
|
||||
max_extend_len,
|
||||
layer.scaling,
|
||||
layer.logit_cap,
|
||||
@@ -303,3 +369,136 @@ class TritonAttnBackend(AttentionBackend):
|
||||
layer.logit_cap,
|
||||
)
|
||||
return o
|
||||
|
||||
|
||||
class TritonMultiStepDraftBackend:
|
||||
"""
|
||||
Wrap multiple triton attention backends as one for multiple consecutive
|
||||
draft decoding steps.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_runner: ModelRunner,
|
||||
topk: int,
|
||||
speculative_num_steps: int,
|
||||
):
|
||||
from sglang.srt.speculative.eagle_utils import generate_draft_decode_kv_indices
|
||||
|
||||
self.topk = topk
|
||||
self.speculative_num_steps = speculative_num_steps
|
||||
self.generate_draft_decode_kv_indices = generate_draft_decode_kv_indices
|
||||
max_bs = model_runner.req_to_token_pool.size
|
||||
self.kv_indptr = torch.zeros(
|
||||
(
|
||||
self.speculative_num_steps,
|
||||
max_bs + 1,
|
||||
),
|
||||
dtype=torch.int32,
|
||||
device=model_runner.device,
|
||||
)
|
||||
self.attn_backends = []
|
||||
for i in range(self.speculative_num_steps):
|
||||
self.attn_backends.append(
|
||||
TritonAttnBackend(
|
||||
model_runner,
|
||||
skip_prefill=True,
|
||||
kv_indptr_buf=self.kv_indptr[i],
|
||||
)
|
||||
)
|
||||
self.max_context_len = self.attn_backends[0].max_context_len
|
||||
# Cached variables for generate_draft_decode_kv_indices
|
||||
self.pool_len = model_runner.req_to_token_pool.req_to_token.shape[1]
|
||||
|
||||
def common_template(
|
||||
self, forward_batch: ForwardBatch, kv_indices_buffer: torch.Tensor, call_fn: int
|
||||
):
|
||||
num_seqs = forward_batch.batch_size
|
||||
bs = self.topk * num_seqs
|
||||
seq_lens_sum = forward_batch.seq_lens_sum
|
||||
|
||||
self.generate_draft_decode_kv_indices[
|
||||
(self.speculative_num_steps, num_seqs, self.topk)
|
||||
](
|
||||
forward_batch.req_pool_indices,
|
||||
forward_batch.req_to_token_pool.req_to_token,
|
||||
forward_batch.seq_lens,
|
||||
kv_indices_buffer,
|
||||
self.kv_indptr,
|
||||
forward_batch.positions,
|
||||
num_seqs,
|
||||
self.topk,
|
||||
self.pool_len,
|
||||
kv_indices_buffer.shape[1],
|
||||
self.kv_indptr.shape[1],
|
||||
triton.next_power_of_2(num_seqs),
|
||||
triton.next_power_of_2(self.speculative_num_steps),
|
||||
triton.next_power_of_2(bs),
|
||||
)
|
||||
|
||||
for i in range(self.speculative_num_steps):
|
||||
forward_batch.spec_info.kv_indptr = self.kv_indptr[i, : bs + 1]
|
||||
forward_batch.spec_info.kv_indices = kv_indices_buffer[i][
|
||||
: seq_lens_sum * self.topk + bs * (i + 1)
|
||||
]
|
||||
call_fn(i, forward_batch)
|
||||
|
||||
def init_forward_metadata(self, forward_batch: ForwardBatch):
|
||||
kv_indices = torch.zeros(
|
||||
(
|
||||
self.speculative_num_steps,
|
||||
forward_batch.batch_size * self.topk * self.max_context_len,
|
||||
),
|
||||
dtype=torch.int32,
|
||||
device="cuda",
|
||||
)
|
||||
|
||||
def call_fn(i, forward_batch):
|
||||
forward_batch.spec_info.kv_indptr = (
|
||||
forward_batch.spec_info.kv_indptr.clone()
|
||||
)
|
||||
forward_batch.spec_info.kv_indices = (
|
||||
forward_batch.spec_info.kv_indices.clone()
|
||||
)
|
||||
self.attn_backends[i].init_forward_metadata(forward_batch)
|
||||
|
||||
self.common_template(forward_batch, kv_indices, call_fn)
|
||||
|
||||
def init_cuda_graph_state(self, max_bs: int):
|
||||
self.cuda_graph_kv_indices = torch.zeros(
|
||||
(self.speculative_num_steps, max_bs * self.max_context_len),
|
||||
dtype=torch.int32,
|
||||
device="cuda",
|
||||
)
|
||||
for i in range(self.speculative_num_steps):
|
||||
self.attn_backends[i].init_cuda_graph_state(
|
||||
max_bs, kv_indices_buf=self.cuda_graph_kv_indices[i]
|
||||
)
|
||||
|
||||
def init_forward_metadata_capture_cuda_graph(self, forward_batch: ForwardBatch):
|
||||
def call_fn(i, forward_batch):
|
||||
self.attn_backends[i].init_forward_metadata_capture_cuda_graph(
|
||||
forward_batch.batch_size,
|
||||
forward_batch.batch_size * self.topk,
|
||||
forward_batch.req_pool_indices,
|
||||
forward_batch.seq_lens,
|
||||
encoder_lens=None,
|
||||
forward_mode=ForwardMode.DECODE,
|
||||
spec_info=forward_batch.spec_info,
|
||||
)
|
||||
|
||||
self.common_template(forward_batch, self.cuda_graph_kv_indices, call_fn)
|
||||
|
||||
def init_forward_metadata_replay_cuda_graph(self, forward_batch):
|
||||
def call_fn(i, forward_batch):
|
||||
self.attn_backends[i].init_forward_metadata_replay_cuda_graph(
|
||||
forward_batch.batch_size,
|
||||
forward_batch.req_pool_indices,
|
||||
forward_batch.seq_lens,
|
||||
seq_lens_sum=-1,
|
||||
encoder_lens=None,
|
||||
forward_mode=ForwardMode.DECODE,
|
||||
spec_info=forward_batch.spec_info,
|
||||
)
|
||||
|
||||
self.common_template(forward_batch, self.cuda_graph_kv_indices, call_fn)
|
||||
|
||||
@@ -50,7 +50,7 @@ def _fwd_kernel(
|
||||
kv_indptr,
|
||||
kv_indices,
|
||||
mask_ptr,
|
||||
mask_offsets,
|
||||
mask_indptr,
|
||||
sm_scale,
|
||||
kv_group_num,
|
||||
stride_qbs,
|
||||
@@ -87,7 +87,7 @@ def _fwd_kernel(
|
||||
cur_seq_len = cur_seq_len_prefix + cur_seq_len_extend
|
||||
|
||||
if USE_CUSTOM_MASK:
|
||||
cur_seq_mask_start_idx = tl.load(mask_offsets + cur_seq)
|
||||
cur_seq_mask_start_idx = tl.load(mask_indptr + cur_seq)
|
||||
|
||||
offs_d = tl.arange(0, BLOCK_DMODEL)
|
||||
offs_dv = tl.arange(0, BLOCK_DV)
|
||||
@@ -288,7 +288,7 @@ def extend_attention_fwd(
|
||||
kv_indptr,
|
||||
kv_indices,
|
||||
custom_mask,
|
||||
mask_offsets,
|
||||
mask_indptr,
|
||||
max_len_extend,
|
||||
sm_scale=None,
|
||||
logit_cap=0.0,
|
||||
@@ -364,7 +364,7 @@ def extend_attention_fwd(
|
||||
kv_indptr,
|
||||
kv_indices,
|
||||
custom_mask,
|
||||
mask_offsets,
|
||||
mask_indptr,
|
||||
sm_scale,
|
||||
kv_group_num,
|
||||
q_extend.stride(0),
|
||||
|
||||
@@ -65,15 +65,31 @@ class EAGLEWorker(TpModelWorker):
|
||||
self.model_runner.server_args.disable_cuda_graph = backup_disable_cuda_graph
|
||||
|
||||
# Create multi-step attn backends and cuda graph runners
|
||||
from sglang.srt.layers.attention.flashinfer_backend import (
|
||||
FlashInferMultiStepDraftBackend,
|
||||
)
|
||||
if server_args.attention_backend == "flashinfer":
|
||||
from sglang.srt.layers.attention.flashinfer_backend import (
|
||||
FlashInferMultiStepDraftBackend,
|
||||
)
|
||||
|
||||
self.draft_attn_backend = FlashInferMultiStepDraftBackend(
|
||||
self.model_runner,
|
||||
self.topk,
|
||||
self.speculative_num_steps,
|
||||
)
|
||||
elif server_args.attention_backend == "triton":
|
||||
from sglang.srt.layers.attention.triton_backend import (
|
||||
TritonMultiStepDraftBackend,
|
||||
)
|
||||
|
||||
self.draft_attn_backend = TritonMultiStepDraftBackend(
|
||||
self.model_runner,
|
||||
self.topk,
|
||||
self.speculative_num_steps,
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"EAGLE is not supportted in attention backend {server_args.attention_backend}"
|
||||
)
|
||||
|
||||
self.draft_attn_backend = FlashInferMultiStepDraftBackend(
|
||||
self.model_runner,
|
||||
self.topk,
|
||||
self.speculative_num_steps,
|
||||
)
|
||||
self.model_runner.draft_attn_backend = self.draft_attn_backend
|
||||
self.init_cuda_graphs()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user