From 11f3cca64fa7bd91a795075876ed2407c4b1ec86 Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Sat, 20 Jan 2024 23:20:35 -0800 Subject: [PATCH] Fix select (#64) --- python/sglang/srt/layers/logits_processor.py | 3 ++- python/sglang/srt/managers/router/model_runner.py | 11 ++++++++++- 2 files changed, 12 insertions(+), 2 deletions(-) diff --git a/python/sglang/srt/layers/logits_processor.py b/python/sglang/srt/layers/logits_processor.py index 315d71869..1442b6db7 100644 --- a/python/sglang/srt/layers/logits_processor.py +++ b/python/sglang/srt/layers/logits_processor.py @@ -63,7 +63,7 @@ class LogitsProcessor(nn.Module): def compute_normalized_logprobs(all_logprobs, len_add_1, input_ids): - # assert all_logprobs.shape[0] == torch.sum(len_add_1) == input_ids.shape[0] + # assert all_logprobs.shape[0] == input_ids.shape[0] == torch.sum(len_add_1) logprobs = torch.zeros( (all_logprobs.shape[0] - len_add_1.shape[0]), dtype=torch.float32, device="cuda" ) @@ -72,6 +72,7 @@ def compute_normalized_logprobs(all_logprobs, len_add_1, input_ids): end = torch.cumsum(len_add_1.sub_(1), dim=0) start = torch.cat((torch.tensor([0], device="cuda"), end[:-1]), 0) end.sub_(1) + torch.cuda.synchronize() sum_logp = cumsum[end] - cumsum[start] + logprobs[start] res = sum_logp / len_add_1 return res diff --git a/python/sglang/srt/managers/router/model_runner.py b/python/sglang/srt/managers/router/model_runner.py index bcdb3125b..b200a7295 100644 --- a/python/sglang/srt/managers/router/model_runner.py +++ b/python/sglang/srt/managers/router/model_runner.py @@ -1,6 +1,7 @@ from dataclasses import dataclass from enum import Enum, auto from typing import List +import logging import numpy as np import torch @@ -12,6 +13,10 @@ from vllm.model_executor.layers.quantization.awq import AWQConfig from vllm.model_executor.model_loader import _set_default_torch_dtype from vllm.model_executor.parallel_utils.parallel_state import initialize_model_parallel + +logger = logging.getLogger("model_runner") + + # for model_mode global_model_mode: List[str] = [] @@ -257,6 +262,8 @@ class ModelRunner: if model_class is None: raise ValueError(f"Unsupported architectures: {architectures}") + logger.info("load weight begin.") + # Load weights linear_method = None with _set_default_torch_dtype(torch.float16): @@ -267,7 +274,7 @@ class ModelRunner: if hf_quant_config is not None: # TODO: config quantization awq etc quant_config = AWQConfig.from_config(hf_quant_config) - print(f"quant_config: {quant_config}") + logger.info(f"quant_config: {quant_config}") linear_method = quant_config.get_linear_method() model = model_class( config=self.model_config.hf_config, linear_method=linear_method @@ -280,6 +287,8 @@ class ModelRunner: ) self.model = model.eval() + logger.info("load weight end.") + def profile_max_num_token(self, total_gpu_memory): available_gpu_memory = get_available_gpu_memory( self.tp_rank, distributed=self.tp_size > 1