[Minor] Improve code style (#2422)
This commit is contained in:
@@ -33,7 +33,7 @@ srt_hip = ["sglang[runtime_common]", "torch", "vllm==0.6.3.dev13"]
|
||||
srt_xpu = ["sglang[runtime_common]"]
|
||||
#For Intel Gaudi(device : hpu) follow the installation guide
|
||||
#https://docs.vllm.ai/en/latest/getting_started/gaudi-installation.html
|
||||
srt_hpu = ["sglang[runtime_common]"]
|
||||
srt_hpu = ["sglang[runtime_common]"]
|
||||
|
||||
openai = ["openai>=1.0", "tiktoken"]
|
||||
anthropic = ["anthropic>=0.20.0"]
|
||||
@@ -50,6 +50,7 @@ all = ["sglang[srt]", "sglang[openai]", "sglang[anthropic]", "sglang[litellm]"]
|
||||
all_hip = ["sglang[srt_hip]", "sglang[openai]", "sglang[anthropic]", "sglang[litellm]"]
|
||||
all_xpu = ["sglang[srt_xpu]", "sglang[openai]", "sglang[anthropic]", "sglang[litellm]"]
|
||||
all_hpu = ["sglang[srt_hpu]", "sglang[openai]", "sglang[anthropic]", "sglang[litellm]"]
|
||||
|
||||
dev = ["sglang[all]", "sglang[test]"]
|
||||
dev_hip = ["sglang[all_hip]", "sglang[test]"]
|
||||
dev_xpu = ["sglang[all_xpu]", "sglang[test]"]
|
||||
|
||||
@@ -285,7 +285,7 @@ def throughput_test(
|
||||
else:
|
||||
raise ValueError('Please set backend to either "engine" or "runtime"')
|
||||
|
||||
tokenizer_id = server_args.model_path
|
||||
tokenizer_id = server_args.tokenizer_path or server_args.model_path
|
||||
tokenizer = get_tokenizer(tokenizer_id)
|
||||
|
||||
# Set global environmnets
|
||||
|
||||
@@ -117,7 +117,10 @@ class XGrammarGrammarBackend(BaseGrammarBackend):
|
||||
key_type, key_string = key
|
||||
if key_type == "json":
|
||||
try:
|
||||
ctx = self.grammar_compiler.compile_json_schema(schema=key_string)
|
||||
if key_string == "$$ANY$$":
|
||||
ctx = self.grammar_compiler.compile_builtin_json_grammar()
|
||||
else:
|
||||
ctx = self.grammar_compiler.compile_json_schema(schema=key_string)
|
||||
except RuntimeError as e:
|
||||
logging.warning(
|
||||
f"Skip invalid json_schema: json_schema={key_string}, {e=}"
|
||||
|
||||
@@ -5,7 +5,6 @@ from typing import TYPE_CHECKING
|
||||
import torch
|
||||
|
||||
from sglang.srt.layers.attention import AttentionBackend
|
||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
||||
@@ -48,7 +48,14 @@ class RadixAttention(nn.Module):
|
||||
self.sliding_window_size = sliding_window_size or -1
|
||||
self.is_cross_attention = is_cross_attention
|
||||
|
||||
def forward(self, q, k, v, forward_batch: ForwardBatch, save_kv_cache=True):
|
||||
def forward(
|
||||
self,
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
forward_batch: ForwardBatch,
|
||||
save_kv_cache: bool = True,
|
||||
):
|
||||
if k is not None:
|
||||
# For cross-layer sharing, kv can be None
|
||||
assert v is not None
|
||||
|
||||
@@ -484,7 +484,7 @@ bid = 0
|
||||
|
||||
@dataclasses.dataclass
|
||||
class ScheduleBatch:
|
||||
"""Store all inforamtion of a batch on the scheduler."""
|
||||
"""Store all information of a batch on the scheduler."""
|
||||
|
||||
# Request, memory pool, and cache
|
||||
reqs: List[Req]
|
||||
|
||||
@@ -22,7 +22,7 @@ import signal
|
||||
import sys
|
||||
import time
|
||||
import uuid
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
import fastapi
|
||||
import uvloop
|
||||
|
||||
@@ -127,7 +127,7 @@ class CudaGraphRunner:
|
||||
|
||||
# Batch sizes to capture
|
||||
if model_runner.server_args.disable_cuda_graph_padding:
|
||||
self.capture_bs = list(range(1, 32)) + [64, 128]
|
||||
self.capture_bs = list(range(1, 33)) + [64, 128]
|
||||
else:
|
||||
self.capture_bs = [1, 2, 4] + [i * 8 for i in range(1, 21)]
|
||||
|
||||
|
||||
@@ -242,20 +242,22 @@ class ModelRunner:
|
||||
if torch.cuda.get_device_capability()[1] < 5:
|
||||
raise RuntimeError("SGLang only supports sm75 and above.")
|
||||
|
||||
# Prepare the vllm model config
|
||||
# Prepare the model config
|
||||
self.load_config = LoadConfig(
|
||||
load_format=self.server_args.load_format,
|
||||
download_dir=self.server_args.download_dir,
|
||||
)
|
||||
|
||||
if self.server_args.load_format == "gguf":
|
||||
monkey_patch_vllm_gguf_config()
|
||||
|
||||
# Load the model
|
||||
self.model = get_model(
|
||||
model_config=self.model_config,
|
||||
load_config=self.load_config,
|
||||
device_config=DeviceConfig(self.device),
|
||||
)
|
||||
|
||||
# Parse other args
|
||||
self.sliding_window_size = (
|
||||
self.model.get_attention_sliding_window_size()
|
||||
if hasattr(self.model, "get_attention_sliding_window_size")
|
||||
@@ -270,8 +272,10 @@ class ModelRunner:
|
||||
f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
|
||||
)
|
||||
|
||||
def update_weights_from_disk(self, model_path: str, load_format: str):
|
||||
"""Update engine weights online from disk."""
|
||||
def update_weights_from_disk(
|
||||
self, model_path: str, load_format: str
|
||||
) -> tuple[bool, str]:
|
||||
"""Update engine weights in-place from the disk."""
|
||||
from sglang.srt.model_loader.loader import (
|
||||
DefaultModelLoader,
|
||||
device_loading_context,
|
||||
|
||||
@@ -32,7 +32,6 @@ class Gemma2ForSequenceClassification(nn.Module):
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.torchao_config = None
|
||||
self.quant_config = quant_config
|
||||
self.num_labels = config.num_labels
|
||||
self.model = Gemma2Model(config, quant_config=quant_config)
|
||||
|
||||
@@ -33,7 +33,6 @@ class LlamaForClassification(nn.Module):
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.torchao_config = None
|
||||
self.quant_config = quant_config
|
||||
self.model = LlamaModel(config, quant_config=quant_config)
|
||||
|
||||
|
||||
@@ -21,7 +21,6 @@ from transformers import LlamaConfig
|
||||
from sglang.srt.layers.pooler import EmbeddingPoolerOutput, Pooler, PoolingType
|
||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
||||
from sglang.srt.models.llama import LlamaForCausalLM, LlamaModel
|
||||
|
||||
|
||||
@@ -33,7 +32,6 @@ class LlamaForSequenceClassification(nn.Module):
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.torchao_config = None
|
||||
self.quant_config = quant_config
|
||||
self.num_labels = config.num_labels
|
||||
self.model = LlamaModel(config, quant_config=quant_config)
|
||||
|
||||
@@ -196,7 +196,7 @@ async def stop_profile_async():
|
||||
@app.post("/update_weights_from_disk")
|
||||
@time_func_latency
|
||||
async def update_weights_from_disk(obj: UpdateWeightFromDiskReqInput, request: Request):
|
||||
"""Update the weights from disk inplace without re-launching the server."""
|
||||
"""Update the weights from disk in-place without re-launching the server."""
|
||||
success, message = await tokenizer_manager.update_weights_from_disk(obj, request)
|
||||
content = {"success": success, "message": message}
|
||||
if success:
|
||||
|
||||
@@ -169,7 +169,7 @@ def calculate_time(show=False, min_cost_ms=0.0):
|
||||
return wrapper
|
||||
|
||||
|
||||
def get_available_gpu_memory(device, gpu_id, distributed=False):
|
||||
def get_available_gpu_memory(device, gpu_id, distributed=False, empty_cache=True):
|
||||
"""
|
||||
Get available memory for cuda:gpu_id device.
|
||||
When distributed is True, the available memory is the minimum available memory of all GPUs.
|
||||
@@ -184,7 +184,8 @@ def get_available_gpu_memory(device, gpu_id, distributed=False):
|
||||
"which may cause useless memory allocation for torch CUDA context.",
|
||||
)
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
if empty_cache:
|
||||
torch.cuda.empty_cache()
|
||||
free_gpu_memory, _ = torch.cuda.mem_get_info(gpu_id)
|
||||
|
||||
elif device == "xpu":
|
||||
@@ -196,7 +197,9 @@ def get_available_gpu_memory(device, gpu_id, distributed=False):
|
||||
f"WARNING: current device is not {gpu_id}, but {torch.xpu.current_device()}, ",
|
||||
"which may cause useless memory allocation for torch XPU context.",
|
||||
)
|
||||
torch.xpu.empty_cache()
|
||||
|
||||
if empty_cache:
|
||||
torch.xpu.empty_cache()
|
||||
used_memory = torch.xpu.memory_allocated()
|
||||
total_gpu_memory = torch.xpu.get_device_properties(gpu_id).total_memory
|
||||
free_gpu_memory = total_gpu_memory - used_memory
|
||||
|
||||
Reference in New Issue
Block a user