Simplify the event loop and expose --num-continuous-decode-steps as an argument (#1652)
This commit is contained in:
@@ -19,7 +19,6 @@ class GlobalConfig:
|
||||
self.new_token_ratio_decay = 0.001
|
||||
|
||||
# Runtime constants: others
|
||||
self.num_continue_decode_steps = 10
|
||||
self.retract_decode_steps = 20
|
||||
self.flashinfer_workspace_size = os.environ.get(
|
||||
"FLASHINFER_WORKSPACE_SIZE", 384 * 1024 * 1024
|
||||
|
||||
@@ -831,6 +831,22 @@ class ScheduleBatch:
|
||||
sampling_info=self.sampling_info,
|
||||
)
|
||||
|
||||
def copy(self):
|
||||
return ScheduleBatch(
|
||||
reqs=self.reqs,
|
||||
req_to_token_pool=self.req_to_token_pool,
|
||||
token_to_kv_pool=self.token_to_kv_pool,
|
||||
tree_cache=self.tree_cache,
|
||||
forward_mode=self.forward_mode,
|
||||
output_token_ids=self.output_token_ids,
|
||||
)
|
||||
|
||||
def __str__(self):
|
||||
return (
|
||||
f"ScheduleBatch(forward_mode={self.forward_mode.name}, "
|
||||
f"#req={(len(self.reqs))})"
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelWorkerBatch:
|
||||
|
||||
@@ -20,6 +20,7 @@ import logging
|
||||
import os
|
||||
import time
|
||||
import warnings
|
||||
from types import SimpleNamespace
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import torch
|
||||
@@ -106,7 +107,8 @@ class Scheduler:
|
||||
self.send_to_detokenizer = context.socket(zmq.PUSH)
|
||||
self.send_to_detokenizer.connect(f"ipc://{port_args.detokenizer_ipc_name}")
|
||||
else:
|
||||
self.recv_from_tokenizer = self.send_to_detokenizer = None
|
||||
self.recv_from_tokenizer = None
|
||||
self.send_to_detokenizer = SimpleNamespace(send_pyobj=lambda x: None)
|
||||
|
||||
# Init tokenizer
|
||||
self.model_config = ModelConfig(
|
||||
@@ -190,7 +192,6 @@ class Scheduler:
|
||||
# Init running status
|
||||
self.waiting_queue: List[Req] = []
|
||||
self.running_batch: ScheduleBatch = None
|
||||
self.out_pyobjs = []
|
||||
self.decode_forward_ct = 0
|
||||
self.stream_interval = server_args.stream_interval
|
||||
self.num_generated_tokens = 0
|
||||
@@ -247,13 +248,30 @@ class Scheduler:
|
||||
|
||||
@torch.inference_mode()
|
||||
def event_loop(self):
|
||||
self.last_batch = None
|
||||
|
||||
while True:
|
||||
recv_reqs = self.recv_requests()
|
||||
self.process_input_requests(recv_reqs)
|
||||
|
||||
self.run_step()
|
||||
batch = self.get_next_batch_to_run()
|
||||
|
||||
self.send_results()
|
||||
if batch:
|
||||
result = self.run_batch(batch)
|
||||
self.process_batch_result(batch, result)
|
||||
|
||||
# Decode multiple steps to reduce the overhead
|
||||
if batch.forward_mode.is_decode():
|
||||
for _ in range(self.server_args.num_continuous_decode_steps - 1):
|
||||
if not self.running_batch:
|
||||
break
|
||||
self.update_running_batch()
|
||||
if not self.running_batch:
|
||||
break
|
||||
result = self.run_batch(batch)
|
||||
self.process_batch_result(batch, result)
|
||||
|
||||
self.last_batch = batch
|
||||
|
||||
def recv_requests(self):
|
||||
if self.tp_rank == 0:
|
||||
@@ -286,7 +304,9 @@ class Scheduler:
|
||||
self.abort_request(recv_req)
|
||||
elif isinstance(recv_req, UpdateWeightReqInput):
|
||||
success, message = self.update_weights(recv_req)
|
||||
self.out_pyobjs.append(UpdateWeightReqOutput(success, message))
|
||||
self.send_to_detokenizer.send_pyobj(
|
||||
UpdateWeightReqOutput(success, message)
|
||||
)
|
||||
elif isinstance(recv_req, ProfileReq):
|
||||
if recv_req == ProfileReq.START_PROFILE:
|
||||
self.start_profile()
|
||||
@@ -384,12 +404,6 @@ class Scheduler:
|
||||
|
||||
self.waiting_queue.append(req)
|
||||
|
||||
def send_results(self):
|
||||
if self.tp_rank == 0:
|
||||
for obj in self.out_pyobjs:
|
||||
self.send_to_detokenizer.send_pyobj(obj)
|
||||
self.out_pyobjs = []
|
||||
|
||||
def print_decode_stats(self):
|
||||
num_used = self.max_total_num_tokens - (
|
||||
self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size()
|
||||
@@ -427,44 +441,32 @@ class Scheduler:
|
||||
)
|
||||
exit(1) if crash_on_warning else None
|
||||
|
||||
def run_step(self):
|
||||
def get_next_batch_to_run(self):
|
||||
# Merge prefill to the running batch
|
||||
if (
|
||||
self.last_batch
|
||||
and not self.last_batch.forward_mode.is_decode()
|
||||
and not self.last_batch.is_empty()
|
||||
):
|
||||
if self.running_batch is None:
|
||||
self.running_batch = self.last_batch
|
||||
else:
|
||||
self.running_batch.merge_batch(self.last_batch)
|
||||
|
||||
# Prefill first
|
||||
new_batch = self.get_new_batch_prefill()
|
||||
if new_batch is not None:
|
||||
# Run a new prefill batch
|
||||
# replace run_batch with the uncommented line to use pytorch profiler
|
||||
# result = pytorch_profile(
|
||||
# "profile_prefill_step", self.run_batch, new_batch, data_size=len(new_batch.reqs)
|
||||
# )
|
||||
result = self.run_batch(new_batch)
|
||||
self.process_batch_result(new_batch, result)
|
||||
return new_batch
|
||||
|
||||
# Run decode
|
||||
if self.running_batch is not None:
|
||||
self.update_running_batch()
|
||||
if not self.running_batch:
|
||||
return None
|
||||
return self.running_batch
|
||||
else:
|
||||
if self.running_batch is not None:
|
||||
# Run a few decode batches continuously for reducing overhead
|
||||
for _ in range(global_config.num_continue_decode_steps):
|
||||
batch = self.get_new_batch_decode()
|
||||
|
||||
if batch:
|
||||
# replace run_batch with the uncommented line to use pytorch profiler
|
||||
# result = pytorch_profile(
|
||||
# "profile_decode_step",
|
||||
# self.run_batch,
|
||||
# batch,
|
||||
# data_size=len(batch.reqs),
|
||||
# )
|
||||
result = self.run_batch(batch)
|
||||
self.process_batch_result(batch, result)
|
||||
|
||||
if self.running_batch.is_empty():
|
||||
self.running_batch = None
|
||||
|
||||
if self.running_batch is None:
|
||||
break
|
||||
|
||||
if self.out_pyobjs and self.running_batch.has_stream:
|
||||
break
|
||||
else:
|
||||
self.check_memory()
|
||||
self.new_token_ratio = global_config.init_new_token_ratio
|
||||
self.check_memory()
|
||||
self.new_token_ratio = global_config.init_new_token_ratio
|
||||
|
||||
def get_new_batch_prefill(self) -> Optional[ScheduleBatch]:
|
||||
# Handle the cases where prefill is not allowed
|
||||
@@ -607,7 +609,7 @@ class Scheduler:
|
||||
|
||||
return new_batch
|
||||
|
||||
def get_new_batch_decode(self) -> Optional[ScheduleBatch]:
|
||||
def update_running_batch(self):
|
||||
batch = self.running_batch
|
||||
|
||||
# Check if decode out of memory
|
||||
@@ -636,11 +638,11 @@ class Scheduler:
|
||||
if jump_forward_reqs:
|
||||
self.batch_is_full = False
|
||||
if batch.is_empty():
|
||||
return None
|
||||
self.running_batch = None
|
||||
return
|
||||
|
||||
# Update batch tensors
|
||||
batch.prepare_for_decode()
|
||||
return batch
|
||||
|
||||
def run_batch(self, batch: ScheduleBatch):
|
||||
if self.is_generation:
|
||||
@@ -657,16 +659,19 @@ class Scheduler:
|
||||
)
|
||||
else:
|
||||
next_token_ids = torch.full((batch.batch_size(),), 0)
|
||||
return logits_output, next_token_ids
|
||||
ret = logits_output, next_token_ids
|
||||
else: # embedding or reward model
|
||||
assert batch.extend_num_tokens != 0
|
||||
model_worker_batch = batch.get_model_worker_batch()
|
||||
embeddings = self.tp_worker.forward_batch_embedding(model_worker_batch)
|
||||
return embeddings
|
||||
ret = embeddings
|
||||
return ret
|
||||
|
||||
def process_batch_result(self, batch: ScheduleBatch, result):
|
||||
if batch.forward_mode.is_decode():
|
||||
self.process_batch_result_decode(batch, result)
|
||||
if batch.is_empty():
|
||||
self.running_batch = None
|
||||
else:
|
||||
self.process_batch_result_prefill(batch, result)
|
||||
|
||||
@@ -728,7 +733,7 @@ class Scheduler:
|
||||
)
|
||||
else: # embedding or reward model
|
||||
assert batch.extend_num_tokens != 0
|
||||
embeddings = result
|
||||
embeddings = result.tolist()
|
||||
|
||||
# Check finish conditions
|
||||
for i, req in enumerate(batch.reqs):
|
||||
@@ -750,12 +755,6 @@ class Scheduler:
|
||||
|
||||
self.handle_finished_requests(batch)
|
||||
|
||||
if not batch.is_empty():
|
||||
if self.running_batch is None:
|
||||
self.running_batch = batch
|
||||
else:
|
||||
self.running_batch.merge_batch(batch)
|
||||
|
||||
def process_batch_result_decode(self, batch: ScheduleBatch, result):
|
||||
logits_output, next_token_ids = result
|
||||
if batch.sampling_info.penalizer_orchestrator:
|
||||
@@ -951,7 +950,7 @@ class Scheduler:
|
||||
# Send to detokenizer
|
||||
if output_rids:
|
||||
if self.is_generation:
|
||||
self.out_pyobjs.append(
|
||||
self.send_to_detokenizer.send_pyobj(
|
||||
BatchTokenIDOut(
|
||||
output_rids,
|
||||
output_vids,
|
||||
@@ -965,7 +964,7 @@ class Scheduler:
|
||||
)
|
||||
)
|
||||
else: # embedding or reward model
|
||||
self.out_pyobjs.append(
|
||||
self.send_to_detokenizer.send_pyobj(
|
||||
BatchEmbeddingOut(
|
||||
output_rids,
|
||||
output_embeddings,
|
||||
|
||||
@@ -118,7 +118,7 @@ class TpModelWorker:
|
||||
def forward_batch_embedding(self, model_worker_batch: ModelWorkerBatch):
|
||||
forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
|
||||
logits_output = self.model_runner.forward(forward_batch)
|
||||
embeddings = logits_output.embeddings.tolist()
|
||||
embeddings = logits_output.embeddings
|
||||
return embeddings
|
||||
|
||||
def update_weights(self, recv_req: UpdateWeightReqInput):
|
||||
|
||||
@@ -111,6 +111,7 @@ class ServerArgs:
|
||||
torchao_config: str = ""
|
||||
enable_p2p_check: bool = False
|
||||
triton_attention_reduce_in_fp32: bool = False
|
||||
num_continuous_decode_steps: int = 1
|
||||
|
||||
def __post_init__(self):
|
||||
# Set missing default values
|
||||
@@ -559,6 +560,14 @@ class ServerArgs:
|
||||
help="Cast the intermidiate attention results to fp32 to avoid possible crashes related to fp16."
|
||||
"This only affects Triton attention kernels.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num-continuous-decode-steps",
|
||||
type=int,
|
||||
default=ServerArgs.num_continuous_decode_steps,
|
||||
help="Run multiple continuous decoding steps to reduce scheduling overhead. "
|
||||
"This can potentially increase throughput but may also increase time-to-first-token latency. "
|
||||
"The default value is 1, meaning only run one decoding step at a time.",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_cli_args(cls, args: argparse.Namespace):
|
||||
|
||||
Reference in New Issue
Block a user