Coverage for langbrainscore/interface/encoder.py: 70%

37 statements  

« prev     ^ index     » next       coverage.py v6.4, created at 2022-06-07 21:22 +0000

1""" 

2Interface and (partial) base implementation classes for Encoders and EncodedRepresentations 

3""" 

4 

5import typing 

6from abc import ABC, abstractmethod 

7from dataclasses import dataclass 

8 

9import xarray as xr 

10from langbrainscore.dataset import Dataset 

11from langbrainscore.interface.cacheable import _Cacheable 

12 

13 

14class _Encoder(_Cacheable, ABC): 

15 """ 

16 Interface for *Encoder classes. 

17 Must implement an `encode` method that operates on a Dataset object. 

18 """ 

19 

20 @staticmethod 

21 def _check_dataset_interface(dataset): 

22 """ 

23 confirms that dataset adheres to `langbrainscore.dataset.Dataset` interface. 

24 """ 

25 if not isinstance(dataset, Dataset): 

26 raise TypeError( 

27 f"dataset must be of type `langbrainscore.dataset.Dataset`, not {type(dataset)}" 

28 ) 

29 

30 @abstractmethod 

31 def encode(self, dataset: Dataset) -> xr.DataArray: 

32 raise NotImplementedError 

33 

34 

35class _ModelEncoder(_Encoder): 

36 def __init__(self, model_id: str, **kwargs) -> "_ModelEncoder": 

37 """This class is intended to be an interface for all ANN subclasses, 

38 including HuggingFaceEncoder, and, in the future, other kinds of 

39 ANN encoders 

40 

41 Args: 

42 model_id (str): _description_ 

43 

44 Returns: 

45 _ModelEncoder: _description_ 

46 """ 

47 

48 self._model_id = model_id 

49 for k, v in kwargs.items(): 

50 setattr(self, k, v) 

51 

52 @abstractmethod 

53 def encode(self, dataset: Dataset) -> xr.DataArray: 

54 """ 

55 returns computed representations for stimuli passed in as a `Dataset` object 

56 

57 Args: 

58 langbrainscore.dataset.Dataset: a Dataset object with a member `xarray.DataArray` 

59 instance (`Dataset._xr_obj`) containing stimuli 

60 

61 Returns: 

62 xr.DataArray: Model representations of each stimulus in brain dataset 

63 """ 

64 raise NotImplementedError 

65 

66 

67@dataclass(repr=False, eq=False, frozen=False) 

68class EncoderRepresentations(_Cacheable): 

69 """ 

70 a class to hold the encoded representations output from an `_Encoder.encode` method 

71 """ 

72 

73 dataset: Dataset # pointer to the dataset these are the EncodedRepresentations of 

74 representations: xr.DataArray # the xarray holding representations 

75 

76 model_id: str = None 

77 context_dimension: str = None 

78 bidirectional: bool = False 

79 emb_aggregation: typing.Union[str, None, typing.Callable] = "last" 

80 emb_preproc: typing.Tuple[str] = () 

81 include_special_tokens: bool = True 

82 

83 

84 def __getattr__(self, __name: str) -> typing.Any: 

85 """falls back on the xarray object in case of a NameError using __getattribute__ 

86 on this object""" 

87 try: 

88 return getattr(self.representations, __name) 

89 except AttributeError: 

90 raise AttributeError(f"no attribute called `{__name}`")