Unify the model type checking (#1905)

This commit is contained in:
Lianmin Zheng
2024-11-03 12:25:39 -08:00
committed by GitHub
parent c17c578108
commit 0abbf289a8
13 changed files with 146 additions and 160 deletions

View File

@@ -409,11 +409,13 @@ class LlamaForCausalLM(nn.Module):
if (
hasattr(self.config, "tie_word_embeddings")
and self.config.tie_word_embeddings
and "lm_head.weight" in params_dict
):
# Tie output embedding layer to input embedding layer, to solve issues where lm_head.weight is missing
param = self.lm_head.weight
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, self.model.embed_tokens.weight)
apply_torchao_config_(self, params_dict, set(["proj.weight"]))

View File

@@ -23,7 +23,7 @@
# limitations under the License.
"""Inference-only Qwen2-VL model compatible with HuggingFace weights."""
from functools import lru_cache, partial
from typing import Iterable, List, Mapping, Optional, Tuple, Type, TypedDict, Union
from typing import Iterable, List, Optional, Tuple, Type, TypedDict
import numpy as np
import torch
@@ -36,7 +36,6 @@ from vllm.distributed import utils as dist_utils
from vllm.logger import init_logger
from vllm.model_executor.layers.activation import QuickGELU
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.interfaces import SupportsMultiModal
from sglang.srt.configs import Qwen2VLConfig, Qwen2VLVisionConfig
from sglang.srt.hf_transformers_utils import get_processor
@@ -486,7 +485,7 @@ class Qwen2VisionTransformer(nn.Module):
cached_get_processor = lru_cache(get_processor)
class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal):
class Qwen2VLForConditionalGeneration(nn.Module):
def calculate_num_image_tokens(self, image_grid_thw: Tuple[int, int, int]):
processor = cached_get_processor(self.config._name_or_path)
grid_t, grid_h, grid_w = image_grid_thw
@@ -536,15 +535,12 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal):
def __init__(
self,
config: Qwen2VLConfig,
multimodal_config: MultiModalConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
self.config = config
self.multimodal_config = multimodal_config
self.visual = Qwen2VisionTransformer(
config.vision_config,
norm_eps=getattr(config, "rms_norm_eps", 1e-6),