Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 19 additions & 5 deletions xla/python/profiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,12 @@ struct ProfilerSessionWrapper {
explicit ProfilerSessionWrapper(std::unique_ptr<tsl::ProfilerSession> session)
: session(std::move(session)) {}

ProfilerSessionWrapper(std::unique_ptr<tsl::ProfilerSession> session,
std::string session_id)
: session(std::move(session)), session_id(std::move(session_id)) {}

std::unique_ptr<tsl::ProfilerSession> session;
std::string session_id;
};

static std::string GetFdoProfile(const std::string& xspace,
Expand Down Expand Up @@ -172,17 +177,23 @@ NB_MODULE(_profiler, m) {
.def("__init__",
[](ProfilerSessionWrapper* wrapper,
const tensorflow::ProfileOptions& options) {
new (wrapper)
ProfilerSessionWrapper(tsl::ProfilerSession::Create(options));
new (wrapper) ProfilerSessionWrapper(
tsl::ProfilerSession::Create(options), options.session_id());
})
.def(
"stop_and_export",
[](ProfilerSessionWrapper* sess, const std::string& tensorboard_dir) {
tensorflow::profiler::XSpace xspace;
// Disables the ProfilerSession
xla::ThrowIfError(sess->session->CollectData(&xspace));
xla::ThrowIfError(tsl::profiler::ExportToTensorBoard(
xspace, tensorboard_dir, /* also_export_trace_json= */ true));
if (sess->session_id.empty()) {
xla::ThrowIfError(tsl::profiler::ExportToTensorBoard(
xspace, tensorboard_dir, /* also_export_trace_json= */ true));
} else {
xla::ThrowIfError(tsl::profiler::ExportToTensorBoard(
xspace, tensorboard_dir, sess->session_id,
/* also_export_trace_json= */ true));
}
},
nb::call_guard<nb::gil_scoped_release>())
.def("stop",
Expand Down Expand Up @@ -275,7 +286,10 @@ NB_MODULE(_profiler, m) {
"repository_path", &tensorflow::ProfileOptions::repository_path,
[](tensorflow::ProfileOptions* options, const std::string& path) {
options->set_repository_path(path);
});
})
.def_prop_rw("session_id", &tensorflow::ProfileOptions::session_id,
[](tensorflow::ProfileOptions* options,
const std::string& id) { options->set_session_id(id); });

nb::class_<TraceMeWrapper> traceme_class(m, "TraceMe");
traceme_class.def(nb::init<nb::str, nb::kwargs>())
Expand Down
10 changes: 9 additions & 1 deletion xla/tsl/profiler/rpc/client/capture_profile.cc
Original file line number Diff line number Diff line change
Expand Up @@ -258,10 +258,10 @@ absl::Status Monitor(const std::string& service_addr, int duration_ms,

absl::Status ExportToTensorBoard(const XSpace& xspace,
const std::string& logdir,
const std::string& run,
bool also_export_trace_json) {
std::string repository_root =
tsl::profiler::GetTensorBoardProfilePluginDir(logdir);
std::string run = tsl::profiler::GetCurrentTimeStampAsString();
std::string host = tsl::port::Hostname();
TF_RETURN_IF_ERROR(
tsl::profiler::SaveXSpace(repository_root, run, host, xspace));
Expand All @@ -275,6 +275,14 @@ absl::Status ExportToTensorBoard(const XSpace& xspace,
return absl::OkStatus();
}

absl::Status ExportToTensorBoard(const XSpace& xspace,
const std::string& logdir,
bool also_export_trace_json) {
return ExportToTensorBoard(xspace, logdir,
tsl::profiler::GetCurrentTimeStampAsString(),
also_export_trace_json);
}

absl::Status CaptureRemoteTrace(
const char* service_addr, const char* logdir, const char* worker_list,
bool include_dataset_ops, int duration_ms, int num_tracing_attempts,
Expand Down
5 changes: 5 additions & 0 deletions xla/tsl/profiler/rpc/client/capture_profile.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,11 @@ absl::Status ExportToTensorBoard(const tensorflow::profiler::XSpace& xspace,
const std::string& logdir,
bool also_export_trace_json = false);

absl::Status ExportToTensorBoard(const tensorflow::profiler::XSpace& xspace,
const std::string& logdir,
const std::string& run,
bool also_export_trace_json = false);

// Collects one sample of monitoring profile and shows user-friendly metrics.
// If timestamp flag is true, timestamp will be displayed in "%H:%M:%S" format.
absl::Status Monitor(const std::string& service_addr, int duration_ms,
Expand Down
Loading