Cleanup readme, llava examples, usage examples and nccl init (#1194)

This commit is contained in:
Lianmin Zheng
2024-08-24 08:02:23 -07:00
committed by GitHub
parent c9064e6fd9
commit f6af3a6561
65 changed files with 174 additions and 317 deletions

View File

@@ -20,7 +20,6 @@ import importlib
import importlib.resources
import logging
import pkgutil
import warnings
from functools import lru_cache
from typing import Optional, Type
@@ -91,23 +90,35 @@ class ModelRunner:
{
"disable_flashinfer": server_args.disable_flashinfer,
"disable_flashinfer_sampling": server_args.disable_flashinfer_sampling,
"attention_reduce_in_fp32": server_args.attention_reduce_in_fp32,
"triton_attention_reduce_in_fp32": server_args.triton_attention_reduce_in_fp32,
"enable_mla": server_args.enable_mla,
}
)
min_per_gpu_memory = self.init_torch_distributed()
self.load_model()
self.init_memory_pool(
min_per_gpu_memory,
server_args.max_num_reqs,
server_args.max_total_tokens,
)
self.init_cublas()
self.init_flashinfer()
self.init_cuda_graphs()
def init_torch_distributed(self):
# Init torch distributed
torch.cuda.set_device(self.gpu_id)
logger.info(f"[gpu={self.gpu_id}] Init nccl begin.")
if not server_args.enable_p2p_check:
if not self.server_args.enable_p2p_check:
monkey_patch_vllm_p2p_access_check(self.gpu_id)
if server_args.nccl_init_addr:
nccl_init_method = f"tcp://{server_args.nccl_init_addr}"
if self.server_args.nccl_init_addr:
nccl_init_method = f"tcp://{self.server_args.nccl_init_addr}"
else:
nccl_init_method = f"tcp://127.0.0.1:{self.nccl_port}"
set_custom_all_reduce(not server_args.disable_custom_all_reduce)
set_custom_all_reduce(not self.server_args.disable_custom_all_reduce)
init_distributed_environment(
backend="nccl",
world_size=self.tp_size,
@@ -116,32 +127,28 @@ class ModelRunner:
distributed_init_method=nccl_init_method,
)
initialize_model_parallel(tensor_model_parallel_size=self.tp_size)
total_gpu_memory = get_available_gpu_memory(
min_per_gpu_memory = get_available_gpu_memory(
self.gpu_id, distributed=self.tp_size > 1
)
self.tp_group = get_tp_group()
# Currently, there is a bug with mulit-node tensor parallelsim + padded cuda graph,
# so we disable padding in cuda graph.
if not all(in_the_same_node_as(self.tp_group.cpu_group, source_rank=0)):
self.server_args.disable_cuda_graph_padding = True
logger.info(
"Setting disable_cuda_graph_padding to True because of multi-node tensor parallelism."
)
# Check memory for tensor parallelism
if self.tp_size > 1:
total_local_gpu_memory = get_available_gpu_memory(self.gpu_id)
if total_local_gpu_memory < total_gpu_memory * 0.9:
local_gpu_memory = get_available_gpu_memory(self.gpu_id)
if min_per_gpu_memory < local_gpu_memory * 0.9:
raise ValueError(
"The memory capacity is unbalanced. Some GPUs may be occupied by other processes."
)
# Load the model and create memory pool
self.load_model()
self.init_memory_pool(
total_gpu_memory,
server_args.max_num_reqs,
server_args.max_total_tokens,
)
self.init_cublas()
self.init_flashinfer()
if self.is_generation:
# FIXME Currently, cuda graph only capture decode steps, which only exists in causal models
# Capture cuda graphs
self.init_cuda_graphs()
return min_per_gpu_memory
def load_model(self):
logger.info(
@@ -150,7 +157,7 @@ class ModelRunner:
)
if torch.cuda.get_device_capability()[0] < 8:
logger.info(
"Compute capability below sm80 use float16 due to lack of bfloat16 support."
"Compute capability below sm80. Use float16 due to lack of bfloat16 support."
)
self.server_args.dtype = "float16"
@@ -168,8 +175,9 @@ class ModelRunner:
skip_tokenizer_init=True,
)
# A temporary hack to fix the num_heads for meta-llama/Meta-Llama-3.1-405B-FP8 checkpoints
# Drop this after Sept, 2024.
if is_llama3_405b_fp8_head_16(self.model_config) and self.tp_size <= 8:
# A temporary hack to fix the num_heads for meta-llama/Meta-Llama-3.1-405B-FP8 checkpoints
self.model_config.hf_config.num_key_value_heads = 8
self.vllm_model_config.hf_config.num_key_value_heads = 8
monkey_patch_vllm_qvk_linear_loader()
@@ -191,8 +199,8 @@ class ModelRunner:
cache_config=None,
)
self.sliding_window_size = (
self.model.get_window_size()
if hasattr(self.model, "get_window_size")
self.model.get_attention_sliding_window_size()
if hasattr(self.model, "get_attention_sliding_window_size")
else None
)
self.is_generation = is_generation_model(
@@ -206,7 +214,8 @@ class ModelRunner:
f"avail mem={get_available_gpu_memory(self.gpu_id):.2f} GB"
)
def update_weights(self, model_path, load_format):
def update_weights(self, model_path: str, load_format: str):
"""Update weights in-place."""
from vllm.model_executor.model_loader.loader import (
DefaultModelLoader,
device_loading_context,
@@ -222,6 +231,7 @@ class ModelRunner:
target_device = torch.device(self.device_config.device)
try:
# TODO: Use a better method to check this
vllm_model_config = VllmModelConfig(
model=model_path,
quantization=self.server_args.quantization,
@@ -291,7 +301,7 @@ class ModelRunner:
logger.info(f"[gpu={self.gpu_id}] Update weights end.")
return True, "Succeeded to update model weights"
def profile_max_num_token(self, total_gpu_memory):
def profile_max_num_token(self, total_gpu_memory: int):
available_gpu_memory = get_available_gpu_memory(
self.gpu_id, distributed=self.tp_size > 1
)
@@ -319,7 +329,10 @@ class ModelRunner:
return max_num_token
def init_memory_pool(
self, total_gpu_memory, max_num_reqs=None, max_total_tokens=None
self,
total_gpu_memory: int,
max_num_reqs: int = None,
max_total_tokens: int = None,
):
self.max_total_num_tokens = self.profile_max_num_token(total_gpu_memory)
if max_total_tokens is not None:
@@ -388,6 +401,7 @@ class ModelRunner:
return c
def init_flashinfer(self):
"""Init flashinfer attention kernel wrappers."""
if self.server_args.disable_flashinfer:
assert (
self.sliding_window_size is None
@@ -448,6 +462,11 @@ class ModelRunner:
)
def init_cuda_graphs(self):
"""Capture cuda graphs."""
if not self.is_generation:
# TODO: Currently, cuda graph only captures decode steps, which only exists for generation models
return
from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner
if self.server_args.disable_cuda_graph or self.server_args.disable_flashinfer:
@@ -457,7 +476,12 @@ class ModelRunner:
logger.info(
f"[gpu={self.gpu_id}] Capture cuda graph begin. This can take up to several minutes."
)
batch_size_list = [1, 2, 4] + [i * 8 for i in range(1, 17)]
if self.server_args.disable_cuda_graph_padding:
batch_size_list = list(range(1, 32)) + [64, 128]
else:
batch_size_list = [1, 2, 4] + [i * 8 for i in range(1, 21)]
self.cuda_graph_runner = CudaGraphRunner(
self,
max_batch_size_to_capture=max(batch_size_list),