update
This commit is contained in:
136
vllm/model_inspection.py
Normal file
136
vllm/model_inspection.py
Normal file
@@ -0,0 +1,136 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Model inspection utilities for vLLM."""
|
||||
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
def _get_module_info(module: nn.Module) -> str:
|
||||
"""Get info string for a module."""
|
||||
class_name = type(module).__name__
|
||||
parts = []
|
||||
|
||||
# Add quant_method if present
|
||||
quant_method = getattr(module, "quant_method", None)
|
||||
if quant_method is not None:
|
||||
quant_name = type(quant_method).__name__
|
||||
# For CompressedTensors, show the underlying scheme instead
|
||||
scheme = getattr(module, "scheme", None)
|
||||
if scheme is not None:
|
||||
quant_name = type(scheme).__name__
|
||||
# Skip unquantized methods
|
||||
if "Unquantized" not in quant_name:
|
||||
parts.append(f"quant={quant_name}")
|
||||
|
||||
# If module has extra_repr, use it
|
||||
if hasattr(module, "extra_repr"):
|
||||
parts.append(module.extra_repr().replace("\n", ""))
|
||||
|
||||
if parts:
|
||||
return f"{class_name}({', '.join(parts)})"
|
||||
|
||||
# For unknown modules, use the default PyTorch repr
|
||||
return str(module)
|
||||
|
||||
|
||||
def _get_child_signature(child: nn.Module) -> str:
|
||||
"""Get a signature for a child module to detect duplicates."""
|
||||
lines = []
|
||||
for name, submodule in child.named_modules():
|
||||
lines.append(f"{name}:{_get_module_info(submodule)}")
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
def _format_index_ranges(indices: list[int]) -> str:
|
||||
"""Format indices into range notation (e.g., [0,1,2,4,5,6] -> '0-2, 4-6')."""
|
||||
indices = sorted(indices)
|
||||
ranges = []
|
||||
start = end = indices[0]
|
||||
|
||||
for idx in indices[1:]:
|
||||
if idx == end + 1:
|
||||
end = idx
|
||||
else:
|
||||
ranges.append(str(start) if start == end else f"{start}-{end}")
|
||||
start = end = idx
|
||||
|
||||
ranges.append(str(start) if start == end else f"{start}-{end}")
|
||||
return ", ".join(ranges)
|
||||
|
||||
|
||||
def _format_module_tree(
|
||||
module: nn.Module,
|
||||
name: str = "",
|
||||
indent: int = 0,
|
||||
) -> list[str]:
|
||||
"""Format a module tree with indentation, grouping identical layers.
|
||||
|
||||
Produces output like:
|
||||
(layers): ModuleList(
|
||||
(0-27, 29-47): 47 x LlamaDecoderLayer(
|
||||
...
|
||||
)
|
||||
(28, 48): 2 x DifferentDecoderLayer(
|
||||
...
|
||||
)
|
||||
)
|
||||
"""
|
||||
lines = []
|
||||
prefix = " " * indent
|
||||
children = list(module.named_children())
|
||||
|
||||
# Leaf node - just output the module info
|
||||
if not children:
|
||||
info = _get_module_info(module)
|
||||
lines.append(f"{prefix}({name}): {info}" if name else f"{prefix}{info}")
|
||||
return lines
|
||||
|
||||
# Non-leaf node - output opening line and recurse into children
|
||||
info = _get_module_info(module)
|
||||
lines.append(f"{prefix}({name}): {info}(" if name else f"{prefix}{info}(")
|
||||
|
||||
# Separate numbered children (e.g., "0", "1") from named ones (e.g., "norm")
|
||||
numbered: list[tuple[int, nn.Module]] = []
|
||||
non_numbered: list[tuple[str, nn.Module]] = []
|
||||
for child_name, child_module in children:
|
||||
try:
|
||||
numbered.append((int(child_name), child_module))
|
||||
except ValueError:
|
||||
non_numbered.append((child_name, child_module))
|
||||
|
||||
# Group numbered children by structure signature to collapse identical layers
|
||||
# e.g., layers 0-27 and 29-47 with same structure become "(0-27, 29-47): 47 x"
|
||||
if numbered:
|
||||
sig_to_group: dict[str, list[tuple[int, nn.Module]]] = {}
|
||||
for idx, child_module in numbered:
|
||||
sig = _get_child_signature(child_module)
|
||||
sig_to_group.setdefault(sig, []).append((idx, child_module))
|
||||
|
||||
# Output groups sorted by first index
|
||||
for group in sorted(sig_to_group.values(), key=lambda g: g[0][0]):
|
||||
indices = [idx for idx, _ in group]
|
||||
representative = group[0][1]
|
||||
child_lines = _format_module_tree(representative, "", indent + 1)
|
||||
first_line = child_lines[0].lstrip()
|
||||
child_prefix = " " * (indent + 1)
|
||||
|
||||
if len(indices) > 1:
|
||||
range_str = _format_index_ranges(indices)
|
||||
child_lines[0] = (
|
||||
f"{child_prefix}({range_str}): {len(indices)} x {first_line}"
|
||||
)
|
||||
else:
|
||||
child_lines[0] = f"{child_prefix}({indices[0]}): {first_line}"
|
||||
lines.extend(child_lines)
|
||||
|
||||
# Output non-numbered children (e.g., "embed_tokens", "norm")
|
||||
for child_name, child_module in non_numbered:
|
||||
lines.extend(_format_module_tree(child_module, child_name, indent + 1))
|
||||
|
||||
lines.append(f"{prefix})")
|
||||
return lines
|
||||
|
||||
|
||||
def format_model_inspection(model: nn.Module) -> str:
|
||||
"""Format a model into a transformers-style hierarchical string."""
|
||||
return "\n".join(_format_module_tree(model))
|
||||
Reference in New Issue
Block a user