Coverage for langbrainscore/brainscore/brainscore.py: 19%
116 statements
« prev ^ index » next coverage.py v6.4, created at 2022-06-07 21:22 +0000
« prev ^ index » next coverage.py v6.4, created at 2022-06-07 21:22 +0000
1import typing
3import numpy as np
4import xarray as xr
5from tqdm.auto import tqdm
7# from methodtools import lru_cache
8from pathlib import Path
10from langbrainscore.interface import (
11 _BrainScore,
12 _Mapping,
13 _Metric,
14 EncoderRepresentations,
15)
17# from langbrainscore.metrics import Metric
18from langbrainscore.utils import logging
19from langbrainscore.utils.xarray import collapse_multidim_coord, copy_metadata
22class BrainScore(_BrainScore):
23 scores = None
24 ceilings = None
25 nulls = []
27 def __init__(
28 self,
29 X: typing.Union[xr.DataArray, EncoderRepresentations],
30 Y: typing.Union[xr.DataArray, EncoderRepresentations],
31 mapping: _Mapping,
32 metric: _Metric,
33 sample_split_coord: str = None,
34 neuroid_split_coord: str = None,
35 run=False,
36 ) -> "BrainScore":
37 """Initializes the [lang]BrainScore object using two encoded representations and a mapping
38 class, and a metric for evaluation
40 Args:
41 X (typing.Union[xr.DataArray, EncoderRepresentations]): Either an xarray DataArray
42 instance, or a wrapper object with a `.representations` attribute that stores the xarray
43 DataArray
44 Y (typing.Union[xr.DataArray, EncoderRepresentations]): see `X`
45 mapping (_Mapping): _description_
46 metric (_Metric): _description_
47 run (bool, optional): _description_. Defaults to False.
49 Returns:
50 BrainScore: _description_
51 """
52 self.X = X.representations if hasattr(X, "representations") else X
53 self.Y = Y.representations if hasattr(Y, "representations") else Y
54 assert self.X.sampleid.size == self.Y.sampleid.size
55 self.mapping = mapping
56 self.metric = metric
57 self._sample_split_coord = sample_split_coord
58 self._neuroid_split_coord = neuroid_split_coord
60 if run:
61 self.run()
63 def __str__(self) -> str:
64 try:
65 return f"{self.scores.mean()}"
66 except AttributeError as e:
67 raise ValueError(
68 "missing scores. did you make a call to `score()` or `run()` yet?"
69 )
71 def to_netcdf(self, filename):
72 """
73 outputs the xarray.DataArray object for 'scores' to a netCDF file
74 identified by `filename`. if it already exists, overwrites it.
75 """
76 if Path(filename).expanduser().resolve().exists():
77 logging.log(f"{filename} already exists. overwriting.", type="WARN")
78 self.scores.to_netcdf(filename)
80 def load_netcdf(self, filename):
81 """
82 loads a netCDF object that contains an xarray instance for 'scores' from
83 a file at `filename`.
84 """
85 self.scores = xr.load_dataarray(filename)
87 @staticmethod
88 def _score(A, B, metric: _Metric) -> np.ndarray:
89 return metric(A, B)
91 # @lru_cache(maxsize=None)
92 def score(
93 self,
94 ceiling=False,
95 null=False,
96 seed=0,
97 ):
98 """
99 Computes The BrainScore™ (/s) using predictions/outputs returned by a
100 Mapping instance which is a member attribute of a BrainScore instance
101 """
102 assert not (ceiling and null)
103 sample_split_coord = self._sample_split_coord
104 neuroid_split_coord = self._neuroid_split_coord
106 if sample_split_coord:
107 assert sample_split_coord in self.Y.coords
109 if neuroid_split_coord:
110 assert neuroid_split_coord in self.Y.coords
112 X = self.X
113 if null:
114 y_shuffle = self.Y.copy()
115 y_shuffle.data = np.random.default_rng(seed=seed).permutation(
116 y_shuffle.data, axis=0
117 )
118 Y = y_shuffle
119 else:
120 Y = self.Y
121 y_pred, y_true = self.mapping.fit_transform(X, Y, ceiling=ceiling)
123 if not (ceiling or null):
124 self.Y_pred = y_pred
125 if y_pred.shape == y_true.shape: # not IdentityMap
126 self.Y_pred = copy_metadata(self.Y_pred, self.Y, "sampleid")
127 self.Y_pred = copy_metadata(self.Y_pred, self.Y, "neuroid")
128 self.Y_pred = copy_metadata(self.Y_pred, self.Y, "timeid")
130 scores_over_time = []
131 for timeid in y_true.timeid.values:
133 y_pred_time = y_pred.sel(timeid=timeid).transpose("sampleid", "neuroid")
134 y_true_time = y_true.sel(timeid=timeid).transpose("sampleid", "neuroid")
136 if sample_split_coord:
137 if sample_split_coord not in y_true_time.sampleid.coords:
138 y_pred_time = collapse_multidim_coord(
139 y_pred_time, sample_split_coord, "sampleid"
140 )
141 y_true_time = collapse_multidim_coord(
142 y_true_time, sample_split_coord, "sampleid"
143 )
144 score_splits = y_pred_time.sampleid.groupby(sample_split_coord).groups
145 else:
146 score_splits = [0]
148 scores_over_time_group = []
149 for scoreid in score_splits:
151 if sample_split_coord:
152 y_pred_time_group = y_pred_time.isel(
153 sampleid=y_pred_time[sample_split_coord] == scoreid
154 )
155 y_true_time_group = y_true_time.isel(
156 sampleid=y_true_time[sample_split_coord] == scoreid
157 )
158 else:
159 y_pred_time_group = y_pred_time
160 y_true_time_group = y_true_time
162 neuroids = []
163 if y_pred.shape != y_true.shape and neuroid_split_coord: # IdentityMap
164 if neuroid_split_coord:
165 if neuroid_split_coord not in y_true_time_group.neuroid.coords:
166 y_true_time_group = collapse_multidim_coord(
167 y_true_time_group, neuroid_split_coord, "neuroid"
168 )
169 neuroid_splits = y_true_time_group.neuroid.groupby(
170 neuroid_split_coord
171 ).groups
172 score_per_time_group = []
173 for neuroid in neuroid_splits:
174 score_per_time_group.append(
175 self._score(
176 y_pred_time_group,
177 y_true_time_group.isel(
178 neuroid=(
179 y_true_time_group[neuroid_split_coord]
180 == neuroid
181 )
182 ),
183 self.metric,
184 )
185 )
186 neuroids.append(neuroid)
187 score_per_time_group = np.array(score_per_time_group)
188 else:
189 score_per_time_group = self._score(
190 y_pred_time_group, y_true_time_group, self.metric
191 )
193 if neuroids:
194 pass
195 elif len(score_per_time_group) == 1: # e.g., RSA, CKA, w/o split
196 neuroids = [np.nan]
197 else:
198 neuroids = y_true_time_group.neuroid.data
200 scores_over_time_group.append(
201 xr.DataArray(
202 score_per_time_group.reshape(1, -1, 1),
203 dims=("scoreid", "neuroid", "timeid"),
204 coords={
205 "scoreid": ("scoreid", [scoreid]),
206 "neuroid": ("neuroid", neuroids),
207 "timeid": ("timeid", [timeid]),
208 },
209 )
210 )
212 scores_over_time.append(xr.concat(scores_over_time_group, dim="scoreid"))
214 scores = xr.concat(scores_over_time, dim="timeid")
216 if scores.neuroid.size == self.Y.neuroid.size: # not RSA, CKA, etc.
217 scores = copy_metadata(scores, self.Y, "neuroid")
218 scores = copy_metadata(scores, self.Y, "timeid")
220 if not (ceiling or null):
221 self.scores = scores
222 elif ceiling:
223 self.ceilings = scores
224 else:
225 self.nulls.append(
226 scores.expand_dims(dim={"iter": [seed]}, axis=-1).assign_coords(
227 iter=[seed]
228 )
229 )
231 def ceiling(self): # , sample_split_coord=None, neuroid_split_coord=None):
232 logging.log("Calculating ceiling.", type="INFO")
233 self.score(
234 ceiling=True,
235 # sample_split_coord=self._sample_split_coord,
236 # neuroid_split_coord=neuroid_split_coord,
237 )
239 def null(
240 self,
241 # sample_split_coord=None, neuroid_split_coord=None,
242 iters=100,
243 ):
244 for i in tqdm([*range(iters)], desc="Running null permutations"):
245 self.score(
246 null=True,
247 # sample_split_coord=sample_split_coord,
248 # neuroid_split_coord=neuroid_split_coord,
249 seed=i,
250 )
251 self.nulls = xr.concat(self.nulls, dim="iter")
253 def run(
254 self,
255 sample_split_coord=None,
256 neuroid_split_coord=None,
257 calc_nulls=False,
258 iters=100,
259 ):
260 self.score(
261 sample_split_coord=sample_split_coord,
262 neuroid_split_coord=neuroid_split_coord,
263 )
264 self.ceiling(
265 sample_split_coord=sample_split_coord,
266 neuroid_split_coord=neuroid_split_coord,
267 )
268 if calc_nulls:
269 self.null(
270 sample_split_coord=sample_split_coord,
271 neuroid_split_coord=neuroid_split_coord,
272 iters=iters,
273 )
274 return {
275 "scores": self.scores,
276 "ceilings": self.ceilings,
277 "nulls": self.nulls,
278 }
279 return {"scores": self.scores, "ceilings": self.ceilings}