Coverage for langbrainscore/encoder/ann.py: 21%
122 statements
« prev ^ index » next coverage.py v6.4, created at 2022-06-07 21:22 +0000
« prev ^ index » next coverage.py v6.4, created at 2022-06-07 21:22 +0000
1import typing
2from enum import unique
4import os
5import numpy as np
6import torch
7from tqdm import tqdm
8import xarray as xr
10from langbrainscore.dataset import Dataset
11from langbrainscore.interface import EncoderRepresentations, _ModelEncoder
12from langbrainscore.utils.encoder import (
13 aggregate_layers,
14 cos_sim_matrix,
15 count_zero_threshold_values,
16 flatten_activations_per_sample,
17 get_context_groups,
18 get_torch_device,
19 pick_matching_token_ixs,
20 postprocess_activations,
21 repackage_flattened_activations,
22 encode_stimuli_in_context,
23)
25from langbrainscore.utils.logging import log
26from langbrainscore.utils.xarray import copy_metadata, fix_xr_dtypes
27from langbrainscore.utils.resources import model_classes, config_name_mappings
29os.environ["TOKENIZERS_PARALLELISM"] = "true"
32class HuggingFaceEncoder(_ModelEncoder):
33 def __init__(
34 self,
35 model_id,
36 emb_aggregation: typing.Union[str, None, typing.Callable],
37 device=get_torch_device(),
38 context_dimension: str = None,
39 bidirectional: bool = False,
40 emb_preproc: typing.Tuple[str] = (),
41 include_special_tokens: bool = True,
42 ):
43 """
44 Args:
45 model_id (str): the model id
46 device (None, ?): the device to use
47 context_dimension (str, optional): the dimension to use for extracting strings using context.
48 if None, each sampleid (stimuli) will be treated as a single context group.
49 if a string is specified, the string must refer to the name of a dimension in the xarray-like dataset
50 object (langbrainscore.dataset.Dataset) that provides groupings of sampleids (stimuli) that should be
51 used as context when generating encoder representations [default: None].
52 bidirectional (bool): whether to use bidirectional encoder (i.e., access both forward and backward context)
53 [default: False]
54 emb_aggregation (typing.Union[str, None, typing.Callable], optional): how to aggregate the hidden states of
55 the encoder representations for each sampleid (stimuli). [default: "last"]
56 emb_preproc (tuple): a list of strings specifying preprocessing functions to apply to the aggregated embeddings.
57 Processing is performed layer-wise.
58 include_special_tokens (bool): whether to include special tokens in the encoder representations.
59 """
61 super().__init__(
62 model_id,
63 _context_dimension=context_dimension,
64 _bidirectional=bidirectional,
65 _emb_aggregation=emb_aggregation,
66 _emb_preproc=emb_preproc,
67 _include_special_tokens=include_special_tokens,
68 )
70 from transformers import AutoConfig, AutoModel, AutoTokenizer
71 from transformers import logging as transformers_logging
73 transformers_logging.set_verbosity_error()
75 self.device = device or get_torch_device()
76 self.config = AutoConfig.from_pretrained(self._model_id)
77 self.tokenizer = AutoTokenizer.from_pretrained(
78 self._model_id, multiprocessing=True
79 )
80 self.model = AutoModel.from_pretrained(self._model_id, config=self.config)
81 try:
82 self.model = self.model.to(self.device)
83 except RuntimeError:
84 self.device = "cpu"
85 self.model = self.model.to(self.device)
87 def get_encoder_representations_template(
88 self, dataset=None, representations=xr.DataArray()
89 ) -> EncoderRepresentations:
90 """
91 returns an empty `EncoderRepresentations` object with all the appropriate
92 attributes but the `dataset` and `representations` missing and to be filled in
93 later.
94 """
95 return EncoderRepresentations(
96 dataset=dataset,
97 representations=representations,
98 model_id=self._model_id,
99 context_dimension=self._context_dimension,
100 bidirectional=self._bidirectional,
101 emb_aggregation=self._emb_aggregation,
102 emb_preproc=self._emb_preproc,
103 include_special_tokens=self._include_special_tokens,
104 )
106 def encode(
107 self,
108 dataset: Dataset,
109 read_cache: bool = True, # avoid recomputing if cached `EncoderRepresentations` exists, recompute if not
110 write_cache: bool = True, # dump the result of this computation to cache?
111 ) -> EncoderRepresentations:
112 """
113 Input a langbrainscore Dataset, encode the stimuli according to the parameters specified in init, and return
114 the an xarray DataArray of aggregated representations for each stimulus.
116 Args:
117 dataset (langbrainscore.dataset.DataSet): [description]
118 read_cache (bool): Avoid recomputing if cached `EncoderRepresentations` exists, recompute if not
119 write_cache (bool): Dump and write the result of the computed encoder representations to cache
121 Raises:
122 NotImplementedError: [description]
123 ValueError: [description]
125 Returns:
126 [type]: [description]
127 """
129 # before computing the representations from scratch, we will first see if any
130 # cached representations exist already.
132 if read_cache:
133 to_check_in_cache: EncoderRepresentations = (
134 self.get_encoder_representations_template(dataset=dataset)
135 )
137 try:
138 to_check_in_cache.load_cache()
139 return to_check_in_cache
140 except FileNotFoundError:
141 log(
142 f"couldn't load cached reprs for {to_check_in_cache.identifier_string}; recomputing.",
143 cmap="WARN",
144 type="WARN",
145 )
147 self.model.eval()
148 stimuli = dataset.stimuli.values
150 # Initialize the context group coordinate (obtain embeddings with context)
151 context_groups = get_context_groups(dataset, self._context_dimension)
153 # list for storing activations for each stimulus with all layers flattened
154 # list for storing layer ids ([0 0 0 0 ... 1 1 1 ...]) indicating which layer each
155 # neuroid (representation dimension) came from
156 flattened_activations, layer_ids = [], []
158 ###############################################################################
159 # ALL SAMPLES LOOP
160 ###############################################################################
161 _, unique_ixs = np.unique(context_groups, return_index=True)
162 # Make sure context group order is preserved
163 for group in tqdm(context_groups[np.sort(unique_ixs)], desc="Encoding stimuli"):
164 # Mask based on the context group
165 mask_context = context_groups == group
166 stimuli_in_context = stimuli[mask_context]
168 # store model states for each stimulus in this context group
169 encoded_stimuli = []
171 ###############################################################################
172 # CONTEXT LOOP
173 ###############################################################################
174 for encoded_stim in encode_stimuli_in_context(
175 stimuli_in_context=stimuli_in_context,
176 tokenizer=self.tokenizer,
177 model=self.model,
178 bidirectional=self._bidirectional,
179 include_special_tokens=self._include_special_tokens,
180 emb_aggregation=self._emb_aggregation,
181 device=self.device,
182 ):
183 encoded_stimuli += [encoded_stim]
184 ###############################################################################
185 # END CONTEXT LOOP
186 ###############################################################################
188 # Flatten activations across layers and package as xarray
189 flattened_activations_and_layer_ids = [
190 *map(flatten_activations_per_sample, encoded_stimuli)
191 ]
192 for f_as, l_ids in flattened_activations_and_layer_ids:
193 flattened_activations += [f_as]
194 layer_ids += [l_ids]
195 assert len(f_as) == len(l_ids) # Assert all layer lists are equal
197 ###############################################################################
198 # END ALL SAMPLES LOOP
199 ###############################################################################
201 # Stack flattened activations and layer ids to obtain [n_samples, emb_din * n_layers]
202 activations_2d = np.vstack(flattened_activations)
203 layer_ids_1d = np.squeeze(np.unique(np.vstack(layer_ids), axis=0))
205 # Post-process activations after obtaining them (or "pre-process" them before computing brainscore)
206 if len(self._emb_preproc) > 0:
207 for mode in self._emb_preproc:
208 activations_2d, layer_ids_1d = postprocess_activations(
209 activations_2d=activations_2d,
210 layer_ids_1d=layer_ids_1d,
211 emb_preproc_mode=mode,
212 )
214 assert activations_2d.shape[1] == len(layer_ids_1d)
215 assert activations_2d.shape[0] == len(stimuli)
217 # Package activations as xarray and reapply metadata
218 encoded_dataset: xr.DataArray = repackage_flattened_activations(
219 activations_2d=activations_2d,
220 layer_ids_1d=layer_ids_1d,
221 dataset=dataset,
222 )
223 encoded_dataset: xr.DataArray = copy_metadata(
224 encoded_dataset,
225 dataset.contents,
226 "sampleid",
227 )
229 to_return: EncoderRepresentations = self.get_encoder_representations_template()
230 to_return.dataset = dataset
231 to_return.representations = fix_xr_dtypes(encoded_dataset)
233 if write_cache:
234 to_return.to_cache(overwrite=True)
236 return to_return
238 def get_modelcard(self):
239 """
240 Returns the model card of the model (model-wise, and not layer-wise)
241 """
243 model_classes = [
244 "gpt",
245 "bert",
246 ] # continuously update based on new model classes supported
248 # based on the model_id, figure out which model class it is
249 model_class = [x for x in model_classes if x in self._model_id][0]
250 assert model_class is not None, f"model_id {self._model_id} not supported"
252 config_specs_of_interest = config_name_mappings[model_class]
254 model_specs = {}
255 for (
256 k_spec,
257 v_spec,
258 ) in (
259 config_specs_of_interest.items()
260 ): # key is the name we want to use in the model card,
261 # value is the name in the config
262 if v_spec is not None:
263 model_specs[k_spec] = getattr(self.config, v_spec)
264 else:
265 model_specs[k_spec] = None
267 self.model_specs = model_specs
269 return model_specs
272class PTEncoder(_ModelEncoder):
273 def __init__(self, model_id: str) -> "PTEncoder":
274 super().__init__(model_id)
276 def encode(self, dataset: "langbrainscore.dataset.Dataset") -> xr.DataArray:
277 # TODO
278 ...
281class EncoderCheck:
282 """
283 Class for checking whether obtained embeddings from the Encoder class are correct and similar to other encoder objects.
284 """
286 def __init__(
287 self,
288 ):
289 pass
291 def _load_cached_activations(self, encoded_ann_identifier: str):
292 raise NotImplementedError
294 def similiarity_metric_across_layers(
295 self,
296 sim_metric: str = "tol",
297 enc1: xr.DataArray = None,
298 enc2: xr.DataArray = None,
299 tol: float = 1e-8,
300 threshold: float = 1e-4,
301 ) -> bool:
302 """
303 Given two activations, iterate across layers and check np.allclose using different tolerance levels.
305 Parameters:
306 sim_metric: str
307 Similarity metric to use.
308 enc1: xr.DataArray
309 First encoder activations.
310 enc2: xr.DataArray
311 Second encoder activations.
312 tol: float
313 Tolerance level to start at (we will iterate upwards the tolerance level). Default is 1e-8.
315 Returns:
316 bool: whether the tolerance level was met (True) or not (False)
317 bad_stim: set of stimuli indices that did not meet tolerance level `threshold` (if any)
319 """
320 # First check is whether number of layers / shapes match
321 assert enc1.shape == enc2.shape
322 assert (
323 enc1.sampleid.values == enc2.sampleid.values
324 ).all() # ensure that we are looking at the same stimuli
325 layer_ids = enc1.layer.values
326 _, unique_ixs = np.unique(layer_ids, return_index=True)
327 print(f"\n\nChecking similarity across layers using sim_metric: {sim_metric}")
329 all_good = True
330 bad_stim = set() # store indices of stimuli that are not similar
332 # Iterate across layers
333 for layer_id in tqdm(layer_ids[np.sort(unique_ixs)]):
334 enc1_layer = enc1.isel(neuroid=(enc1.layer == layer_id)) # .squeeze()
335 enc2_layer = enc2.isel(neuroid=(enc2.layer == layer_id)) # .squeeze()
337 # Check whether values match. If not, iteratively increase tolerance until values match
338 if sim_metric in ("tol", "diff"):
339 abs_diff = np.abs(enc1_layer - enc2_layer)
340 abs_diff_per_stim = np.max(
341 abs_diff, axis=1
342 ) # Obtain the biggest difference aross neuroids (units)
343 while (abs_diff_per_stim > tol).all():
344 tol *= 10
346 elif "cos" in sim_metric:
347 # Check cosine distance between each row, e.g., sentence vector
348 cos_sim = cos_sim_matrix(enc1_layer, enc2_layer)
349 cos_dist = (
350 1 - cos_sim
351 ) # 0 means identical, 1 means orthogonal, 2 means opposite
352 # We still want this as close to zero as possible for similar vectors.
353 cos_dist_abs = np.abs(cos_dist)
354 abs_diff_per_stim = cos_dist_abs
356 # Check how close the cosine distance is to 0
357 while (cos_dist_abs > tol).all():
358 tol *= 10
359 else:
360 raise NotImplementedError(f"Invalid `sim_metric`: {sim_metric}")
362 print(f"Layer {layer_id}: Similarity at tolerance: {tol:.3e}")
363 if tol > threshold:
364 print(f"WARNING: Low tolerance level")
365 all_good = False
366 bad_stim.update(
367 enc1.sampleid[np.where(abs_diff_per_stim > tol)[0]]
368 ) # get sampleids of stimuli that are not similar
370 return all_good, bad_stim