Cleanup readme, llava examples, usage examples and nccl init (#1194)
This commit is contained in:
@@ -111,7 +111,11 @@ def load_model(server_args, tp_rank):
|
||||
suppress_other_loggers()
|
||||
rank_print = print if tp_rank == 0 else lambda *args, **kwargs: None
|
||||
|
||||
model_config = ModelConfig(path=server_args.model_path)
|
||||
model_config = ModelConfig(
|
||||
server_args.model_path,
|
||||
server_args.trust_remote_code,
|
||||
context_length=server_args.context_length,
|
||||
)
|
||||
model_runner = ModelRunner(
|
||||
model_config=model_config,
|
||||
mem_fraction_static=server_args.mem_fraction_static,
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from dataclasses import dataclass, field
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum, auto
|
||||
from typing import Callable, Dict, List, Optional, Tuple
|
||||
from typing import Callable, Dict, List, Tuple
|
||||
|
||||
|
||||
class ChatTemplateStyle(Enum):
|
||||
|
||||
@@ -1,29 +0,0 @@
|
||||
"""Launch the inference server for Llava-video model."""
|
||||
|
||||
import argparse
|
||||
|
||||
from sglang.srt.server import ServerArgs, launch_server
|
||||
|
||||
if __name__ == "__main__":
|
||||
model_overide_args = {}
|
||||
|
||||
model_overide_args["mm_spatial_pool_stride"] = 2
|
||||
model_overide_args["architectures"] = ["LlavaVidForCausalLM"]
|
||||
model_overide_args["num_frames"] = 16
|
||||
model_overide_args["model_type"] = "llavavid"
|
||||
if model_overide_args["num_frames"] == 32:
|
||||
model_overide_args["rope_scaling"] = {"factor": 2.0, "type": "linear"}
|
||||
model_overide_args["max_sequence_length"] = 4096 * 2
|
||||
model_overide_args["tokenizer_model_max_length"] = 4096 * 2
|
||||
model_overide_args["model_max_length"] = 4096 * 2
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
ServerArgs.add_cli_args(parser)
|
||||
args = parser.parse_args()
|
||||
|
||||
if "34b" in args.model_path.lower():
|
||||
model_overide_args["image_token_index"] = 64002
|
||||
|
||||
server_args = ServerArgs.from_cli_args(args)
|
||||
|
||||
launch_server(server_args, model_overide_args, None)
|
||||
@@ -26,7 +26,7 @@ import triton.language as tl
|
||||
|
||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||
|
||||
if global_server_args_dict.get("attention_reduce_in_fp32", False):
|
||||
if global_server_args_dict.get("triton_attention_reduce_in_fp32", False):
|
||||
REDUCE_TRITON_TYPE = tl.float32
|
||||
REDUCE_TORCH_TYPE = torch.float32
|
||||
else:
|
||||
|
||||
@@ -239,7 +239,7 @@ class FusedMoE(torch.nn.Module):
|
||||
weight_name: str,
|
||||
shard_id: int,
|
||||
expert_id: int,
|
||||
pre_sharded: bool,
|
||||
use_presharded_weights: bool = False,
|
||||
):
|
||||
param_data = param.data
|
||||
|
||||
@@ -273,7 +273,7 @@ class FusedMoE(torch.nn.Module):
|
||||
else:
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
shard_size = self.intermediate_size_per_partition
|
||||
if pre_sharded:
|
||||
if use_presharded_weights:
|
||||
shard = slice(None)
|
||||
else:
|
||||
shard = slice(tp_rank * shard_size, (tp_rank + 1) * shard_size)
|
||||
|
||||
@@ -180,7 +180,7 @@ class LogitsProcessor(nn.Module):
|
||||
|
||||
if hasattr(self.config, "final_logit_softcapping"):
|
||||
last_logits.div_(self.config.final_logit_softcapping)
|
||||
last_logits = torch.tanh(last_logits)
|
||||
torch.tanh(last_logits, out=last_logits)
|
||||
last_logits.mul_(self.config.final_logit_softcapping)
|
||||
|
||||
# Return only last_logits if logprob is not requested
|
||||
@@ -241,7 +241,7 @@ class LogitsProcessor(nn.Module):
|
||||
|
||||
if hasattr(self.config, "final_logit_softcapping"):
|
||||
all_logits.div_(self.config.final_logit_softcapping)
|
||||
all_logits = torch.tanh(all_logits)
|
||||
torch.tanh(all_logits, out=all_logits)
|
||||
all_logits.mul_(self.config.final_logit_softcapping)
|
||||
|
||||
all_logprobs = all_logits
|
||||
|
||||
@@ -35,7 +35,7 @@ INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5
|
||||
global_server_args_dict = {
|
||||
"disable_flashinfer": False,
|
||||
"disable_flashinfer_sampling": False,
|
||||
"attention_reduce_in_fp32": False,
|
||||
"triton_attention_reduce_in_fp32": False,
|
||||
"enable_mla": False,
|
||||
}
|
||||
|
||||
|
||||
@@ -606,6 +606,9 @@ class TokenizerManager:
|
||||
return background_tasks
|
||||
|
||||
def create_handle_loop(self):
|
||||
if not self.to_create_loop:
|
||||
return
|
||||
|
||||
self.to_create_loop = False
|
||||
loop = asyncio.get_event_loop()
|
||||
loop.create_task(self.handle_loop())
|
||||
|
||||
@@ -20,7 +20,6 @@ import importlib
|
||||
import importlib.resources
|
||||
import logging
|
||||
import pkgutil
|
||||
import warnings
|
||||
from functools import lru_cache
|
||||
from typing import Optional, Type
|
||||
|
||||
@@ -91,23 +90,35 @@ class ModelRunner:
|
||||
{
|
||||
"disable_flashinfer": server_args.disable_flashinfer,
|
||||
"disable_flashinfer_sampling": server_args.disable_flashinfer_sampling,
|
||||
"attention_reduce_in_fp32": server_args.attention_reduce_in_fp32,
|
||||
"triton_attention_reduce_in_fp32": server_args.triton_attention_reduce_in_fp32,
|
||||
"enable_mla": server_args.enable_mla,
|
||||
}
|
||||
)
|
||||
|
||||
min_per_gpu_memory = self.init_torch_distributed()
|
||||
self.load_model()
|
||||
self.init_memory_pool(
|
||||
min_per_gpu_memory,
|
||||
server_args.max_num_reqs,
|
||||
server_args.max_total_tokens,
|
||||
)
|
||||
self.init_cublas()
|
||||
self.init_flashinfer()
|
||||
self.init_cuda_graphs()
|
||||
|
||||
def init_torch_distributed(self):
|
||||
# Init torch distributed
|
||||
torch.cuda.set_device(self.gpu_id)
|
||||
logger.info(f"[gpu={self.gpu_id}] Init nccl begin.")
|
||||
|
||||
if not server_args.enable_p2p_check:
|
||||
if not self.server_args.enable_p2p_check:
|
||||
monkey_patch_vllm_p2p_access_check(self.gpu_id)
|
||||
|
||||
if server_args.nccl_init_addr:
|
||||
nccl_init_method = f"tcp://{server_args.nccl_init_addr}"
|
||||
if self.server_args.nccl_init_addr:
|
||||
nccl_init_method = f"tcp://{self.server_args.nccl_init_addr}"
|
||||
else:
|
||||
nccl_init_method = f"tcp://127.0.0.1:{self.nccl_port}"
|
||||
set_custom_all_reduce(not server_args.disable_custom_all_reduce)
|
||||
set_custom_all_reduce(not self.server_args.disable_custom_all_reduce)
|
||||
init_distributed_environment(
|
||||
backend="nccl",
|
||||
world_size=self.tp_size,
|
||||
@@ -116,32 +127,28 @@ class ModelRunner:
|
||||
distributed_init_method=nccl_init_method,
|
||||
)
|
||||
initialize_model_parallel(tensor_model_parallel_size=self.tp_size)
|
||||
total_gpu_memory = get_available_gpu_memory(
|
||||
min_per_gpu_memory = get_available_gpu_memory(
|
||||
self.gpu_id, distributed=self.tp_size > 1
|
||||
)
|
||||
self.tp_group = get_tp_group()
|
||||
|
||||
# Currently, there is a bug with mulit-node tensor parallelsim + padded cuda graph,
|
||||
# so we disable padding in cuda graph.
|
||||
if not all(in_the_same_node_as(self.tp_group.cpu_group, source_rank=0)):
|
||||
self.server_args.disable_cuda_graph_padding = True
|
||||
logger.info(
|
||||
"Setting disable_cuda_graph_padding to True because of multi-node tensor parallelism."
|
||||
)
|
||||
|
||||
# Check memory for tensor parallelism
|
||||
if self.tp_size > 1:
|
||||
total_local_gpu_memory = get_available_gpu_memory(self.gpu_id)
|
||||
if total_local_gpu_memory < total_gpu_memory * 0.9:
|
||||
local_gpu_memory = get_available_gpu_memory(self.gpu_id)
|
||||
if min_per_gpu_memory < local_gpu_memory * 0.9:
|
||||
raise ValueError(
|
||||
"The memory capacity is unbalanced. Some GPUs may be occupied by other processes."
|
||||
)
|
||||
|
||||
# Load the model and create memory pool
|
||||
self.load_model()
|
||||
self.init_memory_pool(
|
||||
total_gpu_memory,
|
||||
server_args.max_num_reqs,
|
||||
server_args.max_total_tokens,
|
||||
)
|
||||
self.init_cublas()
|
||||
self.init_flashinfer()
|
||||
|
||||
if self.is_generation:
|
||||
# FIXME Currently, cuda graph only capture decode steps, which only exists in causal models
|
||||
# Capture cuda graphs
|
||||
self.init_cuda_graphs()
|
||||
return min_per_gpu_memory
|
||||
|
||||
def load_model(self):
|
||||
logger.info(
|
||||
@@ -150,7 +157,7 @@ class ModelRunner:
|
||||
)
|
||||
if torch.cuda.get_device_capability()[0] < 8:
|
||||
logger.info(
|
||||
"Compute capability below sm80 use float16 due to lack of bfloat16 support."
|
||||
"Compute capability below sm80. Use float16 due to lack of bfloat16 support."
|
||||
)
|
||||
self.server_args.dtype = "float16"
|
||||
|
||||
@@ -168,8 +175,9 @@ class ModelRunner:
|
||||
skip_tokenizer_init=True,
|
||||
)
|
||||
|
||||
# A temporary hack to fix the num_heads for meta-llama/Meta-Llama-3.1-405B-FP8 checkpoints
|
||||
# Drop this after Sept, 2024.
|
||||
if is_llama3_405b_fp8_head_16(self.model_config) and self.tp_size <= 8:
|
||||
# A temporary hack to fix the num_heads for meta-llama/Meta-Llama-3.1-405B-FP8 checkpoints
|
||||
self.model_config.hf_config.num_key_value_heads = 8
|
||||
self.vllm_model_config.hf_config.num_key_value_heads = 8
|
||||
monkey_patch_vllm_qvk_linear_loader()
|
||||
@@ -191,8 +199,8 @@ class ModelRunner:
|
||||
cache_config=None,
|
||||
)
|
||||
self.sliding_window_size = (
|
||||
self.model.get_window_size()
|
||||
if hasattr(self.model, "get_window_size")
|
||||
self.model.get_attention_sliding_window_size()
|
||||
if hasattr(self.model, "get_attention_sliding_window_size")
|
||||
else None
|
||||
)
|
||||
self.is_generation = is_generation_model(
|
||||
@@ -206,7 +214,8 @@ class ModelRunner:
|
||||
f"avail mem={get_available_gpu_memory(self.gpu_id):.2f} GB"
|
||||
)
|
||||
|
||||
def update_weights(self, model_path, load_format):
|
||||
def update_weights(self, model_path: str, load_format: str):
|
||||
"""Update weights in-place."""
|
||||
from vllm.model_executor.model_loader.loader import (
|
||||
DefaultModelLoader,
|
||||
device_loading_context,
|
||||
@@ -222,6 +231,7 @@ class ModelRunner:
|
||||
target_device = torch.device(self.device_config.device)
|
||||
|
||||
try:
|
||||
# TODO: Use a better method to check this
|
||||
vllm_model_config = VllmModelConfig(
|
||||
model=model_path,
|
||||
quantization=self.server_args.quantization,
|
||||
@@ -291,7 +301,7 @@ class ModelRunner:
|
||||
logger.info(f"[gpu={self.gpu_id}] Update weights end.")
|
||||
return True, "Succeeded to update model weights"
|
||||
|
||||
def profile_max_num_token(self, total_gpu_memory):
|
||||
def profile_max_num_token(self, total_gpu_memory: int):
|
||||
available_gpu_memory = get_available_gpu_memory(
|
||||
self.gpu_id, distributed=self.tp_size > 1
|
||||
)
|
||||
@@ -319,7 +329,10 @@ class ModelRunner:
|
||||
return max_num_token
|
||||
|
||||
def init_memory_pool(
|
||||
self, total_gpu_memory, max_num_reqs=None, max_total_tokens=None
|
||||
self,
|
||||
total_gpu_memory: int,
|
||||
max_num_reqs: int = None,
|
||||
max_total_tokens: int = None,
|
||||
):
|
||||
self.max_total_num_tokens = self.profile_max_num_token(total_gpu_memory)
|
||||
if max_total_tokens is not None:
|
||||
@@ -388,6 +401,7 @@ class ModelRunner:
|
||||
return c
|
||||
|
||||
def init_flashinfer(self):
|
||||
"""Init flashinfer attention kernel wrappers."""
|
||||
if self.server_args.disable_flashinfer:
|
||||
assert (
|
||||
self.sliding_window_size is None
|
||||
@@ -448,6 +462,11 @@ class ModelRunner:
|
||||
)
|
||||
|
||||
def init_cuda_graphs(self):
|
||||
"""Capture cuda graphs."""
|
||||
if not self.is_generation:
|
||||
# TODO: Currently, cuda graph only captures decode steps, which only exists for generation models
|
||||
return
|
||||
|
||||
from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner
|
||||
|
||||
if self.server_args.disable_cuda_graph or self.server_args.disable_flashinfer:
|
||||
@@ -457,7 +476,12 @@ class ModelRunner:
|
||||
logger.info(
|
||||
f"[gpu={self.gpu_id}] Capture cuda graph begin. This can take up to several minutes."
|
||||
)
|
||||
batch_size_list = [1, 2, 4] + [i * 8 for i in range(1, 17)]
|
||||
|
||||
if self.server_args.disable_cuda_graph_padding:
|
||||
batch_size_list = list(range(1, 32)) + [64, 128]
|
||||
else:
|
||||
batch_size_list = [1, 2, 4] + [i * 8 for i in range(1, 21)]
|
||||
|
||||
self.cuda_graph_runner = CudaGraphRunner(
|
||||
self,
|
||||
max_batch_size_to_capture=max(batch_size_list),
|
||||
|
||||
@@ -46,7 +46,7 @@ from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
||||
|
||||
# Aligned with HF's implementation, using sliding window inclusive with the last token
|
||||
# SGLang assumes exclusive
|
||||
def get_window_size(config):
|
||||
def get_attention_sliding_window_size(config):
|
||||
return config.sliding_window - 1
|
||||
|
||||
|
||||
@@ -213,7 +213,11 @@ class Gemma2Attention(nn.Module):
|
||||
self.scaling,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
layer_id=layer_idx,
|
||||
sliding_window_size=get_window_size(config) if use_sliding_window else None,
|
||||
sliding_window_size=(
|
||||
get_attention_sliding_window_size(config)
|
||||
if use_sliding_window
|
||||
else None
|
||||
),
|
||||
logit_cap=self.config.attn_logit_softcapping,
|
||||
)
|
||||
|
||||
@@ -406,8 +410,8 @@ class Gemma2ForCausalLM(nn.Module):
|
||||
input_ids, hidden_states, self.model.embed_tokens.weight, input_metadata
|
||||
)
|
||||
|
||||
def get_window_size(self):
|
||||
return get_window_size(self.config)
|
||||
def get_attention_sliding_window_size(self):
|
||||
return get_attention_sliding_window_size(self.config)
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
stacked_params_mapping = [
|
||||
|
||||
@@ -295,12 +295,14 @@ class Grok1ModelForCausalLM(nn.Module):
|
||||
self.config = config
|
||||
self.quant_config = quant_config
|
||||
self.model = Grok1Model(config, quant_config=quant_config)
|
||||
# self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
|
||||
self.lm_head = ReplicatedLinear(config.hidden_size, config.vocab_size)
|
||||
self.logits_processor = LogitsProcessor(config, skip_all_gather=True)
|
||||
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
|
||||
self.logits_processor = LogitsProcessor(config)
|
||||
|
||||
# Monkey patch _prepare_weights to load pre-sharded weights
|
||||
setattr(DefaultModelLoader, "_prepare_weights", _prepare_presharded_weights)
|
||||
|
||||
self.use_presharded_weights = True
|
||||
|
||||
warnings.filterwarnings("ignore", category=FutureWarning)
|
||||
|
||||
def forward(
|
||||
@@ -356,6 +358,13 @@ class Grok1ModelForCausalLM(nn.Module):
|
||||
continue
|
||||
name = name.replace(weight_name, param_name)
|
||||
|
||||
if self.use_presharded_weights:
|
||||
extra_kwargs = {
|
||||
"use_presharded_weights": self.use_presharded_weights
|
||||
}
|
||||
else:
|
||||
extra_kwargs = {}
|
||||
|
||||
param = params_dict[name]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(
|
||||
@@ -364,7 +373,7 @@ class Grok1ModelForCausalLM(nn.Module):
|
||||
weight_name,
|
||||
shard_id=shard_id,
|
||||
expert_id=expert_id,
|
||||
pre_sharded=get_tensor_model_parallel_world_size() > 1,
|
||||
**extra_kwargs,
|
||||
)
|
||||
break
|
||||
else:
|
||||
|
||||
@@ -81,13 +81,12 @@ class ServerArgs:
|
||||
disable_cuda_graph: bool = False
|
||||
disable_cuda_graph_padding: bool = False
|
||||
disable_disk_cache: bool = False
|
||||
disable_custom_all_reduce: bool = False
|
||||
enable_mixed_chunk: bool = False
|
||||
enable_torch_compile: bool = False
|
||||
enable_p2p_check: bool = False
|
||||
enable_mla: bool = False
|
||||
attention_reduce_in_fp32: bool = False
|
||||
efficient_weight_load: bool = False
|
||||
disable_custom_all_reduce: bool = False
|
||||
triton_attention_reduce_in_fp32: bool = False
|
||||
|
||||
# Distributed args
|
||||
nccl_init_addr: Optional[str] = None
|
||||
@@ -404,6 +403,12 @@ class ServerArgs:
|
||||
action="store_true",
|
||||
help="Disable disk cache to avoid possible crashes related to file system or high concurrency.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--disable-custom-all-reduce",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Disable the custom all-reduce kernel and fall back to NCCL.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--enable-mixed-chunk",
|
||||
action="store_true",
|
||||
@@ -425,7 +430,7 @@ class ServerArgs:
|
||||
help="Enable Multi-head Latent Attention (MLA) for DeepSeek-V2.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--attention-reduce-in-fp32",
|
||||
"--triton-attention-reduce-in-fp32",
|
||||
action="store_true",
|
||||
help="Cast the intermidiate attention results to fp32 to avoid possible crashes related to fp16."
|
||||
"This only affects Triton attention kernels.",
|
||||
@@ -435,12 +440,6 @@ class ServerArgs:
|
||||
action="store_true",
|
||||
help="Turn on memory efficient weight loading with quantization (quantize per layer during loading).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--disable-custom-all-reduce",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Disable the custom all-reduce kernel and fall back to NCCL.",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_cli_args(cls, args: argparse.Namespace):
|
||||
|
||||
@@ -347,7 +347,7 @@ def suppress_other_loggers():
|
||||
logging.WARN
|
||||
)
|
||||
logging.getLogger("vllm.selector").setLevel(logging.WARN)
|
||||
logging.getLogger("vllm.utils").setLevel(logging.WARN)
|
||||
logging.getLogger("vllm.utils").setLevel(logging.ERROR)
|
||||
|
||||
|
||||
def assert_pkg_version(pkg: str, min_version: str, message: str):
|
||||
@@ -451,10 +451,6 @@ def monkey_patch_vllm_dummy_weight_loader():
|
||||
quant_method = getattr(module, "quant_method", None)
|
||||
if quant_method is not None:
|
||||
quant_method.process_weights_after_loading(module)
|
||||
# FIXME: Remove this after Mixtral is updated
|
||||
# to use quant_method.
|
||||
if hasattr(module, "process_weights_after_loading"):
|
||||
module.process_weights_after_loading()
|
||||
|
||||
# NOTE(woosuk): For accurate performance evaluation, we assign
|
||||
# random values to the weights.
|
||||
|
||||
@@ -24,7 +24,6 @@ import torch.nn.functional as F
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
from sglang.srt.server import Runtime
|
||||
from sglang.srt.utils import is_generation_model
|
||||
|
||||
DEFAULT_PROMPTS = [
|
||||
# the output of gemma-2-2b from SRT is unstable on the commented prompt
|
||||
@@ -63,8 +62,8 @@ class HFRunner:
|
||||
def __init__(
|
||||
self,
|
||||
model_path,
|
||||
torch_dtype=torch.float16,
|
||||
is_generation_model=None,
|
||||
torch_dtype,
|
||||
is_generation_model,
|
||||
):
|
||||
self.in_queue = multiprocessing.Queue()
|
||||
self.out_queue = multiprocessing.Queue()
|
||||
@@ -90,11 +89,8 @@ class HFRunner:
|
||||
trust_remote_code=True,
|
||||
)
|
||||
|
||||
self.is_generation_model = (
|
||||
is_generation_model(model_path)
|
||||
if is_generation_model is None
|
||||
else is_generation_model
|
||||
)
|
||||
self.is_generation_model = is_generation_model
|
||||
|
||||
if self.is_generation_model:
|
||||
self.model = AutoModelForCausalLM.from_pretrained(
|
||||
model_path,
|
||||
@@ -176,16 +172,12 @@ class SRTRunner:
|
||||
def __init__(
|
||||
self,
|
||||
model_path,
|
||||
torch_dtype,
|
||||
is_generation_model,
|
||||
tp_size=1,
|
||||
torch_dtype=torch.float16,
|
||||
is_generation_model=None,
|
||||
port=5157,
|
||||
):
|
||||
self.is_generation_model = (
|
||||
is_generation_model(model_path)
|
||||
if is_generation_model is None
|
||||
else is_generation_model
|
||||
)
|
||||
self.is_generation_model = is_generation_model
|
||||
self.runtime = Runtime(
|
||||
model_path=model_path,
|
||||
tp_size=tp_size,
|
||||
|
||||
Reference in New Issue
Block a user