Fixed some accidental mistake xs
This commit is contained in:
@@ -343,7 +343,39 @@ class Trainer:
|
|||||||
features[:96], target, predictions, show_legend=(0 == 0)
|
features[:96], target, predictions, show_legend=(0 == 0)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if epoch > 0:
|
||||||
|
task.get_logger().report_matplotlib_figure(
|
||||||
|
title="Training" if train else "Testing",
|
||||||
|
series=f"Sample {actual_idx}",
|
||||||
|
iteration=epoch,
|
||||||
|
figure=fig,
|
||||||
|
)
|
||||||
|
|
||||||
|
task.get_logger().report_matplotlib_figure(
|
||||||
|
title="Training Samples" if train else "Testing Samples",
|
||||||
|
series=f"Sample {actual_idx} samples",
|
||||||
|
iteration=epoch,
|
||||||
|
figure=fig2,
|
||||||
|
report_interactive=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
else:
|
||||||
|
print("Saving final plots")
|
||||||
|
# fig to PIL image
|
||||||
|
fig.savefig(f"sample_{actual_idx}_plot.png")
|
||||||
|
task.get_logger().report_image(
|
||||||
|
title="Final Training Plot",
|
||||||
|
series=f"Sample {actual_idx}",
|
||||||
|
iteration=epoch,
|
||||||
|
image_path=f"sample_{actual_idx}_plot.png",
|
||||||
|
)
|
||||||
|
|
||||||
|
fig2.savefig(f"sample_{actual_idx}_samples_plot.png")
|
||||||
|
task.get_logger().report_image(
|
||||||
|
title="Final Training Samples Plot",
|
||||||
|
series=f"Sample {actual_idx} samples",
|
||||||
|
iteration=epoch,
|
||||||
|
image_path=f"sample_{actual_idx}_samples_plot.png",
|
||||||
)
|
)
|
||||||
plt.close()
|
plt.close()
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user