-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
11d67f3
commit da2bf35
Showing
3 changed files
with
281 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,32 @@ | ||
# ml4j-quickstart | ||
# ml4j-quickstart | ||
|
||
Get up and running with ml4j quickly, with ml4j-quickstart. | ||
|
||
There are many customisations possible with the configuration of the various | ||
ml4j components and defaults. | ||
|
||
This project autowires these components and factoris with default settings, so | ||
you can get starting working with the project with a single artifact import. | ||
|
||
## Quick Start ## | ||
|
||
Download the jar though Maven: | ||
|
||
```xml | ||
<repository> | ||
<id>ml4j-snapshots</id> | ||
<url>https://raw.githubusercontent.com/ml4j/mvn-repository/master/snapshots | ||
</url> | ||
<snapshots> | ||
<enabled>true</enabled> | ||
</snapshots> | ||
</repository> | ||
``` | ||
|
||
```xml | ||
<dependency> | ||
<groupId>org.ml4j</groupId> | ||
<artifactId>ml4j-quickstart</artifactId> | ||
<version>2.0.0-SNAPSHOT</version> | ||
</dependency> | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,88 @@ | ||
<project xmlns="http://maven.apache.org/POM/4.0.0" | ||
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" | ||
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/maven-v4_0_0.xsd"> | ||
<modelVersion>4.0.0</modelVersion> | ||
<groupId>org.ml4j</groupId> | ||
<artifactId>ml4j-quickstart</artifactId> | ||
<packaging>jar</packaging> | ||
<version>2.0.0-SNAPSHOT</version> | ||
<name>ml4j-quickstart</name> | ||
<properties> | ||
</properties> | ||
<distributionManagement> | ||
</distributionManagement> | ||
<repositories> | ||
<repository> | ||
<id>ml4j-releases</id> | ||
<url>https://raw.githubusercontent.com/ml4j/mvn-repository/master/releases | ||
</url> | ||
<snapshots> | ||
<enabled>false</enabled> | ||
</snapshots> | ||
</repository> | ||
<repository> | ||
<id>ml4j-snapshots</id> | ||
<url>https://raw.githubusercontent.com/ml4j/mvn-repository/master/snapshots | ||
</url> | ||
<snapshots> | ||
<enabled>true</enabled> | ||
</snapshots> | ||
</repository> | ||
</repositories> | ||
<dependencies> | ||
<dependency> | ||
<groupId>org.ml4j</groupId> | ||
<artifactId>ml4j-sessions-impl</artifactId> | ||
<version>2.0.0-SNAPSHOT</version> | ||
</dependency> | ||
<dependency> | ||
<groupId>org.ml4j</groupId> | ||
<artifactId>ml4j-default-components</artifactId> | ||
<version>2.0.0-SNAPSHOT</version> | ||
</dependency> | ||
<dependency> | ||
<groupId>org.ml4j</groupId> | ||
<artifactId>ml4j-nn-impl</artifactId> | ||
<version>2.0.0-SNAPSHOT</version> | ||
</dependency> | ||
<dependency> | ||
<groupId>org.ml4j</groupId> | ||
<artifactId>ml4j-layered-nn-impl</artifactId> | ||
<version>2.0.0-SNAPSHOT</version> | ||
</dependency> | ||
<dependency> | ||
<groupId>org.ml4j</groupId> | ||
<artifactId>neural-network-architectures</artifactId> | ||
<version>1.0.0-SNAPSHOT</version> | ||
</dependency> | ||
<dependency> | ||
<groupId>org.projectlombok</groupId> | ||
<artifactId>lombok</artifactId> | ||
<version>1.16.22</version> | ||
<scope>provided</scope> | ||
</dependency> | ||
<dependency> | ||
<groupId>org.mockito</groupId> | ||
<artifactId>mockito-core</artifactId> | ||
<version>1.10.19</version> | ||
<scope>test</scope> | ||
</dependency> | ||
</dependencies> | ||
<build> | ||
<plugins> | ||
<plugin> | ||
<groupId>org.apache.maven.plugins</groupId> | ||
<artifactId>maven-compiler-plugin</artifactId> | ||
<configuration> | ||
<source>1.8</source> | ||
<target>1.8</target> | ||
</configuration> | ||
</plugin> | ||
</plugins> | ||
</build> | ||
<reporting> | ||
<plugins> | ||
</plugins> | ||
</reporting> | ||
</project> | ||
|
161 changes: 161 additions & 0 deletions
161
src/main/java/org/ml4j/nn/quickstart/sessions/factories/QuickstartSessionFactory.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,161 @@ | ||
/* | ||
* Copyright 2020 the original author or authors. | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except | ||
* in compliance with the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software distributed under the License | ||
* is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express | ||
* or implied. See the License for the specific language governing permissions and limitations under | ||
* the License. | ||
*/ | ||
package org.ml4j.nn.quickstart.sessions.factories; | ||
|
||
import org.ml4j.MatrixFactory; | ||
import org.ml4j.jblas.JBlasRowMajorMatrixFactory; | ||
import org.ml4j.nn.activationfunctions.factories.DifferentiableActivationFunctionFactory; | ||
import org.ml4j.nn.axons.factories.AxonsFactory; | ||
import org.ml4j.nn.components.DirectedComponentsContext; | ||
import org.ml4j.nn.components.DirectedComponentsContextImpl; | ||
import org.ml4j.nn.components.factories.DirectedComponentFactory; | ||
import org.ml4j.nn.factories.DefaultAxonsFactoryImpl; | ||
import org.ml4j.nn.factories.DefaultDifferentiableActivationFunctionFactory; | ||
import org.ml4j.nn.factories.DefaultDirectedComponentFactoryImpl; | ||
import org.ml4j.nn.layers.DefaultDirectedLayerFactory; | ||
import org.ml4j.nn.layers.DirectedLayerFactory; | ||
import org.ml4j.nn.sessions.factories.DefaultSessionFactoryImpl; | ||
import org.ml4j.nn.supervised.DefaultLayeredSupervisedFeedForwardNeuralNetworkFactory; | ||
import org.ml4j.nn.supervised.DefaultSupervisedFeedForwardNeuralNetworkFactory; | ||
import org.ml4j.nn.supervised.LayeredSupervisedFeedForwardNeuralNetworkFactory; | ||
import org.ml4j.nn.supervised.SupervisedFeedForwardNeuralNetworkFactory; | ||
|
||
/** | ||
* Quick-start implementation of DefaultSessionFactory, pre-configured with | ||
* default MatrixFactory, AxonsFactory, DifferentiableActivationFunctionFactory | ||
* and DirectedComponentFactory implementations. | ||
* | ||
* @author Michael Lavelle | ||
* | ||
*/ | ||
public class QuickstartSessionFactory extends DefaultSessionFactoryImpl { | ||
|
||
public static final MatrixFactory DEFAULT_MATRIX_FACTORY = new JBlasRowMajorMatrixFactory(); | ||
|
||
public static final AxonsFactory DEFAULT_AXONS_FACTORY = createDefaultAxonsFactory(DEFAULT_MATRIX_FACTORY); | ||
|
||
public static final DifferentiableActivationFunctionFactory DEFAULT_ACTIVATION_FUNCTION_FACTORY = new DefaultDifferentiableActivationFunctionFactory(); | ||
|
||
/** | ||
* Quickstart session factory with all factories set to defaults | ||
*/ | ||
public QuickstartSessionFactory(boolean isTrainingContext) { | ||
this(DEFAULT_MATRIX_FACTORY, createDefaultDirectedComponentsContext(DEFAULT_MATRIX_FACTORY, isTrainingContext)); | ||
} | ||
|
||
/** | ||
* Quickstart session factory with all factories set to defaults, but with provided directedComponentsContext | ||
*/ | ||
public QuickstartSessionFactory(DirectedComponentsContext directedComponentsContext) { | ||
this(directedComponentsContext.getMatrixFactory(), directedComponentsContext); | ||
} | ||
|
||
/** | ||
* Quickstart session factory with specified MatrixFactory and DirectedComponentsContext, with all other | ||
* factories set to defaults. | ||
* | ||
* @param matrixFactory | ||
*/ | ||
public QuickstartSessionFactory(MatrixFactory matrixFactory, DirectedComponentsContext directedComponentsContext) { | ||
this(matrixFactory, createDefaultAxonsFactory(matrixFactory), DEFAULT_ACTIVATION_FUNCTION_FACTORY, directedComponentsContext); | ||
} | ||
|
||
/** | ||
* Quickstart session factory with specified MatrixFactory and | ||
* DifferentiableActivationFunctionFactory, with all other factories set to | ||
* defaults. | ||
* | ||
* @param matrixFactory | ||
*/ | ||
public QuickstartSessionFactory(MatrixFactory matrixFactory, | ||
DifferentiableActivationFunctionFactory activationFunctionFactory, DirectedComponentsContext directedComponentsContext) { | ||
this(matrixFactory, createDefaultAxonsFactory(matrixFactory), activationFunctionFactory, directedComponentsContext); | ||
} | ||
|
||
/** | ||
* Quickstart session factory with specified MatrixFactory and AxonsFactory, | ||
* with all other factories set to defaults. | ||
* | ||
* @param matrixFactory | ||
*/ | ||
public QuickstartSessionFactory(MatrixFactory matrixFactory, AxonsFactory axonsFactory, DirectedComponentsContext directedComponentsContext) { | ||
this(matrixFactory, axonsFactory, | ||
createDefaultDirectedComponentFactory(matrixFactory, axonsFactory, DEFAULT_ACTIVATION_FUNCTION_FACTORY, directedComponentsContext), | ||
DEFAULT_ACTIVATION_FUNCTION_FACTORY, directedComponentsContext); | ||
} | ||
|
||
/** | ||
* Quickstart session factory with specified MatrixFactory, AxonsFactory and | ||
* DifferentiableActivationFunctionFactory, with all other factories set to | ||
* defaults. | ||
* | ||
* @param matrixFactory | ||
*/ | ||
public QuickstartSessionFactory(MatrixFactory matrixFactory, AxonsFactory axonsFactory, | ||
DifferentiableActivationFunctionFactory activationFunctionFactory, DirectedComponentsContext directedComponentsContext) { | ||
this(matrixFactory, axonsFactory, | ||
createDefaultDirectedComponentFactory(matrixFactory, axonsFactory, activationFunctionFactory, directedComponentsContext), | ||
activationFunctionFactory, directedComponentsContext); | ||
} | ||
|
||
/** | ||
* Quickstart session factory with specified MatrixFactory, AxonsFactory, | ||
* DifferentiableActivationFunctionFactory and DirectedComponentFactory | ||
* | ||
* @param matrixFactory | ||
*/ | ||
public QuickstartSessionFactory(MatrixFactory matrixFactory, AxonsFactory axonsFactory, | ||
DirectedComponentFactory directedComponentFactory, | ||
DifferentiableActivationFunctionFactory activationFunctionFactory, DirectedComponentsContext directedComponentsContext) { | ||
super(matrixFactory, directedComponentFactory, | ||
createDefaultDirectedLayerFactory(axonsFactory, activationFunctionFactory, directedComponentFactory, | ||
directedComponentsContext), | ||
new DefaultSupervisedFeedForwardNeuralNetworkFactory(directedComponentFactory), | ||
new DefaultLayeredSupervisedFeedForwardNeuralNetworkFactory(directedComponentFactory), directedComponentsContext); | ||
} | ||
|
||
/** | ||
* Quickstart session factory allowing all factories to be specified | ||
* | ||
* @param matrixFactory | ||
*/ | ||
public QuickstartSessionFactory(MatrixFactory matrixFactory, DirectedComponentFactory directedComponentFactory, | ||
DirectedLayerFactory directedLayerFactory, | ||
SupervisedFeedForwardNeuralNetworkFactory supervisedFeedForwardNeuralNetworkFactory, | ||
LayeredSupervisedFeedForwardNeuralNetworkFactory layeredSupervisedFeedForwardNeuralNetworkFactory, | ||
DirectedComponentsContext directedComponentsContext) { | ||
super(matrixFactory, directedComponentFactory, directedLayerFactory, supervisedFeedForwardNeuralNetworkFactory, | ||
layeredSupervisedFeedForwardNeuralNetworkFactory, directedComponentsContext); | ||
} | ||
|
||
private static DirectedComponentsContext createDefaultDirectedComponentsContext(MatrixFactory matrixFactory, boolean isTrainingContext) { | ||
return new DirectedComponentsContextImpl(matrixFactory, isTrainingContext); | ||
} | ||
|
||
private static AxonsFactory createDefaultAxonsFactory(MatrixFactory matrixFactory) { | ||
return new DefaultAxonsFactoryImpl(matrixFactory); | ||
} | ||
|
||
private static DirectedComponentFactory createDefaultDirectedComponentFactory(MatrixFactory matrixFactory, | ||
AxonsFactory axonsFactory, DifferentiableActivationFunctionFactory activationFunctionFactory, DirectedComponentsContext directedComponentsContext) { | ||
return new DefaultDirectedComponentFactoryImpl(matrixFactory, axonsFactory, activationFunctionFactory, directedComponentsContext); | ||
} | ||
|
||
private static DirectedLayerFactory createDefaultDirectedLayerFactory(AxonsFactory axonsFactory, | ||
DifferentiableActivationFunctionFactory activationFunctionFactory, | ||
DirectedComponentFactory directedComponentFactory, DirectedComponentsContext directedComponentsContext) { | ||
return new DefaultDirectedLayerFactory(axonsFactory, activationFunctionFactory, directedComponentFactory, directedComponentsContext); | ||
} | ||
|
||
} |