Source code for flutes.structure

from functools import lru_cache
from typing import Callable, Collection, Dict, List, Sequence, Set, Type, TypeVar, no_type_check

__all__ = [
    "reverse_map",
    "register_no_map_class",
    "no_map_instance",
    "map_structure",
    "map_structure_zip",
]

T = TypeVar('T')
R = TypeVar('R')


[docs]def reverse_map(d: Dict[T, int]) -> List[T]: r"""Given a dict containing pairs of ``(item, id)``, return a list where the ``id``-th element is ``item``. .. note:: It is assumed that the ``id``\ s form a permutation. .. code:: python >>> words = ['a', 'aardvark', 'abandon', ...] >>> word_to_id = {word: idx for idx, word in enumerate(words)} >>> id_to_word = reverse_map(word_to_id) >>> (words == id_to_word) True :param d: The dictionary mapping ``item`` to ``id``. """ return [k for k, _ in sorted(d.items(), key=lambda xs: xs[1])]
_NO_MAP_TYPES: Set[type] = set() _NO_MAP_INSTANCE_ATTR = "--no-map--"
[docs]def register_no_map_class(container_type: Type[T]) -> None: r"""Register a container type as `non-mappable`, i.e., instances of the class will be treated as singleton objects in :func:`map_structure` and :func:`map_structure_zip`, their contents will not be traversed. This would be useful for certain types that subclass built-in container types, such as ``torch.Size``. :param container_type: The type of the container, e.g. :py:class:`list`, :py:class:`dict`. """ return _NO_MAP_TYPES.add(container_type)
@lru_cache(maxsize=None) def _no_map_type(container_type: Type[T]) -> Type[T]: # Create a subtype of the container type that sets an normally inaccessible # special attribute on instances. # This is necessary because `setattr` does not work on built-in types # (e.g. `list`). new_type = type("_no_map" + container_type.__name__, (container_type,), {_NO_MAP_INSTANCE_ATTR: True}) return new_type
[docs]@no_type_check def no_map_instance(instance: T) -> T: r"""Register a container instance as `non-mappable`, i.e., it will be treated as a singleton object in :func:`map_structure` and :func:`map_structure_zip`, its contents will not be traversed. :param instance: The container instance. """ try: setattr(instance, _NO_MAP_INSTANCE_ATTR, True) return instance except AttributeError: return _no_map_type(type(instance))(instance)
[docs]@no_type_check def map_structure(fn: Callable[[T], R], obj: Collection[T]) -> Collection[R]: r"""Map a function over all elements in a (possibly nested) collection. :param fn: The function to call on elements. :param obj: The collection to map function over. :return: The collection in the same structure, with elements mapped. """ if obj.__class__ in _NO_MAP_TYPES or hasattr(obj, _NO_MAP_INSTANCE_ATTR): return fn(obj) if isinstance(obj, list): return [map_structure(fn, x) for x in obj] if isinstance(obj, tuple): if hasattr(obj, '_fields'): # namedtuple return type(obj)(*[map_structure(fn, x) for x in obj]) else: return tuple(map_structure(fn, x) for x in obj) if isinstance(obj, dict): # could be `OrderedDict` return type(obj)((k, map_structure(fn, v)) for k, v in obj.items()) if isinstance(obj, set): return {map_structure(fn, x) for x in obj} return fn(obj)
[docs]@no_type_check def map_structure_zip(fn: Callable[..., R], objs: Sequence[Collection[T]]) -> Collection[R]: r"""Map a function over tuples formed by taking one elements from each (possibly nested) collection. Each collection must have identical structures. .. note:: Although identical structures are required, it is not enforced by assertions. The structure of the first collection is assumed to be the structure for all collections. :param fn: The function to call on elements. :param objs: The list of collections to map function over. :return: A collection with the same structure, with elements mapped. """ obj = objs[0] if obj.__class__ in _NO_MAP_TYPES or hasattr(obj, _NO_MAP_INSTANCE_ATTR): return fn(*objs) if isinstance(obj, list): return [map_structure_zip(fn, xs) for xs in zip(*objs)] if isinstance(obj, tuple): if hasattr(obj, '_fields'): # namedtuple return type(obj)(*[map_structure_zip(fn, xs) for xs in zip(*objs)]) else: return tuple(map_structure_zip(fn, xs) for xs in zip(*objs)) if isinstance(obj, dict): # could be `OrderedDict` return type(obj)((k, map_structure_zip(fn, [o[k] for o in objs])) for k in obj.keys()) if isinstance(obj, set): raise ValueError("Structures cannot contain `set` because it's unordered") return fn(*objs)