feat: update linear deps 1/N (#1305)
This commit is contained in:
@@ -26,7 +26,7 @@ import struct
|
||||
import time
|
||||
from importlib.metadata import PackageNotFoundError, version
|
||||
from io import BytesIO
|
||||
from typing import List, Optional, Union
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import psutil
|
||||
@@ -682,3 +682,23 @@ def replace_submodule(
|
||||
target_name = module_name.split(".")[-1]
|
||||
setattr(parent, target_name, new_module)
|
||||
return new_module
|
||||
|
||||
|
||||
def set_weight_attrs(
|
||||
weight: torch.Tensor,
|
||||
weight_attrs: Optional[Dict[str, Any]],
|
||||
):
|
||||
"""Set attributes on a weight tensor.
|
||||
|
||||
This method is used to set attributes on a weight tensor. This method
|
||||
will not overwrite existing attributes.
|
||||
|
||||
Args:
|
||||
weight: The weight tensor.
|
||||
weight_attrs: A dictionary of attributes to set on the weight tensor.
|
||||
"""
|
||||
if weight_attrs is None:
|
||||
return
|
||||
for key, value in weight_attrs.items():
|
||||
assert not hasattr(weight, key), f"Overwriting existing tensor attribute: {key}"
|
||||
setattr(weight, key, value)
|
||||
|
||||
Reference in New Issue
Block a user