Cleanup readme, llava examples, usage examples and nccl init (#1194)

This commit is contained in:
Lianmin Zheng
2024-08-24 08:02:23 -07:00
committed by GitHub
parent c9064e6fd9
commit f6af3a6561
65 changed files with 174 additions and 317 deletions

View File

@@ -111,7 +111,11 @@ def load_model(server_args, tp_rank):
suppress_other_loggers()
rank_print = print if tp_rank == 0 else lambda *args, **kwargs: None
model_config = ModelConfig(path=server_args.model_path)
model_config = ModelConfig(
server_args.model_path,
server_args.trust_remote_code,
context_length=server_args.context_length,
)
model_runner = ModelRunner(
model_config=model_config,
mem_fraction_static=server_args.mem_fraction_static,

View File

@@ -1,6 +1,6 @@
from dataclasses import dataclass, field
from dataclasses import dataclass
from enum import Enum, auto
from typing import Callable, Dict, List, Optional, Tuple
from typing import Callable, Dict, List, Tuple
class ChatTemplateStyle(Enum):

View File

@@ -1,29 +0,0 @@
"""Launch the inference server for Llava-video model."""
import argparse
from sglang.srt.server import ServerArgs, launch_server
if __name__ == "__main__":
model_overide_args = {}
model_overide_args["mm_spatial_pool_stride"] = 2
model_overide_args["architectures"] = ["LlavaVidForCausalLM"]
model_overide_args["num_frames"] = 16
model_overide_args["model_type"] = "llavavid"
if model_overide_args["num_frames"] == 32:
model_overide_args["rope_scaling"] = {"factor": 2.0, "type": "linear"}
model_overide_args["max_sequence_length"] = 4096 * 2
model_overide_args["tokenizer_model_max_length"] = 4096 * 2
model_overide_args["model_max_length"] = 4096 * 2
parser = argparse.ArgumentParser()
ServerArgs.add_cli_args(parser)
args = parser.parse_args()
if "34b" in args.model_path.lower():
model_overide_args["image_token_index"] = 64002
server_args = ServerArgs.from_cli_args(args)
launch_server(server_args, model_overide_args, None)

View File

@@ -26,7 +26,7 @@ import triton.language as tl
from sglang.srt.managers.schedule_batch import global_server_args_dict
if global_server_args_dict.get("attention_reduce_in_fp32", False):
if global_server_args_dict.get("triton_attention_reduce_in_fp32", False):
REDUCE_TRITON_TYPE = tl.float32
REDUCE_TORCH_TYPE = torch.float32
else:

View File

@@ -239,7 +239,7 @@ class FusedMoE(torch.nn.Module):
weight_name: str,
shard_id: int,
expert_id: int,
pre_sharded: bool,
use_presharded_weights: bool = False,
):
param_data = param.data
@@ -273,7 +273,7 @@ class FusedMoE(torch.nn.Module):
else:
tp_rank = get_tensor_model_parallel_rank()
shard_size = self.intermediate_size_per_partition
if pre_sharded:
if use_presharded_weights:
shard = slice(None)
else:
shard = slice(tp_rank * shard_size, (tp_rank + 1) * shard_size)

View File

@@ -180,7 +180,7 @@ class LogitsProcessor(nn.Module):
if hasattr(self.config, "final_logit_softcapping"):
last_logits.div_(self.config.final_logit_softcapping)
last_logits = torch.tanh(last_logits)
torch.tanh(last_logits, out=last_logits)
last_logits.mul_(self.config.final_logit_softcapping)
# Return only last_logits if logprob is not requested
@@ -241,7 +241,7 @@ class LogitsProcessor(nn.Module):
if hasattr(self.config, "final_logit_softcapping"):
all_logits.div_(self.config.final_logit_softcapping)
all_logits = torch.tanh(all_logits)
torch.tanh(all_logits, out=all_logits)
all_logits.mul_(self.config.final_logit_softcapping)
all_logprobs = all_logits

View File

@@ -35,7 +35,7 @@ INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5
global_server_args_dict = {
"disable_flashinfer": False,
"disable_flashinfer_sampling": False,
"attention_reduce_in_fp32": False,
"triton_attention_reduce_in_fp32": False,
"enable_mla": False,
}

View File

@@ -606,6 +606,9 @@ class TokenizerManager:
return background_tasks
def create_handle_loop(self):
if not self.to_create_loop:
return
self.to_create_loop = False
loop = asyncio.get_event_loop()
loop.create_task(self.handle_loop())

View File

@@ -20,7 +20,6 @@ import importlib
import importlib.resources
import logging
import pkgutil
import warnings
from functools import lru_cache
from typing import Optional, Type
@@ -91,23 +90,35 @@ class ModelRunner:
{
"disable_flashinfer": server_args.disable_flashinfer,
"disable_flashinfer_sampling": server_args.disable_flashinfer_sampling,
"attention_reduce_in_fp32": server_args.attention_reduce_in_fp32,
"triton_attention_reduce_in_fp32": server_args.triton_attention_reduce_in_fp32,
"enable_mla": server_args.enable_mla,
}
)
min_per_gpu_memory = self.init_torch_distributed()
self.load_model()
self.init_memory_pool(
min_per_gpu_memory,
server_args.max_num_reqs,
server_args.max_total_tokens,
)
self.init_cublas()
self.init_flashinfer()
self.init_cuda_graphs()
def init_torch_distributed(self):
# Init torch distributed
torch.cuda.set_device(self.gpu_id)
logger.info(f"[gpu={self.gpu_id}] Init nccl begin.")
if not server_args.enable_p2p_check:
if not self.server_args.enable_p2p_check:
monkey_patch_vllm_p2p_access_check(self.gpu_id)
if server_args.nccl_init_addr:
nccl_init_method = f"tcp://{server_args.nccl_init_addr}"
if self.server_args.nccl_init_addr:
nccl_init_method = f"tcp://{self.server_args.nccl_init_addr}"
else:
nccl_init_method = f"tcp://127.0.0.1:{self.nccl_port}"
set_custom_all_reduce(not server_args.disable_custom_all_reduce)
set_custom_all_reduce(not self.server_args.disable_custom_all_reduce)
init_distributed_environment(
backend="nccl",
world_size=self.tp_size,
@@ -116,32 +127,28 @@ class ModelRunner:
distributed_init_method=nccl_init_method,
)
initialize_model_parallel(tensor_model_parallel_size=self.tp_size)
total_gpu_memory = get_available_gpu_memory(
min_per_gpu_memory = get_available_gpu_memory(
self.gpu_id, distributed=self.tp_size > 1
)
self.tp_group = get_tp_group()
# Currently, there is a bug with mulit-node tensor parallelsim + padded cuda graph,
# so we disable padding in cuda graph.
if not all(in_the_same_node_as(self.tp_group.cpu_group, source_rank=0)):
self.server_args.disable_cuda_graph_padding = True
logger.info(
"Setting disable_cuda_graph_padding to True because of multi-node tensor parallelism."
)
# Check memory for tensor parallelism
if self.tp_size > 1:
total_local_gpu_memory = get_available_gpu_memory(self.gpu_id)
if total_local_gpu_memory < total_gpu_memory * 0.9:
local_gpu_memory = get_available_gpu_memory(self.gpu_id)
if min_per_gpu_memory < local_gpu_memory * 0.9:
raise ValueError(
"The memory capacity is unbalanced. Some GPUs may be occupied by other processes."
)
# Load the model and create memory pool
self.load_model()
self.init_memory_pool(
total_gpu_memory,
server_args.max_num_reqs,
server_args.max_total_tokens,
)
self.init_cublas()
self.init_flashinfer()
if self.is_generation:
# FIXME Currently, cuda graph only capture decode steps, which only exists in causal models
# Capture cuda graphs
self.init_cuda_graphs()
return min_per_gpu_memory
def load_model(self):
logger.info(
@@ -150,7 +157,7 @@ class ModelRunner:
)
if torch.cuda.get_device_capability()[0] < 8:
logger.info(
"Compute capability below sm80 use float16 due to lack of bfloat16 support."
"Compute capability below sm80. Use float16 due to lack of bfloat16 support."
)
self.server_args.dtype = "float16"
@@ -168,8 +175,9 @@ class ModelRunner:
skip_tokenizer_init=True,
)
# A temporary hack to fix the num_heads for meta-llama/Meta-Llama-3.1-405B-FP8 checkpoints
# Drop this after Sept, 2024.
if is_llama3_405b_fp8_head_16(self.model_config) and self.tp_size <= 8:
# A temporary hack to fix the num_heads for meta-llama/Meta-Llama-3.1-405B-FP8 checkpoints
self.model_config.hf_config.num_key_value_heads = 8
self.vllm_model_config.hf_config.num_key_value_heads = 8
monkey_patch_vllm_qvk_linear_loader()
@@ -191,8 +199,8 @@ class ModelRunner:
cache_config=None,
)
self.sliding_window_size = (
self.model.get_window_size()
if hasattr(self.model, "get_window_size")
self.model.get_attention_sliding_window_size()
if hasattr(self.model, "get_attention_sliding_window_size")
else None
)
self.is_generation = is_generation_model(
@@ -206,7 +214,8 @@ class ModelRunner:
f"avail mem={get_available_gpu_memory(self.gpu_id):.2f} GB"
)
def update_weights(self, model_path, load_format):
def update_weights(self, model_path: str, load_format: str):
"""Update weights in-place."""
from vllm.model_executor.model_loader.loader import (
DefaultModelLoader,
device_loading_context,
@@ -222,6 +231,7 @@ class ModelRunner:
target_device = torch.device(self.device_config.device)
try:
# TODO: Use a better method to check this
vllm_model_config = VllmModelConfig(
model=model_path,
quantization=self.server_args.quantization,
@@ -291,7 +301,7 @@ class ModelRunner:
logger.info(f"[gpu={self.gpu_id}] Update weights end.")
return True, "Succeeded to update model weights"
def profile_max_num_token(self, total_gpu_memory):
def profile_max_num_token(self, total_gpu_memory: int):
available_gpu_memory = get_available_gpu_memory(
self.gpu_id, distributed=self.tp_size > 1
)
@@ -319,7 +329,10 @@ class ModelRunner:
return max_num_token
def init_memory_pool(
self, total_gpu_memory, max_num_reqs=None, max_total_tokens=None
self,
total_gpu_memory: int,
max_num_reqs: int = None,
max_total_tokens: int = None,
):
self.max_total_num_tokens = self.profile_max_num_token(total_gpu_memory)
if max_total_tokens is not None:
@@ -388,6 +401,7 @@ class ModelRunner:
return c
def init_flashinfer(self):
"""Init flashinfer attention kernel wrappers."""
if self.server_args.disable_flashinfer:
assert (
self.sliding_window_size is None
@@ -448,6 +462,11 @@ class ModelRunner:
)
def init_cuda_graphs(self):
"""Capture cuda graphs."""
if not self.is_generation:
# TODO: Currently, cuda graph only captures decode steps, which only exists for generation models
return
from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner
if self.server_args.disable_cuda_graph or self.server_args.disable_flashinfer:
@@ -457,7 +476,12 @@ class ModelRunner:
logger.info(
f"[gpu={self.gpu_id}] Capture cuda graph begin. This can take up to several minutes."
)
batch_size_list = [1, 2, 4] + [i * 8 for i in range(1, 17)]
if self.server_args.disable_cuda_graph_padding:
batch_size_list = list(range(1, 32)) + [64, 128]
else:
batch_size_list = [1, 2, 4] + [i * 8 for i in range(1, 21)]
self.cuda_graph_runner = CudaGraphRunner(
self,
max_batch_size_to_capture=max(batch_size_list),

View File

@@ -46,7 +46,7 @@ from sglang.srt.model_executor.forward_batch_info import InputMetadata
# Aligned with HF's implementation, using sliding window inclusive with the last token
# SGLang assumes exclusive
def get_window_size(config):
def get_attention_sliding_window_size(config):
return config.sliding_window - 1
@@ -213,7 +213,11 @@ class Gemma2Attention(nn.Module):
self.scaling,
num_kv_heads=self.num_kv_heads,
layer_id=layer_idx,
sliding_window_size=get_window_size(config) if use_sliding_window else None,
sliding_window_size=(
get_attention_sliding_window_size(config)
if use_sliding_window
else None
),
logit_cap=self.config.attn_logit_softcapping,
)
@@ -406,8 +410,8 @@ class Gemma2ForCausalLM(nn.Module):
input_ids, hidden_states, self.model.embed_tokens.weight, input_metadata
)
def get_window_size(self):
return get_window_size(self.config)
def get_attention_sliding_window_size(self):
return get_attention_sliding_window_size(self.config)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [

View File

@@ -295,12 +295,14 @@ class Grok1ModelForCausalLM(nn.Module):
self.config = config
self.quant_config = quant_config
self.model = Grok1Model(config, quant_config=quant_config)
# self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
self.lm_head = ReplicatedLinear(config.hidden_size, config.vocab_size)
self.logits_processor = LogitsProcessor(config, skip_all_gather=True)
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
self.logits_processor = LogitsProcessor(config)
# Monkey patch _prepare_weights to load pre-sharded weights
setattr(DefaultModelLoader, "_prepare_weights", _prepare_presharded_weights)
self.use_presharded_weights = True
warnings.filterwarnings("ignore", category=FutureWarning)
def forward(
@@ -356,6 +358,13 @@ class Grok1ModelForCausalLM(nn.Module):
continue
name = name.replace(weight_name, param_name)
if self.use_presharded_weights:
extra_kwargs = {
"use_presharded_weights": self.use_presharded_weights
}
else:
extra_kwargs = {}
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(
@@ -364,7 +373,7 @@ class Grok1ModelForCausalLM(nn.Module):
weight_name,
shard_id=shard_id,
expert_id=expert_id,
pre_sharded=get_tensor_model_parallel_world_size() > 1,
**extra_kwargs,
)
break
else:

View File

@@ -81,13 +81,12 @@ class ServerArgs:
disable_cuda_graph: bool = False
disable_cuda_graph_padding: bool = False
disable_disk_cache: bool = False
disable_custom_all_reduce: bool = False
enable_mixed_chunk: bool = False
enable_torch_compile: bool = False
enable_p2p_check: bool = False
enable_mla: bool = False
attention_reduce_in_fp32: bool = False
efficient_weight_load: bool = False
disable_custom_all_reduce: bool = False
triton_attention_reduce_in_fp32: bool = False
# Distributed args
nccl_init_addr: Optional[str] = None
@@ -404,6 +403,12 @@ class ServerArgs:
action="store_true",
help="Disable disk cache to avoid possible crashes related to file system or high concurrency.",
)
parser.add_argument(
"--disable-custom-all-reduce",
action="store_true",
default=False,
help="Disable the custom all-reduce kernel and fall back to NCCL.",
)
parser.add_argument(
"--enable-mixed-chunk",
action="store_true",
@@ -425,7 +430,7 @@ class ServerArgs:
help="Enable Multi-head Latent Attention (MLA) for DeepSeek-V2.",
)
parser.add_argument(
"--attention-reduce-in-fp32",
"--triton-attention-reduce-in-fp32",
action="store_true",
help="Cast the intermidiate attention results to fp32 to avoid possible crashes related to fp16."
"This only affects Triton attention kernels.",
@@ -435,12 +440,6 @@ class ServerArgs:
action="store_true",
help="Turn on memory efficient weight loading with quantization (quantize per layer during loading).",
)
parser.add_argument(
"--disable-custom-all-reduce",
action="store_true",
default=False,
help="Disable the custom all-reduce kernel and fall back to NCCL.",
)
@classmethod
def from_cli_args(cls, args: argparse.Namespace):

View File

@@ -347,7 +347,7 @@ def suppress_other_loggers():
logging.WARN
)
logging.getLogger("vllm.selector").setLevel(logging.WARN)
logging.getLogger("vllm.utils").setLevel(logging.WARN)
logging.getLogger("vllm.utils").setLevel(logging.ERROR)
def assert_pkg_version(pkg: str, min_version: str, message: str):
@@ -451,10 +451,6 @@ def monkey_patch_vllm_dummy_weight_loader():
quant_method = getattr(module, "quant_method", None)
if quant_method is not None:
quant_method.process_weights_after_loading(module)
# FIXME: Remove this after Mixtral is updated
# to use quant_method.
if hasattr(module, "process_weights_after_loading"):
module.process_weights_after_loading()
# NOTE(woosuk): For accurate performance evaluation, we assign
# random values to the weights.

View File

@@ -24,7 +24,6 @@ import torch.nn.functional as F
from transformers import AutoModelForCausalLM, AutoTokenizer
from sglang.srt.server import Runtime
from sglang.srt.utils import is_generation_model
DEFAULT_PROMPTS = [
# the output of gemma-2-2b from SRT is unstable on the commented prompt
@@ -63,8 +62,8 @@ class HFRunner:
def __init__(
self,
model_path,
torch_dtype=torch.float16,
is_generation_model=None,
torch_dtype,
is_generation_model,
):
self.in_queue = multiprocessing.Queue()
self.out_queue = multiprocessing.Queue()
@@ -90,11 +89,8 @@ class HFRunner:
trust_remote_code=True,
)
self.is_generation_model = (
is_generation_model(model_path)
if is_generation_model is None
else is_generation_model
)
self.is_generation_model = is_generation_model
if self.is_generation_model:
self.model = AutoModelForCausalLM.from_pretrained(
model_path,
@@ -176,16 +172,12 @@ class SRTRunner:
def __init__(
self,
model_path,
torch_dtype,
is_generation_model,
tp_size=1,
torch_dtype=torch.float16,
is_generation_model=None,
port=5157,
):
self.is_generation_model = (
is_generation_model(model_path)
if is_generation_model is None
else is_generation_model
)
self.is_generation_model = is_generation_model
self.runtime = Runtime(
model_path=model_path,
tp_size=tp_size,