Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add new useful plot features #1544

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 58 additions & 29 deletions deepxde/utils/external.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,6 @@ def uniformly_continuous_delta(X, Y, eps):
return delta - dx / 2
delta += dx


def saveplot(
loss_history,
train_state,
Expand All @@ -141,25 +140,22 @@ def saveplot(
train_fname="train.dat",
test_fname="test.dat",
output_dir=None,
save_format="png",
generate_statistics=False
):
"""Save/plot the loss history and best trained result.

This function is used to quickly check your results. To better investigate your
fisher75 marked this conversation as resolved.
Show resolved Hide resolved
result, use ``save_loss_history()`` and ``save_best_state()``.

Args:
loss_history: ``LossHistory`` instance. The first variable returned from
``Model.train()``.
train_state: ``TrainState`` instance. The second variable returned from
``Model.train()``.
issave (bool): Set ``True`` (default) to save the loss, training points,
and testing points.
isplot (bool): Set ``True`` (default) to plot loss, metric, and the predicted
solution.
loss_history: LossHistory instance. The first variable returned from Model.train().
Copy link
Owner

@lululxvi lululxvi Nov 22, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why modifying this?

train_state: TrainState instance. The second variable returned from Model.train().
issave (bool): Set True (default) to save the loss, training points, and testing points.
isplot (bool): Set True (default) to plot loss, metric, and the predicted solution.
loss_fname (string): Name of the file to save the loss in.
train_fname (string): Name of the file to save the training points in.
test_fname (string): Name of the file to save the testing points in.
output_dir (string): If ``None``, use the current working directory.
output_dir (string): If None, use the current working directory.
save_format (string): File format for saving the plot (default is "png").
generate_statistics (bool): Set True to generate additional statistics (average, standard deviation, etc.).
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

max 88 characters per line

"""
if output_dir is None:
output_dir = os.getcwd()
Expand All @@ -175,29 +171,61 @@ def saveplot(
save_best_state(train_state, train_fname, test_fname)

if isplot:
plot_loss_history(loss_history)
plot_best_state(train_state)
plt.show()


def plot_loss_history(loss_history, fname=None):
"""Plot the training and testing loss history.

Note:
You need to call ``plt.show()`` to show the figure.
plot_style = {
'train_color': 'b',
'test_color': 'r',
'train_linestyle': '-',
'test_linestyle': '--'
}

plot_loss_history(loss_history, fname=os.path.join(output_dir, f"custom_style_loss.{save_format}"), plot_style=plot_style)

if generate_statistics:
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are these statistics useful?

average_loss_train = np.mean(loss_history.loss_train)
std_loss_train = np.std(loss_history.loss_train)
average_loss_test = np.mean(loss_history.loss_test)
std_loss_test = np.std(loss_history.loss_test)

print(f"Average Train Loss: {average_loss_train}")
print(f"Standard Deviation Train Loss: {std_loss_train}")
print(f"Average Test Loss: {average_loss_test}")
print(f"Standard Deviation Test Loss: {std_loss_test}")


def plot_loss_history(loss_history, fname=None, plot_style=None):
"""Plot the training and testing loss history with custom style.

Args:
loss_history: ``LossHistory`` instance. The first variable returned from
``Model.train()``.
fname (string): If `fname` is a string (e.g., 'loss_history.png'), then save the
figure to the file of the file name `fname`.
loss_history: LossHistory instance. The first variable returned from Model.train().
fname (string): If fname is a string (e.g., 'loss_history.png'), then save the
figure to the file of the file name fname.
plot_style (dict): A dictionary containing style information for the plot. It can
include keys like 'train_color', 'test_color', 'train_linestyle', 'test_linestyle', etc.
"""
loss_train = np.sum(loss_history.loss_train, axis=1)
loss_test = np.sum(loss_history.loss_test, axis=1)

plt.figure()
plt.semilogy(loss_history.steps, loss_train, label="Train loss")
plt.semilogy(loss_history.steps, loss_test, label="Test loss")

# Default plot style settings
default_style = {
'train_color': 'b',
'test_color': 'r',
'train_linestyle': '-',
'test_linestyle': '--'
}

# Merge user-defined style with default style
plot_style = {**default_style, **(plot_style or {})}

train_color = plot_style['train_color']
test_color = plot_style['test_color']
train_linestyle = plot_style['train_linestyle']
test_linestyle = plot_style['test_linestyle']

plt.semilogy(loss_history.steps, loss_train, label="Train loss", color=train_color, linestyle=train_linestyle)
plt.semilogy(loss_history.steps, loss_test, label="Test loss", color=test_color, linestyle=test_linestyle)

for i in range(len(loss_history.metrics_test[0])):
plt.semilogy(
loss_history.steps,
Expand All @@ -211,6 +239,7 @@ def plot_loss_history(loss_history, fname=None):
plt.savefig(fname)



def save_loss_history(loss_history, fname):
"""Save the training and testing loss history to a file."""
print("Saving loss history to {} ...".format(fname))
Expand Down