Skip to content

Commit 81e402e

Browse files
authored
Create split tool (#2)
* Change to use NotImplementedError * Pass dataset argument to split * Create ratio extraction and error handling * Create pipeline to create the new data * Install pandas * Split images successfully * Create exporting functionality * Finish split tool * Clean up the path management * Clean up based on PR review * Fix map object not a list issue * Make image distribution consistent
1 parent fceb762 commit 81e402e

File tree

7 files changed

+165
-8
lines changed

7 files changed

+165
-8
lines changed

.vscode/settings.json

+3
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
{
2+
"python.pythonPath": "venv/bin/python"
3+
}

coco_tools/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
11
from coco_tools.merge import merge
22
from coco_tools.split import split
3+
from coco_tools.error import COCOToolsError

coco_tools/error.py

+2
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
class COCOToolsError(BaseException):
2+
pass

coco_tools/merge.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
11
def merge():
2-
print("Unimplemented!")
2+
raise NotImplementedError()

coco_tools/split.py

+129-2
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,129 @@
1-
def split():
2-
print("Unimplemented!")
1+
import json
2+
import pandas as pd
3+
import numpy as np
4+
from pathlib import Path
5+
from coco_tools.error import COCOToolsError
6+
7+
8+
def split(dataset_path, ratio, names):
9+
"""Splits the dataset into multiple parts based on the given ratio.
10+
11+
Within the dataset, one image can have multiple annotations. `split` splits
12+
the dataset by the number of images, and splits the annotations based on the
13+
images that they belong to.
14+
15+
For example, in a dataset of 1000 images, a ratio of `70:20:10` would split
16+
the dataset into three datasets containing `700`, `200` and `100`
17+
respectively.
18+
"""
19+
20+
# Extract and validate the inputs.
21+
dataset_path = Path(dataset_path)
22+
ratio = __extract_ratio(ratio)
23+
names = __extract_names(names)
24+
25+
# Some additional input validation.
26+
if len(ratio) != len(names):
27+
raise COCOToolsError("ratio and names should be of same length")
28+
29+
# Load the dataset from `dataset_path`.
30+
raw_data = None
31+
try:
32+
with open(str(dataset_path), "r") as dataset_file:
33+
raw_data = json.load(dataset_file)
34+
except FileNotFoundError:
35+
raise COCOToolsError(f"file \"{dataset_path}\" not found")
36+
37+
# Extract `images` and `annotations`.
38+
images = raw_data.pop("images")
39+
annotations = raw_data.pop("annotations")
40+
41+
# Initialize the new datas.
42+
new_datas = [raw_data.copy() for _ in ratio]
43+
44+
# Split the data.
45+
__split_data(new_datas, ratio, images, annotations)
46+
47+
# Output the results to the corresponding files.
48+
for (i, new_data) in enumerate(new_datas):
49+
with open(__derive_path(dataset_path, names[i]), "w") as output_file:
50+
json.dump(new_data, output_file)
51+
52+
53+
def __split_data(datas, ratio, images, annotations):
54+
"""Sets `images` and `annotations` on the `datas` based on `ratio`.
55+
56+
Take note that this method mutates `datas`. It is done this way because
57+
`datas` should contain the additional data as part of a COCO dataset.
58+
59+
`pandas` is used here to perform the splitting/partitioning.
60+
"""
61+
62+
# Create data frames.
63+
images = pd.DataFrame(images)
64+
annotations = pd.DataFrame(annotations)
65+
66+
# Create the base mask
67+
base_mask = np.arange(0, 1, 1 / len(images))
68+
np.random.shuffle(base_mask)
69+
70+
# Track the current sum of ratios. This is used when finding the range to
71+
# compare to.
72+
ratio_sum = 0
73+
74+
# Iterate through each ratio and split the data.
75+
for (i, ration) in enumerate(ratio):
76+
data = datas[i]
77+
78+
# Create the mask.
79+
mask = (base_mask >= ratio_sum) & (base_mask < ratio_sum + ration)
80+
ratio_sum += ration
81+
82+
# Set the images on the data.
83+
data["images"] = images[mask].to_dict("records")
84+
85+
# Set the annotations on the data.
86+
common = images[mask].merge(
87+
annotations, left_on="id", right_on="image_id", how="inner")
88+
data["annotations"] = annotations[annotations.image_id.isin(
89+
common.image_id)].to_dict("records")
90+
91+
92+
def __derive_path(dataset_path, name):
93+
"""Derives the output path given `dataset_path` and `name`.
94+
"""
95+
96+
output_filename = Path(f"{str(dataset_path.stem)}_{name}.json")
97+
output_path = dataset_path.parent / output_filename
98+
return output_path
99+
100+
101+
def __extract_ratio(ratio):
102+
"""Splits, verifies and normalizes the ratio.
103+
104+
For example, a ratio of `70: 20: 30` will become `[0.58, 0.17, 0.25]`. The
105+
total does not need to add up to `100`.
106+
"""
107+
108+
# Split and strip.
109+
ratio = [ration.strip() for ration in ratio.split(":")]
110+
111+
# Verify length of ratio.
112+
if len(ratio) != 3:
113+
raise COCOToolsError("ratio should have length 3")
114+
115+
# Parse, and hence, verify.
116+
try:
117+
ratio = list(map(float, ratio))
118+
except ValueError:
119+
raise COCOToolsError(f'ratio should be a floats')
120+
121+
# Normalize based on sum.
122+
return list(map(lambda ration: ration / sum(ratio), ratio))
123+
124+
125+
def __extract_names(names):
126+
"""Splits the names.
127+
"""
128+
129+
return [name.strip() for name in names.split(":")]

main.py

+16-5
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,32 @@
11
from argparse import ArgumentParser
2-
import coco_tools
2+
from coco_tools import COCOToolsError, split
33

44

55
def main():
66
parser = ArgumentParser(description="Useful operations for COCO datasets")
7-
subparsers = parser.add_subparsers(help="Possible operations")
7+
subparsers = parser.add_subparsers(
8+
help="Possible operations", dest="command")
89

910
split_parser = subparsers.add_parser("split", help="Splits a dataset")
10-
split_parser.add_argument("dataset", help="The dataset to split")
11+
split_parser.add_argument(
12+
"-i", "--dataset", help="The dataset to split", default="data.json")
13+
split_parser.add_argument(
14+
"-r", "--ratio", help="The ratio to split by (e.g. 70:20:10)", default="70:20:10")
15+
split_parser.add_argument(
16+
"-n", "--names", help="The names for each split (e.g. train:validation:test)", default="train:validation:test")
1117

1218
merge_parser = subparsers.add_parser("merge", help="Merges datasets")
13-
merge_parser.add_argument("datasets", nargs="+",
19+
merge_parser.add_argument("--input", nargs="+",
1420
help="The datasets to merge")
1521

1622
args = parser.parse_args()
1723

18-
print(args)
24+
try:
25+
if args.command == "split":
26+
split(args.dataset, args.ratio, args.names)
27+
except COCOToolsError as e:
28+
print(f'error: {e}')
29+
exit(1)
1930

2031

2132
if __name__ == "__main__":

requirements.txt

+13
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
astroid==2.0.4
2+
autopep8==1.4.1
3+
isort==4.3.4
4+
lazy-object-proxy==1.3.1
5+
mccabe==0.6.1
6+
numpy==1.15.3
7+
pandas==0.23.4
8+
pycodestyle==2.4.0
9+
pylint==2.1.1
10+
python-dateutil==2.7.3
11+
pytz==2018.5
12+
six==1.11.0
13+
wrapt==1.10.11

0 commit comments

Comments
 (0)