Support Eagle2 for Triton backend (#3466)
This commit is contained in:
@@ -65,15 +65,31 @@ class EAGLEWorker(TpModelWorker):
|
||||
self.model_runner.server_args.disable_cuda_graph = backup_disable_cuda_graph
|
||||
|
||||
# Create multi-step attn backends and cuda graph runners
|
||||
from sglang.srt.layers.attention.flashinfer_backend import (
|
||||
FlashInferMultiStepDraftBackend,
|
||||
)
|
||||
if server_args.attention_backend == "flashinfer":
|
||||
from sglang.srt.layers.attention.flashinfer_backend import (
|
||||
FlashInferMultiStepDraftBackend,
|
||||
)
|
||||
|
||||
self.draft_attn_backend = FlashInferMultiStepDraftBackend(
|
||||
self.model_runner,
|
||||
self.topk,
|
||||
self.speculative_num_steps,
|
||||
)
|
||||
elif server_args.attention_backend == "triton":
|
||||
from sglang.srt.layers.attention.triton_backend import (
|
||||
TritonMultiStepDraftBackend,
|
||||
)
|
||||
|
||||
self.draft_attn_backend = TritonMultiStepDraftBackend(
|
||||
self.model_runner,
|
||||
self.topk,
|
||||
self.speculative_num_steps,
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"EAGLE is not supportted in attention backend {server_args.attention_backend}"
|
||||
)
|
||||
|
||||
self.draft_attn_backend = FlashInferMultiStepDraftBackend(
|
||||
self.model_runner,
|
||||
self.topk,
|
||||
self.speculative_num_steps,
|
||||
)
|
||||
self.model_runner.draft_attn_backend = self.draft_attn_backend
|
||||
self.init_cuda_graphs()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user