[Minor] Improve code style (#2422)

This commit is contained in:
Lianmin Zheng
2024-12-09 06:30:35 -08:00
committed by GitHub
parent 0ce091a82d
commit 641b7d0ae0
15 changed files with 33 additions and 21 deletions

View File

@@ -33,7 +33,7 @@ srt_hip = ["sglang[runtime_common]", "torch", "vllm==0.6.3.dev13"]
srt_xpu = ["sglang[runtime_common]"]
#For Intel Gaudi(device : hpu) follow the installation guide
#https://docs.vllm.ai/en/latest/getting_started/gaudi-installation.html
srt_hpu = ["sglang[runtime_common]"]
srt_hpu = ["sglang[runtime_common]"]
openai = ["openai>=1.0", "tiktoken"]
anthropic = ["anthropic>=0.20.0"]
@@ -50,6 +50,7 @@ all = ["sglang[srt]", "sglang[openai]", "sglang[anthropic]", "sglang[litellm]"]
all_hip = ["sglang[srt_hip]", "sglang[openai]", "sglang[anthropic]", "sglang[litellm]"]
all_xpu = ["sglang[srt_xpu]", "sglang[openai]", "sglang[anthropic]", "sglang[litellm]"]
all_hpu = ["sglang[srt_hpu]", "sglang[openai]", "sglang[anthropic]", "sglang[litellm]"]
dev = ["sglang[all]", "sglang[test]"]
dev_hip = ["sglang[all_hip]", "sglang[test]"]
dev_xpu = ["sglang[all_xpu]", "sglang[test]"]

View File

@@ -285,7 +285,7 @@ def throughput_test(
else:
raise ValueError('Please set backend to either "engine" or "runtime"')
tokenizer_id = server_args.model_path
tokenizer_id = server_args.tokenizer_path or server_args.model_path
tokenizer = get_tokenizer(tokenizer_id)
# Set global environmnets

View File

@@ -117,7 +117,10 @@ class XGrammarGrammarBackend(BaseGrammarBackend):
key_type, key_string = key
if key_type == "json":
try:
ctx = self.grammar_compiler.compile_json_schema(schema=key_string)
if key_string == "$$ANY$$":
ctx = self.grammar_compiler.compile_builtin_json_grammar()
else:
ctx = self.grammar_compiler.compile_json_schema(schema=key_string)
except RuntimeError as e:
logging.warning(
f"Skip invalid json_schema: json_schema={key_string}, {e=}"

View File

@@ -5,7 +5,6 @@ from typing import TYPE_CHECKING
import torch
from sglang.srt.layers.attention import AttentionBackend
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
if TYPE_CHECKING:

View File

@@ -48,7 +48,14 @@ class RadixAttention(nn.Module):
self.sliding_window_size = sliding_window_size or -1
self.is_cross_attention = is_cross_attention
def forward(self, q, k, v, forward_batch: ForwardBatch, save_kv_cache=True):
def forward(
self,
q,
k,
v,
forward_batch: ForwardBatch,
save_kv_cache: bool = True,
):
if k is not None:
# For cross-layer sharing, kv can be None
assert v is not None

View File

@@ -484,7 +484,7 @@ bid = 0
@dataclasses.dataclass
class ScheduleBatch:
"""Store all inforamtion of a batch on the scheduler."""
"""Store all information of a batch on the scheduler."""
# Request, memory pool, and cache
reqs: List[Req]

View File

@@ -22,7 +22,7 @@ import signal
import sys
import time
import uuid
from typing import Any, Dict, List, Optional, Tuple, Union
from typing import Any, Dict, List, Optional, Union
import fastapi
import uvloop

View File

@@ -127,7 +127,7 @@ class CudaGraphRunner:
# Batch sizes to capture
if model_runner.server_args.disable_cuda_graph_padding:
self.capture_bs = list(range(1, 32)) + [64, 128]
self.capture_bs = list(range(1, 33)) + [64, 128]
else:
self.capture_bs = [1, 2, 4] + [i * 8 for i in range(1, 21)]

View File

@@ -242,20 +242,22 @@ class ModelRunner:
if torch.cuda.get_device_capability()[1] < 5:
raise RuntimeError("SGLang only supports sm75 and above.")
# Prepare the vllm model config
# Prepare the model config
self.load_config = LoadConfig(
load_format=self.server_args.load_format,
download_dir=self.server_args.download_dir,
)
if self.server_args.load_format == "gguf":
monkey_patch_vllm_gguf_config()
# Load the model
self.model = get_model(
model_config=self.model_config,
load_config=self.load_config,
device_config=DeviceConfig(self.device),
)
# Parse other args
self.sliding_window_size = (
self.model.get_attention_sliding_window_size()
if hasattr(self.model, "get_attention_sliding_window_size")
@@ -270,8 +272,10 @@ class ModelRunner:
f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
)
def update_weights_from_disk(self, model_path: str, load_format: str):
"""Update engine weights online from disk."""
def update_weights_from_disk(
self, model_path: str, load_format: str
) -> tuple[bool, str]:
"""Update engine weights in-place from the disk."""
from sglang.srt.model_loader.loader import (
DefaultModelLoader,
device_loading_context,

View File

@@ -32,7 +32,6 @@ class Gemma2ForSequenceClassification(nn.Module):
) -> None:
super().__init__()
self.config = config
self.torchao_config = None
self.quant_config = quant_config
self.num_labels = config.num_labels
self.model = Gemma2Model(config, quant_config=quant_config)

View File

@@ -33,7 +33,6 @@ class LlamaForClassification(nn.Module):
) -> None:
super().__init__()
self.config = config
self.torchao_config = None
self.quant_config = quant_config
self.model = LlamaModel(config, quant_config=quant_config)

View File

@@ -21,7 +21,6 @@ from transformers import LlamaConfig
from sglang.srt.layers.pooler import EmbeddingPoolerOutput, Pooler, PoolingType
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.models.llama import LlamaForCausalLM, LlamaModel
@@ -33,7 +32,6 @@ class LlamaForSequenceClassification(nn.Module):
) -> None:
super().__init__()
self.config = config
self.torchao_config = None
self.quant_config = quant_config
self.num_labels = config.num_labels
self.model = LlamaModel(config, quant_config=quant_config)

View File

@@ -196,7 +196,7 @@ async def stop_profile_async():
@app.post("/update_weights_from_disk")
@time_func_latency
async def update_weights_from_disk(obj: UpdateWeightFromDiskReqInput, request: Request):
"""Update the weights from disk inplace without re-launching the server."""
"""Update the weights from disk in-place without re-launching the server."""
success, message = await tokenizer_manager.update_weights_from_disk(obj, request)
content = {"success": success, "message": message}
if success:

View File

@@ -169,7 +169,7 @@ def calculate_time(show=False, min_cost_ms=0.0):
return wrapper
def get_available_gpu_memory(device, gpu_id, distributed=False):
def get_available_gpu_memory(device, gpu_id, distributed=False, empty_cache=True):
"""
Get available memory for cuda:gpu_id device.
When distributed is True, the available memory is the minimum available memory of all GPUs.
@@ -184,7 +184,8 @@ def get_available_gpu_memory(device, gpu_id, distributed=False):
"which may cause useless memory allocation for torch CUDA context.",
)
torch.cuda.empty_cache()
if empty_cache:
torch.cuda.empty_cache()
free_gpu_memory, _ = torch.cuda.mem_get_info(gpu_id)
elif device == "xpu":
@@ -196,7 +197,9 @@ def get_available_gpu_memory(device, gpu_id, distributed=False):
f"WARNING: current device is not {gpu_id}, but {torch.xpu.current_device()}, ",
"which may cause useless memory allocation for torch XPU context.",
)
torch.xpu.empty_cache()
if empty_cache:
torch.xpu.empty_cache()
used_memory = torch.xpu.memory_allocated()
total_gpu_memory = torch.xpu.get_device_properties(gpu_id).total_memory
free_gpu_memory = total_gpu_memory - used_memory