Launch a thread to overlap CPU and GPU (#1687)
This commit is contained in:
@@ -193,16 +193,6 @@ class Scheduler:
|
|||||||
self.tree_cache_metrics = {"total": 0, "hit": 0}
|
self.tree_cache_metrics = {"total": 0, "hit": 0}
|
||||||
self.policy = SchedulePolicy(self.schedule_policy, self.tree_cache)
|
self.policy = SchedulePolicy(self.schedule_policy, self.tree_cache)
|
||||||
|
|
||||||
if self.server_args.enable_overlap_schedule:
|
|
||||||
|
|
||||||
def cache_finished_req(req):
|
|
||||||
free_delta = int(self.running_batch and req in self.cur_batch.reqs)
|
|
||||||
self.tree_cache.cache_finished_req(req, free_delta=free_delta)
|
|
||||||
|
|
||||||
else:
|
|
||||||
cache_finished_req = self.tree_cache.cache_finished_req
|
|
||||||
self.cache_finished_req = cache_finished_req
|
|
||||||
|
|
||||||
# Init running status
|
# Init running status
|
||||||
self.waiting_queue: List[Req] = []
|
self.waiting_queue: List[Req] = []
|
||||||
self.running_batch: Optional[ScheduleBatch] = None
|
self.running_batch: Optional[ScheduleBatch] = None
|
||||||
@@ -245,6 +235,7 @@ class Scheduler:
|
|||||||
self.new_token_ratio_decay = global_config.new_token_ratio_decay
|
self.new_token_ratio_decay = global_config.new_token_ratio_decay
|
||||||
self.batch_is_full = False
|
self.batch_is_full = False
|
||||||
|
|
||||||
|
# Init profiler
|
||||||
if os.getenv("SGLANG_TORCH_PROFILER_DIR", "") == "":
|
if os.getenv("SGLANG_TORCH_PROFILER_DIR", "") == "":
|
||||||
self.profiler = None
|
self.profiler = None
|
||||||
else:
|
else:
|
||||||
@@ -261,6 +252,25 @@ class Scheduler:
|
|||||||
with_stack=True,
|
with_stack=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Init states for overlap schedule
|
||||||
|
if self.server_args.enable_overlap_schedule:
|
||||||
|
self.forward_batch_generation = (
|
||||||
|
self.tp_worker.forward_batch_generation_non_blocking
|
||||||
|
)
|
||||||
|
self.resolve_next_token_ids = (
|
||||||
|
lambda bid, x: self.tp_worker.resolve_future_token_ids(bid)
|
||||||
|
)
|
||||||
|
|
||||||
|
def cache_finished_req(req):
|
||||||
|
free_delta = int(self.running_batch and req in self.cur_batch.reqs)
|
||||||
|
self.tree_cache.cache_finished_req(req, free_delta=free_delta)
|
||||||
|
|
||||||
|
self.cache_finished_req = cache_finished_req
|
||||||
|
else:
|
||||||
|
self.forward_batch_generation = self.tp_worker.forward_batch_generation
|
||||||
|
self.resolve_next_token_ids = lambda bid, x: x.tolist()
|
||||||
|
self.cache_finished_req = self.tree_cache.cache_finished_req
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def event_loop_normal(self):
|
def event_loop_normal(self):
|
||||||
self.last_batch = None
|
self.last_batch = None
|
||||||
@@ -712,7 +722,7 @@ class Scheduler:
|
|||||||
if self.is_generation:
|
if self.is_generation:
|
||||||
if batch.forward_mode.is_decode() or batch.extend_num_tokens != 0:
|
if batch.forward_mode.is_decode() or batch.extend_num_tokens != 0:
|
||||||
model_worker_batch = batch.get_model_worker_batch()
|
model_worker_batch = batch.get_model_worker_batch()
|
||||||
logits_output, next_token_ids = self.tp_worker.forward_batch_generation(
|
logits_output, next_token_ids = self.forward_batch_generation(
|
||||||
model_worker_batch
|
model_worker_batch
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
@@ -724,12 +734,12 @@ class Scheduler:
|
|||||||
else:
|
else:
|
||||||
next_token_ids = torch.full((batch.batch_size(),), 0)
|
next_token_ids = torch.full((batch.batch_size(),), 0)
|
||||||
batch.output_ids = next_token_ids
|
batch.output_ids = next_token_ids
|
||||||
ret = logits_output, next_token_ids
|
ret = logits_output, next_token_ids, model_worker_batch.bid
|
||||||
else: # embedding or reward model
|
else: # embedding or reward model
|
||||||
assert batch.extend_num_tokens != 0
|
assert batch.extend_num_tokens != 0
|
||||||
model_worker_batch = batch.get_model_worker_batch()
|
model_worker_batch = batch.get_model_worker_batch()
|
||||||
embeddings = self.tp_worker.forward_batch_embedding(model_worker_batch)
|
embeddings = self.tp_worker.forward_batch_embedding(model_worker_batch)
|
||||||
ret = embeddings
|
ret = embeddings, model_worker_batch.bid
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
def process_batch_result(self, batch: ScheduleBatch, result):
|
def process_batch_result(self, batch: ScheduleBatch, result):
|
||||||
@@ -742,7 +752,7 @@ class Scheduler:
|
|||||||
|
|
||||||
def process_batch_result_prefill(self, batch: ScheduleBatch, result):
|
def process_batch_result_prefill(self, batch: ScheduleBatch, result):
|
||||||
if self.is_generation:
|
if self.is_generation:
|
||||||
logits_output, next_token_ids = result
|
logits_output, next_token_ids, bid = result
|
||||||
if batch.return_logprob:
|
if batch.return_logprob:
|
||||||
# Move logprobs to cpu
|
# Move logprobs to cpu
|
||||||
if logits_output.next_token_logprobs is not None:
|
if logits_output.next_token_logprobs is not None:
|
||||||
@@ -761,7 +771,7 @@ class Scheduler:
|
|||||||
logits_output.normalized_prompt_logprobs.tolist()
|
logits_output.normalized_prompt_logprobs.tolist()
|
||||||
)
|
)
|
||||||
|
|
||||||
next_token_ids = next_token_ids.tolist()
|
next_token_ids = self.resolve_next_token_ids(bid, next_token_ids)
|
||||||
|
|
||||||
# Check finish conditions
|
# Check finish conditions
|
||||||
logprob_pt = 0
|
logprob_pt = 0
|
||||||
@@ -790,7 +800,8 @@ class Scheduler:
|
|||||||
)
|
)
|
||||||
else: # embedding or reward model
|
else: # embedding or reward model
|
||||||
assert batch.extend_num_tokens != 0
|
assert batch.extend_num_tokens != 0
|
||||||
embeddings = result.tolist()
|
embeddings, bid = result
|
||||||
|
embeddings = embeddings.tolist()
|
||||||
|
|
||||||
# Check finish conditions
|
# Check finish conditions
|
||||||
for i, req in enumerate(batch.reqs):
|
for i, req in enumerate(batch.reqs):
|
||||||
@@ -811,7 +822,7 @@ class Scheduler:
|
|||||||
self.stream_output(batch.reqs)
|
self.stream_output(batch.reqs)
|
||||||
|
|
||||||
def process_batch_result_decode(self, batch: ScheduleBatch, result):
|
def process_batch_result_decode(self, batch: ScheduleBatch, result):
|
||||||
logits_output, next_token_ids = result
|
logits_output, next_token_ids, bid = result
|
||||||
self.num_generated_tokens += len(batch.reqs)
|
self.num_generated_tokens += len(batch.reqs)
|
||||||
|
|
||||||
# Move logprobs to cpu
|
# Move logprobs to cpu
|
||||||
@@ -821,7 +832,7 @@ class Scheduler:
|
|||||||
next_token_ids,
|
next_token_ids,
|
||||||
].tolist()
|
].tolist()
|
||||||
|
|
||||||
next_token_ids = next_token_ids.tolist()
|
next_token_ids = self.resolve_next_token_ids(bid, next_token_ids)
|
||||||
|
|
||||||
# Check finish condition
|
# Check finish condition
|
||||||
for i, (req, next_token_id) in enumerate(zip(batch.reqs, next_token_ids)):
|
for i, (req, next_token_id) in enumerate(zip(batch.reqs, next_token_ids)):
|
||||||
|
|||||||
@@ -17,6 +17,11 @@ limitations under the License.
|
|||||||
|
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
|
import threading
|
||||||
|
import time
|
||||||
|
from queue import Queue
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
from sglang.srt.configs.model_config import ModelConfig
|
from sglang.srt.configs.model_config import ModelConfig
|
||||||
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
|
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
|
||||||
@@ -75,6 +80,7 @@ class TpModelWorker:
|
|||||||
tokenizer_mode=server_args.tokenizer_mode,
|
tokenizer_mode=server_args.tokenizer_mode,
|
||||||
trust_remote_code=server_args.trust_remote_code,
|
trust_remote_code=server_args.trust_remote_code,
|
||||||
)
|
)
|
||||||
|
self.device = self.model_runner.device
|
||||||
|
|
||||||
# Profile number of tokens
|
# Profile number of tokens
|
||||||
self.max_total_num_tokens = self.model_runner.max_total_num_tokens
|
self.max_total_num_tokens = self.model_runner.max_total_num_tokens
|
||||||
@@ -100,6 +106,9 @@ class TpModelWorker:
|
|||||||
)[0]
|
)[0]
|
||||||
set_random_seed(self.random_seed)
|
set_random_seed(self.random_seed)
|
||||||
|
|
||||||
|
if server_args.enable_overlap_schedule:
|
||||||
|
self.init_overlap_status()
|
||||||
|
|
||||||
def get_token_and_memory_info(self):
|
def get_token_and_memory_info(self):
|
||||||
return (
|
return (
|
||||||
self.max_total_num_tokens,
|
self.max_total_num_tokens,
|
||||||
@@ -109,6 +118,83 @@ class TpModelWorker:
|
|||||||
self.random_seed,
|
self.random_seed,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def init_overlap_status(self):
|
||||||
|
self.future_logits_output_dict = dict()
|
||||||
|
self.future_logits_output_ct = 0
|
||||||
|
self.future_token_ids_ct = 0
|
||||||
|
self.future_token_ids_map = torch.empty(
|
||||||
|
(self.max_running_requests * 5,), dtype=torch.int32, device=self.device
|
||||||
|
)
|
||||||
|
self.future_token_ids_limit = self.max_running_requests * 3
|
||||||
|
self.future_token_ids_output = dict()
|
||||||
|
|
||||||
|
self.future_event_map = dict()
|
||||||
|
self.forward_queue = Queue()
|
||||||
|
self.forward_stream = torch.cuda.Stream()
|
||||||
|
self.forward_thread = threading.Thread(
|
||||||
|
target=self.forward_thread_func,
|
||||||
|
)
|
||||||
|
self.forward_thread.start()
|
||||||
|
|
||||||
|
def forward_thread_func(self):
|
||||||
|
with torch.cuda.stream(self.forward_stream):
|
||||||
|
self.forward_thread_func_()
|
||||||
|
|
||||||
|
@torch.inference_mode()
|
||||||
|
def forward_thread_func_(self):
|
||||||
|
while True:
|
||||||
|
tic1 = time.time()
|
||||||
|
model_worker_batch, future_logits_output, future_next_token_ids = (
|
||||||
|
self.forward_queue.get()
|
||||||
|
)
|
||||||
|
|
||||||
|
# Resolve future tokens in the input
|
||||||
|
# logger.info(f"raw input {model_worker_batch.input_ids=}")
|
||||||
|
tic2 = time.time()
|
||||||
|
resolved_input_ids = model_worker_batch.input_ids
|
||||||
|
future_mask = resolved_input_ids < 0
|
||||||
|
resolved_input_ids[future_mask] = self.future_token_ids_map[
|
||||||
|
-resolved_input_ids[future_mask]
|
||||||
|
]
|
||||||
|
# logger.info(f"resolved input {model_worker_batch.input_ids=}")
|
||||||
|
|
||||||
|
# Run forward
|
||||||
|
logits_output, next_token_ids = self.forward_batch_generation(
|
||||||
|
model_worker_batch
|
||||||
|
)
|
||||||
|
|
||||||
|
# Set future values
|
||||||
|
if model_worker_batch.return_logprob:
|
||||||
|
self.future_logits_output_dict[future_logits_output] = logits_output
|
||||||
|
|
||||||
|
# logger.info(f"set output {future_next_token_ids=}, {next_token_ids=}")
|
||||||
|
self.future_token_ids_map[-future_next_token_ids] = next_token_ids.to(
|
||||||
|
torch.int32
|
||||||
|
)
|
||||||
|
# logger.info("Set event")
|
||||||
|
self.future_token_ids_output[model_worker_batch.bid] = (
|
||||||
|
next_token_ids.tolist()
|
||||||
|
)
|
||||||
|
self.future_event_map[model_worker_batch.bid].set()
|
||||||
|
|
||||||
|
if False:
|
||||||
|
tic3 = time.time()
|
||||||
|
self.acc_time_with_waiting += tic3 - tic1
|
||||||
|
self.acc_time_without_waiting += tic3 - tic2
|
||||||
|
if self.forward_queue.qsize() == 0:
|
||||||
|
logger.info(
|
||||||
|
f"{self.acc_time_with_waiting=:.3f}, {self.acc_time_without_waiting=:.3f}, {self.forward_queue.qsize()=}"
|
||||||
|
)
|
||||||
|
|
||||||
|
def resolve_future_token_ids(self, bid: int):
|
||||||
|
self.future_event_map[bid].wait()
|
||||||
|
ret = self.future_token_ids_output[bid]
|
||||||
|
del self.future_event_map[bid]
|
||||||
|
return ret
|
||||||
|
|
||||||
|
def resolve_future_logits_output(self, future_obj):
|
||||||
|
return self.future_logits_output_dict.pop(future_obj)
|
||||||
|
|
||||||
def forward_batch_generation(self, model_worker_batch: ModelWorkerBatch):
|
def forward_batch_generation(self, model_worker_batch: ModelWorkerBatch):
|
||||||
forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
|
forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
|
||||||
logits_output = self.model_runner.forward(forward_batch)
|
logits_output = self.model_runner.forward(forward_batch)
|
||||||
@@ -121,6 +207,31 @@ class TpModelWorker:
|
|||||||
embeddings = logits_output.embeddings
|
embeddings = logits_output.embeddings
|
||||||
return embeddings
|
return embeddings
|
||||||
|
|
||||||
|
def forward_batch_generation_non_blocking(
|
||||||
|
self, model_worker_batch: ModelWorkerBatch
|
||||||
|
):
|
||||||
|
# Allocate output future objects
|
||||||
|
future_logits_output = self.future_logits_output_ct
|
||||||
|
self.future_logits_output_ct += 1
|
||||||
|
|
||||||
|
bs = len(model_worker_batch.seq_lens)
|
||||||
|
future_next_token_ids = -torch.arange(
|
||||||
|
self.future_token_ids_ct + 1,
|
||||||
|
self.future_token_ids_ct + 1 + bs,
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=self.device,
|
||||||
|
)
|
||||||
|
self.future_token_ids_ct = (
|
||||||
|
self.future_token_ids_ct + bs
|
||||||
|
) % self.future_token_ids_limit
|
||||||
|
ret = future_logits_output, future_next_token_ids
|
||||||
|
|
||||||
|
self.future_event_map[model_worker_batch.bid] = threading.Event()
|
||||||
|
self.forward_queue.put(
|
||||||
|
(model_worker_batch.copy(), future_logits_output, future_next_token_ids)
|
||||||
|
)
|
||||||
|
return ret
|
||||||
|
|
||||||
def update_weights(self, recv_req: UpdateWeightReqInput):
|
def update_weights(self, recv_req: UpdateWeightReqInput):
|
||||||
success, message = self.model_runner.update_weights(
|
success, message = self.model_runner.update_weights(
|
||||||
recv_req.model_path, recv_req.load_format
|
recv_req.model_path, recv_req.load_format
|
||||||
|
|||||||
@@ -447,7 +447,7 @@ def _set_envs_and_config(server_args: ServerArgs):
|
|||||||
os.environ["NCCL_CUMEM_ENABLE"] = "0"
|
os.environ["NCCL_CUMEM_ENABLE"] = "0"
|
||||||
os.environ["NCCL_NVLS_ENABLE"] = "0"
|
os.environ["NCCL_NVLS_ENABLE"] = "0"
|
||||||
os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1"
|
os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1"
|
||||||
os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "1"
|
os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "4"
|
||||||
|
|
||||||
# Set ulimit
|
# Set ulimit
|
||||||
set_ulimit()
|
set_ulimit()
|
||||||
@@ -528,7 +528,7 @@ def _wait_and_warmup(server_args, pipe_finish_writer, pid):
|
|||||||
kill_child_process(pid, including_parent=False)
|
kill_child_process(pid, including_parent=False)
|
||||||
return
|
return
|
||||||
|
|
||||||
# print(f"{res.json()=}")
|
print(f"{res.json()=}")
|
||||||
|
|
||||||
logger.info("The server is fired up and ready to roll!")
|
logger.info("The server is fired up and ready to roll!")
|
||||||
if pipe_finish_writer is not None:
|
if pipe_finish_writer is not None:
|
||||||
|
|||||||
Reference in New Issue
Block a user