diff --git a/src/ophydregistry/registry.py b/src/ophydregistry/registry.py index c3dbe64..310cdbc 100644 --- a/src/ophydregistry/registry.py +++ b/src/ophydregistry/registry.py @@ -1,10 +1,9 @@ import logging import time -import warnings from collections import OrderedDict from itertools import chain from typing import Hashable, List, Mapping, Optional, Sequence, Tuple -from weakref import WeakSet, WeakValueDictionary +from weakref import WeakSet from ophyd import ophydobj @@ -120,15 +119,11 @@ class Registry: If false, items will be dropped from this registry if the only reference comes from this registry. Relies on the garbage collector, so to force cleanup use ``gc.collect()``. - warn_duplicates - If true, a warning will be issued if this device is - overwriting a previous device with the same name. """ use_typhos: bool keep_references: bool - warn_duplicates: bool _auto_register: bool _valid_classes: Tuple[type] = ( ophydobj.OphydObject, @@ -145,7 +140,6 @@ def __init__( auto_register: bool = True, use_typhos: bool = False, keep_references: bool = True, - warn_duplicates: bool = True, ): # Check that Typhos is installed if needed if use_typhos and not typhos_available: @@ -155,7 +149,6 @@ def __init__( self.use_typhos = use_typhos self.clear() self.auto_register = auto_register - self.warn_duplicates = warn_duplicates @property def auto_register(self): @@ -239,10 +232,7 @@ def clear(self, clear_typhos: bool = True): """ self._objects_by_label = OrderedDict() - if self.keep_references: - self._objects_by_name = OrderedDict() - else: - self._objects_by_name = WeakValueDictionary() + self._objects_by_name = OrderedDict() if clear_typhos and self.use_typhos: typhos.plugins.core.signal_registry.clear() @@ -284,16 +274,15 @@ def component_names(self): @property def root_devices(self): """Only return root devices, those without parents.""" - return set( - dev for name, dev in self._objects_by_name.items() if dev.parent is None - ) + all_devices = [ + dev for devices in self._objects_by_name.values() for dev in devices + ] + return {device for device in all_devices if device.parent is None} @property def device_names(self): """Only return root devices, those without parents.""" - return set( - [name for name, dev in self._objects_by_name.items() if dev.parent is None] - ) + return {device.name for device in self.root_devices} def find( self, @@ -346,21 +335,24 @@ def find( ``self.findall()`` method. """ - results = list( - self.findall(any_of=any_of, label=label, name=name, allow_none=allow_none) + devices = self.findall( + any_of=any_of, label=label, name=name, allow_none=allow_none ) - if len(results) == 1: - result = results[0] - elif len(results) > 1: + # Remove any direct ancestors + devices = set(dev for dev in devices if dev.parent not in devices) + # Make sure we have only 1 result + if len(devices) == 1: + device = list(devices)[0] + elif len(devices) > 1: raise MultipleComponentsFound( - f"Found {len(results)} components matching query " + f"Found {len(devices)} components matching query " f"[any_of={any_of}, label={label}, name={name}]. " "Consider using ``findall()``. " - f"{results}" + f"{devices}" ) else: - result = None - return result + device = None + return device def _is_resolved(self, obj): """Is the object already resolved into an ophyd device, etc. @@ -406,33 +398,31 @@ def _findall_by_name(self, name): if self._is_resolved(name): yield name return - # Check for an edge case with EpicsMotor objects (user_readback name is same as parent) - try: - is_user_readback = name[-13:] == "user_readback" - except TypeError: - is_user_readback = False - if is_user_readback: - parentname = name[:-14].strip("_") - yield self.find(name=parentname).user_readback - elif is_iterable(name): + # Check for an iterable of names instead of a single name + if is_iterable(name): for n in name: yield from self.findall(name=n) + return + # Split off any dot notation parameters for later filtering + try: + name, *attrs = name.split(".") + except AttributeError: + attrs = [] + # Find the matching components + print(self._objects_by_name) + try: + devices = self._objects_by_name[name] + except KeyError: + pass else: - # Split off any dot notation parameters for later filtering - try: - name, *attrs = name.split(".") - except AttributeError: - attrs = [] - # Find the matching components - try: - cpt_ = self._objects_by_name[name] - except KeyError: - pass - else: - # Re-apply dot-notation filter - for attr in attrs: - cpt_ = getattr(cpt_, attr) - yield cpt_ + # Re-apply dot-notation filter + for device in devices: + try: + for attr in attrs: + device = getattr(device, attr) + except AttributeError: + continue + yield device def findall( self, @@ -528,7 +518,6 @@ def register( self, component: ophydobj.OphydObject, labels: Optional[Sequence] = None, - warn_duplicates=None, ) -> ophydobj.OphydObject: """Register a device, component, etc so that it can be retrieved later. @@ -543,97 +532,68 @@ def register( labels Device labels to use for registration. If `None` (default), the devices *_ophyd_labels_* parameter will be used. - warn_duplicates - If true, a warning will be issued if this device is - overwriting a previous device with the same name. - If None, defaults to the value of the same-named class attribute. """ - if warn_duplicates is None: - warn_duplicates = self.warn_duplicates # Determine how to register the device if isinstance(component, type): # A class was given, so instances should be auto-registered component.__new__ = self.__new__wrapper - else: # An instance was given, so just save it in the register - try: - name = component.name - except AttributeError: - msg = f"Skipping unnamed component {component}" - if isinstance(component, _AggregateSignalState): - log.debug(msg) - else: - log.info(msg) - return component - # Register this object with Typhos - if self.use_typhos: - register_typhos_signal(component) - # Ignore any instances with the same name as a previous component - # (Needed for some sub-components that are just readback - # values of the parent) - # Check if we're adding a duplicate component name - is_duplicate = False - if name in self._objects_by_name.keys(): - old_obj = self._objects_by_name[name] - is_readback = component in [ - getattr(old_obj, "readback", None), - getattr(old_obj, "user_readback", None), - getattr(old_obj, "val", None), - ] - if is_readback: - msg = f"Ignoring readback with duplicate name: '{name}'" - log.debug(msg) - return component - elif old_obj is component: - msg = f"Ignoring previously registered component: '{name}'" - log.debug(msg) - return component - else: - msg = f"Ignoring component with duplicate name: '{name}'" - is_duplicate = True - if warn_duplicates: - log.warning(msg) - warnings.warn(msg) - else: - log.debug(msg) - # Register this component - log.debug(f"Registering {name}") - # Check if this device was previously registered with a - # different name - old_keys = [ - key for key, val in self._objects_by_name.items() if val is component - ] - for old_key in old_keys: - del self._objects_by_name[old_key] - # Register by name - if component.name != "": - self._objects_by_name[component.name] = component - # Create a set for this device's labels if it doesn't exist - if labels is None: - ophyd_labels = getattr(component, "_ophyd_labels_", []) - else: - ophyd_labels = labels - for label in ophyd_labels: - if label not in self._objects_by_label.keys(): - if self.keep_references: - self._objects_by_label[label] = set() - else: - self._objects_by_label[label] = WeakSet() - self._objects_by_label[label].add(component) - # Register this object with Typhos - if self.use_typhos: - import typhos - - typhos.plugins.register_signal(component) - # Recusively register sub-components - if hasattr(component, "_signals"): - # Vanilla ophyd device - sub_signals = component._signals.items() - elif hasattr(component, "children"): - # Ophyd-async device - sub_signals = component.children() + return component + # An instance was given, so just save it in the register + try: + name = component.name + except AttributeError: + msg = f"Skipping unnamed component {component}" + if isinstance(component, _AggregateSignalState): + log.debug(msg) else: - sub_signals = [] - for cpt_name, cpt in sub_signals: - self.register(cpt, warn_duplicates=not is_duplicate and warn_duplicates) + log.info(msg) + return component + # Register this object with Typhos + if self.use_typhos: + register_typhos_signal(component) + # Register this component + log.debug(f"Registering {name}") + # Check if this device was previously registered with a + # different name + old_keys = [ + key for key, val in self._objects_by_name.items() if val is component + ] + for old_key in old_keys: + del self._objects_by_name[old_key] + # Register by name + if self.keep_references: + new_set = set + else: + new_set = WeakSet + if component.name != "": + name = component.name + if name not in self._objects_by_name.keys(): + self._objects_by_name[name] = new_set() + self._objects_by_name[name].add(component) + # Create a set for this device's labels if it doesn't exist + if labels is None: + ophyd_labels = getattr(component, "_ophyd_labels_", []) + else: + ophyd_labels = labels + for label in ophyd_labels: + if label not in self._objects_by_label.keys(): + self._objects_by_label[label] = new_set() + self._objects_by_label[label].add(component) + # Register this object with Typhos + if self.use_typhos: + import typhos + + typhos.plugins.register_signal(component) + # Recusively register sub-components + if hasattr(component, "_signals"): + # Vanilla ophyd device + sub_signals = component._signals.items() + elif hasattr(component, "children"): + # Ophyd-async device + sub_signals = component.children() + else: + sub_signals = [] + for cpt_name, cpt in sub_signals: + self.register(cpt) return component diff --git a/src/ophydregistry/tests/test_instrument_registry.py b/src/ophydregistry/tests/test_instrument_registry.py index 09cc621..15bcc1e 100644 --- a/src/ophydregistry/tests/test_instrument_registry.py +++ b/src/ophydregistry/tests/test_instrument_registry.py @@ -1,5 +1,4 @@ import gc -import logging import time from concurrent.futures import ThreadPoolExecutor from unittest import mock @@ -300,16 +299,6 @@ def test_find_by_list_of_names(registry): assert cptC not in result -def test_user_readback(registry): - """Edge case where EpicsMotor.user_readback is named the same as the motor itself.""" - device = sim.instantiate_fake_device( - EpicsMotor, prefix="255idVME:m1", name="epics_motor" - ) - registry.register(device) - # See if requesting the device.user_readback returns the proper signal - registry.find("epics_motor_user_readback") - - def test_auto_register(): """Ensure the registry gets devices that aren't explicitly registered. @@ -391,29 +380,17 @@ def test_getitem(registry): def test_duplicate_device(caplog, registry): - """Check that a device doesn't get added twice.""" + """Check what happens when a device gets added twice.""" # Two devices with the same name motor1 = sim.instantiate_fake_device(EpicsMotor, prefix="", name="motor") motor2 = sim.instantiate_fake_device(EpicsMotor, prefix="", name="motor") # Set up logging so that we can know what - caplog.clear() - with caplog.at_level(logging.DEBUG): - registry.register(motor1) - # Check for the edge case where motor and motor.user_readback have the same name - assert "Ignoring component with duplicate name" not in caplog.text - assert "Ignoring readback with duplicate name" in caplog.text - # Check that truly duplicated entries get a warning - caplog.clear() - with caplog.at_level(logging.WARNING): - with pytest.warns(UserWarning): - registry.register(motor2) - # Check for the edge case where motor and motor.user_readback have the same name - assert "Ignoring component with duplicate name" in caplog.text - print(caplog.text) - # Check that the warning is only issued for the top-level device, not all its children - assert "motor_user_setpoint" not in caplog.text - # Check that the correct second device is the one that wound up in the registry - assert registry["motor"] is motor2 + registry.register(motor1) + registry.register(motor2) + # Check that we retrieve the correct things from the registry + registered_motors = registry.findall(name="motor") + expected_motors = {motor1, motor1.user_readback, motor2, motor2.user_readback} + assert set(registered_motors) == expected_motors def test_delete_by_name(registry):