"""Meta data for requests and batches""" from dataclasses import dataclass from enum import IntEnum, auto from typing import List import numpy as np import torch from sglang.srt.managers.controller.radix_cache import RadixCache from sglang.srt.memory_pool import ReqToTokenPool, TokenToKVPool class ForwardMode(IntEnum): PREFILL = auto() EXTEND = auto() DECODE = auto() class BaseFinishReason: def __init__(self, is_error: bool = False): self.is_error = is_error def __str__(self): raise NotImplementedError("Subclasses must implement this method") class FINISH_MATCHED_TOKEN(BaseFinishReason): def __init__(self, matched: int | List[int]): super().__init__() self.matched = matched def __str__(self) -> str: return f"FINISH_MATCHED_TOKEN: {self.matched}" class FINISH_LENGTH(BaseFinishReason): def __init__(self, length: int): super().__init__() self.length = length def __str__(self) -> str: return f"FINISH_LENGTH: {self.length}" class FINISH_MATCHED_STR(BaseFinishReason): def __init__(self, matched: str): super().__init__() self.matched = matched def __str__(self) -> str: return f"FINISH_MATCHED_STR: {self.matched}" class FINISH_ABORT(BaseFinishReason): def __init__(self): super().__init__(is_error=True) def __str__(self) -> str: return "FINISH_ABORT" class Req: def __init__(self, rid, origin_input_text, origin_input_ids): self.rid = rid self.origin_input_text = origin_input_text self.origin_input_ids = origin_input_ids self.origin_input_ids_unpadded = origin_input_ids # before image padding self.prev_output_str = "" self.prev_output_ids = [] self.output_ids = [] self.input_ids = None # input_ids = origin_input_ids + prev_output_ids # The number of decoded tokens for token usage report. Note that # this does not include the jump forward tokens. self.completion_tokens_wo_jump_forward = 0 # For vision input self.pixel_values = None self.image_size = None self.image_offset = 0 self.pad_value = None # Sampling parameters self.sampling_params = None self.stream = False self.tokenizer = None # Check finish self.finished_reason = None # Prefix info self.extend_input_len = 0 self.prefix_indices = [] self.last_node = None # Logprobs self.return_logprob = False self.logprob_start_len = 0 self.top_logprobs_num = 0 self.normalized_prompt_logprob = None self.prefill_token_logprobs = None self.prefill_top_logprobs = None self.decode_token_logprobs = [] self.decode_top_logprobs = [] # The tokens is prefilled but need to be considered as decode tokens # and should be updated for the decode logprobs self.last_update_decode_tokens = 0 # Constrained decoding self.regex_fsm = None self.regex_fsm_state = 0 self.jump_forward_map = None # whether request reached finished condition def finished(self) -> bool: return self.finished_reason is not None def partial_decode(self, ids): first_token = self.tokenizer.convert_ids_to_tokens(ids[0]) first_token = ( first_token.decode() if isinstance(first_token, bytes) else first_token ) return (" " if first_token.startswith("▁") else "") + self.tokenizer.decode(ids) def max_new_tokens(self): return self.sampling_params.max_new_tokens def check_finished(self): if self.finished(): return if ( len(self.prev_output_ids) + len(self.output_ids) >= self.sampling_params.max_new_tokens ): self.finished_reason = FINISH_LENGTH(len(self.prev_output_ids) + len(self.output_ids)) return if ( self.output_ids[-1] == self.tokenizer.eos_token_id and not self.sampling_params.ignore_eos ): self.finished_reason = FINISH_MATCHED_TOKEN(matched=self.tokenizer.eos_token_id) return if len(self.sampling_params.stop_strs) > 0: tail_str = self.tokenizer.decode( self.output_ids[-(self.sampling_params.stop_str_max_len + 1) :] ) for stop_str in self.sampling_params.stop_strs: # FIXME: (minor) try incremental match in prev_output_str if stop_str in tail_str or stop_str in self.prev_output_str: self.finished_reason = FINISH_MATCHED_STR(matched=stop_str) return def jump_forward_and_retokenize(self, jump_forward_str, next_state): # FIXME: This logic does not really solve the problem of determining whether # there should be a leading space. cur_output_str = self.partial_decode(self.output_ids) # TODO(lsyin): apply re-tokenize only for decode tokens so that we do not need origin_input_text anymore if self.origin_input_text is None: # Recovering text can only use unpadded ids self.origin_input_text = self.tokenizer.decode( self.origin_input_ids_unpadded ) all_text = ( self.origin_input_text + self.prev_output_str + cur_output_str + jump_forward_str ) all_ids = self.tokenizer.encode(all_text) prompt_tokens = len(self.origin_input_ids_unpadded) self.origin_input_ids = all_ids[:prompt_tokens] self.origin_input_ids_unpadded = self.origin_input_ids # NOTE: the output ids may not strictly correspond to the output text old_prev_output_ids = self.prev_output_ids self.prev_output_ids = all_ids[prompt_tokens:] self.prev_output_str = self.prev_output_str + cur_output_str + jump_forward_str self.output_ids = [] self.regex_fsm_state = next_state if self.return_logprob: # For fast-forward part's logprobs k = 0 for i, old_id in enumerate(old_prev_output_ids): if old_id == self.prev_output_ids[i]: k = k + 1 else: break self.decode_token_logprobs = self.decode_token_logprobs[:k] self.decode_top_logprobs = self.decode_top_logprobs[:k] self.logprob_start_len = prompt_tokens + k self.last_update_decode_tokens = len(self.prev_output_ids) - k # print("=" * 100) # print(f"Catch jump forward:\n{jump_forward_str}") # print(self.tokenizer.convert_ids_to_tokens(self.input_ids)) # print(self.tokenizer.convert_ids_to_tokens(new_input_ids)) # print(f"Output and jump forward str:\n{self.output_and_jump_forward_str}") # print("*" * 100) def __repr__(self): return f"rid(n={self.rid}, " f"input_ids={self.origin_input_ids}, " @dataclass class Batch: reqs: List[Req] req_to_token_pool: ReqToTokenPool token_to_kv_pool: TokenToKVPool tree_cache: RadixCache # batched arguments to model runner input_ids: torch.Tensor = None req_pool_indices: torch.Tensor = None seq_lens: torch.Tensor = None prefix_lens: torch.Tensor = None position_ids_offsets: torch.Tensor = None out_cache_loc: torch.Tensor = None out_cache_cont_start: torch.Tensor = None out_cache_cont_end: torch.Tensor = None # for processing logprobs return_logprob: bool = False top_logprobs_nums: List[int] = None # for multimodal pixel_values: List[torch.Tensor] = None image_sizes: List[List[int]] = None image_offsets: List[int] = None # other arguments for control output_ids: torch.Tensor = None extend_num_tokens: int = None # batched sampling params temperatures: torch.Tensor = None top_ps: torch.Tensor = None top_ks: torch.Tensor = None frequency_penalties: torch.Tensor = None presence_penalties: torch.Tensor = None logit_bias: torch.Tensor = None @classmethod def init_new(cls, reqs, req_to_token_pool, token_to_kv_pool, tree_cache): return_logprob = any(req.return_logprob for req in reqs) return cls( reqs=reqs, req_to_token_pool=req_to_token_pool, token_to_kv_pool=token_to_kv_pool, tree_cache=tree_cache, return_logprob=return_logprob, ) def is_empty(self): return len(self.reqs) == 0 def prepare_for_extend(self, vocab_size: int, int_token_logit_bias: torch.Tensor): device = "cuda" bs = len(self.reqs) reqs = self.reqs input_ids = [r.input_ids[len(r.prefix_indices) :] for r in reqs] prefix_indices = [r.prefix_indices for r in reqs] # Handle prefix flatten_input_ids = [] extend_lens = [] prefix_lens = [] seq_lens = [] req_pool_indices = self.req_to_token_pool.alloc(bs) req_pool_indices_cpu = req_pool_indices.cpu().numpy() for i in range(bs): flatten_input_ids.extend(input_ids[i]) extend_lens.append(len(input_ids[i])) if len(prefix_indices[i]) == 0: prefix_lens.append(0) else: prefix_lens.append(len(prefix_indices[i])) self.req_to_token_pool.req_to_token[req_pool_indices_cpu[i]][ : len(prefix_indices[i]) ] = prefix_indices[i] seq_lens.append(prefix_lens[-1] + extend_lens[-1]) position_ids_offsets = torch.zeros((bs,), dtype=torch.int32, device=device) # Alloc mem seq_lens, prefix_lens = np.array(seq_lens), np.array(prefix_lens) extend_num_tokens = seq_lens.sum() - prefix_lens.sum() out_cache_loc = self.token_to_kv_pool.alloc(extend_num_tokens) if out_cache_loc is None: self.tree_cache.evict(extend_num_tokens, self.token_to_kv_pool.dec_refs) out_cache_loc = self.token_to_kv_pool.alloc(extend_num_tokens) if out_cache_loc is None: print("Prefill out of memory. This should never happen.") self.tree_cache.pretty_print() exit() pt = 0 for i in range(bs): self.req_to_token_pool.req_to_token[req_pool_indices_cpu[i]][ prefix_lens[i] : prefix_lens[i] + extend_lens[i] ] = out_cache_loc[pt : pt + extend_lens[i]] pt += extend_lens[i] # Handle logit bias but only allocate when needed logit_bias = None for i in range(bs): if reqs[i].sampling_params.dtype == "int": if logit_bias is None: logit_bias = torch.zeros( (bs, vocab_size), dtype=torch.float32, device=device ) logit_bias[i] = int_token_logit_bias # Set fields self.input_ids = torch.tensor( flatten_input_ids, dtype=torch.int32, device=device ) self.pixel_values = [r.pixel_values for r in reqs] self.image_sizes = [r.image_size for r in reqs] self.image_offsets = [ r.image_offset - p_len for r, p_len in zip(reqs, prefix_lens) ] self.req_pool_indices = req_pool_indices self.seq_lens = torch.tensor(seq_lens, dtype=torch.int32, device=device) self.prefix_lens = torch.tensor(prefix_lens, dtype=torch.int32, device=device) self.position_ids_offsets = position_ids_offsets self.extend_num_tokens = extend_num_tokens self.out_cache_loc = out_cache_loc self.top_logprobs_nums = [r.top_logprobs_num for r in reqs] self.temperatures = torch.tensor( [r.sampling_params.temperature for r in reqs], dtype=torch.float, device=device, ).view(-1, 1) self.top_ps = torch.tensor( [r.sampling_params.top_p for r in reqs], dtype=torch.float, device=device ).view(-1, 1) self.top_ks = torch.tensor( [r.sampling_params.top_k for r in reqs], dtype=torch.int, device=device ).view(-1, 1) self.frequency_penalties = torch.tensor( [r.sampling_params.frequency_penalty for r in reqs], dtype=torch.float, device=device, ) self.presence_penalties = torch.tensor( [r.sampling_params.presence_penalty for r in reqs], dtype=torch.float, device=device, ) self.logit_bias = logit_bias def check_decode_mem(self): bs = len(self.reqs) if self.token_to_kv_pool.available_size() >= bs: return True self.tree_cache.evict(bs, self.token_to_kv_pool.dec_refs) if self.token_to_kv_pool.available_size() >= bs: return True return False def retract_decode(self): sorted_indices = [i for i in range(len(self.reqs))] # TODO(lsyin): improve the priority of retraction sorted_indices.sort( key=lambda i: (len(self.reqs[i].output_ids), -len(self.reqs[i].input_ids)), reverse=True, ) retracted_reqs = [] seq_lens_cpu = self.seq_lens.cpu().numpy() req_pool_indices_cpu = self.req_pool_indices.cpu().numpy() while self.token_to_kv_pool.available_size() < len(self.reqs): idx = sorted_indices.pop() req = self.reqs[idx] retracted_reqs.append(req) # TODO: apply more fine-grained retraction last_uncached_pos = len(req.prefix_indices) token_indices = self.req_to_token_pool.req_to_token[ req_pool_indices_cpu[idx] ][last_uncached_pos : seq_lens_cpu[idx]] self.token_to_kv_pool.dec_refs(token_indices) # release the last node self.tree_cache.dec_lock_ref(req.last_node) cur_output_str = req.partial_decode(req.output_ids) req.prev_output_str = req.prev_output_str + cur_output_str req.prev_output_ids.extend(req.output_ids) req.prefix_indices = None req.last_node = None req.extend_input_len = 0 req.output_ids = [] # For incremental logprobs req.last_update_decode_tokens = 0 req.logprob_start_len = 10**9 self.filter_batch(sorted_indices) return retracted_reqs def check_for_jump_forward(self, model_runner): jump_forward_reqs = [] filter_indices = [i for i in range(len(self.reqs))] req_pool_indices_cpu = None for i, req in enumerate(self.reqs): if req.jump_forward_map is not None: res = req.jump_forward_map.jump_forward(req.regex_fsm_state) if res is not None: jump_forward_str, next_state = res if len(jump_forward_str) <= 1: continue if req_pool_indices_cpu is None: req_pool_indices_cpu = self.req_pool_indices.tolist() # insert the old request into tree_cache self.tree_cache.cache_req( token_ids=tuple(req.input_ids + req.output_ids)[:-1], last_uncached_pos=len(req.prefix_indices), req_pool_idx=req_pool_indices_cpu[i], ) # unlock the last node self.tree_cache.dec_lock_ref(req.last_node) # jump-forward req.jump_forward_and_retokenize(jump_forward_str, next_state) # re-applying image padding if req.pixel_values is not None: ( req.origin_input_ids, req.image_offset, ) = model_runner.model.pad_input_ids( req.origin_input_ids_unpadded, req.pad_value, req.pixel_values.shape, req.image_size, ) jump_forward_reqs.append(req) filter_indices.remove(i) if len(filter_indices) < len(self.reqs): self.filter_batch(filter_indices) return jump_forward_reqs def prepare_for_decode(self, input_ids=None): if input_ids is None: input_ids = [ r.output_ids[-1] if r.output_ids else r.input_ids[-1] for r in self.reqs ] self.input_ids = torch.tensor(input_ids, dtype=torch.int32, device="cuda") self.seq_lens.add_(1) self.prefix_lens = None # Alloc mem bs = len(self.reqs) alloc_res = self.token_to_kv_pool.alloc_contiguous(bs) if alloc_res is None: self.out_cache_loc = self.token_to_kv_pool.alloc(bs) if self.out_cache_loc is None: print("Decode out of memory. This should never happen.") self.tree_cache.pretty_print() exit() self.out_cache_cont_start = None self.out_cache_cont_end = None else: self.out_cache_loc = alloc_res[0] self.out_cache_cont_start = alloc_res[1] self.out_cache_cont_end = alloc_res[2] self.req_to_token_pool.req_to_token[ self.req_pool_indices, self.seq_lens - 1 ] = self.out_cache_loc def filter_batch(self, unfinished_indices: List[int]): self.reqs = [self.reqs[i] for i in unfinished_indices] new_indices = torch.tensor(unfinished_indices, dtype=torch.int32, device="cuda") self.seq_lens = self.seq_lens[new_indices] self.input_ids = None self.req_pool_indices = self.req_pool_indices[new_indices] self.prefix_lens = None self.position_ids_offsets = self.position_ids_offsets[new_indices] self.out_cache_loc = self.out_cache_cont_start = self.out_cache_cont_end = None self.top_logprobs_nums = [self.top_logprobs_nums[i] for i in unfinished_indices] self.return_logprob = any(req.return_logprob for req in self.reqs) for item in [ "temperatures", "top_ps", "top_ks", "frequency_penalties", "presence_penalties", "logit_bias", ]: self_val = getattr(self, item, None) # logit_bias can be None if self_val is not None: setattr(self, item, self_val[new_indices]) def merge(self, other: "Batch"): self.reqs.extend(other.reqs) self.req_pool_indices = torch.concat( [self.req_pool_indices, other.req_pool_indices] ) self.seq_lens = torch.concat([self.seq_lens, other.seq_lens]) self.prefix_lens = None self.position_ids_offsets = torch.concat( [self.position_ids_offsets, other.position_ids_offsets] ) self.out_cache_loc = self.out_cache_cont_start = self.out_cache_cont_end = None self.top_logprobs_nums.extend(other.top_logprobs_nums) self.return_logprob = any(req.return_logprob for req in self.reqs) for item in [ "temperatures", "top_ps", "top_ks", "frequency_penalties", "presence_penalties", ]: self_val = getattr(self, item, None) other_val = getattr(other, item, None) setattr(self, item, torch.concat([self_val, other_val])) # logit_bias can be None if self.logit_bias is not None or other.logit_bias is not None: vocab_size = ( self.logit_bias.shape[1] if self.logit_bias is not None else other.logit_bias.shape[1] ) if self.logit_bias is None: self.logit_bias = torch.zeros( (len(self.reqs), vocab_size), dtype=torch.float32, device="cuda" ) if other.logit_bias is None: other.logit_bias = torch.zeros( (len(other.reqs), vocab_size), dtype=torch.float32, device="cuda" ) self.logit_bias = torch.concat([self.logit_bias, other.logit_bias]) def sample(self, logits: torch.Tensor): # Post process logits logits = logits.contiguous() logits.div_(self.temperatures) if self.logit_bias is not None: logits.add_(self.logit_bias) has_regex = any(req.regex_fsm is not None for req in self.reqs) if has_regex: allowed_mask = torch.empty_like(logits[0], dtype=torch.bool) for i, req in enumerate(self.reqs): if req.regex_fsm is not None: allowed_mask.zero_() allowed_mask[ req.regex_fsm.allowed_token_ids(req.regex_fsm_state) ] = 1 logits[i].masked_fill_(~allowed_mask, float("-inf")) # TODO(lmzheng): apply penalty probs = torch.softmax(logits, dim=-1) probs_sort, probs_idx = _top_p_top_k(probs, self.top_ps, self.top_ks) sampled_index = torch.multinomial(probs_sort, num_samples=1) batch_next_token_ids = torch.gather(probs_idx, dim=1, index=sampled_index).view( -1 ) batch_next_token_probs = torch.gather( probs_sort, dim=1, index=sampled_index ).view(-1) if has_regex: batch_next_token_ids_cpu = batch_next_token_ids.cpu().numpy() for i, req in enumerate(self.reqs): if req.regex_fsm is not None: req.regex_fsm_state = req.regex_fsm.next_state( req.regex_fsm_state, batch_next_token_ids_cpu[i] ) return batch_next_token_ids, batch_next_token_probs def _top_p_top_k(probs: torch.Tensor, top_ps: torch.Tensor, top_ks: torch.Tensor): probs_sort, probs_idx = probs.sort(dim=-1, descending=True) probs_sum = torch.cumsum(probs_sort, dim=-1) probs_sort[(probs_sum - probs_sort) > top_ps] = 0.0 probs_sort[ torch.arange(0, probs.shape[-1], device=probs.device).view(1, -1) >= top_ks ] = 0.0 probs_sort.div_(probs_sort.max(dim=-1, keepdim=True)[0]) return probs_sort, probs_idx