Unify the model type checking (#1905)
This commit is contained in:
@@ -409,11 +409,13 @@ class LlamaForCausalLM(nn.Module):
|
||||
if (
|
||||
hasattr(self.config, "tie_word_embeddings")
|
||||
and self.config.tie_word_embeddings
|
||||
and "lm_head.weight" in params_dict
|
||||
):
|
||||
# Tie output embedding layer to input embedding layer, to solve issues where lm_head.weight is missing
|
||||
param = self.lm_head.weight
|
||||
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||
weight_loader(param, self.model.embed_tokens.weight)
|
||||
|
||||
apply_torchao_config_(self, params_dict, set(["proj.weight"]))
|
||||
|
||||
|
||||
|
||||
@@ -23,7 +23,7 @@
|
||||
# limitations under the License.
|
||||
"""Inference-only Qwen2-VL model compatible with HuggingFace weights."""
|
||||
from functools import lru_cache, partial
|
||||
from typing import Iterable, List, Mapping, Optional, Tuple, Type, TypedDict, Union
|
||||
from typing import Iterable, List, Optional, Tuple, Type, TypedDict
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -36,7 +36,6 @@ from vllm.distributed import utils as dist_utils
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.activation import QuickGELU
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.models.interfaces import SupportsMultiModal
|
||||
|
||||
from sglang.srt.configs import Qwen2VLConfig, Qwen2VLVisionConfig
|
||||
from sglang.srt.hf_transformers_utils import get_processor
|
||||
@@ -486,7 +485,7 @@ class Qwen2VisionTransformer(nn.Module):
|
||||
cached_get_processor = lru_cache(get_processor)
|
||||
|
||||
|
||||
class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal):
|
||||
class Qwen2VLForConditionalGeneration(nn.Module):
|
||||
def calculate_num_image_tokens(self, image_grid_thw: Tuple[int, int, int]):
|
||||
processor = cached_get_processor(self.config._name_or_path)
|
||||
grid_t, grid_h, grid_w = image_grid_thw
|
||||
@@ -536,15 +535,12 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal):
|
||||
def __init__(
|
||||
self,
|
||||
config: Qwen2VLConfig,
|
||||
multimodal_config: MultiModalConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.config = config
|
||||
self.multimodal_config = multimodal_config
|
||||
|
||||
self.visual = Qwen2VisionTransformer(
|
||||
config.vision_config,
|
||||
norm_eps=getattr(config, "rms_norm_eps", 1e-6),
|
||||
|
||||
Reference in New Issue
Block a user