debugging
This commit is contained in:
@@ -89,15 +89,63 @@ def device_loading_context(module: torch.nn.Module,
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def _get_device_memory_info_loader():
|
||||
"""Get device memory info for debug logging. Returns dict or None."""
|
||||
try:
|
||||
import torch.mlu
|
||||
allocated = torch.mlu.memory_allocated() / (1024 ** 3)
|
||||
reserved = torch.mlu.memory_reserved() / (1024 ** 3)
|
||||
free, total = torch.mlu.mem_get_info()
|
||||
return {"allocated": allocated, "reserved": reserved,
|
||||
"free": free / (1024 ** 3), "total": total / (1024 ** 3)}
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
if torch.cuda.is_available():
|
||||
allocated = torch.cuda.memory_allocated() / (1024 ** 3)
|
||||
reserved = torch.cuda.memory_reserved() / (1024 ** 3)
|
||||
free, total = torch.cuda.mem_get_info()
|
||||
return {"allocated": allocated, "reserved": reserved,
|
||||
"free": free / (1024 ** 3), "total": total / (1024 ** 3)}
|
||||
except Exception:
|
||||
pass
|
||||
return None
|
||||
|
||||
|
||||
def _log_mem(tag: str):
|
||||
info = _get_device_memory_info_loader()
|
||||
if info:
|
||||
logger.info(
|
||||
"[DEBUG-MEM] %s: allocated=%.2f GiB, reserved=%.2f GiB, "
|
||||
"free=%.2f GiB, total=%.2f GiB",
|
||||
tag, info["allocated"], info["reserved"],
|
||||
info["free"], info["total"])
|
||||
|
||||
|
||||
def _initialize_model(vllm_config: VllmConfig, prefix: str = "") -> nn.Module:
|
||||
"""Initialize a model with the given configurations."""
|
||||
model_config = vllm_config.model_config
|
||||
model_class, _ = get_model_architecture(model_config)
|
||||
logger.info("[DEBUG-MEM] Model class: %s, dtype: %s",
|
||||
model_class.__name__, model_config.dtype)
|
||||
_log_mem("Before _initialize_model")
|
||||
signatures = inspect.signature(model_class.__init__)
|
||||
all_params = [param.name for param in signatures.parameters.values()]
|
||||
if "vllm_config" in all_params and "prefix" in all_params:
|
||||
# new-style model class
|
||||
return model_class(vllm_config=vllm_config, prefix=prefix)
|
||||
model = model_class(vllm_config=vllm_config, prefix=prefix)
|
||||
_log_mem("After _initialize_model (empty weights created)")
|
||||
# Print model parameter summary
|
||||
total_params = 0
|
||||
total_bytes = 0
|
||||
for name, param in model.named_parameters():
|
||||
total_params += param.numel()
|
||||
total_bytes += param.numel() * param.element_size()
|
||||
logger.info(
|
||||
"[DEBUG-MEM] Model params: %d, "
|
||||
"estimated size: %.2f GiB",
|
||||
total_params, total_bytes / (1024 ** 3))
|
||||
return model
|
||||
msg = ("vLLM model class should accept `vllm_config` and `prefix` as "
|
||||
"input arguments. Possibly you have an old-style model class"
|
||||
" registered from out of tree and it is used for new vLLM version. "
|
||||
@@ -327,11 +375,14 @@ class DefaultModelLoader(BaseModelLoader):
|
||||
model_config = vllm_config.model_config
|
||||
|
||||
target_device = torch.device(device_config.device)
|
||||
_log_mem("load_model start, target_device=%s" % target_device)
|
||||
with set_default_torch_dtype(model_config.dtype):
|
||||
with target_device:
|
||||
model = _initialize_model(vllm_config=vllm_config)
|
||||
|
||||
_log_mem("Before load_weights")
|
||||
model.load_weights(self._get_all_weights(model_config, model))
|
||||
_log_mem("After load_weights")
|
||||
|
||||
for _, module in model.named_modules():
|
||||
quant_method = getattr(module, "quant_method", None)
|
||||
|
||||
@@ -492,6 +492,29 @@ def maybe_offload_to_cpu(module: torch.nn.Module) -> torch.nn.Module:
|
||||
return module
|
||||
|
||||
|
||||
def _get_device_memory_info() -> Tuple[Optional[float], Optional[float], Optional[float]]:
|
||||
"""Get device memory info in GiB. Returns (allocated, reserved, total) or Nones."""
|
||||
try:
|
||||
import torch.mlu
|
||||
allocated = torch.mlu.memory_allocated() / (1024 ** 3)
|
||||
reserved = torch.mlu.memory_reserved() / (1024 ** 3)
|
||||
free, total = torch.mlu.mem_get_info()
|
||||
total = total / (1024 ** 3)
|
||||
return allocated, reserved, total
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
if torch.cuda.is_available():
|
||||
allocated = torch.cuda.memory_allocated() / (1024 ** 3)
|
||||
reserved = torch.cuda.memory_reserved() / (1024 ** 3)
|
||||
free, total = torch.cuda.mem_get_info()
|
||||
total = total / (1024 ** 3)
|
||||
return allocated, reserved, total
|
||||
except Exception:
|
||||
pass
|
||||
return None, None, None
|
||||
|
||||
|
||||
def make_layers(
|
||||
num_hidden_layers: int,
|
||||
layer_fn: LayerFn,
|
||||
@@ -505,11 +528,31 @@ def make_layers(
|
||||
start_layer, end_layer = get_pp_indices(num_hidden_layers,
|
||||
get_pp_group().rank_in_group,
|
||||
get_pp_group().world_size)
|
||||
|
||||
alloc_before, _, total = _get_device_memory_info()
|
||||
if alloc_before is not None:
|
||||
logger.info(
|
||||
"[DEBUG-MEM] make_layers start: allocated=%.2f GiB, "
|
||||
"total=%.2f GiB, layers to create: %d-%d / %d",
|
||||
alloc_before, total, start_layer, end_layer, num_hidden_layers)
|
||||
|
||||
created_layers = []
|
||||
for idx in range(start_layer, end_layer):
|
||||
layer = maybe_offload_to_cpu(layer_fn(prefix=f"{prefix}.{idx}"))
|
||||
alloc_after, reserved, _ = _get_device_memory_info()
|
||||
if alloc_after is not None:
|
||||
delta = alloc_after - alloc_before
|
||||
logger.info(
|
||||
"[DEBUG-MEM] Layer %s.%d created: "
|
||||
"allocated=%.2f GiB (+%.4f GiB), reserved=%.2f GiB",
|
||||
prefix, idx, alloc_after, delta, reserved)
|
||||
alloc_before = alloc_after
|
||||
created_layers.append(layer)
|
||||
|
||||
modules = torch.nn.ModuleList(
|
||||
[PPMissingLayer() for _ in range(start_layer)] + [
|
||||
maybe_offload_to_cpu(layer_fn(prefix=f"{prefix}.{idx}"))
|
||||
for idx in range(start_layer, end_layer)
|
||||
] + [PPMissingLayer() for _ in range(end_layer, num_hidden_layers)])
|
||||
[PPMissingLayer() for _ in range(start_layer)]
|
||||
+ created_layers
|
||||
+ [PPMissingLayer() for _ in range(end_layer, num_hidden_layers)])
|
||||
return start_layer, end_layer, modules
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user