Simplify the event loop and expose --num-continuous-decode-steps as an argument (#1652)
This commit is contained in:
@@ -19,7 +19,6 @@ class GlobalConfig:
|
|||||||
self.new_token_ratio_decay = 0.001
|
self.new_token_ratio_decay = 0.001
|
||||||
|
|
||||||
# Runtime constants: others
|
# Runtime constants: others
|
||||||
self.num_continue_decode_steps = 10
|
|
||||||
self.retract_decode_steps = 20
|
self.retract_decode_steps = 20
|
||||||
self.flashinfer_workspace_size = os.environ.get(
|
self.flashinfer_workspace_size = os.environ.get(
|
||||||
"FLASHINFER_WORKSPACE_SIZE", 384 * 1024 * 1024
|
"FLASHINFER_WORKSPACE_SIZE", 384 * 1024 * 1024
|
||||||
|
|||||||
@@ -831,6 +831,22 @@ class ScheduleBatch:
|
|||||||
sampling_info=self.sampling_info,
|
sampling_info=self.sampling_info,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def copy(self):
|
||||||
|
return ScheduleBatch(
|
||||||
|
reqs=self.reqs,
|
||||||
|
req_to_token_pool=self.req_to_token_pool,
|
||||||
|
token_to_kv_pool=self.token_to_kv_pool,
|
||||||
|
tree_cache=self.tree_cache,
|
||||||
|
forward_mode=self.forward_mode,
|
||||||
|
output_token_ids=self.output_token_ids,
|
||||||
|
)
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return (
|
||||||
|
f"ScheduleBatch(forward_mode={self.forward_mode.name}, "
|
||||||
|
f"#req={(len(self.reqs))})"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ModelWorkerBatch:
|
class ModelWorkerBatch:
|
||||||
|
|||||||
@@ -20,6 +20,7 @@ import logging
|
|||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
import warnings
|
import warnings
|
||||||
|
from types import SimpleNamespace
|
||||||
from typing import List, Optional, Union
|
from typing import List, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@@ -106,7 +107,8 @@ class Scheduler:
|
|||||||
self.send_to_detokenizer = context.socket(zmq.PUSH)
|
self.send_to_detokenizer = context.socket(zmq.PUSH)
|
||||||
self.send_to_detokenizer.connect(f"ipc://{port_args.detokenizer_ipc_name}")
|
self.send_to_detokenizer.connect(f"ipc://{port_args.detokenizer_ipc_name}")
|
||||||
else:
|
else:
|
||||||
self.recv_from_tokenizer = self.send_to_detokenizer = None
|
self.recv_from_tokenizer = None
|
||||||
|
self.send_to_detokenizer = SimpleNamespace(send_pyobj=lambda x: None)
|
||||||
|
|
||||||
# Init tokenizer
|
# Init tokenizer
|
||||||
self.model_config = ModelConfig(
|
self.model_config = ModelConfig(
|
||||||
@@ -190,7 +192,6 @@ class Scheduler:
|
|||||||
# Init running status
|
# Init running status
|
||||||
self.waiting_queue: List[Req] = []
|
self.waiting_queue: List[Req] = []
|
||||||
self.running_batch: ScheduleBatch = None
|
self.running_batch: ScheduleBatch = None
|
||||||
self.out_pyobjs = []
|
|
||||||
self.decode_forward_ct = 0
|
self.decode_forward_ct = 0
|
||||||
self.stream_interval = server_args.stream_interval
|
self.stream_interval = server_args.stream_interval
|
||||||
self.num_generated_tokens = 0
|
self.num_generated_tokens = 0
|
||||||
@@ -247,13 +248,30 @@ class Scheduler:
|
|||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def event_loop(self):
|
def event_loop(self):
|
||||||
|
self.last_batch = None
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
recv_reqs = self.recv_requests()
|
recv_reqs = self.recv_requests()
|
||||||
self.process_input_requests(recv_reqs)
|
self.process_input_requests(recv_reqs)
|
||||||
|
|
||||||
self.run_step()
|
batch = self.get_next_batch_to_run()
|
||||||
|
|
||||||
self.send_results()
|
if batch:
|
||||||
|
result = self.run_batch(batch)
|
||||||
|
self.process_batch_result(batch, result)
|
||||||
|
|
||||||
|
# Decode multiple steps to reduce the overhead
|
||||||
|
if batch.forward_mode.is_decode():
|
||||||
|
for _ in range(self.server_args.num_continuous_decode_steps - 1):
|
||||||
|
if not self.running_batch:
|
||||||
|
break
|
||||||
|
self.update_running_batch()
|
||||||
|
if not self.running_batch:
|
||||||
|
break
|
||||||
|
result = self.run_batch(batch)
|
||||||
|
self.process_batch_result(batch, result)
|
||||||
|
|
||||||
|
self.last_batch = batch
|
||||||
|
|
||||||
def recv_requests(self):
|
def recv_requests(self):
|
||||||
if self.tp_rank == 0:
|
if self.tp_rank == 0:
|
||||||
@@ -286,7 +304,9 @@ class Scheduler:
|
|||||||
self.abort_request(recv_req)
|
self.abort_request(recv_req)
|
||||||
elif isinstance(recv_req, UpdateWeightReqInput):
|
elif isinstance(recv_req, UpdateWeightReqInput):
|
||||||
success, message = self.update_weights(recv_req)
|
success, message = self.update_weights(recv_req)
|
||||||
self.out_pyobjs.append(UpdateWeightReqOutput(success, message))
|
self.send_to_detokenizer.send_pyobj(
|
||||||
|
UpdateWeightReqOutput(success, message)
|
||||||
|
)
|
||||||
elif isinstance(recv_req, ProfileReq):
|
elif isinstance(recv_req, ProfileReq):
|
||||||
if recv_req == ProfileReq.START_PROFILE:
|
if recv_req == ProfileReq.START_PROFILE:
|
||||||
self.start_profile()
|
self.start_profile()
|
||||||
@@ -384,12 +404,6 @@ class Scheduler:
|
|||||||
|
|
||||||
self.waiting_queue.append(req)
|
self.waiting_queue.append(req)
|
||||||
|
|
||||||
def send_results(self):
|
|
||||||
if self.tp_rank == 0:
|
|
||||||
for obj in self.out_pyobjs:
|
|
||||||
self.send_to_detokenizer.send_pyobj(obj)
|
|
||||||
self.out_pyobjs = []
|
|
||||||
|
|
||||||
def print_decode_stats(self):
|
def print_decode_stats(self):
|
||||||
num_used = self.max_total_num_tokens - (
|
num_used = self.max_total_num_tokens - (
|
||||||
self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size()
|
self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size()
|
||||||
@@ -427,44 +441,32 @@ class Scheduler:
|
|||||||
)
|
)
|
||||||
exit(1) if crash_on_warning else None
|
exit(1) if crash_on_warning else None
|
||||||
|
|
||||||
def run_step(self):
|
def get_next_batch_to_run(self):
|
||||||
|
# Merge prefill to the running batch
|
||||||
|
if (
|
||||||
|
self.last_batch
|
||||||
|
and not self.last_batch.forward_mode.is_decode()
|
||||||
|
and not self.last_batch.is_empty()
|
||||||
|
):
|
||||||
|
if self.running_batch is None:
|
||||||
|
self.running_batch = self.last_batch
|
||||||
|
else:
|
||||||
|
self.running_batch.merge_batch(self.last_batch)
|
||||||
|
|
||||||
|
# Prefill first
|
||||||
new_batch = self.get_new_batch_prefill()
|
new_batch = self.get_new_batch_prefill()
|
||||||
if new_batch is not None:
|
if new_batch is not None:
|
||||||
# Run a new prefill batch
|
return new_batch
|
||||||
# replace run_batch with the uncommented line to use pytorch profiler
|
|
||||||
# result = pytorch_profile(
|
# Run decode
|
||||||
# "profile_prefill_step", self.run_batch, new_batch, data_size=len(new_batch.reqs)
|
if self.running_batch is not None:
|
||||||
# )
|
self.update_running_batch()
|
||||||
result = self.run_batch(new_batch)
|
if not self.running_batch:
|
||||||
self.process_batch_result(new_batch, result)
|
return None
|
||||||
|
return self.running_batch
|
||||||
else:
|
else:
|
||||||
if self.running_batch is not None:
|
self.check_memory()
|
||||||
# Run a few decode batches continuously for reducing overhead
|
self.new_token_ratio = global_config.init_new_token_ratio
|
||||||
for _ in range(global_config.num_continue_decode_steps):
|
|
||||||
batch = self.get_new_batch_decode()
|
|
||||||
|
|
||||||
if batch:
|
|
||||||
# replace run_batch with the uncommented line to use pytorch profiler
|
|
||||||
# result = pytorch_profile(
|
|
||||||
# "profile_decode_step",
|
|
||||||
# self.run_batch,
|
|
||||||
# batch,
|
|
||||||
# data_size=len(batch.reqs),
|
|
||||||
# )
|
|
||||||
result = self.run_batch(batch)
|
|
||||||
self.process_batch_result(batch, result)
|
|
||||||
|
|
||||||
if self.running_batch.is_empty():
|
|
||||||
self.running_batch = None
|
|
||||||
|
|
||||||
if self.running_batch is None:
|
|
||||||
break
|
|
||||||
|
|
||||||
if self.out_pyobjs and self.running_batch.has_stream:
|
|
||||||
break
|
|
||||||
else:
|
|
||||||
self.check_memory()
|
|
||||||
self.new_token_ratio = global_config.init_new_token_ratio
|
|
||||||
|
|
||||||
def get_new_batch_prefill(self) -> Optional[ScheduleBatch]:
|
def get_new_batch_prefill(self) -> Optional[ScheduleBatch]:
|
||||||
# Handle the cases where prefill is not allowed
|
# Handle the cases where prefill is not allowed
|
||||||
@@ -607,7 +609,7 @@ class Scheduler:
|
|||||||
|
|
||||||
return new_batch
|
return new_batch
|
||||||
|
|
||||||
def get_new_batch_decode(self) -> Optional[ScheduleBatch]:
|
def update_running_batch(self):
|
||||||
batch = self.running_batch
|
batch = self.running_batch
|
||||||
|
|
||||||
# Check if decode out of memory
|
# Check if decode out of memory
|
||||||
@@ -636,11 +638,11 @@ class Scheduler:
|
|||||||
if jump_forward_reqs:
|
if jump_forward_reqs:
|
||||||
self.batch_is_full = False
|
self.batch_is_full = False
|
||||||
if batch.is_empty():
|
if batch.is_empty():
|
||||||
return None
|
self.running_batch = None
|
||||||
|
return
|
||||||
|
|
||||||
# Update batch tensors
|
# Update batch tensors
|
||||||
batch.prepare_for_decode()
|
batch.prepare_for_decode()
|
||||||
return batch
|
|
||||||
|
|
||||||
def run_batch(self, batch: ScheduleBatch):
|
def run_batch(self, batch: ScheduleBatch):
|
||||||
if self.is_generation:
|
if self.is_generation:
|
||||||
@@ -657,16 +659,19 @@ class Scheduler:
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
next_token_ids = torch.full((batch.batch_size(),), 0)
|
next_token_ids = torch.full((batch.batch_size(),), 0)
|
||||||
return logits_output, next_token_ids
|
ret = logits_output, next_token_ids
|
||||||
else: # embedding or reward model
|
else: # embedding or reward model
|
||||||
assert batch.extend_num_tokens != 0
|
assert batch.extend_num_tokens != 0
|
||||||
model_worker_batch = batch.get_model_worker_batch()
|
model_worker_batch = batch.get_model_worker_batch()
|
||||||
embeddings = self.tp_worker.forward_batch_embedding(model_worker_batch)
|
embeddings = self.tp_worker.forward_batch_embedding(model_worker_batch)
|
||||||
return embeddings
|
ret = embeddings
|
||||||
|
return ret
|
||||||
|
|
||||||
def process_batch_result(self, batch: ScheduleBatch, result):
|
def process_batch_result(self, batch: ScheduleBatch, result):
|
||||||
if batch.forward_mode.is_decode():
|
if batch.forward_mode.is_decode():
|
||||||
self.process_batch_result_decode(batch, result)
|
self.process_batch_result_decode(batch, result)
|
||||||
|
if batch.is_empty():
|
||||||
|
self.running_batch = None
|
||||||
else:
|
else:
|
||||||
self.process_batch_result_prefill(batch, result)
|
self.process_batch_result_prefill(batch, result)
|
||||||
|
|
||||||
@@ -728,7 +733,7 @@ class Scheduler:
|
|||||||
)
|
)
|
||||||
else: # embedding or reward model
|
else: # embedding or reward model
|
||||||
assert batch.extend_num_tokens != 0
|
assert batch.extend_num_tokens != 0
|
||||||
embeddings = result
|
embeddings = result.tolist()
|
||||||
|
|
||||||
# Check finish conditions
|
# Check finish conditions
|
||||||
for i, req in enumerate(batch.reqs):
|
for i, req in enumerate(batch.reqs):
|
||||||
@@ -750,12 +755,6 @@ class Scheduler:
|
|||||||
|
|
||||||
self.handle_finished_requests(batch)
|
self.handle_finished_requests(batch)
|
||||||
|
|
||||||
if not batch.is_empty():
|
|
||||||
if self.running_batch is None:
|
|
||||||
self.running_batch = batch
|
|
||||||
else:
|
|
||||||
self.running_batch.merge_batch(batch)
|
|
||||||
|
|
||||||
def process_batch_result_decode(self, batch: ScheduleBatch, result):
|
def process_batch_result_decode(self, batch: ScheduleBatch, result):
|
||||||
logits_output, next_token_ids = result
|
logits_output, next_token_ids = result
|
||||||
if batch.sampling_info.penalizer_orchestrator:
|
if batch.sampling_info.penalizer_orchestrator:
|
||||||
@@ -951,7 +950,7 @@ class Scheduler:
|
|||||||
# Send to detokenizer
|
# Send to detokenizer
|
||||||
if output_rids:
|
if output_rids:
|
||||||
if self.is_generation:
|
if self.is_generation:
|
||||||
self.out_pyobjs.append(
|
self.send_to_detokenizer.send_pyobj(
|
||||||
BatchTokenIDOut(
|
BatchTokenIDOut(
|
||||||
output_rids,
|
output_rids,
|
||||||
output_vids,
|
output_vids,
|
||||||
@@ -965,7 +964,7 @@ class Scheduler:
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
else: # embedding or reward model
|
else: # embedding or reward model
|
||||||
self.out_pyobjs.append(
|
self.send_to_detokenizer.send_pyobj(
|
||||||
BatchEmbeddingOut(
|
BatchEmbeddingOut(
|
||||||
output_rids,
|
output_rids,
|
||||||
output_embeddings,
|
output_embeddings,
|
||||||
|
|||||||
@@ -118,7 +118,7 @@ class TpModelWorker:
|
|||||||
def forward_batch_embedding(self, model_worker_batch: ModelWorkerBatch):
|
def forward_batch_embedding(self, model_worker_batch: ModelWorkerBatch):
|
||||||
forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
|
forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
|
||||||
logits_output = self.model_runner.forward(forward_batch)
|
logits_output = self.model_runner.forward(forward_batch)
|
||||||
embeddings = logits_output.embeddings.tolist()
|
embeddings = logits_output.embeddings
|
||||||
return embeddings
|
return embeddings
|
||||||
|
|
||||||
def update_weights(self, recv_req: UpdateWeightReqInput):
|
def update_weights(self, recv_req: UpdateWeightReqInput):
|
||||||
|
|||||||
@@ -111,6 +111,7 @@ class ServerArgs:
|
|||||||
torchao_config: str = ""
|
torchao_config: str = ""
|
||||||
enable_p2p_check: bool = False
|
enable_p2p_check: bool = False
|
||||||
triton_attention_reduce_in_fp32: bool = False
|
triton_attention_reduce_in_fp32: bool = False
|
||||||
|
num_continuous_decode_steps: int = 1
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
# Set missing default values
|
# Set missing default values
|
||||||
@@ -559,6 +560,14 @@ class ServerArgs:
|
|||||||
help="Cast the intermidiate attention results to fp32 to avoid possible crashes related to fp16."
|
help="Cast the intermidiate attention results to fp32 to avoid possible crashes related to fp16."
|
||||||
"This only affects Triton attention kernels.",
|
"This only affects Triton attention kernels.",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--num-continuous-decode-steps",
|
||||||
|
type=int,
|
||||||
|
default=ServerArgs.num_continuous_decode_steps,
|
||||||
|
help="Run multiple continuous decoding steps to reduce scheduling overhead. "
|
||||||
|
"This can potentially increase throughput but may also increase time-to-first-token latency. "
|
||||||
|
"The default value is 1, meaning only run one decoding step at a time.",
|
||||||
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_cli_args(cls, args: argparse.Namespace):
|
def from_cli_args(cls, args: argparse.Namespace):
|
||||||
|
|||||||
Reference in New Issue
Block a user