debugging
This commit is contained in:
@@ -89,15 +89,63 @@ def device_loading_context(module: torch.nn.Module,
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def _get_device_memory_info_loader():
|
||||
"""Get device memory info for debug logging. Returns dict or None."""
|
||||
try:
|
||||
import torch.mlu
|
||||
allocated = torch.mlu.memory_allocated() / (1024 ** 3)
|
||||
reserved = torch.mlu.memory_reserved() / (1024 ** 3)
|
||||
free, total = torch.mlu.mem_get_info()
|
||||
return {"allocated": allocated, "reserved": reserved,
|
||||
"free": free / (1024 ** 3), "total": total / (1024 ** 3)}
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
if torch.cuda.is_available():
|
||||
allocated = torch.cuda.memory_allocated() / (1024 ** 3)
|
||||
reserved = torch.cuda.memory_reserved() / (1024 ** 3)
|
||||
free, total = torch.cuda.mem_get_info()
|
||||
return {"allocated": allocated, "reserved": reserved,
|
||||
"free": free / (1024 ** 3), "total": total / (1024 ** 3)}
|
||||
except Exception:
|
||||
pass
|
||||
return None
|
||||
|
||||
|
||||
def _log_mem(tag: str):
|
||||
info = _get_device_memory_info_loader()
|
||||
if info:
|
||||
logger.info(
|
||||
"[DEBUG-MEM] %s: allocated=%.2f GiB, reserved=%.2f GiB, "
|
||||
"free=%.2f GiB, total=%.2f GiB",
|
||||
tag, info["allocated"], info["reserved"],
|
||||
info["free"], info["total"])
|
||||
|
||||
|
||||
def _initialize_model(vllm_config: VllmConfig, prefix: str = "") -> nn.Module:
|
||||
"""Initialize a model with the given configurations."""
|
||||
model_config = vllm_config.model_config
|
||||
model_class, _ = get_model_architecture(model_config)
|
||||
logger.info("[DEBUG-MEM] Model class: %s, dtype: %s",
|
||||
model_class.__name__, model_config.dtype)
|
||||
_log_mem("Before _initialize_model")
|
||||
signatures = inspect.signature(model_class.__init__)
|
||||
all_params = [param.name for param in signatures.parameters.values()]
|
||||
if "vllm_config" in all_params and "prefix" in all_params:
|
||||
# new-style model class
|
||||
return model_class(vllm_config=vllm_config, prefix=prefix)
|
||||
model = model_class(vllm_config=vllm_config, prefix=prefix)
|
||||
_log_mem("After _initialize_model (empty weights created)")
|
||||
# Print model parameter summary
|
||||
total_params = 0
|
||||
total_bytes = 0
|
||||
for name, param in model.named_parameters():
|
||||
total_params += param.numel()
|
||||
total_bytes += param.numel() * param.element_size()
|
||||
logger.info(
|
||||
"[DEBUG-MEM] Model params: %d, "
|
||||
"estimated size: %.2f GiB",
|
||||
total_params, total_bytes / (1024 ** 3))
|
||||
return model
|
||||
msg = ("vLLM model class should accept `vllm_config` and `prefix` as "
|
||||
"input arguments. Possibly you have an old-style model class"
|
||||
" registered from out of tree and it is used for new vLLM version. "
|
||||
@@ -327,11 +375,14 @@ class DefaultModelLoader(BaseModelLoader):
|
||||
model_config = vllm_config.model_config
|
||||
|
||||
target_device = torch.device(device_config.device)
|
||||
_log_mem("load_model start, target_device=%s" % target_device)
|
||||
with set_default_torch_dtype(model_config.dtype):
|
||||
with target_device:
|
||||
model = _initialize_model(vllm_config=vllm_config)
|
||||
|
||||
_log_mem("Before load_weights")
|
||||
model.load_weights(self._get_all_weights(model_config, model))
|
||||
_log_mem("After load_weights")
|
||||
|
||||
for _, module in model.named_modules():
|
||||
quant_method = getattr(module, "quant_method", None)
|
||||
|
||||
Reference in New Issue
Block a user