[Eagle] Refactor eagle speculative decoding (#3986)
Co-authored-by: Ke Bao <ISPObaoke@163.com>
This commit is contained in:
@@ -22,7 +22,7 @@ from typing import List, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from sglang.srt.mem_cache.memory_pool import MHATokenToKVPool, MLATokenToKVPoolHost
|
||||
from sglang.srt.mem_cache.memory_pool import MHATokenToKVPool, MHATokenToKVPoolHost
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -128,7 +128,7 @@ class HiCacheController:
|
||||
def __init__(
|
||||
self,
|
||||
mem_pool_device: MHATokenToKVPool,
|
||||
mem_pool_host: MLATokenToKVPoolHost,
|
||||
mem_pool_host: MHATokenToKVPoolHost,
|
||||
write_policy: str = "write_through_selective",
|
||||
):
|
||||
|
||||
|
||||
@@ -44,18 +44,16 @@ from sglang.srt.configs.model_config import ModelConfig
|
||||
from sglang.srt.constrained.base_grammar_backend import BaseGrammarObject
|
||||
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
|
||||
from sglang.srt.mem_cache.chunk_cache import ChunkCache
|
||||
from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
|
||||
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool, TokenToKVPoolAllocator
|
||||
from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, ForwardMode
|
||||
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
|
||||
from sglang.srt.sampling.sampling_params import SamplingParams
|
||||
from sglang.srt.server_args import ServerArgs
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.server_args import ServerArgs
|
||||
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
|
||||
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
|
||||
|
||||
|
||||
INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5
|
||||
|
||||
# Put some global args for easy access
|
||||
@@ -523,7 +521,7 @@ class ScheduleBatch:
|
||||
# Request, memory pool, and cache
|
||||
reqs: List[Req]
|
||||
req_to_token_pool: ReqToTokenPool = None
|
||||
token_to_kv_pool: BaseTokenToKVPool = None
|
||||
token_to_kv_pool_allocator: TokenToKVPoolAllocator = None
|
||||
tree_cache: BasePrefixCache = None
|
||||
|
||||
# Batch configs
|
||||
@@ -596,7 +594,7 @@ class ScheduleBatch:
|
||||
cls,
|
||||
reqs: List[Req],
|
||||
req_to_token_pool: ReqToTokenPool,
|
||||
token_to_kv_pool: ReqToTokenPool,
|
||||
token_to_kv_pool_allocator: TokenToKVPoolAllocator,
|
||||
tree_cache: BasePrefixCache,
|
||||
model_config: ModelConfig,
|
||||
enable_overlap: bool,
|
||||
@@ -606,7 +604,7 @@ class ScheduleBatch:
|
||||
return cls(
|
||||
reqs=reqs,
|
||||
req_to_token_pool=req_to_token_pool,
|
||||
token_to_kv_pool=token_to_kv_pool,
|
||||
token_to_kv_pool_allocator=token_to_kv_pool_allocator,
|
||||
tree_cache=tree_cache,
|
||||
model_config=model_config,
|
||||
enable_overlap=enable_overlap,
|
||||
@@ -637,19 +635,19 @@ class ScheduleBatch:
|
||||
return req_pool_indices
|
||||
|
||||
def alloc_token_slots(self, num_tokens: int):
|
||||
out_cache_loc = self.token_to_kv_pool.alloc(num_tokens)
|
||||
out_cache_loc = self.token_to_kv_pool_allocator.alloc(num_tokens)
|
||||
|
||||
if out_cache_loc is None:
|
||||
if self.tree_cache is not None:
|
||||
self.tree_cache.evict(num_tokens, self.token_to_kv_pool.free)
|
||||
out_cache_loc = self.token_to_kv_pool.alloc(num_tokens)
|
||||
self.tree_cache.evict(num_tokens, self.token_to_kv_pool_allocator.free)
|
||||
out_cache_loc = self.token_to_kv_pool_allocator.alloc(num_tokens)
|
||||
|
||||
if out_cache_loc is None:
|
||||
phase_str = "Prefill" if self.forward_mode.is_extend() else "Decode"
|
||||
logger.error(
|
||||
f"{phase_str} out of memory. Try to lower your batch size.\n"
|
||||
f"Try to allocate {num_tokens} tokens.\n"
|
||||
f"Avaliable tokens: {self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size()}\n"
|
||||
f"Avaliable tokens: {self.token_to_kv_pool_allocator.available_size() + self.tree_cache.evictable_size()}\n"
|
||||
)
|
||||
if self.tree_cache is not None:
|
||||
self.tree_cache.pretty_print()
|
||||
@@ -917,12 +915,12 @@ class ScheduleBatch:
|
||||
|
||||
def check_decode_mem(self, buf_multiplier=1):
|
||||
bs = len(self.reqs) * buf_multiplier
|
||||
if self.token_to_kv_pool.available_size() >= bs:
|
||||
if self.token_to_kv_pool_allocator.available_size() >= bs:
|
||||
return True
|
||||
|
||||
self.tree_cache.evict(bs, self.token_to_kv_pool.free)
|
||||
self.tree_cache.evict(bs, self.token_to_kv_pool_allocator.free)
|
||||
|
||||
if self.token_to_kv_pool.available_size() >= bs:
|
||||
if self.token_to_kv_pool_allocator.available_size() >= bs:
|
||||
return True
|
||||
|
||||
return False
|
||||
@@ -945,6 +943,10 @@ class ScheduleBatch:
|
||||
reverse=True,
|
||||
)
|
||||
|
||||
retracted_reqs = []
|
||||
seq_lens_cpu = self.seq_lens.cpu().numpy()
|
||||
first_iter = True
|
||||
|
||||
def get_required_tokens(num_reqs: int):
|
||||
headroom_for_spec_decode = 0
|
||||
if server_args.speculative_algorithm:
|
||||
@@ -958,18 +960,15 @@ class ScheduleBatch:
|
||||
num_reqs * global_config.retract_decode_steps + headroom_for_spec_decode
|
||||
)
|
||||
|
||||
retracted_reqs = []
|
||||
seq_lens_cpu = self.seq_lens.cpu().numpy()
|
||||
first_iter = True
|
||||
while (
|
||||
self.token_to_kv_pool.available_size()
|
||||
self.token_to_kv_pool_allocator.available_size()
|
||||
< get_required_tokens(len(sorted_indices))
|
||||
or first_iter
|
||||
):
|
||||
if len(sorted_indices) == 1:
|
||||
# Corner case: only one request left
|
||||
assert (
|
||||
self.token_to_kv_pool.available_size() > 0
|
||||
self.token_to_kv_pool_allocator.available_size() > 0
|
||||
), "No space left for only one request"
|
||||
break
|
||||
|
||||
@@ -983,7 +982,7 @@ class ScheduleBatch:
|
||||
token_indices = self.req_to_token_pool.req_to_token[
|
||||
req.req_pool_idx, : seq_lens_cpu[idx]
|
||||
]
|
||||
self.token_to_kv_pool.free(token_indices)
|
||||
self.token_to_kv_pool_allocator.free(token_indices)
|
||||
self.req_to_token_pool.free(req.req_pool_idx)
|
||||
del self.tree_cache.entries[req.rid]
|
||||
else:
|
||||
@@ -992,7 +991,7 @@ class ScheduleBatch:
|
||||
token_indices = self.req_to_token_pool.req_to_token[
|
||||
req.req_pool_idx, last_uncached_pos : seq_lens_cpu[idx]
|
||||
]
|
||||
self.token_to_kv_pool.free(token_indices)
|
||||
self.token_to_kv_pool_allocator.free(token_indices)
|
||||
self.req_to_token_pool.free(req.req_pool_idx)
|
||||
|
||||
# release the last node
|
||||
@@ -1001,10 +1000,13 @@ class ScheduleBatch:
|
||||
# NOTE(lsyin): we should use the newly evictable memory instantly.
|
||||
residual_size = (
|
||||
len(sorted_indices) * global_config.retract_decode_steps
|
||||
- self.token_to_kv_pool.available_size()
|
||||
- self.token_to_kv_pool_allocator.available_size()
|
||||
)
|
||||
residual_size = max(0, residual_size)
|
||||
self.tree_cache.evict(residual_size, self.token_to_kv_pool.free)
|
||||
self.tree_cache.evict(
|
||||
residual_size, self.token_to_kv_pool_allocator.free
|
||||
)
|
||||
|
||||
req.reset_for_retract()
|
||||
|
||||
self.filter_batch(keep_indices=sorted_indices)
|
||||
@@ -1183,7 +1185,7 @@ class ScheduleBatch:
|
||||
if self.spec_info:
|
||||
self.spec_info.merge_batch(other.spec_info)
|
||||
|
||||
def get_model_worker_batch(self):
|
||||
def get_model_worker_batch(self) -> ModelWorkerBatch:
|
||||
if self.forward_mode.is_decode_or_idle():
|
||||
extend_seq_lens = extend_prefix_lens = extend_logprob_start_lens = None
|
||||
else:
|
||||
@@ -1273,7 +1275,7 @@ class ModelWorkerBatch:
|
||||
req_pool_indices: torch.Tensor
|
||||
# The sequence length
|
||||
seq_lens: torch.Tensor
|
||||
# The indices of output tokens in the token_to_kv_pool
|
||||
# The indices of output tokens in the token_to_kv_pool_allocator
|
||||
out_cache_loc: torch.Tensor
|
||||
|
||||
# The sum of all sequence lengths
|
||||
|
||||
@@ -22,9 +22,13 @@ from typing import Dict, List, Optional, Set, Union
|
||||
|
||||
import torch
|
||||
|
||||
from sglang.srt.managers.schedule_batch import Req, ScheduleBatch
|
||||
from sglang.srt.managers.schedule_batch import (
|
||||
Req,
|
||||
ScheduleBatch,
|
||||
global_server_args_dict,
|
||||
)
|
||||
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
|
||||
from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool
|
||||
from sglang.srt.mem_cache.memory_pool import TokenToKVPoolAllocator
|
||||
from sglang.srt.mem_cache.radix_cache import RadixCache, TreeNode
|
||||
|
||||
# Clip the estimation of max_new_tokens for the request whose max_new_tokens is very large.
|
||||
@@ -75,7 +79,7 @@ class SchedulePolicy:
|
||||
|
||||
# It is used to find the matching prefix for in-batch prefix caching.
|
||||
self.waiting_queue_radix_tree = RadixCache(
|
||||
req_to_token_pool=None, token_to_kv_pool=None, disable=False
|
||||
req_to_token_pool=None, token_to_kv_pool_allocator=None, disable=False
|
||||
)
|
||||
|
||||
def calc_priority(self, waiting_queue: List[Req]) -> bool:
|
||||
@@ -251,7 +255,7 @@ class PrefillAdder:
|
||||
def __init__(
|
||||
self,
|
||||
tree_cache: BasePrefixCache,
|
||||
token_to_kv_pool: BaseTokenToKVPool,
|
||||
token_to_kv_pool_allocator: TokenToKVPoolAllocator,
|
||||
running_batch: ScheduleBatch,
|
||||
new_token_ratio: float,
|
||||
rem_input_tokens: int,
|
||||
@@ -259,7 +263,7 @@ class PrefillAdder:
|
||||
mixed_with_decode_tokens: int = 0,
|
||||
):
|
||||
self.tree_cache = tree_cache
|
||||
self.token_to_kv_pool = token_to_kv_pool
|
||||
self.token_to_kv_pool_allocator = token_to_kv_pool_allocator
|
||||
self.running_batch = running_batch
|
||||
self.new_token_ratio = new_token_ratio
|
||||
self.rem_input_tokens = rem_input_tokens - mixed_with_decode_tokens
|
||||
@@ -291,7 +295,7 @@ class PrefillAdder:
|
||||
@property
|
||||
def rem_total_tokens(self):
|
||||
return (
|
||||
self.token_to_kv_pool.available_size()
|
||||
self.token_to_kv_pool_allocator.available_size()
|
||||
+ self.tree_cache.evictable_size()
|
||||
- self.rem_total_token_offset
|
||||
)
|
||||
@@ -299,7 +303,7 @@ class PrefillAdder:
|
||||
@property
|
||||
def cur_rem_tokens(self):
|
||||
return (
|
||||
self.token_to_kv_pool.available_size()
|
||||
self.token_to_kv_pool_allocator.available_size()
|
||||
+ self.tree_cache.evictable_size()
|
||||
- self.cur_rem_token_offset
|
||||
)
|
||||
@@ -332,7 +336,6 @@ class PrefillAdder:
|
||||
req.extend_input_len = min(req.extend_input_len, self.rem_chunk_tokens)
|
||||
req.fill_ids = req.fill_ids[: len(req.prefix_indices) + req.extend_input_len]
|
||||
self.can_run_list.append(req)
|
||||
|
||||
self._prefill_one_req(
|
||||
0,
|
||||
req.extend_input_len,
|
||||
@@ -400,8 +403,8 @@ class PrefillAdder:
|
||||
tokens_freed += tokens_occupied
|
||||
|
||||
if (
|
||||
self.rem_chunk_tokens is None
|
||||
or req.extend_input_len <= self.rem_chunk_tokens
|
||||
self.rem_chunk_tokens is None # chunked prefill is disabled
|
||||
or req.extend_input_len <= self.rem_chunk_tokens # it is the last chunk
|
||||
):
|
||||
# Non-chunked prefill
|
||||
self.can_run_list.append(req)
|
||||
@@ -411,10 +414,11 @@ class PrefillAdder:
|
||||
min(req.sampling_params.max_new_tokens, CLIP_MAX_NEW_TOKENS_ESTIMATION),
|
||||
)
|
||||
else:
|
||||
if self.rem_chunk_tokens == 0:
|
||||
return AddReqResult.OTHER
|
||||
|
||||
# Chunked prefill
|
||||
trunc_len = self.rem_chunk_tokens
|
||||
if trunc_len == 0:
|
||||
return AddReqResult.OTHER
|
||||
|
||||
req.extend_input_len = trunc_len
|
||||
req.fill_ids = req.fill_ids[:trunc_len]
|
||||
@@ -457,10 +461,11 @@ class PrefillAdder:
|
||||
),
|
||||
)
|
||||
else:
|
||||
if self.rem_chunk_tokens == 0:
|
||||
return AddReqResult.OTHER
|
||||
|
||||
# Chunked prefill
|
||||
trunc_len = self.rem_chunk_tokens
|
||||
if trunc_len == 0:
|
||||
return AddReqResult.OTHER
|
||||
|
||||
req.extend_input_len = trunc_len
|
||||
req.fill_ids = req.fill_ids[: len(req.prefix_indices) + trunc_len]
|
||||
|
||||
@@ -164,7 +164,7 @@ class Scheduler:
|
||||
self.server_args.speculative_num_draft_tokens
|
||||
+ (
|
||||
self.server_args.speculative_eagle_topk
|
||||
* self.server_args.speculative_num_steps
|
||||
* self.server_args.speculative_num_draft_tokens
|
||||
)
|
||||
)
|
||||
if not self.spec_algorithm.is_none()
|
||||
@@ -309,7 +309,9 @@ class Scheduler:
|
||||
)
|
||||
|
||||
# Init memory pool and cache
|
||||
self.req_to_token_pool, self.token_to_kv_pool = self.tp_worker.get_memory_pool()
|
||||
self.req_to_token_pool, self.token_to_kv_pool_allocator = (
|
||||
self.tp_worker.get_memory_pool()
|
||||
)
|
||||
|
||||
if (
|
||||
server_args.chunked_prefill_size is not None
|
||||
@@ -317,18 +319,18 @@ class Scheduler:
|
||||
):
|
||||
self.tree_cache = ChunkCache(
|
||||
req_to_token_pool=self.req_to_token_pool,
|
||||
token_to_kv_pool=self.token_to_kv_pool,
|
||||
token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
|
||||
)
|
||||
else:
|
||||
if self.enable_hierarchical_cache:
|
||||
self.tree_cache = HiRadixCache(
|
||||
req_to_token_pool=self.req_to_token_pool,
|
||||
token_to_kv_pool=self.token_to_kv_pool,
|
||||
token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
|
||||
)
|
||||
else:
|
||||
self.tree_cache = RadixCache(
|
||||
req_to_token_pool=self.req_to_token_pool,
|
||||
token_to_kv_pool=self.token_to_kv_pool,
|
||||
token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
|
||||
disable=server_args.disable_radix_cache,
|
||||
)
|
||||
|
||||
@@ -458,7 +460,6 @@ class Scheduler:
|
||||
(ResumeMemoryOccupationReqInput, self.resume_memory_occupation),
|
||||
(ProfileReq, self.profile),
|
||||
(GetInternalStateReq, self.get_internal_state),
|
||||
(SetInternalStateReq, self.set_internal_state),
|
||||
]
|
||||
)
|
||||
|
||||
@@ -809,7 +810,8 @@ class Scheduler:
|
||||
running_bs: int,
|
||||
):
|
||||
num_used = self.max_total_num_tokens - (
|
||||
self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size()
|
||||
self.token_to_kv_pool_allocator.available_size()
|
||||
+ self.tree_cache.evictable_size()
|
||||
)
|
||||
self._largest_prefill_len = max(
|
||||
self._largest_prefill_len, adder.log_input_tokens
|
||||
@@ -844,7 +846,8 @@ class Scheduler:
|
||||
self.num_generated_tokens = 0
|
||||
num_running_reqs = len(self.running_batch.reqs) if self.running_batch else 0
|
||||
num_used = self.max_total_num_tokens - (
|
||||
self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size()
|
||||
self.token_to_kv_pool_allocator.available_size()
|
||||
+ self.tree_cache.evictable_size()
|
||||
)
|
||||
|
||||
if RECORD_STEP_TIME:
|
||||
@@ -894,7 +897,8 @@ class Scheduler:
|
||||
|
||||
def check_memory(self):
|
||||
available_size = (
|
||||
self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size()
|
||||
self.token_to_kv_pool_allocator.available_size()
|
||||
+ self.tree_cache.evictable_size()
|
||||
)
|
||||
protected_size = self.tree_cache.protected_size()
|
||||
memory_leak = available_size != (
|
||||
@@ -999,7 +1003,7 @@ class Scheduler:
|
||||
# Prefill policy
|
||||
adder = PrefillAdder(
|
||||
self.tree_cache,
|
||||
self.token_to_kv_pool,
|
||||
self.token_to_kv_pool_allocator,
|
||||
self.running_batch,
|
||||
self.new_token_ratio,
|
||||
self.max_prefill_tokens,
|
||||
@@ -1099,7 +1103,7 @@ class Scheduler:
|
||||
new_batch = ScheduleBatch.init_new(
|
||||
can_run_list,
|
||||
self.req_to_token_pool,
|
||||
self.token_to_kv_pool,
|
||||
self.token_to_kv_pool_allocator,
|
||||
self.tree_cache,
|
||||
self.model_config,
|
||||
self.enable_overlap,
|
||||
@@ -1143,8 +1147,6 @@ class Scheduler:
|
||||
|
||||
retracted_reqs, new_token_ratio = batch.retract_decode(self.server_args)
|
||||
self.new_token_ratio = new_token_ratio
|
||||
if self.draft_worker:
|
||||
self.draft_worker.finish_request(retracted_reqs)
|
||||
|
||||
logger.info(
|
||||
"Decode out of memory happened. "
|
||||
@@ -1184,11 +1186,12 @@ class Scheduler:
|
||||
logits_output, next_token_ids = self.tp_worker.forward_batch_generation(
|
||||
model_worker_batch
|
||||
)
|
||||
bid = model_worker_batch.bid
|
||||
else:
|
||||
(
|
||||
logits_output,
|
||||
next_token_ids,
|
||||
model_worker_batch,
|
||||
bid,
|
||||
num_accepted_tokens,
|
||||
) = self.draft_worker.forward_batch_speculative_generation(batch)
|
||||
self.spec_num_total_accepted_tokens += (
|
||||
@@ -1214,7 +1217,7 @@ class Scheduler:
|
||||
next_token_ids=next_token_ids,
|
||||
extend_input_len_per_req=extend_input_len_per_req,
|
||||
extend_logprob_start_len_per_req=extend_logprob_start_len_per_req,
|
||||
bid=model_worker_batch.bid,
|
||||
bid=bid,
|
||||
)
|
||||
else: # embedding or reward model
|
||||
model_worker_batch = batch.get_model_worker_batch()
|
||||
@@ -1230,6 +1233,7 @@ class Scheduler:
|
||||
result: Union[GenerationBatchResult, EmbeddingBatchResult],
|
||||
):
|
||||
if batch.forward_mode.is_decode():
|
||||
assert isinstance(result, GenerationBatchResult)
|
||||
self.process_batch_result_decode(batch, result)
|
||||
if batch.is_empty():
|
||||
self.running_batch = None
|
||||
@@ -1302,7 +1306,7 @@ class Scheduler:
|
||||
if self.is_mixed_chunk and self.enable_overlap and req.finished():
|
||||
# Free the one delayed token for the mixed decode batch
|
||||
j = len(batch.out_cache_loc) - len(batch.reqs) + i
|
||||
self.token_to_kv_pool.free(batch.out_cache_loc[j : j + 1])
|
||||
self.token_to_kv_pool_allocator.free(batch.out_cache_loc[j : j + 1])
|
||||
continue
|
||||
|
||||
if req.is_chunked <= 0:
|
||||
@@ -1420,23 +1424,27 @@ class Scheduler:
|
||||
self.num_generated_tokens += len(batch.reqs)
|
||||
|
||||
if self.enable_overlap:
|
||||
assert batch.spec_algorithm.is_none()
|
||||
logits_output, next_token_ids = self.tp_worker.resolve_batch_result(bid)
|
||||
next_token_logprobs = logits_output.next_token_logprobs
|
||||
else:
|
||||
elif batch.spec_algorithm.is_none():
|
||||
# spec decoding handles output logprobs inside verify process.
|
||||
next_token_ids = next_token_ids.tolist()
|
||||
if batch.return_logprob:
|
||||
next_token_logprobs = logits_output.next_token_logprobs.tolist()
|
||||
|
||||
self.token_to_kv_pool.free_group_begin()
|
||||
self.token_to_kv_pool_allocator.free_group_begin()
|
||||
|
||||
# Check finish condition
|
||||
# NOTE: the length of reqs and next_token_ids don't match if it is spec decoding.
|
||||
# We should ignore using next_token_ids for spec decoding cases.
|
||||
for i, (req, next_token_id) in enumerate(zip(batch.reqs, next_token_ids)):
|
||||
if req.is_retracted:
|
||||
continue
|
||||
|
||||
if self.enable_overlap and req.finished():
|
||||
# Free the one delayed token
|
||||
self.token_to_kv_pool.free(batch.out_cache_loc[i : i + 1])
|
||||
self.token_to_kv_pool_allocator.free(batch.out_cache_loc[i : i + 1])
|
||||
continue
|
||||
|
||||
if batch.spec_algorithm.is_none():
|
||||
@@ -1479,7 +1487,7 @@ class Scheduler:
|
||||
batch.next_batch_sampling_info.sampling_info_done.set()
|
||||
self.stream_output(batch.reqs, batch.return_logprob)
|
||||
|
||||
self.token_to_kv_pool.free_group_end()
|
||||
self.token_to_kv_pool_allocator.free_group_end()
|
||||
|
||||
self.forward_ct_decode = (self.forward_ct_decode + 1) % (1 << 30)
|
||||
if (
|
||||
@@ -1718,9 +1726,6 @@ class Scheduler:
|
||||
and not self.model_config.is_multimodal_gen
|
||||
)
|
||||
):
|
||||
if self.draft_worker and req.finished():
|
||||
self.draft_worker.finish_request(req)
|
||||
|
||||
rids.append(req.rid)
|
||||
finished_reasons.append(
|
||||
req.finished_reason.to_json() if req.finished_reason else None
|
||||
@@ -1860,7 +1865,7 @@ class Scheduler:
|
||||
idle_batch = ScheduleBatch.init_new(
|
||||
[],
|
||||
self.req_to_token_pool,
|
||||
self.token_to_kv_pool,
|
||||
self.token_to_kv_pool_allocator,
|
||||
self.tree_cache,
|
||||
self.model_config,
|
||||
self.enable_overlap,
|
||||
@@ -1916,11 +1921,11 @@ class Scheduler:
|
||||
if self.grammar_backend:
|
||||
self.grammar_backend.reset()
|
||||
self.req_to_token_pool.clear()
|
||||
self.token_to_kv_pool.clear()
|
||||
self.token_to_kv_pool_allocator.clear()
|
||||
|
||||
if not self.spec_algorithm.is_none():
|
||||
self.draft_worker.model_runner.req_to_token_pool.clear()
|
||||
self.draft_worker.model_runner.token_to_kv_pool.clear()
|
||||
self.draft_worker.model_runner.token_to_kv_pool_allocator.clear()
|
||||
|
||||
self.num_generated_tokens = 0
|
||||
self.forward_ct_decode = 0
|
||||
|
||||
@@ -82,8 +82,6 @@ from sglang.srt.managers.io_struct import (
|
||||
ResumeMemoryOccupationReqInput,
|
||||
ResumeMemoryOccupationReqOutput,
|
||||
SessionParams,
|
||||
SetInternalStateReq,
|
||||
SetInternalStateReqOutput,
|
||||
TokenizedEmbeddingReqInput,
|
||||
TokenizedGenerateReqInput,
|
||||
UpdateWeightFromDiskReqInput,
|
||||
@@ -257,9 +255,6 @@ class TokenizerManager:
|
||||
self.get_internal_state_communicator = _Communicator(
|
||||
self.send_to_scheduler, server_args.dp_size
|
||||
)
|
||||
self.set_internal_state_communicator = _Communicator(
|
||||
self.send_to_scheduler, server_args.dp_size
|
||||
)
|
||||
|
||||
self._result_dispatcher = TypeBasedDispatcher(
|
||||
[
|
||||
@@ -309,10 +304,6 @@ class TokenizerManager:
|
||||
GetInternalStateReqOutput,
|
||||
self.get_internal_state_communicator.handle_recv,
|
||||
),
|
||||
(
|
||||
SetInternalStateReqOutput,
|
||||
self.set_internal_state_communicator.handle_recv,
|
||||
),
|
||||
(HealthCheckOutput, lambda x: None),
|
||||
]
|
||||
)
|
||||
@@ -774,14 +765,6 @@ class TokenizerManager:
|
||||
)
|
||||
return res[0].internal_state
|
||||
|
||||
async def set_internal_state(
|
||||
self, obj: SetInternalStateReq
|
||||
) -> SetInternalStateReqOutput:
|
||||
res: List[SetInternalStateReqOutput] = (
|
||||
await self.set_internal_state_communicator(obj)
|
||||
)
|
||||
return res[0]
|
||||
|
||||
def get_log_request_metadata(self):
|
||||
max_length = None
|
||||
skip_names = None
|
||||
|
||||
@@ -30,6 +30,7 @@ from sglang.srt.managers.io_struct import (
|
||||
UpdateWeightsFromTensorReqInput,
|
||||
)
|
||||
from sglang.srt.managers.schedule_batch import ModelWorkerBatch, global_server_args_dict
|
||||
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool, TokenToKVPoolAllocator
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.model_executor.model_runner import ModelRunner
|
||||
from sglang.srt.server_args import ServerArgs
|
||||
@@ -49,6 +50,8 @@ class TpModelWorker:
|
||||
dp_rank: Optional[int],
|
||||
nccl_port: int,
|
||||
is_draft_worker: bool = False,
|
||||
req_to_token_pool: Optional[ReqToTokenPool] = None,
|
||||
token_to_kv_pool_allocator: Optional[TokenToKVPoolAllocator] = None,
|
||||
):
|
||||
# Parse args
|
||||
self.tp_rank = tp_rank
|
||||
@@ -77,6 +80,8 @@ class TpModelWorker:
|
||||
nccl_port=nccl_port,
|
||||
server_args=server_args,
|
||||
is_draft_worker=is_draft_worker,
|
||||
req_to_token_pool=req_to_token_pool,
|
||||
token_to_kv_pool_allocator=token_to_kv_pool_allocator,
|
||||
)
|
||||
if server_args.skip_tokenizer_init:
|
||||
self.tokenizer = self.processor = None
|
||||
@@ -154,7 +159,7 @@ class TpModelWorker:
|
||||
def get_memory_pool(self):
|
||||
return (
|
||||
self.model_runner.req_to_token_pool,
|
||||
self.model_runner.token_to_kv_pool,
|
||||
self.model_runner.token_to_kv_pool_allocator,
|
||||
)
|
||||
|
||||
def forward_batch_generation(
|
||||
|
||||
@@ -100,7 +100,7 @@ class TpModelWorkerClient:
|
||||
def get_memory_pool(self):
|
||||
return (
|
||||
self.worker.model_runner.req_to_token_pool,
|
||||
self.worker.model_runner.token_to_kv_pool,
|
||||
self.worker.model_runner.token_to_kv_pool_allocator,
|
||||
)
|
||||
|
||||
def forward_thread_func(self):
|
||||
|
||||
Reference in New Issue
Block a user