diff --git a/python/sglang/global_config.py b/python/sglang/global_config.py index 9ce4480bf..b02ce9f81 100644 --- a/python/sglang/global_config.py +++ b/python/sglang/global_config.py @@ -19,7 +19,6 @@ class GlobalConfig: self.init_new_token_ratio = 0.7 self.base_min_new_token_ratio = 0.1 self.new_token_ratio_decay = 0.001 - self.new_token_ratio_recovery = 0.05 # Runtime constants: The threshold (number of tokens) to trigger layer-wise cuda sync. # This can improve the speed for large batch sizes during prefill. diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index 4e9b9eb2f..205f31797 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -100,6 +100,9 @@ class Req: self.output_ids = [] # Each decode stage's output ids self.input_ids = None # input_ids = origin_input_ids + output_ids + # Memory info + self.req_pool_idx = None + # For incremental decoding # ----- | --------- read_ids -------| # ----- | surr_ids | @@ -321,6 +324,9 @@ class ScheduleBatch: return_logprob=return_logprob, ) + def batch_size(self): + return len(self.reqs) if self.reqs is not None else 0 + def is_empty(self): return len(self.reqs) == 0 @@ -328,52 +334,22 @@ class ScheduleBatch: # Return whether batch has at least 1 streaming request return any(r.stream for r in self.reqs) - def prepare_for_extend(self, vocab_size: int, int_token_logit_bias: torch.Tensor): - device = "cuda" - bs = len(self.reqs) - reqs = self.reqs - input_ids = [r.input_ids[len(r.prefix_indices) :] for r in reqs] - prefix_indices = [r.prefix_indices for r in reqs] - - # Handle prefix - flatten_input_ids = [] - extend_lens = [] - prefix_lens = [] - seq_lens = [] - - req_pool_indices = self.req_to_token_pool.alloc(bs) - + def alloc_req_slots(self, num_reqs): + req_pool_indices = self.req_to_token_pool.alloc(num_reqs) if req_pool_indices is None: raise RuntimeError( "Out of memory. " "Please set a smaller number for `--max-running-requests`." ) + return req_pool_indices - req_pool_indices_cpu = req_pool_indices.cpu().numpy() - for i in range(bs): - flatten_input_ids.extend(input_ids[i]) - extend_lens.append(len(input_ids[i])) + def alloc_token_slots(self, num_tokens: int): + out_cache_loc = self.token_to_kv_pool.alloc(num_tokens) - if len(prefix_indices[i]) == 0: - prefix_lens.append(0) - else: - prefix_lens.append(len(prefix_indices[i])) - self.req_to_token_pool.req_to_token[req_pool_indices_cpu[i]][ - : len(prefix_indices[i]) - ] = prefix_indices[i] - - seq_lens.append(prefix_lens[-1] + extend_lens[-1]) - - position_ids_offsets = torch.zeros((bs,), dtype=torch.int32, device=device) - - # Allocate memory - seq_lens, prefix_lens = np.array(seq_lens), np.array(prefix_lens) - extend_num_tokens = seq_lens.sum() - prefix_lens.sum() - out_cache_loc = self.token_to_kv_pool.alloc(extend_num_tokens) if out_cache_loc is None: if self.tree_cache is not None: - self.tree_cache.evict(extend_num_tokens, self.token_to_kv_pool.free) - out_cache_loc = self.token_to_kv_pool.alloc(extend_num_tokens) + self.tree_cache.evict(num_tokens, self.token_to_kv_pool.free) + out_cache_loc = self.token_to_kv_pool.alloc(num_tokens) if out_cache_loc is None: logger.error("Prefill out of memory. Try to lower your batch size.") @@ -381,40 +357,11 @@ class ScheduleBatch: self.tree_cache.pretty_print() exit(1) - pt = 0 - for i in range(bs): - self.req_to_token_pool.req_to_token[req_pool_indices_cpu[i]][ - prefix_lens[i] : prefix_lens[i] + extend_lens[i] - ] = out_cache_loc[pt : pt + extend_lens[i]] - pt += extend_lens[i] - - # Handle logit bias but only allocate when needed - logit_bias = None - for i in range(bs): - if reqs[i].sampling_params.dtype == "int": - if logit_bias is None: - logit_bias = torch.zeros( - (bs, vocab_size), dtype=torch.float32, device=device - ) - logit_bias[i][: len(int_token_logit_bias)] = int_token_logit_bias - - # Set fields - self.input_ids = torch.tensor( - flatten_input_ids, dtype=torch.int32, device=device - ) - self.pixel_values = [r.pixel_values for r in reqs] - self.image_sizes = [r.image_size for r in reqs] - self.image_offsets = [ - r.image_offset - p_len for r, p_len in zip(reqs, prefix_lens) - ] - self.req_pool_indices = req_pool_indices - self.seq_lens = torch.tensor(seq_lens, dtype=torch.int32, device=device) - self.prefix_lens = torch.tensor(prefix_lens, dtype=torch.int32, device=device) - self.position_ids_offsets = position_ids_offsets - self.extend_num_tokens = extend_num_tokens - self.out_cache_loc = out_cache_loc - self.top_logprobs_nums = [r.top_logprobs_num for r in reqs] + return out_cache_loc + def batch_sampling_params(self, vocab_size, int_token_logit_bias): + device = "cuda" + bs, reqs = self.batch_size(), self.reqs self.temperatures = torch.tensor( [r.sampling_params.temperature for r in reqs], dtype=torch.float, @@ -436,10 +383,78 @@ class ScheduleBatch: dtype=torch.float, device=device, ) - self.logit_bias = logit_bias + + # Handle logit bias but only allocate when needed + self.logit_bias = None + for i in range(bs): + if reqs[i].sampling_params.dtype == "int": + if self.logit_bias is None: + self.logit_bias = torch.zeros( + (bs, vocab_size), dtype=torch.float32, device=device + ) + self.logit_bias[i][: len(int_token_logit_bias)] = int_token_logit_bias + + def prepare_for_extend(self, vocab_size: int, int_token_logit_bias: torch.Tensor): + device = "cuda" + bs = self.batch_size() + reqs = self.reqs + input_ids = [r.input_ids[len(r.prefix_indices) :] for r in reqs] + prefix_indices = [r.prefix_indices for r in reqs] + + # Handle prefix + extend_lens = [] + prefix_lens = [] + seq_lens = [] + + req_pool_indices_cpu = self.alloc_req_slots(bs) + + for i, req in enumerate(reqs): + req.req_pool_idx = req_pool_indices_cpu[i] + extend_lens.append(len(input_ids[i])) + + if len(prefix_indices[i]) == 0: + prefix_lens.append(0) + else: + prefix_lens.append(len(prefix_indices[i])) + self.req_to_token_pool.req_to_token[req.req_pool_idx][ + : len(prefix_indices[i]) + ] = prefix_indices[i] + + seq_lens.append(prefix_lens[-1] + extend_lens[-1]) + + # Allocate memory + seq_lens, prefix_lens = np.array(seq_lens), np.array(prefix_lens) + extend_num_tokens = seq_lens.sum() - prefix_lens.sum() + out_cache_loc = self.alloc_token_slots(extend_num_tokens) + + pt = 0 + for i, req in enumerate(reqs): + self.req_to_token_pool.req_to_token[req.req_pool_idx][ + prefix_lens[i] : prefix_lens[i] + extend_lens[i] + ] = out_cache_loc[pt : pt + extend_lens[i]] + pt += extend_lens[i] + + # Set fields + with torch.device("cuda"): + self.input_ids = torch.tensor(sum(input_ids, []), dtype=torch.int32) + self.req_pool_indices = torch.tensor(req_pool_indices_cpu) + self.seq_lens = torch.tensor(seq_lens, dtype=torch.int32) + self.position_ids_offsets = torch.zeros((bs,), dtype=torch.int32) + + self.pixel_values = [r.pixel_values for r in reqs] + self.image_sizes = [r.image_size for r in reqs] + self.image_offsets = [ + r.image_offset - p_len for r, p_len in zip(reqs, prefix_lens) + ] + self.prefix_lens = torch.tensor(prefix_lens, dtype=torch.int32, device=device) + self.extend_num_tokens = extend_num_tokens + self.out_cache_loc = out_cache_loc + self.top_logprobs_nums = [r.top_logprobs_num for r in reqs] + + self.batch_sampling_params(vocab_size, int_token_logit_bias) def check_decode_mem(self): - bs = len(self.reqs) + bs = self.batch_size() if self.token_to_kv_pool.available_size() >= bs: return True @@ -464,7 +479,6 @@ class ScheduleBatch: retracted_reqs = [] seq_lens_cpu = self.seq_lens.cpu().numpy() - req_pool_indices_cpu = self.req_pool_indices.cpu().numpy() while ( self.token_to_kv_pool.available_size() < len(sorted_indices) * global_config.retract_decode_steps @@ -482,20 +496,20 @@ class ScheduleBatch: if isinstance(self.tree_cache, ChunkCache): # ChunkCache does not have eviction - token_indices = self.req_to_token_pool.req_to_token[ - req_pool_indices_cpu[idx] - ][: seq_lens_cpu[idx]] + token_indices = self.req_to_token_pool.req_to_token[req.req_pool_idx][ + : seq_lens_cpu[idx] + ] self.token_to_kv_pool.free(token_indices) - self.req_to_token_pool.free(int(req_pool_indices_cpu[idx])) + self.req_to_token_pool.free(req.req_pool_idx) del self.tree_cache.entries[req.rid] else: # TODO: apply more fine-grained retraction last_uncached_pos = len(req.prefix_indices) - token_indices = self.req_to_token_pool.req_to_token[ - req_pool_indices_cpu[idx] - ][last_uncached_pos : seq_lens_cpu[idx]] + token_indices = self.req_to_token_pool.req_to_token[req.req_pool_idx][ + last_uncached_pos : seq_lens_cpu[idx] + ] self.token_to_kv_pool.free(token_indices) - self.req_to_token_pool.free(int(req_pool_indices_cpu[idx])) + self.req_to_token_pool.free(req.req_pool_idx) # release the last node self.tree_cache.dec_lock_ref(req.last_node) @@ -533,8 +547,6 @@ class ScheduleBatch: jump_forward_reqs = [] filter_indices = [i for i in range(len(self.reqs))] - req_pool_indices_cpu = None - for i, req in enumerate(self.reqs): if req.jump_forward_map is not None: jump_forward_bytes = req.jump_forward_map.jump_forward_byte( @@ -584,13 +596,11 @@ class ScheduleBatch: req.vid += 1 # insert the old request into tree_cache - if req_pool_indices_cpu is None: - req_pool_indices_cpu = self.req_pool_indices.tolist() self.tree_cache.cache_req( rid=req.rid, token_ids=cur_all_ids, last_uncached_pos=len(req.prefix_indices), - req_pool_idx=req_pool_indices_cpu[i], + req_pool_idx=req.req_pool_idx, ) # unlock the last node @@ -626,14 +636,8 @@ class ScheduleBatch: self.prefix_lens = None # Alloc mem - bs = len(self.reqs) - self.out_cache_loc = self.token_to_kv_pool.alloc(bs) - - if self.out_cache_loc is None: - logger.error("Decode out of memory. Try to lower your batch size.") - if self.tree_cache is not None: - self.tree_cache.pretty_print() - exit(1) + bs = self.batch_size() + self.out_cache_loc = self.alloc_token_slots(bs) self.req_to_token_pool.req_to_token[ self.req_pool_indices, self.seq_lens - 1 diff --git a/python/sglang/srt/managers/tp_worker.py b/python/sglang/srt/managers/tp_worker.py index 54d6805d8..cd543da34 100644 --- a/python/sglang/srt/managers/tp_worker.py +++ b/python/sglang/srt/managers/tp_worker.py @@ -200,7 +200,6 @@ class ModelTpServer: ) self.new_token_ratio = self.min_new_token_ratio self.new_token_ratio_decay = global_config.new_token_ratio_decay - self.new_token_ratio_recovery = global_config.new_token_ratio_recovery def exposed_step(self, recv_reqs): try: @@ -625,13 +624,12 @@ class ModelTpServer: req.output_top_logprobs.append(output.output_top_logprobs[i]) def cache_filled_batch(self, batch: ScheduleBatch): - req_pool_indices_cpu = batch.req_pool_indices.cpu().numpy() for i, req in enumerate(batch.reqs): new_prefix_indices, new_last_node = self.tree_cache.cache_req( rid=req.rid, token_ids=tuple(req.input_ids), last_uncached_pos=len(req.prefix_indices), - req_pool_idx=req_pool_indices_cpu[i], + req_pool_idx=req.req_pool_idx, del_in_memory_pool=False, old_last_node=req.last_node, ) @@ -639,7 +637,7 @@ class ModelTpServer: if req is self.current_inflight_req: # inflight request would get a new req idx - self.req_to_token_pool.free(int(req_pool_indices_cpu[i])) + self.req_to_token_pool.free(req.req_pool_idx) def forward_decode_batch(self, batch: ScheduleBatch): # Check if decode out of memory @@ -782,14 +780,13 @@ class ModelTpServer: # Remove finished reqs if finished_indices: # Update radix cache - req_pool_indices_cpu = batch.req_pool_indices.tolist() for i in finished_indices: req = batch.reqs[i] self.tree_cache.cache_req( rid=req.rid, token_ids=tuple(req.origin_input_ids + req.output_ids)[:-1], last_uncached_pos=len(req.prefix_indices), - req_pool_idx=req_pool_indices_cpu[i], + req_pool_idx=req.req_pool_idx, ) self.tree_cache.dec_lock_ref(req.last_node) diff --git a/python/sglang/srt/mem_cache/memory_pool.py b/python/sglang/srt/mem_cache/memory_pool.py index 761b668bd..9036d73d0 100644 --- a/python/sglang/srt/mem_cache/memory_pool.py +++ b/python/sglang/srt/mem_cache/memory_pool.py @@ -16,6 +16,7 @@ limitations under the License. """Memory pool.""" import logging +from typing import List import torch @@ -27,34 +28,29 @@ class ReqToTokenPool: def __init__(self, size: int, max_context_len: int): self.size = size - self.mem_state = torch.ones((size,), dtype=torch.bool, device="cuda") + self.free_slots = list(range(size)) self.req_to_token = torch.empty( (size, max_context_len), dtype=torch.int32, device="cuda" ) self.can_use_mem_size = size - def alloc(self, need_size: int): - if need_size > self.can_use_mem_size: + def alloc(self, need_size: int) -> List[int]: + if need_size > len(self.free_slots): return None - select_index = ( - torch.nonzero(self.mem_state).squeeze(1)[:need_size].to(torch.int32) - ) - self.mem_state[select_index] = False - self.can_use_mem_size -= need_size + select_index = self.free_slots[:need_size] + self.free_slots = self.free_slots[need_size:] return select_index def free(self, free_index): - self.mem_state[free_index] = True if isinstance(free_index, (int,)): - self.can_use_mem_size += 1 + self.free_slots.append(free_index) else: - self.can_use_mem_size += free_index.shape[0] + self.free_slots.extend(free_index) def clear(self): - self.mem_state.fill_(True) - self.can_use_mem_size = len(self.mem_state) + self.free_slots = list(range(self.size)) class BaseTokenToKVPool: