Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

tensorboard 2.3.0+ supported #77

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions jupyter_tensorboard/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,9 @@ def get(self, name, path):
else:
raise web.HTTPError(404)

@web.authenticated
def post(self, name, path):
return self.get(name, path)

class TensorboardErrorHandler(IPythonHandler):
pass
78 changes: 72 additions & 6 deletions jupyter_tensorboard/tensorboard_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,77 @@ def reload_multiplexer(multiplexer, path_to_run):
multiplexer.Reload()
application.reload_multiplexer = reload_multiplexer

if not hasattr(application, 'standard_tensorboard_wsgi'):
# Tensorflow 2.3+ removed reload_multiplexer, patch it
from tensorboard.backend.event_processing import (
data_provider as event_data_provider,
)
from tensorboard.backend.event_processing import (
plugin_event_multiplexer as event_multiplexer,
)
from tensorboard.backend.event_processing.data_ingester import (
DEFAULT_TENSOR_SIZE_GUIDANCE
)
def _apply_tensor_size_guidance(sampling_hints):
"""Apply user per-summary size guidance overrides."""
tensor_size_guidance = dict(DEFAULT_TENSOR_SIZE_GUIDANCE)
tensor_size_guidance.update(sampling_hints)
return tensor_size_guidance
def _get_event_file_active_filter(flags):
"""Returns a predicate for whether an event file load timestamp is active.

Returns:
A predicate function accepting a single UNIX timestamp float argument, or
None if multi-file loading is not enabled.
"""
if not flags.reload_multifile:
return None
inactive_secs = flags.reload_multifile_inactive_secs
if inactive_secs == 0:
return None
if inactive_secs < 0:
return lambda timestamp: True
return lambda timestamp: timestamp + inactive_secs >= time.time()
def standard_tensorboard_wsgi(flags, plugin_loaders, assets_zip_provider):
data_provider = None
multiplexer = None
reload_interval = flags.reload_interval
# Regular logdir loading mode.
sampling_hints = flags.samples_per_plugin
multiplexer = event_multiplexer.EventMultiplexer(
tensor_size_guidance=_apply_tensor_size_guidance(sampling_hints),
purge_orphaned_data=flags.purge_orphaned_data,
max_reload_threads=flags.max_reload_threads,
event_file_active_filter=_get_event_file_active_filter(flags),
)
if reload_interval >= 0:
# We either reload the multiplexer once when TensorBoard starts up, or we
# continuously reload the multiplexer.
if flags.logdir:
path_to_run = {os.path.expanduser(flags.logdir): None}
else:
path_to_run = parse_event_files_spec(flags.logdir_spec)
application.reload_multiplexer(multiplexer, path_to_run)
#start_reloading_multiplexer(
# multiplexer, path_to_run, reload_interval
#)
data_provider = event_data_provider.MultiplexerDataProvider(
multiplexer, flags.logdir or flags.logdir_spec
)

return application.TensorBoardWSGIApp(
flags, plugin_loaders, data_provider, assets_zip_provider, multiplexer
)
application.standard_tensorboard_wsgi = standard_tensorboard_wsgi

if not hasattr(application, 'parse_event_files_spec'):
def parse_event_files_spec(logdir_spec):
from tensorboard.backend.event_processing.data_ingester import (
_parse_event_files_spec
)
return _parse_event_files_spec(logdir_spec)
application.parse_event_files_spec = parse_event_files_spec

if hasattr(default, 'PLUGIN_LOADERS') or hasattr(default, '_PLUGINS'):
# Tensorflow 1.10 or above series
logging.debug("Tensorboard 1.10 or above series detected")
Expand Down Expand Up @@ -147,16 +218,11 @@ def TensorBoardWSGIApp_2x(
application.reload_multiplexer(multiplexer, path_to_run)
thread = None

db_uri = None
db_connection_provider = None

plugin_name_to_instance = {}

from tensorboard.plugins import base_plugin
context = base_plugin.TBContext(
data_provider=data_provider,
db_connection_provider=db_connection_provider,
db_uri=db_uri,
flags=flags,
logdir=flags.logdir,
multiplexer=deprecated_multiplexer,
Expand All @@ -172,7 +238,7 @@ def TensorBoardWSGIApp_2x(
tbplugins.append(plugin)
plugin_name_to_instance[plugin.plugin_name] = plugin

tb_app = application.TensorBoardWSGI(tbplugins)
tb_app = application.TensorBoardWSGI(tbplugins, data_provider=data_provider)
manager.add_instance(logdir, tb_app, thread)
return tb_app

Expand Down
30 changes: 30 additions & 0 deletions tests/test_tensorboard_integration.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,30 @@
# -*- coding:utf-8 -*-

import os
import sys
import time
import logging
import json
import binascii

import pytest
from tornado.testing import AsyncHTTPTestCase

def encode_multipart_formdata(fields):
boundary = binascii.hexlify(os.urandom(16)).decode('ascii')

body = (
"".join("--%s\r\n"
"Content-Disposition: form-data; name=\"%s\"\r\n"
"\r\n"
"%s\r\n" % (boundary, field, value)
for field, value in fields.items()) +
"--%s--\r\n" % boundary
)

content_type = "multipart/form-data; boundary=%s" % boundary

return body, content_type

@pytest.fixture(scope="session")
def tf_logs(tmpdir_factory):
Expand Down Expand Up @@ -98,6 +115,19 @@ def test_tensorboard(self):
response = self.fetch('/tensorboard/1/#graphs')
assert response.code == 200

response = self.fetch(
'/tensorboard/1/data/plugin/scalars/tags',
method='GET')
assert response.code == 200

body, content_type = encode_multipart_formdata({'tag':'loss', 'runs':['.']})
response = self.fetch(
'/tensorboard/1/data/plugin/scalars/scalars_multirun',
method='POST',
body=body,
headers={'Content-Type': content_type})
assert response.code == 200

response = self.fetch('/tensorboard/1/data/plugins_listing')
plugins_list = json.loads(response.body.decode())
assert plugins_list["graphs"]
Expand Down