@@ -173,7 +173,8 @@ def load(self, file_prefix):
173
173
self .num_training_samples = params ['num_training_samples' ]
174
174
175
175
176
- class StudentsTProcessRegressor (GaussianProcessRegressor ):
176
+ class StudentsTProcessRegressor (BayesianModel ,
177
+ GaussianProcessRegressorMixin ):
177
178
""" StudentsT Process Regression built using PyMC3.
178
179
179
180
Fit a StudentsT process model and estimate model parameters using
@@ -204,8 +205,15 @@ class StudentsTProcessRegressor(GaussianProcessRegressor):
204
205
Rasmussen and Williams (2006). Gaussian Processes for Machine Learning.
205
206
"""
206
207
207
- def __init__ (self , prior_mean = 0.0 ):
208
- super (StudentsTProcessRegressor , self ).__init__ (prior_mean = prior_mean )
208
+ def __init__ (self , prior_mean = None , kernel = None ):
209
+ self .ppc = None
210
+ self .gp = None
211
+ self .num_training_samples = None
212
+ self .num_pred = None
213
+ self .prior_mean = prior_mean
214
+ self .kernel = kernel
215
+
216
+ super (StudentsTProcessRegressor , self ).__init__ ()
209
217
210
218
def create_model (self ):
211
219
""" Creates and returns the PyMC3 model.
@@ -241,13 +249,17 @@ def create_model(self):
241
249
degrees_of_freedom = pm .Gamma ('degrees_of_freedom' , alpha = 2 ,
242
250
beta = 0.1 , shape = 1 )
243
251
244
- # cov_function = signal_variance**2 * pm.gp.cov.ExpQuad(
245
- # 1, length_scale)
246
- cov_function = signal_variance ** 2 * pm .gp .cov .Matern52 (
247
- 1 , length_scale )
252
+ if self .kernel is None :
253
+ cov_function = signal_variance ** 2 * RBF (
254
+ input_dim = self .num_pred ,
255
+ ls = length_scale )
256
+ else :
257
+ cov_function = self .kernel
248
258
249
- # mean_function = pm.gp.mean.Zero()
250
- mean_function = pm .gp .mean .Constant (self .prior_mean )
259
+ if self .prior_mean is None :
260
+ mean_function = pm .gp .mean .Zero ()
261
+ else :
262
+ mean_function = self .prior_mean
251
263
252
264
self .gp = pm .gp .Latent (mean_func = mean_function ,
253
265
cov_func = cov_function )
@@ -277,7 +289,8 @@ def load(self, file_prefix):
277
289
self .num_training_samples = params ['num_training_samples' ]
278
290
279
291
280
- class SparseGaussianProcessRegressor (GaussianProcessRegressor ):
292
+ class SparseGaussianProcessRegressor (BayesianModel ,
293
+ GaussianProcessRegressorMixin ):
281
294
""" Sparse Gaussian Process Regression built using PyMC3.
282
295
283
296
Fit a Sparse Gaussian process model and estimate model parameters using
@@ -308,9 +321,15 @@ class SparseGaussianProcessRegressor(GaussianProcessRegressor):
308
321
Rasmussen and Williams (2006). Gaussian Processes for Machine Learning.
309
322
"""
310
323
311
- def __init__ (self , prior_mean = 0.0 ):
312
- super (SparseGaussianProcessRegressor , self ).__init__ (
313
- prior_mean = prior_mean )
324
+ def __init__ (self , prior_mean = None , kernel = None ):
325
+ self .ppc = None
326
+ self .gp = None
327
+ self .num_training_samples = None
328
+ self .num_pred = None
329
+ self .prior_mean = prior_mean
330
+ self .kernel = kernel
331
+
332
+ super (SparseGaussianProcessRegressor , self ).__init__ ()
314
333
315
334
def create_model (self ):
316
335
""" Creates and returns the PyMC3 model.
@@ -344,13 +363,17 @@ def create_model(self):
344
363
noise_variance = pm .HalfCauchy ('noise_variance' , beta = 5 ,
345
364
shape = 1 )
346
365
347
- # cov_function = signal_variance**2 * pm.gp.cov.ExpQuad(
348
- # 1, length_scale)
349
- cov_function = signal_variance ** 2 * pm .gp .cov .Matern52 (
350
- 1 , length_scale )
366
+ if self .kernel is None :
367
+ cov_function = signal_variance ** 2 * RBF (
368
+ input_dim = self .num_pred ,
369
+ ls = length_scale )
370
+ else :
371
+ cov_function = self .kernel
351
372
352
- # mean_function = pm.gp.mean.Zero()
353
- mean_function = pm .gp .mean .Constant (self .prior_mean )
373
+ if self .prior_mean is None :
374
+ mean_function = pm .gp .mean .Zero ()
375
+ else :
376
+ mean_function = self .prior_mean
354
377
355
378
self .gp = pm .gp .MarginalSparse (mean_func = mean_function ,
356
379
cov_func = cov_function ,
0 commit comments