Fix the possible bug of decode out of memory (#36)
This commit is contained in:
@@ -40,7 +40,7 @@ def extract_prefix_by_tracing(program, backend):
|
||||
try:
|
||||
with TracingScope(tracer):
|
||||
tracer.ret_value = program.func(tracer, **arguments)
|
||||
except (StopTracing, TypeError):
|
||||
except (StopTracing, TypeError, AttributeError):
|
||||
# Some exceptions may not be catched
|
||||
pass
|
||||
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum, auto
|
||||
from typing import List
|
||||
|
||||
@@ -38,6 +39,7 @@ class Req:
|
||||
|
||||
self.adjust_input_len = 0
|
||||
self.prefix_indices = []
|
||||
self.last_node = None
|
||||
|
||||
self.normalized_logprob = None
|
||||
|
||||
@@ -81,27 +83,56 @@ class Req:
|
||||
return f"rid(n={self.rid}, " f"input_ids={self.input_ids}, "
|
||||
|
||||
|
||||
@dataclass
|
||||
class Batch:
|
||||
def __init__(
|
||||
self,
|
||||
reqs: List[Req],
|
||||
req_to_token_pool: ReqToTokenPool,
|
||||
token_to_kv_pool: TokenToKVPool,
|
||||
tree_cache: RadixCache,
|
||||
):
|
||||
self.reqs = reqs
|
||||
self.req_to_token_pool = req_to_token_pool
|
||||
self.token_to_kv_pool = token_to_kv_pool
|
||||
self.tree_cache = tree_cache
|
||||
reqs: List[Req]
|
||||
req_to_token_pool: ReqToTokenPool
|
||||
token_to_kv_pool: TokenToKVPool
|
||||
tree_cache: RadixCache
|
||||
|
||||
self.return_normalized_logprob = any(
|
||||
req.return_normalized_logprob for req in reqs
|
||||
# batched arguments to model runner
|
||||
input_ids: torch.Tensor = None
|
||||
req_pool_indices: torch.Tensor = None
|
||||
seq_lens: torch.Tensor = None
|
||||
prefix_lens: torch.Tensor = None
|
||||
position_ids_offsets: torch.Tensor = None
|
||||
out_cache_loc: torch.Tensor = None
|
||||
out_cache_cont_start: torch.Tensor = None
|
||||
out_cache_cont_end: torch.Tensor = None
|
||||
return_normalized_logprob: bool = False
|
||||
|
||||
# for multimodal
|
||||
pixel_values: List[torch.Tensor] = None
|
||||
image_offsets: List[int] = None
|
||||
|
||||
# other arguments for control
|
||||
output_ids: torch.Tensor = None
|
||||
extend_num_tokens: int = None
|
||||
|
||||
# batched sampling params
|
||||
temperatures: torch.Tensor = None
|
||||
top_ps: torch.Tensor = None
|
||||
top_ks: torch.Tensor = None
|
||||
frequency_penalties: torch.Tensor = None
|
||||
presence_penalties: torch.Tensor = None
|
||||
logit_bias: torch.Tensor = None
|
||||
|
||||
@classmethod
|
||||
def init_new(cls, reqs, req_to_token_pool, token_to_kv_pool, tree_cache):
|
||||
return_normalized_logprob = any(req.return_normalized_logprob for req in reqs)
|
||||
|
||||
return cls(
|
||||
reqs=reqs,
|
||||
req_to_token_pool=req_to_token_pool,
|
||||
token_to_kv_pool=token_to_kv_pool,
|
||||
tree_cache=tree_cache,
|
||||
return_normalized_logprob=return_normalized_logprob,
|
||||
)
|
||||
|
||||
def is_empty(self):
|
||||
return len(self.reqs) == 0
|
||||
|
||||
def init_extend_batch(self, vocab_size: int, int_token_logit_bias: torch.Tensor):
|
||||
def prepare_for_extend(self, vocab_size: int, int_token_logit_bias: torch.Tensor):
|
||||
device = "cuda"
|
||||
bs = len(self.reqs)
|
||||
reqs = self.reqs
|
||||
@@ -141,7 +172,7 @@ class Batch:
|
||||
out_cache_loc = self.token_to_kv_pool.alloc(extend_num_tokens)
|
||||
|
||||
if out_cache_loc is None:
|
||||
print("Prefill out of memory.")
|
||||
print("Prefill out of memory. This should nerver happen.")
|
||||
self.tree_cache.pretty_print()
|
||||
exit()
|
||||
|
||||
@@ -196,7 +227,50 @@ class Batch:
|
||||
)
|
||||
self.logit_bias = logit_bias
|
||||
|
||||
def update_for_decode(self, input_ids=None):
|
||||
def check_decode_mem(self):
|
||||
bs = len(self.reqs)
|
||||
avai_size = self.token_to_kv_pool.available_size()
|
||||
if avai_size >= bs:
|
||||
return True
|
||||
|
||||
self.tree_cache.evict(bs, self.token_to_kv_pool.free)
|
||||
if self.token_to_kv_pool.available_size() >= bs:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def retract_decode(self):
|
||||
sorted_indices = [i for i in range(len(self.reqs))]
|
||||
sorted_indices.sort(
|
||||
key=lambda i: (len(self.reqs[i].output_ids), -len(self.reqs[i].input_ids)),
|
||||
reverse=True,
|
||||
)
|
||||
|
||||
retracted_reqs = []
|
||||
seq_lens_np = self.seq_lens.cpu().numpy()
|
||||
req_pool_indices_np = self.req_pool_indices.cpu().numpy()
|
||||
while self.token_to_kv_pool.available_size() < len(self.reqs):
|
||||
idx = sorted_indices.pop()
|
||||
req = self.reqs[idx]
|
||||
retracted_reqs.append(req)
|
||||
|
||||
self.tree_cache.dec_ref_counter(req.last_node)
|
||||
req.prefix_indices = None
|
||||
req.last_node = None
|
||||
req.adjust_input_len = 0
|
||||
req.output_ids = []
|
||||
# TODO: apply more fine-grained retraction
|
||||
|
||||
token_indices = self.req_to_token_pool.req_to_token[
|
||||
req_pool_indices_np[idx]
|
||||
][: seq_lens_np[idx]]
|
||||
self.token_to_kv_pool.free(token_indices)
|
||||
|
||||
self.filter_batch(sorted_indices)
|
||||
|
||||
return retracted_reqs
|
||||
|
||||
def prepare_for_decode(self, input_ids=None):
|
||||
if input_ids is None:
|
||||
input_ids = [
|
||||
r.output_ids[-1] if r.output_ids else r.input_ids[-1] for r in self.reqs
|
||||
@@ -212,13 +286,9 @@ class Batch:
|
||||
self.out_cache_loc = self.token_to_kv_pool.alloc(bs)
|
||||
|
||||
if self.out_cache_loc is None:
|
||||
self.tree_cache.evict(bs, self.token_to_kv_pool.free)
|
||||
self.out_cache_loc = self.token_to_kv_pool.alloc(bs)
|
||||
|
||||
if self.out_cache_loc is None:
|
||||
print("Decode out of memory.")
|
||||
self.tree_cache.pretty_print()
|
||||
exit()
|
||||
print("Decode out of memory. This should nerver happen.")
|
||||
self.tree_cache.pretty_print()
|
||||
exit()
|
||||
|
||||
self.out_cache_cont_start = None
|
||||
self.out_cache_cont_end = None
|
||||
@@ -240,6 +310,9 @@ class Batch:
|
||||
self.prefix_lens = None
|
||||
self.position_ids_offsets = self.position_ids_offsets[new_indices]
|
||||
self.out_cache_loc = self.out_cache_cont_start = self.out_cache_cont_end = None
|
||||
self.return_normalized_logprob = any(
|
||||
req.return_normalized_logprob for req in self.reqs
|
||||
)
|
||||
|
||||
for item in [
|
||||
"temperatures",
|
||||
@@ -263,6 +336,9 @@ class Batch:
|
||||
[self.position_ids_offsets, other.position_ids_offsets]
|
||||
)
|
||||
self.out_cache_loc = self.out_cache_cont_start = self.out_cache_cont_end = None
|
||||
self.return_normalized_logprob = any(
|
||||
req.return_normalized_logprob for req in self.reqs
|
||||
)
|
||||
|
||||
for item in [
|
||||
"temperatures",
|
||||
|
||||
@@ -45,7 +45,6 @@ class ModelRpcServer(rpyc.Service):
|
||||
self.tp_rank = tp_rank
|
||||
self.tp_size = server_args.tp_size
|
||||
self.schedule_heuristic = server_args.schedule_heuristic
|
||||
self.schedule_conservativeness = server_args.schedule_conservativeness
|
||||
|
||||
# Init model and tokenizer
|
||||
self.model_config = ModelConfig(
|
||||
@@ -114,6 +113,11 @@ class ModelRpcServer(rpyc.Service):
|
||||
# Init the FSM cache for constrained generation
|
||||
self.regex_fsm_cache = FSMCache(self.tokenizer)
|
||||
|
||||
# Init new token estimation
|
||||
self.new_token_ratio = min(0.4 * server_args.schedule_conservativeness, 1.0)
|
||||
self.min_new_token_ratio = min(0.2 * server_args.schedule_conservativeness, 1.0)
|
||||
self.new_token_ratio_step = (0.0001, 0.05) # (down, up)
|
||||
|
||||
def exposed_step(self, recv_reqs):
|
||||
if self.tp_size != 1:
|
||||
recv_reqs = obtain(recv_reqs)
|
||||
@@ -209,11 +213,6 @@ class ModelRpcServer(rpyc.Service):
|
||||
req.stream = recv_req.stream
|
||||
req.tokenizer = self.tokenizer
|
||||
|
||||
# init the regex fsm
|
||||
if req.sampling_params.regex is not None:
|
||||
req.regex_fsm_state = 0
|
||||
req.regex_fsm = self.regex_fsm_cache.get_fsm(req.sampling_params.regex)
|
||||
|
||||
# Truncate long prompts
|
||||
req.input_ids = req.input_ids[: self.model_config.context_len - 1]
|
||||
req.sampling_params.max_new_tokens = min(
|
||||
@@ -249,13 +248,10 @@ class ModelRpcServer(rpyc.Service):
|
||||
available_size = (
|
||||
self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size()
|
||||
)
|
||||
new_ratio = (
|
||||
self.scheduler.new_token_estimation_ratio() * self.schedule_conservativeness
|
||||
)
|
||||
if self.running_batch:
|
||||
available_size -= sum(
|
||||
[
|
||||
(r.max_new_tokens() - len(r.output_ids)) * new_ratio
|
||||
(r.max_new_tokens() - len(r.output_ids)) * self.new_token_ratio
|
||||
for r in self.running_batch.reqs
|
||||
]
|
||||
)
|
||||
@@ -311,7 +307,7 @@ class ModelRpcServer(rpyc.Service):
|
||||
f"#running_req: {0 if self.running_batch is None else len(self.running_batch.reqs)}"
|
||||
)
|
||||
|
||||
new_batch = Batch(
|
||||
new_batch = Batch.init_new(
|
||||
can_run_list,
|
||||
self.req_to_token_pool,
|
||||
self.token_to_kv_pool,
|
||||
@@ -322,7 +318,16 @@ class ModelRpcServer(rpyc.Service):
|
||||
|
||||
def forward_fill_batch(self, batch: Batch):
|
||||
# Build batch tensors
|
||||
batch.init_extend_batch(self.model_config.vocab_size, self.int_token_logit_bias)
|
||||
batch.prepare_for_extend(
|
||||
self.model_config.vocab_size, self.int_token_logit_bias
|
||||
)
|
||||
|
||||
# init the regex fsm before first sampling
|
||||
for req in batch.reqs:
|
||||
if req.sampling_params.regex is not None:
|
||||
req.regex_fsm_state = 0
|
||||
req.regex_fsm = self.regex_fsm_cache.get_fsm(req.sampling_params.regex)
|
||||
|
||||
if batch.extend_num_tokens != 0:
|
||||
# Forward
|
||||
logits, normalized_logprobs = self.model_runner.forward(
|
||||
@@ -350,9 +355,27 @@ class ModelRpcServer(rpyc.Service):
|
||||
self.handle_finished_requests(batch)
|
||||
|
||||
def forward_decode_batch(self, batch: Batch):
|
||||
# check if decode out of memory
|
||||
if not batch.check_decode_mem():
|
||||
old_ratio = self.new_token_ratio
|
||||
self.new_token_ratio = min(old_ratio + self.new_token_ratio_step[1], 1.0)
|
||||
|
||||
retracted_reqs = batch.retract_decode()
|
||||
logger.info(
|
||||
"decode out of memory happened, "
|
||||
f"#retracted_reqs: {len(retracted_reqs)}, "
|
||||
f"#new_token_ratio: {old_ratio:.4f} -> {self.new_token_ratio:.4f}"
|
||||
)
|
||||
self.forward_queue.extend(retracted_reqs)
|
||||
else:
|
||||
self.new_token_ratio = max(
|
||||
self.new_token_ratio - self.new_token_ratio_step[0],
|
||||
self.min_new_token_ratio,
|
||||
)
|
||||
|
||||
# Update batch tensors
|
||||
self.decode_forward_ct += 1
|
||||
batch.update_for_decode()
|
||||
batch.prepare_for_decode()
|
||||
|
||||
# Forward
|
||||
logits = self.model_runner.forward(batch, ForwardMode.DECODE)
|
||||
|
||||
@@ -17,9 +17,6 @@ class Scheduler:
|
||||
self.max_total_num_token = max_total_num_token
|
||||
self.tree_cache = tree_cache
|
||||
|
||||
def new_token_estimation_ratio(self):
|
||||
return 0.5 if self.schedule_heuristic != "fcfs" else 0.6
|
||||
|
||||
def get_priority_queue(self, forward_queue):
|
||||
if self.schedule_heuristic == "lpm":
|
||||
# longest prefix match
|
||||
|
||||
@@ -119,7 +119,7 @@ class ServerArgs:
|
||||
"--schedule-conservativeness",
|
||||
type=float,
|
||||
default=ServerArgs.schedule_conservativeness,
|
||||
help="How conservative the schedule policy is. A larger value means more conservative scheduling. Use a larger value if you see out-of-memory errors.",
|
||||
help="How conservative the schedule policy is. A larger value means more conservative scheduling. Use a larger value if you see requests being retracted frequently.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--random-seed",
|
||||
|
||||
Reference in New Issue
Block a user