Coverage for langbrainscore/mapping/mapping.py: 17%

149 statements  

« prev     ^ index     » next       coverage.py v6.4, created at 2022-06-07 21:22 +0000

1import typing 

2from functools import partial 

3 

4from tqdm.auto import tqdm 

5import numpy as np 

6from joblib import Parallel, delayed 

7import xarray as xr 

8from sklearn.cross_decomposition import PLSRegression 

9from sklearn.linear_model import LinearRegression, RidgeCV 

10 

11from langbrainscore.interface import _Mapping 

12from langbrainscore.utils import logging 

13from langbrainscore.utils.xarray import collapse_multidim_coord 

14 

15mapping_classes_params = { 

16 "linreg": (LinearRegression, {}), 

17 "linridge_cv": (RidgeCV, {"alphas": np.logspace(-3, 3, 13)}), 

18 "linpls": (PLSRegression, {"n_components": 20}), 

19} 

20 

21 

22class IdentityMap(_Mapping): 

23 """ 

24 Identity mapping for use with metrics that operate 

25 on non column-aligned matrices, e.g., RSA, CKA 

26 

27 Imputes NaNs for downstream metrics. 

28 """ 

29 

30 def __init__(self, nan_strategy: str = "drop") -> "IdentityMap": 

31 self._nan_strategy = nan_strategy 

32 

33 def fit_transform( 

34 self, 

35 X: xr.DataArray, 

36 Y: xr.DataArray, 

37 ceiling: bool = False, 

38 ): 

39 if ceiling: 

40 logging.log("ceiling not supported for IdentityMap yet") 

41 # TODO: figure out how to handle NaNs better... 

42 if self._nan_strategy == "drop": 

43 X_clean = X.copy(deep=True).dropna(dim="neuroid") 

44 Y_clean = Y.copy(deep=True).dropna(dim="neuroid") 

45 elif self._nan_strategy == "impute": 

46 X_clean = X.copy(deep=True).fillna(0) 

47 Y_clean = Y.copy(deep=True).fillna(0) 

48 else: 

49 raise NotImplementedError("unsupported nan strategy.") 

50 return X_clean, Y_clean 

51 

52 

53class LearnedMap(_Mapping): 

54 def __init__( 

55 self, 

56 mapping_class: typing.Union[ 

57 str, typing.Tuple[typing.Callable, typing.Mapping[str, typing.Any]] 

58 ], 

59 random_seed: int = 42, 

60 k_fold: int = 5, 

61 strat_coord: str = None, 

62 num_split_groups_out: int = None, # (p, the # of groups in the test split) 

63 split_coord: str = None, # (grouping coord) 

64 # TODO 

65 # handle predict held-out subject # but then we have to do mean over ROIs 

66 # because individual neuroids do not correspond 

67 # we kind of already have this along the `sampleid` coordinate, but we 

68 # need to implement this in the neuroid coordinate 

69 **kwargs, 

70 ) -> "LearnedMap": 

71 """ 

72 Initializes a Mapping object that describes a mapping between two encoder representations. 

73 

74 Args: 

75 mapping_class (typing.Union[str, typing.Any], required): [description]. 

76 This Class will be instatiated to get a mapping model. E.g. LinearRegression, Ridge, 

77 from the sklearn family. Must implement <?classifier> interface 

78 random_seed (int, optional): [description]. Defaults to 42. 

79 k_fold (int, optional): [description]. Defaults to 5. 

80 strat_coord (str, optional): [description]. Defaults to None. 

81 num_split_groups_out (int, optional): [description]. Defaults to None. 

82 split_coord (str, optional): [description]. Defaults to None. 

83 """ 

84 self.random_seed = random_seed 

85 self.k_fold = k_fold 

86 self.strat_coord = strat_coord 

87 self.num_split_groups_out = num_split_groups_out 

88 self.split_coord = split_coord 

89 self.mapping_class_name = mapping_class 

90 self.mapping_params = kwargs 

91 

92 if type(mapping_class) == str: 

93 _mapping_class, _kwargs = mapping_classes_params[self.mapping_class_name] 

94 self.mapping_params.update(_kwargs) 

95 # in the spirit of duck-typing, we don't need any of these checks. we will automatically 

96 # fail if we're missing any of these attributes 

97 # else: 

98 # assert callable(mapping_class) 

99 # assert hasattr(mapping_class(), "fit") 

100 # assert hasattr(mapping_class(), "predict") 

101 

102 # TODO: what is the difference between these two (model; full_model)? let's make this less 

103 # confusing 

104 self.full_model = _mapping_class(**self.mapping_params) 

105 self.model = _mapping_class(**self.mapping_params) 

106 logging.log(f"initialized Mapping with {type(self.model)}!") 

107 

108 @staticmethod 

109 def _construct_splits( 

110 xr_dataset: xr.Dataset, 

111 strat_coord: str, 

112 k_folds: int, 

113 split_coord: str, 

114 num_split_groups_out: int, 

115 random_seed: int, 

116 ): 

117 from sklearn.model_selection import ( 

118 GroupKFold, 

119 KFold, 

120 StratifiedGroupKFold, 

121 StratifiedKFold, 

122 ) 

123 

124 sampleid = xr_dataset.sampleid.values 

125 

126 if strat_coord and split_coord: 

127 kf = StratifiedGroupKFold( 

128 n_splits=k_folds, shuffle=True, random_state=random_seed 

129 ) 

130 split = partial( 

131 kf.split, 

132 sampleid, 

133 y=xr_dataset[split_coord].values, 

134 groups=xr_dataset[strat_coord].values, 

135 ) 

136 elif split_coord: 

137 kf = GroupKFold(n_splits=k_folds) 

138 split = partial(kf.split, sampleid, groups=xr_dataset[split_coord].values) 

139 elif strat_coord: 

140 kf = StratifiedKFold( 

141 n_splits=k_folds, shuffle=True, random_state=random_seed 

142 ) 

143 split = partial(kf.split, sampleid, y=xr_dataset[strat_coord].values) 

144 else: 

145 kf = KFold(n_splits=k_folds, shuffle=True, random_state=random_seed) 

146 split = partial(kf.split, sampleid) 

147 

148 logging.log(f"running {type(kf)}!", verbosity_check=True) 

149 return split() 

150 

151 def construct_splits(self, A): 

152 return self._construct_splits( 

153 A, 

154 self.strat_coord, 

155 self.k_fold, 

156 self.split_coord, 

157 self.num_split_groups_out, 

158 random_seed=self.random_seed, 

159 ) 

160 

161 def fit_full(self, X, Y): 

162 # TODO 

163 self.fit(X, Y, k_folds=1) 

164 raise NotImplemented 

165 

166 def _check_sampleids( 

167 self, 

168 X: xr.DataArray, 

169 Y: xr.DataArray, 

170 ): 

171 """ 

172 checks that the sampleids in X and Y are the same 

173 """ 

174 

175 if X.sampleid.values.shape != Y.sampleid.values.shape: 

176 raise ValueError("X and Y sampleid shapes do not match!") 

177 if not np.all(X.sampleid.values == Y.sampleid.values): 

178 raise ValueError("X and Y sampleids do not match!") 

179 

180 logging.log( 

181 f"Passed sampleid check for neuroid {Y.neuroid.values}", 

182 verbosity_check=True, 

183 ) 

184 

185 def _drop_na( 

186 self, X: xr.DataArray, Y: xr.DataArray, dim: str = "sampleid", **kwargs 

187 ): 

188 """ 

189 drop samples with missing values (based on Y) in X or Y along specified dimension 

190 Make sure that X and Y now have the same sampleids 

191 """ 

192 # limit data to current neuroid, and then drop the samples that are missing data for this neuroid 

193 Y_slice = Y.dropna(dim=dim, **kwargs) 

194 Y_filtered_ids = Y_slice[dim].values 

195 

196 assert set(Y_filtered_ids).issubset(set(X[dim].values)) 

197 

198 logging.log( 

199 f"for neuroid {Y_slice.neuroid.values}, we used {(num_retained := len(Y_filtered_ids))}" 

200 f" samples; dropped {len(Y[dim]) - num_retained}", 

201 verbosity_check=True, 

202 ) 

203 

204 # use only the samples that are in Y 

205 X_slice = X.sel(sampleid=Y_filtered_ids) 

206 

207 return X_slice, Y_slice 

208 

209 # def _permute_X( 

210 # self, 

211 # X: xr.DataArray, 

212 # method: str = "shuffle_X_rows", 

213 # random_state: int = 42, 

214 # ): 

215 # """Permute the features of X. 

216 # 

217 # Parameters 

218 # ---------- 

219 # X : xr.DataArray 

220 # The embeddings to be permuted 

221 # method : str 

222 # The method to use for permutation. 

223 # 'shuffle_X_rows' : Shuffle the rows of X (=shuffle the sentences and create a mismatch between the sentence embeddings and target) 

224 # 'shuffle_each_X_col': For each column (=feature/unit) of X, permute that feature's values across all sentences. 

225 # Retains the statistics of the original features (e.g., mean per feature) but the values of the features are shuffled for each sentence. 

226 # random_state : int 

227 # The seed for the random number generator. 

228 # 

229 # Returns 

230 # ------- 

231 # xr.DataArray 

232 # The permuted dataarray 

233 # """ 

234 # 

235 # X_orig = X.copy(deep=True) 

236 # 

237 # if logging.get_verbosity(): 

238 # logging.log(f"OBS: permuting X with method {method}") 

239 # 

240 # if method == "shuffle_X_rows": 

241 # X = X.sample( 

242 # n=X.shape[1], random_state=random_state 

243 # ) # check whether X_shape is correct 

244 # 

245 # elif method == "shuffle_each_X_col": 

246 # np.random.seed(random_state) 

247 # for feat in X.data.shape[0]: # per neuroid 

248 # np.random.shuffle(X.data[feat, :]) 

249 # 

250 # else: 

251 # raise ValueError(f"Invalid method: {method}") 

252 # 

253 # assert X.shape == X_orig.shape 

254 # assert np.all(X.data != X_orig.data) 

255 # 

256 # return X 

257 

258 def fit_transform( 

259 self, 

260 X: xr.DataArray, 

261 Y: xr.DataArray, 

262 # permute_X: typing.Union[bool, str] = False, 

263 ceiling: bool = False, 

264 ceiling_coord: str = "subject", 

265 ) -> typing.Tuple[xr.DataArray, xr.DataArray]: 

266 """creates a mapping model using k-fold cross-validation 

267 -> uses params from the class initialization, uses strat_coord 

268 and split_coord to stratify and split across group boundaries 

269 

270 Returns: 

271 [type]: [description] 

272 """ 

273 from sklearn.random_projection import GaussianRandomProjection 

274 

275 if ceiling: 

276 n_neuroids = X.neuroid.values.size 

277 X = Y.copy() 

278 

279 logging.log(f"X shape: {X.data.shape}", verbosity_check=True) 

280 logging.log(f"Y shape: {Y.data.shape}", verbosity_check=True) 

281 

282 if self.strat_coord: 

283 try: 

284 assert (X[self.strat_coord].values == Y[self.strat_coord].values).all() 

285 except AssertionError as e: 

286 raise ValueError( 

287 f"{self.strat_coord} coordinate does not align across X and Y" 

288 ) 

289 if self.split_coord: 

290 try: 

291 assert (X[self.split_coord].values == Y[self.split_coord].values).all() 

292 except AssertionError as e: 

293 raise ValueError( 

294 f"{self.split_coord} coordinate does not align across X and Y" 

295 ) 

296 

297 def fit_per_neuroid(neuroid): 

298 Y_neuroid = Y.sel(neuroid=neuroid) 

299 

300 # limit data to current neuroid, and then drop the samples that are missing data for this neuroid 

301 X_slice, Y_slice = self._drop_na(X, Y_neuroid, dim="sampleid") 

302 

303 # Assert that X and Y have the same sampleids 

304 self._check_sampleids(X_slice, Y_slice) 

305 

306 # select relevant ceiling split 

307 if ceiling: 

308 X_slice = X_slice.isel( 

309 neuroid=X_slice[ceiling_coord] != Y_slice[ceiling_coord] 

310 ).dropna(dim="neuroid") 

311 

312 # We can perform various sanity checks by 'permuting' the source, X 

313 # NOTE this is a test! do not use under normal workflow! 

314 # if permute_X: 

315 # logging.log( 

316 # f"`permute_X` flag is enabled. only do this in an adversarial setting.", 

317 # cmap="WARN", 

318 # type="WARN", 

319 # verbosity_check=True, 

320 # ) 

321 # X_slice = self._permute_X(X_slice, method=permute_X) 

322 

323 # these collections store each split for our records later 

324 # TODO we aren't saving this to the object instance yet 

325 train_indices = [] 

326 test_indices = [] 

327 # only used in case of ridge_cv or any duck type that uses an alpha hparam 

328 

329 splits = self.construct_splits(Y_slice) 

330 

331 # X_test_collection = [] 

332 Y_test_collection = [] 

333 Y_pred_collection = [] 

334 

335 for cvfoldid, (train_index, test_index) in enumerate(splits): 

336 

337 train_indices.append(train_index) 

338 test_indices.append(test_index) 

339 

340 # !! NOTE the _nan_removed variants instead of X and Y 

341 X_train, X_test = ( 

342 X_slice.sel(sampleid=Y_slice.sampleid.values[train_index]), 

343 X_slice.sel(sampleid=Y_slice.sampleid.values[test_index]), 

344 ) 

345 y_train, y_test = ( 

346 Y_slice.sel(sampleid=Y_slice.sampleid.values[train_index]), 

347 Y_slice.sel(sampleid=Y_slice.sampleid.values[test_index]), 

348 ) 

349 

350 # empty list to house the y_predictions per timeid 

351 y_pred_over_time = [] 

352 

353 for timeid in y_train.timeid: 

354 

355 # TODO: change this code for models that also have a non-singleton timeid 

356 # i.e., output evolves in time (RNN?) 

357 

358 x_model_train = X_train.sel(timeid=0).values 

359 y_model_train = y_train.sel(timeid=timeid).values.reshape(-1, 1) 

360 

361 if ceiling and x_model_train.shape[1] > n_neuroids: 

362 projection = GaussianRandomProjection( 

363 n_components=n_neuroids, random_state=0 

364 ) 

365 x_model_train = projection.fit_transform(x_model_train) 

366 

367 self.model.fit( 

368 x_model_train, 

369 y_model_train, 

370 ) 

371 

372 # store the hparam values related to the fitted models 

373 alpha = getattr(self.model, "alpha_", np.nan) 

374 

375 # deepcopy `y_test` as `y_pred` to inherit some of the metadata and dims 

376 # and then populate it with our new predicted values 

377 y_pred = ( 

378 y_test.sel(timeid=timeid) 

379 .copy(deep=True) 

380 .expand_dims("timeid", 1) 

381 ) 

382 x_model_test = X_test.sel(timeid=0) 

383 if ceiling and x_model_train.shape[1] > n_neuroids: 

384 x_model_test = projection.transform(x_model_test) 

385 y_pred.data = self.model.predict(x_model_test) # y_pred 

386 y_pred = y_pred.assign_coords(timeid=("timeid", [timeid])) 

387 y_pred = y_pred.assign_coords(alpha=("timeid", [alpha])) 

388 y_pred = y_pred.assign_coords(cvfoldid=("timeid", [cvfoldid])) 

389 y_pred_over_time.append(y_pred) 

390 

391 y_pred_over_time = xr.concat(y_pred_over_time, dim="timeid") 

392 Y_pred_collection.append(y_pred_over_time) 

393 Y_test_collection.append(y_test) 

394 

395 Y_test = xr.concat(Y_test_collection, dim="sampleid").sortby("sampleid") 

396 Y_pred = xr.concat(Y_pred_collection, dim="sampleid").sortby("sampleid") 

397 

398 # test.append(Y_test) 

399 # pred.append(Y_pred) 

400 return Y_test, Y_pred 

401 

402 # Loop across each Y neuroid (target) 

403 test = [] 

404 pred = [] 

405 # TODO: parallelize using order-preserving joblib-mapping 

406 # for neuroid in tqdm(Y.neuroid.values, desc="fitting a model per neuroid"): 

407 for t, p in Parallel(n_jobs=-2)( 

408 delayed(fit_per_neuroid)(neuroid) 

409 for neuroid in tqdm(Y.neuroid.values, desc="fitting a model per neuroid") 

410 ): 

411 test += [t] 

412 pred += [p] 

413 

414 test_xr = xr.concat(test, dim="neuroid").transpose( 

415 "sampleid", "neuroid", "timeid" 

416 ) 

417 pred_xr = xr.concat(pred, dim="neuroid").transpose( 

418 "sampleid", "neuroid", "timeid" 

419 ) 

420 

421 if test_xr.stimulus.ndim > 1: 

422 test_xr = collapse_multidim_coord(test_xr, "stimulus", "sampleid") 

423 if pred_xr.stimulus.ndim > 1: 

424 pred_xr = collapse_multidim_coord(pred_xr, "stimulus", "sampleid") 

425 

426 return pred_xr, test_xr 

427 

428 # def map(self, source, target) -> None: 

429 # ''' 

430 # the works: constructs splits, fits models for each split, then evaluates the fit 

431 # of each split and returns the result (also for each split) 

432 # ''' 

433 # pass 

434 

435 def save_model(self) -> None: 

436 """TODO: stuff that needs to be saved eventually 

437 

438 - model weights 

439 - CV stuff (if using CV); but all arguments needed for initializing, in general 

440 - n_splits 

441 - random_state 

442 - split indices (based on random seed) 

443 - params per split (alpha, lambda) 

444 - validation score, etc. for each CV split? 

445 """ 

446 pass 

447 

448 def predict(self, source) -> None: 

449 pass