Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit 8f71b7a

Browse filesBrowse files
added matrix distance exaple as a test
1 parent df29565 commit 8f71b7a
Copy full SHA for 8f71b7a

File tree

1 file changed

+157
-0
lines changed
Filter options

1 file changed

+157
-0
lines changed

‎testthat/test-distance.R

Copy file name to clipboard
+157Lines changed: 157 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,157 @@
1+
context( "distance" )
2+
3+
code <- '
4+
#include <Rcpp.h>
5+
using namespace Rcpp;
6+
7+
#include <cmath>
8+
#include <algorithm>
9+
10+
// generic function for kl_divergence
11+
template <typename InputIterator1, typename InputIterator2>
12+
inline double kl_divergence(InputIterator1 begin1, InputIterator1 end1,
13+
InputIterator2 begin2) {
14+
15+
// value to return
16+
double rval = 0;
17+
18+
// set iterators to beginning of ranges
19+
InputIterator1 it1 = begin1;
20+
InputIterator2 it2 = begin2;
21+
22+
// for each input item
23+
while (it1 != end1) {
24+
25+
// take the value and increment the iterator
26+
double d1 = *it1++;
27+
double d2 = *it2++;
28+
29+
// accumulate if appropirate
30+
if (d1 > 0 && d2 > 0)
31+
rval += std::log(d1 / d2) * d1;
32+
}
33+
return rval;
34+
}
35+
36+
// helper function for taking the average of two numbers
37+
inline double average(double val1, double val2) {
38+
return (val1 + val2) / 2;
39+
}
40+
41+
// [[Rcpp::export]]
42+
NumericMatrix rcpp_js_distance(NumericMatrix mat) {
43+
44+
// allocate the matrix we will return
45+
NumericMatrix rmat(mat.nrow(), mat.nrow());
46+
47+
for (int i = 0; i < rmat.nrow(); i++) {
48+
for (int j = 0; j < i; j++) {
49+
50+
// rows we will operate on
51+
NumericMatrix::Row row1 = mat.row(i);
52+
NumericMatrix::Row row2 = mat.row(j);
53+
54+
// compute the average using std::tranform from the STL
55+
std::vector<double> avg(row1.size());
56+
std::transform(row1.begin(), row1.end(), // input range 1
57+
row2.begin(), // input range 2
58+
avg.begin(), // output range
59+
average); // function to apply
60+
61+
// calculate divergences
62+
double d1 = kl_divergence(row1.begin(), row1.end(), avg.begin());
63+
double d2 = kl_divergence(row2.begin(), row2.end(), avg.begin());
64+
65+
// write to output matrix
66+
rmat(i,j) = std::sqrt(.5 * (d1 + d2));
67+
}
68+
}
69+
70+
return rmat;
71+
}
72+
73+
// [[Rcpp::depends(RcppParallel)]]
74+
#include <RcppParallel.h>
75+
using namespace RcppParallel;
76+
77+
struct JsDistance : public Worker {
78+
79+
// input matrix to read from
80+
const RMatrix<double> mat;
81+
82+
// output matrix to write to
83+
RMatrix<double> rmat;
84+
85+
// initialize from Rcpp input and output matrixes (the RMatrix class
86+
// can be automatically converted to from the Rcpp matrix type)
87+
JsDistance(const NumericMatrix mat, NumericMatrix rmat)
88+
: mat(mat), rmat(rmat) {}
89+
90+
// function call operator that work for the specified range (begin/end)
91+
void operator()(std::size_t begin, std::size_t end) {
92+
for (std::size_t i = begin; i < end; i++) {
93+
for (std::size_t j = 0; j < i; j++) {
94+
95+
// rows we will operate on
96+
RMatrix<double>::Row row1 = mat.row(i);
97+
RMatrix<double>::Row row2 = mat.row(j);
98+
99+
// compute the average using std::tranform from the STL
100+
std::vector<double> avg(row1.length());
101+
std::transform(row1.begin(), row1.end(), // input range 1
102+
row2.begin(), // input range 2
103+
avg.begin(), // output range
104+
average); // function to apply
105+
106+
// calculate divergences
107+
double d1 = kl_divergence(row1.begin(), row1.end(), avg.begin());
108+
double d2 = kl_divergence(row2.begin(), row2.end(), avg.begin());
109+
110+
// write to output matrix
111+
rmat(i,j) = sqrt(.5 * (d1 + d2));
112+
}
113+
}
114+
}
115+
};
116+
117+
// [[Rcpp::export]]
118+
NumericMatrix rcpp_parallel_js_distance(NumericMatrix mat) {
119+
120+
// allocate the matrix we will return
121+
NumericMatrix rmat(mat.nrow(), mat.nrow());
122+
123+
// create the worker
124+
JsDistance jsDistance(mat, rmat);
125+
126+
// call it with parallelFor
127+
parallelFor(0, mat.nrow(), jsDistance);
128+
129+
return rmat;
130+
}
131+
'
132+
133+
test_that( "sum works with Rcpp", {
134+
Rcpp::sourceCpp( code = code )
135+
136+
n = 1000
137+
m = matrix(runif(n*10), ncol = 10)
138+
m = m/rowSums(m)
139+
140+
expect_equal(
141+
rcpp_js_distance(m),
142+
rcpp_parallel_js_distance(m)
143+
)
144+
})
145+
146+
test_that( "sum works with Rcpp11", {
147+
attributes::sourceCpp( code = code)
148+
149+
n = 1000
150+
m = matrix(runif(n*10), ncol = 10)
151+
m = m/rowSums(m)
152+
153+
expect_equal(
154+
rcpp_js_distance(m),
155+
rcpp_parallel_js_distance(m)
156+
)
157+
})

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.