Coverage for langbrainscore/interface/mapping.py: 78%
9 statements
« prev ^ index » next coverage.py v6.4, created at 2022-06-07 21:22 +0000
« prev ^ index » next coverage.py v6.4, created at 2022-06-07 21:22 +0000
1from abc import ABC, abstractmethod
2from typing import Tuple
4import xarray as xr
5from langbrainscore.interface.cacheable import _Cacheable
8class _Mapping(_Cacheable, ABC):
9 # class _Mapping(ABC):
10 """
11 object that defines and applies map between two xarrays with the same number of samples
12 """
14 def __init__(self):
15 pass
17 def fit_transform(
18 self, X: xr.DataArray, Y: xr.DataArray
19 ) -> Tuple[xr.DataArray, xr.DataArray]:
20 """
21 takes in two xarrays with a shared set of samples and returns a new
22 pair of xarrays (Y_pred, Y_true) to be compared with a metric.T
24 Y_pred is either derived from a learned mapping on X or can be X itself
25 when the downstream metric supports comparison of matrices with
26 different dimensions, e.g, RSA, CKA
28 args:
29 xr.DataArray: X
30 xr.DataArray: Y
32 returns:
33 xr.DataArray: Y_pred
34 xr.DataArray: Y_true
35 """
36 raise NotImplementedError