Revert "[EAGLE] Refactor code for page size > 1 & more simplifications" (#7210)
This commit is contained in:
@@ -35,17 +35,11 @@ from sglang.srt.speculative.eagle_utils import (
|
||||
EagleVerifyInput,
|
||||
EagleVerifyOutput,
|
||||
assign_draft_cache_locs,
|
||||
fast_topk,
|
||||
generate_token_bitmask,
|
||||
select_top_k_tokens,
|
||||
)
|
||||
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
|
||||
from sglang.srt.utils import (
|
||||
empty_context,
|
||||
get_available_gpu_memory,
|
||||
is_cuda,
|
||||
next_power_of_2,
|
||||
)
|
||||
from sglang.srt.utils import empty_context, fast_topk, get_available_gpu_memory, is_cuda
|
||||
|
||||
if is_cuda():
|
||||
from sgl_kernel import segment_packbits
|
||||
@@ -158,12 +152,6 @@ class EAGLEWorker(TpModelWorker):
|
||||
self.init_attention_backend()
|
||||
self.init_cuda_graphs()
|
||||
|
||||
# Some dummy tensors
|
||||
self.num_new_pages_per_topk = torch.empty(
|
||||
(), dtype=torch.int64, device=self.device
|
||||
)
|
||||
self.extend_lens = torch.empty((), dtype=torch.int64, device=self.device)
|
||||
|
||||
def init_attention_backend(self):
|
||||
# Create multi-step attn backends and cuda graph runners
|
||||
if self.server_args.attention_backend == "flashinfer":
|
||||
@@ -266,7 +254,7 @@ class EAGLEWorker(TpModelWorker):
|
||||
self.cuda_graph_runner = EAGLEDraftCudaGraphRunner(self)
|
||||
after_mem = get_available_gpu_memory(self.device, self.gpu_id)
|
||||
logger.info(
|
||||
f"Capture draft cuda graph end. Time elapsed: {time.perf_counter() - tic:.2f} s. mem usage={(before_mem - after_mem):.2f} GB. avail mem={after_mem:.2f} GB."
|
||||
f"Capture draft cuda graph end. Time elapsed: {time.perf_counter() - tic:.2f} s. avail mem={after_mem:.2f} GB. mem usage={(before_mem - after_mem):.2f} GB."
|
||||
)
|
||||
|
||||
# Capture extend
|
||||
@@ -281,7 +269,7 @@ class EAGLEWorker(TpModelWorker):
|
||||
)
|
||||
after_mem = get_available_gpu_memory(self.device, self.gpu_id)
|
||||
logger.info(
|
||||
f"Capture draft extend cuda graph end. Time elapsed: {time.perf_counter() - tic:.2f} s. mem usage={(before_mem - after_mem):.2f} GB. avail mem={after_mem:.2f} GB."
|
||||
f"Capture draft extend cuda graph end. Time elapsed: {time.perf_counter() - tic:.2f} s. avail mem={after_mem:.2f} GB. mem usage={(before_mem - after_mem):.2f} GB."
|
||||
)
|
||||
|
||||
@property
|
||||
@@ -302,6 +290,7 @@ class EAGLEWorker(TpModelWorker):
|
||||
A tuple of the final logit output of the target model, next tokens accepted,
|
||||
the batch id (used for overlap schedule), and number of accepted tokens.
|
||||
"""
|
||||
|
||||
if batch.forward_mode.is_decode():
|
||||
with self.draft_tp_context(self.draft_model_runner.tp_group):
|
||||
spec_info = self.draft(batch)
|
||||
@@ -377,21 +366,14 @@ class EAGLEWorker(TpModelWorker):
|
||||
)
|
||||
|
||||
# Allocate cache locations
|
||||
# Layout of the out_cache_loc
|
||||
# [ topk 0 ] [ topk 1 ]
|
||||
# [iter=0, iter=1, iter=2] [iter=0, iter=1, iter=2]
|
||||
if self.page_size == 1:
|
||||
out_cache_loc, token_to_kv_pool_state_backup = batch.alloc_token_slots(
|
||||
num_seqs * self.speculative_num_steps * self.topk, backup_state=True
|
||||
num_seqs * self.topk * self.speculative_num_steps, backup_state=True
|
||||
)
|
||||
else:
|
||||
if self.topk == 1:
|
||||
prefix_lens, seq_lens, last_loc = get_last_loc_large_page_size_top_k_1(
|
||||
batch.req_to_token_pool.req_to_token,
|
||||
batch.req_pool_indices,
|
||||
batch.seq_lens,
|
||||
self.speculative_num_steps,
|
||||
)
|
||||
prefix_lens = batch.seq_lens
|
||||
seq_lens = prefix_lens + self.speculative_num_steps
|
||||
extend_num_tokens = num_seqs * self.speculative_num_steps
|
||||
else:
|
||||
# In this case, the last partial page needs to be duplicated.
|
||||
@@ -404,33 +386,29 @@ class EAGLEWorker(TpModelWorker):
|
||||
# "x" means speculative draft tokens
|
||||
# "." means padded tokens
|
||||
|
||||
# TODO(lmzheng): The current implementation is still a fake support
|
||||
# for page size > 1. In the `assign_draft_cache_locs` below,
|
||||
# we directly move the indices instead of the real kv cache.
|
||||
# This only works when the kernel backend runs with page size = 1.
|
||||
# If the kernel backend runs with page size > 1, we need to
|
||||
# duplicate the real KV cache. The overhead of duplicating KV
|
||||
# cache seems okay because the draft KV cache only has one layer.
|
||||
# see a related copy operation in MHATokenToKVPool::move_kv_cache.
|
||||
|
||||
(
|
||||
prefix_lens,
|
||||
seq_lens,
|
||||
last_loc,
|
||||
self.num_new_pages_per_topk,
|
||||
self.extend_lens,
|
||||
) = get_last_loc_large_page_size_large_top_k(
|
||||
batch.req_to_token_pool.req_to_token,
|
||||
batch.req_pool_indices,
|
||||
batch.seq_lens,
|
||||
self.speculative_num_steps,
|
||||
self.topk,
|
||||
self.page_size,
|
||||
# TODO: fuse these ops
|
||||
prefix_lens = batch.seq_lens
|
||||
last_page_lens = prefix_lens % self.page_size
|
||||
num_new_pages = (
|
||||
last_page_lens + self.speculative_num_steps + self.page_size - 1
|
||||
) // self.page_size
|
||||
seq_lens = (
|
||||
prefix_lens // self.page_size * self.page_size
|
||||
+ num_new_pages * (self.page_size * self.topk)
|
||||
)
|
||||
extend_num_tokens = torch.sum(seq_lens - prefix_lens).item()
|
||||
raise NotImplementedError(
|
||||
"page_size > 1 and top_k > 1 are not supported."
|
||||
)
|
||||
# TODO: Support page_size > 1 and top_k > 1
|
||||
# 1. Duplicate the KV cache in the last partial page for all top-k segments
|
||||
# 2. Modify generate_draft_decode_kv_indices accordingly
|
||||
|
||||
# TODO(lmzheng): remove this device sync
|
||||
extend_num_tokens = torch.sum(self.extend_lens).item()
|
||||
|
||||
last_loc = get_last_loc(
|
||||
batch.req_to_token_pool.req_to_token,
|
||||
batch.req_pool_indices,
|
||||
prefix_lens,
|
||||
)
|
||||
out_cache_loc, token_to_kv_pool_state_backup = (
|
||||
batch.alloc_paged_token_slots_extend(
|
||||
prefix_lens,
|
||||
@@ -445,31 +423,19 @@ class EAGLEWorker(TpModelWorker):
|
||||
batch.req_pool_indices,
|
||||
batch.req_to_token_pool.req_to_token,
|
||||
batch.seq_lens,
|
||||
self.extend_lens,
|
||||
self.num_new_pages_per_topk,
|
||||
out_cache_loc,
|
||||
batch.req_to_token_pool.req_to_token.shape[1],
|
||||
self.topk,
|
||||
self.speculative_num_steps,
|
||||
self.page_size,
|
||||
next_power_of_2(num_seqs),
|
||||
next_power_of_2(self.speculative_num_steps),
|
||||
)
|
||||
|
||||
if self.page_size > 1 and self.topk > 1:
|
||||
# Remove padded slots
|
||||
out_cache_loc = out_cache_loc[
|
||||
: num_seqs * self.topk * self.speculative_num_steps
|
||||
]
|
||||
|
||||
batch.out_cache_loc = out_cache_loc
|
||||
batch.seq_lens_sum = torch.sum(batch.seq_lens).item()
|
||||
batch.return_hidden_states = False
|
||||
spec_info.positions = batch.seq_lens.repeat_interleave(self.topk, dim=0)
|
||||
spec_info.capture_hidden_mode = CaptureHiddenMode.LAST
|
||||
|
||||
# Get forward batch
|
||||
batch.return_hidden_states = False
|
||||
model_worker_batch = batch.get_model_worker_batch()
|
||||
assert model_worker_batch.capture_hidden_mode == CaptureHiddenMode.LAST
|
||||
forward_batch = ForwardBatch.init_new(
|
||||
model_worker_batch, self.draft_model_runner
|
||||
)
|
||||
@@ -538,13 +504,6 @@ class EAGLEWorker(TpModelWorker):
|
||||
if self.hot_token_id is not None:
|
||||
topk_index = self.hot_token_id[topk_index]
|
||||
|
||||
out_cache_loc = out_cache_loc.reshape(
|
||||
forward_batch.batch_size, self.topk, self.speculative_num_steps
|
||||
)
|
||||
out_cache_loc = out_cache_loc.permute((2, 0, 1)).reshape(
|
||||
self.speculative_num_steps, -1
|
||||
)
|
||||
|
||||
# Return values
|
||||
score_list: List[torch.Tensor] = []
|
||||
token_list: List[torch.Tensor] = []
|
||||
@@ -566,7 +525,10 @@ class EAGLEWorker(TpModelWorker):
|
||||
|
||||
# Set inputs
|
||||
forward_batch.input_ids = input_ids
|
||||
forward_batch.out_cache_loc = out_cache_loc[i:]
|
||||
out_cache_loc = out_cache_loc.view(forward_batch.batch_size, -1)
|
||||
forward_batch.out_cache_loc = out_cache_loc[
|
||||
:, self.topk * i : self.topk * (i + 1)
|
||||
].flatten()
|
||||
forward_batch.positions.add_(1)
|
||||
forward_batch.attn_backend = self.draft_attn_backend.attn_backends[i]
|
||||
spec_info.hidden_states = hidden_states
|
||||
@@ -624,7 +586,7 @@ class EAGLEWorker(TpModelWorker):
|
||||
if vocab_mask is not None:
|
||||
assert spec_info.grammar is not None
|
||||
vocab_mask = vocab_mask.to(spec_info.retrive_next_token.device)
|
||||
# NOTE (sk): otherwise, this vocab mask will be the one from the previous extend stage
|
||||
# otherwise, this vocab mask will be the one from the previous extend stage
|
||||
# and will be applied to produce wrong results
|
||||
batch.sampling_info.vocab_mask = None
|
||||
|
||||
@@ -645,13 +607,13 @@ class EAGLEWorker(TpModelWorker):
|
||||
]
|
||||
logits_output.hidden_states = logits_output.hidden_states[res.accepted_indices]
|
||||
|
||||
if batch.return_logprob:
|
||||
self.add_logprob_values(batch, res, logits_output)
|
||||
|
||||
# Prepare the batch for the next draft forwards.
|
||||
batch.forward_mode = ForwardMode.DECODE
|
||||
batch.spec_info = res.draft_input
|
||||
|
||||
if batch.return_logprob:
|
||||
self.add_logprob_values(batch, res, logits_output)
|
||||
|
||||
return logits_output, res, model_worker_batch, can_run_cuda_graph
|
||||
|
||||
def add_logprob_values(
|
||||
@@ -664,16 +626,8 @@ class EAGLEWorker(TpModelWorker):
|
||||
logits_output = res.logits_output
|
||||
top_logprobs_nums = batch.top_logprobs_nums
|
||||
token_ids_logprobs = batch.token_ids_logprobs
|
||||
accepted_indices = res.accepted_indices
|
||||
assert len(accepted_indices) == len(logits_output.next_token_logits)
|
||||
temperatures = batch.sampling_info.temperatures
|
||||
num_draft_tokens = batch.spec_info.draft_token_num
|
||||
# acceptance indices are the indices in a "flattened" batch.
|
||||
# dividing it to num_draft_tokens will yield the actual batch index.
|
||||
temperatures = temperatures[accepted_indices // num_draft_tokens]
|
||||
|
||||
logprobs = torch.nn.functional.log_softmax(
|
||||
logits_output.next_token_logits / temperatures, dim=-1
|
||||
logits_output.next_token_logits, dim=-1
|
||||
)
|
||||
batch_next_token_ids = res.verified_id
|
||||
num_tokens_per_req = [accept + 1 for accept in res.accept_length_per_req_cpu]
|
||||
@@ -708,7 +662,7 @@ class EAGLEWorker(TpModelWorker):
|
||||
pt = 0
|
||||
next_token_logprobs = logits_output.next_token_logprobs.tolist()
|
||||
verified_ids = batch_next_token_ids.tolist()
|
||||
for req, num_tokens in zip(batch.reqs, num_tokens_per_req, strict=True):
|
||||
for req, num_tokens in zip(batch.reqs, num_tokens_per_req):
|
||||
for _ in range(num_tokens):
|
||||
if req.return_logprob:
|
||||
req.output_token_logprobs_val.append(next_token_logprobs[pt])
|
||||
@@ -736,6 +690,7 @@ class EAGLEWorker(TpModelWorker):
|
||||
hidden_states: Hidden states from the target model forward
|
||||
next_token_ids: Next token ids generated from the target forward.
|
||||
"""
|
||||
# Sometimes we get hidden states produced by CaptureHiddenMode.FULL, so we have to select just the last
|
||||
batch.spec_info = EagleDraftInput(
|
||||
hidden_states=hidden_states,
|
||||
verified_id=next_token_ids,
|
||||
@@ -746,6 +701,7 @@ class EAGLEWorker(TpModelWorker):
|
||||
model_worker_batch = batch.get_model_worker_batch(
|
||||
seq_lens_cpu_cache=seq_lens_cpu
|
||||
)
|
||||
assert model_worker_batch.capture_hidden_mode == CaptureHiddenMode.LAST
|
||||
forward_batch = ForwardBatch.init_new(
|
||||
model_worker_batch, self.draft_model_runner
|
||||
)
|
||||
@@ -768,7 +724,9 @@ class EAGLEWorker(TpModelWorker):
|
||||
batch,
|
||||
self.speculative_num_steps,
|
||||
)
|
||||
batch.return_hidden_states = False
|
||||
model_worker_batch = batch.get_model_worker_batch()
|
||||
assert model_worker_batch.capture_hidden_mode == CaptureHiddenMode.LAST
|
||||
forward_batch = ForwardBatch.init_new(
|
||||
model_worker_batch, self.draft_model_runner
|
||||
)
|
||||
@@ -832,47 +790,3 @@ def load_token_map(token_map_path: str) -> List[int]:
|
||||
token_map_path = os.path.join(cache_dir, os.path.basename(token_map_path))
|
||||
hot_token_id = torch.load(token_map_path, weights_only=True)
|
||||
return torch.tensor(hot_token_id, dtype=torch.int32)
|
||||
|
||||
|
||||
@torch.compile(dynamic=True)
|
||||
def get_last_loc_large_page_size_top_k_1(
|
||||
req_to_token: torch.Tensor,
|
||||
req_pool_indices: torch.Tensor,
|
||||
seq_lens,
|
||||
speculative_num_steps: int,
|
||||
):
|
||||
prefix_lens = seq_lens
|
||||
seq_lens = prefix_lens + speculative_num_steps
|
||||
last_loc = get_last_loc(
|
||||
req_to_token,
|
||||
req_pool_indices,
|
||||
prefix_lens,
|
||||
)
|
||||
return prefix_lens, seq_lens, last_loc
|
||||
|
||||
|
||||
@torch.compile(dynamic=True)
|
||||
def get_last_loc_large_page_size_large_top_k(
|
||||
req_to_token: torch.Tensor,
|
||||
req_pool_indices: torch.Tensor,
|
||||
seq_lens: torch.Tensor,
|
||||
speculative_num_steps: int,
|
||||
topk: int,
|
||||
page_size: int,
|
||||
):
|
||||
prefix_lens = seq_lens
|
||||
last_page_lens = prefix_lens % page_size
|
||||
num_new_pages_per_topk = (
|
||||
last_page_lens + speculative_num_steps + page_size - 1
|
||||
) // page_size
|
||||
seq_lens = prefix_lens // page_size * page_size + num_new_pages_per_topk * (
|
||||
page_size * topk
|
||||
)
|
||||
extend_lens = seq_lens - prefix_lens
|
||||
last_loc = get_last_loc(
|
||||
req_to_token,
|
||||
req_pool_indices,
|
||||
prefix_lens,
|
||||
)
|
||||
|
||||
return prefix_lens, seq_lens, last_loc, num_new_pages_per_topk, extend_lens
|
||||
|
||||
Reference in New Issue
Block a user