Optimize broadcast & Reorg code (#1598)
This commit is contained in:
@@ -148,6 +148,6 @@ def get_act_fn(
|
||||
|
||||
if not is_flashinfer_available():
|
||||
logger.info(
|
||||
"FlashInfer is not available on Non-NV GPUs. Fallback to other kernel libraries."
|
||||
"FlashInfer is not available on Non-NV platforms. Fallback to other kernel libraries."
|
||||
)
|
||||
from vllm.model_executor.layers.activation import GeluAndMul, SiluAndMul
|
||||
|
||||
@@ -234,14 +234,9 @@ class Scheduler:
|
||||
recv_reqs = self.recv_requests()
|
||||
self.process_input_requests(recv_reqs)
|
||||
|
||||
# Run one step
|
||||
self.run_step()
|
||||
|
||||
# Send results
|
||||
if self.tp_rank == 0:
|
||||
for obj in self.out_pyobjs:
|
||||
self.send_to_detokenizer.send_pyobj(obj)
|
||||
self.out_pyobjs = []
|
||||
self.send_results()
|
||||
|
||||
def recv_requests(self):
|
||||
if self.tp_rank == 0:
|
||||
@@ -256,7 +251,8 @@ class Scheduler:
|
||||
else:
|
||||
recv_reqs = None
|
||||
|
||||
recv_reqs = broadcast_pyobj(recv_reqs, self.tp_rank, self.tp_cpu_group)
|
||||
if self.tp_size != 1:
|
||||
recv_reqs = broadcast_pyobj(recv_reqs, self.tp_rank, self.tp_cpu_group)
|
||||
return recv_reqs
|
||||
|
||||
def process_input_requests(self, recv_reqs: List):
|
||||
@@ -366,43 +362,11 @@ class Scheduler:
|
||||
|
||||
self.waiting_queue.append(req)
|
||||
|
||||
def run_step(self):
|
||||
new_batch = self.get_new_batch_prefill()
|
||||
|
||||
if new_batch is not None:
|
||||
# Run a new prefill batch
|
||||
result = self.run_batch(new_batch)
|
||||
self.process_batch_result(new_batch, result)
|
||||
|
||||
if not new_batch.is_empty():
|
||||
if self.running_batch is None:
|
||||
self.running_batch = new_batch
|
||||
else:
|
||||
self.running_batch.merge_batch(new_batch)
|
||||
else:
|
||||
# Run a decode batch
|
||||
if self.running_batch is not None:
|
||||
# Run a few decode batches continuously for reducing overhead
|
||||
for _ in range(global_config.num_continue_decode_steps):
|
||||
batch = self.get_new_batch_decode()
|
||||
|
||||
if batch:
|
||||
result = self.run_batch(batch)
|
||||
self.process_batch_result(batch, result)
|
||||
|
||||
# Print stats
|
||||
if self.tp_rank == 0 and self.decode_forward_ct % 40 == 0:
|
||||
self.print_decode_stats()
|
||||
|
||||
if self.running_batch.is_empty():
|
||||
self.running_batch = None
|
||||
break
|
||||
|
||||
if self.out_pyobjs and self.running_batch.has_stream:
|
||||
break
|
||||
else:
|
||||
self.check_memory()
|
||||
self.new_token_ratio = global_config.init_new_token_ratio
|
||||
def send_results(self):
|
||||
if self.tp_rank == 0:
|
||||
for obj in self.out_pyobjs:
|
||||
self.send_to_detokenizer.send_pyobj(obj)
|
||||
self.out_pyobjs = []
|
||||
|
||||
def print_decode_stats(self):
|
||||
num_used = self.max_total_num_tokens - (
|
||||
@@ -441,6 +405,31 @@ class Scheduler:
|
||||
)
|
||||
exit(1) if crash_on_warning else None
|
||||
|
||||
def run_step(self):
|
||||
new_batch = self.get_new_batch_prefill()
|
||||
if new_batch is not None:
|
||||
# Run a new prefill batch
|
||||
result = self.run_batch(new_batch)
|
||||
self.process_batch_result(new_batch, result)
|
||||
else:
|
||||
if self.running_batch is not None:
|
||||
# Run a few decode batches continuously for reducing overhead
|
||||
for _ in range(global_config.num_continue_decode_steps):
|
||||
batch = self.get_new_batch_decode()
|
||||
|
||||
if batch:
|
||||
result = self.run_batch(batch)
|
||||
self.process_batch_result(batch, result)
|
||||
|
||||
if self.running_batch is None:
|
||||
break
|
||||
|
||||
if self.out_pyobjs and self.running_batch.has_stream:
|
||||
break
|
||||
else:
|
||||
self.check_memory()
|
||||
self.new_token_ratio = global_config.init_new_token_ratio
|
||||
|
||||
def get_new_batch_prefill(self) -> Optional[ScheduleBatch]:
|
||||
# Handle the cases where prefill is not allowed
|
||||
if (
|
||||
@@ -612,7 +601,6 @@ class Scheduler:
|
||||
return None
|
||||
|
||||
# Update batch tensors
|
||||
self.decode_forward_ct = (self.decode_forward_ct + 1) % (1 << 30)
|
||||
batch.prepare_for_decode()
|
||||
return batch
|
||||
|
||||
@@ -723,6 +711,12 @@ class Scheduler:
|
||||
|
||||
self.handle_finished_requests(batch)
|
||||
|
||||
if not batch.is_empty():
|
||||
if self.running_batch is None:
|
||||
self.running_batch = batch
|
||||
else:
|
||||
self.running_batch.merge_batch(batch)
|
||||
|
||||
def process_batch_result_decode(self, batch: ScheduleBatch, result):
|
||||
logits_output, next_token_ids = result
|
||||
batch.sampling_info.penalizer_orchestrator.cumulate_output_tokens(
|
||||
@@ -762,6 +756,13 @@ class Scheduler:
|
||||
|
||||
self.handle_finished_requests(batch)
|
||||
|
||||
self.decode_forward_ct = (self.decode_forward_ct + 1) % (1 << 30)
|
||||
if self.tp_rank == 0 and self.decode_forward_ct % 40 == 0:
|
||||
self.print_decode_stats()
|
||||
|
||||
if self.running_batch.is_empty():
|
||||
self.running_batch = None
|
||||
|
||||
def add_logprob_return_values(
|
||||
self,
|
||||
i: int,
|
||||
|
||||
@@ -24,6 +24,7 @@ import random
|
||||
import resource
|
||||
import socket
|
||||
import time
|
||||
import warnings
|
||||
from importlib.metadata import PackageNotFoundError, version
|
||||
from io import BytesIO
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
@@ -333,6 +334,10 @@ def suppress_other_loggers():
|
||||
logging.getLogger("vllm.selector").setLevel(logging.WARN)
|
||||
logging.getLogger("vllm.utils").setLevel(logging.ERROR)
|
||||
|
||||
warnings.filterwarnings(
|
||||
"ignore", category=UserWarning, message="The given NumPy array is not writable"
|
||||
)
|
||||
|
||||
|
||||
def assert_pkg_version(pkg: str, min_version: str, message: str):
|
||||
try:
|
||||
@@ -615,7 +620,9 @@ def broadcast_pyobj(
|
||||
else:
|
||||
serialized_data = pickle.dumps(data)
|
||||
size = len(serialized_data)
|
||||
tensor_data = torch.ByteTensor(list(serialized_data))
|
||||
tensor_data = torch.ByteTensor(
|
||||
np.frombuffer(serialized_data, dtype=np.uint8)
|
||||
)
|
||||
tensor_size = torch.tensor([size], dtype=torch.long)
|
||||
|
||||
dist.broadcast(tensor_size, src=0, group=dist_group)
|
||||
|
||||
Reference in New Issue
Block a user