Fix select (#64)
This commit is contained in:
@@ -63,7 +63,7 @@ class LogitsProcessor(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
def compute_normalized_logprobs(all_logprobs, len_add_1, input_ids):
|
def compute_normalized_logprobs(all_logprobs, len_add_1, input_ids):
|
||||||
# assert all_logprobs.shape[0] == torch.sum(len_add_1) == input_ids.shape[0]
|
# assert all_logprobs.shape[0] == input_ids.shape[0] == torch.sum(len_add_1)
|
||||||
logprobs = torch.zeros(
|
logprobs = torch.zeros(
|
||||||
(all_logprobs.shape[0] - len_add_1.shape[0]), dtype=torch.float32, device="cuda"
|
(all_logprobs.shape[0] - len_add_1.shape[0]), dtype=torch.float32, device="cuda"
|
||||||
)
|
)
|
||||||
@@ -72,6 +72,7 @@ def compute_normalized_logprobs(all_logprobs, len_add_1, input_ids):
|
|||||||
end = torch.cumsum(len_add_1.sub_(1), dim=0)
|
end = torch.cumsum(len_add_1.sub_(1), dim=0)
|
||||||
start = torch.cat((torch.tensor([0], device="cuda"), end[:-1]), 0)
|
start = torch.cat((torch.tensor([0], device="cuda"), end[:-1]), 0)
|
||||||
end.sub_(1)
|
end.sub_(1)
|
||||||
|
torch.cuda.synchronize()
|
||||||
sum_logp = cumsum[end] - cumsum[start] + logprobs[start]
|
sum_logp = cumsum[end] - cumsum[start] + logprobs[start]
|
||||||
res = sum_logp / len_add_1
|
res = sum_logp / len_add_1
|
||||||
return res
|
return res
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from enum import Enum, auto
|
from enum import Enum, auto
|
||||||
from typing import List
|
from typing import List
|
||||||
|
import logging
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
@@ -12,6 +13,10 @@ from vllm.model_executor.layers.quantization.awq import AWQConfig
|
|||||||
from vllm.model_executor.model_loader import _set_default_torch_dtype
|
from vllm.model_executor.model_loader import _set_default_torch_dtype
|
||||||
from vllm.model_executor.parallel_utils.parallel_state import initialize_model_parallel
|
from vllm.model_executor.parallel_utils.parallel_state import initialize_model_parallel
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.getLogger("model_runner")
|
||||||
|
|
||||||
|
|
||||||
# for model_mode
|
# for model_mode
|
||||||
global_model_mode: List[str] = []
|
global_model_mode: List[str] = []
|
||||||
|
|
||||||
@@ -257,6 +262,8 @@ class ModelRunner:
|
|||||||
if model_class is None:
|
if model_class is None:
|
||||||
raise ValueError(f"Unsupported architectures: {architectures}")
|
raise ValueError(f"Unsupported architectures: {architectures}")
|
||||||
|
|
||||||
|
logger.info("load weight begin.")
|
||||||
|
|
||||||
# Load weights
|
# Load weights
|
||||||
linear_method = None
|
linear_method = None
|
||||||
with _set_default_torch_dtype(torch.float16):
|
with _set_default_torch_dtype(torch.float16):
|
||||||
@@ -267,7 +274,7 @@ class ModelRunner:
|
|||||||
if hf_quant_config is not None:
|
if hf_quant_config is not None:
|
||||||
# TODO: config quantization awq etc
|
# TODO: config quantization awq etc
|
||||||
quant_config = AWQConfig.from_config(hf_quant_config)
|
quant_config = AWQConfig.from_config(hf_quant_config)
|
||||||
print(f"quant_config: {quant_config}")
|
logger.info(f"quant_config: {quant_config}")
|
||||||
linear_method = quant_config.get_linear_method()
|
linear_method = quant_config.get_linear_method()
|
||||||
model = model_class(
|
model = model_class(
|
||||||
config=self.model_config.hf_config, linear_method=linear_method
|
config=self.model_config.hf_config, linear_method=linear_method
|
||||||
@@ -280,6 +287,8 @@ class ModelRunner:
|
|||||||
)
|
)
|
||||||
self.model = model.eval()
|
self.model = model.eval()
|
||||||
|
|
||||||
|
logger.info("load weight end.")
|
||||||
|
|
||||||
def profile_max_num_token(self, total_gpu_memory):
|
def profile_max_num_token(self, total_gpu_memory):
|
||||||
available_gpu_memory = get_available_gpu_memory(
|
available_gpu_memory = get_available_gpu_memory(
|
||||||
self.tp_rank, distributed=self.tp_size > 1
|
self.tp_rank, distributed=self.tp_size > 1
|
||||||
|
|||||||
Reference in New Issue
Block a user