Support Eagle2 for Triton backend (#3466)

This commit is contained in:
Ke Bao
2025-02-10 20:00:42 +08:00
committed by GitHub
parent cddb1cdf8f
commit 2d61132374
5 changed files with 285 additions and 41 deletions

View File

@@ -65,15 +65,31 @@ class EAGLEWorker(TpModelWorker):
self.model_runner.server_args.disable_cuda_graph = backup_disable_cuda_graph
# Create multi-step attn backends and cuda graph runners
from sglang.srt.layers.attention.flashinfer_backend import (
FlashInferMultiStepDraftBackend,
)
if server_args.attention_backend == "flashinfer":
from sglang.srt.layers.attention.flashinfer_backend import (
FlashInferMultiStepDraftBackend,
)
self.draft_attn_backend = FlashInferMultiStepDraftBackend(
self.model_runner,
self.topk,
self.speculative_num_steps,
)
elif server_args.attention_backend == "triton":
from sglang.srt.layers.attention.triton_backend import (
TritonMultiStepDraftBackend,
)
self.draft_attn_backend = TritonMultiStepDraftBackend(
self.model_runner,
self.topk,
self.speculative_num_steps,
)
else:
raise ValueError(
f"EAGLE is not supportted in attention backend {server_args.attention_backend}"
)
self.draft_attn_backend = FlashInferMultiStepDraftBackend(
self.model_runner,
self.topk,
self.speculative_num_steps,
)
self.model_runner.draft_attn_backend = self.draft_attn_backend
self.init_cuda_graphs()