[Fix] Reduce memory usage for loading llava model & Remove EntryClassRemapping (#1308)
This commit is contained in:
@@ -16,17 +16,16 @@ limitations under the License.
|
||||
from typing import Iterable, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import tqdm
|
||||
from torch import nn
|
||||
from transformers import LlamaConfig
|
||||
from vllm.config import CacheConfig
|
||||
from vllm.distributed import get_tensor_model_parallel_rank
|
||||
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
||||
from sglang.srt.layers.sampler import SampleOutput
|
||||
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
||||
from sglang.srt.models.llama2 import LlamaModel
|
||||
from sglang.srt.models.llama import LlamaForCausalLM, LlamaModel
|
||||
|
||||
|
||||
class LlamaForClassification(nn.Module):
|
||||
@@ -42,10 +41,12 @@ class LlamaForClassification(nn.Module):
|
||||
self.model = LlamaModel(config, quant_config=quant_config)
|
||||
|
||||
self.classification_head = nn.Linear(
|
||||
config.hidden_size, config.classification_out_size
|
||||
config.hidden_size, config.classification_out_size, bias=False
|
||||
)
|
||||
self.eos_token_id = config.eos_token_id
|
||||
|
||||
self.param_dict = dict(self.named_parameters())
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(
|
||||
self,
|
||||
@@ -65,7 +66,7 @@ class LlamaForClassification(nn.Module):
|
||||
(input_metadata.batch_size, self.config.classification_out_size)
|
||||
).to(input_ids.device)
|
||||
|
||||
return LogitsProcessorOutput(
|
||||
logits_output = LogitsProcessorOutput(
|
||||
next_token_logits=scores,
|
||||
next_token_logprobs=scores,
|
||||
normalized_prompt_logprobs=scores,
|
||||
@@ -74,46 +75,38 @@ class LlamaForClassification(nn.Module):
|
||||
output_top_logprobs=None,
|
||||
)
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
("qkv_proj", "q_proj", "q"),
|
||||
("qkv_proj", "k_proj", "k"),
|
||||
("qkv_proj", "v_proj", "v"),
|
||||
("gate_up_proj", "gate_proj", 0),
|
||||
("gate_up_proj", "up_proj", 1),
|
||||
]
|
||||
params_dict = dict(self.named_parameters())
|
||||
if get_tensor_model_parallel_rank() == 0:
|
||||
weights = tqdm.tqdm(weights, total=int(len(params_dict) * 1.5))
|
||||
for name, loaded_weight in weights:
|
||||
if "rotary_emb.inv_freq" in name or "projector" in name:
|
||||
continue
|
||||
if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name:
|
||||
# Models trained using ColossalAI may include these tensors in
|
||||
# the checkpoint. Skip them.
|
||||
continue
|
||||
if "lm_head" in name:
|
||||
continue
|
||||
# A dummy to make this work
|
||||
sample_output = SampleOutput(
|
||||
success=torch.full(
|
||||
size=(scores.shape[0],),
|
||||
fill_value=True,
|
||||
dtype=torch.bool,
|
||||
),
|
||||
probs=torch.full(
|
||||
size=(scores.shape[0], 1),
|
||||
fill_value=1.0,
|
||||
dtype=torch.float16,
|
||||
),
|
||||
batch_next_token_ids=torch.full(
|
||||
size=(scores.shape[0],),
|
||||
fill_value=0,
|
||||
dtype=torch.long,
|
||||
),
|
||||
)
|
||||
return sample_output, logits_output
|
||||
|
||||
for param_name, weight_name, shard_id in stacked_params_mapping:
|
||||
if weight_name not in name:
|
||||
continue
|
||||
name = name.replace(weight_name, param_name)
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param, loaded_weight, shard_id)
|
||||
break
|
||||
else:
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
params_dict = self.param_dict
|
||||
|
||||
for name, loaded_weight in weights:
|
||||
if "classification_head" in name:
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
elif "lm_head" in name:
|
||||
continue
|
||||
else:
|
||||
LlamaForCausalLM.load_weights(self, [(name, loaded_weight)])
|
||||
|
||||
|
||||
EntryClass = LlamaForClassification
|
||||
|
||||
Reference in New Issue
Block a user