From a6f39375e59db3ac51f50b27b9820b3c520bbc37 Mon Sep 17 00:00:00 2001 From: Chranos <826995883@qq.com> Date: Tue, 10 Feb 2026 16:10:28 +0800 Subject: [PATCH] debugging --- .../model_executor/model_loader/loader.py | 53 ++++++++++++++++++- .../vllm/model_executor/models/utils.py | 51 ++++++++++++++++-- 2 files changed, 99 insertions(+), 5 deletions(-) diff --git a/vllm-v0.6.2/vllm/model_executor/model_loader/loader.py b/vllm-v0.6.2/vllm/model_executor/model_loader/loader.py index 140b61f..0b5f6f3 100644 --- a/vllm-v0.6.2/vllm/model_executor/model_loader/loader.py +++ b/vllm-v0.6.2/vllm/model_executor/model_loader/loader.py @@ -89,15 +89,63 @@ def device_loading_context(module: torch.nn.Module, logger = init_logger(__name__) +def _get_device_memory_info_loader(): + """Get device memory info for debug logging. Returns dict or None.""" + try: + import torch.mlu + allocated = torch.mlu.memory_allocated() / (1024 ** 3) + reserved = torch.mlu.memory_reserved() / (1024 ** 3) + free, total = torch.mlu.mem_get_info() + return {"allocated": allocated, "reserved": reserved, + "free": free / (1024 ** 3), "total": total / (1024 ** 3)} + except Exception: + pass + try: + if torch.cuda.is_available(): + allocated = torch.cuda.memory_allocated() / (1024 ** 3) + reserved = torch.cuda.memory_reserved() / (1024 ** 3) + free, total = torch.cuda.mem_get_info() + return {"allocated": allocated, "reserved": reserved, + "free": free / (1024 ** 3), "total": total / (1024 ** 3)} + except Exception: + pass + return None + + +def _log_mem(tag: str): + info = _get_device_memory_info_loader() + if info: + logger.info( + "[DEBUG-MEM] %s: allocated=%.2f GiB, reserved=%.2f GiB, " + "free=%.2f GiB, total=%.2f GiB", + tag, info["allocated"], info["reserved"], + info["free"], info["total"]) + + def _initialize_model(vllm_config: VllmConfig, prefix: str = "") -> nn.Module: """Initialize a model with the given configurations.""" model_config = vllm_config.model_config model_class, _ = get_model_architecture(model_config) + logger.info("[DEBUG-MEM] Model class: %s, dtype: %s", + model_class.__name__, model_config.dtype) + _log_mem("Before _initialize_model") signatures = inspect.signature(model_class.__init__) all_params = [param.name for param in signatures.parameters.values()] if "vllm_config" in all_params and "prefix" in all_params: # new-style model class - return model_class(vllm_config=vllm_config, prefix=prefix) + model = model_class(vllm_config=vllm_config, prefix=prefix) + _log_mem("After _initialize_model (empty weights created)") + # Print model parameter summary + total_params = 0 + total_bytes = 0 + for name, param in model.named_parameters(): + total_params += param.numel() + total_bytes += param.numel() * param.element_size() + logger.info( + "[DEBUG-MEM] Model params: %d, " + "estimated size: %.2f GiB", + total_params, total_bytes / (1024 ** 3)) + return model msg = ("vLLM model class should accept `vllm_config` and `prefix` as " "input arguments. Possibly you have an old-style model class" " registered from out of tree and it is used for new vLLM version. " @@ -327,11 +375,14 @@ class DefaultModelLoader(BaseModelLoader): model_config = vllm_config.model_config target_device = torch.device(device_config.device) + _log_mem("load_model start, target_device=%s" % target_device) with set_default_torch_dtype(model_config.dtype): with target_device: model = _initialize_model(vllm_config=vllm_config) + _log_mem("Before load_weights") model.load_weights(self._get_all_weights(model_config, model)) + _log_mem("After load_weights") for _, module in model.named_modules(): quant_method = getattr(module, "quant_method", None) diff --git a/vllm-v0.6.2/vllm/model_executor/models/utils.py b/vllm-v0.6.2/vllm/model_executor/models/utils.py index 1d51885..4a0681f 100644 --- a/vllm-v0.6.2/vllm/model_executor/models/utils.py +++ b/vllm-v0.6.2/vllm/model_executor/models/utils.py @@ -492,6 +492,29 @@ def maybe_offload_to_cpu(module: torch.nn.Module) -> torch.nn.Module: return module +def _get_device_memory_info() -> Tuple[Optional[float], Optional[float], Optional[float]]: + """Get device memory info in GiB. Returns (allocated, reserved, total) or Nones.""" + try: + import torch.mlu + allocated = torch.mlu.memory_allocated() / (1024 ** 3) + reserved = torch.mlu.memory_reserved() / (1024 ** 3) + free, total = torch.mlu.mem_get_info() + total = total / (1024 ** 3) + return allocated, reserved, total + except Exception: + pass + try: + if torch.cuda.is_available(): + allocated = torch.cuda.memory_allocated() / (1024 ** 3) + reserved = torch.cuda.memory_reserved() / (1024 ** 3) + free, total = torch.cuda.mem_get_info() + total = total / (1024 ** 3) + return allocated, reserved, total + except Exception: + pass + return None, None, None + + def make_layers( num_hidden_layers: int, layer_fn: LayerFn, @@ -505,11 +528,31 @@ def make_layers( start_layer, end_layer = get_pp_indices(num_hidden_layers, get_pp_group().rank_in_group, get_pp_group().world_size) + + alloc_before, _, total = _get_device_memory_info() + if alloc_before is not None: + logger.info( + "[DEBUG-MEM] make_layers start: allocated=%.2f GiB, " + "total=%.2f GiB, layers to create: %d-%d / %d", + alloc_before, total, start_layer, end_layer, num_hidden_layers) + + created_layers = [] + for idx in range(start_layer, end_layer): + layer = maybe_offload_to_cpu(layer_fn(prefix=f"{prefix}.{idx}")) + alloc_after, reserved, _ = _get_device_memory_info() + if alloc_after is not None: + delta = alloc_after - alloc_before + logger.info( + "[DEBUG-MEM] Layer %s.%d created: " + "allocated=%.2f GiB (+%.4f GiB), reserved=%.2f GiB", + prefix, idx, alloc_after, delta, reserved) + alloc_before = alloc_after + created_layers.append(layer) + modules = torch.nn.ModuleList( - [PPMissingLayer() for _ in range(start_layer)] + [ - maybe_offload_to_cpu(layer_fn(prefix=f"{prefix}.{idx}")) - for idx in range(start_layer, end_layer) - ] + [PPMissingLayer() for _ in range(end_layer, num_hidden_layers)]) + [PPMissingLayer() for _ in range(start_layer)] + + created_layers + + [PPMissingLayer() for _ in range(end_layer, num_hidden_layers)]) return start_layer, end_layer, modules