Skip to content

Commit

Permalink
Add get_render() method for client handles
Browse files Browse the repository at this point in the history
  • Loading branch information
brentyi committed Sep 20, 2023
1 parent 8ac4a81 commit 26e6322
Show file tree
Hide file tree
Showing 12 changed files with 381 additions and 156 deletions.
59 changes: 30 additions & 29 deletions docs/source/examples/01_image.rst
Original file line number Diff line number Diff line change
Expand Up @@ -24,37 +24,38 @@ NeRFs), or images to render as 3D textures.
import viser
server = viser.ViserServer()
# Add a background image.
server.set_background_image(
iio.imread(Path(__file__).parent / "assets/Cal_logo.png"),
format="png",
)
# Add main image.
server.add_image(
"/img",
iio.imread(Path(__file__).parent / "assets/Cal_logo.png"),
4.0,
4.0,
format="png",
wxyz=(1.0, 0.0, 0.0, 0.0),
position=(2.0, 2.0, 0.0),
)
while True:
if __name__ == "__main__":
server = viser.ViserServer()
# Add a background image.
server.set_background_image(
iio.imread(Path(__file__).parent / "assets/Cal_logo.png"),
format="png",
)
# Add main image.
server.add_image(
"/noise",
onp.random.randint(
0,
256,
size=(400, 400, 3),
dtype=onp.uint8,
),
"/img",
iio.imread(Path(__file__).parent / "assets/Cal_logo.png"),
4.0,
4.0,
format="jpeg",
format="png",
wxyz=(1.0, 0.0, 0.0, 0.0),
position=(2.0, 2.0, -1e-2),
position=(2.0, 2.0, 0.0),
)
time.sleep(0.2)
while True:
server.add_image(
"/noise",
onp.random.randint(
0,
256,
size=(400, 400, 3),
dtype=onp.uint8,
),
4.0,
4.0,
format="jpeg",
wxyz=(1.0, 0.0, 0.0, 0.0),
position=(2.0, 2.0, -1e-2),
)
time.sleep(0.2)
3 changes: 2 additions & 1 deletion docs/source/examples/07_record3d_visualizer.rst
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,9 @@ Parse and stream record3d captures. To get the demo data, see ``./assets/downloa
data_path: Path = Path(__file__).parent / "assets/record3d_dance",
downsample_factor: int = 4,
max_frames: int = 100,
share: bool = False,
) -> None:
server = viser.ViserServer()
server = viser.ViserServer(share=share)
print("Loading frames!")
loader = viser.extras.Record3dLoader(data_path)
Expand Down
3 changes: 2 additions & 1 deletion docs/source/examples/08_smplx_visualizer.rst
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,9 @@ parameters to run this script:
num_betas: int = 10,
num_expression_coeffs: int = 10,
ext: Literal["npz", "pkl"] = "npz",
share: bool = False,
) -> None:
server = viser.ViserServer()
server = viser.ViserServer(share=share)
server.configure_theme(control_layout="collapsible", dark_mode=True)
model = smplx.create(
model_path=str(model_path),
Expand Down
47 changes: 47 additions & 0 deletions examples/19_get_renders.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
"""Get Renders
Example for getting renders from a client's viewport to the Python API."""

import time

import imageio.v3 as iio
import numpy as onp

import viser


def main():
server = viser.ViserServer()

button = server.add_gui_button("Render a GIF")

@button.on_click
def _(event: viser.GuiEvent) -> None:
client = event.client
assert client is not None

client.reset_scene()

images = []

for i in range(20):
positions = onp.random.normal(size=(30, 3)) * 3.0
client.add_spline_catmull_rom(
f"/catmull_{i}",
positions,
tension=0.5,
line_width=3.0,
color=onp.random.uniform(size=3),
)
images.append(client.get_render(height=1080, width=1920))

print("Writing GIF...")
iio.imwrite("saved.gif", images)
print("Wrote GIF!")

while True:
time.sleep(10.0)


if __name__ == "__main__":
main()
17 changes: 17 additions & 0 deletions src/viser/_messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -494,3 +494,20 @@ class CubicBezierSplineMessage(Message):
control_points: Tuple[Tuple[float, float, float], ...]
line_width: float
color: int


@dataclasses.dataclass
class GetRenderRequestMessage(Message):
"""Message from server->client requesting a render of the current viewport."""

format: Literal["image/jpeg", "image/png"]
height: int
width: int
quality: int


@dataclasses.dataclass
class GetRenderResponseMessage(Message):
"""Message from client->server carrying a render."""

payload: bytes
57 changes: 55 additions & 2 deletions src/viser/_viser.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,20 @@

import contextlib
import dataclasses
import io
import threading
import time
from pathlib import Path
from typing import Callable, Dict, Generator, List, Tuple
from typing import Callable, Dict, Generator, List, Optional, Tuple

import imageio.v3 as iio
import numpy as onp
import numpy.typing as npt
import rich
from rich import box, style
from rich.panel import Panel
from rich.table import Table
from typing_extensions import override
from typing_extensions import Literal, override

from . import _client_autobuild, _messages, infra
from . import transforms as tf
Expand Down Expand Up @@ -196,6 +198,57 @@ def _queue_unsafe(self, message: _messages.Message) -> None:
"""Define how the message API should send messages."""
self._state.connection.send(message)

def get_render(
self, height: int, width: int, transport_format: Literal["png", "jpeg"] = "jpeg"
) -> onp.ndarray:
"""Request a render from a client, block until it's done and received, then
return it as a numpy array.
Args:
height: Height of rendered image. Should be <= the browser height.
width: Width of rendered image. Should be <= the browser width.
transport_format: Image transport format. JPEG will return a lossy (H, W, 3) RGB array. PNG will
return a lossless (H, W, 4) RGBA array, but can cause memory issues on the frontend if called
too quickly for higher-resolution images.
"""

# Listen for a render reseponse message, which should contain the rendered
# image.
render_ready_event = threading.Event()
out: Optional[onp.ndarray] = None

def got_render_cb(
client_id: int, message: _messages.GetRenderResponseMessage
) -> None:
del client_id
self._state.connection.unregister_handler(
_messages.GetRenderResponseMessage, got_render_cb
)
nonlocal out
out = iio.imread(
io.BytesIO(message.payload),
extension=f".{transport_format}",
)
render_ready_event.set()

self._state.connection.register_handler(
_messages.GetRenderResponseMessage, got_render_cb
)
self._queue(
_messages.GetRenderRequestMessage(
"image/jpeg" if transport_format == "jpeg" else "image/png",
height=height,
width=width,
# Only used for JPEG. The main reason to use a lower quality version
# value is (unfortunately) to make life easier for the Javascript
# garbage collector.
quality=80,
)
)
render_ready_event.wait()
assert out is not None
return out

@contextlib.contextmanager
def atomic(self) -> Generator[None, None, None]:
"""Returns a context where:
Expand Down
9 changes: 8 additions & 1 deletion src/viser/client/src/App.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ import {
import { Titlebar } from "./Titlebar";
import { ViserModal } from "./Modal";
import { useSceneTreeState } from "./SceneTreeState";
import { Message } from "./WebsocketMessages";
import { GetRenderRequestMessage, Message } from "./WebsocketMessages";

export type ViewerContextContents = {
// Zustand hooks.
Expand All @@ -61,6 +61,11 @@ export type ViewerContextContents = {
};
}>;
messageQueueRef: React.MutableRefObject<Message[]>;
// Requested a render.
getRenderRequestState: React.MutableRefObject<
"ready" | "triggered" | "pause" | "in_progress"
>;
getRenderRequest: React.MutableRefObject<null | GetRenderRequestMessage>;
};
export const ViewerContext = React.createContext<null | ViewerContextContents>(
null
Expand Down Expand Up @@ -99,6 +104,8 @@ function ViewerRoot() {
// Scene node attributes that aren't placed in the zustand state for performance reasons.
nodeAttributesFromName: React.useRef({}),
messageQueueRef: React.useRef([]),
getRenderRequestState: React.useRef("ready"),
getRenderRequest: React.useRef(null),
};

return (
Expand Down
49 changes: 10 additions & 39 deletions src/viser/client/src/SceneTree.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -40,47 +40,16 @@ function SceneNodeThreeChildren(props: {
parent: THREE.Object3D;
}) {
const viewer = React.useContext(ViewerContext)!;
const [children, setChildren] = React.useState<string[]>([]);

// De-bounce updates to children.
React.useEffect(() => {
let readyToUpdate = true;

let updateChildrenTimeout: NodeJS.Timeout | undefined = undefined;

function updateChildren() {
const newChildren =
viewer.useSceneTree.getState().nodeFromName[props.name]?.children;
if (newChildren === undefined || children == newChildren) {
return;
}
if (readyToUpdate) {
setChildren(newChildren!);
readyToUpdate = false;
updateChildrenTimeout = setTimeout(() => {
readyToUpdate = true;
updateChildren();
}, 50);
}
}
const unsubscribe = viewer.useSceneTree.subscribe(
(state) => state.nodeFromName[props.name],
updateChildren
);
updateChildren();

return () => {
clearTimeout(updateChildrenTimeout);
unsubscribe();
};
}, [children]);
const children =
viewer.useSceneTree((state) => state.nodeFromName[props.name]?.children);

// Create a group of children inside of the parent object.
return createPortal(
<group>
{children.map((child_id) => {
return <SceneNodeThreeObject key={child_id} name={child_id} />;
})}
{children &&
children.map((child_id) => {
return <SceneNodeThreeObject key={child_id} name={child_id} />;
})}
<SceneNodeLabel name={props.name} />
</group>,
props.parent
Expand Down Expand Up @@ -129,8 +98,10 @@ export function SceneNodeThreeObject(props: { name: string }) {
// For not-fully-understood reasons, wrapping makeObject with useMemo() fixes
// stability issues (eg breaking runtime errors) associated with
// PivotControls.
const objNode =
makeObject && React.useMemo(() => makeObject(setRef), [makeObject]);
const objNode = React.useMemo(
() => makeObject && makeObject(setRef),
[makeObject]
);
const children =
obj === null ? null : (
<SceneNodeThreeChildren name={props.name} parent={obj} />
Expand Down
Loading

0 comments on commit 26e6322

Please sign in to comment.