Add speculator attention backend switch (#9981)
This commit is contained in:
@@ -22,17 +22,45 @@ class HybridAttnBackend(AttentionBackend):
|
|||||||
self.prefill_backend = prefill_backend
|
self.prefill_backend = prefill_backend
|
||||||
self.decode_backend = decode_backend
|
self.decode_backend = decode_backend
|
||||||
|
|
||||||
def init_forward_metadata(self, forward_batch: ForwardBatch):
|
def _select_backend(self, forward_mode: ForwardMode) -> AttentionBackend:
|
||||||
if forward_batch.forward_mode.is_decode_or_idle():
|
"""
|
||||||
self.decode_backend.init_forward_metadata(forward_batch)
|
Select the appropriate attention backend based on the forward mode.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
forward_mode: The current forward mode indicating the operation type
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The selected attention backend (prefill or decode)
|
||||||
|
|
||||||
|
Note:
|
||||||
|
- decode_or_idle: Always uses decode backend
|
||||||
|
- target_verify or draft_extend: Uses decode backend if speculative_attention_backend is "decode", otherwise prefill backend
|
||||||
|
- prefill: Always uses prefill backend
|
||||||
|
"""
|
||||||
|
if forward_mode.is_decode_or_idle():
|
||||||
|
return self.decode_backend
|
||||||
|
elif forward_mode.is_target_verify() or forward_mode.is_draft_extend():
|
||||||
|
return (
|
||||||
|
self.decode_backend
|
||||||
|
if self.model_runner.server_args.speculative_attention_backend
|
||||||
|
== "decode"
|
||||||
|
else self.prefill_backend
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
self.prefill_backend.init_forward_metadata(forward_batch)
|
return self.prefill_backend
|
||||||
|
|
||||||
|
def init_forward_metadata(self, forward_batch: ForwardBatch):
|
||||||
|
backend = self._select_backend(forward_batch.forward_mode)
|
||||||
|
backend.init_forward_metadata(forward_batch)
|
||||||
|
|
||||||
def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int):
|
def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int):
|
||||||
self.decode_backend.init_cuda_graph_state(max_bs, max_num_tokens)
|
self.decode_backend.init_cuda_graph_state(max_bs, max_num_tokens)
|
||||||
if self.model_runner.server_args.speculative_algorithm is not None:
|
if (
|
||||||
# When speculative decoding is enabled, we also need to initialize the
|
self.model_runner.server_args.speculative_algorithm is not None
|
||||||
# prefill backend's cuda graph state to support target_verify.
|
and self.model_runner.server_args.speculative_attention_backend == "prefill"
|
||||||
|
):
|
||||||
|
# When speculative decoding is enabled, we need to initialize the backend
|
||||||
|
# that will be used for target_verify.
|
||||||
self.prefill_backend.init_cuda_graph_state(max_bs, max_num_tokens)
|
self.prefill_backend.init_cuda_graph_state(max_bs, max_num_tokens)
|
||||||
|
|
||||||
def init_forward_metadata_capture_cuda_graph(
|
def init_forward_metadata_capture_cuda_graph(
|
||||||
@@ -45,26 +73,16 @@ class HybridAttnBackend(AttentionBackend):
|
|||||||
forward_mode: ForwardMode,
|
forward_mode: ForwardMode,
|
||||||
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
|
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
|
||||||
):
|
):
|
||||||
if forward_mode.is_decode_or_idle():
|
backend = self._select_backend(forward_mode)
|
||||||
self.decode_backend.init_forward_metadata_capture_cuda_graph(
|
backend.init_forward_metadata_capture_cuda_graph(
|
||||||
bs,
|
bs,
|
||||||
num_tokens,
|
num_tokens,
|
||||||
req_pool_indices,
|
req_pool_indices,
|
||||||
seq_lens,
|
seq_lens,
|
||||||
encoder_lens,
|
encoder_lens,
|
||||||
forward_mode,
|
forward_mode,
|
||||||
spec_info,
|
spec_info,
|
||||||
)
|
)
|
||||||
else:
|
|
||||||
self.prefill_backend.init_forward_metadata_capture_cuda_graph(
|
|
||||||
bs,
|
|
||||||
num_tokens,
|
|
||||||
req_pool_indices,
|
|
||||||
seq_lens,
|
|
||||||
encoder_lens,
|
|
||||||
forward_mode,
|
|
||||||
spec_info,
|
|
||||||
)
|
|
||||||
|
|
||||||
def init_forward_metadata_replay_cuda_graph(
|
def init_forward_metadata_replay_cuda_graph(
|
||||||
self,
|
self,
|
||||||
@@ -77,28 +95,17 @@ class HybridAttnBackend(AttentionBackend):
|
|||||||
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
|
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
|
||||||
seq_lens_cpu: Optional[torch.Tensor],
|
seq_lens_cpu: Optional[torch.Tensor],
|
||||||
):
|
):
|
||||||
if forward_mode.is_decode_or_idle():
|
backend = self._select_backend(forward_mode)
|
||||||
self.decode_backend.init_forward_metadata_replay_cuda_graph(
|
backend.init_forward_metadata_replay_cuda_graph(
|
||||||
bs,
|
bs,
|
||||||
req_pool_indices,
|
req_pool_indices,
|
||||||
seq_lens,
|
seq_lens,
|
||||||
seq_lens_sum,
|
seq_lens_sum,
|
||||||
encoder_lens,
|
encoder_lens,
|
||||||
forward_mode,
|
forward_mode,
|
||||||
spec_info,
|
spec_info,
|
||||||
seq_lens_cpu,
|
seq_lens_cpu,
|
||||||
)
|
)
|
||||||
else:
|
|
||||||
self.prefill_backend.init_forward_metadata_replay_cuda_graph(
|
|
||||||
bs,
|
|
||||||
req_pool_indices,
|
|
||||||
seq_lens,
|
|
||||||
seq_lens_sum,
|
|
||||||
encoder_lens,
|
|
||||||
forward_mode,
|
|
||||||
spec_info,
|
|
||||||
seq_lens_cpu,
|
|
||||||
)
|
|
||||||
|
|
||||||
def get_cuda_graph_seq_len_fill_value(self):
|
def get_cuda_graph_seq_len_fill_value(self):
|
||||||
return self.decode_backend.get_cuda_graph_seq_len_fill_value()
|
return self.decode_backend.get_cuda_graph_seq_len_fill_value()
|
||||||
@@ -127,6 +134,7 @@ class HybridAttnBackend(AttentionBackend):
|
|||||||
save_kv_cache: bool = True,
|
save_kv_cache: bool = True,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
return self.prefill_backend.forward_extend(
|
backend = self._select_backend(forward_batch.forward_mode)
|
||||||
|
return backend.forward_extend(
|
||||||
q, k, v, layer, forward_batch, save_kv_cache, **kwargs
|
q, k, v, layer, forward_batch, save_kv_cache, **kwargs
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -98,6 +98,7 @@ GLOBAL_SERVER_ARGS_KEYS = [
|
|||||||
"sampling_backend",
|
"sampling_backend",
|
||||||
"speculative_accept_threshold_single",
|
"speculative_accept_threshold_single",
|
||||||
"speculative_accept_threshold_acc",
|
"speculative_accept_threshold_acc",
|
||||||
|
"speculative_attention_backend",
|
||||||
"torchao_config",
|
"torchao_config",
|
||||||
"triton_attention_reduce_in_fp32",
|
"triton_attention_reduce_in_fp32",
|
||||||
"num_reserved_decode_tokens",
|
"num_reserved_decode_tokens",
|
||||||
|
|||||||
@@ -1045,6 +1045,15 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|||||||
# Determine attention backend used by current forward batch
|
# Determine attention backend used by current forward batch
|
||||||
if forward_batch.forward_mode.is_decode_or_idle():
|
if forward_batch.forward_mode.is_decode_or_idle():
|
||||||
attention_backend = global_server_args_dict["decode_attention_backend"]
|
attention_backend = global_server_args_dict["decode_attention_backend"]
|
||||||
|
elif (
|
||||||
|
forward_batch.forward_mode.is_target_verify()
|
||||||
|
or forward_batch.forward_mode.is_draft_extend()
|
||||||
|
):
|
||||||
|
# Use the specified backend for speculative operations (both verify and draft extend)
|
||||||
|
if global_server_args_dict["speculative_attention_backend"] == "decode":
|
||||||
|
attention_backend = global_server_args_dict["decode_attention_backend"]
|
||||||
|
else: # default to prefill
|
||||||
|
attention_backend = global_server_args_dict["prefill_attention_backend"]
|
||||||
else:
|
else:
|
||||||
attention_backend = global_server_args_dict["prefill_attention_backend"]
|
attention_backend = global_server_args_dict["prefill_attention_backend"]
|
||||||
self.current_attention_backend = attention_backend
|
self.current_attention_backend = attention_backend
|
||||||
|
|||||||
@@ -262,6 +262,7 @@ class ServerArgs:
|
|||||||
speculative_accept_threshold_single: float = 1.0
|
speculative_accept_threshold_single: float = 1.0
|
||||||
speculative_accept_threshold_acc: float = 1.0
|
speculative_accept_threshold_acc: float = 1.0
|
||||||
speculative_token_map: Optional[str] = None
|
speculative_token_map: Optional[str] = None
|
||||||
|
speculative_attention_backend: str = "prefill"
|
||||||
|
|
||||||
# Expert parallelism
|
# Expert parallelism
|
||||||
ep_size: int = 1
|
ep_size: int = 1
|
||||||
@@ -1561,6 +1562,13 @@ class ServerArgs:
|
|||||||
help="The path of the draft model's small vocab table.",
|
help="The path of the draft model's small vocab table.",
|
||||||
default=ServerArgs.speculative_token_map,
|
default=ServerArgs.speculative_token_map,
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--speculative-attention-backend",
|
||||||
|
type=str,
|
||||||
|
choices=["prefill", "decode"],
|
||||||
|
help="Attention backend to use for speculative decoding operations (both target verify and draft extend). 'prefill' (default) or 'decode'.",
|
||||||
|
default=ServerArgs.speculative_attention_backend,
|
||||||
|
)
|
||||||
|
|
||||||
# Expert parallelism
|
# Expert parallelism
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
|
|||||||
@@ -191,7 +191,7 @@ class EAGLEWorker(TpModelWorker):
|
|||||||
# Initialize decode attention backend
|
# Initialize decode attention backend
|
||||||
self.draft_attn_backend = self._create_decode_backend()
|
self.draft_attn_backend = self._create_decode_backend()
|
||||||
|
|
||||||
# Initialize prefill attention backend
|
# Initialize draft extend attention backend (respects speculative_attention_backend setting)
|
||||||
self.draft_extend_attn_backend = self._create_draft_extend_backend()
|
self.draft_extend_attn_backend = self._create_draft_extend_backend()
|
||||||
|
|
||||||
self.draft_model_runner.draft_attn_backend = self.draft_attn_backend
|
self.draft_model_runner.draft_attn_backend = self.draft_attn_backend
|
||||||
@@ -234,11 +234,15 @@ class EAGLEWorker(TpModelWorker):
|
|||||||
"trtllm_mha": self._create_trtllm_mha_prefill_backend,
|
"trtllm_mha": self._create_trtllm_mha_prefill_backend,
|
||||||
"trtllm_mla": self._create_trtllm_mla_prefill_backend,
|
"trtllm_mla": self._create_trtllm_mla_prefill_backend,
|
||||||
}
|
}
|
||||||
|
backend_name = (
|
||||||
|
"decode_attention_backend"
|
||||||
|
if self.server_args.speculative_attention_backend == "decode"
|
||||||
|
else "prefill_attention_backend"
|
||||||
|
)
|
||||||
return self._create_backend(
|
return self._create_backend(
|
||||||
"prefill_attention_backend",
|
backend_name,
|
||||||
backend_map,
|
backend_map,
|
||||||
"EAGLE is not supported in prefill attention backend {backend_type}",
|
"EAGLE is not supported in attention backend {backend_type}",
|
||||||
)
|
)
|
||||||
|
|
||||||
def _create_flashinfer_decode_backend(self):
|
def _create_flashinfer_decode_backend(self):
|
||||||
|
|||||||
@@ -132,5 +132,51 @@ class TestHybridAttnBackendSpeculativeDecoding(TestHybridAttnBackendBase):
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class TestHybridAttnBackendSpeculativeDecodingPrefillBackend(TestHybridAttnBackendBase):
|
||||||
|
speculative_decode = True
|
||||||
|
# This eagle test uses a very small model, so the accuracy is low.
|
||||||
|
accuracy_threshold = 0.2
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_server_args(cls):
|
||||||
|
return DEFAULT_SERVER_ARGS + [
|
||||||
|
"--speculative-algorithm",
|
||||||
|
"EAGLE",
|
||||||
|
"--speculative-draft",
|
||||||
|
DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST,
|
||||||
|
"--speculative-num-steps",
|
||||||
|
"3",
|
||||||
|
"--speculative-eagle-topk",
|
||||||
|
"2",
|
||||||
|
"--speculative-num-draft-tokens",
|
||||||
|
"4",
|
||||||
|
"--speculative-attention-backend",
|
||||||
|
"prefill",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class TestHybridAttnBackendSpeculativeDecodingDecodeBackend(TestHybridAttnBackendBase):
|
||||||
|
speculative_decode = True
|
||||||
|
# This eagle test uses a very small model, so the accuracy is low.
|
||||||
|
accuracy_threshold = 0.2
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_server_args(cls):
|
||||||
|
return DEFAULT_SERVER_ARGS + [
|
||||||
|
"--speculative-algorithm",
|
||||||
|
"EAGLE",
|
||||||
|
"--speculative-draft",
|
||||||
|
DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST,
|
||||||
|
"--speculative-num-steps",
|
||||||
|
"3",
|
||||||
|
"--speculative-eagle-topk",
|
||||||
|
"2",
|
||||||
|
"--speculative-num-draft-tokens",
|
||||||
|
"4",
|
||||||
|
"--speculative-attention-backend",
|
||||||
|
"decode",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|||||||
Reference in New Issue
Block a user