Skip to content

Commit

Permalink
Fix tests (#188)
Browse files Browse the repository at this point in the history
* Removed unused imports.

* Skipped the run_browser tests since they produce seg faults.

* Black and isort.
  • Loading branch information
canismarko authored Apr 3, 2024
1 parent 83a1858 commit febd4f0
Show file tree
Hide file tree
Showing 11 changed files with 160 additions and 127 deletions.
58 changes: 28 additions & 30 deletions src/conftest.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,12 @@
import gc
import asyncio
import os
import subprocess
from pathlib import Path
from unittest import mock

import psutil
from qasync import QEventLoop, DefaultQEventLoopPolicy
import numpy as np
import pandas as pd

# from pydm.data_plugins import plugin_modules, add_plugin
import pydm
import pytest
import numpy as np
import pandas as pd
from ophyd import DynamicDeviceComponent as DDC
from ophyd import Kind
from ophyd.sim import (
Expand All @@ -21,18 +15,14 @@
instantiate_fake_device,
make_fake_device,
)
from pytestqt.qt_compat import qt_api
from tiled.adapters.mapping import MapAdapter
from tiled.adapters.xarray import DatasetAdapter
from tiled.adapters.table import TableAdapter
from tiled.client import Context, from_context
from tiled.server.app import build_app


import haven
from firefly.application import FireflyApplication
from firefly.main_window import FireflyMainWindow
from haven._iconfig import beamline_connected as _beamline_connected
from haven.catalog import Catalog
from haven.instrument.aerotech import AerotechStage
from haven.instrument.aps import ApsMachine
from haven.instrument.camera import AravisDetector
Expand All @@ -45,7 +35,6 @@
from haven.instrument.slits import ApertureSlits, BladeSlits
from haven.instrument.xspress import Xspress3Detector
from haven.instrument.xspress import add_mcas as add_xspress_mcas
from haven.catalog import Catalog

top_dir = Path(__file__).parent.resolve()
haven_dir = top_dir / "haven"
Expand Down Expand Up @@ -288,7 +277,7 @@ def shutters(sim_registry):

grid_scan = pd.DataFrame(
{
'CdnIPreKb': np.linspace(0, 104, num=105),
"CdnIPreKb": np.linspace(0, 104, num=105),
"It_net_counts": np.linspace(0, 104, num=105),
"aerotech_horiz": np.linspace(0, 104, num=105),
"aerotech_vert": np.linspace(0, 104, num=105),
Expand Down Expand Up @@ -341,26 +330,36 @@ def shutters(sim_registry):
"primary": MapAdapter(
{
"data": DatasetAdapter.from_dataset(grid_scan),
}, metadata={
"descriptors": [{"hints": {'Ipreslit': {'fields': ['Ipreslit_net_counts']},
'CdnIPreKb': {'fields': ['CdnIPreKb_net_counts']},
'I0': {'fields': ['I0_net_counts']},
'CdnIt': {'fields': ['CdnIt_net_counts']},
'aerotech_vert': {'fields': ['aerotech_vert']},
'aerotech_horiz': {'fields': ['aerotech_horiz']},
'Ipre_KB': {'fields': ['Ipre_KB_net_counts']},
'CdnI0': {'fields': ['CdnI0_net_counts']},
'It': {'fields': ['It_net_counts']}}}]
}),
},
metadata={
"descriptors": [
{
"hints": {
"Ipreslit": {"fields": ["Ipreslit_net_counts"]},
"CdnIPreKb": {"fields": ["CdnIPreKb_net_counts"]},
"I0": {"fields": ["I0_net_counts"]},
"CdnIt": {"fields": ["CdnIt_net_counts"]},
"aerotech_vert": {"fields": ["aerotech_vert"]},
"aerotech_horiz": {"fields": ["aerotech_horiz"]},
"Ipre_KB": {"fields": ["Ipre_KB_net_counts"]},
"CdnI0": {"fields": ["CdnI0_net_counts"]},
"It": {"fields": ["It_net_counts"]},
}
}
]
},
),
},
metadata={
"start": {
"plan_name": "grid_scan",
"uid": "85573831-f4b4-4f64-b613-a6007bf03a8d",
"hints": {
'dimensions': [[['aerotech_vert'], 'primary'],
[['aerotech_horiz'], 'primary']],
'gridding': 'rectilinear'
"dimensions": [
[["aerotech_vert"], "primary"],
[["aerotech_horiz"], "primary"],
],
"gridding": "rectilinear",
},
"shape": [5, 21],
"extents": [[-80, 80], [-100, 100]],
Expand Down Expand Up @@ -390,7 +389,6 @@ def catalog(tiled_client):
return Catalog(client=tiled_client)



# -----------------------------------------------------------------------------
# :author: Mark Wolfman
# :email: [email protected]
Expand Down
1 change: 0 additions & 1 deletion src/firefly/application.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import asyncio
import logging
import subprocess
from collections import OrderedDict
Expand Down
7 changes: 3 additions & 4 deletions src/firefly/conftest.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
import asyncio
import pytest
import subprocess
import psutil
import gc

import psutil
import pydm
from qasync import QEventLoop, DefaultQEventLoopPolicy
import pytest
from qasync import DefaultQEventLoopPolicy, QEventLoop

from firefly import FireflyApplication
from firefly.main_window import FireflyMainWindow
Expand Down
2 changes: 1 addition & 1 deletion src/firefly/launcher.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import asyncio
import argparse
import asyncio
import cProfile
import logging
import pstats
Expand Down
63 changes: 38 additions & 25 deletions src/firefly/run_browser.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,21 @@
import asyncio
import logging
from functools import wraps
from itertools import count
from typing import Sequence, Mapping
import time
import asyncio
from functools import partial, wraps
from typing import Mapping, Sequence

import numpy as np
import qtawesome as qta
import yaml
from qasync import asyncSlot
from matplotlib.colors import TABLEAU_COLORS
from pydantic.error_wrappers import ValidationError
from pyqtgraph import PlotItem, PlotWidget, ImageView, GraphicsLayoutWidget
from qtpy.QtCore import Qt, QThread, Signal
from pyqtgraph import GraphicsLayoutWidget, ImageView, PlotItem, PlotWidget
from qasync import asyncSlot
from qtpy.QtCore import Qt, Signal
from qtpy.QtGui import QStandardItem, QStandardItemModel
from qtpy.QtWidgets import QWidget

from firefly import display
from firefly.run_client import DatabaseWorker
from haven import exceptions

log = logging.getLogger(__name__)

Expand All @@ -33,8 +30,8 @@ async def inner(*args, **kwargs):
return await fn(*args, **kwargs)
except asyncio.exceptions.CancelledError:
log.warning(f"Cancelled task {fn}")
return inner

return inner


class FiltersWidget(QWidget):
Expand Down Expand Up @@ -171,14 +168,16 @@ def plot_runs(self, runs: Mapping, ylabel="", xlabel=""):


class Browser2DPlotWidget(ImageView):

"""A plot widget for 2D maps."""

def __init__(self, *args, view=None, **kwargs):
if view is None:
view = PlotItem()
super().__init__(*args, view=view, **kwargs)

def plot_runs(self, runs: Mapping, xlabel: str = "", ylabel: str = "", extents=None):
def plot_runs(
self, runs: Mapping, xlabel: str = "", ylabel: str = "", extents=None
):
"""Take loaded 2D or 3D mapping data and plot it.
Parameters
Expand All @@ -200,7 +199,7 @@ def plot_runs(self, runs: Mapping, xlabel: str = "", ylabel: str = "", extents=N
# To-do: make this respond to the combobox selection
image = np.mean(images, axis=0)
# To-do: Apply transformations

# # Plot the image
if 2 <= image.ndim <= 3:
self.setImage(image.T, autoRange=False)
Expand Down Expand Up @@ -246,8 +245,6 @@ def __init__(self, root_node=None, args=None, macros=None, **kwargs):
# Load the list of all runs for the selection widget
self.db_task(self.load_runs())



def db_task(self, coro, name="default task"):
"""Executes a co-routine as a database task. Existing database
tasks get cancelled.
Expand Down Expand Up @@ -437,7 +434,9 @@ async def update_1d_signals(self, *args):
old_values = [cb.currentText() for cb in comboboxes]
# Determine valid list of columns to choose from
use_hints = self.ui.plot_1d_hints_checkbox.isChecked()
signals_task = self.db_task(self.db.signal_names(hinted_only=use_hints), "1D signals")
signals_task = self.db_task(
self.db.signal_names(hinted_only=use_hints), "1D signals"
)
xcols, ycols = await signals_task
self.multi_y_signals = ycols
# Update the comboboxes with new signals
Expand All @@ -462,13 +461,14 @@ async def update_2d_signals(self, *args):
old_value = val_cb.currentText()
# Determine valid list of dependent signals to choose from
use_hints = self.ui.plot_2d_hints_checkbox.isChecked()
xcols, vcols = await self.db_task(self.db.signal_names(hinted_only=use_hints), "2D signals")
xcols, vcols = await self.db_task(
self.db.signal_names(hinted_only=use_hints), "2D signals"
)
# Update the UI with the list of controls
val_cb.clear()
val_cb.addItems(vcols)
# Restore previous selection
val_cb.setCurrentText(old_value)


# def calculate_ydata(
# self,
Expand Down Expand Up @@ -510,15 +510,16 @@ async def update_2d_signals(self, *args):
# raise exceptions.InvalidTransformation(msg)
# return y, y_string


@asyncSlot()
@cancellable
async def update_multi_plot(self, *args):
x_signal = self.ui.multi_signal_x_combobox.currentText()
if x_signal == "":
return
use_hints = self.ui.plot_1d_hints_checkbox.isChecked()
runs = await self.db_task(self.db.all_signals(hinted_only=use_hints), "multi-plot")
runs = await self.db_task(
self.db.all_signals(hinted_only=use_hints), "multi-plot"
)
self.ui.plot_multi_view.plot_runs(runs, xsignal=x_signal)

@asyncSlot()
Expand All @@ -536,7 +537,17 @@ async def update_1d_plot(self, *args):
use_log = self.ui.logarithm_checkbox.isChecked()
use_invert = self.ui.invert_checkbox.isChecked()
use_grad = self.ui.gradient_checkbox.isChecked()
task = self.db_task(self.db.signals(x_signal, y_signal, r_signal, use_log=use_log, use_invert=use_invert, use_grad=use_grad), "1D plot")
task = self.db_task(
self.db.signals(
x_signal,
y_signal,
r_signal,
use_log=use_log,
use_invert=use_invert,
use_grad=use_grad,
),
"1D plot",
)
runs = await task
self.ui.plot_1d_view.plot_runs(runs)

Expand All @@ -554,16 +565,18 @@ async def update_2d_plot(self):
# Eventually this will be replaced with robus choices for plotting multiple images
metadata = await self.db_task(self.db.metadata(), "2D plot")
metadata = list(metadata.values())[0]
dimensions = metadata['start']['hints']['dimensions']
dimensions = metadata["start"]["hints"]["dimensions"]
try:
xlabel = dimensions[-1][0][0]
xlabel = dimensions[-1][0][0]
ylabel = dimensions[-2][0][0]
except IndexError:
# Not a 2D scan
return
# Get spatial extent
extents = metadata['start']['extents']
self.ui.plot_2d_view.plot_runs(images, xlabel=xlabel, ylabel=ylabel, extents=extents)
extents = metadata["start"]["extents"]
self.ui.plot_2d_view.plot_runs(
images, xlabel=xlabel, ylabel=ylabel, extents=extents
)

@asyncSlot()
async def update_metadata(self, *args):
Expand Down
38 changes: 26 additions & 12 deletions src/firefly/run_client.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,20 @@
import asyncio
import datetime as dt
import logging
from collections import OrderedDict
from typing import Sequence, Mapping
from typing import Mapping, Sequence

from qtpy.QtCore import QObject, Signal, Slot
from tiled import queries
import pandas as pd
import numpy as np
import pandas as pd
from qtpy.QtCore import Signal
from tiled import queries

from haven.catalog import Catalog
from haven import exceptions
from haven.catalog import Catalog

log = logging.getLogger(__name__)


class DatabaseWorker():
class DatabaseWorker:
selected_runs: Sequence = []

# Signals
Expand Down Expand Up @@ -50,7 +49,9 @@ async def filtered_nodes(self, filters: Mapping):
for filter_name, Query, md_name in filter_params:
val = filters.get(filter_name, "")
if val != "":
runs = await runs.search(Query(md_name, val, case_sensitive=case_sensitive))
runs = await runs.search(
Query(md_name, val, case_sensitive=case_sensitive)
)
full_text = filters.get("full_text", "")
if full_text != "":
runs = await runs.search(
Expand Down Expand Up @@ -203,7 +204,15 @@ async def all_signals(self, hinted_only=False):
dfs[run.uid] = df
return dfs

async def signals(self, x_signal, y_signal, r_signal=None, use_log=False, use_invert=False, use_grad=False) -> Mapping:
async def signals(
self,
x_signal,
y_signal,
r_signal=None,
use_log=False,
use_invert=False,
use_grad=False,
) -> Mapping:
"""Produce a dictionary with the 1D datasets for plotting.
The keys of the dictionary are the labels for each curve, and
Expand All @@ -214,8 +223,10 @@ async def signals(self, x_signal, y_signal, r_signal=None, use_log=False, use_in
# Check for sensible inputs
use_reference = r_signal is not None
if "" in [x_signal, y_signal] or (use_reference and r_signal == ""):
msg = (f"Empty signal name requested: x={repr(x_signal)}, y={repr(y_signal)},"
f" r={repr(r_signal)}")
msg = (
f"Empty signal name requested: x={repr(x_signal)}, y={repr(y_signal)},"
f" r={repr(r_signal)}"
)
log.debug(msg)
raise exceptions.EmptySignalName(msg)
signals = [x_signal, y_signal]
Expand All @@ -233,7 +244,9 @@ async def signals(self, x_signal, y_signal, r_signal=None, use_log=False, use_in
missing_y = y_signal not in df.columns
missing_r = r_signal not in df.columns
if missing_x or missing_y or (use_reference and missing_r):
log.warning("Could not find signals {x_signal}, {y_signal} and {r_signal}")
log.warning(
"Could not find signals {x_signal}, {y_signal} and {r_signal}"
)
continue
# Apply transformations
if use_reference:
Expand All @@ -248,6 +261,7 @@ async def signals(self, x_signal, y_signal, r_signal=None, use_log=False, use_in
dfs[run.uid] = series
return dfs


# -----------------------------------------------------------------------------
# :author: Mark Wolfman
# :email: [email protected]
Expand Down
Loading

0 comments on commit febd4f0

Please sign in to comment.