generated from project-mirai/mirai-console-plugin-template
-
-
Notifications
You must be signed in to change notification settings - Fork 8
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: messagefilter: add ll4j lib, model
- Loading branch information
Showing
30 changed files
with
1,226 additions
and
29 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
package huzpsb.ll4j.data; | ||
|
||
import java.io.File; | ||
import java.util.Scanner; | ||
|
||
public class CsvLoader { | ||
public static DataSet load(String path, int labelIndex) { | ||
try { | ||
DataSet data = new DataSet(); | ||
Scanner sc = new Scanner(new File(path), "UTF-8"); | ||
String[] header = sc.nextLine().split(","); | ||
int n = header.length; | ||
while (sc.hasNextLine()) { | ||
String[] line = sc.nextLine().split(","); | ||
double[] x = new double[n - 1]; | ||
int label = Integer.parseInt(line[labelIndex]); | ||
int idx = 0; | ||
for (int i = 0; i < n; i++) { | ||
if (i != labelIndex) { | ||
x[idx++] = Double.parseDouble(line[i]); | ||
} | ||
} | ||
DataEntry entry = new DataEntry(label, x); | ||
data.split.add(entry); | ||
} | ||
return data; | ||
} catch (Exception ex) { | ||
throw new RuntimeException(ex); | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
package huzpsb.ll4j.data; | ||
|
||
public class DataEntry { | ||
public final int type; | ||
public final double[] values; | ||
|
||
public DataEntry(int type, double[] values) { | ||
this.type = type; | ||
this.values = values; | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
package huzpsb.ll4j.data; | ||
|
||
import java.util.ArrayList; | ||
import java.util.List; | ||
|
||
public class DataSet { | ||
public final List<DataEntry> split = new ArrayList<>(); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
package huzpsb.ll4j.filter; | ||
// MinRt is a minimal runtime for LL4J deep learning framework | ||
// While filter.MinRt is as small as possible, it is SLOW and INEFFICIENT. | ||
// Use filter.MinRt only when you need to run LL4J models in a memory-constrained environment. | ||
// Copyright (c) 2024 huzpsb [admin<at>huzpsb<dot>eu<dot>org] | ||
// Licensed under the WTFPL license. You may remove this notice at will. | ||
|
||
public class MinRt { | ||
public MinRt() { | ||
throw new UnsupportedOperationException("This is a static class"); | ||
} | ||
|
||
public static int doAi(double[] input, String[] script) { | ||
double[] current = new double[input.length]; | ||
System.arraycopy(input, 0, current, 0, input.length); | ||
for (String str : script) { | ||
if (str.length() < 2) { | ||
continue; | ||
} | ||
String[] tokens = str.split(" "); | ||
switch (tokens[0]) { | ||
case "D": | ||
int ic = Integer.parseInt(tokens[1]); | ||
int oc = Integer.parseInt(tokens[2]); | ||
if (current.length != ic) { | ||
throw new RuntimeException("Wrong input size for Dense layer (expected " + ic + ", got " + current.length + ")"); | ||
} | ||
double[] tmp = new double[oc]; | ||
for (int i = 0; i < oc; i++) { | ||
double sum = 0; | ||
for (int j = 0; j < ic; j++) { | ||
sum += current[j] * Double.parseDouble(tokens[3 + i + j * oc]); | ||
} | ||
tmp[i] = sum; | ||
} | ||
current = tmp; | ||
break; | ||
case "L": | ||
int n = Integer.parseInt(tokens[1]); | ||
if (current.length != n) { | ||
throw new RuntimeException("Wrong input size for LeakyRelu layer (expected " + n + ", got " + current.length + ")"); | ||
} | ||
for (int i = 0; i < n; i++) { | ||
current[i] = current[i] > 0 ? current[i] : current[i] * 0.01; | ||
} | ||
break; | ||
case "J": | ||
int m = Integer.parseInt(tokens[1]); | ||
if (current.length != m) { | ||
throw new RuntimeException("Wrong input size for Judge layer (expected " + m + ", got " + current.length + ")"); | ||
} | ||
int idx = 0; | ||
for (int i = 1; i < m; i++) { | ||
if (current[i] > current[idx]) { | ||
idx = i; | ||
} | ||
} | ||
return idx; | ||
default: | ||
throw new RuntimeException("Unknown layer type"); | ||
} | ||
} | ||
throw new RuntimeException("No output layer"); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
package huzpsb.ll4j.filter; | ||
|
||
import java.io.File; | ||
import java.util.ArrayList; | ||
import java.util.List; | ||
import java.util.Scanner; | ||
|
||
public class TextFilter { | ||
private static final String[] script; | ||
private static final Tokenizer tokenizer; | ||
|
||
static { | ||
try { | ||
List<String> scriptList = new ArrayList<>(); | ||
Scanner sc = new Scanner(new File("judge.model")); | ||
while (sc.hasNextLine()) { | ||
scriptList.add(sc.nextLine()); | ||
} | ||
script = scriptList.toArray(new String[0]); | ||
tokenizer = Tokenizer.loadFromFile("tokenize.model"); | ||
} catch (Exception ex) { | ||
throw new ExceptionInInitializerError(ex); | ||
} | ||
} | ||
|
||
public TextFilter() { | ||
throw new UnsupportedOperationException("This is a static class"); | ||
} | ||
|
||
public static boolean isIllegal(String text) { | ||
return MinRt.doAi(tokenizer.tokenize(text), script) == 1; | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
package huzpsb.ll4j.filter; | ||
|
||
import java.io.File; | ||
import java.util.Scanner; | ||
|
||
public class Tokenizer { | ||
private final String[] vocab; | ||
|
||
public Tokenizer(String[] vocab) { | ||
this.vocab = vocab; | ||
} | ||
|
||
public static Tokenizer loadFromFile(String filename) { | ||
try { | ||
Scanner scanner = new Scanner(new File(filename)); | ||
int size = Integer.parseInt(scanner.nextLine()); | ||
String[] vocab = new String[size]; | ||
for (int i = 0; i < size; i++) { | ||
vocab[i] = scanner.nextLine(); | ||
} | ||
return new Tokenizer(vocab); | ||
} catch (Exception ignored) { | ||
} | ||
return null; | ||
} | ||
|
||
public double[] tokenize(String text) { | ||
double[] values = new double[vocab.length]; | ||
for (int i = 0; i < values.length; i++) { | ||
if (text.contains(vocab[i])) { | ||
values[i] = 1; | ||
} | ||
} | ||
return values; | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
package huzpsb.ll4j.layer; | ||
|
||
import huzpsb.ll4j.utils.NRandom; | ||
|
||
public abstract class AbstractLayer { | ||
public final int input_size; | ||
public final int output_size; | ||
public final NRandom random; | ||
public double[] input; | ||
public double[] output; | ||
public double[] input_error = null; | ||
public double[] output_error = null; | ||
|
||
public AbstractLayer(int inputSize, int outputSize) { | ||
this(inputSize, outputSize, 1145151919810L); | ||
} | ||
|
||
public AbstractLayer(int inputSize, int outputSize, long seed) { | ||
input_size = inputSize; | ||
output_size = outputSize; | ||
input = new double[input_size]; | ||
output = new double[output_size]; | ||
random = new NRandom(seed); | ||
} | ||
|
||
public void makeInputError() { | ||
if (input_error == null) { | ||
input_error = new double[input_size]; | ||
} | ||
} | ||
|
||
public void makeOutputError() { | ||
if (output_error == null) { | ||
output_error = new double[output_size]; | ||
} | ||
} | ||
|
||
public abstract void forward(); | ||
|
||
public abstract void backward(); | ||
|
||
public abstract void update(double learningRate); | ||
|
||
public abstract void randomize(double rv); | ||
|
||
public abstract void initialize(); | ||
|
||
public abstract void serialize(StringBuilder sb); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,71 @@ | ||
package huzpsb.ll4j.layer; | ||
|
||
public class DenseLayer extends AbstractLayer { | ||
public double[][] weights; | ||
|
||
public DenseLayer(int inputSize, int outputSize) { | ||
super(inputSize, outputSize); | ||
weights = new double[inputSize][outputSize]; | ||
initialize(); | ||
} | ||
|
||
@Override | ||
public void forward() { | ||
for (int i = 0; i < output_size; i++) { | ||
output[i] = 0; | ||
for (int j = 0; j < input_size; j++) { | ||
output[i] += input[j] * weights[j][i]; | ||
} | ||
} | ||
} | ||
|
||
@Override | ||
public void backward() { | ||
makeInputError(); | ||
for (int i = 0; i < input_size; i++) { | ||
input_error[i] = 0; | ||
for (int j = 0; j < output_size; j++) { | ||
input_error[i] += output_error[j] * weights[i][j]; | ||
} | ||
} | ||
} | ||
|
||
@Override | ||
public void update(double learningRate) { | ||
for (int i = 0; i < input_size; i++) { | ||
for (int j = 0; j < output_size; j++) { | ||
double delta = learningRate * output_error[j] * input[i]; | ||
weights[i][j] -= delta; | ||
} | ||
} | ||
} | ||
|
||
@Override | ||
public void randomize(double rv) { | ||
for (int i = 0; i < input_size; i++) { | ||
for (int j = 0; j < output_size; j++) { | ||
weights[i][j] *= random.nextDouble(1 - rv, 1 + rv); | ||
} | ||
} | ||
} | ||
|
||
@Override | ||
public void initialize() { | ||
for (int i = 0; i < input_size; i++) { | ||
for (int j = 0; j < output_size; j++) { | ||
weights[i][j] = random.nextGaussian(0, 1.0 / Math.sqrt(input_size)); | ||
} | ||
} | ||
} | ||
|
||
@Override | ||
public void serialize(StringBuilder sb) { | ||
sb.append("D ").append(input_size).append(" ").append(output_size).append(" "); | ||
for (int i = 0; i < input_size; i++) { | ||
for (int j = 0; j < output_size; j++) { | ||
sb.append(weights[i][j]).append(" "); | ||
} | ||
} | ||
sb.append("\n"); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
package huzpsb.ll4j.layer; | ||
|
||
public class JudgeLayer extends AbstractLayer { | ||
public int result = 0; | ||
|
||
public JudgeLayer(int size) { | ||
super(size, size); | ||
} | ||
|
||
@Override | ||
public void forward() { | ||
for (int i = 0; i < input_size; i++) { | ||
if (Double.isNaN(input[i])) { | ||
throw new RuntimeException("input[" + i + "] is NaN! Plz reduce learning rate!"); | ||
} | ||
if (input[i] > input[result]) { | ||
result = i; | ||
} | ||
} | ||
} | ||
|
||
@Override | ||
public void backward() { | ||
makeInputError(); | ||
for (int i = 0; i < input_size; i++) { | ||
if (i == result) { | ||
input_error[i] = input[i] - 1; | ||
} else { | ||
input_error[i] = input[i]; | ||
} | ||
} | ||
} | ||
|
||
@Override | ||
public void update(double learningRate) { | ||
} | ||
|
||
@Override | ||
public void randomize(double rv) { | ||
} | ||
|
||
@Override | ||
public void initialize() { | ||
} | ||
|
||
@Override | ||
public void serialize(StringBuilder sb) { | ||
sb.append("J ").append(input_size).append("\n"); | ||
} | ||
} |
Oops, something went wrong.