Optimize broadcast & Reorg code (#1598)
This commit is contained in:
@@ -148,6 +148,6 @@ def get_act_fn(
|
|||||||
|
|
||||||
if not is_flashinfer_available():
|
if not is_flashinfer_available():
|
||||||
logger.info(
|
logger.info(
|
||||||
"FlashInfer is not available on Non-NV GPUs. Fallback to other kernel libraries."
|
"FlashInfer is not available on Non-NV platforms. Fallback to other kernel libraries."
|
||||||
)
|
)
|
||||||
from vllm.model_executor.layers.activation import GeluAndMul, SiluAndMul
|
from vllm.model_executor.layers.activation import GeluAndMul, SiluAndMul
|
||||||
|
|||||||
@@ -234,14 +234,9 @@ class Scheduler:
|
|||||||
recv_reqs = self.recv_requests()
|
recv_reqs = self.recv_requests()
|
||||||
self.process_input_requests(recv_reqs)
|
self.process_input_requests(recv_reqs)
|
||||||
|
|
||||||
# Run one step
|
|
||||||
self.run_step()
|
self.run_step()
|
||||||
|
|
||||||
# Send results
|
self.send_results()
|
||||||
if self.tp_rank == 0:
|
|
||||||
for obj in self.out_pyobjs:
|
|
||||||
self.send_to_detokenizer.send_pyobj(obj)
|
|
||||||
self.out_pyobjs = []
|
|
||||||
|
|
||||||
def recv_requests(self):
|
def recv_requests(self):
|
||||||
if self.tp_rank == 0:
|
if self.tp_rank == 0:
|
||||||
@@ -256,6 +251,7 @@ class Scheduler:
|
|||||||
else:
|
else:
|
||||||
recv_reqs = None
|
recv_reqs = None
|
||||||
|
|
||||||
|
if self.tp_size != 1:
|
||||||
recv_reqs = broadcast_pyobj(recv_reqs, self.tp_rank, self.tp_cpu_group)
|
recv_reqs = broadcast_pyobj(recv_reqs, self.tp_rank, self.tp_cpu_group)
|
||||||
return recv_reqs
|
return recv_reqs
|
||||||
|
|
||||||
@@ -366,43 +362,11 @@ class Scheduler:
|
|||||||
|
|
||||||
self.waiting_queue.append(req)
|
self.waiting_queue.append(req)
|
||||||
|
|
||||||
def run_step(self):
|
def send_results(self):
|
||||||
new_batch = self.get_new_batch_prefill()
|
if self.tp_rank == 0:
|
||||||
|
for obj in self.out_pyobjs:
|
||||||
if new_batch is not None:
|
self.send_to_detokenizer.send_pyobj(obj)
|
||||||
# Run a new prefill batch
|
self.out_pyobjs = []
|
||||||
result = self.run_batch(new_batch)
|
|
||||||
self.process_batch_result(new_batch, result)
|
|
||||||
|
|
||||||
if not new_batch.is_empty():
|
|
||||||
if self.running_batch is None:
|
|
||||||
self.running_batch = new_batch
|
|
||||||
else:
|
|
||||||
self.running_batch.merge_batch(new_batch)
|
|
||||||
else:
|
|
||||||
# Run a decode batch
|
|
||||||
if self.running_batch is not None:
|
|
||||||
# Run a few decode batches continuously for reducing overhead
|
|
||||||
for _ in range(global_config.num_continue_decode_steps):
|
|
||||||
batch = self.get_new_batch_decode()
|
|
||||||
|
|
||||||
if batch:
|
|
||||||
result = self.run_batch(batch)
|
|
||||||
self.process_batch_result(batch, result)
|
|
||||||
|
|
||||||
# Print stats
|
|
||||||
if self.tp_rank == 0 and self.decode_forward_ct % 40 == 0:
|
|
||||||
self.print_decode_stats()
|
|
||||||
|
|
||||||
if self.running_batch.is_empty():
|
|
||||||
self.running_batch = None
|
|
||||||
break
|
|
||||||
|
|
||||||
if self.out_pyobjs and self.running_batch.has_stream:
|
|
||||||
break
|
|
||||||
else:
|
|
||||||
self.check_memory()
|
|
||||||
self.new_token_ratio = global_config.init_new_token_ratio
|
|
||||||
|
|
||||||
def print_decode_stats(self):
|
def print_decode_stats(self):
|
||||||
num_used = self.max_total_num_tokens - (
|
num_used = self.max_total_num_tokens - (
|
||||||
@@ -441,6 +405,31 @@ class Scheduler:
|
|||||||
)
|
)
|
||||||
exit(1) if crash_on_warning else None
|
exit(1) if crash_on_warning else None
|
||||||
|
|
||||||
|
def run_step(self):
|
||||||
|
new_batch = self.get_new_batch_prefill()
|
||||||
|
if new_batch is not None:
|
||||||
|
# Run a new prefill batch
|
||||||
|
result = self.run_batch(new_batch)
|
||||||
|
self.process_batch_result(new_batch, result)
|
||||||
|
else:
|
||||||
|
if self.running_batch is not None:
|
||||||
|
# Run a few decode batches continuously for reducing overhead
|
||||||
|
for _ in range(global_config.num_continue_decode_steps):
|
||||||
|
batch = self.get_new_batch_decode()
|
||||||
|
|
||||||
|
if batch:
|
||||||
|
result = self.run_batch(batch)
|
||||||
|
self.process_batch_result(batch, result)
|
||||||
|
|
||||||
|
if self.running_batch is None:
|
||||||
|
break
|
||||||
|
|
||||||
|
if self.out_pyobjs and self.running_batch.has_stream:
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
self.check_memory()
|
||||||
|
self.new_token_ratio = global_config.init_new_token_ratio
|
||||||
|
|
||||||
def get_new_batch_prefill(self) -> Optional[ScheduleBatch]:
|
def get_new_batch_prefill(self) -> Optional[ScheduleBatch]:
|
||||||
# Handle the cases where prefill is not allowed
|
# Handle the cases where prefill is not allowed
|
||||||
if (
|
if (
|
||||||
@@ -612,7 +601,6 @@ class Scheduler:
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
# Update batch tensors
|
# Update batch tensors
|
||||||
self.decode_forward_ct = (self.decode_forward_ct + 1) % (1 << 30)
|
|
||||||
batch.prepare_for_decode()
|
batch.prepare_for_decode()
|
||||||
return batch
|
return batch
|
||||||
|
|
||||||
@@ -723,6 +711,12 @@ class Scheduler:
|
|||||||
|
|
||||||
self.handle_finished_requests(batch)
|
self.handle_finished_requests(batch)
|
||||||
|
|
||||||
|
if not batch.is_empty():
|
||||||
|
if self.running_batch is None:
|
||||||
|
self.running_batch = batch
|
||||||
|
else:
|
||||||
|
self.running_batch.merge_batch(batch)
|
||||||
|
|
||||||
def process_batch_result_decode(self, batch: ScheduleBatch, result):
|
def process_batch_result_decode(self, batch: ScheduleBatch, result):
|
||||||
logits_output, next_token_ids = result
|
logits_output, next_token_ids = result
|
||||||
batch.sampling_info.penalizer_orchestrator.cumulate_output_tokens(
|
batch.sampling_info.penalizer_orchestrator.cumulate_output_tokens(
|
||||||
@@ -762,6 +756,13 @@ class Scheduler:
|
|||||||
|
|
||||||
self.handle_finished_requests(batch)
|
self.handle_finished_requests(batch)
|
||||||
|
|
||||||
|
self.decode_forward_ct = (self.decode_forward_ct + 1) % (1 << 30)
|
||||||
|
if self.tp_rank == 0 and self.decode_forward_ct % 40 == 0:
|
||||||
|
self.print_decode_stats()
|
||||||
|
|
||||||
|
if self.running_batch.is_empty():
|
||||||
|
self.running_batch = None
|
||||||
|
|
||||||
def add_logprob_return_values(
|
def add_logprob_return_values(
|
||||||
self,
|
self,
|
||||||
i: int,
|
i: int,
|
||||||
|
|||||||
@@ -24,6 +24,7 @@ import random
|
|||||||
import resource
|
import resource
|
||||||
import socket
|
import socket
|
||||||
import time
|
import time
|
||||||
|
import warnings
|
||||||
from importlib.metadata import PackageNotFoundError, version
|
from importlib.metadata import PackageNotFoundError, version
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
from typing import Any, Dict, List, Optional, Union
|
from typing import Any, Dict, List, Optional, Union
|
||||||
@@ -333,6 +334,10 @@ def suppress_other_loggers():
|
|||||||
logging.getLogger("vllm.selector").setLevel(logging.WARN)
|
logging.getLogger("vllm.selector").setLevel(logging.WARN)
|
||||||
logging.getLogger("vllm.utils").setLevel(logging.ERROR)
|
logging.getLogger("vllm.utils").setLevel(logging.ERROR)
|
||||||
|
|
||||||
|
warnings.filterwarnings(
|
||||||
|
"ignore", category=UserWarning, message="The given NumPy array is not writable"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def assert_pkg_version(pkg: str, min_version: str, message: str):
|
def assert_pkg_version(pkg: str, min_version: str, message: str):
|
||||||
try:
|
try:
|
||||||
@@ -615,7 +620,9 @@ def broadcast_pyobj(
|
|||||||
else:
|
else:
|
||||||
serialized_data = pickle.dumps(data)
|
serialized_data = pickle.dumps(data)
|
||||||
size = len(serialized_data)
|
size = len(serialized_data)
|
||||||
tensor_data = torch.ByteTensor(list(serialized_data))
|
tensor_data = torch.ByteTensor(
|
||||||
|
np.frombuffer(serialized_data, dtype=np.uint8)
|
||||||
|
)
|
||||||
tensor_size = torch.tensor([size], dtype=torch.long)
|
tensor_size = torch.tensor([size], dtype=torch.long)
|
||||||
|
|
||||||
dist.broadcast(tensor_size, src=0, group=dist_group)
|
dist.broadcast(tensor_size, src=0, group=dist_group)
|
||||||
|
|||||||
Reference in New Issue
Block a user