Update vllm version to support llama3.1 (#705)

This commit is contained in:
Ying Sheng
2024-07-23 13:49:34 -07:00
committed by GitHub
parent fa7ccb3316
commit 444a02441a
4 changed files with 5 additions and 9 deletions

View File

@@ -21,7 +21,7 @@ dependencies = [
[project.optional-dependencies]
srt = ["aiohttp", "fastapi", "hf_transfer", "huggingface_hub", "interegular", "packaging", "pillow",
"psutil", "pydantic", "torch", "uvicorn", "uvloop", "zmq", "vllm==0.5.1", "outlines>=0.0.44"]
"psutil", "pydantic", "torch", "uvicorn", "uvloop", "zmq", "vllm==0.5.3.post1", "outlines>=0.0.44"]
openai = ["openai>=1.0", "tiktoken"]
anthropic = ["anthropic>=0.20.0"]
litellm = ["litellm>=1.0.0"]

View File

@@ -73,6 +73,8 @@ def get_context_length(config):
rope_scaling = getattr(config, "rope_scaling", None)
if rope_scaling:
rope_scaling_factor = config.rope_scaling["factor"]
if config.rope_scaling["rope_type"] == "llama3":
rope_scaling_factor = 1
else:
rope_scaling_factor = 1

View File

@@ -5,14 +5,10 @@
from typing import Any, Dict, Iterable, Optional, Tuple
import torch
import tqdm
from torch import nn
from transformers import LlamaConfig
from vllm.config import CacheConfig
from vllm.distributed import (
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
)
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
@@ -375,9 +371,6 @@ class LlamaForCausalLM(nn.Module):
weight_loader(param, loaded_weight)
if name is None or loaded_weight is None:
if get_tensor_model_parallel_rank() == 0:
weights = tqdm.tqdm(weights, total=int(len(params_dict) * 1.5))
for name, loaded_weight in weights:
load_weights_per_param(name, loaded_weight)
else:

View File

@@ -222,6 +222,7 @@ def launch_server(
detokenizer_port=ports[2],
nccl_ports=ports[3:],
)
logger.info(f"{server_args=}")
# Handle multi-node tensor parallelism
if server_args.nnodes > 1: