Coverage for langbrainscore/interface/encoder.py: 70%
37 statements
« prev ^ index » next coverage.py v6.4, created at 2022-06-07 21:22 +0000
« prev ^ index » next coverage.py v6.4, created at 2022-06-07 21:22 +0000
1"""
2Interface and (partial) base implementation classes for Encoders and EncodedRepresentations
3"""
5import typing
6from abc import ABC, abstractmethod
7from dataclasses import dataclass
9import xarray as xr
10from langbrainscore.dataset import Dataset
11from langbrainscore.interface.cacheable import _Cacheable
14class _Encoder(_Cacheable, ABC):
15 """
16 Interface for *Encoder classes.
17 Must implement an `encode` method that operates on a Dataset object.
18 """
20 @staticmethod
21 def _check_dataset_interface(dataset):
22 """
23 confirms that dataset adheres to `langbrainscore.dataset.Dataset` interface.
24 """
25 if not isinstance(dataset, Dataset):
26 raise TypeError(
27 f"dataset must be of type `langbrainscore.dataset.Dataset`, not {type(dataset)}"
28 )
30 @abstractmethod
31 def encode(self, dataset: Dataset) -> xr.DataArray:
32 raise NotImplementedError
35class _ModelEncoder(_Encoder):
36 def __init__(self, model_id: str, **kwargs) -> "_ModelEncoder":
37 """This class is intended to be an interface for all ANN subclasses,
38 including HuggingFaceEncoder, and, in the future, other kinds of
39 ANN encoders
41 Args:
42 model_id (str): _description_
44 Returns:
45 _ModelEncoder: _description_
46 """
48 self._model_id = model_id
49 for k, v in kwargs.items():
50 setattr(self, k, v)
52 @abstractmethod
53 def encode(self, dataset: Dataset) -> xr.DataArray:
54 """
55 returns computed representations for stimuli passed in as a `Dataset` object
57 Args:
58 langbrainscore.dataset.Dataset: a Dataset object with a member `xarray.DataArray`
59 instance (`Dataset._xr_obj`) containing stimuli
61 Returns:
62 xr.DataArray: Model representations of each stimulus in brain dataset
63 """
64 raise NotImplementedError
67@dataclass(repr=False, eq=False, frozen=False)
68class EncoderRepresentations(_Cacheable):
69 """
70 a class to hold the encoded representations output from an `_Encoder.encode` method
71 """
73 dataset: Dataset # pointer to the dataset these are the EncodedRepresentations of
74 representations: xr.DataArray # the xarray holding representations
76 model_id: str = None
77 context_dimension: str = None
78 bidirectional: bool = False
79 emb_aggregation: typing.Union[str, None, typing.Callable] = "last"
80 emb_preproc: typing.Tuple[str] = ()
81 include_special_tokens: bool = True
84 def __getattr__(self, __name: str) -> typing.Any:
85 """falls back on the xarray object in case of a NameError using __getattribute__
86 on this object"""
87 try:
88 return getattr(self.representations, __name)
89 except AttributeError:
90 raise AttributeError(f"no attribute called `{__name}`")