Support page size > 1 + eagle (#4908)
This commit is contained in:
@@ -33,6 +33,7 @@ runtime_common = [
|
|||||||
"prometheus-client>=0.20.0",
|
"prometheus-client>=0.20.0",
|
||||||
"psutil",
|
"psutil",
|
||||||
"pydantic",
|
"pydantic",
|
||||||
|
"pynvml",
|
||||||
"python-multipart",
|
"python-multipart",
|
||||||
"pyzmq>=25.1.2",
|
"pyzmq>=25.1.2",
|
||||||
"soundfile==0.13.1",
|
"soundfile==0.13.1",
|
||||||
|
|||||||
@@ -14,7 +14,6 @@ from functools import partial
|
|||||||
from typing import TYPE_CHECKING, Callable, List, Optional, Union
|
from typing import TYPE_CHECKING, Callable, List, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import triton
|
|
||||||
|
|
||||||
from sglang.global_config import global_config
|
from sglang.global_config import global_config
|
||||||
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
|
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
|
||||||
@@ -22,7 +21,7 @@ from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_trito
|
|||||||
from sglang.srt.layers.dp_attention import get_attention_tp_size
|
from sglang.srt.layers.dp_attention import get_attention_tp_size
|
||||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
||||||
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
|
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
|
||||||
from sglang.srt.utils import get_bool_env_var, is_flashinfer_available
|
from sglang.srt.utils import is_flashinfer_available, next_power_of_2
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from sglang.srt.layers.radix_attention import RadixAttention
|
from sglang.srt.layers.radix_attention import RadixAttention
|
||||||
@@ -932,6 +931,7 @@ class FlashInferMultiStepDraftBackend:
|
|||||||
self.topk = topk
|
self.topk = topk
|
||||||
self.speculative_num_steps = speculative_num_steps
|
self.speculative_num_steps = speculative_num_steps
|
||||||
self.generate_draft_decode_kv_indices = generate_draft_decode_kv_indices
|
self.generate_draft_decode_kv_indices = generate_draft_decode_kv_indices
|
||||||
|
self.page_size = model_runner.page_size
|
||||||
|
|
||||||
max_bs = model_runner.req_to_token_pool.size * self.topk
|
max_bs = model_runner.req_to_token_pool.size * self.topk
|
||||||
self.kv_indptr = torch.zeros(
|
self.kv_indptr = torch.zeros(
|
||||||
@@ -985,9 +985,9 @@ class FlashInferMultiStepDraftBackend:
|
|||||||
self.pool_len,
|
self.pool_len,
|
||||||
kv_indices_buffer.shape[1],
|
kv_indices_buffer.shape[1],
|
||||||
self.kv_indptr.shape[1],
|
self.kv_indptr.shape[1],
|
||||||
triton.next_power_of_2(num_seqs),
|
next_power_of_2(num_seqs),
|
||||||
triton.next_power_of_2(self.speculative_num_steps),
|
next_power_of_2(self.speculative_num_steps),
|
||||||
triton.next_power_of_2(bs),
|
next_power_of_2(bs),
|
||||||
)
|
)
|
||||||
|
|
||||||
assert forward_batch.spec_info is not None
|
assert forward_batch.spec_info is not None
|
||||||
@@ -1018,8 +1018,6 @@ class FlashInferMultiStepDraftBackend:
|
|||||||
)
|
)
|
||||||
|
|
||||||
def call_fn(i, forward_batch):
|
def call_fn(i, forward_batch):
|
||||||
assert forward_batch.spec_info is not None
|
|
||||||
assert isinstance(forward_batch.spec_info, EagleDraftInput)
|
|
||||||
forward_batch.spec_info.kv_indptr = (
|
forward_batch.spec_info.kv_indptr = (
|
||||||
forward_batch.spec_info.kv_indptr.clone()
|
forward_batch.spec_info.kv_indptr.clone()
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -740,11 +740,14 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|||||||
)
|
)
|
||||||
return req_pool_indices
|
return req_pool_indices
|
||||||
|
|
||||||
def alloc_token_slots(self, num_tokens: int):
|
def alloc_token_slots(self, num_tokens: int, backup_state: bool = False):
|
||||||
if self.token_to_kv_pool_allocator.available_size() < num_tokens:
|
if self.token_to_kv_pool_allocator.available_size() < num_tokens:
|
||||||
if self.tree_cache is not None:
|
if self.tree_cache is not None:
|
||||||
self.tree_cache.evict(num_tokens)
|
self.tree_cache.evict(num_tokens)
|
||||||
|
|
||||||
|
if backup_state:
|
||||||
|
state = self.token_to_kv_pool_allocator.backup_state()
|
||||||
|
|
||||||
out_cache_loc = self.token_to_kv_pool_allocator.alloc(num_tokens)
|
out_cache_loc = self.token_to_kv_pool_allocator.alloc(num_tokens)
|
||||||
if out_cache_loc is None:
|
if out_cache_loc is None:
|
||||||
phase_str = "Prefill" if self.forward_mode.is_extend() else "Decode"
|
phase_str = "Prefill" if self.forward_mode.is_extend() else "Decode"
|
||||||
@@ -758,7 +761,10 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|||||||
self.tree_cache.pretty_print()
|
self.tree_cache.pretty_print()
|
||||||
raise RuntimeError(error_msg)
|
raise RuntimeError(error_msg)
|
||||||
|
|
||||||
return out_cache_loc
|
if backup_state:
|
||||||
|
return out_cache_loc, state
|
||||||
|
else:
|
||||||
|
return out_cache_loc
|
||||||
|
|
||||||
def alloc_paged_token_slots_extend(
|
def alloc_paged_token_slots_extend(
|
||||||
self,
|
self,
|
||||||
@@ -766,6 +772,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|||||||
seq_lens: torch.Tensor,
|
seq_lens: torch.Tensor,
|
||||||
last_loc: torch.Tensor,
|
last_loc: torch.Tensor,
|
||||||
extend_num_tokens: int,
|
extend_num_tokens: int,
|
||||||
|
backup_state: bool = False,
|
||||||
):
|
):
|
||||||
if (
|
if (
|
||||||
self.token_to_kv_pool_allocator.available_size()
|
self.token_to_kv_pool_allocator.available_size()
|
||||||
@@ -778,6 +785,9 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|||||||
+ len(seq_lens) * self.token_to_kv_pool_allocator.page_size,
|
+ len(seq_lens) * self.token_to_kv_pool_allocator.page_size,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if backup_state:
|
||||||
|
state = self.token_to_kv_pool_allocator.backup_state()
|
||||||
|
|
||||||
out_cache_loc = self.token_to_kv_pool_allocator.alloc_extend(
|
out_cache_loc = self.token_to_kv_pool_allocator.alloc_extend(
|
||||||
prefix_lens, seq_lens, last_loc, extend_num_tokens
|
prefix_lens, seq_lens, last_loc, extend_num_tokens
|
||||||
)
|
)
|
||||||
@@ -791,12 +801,17 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|||||||
)
|
)
|
||||||
logger.error(error_msg)
|
logger.error(error_msg)
|
||||||
raise RuntimeError(error_msg)
|
raise RuntimeError(error_msg)
|
||||||
return out_cache_loc
|
|
||||||
|
if backup_state:
|
||||||
|
return out_cache_loc, state
|
||||||
|
else:
|
||||||
|
return out_cache_loc
|
||||||
|
|
||||||
def alloc_paged_token_slots_decode(
|
def alloc_paged_token_slots_decode(
|
||||||
self,
|
self,
|
||||||
seq_lens: torch.Tensor,
|
seq_lens: torch.Tensor,
|
||||||
last_loc: torch.Tensor,
|
last_loc: torch.Tensor,
|
||||||
|
backup_state: bool = False,
|
||||||
):
|
):
|
||||||
if (
|
if (
|
||||||
self.token_to_kv_pool_allocator.available_size()
|
self.token_to_kv_pool_allocator.available_size()
|
||||||
@@ -806,8 +821,11 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|||||||
self.tree_cache.evict(
|
self.tree_cache.evict(
|
||||||
len(seq_lens) * self.token_to_kv_pool_allocator.page_size,
|
len(seq_lens) * self.token_to_kv_pool_allocator.page_size,
|
||||||
)
|
)
|
||||||
out_cache_loc = self.token_to_kv_pool_allocator.alloc_decode(seq_lens, last_loc)
|
|
||||||
|
|
||||||
|
if backup_state:
|
||||||
|
state = self.token_to_kv_pool_allocator.backup_state()
|
||||||
|
|
||||||
|
out_cache_loc = self.token_to_kv_pool_allocator.alloc_decode(seq_lens, last_loc)
|
||||||
if out_cache_loc is None:
|
if out_cache_loc is None:
|
||||||
error_msg = (
|
error_msg = (
|
||||||
f"Decode out of memory. Try to lower your batch size.\n"
|
f"Decode out of memory. Try to lower your batch size.\n"
|
||||||
@@ -818,7 +836,11 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|||||||
)
|
)
|
||||||
logger.error(error_msg)
|
logger.error(error_msg)
|
||||||
raise RuntimeError(error_msg)
|
raise RuntimeError(error_msg)
|
||||||
return out_cache_loc
|
|
||||||
|
if backup_state:
|
||||||
|
return out_cache_loc, state
|
||||||
|
else:
|
||||||
|
return out_cache_loc
|
||||||
|
|
||||||
def prepare_encoder_info_extend(self, input_ids: List[int], seq_lens: List[int]):
|
def prepare_encoder_info_extend(self, input_ids: List[int], seq_lens: List[int]):
|
||||||
self.encoder_lens_cpu = []
|
self.encoder_lens_cpu = []
|
||||||
|
|||||||
@@ -1110,7 +1110,7 @@ class Scheduler(
|
|||||||
)
|
)
|
||||||
if memory_leak:
|
if memory_leak:
|
||||||
msg = (
|
msg = (
|
||||||
"KV cache pool leak detected! "
|
"token_to_kv_pool_allocator memory leak detected! "
|
||||||
f"{available_size=}, {protected_size=}, {self.max_total_num_tokens=}\n"
|
f"{available_size=}, {protected_size=}, {self.max_total_num_tokens=}\n"
|
||||||
f"{self.token_to_kv_pool_allocator.available_size()=}\n"
|
f"{self.token_to_kv_pool_allocator.available_size()=}\n"
|
||||||
f"{self.tree_cache.evictable_size()=}\n"
|
f"{self.tree_cache.evictable_size()=}\n"
|
||||||
@@ -1121,7 +1121,7 @@ class Scheduler(
|
|||||||
|
|
||||||
if len(self.req_to_token_pool.free_slots) != self.req_to_token_pool.size:
|
if len(self.req_to_token_pool.free_slots) != self.req_to_token_pool.size:
|
||||||
msg = (
|
msg = (
|
||||||
"Memory pool leak detected!"
|
"req_to_token_pool memory leak detected!"
|
||||||
f"available_size={len(self.req_to_token_pool.free_slots)}, "
|
f"available_size={len(self.req_to_token_pool.free_slots)}, "
|
||||||
f"total_size={self.req_to_token_pool.size}\n"
|
f"total_size={self.req_to_token_pool.size}\n"
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -185,6 +185,12 @@ class TokenToKVPoolAllocator:
|
|||||||
if self.free_group:
|
if self.free_group:
|
||||||
self.free(torch.cat(self.free_group))
|
self.free(torch.cat(self.free_group))
|
||||||
|
|
||||||
|
def backup_state(self):
|
||||||
|
return self.free_slots
|
||||||
|
|
||||||
|
def restore_state(self, free_slots):
|
||||||
|
self.free_slots = free_slots
|
||||||
|
|
||||||
def clear(self):
|
def clear(self):
|
||||||
# The padded slot 0 is used for writing dummy outputs from padded tokens.
|
# The padded slot 0 is used for writing dummy outputs from padded tokens.
|
||||||
self.free_slots = torch.arange(
|
self.free_slots = torch.arange(
|
||||||
|
|||||||
@@ -218,6 +218,9 @@ class PagedTokenToKVPoolAllocator:
|
|||||||
next_power_of_2(extend_num_tokens),
|
next_power_of_2(extend_num_tokens),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if self.debug_mode:
|
||||||
|
assert len(torch.unique(out_indices)) == len(out_indices)
|
||||||
|
|
||||||
merged_value = self.ret_values.item()
|
merged_value = self.ret_values.item()
|
||||||
num_new_pages = merged_value >> 32
|
num_new_pages = merged_value >> 32
|
||||||
if num_new_pages > len(self.free_pages):
|
if num_new_pages > len(self.free_pages):
|
||||||
@@ -248,6 +251,9 @@ class PagedTokenToKVPoolAllocator:
|
|||||||
self.page_size,
|
self.page_size,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if self.debug_mode:
|
||||||
|
assert len(torch.unique(out_indices)) == len(out_indices)
|
||||||
|
|
||||||
num_new_pages = self.ret_values.item()
|
num_new_pages = self.ret_values.item()
|
||||||
if num_new_pages > len(self.free_pages):
|
if num_new_pages > len(self.free_pages):
|
||||||
return None
|
return None
|
||||||
@@ -265,6 +271,9 @@ class PagedTokenToKVPoolAllocator:
|
|||||||
else:
|
else:
|
||||||
self.free_group.append(free_index)
|
self.free_group.append(free_index)
|
||||||
|
|
||||||
|
if self.debug_mode:
|
||||||
|
assert len(torch.unique(self.free_pages)) == len(self.free_pages)
|
||||||
|
|
||||||
def free_group_begin(self):
|
def free_group_begin(self):
|
||||||
self.is_not_in_free_group = False
|
self.is_not_in_free_group = False
|
||||||
self.free_group = []
|
self.free_group = []
|
||||||
@@ -274,6 +283,12 @@ class PagedTokenToKVPoolAllocator:
|
|||||||
if self.free_group:
|
if self.free_group:
|
||||||
self.free(torch.cat(self.free_group))
|
self.free(torch.cat(self.free_group))
|
||||||
|
|
||||||
|
def backup_state(self):
|
||||||
|
return self.free_pages
|
||||||
|
|
||||||
|
def restore_state(self, free_pages):
|
||||||
|
self.free_pages = free_pages
|
||||||
|
|
||||||
def clear(self):
|
def clear(self):
|
||||||
# The padded slot 0 is used for writing dummy outputs from padded tokens.
|
# The padded slot 0 is used for writing dummy outputs from padded tokens.
|
||||||
self.free_pages = torch.arange(
|
self.free_pages = torch.arange(
|
||||||
|
|||||||
@@ -116,16 +116,18 @@ def get_batch_sizes_to_capture(model_runner: ModelRunner):
|
|||||||
if capture_bs is None:
|
if capture_bs is None:
|
||||||
if server_args.speculative_algorithm is None:
|
if server_args.speculative_algorithm is None:
|
||||||
if server_args.disable_cuda_graph_padding:
|
if server_args.disable_cuda_graph_padding:
|
||||||
capture_bs = list(range(1, 33)) + [64, 96, 128, 160]
|
capture_bs = list(range(1, 33)) + range(40, 161, 16)
|
||||||
else:
|
else:
|
||||||
capture_bs = [1, 2, 4] + [i * 8 for i in range(1, 21)]
|
capture_bs = [1, 2, 4, 8] + list(range(16, 161, 8))
|
||||||
else:
|
else:
|
||||||
# Since speculative decoding requires more cuda graph memory, we
|
# Since speculative decoding requires more cuda graph memory, we
|
||||||
# capture less.
|
# capture less.
|
||||||
capture_bs = list(range(1, 9)) + list(range(9, 33, 2)) + [64, 96, 128, 160]
|
capture_bs = (
|
||||||
|
list(range(1, 9)) + list(range(10, 33, 2)) + list(range(40, 161, 16))
|
||||||
|
)
|
||||||
|
|
||||||
if _is_hip:
|
if _is_hip:
|
||||||
capture_bs += [i * 8 for i in range(21, 33)]
|
capture_bs += list(range(160, 257, 8))
|
||||||
|
|
||||||
if max(capture_bs) > model_runner.req_to_token_pool.size:
|
if max(capture_bs) > model_runner.req_to_token_pool.size:
|
||||||
# In some case (e.g., with a small GPU or --max-running-requests), the #max-running-requests
|
# In some case (e.g., with a small GPU or --max-running-requests), the #max-running-requests
|
||||||
|
|||||||
@@ -17,7 +17,7 @@
|
|||||||
"""Inference-only LLaMA model compatible with HuggingFace weights."""
|
"""Inference-only LLaMA model compatible with HuggingFace weights."""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union
|
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|||||||
@@ -15,6 +15,7 @@
|
|||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import dataclasses
|
import dataclasses
|
||||||
|
import json
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import random
|
import random
|
||||||
@@ -132,9 +133,9 @@ class ServerArgs:
|
|||||||
# Speculative decoding
|
# Speculative decoding
|
||||||
speculative_algorithm: Optional[str] = None
|
speculative_algorithm: Optional[str] = None
|
||||||
speculative_draft_model_path: Optional[str] = None
|
speculative_draft_model_path: Optional[str] = None
|
||||||
speculative_num_steps: int = 5
|
speculative_num_steps: Optional[int] = None
|
||||||
speculative_eagle_topk: int = 4
|
speculative_eagle_topk: Optional[int] = None
|
||||||
speculative_num_draft_tokens: int = 8
|
speculative_num_draft_tokens: Optional[int] = None
|
||||||
speculative_accept_threshold_single: float = 1.0
|
speculative_accept_threshold_single: float = 1.0
|
||||||
speculative_accept_threshold_acc: float = 1.0
|
speculative_accept_threshold_acc: float = 1.0
|
||||||
speculative_token_map: Optional[str] = None
|
speculative_token_map: Optional[str] = None
|
||||||
@@ -313,12 +314,29 @@ class ServerArgs:
|
|||||||
or self.speculative_algorithm == "EAGLE3"
|
or self.speculative_algorithm == "EAGLE3"
|
||||||
):
|
):
|
||||||
if self.max_running_requests is None:
|
if self.max_running_requests is None:
|
||||||
self.max_running_requests = 32
|
self.max_running_requests = 48
|
||||||
self.disable_overlap_schedule = True
|
self.disable_overlap_schedule = True
|
||||||
logger.info(
|
logger.info(
|
||||||
"Overlap scheduler is disabled because of using "
|
"Overlap scheduler is disabled because of using "
|
||||||
"eagle speculative decoding."
|
"eagle speculative decoding."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Auto choose parameters
|
||||||
|
if self.speculative_num_steps is None:
|
||||||
|
assert (
|
||||||
|
self.speculative_eagle_topk is None
|
||||||
|
and self.speculative_num_draft_tokens is None
|
||||||
|
)
|
||||||
|
(
|
||||||
|
self.speculative_num_steps,
|
||||||
|
self.speculative_eagle_topk,
|
||||||
|
self.speculative_num_draft_tokens,
|
||||||
|
) = auto_choose_speculative_params(self)
|
||||||
|
|
||||||
|
if self.page_size > 1 and self.speculative_eagle_topk > 1:
|
||||||
|
self.speculative_eagle_topk = 1
|
||||||
|
logger.info("speculative_eagle_topk is changed to 1 when page_size > 1")
|
||||||
|
|
||||||
# The token generated from the verify step is counted.
|
# The token generated from the verify step is counted.
|
||||||
# If sepculative_num_steps >= speculative_num_draft_tokens, the additional tokens will definitely be discarded.
|
# If sepculative_num_steps >= speculative_num_draft_tokens, the additional tokens will definitely be discarded.
|
||||||
# assert self.speculative_num_steps < self.speculative_num_draft_tokens
|
# assert self.speculative_num_steps < self.speculative_num_draft_tokens
|
||||||
@@ -1253,3 +1271,33 @@ class DeprecatedAction(argparse.Action):
|
|||||||
|
|
||||||
def __call__(self, parser, namespace, values, option_string=None):
|
def __call__(self, parser, namespace, values, option_string=None):
|
||||||
raise ValueError(self.help)
|
raise ValueError(self.help)
|
||||||
|
|
||||||
|
|
||||||
|
def auto_choose_speculative_params(self: ServerArgs):
|
||||||
|
"""
|
||||||
|
Automatically choose the parameters for speculative decoding.
|
||||||
|
|
||||||
|
You can tune them on your own models and prompts with scripts/playground/bench_speculative.py
|
||||||
|
"""
|
||||||
|
if self.decrypted_config_file:
|
||||||
|
config_path = self.decrypted_config_file
|
||||||
|
else:
|
||||||
|
config_path = os.path.join(self.model_path, "config.json")
|
||||||
|
if not os.path.exists(config_path):
|
||||||
|
raise ValueError(f"{config_path} is not found.")
|
||||||
|
|
||||||
|
config = json.load(open(config_path))
|
||||||
|
|
||||||
|
arch = config.get("architectures", ["Unknown"])[0]
|
||||||
|
|
||||||
|
if arch in ["LlamaForCausalLM"]:
|
||||||
|
# The default value for llama
|
||||||
|
return (5, 4, 8)
|
||||||
|
elif arch in ["DeepseekV3ForCausalLM", "DeepseekV2ForCausalLM"]:
|
||||||
|
# The default value for deepseek
|
||||||
|
return (5, 4, 8)
|
||||||
|
elif arch in ["Grok1ForCausalLM", "Grok1VForCausalLM"]:
|
||||||
|
return (5, 4, 8)
|
||||||
|
else:
|
||||||
|
# The default value for all other models
|
||||||
|
return (5, 4, 8)
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import os
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import TYPE_CHECKING, List, Optional
|
from typing import TYPE_CHECKING, List, Optional
|
||||||
|
|
||||||
@@ -10,11 +11,15 @@ import triton.language as tl
|
|||||||
|
|
||||||
from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton
|
from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton
|
||||||
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
||||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
from sglang.srt.managers.schedule_batch import (
|
||||||
|
ScheduleBatch,
|
||||||
|
get_last_loc,
|
||||||
|
global_server_args_dict,
|
||||||
|
)
|
||||||
from sglang.srt.mem_cache.memory_pool import TokenToKVPoolAllocator
|
from sglang.srt.mem_cache.memory_pool import TokenToKVPoolAllocator
|
||||||
from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode
|
from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode
|
||||||
from sglang.srt.speculative.build_eagle_tree import build_tree_kernel_efficient
|
from sglang.srt.speculative.build_eagle_tree import build_tree_kernel_efficient
|
||||||
from sglang.srt.utils import is_cuda_available, is_hip
|
from sglang.srt.utils import is_cuda_available, is_hip, next_power_of_2
|
||||||
|
|
||||||
if is_cuda_available():
|
if is_cuda_available():
|
||||||
from sgl_kernel import (
|
from sgl_kernel import (
|
||||||
@@ -34,6 +39,9 @@ import logging
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
SIMULATE_ACC_LEN = os.environ.get("SIMULATE_ACC_LEN")
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class EagleDraftInput:
|
class EagleDraftInput:
|
||||||
# The inputs for decode
|
# The inputs for decode
|
||||||
@@ -93,7 +101,7 @@ class EagleDraftInput:
|
|||||||
torch.cumsum(self.accept_length, axis=0, dtype=torch.int),
|
torch.cumsum(self.accept_length, axis=0, dtype=torch.int),
|
||||||
self.positions,
|
self.positions,
|
||||||
new_verified_id,
|
new_verified_id,
|
||||||
triton.next_power_of_2(speculative_num_steps + 1),
|
next_power_of_2(speculative_num_steps + 1),
|
||||||
)
|
)
|
||||||
|
|
||||||
batch.seq_lens_sum = sum(seq_lens_cpu)
|
batch.seq_lens_sum = sum(seq_lens_cpu)
|
||||||
@@ -225,18 +233,34 @@ class EagleVerifyInput:
|
|||||||
CaptureHiddenMode.FULL,
|
CaptureHiddenMode.FULL,
|
||||||
)
|
)
|
||||||
|
|
||||||
def prepare_for_verify(self, batch: ScheduleBatch):
|
def prepare_for_verify(self, batch: ScheduleBatch, page_size: int):
|
||||||
batch.input_ids = self.draft_token
|
batch.input_ids = self.draft_token
|
||||||
batch.out_cache_loc = batch.alloc_token_slots(batch.input_ids.numel())
|
|
||||||
|
if page_size == 1:
|
||||||
|
batch.out_cache_loc = batch.alloc_token_slots(len(batch.input_ids))
|
||||||
|
end_offset = batch.seq_lens + self.draft_token_num
|
||||||
|
else:
|
||||||
|
prefix_lens = batch.seq_lens
|
||||||
|
end_offset = prefix_lens + self.draft_token_num
|
||||||
|
last_loc = get_last_loc(
|
||||||
|
batch.req_to_token_pool.req_to_token,
|
||||||
|
batch.req_pool_indices,
|
||||||
|
prefix_lens,
|
||||||
|
)
|
||||||
|
batch.out_cache_loc = batch.alloc_paged_token_slots_extend(
|
||||||
|
prefix_lens, end_offset, last_loc, len(batch.input_ids)
|
||||||
|
)
|
||||||
|
self.last_loc = last_loc
|
||||||
|
|
||||||
bs = batch.batch_size()
|
bs = batch.batch_size()
|
||||||
assign_req_to_token_pool[(bs,)](
|
assign_req_to_token_pool[(bs,)](
|
||||||
batch.req_pool_indices,
|
batch.req_pool_indices,
|
||||||
batch.req_to_token_pool.req_to_token,
|
batch.req_to_token_pool.req_to_token,
|
||||||
batch.seq_lens,
|
batch.seq_lens,
|
||||||
batch.seq_lens + self.draft_token_num,
|
end_offset,
|
||||||
batch.out_cache_loc,
|
batch.out_cache_loc,
|
||||||
batch.req_to_token_pool.req_to_token.shape[1],
|
batch.req_to_token_pool.req_to_token.shape[1],
|
||||||
triton.next_power_of_2(bs),
|
next_power_of_2(bs),
|
||||||
)
|
)
|
||||||
|
|
||||||
def generate_attn_arg_prefill(
|
def generate_attn_arg_prefill(
|
||||||
@@ -282,6 +306,7 @@ class EagleVerifyInput:
|
|||||||
batch: ScheduleBatch,
|
batch: ScheduleBatch,
|
||||||
logits_output: torch.Tensor,
|
logits_output: torch.Tensor,
|
||||||
token_to_kv_pool_allocator: TokenToKVPoolAllocator,
|
token_to_kv_pool_allocator: TokenToKVPoolAllocator,
|
||||||
|
page_size: int,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Verify and find accepted tokens based on logits output and batch
|
Verify and find accepted tokens based on logits output and batch
|
||||||
@@ -305,6 +330,7 @@ class EagleVerifyInput:
|
|||||||
)
|
)
|
||||||
accept_length = torch.empty((bs,), dtype=torch.int32, device="cuda")
|
accept_length = torch.empty((bs,), dtype=torch.int32, device="cuda")
|
||||||
|
|
||||||
|
# Apply penalty
|
||||||
if sampling_info.penalizer_orchestrator.is_required:
|
if sampling_info.penalizer_orchestrator.is_required:
|
||||||
# This is a relaxed version of penalties for speculative decoding.
|
# This is a relaxed version of penalties for speculative decoding.
|
||||||
linear_penalty = torch.zeros(
|
linear_penalty = torch.zeros(
|
||||||
@@ -317,6 +343,7 @@ class EagleVerifyInput:
|
|||||||
torch.repeat_interleave(linear_penalty, self.draft_token_num, dim=0)
|
torch.repeat_interleave(linear_penalty, self.draft_token_num, dim=0)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Sample tokens
|
||||||
if batch.sampling_info.is_all_greedy:
|
if batch.sampling_info.is_all_greedy:
|
||||||
target_predict = torch.argmax(logits_output.next_token_logits, dim=-1)
|
target_predict = torch.argmax(logits_output.next_token_logits, dim=-1)
|
||||||
target_predict = target_predict.reshape(bs, self.draft_token_num)
|
target_predict = target_predict.reshape(bs, self.draft_token_num)
|
||||||
@@ -378,13 +405,24 @@ class EagleVerifyInput:
|
|||||||
deterministic=True,
|
deterministic=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if SIMULATE_ACC_LEN:
|
||||||
|
# Do simulation
|
||||||
|
accept_index = _generate_simulated_accept_index(
|
||||||
|
accept_index=accept_index,
|
||||||
|
predict=predict, # mutable
|
||||||
|
accept_length=accept_length, # mutable
|
||||||
|
simulate_acc_len=SIMULATE_ACC_LEN,
|
||||||
|
bs=bs,
|
||||||
|
spec_steps=self.spec_steps,
|
||||||
|
)
|
||||||
|
|
||||||
new_accept_index = []
|
new_accept_index = []
|
||||||
unfinished_index = []
|
unfinished_index = []
|
||||||
accept_index_cpu = accept_index.tolist()
|
accept_index_cpu = accept_index.tolist()
|
||||||
predict_cpu = predict.tolist()
|
predict_cpu = predict.tolist()
|
||||||
has_finished = False
|
has_finished = False
|
||||||
|
|
||||||
# iterate every accepted token and check if req has finished after append the token
|
# Iterate every accepted token and check if req has finished after append the token
|
||||||
# should be checked BEFORE free kv cache slots
|
# should be checked BEFORE free kv cache slots
|
||||||
for i, (req, accept_index_row) in enumerate(zip(batch.reqs, accept_index_cpu)):
|
for i, (req, accept_index_row) in enumerate(zip(batch.reqs, accept_index_cpu)):
|
||||||
new_accept_index_ = []
|
new_accept_index_ = []
|
||||||
@@ -407,13 +445,28 @@ class EagleVerifyInput:
|
|||||||
unfinished_index.append(i)
|
unfinished_index.append(i)
|
||||||
req.spec_verify_ct += 1
|
req.spec_verify_ct += 1
|
||||||
|
|
||||||
|
if has_finished:
|
||||||
|
accept_length = (accept_index != -1).sum(dim=1) - 1
|
||||||
|
|
||||||
|
# Free the KV cache for unaccepted tokens
|
||||||
|
accept_index = accept_index[accept_index != -1]
|
||||||
|
verified_id = predict[accept_index]
|
||||||
|
evict_mask = torch.full_like(self.draft_token, True, dtype=torch.bool)
|
||||||
|
evict_mask[accept_index] = False
|
||||||
|
|
||||||
|
if page_size != 1:
|
||||||
|
align_evict_mask_to_page_size[len(batch.seq_lens),](
|
||||||
|
batch.seq_lens,
|
||||||
|
evict_mask,
|
||||||
|
page_size,
|
||||||
|
self.draft_token_num,
|
||||||
|
next_power_of_2(self.draft_token_num),
|
||||||
|
)
|
||||||
|
|
||||||
|
token_to_kv_pool_allocator.free(batch.out_cache_loc[evict_mask])
|
||||||
|
|
||||||
|
# Construct EagleVerifyOutput
|
||||||
if not has_finished:
|
if not has_finished:
|
||||||
accept_index = accept_index[accept_index != -1]
|
|
||||||
verified_id = predict[accept_index]
|
|
||||||
evict_mask = torch.full_like(self.draft_token, True, dtype=torch.bool)
|
|
||||||
evict_mask[accept_index] = False
|
|
||||||
mem_need_free_idx = batch.out_cache_loc[evict_mask]
|
|
||||||
token_to_kv_pool_allocator.free(mem_need_free_idx)
|
|
||||||
batch.out_cache_loc = batch.out_cache_loc[accept_index]
|
batch.out_cache_loc = batch.out_cache_loc[accept_index]
|
||||||
assign_req_to_token_pool[(bs,)](
|
assign_req_to_token_pool[(bs,)](
|
||||||
batch.req_pool_indices,
|
batch.req_pool_indices,
|
||||||
@@ -422,7 +475,7 @@ class EagleVerifyInput:
|
|||||||
batch.seq_lens + accept_length + 1,
|
batch.seq_lens + accept_length + 1,
|
||||||
batch.out_cache_loc,
|
batch.out_cache_loc,
|
||||||
batch.req_to_token_pool.req_to_token.shape[1],
|
batch.req_to_token_pool.req_to_token.shape[1],
|
||||||
triton.next_power_of_2(bs),
|
next_power_of_2(bs),
|
||||||
)
|
)
|
||||||
batch.seq_lens.add_(accept_length + 1)
|
batch.seq_lens.add_(accept_length + 1)
|
||||||
accept_length_cpu = accept_length.tolist()
|
accept_length_cpu = accept_length.tolist()
|
||||||
@@ -443,13 +496,6 @@ class EagleVerifyInput:
|
|||||||
accepeted_indices=accept_index,
|
accepeted_indices=accept_index,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
accept_length = (accept_index != -1).sum(dim=1) - 1
|
|
||||||
accept_index = accept_index[accept_index != -1]
|
|
||||||
verified_id = predict[accept_index]
|
|
||||||
evict_mask = torch.full_like(self.draft_token, True, dtype=torch.bool)
|
|
||||||
evict_mask[accept_index] = False
|
|
||||||
mem_need_free_idx = batch.out_cache_loc[evict_mask]
|
|
||||||
token_to_kv_pool_allocator.free(mem_need_free_idx)
|
|
||||||
assign_req_to_token_pool[(bs,)](
|
assign_req_to_token_pool[(bs,)](
|
||||||
batch.req_pool_indices,
|
batch.req_pool_indices,
|
||||||
batch.req_to_token_pool.req_to_token,
|
batch.req_to_token_pool.req_to_token,
|
||||||
@@ -457,7 +503,7 @@ class EagleVerifyInput:
|
|||||||
batch.seq_lens + accept_length + 1,
|
batch.seq_lens + accept_length + 1,
|
||||||
batch.out_cache_loc[accept_index],
|
batch.out_cache_loc[accept_index],
|
||||||
batch.req_to_token_pool.req_to_token.shape[1],
|
batch.req_to_token_pool.req_to_token.shape[1],
|
||||||
triton.next_power_of_2(bs),
|
next_power_of_2(bs),
|
||||||
)
|
)
|
||||||
batch.seq_lens.add_(accept_length + 1)
|
batch.seq_lens.add_(accept_length + 1)
|
||||||
accept_length_cpu = accept_length.tolist()
|
accept_length_cpu = accept_length.tolist()
|
||||||
@@ -465,20 +511,21 @@ class EagleVerifyInput:
|
|||||||
draft_input = EagleDraftInput()
|
draft_input = EagleDraftInput()
|
||||||
if len(new_accept_index) > 0:
|
if len(new_accept_index) > 0:
|
||||||
new_accept_index = torch.tensor(new_accept_index, device="cuda")
|
new_accept_index = torch.tensor(new_accept_index, device="cuda")
|
||||||
|
unfinished_index_device = torch.tensor(unfinished_index, device="cuda")
|
||||||
draft_input.hidden_states = batch.spec_info.hidden_states[
|
draft_input.hidden_states = batch.spec_info.hidden_states[
|
||||||
new_accept_index
|
new_accept_index
|
||||||
]
|
]
|
||||||
draft_input.verified_id = predict[new_accept_index]
|
draft_input.verified_id = predict[new_accept_index]
|
||||||
draft_input.accept_length = accept_length[unfinished_index]
|
|
||||||
draft_input.accept_length_cpu = [
|
draft_input.accept_length_cpu = [
|
||||||
accept_length_cpu[i] for i in unfinished_index
|
accept_length_cpu[i] for i in unfinished_index
|
||||||
]
|
]
|
||||||
|
draft_input.accept_length = accept_length[unfinished_index_device]
|
||||||
if has_finished:
|
if has_finished:
|
||||||
draft_input.seq_lens_for_draft_extend = batch.seq_lens[
|
draft_input.seq_lens_for_draft_extend = batch.seq_lens[
|
||||||
unfinished_index
|
unfinished_index_device
|
||||||
]
|
]
|
||||||
draft_input.req_pool_indices_for_draft_extend = (
|
draft_input.req_pool_indices_for_draft_extend = (
|
||||||
batch.req_pool_indices[unfinished_index]
|
batch.req_pool_indices[unfinished_index_device]
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
draft_input.seq_lens_for_draft_extend = batch.seq_lens
|
draft_input.seq_lens_for_draft_extend = batch.seq_lens
|
||||||
@@ -564,13 +611,24 @@ def assign_draft_cache_locs(
|
|||||||
pool_len: tl.constexpr,
|
pool_len: tl.constexpr,
|
||||||
topk: tl.constexpr,
|
topk: tl.constexpr,
|
||||||
speculative_num_steps: tl.constexpr,
|
speculative_num_steps: tl.constexpr,
|
||||||
|
page_size: tl.constexpr,
|
||||||
):
|
):
|
||||||
BLOCK_SIZE: tl.constexpr = 32
|
BLOCK_SIZE: tl.constexpr = 32
|
||||||
pid = tl.program_id(axis=0)
|
pid = tl.program_id(axis=0)
|
||||||
kv_start = tl.load(seq_lens + pid)
|
kv_start = tl.load(seq_lens + pid)
|
||||||
kv_end = tl.load(seq_lens + pid) + topk * speculative_num_steps
|
|
||||||
|
if page_size == 1 or topk == 1:
|
||||||
|
kv_end = tl.load(seq_lens + pid) + topk * speculative_num_steps
|
||||||
|
out_cache_ptr = out_cache_loc + pid * topk * speculative_num_steps
|
||||||
|
else:
|
||||||
|
prefix_len = tl.load(seq_lens + pid)
|
||||||
|
last_page_len = prefix_len % page_size
|
||||||
|
num_new_page = (
|
||||||
|
last_page_len + speculative_num_steps + page_size - 1
|
||||||
|
) // page_size
|
||||||
|
kv_end = prefix_len // page_size * page_size + num_new_page * (page_size * topk)
|
||||||
|
|
||||||
token_pool = req_to_token + tl.load(req_pool_indices + pid) * pool_len
|
token_pool = req_to_token + tl.load(req_pool_indices + pid) * pool_len
|
||||||
out_cache_ptr = out_cache_loc + pid * topk * speculative_num_steps
|
|
||||||
|
|
||||||
num_loop = tl.cdiv(topk * speculative_num_steps, BLOCK_SIZE)
|
num_loop = tl.cdiv(topk * speculative_num_steps, BLOCK_SIZE)
|
||||||
for i in range(num_loop):
|
for i in range(num_loop):
|
||||||
@@ -642,6 +700,29 @@ def generate_draft_decode_kv_indices(
|
|||||||
tl.store(kv_indptr + zid, base + zid * iters)
|
tl.store(kv_indptr + zid, base + zid * iters)
|
||||||
|
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def align_evict_mask_to_page_size(
|
||||||
|
seq_lens,
|
||||||
|
evict_mask,
|
||||||
|
page_size: tl.constexpr,
|
||||||
|
num_draft_tokens: tl.constexpr,
|
||||||
|
BLOCK_SIZE: tl.constexpr,
|
||||||
|
):
|
||||||
|
t_range = tl.arange(0, BLOCK_SIZE)
|
||||||
|
|
||||||
|
bid = tl.program_id(axis=0)
|
||||||
|
seq_len = tl.load(seq_lens + bid)
|
||||||
|
io_mask = t_range < num_draft_tokens
|
||||||
|
mask_row = tl.load(evict_mask + bid * num_draft_tokens + t_range, mask=io_mask)
|
||||||
|
|
||||||
|
num_trues = tl.sum(mask_row)
|
||||||
|
num_false = num_draft_tokens - num_trues
|
||||||
|
|
||||||
|
start = (seq_len + num_false - 1) // page_size * page_size - seq_len
|
||||||
|
for i in range(max(start, 0), min(start + page_size, num_draft_tokens)):
|
||||||
|
tl.store(evict_mask + bid * num_draft_tokens + i, False)
|
||||||
|
|
||||||
|
|
||||||
@torch.compile(dynamic=True)
|
@torch.compile(dynamic=True)
|
||||||
def select_top_k_tokens(
|
def select_top_k_tokens(
|
||||||
i: int,
|
i: int,
|
||||||
@@ -699,3 +780,34 @@ def fast_topk(values, topk, dim):
|
|||||||
else:
|
else:
|
||||||
# Use topk for efficiency with larger k values
|
# Use topk for efficiency with larger k values
|
||||||
return torch.topk(values, topk, dim=dim)
|
return torch.topk(values, topk, dim=dim)
|
||||||
|
|
||||||
|
|
||||||
|
def _generate_simulated_accept_index(
|
||||||
|
accept_index,
|
||||||
|
predict,
|
||||||
|
accept_length,
|
||||||
|
simulate_acc_len,
|
||||||
|
bs,
|
||||||
|
spec_steps,
|
||||||
|
):
|
||||||
|
simulate_acc_len_float = float(simulate_acc_len)
|
||||||
|
simulated_values = torch.normal(
|
||||||
|
mean=simulate_acc_len_float,
|
||||||
|
std=1.0,
|
||||||
|
size=(1,),
|
||||||
|
device="cpu",
|
||||||
|
)
|
||||||
|
# clamp simulated values to be between 1 and self.spec_steps
|
||||||
|
simulated_values = torch.clamp(simulated_values, min=1.0, max=spec_steps)
|
||||||
|
simulate_acc_len = int(simulated_values.round().item())
|
||||||
|
|
||||||
|
accept_indx_first_col = accept_index[:, 0].view(-1, 1)
|
||||||
|
sim_accept_index = torch.full(
|
||||||
|
(bs, spec_steps + 1), -1, dtype=torch.int32, device="cuda"
|
||||||
|
)
|
||||||
|
sim_accept_index[:, :simulate_acc_len] = accept_indx_first_col + torch.arange(
|
||||||
|
simulate_acc_len, device=accept_index.device
|
||||||
|
)
|
||||||
|
accept_length.fill_(simulate_acc_len - 1)
|
||||||
|
predict.fill_(100) # some legit token id
|
||||||
|
return sim_accept_index
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ from sglang.srt.distributed import GroupCoordinator, patch_tensor_parallel_group
|
|||||||
from sglang.srt.layers.dp_attention import disable_dp_size
|
from sglang.srt.layers.dp_attention import disable_dp_size
|
||||||
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
||||||
from sglang.srt.layers.sampler import get_token_ids_logprobs, get_top_logprobs
|
from sglang.srt.layers.sampler import get_token_ids_logprobs, get_top_logprobs
|
||||||
from sglang.srt.managers.schedule_batch import ScheduleBatch
|
from sglang.srt.managers.schedule_batch import ScheduleBatch, get_last_loc
|
||||||
from sglang.srt.managers.tp_worker import TpModelWorker
|
from sglang.srt.managers.tp_worker import TpModelWorker
|
||||||
from sglang.srt.model_executor.forward_batch_info import (
|
from sglang.srt.model_executor.forward_batch_info import (
|
||||||
CaptureHiddenMode,
|
CaptureHiddenMode,
|
||||||
@@ -67,6 +67,7 @@ class EAGLEWorker(TpModelWorker):
|
|||||||
self.gpu_id = gpu_id
|
self.gpu_id = gpu_id
|
||||||
self.device = server_args.device
|
self.device = server_args.device
|
||||||
self.target_worker = target_worker
|
self.target_worker = target_worker
|
||||||
|
self.page_size = server_args.page_size
|
||||||
self.speculative_algorithm = SpeculativeAlgorithm.from_string(
|
self.speculative_algorithm = SpeculativeAlgorithm.from_string(
|
||||||
server_args.speculative_algorithm
|
server_args.speculative_algorithm
|
||||||
)
|
)
|
||||||
@@ -234,14 +235,11 @@ class EAGLEWorker(TpModelWorker):
|
|||||||
"""
|
"""
|
||||||
if batch.forward_mode.is_decode():
|
if batch.forward_mode.is_decode():
|
||||||
with self.draft_tp_context(self.draft_model_runner.tp_group):
|
with self.draft_tp_context(self.draft_model_runner.tp_group):
|
||||||
spec_info, to_free_cache_loc = self.draft(batch)
|
spec_info = self.draft(batch)
|
||||||
logits_output, verify_output, model_worker_batch = self.verify(
|
logits_output, verify_output, model_worker_batch = self.verify(
|
||||||
batch, spec_info
|
batch, spec_info
|
||||||
)
|
)
|
||||||
|
|
||||||
# Free cache loc (we put it here to avoid synchronization and hide kernel launch overhead.)
|
|
||||||
self.token_to_kv_pool_allocator.free(to_free_cache_loc)
|
|
||||||
|
|
||||||
# If it is None, it means all requests are finished
|
# If it is None, it means all requests are finished
|
||||||
if batch.spec_info.verified_id is not None:
|
if batch.spec_info.verified_id is not None:
|
||||||
with self.draft_tp_context(self.draft_model_runner.tp_group):
|
with self.draft_tp_context(self.draft_model_runner.tp_group):
|
||||||
@@ -305,9 +303,59 @@ class EAGLEWorker(TpModelWorker):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Allocate cache locations
|
# Allocate cache locations
|
||||||
out_cache_loc = batch.alloc_token_slots(
|
if self.page_size == 1:
|
||||||
num_seqs * self.topk * self.speculative_num_steps
|
out_cache_loc, token_to_kv_pool_state_backup = batch.alloc_token_slots(
|
||||||
)
|
num_seqs * self.topk * self.speculative_num_steps, backup_state=True
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
if self.topk == 1:
|
||||||
|
prefix_lens = batch.seq_lens
|
||||||
|
seq_lens = prefix_lens + self.speculative_num_steps
|
||||||
|
extend_num_tokens = num_seqs * self.speculative_num_steps
|
||||||
|
else:
|
||||||
|
# In this case, the last partial page needs to be duplicated.
|
||||||
|
# KV cache layout in batch.req_to_token_pool.req_to_token:
|
||||||
|
#
|
||||||
|
# | -------- | -- xxxx .. | -- xxxx .. | -- xxxx .. |
|
||||||
|
# prefix top-k = 0 tok-k = 1 top-k = 2
|
||||||
|
#
|
||||||
|
# "-" means prefix tokens
|
||||||
|
# "x" means speculative draft tokens
|
||||||
|
# "." means padded tokens
|
||||||
|
|
||||||
|
# TODO: fuse these ops
|
||||||
|
prefix_lens = batch.seq_lens
|
||||||
|
last_page_lens = prefix_lens % self.page_size
|
||||||
|
num_new_pages = (
|
||||||
|
last_page_lens + self.speculative_num_steps + self.page_size - 1
|
||||||
|
) // self.page_size
|
||||||
|
seq_lens = (
|
||||||
|
prefix_lens // self.page_size * self.page_size
|
||||||
|
+ num_new_pages * (self.page_size * self.topk)
|
||||||
|
)
|
||||||
|
extend_num_tokens = torch.sum(seq_lens - prefix_lens).item()
|
||||||
|
raise NotImplementedError(
|
||||||
|
"page_size > 1 and top_k > 1 are not supported."
|
||||||
|
)
|
||||||
|
# TODO: Support page_size > 1 and top_k > 1
|
||||||
|
# 1. Duplicate the KV cache in the last partial page for all top-k segments
|
||||||
|
# 2. Modify generate_draft_decode_kv_indices accordingly
|
||||||
|
|
||||||
|
last_loc = get_last_loc(
|
||||||
|
batch.req_to_token_pool.req_to_token,
|
||||||
|
batch.req_pool_indices,
|
||||||
|
prefix_lens,
|
||||||
|
)
|
||||||
|
out_cache_loc, token_to_kv_pool_state_backup = (
|
||||||
|
batch.alloc_paged_token_slots_extend(
|
||||||
|
prefix_lens,
|
||||||
|
seq_lens,
|
||||||
|
last_loc,
|
||||||
|
extend_num_tokens,
|
||||||
|
backup_state=True,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
assign_draft_cache_locs[(num_seqs,)](
|
assign_draft_cache_locs[(num_seqs,)](
|
||||||
batch.req_pool_indices,
|
batch.req_pool_indices,
|
||||||
batch.req_to_token_pool.req_to_token,
|
batch.req_to_token_pool.req_to_token,
|
||||||
@@ -316,6 +364,7 @@ class EAGLEWorker(TpModelWorker):
|
|||||||
batch.req_to_token_pool.req_to_token.shape[1],
|
batch.req_to_token_pool.req_to_token.shape[1],
|
||||||
self.topk,
|
self.topk,
|
||||||
self.speculative_num_steps,
|
self.speculative_num_steps,
|
||||||
|
self.page_size,
|
||||||
)
|
)
|
||||||
batch.out_cache_loc = out_cache_loc
|
batch.out_cache_loc = out_cache_loc
|
||||||
batch.seq_lens_sum = torch.sum(batch.seq_lens).item()
|
batch.seq_lens_sum = torch.sum(batch.seq_lens).item()
|
||||||
@@ -343,6 +392,8 @@ class EAGLEWorker(TpModelWorker):
|
|||||||
# Run forward steps
|
# Run forward steps
|
||||||
score_list, token_list, parents_list = self.draft_forward(forward_batch)
|
score_list, token_list, parents_list = self.draft_forward(forward_batch)
|
||||||
|
|
||||||
|
self.token_to_kv_pool_allocator.restore_state(token_to_kv_pool_state_backup)
|
||||||
|
|
||||||
ret = EagleVerifyInput.create(
|
ret = EagleVerifyInput.create(
|
||||||
spec_info.verified_id,
|
spec_info.verified_id,
|
||||||
score_list,
|
score_list,
|
||||||
@@ -354,7 +405,7 @@ class EAGLEWorker(TpModelWorker):
|
|||||||
self.speculative_num_steps,
|
self.speculative_num_steps,
|
||||||
self.server_args.speculative_num_draft_tokens,
|
self.server_args.speculative_num_draft_tokens,
|
||||||
)
|
)
|
||||||
return ret, out_cache_loc
|
return ret
|
||||||
|
|
||||||
def draft_forward(self, forward_batch: ForwardBatch):
|
def draft_forward(self, forward_batch: ForwardBatch):
|
||||||
# Parse args
|
# Parse args
|
||||||
@@ -411,7 +462,7 @@ class EAGLEWorker(TpModelWorker):
|
|||||||
return score_list, token_list, parents_list
|
return score_list, token_list, parents_list
|
||||||
|
|
||||||
def verify(self, batch: ScheduleBatch, spec_info: EagleVerifyInput):
|
def verify(self, batch: ScheduleBatch, spec_info: EagleVerifyInput):
|
||||||
spec_info.prepare_for_verify(batch)
|
spec_info.prepare_for_verify(batch, self.page_size)
|
||||||
batch.forward_mode = ForwardMode.TARGET_VERIFY
|
batch.forward_mode = ForwardMode.TARGET_VERIFY
|
||||||
batch.spec_info = spec_info
|
batch.spec_info = spec_info
|
||||||
model_worker_batch = batch.get_model_worker_batch()
|
model_worker_batch = batch.get_model_worker_batch()
|
||||||
@@ -421,7 +472,10 @@ class EAGLEWorker(TpModelWorker):
|
|||||||
self._detect_nan_if_needed(logits_output)
|
self._detect_nan_if_needed(logits_output)
|
||||||
spec_info.hidden_states = logits_output.hidden_states
|
spec_info.hidden_states = logits_output.hidden_states
|
||||||
res: EagleVerifyOutput = spec_info.verify(
|
res: EagleVerifyOutput = spec_info.verify(
|
||||||
batch, logits_output, self.token_to_kv_pool_allocator
|
batch,
|
||||||
|
logits_output,
|
||||||
|
self.token_to_kv_pool_allocator,
|
||||||
|
self.page_size,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Post process based on verified outputs.
|
# Post process based on verified outputs.
|
||||||
|
|||||||
@@ -76,11 +76,14 @@ def is_in_ci():
|
|||||||
|
|
||||||
|
|
||||||
if is_in_ci():
|
if is_in_ci():
|
||||||
DEFAULT_PORT_FOR_SRT_TEST_RUNNER = 5157
|
DEFAULT_PORT_FOR_SRT_TEST_RUNNER = (
|
||||||
DEFAULT_URL_FOR_TEST = "http://127.0.0.1:6157"
|
5000 + int(os.environ.get("CUDA_VISIBLE_DEVICES", "0")[0]) * 100
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
DEFAULT_PORT_FOR_SRT_TEST_RUNNER = 1157
|
DEFAULT_PORT_FOR_SRT_TEST_RUNNER = (
|
||||||
DEFAULT_URL_FOR_TEST = "http://127.0.0.1:2157"
|
7000 + int(os.environ.get("CUDA_VISIBLE_DEVICES", "0")[0]) * 100
|
||||||
|
)
|
||||||
|
DEFAULT_URL_FOR_TEST = f"http://127.0.0.1:{DEFAULT_PORT_FOR_SRT_TEST_RUNNER + 1000}"
|
||||||
|
|
||||||
|
|
||||||
def call_generate_lightllm(prompt, temperature, max_tokens, stop=None, url=None):
|
def call_generate_lightllm(prompt, temperature, max_tokens, stop=None, url=None):
|
||||||
@@ -1009,6 +1012,9 @@ def run_logprob_check(self: unittest.TestCase, arg: Tuple):
|
|||||||
|
|
||||||
|
|
||||||
class CustomTestCase(unittest.TestCase):
|
class CustomTestCase(unittest.TestCase):
|
||||||
|
pass
|
||||||
|
|
||||||
|
"""
|
||||||
def _callTestMethod(self, method):
|
def _callTestMethod(self, method):
|
||||||
max_retry = int(
|
max_retry = int(
|
||||||
os.environ.get("SGLANG_TEST_MAX_RETRY", "2" if is_in_ci() else "0")
|
os.environ.get("SGLANG_TEST_MAX_RETRY", "2" if is_in_ci() else "0")
|
||||||
@@ -1017,3 +1023,4 @@ class CustomTestCase(unittest.TestCase):
|
|||||||
lambda: super(CustomTestCase, self)._callTestMethod(method),
|
lambda: super(CustomTestCase, self)._callTestMethod(method),
|
||||||
max_retry=max_retry,
|
max_retry=max_retry,
|
||||||
)
|
)
|
||||||
|
"""
|
||||||
|
|||||||
@@ -18,7 +18,7 @@ pip install flashinfer_python==0.2.3 --find-links ${FLASHINFER_REPO} --force-rei
|
|||||||
pip install sgl-kernel==0.0.5.post4 --force-reinstall
|
pip install sgl-kernel==0.0.5.post4 --force-reinstall
|
||||||
|
|
||||||
pip install torch_memory_saver
|
pip install torch_memory_saver
|
||||||
pip install transformers==4.50.0 sentence_transformers accelerate==1.4.0 peft pandas datasets timm
|
pip install transformers==4.50.0 sentence_transformers accelerate==1.4.0 peft pandas datasets timm torchaudio
|
||||||
|
|
||||||
# For compling xgrammar kernels
|
# For compling xgrammar kernels
|
||||||
pip install cuda-python nvidia-cuda-nvrtc-cu12
|
pip install cuda-python nvidia-cuda-nvrtc-cu12
|
||||||
|
|||||||
@@ -26,7 +26,7 @@ suites = {
|
|||||||
TestFile("test_abort.py", 51),
|
TestFile("test_abort.py", 51),
|
||||||
TestFile("test_block_int8.py", 22),
|
TestFile("test_block_int8.py", 22),
|
||||||
TestFile("test_chunked_prefill.py", 336),
|
TestFile("test_chunked_prefill.py", 336),
|
||||||
TestFile("test_eagle_infer.py", 447),
|
TestFile("test_eagle_infer.py", 500),
|
||||||
TestFile("test_ebnf_constrained.py"),
|
TestFile("test_ebnf_constrained.py"),
|
||||||
TestFile("test_fp8_kernel.py", 2),
|
TestFile("test_fp8_kernel.py", 2),
|
||||||
TestFile("test_embedding_openai_server.py", 36),
|
TestFile("test_embedding_openai_server.py", 36),
|
||||||
|
|||||||
@@ -298,10 +298,16 @@ class TestEAGLEServer(CustomTestCase):
|
|||||||
print(f"{metrics=}")
|
print(f"{metrics=}")
|
||||||
self.assertGreater(metrics["accuracy"], 0.20)
|
self.assertGreater(metrics["accuracy"], 0.20)
|
||||||
|
|
||||||
server_info = requests.get(self.base_url + "/get_server_info")
|
server_info = requests.get(self.base_url + "/get_server_info").json()
|
||||||
avg_spec_accept_length = server_info.json()["avg_spec_accept_length"]
|
avg_spec_accept_length = server_info["avg_spec_accept_length"]
|
||||||
print(f"{avg_spec_accept_length=}")
|
print(f"{avg_spec_accept_length=}")
|
||||||
self.assertGreater(avg_spec_accept_length, 3.5)
|
|
||||||
|
speculative_eagle_topk = server_info["speculative_eagle_topk"]
|
||||||
|
|
||||||
|
if speculative_eagle_topk == 1:
|
||||||
|
self.assertGreater(avg_spec_accept_length, 2.5)
|
||||||
|
else:
|
||||||
|
self.assertGreater(avg_spec_accept_length, 3.5)
|
||||||
|
|
||||||
# Wait a little bit so that the memory check happens.
|
# Wait a little bit so that the memory check happens.
|
||||||
time.sleep(4)
|
time.sleep(4)
|
||||||
@@ -535,5 +541,36 @@ class TestEAGLEServerTriton(TestEAGLEServer):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestEAGLEServerPageSize(TestEAGLEServer):
|
||||||
|
@classmethod
|
||||||
|
def setUpClass(cls):
|
||||||
|
cls.base_url = DEFAULT_URL_FOR_TEST
|
||||||
|
cls.process = popen_launch_server(
|
||||||
|
DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST,
|
||||||
|
cls.base_url,
|
||||||
|
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||||
|
other_args=[
|
||||||
|
"--speculative-algorithm",
|
||||||
|
"EAGLE",
|
||||||
|
"--speculative-draft-model-path",
|
||||||
|
DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST,
|
||||||
|
"--speculative-num-steps",
|
||||||
|
5,
|
||||||
|
"--speculative-eagle-topk",
|
||||||
|
1,
|
||||||
|
"--speculative-num-draft-tokens",
|
||||||
|
6,
|
||||||
|
"--mem-fraction-static",
|
||||||
|
0.7,
|
||||||
|
"--chunked-prefill-size",
|
||||||
|
128,
|
||||||
|
"--max-running-requests",
|
||||||
|
8,
|
||||||
|
"--page-size",
|
||||||
|
4,
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|||||||
@@ -157,6 +157,7 @@ class TestFlashinferMLAMTP(CustomTestCase):
|
|||||||
self.assertGreater(metrics["accuracy"], 0.60)
|
self.assertGreater(metrics["accuracy"], 0.60)
|
||||||
|
|
||||||
server_info = requests.get(self.base_url + "/get_server_info")
|
server_info = requests.get(self.base_url + "/get_server_info")
|
||||||
|
print(f"{server_info=}")
|
||||||
avg_spec_accept_length = server_info.json()["avg_spec_accept_length"]
|
avg_spec_accept_length = server_info.json()["avg_spec_accept_length"]
|
||||||
print(f"{avg_spec_accept_length=}")
|
print(f"{avg_spec_accept_length=}")
|
||||||
self.assertGreater(avg_spec_accept_length, 2.5)
|
self.assertGreater(avg_spec_accept_length, 2.5)
|
||||||
|
|||||||
Reference in New Issue
Block a user