[Model] Supporet llama3 on v0.11.0
FULL AND PIECEWISE GRAPH ENBALE
This commit is contained in:
@@ -1,6 +1,31 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
# Adapted from
|
||||
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
|
||||
# Copyright 2023 The vLLM team.
|
||||
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
||||
# and OPT implementations in this library. It has been modified from its
|
||||
# original forms to accommodate minor architectural differences compared
|
||||
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Inference-only LLaMA model compatible with HuggingFace weights."""
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
# Adapted from
|
||||
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
|
||||
# Copyright 2023 The vLLM team.
|
||||
@@ -24,6 +49,7 @@
|
||||
# limitations under the License.
|
||||
"""Inference-only LLaMA model compatible with HuggingFace weights."""
|
||||
from collections.abc import Iterable
|
||||
from itertools import islice
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import torch
|
||||
@@ -37,20 +63,22 @@ from vllm.config import CacheConfig, VllmConfig
|
||||
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
||||
from vllm_kunlun.ops.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
|
||||
from vllm_kunlun.ops.linear import (MergedColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
|
||||
from vllm_kunlun.ops.vocab_parallel_embedding import VocabParallelEmbedding
|
||||
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
|
||||
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead)
|
||||
from vllm.model_executor.model_loader.weight_utils import (
|
||||
default_weight_loader, maybe_remap_kv_scale_name)
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from vllm.model_executor.models.interfaces import SupportsLoRA, SupportsPP
|
||||
from vllm.model_executor.models.interfaces import SupportsEagle3, SupportsLoRA, SupportsPP
|
||||
from vllm.model_executor.models.utils import (AutoWeightsLoader, PPMissingLayer, extract_layer_index,
|
||||
is_pp_missing_parameter,
|
||||
make_empty_intermediate_tensors_factory, make_layers,
|
||||
@@ -68,6 +96,7 @@ class LlamaMLP(nn.Module):
|
||||
bias: bool = False,
|
||||
prefix: str = "",
|
||||
reduce_results: bool = True,
|
||||
disable_tp: bool = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.gate_up_proj = MergedColumnParallelLinear(
|
||||
@@ -75,6 +104,7 @@ class LlamaMLP(nn.Module):
|
||||
output_sizes=[intermediate_size] * 2,
|
||||
bias=bias,
|
||||
quant_config=quant_config,
|
||||
disable_tp=disable_tp,
|
||||
prefix=f"{prefix}.gate_up_proj",
|
||||
)
|
||||
self.down_proj = RowParallelLinear(
|
||||
@@ -83,6 +113,7 @@ class LlamaMLP(nn.Module):
|
||||
bias=bias,
|
||||
quant_config=quant_config,
|
||||
reduce_results=reduce_results,
|
||||
disable_tp=disable_tp,
|
||||
prefix=f"{prefix}.down_proj",
|
||||
)
|
||||
if hidden_act != "silu":
|
||||
@@ -168,20 +199,31 @@ class LlamaAttention(nn.Module):
|
||||
rope_scaling=rope_scaling,
|
||||
quant_config=quant_config)
|
||||
|
||||
if hasattr(config, "interleaved_sliding_window"):
|
||||
interleaved_sliding_window = config.interleaved_sliding_window
|
||||
if isinstance(interleaved_sliding_window, int):
|
||||
sliding_window = interleaved_sliding_window
|
||||
elif isinstance(interleaved_sliding_window, list):
|
||||
sw_idx = layer_idx % len(interleaved_sliding_window)
|
||||
sliding_window = interleaved_sliding_window[sw_idx]
|
||||
sliding_window = None
|
||||
if layer_types := getattr(config, "layer_types", None):
|
||||
# Fix for Eagle3 compatibility:
|
||||
# for draft models, subtract target layer count
|
||||
# to get draft-relative layer index starting from 0
|
||||
if hasattr(config, 'target_layer_count'):
|
||||
# This is a draft model,
|
||||
# adjust layer_idx to be relative to draft layers
|
||||
effective_layer_idx = layer_idx - config.target_layer_count
|
||||
else:
|
||||
raise ValueError(
|
||||
f"{type(interleaved_sliding_window)} is not supported.")
|
||||
else:
|
||||
sliding_window = None
|
||||
# This is a target model, use layer_idx directly
|
||||
effective_layer_idx = layer_idx
|
||||
assert effective_layer_idx < len(layer_types), \
|
||||
f"effective_layer_idx: {effective_layer_idx} \
|
||||
is out of bounds for layer_types: {layer_types}"
|
||||
|
||||
self.attn = Attention(
|
||||
is_sliding = layer_types[
|
||||
effective_layer_idx] == "sliding_attention"
|
||||
if is_sliding:
|
||||
sliding_window = config.sliding_window
|
||||
|
||||
attn_cls = (EncoderOnlyAttention
|
||||
if attn_type == AttentionType.ENCODER_ONLY else Attention)
|
||||
|
||||
self.attn = attn_cls(
|
||||
self.num_heads,
|
||||
self.head_dim,
|
||||
self.scaling,
|
||||
@@ -200,8 +242,7 @@ class LlamaAttention(nn.Module):
|
||||
) -> torch.Tensor:
|
||||
qkv, _ = self.qkv_proj(hidden_states)
|
||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||
#TODO@hanhaowen:use kunlun ops to speed up
|
||||
q, k = self.rotary_emb.forward_native(positions, q, k)
|
||||
q, k = self.rotary_emb(positions, q, k)
|
||||
attn_output = self.attn(q, k, v)
|
||||
output, _ = self.o_proj(attn_output)
|
||||
return output
|
||||
@@ -227,14 +268,16 @@ class LlamaAttention(nn.Module):
|
||||
|
||||
class LlamaDecoderLayer(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: LlamaConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
def __init__(self,
|
||||
vllm_config: VllmConfig,
|
||||
prefix: str = "",
|
||||
config: Optional[LlamaConfig] = None) -> None:
|
||||
super().__init__()
|
||||
|
||||
config = config or vllm_config.model_config.hf_config
|
||||
cache_config = vllm_config.cache_config
|
||||
quant_config = vllm_config.quant_config
|
||||
|
||||
self.hidden_size = config.hidden_size
|
||||
rope_theta = getattr(config, "rope_theta", 10000)
|
||||
rope_scaling = getattr(config, "rope_scaling", None)
|
||||
@@ -306,6 +349,7 @@ class LlamaDecoderLayer(nn.Module):
|
||||
hidden_states, residual)
|
||||
hidden_states = self.self_attn(positions=positions,
|
||||
hidden_states=hidden_states)
|
||||
|
||||
# Fully Connected
|
||||
hidden_states, residual = self.post_attention_layernorm(
|
||||
hidden_states, residual)
|
||||
@@ -313,7 +357,7 @@ class LlamaDecoderLayer(nn.Module):
|
||||
return hidden_states, residual
|
||||
|
||||
|
||||
# @support_torch_compile
|
||||
@support_torch_compile
|
||||
class LlamaModel(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
@@ -324,7 +368,6 @@ class LlamaModel(nn.Module):
|
||||
super().__init__()
|
||||
|
||||
config = vllm_config.model_config.hf_config
|
||||
cache_config = vllm_config.cache_config
|
||||
quant_config = vllm_config.quant_config
|
||||
lora_config = vllm_config.lora_config
|
||||
|
||||
@@ -346,10 +389,7 @@ class LlamaModel(nn.Module):
|
||||
self.embed_tokens = PPMissingLayer()
|
||||
self.start_layer, self.end_layer, self.layers = make_layers(
|
||||
config.num_hidden_layers,
|
||||
lambda prefix: layer_type(config=config,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
prefix=prefix),
|
||||
lambda prefix: layer_type(vllm_config=vllm_config, prefix=prefix),
|
||||
prefix=f"{prefix}.layers",
|
||||
)
|
||||
if get_pp_group().is_last_rank:
|
||||
@@ -357,7 +397,7 @@ class LlamaModel(nn.Module):
|
||||
else:
|
||||
self.norm = PPMissingLayer()
|
||||
|
||||
self.aux_hidden_state_layers: tuple[int] = tuple()
|
||||
self.aux_hidden_state_layers = tuple[int, ...]()
|
||||
|
||||
self.make_empty_intermediate_tensors = (
|
||||
make_empty_intermediate_tensors_factory(
|
||||
@@ -387,7 +427,7 @@ class LlamaModel(nn.Module):
|
||||
|
||||
aux_hidden_states = []
|
||||
for idx, layer in enumerate(
|
||||
self.layers[self.start_layer:self.end_layer]):
|
||||
islice(self.layers, self.start_layer, self.end_layer)):
|
||||
if idx in self.aux_hidden_state_layers:
|
||||
aux_hidden_states.append(hidden_states + residual)
|
||||
hidden_states, residual = layer(positions, hidden_states, residual)
|
||||
@@ -471,7 +511,7 @@ class LlamaModel(nn.Module):
|
||||
return loaded_params
|
||||
|
||||
|
||||
class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP, SupportsEagle3):
|
||||
packed_modules_mapping = {
|
||||
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
|
||||
"gate_up_proj": ["gate_proj", "up_proj"]
|
||||
@@ -557,10 +597,10 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.model.make_empty_intermediate_tensors)
|
||||
|
||||
def set_aux_hidden_state_layers(self, layers: tuple[int]) -> None:
|
||||
def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None:
|
||||
self.model.aux_hidden_state_layers = layers
|
||||
|
||||
def get_eagle3_aux_hidden_state_layers(self) -> tuple[int]:
|
||||
def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]:
|
||||
num_layers = len(self.model.layers)
|
||||
return (2, num_layers // 2, num_layers - 3)
|
||||
|
||||
@@ -589,10 +629,8 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
def compute_logits(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> Optional[torch.Tensor]:
|
||||
logits = self.logits_processor(self.lm_head, hidden_states,
|
||||
sampling_metadata)
|
||||
logits = self.logits_processor(self.lm_head, hidden_states)
|
||||
return logits
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str,
|
||||
@@ -614,9 +652,8 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
loaded_weight: torch.Tensor,
|
||||
) -> tuple[str, torch.Tensor]:
|
||||
|
||||
def permute(w: torch.Tensor, n_heads: int):
|
||||
def permute(w: torch.Tensor, n_heads: int, attn_out: int):
|
||||
attn_in = self.config.head_dim * n_heads
|
||||
attn_out = self.config.hidden_size
|
||||
|
||||
return w.view(n_heads, attn_in // n_heads // 2, 2,
|
||||
attn_out).transpose(1, 2).reshape(attn_in, attn_out)
|
||||
@@ -625,12 +662,24 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
modules = name.split(".")
|
||||
|
||||
# rotary embeds should be sliced
|
||||
# If using quantized model in mistral format,
|
||||
# quantization scales (qscale_weight) also need to be sliced
|
||||
if "wk" in modules and modules[-1] == "weight":
|
||||
loaded_weight = permute(loaded_weight,
|
||||
self.config.num_key_value_heads)
|
||||
self.config.num_key_value_heads,
|
||||
self.config.hidden_size)
|
||||
elif "wk" in modules and modules[
|
||||
-1] == "qscale_weight" and loaded_weight.numel() > 1:
|
||||
loaded_weight = permute(loaded_weight,
|
||||
self.config.num_key_value_heads, 1)
|
||||
elif "wq" in modules and modules[-1] == "weight":
|
||||
loaded_weight = permute(loaded_weight,
|
||||
self.config.num_attention_heads)
|
||||
self.config.num_attention_heads,
|
||||
self.config.hidden_size)
|
||||
elif "wq" in modules and modules[
|
||||
-1] == "qscale_weight" and loaded_weight.numel() > 1:
|
||||
loaded_weight = permute(loaded_weight,
|
||||
self.config.num_attention_heads, 1)
|
||||
|
||||
num_modules = len(modules)
|
||||
for i in range(num_modules):
|
||||
@@ -646,3 +695,4 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
name = name.replace(item, mapping[item])
|
||||
|
||||
return name, loaded_weight
|
||||
|
||||
|
||||
Reference in New Issue
Block a user