1
1
import gpjax
2
- import gpjax .core as gpx
3
2
from gpjax .parameters import initialise
4
3
import matplotlib .pyplot as plt
5
4
from multipledispatch import dispatch
13
12
14
13
@dispatch (Kernel , Array , dict )
15
14
def plot (kernel : Kernel , X : Array , params : dict , ax = None ):
15
+ """
16
+ Plot the kernel's Gram matrix.
17
+
18
+ :param kernel: The kernel function that generates the Gram matrix
19
+ :param X: The data points for which the Gram matrix is computed on.
20
+ :param params: A dictionary containing the kernel parameters
21
+ :param ax: An optional matplotlib axes
22
+ :return:
23
+ """
16
24
if dict is None :
17
25
params = initialise (kernel )
18
26
@@ -27,6 +35,16 @@ def plot(kernel: Kernel, X: Array, params: dict, ax = None):
27
35
28
36
@dispatch (Kernel , Array , Array , dict )
29
37
def plot (kernel : Kernel , X : Array , Y : Array , params : dict = None , ax = None ):
38
+ """
39
+ Plot the kernel's cross-covariance matrix.
40
+
41
+ :param kernel: The kernel function that generates the covariance matrix
42
+ :param X: The first set of data points for which the covariance matrix is computed on.
43
+ :param Y: The second set of data points for which the covariance matrix is computed on.
44
+ :param params: A dictionary containing the kernel parameters
45
+ :param ax: An optional matplotlib axes
46
+ :return:
47
+ """
30
48
if dict is None :
31
49
params = initialise (kernel )
32
50
@@ -41,6 +59,15 @@ def plot(kernel: Kernel, X: Array, Y: Array, params: dict = None, ax=None):
41
59
42
60
@dispatch (Kernel )
43
61
def plot (kernel : Kernel , params : dict = None , ax = None , xrange : Tuple [float , float ] = (- 10 , 10. )):
62
+ """
63
+ Plot the kernel's shape.
64
+
65
+ :param kernel: The kernel function
66
+ :param params: A dictionary containing the kernel parameters
67
+ :param ax: An optional matplotlib axes
68
+ :param xrange The tuple pair lower and upper values over which the kernel should be evaluated.
69
+ :return:
70
+ """
44
71
if dict is None :
45
72
params = initialise (kernel )
46
73
@@ -54,4 +81,3 @@ def plot(kernel: Kernel, params: dict = None, ax=None, xrange: Tuple[float, floa
54
81
K = gpjax .kernels .cross_covariance (kernel , X , x1 , params )
55
82
ax .plot (X , K .T , color = cols ['base' ])
56
83
mplcyberpunk .add_underglow (ax = ax )
57
-
0 commit comments