Source code for himalaya.backend._utils

import types
import importlib
import warnings
from functools import wraps


ALL_BACKENDS = [
    "numpy",
    "cupy",
    "torch",
    "torch_cuda",
]

CURRENT_BACKEND = "numpy"

MATCHING_CPU_BACKEND = {
    "numpy": "numpy",
    "cupy": "numpy",
    "torch": "torch",
    "torch_cuda": "torch",
}


[docs]def set_backend(backend, on_error="raise"): """Set the backend using a global variable, and return the backend module. Parameters ---------- backend : str or module Name or module of the backend. on_error : str in {"raise", "warn"} Define what is done if the backend fails to be loaded. If "warn", this function only warns, and keeps the previous backend. If "raise", this function raises on errors. Returns ------- module : python module Module of the backend. """ global CURRENT_BACKEND try: if isinstance(backend, types.ModuleType): # get name from module backend = backend.name if backend not in ALL_BACKENDS: raise ValueError("Unknown backend=%r" % (backend, )) module = importlib.import_module(__package__ + "." + backend) CURRENT_BACKEND = backend except Exception as error: if on_error == "raise": raise error elif on_error == "warn": warnings.warn(f"Setting backend to {backend} failed: {str(error)}." f"Falling back to {CURRENT_BACKEND} backend.") module = get_backend() else: raise ValueError('Unknown value on_error=%r' % (on_error, )) return module
[docs]def get_backend(): """Get the current backend module. Returns ------- module : python module Module of the backend. """ module = importlib.import_module(__package__ + "." + CURRENT_BACKEND) return module
def _dtype_to_str(dtype): """Cast dtype to string, such as "float32", or "float64".""" if isinstance(dtype, str): return dtype elif hasattr(dtype, "name"): # works for numpy and cupy return dtype.name elif "torch." in str(dtype): # works for torch return str(dtype)[6:] elif dtype is None: return None else: raise NotImplementedError() def force_cpu_backend(func): """Decorator to force the use of a CPU backend.""" @wraps(func) def wrapper(*args, **kwargs): # skip if the object does not force cpu use if not hasattr(args[0], "force_cpu") or not args[0].force_cpu: return func(*args, **kwargs) # set corresponding cpu backend original_backend = get_backend().name temp_backend = MATCHING_CPU_BACKEND[original_backend] set_backend(temp_backend) # run function result = func(*args, **kwargs) # set back original backend set_backend(original_backend) return result return wrapper