638 lines
23 KiB
Python
638 lines
23 KiB
Python
# Copyright 2023-2024 SGLang Team
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ==============================================================================
|
|
|
|
"""Inference-only OPT model compatible with HuggingFace weights."""
|
|
from collections.abc import Iterable
|
|
from typing import Optional, Union
|
|
|
|
import torch
|
|
import torch.nn.functional as F
|
|
from torch import nn
|
|
from transformers import OPTConfig
|
|
|
|
from sglang.srt.distributed import (
|
|
get_pp_group,
|
|
get_tensor_model_parallel_rank,
|
|
get_tensor_model_parallel_world_size,
|
|
)
|
|
from sglang.srt.layers.activation import get_act_fn
|
|
from sglang.srt.layers.linear import (
|
|
ColumnParallelLinear,
|
|
MergedColumnParallelLinear,
|
|
QKVParallelLinear,
|
|
ReplicatedLinear,
|
|
RowParallelLinear,
|
|
)
|
|
from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput
|
|
from sglang.srt.layers.pooler import Pooler, PoolingType
|
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
|
from sglang.srt.layers.radix_attention import RadixAttention
|
|
from sglang.srt.layers.utils import PPMissingLayer, get_layer_id
|
|
from sglang.srt.layers.vocab_parallel_embedding import (
|
|
ParallelLMHead,
|
|
VocabParallelEmbedding,
|
|
)
|
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
|
|
from sglang.srt.model_loader.weight_utils import (
|
|
default_weight_loader,
|
|
kv_cache_scales_loader,
|
|
maybe_remap_kv_scale_name,
|
|
)
|
|
from sglang.srt.utils import add_prefix, make_layers
|
|
|
|
|
|
def get_activation(name="relu"):
|
|
"""Select an activation function by name
|
|
|
|
Args:
|
|
name: str
|
|
activation function name,
|
|
one of ["relu", "gelu", "swish", "sigmoid"],
|
|
default "relu".
|
|
"""
|
|
name = name.lower()
|
|
if name == "relu":
|
|
return nn.ReLU()
|
|
if name == "gelu":
|
|
return nn.GELU()
|
|
if name == "sigmoid":
|
|
return torch.nn.Sigmoid()
|
|
return nn.Identity()
|
|
|
|
|
|
class OPTLearnedPositionalEmbedding(nn.Embedding):
|
|
|
|
def __init__(self, num_embeddings: int, embedding_dim: int):
|
|
# OPT is set up so that if padding_idx is specified then offset the
|
|
# embedding ids by 2 and adjust num_embeddings appropriately. Other
|
|
# models don't have this hack
|
|
self.offset = 2
|
|
super().__init__(num_embeddings + self.offset, embedding_dim)
|
|
|
|
def forward(self, positions: torch.Tensor):
|
|
return super().forward(positions + self.offset)
|
|
|
|
|
|
class OPTAttention(nn.Module):
|
|
|
|
def __init__(
|
|
self,
|
|
embed_dim: int,
|
|
num_heads: int,
|
|
layer_id: int = 0,
|
|
bias: bool = True,
|
|
quant_config: Optional[QuantizationConfig] = None,
|
|
prefix: str = "",
|
|
) -> None:
|
|
super().__init__()
|
|
self.embed_dim = embed_dim
|
|
tensor_model_parallel_world_size = get_tensor_model_parallel_world_size()
|
|
total_num_heads = num_heads
|
|
assert num_heads % tensor_model_parallel_world_size == 0
|
|
self.num_heads = total_num_heads // tensor_model_parallel_world_size
|
|
self.head_dim = embed_dim // total_num_heads
|
|
self.scaling = self.head_dim**-0.5
|
|
|
|
self.qkv_proj = QKVParallelLinear(
|
|
embed_dim,
|
|
self.head_dim,
|
|
total_num_heads,
|
|
bias=bias,
|
|
quant_config=quant_config,
|
|
prefix=add_prefix("qkv_proj", prefix),
|
|
)
|
|
self.out_proj = RowParallelLinear(
|
|
embed_dim,
|
|
embed_dim,
|
|
bias=bias,
|
|
quant_config=quant_config,
|
|
prefix=add_prefix("o_proj", prefix),
|
|
)
|
|
|
|
self.attn = RadixAttention(
|
|
self.num_heads,
|
|
self.head_dim,
|
|
self.scaling,
|
|
num_kv_heads=self.num_heads,
|
|
layer_id=layer_id,
|
|
quant_config=quant_config,
|
|
prefix=add_prefix("attn", prefix),
|
|
)
|
|
|
|
def forward(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
forward_batch: ForwardBatch,
|
|
) -> torch.Tensor:
|
|
qkv, _ = self.qkv_proj(hidden_states)
|
|
q, k, v = qkv.chunk(chunks=3, dim=-1)
|
|
attn_output = self.attn(q, k, v, forward_batch)
|
|
output, _ = self.out_proj(attn_output)
|
|
return output
|
|
|
|
|
|
class OPTDecoderLayer(nn.Module):
|
|
|
|
def __init__(
|
|
self,
|
|
config: OPTConfig,
|
|
layer_id: int = 0,
|
|
quant_config: Optional[QuantizationConfig] = None,
|
|
prefix: str = "",
|
|
):
|
|
super().__init__()
|
|
self.config = config
|
|
self.embed_dim = config.hidden_size
|
|
self.self_attn = OPTAttention(
|
|
embed_dim=self.embed_dim,
|
|
num_heads=config.num_attention_heads,
|
|
layer_id=layer_id,
|
|
bias=config.enable_bias,
|
|
quant_config=quant_config,
|
|
prefix=add_prefix("self_attn", prefix),
|
|
)
|
|
self.do_layer_norm_before = config.do_layer_norm_before
|
|
|
|
self.self_attn_layer_norm = nn.LayerNorm(
|
|
self.embed_dim, elementwise_affine=config.layer_norm_elementwise_affine
|
|
)
|
|
self.fc1 = ColumnParallelLinear(
|
|
self.embed_dim,
|
|
config.ffn_dim,
|
|
bias=config.enable_bias,
|
|
quant_config=quant_config,
|
|
prefix=add_prefix("fc1", prefix),
|
|
)
|
|
self.activation_fn = get_activation(config.activation_function)
|
|
self.fc2 = RowParallelLinear(
|
|
config.ffn_dim,
|
|
self.embed_dim,
|
|
bias=config.enable_bias,
|
|
quant_config=quant_config,
|
|
prefix=add_prefix("fc2", prefix),
|
|
)
|
|
self.final_layer_norm = nn.LayerNorm(
|
|
self.embed_dim, elementwise_affine=config.layer_norm_elementwise_affine
|
|
)
|
|
|
|
def forward(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
forward_batch: ForwardBatch,
|
|
) -> torch.Tensor:
|
|
# Self Attention
|
|
residual = hidden_states
|
|
# 125m, 1.7B, ..., 175B applies layer norm BEFORE attention
|
|
if self.do_layer_norm_before:
|
|
hidden_states = self.self_attn_layer_norm(hidden_states)
|
|
hidden_states = self.self_attn(
|
|
hidden_states=hidden_states, forward_batch=forward_batch
|
|
)
|
|
hidden_states = residual + hidden_states
|
|
# 350m applies layer norm AFTER attention
|
|
if not self.do_layer_norm_before:
|
|
hidden_states = self.self_attn_layer_norm(hidden_states)
|
|
|
|
# Fully Connected
|
|
residual = hidden_states
|
|
# 125m, 1.7B, ..., 175B applies layer norm BEFORE attention
|
|
if self.do_layer_norm_before:
|
|
hidden_states = self.final_layer_norm(hidden_states)
|
|
hidden_states, _ = self.fc1(hidden_states)
|
|
hidden_states = self.activation_fn(hidden_states)
|
|
hidden_states, _ = self.fc2(hidden_states)
|
|
hidden_states = residual + hidden_states
|
|
# 350m applies layer norm AFTER attention
|
|
if not self.do_layer_norm_before:
|
|
hidden_states = self.final_layer_norm(hidden_states)
|
|
return hidden_states
|
|
|
|
|
|
class OPTDecoder(nn.Module):
|
|
|
|
def __init__(
|
|
self,
|
|
config: OPTConfig,
|
|
layer_id: int = 0,
|
|
quant_config: Optional[QuantizationConfig] = None,
|
|
prefix: str = "",
|
|
):
|
|
super().__init__()
|
|
self.config = config
|
|
self.max_target_positions = config.max_position_embeddings
|
|
self.vocab_size = config.vocab_size
|
|
|
|
self.pp_group = get_pp_group()
|
|
|
|
self.embed_tokens = VocabParallelEmbedding(
|
|
config.vocab_size,
|
|
config.word_embed_proj_dim,
|
|
prefix=add_prefix("embed_tokens", prefix),
|
|
)
|
|
# Positional embeddings are replicated (not sharded).
|
|
self.embed_positions = OPTLearnedPositionalEmbedding(
|
|
config.max_position_embeddings, config.hidden_size
|
|
)
|
|
|
|
# Project out & in will be replicated if they exist.
|
|
if config.word_embed_proj_dim != config.hidden_size:
|
|
self.project_out = ReplicatedLinear(
|
|
config.hidden_size,
|
|
config.word_embed_proj_dim,
|
|
bias=False,
|
|
quant_config=quant_config,
|
|
prefix=add_prefix("project_out", prefix),
|
|
)
|
|
else:
|
|
self.project_out = None
|
|
|
|
if config.word_embed_proj_dim != config.hidden_size:
|
|
self.project_in = ReplicatedLinear(
|
|
config.word_embed_proj_dim,
|
|
config.hidden_size,
|
|
bias=False,
|
|
quant_config=quant_config,
|
|
prefix=add_prefix("project_in", prefix),
|
|
)
|
|
else:
|
|
self.project_in = None
|
|
|
|
# Note that the only purpose of `config._remove_final_layer_norm` is to
|
|
# keep backward compatibility with checkpoints that have been fine-tuned
|
|
# before transformers v4.20.1
|
|
# see https://github.com/facebookresearch/metaseq/pull/164
|
|
if config.do_layer_norm_before and not config._remove_final_layer_norm:
|
|
self.final_layer_norm = nn.LayerNorm(
|
|
config.hidden_size,
|
|
elementwise_affine=config.layer_norm_elementwise_affine,
|
|
)
|
|
else:
|
|
self.final_layer_norm = None
|
|
|
|
self.layers, self.start_layer, self.end_layer = make_layers(
|
|
config.num_hidden_layers,
|
|
lambda idx, prefix: OPTDecoderLayer(
|
|
config=config, layer_id=idx, quant_config=quant_config, prefix=prefix
|
|
),
|
|
pp_rank=self.pp_group.rank_in_group,
|
|
pp_size=self.pp_group.world_size,
|
|
prefix="model.layers",
|
|
)
|
|
|
|
def forward(
|
|
self,
|
|
input_ids: torch.Tensor,
|
|
positions: torch.Tensor,
|
|
forward_batch: ForwardBatch,
|
|
pp_proxy_tensors: Optional[PPProxyTensors] = None,
|
|
input_embeds: Optional[torch.Tensor] = None,
|
|
) -> Union[torch.Tensor, PPProxyTensors]:
|
|
if self.pp_group.is_first_rank:
|
|
if input_embeds is None:
|
|
input_embeds = self.embed_tokens(input_ids)
|
|
pos_embeds = self.embed_positions(positions)
|
|
if self.project_in is not None:
|
|
input_embeds, _ = self.project_in(input_embeds)
|
|
hidden_states = input_embeds + pos_embeds
|
|
else:
|
|
assert pp_proxy_tensors is not None
|
|
hidden_states = pp_proxy_tensors["hidden_states"]
|
|
|
|
for layer in self.layers[self.start_layer : self.end_layer]:
|
|
hidden_states = layer(
|
|
hidden_states=hidden_states, forward_batch=forward_batch
|
|
)
|
|
if not self.pp_group.is_last_rank:
|
|
return PPProxyTensors({"hidden_states": hidden_states})
|
|
if self.final_layer_norm is not None:
|
|
hidden_states = self.final_layer_norm(hidden_states)
|
|
# 没有经过这里
|
|
if self.project_out is not None:
|
|
hidden_states, _ = self.project_out(hidden_states)
|
|
return hidden_states
|
|
|
|
|
|
class OPTModel(nn.Module):
|
|
|
|
def __init__(
|
|
self,
|
|
config: OPTConfig,
|
|
quant_config: Optional[QuantizationConfig] = None,
|
|
prefix: str = "",
|
|
) -> None:
|
|
super().__init__()
|
|
|
|
# config = vllm_config.model_config.hf_config
|
|
# quant_config = vllm_config.quant_config
|
|
self.config = config
|
|
self.padding_idx = config.pad_token_id
|
|
self.vocab_size = config.vocab_size
|
|
self.pp_group = get_pp_group()
|
|
|
|
self.decoder = OPTDecoder(
|
|
config=config,
|
|
quant_config=quant_config,
|
|
prefix=add_prefix("decoder", prefix),
|
|
)
|
|
|
|
def forward(
|
|
self,
|
|
input_ids: torch.Tensor,
|
|
positions: torch.Tensor,
|
|
forward_batch: ForwardBatch,
|
|
pp_proxy_tensors: Optional[PPProxyTensors],
|
|
input_embeds: Optional[torch.Tensor] = None,
|
|
) -> Union[torch.Tensor, PPProxyTensors]:
|
|
return self.decoder(
|
|
input_ids,
|
|
positions,
|
|
pp_proxy_tensors=pp_proxy_tensors,
|
|
input_embeds=input_embeds,
|
|
forward_batch=forward_batch,
|
|
)
|
|
|
|
def load_kv_cache_scales(self, quantization_param_path: str) -> None:
|
|
tp_size = get_tensor_model_parallel_world_size()
|
|
tp_rank = get_tensor_model_parallel_rank()
|
|
for layer_idx, scaling_factor in kv_cache_scales_loader(
|
|
quantization_param_path,
|
|
tp_rank,
|
|
tp_size,
|
|
self.config.num_hidden_layers,
|
|
self.config.__class__.model_type,
|
|
):
|
|
if not isinstance(self.decoder.layers[layer_idx], nn.Identity):
|
|
layer_self_attn = self.decoder.layers[layer_idx].self_attn
|
|
|
|
if hasattr(layer_self_attn.attn, "k_scale"):
|
|
layer_self_attn.attn.k_scale = scaling_factor
|
|
layer_self_attn.attn.v_scale = scaling_factor
|
|
else:
|
|
raise RuntimeError(
|
|
"Self attention has no KV cache scaling " "factor attribute!"
|
|
)
|
|
|
|
|
|
class OPTForCausalLM(nn.Module):
|
|
# BitandBytes specific attributes
|
|
# in TP, these weights are partitioned along the column dimension (dim=-1)
|
|
column_parallel_weights_modules = [".down_proj.", ".o_proj."]
|
|
|
|
def __init__(
|
|
self,
|
|
config: OPTConfig,
|
|
quant_config: Optional[QuantizationConfig] = None,
|
|
prefix: str = "",
|
|
):
|
|
super().__init__()
|
|
self.config = config
|
|
self.quant_config = quant_config
|
|
|
|
self.model = OPTModel(
|
|
config=config, quant_config=quant_config, prefix=add_prefix("model", prefix)
|
|
)
|
|
if self.config.tie_word_embeddings:
|
|
self.lm_head = self.model.decoder.embed_tokens
|
|
else:
|
|
self.lm_head = ParallelLMHead(
|
|
config.vocab_size,
|
|
config.word_embed_proj_dim,
|
|
prefix=add_prefix("lm_head", prefix),
|
|
)
|
|
self.logits_processor = LogitsProcessor(config)
|
|
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
|
|
self.capture_aux_hidden_states = False
|
|
self.pp_group = get_pp_group()
|
|
self.stacked_params_mapping = [
|
|
# (param_name, shard_name, shard_id)
|
|
(".qkv_proj", ".q_proj", "q"),
|
|
(".qkv_proj", ".k_proj", "k"),
|
|
(".qkv_proj", ".v_proj", "v"),
|
|
]
|
|
|
|
def forward(
|
|
self,
|
|
input_ids: torch.Tensor,
|
|
positions: torch.Tensor,
|
|
forward_batch: ForwardBatch,
|
|
pp_proxy_tensors: Optional[PPProxyTensors] = None,
|
|
input_embeds: Optional[torch.Tensor] = None,
|
|
get_embedding: bool = False,
|
|
) -> LogitsProcessorOutput:
|
|
hidden_states = self.model(
|
|
input_ids=input_ids,
|
|
positions=positions,
|
|
forward_batch=forward_batch,
|
|
input_embeds=input_embeds,
|
|
pp_proxy_tensors=pp_proxy_tensors,
|
|
)
|
|
aux_hidden_states = None
|
|
if self.capture_aux_hidden_states:
|
|
hidden_states, aux_hidden_states = hidden_states
|
|
|
|
if self.pp_group.is_last_rank:
|
|
if not get_embedding:
|
|
return self.logits_processor(
|
|
input_ids,
|
|
hidden_states,
|
|
self.lm_head,
|
|
forward_batch,
|
|
aux_hidden_states=aux_hidden_states,
|
|
)
|
|
else:
|
|
return self.pooler(hidden_states, forward_batch)
|
|
else:
|
|
return hidden_states
|
|
|
|
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> None:
|
|
stacked_params_mapping = [
|
|
# (param_name, shard_name, shard_id)
|
|
("qkv_proj", "q_proj", "q"),
|
|
("qkv_proj", "k_proj", "k"),
|
|
("qkv_proj", "v_proj", "v"),
|
|
]
|
|
params_dict = dict(self.named_parameters(remove_duplicate=False))
|
|
|
|
for name, loaded_weight in weights:
|
|
if name.startswith("decoder"):
|
|
name = name.replace("decoder.", "model.decoder.")
|
|
layer_id = get_layer_id(name)
|
|
if (
|
|
layer_id is not None
|
|
and hasattr(self.model, "start_layer")
|
|
and (
|
|
layer_id < self.model.start_layer
|
|
or layer_id >= self.model.end_layer
|
|
)
|
|
):
|
|
continue
|
|
for param_name, weight_name, shard_id in stacked_params_mapping:
|
|
if weight_name not in name:
|
|
continue
|
|
name = name.replace(weight_name, param_name)
|
|
# Skip loading extra bias for GPTQ models.
|
|
if name.endswith(".bias") and name not in params_dict:
|
|
continue
|
|
# if is_pp_missing_parameter(name, self):
|
|
# continue
|
|
param = params_dict[name]
|
|
weight_loader = param.weight_loader
|
|
weight_loader(param, loaded_weight, shard_id)
|
|
break
|
|
else:
|
|
# Skip loading extra bias for GPTQ models.
|
|
if name.endswith(".bias") and name not in params_dict:
|
|
continue
|
|
# if is_pp_missing_parameter(name, self):
|
|
# continue
|
|
if name not in params_dict:
|
|
continue
|
|
if name in params_dict.keys():
|
|
param = params_dict[name]
|
|
weight_loader = getattr(
|
|
param, "weight_loader", default_weight_loader
|
|
)
|
|
weight_loader(param, loaded_weight)
|
|
else:
|
|
logger.warning(f"Parameter {name} not found in params_dict")
|
|
|
|
@property
|
|
def start_layer(self):
|
|
return self.model.start_layer
|
|
|
|
@property
|
|
def end_layer(self):
|
|
return self.model.end_layer
|
|
|
|
def get_input_embeddings(self) -> nn.Embedding:
|
|
return self.model.embed_tokens
|
|
|
|
def get_module_name_from_weight_name(self, name):
|
|
for param_name, weight_name, shard_id, num_shard in self.stacked_params_mapping:
|
|
if weight_name in name:
|
|
return (
|
|
name.replace(weight_name, param_name)[: -len(".weight")],
|
|
num_shard,
|
|
)
|
|
return name[: -len(".weight")], 1
|
|
|
|
def get_num_params(self):
|
|
params_dict = dict(self.named_parameters())
|
|
return len(params_dict)
|
|
|
|
def get_weights_by_name(
|
|
self, name: str, truncate_size: int = 100, tp_size: int = 1
|
|
) -> Optional[torch.Tensor]:
|
|
"""Get the weights of the parameter by its name. Similar to `get_parameter` in Hugging Face.
|
|
|
|
Only used for unit test with an unoptimized performance.
|
|
For optimized performance, please use torch.save and torch.load.
|
|
"""
|
|
try:
|
|
if name == "lm_head.weight" and self.config.tie_word_embeddings:
|
|
logger.info(
|
|
"word embedding is tied for this model, return embed_tokens.weight as lm_head.weight."
|
|
)
|
|
return (
|
|
self.model.embed_tokens.weight.cpu()
|
|
.to(torch.float32)
|
|
.numpy()
|
|
.tolist()[:truncate_size]
|
|
)
|
|
|
|
mapped_name = name
|
|
mapped_shard_id = None
|
|
for param_name, weight_name, shard_id in self.stacked_params_mapping:
|
|
if weight_name in name:
|
|
mapped_name = name.replace(weight_name, param_name)
|
|
mapped_shard_id = shard_id
|
|
break
|
|
params_dict = dict(self.named_parameters())
|
|
param = params_dict[mapped_name]
|
|
if mapped_shard_id is not None:
|
|
if mapped_shard_id in ["q", "k", "v"]:
|
|
num_heads = self.config.num_attention_heads // tp_size
|
|
num_kv_heads = self.config.num_attention_heads // tp_size
|
|
head_dim = (
|
|
self.config.hidden_size // self.config.num_attention_heads
|
|
)
|
|
if mapped_shard_id == "q":
|
|
offset = 0
|
|
size = num_heads * head_dim
|
|
elif mapped_shard_id == "k":
|
|
offset = num_heads * head_dim
|
|
size = num_kv_heads * head_dim
|
|
elif mapped_shard_id == "v":
|
|
offset = (num_heads + num_kv_heads) * head_dim
|
|
size = num_kv_heads * head_dim
|
|
weight = param.data.narrow(0, offset, size)
|
|
elif mapped_shard_id in [0, 1]:
|
|
intermediate_size = self.config.ffn_dim
|
|
slice_size = intermediate_size // tp_size
|
|
if mapped_shard_id == 0: # gate_proj
|
|
offset = 0
|
|
size = slice_size
|
|
elif mapped_shard_id == 1: # up_proj
|
|
offset = slice_size
|
|
size = slice_size
|
|
|
|
weight = param.data.narrow(0, offset, size)
|
|
else:
|
|
weight = param.data
|
|
else:
|
|
weight = param.data
|
|
if tp_size > 1 and ("o_proj" in name or "down_proj" in name):
|
|
gathered_weights = [torch.zeros_like(weight) for _ in range(tp_size)]
|
|
torch.distributed.all_gather(gathered_weights, weight)
|
|
weight = torch.cat(gathered_weights, dim=1)
|
|
return weight.cpu().to(torch.float32).numpy().tolist()[:truncate_size]
|
|
|
|
except Exception:
|
|
logger.error(
|
|
f"Error getting weights by name {name} in OPTForCausalLM: {get_exception_traceback()}"
|
|
)
|
|
return None
|
|
|
|
def get_embed_and_head(self):
|
|
return self.model.embed_tokens.weight, self.lm_head.weight
|
|
|
|
def set_embed_and_head(self, embed, head):
|
|
del self.model.embed_tokens.weight
|
|
del self.lm_head.weight
|
|
self.model.embed_tokens.weight = embed
|
|
self.lm_head.weight = head
|
|
torch.cuda.empty_cache()
|
|
torch.cuda.synchronize()
|
|
|
|
def get_embed(self):
|
|
return self.model.embed_tokens.weight
|
|
|
|
def set_embed(self, embed):
|
|
# NOTE: If draft hidden size != target hidden size, the embed weight cannot be shared for EAGLE3
|
|
if (
|
|
hasattr(self.config, "target_hidden_size")
|
|
and self.config.target_hidden_size != self.config.hidden_size
|
|
):
|
|
return
|
|
del self.model.embed_tokens.weight
|
|
self.model.embed_tokens.weight = embed
|
|
torch.cuda.empty_cache()
|
|
torch.cuda.synchronize()
|
|
|
|
def load_kv_cache_scales(self, quantization_param_path: str) -> None:
|
|
self.model.load_kv_cache_scales(quantization_param_path)
|
|
|
|
|
|
EntryClass = [OPTForCausalLM]
|