Skip to content

Commit 31d64ea

Browse files
committed
Store problem configuration in Problem
Introduces Problem.config which contains the info from the PEtab yaml file. Sometimes it is convenient to have the original filenames around. Closes PEtab-dev#324.
1 parent 45a3371 commit 31d64ea

File tree

2 files changed

+78
-27
lines changed

2 files changed

+78
-27
lines changed

Diff for: petab/v1/problem.py

+77-27
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from warnings import warn
1212

1313
import pandas as pd
14+
from pydantic import AnyUrl, BaseModel, Field, RootModel
1415

1516
from . import (
1617
conditions,
@@ -79,6 +80,7 @@ def __init__(
7980
observable_df: pd.DataFrame = None,
8081
mapping_df: pd.DataFrame = None,
8182
extensions_config: dict = None,
83+
config: ProblemConfig = None,
8284
):
8385
self.condition_df: pd.DataFrame | None = condition_df
8486
self.measurement_df: pd.DataFrame | None = measurement_df
@@ -113,6 +115,7 @@ def __init__(
113115

114116
self.model: Model | None = model
115117
self.extensions_config = extensions_config or {}
118+
self.config = config
116119

117120
def __getattr__(self, name):
118121
# For backward-compatibility, allow access to SBML model related
@@ -262,10 +265,14 @@ def from_yaml(
262265
yaml_config: PEtab configuration as dictionary or YAML file name
263266
base_path: Base directory or URL to resolve relative paths
264267
"""
268+
# path to the yaml file
269+
filepath = None
270+
265271
if isinstance(yaml_config, Path):
266272
yaml_config = str(yaml_config)
267273

268274
if isinstance(yaml_config, str):
275+
filepath = yaml_config
269276
if base_path is None:
270277
base_path = get_path_prefix(yaml_config)
271278
yaml_config = yaml.load_yaml(yaml_config)
@@ -297,59 +304,58 @@ def get_path(filename):
297304
DeprecationWarning,
298305
stacklevel=2,
299306
)
307+
config = ProblemConfig(
308+
**yaml_config, base_path=base_path, filepath=filepath
309+
)
310+
problem0 = config.problems[0]
311+
# currently required for handling PEtab v2 in here
312+
problem0_ = yaml_config["problems"][0]
300313

301-
problem0 = yaml_config["problems"][0]
302-
303-
if isinstance(yaml_config[PARAMETER_FILE], list):
314+
if isinstance(config.parameter_file, list):
304315
parameter_df = parameters.get_parameter_df(
305-
[get_path(f) for f in yaml_config[PARAMETER_FILE]]
316+
[get_path(f) for f in config.parameter_file]
306317
)
307318
else:
308319
parameter_df = (
309-
parameters.get_parameter_df(
310-
get_path(yaml_config[PARAMETER_FILE])
311-
)
312-
if yaml_config[PARAMETER_FILE]
320+
parameters.get_parameter_df(get_path(config.parameter_file))
321+
if config.parameter_file
313322
else None
314323
)
315-
316-
if yaml_config[FORMAT_VERSION] in [1, "1", "1.0.0"]:
317-
if len(problem0[SBML_FILES]) > 1:
324+
if config.format_version.root in [1, "1", "1.0.0"]:
325+
if len(problem0.sbml_files) > 1:
318326
# TODO https://github.com/PEtab-dev/libpetab-python/issues/6
319327
raise NotImplementedError(
320328
"Support for multiple models is not yet implemented."
321329
)
322330

323331
model = (
324332
model_factory(
325-
get_path(problem0[SBML_FILES][0]),
333+
get_path(problem0.sbml_files[0]),
326334
MODEL_TYPE_SBML,
327335
model_id=None,
328336
)
329-
if problem0[SBML_FILES]
337+
if problem0.sbml_files
330338
else None
331339
)
332340
else:
333-
if len(problem0[MODEL_FILES]) > 1:
341+
if len(problem0_[MODEL_FILES]) > 1:
334342
# TODO https://github.com/PEtab-dev/libpetab-python/issues/6
335343
raise NotImplementedError(
336344
"Support for multiple models is not yet implemented."
337345
)
338-
if not problem0[MODEL_FILES]:
346+
if not problem0_[MODEL_FILES]:
339347
model = None
340348
else:
341349
model_id, model_info = next(
342-
iter(problem0[MODEL_FILES].items())
350+
iter(problem0_[MODEL_FILES].items())
343351
)
344352
model = model_factory(
345353
get_path(model_info[MODEL_LOCATION]),
346354
model_info[MODEL_LANGUAGE],
347355
model_id=model_id,
348356
)
349357

350-
measurement_files = [
351-
get_path(f) for f in problem0.get(MEASUREMENT_FILES, [])
352-
]
358+
measurement_files = [get_path(f) for f in problem0.measurement_files]
353359
# If there are multiple tables, we will merge them
354360
measurement_df = (
355361
core.concat_tables(
@@ -359,9 +365,7 @@ def get_path(filename):
359365
else None
360366
)
361367

362-
condition_files = [
363-
get_path(f) for f in problem0.get(CONDITION_FILES, [])
364-
]
368+
condition_files = [get_path(f) for f in problem0.condition_files]
365369
# If there are multiple tables, we will merge them
366370
condition_df = (
367371
core.concat_tables(condition_files, conditions.get_condition_df)
@@ -370,7 +374,7 @@ def get_path(filename):
370374
)
371375

372376
visualization_files = [
373-
get_path(f) for f in problem0.get(VISUALIZATION_FILES, [])
377+
get_path(f) for f in problem0.visualization_files
374378
]
375379
# If there are multiple tables, we will merge them
376380
visualization_df = (
@@ -379,17 +383,15 @@ def get_path(filename):
379383
else None
380384
)
381385

382-
observable_files = [
383-
get_path(f) for f in problem0.get(OBSERVABLE_FILES, [])
384-
]
386+
observable_files = [get_path(f) for f in problem0.observable_files]
385387
# If there are multiple tables, we will merge them
386388
observable_df = (
387389
core.concat_tables(observable_files, observables.get_observable_df)
388390
if observable_files
389391
else None
390392
)
391393

392-
mapping_files = [get_path(f) for f in problem0.get(MAPPING_FILES, [])]
394+
mapping_files = [get_path(f) for f in problem0_.get(MAPPING_FILES, [])]
393395
# If there are multiple tables, we will merge them
394396
mapping_df = (
395397
core.concat_tables(mapping_files, mapping.get_mapping_df)
@@ -406,6 +408,7 @@ def get_path(filename):
406408
visualization_df=visualization_df,
407409
mapping_df=mapping_df,
408410
extensions_config=yaml_config.get(EXTENSIONS, {}),
411+
config=config,
409412
)
410413

411414
@staticmethod
@@ -1184,3 +1187,50 @@ def add_measurement(
11841187
if self.measurement_df is not None
11851188
else tmp_df
11861189
)
1190+
1191+
1192+
class VersionNumber(RootModel):
1193+
root: str | int
1194+
1195+
1196+
class ListOfFiles(RootModel):
1197+
"""List of files."""
1198+
1199+
root: list[str | AnyUrl] = Field(..., description="List of files.")
1200+
1201+
def __iter__(self):
1202+
return iter(self.root)
1203+
1204+
def __len__(self):
1205+
return len(self.root)
1206+
1207+
def __getitem__(self, index):
1208+
return self.root[index]
1209+
1210+
1211+
class SubProblem(BaseModel):
1212+
"""A `problems` object in the PEtab problem configuration."""
1213+
1214+
sbml_files: ListOfFiles = []
1215+
measurement_files: ListOfFiles = []
1216+
condition_files: ListOfFiles = []
1217+
observable_files: ListOfFiles = []
1218+
visualization_files: ListOfFiles = []
1219+
1220+
1221+
class ProblemConfig(BaseModel):
1222+
"""The PEtab problem configuration."""
1223+
1224+
filepath: str | AnyUrl | None = Field(
1225+
None,
1226+
description="The path to the PEtab problem configuration.",
1227+
exclude=True,
1228+
)
1229+
base_path: str | AnyUrl | None = Field(
1230+
None,
1231+
description="The base path to resolve relative paths.",
1232+
exclude=True,
1233+
)
1234+
format_version: VersionNumber = 1
1235+
parameter_file: str | AnyUrl | None = None
1236+
problems: list[SubProblem] = []

Diff for: pyproject.toml

+1
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ dependencies = [
2222
"pyyaml",
2323
"jsonschema",
2424
"antlr4-python3-runtime==4.13.1",
25+
"pydantic>=2.10",
2526
]
2627
license = {text = "MIT License"}
2728
authors = [

0 commit comments

Comments
 (0)