Coverage for langbrainscore/utils/xarray.py: 26%
19 statements
« prev ^ index » next coverage.py v6.4, created at 2022-06-07 21:22 +0000
« prev ^ index » next coverage.py v6.4, created at 2022-06-07 21:22 +0000
1import xarray as xr
2from sklearn.impute import SimpleImputer
5def copy_metadata(target: xr.DataArray, source: xr.DataArray, dim: str) -> xr.DataArray:
6 """copies the metadata coordinates of a source xarray on dimension `dim` over to target xarray
7 for this to work, the two `xr.DataArray` objects must have the same dimensions and
8 dimensionality of data, minimally in the `dim` dimension.
10 Args:
11 target (xr.DataArray): target xarray to copy the metadata coordinates onto
12 (a copy is made, this does not happen inplace)
13 source (xr.DataArray): the source xarray for the metadata coordinates along `dim`
14 dim (str): dimension of the metadata coordinates (see `xr` documentation for help)
16 Returns:
17 xr.DataArray
18 """
19 for coord in source[dim].coords:
20 target = target.assign_coords({coord: (dim, source[coord].data)})
21 return target
24def collapse_multidim_coord(xr_obj, coord, keep_dim):
25 """As a result of iterative construction of `xarray`s in our various functions
26 (such as in the `HuggingFaceEncoder.encode` method), the same values are repeated
27 over and over again
29 Args:
30 xr_obj (xr.Array): _description_
31 coord (str): _description_
32 keep_dim (bool): _description_
34 Returns:
35 _type_: _description_
36 """
37 imputer = SimpleImputer(strategy="most_frequent")
38 try:
39 stimuli = imputer.fit_transform(xr_obj[coord])[0]
40 return xr_obj.assign_coords({coord: (keep_dim, stimuli)})
41 except ValueError as e: # TODO which exception? what scenario does this cover?
42 stimuli = imputer.fit_transform(xr_obj[coord]).transpose()[0]
43 return xr_obj.assign_coords({coord: (keep_dim, stimuli)})
46def fix_xr_dtypes(xr_obj):
47 """
48 sometimes xarrays end up having dtype='O' (object) instead of the
49 expected dtypes, which is usually 'str'
50 """
51 for c in xr_obj.coords:
52 if xr_obj[c].dtype == "O":
53 xr_obj[c] = xr_obj[c].astype(str)
54 return xr_obj # this is likely not necessary --- the xr_obj should be modified in-place