Use cuda event wait and synchronization instead of busy waiting (#2089)
This commit is contained in:
@@ -1063,7 +1063,7 @@ class ScheduleBatch:
|
|||||||
out_cache_loc=self.out_cache_loc,
|
out_cache_loc=self.out_cache_loc,
|
||||||
return_logprob=self.return_logprob,
|
return_logprob=self.return_logprob,
|
||||||
decoding_reqs=self.decoding_reqs,
|
decoding_reqs=self.decoding_reqs,
|
||||||
sampling_info=dataclasses.replace(self.sampling_info),
|
sampling_info=self.sampling_info,
|
||||||
)
|
)
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
|
|||||||
@@ -387,9 +387,6 @@ class Scheduler:
|
|||||||
batch = self.get_next_batch_to_run()
|
batch = self.get_next_batch_to_run()
|
||||||
self.cur_batch = batch
|
self.cur_batch = batch
|
||||||
if batch:
|
if batch:
|
||||||
# We need a stream synchronization here. Otherwise, there will be cuda illegal memory access errors.
|
|
||||||
_ = batch.seq_lens[0].item()
|
|
||||||
|
|
||||||
result = self.run_batch(batch)
|
result = self.run_batch(batch)
|
||||||
result_queue.append((batch.copy(), result))
|
result_queue.append((batch.copy(), result))
|
||||||
|
|
||||||
|
|||||||
@@ -142,12 +142,12 @@ class TpModelWorker:
|
|||||||
def forward_batch_generation(
|
def forward_batch_generation(
|
||||||
self,
|
self,
|
||||||
model_worker_batch: ModelWorkerBatch,
|
model_worker_batch: ModelWorkerBatch,
|
||||||
launch_event: Optional[threading.Event] = None,
|
launch_done: Optional[threading.Event] = None,
|
||||||
):
|
):
|
||||||
forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
|
forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
|
||||||
logits_output = self.model_runner.forward(forward_batch)
|
logits_output = self.model_runner.forward(forward_batch)
|
||||||
if launch_event:
|
if launch_done:
|
||||||
launch_event.set()
|
launch_done.set()
|
||||||
next_token_ids = self.model_runner.sample(logits_output, model_worker_batch)
|
next_token_ids = self.model_runner.sample(logits_output, model_worker_batch)
|
||||||
return logits_output, next_token_ids
|
return logits_output, next_token_ids
|
||||||
|
|
||||||
|
|||||||
@@ -96,19 +96,22 @@ class TpModelWorkerClient:
|
|||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def forward_thread_func_(self):
|
def forward_thread_func_(self):
|
||||||
while True:
|
while True:
|
||||||
model_worker_batch, future_token_ids_ct = self.input_queue.get()
|
model_worker_batch, future_token_ids_ct, compute_info_done = (
|
||||||
|
self.input_queue.get()
|
||||||
|
)
|
||||||
if not model_worker_batch:
|
if not model_worker_batch:
|
||||||
break
|
break
|
||||||
self.launch_event = threading.Event()
|
self.launch_done = threading.Event()
|
||||||
copy_event = torch.cuda.Event()
|
copy_done = torch.cuda.Event()
|
||||||
|
|
||||||
# Resolve future tokens in the input
|
# Resolve future tokens in the input
|
||||||
input_ids = model_worker_batch.input_ids
|
input_ids = model_worker_batch.input_ids
|
||||||
resolve_future_token_ids(input_ids, self.future_token_ids_map)
|
resolve_future_token_ids(input_ids, self.future_token_ids_map)
|
||||||
|
|
||||||
# Run forward
|
# Run forward
|
||||||
|
compute_info_done.wait()
|
||||||
logits_output, next_token_ids = self.worker.forward_batch_generation(
|
logits_output, next_token_ids = self.worker.forward_batch_generation(
|
||||||
model_worker_batch, self.launch_event
|
model_worker_batch, self.launch_done
|
||||||
)
|
)
|
||||||
|
|
||||||
# Update the future token ids map
|
# Update the future token ids map
|
||||||
@@ -133,15 +136,14 @@ class TpModelWorkerClient:
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
next_token_ids = next_token_ids.to("cpu", non_blocking=True)
|
next_token_ids = next_token_ids.to("cpu", non_blocking=True)
|
||||||
copy_event.record()
|
copy_done.record()
|
||||||
|
|
||||||
self.output_queue.put((copy_event, logits_output, next_token_ids))
|
self.output_queue.put((copy_done, logits_output, next_token_ids))
|
||||||
|
|
||||||
def resolve_batch_result(self, bid: int):
|
def resolve_batch_result(self, bid: int):
|
||||||
copy_event, logits_output, next_token_ids = self.output_queue.get()
|
copy_done, logits_output, next_token_ids = self.output_queue.get()
|
||||||
while not copy_event.query():
|
copy_done.synchronize()
|
||||||
time.sleep(1e-5)
|
self.launch_done.wait()
|
||||||
self.launch_event.wait()
|
|
||||||
|
|
||||||
if logits_output.next_token_logprobs is not None:
|
if logits_output.next_token_logprobs is not None:
|
||||||
logits_output.next_token_logprobs = (
|
logits_output.next_token_logprobs = (
|
||||||
@@ -162,7 +164,11 @@ class TpModelWorkerClient:
|
|||||||
model_worker_batch.sampling_info = dataclasses.replace(
|
model_worker_batch.sampling_info = dataclasses.replace(
|
||||||
model_worker_batch.sampling_info
|
model_worker_batch.sampling_info
|
||||||
)
|
)
|
||||||
self.input_queue.put((model_worker_batch, self.future_token_ids_ct))
|
compute_info_done = torch.cuda.Event()
|
||||||
|
compute_info_done.record()
|
||||||
|
self.input_queue.put(
|
||||||
|
(model_worker_batch, self.future_token_ids_ct, compute_info_done)
|
||||||
|
)
|
||||||
|
|
||||||
# Allocate output future objects
|
# Allocate output future objects
|
||||||
bs = len(model_worker_batch.seq_lens)
|
bs = len(model_worker_batch.seq_lens)
|
||||||
|
|||||||
@@ -38,7 +38,7 @@ class TestLargeMaxNewTokens(unittest.TestCase):
|
|||||||
api_key=cls.api_key,
|
api_key=cls.api_key,
|
||||||
other_args=(
|
other_args=(
|
||||||
"--max-total-token",
|
"--max-total-token",
|
||||||
"1024",
|
"1536",
|
||||||
"--context-len",
|
"--context-len",
|
||||||
"8192",
|
"8192",
|
||||||
"--decode-log-interval",
|
"--decode-log-interval",
|
||||||
|
|||||||
@@ -29,7 +29,7 @@ class TestSRTEngine(unittest.TestCase):
|
|||||||
|
|
||||||
sampling_params = {"temperature": 0, "max_new_tokens": 8}
|
sampling_params = {"temperature": 0, "max_new_tokens": 8}
|
||||||
|
|
||||||
engine = sgl.Engine(model_path=model_path, random_seed=42, log_level="error")
|
engine = sgl.Engine(model_path=model_path, random_seed=42)
|
||||||
out1 = engine.generate(prompt, sampling_params)["text"]
|
out1 = engine.generate(prompt, sampling_params)["text"]
|
||||||
engine.shutdown()
|
engine.shutdown()
|
||||||
|
|
||||||
@@ -51,7 +51,7 @@ class TestSRTEngine(unittest.TestCase):
|
|||||||
|
|
||||||
sampling_params = {"temperature": 0, "max_new_tokens": 8}
|
sampling_params = {"temperature": 0, "max_new_tokens": 8}
|
||||||
|
|
||||||
engine = sgl.Engine(model_path=model_path, random_seed=42, log_level="error")
|
engine = sgl.Engine(model_path=model_path, random_seed=42)
|
||||||
engine.generate(prompt, sampling_params)
|
engine.generate(prompt, sampling_params)
|
||||||
engine.generate(prompt, sampling_params)
|
engine.generate(prompt, sampling_params)
|
||||||
engine.shutdown()
|
engine.shutdown()
|
||||||
@@ -74,7 +74,6 @@ class TestSRTEngine(unittest.TestCase):
|
|||||||
# Create an LLM.
|
# Create an LLM.
|
||||||
llm = sgl.Engine(
|
llm = sgl.Engine(
|
||||||
model_path=DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
|
model_path=DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
|
||||||
log_level="error",
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# 1. sync + non streaming
|
# 1. sync + non streaming
|
||||||
@@ -118,7 +117,9 @@ class TestSRTEngine(unittest.TestCase):
|
|||||||
prompt = "The capital of UK is"
|
prompt = "The capital of UK is"
|
||||||
|
|
||||||
model_path = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
|
model_path = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
|
||||||
engine = sgl.Engine(model_path=model_path, random_seed=42, log_level="error")
|
engine = sgl.Engine(
|
||||||
|
model_path=model_path, random_seed=42, disable_radix_cache=True
|
||||||
|
)
|
||||||
sampling_params = {"temperature": 0, "max_new_tokens": 8}
|
sampling_params = {"temperature": 0, "max_new_tokens": 8}
|
||||||
out1 = engine.generate(prompt, sampling_params)["text"]
|
out1 = engine.generate(prompt, sampling_params)["text"]
|
||||||
|
|
||||||
@@ -141,9 +142,7 @@ class TestSRTEngine(unittest.TestCase):
|
|||||||
prompt = "Today is a sunny day and I like"
|
prompt = "Today is a sunny day and I like"
|
||||||
model_path = DEFAULT_SMALL_EMBEDDING_MODEL_NAME_FOR_TEST
|
model_path = DEFAULT_SMALL_EMBEDDING_MODEL_NAME_FOR_TEST
|
||||||
|
|
||||||
engine = sgl.Engine(
|
engine = sgl.Engine(model_path=model_path, is_embedding=True, random_seed=42)
|
||||||
model_path=model_path, is_embedding=True, random_seed=42, log_level="error"
|
|
||||||
)
|
|
||||||
out1 = torch.tensor(engine.encode(prompt)["embedding"])
|
out1 = torch.tensor(engine.encode(prompt)["embedding"])
|
||||||
engine.shutdown()
|
engine.shutdown()
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user