Update vllm version to support llama3.1 (#705)
This commit is contained in:
@@ -21,7 +21,7 @@ dependencies = [
|
||||
|
||||
[project.optional-dependencies]
|
||||
srt = ["aiohttp", "fastapi", "hf_transfer", "huggingface_hub", "interegular", "packaging", "pillow",
|
||||
"psutil", "pydantic", "torch", "uvicorn", "uvloop", "zmq", "vllm==0.5.1", "outlines>=0.0.44"]
|
||||
"psutil", "pydantic", "torch", "uvicorn", "uvloop", "zmq", "vllm==0.5.3.post1", "outlines>=0.0.44"]
|
||||
openai = ["openai>=1.0", "tiktoken"]
|
||||
anthropic = ["anthropic>=0.20.0"]
|
||||
litellm = ["litellm>=1.0.0"]
|
||||
|
||||
@@ -73,6 +73,8 @@ def get_context_length(config):
|
||||
rope_scaling = getattr(config, "rope_scaling", None)
|
||||
if rope_scaling:
|
||||
rope_scaling_factor = config.rope_scaling["factor"]
|
||||
if config.rope_scaling["rope_type"] == "llama3":
|
||||
rope_scaling_factor = 1
|
||||
else:
|
||||
rope_scaling_factor = 1
|
||||
|
||||
|
||||
@@ -5,14 +5,10 @@
|
||||
from typing import Any, Dict, Iterable, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import tqdm
|
||||
from torch import nn
|
||||
from transformers import LlamaConfig
|
||||
from vllm.config import CacheConfig
|
||||
from vllm.distributed import (
|
||||
get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size,
|
||||
)
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
|
||||
@@ -375,9 +371,6 @@ class LlamaForCausalLM(nn.Module):
|
||||
weight_loader(param, loaded_weight)
|
||||
|
||||
if name is None or loaded_weight is None:
|
||||
if get_tensor_model_parallel_rank() == 0:
|
||||
weights = tqdm.tqdm(weights, total=int(len(params_dict) * 1.5))
|
||||
|
||||
for name, loaded_weight in weights:
|
||||
load_weights_per_param(name, loaded_weight)
|
||||
else:
|
||||
|
||||
@@ -222,6 +222,7 @@ def launch_server(
|
||||
detokenizer_port=ports[2],
|
||||
nccl_ports=ports[3:],
|
||||
)
|
||||
logger.info(f"{server_args=}")
|
||||
|
||||
# Handle multi-node tensor parallelism
|
||||
if server_args.nnodes > 1:
|
||||
|
||||
Reference in New Issue
Block a user