Skip to content

Commit

Permalink
Merge pull request #342 from spc-group/run_browser
Browse files Browse the repository at this point in the history
Updated run browser for new relational database
  • Loading branch information
canismarko authored Jan 21, 2025
2 parents 54a3866 + 7399815 commit 1ac6308
Show file tree
Hide file tree
Showing 12 changed files with 1,089 additions and 693 deletions.
16 changes: 12 additions & 4 deletions src/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,12 +270,20 @@ async def filters(sim_registry):
),
},
metadata={
"plan_name": "xafs_scan",
"start": {
"plan_name": "xafs_scan",
"esaf_id": "1337",
"proposal_id": "158839",
"beamline_id": "255-ID-Z",
"sample_name": "NMC-532",
"sample_formula": "LiNi0.5Mn0.3Co0.2O2",
"edge": "Ni-K",
"uid": "7d1daf1d-60c7-4aa7-a668-d1cd97e5335f",
"hints": {"dimensions": [[["energy_energy"], "primary"]]},
},
"stop": {
"exit_status": "success",
},
},
),
"9d33bf66-9701-4ee3-90f4-3be730bc226c": MapAdapter(
Expand Down Expand Up @@ -347,6 +355,7 @@ async def filters(sim_registry):

mapping = {
"255id_testing": MapAdapter(bluesky_mapping),
"255bm_testing": MapAdapter(bluesky_mapping),
}

tree = MapAdapter(mapping)
Expand All @@ -357,13 +366,12 @@ def tiled_client():
app = build_app(tree)
with Context.from_app(app) as context:
client = from_context(context)
yield client["255id_testing"]
yield client


@pytest.fixture()
def catalog(tiled_client):
cat = Catalog(client=tiled_client)
# cat = mock.AsyncMock()
cat = Catalog(client=tiled_client["255id_testing"])
return cat


Expand Down
19 changes: 10 additions & 9 deletions src/firefly/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from qtpy.QtGui import QIcon, QKeySequence
from qtpy.QtWidgets import QAction, QErrorMessage

from haven import beamline, load_config
from haven import beamline, load_config, tiled_client
from haven.exceptions import ComponentNotFound, InvalidConfiguration
from haven.utils import titleize

Expand Down Expand Up @@ -333,11 +333,18 @@ async def finalize_new_window(self, action):
# Send the current devices to the window
await action.window.update_devices(self.registry)

def finalize_run_browser_window(self, action):
"""Connect up signals that are specific to the run browser window."""
@asyncSlot(QAction)
async def finalize_run_browser_window(self, action):
"""Connect up run browser signals and load initial data."""
display = action.display
self.run_updated.connect(display.update_running_scan)
self.run_stopped.connect(display.update_running_scan)
# Set initial state for the run_browser
client = tiled_client(catalog=None)
config = load_config()["tiled"]
await display.setup_database(
tiled_client=client, catalog_name=config["default_catalog"]
)

def finalize_status_window(self, action):
"""Connect up signals that are specific to the voltmeters window."""
Expand Down Expand Up @@ -652,12 +659,6 @@ async def add_queue_item(self, item):
if getattr(self, "_queue_client", None) is not None:
await self._queue_client.add_queue_item(item)

@QtCore.Slot()
def show_sample_viewer_window(self):
return self.show_window(
FireflyMainWindow, ui_dir / "sample_viewer.ui", name="sample_viewer"
)

@QtCore.Slot(bool)
def set_open_environment_action_state(self, is_open: bool):
"""Update the readback value for opening the queueserver environment."""
Expand Down
3 changes: 2 additions & 1 deletion src/firefly/kafka_client.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import asyncio
import logging
import warnings
from uuid import uuid4

import msgpack
from aiokafka import AIOKafkaConsumer
Expand Down Expand Up @@ -40,7 +41,7 @@ async def consumer_loop(self):
self.kafka_consumer = AIOKafkaConsumer(
config["queueserver"]["kafka_topic"],
bootstrap_servers="fedorov.xray.aps.anl.gov:9092",
group_id="my-group",
group_id=str(uuid4()),
value_deserializer=msgpack.loads,
)
consumer = self.kafka_consumer
Expand Down
118 changes: 78 additions & 40 deletions src/firefly/run_browser/client.py
Original file line number Diff line number Diff line change
@@ -1,65 +1,97 @@
import asyncio
import datetime as dt
import logging
import warnings
from collections import OrderedDict
from functools import partial
from typing import Mapping, Sequence

import numpy as np
import pandas as pd
from qasync import asyncSlot
from tiled import queries

from haven import exceptions
from haven.catalog import Catalog
from haven.catalog import Catalog, run_in_executor

log = logging.getLogger(__name__)


class DatabaseWorker:
selected_runs: Sequence = []
catalog: Catalog = None

def __init__(self, catalog=None, *args, **kwargs):
if catalog is None:
catalog = Catalog()
self.catalog = catalog
def __init__(self, tiled_client, *args, **kwargs):
self.client = tiled_client
super().__init__(*args, **kwargs)

@asyncSlot(str)
async def change_catalog(self, catalog_name: str):
"""Change the catalog being used for pulling data.
*catalog_name* should be an entry in *worker.tiled_client()*.
"""

def get_catalog(name):
return Catalog(self.client[catalog_name])

loop = asyncio.get_running_loop()
self.catalog = await loop.run_in_executor(None, get_catalog, catalog_name)

@run_in_executor
def catalog_names(self):
return list(self.client.keys())

async def stream_names(self):
awaitables = [scan.stream_names() for scan in self.selected_runs]
all_streams = await asyncio.gather(*awaitables)
# Flatten the lists
streams = [stream for streams in all_streams for stream in streams]
return list(set(streams))

async def filtered_nodes(self, filters: Mapping):
case_sensitive = False
log.debug(f"Filtering nodes: {filters}")
filter_params = [
# (filter_name, query type, metadata key)
("user", queries.Contains, "start.proposal_users"),
("proposal", queries.Eq, "start.proposal_id"),
("esaf", queries.Eq, "start.esaf_id"),
("sample", queries.Contains, "start.sample_name"),
("exit_status", queries.Eq, "stop.exit_status"),
("plan", queries.Eq, "start.plan_name"),
("edge", queries.Contains, "start.edge"),
]
filter_params = {
# filter_name: (query type, metadata key)
"plan": (queries.Eq, "start.plan_name"),
"sample": (queries.Contains, "start.sample_name"),
"formula": (queries.Contains, "start.sample_formula"),
"edge": (queries.Contains, "start.edge"),
"exit_status": (queries.Eq, "stop.exit_status"),
"user": (queries.Contains, "start.proposal_users"),
"proposal": (queries.Eq, "start.proposal_id"),
"esaf": (queries.Eq, "start.esaf_id"),
"beamline": (queries.Eq, "start.beamline_id"),
"before": (partial(queries.Comparison, "le"), "end.time"),
"after": (partial(queries.Comparison, "ge"), "start.time"),
"full_text": (queries.FullText, ""),
"standards_only": (queries.Eq, "start.is_standard"),
}
# Apply filters
runs = self.catalog
for filter_name, Query, md_name in filter_params:
val = filters.get(filter_name, "")
if val != "":
runs = await runs.search(Query(md_name, val))
full_text = filters.get("full_text", "")
if full_text != "":
runs = await runs.search(
queries.FullText(full_text, case_sensitive=case_sensitive)
)
for filter_name, filter_value in filters.items():
if filter_name not in filter_params:
continue
Query, md_name = filter_params[filter_name]
if Query is queries.FullText:
runs = await runs.search(Query(filter_value), case_sensitive=False)
else:
runs = await runs.search(Query(md_name, filter_value))
return runs

async def load_distinct_fields(self):
"""Get distinct metadata fields for filterable metadata."""
new_fields = {}
target_fields = [
"sample_name",
"proposal_users",
"proposal_id",
"esaf_id",
"sample_name",
"plan_name",
"edge",
"start.plan_name",
"start.sample_name",
"start.sample_formula",
"start.edge",
"stop.exit_status",
"start.proposal_id",
"start.esaf_id",
"start.beamline_id",
]
# Get fields from the database
response = await self.catalog.distinct(*target_fields)
Expand Down Expand Up @@ -118,11 +150,13 @@ async def load_all_runs(self, filters: Mapping = {}):
all_runs.append(run_data)
return all_runs

async def signal_names(self, hinted_only: bool = False):
async def signal_names(self, stream: str, *, hinted_only: bool = False):
"""Get a list of valid signal names (data columns) for selected runs.
Parameters
==========
stream
The Tiled stream name to fetch.
hinted_only
If true, only signals with the kind="hinted" parameter get
picked.
Expand All @@ -131,9 +165,9 @@ async def signal_names(self, hinted_only: bool = False):
xsignals, ysignals = [], []
for run in self.selected_runs:
if hinted_only:
xsig, ysig = await run.hints()
xsig, ysig = await run.hints(stream=stream)
else:
df = await run.data()
df = await run.data(stream=stream)
xsig = ysig = df.columns
xsignals.extend(xsig)
ysignals.extend(ysig)
Expand All @@ -160,32 +194,34 @@ async def load_selected_runs(self, uids):
self.selected_runs = runs
return runs

async def images(self, signal):
async def images(self, signal: str, stream: str):
"""Load the selected runs as 2D or 3D images suitable for plotting."""
images = OrderedDict()
for idx, run in enumerate(self.selected_runs):
# Load datasets from the database
try:
image = await run[signal]
image = await run.__getitem__(signal, stream=stream)
except KeyError as exc:
log.exception(exc)
else:
images[run.uid] = image
return images

async def all_signals(self, hinted_only=False) -> dict:
async def all_signals(self, stream: str, *, hinted_only=False) -> dict:
"""Produce dataframes with all signals for each run.
The keys of the dictionary are the labels for each curve, and
the corresponding value is a pandas dataframe with the scan data.
"""
xsignals, ysignals = await self.signal_names(hinted_only=hinted_only)
xsignals, ysignals = await self.signal_names(
hinted_only=hinted_only, stream=stream
)
# Build the dataframes
dfs = OrderedDict()
for run in self.selected_runs:
# Get data from the database
df = await run.data(signals=xsignals + ysignals)
df = await run.data(signals=xsignals + ysignals, stream=stream)
dfs[run.uid] = df
return dfs

Expand All @@ -194,6 +230,8 @@ async def signals(
x_signal,
y_signal,
r_signal=None,
*,
stream: str,
use_log=False,
use_invert=False,
use_grad=False,
Expand Down Expand Up @@ -233,7 +271,7 @@ async def signals(
if uids is not None and run.uid not in uids:
break
# Get data from the database
df = await run.data(signals=signals)
df = await run.data(signals=signals, stream=stream)
# Check for missing signals
missing_x = x_signal not in df.columns and df.index.name != x_signal
missing_y = y_signal not in df.columns
Expand Down
Loading

0 comments on commit 1ac6308

Please sign in to comment.