diff --git a/python/sglang/srt/layers/activation.py b/python/sglang/srt/layers/activation.py index 8f2b0e116..94d48e82b 100644 --- a/python/sglang/srt/layers/activation.py +++ b/python/sglang/srt/layers/activation.py @@ -148,6 +148,6 @@ def get_act_fn( if not is_flashinfer_available(): logger.info( - "FlashInfer is not available on Non-NV GPUs. Fallback to other kernel libraries." + "FlashInfer is not available on Non-NV platforms. Fallback to other kernel libraries." ) from vllm.model_executor.layers.activation import GeluAndMul, SiluAndMul diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 46789568c..fe8fd895b 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -234,14 +234,9 @@ class Scheduler: recv_reqs = self.recv_requests() self.process_input_requests(recv_reqs) - # Run one step self.run_step() - # Send results - if self.tp_rank == 0: - for obj in self.out_pyobjs: - self.send_to_detokenizer.send_pyobj(obj) - self.out_pyobjs = [] + self.send_results() def recv_requests(self): if self.tp_rank == 0: @@ -256,7 +251,8 @@ class Scheduler: else: recv_reqs = None - recv_reqs = broadcast_pyobj(recv_reqs, self.tp_rank, self.tp_cpu_group) + if self.tp_size != 1: + recv_reqs = broadcast_pyobj(recv_reqs, self.tp_rank, self.tp_cpu_group) return recv_reqs def process_input_requests(self, recv_reqs: List): @@ -366,43 +362,11 @@ class Scheduler: self.waiting_queue.append(req) - def run_step(self): - new_batch = self.get_new_batch_prefill() - - if new_batch is not None: - # Run a new prefill batch - result = self.run_batch(new_batch) - self.process_batch_result(new_batch, result) - - if not new_batch.is_empty(): - if self.running_batch is None: - self.running_batch = new_batch - else: - self.running_batch.merge_batch(new_batch) - else: - # Run a decode batch - if self.running_batch is not None: - # Run a few decode batches continuously for reducing overhead - for _ in range(global_config.num_continue_decode_steps): - batch = self.get_new_batch_decode() - - if batch: - result = self.run_batch(batch) - self.process_batch_result(batch, result) - - # Print stats - if self.tp_rank == 0 and self.decode_forward_ct % 40 == 0: - self.print_decode_stats() - - if self.running_batch.is_empty(): - self.running_batch = None - break - - if self.out_pyobjs and self.running_batch.has_stream: - break - else: - self.check_memory() - self.new_token_ratio = global_config.init_new_token_ratio + def send_results(self): + if self.tp_rank == 0: + for obj in self.out_pyobjs: + self.send_to_detokenizer.send_pyobj(obj) + self.out_pyobjs = [] def print_decode_stats(self): num_used = self.max_total_num_tokens - ( @@ -441,6 +405,31 @@ class Scheduler: ) exit(1) if crash_on_warning else None + def run_step(self): + new_batch = self.get_new_batch_prefill() + if new_batch is not None: + # Run a new prefill batch + result = self.run_batch(new_batch) + self.process_batch_result(new_batch, result) + else: + if self.running_batch is not None: + # Run a few decode batches continuously for reducing overhead + for _ in range(global_config.num_continue_decode_steps): + batch = self.get_new_batch_decode() + + if batch: + result = self.run_batch(batch) + self.process_batch_result(batch, result) + + if self.running_batch is None: + break + + if self.out_pyobjs and self.running_batch.has_stream: + break + else: + self.check_memory() + self.new_token_ratio = global_config.init_new_token_ratio + def get_new_batch_prefill(self) -> Optional[ScheduleBatch]: # Handle the cases where prefill is not allowed if ( @@ -612,7 +601,6 @@ class Scheduler: return None # Update batch tensors - self.decode_forward_ct = (self.decode_forward_ct + 1) % (1 << 30) batch.prepare_for_decode() return batch @@ -723,6 +711,12 @@ class Scheduler: self.handle_finished_requests(batch) + if not batch.is_empty(): + if self.running_batch is None: + self.running_batch = batch + else: + self.running_batch.merge_batch(batch) + def process_batch_result_decode(self, batch: ScheduleBatch, result): logits_output, next_token_ids = result batch.sampling_info.penalizer_orchestrator.cumulate_output_tokens( @@ -762,6 +756,13 @@ class Scheduler: self.handle_finished_requests(batch) + self.decode_forward_ct = (self.decode_forward_ct + 1) % (1 << 30) + if self.tp_rank == 0 and self.decode_forward_ct % 40 == 0: + self.print_decode_stats() + + if self.running_batch.is_empty(): + self.running_batch = None + def add_logprob_return_values( self, i: int, diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py index 39106581a..6d63d42b0 100644 --- a/python/sglang/srt/utils.py +++ b/python/sglang/srt/utils.py @@ -24,6 +24,7 @@ import random import resource import socket import time +import warnings from importlib.metadata import PackageNotFoundError, version from io import BytesIO from typing import Any, Dict, List, Optional, Union @@ -333,6 +334,10 @@ def suppress_other_loggers(): logging.getLogger("vllm.selector").setLevel(logging.WARN) logging.getLogger("vllm.utils").setLevel(logging.ERROR) + warnings.filterwarnings( + "ignore", category=UserWarning, message="The given NumPy array is not writable" + ) + def assert_pkg_version(pkg: str, min_version: str, message: str): try: @@ -615,7 +620,9 @@ def broadcast_pyobj( else: serialized_data = pickle.dumps(data) size = len(serialized_data) - tensor_data = torch.ByteTensor(list(serialized_data)) + tensor_data = torch.ByteTensor( + np.frombuffer(serialized_data, dtype=np.uint8) + ) tensor_size = torch.tensor([size], dtype=torch.long) dist.broadcast(tensor_size, src=0, group=dist_group)