diff --git a/environment.yml b/environment.yml index 1c78f817..7c2d3b56 100644 --- a/environment.yml +++ b/environment.yml @@ -59,6 +59,7 @@ dependencies: # - pyqt >=5.12, <5.13 - qt >=5.12 - qtawesome + - qasync # --- general support packages - bitshuffle diff --git a/pytest.ini b/pytest.ini index c33540b1..facb7555 100644 --- a/pytest.ini +++ b/pytest.ini @@ -1,3 +1,3 @@ [pytest] mongodb_fixture_dir = - tests/fixtures + src/haven/tests/fixtures diff --git a/src/conftest.py b/src/conftest.py index 48a0811d..08502c93 100644 --- a/src/conftest.py +++ b/src/conftest.py @@ -1,14 +1,18 @@ +import gc +import asyncio import os import subprocess from pathlib import Path from unittest import mock import psutil +from qasync import QEventLoop, DefaultQEventLoopPolicy # from pydm.data_plugins import plugin_modules, add_plugin import pydm import pytest -from bluesky import RunEngine +import numpy as np +import pandas as pd from ophyd import DynamicDeviceComponent as DDC from ophyd import Kind from ophyd.sim import ( @@ -18,6 +22,12 @@ make_fake_device, ) from pytestqt.qt_compat import qt_api +from tiled.adapters.mapping import MapAdapter +from tiled.adapters.xarray import DatasetAdapter +from tiled.adapters.table import TableAdapter +from tiled.client import Context, from_context +from tiled.server.app import build_app + import haven from firefly.application import FireflyApplication @@ -35,6 +45,7 @@ from haven.instrument.slits import ApertureSlits, BladeSlits from haven.instrument.xspress import Xspress3Detector from haven.instrument.xspress import add_mcas as add_xspress_mcas +from haven.catalog import Catalog top_dir = Path(__file__).parent.resolve() haven_dir = top_dir / "haven" @@ -67,53 +78,6 @@ def beamline_connected(): yield -class RunEngineStub(RunEngine): - def __repr__(self): - return "" - - -@pytest.fixture() -def RE(event_loop): - return RunEngineStub(call_returns_result=True) - - -@pytest.fixture(scope="session") -def qapp_cls(): - return FireflyApplication - - -# def pytest_configure(config): -# app = QtWidgets.QApplication.instance() -# assert app is None -# app = FireflyApplication() -# app = QtWidgets.QApplication.instance() -# assert isinstance(app, FireflyApplication) -# # # Create event loop for asyncio stuff -# # loop = asyncio.new_event_loop() -# # asyncio.set_event_loop(loop) - - -def tiled_is_running(port, match_command=True): - lsof = subprocess.run(["lsof", "-i", f":{port}", "-F"], capture_output=True) - assert lsof.stderr.decode() == "" - stdout = lsof.stdout.decode().split("\n") - is_running = len(stdout) >= 3 - if match_command: - is_running = is_running and stdout[3] == "ctiled" - return is_running - - -def kill_process(process_name): - processes = [] - for proc in psutil.process_iter(): - # check whether the process name matches - if proc.name() == process_name: - proc.kill() - processes.append(proc) - # Wait for them all the terminate - [proc.wait(timeout=5) for proc in processes] - - @pytest.fixture() def sim_registry(monkeypatch): # mock out Ophyd connections so devices can be created @@ -307,68 +271,124 @@ def shutters(sim_registry): yield shutters -@pytest.fixture(scope="session") -def pydm_ophyd_plugin(): - return pydm.data_plugins.plugin_for_address("haven://") - - -qs_status = { - "msg": "RE Manager v0.0.18", - "items_in_queue": 0, - "items_in_history": 0, - "running_item_uid": None, - "manager_state": "idle", - "queue_stop_pending": False, - "worker_environment_exists": False, - "worker_environment_state": "closed", - "worker_background_tasks": 0, - "re_state": None, - "pause_pending": False, - "run_list_uid": "4f2d48cc-980d-4472-b62b-6686caeb3833", - "plan_queue_uid": "2b99ccd8-f69b-4a44-82d0-947d32c5d0a2", - "plan_history_uid": "9af8e898-0f00-4e7a-8d97-0964c8d43f47", - "devices_existing_uid": "51d8b88d-7457-42c4-b67f-097b168be96d", - "plans_existing_uid": "65f11f60-0049-46f5-9eb3-9f1589c4a6dd", - "devices_allowed_uid": "a5ddff29-917c-462e-ba66-399777d2442a", - "plans_allowed_uid": "d1e907cd-cb92-4d68-baab-fe195754827e", - "plan_queue_mode": {"loop": False}, - "task_results_uid": "159e1820-32be-4e01-ab03-e3478d12d288", - "lock_info_uid": "c7fe6f73-91fc-457d-8db0-dfcecb2f2aba", - "lock": {"environment": False, "queue": False}, +# holds a global QApplication instance created in the qapp fixture; keeping +# this reference alive avoids it being garbage collected too early +_ffapp_instance = None + + +# Tiled data to use for testing +# Some mocked test data +run1 = pd.DataFrame( + { + "energy_energy": np.linspace(8300, 8400, num=100), + "It_net_counts": np.abs(np.sin(np.linspace(0, 4 * np.pi, num=100))), + "I0_net_counts": np.linspace(1, 2, num=100), + } +).to_xarray() + +grid_scan = pd.DataFrame( + { + 'CdnIPreKb': np.linspace(0, 104, num=105), + "It_net_counts": np.linspace(0, 104, num=105), + "aerotech_horiz": np.linspace(0, 104, num=105), + "aerotech_vert": np.linspace(0, 104, num=105), + } +).to_xarray() + +hints = { + "energy": {"fields": ["energy_energy", "energy_id_energy_readback"]}, +} + +bluesky_mapping = { + "7d1daf1d-60c7-4aa7-a668-d1cd97e5335f": MapAdapter( + { + "primary": MapAdapter( + { + "data": DatasetAdapter.from_dataset(run1), + }, + metadata={"descriptors": [{"hints": hints}]}, + ), + }, + metadata={ + "plan_name": "xafs_scan", + "start": { + "plan_name": "xafs_scan", + "uid": "7d1daf1d-60c7-4aa7-a668-d1cd97e5335f", + "hints": {"dimensions": [[["energy_energy"], "primary"]]}, + }, + }, + ), + "9d33bf66-9701-4ee3-90f4-3be730bc226c": MapAdapter( + { + "primary": MapAdapter( + { + "data": DatasetAdapter.from_dataset(run1), + }, + metadata={"descriptors": [{"hints": hints}]}, + ), + }, + metadata={ + "start": { + "plan_name": "rel_scan", + "uid": "9d33bf66-9701-4ee3-90f4-3be730bc226c", + "hints": {"dimensions": [[["pitch2"], "primary"]]}, + } + }, + ), + # 2D grid scan map data + "85573831-f4b4-4f64-b613-a6007bf03a8d": MapAdapter( + { + "primary": MapAdapter( + { + "data": DatasetAdapter.from_dataset(grid_scan), + }, metadata={ + "descriptors": [{"hints": {'Ipreslit': {'fields': ['Ipreslit_net_counts']}, + 'CdnIPreKb': {'fields': ['CdnIPreKb_net_counts']}, + 'I0': {'fields': ['I0_net_counts']}, + 'CdnIt': {'fields': ['CdnIt_net_counts']}, + 'aerotech_vert': {'fields': ['aerotech_vert']}, + 'aerotech_horiz': {'fields': ['aerotech_horiz']}, + 'Ipre_KB': {'fields': ['Ipre_KB_net_counts']}, + 'CdnI0': {'fields': ['CdnI0_net_counts']}, + 'It': {'fields': ['It_net_counts']}}}] + }), + }, + metadata={ + "start": { + "plan_name": "grid_scan", + "uid": "85573831-f4b4-4f64-b613-a6007bf03a8d", + "hints": { + 'dimensions': [[['aerotech_vert'], 'primary'], + [['aerotech_horiz'], 'primary']], + 'gridding': 'rectilinear' + }, + "shape": [5, 21], + "extents": [[-80, 80], [-100, 100]], + }, + }, + ), } +mapping = { + "255id_testing": MapAdapter(bluesky_mapping), +} + +tree = MapAdapter(mapping) + + @pytest.fixture(scope="session") -def ffapp(pydm_ophyd_plugin, qapp_cls, qapp_args, pytestconfig): - # Get an instance of the application - app = qt_api.QtWidgets.QApplication.instance() - if app is None: - # New Application - global _ffapp_instance - _ffapp_instance = qapp_cls(qapp_args) - app = _ffapp_instance - name = pytestconfig.getini("qt_qapp_name") - app.setApplicationName(name) - # Make sure there's at least one Window, otherwise things get weird - if getattr(app, "_dummy_main_window", None) is None: - # Set up the actions and other boildplate stuff - app.setup_window_actions() - app.setup_runengine_actions() - app._dummy_main_window = FireflyMainWindow() - # Sanity check to make sure a QApplication was not created by mistake - assert isinstance(app, FireflyApplication) - # Yield the finalized application object - try: - yield app - finally: - if hasattr(app, "_queue_thread"): - app._queue_thread.quit() - app._queue_thread.wait(msecs=5000) +def tiled_client(): + app = build_app(tree) + with Context.from_app(app) as context: + client = from_context(context) + yield client["255id_testing"] -# holds a global QApplication instance created in the qapp fixture; keeping -# this reference alive avoids it being garbage collected too early -_ffapp_instance = None +@pytest.fixture(scope="session") +def catalog(tiled_client): + return Catalog(client=tiled_client) + # ----------------------------------------------------------------------------- diff --git a/src/firefly/application.py b/src/firefly/application.py index 22da6779..d95f31b8 100644 --- a/src/firefly/application.py +++ b/src/firefly/application.py @@ -1,3 +1,4 @@ +import asyncio import logging import subprocess from collections import OrderedDict @@ -133,7 +134,7 @@ def load_instrument(self): # Actions for controlling the bluesky run engine self.setup_runengine_actions() # Prepare the client for interacting with the queue server - self.prepare_queue_client() + # self.prepare_queue_client() # Launch the default display show_default_window = getattr(self, f"show_{self.default_display}_window") default_window = show_default_window() @@ -216,7 +217,7 @@ def setup_window_actions(self): self._setup_window_action( action_name="show_run_browser_action", text="Browse Runs", - slot=self.show_run_browser, + slot=self.show_run_browser_window, ) # Action for launch queue-monitor self._setup_window_action( @@ -529,6 +530,7 @@ def show_window(self, WindowClass, ui_file, name=None, macros={}): if (w := self.windows.get(name)) is None: # Window is not yet created, so create one w = self.create_window(WindowClass, ui_dir / ui_file, macros=macros) + # return self.windows[name] = w # Connect signals to remove the window when it closes w.destroyed.connect(partial(self.forget_window, name=name)) @@ -591,7 +593,7 @@ def show_status_window(self, stylesheet_path=None): ) @QtCore.Slot() - def show_run_browser(self): + def show_run_browser_window(self): return self.show_window( PlanMainWindow, ui_dir / "run_browser.py", name="run_browser" ) diff --git a/src/firefly/conftest.py b/src/firefly/conftest.py new file mode 100644 index 00000000..cde9815f --- /dev/null +++ b/src/firefly/conftest.py @@ -0,0 +1,137 @@ +import asyncio +import pytest +import subprocess +import psutil +import gc + +import pydm +from qasync import QEventLoop, DefaultQEventLoopPolicy + +from firefly import FireflyApplication +from firefly.main_window import FireflyMainWindow + + +@pytest.fixture(scope="session") +def qapp_cls(): + return FireflyApplication + + +# def pytest_configure(config): +# app = qt_api.QtWidgets.QApplication.instance() +# assert app is None +# app = FireflyApplication() +# app = qt_api.QtWidgets.QApplication.instance() +# assert isinstance(app, FireflyApplication) +# # # Create event loop for asyncio stuff +# # loop = asyncio.new_event_loop() +# # asyncio.set_event_loop(loop) + + +def tiled_is_running(port, match_command=True): + lsof = subprocess.run(["lsof", "-i", f":{port}", "-F"], capture_output=True) + assert lsof.stderr.decode() == "" + stdout = lsof.stdout.decode().split("\n") + is_running = len(stdout) >= 3 + if match_command: + is_running = is_running and stdout[3] == "ctiled" + return is_running + + +def kill_process(process_name): + processes = [] + for proc in psutil.process_iter(): + # check whether the process name matches + if proc.name() == process_name: + proc.kill() + processes.append(proc) + # Wait for them all the terminate + [proc.wait(timeout=5) for proc in processes] + + +@pytest.fixture(scope="session") +def pydm_ophyd_plugin(): + return pydm.data_plugins.plugin_for_address("haven://") + + +qs_status = { + "msg": "RE Manager v0.0.18", + "items_in_queue": 0, + "items_in_history": 0, + "running_item_uid": None, + "manager_state": "idle", + "queue_stop_pending": False, + "worker_environment_exists": False, + "worker_environment_state": "closed", + "worker_background_tasks": 0, + "re_state": None, + "pause_pending": False, + "run_list_uid": "4f2d48cc-980d-4472-b62b-6686caeb3833", + "plan_queue_uid": "2b99ccd8-f69b-4a44-82d0-947d32c5d0a2", + "plan_history_uid": "9af8e898-0f00-4e7a-8d97-0964c8d43f47", + "devices_existing_uid": "51d8b88d-7457-42c4-b67f-097b168be96d", + "plans_existing_uid": "65f11f60-0049-46f5-9eb3-9f1589c4a6dd", + "devices_allowed_uid": "a5ddff29-917c-462e-ba66-399777d2442a", + "plans_allowed_uid": "d1e907cd-cb92-4d68-baab-fe195754827e", + "plan_queue_mode": {"loop": False}, + "task_results_uid": "159e1820-32be-4e01-ab03-e3478d12d288", + "lock_info_uid": "c7fe6f73-91fc-457d-8db0-dfcecb2f2aba", + "lock": {"environment": False, "queue": False}, +} + + +class FireflyQEventLoopPolicy(DefaultQEventLoopPolicy): + def new_event_loop(self): + return QEventLoop(FireflyApplication.instance()) + + +@pytest.fixture() +def event_loop_policy(request, ffapp): + """Make sure pytest-asyncio uses the QEventLoop.""" + return FireflyQEventLoopPolicy() + + +@pytest.fixture() +def ffapp(pydm_ophyd_plugin, qapp_cls, qapp_args, pytestconfig): + # Get an instance of the application + # app = qt_api.QtWidgets.QApplication.instance() + app = qapp_cls.instance() + if app is None: + # New Application + global _ffapp_instance + app = qapp_cls(qapp_args) + # _ffapp_instance = app + name = pytestconfig.getini("qt_qapp_name") + app.setApplicationName(name) + # Make sure there's at least one Window, otherwise things get weird + if getattr(app, "_dummy_main_window", None) is None: + # Set up the actions and other boildplate stuff + app.setup_window_actions() + app.setup_runengine_actions() + app._dummy_main_window = FireflyMainWindow() + yield app + # try: + # yield app + # finally: + # del app + # gc.collect() + + +@pytest.fixture() +def affapp(event_loop, ffapp): + # Prepare the event loop + asyncio.set_event_loop(event_loop) + # Sanity check to make sure a QApplication was not created by mistake + assert isinstance(ffapp, FireflyApplication) + # Yield the finalized application object + try: + yield ffapp + finally: + # Cancel remaining async tasks + pending = asyncio.all_tasks(event_loop) + event_loop.run_until_complete(asyncio.gather(*pending)) + assert all(task.done() for task in pending), "Shutdown tasks not complete." + # if hasattr(app, "_queue_thread"): + # app._queue_thread.quit() + # app._queue_thread.wait(msecs=5000) + # del app + # gc.collect() diff --git a/src/firefly/launcher.py b/src/firefly/launcher.py index cbe37ecb..9838392f 100644 --- a/src/firefly/launcher.py +++ b/src/firefly/launcher.py @@ -1,3 +1,4 @@ +import asyncio import argparse import cProfile import logging @@ -6,6 +7,8 @@ import time from pathlib import Path +from qasync import QEventLoop + import haven @@ -174,17 +177,27 @@ def main(default_fullscreen=False, default_display="status"): stylesheet_path=pydm_args.stylesheet, ) + # Make it asynchronous + event_loop = QEventLoop(app) + asyncio.set_event_loop(event_loop) + app_close_event = asyncio.Event() + app.aboutToQuit.connect(app_close_event.set) + # Define devices on the beamline (slow!) if not pydm_args.no_instrument: haven.load_instrument() app.load_instrument() - FireflyApplication.processEvents() + # FireflyApplication.processEvents() - # Show the first window - first_window = list(app.windows.values())[0] - splash.finish(first_window) + # Show the first window (breaks asyncio) + # first_window = list(app.windows.values())[0] + # splash.finish(first_window) + splash.close() - exit_code = app.exec_() + event_loop.run_until_complete(app_close_event.wait()) + # event_loop.run_until_complete(app.exec_) + # exit_code = app.exec_() + event_loop.close() if pydm_args.profile: profile.disable() @@ -194,7 +207,7 @@ def main(default_fullscreen=False, default_display="status"): ).sort_stats(pstats.SortKey.CUMULATIVE) stats.print_stats() - sys.exit(exit_code) + # sys.exit(exit_code) def cameras(): diff --git a/src/firefly/run_browser.py b/src/firefly/run_browser.py index 0b4173f2..a38a1a1c 100644 --- a/src/firefly/run_browser.py +++ b/src/firefly/run_browser.py @@ -1,13 +1,17 @@ import logging from itertools import count -from typing import Sequence +from typing import Sequence, Mapping +import time +import asyncio +from functools import partial, wraps import numpy as np import qtawesome as qta import yaml +from qasync import asyncSlot from matplotlib.colors import TABLEAU_COLORS from pydantic.error_wrappers import ValidationError -from pyqtgraph import PlotItem, PlotWidget +from pyqtgraph import PlotItem, PlotWidget, ImageView, GraphicsLayoutWidget from qtpy.QtCore import Qt, QThread, Signal from qtpy.QtGui import QStandardItem, QStandardItemModel from qtpy.QtWidgets import QWidget @@ -22,6 +26,17 @@ colors = list(TABLEAU_COLORS.values()) +def cancellable(fn): + @wraps(fn) + async def inner(*args, **kwargs): + try: + return await fn(*args, **kwargs) + except asyncio.exceptions.CancelledError: + log.warning(f"Cancelled task {fn}") + return inner + + + class FiltersWidget(QWidget): returnPressed = Signal() @@ -47,11 +62,162 @@ def hoverEvent(self, event): self.hover_coords_changed.emit(pos_str) +class BrowserMultiPlotWidget(GraphicsLayoutWidget): + _multiplot_items: Mapping + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._multiplot_items = {} + + def multiplot_items(self, n_cols: int = 3): + view = self + item0 = None + for idx in count(): + row = int(idx / n_cols) + col = idx % n_cols + # Make a new plot item if one doesn't exist + if (row, col) not in self._multiplot_items: + self._multiplot_items[(row, col)] = view.addPlot(row=row, col=col) + new_item = self._multiplot_items[(row, col)] + # Link the X-axes together + if item0 is None: + item0 = new_item + else: + new_item.setXLink(item0) + # Resize the viewing area to fit the contents + width = view.width() + plot_width = width / n_cols + # view.resize(int(width), int(plot_width * row)) + view.setFixedHeight(1200) + yield new_item + + def plot_runs(self, runs: Mapping, xsignal: str): + """Take loaded run data and plot small multiples. + + Parameters + ========== + runs + Dictionary with pandas series for each curve. The keys + should be the curve labels, the series' indexes are the x + values and the series' values are the y data. + xsignal + The name of the signal to use for the common horizontal + axis. + + """ + # Use all the data columns as y signals + ysignals = [] + for run in runs.values(): + ysignals.extend(run.columns) + # Remove the x-signal from the list of y signals + ysignals = sorted(list(dict.fromkeys(ysignals))) + # Plot the runs + self.clear() + for label, data in runs.items(): + # Figure out which signals to plot + try: + xdata = data[xsignal] + except KeyError: + log.warning(f"Cannot plot x='{xsignal}' for {list(data.keys())}") + continue + # Plot each y signal on a separate plot + for ysignal, plot_item in zip(ysignals, self.multiplot_items()): + try: + plot_item.plot(xdata, data[ysignal]) + except KeyError: + log.warning(f"No signal {ysignal} in data.") + else: + log.debug(f"Plotted {ysignal} vs. {xsignal} for {data}") + plot_item.setTitle(ysignal) + + class Browser1DPlotWidget(PlotWidget): def __init__(self, parent=None, background="default", plotItem=None, **kargs): plot_item = Browser1DPlotItem(**kargs) super().__init__(parent=parent, background=background, plotItem=plot_item) + def plot_runs(self, runs: Mapping, ylabel="", xlabel=""): + """Take loaded run data and plot it. + + Parameters + ========== + runs + Dictionary with pandas series for each curve. The keys + should be the curve labels, the series' indexes are the x + values and the series' values are the y data. + + """ + plot_item = self.getPlotItem() + plot_item.clear() + # Plot this run's data + cursor_needed = True + for idx, (label, series) in enumerate(runs.items()): + color = colors[idx % len(colors)] + plot_item.plot( + x=series.index, + y=series.values, + pen=color, + name=label, + clear=False, + ) + # Cursor to drag around on the data + if cursor_needed: + plot_item.addLine( + x=np.median(series.index), movable=True, label="{value:.3f}" + ) + cursor_needed = False + # Axis formatting + plot_item.setLabels(left=ylabel, bottom=xlabel) + + +class Browser2DPlotWidget(ImageView): + + """A plot widget for 2D maps.""" + def __init__(self, *args, view=None, **kwargs): + if view is None: + view = PlotItem() + super().__init__(*args, view=view, **kwargs) + + def plot_runs(self, runs: Mapping, xlabel: str = "", ylabel: str = "", extents=None): + """Take loaded 2D or 3D mapping data and plot it. + + Parameters + ========== + runs + Dictionary with pandas series for each curve. The keys + should be the curve labels, the series' indexes are the x + values and the series' values are the y data. + xlabel + The label for the horizontal axis. + ylabel + The label for the vertical axis. + extents + Spatial extents for the map as ((-y, +y), (-x, +x)). + + """ + images = np.asarray(list(runs.values())) + # Combine the different runs into one image + # To-do: make this respond to the combobox selection + image = np.mean(images, axis=0) + # To-do: Apply transformations + + # # Plot the image + if 2 <= image.ndim <= 3: + self.setImage(image.T, autoRange=False) + else: + log.info(f"Could not plot image of dataset with shape {image.shape}.") + return + # Determine the axes labels + self.view.setLabel(axis="bottom", text=xlabel) + self.view.setLabel(axis="left", text=ylabel) + # Set axes extent + yextent, xextent = extents + x = xextent[0] + y = yextent[0] + w = xextent[1] - xextent[0] + h = yextent[1] - yextent[0] + self.getImageItem().setRect(x, y, w, h) + class RunBrowserDisplay(display.FireflyDisplay): runs_model: QStandardItemModel @@ -69,64 +235,148 @@ class RunBrowserDisplay(display.FireflyDisplay): ] _multiplot_items = {} - # Signals - runs_selected = Signal(list) - runs_model_changed = Signal(QStandardItemModel) - plot_1d_changed = Signal(object) - filters_changed = Signal(dict) + selected_runs: list + _running_db_tasks: Mapping def __init__(self, root_node=None, args=None, macros=None, **kwargs): super().__init__(args=args, macros=macros, **kwargs) - self.start_run_client(root_node=root_node) - - def start_run_client(self, root_node): - """Set up the database client in a separate thread.""" - # Create the thread and worker - thread = QThread(parent=self) - self._thread = thread - worker = DatabaseWorker(root_node=root_node) - self._db_worker = worker - worker.moveToThread(thread) - # Connect signals/slots - thread.started.connect(worker.load_all_runs) - worker.all_runs_changed.connect(self.set_runs_model_items) - worker.selected_runs_changed.connect(self.update_metadata) - worker.selected_runs_changed.connect(self.update_1d_signals) - worker.selected_runs_changed.connect(self.update_1d_plot) - worker.selected_runs_changed.connect(self.update_multi_plot) - worker.db_op_started.connect(self.disable_run_widgets) - worker.db_op_ended.connect(self.enable_run_widgets) - self.runs_selected.connect(worker.load_selected_runs) - self.ui.refresh_runs_button.clicked.connect(worker.load_all_runs) - self.filters_changed.connect(worker.set_filters) - worker.new_message.connect(self.show_message) - # Make sure filters are current - self.update_filters() - # Start the thread - thread.start() - # Get distinct fields so we can populate the comboboxes - worker.distinct_fields_changed.connect(self.update_combobox_items) - worker.load_distinct_fields() - - def update_combobox_items(self, fields): - for field_name, cb in [ - ("proposal_users", self.ui.filter_proposal_combobox), - ("proposal_id", self.ui.filter_user_combobox), - ("esaf_id", self.ui.filter_esaf_combobox), - ("sample_name", self.ui.filter_sample_combobox), - ("plan_name", self.ui.filter_plan_combobox), - ("edge", self.ui.filter_edge_combobox), - ]: - if field_name in fields.keys(): - cb.clear() - cb.addItems(fields[field_name]) + self.selected_runs = [] + self._running_db_tasks = {} + self.db = DatabaseWorker(catalog=root_node) + # Load the list of all runs for the selection widget + self.db_task(self.load_runs()) + + + + def db_task(self, coro, name="default task"): + """Executes a co-routine as a database task. Existing database + tasks get cancelled. + + """ + # Check for existing tasks + has_previous_task = name in self._running_db_tasks.keys() + task_is_running = has_previous_task and not self._running_db_tasks[name].done() + if task_is_running: + self._running_db_tasks[name].cancel("New database task started.") + # Wait on this task to be done + new_task = asyncio.ensure_future(coro) + self._running_db_tasks[name] = new_task + return new_task + + @asyncSlot() + async def reload_runs(self): + """A simple wrapper to make load_runs a slot.""" + await self.load_runs() + + @cancellable + async def load_runs(self): + """Get the list of available runs based on filters.""" + runs = await self.db_task( + self.db.load_all_runs(self.filters()), + name="load all runs", + ) + # Update the table view data model + self.runs_model.clear() + self.runs_model.setHorizontalHeaderLabels(self._run_col_names) + for run in runs: + items = [QStandardItem(val) for val in run.values()] + self.ui.runs_model.appendRow(items) + # Adjust the layout of the data table + sort_col = self._run_col_names.index("Datetime") + self.ui.run_tableview.sortByColumn(sort_col, Qt.DescendingOrder) + self.ui.run_tableview.resizeColumnsToContents() + # Let slots know that the model data have changed + self.runs_total_label.setText(str(self.ui.runs_model.rowCount())) + + # # def start_run_client(self, root_node): + # # """Set up the database client in a separate thread.""" + # # # Create the thread and worker + # # thread = QThread(parent=self) + # # self._thread = thread + # # worker = DatabaseWorker(root_node=root_node) + # # self._db_worker = worker + # # worker.moveToThread(thread) + # # # Set up filters + # # worker.new_message.connect(self.show_message) + # # self.filters_changed.connect(worker.set_filters) + # # # Connect signals/slots + # # thread.started.connect(worker.load_all_runs) + # # worker.all_runs_changed.connect(self.set_runs_model_items) + # # worker.selected_runs_changed.connect(self.update_metadata) + # # worker.selected_runs_changed.connect(self.update_1d_signals) + # # worker.selected_runs_changed.connect(self.update_2d_signals) + # # worker.selected_runs_changed.connect(self.update_1d_plot) + # # worker.selected_runs_changed.connect(self.update_2d_plot) + # # worker.selected_runs_changed.connect(self.update_multi_plot) + # # worker.db_op_started.connect(self.disable_run_widgets) + # # worker.db_op_ended.connect(self.enable_run_widgets) + # # # Make sure filters are current + # # self.update_filters() + # # # Start the thread + # # thread.start() + # # # Get distinct fields so we can populate the comboboxes + # # self.load_distinct_fields.connect(worker.load_distinct_fields) + # # worker.distinct_fields_changed.connect(self.update_combobox_items) + # # self.load_distinct_fields.emit() + + def clear_filters(self): + self.ui.filter_proposal_combobox.setCurrentText("") + self.ui.filter_esaf_combobox.setCurrentText("") + self.ui.filter_sample_combobox.setCurrentText("") + self.ui.filter_exit_status_combobox.setCurrentText("") + self.ui.filter_current_proposal_checkbox.setChecked(False) + self.ui.filter_current_esaf_checkbox.setChecked(False) + self.ui.filter_plan_combobox.setCurrentText("") + self.ui.filter_full_text_lineedit.setText("") + self.ui.filter_edge_combobox.setCurrentText("") + self.ui.filter_user_combobox.setCurrentText("") + + # def update_combobox_items(self, fields): + # for field_name, cb in [ + # ("proposal_users", self.ui.filter_proposal_combobox), + # ("proposal_id", self.ui.filter_user_combobox), + # ("esaf_id", self.ui.filter_esaf_combobox), + # ("sample_name", self.ui.filter_sample_combobox), + # ("plan_name", self.ui.filter_plan_combobox), + # ("edge", self.ui.filter_edge_combobox), + # ]: + # if field_name in fields.keys(): + # old_text = cb.currentText() + # cb.clear() + # cb.addItems(fields[field_name]) + # cb.setCurrentText(old_text) + + @asyncSlot() + @cancellable + async def sleep_slot(self): + print("Sleeping") + await self.db_task(self.print_sleep()) + + async def print_sleep(self): + label = self.ui.sleep_label + label.setText(f"3...") + await asyncio.sleep(1) + old_text = label.text() + label.setText(f"{old_text}2...") + await asyncio.sleep(1) + old_text = label.text() + label.setText(f"{old_text}1...") + await asyncio.sleep(1) + old_text = label.text() + label.setText(f"{old_text}done!") def customize_ui(self): self.load_models() # Setup controls for select which run to show + self.ui.run_tableview.selectionModel().selectionChanged.connect( self.update_selected_runs ) + self.ui.refresh_runs_button.setIcon(qta.icon("fa5s.sync")) + self.ui.refresh_runs_button.clicked.connect(self.reload_runs) + # Sleep controls for testing async timing + self.ui.sleep_button.clicked.connect(self.sleep_slot) + # Respond to changes in displaying the 1d plot self.ui.signal_y_combobox.currentTextChanged.connect(self.update_1d_plot) self.ui.signal_x_combobox.currentTextChanged.connect(self.update_1d_plot) self.ui.signal_r_combobox.currentTextChanged.connect(self.update_1d_plot) @@ -135,80 +385,49 @@ def customize_ui(self): self.ui.invert_checkbox.stateChanged.connect(self.update_1d_plot) self.ui.gradient_checkbox.stateChanged.connect(self.update_1d_plot) self.ui.plot_1d_hints_checkbox.stateChanged.connect(self.update_1d_signals) - self.ui.refresh_runs_button.setIcon(qta.icon("fa5s.sync")) + # Respond to changes in displaying the 2d plot + self.ui.signal_value_combobox.currentTextChanged.connect(self.update_2d_plot) + self.ui.logarithm_checkbox_2d.stateChanged.connect(self.update_2d_plot) + self.ui.invert_checkbox_2d.stateChanged.connect(self.update_2d_plot) + self.ui.gradient_checkbox_2d.stateChanged.connect(self.update_2d_plot) + self.ui.plot_2d_hints_checkbox.stateChanged.connect(self.update_2d_signals) # Respond to filter controls getting updated - self.ui.filter_user_combobox.currentTextChanged.connect(self.update_filters) - self.ui.filter_proposal_combobox.currentTextChanged.connect(self.update_filters) - self.ui.filter_sample_combobox.currentTextChanged.connect(self.update_filters) - self.ui.filter_exit_status_combobox.currentTextChanged.connect( - self.update_filters - ) - self.ui.filter_esaf_combobox.currentTextChanged.connect(self.update_filters) - self.ui.filter_current_proposal_checkbox.stateChanged.connect( - self.update_filters - ) - self.ui.filter_current_esaf_checkbox.stateChanged.connect(self.update_filters) - self.ui.filter_plan_combobox.currentTextChanged.connect(self.update_filters) - self.ui.filter_full_text_lineedit.textChanged.connect(self.update_filters) - self.ui.filter_edge_combobox.currentTextChanged.connect(self.update_filters) self.ui.filters_widget.returnPressed.connect(self.refresh_runs_button.click) # Set up 1D plotting widgets self.plot_1d_item = self.ui.plot_1d_view.getPlotItem() + self.plot_2d_item = self.ui.plot_2d_view.getImageItem() self.plot_1d_item.addLegend() self.plot_1d_item.hover_coords_changed.connect( self.ui.hover_coords_label.setText ) - def get_signals(self, run, hinted_only=False): - if hinted_only: - xsignals = run.metadata["start"]["hints"]["dimensions"][0][0] - ysignals = [] - hints = run["primary"].metadata["descriptors"][0]["hints"] - for device, dev_hints in hints.items(): - ysignals.extend(dev_hints["fields"]) - else: - xsignals = ysignals = run["primary"]["data"].keys() - return xsignals, ysignals - - def set_runs_model_items(self, runs): - self.runs_model.clear() - self.runs_model.setHorizontalHeaderLabels(self._run_col_names) - for run in runs: - items = [QStandardItem(val) for val in run.values()] - self.ui.runs_model.appendRow(items) - # Adjust the layout of the data table - sort_col = self._run_col_names.index("Datetime") - self.ui.run_tableview.sortByColumn(sort_col, Qt.DescendingOrder) - self.ui.run_tableview.resizeColumnsToContents() - # Let slots know that the model data have changed - self.runs_model_changed.emit(self.ui.runs_model) - self.runs_total_label.setText(str(self.ui.runs_model.rowCount())) - - def disable_run_widgets(self): - self.show_message("Loading...") - widgets = [ - self.ui.run_tableview, - self.ui.refresh_runs_button, - self.ui.detail_tabwidget, - self.ui.runs_total_layout, - self.ui.filters_widget, - ] - for widget in widgets: - widget.setEnabled(False) - self.disabled_widgets = widgets - self.setCursor(Qt.WaitCursor) - - def enable_run_widgets(self, exceptions=[]): - if any(exceptions): - self.show_message(exceptions[0]) - else: - self.show_message("Done", 5000) - # Re-enable the widgets - for widget in self.disabled_widgets: - widget.setEnabled(True) - self.setCursor(Qt.ArrowCursor) - - def update_1d_signals(self, *args): + # def disable_run_widgets(self): + # self.show_message("Loading...") + # widgets = [ + # self.ui.run_tableview, + # self.ui.refresh_runs_button, + # self.ui.detail_tabwidget, + # self.ui.runs_total_layout, + # self.ui.filters_widget, + # ] + # for widget in widgets: + # widget.setEnabled(False) + # self.disabled_widgets = widgets + # self.setCursor(Qt.WaitCursor) + + # def enable_run_widgets(self, exceptions=[]): + # if any(exceptions): + # self.show_message(exceptions[0]) + # else: + # self.show_message("Done", 5000) + # # Re-enable the widgets + # for widget in self.disabled_widgets: + # widget.setEnabled(True) + # self.setCursor(Qt.ArrowCursor) + + @asyncSlot() + @cancellable + async def update_1d_signals(self, *args): # Store old values for restoring later comboboxes = [ self.ui.signal_x_combobox, @@ -217,22 +436,11 @@ def update_1d_signals(self, *args): ] old_values = [cb.currentText() for cb in comboboxes] # Determine valid list of columns to choose from - xcols = set() - ycols = set() - runs = self._db_worker.selected_runs use_hints = self.ui.plot_1d_hints_checkbox.isChecked() - for run in runs: - try: - _xcols, _ycols = self.get_signals(run, hinted_only=use_hints) - except KeyError: - continue - else: - xcols.update(_xcols) - ycols.update(_ycols) - # Update the UI with the list of controls - xcols = sorted(list(set(xcols))) - ycols = sorted(list(set(ycols))) + signals_task = self.db_task(self.db.signal_names(hinted_only=use_hints), "1D signals") + xcols, ycols = await signals_task self.multi_y_signals = ycols + # Update the comboboxes with new signals for cb in [self.ui.multi_signal_x_combobox, self.ui.signal_x_combobox]: cb.clear() cb.addItems(xcols) @@ -246,123 +454,76 @@ def update_1d_signals(self, *args): for val, cb in zip(old_values, comboboxes): cb.setCurrentText(val) - def calculate_ydata( - self, - x_data, - y_data, - r_data, - x_signal, - y_signal, - r_signal, - use_reference=False, - use_log=False, - use_invert=False, - use_grad=False, - ): - """Take raw y and reference data and calculate a new y_data signal.""" - # Apply transformations - y = y_data - y_string = f"[{y_signal}]" - try: - if use_reference: - y = y / r_data - y_string = f"{y_string}/[{r_signal}]" - if use_log: - y = np.log(y) - y_string = f"ln({y_string})" - if use_invert: - y *= -1 - y_string = f"-{y_string}" - if use_grad: - y = np.gradient(y, x_data) - y_string = f"d({y_string})/d[{r_signal}]" - except TypeError as exc: - msg = f"Could not calculate transformation: {exc}" - log.warning(msg) - raise exceptions.InvalidTransformation(msg) - return y, y_string - - def load_run_data(self, run, x_signal, y_signal, r_signal, use_reference=True): - if "" in [x_signal, y_signal] or (use_reference and r_signal == ""): - log.debug( - f"Empty signal name requested: x='{x_signal}', y='{y_signal}'," - f" r='{r_signal}'" - ) - raise exceptions.EmptySignalName - signals = [x_signal, y_signal] - if use_reference: - signals.append(r_signal) - try: - data = run["primary"]["data"] - y_data = data[y_signal] - x_data = data[x_signal] - if use_reference: - r_data = data[r_signal] - else: - r_data = 1 - except KeyError as e: - # No data, so nothing to plot - msg = f"Cannot find key {e} in {run}." - log.warning(msg) - raise exceptions.SignalNotFound(msg) - except ValidationError: - print("Pydantic error:", run) - raise - return x_data, y_data, r_data - - def multiplot_items(self, n_cols: int = 3): - view = self.ui.plot_multi_view - item0 = None - for idx in count(): - row = int(idx / n_cols) - col = idx % n_cols - # Make a new plot item if one doesn't exist - if (row, col) not in self._multiplot_items: - self._multiplot_items[(row, col)] = view.addPlot(row=row, col=col) - new_item = self._multiplot_items[(row, col)] - # Link the X-axes together - if item0 is None: - item0 = new_item - else: - new_item.setXLink(item0) - # Resize the viewing area to fit the contents - width = view.width() - plot_width = width / n_cols - # view.resize(int(width), int(plot_width * row)) - view.setFixedHeight(1200) - yield new_item - - def update_multi_plot(self, *args): + @asyncSlot() + @cancellable + async def update_2d_signals(self, *args): + # Store current selection for restoring later + val_cb = self.ui.signal_value_combobox + old_value = val_cb.currentText() + # Determine valid list of dependent signals to choose from + use_hints = self.ui.plot_2d_hints_checkbox.isChecked() + xcols, vcols = await self.db_task(self.db.signal_names(hinted_only=use_hints), "2D signals") + # Update the UI with the list of controls + val_cb.clear() + val_cb.addItems(vcols) + # Restore previous selection + val_cb.setCurrentText(old_value) + + + # def calculate_ydata( + # self, + # x_data, + # y_data, + # r_data, + # x_signal, + # y_signal, + # r_signal, + # use_reference=False, + # use_log=False, + # use_invert=False, + # use_grad=False, + # ): + # """Take raw y and reference data and calculate a new y_data signal.""" + # # Make sure we have numpy arrays + # x = np.asarray(x_data) + # y = np.asarray(y_data) + # r = np.asarray(r_data) + # # Apply transformations + # y_string = f"[{y_signal}]" + # try: + # if use_reference: + # y = y / r + # y_string = f"{y_string}/[{r_signal}]" + # if use_log: + # y = np.log(y) + # y_string = f"ln({y_string})" + # if use_invert: + # y *= -1 + # y_string = f"-{y_string}" + # if use_grad: + # y = np.gradient(y, x) + # y_string = f"d({y_string})/d[{r_signal}]" + # except TypeError as exc: + # msg = f"Could not calculate transformation: {exc}" + # log.warning(msg) + # raise + # raise exceptions.InvalidTransformation(msg) + # return y, y_string + + + @asyncSlot() + @cancellable + async def update_multi_plot(self, *args): x_signal = self.ui.multi_signal_x_combobox.currentText() if x_signal == "": return - y_signals = self.multi_y_signals - all_signals = set((x_signal, *y_signals)) - view = self.ui.plot_multi_view - view.clear() - self._multiplot_items = {} - n_cols = 3 - runs = self._db_worker.selected_runs - for run in runs: - data = run["primary"]["data"].read(all_signals) - try: - xdata = data[x_signal] - except KeyError: - log.warning(f"Cannot plot x='{x_signal}' for {list(data.keys())}") - continue - for y_signal, plot_item in zip(y_signals, self.multiplot_items()): - # Get data from the database - try: - plot_item.plot(xdata, data[y_signal]) - except KeyError: - log.warning(f"Cannot plot y='{y_signal}' for {list(data.keys())}") - continue - else: - log.debug(f"Plotted {y_signal} vs. {x_signal} for {data}") - finally: - plot_item.setTitle(y_signal) + use_hints = self.ui.plot_1d_hints_checkbox.isChecked() + runs = await self.db_task(self.db.all_signals(hinted_only=use_hints), "multi-plot") + self.ui.plot_multi_view.plot_runs(runs, xsignal=x_signal) - def update_1d_plot(self, *args): + @asyncSlot() + @cancellable + async def update_1d_plot(self, *args): self.plot_1d_item.clear() # Figure out which signals to plot y_signal = self.ui.signal_y_combobox.currentText() @@ -375,85 +536,68 @@ def update_1d_plot(self, *args): use_log = self.ui.logarithm_checkbox.isChecked() use_invert = self.ui.invert_checkbox.isChecked() use_grad = self.ui.gradient_checkbox.isChecked() - # Do the plotting for each run - y_string = "" - x_data = None - for idx, run in enumerate(self._db_worker.selected_runs): - # Load datasets from the database - try: - x_data, y_data, r_data = self.load_run_data( - run, x_signal, y_signal, r_signal, use_reference=use_reference - ) - except exceptions.SignalNotFound as e: - self.show_message(str(e), 0) - continue - except exceptions.EmptySignalName: - continue - # Screen out non-numeric data types - try: - np.isfinite(x_data) - np.isfinite(y_data) - np.isfinite(r_data) - except TypeError as e: - msg = str(e) - log.warning(msg) - self.show_message(msg) - continue - # Calculate plotting data - try: - y_data, y_string = self.calculate_ydata( - x_data, - y_data, - r_data, - x_signal, - y_signal, - r_signal, - use_reference=use_reference, - use_log=use_log, - use_invert=use_invert, - use_grad=use_grad, - ) - except exceptions.InvalidTransformation as e: - self.show_message(str(e)) - continue - # Plot this run's data - color = colors[idx % len(colors)] - self.plot_1d_item.plot( - x=x_data, - y=y_data, - pen=color, - name=run.metadata["start"]["uid"], - clear=False, - ) - # Axis formatting - self.plot_1d_item.setLabels(left=y_string, bottom=x_signal) - if x_data is not None: - self.plot_1d_item.addLine( - x=np.median(x_data), movable=True, label="{value:.3f}" - ) - self.plot_1d_changed.emit(self.plot_1d_item) + task = self.db_task(self.db.signals(x_signal, y_signal, r_signal, use_log=use_log, use_invert=use_invert, use_grad=use_grad), "1D plot") + runs = await task + self.ui.plot_1d_view.plot_runs(runs) + + @asyncSlot() + @cancellable + async def update_2d_plot(self): + """Change the 2D map plot based on desired signals, etc.""" + # Figure out which signals to plot + value_signal = self.ui.signal_value_combobox.currentText() + use_log = self.ui.logarithm_checkbox_2d.isChecked() + use_invert = self.ui.invert_checkbox_2d.isChecked() + use_grad = self.ui.gradient_checkbox_2d.isChecked() + images = await self.db_task(self.db.images(value_signal), "2D plot") + # Get axis labels + # Eventually this will be replaced with robus choices for plotting multiple images + metadata = await self.db_task(self.db.metadata(), "2D plot") + metadata = list(metadata.values())[0] + dimensions = metadata['start']['hints']['dimensions'] + try: + xlabel = dimensions[-1][0][0] + ylabel = dimensions[-2][0][0] + except IndexError: + # Not a 2D scan + return + # Get spatial extent + extents = metadata['start']['extents'] + self.ui.plot_2d_view.plot_runs(images, xlabel=xlabel, ylabel=ylabel, extents=extents) - def update_metadata(self, *args): + @asyncSlot() + async def update_metadata(self, *args): """Render metadata for the runs into the metadata widget.""" # Combine the metadata in a human-readable output text = "" - runs = self._db_worker.selected_runs - for run in runs: - md_dict = dict(**run.metadata) - text += yaml.dump(md_dict) + all_md = await self.db_task(self.db.metadata(), "metadata") + for uid, md in all_md.items(): + text += f"# {uid}" + text += yaml.dump(md) text += f"\n\n{'=' * 20}\n\n" # Update the widget with the rendered metadata self.ui.metadata_textedit.document().setPlainText(text) - def update_selected_runs(self, *args): + @asyncSlot() + @cancellable + async def update_selected_runs(self, *args): """Get the current runs from the database and stash them.""" # Get UID's from the selection col_idx = self._run_col_names.index("UID") indexes = self.ui.run_tableview.selectedIndexes() uids = [i.siblingAtColumn(col_idx).data() for i in indexes] - self.runs_selected.emit(uids) - - def update_filters(self, *args): + # Get selected runs from the database + task = self.db_task(self.db.load_selected_runs(uids), "update selected runs") + self.selected_runs = await task + # Update the necessary UI elements + await self.update_1d_signals() + await self.update_2d_signals() + await self.update_metadata() + await self.update_1d_plot() + await self.update_2d_plot() + await self.update_multi_plot() + + def filters(self, *args): new_filters = { "proposal": self.ui.filter_proposal_combobox.currentText(), "esaf": self.ui.filter_esaf_combobox.currentText(), @@ -470,7 +614,7 @@ def update_filters(self, *args): } null_values = ["", False] new_filters = {k: v for k, v in new_filters.items() if v not in null_values} - self.filters_changed.emit(new_filters) + return new_filters def load_models(self): # Set up the model diff --git a/src/firefly/run_browser.ui b/src/firefly/run_browser.ui index bbcec3c1..d1bd4ed2 100644 --- a/src/firefly/run_browser.ui +++ b/src/firefly/run_browser.ui @@ -20,7 +20,7 @@ Qt::Horizontal - + @@ -163,6 +163,40 @@ + + + + 4 + + + + + Sleep + + + + + + + <- Press the button + + + + + + + Qt::Horizontal + + + + 40 + 20 + + + + + + @@ -212,11 +246,14 @@ + + The sample composition to look for. Supports regular expression. E.g. Sb.*Te + true - Sb.*Te + @@ -229,11 +266,14 @@ + + The X-ray absorption edge, or energy in electron-volts, that an energy scan was collected. Supports regular expressions. E.g. Ni_K + true - Sb + @@ -250,7 +290,7 @@ true - wolfman + @@ -388,7 +428,7 @@ - 2 + 1 @@ -439,10 +479,13 @@ - + Use Hints + + true + @@ -476,8 +519,8 @@ 0 0 - 84 - 28 + 788 + 460 @@ -488,7 +531,7 @@ - + 0 @@ -745,11 +788,156 @@ - false + true Map + + + + + + + Qt::Horizontal + + + + 40 + 20 + + + + + + + + true + + + Limits the selectable signals to those marked as "hinted". + + + Use Hints + + + true + + + + + + + + + Value: + + + + + + + true + + + + + + + + + true + + + Plot on a logarithmic scale + + + Log + + + + + + + true + + + Invert the values in the map + + + Invert + + + + + + + true + + + Show the magnitude of the gradient + + + Grad + + + + + + + + 0 + 0 + + + + What comparison to use if multiple compatible scans have been selected. + + + + Median + + + + + Mean + + + + + Sum + + + + + StDev + + + + + Max + + + + + Min + + + + + + + + + + + 0 + 0 + + + + + @@ -785,13 +973,19 @@ 1 - GraphicsLayoutWidget + Browser1DPlotWidget QWidget -
pyqtgraph
+
firefly.run_browser
1
- Browser1DPlotWidget + Browser2DPlotWidget + QWidget +
firefly.run_browser
+ 1 +
+ + BrowserMultiPlotWidget QWidget
firefly.run_browser
1 diff --git a/src/firefly/run_client.py b/src/firefly/run_client.py index 83e57465..7fff36eb 100644 --- a/src/firefly/run_client.py +++ b/src/firefly/run_client.py @@ -1,19 +1,22 @@ +import asyncio import datetime as dt import logging from collections import OrderedDict -from typing import Sequence +from typing import Sequence, Mapping from qtpy.QtCore import QObject, Signal, Slot from tiled import queries +import pandas as pd +import numpy as np -from haven import tiled_client +from haven.catalog import Catalog +from haven import exceptions log = logging.getLogger(__name__) -class DatabaseWorker(QObject): +class DatabaseWorker(): selected_runs: Sequence = [] - _filters = {"exit_status": "success"} # Signals all_runs_changed = Signal(list) @@ -23,20 +26,14 @@ class DatabaseWorker(QObject): db_op_started = Signal() db_op_ended = Signal(list) # (list of exceptions thrown) - def __init__(self, root_node, *args, **kwargs): - if root_node is None: - root_node = tiled_client() - self.root = root_node + def __init__(self, catalog=None, *args, **kwargs): + if catalog is None: + catalog = Catalog() + self.catalog = catalog super().__init__(*args, **kwargs) - def set_filters(self, filters): - log.debug(f"Setting new filters: {filters}") - self._filters = filters - - def filtered_nodes(self): + async def filtered_nodes(self, filters: Mapping): case_sensitive = False - runs = self.root - filters = self._filters log.debug(f"Filtering nodes: {filters}") filter_params = [ # (filter_name, query type, metadata key) @@ -48,19 +45,20 @@ def filtered_nodes(self): ("plan", queries.Regex, "plan_name"), ("edge", queries.Regex, "edge"), ] + # Apply filters + runs = self.catalog for filter_name, Query, md_name in filter_params: val = filters.get(filter_name, "") if val != "": - runs = runs.search(Query(md_name, val, case_sensitive=case_sensitive)) + runs = await runs.search(Query(md_name, val, case_sensitive=case_sensitive)) full_text = filters.get("full_text", "") if full_text != "": - runs = runs.search( + runs = await runs.search( queries.FullText(full_text, case_sensitive=case_sensitive) ) return runs - @Slot() - def load_distinct_fields(self): + async def load_distinct_fields(self): """Get distinct metadata fields for filterable metadata. Emits @@ -81,87 +79,174 @@ def load_distinct_fields(self): "edge", ] # Get fields from the database - response = self.root.distinct(*target_fields) + response = await self.catalog.distinct(*target_fields) # Build into a new dictionary for key, result in response["metadata"].items(): field = key.split(".")[-1] new_fields[field] = [r["value"] for r in result] - self.distinct_fields_changed.emit(new_fields) + return new_fields - @Slot() - def load_all_runs(self): + async def load_all_runs(self, filters: Mapping = {}): all_runs = [] - nodes = self.filtered_nodes() - self.db_op_started.emit() - try: - for uid, node in nodes.items(): - # Get meta-data documents - metadata = node.metadata - start_doc = metadata.get("start") - if start_doc is None: - log.debug(f"Skipping run with no start doc: {uid}") - continue - stop_doc = node.metadata.get("stop") - if stop_doc is None: - stop_doc = {} - # Get a human-readable timestamp for the run - timestamp = start_doc.get("time") - if timestamp is None: - run_datetime = "" - else: - run_datetime = dt.datetime.fromtimestamp(timestamp) - run_datetime = run_datetime.strftime("%Y-%m-%d %H:%M:%S") - # Get the X-ray edge scanned - edge = start_doc.get("edge") - E0 = start_doc.get("E0") - E0_str = "" if E0 is None else str(E0) - if edge and E0: - edge_str = f"{edge} ({E0} eV)" - elif edge: - edge_str = edge - elif E0: - edge_str = E0_str - else: - edge_str = "" - # Build the table item - # Get sample data from: dd80f432-c849-4749-a8f3-bdeec6f9c1f0 - run_data = OrderedDict( - plan_name=start_doc.get("plan_name", ""), - sample_name=start_doc.get("sample_name", ""), - edge=edge_str, - E0=E0_str, - exit_status=stop_doc.get("exit_status", ""), - run_datetime=run_datetime, - uid=uid, - proposal_id=start_doc.get("proposal_id", ""), - esaf_id=start_doc.get("esaf_id", ""), - esaf_users=start_doc.get("esaf_users", ""), - ) - all_runs.append(run_data) - except Exception as exc: - self.db_op_ended.emit([exc]) - raise - else: - self.db_op_ended.emit([]) - self.all_runs_changed.emit(all_runs) - - @Slot(list) - def load_selected_runs(self, uids): + nodes = await self.filtered_nodes(filters=filters) + async for uid, node in nodes.items(): + # Get meta-data documents + metadata = await node.metadata + start_doc = metadata.get("start") + if start_doc is None: + log.debug(f"Skipping run with no start doc: {uid}") + continue + stop_doc = metadata.get("stop") + if stop_doc is None: + stop_doc = {} + # Get a human-readable timestamp for the run + timestamp = start_doc.get("time") + if timestamp is None: + run_datetime = "" + else: + run_datetime = dt.datetime.fromtimestamp(timestamp) + run_datetime = run_datetime.strftime("%Y-%m-%d %H:%M:%S") + # Get the X-ray edge scanned + edge = start_doc.get("edge") + E0 = start_doc.get("E0") + E0_str = "" if E0 is None else str(E0) + if edge and E0: + edge_str = f"{edge} ({E0} eV)" + elif edge: + edge_str = edge + elif E0: + edge_str = E0_str + else: + edge_str = "" + # Build the table item + # Get sample data from: dd80f432-c849-4749-a8f3-bdeec6f9c1f0 + run_data = OrderedDict( + plan_name=start_doc.get("plan_name", ""), + sample_name=start_doc.get("sample_name", ""), + edge=edge_str, + E0=E0_str, + exit_status=stop_doc.get("exit_status", ""), + run_datetime=run_datetime, + uid=uid, + proposal_id=start_doc.get("proposal_id", ""), + esaf_id=start_doc.get("esaf_id", ""), + esaf_users=start_doc.get("esaf_users", ""), + ) + all_runs.append(run_data) + return all_runs + + async def signal_names(self, hinted_only: bool = False): + """Get a list of valid signal names (data columns) for selected runs. + + Parameters + ========== + hinted_only + If true, only signals with the kind="hinted" parameter get + picked. + + """ + xsignals, ysignals = [], [] + for run in self.selected_runs: + if hinted_only: + xsig, ysig = await run.hints() + else: + df = await run.to_dataframe() + xsig = ysig = df.columns + xsignals.extend(xsig) + ysignals.extend(ysig) + # Remove duplicates + xsignals = list(dict.fromkeys(xsignals)) + ysignals = list(dict.fromkeys(ysignals)) + return list(xsignals), list(ysignals) + + async def metadata(self): + """Get all metadata for the selected runs in one big dictionary.""" + md = {} + for run in self.selected_runs: + md[run.uid] = await run.metadata + return md + + async def load_selected_runs(self, uids): + # Prepare the query for finding the runs + uids = list(dict.fromkeys(uids)) # Retrieve runs from the database - uids = list(set(uids)) - self.db_op_started.emit() - # Download each item, maybe we can find a more efficient way to do this - try: - runs = [self.root[uid] for uid in uids] - except Exception as exc: - self.db_op_ended.emit([exc]) - raise - else: - self.db_op_ended.emit([]) - # Save and inform clients of the run data + runs = [await self.catalog[uid] for uid in uids] + # runs = await asyncio.gather(*run_coros) self.selected_runs = runs - self.selected_runs_changed.emit(runs) + async def images(self, signal): + """Load the selected runs as 2D or 3D images suitable for plotting.""" + images = OrderedDict() + for idx, run in enumerate(self.selected_runs): + # Load datasets from the database + try: + image = await run[signal] + except KeyError: + log.warning(f"Signal {signal} not found in run {run}.") + else: + images[run.uid] = image + return images + + async def all_signals(self, hinted_only=False): + """Produce dataframe with all signals for each run. + + The keys of the dictionary are the labels for each curve, and + the corresponding value is a pandas dataframe with the scan data. + + """ + xsignals, ysignals = await self.signal_names(hinted_only=hinted_only) + # Build the dataframes + dfs = OrderedDict() + for run in self.selected_runs: + # Get data from the database + df = await run.to_dataframe(signals=xsignals + ysignals) + dfs[run.uid] = df + return dfs + + async def signals(self, x_signal, y_signal, r_signal=None, use_log=False, use_invert=False, use_grad=False) -> Mapping: + """Produce a dictionary with the 1D datasets for plotting. + + The keys of the dictionary are the labels for each curve, and + the corresponding value is a pandas dataset with the data for + each signal. + + """ + # Check for sensible inputs + use_reference = r_signal is not None + if "" in [x_signal, y_signal] or (use_reference and r_signal == ""): + msg = (f"Empty signal name requested: x={repr(x_signal)}, y={repr(y_signal)}," + f" r={repr(r_signal)}") + log.debug(msg) + raise exceptions.EmptySignalName(msg) + signals = [x_signal, y_signal] + if use_reference: + signals.append(r_signal) + # Remove duplicates + signals = list(dict.fromkeys(signals).keys()) + # Build the dataframes + dfs = OrderedDict() + for run in self.selected_runs: + # Get data from the database + df = await run.to_dataframe(signals=signals) + # Check for missing signals + missing_x = x_signal not in df.columns + missing_y = y_signal not in df.columns + missing_r = r_signal not in df.columns + if missing_x or missing_y or (use_reference and missing_r): + log.warning("Could not find signals {x_signal}, {y_signal} and {r_signal}") + continue + # Apply transformations + if use_reference: + df[y_signal] /= df[r_signal] + if use_log: + df[y_signal] = np.log(df[y_signal]) + if use_invert: + df[y_signal] *= -1 + if use_grad: + df[y_signal] = np.gradient(df[y_signal], df[x_signal]) + series = pd.Series(df[y_signal].values, index=df[x_signal].values) + dfs[run.uid] = series + return dfs # ----------------------------------------------------------------------------- # :author: Mark Wolfman diff --git a/src/firefly/tests/test_run_browser.py b/src/firefly/tests/test_run_browser.py index c0a7168b..33caed71 100644 --- a/src/firefly/tests/test_run_browser.py +++ b/src/firefly/tests/test_run_browser.py @@ -1,108 +1,40 @@ import logging from unittest.mock import MagicMock +import asyncio import numpy as np import pandas as pd import pytest -from pyqtgraph import PlotItem, PlotWidget +from pyqtgraph import PlotItem, PlotWidget, ImageView, ImageItem from qtpy.QtCore import Qt -from tiled.adapters.mapping import MapAdapter -from tiled.adapters.xarray import DatasetAdapter -from tiled.client import Context, from_context -from tiled.server.app import build_app +from haven.catalog import Catalog from firefly.run_browser import RunBrowserDisplay from firefly.run_client import DatabaseWorker -log = logging.getLogger(__name__) - - -def wait_for_runs_model(display, qtbot): - with qtbot.waitSignal(display.runs_model_changed): - pass - - -# Some mocked test data -run1 = pd.DataFrame( - { - "energy_energy": np.linspace(8300, 8400, num=100), - "It_net_counts": np.abs(np.sin(np.linspace(0, 4 * np.pi, num=100))), - "I0_net_counts": np.linspace(1, 2, num=100), - } -) - -hints = { - "energy": {"fields": ["energy_energy", "energy_id_energy_readback"]}, -} - -bluesky_mapping = { - "7d1daf1d-60c7-4aa7-a668-d1cd97e5335f": MapAdapter( - { - "primary": MapAdapter( - { - "data": DatasetAdapter.from_dataset(run1.to_xarray()), - }, - metadata={"descriptors": [{"hints": hints}]}, - ), - }, - metadata={ - "plan_name": "xafs_scan", - "start": { - "plan_name": "xafs_scan", - "uid": "7d1daf1d-60c7-4aa7-a668-d1cd97e5335f", - "hints": {"dimensions": [[["energy_energy"], "primary"]]}, - }, - }, - ), - "9d33bf66-9701-4ee3-90f4-3be730bc226c": MapAdapter( - { - "primary": MapAdapter( - { - "data": DatasetAdapter.from_dataset(run1.to_xarray()), - }, - metadata={"descriptors": [{"hints": hints}]}, - ), - }, - metadata={ - "start": { - "plan_name": "rel_scan", - "uid": "9d33bf66-9701-4ee3-90f4-3be730bc226c", - "hints": {"dimensions": [[["pitch2"], "primary"]]}, - } - }, - ), -} - - -mapping = { - "255id_testing": MapAdapter(bluesky_mapping), -} - -tree = MapAdapter(mapping) - - -@pytest.fixture(scope="module") -def client(): - app = build_app(tree) - with Context.from_app(app) as context: - client = from_context(context) - yield client["255id_testing"] + +# pytest.skip("Need to migrate the module to gemviz fork", allow_module_level=True) @pytest.fixture() -def display(ffapp, client, qtbot): - display = RunBrowserDisplay(root_node=client) - wait_for_runs_model(display, qtbot) +def display(affapp, catalog): + display = RunBrowserDisplay(root_node=catalog) + display.clear_filters() + # Flush pending async coroutines + loop = asyncio.get_event_loop() + pending = asyncio.all_tasks(loop) + loop.run_until_complete(asyncio.gather(*pending)) + assert all(task.done() for task in pending), "Init tasks not complete." + # Run the test + # yield display try: yield display finally: - display._thread.quit() - display._thread.wait(msecs=5000) - assert not display._thread.isRunning() - - -def test_client_fixture(client): - """Does the client fixture load without stalling the test runner?""" + # Cancel remaining tasks + loop = asyncio.get_event_loop() + pending = asyncio.all_tasks(loop) + loop.run_until_complete(asyncio.gather(*pending)) + assert all(task.done() for task in pending), "Shutdown tasks not complete." def test_run_viewer_action(ffapp, monkeypatch): @@ -112,50 +44,46 @@ def test_run_viewer_action(ffapp, monkeypatch): assert isinstance(ffapp.windows["run_browser"], MagicMock) -def test_load_runs(display): +@pytest.mark.asyncio +async def test_load_runs(display): assert display.runs_model.rowCount() > 0 assert display.ui.runs_total_label.text() == str(display.runs_model.rowCount()) -def test_update_selected_runs(qtbot, display): +@pytest.mark.asyncio +async def test_update_selected_runs(display): # Change the proposal item selection_model = display.ui.run_tableview.selectionModel() item = display.runs_model.item(0, 1) assert item is not None - rect = display.run_tableview.visualRect(item.index()) - with qtbot.waitSignal(display._db_worker.selected_runs_changed): - qtbot.mouseClick( - display.run_tableview.viewport(), Qt.LeftButton, pos=rect.center() - ) + display.ui.run_tableview.selectRow(0) + # Update the runs + await display.update_selected_runs() # Check that the runs were saved - assert len(display._db_worker.selected_runs) > 0 + assert len(display.db.selected_runs) > 0 -def test_metadata(qtbot, display): +@pytest.mark.asyncio +async def test_metadata(display): # Change the proposal item - selection_model = display.ui.run_tableview.selectionModel() - item = display.runs_model.item(0, 1) - assert item is not None - rect = display.run_tableview.visualRect(item.index()) - with qtbot.waitSignal(display._db_worker.selected_runs_changed): - qtbot.mouseClick( - display.run_tableview.viewport(), Qt.LeftButton, pos=rect.center() - ) + display.ui.run_tableview.selectRow(0) + await display.update_selected_runs() # Check that the metadata was set properly in the Metadata tab metadata_doc = display.ui.metadata_textedit.document() text = display.ui.metadata_textedit.document().toPlainText() assert "xafs_scan" in text -def test_1d_plot_signals(client, display): +@pytest.mark.asyncio +async def test_1d_plot_signals(catalog, display): # Check that the 1D plot was created plot_widget = display.ui.plot_1d_view plot_item = display.plot_1d_item assert isinstance(plot_widget, PlotWidget) assert isinstance(plot_item, PlotItem) - # Update the list of runs and see if the controsl get updated - display._db_worker.selected_runs = client.values() - display._db_worker.selected_runs_changed.emit([]) + # Update the list of runs and see if the controls get updated + display.ui.run_tableview.selectColumn(0) + await display.update_selected_runs() # Check signals in checkboxes for combobox in [ display.ui.multi_signal_x_combobox, @@ -167,8 +95,8 @@ def test_1d_plot_signals(client, display): combobox.findText("energy_energy") > -1 ), f"energy_energy signal not in {combobox.objectName()}." - -def test_1d_plot_signal_memory(client, display): +@pytest.mark.asyncio +async def test_1d_plot_signal_memory(catalog, display): """Do we remember the signals that were previously selected.""" # Check that the 1D plot was created plot_widget = display.ui.plot_1d_view @@ -176,19 +104,20 @@ def test_1d_plot_signal_memory(client, display): assert isinstance(plot_widget, PlotWidget) assert isinstance(plot_item, PlotItem) # Update the list of runs and see if the controls get updated - display._db_worker.selected_runs = client.values() - display.update_1d_signals() + display.ui.run_tableview.selectRow(1) + await display.update_selected_runs() # Check signals in comboboxes cb = display.ui.signal_y_combobox assert cb.currentText() == "energy_energy" cb.setCurrentIndex(1) assert cb.currentText() == "energy_id_energy_readback" # Update the combobox signals and make sure the text didn't change - display.update_1d_signals() + await display.update_1d_signals() assert cb.currentText() == "energy_id_energy_readback" -def test_1d_hinted_signals(client, display): +@pytest.mark.asyncio +async def test_1d_hinted_signals(catalog, display, ffapp): display.ui.plot_1d_hints_checkbox.setChecked(True) # Check that the 1D plot was created plot_widget = display.ui.plot_1d_view @@ -196,8 +125,9 @@ def test_1d_hinted_signals(client, display): assert isinstance(plot_widget, PlotWidget) assert isinstance(plot_item, PlotItem) # Update the list of runs and see if the controsl get updated - display._db_worker.selected_runs = client.values() - display.update_1d_signals() + display.db.selected_runs = [run async for run in catalog.values()] + await display.update_1d_signals() + return # Check signals in checkboxes combobox = display.ui.signal_x_combobox assert ( @@ -207,16 +137,16 @@ def test_1d_hinted_signals(client, display): combobox.findText("It_net_counts") == -1 ), f"unhinted signal found in {combobox.objectName()}." - -@pytest.mark.skip(reason="Need to figure out why tiled fails with this test.") -def test_update_1d_plot(client, display, qtbot): - run = client.values()[0] - run_data = run["primary"]["data"].read() +@pytest.mark.asyncio +async def test_update_1d_plot(catalog, display, ffapp): + # Set up some fake data + run = [run async for run in catalog.values()][0] + display.db.selected_runs = [run] + await display.update_1d_signals() + run_data = await run.to_dataframe() expected_xdata = run_data.energy_energy expected_ydata = np.log(run_data.I0_net_counts / run_data.It_net_counts) expected_ydata = np.gradient(expected_ydata, expected_xdata) - with qtbot.waitSignal(display.plot_1d_changed): - display._db_worker.selected_runs_changed.emit([]) # Set the controls to describe the data we want to test x_combobox = display.ui.signal_x_combobox x_combobox.addItem("energy_energy") @@ -232,8 +162,7 @@ def test_update_1d_plot(client, display, qtbot): display.ui.invert_checkbox.setChecked(True) display.ui.gradient_checkbox.setChecked(True) # Update the plots - display._db_worker.selected_runs = [run] - display.update_1d_plot() + await display.update_1d_plot() # Check that the data were added data_item = display.plot_1d_item.listDataItems()[0] xdata, ydata = data_item.getData() @@ -241,21 +170,64 @@ def test_update_1d_plot(client, display, qtbot): np.testing.assert_almost_equal(ydata, expected_ydata) -def test_update_multi_plot(client, display, qtbot): - run = client.values()[0] - run_data = run["primary"]["data"].read() - expected_xdata = run_data.energy_energy - expected_ydata = np.log(run_data.I0_net_counts / run_data.It_net_counts) +@pytest.mark.asyncio +async def test_2d_plot_signals(catalog, display): + # Check that the 1D plot was created + plot_widget = display.ui.plot_2d_view + plot_item = display.plot_2d_item + assert isinstance(plot_widget, ImageView) + assert isinstance(plot_item, ImageItem) + # Update the list of runs and see if the controls get updated + display.db.selected_runs = [await catalog["85573831-f4b4-4f64-b613-a6007bf03a8d"]] + await display.update_2d_signals() + # Check signals in checkboxes + combobox = display.ui.signal_value_combobox + assert combobox.findText("It_net_counts") > -1 + +@pytest.mark.asyncio +async def test_update_2d_plot(catalog, display): + display.plot_2d_item.setRect = MagicMock() + # Load test data + run = await catalog["85573831-f4b4-4f64-b613-a6007bf03a8d"] + display.db.selected_runs = [run] + await display.update_1d_signals() + # Set the controls to describe the data we want to test + val_combobox = display.ui.signal_value_combobox + val_combobox.addItem("It_net_counts") + val_combobox.setCurrentText("It_net_counts") + display.ui.logarithm_checkbox_2d.setChecked(True) + display.ui.invert_checkbox_2d.setChecked(True) + display.ui.gradient_checkbox_2d.setChecked(True) + # Update the plots + await display.update_2d_plot() + # Determine what the image data should look like + expected_data = await run["It_net_counts"] + expected_data = expected_data.reshape((5, 21)).T + # Check that the data were added + image = display.plot_2d_item.image + np.testing.assert_almost_equal(image, expected_data) + # Check that the axes were formatted correctly + axes = display.plot_2d_view.view.axes + xaxis = axes['bottom']['item'] + yaxis = axes['left']['item'] + assert xaxis.labelText == "aerotech_horiz" + assert yaxis.labelText == "aerotech_vert" + display.plot_2d_item.setRect.assert_called_with(-100, -80, 200, 160) + + +@pytest.mark.asyncio +async def test_update_multi_plot(catalog, display): + run = await catalog["7d1daf1d-60c7-4aa7-a668-d1cd97e5335f"] + expected_xdata = await run['energy_energy'] + expected_ydata = np.log(await run['I0_net_counts'] / await run['It_net_counts']) expected_ydata = np.gradient(expected_ydata, expected_xdata) - with qtbot.waitSignal(display.plot_1d_changed): - display._db_worker.selected_runs_changed.emit([]) # Configure signals display.ui.multi_signal_x_combobox.addItem("energy_energy") display.ui.multi_signal_x_combobox.setCurrentText("energy_energy") display.multi_y_signals = ["energy_energy"] - display._db_worker.selected_runs = [run] + display.db.selected_runs = [run] # Update the plots - display.update_multi_plot() + await display.update_multi_plot() # Check that the data were added # data_item = display._multiplot_items[0].listDataItems()[0] # xdata, ydata = data_item.getData() @@ -263,59 +235,50 @@ def test_update_multi_plot(client, display, qtbot): # np.testing.assert_almost_equal(ydata, expected_ydata) -def test_filter_controls(client, display, qtbot): - # Does editing text change the filters? - display.ui.filter_user_combobox.setCurrentText("") - with qtbot.waitSignal(display.filters_changed): - qtbot.keyClicks(display.ui.filter_user_combobox, "wolfman") - # Set some values for the rest of the controls - display.ui.filter_proposal_combobox.setCurrentText("12345") - display.ui.filter_esaf_combobox.setCurrentText("678901") - display.ui.filter_current_proposal_checkbox.setChecked(True) - display.ui.filter_current_esaf_checkbox.setChecked(True) - display.ui.filter_plan_combobox.addItem("cake") - display.ui.filter_plan_combobox.setCurrentText("cake") - display.ui.filter_full_text_lineedit.setText("Aperature Science") - display.ui.filter_edge_combobox.setCurrentText("U-K") - display.ui.filter_sample_combobox.setCurrentText("Pb.*") - with qtbot.waitSignal(display.filters_changed) as blocker: - display.update_filters() - # Check if the filters were update correctly - filters = blocker.args[0] - assert filters == { - "user": "wolfman", - "proposal": "12345", - "esaf": "678901", - "use_current_proposal": True, - "use_current_esaf": True, - "exit_status": "success", - "plan": "cake", - "full_text": "Aperature Science", - "edge": "U-K", - "sample": "Pb.*", - } - - -def test_filter_runs(client, qtbot): - worker = DatabaseWorker(root_node=client) - worker._filters["plan"] = "xafs_scan" - with qtbot.waitSignal(worker.all_runs_changed) as blocker: - worker.load_all_runs() +@pytest.mark.asyncio +async def test_filter_runs(catalog): + worker = DatabaseWorker(catalog=catalog) + runs = await worker.load_all_runs(filters={"plan": "xafs_scan"}) # Check that the runs were filtered - runs = blocker.args[0] assert len(runs) == 1 -def test_distinct_fields(client, qtbot, display): - worker = DatabaseWorker(root_node=client) - with qtbot.waitSignal(worker.distinct_fields_changed) as blocker: - worker.load_distinct_fields() +@pytest.mark.asyncio +async def test_distinct_fields(catalog, display): + worker = DatabaseWorker(catalog=catalog) + distinct_fields = await worker.load_distinct_fields() # Check that the dictionary has the right structure - distinct_fields = blocker.args[0] for key in ["sample_name"]: assert key in distinct_fields.keys() +@pytest.mark.asyncio +async def test_db_task(display): + async def test_coro(): + return 15 + + result = await display.db_task(test_coro()) + assert result == 15 + + +@pytest.mark.asyncio +async def test_db_task_interruption(display, event_loop): + async def test_coro(sleep_time): + await asyncio.sleep(sleep_time) + return sleep_time + + # Create an existing task that will be cancelled + task_1 = display.db_task(test_coro(1.0), name="testing") + # Now execute another task + result = await display.db_task(test_coro(0.01), name="testing") + assert result == 0.01 + # Check that the first one was cancelled + with pytest.raises(asyncio.exceptions.CancelledError): + await task_1 + assert task_1.done() + assert task_1.cancelled() + + # ----------------------------------------------------------------------------- # :author: Mark Wolfman # :email: wolfman@anl.gov diff --git a/src/firefly/xrf_detector.ui b/src/firefly/xrf_detector.ui index 5e6f3fb2..232ede0d 100644 --- a/src/firefly/xrf_detector.ui +++ b/src/firefly/xrf_detector.ui @@ -196,12 +196,12 @@ true - - haven://${DEV}.acquire - + + haven://${DEV}.acquire +
@@ -433,12 +433,12 @@ 14.0 - - haven://${DEV}.dead_time_min - Lowest dead time across all elements. + + haven://${DEV}.dead_time_min +
@@ -488,12 +488,12 @@ 24.0 - - haven://${DEV}.dead_time_max - Highest dead time across all elements. + + haven://${DEV}.dead_time_max + @@ -1127,7 +1127,7 @@ 0 0 1023 - 364 + 394 @@ -1199,17 +1199,6 @@ 0 - - - 0 - - - 0 - - - 0 - - @@ -1366,8 +1355,8 @@ 0 0 - 863 - 363 + 877 + 393 diff --git a/src/haven/__init__.py b/src/haven/__init__.py index 6f3954f6..6915e8d2 100644 --- a/src/haven/__init__.py +++ b/src/haven/__init__.py @@ -5,7 +5,8 @@ from ._iconfig import load_config # noqa: F401 # Top-level imports -from .catalog import load_catalog, load_data, load_result, tiled_client # noqa: F401 +# from .catalog import load_catalog, load_data, load_result, tiled_client # noqa: F401 +from .catalog import catalog from .constants import edge_energy # noqa: F401 from .energy_ranges import ERange, KRange, merge_ranges # noqa: F401 from .instrument import ( # noqa: F401 diff --git a/src/haven/catalog.py b/src/haven/catalog.py index 72b4dc16..52fc0b6b 100644 --- a/src/haven/catalog.py +++ b/src/haven/catalog.py @@ -1,10 +1,51 @@ +import threading +import asyncio +from functools import partial, lru_cache +import logging + +import pandas as pd +import numpy as np import databroker +import sqlite3 from tiled.client import from_uri from tiled.client.cache import Cache from ._iconfig import load_config +log = logging.getLogger(__name__) + + +def unsnake(arr: np.ndarray, snaking: list) -> np.ndarray: + """Unsnake a nump array. + + For each axis in *arr*, there should be a corresponding True/False + in *snaking* whether that axis should have alternating rows. The + first entry is ignored as it doesn't make sense to snake the first + axis. + + Returns + ======= + unsnaked + A copy of *arr* with the odd-numbered axes flipped (if indicated + by *snaking*). + + """ + # arr = np.copy(arr) + # Create some slice object for easier manipulation + full_axis = slice(None) + alternating = slice(None, None, 2) + flipped = slice(None, None, -1) + # Flip each axis if necessary (skipping the first axis) + for axis, is_snaked in enumerate(snaking[1:]): + if not is_snaked: + continue + slices = (full_axis,) * axis + slices += (alternating,) + arr[slices] = arr[slices + (flipped,)] + return arr + + def load_catalog(name: str = "bluesky"): """Load a databroker catalog for retrieving data. @@ -82,17 +123,230 @@ def load_data(uid, catalog_name="bluesky", stream="primary"): return data -def tiled_client(entry_node=None, uri=None): +def with_thread_lock(fn): + """Makes sure the function isn't accessed concurrently.""" + def wrapper(obj, *args, **kwargs): + obj._lock.acquire() + try: + fn(obj, *args, **kwargs) + finally: + obj._lock.release() + return wrapper + + +class ThreadSafeCache(Cache): + """Equivalent to the regular cache, but thread-safe. + + Ensures that sqlite3 is built with concurrency features, and + ensures that no two write operations happen concurrently. + + """ + def __init__(self, *args, **kwargs, ): + super().__init__(*args, **kwargs) + self._lock = threading.Lock() + + def write_safe(self): + """ + Check that it is safe to write. + + SQLite is not threadsafe for concurrent _writes_. + """ + is_main_thread = threading.current_thread().ident == self._owner_thread + sqlite_is_safe = sqlite3.threadsafety == 1 + return is_main_thread or sqlite_is_safe + + # Wrap the accessor methods so they wait for the lock + clear = with_thread_lock(Cache.clear) + set = with_thread_lock(Cache.set) + get = with_thread_lock(Cache.get) + delete = with_thread_lock(Cache.delete) + + +def tiled_client(entry_node=None, uri=None, cache_filepath=None): config = load_config() + # Create a cache for saving local copies + if cache_filepath is None: + cache_filepath = config['database']['tiled'].get("cache_filepath", "") + cache_filepath = cache_filepath or None + cache = ThreadSafeCache(filepath=cache_filepath) + # Create the client if uri is None: uri = config["database"]["tiled"]["uri"] - client_ = from_uri(uri, cache=Cache()) + client_ = from_uri(uri, "dask", cache=cache) if entry_node is None: entry_node = config["database"]["tiled"]["entry_node"] client_ = client_[entry_node] return client_ + +class CatalogScan(): + """A single scan from the tiled API with some convenience methods. + + Parameters + ========== + A tiled container on which to operate.""" + + def __init__(self, container): + self.container = container + + def _read_data(self, signals): + # Fetch data if needed + data = self.container['primary']['data'] + return data.read(signals) + + def _read_metadata(self, keys=None): + container = self.container + if keys is not None: + container = container[keys] + return container.metadata + + @property + def uid(self): + return self.container._item['id'] + + async def to_dataframe(self, signals=None): + """Convert the dataset into a pandas dataframe.""" + xarray = await self.loop.run_in_executor(None, self._read_data, signals) + if len(xarray) > 0: + df = xarray.to_dataframe() + else: + df = pd.DataFrame() + return df + + @property + def loop(self): + return asyncio.get_running_loop() + + async def hints(self): + """Retrieve the data hints for this scan. + + Returns + ======= + independent + The hints for the independent scanning axis. + dependent + The hints for the dependent scanning axis. + """ + metadata = await self.metadata + # Get hints for the independent (X) + independent = metadata["start"]["hints"]["dimensions"][0][0] + # Get hints for the dependent (X) + dependent = [] + primary_metadata = await self.loop.run_in_executor(None, self._read_metadata, "primary") + hints = primary_metadata["descriptors"][0]["hints"] + for device, dev_hints in hints.items(): + dependent.extend(dev_hints["fields"]) + return independent, dependent + + @property + async def metadata(self): + metadata = await self.loop.run_in_executor(None, self._read_metadata) + return metadata + + async def __getitem__(self, signal): + """Retrieve a signal from the dataset, with reshaping etc.""" + loop = asyncio.get_running_loop() + arr = await loop.run_in_executor(None, self._read_data, tuple([signal])) + arr = np.asarray(arr[signal]) + # Re-shape to match the scan dimensions + metadata = await self.metadata + try: + shape = metadata['start']['shape'] + except KeyError: + log.warning(f"No shape found for {repr(signal)}.") + else: + arr = np.reshape(arr, shape) + # Flip alternating rows if snaking is enabled + if "snaking" in metadata['start']: + arr = unsnake(arr, metadata['start']['snaking']) + return arr + + +class Catalog(): + """An asynchronous wrapper around the tiled client. + + This class has a more intelligent understanding of how *our* data + are structured, so can make some assumptions and takes care of + boiler-plate code (e.g. reshaping maps, etc). + + """ + _client = None + + def __init__(self, client=None): + self._client = client + + @property + def loop(self): + return asyncio.get_running_loop() + + @property + async def client(self): + if self._client is None: + self._client = await self.loop.run_in_executor(None, tiled_client) + return self._client + + async def __getitem__(self, uid) -> CatalogScan: + client = await self.client + container = await self.loop.run_in_executor(None, client.__getitem__, uid) + scan = CatalogScan(container=container) + return scan + + async def items(self): + client = await self.client + for key, value in await self.loop.run_in_executor(None, client.items): + yield key, CatalogScan(container=value) + + async def values(self): + client = await self.client + containers = await self.loop.run_in_executor(None, client.values) + for container in containers: + yield CatalogScan(container) + + async def __len__(self): + client = await self.client + length = await self.loop.run_in_executor(None, client.__len__) + return length + + async def search(self, query): + """ + Make a Node with a subset of this Node's entries, filtered by query. + + Examples + -------- + + >>> from tiled.queries import FullText + >>> await tree.search(FullText("hello")) + """ + loop = asyncio.get_running_loop() + client = await self.client + return Catalog(await loop.run_in_executor(None, client.search, query)) + + async def distinct(self, *metadata_keys, structure_families=False, specs=False, counts=False): + """Get the unique values and optionally counts of metadata_keys, + structure_families, and specs in this Node's entries + + Examples + -------- + + Query all the distinct values of a key. + + >>> await catalog.distinct("foo", counts=True) + + Query for multiple keys at once. + + >>> await catalog.distinct("foo", "bar", counts=True) + + """ + loop = asyncio.get_running_loop() + client = await self.client + query = partial(client.distinct, *metadata_keys, structure_families=structure_families, specs=specs, counts=counts) + return await loop.run_in_executor(None, query) + +# Create a default catalog for basic usage +catalog = Catalog() + + # ----------------------------------------------------------------------------- # :author: Mark Wolfman # :email: wolfman@anl.gov diff --git a/src/haven/conftest.py b/src/haven/conftest.py new file mode 100644 index 00000000..7fe22305 --- /dev/null +++ b/src/haven/conftest.py @@ -0,0 +1,13 @@ +import pytest + +from bluesky import RunEngine + + +class RunEngineStub(RunEngine): + def __repr__(self): + return "" + + +@pytest.fixture() +def RE(event_loop): + return RunEngineStub(call_returns_result=True) diff --git a/src/haven/instrument/motor.py b/src/haven/instrument/motor.py index 8d7a99ea..fe5e7de7 100644 --- a/src/haven/instrument/motor.py +++ b/src/haven/instrument/motor.py @@ -85,7 +85,10 @@ async def load_motor(prefix: str, motor_num: int, ioc_name: str = None): """Create the requested motor if it is reachable.""" pv = f"{prefix}:m{motor_num+1}" # Check that we're not duplicating a motor somewhere else (e.g. KB mirrors) - existing_pvs = [m.prefix for m in registry.findall(label="motors", allow_none=True)] + existing_pvs = [] + for m in registry.findall(label="motors", allow_none=True): + if hasattr(m, "prefix"): + existing_pvs.append(m.prefix) if pv in existing_pvs: log.info(f"Motor for prefix {pv} already exists. Skipping.") return @@ -104,7 +107,6 @@ async def load_motor(prefix: str, motor_num: int, ioc_name: str = None): return else: log.debug(f"Resolved motor {pv} to '{name}'") - # Create the motor device unused_motor_names = [f"motor {motor_num+1}", ""] if name in unused_motor_names: diff --git a/tests/fixtures/motor_positions.yaml b/src/haven/tests/fixtures/motor_positions.yaml similarity index 100% rename from tests/fixtures/motor_positions.yaml rename to src/haven/tests/fixtures/motor_positions.yaml diff --git a/src/haven/tests/test_catalog.py b/src/haven/tests/test_catalog.py new file mode 100644 index 00000000..032ef58c --- /dev/null +++ b/src/haven/tests/test_catalog.py @@ -0,0 +1,115 @@ +import logging +from unittest.mock import MagicMock + +import pandas as pd +import numpy as np +import pandas as pd +import pytest +from pyqtgraph import PlotItem, PlotWidget, ImageView, ImageItem +from qtpy.QtCore import Qt +from tiled import queries +from tiled.adapters.mapping import MapAdapter +from tiled.adapters.xarray import DatasetAdapter +from tiled.client import Context, from_context +from tiled.server.app import build_app + + +from haven.catalog import Catalog, CatalogScan, unsnake + + +@pytest.fixture() +def scan(tiled_client): + uid = "7d1daf1d-60c7-4aa7-a668-d1cd97e5335f" + return CatalogScan(tiled_client[uid]) + + +@pytest.fixture() +def grid_scan(tiled_client): + uid = "85573831-f4b4-4f64-b613-a6007bf03a8d" + return CatalogScan(tiled_client[uid]) + + +def test_unsnake(): + # Make a snaked array + arr = np.arange(27).reshape((3, 3, 3)) + snaked = np.copy(arr) + snaked[::2] = snaked[::2, ::-1] + snaked[:, ::2] = snaked[:, ::2, ::-1] + # Do the unsnaking + unsnaked = unsnake(snaked, [False, True, True]) + # Check the result + np.testing.assert_equal(arr, unsnaked) + + +@pytest.mark.asyncio +async def test_client_fixture(tiled_client): + """Does the client fixture load without stalling the test runner?""" + + +@pytest.mark.asyncio +async def test_load_scan(catalog): + """Check that scans can be loaded from the catalog.""" + uid = "7d1daf1d-60c7-4aa7-a668-d1cd97e5335f" + scan = await catalog[uid] + assert isinstance(scan, CatalogScan) + + +@pytest.mark.asyncio +async def test_dataframe(scan): + """Check that the catalogscan can produce a pandas dataframe.""" + df = await scan.to_dataframe() + assert isinstance(df, pd.DataFrame) + +@pytest.mark.asyncio +async def test_load_nd_data(grid_scan): + """Check that the catalog scan can convert e.g. grid_scan results.""" + arr = await grid_scan["It_net_counts"] + assert arr.ndim == 2 + assert arr.shape == (5, 21) + + +@pytest.mark.asyncio +async def test_distinct(catalog, tiled_client): + distinct = tiled_client.distinct("plan_name") + assert await catalog.distinct("plan_name") == distinct + +@pytest.mark.asyncio +async def test_search(catalog, tiled_client): + """Make sure we can query to database properly.""" + query = queries.Regex("plan_name", "xafs_scan") + expected = tiled_client.search(query) + response = await catalog.search(query) + assert len(expected) == await response.__len__() + + +@pytest.mark.asyncio +async def test_values(catalog, tiled_client): + """Get the individual scans in the catalog.""" + expected = [uid for uid in tiled_client.keys()] + response = [val.uid async for val in catalog.values()] + assert expected == response + +# ----------------------------------------------------------------------------- +# :author: Mark Wolfman +# :email: wolfman@anl.gov +# :copyright: Copyright © 2023, UChicago Argonne, LLC +# +# Distributed under the terms of the 3-Clause BSD License +# +# The full license is in the file LICENSE, distributed with this software. +# +# DISCLAIMER +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +# ----------------------------------------------------------------------------- diff --git a/tests/README.md b/test_iocs/README.md similarity index 100% rename from tests/README.md rename to test_iocs/README.md diff --git a/tests/conftest.py b/test_iocs/conftest.py similarity index 100% rename from tests/conftest.py rename to test_iocs/conftest.py diff --git a/tests/dxp_3px_4elem_Fe55.txt b/test_iocs/dxp_3px_4elem_Fe55.txt similarity index 100% rename from tests/dxp_3px_4elem_Fe55.txt rename to test_iocs/dxp_3px_4elem_Fe55.txt diff --git a/tests/test_simulated_ioc.py b/test_iocs/test_simulated_ioc.py similarity index 100% rename from tests/test_simulated_ioc.py rename to test_iocs/test_simulated_ioc.py