2024-11-22 22:16:53 +08:00
|
|
|
|
# Copyright 2023-2024 SGLang Team
|
|
|
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
|
|
# you may not use this file except in compliance with the License.
|
|
|
|
|
|
# You may obtain a copy of the License at
|
|
|
|
|
|
#
|
|
|
|
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
|
|
#
|
|
|
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
|
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
|
|
# limitations under the License.
|
|
|
|
|
|
# ==============================================================================
|
2024-09-30 06:41:49 -07:00
|
|
|
|
"""
|
|
|
|
|
|
Store information about requests and batches.
|
|
|
|
|
|
|
|
|
|
|
|
The following is the flow of data structures for a batch:
|
|
|
|
|
|
|
|
|
|
|
|
ScheduleBatch -> ModelWorkerBatch -> ForwardBatch
|
|
|
|
|
|
|
|
|
|
|
|
- ScheduleBatch is managed by `scheduler.py::Scheduler`.
|
|
|
|
|
|
It contains high-level scheduling data. Most of the data is on the CPU.
|
|
|
|
|
|
- ModelWorkerBatch is managed by `tp_worker.py::TpModelWorker`.
|
2024-10-19 23:19:26 -07:00
|
|
|
|
It is a subset of `ScheduleBatch` that only contains data related to the model forward on GPU.
|
|
|
|
|
|
It will be transformed from CPU scheduler to GPU model runner.
|
2024-09-30 06:41:49 -07:00
|
|
|
|
- ForwardBatch is managed by `model_runner.py::ModelRunner`.
|
|
|
|
|
|
It contains low-level tensor data. Most of the data consists of GPU tensors.
|
|
|
|
|
|
"""
|
2024-06-08 02:06:52 -07:00
|
|
|
|
|
2024-10-20 19:47:14 -07:00
|
|
|
|
import dataclasses
|
2024-07-28 23:01:45 -07:00
|
|
|
|
import logging
|
2024-12-28 02:59:56 +08:00
|
|
|
|
from typing import List, Optional, Set, Tuple, Union
|
2024-01-08 04:37:50 +00:00
|
|
|
|
|
2024-11-27 00:03:29 -08:00
|
|
|
|
import numpy as np
|
2024-01-08 04:37:50 +00:00
|
|
|
|
import torch
|
2024-11-16 16:14:23 -08:00
|
|
|
|
import triton
|
|
|
|
|
|
import triton.language as tl
|
2024-08-21 16:48:24 -07:00
|
|
|
|
|
2024-07-23 22:06:02 -07:00
|
|
|
|
from sglang.global_config import global_config
|
2024-10-21 15:01:21 -07:00
|
|
|
|
from sglang.srt.configs.model_config import ModelConfig
|
2024-11-13 01:49:45 -08:00
|
|
|
|
from sglang.srt.constrained.base_grammar_backend import BaseGrammarObject
|
2024-08-11 17:57:02 -07:00
|
|
|
|
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
|
2024-08-01 00:29:01 -07:00
|
|
|
|
from sglang.srt.mem_cache.chunk_cache import ChunkCache
|
2024-08-05 01:40:33 +08:00
|
|
|
|
from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
|
2024-09-30 06:41:49 -07:00
|
|
|
|
from sglang.srt.model_executor.forward_batch_info import ForwardMode
|
2024-08-21 16:48:24 -07:00
|
|
|
|
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
|
2024-09-29 17:42:45 -07:00
|
|
|
|
from sglang.srt.sampling.sampling_params import SamplingParams
|
2024-09-10 17:11:16 -07:00
|
|
|
|
from sglang.srt.server_args import ServerArgs
|
2024-06-12 14:39:12 +08:00
|
|
|
|
|
|
|
|
|
|
INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5
|
2024-01-08 04:37:50 +00:00
|
|
|
|
|
2024-07-27 20:18:56 -07:00
|
|
|
|
# Put some global args for easy access
|
|
|
|
|
|
global_server_args_dict = {
|
2024-09-10 17:11:16 -07:00
|
|
|
|
"attention_backend": ServerArgs.attention_backend,
|
|
|
|
|
|
"sampling_backend": ServerArgs.sampling_backend,
|
|
|
|
|
|
"triton_attention_reduce_in_fp32": ServerArgs.triton_attention_reduce_in_fp32,
|
2024-09-17 19:42:48 +08:00
|
|
|
|
"disable_mla": ServerArgs.disable_mla,
|
2024-09-10 17:11:16 -07:00
|
|
|
|
"torchao_config": ServerArgs.torchao_config,
|
2024-11-17 16:53:44 -08:00
|
|
|
|
"enable_nan_detection": ServerArgs.enable_nan_detection,
|
2024-11-16 17:01:43 +08:00
|
|
|
|
"enable_dp_attention": ServerArgs.enable_dp_attention,
|
2024-12-06 15:05:21 +08:00
|
|
|
|
"enable_ep_moe": ServerArgs.enable_ep_moe,
|
2024-07-27 20:18:56 -07:00
|
|
|
|
}
|
|
|
|
|
|
|
2024-01-08 04:37:50 +00:00
|
|
|
|
|
2024-07-28 23:01:45 -07:00
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
|
|
|
|
|
2024-06-08 04:20:40 +08:00
|
|
|
|
class BaseFinishReason:
|
|
|
|
|
|
def __init__(self, is_error: bool = False):
|
|
|
|
|
|
self.is_error = is_error
|
2024-01-08 04:37:50 +00:00
|
|
|
|
|
2024-09-12 20:47:31 -07:00
|
|
|
|
def to_json(self):
|
2024-09-15 06:36:06 -07:00
|
|
|
|
raise NotImplementedError()
|
2024-06-08 04:20:40 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class FINISH_MATCHED_TOKEN(BaseFinishReason):
|
2024-07-06 00:58:46 -07:00
|
|
|
|
def __init__(self, matched: Union[int, List[int]]):
|
2024-06-08 04:20:40 +08:00
|
|
|
|
super().__init__()
|
|
|
|
|
|
self.matched = matched
|
|
|
|
|
|
|
2024-09-12 20:47:31 -07:00
|
|
|
|
def to_json(self):
|
|
|
|
|
|
return {
|
|
|
|
|
|
"type": "stop", # to match OpenAI API's return value
|
|
|
|
|
|
"matched": self.matched,
|
|
|
|
|
|
}
|
2024-06-08 04:20:40 +08:00
|
|
|
|
|
|
|
|
|
|
|
2024-09-12 20:47:31 -07:00
|
|
|
|
class FINISH_MATCHED_STR(BaseFinishReason):
|
|
|
|
|
|
def __init__(self, matched: str):
|
2024-06-08 04:20:40 +08:00
|
|
|
|
super().__init__()
|
2024-09-12 20:47:31 -07:00
|
|
|
|
self.matched = matched
|
2024-06-08 04:20:40 +08:00
|
|
|
|
|
2024-09-12 20:47:31 -07:00
|
|
|
|
def to_json(self):
|
|
|
|
|
|
return {
|
|
|
|
|
|
"type": "stop", # to match OpenAI API's return value
|
|
|
|
|
|
"matched": self.matched,
|
|
|
|
|
|
}
|
2024-06-08 04:20:40 +08:00
|
|
|
|
|
|
|
|
|
|
|
2024-09-12 20:47:31 -07:00
|
|
|
|
class FINISH_LENGTH(BaseFinishReason):
|
|
|
|
|
|
def __init__(self, length: int):
|
2024-06-08 04:20:40 +08:00
|
|
|
|
super().__init__()
|
2024-09-12 20:47:31 -07:00
|
|
|
|
self.length = length
|
2024-06-08 04:20:40 +08:00
|
|
|
|
|
2024-09-12 20:47:31 -07:00
|
|
|
|
def to_json(self):
|
|
|
|
|
|
return {
|
|
|
|
|
|
"type": "length", # to match OpenAI API's return value
|
|
|
|
|
|
"length": self.length,
|
|
|
|
|
|
}
|
2024-06-08 04:20:40 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class FINISH_ABORT(BaseFinishReason):
|
2024-11-11 23:24:41 -08:00
|
|
|
|
def __init__(self, message="Unknown error"):
|
2024-06-08 04:20:40 +08:00
|
|
|
|
super().__init__(is_error=True)
|
2024-11-11 23:24:41 -08:00
|
|
|
|
self.message = message
|
2024-06-08 04:20:40 +08:00
|
|
|
|
|
2024-09-12 20:47:31 -07:00
|
|
|
|
def to_json(self):
|
|
|
|
|
|
return {
|
|
|
|
|
|
"type": "abort",
|
2024-11-11 23:24:41 -08:00
|
|
|
|
"message": self.message,
|
2024-09-12 20:47:31 -07:00
|
|
|
|
}
|
2024-05-13 15:56:00 -07:00
|
|
|
|
|
2024-01-08 04:37:50 +00:00
|
|
|
|
|
2024-10-20 19:47:14 -07:00
|
|
|
|
@dataclasses.dataclass
|
2024-09-28 23:28:55 -07:00
|
|
|
|
class ImageInputs:
|
2024-09-30 06:41:49 -07:00
|
|
|
|
"""The image related inputs."""
|
|
|
|
|
|
|
2024-11-29 03:15:58 -08:00
|
|
|
|
pixel_values: Union[torch.Tensor, np.array]
|
2024-10-21 15:01:21 -07:00
|
|
|
|
image_hashes: Optional[list] = None
|
2024-09-28 23:28:55 -07:00
|
|
|
|
image_sizes: Optional[list] = None
|
|
|
|
|
|
image_offsets: Optional[list] = None
|
2024-12-09 09:52:38 -08:00
|
|
|
|
image_pad_len: Optional[list] = None
|
2024-09-28 23:28:55 -07:00
|
|
|
|
pad_values: Optional[list] = None
|
|
|
|
|
|
modalities: Optional[list] = None
|
2024-10-21 15:01:21 -07:00
|
|
|
|
num_image_tokens: Optional[int] = None
|
2024-09-28 23:28:55 -07:00
|
|
|
|
|
2024-11-29 03:15:58 -08:00
|
|
|
|
# Llava related
|
2024-09-28 23:28:55 -07:00
|
|
|
|
aspect_ratio_ids: Optional[List[torch.Tensor]] = None
|
|
|
|
|
|
aspect_ratio_mask: Optional[List[torch.Tensor]] = None
|
2024-11-19 15:04:43 -08:00
|
|
|
|
|
2024-10-19 21:44:38 -07:00
|
|
|
|
# QWen2-VL related
|
|
|
|
|
|
image_grid_thws: List[Tuple[int, int, int]] = None
|
2024-11-11 00:10:45 +08:00
|
|
|
|
mrope_position_delta: Optional[torch.Tensor] = None
|
2024-09-28 23:28:55 -07:00
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
2024-11-29 03:15:58 -08:00
|
|
|
|
def from_dict(obj: dict):
|
2024-09-28 23:28:55 -07:00
|
|
|
|
ret = ImageInputs(
|
|
|
|
|
|
pixel_values=obj["pixel_values"],
|
2024-11-28 12:08:13 -08:00
|
|
|
|
image_hashes=obj["image_hashes"],
|
2024-09-28 23:28:55 -07:00
|
|
|
|
)
|
2024-11-29 03:15:58 -08:00
|
|
|
|
|
|
|
|
|
|
# Use image hash as fake token_ids. We use this as the key for prefix matching in the radix cache.
|
|
|
|
|
|
# Please note that if the `input_ids` is later used in the model forward,
|
2024-11-29 04:24:20 -08:00
|
|
|
|
# you also need to clamp the values within the range of [0, vocab_size) to avoid out-of-bound
|
|
|
|
|
|
# errors in cuda kernels. See also llava.py for example.
|
2024-11-29 03:15:58 -08:00
|
|
|
|
ret.pad_values = [x % (1 << 30) for x in ret.image_hashes]
|
2024-10-21 15:01:21 -07:00
|
|
|
|
|
|
|
|
|
|
optional_args = [
|
|
|
|
|
|
"image_sizes",
|
|
|
|
|
|
"modalities",
|
|
|
|
|
|
"aspect_ratio_ids",
|
|
|
|
|
|
"aspect_ratio_mask",
|
|
|
|
|
|
"image_grid_thws",
|
|
|
|
|
|
]
|
|
|
|
|
|
for arg in optional_args:
|
|
|
|
|
|
if arg in obj:
|
|
|
|
|
|
setattr(ret, arg, obj[arg])
|
|
|
|
|
|
|
2024-09-28 23:28:55 -07:00
|
|
|
|
return ret
|
|
|
|
|
|
|
2024-11-29 03:15:58 -08:00
|
|
|
|
def merge(self, other):
|
2024-11-27 00:03:29 -08:00
|
|
|
|
assert self.pixel_values.shape[1:] == other.pixel_values.shape[1:]
|
|
|
|
|
|
self.pixel_values = np.concatenate([self.pixel_values, other.pixel_values])
|
|
|
|
|
|
|
2024-11-29 03:15:58 -08:00
|
|
|
|
# Use image hash as fake token_ids. We use this as the key for prefix matching in the radix cache.
|
|
|
|
|
|
# Please note that if the `input_ids` is later used in the model forward,
|
2024-11-29 04:24:20 -08:00
|
|
|
|
# you also need to clamp the values within the range of [0, vocab_size) to avoid out-of-bound
|
|
|
|
|
|
# errors in cuda kernels. See also llava.py for example.
|
2024-11-29 03:15:58 -08:00
|
|
|
|
self.image_hashes += other.image_hashes
|
|
|
|
|
|
self.pad_values = [x % (1 << 30) for x in self.image_hashes]
|
2024-11-27 00:03:29 -08:00
|
|
|
|
|
|
|
|
|
|
optional_args = [
|
|
|
|
|
|
"image_sizes",
|
|
|
|
|
|
"image_offsets",
|
2024-12-09 09:52:38 -08:00
|
|
|
|
"image_pad_len",
|
2024-11-27 00:03:29 -08:00
|
|
|
|
# "modalities", # modalities should be ["multi-images"] (one entry) even for multiple images
|
|
|
|
|
|
"aspect_ratio_ids",
|
|
|
|
|
|
"aspect_ratio_mask",
|
|
|
|
|
|
"image_grid_thws",
|
|
|
|
|
|
]
|
|
|
|
|
|
for arg in optional_args:
|
|
|
|
|
|
if getattr(self, arg, None) is not None:
|
|
|
|
|
|
setattr(self, arg, getattr(self, arg) + getattr(other, arg))
|
|
|
|
|
|
|
2024-09-28 23:28:55 -07:00
|
|
|
|
|
2024-01-08 04:37:50 +00:00
|
|
|
|
class Req:
|
2024-09-30 06:41:49 -07:00
|
|
|
|
"""The input and output status of a request."""
|
2024-07-12 12:28:09 -07:00
|
|
|
|
|
2024-09-15 06:36:06 -07:00
|
|
|
|
def __init__(
|
|
|
|
|
|
self,
|
|
|
|
|
|
rid: str,
|
|
|
|
|
|
origin_input_text: str,
|
|
|
|
|
|
origin_input_ids: Tuple[int],
|
2024-09-29 17:42:45 -07:00
|
|
|
|
sampling_params: SamplingParams,
|
2024-12-08 12:27:13 -08:00
|
|
|
|
return_logprob: bool = False,
|
|
|
|
|
|
top_logprobs_num: int = 0,
|
|
|
|
|
|
stream: bool = False,
|
2024-11-27 00:03:29 -08:00
|
|
|
|
origin_input_ids_unpadded: Optional[Tuple[int]] = None,
|
2024-09-15 06:36:06 -07:00
|
|
|
|
lora_path: Optional[str] = None,
|
2024-11-25 19:35:04 -05:00
|
|
|
|
input_embeds: Optional[List[List[float]]] = None,
|
2024-11-20 00:36:53 -08:00
|
|
|
|
session_id: Optional[str] = None,
|
2024-12-28 02:59:56 +08:00
|
|
|
|
eos_token_ids: Optional[Set[int]] = None,
|
2024-09-15 06:36:06 -07:00
|
|
|
|
):
|
2024-07-12 18:21:11 -07:00
|
|
|
|
# Input and output info
|
2024-01-08 04:37:50 +00:00
|
|
|
|
self.rid = rid
|
2024-05-26 00:07:26 +08:00
|
|
|
|
self.origin_input_text = origin_input_text
|
2024-11-27 00:03:29 -08:00
|
|
|
|
self.origin_input_ids_unpadded = (
|
|
|
|
|
|
origin_input_ids_unpadded
|
|
|
|
|
|
if origin_input_ids_unpadded
|
|
|
|
|
|
else origin_input_ids # Before image padding
|
|
|
|
|
|
)
|
2024-05-26 00:07:26 +08:00
|
|
|
|
self.origin_input_ids = origin_input_ids
|
2024-06-12 14:39:12 +08:00
|
|
|
|
self.output_ids = [] # Each decode stage's output ids
|
2024-08-10 16:24:12 -07:00
|
|
|
|
self.fill_ids = None # fill_ids = origin_input_ids + output_ids
|
2024-11-20 00:36:53 -08:00
|
|
|
|
self.session_id = session_id
|
2024-12-08 12:27:13 -08:00
|
|
|
|
self.input_embeds = input_embeds
|
2024-11-20 00:36:53 -08:00
|
|
|
|
|
2024-12-08 12:27:13 -08:00
|
|
|
|
# Sampling info
|
2024-09-29 17:42:45 -07:00
|
|
|
|
self.sampling_params = sampling_params
|
2024-09-12 16:46:14 -07:00
|
|
|
|
self.lora_path = lora_path
|
2024-06-12 14:39:12 +08:00
|
|
|
|
|
2024-11-19 15:04:43 -08:00
|
|
|
|
# Memory pool info
|
2024-08-07 01:41:25 -07:00
|
|
|
|
self.req_pool_idx = None
|
|
|
|
|
|
|
2024-09-15 06:36:06 -07:00
|
|
|
|
# Check finish
|
|
|
|
|
|
self.tokenizer = None
|
|
|
|
|
|
self.finished_reason = None
|
2024-11-28 02:22:15 -08:00
|
|
|
|
self.to_abort = False
|
2024-12-08 12:27:13 -08:00
|
|
|
|
self.stream = stream
|
2024-12-28 02:59:56 +08:00
|
|
|
|
self.eos_token_ids = eos_token_ids
|
2024-09-15 06:36:06 -07:00
|
|
|
|
|
2024-07-12 12:28:09 -07:00
|
|
|
|
# For incremental decoding
|
2024-07-18 17:57:40 -07:00
|
|
|
|
# ----- | --------- read_ids -------|
|
|
|
|
|
|
# ----- | surr_ids |
|
|
|
|
|
|
# xxxxx | xxxxxxxxxxx | xxxxxxxxxxx |
|
|
|
|
|
|
# ----- ^ ----------- ^ ----------- ^
|
|
|
|
|
|
# ----- 1 ----------- 2 ----------- 3
|
|
|
|
|
|
# 1: surr_offset
|
|
|
|
|
|
# 2: read_offset
|
|
|
|
|
|
# 3: last token
|
2024-07-19 16:42:06 -07:00
|
|
|
|
self.vid = 0 # version id to sync decode status with in detokenizer_manager
|
2024-06-12 14:39:12 +08:00
|
|
|
|
self.surr_offset = None # Surrounding offset to defeat the cleanup algorithm
|
|
|
|
|
|
self.read_offset = None
|
2024-12-08 12:27:13 -08:00
|
|
|
|
self.decoded_text = ""
|
2024-02-09 20:06:15 -08:00
|
|
|
|
|
2024-11-11 16:34:10 -08:00
|
|
|
|
# For multimodal inputs
|
2024-09-28 23:28:55 -07:00
|
|
|
|
self.image_inputs: Optional[ImageInputs] = None
|
2024-01-29 17:05:42 -08:00
|
|
|
|
|
2024-07-12 18:21:11 -07:00
|
|
|
|
# Prefix info
|
|
|
|
|
|
self.prefix_indices = []
|
2024-12-11 12:51:50 -08:00
|
|
|
|
# Tokens to run prefill. input_tokens - shared_prefix_tokens.
|
2024-09-15 06:36:06 -07:00
|
|
|
|
self.extend_input_len = 0
|
2024-07-12 18:21:11 -07:00
|
|
|
|
self.last_node = None
|
2024-12-08 12:27:13 -08:00
|
|
|
|
|
|
|
|
|
|
# Chunked prefill
|
2024-10-31 14:51:51 -07:00
|
|
|
|
self.is_being_chunked = 0
|
2024-07-12 18:21:11 -07:00
|
|
|
|
|
2024-10-31 18:27:42 -07:00
|
|
|
|
# For retraction
|
|
|
|
|
|
self.is_retracted = False
|
|
|
|
|
|
|
2024-09-15 06:36:06 -07:00
|
|
|
|
# Logprobs (arguments)
|
2024-12-08 12:27:13 -08:00
|
|
|
|
self.return_logprob = return_logprob
|
2024-05-12 04:54:07 -07:00
|
|
|
|
self.logprob_start_len = 0
|
2024-12-08 12:27:13 -08:00
|
|
|
|
self.top_logprobs_num = top_logprobs_num
|
2024-09-15 06:36:06 -07:00
|
|
|
|
|
|
|
|
|
|
# Logprobs (return value)
|
2024-05-12 04:54:07 -07:00
|
|
|
|
self.normalized_prompt_logprob = None
|
2024-12-08 12:27:13 -08:00
|
|
|
|
self.input_token_logprobs_val = None
|
|
|
|
|
|
self.input_token_logprobs_idx = None
|
|
|
|
|
|
self.input_top_logprobs_val = None
|
|
|
|
|
|
self.input_top_logprobs_idx = None
|
|
|
|
|
|
|
|
|
|
|
|
if return_logprob:
|
|
|
|
|
|
self.output_token_logprobs_val = []
|
|
|
|
|
|
self.output_token_logprobs_idx = []
|
|
|
|
|
|
self.output_top_logprobs_val = []
|
|
|
|
|
|
self.output_top_logprobs_idx = []
|
|
|
|
|
|
else:
|
|
|
|
|
|
self.output_token_logprobs_val = self.output_token_logprobs_idx = (
|
|
|
|
|
|
self.output_top_logprobs_val
|
|
|
|
|
|
) = self.output_top_logprobs_idx = None
|
2024-09-15 06:36:06 -07:00
|
|
|
|
|
|
|
|
|
|
# Logprobs (internal values)
|
2024-05-26 00:07:26 +08:00
|
|
|
|
# The tokens is prefilled but need to be considered as decode tokens
|
|
|
|
|
|
# and should be updated for the decode logprobs
|
|
|
|
|
|
self.last_update_decode_tokens = 0
|
2024-09-15 06:36:06 -07:00
|
|
|
|
# The relative logprob_start_len in an extend batch
|
|
|
|
|
|
self.extend_logprob_start_len = 0
|
|
|
|
|
|
|
2024-10-31 18:27:42 -07:00
|
|
|
|
# Embedding (return values)
|
2024-09-15 06:36:06 -07:00
|
|
|
|
self.embedding = None
|
2024-01-08 04:37:50 +00:00
|
|
|
|
|
2024-05-12 04:54:07 -07:00
|
|
|
|
# Constrained decoding
|
2024-11-13 01:49:45 -08:00
|
|
|
|
self.grammar: Optional[BaseGrammarObject] = None
|
2024-05-26 00:07:26 +08:00
|
|
|
|
|
2024-10-31 18:27:42 -07:00
|
|
|
|
# The number of cached tokens, that were already cached in the KV cache
|
|
|
|
|
|
self.cached_tokens = 0
|
|
|
|
|
|
|
2024-11-29 03:15:58 -08:00
|
|
|
|
def extend_image_inputs(self, image_inputs):
|
2024-11-27 00:03:29 -08:00
|
|
|
|
if self.image_inputs is None:
|
|
|
|
|
|
self.image_inputs = image_inputs
|
|
|
|
|
|
else:
|
2024-11-29 03:15:58 -08:00
|
|
|
|
self.image_inputs.merge(image_inputs)
|
2024-11-27 00:03:29 -08:00
|
|
|
|
|
2024-06-08 04:20:40 +08:00
|
|
|
|
def finished(self) -> bool:
|
2024-12-08 12:27:13 -08:00
|
|
|
|
# Whether request reached finished condition
|
2024-06-08 04:20:40 +08:00
|
|
|
|
return self.finished_reason is not None
|
|
|
|
|
|
|
2024-08-11 17:57:02 -07:00
|
|
|
|
def init_next_round_input(self, tree_cache: Optional[BasePrefixCache] = None):
|
2024-08-10 16:24:12 -07:00
|
|
|
|
self.fill_ids = self.origin_input_ids + self.output_ids
|
2024-08-11 17:57:02 -07:00
|
|
|
|
if tree_cache is not None:
|
2024-12-11 12:51:50 -08:00
|
|
|
|
# tree cache is None if the prefix is not computed with tree cache.
|
2024-08-11 17:57:02 -07:00
|
|
|
|
self.prefix_indices, self.last_node = tree_cache.match_prefix(
|
|
|
|
|
|
rid=self.rid, key=self.adjust_max_prefix_ids()
|
|
|
|
|
|
)
|
2024-08-10 16:24:12 -07:00
|
|
|
|
self.extend_input_len = len(self.fill_ids) - len(self.prefix_indices)
|
2024-08-09 16:36:57 -07:00
|
|
|
|
|
2024-08-07 15:52:24 -07:00
|
|
|
|
def adjust_max_prefix_ids(self):
|
2024-08-10 16:24:12 -07:00
|
|
|
|
self.fill_ids = self.origin_input_ids + self.output_ids
|
|
|
|
|
|
input_len = len(self.fill_ids)
|
2024-09-03 06:31:45 -07:00
|
|
|
|
|
|
|
|
|
|
# FIXME: To work around some bugs in logprob computation, we need to ensure each
|
|
|
|
|
|
# request has at least one token. Later, we can relax this requirement and use `input_len`.
|
|
|
|
|
|
max_prefix_len = input_len - 1
|
2024-08-07 17:41:26 -07:00
|
|
|
|
|
|
|
|
|
|
if self.sampling_params.max_new_tokens > 0:
|
|
|
|
|
|
# Need at least one token to compute logits
|
|
|
|
|
|
max_prefix_len = min(max_prefix_len, input_len - 1)
|
|
|
|
|
|
|
2024-08-07 15:52:24 -07:00
|
|
|
|
if self.return_logprob:
|
2024-08-07 17:41:26 -07:00
|
|
|
|
if self.normalized_prompt_logprob is None:
|
|
|
|
|
|
# Need at least two tokens to compute normalized logprob
|
|
|
|
|
|
max_prefix_len = min(max_prefix_len, input_len - 2)
|
2024-09-03 06:31:45 -07:00
|
|
|
|
max_prefix_len = min(max_prefix_len, self.logprob_start_len)
|
2024-08-07 15:52:24 -07:00
|
|
|
|
|
2024-09-03 06:31:45 -07:00
|
|
|
|
max_prefix_len = max(max_prefix_len, 0)
|
2024-08-10 16:24:12 -07:00
|
|
|
|
return self.fill_ids[:max_prefix_len]
|
2024-08-07 15:52:24 -07:00
|
|
|
|
|
2024-06-12 14:39:12 +08:00
|
|
|
|
# Based on https://github.com/vllm-project/vllm/blob/7a64d24aad69e4d2548aa0bf528d9fe63428ab01/vllm/transformers_utils/detokenizer.py#L194-L313
|
2024-07-18 17:57:40 -07:00
|
|
|
|
def init_incremental_detokenize(self):
|
2024-06-12 14:39:12 +08:00
|
|
|
|
first_iter = self.surr_offset is None or self.read_offset is None
|
|
|
|
|
|
|
|
|
|
|
|
if first_iter:
|
|
|
|
|
|
self.read_offset = len(self.origin_input_ids_unpadded)
|
|
|
|
|
|
self.surr_offset = max(
|
|
|
|
|
|
self.read_offset - INIT_INCREMENTAL_DETOKENIZATION_OFFSET, 0
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
all_ids = self.origin_input_ids_unpadded + self.output_ids
|
2024-07-18 17:57:40 -07:00
|
|
|
|
return all_ids[self.surr_offset :], self.read_offset - self.surr_offset
|
2024-06-12 14:39:12 +08:00
|
|
|
|
|
2024-07-18 17:57:40 -07:00
|
|
|
|
def get_next_inc_detokenization(self):
|
2024-08-10 03:14:13 +08:00
|
|
|
|
if self.tokenizer is None:
|
|
|
|
|
|
return False, ""
|
2024-07-18 17:57:40 -07:00
|
|
|
|
read_ids, read_offset = self.init_incremental_detokenize()
|
|
|
|
|
|
surr_ids = read_ids[:read_offset]
|
2024-06-12 14:39:12 +08:00
|
|
|
|
|
|
|
|
|
|
surr_text = self.tokenizer.decode(
|
|
|
|
|
|
surr_ids,
|
|
|
|
|
|
skip_special_tokens=self.sampling_params.skip_special_tokens,
|
|
|
|
|
|
spaces_between_special_tokens=self.sampling_params.spaces_between_special_tokens,
|
2024-05-26 00:07:26 +08:00
|
|
|
|
)
|
2024-06-12 14:39:12 +08:00
|
|
|
|
new_text = self.tokenizer.decode(
|
|
|
|
|
|
read_ids,
|
|
|
|
|
|
skip_special_tokens=self.sampling_params.skip_special_tokens,
|
|
|
|
|
|
spaces_between_special_tokens=self.sampling_params.spaces_between_special_tokens,
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
if len(new_text) > len(surr_text) and not new_text.endswith("<EFBFBD>"):
|
2024-07-18 17:57:40 -07:00
|
|
|
|
return True, new_text[len(surr_text) :]
|
2024-06-12 14:39:12 +08:00
|
|
|
|
|
|
|
|
|
|
return False, ""
|
2024-01-08 04:37:50 +00:00
|
|
|
|
|
2024-05-20 18:41:21 -07:00
|
|
|
|
def check_finished(self):
|
2024-06-08 04:20:40 +08:00
|
|
|
|
if self.finished():
|
2024-05-20 18:41:21 -07:00
|
|
|
|
return
|
|
|
|
|
|
|
2024-11-28 02:22:15 -08:00
|
|
|
|
if self.to_abort:
|
|
|
|
|
|
self.finished_reason = FINISH_ABORT()
|
|
|
|
|
|
return
|
|
|
|
|
|
|
2024-06-12 14:39:12 +08:00
|
|
|
|
if len(self.output_ids) >= self.sampling_params.max_new_tokens:
|
2024-08-08 17:41:57 +08:00
|
|
|
|
self.finished_reason = FINISH_LENGTH(
|
|
|
|
|
|
length=self.sampling_params.max_new_tokens
|
|
|
|
|
|
)
|
2024-05-20 18:41:21 -07:00
|
|
|
|
return
|
|
|
|
|
|
|
2024-08-08 04:21:08 -07:00
|
|
|
|
last_token_id = self.output_ids[-1]
|
2024-08-14 17:31:39 -07:00
|
|
|
|
|
2024-12-28 02:59:56 +08:00
|
|
|
|
if not self.sampling_params.ignore_eos:
|
|
|
|
|
|
matched_eos = False
|
|
|
|
|
|
|
|
|
|
|
|
# Check stop token ids
|
|
|
|
|
|
if self.sampling_params.stop_token_ids:
|
|
|
|
|
|
matched_eos = last_token_id in self.sampling_params.stop_token_ids
|
|
|
|
|
|
if self.eos_token_ids:
|
|
|
|
|
|
matched_eos |= last_token_id in self.eos_token_ids
|
|
|
|
|
|
if self.tokenizer is not None:
|
|
|
|
|
|
matched_eos |= last_token_id == self.tokenizer.eos_token_id
|
|
|
|
|
|
if self.tokenizer.additional_stop_token_ids:
|
|
|
|
|
|
matched_eos |= (
|
|
|
|
|
|
last_token_id in self.tokenizer.additional_stop_token_ids
|
|
|
|
|
|
)
|
|
|
|
|
|
if matched_eos:
|
|
|
|
|
|
self.finished_reason = FINISH_MATCHED_TOKEN(matched=last_token_id)
|
|
|
|
|
|
return
|
2024-08-08 04:21:08 -07:00
|
|
|
|
|
2024-10-23 16:45:21 -07:00
|
|
|
|
# Check stop strings
|
2024-05-20 18:41:21 -07:00
|
|
|
|
if len(self.sampling_params.stop_strs) > 0:
|
|
|
|
|
|
tail_str = self.tokenizer.decode(
|
|
|
|
|
|
self.output_ids[-(self.sampling_params.stop_str_max_len + 1) :]
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
for stop_str in self.sampling_params.stop_strs:
|
2024-06-12 14:39:12 +08:00
|
|
|
|
if stop_str in tail_str or stop_str in self.decoded_text:
|
2024-06-08 04:20:40 +08:00
|
|
|
|
self.finished_reason = FINISH_MATCHED_STR(matched=stop_str)
|
2024-05-20 18:41:21 -07:00
|
|
|
|
return
|
|
|
|
|
|
|
2024-02-05 16:50:37 +08:00
|
|
|
|
def jump_forward_and_retokenize(self, jump_forward_str, next_state):
|
2024-05-26 00:07:26 +08:00
|
|
|
|
if self.origin_input_text is None:
|
|
|
|
|
|
# Recovering text can only use unpadded ids
|
|
|
|
|
|
self.origin_input_text = self.tokenizer.decode(
|
|
|
|
|
|
self.origin_input_ids_unpadded
|
|
|
|
|
|
)
|
|
|
|
|
|
|
2024-06-12 14:39:12 +08:00
|
|
|
|
all_text = self.origin_input_text + self.decoded_text + jump_forward_str
|
2024-05-26 00:07:26 +08:00
|
|
|
|
all_ids = self.tokenizer.encode(all_text)
|
2024-08-26 18:37:26 +02:00
|
|
|
|
if not all_ids:
|
2024-08-27 12:10:46 +02:00
|
|
|
|
logger.warning("Encoded all_text resulted in empty all_ids")
|
2024-08-26 18:37:26 +02:00
|
|
|
|
return False
|
|
|
|
|
|
|
2024-05-26 00:07:26 +08:00
|
|
|
|
prompt_tokens = len(self.origin_input_ids_unpadded)
|
2024-08-26 18:37:26 +02:00
|
|
|
|
if prompt_tokens > len(all_ids):
|
2024-08-27 12:10:46 +02:00
|
|
|
|
logger.warning("prompt_tokens is larger than encoded all_ids")
|
2024-08-26 18:37:26 +02:00
|
|
|
|
return False
|
2024-06-12 14:39:12 +08:00
|
|
|
|
|
|
|
|
|
|
if all_ids[prompt_tokens - 1] != self.origin_input_ids_unpadded[-1]:
|
|
|
|
|
|
# TODO(lsyin): fix token fusion
|
2024-08-20 22:35:05 -07:00
|
|
|
|
logger.warning(
|
2024-06-12 14:39:12 +08:00
|
|
|
|
"Token fusion between input and output, try to avoid this by removing the space at the end of the input."
|
|
|
|
|
|
)
|
|
|
|
|
|
return False
|
|
|
|
|
|
|
|
|
|
|
|
old_output_ids = self.output_ids
|
|
|
|
|
|
self.output_ids = all_ids[prompt_tokens:]
|
|
|
|
|
|
self.decoded_text = self.decoded_text + jump_forward_str
|
|
|
|
|
|
self.surr_offset = prompt_tokens
|
|
|
|
|
|
self.read_offset = len(all_ids)
|
|
|
|
|
|
|
|
|
|
|
|
# NOTE: A trick to reduce the surrouding tokens decoding overhead
|
|
|
|
|
|
for i in range(0, INIT_INCREMENTAL_DETOKENIZATION_OFFSET):
|
|
|
|
|
|
surr_text_ = self.tokenizer.decode(
|
|
|
|
|
|
all_ids[self.read_offset - i : self.read_offset]
|
|
|
|
|
|
)
|
|
|
|
|
|
if not surr_text_.endswith("<EFBFBD>"):
|
|
|
|
|
|
self.surr_offset = self.read_offset - i
|
|
|
|
|
|
break
|
2024-05-26 00:07:26 +08:00
|
|
|
|
|
2024-10-26 06:47:02 +09:00
|
|
|
|
# update the inner state of the grammar
|
|
|
|
|
|
self.grammar.jump_and_retokenize(old_output_ids, self.output_ids, next_state)
|
2024-05-26 00:07:26 +08:00
|
|
|
|
|
|
|
|
|
|
if self.return_logprob:
|
|
|
|
|
|
# For fast-forward part's logprobs
|
|
|
|
|
|
k = 0
|
2024-06-12 14:39:12 +08:00
|
|
|
|
for i, old_id in enumerate(old_output_ids):
|
|
|
|
|
|
if old_id == self.output_ids[i]:
|
2024-05-26 00:07:26 +08:00
|
|
|
|
k = k + 1
|
|
|
|
|
|
else:
|
|
|
|
|
|
break
|
2024-12-08 12:27:13 -08:00
|
|
|
|
self.output_token_logprobs_val = self.output_token_logprobs_val[:k]
|
|
|
|
|
|
self.output_token_logprobs_idx = self.output_token_logprobs_idx[:k]
|
|
|
|
|
|
self.output_top_logprobs_val = self.output_top_logprobs_val[:k]
|
|
|
|
|
|
self.output_top_logprobs_idx = self.output_top_logprobs_idx[:k]
|
2024-05-26 00:07:26 +08:00
|
|
|
|
self.logprob_start_len = prompt_tokens + k
|
2024-06-12 14:39:12 +08:00
|
|
|
|
self.last_update_decode_tokens = len(self.output_ids) - k
|
2024-02-03 23:32:05 +08:00
|
|
|
|
|
2024-06-12 14:39:12 +08:00
|
|
|
|
return True
|
2024-01-25 01:16:25 +08:00
|
|
|
|
|
2024-12-22 06:25:57 -08:00
|
|
|
|
def reset_for_retract(self):
|
|
|
|
|
|
self.prefix_indices = []
|
|
|
|
|
|
self.last_node = None
|
|
|
|
|
|
self.extend_input_len = 0
|
|
|
|
|
|
self.is_retracted = True
|
|
|
|
|
|
|
|
|
|
|
|
# For incremental logprobs
|
|
|
|
|
|
# TODO: Fix the `logprob_start_len`
|
|
|
|
|
|
self.last_update_decode_tokens = 0
|
|
|
|
|
|
self.logprob_start_len = 10**9
|
|
|
|
|
|
|
2024-01-08 04:37:50 +00:00
|
|
|
|
def __repr__(self):
|
2024-12-22 06:25:57 -08:00
|
|
|
|
return (
|
|
|
|
|
|
f"rid(n={self.rid}, "
|
|
|
|
|
|
f"input_ids={self.origin_input_ids}, output_ids={self.output_ids}"
|
|
|
|
|
|
)
|
2024-01-08 04:37:50 +00:00
|
|
|
|
|
|
|
|
|
|
|
2024-10-14 20:08:03 -07:00
|
|
|
|
bid = 0
|
|
|
|
|
|
|
|
|
|
|
|
|
2024-10-20 19:47:14 -07:00
|
|
|
|
@dataclasses.dataclass
|
2024-08-06 20:50:32 -07:00
|
|
|
|
class ScheduleBatch:
|
2024-12-09 06:30:35 -08:00
|
|
|
|
"""Store all information of a batch on the scheduler."""
|
2024-07-12 12:28:09 -07:00
|
|
|
|
|
2024-07-12 18:21:11 -07:00
|
|
|
|
# Request, memory pool, and cache
|
2024-01-20 03:01:15 +08:00
|
|
|
|
reqs: List[Req]
|
2024-10-19 06:50:56 -07:00
|
|
|
|
req_to_token_pool: ReqToTokenPool = None
|
|
|
|
|
|
token_to_kv_pool: BaseTokenToKVPool = None
|
|
|
|
|
|
tree_cache: BasePrefixCache = None
|
2024-10-21 15:01:21 -07:00
|
|
|
|
|
2024-11-24 06:29:38 -08:00
|
|
|
|
# Batch configs
|
2024-10-21 15:01:21 -07:00
|
|
|
|
model_config: ModelConfig = None
|
2024-09-09 13:49:29 -07:00
|
|
|
|
forward_mode: ForwardMode = None
|
2024-11-24 06:29:38 -08:00
|
|
|
|
enable_overlap: bool = False
|
|
|
|
|
|
|
|
|
|
|
|
# Sampling info
|
2024-09-16 21:23:31 -07:00
|
|
|
|
sampling_info: SamplingBatchInfo = None
|
2024-11-19 15:04:43 -08:00
|
|
|
|
next_batch_sampling_info: SamplingBatchInfo = None
|
2024-09-09 13:49:29 -07:00
|
|
|
|
|
2024-07-12 18:21:11 -07:00
|
|
|
|
# Batched arguments to model runner
|
2024-10-12 14:49:24 -07:00
|
|
|
|
input_ids: torch.Tensor = None
|
2024-11-25 19:35:04 -05:00
|
|
|
|
input_embeds: torch.Tensor = None
|
2024-10-12 14:49:24 -07:00
|
|
|
|
req_pool_indices: torch.Tensor = None
|
|
|
|
|
|
seq_lens: torch.Tensor = None
|
2024-10-21 01:43:16 -07:00
|
|
|
|
# The output locations of the KV cache
|
2024-01-20 03:01:15 +08:00
|
|
|
|
out_cache_loc: torch.Tensor = None
|
2024-10-13 19:54:02 -07:00
|
|
|
|
output_ids: torch.Tensor = None
|
|
|
|
|
|
|
2024-10-21 01:43:16 -07:00
|
|
|
|
# The sum of all sequence lengths
|
|
|
|
|
|
seq_lens_sum: int = None
|
|
|
|
|
|
|
2024-11-16 17:01:43 +08:00
|
|
|
|
# For DP attention
|
|
|
|
|
|
global_num_tokens: Optional[List[int]] = None
|
2024-11-18 08:29:20 +08:00
|
|
|
|
can_run_dp_cuda_graph: bool = False
|
2024-11-16 17:01:43 +08:00
|
|
|
|
|
2024-07-12 18:21:11 -07:00
|
|
|
|
# For processing logprobs
|
2024-01-23 05:07:30 -08:00
|
|
|
|
return_logprob: bool = False
|
2024-09-30 06:41:49 -07:00
|
|
|
|
top_logprobs_nums: Optional[List[int]] = None
|
|
|
|
|
|
|
|
|
|
|
|
# For extend and mixed chunekd prefill
|
|
|
|
|
|
prefix_lens: List[int] = None
|
|
|
|
|
|
extend_lens: List[int] = None
|
|
|
|
|
|
extend_num_tokens: int = None
|
2024-10-14 02:01:44 -07:00
|
|
|
|
decoding_reqs: List[Req] = None
|
2024-11-24 04:47:10 -08:00
|
|
|
|
extend_logprob_start_lens: List[int] = None
|
2024-01-20 03:01:15 +08:00
|
|
|
|
|
2024-10-21 15:01:21 -07:00
|
|
|
|
# For encoder-decoder
|
|
|
|
|
|
encoder_cached: Optional[List[bool]] = None
|
|
|
|
|
|
encoder_lens: Optional[torch.Tensor] = None
|
|
|
|
|
|
encoder_lens_cpu: Optional[List[int]] = None
|
|
|
|
|
|
encoder_out_cache_loc: Optional[torch.Tensor] = None
|
|
|
|
|
|
|
2024-09-15 06:36:06 -07:00
|
|
|
|
# Stream
|
|
|
|
|
|
has_stream: bool = False
|
|
|
|
|
|
|
2024-10-26 06:47:02 +09:00
|
|
|
|
# Has grammar
|
|
|
|
|
|
has_grammar: bool = False
|
2024-10-02 17:18:04 -07:00
|
|
|
|
|
2024-10-20 18:17:41 -07:00
|
|
|
|
# device
|
|
|
|
|
|
device: str = "cuda"
|
|
|
|
|
|
|
2024-01-20 03:01:15 +08:00
|
|
|
|
@classmethod
|
2024-10-21 15:01:21 -07:00
|
|
|
|
def init_new(
|
|
|
|
|
|
cls,
|
2024-10-26 06:47:02 +09:00
|
|
|
|
reqs: List[Req],
|
2024-11-24 06:29:38 -08:00
|
|
|
|
req_to_token_pool: ReqToTokenPool,
|
|
|
|
|
|
token_to_kv_pool: ReqToTokenPool,
|
|
|
|
|
|
tree_cache: BasePrefixCache,
|
|
|
|
|
|
model_config: ModelConfig,
|
|
|
|
|
|
enable_overlap: bool,
|
2024-10-21 15:01:21 -07:00
|
|
|
|
):
|
2024-01-20 03:01:15 +08:00
|
|
|
|
return cls(
|
|
|
|
|
|
reqs=reqs,
|
|
|
|
|
|
req_to_token_pool=req_to_token_pool,
|
|
|
|
|
|
token_to_kv_pool=token_to_kv_pool,
|
|
|
|
|
|
tree_cache=tree_cache,
|
2024-10-21 15:01:21 -07:00
|
|
|
|
model_config=model_config,
|
2024-11-24 06:29:38 -08:00
|
|
|
|
enable_overlap=enable_overlap,
|
2024-10-20 18:17:41 -07:00
|
|
|
|
return_logprob=any(req.return_logprob for req in reqs),
|
|
|
|
|
|
has_stream=any(req.stream for req in reqs),
|
2024-10-26 06:47:02 +09:00
|
|
|
|
has_grammar=any(req.grammar for req in reqs),
|
2024-10-11 17:05:58 +08:00
|
|
|
|
device=req_to_token_pool.device,
|
2024-01-08 04:37:50 +00:00
|
|
|
|
)
|
|
|
|
|
|
|
2024-08-07 01:41:25 -07:00
|
|
|
|
def batch_size(self):
|
2024-09-15 06:36:06 -07:00
|
|
|
|
return len(self.reqs)
|
2024-08-07 01:41:25 -07:00
|
|
|
|
|
2024-01-08 04:37:50 +00:00
|
|
|
|
def is_empty(self):
|
|
|
|
|
|
return len(self.reqs) == 0
|
|
|
|
|
|
|
2024-11-19 15:04:43 -08:00
|
|
|
|
def alloc_req_slots(self, num_reqs: int):
|
2024-08-07 01:41:25 -07:00
|
|
|
|
req_pool_indices = self.req_to_token_pool.alloc(num_reqs)
|
|
|
|
|
|
if req_pool_indices is None:
|
|
|
|
|
|
raise RuntimeError(
|
|
|
|
|
|
"Out of memory. "
|
|
|
|
|
|
"Please set a smaller number for `--max-running-requests`."
|
|
|
|
|
|
)
|
|
|
|
|
|
return req_pool_indices
|
|
|
|
|
|
|
|
|
|
|
|
def alloc_token_slots(self, num_tokens: int):
|
|
|
|
|
|
out_cache_loc = self.token_to_kv_pool.alloc(num_tokens)
|
|
|
|
|
|
|
|
|
|
|
|
if out_cache_loc is None:
|
|
|
|
|
|
if self.tree_cache is not None:
|
|
|
|
|
|
self.tree_cache.evict(num_tokens, self.token_to_kv_pool.free)
|
|
|
|
|
|
out_cache_loc = self.token_to_kv_pool.alloc(num_tokens)
|
|
|
|
|
|
|
|
|
|
|
|
if out_cache_loc is None:
|
2024-10-23 15:20:39 -07:00
|
|
|
|
phase_str = "Prefill" if self.forward_mode.is_extend() else "Decode"
|
|
|
|
|
|
logger.error(
|
|
|
|
|
|
f"{phase_str} out of memory. Try to lower your batch size.\n"
|
|
|
|
|
|
f"Try to allocate {num_tokens} tokens.\n"
|
|
|
|
|
|
f"Avaliable tokens: {self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size()}\n"
|
|
|
|
|
|
)
|
2024-08-07 01:41:25 -07:00
|
|
|
|
if self.tree_cache is not None:
|
|
|
|
|
|
self.tree_cache.pretty_print()
|
|
|
|
|
|
exit(1)
|
|
|
|
|
|
|
|
|
|
|
|
return out_cache_loc
|
|
|
|
|
|
|
2024-10-21 15:01:21 -07:00
|
|
|
|
def prepare_encoder_info_extend(self, input_ids: List[int], seq_lens: List[int]):
|
|
|
|
|
|
self.encoder_lens_cpu = []
|
|
|
|
|
|
self.encoder_cached = []
|
|
|
|
|
|
|
|
|
|
|
|
for req in self.reqs:
|
|
|
|
|
|
im = req.image_inputs
|
|
|
|
|
|
if im is None or im.num_image_tokens is None:
|
|
|
|
|
|
# No image input
|
|
|
|
|
|
self.encoder_lens_cpu.append(0)
|
|
|
|
|
|
self.encoder_cached.append(True)
|
|
|
|
|
|
else:
|
|
|
|
|
|
self.encoder_lens_cpu.append(im.num_image_tokens)
|
|
|
|
|
|
self.encoder_cached.append(
|
|
|
|
|
|
self.forward_mode.is_decode()
|
|
|
|
|
|
or len(req.prefix_indices) >= im.num_image_tokens
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
self.encoder_lens = torch.tensor(self.encoder_lens_cpu, dtype=torch.int32).to(
|
|
|
|
|
|
self.device, non_blocking=True
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
# Strip encoder infos
|
|
|
|
|
|
pt = 0
|
|
|
|
|
|
decoder_out_cache_loc = []
|
|
|
|
|
|
encoder_out_cache_loc = []
|
|
|
|
|
|
for i, req in enumerate(self.reqs):
|
|
|
|
|
|
encoder_len = self.encoder_lens_cpu[i]
|
|
|
|
|
|
seq_lens[i] -= encoder_len
|
|
|
|
|
|
|
|
|
|
|
|
if len(req.prefix_indices) < encoder_len:
|
2024-10-31 18:27:42 -07:00
|
|
|
|
# NOTE: the encoder part should be considered as a whole
|
2024-10-21 15:01:21 -07:00
|
|
|
|
assert len(req.prefix_indices) == 0
|
|
|
|
|
|
input_ids[i] = input_ids[i][encoder_len:]
|
|
|
|
|
|
encoder_out_cache_loc.append(self.out_cache_loc[pt : pt + encoder_len])
|
|
|
|
|
|
decoder_out_cache_loc.append(
|
|
|
|
|
|
self.out_cache_loc[pt + encoder_len : pt + req.extend_input_len]
|
|
|
|
|
|
)
|
|
|
|
|
|
self.extend_lens[i] -= encoder_len
|
|
|
|
|
|
self.extend_num_tokens -= encoder_len
|
|
|
|
|
|
else:
|
|
|
|
|
|
decoder_out_cache_loc.append(
|
|
|
|
|
|
self.out_cache_loc[pt : pt + req.extend_input_len]
|
|
|
|
|
|
)
|
|
|
|
|
|
self.prefix_lens[i] -= encoder_len
|
|
|
|
|
|
|
|
|
|
|
|
pt += req.extend_input_len
|
|
|
|
|
|
|
|
|
|
|
|
# Reassign
|
|
|
|
|
|
self.input_ids = torch.tensor(sum(input_ids, []), dtype=torch.int32).to(
|
|
|
|
|
|
self.device, non_blocking=True
|
|
|
|
|
|
)
|
|
|
|
|
|
self.seq_lens = torch.tensor(seq_lens, dtype=torch.int32).to(
|
|
|
|
|
|
self.device, non_blocking=True
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
if not decoder_out_cache_loc:
|
2024-11-15 21:24:42 -08:00
|
|
|
|
self.out_cache_loc = torch.zeros(0, dtype=torch.int32).to(
|
2024-10-21 15:01:21 -07:00
|
|
|
|
self.device, non_blocking=True
|
|
|
|
|
|
)
|
|
|
|
|
|
else:
|
|
|
|
|
|
self.out_cache_loc = torch.cat(decoder_out_cache_loc)
|
|
|
|
|
|
|
|
|
|
|
|
if not encoder_out_cache_loc:
|
2024-11-15 21:24:42 -08:00
|
|
|
|
self.encoder_out_cache_loc = torch.zeros(0, dtype=torch.int32).to(
|
2024-10-21 15:01:21 -07:00
|
|
|
|
self.device, non_blocking=True
|
|
|
|
|
|
)
|
|
|
|
|
|
else:
|
|
|
|
|
|
self.encoder_out_cache_loc = torch.cat(encoder_out_cache_loc)
|
|
|
|
|
|
|
|
|
|
|
|
assert len(self.out_cache_loc) == self.extend_num_tokens
|
|
|
|
|
|
|
2024-11-24 06:29:38 -08:00
|
|
|
|
def prepare_for_extend(self):
|
2024-09-09 13:49:29 -07:00
|
|
|
|
self.forward_mode = ForwardMode.EXTEND
|
|
|
|
|
|
|
2024-09-23 07:38:14 -07:00
|
|
|
|
bs = len(self.reqs)
|
2024-01-08 04:37:50 +00:00
|
|
|
|
reqs = self.reqs
|
2024-08-10 16:24:12 -07:00
|
|
|
|
input_ids = [r.fill_ids[len(r.prefix_indices) :] for r in reqs]
|
2024-08-08 01:11:22 -07:00
|
|
|
|
extend_num_tokens = sum(len(ids) for ids in input_ids)
|
2024-01-08 04:37:50 +00:00
|
|
|
|
seq_lens = []
|
2024-11-16 16:14:23 -08:00
|
|
|
|
pre_lens = []
|
2024-01-08 04:37:50 +00:00
|
|
|
|
|
2024-08-08 01:11:22 -07:00
|
|
|
|
# Allocate memory
|
2024-09-30 06:41:49 -07:00
|
|
|
|
req_pool_indices = self.alloc_req_slots(bs)
|
2024-08-08 01:11:22 -07:00
|
|
|
|
out_cache_loc = self.alloc_token_slots(extend_num_tokens)
|
2024-07-17 15:44:41 -07:00
|
|
|
|
|
2024-11-25 19:35:04 -05:00
|
|
|
|
input_embeds = []
|
|
|
|
|
|
|
|
|
|
|
|
pt = 0
|
2024-08-07 01:41:25 -07:00
|
|
|
|
for i, req in enumerate(reqs):
|
2024-10-16 20:49:22 +02:00
|
|
|
|
already_computed = (
|
|
|
|
|
|
req.extend_logprob_start_len + 1 + req.cached_tokens
|
|
|
|
|
|
if req.extend_logprob_start_len > 0
|
|
|
|
|
|
else 0
|
|
|
|
|
|
)
|
|
|
|
|
|
req.cached_tokens += len(req.prefix_indices) - already_computed
|
|
|
|
|
|
|
2024-09-30 06:41:49 -07:00
|
|
|
|
req.req_pool_idx = req_pool_indices[i]
|
2024-08-10 16:24:12 -07:00
|
|
|
|
pre_len, seq_len = len(req.prefix_indices), len(req.fill_ids)
|
2024-08-08 01:11:22 -07:00
|
|
|
|
seq_lens.append(seq_len)
|
2024-09-15 06:36:06 -07:00
|
|
|
|
assert seq_len - pre_len == req.extend_input_len
|
2024-01-08 04:37:50 +00:00
|
|
|
|
|
2024-08-08 01:11:22 -07:00
|
|
|
|
if pre_len > 0:
|
2024-10-19 23:19:26 -07:00
|
|
|
|
self.req_to_token_pool.write(
|
|
|
|
|
|
(req.req_pool_idx, slice(0, pre_len)), req.prefix_indices
|
2024-10-18 17:54:03 -07:00
|
|
|
|
)
|
2024-09-15 06:36:06 -07:00
|
|
|
|
|
2024-11-25 19:35:04 -05:00
|
|
|
|
# If input_embeds are available, store them
|
|
|
|
|
|
if req.input_embeds is not None:
|
|
|
|
|
|
# If req.input_embeds is already a list, append its content directly
|
|
|
|
|
|
input_embeds.extend(req.input_embeds) # Use extend to avoid nesting
|
|
|
|
|
|
|
2024-09-15 06:36:06 -07:00
|
|
|
|
# Compute the relative logprob_start_len in an extend batch
|
|
|
|
|
|
if req.logprob_start_len >= pre_len:
|
|
|
|
|
|
extend_logprob_start_len = min(
|
|
|
|
|
|
req.logprob_start_len - pre_len, req.extend_input_len - 1
|
|
|
|
|
|
)
|
|
|
|
|
|
else:
|
|
|
|
|
|
extend_logprob_start_len = req.extend_input_len - 1
|
|
|
|
|
|
|
|
|
|
|
|
req.extend_logprob_start_len = extend_logprob_start_len
|
2024-10-31 18:27:42 -07:00
|
|
|
|
req.is_retracted = False
|
2024-11-16 16:14:23 -08:00
|
|
|
|
pre_lens.append(pre_len)
|
2024-01-08 04:37:50 +00:00
|
|
|
|
|
|
|
|
|
|
# Set fields
|
2024-10-18 17:54:03 -07:00
|
|
|
|
self.input_ids = torch.tensor(sum(input_ids, []), dtype=torch.int32).to(
|
|
|
|
|
|
self.device, non_blocking=True
|
|
|
|
|
|
)
|
|
|
|
|
|
self.req_pool_indices = torch.tensor(req_pool_indices, dtype=torch.int32).to(
|
|
|
|
|
|
self.device, non_blocking=True
|
|
|
|
|
|
)
|
|
|
|
|
|
self.seq_lens = torch.tensor(seq_lens, dtype=torch.int32).to(
|
|
|
|
|
|
self.device, non_blocking=True
|
|
|
|
|
|
)
|
2024-11-25 19:35:04 -05:00
|
|
|
|
self.input_embeds = (
|
|
|
|
|
|
torch.tensor(input_embeds).to(self.device, non_blocking=True)
|
|
|
|
|
|
if input_embeds
|
|
|
|
|
|
else None
|
|
|
|
|
|
)
|
|
|
|
|
|
|
2024-01-08 04:37:50 +00:00
|
|
|
|
self.out_cache_loc = out_cache_loc
|
2024-10-21 01:43:16 -07:00
|
|
|
|
|
|
|
|
|
|
self.seq_lens_sum = sum(seq_lens)
|
2024-09-30 06:41:49 -07:00
|
|
|
|
if self.return_logprob:
|
|
|
|
|
|
self.top_logprobs_nums = [r.top_logprobs_num for r in reqs]
|
2024-10-21 01:43:16 -07:00
|
|
|
|
self.extend_num_tokens = extend_num_tokens
|
2024-09-30 06:41:49 -07:00
|
|
|
|
self.prefix_lens = [len(r.prefix_indices) for r in reqs]
|
|
|
|
|
|
self.extend_lens = [r.extend_input_len for r in reqs]
|
|
|
|
|
|
self.extend_logprob_start_lens = [r.extend_logprob_start_len for r in reqs]
|
2024-01-08 04:37:50 +00:00
|
|
|
|
|
2024-11-16 16:14:23 -08:00
|
|
|
|
# Write to req_to_token_pool
|
|
|
|
|
|
pre_lens = torch.tensor(pre_lens, dtype=torch.int32).to(
|
|
|
|
|
|
self.device, non_blocking=True
|
|
|
|
|
|
)
|
|
|
|
|
|
extend_lens = torch.tensor(self.extend_lens, dtype=torch.int32).to(
|
|
|
|
|
|
self.device, non_blocking=True
|
|
|
|
|
|
)
|
2024-12-01 19:01:25 +08:00
|
|
|
|
if global_server_args_dict["attention_backend"] != "torch_native":
|
|
|
|
|
|
write_req_to_token_pool_triton[(bs,)](
|
|
|
|
|
|
self.req_to_token_pool.req_to_token,
|
|
|
|
|
|
self.req_pool_indices,
|
|
|
|
|
|
pre_lens,
|
|
|
|
|
|
self.seq_lens,
|
|
|
|
|
|
extend_lens,
|
|
|
|
|
|
self.out_cache_loc,
|
|
|
|
|
|
self.req_to_token_pool.req_to_token.shape[1],
|
|
|
|
|
|
)
|
|
|
|
|
|
else:
|
|
|
|
|
|
pt = 0
|
|
|
|
|
|
for i in range(bs):
|
|
|
|
|
|
self.req_to_token_pool.write(
|
|
|
|
|
|
(self.req_pool_indices[i], slice(pre_lens[i], self.seq_lens[i])),
|
|
|
|
|
|
self.out_cache_loc[pt : pt + self.extend_lens[i]],
|
|
|
|
|
|
)
|
|
|
|
|
|
pt += self.extend_lens[i]
|
2024-11-16 16:14:23 -08:00
|
|
|
|
# TODO: some tensors can be reused for ForwardBatchInfo (e.g., extend_lens, cumsum_start)
|
|
|
|
|
|
|
2024-10-21 15:01:21 -07:00
|
|
|
|
if self.model_config.is_encoder_decoder:
|
|
|
|
|
|
self.prepare_encoder_info_extend(input_ids, seq_lens)
|
|
|
|
|
|
|
2024-11-16 16:14:23 -08:00
|
|
|
|
# Build sampling info
|
2024-10-12 17:53:23 -07:00
|
|
|
|
self.sampling_info = SamplingBatchInfo.from_schedule_batch(
|
2024-10-21 15:01:21 -07:00
|
|
|
|
self,
|
|
|
|
|
|
self.model_config.vocab_size,
|
2024-11-24 06:29:38 -08:00
|
|
|
|
enable_overlap_schedule=self.enable_overlap,
|
2024-10-12 17:53:23 -07:00
|
|
|
|
)
|
2024-09-29 20:28:45 -07:00
|
|
|
|
|
2024-08-16 02:13:00 -07:00
|
|
|
|
def mix_with_running(self, running_batch: "ScheduleBatch"):
|
2024-09-10 17:38:59 -07:00
|
|
|
|
self.forward_mode = ForwardMode.MIXED
|
2024-09-15 06:36:06 -07:00
|
|
|
|
running_bs = running_batch.batch_size()
|
2024-08-16 02:13:00 -07:00
|
|
|
|
|
|
|
|
|
|
for req in running_batch.reqs:
|
|
|
|
|
|
req.fill_ids = req.origin_input_ids + req.output_ids
|
|
|
|
|
|
req.extend_input_len = 1
|
|
|
|
|
|
|
2024-10-04 01:09:59 -07:00
|
|
|
|
input_ids = torch.cat([self.input_ids, running_batch.input_ids])
|
2024-08-16 02:13:00 -07:00
|
|
|
|
out_cache_loc = torch.cat([self.out_cache_loc, running_batch.out_cache_loc])
|
2024-09-15 06:36:06 -07:00
|
|
|
|
|
2024-09-30 06:41:49 -07:00
|
|
|
|
self.merge_batch(running_batch)
|
2024-08-16 02:13:00 -07:00
|
|
|
|
self.input_ids = input_ids
|
|
|
|
|
|
self.out_cache_loc = out_cache_loc
|
2024-09-15 06:36:06 -07:00
|
|
|
|
|
2024-11-24 07:17:37 -08:00
|
|
|
|
# For overlap scheduler, the output_ids has one step delay
|
|
|
|
|
|
delta = 0 if self.enable_overlap else -1
|
|
|
|
|
|
|
2024-09-15 06:36:06 -07:00
|
|
|
|
# NOTE: prefix_indices is what has been cached, but we don't cache each decode step
|
2024-09-30 06:41:49 -07:00
|
|
|
|
self.prefix_lens.extend(
|
2024-09-15 06:36:06 -07:00
|
|
|
|
[
|
2024-11-24 07:17:37 -08:00
|
|
|
|
len(r.origin_input_ids) + len(r.output_ids) + delta
|
2024-09-15 06:36:06 -07:00
|
|
|
|
for r in running_batch.reqs
|
|
|
|
|
|
]
|
|
|
|
|
|
)
|
2024-09-30 06:41:49 -07:00
|
|
|
|
self.extend_lens.extend([1] * running_bs)
|
2024-11-24 04:47:10 -08:00
|
|
|
|
self.extend_num_tokens += running_bs
|
|
|
|
|
|
# TODO (lianmin): Revisit this. It should be seq_len - 1
|
2024-09-30 06:41:49 -07:00
|
|
|
|
self.extend_logprob_start_lens.extend([0] * running_bs)
|
2024-08-16 02:13:00 -07:00
|
|
|
|
|
2024-01-20 03:01:15 +08:00
|
|
|
|
def check_decode_mem(self):
|
2024-09-23 07:38:14 -07:00
|
|
|
|
bs = len(self.reqs)
|
2024-02-03 04:59:06 -08:00
|
|
|
|
if self.token_to_kv_pool.available_size() >= bs:
|
2024-01-20 03:01:15 +08:00
|
|
|
|
return True
|
|
|
|
|
|
|
2024-07-15 02:01:09 -07:00
|
|
|
|
self.tree_cache.evict(bs, self.token_to_kv_pool.free)
|
2024-04-26 01:01:36 +08:00
|
|
|
|
|
2024-01-20 03:01:15 +08:00
|
|
|
|
if self.token_to_kv_pool.available_size() >= bs:
|
|
|
|
|
|
return True
|
|
|
|
|
|
|
|
|
|
|
|
return False
|
|
|
|
|
|
|
|
|
|
|
|
def retract_decode(self):
|
2024-11-19 15:04:43 -08:00
|
|
|
|
"""Retract the decoding requests when there is not enough memory."""
|
2024-01-20 03:01:15 +08:00
|
|
|
|
sorted_indices = [i for i in range(len(self.reqs))]
|
2024-07-23 22:06:02 -07:00
|
|
|
|
|
|
|
|
|
|
# TODO(lsyin): improve retraction policy for radix cache
|
2024-01-20 03:01:15 +08:00
|
|
|
|
sorted_indices.sort(
|
2024-06-12 14:39:12 +08:00
|
|
|
|
key=lambda i: (
|
|
|
|
|
|
len(self.reqs[i].output_ids),
|
|
|
|
|
|
-len(self.reqs[i].origin_input_ids),
|
|
|
|
|
|
),
|
2024-01-20 03:01:15 +08:00
|
|
|
|
reverse=True,
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
retracted_reqs = []
|
2024-05-12 04:54:07 -07:00
|
|
|
|
seq_lens_cpu = self.seq_lens.cpu().numpy()
|
2024-10-13 20:32:37 -07:00
|
|
|
|
first_iter = True
|
2024-07-23 22:06:02 -07:00
|
|
|
|
while (
|
|
|
|
|
|
self.token_to_kv_pool.available_size()
|
|
|
|
|
|
< len(sorted_indices) * global_config.retract_decode_steps
|
2024-10-13 20:32:37 -07:00
|
|
|
|
or first_iter
|
2024-07-23 22:06:02 -07:00
|
|
|
|
):
|
|
|
|
|
|
if len(sorted_indices) == 1:
|
|
|
|
|
|
# Corner case: only one request left
|
|
|
|
|
|
assert (
|
|
|
|
|
|
self.token_to_kv_pool.available_size() > 0
|
|
|
|
|
|
), "No space left for only one request"
|
|
|
|
|
|
break
|
|
|
|
|
|
|
2024-10-13 20:32:37 -07:00
|
|
|
|
first_iter = False
|
2024-01-20 03:01:15 +08:00
|
|
|
|
idx = sorted_indices.pop()
|
|
|
|
|
|
req = self.reqs[idx]
|
|
|
|
|
|
retracted_reqs.append(req)
|
|
|
|
|
|
|
2024-08-01 00:29:01 -07:00
|
|
|
|
if isinstance(self.tree_cache, ChunkCache):
|
|
|
|
|
|
# ChunkCache does not have eviction
|
2024-10-20 00:29:29 -07:00
|
|
|
|
token_indices = self.req_to_token_pool.req_to_token[
|
|
|
|
|
|
req.req_pool_idx, : seq_lens_cpu[idx]
|
2024-08-07 01:41:25 -07:00
|
|
|
|
]
|
2024-08-01 00:29:01 -07:00
|
|
|
|
self.token_to_kv_pool.free(token_indices)
|
2024-08-07 01:41:25 -07:00
|
|
|
|
self.req_to_token_pool.free(req.req_pool_idx)
|
2024-08-01 00:29:01 -07:00
|
|
|
|
del self.tree_cache.entries[req.rid]
|
|
|
|
|
|
else:
|
|
|
|
|
|
# TODO: apply more fine-grained retraction
|
|
|
|
|
|
last_uncached_pos = len(req.prefix_indices)
|
2024-10-20 00:29:29 -07:00
|
|
|
|
token_indices = self.req_to_token_pool.req_to_token[
|
|
|
|
|
|
req.req_pool_idx, last_uncached_pos : seq_lens_cpu[idx]
|
2024-08-07 01:41:25 -07:00
|
|
|
|
]
|
2024-08-01 00:29:01 -07:00
|
|
|
|
self.token_to_kv_pool.free(token_indices)
|
2024-08-07 01:41:25 -07:00
|
|
|
|
self.req_to_token_pool.free(req.req_pool_idx)
|
2024-08-01 00:29:01 -07:00
|
|
|
|
|
|
|
|
|
|
# release the last node
|
|
|
|
|
|
self.tree_cache.dec_lock_ref(req.last_node)
|
|
|
|
|
|
|
|
|
|
|
|
# NOTE(lsyin): we should use the newly evictable memory instantly.
|
|
|
|
|
|
residual_size = (
|
|
|
|
|
|
len(sorted_indices) * global_config.retract_decode_steps
|
|
|
|
|
|
- self.token_to_kv_pool.available_size()
|
|
|
|
|
|
)
|
|
|
|
|
|
residual_size = max(0, residual_size)
|
|
|
|
|
|
self.tree_cache.evict(residual_size, self.token_to_kv_pool.free)
|
2024-12-22 06:25:57 -08:00
|
|
|
|
req.reset_for_retract()
|
2024-01-25 01:16:25 +08:00
|
|
|
|
|
2024-10-15 22:59:26 -07:00
|
|
|
|
self.filter_batch(keep_indices=sorted_indices)
|
2024-01-20 03:01:15 +08:00
|
|
|
|
|
2024-07-23 22:06:02 -07:00
|
|
|
|
# Reqs in batch are filtered
|
|
|
|
|
|
total_decoded_tokens = sum(len(r.output_ids) for r in self.reqs)
|
|
|
|
|
|
total_max_new_tokens = sum(r.sampling_params.max_new_tokens for r in self.reqs)
|
|
|
|
|
|
|
|
|
|
|
|
new_estimate_ratio = (
|
|
|
|
|
|
total_decoded_tokens + global_config.retract_decode_steps * len(self.reqs)
|
|
|
|
|
|
) / total_max_new_tokens
|
|
|
|
|
|
new_estimate_ratio = min(1.0, new_estimate_ratio)
|
|
|
|
|
|
|
|
|
|
|
|
return retracted_reqs, new_estimate_ratio
|
2024-01-20 03:01:15 +08:00
|
|
|
|
|
2024-09-30 06:41:49 -07:00
|
|
|
|
def check_for_jump_forward(self, pad_input_ids_func):
|
2024-02-05 16:50:37 +08:00
|
|
|
|
jump_forward_reqs = []
|
2024-10-14 01:15:34 -07:00
|
|
|
|
keep_indices = set(i for i in range(len(self.reqs)))
|
2024-01-25 01:16:25 +08:00
|
|
|
|
|
|
|
|
|
|
for i, req in enumerate(self.reqs):
|
2024-10-26 06:47:02 +09:00
|
|
|
|
if req.grammar is not None:
|
2024-11-12 21:17:38 -08:00
|
|
|
|
jump_helper = req.grammar.try_jump_forward(req.tokenizer)
|
|
|
|
|
|
if jump_helper:
|
|
|
|
|
|
suffix_ids, _ = jump_helper
|
|
|
|
|
|
|
2024-06-12 14:39:12 +08:00
|
|
|
|
# Current ids, for cache and revert
|
|
|
|
|
|
cur_all_ids = tuple(req.origin_input_ids + req.output_ids)[:-1]
|
|
|
|
|
|
cur_output_ids = req.output_ids
|
|
|
|
|
|
|
|
|
|
|
|
req.output_ids.extend(suffix_ids)
|
2024-07-18 17:57:40 -07:00
|
|
|
|
decode_res, new_text = req.get_next_inc_detokenization()
|
2024-06-12 14:39:12 +08:00
|
|
|
|
if not decode_res:
|
|
|
|
|
|
req.output_ids = cur_output_ids
|
2024-01-25 01:16:25 +08:00
|
|
|
|
continue
|
|
|
|
|
|
|
2024-06-29 23:42:14 -07:00
|
|
|
|
(
|
|
|
|
|
|
jump_forward_str,
|
|
|
|
|
|
next_state,
|
2024-10-26 06:47:02 +09:00
|
|
|
|
) = req.grammar.jump_forward_str_state(jump_helper)
|
2024-06-12 14:39:12 +08:00
|
|
|
|
|
2024-11-12 21:17:38 -08:00
|
|
|
|
# Make the incrementally decoded text part of jump_forward_str
|
|
|
|
|
|
# so that the UTF-8 will not corrupt
|
2024-06-12 14:39:12 +08:00
|
|
|
|
jump_forward_str = new_text + jump_forward_str
|
|
|
|
|
|
if not req.jump_forward_and_retokenize(
|
|
|
|
|
|
jump_forward_str, next_state
|
|
|
|
|
|
):
|
|
|
|
|
|
req.output_ids = cur_output_ids
|
|
|
|
|
|
continue
|
2024-05-13 12:47:13 +08:00
|
|
|
|
|
2024-07-19 16:42:06 -07:00
|
|
|
|
# The decode status has diverged from detokenizer_manager
|
|
|
|
|
|
req.vid += 1
|
|
|
|
|
|
|
2024-05-13 12:47:13 +08:00
|
|
|
|
# insert the old request into tree_cache
|
2024-08-07 15:52:24 -07:00
|
|
|
|
self.tree_cache.cache_finished_req(req, cur_all_ids)
|
2024-01-25 01:16:25 +08:00
|
|
|
|
|
2024-05-26 00:07:26 +08:00
|
|
|
|
# re-applying image padding
|
2024-09-28 23:28:55 -07:00
|
|
|
|
if req.image_inputs is not None:
|
2024-09-30 06:41:49 -07:00
|
|
|
|
req.origin_input_ids = pad_input_ids_func(
|
2024-09-28 23:28:55 -07:00
|
|
|
|
req.origin_input_ids_unpadded, req.image_inputs
|
2024-05-26 00:07:26 +08:00
|
|
|
|
)
|
|
|
|
|
|
|
2024-02-05 16:50:37 +08:00
|
|
|
|
jump_forward_reqs.append(req)
|
2024-10-14 01:15:34 -07:00
|
|
|
|
keep_indices.remove(i)
|
2024-01-25 01:16:25 +08:00
|
|
|
|
|
2024-10-14 01:15:34 -07:00
|
|
|
|
self.filter_batch(keep_indices=list(keep_indices))
|
2024-01-25 01:16:25 +08:00
|
|
|
|
|
2024-02-05 16:50:37 +08:00
|
|
|
|
return jump_forward_reqs
|
2024-01-25 01:16:25 +08:00
|
|
|
|
|
2024-10-21 15:01:21 -07:00
|
|
|
|
def prepare_encoder_info_decode(self):
|
|
|
|
|
|
# Reset the encoder cached status
|
|
|
|
|
|
self.encoder_cached = [True] * len(self.reqs)
|
|
|
|
|
|
|
2024-11-16 17:01:43 +08:00
|
|
|
|
def prepare_for_idle(self):
|
|
|
|
|
|
self.forward_mode = ForwardMode.IDLE
|
2024-11-19 15:04:43 -08:00
|
|
|
|
self.input_ids = torch.empty(0, dtype=torch.int32, device=self.device)
|
|
|
|
|
|
self.seq_lens = torch.empty(0, dtype=torch.int32, device=self.device)
|
|
|
|
|
|
self.out_cache_loc = torch.empty(0, dtype=torch.int32, device=self.device)
|
|
|
|
|
|
self.req_pool_indices = torch.empty(0, dtype=torch.int32, device=self.device)
|
2024-11-18 08:29:20 +08:00
|
|
|
|
self.seq_lens_sum = 0
|
2024-11-16 17:01:43 +08:00
|
|
|
|
self.extend_num_tokens = 0
|
|
|
|
|
|
|
2024-11-24 06:29:38 -08:00
|
|
|
|
def prepare_for_decode(self):
|
2024-09-09 13:49:29 -07:00
|
|
|
|
self.forward_mode = ForwardMode.DECODE
|
|
|
|
|
|
|
2024-10-13 19:54:02 -07:00
|
|
|
|
self.input_ids = self.output_ids
|
|
|
|
|
|
self.output_ids = None
|
2024-11-19 22:07:58 -08:00
|
|
|
|
self.sampling_info.penalizer_orchestrator.cumulate_output_tokens(self.input_ids)
|
2024-01-08 04:37:50 +00:00
|
|
|
|
|
|
|
|
|
|
# Alloc mem
|
2024-09-23 07:38:14 -07:00
|
|
|
|
bs = len(self.reqs)
|
2024-08-07 01:41:25 -07:00
|
|
|
|
self.out_cache_loc = self.alloc_token_slots(bs)
|
2024-10-20 18:17:41 -07:00
|
|
|
|
|
2024-10-21 15:01:21 -07:00
|
|
|
|
if self.model_config.is_encoder_decoder:
|
|
|
|
|
|
locs = self.encoder_lens + self.seq_lens
|
|
|
|
|
|
self.prepare_encoder_info_decode()
|
|
|
|
|
|
else:
|
|
|
|
|
|
locs = self.seq_lens
|
|
|
|
|
|
|
2024-11-24 06:29:38 -08:00
|
|
|
|
if self.enable_overlap:
|
2024-10-20 18:17:41 -07:00
|
|
|
|
# Do not use in-place operations in the overlap mode
|
|
|
|
|
|
self.req_to_token_pool.write(
|
2024-10-21 15:01:21 -07:00
|
|
|
|
(self.req_pool_indices, locs), self.out_cache_loc
|
2024-10-20 18:17:41 -07:00
|
|
|
|
)
|
|
|
|
|
|
self.seq_lens = self.seq_lens + 1
|
|
|
|
|
|
else:
|
|
|
|
|
|
# A faster in-place version
|
|
|
|
|
|
self.req_to_token_pool.write(
|
2024-10-21 15:01:21 -07:00
|
|
|
|
(self.req_pool_indices, locs), self.out_cache_loc
|
2024-10-20 18:17:41 -07:00
|
|
|
|
)
|
|
|
|
|
|
self.seq_lens.add_(1)
|
2024-10-21 01:43:16 -07:00
|
|
|
|
self.seq_lens_sum += bs
|
2024-01-08 04:37:50 +00:00
|
|
|
|
|
2024-10-14 01:15:34 -07:00
|
|
|
|
def filter_batch(
|
|
|
|
|
|
self,
|
2024-10-31 14:51:51 -07:00
|
|
|
|
being_chunked_req: Optional[Req] = None,
|
2024-10-14 01:15:34 -07:00
|
|
|
|
keep_indices: Optional[List[int]] = None,
|
|
|
|
|
|
):
|
|
|
|
|
|
if keep_indices is None:
|
|
|
|
|
|
keep_indices = [
|
|
|
|
|
|
i
|
|
|
|
|
|
for i in range(len(self.reqs))
|
2024-11-07 15:42:47 -08:00
|
|
|
|
if not self.reqs[i].finished() and self.reqs[i] is not being_chunked_req
|
2024-10-14 01:15:34 -07:00
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
|
|
if keep_indices is None or len(keep_indices) == 0:
|
2024-08-07 15:52:24 -07:00
|
|
|
|
# Filter out all requests
|
|
|
|
|
|
self.reqs = []
|
|
|
|
|
|
return
|
|
|
|
|
|
|
2024-10-14 01:15:34 -07:00
|
|
|
|
if len(keep_indices) == len(self.reqs):
|
2024-08-07 15:52:24 -07:00
|
|
|
|
# No need to filter
|
|
|
|
|
|
return
|
|
|
|
|
|
|
2024-10-21 15:01:21 -07:00
|
|
|
|
if self.model_config.is_encoder_decoder:
|
|
|
|
|
|
self.encoder_lens = self.encoder_lens[keep_indices]
|
|
|
|
|
|
self.encoder_lens_cpu = [self.encoder_lens_cpu[i] for i in keep_indices]
|
|
|
|
|
|
|
2024-10-14 01:15:34 -07:00
|
|
|
|
self.reqs = [self.reqs[i] for i in keep_indices]
|
2024-10-18 17:54:03 -07:00
|
|
|
|
new_indices = torch.tensor(keep_indices, dtype=torch.int32).to(
|
|
|
|
|
|
self.device, non_blocking=True
|
2024-10-06 00:10:48 -07:00
|
|
|
|
)
|
2024-01-08 04:37:50 +00:00
|
|
|
|
self.req_pool_indices = self.req_pool_indices[new_indices]
|
2024-09-30 06:41:49 -07:00
|
|
|
|
self.seq_lens = self.seq_lens[new_indices]
|
2024-07-13 15:24:03 -07:00
|
|
|
|
self.out_cache_loc = None
|
2024-10-21 01:43:16 -07:00
|
|
|
|
self.seq_lens_sum = self.seq_lens.sum().item()
|
2024-10-13 19:54:02 -07:00
|
|
|
|
self.output_ids = self.output_ids[new_indices]
|
2024-01-23 05:07:30 -08:00
|
|
|
|
self.return_logprob = any(req.return_logprob for req in self.reqs)
|
2024-09-30 06:41:49 -07:00
|
|
|
|
if self.return_logprob:
|
2024-10-14 01:15:34 -07:00
|
|
|
|
self.top_logprobs_nums = [self.top_logprobs_nums[i] for i in keep_indices]
|
2024-09-30 17:09:54 -07:00
|
|
|
|
else:
|
|
|
|
|
|
self.top_logprobs_nums = None
|
2024-10-02 17:18:04 -07:00
|
|
|
|
|
2024-09-15 06:36:06 -07:00
|
|
|
|
self.has_stream = any(req.stream for req in self.reqs)
|
2024-10-26 06:47:02 +09:00
|
|
|
|
self.has_grammar = any(req.grammar for req in self.reqs)
|
2024-01-08 04:37:50 +00:00
|
|
|
|
|
2024-10-14 01:15:34 -07:00
|
|
|
|
self.sampling_info.filter_batch(keep_indices, new_indices)
|
2024-01-08 04:37:50 +00:00
|
|
|
|
|
2024-09-30 06:41:49 -07:00
|
|
|
|
def merge_batch(self, other: "ScheduleBatch"):
|
2024-08-09 04:46:24 -07:00
|
|
|
|
# Penalizer orchestrator must be merged before Batch.reqs is merged. This is because
|
|
|
|
|
|
# orchestrator.merge() depends on Batch.reqs during preparation of each penalizers, so it
|
|
|
|
|
|
# needs to be called with pre-merged Batch.reqs.
|
2024-09-30 06:41:49 -07:00
|
|
|
|
self.sampling_info.merge_batch(other.sampling_info)
|
2024-08-09 04:46:24 -07:00
|
|
|
|
|
2024-10-21 15:01:21 -07:00
|
|
|
|
# Encoder-decoder infos
|
|
|
|
|
|
if self.model_config.is_encoder_decoder:
|
|
|
|
|
|
self.encoder_lens = torch.cat([self.encoder_lens, other.encoder_lens])
|
|
|
|
|
|
self.encoder_lens_cpu.extend(other.encoder_lens_cpu)
|
|
|
|
|
|
|
2024-01-08 04:37:50 +00:00
|
|
|
|
self.req_pool_indices = torch.concat(
|
|
|
|
|
|
[self.req_pool_indices, other.req_pool_indices]
|
|
|
|
|
|
)
|
|
|
|
|
|
self.seq_lens = torch.concat([self.seq_lens, other.seq_lens])
|
2024-07-13 15:24:03 -07:00
|
|
|
|
self.out_cache_loc = None
|
2024-10-21 01:43:16 -07:00
|
|
|
|
self.seq_lens_sum += other.seq_lens_sum
|
2024-10-13 19:54:02 -07:00
|
|
|
|
if self.output_ids is not None:
|
|
|
|
|
|
self.output_ids = torch.concat([self.output_ids, other.output_ids])
|
2024-09-30 06:41:49 -07:00
|
|
|
|
if self.return_logprob and other.return_logprob:
|
|
|
|
|
|
self.top_logprobs_nums.extend(other.top_logprobs_nums)
|
|
|
|
|
|
elif self.return_logprob:
|
|
|
|
|
|
self.top_logprobs_nums.extend([0] * len(other.reqs))
|
|
|
|
|
|
elif other.return_logprob:
|
|
|
|
|
|
self.top_logprobs_nums = [0] * len(self.reqs) + other.top_logprobs_nums
|
2024-09-30 17:09:54 -07:00
|
|
|
|
self.reqs.extend(other.reqs)
|
2024-10-02 17:18:04 -07:00
|
|
|
|
|
2024-12-16 14:11:09 -08:00
|
|
|
|
self.return_logprob |= other.return_logprob
|
|
|
|
|
|
self.has_stream |= other.has_stream
|
|
|
|
|
|
self.has_grammar |= other.has_grammar
|
2024-09-30 06:41:49 -07:00
|
|
|
|
|
|
|
|
|
|
def get_model_worker_batch(self):
|
2024-11-16 17:01:43 +08:00
|
|
|
|
if self.forward_mode.is_decode() or self.forward_mode.is_idle():
|
2024-10-21 15:01:21 -07:00
|
|
|
|
extend_seq_lens = extend_prefix_lens = extend_logprob_start_lens = None
|
2024-09-30 06:41:49 -07:00
|
|
|
|
else:
|
|
|
|
|
|
extend_seq_lens = self.extend_lens
|
|
|
|
|
|
extend_prefix_lens = self.prefix_lens
|
|
|
|
|
|
extend_logprob_start_lens = self.extend_logprob_start_lens
|
|
|
|
|
|
|
2024-11-18 17:48:28 -08:00
|
|
|
|
if self.sampling_info:
|
2024-11-16 17:01:43 +08:00
|
|
|
|
if self.has_grammar:
|
|
|
|
|
|
self.sampling_info.grammars = [req.grammar for req in self.reqs]
|
|
|
|
|
|
else:
|
|
|
|
|
|
self.sampling_info.grammars = None
|
2024-09-30 06:41:49 -07:00
|
|
|
|
|
2024-10-14 20:08:03 -07:00
|
|
|
|
global bid
|
|
|
|
|
|
bid += 1
|
|
|
|
|
|
|
2024-09-30 06:41:49 -07:00
|
|
|
|
return ModelWorkerBatch(
|
2024-10-14 20:08:03 -07:00
|
|
|
|
bid=bid,
|
2024-09-30 06:41:49 -07:00
|
|
|
|
forward_mode=self.forward_mode,
|
|
|
|
|
|
input_ids=self.input_ids,
|
|
|
|
|
|
req_pool_indices=self.req_pool_indices,
|
|
|
|
|
|
seq_lens=self.seq_lens,
|
|
|
|
|
|
out_cache_loc=self.out_cache_loc,
|
2024-10-21 01:43:16 -07:00
|
|
|
|
seq_lens_sum=self.seq_lens_sum,
|
2024-09-30 06:41:49 -07:00
|
|
|
|
return_logprob=self.return_logprob,
|
|
|
|
|
|
top_logprobs_nums=self.top_logprobs_nums,
|
2024-11-16 17:01:43 +08:00
|
|
|
|
global_num_tokens=self.global_num_tokens,
|
2024-11-18 08:29:20 +08:00
|
|
|
|
can_run_dp_cuda_graph=self.can_run_dp_cuda_graph,
|
2024-10-21 01:43:16 -07:00
|
|
|
|
extend_num_tokens=self.extend_num_tokens,
|
2024-09-30 06:41:49 -07:00
|
|
|
|
extend_seq_lens=extend_seq_lens,
|
|
|
|
|
|
extend_prefix_lens=extend_prefix_lens,
|
|
|
|
|
|
extend_logprob_start_lens=extend_logprob_start_lens,
|
2024-10-21 15:01:21 -07:00
|
|
|
|
image_inputs=[r.image_inputs for r in self.reqs],
|
|
|
|
|
|
encoder_cached=self.encoder_cached,
|
|
|
|
|
|
encoder_lens=self.encoder_lens,
|
|
|
|
|
|
encoder_lens_cpu=self.encoder_lens_cpu,
|
|
|
|
|
|
encoder_out_cache_loc=self.encoder_out_cache_loc,
|
2024-10-19 23:19:26 -07:00
|
|
|
|
lora_paths=[req.lora_path for req in self.reqs],
|
2024-09-30 06:41:49 -07:00
|
|
|
|
sampling_info=self.sampling_info,
|
2024-11-25 19:35:04 -05:00
|
|
|
|
input_embeds=self.input_embeds,
|
2024-09-30 06:41:49 -07:00
|
|
|
|
)
|
|
|
|
|
|
|
2024-10-12 21:35:30 -07:00
|
|
|
|
def copy(self):
|
2024-10-20 18:17:41 -07:00
|
|
|
|
# Only contain fields that will be used by process_batch_result
|
2024-10-12 21:35:30 -07:00
|
|
|
|
return ScheduleBatch(
|
|
|
|
|
|
reqs=self.reqs,
|
2024-10-21 15:01:21 -07:00
|
|
|
|
model_config=self.model_config,
|
2024-10-12 21:35:30 -07:00
|
|
|
|
forward_mode=self.forward_mode,
|
2024-10-19 06:50:56 -07:00
|
|
|
|
out_cache_loc=self.out_cache_loc,
|
|
|
|
|
|
return_logprob=self.return_logprob,
|
2024-10-13 19:54:02 -07:00
|
|
|
|
decoding_reqs=self.decoding_reqs,
|
2024-10-12 21:35:30 -07:00
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
def __str__(self):
|
|
|
|
|
|
return (
|
|
|
|
|
|
f"ScheduleBatch(forward_mode={self.forward_mode.name}, "
|
|
|
|
|
|
f"#req={(len(self.reqs))})"
|
|
|
|
|
|
)
|
|
|
|
|
|
|
2024-11-07 15:42:47 -08:00
|
|
|
|
|
2024-10-20 19:47:14 -07:00
|
|
|
|
@dataclasses.dataclass
|
2024-09-30 06:41:49 -07:00
|
|
|
|
class ModelWorkerBatch:
|
2024-10-14 20:08:03 -07:00
|
|
|
|
# The batch id
|
|
|
|
|
|
bid: int
|
2024-09-30 06:41:49 -07:00
|
|
|
|
# The forward mode
|
|
|
|
|
|
forward_mode: ForwardMode
|
|
|
|
|
|
# The input ids
|
2024-10-04 01:09:59 -07:00
|
|
|
|
input_ids: torch.Tensor
|
2024-09-30 06:41:49 -07:00
|
|
|
|
# The indices of requests in the req_to_token_pool
|
|
|
|
|
|
req_pool_indices: torch.Tensor
|
|
|
|
|
|
# The sequence length
|
|
|
|
|
|
seq_lens: torch.Tensor
|
|
|
|
|
|
# The indices of output tokens in the token_to_kv_pool
|
|
|
|
|
|
out_cache_loc: torch.Tensor
|
|
|
|
|
|
|
2024-10-21 01:43:16 -07:00
|
|
|
|
# The sum of all sequence lengths
|
|
|
|
|
|
seq_lens_sum: int
|
|
|
|
|
|
|
2024-09-30 06:41:49 -07:00
|
|
|
|
# For logprob
|
|
|
|
|
|
return_logprob: bool
|
|
|
|
|
|
top_logprobs_nums: Optional[List[int]]
|
|
|
|
|
|
|
2024-11-16 17:01:43 +08:00
|
|
|
|
# For DP attention
|
|
|
|
|
|
global_num_tokens: Optional[List[int]]
|
2024-11-18 08:29:20 +08:00
|
|
|
|
can_run_dp_cuda_graph: bool
|
2024-11-16 17:01:43 +08:00
|
|
|
|
|
2024-09-30 06:41:49 -07:00
|
|
|
|
# For extend
|
2024-10-21 01:43:16 -07:00
|
|
|
|
extend_num_tokens: Optional[int]
|
2024-09-30 06:41:49 -07:00
|
|
|
|
extend_seq_lens: Optional[List[int]]
|
|
|
|
|
|
extend_prefix_lens: Optional[List[int]]
|
|
|
|
|
|
extend_logprob_start_lens: Optional[List[int]]
|
|
|
|
|
|
|
|
|
|
|
|
# For multimodal
|
|
|
|
|
|
image_inputs: Optional[List[ImageInputs]]
|
|
|
|
|
|
|
2024-10-21 15:01:21 -07:00
|
|
|
|
# For encoder-decoder
|
|
|
|
|
|
encoder_cached: Optional[List[bool]]
|
|
|
|
|
|
encoder_lens: Optional[torch.Tensor]
|
|
|
|
|
|
encoder_lens_cpu: Optional[List[int]]
|
|
|
|
|
|
encoder_out_cache_loc: Optional[torch.Tensor]
|
|
|
|
|
|
|
2024-09-30 06:41:49 -07:00
|
|
|
|
# For LoRA
|
|
|
|
|
|
lora_paths: Optional[List[str]]
|
|
|
|
|
|
|
|
|
|
|
|
# Sampling info
|
|
|
|
|
|
sampling_info: SamplingBatchInfo
|
2024-10-14 20:08:03 -07:00
|
|
|
|
|
2024-11-25 19:35:04 -05:00
|
|
|
|
# The input Embeds
|
|
|
|
|
|
input_embeds: Optional[torch.tensor] = None
|
|
|
|
|
|
|
2024-11-16 16:14:23 -08:00
|
|
|
|
|
|
|
|
|
|
@triton.jit
|
|
|
|
|
|
def write_req_to_token_pool_triton(
|
|
|
|
|
|
req_to_token_ptr, # [max_batch, max_context_len]
|
|
|
|
|
|
req_pool_indices,
|
|
|
|
|
|
pre_lens,
|
|
|
|
|
|
seq_lens,
|
|
|
|
|
|
extend_lens,
|
|
|
|
|
|
out_cache_loc,
|
|
|
|
|
|
req_to_token_ptr_stride: tl.constexpr,
|
|
|
|
|
|
):
|
|
|
|
|
|
BLOCK_SIZE: tl.constexpr = 512
|
|
|
|
|
|
pid = tl.program_id(0)
|
|
|
|
|
|
|
|
|
|
|
|
req_pool_index = tl.load(req_pool_indices + pid)
|
|
|
|
|
|
pre_len = tl.load(pre_lens + pid)
|
|
|
|
|
|
seq_len = tl.load(seq_lens + pid)
|
|
|
|
|
|
|
|
|
|
|
|
# TODO: optimize this?
|
|
|
|
|
|
cumsum_start = 0
|
|
|
|
|
|
for i in range(pid):
|
|
|
|
|
|
cumsum_start += tl.load(extend_lens + i)
|
|
|
|
|
|
|
|
|
|
|
|
num_loop = tl.cdiv(seq_len - pre_len, BLOCK_SIZE)
|
|
|
|
|
|
for i in range(num_loop):
|
|
|
|
|
|
offset = tl.arange(0, BLOCK_SIZE) + i * BLOCK_SIZE
|
|
|
|
|
|
mask = offset < (seq_len - pre_len)
|
|
|
|
|
|
value = tl.load(out_cache_loc + cumsum_start + offset, mask=mask)
|
|
|
|
|
|
tl.store(
|
|
|
|
|
|
req_to_token_ptr
|
|
|
|
|
|
+ req_pool_index * req_to_token_ptr_stride
|
|
|
|
|
|
+ offset
|
|
|
|
|
|
+ pre_len,
|
|
|
|
|
|
value,
|
|
|
|
|
|
mask=mask,
|
|
|
|
|
|
)
|