diff --git a/python/sglang/lang/interpreter.py b/python/sglang/lang/interpreter.py index fa5800e09..03406eb74 100644 --- a/python/sglang/lang/interpreter.py +++ b/python/sglang/lang/interpreter.py @@ -29,7 +29,7 @@ from sglang.lang.ir import ( SglVarScopeBegin, SglVarScopeEnd, ) -from sglang.utils import encode_image_base64 +from sglang.utils import encode_image_base64, get_exception_traceback def run_internal(state, program, func_args, func_kwargs, sync): @@ -195,6 +195,7 @@ class StreamExecutor: self.variable_event = {} # Dict[name: str -> event: threading.Event] self.meta_info = {} # Dict[name: str -> info: str] self.is_finished = False + self.error = None # For completion self.text_ = "" # The full text @@ -310,17 +311,39 @@ class StreamExecutor: self.backend.end_program(self) def _thread_worker_func(self): + error = None + while True: expr = self.queue.get() if expr is None: self.queue.task_done() break - self._execute(expr) + try: + self._execute(expr) + except Exception as e: + print(f"Error in stream_executor: {get_exception_traceback()}") + error = e + break self.queue.task_done() if self.stream_text_event: self.stream_text_event.set() + # Clean the queue and events + if error is not None: + try: + while True: + self.queue.task_done() + self.queue.get_nowait() + except queue.Empty: + pass + for name in self.variable_event: + self.variable_event[name].set() + if self.stream_var_event: + for name in self.stream_var_event: + self.stream_var_event[name].set() + self.error = error + if self.stream_text_event: self.stream_text_event.set() @@ -679,7 +702,9 @@ class ProgramState: return self.stream_executor.messages() def sync(self): - return self.stream_executor.sync() + ret = self.stream_executor.sync() + self.error = self.stream_executor.error + return ret def text_iter(self, var_name: Optional[str] = None): if self.stream_executor.stream: @@ -769,6 +794,9 @@ class ProgramState: def __setitem__(self, name, value): self.set_var(name, value) + def __contains__(self, name): + return name in self.stream_executor.variables + def __del__(self): self.stream_executor.end() diff --git a/python/sglang/srt/flush_cache.py b/python/sglang/srt/flush_cache.py new file mode 100644 index 000000000..3d695d44d --- /dev/null +++ b/python/sglang/srt/flush_cache.py @@ -0,0 +1,16 @@ +""" +Usage: +python3 -m sglang.srt.flush_cache --url http://localhost:30000 +""" + +import argparse + +import requests + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--url", type=str, default="http://localhost:30000") + args = parser.parse_args() + + response = requests.get(args.url + "/flush_cache") + assert response.status_code == 200 \ No newline at end of file diff --git a/python/sglang/srt/managers/router/model_rpc.py b/python/sglang/srt/managers/router/model_rpc.py index ea1b7ae5e..f85faecd0 100644 --- a/python/sglang/srt/managers/router/model_rpc.py +++ b/python/sglang/srt/managers/router/model_rpc.py @@ -135,6 +135,8 @@ class ModelRpcServer: self.out_pyobjs = [] self.decode_forward_ct = 0 self.stream_interval = server_args.stream_interval + self.num_generated_tokens = 0 + self.last_stats_tic = time.time() # Init the FSM cache for constrained generation self.regex_fsm_cache = FSMCache( @@ -211,6 +213,7 @@ class ModelRpcServer: if self.running_batch is not None: # Run a few decode batches continuously for reducing overhead for _ in range(10): + self.num_generated_tokens += len(self.running_batch.reqs) self.forward_decode_batch(self.running_batch) if self.running_batch.is_empty(): @@ -226,10 +229,14 @@ class ModelRpcServer: self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size() ) + throuhgput = self.num_generated_tokens / (time.time() - self.last_stats_tic) + self.num_generated_tokens = 0 + self.last_stats_tic = time.time() logger.info( f"#running-req: {len(self.running_batch.reqs)}, " f"#token: {num_used}, " f"token usage: {num_used / self.max_total_num_token:.2f}, " + f"gen throughput (token/s): {throuhgput:.2f}, " f"#queue-req: {len(self.forward_queue)}" ) else: diff --git a/python/sglang/srt/managers/router/model_runner.py b/python/sglang/srt/managers/router/model_runner.py index 65363a386..40a32369a 100644 --- a/python/sglang/srt/managers/router/model_runner.py +++ b/python/sglang/srt/managers/router/model_runner.py @@ -17,8 +17,8 @@ from vllm.distributed import initialize_model_parallel from sglang.srt.managers.router.infer_batch import Batch, ForwardMode from sglang.srt.memory_pool import ReqToTokenPool, TokenToKVPool -from sglang.srt.utils import is_multimodal_model -from sglang.utils import get_available_gpu_memory +from sglang.srt.utils import is_multimodal_model, get_available_gpu_memory + QUANTIZATION_CONFIG_MAPPING = { "awq": AWQConfig, diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py index 56f408db0..11bf139bf 100644 --- a/python/sglang/srt/utils.py +++ b/python/sglang/srt/utils.py @@ -4,9 +4,7 @@ import base64 import os import random import socket -import sys import time -import traceback from importlib.metadata import PackageNotFoundError, version from io import BytesIO from typing import List, Optional @@ -20,6 +18,8 @@ from packaging import version as pkg_version from pydantic import BaseModel from starlette.middleware.base import BaseHTTPMiddleware +from sglang.utils import get_exception_traceback + show_time_cost = False time_infos = {} @@ -90,6 +90,32 @@ def calculate_time(show=False, min_cost_ms=0.0): return wrapper +def get_available_gpu_memory(gpu_id, distributed=True): + """ + Get available memory for cuda:gpu_id device. + When distributed is True, the available memory is the minimum available memory of all GPUs. + """ + num_gpus = torch.cuda.device_count() + assert gpu_id < num_gpus + + if torch.cuda.current_device() != gpu_id: + print( + f"WARNING: current device is not {gpu_id}, but {torch.cuda.current_device()}, ", + "which may cause useless memory allocation for torch CUDA context.", + ) + + free_gpu_memory, _ = torch.cuda.mem_get_info(gpu_id) + + if distributed: + tensor = torch.tensor(free_gpu_memory, dtype=torch.float32).to( + torch.device("cuda", gpu_id) + ) + torch.distributed.all_reduce(tensor, op=torch.distributed.ReduceOp.MIN) + free_gpu_memory = tensor.item() + + return free_gpu_memory / (1 << 30) + + def set_random_seed(seed: int) -> None: random.seed(seed) @@ -158,12 +184,6 @@ def allocate_init_ports( return port, additional_ports -def get_exception_traceback(): - etype, value, tb = sys.exc_info() - err_str = "".join(traceback.format_exception(etype, value, tb)) - return err_str - - def get_int_token_logit_bias(tokenizer, vocab_size): # a bug when model's vocab size > tokenizer.vocab_size vocab_size = tokenizer.vocab_size @@ -314,4 +334,4 @@ IS_PYDANTIC_1 = int(pydantic.VERSION.split(".")[0]) == 1 def jsonify_pydantic_model(obj: BaseModel): if IS_PYDANTIC_1: return obj.json(ensure_ascii=False) - return obj.model_dump_json() \ No newline at end of file + return obj.model_dump_json() diff --git a/python/sglang/utils.py b/python/sglang/utils.py index aa993d8f8..bbe9e0844 100644 --- a/python/sglang/utils.py +++ b/python/sglang/utils.py @@ -2,7 +2,9 @@ import base64 import json +import sys import threading +import traceback import urllib.request from io import BytesIO from json import dumps @@ -10,32 +12,10 @@ from json import dumps import requests -def get_available_gpu_memory(gpu_id, distributed=True): - """ - Get available memory for cuda:gpu_id device. - When distributed is True, the available memory is the minimum available memory of all GPUs. - """ - import torch - - num_gpus = torch.cuda.device_count() - assert gpu_id < num_gpus - - if torch.cuda.current_device() != gpu_id: - print( - f"WARNING: current device is not {gpu_id}, but {torch.cuda.current_device()}, ", - "which may cause useless memory allocation for torch CUDA context.", - ) - - free_gpu_memory, _ = torch.cuda.mem_get_info(gpu_id) - - if distributed: - tensor = torch.tensor(free_gpu_memory, dtype=torch.float32).to( - torch.device("cuda", gpu_id) - ) - torch.distributed.all_reduce(tensor, op=torch.distributed.ReduceOp.MIN) - free_gpu_memory = tensor.item() - - return free_gpu_memory / (1 << 30) +def get_exception_traceback(): + etype, value, tb = sys.exc_info() + err_str = "".join(traceback.format_exception(etype, value, tb)) + return err_str def is_same_type(values): @@ -190,4 +170,4 @@ def run_with_timeout(func, args=(), kwargs=None, timeout=None): if not ret_value: raise RuntimeError() - return ret_value[0] + return ret_value[0] \ No newline at end of file