Coverage for langbrainscore/utils/encoder.py: 15%

205 statements  

« prev     ^ index     » next       coverage.py v6.4, created at 2022-06-07 21:22 +0000

1import typing 

2 

3import numpy as np 

4import torch 

5import xarray as xr 

6from tqdm.auto import tqdm 

7import random 

8 

9from langbrainscore.utils.preprocessing import preprocessor_classes 

10from langbrainscore.utils.logging import log 

11 

12 

13def count_zero_threshold_values( 

14 A: np.ndarray, 

15 zero_threshold: float = 0.001, 

16): 

17 """Given matrix A, count how many values are below the zero_threshold""" 

18 return np.sum(A < zero_threshold) 

19 

20 

21def flatten_activations_per_sample(activations: dict): 

22 """ 

23 Convert activations into dataframe format 

24 

25 Args: 

26 Input (dict): key = layer, value = array of emb_dim 

27 

28 Returns: 

29 arr_flat (np.ndarray): 1D ndarray of flattened activations across all layers 

30 layers_arr (np.ndarray): 1D ndarray of layer indices, corresponding to arr_flat 

31 """ 

32 layers_arr = [] 

33 arr_flat = [] 

34 for layer, arr in activations.items(): # Iterate over layers 

35 arr = np.array(arr) 

36 arr_flat.append(arr) 

37 for i in range(arr.shape[0]): # Iterate across units 

38 layers_arr.append(layer) 

39 arr_flat = np.concatenate( 

40 arr_flat, axis=0 

41 ) # concatenated activations across layers 

42 

43 return arr_flat, np.array(layers_arr) 

44 

45 

46def aggregate_layers( 

47 hidden_states: dict, mode: typing.Union[str, typing.Callable] 

48) -> np.ndarray: 

49 """Input a hidden states dictionary (key = layer, value = 2D array of n_tokens x emb_dim) 

50 

51 Args: 

52 hidden_states (dict): key = layer (int), value = 2D PyTorch tensor of shape (n_tokens, emb_dim) 

53 

54 Raises: 

55 NotImplementedError 

56 

57 Returns: 

58 dict: key = layer, value = array of emb_dim 

59 """ 

60 states_layers = dict() 

61 

62 emb_aggregation = mode 

63 # iterate over layers 

64 for i in hidden_states.keys(): 

65 if emb_aggregation == "last": 

66 state = hidden_states[i][-1, :] # get last token 

67 elif emb_aggregation == "first": 

68 state = hidden_states[i][0, :] # get first token 

69 elif emb_aggregation == "mean": 

70 state = torch.mean(hidden_states[i], dim=0) # mean over tokens 

71 elif emb_aggregation == "median": 

72 state = torch.median(hidden_states[i], dim=0) # median over tokens 

73 elif emb_aggregation == "sum": 

74 state = torch.sum(hidden_states[i], dim=0) # sum over tokens 

75 elif emb_aggregation == "all" or emb_aggregation == None: 

76 state = hidden_states 

77 elif callable(emb_aggregation): 

78 state = emb_aggregation(hidden_states[i]) 

79 else: 

80 raise NotImplementedError( 

81 f"Sentence embedding method [{emb_aggregation}] not implemented" 

82 ) 

83 

84 states_layers[i] = state.detach().cpu().numpy() 

85 

86 return states_layers 

87 

88 

89def get_torch_device(): 

90 """ 

91 get torch device based on whether cuda is available or not 

92 """ 

93 import torch 

94 

95 # Set device to GPU if cuda is available. 

96 if torch.cuda.is_available(): 

97 device = torch.device("cuda") 

98 torch.set_default_tensor_type(torch.cuda.FloatTensor) 

99 else: 

100 device = torch.device("cpu") 

101 return device 

102 

103 

104def set_case(sample: str, emb_case: typing.Union[str, None] = None): 

105 if emb_case == "lower": 

106 return sample.lower() 

107 elif emb_case == "upper": 

108 return sample.upper() 

109 return sample 

110 

111 

112def get_context_groups(dataset, context_dimension): 

113 if context_dimension is None: 

114 context_groups = np.arange(0, dataset.stimuli.size, 1) 

115 else: 

116 context_groups = dataset.stimuli.coords[context_dimension].values 

117 return context_groups 

118 

119 

120def preprocess_activations(*args, **kwargs): 

121 return postprocess_activations(*args, **kwargs) 

122 

123 

124def postprocess_activations( 

125 activations_2d: np.ndarray = None, 

126 layer_ids_1d: np.ndarray = None, 

127 emb_preproc_mode: str = None, # "demean", 

128): 

129 

130 activations_processed = [] 

131 layer_ids_processed = [] 

132 

133 # log(f"Preprocessing activations with {p_id}") 

134 for l_id in np.sort(np.unique(layer_ids_1d)): # For each layer 

135 preprocessor = preprocessor_classes[emb_preproc_mode] 

136 

137 # Get the activations for this layer and retain 2d shape: [n_samples, emb_dim] 

138 activations_2d_layer = activations_2d[:, layer_ids_1d == l_id] 

139 

140 preprocessor.fit( 

141 activations_2d_layer 

142 ) # obtain a scaling per unit (in emb space) 

143 

144 # Apply the scaling to the activations and reassamble the activations (might have different shape than original) 

145 activations_2d_layer_processed = preprocessor.transform(activations_2d_layer) 

146 activations_processed += [activations_2d_layer_processed] 

147 layer_ids_processed += [np.full(activations_2d_layer_processed.shape[1], l_id)] 

148 

149 # Concatenate to obtain [n_samples, emb_dim across layers], i.e., flattened activations 

150 activations_2d_layer_processed = np.hstack(activations_processed) 

151 layer_ids_1d_processed = np.hstack(layer_ids_processed) 

152 

153 return activations_2d_layer_processed, layer_ids_1d_processed 

154 

155 

156def repackage_flattened_activations( 

157 activations_2d: np.ndarray = None, 

158 layer_ids_1d: np.ndarray = None, 

159 dataset: xr.Dataset = None, 

160): 

161 return xr.DataArray( 

162 np.expand_dims(activations_2d, axis=2), # add in time dimension 

163 dims=("sampleid", "neuroid", "timeid"), 

164 coords={ 

165 "sampleid": dataset.contents.sampleid.values, 

166 "neuroid": np.arange(len(layer_ids_1d)), 

167 "timeid": np.arange(1), 

168 "layer": ("neuroid", np.array(layer_ids_1d, dtype="int64")), 

169 }, 

170 ) 

171 

172 

173def cos_sim_matrix(A, B): 

174 """Compute the cosine similarity matrix between two matrices A and B. 

175 1 means the two vectors are identical. 0 means they are orthogonal. 

176 -1 means they are opposite.""" 

177 return (A * B).sum(axis=1) / (A * A).sum(axis=1) ** 0.5 / (B * B).sum(axis=1) ** 0.5 

178 

179 

180def pick_matching_token_ixs( 

181 batchencoding: "transformers.tokenization_utils_base.BatchEncoding", 

182 char_span_of_interest: slice, 

183) -> slice: 

184 """Picks token indices in a tokenized encoded sequence that best correspond to 

185 a substring of interest in the original sequence, given by a char span (slice) 

186 

187 Args: 

188 batchencoding (transformers.tokenization_utils_base.BatchEncoding): the output of a 

189 `tokenizer(text)` call on a single text instance (not a batch, i.e. `tokenizer([text])`). 

190 char_span_of_interest (slice): a `slice` object denoting the character indices in the 

191 original `text` string we want to extract the corresponding tokens for 

192 

193 Returns: 

194 slice: the start and stop indices within an encoded sequence that 

195 best match the `char_span_of_interest` 

196 """ 

197 from transformers import tokenization_utils_base 

198 

199 start_token = 0 

200 end_token = batchencoding.input_ids.shape[-1] 

201 for i, _ in enumerate(batchencoding.input_ids.reshape(-1)): 

202 span = batchencoding[0].token_to_chars( 

203 i 

204 ) # batchencoding 0 gives access to the encoded string 

205 

206 if span is None: # for [CLS], no span is returned 

207 log( 

208 f'No span returned for token at {i}: "{batchencoding.tokens()[i]}"', 

209 type="WARN", 

210 cmap="WARN", 

211 verbosity_check=True, 

212 ) 

213 continue 

214 else: 

215 span = tokenization_utils_base.CharSpan(*span) 

216 

217 if span.start <= char_span_of_interest.start: 

218 start_token = i 

219 if span.end >= char_span_of_interest.stop: 

220 end_token = i + 1 

221 break 

222 

223 assert ( 

224 end_token - start_token <= batchencoding.input_ids.shape[-1] 

225 ), f"Extracted span is larger than original span" 

226 

227 return slice(start_token, end_token) 

228 

229 

230def encode_stimuli_in_context( 

231 stimuli_in_context, 

232 tokenizer: "transformers.AutoTokenizer", 

233 model: "transformers.AutoModel", 

234 bidirectional: bool, 

235 include_special_tokens: bool, 

236 emb_aggregation, 

237 device=get_torch_device(), 

238): 

239 """ """ 

240 # CONTEXT LOOP 

241 for i, stimulus in enumerate(stimuli_in_context): 

242 

243 # extract stim to encode based on the uni/bi-directional nature of models 

244 if not bidirectional: 

245 stimuli_directional = stimuli_in_context[: i + 1] 

246 else: 

247 stimuli_directional = stimuli_in_context 

248 

249 # join the stimuli together within a context group using just a single space 

250 stimuli_directional = " ".join(stimuli_directional) 

251 

252 tokenized_directional_context = tokenizer( 

253 stimuli_directional, 

254 padding=False, 

255 return_tensors="pt", 

256 add_special_tokens=True, 

257 ).to(device) 

258 

259 # Get the hidden states 

260 result_model = model( 

261 tokenized_directional_context.input_ids, 

262 output_hidden_states=True, 

263 return_dict=True, 

264 ) 

265 

266 # dict with key=layer, value=3D tensor of dims: [batch, tokens, emb size] 

267 hidden_states = result_model["hidden_states"] 

268 

269 layer_wise_activations = dict() 

270 

271 # Find which indices match the current stimulus in the given context group 

272 start_of_interest = stimuli_directional.find(stimulus) 

273 char_span_of_interest = slice( 

274 start_of_interest, start_of_interest + len(stimulus) 

275 ) 

276 token_span_of_interest = pick_matching_token_ixs( 

277 tokenized_directional_context, char_span_of_interest 

278 ) 

279 

280 log( 

281 f"Interested in the following stimulus:\n{stimuli_directional[char_span_of_interest]}\n" 

282 f"Recovered:\n{tokenized_directional_context.tokens()[token_span_of_interest]}", 

283 cmap="INFO", 

284 type="INFO", 

285 verbosity_check=True, 

286 ) 

287 

288 all_special_ids = set(tokenizer.all_special_ids) 

289 

290 # Look for special tokens in the beginning and end of the sequence 

291 insert_first_upto = 0 

292 insert_last_from = tokenized_directional_context.input_ids.shape[-1] 

293 # loop through input ids 

294 for i, tid in enumerate(tokenized_directional_context.input_ids[0, :]): 

295 if tid.item() in all_special_ids: 

296 insert_first_upto = i + 1 

297 else: 

298 break 

299 for i in range(1, tokenized_directional_context.input_ids.shape[-1] + 1): 

300 tid = tokenized_directional_context.input_ids[0, -i] 

301 if tid.item() in all_special_ids: 

302 insert_last_from -= 1 

303 else: 

304 break 

305 

306 for idx_layer, layer in enumerate(hidden_states): # Iterate over layers 

307 # b (1), n (tokens), h (768, ...) 

308 # collapse batch dim to obtain shape (n_tokens, emb_dim) 

309 this_extracted = layer[ 

310 :, 

311 token_span_of_interest, 

312 :, 

313 ].squeeze(0) 

314 

315 if include_special_tokens: 

316 # get the embeddings for the first special tokens 

317 this_extracted = torch.cat( 

318 [ 

319 layer[:, :insert_first_upto, :].squeeze(0), 

320 this_extracted, 

321 ], 

322 axis=0, 

323 ) 

324 # get the embeddings for the last special tokens 

325 this_extracted = torch.cat( 

326 [ 

327 this_extracted, 

328 layer[:, insert_last_from:, :].squeeze(0), 

329 ], 

330 axis=0, 

331 ) 

332 

333 layer_wise_activations[idx_layer] = this_extracted.detach() 

334 

335 # Aggregate hidden states within a sample 

336 # aggregated_layerwise_sentence_encodings is a dict with key = layer, value = array of emb dimension 

337 aggregated_layerwise_sentence_encodings = aggregate_layers( 

338 layer_wise_activations, mode=emb_aggregation 

339 ) 

340 yield aggregated_layerwise_sentence_encodings 

341 # END CONTEXT LOOP 

342 

343 

344def dataset_from_stimuli(stimuli: "pd.DataFrame"): 

345 pass 

346 

347 

348############################################################################### 

349# ANALYSIS UTILS: these act upon encoded data, rather than encoders 

350############################################################################### 

351 

352 

353def get_decomposition_method(method: str = "pca", n_comp: int = 10, **kwargs): 

354 """ 

355 Return the sklearn method to use for decomposition. 

356 

357 Args: 

358 method (str): Method to use for decomposition (default: "pca", other options: "mds", "tsne") 

359 n_comp (int): Number of components to keep (default: 10) 

360 

361 Returns: 

362 sklearn method 

363 """ 

364 

365 if method == "pca": 

366 from sklearn.decomposition import PCA 

367 

368 decomp_method = PCA(n_components=n_comp) 

369 

370 elif method == "mds": 

371 from sklearn.manifold import MDS 

372 

373 decomp_method = MDS(n_components=n_comp) 

374 

375 elif method == "tsne": 

376 from sklearn.manifold import TSNE 

377 

378 decomp_method = TSNE(n_components=n_comp) 

379 

380 else: 

381 raise ValueError(f"Unknown method: {method}") 

382 

383 return decomp_method 

384 

385 

386def get_explainable_variance( 

387 ann_encoded_dataset, 

388 method: str = "pca", 

389 variance_threshold: float = 0.80, 

390 **kwargs, 

391) -> xr.Dataset: 

392 """ 

393 Returns how many components are needed to explain the variance threshold (default 80%) per layer. 

394 

395 Args: 

396 ann_encoded_dataset (xr.Dataset): ANN encoded dataset 

397 method (str): Method to use for decomposition (default: "pca", other options: "mds", "tsne") 

398 variance_threshold (float): Variance threshold to use for determining how many components are needed to 

399 explain explained a certain threshold of variance (default: 0.80) 

400 **kwargs: Additional keyword arguments to pass to the underlying method 

401 

402 Returns: 

403 variance_across_layers (dict): Nested dict with value of interest as key (e.g., explained variance) and 

404 layer id as key (e.g., 0, 1, 2, ...) with corresponding values. 

405 

406 """ 

407 

408 ks = [ 

409 f"n_comp-{method}_needed-{variance_threshold}", 

410 f"first_comp-{method}_explained_variance", 

411 ] 

412 variance_across_layers = {k: {} for k in ks} 

413 

414 # Get the PCA explained variance per layer 

415 layer_ids = ann_encoded_dataset.layer.values 

416 _, unique_ixs = np.unique(layer_ids, return_index=True) 

417 

418 # Make sure that layer order is preserved 

419 for layer_id in tqdm(layer_ids[np.sort(unique_ixs)]): 

420 layer_dataset = ( 

421 ann_encoded_dataset.isel(neuroid=(ann_encoded_dataset.layer == layer_id)) 

422 .drop("timeid") 

423 .squeeze() 

424 ) 

425 

426 # Figure out how many PCs we attempt to fit 

427 n_comp = np.min([layer_dataset.shape[1], layer_dataset.shape[0]]) 

428 

429 # Get explained variance 

430 decomp_method = get_decomposition_method(method=method, n_comp=n_comp, **kwargs) 

431 

432 decomp_method.fit(layer_dataset.values) 

433 explained_variance = decomp_method.explained_variance_ratio_ 

434 

435 # Get the number of PCs needed to explain the variance threshold 

436 explained_variance_cum = np.cumsum(explained_variance) 

437 n_pc_needed = np.argmax(explained_variance_cum >= variance_threshold) + 1 

438 

439 # Store per layer 

440 layer_id = str(layer_id) 

441 print( 

442 f"Layer {layer_id}: {n_pc_needed} PCs needed to explain {variance_threshold} variance " 

443 f"with the 1st PC explaining {explained_variance[0]:.2f}% of the total variance" 

444 ) 

445 

446 variance_across_layers[f"n_comp-{method}_needed-{variance_threshold}"][ 

447 layer_id 

448 ] = n_pc_needed 

449 variance_across_layers[f"first_comp-{method}_explained_variance"][ 

450 layer_id 

451 ] = explained_variance[0] 

452 

453 return variance_across_layers 

454 

455 

456def get_layer_sparsity( 

457 ann_encoded_dataset, zero_threshold: float = 0.0001, **kwargs 

458) -> xr.Dataset: 

459 """ 

460 Check how sparse activations within a given layer are. 

461 

462 Sparsity is defined as 1 - values below the zero_threshold / total number of values. 

463 

464 Args: 

465 ann_encoded_dataset (xr.Dataset): ANN encoded dataset 

466 zero_threshold (float): Threshold to use for determining sparsity (default: 0.0001) 

467 **kwargs: Additional keyword arguments to pass to the underlying method 

468 

469 Returns: 

470 sparsity_across_layers (dict): Nested dict with value of interest as key (e.g., sparsity) and 

471 layer id as key (e.g., 0, 1, 2, ...) with corresponding values. 

472 

473 """ 

474 # Obtain embedding dimension (for sanity checks) 

475 # if self.model_specs["hidden_emb_dim"]: 

476 # hidden_emb_dim = self.model_specs["hidden_emb_dim"] 

477 # else: 

478 # hidden_emb_dim = None 

479 # log( 

480 # f"Hidden embedding dimension not specified yet", 

481 # cmap="WARN", 

482 # type="WARN", 

483 # ) 

484 

485 ks = [f"sparsity-{zero_threshold}"] 

486 sparsity_across_layers = {k: {} for k in ks} 

487 

488 # Get the PCA explained variance per layer 

489 layer_ids = ann_encoded_dataset.layer.values 

490 _, unique_ixs = np.unique(layer_ids, return_index=True) 

491 

492 # Make sure that layer order is preserved 

493 for layer_id in tqdm(layer_ids[np.sort(unique_ixs)]): 

494 layer_dataset = ( 

495 ann_encoded_dataset.isel(neuroid=(ann_encoded_dataset.layer == layer_id)) 

496 .drop("timeid") 

497 .squeeze() 

498 ) 

499 

500 # if hidden_emb_dim is not None: 

501 # assert layer_dataset.shape[1] == hidden_emb_dim 

502 # 

503 # Get sparsity 

504 zero_values = count_zero_threshold_values(layer_dataset.values, zero_threshold) 

505 sparsity = 1 - (zero_values / layer_dataset.size) 

506 

507 # Store per layer 

508 layer_id = str(layer_id) 

509 print(f"Layer {layer_id}: {sparsity:.3f} sparsity") 

510 

511 sparsity_across_layers[f"sparsity-{zero_threshold}"][layer_id] = sparsity 

512 

513 return sparsity_across_layers 

514 

515 

516def cos_contrib( 

517 emb1: np.ndarray, 

518 emb2: np.ndarray, 

519): 

520 """ 

521 Cosine contribution function defined in eq. 3 by Timkey & van Schijndel (2021): https://arxiv.org/abs/2109.04404 

522 

523 Args: 

524 emb1 (np.ndarray): Embedding vector 1 

525 emb2 (np.ndarray): Embedding vector 2 

526 

527 Returns: 

528 cos_contrib (float): Cosine contribution 

529 

530 """ 

531 

532 numerator_terms = emb1 * emb2 

533 denom = np.linalg.norm(emb1) * np.linalg.norm(emb2) 

534 return numerator_terms / denom 

535 

536 

537def get_anisotropy( 

538 ann_encoded_dataset: "EncoderRepresentations", num_random_samples: int = 1000 

539): 

540 """ 

541 Calculate the anisotropy of the embedding vectors as Timkey & van Schijndel (2021): https://arxiv.org/abs/2109.04404 

542 (base function from their GitHub repo: https://github.com/wtimkey/rogue-dimensions/blob/main/replication.ipynb, 

543 but modified to work within the Language Brain-Score project) 

544 

545 

546 """ 

547 rogue_dist = [] 

548 num_toks = len(ann_encoded_dataset.sampleid) # Number of stimuli 

549 

550 # randomly sample embedding pairs to compute avg. cosine similiarity contribution 

551 random_pairs = [ 

552 random.sample(range(num_toks), 2) for i in range(num_random_samples) 

553 ] 

554 

555 cos_contribs_by_layer = [] 

556 

557 layer_ids = ann_encoded_dataset.layer.values 

558 _, unique_ixs = np.unique(layer_ids, return_index=True) 

559 

560 for layer_id in tqdm(layer_ids[np.sort(unique_ixs)]): 

561 layer_dataset = ( 

562 ann_encoded_dataset.isel(neuroid=(ann_encoded_dataset.layer == layer_id)) 

563 .drop("timeid") 

564 .squeeze() 

565 ) 

566 

567 layer_cosine_contribs = [] 

568 layer_rogue_cos_contribs = [] 

569 for pair in random_pairs: 

570 emb1 = sample_data[layer, pair[0], :] # fix 

571 emb2 = sample_data[layer, pair[1], :] 

572 layer_cosine_contribs.append(cos_contrib(emb1, emb2)) 

573 

574 layer_cosine_contribs = np.array(layer_cosine_contribs) 

575 layer_cosine_sims = layer_cosine_contribs.sum(axis=1) 

576 layer_cosine_contribs_mean = layer_cosine_contribs.mean(axis=0) 

577 cos_contribs_by_layer.append(layer_cosine_contribs_mean) 

578 cos_contribs_by_layer = np.array(cos_contribs_by_layer) 

579 

580 aniso = cos_contribs_by_layer.sum( 

581 axis=1 

582 ) # total anisotropy, measured as avg. cosine sim between random emb. pairs 

583 

584 for layer in range(num_layers[model_name]): 

585 top_3_dims = np.argsort(cos_contribs_by_layer[layer])[-3:] 

586 top = cos_contribs_by_layer[layer, top_3_dims[2]] / aniso[layer] 

587 second = cos_contribs_by_layer[layer, top_3_dims[1]] / aniso[layer] 

588 third = cos_contribs_by_layer[layer, top_3_dims[0]] / aniso[layer] 

589 print( 

590 "& {} & {:.3f} & {:.3f} & {:.3f} & {:.3f} \\\\".format( 

591 layer, top, second, third, aniso[layer] 

592 ) 

593 )