Format (#593)
This commit is contained in:
@@ -5,14 +5,12 @@ import tqdm
|
||||
from torch import nn
|
||||
from transformers import LlamaConfig
|
||||
from vllm.config import CacheConfig
|
||||
from vllm.distributed import (
|
||||
get_tensor_model_parallel_rank,
|
||||
)
|
||||
from vllm.distributed import get_tensor_model_parallel_rank
|
||||
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
|
||||
from sglang.srt.managers.controller.model_runner import InputMetadata
|
||||
from sglang.srt.layers.logits_processor import LogitProcessorOutput
|
||||
from sglang.srt.managers.controller.model_runner import InputMetadata
|
||||
from sglang.srt.models.llama2 import LlamaModel
|
||||
|
||||
|
||||
@@ -28,7 +26,9 @@ class LlamaForClassification(nn.Module):
|
||||
self.quant_config = quant_config
|
||||
self.model = LlamaModel(config, quant_config=quant_config)
|
||||
|
||||
self.classification_head = nn.Linear(config.hidden_size, config.classification_out_size)
|
||||
self.classification_head = nn.Linear(
|
||||
config.hidden_size, config.classification_out_size
|
||||
)
|
||||
self.eos_token_id = config.eos_token_id
|
||||
|
||||
def forward(
|
||||
@@ -45,7 +45,9 @@ class LlamaForClassification(nn.Module):
|
||||
|
||||
if scores.shape[0] != input_metadata.batch_size:
|
||||
print("Warning: the EOS tokens are missing in some sentences.")
|
||||
scores = torch.ones((input_metadata.batch_size, self.config.classification_out_size)).to(input_ids.device)
|
||||
scores = torch.ones(
|
||||
(input_metadata.batch_size, self.config.classification_out_size)
|
||||
).to(input_ids.device)
|
||||
|
||||
return LogitProcessorOutput(
|
||||
next_token_logits=scores,
|
||||
@@ -101,4 +103,5 @@ class LlamaForClassification(nn.Module):
|
||||
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
|
||||
EntryClass = LlamaForClassification
|
||||
|
||||
EntryClass = LlamaForClassification
|
||||
|
||||
Reference in New Issue
Block a user