diff --git a/.gitignore b/.gitignore
index 2eff84a4..ebb9b497 100644
--- a/.gitignore
+++ b/.gitignore
@@ -2,6 +2,7 @@
/build
/doc
.idea/
+*.iml
.gradle
.settings
.classpath
diff --git a/README.md b/README.md
index 787946bd..b0291172 100644
--- a/README.md
+++ b/README.md
@@ -5,16 +5,28 @@ htm.java
* Build: [](https://travis-ci.org/numenta/htm.java)
* Unit Test Coverage: [](https://coveralls.io/r/numenta/htm.java?branch=master)
-Community-supported Java port of the [Numenta Platform for Intelligent Computing](https://github.com/numenta/nupic).
+Official community-supported Java implementation of [Hierarchal Temporal Memory (HTM)](http://numenta.org/htm-white-paper.html), ported from the [Numenta Platform for Intelligent Computing](https://github.com/numenta/nupic) python project.
-Holding area for the development of the Java port of Numenta's online prediction and anomaly detection systems, and implementation of the [Cortical Learning Algorithm (CLA)](https://github.com/numenta/nupic/wiki/Cortical-Learning-Algorithm)
+**NOTE: Minimum JavaSE version is 1.8**
-## In The News - see the latest [blog!](http://numenta.org/blog/2014/12/03/htm-on-the-jvm.html)
+### In The News - (We never updated with this latest change...)
+* HTM.java is now [**OFFICIAL!**](https://github.com/numenta/htm.java/issues/193) See the [_**announcement**_](http://lists.numenta.org/pipermail/nupic_lists.numenta.org/2015-February/010404.html) (02/25/2015)
+* [HTM.java Now Has Anomaly Detection & Anomaly Likelihood Prediction!](https://github.com/numenta/htm.java/wiki/Anomaly-Detection-Module) (02/22/2015)
+* [HTM.java Recieves New Benchmarking Tools](http://numenta.org/blog/2015/02/10/htm-java-receives-benchmark-harness.html) (02/2015)
+* [HTM.java Reaches Functional Completion](http://numenta.org/blog/2014/12/03/htm-on-the-jvm.html) (12/2014)
For a more detailed discussion of htm.java see:
* [htm.java Wiki](https://github.com/numenta/htm.java/wiki)
* [Java Docs](http://numenta.org/docs/htm.java/)
+See the [Test Coverage Reports](https://coveralls.io/jobs/4164658) - For more information on where you can contribute! Extend the tests and get your name in bright lights!
+
+For answers to more detailed questions, email the [nupic-discuss](http://lists.numenta.org/mailman/listinfo/nupic_lists.numenta.org) mailing list, or chat with us on Gitter.
+
+[](https://gitter.im/numenta/public?utm_source=badge)
+
+***
+
### Call to Arms: [HTM.java needs you!](http://lists.numenta.org/pipermail/nupic-hackers_lists.numenta.org/2014-November/002819.html)
## Goals
@@ -31,6 +43,18 @@ An Eclipse IDE .project and .classpath file are provided so that the cloned proj
In addition, there are "launch configurations" for all of the tests and any runnable entities off of the "htm.java" main directory. These may be run directly in Eclipse by right-clicking them and choosing "run".
+## After download by clone or fork:
+
+Execute a quick sanity check by running all the tests from within the \/htm.java
+```
+gradle check # Executes the tests and runs the benchmarks
+
+--or--
+
+gradle -Pskipbench check # Executes the tests w/o running the benchmarks
+```
+**Note:** Info on installing **gradle** can be found on the wiki (look at #3.) [here](https://github.com/numenta/htm.java/wiki/Eclipse-Setup-Tips)
+
## For Updates Follow
* [#HtmJavaDevUpdates](https://twitter.com/hashtag/HtmJavaDevUpdates?src=hash)
diff --git a/[ htm.java Unit Test ] SpatialPoolerTest.launch b/[ htm.java Unit Test ] SpatialPoolerTest.launch
index 920870cf..1a6f65fd 100644
--- a/[ htm.java Unit Test ] SpatialPoolerTest.launch
+++ b/[ htm.java Unit Test ] SpatialPoolerTest.launch
@@ -1,10 +1,10 @@
-
+
-
+
@@ -25,7 +25,7 @@
-
+
diff --git a/[ htm.java Unit Test ] TemporalMemoryTest.launch b/[ htm.java Unit Test ] TemporalMemoryTest.launch
index f40697bf..63ec3076 100644
--- a/[ htm.java Unit Test ] TemporalMemoryTest.launch
+++ b/[ htm.java Unit Test ] TemporalMemoryTest.launch
@@ -1,10 +1,10 @@
-
+
-
+
@@ -25,7 +25,7 @@
-
+
diff --git a/build.gradle b/build.gradle
index 2c5c3d40..8f7ed17b 100644
--- a/build.gradle
+++ b/build.gradle
@@ -10,7 +10,7 @@ targetCompatibility = 1.8
jar {
manifest {
- attributes 'Implementation-Title': 'htm.java', 'Implementation-Version': 0.30
+ attributes 'Implementation-Title': 'htm.java', 'Implementation-Version': 0.40
}
}
@@ -34,10 +34,57 @@ dependencies {
compile group: 'com.fasterxml.jackson.core', name: 'jackson-annotations', version:'2.4.4'
compile group: 'com.fasterxml.jackson.core', name: 'jackson-core', version:'2.4.4'
compile group: 'com.fasterxml.jackson.core', name: 'jackson-databind', version:'2.4.4'
+ compile group: 'org.slf4j', name: 'slf4j-api', version:'1.7.10'
testCompile group: 'junit', name: 'junit', version:'4.11'
+ testCompile group: 'ch.qos.logback', name: 'logback-classic', version:'1.1.2'
}
+/////////////////////////////////////////////////////////////////
+// jmh Benchmarking Tool //
+/////////////////////////////////////////////////////////////////
+buildscript {
+ repositories {
+ jcenter()
+ mavenCentral()
+ maven {
+ name 'Shadow'
+ url 'http://dl.bintray.com/content/johnrengelman/gradle-plugins'
+ }
+
+ }
+
+ dependencies {
+ classpath 'me.champeau.gradle:jmh-gradle-plugin:0.1.3'
+ classpath 'com.github.jengelman.gradle.plugins:shadow:1.2.0'
+ classpath 'org.openjdk.jmh:jmh-generator-annprocess:1.5.2'
+ }
+}
+
+apply plugin: 'me.champeau.gradle.jmh'
+apply plugin: 'com.github.johnrengelman.shadow'
+
+jmh {
+ resultFormat = 'CSV'
+ fork = 1
+ warmupIterations = 5 // Number of warm up iterations to do.
+ iterations = 5 // Number of measurement iterations to do.
+}
+
+task runBench(dependsOn: ['parent.compile', 'test']) << {
+ if(!project.hasProperty('skipbench')) {
+ tasks.compileJmhJava.execute()
+ tasks.processJmhResources.execute()
+ tasks.jmhJar.execute()
+ tasks.jmh.execute()
+ }
+}
+
+check.doLast {
+ tasks.runBench.execute()
+}
+/////////////////// END jmh /////////////////////
+
// create Gradle wrapper with command line `gradle wrapper` in terminal
task wrapper(type: Wrapper) {
- gradleVersion = '2.0'
+ gradleVersion = '2.2'
}
diff --git a/libs/jmh-core-1.5.1.jar b/libs/jmh-core-1.5.1.jar
new file mode 100644
index 00000000..7cb080de
Binary files /dev/null and b/libs/jmh-core-1.5.1.jar differ
diff --git a/libs/joda-time-2.5.jar b/libs/joda-time-2.5.jar
new file mode 100644
index 00000000..4fe151d0
Binary files /dev/null and b/libs/joda-time-2.5.jar differ
diff --git a/libs/logback-classic-1.1.2.jar b/libs/logback-classic-1.1.2.jar
new file mode 100644
index 00000000..9230b2a7
Binary files /dev/null and b/libs/logback-classic-1.1.2.jar differ
diff --git a/libs/logback-core-1.1.2.jar b/libs/logback-core-1.1.2.jar
new file mode 100644
index 00000000..391da641
Binary files /dev/null and b/libs/logback-core-1.1.2.jar differ
diff --git a/libs/rxjava-1.0.0.jar b/libs/rxjava-1.0.0.jar
new file mode 100644
index 00000000..a5df433f
Binary files /dev/null and b/libs/rxjava-1.0.0.jar differ
diff --git a/libs/slf4j-api-1.7.10.jar b/libs/slf4j-api-1.7.10.jar
new file mode 100644
index 00000000..ac7da374
Binary files /dev/null and b/libs/slf4j-api-1.7.10.jar differ
diff --git a/pom.xml b/pom.xml
index de504514..5a62f768 100644
--- a/pom.xml
+++ b/pom.xml
@@ -70,7 +70,18 @@
com.fasterxml.jackson.corejackson-databind2.4.4
-
+
+
+ org.slf4j
+ slf4j-api
+ 1.7.10
+
+
+ ch.qos.logback
+ logback-classic
+ 1.1.2
+ test
+
diff --git a/src/jmh/java/org/numenta/nupic/benchmarks/AbstractAlgorithmBenchmark.java b/src/jmh/java/org/numenta/nupic/benchmarks/AbstractAlgorithmBenchmark.java
new file mode 100644
index 00000000..606accc4
--- /dev/null
+++ b/src/jmh/java/org/numenta/nupic/benchmarks/AbstractAlgorithmBenchmark.java
@@ -0,0 +1,90 @@
+package org.numenta.nupic.benchmarks;
+
+import org.numenta.nupic.Connections;
+import org.numenta.nupic.Parameters;
+import org.numenta.nupic.Parameters.KEY;
+import org.numenta.nupic.encoders.ScalarEncoder;
+import org.numenta.nupic.research.SpatialPooler;
+import org.numenta.nupic.research.TemporalMemory;
+import org.openjdk.jmh.annotations.Scope;
+import org.openjdk.jmh.annotations.Setup;
+import org.openjdk.jmh.annotations.State;
+
+@State(Scope.Benchmark)
+public abstract class AbstractAlgorithmBenchmark {
+
+ protected int[] SDR;
+ protected ScalarEncoder encoder;
+ protected SpatialPooler pooler;
+ protected TemporalMemory temporalMemory;
+ protected Connections memory;
+
+ @Setup
+ public void init() {
+ SDR = new int[2048];
+
+ //Layer components
+ ScalarEncoder.Builder dayBuilder = ScalarEncoder.builder()
+ .n(8)
+ .w(3)
+ .radius(1.0)
+ .minVal(1.0)
+ .maxVal(8)
+ .periodic(true)
+ .forced(true)
+ .resolution(1);
+ encoder = dayBuilder.build();
+ pooler = new SpatialPooler();
+
+ memory = new Connections();
+ Parameters params = getParameters();
+ params.apply(memory);
+
+ pooler = new SpatialPooler();
+ pooler.init(memory);
+
+ temporalMemory = new TemporalMemory();
+ temporalMemory.init(memory);
+ }
+
+ /**
+ * Create and return a {@link Parameters} object.
+ *
+ * @return
+ */
+ protected Parameters getParameters() {
+ Parameters parameters = Parameters.getAllDefaultParameters();
+ parameters.setParameterByKey(KEY.INPUT_DIMENSIONS, new int[] { 8 });
+ parameters.setParameterByKey(KEY.COLUMN_DIMENSIONS, new int[] { 2048 });
+ parameters.setParameterByKey(KEY.CELLS_PER_COLUMN, 32);
+
+ //SpatialPooler specific
+ parameters.setParameterByKey(KEY.POTENTIAL_RADIUS, 12);//3
+ parameters.setParameterByKey(KEY.POTENTIAL_PCT, 0.5);//0.5
+ parameters.setParameterByKey(KEY.GLOBAL_INHIBITIONS, false);
+ parameters.setParameterByKey(KEY.LOCAL_AREA_DENSITY, -1.0);
+ parameters.setParameterByKey(KEY.NUM_ACTIVE_COLUMNS_PER_INH_AREA, 5.0);
+ parameters.setParameterByKey(KEY.STIMULUS_THRESHOLD, 1.0);
+ parameters.setParameterByKey(KEY.SYN_PERM_INACTIVE_DEC, 0.01);
+ parameters.setParameterByKey(KEY.SYN_PERM_ACTIVE_INC, 0.1);
+ parameters.setParameterByKey(KEY.SYN_PERM_TRIM_THRESHOLD, 0.05);
+ parameters.setParameterByKey(KEY.SYN_PERM_CONNECTED, 0.1);
+ parameters.setParameterByKey(KEY.MIN_PCT_OVERLAP_DUTY_CYCLE, 0.1);
+ parameters.setParameterByKey(KEY.MIN_PCT_ACTIVE_DUTY_CYCLE, 0.1);
+ parameters.setParameterByKey(KEY.DUTY_CYCLE_PERIOD, 10);
+ parameters.setParameterByKey(KEY.MAX_BOOST, 10.0);
+ parameters.setParameterByKey(KEY.SEED, 42);
+ parameters.setParameterByKey(KEY.SP_VERBOSITY, 0);
+
+ //Temporal Memory specific
+ parameters.setParameterByKey(KEY.INITIAL_PERMANENCE, 0.4);
+ parameters.setParameterByKey(KEY.CONNECTED_PERMANENCE, 0.5);
+ parameters.setParameterByKey(KEY.MIN_THRESHOLD, 4);
+ parameters.setParameterByKey(KEY.MAX_NEW_SYNAPSE_COUNT, 4);
+ parameters.setParameterByKey(KEY.PERMANENCE_INCREMENT, 0.05);
+ parameters.setParameterByKey(KEY.PERMANENCE_DECREMENT, 0.05);
+ parameters.setParameterByKey(KEY.ACTIVATION_THRESHOLD, 4);
+
+ return parameters;
+ }
+}
diff --git a/src/jmh/java/org/numenta/nupic/benchmarks/SpatialPoolerGlobalInhibitionBenchmark.java b/src/jmh/java/org/numenta/nupic/benchmarks/SpatialPoolerGlobalInhibitionBenchmark.java
new file mode 100644
index 00000000..38db2e65
--- /dev/null
+++ b/src/jmh/java/org/numenta/nupic/benchmarks/SpatialPoolerGlobalInhibitionBenchmark.java
@@ -0,0 +1,46 @@
+package org.numenta.nupic.benchmarks;
+
+import java.util.concurrent.TimeUnit;
+
+import org.numenta.nupic.Parameters;
+import org.numenta.nupic.Parameters.KEY;
+import org.openjdk.jmh.annotations.Benchmark;
+import org.openjdk.jmh.annotations.BenchmarkMode;
+import org.openjdk.jmh.annotations.Mode;
+import org.openjdk.jmh.annotations.OutputTimeUnit;
+import org.openjdk.jmh.annotations.Setup;
+import org.openjdk.jmh.infra.Blackhole;
+
+public class SpatialPoolerGlobalInhibitionBenchmark extends AbstractAlgorithmBenchmark {
+
+ private int[][] input;
+
+ @Setup
+ public void init() {
+ super.init();
+
+ input = new int[7][8];
+ for(int i = 0;i < 7;i++) {
+ input[i] = encoder.encode((double) i + 1);
+ }
+ }
+
+ @Override
+ protected Parameters getParameters() {
+ Parameters parameters = super.getParameters();
+ parameters.setParameterByKey(KEY.GLOBAL_INHIBITIONS, true);
+ return parameters;
+ }
+
+ @Benchmark
+ @BenchmarkMode(Mode.AverageTime)
+ @OutputTimeUnit(TimeUnit.MILLISECONDS)
+ public int[] measureAvgCompute_7_Times(Blackhole bh) throws InterruptedException {
+ for(int i = 0;i < 7;i++) {
+ pooler.compute(memory, input[i], SDR, true, false);
+ }
+
+ return SDR;
+ }
+
+}
diff --git a/src/jmh/java/org/numenta/nupic/benchmarks/SpatialPoolerLocalInhibitionBenchmark.java b/src/jmh/java/org/numenta/nupic/benchmarks/SpatialPoolerLocalInhibitionBenchmark.java
new file mode 100644
index 00000000..f0830248
--- /dev/null
+++ b/src/jmh/java/org/numenta/nupic/benchmarks/SpatialPoolerLocalInhibitionBenchmark.java
@@ -0,0 +1,37 @@
+package org.numenta.nupic.benchmarks;
+
+import java.util.concurrent.TimeUnit;
+
+import org.openjdk.jmh.annotations.Benchmark;
+import org.openjdk.jmh.annotations.BenchmarkMode;
+import org.openjdk.jmh.annotations.Mode;
+import org.openjdk.jmh.annotations.OutputTimeUnit;
+import org.openjdk.jmh.annotations.Setup;
+import org.openjdk.jmh.infra.Blackhole;
+
+public class SpatialPoolerLocalInhibitionBenchmark extends AbstractAlgorithmBenchmark {
+
+ private int[][] input;
+
+ @Setup
+ public void init() {
+ super.init();
+
+ input = new int[7][8];
+ for(int i = 0;i < 7;i++) {
+ input[i] = encoder.encode((double) i + 1);
+ }
+ }
+
+ @Benchmark
+ @BenchmarkMode(Mode.AverageTime)
+ @OutputTimeUnit(TimeUnit.MILLISECONDS)
+ public int[] measureAvgCompute_7_Times(Blackhole bh) throws InterruptedException {
+ for(int i = 0;i < 7;i++) {
+ pooler.compute(memory, input[i], SDR, true, false);
+ }
+
+ return SDR;
+ }
+
+}
diff --git a/src/jmh/java/org/numenta/nupic/benchmarks/TemporalMemoryBenchmark.java b/src/jmh/java/org/numenta/nupic/benchmarks/TemporalMemoryBenchmark.java
new file mode 100644
index 00000000..357be37c
--- /dev/null
+++ b/src/jmh/java/org/numenta/nupic/benchmarks/TemporalMemoryBenchmark.java
@@ -0,0 +1,55 @@
+package org.numenta.nupic.benchmarks;
+
+import java.util.concurrent.TimeUnit;
+
+import org.numenta.nupic.research.ComputeCycle;
+import org.numenta.nupic.util.ArrayUtils;
+import org.openjdk.jmh.annotations.Benchmark;
+import org.openjdk.jmh.annotations.BenchmarkMode;
+import org.openjdk.jmh.annotations.Mode;
+import org.openjdk.jmh.annotations.OutputTimeUnit;
+import org.openjdk.jmh.annotations.Scope;
+import org.openjdk.jmh.annotations.State;
+import org.openjdk.jmh.infra.Blackhole;
+
+@State(Scope.Benchmark)
+public class TemporalMemoryBenchmark extends AbstractAlgorithmBenchmark {
+
+ private int[][] input;
+ private int[][] SDRs;
+
+ public void init() {
+ super.init();
+
+ //Initialize the encoder's encoded output
+ input = new int[7][8];
+ for(int i = 0;i < 7;i++) {
+ input[i] = encoder.encode((double) i + 1);
+ }
+
+ SDRs = new int[7][];
+
+ for(int i = 0;i < 7;i++) {
+ pooler.compute(memory, input[i], SDR, true, false);
+ SDRs[i] = ArrayUtils.where(SDR, ArrayUtils.WHERE_1);
+ }
+
+ for(int j = 0;j < 100;j++) {
+ for(int i = 0;i < 7;i++) {
+ temporalMemory.compute(memory, SDRs[i], true);
+ }
+ }
+ }
+
+ @Benchmark
+ @BenchmarkMode(Mode.AverageTime)
+ @OutputTimeUnit(TimeUnit.MILLISECONDS)
+ public ComputeCycle measureAvgCompute_7_Times(Blackhole bh) throws InterruptedException {
+ ComputeCycle cc = null;
+ for(int i = 0;i < 7;i++) {
+ cc = temporalMemory.compute(memory, SDRs[i], true);
+ }
+
+ return cc;
+ }
+}
diff --git a/src/jmh/resources/jmh_defaults.txt b/src/jmh/resources/jmh_defaults.txt
new file mode 100644
index 00000000..b35faa0d
--- /dev/null
+++ b/src/jmh/resources/jmh_defaults.txt
@@ -0,0 +1,36 @@
+jmh {
+ include = 'some regular expression' // include pattern (regular expression) for benchmarks to be executed
+ exclude = 'some regular expression' // exclude pattern (regular expression) for benchmarks to be executed
+ iterations = 10 // Number of measurement iterations to do.
+ benchmarkMode = 'thrpt' // Benchmark mode. Available modes are: [Throughput/thrpt, AverageTime/avgt, SampleTime/sample, SingleShotTime/ss, All/all]
+ batchSize = 1 // Batch size: number of benchmark method calls per operation. (some benchmark modes can ignore this setting)
+ fork = 2 // How many times to forks a single benchmark. Use 0 to disable forking altogether
+ failOnError = false // Should JMH fail immediately if any benchmark had experienced the unrecoverable error?
+ forceGC = false // Should JMH force GC between iterations?
+ jvm = 'myjvm' // Custom JVM to use when forking.
+ jvmArgs = 'Custom JVM args to use when forking.'
+ jvmArgsAppend = 'Custom JVM args to use when forking (append these)'
+ jvmArgsPrepend = 'Custom JVM args to use when forking (prepend these)'
+ humanOutputFile = project.file("${project.buildDir}/reports/jmh/human.txt") // human-readable output file
+ resultsFile = project.file("${project.buildDir}/reports/jmh/results.txt") // results file
+ operationsPerInvocation = 10 // Operations per invocation.
+ benchmarkParameters = [:] // Benchmark parameters.
+ profilers = [] // Use profilers to collect additional data.
+ timeOnIteration = '1s' // Time to spend at each measurement iteration.
+ resultFormat = 'CSV' // Result format type (one of CSV, JSON, NONE, SCSV, TEXT)
+ synchronizeIterations = false // Synchronize iterations?
+ threads = 4 // Number of worker threads to run with.
+ threadGroups = [2,3,4] //Override thread group distribution for asymmetric benchmarks.
+ timeUnit = 'ms' // Output time unit. Available time units are: [m, s, ms, us, ns].
+ verbosity = 'NORMAL' // Verbosity mode. Available modes are: [SILENT, NORMAL, EXTRA]
+ warmup = '1s' // Time to spend at each warmup iteration.
+ warmupBatchSize = 10 // Warmup batch size: number of benchmark method calls per operation.
+ warmupForks = 0 // How many warmup forks to make for a single benchmark. 0 to disable warmup forks.
+ warmupIterations = 1 // Number of warmup iterations to do.
+ warmupMode = 'INDI' // Warmup mode for warming up selected benchmarks. Warmup modes are: [INDI, BULK, BULK_INDI].
+ warmupBenchmarks = ['.*Warmup'] // Warmup benchmarks to include in the run in addition to already selected. JMH will not measure these benchmarks, but only use them for the warmup.
+
+ zip64 = true // Use ZIP64 format for bigger archives
+ jmhVersion = 1.3.2 // Specifies JMH version
+ includeTests = false // Allows to include test sources into generate JMH jar, i.e. use it when benchmarks depend on the test classes.
+}
\ No newline at end of file
diff --git a/src/main/java/org/numenta/nupic/Connections.java b/src/main/java/org/numenta/nupic/Connections.java
index 15e6d6f9..8f20b24f 100644
--- a/src/main/java/org/numenta/nupic/Connections.java
+++ b/src/main/java/org/numenta/nupic/Connections.java
@@ -131,7 +131,6 @@ public class Connections {
private double[] minActiveDutyCycles;
private double[] boostFactors;
-
/////////////////////////////////////// Temporal Memory Vars ///////////////////////////////////////////
protected Set activeCells = new LinkedHashSet();
@@ -1038,7 +1037,7 @@ public int getProxSynCount() {
* High verbose output useful for debugging
*/
public void printParameters() {
- System.out.println("------------J SpatialPooler Parameters ------------------");
+ System.out.println("------------ SpatialPooler Parameters ------------------");
System.out.println("numInputs = " + getNumInputs());
System.out.println("numColumns = " + getNumColumns());
System.out.println("columnDimensions = " + getColumnDimensions());
diff --git a/src/main/java/org/numenta/nupic/encoders/AdaptiveScalarEncoder.java b/src/main/java/org/numenta/nupic/encoders/AdaptiveScalarEncoder.java
index 6bbf5fe2..a00cf8be 100644
--- a/src/main/java/org/numenta/nupic/encoders/AdaptiveScalarEncoder.java
+++ b/src/main/java/org/numenta/nupic/encoders/AdaptiveScalarEncoder.java
@@ -1,31 +1,36 @@
package org.numenta.nupic.encoders;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
public class AdaptiveScalarEncoder extends ScalarEncoder {
+ private static final Logger LOGGER = LoggerFactory.getLogger(AdaptiveScalarEncoder.class);
+
/*
* This is an implementation of the scalar encoder that adapts the min and
* max of the scalar encoder dynamically. This is essential to the streaming
* model of the online prediction framework.
- *
+ *
* Initialization of an adapive encoder using resolution or radius is not
* supported; it must be intitialized with n. This n is kept constant while
* the min and max of the encoder changes.
- *
+ *
* The adaptive encoder must be have periodic set to false.
- *
+ *
* The adaptive encoder may be initialized with a minval and maxval or with
* `None` for each of these. In the latter case, the min and max are set as
* the 1st and 99th percentile over a window of the past 100 records.
- *
+ *
* *Note:** the sliding window may record duplicates of the values in the
* dataset, and therefore does not reflect the statistical distribution of
* the input data and may not be used to calculate the median, mean etc.
*/
-
+
public int recordNum = 0;
public boolean learningEnabled = true;
public Double[] slidingWindow = new Double[0];
@@ -34,7 +39,7 @@ public class AdaptiveScalarEncoder extends ScalarEncoder {
/*
* (non-Javadoc)
- *
+ *
* @see org.numenta.nupic.encoders.ScalarEncoder#init()
*/
@Override
@@ -42,10 +47,10 @@ public void init() {
this.setPeriodic(false);
super.init();
}
-
+
/*
* (non-Javadoc)
- *
+ *
* @see org.numenta.nupic.encoders.ScalarEncoder#initEncoder(int, double,
* double, int, double, double)
*/
@@ -71,7 +76,7 @@ public AdaptiveScalarEncoder() {
/**
* Returns a builder for building AdaptiveScalarEncoder. This builder may be
* reused to produce multiple builders
- *
+ *
* @return a {@code AdaptiveScalarEncoder.Builder}
*/
public static AdaptiveScalarEncoder.Builder adaptiveBuilder() {
@@ -129,7 +134,7 @@ private void setMinAndMax(Double input, boolean learn) {
slidingWindow = deleteItem(slidingWindow, 0);
}
slidingWindow = appendItem(slidingWindow, input);
-
+
if (this.minVal == this.maxVal) {
this.minVal = input;
this.maxVal = input + 1;
@@ -140,20 +145,16 @@ private void setMinAndMax(Double input, boolean learn) {
double minOverWindow = sorted[0];
double maxOverWindow = sorted[sorted.length - 1];
if (minOverWindow < this.minVal) {
- if (this.verbosity >= 2) {
- System.out.println(String.format("Input %s=%d smaller than minval %d. Adjusting minval to %d",
- this.name, input, this.minVal, minOverWindow));
- this.minVal = minOverWindow;
- setEncoderParams();
- }
+ LOGGER.trace("Input {}={} smaller than minVal {}. Adjusting minVal to {}",
+ this.name, input, this.minVal, minOverWindow);
+ this.minVal = minOverWindow;
+ setEncoderParams();
}
if (maxOverWindow > this.maxVal) {
- if (this.verbosity >= 2) {
- System.out.println(String.format("Input %s=%d greater than maxval %d. Adjusting maxval to %d",
- this.name, input, this.minVal, minOverWindow));
- this.maxVal = maxOverWindow;
- setEncoderParams();
- }
+ LOGGER.trace("Input {}={} greater than maxVal {}. Adjusting maxVal to {}",
+ this.name, input, this.minVal, minOverWindow);
+ this.maxVal = maxOverWindow;
+ setEncoderParams();
}
}
}
diff --git a/src/main/java/org/numenta/nupic/encoders/CategoryEncoder.java b/src/main/java/org/numenta/nupic/encoders/CategoryEncoder.java
index 4b7edb0c..fdba1c7d 100644
--- a/src/main/java/org/numenta/nupic/encoders/CategoryEncoder.java
+++ b/src/main/java/org/numenta/nupic/encoders/CategoryEncoder.java
@@ -33,6 +33,8 @@
import org.numenta.nupic.util.MinMax;
import org.numenta.nupic.util.SparseObjectMatrix;
import org.numenta.nupic.util.Tuple;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
import java.util.ArrayList;
import java.util.Arrays;
@@ -43,14 +45,14 @@
/**
* Encodes a list of discrete categories (described by strings), that aren't
* related to each other, so we never emit a mixture of categories.
- *
+ *
* The value of zero is reserved for "unknown category"
- *
+ *
* Internally we use a ScalarEncoder with a radius of 1, but since we only encode
* integers, we never get mixture outputs.
*
* The SDRCategoryEncoder (not yet implemented in Java) uses a different method to encode categories
- *
+ *
*
* Typical usage is as follows:
*
@@ -61,13 +63,13 @@
* .maxVal(8.0)
* .periodic(false)
* .forced(true);
- *
+ *
* CategoryEncoder encoder = builder.build();
- *
+ *
* Above values are not an example of "sane" values.
- *
+ *
*
- *
+ *
* @author David Ray
* @see ScalarEncoder
* @see Encoder
@@ -75,39 +77,42 @@
* @see Parameters
*/
public class CategoryEncoder extends Encoder {
+
+ private static final Logger LOG = LoggerFactory.getLogger(CategoryEncoder.class);
+
protected int ncategories;
-
+
protected TObjectIntMap categoryToIndex = new TObjectIntHashMap();
protected TIntObjectMap indexToCategory = new TIntObjectHashMap();
-
+
protected List categoryList;
-
+
protected int width;
private ScalarEncoder scalarEncoder;
-
+
/**
* Constructs a new {@code CategoryEncoder}
*/
private CategoryEncoder() {
}
-
+
/**
- * Returns a builder for building CategoryEncoders.
+ * Returns a builder for building CategoryEncoders.
* This builder may be reused to produce multiple builders
- *
+ *
* @return a {@code CategoryEncoder.Builder}
*/
public static Encoder.Builder builder() {
return new CategoryEncoder.Builder();
}
-
+
public void init() {
// number of categories includes zero'th category: "unknown"
ncategories = categoryList == null ? 0 : categoryList.size() + 1;
minVal = 0;
maxVal = ncategories - 1;
-
+
scalarEncoder = ScalarEncoder.builder()
.n(this.n)
.w(this.w)
@@ -116,7 +121,7 @@ public void init() {
.maxVal(this.maxVal)
.periodic(this.periodic)
.forced(this.forced).build();
-
+
indexToCategory.put(0, "");
if(categoryList != null && !categoryList.isEmpty()) {
int len = categoryList.size();
@@ -125,26 +130,26 @@ public void init() {
indexToCategory.put(i + 1, categoryList.get(i));
}
}
-
-
+
+
width = n = w * ncategories;
-
+
//TODO this is what the CategoryEncoder was doing before I added the ScalarEncoder delegate.
- //I'm concerned because we're changing n without calling init again on the scalar encoder.
+ //I'm concerned because we're changing n without calling init again on the scalar encoder.
//In other words, if I move the scalarEncoder = ...build() from to here, the test cases fail
//which indicates significant fragility and at some level a violation of encapsulation.
scalarEncoder.n = n;
-
-
-
+
+
+
if(getWidth() != width) {
throw new IllegalStateException(
"Width != w (num bits to represent output item) * #categories");
}
-
+
description.add(new Tuple(2, name, 0));
}
-
+
/**
* {@inheritDoc}
*/
@@ -152,7 +157,7 @@ public void init() {
public TDoubleList getScalars(T d) {
return new TDoubleArrayList(new double[] { categoryToIndex.get(d) });
}
-
+
/**
* {@inheritDoc}
*/
@@ -161,7 +166,7 @@ public int[] getBucketIndices(String input) {
if(input == null) return null;
return scalarEncoder.getBucketIndices(categoryToIndex.get(input));
}
-
+
/**
* {@inheritDoc}
*/
@@ -176,12 +181,9 @@ public void encodeIntoArray(String input, int[] output) {
value = value == categoryToIndex.getNoEntryValue() ? 0 : value;
scalarEncoder.encodeIntoArray(value, output);
}
-
- if(verbosity >= 2) {
- System.out.println(
- String.format("input: %s, val: %s, value: %d, output: %s",
- input, val, value, Arrays.toString(output)));
- }
+
+ LOG.trace("input: {}, val: {}, value: {}, output: {}",
+ input, val, value, Arrays.toString(output));
}
/**
@@ -191,16 +193,16 @@ public void encodeIntoArray(String input, int[] output) {
public DecodeResult decode(int[] encoded, String parentFieldName) {
// Get the scalar values from the underlying scalar encoder
DecodeResult result = scalarEncoder.decode(encoded, parentFieldName);
-
+
if(result.getFields().size() == 0) {
return result;
}
-
+
// Expect only 1 field
if(result.getFields().size() != 1) {
throw new IllegalStateException("Expecting only one field");
}
-
+
//Get the list of categories the scalar values correspond to and
// generate the description from the category name(s).
Map fieldRanges = result.getFields();
@@ -219,7 +221,7 @@ public DecodeResult decode(int[] encoded, String parentFieldName) {
minV += 1;
}
}
-
+
//Return result
String fieldName;
if(!parentFieldName.isEmpty()) {
@@ -227,13 +229,13 @@ public DecodeResult decode(int[] encoded, String parentFieldName) {
}else{
fieldName = name;
}
-
+
Map retVal = new HashMap();
retVal.put(fieldName, new RangeList(outRanges, desc.toString()));
-
+
return new DecodeResult(retVal, Arrays.asList(new String[] { fieldName }));
}
-
+
/**
* {@inheritDoc}
*/
@@ -241,26 +243,26 @@ public DecodeResult decode(int[] encoded, String parentFieldName) {
public TDoubleList closenessScores(TDoubleList expValues, TDoubleList actValues, boolean fractional) {
double expValue = expValues.get(0);
double actValue = actValues.get(0);
-
+
double closeness = expValue == actValue ? 1.0 : 0;
if(!fractional) closeness = 1.0 - closeness;
-
+
return new TDoubleArrayList(new double[]{ closeness });
}
-
+
/**
* Returns a list of items, one for each bucket defined by this encoder.
* Each item is the value assigned to that bucket, this is the same as the
* EncoderResult.value that would be returned by getBucketInfo() for that
* bucket and is in the same format as the input that would be passed to
* encode().
- *
+ *
* This call is faster than calling getBucketInfo() on each bucket individually
* if all you need are the bucket values.
*
* @param returnType class type parameter so that this method can return encoder
* specific value types
- *
+ *
* @return list of items, each item representing the bucket value for that
* bucket.
*/
@@ -275,10 +277,10 @@ public List getBucketValues(Class t) {
((List)bucketValues).add((String)getBucketInfo(new int[] { i }).get(0).getValue());
}
}
-
+
return (List)bucketValues;
}
-
+
/**
* {@inheritDoc}
*/
@@ -286,23 +288,23 @@ public List getBucketValues(Class t) {
public List getBucketInfo(int[] buckets) {
// For the category encoder, the bucket index is the category index
List bucketInfo = scalarEncoder.getBucketInfo(buckets);
-
+
int categoryIndex = (int)Math.round((double)bucketInfo.get(0).getValue());
String category = indexToCategory.get(categoryIndex);
-
+
bucketInfo.set(0, new EncoderResult(category, categoryIndex, bucketInfo.get(0).getEncoding()));
return bucketInfo;
}
-
+
/**
* {@inheritDoc}
*/
@Override
public List topDownCompute(int[] encoded) {
//Get/generate the topDown mapping table
- SparseObjectMatrix topDownMapping = scalarEncoder.getTopDownMapping();
+ SparseObjectMatrix topDownMapping = scalarEncoder.getTopDownMapping();
// See which "category" we match the closest.
- int category = ArrayUtils.argmax(rightVecProd(topDownMapping, encoded));
+ int category = ArrayUtils.argmax(rightVecProd(topDownMapping, encoded));
return getBucketInfo(new int[] { category });
}
@@ -313,30 +315,30 @@ public List getCategoryList() {
public void setCategoryList(List categoryList) {
this.categoryList = categoryList;
}
-
+
/**
* Returns a {@link EncoderBuilder} for constructing {@link CategoryEncoder}s
- *
+ *
* The base class architecture is put together in such a way where boilerplate
* initialization can be kept to a minimum for implementing subclasses, while avoiding
* the mistake-proneness of extremely long argument lists.
- *
+ *
* @see ScalarEncoder.Builder#setStuff(int)
*/
public static class Builder extends Encoder.Builder {
private List categoryList;
-
+
private Builder() {}
@Override
public CategoryEncoder build() {
- //Must be instantiated so that super class can initialize
+ //Must be instantiated so that super class can initialize
//boilerplate variables.
encoder = new CategoryEncoder();
-
+
//Call super class here
super.build();
-
+
////////////////////////////////////////////////////////
// Implementing classes would do setting of specific //
// vars here together with any sanity checking //
@@ -348,16 +350,16 @@ public CategoryEncoder build() {
((CategoryEncoder)encoder).setCategoryList(this.categoryList);
//Call init
((CategoryEncoder)encoder).init();
-
+
return (CategoryEncoder)encoder;
}
-
+
/**
- * Never called - just here as an example of specialization for a specific
+ * Never called - just here as an example of specialization for a specific
* subclass of Encoder.Builder
- *
+ *
* Example specific method!!
- *
+ *
* @param stuff
* @return
*/
diff --git a/src/main/java/org/numenta/nupic/encoders/DateEncoder.java b/src/main/java/org/numenta/nupic/encoders/DateEncoder.java
index 101cbcd4..bf51791b 100644
--- a/src/main/java/org/numenta/nupic/encoders/DateEncoder.java
+++ b/src/main/java/org/numenta/nupic/encoders/DateEncoder.java
@@ -144,15 +144,6 @@ public void init() {
// TODO figure out how to remove this
setForced(true);
- // Adapted from MultiEncoder
- encoders = new LinkedHashMap<>();
-
- // Encoder.getEncoderTuple() can return null, but logic in Encoder.addEncoder()
- // expects at least one empty ArrayList, so we need to prepare such
- // an initial EncoderTuple.
- // TODO figure out whether the solution can be moved to Encoder
- encoders.put(new EncoderTuple("", this, 0), new ArrayList());
-
// Note: The order of adding encoders matters, must be in the following
// season, dayOfWeek, weekend, customDays, holiday, timeOfDay
@@ -455,7 +446,6 @@ public List getEncodedValues(Date inputData) {
* @param inputData the input value, in this case a date object
* @return a list of one input double
*/
- @SuppressWarnings({ "rawtypes", "unchecked" })
public TDoubleList getScalars(Date inputData) {
if(inputData == null) {
throw new IllegalArgumentException("DateEncoder requires a valid Date object but got null");
diff --git a/src/main/java/org/numenta/nupic/encoders/DeltaEncoder.java b/src/main/java/org/numenta/nupic/encoders/DeltaEncoder.java
index 6fab0d48..0b5a77f0 100644
--- a/src/main/java/org/numenta/nupic/encoders/DeltaEncoder.java
+++ b/src/main/java/org/numenta/nupic/encoders/DeltaEncoder.java
@@ -38,7 +38,8 @@ public class DeltaEncoder extends AdaptiveScalarEncoder {
public DeltaEncoder() {
}
- /* (non-Javadoc)
+ /**
+ * {@inheritDoc}
* @see org.numenta.nupic.encoders.AdaptiveScalarEncoder#init()
*/
@Override
@@ -46,7 +47,8 @@ public void init() {
super.init();
}
- /* (non-Javadoc)
+ /**
+ * {@inheritDoc}
* @see org.numenta.nupic.encoders.AdaptiveScalarEncoder#initEncoder(int, double, double, int, double, double)
*/
@Override
@@ -77,7 +79,8 @@ public DeltaEncoder build() {
}
}
- /* (non-Javadoc)
+ /**
+ * {@inheritDoc}
* @see org.numenta.nupic.encoders.AdaptiveScalarEncoder#encodeIntoArray(java.lang.Double, int[])
*/
@Override
@@ -86,11 +89,7 @@ public void encodeIntoArray(Double input, int[] output) {
throw new IllegalArgumentException(
String.format("Expected a Double input but got input of type %s", input.toString()));
}
- boolean learn = false;
double delta = 0;
- if (this.encLearningEnabled == false) {
- learn = this.learningEnabled;
- }
if (input == DeltaEncoder.SENTINEL_VALUE_FOR_MISSING_DATA) {
output = new int[this.n];
Arrays.fill(output, 0);
@@ -130,7 +129,8 @@ public boolean isDelta() {
return true;
}
- /* (non-Javadoc)
+ /**
+ * {@inheritDoc}
* @see org.numenta.nupic.encoders.AdaptiveScalarEncoder#getBucketInfo(int[])
*/
@Override
@@ -138,7 +138,8 @@ public List getBucketInfo(int[] buckets) {
return super.getBucketInfo(buckets);
}
- /* (non-Javadoc)
+ /**
+ * {@inheritDoc}
* @see org.numenta.nupic.encoders.AdaptiveScalarEncoder#topDownCompute(int[])
*/
@Override
diff --git a/src/main/java/org/numenta/nupic/encoders/Encoder.java b/src/main/java/org/numenta/nupic/encoders/Encoder.java
index 8e63df24..c4f6671d 100644
--- a/src/main/java/org/numenta/nupic/encoders/Encoder.java
+++ b/src/main/java/org/numenta/nupic/encoders/Encoder.java
@@ -38,6 +38,8 @@
import org.numenta.nupic.util.MinMax;
import org.numenta.nupic.util.SparseObjectMatrix;
import org.numenta.nupic.util.Tuple;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
/**
*
@@ -79,6 +81,9 @@
* @author David Ray
*/
public abstract class Encoder {
+
+ private static final Logger LOGGER = LoggerFactory.getLogger(Encoder.class);
+
/** Value used to represent no data */
public static final double SENTINEL_VALUE_FOR_MISSING_DATA = Double.NaN;
protected List description = new ArrayList<>();
@@ -114,12 +119,11 @@ public abstract class Encoder {
/** if true, skip some safety checks (for compatibility reasons), default false */
protected boolean forced;
/** Encoder name - an optional string which will become part of the description */
- protected String name;
+ protected String name = "";
protected int padding;
protected int nInternal;
protected double rangeInternal;
protected double range;
- protected int verbosity;
protected boolean encLearningEnabled;
protected List flattenedFieldTypeList;
protected Map> decoderFieldTypes;
@@ -422,22 +426,6 @@ public String getName() {
return name;
}
- /**
- * Sets the verbosity of debug output
- * @param v
- */
- public void setVerbosity(int v) {
- this.verbosity = v;
- }
-
- /**
- * Returns the verbosity setting for an encoder.
- * @return
- */
- public int getVerbosity() {
- return verbosity;
- }
-
/**
* Adds a the specified {@link Encoder} to the list of the specified
* parent's {@code Encoder}s.
@@ -454,6 +442,11 @@ public void addEncoder(Encoder parent, String name, Encoder child, int off
}
EncoderTuple key = getEncoderTuple(parent);
+ // Insert a new Tuple for the parent if not yet added.
+ if(key == null) {
+ encoders.put(key = new EncoderTuple("", this, 0), new ArrayList());
+ }
+
List childEncoders = null;
if((childEncoders = encoders.get(key)) == null) {
encoders.put(key, childEncoders = new ArrayList());
@@ -920,7 +913,7 @@ public Tuple encodedBitDescription(int bitOffset, boolean formatted) {
* @param prefix
*/
public void pprintHeader(String prefix) {
- System.out.println(prefix == null ? "" : prefix);
+ LOGGER.info(prefix == null ? "" : prefix);
List description = getDescription();
description.add(new Tuple(2, "end", getWidth()));
@@ -934,22 +927,22 @@ public void pprintHeader(String prefix) {
StringBuilder pname = new StringBuilder(name);
if(name.length() > width) pname.setLength(width);
- System.out.println(String.format(formatStr, pname));
+ LOGGER.info(String.format(formatStr, pname));
}
len = getWidth() + (description.size() - 1)*3 - 1;
StringBuilder hyphens = new StringBuilder();
for(int i = 0;i < len;i++) hyphens.append("-");
- System.out.println(new StringBuilder(prefix).append(hyphens));
- }
+ LOGGER.info(new StringBuilder(prefix).append(hyphens).toString());
+ }
- /**
+ /**
* Pretty-print the encoded output using ascii art.
* @param output
* @param prefix
*/
public void pprint(int[] output, String prefix) {
- System.out.println(prefix == null ? "" : prefix);
+ LOGGER.info(prefix == null ? "" : prefix);
List description = getDescription();
description.add(new Tuple(2, "end", getWidth()));
@@ -959,17 +952,17 @@ public void pprint(int[] output, String prefix) {
int offset = (int)description.get(i).get(1);
int nextOffset = (int)description.get(i + 1).get(1);
- System.out.println(
- String.format("%s |",
- ArrayUtils.bitsToString(
- ArrayUtils.sub(output, ArrayUtils.range(offset, nextOffset))
- )
- )
- );
- }
- }
+ LOGGER.info(
+ String.format("%s |",
+ ArrayUtils.bitsToString(
+ ArrayUtils.sub(output, ArrayUtils.range(offset, nextOffset))
+ )
+ )
+ );
+ }
+ }
- /**
+ /**
* Takes an encoded output and does its best to work backwards and generate
* the input that would have generated it.
*
@@ -1265,7 +1258,6 @@ public int getDisplayWidth() {
public static abstract class Builder {
protected int n;
protected int w;
- protected int encVerbosity;
protected double minVal;
protected double maxVal;
protected double radius;
@@ -1284,7 +1276,6 @@ public E build() {
}
encoder.setN(n);
encoder.setW(w);
- encoder.setVerbosity(encVerbosity);
encoder.setMinVal(minVal);
encoder.setMaxVal(maxVal);
encoder.setRadius(radius);
@@ -1305,10 +1296,6 @@ public K w(int w) {
this.w = w;
return (K)this;
}
- public K verbosity(int verbosity) {
- this.encVerbosity = verbosity;
- return (K)this;
- }
public K minVal(double minVal) {
this.minVal = minVal;
return (K)this;
diff --git a/src/main/java/org/numenta/nupic/encoders/LogEncoder.java b/src/main/java/org/numenta/nupic/encoders/LogEncoder.java
index ec9ca334..31d5ee41 100644
--- a/src/main/java/org/numenta/nupic/encoders/LogEncoder.java
+++ b/src/main/java/org/numenta/nupic/encoders/LogEncoder.java
@@ -35,15 +35,17 @@
import org.numenta.nupic.FieldMetaType;
import org.numenta.nupic.util.MinMax;
import org.numenta.nupic.util.Tuple;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
/**
* DOCUMENTATION TAKEN DIRECTLY FROM THE PYTHON VERSION:
- *
+ *
* This class wraps the ScalarEncoder class.
* A Log encoder represents a floating point value on a logarithmic scale.
- * valueToEncode = log10(input)
- *
+ * valueToEncode = log10(input)
+ *
* w -- number of bits to set in output
* minval -- minimum input value. must be greater than 0. Lower values are
* reset to this value
@@ -51,7 +53,7 @@
* periodic -- If true, then the input value "wraps around" such that minval =
* maxval For a periodic value, the input must be strictly less than
* maxval, otherwise maxval is a true upper bound.
- *
+ *
* Exactly one of n, radius, resolution must be set. "0" is a special
* value that means "not set".
* n -- number of bits in the representation (must be > w)
@@ -63,29 +65,31 @@
* in the output. In terms of the original input values, this
* means 10^1 (1) and 10^1.1 (1.25) will be distinguishable.
* name -- an optional string which will become part of the description
- * verbosity -- level of debugging output you want the encoder to provide.
* clipInput -- if true, non-periodic inputs smaller than minval or greater
* than maxval will be clipped to minval/maxval
* forced -- (default False), if True, skip some safety checks
*/
public class LogEncoder extends Encoder {
+
+ private static final Logger LOG = LoggerFactory.getLogger(LogEncoder.class);
+
private ScalarEncoder encoder;
private double minScaledValue, maxScaledValue;
/**
* Constructs a new {@code LogEncoder}
*/
LogEncoder() {}
-
+
/**
- * Returns a builder for building LogEncoders.
+ * Returns a builder for building LogEncoders.
* This builder may be reused to produce multiple builders
- *
+ *
* @return a {@code LogEncoder.Builder}
*/
public static Encoder.Builder builder() {
return new LogEncoder.Builder();
}
-
+
/**
* w -- number of bits to set in output
* minval -- minimum input value. must be greater than 0. Lower values are
@@ -94,7 +98,7 @@ public static Encoder.Builder builder() {
* periodic -- If true, then the input value "wraps around" such that minval =
* maxval For a periodic value, the input must be strictly less than
* maxval, otherwise maxval is a true upper bound.
- *
+ *
* Exactly one of n, radius, resolution must be set. "0" is a special
* value that means "not set".
* n -- number of bits in the representation (must be > w)
@@ -106,7 +110,6 @@ public static Encoder.Builder builder() {
* in the output. In terms of the original input values, this
* means 10^1 (1) and 10^1.1 (1.25) will be distinguishable.
* name -- an optional string which will become part of the description
- * verbosity -- level of debugging output you want the encoder to provide.
* clipInput -- if true, non-periodic inputs smaller than minval or greater
* than maxval will be clipped to minval/maxval
* forced -- (default False), if True, skip some safety checks
@@ -118,28 +121,28 @@ public void init() {
if (getW() == 0) {
setW(5);
}
-
+
// maxVal defaults to 10000.
if (getMaxVal() == 0.0) {
setMaxVal(10000.);
}
-
+
if (getMinVal() < lowLimit) {
setMinVal(lowLimit);
}
-
+
if (getMinVal() >= getMaxVal()) {
- throw new IllegalStateException("Max val must be larger than min val or the lower limit " +
+ throw new IllegalStateException("Max val must be larger than min val or the lower limit " +
"for this encoder " + String.format("%.7f", lowLimit));
}
-
+
minScaledValue = Math.log10(getMinVal());
maxScaledValue = Math.log10(getMaxVal());
-
+
if(minScaledValue >= maxScaledValue) {
throw new IllegalStateException("Max val must be larger, in log space, than min val.");
}
-
+
// There are three different ways of thinking about the representation. Handle
// each case here.
encoder = ScalarEncoder.builder()
@@ -150,17 +153,16 @@ public void init() {
.n(getN())
.radius(getRadius())
.resolution(getResolution())
- .verbosity(getVerbosity())
.clipInput(clipInput())
.forced(isForced())
.name(getName())
.build();
-
+
setN(encoder.getN());
setResolution(encoder.getResolution());
setRadius(encoder.getRadius());
}
-
+
@Override
public int getWidth() {
@@ -176,7 +178,7 @@ public boolean isDelta() {
public List getDescription() {
return encoder.getDescription();
}
-
+
/**
* {@inheritDoc}
*/
@@ -184,7 +186,7 @@ public List getDescription() {
public List getDecoderOutputFieldTypes() {
return encoder.getDecoderOutputFieldTypes();
}
-
+
/**
* Convert the input, which is in normal space, into log space
* @param input Value in normal space.
@@ -204,23 +206,23 @@ private Double getScaledValue(double input) {
return Math.log10(val);
}
}
-
+
/**
* Returns the bucket indices.
- *
- * @param input
+ *
+ * @param input
*/
@Override
public int[] getBucketIndices(double input) {
Double scaledVal = getScaledValue(input);
-
+
if (scaledVal == null) {
return new int[]{};
} else {
return encoder.getBucketIndices(scaledVal);
}
}
-
+
/**
* Encodes inputData and puts the encoded value into the output array,
* which is a 1-D array of length returned by {@link Connections#getW()}.
@@ -228,23 +230,21 @@ public int[] getBucketIndices(double input) {
* Note: The output array is reused, so clear it before updating it.
* @param inputData Data to encode. This should be validated by the encoder.
* @param output 1-D array of same length returned by {@link Connections#getW()}
- *
+ *
* @return
*/
@Override
public void encodeIntoArray(Double input, int[] output) {
Double scaledVal = getScaledValue(input);
-
+
if (scaledVal == null) {
Arrays.fill(output, 0);
} else {
encoder.encodeIntoArray(scaledVal, output);
-
- if (getVerbosity() >= 2) {
- System.out.print("input: " + input);
- System.out.print(" scaledVal: " + scaledVal);
- System.out.println(" output: " + Arrays.toString(output));
- }
+
+ LOG.trace("input: " + input);
+ LOG.trace(" scaledVal: " + scaledVal);
+ LOG.trace(" output: " + Arrays.toString(output));
}
}
@@ -255,22 +255,22 @@ public void encodeIntoArray(Double input, int[] output) {
public DecodeResult decode(int[] encoded, String parentFieldName) {
// Get the scalar values from the underlying scalar encoder
DecodeResult decodeResult = encoder.decode(encoded, parentFieldName);
-
+
Map fields = decodeResult.getFields();
-
+
if (fields.keySet().size() == 0) {
return decodeResult;
}
-
+
// Convert each range into normal space
RangeList inRanges = (RangeList) fields.values().toArray()[0];
RangeList outRanges = new RangeList(new ArrayList(), "");
for (MinMax minMax : inRanges.getRanges()) {
- MinMax scaledMinMax = new MinMax( Math.pow(10, minMax.min()),
+ MinMax scaledMinMax = new MinMax( Math.pow(10, minMax.min()),
Math.pow(10, minMax.max()));
outRanges.add(scaledMinMax);
}
-
+
// Generate a text description of the ranges
String desc = "";
int numRanges = outRanges.size();
@@ -286,20 +286,20 @@ public DecodeResult decode(int[] encoded, String parentFieldName) {
}
}
outRanges.setDescription(desc);
-
+
String fieldName;
if (!parentFieldName.equals("")) {
fieldName = String.format("%s.%s", parentFieldName, getName());
} else {
fieldName = getName();
}
-
+
Map outFields = new HashMap();
outFields.put(fieldName, outRanges);
-
+
List fieldNames = new ArrayList();
fieldNames.add(fieldName);
-
+
return new DecodeResult(outFields, fieldNames);
}
@@ -309,11 +309,11 @@ public DecodeResult decode(int[] encoded, String parentFieldName) {
@SuppressWarnings("unchecked")
@Override
public List getBucketValues(Class t) {
- // Need to re-create?
+ // Need to re-create?
if(bucketValues == null) {
List scaledValues = encoder.getBucketValues(t);
bucketValues = new ArrayList();
-
+
for (S scaledValue : scaledValues) {
double value = Math.pow(10, (Double)scaledValue);
((List)bucketValues).add(value);
@@ -321,7 +321,7 @@ public List getBucketValues(Class t) {
}
return (List)bucketValues;
}
-
+
/**
* {@inheritDoc}
*/
@@ -330,10 +330,10 @@ public List getBucketInfo(int[] buckets) {
EncoderResult scaledResult = encoder.getBucketInfo(buckets).get(0);
double scaledValue = (Double)scaledResult.getValue();
double value = Math.pow(10, scaledValue);
-
+
return Arrays.asList(new EncoderResult(value, value, scaledResult.getEncoding()));
}
-
+
/**
* {@inheritDoc}
*/
@@ -342,17 +342,17 @@ public List topDownCompute(int[] encoded) {
EncoderResult scaledResult = encoder.topDownCompute(encoded).get(0);
double scaledValue = (Double)scaledResult.getValue();
double value = Math.pow(10, scaledValue);
-
+
return Arrays.asList(new EncoderResult(value, value, scaledResult.getEncoding()));
}
-
+
/**
* {@inheritDoc}
*/
@Override
public TDoubleList closenessScores(TDoubleList expValues, TDoubleList actValues, boolean fractional) {
TDoubleList retVal = new TDoubleArrayList();
-
+
double expValue, actValue;
if (expValues.get(0) > 0) {
expValue = Math.log10(expValues.get(0));
@@ -364,7 +364,7 @@ public TDoubleList closenessScores(TDoubleList expValues, TDoubleList actValues,
} else {
actValue = minScaledValue;
}
-
+
double closeness;
if (fractional) {
double err = Math.abs(expValue - actValue);
@@ -374,18 +374,18 @@ public TDoubleList closenessScores(TDoubleList expValues, TDoubleList actValues,
} else {
closeness = Math.abs(expValue - actValue);;
}
-
+
retVal.add(closeness);
return retVal;
}
-
+
/**
* Returns a {@link EncoderBuilder} for constructing {@link ScalarEncoder}s
- *
+ *
* The base class architecture is put together in such a way where boilerplate
* initialization can be kept to a minimum for implementing subclasses, while avoiding
* the mistake-proneness of extremely long argument lists.
- *
+ *
* @see ScalarEncoder.Builder#setStuff(int)
*/
public static class Builder extends Encoder.Builder {
@@ -393,20 +393,20 @@ private Builder() {}
@Override
public LogEncoder build() {
- //Must be instantiated so that super class can initialize
+ //Must be instantiated so that super class can initialize
//boilerplate variables.
encoder = new LogEncoder();
-
+
//Call super class here
super.build();
-
+
////////////////////////////////////////////////////////
// Implementing classes would do setting of specific //
// vars here together with any sanity checking //
////////////////////////////////////////////////////////
-
+
((LogEncoder)encoder).init();
-
+
return (LogEncoder)encoder;
}
}
diff --git a/src/main/java/org/numenta/nupic/encoders/MultiEncoder.java b/src/main/java/org/numenta/nupic/encoders/MultiEncoder.java
index 61cc79ab..c0830e92 100644
--- a/src/main/java/org/numenta/nupic/encoders/MultiEncoder.java
+++ b/src/main/java/org/numenta/nupic/encoders/MultiEncoder.java
@@ -27,7 +27,6 @@
import java.util.ArrayList;
import java.util.Collections;
-import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
@@ -38,29 +37,29 @@
* A MultiEncoder encodes a dictionary or object with
* multiple components. A MultiEncode contains a number
* of sub-encoders, each of which encodes a separate component.
- *
+ *
* @see Encoder
* @see EncoderResult
* @see Parameters
- *
+ *
* @author wlmiller
*/
public class MultiEncoder extends Encoder
- *
+ *
* @author Anubhav Chaturvedi
*/
public static class Builder
@@ -671,7 +668,6 @@ public static class Builder
private Builder() {
this.n(400);
this.w(21);
- this.verbosity(0);
seed = 42;
maxBuckets = 1000;
maxOverlap = 2;
diff --git a/src/main/java/org/numenta/nupic/encoders/SDRCategoryEncoder.java b/src/main/java/org/numenta/nupic/encoders/SDRCategoryEncoder.java
index 3fcf7381..77b0a356 100644
--- a/src/main/java/org/numenta/nupic/encoders/SDRCategoryEncoder.java
+++ b/src/main/java/org/numenta/nupic/encoders/SDRCategoryEncoder.java
@@ -31,6 +31,8 @@
import org.numenta.nupic.util.MinMax;
import org.numenta.nupic.util.SparseObjectMatrix;
import org.numenta.nupic.util.Tuple;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
import java.util.ArrayList;
import java.util.Arrays;
@@ -54,6 +56,9 @@
* @see EncoderResult
*/
public class SDRCategoryEncoder extends Encoder {
+
+ private static final Logger LOG = LoggerFactory.getLogger(SDRCategoryEncoder.class);
+
private Random random;
private int thresholdOverlap;
private final SDRByCategoryMap sdrByCategory = new SDRByCategoryMap();
@@ -120,7 +125,7 @@ private SDRCategoryEncoder() {
def __init__(self, n, w, categoryList = None, name="category", verbosity=0,
encoderSeed=1, forced=False):
*/
- private void init(int n, int w, List categoryList, String name, int verbosity,
+ private void init(int n, int w, List categoryList, String name,
int encoderSeed, boolean forced) {
/*Python ref: n is total bits in output
@@ -167,7 +172,6 @@ private void init(int n, int w, List categoryList, String name, int verb
if (this.thresholdOverlap < this.w - 3) {
this.thresholdOverlap = this.w - 3;
}
- this.verbosity = verbosity;
this.description.add(new Tuple(2, name, 0));
this.name = name;
/*
@@ -215,11 +219,8 @@ public void encodeIntoArray(String input, int[] output) {
int[] categoryEncoding = sdrByCategory.getSdr(index);
System.arraycopy(categoryEncoding, 0, output, 0, categoryEncoding.length);
}
- if (verbosity >= 2) {
- System.out.println("input:" + input + ", index:" + index + ", output:" + ArrayUtils.intArrayToString(
- output));
- System.out.println("decoded:" + decodedToStr(decode(output, "")));
- }
+ LOG.trace("input:" + input + ", index:" + index + ", output:" + ArrayUtils.intArrayToString(output));
+ LOG.trace("decoded:" + decodedToStr(decode(output, "")));
}
/**
@@ -287,11 +288,11 @@ public boolean eval(int i) {
}
}
}
- if (verbosity >= 2) {
- System.out.println("Overlaps for decoding:");
+ LOG.trace("Overlaps for decoding:");
+ if (LOG.isTraceEnabled()){
int inx = 0;
for (String category : sdrByCategory.keySet()) {
- System.out.println(overlap[inx] + " " + category);
+ LOG.trace(overlap[inx] + " " + category);
inx++;
}
}
@@ -472,7 +473,7 @@ public SDRCategoryEncoder build() {
throw new IllegalStateException("\"W\" should be set");
}
SDRCategoryEncoder sdrCategoryEncoder = new SDRCategoryEncoder();
- sdrCategoryEncoder.init(n, w, categoryList, name, encVerbosity, encoderSeed, forced);
+ sdrCategoryEncoder.init(n, w, categoryList, name, encoderSeed, forced);
return sdrCategoryEncoder;
diff --git a/src/main/java/org/numenta/nupic/encoders/ScalarEncoder.java b/src/main/java/org/numenta/nupic/encoders/ScalarEncoder.java
index 6f6bb0ce..1d8a66ed 100644
--- a/src/main/java/org/numenta/nupic/encoders/ScalarEncoder.java
+++ b/src/main/java/org/numenta/nupic/encoders/ScalarEncoder.java
@@ -31,6 +31,8 @@
import org.numenta.nupic.util.MinMax;
import org.numenta.nupic.util.SparseObjectMatrix;
import org.numenta.nupic.util.Tuple;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
import java.util.ArrayList;
import java.util.Arrays;
@@ -41,7 +43,7 @@
/**
* DOCUMENTATION TAKEN DIRECTLY FROM THE PYTHON VERSION:
- *
+ *
* A scalar encoder encodes a numeric (floating point) value into an array
* of bits. The output is 0's except for a contiguous block of 1's. The
* location of this contiguous block varies continuously with the input value.
@@ -149,25 +151,28 @@
* resolution = radius / w
* n = w * range/radius (periodic)
* n = w * range/radius + 2 * h (non-periodic)
- *
+ *
* @author metaware
*/
public class ScalarEncoder extends Encoder {
+
+ private static final Logger LOGGER = LoggerFactory.getLogger(ScalarEncoder.class);
+
/**
* Constructs a new {@code ScalarEncoder}
*/
ScalarEncoder() {}
-
+
/**
- * Returns a builder for building ScalarEncoders.
+ * Returns a builder for building ScalarEncoders.
* This builder may be reused to produce multiple builders
- *
+ *
* @return a {@code ScalarEncoder.Builder}
*/
public static Encoder.Builder builder() {
return new ScalarEncoder.Builder();
}
-
+
/**
* Returns true if the underlying encoder works on deltas
*/
@@ -175,7 +180,7 @@ public static Encoder.Builder builder() {
public boolean isDelta() {
return false;
}
-
+
/**
* w -- number of bits to set in output
* minval -- minimum input value
@@ -189,7 +194,7 @@ public boolean isDelta() {
* representations
* resolution -- inputs separated by more than, or equal to this distance will have different
* representations
- *
+ *
* name -- an optional string which will become part of the description
*
* clipInput -- if true, non-periodic inputs smaller than minval or greater
@@ -202,29 +207,29 @@ public void init() {
throw new IllegalStateException(
"W must be an odd number (to eliminate centering difficulty)");
}
-
+
setHalfWidth((getW() - 1) / 2);
-
+
// For non-periodic inputs, padding is the number of bits "outside" the range,
// on each side. I.e. the representation of minval is centered on some bit, and
// there are "padding" bits to the left of that centered bit; similarly with
// bits to the right of the center bit of maxval
setPadding(isPeriodic() ? 0 : getHalfWidth());
-
- if(!Double.isNaN(getMinVal()) && !Double.isNaN(getMinVal())) {
+
+ if(!Double.isNaN(getMinVal()) && !Double.isNaN(getMaxVal())) {
if(getMinVal() >= getMaxVal()) {
throw new IllegalStateException("maxVal must be > minVal");
}
setRangeInternal(getMaxVal() - getMinVal());
}
-
+
// There are three different ways of thinking about the representation. Handle
// each case here.
initEncoder(getW(), getMinVal(), getMaxVal(), getN(), getRadius(), getResolution());
-
+
//nInternal represents the output area excluding the possible padding on each side
setNInternal(getN() - 2 * getPadding());
-
+
if(getName() == null) {
if((getMinVal() % ((int)getMinVal())) > 0 ||
(getMaxVal() % ((int)getMaxVal())) > 0) {
@@ -233,18 +238,18 @@ public void init() {
setName("[" + (int)getMinVal() + ":" + (int)getMaxVal() + "]");
}
}
-
+
//Checks for likely mistakes in encoder settings
if(!isForced()) {
checkReasonableSettings();
}
description.add(new Tuple(2, (name = getName()).equals("None") ? "[" + (int)getMinVal() + ":" + (int)getMaxVal() + "]" : name, 0));
}
-
+
/**
- * There are three different ways of thinking about the representation.
+ * There are three different ways of thinking about the representation.
* Handle each case here.
- *
+ *
* @param c
* @param minVal
* @param maxVal
@@ -254,15 +259,15 @@ public void init() {
*/
public void initEncoder(int w, double minVal, double maxVal, int n, double radius, double resolution) {
if(n != 0) {
- if(minVal != 0 && maxVal != 0) {
+ if(!Double.isNaN(minVal) && !Double.isNaN(maxVal)) {
if(!isPeriodic()) {
setResolution(getRangeInternal() / (getN() - getW()));
}else{
setResolution(getRangeInternal() / getN());
}
-
+
setRadius(getW() * getResolution());
-
+
if(isPeriodic()) {
setRange(getRangeInternal());
}else{
@@ -278,13 +283,13 @@ public void initEncoder(int w, double minVal, double maxVal, int n, double radiu
throw new IllegalStateException(
"One of n, radius, resolution must be specified for a ScalarEncoder");
}
-
+
if(isPeriodic()) {
setRange(getRangeInternal());
}else{
setRange(getRangeInternal() + getResolution());
}
-
+
double nFloat = w * (getRange() / getRadius()) + 2 * getPadding();
setN((int)Math.ceil(nFloat));
}
@@ -294,7 +299,7 @@ public void initEncoder(int w, double minVal, double maxVal, int n, double radiu
* Return the bit offset of the first bit to be set in the encoder output.
* For periodic encoders, this can be a negative number when the encoded output
* wraps around.
- *
+ *
* @param c the memory
* @param input the input data
* @return an encoded array
@@ -305,10 +310,7 @@ public Integer getFirstOnBit(double input) {
}else{
if(input < getMinVal()) {
if(clipInput() && !isPeriodic()) {
- if(getVerbosity() > 0) {
- System.out.println("Clipped input " + getName() +
- "=" + input + " to minval " + getMinVal());
- }
+ LOGGER.info("Clipped input " + getName() + "=" + input + " to minval " + getMinVal());
input = getMinVal();
}else{
throw new IllegalStateException("input (" + input +") less than range (" +
@@ -316,7 +318,7 @@ public Integer getFirstOnBit(double input) {
}
}
}
-
+
if(isPeriodic()) {
if(input >= getMaxVal()) {
throw new IllegalStateException("input (" + input +") greater than periodic range (" +
@@ -325,10 +327,7 @@ public Integer getFirstOnBit(double input) {
}else{
if(input > getMaxVal()) {
if(clipInput()) {
- if(getVerbosity() > 0) {
- System.out.println("Clipped input " + getName() + "=" + input + " to maxval " + getMaxVal());
- }
-
+ LOGGER.info("Clipped input " + getName() + "=" + input + " to maxval " + getMaxVal());
input = getMaxVal();
}else{
throw new IllegalStateException("input (" + input +") greater than periodic range (" +
@@ -336,17 +335,17 @@ public Integer getFirstOnBit(double input) {
}
}
}
-
+
int centerbin;
if(isPeriodic()) {
centerbin = ((int)((input - getMinVal()) * getNInternal() / getRange())) + getPadding();
}else{
centerbin = ((int)(((input - getMinVal()) + getResolution()/2) / getResolution())) + getPadding();
}
-
+
return centerbin - getHalfWidth();
}
-
+
/**
* Check if the settings are reasonable for the SpatialPooler to work
* @param c
@@ -357,7 +356,7 @@ public void checkReasonableSettings() {
"Number of bits in the SDR (%d) must be greater than 2, and recommended >= 21 (use forced=True to override)");
}
}
-
+
/**
* {@inheritDoc}
*/
@@ -365,7 +364,7 @@ public void checkReasonableSettings() {
public List getDecoderOutputFieldTypes() {
return Arrays.asList(FieldMetaType.FLOAT);
}
-
+
/**
* Should return the output width, in bits.
*/
@@ -373,23 +372,23 @@ public List getDecoderOutputFieldTypes() {
public int getWidth() {
return getN();
}
-
+
/**
* {@inheritDoc}
* NO-OP
*/
@Override
public int[] getBucketIndices(String input) { return null; }
-
+
/**
* Returns the bucket indices.
- *
- * @param input
+ *
+ * @param input
*/
@Override
public int[] getBucketIndices(double input) {
int minbin = getFirstOnBit(input);
-
+
//For periodic encoders, the bucket index is the index of the center bit
int bucketIdx;
if(isPeriodic()) {
@@ -400,10 +399,10 @@ public int[] getBucketIndices(double input) {
}else{//for non-periodic encoders, the bucket index is the index of the left bit
bucketIdx = minbin;
}
-
+
return new int[] { bucketIdx };
}
-
+
/**
* Encodes inputData and puts the encoded value into the output array,
* which is a 1-D array of length returned by {@link Connections#getW()}.
@@ -411,7 +410,7 @@ public int[] getBucketIndices(double input) {
* Note: The output array is reused, so clear it before updating it.
* @param inputData Data to encode. This should be validated by the encoder.
* @param output 1-D array of same length returned by {@link Connections#getW()}
- *
+ *
* @return
*/
@Override
@@ -420,7 +419,7 @@ public void encodeIntoArray(Double input, int[] output) {
Arrays.fill(output, 0);
return;
}
-
+
Integer bucketVal = getFirstOnBit(input);
if(bucketVal != null) {
int bucketIdx = bucketVal;
@@ -441,26 +440,24 @@ public void encodeIntoArray(Double input, int[] output) {
minbin = 0;
}
}
-
+
ArrayUtils.setIndexesTo(output, ArrayUtils.range(minbin, maxbin + 1), 1);
}
-
- if(getVerbosity() >= 2) {
- System.out.println("");
- System.out.println("input: " + input);
- System.out.println("range: " + getMinVal() + " - " + getMaxVal());
- System.out.println("n:" + getN() + "w:" + getW() + "resolution:" + getResolution() +
+
+ LOGGER.trace("");
+ LOGGER.trace("input: " + input);
+ LOGGER.trace("range: " + getMinVal() + " - " + getMaxVal());
+ LOGGER.trace("n:" + getN() + "w:" + getW() + "resolution:" + getResolution() +
"radius:" + getRadius() + "periodic:" + isPeriodic());
- System.out.println("output: " + Arrays.toString(output));
- System.out.println("input desc: " + decode(output, ""));
- }
+ LOGGER.trace("output: " + Arrays.toString(output));
+ LOGGER.trace("input desc: " + decode(output, ""));
}
/**
* Returns a {@link DecodeResult} which is a tuple of range names
- * and lists of {@link RangeLists} in the first entry, and a list
+ * and lists of {@link RangeLists} in the first entry, and a list
* of descriptions for each range in the second entry.
- *
+ *
* @param encoded the encoded bit vector
* @param parentFieldName the field the vector corresponds with
* @return
@@ -470,11 +467,11 @@ public DecodeResult decode(int[] encoded, String parentFieldName) {
// For now, we simply assume any top-down output greater than 0
// is ON. Eventually, we will probably want to incorporate the strength
// of each top-down output.
- if(encoded == null || encoded.length < 1) {
+ if(encoded == null || encoded.length < 1) {
return null;
}
int[] tmpOutput = Arrays.copyOf(encoded, encoded.length);
-
+
// ------------------------------------------------------------------------
// First, assume the input pool is not sampled 100%, and fill in the
// "holes" in the encoded representation (which are likely to be present
@@ -487,7 +484,7 @@ public DecodeResult decode(int[] encoded, String parentFieldName) {
Arrays.fill(searchStr, 1);
ArrayUtils.setRangeTo(searchStr, 1, -1, 0);
int subLen = searchStr.length;
-
+
// Does this search string appear in the output?
if(isPeriodic()) {
for(int j = 0;j < getN();j++) {
@@ -505,13 +502,11 @@ public DecodeResult decode(int[] encoded, String parentFieldName) {
}
}
}
-
- if(getVerbosity() >= 2) {
- System.out.println("raw output:" + Arrays.toString(
+
+ LOGGER.trace("raw output:" + Arrays.toString(
ArrayUtils.sub(encoded, ArrayUtils.range(0, getN()))));
- System.out.println("filtered output:" + Arrays.toString(tmpOutput));
- }
-
+ LOGGER.trace("filtered output:" + Arrays.toString(tmpOutput));
+
// ------------------------------------------------------------------------
// Find each run of 1's.
int[] nz = ArrayUtils.where(tmpOutput, new Condition.Adapter() {
@@ -534,19 +529,19 @@ public boolean eval(int n) {
i += 1;
}
runs.add(new Tuple(2, run[0], run[1]));
-
+
// If we have a periodic encoder, merge the first and last run if they
// both go all the way to the edges
if(isPeriodic() && runs.size() > 1) {
int l = runs.size() - 1;
if(((Integer)runs.get(0).get(0)) == 0 && ((Integer)runs.get(l).get(0)) + ((Integer)runs.get(l).get(1)) == getN()) {
- runs.set(l, new Tuple(2,
- (Integer)runs.get(l).get(0),
+ runs.set(l, new Tuple(2,
+ (Integer)runs.get(l).get(0),
((Integer)runs.get(l).get(1)) + ((Integer)runs.get(0).get(1)) ));
runs = runs.subList(1, runs.size());
}
}
-
+
// ------------------------------------------------------------------------
// Now, for each group of 1's, determine the "left" and "right" edges, where
// the "left" edge is inset by halfwidth and the "right" edge is inset by
@@ -565,7 +560,7 @@ public boolean eval(int n) {
left = start + getHalfWidth();
right = start + runLen - 1 - getHalfWidth();
}
-
+
double inMin, inMax;
// Convert to input space.
if(!isPeriodic()) {
@@ -582,7 +577,7 @@ public boolean eval(int n) {
inMax -= getRange();
}
}
-
+
// Clip low end
if(inMin < getMinVal()) {
inMin = getMinVal();
@@ -590,7 +585,7 @@ public boolean eval(int n) {
if(inMax < getMinVal()) {
inMax = getMinVal();
}
-
+
// If we have a periodic encoder, and the max is past the edge, break into
// 2 separate ranges
if(isPeriodic() && inMax >= getMaxVal()) {
@@ -606,7 +601,7 @@ public boolean eval(int n) {
ranges.add(new MinMax(inMin, inMax));
}
}
-
+
String desc = generateRangeDescription(ranges);
String fieldName;
// Return result
@@ -615,17 +610,17 @@ public boolean eval(int n) {
}else{
fieldName = getName();
}
-
+
RangeList inner = new RangeList(ranges, desc);
Map fieldsDict = new HashMap();
fieldsDict.put(fieldName, inner);
-
+
return new DecodeResult(fieldsDict, Arrays.asList(fieldName));
}
-
+
/**
* Generate description from a text description of the ranges
- *
+ *
* @param ranges A list of {@link MinMax}es.
*/
public String generateRangeDescription(List ranges) {
@@ -643,39 +638,39 @@ public String generateRangeDescription(List ranges) {
}
return desc.toString();
}
-
+
/**
* Return the internal topDownMapping matrix used for handling the
* bucketInfo() and topDownCompute() methods. This is a matrix, one row per
* category (bucket) where each row contains the encoded output for that
* category.
- *
+ *
* @param c the connections memory
* @return the internal topDownMapping
*/
public SparseObjectMatrix getTopDownMapping() {
-
+
if(topDownMapping == null) {
//The input scalar value corresponding to each possible output encoding
if(isPeriodic()) {
setTopDownValues(
- ArrayUtils.arange(getMinVal() + getResolution() / 2.0,
+ ArrayUtils.arange(getMinVal() + getResolution() / 2.0,
getMaxVal(), getResolution()));
}else{
//Number of values is (max-min)/resolutions
setTopDownValues(
- ArrayUtils.arange(getMinVal(), getMaxVal() + getResolution() / 2.0,
+ ArrayUtils.arange(getMinVal(), getMaxVal() + getResolution() / 2.0,
getResolution()));
}
}
-
+
//Each row represents an encoded output pattern
int numCategories = getTopDownValues().length;
SparseObjectMatrix topDownMapping;
setTopDownMapping(
topDownMapping = new SparseObjectMatrix(
new int[] { numCategories }));
-
+
double[] topDownValues = getTopDownValues();
int[] outputSpace = new int[getN()];
double minVal = getMinVal();
@@ -687,13 +682,13 @@ public SparseObjectMatrix getTopDownMapping() {
encodeIntoArray(value, outputSpace);
topDownMapping.set(i, Arrays.copyOf(outputSpace, outputSpace.length));
}
-
+
return topDownMapping;
}
-
+
/**
* {@inheritDoc}
- *
+ *
* @param the input value, in this case a double
* @return a list of one input double
*/
@@ -703,20 +698,20 @@ public TDoubleList getScalars(S d) {
retVal.add((Double)d);
return retVal;
}
-
+
/**
* Returns a list of items, one for each bucket defined by this encoder.
* Each item is the value assigned to that bucket, this is the same as the
* EncoderResult.value that would be returned by getBucketInfo() for that
* bucket and is in the same format as the input that would be passed to
* encode().
- *
+ *
* This call is faster than calling getBucketInfo() on each bucket individually
* if all you need are the bucket values.
*
* @param returnType class type parameter so that this method can return encoder
* specific value types
- *
+ *
* @return list of items, each item representing the bucket value for that
* bucket.
*/
@@ -733,18 +728,18 @@ public List getBucketValues(Class t) {
}
return (List)bucketValues;
}
-
+
/**
* {@inheritDoc}
*/
@Override
public List getBucketInfo(int[] buckets) {
SparseObjectMatrix topDownMapping = getTopDownMapping();
-
+
//The "category" is simply the bucket index
int category = buckets[0];
int[] encoding = topDownMapping.getObject(category);
-
+
//Which input value does this correspond to?
double inputVal;
if(isPeriodic()) {
@@ -752,10 +747,10 @@ public List getBucketInfo(int[] buckets) {
}else{
inputVal = getMinVal() + category * getResolution();
}
-
+
return Arrays.asList(new EncoderResult(inputVal, inputVal, encoding));
}
-
+
/**
* {@inheritDoc}
*/
@@ -763,17 +758,17 @@ public List getBucketInfo(int[] buckets) {
public List topDownCompute(int[] encoded) {
//Get/generate the topDown mapping table
SparseObjectMatrix topDownMapping = getTopDownMapping();
-
+
// See which "category" we match the closest.
int category = ArrayUtils.argmax(rightVecProd(topDownMapping, encoded));
-
- return getBucketInfo(new int[] { category });
+
+ return getBucketInfo(new int[]{category});
}
-
+
/**
* Returns a list of {@link Tuple}s which in this case is a list of
* key value parameter values for this {@code ScalarEncoder}
- *
+ *
* @return a list of {@link Tuple}s
*/
public List dict() {
@@ -784,7 +779,6 @@ public List dict() {
l.add(new Tuple(2, "name", getName()));
l.add(new Tuple(2, "minval", getMinVal()));
l.add(new Tuple(2, "topDownValues", Arrays.toString(getTopDownValues())));
- l.add(new Tuple(2, "verbosity", getVerbosity()));
l.add(new Tuple(2, "clipInput", clipInput()));
l.add(new Tuple(2, "n", getN()));
l.add(new Tuple(2, "padding", getPadding()));
@@ -796,17 +790,17 @@ public List dict() {
l.add(new Tuple(2, "halfwidth", getHalfWidth()));
l.add(new Tuple(2, "resolution", getResolution()));
l.add(new Tuple(2, "rangeInternal", getRangeInternal()));
-
+
return l;
}
/**
* Returns a {@link EncoderBuilder} for constructing {@link ScalarEncoder}s
- *
+ *
* The base class architecture is put together in such a way where boilerplate
* initialization can be kept to a minimum for implementing subclasses, while avoiding
* the mistake-proneness of extremely long argument lists.
- *
+ *
* @see ScalarEncoder.Builder#setStuff(int)
*/
public static class Builder extends Encoder.Builder {
@@ -814,20 +808,20 @@ private Builder() {}
@Override
public ScalarEncoder build() {
- //Must be instantiated so that super class can initialize
+ //Must be instantiated so that super class can initialize
//boilerplate variables.
encoder = new ScalarEncoder();
-
+
//Call super class here
super.build();
-
+
////////////////////////////////////////////////////////
// Implementing classes would do setting of specific //
// vars here together with any sanity checking //
////////////////////////////////////////////////////////
-
+
((ScalarEncoder)encoder).init();
-
+
return (ScalarEncoder)encoder;
}
}
diff --git a/src/main/java/org/numenta/nupic/encoders/SparsePassThroughEncoder.java b/src/main/java/org/numenta/nupic/encoders/SparsePassThroughEncoder.java
index 96402de9..7f88b4c4 100644
--- a/src/main/java/org/numenta/nupic/encoders/SparsePassThroughEncoder.java
+++ b/src/main/java/org/numenta/nupic/encoders/SparsePassThroughEncoder.java
@@ -21,79 +21,87 @@
*/
package org.numenta.nupic.encoders;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.util.Arrays;
+
/**
* Sparse Pass Through Encoder
- * Convert a bitmap encoded as array indicies to an SDR
+ * Convert a bitmap encoded as array indices to an SDR
* Each encoding is an SDR in which w out of n bits are turned on.
- * The input should be an array or string of indicies to turn on
+ * The input should be an array or string of indices to turn on
* Note: the value for n must equal input length * w
* i.e. for n=8 w=1 [0,2,5] => 101001000
* or for n=8 w=1 "0,2,5" => 101001000
* i.e. for n=24 w=3 [0,2,5] => 111000111000000111000000000
* or for n=24 w=3 "0,2,5" => 111000111000000111000000000
- * @author wilsondy (from Python original)
*
+ * @author wilsondy (from Python original)
*/
public class SparsePassThroughEncoder extends PassThroughEncoder {
- private SparsePassThroughEncoder() { super(); }
-
- public SparsePassThroughEncoder(int outputWidth, Integer outputBitsOnCount) {
- super(outputWidth, outputBitsOnCount);
- }
-
- /**
- * Returns a builder for building SparsePassThroughEncoders.
- * This builder may be reused to produce multiple builders
- *
- * @return a {@code SparsePassThroughEncoder.Builder}
- */
- public static Encoder.Builder sparseBuilder() {
- return new SparsePassThroughEncoder.Builder();
- }
+ private SparsePassThroughEncoder() { super(); }
+
+ private static final Logger LOGGER = LoggerFactory.getLogger(SparsePassThroughEncoder.class);
+
+ public SparsePassThroughEncoder(int outputWidth, Integer outputBitsOnCount) {
+ super(outputWidth, outputBitsOnCount);
+ LOGGER.info("Building new SparsePassThroughEncoder instance, outputWidth: {} outputBitsOnCount: {}", outputWidth);
+ }
+
+ /**
+ * Returns a builder for building SparsePassThroughEncoders.
+ * This builder may be reused to produce multiple builders
+ *
+ * @return a {@code SparsePassThroughEncoder.Builder}
+ */
+ public static Encoder.Builder sparseBuilder() {
+ return new SparsePassThroughEncoder.Builder();
+ }
+
+ @Override
+ /**
+ * Convert the array of indices to a bit array and then pass to parent.
+ */
+ public void encodeIntoArray(int[] input, int[] output) {
+
+ int[] denseInput = new int[output.length];
+ for (int i : input) {
+ if (i > denseInput.length)
+ throw new IllegalArgumentException(String.format("Output bit count set too low, need at least %d bits", i));
+ denseInput[i] = 1;
+ }
+ super.encodeIntoArray(denseInput, output);
+ LOGGER.trace("Input: {} \nOutput: {} \n", Arrays.toString(input), Arrays.toString(output));
+ }
+
+ /**
+ * Returns a {@link Encoder.Builder} for constructing {@link SparsePassThroughEncoder}s
+ *
+ * The base class architecture is put together in such a way where boilerplate
+ * initialization can be kept to a minimum for implementing subclasses, while avoiding
+ * the mistake-proneness of extremely long argument lists.
+ */
+ public static class Builder extends Encoder.Builder {
+ private Builder() {}
+
+ @Override
+ public SparsePassThroughEncoder build() {
+ //Must be instantiated so that super class can initialize
+ //boilerplate variables.
+ encoder = new SparsePassThroughEncoder();
+
+ //Call super class here
+ super.build();
+
+ ////////////////////////////////////////////////////////
+ // Implementing classes would do setting of specific //
+ // vars here together with any sanity checking //
+ ////////////////////////////////////////////////////////
- @Override
- /**
- * Convert the array of indices to a bit array and then pass to parent.
- */
- public void encodeIntoArray(int[] input, int[] output){
-
- int[] denseInput = new int[output.length];
- for (int i : input) {
- if(i > denseInput.length)
- throw new IllegalArgumentException(String.format("Output bit count set too low, need at least %d bits", i));
- denseInput[i] = 1;
- }
- super.encodeIntoArray(denseInput, output);
- }
-
- /**
- * Returns a {@link EncoderBuilder} for constructing {@link SparsePassThroughEncoder}s
- *
- * The base class architecture is put together in such a way where boilerplate
- * initialization can be kept to a minimum for implementing subclasses, while avoiding
- * the mistake-proneness of extremely long argument lists.
- *
- */
- public static class Builder extends Encoder.Builder {
- private Builder() {}
+ ((SparsePassThroughEncoder) encoder).init();
- @Override
- public SparsePassThroughEncoder build() {
- //Must be instantiated so that super class can initialize
- //boilerplate variables.
- encoder = new SparsePassThroughEncoder();
-
- //Call super class here
- super.build();
-
- ////////////////////////////////////////////////////////
- // Implementing classes would do setting of specific //
- // vars here together with any sanity checking //
- ////////////////////////////////////////////////////////
-
- ((SparsePassThroughEncoder)encoder).init();
-
- return (SparsePassThroughEncoder)encoder;
- }
- }
+ return (SparsePassThroughEncoder) encoder;
+ }
+ }
}
diff --git a/src/main/java/org/numenta/nupic/research/SpatialPooler.java b/src/main/java/org/numenta/nupic/research/SpatialPooler.java
index 876f117b..48132b37 100644
--- a/src/main/java/org/numenta/nupic/research/SpatialPooler.java
+++ b/src/main/java/org/numenta/nupic/research/SpatialPooler.java
@@ -280,7 +280,7 @@ public void updateMinDutyCyclesLocal(Connections c) {
/**
* Updates the duty cycles for each column. The OVERLAP duty cycle is a moving
- * average of the number of inputs which overlapped with the each column. The
+ * average of the number of inputs which overlapped with each column. The
* ACTIVITY duty cycles is a moving average of the frequency of activation for
* each column.
*
@@ -425,10 +425,8 @@ public double avgColumnsPerInput(Connections c) {
* survived inhibition.
*/
public void adaptSynapses(Connections c, int[] inputVector, int[] activeColumns) {
- int[] inputIndices = ArrayUtils.where(inputVector, new Condition.Adapter