diff --git a/src/htm/algorithms/SDRClassifier.cpp b/src/htm/algorithms/SDRClassifier.cpp index 277cca1ae4..ac807da209 100644 --- a/src/htm/algorithms/SDRClassifier.cpp +++ b/src/htm/algorithms/SDRClassifier.cpp @@ -52,7 +52,7 @@ PDF Classifier::infer(const SDR & pattern) const { NTA_WARN << "Classifier: must call `learn` before `infer`."; return PDF(numCategories_, std::nan("")); //empty array [] } - NTA_ASSERT(pattern.size == dimensions_) << "Input SDR does not match previously seen size!"; + NTA_CHECK(pattern.size == dimensions_) << "Input SDR does not match previously seen size!"; // Accumulate feed forward input. PDF probabilities( numCategories_, 0.0f ); diff --git a/src/htm/regions/ClassifierRegion.cpp b/src/htm/regions/ClassifierRegion.cpp index 556f5a729b..527d8f8cfb 100644 --- a/src/htm/regions/ClassifierRegion.cpp +++ b/src/htm/regions/ClassifierRegion.cpp @@ -78,16 +78,20 @@ namespace htm { }, inputs: { bucket: { description: "The quantized value of the current sample, one from each encoder if more than one, for the learn step", - type: Real64, count: 0}, - pattern: { description: "An SDR output bit pattern for a sample. Usually the output of the SP or TM. For example: activeCells from TM", - type: SDR, count: 0} - }, + type: Real64, count: 0}, + pattern: { description: "An SDR output bit pattern for a sample. Usually the output of the SP or TM. For example: activeCells from TM", + type: SDR, count: 0}, + inferPattern: { description: "An SDR output bit pattern for a sample. Usually the output of the SP or TM. For example: predictiveCells from TM", + type: SDR, count: 0}, + learnPattern: { description: "An SDR output bit pattern for a sample. Usually the output of the SP or TM. For example: activeCells from TM", + type: SDR, count: 0} + }, outputs: { pdf: { description: "probability distribution function (pdf) for each category or bucket. Sorted by title. Warning, buffer length will grow.", type: Real64, count: 0}, titles: { description: "Quantized values of used samples which are the Titles corresponding to the pdf indexes. Sorted by title. Warning, buffer length will grow.", type: Real64, count: 0}, - predicted: { description: "An index (into pdf and titles) with the highest probability of being the match with the current pattern.", + predicted: { description: "An index (into pdf and titles) with the highest probability of being the match with the current inferPattern.", type: UInt32, count: 1} } } @@ -139,16 +143,13 @@ Dimensions ClassifierRegion::askImplForOutputDimensions(const std::string &name) void ClassifierRegion::compute() { SDR &pattern = getInput("pattern")->getData().getSDR(); - // Note: if there is no link to 'pattern' input, the 'pattern' SDR length is 0 - // and SDRClassifier::infer() will throw an exception. - if (learn_) { Array &b = getInput("bucket")->getData(); // 'bucket' is a list of quantized samples being processed for this iteration. // There are one of these for each encoder (or value being encoded). // The values might not be consecutive, or in different ranges, or different things entirely. // We build a map and a corresponding vector containing the quantized samples actually used. - // This vector becomes the titles. The index into this list will be a consecutive list that + // This vector becomes the titles. The index into this list will be a consecutive list that // we can presented to the Classifier which produces the pdf. Note that the indexes used // by the classifier are not sorted by title but rather by the order in which an index is first seen. std::vector categoryIdxList; @@ -166,9 +167,25 @@ void ClassifierRegion::compute() { } categoryIdxList.push_back(c); } - classifier_->learn(pattern, categoryIdxList); + + SDR &learnPattern = getInput("learnPattern")->getData().getSDR(); + if (learnPattern.size == 0) { + classifier_->learn(pattern, categoryIdxList); + } else { + classifier_->learn(learnPattern, categoryIdxList); + } + } + + SDR &inferPattern = getInput("inferPattern")->getData().getSDR(); + // Note: if there is no link to 'inferPattern' input, the 'inferPattern' SDR length is 0 + // and SDRClassifier::infer() will throw an exception. + // + PDF pdf; + if (inferPattern.size == 0) { + pdf = classifier_->infer(pattern); + } else { + pdf = classifier_->infer(inferPattern); } - PDF pdf = classifier_->infer(pattern); // Adjust the buffer size to match the pdf. if (getOutput("pdf")->getData().getCount() < pdf.size()) { @@ -189,7 +206,7 @@ void ClassifierRegion::compute() { size_t i = itm.second; if (pdf[i] > m) { m = pdf[i]; - predicted[0] = static_cast(j); // index of the quantized sample with the highest probability of matching the pattern + predicted[0] = static_cast(j); // index of the quantized sample with the highest probability of matching the inferPattern } out[j] = pdf[i]; titles[j] = bucketList[i];