forked from EngineX-Cambricon/enginex-mlu370-vllm
testing dynamic register
This commit is contained in:
118
vllm-v0.6.2/vllm/model_executor/models/transformers/legacy.py
Normal file
118
vllm-v0.6.2/vllm/model_executor/models/transformers/legacy.py
Normal file
@@ -0,0 +1,118 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# Copyright 2024 The vLLM team.
|
||||
"""Transformers modeling backend mixin for legacy models.
|
||||
|
||||
This module provides LegacyMixin for BERT-like encoder models that have
|
||||
different weight naming conventions and special position handling.
|
||||
|
||||
Following latest vLLM architecture patterns adapted for v0.6.2.
|
||||
"""
|
||||
|
||||
from typing import TYPE_CHECKING, List, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.models.utils import WeightsMapper
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.config import VllmConfig
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class LegacyMixin:
|
||||
"""
|
||||
Mixin class for legacy/encoder models like BERT, RoBERTa.
|
||||
|
||||
This mixin provides:
|
||||
- Weight name mapping for legacy suffix conventions (.gamma/.beta)
|
||||
- Prefix mapping for BERT-like model structures
|
||||
- RoBERTa-specific position handling
|
||||
- Skip prefixes for unsupported output layers
|
||||
|
||||
Should be used with Base class:
|
||||
class TransformersForLegacy(LegacyMixin, Base): ...
|
||||
"""
|
||||
|
||||
# Weight name mapping for legacy models
|
||||
hf_to_vllm_mapper = WeightsMapper(
|
||||
# These are applied in order, so the order matters!
|
||||
orig_to_new_prefix={
|
||||
# Handle BERT-like models
|
||||
"roberta": "model",
|
||||
"bert": "model",
|
||||
},
|
||||
orig_to_new_suffix={
|
||||
# Replace legacy suffixes used for norms
|
||||
".gamma": ".weight",
|
||||
".beta": ".bias",
|
||||
},
|
||||
)
|
||||
|
||||
def __init__(self, *, vllm_config: "VllmConfig", prefix: str = "") -> None:
|
||||
# Call next class in MRO (should be Base)
|
||||
super().__init__(vllm_config=vllm_config, prefix=prefix)
|
||||
|
||||
# Skip unsupported/unwanted output embeddings layers
|
||||
self.skip_prefixes.extend([
|
||||
"model.lm_head.",
|
||||
"model.predictions.",
|
||||
"model.qa_outputs.",
|
||||
"model.embeddings_project.",
|
||||
"model.discriminator_predictions.",
|
||||
])
|
||||
|
||||
# v0.6.2 doesn't have skip_substrs, so we handle it differently
|
||||
# Store patterns to skip during weight loading
|
||||
self._legacy_skip_patterns: List[str] = [
|
||||
"position_ids", # Some encoder models have position_ids buffer
|
||||
"score.bias", # Final classifier bias not used by vLLM
|
||||
]
|
||||
|
||||
# RoBERTa-like models have extra padding in positions
|
||||
model_type = getattr(self.text_config, "model_type", "").lower()
|
||||
self.is_roberta = "roberta" in model_type
|
||||
self.padding_idx = getattr(self.text_config, "pad_token_id", 1)
|
||||
|
||||
if self.is_roberta:
|
||||
logger.info("LegacyMixin detected RoBERTa model, enabling position padding")
|
||||
|
||||
logger.info("LegacyMixin initialized for legacy/encoder model")
|
||||
|
||||
def _should_skip_weight(self, name: str) -> bool:
|
||||
"""Check if a weight should be skipped during loading."""
|
||||
for pattern in self._legacy_skip_patterns:
|
||||
if pattern in name:
|
||||
return True
|
||||
return False
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.Tensor],
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Forward pass with RoBERTa position handling.
|
||||
|
||||
RoBERTa models require positions to be offset by padding_idx + 1.
|
||||
"""
|
||||
if self.is_roberta and positions is not None:
|
||||
# RoBERTa-specific positions padding
|
||||
positions = positions + self.padding_idx + 1
|
||||
|
||||
return super().forward(
|
||||
input_ids=input_ids,
|
||||
positions=positions,
|
||||
kv_caches=kv_caches,
|
||||
attn_metadata=attn_metadata,
|
||||
intermediate_tensors=intermediate_tensors,
|
||||
inputs_embeds=inputs_embeds,
|
||||
**kwargs,
|
||||
)
|
||||
Reference in New Issue
Block a user