Coverage for langbrainscore/interface/cacheable.py: 18%
137 statements
« prev ^ index » next coverage.py v6.4, created at 2022-06-07 21:22 +0000
« prev ^ index » next coverage.py v6.4, created at 2022-06-07 21:22 +0000
1import pickle
2import typing
3from abc import ABC, abstractclassmethod, abstractmethod
4from numbers import Number
5from pathlib import Path
7import xarray as xr
8import numpy as np
9import yaml
10from langbrainscore.utils.cache import get_cache_directory, pathify
11from langbrainscore.utils.logging import log
13# from langbrainscore.interface.dryrunnable import _DryRunnable
15T = typing.TypeVar("T")
18@typing.runtime_checkable
19class _Cacheable(typing.Protocol):
20 """
21 A class used to define a common interface for Object caching in LangBrainscore
22 """
24 def __eq__(o1: "_Cacheable", o2: "_Cacheable") -> bool:
25 def checkattr(key) -> bool:
26 """helper function to check if an attribute is the same between two objects
27 and handles AttributeError while at it. if the attributes differ (or does
28 not exist on one or the other object), returns False.
29 """
30 try:
31 if getattr(o1, key) != getattr(o2, key):
32 return False
33 except AttributeError:
34 return False
35 return True
37 for key, ob in vars(o1).items():
38 if isinstance(ob, (str, Number, bool, _Cacheable, tuple, type(None))):
39 if not checkattr(key):
40 log(f"{o1} and {o2} differ on {key}", cmap="ERR")
41 return False
42 elif isinstance(ob, xr.DataArray):
43 x1 = getattr(o1, key)
44 x2 = getattr(o2, key)
45 if (not np.allclose(x1.data, x2.data, equal_nan=True, atol=1e-4)) or (
46 x1.attrs != x2.attrs
47 ):
48 log(f"{o1} and {o2} differ on {key}", cmap="ERR")
49 return False
50 else:
51 return True
53 # @abstractclassmethod
54 # @classmethod
55 def _get_xarray_objects(self) -> typing.Iterable[str]:
56 """
57 returns the *names* of all attributes of self that are instances of xarray
58 NOTE: this method should be implemented by any subclass irrespective of instance
59 state so that in the future we can support loading from cache without having
60 to re-run the pipeline (and thereby assign attributes as appropriate)
61 by default, just goes over all the objects and returns their names if they are instances
62 of `xr.DataArray`
63 """
64 keys = []
65 for key, ob in vars(self).items():
66 if isinstance(ob, xr.DataArray):
67 keys += [key]
68 return keys
70 @property
71 def params(self) -> dict:
72 """ """
73 params = {}
74 for key in sorted(vars(self)):
75 ob = getattr(self, key)
76 if isinstance(ob, (str, Number, bool, _Cacheable, tuple, dict, type(None))):
77 # if isinstance(ob, (str, Number, bool, _Cacheable, tuple)):
78 if isinstance(ob, _Cacheable):
79 params[key] = ob.identifier_string
80 elif isinstance(ob, dict):
81 for k in ob:
82 params[f"{key}_{k}"] = ob[k]
83 pass # TODO!!
84 else:
85 params[key] = ob
86 return params
88 def __repr__(self) -> str:
89 """
90 default, broad implementation to support our use case.
91 constructs a string by concatenating all str, numeric, boolean
92 attributes of self, as well as all the representations of Cacheable
93 instances that are attributes of self.
94 """
95 left = "("
96 right = ")"
97 sep = "?"
98 rep = f"{left}{self.__class__.__name__}"
99 params = self.params
100 for key in sorted([*params.keys()]):
101 val = params[key]
102 rep += f"{sep}{key}={val}"
103 return rep + f"{right}"
105 @property
106 def identifier_string(self):
107 """
108 This property aims to return an unambiguous representation of this _Cacheable
109 instance, complete with all scalar parameters used to initialize it, and any
110 _Cacheable instances that are attributes of this object.
112 Unless overridden, makes a call to `repr`
113 """
114 return repr(self)
116 def to_cache(
117 self,
118 identifier_string: str = None,
119 overwrite=True,
120 cache_dir=None,
121 xarray_serialization_backend="to_zarr",
122 ) -> Path:
123 """
124 dump this object to cache. this method implementation will serve
125 as the default implementation. it is recommended that this be left
126 as-is for compatibility with caching across the library.
128 Args:
129 identifier_string (str): a unique identifier string to identify this cache
130 instance by (optional; by default, the .identifier_string property is used)
131 overwrite (bool): whether to overwrite existing cache by the same identity,
132 if it exists. if False, an exce
133 """
134 if cache_dir:
135 cache = get_cache_directory(
136 cache_dir, calling_class=self.__class__.__name__
137 )
138 else:
139 cache = get_cache_directory(calling_class=self.__class__.__name__)
141 root, subdir = cache.root, cache.subdir
142 # now we use "subdir" to be our working directory to dump this cache object
143 subdir /= identifier_string or self.identifier_string
144 subdir.mkdir(parents=True, exist_ok=overwrite)
145 log(f"caching {self} to {subdir}")
147 with (subdir / "xarray_object_names.yml").open("w") as f:
148 yaml.dump(self._get_xarray_objects(), f, yaml.SafeDumper)
149 with (subdir / "id.txt").open("w") as f:
150 f.write(self.identifier_string)
152 kwargs = {}
153 if overwrite and "zarr" in xarray_serialization_backend:
154 kwargs.update({"mode": "w"})
155 for ob_name in self._get_xarray_objects():
156 ob = getattr(self, ob_name)
157 tgt_dir = subdir / (ob_name + ".xr")
158 dump_object_fn = getattr(
159 ob.to_dataset(name="data"), xarray_serialization_backend
160 )
161 dump_object_fn(tgt_dir, **kwargs)
163 cacheable_ptrs = {}
164 meta_attributes = {}
165 for key, ob in vars(self).items():
166 if isinstance(ob, _Cacheable):
167 dest = ob.to_cache(
168 identifier_string=identifier_string,
169 overwrite=overwrite,
170 xarray_serialization_backend=xarray_serialization_backend,
171 cache_dir=cache_dir,
172 )
173 cacheable_ptrs[key] = str(dest)
174 elif isinstance(ob, (str, Number, bool, _Cacheable, type(None))):
175 meta_attributes[key] = ob
176 with (subdir / "meta_attributes.yml").open("w") as f:
177 yaml.dump(meta_attributes, f, yaml.SafeDumper)
178 with (subdir / "cacheable_object_pointers.yml").open("w") as f:
179 yaml.dump(cacheable_ptrs, f, yaml.SafeDumper)
181 return subdir
183 def load_cache(
184 self,
185 identifier_string: str = None,
186 overwrite: bool = True,
187 xarray_deserialization_backend="open_zarr",
188 cache_dir=None,
189 ) -> Path:
190 """load attribute objects from cache onto the existing initialized object (self)"""
192 if cache_dir:
193 cache = get_cache_directory(
194 cache_dir, calling_class=self.__class__.__name__
195 )
196 else:
197 cache = get_cache_directory(calling_class=self.__class__.__name__)
199 root, subdir = cache.root, cache.subdir
200 # now we use "subdir" as our working directory to dump this cache object
201 subdir /= identifier_string or self.identifier_string
202 log(f"attempt loading attributes of {self} from {subdir.parent}")
204 with (subdir / "xarray_object_names.yml").open("r") as f:
205 self_xarray_objects = yaml.load(f, yaml.SafeLoader)
207 with (subdir / "id.txt").open("r") as f:
208 if (identifier_string or self.identifier_string) != (
209 cached_identifier_str := f.read()
210 ):
211 if not overwrite:
212 raise ValueError(
213 f"mismatch in identifier string of self ({self.identifier_string}) and "
214 f"cached object ({cached_identifier_str}); overwriting is disabled."
215 )
216 else:
217 log(
218 f"mismatch in identifier string of self ({self.identifier_string}) and "
219 f"cached object ({cached_identifier_str}); overwriting anyway."
220 )
222 kwargs = {}
223 for ob_name in self_xarray_objects:
224 tgt_dir = subdir / (ob_name + ".xr")
225 load_object_fn = getattr(xr, xarray_deserialization_backend)
226 ob = load_object_fn(tgt_dir, **kwargs)
227 setattr(self, ob_name, ob.data)
229 with (subdir / "cacheable_object_pointers.yml").open("r") as f:
230 cacheable_ptrs: dict = yaml.load(f, yaml.SafeLoader)
232 # calls `load_cache` on all attributes that are also `_Cacheable` instances
233 # and thus implement the `load_cache` method
234 for key, ptr in cacheable_ptrs.items():
235 try:
236 ob = getattr(self, key)
237 ob.load_cache(
238 identifier_string=identifier_string,
239 overwrite=overwrite,
240 xarray_deserialization_backend=xarray_deserialization_backend,
241 cache_dir=cache_dir,
242 )
243 except AttributeError:
244 log(
245 f"`load_cache` currently only supports loading xarray objects or initialized `_Cacheable` objects"
246 )
248 with (subdir / "meta_attributes.yml").open("r") as f:
249 meta_attributes: dict = yaml.load(f, yaml.SafeLoader)
250 for key, ob in meta_attributes.items():
251 setattr(self, key, ob)
253 # NB comment from Guido: https://github.com/python/typing/issues/58#issuecomment-194569410
254 @classmethod
255 def from_cache(
256 cls: typing.Callable[..., T],
257 identifier_string: str,
258 xarray_deserialization_backend="open_zarr",
259 cache_dir=None,
260 ) -> T:
261 """
262 construct an object from cache. subclasses must start with the
263 object returned by a call to this method like so:
265 ob = super().from_cache(filename)
266 # further implementation, such as initializing
267 # member classes based on metadata
268 return ob
270 """
272 Duck = type(cls.__name__, (cls,), {"__init__": (lambda _: None)})
273 duck = Duck()
274 duck.load_cache(
275 identifier_string,
276 overwrite=True,
277 xarray_deserialization_backend=xarray_deserialization_backend,
278 cache_dir=cache_dir,
279 )
280 return duck