Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit 58c7441

Browse filesBrowse files
committed
added tests for students tp, sparse gp
1 parent 594cf18 commit 58c7441
Copy full SHA for 58c7441

File tree

Expand file treeCollapse file tree

3 files changed

+288
-271
lines changed
Filter options
Expand file treeCollapse file tree

3 files changed

+288
-271
lines changed

‎README.rst

Copy file name to clipboardExpand all lines: README.rst
+3-3Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -262,8 +262,8 @@ project: https://github.com/parsing-science/pymc3_models.
262262
.. |Travis| image:: https://travis-ci.com/pymc-learn/pymc-learn.svg?branch=master
263263
:target: https://travis-ci.com/pymc-learn/pymc-learn
264264

265-
.. |Coverage| image:: https://coveralls.io/repos/github/pymc-learn/pymc-learn/badge.svg
266-
:target: https://coveralls.io/github/pymc-learn/pymc-learn
265+
.. |Coverage| image:: https://coveralls.io/repos/github/pymc-learn/pymc-learn/badge.svg?branch=master
266+
:target: https://coveralls.io/github/pymc-learn/pymc-learn?branch=master
267267

268268
.. |Python27| image:: https://img.shields.io/badge/python-2.7-blue.svg
269269
:target: https://badge.fury.io/py/pymc-learn
@@ -280,4 +280,4 @@ project: https://github.com/parsing-science/pymc3_models.
280280
:target: https://github.com/pymc-learn/pymc-learn/blob/master/LICENSE
281281

282282
.. |Pypi| image:: https://badge.fury.io/py/pymc-learn.svg
283-
:target: https://badge.fury.io/py/pymc-learn
283+
:target: https://badge.fury.io/py/pymc-learn

‎pmlearn/gaussian_process/gpr.py

Copy file name to clipboardExpand all lines: pmlearn/gaussian_process/gpr.py
+42-19Lines changed: 42 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -173,7 +173,8 @@ def load(self, file_prefix):
173173
self.num_training_samples = params['num_training_samples']
174174

175175

176-
class StudentsTProcessRegressor(GaussianProcessRegressor):
176+
class StudentsTProcessRegressor(BayesianModel,
177+
GaussianProcessRegressorMixin):
177178
""" StudentsT Process Regression built using PyMC3.
178179
179180
Fit a StudentsT process model and estimate model parameters using
@@ -204,8 +205,15 @@ class StudentsTProcessRegressor(GaussianProcessRegressor):
204205
Rasmussen and Williams (2006). Gaussian Processes for Machine Learning.
205206
"""
206207

207-
def __init__(self, prior_mean=0.0):
208-
super(StudentsTProcessRegressor, self).__init__(prior_mean=prior_mean)
208+
def __init__(self, prior_mean=None, kernel=None):
209+
self.ppc = None
210+
self.gp = None
211+
self.num_training_samples = None
212+
self.num_pred = None
213+
self.prior_mean = prior_mean
214+
self.kernel = kernel
215+
216+
super(StudentsTProcessRegressor, self).__init__()
209217

210218
def create_model(self):
211219
""" Creates and returns the PyMC3 model.
@@ -241,13 +249,17 @@ def create_model(self):
241249
degrees_of_freedom = pm.Gamma('degrees_of_freedom', alpha=2,
242250
beta=0.1, shape=1)
243251

244-
# cov_function = signal_variance**2 * pm.gp.cov.ExpQuad(
245-
# 1, length_scale)
246-
cov_function = signal_variance ** 2 * pm.gp.cov.Matern52(
247-
1, length_scale)
252+
if self.kernel is None:
253+
cov_function = signal_variance ** 2 * RBF(
254+
input_dim=self.num_pred,
255+
ls=length_scale)
256+
else:
257+
cov_function = self.kernel
248258

249-
# mean_function = pm.gp.mean.Zero()
250-
mean_function = pm.gp.mean.Constant(self.prior_mean)
259+
if self.prior_mean is None:
260+
mean_function = pm.gp.mean.Zero()
261+
else:
262+
mean_function = self.prior_mean
251263

252264
self.gp = pm.gp.Latent(mean_func=mean_function,
253265
cov_func=cov_function)
@@ -277,7 +289,8 @@ def load(self, file_prefix):
277289
self.num_training_samples = params['num_training_samples']
278290

279291

280-
class SparseGaussianProcessRegressor(GaussianProcessRegressor):
292+
class SparseGaussianProcessRegressor(BayesianModel,
293+
GaussianProcessRegressorMixin):
281294
""" Sparse Gaussian Process Regression built using PyMC3.
282295
283296
Fit a Sparse Gaussian process model and estimate model parameters using
@@ -308,9 +321,15 @@ class SparseGaussianProcessRegressor(GaussianProcessRegressor):
308321
Rasmussen and Williams (2006). Gaussian Processes for Machine Learning.
309322
"""
310323

311-
def __init__(self, prior_mean=0.0):
312-
super(SparseGaussianProcessRegressor, self).__init__(
313-
prior_mean=prior_mean)
324+
def __init__(self, prior_mean=None, kernel=None):
325+
self.ppc = None
326+
self.gp = None
327+
self.num_training_samples = None
328+
self.num_pred = None
329+
self.prior_mean = prior_mean
330+
self.kernel = kernel
331+
332+
super(SparseGaussianProcessRegressor, self).__init__()
314333

315334
def create_model(self):
316335
""" Creates and returns the PyMC3 model.
@@ -344,13 +363,17 @@ def create_model(self):
344363
noise_variance = pm.HalfCauchy('noise_variance', beta=5,
345364
shape=1)
346365

347-
# cov_function = signal_variance**2 * pm.gp.cov.ExpQuad(
348-
# 1, length_scale)
349-
cov_function = signal_variance ** 2 * pm.gp.cov.Matern52(
350-
1, length_scale)
366+
if self.kernel is None:
367+
cov_function = signal_variance ** 2 * RBF(
368+
input_dim=self.num_pred,
369+
ls=length_scale)
370+
else:
371+
cov_function = self.kernel
351372

352-
# mean_function = pm.gp.mean.Zero()
353-
mean_function = pm.gp.mean.Constant(self.prior_mean)
373+
if self.prior_mean is None:
374+
mean_function = pm.gp.mean.Zero()
375+
else:
376+
mean_function = self.prior_mean
354377

355378
self.gp = pm.gp.MarginalSparse(mean_func=mean_function,
356379
cov_func=cov_function,

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.