Fix a draft model accuracy bug in eagle; support step=1; return logprob in eagle (#4134)
Co-authored-by: Sehoon Kim <kssteven418@gmail.com> Co-authored-by: SangBin Cho <rkooo567@gmail.com> Co-authored-by: Sehoon Kim <sehoon@x.ai>
This commit is contained in:
@@ -7,16 +7,14 @@ FlashInfer is faster and Triton is easier to customize.
|
||||
Each backend supports two operators: extend (i.e. prefill with cached prefix) and decode.
|
||||
"""
|
||||
|
||||
import math
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum, auto
|
||||
from functools import partial
|
||||
from typing import TYPE_CHECKING, List, Optional, Union
|
||||
from typing import TYPE_CHECKING, Callable, List, Optional, Union
|
||||
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
from sglang.global_config import global_config
|
||||
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
|
||||
@@ -37,7 +35,7 @@ if is_flashinfer_available():
|
||||
BatchPrefillWithRaggedKVCacheWrapper,
|
||||
)
|
||||
from flashinfer.cascade import merge_state
|
||||
from flashinfer.decode import PosEncodingMode
|
||||
from flashinfer.decode import _get_range_buf, get_seq_lens
|
||||
|
||||
|
||||
class WrapperDispatch(Enum):
|
||||
@@ -73,8 +71,6 @@ class FlashInferAttnBackend(AttentionBackend):
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.is_multimodal = model_runner.model_config.is_multimodal
|
||||
|
||||
# Parse constants
|
||||
self.decode_use_tensor_cores = should_use_tensor_core(
|
||||
kv_cache_dtype=model_runner.kv_cache_dtype,
|
||||
@@ -86,6 +82,7 @@ class FlashInferAttnBackend(AttentionBackend):
|
||||
)
|
||||
self.max_context_len = model_runner.model_config.context_len
|
||||
self.skip_prefill = skip_prefill
|
||||
self.is_multimodal = model_runner.model_config.is_multimodal
|
||||
|
||||
assert not (
|
||||
model_runner.sliding_window_size is not None
|
||||
@@ -115,7 +112,6 @@ class FlashInferAttnBackend(AttentionBackend):
|
||||
device=model_runner.device,
|
||||
)
|
||||
self.workspace_buffer = global_workspace_buffer
|
||||
|
||||
max_bs = model_runner.req_to_token_pool.size
|
||||
if kv_indptr_buf is None:
|
||||
self.kv_indptr = [
|
||||
@@ -163,9 +159,11 @@ class FlashInferAttnBackend(AttentionBackend):
|
||||
)
|
||||
)
|
||||
self.prefill_wrappers_verify.append(
|
||||
BatchPrefillWithPagedKVCacheWrapper(self.workspace_buffer, "NHD")
|
||||
BatchPrefillWithPagedKVCacheWrapper(
|
||||
self.workspace_buffer,
|
||||
"NHD",
|
||||
)
|
||||
)
|
||||
|
||||
self.decode_wrappers.append(
|
||||
BatchDecodeWithPagedKVCacheWrapper(
|
||||
self.workspace_buffer,
|
||||
@@ -178,13 +176,14 @@ class FlashInferAttnBackend(AttentionBackend):
|
||||
if not skip_prefill:
|
||||
self.indices_updater_prefill = FlashInferIndicesUpdaterPrefill(
|
||||
model_runner, self
|
||||
)
|
||||
) # for verify
|
||||
self.indices_updater_decode = FlashInferIndicesUpdaterDecode(model_runner, self)
|
||||
|
||||
# Other metadata
|
||||
self.forward_metadata: Union[PrefillMetadata, DecodeMetadata] = None
|
||||
self.decode_cuda_graph_metadata = {}
|
||||
self.prefill_cuda_graph_metadata = {}
|
||||
self.prefill_cuda_graph_metadata = {} # For verify
|
||||
self.draft_extend_cuda_graph_metadata = {} # For draft extend
|
||||
|
||||
def init_forward_metadata(self, forward_batch: ForwardBatch):
|
||||
if forward_batch.forward_mode.is_decode_or_idle():
|
||||
@@ -300,7 +299,6 @@ class FlashInferAttnBackend(AttentionBackend):
|
||||
],
|
||||
)
|
||||
)
|
||||
|
||||
seq_lens_sum = seq_lens.sum().item()
|
||||
self.indices_updater_decode.update(
|
||||
req_pool_indices,
|
||||
@@ -312,6 +310,10 @@ class FlashInferAttnBackend(AttentionBackend):
|
||||
)
|
||||
self.decode_cuda_graph_metadata[bs] = decode_wrappers
|
||||
self.forward_metadata = DecodeMetadata(decode_wrappers)
|
||||
for i in range(self.num_wrappers):
|
||||
decode_wrappers[i].begin_forward = partial(
|
||||
fast_decode_plan, decode_wrappers[i]
|
||||
)
|
||||
elif forward_mode.is_target_verify():
|
||||
prefill_wrappers = []
|
||||
for i in range(self.num_wrappers):
|
||||
@@ -437,7 +439,7 @@ class FlashInferAttnBackend(AttentionBackend):
|
||||
forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id),
|
||||
causal=False,
|
||||
sm_scale=layer.scaling,
|
||||
logits_soft_cap=layer.logit_cap,
|
||||
logits_soft_cap=logits_soft_cap,
|
||||
)
|
||||
|
||||
o, _ = merge_state(o1, s1, o2, s2)
|
||||
@@ -636,9 +638,15 @@ class FlashInferIndicesUpdaterDecode:
|
||||
bs = len(req_pool_indices)
|
||||
kv_indptr[1 : bs + 1] = torch.cumsum(paged_kernel_lens, dim=0)
|
||||
kv_indptr = kv_indptr[: bs + 1]
|
||||
kv_indices = torch.empty(
|
||||
paged_kernel_lens_sum, dtype=torch.int32, device="cuda"
|
||||
)
|
||||
|
||||
if wrapper.is_cuda_graph_enabled:
|
||||
# Directly write to the cuda graph input buffer
|
||||
kv_indices = wrapper._paged_kv_indices_buf
|
||||
else:
|
||||
kv_indices = torch.empty(
|
||||
paged_kernel_lens_sum, dtype=torch.int32, device="cuda"
|
||||
)
|
||||
|
||||
create_flashinfer_kv_indices_triton[(bs,)](
|
||||
self.req_to_token,
|
||||
req_pool_indices,
|
||||
@@ -649,9 +657,9 @@ class FlashInferIndicesUpdaterDecode:
|
||||
self.req_to_token.shape[1],
|
||||
)
|
||||
else:
|
||||
assert isinstance(spec_info, EagleDraftInput)
|
||||
kv_indptr, kv_indices = spec_info.kv_indptr, spec_info.kv_indices
|
||||
bs = kv_indptr.shape[0] - 1
|
||||
|
||||
wrapper.begin_forward(
|
||||
kv_indptr,
|
||||
kv_indices,
|
||||
@@ -699,7 +707,7 @@ class FlashInferIndicesUpdaterPrefill:
|
||||
|
||||
def update(
|
||||
self,
|
||||
req_pool_indices: torch.Tnesor,
|
||||
req_pool_indices: torch.Tensor,
|
||||
seq_lens: torch.Tensor,
|
||||
seq_lens_sum: int,
|
||||
prefix_lens: torch.Tensor,
|
||||
@@ -713,7 +721,7 @@ class FlashInferIndicesUpdaterPrefill:
|
||||
|
||||
def update_single_wrapper(
|
||||
self,
|
||||
req_pool_indices: torch.Tnesor,
|
||||
req_pool_indices: torch.Tensor,
|
||||
seq_lens: torch.Tensor,
|
||||
seq_lens_sum: int,
|
||||
prefix_lens: torch.Tensor,
|
||||
@@ -858,7 +866,6 @@ class FlashInferIndicesUpdaterPrefill:
|
||||
kv_indices,
|
||||
self.req_to_token.shape[1],
|
||||
)
|
||||
|
||||
qo_indptr[1 : bs + 1] = torch.cumsum(seq_lens - prefix_lens, dim=0)
|
||||
qo_indptr = qo_indptr[: bs + 1]
|
||||
custom_mask = None
|
||||
@@ -954,7 +961,10 @@ class FlashInferMultiStepDraftBackend:
|
||||
self.pool_len = model_runner.req_to_token_pool.req_to_token.shape[1]
|
||||
|
||||
def common_template(
|
||||
self, forward_batch: ForwardBatch, kv_indices_buffer: torch.Tensor, call_fn: int
|
||||
self,
|
||||
forward_batch: ForwardBatch,
|
||||
kv_indices_buffer: torch.Tensor,
|
||||
call_fn: Callable,
|
||||
):
|
||||
num_seqs = forward_batch.batch_size
|
||||
bs = self.topk * num_seqs
|
||||
@@ -1042,17 +1052,15 @@ class FlashInferMultiStepDraftBackend:
|
||||
forward_mode=ForwardMode.DECODE,
|
||||
spec_info=forward_batch.spec_info,
|
||||
)
|
||||
decode_wrapper = self.attn_backends[i].decode_cuda_graph_metadata[
|
||||
forward_batch.batch_size
|
||||
][0]
|
||||
decode_wrapper.begin_forward = partial(fast_decode_plan, decode_wrapper)
|
||||
|
||||
self.common_template(forward_batch, self.cuda_graph_kv_indices, call_fn)
|
||||
|
||||
def init_forward_metadata_replay_cuda_graph(self, forward_batch):
|
||||
def init_forward_metadata_replay_cuda_graph(
|
||||
self, forward_batch: ForwardBatch, bs: int
|
||||
):
|
||||
def call_fn(i, forward_batch):
|
||||
self.attn_backends[i].init_forward_metadata_replay_cuda_graph(
|
||||
forward_batch.batch_size,
|
||||
bs,
|
||||
forward_batch.req_pool_indices,
|
||||
forward_batch.seq_lens,
|
||||
seq_lens_sum=-1,
|
||||
@@ -1113,6 +1121,11 @@ def should_use_tensor_core(
|
||||
return False
|
||||
|
||||
|
||||
# Use as a fast path to override the indptr in flashinfer's plan function
|
||||
# This is used to remove some host-to-device copy overhead.
|
||||
global_override_indptr_cpu = None
|
||||
|
||||
|
||||
def fast_decode_plan(
|
||||
self,
|
||||
indptr: torch.Tensor,
|
||||
@@ -1142,6 +1155,9 @@ def fast_decode_plan(
|
||||
if logits_soft_cap is None:
|
||||
logits_soft_cap = 0.0
|
||||
|
||||
if self.use_tensor_cores:
|
||||
qo_indptr_host = _get_range_buf(batch_size + 1, "cpu")
|
||||
|
||||
if self.is_cuda_graph_enabled:
|
||||
if batch_size != self._fixed_batch_size:
|
||||
raise ValueError(
|
||||
@@ -1154,7 +1170,7 @@ def fast_decode_plan(
|
||||
raise ValueError(
|
||||
"The size of indices should be less than or equal to the allocated buffer"
|
||||
)
|
||||
# Skip these copies
|
||||
# Skip these copies because we directly write to them during prepartion
|
||||
# self._paged_kv_indptr_buf.copy_(indptr)
|
||||
# self._paged_kv_indices_buf[: len(indices)] = indices
|
||||
# self._paged_kv_last_page_len_buf.copy_(last_page_len)
|
||||
@@ -1162,6 +1178,7 @@ def fast_decode_plan(
|
||||
self._paged_kv_indptr_buf = indptr
|
||||
self._paged_kv_indices_buf = indices
|
||||
self._paged_kv_last_page_len_buf = last_page_len
|
||||
self._qo_indptr_buf = qo_indptr_host.to(self.device, non_blocking=non_blocking)
|
||||
|
||||
# NOTE(Zihao): the following tensors acts as placeholder to pass dtype info
|
||||
if not q_data_type:
|
||||
@@ -1184,27 +1201,55 @@ def fast_decode_plan(
|
||||
)
|
||||
self.last_page_len = torch.ones(32768, dtype=torch.int32)
|
||||
|
||||
empty_q_data = self.empty_q_data
|
||||
empty_kv_cache = self.empty_kv_cache
|
||||
stream = torch.cuda.current_stream()
|
||||
self._cached_module.plan(
|
||||
self._float_workspace_buffer,
|
||||
self._int_workspace_buffer,
|
||||
self._pin_memory_int_workspace_buffer,
|
||||
indptr.to("cpu"),
|
||||
batch_size,
|
||||
num_qo_heads,
|
||||
num_kv_heads,
|
||||
page_size,
|
||||
self.is_cuda_graph_enabled,
|
||||
window_left,
|
||||
logits_soft_cap,
|
||||
head_dim,
|
||||
head_dim,
|
||||
empty_q_data,
|
||||
empty_kv_cache,
|
||||
stream.cuda_stream,
|
||||
indptr_host = (
|
||||
global_override_indptr_cpu
|
||||
if global_override_indptr_cpu is not None
|
||||
else indptr.cpu()
|
||||
)
|
||||
|
||||
if self.use_tensor_cores:
|
||||
kv_lens_arr_host = get_seq_lens(
|
||||
indptr_host, self.last_page_len[:batch_size], page_size
|
||||
)
|
||||
|
||||
self._plan_info = self._cached_module.plan(
|
||||
self._float_workspace_buffer,
|
||||
self._int_workspace_buffer,
|
||||
self._pin_memory_int_workspace_buffer,
|
||||
qo_indptr_host,
|
||||
indptr_host,
|
||||
kv_lens_arr_host,
|
||||
batch_size, # total_num_rows
|
||||
batch_size,
|
||||
num_qo_heads,
|
||||
num_kv_heads,
|
||||
page_size,
|
||||
self.is_cuda_graph_enabled,
|
||||
head_dim,
|
||||
head_dim,
|
||||
False, # causal
|
||||
torch.cuda.current_stream().cuda_stream,
|
||||
)
|
||||
else:
|
||||
self._plan_info = self._cached_module.plan(
|
||||
self._float_workspace_buffer,
|
||||
self._int_workspace_buffer,
|
||||
self._pin_memory_int_workspace_buffer,
|
||||
indptr_host,
|
||||
batch_size,
|
||||
num_qo_heads,
|
||||
num_kv_heads,
|
||||
page_size,
|
||||
self.is_cuda_graph_enabled,
|
||||
window_left,
|
||||
logits_soft_cap,
|
||||
head_dim,
|
||||
head_dim,
|
||||
self.empty_q_data,
|
||||
self.empty_kv_cache,
|
||||
torch.cuda.current_stream().cuda_stream,
|
||||
)
|
||||
|
||||
self._pos_encoding_mode = pos_encoding_mode
|
||||
self._window_left = window_left
|
||||
self._logits_soft_cap = logits_soft_cap
|
||||
|
||||
Reference in New Issue
Block a user