Module langbrainscore.interface.cacheable

Expand source code
import pickle
import typing
from abc import ABC, abstractclassmethod, abstractmethod
from numbers import Number
from pathlib import Path

import xarray as xr
import numpy as np
import yaml
from langbrainscore.utils.cache import get_cache_directory, pathify
from langbrainscore.utils.logging import log

# from langbrainscore.interface.dryrunnable import _DryRunnable

T = typing.TypeVar("T")


@typing.runtime_checkable
class _Cacheable(typing.Protocol):
    """
    A class used to define a common interface for Object caching in LangBrainscore
    """

    def __eq__(o1: "_Cacheable", o2: "_Cacheable") -> bool:
        def checkattr(key) -> bool:
            """helper function to check if an attribute is the same between two objects
            and handles AttributeError while at it. if the attributes differ (or does
            not exist on one or the other object), returns False.
            """
            try:
                if getattr(o1, key) != getattr(o2, key):
                    return False
            except AttributeError:
                return False
            return True

        for key, ob in vars(o1).items():
            if isinstance(ob, (str, Number, bool, _Cacheable, tuple, type(None))):
                if not checkattr(key):
                    log(f"{o1} and {o2} differ on {key}", cmap="ERR")
                    return False
            elif isinstance(ob, xr.DataArray):
                x1 = getattr(o1, key)
                x2 = getattr(o2, key)
                if (not np.allclose(x1.data, x2.data, equal_nan=True, atol=1e-4)) or (
                    x1.attrs != x2.attrs
                ):
                    log(f"{o1} and {o2} differ on {key}", cmap="ERR")
                    return False
        else:
            return True

    # @abstractclassmethod
    # @classmethod
    def _get_xarray_objects(self) -> typing.Iterable[str]:
        """
        returns the *names* of all attributes of self that are instances of xarray
        NOTE: this method should be implemented by any subclass irrespective of instance
            state so that in the future we can support loading from cache without having
            to re-run the pipeline (and thereby assign attributes as appropriate)
        by default, just goes over all the objects and returns their names if they are instances
        of `xr.DataArray`
        """
        keys = []
        for key, ob in vars(self).items():
            if isinstance(ob, xr.DataArray):
                keys += [key]
        return keys

    @property
    def params(self) -> dict:
        """ """
        params = {}
        for key in sorted(vars(self)):
            ob = getattr(self, key)
            if isinstance(ob, (str, Number, bool, _Cacheable, tuple, dict, type(None))):
                # if isinstance(ob, (str, Number, bool, _Cacheable, tuple)):
                if isinstance(ob, _Cacheable):
                    params[key] = ob.identifier_string
                elif isinstance(ob, dict):
                    for k in ob:
                        params[f"{key}_{k}"] = ob[k]
                    pass  # TODO!!
                else:
                    params[key] = ob
        return params

    def __repr__(self) -> str:
        """
        default, broad implementation to support our use case.
        constructs a string by concatenating all str, numeric, boolean
        attributes of self, as well as all the representations of Cacheable
        instances that are attributes of self.
        """
        left = "("
        right = ")"
        sep = "?"
        rep = f"{left}{self.__class__.__name__}"
        params = self.params
        for key in sorted([*params.keys()]):
            val = params[key]
            rep += f"{sep}{key}={val}"
        return rep + f"{right}"

    @property
    def identifier_string(self):
        """
        This property aims to return an unambiguous representation of this _Cacheable
        instance, complete with all scalar parameters used to initialize it, and any
        _Cacheable instances that are attributes of this object.

        Unless overridden, makes a call to `repr`
        """
        return repr(self)

    def to_cache(
        self,
        identifier_string: str = None,
        overwrite=True,
        cache_dir=None,
        xarray_serialization_backend="to_zarr",
    ) -> Path:
        """
        dump this object to cache. this method implementation will serve
        as the default implementation. it is recommended that this be left
        as-is for compatibility with caching across the library.

        Args:
            identifier_string (str): a unique identifier string to identify this cache
                instance by (optional; by default, the .identifier_string property is used)
            overwrite (bool): whether to overwrite existing cache by the same identity,
                if it exists. if False, an exce
        """
        if cache_dir:
            cache = get_cache_directory(
                cache_dir, calling_class=self.__class__.__name__
            )
        else:
            cache = get_cache_directory(calling_class=self.__class__.__name__)

        root, subdir = cache.root, cache.subdir
        # now we use "subdir" to be our working directory to dump this cache object
        subdir /= identifier_string or self.identifier_string
        subdir.mkdir(parents=True, exist_ok=overwrite)
        log(f"caching {self} to {subdir}")

        with (subdir / "xarray_object_names.yml").open("w") as f:
            yaml.dump(self._get_xarray_objects(), f, yaml.SafeDumper)
        with (subdir / "id.txt").open("w") as f:
            f.write(self.identifier_string)

        kwargs = {}
        if overwrite and "zarr" in xarray_serialization_backend:
            kwargs.update({"mode": "w"})
        for ob_name in self._get_xarray_objects():
            ob = getattr(self, ob_name)
            tgt_dir = subdir / (ob_name + ".xr")
            dump_object_fn = getattr(
                ob.to_dataset(name="data"), xarray_serialization_backend
            )
            dump_object_fn(tgt_dir, **kwargs)

        cacheable_ptrs = {}
        meta_attributes = {}
        for key, ob in vars(self).items():
            if isinstance(ob, _Cacheable):
                dest = ob.to_cache(
                    identifier_string=identifier_string,
                    overwrite=overwrite,
                    xarray_serialization_backend=xarray_serialization_backend,
                    cache_dir=cache_dir,
                )
                cacheable_ptrs[key] = str(dest)
            elif isinstance(ob, (str, Number, bool, _Cacheable, type(None))):
                meta_attributes[key] = ob
        with (subdir / "meta_attributes.yml").open("w") as f:
            yaml.dump(meta_attributes, f, yaml.SafeDumper)
        with (subdir / "cacheable_object_pointers.yml").open("w") as f:
            yaml.dump(cacheable_ptrs, f, yaml.SafeDumper)

        return subdir

    def load_cache(
        self,
        identifier_string: str = None,
        overwrite: bool = True,
        xarray_deserialization_backend="open_zarr",
        cache_dir=None,
    ) -> Path:
        """load attribute objects from cache onto the existing initialized object (self)"""

        if cache_dir:
            cache = get_cache_directory(
                cache_dir, calling_class=self.__class__.__name__
            )
        else:
            cache = get_cache_directory(calling_class=self.__class__.__name__)

        root, subdir = cache.root, cache.subdir
        # now we use "subdir" as our working directory to dump this cache object
        subdir /= identifier_string or self.identifier_string
        log(f"attempt loading attributes of {self} from {subdir.parent}")

        with (subdir / "xarray_object_names.yml").open("r") as f:
            self_xarray_objects = yaml.load(f, yaml.SafeLoader)

        with (subdir / "id.txt").open("r") as f:
            if (identifier_string or self.identifier_string) != (
                cached_identifier_str := f.read()
            ):
                if not overwrite:
                    raise ValueError(
                        f"mismatch in identifier string of self ({self.identifier_string}) and "
                        f"cached object ({cached_identifier_str}); overwriting is disabled."
                    )
                else:
                    log(
                        f"mismatch in identifier string of self ({self.identifier_string}) and "
                        f"cached object ({cached_identifier_str}); overwriting anyway."
                    )

        kwargs = {}
        for ob_name in self_xarray_objects:
            tgt_dir = subdir / (ob_name + ".xr")
            load_object_fn = getattr(xr, xarray_deserialization_backend)
            ob = load_object_fn(tgt_dir, **kwargs)
            setattr(self, ob_name, ob.data)

        with (subdir / "cacheable_object_pointers.yml").open("r") as f:
            cacheable_ptrs: dict = yaml.load(f, yaml.SafeLoader)

        # calls `load_cache` on all attributes that are also `_Cacheable` instances
        # and thus implement the `load_cache` method
        for key, ptr in cacheable_ptrs.items():
            try:
                ob = getattr(self, key)
                ob.load_cache(
                    identifier_string=identifier_string,
                    overwrite=overwrite,
                    xarray_deserialization_backend=xarray_deserialization_backend,
                    cache_dir=cache_dir,
                )
            except AttributeError:
                log(
                    f"`load_cache` currently only supports loading xarray objects or initialized `_Cacheable` objects"
                )

        with (subdir / "meta_attributes.yml").open("r") as f:
            meta_attributes: dict = yaml.load(f, yaml.SafeLoader)
        for key, ob in meta_attributes.items():
            setattr(self, key, ob)

    # NB comment from Guido: https://github.com/python/typing/issues/58#issuecomment-194569410
    @classmethod
    def from_cache(
        cls: typing.Callable[..., T],
        identifier_string: str,
        xarray_deserialization_backend="open_zarr",
        cache_dir=None,
    ) -> T:
        """
        construct an object from cache. subclasses must start with the
        object returned by a call to this method like so:

            ob = super().from_cache(filename)
            # further implementation, such as initializing
            # member classes based on metadata
            return ob

        """

        Duck = type(cls.__name__, (cls,), {"__init__": (lambda _: None)})
        duck = Duck()
        duck.load_cache(
            identifier_string,
            overwrite=True,
            xarray_deserialization_backend=xarray_deserialization_backend,
            cache_dir=cache_dir,
        )
        return duck