Fix the possible bug of decode out of memory (#36)

This commit is contained in:
Liangsheng Yin
2024-01-20 03:01:15 +08:00
committed by GitHub
parent 199e82a15d
commit 40ab1f0129
7 changed files with 274 additions and 46 deletions

View File

@@ -40,7 +40,7 @@ def extract_prefix_by_tracing(program, backend):
try:
with TracingScope(tracer):
tracer.ret_value = program.func(tracer, **arguments)
except (StopTracing, TypeError):
except (StopTracing, TypeError, AttributeError):
# Some exceptions may not be catched
pass

View File

@@ -1,3 +1,4 @@
from dataclasses import dataclass
from enum import Enum, auto
from typing import List
@@ -38,6 +39,7 @@ class Req:
self.adjust_input_len = 0
self.prefix_indices = []
self.last_node = None
self.normalized_logprob = None
@@ -81,27 +83,56 @@ class Req:
return f"rid(n={self.rid}, " f"input_ids={self.input_ids}, "
@dataclass
class Batch:
def __init__(
self,
reqs: List[Req],
req_to_token_pool: ReqToTokenPool,
token_to_kv_pool: TokenToKVPool,
tree_cache: RadixCache,
):
self.reqs = reqs
self.req_to_token_pool = req_to_token_pool
self.token_to_kv_pool = token_to_kv_pool
self.tree_cache = tree_cache
reqs: List[Req]
req_to_token_pool: ReqToTokenPool
token_to_kv_pool: TokenToKVPool
tree_cache: RadixCache
self.return_normalized_logprob = any(
req.return_normalized_logprob for req in reqs
# batched arguments to model runner
input_ids: torch.Tensor = None
req_pool_indices: torch.Tensor = None
seq_lens: torch.Tensor = None
prefix_lens: torch.Tensor = None
position_ids_offsets: torch.Tensor = None
out_cache_loc: torch.Tensor = None
out_cache_cont_start: torch.Tensor = None
out_cache_cont_end: torch.Tensor = None
return_normalized_logprob: bool = False
# for multimodal
pixel_values: List[torch.Tensor] = None
image_offsets: List[int] = None
# other arguments for control
output_ids: torch.Tensor = None
extend_num_tokens: int = None
# batched sampling params
temperatures: torch.Tensor = None
top_ps: torch.Tensor = None
top_ks: torch.Tensor = None
frequency_penalties: torch.Tensor = None
presence_penalties: torch.Tensor = None
logit_bias: torch.Tensor = None
@classmethod
def init_new(cls, reqs, req_to_token_pool, token_to_kv_pool, tree_cache):
return_normalized_logprob = any(req.return_normalized_logprob for req in reqs)
return cls(
reqs=reqs,
req_to_token_pool=req_to_token_pool,
token_to_kv_pool=token_to_kv_pool,
tree_cache=tree_cache,
return_normalized_logprob=return_normalized_logprob,
)
def is_empty(self):
return len(self.reqs) == 0
def init_extend_batch(self, vocab_size: int, int_token_logit_bias: torch.Tensor):
def prepare_for_extend(self, vocab_size: int, int_token_logit_bias: torch.Tensor):
device = "cuda"
bs = len(self.reqs)
reqs = self.reqs
@@ -141,7 +172,7 @@ class Batch:
out_cache_loc = self.token_to_kv_pool.alloc(extend_num_tokens)
if out_cache_loc is None:
print("Prefill out of memory.")
print("Prefill out of memory. This should nerver happen.")
self.tree_cache.pretty_print()
exit()
@@ -196,7 +227,50 @@ class Batch:
)
self.logit_bias = logit_bias
def update_for_decode(self, input_ids=None):
def check_decode_mem(self):
bs = len(self.reqs)
avai_size = self.token_to_kv_pool.available_size()
if avai_size >= bs:
return True
self.tree_cache.evict(bs, self.token_to_kv_pool.free)
if self.token_to_kv_pool.available_size() >= bs:
return True
return False
def retract_decode(self):
sorted_indices = [i for i in range(len(self.reqs))]
sorted_indices.sort(
key=lambda i: (len(self.reqs[i].output_ids), -len(self.reqs[i].input_ids)),
reverse=True,
)
retracted_reqs = []
seq_lens_np = self.seq_lens.cpu().numpy()
req_pool_indices_np = self.req_pool_indices.cpu().numpy()
while self.token_to_kv_pool.available_size() < len(self.reqs):
idx = sorted_indices.pop()
req = self.reqs[idx]
retracted_reqs.append(req)
self.tree_cache.dec_ref_counter(req.last_node)
req.prefix_indices = None
req.last_node = None
req.adjust_input_len = 0
req.output_ids = []
# TODO: apply more fine-grained retraction
token_indices = self.req_to_token_pool.req_to_token[
req_pool_indices_np[idx]
][: seq_lens_np[idx]]
self.token_to_kv_pool.free(token_indices)
self.filter_batch(sorted_indices)
return retracted_reqs
def prepare_for_decode(self, input_ids=None):
if input_ids is None:
input_ids = [
r.output_ids[-1] if r.output_ids else r.input_ids[-1] for r in self.reqs
@@ -212,13 +286,9 @@ class Batch:
self.out_cache_loc = self.token_to_kv_pool.alloc(bs)
if self.out_cache_loc is None:
self.tree_cache.evict(bs, self.token_to_kv_pool.free)
self.out_cache_loc = self.token_to_kv_pool.alloc(bs)
if self.out_cache_loc is None:
print("Decode out of memory.")
self.tree_cache.pretty_print()
exit()
print("Decode out of memory. This should nerver happen.")
self.tree_cache.pretty_print()
exit()
self.out_cache_cont_start = None
self.out_cache_cont_end = None
@@ -240,6 +310,9 @@ class Batch:
self.prefix_lens = None
self.position_ids_offsets = self.position_ids_offsets[new_indices]
self.out_cache_loc = self.out_cache_cont_start = self.out_cache_cont_end = None
self.return_normalized_logprob = any(
req.return_normalized_logprob for req in self.reqs
)
for item in [
"temperatures",
@@ -263,6 +336,9 @@ class Batch:
[self.position_ids_offsets, other.position_ids_offsets]
)
self.out_cache_loc = self.out_cache_cont_start = self.out_cache_cont_end = None
self.return_normalized_logprob = any(
req.return_normalized_logprob for req in self.reqs
)
for item in [
"temperatures",

View File

@@ -45,7 +45,6 @@ class ModelRpcServer(rpyc.Service):
self.tp_rank = tp_rank
self.tp_size = server_args.tp_size
self.schedule_heuristic = server_args.schedule_heuristic
self.schedule_conservativeness = server_args.schedule_conservativeness
# Init model and tokenizer
self.model_config = ModelConfig(
@@ -114,6 +113,11 @@ class ModelRpcServer(rpyc.Service):
# Init the FSM cache for constrained generation
self.regex_fsm_cache = FSMCache(self.tokenizer)
# Init new token estimation
self.new_token_ratio = min(0.4 * server_args.schedule_conservativeness, 1.0)
self.min_new_token_ratio = min(0.2 * server_args.schedule_conservativeness, 1.0)
self.new_token_ratio_step = (0.0001, 0.05) # (down, up)
def exposed_step(self, recv_reqs):
if self.tp_size != 1:
recv_reqs = obtain(recv_reqs)
@@ -209,11 +213,6 @@ class ModelRpcServer(rpyc.Service):
req.stream = recv_req.stream
req.tokenizer = self.tokenizer
# init the regex fsm
if req.sampling_params.regex is not None:
req.regex_fsm_state = 0
req.regex_fsm = self.regex_fsm_cache.get_fsm(req.sampling_params.regex)
# Truncate long prompts
req.input_ids = req.input_ids[: self.model_config.context_len - 1]
req.sampling_params.max_new_tokens = min(
@@ -249,13 +248,10 @@ class ModelRpcServer(rpyc.Service):
available_size = (
self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size()
)
new_ratio = (
self.scheduler.new_token_estimation_ratio() * self.schedule_conservativeness
)
if self.running_batch:
available_size -= sum(
[
(r.max_new_tokens() - len(r.output_ids)) * new_ratio
(r.max_new_tokens() - len(r.output_ids)) * self.new_token_ratio
for r in self.running_batch.reqs
]
)
@@ -311,7 +307,7 @@ class ModelRpcServer(rpyc.Service):
f"#running_req: {0 if self.running_batch is None else len(self.running_batch.reqs)}"
)
new_batch = Batch(
new_batch = Batch.init_new(
can_run_list,
self.req_to_token_pool,
self.token_to_kv_pool,
@@ -322,7 +318,16 @@ class ModelRpcServer(rpyc.Service):
def forward_fill_batch(self, batch: Batch):
# Build batch tensors
batch.init_extend_batch(self.model_config.vocab_size, self.int_token_logit_bias)
batch.prepare_for_extend(
self.model_config.vocab_size, self.int_token_logit_bias
)
# init the regex fsm before first sampling
for req in batch.reqs:
if req.sampling_params.regex is not None:
req.regex_fsm_state = 0
req.regex_fsm = self.regex_fsm_cache.get_fsm(req.sampling_params.regex)
if batch.extend_num_tokens != 0:
# Forward
logits, normalized_logprobs = self.model_runner.forward(
@@ -350,9 +355,27 @@ class ModelRpcServer(rpyc.Service):
self.handle_finished_requests(batch)
def forward_decode_batch(self, batch: Batch):
# check if decode out of memory
if not batch.check_decode_mem():
old_ratio = self.new_token_ratio
self.new_token_ratio = min(old_ratio + self.new_token_ratio_step[1], 1.0)
retracted_reqs = batch.retract_decode()
logger.info(
"decode out of memory happened, "
f"#retracted_reqs: {len(retracted_reqs)}, "
f"#new_token_ratio: {old_ratio:.4f} -> {self.new_token_ratio:.4f}"
)
self.forward_queue.extend(retracted_reqs)
else:
self.new_token_ratio = max(
self.new_token_ratio - self.new_token_ratio_step[0],
self.min_new_token_ratio,
)
# Update batch tensors
self.decode_forward_ct += 1
batch.update_for_decode()
batch.prepare_for_decode()
# Forward
logits = self.model_runner.forward(batch, ForwardMode.DECODE)

View File

@@ -17,9 +17,6 @@ class Scheduler:
self.max_total_num_token = max_total_num_token
self.tree_cache = tree_cache
def new_token_estimation_ratio(self):
return 0.5 if self.schedule_heuristic != "fcfs" else 0.6
def get_priority_queue(self, forward_queue):
if self.schedule_heuristic == "lpm":
# longest prefix match

View File

@@ -119,7 +119,7 @@ class ServerArgs:
"--schedule-conservativeness",
type=float,
default=ServerArgs.schedule_conservativeness,
help="How conservative the schedule policy is. A larger value means more conservative scheduling. Use a larger value if you see out-of-memory errors.",
help="How conservative the schedule policy is. A larger value means more conservative scheduling. Use a larger value if you see requests being retracted frequently.",
)
parser.add_argument(
"--random-seed",