Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

added center-method focus to sdfstudio-data #196

Open
wants to merge 8 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 60 additions & 11 deletions nerfstudio/cameras/camera_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,15 @@
"""

import math
from typing import List, Optional, Tuple
from typing import List, Literal, Optional, Tuple

import numpy as np
import torch
from jaxtyping import Float
from numpy.typing import NDArray
from torchtyping import TensorType
from typing_extensions import Literal
from torch import Tensor

_EPS = np.finfo(float).eps * 4.0

Expand Down Expand Up @@ -406,36 +409,82 @@ def rotation_matrix(a: TensorType[3], b: TensorType[3]) -> TensorType[3, 3]:
)
return torch.eye(3) + skew_sym_mat + skew_sym_mat @ skew_sym_mat * ((1 - c) / (s**2 + 1e-8))

def focus_of_attention(poses: Float[Tensor, "*num_poses 4 4"], initial_focus: Float[Tensor, "3"]) -> Float[Tensor, "3"]:
"""Compute the focus of attention of a set of cameras. Only cameras
that have the focus of attention in front of them are considered.

Args:
poses: The poses to orient.
initial_focus: The 3D point views to decide which cameras are initially activated.

Returns:
The 3D position of the focus of attention.
"""
# References to the same method in third-party code:
# https://github.com/google-research/multinerf/blob/1c8b1c552133cdb2de1c1f3c871b2813f6662265/internal/camera_utils.py#L145
# https://github.com/bmild/nerf/blob/18b8aebda6700ed659cb27a0c348b737a5f6ab60/load_llff.py#L197
active_directions = -poses[:, :3, 2:3]
active_origins = poses[:, :3, 3:4]
# initial value for testing if the focus_pt is in front or behind
focus_pt = initial_focus
# Prune cameras which have the current have the focus_pt behind them.
active = torch.sum(active_directions.squeeze(-1) * (focus_pt - active_origins.squeeze(-1)), dim=-1) > 0
done = False
# We need at least two active cameras, else fallback on the previous solution.
# This may be the "poses" solution if no cameras are active on first iteration, e.g.
# they are in an outward-looking configuration.
while torch.sum(active.int()) > 1 and not done:
active_directions = active_directions[active]
active_origins = active_origins[active]
# https://en.wikipedia.org/wiki/Line–line_intersection#In_more_than_two_dimensions
m = torch.eye(3) - active_directions * torch.transpose(active_directions, -2, -1)
mt_m = torch.transpose(m, -2, -1) @ m
focus_pt = torch.linalg.inv(mt_m.mean(0)) @ (mt_m @ active_origins).mean(0)[:, 0]
active = torch.sum(active_directions.squeeze(-1) * (focus_pt - active_origins.squeeze(-1)), dim=-1) > 0
if active.all():
# the set of active cameras did not change, so we're done.
done = True
return focus_pt

def auto_orient_and_center_poses(
poses: TensorType["num_poses":..., 4, 4], method: Literal["pca", "up", "none"] = "up", center_poses: bool = True
) -> TensorType["num_poses":..., 3, 4]:
poses: TensorType["num_poses":..., 4, 4], method: Literal["pca", "up", "none"] = "up", center_method: Literal["poses", "focus", "none"] = "poses",
) -> Tuple[Float [Tensor, "*num_poses 3 4"], Float[Tensor, "3 4"]]:
"""Orients and centers the poses. We provide two methods for orientation: pca and up.

pca: Orient the poses so that the principal component of the points is aligned with the axes.
This method works well when all of the cameras are in the same plane.
up: Orient the poses so that the average up vector is aligned with the z axis.
This method works well when images are not at arbitrary angles.

There are two centering methods:
poses: The poses are centered around the origin.
focus: The origin is set to the focus of attention of all cameras (the
closest point to cameras optical axes). Recommended for inward-looking
camera configurations.


Args:
poses: The poses to orient.
method: The method to use for orientation.
center_poses: If True, the poses are centered around the origin.
center_method: The method to use to center poses

Returns:
The oriented poses.
Tuple of the oriented poses and the transform matrix.
"""

translation = poses[..., :3, 3]
origin = poses[..., :3, 3]

mean_translation = torch.mean(translation, dim=0)
translation_diff = translation - mean_translation
mean_origin = torch.mean(origin, dim=0)
translation_diff = origin - mean_origin

if center_poses:
translation = mean_translation
if center_method == "poses":
translation = mean_origin
elif center_method == "focus":
translation = focus_of_attention(poses, mean_origin)
elif center_method == "none":
translation = torch.zeros_like(mean_origin)
else:
translation = torch.zeros_like(mean_translation)
raise ValueError(f"Unknown value for center_method: {center_method}")

if method == "pca":
_, eigvec = torch.linalg.eigh(translation_diff.T @ translation_diff)
Expand Down
8 changes: 5 additions & 3 deletions nerfstudio/data/dataparsers/nerfstudio_dataparser.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,10 @@ class NerfstudioDataParserConfig(DataParserConfig):
"""How much to scale the region of interest by."""
orientation_method: Literal["pca", "up", "none"] = "up"
"""The method to use for orientation."""
center_poses: bool = True
"""Whether to center the poses."""
# center_poses: bool = True
# """Whether to center the poses."""
center_method: Literal["poses", "focus", "none"] = "poses"
"""The method to use to center the poses"""
auto_scale_poses: bool = True
"""Whether to automatically scale the poses to fit in +/- 1 bounding box."""
train_split_percentage: float = 0.9
Expand Down Expand Up @@ -189,7 +191,7 @@ def _generate_dataparser_outputs(self, split="train"):
poses, transform_matrix = camera_utils.auto_orient_and_center_poses(
poses,
method=orientation_method,
center_poses=self.config.center_poses,
center_method=self.config.center_method,
)

# Scale poses
Expand Down
48 changes: 46 additions & 2 deletions nerfstudio/data/dataparsers/sdfstudio_dataparser.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
"""Data parser for friends dataset"""
from __future__ import annotations

import os
from dataclasses import dataclass, field
from pathlib import Path
from typing import Dict, Optional, Type
Expand All @@ -34,6 +35,7 @@
DataparserOutputs,
)
from nerfstudio.data.scene_box import SceneBox
from nerfstudio.data.utils.data_utils import create_masked_img
from nerfstudio.utils.images import BasicImages
from nerfstudio.utils.io import load_from_json

Expand Down Expand Up @@ -157,7 +159,7 @@ class SDFStudioDataParserConfig(DataParserConfig):
# """How much to downscale images. If not set, images are chosen such that the max dimension is <1600px."""
orientation_method: Literal["up", "none"] = "up"
"""The method to use for orientation."""
center_poses: bool = False
center_method: Literal["focus", "none"] = "focus"
"""Whether to center the poses."""
auto_scale_poses: bool = False
"""Whether to automatically scale the poses to fit in +/- 1 bounding box."""
Expand All @@ -176,6 +178,12 @@ class SDFStudioDataParserConfig(DataParserConfig):
"""automatically orient the scene such that the up direction is the same as the viewer's up direction"""
load_dtu_highres: bool = False
"""load high resolution images from DTU dataset, should only be used for the preprocessed DTU dataset"""
train_with_masked_imgs: bool = False
"""whether or not to mask out objects using foreground masks and train with masked images"""
sample_from_mask: bool = False
"""if true, pixels are sampled only from masked regions"""
masked_img_dir: str = "masked_images"
"""name of the folder where masked images are stored if train_with_masked_imgs is true"""


def filter_list(list_to_filter, indices):
Expand Down Expand Up @@ -209,6 +217,7 @@ def _generate_dataparser_outputs(self, split="train"): # pylint: disable=unused

image_filenames = []
depth_images = []
mask_filenames = []
normal_images = []
sensor_depth_images = []
foreground_mask_images = []
Expand All @@ -221,6 +230,35 @@ def _generate_dataparser_outputs(self, split="train"): # pylint: disable=unused
for i, frame in enumerate(meta["frames"]):
image_filename = self.config.data / frame["rgb_path"]


if (
self.config.train_with_masked_imgs
or self.config.include_foreground_mask
or self.config.sample_from_mask
):
assert meta["has_foreground_mask"]
mask_filename = self.config.data / frame["foreground_mask"]
mask = np.array(Image.open(mask_filename), dtype=np.float32) / 255.0
if len(mask.shape) == 3:
mask = mask[..., 0]
if self.config.train_with_masked_imgs or self.config.sample_from_mask:
masked_img_dir_path = self.config.data / self.config.masked_img_dir
os.makedirs(str(masked_img_dir_path), exist_ok=True)

if self.config.train_with_masked_imgs:
image_filename = create_masked_img(image_filename, mask_filename, masked_img_dir_path)

if self.config.include_foreground_mask:
foreground_mask = mask[..., None]
foreground_mask_images.append(torch.from_numpy(foreground_mask).float())

if self.config.sample_from_mask:
# nerfstudio's pixel sampler requires single channel masks
mask_img = Image.fromarray((255.0 * mask).astype(np.uint8))
mask_filename = masked_img_dir_path / mask_filename.name
mask_img.save(mask_filename)
mask_filenames.append(mask_filename)

intrinsics = torch.tensor(frame["intrinsics"])
camtoworld = torch.tensor(frame["camtoworld"])

Expand All @@ -236,6 +274,8 @@ def _generate_dataparser_outputs(self, split="train"): # pylint: disable=unused
assert meta["has_mono_prior"]
# load mono depth
depth = np.load(self.config.data / frame["mono_depth_path"])
if self.config.train_with_masked_imgs:
depth = depth * mask
depth_images.append(torch.from_numpy(depth).float())

# load mono normal
Expand All @@ -258,6 +298,9 @@ def _generate_dataparser_outputs(self, split="train"): # pylint: disable=unused
assert meta["has_sensor_depth"]
# load sensor depth
sensor_depth = np.load(self.config.data / frame["sensor_depth_path"])
if self.config.train_with_masked_imgs:
# TODO: Maybe set background depth to very large value instead of 0?
sensor_depth = sensor_depth * mask
sensor_depth_images.append(torch.from_numpy(sensor_depth).float())

if self.config.include_foreground_mask:
Expand Down Expand Up @@ -310,7 +353,7 @@ def _generate_dataparser_outputs(self, split="train"): # pylint: disable=unused
camera_to_worlds, transform = camera_utils.auto_orient_and_center_poses(
camera_to_worlds,
method=orientation_method,
center_poses=self.config.center_poses,
center_method=self.config.center_method,
)

# we should also transform normal accordingly
Expand Down Expand Up @@ -418,6 +461,7 @@ def _generate_dataparser_outputs(self, split="train"): # pylint: disable=unused
dataparser_outputs = DataparserOutputs(
image_filenames=filter_list(image_filenames, indices),
cameras=cameras,
mask_filenames=mask_filenames if self.config.sample_from_mask else None,
scene_box=scene_box,
additional_inputs=additional_inputs_dict,
depths=filter_list(depth_images, indices),
Expand Down
26 changes: 26 additions & 0 deletions nerfstudio/data/utils/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

"""Utility functions to allow easy re-use of common operations across dataloaders"""
import os
from pathlib import Path
from typing import List, Tuple, Union

Expand Down Expand Up @@ -51,3 +52,28 @@ def get_semantics_and_mask_tensors_from_path(
semantics = torch.from_numpy(np.array(pil_image, dtype="int64"))[..., None]
mask = torch.sum(semantics == mask_indices, dim=-1, keepdim=True) == 0
return semantics, mask


def create_masked_img(img_filepath: Path, mask_filepath: Path, output_dir: Path) -> Path:
"""
Utility function to mask an image using provided mask and store it on disk.
Output_dir is absolute path where to store the masked image.
"""
img = np.array(Image.open(img_filepath), dtype=np.float32)
mask = np.array(Image.open(mask_filepath), dtype=np.float32) / 255.0
assert len(img.shape) == 3
if img.shape[-1] == 4:
img = img[:, :, :3]

# in case the mask comes with alpha channel
if mask.shape[-1] == 4:
mask = mask[:, :, :3]

if len(mask.shape) == 2:
mask = mask[..., np.newaxis]

masked_image = Image.fromarray((img * mask).astype(np.uint8))
masked_image_filename = output_dir / (img_filepath.stem + "_masked" + img_filepath.suffix)
masked_image.save(masked_image_filename)

return masked_image_filename
3 changes: 1 addition & 2 deletions nerfstudio/data/utils/nerfstudio_collate.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@

import torch
import torch.utils.data
from torch._six import string_classes

from nerfstudio.cameras.cameras import Cameras
from nerfstudio.utils.images import BasicImages
Expand Down Expand Up @@ -120,7 +119,7 @@ def nerfstudio_collate(
return torch.tensor(batch, dtype=torch.float64)
elif isinstance(elem, int):
return torch.tensor(batch)
elif isinstance(elem, string_classes):
elif isinstance(elem, str):
return batch
elif isinstance(elem, collections.abc.Mapping):
try:
Expand Down
Loading