Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DRAFT] Automatic circular camera path rendering #3314

Draft
wants to merge 15 commits into
base: main
Choose a base branch
from
Draft
142 changes: 142 additions & 0 deletions nerfstudio/viewer/render_panel.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
import viser.transforms as tf
from scipy import interpolate

from nerfstudio.models.base_model import Model
from nerfstudio.viewer.control_panel import ControlPanel


Expand Down Expand Up @@ -520,6 +521,7 @@
server: viser.ViserServer,
config_path: Path,
datapath: Path,
viewer_model: Model,
control_panel: Optional[ControlPanel] = None,
) -> RenderTabState:
from nerfstudio.viewer.viewer import VISER_NERFSTUDIO_SCALE_RATIO
Expand Down Expand Up @@ -588,6 +590,7 @@
initial_value="Perspective",
hint="Camera model to render with. This is applied to all keyframes.",
)

add_button = server.gui.add_button(
"Add Keyframe",
icon=viser.Icon.PLUS,
Expand Down Expand Up @@ -726,6 +729,144 @@
def _(_) -> None:
camera_path.show_spline = show_spline_checkbox.value
camera_path.update_spline()

auto_camera_folder = server.gui.add_folder("Automatic Camera Path")
with auto_camera_folder:
click_position = np.array([1.0, 0.0, 0.0])
select_center_button = server.gui.add_button(
"Select Center",
icon=viser.Icon.CROSSHAIR,
hint="Choose center point to generate camera path around.",
)

@select_center_button.on_click
def _(event: viser.GuiEvent) -> None:
select_center_button.disabled = True

@event.client.scene.on_pointer_event(event_type="click")
def _(event: viser.ScenePointerEvent) -> None:
# Code mostly borrowed from garfield.studio!
import torch
from nerfstudio.cameras.rays import RayBundle
from nerfstudio.field_components.field_heads import FieldHeadNames
from nerfstudio.model_components.losses import scale_gradients_by_distance_squared

Check failure on line 753 in nerfstudio/viewer/render_panel.py

View workflow job for this annotation

GitHub Actions / build

Ruff (I001)

nerfstudio/viewer/render_panel.py:749:1: I001 Import block is un-sorted or un-formatted
origin = torch.tensor(event.ray_origin).view(1, 3)
direction = torch.tensor(event.ray_direction).view(1, 3)

# Get intersection
bundle = RayBundle(
origins=origin,
directions=direction,
pixel_area=torch.tensor(0.001).view(1, 1),
camera_indices=torch.tensor(0).view(1, 1),
nears=torch.tensor(0.05).view(1, 1),
fars=torch.tensor(100).view(1, 1),
).to("cuda")

# Get the distance/depth to the intersection --> calculate 3D position of the click
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a good start! but this part should use the more general get_outputs_for_camera() function (a lot of the sampling, scaling stuff is method-specific to nerfacto). As long as the method outputs a 'depth' value it should work with this ray deprojection approach

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This way, you would render a depth image from the viewer camera and deproject the click point with the intrinsics matrix + rendered depth, meaning it doesn't matter what the rendering backend is.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One thing to note is that the 'depth' in splatfacto is actually z-depth and not ray-depth, so the math would need to be different for the two methods

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok that makes sense! Thanks for the feedback appreciate it!!

ray_samples, _, _ = viewer_model.proposal_sampler(bundle, density_fns=viewer_model.density_fns)
field_outputs = viewer_model.field.forward(ray_samples, compute_normals=viewer_model.config.predict_normals)
if viewer_model.config.use_gradient_scaling:
field_outputs = scale_gradients_by_distance_squared(field_outputs, ray_samples)
weights = ray_samples.get_weights(field_outputs[FieldHeadNames.DENSITY])
with torch.no_grad():
depth = viewer_model.renderer_depth(weights=weights, ray_samples=ray_samples)
distance = depth[0, 0].detach().cpu().numpy()

nonlocal click_position
click_position = np.array(origin + direction * distance).reshape(3,)

server.scene.add_icosphere(
f"/render_center_pos",

Check failure on line 781 in nerfstudio/viewer/render_panel.py

View workflow job for this annotation

GitHub Actions / build

Ruff (F541)

nerfstudio/viewer/render_panel.py:781:25: F541 f-string without any placeholders
radius=0.1,
color=(200, 10, 30),
position=click_position,
)

event.client.scene.remove_pointer_callback()

@event.client.scene.on_pointer_callback_removed
def _():
select_center_button.disabled = False

num_cameras_handle = server.gui.add_number(
label="Number of Cameras",
initial_value=10,
hint="Total number of cameras generated in path, placed equidistant from neighboring ones.",
)

radius_handle = server.gui.add_number(
label="Radius",
initial_value=2,
hint="Radius of circular camera path.",
)

camera_height_handle = server.gui.add_number(
label="Height",
initial_value=2,
hint="Height of cameras with respect to chosen origin.",
)

circular_camera_path_button = server.gui.add_button(
"Generate Circular Camera Path",
icon=viser.Icon.CAMERA,
hint="Automatically generate a circular camera path around selected point.",
)

@circular_camera_path_button.on_click
def _(event: viser.GuiEvent) -> None:
nonlocal click_position, num_cameras_handle, radius_handle, camera_height_handle
num_cameras = num_cameras_handle.value
radius = radius_handle.value
camera_height = camera_height_handle.value

camera_coords = []
for i in range(num_cameras):
camera_coords.append((radius * np.cos(2 * np.pi * i / num_cameras), radius * np.sin(2 * np.pi * i/ num_cameras)))

def wxyz_helper(camera_position: np.ndarray) -> np.ndarray:
# Calculates the camera direction from position to click_position
camera_direction = camera_position - click_position
camera_direction = camera_direction / np.linalg.norm(camera_direction)

global_up = np.array([0.0, 0.0, 1.0])

camera_right = np.cross(camera_direction, global_up)
camera_right_norm = np.linalg.norm(camera_right)
if camera_right_norm > 0:
camera_right = camera_right / camera_right_norm

camera_up = np.cross(camera_right, camera_direction)

R = np.array([camera_right, camera_up, -camera_direction]).T

w = np.sqrt(1 + R[0, 0] + R[1, 1] + R[2, 2]) / 2
x = (R[2, 1] - R[1, 2]) / (4 * w)
y = (R[0, 2] - R[2, 0]) / (4 * w)
z = (R[1, 0] - R[0, 1]) / (4 * w)
return np.array([w, x, y, z])
else:
return np.array([1.0, 0.0, 0.0, 0.0])

fov = event.client.camera.fov
for i, item in enumerate(camera_coords):
position = click_position + np.array([item[0], item[1], camera_height])
camera_path.add_camera(
keyframe=Keyframe(
position=position,
wxyz=wxyz_helper(position),
override_fov_enabled=False,
override_fov_rad=fov,
override_time_enabled=False,
override_time_val=0.0,
aspect=resolution.value[0] / resolution.value[1],
override_transition_enabled=False,
override_transition_sec=None,
)
)
duration_number.value = camera_path.compute_duration()
camera_path.update_spline()

playback_folder = server.gui.add_folder("Playback")
with playback_folder:
Expand Down Expand Up @@ -1178,6 +1319,7 @@
server=viser.ViserServer(),
config_path=Path("."),
datapath=Path("."),
viewer_model=Model,
)
while True:
time.sleep(10.0)
3 changes: 2 additions & 1 deletion nerfstudio/viewer/viewer.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,9 +200,10 @@ def __init__(
default_composite_depth=self.config.default_composite_depth,
)
config_path = self.log_filename.parents[0] / "config.yml"

with tabs.add_tab("Render", viser.Icon.CAMERA):
self.render_tab_state = populate_render_tab(
self.viser_server, config_path, self.datapath, self.control_panel
self.viser_server, config_path, self.datapath, self.pipeline.model, self.control_panel
)

with tabs.add_tab("Export", viser.Icon.PACKAGE_EXPORT):
Expand Down
Loading