Skip to content
This repository has been archived by the owner on May 11, 2023. It is now read-only.

Commit

Permalink
One Takes All (softmax) network, cluster and builders. Preparation fo…
Browse files Browse the repository at this point in the history
…r the ReadoutLayer redesign.
  • Loading branch information
okozelsk committed Dec 7, 2020
1 parent f9c4ea4 commit aa671e1
Show file tree
Hide file tree
Showing 19 changed files with 1,637 additions and 329 deletions.
4 changes: 2 additions & 2 deletions Demo/DemoConsoleApp/Examples/TTOOForecastFromScratch.cs
Original file line number Diff line number Diff line change
Expand Up @@ -263,11 +263,11 @@ ReadoutLayerSettings CreateReadoutLayerCfg(double foldDataRatio, int numOfAttemp
ReadoutUnitSettings highReadoutUnitCfg = new ReadoutUnitSettings("High", new ForecastTaskSettings(new RealFeatureFilterSettings()));
ReadoutUnitSettings lowReadoutUnitCfg = new ReadoutUnitSettings("Low", new ForecastTaskSettings(new RealFeatureFilterSettings()));
//Create readout layer configuration
ReadoutLayerSettings readoutLayerCfg = new ReadoutLayerSettings(new CrossvalidationSettings(foldDataRatio),
ReadoutLayerSettings readoutLayerCfg = new ReadoutLayerSettings(new ClusterSettings(new CrossvalidationSettings(foldDataRatio), defaultNetworksCfg),
new ReadoutUnitsSettings(highReadoutUnitCfg,
lowReadoutUnitCfg
),
defaultNetworksCfg
null
);
return readoutLayerCfg;
}
Expand Down
84 changes: 83 additions & 1 deletion Demo/DemoConsoleApp/Playground.cs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
using RCNet.MathTools;
using RCNet.Neural.Data.Filter;
using RCNet.Neural.Data;
using RCNet.Neural.Network.NonRecurrent;
using RCNet.Neural.Network.NonRecurrent.FF;

namespace Demo.DemoConsoleApp
{
Expand Down Expand Up @@ -339,14 +341,94 @@ private void TestVectorBundleFolderization(string dataFile, int numOfClasses)
return;
}

/// <summary>
/// Displays information about the readout unit regression progress.
/// </summary>
/// <param name="buildingState">Current state of the regression process</param>
/// <param name="foundBetter">Indicates that the best readout unit was changed as a result of the performed epoch</param>
private void OnRegressionEpochDone(TrainedOneTakesAllNetworkBuilder.BuildingState buildingState, bool foundBetter)
{
int reportEpochsInterval = 5;
//Progress info
if (foundBetter ||
(buildingState.Epoch % reportEpochsInterval) == 0 ||
buildingState.Epoch == buildingState.MaxEpochs ||
(buildingState.Epoch == 1 && buildingState.RegrAttemptNumber == 1)
)
{
//Build progress report message
string progressText = buildingState.GetProgressInfo(4);
//Report the progress
if((buildingState.Epoch == 1 && buildingState.RegrAttemptNumber == 1))
{
Console.WriteLine();
}
Console.Write("\x0d" + progressText + " ");
}
return;
}


private void TestTrainedOneTakesAllClusterAndBuilder(string trainDataFile, string verifyDataFile, int numOfClasses, double foldDataRatio = 0.1d)
{
Console.BufferWidth = 320;
Console.WriteLine("One Takes All - Cluster and Cluster builder test");
//Load csv data and create vector bundle
Console.WriteLine($"Loading {trainDataFile}...");
CsvDataHolder trainCsvData = new CsvDataHolder(trainDataFile);
VectorBundle trainData = VectorBundle.Load(trainCsvData, numOfClasses);
Console.WriteLine($"Loading {verifyDataFile}...");
CsvDataHolder verifyCsvData = new CsvDataHolder(verifyDataFile);
VectorBundle verifyData = VectorBundle.Load(verifyCsvData, numOfClasses);
Console.WriteLine($"Cluster training on {trainDataFile}...");
//TRAINING
List<FeedForwardNetworkSettings> netCfgs = new List<FeedForwardNetworkSettings>();
netCfgs.Add(new FeedForwardNetworkSettings(new AFAnalogSoftMaxSettings(),
new HiddenLayersSettings(new HiddenLayerSettings(30, new AFAnalogTanHSettings())),
new RPropTrainerSettings(5, 750)
)
);

TrainedOneTakesAllNetworkClusterBuilder builder =
new TrainedOneTakesAllNetworkClusterBuilder("Test",
netCfgs,
null,
null
);
builder.RegressionEpochDone += OnRegressionEpochDone;
TrainedOneTakesAllNetworkCluster tc = builder.Build(trainData, new CrossvalidationSettings(foldDataRatio));

//VERIFICATION
Console.WriteLine();
Console.WriteLine();
Console.WriteLine($"Cluster verification on {verifyDataFile}...");
Console.WriteLine();
int numOfErrors = 0;
for(int i = 0; i < verifyData.InputVectorCollection.Count; i++)
{
double[] computed = tc.Compute(verifyData.InputVectorCollection[i]);
int computedWinnerIdx = computed.MaxIdx();
int realWinnerIdx = verifyData.OutputVectorCollection[i].MaxIdx();
if (computedWinnerIdx != realWinnerIdx) ++numOfErrors;
Console.Write("\x0d" + $"({i+1}/{verifyData.InputVectorCollection.Count}) Errors:{numOfErrors}...");
}
Console.WriteLine();
Console.WriteLine($"Accuracy {(1d - (double)numOfErrors/(double)verifyData.InputVectorCollection.Count).ToString(CultureInfo.InvariantCulture)}");
Console.WriteLine();

return;
}

/// <summary>
/// Playground's entry point
/// </summary>
public void Run()
{
Console.Clear();
//TODO - place your code here
TestVectorBundleFolderization("./Data/ProximalPhalanxOutlineAgeGroup_train.csv", 3);
//TestVectorBundleFolderization("./Data/ProximalPhalanxOutlineAgeGroup_train.csv", 3);
TestTrainedOneTakesAllClusterAndBuilder("./Data/LibrasMovement_train.csv", "./Data/LibrasMovement_verify.csv", 15, 0.1d);
TestTrainedOneTakesAllClusterAndBuilder("./Data/ProximalPhalanxOutlineAgeGroup_train.csv", "./Data/ProximalPhalanxOutlineAgeGroup_verify.csv", 3, 0.1d);
return;
}

Expand Down
Loading

0 comments on commit aa671e1

Please sign in to comment.