Coverage for langbrainscore/interface/dataset.py: 82%
22 statements
« prev ^ index » next coverage.py v6.4, created at 2022-06-07 21:22 +0000
« prev ^ index » next coverage.py v6.4, created at 2022-06-07 21:22 +0000
1from abc import ABC
2import typing
4import xarray as xr
6from langbrainscore.interface.cacheable import _Cacheable
7from langbrainscore.utils.xarray import fix_xr_dtypes
10class _Dataset(_Cacheable, ABC):
11 """
12 wrapper class for xarray DataArray that confirms format adheres to interface.
13 """
15 dataset_name: str = None
17 def __init__(
18 self,
19 xr_obj: xr.DataArray,
20 dataset_name: str = None,
21 # modality: str = None,
22 _skip_checks: bool = False,
23 ) -> "_Dataset":
24 """
25 accepts an xarray with the following core
26 dimensions: sampleid, neuroid, timeid
27 and at least the following core
28 coordinates: sampleid, neuroid, timeid, stimulus, subject
30 Args:
31 xr_obj (xr.DataArray): xarray object with core dimensions and coordinates
32 """
34 if xr_obj is not None:
35 try:
36 self.dataset_name = xr_obj.attrs["name"]
37 except KeyError:
38 pass
39 self.dataset_name = self.dataset_name or dataset_name
40 # self.modality = modality
42 if not _skip_checks:
43 dims = ("sampleid", "neuroid", "timeid")
44 coords = dims + ("stimulus", "subject")
45 assert isinstance(xr_obj, xr.DataArray)
46 assert xr_obj.ndim == len(dims)
47 assert all([dim in xr_obj.dims for dim in dims])
48 assert all([coord in xr_obj.coords for coord in coords])
50 self._xr_obj = fix_xr_dtypes(xr_obj)
52 # def __getattr__(self, __name: str) -> typing.Any:
53 # """falls back on the xarray object in case of a NameError using __getattribute__
54 # on this object"""
55 # try:
56 # return getattr(self.contents, __name)
57 # except AttributeError:
58 # raise AttributeError(f"no attribute called `{__name}` on object")