Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit ea2285b

Browse filesBrowse files
committed
add bivariate random walk as transition model
1 parent 59e987c commit ea2285b
Copy full SHA for ea2285b

File tree

Expand file treeCollapse file tree

2 files changed

+87
-0
lines changed
Filter options
Expand file treeCollapse file tree

2 files changed

+87
-0
lines changed

‎bayesloop/transitionModels.py

Copy file name to clipboardExpand all lines: bayesloop/transitionModels.py
+73Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,10 @@
1111
from __future__ import division, print_function
1212
import numpy as np
1313
from scipy.signal import fftconvolve
14+
from scipy.signal import convolve2d
1415
from scipy.ndimage.filters import gaussian_filter1d
1516
from scipy.ndimage.interpolation import shift
17+
from scipy.stats import multivariate_normal
1618
from collections import Iterable
1719
from inspect import getargspec
1820
from copy import deepcopy
@@ -833,3 +835,74 @@ def __init__(self, name='tBreak', value=None, prior=None):
833835

834836
def __str__(self):
835837
return 'Break-point'
838+
839+
840+
class BivariateRandomWalk(TransitionModel):
841+
"""
842+
Correlated Gaussian parameter fluctuations. This model assumes that parameter changes follow a bivariate Gaussian
843+
distribution.
844+
"""
845+
def __init__(self, name1='sigma1', value1=None,
846+
name2='sigma2', value2=None,
847+
name3='rho', value3=None,
848+
prior=(None, None, None)):
849+
850+
if isinstance(value1, (list, tuple)):
851+
value1 = np.array(value1)
852+
if isinstance(value2, (list, tuple)):
853+
value2 = np.array(value2)
854+
if isinstance(value3, (list, tuple)):
855+
value2 = np.array(value3)
856+
857+
self.study = None
858+
self.latticeConstant = None
859+
self.hyperParameterNames = [name1, name2, name3]
860+
self.hyperParameterValues = [value1, value2, value3]
861+
self.prior = prior
862+
self.kernel = None
863+
self.kernelParameters = None
864+
self.tOffset = 0 # is set to the time of the last Breakpoint by SerialTransition model
865+
866+
def __str__(self):
867+
return 'Bivariate random walk'
868+
869+
def computeForwardPrior(self, posterior, t):
870+
"""
871+
Compute new prior from old posterior (moving forwards in time).
872+
873+
Args:
874+
posterior(ndarray): Parameter distribution from current time step
875+
t(int): integer time step
876+
877+
Returns:
878+
ndarray: Prior parameter distribution for subsequent time step
879+
"""
880+
881+
# if hyper-parameter values have changed, a new convolution kernel needs to be created
882+
if not self.kernelParameters == self.hyperParameterValues:
883+
normedSigma1 = self.hyperParameterValues[0] / self.latticeConstant[0]
884+
normedSigma2 = self.hyperParameterValues[1] / self.latticeConstant[1]
885+
886+
self.kernel = self.createKernel(normedSigma1, normedSigma2, self.hyperParameterValues[2])
887+
self.kernelParameters = deepcopy(self.hyperParameterValues)
888+
889+
newPrior = convolve2d(posterior, self.kernel, mode='same')
890+
newPrior /= np.sum(newPrior)
891+
return newPrior
892+
893+
def computeBackwardPrior(self, posterior, t):
894+
return self.computeForwardPrior(posterior, t - 1)
895+
896+
@staticmethod
897+
def createKernel(sigma1, sigma2, rho):
898+
rv = multivariate_normal(cov=[[sigma1 ** 2., rho * sigma1 * sigma2],
899+
[rho * sigma1 * sigma2, sigma2 ** 2.]])
900+
901+
x = np.arange(-3 * np.ceil(sigma1), 3 * np.ceil(sigma1) + 1)
902+
y = np.arange(-3 * np.ceil(sigma2), 3 * np.ceil(sigma2) + 1)
903+
904+
xv, yv = np.meshgrid(x, y, sparse=False, indexing='ij')
905+
906+
kernel = rv.pdf(np.array([xv, yv]).T).T
907+
kernel /= np.sum(kernel)
908+
return kernel

‎tests/test_transitionmodels.py

Copy file name to clipboardExpand all lines: tests/test_transitionmodels.py
+14Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,20 @@ def test_gaussianrandomwalk(self):
5151
np.testing.assert_almost_equal(S.logEvidence, -10.323144246611964, decimal=5,
5252
err_msg='Erroneous log-evidence value.')
5353

54+
def test_bivariaterandomwalk(self):
55+
S = bl.Study()
56+
S.loadData(np.array([1, 2, 3, 4, 5]))
57+
58+
L = bl.om.Gaussian('mu', bl.oint(0, 6, 20), 'sigma', bl.oint(0, 2, 20))
59+
T = bl.tm.BivariateRandomWalk('sigma1', 1., 'sigma2', 0.1, 'rho', 0.5)
60+
S.set(L, T)
61+
62+
S.fit()
63+
64+
# test model evidence value
65+
np.testing.assert_almost_equal(S.logEvidence, -7.330706514472251, decimal=5,
66+
err_msg='Erroneous log-evidence value.')
67+
5468
def test_alphastablerandomwalk(self):
5569
S = bl.Study()
5670
S.loadData(np.array([1, 2, 3, 4, 5]))

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.