Skip to content

Commit

Permalink
Starting to draft some code on training phase implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
AndrewQuijano committed Jul 29, 2023
1 parent a472600 commit 9e62bad
Show file tree
Hide file tree
Showing 16 changed files with 303 additions and 47 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ bash setup.sh
3. precision controls how accurate to measure thresholds that are decimals. If a value was 100.1, then a precision of
1 would set this value to 1001.
4. The data would point to the directory with the `answer.csv` file and all the training and testing data.
2. Currently, the [test file](src/test/java/PrivacyTest.java) will read from the `data/answers.csv` file.
2. Currently, the [test file](src/test/java/EvaluationTest.java) will read from the `data/answers.csv` file.
1. The first column is the training data set,
it is required to be a .arff file to be compatible with Weka.
Alternatively, you can pass a .model file, which is a pre-trained Weka model.
Expand Down
1 change: 1 addition & 0 deletions build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ java {

dependencies {
implementation 'junit:junit:4.13.1'
implementation 'org.jetbrains:annotations:24.0.0'
testImplementation 'org.junit.jupiter:junit-jupiter-api:5.8.1'
testRuntimeOnly 'org.junit.jupiter:junit-jupiter-engine:5.8.1'

Expand Down
Binary file modified libs/crypto.jar
Binary file not shown.
12 changes: 4 additions & 8 deletions src/main/java/weka/finito/AES.java
Original file line number Diff line number Diff line change
Expand Up @@ -31,16 +31,12 @@ public final class AES {
private String iv_string = null;
private final Cipher cipher = Cipher.getInstance("AES/CBC/PKCS5Padding");

public AES(String password) throws NoSuchPaddingException, NoSuchAlgorithmException {
try {
getKeyFromPassword(password);
}
catch (NoSuchAlgorithmException | InvalidKeySpecException e) {
e.printStackTrace();
}
public AES(String password) throws NoSuchPaddingException, NoSuchAlgorithmException, InvalidKeySpecException {
getKeyFromPassword(password);
}

public AES(String password, int iterations, int key_length) throws NoSuchPaddingException, NoSuchAlgorithmException {
public AES(String password, int iterations, int key_length) throws NoSuchPaddingException,
NoSuchAlgorithmException, InvalidKeySpecException {
this(password);
this.iterations = iterations;
this.key_length = key_length;
Expand Down
37 changes: 19 additions & 18 deletions src/main/java/weka/finito/client.java
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,10 @@

import java.lang.System;

import security.DGK.DGKKeyPairGenerator;
import security.DGK.DGKOperations;
import security.DGK.DGKPrivateKey;
import security.DGK.DGKPublicKey;
import security.dgk.DGKKeyPairGenerator;
import security.dgk.DGKOperations;
import security.dgk.DGKPrivateKey;
import security.dgk.DGKPublicKey;
import security.misc.HomomorphicException;
import security.paillier.PaillierCipher;
import security.paillier.PaillierKeyPairGenerator;
Expand Down Expand Up @@ -172,13 +172,10 @@ private boolean need_keys() {
}
return false;
}
catch (RuntimeException e) {
return true;
}
catch (NoSuchAlgorithmException e) {
catch (NoSuchAlgorithmException | IOException | ClassNotFoundException e) {
throw new RuntimeException(e);
}
}
}

private String [] read_classes() {
// Don't forget to remember the classes of DT as well
Expand Down Expand Up @@ -263,11 +260,10 @@ private Hashtable<String, BigIntegers> read_features(String path,
}

// Function used to Evaluate
private void communicate_with_level_site(Socket level_site)
private void communicate_with_level_site(Socket level_site, bob client)
throws IOException, ClassNotFoundException, HomomorphicException {
// Communicate with each Level-Site
Object o;
bob client;

// Create I/O streams
ObjectOutputStream to_level_site = new ObjectOutputStream(level_site.getOutputStream());
Expand All @@ -278,7 +274,7 @@ private void communicate_with_level_site(Socket level_site)
to_level_site.flush();

// Send the Public Keys using Alice and Bob
client = new bob(level_site, paillier, dgk);
client.set_socket(level_site);

// Send bool:
// 1- true, there is an encrypted index coming
Expand Down Expand Up @@ -311,7 +307,7 @@ else if (comparison_type == 0) {
else if (comparison_type == 1) {
client.setDGKMode(true);
}
client.Protocol4();
client.Protocol2();
}

// Get boolean from level-site:
Expand Down Expand Up @@ -339,6 +335,7 @@ else if (comparison_type == 1) {
// Function used to Evaluate
public void run() {
this.talk_to_server_site = this.need_keys();
bob client = new bob(paillier, dgk, null);

try {
// Don't regenerate keys if you are just using a different VALUES file
Expand Down Expand Up @@ -391,7 +388,7 @@ public void run() {

try(Socket level_site = new Socket(level_site_ips[i], connection_port)) {
System.out.println("Client connected to level " + i);
communicate_with_level_site(level_site);
communicate_with_level_site(level_site, client);
}
}
long end_time = System.nanoTime();
Expand All @@ -405,10 +402,14 @@ public void run() {
}

// At the end, write your keys...
dgk_public_key.writeKey("dgk.pub");
paillier_public_key.writeKey("paillier.pub");
dgk_private_key.writeKey("dgk");
paillier_private_key.writeKey("paillier");
try {
paillier_private_key.writeKey("paillier");
dgk_public_key.writeKey("dgk.pub");
paillier_public_key.writeKey("paillier.pub");
dgk_private_key.writeKey("dgk");
} catch (IOException e) {
throw new RuntimeException(e);
}

// Remember the classes as well too...
try (BufferedWriter writer = new BufferedWriter(new FileWriter(classes_file, true))) {
Expand Down
3 changes: 2 additions & 1 deletion src/main/java/weka/finito/level_site_server.java
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

import java.lang.System;
import java.security.NoSuchAlgorithmException;
import java.security.spec.InvalidKeySpecException;

public class level_site_server implements Runnable {

Expand All @@ -21,7 +22,7 @@ public class level_site_server implements Runnable {

protected AES crypto;

public static void main(String[] args) throws NoSuchPaddingException, NoSuchAlgorithmException {
public static void main(String[] args) throws NoSuchPaddingException, NoSuchAlgorithmException, InvalidKeySpecException {
int our_port = 0;
int our_precision = 0;
String AES_Pass = System.getenv("AES_PASS");
Expand Down
4 changes: 2 additions & 2 deletions src/main/java/weka/finito/level_site_thread.java
Original file line number Diff line number Diff line change
Expand Up @@ -102,10 +102,10 @@ else if ((comparisonType == 3) || (comparisonType == 5)) {
System.out.printf("Comparison took %f ms\n", run_time);
if (((comparisonType == 1) && (ld.threshold == 0))
|| (comparisonType == 4) || (comparisonType == 5)) {
return Niu.Protocol4(encrypted_thresh, encrypted_client_value);
return Niu.Protocol2(encrypted_thresh, encrypted_client_value);
}
else {
return Niu.Protocol4(encrypted_client_value, encrypted_thresh);
return Niu.Protocol2(encrypted_client_value, encrypted_thresh);
}
}

Expand Down
17 changes: 5 additions & 12 deletions src/main/java/weka/finito/server.java
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.PrintWriter;
import java.io.BufferedReader;
import java.io.FileReader;
import java.io.IOException;

import java.net.ServerSocket;
Expand All @@ -18,7 +16,7 @@
import java.lang.System;
import java.util.concurrent.TimeUnit;

import security.DGK.DGKPublicKey;
import security.dgk.DGKPublicKey;
import security.paillier.PaillierPublicKey;
import weka.classifiers.trees.j48.BinC45ModelSelection;
import weka.classifiers.trees.j48.C45PruneableClassifierTree;
Expand All @@ -28,6 +26,7 @@
import weka.core.SerializationHelper;
import weka.finito.structs.level_order_site;
import weka.finito.structs.NodeInfo;
import weka.finito.utils.DataHandling;


public final class server implements Runnable {
Expand Down Expand Up @@ -175,6 +174,7 @@ private static ClassifierTree train_decision_tree(String arff_file)
File training_file = new File(arff_file);
String base_name = training_file.getName().split("\\.")[0];
File output_model_file = new File("output", base_name + ".model");
Instances train;

File dir = new File("output");
if (!dir.exists()) {
Expand All @@ -189,16 +189,9 @@ private static ClassifierTree train_decision_tree(String arff_file)
printTree(j48, base_name);
return j48;
}

// If this is a .arff file
Instances train = null;

try (BufferedReader reader = new BufferedReader(new FileReader(arff_file))) {
train = new Instances(reader);
} catch (IOException e2) {
e2.printStackTrace();
else {
train = DataHandling.read_data(arff_file);
}
assert train != null;
train.setClassIndex(train.numAttributes() - 1);

// https://weka.sourceforge.io/doc.dev/weka/classifiers/trees/j48/C45ModelSelection.html
Expand Down
4 changes: 2 additions & 2 deletions src/main/java/weka/finito/structs/NodeInfo.java
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
package weka.finito.structs;

import security.DGK.DGKOperations;
import security.DGK.DGKPublicKey;
import security.dgk.DGKOperations;
import security.dgk.DGKPublicKey;
import security.misc.HomomorphicException;
import security.paillier.PaillierCipher;
import security.paillier.PaillierPublicKey;
Expand Down
2 changes: 1 addition & 1 deletion src/main/java/weka/finito/structs/level_order_site.java
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
package weka.finito.structs;

import security.DGK.DGKPublicKey;
import security.dgk.DGKPublicKey;
import security.paillier.PaillierPublicKey;

import java.io.Serializable;
Expand Down
15 changes: 15 additions & 0 deletions src/main/java/weka/finito/training/DataProvider.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
package weka.finito.training;

// This class is a Data Provider.
// They can own multiple levels, as the data shouldn't leave their site from my understanding
public class DataProvider implements Runnable {

public DataProvider() {

}

@Override
public void run() {
// Within this run, you should complete the training
}
}
29 changes: 29 additions & 0 deletions src/main/java/weka/finito/training/SiteMain.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
package weka.finito.training;

import weka.Run;
import weka.classifiers.trees.j48.ClassifierTree;

// This will work with DataProviders to collect the needed data on how to make PPDT
// Will help Data Providers create level-sites
public class SiteMain implements Runnable {

// Main Objective
// 1- get info from all providers?
// 2- Somehow tell data providers what level-site(s) they are creating?

// ppdt could be full PPDT?
ClassifierTree ppdt;

String [] hosts;
int [] ports;

public SiteMain(String [] hosts, int [] ports) {
this.hosts = hosts;
this.ports = ports;
}

@Override
public void run() {

}
}
41 changes: 41 additions & 0 deletions src/main/java/weka/finito/utils/Computations.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
package weka.finito.utils;

import weka.core.Attribute;
import weka.core.Utils;

import java.rmi.RemoteException;

public class Computations {
// Create a Compute Locally

/*
// Probably would be a NodeInfo object instead?
private TreeNode formMajorityClassLeaf() throws RemoteException {
// !!! FIX IMPORTANT
// IF DISTRIBUTION is all zeroes do something else
// For now just setting all classes to be equiprobable.. This is clearly wrong
boolean nonZeroFlag = false;
for (int i = 1; i<m_nhosts; i++)
m_Slaves[i-1].prepareTransVector();
m_Distribution = new double[m_insts.numClasses()];
for (int i=0; i<m_insts.numClasses(); i++) {
Attribute classattr = m_insts.classAttribute();
prepareTransVectorWithClassFilter(classattr.indexOfValue(classattr.value(i)));
m_Distribution[i] = SetIntersect();
if (m_Distribution[i] > 0) nonZeroFlag = true;
}
if (!nonZeroFlag) {
for (int i=0; i<m_insts.numClasses(); i++) {
m_Distribution[i] = 1;
}
}
Utils.normalize(m_Distribution);
m_ClassValue = Utils.maxIndex(m_Distribution);
return new LeafNode(m_ClassValue, m_Distribution);
}
*/
}
Loading

0 comments on commit 9e62bad

Please sign in to comment.