Code structure refactor (#807)
This commit is contained in:
995
python/sglang/srt/managers/schedule_batch.py
Normal file
995
python/sglang/srt/managers/schedule_batch.py
Normal file
@@ -0,0 +1,995 @@
|
||||
"""
|
||||
Copyright 2023-2024 SGLang Team
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
|
||||
"""Meta data for requests and batches"""
|
||||
|
||||
import logging
|
||||
import warnings
|
||||
from dataclasses import dataclass
|
||||
from enum import IntEnum, auto
|
||||
from typing import List, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from flashinfer.sampling import top_k_top_p_sampling_from_probs
|
||||
|
||||
from sglang.global_config import global_config
|
||||
from sglang.srt.constrained import RegexGuide
|
||||
from sglang.srt.constrained.jump_forward import JumpForwardMap
|
||||
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool, TokenToKVPool
|
||||
from sglang.srt.mem_cache.radix_cache import RadixCache
|
||||
|
||||
INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5
|
||||
|
||||
# Put some global args for easy access
|
||||
global_server_args_dict = {
|
||||
"disable_flashinfer": False,
|
||||
"disable_flashinfer_sampling": False,
|
||||
"attention_reduce_in_fp32": False,
|
||||
}
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ForwardMode(IntEnum):
|
||||
# Prefill a new sequence. This is deprecated now. "EXTEND" covers this case.
|
||||
PREFILL = auto()
|
||||
# Extend a sequence. The KV cache of the first part of the sequence is already computed (e.g., system prompt).
|
||||
EXTEND = auto()
|
||||
# Decode one token.
|
||||
DECODE = auto()
|
||||
|
||||
|
||||
class BaseFinishReason:
|
||||
def __init__(self, is_error: bool = False):
|
||||
self.is_error = is_error
|
||||
|
||||
def __str__(self):
|
||||
raise NotImplementedError("Subclasses must implement this method")
|
||||
|
||||
|
||||
class FINISH_MATCHED_TOKEN(BaseFinishReason):
|
||||
def __init__(self, matched: Union[int, List[int]]):
|
||||
super().__init__()
|
||||
self.matched = matched
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"FINISH_MATCHED_TOKEN: {self.matched}"
|
||||
|
||||
|
||||
class FINISH_LENGTH(BaseFinishReason):
|
||||
def __init__(self, length: int):
|
||||
super().__init__()
|
||||
self.length = length
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"FINISH_LENGTH: {self.length}"
|
||||
|
||||
|
||||
class FINISH_MATCHED_STR(BaseFinishReason):
|
||||
def __init__(self, matched: str):
|
||||
super().__init__()
|
||||
self.matched = matched
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"FINISH_MATCHED_STR: {self.matched}"
|
||||
|
||||
|
||||
class FINISH_ABORT(BaseFinishReason):
|
||||
def __init__(self):
|
||||
super().__init__(is_error=True)
|
||||
|
||||
def __str__(self) -> str:
|
||||
return "FINISH_ABORT"
|
||||
|
||||
|
||||
class Req:
|
||||
"""Store all inforamtion of a request."""
|
||||
|
||||
def __init__(self, rid, origin_input_text, origin_input_ids):
|
||||
# Input and output info
|
||||
self.rid = rid
|
||||
self.origin_input_text = origin_input_text
|
||||
self.origin_input_ids_unpadded = origin_input_ids # Before image padding
|
||||
self.origin_input_ids = origin_input_ids
|
||||
self.output_ids = [] # Each decode stage's output ids
|
||||
self.input_ids = None # input_ids = origin_input_ids + output_ids
|
||||
|
||||
# For incremental decoding
|
||||
# ----- | --------- read_ids -------|
|
||||
# ----- | surr_ids |
|
||||
# xxxxx | xxxxxxxxxxx | xxxxxxxxxxx |
|
||||
# ----- ^ ----------- ^ ----------- ^
|
||||
# ----- 1 ----------- 2 ----------- 3
|
||||
# 1: surr_offset
|
||||
# 2: read_offset
|
||||
# 3: last token
|
||||
self.vid = 0 # version id to sync decode status with in detokenizer_manager
|
||||
self.decoded_text = ""
|
||||
self.surr_offset = None # Surrounding offset to defeat the cleanup algorithm
|
||||
self.read_offset = None
|
||||
|
||||
# The number of decoded tokens for token usage report. Note that
|
||||
# this does not include the jump forward tokens.
|
||||
self.completion_tokens_wo_jump_forward = 0
|
||||
|
||||
# For vision input
|
||||
self.pixel_values = None
|
||||
self.image_size = None
|
||||
self.image_offset = 0
|
||||
self.pad_value = None
|
||||
|
||||
# Prefix info
|
||||
self.extend_input_len = 0
|
||||
self.prefix_indices = []
|
||||
self.last_node = None
|
||||
|
||||
# Sampling parameters
|
||||
self.sampling_params = None
|
||||
self.stream = False
|
||||
|
||||
# Check finish
|
||||
self.tokenizer = None
|
||||
self.finished_reason = None
|
||||
|
||||
# Logprobs
|
||||
self.return_logprob = False
|
||||
self.logprob_start_len = 0
|
||||
self.top_logprobs_num = 0
|
||||
self.normalized_prompt_logprob = None
|
||||
self.input_token_logprobs = None
|
||||
self.input_top_logprobs = None
|
||||
self.output_token_logprobs = []
|
||||
self.output_top_logprobs = []
|
||||
# The tokens is prefilled but need to be considered as decode tokens
|
||||
# and should be updated for the decode logprobs
|
||||
self.last_update_decode_tokens = 0
|
||||
|
||||
# Constrained decoding
|
||||
self.regex_fsm: RegexGuide = None
|
||||
self.regex_fsm_state: int = 0
|
||||
self.jump_forward_map: JumpForwardMap = None
|
||||
|
||||
# whether request reached finished condition
|
||||
def finished(self) -> bool:
|
||||
return self.finished_reason is not None
|
||||
|
||||
# Based on https://github.com/vllm-project/vllm/blob/7a64d24aad69e4d2548aa0bf528d9fe63428ab01/vllm/transformers_utils/detokenizer.py#L194-L313
|
||||
def init_incremental_detokenize(self):
|
||||
first_iter = self.surr_offset is None or self.read_offset is None
|
||||
|
||||
if first_iter:
|
||||
self.read_offset = len(self.origin_input_ids_unpadded)
|
||||
self.surr_offset = max(
|
||||
self.read_offset - INIT_INCREMENTAL_DETOKENIZATION_OFFSET, 0
|
||||
)
|
||||
|
||||
all_ids = self.origin_input_ids_unpadded + self.output_ids
|
||||
return all_ids[self.surr_offset :], self.read_offset - self.surr_offset
|
||||
|
||||
def get_next_inc_detokenization(self):
|
||||
read_ids, read_offset = self.init_incremental_detokenize()
|
||||
surr_ids = read_ids[:read_offset]
|
||||
|
||||
surr_text = self.tokenizer.decode(
|
||||
surr_ids,
|
||||
skip_special_tokens=self.sampling_params.skip_special_tokens,
|
||||
spaces_between_special_tokens=self.sampling_params.spaces_between_special_tokens,
|
||||
)
|
||||
new_text = self.tokenizer.decode(
|
||||
read_ids,
|
||||
skip_special_tokens=self.sampling_params.skip_special_tokens,
|
||||
spaces_between_special_tokens=self.sampling_params.spaces_between_special_tokens,
|
||||
)
|
||||
|
||||
if len(new_text) > len(surr_text) and not new_text.endswith("<EFBFBD>"):
|
||||
return True, new_text[len(surr_text) :]
|
||||
|
||||
return False, ""
|
||||
|
||||
def check_finished(self):
|
||||
if self.finished():
|
||||
return
|
||||
|
||||
if len(self.output_ids) >= self.sampling_params.max_new_tokens:
|
||||
self.finished_reason = FINISH_LENGTH(len(self.output_ids))
|
||||
return
|
||||
|
||||
if (
|
||||
self.output_ids[-1] == self.tokenizer.eos_token_id
|
||||
and not self.sampling_params.ignore_eos
|
||||
):
|
||||
self.finished_reason = FINISH_MATCHED_TOKEN(
|
||||
matched=self.tokenizer.eos_token_id
|
||||
)
|
||||
return
|
||||
|
||||
if len(self.sampling_params.stop_strs) > 0:
|
||||
tail_str = self.tokenizer.decode(
|
||||
self.output_ids[-(self.sampling_params.stop_str_max_len + 1) :]
|
||||
)
|
||||
|
||||
for stop_str in self.sampling_params.stop_strs:
|
||||
if stop_str in tail_str or stop_str in self.decoded_text:
|
||||
self.finished_reason = FINISH_MATCHED_STR(matched=stop_str)
|
||||
return
|
||||
|
||||
def jump_forward_and_retokenize(self, jump_forward_str, next_state):
|
||||
if self.origin_input_text is None:
|
||||
# Recovering text can only use unpadded ids
|
||||
self.origin_input_text = self.tokenizer.decode(
|
||||
self.origin_input_ids_unpadded
|
||||
)
|
||||
|
||||
all_text = self.origin_input_text + self.decoded_text + jump_forward_str
|
||||
all_ids = self.tokenizer.encode(all_text)
|
||||
prompt_tokens = len(self.origin_input_ids_unpadded)
|
||||
|
||||
if all_ids[prompt_tokens - 1] != self.origin_input_ids_unpadded[-1]:
|
||||
# TODO(lsyin): fix token fusion
|
||||
warnings.warn(
|
||||
"Token fusion between input and output, try to avoid this by removing the space at the end of the input."
|
||||
)
|
||||
return False
|
||||
|
||||
old_output_ids = self.output_ids
|
||||
self.output_ids = all_ids[prompt_tokens:]
|
||||
self.decoded_text = self.decoded_text + jump_forward_str
|
||||
self.surr_offset = prompt_tokens
|
||||
self.read_offset = len(all_ids)
|
||||
|
||||
# NOTE: A trick to reduce the surrouding tokens decoding overhead
|
||||
for i in range(0, INIT_INCREMENTAL_DETOKENIZATION_OFFSET):
|
||||
surr_text_ = self.tokenizer.decode(
|
||||
all_ids[self.read_offset - i : self.read_offset]
|
||||
)
|
||||
if not surr_text_.endswith("<EFBFBD>"):
|
||||
self.surr_offset = self.read_offset - i
|
||||
break
|
||||
|
||||
self.regex_fsm_state = next_state
|
||||
|
||||
if self.return_logprob:
|
||||
# For fast-forward part's logprobs
|
||||
k = 0
|
||||
for i, old_id in enumerate(old_output_ids):
|
||||
if old_id == self.output_ids[i]:
|
||||
k = k + 1
|
||||
else:
|
||||
break
|
||||
self.output_token_logprobs = self.output_token_logprobs[:k]
|
||||
self.output_top_logprobs = self.output_top_logprobs[:k]
|
||||
self.logprob_start_len = prompt_tokens + k
|
||||
self.last_update_decode_tokens = len(self.output_ids) - k
|
||||
|
||||
return True
|
||||
|
||||
def __repr__(self):
|
||||
return f"rid(n={self.rid}, " f"input_ids={self.origin_input_ids}, "
|
||||
|
||||
|
||||
@dataclass
|
||||
class Batch:
|
||||
"""Store all inforamtion of a batch."""
|
||||
|
||||
# Request, memory pool, and cache
|
||||
reqs: List[Req]
|
||||
req_to_token_pool: ReqToTokenPool
|
||||
token_to_kv_pool: TokenToKVPool
|
||||
tree_cache: RadixCache
|
||||
|
||||
# Batched arguments to model runner
|
||||
input_ids: torch.Tensor = None
|
||||
req_pool_indices: torch.Tensor = None
|
||||
seq_lens: torch.Tensor = None
|
||||
prefix_lens: torch.Tensor = None
|
||||
position_ids_offsets: torch.Tensor = None
|
||||
out_cache_loc: torch.Tensor = None
|
||||
extend_num_tokens: int = None
|
||||
|
||||
# For processing logprobs
|
||||
return_logprob: bool = False
|
||||
top_logprobs_nums: List[int] = None
|
||||
|
||||
# For multimodal
|
||||
pixel_values: List[torch.Tensor] = None
|
||||
image_sizes: List[List[int]] = None
|
||||
image_offsets: List[int] = None
|
||||
|
||||
# Batched sampling params
|
||||
temperatures: torch.Tensor = None
|
||||
top_ps: torch.Tensor = None
|
||||
top_ks: torch.Tensor = None
|
||||
frequency_penalties: torch.Tensor = None
|
||||
presence_penalties: torch.Tensor = None
|
||||
logit_bias: torch.Tensor = None
|
||||
|
||||
@classmethod
|
||||
def init_new(cls, reqs, req_to_token_pool, token_to_kv_pool, tree_cache):
|
||||
return_logprob = any(req.return_logprob for req in reqs)
|
||||
|
||||
return cls(
|
||||
reqs=reqs,
|
||||
req_to_token_pool=req_to_token_pool,
|
||||
token_to_kv_pool=token_to_kv_pool,
|
||||
tree_cache=tree_cache,
|
||||
return_logprob=return_logprob,
|
||||
)
|
||||
|
||||
def is_empty(self):
|
||||
return len(self.reqs) == 0
|
||||
|
||||
def has_stream(self) -> bool:
|
||||
# Return whether batch has at least 1 streaming request
|
||||
return any(r.stream for r in self.reqs)
|
||||
|
||||
def prepare_for_extend(self, vocab_size: int, int_token_logit_bias: torch.Tensor):
|
||||
device = "cuda"
|
||||
bs = len(self.reqs)
|
||||
reqs = self.reqs
|
||||
input_ids = [r.input_ids[len(r.prefix_indices) :] for r in reqs]
|
||||
prefix_indices = [r.prefix_indices for r in reqs]
|
||||
|
||||
# Handle prefix
|
||||
flatten_input_ids = []
|
||||
extend_lens = []
|
||||
prefix_lens = []
|
||||
seq_lens = []
|
||||
|
||||
req_pool_indices = self.req_to_token_pool.alloc(bs)
|
||||
|
||||
if req_pool_indices is None:
|
||||
raise RuntimeError(
|
||||
"Out of memory. "
|
||||
"Please set a smaller number for `--max-running-requests`."
|
||||
)
|
||||
|
||||
req_pool_indices_cpu = req_pool_indices.cpu().numpy()
|
||||
for i in range(bs):
|
||||
flatten_input_ids.extend(input_ids[i])
|
||||
extend_lens.append(len(input_ids[i]))
|
||||
|
||||
if len(prefix_indices[i]) == 0:
|
||||
prefix_lens.append(0)
|
||||
else:
|
||||
prefix_lens.append(len(prefix_indices[i]))
|
||||
self.req_to_token_pool.req_to_token[req_pool_indices_cpu[i]][
|
||||
: len(prefix_indices[i])
|
||||
] = prefix_indices[i]
|
||||
|
||||
seq_lens.append(prefix_lens[-1] + extend_lens[-1])
|
||||
|
||||
position_ids_offsets = torch.zeros((bs,), dtype=torch.int32, device=device)
|
||||
|
||||
# Allocate memory
|
||||
seq_lens, prefix_lens = np.array(seq_lens), np.array(prefix_lens)
|
||||
extend_num_tokens = seq_lens.sum() - prefix_lens.sum()
|
||||
out_cache_loc = self.token_to_kv_pool.alloc(extend_num_tokens)
|
||||
if out_cache_loc is None:
|
||||
self.tree_cache.evict(extend_num_tokens, self.token_to_kv_pool.free)
|
||||
out_cache_loc = self.token_to_kv_pool.alloc(extend_num_tokens)
|
||||
|
||||
if out_cache_loc is None:
|
||||
logger.error("Prefill out of memory. This should never happen.")
|
||||
self.tree_cache.pretty_print()
|
||||
exit()
|
||||
|
||||
pt = 0
|
||||
for i in range(bs):
|
||||
self.req_to_token_pool.req_to_token[req_pool_indices_cpu[i]][
|
||||
prefix_lens[i] : prefix_lens[i] + extend_lens[i]
|
||||
] = out_cache_loc[pt : pt + extend_lens[i]]
|
||||
pt += extend_lens[i]
|
||||
|
||||
# Handle logit bias but only allocate when needed
|
||||
logit_bias = None
|
||||
for i in range(bs):
|
||||
if reqs[i].sampling_params.dtype == "int":
|
||||
if logit_bias is None:
|
||||
logit_bias = torch.zeros(
|
||||
(bs, vocab_size), dtype=torch.float32, device=device
|
||||
)
|
||||
logit_bias[i][: len(int_token_logit_bias)] = int_token_logit_bias
|
||||
|
||||
# Set fields
|
||||
self.input_ids = torch.tensor(
|
||||
flatten_input_ids, dtype=torch.int32, device=device
|
||||
)
|
||||
self.pixel_values = [r.pixel_values for r in reqs]
|
||||
self.image_sizes = [r.image_size for r in reqs]
|
||||
self.image_offsets = [
|
||||
r.image_offset - p_len for r, p_len in zip(reqs, prefix_lens)
|
||||
]
|
||||
self.req_pool_indices = req_pool_indices
|
||||
self.seq_lens = torch.tensor(seq_lens, dtype=torch.int32, device=device)
|
||||
self.prefix_lens = torch.tensor(prefix_lens, dtype=torch.int32, device=device)
|
||||
self.position_ids_offsets = position_ids_offsets
|
||||
self.extend_num_tokens = extend_num_tokens
|
||||
self.out_cache_loc = out_cache_loc
|
||||
self.top_logprobs_nums = [r.top_logprobs_num for r in reqs]
|
||||
|
||||
self.temperatures = torch.tensor(
|
||||
[r.sampling_params.temperature for r in reqs],
|
||||
dtype=torch.float,
|
||||
device=device,
|
||||
).view(-1, 1)
|
||||
self.top_ps = torch.tensor(
|
||||
[r.sampling_params.top_p for r in reqs], dtype=torch.float, device=device
|
||||
)
|
||||
self.top_ks = torch.tensor(
|
||||
[r.sampling_params.top_k for r in reqs], dtype=torch.int, device=device
|
||||
)
|
||||
self.frequency_penalties = torch.tensor(
|
||||
[r.sampling_params.frequency_penalty for r in reqs],
|
||||
dtype=torch.float,
|
||||
device=device,
|
||||
)
|
||||
self.presence_penalties = torch.tensor(
|
||||
[r.sampling_params.presence_penalty for r in reqs],
|
||||
dtype=torch.float,
|
||||
device=device,
|
||||
)
|
||||
self.logit_bias = logit_bias
|
||||
|
||||
def check_decode_mem(self):
|
||||
bs = len(self.reqs)
|
||||
if self.token_to_kv_pool.available_size() >= bs:
|
||||
return True
|
||||
|
||||
self.tree_cache.evict(bs, self.token_to_kv_pool.free)
|
||||
|
||||
if self.token_to_kv_pool.available_size() >= bs:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def retract_decode(self):
|
||||
sorted_indices = [i for i in range(len(self.reqs))]
|
||||
|
||||
# TODO(lsyin): improve retraction policy for radix cache
|
||||
sorted_indices.sort(
|
||||
key=lambda i: (
|
||||
len(self.reqs[i].output_ids),
|
||||
-len(self.reqs[i].origin_input_ids),
|
||||
),
|
||||
reverse=True,
|
||||
)
|
||||
|
||||
retracted_reqs = []
|
||||
seq_lens_cpu = self.seq_lens.cpu().numpy()
|
||||
req_pool_indices_cpu = self.req_pool_indices.cpu().numpy()
|
||||
while (
|
||||
self.token_to_kv_pool.available_size()
|
||||
< len(sorted_indices) * global_config.retract_decode_steps
|
||||
):
|
||||
if len(sorted_indices) == 1:
|
||||
# Corner case: only one request left
|
||||
assert (
|
||||
self.token_to_kv_pool.available_size() > 0
|
||||
), "No space left for only one request"
|
||||
break
|
||||
|
||||
idx = sorted_indices.pop()
|
||||
req = self.reqs[idx]
|
||||
retracted_reqs.append(req)
|
||||
|
||||
# TODO: apply more fine-grained retraction
|
||||
last_uncached_pos = len(req.prefix_indices)
|
||||
token_indices = self.req_to_token_pool.req_to_token[
|
||||
req_pool_indices_cpu[idx]
|
||||
][last_uncached_pos : seq_lens_cpu[idx]]
|
||||
self.token_to_kv_pool.free(token_indices)
|
||||
|
||||
# release the last node
|
||||
self.tree_cache.dec_lock_ref(req.last_node)
|
||||
|
||||
req.prefix_indices = None
|
||||
req.last_node = None
|
||||
req.extend_input_len = 0
|
||||
|
||||
# For incremental logprobs
|
||||
req.last_update_decode_tokens = 0
|
||||
req.logprob_start_len = 10**9
|
||||
|
||||
self.filter_batch(sorted_indices)
|
||||
|
||||
# Reqs in batch are filtered
|
||||
total_decoded_tokens = sum(len(r.output_ids) for r in self.reqs)
|
||||
total_max_new_tokens = sum(r.sampling_params.max_new_tokens for r in self.reqs)
|
||||
|
||||
new_estimate_ratio = (
|
||||
total_decoded_tokens + global_config.retract_decode_steps * len(self.reqs)
|
||||
) / total_max_new_tokens
|
||||
new_estimate_ratio = min(1.0, new_estimate_ratio)
|
||||
|
||||
return retracted_reqs, new_estimate_ratio
|
||||
|
||||
def check_for_jump_forward(self, model_runner):
|
||||
jump_forward_reqs = []
|
||||
filter_indices = [i for i in range(len(self.reqs))]
|
||||
|
||||
req_pool_indices_cpu = None
|
||||
|
||||
for i, req in enumerate(self.reqs):
|
||||
if req.jump_forward_map is not None:
|
||||
jump_forward_bytes = req.jump_forward_map.jump_forward_byte(
|
||||
req.regex_fsm_state
|
||||
)
|
||||
if jump_forward_bytes is not None and len(jump_forward_bytes) > 1:
|
||||
suffix_bytes = []
|
||||
continuation_range = range(0x80, 0xC0)
|
||||
cur_state = req.regex_fsm_state
|
||||
while (
|
||||
len(jump_forward_bytes)
|
||||
and jump_forward_bytes[0][0] in continuation_range
|
||||
):
|
||||
# continuation bytes
|
||||
byte_edge = jump_forward_bytes.pop(0)
|
||||
suffix_bytes.append(byte_edge[0])
|
||||
cur_state = byte_edge[1]
|
||||
|
||||
suffix_tokens = [f"<0x{hex(b)[2:].upper()}>" for b in suffix_bytes]
|
||||
suffix_ids = req.tokenizer.convert_tokens_to_ids(suffix_tokens)
|
||||
|
||||
# Current ids, for cache and revert
|
||||
cur_all_ids = tuple(req.origin_input_ids + req.output_ids)[:-1]
|
||||
cur_output_ids = req.output_ids
|
||||
|
||||
req.output_ids.extend(suffix_ids)
|
||||
decode_res, new_text = req.get_next_inc_detokenization()
|
||||
if not decode_res:
|
||||
req.output_ids = cur_output_ids
|
||||
continue
|
||||
|
||||
(
|
||||
jump_forward_str,
|
||||
next_state,
|
||||
) = req.jump_forward_map.jump_forward_symbol(cur_state)
|
||||
|
||||
# Make the incrementally decoded text part of jump_forward_str
|
||||
# so that the UTF-8 will not corrupt
|
||||
jump_forward_str = new_text + jump_forward_str
|
||||
if not req.jump_forward_and_retokenize(
|
||||
jump_forward_str, next_state
|
||||
):
|
||||
req.output_ids = cur_output_ids
|
||||
continue
|
||||
|
||||
# The decode status has diverged from detokenizer_manager
|
||||
req.vid += 1
|
||||
|
||||
# insert the old request into tree_cache
|
||||
if req_pool_indices_cpu is None:
|
||||
req_pool_indices_cpu = self.req_pool_indices.tolist()
|
||||
self.tree_cache.cache_req(
|
||||
token_ids=cur_all_ids,
|
||||
last_uncached_pos=len(req.prefix_indices),
|
||||
req_pool_idx=req_pool_indices_cpu[i],
|
||||
)
|
||||
|
||||
# unlock the last node
|
||||
self.tree_cache.dec_lock_ref(req.last_node)
|
||||
|
||||
# re-applying image padding
|
||||
if req.pixel_values is not None:
|
||||
(
|
||||
req.origin_input_ids,
|
||||
req.image_offset,
|
||||
) = model_runner.model.pad_input_ids(
|
||||
req.origin_input_ids_unpadded,
|
||||
req.pad_value,
|
||||
req.pixel_values.shape,
|
||||
req.image_size,
|
||||
)
|
||||
|
||||
jump_forward_reqs.append(req)
|
||||
filter_indices.remove(i)
|
||||
|
||||
if len(filter_indices) < len(self.reqs):
|
||||
self.filter_batch(filter_indices)
|
||||
|
||||
return jump_forward_reqs
|
||||
|
||||
def prepare_for_decode(self, input_ids=None):
|
||||
if input_ids is None:
|
||||
input_ids = [
|
||||
r.output_ids[-1] if r.output_ids else r.input_ids[-1] for r in self.reqs
|
||||
]
|
||||
self.input_ids = torch.tensor(input_ids, dtype=torch.int32, device="cuda")
|
||||
self.seq_lens.add_(1)
|
||||
self.prefix_lens = None
|
||||
|
||||
# Alloc mem
|
||||
bs = len(self.reqs)
|
||||
self.out_cache_loc = self.token_to_kv_pool.alloc(bs)
|
||||
|
||||
if self.out_cache_loc is None:
|
||||
logger.error("Decode out of memory. This should never happen.")
|
||||
self.tree_cache.pretty_print()
|
||||
exit()
|
||||
|
||||
self.req_to_token_pool.req_to_token[
|
||||
self.req_pool_indices, self.seq_lens - 1
|
||||
] = self.out_cache_loc
|
||||
|
||||
def filter_batch(self, unfinished_indices: List[int]):
|
||||
self.reqs = [self.reqs[i] for i in unfinished_indices]
|
||||
new_indices = torch.tensor(unfinished_indices, dtype=torch.int32, device="cuda")
|
||||
self.seq_lens = self.seq_lens[new_indices]
|
||||
self.input_ids = None
|
||||
self.req_pool_indices = self.req_pool_indices[new_indices]
|
||||
self.prefix_lens = None
|
||||
self.position_ids_offsets = self.position_ids_offsets[new_indices]
|
||||
self.out_cache_loc = None
|
||||
self.top_logprobs_nums = [self.top_logprobs_nums[i] for i in unfinished_indices]
|
||||
self.return_logprob = any(req.return_logprob for req in self.reqs)
|
||||
|
||||
for item in [
|
||||
"temperatures",
|
||||
"top_ps",
|
||||
"top_ks",
|
||||
"frequency_penalties",
|
||||
"presence_penalties",
|
||||
"logit_bias",
|
||||
]:
|
||||
self_val = getattr(self, item, None)
|
||||
if self_val is not None: # logit_bias can be None
|
||||
setattr(self, item, self_val[new_indices])
|
||||
|
||||
def merge(self, other: "Batch"):
|
||||
self.reqs.extend(other.reqs)
|
||||
|
||||
self.req_pool_indices = torch.concat(
|
||||
[self.req_pool_indices, other.req_pool_indices]
|
||||
)
|
||||
self.seq_lens = torch.concat([self.seq_lens, other.seq_lens])
|
||||
self.prefix_lens = None
|
||||
self.position_ids_offsets = torch.concat(
|
||||
[self.position_ids_offsets, other.position_ids_offsets]
|
||||
)
|
||||
self.out_cache_loc = None
|
||||
self.top_logprobs_nums.extend(other.top_logprobs_nums)
|
||||
self.return_logprob = any(req.return_logprob for req in self.reqs)
|
||||
|
||||
for item in [
|
||||
"temperatures",
|
||||
"top_ps",
|
||||
"top_ks",
|
||||
"frequency_penalties",
|
||||
"presence_penalties",
|
||||
]:
|
||||
self_val = getattr(self, item, None)
|
||||
other_val = getattr(other, item, None)
|
||||
setattr(self, item, torch.concat([self_val, other_val]))
|
||||
|
||||
# logit_bias can be None
|
||||
if self.logit_bias is not None or other.logit_bias is not None:
|
||||
vocab_size = (
|
||||
self.logit_bias.shape[1]
|
||||
if self.logit_bias is not None
|
||||
else other.logit_bias.shape[1]
|
||||
)
|
||||
if self.logit_bias is None:
|
||||
self.logit_bias = torch.zeros(
|
||||
(len(self.reqs), vocab_size), dtype=torch.float32, device="cuda"
|
||||
)
|
||||
if other.logit_bias is None:
|
||||
other.logit_bias = torch.zeros(
|
||||
(len(other.reqs), vocab_size), dtype=torch.float32, device="cuda"
|
||||
)
|
||||
self.logit_bias = torch.concat([self.logit_bias, other.logit_bias])
|
||||
|
||||
def sample(self, logits: torch.Tensor):
|
||||
# Post process logits
|
||||
logits = logits.contiguous()
|
||||
logits.div_(self.temperatures)
|
||||
if self.logit_bias is not None:
|
||||
logits.add_(self.logit_bias)
|
||||
|
||||
has_regex = any(req.regex_fsm is not None for req in self.reqs)
|
||||
if has_regex:
|
||||
allowed_mask = torch.empty_like(logits[0], dtype=torch.bool)
|
||||
for i, req in enumerate(self.reqs):
|
||||
if req.regex_fsm is not None:
|
||||
allowed_mask.zero_()
|
||||
allowed_mask[
|
||||
req.regex_fsm.get_next_instruction(req.regex_fsm_state).tokens
|
||||
] = 1
|
||||
logits[i].masked_fill_(~allowed_mask, float("-inf"))
|
||||
|
||||
# TODO(lmzheng): apply penalty
|
||||
probs = torch.softmax(logits, dim=-1)
|
||||
|
||||
if not global_server_args_dict["disable_flashinfer_sampling"]:
|
||||
max_top_k_round, batch_size = 32, probs.shape[0]
|
||||
uniform_samples = torch.rand(
|
||||
(max_top_k_round, batch_size), device=probs.device
|
||||
)
|
||||
batch_next_token_ids, success = top_k_top_p_sampling_from_probs(
|
||||
probs, uniform_samples, self.top_ks, self.top_ps
|
||||
)
|
||||
else:
|
||||
# Here we provide a slower fallback implementation.
|
||||
batch_next_token_ids, success = top_k_top_p_sampling_from_probs_torch(
|
||||
probs, self.top_ks, self.top_ps
|
||||
)
|
||||
|
||||
if not torch.all(success):
|
||||
warnings.warn("Sampling failed, fallback to top_k=1 strategy")
|
||||
probs = probs.masked_fill(torch.isnan(probs), 0.0)
|
||||
argmax_ids = torch.argmax(probs, dim=-1)
|
||||
batch_next_token_ids = torch.where(
|
||||
success, batch_next_token_ids, argmax_ids
|
||||
)
|
||||
|
||||
if has_regex:
|
||||
batch_next_token_ids_cpu = batch_next_token_ids.cpu().numpy()
|
||||
for i, req in enumerate(self.reqs):
|
||||
if req.regex_fsm is not None:
|
||||
req.regex_fsm_state = req.regex_fsm.get_next_state(
|
||||
req.regex_fsm_state, batch_next_token_ids_cpu[i]
|
||||
)
|
||||
|
||||
return batch_next_token_ids
|
||||
|
||||
|
||||
@dataclass
|
||||
class InputMetadata:
|
||||
"""Store all inforamtion of a forward pass."""
|
||||
|
||||
forward_mode: ForwardMode
|
||||
batch_size: int
|
||||
total_num_tokens: int
|
||||
req_pool_indices: torch.Tensor
|
||||
seq_lens: torch.Tensor
|
||||
positions: torch.Tensor
|
||||
req_to_token_pool: ReqToTokenPool
|
||||
token_to_kv_pool: TokenToKVPool
|
||||
|
||||
# For extend
|
||||
extend_seq_lens: torch.Tensor
|
||||
extend_start_loc: torch.Tensor
|
||||
extend_no_prefix: bool
|
||||
|
||||
# Output location of the KV cache
|
||||
out_cache_loc: torch.Tensor = None
|
||||
|
||||
# Output options
|
||||
return_logprob: bool = False
|
||||
top_logprobs_nums: List[int] = None
|
||||
|
||||
# Trition attention backend
|
||||
triton_max_seq_len: int = 0
|
||||
triton_max_extend_len: int = 0
|
||||
triton_start_loc: torch.Tensor = None
|
||||
triton_prefix_lens: torch.Tensor = None
|
||||
|
||||
# FlashInfer attention backend
|
||||
flashinfer_prefill_wrapper_ragged: "BatchPrefillWithRaggedKVCacheWrapper" = None
|
||||
flashinfer_prefill_wrapper_paged: "BatchPrefillWithPagedKVCacheWrapper" = None
|
||||
flashinfer_decode_wrapper: "BatchDecodeWithPagedKVCacheWrapper" = None
|
||||
use_ragged: bool = False
|
||||
|
||||
@classmethod
|
||||
def create(
|
||||
cls,
|
||||
model_runner,
|
||||
forward_mode,
|
||||
req_pool_indices,
|
||||
seq_lens,
|
||||
prefix_lens,
|
||||
position_ids_offsets,
|
||||
out_cache_loc,
|
||||
top_logprobs_nums=None,
|
||||
return_logprob=False,
|
||||
skip_flashinfer_init=False,
|
||||
):
|
||||
use_ragged = False
|
||||
if not skip_flashinfer_init and not model_runner.server_args.disable_flashinfer:
|
||||
if forward_mode != ForwardMode.DECODE and int(torch.sum(seq_lens)) > 4096:
|
||||
use_ragged = True
|
||||
init_flashinfer_args(
|
||||
forward_mode,
|
||||
model_runner,
|
||||
req_pool_indices,
|
||||
seq_lens,
|
||||
prefix_lens,
|
||||
model_runner.flashinfer_decode_wrapper,
|
||||
use_ragged,
|
||||
)
|
||||
|
||||
batch_size = len(req_pool_indices)
|
||||
|
||||
if forward_mode == ForwardMode.DECODE:
|
||||
positions = ((seq_lens - 1) + position_ids_offsets).to(torch.int64)
|
||||
extend_seq_lens = extend_start_loc = extend_no_prefix = None
|
||||
if not model_runner.server_args.disable_flashinfer:
|
||||
# This variable is not needed in this case,
|
||||
# we do not compute it to make it compatbile with cuda graph.
|
||||
total_num_tokens = None
|
||||
else:
|
||||
total_num_tokens = int(torch.sum(seq_lens))
|
||||
else:
|
||||
seq_lens_cpu = seq_lens.cpu().numpy()
|
||||
prefix_lens_cpu = prefix_lens.cpu().numpy()
|
||||
position_ids_offsets_cpu = position_ids_offsets.cpu().numpy()
|
||||
positions = torch.tensor(
|
||||
np.concatenate(
|
||||
[
|
||||
np.arange(
|
||||
prefix_lens_cpu[i] + position_ids_offsets_cpu[i],
|
||||
seq_lens_cpu[i] + position_ids_offsets_cpu[i],
|
||||
)
|
||||
for i in range(batch_size)
|
||||
],
|
||||
axis=0,
|
||||
),
|
||||
device="cuda",
|
||||
)
|
||||
extend_seq_lens = seq_lens - prefix_lens
|
||||
extend_start_loc = torch.zeros_like(seq_lens)
|
||||
extend_start_loc[1:] = torch.cumsum(extend_seq_lens[:-1], dim=0)
|
||||
extend_no_prefix = torch.all(prefix_lens == 0)
|
||||
total_num_tokens = int(torch.sum(seq_lens))
|
||||
|
||||
ret = cls(
|
||||
forward_mode=forward_mode,
|
||||
batch_size=batch_size,
|
||||
total_num_tokens=total_num_tokens,
|
||||
req_pool_indices=req_pool_indices,
|
||||
seq_lens=seq_lens,
|
||||
positions=positions,
|
||||
req_to_token_pool=model_runner.req_to_token_pool,
|
||||
token_to_kv_pool=model_runner.token_to_kv_pool,
|
||||
out_cache_loc=out_cache_loc,
|
||||
extend_seq_lens=extend_seq_lens,
|
||||
extend_start_loc=extend_start_loc,
|
||||
extend_no_prefix=extend_no_prefix,
|
||||
return_logprob=return_logprob,
|
||||
top_logprobs_nums=top_logprobs_nums,
|
||||
flashinfer_prefill_wrapper_ragged=model_runner.flashinfer_prefill_wrapper_ragged,
|
||||
flashinfer_prefill_wrapper_paged=model_runner.flashinfer_prefill_wrapper_paged,
|
||||
flashinfer_decode_wrapper=model_runner.flashinfer_decode_wrapper,
|
||||
use_ragged=use_ragged,
|
||||
)
|
||||
|
||||
if model_runner.server_args.disable_flashinfer:
|
||||
(
|
||||
ret.triton_max_seq_len,
|
||||
ret.triton_max_extend_len,
|
||||
ret.triton_start_loc,
|
||||
ret.triton_prefix_lens,
|
||||
) = init_triton_args(forward_mode, seq_lens, prefix_lens)
|
||||
|
||||
return ret
|
||||
|
||||
|
||||
def init_flashinfer_args(
|
||||
forward_mode,
|
||||
model_runner,
|
||||
req_pool_indices,
|
||||
seq_lens,
|
||||
prefix_lens,
|
||||
flashinfer_decode_wrapper,
|
||||
use_ragged=False,
|
||||
):
|
||||
"""Init auxiliary variables for FlashInfer attention backend."""
|
||||
num_qo_heads = model_runner.model_config.num_attention_heads // model_runner.tp_size
|
||||
num_kv_heads = model_runner.model_config.get_num_kv_heads(model_runner.tp_size)
|
||||
head_dim = model_runner.model_config.head_dim
|
||||
batch_size = len(req_pool_indices)
|
||||
total_num_tokens = int(torch.sum(seq_lens))
|
||||
|
||||
if use_ragged:
|
||||
paged_kernel_lens = prefix_lens
|
||||
else:
|
||||
paged_kernel_lens = seq_lens
|
||||
|
||||
kv_indptr = torch.zeros((batch_size + 1,), dtype=torch.int32, device="cuda")
|
||||
kv_indptr[1:] = torch.cumsum(paged_kernel_lens, dim=0)
|
||||
req_pool_indices_cpu = req_pool_indices.cpu().numpy()
|
||||
paged_kernel_lens_cpu = paged_kernel_lens.cpu().numpy()
|
||||
kv_indices = torch.cat(
|
||||
[
|
||||
model_runner.req_to_token_pool.req_to_token[
|
||||
req_pool_indices_cpu[i], : paged_kernel_lens_cpu[i]
|
||||
]
|
||||
for i in range(batch_size)
|
||||
],
|
||||
dim=0,
|
||||
).contiguous()
|
||||
kv_last_page_len = torch.ones((batch_size,), dtype=torch.int32, device="cuda")
|
||||
|
||||
if forward_mode == ForwardMode.DECODE:
|
||||
flashinfer_decode_wrapper.end_forward()
|
||||
flashinfer_decode_wrapper.begin_forward(
|
||||
kv_indptr,
|
||||
kv_indices,
|
||||
kv_last_page_len,
|
||||
num_qo_heads,
|
||||
num_kv_heads,
|
||||
head_dim,
|
||||
1,
|
||||
)
|
||||
else:
|
||||
# extend part
|
||||
qo_indptr = torch.zeros((batch_size + 1,), dtype=torch.int32, device="cuda")
|
||||
qo_indptr[1:] = torch.cumsum(seq_lens - prefix_lens, dim=0)
|
||||
|
||||
if use_ragged:
|
||||
model_runner.flashinfer_prefill_wrapper_ragged.end_forward()
|
||||
model_runner.flashinfer_prefill_wrapper_ragged.begin_forward(
|
||||
qo_indptr,
|
||||
qo_indptr,
|
||||
num_qo_heads,
|
||||
num_kv_heads,
|
||||
head_dim,
|
||||
)
|
||||
|
||||
# cached part
|
||||
model_runner.flashinfer_prefill_wrapper_paged.end_forward()
|
||||
model_runner.flashinfer_prefill_wrapper_paged.begin_forward(
|
||||
qo_indptr,
|
||||
kv_indptr,
|
||||
kv_indices,
|
||||
kv_last_page_len,
|
||||
num_qo_heads,
|
||||
num_kv_heads,
|
||||
head_dim,
|
||||
1,
|
||||
)
|
||||
|
||||
|
||||
def init_triton_args(forward_mode, seq_lens, prefix_lens):
|
||||
"""Init auxiliary variables for triton attention backend."""
|
||||
batch_size = len(seq_lens)
|
||||
max_seq_len = int(torch.max(seq_lens))
|
||||
start_loc = torch.zeros((batch_size,), dtype=torch.int32, device="cuda")
|
||||
start_loc[1:] = torch.cumsum(seq_lens[:-1], dim=0)
|
||||
|
||||
if forward_mode == ForwardMode.DECODE:
|
||||
max_extend_len = None
|
||||
else:
|
||||
extend_seq_lens = seq_lens - prefix_lens
|
||||
max_extend_len = int(torch.max(extend_seq_lens))
|
||||
|
||||
return max_seq_len, max_extend_len, start_loc, prefix_lens
|
||||
|
||||
|
||||
def top_k_top_p_sampling_from_probs_torch(
|
||||
probs: torch.Tensor, top_ks: torch.Tensor, top_ps: torch.Tensor
|
||||
):
|
||||
"""A top-k and top-k sampling implementation with native pytorch operations."""
|
||||
probs_sort, probs_idx = probs.sort(dim=-1, descending=True)
|
||||
probs_sum = torch.cumsum(probs_sort, dim=-1)
|
||||
probs_sort[(probs_sum - probs_sort) > top_ps.view(-1, 1)] = 0.0
|
||||
probs_sort[
|
||||
torch.arange(0, probs.shape[-1], device=probs.device).view(1, -1)
|
||||
>= top_ks.view(-1, 1)
|
||||
] = 0.0
|
||||
probs_sort.div_(probs_sort.max(dim=-1, keepdim=True)[0])
|
||||
try:
|
||||
sampled_index = torch.multinomial(probs_sort, num_samples=1)
|
||||
except RuntimeError:
|
||||
batch_next_token_ids = torch.zeros(
|
||||
(probs_sort.shape[0],), dtype=torch.int64, device=probs.device
|
||||
)
|
||||
success = torch.zeros(probs.shape[0], dtype=torch.bool, device=probs.device)
|
||||
return batch_next_token_ids, success
|
||||
|
||||
batch_next_token_ids = torch.gather(probs_idx, dim=1, index=sampled_index).view(-1)
|
||||
success = torch.ones(probs.shape[0], dtype=torch.bool, device=probs.device)
|
||||
return batch_next_token_ids, success
|
||||
Reference in New Issue
Block a user