|
11 | 11 | from functools import reduce |
12 | 12 | from sklearn import metrics |
13 | 13 |
|
| 14 | +import seaborn as sns |
14 | 15 | import pandas as pd |
15 | 16 | import numpy as np |
16 | 17 |
|
@@ -323,10 +324,12 @@ def cdf(x, save=None, binsave=None, subsave='', choice='standard_normal'): |
323 | 324 |
|
324 | 325 | cdf_name = 'cdf' |
325 | 326 | parity_name = 'cdf_parity' |
| 327 | + dist_name = 'distribution' |
326 | 328 | if binsave is not None: |
327 | 329 | save = os.path.join(save, 'each_bin') |
328 | 330 | cdf_name = '{}_{}'.format(cdf_name, binsave) |
329 | 331 | parity_name = '{}_{}'.format(parity_name, binsave) |
| 332 | + dist_name = '{}_{}'.format(dist_name, binsave) |
330 | 333 |
|
331 | 334 | os.makedirs(save, exist_ok=True) |
332 | 335 |
|
@@ -472,6 +475,71 @@ def cdf(x, save=None, binsave=None, subsave='', choice='standard_normal'): |
472 | 475 | ), 'w') as handle: |
473 | 476 | json.dump(data, handle) |
474 | 477 |
|
| 478 | + fig, ax = pl.subplots() |
| 479 | + |
| 480 | + sns.histplot( |
| 481 | + z, |
| 482 | + kde=True, |
| 483 | + stat='density', |
| 484 | + color='g', |
| 485 | + ax=ax, |
| 486 | + label='Standard Normal Distribution', |
| 487 | + ) |
| 488 | + |
| 489 | + sns.histplot( |
| 490 | + x, |
| 491 | + kde=True, |
| 492 | + stat='density', |
| 493 | + color='r', |
| 494 | + ax=ax, |
| 495 | + label='Observed Distribution', |
| 496 | + ) |
| 497 | + |
| 498 | + ax.set_xlabel('z') |
| 499 | + ax.set_ylabel('Fraction') |
| 500 | + |
| 501 | + fig.tight_layout() |
| 502 | + |
| 503 | + fig_legend, ax_legend = pl.subplots() |
| 504 | + ax_legend.axis(False) |
| 505 | + legend = ax_legend.legend( |
| 506 | + *ax.get_legend_handles_labels(), |
| 507 | + frameon=False, |
| 508 | + loc='center', |
| 509 | + bbox_to_anchor=(0.5, 0.5) |
| 510 | + ) |
| 511 | + ax_legend.spines['top'].set_visible(False) |
| 512 | + ax_legend.spines['bottom'].set_visible(False) |
| 513 | + ax_legend.spines['left'].set_visible(False) |
| 514 | + ax_legend.spines['right'].set_visible(False) |
| 515 | + |
| 516 | + fig.savefig(os.path.join( |
| 517 | + save, |
| 518 | + '{}{}.png'.format(dist_name, subsave), |
| 519 | + ), bbox_inches='tight') |
| 520 | + |
| 521 | + fig_legend.savefig(os.path.join( |
| 522 | + save, |
| 523 | + '{}{}_legend.png'.format( |
| 524 | + dist_name, |
| 525 | + subsave |
| 526 | + ), |
| 527 | + ), bbox_inches='tight') |
| 528 | + |
| 529 | + pl.close(fig) |
| 530 | + pl.close(fig_legend) |
| 531 | + |
| 532 | + data = {} |
| 533 | + data['x'] = list(eval_points) |
| 534 | + data['y'] = list(y) |
| 535 | + data['y_pred'] = list(y_pred) |
| 536 | + data['Area'] = areacdf |
| 537 | + with open(os.path.join( |
| 538 | + save, |
| 539 | + '{}{}.json'.format(cdf_name, subsave), |
| 540 | + ), 'w') as handle: |
| 541 | + json.dump(data, handle) |
| 542 | + |
475 | 543 | return y, y_pred, areaparity, areacdf |
476 | 544 |
|
477 | 545 |
|
|
0 commit comments