Skip to content

Commit 3dff60f

Browse files
committed
v1.6.0alpha
1. add incomplete NuScenes support 2. New Dataset API 3. lots of code changes 4. config change, incompatible with previous
1 parent f980f3d commit 3dff60f

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

79 files changed

+7000
-2815
lines changed

README.md

+44-23
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,17 @@
1-
# SECOND-V1.5 for KITTI object detection
2-
SECOND-V1.5 detector.
1+
# SECOND for KITTI/NuScenes object detection
2+
SECOND detector.
33

44
ONLY support python 3.6+, pytorch 1.0.0+. Tested in Ubuntu 16.04/18.04/Windows 10.
55

66
## News
77

8-
2019-3-21: SECOND V1.5.1 (minor improvement and bug fix) released! See [release notes](RELEASE.md) for more details.
8+
2019-4-1: SECOND V1.6.0alpha released: New Data API, [NuScenes](https://www.nuscenes.org) support, [PointPillars](https://github.com/nutonomy/second.pytorch) support.
99

10-
2019-1-20: SECOND V1.5 released! See [release notes](RELEASE.md) for more details.
10+
2019-3-21: SECOND V1.5.1 (minor improvement and bug fix) released!
1111

12+
2019-1-20: SECOND V1.5 released! Sparse convolution-based network.
13+
14+
See [release notes](RELEASE.md) for more details.
1215

1316
### Performance in KITTI validation set (50/50 split)
1417

@@ -68,6 +71,8 @@ pip install numba scikit-image scipy pillow
6871

6972
Follow instructions in [spconv](https://github.com/traveller59/spconv) to install spconv.
7073

74+
If you want to use NuScenes dataset, you need to install [nuscenes-devkit](https://github.com/nutonomy/nuscenes-devkit), I recommend to copy nuscenes in python-sdk to second/.. folder (equalivent to add it to PYTHONPATH) and manually install its dependencies, use pip to install devkit will install many fixed-version library.
75+
7176
### 3. Setup cuda for numba
7277

7378
you need to add following environment variable for numba.cuda, you can add them to ~/.bashrc:
@@ -82,7 +87,7 @@ export NUMBAPRO_LIBDEVICE=/usr/local/cuda/nvvm/libdevice
8287

8388
## Prepare dataset
8489

85-
* Dataset preparation
90+
* KITTI Dataset preparation
8691

8792
Download KITTI dataset and create some directories first:
8893

@@ -101,22 +106,32 @@ Download KITTI dataset and create some directories first:
101106
└── velodyne_reduced <-- empty directory
102107
```
103108

104-
* Create kitti infos:
105-
109+
Then run
106110
```bash
107-
python create_data.py create_kitti_info_file --data_path=KITTI_DATASET_ROOT
111+
python create_data.py kitti_data_prep --data_path=KITTI_DATASET_ROOT
108112
```
109113

110-
* Create reduced point cloud:
111-
112-
```bash
113-
python create_data.py create_reduced_point_cloud --data_path=KITTI_DATASET_ROOT
114-
```
115-
116-
* Create groundtruth-database infos:
114+
* [NuScenes](https://www.nuscenes.org) Dataset preparation
117115

116+
Download NuScenes dataset:
117+
```plain
118+
└── NUSCENES_TRAINVAL_DATASET_ROOT
119+
├── samples <-- key frames
120+
├── sweeps <-- frames without annotation
121+
├── maps <-- unused
122+
└── v1.0-trainval <-- metadata and annotations
123+
└── NUSCENES_TEST_DATASET_ROOT
124+
├── samples <-- key frames
125+
├── sweeps <-- frames without annotation
126+
├── maps <-- unused
127+
└── v1.0-test <-- metadata
128+
```
129+
Since the dataset is really large, you can download parts of the dataset.
130+
131+
Then run
118132
```bash
119-
python create_data.py create_groundtruth_database --data_path=KITTI_DATASET_ROOT
133+
python create_data.py nuscenes_data_prep --data_path=NUSCENES_TRAINVAL_DATASET_ROOT --version="v1.0-trainval"
134+
python create_data.py nuscenes_data_prep --data_path=NUSCENES_TEST_DATASET_ROOT --version="v1.0-test"
120135
```
121136

122137
* Modify config file
@@ -127,24 +142,30 @@ There is some path need to be configured in config file:
127142
train_input_reader: {
128143
...
129144
database_sampler {
130-
database_info_path: "/path/to/kitti_dbinfos_train.pkl"
145+
database_info_path: "/path/to/dataset_dbinfos_train.pkl"
131146
...
132147
}
133-
kitti_info_path: "/path/to/kitti_infos_train.pkl"
134-
kitti_root_path: "KITTI_DATASET_ROOT"
148+
dataset: {
149+
kitti_info_path: "/path/to/dataset_infos_train.pkl"
150+
kitti_root_path: "DATASET_ROOT"
151+
}
135152
}
136153
...
137154
eval_input_reader: {
138155
...
139-
kitti_info_path: "/path/to/kitti_infos_val.pkl"
140-
kitti_root_path: "KITTI_DATASET_ROOT"
156+
dataset: {
157+
kitti_info_path: "/path/to/dataset_infos_val.pkl"
158+
kitti_root_path: "DATASET_ROOT"
159+
}
141160
}
142161
```
143162

144163
## Usage
145164

146165
### train
147166

167+
I recommend to use script.py to train and eval. see script.py for more details.
168+
148169
```bash
149170
python ./pytorch/train.py train --config_path=./configs/car.fhd.config --model_dir=/path/to/model_dir
150171
```
@@ -169,7 +190,7 @@ You can download pretrained models in [google drive](https://drive.google.com/op
169190

170191
Note that this pretrained model is trained before a bug of sparse convolution fixed, so the eval result may slightly worse.
171192

172-
## Docker (I don't have time to build docker for SECOND-V1.5)
193+
## Docker (Deprecated. I can't push docker due to network problem.)
173194

174195
You can use a prebuilt docker for testing:
175196
```
@@ -185,7 +206,7 @@ python ./pytorch/train.py evaluate --config_path=./configs/car.config --model_di
185206

186207
### Major step
187208

188-
1. run ```python ./kittiviewer/backend.py main --port=xxxx``` in your server/local.
209+
1. run ```python ./kittiviewer/backend/main.py main --port=xxxx``` in your server/local.
189210

190211
2. run ```cd ./kittiviewer/frontend && python -m http.server``` to launch a local web server.
191212

RELEASE.md

+30-7
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,24 @@
1-
# Release 1.5
1+
# Release 1.6.0alpha
22

33
## Major Features and Improvements
4+
1. New dataset API (unstable during alpha), almost completely remove kitti-specific code. you can add your custom dataset by following steps:
5+
(1): implement all Dataset API functions
6+
(2): use web visualization tool to check whether the box is correct.
7+
(3): add your dataset to all_dataset.py, change the dataset_class_name in config file.
48

5-
1. New sparse convolution based models. VFE-based old models are deprecated. Now the model looks like this:
6-
points([N, 4])->voxels([N, 5, 4])->Features([N, 4])->Sparse Convolution Networks->RPN. See [this](https://github.com/traveller59/second.pytorch/blob/master/second/pytorch/models/middle.py) for more details of sparse conv networks.
7-
2. The [SparseConvNet](https://github.com/facebookresearch/SparseConvNet) is deprecated. New library [spconv](https://github.com/traveller59/spconv) is introduced.
8-
3. Super converge (from fastai) is implemented. Now all network can converge to a good result with only 50~80 epoch. For example. ```car.fhd.config``` only needs 50 epochs to reach 78.3 AP (car mod 3d).
9-
4. Target assigner now works correctly when using multi-class.
9+
2. Add [NuScenes](https://www.nuscenes.org) dataset support (incomplete in 1.6.0alpha), I plan to reproduce the NDS score in their paper.
10+
11+
3. Add [pointpillars](https://github.com/nutonomy/second.pytorch) to this repo.
12+
13+
4. Full Tensorboard support.
14+
15+
## Minor Improvements and Bug fixes
16+
17+
1. Move all data-specific functions to their corresponding dataset file.
18+
19+
2. Improved config file structure, remove some unused item.
20+
21+
3. remove much unused and deprecated code.
1022

1123
# Release 1.5.1
1224

@@ -19,4 +31,15 @@ points([N, 4])->voxels([N, 5, 4])->Features([N, 4])->Sparse Convolution Networks
1931
2. Better RPN, you can add custom block by inherit RPNBase and implement _make_layer method.
2032
3. Update pretrained model.
2133
4. Add a simple inference notebook. everyone should start this project by that notebook.
22-
5. Add windows support. Training on windows is slow than linux.
34+
5. Add windows support. Training on windows is slow than linux.
35+
36+
# Release 1.5
37+
38+
## Major Features and Improvements
39+
40+
1. New sparse convolution based models. VFE-based old models are deprecated. Now the model looks like this:
41+
points([N, 4])->voxels([N, 5, 4])->Features([N, 4])->Sparse Convolution Networks->RPN. See [this](https://github.com/traveller59/second.pytorch/blob/master/second/pytorch/models/middle.py) for more details of sparse conv networks.
42+
2. The [SparseConvNet](https://github.com/facebookresearch/SparseConvNet) is deprecated. New library [spconv](https://github.com/traveller59/spconv) is introduced.
43+
3. Super converge (from fastai) is implemented. Now all network can converge to a good result with only 50~80 epoch. For example. ```car.fhd.config``` only needs 50 epochs to reach 78.3 AP (car mod 3d).
44+
4. Target assigner now works correctly when using multi-class.
45+

second/builder/dataset_builder.py

+77-39
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,41 @@
1+
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
"""Input reader builder.
16+
17+
Creates data sources for DetectionModels from an InputReader config. See
18+
input_reader.proto for options.
19+
20+
Note: If users wishes to also use their own InputReaders with the Object
21+
Detection configuration framework, they should define their own builder function
22+
that wraps the build function.
23+
"""
24+
125
from second.protos import input_reader_pb2
2-
from second.data.dataset import KittiDataset
26+
from second.data.all_dataset import get_dataset_class
327
from second.data.preprocess import prep_pointcloud
28+
from second.core import box_np_ops
429
import numpy as np
530
from second.builder import dbsampler_builder
631
from functools import partial
7-
from second.utils import config_tool
32+
from second.utils.config_tool import get_downsample_factor
833

934
def build(input_reader_config,
1035
model_config,
1136
training,
1237
voxel_generator,
13-
target_assigner=None):
38+
target_assigner):
1439
"""Builds a tensor dictionary based on the InputReader config.
1540
1641
Args:
@@ -26,61 +51,74 @@ def build(input_reader_config,
2651
if not isinstance(input_reader_config, input_reader_pb2.InputReader):
2752
raise ValueError('input_reader_config not of type '
2853
'input_reader_pb2.InputReader.')
29-
generate_bev = model_config.use_bev
30-
without_reflectivity = model_config.without_reflectivity
54+
prep_cfg = input_reader_config.preprocess
55+
dataset_cfg = input_reader_config.dataset
3156
num_point_features = model_config.num_point_features
32-
downsample_factor = config_tool.get_downsample_factor(model_config)
57+
out_size_factor = get_downsample_factor(model_config)
58+
assert out_size_factor > 0
3359
cfg = input_reader_config
34-
db_sampler_cfg = input_reader_config.database_sampler
60+
db_sampler_cfg = prep_cfg.database_sampler
3561
db_sampler = None
3662
if len(db_sampler_cfg.sample_groups) > 0: # enable sample
3763
db_sampler = dbsampler_builder.build(db_sampler_cfg)
38-
u_db_sampler_cfg = input_reader_config.unlabeled_database_sampler
39-
u_db_sampler = None
40-
if len(u_db_sampler_cfg.sample_groups) > 0: # enable sample
41-
u_db_sampler = dbsampler_builder.build(u_db_sampler_cfg)
4264
grid_size = voxel_generator.grid_size
4365
# [352, 400]
44-
feature_map_size = grid_size[:2] // downsample_factor
66+
feature_map_size = grid_size[:2] // out_size_factor
4567
feature_map_size = [*feature_map_size, 1][::-1]
4668
print("feature_map_size", feature_map_size)
4769
assert all([n != '' for n in target_assigner.classes]), "you must specify class_name in anchor_generators."
70+
dataset_cls = get_dataset_class(dataset_cfg.dataset_class_name)
71+
assert dataset_cls.NumPointFeatures >= 3, "you must set this to correct value"
4872
prep_func = partial(
4973
prep_pointcloud,
50-
root_path=cfg.kitti_root_path,
51-
class_names=target_assigner.classes,
74+
root_path=dataset_cfg.kitti_root_path,
5275
voxel_generator=voxel_generator,
5376
target_assigner=target_assigner,
5477
training=training,
55-
max_voxels=cfg.max_number_of_voxels,
78+
max_voxels=prep_cfg.max_number_of_voxels,
5679
remove_outside_points=False,
57-
remove_unknown=cfg.remove_unknown_examples,
80+
remove_unknown=prep_cfg.remove_unknown_examples,
5881
create_targets=training,
59-
shuffle_points=cfg.shuffle_points,
60-
gt_rotation_noise=list(cfg.groundtruth_rotation_uniform_noise),
61-
gt_loc_noise_std=list(cfg.groundtruth_localization_noise_std),
62-
global_rotation_noise=list(cfg.global_rotation_uniform_noise),
63-
global_scaling_noise=list(cfg.global_scaling_uniform_noise),
82+
shuffle_points=prep_cfg.shuffle_points,
83+
gt_rotation_noise=list(prep_cfg.groundtruth_rotation_uniform_noise),
84+
gt_loc_noise_std=list(prep_cfg.groundtruth_localization_noise_std),
85+
global_rotation_noise=list(prep_cfg.global_rotation_uniform_noise),
86+
global_scaling_noise=list(prep_cfg.global_scaling_uniform_noise),
6487
global_random_rot_range=list(
65-
cfg.global_random_rotation_range_per_object),
88+
prep_cfg.global_random_rotation_range_per_object),
89+
global_translate_noise_std=list(prep_cfg.global_translate_noise_std),
6690
db_sampler=db_sampler,
67-
unlabeled_db_sampler=u_db_sampler,
68-
generate_bev=generate_bev,
69-
without_reflectivity=without_reflectivity,
70-
num_point_features=num_point_features,
71-
anchor_area_threshold=cfg.anchor_area_threshold,
72-
gt_points_drop=cfg.groundtruth_points_drop_percentage,
73-
gt_drop_max_keep=cfg.groundtruth_drop_max_keep_points,
74-
remove_points_after_sample=cfg.remove_points_after_sample,
75-
remove_environment=cfg.remove_environment,
76-
use_group_id=cfg.use_group_id,
77-
downsample_factor=downsample_factor)
78-
dataset = KittiDataset(
79-
info_path=cfg.kitti_info_path,
80-
root_path=cfg.kitti_root_path,
81-
num_point_features=num_point_features,
82-
target_assigner=target_assigner,
83-
feature_map_size=feature_map_size,
91+
num_point_features=dataset_cls.NumPointFeatures,
92+
anchor_area_threshold=prep_cfg.anchor_area_threshold,
93+
gt_points_drop=prep_cfg.groundtruth_points_drop_percentage,
94+
gt_drop_max_keep=prep_cfg.groundtruth_drop_max_keep_points,
95+
remove_points_after_sample=prep_cfg.remove_points_after_sample,
96+
remove_environment=prep_cfg.remove_environment,
97+
use_group_id=prep_cfg.use_group_id,
98+
out_size_factor=out_size_factor)
99+
100+
ret = target_assigner.generate_anchors(feature_map_size)
101+
class_names = target_assigner.classes
102+
anchors_dict = target_assigner.generate_anchors_dict(feature_map_size)
103+
anchors = ret["anchors"]
104+
anchors = anchors.reshape([-1, 7])
105+
matched_thresholds = ret["matched_thresholds"]
106+
unmatched_thresholds = ret["unmatched_thresholds"]
107+
anchors_bv = box_np_ops.rbbox2d_to_near_bbox(
108+
anchors[:, [0, 1, 3, 4, 6]])
109+
anchor_cache = {
110+
"anchors": anchors,
111+
"anchors_bv": anchors_bv,
112+
"matched_thresholds": matched_thresholds,
113+
"unmatched_thresholds": unmatched_thresholds,
114+
"anchors_dict": anchors_dict,
115+
}
116+
prep_func = partial(prep_func, anchor_cache=anchor_cache)
117+
118+
dataset = dataset_cls(
119+
info_path=dataset_cfg.kitti_info_path,
120+
root_path=dataset_cfg.kitti_root_path,
121+
class_names=class_names,
84122
prep_func=prep_func)
85123

86124
return dataset

second/builder/target_assigner_builder.py

-1
Original file line numberDiff line numberDiff line change
@@ -38,4 +38,3 @@ def build(target_assigner_config, bv_range, box_coder):
3838
positive_fraction=positive_fraction,
3939
sample_size=target_assigner_config.sample_size)
4040
return target_assigner
41-

second/builder/voxel_builder.py

+1-9
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,6 @@
33
from spconv.utils import VoxelGenerator
44
from second.protos import voxel_generator_pb2
55

6-
class _VoxelGenerator(VoxelGenerator):
7-
@property
8-
def grid_size(self):
9-
point_cloud_range = np.array(self.point_cloud_range)
10-
voxel_size = np.array(self.voxel_size)
11-
g_size = (point_cloud_range[3:] - point_cloud_range[:3]) / voxel_size
12-
g_size = np.round(g_size).astype(np.int64)
13-
return g_size
146

157
def build(voxel_config):
168
"""Builds a tensor dictionary based on the InputReader config.
@@ -28,7 +20,7 @@ def build(voxel_config):
2820
if not isinstance(voxel_config, (voxel_generator_pb2.VoxelGenerator)):
2921
raise ValueError('input_reader_config not of type '
3022
'input_reader_pb2.InputReader.')
31-
voxel_generator = _VoxelGenerator(
23+
voxel_generator = VoxelGenerator(
3224
voxel_size=list(voxel_config.voxel_size),
3325
point_cloud_range=list(voxel_config.point_cloud_range),
3426
max_num_points=voxel_config.max_number_of_points_per_voxel,

0 commit comments

Comments
 (0)