Coverage for langbrainscore/utils/encoder.py: 15%
205 statements
« prev ^ index » next coverage.py v6.4, created at 2022-06-07 21:22 +0000
« prev ^ index » next coverage.py v6.4, created at 2022-06-07 21:22 +0000
1import typing
3import numpy as np
4import torch
5import xarray as xr
6from tqdm.auto import tqdm
7import random
9from langbrainscore.utils.preprocessing import preprocessor_classes
10from langbrainscore.utils.logging import log
13def count_zero_threshold_values(
14 A: np.ndarray,
15 zero_threshold: float = 0.001,
16):
17 """Given matrix A, count how many values are below the zero_threshold"""
18 return np.sum(A < zero_threshold)
21def flatten_activations_per_sample(activations: dict):
22 """
23 Convert activations into dataframe format
25 Args:
26 Input (dict): key = layer, value = array of emb_dim
28 Returns:
29 arr_flat (np.ndarray): 1D ndarray of flattened activations across all layers
30 layers_arr (np.ndarray): 1D ndarray of layer indices, corresponding to arr_flat
31 """
32 layers_arr = []
33 arr_flat = []
34 for layer, arr in activations.items(): # Iterate over layers
35 arr = np.array(arr)
36 arr_flat.append(arr)
37 for i in range(arr.shape[0]): # Iterate across units
38 layers_arr.append(layer)
39 arr_flat = np.concatenate(
40 arr_flat, axis=0
41 ) # concatenated activations across layers
43 return arr_flat, np.array(layers_arr)
46def aggregate_layers(
47 hidden_states: dict, mode: typing.Union[str, typing.Callable]
48) -> np.ndarray:
49 """Input a hidden states dictionary (key = layer, value = 2D array of n_tokens x emb_dim)
51 Args:
52 hidden_states (dict): key = layer (int), value = 2D PyTorch tensor of shape (n_tokens, emb_dim)
54 Raises:
55 NotImplementedError
57 Returns:
58 dict: key = layer, value = array of emb_dim
59 """
60 states_layers = dict()
62 emb_aggregation = mode
63 # iterate over layers
64 for i in hidden_states.keys():
65 if emb_aggregation == "last":
66 state = hidden_states[i][-1, :] # get last token
67 elif emb_aggregation == "first":
68 state = hidden_states[i][0, :] # get first token
69 elif emb_aggregation == "mean":
70 state = torch.mean(hidden_states[i], dim=0) # mean over tokens
71 elif emb_aggregation == "median":
72 state = torch.median(hidden_states[i], dim=0) # median over tokens
73 elif emb_aggregation == "sum":
74 state = torch.sum(hidden_states[i], dim=0) # sum over tokens
75 elif emb_aggregation == "all" or emb_aggregation == None:
76 state = hidden_states
77 elif callable(emb_aggregation):
78 state = emb_aggregation(hidden_states[i])
79 else:
80 raise NotImplementedError(
81 f"Sentence embedding method [{emb_aggregation}] not implemented"
82 )
84 states_layers[i] = state.detach().cpu().numpy()
86 return states_layers
89def get_torch_device():
90 """
91 get torch device based on whether cuda is available or not
92 """
93 import torch
95 # Set device to GPU if cuda is available.
96 if torch.cuda.is_available():
97 device = torch.device("cuda")
98 torch.set_default_tensor_type(torch.cuda.FloatTensor)
99 else:
100 device = torch.device("cpu")
101 return device
104def set_case(sample: str, emb_case: typing.Union[str, None] = None):
105 if emb_case == "lower":
106 return sample.lower()
107 elif emb_case == "upper":
108 return sample.upper()
109 return sample
112def get_context_groups(dataset, context_dimension):
113 if context_dimension is None:
114 context_groups = np.arange(0, dataset.stimuli.size, 1)
115 else:
116 context_groups = dataset.stimuli.coords[context_dimension].values
117 return context_groups
120def preprocess_activations(*args, **kwargs):
121 return postprocess_activations(*args, **kwargs)
124def postprocess_activations(
125 activations_2d: np.ndarray = None,
126 layer_ids_1d: np.ndarray = None,
127 emb_preproc_mode: str = None, # "demean",
128):
130 activations_processed = []
131 layer_ids_processed = []
133 # log(f"Preprocessing activations with {p_id}")
134 for l_id in np.sort(np.unique(layer_ids_1d)): # For each layer
135 preprocessor = preprocessor_classes[emb_preproc_mode]
137 # Get the activations for this layer and retain 2d shape: [n_samples, emb_dim]
138 activations_2d_layer = activations_2d[:, layer_ids_1d == l_id]
140 preprocessor.fit(
141 activations_2d_layer
142 ) # obtain a scaling per unit (in emb space)
144 # Apply the scaling to the activations and reassamble the activations (might have different shape than original)
145 activations_2d_layer_processed = preprocessor.transform(activations_2d_layer)
146 activations_processed += [activations_2d_layer_processed]
147 layer_ids_processed += [np.full(activations_2d_layer_processed.shape[1], l_id)]
149 # Concatenate to obtain [n_samples, emb_dim across layers], i.e., flattened activations
150 activations_2d_layer_processed = np.hstack(activations_processed)
151 layer_ids_1d_processed = np.hstack(layer_ids_processed)
153 return activations_2d_layer_processed, layer_ids_1d_processed
156def repackage_flattened_activations(
157 activations_2d: np.ndarray = None,
158 layer_ids_1d: np.ndarray = None,
159 dataset: xr.Dataset = None,
160):
161 return xr.DataArray(
162 np.expand_dims(activations_2d, axis=2), # add in time dimension
163 dims=("sampleid", "neuroid", "timeid"),
164 coords={
165 "sampleid": dataset.contents.sampleid.values,
166 "neuroid": np.arange(len(layer_ids_1d)),
167 "timeid": np.arange(1),
168 "layer": ("neuroid", np.array(layer_ids_1d, dtype="int64")),
169 },
170 )
173def cos_sim_matrix(A, B):
174 """Compute the cosine similarity matrix between two matrices A and B.
175 1 means the two vectors are identical. 0 means they are orthogonal.
176 -1 means they are opposite."""
177 return (A * B).sum(axis=1) / (A * A).sum(axis=1) ** 0.5 / (B * B).sum(axis=1) ** 0.5
180def pick_matching_token_ixs(
181 batchencoding: "transformers.tokenization_utils_base.BatchEncoding",
182 char_span_of_interest: slice,
183) -> slice:
184 """Picks token indices in a tokenized encoded sequence that best correspond to
185 a substring of interest in the original sequence, given by a char span (slice)
187 Args:
188 batchencoding (transformers.tokenization_utils_base.BatchEncoding): the output of a
189 `tokenizer(text)` call on a single text instance (not a batch, i.e. `tokenizer([text])`).
190 char_span_of_interest (slice): a `slice` object denoting the character indices in the
191 original `text` string we want to extract the corresponding tokens for
193 Returns:
194 slice: the start and stop indices within an encoded sequence that
195 best match the `char_span_of_interest`
196 """
197 from transformers import tokenization_utils_base
199 start_token = 0
200 end_token = batchencoding.input_ids.shape[-1]
201 for i, _ in enumerate(batchencoding.input_ids.reshape(-1)):
202 span = batchencoding[0].token_to_chars(
203 i
204 ) # batchencoding 0 gives access to the encoded string
206 if span is None: # for [CLS], no span is returned
207 log(
208 f'No span returned for token at {i}: "{batchencoding.tokens()[i]}"',
209 type="WARN",
210 cmap="WARN",
211 verbosity_check=True,
212 )
213 continue
214 else:
215 span = tokenization_utils_base.CharSpan(*span)
217 if span.start <= char_span_of_interest.start:
218 start_token = i
219 if span.end >= char_span_of_interest.stop:
220 end_token = i + 1
221 break
223 assert (
224 end_token - start_token <= batchencoding.input_ids.shape[-1]
225 ), f"Extracted span is larger than original span"
227 return slice(start_token, end_token)
230def encode_stimuli_in_context(
231 stimuli_in_context,
232 tokenizer: "transformers.AutoTokenizer",
233 model: "transformers.AutoModel",
234 bidirectional: bool,
235 include_special_tokens: bool,
236 emb_aggregation,
237 device=get_torch_device(),
238):
239 """ """
240 # CONTEXT LOOP
241 for i, stimulus in enumerate(stimuli_in_context):
243 # extract stim to encode based on the uni/bi-directional nature of models
244 if not bidirectional:
245 stimuli_directional = stimuli_in_context[: i + 1]
246 else:
247 stimuli_directional = stimuli_in_context
249 # join the stimuli together within a context group using just a single space
250 stimuli_directional = " ".join(stimuli_directional)
252 tokenized_directional_context = tokenizer(
253 stimuli_directional,
254 padding=False,
255 return_tensors="pt",
256 add_special_tokens=True,
257 ).to(device)
259 # Get the hidden states
260 result_model = model(
261 tokenized_directional_context.input_ids,
262 output_hidden_states=True,
263 return_dict=True,
264 )
266 # dict with key=layer, value=3D tensor of dims: [batch, tokens, emb size]
267 hidden_states = result_model["hidden_states"]
269 layer_wise_activations = dict()
271 # Find which indices match the current stimulus in the given context group
272 start_of_interest = stimuli_directional.find(stimulus)
273 char_span_of_interest = slice(
274 start_of_interest, start_of_interest + len(stimulus)
275 )
276 token_span_of_interest = pick_matching_token_ixs(
277 tokenized_directional_context, char_span_of_interest
278 )
280 log(
281 f"Interested in the following stimulus:\n{stimuli_directional[char_span_of_interest]}\n"
282 f"Recovered:\n{tokenized_directional_context.tokens()[token_span_of_interest]}",
283 cmap="INFO",
284 type="INFO",
285 verbosity_check=True,
286 )
288 all_special_ids = set(tokenizer.all_special_ids)
290 # Look for special tokens in the beginning and end of the sequence
291 insert_first_upto = 0
292 insert_last_from = tokenized_directional_context.input_ids.shape[-1]
293 # loop through input ids
294 for i, tid in enumerate(tokenized_directional_context.input_ids[0, :]):
295 if tid.item() in all_special_ids:
296 insert_first_upto = i + 1
297 else:
298 break
299 for i in range(1, tokenized_directional_context.input_ids.shape[-1] + 1):
300 tid = tokenized_directional_context.input_ids[0, -i]
301 if tid.item() in all_special_ids:
302 insert_last_from -= 1
303 else:
304 break
306 for idx_layer, layer in enumerate(hidden_states): # Iterate over layers
307 # b (1), n (tokens), h (768, ...)
308 # collapse batch dim to obtain shape (n_tokens, emb_dim)
309 this_extracted = layer[
310 :,
311 token_span_of_interest,
312 :,
313 ].squeeze(0)
315 if include_special_tokens:
316 # get the embeddings for the first special tokens
317 this_extracted = torch.cat(
318 [
319 layer[:, :insert_first_upto, :].squeeze(0),
320 this_extracted,
321 ],
322 axis=0,
323 )
324 # get the embeddings for the last special tokens
325 this_extracted = torch.cat(
326 [
327 this_extracted,
328 layer[:, insert_last_from:, :].squeeze(0),
329 ],
330 axis=0,
331 )
333 layer_wise_activations[idx_layer] = this_extracted.detach()
335 # Aggregate hidden states within a sample
336 # aggregated_layerwise_sentence_encodings is a dict with key = layer, value = array of emb dimension
337 aggregated_layerwise_sentence_encodings = aggregate_layers(
338 layer_wise_activations, mode=emb_aggregation
339 )
340 yield aggregated_layerwise_sentence_encodings
341 # END CONTEXT LOOP
344def dataset_from_stimuli(stimuli: "pd.DataFrame"):
345 pass
348###############################################################################
349# ANALYSIS UTILS: these act upon encoded data, rather than encoders
350###############################################################################
353def get_decomposition_method(method: str = "pca", n_comp: int = 10, **kwargs):
354 """
355 Return the sklearn method to use for decomposition.
357 Args:
358 method (str): Method to use for decomposition (default: "pca", other options: "mds", "tsne")
359 n_comp (int): Number of components to keep (default: 10)
361 Returns:
362 sklearn method
363 """
365 if method == "pca":
366 from sklearn.decomposition import PCA
368 decomp_method = PCA(n_components=n_comp)
370 elif method == "mds":
371 from sklearn.manifold import MDS
373 decomp_method = MDS(n_components=n_comp)
375 elif method == "tsne":
376 from sklearn.manifold import TSNE
378 decomp_method = TSNE(n_components=n_comp)
380 else:
381 raise ValueError(f"Unknown method: {method}")
383 return decomp_method
386def get_explainable_variance(
387 ann_encoded_dataset,
388 method: str = "pca",
389 variance_threshold: float = 0.80,
390 **kwargs,
391) -> xr.Dataset:
392 """
393 Returns how many components are needed to explain the variance threshold (default 80%) per layer.
395 Args:
396 ann_encoded_dataset (xr.Dataset): ANN encoded dataset
397 method (str): Method to use for decomposition (default: "pca", other options: "mds", "tsne")
398 variance_threshold (float): Variance threshold to use for determining how many components are needed to
399 explain explained a certain threshold of variance (default: 0.80)
400 **kwargs: Additional keyword arguments to pass to the underlying method
402 Returns:
403 variance_across_layers (dict): Nested dict with value of interest as key (e.g., explained variance) and
404 layer id as key (e.g., 0, 1, 2, ...) with corresponding values.
406 """
408 ks = [
409 f"n_comp-{method}_needed-{variance_threshold}",
410 f"first_comp-{method}_explained_variance",
411 ]
412 variance_across_layers = {k: {} for k in ks}
414 # Get the PCA explained variance per layer
415 layer_ids = ann_encoded_dataset.layer.values
416 _, unique_ixs = np.unique(layer_ids, return_index=True)
418 # Make sure that layer order is preserved
419 for layer_id in tqdm(layer_ids[np.sort(unique_ixs)]):
420 layer_dataset = (
421 ann_encoded_dataset.isel(neuroid=(ann_encoded_dataset.layer == layer_id))
422 .drop("timeid")
423 .squeeze()
424 )
426 # Figure out how many PCs we attempt to fit
427 n_comp = np.min([layer_dataset.shape[1], layer_dataset.shape[0]])
429 # Get explained variance
430 decomp_method = get_decomposition_method(method=method, n_comp=n_comp, **kwargs)
432 decomp_method.fit(layer_dataset.values)
433 explained_variance = decomp_method.explained_variance_ratio_
435 # Get the number of PCs needed to explain the variance threshold
436 explained_variance_cum = np.cumsum(explained_variance)
437 n_pc_needed = np.argmax(explained_variance_cum >= variance_threshold) + 1
439 # Store per layer
440 layer_id = str(layer_id)
441 print(
442 f"Layer {layer_id}: {n_pc_needed} PCs needed to explain {variance_threshold} variance "
443 f"with the 1st PC explaining {explained_variance[0]:.2f}% of the total variance"
444 )
446 variance_across_layers[f"n_comp-{method}_needed-{variance_threshold}"][
447 layer_id
448 ] = n_pc_needed
449 variance_across_layers[f"first_comp-{method}_explained_variance"][
450 layer_id
451 ] = explained_variance[0]
453 return variance_across_layers
456def get_layer_sparsity(
457 ann_encoded_dataset, zero_threshold: float = 0.0001, **kwargs
458) -> xr.Dataset:
459 """
460 Check how sparse activations within a given layer are.
462 Sparsity is defined as 1 - values below the zero_threshold / total number of values.
464 Args:
465 ann_encoded_dataset (xr.Dataset): ANN encoded dataset
466 zero_threshold (float): Threshold to use for determining sparsity (default: 0.0001)
467 **kwargs: Additional keyword arguments to pass to the underlying method
469 Returns:
470 sparsity_across_layers (dict): Nested dict with value of interest as key (e.g., sparsity) and
471 layer id as key (e.g., 0, 1, 2, ...) with corresponding values.
473 """
474 # Obtain embedding dimension (for sanity checks)
475 # if self.model_specs["hidden_emb_dim"]:
476 # hidden_emb_dim = self.model_specs["hidden_emb_dim"]
477 # else:
478 # hidden_emb_dim = None
479 # log(
480 # f"Hidden embedding dimension not specified yet",
481 # cmap="WARN",
482 # type="WARN",
483 # )
485 ks = [f"sparsity-{zero_threshold}"]
486 sparsity_across_layers = {k: {} for k in ks}
488 # Get the PCA explained variance per layer
489 layer_ids = ann_encoded_dataset.layer.values
490 _, unique_ixs = np.unique(layer_ids, return_index=True)
492 # Make sure that layer order is preserved
493 for layer_id in tqdm(layer_ids[np.sort(unique_ixs)]):
494 layer_dataset = (
495 ann_encoded_dataset.isel(neuroid=(ann_encoded_dataset.layer == layer_id))
496 .drop("timeid")
497 .squeeze()
498 )
500 # if hidden_emb_dim is not None:
501 # assert layer_dataset.shape[1] == hidden_emb_dim
502 #
503 # Get sparsity
504 zero_values = count_zero_threshold_values(layer_dataset.values, zero_threshold)
505 sparsity = 1 - (zero_values / layer_dataset.size)
507 # Store per layer
508 layer_id = str(layer_id)
509 print(f"Layer {layer_id}: {sparsity:.3f} sparsity")
511 sparsity_across_layers[f"sparsity-{zero_threshold}"][layer_id] = sparsity
513 return sparsity_across_layers
516def cos_contrib(
517 emb1: np.ndarray,
518 emb2: np.ndarray,
519):
520 """
521 Cosine contribution function defined in eq. 3 by Timkey & van Schijndel (2021): https://arxiv.org/abs/2109.04404
523 Args:
524 emb1 (np.ndarray): Embedding vector 1
525 emb2 (np.ndarray): Embedding vector 2
527 Returns:
528 cos_contrib (float): Cosine contribution
530 """
532 numerator_terms = emb1 * emb2
533 denom = np.linalg.norm(emb1) * np.linalg.norm(emb2)
534 return numerator_terms / denom
537def get_anisotropy(
538 ann_encoded_dataset: "EncoderRepresentations", num_random_samples: int = 1000
539):
540 """
541 Calculate the anisotropy of the embedding vectors as Timkey & van Schijndel (2021): https://arxiv.org/abs/2109.04404
542 (base function from their GitHub repo: https://github.com/wtimkey/rogue-dimensions/blob/main/replication.ipynb,
543 but modified to work within the Language Brain-Score project)
546 """
547 rogue_dist = []
548 num_toks = len(ann_encoded_dataset.sampleid) # Number of stimuli
550 # randomly sample embedding pairs to compute avg. cosine similiarity contribution
551 random_pairs = [
552 random.sample(range(num_toks), 2) for i in range(num_random_samples)
553 ]
555 cos_contribs_by_layer = []
557 layer_ids = ann_encoded_dataset.layer.values
558 _, unique_ixs = np.unique(layer_ids, return_index=True)
560 for layer_id in tqdm(layer_ids[np.sort(unique_ixs)]):
561 layer_dataset = (
562 ann_encoded_dataset.isel(neuroid=(ann_encoded_dataset.layer == layer_id))
563 .drop("timeid")
564 .squeeze()
565 )
567 layer_cosine_contribs = []
568 layer_rogue_cos_contribs = []
569 for pair in random_pairs:
570 emb1 = sample_data[layer, pair[0], :] # fix
571 emb2 = sample_data[layer, pair[1], :]
572 layer_cosine_contribs.append(cos_contrib(emb1, emb2))
574 layer_cosine_contribs = np.array(layer_cosine_contribs)
575 layer_cosine_sims = layer_cosine_contribs.sum(axis=1)
576 layer_cosine_contribs_mean = layer_cosine_contribs.mean(axis=0)
577 cos_contribs_by_layer.append(layer_cosine_contribs_mean)
578 cos_contribs_by_layer = np.array(cos_contribs_by_layer)
580 aniso = cos_contribs_by_layer.sum(
581 axis=1
582 ) # total anisotropy, measured as avg. cosine sim between random emb. pairs
584 for layer in range(num_layers[model_name]):
585 top_3_dims = np.argsort(cos_contribs_by_layer[layer])[-3:]
586 top = cos_contribs_by_layer[layer, top_3_dims[2]] / aniso[layer]
587 second = cos_contribs_by_layer[layer, top_3_dims[1]] / aniso[layer]
588 third = cos_contribs_by_layer[layer, top_3_dims[0]] / aniso[layer]
589 print(
590 "& {} & {:.3f} & {:.3f} & {:.3f} & {:.3f} \\\\".format(
591 layer, top, second, third, aniso[layer]
592 )
593 )