forked from EngineX-Cambricon/enginex-mlu370-vllm
testing dynamic register
This commit is contained in:
234
vllm-v0.6.2/vllm/model_executor/models/transformers/causal.py
Normal file
234
vllm-v0.6.2/vllm/model_executor/models/transformers/causal.py
Normal file
@@ -0,0 +1,234 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# Copyright 2024 The vLLM team.
|
||||
"""Transformers modeling backend for causal language models.
|
||||
|
||||
This module provides a wrapper class that enables vLLM to use any HuggingFace
|
||||
causal language model, including custom models that define their implementation
|
||||
via `auto_map` in config.json.
|
||||
|
||||
The key insight is that we use HuggingFace's AutoModelForCausalLM to load the
|
||||
actual model, then wrap it with the vLLM interface (compute_logits, sample, etc).
|
||||
"""
|
||||
|
||||
from typing import TYPE_CHECKING, Iterable, List, Optional, Set, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import PreTrainedModel
|
||||
from vllm.attention import AttentionMetadata
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class TransformersForCausalLM(nn.Module):
|
||||
"""
|
||||
A wrapper class that adapts any HuggingFace causal language model
|
||||
to the vLLM interface.
|
||||
|
||||
This class provides:
|
||||
1. forward() - processes input through the model
|
||||
2. compute_logits() - computes output logits
|
||||
3. sample() - samples tokens from logits
|
||||
4. load_weights() - loads model weights
|
||||
|
||||
The actual HuggingFace model is loaded using AutoModelForCausalLM and
|
||||
stored in self.model.
|
||||
|
||||
Interface compliance:
|
||||
- Implements VllmModel protocol (vllm_config init, forward with required args)
|
||||
- Implements VllmModelForTextGeneration protocol (compute_logits, sample)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
config = vllm_config.model_config
|
||||
cache_config = vllm_config.cache_config
|
||||
quant_config = vllm_config.quant_config
|
||||
|
||||
self.config = config.hf_config
|
||||
self.model_config = config
|
||||
self.cache_config = cache_config
|
||||
self.quant_config = quant_config
|
||||
self.prefix = prefix
|
||||
|
||||
logger.info("Using Transformers modeling backend for %s",
|
||||
config.hf_config.architectures)
|
||||
|
||||
# Load the actual HuggingFace model
|
||||
self._load_hf_model()
|
||||
|
||||
# Setup logits processor and sampler
|
||||
self.logits_processor = LogitsProcessor(
|
||||
self.config.vocab_size,
|
||||
logits_as_input=False,
|
||||
)
|
||||
self.sampler = Sampler()
|
||||
|
||||
def _load_hf_model(self) -> None:
|
||||
"""Load the HuggingFace model using AutoModelForCausalLM."""
|
||||
from transformers import AutoModelForCausalLM
|
||||
|
||||
# We load with minimal config first - weights will be loaded separately
|
||||
# by vLLM's weight loader
|
||||
logger.info("Loading HuggingFace model from config...")
|
||||
|
||||
self.model: "PreTrainedModel" = AutoModelForCausalLM.from_config(
|
||||
self.config,
|
||||
torch_dtype=self.model_config.dtype,
|
||||
trust_remote_code=self.model_config.trust_remote_code,
|
||||
)
|
||||
|
||||
# Disable gradient computation for inference
|
||||
self.model.eval()
|
||||
for param in self.model.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: "AttentionMetadata",
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Forward pass through the model.
|
||||
|
||||
This method conforms to the VllmModel protocol by accepting:
|
||||
- input_ids: Token IDs
|
||||
- positions: Position IDs
|
||||
- kv_caches: KV cache tensors (not used in basic HF forward)
|
||||
- attn_metadata: Attention metadata (not used in basic HF forward)
|
||||
|
||||
Note: This is a simplified implementation that does not use vLLM's
|
||||
optimized attention mechanisms. For production use with KV caching,
|
||||
a more sophisticated implementation would be needed.
|
||||
"""
|
||||
# For simplicity, we use HuggingFace's native forward
|
||||
# This won't have vLLM's optimizations but will work
|
||||
|
||||
if inputs_embeds is not None:
|
||||
model_inputs = {"inputs_embeds": inputs_embeds}
|
||||
else:
|
||||
model_inputs = {"input_ids": input_ids.unsqueeze(0) if input_ids.dim() == 1 else input_ids}
|
||||
|
||||
# Position IDs
|
||||
if positions is not None:
|
||||
model_inputs["position_ids"] = positions.unsqueeze(0) if positions.dim() == 1 else positions
|
||||
|
||||
# Run the model
|
||||
with torch.no_grad():
|
||||
outputs = self.model(
|
||||
**model_inputs,
|
||||
use_cache=False,
|
||||
return_dict=True,
|
||||
)
|
||||
|
||||
# Get hidden states from the last layer
|
||||
# For CausalLM, we typically want the hidden states before the LM head
|
||||
if hasattr(outputs, "hidden_states") and outputs.hidden_states is not None:
|
||||
hidden_states = outputs.hidden_states[-1]
|
||||
else:
|
||||
# Fall back to running without output_hidden_states
|
||||
# and getting logits directly
|
||||
hidden_states = outputs.logits
|
||||
if hidden_states.dim() == 3:
|
||||
hidden_states = hidden_states.squeeze(0)
|
||||
return hidden_states
|
||||
|
||||
if hidden_states.dim() == 3:
|
||||
hidden_states = hidden_states.squeeze(0)
|
||||
|
||||
return hidden_states
|
||||
|
||||
def compute_logits(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> Optional[torch.Tensor]:
|
||||
"""
|
||||
Compute logits from hidden states.
|
||||
|
||||
This method conforms to the VllmModelForTextGeneration protocol.
|
||||
"""
|
||||
# If hidden_states are already logits (from forward), process them
|
||||
if hidden_states.shape[-1] == self.config.vocab_size:
|
||||
logits = hidden_states
|
||||
else:
|
||||
# Apply the LM head
|
||||
logits = self.model.lm_head(hidden_states)
|
||||
|
||||
return self.logits_processor(None, logits, sampling_metadata)
|
||||
|
||||
def sample(
|
||||
self,
|
||||
logits: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> Optional[SamplerOutput]:
|
||||
"""
|
||||
Sample tokens from logits.
|
||||
|
||||
This method conforms to the VllmModelForTextGeneration protocol.
|
||||
"""
|
||||
next_tokens = self.sampler(logits, sampling_metadata)
|
||||
return next_tokens
|
||||
|
||||
def load_weights(
|
||||
self,
|
||||
weights: Iterable[Tuple[str, torch.Tensor]],
|
||||
) -> Set[str]:
|
||||
"""
|
||||
Load weights into the model.
|
||||
|
||||
This method loads weights from an iterable of (name, tensor) pairs
|
||||
into the HuggingFace model.
|
||||
"""
|
||||
loaded_params: Set[str] = set()
|
||||
model_params = dict(self.model.named_parameters())
|
||||
|
||||
for name, loaded_weight in weights:
|
||||
# Try to find the parameter in the model
|
||||
if name in model_params:
|
||||
param = model_params[name]
|
||||
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(name)
|
||||
else:
|
||||
# Try common prefixes
|
||||
for prefix in ["model.", ""]:
|
||||
full_name = f"{prefix}{name}" if prefix else name
|
||||
if full_name in model_params:
|
||||
param = model_params[full_name]
|
||||
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(name)
|
||||
break
|
||||
|
||||
return loaded_params
|
||||
|
||||
|
||||
def is_backend_compatible() -> bool:
|
||||
"""
|
||||
Check if the current model is compatible with the Transformers backend.
|
||||
|
||||
This is a simplified check - in practice, compatibility depends on
|
||||
whether the model follows standard HuggingFace conventions.
|
||||
"""
|
||||
return True
|
||||
Reference in New Issue
Block a user