[Revision] Replace enable_flashinfer_mla argument with attention_backend (#5052)
This commit is contained in:
@@ -11,7 +11,11 @@ from sglang.srt.distributed import GroupCoordinator, patch_tensor_parallel_group
|
||||
from sglang.srt.layers.dp_attention import disable_dp_size
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
||||
from sglang.srt.layers.sampler import get_token_ids_logprobs, get_top_logprobs
|
||||
from sglang.srt.managers.schedule_batch import ScheduleBatch, get_last_loc
|
||||
from sglang.srt.managers.schedule_batch import (
|
||||
ScheduleBatch,
|
||||
get_last_loc,
|
||||
global_server_args_dict,
|
||||
)
|
||||
from sglang.srt.managers.tp_worker import TpModelWorker
|
||||
from sglang.srt.model_executor.forward_batch_info import (
|
||||
CaptureHiddenMode,
|
||||
@@ -146,15 +150,26 @@ class EAGLEWorker(TpModelWorker):
|
||||
def init_attention_backend(self):
|
||||
# Create multi-step attn backends and cuda graph runners
|
||||
if self.server_args.attention_backend == "flashinfer":
|
||||
from sglang.srt.layers.attention.flashinfer_backend import (
|
||||
FlashInferMultiStepDraftBackend,
|
||||
)
|
||||
if not global_server_args_dict["use_mla_backend"]:
|
||||
from sglang.srt.layers.attention.flashinfer_backend import (
|
||||
FlashInferMultiStepDraftBackend,
|
||||
)
|
||||
|
||||
self.draft_attn_backend = FlashInferMultiStepDraftBackend(
|
||||
self.draft_model_runner,
|
||||
self.topk,
|
||||
self.speculative_num_steps,
|
||||
)
|
||||
self.draft_attn_backend = FlashInferMultiStepDraftBackend(
|
||||
self.draft_model_runner,
|
||||
self.topk,
|
||||
self.speculative_num_steps,
|
||||
)
|
||||
else:
|
||||
from sglang.srt.layers.attention.flashinfer_mla_backend import (
|
||||
FlashInferMLAMultiStepDraftBackend,
|
||||
)
|
||||
|
||||
self.draft_attn_backend = FlashInferMLAMultiStepDraftBackend(
|
||||
self.draft_model_runner,
|
||||
self.topk,
|
||||
self.speculative_num_steps,
|
||||
)
|
||||
self.draft_extend_attn_backend = None
|
||||
self.padded_static_len = self.speculative_num_steps + 1
|
||||
self.has_prefill_wrapper_verify = True
|
||||
@@ -171,19 +186,6 @@ class EAGLEWorker(TpModelWorker):
|
||||
self.draft_extend_attn_backend = None
|
||||
self.padded_static_len = self.speculative_num_steps + 1
|
||||
self.has_prefill_wrapper_verify = False
|
||||
elif self.server_args.attention_backend == "flashinfer_mla":
|
||||
from sglang.srt.layers.attention.flashinfer_mla_backend import (
|
||||
FlashInferMLAMultiStepDraftBackend,
|
||||
)
|
||||
|
||||
self.draft_attn_backend = FlashInferMLAMultiStepDraftBackend(
|
||||
self.draft_model_runner,
|
||||
self.topk,
|
||||
self.speculative_num_steps,
|
||||
)
|
||||
self.draft_extend_attn_backend = None
|
||||
self.padded_static_len = self.speculative_num_steps + 1
|
||||
self.has_prefill_wrapper_verify = True
|
||||
elif self.server_args.attention_backend == "fa3":
|
||||
from sglang.srt.layers.attention.flashattention_backend import (
|
||||
FlashAttentionMultiStepBackend,
|
||||
|
||||
Reference in New Issue
Block a user