Coverage for langbrainscore/interface/cacheable.py: 18%

137 statements  

« prev     ^ index     » next       coverage.py v6.4, created at 2022-06-07 21:22 +0000

1import pickle 

2import typing 

3from abc import ABC, abstractclassmethod, abstractmethod 

4from numbers import Number 

5from pathlib import Path 

6 

7import xarray as xr 

8import numpy as np 

9import yaml 

10from langbrainscore.utils.cache import get_cache_directory, pathify 

11from langbrainscore.utils.logging import log 

12 

13# from langbrainscore.interface.dryrunnable import _DryRunnable 

14 

15T = typing.TypeVar("T") 

16 

17 

18@typing.runtime_checkable 

19class _Cacheable(typing.Protocol): 

20 """ 

21 A class used to define a common interface for Object caching in LangBrainscore 

22 """ 

23 

24 def __eq__(o1: "_Cacheable", o2: "_Cacheable") -> bool: 

25 def checkattr(key) -> bool: 

26 """helper function to check if an attribute is the same between two objects 

27 and handles AttributeError while at it. if the attributes differ (or does 

28 not exist on one or the other object), returns False. 

29 """ 

30 try: 

31 if getattr(o1, key) != getattr(o2, key): 

32 return False 

33 except AttributeError: 

34 return False 

35 return True 

36 

37 for key, ob in vars(o1).items(): 

38 if isinstance(ob, (str, Number, bool, _Cacheable, tuple, type(None))): 

39 if not checkattr(key): 

40 log(f"{o1} and {o2} differ on {key}", cmap="ERR") 

41 return False 

42 elif isinstance(ob, xr.DataArray): 

43 x1 = getattr(o1, key) 

44 x2 = getattr(o2, key) 

45 if (not np.allclose(x1.data, x2.data, equal_nan=True, atol=1e-4)) or ( 

46 x1.attrs != x2.attrs 

47 ): 

48 log(f"{o1} and {o2} differ on {key}", cmap="ERR") 

49 return False 

50 else: 

51 return True 

52 

53 # @abstractclassmethod 

54 # @classmethod 

55 def _get_xarray_objects(self) -> typing.Iterable[str]: 

56 """ 

57 returns the *names* of all attributes of self that are instances of xarray 

58 NOTE: this method should be implemented by any subclass irrespective of instance 

59 state so that in the future we can support loading from cache without having 

60 to re-run the pipeline (and thereby assign attributes as appropriate) 

61 by default, just goes over all the objects and returns their names if they are instances 

62 of `xr.DataArray` 

63 """ 

64 keys = [] 

65 for key, ob in vars(self).items(): 

66 if isinstance(ob, xr.DataArray): 

67 keys += [key] 

68 return keys 

69 

70 @property 

71 def params(self) -> dict: 

72 """ """ 

73 params = {} 

74 for key in sorted(vars(self)): 

75 ob = getattr(self, key) 

76 if isinstance(ob, (str, Number, bool, _Cacheable, tuple, dict, type(None))): 

77 # if isinstance(ob, (str, Number, bool, _Cacheable, tuple)): 

78 if isinstance(ob, _Cacheable): 

79 params[key] = ob.identifier_string 

80 elif isinstance(ob, dict): 

81 for k in ob: 

82 params[f"{key}_{k}"] = ob[k] 

83 pass # TODO!! 

84 else: 

85 params[key] = ob 

86 return params 

87 

88 def __repr__(self) -> str: 

89 """ 

90 default, broad implementation to support our use case. 

91 constructs a string by concatenating all str, numeric, boolean 

92 attributes of self, as well as all the representations of Cacheable 

93 instances that are attributes of self. 

94 """ 

95 left = "(" 

96 right = ")" 

97 sep = "?" 

98 rep = f"{left}{self.__class__.__name__}" 

99 params = self.params 

100 for key in sorted([*params.keys()]): 

101 val = params[key] 

102 rep += f"{sep}{key}={val}" 

103 return rep + f"{right}" 

104 

105 @property 

106 def identifier_string(self): 

107 """ 

108 This property aims to return an unambiguous representation of this _Cacheable 

109 instance, complete with all scalar parameters used to initialize it, and any 

110 _Cacheable instances that are attributes of this object. 

111 

112 Unless overridden, makes a call to `repr` 

113 """ 

114 return repr(self) 

115 

116 def to_cache( 

117 self, 

118 identifier_string: str = None, 

119 overwrite=True, 

120 cache_dir=None, 

121 xarray_serialization_backend="to_zarr", 

122 ) -> Path: 

123 """ 

124 dump this object to cache. this method implementation will serve 

125 as the default implementation. it is recommended that this be left 

126 as-is for compatibility with caching across the library. 

127 

128 Args: 

129 identifier_string (str): a unique identifier string to identify this cache 

130 instance by (optional; by default, the .identifier_string property is used) 

131 overwrite (bool): whether to overwrite existing cache by the same identity, 

132 if it exists. if False, an exce 

133 """ 

134 if cache_dir: 

135 cache = get_cache_directory( 

136 cache_dir, calling_class=self.__class__.__name__ 

137 ) 

138 else: 

139 cache = get_cache_directory(calling_class=self.__class__.__name__) 

140 

141 root, subdir = cache.root, cache.subdir 

142 # now we use "subdir" to be our working directory to dump this cache object 

143 subdir /= identifier_string or self.identifier_string 

144 subdir.mkdir(parents=True, exist_ok=overwrite) 

145 log(f"caching {self} to {subdir}") 

146 

147 with (subdir / "xarray_object_names.yml").open("w") as f: 

148 yaml.dump(self._get_xarray_objects(), f, yaml.SafeDumper) 

149 with (subdir / "id.txt").open("w") as f: 

150 f.write(self.identifier_string) 

151 

152 kwargs = {} 

153 if overwrite and "zarr" in xarray_serialization_backend: 

154 kwargs.update({"mode": "w"}) 

155 for ob_name in self._get_xarray_objects(): 

156 ob = getattr(self, ob_name) 

157 tgt_dir = subdir / (ob_name + ".xr") 

158 dump_object_fn = getattr( 

159 ob.to_dataset(name="data"), xarray_serialization_backend 

160 ) 

161 dump_object_fn(tgt_dir, **kwargs) 

162 

163 cacheable_ptrs = {} 

164 meta_attributes = {} 

165 for key, ob in vars(self).items(): 

166 if isinstance(ob, _Cacheable): 

167 dest = ob.to_cache( 

168 identifier_string=identifier_string, 

169 overwrite=overwrite, 

170 xarray_serialization_backend=xarray_serialization_backend, 

171 cache_dir=cache_dir, 

172 ) 

173 cacheable_ptrs[key] = str(dest) 

174 elif isinstance(ob, (str, Number, bool, _Cacheable, type(None))): 

175 meta_attributes[key] = ob 

176 with (subdir / "meta_attributes.yml").open("w") as f: 

177 yaml.dump(meta_attributes, f, yaml.SafeDumper) 

178 with (subdir / "cacheable_object_pointers.yml").open("w") as f: 

179 yaml.dump(cacheable_ptrs, f, yaml.SafeDumper) 

180 

181 return subdir 

182 

183 def load_cache( 

184 self, 

185 identifier_string: str = None, 

186 overwrite: bool = True, 

187 xarray_deserialization_backend="open_zarr", 

188 cache_dir=None, 

189 ) -> Path: 

190 """load attribute objects from cache onto the existing initialized object (self)""" 

191 

192 if cache_dir: 

193 cache = get_cache_directory( 

194 cache_dir, calling_class=self.__class__.__name__ 

195 ) 

196 else: 

197 cache = get_cache_directory(calling_class=self.__class__.__name__) 

198 

199 root, subdir = cache.root, cache.subdir 

200 # now we use "subdir" as our working directory to dump this cache object 

201 subdir /= identifier_string or self.identifier_string 

202 log(f"attempt loading attributes of {self} from {subdir.parent}") 

203 

204 with (subdir / "xarray_object_names.yml").open("r") as f: 

205 self_xarray_objects = yaml.load(f, yaml.SafeLoader) 

206 

207 with (subdir / "id.txt").open("r") as f: 

208 if (identifier_string or self.identifier_string) != ( 

209 cached_identifier_str := f.read() 

210 ): 

211 if not overwrite: 

212 raise ValueError( 

213 f"mismatch in identifier string of self ({self.identifier_string}) and " 

214 f"cached object ({cached_identifier_str}); overwriting is disabled." 

215 ) 

216 else: 

217 log( 

218 f"mismatch in identifier string of self ({self.identifier_string}) and " 

219 f"cached object ({cached_identifier_str}); overwriting anyway." 

220 ) 

221 

222 kwargs = {} 

223 for ob_name in self_xarray_objects: 

224 tgt_dir = subdir / (ob_name + ".xr") 

225 load_object_fn = getattr(xr, xarray_deserialization_backend) 

226 ob = load_object_fn(tgt_dir, **kwargs) 

227 setattr(self, ob_name, ob.data) 

228 

229 with (subdir / "cacheable_object_pointers.yml").open("r") as f: 

230 cacheable_ptrs: dict = yaml.load(f, yaml.SafeLoader) 

231 

232 # calls `load_cache` on all attributes that are also `_Cacheable` instances 

233 # and thus implement the `load_cache` method 

234 for key, ptr in cacheable_ptrs.items(): 

235 try: 

236 ob = getattr(self, key) 

237 ob.load_cache( 

238 identifier_string=identifier_string, 

239 overwrite=overwrite, 

240 xarray_deserialization_backend=xarray_deserialization_backend, 

241 cache_dir=cache_dir, 

242 ) 

243 except AttributeError: 

244 log( 

245 f"`load_cache` currently only supports loading xarray objects or initialized `_Cacheable` objects" 

246 ) 

247 

248 with (subdir / "meta_attributes.yml").open("r") as f: 

249 meta_attributes: dict = yaml.load(f, yaml.SafeLoader) 

250 for key, ob in meta_attributes.items(): 

251 setattr(self, key, ob) 

252 

253 # NB comment from Guido: https://github.com/python/typing/issues/58#issuecomment-194569410 

254 @classmethod 

255 def from_cache( 

256 cls: typing.Callable[..., T], 

257 identifier_string: str, 

258 xarray_deserialization_backend="open_zarr", 

259 cache_dir=None, 

260 ) -> T: 

261 """ 

262 construct an object from cache. subclasses must start with the 

263 object returned by a call to this method like so: 

264 

265 ob = super().from_cache(filename) 

266 # further implementation, such as initializing 

267 # member classes based on metadata 

268 return ob 

269 

270 """ 

271 

272 Duck = type(cls.__name__, (cls,), {"__init__": (lambda _: None)}) 

273 duck = Duck() 

274 duck.load_cache( 

275 identifier_string, 

276 overwrite=True, 

277 xarray_deserialization_backend=xarray_deserialization_backend, 

278 cache_dir=cache_dir, 

279 ) 

280 return duck