Reduce one step decode for draft model. (#11561)
This commit is contained in:
@@ -1064,7 +1064,7 @@ class AiterMultiStepDraftBackend:
|
|||||||
device=model_runner.device,
|
device=model_runner.device,
|
||||||
)
|
)
|
||||||
self.attn_backends = []
|
self.attn_backends = []
|
||||||
for i in range(self.speculative_num_steps):
|
for i in range(self.speculative_num_steps - 1):
|
||||||
self.attn_backends.append(
|
self.attn_backends.append(
|
||||||
AiterAttnBackend(
|
AiterAttnBackend(
|
||||||
model_runner,
|
model_runner,
|
||||||
@@ -1107,7 +1107,7 @@ class AiterMultiStepDraftBackend:
|
|||||||
self.page_size,
|
self.page_size,
|
||||||
)
|
)
|
||||||
|
|
||||||
for i in range(self.speculative_num_steps):
|
for i in range(self.speculative_num_steps - 1):
|
||||||
forward_batch.spec_info.kv_indptr = self.kv_indptr[i, : bs + 1]
|
forward_batch.spec_info.kv_indptr = self.kv_indptr[i, : bs + 1]
|
||||||
forward_batch.spec_info.kv_indices = kv_indices_buffer[i][
|
forward_batch.spec_info.kv_indices = kv_indices_buffer[i][
|
||||||
: seq_lens_sum * self.topk + bs * (i + 1)
|
: seq_lens_sum * self.topk + bs * (i + 1)
|
||||||
@@ -1141,7 +1141,7 @@ class AiterMultiStepDraftBackend:
|
|||||||
dtype=torch.int32,
|
dtype=torch.int32,
|
||||||
device=self.device,
|
device=self.device,
|
||||||
)
|
)
|
||||||
for i in range(self.speculative_num_steps):
|
for i in range(self.speculative_num_steps - 1):
|
||||||
self.attn_backends[i].init_cuda_graph_state(
|
self.attn_backends[i].init_cuda_graph_state(
|
||||||
max_bs, max_num_tokens, kv_indices_buf=self.cuda_graph_kv_indices[i]
|
max_bs, max_num_tokens, kv_indices_buf=self.cuda_graph_kv_indices[i]
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -2320,7 +2320,7 @@ class FlashAttentionMultiStepBackend:
|
|||||||
self.topk = topk
|
self.topk = topk
|
||||||
self.speculative_num_steps = speculative_num_steps
|
self.speculative_num_steps = speculative_num_steps
|
||||||
self.attn_backends = []
|
self.attn_backends = []
|
||||||
for i in range(self.speculative_num_steps):
|
for i in range(self.speculative_num_steps - 1):
|
||||||
self.attn_backends.append(
|
self.attn_backends.append(
|
||||||
FlashAttentionBackend(
|
FlashAttentionBackend(
|
||||||
model_runner,
|
model_runner,
|
||||||
@@ -2335,7 +2335,7 @@ class FlashAttentionMultiStepBackend:
|
|||||||
self.attn_backends[i].init_forward_metadata(forward_batch)
|
self.attn_backends[i].init_forward_metadata(forward_batch)
|
||||||
|
|
||||||
def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int):
|
def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int):
|
||||||
for i in range(self.speculative_num_steps):
|
for i in range(self.speculative_num_steps - 1):
|
||||||
self.attn_backends[i].init_cuda_graph_state(max_bs, max_num_tokens)
|
self.attn_backends[i].init_cuda_graph_state(max_bs, max_num_tokens)
|
||||||
|
|
||||||
def init_forward_metadata_capture_cuda_graph(
|
def init_forward_metadata_capture_cuda_graph(
|
||||||
|
|||||||
@@ -1405,7 +1405,7 @@ class FlashInferMultiStepDraftBackend:
|
|||||||
(max_bs,), dtype=torch.int32, device=model_runner.device
|
(max_bs,), dtype=torch.int32, device=model_runner.device
|
||||||
)
|
)
|
||||||
self.attn_backends: List[FlashInferAttnBackend] = []
|
self.attn_backends: List[FlashInferAttnBackend] = []
|
||||||
for i in range(self.speculative_num_steps):
|
for i in range(self.speculative_num_steps - 1):
|
||||||
self.attn_backends.append(
|
self.attn_backends.append(
|
||||||
FlashInferAttnBackend(
|
FlashInferAttnBackend(
|
||||||
model_runner,
|
model_runner,
|
||||||
@@ -1493,7 +1493,7 @@ class FlashInferMultiStepDraftBackend:
|
|||||||
device="cuda",
|
device="cuda",
|
||||||
)
|
)
|
||||||
|
|
||||||
for i in range(self.speculative_num_steps):
|
for i in range(self.speculative_num_steps - 1):
|
||||||
self.attn_backends[i].init_cuda_graph_state(
|
self.attn_backends[i].init_cuda_graph_state(
|
||||||
max_bs, max_num_tokens, kv_indices_buf=self.cuda_graph_kv_indices[i]
|
max_bs, max_num_tokens, kv_indices_buf=self.cuda_graph_kv_indices[i]
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -916,7 +916,7 @@ class FlashInferMLAMultiStepDraftBackend:
|
|||||||
)
|
)
|
||||||
|
|
||||||
self.attn_backends = []
|
self.attn_backends = []
|
||||||
for i in range(self.speculative_num_steps):
|
for i in range(self.speculative_num_steps - 1):
|
||||||
self.attn_backends.append(
|
self.attn_backends.append(
|
||||||
FlashInferMLAAttnBackend(
|
FlashInferMLAAttnBackend(
|
||||||
model_runner,
|
model_runner,
|
||||||
@@ -998,7 +998,7 @@ class FlashInferMLAMultiStepDraftBackend:
|
|||||||
device="cuda",
|
device="cuda",
|
||||||
)
|
)
|
||||||
|
|
||||||
for i in range(self.speculative_num_steps):
|
for i in range(self.speculative_num_steps - 1):
|
||||||
self.attn_backends[i].init_cuda_graph_state(
|
self.attn_backends[i].init_cuda_graph_state(
|
||||||
max_bs, max_num_tokens, kv_indices_buf=self.cuda_graph_kv_indices[i]
|
max_bs, max_num_tokens, kv_indices_buf=self.cuda_graph_kv_indices[i]
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -478,7 +478,7 @@ class FlashMLAMultiStepDraftBackend:
|
|||||||
)
|
)
|
||||||
|
|
||||||
self.attn_backends = []
|
self.attn_backends = []
|
||||||
for i in range(self.speculative_num_steps):
|
for i in range(self.speculative_num_steps - 1):
|
||||||
self.attn_backends.append(
|
self.attn_backends.append(
|
||||||
FlashMLABackend(
|
FlashMLABackend(
|
||||||
model_runner,
|
model_runner,
|
||||||
@@ -506,7 +506,7 @@ class FlashMLAMultiStepDraftBackend:
|
|||||||
self.common_template(forward_batch, call_fn)
|
self.common_template(forward_batch, call_fn)
|
||||||
|
|
||||||
def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int):
|
def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int):
|
||||||
for i in range(self.speculative_num_steps):
|
for i in range(self.speculative_num_steps - 1):
|
||||||
self.attn_backends[i].init_cuda_graph_state(
|
self.attn_backends[i].init_cuda_graph_state(
|
||||||
max_bs, max_num_tokens, block_kv_indices=None
|
max_bs, max_num_tokens, block_kv_indices=None
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -918,7 +918,7 @@ class TritonMultiStepDraftBackend:
|
|||||||
device=model_runner.device,
|
device=model_runner.device,
|
||||||
)
|
)
|
||||||
self.attn_backends: List[TritonAttnBackend] = []
|
self.attn_backends: List[TritonAttnBackend] = []
|
||||||
for i in range(self.speculative_num_steps):
|
for i in range(self.speculative_num_steps - 1):
|
||||||
self.attn_backends.append(
|
self.attn_backends.append(
|
||||||
TritonAttnBackend(
|
TritonAttnBackend(
|
||||||
model_runner,
|
model_runner,
|
||||||
@@ -969,7 +969,7 @@ class TritonMultiStepDraftBackend:
|
|||||||
if call_fn is None:
|
if call_fn is None:
|
||||||
return
|
return
|
||||||
|
|
||||||
for i in range(self.speculative_num_steps):
|
for i in range(self.speculative_num_steps - 1):
|
||||||
forward_batch.spec_info.kv_indptr = self.kv_indptr[i, : bs + 1]
|
forward_batch.spec_info.kv_indptr = self.kv_indptr[i, : bs + 1]
|
||||||
forward_batch.spec_info.kv_indices = kv_indices_buffer[i][
|
forward_batch.spec_info.kv_indices = kv_indices_buffer[i][
|
||||||
: seq_lens_sum * self.topk + bs * (i + 1)
|
: seq_lens_sum * self.topk + bs * (i + 1)
|
||||||
@@ -1009,7 +1009,8 @@ class TritonMultiStepDraftBackend:
|
|||||||
dtype=torch.int32,
|
dtype=torch.int32,
|
||||||
device=self.device,
|
device=self.device,
|
||||||
)
|
)
|
||||||
for i in range(self.speculative_num_steps):
|
|
||||||
|
for i in range(self.speculative_num_steps - 1):
|
||||||
self.attn_backends[i].init_cuda_graph_state(
|
self.attn_backends[i].init_cuda_graph_state(
|
||||||
max_bs,
|
max_bs,
|
||||||
max_num_tokens,
|
max_num_tokens,
|
||||||
|
|||||||
@@ -637,7 +637,7 @@ class TRTLLMHAAttnMultiStepDraftBackend(FlashInferMultiStepDraftBackend):
|
|||||||
self, model_runner: ModelRunner, topk: int, speculative_num_steps: int
|
self, model_runner: ModelRunner, topk: int, speculative_num_steps: int
|
||||||
):
|
):
|
||||||
super().__init__(model_runner, topk, speculative_num_steps)
|
super().__init__(model_runner, topk, speculative_num_steps)
|
||||||
for i in range(speculative_num_steps):
|
for i in range(self.speculative_num_steps - 1):
|
||||||
self.attn_backends[i] = TRTLLMHAAttnBackend(
|
self.attn_backends[i] = TRTLLMHAAttnBackend(
|
||||||
model_runner,
|
model_runner,
|
||||||
skip_prefill=True,
|
skip_prefill=True,
|
||||||
@@ -651,7 +651,7 @@ class TRTLLMHAAttnMultiStepDraftBackend(FlashInferMultiStepDraftBackend):
|
|||||||
self.attn_backends[i].init_forward_metadata(forward_batch)
|
self.attn_backends[i].init_forward_metadata(forward_batch)
|
||||||
|
|
||||||
def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int):
|
def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int):
|
||||||
for i in range(self.speculative_num_steps):
|
for i in range(self.speculative_num_steps - 1):
|
||||||
self.attn_backends[i].init_cuda_graph_state(max_bs, max_num_tokens)
|
self.attn_backends[i].init_cuda_graph_state(max_bs, max_num_tokens)
|
||||||
|
|
||||||
def init_forward_metadata_capture_cuda_graph(
|
def init_forward_metadata_capture_cuda_graph(
|
||||||
|
|||||||
@@ -735,7 +735,7 @@ class TRTLLMMLAMultiStepDraftBackend(FlashInferMLAMultiStepDraftBackend):
|
|||||||
):
|
):
|
||||||
super().__init__(model_runner, topk, speculative_num_steps)
|
super().__init__(model_runner, topk, speculative_num_steps)
|
||||||
|
|
||||||
for i in range(self.speculative_num_steps):
|
for i in range(self.speculative_num_steps - 1):
|
||||||
self.attn_backends[i] = TRTLLMMLABackend(
|
self.attn_backends[i] = TRTLLMMLABackend(
|
||||||
model_runner,
|
model_runner,
|
||||||
skip_prefill=True,
|
skip_prefill=True,
|
||||||
|
|||||||
222
python/sglang/srt/speculative/draft_utils.py
Normal file
222
python/sglang/srt/speculative/draft_utils.py
Normal file
@@ -0,0 +1,222 @@
|
|||||||
|
import logging
|
||||||
|
|
||||||
|
from sglang.srt.server_args import ServerArgs, get_global_server_args
|
||||||
|
from sglang.srt.utils.common import is_blackwell
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class DraftBackendFactory:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
server_args: ServerArgs,
|
||||||
|
draft_model_runner,
|
||||||
|
topk: int,
|
||||||
|
speculative_num_steps: int,
|
||||||
|
):
|
||||||
|
self.server_args = server_args
|
||||||
|
self.draft_model_runner = draft_model_runner
|
||||||
|
self.topk = topk
|
||||||
|
self.speculative_num_steps = speculative_num_steps
|
||||||
|
|
||||||
|
def _create_backend(
|
||||||
|
self, backend_name: str, backend_map: dict, error_template: str
|
||||||
|
):
|
||||||
|
backend_type = getattr(self.server_args, backend_name)
|
||||||
|
if backend_type is None:
|
||||||
|
backend_type = self.server_args.attention_backend
|
||||||
|
|
||||||
|
if backend_type not in backend_map:
|
||||||
|
raise ValueError(error_template.format(backend_type=backend_type))
|
||||||
|
|
||||||
|
return backend_map[backend_type]()
|
||||||
|
|
||||||
|
def create_decode_backend(self):
|
||||||
|
if self.speculative_num_steps == 1:
|
||||||
|
|
||||||
|
class DummyAttnBackend:
|
||||||
|
def __init__(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def init_forward_metadata(*args, **kwargs):
|
||||||
|
pass
|
||||||
|
|
||||||
|
return DummyAttnBackend()
|
||||||
|
|
||||||
|
backend_map = {
|
||||||
|
"flashinfer": self._create_flashinfer_decode_backend,
|
||||||
|
"triton": self._create_triton_decode_backend,
|
||||||
|
"aiter": self._create_aiter_decode_backend,
|
||||||
|
"fa3": self._create_fa3_decode_backend,
|
||||||
|
"hybrid_linear_attn": (
|
||||||
|
self._create_fa3_decode_backend
|
||||||
|
if not is_blackwell()
|
||||||
|
else self._create_triton_decode_backend
|
||||||
|
),
|
||||||
|
"flashmla": self._create_flashmla_decode_backend,
|
||||||
|
"trtllm_mha": self._create_trtllm_mha_decode_backend,
|
||||||
|
"trtllm_mla": self._create_trtllm_mla_decode_backend,
|
||||||
|
}
|
||||||
|
|
||||||
|
return self._create_backend(
|
||||||
|
"decode_attention_backend",
|
||||||
|
backend_map,
|
||||||
|
"EAGLE is not supported in decode attention backend {backend_type}",
|
||||||
|
)
|
||||||
|
|
||||||
|
def create_draft_extend_backend(self):
|
||||||
|
backend_map = {
|
||||||
|
"flashinfer": self._create_flashinfer_prefill_backend,
|
||||||
|
"triton": self._create_triton_prefill_backend,
|
||||||
|
"aiter": self._create_aiter_prefill_backend,
|
||||||
|
"fa3": self._create_fa3_prefill_backend,
|
||||||
|
"hybrid_linear_attn": (
|
||||||
|
self._create_fa3_prefill_backend
|
||||||
|
if not is_blackwell()
|
||||||
|
else self._create_triton_prefill_backend
|
||||||
|
),
|
||||||
|
"flashmla": self._create_flashmla_prefill_backend,
|
||||||
|
"trtllm_mha": self._create_trtllm_mha_prefill_backend,
|
||||||
|
"trtllm_mla": self._create_trtllm_mla_prefill_backend,
|
||||||
|
}
|
||||||
|
backend_name = (
|
||||||
|
"decode_attention_backend"
|
||||||
|
if self.server_args.speculative_attention_mode == "decode"
|
||||||
|
else "prefill_attention_backend"
|
||||||
|
)
|
||||||
|
return self._create_backend(
|
||||||
|
backend_name,
|
||||||
|
backend_map,
|
||||||
|
"EAGLE is not supported in attention backend {backend_type}",
|
||||||
|
)
|
||||||
|
|
||||||
|
def _create_flashinfer_decode_backend(self):
|
||||||
|
if not get_global_server_args().use_mla_backend:
|
||||||
|
from sglang.srt.layers.attention.flashinfer_backend import (
|
||||||
|
FlashInferMultiStepDraftBackend,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.has_prefill_wrapper_verify = True
|
||||||
|
return FlashInferMultiStepDraftBackend(
|
||||||
|
self.draft_model_runner, self.topk, self.speculative_num_steps
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
from sglang.srt.layers.attention.flashinfer_mla_backend import (
|
||||||
|
FlashInferMLAMultiStepDraftBackend,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.has_prefill_wrapper_verify = True
|
||||||
|
return FlashInferMLAMultiStepDraftBackend(
|
||||||
|
self.draft_model_runner, self.topk, self.speculative_num_steps
|
||||||
|
)
|
||||||
|
|
||||||
|
def _create_triton_decode_backend(self):
|
||||||
|
from sglang.srt.layers.attention.triton_backend import (
|
||||||
|
TritonMultiStepDraftBackend,
|
||||||
|
)
|
||||||
|
|
||||||
|
return TritonMultiStepDraftBackend(
|
||||||
|
self.draft_model_runner, self.topk, self.speculative_num_steps
|
||||||
|
)
|
||||||
|
|
||||||
|
def _create_aiter_decode_backend(self):
|
||||||
|
from sglang.srt.layers.attention.aiter_backend import AiterMultiStepDraftBackend
|
||||||
|
|
||||||
|
return AiterMultiStepDraftBackend(
|
||||||
|
self.draft_model_runner, self.topk, self.speculative_num_steps
|
||||||
|
)
|
||||||
|
|
||||||
|
def _create_fa3_decode_backend(self):
|
||||||
|
from sglang.srt.layers.attention.flashattention_backend import (
|
||||||
|
FlashAttentionMultiStepBackend,
|
||||||
|
)
|
||||||
|
|
||||||
|
return FlashAttentionMultiStepBackend(
|
||||||
|
self.draft_model_runner, self.topk, self.speculative_num_steps
|
||||||
|
)
|
||||||
|
|
||||||
|
def _create_flashmla_decode_backend(self):
|
||||||
|
from sglang.srt.layers.attention.flashmla_backend import (
|
||||||
|
FlashMLAMultiStepDraftBackend,
|
||||||
|
)
|
||||||
|
|
||||||
|
return FlashMLAMultiStepDraftBackend(
|
||||||
|
self.draft_model_runner, self.topk, self.speculative_num_steps
|
||||||
|
)
|
||||||
|
|
||||||
|
def _create_trtllm_mha_decode_backend(self):
|
||||||
|
from sglang.srt.layers.attention.trtllm_mha_backend import (
|
||||||
|
TRTLLMHAAttnMultiStepDraftBackend,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.has_prefill_wrapper_verify = True
|
||||||
|
return TRTLLMHAAttnMultiStepDraftBackend(
|
||||||
|
self.draft_model_runner, self.topk, self.speculative_num_steps
|
||||||
|
)
|
||||||
|
|
||||||
|
def _create_trtllm_mla_decode_backend(self):
|
||||||
|
if not get_global_server_args().use_mla_backend:
|
||||||
|
raise ValueError(
|
||||||
|
"trtllm_mla backend requires MLA model (use_mla_backend=True)."
|
||||||
|
)
|
||||||
|
|
||||||
|
from sglang.srt.layers.attention.trtllm_mla_backend import (
|
||||||
|
TRTLLMMLAMultiStepDraftBackend,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.has_prefill_wrapper_verify = True
|
||||||
|
return TRTLLMMLAMultiStepDraftBackend(
|
||||||
|
self.draft_model_runner, self.topk, self.speculative_num_steps
|
||||||
|
)
|
||||||
|
|
||||||
|
def _create_flashinfer_prefill_backend(self):
|
||||||
|
if not get_global_server_args().use_mla_backend:
|
||||||
|
from sglang.srt.layers.attention.flashinfer_backend import (
|
||||||
|
FlashInferAttnBackend,
|
||||||
|
)
|
||||||
|
|
||||||
|
return FlashInferAttnBackend(self.draft_model_runner, skip_prefill=False)
|
||||||
|
else:
|
||||||
|
from sglang.srt.layers.attention.flashinfer_mla_backend import (
|
||||||
|
FlashInferMLAAttnBackend,
|
||||||
|
)
|
||||||
|
|
||||||
|
return FlashInferMLAAttnBackend(self.draft_model_runner, skip_prefill=False)
|
||||||
|
|
||||||
|
def _create_triton_prefill_backend(self):
|
||||||
|
from sglang.srt.layers.attention.triton_backend import TritonAttnBackend
|
||||||
|
|
||||||
|
return TritonAttnBackend(self.draft_model_runner, skip_prefill=False)
|
||||||
|
|
||||||
|
def _create_aiter_prefill_backend(self):
|
||||||
|
from sglang.srt.layers.attention.aiter_backend import AiterAttnBackend
|
||||||
|
|
||||||
|
return AiterAttnBackend(self.draft_model_runner, skip_prefill=False)
|
||||||
|
|
||||||
|
def _create_fa3_prefill_backend(self):
|
||||||
|
from sglang.srt.layers.attention.flashattention_backend import (
|
||||||
|
FlashAttentionBackend,
|
||||||
|
)
|
||||||
|
|
||||||
|
return FlashAttentionBackend(self.draft_model_runner, skip_prefill=False)
|
||||||
|
|
||||||
|
def _create_trtllm_mha_prefill_backend(self):
|
||||||
|
from sglang.srt.layers.attention.trtllm_mha_backend import TRTLLMHAAttnBackend
|
||||||
|
|
||||||
|
return TRTLLMHAAttnBackend(self.draft_model_runner, skip_prefill=False)
|
||||||
|
|
||||||
|
def _create_trtllm_mla_prefill_backend(self):
|
||||||
|
if not get_global_server_args().use_mla_backend:
|
||||||
|
raise ValueError(
|
||||||
|
"trtllm_mla backend requires MLA model (use_mla_backend=True)."
|
||||||
|
)
|
||||||
|
|
||||||
|
from sglang.srt.layers.attention.trtllm_mla_backend import TRTLLMMLABackend
|
||||||
|
|
||||||
|
return TRTLLMMLABackend(self.draft_model_runner, skip_prefill=False)
|
||||||
|
|
||||||
|
def _create_flashmla_prefill_backend(self):
|
||||||
|
logger.warning(
|
||||||
|
"flashmla prefill backend is not yet supported for draft extend."
|
||||||
|
)
|
||||||
|
return None
|
||||||
@@ -27,7 +27,8 @@ from sglang.srt.model_executor.forward_batch_info import (
|
|||||||
ForwardBatch,
|
ForwardBatch,
|
||||||
ForwardMode,
|
ForwardMode,
|
||||||
)
|
)
|
||||||
from sglang.srt.server_args import ServerArgs, get_global_server_args
|
from sglang.srt.server_args import ServerArgs
|
||||||
|
from sglang.srt.speculative.draft_utils import DraftBackendFactory
|
||||||
from sglang.srt.speculative.eagle_draft_cuda_graph_runner import (
|
from sglang.srt.speculative.eagle_draft_cuda_graph_runner import (
|
||||||
EAGLEDraftCudaGraphRunner,
|
EAGLEDraftCudaGraphRunner,
|
||||||
)
|
)
|
||||||
@@ -195,205 +196,23 @@ class EAGLEWorker(TpModelWorker):
|
|||||||
self.has_prefill_wrapper_verify = False
|
self.has_prefill_wrapper_verify = False
|
||||||
self.draft_extend_attn_backend = None
|
self.draft_extend_attn_backend = None
|
||||||
|
|
||||||
|
draft_backend_factory = DraftBackendFactory(
|
||||||
|
self.server_args,
|
||||||
|
self.draft_model_runner,
|
||||||
|
self.topk,
|
||||||
|
self.speculative_num_steps,
|
||||||
|
)
|
||||||
|
|
||||||
# Initialize decode attention backend
|
# Initialize decode attention backend
|
||||||
self.draft_attn_backend = self._create_decode_backend()
|
self.draft_attn_backend = draft_backend_factory.create_decode_backend()
|
||||||
|
|
||||||
# Initialize draft extend attention backend (respects speculative_attention_mode setting)
|
# Initialize draft extend attention backend (respects speculative_attention_mode setting)
|
||||||
self.draft_extend_attn_backend = self._create_draft_extend_backend()
|
self.draft_extend_attn_backend = (
|
||||||
|
draft_backend_factory.create_draft_extend_backend()
|
||||||
|
)
|
||||||
|
|
||||||
self.draft_model_runner.draft_attn_backend = self.draft_attn_backend
|
self.draft_model_runner.draft_attn_backend = self.draft_attn_backend
|
||||||
|
|
||||||
def _create_backend(
|
|
||||||
self, backend_name: str, backend_map: dict, error_template: str
|
|
||||||
):
|
|
||||||
backend_type = getattr(self.server_args, backend_name)
|
|
||||||
if backend_type is None:
|
|
||||||
backend_type = self.server_args.attention_backend
|
|
||||||
|
|
||||||
if backend_type not in backend_map:
|
|
||||||
raise ValueError(error_template.format(backend_type=backend_type))
|
|
||||||
|
|
||||||
return backend_map[backend_type]()
|
|
||||||
|
|
||||||
def _create_decode_backend(self):
|
|
||||||
backend_map = {
|
|
||||||
"flashinfer": self._create_flashinfer_decode_backend,
|
|
||||||
"triton": self._create_triton_decode_backend,
|
|
||||||
"aiter": self._create_aiter_decode_backend,
|
|
||||||
"fa3": self._create_fa3_decode_backend,
|
|
||||||
"hybrid_linear_attn": (
|
|
||||||
self._create_fa3_decode_backend
|
|
||||||
if not is_blackwell()
|
|
||||||
else self._create_triton_decode_backend
|
|
||||||
),
|
|
||||||
"flashmla": self._create_flashmla_decode_backend,
|
|
||||||
"trtllm_mha": self._create_trtllm_mha_decode_backend,
|
|
||||||
"trtllm_mla": self._create_trtllm_mla_decode_backend,
|
|
||||||
}
|
|
||||||
|
|
||||||
return self._create_backend(
|
|
||||||
"decode_attention_backend",
|
|
||||||
backend_map,
|
|
||||||
"EAGLE is not supported in decode attention backend {backend_type}",
|
|
||||||
)
|
|
||||||
|
|
||||||
def _create_draft_extend_backend(self):
|
|
||||||
backend_map = {
|
|
||||||
"flashinfer": self._create_flashinfer_prefill_backend,
|
|
||||||
"triton": self._create_triton_prefill_backend,
|
|
||||||
"aiter": self._create_aiter_prefill_backend,
|
|
||||||
"fa3": self._create_fa3_prefill_backend,
|
|
||||||
"hybrid_linear_attn": (
|
|
||||||
self._create_fa3_prefill_backend
|
|
||||||
if not is_blackwell()
|
|
||||||
else self._create_triton_prefill_backend
|
|
||||||
),
|
|
||||||
"flashmla": self._create_flashmla_prefill_backend,
|
|
||||||
"trtllm_mha": self._create_trtllm_mha_prefill_backend,
|
|
||||||
"trtllm_mla": self._create_trtllm_mla_prefill_backend,
|
|
||||||
}
|
|
||||||
backend_name = (
|
|
||||||
"decode_attention_backend"
|
|
||||||
if self.server_args.speculative_attention_mode == "decode"
|
|
||||||
else "prefill_attention_backend"
|
|
||||||
)
|
|
||||||
return self._create_backend(
|
|
||||||
backend_name,
|
|
||||||
backend_map,
|
|
||||||
"EAGLE is not supported in attention backend {backend_type}",
|
|
||||||
)
|
|
||||||
|
|
||||||
def _create_flashinfer_decode_backend(self):
|
|
||||||
if not get_global_server_args().use_mla_backend:
|
|
||||||
from sglang.srt.layers.attention.flashinfer_backend import (
|
|
||||||
FlashInferMultiStepDraftBackend,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.has_prefill_wrapper_verify = True
|
|
||||||
return FlashInferMultiStepDraftBackend(
|
|
||||||
self.draft_model_runner, self.topk, self.speculative_num_steps
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
from sglang.srt.layers.attention.flashinfer_mla_backend import (
|
|
||||||
FlashInferMLAMultiStepDraftBackend,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.has_prefill_wrapper_verify = True
|
|
||||||
return FlashInferMLAMultiStepDraftBackend(
|
|
||||||
self.draft_model_runner, self.topk, self.speculative_num_steps
|
|
||||||
)
|
|
||||||
|
|
||||||
def _create_triton_decode_backend(self):
|
|
||||||
from sglang.srt.layers.attention.triton_backend import (
|
|
||||||
TritonMultiStepDraftBackend,
|
|
||||||
)
|
|
||||||
|
|
||||||
return TritonMultiStepDraftBackend(
|
|
||||||
self.draft_model_runner, self.topk, self.speculative_num_steps
|
|
||||||
)
|
|
||||||
|
|
||||||
def _create_aiter_decode_backend(self):
|
|
||||||
from sglang.srt.layers.attention.aiter_backend import AiterMultiStepDraftBackend
|
|
||||||
|
|
||||||
return AiterMultiStepDraftBackend(
|
|
||||||
self.draft_model_runner, self.topk, self.speculative_num_steps
|
|
||||||
)
|
|
||||||
|
|
||||||
def _create_fa3_decode_backend(self):
|
|
||||||
from sglang.srt.layers.attention.flashattention_backend import (
|
|
||||||
FlashAttentionMultiStepBackend,
|
|
||||||
)
|
|
||||||
|
|
||||||
return FlashAttentionMultiStepBackend(
|
|
||||||
self.draft_model_runner, self.topk, self.speculative_num_steps
|
|
||||||
)
|
|
||||||
|
|
||||||
def _create_flashmla_decode_backend(self):
|
|
||||||
from sglang.srt.layers.attention.flashmla_backend import (
|
|
||||||
FlashMLAMultiStepDraftBackend,
|
|
||||||
)
|
|
||||||
|
|
||||||
return FlashMLAMultiStepDraftBackend(
|
|
||||||
self.draft_model_runner, self.topk, self.speculative_num_steps
|
|
||||||
)
|
|
||||||
|
|
||||||
def _create_trtllm_mha_decode_backend(self):
|
|
||||||
from sglang.srt.layers.attention.trtllm_mha_backend import (
|
|
||||||
TRTLLMHAAttnMultiStepDraftBackend,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.has_prefill_wrapper_verify = True
|
|
||||||
return TRTLLMHAAttnMultiStepDraftBackend(
|
|
||||||
self.draft_model_runner, self.topk, self.speculative_num_steps
|
|
||||||
)
|
|
||||||
|
|
||||||
def _create_trtllm_mla_decode_backend(self):
|
|
||||||
if not get_global_server_args().use_mla_backend:
|
|
||||||
raise ValueError(
|
|
||||||
"trtllm_mla backend requires MLA model (use_mla_backend=True)."
|
|
||||||
)
|
|
||||||
|
|
||||||
from sglang.srt.layers.attention.trtllm_mla_backend import (
|
|
||||||
TRTLLMMLAMultiStepDraftBackend,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.has_prefill_wrapper_verify = True
|
|
||||||
return TRTLLMMLAMultiStepDraftBackend(
|
|
||||||
self.draft_model_runner, self.topk, self.speculative_num_steps
|
|
||||||
)
|
|
||||||
|
|
||||||
def _create_flashinfer_prefill_backend(self):
|
|
||||||
if not get_global_server_args().use_mla_backend:
|
|
||||||
from sglang.srt.layers.attention.flashinfer_backend import (
|
|
||||||
FlashInferAttnBackend,
|
|
||||||
)
|
|
||||||
|
|
||||||
return FlashInferAttnBackend(self.draft_model_runner, skip_prefill=False)
|
|
||||||
else:
|
|
||||||
from sglang.srt.layers.attention.flashinfer_mla_backend import (
|
|
||||||
FlashInferMLAAttnBackend,
|
|
||||||
)
|
|
||||||
|
|
||||||
return FlashInferMLAAttnBackend(self.draft_model_runner, skip_prefill=False)
|
|
||||||
|
|
||||||
def _create_triton_prefill_backend(self):
|
|
||||||
from sglang.srt.layers.attention.triton_backend import TritonAttnBackend
|
|
||||||
|
|
||||||
return TritonAttnBackend(self.draft_model_runner, skip_prefill=False)
|
|
||||||
|
|
||||||
def _create_aiter_prefill_backend(self):
|
|
||||||
from sglang.srt.layers.attention.aiter_backend import AiterAttnBackend
|
|
||||||
|
|
||||||
return AiterAttnBackend(self.draft_model_runner, skip_prefill=False)
|
|
||||||
|
|
||||||
def _create_fa3_prefill_backend(self):
|
|
||||||
from sglang.srt.layers.attention.flashattention_backend import (
|
|
||||||
FlashAttentionBackend,
|
|
||||||
)
|
|
||||||
|
|
||||||
return FlashAttentionBackend(self.draft_model_runner, skip_prefill=False)
|
|
||||||
|
|
||||||
def _create_trtllm_mha_prefill_backend(self):
|
|
||||||
from sglang.srt.layers.attention.trtllm_mha_backend import TRTLLMHAAttnBackend
|
|
||||||
|
|
||||||
return TRTLLMHAAttnBackend(self.draft_model_runner, skip_prefill=False)
|
|
||||||
|
|
||||||
def _create_trtllm_mla_prefill_backend(self):
|
|
||||||
if not get_global_server_args().use_mla_backend:
|
|
||||||
raise ValueError(
|
|
||||||
"trtllm_mla backend requires MLA model (use_mla_backend=True)."
|
|
||||||
)
|
|
||||||
|
|
||||||
from sglang.srt.layers.attention.trtllm_mla_backend import TRTLLMMLABackend
|
|
||||||
|
|
||||||
return TRTLLMMLABackend(self.draft_model_runner, skip_prefill=False)
|
|
||||||
|
|
||||||
def _create_flashmla_prefill_backend(self):
|
|
||||||
logger.warning(
|
|
||||||
"flashmla prefill backend is not yet supported for draft extend."
|
|
||||||
)
|
|
||||||
return None
|
|
||||||
|
|
||||||
def init_cuda_graphs(self):
|
def init_cuda_graphs(self):
|
||||||
"""Capture cuda graphs."""
|
"""Capture cuda graphs."""
|
||||||
self.cuda_graph_runner = None
|
self.cuda_graph_runner = None
|
||||||
|
|||||||
Reference in New Issue
Block a user