Skip to content
This repository was archived by the owner on Apr 21, 2023. It is now read-only.

Commit 58c3486

Browse files
committed
Update docs
1 parent ccd992b commit 58c3486

File tree

3 files changed

+30
-3
lines changed

3 files changed

+30
-3
lines changed

docs/index.rst

+1-1
Original file line numberDiff line numberDiff line change
@@ -23,4 +23,4 @@ Demo
2323
:maxdepth: 1
2424
:caption: GPViz Demo
2525

26-
demo
26+
nbs/demo

gpviz/kernel.py

+28-2
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import gpjax
2-
import gpjax.core as gpx
32
from gpjax.parameters import initialise
43
import matplotlib.pyplot as plt
54
from multipledispatch import dispatch
@@ -13,6 +12,15 @@
1312

1413
@dispatch(Kernel, Array, dict)
1514
def plot(kernel: Kernel, X: Array, params: dict, ax = None):
15+
"""
16+
Plot the kernel's Gram matrix.
17+
18+
:param kernel: The kernel function that generates the Gram matrix
19+
:param X: The data points for which the Gram matrix is computed on.
20+
:param params: A dictionary containing the kernel parameters
21+
:param ax: An optional matplotlib axes
22+
:return:
23+
"""
1624
if dict is None:
1725
params = initialise(kernel)
1826

@@ -27,6 +35,16 @@ def plot(kernel: Kernel, X: Array, params: dict, ax = None):
2735

2836
@dispatch(Kernel, Array, Array, dict)
2937
def plot(kernel: Kernel, X: Array, Y: Array, params: dict = None, ax=None):
38+
"""
39+
Plot the kernel's cross-covariance matrix.
40+
41+
:param kernel: The kernel function that generates the covariance matrix
42+
:param X: The first set of data points for which the covariance matrix is computed on.
43+
:param Y: The second set of data points for which the covariance matrix is computed on.
44+
:param params: A dictionary containing the kernel parameters
45+
:param ax: An optional matplotlib axes
46+
:return:
47+
"""
3048
if dict is None:
3149
params = initialise(kernel)
3250

@@ -41,6 +59,15 @@ def plot(kernel: Kernel, X: Array, Y: Array, params: dict = None, ax=None):
4159

4260
@dispatch(Kernel)
4361
def plot(kernel: Kernel, params: dict = None, ax=None, xrange: Tuple[float, float] = (-10, 10.)):
62+
"""
63+
Plot the kernel's shape.
64+
65+
:param kernel: The kernel function
66+
:param params: A dictionary containing the kernel parameters
67+
:param ax: An optional matplotlib axes
68+
:param xrange The tuple pair lower and upper values over which the kernel should be evaluated.
69+
:return:
70+
"""
4471
if dict is None:
4572
params = initialise(kernel)
4673

@@ -54,4 +81,3 @@ def plot(kernel: Kernel, params: dict = None, ax=None, xrange: Tuple[float, floa
5481
K = gpjax.kernels.cross_covariance(kernel, X, x1, params)
5582
ax.plot(X, K.T, color=cols['base'])
5683
mplcyberpunk.add_underglow(ax=ax)
57-

requirements.txt

+1
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
11
gpjax>=0.3.7
22
matplotlib
3+
mplcyberpunk==0.1.11

0 commit comments

Comments
 (0)