Coverage for langbrainscore/encoder/brain.py: 40%
15 statements
« prev ^ index » next coverage.py v6.4, created at 2022-06-07 21:22 +0000
« prev ^ index » next coverage.py v6.4, created at 2022-06-07 21:22 +0000
1import langbrainscore
2import xarray as xr
3from langbrainscore.interface.encoder import _Encoder, EncoderRepresentations
6class BrainEncoder(_Encoder):
7 """
8 This class is used to extract the relevant contents of a given
9 `langbrainscore.dataset.Dataset` object and maintains the Encoder interface.
10 """
12 def __init__(
13 self, measurement: str = "unknown", aggregate_time: bool = False
14 ) -> "BrainEncoder":
15 """Initialize a BrainEncoder
17 Args:
18 modality (str, optional): The modality/type of human data. Defaults to None.
19 aggregate_time (bool, optional): Whether we should aggregate timeid dimension of the
20 data during encoding. Defaults to False.
22 Returns:
23 BrainEncoder: _description_
24 """
25 self._measurement = measurement
26 self._aggregate_time = aggregate_time
28 def encode(
29 self,
30 dataset: langbrainscore.dataset.Dataset,
31 ) -> EncoderRepresentations:
32 """
33 returns human measurements related to stimuli (passed in as a Dataset)
35 Args:
36 langbrainscore.dataset.Dataset: brain dataset object
38 Returns:
39 xr.DataArray: contents of brain dataset
40 """
41 self._check_dataset_interface(dataset)
42 if self._aggregate_time:
43 dim = "timeid"
44 return (
45 dataset.contents.mean(dim)
46 .expand_dims(dim, 2)
47 .assign_coords({dim: (dim, [0])})
48 )
50 if "measurement" in dataset.contents.attrs:
51 self._measurement = dataset.contents.attrs["measurement"]
53 # return dataset.contents
54 return EncoderRepresentations(
55 dataset=dataset,
56 representations=dataset.contents,
57 model_id=self._measurement,
58 emb_aggregation=None,
59 emb_preproc=(),
60 include_special_tokens=None,
61 context_dimension=None,
62 bidirectional=False,
63 )