Launch a thread to overlap CPU and GPU (#1687)

This commit is contained in:
Lianmin Zheng
2024-10-16 11:20:17 -07:00
committed by GitHub
parent e4b367baa8
commit dbec2f1847
3 changed files with 142 additions and 20 deletions

View File

@@ -193,16 +193,6 @@ class Scheduler:
self.tree_cache_metrics = {"total": 0, "hit": 0}
self.policy = SchedulePolicy(self.schedule_policy, self.tree_cache)
if self.server_args.enable_overlap_schedule:
def cache_finished_req(req):
free_delta = int(self.running_batch and req in self.cur_batch.reqs)
self.tree_cache.cache_finished_req(req, free_delta=free_delta)
else:
cache_finished_req = self.tree_cache.cache_finished_req
self.cache_finished_req = cache_finished_req
# Init running status
self.waiting_queue: List[Req] = []
self.running_batch: Optional[ScheduleBatch] = None
@@ -245,6 +235,7 @@ class Scheduler:
self.new_token_ratio_decay = global_config.new_token_ratio_decay
self.batch_is_full = False
# Init profiler
if os.getenv("SGLANG_TORCH_PROFILER_DIR", "") == "":
self.profiler = None
else:
@@ -261,6 +252,25 @@ class Scheduler:
with_stack=True,
)
# Init states for overlap schedule
if self.server_args.enable_overlap_schedule:
self.forward_batch_generation = (
self.tp_worker.forward_batch_generation_non_blocking
)
self.resolve_next_token_ids = (
lambda bid, x: self.tp_worker.resolve_future_token_ids(bid)
)
def cache_finished_req(req):
free_delta = int(self.running_batch and req in self.cur_batch.reqs)
self.tree_cache.cache_finished_req(req, free_delta=free_delta)
self.cache_finished_req = cache_finished_req
else:
self.forward_batch_generation = self.tp_worker.forward_batch_generation
self.resolve_next_token_ids = lambda bid, x: x.tolist()
self.cache_finished_req = self.tree_cache.cache_finished_req
@torch.inference_mode()
def event_loop_normal(self):
self.last_batch = None
@@ -712,7 +722,7 @@ class Scheduler:
if self.is_generation:
if batch.forward_mode.is_decode() or batch.extend_num_tokens != 0:
model_worker_batch = batch.get_model_worker_batch()
logits_output, next_token_ids = self.tp_worker.forward_batch_generation(
logits_output, next_token_ids = self.forward_batch_generation(
model_worker_batch
)
else:
@@ -724,12 +734,12 @@ class Scheduler:
else:
next_token_ids = torch.full((batch.batch_size(),), 0)
batch.output_ids = next_token_ids
ret = logits_output, next_token_ids
ret = logits_output, next_token_ids, model_worker_batch.bid
else: # embedding or reward model
assert batch.extend_num_tokens != 0
model_worker_batch = batch.get_model_worker_batch()
embeddings = self.tp_worker.forward_batch_embedding(model_worker_batch)
ret = embeddings
ret = embeddings, model_worker_batch.bid
return ret
def process_batch_result(self, batch: ScheduleBatch, result):
@@ -742,7 +752,7 @@ class Scheduler:
def process_batch_result_prefill(self, batch: ScheduleBatch, result):
if self.is_generation:
logits_output, next_token_ids = result
logits_output, next_token_ids, bid = result
if batch.return_logprob:
# Move logprobs to cpu
if logits_output.next_token_logprobs is not None:
@@ -761,7 +771,7 @@ class Scheduler:
logits_output.normalized_prompt_logprobs.tolist()
)
next_token_ids = next_token_ids.tolist()
next_token_ids = self.resolve_next_token_ids(bid, next_token_ids)
# Check finish conditions
logprob_pt = 0
@@ -790,7 +800,8 @@ class Scheduler:
)
else: # embedding or reward model
assert batch.extend_num_tokens != 0
embeddings = result.tolist()
embeddings, bid = result
embeddings = embeddings.tolist()
# Check finish conditions
for i, req in enumerate(batch.reqs):
@@ -811,7 +822,7 @@ class Scheduler:
self.stream_output(batch.reqs)
def process_batch_result_decode(self, batch: ScheduleBatch, result):
logits_output, next_token_ids = result
logits_output, next_token_ids, bid = result
self.num_generated_tokens += len(batch.reqs)
# Move logprobs to cpu
@@ -821,7 +832,7 @@ class Scheduler:
next_token_ids,
].tolist()
next_token_ids = next_token_ids.tolist()
next_token_ids = self.resolve_next_token_ids(bid, next_token_ids)
# Check finish condition
for i, (req, next_token_id) in enumerate(zip(batch.reqs, next_token_ids)):