diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index 8e55fb1d7..7cfda1656 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -405,9 +405,9 @@ class ScheduleBatch: # Request, memory pool, and cache reqs: List[Req] - req_to_token_pool: ReqToTokenPool - token_to_kv_pool: BaseTokenToKVPool - tree_cache: BasePrefixCache + req_to_token_pool: ReqToTokenPool = None + token_to_kv_pool: BaseTokenToKVPool = None + tree_cache: BasePrefixCache = None forward_mode: ForwardMode = None sampling_info: SamplingBatchInfo = None @@ -874,12 +874,9 @@ class ScheduleBatch: def copy(self): return ScheduleBatch( reqs=self.reqs, - req_to_token_pool=self.req_to_token_pool, - token_to_kv_pool=self.token_to_kv_pool, - tree_cache=self.tree_cache, forward_mode=self.forward_mode, - output_ids=self.output_ids, - sampling_info=self.sampling_info, + out_cache_loc=self.out_cache_loc, + return_logprob=self.return_logprob, decoding_reqs=self.decoding_reqs, ) @@ -929,7 +926,7 @@ class ModelWorkerBatch: forward_mode=self.forward_mode, input_ids=self.input_ids.clone(), req_pool_indices=self.req_pool_indices, - seq_lens=self.seq_lens, + seq_lens=self.seq_lens.clone(), out_cache_loc=self.out_cache_loc, return_logprob=self.return_logprob, top_logprobs_nums=self.top_logprobs_nums, diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 16c43dd16..10a76d53d 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -261,12 +261,7 @@ class Scheduler: self.resolve_next_token_ids = ( lambda bid, x: self.tp_worker.resolve_future_token_ids(bid) ) - - def cache_finished_req(req): - free_delta = int(self.running_batch and req in self.cur_batch.reqs) - self.tree_cache.cache_finished_req(req, free_delta=free_delta) - - self.cache_finished_req = cache_finished_req + self.cache_finished_req = self.tree_cache.cache_finished_req else: self.forward_batch_generation = self.tp_worker.forward_batch_generation self.resolve_next_token_ids = lambda bid, x: x.tolist() @@ -798,7 +793,6 @@ class Scheduler: i, req, logprob_pt, next_token_ids, logits_output ) else: # embedding or reward model - assert batch.extend_num_tokens != 0 embeddings, bid = result embeddings = embeddings.tolist() @@ -838,6 +832,7 @@ class Scheduler: # Check finish condition for i, (req, next_token_id) in enumerate(zip(batch.reqs, next_token_ids)): if self.server_args.enable_overlap_schedule and req.finished(): + self.token_to_kv_pool.free(batch.out_cache_loc[i : i + 1]) continue req.completion_tokens_wo_jump_forward += 1 diff --git a/python/sglang/srt/managers/tp_worker.py b/python/sglang/srt/managers/tp_worker.py index d416aa64a..814b2d2cc 100644 --- a/python/sglang/srt/managers/tp_worker.py +++ b/python/sglang/srt/managers/tp_worker.py @@ -149,14 +149,12 @@ class TpModelWorker: ) # Resolve future tokens in the input - # logger.info(f"raw input {model_worker_batch.input_ids=}") tic2 = time.time() resolved_input_ids = model_worker_batch.input_ids future_mask = resolved_input_ids < 0 resolved_input_ids[future_mask] = self.future_token_ids_map[ -resolved_input_ids[future_mask] ] - # logger.info(f"resolved input {model_worker_batch.input_ids=}") # Run forward logits_output, next_token_ids = self.forward_batch_generation( @@ -215,12 +213,13 @@ class TpModelWorker: self.future_logits_output_ct += 1 bs = len(model_worker_batch.seq_lens) - future_next_token_ids = -torch.arange( - self.future_token_ids_ct + 1, - self.future_token_ids_ct + 1 + bs, - dtype=torch.int32, - device=self.device, - ) + with torch.cuda.stream(self.forward_stream): + future_next_token_ids = -torch.arange( + self.future_token_ids_ct + 1, + self.future_token_ids_ct + 1 + bs, + dtype=torch.int32, + device=self.device, + ) self.future_token_ids_ct = ( self.future_token_ids_ct + bs ) % self.future_token_ids_limit diff --git a/python/sglang/srt/mem_cache/chunk_cache.py b/python/sglang/srt/mem_cache/chunk_cache.py index e13a2075a..3c430aba3 100644 --- a/python/sglang/srt/mem_cache/chunk_cache.py +++ b/python/sglang/srt/mem_cache/chunk_cache.py @@ -38,16 +38,14 @@ class ChunkCache(BasePrefixCache): max_prefix_len = len(key) return entry.value[:max_prefix_len], entry - def cache_finished_req( - self, req: Req, token_ids: Optional[List[int]] = None, free_delta: int = 0 - ): + def cache_finished_req(self, req: Req, token_ids: Optional[List[int]] = None): if token_ids is None: token_id_len = len(req.origin_input_ids) + len(req.output_ids) - 1 else: token_id_len = len(token_ids) kv_indices = self.req_to_token_pool.req_to_token[ - req.req_pool_idx, : token_id_len + free_delta + req.req_pool_idx, :token_id_len ] self.req_to_token_pool.free(req.req_pool_idx) self.token_to_kv_pool.free(kv_indices) diff --git a/python/sglang/srt/mem_cache/radix_cache.py b/python/sglang/srt/mem_cache/radix_cache.py index 6e2dee3c2..ca294c3bd 100644 --- a/python/sglang/srt/mem_cache/radix_cache.py +++ b/python/sglang/srt/mem_cache/radix_cache.py @@ -97,9 +97,7 @@ class RadixCache(BasePrefixCache): value = [x for x in key] return self._insert_helper(self.root_node, key, value) - def cache_finished_req( - self, req: Req, token_ids: Optional[List[int]] = None, free_delta: int = 0 - ): + def cache_finished_req(self, req: Req, token_ids: Optional[List[int]] = None): """Cache request when it finishes.""" if self.disable: if token_ids is None: @@ -108,7 +106,7 @@ class RadixCache(BasePrefixCache): token_ids_len = len(token_ids) kv_indices = self.req_to_token_pool.req_to_token[ - req.req_pool_idx, : token_ids_len + free_delta + req.req_pool_idx, :token_ids_len ] self.token_to_kv_pool.free(kv_indices) self.req_to_token_pool.free(req.req_pool_idx) @@ -123,12 +121,6 @@ class RadixCache(BasePrefixCache): # Radix Cache takes one ref in memory pool new_prefix_len = self.insert(token_ids, kv_indices.clone()) self.token_to_kv_pool.free(kv_indices[len(req.prefix_indices) : new_prefix_len]) - if free_delta: - self.token_to_kv_pool.free( - self.req_to_token_pool.req_to_token[ - req.req_pool_idx, len(token_ids) : len(token_ids) + 1 - ] - ) # Remove req slot release the cache lock self.req_to_token_pool.free(req.req_pool_idx) diff --git a/python/sglang/srt/server.py b/python/sglang/srt/server.py index 5716815e0..ceb2d55c2 100644 --- a/python/sglang/srt/server.py +++ b/python/sglang/srt/server.py @@ -542,6 +542,8 @@ def _wait_and_warmup(server_args, pipe_finish_writer, pid): kill_child_process(pid, including_parent=False) return + # logger.info(f"{res.json()=}") + logger.info("The server is fired up and ready to roll!") if pipe_finish_writer is not None: pipe_finish_writer.send("ready")