Coverage for langbrainscore/encoder/ann.py: 21%

122 statements  

« prev     ^ index     » next       coverage.py v6.4, created at 2022-06-07 21:22 +0000

1import typing 

2from enum import unique 

3 

4import os 

5import numpy as np 

6import torch 

7from tqdm import tqdm 

8import xarray as xr 

9 

10from langbrainscore.dataset import Dataset 

11from langbrainscore.interface import EncoderRepresentations, _ModelEncoder 

12from langbrainscore.utils.encoder import ( 

13 aggregate_layers, 

14 cos_sim_matrix, 

15 count_zero_threshold_values, 

16 flatten_activations_per_sample, 

17 get_context_groups, 

18 get_torch_device, 

19 pick_matching_token_ixs, 

20 postprocess_activations, 

21 repackage_flattened_activations, 

22 encode_stimuli_in_context, 

23) 

24 

25from langbrainscore.utils.logging import log 

26from langbrainscore.utils.xarray import copy_metadata, fix_xr_dtypes 

27from langbrainscore.utils.resources import model_classes, config_name_mappings 

28 

29os.environ["TOKENIZERS_PARALLELISM"] = "true" 

30 

31 

32class HuggingFaceEncoder(_ModelEncoder): 

33 def __init__( 

34 self, 

35 model_id, 

36 emb_aggregation: typing.Union[str, None, typing.Callable], 

37 device=get_torch_device(), 

38 context_dimension: str = None, 

39 bidirectional: bool = False, 

40 emb_preproc: typing.Tuple[str] = (), 

41 include_special_tokens: bool = True, 

42 ): 

43 """ 

44 Args: 

45 model_id (str): the model id 

46 device (None, ?): the device to use 

47 context_dimension (str, optional): the dimension to use for extracting strings using context. 

48 if None, each sampleid (stimuli) will be treated as a single context group. 

49 if a string is specified, the string must refer to the name of a dimension in the xarray-like dataset 

50 object (langbrainscore.dataset.Dataset) that provides groupings of sampleids (stimuli) that should be 

51 used as context when generating encoder representations [default: None]. 

52 bidirectional (bool): whether to use bidirectional encoder (i.e., access both forward and backward context) 

53 [default: False] 

54 emb_aggregation (typing.Union[str, None, typing.Callable], optional): how to aggregate the hidden states of 

55 the encoder representations for each sampleid (stimuli). [default: "last"] 

56 emb_preproc (tuple): a list of strings specifying preprocessing functions to apply to the aggregated embeddings. 

57 Processing is performed layer-wise. 

58 include_special_tokens (bool): whether to include special tokens in the encoder representations. 

59 """ 

60 

61 super().__init__( 

62 model_id, 

63 _context_dimension=context_dimension, 

64 _bidirectional=bidirectional, 

65 _emb_aggregation=emb_aggregation, 

66 _emb_preproc=emb_preproc, 

67 _include_special_tokens=include_special_tokens, 

68 ) 

69 

70 from transformers import AutoConfig, AutoModel, AutoTokenizer 

71 from transformers import logging as transformers_logging 

72 

73 transformers_logging.set_verbosity_error() 

74 

75 self.device = device or get_torch_device() 

76 self.config = AutoConfig.from_pretrained(self._model_id) 

77 self.tokenizer = AutoTokenizer.from_pretrained( 

78 self._model_id, multiprocessing=True 

79 ) 

80 self.model = AutoModel.from_pretrained(self._model_id, config=self.config) 

81 try: 

82 self.model = self.model.to(self.device) 

83 except RuntimeError: 

84 self.device = "cpu" 

85 self.model = self.model.to(self.device) 

86 

87 def get_encoder_representations_template( 

88 self, dataset=None, representations=xr.DataArray() 

89 ) -> EncoderRepresentations: 

90 """ 

91 returns an empty `EncoderRepresentations` object with all the appropriate 

92 attributes but the `dataset` and `representations` missing and to be filled in 

93 later. 

94 """ 

95 return EncoderRepresentations( 

96 dataset=dataset, 

97 representations=representations, 

98 model_id=self._model_id, 

99 context_dimension=self._context_dimension, 

100 bidirectional=self._bidirectional, 

101 emb_aggregation=self._emb_aggregation, 

102 emb_preproc=self._emb_preproc, 

103 include_special_tokens=self._include_special_tokens, 

104 ) 

105 

106 def encode( 

107 self, 

108 dataset: Dataset, 

109 read_cache: bool = True, # avoid recomputing if cached `EncoderRepresentations` exists, recompute if not 

110 write_cache: bool = True, # dump the result of this computation to cache? 

111 ) -> EncoderRepresentations: 

112 """ 

113 Input a langbrainscore Dataset, encode the stimuli according to the parameters specified in init, and return 

114 the an xarray DataArray of aggregated representations for each stimulus. 

115 

116 Args: 

117 dataset (langbrainscore.dataset.DataSet): [description] 

118 read_cache (bool): Avoid recomputing if cached `EncoderRepresentations` exists, recompute if not 

119 write_cache (bool): Dump and write the result of the computed encoder representations to cache 

120 

121 Raises: 

122 NotImplementedError: [description] 

123 ValueError: [description] 

124 

125 Returns: 

126 [type]: [description] 

127 """ 

128 

129 # before computing the representations from scratch, we will first see if any 

130 # cached representations exist already. 

131 

132 if read_cache: 

133 to_check_in_cache: EncoderRepresentations = ( 

134 self.get_encoder_representations_template(dataset=dataset) 

135 ) 

136 

137 try: 

138 to_check_in_cache.load_cache() 

139 return to_check_in_cache 

140 except FileNotFoundError: 

141 log( 

142 f"couldn't load cached reprs for {to_check_in_cache.identifier_string}; recomputing.", 

143 cmap="WARN", 

144 type="WARN", 

145 ) 

146 

147 self.model.eval() 

148 stimuli = dataset.stimuli.values 

149 

150 # Initialize the context group coordinate (obtain embeddings with context) 

151 context_groups = get_context_groups(dataset, self._context_dimension) 

152 

153 # list for storing activations for each stimulus with all layers flattened 

154 # list for storing layer ids ([0 0 0 0 ... 1 1 1 ...]) indicating which layer each 

155 # neuroid (representation dimension) came from 

156 flattened_activations, layer_ids = [], [] 

157 

158 ############################################################################### 

159 # ALL SAMPLES LOOP 

160 ############################################################################### 

161 _, unique_ixs = np.unique(context_groups, return_index=True) 

162 # Make sure context group order is preserved 

163 for group in tqdm(context_groups[np.sort(unique_ixs)], desc="Encoding stimuli"): 

164 # Mask based on the context group 

165 mask_context = context_groups == group 

166 stimuli_in_context = stimuli[mask_context] 

167 

168 # store model states for each stimulus in this context group 

169 encoded_stimuli = [] 

170 

171 ############################################################################### 

172 # CONTEXT LOOP 

173 ############################################################################### 

174 for encoded_stim in encode_stimuli_in_context( 

175 stimuli_in_context=stimuli_in_context, 

176 tokenizer=self.tokenizer, 

177 model=self.model, 

178 bidirectional=self._bidirectional, 

179 include_special_tokens=self._include_special_tokens, 

180 emb_aggregation=self._emb_aggregation, 

181 device=self.device, 

182 ): 

183 encoded_stimuli += [encoded_stim] 

184 ############################################################################### 

185 # END CONTEXT LOOP 

186 ############################################################################### 

187 

188 # Flatten activations across layers and package as xarray 

189 flattened_activations_and_layer_ids = [ 

190 *map(flatten_activations_per_sample, encoded_stimuli) 

191 ] 

192 for f_as, l_ids in flattened_activations_and_layer_ids: 

193 flattened_activations += [f_as] 

194 layer_ids += [l_ids] 

195 assert len(f_as) == len(l_ids) # Assert all layer lists are equal 

196 

197 ############################################################################### 

198 # END ALL SAMPLES LOOP 

199 ############################################################################### 

200 

201 # Stack flattened activations and layer ids to obtain [n_samples, emb_din * n_layers] 

202 activations_2d = np.vstack(flattened_activations) 

203 layer_ids_1d = np.squeeze(np.unique(np.vstack(layer_ids), axis=0)) 

204 

205 # Post-process activations after obtaining them (or "pre-process" them before computing brainscore) 

206 if len(self._emb_preproc) > 0: 

207 for mode in self._emb_preproc: 

208 activations_2d, layer_ids_1d = postprocess_activations( 

209 activations_2d=activations_2d, 

210 layer_ids_1d=layer_ids_1d, 

211 emb_preproc_mode=mode, 

212 ) 

213 

214 assert activations_2d.shape[1] == len(layer_ids_1d) 

215 assert activations_2d.shape[0] == len(stimuli) 

216 

217 # Package activations as xarray and reapply metadata 

218 encoded_dataset: xr.DataArray = repackage_flattened_activations( 

219 activations_2d=activations_2d, 

220 layer_ids_1d=layer_ids_1d, 

221 dataset=dataset, 

222 ) 

223 encoded_dataset: xr.DataArray = copy_metadata( 

224 encoded_dataset, 

225 dataset.contents, 

226 "sampleid", 

227 ) 

228 

229 to_return: EncoderRepresentations = self.get_encoder_representations_template() 

230 to_return.dataset = dataset 

231 to_return.representations = fix_xr_dtypes(encoded_dataset) 

232 

233 if write_cache: 

234 to_return.to_cache(overwrite=True) 

235 

236 return to_return 

237 

238 def get_modelcard(self): 

239 """ 

240 Returns the model card of the model (model-wise, and not layer-wise) 

241 """ 

242 

243 model_classes = [ 

244 "gpt", 

245 "bert", 

246 ] # continuously update based on new model classes supported 

247 

248 # based on the model_id, figure out which model class it is 

249 model_class = [x for x in model_classes if x in self._model_id][0] 

250 assert model_class is not None, f"model_id {self._model_id} not supported" 

251 

252 config_specs_of_interest = config_name_mappings[model_class] 

253 

254 model_specs = {} 

255 for ( 

256 k_spec, 

257 v_spec, 

258 ) in ( 

259 config_specs_of_interest.items() 

260 ): # key is the name we want to use in the model card, 

261 # value is the name in the config 

262 if v_spec is not None: 

263 model_specs[k_spec] = getattr(self.config, v_spec) 

264 else: 

265 model_specs[k_spec] = None 

266 

267 self.model_specs = model_specs 

268 

269 return model_specs 

270 

271 

272class PTEncoder(_ModelEncoder): 

273 def __init__(self, model_id: str) -> "PTEncoder": 

274 super().__init__(model_id) 

275 

276 def encode(self, dataset: "langbrainscore.dataset.Dataset") -> xr.DataArray: 

277 # TODO 

278 ... 

279 

280 

281class EncoderCheck: 

282 """ 

283 Class for checking whether obtained embeddings from the Encoder class are correct and similar to other encoder objects. 

284 """ 

285 

286 def __init__( 

287 self, 

288 ): 

289 pass 

290 

291 def _load_cached_activations(self, encoded_ann_identifier: str): 

292 raise NotImplementedError 

293 

294 def similiarity_metric_across_layers( 

295 self, 

296 sim_metric: str = "tol", 

297 enc1: xr.DataArray = None, 

298 enc2: xr.DataArray = None, 

299 tol: float = 1e-8, 

300 threshold: float = 1e-4, 

301 ) -> bool: 

302 """ 

303 Given two activations, iterate across layers and check np.allclose using different tolerance levels. 

304 

305 Parameters: 

306 sim_metric: str 

307 Similarity metric to use. 

308 enc1: xr.DataArray 

309 First encoder activations. 

310 enc2: xr.DataArray 

311 Second encoder activations. 

312 tol: float 

313 Tolerance level to start at (we will iterate upwards the tolerance level). Default is 1e-8. 

314 

315 Returns: 

316 bool: whether the tolerance level was met (True) or not (False) 

317 bad_stim: set of stimuli indices that did not meet tolerance level `threshold` (if any) 

318 

319 """ 

320 # First check is whether number of layers / shapes match 

321 assert enc1.shape == enc2.shape 

322 assert ( 

323 enc1.sampleid.values == enc2.sampleid.values 

324 ).all() # ensure that we are looking at the same stimuli 

325 layer_ids = enc1.layer.values 

326 _, unique_ixs = np.unique(layer_ids, return_index=True) 

327 print(f"\n\nChecking similarity across layers using sim_metric: {sim_metric}") 

328 

329 all_good = True 

330 bad_stim = set() # store indices of stimuli that are not similar 

331 

332 # Iterate across layers 

333 for layer_id in tqdm(layer_ids[np.sort(unique_ixs)]): 

334 enc1_layer = enc1.isel(neuroid=(enc1.layer == layer_id)) # .squeeze() 

335 enc2_layer = enc2.isel(neuroid=(enc2.layer == layer_id)) # .squeeze() 

336 

337 # Check whether values match. If not, iteratively increase tolerance until values match 

338 if sim_metric in ("tol", "diff"): 

339 abs_diff = np.abs(enc1_layer - enc2_layer) 

340 abs_diff_per_stim = np.max( 

341 abs_diff, axis=1 

342 ) # Obtain the biggest difference aross neuroids (units) 

343 while (abs_diff_per_stim > tol).all(): 

344 tol *= 10 

345 

346 elif "cos" in sim_metric: 

347 # Check cosine distance between each row, e.g., sentence vector 

348 cos_sim = cos_sim_matrix(enc1_layer, enc2_layer) 

349 cos_dist = ( 

350 1 - cos_sim 

351 ) # 0 means identical, 1 means orthogonal, 2 means opposite 

352 # We still want this as close to zero as possible for similar vectors. 

353 cos_dist_abs = np.abs(cos_dist) 

354 abs_diff_per_stim = cos_dist_abs 

355 

356 # Check how close the cosine distance is to 0 

357 while (cos_dist_abs > tol).all(): 

358 tol *= 10 

359 else: 

360 raise NotImplementedError(f"Invalid `sim_metric`: {sim_metric}") 

361 

362 print(f"Layer {layer_id}: Similarity at tolerance: {tol:.3e}") 

363 if tol > threshold: 

364 print(f"WARNING: Low tolerance level") 

365 all_good = False 

366 bad_stim.update( 

367 enc1.sampleid[np.where(abs_diff_per_stim > tol)[0]] 

368 ) # get sampleids of stimuli that are not similar 

369 

370 return all_good, bad_stim