Upgrade to vllm 0.17.0 corex v4.1 overlay
This commit is contained in:
@@ -44,6 +44,7 @@ from vllm.model_executor.models.internvl import (
|
||||
)
|
||||
from vllm.model_executor.models.module_mapping import MultiModelKeys
|
||||
from vllm.model_executor.models.nemotron_h import NemotronHForCausalLM
|
||||
from vllm.model_executor.models.parakeet import ParakeetExtractor, ProjectedParakeet
|
||||
from vllm.model_executor.models.radio import RadioModel, calc_seq_lens
|
||||
from vllm.model_executor.models.utils import (
|
||||
init_vllm_registered_model,
|
||||
@@ -55,12 +56,14 @@ from vllm.multimodal.evs import (
|
||||
compute_retention_mask,
|
||||
)
|
||||
from vllm.multimodal.inputs import (
|
||||
AudioItem,
|
||||
MultiModalDataDict,
|
||||
MultiModalFieldConfig,
|
||||
MultiModalKwargsItems,
|
||||
VideoItem,
|
||||
)
|
||||
from vllm.multimodal.parse import (
|
||||
AudioProcessorItems,
|
||||
ImageEmbeddingItems,
|
||||
ImageProcessorItems,
|
||||
ImageSize,
|
||||
@@ -91,9 +94,29 @@ Image.MAX_IMAGE_PIXELS = None # Disable the limit entirely
|
||||
# Alternative: Set a specific higher limit
|
||||
# Image.MAX_IMAGE_PIXELS = 300000000 # ~300M pixels
|
||||
|
||||
|
||||
class NanoNemotronVLAudioFeatureInputs(TensorSchema):
|
||||
"""
|
||||
Dimensions:
|
||||
- b: Number of audio clips
|
||||
- t: Audio feature length
|
||||
- f: Feature size (mel bins)
|
||||
"""
|
||||
|
||||
type: Literal["audio_features"] = "audio_features"
|
||||
input_audio_features: Annotated[torch.Tensor, TensorShape("b", "t", "f")]
|
||||
feature_attention_mask: Annotated[torch.Tensor, TensorShape("b", "t")]
|
||||
audio_feature_lengths: Annotated[torch.Tensor, TensorShape("b")]
|
||||
|
||||
|
||||
MAX_AUDIO_LEN_S = 10 * 60 # 10 minutes
|
||||
|
||||
IMG_START = "<img>"
|
||||
IMG_END = "</img>"
|
||||
IMG_CONTEXT = "<image>"
|
||||
AUDIO_START = "<so_start>"
|
||||
AUDIO_END = "<so_end>"
|
||||
AUDIO_CONTEXT = "<so_embedding>"
|
||||
|
||||
# Profiling
|
||||
# MAX_FRAMES = 16
|
||||
@@ -820,6 +843,11 @@ class NanoNemotronVLProcessor(BaseNanoNemotronVLProcessor):
|
||||
self.video_token = video_token
|
||||
self.video_pruning_rate = video_pruning_rate
|
||||
|
||||
self.audio_extractor: ParakeetExtractor | None = None
|
||||
raw_sound_config = getattr(config, "sound_config", None)
|
||||
if raw_sound_config is not None:
|
||||
self.audio_extractor = ParakeetExtractor(raw_sound_config)
|
||||
|
||||
# Pre-tokenize special tokens for video processing
|
||||
# to avoid repeated tokenization
|
||||
self._img_start_token_ids = tokenizer.encode(
|
||||
@@ -952,11 +980,53 @@ class NanoNemotronVLProcessor(BaseNanoNemotronVLProcessor):
|
||||
text = [t.replace("<video>", video_repl_text, 1) for t in text]
|
||||
return text, video_inputs
|
||||
|
||||
def _preprocess_audio(
|
||||
self,
|
||||
text: list[str],
|
||||
audios: list[npt.NDArray],
|
||||
):
|
||||
if len(audios) == 0:
|
||||
return text, {}
|
||||
assert self.audio_extractor is not None
|
||||
|
||||
extractor = self.audio_extractor
|
||||
|
||||
parts = [x for x in re.split(f"({re.escape(AUDIO_CONTEXT)})", text[0]) if x]
|
||||
token_count = parts.count(AUDIO_CONTEXT)
|
||||
if token_count != len(audios):
|
||||
raise ValueError(
|
||||
"Number of audio tokens in text does not match the number "
|
||||
f"of audios (tokens={token_count}, audios={len(audios)})."
|
||||
)
|
||||
audio_index = 0
|
||||
for idx, part in enumerate(parts):
|
||||
if part == AUDIO_CONTEXT:
|
||||
audio_repl = self.get_audio_repl(audios[audio_index])
|
||||
parts[idx] = audio_repl.full
|
||||
audio_index += 1
|
||||
text = ["".join(parts)]
|
||||
audio_inputs = extractor(
|
||||
audios,
|
||||
sampling_rate=extractor.sampling_rate,
|
||||
return_tensors="pt",
|
||||
)
|
||||
input_audio_features = audio_inputs.input_features
|
||||
feature_attention_mask = audio_inputs.attention_mask
|
||||
audio_feature_lengths = feature_attention_mask.sum(dim=1)
|
||||
audio_inputs = {
|
||||
"input_audio_features": input_audio_features,
|
||||
"feature_attention_mask": feature_attention_mask,
|
||||
"audio_feature_lengths": audio_feature_lengths,
|
||||
}
|
||||
|
||||
return text, audio_inputs
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
text: str | list[str] | None = None,
|
||||
images: Image.Image | list[Image.Image] | None = None,
|
||||
videos: list[tuple[npt.NDArray, dict[str, Any]]] | None = None,
|
||||
audios: AudioItem | list[AudioItem] | None = None,
|
||||
return_tensors: str | TensorType | None = None,
|
||||
max_num_tiles: int | None = None,
|
||||
) -> BatchFeature:
|
||||
@@ -964,8 +1034,8 @@ class NanoNemotronVLProcessor(BaseNanoNemotronVLProcessor):
|
||||
if max_num_tiles is None:
|
||||
max_num_tiles = self.max_num_tiles
|
||||
|
||||
text, images, videos = [
|
||||
self._make_batch_input(x) for x in (text, images, videos)
|
||||
text, images, videos, audios = [
|
||||
self._make_batch_input(x) for x in (text, images, videos, audios)
|
||||
]
|
||||
|
||||
text, image_inputs = self._preprocess_image(
|
||||
@@ -980,17 +1050,22 @@ class NanoNemotronVLProcessor(BaseNanoNemotronVLProcessor):
|
||||
max_num_tiles=1,
|
||||
)
|
||||
|
||||
text, audio_inputs = self._preprocess_audio(
|
||||
text=text,
|
||||
audios=audios,
|
||||
)
|
||||
|
||||
text_inputs = self.tokenizer(text, add_special_tokens=False)
|
||||
|
||||
combined_inputs = {**text_inputs, **video_inputs, **audio_inputs}
|
||||
|
||||
if self.dynamic_tiler is None:
|
||||
batch = BatchFeature(
|
||||
{**text_inputs, **video_inputs, **image_inputs},
|
||||
{**combined_inputs, **image_inputs},
|
||||
tensor_type=return_tensors,
|
||||
)
|
||||
else:
|
||||
batch = BatchFeature(
|
||||
{**text_inputs, **video_inputs}, tensor_type=return_tensors
|
||||
)
|
||||
batch = BatchFeature(combined_inputs, tensor_type=return_tensors)
|
||||
# allow images to be exempt from the BatchFeature validation:
|
||||
# We will .stack() them in _parse_and_validate_image_input
|
||||
batch.update(image_inputs)
|
||||
@@ -1006,6 +1081,15 @@ class NanoNemotronVLProcessor(BaseNanoNemotronVLProcessor):
|
||||
|
||||
return PromptUpdateDetails.select_text(repl_full, IMG_CONTEXT)
|
||||
|
||||
def get_audio_repl(
|
||||
self,
|
||||
audio: npt.NDArray,
|
||||
) -> PromptUpdateDetails[str]:
|
||||
assert self.audio_extractor is not None
|
||||
num_tokens = self.audio_extractor.audio_token_count(len(audio))
|
||||
repl_full = f"{AUDIO_START}{AUDIO_CONTEXT * num_tokens}{AUDIO_END}"
|
||||
return PromptUpdateDetails.select_text(repl_full, AUDIO_CONTEXT)
|
||||
|
||||
@classmethod
|
||||
def get_video_repl(
|
||||
cls,
|
||||
@@ -1147,15 +1231,28 @@ class NanoNemotronVLProcessingInfo(BaseNanoNemotronVLProcessingInfo):
|
||||
def supports_video(self):
|
||||
return self.get_hf_processor().supports_video
|
||||
|
||||
@property
|
||||
def audio_extractor(self) -> ParakeetExtractor | None:
|
||||
return self.get_hf_processor().audio_extractor
|
||||
|
||||
def get_data_parser(self):
|
||||
target_sr = None
|
||||
target_channels = None
|
||||
if extractor := self.audio_extractor:
|
||||
target_sr = extractor.sampling_rate
|
||||
target_channels = 1
|
||||
|
||||
return MultiModalDataParser(
|
||||
video_needs_metadata=True,
|
||||
target_sr=target_sr,
|
||||
target_channels=target_channels,
|
||||
expected_hidden_size=self._get_expected_hidden_size(),
|
||||
)
|
||||
|
||||
def get_supported_mm_limits(self):
|
||||
video_limit = {"video": None} if self.supports_video else {}
|
||||
return {**super().get_supported_mm_limits(), **video_limit}
|
||||
audio_limit = {"audio": None} if self.audio_extractor is not None else {}
|
||||
return {**super().get_supported_mm_limits(), **video_limit, **audio_limit}
|
||||
|
||||
def get_video_token(self) -> str | None:
|
||||
return IMG_CONTEXT
|
||||
@@ -1304,7 +1401,16 @@ class NanoNemotronVLMultiModalProcessor(
|
||||
else:
|
||||
video_fields = {}
|
||||
|
||||
return image_fields | video_fields
|
||||
if self.info.audio_extractor is not None:
|
||||
audio_fields = dict(
|
||||
input_audio_features=MultiModalFieldConfig.batched("audio"),
|
||||
feature_attention_mask=MultiModalFieldConfig.batched("audio"),
|
||||
audio_feature_lengths=MultiModalFieldConfig.batched("audio"),
|
||||
)
|
||||
else:
|
||||
audio_fields = {}
|
||||
|
||||
return image_fields | video_fields | audio_fields
|
||||
|
||||
def _get_prompt_updates(
|
||||
self,
|
||||
@@ -1373,6 +1479,20 @@ class NanoNemotronVLMultiModalProcessor(
|
||||
),
|
||||
]
|
||||
|
||||
def get_audio_replacement(item_idx: int):
|
||||
audios = mm_items.get_items("audio", AudioProcessorItems)
|
||||
return hf_processor.get_audio_repl(audios.get(item_idx))
|
||||
|
||||
if self.info.audio_extractor is not None:
|
||||
prompt_repl = [
|
||||
*prompt_repl,
|
||||
PromptReplacement(
|
||||
modality="audio",
|
||||
target=AUDIO_CONTEXT,
|
||||
replacement=get_audio_replacement,
|
||||
),
|
||||
]
|
||||
|
||||
return prompt_repl
|
||||
|
||||
|
||||
@@ -1422,8 +1542,13 @@ class NanoNemotronVLDummyInputsBuilder(
|
||||
|
||||
def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
|
||||
num_videos = mm_counts.get("video", 0)
|
||||
num_audios = mm_counts.get("audio", 0)
|
||||
|
||||
return super().get_dummy_text(mm_counts) + "<video>" * num_videos
|
||||
return (
|
||||
super().get_dummy_text(mm_counts)
|
||||
+ "<video>" * num_videos
|
||||
+ AUDIO_CONTEXT * num_audios
|
||||
)
|
||||
|
||||
def _get_dummy_videos(
|
||||
self,
|
||||
@@ -1482,7 +1607,25 @@ class NanoNemotronVLDummyInputsBuilder(
|
||||
}
|
||||
else:
|
||||
dummy_video = {}
|
||||
return {**dummy_image, **dummy_video}
|
||||
|
||||
if extractor := self.info.audio_extractor:
|
||||
num_audios = mm_counts.get("audio", 0)
|
||||
audio_overrides = mm_options.get("audio") if mm_options else None
|
||||
tokens_per_audio = max(1, seq_len // max(num_audios, 1))
|
||||
max_audio_num_samples = MAX_AUDIO_LEN_S * extractor.sampling_rate
|
||||
calculated_max_audio_num_samples = extractor.audio_length(tokens_per_audio)
|
||||
audio_len = min(max_audio_num_samples, calculated_max_audio_num_samples)
|
||||
dummy_audio = {
|
||||
"audio": self._get_dummy_audios(
|
||||
length=audio_len,
|
||||
num_audios=num_audios,
|
||||
overrides=audio_overrides,
|
||||
)
|
||||
}
|
||||
else:
|
||||
dummy_audio = {}
|
||||
|
||||
return {**dummy_image, **dummy_video, **dummy_audio}
|
||||
|
||||
|
||||
@MULTIMODAL_REGISTRY.register_processor(
|
||||
@@ -1499,12 +1642,15 @@ class NemotronH_Nano_VL_V2(
|
||||
return "<image>"
|
||||
if modality.startswith("video"):
|
||||
return "<video>"
|
||||
if modality.startswith("audio"):
|
||||
return AUDIO_CONTEXT
|
||||
return None
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
config = vllm_config.model_config.hf_config
|
||||
multimodal_config = vllm_config.model_config.multimodal_config
|
||||
model_config = vllm_config.model_config
|
||||
config = model_config.hf_config
|
||||
multimodal_config = model_config.multimodal_config
|
||||
image_size = config.force_image_size
|
||||
patch_size = config.patch_size
|
||||
self.patch_size = patch_size
|
||||
@@ -1523,10 +1669,12 @@ class NemotronH_Nano_VL_V2(
|
||||
hf_config=config.text_config,
|
||||
prefix=maybe_prefix(prefix, "language_model"),
|
||||
)
|
||||
|
||||
with self._mark_tower_model(vllm_config, {"image", "video"}):
|
||||
llm_dtype = self.language_model.config.dtype
|
||||
assert isinstance(llm_dtype, torch.dtype)
|
||||
self.llm_dtype = llm_dtype
|
||||
with self._mark_tower_model(vllm_config, {"image", "video", "audio"}):
|
||||
self.vision_model = self.get_vit_model_from_radio_config(config).to(
|
||||
self.language_model.config.dtype
|
||||
llm_dtype
|
||||
)
|
||||
|
||||
# Construct the vision projection.
|
||||
@@ -1547,14 +1695,26 @@ class NemotronH_Nano_VL_V2(
|
||||
ReLUSquaredActivation(),
|
||||
nn.Linear(vision_projection_hidden_size, llm_hidden_size, bias=False),
|
||||
)
|
||||
self.mlp1 = mlp1.to(self.language_model.config.dtype)
|
||||
self.mlp1 = mlp1.to(llm_dtype)
|
||||
self.sound_encoder: ProjectedParakeet | None = None
|
||||
if getattr(config, "sound_config", None) is not None:
|
||||
logger.info_once(
|
||||
"Found sound config, initializing sound encoder for Nemotron AVLM",
|
||||
scope="global",
|
||||
)
|
||||
self.sound_encoder = ProjectedParakeet(
|
||||
config.sound_config,
|
||||
dtype=llm_dtype,
|
||||
llm_hidden_size=llm_hidden_size,
|
||||
max_model_len=model_config.max_model_len,
|
||||
)
|
||||
|
||||
self.config = config
|
||||
self.model_config = vllm_config.model_config
|
||||
|
||||
# Pre-tokenize special tokens for video processing
|
||||
# to avoid repeated tokenization
|
||||
tokenizer = cached_tokenizer_from_config(vllm_config.model_config)
|
||||
tokenizer = cached_tokenizer_from_config(model_config)
|
||||
self._img_start_token_ids = tokenizer.encode(
|
||||
IMG_START, add_special_tokens=False
|
||||
)
|
||||
@@ -1566,7 +1726,10 @@ class NemotronH_Nano_VL_V2(
|
||||
config
|
||||
)
|
||||
if self.dynamic_resolution:
|
||||
logger.info("Dynamic resolution is enabled for NanoNemotronVLProcessor")
|
||||
logger.info_once(
|
||||
"Dynamic resolution is enabled for NanoNemotronVLProcessor",
|
||||
scope="global",
|
||||
)
|
||||
|
||||
def pixel_shuffle(self, x, scale_factor=0.5):
|
||||
n, w, h, c = x.size()
|
||||
@@ -1780,6 +1943,51 @@ class NemotronH_Nano_VL_V2(
|
||||
|
||||
return final_video_embeddings
|
||||
|
||||
def _process_audio_input(
|
||||
self, audio_input: NanoNemotronVLAudioFeatureInputs
|
||||
) -> tuple[torch.Tensor, ...]:
|
||||
assert self.sound_encoder is not None
|
||||
input_audio_features = audio_input.input_audio_features
|
||||
feature_attention_mask = audio_input.feature_attention_mask
|
||||
target_device = next(self.sound_encoder.parameters()).device
|
||||
|
||||
# When cross-request batching combines audio clips with different
|
||||
# time dimensions, _reduce_data returns a list instead of a stacked
|
||||
# tensor. Pad to the max time dim and stack; the attention mask
|
||||
# already marks valid positions so zero-padding is safe.
|
||||
if isinstance(input_audio_features, list):
|
||||
feature_sizes = [f.shape[-2] for f in input_audio_features]
|
||||
max_t = max(feature_sizes)
|
||||
padded_feats = [
|
||||
torch.nn.functional.pad(feat, (0, 0, 0, max_t - feat_size))
|
||||
for feat, feat_size in zip(
|
||||
input_audio_features, feature_sizes, strict=True
|
||||
)
|
||||
]
|
||||
padded_masks = [
|
||||
torch.nn.functional.pad(mask, (0, max_t - mask.shape[-1]))
|
||||
for mask in feature_attention_mask
|
||||
]
|
||||
input_audio_features = torch.stack(padded_feats)
|
||||
feature_attention_mask = torch.stack(padded_masks)
|
||||
|
||||
input_audio_features = input_audio_features.to(
|
||||
dtype=self.llm_dtype, device=target_device
|
||||
)
|
||||
feature_attention_mask = feature_attention_mask.to(device=target_device)
|
||||
sound_embeds = self.sound_encoder(input_audio_features, feature_attention_mask)
|
||||
|
||||
valid_input_lens = feature_attention_mask.sum(dim=1)
|
||||
valid_output_lens = self.sound_encoder.encoder._get_subsampling_output_length(
|
||||
valid_input_lens
|
||||
)
|
||||
truncated_embeds = []
|
||||
for i in range(sound_embeds.shape[0]):
|
||||
valid_len = valid_output_lens[i].item()
|
||||
truncated_embeds.append(sound_embeds[i, :valid_len])
|
||||
|
||||
return tuple(truncated_embeds)
|
||||
|
||||
def _create_final_video_embeddings(
|
||||
self,
|
||||
video_embeddings: torch.Tensor,
|
||||
@@ -1887,6 +2095,18 @@ class NemotronH_Nano_VL_V2(
|
||||
modalities["images"] = self._parse_and_validate_image_input(**kwargs)
|
||||
if input_key in ("pixel_values_flat_video",) and "videos" not in modalities:
|
||||
modalities["videos"] = self._parse_and_validate_video_input(**kwargs)
|
||||
if (
|
||||
input_key
|
||||
in (
|
||||
"input_audio_features",
|
||||
"feature_attention_mask",
|
||||
"audio_feature_lengths",
|
||||
)
|
||||
and "audios" not in modalities
|
||||
):
|
||||
modalities["audios"] = NanoNemotronVLAudioFeatureInputs(
|
||||
**kwargs, validate=False
|
||||
)
|
||||
|
||||
return modalities
|
||||
|
||||
@@ -1917,6 +2137,10 @@ class NemotronH_Nano_VL_V2(
|
||||
video_input = modalities["videos"]
|
||||
video_embeddings = self._process_video_input(video_input)
|
||||
multimodal_embeddings += tuple(video_embeddings)
|
||||
if modality == "audios":
|
||||
audio_input = modalities["audios"]
|
||||
audio_embeddings = self._process_audio_input(audio_input)
|
||||
multimodal_embeddings += tuple(audio_embeddings)
|
||||
|
||||
return multimodal_embeddings
|
||||
|
||||
@@ -1947,8 +2171,8 @@ class NemotronH_Nano_VL_V2(
|
||||
"""
|
||||
return MultiModelKeys.from_string_field(
|
||||
language_model="language_model",
|
||||
connector="mlp1",
|
||||
tower_model="vision_model",
|
||||
connector=["mlp1", "sound_encoder.projection"],
|
||||
tower_model=["vision_model", "sound_encoder.encoder"],
|
||||
)
|
||||
|
||||
def compute_logits(
|
||||
@@ -1969,9 +2193,13 @@ class NemotronH_Nano_VL_V2(
|
||||
def is_vision_weights(name: str) -> bool:
|
||||
return name.startswith("vision_model.radio_model.")
|
||||
|
||||
def is_sound_weights(name: str) -> bool:
|
||||
return name.startswith("sound")
|
||||
|
||||
# Separate weights by component
|
||||
llm_weights = []
|
||||
vision_weights = []
|
||||
sound_weights = []
|
||||
|
||||
for name, w in weights:
|
||||
if is_llm(name):
|
||||
@@ -1987,107 +2215,15 @@ class NemotronH_Nano_VL_V2(
|
||||
# Convert: vision_model.radio_model.* → radio_model.*
|
||||
hf_key = name[len("vision_model.") :] # Remove "vision_model." prefix
|
||||
vision_weights.append((hf_key, w))
|
||||
elif is_sound_weights(name):
|
||||
assert self.sound_encoder is not None
|
||||
sound_weights.append((name, w))
|
||||
|
||||
self.language_model.load_weights(llm_weights)
|
||||
self.vision_model.load_weights(vision_weights)
|
||||
|
||||
def print_architecture(self, detailed: bool = True, save_to_file: str = None):
|
||||
"""
|
||||
Print model architecture with parameter names, shapes, and sizes.
|
||||
|
||||
Args:
|
||||
detailed: If True, show detailed parameter breakdown
|
||||
save_to_file: If provided, save output to this file path
|
||||
"""
|
||||
import sys
|
||||
from io import StringIO
|
||||
|
||||
# Capture output if saving to file
|
||||
original_stdout = sys.stdout
|
||||
if save_to_file:
|
||||
sys.stdout = StringIO()
|
||||
|
||||
try:
|
||||
print("=" * 100)
|
||||
print("NemotronH_Nano_VL_V2 Model Architecture")
|
||||
print("=" * 100)
|
||||
|
||||
total_params = 0
|
||||
param_groups = {
|
||||
"language_model": [],
|
||||
"vision_model": [],
|
||||
"mlp1": [],
|
||||
"other": [],
|
||||
}
|
||||
|
||||
for name, param in self.named_parameters():
|
||||
param_size = param.numel()
|
||||
total_params += param_size
|
||||
|
||||
# Group parameters by main component
|
||||
if name.startswith("language_model"):
|
||||
param_groups["language_model"].append(
|
||||
(name, param.shape, param_size, param.dtype)
|
||||
)
|
||||
elif name.startswith("vision_model"):
|
||||
param_groups["vision_model"].append(
|
||||
(name, param.shape, param_size, param.dtype)
|
||||
)
|
||||
elif name.startswith("mlp1"):
|
||||
param_groups["mlp1"].append(
|
||||
(name, param.shape, param_size, param.dtype)
|
||||
)
|
||||
else:
|
||||
param_groups["other"].append(
|
||||
(name, param.shape, param_size, param.dtype)
|
||||
)
|
||||
|
||||
if detailed:
|
||||
print(
|
||||
f"{name:<70} | Shape: {str(param.shape):<25} | "
|
||||
f"Size: {param_size:>12,} | Dtype: {param.dtype}"
|
||||
)
|
||||
|
||||
print("=" * 100)
|
||||
print("Summary by Component:")
|
||||
print("-" * 60)
|
||||
|
||||
for component, params in param_groups.items():
|
||||
if params: # Only show components that have parameters
|
||||
component_total = sum(size for _, _, size, _ in params)
|
||||
percentage = (
|
||||
(component_total / total_params) * 100
|
||||
if total_params > 0
|
||||
else 0
|
||||
)
|
||||
print(
|
||||
f"{component:<20} | Parameters: {len(params):>4} | "
|
||||
f"Total Size: {component_total:>15,} | "
|
||||
f"{percentage:>6.2f}%"
|
||||
)
|
||||
|
||||
print("-" * 60)
|
||||
print(f"{'Total Parameters':<20} | {total_params:>15,}")
|
||||
|
||||
# Estimate memory usage (assuming bfloat16 = 2 bytes per parameter)
|
||||
memory_mb = total_params * 2 / (1024**2)
|
||||
memory_gb = memory_mb / 1024
|
||||
print(f"{'Est. Memory (MB)':<20} | {memory_mb:>15.2f}")
|
||||
print(f"{'Est. Memory (GB)':<20} | {memory_gb:>15.2f}")
|
||||
print("=" * 100)
|
||||
|
||||
# Save to file if requested
|
||||
if save_to_file:
|
||||
output = sys.stdout.getvalue()
|
||||
sys.stdout = original_stdout
|
||||
with open(save_to_file, "w") as f:
|
||||
f.write(output)
|
||||
print(f"Architecture saved to: {save_to_file}")
|
||||
print(output) # Also print to console
|
||||
|
||||
finally:
|
||||
if save_to_file and sys.stdout != original_stdout:
|
||||
sys.stdout = original_stdout
|
||||
if self.sound_encoder is not None:
|
||||
assert len(sound_weights) > 0
|
||||
self.sound_encoder.load_weights(sound_weights)
|
||||
|
||||
def get_vit_model_from_radio_config(self, hf_config):
|
||||
hf_config_vision = hf_config.vision_config
|
||||
|
||||
Reference in New Issue
Block a user