Coverage for langbrainscore/interface/dataset.py: 82%

22 statements  

« prev     ^ index     » next       coverage.py v6.4, created at 2022-06-07 21:22 +0000

1from abc import ABC 

2import typing 

3 

4import xarray as xr 

5 

6from langbrainscore.interface.cacheable import _Cacheable 

7from langbrainscore.utils.xarray import fix_xr_dtypes 

8 

9 

10class _Dataset(_Cacheable, ABC): 

11 """ 

12 wrapper class for xarray DataArray that confirms format adheres to interface. 

13 """ 

14 

15 dataset_name: str = None 

16 

17 def __init__( 

18 self, 

19 xr_obj: xr.DataArray, 

20 dataset_name: str = None, 

21 # modality: str = None, 

22 _skip_checks: bool = False, 

23 ) -> "_Dataset": 

24 """ 

25 accepts an xarray with the following core 

26 dimensions: sampleid, neuroid, timeid 

27 and at least the following core 

28 coordinates: sampleid, neuroid, timeid, stimulus, subject 

29 

30 Args: 

31 xr_obj (xr.DataArray): xarray object with core dimensions and coordinates 

32 """ 

33 

34 if xr_obj is not None: 

35 try: 

36 self.dataset_name = xr_obj.attrs["name"] 

37 except KeyError: 

38 pass 

39 self.dataset_name = self.dataset_name or dataset_name 

40 # self.modality = modality 

41 

42 if not _skip_checks: 

43 dims = ("sampleid", "neuroid", "timeid") 

44 coords = dims + ("stimulus", "subject") 

45 assert isinstance(xr_obj, xr.DataArray) 

46 assert xr_obj.ndim == len(dims) 

47 assert all([dim in xr_obj.dims for dim in dims]) 

48 assert all([coord in xr_obj.coords for coord in coords]) 

49 

50 self._xr_obj = fix_xr_dtypes(xr_obj) 

51 

52 # def __getattr__(self, __name: str) -> typing.Any: 

53 # """falls back on the xarray object in case of a NameError using __getattribute__ 

54 # on this object""" 

55 # try: 

56 # return getattr(self.contents, __name) 

57 # except AttributeError: 

58 # raise AttributeError(f"no attribute called `{__name}` on object")