Coverage for langbrainscore/encoder/brain.py: 40%

15 statements  

« prev     ^ index     » next       coverage.py v6.4, created at 2022-06-07 21:22 +0000

1import langbrainscore 

2import xarray as xr 

3from langbrainscore.interface.encoder import _Encoder, EncoderRepresentations 

4 

5 

6class BrainEncoder(_Encoder): 

7 """ 

8 This class is used to extract the relevant contents of a given 

9 `langbrainscore.dataset.Dataset` object and maintains the Encoder interface. 

10 """ 

11 

12 def __init__( 

13 self, measurement: str = "unknown", aggregate_time: bool = False 

14 ) -> "BrainEncoder": 

15 """Initialize a BrainEncoder 

16 

17 Args: 

18 modality (str, optional): The modality/type of human data. Defaults to None. 

19 aggregate_time (bool, optional): Whether we should aggregate timeid dimension of the 

20 data during encoding. Defaults to False. 

21 

22 Returns: 

23 BrainEncoder: _description_ 

24 """ 

25 self._measurement = measurement 

26 self._aggregate_time = aggregate_time 

27 

28 def encode( 

29 self, 

30 dataset: langbrainscore.dataset.Dataset, 

31 ) -> EncoderRepresentations: 

32 """ 

33 returns human measurements related to stimuli (passed in as a Dataset) 

34 

35 Args: 

36 langbrainscore.dataset.Dataset: brain dataset object 

37 

38 Returns: 

39 xr.DataArray: contents of brain dataset 

40 """ 

41 self._check_dataset_interface(dataset) 

42 if self._aggregate_time: 

43 dim = "timeid" 

44 return ( 

45 dataset.contents.mean(dim) 

46 .expand_dims(dim, 2) 

47 .assign_coords({dim: (dim, [0])}) 

48 ) 

49 

50 if "measurement" in dataset.contents.attrs: 

51 self._measurement = dataset.contents.attrs["measurement"] 

52 

53 # return dataset.contents 

54 return EncoderRepresentations( 

55 dataset=dataset, 

56 representations=dataset.contents, 

57 model_id=self._measurement, 

58 emb_aggregation=None, 

59 emb_preproc=(), 

60 include_special_tokens=None, 

61 context_dimension=None, 

62 bidirectional=False, 

63 )