forked from EngineX-Cambricon/enginex-mlu370-vllm
add qwen3
This commit is contained in:
18
vllm-v0.6.2/vllm/model_executor/model_loader/__init__.py
Normal file
18
vllm-v0.6.2/vllm/model_executor/model_loader/__init__.py
Normal file
@@ -0,0 +1,18 @@
|
||||
from torch import nn
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.model_executor.model_loader.loader import (BaseModelLoader,
|
||||
get_model_loader)
|
||||
from vllm.model_executor.model_loader.utils import (
|
||||
get_architecture_class_name, get_model_architecture)
|
||||
|
||||
|
||||
def get_model(*, vllm_config: VllmConfig) -> nn.Module:
|
||||
loader = get_model_loader(vllm_config.load_config)
|
||||
return loader.load_model(vllm_config=vllm_config)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"get_model", "get_model_loader", "BaseModelLoader",
|
||||
"get_architecture_class_name", "get_model_architecture"
|
||||
]
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
1197
vllm-v0.6.2/vllm/model_executor/model_loader/loader.py
Normal file
1197
vllm-v0.6.2/vllm/model_executor/model_loader/loader.py
Normal file
File diff suppressed because it is too large
Load Diff
211
vllm-v0.6.2/vllm/model_executor/model_loader/neuron.py
Normal file
211
vllm-v0.6.2/vllm/model_executor/model_loader/neuron.py
Normal file
@@ -0,0 +1,211 @@
|
||||
"""Utilities for selecting and loading neuron models."""
|
||||
import copy
|
||||
import importlib
|
||||
import os
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
from vllm.config import ModelConfig, ParallelConfig, SchedulerConfig
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.quantization import get_quantization_config
|
||||
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.sequence import (CompletionSequenceGroupOutput, Logprob,
|
||||
SequenceOutput)
|
||||
|
||||
TORCH_DTYPE_TO_NEURON_AMP = {
|
||||
"auto": "f32",
|
||||
"half": "f16",
|
||||
"float16": "f16",
|
||||
"bfloat16": "bf16",
|
||||
"float": "f32",
|
||||
"float32": "f32",
|
||||
torch.float16: "f16",
|
||||
torch.bfloat16: "bf16",
|
||||
torch.float32: "f32",
|
||||
}
|
||||
|
||||
# Models supported by Neuron.
|
||||
_NEURON_SUPPORTED_MODELS: Dict[str, Tuple[str, str, str]] = {
|
||||
"LlamaForCausalLM": ("transformers_neuronx.llama.model",
|
||||
"LlamaForSampling", "LlamaForCausalLM"),
|
||||
"MistralForCausalLM": ("transformers_neuronx.mistral.model",
|
||||
"MistralForSampling", "MistralForCausalLM")
|
||||
}
|
||||
|
||||
|
||||
class NeuronCausalLM(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
config: PretrainedConfig,
|
||||
on_device_sampling_disabled: bool = False) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.logits_processor = LogitsProcessor(config.vocab_size,
|
||||
logits_as_input=True)
|
||||
|
||||
self.on_device_sampling_disabled = on_device_sampling_disabled
|
||||
if self.on_device_sampling_disabled:
|
||||
# Use default sampler
|
||||
self.sampler = Sampler()
|
||||
|
||||
# Lazy initialized
|
||||
self.model: nn.Module
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
input_block_ids: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
logits = self.model(input_ids,
|
||||
cache_ids=positions,
|
||||
start_ids=input_block_ids)
|
||||
return logits
|
||||
|
||||
def compute_logits(self, hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata) -> torch.Tensor:
|
||||
logits = self.logits_processor(None, hidden_states, sampling_metadata)
|
||||
return logits
|
||||
|
||||
def sample(
|
||||
self,
|
||||
logits: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> Optional[SamplerOutput]:
|
||||
|
||||
if self.on_device_sampling_disabled:
|
||||
next_tokens = self.sampler(logits, sampling_metadata)
|
||||
return next_tokens
|
||||
|
||||
# On-device sampling outputs the token ids directly.
|
||||
sampled_token_ids = logits.flatten()
|
||||
next_tokens = []
|
||||
sample_idx = 0
|
||||
for seq_group in sampling_metadata.seq_groups:
|
||||
samples = []
|
||||
for seq_id in seq_group.seq_ids:
|
||||
token_id = sampled_token_ids[sample_idx].item()
|
||||
samples.append(
|
||||
SequenceOutput(parent_seq_id=seq_id,
|
||||
output_token=token_id,
|
||||
logprobs={token_id: Logprob(token_id)}))
|
||||
sample_idx += 1
|
||||
next_tokens.append(
|
||||
CompletionSequenceGroupOutput(samples=samples,
|
||||
prompt_logprobs=None))
|
||||
|
||||
return SamplerOutput(outputs=next_tokens)
|
||||
|
||||
def load_weights(self, model_name_or_path: str, **kwargs):
|
||||
arch = _get_model_architecture(self.config)
|
||||
neuronx_module_path, neuronx_model_cls_name, hf_model_cls_name = (
|
||||
_NEURON_SUPPORTED_MODELS[arch])
|
||||
neuronx_module = importlib.import_module(neuronx_module_path)
|
||||
neuronx_model_cls = getattr(neuronx_module, neuronx_model_cls_name)
|
||||
|
||||
self.model = neuronx_model_cls.from_pretrained(model_name_or_path,
|
||||
**kwargs)
|
||||
self.model.to_neuron()
|
||||
|
||||
|
||||
def _get_model_architecture(config: PretrainedConfig) -> str:
|
||||
architectures = getattr(config, "architectures", [])
|
||||
for arch in architectures:
|
||||
if arch in _NEURON_SUPPORTED_MODELS:
|
||||
return arch
|
||||
raise ValueError(
|
||||
f"Model architectures {architectures} are not supported on Neuron "
|
||||
f"for now. Supported architectures: "
|
||||
f"{list(_NEURON_SUPPORTED_MODELS.keys())}")
|
||||
|
||||
|
||||
def _get_buckets(env: str, default_value: List[int]) -> List[int]:
|
||||
env_value = os.getenv(env)
|
||||
if env_value is None:
|
||||
return default_value
|
||||
buckets_remove_empty = filter(
|
||||
lambda x: x is not None and len(x.strip()) > 0, env_value.split(","))
|
||||
buckets_int = map(int, buckets_remove_empty)
|
||||
buckets_list = list(buckets_int)
|
||||
return buckets_list
|
||||
|
||||
|
||||
def _get_default_neuron_config(model_config: ModelConfig,
|
||||
parallel_config: ParallelConfig,
|
||||
scheduler_config: SchedulerConfig):
|
||||
from transformers_neuronx.config import ContinuousBatchingConfig
|
||||
from transformers_neuronx.constants import LAYOUT_BSH
|
||||
|
||||
continuous_batching_config = ContinuousBatchingConfig(
|
||||
batch_size_for_shared_caches=scheduler_config.max_num_seqs)
|
||||
quant_config = dict(
|
||||
dequant_dtype=TORCH_DTYPE_TO_NEURON_AMP[model_config.dtype],
|
||||
quantize_method="vector_dynamic")
|
||||
neuron_quantization_config_builder = lambda quant: get_quantization_config(
|
||||
quant).from_config(quant_config).get_quant_method(None, "")
|
||||
# TODO: Add Paged attention config to the default neuron arguments.
|
||||
default_neuron_args = dict(
|
||||
collectives_layout=LAYOUT_BSH,
|
||||
attention_layout=LAYOUT_BSH,
|
||||
fuse_qkv=True,
|
||||
quant=neuron_quantization_config_builder(model_config.quantization)
|
||||
if model_config.quantization else None,
|
||||
continuous_batching=continuous_batching_config,
|
||||
weight_tiling=bool(model_config.quantization),
|
||||
on_device_generation=_get_neuron_on_device_generation_config(
|
||||
model_config))
|
||||
return default_neuron_args
|
||||
|
||||
|
||||
def _get_neuron_on_device_generation_config(model_config: ModelConfig):
|
||||
if not _is_neuron_on_device_sampling_disabled(model_config):
|
||||
return copy.deepcopy(model_config.neuron_sampling_params)
|
||||
return None
|
||||
|
||||
|
||||
def _is_neuron_on_device_sampling_disabled(model_config: ModelConfig) -> bool:
|
||||
return not getattr(model_config, "neuron_sampling_params", None)
|
||||
|
||||
|
||||
def _get_neuron_config_after_override(default_neuron_config,
|
||||
overridden_neuron_config):
|
||||
from transformers_neuronx.config import NeuronConfig
|
||||
overridden_neuron_config = overridden_neuron_config or {}
|
||||
default_neuron_config.update(overridden_neuron_config)
|
||||
return NeuronConfig(**default_neuron_config)
|
||||
|
||||
|
||||
def get_neuron_model(model_config: ModelConfig,
|
||||
parallel_config: ParallelConfig,
|
||||
scheduler_config: SchedulerConfig) -> nn.Module:
|
||||
|
||||
# Create a model instance.
|
||||
model = NeuronCausalLM(
|
||||
model_config.hf_config,
|
||||
_is_neuron_on_device_sampling_disabled(model_config))
|
||||
|
||||
default_neuron_config_args = _get_default_neuron_config(
|
||||
model_config, parallel_config, scheduler_config)
|
||||
|
||||
neuron_config = _get_neuron_config_after_override(
|
||||
default_neuron_config_args, model_config.override_neuron_config)
|
||||
|
||||
context_length_estimates = _get_buckets("NEURON_CONTEXT_LENGTH_BUCKETS",
|
||||
[scheduler_config.max_model_len])
|
||||
n_positions = _get_buckets("NEURON_TOKEN_GEN_BUCKETS",
|
||||
[scheduler_config.max_model_len])
|
||||
|
||||
# Load the weights from the cached or downloaded files.
|
||||
model.load_weights(model_config.model,
|
||||
tp_degree=parallel_config.tensor_parallel_size,
|
||||
amp=TORCH_DTYPE_TO_NEURON_AMP[model_config.dtype],
|
||||
neuron_config=neuron_config,
|
||||
context_length_estimate=context_length_estimates,
|
||||
n_positions=n_positions,
|
||||
batch_size=scheduler_config.max_num_seqs)
|
||||
|
||||
return model.eval()
|
||||
203
vllm-v0.6.2/vllm/model_executor/model_loader/openvino.py
Normal file
203
vllm-v0.6.2/vllm/model_executor/model_loader/openvino.py
Normal file
@@ -0,0 +1,203 @@
|
||||
# ruff: noqa: SIM117
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import openvino as ov
|
||||
import torch
|
||||
from huggingface_hub import HfApi
|
||||
from openvino._offline_transformations import paged_attention_transformation
|
||||
from optimum.intel import OVModelForCausalLM
|
||||
from torch import nn
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.attention.backends.openvino import OpenVINOAttentionMetadata
|
||||
from vllm.config import DeviceConfig, ModelConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.logits_processor import (LogitsProcessor,
|
||||
_prune_hidden_states)
|
||||
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def _flattenize_inputs(inputs):
|
||||
"""
|
||||
Helper function for making nested inputs flattens
|
||||
"""
|
||||
flatten_inputs = []
|
||||
for input_data in inputs:
|
||||
if input_data is None:
|
||||
continue
|
||||
if isinstance(input_data, (list, tuple)):
|
||||
flatten_inputs.extend(_flattenize_inputs(input_data))
|
||||
elif isinstance(input_data, dict):
|
||||
flatten_inputs.extend(_flattenize_inputs(list(
|
||||
input_data.values())))
|
||||
else:
|
||||
flatten_inputs.append(input_data)
|
||||
return flatten_inputs
|
||||
|
||||
|
||||
def _modify_cache_parameters(model: ov.Model, kv_cache_dtype: ov.Type,
|
||||
is_cpu: bool):
|
||||
# Apply hardware dependent modifications to KV tensors
|
||||
for parameter in model.get_parameters():
|
||||
input = parameter.get_output_tensor(0)
|
||||
input_names = input.get_names()
|
||||
if len(input_names) != 1:
|
||||
continue
|
||||
input_name = next(iter(input_names))
|
||||
shape = parameter.get_partial_shape()
|
||||
# use real block size if available, just a placeholder
|
||||
# to provide the expected rank
|
||||
num_blocks = ov.Dimension()
|
||||
block_size = ov.Dimension()
|
||||
head_size = ov.Dimension()
|
||||
if input_name.startswith("key_cache."):
|
||||
cpu_shape = [num_blocks, shape[1], block_size, head_size]
|
||||
gpu_shape = [num_blocks, shape[1], shape[2], block_size]
|
||||
elif input_name.startswith("value_cache."):
|
||||
cpu_shape = [num_blocks, shape[1], block_size, head_size]
|
||||
gpu_shape = [num_blocks, shape[1], block_size, shape[2]]
|
||||
else:
|
||||
continue
|
||||
parameter.set_partial_shape(
|
||||
ov.PartialShape(cpu_shape if is_cpu else gpu_shape))
|
||||
parameter.set_element_type(kv_cache_dtype)
|
||||
model.validate_nodes_and_infer_types()
|
||||
|
||||
|
||||
def _require_model_export(model_id, revision=None, subfolder=None):
|
||||
model_dir = Path(model_id)
|
||||
if subfolder is not None:
|
||||
model_dir = model_dir / subfolder
|
||||
if model_dir.is_dir():
|
||||
return (not (model_dir / "openvino_model.xml").exists()
|
||||
or not (model_dir / "openvino_model.bin").exists())
|
||||
|
||||
hf_api = HfApi()
|
||||
try:
|
||||
model_info = hf_api.model_info(model_id, revision=revision or "main")
|
||||
normalized_subfolder = (None if subfolder is None else
|
||||
Path(subfolder).as_posix())
|
||||
model_files = [
|
||||
file.rfilename for file in model_info.siblings
|
||||
if normalized_subfolder is None
|
||||
or file.rfilename.startswith(normalized_subfolder)
|
||||
]
|
||||
ov_model_path = ("openvino_model.xml" if normalized_subfolder is None
|
||||
else f"{normalized_subfolder}/openvino_model.xml")
|
||||
return (ov_model_path not in model_files
|
||||
or ov_model_path.replace(".xml", ".bin") not in model_files)
|
||||
except Exception:
|
||||
return True
|
||||
|
||||
|
||||
class OpenVINOCausalLM(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
ov_core: ov.Core,
|
||||
model_config: ModelConfig,
|
||||
device_config: DeviceConfig,
|
||||
kv_cache_dtype: ov.Type,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.logits_processor = LogitsProcessor(
|
||||
model_config.hf_config.vocab_size, logits_as_input=True)
|
||||
self.sampler = Sampler()
|
||||
|
||||
export = _require_model_export(model_config.model)
|
||||
if export:
|
||||
logger.warning(
|
||||
f"Provided model id {model_config.model} does not " # noqa: G004
|
||||
"contain OpenVINO IR, the model will be converted to IR with "
|
||||
"default options. If you need to use specific options for "
|
||||
"model conversion, use optimum-cli export openvino with "
|
||||
"desired options.")
|
||||
else:
|
||||
logger.warning(
|
||||
"OpenVINO IR is available for provided model id " # noqa: G004
|
||||
f"{model_config.model}. This IR will be used for inference "
|
||||
"as-is, all possible options that may affect model conversion "
|
||||
"are ignored.")
|
||||
|
||||
load_in_8bit = envs.VLLM_OPENVINO_ENABLE_QUANTIZED_WEIGHTS
|
||||
pt_model = OVModelForCausalLM.from_pretrained(
|
||||
model_config.model,
|
||||
export=export,
|
||||
compile=False,
|
||||
load_in_8bit=load_in_8bit,
|
||||
trust_remote_code=model_config.trust_remote_code,
|
||||
)
|
||||
|
||||
ov_device = envs.VLLM_OPENVINO_DEVICE
|
||||
paged_attention_transformation(pt_model.model)
|
||||
_modify_cache_parameters(pt_model.model, kv_cache_dtype,
|
||||
current_platform.is_openvino_cpu())
|
||||
|
||||
ov_compiled = ov_core.compile_model(pt_model.model, ov_device)
|
||||
self.ov_request = ov_compiled.create_infer_request()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[Tuple[ov.Tensor, ov.Tensor]],
|
||||
attn_metadata: OpenVINOAttentionMetadata,
|
||||
) -> torch.Tensor:
|
||||
flatten_kv_cache = _flattenize_inputs(kv_caches)
|
||||
|
||||
inputs = [
|
||||
input_ids,
|
||||
positions,
|
||||
*flatten_kv_cache,
|
||||
attn_metadata.past_lens,
|
||||
attn_metadata.subsequence_begins,
|
||||
attn_metadata.block_indices,
|
||||
attn_metadata.block_indices_begins,
|
||||
attn_metadata.max_context_len,
|
||||
]
|
||||
|
||||
self.ov_request.start_async(inputs, share_inputs=True)
|
||||
self.ov_request.wait()
|
||||
|
||||
logits = torch.from_numpy(self.ov_request.get_tensor("logits").data)
|
||||
|
||||
# TODO: remove 'view' once OpenVINO PA will drop 'seq_len' dimension
|
||||
return logits.view(-1, logits.shape[-1])
|
||||
|
||||
def compute_logits(self, hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata) -> torch.Tensor:
|
||||
hidden_states = _prune_hidden_states(hidden_states, sampling_metadata)
|
||||
logits = self.logits_processor(None, hidden_states, sampling_metadata)
|
||||
return logits
|
||||
|
||||
def sample(
|
||||
self,
|
||||
logits: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> Optional[SamplerOutput]:
|
||||
next_tokens = self.sampler(logits, sampling_metadata)
|
||||
return next_tokens
|
||||
|
||||
|
||||
def get_model(
|
||||
model_config: ModelConfig,
|
||||
device_config: DeviceConfig,
|
||||
kv_cache_dtype: ov.Type,
|
||||
**kwargs,
|
||||
) -> torch.nn.Module:
|
||||
lora_config = kwargs.get("lora_config")
|
||||
ov_core = kwargs.get("ov_core")
|
||||
if lora_config:
|
||||
raise ValueError(
|
||||
"OpenVINO modeling does not support LoRA, "
|
||||
"but LoRA is enabled. Support for this model may "
|
||||
"be added in the future. If this is important to you, "
|
||||
"please open an issue on github.")
|
||||
|
||||
return OpenVINOCausalLM(ov_core, model_config, device_config,
|
||||
kv_cache_dtype)
|
||||
470
vllm-v0.6.2/vllm/model_executor/model_loader/tensorizer.py
Normal file
470
vllm-v0.6.2/vllm/model_executor/model_loader/tensorizer.py
Normal file
@@ -0,0 +1,470 @@
|
||||
import argparse
|
||||
import dataclasses
|
||||
import io
|
||||
import os
|
||||
import re
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from functools import partial
|
||||
from typing import BinaryIO, Generator, Optional, Tuple, Type, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.config import ModelConfig, ParallelConfig
|
||||
from vllm.engine.arg_utils import EngineArgs
|
||||
from vllm.engine.llm_engine import LLMEngine
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
VocabParallelEmbedding)
|
||||
from vllm.utils import FlexibleArgumentParser
|
||||
|
||||
tensorizer_error_msg = None
|
||||
|
||||
try:
|
||||
from tensorizer import (DecryptionParams, EncryptionParams,
|
||||
TensorDeserializer, TensorSerializer)
|
||||
from tensorizer.stream_io import open_stream
|
||||
from tensorizer.utils import (convert_bytes, get_mem_usage,
|
||||
no_init_or_tensor)
|
||||
|
||||
_read_stream, _write_stream = (partial(
|
||||
open_stream,
|
||||
mode=mode,
|
||||
) for mode in ("rb", "wb+"))
|
||||
except ImportError as e:
|
||||
tensorizer_error_msg = str(e)
|
||||
|
||||
__all__ = [
|
||||
'EncryptionParams', 'DecryptionParams', 'TensorDeserializer',
|
||||
'TensorSerializer', 'open_stream', 'convert_bytes', 'get_mem_usage',
|
||||
'no_init_or_tensor', 'TensorizerConfig'
|
||||
]
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class TensorizerConfig:
|
||||
tensorizer_uri: str
|
||||
vllm_tensorized: Optional[bool] = False
|
||||
verify_hash: Optional[bool] = False
|
||||
num_readers: Optional[int] = None
|
||||
encryption_keyfile: Optional[str] = None
|
||||
s3_access_key_id: Optional[str] = None
|
||||
s3_secret_access_key: Optional[str] = None
|
||||
s3_endpoint: Optional[str] = None
|
||||
model_class: Optional[Type[torch.nn.Module]] = None
|
||||
hf_config: Optional[PretrainedConfig] = None
|
||||
dtype: Optional[Union[str, torch.dtype]] = None
|
||||
_is_sharded: bool = False
|
||||
|
||||
def __post_init__(self):
|
||||
# check if the configuration is for a sharded vLLM model
|
||||
self._is_sharded = isinstance(self.tensorizer_uri, str) \
|
||||
and re.search(r'%0\dd', self.tensorizer_uri) is not None
|
||||
|
||||
def _construct_tensorizer_args(self) -> "TensorizerArgs":
|
||||
tensorizer_args = {
|
||||
"tensorizer_uri": self.tensorizer_uri,
|
||||
"vllm_tensorized": self.vllm_tensorized,
|
||||
"verify_hash": self.verify_hash,
|
||||
"num_readers": self.num_readers,
|
||||
"encryption_keyfile": self.encryption_keyfile,
|
||||
"s3_access_key_id": self.s3_access_key_id,
|
||||
"s3_secret_access_key": self.s3_secret_access_key,
|
||||
"s3_endpoint": self.s3_endpoint,
|
||||
}
|
||||
return TensorizerArgs(**tensorizer_args) # type: ignore
|
||||
|
||||
def verify_with_parallel_config(
|
||||
self,
|
||||
parallel_config: "ParallelConfig",
|
||||
) -> None:
|
||||
if parallel_config.tensor_parallel_size > 1 \
|
||||
and not self._is_sharded:
|
||||
raise ValueError(
|
||||
"For a sharded model, tensorizer_uri should include a"
|
||||
" string format template like '%04d' to be formatted"
|
||||
" with the rank of the shard")
|
||||
|
||||
def verify_with_model_config(self, model_config: "ModelConfig") -> None:
|
||||
if (model_config.quantization is not None
|
||||
and self.tensorizer_uri is not None):
|
||||
logger.warning(
|
||||
"Loading a model using Tensorizer with quantization on vLLM"
|
||||
" is unstable and may lead to errors.")
|
||||
|
||||
def open_stream(self, tensorizer_args: Optional["TensorizerArgs"] = None):
|
||||
if tensorizer_args is None:
|
||||
tensorizer_args = self._construct_tensorizer_args()
|
||||
|
||||
return open_stream(self.tensorizer_uri,
|
||||
**tensorizer_args.stream_params)
|
||||
|
||||
|
||||
def load_with_tensorizer(tensorizer_config: TensorizerConfig,
|
||||
**extra_kwargs) -> nn.Module:
|
||||
tensorizer = TensorizerAgent(tensorizer_config, **extra_kwargs)
|
||||
return tensorizer.deserialize()
|
||||
|
||||
|
||||
@dataclass
|
||||
class TensorizerArgs:
|
||||
tensorizer_uri: Union[io.BufferedIOBase, io.RawIOBase, BinaryIO, str,
|
||||
bytes, os.PathLike, int]
|
||||
vllm_tensorized: Optional[bool] = False
|
||||
verify_hash: Optional[bool] = False
|
||||
num_readers: Optional[int] = None
|
||||
encryption_keyfile: Optional[str] = None
|
||||
s3_access_key_id: Optional[str] = None
|
||||
s3_secret_access_key: Optional[str] = None
|
||||
s3_endpoint: Optional[str] = None
|
||||
"""
|
||||
Args for the TensorizerAgent class. These are used to configure the behavior
|
||||
of the TensorDeserializer when loading tensors from a serialized model.
|
||||
|
||||
Args:
|
||||
tensorizer_uri: Path to serialized model tensors. Can be a local file
|
||||
path or a S3 URI.
|
||||
vllm_tensorized: If True, indicates that the serialized model is a
|
||||
vLLM model. This is used to determine the behavior of the
|
||||
TensorDeserializer when loading tensors from a serialized model.
|
||||
It is far faster to deserialize a vLLM model as it utilizes
|
||||
tensorizer's optimized GPU loading. Note that this is now
|
||||
deprecated, as serialized vLLM models are now automatically
|
||||
inferred as vLLM models.
|
||||
verify_hash: If True, the hashes of each tensor will be verified against
|
||||
the hashes stored in the metadata. A `HashMismatchError` will be
|
||||
raised if any of the hashes do not match.
|
||||
num_readers: Controls how many threads are allowed to read concurrently
|
||||
from the source file. Default is `None`, which will dynamically set
|
||||
the number of readers based on the number of available
|
||||
resources and model size. This greatly increases performance.
|
||||
encryption_keyfile: File path to a binary file containing a
|
||||
binary key to use for decryption. `None` (the default) means
|
||||
no decryption. See the example script in
|
||||
examples/tensorize_vllm_model.py.
|
||||
s3_access_key_id: The access key for the S3 bucket. Can also be set via
|
||||
the S3_ACCESS_KEY_ID environment variable.
|
||||
s3_secret_access_key: The secret access key for the S3 bucket. Can also
|
||||
be set via the S3_SECRET_ACCESS_KEY environment variable.
|
||||
s3_endpoint: The endpoint for the S3 bucket. Can also be set via the
|
||||
S3_ENDPOINT_URL environment variable.
|
||||
"""
|
||||
|
||||
def __post_init__(self):
|
||||
self.file_obj = self.tensorizer_uri
|
||||
self.s3_access_key_id = self.s3_access_key_id or envs.S3_ACCESS_KEY_ID
|
||||
self.s3_secret_access_key = (self.s3_secret_access_key
|
||||
or envs.S3_SECRET_ACCESS_KEY)
|
||||
self.s3_endpoint = self.s3_endpoint or envs.S3_ENDPOINT_URL
|
||||
self.stream_params = {
|
||||
"s3_access_key_id": self.s3_access_key_id,
|
||||
"s3_secret_access_key": self.s3_secret_access_key,
|
||||
"s3_endpoint": self.s3_endpoint,
|
||||
}
|
||||
|
||||
self.deserializer_params = {
|
||||
"verify_hash": self.verify_hash,
|
||||
"encryption": self.encryption_keyfile,
|
||||
"num_readers": self.num_readers
|
||||
}
|
||||
|
||||
if self.encryption_keyfile:
|
||||
with open_stream(
|
||||
self.encryption_keyfile,
|
||||
**self.stream_params,
|
||||
) as stream:
|
||||
key = stream.read()
|
||||
decryption_params = DecryptionParams.from_key(key)
|
||||
self.deserializer_params['encryption'] = decryption_params
|
||||
|
||||
@staticmethod
|
||||
def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
|
||||
"""Tensorizer CLI arguments"""
|
||||
|
||||
# Tensorizer options arg group
|
||||
group = parser.add_argument_group(
|
||||
'tensorizer options',
|
||||
description=('Options for configuring the behavior of the'
|
||||
' tensorizer deserializer when '
|
||||
'load_format=tensorizer is specified when '
|
||||
'initializing an LLMEngine, either via the CLI '
|
||||
'when running the vLLM OpenAI inference server '
|
||||
'with a JSON string passed to '
|
||||
'--model-loader-extra-config or as arguments given '
|
||||
'to TensorizerConfig when passed to '
|
||||
'model_loader_extra_config in the constructor '
|
||||
'for LLMEngine.'))
|
||||
|
||||
group.add_argument(
|
||||
"--tensorizer-uri",
|
||||
help="Path to serialized model tensors. Can be a local file path,"
|
||||
" or an HTTP(S) or S3 URI.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--verify-hash",
|
||||
action="store_true",
|
||||
help="If enabled, the hashes of each tensor will be verified"
|
||||
" against the hashes stored in the file metadata. An exception"
|
||||
" will be raised if any of the hashes do not match.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--encryption-keyfile",
|
||||
default=None,
|
||||
help="The file path to a binary file containing a binary key to "
|
||||
"use for decryption. Can be a file path or S3 network URI.")
|
||||
group.add_argument(
|
||||
"--num-readers",
|
||||
default=None,
|
||||
type=int,
|
||||
help="Controls how many threads are allowed to read concurrently "
|
||||
"from the source file. Default is `None`, which will dynamically "
|
||||
"set the number of readers based on the available resources "
|
||||
"and model size. This greatly increases performance.")
|
||||
group.add_argument(
|
||||
"--s3-access-key-id",
|
||||
default=None,
|
||||
help="The access key for the S3 bucket. Can also be set via the "
|
||||
"S3_ACCESS_KEY_ID environment variable.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--s3-secret-access-key",
|
||||
default=None,
|
||||
help="The secret access key for the S3 bucket. Can also be set via "
|
||||
"the S3_SECRET_ACCESS_KEY environment variable.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--s3-endpoint",
|
||||
default=None,
|
||||
help="The endpoint for the S3 bucket. Can also be set via the "
|
||||
"S3_ENDPOINT_URL environment variable.",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
@classmethod
|
||||
def from_cli_args(cls, args: argparse.Namespace) -> "TensorizerArgs":
|
||||
attrs = [attr.name for attr in dataclasses.fields(cls)]
|
||||
tensorizer_args = cls(**{
|
||||
attr: getattr(args, attr)
|
||||
for attr in attrs if hasattr(args, attr)
|
||||
})
|
||||
return tensorizer_args
|
||||
|
||||
|
||||
class TensorizerAgent:
|
||||
"""
|
||||
A class for performing tensorizer deserializations specifically for
|
||||
vLLM models using plaid_mode. Uses TensorizerArgs to configure the
|
||||
behavior of the TensorDeserializer when loading tensors from a serialized
|
||||
model. For deserializations of HuggingFace models, TensorDeserializer is
|
||||
instead used as an iterator directly in the func hf_model_weights_iterator
|
||||
in vllm/model_executor/model_loader/weight_utils.py
|
||||
"""
|
||||
|
||||
def __init__(self, tensorizer_config: TensorizerConfig, vllm_config):
|
||||
if tensorizer_error_msg is not None:
|
||||
raise ImportError(
|
||||
"Tensorizer is not installed. Please install tensorizer "
|
||||
"to use this feature with `pip install vllm[tensorizer]`. "
|
||||
"Error message: {}".format(tensorizer_error_msg))
|
||||
|
||||
self.tensorizer_config = tensorizer_config
|
||||
self.tensorizer_args = (
|
||||
self.tensorizer_config._construct_tensorizer_args())
|
||||
self.vllm_config = vllm_config
|
||||
self.model = self._init_model()
|
||||
|
||||
def _init_model(self):
|
||||
assert self.tensorizer_config.hf_config is not None
|
||||
model_args = self.tensorizer_config.hf_config
|
||||
model_args.torch_dtype = self.tensorizer_config.dtype
|
||||
assert self.tensorizer_config.model_class is not None
|
||||
with no_init_or_tensor():
|
||||
return self.tensorizer_config.model_class(
|
||||
vllm_config=self.vllm_config, )
|
||||
|
||||
def _resize_lora_embeddings(self):
|
||||
"""Modify LoRA embedding layers to use bigger tensors
|
||||
to allow for adapter added tokens."""
|
||||
for child in self.model.modules():
|
||||
if (isinstance(child, VocabParallelEmbedding)
|
||||
and child.weight.shape[0] <
|
||||
child.num_embeddings_per_partition):
|
||||
new_weight = torch.empty(child.num_embeddings_per_partition,
|
||||
child.embedding_dim,
|
||||
dtype=child.weight.dtype,
|
||||
device=child.weight.device)
|
||||
new_weight[:child.weight.shape[0]].copy_(child.weight.data)
|
||||
new_weight[child.weight.shape[0]:].fill_(0)
|
||||
child.weight.data = new_weight
|
||||
|
||||
def _check_tensors_on_meta_device(self):
|
||||
for tensor in self.model.state_dict().values():
|
||||
if tensor.device.type == 'meta':
|
||||
raise ValueError(
|
||||
"The serialized model contains tensors on the meta device,"
|
||||
" indicating that some tensors were not loaded properly."
|
||||
" Please check that the parameters of the model being"
|
||||
" specified match that of the serialized model, such as"
|
||||
" its quantization.")
|
||||
|
||||
def deserialize(self):
|
||||
"""
|
||||
Deserialize the model using the TensorDeserializer. This method is
|
||||
specifically for vLLM models using tensorizer's plaid_mode.
|
||||
|
||||
The deserializer makes use of tensorizer_args.stream_params
|
||||
to configure the behavior of the stream when loading tensors from a
|
||||
serialized model. The deserializer_params are used to configure the
|
||||
behavior of the TensorDeserializer when loading tensors themselves.
|
||||
Documentation on these params can be found in TensorizerArgs
|
||||
|
||||
Returns:
|
||||
nn.Module: The deserialized model.
|
||||
"""
|
||||
before_mem = get_mem_usage()
|
||||
start = time.perf_counter()
|
||||
with _read_stream(
|
||||
self.tensorizer_config.tensorizer_uri,
|
||||
**self.tensorizer_args.stream_params
|
||||
) as stream, TensorDeserializer(
|
||||
stream,
|
||||
dtype=self.tensorizer_config.dtype,
|
||||
device=f'cuda:{torch.cuda.current_device()}',
|
||||
**self.tensorizer_args.deserializer_params) as deserializer:
|
||||
deserializer.load_into_module(self.model)
|
||||
end = time.perf_counter()
|
||||
|
||||
total_bytes_str = convert_bytes(deserializer.total_tensor_bytes)
|
||||
duration = end - start
|
||||
per_second = convert_bytes(deserializer.total_tensor_bytes / duration)
|
||||
after_mem = get_mem_usage()
|
||||
deserializer.close()
|
||||
logger.info("Deserialized %s in %0.2fs, %s/s", total_bytes_str,
|
||||
end - start, per_second)
|
||||
logger.info("Memory usage before: %s", before_mem)
|
||||
logger.info("Memory usage after: %s", after_mem)
|
||||
|
||||
self._check_tensors_on_meta_device()
|
||||
self._resize_lora_embeddings()
|
||||
del self.model.vllm_tensorized_marker
|
||||
return self.model.eval()
|
||||
|
||||
|
||||
def tensorizer_weights_iterator(
|
||||
tensorizer_args: "TensorizerArgs"
|
||||
) -> Generator[Tuple[str, torch.Tensor], None, None]:
|
||||
logger.warning(
|
||||
"Deserializing HuggingFace models is not optimized for "
|
||||
"loading on vLLM, as tensorizer is forced to load to CPU. "
|
||||
"Consider deserializing a vLLM model instead for faster "
|
||||
"load times. See the examples/tensorize_vllm_model.py example "
|
||||
"script for serializing vLLM models.")
|
||||
|
||||
deserializer_args = tensorizer_args.deserializer_params
|
||||
stream_params = tensorizer_args.stream_params
|
||||
stream = open_stream(tensorizer_args.tensorizer_uri, **stream_params)
|
||||
with TensorDeserializer(stream, **deserializer_args,
|
||||
device="cpu") as state:
|
||||
yield from state.items()
|
||||
del state
|
||||
|
||||
|
||||
def is_vllm_tensorized(tensorizer_config: "TensorizerConfig") -> bool:
|
||||
"""
|
||||
Infer if the model is a vLLM model by checking the weights for
|
||||
a vLLM tensorized marker.
|
||||
|
||||
Args:
|
||||
tensorizer_config: The TensorizerConfig object containing the
|
||||
tensorizer_uri to the serialized model.
|
||||
|
||||
Returns:
|
||||
bool: True if the model is a vLLM model, False otherwise.
|
||||
"""
|
||||
tensorizer_args = tensorizer_config._construct_tensorizer_args()
|
||||
deserializer = TensorDeserializer(open_stream(
|
||||
tensorizer_args.tensorizer_uri, **tensorizer_args.stream_params),
|
||||
**tensorizer_args.deserializer_params,
|
||||
lazy_load=True)
|
||||
if tensorizer_config.vllm_tensorized:
|
||||
logger.warning(
|
||||
"Please note that newly serialized vLLM models are automatically "
|
||||
"inferred as vLLM models, so setting vllm_tensorized=True is "
|
||||
"only necessary for models serialized prior to this change.")
|
||||
return True
|
||||
return ".vllm_tensorized_marker" in deserializer
|
||||
|
||||
|
||||
def serialize_vllm_model(
|
||||
model: nn.Module,
|
||||
tensorizer_config: TensorizerConfig,
|
||||
) -> nn.Module:
|
||||
model.register_parameter(
|
||||
"vllm_tensorized_marker",
|
||||
nn.Parameter(torch.tensor((1, ), device="meta"), requires_grad=False))
|
||||
tensorizer_args = tensorizer_config._construct_tensorizer_args()
|
||||
|
||||
encryption_params = None
|
||||
if (keyfile := tensorizer_config.encryption_keyfile) is not None:
|
||||
with open(keyfile, "rb") as f:
|
||||
key = f.read()
|
||||
encryption_params = EncryptionParams(key=key)
|
||||
|
||||
output_file = tensorizer_args.tensorizer_uri
|
||||
if tensorizer_config._is_sharded:
|
||||
from vllm.distributed import get_tensor_model_parallel_rank
|
||||
output_file = output_file % get_tensor_model_parallel_rank()
|
||||
|
||||
with _write_stream(output_file, **tensorizer_args.stream_params) as stream:
|
||||
serializer = TensorSerializer(stream, encryption=encryption_params)
|
||||
serializer.write_module(model)
|
||||
serializer.close()
|
||||
logger.info("Successfully serialized model to %s", str(output_file))
|
||||
return model
|
||||
|
||||
|
||||
def tensorize_vllm_model(engine_args: EngineArgs,
|
||||
tensorizer_config: TensorizerConfig,
|
||||
generate_keyfile: bool = True):
|
||||
"""Utility to load a model and then serialize it with Tensorizer
|
||||
|
||||
Intended to be used separately from running a vLLM server since it
|
||||
creates its own Engine instance.
|
||||
"""
|
||||
engine_config = engine_args.create_engine_config()
|
||||
tensorizer_config.verify_with_model_config(engine_config.model_config)
|
||||
tensorizer_config.verify_with_parallel_config(
|
||||
engine_config.parallel_config)
|
||||
|
||||
# generate the encryption key before creating the engine to support sharding
|
||||
if generate_keyfile and (keyfile :=
|
||||
tensorizer_config.encryption_keyfile) is not None:
|
||||
encryption_params = EncryptionParams.random()
|
||||
with _write_stream(
|
||||
keyfile,
|
||||
s3_access_key_id=tensorizer_config.s3_access_key_id,
|
||||
s3_secret_access_key=tensorizer_config.s3_secret_access_key,
|
||||
s3_endpoint=tensorizer_config.s3_endpoint,
|
||||
) as stream:
|
||||
stream.write(encryption_params.key)
|
||||
|
||||
engine = LLMEngine.from_engine_args(engine_args)
|
||||
if tensorizer_config._is_sharded:
|
||||
# if the engine is a distributed engine (for tensor parallel) then each
|
||||
# worker shard needs to serialize its part of the model.
|
||||
engine.model_executor._run_workers(
|
||||
"save_tensorized_model",
|
||||
tensorizer_config=tensorizer_config,
|
||||
)
|
||||
else:
|
||||
# with a single worker, we can get to the underlying model directly
|
||||
serialize_vllm_model(
|
||||
engine.model_executor.driver_worker.model_runner.model,
|
||||
tensorizer_config,
|
||||
)
|
||||
39
vllm-v0.6.2/vllm/model_executor/model_loader/utils.py
Normal file
39
vllm-v0.6.2/vllm/model_executor/model_loader/utils.py
Normal file
@@ -0,0 +1,39 @@
|
||||
"""Utilities for selecting and loading models."""
|
||||
import contextlib
|
||||
from typing import Tuple, Type
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.model_executor.models import ModelRegistry
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def set_default_torch_dtype(dtype: torch.dtype):
|
||||
"""Sets the default torch dtype to the given dtype."""
|
||||
old_dtype = torch.get_default_dtype()
|
||||
torch.set_default_dtype(dtype)
|
||||
yield
|
||||
torch.set_default_dtype(old_dtype)
|
||||
|
||||
|
||||
def get_model_architecture(
|
||||
model_config: ModelConfig) -> Tuple[Type[nn.Module], str]:
|
||||
architectures = getattr(model_config.hf_config, "architectures", [])
|
||||
# Special handling for quantized Mixtral.
|
||||
# FIXME(woosuk): This is a temporary hack.
|
||||
mixtral_supported = [
|
||||
"fp8", "compressed-tensors", "gptq_marlin", "awq_marlin"
|
||||
]
|
||||
|
||||
if (model_config.quantization is not None
|
||||
and model_config.quantization not in mixtral_supported
|
||||
and "MixtralForCausalLM" in architectures):
|
||||
architectures = ["QuantMixtralForCausalLM"]
|
||||
|
||||
return ModelRegistry.resolve_model_cls(architectures)
|
||||
|
||||
|
||||
def get_architecture_class_name(model_config: ModelConfig) -> str:
|
||||
return get_model_architecture(model_config)[1]
|
||||
681
vllm-v0.6.2/vllm/model_executor/model_loader/weight_utils.py
Normal file
681
vllm-v0.6.2/vllm/model_executor/model_loader/weight_utils.py
Normal file
@@ -0,0 +1,681 @@
|
||||
"""Utilities for downloading and initializing model weights."""
|
||||
import fnmatch
|
||||
import glob
|
||||
import hashlib
|
||||
import json
|
||||
import os
|
||||
import tempfile
|
||||
from collections import defaultdict
|
||||
from typing import (Any, Callable, Dict, Generator, Iterable, List, Optional,
|
||||
Tuple, Union)
|
||||
|
||||
import filelock
|
||||
import gguf
|
||||
import huggingface_hub.constants
|
||||
import numpy as np
|
||||
import torch
|
||||
from huggingface_hub import HfFileSystem, hf_hub_download, snapshot_download
|
||||
from safetensors.torch import load_file, safe_open, save_file
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
from vllm.config import LoadConfig, ModelConfig
|
||||
from vllm.distributed import get_tensor_model_parallel_rank
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.quantization import (QuantizationConfig,
|
||||
get_quantization_config)
|
||||
from vllm.model_executor.layers.quantization.schema import QuantParamSchema
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import print_warning_once
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
# use system-level temp directory for file locks, so that multiple users
|
||||
# can share the same lock without error.
|
||||
# lock files in the temp directory will be automatically deleted when the
|
||||
# system reboots, so users will not complain about annoying lock files
|
||||
temp_dir = tempfile.gettempdir()
|
||||
|
||||
|
||||
def enable_hf_transfer():
|
||||
"""automatically activates hf_transfer
|
||||
"""
|
||||
if "HF_HUB_ENABLE_HF_TRANSFER" not in os.environ:
|
||||
try:
|
||||
# enable hf hub transfer if available
|
||||
import hf_transfer # type: ignore # noqa
|
||||
huggingface_hub.constants.HF_HUB_ENABLE_HF_TRANSFER = True
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
|
||||
enable_hf_transfer()
|
||||
|
||||
|
||||
class DisabledTqdm(tqdm):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs, disable=True)
|
||||
|
||||
|
||||
def get_lock(model_name_or_path: str, cache_dir: Optional[str] = None):
|
||||
lock_dir = cache_dir or temp_dir
|
||||
os.makedirs(os.path.dirname(lock_dir), exist_ok=True)
|
||||
model_name = model_name_or_path.replace("/", "-")
|
||||
hash_name = hashlib.sha256(model_name.encode()).hexdigest()
|
||||
# add hash to avoid conflict with old users' lock files
|
||||
lock_file_name = hash_name + model_name + ".lock"
|
||||
# mode 0o666 is required for the filelock to be shared across users
|
||||
lock = filelock.FileLock(os.path.join(lock_dir, lock_file_name),
|
||||
mode=0o666)
|
||||
return lock
|
||||
|
||||
|
||||
def _shared_pointers(tensors):
|
||||
ptrs = defaultdict(list)
|
||||
for k, v in tensors.items():
|
||||
ptrs[v.data_ptr()].append(k)
|
||||
failing = []
|
||||
for _, names in ptrs.items():
|
||||
if len(names) > 1:
|
||||
failing.append(names)
|
||||
return failing
|
||||
|
||||
|
||||
def convert_bin_to_safetensor_file(
|
||||
pt_filename: str,
|
||||
sf_filename: str,
|
||||
) -> None:
|
||||
loaded = torch.load(pt_filename, map_location="cpu")
|
||||
if "state_dict" in loaded:
|
||||
loaded = loaded["state_dict"]
|
||||
shared = _shared_pointers(loaded)
|
||||
for shared_weights in shared:
|
||||
for name in shared_weights[1:]:
|
||||
loaded.pop(name)
|
||||
|
||||
# For tensors to be contiguous
|
||||
loaded = {k: v.contiguous() for k, v in loaded.items()}
|
||||
|
||||
dirname = os.path.dirname(sf_filename)
|
||||
os.makedirs(dirname, exist_ok=True)
|
||||
save_file(loaded, sf_filename, metadata={"format": "pt"})
|
||||
|
||||
# check file size
|
||||
sf_size = os.stat(sf_filename).st_size
|
||||
pt_size = os.stat(pt_filename).st_size
|
||||
if (sf_size - pt_size) / pt_size > 0.01:
|
||||
raise RuntimeError(f"""The file size different is more than 1%:
|
||||
- {sf_filename}: {sf_size}
|
||||
- {pt_filename}: {pt_size}
|
||||
""")
|
||||
|
||||
# check if the tensors are the same
|
||||
reloaded = load_file(sf_filename)
|
||||
for k in loaded:
|
||||
pt_tensor = loaded[k]
|
||||
sf_tensor = reloaded[k]
|
||||
if not torch.equal(pt_tensor, sf_tensor):
|
||||
raise RuntimeError(f"The output tensors do not match for key {k}")
|
||||
|
||||
|
||||
# TODO(woosuk): Move this to other place.
|
||||
def get_quant_config(model_config: ModelConfig,
|
||||
load_config: LoadConfig) -> QuantizationConfig:
|
||||
|
||||
quant_cls = get_quantization_config(model_config.quantization)
|
||||
|
||||
# GGUF doesn't have config file
|
||||
if model_config.quantization == "gguf":
|
||||
return quant_cls.from_config({})
|
||||
|
||||
# Read the quantization config from the HF model config, if available.
|
||||
hf_quant_config = getattr(model_config.hf_config, "quantization_config",
|
||||
None)
|
||||
# some vision model may keep quantization_config in their text_config
|
||||
hf_text_config = getattr(model_config.hf_config, "text_config", None)
|
||||
if hf_quant_config is None and hf_text_config is not None:
|
||||
hf_quant_config = getattr(hf_text_config, "quantization_config", None)
|
||||
if hf_quant_config is None:
|
||||
# compressed-tensors uses a compressions_config
|
||||
hf_quant_config = getattr(model_config.hf_config, "compression_config",
|
||||
None)
|
||||
if hf_quant_config is not None:
|
||||
return quant_cls.from_config(hf_quant_config)
|
||||
# In case of bitsandbytes/QLoRA, get quant config from the adapter model.
|
||||
if model_config.quantization == "bitsandbytes":
|
||||
if (not load_config.model_loader_extra_config
|
||||
or "qlora_adapter_name_or_path"
|
||||
not in load_config.model_loader_extra_config):
|
||||
return quant_cls.from_config({"adapter_name_or_path": ""})
|
||||
model_name_or_path = load_config.model_loader_extra_config[
|
||||
"qlora_adapter_name_or_path"]
|
||||
|
||||
else:
|
||||
model_name_or_path = model_config.model
|
||||
is_local = os.path.isdir(model_name_or_path)
|
||||
if not is_local:
|
||||
# Download the config files.
|
||||
with get_lock(model_name_or_path, load_config.download_dir):
|
||||
hf_folder = snapshot_download(
|
||||
model_name_or_path,
|
||||
revision=model_config.revision,
|
||||
allow_patterns="*.json",
|
||||
cache_dir=load_config.download_dir,
|
||||
local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE,
|
||||
tqdm_class=DisabledTqdm,
|
||||
)
|
||||
else:
|
||||
hf_folder = model_name_or_path
|
||||
|
||||
possible_config_filenames = quant_cls.get_config_filenames()
|
||||
|
||||
# If the quantization config is not found, use the default config.
|
||||
if not possible_config_filenames:
|
||||
return quant_cls()
|
||||
|
||||
config_files = glob.glob(os.path.join(hf_folder, "*.json"))
|
||||
|
||||
quant_config_files = [
|
||||
f for f in config_files if any(
|
||||
f.endswith(x) for x in possible_config_filenames)
|
||||
]
|
||||
if len(quant_config_files) == 0:
|
||||
raise ValueError(
|
||||
f"Cannot find the config file for {model_config.quantization}")
|
||||
if len(quant_config_files) > 1:
|
||||
raise ValueError(
|
||||
f"Found multiple config files for {model_config.quantization}: "
|
||||
f"{quant_config_files}")
|
||||
|
||||
quant_config_file = quant_config_files[0]
|
||||
with open(quant_config_file) as f:
|
||||
config = json.load(f)
|
||||
|
||||
if model_config.quantization == "bitsandbytes":
|
||||
config["adapter_name_or_path"] = model_name_or_path
|
||||
elif model_config.quantization == "modelopt":
|
||||
if config["producer"]["name"] == "modelopt":
|
||||
return quant_cls.from_config(config)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported quantization config"
|
||||
f" found for {model_config.quantization} in {f}.")
|
||||
|
||||
return quant_cls.from_config(config)
|
||||
|
||||
|
||||
def download_weights_from_hf(
|
||||
model_name_or_path: str,
|
||||
cache_dir: Optional[str],
|
||||
allow_patterns: List[str],
|
||||
revision: Optional[str] = None,
|
||||
ignore_patterns: Optional[Union[str, List[str]]] = None,
|
||||
) -> str:
|
||||
"""Download model weights from Hugging Face Hub.
|
||||
|
||||
Args:
|
||||
model_name_or_path (str): The model name or path.
|
||||
cache_dir (Optional[str]): The cache directory to store the model
|
||||
weights. If None, will use HF defaults.
|
||||
allow_patterns (List[str]): The allowed patterns for the
|
||||
weight files. Files matched by any of the patterns will be
|
||||
downloaded.
|
||||
revision (Optional[str]): The revision of the model.
|
||||
ignore_patterns (Optional[Union[str, List[str]]]): The patterns to
|
||||
filter out the weight files. Files matched by any of the patterns
|
||||
will be ignored.
|
||||
|
||||
Returns:
|
||||
str: The path to the downloaded model weights.
|
||||
"""
|
||||
if not huggingface_hub.constants.HF_HUB_OFFLINE:
|
||||
# Before we download we look at that is available:
|
||||
fs = HfFileSystem()
|
||||
file_list = fs.ls(model_name_or_path, detail=False, revision=revision)
|
||||
|
||||
# depending on what is available we download different things
|
||||
for pattern in allow_patterns:
|
||||
matching = fnmatch.filter(file_list, pattern)
|
||||
if len(matching) > 0:
|
||||
allow_patterns = [pattern]
|
||||
break
|
||||
|
||||
logger.info("Using model weights format %s", allow_patterns)
|
||||
# Use file lock to prevent multiple processes from
|
||||
# downloading the same model weights at the same time.
|
||||
with get_lock(model_name_or_path, cache_dir):
|
||||
hf_folder = snapshot_download(
|
||||
model_name_or_path,
|
||||
allow_patterns=allow_patterns,
|
||||
ignore_patterns=ignore_patterns,
|
||||
cache_dir=cache_dir,
|
||||
tqdm_class=DisabledTqdm,
|
||||
revision=revision,
|
||||
local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE,
|
||||
)
|
||||
return hf_folder
|
||||
|
||||
|
||||
def download_safetensors_index_file_from_hf(
|
||||
model_name_or_path: str,
|
||||
index_file: str,
|
||||
cache_dir: Optional[str],
|
||||
revision: Optional[str] = None,
|
||||
) -> None:
|
||||
"""Download hf safetensors index file from Hugging Face Hub.
|
||||
|
||||
Args:
|
||||
model_name_or_path (str): The model name or path.
|
||||
cache_dir (Optional[str]): The cache directory to store the model
|
||||
weights. If None, will use HF defaults.
|
||||
revision (Optional[str]): The revision of the model.
|
||||
"""
|
||||
# Use file lock to prevent multiple processes from
|
||||
# downloading the same model weights at the same time.
|
||||
with get_lock(model_name_or_path, cache_dir):
|
||||
try:
|
||||
# Download the safetensors index file.
|
||||
hf_hub_download(
|
||||
repo_id=model_name_or_path,
|
||||
filename=index_file,
|
||||
cache_dir=cache_dir,
|
||||
revision=revision,
|
||||
local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE,
|
||||
)
|
||||
# If file not found on remote or locally, we should not fail since
|
||||
# only some models will have index_file.
|
||||
except huggingface_hub.utils.EntryNotFoundError:
|
||||
logger.info("No %s found in remote.", index_file)
|
||||
except huggingface_hub.utils.LocalEntryNotFoundError:
|
||||
logger.info("No %s found in local cache.", index_file)
|
||||
|
||||
|
||||
# For models like Mistral-7B-v0.3, there are both sharded
|
||||
# safetensors files and a consolidated safetensors file.
|
||||
# Passing both of these to the weight loader functionality breaks.
|
||||
# So, we use the index_file to
|
||||
# look up which safetensors files should be used.
|
||||
def filter_duplicate_safetensors_files(hf_weights_files: List[str],
|
||||
hf_folder: str,
|
||||
index_file: str) -> List[str]:
|
||||
# model.safetensors.index.json is a mapping from keys in the
|
||||
# torch state_dict to safetensors file holding that weight.
|
||||
index_file_name = os.path.join(hf_folder, index_file)
|
||||
if not os.path.isfile(index_file_name):
|
||||
return hf_weights_files
|
||||
|
||||
# Iterate through the weight_map (weight_name: safetensors files)
|
||||
# to identify weights that we should use.
|
||||
with open(index_file_name) as f:
|
||||
weight_map = json.load(f)["weight_map"]
|
||||
weight_files_in_index = set()
|
||||
for weight_name in weight_map:
|
||||
weight_files_in_index.add(
|
||||
os.path.join(hf_folder, weight_map[weight_name]))
|
||||
# Filter out any fields that are not found in the index file.
|
||||
hf_weights_files = [
|
||||
f for f in hf_weights_files if f in weight_files_in_index
|
||||
]
|
||||
return hf_weights_files
|
||||
|
||||
|
||||
def filter_files_not_needed_for_inference(
|
||||
hf_weights_files: List[str]) -> List[str]:
|
||||
"""
|
||||
Exclude files that are not needed for inference.
|
||||
|
||||
See https://github.com/huggingface/transformers/blob/v4.34.0/src/transformers/trainer.py#L227-L233
|
||||
"""
|
||||
blacklist = [
|
||||
"training_args.bin",
|
||||
"optimizer.bin",
|
||||
"optimizer.pt",
|
||||
"scheduler.pt",
|
||||
"scaler.pt",
|
||||
]
|
||||
hf_weights_files = [
|
||||
f for f in hf_weights_files
|
||||
if not any(f.endswith(x) for x in blacklist)
|
||||
]
|
||||
return hf_weights_files
|
||||
|
||||
|
||||
# explicitly use pure text format, with a newline at the end
|
||||
# this makes it impossible to see the animation in the progress bar
|
||||
# but will avoid messing up with ray or multiprocessing, which wraps
|
||||
# each line of output with some prefix.
|
||||
_BAR_FORMAT = "{desc}: {percentage:3.0f}% Completed | {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]\n" # noqa: E501
|
||||
|
||||
|
||||
def np_cache_weights_iterator(
|
||||
model_name_or_path: str, cache_dir: Optional[str], hf_folder: str,
|
||||
hf_weights_files: List[str]
|
||||
) -> Generator[Tuple[str, torch.Tensor], None, None]:
|
||||
"""Iterate over the weights in the model np files.
|
||||
|
||||
Will dump the model weights to numpy files if they are not already dumped.
|
||||
"""
|
||||
enable_tqdm = not torch.distributed.is_initialized(
|
||||
) or torch.distributed.get_rank() == 0
|
||||
# Convert the model weights from torch tensors to numpy arrays for
|
||||
# faster loading.
|
||||
np_folder = os.path.join(hf_folder, "np")
|
||||
os.makedirs(np_folder, exist_ok=True)
|
||||
weight_names_file = os.path.join(np_folder, "weight_names.json")
|
||||
# Use file lock to prevent multiple processes from
|
||||
# dumping the same model weights to numpy at the same time.
|
||||
with get_lock(model_name_or_path, cache_dir):
|
||||
if not os.path.exists(weight_names_file):
|
||||
weight_names: List[str] = []
|
||||
for bin_file in tqdm(
|
||||
hf_weights_files,
|
||||
desc="Loading np_cache checkpoint shards",
|
||||
disable=not enable_tqdm,
|
||||
bar_format=_BAR_FORMAT,
|
||||
):
|
||||
state = torch.load(bin_file, map_location="cpu")
|
||||
for name, param in state.items():
|
||||
param_path = os.path.join(np_folder, name)
|
||||
with open(param_path, "wb") as f:
|
||||
np.save(f, param.cpu().detach().numpy())
|
||||
weight_names.append(name)
|
||||
with open(weight_names_file, "w") as f:
|
||||
json.dump(weight_names, f)
|
||||
|
||||
with open(weight_names_file) as f:
|
||||
weight_names = json.load(f)
|
||||
|
||||
for name in weight_names:
|
||||
param_path = os.path.join(np_folder, name)
|
||||
with open(param_path, "rb") as f:
|
||||
param = np.load(f)
|
||||
yield name, torch.from_numpy(param)
|
||||
|
||||
|
||||
def safetensors_weights_iterator(
|
||||
hf_weights_files: List[str]
|
||||
) -> Generator[Tuple[str, torch.Tensor], None, None]:
|
||||
"""Iterate over the weights in the model safetensor files."""
|
||||
enable_tqdm = not torch.distributed.is_initialized(
|
||||
) or torch.distributed.get_rank() == 0
|
||||
for st_file in tqdm(
|
||||
hf_weights_files,
|
||||
desc="Loading safetensors checkpoint shards",
|
||||
disable=not enable_tqdm,
|
||||
bar_format=_BAR_FORMAT,
|
||||
):
|
||||
with safe_open(st_file, framework="pt") as f:
|
||||
for name in f.keys(): # noqa: SIM118
|
||||
param = f.get_tensor(name)
|
||||
yield name, param
|
||||
|
||||
|
||||
def pt_weights_iterator(
|
||||
hf_weights_files: List[str]
|
||||
) -> Generator[Tuple[str, torch.Tensor], None, None]:
|
||||
"""Iterate over the weights in the model bin/pt files."""
|
||||
enable_tqdm = not torch.distributed.is_initialized(
|
||||
) or torch.distributed.get_rank() == 0
|
||||
for bin_file in tqdm(
|
||||
hf_weights_files,
|
||||
desc="Loading pt checkpoint shards",
|
||||
disable=not enable_tqdm,
|
||||
bar_format=_BAR_FORMAT,
|
||||
):
|
||||
state = torch.load(bin_file, map_location="cpu")
|
||||
yield from state.items()
|
||||
del state
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
def get_gguf_extra_tensor_names(
|
||||
gguf_file: str, gguf_to_hf_name_map: Dict[str, str]) -> List[str]:
|
||||
reader = gguf.GGUFReader(gguf_file)
|
||||
expected_gguf_keys = set(gguf_to_hf_name_map.keys())
|
||||
exact_gguf_keys = set([tensor.name for tensor in reader.tensors])
|
||||
extra_keys = expected_gguf_keys - exact_gguf_keys
|
||||
return [gguf_to_hf_name_map[key] for key in extra_keys]
|
||||
|
||||
|
||||
def gguf_quant_weights_iterator(
|
||||
gguf_file: str, gguf_to_hf_name_map: Dict[str, str]
|
||||
) -> Generator[Tuple[str, torch.Tensor], None, None]:
|
||||
"""
|
||||
Iterate over the quant weights in the model gguf files and convert
|
||||
them to torch tensors
|
||||
"""
|
||||
|
||||
reader = gguf.GGUFReader(gguf_file)
|
||||
|
||||
for tensor in reader.tensors:
|
||||
if tensor.name in gguf_to_hf_name_map:
|
||||
weight_type = tensor.tensor_type
|
||||
name = gguf_to_hf_name_map[tensor.name]
|
||||
|
||||
if weight_type.name != "F32":
|
||||
weight_type_name = name.replace("weight", "qweight_type")
|
||||
weight_type = torch.tensor(weight_type)
|
||||
yield weight_type_name, weight_type
|
||||
|
||||
for tensor in reader.tensors:
|
||||
if tensor.name in gguf_to_hf_name_map:
|
||||
weight = tensor.data
|
||||
weight_type = tensor.tensor_type
|
||||
name = gguf_to_hf_name_map[tensor.name]
|
||||
|
||||
if weight_type.name != "F32":
|
||||
name = name.replace("weight", "qweight")
|
||||
param = torch.tensor(weight)
|
||||
yield name, param
|
||||
|
||||
|
||||
def kv_cache_scales_loader(
|
||||
filename: str, tp_rank: int, tp_size: int, num_hidden_layers: int,
|
||||
model_type: Optional[str]) -> Iterable[Tuple[int, float]]:
|
||||
"""
|
||||
A simple utility to read in KV cache scaling factors that have been
|
||||
previously serialized to disk. Used by the model to populate the appropriate
|
||||
KV cache scaling factors. The serialization should represent a dictionary
|
||||
whose keys are the TP ranks and values are another dictionary mapping layers
|
||||
to their KV cache scaling factors.
|
||||
Keep this function in sync with the output of examples/fp8/extract_scales.py
|
||||
"""
|
||||
try:
|
||||
with open(filename) as f:
|
||||
context = {
|
||||
"model_type": model_type,
|
||||
"num_hidden_layers": num_hidden_layers,
|
||||
"tp_rank": tp_rank,
|
||||
"tp_size": tp_size,
|
||||
}
|
||||
schema_dct = json.load(f)
|
||||
schema = QuantParamSchema.model_validate(schema_dct,
|
||||
context=context)
|
||||
layer_scales_map = schema.kv_cache.scaling_factor[tp_rank]
|
||||
return layer_scales_map.items()
|
||||
|
||||
except FileNotFoundError:
|
||||
logger.error("File or directory '%s' not found.", filename)
|
||||
except json.JSONDecodeError:
|
||||
logger.error("Error decoding JSON in file '%s'.", filename)
|
||||
except Exception:
|
||||
logger.exception("An error occurred while reading '%s'.", filename)
|
||||
# This section is reached if and only if any of the excepts are hit
|
||||
# Return an empty iterable (list) => no KV cache scales are loaded
|
||||
# which ultimately defaults to 1.0 scales
|
||||
logger.warning(
|
||||
"Defaulting to KV cache scaling factors = 1.0 for all "
|
||||
"layers in TP rank %d as an error occurred during loading.", tp_rank)
|
||||
return []
|
||||
|
||||
|
||||
def convert_pyslice_to_tensor(x: Any) -> torch.Tensor:
|
||||
"""convert PySafeSlice object from safetensors to torch.Tensor
|
||||
|
||||
PySafeSlice object supports indexing, which is done before loading the
|
||||
actual tensor and can reduce the amount of memory being read into the
|
||||
memory. However, it does not support more advanced functionalities
|
||||
like `.view()` or `.t()`. Therefore, if we need to modify the loaded
|
||||
tensor with these more complicated operators, we need to convert to
|
||||
tensor first.
|
||||
"""
|
||||
if not isinstance(x, torch.Tensor):
|
||||
x = x[:]
|
||||
return x
|
||||
|
||||
|
||||
def default_weight_loader(param: torch.Tensor,
|
||||
loaded_weight: torch.Tensor) -> None:
|
||||
"""Default weight loader."""
|
||||
try:
|
||||
if param.numel() == 1 and loaded_weight.numel() == 1:
|
||||
# Sometimes scalar values aren't considered tensors with shapes
|
||||
# so if both param and loaded_weight are a scalar,
|
||||
# "broadcast" instead of copy
|
||||
param.data.fill_(loaded_weight.item())
|
||||
else:
|
||||
assert param.size() == loaded_weight.size(), (
|
||||
f"Attempted to load weight ({loaded_weight.size()}) "
|
||||
f"into parameter ({param.size()})")
|
||||
|
||||
param.data.copy_(loaded_weight)
|
||||
except Exception:
|
||||
# NOTE: This exception is added for the purpose of setting breakpoint to
|
||||
# debug weight loading issues.
|
||||
raise
|
||||
|
||||
|
||||
def row_parallel_weight_loader(param: torch.Tensor,
|
||||
loaded_weight: torch.Tensor) -> None:
|
||||
"""Load weights that are row-parallelized."""
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
shard_dim = 0 if param.dim() != 1 else None
|
||||
|
||||
if shard_dim is not None:
|
||||
shard_size = param.data.shape[shard_dim]
|
||||
start_idx = tp_rank * shard_size
|
||||
loaded_weight = loaded_weight.narrow(shard_dim, start_idx, shard_size)
|
||||
|
||||
return default_weight_loader(param, loaded_weight)
|
||||
|
||||
|
||||
LoaderFunction = Callable[[torch.Tensor, torch.Tensor], torch.Tensor]
|
||||
|
||||
|
||||
def sharded_weight_loader(shard_axis: int) -> LoaderFunction:
|
||||
"""Create a weight loader that shards the weights along the given axis"""
|
||||
|
||||
def loader(param: torch.Tensor, loaded_weight: torch.Tensor) -> None:
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
|
||||
shard_size = param.data.shape[shard_axis]
|
||||
start_idx = tp_rank * shard_size
|
||||
loaded_weight = loaded_weight.narrow(shard_axis, start_idx, shard_size)
|
||||
|
||||
return default_weight_loader(param, loaded_weight)
|
||||
|
||||
return loader
|
||||
|
||||
|
||||
def composed_weight_loader(
|
||||
loader: LoaderFunction, fn: Callable[[torch.Tensor],
|
||||
torch.Tensor]) -> LoaderFunction:
|
||||
"""Create a weight loader that post-processes the weights after loading"""
|
||||
|
||||
def composed_loader(param: torch.Tensor,
|
||||
loaded_weight: torch.Tensor) -> None:
|
||||
loader(param, loaded_weight)
|
||||
param.data.copy_(fn(param))
|
||||
return
|
||||
|
||||
return composed_loader
|
||||
|
||||
|
||||
def initialize_dummy_weights(
|
||||
model: torch.nn.Module,
|
||||
low: float = -1e-3,
|
||||
high: float = 1e-3,
|
||||
seed: int = 1234,
|
||||
) -> None:
|
||||
"""Initialize model weights with random values.
|
||||
|
||||
The model weights must be randomly initialized for accurate performance
|
||||
measurements. Additionally, the model weights should not cause NaNs in the
|
||||
forward pass. We empirically found that initializing the weights with
|
||||
values between -1e-3 and 1e-3 works well for most models.
|
||||
|
||||
We use per-parameter random seed, so that dummy weights are consistent,
|
||||
even if the model is partitioned across multiple devices. When the seed
|
||||
is fixed, the random values generated by this function only depends on
|
||||
the parameter's number of elements and its data type.
|
||||
"""
|
||||
for param in model.state_dict().values():
|
||||
if torch.is_floating_point(param):
|
||||
if current_platform.is_tpu():
|
||||
# XLA device does not support torch.Generator()
|
||||
param.uniform_(low, high)
|
||||
continue
|
||||
|
||||
generator = torch.Generator(device=param.data.device)
|
||||
generator.manual_seed(seed)
|
||||
if torch.finfo(param.data.dtype).bits < 16:
|
||||
# uniform_ doesn't support < 16-bit datatypes (FP8)
|
||||
dtype = param.data.dtype
|
||||
tmp_param = param.data.to(torch.float16)
|
||||
tmp_param = tmp_param.uniform_(low, high,
|
||||
generator=generator).to(dtype)
|
||||
param.data.copy_(tmp_param)
|
||||
else:
|
||||
param.uniform_(low, high, generator=generator)
|
||||
|
||||
|
||||
def maybe_remap_kv_scale_name(name: str, params_dict: dict) -> Optional[str]:
|
||||
"""Remap the name of FP8 k/v_scale parameters.
|
||||
|
||||
This function handles the remapping of FP8 k/v_scale parameter names.
|
||||
It detects if the given name ends with a suffix and attempts to remap
|
||||
it to the expected name format in the model. If the remapped name is not
|
||||
found in the params_dict, a warning is printed and None is returned.
|
||||
|
||||
Args:
|
||||
name (str): The original loaded checkpoint parameter name.
|
||||
params_dict (dict): Dictionary containing the model's named parameters.
|
||||
|
||||
Returns:
|
||||
str: The remapped parameter name if successful, or the original name
|
||||
if no remapping is needed.
|
||||
None: If the remapped name is not found in params_dict.
|
||||
"""
|
||||
if name.endswith(".kv_scale"):
|
||||
print_warning_once(
|
||||
"DEPRECATED. Found kv_scale in the checkpoint. "
|
||||
"This format is deprecated in favor of separate k_scale and "
|
||||
"v_scale tensors and will be removed in a future release. "
|
||||
"Functionally, we will remap kv_scale to k_scale and duplicate "
|
||||
"k_scale to v_scale")
|
||||
# NOTE: we remap the deprecated kv_scale to k_scale
|
||||
remapped_name = name.replace(".kv_scale", ".attn.k_scale")
|
||||
if remapped_name not in params_dict:
|
||||
print_warning_once(
|
||||
f"Found kv_scale in the checkpoint (e.g. {name}), "
|
||||
"but not found the expected name in the model "
|
||||
f"(e.g. {remapped_name}). kv_scale is "
|
||||
"not loaded.")
|
||||
return None
|
||||
return remapped_name
|
||||
|
||||
possible_scale_names = [".k_scale", ".v_scale"]
|
||||
for scale_name in possible_scale_names:
|
||||
if name.endswith(scale_name):
|
||||
remapped_name = name.replace(scale_name, f".attn{scale_name}")
|
||||
if remapped_name not in params_dict:
|
||||
print_warning_once(
|
||||
f"Found {scale_name} in the checkpoint (e.g. {name}), "
|
||||
"but not found the expected name in the model "
|
||||
f"(e.g. {remapped_name}). {scale_name} is "
|
||||
"not loaded.")
|
||||
return None
|
||||
return remapped_name
|
||||
|
||||
# If there were no matches, return the untouched param name
|
||||
return name
|
||||
Reference in New Issue
Block a user