diff --git a/Demo/DemoConsoleApp/Examples/TTOOForecastFromScratch.cs b/Demo/DemoConsoleApp/Examples/TTOOForecastFromScratch.cs
index 3148783..cc2d939 100644
--- a/Demo/DemoConsoleApp/Examples/TTOOForecastFromScratch.cs
+++ b/Demo/DemoConsoleApp/Examples/TTOOForecastFromScratch.cs
@@ -263,11 +263,11 @@ ReadoutLayerSettings CreateReadoutLayerCfg(double foldDataRatio, int numOfAttemp
ReadoutUnitSettings highReadoutUnitCfg = new ReadoutUnitSettings("High", new ForecastTaskSettings(new RealFeatureFilterSettings()));
ReadoutUnitSettings lowReadoutUnitCfg = new ReadoutUnitSettings("Low", new ForecastTaskSettings(new RealFeatureFilterSettings()));
//Create readout layer configuration
- ReadoutLayerSettings readoutLayerCfg = new ReadoutLayerSettings(new CrossvalidationSettings(foldDataRatio),
+ ReadoutLayerSettings readoutLayerCfg = new ReadoutLayerSettings(new ClusterSettings(new CrossvalidationSettings(foldDataRatio), defaultNetworksCfg),
new ReadoutUnitsSettings(highReadoutUnitCfg,
lowReadoutUnitCfg
),
- defaultNetworksCfg
+ null
);
return readoutLayerCfg;
}
diff --git a/Demo/DemoConsoleApp/Playground.cs b/Demo/DemoConsoleApp/Playground.cs
index 021710d..b15d858 100644
--- a/Demo/DemoConsoleApp/Playground.cs
+++ b/Demo/DemoConsoleApp/Playground.cs
@@ -13,6 +13,8 @@
using RCNet.MathTools;
using RCNet.Neural.Data.Filter;
using RCNet.Neural.Data;
+using RCNet.Neural.Network.NonRecurrent;
+using RCNet.Neural.Network.NonRecurrent.FF;
namespace Demo.DemoConsoleApp
{
@@ -339,6 +341,84 @@ private void TestVectorBundleFolderization(string dataFile, int numOfClasses)
return;
}
+ ///
+ /// Displays information about the readout unit regression progress.
+ ///
+ /// Current state of the regression process
+ /// Indicates that the best readout unit was changed as a result of the performed epoch
+ private void OnRegressionEpochDone(TrainedOneTakesAllNetworkBuilder.BuildingState buildingState, bool foundBetter)
+ {
+ int reportEpochsInterval = 5;
+ //Progress info
+ if (foundBetter ||
+ (buildingState.Epoch % reportEpochsInterval) == 0 ||
+ buildingState.Epoch == buildingState.MaxEpochs ||
+ (buildingState.Epoch == 1 && buildingState.RegrAttemptNumber == 1)
+ )
+ {
+ //Build progress report message
+ string progressText = buildingState.GetProgressInfo(4);
+ //Report the progress
+ if((buildingState.Epoch == 1 && buildingState.RegrAttemptNumber == 1))
+ {
+ Console.WriteLine();
+ }
+ Console.Write("\x0d" + progressText + " ");
+ }
+ return;
+ }
+
+
+ private void TestTrainedOneTakesAllClusterAndBuilder(string trainDataFile, string verifyDataFile, int numOfClasses, double foldDataRatio = 0.1d)
+ {
+ Console.BufferWidth = 320;
+ Console.WriteLine("One Takes All - Cluster and Cluster builder test");
+ //Load csv data and create vector bundle
+ Console.WriteLine($"Loading {trainDataFile}...");
+ CsvDataHolder trainCsvData = new CsvDataHolder(trainDataFile);
+ VectorBundle trainData = VectorBundle.Load(trainCsvData, numOfClasses);
+ Console.WriteLine($"Loading {verifyDataFile}...");
+ CsvDataHolder verifyCsvData = new CsvDataHolder(verifyDataFile);
+ VectorBundle verifyData = VectorBundle.Load(verifyCsvData, numOfClasses);
+ Console.WriteLine($"Cluster training on {trainDataFile}...");
+ //TRAINING
+ List netCfgs = new List();
+ netCfgs.Add(new FeedForwardNetworkSettings(new AFAnalogSoftMaxSettings(),
+ new HiddenLayersSettings(new HiddenLayerSettings(30, new AFAnalogTanHSettings())),
+ new RPropTrainerSettings(5, 750)
+ )
+ );
+
+ TrainedOneTakesAllNetworkClusterBuilder builder =
+ new TrainedOneTakesAllNetworkClusterBuilder("Test",
+ netCfgs,
+ null,
+ null
+ );
+ builder.RegressionEpochDone += OnRegressionEpochDone;
+ TrainedOneTakesAllNetworkCluster tc = builder.Build(trainData, new CrossvalidationSettings(foldDataRatio));
+
+ //VERIFICATION
+ Console.WriteLine();
+ Console.WriteLine();
+ Console.WriteLine($"Cluster verification on {verifyDataFile}...");
+ Console.WriteLine();
+ int numOfErrors = 0;
+ for(int i = 0; i < verifyData.InputVectorCollection.Count; i++)
+ {
+ double[] computed = tc.Compute(verifyData.InputVectorCollection[i]);
+ int computedWinnerIdx = computed.MaxIdx();
+ int realWinnerIdx = verifyData.OutputVectorCollection[i].MaxIdx();
+ if (computedWinnerIdx != realWinnerIdx) ++numOfErrors;
+ Console.Write("\x0d" + $"({i+1}/{verifyData.InputVectorCollection.Count}) Errors:{numOfErrors}...");
+ }
+ Console.WriteLine();
+ Console.WriteLine($"Accuracy {(1d - (double)numOfErrors/(double)verifyData.InputVectorCollection.Count).ToString(CultureInfo.InvariantCulture)}");
+ Console.WriteLine();
+
+ return;
+ }
+
///
/// Playground's entry point
///
@@ -346,7 +426,9 @@ public void Run()
{
Console.Clear();
//TODO - place your code here
- TestVectorBundleFolderization("./Data/ProximalPhalanxOutlineAgeGroup_train.csv", 3);
+ //TestVectorBundleFolderization("./Data/ProximalPhalanxOutlineAgeGroup_train.csv", 3);
+ TestTrainedOneTakesAllClusterAndBuilder("./Data/LibrasMovement_train.csv", "./Data/LibrasMovement_verify.csv", 15, 0.1d);
+ TestTrainedOneTakesAllClusterAndBuilder("./Data/ProximalPhalanxOutlineAgeGroup_train.csv", "./Data/ProximalPhalanxOutlineAgeGroup_verify.csv", 3, 0.1d);
return;
}
diff --git a/Demo/DemoConsoleApp/SMDemoSettings.xml b/Demo/DemoConsoleApp/SMDemoSettings.xml
index 73639d6..0b189e4 100644
--- a/Demo/DemoConsoleApp/SMDemoSettings.xml
+++ b/Demo/DemoConsoleApp/SMDemoSettings.xml
@@ -69,15 +69,17 @@
-
-
-
-
-
-
-
-
-
+
+
+
+
+
+
+
+
+
+
+
@@ -210,15 +212,17 @@
-
-
-
-
-
-
-
-
-
+
+
+
+
+
+
+
+
+
+
+
@@ -344,8 +348,10 @@
-
-
+
+
+
+
@@ -424,8 +430,10 @@
-
-
+
+
+
+
@@ -519,15 +527,29 @@
-
-
-
+
+
+
+
+
+
+
+
+
+
+
+
-
-
+
+
+
+
+
+
+
-
-
+
+
@@ -538,18 +560,6 @@
-
-
-
-
-
-
-
-
-
-
-
-
@@ -612,8 +622,10 @@
-
-
+
+
+
+
@@ -709,15 +721,17 @@
-
-
-
-
-
-
-
-
-
+
+
+
+
+
+
+
+
+
+
+
@@ -795,23 +809,32 @@
-
-
-
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
-
-
-
-
-
-
-
-
-
+
-
-
+
+
@@ -822,13 +845,6 @@
-
-
-
-
-
-
-
@@ -947,8 +963,10 @@
-
-
+
+
+
+
@@ -1041,15 +1059,17 @@
-
-
-
-
-
-
-
-
-
+
+
+
+
+
+
+
+
+
+
+
@@ -1132,8 +1152,10 @@
-
-
+
+
+
+
@@ -1223,23 +1245,25 @@
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
@@ -1365,15 +1389,17 @@
-
-
-
-
-
-
-
-
-
+
+
+
+
+
+
+
+
+
+
+
@@ -1451,23 +1477,25 @@
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
@@ -1623,23 +1651,25 @@
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/RCNet/Extensions/DoubleArrayExtensions.cs b/RCNet/Extensions/DoubleArrayExtensions.cs
index 40bf5d4..723a041 100644
--- a/RCNet/Extensions/DoubleArrayExtensions.cs
+++ b/RCNet/Extensions/DoubleArrayExtensions.cs
@@ -66,6 +66,25 @@ public static double Max(this double[] array)
return max;
}
+ ///
+ /// Returns an index of the max value within an array
+ ///
+ [MethodImpl(MethodImplOptions.AggressiveInlining)]
+ public static int MaxIdx(this double[] array)
+ {
+ double max = double.MinValue;
+ int maxIdx = 0;
+ for (int i = 0; i < array.Length; i++)
+ {
+ if (max < array[i])
+ {
+ max = array[i];
+ maxIdx = i;
+ }
+ }
+ return maxIdx;
+ }
+
///
/// Returns the min value within an array
///
@@ -83,6 +102,25 @@ public static double Min(this double[] array)
return min;
}
+ ///
+ /// Returns an index of the min value within an array
+ ///
+ [MethodImpl(MethodImplOptions.AggressiveInlining)]
+ public static int MinIdx(this double[] array)
+ {
+ double min = double.MaxValue;
+ int minIdx = 0;
+ for (int i = 0; i < array.Length; i++)
+ {
+ if (min > array[i])
+ {
+ min = array[i];
+ minIdx = i;
+ }
+ }
+ return minIdx;
+ }
+
///
/// Returns sum of values within an array
///
diff --git a/RCNet/Neural/Data/VectorBundle.cs b/RCNet/Neural/Data/VectorBundle.cs
index b20e92c..748ff1a 100644
--- a/RCNet/Neural/Data/VectorBundle.cs
+++ b/RCNet/Neural/Data/VectorBundle.cs
@@ -223,119 +223,6 @@ public void Add(VectorBundle data)
}
//Methods
- public List Folderize(double foldDataRatio, double binBorder = double.NaN)
- {
- if (OutputVectorCollection.Count < 2)
- {
- throw new InvalidOperationException($"Insufficient number of samples ({OutputVectorCollection.Count.ToString(CultureInfo.InvariantCulture)}).");
- }
- List bundleCollection = new List();
- //Fold data ratio basic correction
- if (foldDataRatio > MaxRatioOfFoldData)
- {
- foldDataRatio = MaxRatioOfFoldData;
- }
- //Initial fold size estimation
- int foldSize = Math.Max(1, (int)Math.Round(OutputVectorCollection.Count * foldDataRatio, 0));
- //Initial number of folds
- int numOfFolds = OutputVectorCollection.Count / foldSize;
- //Folds creation
- if (double.IsNaN(binBorder))
- {
- //No binary output
- int samplesPos = 0;
- for (int bundleNum = 0; bundleNum < numOfFolds; bundleNum++)
- {
- VectorBundle bundle = new VectorBundle();
- for (int i = 0; i < foldSize && samplesPos < OutputVectorCollection.Count; i++)
- {
- bundle.InputVectorCollection.Add(InputVectorCollection[samplesPos]);
- bundle.OutputVectorCollection.Add(OutputVectorCollection[samplesPos]);
- ++samplesPos;
- }
- bundleCollection.Add(bundle);
- }
- //Remaining samples
- for (int i = 0; i < OutputVectorCollection.Count - samplesPos; i++)
- {
- int bundleIdx = i % bundleCollection.Count;
- bundleCollection[bundleIdx].InputVectorCollection.Add(InputVectorCollection[samplesPos + i]);
- bundleCollection[bundleIdx].OutputVectorCollection.Add(OutputVectorCollection[samplesPos + i]);
- }
-
- }
- else
- {
- //Binary output
- BinDistribution refBinDistr = new BinDistribution(binBorder);
- refBinDistr.Update(OutputVectorCollection, 0);
- int min01 = Math.Min(refBinDistr.NumOf[0], refBinDistr.NumOf[1]);
- if (min01 < 2)
- {
- throw new InvalidOperationException($"Insufficient bin 0 or 1 samples (less than 2).");
- }
- if (numOfFolds > min01)
- {
- numOfFolds = min01;
- }
- //Scan data
- int[] bin0SampleIdxs = new int[refBinDistr.NumOf[0]];
- int bin0SamplesPos = 0;
- int[] bin1SampleIdxs = new int[refBinDistr.NumOf[1]];
- int bin1SamplesPos = 0;
- for (int i = 0; i < OutputVectorCollection.Count; i++)
- {
- if (OutputVectorCollection[i][0] >= refBinDistr.BinBorder)
- {
- bin1SampleIdxs[bin1SamplesPos++] = i;
- }
- else
- {
- bin0SampleIdxs[bin0SamplesPos++] = i;
- }
- }
- //Determine distributions of 0 and 1 for one fold
- int bundleBin0Count = Math.Max(1, refBinDistr.NumOf[0] / numOfFolds);
- int bundleBin1Count = Math.Max(1, refBinDistr.NumOf[1] / numOfFolds);
- //Bundles creation
- bin0SamplesPos = 0;
- bin1SamplesPos = 0;
- for (int bundleNum = 0; bundleNum < numOfFolds; bundleNum++)
- {
- VectorBundle bundle = new VectorBundle();
- //Bin 0
- for (int i = 0; i < bundleBin0Count; i++)
- {
- bundle.InputVectorCollection.Add(InputVectorCollection[bin0SampleIdxs[bin0SamplesPos]]);
- bundle.OutputVectorCollection.Add(OutputVectorCollection[bin0SampleIdxs[bin0SamplesPos]]);
- ++bin0SamplesPos;
- }
- //Bin 1
- for (int i = 0; i < bundleBin1Count; i++)
- {
- bundle.InputVectorCollection.Add(InputVectorCollection[bin1SampleIdxs[bin1SamplesPos]]);
- bundle.OutputVectorCollection.Add(OutputVectorCollection[bin1SampleIdxs[bin1SamplesPos]]);
- ++bin1SamplesPos;
- }
- bundleCollection.Add(bundle);
- }
- //Remaining samples
- for (int i = 0; i < bin0SampleIdxs.Length - bin0SamplesPos; i++)
- {
- int bundleIdx = i % bundleCollection.Count;
- bundleCollection[bundleIdx].InputVectorCollection.Add(InputVectorCollection[bin0SampleIdxs[bin0SamplesPos + i]]);
- bundleCollection[bundleIdx].OutputVectorCollection.Add(OutputVectorCollection[bin0SampleIdxs[bin0SamplesPos + i]]);
- }
- for (int i = 0; i < bin1SampleIdxs.Length - bin1SamplesPos; i++)
- {
- int bundleIdx = i % bundleCollection.Count;
- bundleCollection[bundleIdx].InputVectorCollection.Add(InputVectorCollection[bin1SampleIdxs[bin1SamplesPos + i]]);
- bundleCollection[bundleIdx].OutputVectorCollection.Add(OutputVectorCollection[bin1SampleIdxs[bin1SamplesPos + i]]);
- }
- }
-
- return bundleCollection;
- }
///
/// Splits this bundle to a collection of smaller folds (sub-bundles) suitable for the cross-validation.
/// When the binBorder is specified then all output features are considered as binary
@@ -345,7 +232,7 @@ public List Folderize(double foldDataRatio, double binBorder = dou
/// Requested ratio of the samples constituting one fold (sub-bundle).
/// When specified then method keeps balanced ratios of 0 and 1 values in each fold sub-bundle.
/// Collection of created folds.
- public List Folderize_new(double foldDataRatio, double binBorder = double.NaN)
+ public List Folderize(double foldDataRatio, double binBorder = double.NaN)
{
if (OutputVectorCollection.Count < 2)
{
diff --git a/RCNet/Neural/Network/NonRecurrent/NetworkClusterSecondLevelCompSettings.cs b/RCNet/Neural/Network/NonRecurrent/NetworkClusterSecondLevelCompSettings.cs
index be888e4..1c3598e 100644
--- a/RCNet/Neural/Network/NonRecurrent/NetworkClusterSecondLevelCompSettings.cs
+++ b/RCNet/Neural/Network/NonRecurrent/NetworkClusterSecondLevelCompSettings.cs
@@ -122,7 +122,7 @@ public override XElement GetXml(string rootElemName, bool suppressDefaults)
///
public override XElement GetXml(bool suppressDefaults)
{
- return GetXml("clusterSecondLevelComputation", suppressDefaults);
+ return GetXml("secondLevelComputation", suppressDefaults);
}
}//NetworkClusterSecondLevelCompSettings
diff --git a/RCNet/Neural/Network/NonRecurrent/TrainedOneTakesAllNetwork.cs b/RCNet/Neural/Network/NonRecurrent/TrainedOneTakesAllNetwork.cs
new file mode 100644
index 0000000..5393b38
--- /dev/null
+++ b/RCNet/Neural/Network/NonRecurrent/TrainedOneTakesAllNetwork.cs
@@ -0,0 +1,105 @@
+using RCNet.MathTools;
+using RCNet.Neural.Network.NonRecurrent.FF;
+using System;
+
+namespace RCNet.Neural.Network.NonRecurrent
+{
+ ///
+ /// Contains trained "One Takes All" network and related important error statistics from training/testing.
+ /// Network is Feed Forward Network having multiple probability outputs and SoftMax output activation.
+ ///
+ [Serializable]
+ public class TrainedOneTakesAllNetwork
+ {
+ //Attribute properties
+ ///
+ /// Name of the trained network
+ ///
+ public string NetworkName { get; set; }
+ ///
+ /// Trained network
+ ///
+ public FeedForwardNetwork Network { get; set; }
+ ///
+ /// Training error statistics
+ ///
+ public BasicStat TrainingErrorStat { get; set; }
+ ///
+ /// Training binary error statistics.
+ ///
+ public BinErrStat TrainingBinErrorStat { get; set; }
+ ///
+ /// Testing error statistics
+ ///
+ public BasicStat TestingErrorStat { get; set; }
+ ///
+ /// Testing binary error statistics.
+ ///
+ public BinErrStat TestingBinErrorStat { get; set; }
+ ///
+ /// Statistics of the network weights
+ ///
+ public BasicStat OutputWeightsStat { get; set; }
+ ///
+ /// Achieved training/testing combined precision error
+ ///
+ public double CombinedPrecisionError { get; set; }
+ ///
+ /// Achieved training/testing combined binary error.
+ ///
+ public double CombinedBinaryError { get; set; }
+
+ //Constructors
+ ///
+ /// Creates an unitialized instance
+ ///
+ public TrainedOneTakesAllNetwork()
+ {
+ NetworkName = string.Empty;
+ Network = null;
+ TrainingErrorStat = null;
+ TrainingBinErrorStat = null;
+ TestingErrorStat = null;
+ TestingBinErrorStat = null;
+ OutputWeightsStat = null;
+ CombinedPrecisionError = -1d;
+ CombinedBinaryError = -1d;
+ return;
+ }
+
+ ///
+ /// The deep copy constructor.
+ ///
+ /// Source instance
+ public TrainedOneTakesAllNetwork(TrainedOneTakesAllNetwork source)
+ {
+ NetworkName = source.NetworkName;
+ Network = (FeedForwardNetwork)source.Network?.DeepClone();
+ TrainingErrorStat = source.TrainingErrorStat?.DeepClone();
+ TrainingBinErrorStat = source.TrainingBinErrorStat?.DeepClone();
+ TestingErrorStat = source.TestingErrorStat?.DeepClone();
+ TestingBinErrorStat = source.TestingBinErrorStat?.DeepClone();
+ OutputWeightsStat = source.OutputWeightsStat?.DeepClone();
+ CombinedPrecisionError = source.CombinedPrecisionError;
+ CombinedBinaryError = source.CombinedBinaryError;
+ return;
+ }
+
+ //Properties
+ ///
+ /// Indicates that the network ideal output is binary
+ ///
+ public bool BinaryOutput { get { return true; } }
+
+ //Methods
+ ///
+ /// Creates the deep copy instance of this instance
+ ///
+ public TrainedOneTakesAllNetwork DeepClone()
+ {
+ return new TrainedOneTakesAllNetwork(this);
+ }
+
+ }//TrainedOneTakesAllNetwork
+
+}//Namespace
diff --git a/RCNet/Neural/Network/NonRecurrent/TrainedOneTakesAllNetworkBuilder.cs b/RCNet/Neural/Network/NonRecurrent/TrainedOneTakesAllNetworkBuilder.cs
new file mode 100644
index 0000000..57fbeac
--- /dev/null
+++ b/RCNet/Neural/Network/NonRecurrent/TrainedOneTakesAllNetworkBuilder.cs
@@ -0,0 +1,486 @@
+using RCNet.MathTools;
+using RCNet.Neural.Activation;
+using RCNet.Neural.Data;
+using RCNet.Neural.Network.NonRecurrent.FF;
+using System;
+using System.Collections.Generic;
+using System.Globalization;
+using System.Text;
+
+namespace RCNet.Neural.Network.NonRecurrent
+{
+ ///
+ /// Builds trained "One Takes All" network.
+ /// Supported is Feed Forward Network having multiple probability outputs and SoftMax output activation.
+ ///
+ public class TrainedOneTakesAllNetworkBuilder
+ {
+ //Delegates
+ ///
+ /// This is the control function of the regression process and it is called after the completion of each regression training epoch.
+ /// The goal of the regression process is to train a network that will give good results both on the training data and the test data.
+ /// BuildingState object contains the best error statistics so far and the latest statistics. The primary purpose of this function is
+ /// to decide whether the latest statistics are better than the best statistics so far.
+ /// Function can also tell the regression process that it does not make any sense to continue the regression. It can
+ /// terminate the current regression attempt or whole regression process.
+ ///
+ /// Contains all the necessary information to control the regression.
+ /// Instructions for the regression process.
+ public delegate BuildingInstr RegressionControllerDelegate(BuildingState buildingState);
+
+ ///
+ /// Delegate of RegressionEpochDone event handler.
+ ///
+ /// Current state of the regression process
+ /// Indicates that the best network was found as a result of the performed epoch
+ public delegate void RegressionEpochDoneHandler(BuildingState buildingState, bool foundBetter);
+
+ ///
+ /// This informative event occurs every time the regression epoch is done
+ ///
+ [field: NonSerialized]
+ public event RegressionEpochDoneHandler RegressionEpochDone;
+
+ //Attributes
+ private readonly string _networkName;
+ private readonly FeedForwardNetworkSettings _networkSettings;
+ private readonly int _foldNum;
+ private readonly int _numOfFolds;
+ private readonly int _foldNetworkNum;
+ private readonly int _numOfFoldNetworks;
+ private readonly VectorBundle _trainingBundle;
+ private readonly VectorBundle _testingBundle;
+ private readonly Random _rand;
+ private readonly RegressionControllerDelegate _controller;
+
+ //Constructor
+ ///
+ /// Creates an instance ready to start building trained "One Takes All" network
+ ///
+ /// Name of the network to be built
+ /// Feed Forward Network configuration (required is SoftMax output activation)
+ /// Current fold number
+ /// Total number of the folds
+ /// Current fold network number
+ /// Total number of the fold networks
+ /// Bundle of predictors and ideal values to be used for training purposes. Ideal values must be 0 or 1 and only one 1 must be set in the ideal vector.
+ /// Bundle of predictors and ideal values to be used for testing purposes. Ideal values must be 0 or 1 and only one 1 must be set in the ideal vector.
+ /// Random generator to be used (optional)
+ /// Regression controller (optional)
+ public TrainedOneTakesAllNetworkBuilder(string networkName,
+ FeedForwardNetworkSettings networkSettings,
+ int foldNum,
+ int numOfFolds,
+ int foldNetworkNum,
+ int numOfFoldNetworks,
+ VectorBundle trainingBundle,
+ VectorBundle testingBundle,
+ Random rand = null,
+ RegressionControllerDelegate controller = null
+ )
+ {
+ //Check network config
+ CheckNetCfg(networkSettings);
+ //Check the data
+ CheckBinOutput(trainingBundle);
+ CheckBinOutput(testingBundle);
+ //Checking passed, continue
+ _networkName = networkName;
+ _networkSettings = networkSettings;
+ _foldNum = foldNum;
+ _numOfFolds = numOfFolds;
+ _foldNetworkNum = foldNetworkNum;
+ _numOfFoldNetworks = numOfFoldNetworks;
+ _trainingBundle = trainingBundle;
+ _testingBundle = testingBundle;
+ _rand = rand ?? new Random(0);
+ _controller = controller ?? DefaultRegressionController;
+ return;
+ }
+
+ //Properties
+
+ //Static methods
+ ///
+ /// Checks the output vector contains only 0 and one 1.
+ ///
+ /// Data to be checked
+ public void CheckBinOutput(VectorBundle data)
+ {
+ foreach(double[] outputVector in data.OutputVectorCollection)
+ {
+ if(outputVector.Length <= 1)
+ {
+ throw new InvalidOperationException($"Number of output vector values must be GT 1.");
+ }
+ int bin1Counter = 0;
+ int bin0Counter = 0;
+ for (int i = 0; i < outputVector.Length; i++)
+ {
+ if(outputVector[i] == 0d)
+ {
+ ++bin0Counter;
+ }
+ else if(outputVector[i] == 1d)
+ {
+ ++bin1Counter;
+ }
+ else
+ {
+ throw new ArgumentException($"Output data vectors contain different values tha 0 or 1.", "data");
+ }
+ }
+ if (bin1Counter != 1)
+ {
+ throw new ArgumentException($"Output data vector contains more than one 1.", "data");
+ }
+ }
+ return;
+ }
+
+ ///
+ /// Function checks that network has SoftMax output activation and associated RProp trainer.
+ ///
+ /// Feed Forward Network configuration
+ public static void CheckNetCfg(FeedForwardNetworkSettings netCfg)
+ {
+ if (netCfg.OutputActivationCfg.GetType() != typeof(AFAnalogSoftMaxSettings))
+ {
+ throw new ArgumentException($"Feed forward network must have SoftMax output activation.", "netCfg");
+ }
+ if (netCfg.TrainerCfg.GetType() != typeof(RPropTrainerSettings))
+ {
+ throw new ArgumentException($"Feed forward network must have associated RProp trainer.", "netCfg");
+ }
+ return;
+ }
+
+ ///
+ /// This is a default implementation of an evaluation whether the "candidate" network
+ /// achieved a better result than the best network so far
+ ///
+ /// Network to be evaluated
+ /// The best network so far
+ public static bool IsBetter(TrainedOneTakesAllNetwork candidate, TrainedOneTakesAllNetwork currentBest)
+ {
+ if (candidate.CombinedBinaryError > currentBest.CombinedBinaryError)
+ {
+ return false;
+ }
+ else if (candidate.CombinedBinaryError < currentBest.CombinedBinaryError)
+ {
+ return true;
+ }
+ //CombinedBinaryError is the same
+ else if (candidate.TestingBinErrorStat.BinValErrStat[0].Sum > currentBest.TestingBinErrorStat.BinValErrStat[0].Sum)
+ {
+ return false;
+ }
+ else if (candidate.TestingBinErrorStat.BinValErrStat[0].Sum < currentBest.TestingBinErrorStat.BinValErrStat[0].Sum)
+ {
+ return true;
+ }
+ //CombinedBinaryError is the same
+ //TestingBinErrorStat.BinValErrStat[0].Sum is the same
+ else if (candidate.TrainingBinErrorStat.BinValErrStat[0].Sum > currentBest.TrainingBinErrorStat.BinValErrStat[0].Sum)
+ {
+ return false;
+ }
+ else if (candidate.TrainingBinErrorStat.BinValErrStat[0].Sum < currentBest.TrainingBinErrorStat.BinValErrStat[0].Sum)
+ {
+ return true;
+ }
+ //CombinedBinaryError is the same
+ //TestingBinErrorStat.BinValErrStat[0].Sum is the same
+ //TrainingBinErrorStat.BinValErrStat[0].Sum is the same
+ else if (candidate.CombinedPrecisionError < currentBest.CombinedPrecisionError)
+ {
+ return true;
+ }
+ else
+ {
+ return false;
+ }
+ }
+
+ //Instance methods
+ private BuildingInstr DefaultRegressionController(BuildingState buildingState)
+ {
+ BuildingInstr instructions = new BuildingInstr
+ {
+ CurrentIsBetter = IsBetter(buildingState.CurrNetwork,
+ buildingState.BestNetwork
+ ),
+ StopProcess = (buildingState.BestNetwork.CombinedBinaryError == 0 &&
+ buildingState.CurrNetwork.CombinedPrecisionError > buildingState.BestNetwork.CombinedPrecisionError
+ )
+ };
+ return instructions;
+ }
+
+ ///
+ /// Builds trained network
+ ///
+ /// Trained network
+ public TrainedOneTakesAllNetwork Build()
+ {
+ TrainedOneTakesAllNetwork bestNetwork = null;
+ int bestNetworkAttempt = 0;
+ int currNetworkLastImprovementEpoch = 0;
+ double currNetworkLastImprovementCombinedPrecisionError = 0d;
+ double currNetworkLastImprovementCombinedBinaryError = 0d;
+ //Create network and trainer
+ NonRecurrentNetUtils.CreateNetworkAndTrainer(_networkSettings,
+ _trainingBundle.InputVectorCollection,
+ _trainingBundle.OutputVectorCollection,
+ _rand,
+ out INonRecurrentNetwork net,
+ out INonRecurrentNetworkTrainer trainer
+ );
+ //Iterate training cycles
+ while (trainer.Iteration())
+ {
+ //Compute current error statistics after training iteration
+ //Training data part
+ TrainedOneTakesAllNetwork currNetwork = new TrainedOneTakesAllNetwork
+ {
+ NetworkName = _networkName,
+ Network = (FeedForwardNetwork)net,
+ TrainingErrorStat = net.ComputeBatchErrorStat(_trainingBundle.InputVectorCollection, _trainingBundle.OutputVectorCollection, out List trainingComputedOutputsCollection)
+ };
+ currNetwork.TrainingBinErrorStat = new BinErrStat(Interval.IntZP1.Mid, trainingComputedOutputsCollection, _trainingBundle.OutputVectorCollection);
+ currNetwork.CombinedBinaryError = currNetwork.TrainingBinErrorStat.TotalErrStat.Sum;
+ currNetwork.CombinedPrecisionError = currNetwork.TrainingErrorStat.ArithAvg;
+ //Testing data part
+ currNetwork.TestingErrorStat = net.ComputeBatchErrorStat(_testingBundle.InputVectorCollection, _testingBundle.OutputVectorCollection, out List testingComputedOutputsCollection);
+ currNetwork.CombinedPrecisionError = Math.Max(currNetwork.CombinedPrecisionError, currNetwork.TestingErrorStat.ArithAvg);
+ currNetwork.TestingBinErrorStat = new BinErrStat(Interval.IntZP1.Mid, testingComputedOutputsCollection, _testingBundle.OutputVectorCollection);
+ currNetwork.CombinedBinaryError = Math.Max(currNetwork.CombinedBinaryError, currNetwork.TestingBinErrorStat.TotalErrStat.Sum);
+ //Restart lastImprovementEpoch when new trainer's attempt started
+ if (trainer.AttemptEpoch == 1)
+ {
+ currNetworkLastImprovementEpoch = trainer.AttemptEpoch;
+ currNetworkLastImprovementCombinedPrecisionError = currNetwork.CombinedPrecisionError;
+ currNetworkLastImprovementCombinedBinaryError = currNetwork.CombinedBinaryError;
+ }
+ //First initialization of the best network
+ if(bestNetwork == null)
+ {
+ bestNetwork = currNetwork.DeepClone();
+ bestNetworkAttempt = trainer.Attempt;
+ }
+ //RegrState instance
+ BuildingState regrState = new BuildingState(_networkName, _foldNum, _numOfFolds, _foldNetworkNum, _numOfFoldNetworks, trainer.Attempt, trainer.MaxAttempt, trainer.AttemptEpoch, trainer.MaxAttemptEpoch, currNetwork, currNetworkLastImprovementEpoch, bestNetwork, bestNetworkAttempt);
+ //Call controller
+ BuildingInstr instructions = _controller(regrState);
+ //Better?
+ if (instructions.CurrentIsBetter)
+ {
+ //Adopt current regression unit as a best one
+ bestNetwork = currNetwork.DeepClone();
+ regrState.BestNetwork = bestNetwork;
+ bestNetworkAttempt = trainer.Attempt;
+ }
+ if (currNetwork.CombinedBinaryError < currNetworkLastImprovementCombinedBinaryError || currNetwork.CombinedPrecisionError < currNetworkLastImprovementCombinedPrecisionError)
+ {
+ currNetworkLastImprovementEpoch = trainer.AttemptEpoch;
+ currNetworkLastImprovementCombinedPrecisionError = currNetwork.CombinedPrecisionError;
+ currNetworkLastImprovementCombinedBinaryError = currNetwork.CombinedBinaryError;
+ }
+ //Raise notification event
+ RegressionEpochDone?.Invoke(regrState, instructions.CurrentIsBetter);
+ //Process instructions
+ if (instructions.StopProcess)
+ {
+ break;
+ }
+ else if (instructions.StopCurrentAttempt)
+ {
+ if (!trainer.NextAttempt())
+ {
+ break;
+ }
+ }
+ }//while (iteration)
+ //Create statistics of the best network weights
+ bestNetwork.OutputWeightsStat = bestNetwork.Network.ComputeWeightsStat();
+ return bestNetwork;
+ }
+
+ //Inner classes
+ ///
+ /// The class contains information needed to control network building (regression) process.
+ /// This class is also used for progeress changed event.
+ ///
+ [Serializable]
+ public class BuildingState
+ {
+ //Attribute properties
+ ///
+ /// Name of the network
+ ///
+ public string NetworkName { get; }
+ ///
+ /// Current fold number
+ ///
+ public int FoldNum { get; }
+ ///
+ /// Total number of the folds
+ ///
+ public int NumOfFolds { get; }
+ ///
+ /// Current fold network number
+ ///
+ public int FoldNetworkNum { get; }
+ ///
+ /// Total number of the fold networks
+ ///
+ public int NumOfFoldNetworks { get; }
+ ///
+ /// Current regression attempt number
+ ///
+ public int RegrAttemptNumber { get; }
+ ///
+ /// Maximum number of regression attempts
+ ///
+ public int RegrMaxAttempts { get; }
+ ///
+ /// Current epoch number
+ ///
+ public int Epoch { get; }
+ ///
+ /// Maximum nuber of epochs
+ ///
+ public int MaxEpochs { get; }
+ ///
+ /// Contains current network and related important error statistics.
+ ///
+ public TrainedOneTakesAllNetwork CurrNetwork { get; }
+ ///
+ /// Specifies when was lastly found an improvement of current network within the current attempt
+ ///
+ public int CurrNetworkLastImprovementEpoch { get; set; }
+ ///
+ /// Contains the best network for now and related important error statistics.
+ ///
+ public TrainedOneTakesAllNetwork BestNetwork { get; set; }
+ ///
+ /// Number of attempt in which was recognized the best network
+ ///
+ public int BestNetworkAttempt { get; set; }
+
+ ///
+ /// Creates an initialized instance
+ ///
+ /// Name of the network
+ /// Current fold number
+ /// Total number of the folds
+ /// Current fold network number
+ /// Total number of the fold networks
+ /// Current regression attempt number
+ /// Maximum number of regression attempts
+ /// Current epoch number within the current regression attempt
+ /// Maximum number of epochs
+ /// Current network and related important error statistics.
+ /// Specifies when was lastly found an improvement of current network within the current attempt.
+ /// The best network for now and related important error statistics.
+ /// Number of attempt in which was recognized the best network.
+ public BuildingState(string networkName,
+ int foldNum,
+ int numOfFolds,
+ int foldNetworkNum,
+ int numOfFoldNetworks,
+ int regrAttemptNumber,
+ int regrMaxAttempts,
+ int epoch,
+ int maxEpochs,
+ TrainedOneTakesAllNetwork currNetwork,
+ int currNetworkLastImprovementEpoch,
+ TrainedOneTakesAllNetwork bestNetwork,
+ int bestNetworkAttempt
+ )
+ {
+ NetworkName = networkName;
+ FoldNum = foldNum;
+ NumOfFolds = numOfFolds;
+ FoldNetworkNum = foldNetworkNum;
+ NumOfFoldNetworks = numOfFoldNetworks;
+ RegrAttemptNumber = regrAttemptNumber;
+ RegrMaxAttempts = regrMaxAttempts;
+ Epoch = epoch;
+ MaxEpochs = maxEpochs;
+ CurrNetwork = currNetwork;
+ CurrNetworkLastImprovementEpoch = currNetworkLastImprovementEpoch;
+ BestNetwork = bestNetwork;
+ BestNetworkAttempt = bestNetworkAttempt;
+ return;
+ }
+
+ //Methods
+ ///
+ /// Builds string containing information about the regression progress.
+ ///
+ /// Specifies how many spaces to be at the begining of the line.
+ /// Built text line
+ public string GetProgressInfo(int margin = 0)
+ {
+ //Build progress text message
+ StringBuilder progressText = new StringBuilder();
+ progressText.Append(new string(' ', margin));
+ progressText.Append("Network [");
+ progressText.Append(NetworkName);
+ progressText.Append("] Fold/Net/Attempt/Epoch: ");
+ progressText.Append(FoldNum.ToString().PadLeft(NumOfFolds.ToString().Length, '0') + "/");
+ progressText.Append(FoldNetworkNum.ToString().PadLeft(NumOfFoldNetworks.ToString().Length, '0') + "/");
+ progressText.Append(RegrAttemptNumber.ToString().PadLeft(RegrMaxAttempts.ToString().Length, '0') + "/");
+ progressText.Append(Epoch.ToString().PadLeft(MaxEpochs.ToString().Length, '0'));
+ progressText.Append(", Samples: ");
+ progressText.Append((CurrNetwork.TrainingErrorStat.NumOfSamples / CurrNetwork.Network.NumOfOutputValues).ToString() + "/");
+ progressText.Append((CurrNetwork.TestingErrorStat.NumOfSamples / CurrNetwork.Network.NumOfOutputValues).ToString());
+ progressText.Append(", Best-Train: ");
+ progressText.Append(BestNetwork.TrainingErrorStat.ArithAvg.ToString("E3", CultureInfo.InvariantCulture));
+ progressText.Append("/" + BestNetwork.TrainingBinErrorStat.TotalErrStat.Sum.ToString(CultureInfo.InvariantCulture));
+ progressText.Append("/" + BestNetwork.TrainingBinErrorStat.BinValErrStat[1].Sum.ToString(CultureInfo.InvariantCulture));
+ progressText.Append(", Best-Test: ");
+ progressText.Append(BestNetwork.TestingErrorStat.ArithAvg.ToString("E3", CultureInfo.InvariantCulture));
+ progressText.Append("/" + BestNetwork.TestingBinErrorStat.TotalErrStat.Sum.ToString(CultureInfo.InvariantCulture));
+ progressText.Append("/" + BestNetwork.TestingBinErrorStat.BinValErrStat[1].Sum.ToString(CultureInfo.InvariantCulture));
+ progressText.Append(", Curr-Train: ");
+ progressText.Append(CurrNetwork.TrainingErrorStat.ArithAvg.ToString("E3", CultureInfo.InvariantCulture));
+ progressText.Append("/" + CurrNetwork.TrainingBinErrorStat.TotalErrStat.Sum.ToString(CultureInfo.InvariantCulture));
+ progressText.Append("/" + CurrNetwork.TrainingBinErrorStat.BinValErrStat[1].Sum.ToString(CultureInfo.InvariantCulture));
+ progressText.Append(", Curr-Test: ");
+ progressText.Append(CurrNetwork.TestingErrorStat.ArithAvg.ToString("E3", CultureInfo.InvariantCulture));
+ progressText.Append("/" + CurrNetwork.TestingBinErrorStat.TotalErrStat.Sum.ToString(CultureInfo.InvariantCulture));
+ progressText.Append("/" + CurrNetwork.TestingBinErrorStat.BinValErrStat[1].Sum.ToString(CultureInfo.InvariantCulture));
+ return progressText.ToString();
+ }
+
+
+ }//BuildingState
+
+ ///
+ /// Contains instructions for the network building (regression) process
+ ///
+ public class BuildingInstr
+ {
+ //Attribute properties
+ ///
+ /// Indicates whether to terminate the current regression attempt
+ ///
+ public bool StopCurrentAttempt { get; set; } = false;
+ ///
+ /// Indicates whether to terminate the entire regression process
+ ///
+ public bool StopProcess { get; set; } = false;
+ ///
+ /// This is the most important switch indicating whether the CurrNetwork is better than
+ /// the BestNetwork
+ ///
+ public bool CurrentIsBetter { get; set; } = false;
+
+ }//BuildingInstr
+
+ }//TrainedOneTakesAllNetworkBuilder
+
+}//Namespace
diff --git a/RCNet/Neural/Network/NonRecurrent/TrainedOneTakesAllNetworkCluster.cs b/RCNet/Neural/Network/NonRecurrent/TrainedOneTakesAllNetworkCluster.cs
new file mode 100644
index 0000000..f90c8c6
--- /dev/null
+++ b/RCNet/Neural/Network/NonRecurrent/TrainedOneTakesAllNetworkCluster.cs
@@ -0,0 +1,247 @@
+using System;
+using System.Collections.Generic;
+using RCNet.Extensions;
+using RCNet.MathTools;
+using RCNet.MathTools.Probability;
+using RCNet.Neural.Data;
+using RCNet.Neural.Data.Filter;
+
+namespace RCNet.Neural.Network.NonRecurrent
+{
+ ///
+ /// Computation cluster of trained "One Takes All" networks.
+ ///
+ [Serializable]
+ public class TrainedOneTakesAllNetworkCluster
+ {
+
+ //Constants
+ //Macro-weights
+ //Measures group weights
+ private const double TrainGroupWeight = 1d;
+ private const double TestGroupWeight = 1d;
+ //Measures weights
+ private const double SamplesWeight = 1d;
+ private const double PrecisionWeight = 1d;
+ private const double MisrecognizedWeight = 1d;
+ private const double UnrecognizedWeight = 1d;
+
+ //Attribute properties
+ ///
+ /// Name of the cluster
+ ///
+ public string ClusterName { get; }
+
+ ///
+ /// Indicates the cluster is finalized and ready
+ ///
+ public bool Finalized { get; private set; }
+
+ //Attributes
+ ///
+ /// Member networks
+ ///
+ private readonly List _memberNetCollection;
+
+ ///
+ /// First level weights
+ ///
+ private double[] _membersWeights;
+
+ //Constructors
+ ///
+ /// Creates an uninitialized instance
+ ///
+ /// Name of the cluster
+ public TrainedOneTakesAllNetworkCluster(string clusterName)
+ {
+ ClusterName = clusterName;
+ Finalized = false;
+ _memberNetCollection = new List();
+ _membersWeights = null;
+ return;
+ }
+
+ ///
+ /// Copy constructor
+ ///
+ /// Source cluster
+ public TrainedOneTakesAllNetworkCluster(TrainedOneTakesAllNetworkCluster source)
+ {
+ ClusterName = source.ClusterName;
+ Finalized = source.Finalized;
+ _memberNetCollection = new List(source._memberNetCollection.Count);
+ foreach (TrainedOneTakesAllNetwork tn in source._memberNetCollection)
+ {
+ _memberNetCollection.Add(tn.DeepClone());
+ }
+ _membersWeights = (double[])source._membersWeights?.Clone();
+ return;
+ }
+
+ //Properties
+ ///
+ /// Number of member networks
+ ///
+ public int NumOfMembers { get { return _memberNetCollection.Count; } }
+
+ //Methods
+ ///
+ /// Adds new member network and updates cluster error statistics
+ ///
+ /// New member network
+ /// Testing data
+ public void AddMember(TrainedOneTakesAllNetwork memberNet, VectorBundle testData)
+ {
+ //Add member to inner collection
+ _memberNetCollection.Add(memberNet);
+ return;
+ }
+
+ ///
+ /// Computes weights of given trained networks
+ ///
+ private double[] ComputeTrainedNetworksWeights(List trainedNetworks)
+ {
+ double[] weightCollection = new double[trainedNetworks.Count];
+ if (weightCollection.Length == 1)
+ {
+ weightCollection[0] = 1d;
+ }
+ else
+ {
+ //Holders for metrics from member's training and testing phases
+ //Training
+ double[] wHolderTrainSamples = new double[trainedNetworks.Count];
+ double[] wHolderTrainPrecision = new double[trainedNetworks.Count];
+ double[] wHolderTrainMisrecognized = new double[trainedNetworks.Count];
+ double[] wHolderTrainUnrecognized = new double[trainedNetworks.Count];
+ //Testing
+ double[] wHolderTestSamples = new double[trainedNetworks.Count];
+ double[] wHolderTestPrecision = new double[trainedNetworks.Count];
+ double[] wHolderTestMisrecognized = new double[trainedNetworks.Count];
+ double[] wHolderTestUnrecognized = new double[trainedNetworks.Count];
+ //Collect members' metrics
+ for (int memberIdx = 0; memberIdx < trainedNetworks.Count; memberIdx++)
+ {
+ wHolderTrainSamples[memberIdx] = trainedNetworks[memberIdx].TrainingErrorStat.NumOfSamples;
+ wHolderTrainPrecision[memberIdx] = trainedNetworks[memberIdx].TrainingErrorStat.ArithAvg;
+ wHolderTrainMisrecognized[memberIdx] = (1d - trainedNetworks[memberIdx].TrainingBinErrorStat.BinValErrStat[0].ArithAvg);
+ wHolderTrainUnrecognized[memberIdx] = (1d - trainedNetworks[memberIdx].TrainingBinErrorStat.BinValErrStat[1].ArithAvg);
+ wHolderTestSamples[memberIdx] = trainedNetworks[memberIdx].TestingErrorStat.NumOfSamples;
+ wHolderTestPrecision[memberIdx] = trainedNetworks[memberIdx].TestingErrorStat.ArithAvg;
+ wHolderTestMisrecognized[memberIdx] = (1d - trainedNetworks[memberIdx].TestingBinErrorStat.BinValErrStat[0].ArithAvg);
+ wHolderTestUnrecognized[memberIdx] = (1d - trainedNetworks[memberIdx].TestingBinErrorStat.BinValErrStat[1].ArithAvg);
+ }
+ //Turn the metrics to have the same meaning and scale them to be useable as the sub-weights
+ wHolderTrainSamples.ScaleToNewSum(1d);
+ wHolderTestSamples.ScaleToNewSum(1d);
+ wHolderTrainPrecision.RevertMeaning();
+ wHolderTrainPrecision.ScaleToNewSum(1d);
+ wHolderTestPrecision.RevertMeaning();
+ wHolderTestPrecision.ScaleToNewSum(1d);
+ wHolderTrainMisrecognized.ScaleToNewSum(1d);
+ wHolderTestMisrecognized.ScaleToNewSum(1d);
+ wHolderTrainUnrecognized.ScaleToNewSum(1d);
+ wHolderTestUnrecognized.ScaleToNewSum(1d);
+ //Build the final weights
+ //Combine the sub-weights using defined macro weights of the metrics
+ for (int i = 0; i < trainedNetworks.Count; i++)
+ {
+ weightCollection[i] = TrainGroupWeight * (SamplesWeight * wHolderTrainSamples[i] +
+ PrecisionWeight * wHolderTrainPrecision[i] +
+ MisrecognizedWeight * wHolderTrainMisrecognized[i] +
+ UnrecognizedWeight * wHolderTrainUnrecognized[i]
+ ) +
+ TestGroupWeight * (SamplesWeight * wHolderTestSamples[i] +
+ PrecisionWeight * wHolderTestPrecision[i] +
+ MisrecognizedWeight * wHolderTestMisrecognized[i] +
+ UnrecognizedWeight * wHolderTestUnrecognized[i]
+ );
+ }
+ //Softmax transformation
+ weightCollection.Softmax();
+ }
+ return weightCollection;
+ }
+
+ ///
+ /// Computes the composite probabilistic result
+ ///
+ private double[] ComputeCompositeOutput(List membersResults)
+ {
+ int numOfOutputProbabilities = membersResults[0].Length;
+ double[] outProbabilities = new double[numOfOutputProbabilities];
+ for(int pIdx = 0; pIdx < numOfOutputProbabilities; pIdx++)
+ {
+ double[] memberPs = new double[NumOfMembers];
+ for(int i = 0; i < NumOfMembers; i++)
+ {
+ memberPs[i] = membersResults[i][pIdx];
+ }
+ //Compute members mixed probability
+ outProbabilities[pIdx] = PMixer.MixP(memberPs, _membersWeights);
+ }
+ //outProbabilities.Softmax();
+ outProbabilities.ScaleToNewSum(1d);
+ return outProbabilities;
+ }
+
+ ///
+ /// Computes and returns probabilistic outputs of all cluster member networks
+ ///
+ private List ComputeClusterMemberNetworks(double[] inputVector)
+ {
+ List outputVectors = new List(NumOfMembers);
+ for (int memberIdx = 0; memberIdx < NumOfMembers; memberIdx++)
+ {
+ outputVectors.Add(_memberNetCollection[memberIdx].Network.Compute(inputVector));
+ }
+ return outputVectors;
+ }
+
+ ///
+ /// Makes cluster ready to operate
+ ///
+ /// Whole data used for training of inner members
+ public void FinalizeCluster(VectorBundle dataBundle)
+ {
+ if (Finalized)
+ {
+ throw new InvalidOperationException("Cluster was already finalized.");
+ }
+ //Initialize the members' weights
+ _membersWeights = ComputeTrainedNetworksWeights(_memberNetCollection);
+ //When necessary, prepare the second level networks
+ Finalized = true;
+ return;
+ }
+
+ ///
+ /// Computes the cluster probabilistic output
+ ///
+ /// Input predictors
+ public double[] Compute(double[] predictors)
+ {
+ if (!Finalized)
+ {
+ throw new InvalidOperationException("Cluster is not finalized. Call FinalizeCluster method first.");
+ }
+ //Collect member networks outputs
+ List memberOutputs = ComputeClusterMemberNetworks(predictors);
+ //Compute the probabilistic result
+ double[] output = ComputeCompositeOutput(memberOutputs);
+ return output;
+ }
+
+ ///
+ /// Creates the deep copy of this instance
+ ///
+ public TrainedOneTakesAllNetworkCluster DeepClone()
+ {
+ return new TrainedOneTakesAllNetworkCluster(this);
+ }
+
+ }//TrainedOneTakesAllNetworkCluster
+
+}//Namespace
diff --git a/RCNet/Neural/Network/NonRecurrent/TrainedOneTakesAllNetworkClusterBuilder.cs b/RCNet/Neural/Network/NonRecurrent/TrainedOneTakesAllNetworkClusterBuilder.cs
new file mode 100644
index 0000000..ea9e3ac
--- /dev/null
+++ b/RCNet/Neural/Network/NonRecurrent/TrainedOneTakesAllNetworkClusterBuilder.cs
@@ -0,0 +1,125 @@
+using RCNet.Extensions;
+using RCNet.MathTools;
+using RCNet.Neural.Data;
+using RCNet.Neural.Data.Filter;
+using RCNet.Neural.Network.NonRecurrent.FF;
+using System;
+using System.Collections.Generic;
+using System.Globalization;
+using System.Linq;
+
+namespace RCNet.Neural.Network.NonRecurrent
+{
+ ///
+ /// Builds computation cluster of trained "One Takes All" networks.
+ ///
+ public class TrainedOneTakesAllNetworkClusterBuilder
+ {
+ //Events
+ ///
+ /// This informative event occurs every time the regression epoch is done
+ ///
+ [field: NonSerialized]
+ public event TrainedOneTakesAllNetworkBuilder.RegressionEpochDoneHandler RegressionEpochDone;
+
+ //Attributes
+ private readonly string _clusterName;
+ private readonly List _networkSettingsCollection;
+ private readonly Random _rand;
+ private readonly TrainedOneTakesAllNetworkBuilder.RegressionControllerDelegate _controller;
+
+ //Constructor
+ ///
+ /// Creates an instance ready to build probabilistic computation cluster
+ ///
+ /// Name of the cluster
+ /// Collection of network configurations (FeedForwardNetworkSettings)
+ /// Random generator to be used (optional)
+ /// Regression controller (optional)
+ public TrainedOneTakesAllNetworkClusterBuilder(string clusterName,
+ List networkSettingsCollection,
+ Random rand = null,
+ TrainedOneTakesAllNetworkBuilder.RegressionControllerDelegate controller = null
+ )
+ {
+ _clusterName = clusterName;
+ _networkSettingsCollection = networkSettingsCollection;
+ _rand = rand ?? new Random(0);
+ _controller = controller;
+ return;
+ }
+
+ //Properties
+
+ //Methods
+ private void OnRegressionEpochDone(TrainedOneTakesAllNetworkBuilder.BuildingState buildingState, bool foundBetter)
+ {
+ //Only raise up
+ RegressionEpochDone(buildingState, foundBetter);
+ return;
+ }
+
+ ///
+ /// Builds computation cluster of trained networks
+ ///
+ /// Data to be used for training. Take into the account that rows in this bundle may be in random order after the Build call.
+ /// Crossvalidation configuration
+ public TrainedOneTakesAllNetworkCluster Build(VectorBundle dataBundle,
+ CrossvalidationSettings crossvalidationCfg
+ )
+ {
+ //Cluster of trained networks
+ TrainedOneTakesAllNetworkCluster cluster = new TrainedOneTakesAllNetworkCluster(_clusterName);
+ //Member's training
+ for (int repetitionIdx = 0; repetitionIdx < crossvalidationCfg.Repetitions; repetitionIdx++)
+ {
+ //Data split to folds
+ List subBundleCollection = dataBundle.Folderize(crossvalidationCfg.FoldDataRatio, Interval.IntZP1.Mid);
+ int numOfFoldsToBeProcessed = Math.Min(crossvalidationCfg.Folds <= 0 ? subBundleCollection.Count : crossvalidationCfg.Folds, subBundleCollection.Count);
+ //Train collection of networks for each processing fold.
+ for (int foldIdx = 0; foldIdx < numOfFoldsToBeProcessed; foldIdx++)
+ {
+ for (int netCfgIdx = 0; netCfgIdx < _networkSettingsCollection.Count; netCfgIdx++)
+ {
+ //Prepare training data bundle
+ VectorBundle trainingData = new VectorBundle();
+ for (int bundleIdx = 0; bundleIdx < subBundleCollection.Count; bundleIdx++)
+ {
+ if (bundleIdx != foldIdx)
+ {
+ trainingData.Add(subBundleCollection[bundleIdx]);
+ }
+ }
+ TrainedOneTakesAllNetworkBuilder netBuilder = new TrainedOneTakesAllNetworkBuilder(_clusterName,
+ _networkSettingsCollection[netCfgIdx],
+ (repetitionIdx * numOfFoldsToBeProcessed) + foldIdx + 1,
+ crossvalidationCfg.Repetitions * numOfFoldsToBeProcessed,
+ netCfgIdx + 1,
+ _networkSettingsCollection.Count,
+ trainingData,
+ subBundleCollection[foldIdx],
+ _rand,
+ _controller
+ );
+ //Register notification
+ netBuilder.RegressionEpochDone += OnRegressionEpochDone;
+ //Build trained network. Trained network becomes to be the cluster member
+ TrainedOneTakesAllNetwork tn = netBuilder.Build();
+ cluster.AddMember(tn, subBundleCollection[foldIdx]);
+ }//netCfgIdx
+ }//foldIdx
+ if (repetitionIdx < crossvalidationCfg.Repetitions - 1)
+ {
+ //Reshuffle the data
+ dataBundle.Shuffle(_rand);
+ }
+ }//repetitionIdx
+ //Make the cluster operable
+ cluster.FinalizeCluster(dataBundle);
+ //Return the built cluster
+ return cluster;
+ }
+
+ }//TrainedOneTakesAllNetworkClusterBuilder
+
+}//Namespace
diff --git a/RCNet/Neural/Network/SM/Readout/ClusterSettings.cs b/RCNet/Neural/Network/SM/Readout/ClusterSettings.cs
new file mode 100644
index 0000000..51e25c8
--- /dev/null
+++ b/RCNet/Neural/Network/SM/Readout/ClusterSettings.cs
@@ -0,0 +1,137 @@
+using System;
+using System.Collections.Generic;
+using System.Linq;
+using System.Xml.Linq;
+using RCNet.Neural.Network.NonRecurrent;
+
+namespace RCNet.Neural.Network.SM.Readout
+{
+ ///
+ /// Configuration of the cluster associated with the readout unit
+ ///
+ [Serializable]
+ public class ClusterSettings : RCNetBaseSettings
+ {
+ //Constants
+ ///
+ /// Name of the associated xsd type
+ ///
+ public const string XsdTypeName = "ROutLayerClusterType";
+
+ //Attribute properties
+ ///
+ /// Crossvalidation configuration
+ ///
+ public CrossvalidationSettings CrossvalidationCfg { get; }
+
+ ///
+ /// Task dependent networks settings to be applied when specific networks for readout unit are not specified
+ ///
+ public DefaultNetworksSettings DefaultNetworksCfg { get; }
+
+ ///
+ /// Configuration of the network cluster 2nd level computation
+ ///
+ public NetworkClusterSecondLevelCompSettings SecondLevelCompCfg { get; }
+
+ //Constructors
+ ///
+ /// Creates an initialized instance.
+ ///
+ /// Crossvalidation configuration
+ /// Task dependent networks settings to be applied when specific networks for readout unit are not specified
+ /// Configuration of the network cluster 2nd level computation
+ public ClusterSettings(CrossvalidationSettings crossvalidationCfg,
+ DefaultNetworksSettings defaultNetworksCfg = null,
+ NetworkClusterSecondLevelCompSettings secondLevelCompCfg = null
+ )
+ {
+ CrossvalidationCfg = (CrossvalidationSettings)crossvalidationCfg.DeepClone();
+ DefaultNetworksCfg = defaultNetworksCfg == null ? new DefaultNetworksSettings() : (DefaultNetworksSettings)defaultNetworksCfg.DeepClone();
+ SecondLevelCompCfg = (NetworkClusterSecondLevelCompSettings)secondLevelCompCfg?.DeepClone();
+ Check();
+ return;
+ }
+
+ ///
+ /// The deep copy constructor
+ ///
+ /// Source instance
+ public ClusterSettings(ClusterSettings source)
+ : this(source.CrossvalidationCfg, source.DefaultNetworksCfg, source.SecondLevelCompCfg)
+ {
+ return;
+ }
+
+ ///
+ /// Creates an initialized instance.
+ ///
+ /// Xml element containing the initialization settings
+ public ClusterSettings(XElement elem)
+ {
+ //Validation
+ XElement settingsElem = Validate(elem, XsdTypeName);
+ //Parsing
+ //Crossvalidation
+ CrossvalidationCfg = new CrossvalidationSettings(settingsElem.Element("crossvalidation"));
+ //Default networks settings
+ XElement defaultNetworksElem = settingsElem.Elements("defaultNetworks").FirstOrDefault();
+ DefaultNetworksCfg = defaultNetworksElem == null ? new DefaultNetworksSettings() : new DefaultNetworksSettings(defaultNetworksElem);
+ //Second level computation
+ XElement clusterSecondLevelCompElem = settingsElem.Elements("secondLevelComputation").FirstOrDefault();
+ if (clusterSecondLevelCompElem != null)
+ {
+ SecondLevelCompCfg = new NetworkClusterSecondLevelCompSettings(clusterSecondLevelCompElem);
+ }
+ else
+ {
+ SecondLevelCompCfg = null;
+ }
+ Check();
+ return;
+ }
+
+ //Properties
+
+ ///
+ public override bool ContainsOnlyDefaults { get { return false; } }
+
+ //Methods
+ ///
+ protected override void Check()
+ {
+ return;
+ }
+
+ ///
+ public override RCNetBaseSettings DeepClone()
+ {
+ return new ClusterSettings(this);
+ }
+
+ ///
+ public override XElement GetXml(string rootElemName, bool suppressDefaults)
+ {
+ XElement rootElem = new XElement(rootElemName, CrossvalidationCfg.GetXml(suppressDefaults));
+ if (!DefaultNetworksCfg.ContainsOnlyDefaults)
+ {
+ rootElem.Add(DefaultNetworksCfg.GetXml(suppressDefaults));
+ }
+ if (SecondLevelCompCfg != null)
+ {
+ rootElem.Add(SecondLevelCompCfg.GetXml(suppressDefaults));
+ }
+ Validate(rootElem, XsdTypeName);
+ return rootElem;
+ }
+
+ ///
+ public override XElement GetXml(bool suppressDefaults)
+ {
+ return GetXml("cluster", suppressDefaults);
+ }
+
+
+ }//ClusterSettings
+
+}//Namespace
diff --git a/RCNet/Neural/Network/SM/Readout/DefaultNetworksSettings.cs b/RCNet/Neural/Network/SM/Readout/DefaultNetworksSettings.cs
index e37cb21..a9b975f 100644
--- a/RCNet/Neural/Network/SM/Readout/DefaultNetworksSettings.cs
+++ b/RCNet/Neural/Network/SM/Readout/DefaultNetworksSettings.cs
@@ -17,7 +17,7 @@ public class DefaultNetworksSettings : RCNetBaseSettings
///
/// Name of the associated xsd type
///
- public const string XsdTypeName = "ROutLayerUnitDefaultNetworksType";
+ public const string XsdTypeName = "ROutLayerClusterDefaultNetworksType";
//Attribute properties
///
diff --git a/RCNet/Neural/Network/SM/Readout/OneWinnerDecisionMaker.cs b/RCNet/Neural/Network/SM/Readout/OneWinnerDecisionMaker.cs
new file mode 100644
index 0000000..df47aff
--- /dev/null
+++ b/RCNet/Neural/Network/SM/Readout/OneWinnerDecisionMaker.cs
@@ -0,0 +1,18 @@
+using System;
+using System.Collections.Generic;
+using System.Text;
+
+namespace RCNet.Neural.Network.SM.Readout
+{
+ ///
+ ///
+ ///
+ [Serializable]
+ public class OneWinnerDecisionMaker
+ {
+
+
+
+ }//OneWinnerDecisionMaker
+
+}//Namespace
diff --git a/RCNet/Neural/Network/SM/Readout/OneWinnerDecisionMakerSettings.cs b/RCNet/Neural/Network/SM/Readout/OneWinnerDecisionMakerSettings.cs
new file mode 100644
index 0000000..ad91e9d
--- /dev/null
+++ b/RCNet/Neural/Network/SM/Readout/OneWinnerDecisionMakerSettings.cs
@@ -0,0 +1,133 @@
+using System;
+using System.Globalization;
+using System.Linq;
+using System.Xml.Linq;
+using RCNet.Neural.Activation;
+using RCNet.Neural.Network.NonRecurrent;
+using RCNet.Neural.Network.NonRecurrent.FF;
+
+namespace RCNet.Neural.Network.SM.Readout
+{
+ ///
+ /// Configuration of the OneWinnerDecisionMaker
+ ///
+ [Serializable]
+ public class OneWinnerDecisionMakerSettings : RCNetBaseSettings
+ {
+ //Constants
+ ///
+ /// Name of the associated xsd type
+ ///
+ public const string XsdTypeName = "ROutLayerOneWinnerDecisionMakerType";
+ //Default values
+ ///
+ /// Default value of the parameter specifying how rich will be an input for the final probabilities network. True means to use all available sub-predictions from clusters members and False means to use only already aggregated predictions from clusters.
+ ///
+ public const bool DefaultConsiderAllSubPredictions = true;
+
+ //Attribute properties
+ ///
+ /// Crossvalidation configuration
+ ///
+ public CrossvalidationSettings CrossvalidationCfg { get; }
+
+ ///
+ /// Final probabilities network configuration
+ ///
+ public FeedForwardNetworkSettings NetCfg { get; }
+
+ ///
+ /// Specifies how rich will be an input for the final probabilities network. True means to use all available sub-predictions from clusters members and False means to use only already aggregated predictions from clusters.
+ ///
+ public bool ConsiderAllSubPredictions { get; }
+
+
+ //Constructors
+ ///
+ /// Creates an unitialized instance
+ ///
+ /// Crossvalidation configuration
+ /// Final probabilities network configuration
+ /// Specifies how rich will be an input for the final probabilities network. True means to use all available sub-predictions from clusters members and False means to use only already aggregated predictions from clusters.
+ public OneWinnerDecisionMakerSettings(CrossvalidationSettings crossvalidationCfg,
+ FeedForwardNetworkSettings netCfg,
+ bool considerAllSubPredictions = DefaultConsiderAllSubPredictions
+ )
+ {
+ CrossvalidationCfg = (CrossvalidationSettings)crossvalidationCfg.DeepClone();
+ NetCfg = (FeedForwardNetworkSettings)netCfg.DeepClone();
+ ConsiderAllSubPredictions = considerAllSubPredictions;
+ Check();
+ return;
+ }
+
+ ///
+ /// Copy constructor
+ ///
+ /// Source instance
+ public OneWinnerDecisionMakerSettings(OneWinnerDecisionMakerSettings source)
+ : this(source.CrossvalidationCfg, source.NetCfg, source.ConsiderAllSubPredictions)
+ {
+ return;
+ }
+
+ ///
+ /// Creates an initialized instance.
+ ///
+ /// Xml data containing the settings.
+ public OneWinnerDecisionMakerSettings(XElement elem)
+ {
+ //Validation
+ XElement settingsElem = Validate(elem, XsdTypeName);
+ //Parsing
+ CrossvalidationCfg = new CrossvalidationSettings(settingsElem.Element("crossvalidation"));
+ NetCfg = new FeedForwardNetworkSettings(settingsElem.Element("ff"));
+ ConsiderAllSubPredictions = bool.Parse(settingsElem.Attribute("considerAllSubPredictions").Value);
+ Check();
+ return;
+ }
+
+ //Properties
+ ///
+ /// Checks the defaults
+ ///
+ public bool IsDefaultConsiderAllSubPredictions { get { return (ConsiderAllSubPredictions == DefaultConsiderAllSubPredictions); } }
+
+ ///
+ public override bool ContainsOnlyDefaults { get { return false; } }
+
+ //Methods
+ ///
+ protected override void Check()
+ {
+ TrainedOneTakesAllNetworkBuilder.CheckNetCfg(NetCfg);
+ return;
+ }
+
+ ///
+ public override RCNetBaseSettings DeepClone()
+ {
+ return new OneWinnerDecisionMakerSettings(this);
+ }
+
+ ///
+ public override XElement GetXml(string rootElemName, bool suppressDefaults)
+ {
+ XElement rootElem = new XElement(rootElemName, CrossvalidationCfg.GetXml(suppressDefaults), NetCfg.GetXml(suppressDefaults));
+ if (!suppressDefaults || !IsDefaultConsiderAllSubPredictions)
+ {
+ rootElem.Add(new XAttribute("considerAllSubPredictions", ConsiderAllSubPredictions.ToString().ToLowerInvariant()));
+ }
+ Validate(rootElem, XsdTypeName);
+ return rootElem;
+ }
+
+ ///
+ public override XElement GetXml(bool suppressDefaults)
+ {
+ return GetXml("oneWinnerDecisionMaker", suppressDefaults);
+ }
+
+ }//OneWinnerDecisionMakerSettings
+
+}//Namespace
diff --git a/RCNet/Neural/Network/SM/Readout/ReadoutLayer.cs b/RCNet/Neural/Network/SM/Readout/ReadoutLayer.cs
index 4cb1975..c972066 100644
--- a/RCNet/Neural/Network/SM/Readout/ReadoutLayer.cs
+++ b/RCNet/Neural/Network/SM/Readout/ReadoutLayer.cs
@@ -227,19 +227,23 @@ public RegressionOverview Build(VectorBundle dataBundle,
ReadoutLayerCfg.ReadoutUnitsCfg.ReadoutUnitCfgCollection[unitIdx].TaskCfg.Type == ReadoutUnit.TaskType.Classification ? InternalBinBorder : double.NaN,
rand,
controller,
- ReadoutLayerCfg.ReadoutUnitsCfg.ClusterSecondLevelCompCfg
+ ReadoutLayerCfg.ClusterCfg.SecondLevelCompCfg
);
//Register notification
readoutUnitBuilder.RegressionEpochDone += OnRegressionEpochDone;
//Build trained readout unit. Trained unit becomes to be the predicting cluster member
_readoutUnitCollection[unitIdx] = new ReadoutUnit(unitIdx,
readoutUnitBuilder.Build(readoutUnitDataBundle,
- ReadoutLayerCfg.CrossvalidationCfg,
+ ReadoutLayerCfg.ClusterCfg.CrossvalidationCfg,
_outputFeatureFilterCollection[unitIdx]
)
);
}//unitIdx
+ //TODO Train One Winner Groups decision makers if defined
+
+
+
//Readout layer is trained and ready
Trained = true;
return new RegressionOverview(ReadoutUnitErrStatCollection);
diff --git a/RCNet/Neural/Network/SM/Readout/ReadoutLayerSettings.cs b/RCNet/Neural/Network/SM/Readout/ReadoutLayerSettings.cs
index f7d51e2..a2fe92f 100644
--- a/RCNet/Neural/Network/SM/Readout/ReadoutLayerSettings.cs
+++ b/RCNet/Neural/Network/SM/Readout/ReadoutLayerSettings.cs
@@ -20,35 +20,35 @@ public class ReadoutLayerSettings : RCNetBaseSettings
//Attribute properties
///
- /// Crossvalidation configuration
+ /// Configuration of the cluster associated with the readout unit
///
- public CrossvalidationSettings CrossvalidationCfg { get; }
+ public ClusterSettings ClusterCfg { get; }
///
- /// Task dependent networks settings to be applied when specific networks for readout unit are not specified
+ /// Readout units configuration
///
- public DefaultNetworksSettings DefaultNetworksCfg { get; }
+ public ReadoutUnitsSettings ReadoutUnitsCfg { get; }
///
- /// Readout units configuration
+ /// Configuration of the decision maker applied on "One winner" groups
///
- public ReadoutUnitsSettings ReadoutUnitsCfg { get; }
+ public OneWinnerDecisionMakerSettings OneWinnerDecisionMakerCfg { get; }
//Constructors
///
/// Creates an initialized instance.
///
- /// Crossvalidation configuration
+ /// Configuration of the cluster associated with the readout unit
/// Readout units configuration
- /// Task dependent networks settings to be applied when specific networks for readout unit are not specified
- public ReadoutLayerSettings(CrossvalidationSettings crossvalidationCfg,
+ /// Configuration of the decision maker applied on "One winner" groups
+ public ReadoutLayerSettings(ClusterSettings clusterCfg,
ReadoutUnitsSettings readoutUnitsCfg,
- DefaultNetworksSettings defaultNetworksCfg = null
+ OneWinnerDecisionMakerSettings oneWinnerDecisionMakerCfg = null
)
{
- CrossvalidationCfg = (CrossvalidationSettings)crossvalidationCfg.DeepClone();
+ ClusterCfg = (ClusterSettings)clusterCfg.DeepClone();
ReadoutUnitsCfg = (ReadoutUnitsSettings)readoutUnitsCfg.DeepClone();
- DefaultNetworksCfg = defaultNetworksCfg == null ? new DefaultNetworksSettings() : (DefaultNetworksSettings)defaultNetworksCfg.DeepClone();
+ OneWinnerDecisionMakerCfg = (OneWinnerDecisionMakerSettings)oneWinnerDecisionMakerCfg?.DeepClone();
Check();
return;
}
@@ -58,7 +58,7 @@ public ReadoutLayerSettings(CrossvalidationSettings crossvalidationCfg,
///
/// Source instance
public ReadoutLayerSettings(ReadoutLayerSettings source)
- : this(source.CrossvalidationCfg, source.ReadoutUnitsCfg, source.DefaultNetworksCfg)
+ : this(source.ClusterCfg, source.ReadoutUnitsCfg, source.OneWinnerDecisionMakerCfg)
{
return;
}
@@ -73,14 +73,14 @@ public ReadoutLayerSettings(XElement elem)
//Validation
XElement settingsElem = Validate(elem, XsdTypeName);
//Parsing
- //Crossvalidation
- CrossvalidationCfg = new CrossvalidationSettings(settingsElem.Element("crossvalidation"));
- //Default networks settings
- XElement defaultNetworksElem = settingsElem.Elements("defaultNetworks").FirstOrDefault();
- DefaultNetworksCfg = defaultNetworksElem == null ? new DefaultNetworksSettings() : new DefaultNetworksSettings(defaultNetworksElem);
+ //Cluster
+ ClusterCfg = new ClusterSettings(settingsElem.Element("cluster"));
//Readout units
XElement readoutUnitsElem = settingsElem.Elements("readoutUnits").First();
ReadoutUnitsCfg = new ReadoutUnitsSettings(readoutUnitsElem);
+ //One winner decision maker
+ XElement oneWinnerDecisionMakerElem = settingsElem.Elements("oneWinnerDecisionMaker").FirstOrDefault();
+ OneWinnerDecisionMakerCfg = oneWinnerDecisionMakerElem == null ? null : new OneWinnerDecisionMakerSettings(oneWinnerDecisionMakerElem);
Check();
return;
}
@@ -108,7 +108,7 @@ protected override void Check()
{
if (rus.TaskCfg.NetworkCfgCollection.Count == 0)
{
- if (DefaultNetworksCfg.GetTaskNetworksCfgs(rus.TaskCfg.Type).Count == 0)
+ if (ClusterCfg.DefaultNetworksCfg.GetTaskNetworksCfgs(rus.TaskCfg.Type).Count == 0)
{
throw new ArgumentException($"Readout unit {rus.Name} has not associated network(s) settings.", "ReadoutUnitsCfg");
}
@@ -130,7 +130,7 @@ public List GetReadoutUnitNetworksCollection(int r
}
else
{
- return DefaultNetworksCfg.GetTaskNetworksCfgs(ReadoutUnitsCfg.ReadoutUnitCfgCollection[readoutUnitIndex].TaskCfg.Type);
+ return ClusterCfg.DefaultNetworksCfg.GetTaskNetworksCfgs(ReadoutUnitsCfg.ReadoutUnitCfgCollection[readoutUnitIndex].TaskCfg.Type);
}
}
@@ -143,12 +143,12 @@ public override RCNetBaseSettings DeepClone()
///
public override XElement GetXml(string rootElemName, bool suppressDefaults)
{
- XElement rootElem = new XElement(rootElemName, CrossvalidationCfg.GetXml(suppressDefaults));
- if (!DefaultNetworksCfg.ContainsOnlyDefaults)
+ XElement rootElem = new XElement(rootElemName, ClusterCfg.GetXml(suppressDefaults));
+ rootElem.Add(ReadoutUnitsCfg.GetXml(suppressDefaults));
+ if (OneWinnerDecisionMakerCfg != null)
{
- rootElem.Add(DefaultNetworksCfg.GetXml(suppressDefaults));
+ rootElem.Add(OneWinnerDecisionMakerCfg.GetXml(suppressDefaults));
}
- rootElem.Add(ReadoutUnitsCfg.GetXml(suppressDefaults));
Validate(rootElem, XsdTypeName);
return rootElem;
}
diff --git a/RCNet/Neural/Network/SM/Readout/ReadoutUnitsSettings.cs b/RCNet/Neural/Network/SM/Readout/ReadoutUnitsSettings.cs
index a480845..52f8e68 100644
--- a/RCNet/Neural/Network/SM/Readout/ReadoutUnitsSettings.cs
+++ b/RCNet/Neural/Network/SM/Readout/ReadoutUnitsSettings.cs
@@ -1,5 +1,4 @@
-using RCNet.Neural.Network.NonRecurrent;
-using System;
+using System;
using System.Collections.Generic;
using System.Linq;
using System.Xml.Linq;
@@ -24,11 +23,6 @@ public class ReadoutUnitsSettings : RCNetBaseSettings
///
public List ReadoutUnitCfgCollection { get; }
- ///
- /// Configuration of the network cluster 2nd level computation
- ///
- public NetworkClusterSecondLevelCompSettings ClusterSecondLevelCompCfg { get; }
-
///
/// Dictionary of "one winner" groups
///
@@ -40,17 +34,13 @@ public class ReadoutUnitsSettings : RCNetBaseSettings
/// Creates an initialized instance
///
/// Collection of readout unit settings
- /// Configuration of the network cluster 2nd level computation
- public ReadoutUnitsSettings(IEnumerable readoutUnitsCfgs,
- NetworkClusterSecondLevelCompSettings clusterSecondLevelCompCfg = null
- )
+ public ReadoutUnitsSettings(IEnumerable readoutUnitsCfgs)
{
ReadoutUnitCfgCollection = new List();
foreach (ReadoutUnitSettings rucfg in readoutUnitsCfgs)
{
ReadoutUnitCfgCollection.Add((ReadoutUnitSettings)rucfg.DeepClone());
}
- ClusterSecondLevelCompCfg = (NetworkClusterSecondLevelCompSettings)clusterSecondLevelCompCfg?.DeepClone();
Check();
return;
}
@@ -60,7 +50,7 @@ public ReadoutUnitsSettings(IEnumerable readoutUnitsCfgs,
///
/// Readout layer settings
public ReadoutUnitsSettings(params ReadoutUnitSettings[] readoutUnitsCfgs)
- : this(readoutUnitsCfgs.AsEnumerable(), null)
+ : this(readoutUnitsCfgs.AsEnumerable())
{
return;
}
@@ -70,7 +60,7 @@ public ReadoutUnitsSettings(params ReadoutUnitSettings[] readoutUnitsCfgs)
///
/// Source instance
public ReadoutUnitsSettings(ReadoutUnitsSettings source)
- : this(source.ReadoutUnitCfgCollection, source.ClusterSecondLevelCompCfg)
+ : this(source.ReadoutUnitCfgCollection)
{
return;
}
@@ -89,15 +79,6 @@ public ReadoutUnitsSettings(XElement elem)
{
ReadoutUnitCfgCollection.Add(new ReadoutUnitSettings(unitElem));
}
- XElement clusterSecondLevelCompElem = settingsElem.Elements("clusterSecondLevelComputation").FirstOrDefault();
- if(clusterSecondLevelCompElem != null)
- {
- ClusterSecondLevelCompCfg = new NetworkClusterSecondLevelCompSettings(clusterSecondLevelCompElem);
- }
- else
- {
- ClusterSecondLevelCompCfg = null;
- }
Check();
return;
}
@@ -194,10 +175,6 @@ public override XElement GetXml(string rootElemName, bool suppressDefaults)
{
rootElem.Add(rucfg.GetXml(suppressDefaults));
}
- if(ClusterSecondLevelCompCfg != null)
- {
- rootElem.Add(ClusterSecondLevelCompCfg.GetXml(suppressDefaults));
- }
Validate(rootElem, XsdTypeName);
return rootElem;
}
diff --git a/RCNet/Neural/Network/SM/StateMachineDesigner.cs b/RCNet/Neural/Network/SM/StateMachineDesigner.cs
index c08f16a..f33b2ce 100644
--- a/RCNet/Neural/Network/SM/StateMachineDesigner.cs
+++ b/RCNet/Neural/Network/SM/StateMachineDesigner.cs
@@ -229,9 +229,9 @@ params string[] unitName
{
unitCfgCollection.Add(new ReadoutUnitSettings(name, new ForecastTaskSettings(new RealFeatureFilterSettings())));
}
- return new ReadoutLayerSettings(crossvalidationCfg,
- new ReadoutUnitsSettings(unitCfgCollection, cluster2ndLevelComputingCfg),
- new DefaultNetworksSettings(null, new ForecastNetworksSettings(netCfg))
+ return new ReadoutLayerSettings(new ClusterSettings(crossvalidationCfg, new DefaultNetworksSettings(null, new ForecastNetworksSettings(netCfg)), cluster2ndLevelComputingCfg),
+ new ReadoutUnitsSettings(unitCfgCollection),
+ null
);
}
@@ -259,9 +259,9 @@ params string[] unitName
{
unitCfgCollection.Add(new ReadoutUnitSettings(name, new ClassificationTaskSettings(oneWinnerGroupName)));
}
- return new ReadoutLayerSettings(crossvalidationCfg,
- new ReadoutUnitsSettings(unitCfgCollection, cluster2ndLevelComputingCfg),
- new DefaultNetworksSettings(new ClassificationNetworksSettings(netCfg), null)
+ return new ReadoutLayerSettings(new ClusterSettings(crossvalidationCfg, new DefaultNetworksSettings(new ClassificationNetworksSettings(netCfg), null), cluster2ndLevelComputingCfg),
+ new ReadoutUnitsSettings(unitCfgCollection),
+ null
);
}
diff --git a/RCNet/RCNetTypes.xsd b/RCNet/RCNetTypes.xsd
index 8ee8e56..d29b377 100644
--- a/RCNet/RCNetTypes.xsd
+++ b/RCNet/RCNetTypes.xsd
@@ -1243,6 +1243,21 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
@@ -2556,8 +2571,8 @@
-
-
+
+
@@ -2604,21 +2619,45 @@
-
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ Specifies how rich will be an input for the final probabilities network. True means to use all available sub-predictions from clusters members and False means to use only already aggregated predictions from clusters. Default value is true.
+
+
+
+
+
+
+
+
+
+
-
-
+
+