Coverage for langbrainscore/interface/mapping.py: 78%

9 statements  

« prev     ^ index     » next       coverage.py v6.4, created at 2022-06-07 21:22 +0000

1from abc import ABC, abstractmethod 

2from typing import Tuple 

3 

4import xarray as xr 

5from langbrainscore.interface.cacheable import _Cacheable 

6 

7 

8class _Mapping(_Cacheable, ABC): 

9 # class _Mapping(ABC): 

10 """ 

11 object that defines and applies map between two xarrays with the same number of samples 

12 """ 

13 

14 def __init__(self): 

15 pass 

16 

17 def fit_transform( 

18 self, X: xr.DataArray, Y: xr.DataArray 

19 ) -> Tuple[xr.DataArray, xr.DataArray]: 

20 """ 

21 takes in two xarrays with a shared set of samples and returns a new 

22 pair of xarrays (Y_pred, Y_true) to be compared with a metric.T 

23 

24 Y_pred is either derived from a learned mapping on X or can be X itself 

25 when the downstream metric supports comparison of matrices with 

26 different dimensions, e.g, RSA, CKA 

27 

28 args: 

29 xr.DataArray: X 

30 xr.DataArray: Y 

31 

32 returns: 

33 xr.DataArray: Y_pred 

34 xr.DataArray: Y_true 

35 """ 

36 raise NotImplementedError