[Model] Support llama3 on v0.11.0 Merge pull request #19 from xyDong0223/v0.11.0dev

[Model] Support llama3 on v0.11.0
This commit is contained in:
Xinyu Dong
2025-12-16 14:15:58 +08:00
committed by GitHub
2 changed files with 73 additions and 44 deletions

View File

@@ -78,6 +78,10 @@ def register_model():
"SeedOssForCausalLM", "SeedOssForCausalLM",
"vllm_kunlun.models.seed_oss:SeedOssForCausalLM") "vllm_kunlun.models.seed_oss:SeedOssForCausalLM")
ModelRegistry.register_model(
"LlamaForCausalLM",
"vllm_kunlun.models.llama:LlamaForCausalLM")
def register_quant_method(): def register_quant_method():
"""to do""" """to do"""

View File

@@ -24,6 +24,7 @@
# limitations under the License. # limitations under the License.
"""Inference-only LLaMA model compatible with HuggingFace weights.""" """Inference-only LLaMA model compatible with HuggingFace weights."""
from collections.abc import Iterable from collections.abc import Iterable
from itertools import islice
from typing import Any, Optional, Union from typing import Any, Optional, Union
import torch import torch
@@ -37,20 +38,22 @@ from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm_kunlun.ops.activation import SiluAndMul from vllm_kunlun.ops.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, from vllm_kunlun.ops.linear import (MergedColumnParallelLinear,
QKVParallelLinear, QKVParallelLinear,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm_kunlun.ops.vocab_parallel_embedding import VocabParallelEmbedding
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead)
from vllm.model_executor.model_loader.weight_utils import ( from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, maybe_remap_kv_scale_name) default_weight_loader, maybe_remap_kv_scale_name)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.model_executor.models.interfaces import SupportsLoRA, SupportsPP from vllm.model_executor.models.interfaces import SupportsEagle3, SupportsLoRA, SupportsPP
from vllm.model_executor.models.utils import (AutoWeightsLoader, PPMissingLayer, extract_layer_index, from vllm.model_executor.models.utils import (AutoWeightsLoader, PPMissingLayer, extract_layer_index,
is_pp_missing_parameter, is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers, make_empty_intermediate_tensors_factory, make_layers,
@@ -68,6 +71,7 @@ class LlamaMLP(nn.Module):
bias: bool = False, bias: bool = False,
prefix: str = "", prefix: str = "",
reduce_results: bool = True, reduce_results: bool = True,
disable_tp: bool = False,
) -> None: ) -> None:
super().__init__() super().__init__()
self.gate_up_proj = MergedColumnParallelLinear( self.gate_up_proj = MergedColumnParallelLinear(
@@ -75,6 +79,7 @@ class LlamaMLP(nn.Module):
output_sizes=[intermediate_size] * 2, output_sizes=[intermediate_size] * 2,
bias=bias, bias=bias,
quant_config=quant_config, quant_config=quant_config,
disable_tp=disable_tp,
prefix=f"{prefix}.gate_up_proj", prefix=f"{prefix}.gate_up_proj",
) )
self.down_proj = RowParallelLinear( self.down_proj = RowParallelLinear(
@@ -83,6 +88,7 @@ class LlamaMLP(nn.Module):
bias=bias, bias=bias,
quant_config=quant_config, quant_config=quant_config,
reduce_results=reduce_results, reduce_results=reduce_results,
disable_tp=disable_tp,
prefix=f"{prefix}.down_proj", prefix=f"{prefix}.down_proj",
) )
if hidden_act != "silu": if hidden_act != "silu":
@@ -168,20 +174,31 @@ class LlamaAttention(nn.Module):
rope_scaling=rope_scaling, rope_scaling=rope_scaling,
quant_config=quant_config) quant_config=quant_config)
if hasattr(config, "interleaved_sliding_window"): sliding_window = None
interleaved_sliding_window = config.interleaved_sliding_window if layer_types := getattr(config, "layer_types", None):
if isinstance(interleaved_sliding_window, int): # Fix for Eagle3 compatibility:
sliding_window = interleaved_sliding_window # for draft models, subtract target layer count
elif isinstance(interleaved_sliding_window, list): # to get draft-relative layer index starting from 0
sw_idx = layer_idx % len(interleaved_sliding_window) if hasattr(config, 'target_layer_count'):
sliding_window = interleaved_sliding_window[sw_idx] # This is a draft model,
# adjust layer_idx to be relative to draft layers
effective_layer_idx = layer_idx - config.target_layer_count
else: else:
raise ValueError( # This is a target model, use layer_idx directly
f"{type(interleaved_sliding_window)} is not supported.") effective_layer_idx = layer_idx
else: assert effective_layer_idx < len(layer_types), \
sliding_window = None f"effective_layer_idx: {effective_layer_idx} \
is out of bounds for layer_types: {layer_types}"
self.attn = Attention( is_sliding = layer_types[
effective_layer_idx] == "sliding_attention"
if is_sliding:
sliding_window = config.sliding_window
attn_cls = (EncoderOnlyAttention
if attn_type == AttentionType.ENCODER_ONLY else Attention)
self.attn = attn_cls(
self.num_heads, self.num_heads,
self.head_dim, self.head_dim,
self.scaling, self.scaling,
@@ -200,8 +217,7 @@ class LlamaAttention(nn.Module):
) -> torch.Tensor: ) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states) qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
#TODO@hanhaowen:use kunlun ops to speed up q, k = self.rotary_emb(positions, q, k)
q, k = self.rotary_emb.forward_native(positions, q, k)
attn_output = self.attn(q, k, v) attn_output = self.attn(q, k, v)
output, _ = self.o_proj(attn_output) output, _ = self.o_proj(attn_output)
return output return output
@@ -227,14 +243,16 @@ class LlamaAttention(nn.Module):
class LlamaDecoderLayer(nn.Module): class LlamaDecoderLayer(nn.Module):
def __init__( def __init__(self,
self, vllm_config: VllmConfig,
config: LlamaConfig, prefix: str = "",
cache_config: Optional[CacheConfig] = None, config: Optional[LlamaConfig] = None) -> None:
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__() super().__init__()
config = config or vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
rope_theta = getattr(config, "rope_theta", 10000) rope_theta = getattr(config, "rope_theta", 10000)
rope_scaling = getattr(config, "rope_scaling", None) rope_scaling = getattr(config, "rope_scaling", None)
@@ -306,6 +324,7 @@ class LlamaDecoderLayer(nn.Module):
hidden_states, residual) hidden_states, residual)
hidden_states = self.self_attn(positions=positions, hidden_states = self.self_attn(positions=positions,
hidden_states=hidden_states) hidden_states=hidden_states)
# Fully Connected # Fully Connected
hidden_states, residual = self.post_attention_layernorm( hidden_states, residual = self.post_attention_layernorm(
hidden_states, residual) hidden_states, residual)
@@ -313,7 +332,7 @@ class LlamaDecoderLayer(nn.Module):
return hidden_states, residual return hidden_states, residual
# @support_torch_compile @support_torch_compile
class LlamaModel(nn.Module): class LlamaModel(nn.Module):
def __init__(self, def __init__(self,
@@ -324,7 +343,6 @@ class LlamaModel(nn.Module):
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config lora_config = vllm_config.lora_config
@@ -346,10 +364,7 @@ class LlamaModel(nn.Module):
self.embed_tokens = PPMissingLayer() self.embed_tokens = PPMissingLayer()
self.start_layer, self.end_layer, self.layers = make_layers( self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers, config.num_hidden_layers,
lambda prefix: layer_type(config=config, lambda prefix: layer_type(vllm_config=vllm_config, prefix=prefix),
cache_config=cache_config,
quant_config=quant_config,
prefix=prefix),
prefix=f"{prefix}.layers", prefix=f"{prefix}.layers",
) )
if get_pp_group().is_last_rank: if get_pp_group().is_last_rank:
@@ -357,7 +372,7 @@ class LlamaModel(nn.Module):
else: else:
self.norm = PPMissingLayer() self.norm = PPMissingLayer()
self.aux_hidden_state_layers: tuple[int] = tuple() self.aux_hidden_state_layers = tuple[int, ...]()
self.make_empty_intermediate_tensors = ( self.make_empty_intermediate_tensors = (
make_empty_intermediate_tensors_factory( make_empty_intermediate_tensors_factory(
@@ -387,7 +402,7 @@ class LlamaModel(nn.Module):
aux_hidden_states = [] aux_hidden_states = []
for idx, layer in enumerate( for idx, layer in enumerate(
self.layers[self.start_layer:self.end_layer]): islice(self.layers, self.start_layer, self.end_layer)):
if idx in self.aux_hidden_state_layers: if idx in self.aux_hidden_state_layers:
aux_hidden_states.append(hidden_states + residual) aux_hidden_states.append(hidden_states + residual)
hidden_states, residual = layer(positions, hidden_states, residual) hidden_states, residual = layer(positions, hidden_states, residual)
@@ -471,7 +486,7 @@ class LlamaModel(nn.Module):
return loaded_params return loaded_params
class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP): class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP, SupportsEagle3):
packed_modules_mapping = { packed_modules_mapping = {
"qkv_proj": ["q_proj", "k_proj", "v_proj"], "qkv_proj": ["q_proj", "k_proj", "v_proj"],
"gate_up_proj": ["gate_proj", "up_proj"] "gate_up_proj": ["gate_proj", "up_proj"]
@@ -557,10 +572,10 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
self.make_empty_intermediate_tensors = ( self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors) self.model.make_empty_intermediate_tensors)
def set_aux_hidden_state_layers(self, layers: tuple[int]) -> None: def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None:
self.model.aux_hidden_state_layers = layers self.model.aux_hidden_state_layers = layers
def get_eagle3_aux_hidden_state_layers(self) -> tuple[int]: def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]:
num_layers = len(self.model.layers) num_layers = len(self.model.layers)
return (2, num_layers // 2, num_layers - 3) return (2, num_layers // 2, num_layers - 3)
@@ -589,10 +604,8 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
def compute_logits( def compute_logits(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]: ) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.lm_head, hidden_states, logits = self.logits_processor(self.lm_head, hidden_states)
sampling_metadata)
return logits return logits
def load_weights(self, weights: Iterable[tuple[str, def load_weights(self, weights: Iterable[tuple[str,
@@ -614,9 +627,8 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
loaded_weight: torch.Tensor, loaded_weight: torch.Tensor,
) -> tuple[str, torch.Tensor]: ) -> tuple[str, torch.Tensor]:
def permute(w: torch.Tensor, n_heads: int): def permute(w: torch.Tensor, n_heads: int, attn_out: int):
attn_in = self.config.head_dim * n_heads attn_in = self.config.head_dim * n_heads
attn_out = self.config.hidden_size
return w.view(n_heads, attn_in // n_heads // 2, 2, return w.view(n_heads, attn_in // n_heads // 2, 2,
attn_out).transpose(1, 2).reshape(attn_in, attn_out) attn_out).transpose(1, 2).reshape(attn_in, attn_out)
@@ -625,12 +637,24 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
modules = name.split(".") modules = name.split(".")
# rotary embeds should be sliced # rotary embeds should be sliced
# If using quantized model in mistral format,
# quantization scales (qscale_weight) also need to be sliced
if "wk" in modules and modules[-1] == "weight": if "wk" in modules and modules[-1] == "weight":
loaded_weight = permute(loaded_weight, loaded_weight = permute(loaded_weight,
self.config.num_key_value_heads) self.config.num_key_value_heads,
self.config.hidden_size)
elif "wk" in modules and modules[
-1] == "qscale_weight" and loaded_weight.numel() > 1:
loaded_weight = permute(loaded_weight,
self.config.num_key_value_heads, 1)
elif "wq" in modules and modules[-1] == "weight": elif "wq" in modules and modules[-1] == "weight":
loaded_weight = permute(loaded_weight, loaded_weight = permute(loaded_weight,
self.config.num_attention_heads) self.config.num_attention_heads,
self.config.hidden_size)
elif "wq" in modules and modules[
-1] == "qscale_weight" and loaded_weight.numel() > 1:
loaded_weight = permute(loaded_weight,
self.config.num_attention_heads, 1)
num_modules = len(modules) num_modules = len(modules)
for i in range(num_modules): for i in range(num_modules):
@@ -646,3 +670,4 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
name = name.replace(item, mapping[item]) name = name.replace(item, mapping[item])
return name, loaded_weight return name, loaded_weight