Coverage for langbrainscore/mapping/mapping.py: 17%
149 statements
« prev ^ index » next coverage.py v6.4, created at 2022-06-07 21:22 +0000
« prev ^ index » next coverage.py v6.4, created at 2022-06-07 21:22 +0000
1import typing
2from functools import partial
4from tqdm.auto import tqdm
5import numpy as np
6from joblib import Parallel, delayed
7import xarray as xr
8from sklearn.cross_decomposition import PLSRegression
9from sklearn.linear_model import LinearRegression, RidgeCV
11from langbrainscore.interface import _Mapping
12from langbrainscore.utils import logging
13from langbrainscore.utils.xarray import collapse_multidim_coord
15mapping_classes_params = {
16 "linreg": (LinearRegression, {}),
17 "linridge_cv": (RidgeCV, {"alphas": np.logspace(-3, 3, 13)}),
18 "linpls": (PLSRegression, {"n_components": 20}),
19}
22class IdentityMap(_Mapping):
23 """
24 Identity mapping for use with metrics that operate
25 on non column-aligned matrices, e.g., RSA, CKA
27 Imputes NaNs for downstream metrics.
28 """
30 def __init__(self, nan_strategy: str = "drop") -> "IdentityMap":
31 self._nan_strategy = nan_strategy
33 def fit_transform(
34 self,
35 X: xr.DataArray,
36 Y: xr.DataArray,
37 ceiling: bool = False,
38 ):
39 if ceiling:
40 logging.log("ceiling not supported for IdentityMap yet")
41 # TODO: figure out how to handle NaNs better...
42 if self._nan_strategy == "drop":
43 X_clean = X.copy(deep=True).dropna(dim="neuroid")
44 Y_clean = Y.copy(deep=True).dropna(dim="neuroid")
45 elif self._nan_strategy == "impute":
46 X_clean = X.copy(deep=True).fillna(0)
47 Y_clean = Y.copy(deep=True).fillna(0)
48 else:
49 raise NotImplementedError("unsupported nan strategy.")
50 return X_clean, Y_clean
53class LearnedMap(_Mapping):
54 def __init__(
55 self,
56 mapping_class: typing.Union[
57 str, typing.Tuple[typing.Callable, typing.Mapping[str, typing.Any]]
58 ],
59 random_seed: int = 42,
60 k_fold: int = 5,
61 strat_coord: str = None,
62 num_split_groups_out: int = None, # (p, the # of groups in the test split)
63 split_coord: str = None, # (grouping coord)
64 # TODO
65 # handle predict held-out subject # but then we have to do mean over ROIs
66 # because individual neuroids do not correspond
67 # we kind of already have this along the `sampleid` coordinate, but we
68 # need to implement this in the neuroid coordinate
69 **kwargs,
70 ) -> "LearnedMap":
71 """
72 Initializes a Mapping object that describes a mapping between two encoder representations.
74 Args:
75 mapping_class (typing.Union[str, typing.Any], required): [description].
76 This Class will be instatiated to get a mapping model. E.g. LinearRegression, Ridge,
77 from the sklearn family. Must implement <?classifier> interface
78 random_seed (int, optional): [description]. Defaults to 42.
79 k_fold (int, optional): [description]. Defaults to 5.
80 strat_coord (str, optional): [description]. Defaults to None.
81 num_split_groups_out (int, optional): [description]. Defaults to None.
82 split_coord (str, optional): [description]. Defaults to None.
83 """
84 self.random_seed = random_seed
85 self.k_fold = k_fold
86 self.strat_coord = strat_coord
87 self.num_split_groups_out = num_split_groups_out
88 self.split_coord = split_coord
89 self.mapping_class_name = mapping_class
90 self.mapping_params = kwargs
92 if type(mapping_class) == str:
93 _mapping_class, _kwargs = mapping_classes_params[self.mapping_class_name]
94 self.mapping_params.update(_kwargs)
95 # in the spirit of duck-typing, we don't need any of these checks. we will automatically
96 # fail if we're missing any of these attributes
97 # else:
98 # assert callable(mapping_class)
99 # assert hasattr(mapping_class(), "fit")
100 # assert hasattr(mapping_class(), "predict")
102 # TODO: what is the difference between these two (model; full_model)? let's make this less
103 # confusing
104 self.full_model = _mapping_class(**self.mapping_params)
105 self.model = _mapping_class(**self.mapping_params)
106 logging.log(f"initialized Mapping with {type(self.model)}!")
108 @staticmethod
109 def _construct_splits(
110 xr_dataset: xr.Dataset,
111 strat_coord: str,
112 k_folds: int,
113 split_coord: str,
114 num_split_groups_out: int,
115 random_seed: int,
116 ):
117 from sklearn.model_selection import (
118 GroupKFold,
119 KFold,
120 StratifiedGroupKFold,
121 StratifiedKFold,
122 )
124 sampleid = xr_dataset.sampleid.values
126 if strat_coord and split_coord:
127 kf = StratifiedGroupKFold(
128 n_splits=k_folds, shuffle=True, random_state=random_seed
129 )
130 split = partial(
131 kf.split,
132 sampleid,
133 y=xr_dataset[split_coord].values,
134 groups=xr_dataset[strat_coord].values,
135 )
136 elif split_coord:
137 kf = GroupKFold(n_splits=k_folds)
138 split = partial(kf.split, sampleid, groups=xr_dataset[split_coord].values)
139 elif strat_coord:
140 kf = StratifiedKFold(
141 n_splits=k_folds, shuffle=True, random_state=random_seed
142 )
143 split = partial(kf.split, sampleid, y=xr_dataset[strat_coord].values)
144 else:
145 kf = KFold(n_splits=k_folds, shuffle=True, random_state=random_seed)
146 split = partial(kf.split, sampleid)
148 logging.log(f"running {type(kf)}!", verbosity_check=True)
149 return split()
151 def construct_splits(self, A):
152 return self._construct_splits(
153 A,
154 self.strat_coord,
155 self.k_fold,
156 self.split_coord,
157 self.num_split_groups_out,
158 random_seed=self.random_seed,
159 )
161 def fit_full(self, X, Y):
162 # TODO
163 self.fit(X, Y, k_folds=1)
164 raise NotImplemented
166 def _check_sampleids(
167 self,
168 X: xr.DataArray,
169 Y: xr.DataArray,
170 ):
171 """
172 checks that the sampleids in X and Y are the same
173 """
175 if X.sampleid.values.shape != Y.sampleid.values.shape:
176 raise ValueError("X and Y sampleid shapes do not match!")
177 if not np.all(X.sampleid.values == Y.sampleid.values):
178 raise ValueError("X and Y sampleids do not match!")
180 logging.log(
181 f"Passed sampleid check for neuroid {Y.neuroid.values}",
182 verbosity_check=True,
183 )
185 def _drop_na(
186 self, X: xr.DataArray, Y: xr.DataArray, dim: str = "sampleid", **kwargs
187 ):
188 """
189 drop samples with missing values (based on Y) in X or Y along specified dimension
190 Make sure that X and Y now have the same sampleids
191 """
192 # limit data to current neuroid, and then drop the samples that are missing data for this neuroid
193 Y_slice = Y.dropna(dim=dim, **kwargs)
194 Y_filtered_ids = Y_slice[dim].values
196 assert set(Y_filtered_ids).issubset(set(X[dim].values))
198 logging.log(
199 f"for neuroid {Y_slice.neuroid.values}, we used {(num_retained := len(Y_filtered_ids))}"
200 f" samples; dropped {len(Y[dim]) - num_retained}",
201 verbosity_check=True,
202 )
204 # use only the samples that are in Y
205 X_slice = X.sel(sampleid=Y_filtered_ids)
207 return X_slice, Y_slice
209 # def _permute_X(
210 # self,
211 # X: xr.DataArray,
212 # method: str = "shuffle_X_rows",
213 # random_state: int = 42,
214 # ):
215 # """Permute the features of X.
216 #
217 # Parameters
218 # ----------
219 # X : xr.DataArray
220 # The embeddings to be permuted
221 # method : str
222 # The method to use for permutation.
223 # 'shuffle_X_rows' : Shuffle the rows of X (=shuffle the sentences and create a mismatch between the sentence embeddings and target)
224 # 'shuffle_each_X_col': For each column (=feature/unit) of X, permute that feature's values across all sentences.
225 # Retains the statistics of the original features (e.g., mean per feature) but the values of the features are shuffled for each sentence.
226 # random_state : int
227 # The seed for the random number generator.
228 #
229 # Returns
230 # -------
231 # xr.DataArray
232 # The permuted dataarray
233 # """
234 #
235 # X_orig = X.copy(deep=True)
236 #
237 # if logging.get_verbosity():
238 # logging.log(f"OBS: permuting X with method {method}")
239 #
240 # if method == "shuffle_X_rows":
241 # X = X.sample(
242 # n=X.shape[1], random_state=random_state
243 # ) # check whether X_shape is correct
244 #
245 # elif method == "shuffle_each_X_col":
246 # np.random.seed(random_state)
247 # for feat in X.data.shape[0]: # per neuroid
248 # np.random.shuffle(X.data[feat, :])
249 #
250 # else:
251 # raise ValueError(f"Invalid method: {method}")
252 #
253 # assert X.shape == X_orig.shape
254 # assert np.all(X.data != X_orig.data)
255 #
256 # return X
258 def fit_transform(
259 self,
260 X: xr.DataArray,
261 Y: xr.DataArray,
262 # permute_X: typing.Union[bool, str] = False,
263 ceiling: bool = False,
264 ceiling_coord: str = "subject",
265 ) -> typing.Tuple[xr.DataArray, xr.DataArray]:
266 """creates a mapping model using k-fold cross-validation
267 -> uses params from the class initialization, uses strat_coord
268 and split_coord to stratify and split across group boundaries
270 Returns:
271 [type]: [description]
272 """
273 from sklearn.random_projection import GaussianRandomProjection
275 if ceiling:
276 n_neuroids = X.neuroid.values.size
277 X = Y.copy()
279 logging.log(f"X shape: {X.data.shape}", verbosity_check=True)
280 logging.log(f"Y shape: {Y.data.shape}", verbosity_check=True)
282 if self.strat_coord:
283 try:
284 assert (X[self.strat_coord].values == Y[self.strat_coord].values).all()
285 except AssertionError as e:
286 raise ValueError(
287 f"{self.strat_coord} coordinate does not align across X and Y"
288 )
289 if self.split_coord:
290 try:
291 assert (X[self.split_coord].values == Y[self.split_coord].values).all()
292 except AssertionError as e:
293 raise ValueError(
294 f"{self.split_coord} coordinate does not align across X and Y"
295 )
297 def fit_per_neuroid(neuroid):
298 Y_neuroid = Y.sel(neuroid=neuroid)
300 # limit data to current neuroid, and then drop the samples that are missing data for this neuroid
301 X_slice, Y_slice = self._drop_na(X, Y_neuroid, dim="sampleid")
303 # Assert that X and Y have the same sampleids
304 self._check_sampleids(X_slice, Y_slice)
306 # select relevant ceiling split
307 if ceiling:
308 X_slice = X_slice.isel(
309 neuroid=X_slice[ceiling_coord] != Y_slice[ceiling_coord]
310 ).dropna(dim="neuroid")
312 # We can perform various sanity checks by 'permuting' the source, X
313 # NOTE this is a test! do not use under normal workflow!
314 # if permute_X:
315 # logging.log(
316 # f"`permute_X` flag is enabled. only do this in an adversarial setting.",
317 # cmap="WARN",
318 # type="WARN",
319 # verbosity_check=True,
320 # )
321 # X_slice = self._permute_X(X_slice, method=permute_X)
323 # these collections store each split for our records later
324 # TODO we aren't saving this to the object instance yet
325 train_indices = []
326 test_indices = []
327 # only used in case of ridge_cv or any duck type that uses an alpha hparam
329 splits = self.construct_splits(Y_slice)
331 # X_test_collection = []
332 Y_test_collection = []
333 Y_pred_collection = []
335 for cvfoldid, (train_index, test_index) in enumerate(splits):
337 train_indices.append(train_index)
338 test_indices.append(test_index)
340 # !! NOTE the _nan_removed variants instead of X and Y
341 X_train, X_test = (
342 X_slice.sel(sampleid=Y_slice.sampleid.values[train_index]),
343 X_slice.sel(sampleid=Y_slice.sampleid.values[test_index]),
344 )
345 y_train, y_test = (
346 Y_slice.sel(sampleid=Y_slice.sampleid.values[train_index]),
347 Y_slice.sel(sampleid=Y_slice.sampleid.values[test_index]),
348 )
350 # empty list to house the y_predictions per timeid
351 y_pred_over_time = []
353 for timeid in y_train.timeid:
355 # TODO: change this code for models that also have a non-singleton timeid
356 # i.e., output evolves in time (RNN?)
358 x_model_train = X_train.sel(timeid=0).values
359 y_model_train = y_train.sel(timeid=timeid).values.reshape(-1, 1)
361 if ceiling and x_model_train.shape[1] > n_neuroids:
362 projection = GaussianRandomProjection(
363 n_components=n_neuroids, random_state=0
364 )
365 x_model_train = projection.fit_transform(x_model_train)
367 self.model.fit(
368 x_model_train,
369 y_model_train,
370 )
372 # store the hparam values related to the fitted models
373 alpha = getattr(self.model, "alpha_", np.nan)
375 # deepcopy `y_test` as `y_pred` to inherit some of the metadata and dims
376 # and then populate it with our new predicted values
377 y_pred = (
378 y_test.sel(timeid=timeid)
379 .copy(deep=True)
380 .expand_dims("timeid", 1)
381 )
382 x_model_test = X_test.sel(timeid=0)
383 if ceiling and x_model_train.shape[1] > n_neuroids:
384 x_model_test = projection.transform(x_model_test)
385 y_pred.data = self.model.predict(x_model_test) # y_pred
386 y_pred = y_pred.assign_coords(timeid=("timeid", [timeid]))
387 y_pred = y_pred.assign_coords(alpha=("timeid", [alpha]))
388 y_pred = y_pred.assign_coords(cvfoldid=("timeid", [cvfoldid]))
389 y_pred_over_time.append(y_pred)
391 y_pred_over_time = xr.concat(y_pred_over_time, dim="timeid")
392 Y_pred_collection.append(y_pred_over_time)
393 Y_test_collection.append(y_test)
395 Y_test = xr.concat(Y_test_collection, dim="sampleid").sortby("sampleid")
396 Y_pred = xr.concat(Y_pred_collection, dim="sampleid").sortby("sampleid")
398 # test.append(Y_test)
399 # pred.append(Y_pred)
400 return Y_test, Y_pred
402 # Loop across each Y neuroid (target)
403 test = []
404 pred = []
405 # TODO: parallelize using order-preserving joblib-mapping
406 # for neuroid in tqdm(Y.neuroid.values, desc="fitting a model per neuroid"):
407 for t, p in Parallel(n_jobs=-2)(
408 delayed(fit_per_neuroid)(neuroid)
409 for neuroid in tqdm(Y.neuroid.values, desc="fitting a model per neuroid")
410 ):
411 test += [t]
412 pred += [p]
414 test_xr = xr.concat(test, dim="neuroid").transpose(
415 "sampleid", "neuroid", "timeid"
416 )
417 pred_xr = xr.concat(pred, dim="neuroid").transpose(
418 "sampleid", "neuroid", "timeid"
419 )
421 if test_xr.stimulus.ndim > 1:
422 test_xr = collapse_multidim_coord(test_xr, "stimulus", "sampleid")
423 if pred_xr.stimulus.ndim > 1:
424 pred_xr = collapse_multidim_coord(pred_xr, "stimulus", "sampleid")
426 return pred_xr, test_xr
428 # def map(self, source, target) -> None:
429 # '''
430 # the works: constructs splits, fits models for each split, then evaluates the fit
431 # of each split and returns the result (also for each split)
432 # '''
433 # pass
435 def save_model(self) -> None:
436 """TODO: stuff that needs to be saved eventually
438 - model weights
439 - CV stuff (if using CV); but all arguments needed for initializing, in general
440 - n_splits
441 - random_state
442 - split indices (based on random seed)
443 - params per split (alpha, lambda)
444 - validation score, etc. for each CV split?
445 """
446 pass
448 def predict(self, source) -> None:
449 pass