Skip to content

Commit

Permalink
feat: messagefilter: add ll4j lib, model
Browse files Browse the repository at this point in the history
  • Loading branch information
guimc233 committed Feb 23, 2024
1 parent 7dcdc45 commit 6ef3d8a
Show file tree
Hide file tree
Showing 30 changed files with 1,226 additions and 29 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
* 这是[Mirai](https://github.com/mamoe/mirai)框架的插件,依赖于[Mirai Console](https://github.com/mamoe/mirai-console)运行.
* 本插件可以为使用者提供群消息合规性检查,并作出相应的处罚.
* 还有一些有趣的小东西 \>_\<
* Thanks to [LL4J](https://github.com/LL4J/)

## License
This project is subject to the [GNU Affero General Public License v3.0](LICENSE). This does only apply for source code located directly in this clean repository. During the development and compilation process, additional source code may be used to which we have obtained no rights. Such code is not covered by the GPL license.
Expand Down
31 changes: 31 additions & 0 deletions src/main/java/huzpsb/ll4j/data/CsvLoader.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
package huzpsb.ll4j.data;

import java.io.File;
import java.util.Scanner;

public class CsvLoader {
public static DataSet load(String path, int labelIndex) {
try {
DataSet data = new DataSet();
Scanner sc = new Scanner(new File(path), "UTF-8");
String[] header = sc.nextLine().split(",");
int n = header.length;
while (sc.hasNextLine()) {
String[] line = sc.nextLine().split(",");
double[] x = new double[n - 1];
int label = Integer.parseInt(line[labelIndex]);
int idx = 0;
for (int i = 0; i < n; i++) {
if (i != labelIndex) {
x[idx++] = Double.parseDouble(line[i]);
}
}
DataEntry entry = new DataEntry(label, x);
data.split.add(entry);
}
return data;
} catch (Exception ex) {
throw new RuntimeException(ex);
}
}
}
11 changes: 11 additions & 0 deletions src/main/java/huzpsb/ll4j/data/DataEntry.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
package huzpsb.ll4j.data;

public class DataEntry {
public final int type;
public final double[] values;

public DataEntry(int type, double[] values) {
this.type = type;
this.values = values;
}
}
8 changes: 8 additions & 0 deletions src/main/java/huzpsb/ll4j/data/DataSet.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
package huzpsb.ll4j.data;

import java.util.ArrayList;
import java.util.List;

public class DataSet {
public final List<DataEntry> split = new ArrayList<>();
}
65 changes: 65 additions & 0 deletions src/main/java/huzpsb/ll4j/filter/MinRt.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
package huzpsb.ll4j.filter;
// MinRt is a minimal runtime for LL4J deep learning framework
// While filter.MinRt is as small as possible, it is SLOW and INEFFICIENT.
// Use filter.MinRt only when you need to run LL4J models in a memory-constrained environment.
// Copyright (c) 2024 huzpsb [admin<at>huzpsb<dot>eu<dot>org]
// Licensed under the WTFPL license. You may remove this notice at will.

public class MinRt {
public MinRt() {
throw new UnsupportedOperationException("This is a static class");
}

public static int doAi(double[] input, String[] script) {
double[] current = new double[input.length];
System.arraycopy(input, 0, current, 0, input.length);
for (String str : script) {
if (str.length() < 2) {
continue;
}
String[] tokens = str.split(" ");
switch (tokens[0]) {
case "D":
int ic = Integer.parseInt(tokens[1]);
int oc = Integer.parseInt(tokens[2]);
if (current.length != ic) {
throw new RuntimeException("Wrong input size for Dense layer (expected " + ic + ", got " + current.length + ")");
}
double[] tmp = new double[oc];
for (int i = 0; i < oc; i++) {
double sum = 0;
for (int j = 0; j < ic; j++) {
sum += current[j] * Double.parseDouble(tokens[3 + i + j * oc]);
}
tmp[i] = sum;
}
current = tmp;
break;
case "L":
int n = Integer.parseInt(tokens[1]);
if (current.length != n) {
throw new RuntimeException("Wrong input size for LeakyRelu layer (expected " + n + ", got " + current.length + ")");
}
for (int i = 0; i < n; i++) {
current[i] = current[i] > 0 ? current[i] : current[i] * 0.01;
}
break;
case "J":
int m = Integer.parseInt(tokens[1]);
if (current.length != m) {
throw new RuntimeException("Wrong input size for Judge layer (expected " + m + ", got " + current.length + ")");
}
int idx = 0;
for (int i = 1; i < m; i++) {
if (current[i] > current[idx]) {
idx = i;
}
}
return idx;
default:
throw new RuntimeException("Unknown layer type");
}
}
throw new RuntimeException("No output layer");
}
}
33 changes: 33 additions & 0 deletions src/main/java/huzpsb/ll4j/filter/TextFilter.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
package huzpsb.ll4j.filter;

import java.io.File;
import java.util.ArrayList;
import java.util.List;
import java.util.Scanner;

public class TextFilter {
private static final String[] script;
private static final Tokenizer tokenizer;

static {
try {
List<String> scriptList = new ArrayList<>();
Scanner sc = new Scanner(new File("judge.model"));
while (sc.hasNextLine()) {
scriptList.add(sc.nextLine());
}
script = scriptList.toArray(new String[0]);
tokenizer = Tokenizer.loadFromFile("tokenize.model");
} catch (Exception ex) {
throw new ExceptionInInitializerError(ex);
}
}

public TextFilter() {
throw new UnsupportedOperationException("This is a static class");
}

public static boolean isIllegal(String text) {
return MinRt.doAi(tokenizer.tokenize(text), script) == 1;
}
}
36 changes: 36 additions & 0 deletions src/main/java/huzpsb/ll4j/filter/Tokenizer.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
package huzpsb.ll4j.filter;

import java.io.File;
import java.util.Scanner;

public class Tokenizer {
private final String[] vocab;

public Tokenizer(String[] vocab) {
this.vocab = vocab;
}

public static Tokenizer loadFromFile(String filename) {
try {
Scanner scanner = new Scanner(new File(filename));
int size = Integer.parseInt(scanner.nextLine());
String[] vocab = new String[size];
for (int i = 0; i < size; i++) {
vocab[i] = scanner.nextLine();
}
return new Tokenizer(vocab);
} catch (Exception ignored) {
}
return null;
}

public double[] tokenize(String text) {
double[] values = new double[vocab.length];
for (int i = 0; i < values.length; i++) {
if (text.contains(vocab[i])) {
values[i] = 1;
}
}
return values;
}
}
49 changes: 49 additions & 0 deletions src/main/java/huzpsb/ll4j/layer/AbstractLayer.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
package huzpsb.ll4j.layer;

import huzpsb.ll4j.utils.NRandom;

public abstract class AbstractLayer {
public final int input_size;
public final int output_size;
public final NRandom random;
public double[] input;
public double[] output;
public double[] input_error = null;
public double[] output_error = null;

public AbstractLayer(int inputSize, int outputSize) {
this(inputSize, outputSize, 1145151919810L);
}

public AbstractLayer(int inputSize, int outputSize, long seed) {
input_size = inputSize;
output_size = outputSize;
input = new double[input_size];
output = new double[output_size];
random = new NRandom(seed);
}

public void makeInputError() {
if (input_error == null) {
input_error = new double[input_size];
}
}

public void makeOutputError() {
if (output_error == null) {
output_error = new double[output_size];
}
}

public abstract void forward();

public abstract void backward();

public abstract void update(double learningRate);

public abstract void randomize(double rv);

public abstract void initialize();

public abstract void serialize(StringBuilder sb);
}
71 changes: 71 additions & 0 deletions src/main/java/huzpsb/ll4j/layer/DenseLayer.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
package huzpsb.ll4j.layer;

public class DenseLayer extends AbstractLayer {
public double[][] weights;

public DenseLayer(int inputSize, int outputSize) {
super(inputSize, outputSize);
weights = new double[inputSize][outputSize];
initialize();
}

@Override
public void forward() {
for (int i = 0; i < output_size; i++) {
output[i] = 0;
for (int j = 0; j < input_size; j++) {
output[i] += input[j] * weights[j][i];
}
}
}

@Override
public void backward() {
makeInputError();
for (int i = 0; i < input_size; i++) {
input_error[i] = 0;
for (int j = 0; j < output_size; j++) {
input_error[i] += output_error[j] * weights[i][j];
}
}
}

@Override
public void update(double learningRate) {
for (int i = 0; i < input_size; i++) {
for (int j = 0; j < output_size; j++) {
double delta = learningRate * output_error[j] * input[i];
weights[i][j] -= delta;
}
}
}

@Override
public void randomize(double rv) {
for (int i = 0; i < input_size; i++) {
for (int j = 0; j < output_size; j++) {
weights[i][j] *= random.nextDouble(1 - rv, 1 + rv);
}
}
}

@Override
public void initialize() {
for (int i = 0; i < input_size; i++) {
for (int j = 0; j < output_size; j++) {
weights[i][j] = random.nextGaussian(0, 1.0 / Math.sqrt(input_size));
}
}
}

@Override
public void serialize(StringBuilder sb) {
sb.append("D ").append(input_size).append(" ").append(output_size).append(" ");
for (int i = 0; i < input_size; i++) {
for (int j = 0; j < output_size; j++) {
sb.append(weights[i][j]).append(" ");
}
}
sb.append("\n");
}
}
50 changes: 50 additions & 0 deletions src/main/java/huzpsb/ll4j/layer/JudgeLayer.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
package huzpsb.ll4j.layer;

public class JudgeLayer extends AbstractLayer {
public int result = 0;

public JudgeLayer(int size) {
super(size, size);
}

@Override
public void forward() {
for (int i = 0; i < input_size; i++) {
if (Double.isNaN(input[i])) {
throw new RuntimeException("input[" + i + "] is NaN! Plz reduce learning rate!");
}
if (input[i] > input[result]) {
result = i;
}
}
}

@Override
public void backward() {
makeInputError();
for (int i = 0; i < input_size; i++) {
if (i == result) {
input_error[i] = input[i] - 1;
} else {
input_error[i] = input[i];
}
}
}

@Override
public void update(double learningRate) {
}

@Override
public void randomize(double rv) {
}

@Override
public void initialize() {
}

@Override
public void serialize(StringBuilder sb) {
sb.append("J ").append(input_size).append("\n");
}
}
Loading

0 comments on commit 6ef3d8a

Please sign in to comment.