testing dynamic register

This commit is contained in:
Chranos
2026-02-05 17:11:09 +08:00
parent b399840b8d
commit 6e38461af6
3 changed files with 657 additions and 98 deletions

View File

@@ -2,13 +2,39 @@
# Copyright 2024 The vLLM team.
"""Wrapper around `transformers` models for vLLM v0.6.2.
This module provides a simplified Transformers modeling backend that wraps
This module provides an advanced Transformers modeling backend that wraps
any HuggingFace model with the vLLM interface, enabling support for custom
models that define their implementation via `auto_map` in config.json.
Key optimizations and features:
- Meta device initialization for memory efficiency
- Module replacement (Linear, RMSNorm, Embedding) with vLLM optimized versions
- VocabParallelEmbedding for input embeddings
- vLLM Attention instances for proper KV cache allocation
- AutoWeightsLoader for efficient weight loading with name mapping
- vLLM attention backend integration (when supported by upgraded transformers)
"""
from vllm.model_executor.models.transformers.causal import TransformersForCausalLM
from vllm.model_executor.models.transformers.causal import (
TransformersForCausalLM,
is_backend_compatible,
)
from vllm.model_executor.models.transformers.utils import (
init_on_device_without_buffers,
replace_linear_class,
replace_rms_norm_class,
log_replacement,
maybe_prefix,
)
__all__ = [
# Main wrapper classes
"TransformersForCausalLM",
"is_backend_compatible",
# Utility functions
"init_on_device_without_buffers",
"replace_linear_class",
"replace_rms_norm_class",
"log_replacement",
"maybe_prefix",
]

View File

@@ -6,11 +6,15 @@ This module provides a wrapper class that enables vLLM to use any HuggingFace
causal language model, including custom models that define their implementation
via `auto_map` in config.json.
The key insight is that we use HuggingFace's AutoModelForCausalLM to load the
actual model, then wrap it with the vLLM interface (compute_logits, sample, etc).
Key optimizations:
1. Use meta device for delayed memory allocation
2. Replace nn.Linear with vLLM's optimized Linear classes
3. Replace RMSNorm with vLLM's fused RMSNorm
4. Replace input embeddings with VocabParallelEmbedding
5. Use vLLM's weight loading infrastructure (AutoWeightsLoader)
"""
from typing import TYPE_CHECKING, Iterable, List, Optional, Set, Tuple
from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Set, Tuple, Union
import torch
import torch.nn as nn
@@ -19,8 +23,22 @@ from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding,
)
from vllm.attention.layer import Attention
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.models.utils import (
AutoWeightsLoader,
WeightsMapper,
)
from vllm.model_executor.models.transformers.utils import (
init_on_device_without_buffers,
replace_linear_class,
replace_rms_norm_class,
log_replacement,
maybe_prefix,
)
from vllm.sequence import IntermediateTensors
if TYPE_CHECKING:
@@ -30,25 +48,153 @@ if TYPE_CHECKING:
logger = init_logger(__name__)
# Note: In v0.6.2, the vLLM Attention.forward requires (query, key, value, kv_cache, attn_metadata).
# The transformers backend integration works differently than in latest vLLM.
# We keep the vllm_flash_attention_forward for reference, but it may not be compatible
# with all transformers versions or MLU backends.
# Global variable to store current attention metadata (set during forward pass)
_current_attn_metadata = None
_current_kv_caches = None
def set_attention_context(attn_metadata, kv_caches):
"""Set the current attention context for vLLM attention functions."""
global _current_attn_metadata, _current_kv_caches
_current_attn_metadata = attn_metadata
_current_kv_caches = kv_caches
def clear_attention_context():
"""Clear the current attention context."""
global _current_attn_metadata, _current_kv_caches
_current_attn_metadata = None
_current_kv_caches = None
def vllm_flash_attention_forward(
# Transformers args
module: torch.nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask: torch.Tensor,
# Transformers kwargs
scaling: float = None,
# vLLM kwargs
attention_instances: Dict[int, Attention] = None,
**kwargs,
):
"""
vLLM's optimized attention function that replaces HuggingFace's attention.
This function is registered to transformers' ALL_ATTENTION_FUNCTIONS.
Note: In v0.6.2, this function may have limited functionality due to
API differences in the Attention layer. For full functionality, the model
should fall back to HuggingFace's native attention when vLLM attention
is not properly configured.
"""
# Get the attention instance for this layer
layer_idx = getattr(module, 'layer_idx', 0)
if attention_instances is None or layer_idx not in attention_instances:
# Fall back to standard attention computation
logger.debug("No attention instance for layer %d, using standard attention", layer_idx)
# Standard scaled dot-product attention
attn_weights = torch.matmul(query, key.transpose(-2, -1))
if scaling is not None:
attn_weights = attn_weights * scaling
if attention_mask is not None:
attn_weights = attn_weights + attention_mask
attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1)
attn_output = torch.matmul(attn_weights, value)
return attn_output, None
self_attn = attention_instances[layer_idx]
# v0.6.2 Attention.forward requires: (query, key, value, kv_cache, attn_metadata)
# We need to get these from the global context
global _current_attn_metadata, _current_kv_caches
if _current_attn_metadata is None or _current_kv_caches is None:
# No context set, fall back to standard attention
logger.debug("No attention context, using standard attention for layer %d", layer_idx)
attn_weights = torch.matmul(query, key.transpose(-2, -1))
if scaling is not None:
attn_weights = attn_weights * scaling
if attention_mask is not None:
attn_weights = attn_weights + attention_mask
attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1)
attn_output = torch.matmul(attn_weights, value)
return attn_output, None
# Update scale if provided
if scaling is not None:
self_attn.impl.scale = float(scaling)
# Reshape tensors for vLLM: [batch, heads, seq, head_dim] -> [seq, heads * head_dim]
hidden = query.shape[-2]
query, key, value = (x.transpose(1, 2) for x in (query, key, value))
query, key, value = (x.reshape(hidden, -1) for x in (query, key, value))
# Get KV cache for this layer
kv_cache = _current_kv_caches[layer_idx] if layer_idx < len(_current_kv_caches) else None
# Call vLLM attention
output = self_attn.forward(query, key, value, kv_cache, _current_attn_metadata)
return output, None
# Try to register vLLM attention to transformers
_vllm_attention_registered = False
try:
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
ALL_ATTENTION_FUNCTIONS["vllm"] = vllm_flash_attention_forward
_vllm_attention_registered = True
logger.info("Registered vLLM attention function to transformers")
except (ImportError, AttributeError) as e:
logger.warning("Could not register vLLM attention function - "
"transformers version may not support custom attention: %s", e)
class TransformersForCausalLM(nn.Module):
"""
A wrapper class that adapts any HuggingFace causal language model
to the vLLM interface.
to the vLLM interface with memory optimizations.
This class provides:
1. forward() - processes input through the model
2. compute_logits() - computes output logits
3. sample() - samples tokens from logits
4. load_weights() - loads model weights
The actual HuggingFace model is loaded using AutoModelForCausalLM and
stored in self.model.
Key optimizations (following latest vLLM):
1. Meta device initialization - no GPU memory until weights are loaded
2. Module replacement - Linear/RMSNorm replaced with vLLM optimized versions
3. VocabParallelEmbedding for input embeddings
4. AutoWeightsLoader for efficient weight loading with name mapping
Interface compliance:
- Implements VllmModel protocol (vllm_config init, forward with required args)
- Implements VllmModelForTextGeneration protocol (compute_logits, sample)
"""
# Weight name mapping from HuggingFace to vLLM format
hf_to_vllm_mapper = WeightsMapper(
orig_to_new_prefix={
# Add `model.` prefix for base model checkpoints
"": "model.",
"model.model.": "model.",
# Heads will be adjacent to `model`
"model.lm_head.": "lm_head.",
"lm_head.": "lm_head.",
# Handle different model architectures
"transformer.": "model.",
"model.transformer.": "model.",
# Handle embeddings
"embed_tokens.": "model.embed_tokens.",
"model.embed_tokens.": "model.embed_tokens.",
# Handle attention weights
"self_attn.": "self_attn.",
"attention.": "self_attn.",
}
)
def __init__(
self,
vllm_config: VllmConfig,
@@ -61,42 +207,250 @@ class TransformersForCausalLM(nn.Module):
quant_config = vllm_config.quant_config
self.config = config.hf_config
self.text_config = getattr(self.config, "text_config", self.config)
self.model_config = config
self.cache_config = cache_config
self.device_config = vllm_config.device_config
self.quant_config = quant_config
self.prefix = prefix
# Get model dimensions from config
self.hidden_size = getattr(self.text_config, "hidden_size", 4096)
self.vocab_size = getattr(self.text_config, "vocab_size", 32000)
# Weight loading configuration
self.skip_prefixes: List[str] = []
self.ignore_unexpected_prefixes: List[str] = [
"model.layers.*.self_attn.rotary_emb.inv_freq", # Skip RoPE weights
"model.norm.bias", # Some models don't have bias in final norm
]
logger.info("Using Transformers modeling backend for %s",
config.hf_config.architectures)
# Load the actual HuggingFace model
self._load_hf_model()
# Configure vLLM attention backend
self._configure_attention_backend()
# Load the HuggingFace model structure on meta device (no memory allocation)
self._load_hf_model_on_meta()
# Replace modules with vLLM optimized versions
self._replace_modules()
# Replace input embeddings with VocabParallelEmbedding
self._replace_input_embeddings()
# Create attention instances for KV cache
self.attention_instances = self._create_attention_instances()
# Initialize parameters (allocate memory on target device)
self._init_parameters()
# Setup logits processor and sampler
self.logits_processor = LogitsProcessor(
self.config.vocab_size,
self.vocab_size,
logits_as_input=False,
)
self.sampler = Sampler()
def _load_hf_model(self) -> None:
"""Load the HuggingFace model using AutoModelForCausalLM."""
def _load_hf_model_on_meta(self) -> None:
"""
Load the HuggingFace model structure on meta device.
This creates the model structure without allocating GPU memory.
Memory will be allocated later during weight loading.
"""
from transformers import AutoModelForCausalLM
# We load with minimal config first - weights will be loaded separately
# by vLLM's weight loader
logger.info("Loading HuggingFace model from config...")
logger.info("Creating model structure on meta device...")
self.model: "PreTrainedModel" = AutoModelForCausalLM.from_config(
self.config,
torch_dtype=self.model_config.dtype,
trust_remote_code=self.model_config.trust_remote_code,
)
# Create model on meta device - no GPU memory allocated
with init_on_device_without_buffers("meta"):
self.model: "PreTrainedModel" = AutoModelForCausalLM.from_config(
self.config,
torch_dtype=self.model_config.dtype,
trust_remote_code=self.model_config.trust_remote_code,
)
# Disable gradient computation for inference
self.model.eval()
for param in self.model.parameters():
param.requires_grad = False
logger.info("Model structure created on meta device")
def _replace_modules(self) -> None:
"""
Replace HuggingFace modules with vLLM optimized versions.
This replaces:
- nn.Linear with ReplicatedLinear (memory efficient, supports quantization)
- RMSNorm variants with vLLM's fused RMSNorm
"""
logger.info("Replacing modules with vLLM optimized versions...")
replaced_count = 0
def _recursive_replace(module: nn.Module, prefix: str = ""):
nonlocal replaced_count
for name, child in list(module.named_children()):
qual_name = maybe_prefix(prefix, name)
if isinstance(child, nn.Linear):
# Replace Linear with vLLM's ReplicatedLinear
new_module = replace_linear_class(
child,
style="replicate",
quant_config=self.quant_config,
prefix=qual_name,
)
setattr(module, name, new_module)
log_replacement(qual_name, child, new_module)
replaced_count += 1
elif child.__class__.__name__.endswith("RMSNorm"):
# Replace RMSNorm with vLLM's optimized version
new_module = replace_rms_norm_class(child, self.hidden_size)
setattr(module, name, new_module)
log_replacement(qual_name, child, new_module)
replaced_count += 1
elif child.__class__.__name__.endswith(("LayerNorm", "GroupNorm")):
# Also handle other normalization layers
logger.debug("Found normalization layer %s: %s", qual_name, type(child).__name__)
# Could add specialized replacement here if needed
_recursive_replace(child, qual_name)
elif "Attention" in child.__class__.__name__:
# Mark attention layers for potential replacement
logger.debug("Found attention layer %s: %s", qual_name, type(child).__name__)
# Note: We don't replace the attention module itself,
# but we create separate vLLM Attention instances
_recursive_replace(child, qual_name)
else:
# Recursively process children
_recursive_replace(child, qual_name)
_recursive_replace(self.model, "model")
logger.info("Replaced %d modules with vLLM optimized versions", replaced_count)
def _replace_input_embeddings(self) -> None:
"""
Replace the input embeddings with VocabParallelEmbedding.
This provides memory efficiency for large vocabularies.
"""
input_embeddings = self.model.get_input_embeddings()
if input_embeddings is None:
logger.warning("Could not find input embeddings to replace")
return
# Get embedding dimension
if hasattr(input_embeddings, "embedding_dim"):
embedding_dim = input_embeddings.embedding_dim
elif hasattr(input_embeddings, "weight"):
embedding_dim = input_embeddings.weight.shape[1]
else:
embedding_dim = self.hidden_size
# Store embed_scale if present (some models scale embeddings)
self.embed_scale = getattr(input_embeddings, "embed_scale", None)
logger.info("Replacing input embeddings with VocabParallelEmbedding "
"(vocab_size=%d, embedding_dim=%d)", self.vocab_size, embedding_dim)
new_embeddings = VocabParallelEmbedding(
self.vocab_size,
embedding_dim,
org_num_embeddings=self.vocab_size,
quant_config=self.quant_config,
)
self.model.set_input_embeddings(new_embeddings)
def _init_parameters(self) -> None:
"""
Initialize parameters from meta device to target device.
This allocates the actual GPU memory for all parameters.
"""
logger.info("Initializing parameters on target device...")
# Use device_config to get the correct device (supports MLU, CUDA, etc.)
device = self.device_config.device
if device is None:
# Fallback for TPU or other special cases
device = torch.device("cpu")
dtype = self.model_config.dtype
def _init_params(module: nn.Module):
for name, param in list(module.named_parameters(recurse=False)):
if param.device == torch.device("meta"):
new_param = nn.Parameter(
torch.empty_like(
param.data,
dtype=dtype,
device=device,
),
requires_grad=False,
)
setattr(module, name, new_param)
for child in module.children():
_init_params(child)
_init_params(self.model)
logger.info("Parameters initialized on %s", device)
def _create_attention_instances(self) -> Dict[int, Attention]:
"""
Create vLLM Attention instances for each layer.
This enables proper KV cache allocation and vLLM's optimized attention.
Returns a dict mapping layer_idx to Attention instance.
"""
attention_instances: Dict[int, Attention] = {}
num_layers = getattr(self.text_config, "num_hidden_layers",
getattr(self.text_config, "num_layers", 32))
num_heads = getattr(self.text_config, "num_attention_heads", 32)
head_size = self.hidden_size // num_heads
scale = 1.0 / (head_size ** 0.5)
num_kv_heads = getattr(self.text_config, "num_key_value_heads", num_heads)
logger.info("Creating %d attention instances for KV cache "
"(num_heads=%d, head_size=%d, num_kv_heads=%d)",
num_layers, num_heads, head_size, num_kv_heads)
for layer_idx in range(num_layers):
attention = Attention(
num_heads=num_heads,
head_size=head_size,
scale=scale,
num_kv_heads=num_kv_heads,
cache_config=self.cache_config,
quant_config=self.quant_config,
prefix=f"model.layers.{layer_idx}.self_attn",
)
attention_instances[layer_idx] = attention
return attention_instances
def _configure_attention_backend(self) -> None:
"""
Configure vLLM attention backend for the model.
This sets up the attention implementation BEFORE model creation.
Only sets 'vllm' implementation if the attention function was registered.
"""
global _vllm_attention_registered
if _vllm_attention_registered:
# Set vLLM attention implementation in config (must be before from_config)
self.text_config._attn_implementation = "vllm"
logger.info("Set attention implementation to 'vllm' in text_config")
else:
# Use default eager attention if vLLM attention is not available
logger.info("Using default HuggingFace attention (vLLM attention not registered)")
def forward(
self,
@@ -114,49 +468,66 @@ class TransformersForCausalLM(nn.Module):
This method conforms to the VllmModel protocol by accepting:
- input_ids: Token IDs
- positions: Position IDs
- kv_caches: KV cache tensors (not used in basic HF forward)
- attn_metadata: Attention metadata (not used in basic HF forward)
Note: This is a simplified implementation that does not use vLLM's
optimized attention mechanisms. For production use with KV caching,
a more sophisticated implementation would be needed.
- kv_caches: KV cache tensors
- attn_metadata: Attention metadata
"""
# For simplicity, we use HuggingFace's native forward
# This won't have vLLM's optimizations but will work
# Set attention context for vLLM attention function
set_attention_context(attn_metadata, kv_caches)
if inputs_embeds is not None:
model_inputs = {"inputs_embeds": inputs_embeds}
else:
model_inputs = {"input_ids": input_ids.unsqueeze(0) if input_ids.dim() == 1 else input_ids}
# Position IDs
if positions is not None:
model_inputs["position_ids"] = positions.unsqueeze(0) if positions.dim() == 1 else positions
# Run the model
with torch.no_grad():
outputs = self.model(
**model_inputs,
use_cache=False,
return_dict=True,
)
# Get hidden states from the last layer
# For CausalLM, we typically want the hidden states before the LM head
if hasattr(outputs, "hidden_states") and outputs.hidden_states is not None:
hidden_states = outputs.hidden_states[-1]
else:
# Fall back to running without output_hidden_states
# and getting logits directly
hidden_states = outputs.logits
if hidden_states.dim() == 3:
try:
# Prepare inputs - add batch dimension if needed
if inputs_embeds is not None:
if inputs_embeds.dim() == 2:
inputs_embeds = inputs_embeds.unsqueeze(0)
model_inputs = {"inputs_embeds": inputs_embeds}
else:
if input_ids.dim() == 1:
input_ids = input_ids.unsqueeze(0)
model_inputs = {"input_ids": input_ids}
# Position IDs - add batch dimension
if positions is not None:
if positions.dim() == 1:
positions = positions.unsqueeze(0)
model_inputs["position_ids"] = positions
# Apply embed_scale if needed
if (
self.embed_scale is not None
and "input_ids" in model_inputs
and "inputs_embeds" not in model_inputs
):
inputs_embeds = self.model.get_input_embeddings()(model_inputs["input_ids"])
inputs_embeds = inputs_embeds * self.embed_scale
model_inputs = {"inputs_embeds": inputs_embeds}
if positions is not None:
model_inputs["position_ids"] = positions
# Run the model with vLLM attention instances
with torch.no_grad():
outputs = self.model(
**model_inputs,
use_cache=False,
return_dict=True,
output_hidden_states=True,
attention_instances=self.attention_instances,
)
# Get hidden states from the last layer
if outputs.hidden_states is not None:
hidden_states = outputs.hidden_states[-1]
else:
# Fallback: use logits directly
hidden_states = outputs.logits
# Remove batch dimension
if hidden_states.dim() == 3 and hidden_states.size(0) == 1:
hidden_states = hidden_states.squeeze(0)
return hidden_states
if hidden_states.dim() == 3:
hidden_states = hidden_states.squeeze(0)
return hidden_states
finally:
# Clear attention context
clear_attention_context()
def compute_logits(
self,
@@ -168,12 +539,24 @@ class TransformersForCausalLM(nn.Module):
This method conforms to the VllmModelForTextGeneration protocol.
"""
# If hidden_states are already logits (from forward), process them
if hidden_states.shape[-1] == self.config.vocab_size:
# Check if hidden_states are already logits
if hidden_states.shape[-1] == self.vocab_size:
logits = hidden_states
else:
# Apply the LM head
logits = self.model.lm_head(hidden_states)
lm_head = getattr(self.model, "lm_head", None)
if lm_head is None:
lm_head = getattr(self.model, "embed_out", None)
if lm_head is not None:
output = lm_head(hidden_states)
# Handle tuple output from vLLM Linear layers
if isinstance(output, tuple):
logits = output[0]
else:
logits = output
else:
logits = hidden_states
return self.logits_processor(None, logits, sampling_metadata)
@@ -195,40 +578,23 @@ class TransformersForCausalLM(nn.Module):
weights: Iterable[Tuple[str, torch.Tensor]],
) -> Set[str]:
"""
Load weights into the model.
Load weights into the model using AutoWeightsLoader.
This method loads weights from an iterable of (name, tensor) pairs
into the HuggingFace model.
This uses vLLM's efficient weight loading infrastructure with
automatic name mapping.
"""
loaded_params: Set[str] = set()
model_params = dict(self.model.named_parameters())
for name, loaded_weight in weights:
# Try to find the parameter in the model
if name in model_params:
param = model_params[name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
else:
# Try common prefixes
for prefix in ["model.", ""]:
full_name = f"{prefix}{name}" if prefix else name
if full_name in model_params:
param = model_params[full_name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
break
return loaded_params
loader = AutoWeightsLoader(
self,
skip_prefixes=self.skip_prefixes,
ignore_unexpected_prefixes=self.ignore_unexpected_prefixes,
)
loaded = loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
logger.info("Loaded %d weight tensors", len(loaded))
return set(loaded)
def is_backend_compatible() -> bool:
"""
Check if the current model is compatible with the Transformers backend.
This is a simplified check - in practice, compatibility depends on
whether the model follows standard HuggingFace conventions.
"""
return True

View File

@@ -0,0 +1,167 @@
# SPDX-License-Identifier: Apache-2.0
# Copyright 2024 The vLLM team.
"""Transformers modeling backend utilities for v0.6.2.
This module provides utility functions for the Transformers backend,
including context managers for meta device initialization and
module replacement functions.
"""
from contextlib import contextmanager
from typing import TYPE_CHECKING, Literal, Optional, Union
import torch
import torch.nn as nn
from vllm.logger import init_logger
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (
ColumnParallelLinear,
ReplicatedLinear,
RowParallelLinear,
)
if TYPE_CHECKING:
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig,
)
logger = init_logger(__name__)
@contextmanager
def init_on_device_without_buffers(device: Union[str, torch.device]):
"""
A context manager under which models are initialized with all
parameters on the specified device. However buffers are not
initialized on specified device.
This is useful for creating model structure without allocating
GPU memory, which is essential for memory efficiency.
Args:
device: Device to initialize all parameters on (e.g., "meta").
Example:
with init_on_device_without_buffers("meta"):
model = AutoModel.from_config(config)
# Now model is on meta device, no GPU memory allocated
"""
if isinstance(device, str):
device = torch.device(device)
old_register_parameter = nn.Module.register_parameter
def register_empty_parameter(module, name, param):
old_register_parameter(module, name, param)
if param is not None:
param_cls = type(module._parameters[name])
kwargs = module._parameters[name].__dict__
kwargs["requires_grad"] = param.requires_grad
module._parameters[name] = param_cls(
module._parameters[name].to(device), **kwargs
)
try:
nn.Module.register_parameter = register_empty_parameter
yield
finally:
nn.Module.register_parameter = old_register_parameter
# Linear replacement styles
Style = Literal["colwise", "colwise_rep", "rowwise", "rowwise_rep", "replicate"]
def replace_linear_class(
linear: nn.Linear,
style: Style = "replicate",
quant_config: Optional["QuantizationConfig"] = None,
prefix: str = "",
) -> Union[ColumnParallelLinear, RowParallelLinear, ReplicatedLinear]:
"""
Replace nn.Linear with one of vLLM's tensor parallel linear classes.
This replacement provides:
- Memory efficiency through proper tensor allocation
- Support for quantization
- Tensor parallel support (when using ColumnParallel/RowParallel)
Args:
linear: `nn.Linear` to be replaced.
style: Tensor parallel style of the new linear:
- "colwise": Column parallel (split output dim)
- "colwise_rep": Column parallel with gather output
- "rowwise": Row parallel (split input dim)
- "rowwise_rep": Row parallel without parallel input
- "replicate": Replicated (no parallelism)
quant_config: Quantization config for the new linear.
prefix: The name of the layer for weight loading.
Returns:
The new vLLM linear layer.
"""
if not isinstance(style, str):
raise ValueError(f"Unsupported parallel style type {type(style)}, expected str")
vllm_linear_cls, vllm_linear_kwargs = {
"colwise": (ColumnParallelLinear, {}),
"colwise_rep": (ColumnParallelLinear, {"gather_output": True}),
"rowwise": (RowParallelLinear, {}),
"rowwise_rep": (RowParallelLinear, {"input_is_parallel": False}),
"replicate": (ReplicatedLinear, {}),
}.get(style, (ReplicatedLinear, {}))
return vllm_linear_cls(
input_size=linear.in_features,
output_size=linear.out_features,
bias=linear.bias is not None,
quant_config=quant_config,
prefix=prefix,
**vllm_linear_kwargs,
)
def replace_rms_norm_class(
rms_norm: nn.Module,
hidden_size: int,
) -> RMSNorm:
"""
Replace a Transformers RMSNorm with vLLM's optimized RMSNorm.
vLLM's RMSNorm provides:
- Fused CUDA kernels for better performance
- Support for fused add + norm operations
Args:
rms_norm: The RMSNorm module to replace.
hidden_size: The hidden size of the model.
Returns:
The new vLLM RMSNorm layer.
"""
# Try to get epsilon from various attribute names
eps = getattr(rms_norm, "eps", None)
if eps is None:
eps = getattr(rms_norm, "variance_epsilon", None)
if eps is None:
eps = 1e-6
# Check if weight exists and get its size
weight = getattr(rms_norm, "weight", None)
if weight is not None:
hidden_size = weight.size(0)
return RMSNorm(hidden_size=hidden_size, eps=eps)
def log_replacement(name: str, old_module: nn.Module, new_module: nn.Module):
"""Log module replacement for debugging."""
logger.debug("Replaced %s: %s -> %s", name, type(old_module).__name__, type(new_module).__name__)
def maybe_prefix(prefix: str, name: str) -> str:
"""Combine prefix and name with a dot separator."""
if prefix:
return f"{prefix}.{name}"
return name