Fix select (#64)
This commit is contained in:
@@ -63,7 +63,7 @@ class LogitsProcessor(nn.Module):
|
||||
|
||||
|
||||
def compute_normalized_logprobs(all_logprobs, len_add_1, input_ids):
|
||||
# assert all_logprobs.shape[0] == torch.sum(len_add_1) == input_ids.shape[0]
|
||||
# assert all_logprobs.shape[0] == input_ids.shape[0] == torch.sum(len_add_1)
|
||||
logprobs = torch.zeros(
|
||||
(all_logprobs.shape[0] - len_add_1.shape[0]), dtype=torch.float32, device="cuda"
|
||||
)
|
||||
@@ -72,6 +72,7 @@ def compute_normalized_logprobs(all_logprobs, len_add_1, input_ids):
|
||||
end = torch.cumsum(len_add_1.sub_(1), dim=0)
|
||||
start = torch.cat((torch.tensor([0], device="cuda"), end[:-1]), 0)
|
||||
end.sub_(1)
|
||||
torch.cuda.synchronize()
|
||||
sum_logp = cumsum[end] - cumsum[start] + logprobs[start]
|
||||
res = sum_logp / len_add_1
|
||||
return res
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum, auto
|
||||
from typing import List
|
||||
import logging
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -12,6 +13,10 @@ from vllm.model_executor.layers.quantization.awq import AWQConfig
|
||||
from vllm.model_executor.model_loader import _set_default_torch_dtype
|
||||
from vllm.model_executor.parallel_utils.parallel_state import initialize_model_parallel
|
||||
|
||||
|
||||
logger = logging.getLogger("model_runner")
|
||||
|
||||
|
||||
# for model_mode
|
||||
global_model_mode: List[str] = []
|
||||
|
||||
@@ -257,6 +262,8 @@ class ModelRunner:
|
||||
if model_class is None:
|
||||
raise ValueError(f"Unsupported architectures: {architectures}")
|
||||
|
||||
logger.info("load weight begin.")
|
||||
|
||||
# Load weights
|
||||
linear_method = None
|
||||
with _set_default_torch_dtype(torch.float16):
|
||||
@@ -267,7 +274,7 @@ class ModelRunner:
|
||||
if hf_quant_config is not None:
|
||||
# TODO: config quantization awq etc
|
||||
quant_config = AWQConfig.from_config(hf_quant_config)
|
||||
print(f"quant_config: {quant_config}")
|
||||
logger.info(f"quant_config: {quant_config}")
|
||||
linear_method = quant_config.get_linear_method()
|
||||
model = model_class(
|
||||
config=self.model_config.hf_config, linear_method=linear_method
|
||||
@@ -280,6 +287,8 @@ class ModelRunner:
|
||||
)
|
||||
self.model = model.eval()
|
||||
|
||||
logger.info("load weight end.")
|
||||
|
||||
def profile_max_num_token(self, total_gpu_memory):
|
||||
available_gpu_memory = get_available_gpu_memory(
|
||||
self.tp_rank, distributed=self.tp_size > 1
|
||||
|
||||
Reference in New Issue
Block a user