Improve error handling (#433)
This commit is contained in:
@@ -29,7 +29,7 @@ from sglang.lang.ir import (
|
|||||||
SglVarScopeBegin,
|
SglVarScopeBegin,
|
||||||
SglVarScopeEnd,
|
SglVarScopeEnd,
|
||||||
)
|
)
|
||||||
from sglang.utils import encode_image_base64
|
from sglang.utils import encode_image_base64, get_exception_traceback
|
||||||
|
|
||||||
|
|
||||||
def run_internal(state, program, func_args, func_kwargs, sync):
|
def run_internal(state, program, func_args, func_kwargs, sync):
|
||||||
@@ -195,6 +195,7 @@ class StreamExecutor:
|
|||||||
self.variable_event = {} # Dict[name: str -> event: threading.Event]
|
self.variable_event = {} # Dict[name: str -> event: threading.Event]
|
||||||
self.meta_info = {} # Dict[name: str -> info: str]
|
self.meta_info = {} # Dict[name: str -> info: str]
|
||||||
self.is_finished = False
|
self.is_finished = False
|
||||||
|
self.error = None
|
||||||
|
|
||||||
# For completion
|
# For completion
|
||||||
self.text_ = "" # The full text
|
self.text_ = "" # The full text
|
||||||
@@ -310,17 +311,39 @@ class StreamExecutor:
|
|||||||
self.backend.end_program(self)
|
self.backend.end_program(self)
|
||||||
|
|
||||||
def _thread_worker_func(self):
|
def _thread_worker_func(self):
|
||||||
|
error = None
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
expr = self.queue.get()
|
expr = self.queue.get()
|
||||||
if expr is None:
|
if expr is None:
|
||||||
self.queue.task_done()
|
self.queue.task_done()
|
||||||
break
|
break
|
||||||
|
|
||||||
self._execute(expr)
|
try:
|
||||||
|
self._execute(expr)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error in stream_executor: {get_exception_traceback()}")
|
||||||
|
error = e
|
||||||
|
break
|
||||||
self.queue.task_done()
|
self.queue.task_done()
|
||||||
if self.stream_text_event:
|
if self.stream_text_event:
|
||||||
self.stream_text_event.set()
|
self.stream_text_event.set()
|
||||||
|
|
||||||
|
# Clean the queue and events
|
||||||
|
if error is not None:
|
||||||
|
try:
|
||||||
|
while True:
|
||||||
|
self.queue.task_done()
|
||||||
|
self.queue.get_nowait()
|
||||||
|
except queue.Empty:
|
||||||
|
pass
|
||||||
|
for name in self.variable_event:
|
||||||
|
self.variable_event[name].set()
|
||||||
|
if self.stream_var_event:
|
||||||
|
for name in self.stream_var_event:
|
||||||
|
self.stream_var_event[name].set()
|
||||||
|
self.error = error
|
||||||
|
|
||||||
if self.stream_text_event:
|
if self.stream_text_event:
|
||||||
self.stream_text_event.set()
|
self.stream_text_event.set()
|
||||||
|
|
||||||
@@ -679,7 +702,9 @@ class ProgramState:
|
|||||||
return self.stream_executor.messages()
|
return self.stream_executor.messages()
|
||||||
|
|
||||||
def sync(self):
|
def sync(self):
|
||||||
return self.stream_executor.sync()
|
ret = self.stream_executor.sync()
|
||||||
|
self.error = self.stream_executor.error
|
||||||
|
return ret
|
||||||
|
|
||||||
def text_iter(self, var_name: Optional[str] = None):
|
def text_iter(self, var_name: Optional[str] = None):
|
||||||
if self.stream_executor.stream:
|
if self.stream_executor.stream:
|
||||||
@@ -769,6 +794,9 @@ class ProgramState:
|
|||||||
def __setitem__(self, name, value):
|
def __setitem__(self, name, value):
|
||||||
self.set_var(name, value)
|
self.set_var(name, value)
|
||||||
|
|
||||||
|
def __contains__(self, name):
|
||||||
|
return name in self.stream_executor.variables
|
||||||
|
|
||||||
def __del__(self):
|
def __del__(self):
|
||||||
self.stream_executor.end()
|
self.stream_executor.end()
|
||||||
|
|
||||||
|
|||||||
16
python/sglang/srt/flush_cache.py
Normal file
16
python/sglang/srt/flush_cache.py
Normal file
@@ -0,0 +1,16 @@
|
|||||||
|
"""
|
||||||
|
Usage:
|
||||||
|
python3 -m sglang.srt.flush_cache --url http://localhost:30000
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
import requests
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("--url", type=str, default="http://localhost:30000")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
response = requests.get(args.url + "/flush_cache")
|
||||||
|
assert response.status_code == 200
|
||||||
@@ -135,6 +135,8 @@ class ModelRpcServer:
|
|||||||
self.out_pyobjs = []
|
self.out_pyobjs = []
|
||||||
self.decode_forward_ct = 0
|
self.decode_forward_ct = 0
|
||||||
self.stream_interval = server_args.stream_interval
|
self.stream_interval = server_args.stream_interval
|
||||||
|
self.num_generated_tokens = 0
|
||||||
|
self.last_stats_tic = time.time()
|
||||||
|
|
||||||
# Init the FSM cache for constrained generation
|
# Init the FSM cache for constrained generation
|
||||||
self.regex_fsm_cache = FSMCache(
|
self.regex_fsm_cache = FSMCache(
|
||||||
@@ -211,6 +213,7 @@ class ModelRpcServer:
|
|||||||
if self.running_batch is not None:
|
if self.running_batch is not None:
|
||||||
# Run a few decode batches continuously for reducing overhead
|
# Run a few decode batches continuously for reducing overhead
|
||||||
for _ in range(10):
|
for _ in range(10):
|
||||||
|
self.num_generated_tokens += len(self.running_batch.reqs)
|
||||||
self.forward_decode_batch(self.running_batch)
|
self.forward_decode_batch(self.running_batch)
|
||||||
|
|
||||||
if self.running_batch.is_empty():
|
if self.running_batch.is_empty():
|
||||||
@@ -226,10 +229,14 @@ class ModelRpcServer:
|
|||||||
self.token_to_kv_pool.available_size()
|
self.token_to_kv_pool.available_size()
|
||||||
+ self.tree_cache.evictable_size()
|
+ self.tree_cache.evictable_size()
|
||||||
)
|
)
|
||||||
|
throuhgput = self.num_generated_tokens / (time.time() - self.last_stats_tic)
|
||||||
|
self.num_generated_tokens = 0
|
||||||
|
self.last_stats_tic = time.time()
|
||||||
logger.info(
|
logger.info(
|
||||||
f"#running-req: {len(self.running_batch.reqs)}, "
|
f"#running-req: {len(self.running_batch.reqs)}, "
|
||||||
f"#token: {num_used}, "
|
f"#token: {num_used}, "
|
||||||
f"token usage: {num_used / self.max_total_num_token:.2f}, "
|
f"token usage: {num_used / self.max_total_num_token:.2f}, "
|
||||||
|
f"gen throughput (token/s): {throuhgput:.2f}, "
|
||||||
f"#queue-req: {len(self.forward_queue)}"
|
f"#queue-req: {len(self.forward_queue)}"
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -17,8 +17,8 @@ from vllm.distributed import initialize_model_parallel
|
|||||||
|
|
||||||
from sglang.srt.managers.router.infer_batch import Batch, ForwardMode
|
from sglang.srt.managers.router.infer_batch import Batch, ForwardMode
|
||||||
from sglang.srt.memory_pool import ReqToTokenPool, TokenToKVPool
|
from sglang.srt.memory_pool import ReqToTokenPool, TokenToKVPool
|
||||||
from sglang.srt.utils import is_multimodal_model
|
from sglang.srt.utils import is_multimodal_model, get_available_gpu_memory
|
||||||
from sglang.utils import get_available_gpu_memory
|
|
||||||
|
|
||||||
QUANTIZATION_CONFIG_MAPPING = {
|
QUANTIZATION_CONFIG_MAPPING = {
|
||||||
"awq": AWQConfig,
|
"awq": AWQConfig,
|
||||||
|
|||||||
@@ -4,9 +4,7 @@ import base64
|
|||||||
import os
|
import os
|
||||||
import random
|
import random
|
||||||
import socket
|
import socket
|
||||||
import sys
|
|
||||||
import time
|
import time
|
||||||
import traceback
|
|
||||||
from importlib.metadata import PackageNotFoundError, version
|
from importlib.metadata import PackageNotFoundError, version
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
@@ -20,6 +18,8 @@ from packaging import version as pkg_version
|
|||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from starlette.middleware.base import BaseHTTPMiddleware
|
from starlette.middleware.base import BaseHTTPMiddleware
|
||||||
|
|
||||||
|
from sglang.utils import get_exception_traceback
|
||||||
|
|
||||||
show_time_cost = False
|
show_time_cost = False
|
||||||
time_infos = {}
|
time_infos = {}
|
||||||
|
|
||||||
@@ -90,6 +90,32 @@ def calculate_time(show=False, min_cost_ms=0.0):
|
|||||||
return wrapper
|
return wrapper
|
||||||
|
|
||||||
|
|
||||||
|
def get_available_gpu_memory(gpu_id, distributed=True):
|
||||||
|
"""
|
||||||
|
Get available memory for cuda:gpu_id device.
|
||||||
|
When distributed is True, the available memory is the minimum available memory of all GPUs.
|
||||||
|
"""
|
||||||
|
num_gpus = torch.cuda.device_count()
|
||||||
|
assert gpu_id < num_gpus
|
||||||
|
|
||||||
|
if torch.cuda.current_device() != gpu_id:
|
||||||
|
print(
|
||||||
|
f"WARNING: current device is not {gpu_id}, but {torch.cuda.current_device()}, ",
|
||||||
|
"which may cause useless memory allocation for torch CUDA context.",
|
||||||
|
)
|
||||||
|
|
||||||
|
free_gpu_memory, _ = torch.cuda.mem_get_info(gpu_id)
|
||||||
|
|
||||||
|
if distributed:
|
||||||
|
tensor = torch.tensor(free_gpu_memory, dtype=torch.float32).to(
|
||||||
|
torch.device("cuda", gpu_id)
|
||||||
|
)
|
||||||
|
torch.distributed.all_reduce(tensor, op=torch.distributed.ReduceOp.MIN)
|
||||||
|
free_gpu_memory = tensor.item()
|
||||||
|
|
||||||
|
return free_gpu_memory / (1 << 30)
|
||||||
|
|
||||||
|
|
||||||
def set_random_seed(seed: int) -> None:
|
def set_random_seed(seed: int) -> None:
|
||||||
random.seed(seed)
|
random.seed(seed)
|
||||||
|
|
||||||
@@ -158,12 +184,6 @@ def allocate_init_ports(
|
|||||||
return port, additional_ports
|
return port, additional_ports
|
||||||
|
|
||||||
|
|
||||||
def get_exception_traceback():
|
|
||||||
etype, value, tb = sys.exc_info()
|
|
||||||
err_str = "".join(traceback.format_exception(etype, value, tb))
|
|
||||||
return err_str
|
|
||||||
|
|
||||||
|
|
||||||
def get_int_token_logit_bias(tokenizer, vocab_size):
|
def get_int_token_logit_bias(tokenizer, vocab_size):
|
||||||
# a bug when model's vocab size > tokenizer.vocab_size
|
# a bug when model's vocab size > tokenizer.vocab_size
|
||||||
vocab_size = tokenizer.vocab_size
|
vocab_size = tokenizer.vocab_size
|
||||||
|
|||||||
@@ -2,7 +2,9 @@
|
|||||||
|
|
||||||
import base64
|
import base64
|
||||||
import json
|
import json
|
||||||
|
import sys
|
||||||
import threading
|
import threading
|
||||||
|
import traceback
|
||||||
import urllib.request
|
import urllib.request
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
from json import dumps
|
from json import dumps
|
||||||
@@ -10,32 +12,10 @@ from json import dumps
|
|||||||
import requests
|
import requests
|
||||||
|
|
||||||
|
|
||||||
def get_available_gpu_memory(gpu_id, distributed=True):
|
def get_exception_traceback():
|
||||||
"""
|
etype, value, tb = sys.exc_info()
|
||||||
Get available memory for cuda:gpu_id device.
|
err_str = "".join(traceback.format_exception(etype, value, tb))
|
||||||
When distributed is True, the available memory is the minimum available memory of all GPUs.
|
return err_str
|
||||||
"""
|
|
||||||
import torch
|
|
||||||
|
|
||||||
num_gpus = torch.cuda.device_count()
|
|
||||||
assert gpu_id < num_gpus
|
|
||||||
|
|
||||||
if torch.cuda.current_device() != gpu_id:
|
|
||||||
print(
|
|
||||||
f"WARNING: current device is not {gpu_id}, but {torch.cuda.current_device()}, ",
|
|
||||||
"which may cause useless memory allocation for torch CUDA context.",
|
|
||||||
)
|
|
||||||
|
|
||||||
free_gpu_memory, _ = torch.cuda.mem_get_info(gpu_id)
|
|
||||||
|
|
||||||
if distributed:
|
|
||||||
tensor = torch.tensor(free_gpu_memory, dtype=torch.float32).to(
|
|
||||||
torch.device("cuda", gpu_id)
|
|
||||||
)
|
|
||||||
torch.distributed.all_reduce(tensor, op=torch.distributed.ReduceOp.MIN)
|
|
||||||
free_gpu_memory = tensor.item()
|
|
||||||
|
|
||||||
return free_gpu_memory / (1 << 30)
|
|
||||||
|
|
||||||
|
|
||||||
def is_same_type(values):
|
def is_same_type(values):
|
||||||
|
|||||||
Reference in New Issue
Block a user