Files
enginex-bi_series-vllm/pkgs/xformers/components/attention/feature_maps/base.py
2025-08-05 19:02:46 +08:00

62 lines
1.6 KiB
Python

# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
from abc import abstractmethod
from dataclasses import asdict, dataclass
from typing import Optional, Type, TypeVar
import torch
"""
Feature maps allow for a given query or key to be encoded in a different space.
"""
Self = TypeVar("Self", bound="FeatureMap")
@dataclass
class FeatureMapConfig:
name: str
dim_features: int
iter_before_redraw: Optional[int]
normalize_inputs: Optional[bool]
epsilon: Optional[float]
class FeatureMap(torch.nn.Module):
def __init__(
self,
dim_features: int,
iter_before_redraw: Optional[int] = None,
normalize_inputs: bool = False,
epsilon: float = 1e-6,
):
super().__init__()
self.dim_features = dim_features
self.dim_feature_map = dim_features
self.iter_before_redraw = iter_before_redraw
self.features: Optional[torch.Tensor] = None
self.epsilon = epsilon
self.normalize_inputs = normalize_inputs
self._iter_counter = 0
@abstractmethod
def _get_feature_map(self, dim_input: int, dim_features: int, device: torch.device):
raise NotImplementedError()
@classmethod
def from_config(cls: Type[Self], config: FeatureMapConfig) -> Self:
# Generate the class inputs from the config
fields = asdict(config)
# Skip all Nones so that default values are used
fields = {k: v for k, v in fields.items() if v is not None}
return cls(**fields)