1
- #%%
2
- """
3
- Created on Thu Nov 28 2018
4
- Bivariate normal PDF
5
- @author: Lech A. Grzelak
6
- """
7
1
import numpy as np
8
2
import matplotlib .pyplot as plt
9
- from matplotlib . mlab import bivariate_normal
3
+ from scipy . stats import multivariate_normal
10
4
11
- def BivariateNormalPDFPlot ():
12
5
6
+ def BivariateNormalPDFPlot ():
13
7
# Number of points in each direction
14
8
15
- n = 40 ;
16
-
9
+ n = 40 ;
10
+
17
11
# Parameters
18
12
19
- mu_1 = 0 ;
20
- mu_2 = 0 ;
13
+ mu_1 = 0 ;
14
+ mu_2 = 0 ;
21
15
sigma_1 = 1 ;
22
16
sigma_2 = 0.5 ;
23
- rho1 = 0.0
24
- rho2 = - 0.8
25
- rho3 = 0.8
26
-
17
+ rho1 = 0.0
18
+ rho2 = - 0.8
19
+ rho3 = 0.8
20
+
27
21
# Create a grid and a multivariate normal
28
22
29
- x = np .linspace (- 3.0 ,3.0 ,n )
30
- y = np .linspace (- 3.0 ,3.0 ,n )
31
- X , Y = np .meshgrid (x ,y )
32
- Z = lambda rho : bivariate_normal (X ,Y ,sigma_1 ,sigma_2 ,mu_1 ,mu_2 ,rho * sigma_1 * sigma_2 )
33
-
23
+ x = np .linspace (- 3.0 , 3.0 , n )
24
+ y = np .linspace (- 3.0 , 3.0 , n )
25
+ X , Y = np .meshgrid (x , y )
26
+ pos = np .empty (X .shape + (2 ,))
27
+ pos [:, :, 0 ] = X
28
+ pos [:, :, 1 ] = Y
29
+ Z = lambda rho : multivariate_normal ([mu_1 , mu_2 ], [[sigma_1 , rho ],
30
+ [rho , sigma_2 ]])
31
+
34
32
# Make a 3D plot- rho = 0.0
35
33
36
- fig = plt .figure (1 )
34
+ fig = plt .figure (1 )
37
35
ax = fig .gca (projection = '3d' )
38
- ax .plot_surface (X , Y , Z (rho1 ), cmap = 'viridis' ,linewidth = 0 )
36
+ ax .plot_surface (X , Y , Z (rho1 ). pdf ( pos ), cmap = 'viridis' , linewidth = 0 )
39
37
ax .set_xlabel ('X axis' )
40
38
ax .set_ylabel ('Y axis' )
41
39
ax .set_zlabel ('Z axis' )
42
40
plt .show ()
43
-
41
+
44
42
# Make a 3D plot- rho = -0.8
45
43
46
- fig = plt .figure (2 )
44
+ fig = plt .figure (2 )
47
45
ax = fig .gca (projection = '3d' )
48
- ax .plot_surface (X , Y , Z (rho2 ), cmap = 'viridis' ,linewidth = 0 )
46
+ ax .plot_surface (X , Y , Z (rho2 ). pdf ( pos ), cmap = 'viridis' , linewidth = 0 )
49
47
ax .set_xlabel ('X axis' )
50
48
ax .set_ylabel ('Y axis' )
51
49
ax .set_zlabel ('Z axis' )
52
50
plt .show ()
53
-
51
+
54
52
# Make a 3D plot- rho = 0.8
55
53
56
- fig = plt .figure (3 )
54
+ fig = plt .figure (3 )
57
55
ax = fig .gca (projection = '3d' )
58
- ax .plot_surface (X , Y , Z (rho3 ), cmap = 'viridis' ,linewidth = 0 )
56
+ ax .plot_surface (X , Y , Z (rho3 ). pdf ( pos ), cmap = 'viridis' , linewidth = 0 )
59
57
ax .set_xlabel ('X axis' )
60
58
ax .set_ylabel ('Y axis' )
61
59
ax .set_zlabel ('Z axis' )
62
60
plt .show ()
63
-
64
- BivariateNormalPDFPlot ()
61
+
62
+
63
+ BivariateNormalPDFPlot ()
0 commit comments