Coverage for langbrainscore/brainscore/brainscore.py: 19%

116 statements  

« prev     ^ index     » next       coverage.py v6.4, created at 2022-06-07 21:22 +0000

1import typing 

2 

3import numpy as np 

4import xarray as xr 

5from tqdm.auto import tqdm 

6 

7# from methodtools import lru_cache 

8from pathlib import Path 

9 

10from langbrainscore.interface import ( 

11 _BrainScore, 

12 _Mapping, 

13 _Metric, 

14 EncoderRepresentations, 

15) 

16 

17# from langbrainscore.metrics import Metric 

18from langbrainscore.utils import logging 

19from langbrainscore.utils.xarray import collapse_multidim_coord, copy_metadata 

20 

21 

22class BrainScore(_BrainScore): 

23 scores = None 

24 ceilings = None 

25 nulls = [] 

26 

27 def __init__( 

28 self, 

29 X: typing.Union[xr.DataArray, EncoderRepresentations], 

30 Y: typing.Union[xr.DataArray, EncoderRepresentations], 

31 mapping: _Mapping, 

32 metric: _Metric, 

33 sample_split_coord: str = None, 

34 neuroid_split_coord: str = None, 

35 run=False, 

36 ) -> "BrainScore": 

37 """Initializes the [lang]BrainScore object using two encoded representations and a mapping 

38 class, and a metric for evaluation 

39 

40 Args: 

41 X (typing.Union[xr.DataArray, EncoderRepresentations]): Either an xarray DataArray 

42 instance, or a wrapper object with a `.representations` attribute that stores the xarray 

43 DataArray 

44 Y (typing.Union[xr.DataArray, EncoderRepresentations]): see `X` 

45 mapping (_Mapping): _description_ 

46 metric (_Metric): _description_ 

47 run (bool, optional): _description_. Defaults to False. 

48 

49 Returns: 

50 BrainScore: _description_ 

51 """ 

52 self.X = X.representations if hasattr(X, "representations") else X 

53 self.Y = Y.representations if hasattr(Y, "representations") else Y 

54 assert self.X.sampleid.size == self.Y.sampleid.size 

55 self.mapping = mapping 

56 self.metric = metric 

57 self._sample_split_coord = sample_split_coord 

58 self._neuroid_split_coord = neuroid_split_coord 

59 

60 if run: 

61 self.run() 

62 

63 def __str__(self) -> str: 

64 try: 

65 return f"{self.scores.mean()}" 

66 except AttributeError as e: 

67 raise ValueError( 

68 "missing scores. did you make a call to `score()` or `run()` yet?" 

69 ) 

70 

71 def to_netcdf(self, filename): 

72 """ 

73 outputs the xarray.DataArray object for 'scores' to a netCDF file 

74 identified by `filename`. if it already exists, overwrites it. 

75 """ 

76 if Path(filename).expanduser().resolve().exists(): 

77 logging.log(f"{filename} already exists. overwriting.", type="WARN") 

78 self.scores.to_netcdf(filename) 

79 

80 def load_netcdf(self, filename): 

81 """ 

82 loads a netCDF object that contains an xarray instance for 'scores' from 

83 a file at `filename`. 

84 """ 

85 self.scores = xr.load_dataarray(filename) 

86 

87 @staticmethod 

88 def _score(A, B, metric: _Metric) -> np.ndarray: 

89 return metric(A, B) 

90 

91 # @lru_cache(maxsize=None) 

92 def score( 

93 self, 

94 ceiling=False, 

95 null=False, 

96 seed=0, 

97 ): 

98 """ 

99 Computes The BrainScore™ (/s) using predictions/outputs returned by a 

100 Mapping instance which is a member attribute of a BrainScore instance 

101 """ 

102 assert not (ceiling and null) 

103 sample_split_coord = self._sample_split_coord 

104 neuroid_split_coord = self._neuroid_split_coord 

105 

106 if sample_split_coord: 

107 assert sample_split_coord in self.Y.coords 

108 

109 if neuroid_split_coord: 

110 assert neuroid_split_coord in self.Y.coords 

111 

112 X = self.X 

113 if null: 

114 y_shuffle = self.Y.copy() 

115 y_shuffle.data = np.random.default_rng(seed=seed).permutation( 

116 y_shuffle.data, axis=0 

117 ) 

118 Y = y_shuffle 

119 else: 

120 Y = self.Y 

121 y_pred, y_true = self.mapping.fit_transform(X, Y, ceiling=ceiling) 

122 

123 if not (ceiling or null): 

124 self.Y_pred = y_pred 

125 if y_pred.shape == y_true.shape: # not IdentityMap 

126 self.Y_pred = copy_metadata(self.Y_pred, self.Y, "sampleid") 

127 self.Y_pred = copy_metadata(self.Y_pred, self.Y, "neuroid") 

128 self.Y_pred = copy_metadata(self.Y_pred, self.Y, "timeid") 

129 

130 scores_over_time = [] 

131 for timeid in y_true.timeid.values: 

132 

133 y_pred_time = y_pred.sel(timeid=timeid).transpose("sampleid", "neuroid") 

134 y_true_time = y_true.sel(timeid=timeid).transpose("sampleid", "neuroid") 

135 

136 if sample_split_coord: 

137 if sample_split_coord not in y_true_time.sampleid.coords: 

138 y_pred_time = collapse_multidim_coord( 

139 y_pred_time, sample_split_coord, "sampleid" 

140 ) 

141 y_true_time = collapse_multidim_coord( 

142 y_true_time, sample_split_coord, "sampleid" 

143 ) 

144 score_splits = y_pred_time.sampleid.groupby(sample_split_coord).groups 

145 else: 

146 score_splits = [0] 

147 

148 scores_over_time_group = [] 

149 for scoreid in score_splits: 

150 

151 if sample_split_coord: 

152 y_pred_time_group = y_pred_time.isel( 

153 sampleid=y_pred_time[sample_split_coord] == scoreid 

154 ) 

155 y_true_time_group = y_true_time.isel( 

156 sampleid=y_true_time[sample_split_coord] == scoreid 

157 ) 

158 else: 

159 y_pred_time_group = y_pred_time 

160 y_true_time_group = y_true_time 

161 

162 neuroids = [] 

163 if y_pred.shape != y_true.shape and neuroid_split_coord: # IdentityMap 

164 if neuroid_split_coord: 

165 if neuroid_split_coord not in y_true_time_group.neuroid.coords: 

166 y_true_time_group = collapse_multidim_coord( 

167 y_true_time_group, neuroid_split_coord, "neuroid" 

168 ) 

169 neuroid_splits = y_true_time_group.neuroid.groupby( 

170 neuroid_split_coord 

171 ).groups 

172 score_per_time_group = [] 

173 for neuroid in neuroid_splits: 

174 score_per_time_group.append( 

175 self._score( 

176 y_pred_time_group, 

177 y_true_time_group.isel( 

178 neuroid=( 

179 y_true_time_group[neuroid_split_coord] 

180 == neuroid 

181 ) 

182 ), 

183 self.metric, 

184 ) 

185 ) 

186 neuroids.append(neuroid) 

187 score_per_time_group = np.array(score_per_time_group) 

188 else: 

189 score_per_time_group = self._score( 

190 y_pred_time_group, y_true_time_group, self.metric 

191 ) 

192 

193 if neuroids: 

194 pass 

195 elif len(score_per_time_group) == 1: # e.g., RSA, CKA, w/o split 

196 neuroids = [np.nan] 

197 else: 

198 neuroids = y_true_time_group.neuroid.data 

199 

200 scores_over_time_group.append( 

201 xr.DataArray( 

202 score_per_time_group.reshape(1, -1, 1), 

203 dims=("scoreid", "neuroid", "timeid"), 

204 coords={ 

205 "scoreid": ("scoreid", [scoreid]), 

206 "neuroid": ("neuroid", neuroids), 

207 "timeid": ("timeid", [timeid]), 

208 }, 

209 ) 

210 ) 

211 

212 scores_over_time.append(xr.concat(scores_over_time_group, dim="scoreid")) 

213 

214 scores = xr.concat(scores_over_time, dim="timeid") 

215 

216 if scores.neuroid.size == self.Y.neuroid.size: # not RSA, CKA, etc. 

217 scores = copy_metadata(scores, self.Y, "neuroid") 

218 scores = copy_metadata(scores, self.Y, "timeid") 

219 

220 if not (ceiling or null): 

221 self.scores = scores 

222 elif ceiling: 

223 self.ceilings = scores 

224 else: 

225 self.nulls.append( 

226 scores.expand_dims(dim={"iter": [seed]}, axis=-1).assign_coords( 

227 iter=[seed] 

228 ) 

229 ) 

230 

231 def ceiling(self): # , sample_split_coord=None, neuroid_split_coord=None): 

232 logging.log("Calculating ceiling.", type="INFO") 

233 self.score( 

234 ceiling=True, 

235 # sample_split_coord=self._sample_split_coord, 

236 # neuroid_split_coord=neuroid_split_coord, 

237 ) 

238 

239 def null( 

240 self, 

241 # sample_split_coord=None, neuroid_split_coord=None, 

242 iters=100, 

243 ): 

244 for i in tqdm([*range(iters)], desc="Running null permutations"): 

245 self.score( 

246 null=True, 

247 # sample_split_coord=sample_split_coord, 

248 # neuroid_split_coord=neuroid_split_coord, 

249 seed=i, 

250 ) 

251 self.nulls = xr.concat(self.nulls, dim="iter") 

252 

253 def run( 

254 self, 

255 sample_split_coord=None, 

256 neuroid_split_coord=None, 

257 calc_nulls=False, 

258 iters=100, 

259 ): 

260 self.score( 

261 sample_split_coord=sample_split_coord, 

262 neuroid_split_coord=neuroid_split_coord, 

263 ) 

264 self.ceiling( 

265 sample_split_coord=sample_split_coord, 

266 neuroid_split_coord=neuroid_split_coord, 

267 ) 

268 if calc_nulls: 

269 self.null( 

270 sample_split_coord=sample_split_coord, 

271 neuroid_split_coord=neuroid_split_coord, 

272 iters=iters, 

273 ) 

274 return { 

275 "scores": self.scores, 

276 "ceilings": self.ceilings, 

277 "nulls": self.nulls, 

278 } 

279 return {"scores": self.scores, "ceilings": self.ceilings}