Fix the race condition in overlap mode (#1712)
This commit is contained in:
@@ -405,9 +405,9 @@ class ScheduleBatch:
|
|||||||
|
|
||||||
# Request, memory pool, and cache
|
# Request, memory pool, and cache
|
||||||
reqs: List[Req]
|
reqs: List[Req]
|
||||||
req_to_token_pool: ReqToTokenPool
|
req_to_token_pool: ReqToTokenPool = None
|
||||||
token_to_kv_pool: BaseTokenToKVPool
|
token_to_kv_pool: BaseTokenToKVPool = None
|
||||||
tree_cache: BasePrefixCache
|
tree_cache: BasePrefixCache = None
|
||||||
|
|
||||||
forward_mode: ForwardMode = None
|
forward_mode: ForwardMode = None
|
||||||
sampling_info: SamplingBatchInfo = None
|
sampling_info: SamplingBatchInfo = None
|
||||||
@@ -874,12 +874,9 @@ class ScheduleBatch:
|
|||||||
def copy(self):
|
def copy(self):
|
||||||
return ScheduleBatch(
|
return ScheduleBatch(
|
||||||
reqs=self.reqs,
|
reqs=self.reqs,
|
||||||
req_to_token_pool=self.req_to_token_pool,
|
|
||||||
token_to_kv_pool=self.token_to_kv_pool,
|
|
||||||
tree_cache=self.tree_cache,
|
|
||||||
forward_mode=self.forward_mode,
|
forward_mode=self.forward_mode,
|
||||||
output_ids=self.output_ids,
|
out_cache_loc=self.out_cache_loc,
|
||||||
sampling_info=self.sampling_info,
|
return_logprob=self.return_logprob,
|
||||||
decoding_reqs=self.decoding_reqs,
|
decoding_reqs=self.decoding_reqs,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -929,7 +926,7 @@ class ModelWorkerBatch:
|
|||||||
forward_mode=self.forward_mode,
|
forward_mode=self.forward_mode,
|
||||||
input_ids=self.input_ids.clone(),
|
input_ids=self.input_ids.clone(),
|
||||||
req_pool_indices=self.req_pool_indices,
|
req_pool_indices=self.req_pool_indices,
|
||||||
seq_lens=self.seq_lens,
|
seq_lens=self.seq_lens.clone(),
|
||||||
out_cache_loc=self.out_cache_loc,
|
out_cache_loc=self.out_cache_loc,
|
||||||
return_logprob=self.return_logprob,
|
return_logprob=self.return_logprob,
|
||||||
top_logprobs_nums=self.top_logprobs_nums,
|
top_logprobs_nums=self.top_logprobs_nums,
|
||||||
|
|||||||
@@ -261,12 +261,7 @@ class Scheduler:
|
|||||||
self.resolve_next_token_ids = (
|
self.resolve_next_token_ids = (
|
||||||
lambda bid, x: self.tp_worker.resolve_future_token_ids(bid)
|
lambda bid, x: self.tp_worker.resolve_future_token_ids(bid)
|
||||||
)
|
)
|
||||||
|
self.cache_finished_req = self.tree_cache.cache_finished_req
|
||||||
def cache_finished_req(req):
|
|
||||||
free_delta = int(self.running_batch and req in self.cur_batch.reqs)
|
|
||||||
self.tree_cache.cache_finished_req(req, free_delta=free_delta)
|
|
||||||
|
|
||||||
self.cache_finished_req = cache_finished_req
|
|
||||||
else:
|
else:
|
||||||
self.forward_batch_generation = self.tp_worker.forward_batch_generation
|
self.forward_batch_generation = self.tp_worker.forward_batch_generation
|
||||||
self.resolve_next_token_ids = lambda bid, x: x.tolist()
|
self.resolve_next_token_ids = lambda bid, x: x.tolist()
|
||||||
@@ -798,7 +793,6 @@ class Scheduler:
|
|||||||
i, req, logprob_pt, next_token_ids, logits_output
|
i, req, logprob_pt, next_token_ids, logits_output
|
||||||
)
|
)
|
||||||
else: # embedding or reward model
|
else: # embedding or reward model
|
||||||
assert batch.extend_num_tokens != 0
|
|
||||||
embeddings, bid = result
|
embeddings, bid = result
|
||||||
embeddings = embeddings.tolist()
|
embeddings = embeddings.tolist()
|
||||||
|
|
||||||
@@ -838,6 +832,7 @@ class Scheduler:
|
|||||||
# Check finish condition
|
# Check finish condition
|
||||||
for i, (req, next_token_id) in enumerate(zip(batch.reqs, next_token_ids)):
|
for i, (req, next_token_id) in enumerate(zip(batch.reqs, next_token_ids)):
|
||||||
if self.server_args.enable_overlap_schedule and req.finished():
|
if self.server_args.enable_overlap_schedule and req.finished():
|
||||||
|
self.token_to_kv_pool.free(batch.out_cache_loc[i : i + 1])
|
||||||
continue
|
continue
|
||||||
|
|
||||||
req.completion_tokens_wo_jump_forward += 1
|
req.completion_tokens_wo_jump_forward += 1
|
||||||
|
|||||||
@@ -149,14 +149,12 @@ class TpModelWorker:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Resolve future tokens in the input
|
# Resolve future tokens in the input
|
||||||
# logger.info(f"raw input {model_worker_batch.input_ids=}")
|
|
||||||
tic2 = time.time()
|
tic2 = time.time()
|
||||||
resolved_input_ids = model_worker_batch.input_ids
|
resolved_input_ids = model_worker_batch.input_ids
|
||||||
future_mask = resolved_input_ids < 0
|
future_mask = resolved_input_ids < 0
|
||||||
resolved_input_ids[future_mask] = self.future_token_ids_map[
|
resolved_input_ids[future_mask] = self.future_token_ids_map[
|
||||||
-resolved_input_ids[future_mask]
|
-resolved_input_ids[future_mask]
|
||||||
]
|
]
|
||||||
# logger.info(f"resolved input {model_worker_batch.input_ids=}")
|
|
||||||
|
|
||||||
# Run forward
|
# Run forward
|
||||||
logits_output, next_token_ids = self.forward_batch_generation(
|
logits_output, next_token_ids = self.forward_batch_generation(
|
||||||
@@ -215,12 +213,13 @@ class TpModelWorker:
|
|||||||
self.future_logits_output_ct += 1
|
self.future_logits_output_ct += 1
|
||||||
|
|
||||||
bs = len(model_worker_batch.seq_lens)
|
bs = len(model_worker_batch.seq_lens)
|
||||||
future_next_token_ids = -torch.arange(
|
with torch.cuda.stream(self.forward_stream):
|
||||||
self.future_token_ids_ct + 1,
|
future_next_token_ids = -torch.arange(
|
||||||
self.future_token_ids_ct + 1 + bs,
|
self.future_token_ids_ct + 1,
|
||||||
dtype=torch.int32,
|
self.future_token_ids_ct + 1 + bs,
|
||||||
device=self.device,
|
dtype=torch.int32,
|
||||||
)
|
device=self.device,
|
||||||
|
)
|
||||||
self.future_token_ids_ct = (
|
self.future_token_ids_ct = (
|
||||||
self.future_token_ids_ct + bs
|
self.future_token_ids_ct + bs
|
||||||
) % self.future_token_ids_limit
|
) % self.future_token_ids_limit
|
||||||
|
|||||||
@@ -38,16 +38,14 @@ class ChunkCache(BasePrefixCache):
|
|||||||
max_prefix_len = len(key)
|
max_prefix_len = len(key)
|
||||||
return entry.value[:max_prefix_len], entry
|
return entry.value[:max_prefix_len], entry
|
||||||
|
|
||||||
def cache_finished_req(
|
def cache_finished_req(self, req: Req, token_ids: Optional[List[int]] = None):
|
||||||
self, req: Req, token_ids: Optional[List[int]] = None, free_delta: int = 0
|
|
||||||
):
|
|
||||||
if token_ids is None:
|
if token_ids is None:
|
||||||
token_id_len = len(req.origin_input_ids) + len(req.output_ids) - 1
|
token_id_len = len(req.origin_input_ids) + len(req.output_ids) - 1
|
||||||
else:
|
else:
|
||||||
token_id_len = len(token_ids)
|
token_id_len = len(token_ids)
|
||||||
|
|
||||||
kv_indices = self.req_to_token_pool.req_to_token[
|
kv_indices = self.req_to_token_pool.req_to_token[
|
||||||
req.req_pool_idx, : token_id_len + free_delta
|
req.req_pool_idx, :token_id_len
|
||||||
]
|
]
|
||||||
self.req_to_token_pool.free(req.req_pool_idx)
|
self.req_to_token_pool.free(req.req_pool_idx)
|
||||||
self.token_to_kv_pool.free(kv_indices)
|
self.token_to_kv_pool.free(kv_indices)
|
||||||
|
|||||||
@@ -97,9 +97,7 @@ class RadixCache(BasePrefixCache):
|
|||||||
value = [x for x in key]
|
value = [x for x in key]
|
||||||
return self._insert_helper(self.root_node, key, value)
|
return self._insert_helper(self.root_node, key, value)
|
||||||
|
|
||||||
def cache_finished_req(
|
def cache_finished_req(self, req: Req, token_ids: Optional[List[int]] = None):
|
||||||
self, req: Req, token_ids: Optional[List[int]] = None, free_delta: int = 0
|
|
||||||
):
|
|
||||||
"""Cache request when it finishes."""
|
"""Cache request when it finishes."""
|
||||||
if self.disable:
|
if self.disable:
|
||||||
if token_ids is None:
|
if token_ids is None:
|
||||||
@@ -108,7 +106,7 @@ class RadixCache(BasePrefixCache):
|
|||||||
token_ids_len = len(token_ids)
|
token_ids_len = len(token_ids)
|
||||||
|
|
||||||
kv_indices = self.req_to_token_pool.req_to_token[
|
kv_indices = self.req_to_token_pool.req_to_token[
|
||||||
req.req_pool_idx, : token_ids_len + free_delta
|
req.req_pool_idx, :token_ids_len
|
||||||
]
|
]
|
||||||
self.token_to_kv_pool.free(kv_indices)
|
self.token_to_kv_pool.free(kv_indices)
|
||||||
self.req_to_token_pool.free(req.req_pool_idx)
|
self.req_to_token_pool.free(req.req_pool_idx)
|
||||||
@@ -123,12 +121,6 @@ class RadixCache(BasePrefixCache):
|
|||||||
# Radix Cache takes one ref in memory pool
|
# Radix Cache takes one ref in memory pool
|
||||||
new_prefix_len = self.insert(token_ids, kv_indices.clone())
|
new_prefix_len = self.insert(token_ids, kv_indices.clone())
|
||||||
self.token_to_kv_pool.free(kv_indices[len(req.prefix_indices) : new_prefix_len])
|
self.token_to_kv_pool.free(kv_indices[len(req.prefix_indices) : new_prefix_len])
|
||||||
if free_delta:
|
|
||||||
self.token_to_kv_pool.free(
|
|
||||||
self.req_to_token_pool.req_to_token[
|
|
||||||
req.req_pool_idx, len(token_ids) : len(token_ids) + 1
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
# Remove req slot release the cache lock
|
# Remove req slot release the cache lock
|
||||||
self.req_to_token_pool.free(req.req_pool_idx)
|
self.req_to_token_pool.free(req.req_pool_idx)
|
||||||
|
|||||||
@@ -542,6 +542,8 @@ def _wait_and_warmup(server_args, pipe_finish_writer, pid):
|
|||||||
kill_child_process(pid, including_parent=False)
|
kill_child_process(pid, including_parent=False)
|
||||||
return
|
return
|
||||||
|
|
||||||
|
# logger.info(f"{res.json()=}")
|
||||||
|
|
||||||
logger.info("The server is fired up and ready to roll!")
|
logger.info("The server is fired up and ready to roll!")
|
||||||
if pipe_finish_writer is not None:
|
if pipe_finish_writer is not None:
|
||||||
pipe_finish_writer.send("ready")
|
pipe_finish_writer.send("ready")
|
||||||
|
|||||||
Reference in New Issue
Block a user