Support page size > 1 + eagle (#4908)
This commit is contained in:
@@ -33,6 +33,7 @@ runtime_common = [
|
||||
"prometheus-client>=0.20.0",
|
||||
"psutil",
|
||||
"pydantic",
|
||||
"pynvml",
|
||||
"python-multipart",
|
||||
"pyzmq>=25.1.2",
|
||||
"soundfile==0.13.1",
|
||||
|
||||
@@ -14,7 +14,6 @@ from functools import partial
|
||||
from typing import TYPE_CHECKING, Callable, List, Optional, Union
|
||||
|
||||
import torch
|
||||
import triton
|
||||
|
||||
from sglang.global_config import global_config
|
||||
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
|
||||
@@ -22,7 +21,7 @@ from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_trito
|
||||
from sglang.srt.layers.dp_attention import get_attention_tp_size
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
||||
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
|
||||
from sglang.srt.utils import get_bool_env_var, is_flashinfer_available
|
||||
from sglang.srt.utils import is_flashinfer_available, next_power_of_2
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.layers.radix_attention import RadixAttention
|
||||
@@ -932,6 +931,7 @@ class FlashInferMultiStepDraftBackend:
|
||||
self.topk = topk
|
||||
self.speculative_num_steps = speculative_num_steps
|
||||
self.generate_draft_decode_kv_indices = generate_draft_decode_kv_indices
|
||||
self.page_size = model_runner.page_size
|
||||
|
||||
max_bs = model_runner.req_to_token_pool.size * self.topk
|
||||
self.kv_indptr = torch.zeros(
|
||||
@@ -985,9 +985,9 @@ class FlashInferMultiStepDraftBackend:
|
||||
self.pool_len,
|
||||
kv_indices_buffer.shape[1],
|
||||
self.kv_indptr.shape[1],
|
||||
triton.next_power_of_2(num_seqs),
|
||||
triton.next_power_of_2(self.speculative_num_steps),
|
||||
triton.next_power_of_2(bs),
|
||||
next_power_of_2(num_seqs),
|
||||
next_power_of_2(self.speculative_num_steps),
|
||||
next_power_of_2(bs),
|
||||
)
|
||||
|
||||
assert forward_batch.spec_info is not None
|
||||
@@ -1018,8 +1018,6 @@ class FlashInferMultiStepDraftBackend:
|
||||
)
|
||||
|
||||
def call_fn(i, forward_batch):
|
||||
assert forward_batch.spec_info is not None
|
||||
assert isinstance(forward_batch.spec_info, EagleDraftInput)
|
||||
forward_batch.spec_info.kv_indptr = (
|
||||
forward_batch.spec_info.kv_indptr.clone()
|
||||
)
|
||||
|
||||
@@ -740,11 +740,14 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
||||
)
|
||||
return req_pool_indices
|
||||
|
||||
def alloc_token_slots(self, num_tokens: int):
|
||||
def alloc_token_slots(self, num_tokens: int, backup_state: bool = False):
|
||||
if self.token_to_kv_pool_allocator.available_size() < num_tokens:
|
||||
if self.tree_cache is not None:
|
||||
self.tree_cache.evict(num_tokens)
|
||||
|
||||
if backup_state:
|
||||
state = self.token_to_kv_pool_allocator.backup_state()
|
||||
|
||||
out_cache_loc = self.token_to_kv_pool_allocator.alloc(num_tokens)
|
||||
if out_cache_loc is None:
|
||||
phase_str = "Prefill" if self.forward_mode.is_extend() else "Decode"
|
||||
@@ -758,7 +761,10 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
||||
self.tree_cache.pretty_print()
|
||||
raise RuntimeError(error_msg)
|
||||
|
||||
return out_cache_loc
|
||||
if backup_state:
|
||||
return out_cache_loc, state
|
||||
else:
|
||||
return out_cache_loc
|
||||
|
||||
def alloc_paged_token_slots_extend(
|
||||
self,
|
||||
@@ -766,6 +772,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
||||
seq_lens: torch.Tensor,
|
||||
last_loc: torch.Tensor,
|
||||
extend_num_tokens: int,
|
||||
backup_state: bool = False,
|
||||
):
|
||||
if (
|
||||
self.token_to_kv_pool_allocator.available_size()
|
||||
@@ -778,6 +785,9 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
||||
+ len(seq_lens) * self.token_to_kv_pool_allocator.page_size,
|
||||
)
|
||||
|
||||
if backup_state:
|
||||
state = self.token_to_kv_pool_allocator.backup_state()
|
||||
|
||||
out_cache_loc = self.token_to_kv_pool_allocator.alloc_extend(
|
||||
prefix_lens, seq_lens, last_loc, extend_num_tokens
|
||||
)
|
||||
@@ -791,12 +801,17 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
||||
)
|
||||
logger.error(error_msg)
|
||||
raise RuntimeError(error_msg)
|
||||
return out_cache_loc
|
||||
|
||||
if backup_state:
|
||||
return out_cache_loc, state
|
||||
else:
|
||||
return out_cache_loc
|
||||
|
||||
def alloc_paged_token_slots_decode(
|
||||
self,
|
||||
seq_lens: torch.Tensor,
|
||||
last_loc: torch.Tensor,
|
||||
backup_state: bool = False,
|
||||
):
|
||||
if (
|
||||
self.token_to_kv_pool_allocator.available_size()
|
||||
@@ -806,8 +821,11 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
||||
self.tree_cache.evict(
|
||||
len(seq_lens) * self.token_to_kv_pool_allocator.page_size,
|
||||
)
|
||||
out_cache_loc = self.token_to_kv_pool_allocator.alloc_decode(seq_lens, last_loc)
|
||||
|
||||
if backup_state:
|
||||
state = self.token_to_kv_pool_allocator.backup_state()
|
||||
|
||||
out_cache_loc = self.token_to_kv_pool_allocator.alloc_decode(seq_lens, last_loc)
|
||||
if out_cache_loc is None:
|
||||
error_msg = (
|
||||
f"Decode out of memory. Try to lower your batch size.\n"
|
||||
@@ -818,7 +836,11 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
||||
)
|
||||
logger.error(error_msg)
|
||||
raise RuntimeError(error_msg)
|
||||
return out_cache_loc
|
||||
|
||||
if backup_state:
|
||||
return out_cache_loc, state
|
||||
else:
|
||||
return out_cache_loc
|
||||
|
||||
def prepare_encoder_info_extend(self, input_ids: List[int], seq_lens: List[int]):
|
||||
self.encoder_lens_cpu = []
|
||||
|
||||
@@ -1110,7 +1110,7 @@ class Scheduler(
|
||||
)
|
||||
if memory_leak:
|
||||
msg = (
|
||||
"KV cache pool leak detected! "
|
||||
"token_to_kv_pool_allocator memory leak detected! "
|
||||
f"{available_size=}, {protected_size=}, {self.max_total_num_tokens=}\n"
|
||||
f"{self.token_to_kv_pool_allocator.available_size()=}\n"
|
||||
f"{self.tree_cache.evictable_size()=}\n"
|
||||
@@ -1121,7 +1121,7 @@ class Scheduler(
|
||||
|
||||
if len(self.req_to_token_pool.free_slots) != self.req_to_token_pool.size:
|
||||
msg = (
|
||||
"Memory pool leak detected!"
|
||||
"req_to_token_pool memory leak detected!"
|
||||
f"available_size={len(self.req_to_token_pool.free_slots)}, "
|
||||
f"total_size={self.req_to_token_pool.size}\n"
|
||||
)
|
||||
|
||||
@@ -185,6 +185,12 @@ class TokenToKVPoolAllocator:
|
||||
if self.free_group:
|
||||
self.free(torch.cat(self.free_group))
|
||||
|
||||
def backup_state(self):
|
||||
return self.free_slots
|
||||
|
||||
def restore_state(self, free_slots):
|
||||
self.free_slots = free_slots
|
||||
|
||||
def clear(self):
|
||||
# The padded slot 0 is used for writing dummy outputs from padded tokens.
|
||||
self.free_slots = torch.arange(
|
||||
|
||||
@@ -218,6 +218,9 @@ class PagedTokenToKVPoolAllocator:
|
||||
next_power_of_2(extend_num_tokens),
|
||||
)
|
||||
|
||||
if self.debug_mode:
|
||||
assert len(torch.unique(out_indices)) == len(out_indices)
|
||||
|
||||
merged_value = self.ret_values.item()
|
||||
num_new_pages = merged_value >> 32
|
||||
if num_new_pages > len(self.free_pages):
|
||||
@@ -248,6 +251,9 @@ class PagedTokenToKVPoolAllocator:
|
||||
self.page_size,
|
||||
)
|
||||
|
||||
if self.debug_mode:
|
||||
assert len(torch.unique(out_indices)) == len(out_indices)
|
||||
|
||||
num_new_pages = self.ret_values.item()
|
||||
if num_new_pages > len(self.free_pages):
|
||||
return None
|
||||
@@ -265,6 +271,9 @@ class PagedTokenToKVPoolAllocator:
|
||||
else:
|
||||
self.free_group.append(free_index)
|
||||
|
||||
if self.debug_mode:
|
||||
assert len(torch.unique(self.free_pages)) == len(self.free_pages)
|
||||
|
||||
def free_group_begin(self):
|
||||
self.is_not_in_free_group = False
|
||||
self.free_group = []
|
||||
@@ -274,6 +283,12 @@ class PagedTokenToKVPoolAllocator:
|
||||
if self.free_group:
|
||||
self.free(torch.cat(self.free_group))
|
||||
|
||||
def backup_state(self):
|
||||
return self.free_pages
|
||||
|
||||
def restore_state(self, free_pages):
|
||||
self.free_pages = free_pages
|
||||
|
||||
def clear(self):
|
||||
# The padded slot 0 is used for writing dummy outputs from padded tokens.
|
||||
self.free_pages = torch.arange(
|
||||
|
||||
@@ -116,16 +116,18 @@ def get_batch_sizes_to_capture(model_runner: ModelRunner):
|
||||
if capture_bs is None:
|
||||
if server_args.speculative_algorithm is None:
|
||||
if server_args.disable_cuda_graph_padding:
|
||||
capture_bs = list(range(1, 33)) + [64, 96, 128, 160]
|
||||
capture_bs = list(range(1, 33)) + range(40, 161, 16)
|
||||
else:
|
||||
capture_bs = [1, 2, 4] + [i * 8 for i in range(1, 21)]
|
||||
capture_bs = [1, 2, 4, 8] + list(range(16, 161, 8))
|
||||
else:
|
||||
# Since speculative decoding requires more cuda graph memory, we
|
||||
# capture less.
|
||||
capture_bs = list(range(1, 9)) + list(range(9, 33, 2)) + [64, 96, 128, 160]
|
||||
capture_bs = (
|
||||
list(range(1, 9)) + list(range(10, 33, 2)) + list(range(40, 161, 16))
|
||||
)
|
||||
|
||||
if _is_hip:
|
||||
capture_bs += [i * 8 for i in range(21, 33)]
|
||||
capture_bs += list(range(160, 257, 8))
|
||||
|
||||
if max(capture_bs) > model_runner.req_to_token_pool.size:
|
||||
# In some case (e.g., with a small GPU or --max-running-requests), the #max-running-requests
|
||||
|
||||
@@ -17,7 +17,7 @@
|
||||
"""Inference-only LLaMA model compatible with HuggingFace weights."""
|
||||
|
||||
import logging
|
||||
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union
|
||||
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
@@ -15,6 +15,7 @@
|
||||
|
||||
import argparse
|
||||
import dataclasses
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
@@ -132,9 +133,9 @@ class ServerArgs:
|
||||
# Speculative decoding
|
||||
speculative_algorithm: Optional[str] = None
|
||||
speculative_draft_model_path: Optional[str] = None
|
||||
speculative_num_steps: int = 5
|
||||
speculative_eagle_topk: int = 4
|
||||
speculative_num_draft_tokens: int = 8
|
||||
speculative_num_steps: Optional[int] = None
|
||||
speculative_eagle_topk: Optional[int] = None
|
||||
speculative_num_draft_tokens: Optional[int] = None
|
||||
speculative_accept_threshold_single: float = 1.0
|
||||
speculative_accept_threshold_acc: float = 1.0
|
||||
speculative_token_map: Optional[str] = None
|
||||
@@ -313,12 +314,29 @@ class ServerArgs:
|
||||
or self.speculative_algorithm == "EAGLE3"
|
||||
):
|
||||
if self.max_running_requests is None:
|
||||
self.max_running_requests = 32
|
||||
self.max_running_requests = 48
|
||||
self.disable_overlap_schedule = True
|
||||
logger.info(
|
||||
"Overlap scheduler is disabled because of using "
|
||||
"eagle speculative decoding."
|
||||
)
|
||||
|
||||
# Auto choose parameters
|
||||
if self.speculative_num_steps is None:
|
||||
assert (
|
||||
self.speculative_eagle_topk is None
|
||||
and self.speculative_num_draft_tokens is None
|
||||
)
|
||||
(
|
||||
self.speculative_num_steps,
|
||||
self.speculative_eagle_topk,
|
||||
self.speculative_num_draft_tokens,
|
||||
) = auto_choose_speculative_params(self)
|
||||
|
||||
if self.page_size > 1 and self.speculative_eagle_topk > 1:
|
||||
self.speculative_eagle_topk = 1
|
||||
logger.info("speculative_eagle_topk is changed to 1 when page_size > 1")
|
||||
|
||||
# The token generated from the verify step is counted.
|
||||
# If sepculative_num_steps >= speculative_num_draft_tokens, the additional tokens will definitely be discarded.
|
||||
# assert self.speculative_num_steps < self.speculative_num_draft_tokens
|
||||
@@ -1253,3 +1271,33 @@ class DeprecatedAction(argparse.Action):
|
||||
|
||||
def __call__(self, parser, namespace, values, option_string=None):
|
||||
raise ValueError(self.help)
|
||||
|
||||
|
||||
def auto_choose_speculative_params(self: ServerArgs):
|
||||
"""
|
||||
Automatically choose the parameters for speculative decoding.
|
||||
|
||||
You can tune them on your own models and prompts with scripts/playground/bench_speculative.py
|
||||
"""
|
||||
if self.decrypted_config_file:
|
||||
config_path = self.decrypted_config_file
|
||||
else:
|
||||
config_path = os.path.join(self.model_path, "config.json")
|
||||
if not os.path.exists(config_path):
|
||||
raise ValueError(f"{config_path} is not found.")
|
||||
|
||||
config = json.load(open(config_path))
|
||||
|
||||
arch = config.get("architectures", ["Unknown"])[0]
|
||||
|
||||
if arch in ["LlamaForCausalLM"]:
|
||||
# The default value for llama
|
||||
return (5, 4, 8)
|
||||
elif arch in ["DeepseekV3ForCausalLM", "DeepseekV2ForCausalLM"]:
|
||||
# The default value for deepseek
|
||||
return (5, 4, 8)
|
||||
elif arch in ["Grok1ForCausalLM", "Grok1VForCausalLM"]:
|
||||
return (5, 4, 8)
|
||||
else:
|
||||
# The default value for all other models
|
||||
return (5, 4, 8)
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, List, Optional
|
||||
|
||||
@@ -10,11 +11,15 @@ import triton.language as tl
|
||||
|
||||
from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||
from sglang.srt.managers.schedule_batch import (
|
||||
ScheduleBatch,
|
||||
get_last_loc,
|
||||
global_server_args_dict,
|
||||
)
|
||||
from sglang.srt.mem_cache.memory_pool import TokenToKVPoolAllocator
|
||||
from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode
|
||||
from sglang.srt.speculative.build_eagle_tree import build_tree_kernel_efficient
|
||||
from sglang.srt.utils import is_cuda_available, is_hip
|
||||
from sglang.srt.utils import is_cuda_available, is_hip, next_power_of_2
|
||||
|
||||
if is_cuda_available():
|
||||
from sgl_kernel import (
|
||||
@@ -34,6 +39,9 @@ import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
SIMULATE_ACC_LEN = os.environ.get("SIMULATE_ACC_LEN")
|
||||
|
||||
|
||||
@dataclass
|
||||
class EagleDraftInput:
|
||||
# The inputs for decode
|
||||
@@ -93,7 +101,7 @@ class EagleDraftInput:
|
||||
torch.cumsum(self.accept_length, axis=0, dtype=torch.int),
|
||||
self.positions,
|
||||
new_verified_id,
|
||||
triton.next_power_of_2(speculative_num_steps + 1),
|
||||
next_power_of_2(speculative_num_steps + 1),
|
||||
)
|
||||
|
||||
batch.seq_lens_sum = sum(seq_lens_cpu)
|
||||
@@ -225,18 +233,34 @@ class EagleVerifyInput:
|
||||
CaptureHiddenMode.FULL,
|
||||
)
|
||||
|
||||
def prepare_for_verify(self, batch: ScheduleBatch):
|
||||
def prepare_for_verify(self, batch: ScheduleBatch, page_size: int):
|
||||
batch.input_ids = self.draft_token
|
||||
batch.out_cache_loc = batch.alloc_token_slots(batch.input_ids.numel())
|
||||
|
||||
if page_size == 1:
|
||||
batch.out_cache_loc = batch.alloc_token_slots(len(batch.input_ids))
|
||||
end_offset = batch.seq_lens + self.draft_token_num
|
||||
else:
|
||||
prefix_lens = batch.seq_lens
|
||||
end_offset = prefix_lens + self.draft_token_num
|
||||
last_loc = get_last_loc(
|
||||
batch.req_to_token_pool.req_to_token,
|
||||
batch.req_pool_indices,
|
||||
prefix_lens,
|
||||
)
|
||||
batch.out_cache_loc = batch.alloc_paged_token_slots_extend(
|
||||
prefix_lens, end_offset, last_loc, len(batch.input_ids)
|
||||
)
|
||||
self.last_loc = last_loc
|
||||
|
||||
bs = batch.batch_size()
|
||||
assign_req_to_token_pool[(bs,)](
|
||||
batch.req_pool_indices,
|
||||
batch.req_to_token_pool.req_to_token,
|
||||
batch.seq_lens,
|
||||
batch.seq_lens + self.draft_token_num,
|
||||
end_offset,
|
||||
batch.out_cache_loc,
|
||||
batch.req_to_token_pool.req_to_token.shape[1],
|
||||
triton.next_power_of_2(bs),
|
||||
next_power_of_2(bs),
|
||||
)
|
||||
|
||||
def generate_attn_arg_prefill(
|
||||
@@ -282,6 +306,7 @@ class EagleVerifyInput:
|
||||
batch: ScheduleBatch,
|
||||
logits_output: torch.Tensor,
|
||||
token_to_kv_pool_allocator: TokenToKVPoolAllocator,
|
||||
page_size: int,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Verify and find accepted tokens based on logits output and batch
|
||||
@@ -305,6 +330,7 @@ class EagleVerifyInput:
|
||||
)
|
||||
accept_length = torch.empty((bs,), dtype=torch.int32, device="cuda")
|
||||
|
||||
# Apply penalty
|
||||
if sampling_info.penalizer_orchestrator.is_required:
|
||||
# This is a relaxed version of penalties for speculative decoding.
|
||||
linear_penalty = torch.zeros(
|
||||
@@ -317,6 +343,7 @@ class EagleVerifyInput:
|
||||
torch.repeat_interleave(linear_penalty, self.draft_token_num, dim=0)
|
||||
)
|
||||
|
||||
# Sample tokens
|
||||
if batch.sampling_info.is_all_greedy:
|
||||
target_predict = torch.argmax(logits_output.next_token_logits, dim=-1)
|
||||
target_predict = target_predict.reshape(bs, self.draft_token_num)
|
||||
@@ -378,13 +405,24 @@ class EagleVerifyInput:
|
||||
deterministic=True,
|
||||
)
|
||||
|
||||
if SIMULATE_ACC_LEN:
|
||||
# Do simulation
|
||||
accept_index = _generate_simulated_accept_index(
|
||||
accept_index=accept_index,
|
||||
predict=predict, # mutable
|
||||
accept_length=accept_length, # mutable
|
||||
simulate_acc_len=SIMULATE_ACC_LEN,
|
||||
bs=bs,
|
||||
spec_steps=self.spec_steps,
|
||||
)
|
||||
|
||||
new_accept_index = []
|
||||
unfinished_index = []
|
||||
accept_index_cpu = accept_index.tolist()
|
||||
predict_cpu = predict.tolist()
|
||||
has_finished = False
|
||||
|
||||
# iterate every accepted token and check if req has finished after append the token
|
||||
# Iterate every accepted token and check if req has finished after append the token
|
||||
# should be checked BEFORE free kv cache slots
|
||||
for i, (req, accept_index_row) in enumerate(zip(batch.reqs, accept_index_cpu)):
|
||||
new_accept_index_ = []
|
||||
@@ -407,13 +445,28 @@ class EagleVerifyInput:
|
||||
unfinished_index.append(i)
|
||||
req.spec_verify_ct += 1
|
||||
|
||||
if has_finished:
|
||||
accept_length = (accept_index != -1).sum(dim=1) - 1
|
||||
|
||||
# Free the KV cache for unaccepted tokens
|
||||
accept_index = accept_index[accept_index != -1]
|
||||
verified_id = predict[accept_index]
|
||||
evict_mask = torch.full_like(self.draft_token, True, dtype=torch.bool)
|
||||
evict_mask[accept_index] = False
|
||||
|
||||
if page_size != 1:
|
||||
align_evict_mask_to_page_size[len(batch.seq_lens),](
|
||||
batch.seq_lens,
|
||||
evict_mask,
|
||||
page_size,
|
||||
self.draft_token_num,
|
||||
next_power_of_2(self.draft_token_num),
|
||||
)
|
||||
|
||||
token_to_kv_pool_allocator.free(batch.out_cache_loc[evict_mask])
|
||||
|
||||
# Construct EagleVerifyOutput
|
||||
if not has_finished:
|
||||
accept_index = accept_index[accept_index != -1]
|
||||
verified_id = predict[accept_index]
|
||||
evict_mask = torch.full_like(self.draft_token, True, dtype=torch.bool)
|
||||
evict_mask[accept_index] = False
|
||||
mem_need_free_idx = batch.out_cache_loc[evict_mask]
|
||||
token_to_kv_pool_allocator.free(mem_need_free_idx)
|
||||
batch.out_cache_loc = batch.out_cache_loc[accept_index]
|
||||
assign_req_to_token_pool[(bs,)](
|
||||
batch.req_pool_indices,
|
||||
@@ -422,7 +475,7 @@ class EagleVerifyInput:
|
||||
batch.seq_lens + accept_length + 1,
|
||||
batch.out_cache_loc,
|
||||
batch.req_to_token_pool.req_to_token.shape[1],
|
||||
triton.next_power_of_2(bs),
|
||||
next_power_of_2(bs),
|
||||
)
|
||||
batch.seq_lens.add_(accept_length + 1)
|
||||
accept_length_cpu = accept_length.tolist()
|
||||
@@ -443,13 +496,6 @@ class EagleVerifyInput:
|
||||
accepeted_indices=accept_index,
|
||||
)
|
||||
else:
|
||||
accept_length = (accept_index != -1).sum(dim=1) - 1
|
||||
accept_index = accept_index[accept_index != -1]
|
||||
verified_id = predict[accept_index]
|
||||
evict_mask = torch.full_like(self.draft_token, True, dtype=torch.bool)
|
||||
evict_mask[accept_index] = False
|
||||
mem_need_free_idx = batch.out_cache_loc[evict_mask]
|
||||
token_to_kv_pool_allocator.free(mem_need_free_idx)
|
||||
assign_req_to_token_pool[(bs,)](
|
||||
batch.req_pool_indices,
|
||||
batch.req_to_token_pool.req_to_token,
|
||||
@@ -457,7 +503,7 @@ class EagleVerifyInput:
|
||||
batch.seq_lens + accept_length + 1,
|
||||
batch.out_cache_loc[accept_index],
|
||||
batch.req_to_token_pool.req_to_token.shape[1],
|
||||
triton.next_power_of_2(bs),
|
||||
next_power_of_2(bs),
|
||||
)
|
||||
batch.seq_lens.add_(accept_length + 1)
|
||||
accept_length_cpu = accept_length.tolist()
|
||||
@@ -465,20 +511,21 @@ class EagleVerifyInput:
|
||||
draft_input = EagleDraftInput()
|
||||
if len(new_accept_index) > 0:
|
||||
new_accept_index = torch.tensor(new_accept_index, device="cuda")
|
||||
unfinished_index_device = torch.tensor(unfinished_index, device="cuda")
|
||||
draft_input.hidden_states = batch.spec_info.hidden_states[
|
||||
new_accept_index
|
||||
]
|
||||
draft_input.verified_id = predict[new_accept_index]
|
||||
draft_input.accept_length = accept_length[unfinished_index]
|
||||
draft_input.accept_length_cpu = [
|
||||
accept_length_cpu[i] for i in unfinished_index
|
||||
]
|
||||
draft_input.accept_length = accept_length[unfinished_index_device]
|
||||
if has_finished:
|
||||
draft_input.seq_lens_for_draft_extend = batch.seq_lens[
|
||||
unfinished_index
|
||||
unfinished_index_device
|
||||
]
|
||||
draft_input.req_pool_indices_for_draft_extend = (
|
||||
batch.req_pool_indices[unfinished_index]
|
||||
batch.req_pool_indices[unfinished_index_device]
|
||||
)
|
||||
else:
|
||||
draft_input.seq_lens_for_draft_extend = batch.seq_lens
|
||||
@@ -564,13 +611,24 @@ def assign_draft_cache_locs(
|
||||
pool_len: tl.constexpr,
|
||||
topk: tl.constexpr,
|
||||
speculative_num_steps: tl.constexpr,
|
||||
page_size: tl.constexpr,
|
||||
):
|
||||
BLOCK_SIZE: tl.constexpr = 32
|
||||
pid = tl.program_id(axis=0)
|
||||
kv_start = tl.load(seq_lens + pid)
|
||||
kv_end = tl.load(seq_lens + pid) + topk * speculative_num_steps
|
||||
|
||||
if page_size == 1 or topk == 1:
|
||||
kv_end = tl.load(seq_lens + pid) + topk * speculative_num_steps
|
||||
out_cache_ptr = out_cache_loc + pid * topk * speculative_num_steps
|
||||
else:
|
||||
prefix_len = tl.load(seq_lens + pid)
|
||||
last_page_len = prefix_len % page_size
|
||||
num_new_page = (
|
||||
last_page_len + speculative_num_steps + page_size - 1
|
||||
) // page_size
|
||||
kv_end = prefix_len // page_size * page_size + num_new_page * (page_size * topk)
|
||||
|
||||
token_pool = req_to_token + tl.load(req_pool_indices + pid) * pool_len
|
||||
out_cache_ptr = out_cache_loc + pid * topk * speculative_num_steps
|
||||
|
||||
num_loop = tl.cdiv(topk * speculative_num_steps, BLOCK_SIZE)
|
||||
for i in range(num_loop):
|
||||
@@ -642,6 +700,29 @@ def generate_draft_decode_kv_indices(
|
||||
tl.store(kv_indptr + zid, base + zid * iters)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def align_evict_mask_to_page_size(
|
||||
seq_lens,
|
||||
evict_mask,
|
||||
page_size: tl.constexpr,
|
||||
num_draft_tokens: tl.constexpr,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
t_range = tl.arange(0, BLOCK_SIZE)
|
||||
|
||||
bid = tl.program_id(axis=0)
|
||||
seq_len = tl.load(seq_lens + bid)
|
||||
io_mask = t_range < num_draft_tokens
|
||||
mask_row = tl.load(evict_mask + bid * num_draft_tokens + t_range, mask=io_mask)
|
||||
|
||||
num_trues = tl.sum(mask_row)
|
||||
num_false = num_draft_tokens - num_trues
|
||||
|
||||
start = (seq_len + num_false - 1) // page_size * page_size - seq_len
|
||||
for i in range(max(start, 0), min(start + page_size, num_draft_tokens)):
|
||||
tl.store(evict_mask + bid * num_draft_tokens + i, False)
|
||||
|
||||
|
||||
@torch.compile(dynamic=True)
|
||||
def select_top_k_tokens(
|
||||
i: int,
|
||||
@@ -699,3 +780,34 @@ def fast_topk(values, topk, dim):
|
||||
else:
|
||||
# Use topk for efficiency with larger k values
|
||||
return torch.topk(values, topk, dim=dim)
|
||||
|
||||
|
||||
def _generate_simulated_accept_index(
|
||||
accept_index,
|
||||
predict,
|
||||
accept_length,
|
||||
simulate_acc_len,
|
||||
bs,
|
||||
spec_steps,
|
||||
):
|
||||
simulate_acc_len_float = float(simulate_acc_len)
|
||||
simulated_values = torch.normal(
|
||||
mean=simulate_acc_len_float,
|
||||
std=1.0,
|
||||
size=(1,),
|
||||
device="cpu",
|
||||
)
|
||||
# clamp simulated values to be between 1 and self.spec_steps
|
||||
simulated_values = torch.clamp(simulated_values, min=1.0, max=spec_steps)
|
||||
simulate_acc_len = int(simulated_values.round().item())
|
||||
|
||||
accept_indx_first_col = accept_index[:, 0].view(-1, 1)
|
||||
sim_accept_index = torch.full(
|
||||
(bs, spec_steps + 1), -1, dtype=torch.int32, device="cuda"
|
||||
)
|
||||
sim_accept_index[:, :simulate_acc_len] = accept_indx_first_col + torch.arange(
|
||||
simulate_acc_len, device=accept_index.device
|
||||
)
|
||||
accept_length.fill_(simulate_acc_len - 1)
|
||||
predict.fill_(100) # some legit token id
|
||||
return sim_accept_index
|
||||
|
||||
@@ -11,7 +11,7 @@ from sglang.srt.distributed import GroupCoordinator, patch_tensor_parallel_group
|
||||
from sglang.srt.layers.dp_attention import disable_dp_size
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
||||
from sglang.srt.layers.sampler import get_token_ids_logprobs, get_top_logprobs
|
||||
from sglang.srt.managers.schedule_batch import ScheduleBatch
|
||||
from sglang.srt.managers.schedule_batch import ScheduleBatch, get_last_loc
|
||||
from sglang.srt.managers.tp_worker import TpModelWorker
|
||||
from sglang.srt.model_executor.forward_batch_info import (
|
||||
CaptureHiddenMode,
|
||||
@@ -67,6 +67,7 @@ class EAGLEWorker(TpModelWorker):
|
||||
self.gpu_id = gpu_id
|
||||
self.device = server_args.device
|
||||
self.target_worker = target_worker
|
||||
self.page_size = server_args.page_size
|
||||
self.speculative_algorithm = SpeculativeAlgorithm.from_string(
|
||||
server_args.speculative_algorithm
|
||||
)
|
||||
@@ -234,14 +235,11 @@ class EAGLEWorker(TpModelWorker):
|
||||
"""
|
||||
if batch.forward_mode.is_decode():
|
||||
with self.draft_tp_context(self.draft_model_runner.tp_group):
|
||||
spec_info, to_free_cache_loc = self.draft(batch)
|
||||
spec_info = self.draft(batch)
|
||||
logits_output, verify_output, model_worker_batch = self.verify(
|
||||
batch, spec_info
|
||||
)
|
||||
|
||||
# Free cache loc (we put it here to avoid synchronization and hide kernel launch overhead.)
|
||||
self.token_to_kv_pool_allocator.free(to_free_cache_loc)
|
||||
|
||||
# If it is None, it means all requests are finished
|
||||
if batch.spec_info.verified_id is not None:
|
||||
with self.draft_tp_context(self.draft_model_runner.tp_group):
|
||||
@@ -305,9 +303,59 @@ class EAGLEWorker(TpModelWorker):
|
||||
)
|
||||
|
||||
# Allocate cache locations
|
||||
out_cache_loc = batch.alloc_token_slots(
|
||||
num_seqs * self.topk * self.speculative_num_steps
|
||||
)
|
||||
if self.page_size == 1:
|
||||
out_cache_loc, token_to_kv_pool_state_backup = batch.alloc_token_slots(
|
||||
num_seqs * self.topk * self.speculative_num_steps, backup_state=True
|
||||
)
|
||||
else:
|
||||
if self.topk == 1:
|
||||
prefix_lens = batch.seq_lens
|
||||
seq_lens = prefix_lens + self.speculative_num_steps
|
||||
extend_num_tokens = num_seqs * self.speculative_num_steps
|
||||
else:
|
||||
# In this case, the last partial page needs to be duplicated.
|
||||
# KV cache layout in batch.req_to_token_pool.req_to_token:
|
||||
#
|
||||
# | -------- | -- xxxx .. | -- xxxx .. | -- xxxx .. |
|
||||
# prefix top-k = 0 tok-k = 1 top-k = 2
|
||||
#
|
||||
# "-" means prefix tokens
|
||||
# "x" means speculative draft tokens
|
||||
# "." means padded tokens
|
||||
|
||||
# TODO: fuse these ops
|
||||
prefix_lens = batch.seq_lens
|
||||
last_page_lens = prefix_lens % self.page_size
|
||||
num_new_pages = (
|
||||
last_page_lens + self.speculative_num_steps + self.page_size - 1
|
||||
) // self.page_size
|
||||
seq_lens = (
|
||||
prefix_lens // self.page_size * self.page_size
|
||||
+ num_new_pages * (self.page_size * self.topk)
|
||||
)
|
||||
extend_num_tokens = torch.sum(seq_lens - prefix_lens).item()
|
||||
raise NotImplementedError(
|
||||
"page_size > 1 and top_k > 1 are not supported."
|
||||
)
|
||||
# TODO: Support page_size > 1 and top_k > 1
|
||||
# 1. Duplicate the KV cache in the last partial page for all top-k segments
|
||||
# 2. Modify generate_draft_decode_kv_indices accordingly
|
||||
|
||||
last_loc = get_last_loc(
|
||||
batch.req_to_token_pool.req_to_token,
|
||||
batch.req_pool_indices,
|
||||
prefix_lens,
|
||||
)
|
||||
out_cache_loc, token_to_kv_pool_state_backup = (
|
||||
batch.alloc_paged_token_slots_extend(
|
||||
prefix_lens,
|
||||
seq_lens,
|
||||
last_loc,
|
||||
extend_num_tokens,
|
||||
backup_state=True,
|
||||
)
|
||||
)
|
||||
|
||||
assign_draft_cache_locs[(num_seqs,)](
|
||||
batch.req_pool_indices,
|
||||
batch.req_to_token_pool.req_to_token,
|
||||
@@ -316,6 +364,7 @@ class EAGLEWorker(TpModelWorker):
|
||||
batch.req_to_token_pool.req_to_token.shape[1],
|
||||
self.topk,
|
||||
self.speculative_num_steps,
|
||||
self.page_size,
|
||||
)
|
||||
batch.out_cache_loc = out_cache_loc
|
||||
batch.seq_lens_sum = torch.sum(batch.seq_lens).item()
|
||||
@@ -343,6 +392,8 @@ class EAGLEWorker(TpModelWorker):
|
||||
# Run forward steps
|
||||
score_list, token_list, parents_list = self.draft_forward(forward_batch)
|
||||
|
||||
self.token_to_kv_pool_allocator.restore_state(token_to_kv_pool_state_backup)
|
||||
|
||||
ret = EagleVerifyInput.create(
|
||||
spec_info.verified_id,
|
||||
score_list,
|
||||
@@ -354,7 +405,7 @@ class EAGLEWorker(TpModelWorker):
|
||||
self.speculative_num_steps,
|
||||
self.server_args.speculative_num_draft_tokens,
|
||||
)
|
||||
return ret, out_cache_loc
|
||||
return ret
|
||||
|
||||
def draft_forward(self, forward_batch: ForwardBatch):
|
||||
# Parse args
|
||||
@@ -411,7 +462,7 @@ class EAGLEWorker(TpModelWorker):
|
||||
return score_list, token_list, parents_list
|
||||
|
||||
def verify(self, batch: ScheduleBatch, spec_info: EagleVerifyInput):
|
||||
spec_info.prepare_for_verify(batch)
|
||||
spec_info.prepare_for_verify(batch, self.page_size)
|
||||
batch.forward_mode = ForwardMode.TARGET_VERIFY
|
||||
batch.spec_info = spec_info
|
||||
model_worker_batch = batch.get_model_worker_batch()
|
||||
@@ -421,7 +472,10 @@ class EAGLEWorker(TpModelWorker):
|
||||
self._detect_nan_if_needed(logits_output)
|
||||
spec_info.hidden_states = logits_output.hidden_states
|
||||
res: EagleVerifyOutput = spec_info.verify(
|
||||
batch, logits_output, self.token_to_kv_pool_allocator
|
||||
batch,
|
||||
logits_output,
|
||||
self.token_to_kv_pool_allocator,
|
||||
self.page_size,
|
||||
)
|
||||
|
||||
# Post process based on verified outputs.
|
||||
|
||||
@@ -76,11 +76,14 @@ def is_in_ci():
|
||||
|
||||
|
||||
if is_in_ci():
|
||||
DEFAULT_PORT_FOR_SRT_TEST_RUNNER = 5157
|
||||
DEFAULT_URL_FOR_TEST = "http://127.0.0.1:6157"
|
||||
DEFAULT_PORT_FOR_SRT_TEST_RUNNER = (
|
||||
5000 + int(os.environ.get("CUDA_VISIBLE_DEVICES", "0")[0]) * 100
|
||||
)
|
||||
else:
|
||||
DEFAULT_PORT_FOR_SRT_TEST_RUNNER = 1157
|
||||
DEFAULT_URL_FOR_TEST = "http://127.0.0.1:2157"
|
||||
DEFAULT_PORT_FOR_SRT_TEST_RUNNER = (
|
||||
7000 + int(os.environ.get("CUDA_VISIBLE_DEVICES", "0")[0]) * 100
|
||||
)
|
||||
DEFAULT_URL_FOR_TEST = f"http://127.0.0.1:{DEFAULT_PORT_FOR_SRT_TEST_RUNNER + 1000}"
|
||||
|
||||
|
||||
def call_generate_lightllm(prompt, temperature, max_tokens, stop=None, url=None):
|
||||
@@ -1009,6 +1012,9 @@ def run_logprob_check(self: unittest.TestCase, arg: Tuple):
|
||||
|
||||
|
||||
class CustomTestCase(unittest.TestCase):
|
||||
pass
|
||||
|
||||
"""
|
||||
def _callTestMethod(self, method):
|
||||
max_retry = int(
|
||||
os.environ.get("SGLANG_TEST_MAX_RETRY", "2" if is_in_ci() else "0")
|
||||
@@ -1017,3 +1023,4 @@ class CustomTestCase(unittest.TestCase):
|
||||
lambda: super(CustomTestCase, self)._callTestMethod(method),
|
||||
max_retry=max_retry,
|
||||
)
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user