@@ -7,23 +7,25 @@ from scipy.special import gammaln
7
7
import numpy as np
8
8
cimport numpy as cnp
9
9
10
- cnp.import_array()
11
- ctypedef cnp.float64_t DOUBLE
12
10
13
-
14
- def expected_mutual_information (contingency , int n_samples ):
11
+ def expected_mutual_information (contingency , cnp.int64_t n_samples ):
15
12
""" Calculate the expected mutual information for two labelings."""
16
- cdef int R, C
17
- cdef DOUBLE N, gln_N, emi, term2, term3, gln
18
- cdef cnp.ndarray[DOUBLE] gln_a, gln_b, gln_Na, gln_Nb, gln_nij, log_Nnij
19
- cdef cnp.ndarray[DOUBLE] nijs, term1
20
- cdef cnp.ndarray[DOUBLE] log_a, log_b
21
- cdef cnp.ndarray[cnp.int32_t] a, b
22
- # cdef np.ndarray[int, ndim=2] start, end
23
- R, C = contingency.shape
24
- N = < DOUBLE> n_samples
25
- a = np.ravel(contingency.sum(axis = 1 ).astype(np.int32, copy = False ))
26
- b = np.ravel(contingency.sum(axis = 0 ).astype(np.int32, copy = False ))
13
+ cdef:
14
+ cnp.float64_t emi = 0
15
+ cnp.int64_t n_rows, n_cols
16
+ cnp.float64_t term2, term3, gln
17
+ cnp.int64_t[::1 ] a_view, b_view
18
+ cnp.float64_t[::1 ] nijs_view, term1
19
+ cnp.float64_t[::1 ] gln_a, gln_b, gln_Na, gln_Nb, gln_Nnij, log_Nnij
20
+ cnp.float64_t[::1 ] log_a, log_b
21
+ Py_ssize_t i, j, nij
22
+ cnp.int64_t start, end
23
+
24
+ n_rows, n_cols = contingency.shape
25
+ a = np.ravel(contingency.sum(axis = 1 ).astype(np.int64, copy = False ))
26
+ b = np.ravel(contingency.sum(axis = 0 ).astype(np.int64, copy = False ))
27
+ a_view = a
28
+ b_view = b
27
29
28
30
# any labelling with zero entropy implies EMI = 0
29
31
if a.size == 1 or b.size == 1 :
@@ -34,37 +36,34 @@ def expected_mutual_information(contingency, int n_samples):
34
36
# While nijs[0] will never be used, having it simplifies the indexing.
35
37
nijs = np.arange(0 , max (np.max(a), np.max(b)) + 1 , dtype = ' float' )
36
38
nijs[0 ] = 1 # Stops divide by zero warnings. As its not used, no issue.
39
+ nijs_view = nijs
37
40
# term1 is nij / N
38
- term1 = nijs / N
41
+ term1 = nijs / n_samples
39
42
# term2 is log((N*nij) / (a * b)) == log(N * nij) - log(a * b)
40
43
log_a = np.log(a)
41
44
log_b = np.log(b)
42
45
# term2 uses log(N * nij) = log(N) + log(nij)
43
- log_Nnij = np.log(N ) + np.log(nijs)
46
+ log_Nnij = np.log(n_samples ) + np.log(nijs)
44
47
# term3 is large, and involved many factorials. Calculate these in log
45
48
# space to stop overflows.
46
49
gln_a = gammaln(a + 1 )
47
50
gln_b = gammaln(b + 1 )
48
- gln_Na = gammaln(N - a + 1 )
49
- gln_Nb = gammaln(N - b + 1 )
50
- gln_N = gammaln(N + 1 )
51
- gln_nij = gammaln(nijs + 1 )
52
- # start and end values for nij terms for each summation.
53
- start = np.array([[v - N + w for w in b] for v in a], dtype = ' int' )
54
- start = np.maximum(start, 1 )
55
- end = np.minimum(np.resize(a, (C, R)).T, np.resize(b, (R, C))) + 1
51
+ gln_Na = gammaln(n_samples - a + 1 )
52
+ gln_Nb = gammaln(n_samples - b + 1 )
53
+ gln_Nnij = gammaln(nijs + 1 ) + gammaln(n_samples + 1 )
54
+
56
55
# emi itself is a summation over the various values.
57
- emi = 0.0
58
- cdef Py_ssize_t i, j, nij
59
- for i in range (R):
60
- for j in range (C):
61
- for nij in range (start[i,j], end[i,j] ):
56
+ for i in range (n_rows):
57
+ for j in range (n_cols):
58
+ start = max ( 1 , a_view[i] - n_samples + b_view[j])
59
+ end = min (a_view[i], b_view[j]) + 1
60
+ for nij in range (start, end):
62
61
term2 = log_Nnij[nij] - log_a[i] - log_b[j]
63
62
# Numerators are positive, denominators are negative.
64
63
gln = (gln_a[i] + gln_b[j] + gln_Na[i] + gln_Nb[j]
65
- - gln_N - gln_nij [nij] - lgamma(a [i] - nij + 1 )
66
- - lgamma(b [j] - nij + 1 )
67
- - lgamma(N - a [i] - b [j] + nij + 1 ))
64
+ - gln_Nnij [nij] - lgamma(a_view [i] - nij + 1 )
65
+ - lgamma(b_view [j] - nij + 1 )
66
+ - lgamma(n_samples - a_view [i] - b_view [j] + nij + 1 ))
68
67
term3 = exp(gln)
69
68
emi += (term1[nij] * term2 * term3)
70
69
return emi
0 commit comments