Support speculative decoding in the trtllm_mha attention backend (#9331)
Co-authored-by: ispobock <ispobaoke@gmail.com>
This commit is contained in:
@@ -10,13 +10,18 @@ from typing import TYPE_CHECKING, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from sglang.srt.layers.attention.flashinfer_backend import FlashInferAttnBackend
|
||||
from sglang.srt.layers.attention.flashinfer_backend import (
|
||||
FlashInferAttnBackend,
|
||||
FlashInferMultiStepDraftBackend,
|
||||
)
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
||||
from sglang.srt.utils import is_flashinfer_available
|
||||
|
||||
if is_flashinfer_available():
|
||||
import flashinfer
|
||||
|
||||
from sglang.srt.speculative.eagle_utils import EagleDraftInput
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.layers.radix_attention import RadixAttention
|
||||
from sglang.srt.model_executor.model_runner import ModelRunner
|
||||
@@ -55,9 +60,12 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend):
|
||||
model_runner: ModelRunner,
|
||||
skip_prefill: bool = False,
|
||||
kv_indptr_buf: Optional[torch.Tensor] = None,
|
||||
q_indptr_decode_buf: Optional[torch.Tensor] = None,
|
||||
kv_last_page_len_buf: Optional[torch.Tensor] = None,
|
||||
speculative_step_id: int = 0,
|
||||
):
|
||||
super().__init__(model_runner, skip_prefill, kv_indptr_buf, q_indptr_decode_buf)
|
||||
super().__init__(
|
||||
model_runner, skip_prefill, kv_indptr_buf, kv_last_page_len_buf
|
||||
)
|
||||
|
||||
config = model_runner.model_config
|
||||
|
||||
@@ -87,6 +95,16 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend):
|
||||
# CUDA graph state
|
||||
self.decode_cuda_graph_metadata = {}
|
||||
|
||||
# Speculative decoding
|
||||
# Only support topk <= 1 for now.
|
||||
self.topk = model_runner.server_args.speculative_eagle_topk or 0
|
||||
self.speculative_step_id = speculative_step_id
|
||||
self.target_verify_metadata = {}
|
||||
|
||||
self.speculative_num_draft_tokens = (
|
||||
model_runner.server_args.speculative_num_draft_tokens
|
||||
)
|
||||
|
||||
# Forward metadata
|
||||
self.forward_metadata: Optional[TRTLLMMHAMetadata] = None
|
||||
|
||||
@@ -97,11 +115,12 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend):
|
||||
kv_indices_buf: Optional[torch.Tensor] = None,
|
||||
):
|
||||
"""Initialize CUDA graph state for TRTLLM MHA."""
|
||||
max_num_pages = (self.max_context_len + self.page_size - 1) // self.page_size
|
||||
self.decode_cuda_graph_metadata = {
|
||||
"cache_seqlens": torch.zeros(max_bs, dtype=torch.int32, device=self.device),
|
||||
"page_table": torch.zeros(
|
||||
max_bs,
|
||||
(self.max_context_len + self.page_size - 1) // self.page_size,
|
||||
max_num_pages,
|
||||
dtype=torch.int32,
|
||||
device=self.device,
|
||||
),
|
||||
@@ -110,6 +129,70 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend):
|
||||
),
|
||||
}
|
||||
|
||||
if (
|
||||
self.speculative_num_draft_tokens is not None
|
||||
and self.speculative_num_draft_tokens > 0
|
||||
):
|
||||
self.decode_cuda_graph_metadata["cu_seqlens_q"] = torch.arange(
|
||||
0, max_bs + 1, dtype=torch.int32, device=self.device
|
||||
)
|
||||
self.decode_cuda_graph_metadata["cu_seqlens_k"] = torch.zeros(
|
||||
max_bs + 1, dtype=torch.int32, device=self.device
|
||||
)
|
||||
self.decode_cuda_graph_metadata["page_table_draft_decode"] = torch.zeros(
|
||||
max_bs,
|
||||
max_num_pages,
|
||||
dtype=torch.int32,
|
||||
device=self.device,
|
||||
)
|
||||
self.target_verify_metadata = {
|
||||
"cache_seqlens": torch.zeros(
|
||||
max_bs, dtype=torch.int32, device=self.device
|
||||
),
|
||||
"cu_seqlens_q": torch.arange(
|
||||
0,
|
||||
max_bs * self.speculative_num_draft_tokens + 1,
|
||||
step=self.speculative_num_draft_tokens,
|
||||
dtype=torch.int32,
|
||||
device=self.device,
|
||||
),
|
||||
"cu_seqlens_k": torch.zeros(
|
||||
max_bs + 1, dtype=torch.int32, device=self.device
|
||||
),
|
||||
"page_table": torch.zeros(
|
||||
max_bs,
|
||||
max_num_pages,
|
||||
dtype=torch.int32,
|
||||
device=self.device,
|
||||
),
|
||||
"strided_indices": torch.arange(
|
||||
0, self.max_context_len, self.page_size, device=self.device
|
||||
),
|
||||
}
|
||||
|
||||
self.draft_extend_metadata = {
|
||||
"cache_seqlens": torch.zeros(
|
||||
max_bs, dtype=torch.int32, device=self.device
|
||||
),
|
||||
"cu_seqlens_q": torch.zeros(
|
||||
max_bs + 1,
|
||||
dtype=torch.int32,
|
||||
device=self.device,
|
||||
),
|
||||
"cu_seqlens_k": torch.zeros(
|
||||
max_bs + 1, dtype=torch.int32, device=self.device
|
||||
),
|
||||
"page_table": torch.zeros(
|
||||
max_bs,
|
||||
max_num_pages,
|
||||
dtype=torch.int32,
|
||||
device=self.device,
|
||||
),
|
||||
"strided_indices": torch.arange(
|
||||
0, self.max_context_len, self.page_size, device=self.device
|
||||
),
|
||||
}
|
||||
|
||||
def init_forward_metadata_capture_cuda_graph(
|
||||
self,
|
||||
bs: int,
|
||||
@@ -122,16 +205,105 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend):
|
||||
):
|
||||
"""Initialize metadata for CUDA graph capture."""
|
||||
metadata = TRTLLMMHAMetadata()
|
||||
device = seq_lens.device
|
||||
|
||||
# Get sequence information
|
||||
metadata.cache_seqlens_int32 = seq_lens[:bs].to(torch.int32)
|
||||
if forward_mode.is_decode_or_idle():
|
||||
if spec_info is not None:
|
||||
# Draft Decode
|
||||
# Here we only support topk = 1 for now.
|
||||
metadata.cache_seqlens_int32 = self.decode_cuda_graph_metadata[
|
||||
"cache_seqlens"
|
||||
][:bs]
|
||||
metadata.max_seq_len_k = seq_lens.max().item() + (
|
||||
self.speculative_step_id + 1
|
||||
)
|
||||
metadata.cu_seqlens_q = self.decode_cuda_graph_metadata["cu_seqlens_q"][
|
||||
: bs + 1
|
||||
]
|
||||
metadata.cu_seqlens_k = torch.nn.functional.pad(
|
||||
torch.cumsum(
|
||||
metadata.cache_seqlens_int32, dim=0, dtype=torch.int32
|
||||
),
|
||||
(1, 0),
|
||||
)
|
||||
metadata.page_table = self.decode_cuda_graph_metadata[
|
||||
"page_table_draft_decode"
|
||||
][:bs, :]
|
||||
self.decode_cuda_graph_metadata[bs] = metadata
|
||||
else:
|
||||
# Normal Decode
|
||||
# Get sequence information
|
||||
metadata.cache_seqlens_int32 = seq_lens[:bs].to(torch.int32)
|
||||
batch_size = len(seq_lens)
|
||||
metadata.cu_seqlens_k = torch.nn.functional.pad(
|
||||
torch.cumsum(seq_lens, dim=0, dtype=torch.int32), (1, 0)
|
||||
)
|
||||
|
||||
# Precompute maximum sequence length
|
||||
metadata.max_seq_len_k = seq_lens[:bs].max().item()
|
||||
# Precompute maximum sequence length
|
||||
metadata.max_seq_len_k = seq_lens.max().item()
|
||||
# Precompute cumulative sequence lengths
|
||||
metadata.cu_seqlens_q = torch.arange(
|
||||
0, batch_size + 1, dtype=torch.int32, device=device
|
||||
)
|
||||
# Precompute page table
|
||||
metadata.page_table = self.decode_cuda_graph_metadata["page_table"][
|
||||
:bs, :
|
||||
]
|
||||
self.decode_cuda_graph_metadata[bs] = metadata
|
||||
elif forward_mode.is_target_verify():
|
||||
# Target Verify
|
||||
# Here we only support topk = 1 for now.
|
||||
metadata.cache_seqlens_int32 = self.target_verify_metadata["cache_seqlens"][
|
||||
:bs
|
||||
]
|
||||
metadata.cache_seqlens_int32.copy_(
|
||||
(seq_lens + self.speculative_num_draft_tokens)
|
||||
)
|
||||
|
||||
# Precompute page table
|
||||
metadata.page_table = self.decode_cuda_graph_metadata["page_table"][:bs, :]
|
||||
self.decode_cuda_graph_metadata[bs] = metadata
|
||||
metadata.cu_seqlens_q = torch.arange(
|
||||
0,
|
||||
bs * self.speculative_num_draft_tokens + 1,
|
||||
self.speculative_num_draft_tokens,
|
||||
dtype=torch.int32,
|
||||
device=device,
|
||||
)
|
||||
|
||||
metadata.cu_seqlens_k = self.target_verify_metadata["cu_seqlens_k"][
|
||||
: (bs + 1)
|
||||
]
|
||||
|
||||
metadata.max_seq_len_q = self.speculative_num_draft_tokens
|
||||
metadata.max_seq_len_k = (
|
||||
seq_lens.max().item() + self.speculative_num_draft_tokens
|
||||
)
|
||||
|
||||
metadata.page_table = self.target_verify_metadata["page_table"][:bs, :]
|
||||
|
||||
self.target_verify_metadata[bs] = metadata
|
||||
elif forward_mode.is_draft_extend():
|
||||
metadata.cache_seqlens_int32 = self.draft_extend_metadata["cache_seqlens"][
|
||||
:bs
|
||||
]
|
||||
metadata.cache_seqlens_int32.copy_(seq_lens)
|
||||
num_tokens_per_bs = num_tokens // bs
|
||||
metadata.cu_seqlens_q = torch.arange(
|
||||
0,
|
||||
bs * num_tokens_per_bs + 1,
|
||||
num_tokens_per_bs,
|
||||
dtype=torch.int32,
|
||||
device=device,
|
||||
)
|
||||
|
||||
metadata.cu_seqlens_k = self.draft_extend_metadata["cu_seqlens_k"][
|
||||
: (bs + 1)
|
||||
]
|
||||
num_tokens_per_bs = num_tokens // bs
|
||||
metadata.max_seq_len_q = num_tokens_per_bs
|
||||
metadata.max_seq_len_k = seq_lens.max().item()
|
||||
|
||||
metadata.page_table = self.draft_extend_metadata["page_table"][:bs, :]
|
||||
|
||||
self.draft_extend_metadata[bs] = metadata
|
||||
self.forward_metadata = metadata
|
||||
|
||||
def init_forward_metadata_replay_cuda_graph(
|
||||
@@ -149,21 +321,91 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend):
|
||||
seq_lens = seq_lens[:bs]
|
||||
seq_lens_cpu = seq_lens_cpu[:bs]
|
||||
req_pool_indices = req_pool_indices[:bs]
|
||||
device = seq_lens.device
|
||||
metadata = None
|
||||
if forward_mode.is_decode_or_idle():
|
||||
if spec_info is not None:
|
||||
# Draft Decode
|
||||
# Here we only support topk = 1 for now.
|
||||
metadata = self.decode_cuda_graph_metadata[bs]
|
||||
max_len = seq_lens_cpu.max().item()
|
||||
metadata.max_seq_len_k = max_len + self.speculative_step_id + 1
|
||||
|
||||
# Normal Decode
|
||||
metadata = self.decode_cuda_graph_metadata[bs]
|
||||
max_len = seq_lens_cpu.max().item()
|
||||
max_seq_pages = (max_len + self.page_size - 1) // self.page_size
|
||||
metadata.max_seq_len_k = max_len
|
||||
max_seq_pages = (
|
||||
metadata.max_seq_len_k + self.page_size - 1
|
||||
) // self.page_size
|
||||
|
||||
metadata.cache_seqlens_int32.copy_(seq_lens)
|
||||
page_indices = self.req_to_token[
|
||||
req_pool_indices[:, None],
|
||||
self.decode_cuda_graph_metadata["strided_indices"][:max_seq_pages][None, :],
|
||||
]
|
||||
metadata.page_table[:, :max_seq_pages].copy_(page_indices // self.page_size)
|
||||
metadata.cache_seqlens_int32.copy_(
|
||||
seq_lens + self.speculative_step_id + 1
|
||||
)
|
||||
else:
|
||||
# Normal Decode
|
||||
metadata = self.decode_cuda_graph_metadata[bs]
|
||||
max_len = seq_lens_cpu.max().item()
|
||||
max_seq_pages = (max_len + self.page_size - 1) // self.page_size
|
||||
metadata.max_seq_len_k = max_len
|
||||
|
||||
metadata.cache_seqlens_int32.copy_(seq_lens)
|
||||
|
||||
metadata.cu_seqlens_k[1:].copy_(
|
||||
torch.cumsum(metadata.cache_seqlens_int32, dim=0, dtype=torch.int32)
|
||||
)
|
||||
page_indices = self.req_to_token[
|
||||
req_pool_indices[:, None],
|
||||
self.decode_cuda_graph_metadata["strided_indices"][:max_seq_pages][
|
||||
None, :
|
||||
],
|
||||
]
|
||||
metadata.page_table[:, :max_seq_pages].copy_(page_indices // self.page_size)
|
||||
elif forward_mode.is_target_verify():
|
||||
# Here we only support topk = 1 for now.
|
||||
metadata = self.target_verify_metadata[bs]
|
||||
metadata.cache_seqlens_int32.copy_(
|
||||
(seq_lens + self.speculative_num_draft_tokens)
|
||||
)
|
||||
|
||||
metadata.max_seq_len_k = (
|
||||
seq_lens_cpu.max().item() + self.speculative_num_draft_tokens
|
||||
)
|
||||
max_len = seq_lens_cpu.max().item()
|
||||
metadata.cu_seqlens_k[1:].copy_(
|
||||
torch.cumsum(metadata.cache_seqlens_int32, dim=0, dtype=torch.int32)
|
||||
)
|
||||
max_seq_pages = (
|
||||
metadata.max_seq_len_k + self.page_size - 1
|
||||
) // self.page_size
|
||||
page_indices = self.req_to_token[
|
||||
req_pool_indices[:, None],
|
||||
self.decode_cuda_graph_metadata["strided_indices"][:max_seq_pages],
|
||||
]
|
||||
page_indices //= self.page_size
|
||||
metadata.page_table[:, :max_seq_pages].copy_(page_indices)
|
||||
elif forward_mode.is_draft_extend():
|
||||
metadata = self.draft_extend_metadata[bs]
|
||||
metadata.cache_seqlens_int32.copy_(seq_lens)
|
||||
|
||||
metadata.max_seq_len_k = seq_lens_cpu.max().item()
|
||||
max_len = seq_lens_cpu.max().item()
|
||||
metadata.cu_seqlens_k[1:].copy_(
|
||||
torch.cumsum(metadata.cache_seqlens_int32, dim=0, dtype=torch.int32)
|
||||
)
|
||||
accept_length = spec_info.accept_length[:bs]
|
||||
if spec_info.accept_length_cpu:
|
||||
metadata.max_seq_len_q = max(spec_info.accept_length_cpu) + 1
|
||||
else:
|
||||
metadata.max_seq_len_q = 1
|
||||
|
||||
metadata.cu_seqlens_q[1:].copy_(
|
||||
torch.cumsum(accept_length, dim=0, dtype=torch.int32)
|
||||
)
|
||||
|
||||
max_seq_pages = (
|
||||
metadata.max_seq_len_k + self.page_size - 1
|
||||
) // self.page_size
|
||||
page_indices = self.req_to_token[
|
||||
req_pool_indices[:, None],
|
||||
self.draft_extend_metadata["strided_indices"][:max_seq_pages],
|
||||
]
|
||||
metadata.page_table[:, :max_seq_pages].copy_(page_indices // self.page_size)
|
||||
self.forward_metadata = metadata
|
||||
|
||||
def get_cuda_graph_seq_len_fill_value(self) -> int:
|
||||
@@ -179,12 +421,65 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend):
|
||||
device = seqlens_in_batch.device
|
||||
|
||||
if forward_batch.forward_mode.is_decode_or_idle():
|
||||
# Normal Decode
|
||||
metadata.cache_seqlens_int32 = seqlens_in_batch.to(torch.int32)
|
||||
metadata.max_seq_len_k = forward_batch.seq_lens_cpu.max().item()
|
||||
if forward_batch.spec_info is not None:
|
||||
# Draft Decode
|
||||
# Here we only support topk = 1 for now.
|
||||
metadata.cache_seqlens_int32 = (
|
||||
seqlens_in_batch + (self.speculative_step_id + 1)
|
||||
).to(torch.int32)
|
||||
metadata.max_seq_len_k = forward_batch.seq_lens_cpu.max().item() + (
|
||||
self.speculative_step_id + 1
|
||||
)
|
||||
metadata.cu_seqlens_q = torch.arange(
|
||||
0, batch_size + 1, dtype=torch.int32, device=device
|
||||
)
|
||||
metadata.cu_seqlens_k = torch.nn.functional.pad(
|
||||
torch.cumsum(
|
||||
metadata.cache_seqlens_int32, dim=0, dtype=torch.int32
|
||||
),
|
||||
(1, 0),
|
||||
)
|
||||
metadata.page_table = forward_batch.req_to_token_pool.req_to_token[
|
||||
forward_batch.req_pool_indices, : metadata.max_seq_len_k
|
||||
]
|
||||
else:
|
||||
# Normal Decode
|
||||
metadata.cache_seqlens_int32 = seqlens_in_batch.to(torch.int32)
|
||||
metadata.max_seq_len_k = forward_batch.seq_lens_cpu.max().item()
|
||||
metadata.cu_seqlens_q = torch.arange(
|
||||
0, batch_size + 1, dtype=torch.int32, device=device
|
||||
)
|
||||
metadata.cu_seqlens_k = torch.nn.functional.pad(
|
||||
torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)
|
||||
)
|
||||
metadata.page_table = forward_batch.req_to_token_pool.req_to_token[
|
||||
forward_batch.req_pool_indices, : metadata.max_seq_len_k
|
||||
]
|
||||
elif forward_batch.forward_mode.is_target_verify():
|
||||
# Only support topk = 1 for now.
|
||||
metadata.cache_seqlens_int32 = (
|
||||
forward_batch.seq_lens + self.speculative_num_draft_tokens
|
||||
).to(torch.int32)
|
||||
metadata.max_seq_len_q = self.speculative_num_draft_tokens
|
||||
metadata.max_seq_len_k = (
|
||||
forward_batch.seq_lens_cpu.max().item()
|
||||
+ self.speculative_num_draft_tokens
|
||||
)
|
||||
metadata.cu_seqlens_q = torch.arange(
|
||||
0,
|
||||
batch_size * self.speculative_num_draft_tokens + 1,
|
||||
self.speculative_num_draft_tokens,
|
||||
dtype=torch.int32,
|
||||
device=device,
|
||||
)
|
||||
metadata.cu_seqlens_k = torch.nn.functional.pad(
|
||||
torch.cumsum(metadata.cache_seqlens_int32, dim=0, dtype=torch.int32),
|
||||
(1, 0),
|
||||
)
|
||||
metadata.page_table = forward_batch.req_to_token_pool.req_to_token[
|
||||
forward_batch.req_pool_indices, : metadata.max_seq_len_k
|
||||
]
|
||||
|
||||
else:
|
||||
metadata.cache_seqlens_int32 = seqlens_in_batch.to(torch.int32)
|
||||
metadata.max_seq_len_k = forward_batch.seq_lens_cpu.max().item()
|
||||
@@ -195,7 +490,10 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend):
|
||||
forward_batch.req_pool_indices, : metadata.max_seq_len_k
|
||||
]
|
||||
|
||||
if any(forward_batch.extend_prefix_lens_cpu):
|
||||
if (
|
||||
any(forward_batch.extend_prefix_lens_cpu)
|
||||
or forward_batch.forward_mode == ForwardMode.DRAFT_EXTEND
|
||||
):
|
||||
extend_seq_lens = forward_batch.extend_seq_lens
|
||||
metadata.max_seq_len_q = max(forward_batch.extend_seq_lens_cpu)
|
||||
metadata.cu_seqlens_q = torch.nn.functional.pad(
|
||||
@@ -332,3 +630,65 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend):
|
||||
)
|
||||
|
||||
return o.view(-1, layer.tp_q_head_num * layer.head_dim)
|
||||
|
||||
|
||||
class TRTLLMHAAttnMultiStepDraftBackend(FlashInferMultiStepDraftBackend):
|
||||
"""Multi-step TRTLLM MHA attention kernel used by EAGLE."""
|
||||
|
||||
def __init__(
|
||||
self, model_runner: ModelRunner, topk: int, speculative_num_steps: int
|
||||
):
|
||||
super().__init__(model_runner, topk, speculative_num_steps)
|
||||
for i in range(speculative_num_steps):
|
||||
self.attn_backends[i] = TRTLLMHAAttnBackend(
|
||||
model_runner,
|
||||
skip_prefill=True,
|
||||
kv_indptr_buf=self.kv_indptr[i],
|
||||
kv_last_page_len_buf=self.kv_last_page_len,
|
||||
speculative_step_id=i,
|
||||
)
|
||||
|
||||
def init_forward_metadata(self, forward_batch: ForwardBatch):
|
||||
for i in range(self.speculative_num_steps - 1):
|
||||
self.attn_backends[i].init_forward_metadata(forward_batch)
|
||||
|
||||
def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int):
|
||||
for i in range(self.speculative_num_steps):
|
||||
self.attn_backends[i].init_cuda_graph_state(max_bs, max_num_tokens)
|
||||
|
||||
def init_forward_metadata_capture_cuda_graph(
|
||||
self,
|
||||
forward_batch: ForwardBatch,
|
||||
):
|
||||
assert forward_batch.spec_info is not None
|
||||
assert isinstance(forward_batch.spec_info, EagleDraftInput)
|
||||
|
||||
for i in range(self.speculative_num_steps - 1):
|
||||
self.attn_backends[i].init_forward_metadata_capture_cuda_graph(
|
||||
forward_batch.batch_size,
|
||||
forward_batch.batch_size * self.topk,
|
||||
forward_batch.req_pool_indices,
|
||||
forward_batch.seq_lens,
|
||||
encoder_lens=forward_batch.encoder_lens,
|
||||
forward_mode=ForwardMode.DECODE,
|
||||
spec_info=forward_batch.spec_info,
|
||||
)
|
||||
|
||||
def init_forward_metadata_replay_cuda_graph(
|
||||
self, forward_batch: ForwardBatch, bs: int
|
||||
):
|
||||
assert forward_batch.spec_info is not None
|
||||
assert isinstance(forward_batch.spec_info, EagleDraftInput)
|
||||
|
||||
for i in range(self.speculative_num_steps - 1):
|
||||
|
||||
self.attn_backends[i].init_forward_metadata_replay_cuda_graph(
|
||||
bs,
|
||||
forward_batch.req_pool_indices,
|
||||
forward_batch.seq_lens,
|
||||
forward_batch.seq_lens_sum,
|
||||
encoder_lens=forward_batch.encoder_lens,
|
||||
forward_mode=ForwardMode.DECODE,
|
||||
spec_info=forward_batch.spec_info,
|
||||
seq_lens_cpu=forward_batch.seq_lens_cpu,
|
||||
)
|
||||
|
||||
@@ -500,11 +500,6 @@ class ServerArgs:
|
||||
)
|
||||
self.page_size = 64
|
||||
|
||||
if self.speculative_algorithm is not None:
|
||||
raise ValueError(
|
||||
"trtllm_mha backend does not support speculative decoding yet."
|
||||
)
|
||||
|
||||
if self.attention_backend == "dual_chunk_flash_attn":
|
||||
logger.warning(
|
||||
"Mixed chunk, radix cache, and cuda graphs are disabled because of using dual chunk flash attention backend"
|
||||
@@ -653,6 +648,16 @@ class ServerArgs:
|
||||
self.speculative_num_draft_tokens,
|
||||
) = auto_choose_speculative_params(self)
|
||||
|
||||
if (
|
||||
self.attention_backend == "trtllm_mha"
|
||||
or self.decode_attention_backend == "trtllm_mha"
|
||||
or self.prefill_attention_backend == "trtllm_mha"
|
||||
):
|
||||
if self.speculative_eagle_topk > 1:
|
||||
raise ValueError(
|
||||
"trtllm_mha backend only supports topk = 1 for speculative decoding."
|
||||
)
|
||||
|
||||
if (
|
||||
self.speculative_eagle_topk == 1
|
||||
and self.speculative_num_draft_tokens != self.speculative_num_steps + 1
|
||||
|
||||
@@ -266,6 +266,22 @@ class EAGLEWorker(TpModelWorker):
|
||||
self.topk,
|
||||
self.speculative_num_steps,
|
||||
)
|
||||
elif self.server_args.attention_backend == "trtllm_mha":
|
||||
from sglang.srt.layers.attention.trtllm_mha_backend import (
|
||||
TRTLLMHAAttnBackend,
|
||||
TRTLLMHAAttnMultiStepDraftBackend,
|
||||
)
|
||||
|
||||
self.draft_attn_backend = TRTLLMHAAttnMultiStepDraftBackend(
|
||||
self.draft_model_runner,
|
||||
self.topk,
|
||||
self.speculative_num_steps,
|
||||
)
|
||||
self.draft_extend_attn_backend = TRTLLMHAAttnBackend(
|
||||
self.draft_model_runner,
|
||||
skip_prefill=False,
|
||||
)
|
||||
self.has_prefill_wrapper_verify = True
|
||||
elif self.server_args.attention_backend == "trtllm_mla":
|
||||
if not global_server_args_dict["use_mla_backend"]:
|
||||
raise ValueError(
|
||||
|
||||
Reference in New Issue
Block a user