Skip to content

Commit 575c9bd

Browse files
authored
refactor!: mpi hdf5 (#902)
* fix: cleanup jobpool with mpi * fix: make Engine.recognizes static and pass MPIService comm as attribute. * fix: use MPILock instead of single rank selection. * fix: change variable name to match parent class * fix: clear of stored data * fix: change recognizes signature to use MPI communicator instead of MPILock * fix: test_connectivity.py to take into account new clear system * fix: bump pull request to allow test for breaking changes
1 parent 553f8b3 commit 575c9bd

File tree

7 files changed

+21
-18
lines changed

7 files changed

+21
-18
lines changed

.github/workflows/pull_request.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ jobs:
99
runs-on: ubuntu-latest
1010
steps:
1111
- name: PR Conventional Commit Validation
12-
uses: ytanikin/PRConventionalCommits@1.2.0
12+
uses: ytanikin/PRConventionalCommits@1.3.0
1313
with:
1414
task_types: '["feat","fix","docs","test","ci","refactor","perf","revert"]'
1515

bsb/core.py

+4-5
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import itertools
22
import os
33
import sys
4+
import time
45
import typing
56

67
import numpy as np
@@ -201,8 +202,6 @@ def _bootstrap(self, config, storage, clear=False):
201202
# Then, `storage` is initialized for the scaffold, and `config` is stored (happens
202203
# inside the `storage` property).
203204
self.storage = storage
204-
# Synchronize the JobPool static variable so that each core use the same ids.
205-
JobPool._next_pool_id = self._comm.bcast(JobPool._next_pool_id, root=0)
206205

207206
storage_cfg = _config_property("storage")
208207
for attr in _cfg_props:
@@ -388,8 +387,7 @@ def compile(
388387
report("Clearing data", level=2)
389388
# Clear the placement and connectivity data, but leave any cached files
390389
# and morphologies intact.
391-
self.clear_placement()
392-
self.clear_connectivity()
390+
self.clear()
393391
elif redo:
394392
# In order to properly redo things, we clear some placement and connection
395393
# data, but since multiple placement/connection strategies can contribute
@@ -791,8 +789,9 @@ def _load_cs_types(
791789
return cs
792790

793791
def create_job_pool(self, fail_fast=None, quiet=False):
792+
id_pool = self._comm.bcast(int(time.time()), root=0)
794793
pool = JobPool(
795-
self, fail_fast=fail_fast, workflow=getattr(self, "_workflow", None)
794+
id_pool, self, fail_fast=fail_fast, workflow=getattr(self, "_workflow", None)
796795
)
797796
try:
798797
# Check whether stdout is a TTY, and that it is larger than 0x0

bsb/services/pool.py

+2-4
Original file line numberDiff line numberDiff line change
@@ -540,9 +540,9 @@ class JobPool:
540540
_pool_owners = {}
541541
_tmp_folders = {}
542542

543-
def __init__(self, scaffold, fail_fast=False, workflow: "Workflow" = None):
543+
def __init__(self, id, scaffold, fail_fast=False, workflow: "Workflow" = None):
544544
self._schedulers: list[concurrent.futures.Future] = []
545-
self.id: int = None
545+
self.id: int = id
546546
self._scaffold = scaffold
547547
self._comm = scaffold._comm
548548
self._unhandled_errors = []
@@ -563,8 +563,6 @@ def __enter__(self):
563563
self._context = ExitStack()
564564
tmp_dirname = self._context.enter_context(tempfile.TemporaryDirectory())
565565

566-
self.id = JobPool._next_pool_id
567-
JobPool._next_pool_id += 1
568566
JobPool._pool_owners[self.id] = self._scaffold
569567
JobPool._pools[self.id] = self
570568
JobPool._tmp_folders[self.id] = tmp_dirname

bsb/storage/__init__.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
from .. import plugins
2121
from ..exceptions import UnknownStorageEngineError
22+
from ..services import MPILock
2223
from ..services.mpi import MPIService
2324

2425
if typing.TYPE_CHECKING:
@@ -389,7 +390,7 @@ def open_storage(root, comm=None):
389390
"""
390391
engines = get_engines()
391392
for name, engine in engines.items():
392-
if engine.peek_exists(root) and engine.recognizes(root):
393+
if engine.peek_exists(root) and engine.recognizes(root, comm):
393394
return Storage(name, root, comm, missing_ok=False)
394395
else:
395396
for name, engine in engines.items():
@@ -417,7 +418,7 @@ def view_support(engine=None):
417418
"""
418419
if engine is None:
419420
return {
420-
# Loop over all enginges
421+
# Loop over all engines
421422
engine_name: {
422423
# Loop over all features, check whether they're supported
423424
feature_name: not isinstance(feature, NotSupported)

bsb/storage/fs/__init__.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,8 @@ def versions(self):
2929
path = Path(self._root) / "versions.txt"
3030
return json.loads(path.read_text())
3131

32-
@classmethod
33-
def recognizes(cls, root):
32+
@staticmethod
33+
def recognizes(root, comm):
3434
try:
3535
return os.path.exists(root) and os.path.isdir(root)
3636
except Exception:

bsb/storage/interfaces.py

+7-2
Original file line numberDiff line numberDiff line change
@@ -114,10 +114,15 @@ def root_slug(self):
114114
"""
115115
pass
116116

117+
@staticmethod
117118
@abc.abstractmethod
118-
def recognizes(self, root):
119+
def recognizes(root, comm):
119120
"""
120-
Must return whether the given argument is recognized as a valid storage object.
121+
Must return whether the given root argument is recognized as a valid storage object.
122+
123+
:param root: The unique identifier for the storage
124+
:param mpi4py.MPI.Comm comm: MPI communicator that shares control
125+
over the Storage.
121126
"""
122127
pass
123128

tests/test_connectivity.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -639,7 +639,7 @@ def test_contacts(self):
639639
conns = len(self.network.get_connectivity_set("intersect"))
640640
self.assertGreater(conns, 0, "no connections formed")
641641
self.network.connectivity.intersect.contacts = 2
642-
self.network.compile(clear=True)
642+
self.network.compile(redo=True)
643643
new_conns = len(self.network.get_connectivity_set("intersect"))
644644
self.assertEqual(conns * 2, new_conns, "Expected double contacts")
645645

@@ -651,7 +651,7 @@ def test_zero_contacts(self):
651651
conns = len(self.network.get_connectivity_set("intersect"))
652652
self.assertEqual(0, conns, "expected no contacts")
653653
self.network.connectivity.intersect.contacts = -3
654-
self.network.compile(clear=True)
654+
self.network.compile(redo=True)
655655
conns = len(self.network.get_connectivity_set("intersect"))
656656
self.assertEqual(0, conns, "expected no contacts")
657657

0 commit comments

Comments
 (0)