Module langbrainscore.interface.encoder
Interface and (partial) base implementation classes for Encoders and EncodedRepresentations
Expand source code
"""
Interface and (partial) base implementation classes for Encoders and EncodedRepresentations
"""
import typing
from abc import ABC, abstractmethod
from dataclasses import dataclass
import xarray as xr
from langbrainscore.dataset import Dataset
from langbrainscore.interface.cacheable import _Cacheable
class _Encoder(_Cacheable, ABC):
"""
Interface for *Encoder classes.
Must implement an `encode` method that operates on a Dataset object.
"""
@staticmethod
def _check_dataset_interface(dataset):
"""
confirms that dataset adheres to `langbrainscore.dataset.Dataset` interface.
"""
if not isinstance(dataset, Dataset):
raise TypeError(
f"dataset must be of type `langbrainscore.dataset.Dataset`, not {type(dataset)}"
)
@abstractmethod
def encode(self, dataset: Dataset) -> xr.DataArray:
raise NotImplementedError
class _ModelEncoder(_Encoder):
def __init__(self, model_id: str, **kwargs) -> "_ModelEncoder":
"""This class is intended to be an interface for all ANN subclasses,
including HuggingFaceEncoder, and, in the future, other kinds of
ANN encoders
Args:
model_id (str): _description_
Returns:
_ModelEncoder: _description_
"""
self._model_id = model_id
for k, v in kwargs.items():
setattr(self, k, v)
@abstractmethod
def encode(self, dataset: Dataset) -> xr.DataArray:
"""
returns computed representations for stimuli passed in as a `Dataset` object
Args:
langbrainscore.dataset.Dataset: a Dataset object with a member `xarray.DataArray`
instance (`Dataset._xr_obj`) containing stimuli
Returns:
xr.DataArray: Model representations of each stimulus in brain dataset
"""
raise NotImplementedError
@dataclass(repr=False, eq=False, frozen=False)
class EncoderRepresentations(_Cacheable):
"""
a class to hold the encoded representations output from an `_Encoder.encode` method
"""
dataset: Dataset # pointer to the dataset these are the EncodedRepresentations of
representations: xr.DataArray # the xarray holding representations
model_id: str = None
context_dimension: str = None
bidirectional: bool = False
emb_aggregation: typing.Union[str, None, typing.Callable] = "last"
emb_preproc: typing.Tuple[str] = ()
include_special_tokens: bool = True
def __getattr__(self, __name: str) -> typing.Any:
"""falls back on the xarray object in case of a NameError using __getattribute__
on this object"""
try:
return getattr(self.representations, __name)
except AttributeError:
raise AttributeError(f"no attribute called `{__name}`")
Classes
class EncoderRepresentations (dataset: Dataset, representations: xarray.core.dataarray.DataArray, model_id: str = None, context_dimension: str = None, bidirectional: bool = False, emb_aggregation: Union[str, None, Callable] = 'last', emb_preproc: Tuple[str] = (), include_special_tokens: bool = True)
-
a class to hold the encoded representations output from an
_Encoder.encode
methodExpand source code
class EncoderRepresentations(_Cacheable): """ a class to hold the encoded representations output from an `_Encoder.encode` method """ dataset: Dataset # pointer to the dataset these are the EncodedRepresentations of representations: xr.DataArray # the xarray holding representations model_id: str = None context_dimension: str = None bidirectional: bool = False emb_aggregation: typing.Union[str, None, typing.Callable] = "last" emb_preproc: typing.Tuple[str] = () include_special_tokens: bool = True def __getattr__(self, __name: str) -> typing.Any: """falls back on the xarray object in case of a NameError using __getattribute__ on this object""" try: return getattr(self.representations, __name) except AttributeError: raise AttributeError(f"no attribute called `{__name}`")
Ancestors
- langbrainscore.interface.cacheable._Cacheable
- typing.Protocol
- typing.Generic
Class variables
var bidirectional : bool
var context_dimension : str
var dataset : Dataset
var emb_aggregation : Union[str, None, Callable]
var emb_preproc : Tuple[str]
var include_special_tokens : bool
var model_id : str
var representations : xarray.core.dataarray.DataArray