Skip to content

Commit 29967f2

Browse files
committed
refactor: 🎨 Move validation to the base configuration
Addresses part of #7
1 parent 025e520 commit 29967f2

File tree

2 files changed

+21
-15
lines changed

2 files changed

+21
-15
lines changed

{{cookiecutter.project_slug}}/config/config.yaml

+20-1
Original file line numberDiff line numberDiff line change
@@ -18,4 +18,23 @@ gt:
1818

1919

2020
train:
21-
experiment_dir: ??? # TODO specify the path to save the run results, will also be used to store checkpoints etc.
21+
experiment_dir: ??? # TODO specify the path to save the run results, will also be used to store checkpoints etc.
22+
input_shape: # TODO specify the input shape of the model
23+
- ???
24+
- ???
25+
- ???
26+
27+
validate:
28+
batch_size: ???
29+
container: ${data.container}
30+
dataset: ${data.dataset}
31+
num_transmitters: ${gt.num_transmitters}
32+
experiment_dir: ${train.experiment_dir}
33+
checkpoint_number: ???
34+
input_shape: ${train.input_shape}
35+
num_workers: 12
36+
nt_name: ${gt.nt_name}
37+
point_id: ${gt.point_id}
38+
num_partitions: 10
39+
partition_id: 1
40+
output_dir: ??? # TODO specify the path to save the validation results, optional

{{cookiecutter.project_slug}}/scripts/03_inference.py

+1-14
Original file line numberDiff line numberDiff line change
@@ -253,7 +253,7 @@ def validate(
253253
save_output(results, identifiers, class_names_ordered, output_file)
254254
logging.info(f"Intermediate results saved to {output_file}")
255255

256-
@hydra.main(config_path="../config", config_name="validate")
256+
@hydra.main(config_path="../config", config_name="config")
257257
def main(cfg: DictConfig):
258258
logging.basicConfig(
259259
level=logging.INFO,
@@ -267,25 +267,12 @@ def main(cfg: DictConfig):
267267
# training data
268268
validation_data = cfg.gt.val
269269

270-
num_partitions = 1
271-
partition_id = 1
272-
if "num_partitions" in cfg:
273-
num_partitions = cfg.num_partitions
274-
if "partition_id" in cfg:
275-
partition_id = cfg.partition_id
276-
output_dir = None
277-
if "output_dir" in cfg:
278-
output_dir = cfg.output_dir
279-
280270
# run validation
281271
validate(
282272
model=model,
283273
# Get the data location
284274
val_gt_location=validation_data,
285275
**cfg.validate,
286-
num_partitions=num_partitions,
287-
partition_id=partition_id,
288-
output_dir=output_dir,
289276
)
290277

291278

0 commit comments

Comments
 (0)