Skip to content

Commit

Permalink
features object now tracked index updates. So now this whole thing sh…
Browse files Browse the repository at this point in the history
…ould be thread safe too
  • Loading branch information
AndrewQuijano committed Jan 16, 2024
1 parent ed916fa commit 20766b6
Show file tree
Hide file tree
Showing 7 changed files with 44 additions and 61 deletions.
17 changes: 4 additions & 13 deletions src/main/java/weka/finito/client.java
Original file line number Diff line number Diff line change
Expand Up @@ -25,32 +25,26 @@

public final class client implements Runnable {
private final SSLSocketFactory factory = (SSLSocketFactory) SSLSocketFactory.getDefault();

private final String classes_file = "classes.txt";
private final String features_file;
private final int key_size;
private final int precision;

private final String [] level_site_ips;
private final int [] level_site_ports;
private final int port;

private String classification = null;

private KeyPair dgk;
private KeyPair paillier;
private features feature = null;
private boolean classification_complete = false;
private String [] classes;

private DGKPublicKey dgk_public_key;
private PaillierPublicKey paillier_public_key;
private DGKPrivateKey dgk_private_key;
private PaillierPrivateKey paillier_private_key;
private final HashMap<String, String> hashed_classification = new HashMap<>();
private final String server_ip;
private final int server_port;
private Integer next_index = 0;

//For k8s deployment.
public static void main(String[] args) {
Expand Down Expand Up @@ -310,11 +304,6 @@ private void communicate_with_level_site(Socket level_site)
client = new bob_joye(paillier, dgk, null);
client.set_socket(level_site);

// Send bool:
// 1- true, there is an encrypted index coming
// 2- false, there is NO encrypted index coming
client.writeInt(next_index);

// Get the comparison
// I am not sure why I need this loop, but you will only need 1 comparison.
int comparison_type;
Expand All @@ -337,7 +326,6 @@ else if (comparison_type == 1) {
// true - get leaf value
// false - get encrypted AES index for next round
classification_complete = client.readBoolean();
System.out.println("BOOLEAN READ");
if (classification_complete) {
o = client.readObject();
if (o instanceof String) {
Expand All @@ -346,7 +334,10 @@ else if (comparison_type == 1) {
}
}
else {
next_index = client.readInt();
o = from_level_site.readObject();
if (o instanceof features) {
this.feature = (features) o;
}
}
}

Expand Down
24 changes: 9 additions & 15 deletions src/main/java/weka/finito/level_site_evaluation_thread.java
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package weka.finito;

import java.io.ObjectOutputStream;
import java.lang.System;

import security.socialistmillionaire.alice_joye;
Expand All @@ -13,27 +14,19 @@
import static weka.finito.utils.shared.*;

public class level_site_evaluation_thread implements Runnable {

private final Socket client_socket;
private final level_order_site level_site_data;
private final features encrypted_features;
private final Socket next_level_site;
private final ObjectOutputStream oos;

// This thread is ONLY to handle evaluations
public level_site_evaluation_thread(Socket client_socket, level_order_site level_site_data,
features encrypted_features) {
features encrypted_features, ObjectOutputStream oos) {
// Have encrypted copy of thresholds if not done already for all nodes in level-site
this(client_socket, level_site_data, encrypted_features, null);
}

// This thread is ONLY to handle evaluations
public level_site_evaluation_thread(Socket client_socket, level_order_site level_site_data,
features encrypted_features, Socket next_level_site) {
// Have encrypted copy of thresholds if not done already for all nodes in level-site
this.level_site_data = level_site_data;
this.client_socket = client_socket;
this.level_site_data = level_site_data;
this.encrypted_features = encrypted_features;
this.next_level_site = next_level_site;
this.oos = oos;
}

// This will run the communication with client and next level site
Expand All @@ -45,10 +38,9 @@ public final void run() {
niu.setDGKPublicKey(this.level_site_data.dgk_public_key);
niu.setPaillierPublicKey(this.level_site_data.paillier_public_key);

level_site_data.set_current_index(niu.readInt());

// Null, keep going down the tree,
// Not null, you got the correct leaf node of your DT!
// encrypted_features will have index updated within traverse_level
NodeInfo reply = traverse_level(level_site_data, encrypted_features, niu);
niu.writeInt(-1);

Expand All @@ -59,7 +51,9 @@ public final void run() {
}
else {
niu.writeBoolean(false);
niu.writeInt(level_site_data.get_next_index());
oos.writeObject(encrypted_features);
oos.flush();
// Update Index and send it down to the next level-site
}
long stop_time = System.nanoTime();
double run_time = (double) (stop_time - start_time);
Expand Down
13 changes: 11 additions & 2 deletions src/main/java/weka/finito/level_site_server.java
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import javax.net.ssl.SSLServerSocket;
import javax.net.ssl.SSLServerSocketFactory;
import javax.net.ssl.SSLSocket;
import javax.net.ssl.SSLSocketFactory;
import java.io.IOException;
import java.io.ObjectOutputStream;

Expand All @@ -20,8 +21,9 @@ public class level_site_server implements Runnable {
protected boolean isStopped = false;
protected Thread runningThread= null;
protected level_order_site level_site_parameters = null;

protected SSLServerSocketFactory factory = (SSLServerSocketFactory) SSLServerSocketFactory.getDefault();
private final SSLSocketFactory socket_factory = (SSLSocketFactory) SSLSocketFactory.getDefault();
private SSLSocket next_level_site;

public static void main(String[] args) {
setup_tls();
Expand Down Expand Up @@ -86,12 +88,19 @@ public void run() {
// System.out.println("Level-Site received training data on Port: " + client_socket.getLocalPort());
oos.writeBoolean(true);
// Create a persistent connection to next level-site and oos to send the next stuff down
/*
if(level_site_parameters.get_next_level_site() != null) {
next_level_site = (SSLSocket) socket_factory.createSocket(
level_site_parameters.get_next_level_site(), level_site_parameters.get_next_level_site_port());
}
*/
closeConnection(oos, ois, client_socket);
}
else if (o instanceof features) {
// Start evaluating with the client
level_site_evaluation_thread current_level_site_class =
new level_site_evaluation_thread(client_socket, this.level_site_parameters, (features) o);
new level_site_evaluation_thread(client_socket, this.level_site_parameters,
(features) o, oos);
new Thread(current_level_site_class).start();
}
else {
Expand Down
5 changes: 3 additions & 2 deletions src/main/java/weka/finito/server.java
Original file line number Diff line number Diff line change
Expand Up @@ -160,13 +160,14 @@ private void evaluate_with_client_directly(SSLSocket client_site)
if (client_input instanceof features) {
input = (features) client_input;
}
assert input != null;

long start_time = System.nanoTime();
int previous_index = 0;

// Traverse DT until you hit a leaf, the client has to track the index...
for (level_order_site level_site_data : all_level_sites) {
level_site_data.set_current_index(previous_index);
input.set_current_index(previous_index);

// Handle at a level...
NodeInfo leaf = traverse_level(level_site_data, input, Niu);
Expand All @@ -184,7 +185,7 @@ private void evaluate_with_client_directly(SSLSocket client_site)
}
else {
// Update the index for next merry-go-round
previous_index = level_site_data.get_next_index();
previous_index = input.get_next_index();
}
}
}
Expand Down
10 changes: 10 additions & 0 deletions src/main/java/weka/finito/structs/features.java
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,14 @@ public final class features implements Serializable {
private static final long serialVersionUID = 6000706455545108960L;
private String client_ip;
private int next_index;
private int current_index;
private final HashMap<String, BigIntegers> thresholds;

public features(String path, int precision, PaillierPublicKey paillier_public_key, DGKPublicKey dgk_public_key)
throws HomomorphicException, IOException {
this.client_ip = "";
this.next_index = 0;
this.current_index =0;
this.thresholds = read_values(path, precision, paillier_public_key, dgk_public_key);
}

Expand All @@ -43,6 +45,14 @@ public void set_next_index(int next_index) {
this.next_index = next_index;
}

public int get_current_index() {
return this.current_index;
}

public void set_current_index(int current_index) {
this.current_index = current_index;
}

public static HashMap<String, BigIntegers> read_values(String path,
int precision,
PaillierPublicKey paillier_public_key,
Expand Down
24 changes: 1 addition & 23 deletions src/main/java/weka/finito/structs/level_order_site.java
Original file line number Diff line number Diff line change
Expand Up @@ -14,18 +14,12 @@

public final class level_order_site implements Serializable {
private static final long serialVersionUID = 575566807906351024L;
private int index = 0;
private int next_index;
private final int level;

public final PaillierPublicKey paillier_public_key;
public final DGKPublicKey dgk_public_key;

private final List<NodeInfo> node_level_data = new ArrayList<>();

// Set to value to a client, to let level-site d-1 know to talk to a client next with answer
private String next_level_site = "client";

private String next_level_site = null;
private int next_level_site_port = -1;

public level_order_site(int level, PaillierPublicKey paillier_public_key, DGKPublicKey dgk_public_key) {
Expand All @@ -38,22 +32,6 @@ public List<NodeInfo> get_node_data() {
return this.node_level_data;
}

public void set_current_index(int index) {
this.index = index;
}

public void set_next_index(int next_index) {
this.next_index = next_index;
}

public int get_next_index() {
return this.next_index;
}

public int get_current_index() {
return this.index;
}

public void set_next_level_site(String next_level_site) {
this.next_level_site = next_level_site;
}
Expand Down
12 changes: 6 additions & 6 deletions src/main/java/weka/finito/utils/shared.java
Original file line number Diff line number Diff line change
Expand Up @@ -88,16 +88,16 @@ public static NodeInfo traverse_level(level_order_site level_site_data,
ls = node_level_data.get(node_level_index);
System.out.println("j=" + node_level_index);
if (ls.isLeaf()) {
if (n == 2 * level_site_data.get_current_index()
|| n == 2 * level_site_data.get_current_index() + 1) {
if (n == 2 * encrypted_features.get_current_index()
|| n == 2 * encrypted_features.get_current_index() + 1) {
terminalLeafFound = true;
to_return = ls;
}
n += 2;
}
else {
if ((n == 2 * level_site_data.get_current_index()
|| n == 2 * level_site_data.get_current_index() + 1)) {
if ((n == 2 * encrypted_features.get_current_index()
|| n == 2 * encrypted_features.get_current_index() + 1)) {

if (ls.comparisonType == 6) {
inequalityHolds = compare(ls, 1,
Expand All @@ -111,8 +111,8 @@ public static NodeInfo traverse_level(level_order_site level_site_data,

if (inequalityHolds) {
equalsFound = true;
level_site_data.set_next_index(next_index);
System.out.println("New index: " + level_site_data.get_next_index());
encrypted_features.set_next_index(next_index);
System.out.println("New index: " + encrypted_features.get_next_index());
}
}
n++;
Expand Down

0 comments on commit 20766b6

Please sign in to comment.