Coverage for langbrainscore/utils/xarray.py: 26%

19 statements  

« prev     ^ index     » next       coverage.py v6.4, created at 2022-06-07 21:22 +0000

1import xarray as xr 

2from sklearn.impute import SimpleImputer 

3 

4 

5def copy_metadata(target: xr.DataArray, source: xr.DataArray, dim: str) -> xr.DataArray: 

6 """copies the metadata coordinates of a source xarray on dimension `dim` over to target xarray 

7 for this to work, the two `xr.DataArray` objects must have the same dimensions and 

8 dimensionality of data, minimally in the `dim` dimension. 

9 

10 Args: 

11 target (xr.DataArray): target xarray to copy the metadata coordinates onto 

12 (a copy is made, this does not happen inplace) 

13 source (xr.DataArray): the source xarray for the metadata coordinates along `dim` 

14 dim (str): dimension of the metadata coordinates (see `xr` documentation for help) 

15 

16 Returns: 

17 xr.DataArray 

18 """ 

19 for coord in source[dim].coords: 

20 target = target.assign_coords({coord: (dim, source[coord].data)}) 

21 return target 

22 

23 

24def collapse_multidim_coord(xr_obj, coord, keep_dim): 

25 """As a result of iterative construction of `xarray`s in our various functions 

26 (such as in the `HuggingFaceEncoder.encode` method), the same values are repeated 

27 over and over again 

28 

29 Args: 

30 xr_obj (xr.Array): _description_ 

31 coord (str): _description_ 

32 keep_dim (bool): _description_ 

33 

34 Returns: 

35 _type_: _description_ 

36 """ 

37 imputer = SimpleImputer(strategy="most_frequent") 

38 try: 

39 stimuli = imputer.fit_transform(xr_obj[coord])[0] 

40 return xr_obj.assign_coords({coord: (keep_dim, stimuli)}) 

41 except ValueError as e: # TODO which exception? what scenario does this cover? 

42 stimuli = imputer.fit_transform(xr_obj[coord]).transpose()[0] 

43 return xr_obj.assign_coords({coord: (keep_dim, stimuli)}) 

44 

45 

46def fix_xr_dtypes(xr_obj): 

47 """ 

48 sometimes xarrays end up having dtype='O' (object) instead of the 

49 expected dtypes, which is usually 'str' 

50 """ 

51 for c in xr_obj.coords: 

52 if xr_obj[c].dtype == "O": 

53 xr_obj[c] = xr_obj[c].astype(str) 

54 return xr_obj # this is likely not necessary --- the xr_obj should be modified in-place