This commit is contained in:
Ying Sheng
2024-07-05 10:06:17 -07:00
committed by GitHub
parent 5a57b8addd
commit dc1b8bcfaa
21 changed files with 487 additions and 354 deletions

View File

@@ -5,14 +5,12 @@ import tqdm
from torch import nn
from transformers import LlamaConfig
from vllm.config import CacheConfig
from vllm.distributed import (
get_tensor_model_parallel_rank,
)
from vllm.distributed import get_tensor_model_parallel_rank
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from sglang.srt.managers.controller.model_runner import InputMetadata
from sglang.srt.layers.logits_processor import LogitProcessorOutput
from sglang.srt.managers.controller.model_runner import InputMetadata
from sglang.srt.models.llama2 import LlamaModel
@@ -28,7 +26,9 @@ class LlamaForClassification(nn.Module):
self.quant_config = quant_config
self.model = LlamaModel(config, quant_config=quant_config)
self.classification_head = nn.Linear(config.hidden_size, config.classification_out_size)
self.classification_head = nn.Linear(
config.hidden_size, config.classification_out_size
)
self.eos_token_id = config.eos_token_id
def forward(
@@ -45,7 +45,9 @@ class LlamaForClassification(nn.Module):
if scores.shape[0] != input_metadata.batch_size:
print("Warning: the EOS tokens are missing in some sentences.")
scores = torch.ones((input_metadata.batch_size, self.config.classification_out_size)).to(input_ids.device)
scores = torch.ones(
(input_metadata.batch_size, self.config.classification_out_size)
).to(input_ids.device)
return LogitProcessorOutput(
next_token_logits=scores,
@@ -101,4 +103,5 @@ class LlamaForClassification(nn.Module):
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight)
EntryClass = LlamaForClassification
EntryClass = LlamaForClassification