Skip to content

Commit

Permalink
got being able to do everytihng in server-site complete. Also migrate…
Browse files Browse the repository at this point in the history
…d compare() to shared. Likely will do same with overall computing index, and will need to start messing around with log4j
  • Loading branch information
AndrewQuijano committed Oct 16, 2023
1 parent 24411fc commit b7c929e
Show file tree
Hide file tree
Showing 5 changed files with 108 additions and 140 deletions.
5 changes: 1 addition & 4 deletions src/main/java/weka/finito/client.java
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,8 @@
import java.io.*;
import java.math.BigInteger;
import java.net.Socket;
import java.nio.charset.StandardCharsets;
import java.security.KeyPair;
import java.security.MessageDigest;
import java.security.NoSuchAlgorithmException;
import java.util.Base64;
import java.util.HashMap;
import java.util.Hashtable;

Expand All @@ -25,7 +22,7 @@
import security.socialistmillionaire.bob;
import weka.finito.structs.BigIntegers;

import static weka.finito.utils.Shared.hash;
import static weka.finito.utils.shared.hash;

public final class client implements Runnable {
private final String classes_file = "classes.txt";
Expand Down
83 changes: 27 additions & 56 deletions src/main/java/weka/finito/level_site_thread.java
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import java.lang.System;

import security.misc.HomomorphicException;
import security.socialistmillionaire.alice;
import weka.finito.structs.BigIntegers;
import weka.finito.structs.NodeInfo;
Expand All @@ -11,10 +10,12 @@
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.math.BigInteger;
import java.net.Socket;
import java.util.Hashtable;
import java.util.List;
import java.util.Map;

import static weka.finito.utils.shared.compare;

public class level_site_thread implements Runnable {

Expand All @@ -24,9 +25,7 @@ public class level_site_thread implements Runnable {

private level_order_site level_site_data = null;

private alice Niu = null;

private Hashtable<String, BigIntegers> encrypted_features;
private final Hashtable<String, BigIntegers> encrypted_features = new Hashtable<>();
private final AES crypto;

public level_site_thread(Socket client_socket, level_order_site level_site_data, AES crypto) {
Expand All @@ -46,7 +45,11 @@ public level_site_thread(Socket client_socket, level_order_site level_site_data,
this.toClient.writeBoolean(true);
closeClientConnection();
} else if (x instanceof Hashtable) {
encrypted_features = (Hashtable<String, BigIntegers>) x;
for (Map.Entry<?, ?> entry: ((Hashtable<?, ?>) x).entrySet()){
if (entry.getKey() instanceof String && entry.getValue() instanceof BigIntegers) {
encrypted_features.put((String) entry.getKey(), (BigIntegers) entry.getValue());
}
}
// Have encrypted copy of thresholds if not done already for all nodes in level-site
this.level_site_data = level_site_data;
} else {
Expand All @@ -70,45 +73,6 @@ private void closeClientConnection() throws IOException {
}
}

// a - from CLIENT, should already be encrypted...
private boolean compare(NodeInfo ld, int comparisonType)
throws HomomorphicException, ClassNotFoundException, IOException {

long start_time = System.nanoTime();

BigIntegers encrypted_values = this.encrypted_features.get(ld.variable_name);
BigInteger encrypted_client_value = null;
BigInteger encrypted_thresh = null;

// Encrypt the thresh-hold correctly
if ((comparisonType == 1) || (comparisonType == 2) || (comparisonType == 4)) {
encrypted_thresh = ld.getPaillier();
encrypted_client_value = encrypted_values.getIntegerValuePaillier();
toClient.writeInt(0);
Niu.setDGKMode(false);
}
else if ((comparisonType == 3) || (comparisonType == 5)) {
encrypted_thresh = ld.getDGK();
encrypted_client_value = encrypted_values.getIntegerValueDGK();
toClient.writeInt(1);
Niu.setDGKMode(true);
}
toClient.flush();
assert encrypted_client_value != null;
long stop_time = System.nanoTime();

double run_time = (double) (stop_time - start_time);
run_time = run_time / 1000000;
System.out.printf("Comparison took %f ms\n", run_time);
if (((comparisonType == 1) && (ld.threshold == 0))
|| (comparisonType == 4) || (comparisonType == 5)) {
return Niu.Protocol4(encrypted_thresh, encrypted_client_value);
}
else {
return Niu.Protocol4(encrypted_client_value, encrypted_thresh);
}
}

// This will run the communication with client and next level site
public final void run() {
Object o;
Expand All @@ -118,16 +82,16 @@ public final void run() {
long start_time = System.nanoTime();

try {
Niu = new alice(client_socket);
alice niu = new alice(client_socket);
if (this.level_site_data == null) {
toClient.writeInt(-2);
closeClientConnection();
return;
}

List<NodeInfo> node_level_data = this.level_site_data.get_node_data();
Niu.setDGKPublicKey(this.level_site_data.dgk_public_key);
Niu.setPaillierPublicKey(this.level_site_data.paillier_public_key);
niu.setDGKPublicKey(this.level_site_data.dgk_public_key);
niu.setPaillierPublicKey(this.level_site_data.paillier_public_key);

get_previous_index = fromClient.readBoolean();
if (get_previous_index) {
Expand Down Expand Up @@ -161,32 +125,39 @@ public final void run() {
ls = node_level_data.get(node_level_index);
System.out.println("j=" + node_level_index);
if (ls.isLeaf()) {
if (n == 2 * this.level_site_data.get_current_index() || n == 2 * this.level_site_data.get_current_index() + 1) {
if (n == 2 * this.level_site_data.get_current_index()
|| n == 2 * this.level_site_data.get_current_index() + 1) {
terminalLeafFound = true;
System.out.println("Terminal leaf:" + ls.getVariableName());
}
n += 2;
}
else {
if ((n==2 * this.level_site_data.get_current_index() || n == 2 * this.level_site_data.get_current_index() + 1)) {
if ((n==2 * this.level_site_data.get_current_index()
|| n == 2 * this.level_site_data.get_current_index() + 1)) {
if (ls.comparisonType == 6) {
boolean firstInequalityHolds = compare(ls, 3);
boolean firstInequalityHolds = compare(ls, 3,
encrypted_features, toClient, niu);
if (firstInequalityHolds) {
inequalityHolds = true;
} else {
boolean secondInequalityHolds = compare(ls, 5);
}
else {
boolean secondInequalityHolds = compare(ls, 5,
encrypted_features, toClient, niu);
if (secondInequalityHolds) {
inequalityHolds = true;
}
}
} else {
inequalityHolds = compare(ls, ls.comparisonType);
}
else {
inequalityHolds = compare(ls, ls.comparisonType,
encrypted_features, toClient, niu);
}

if (inequalityHolds) {
equalsFound = true;
this.level_site_data.set_next_index(next_index);
System.out.println("New index:" + this.level_site_data.get_current_index());
System.out.println("New index: " + this.level_site_data.get_next_index());
}
}
n++;
Expand Down
81 changes: 15 additions & 66 deletions src/main/java/weka/finito/server.java
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
import java.io.FileReader;
import java.io.IOException;

import java.math.BigInteger;
import java.net.ServerSocket;
import java.net.Socket;
import java.util.*;
Expand All @@ -31,7 +30,7 @@
import weka.finito.structs.level_order_site;
import weka.finito.structs.NodeInfo;

import static weka.finito.utils.Shared.hash;
import static weka.finito.utils.shared.*;


public final class server implements Runnable {
Expand Down Expand Up @@ -137,46 +136,6 @@ public server(String training_data, int precision, int server_port) {
this.use_level_sites = false;
}

private boolean compare(NodeInfo ld, int comparisonType,
Hashtable<String, BigIntegers> encrypted_features,
ObjectOutputStream toClient, alice Niu)
throws ClassNotFoundException, HomomorphicException, IOException {

long start_time = System.nanoTime();

BigIntegers encrypted_values = encrypted_features.get(ld.variable_name);
BigInteger encrypted_client_value = null;
BigInteger encrypted_thresh = null;

// Encrypt the thresh-hold correctly
if ((comparisonType == 1) || (comparisonType == 2) || (comparisonType == 4)) {
encrypted_thresh = ld.getPaillier();
encrypted_client_value = encrypted_values.getIntegerValuePaillier();
toClient.writeInt(0);
Niu.setDGKMode(false);
}
else if ((comparisonType == 3) || (comparisonType == 5)) {
encrypted_thresh = ld.getDGK();
encrypted_client_value = encrypted_values.getIntegerValueDGK();
toClient.writeInt(1);
Niu.setDGKMode(true);
}
toClient.flush();
assert encrypted_client_value != null;
long stop_time = System.nanoTime();

double run_time = (double) (stop_time - start_time);
run_time = run_time / 1000000;
System.out.printf("Comparison took %f ms\n", run_time);
if (((comparisonType == 1) && (ld.threshold == 0))
|| (comparisonType == 4) || (comparisonType == 5)) {
return Niu.Protocol4(encrypted_thresh, encrypted_client_value);
}
else {
return Niu.Protocol4(encrypted_client_value, encrypted_thresh);
}
}

private void evaluate(int server_port) throws IOException, HomomorphicException {
ServerSocket serverSocket = new ServerSocket(server_port);
System.out.println("Server will be waiting for direct evaluation from client");
Expand All @@ -200,32 +159,36 @@ private void evaluate(int server_port) throws IOException, HomomorphicException
Niu.setDGKPublicKey(dgk_public);

List<NodeInfo> node_level_data;
NodeInfo ls = null;
NodeInfo ls;
long start_time = System.nanoTime();
int previous_index = 0;

// Traverse DT until you hit a leaf, the client has to track the index...
for (level_order_site level_site_data : all_level_sites) {
node_level_data = level_site_data.get_node_data();

level_site_data.set_current_index(previous_index);

// Handle at a level...
int node_level_index = 0;
int n = 0;
int next_index = 0;
boolean equalsFound = false;
boolean inequalityHolds = false;
boolean terminalLeafFound = false;

while (!equalsFound) {
while (true) {
ls = node_level_data.get(node_level_index);
System.out.println("j=" + node_level_index);
if (ls.isLeaf()) {
if (n == 2 * level_site_data.get_current_index()
|| n == 2 * level_site_data.get_current_index() + 1) {
terminalLeafFound = true;
System.out.println("Terminal leaf:" + ls.getVariableName());
// Tell the client the value
to_client_site.writeInt(-1);
to_client_site.writeObject(ls.getVariableName());
to_client_site.flush();
long stop_time = System.nanoTime();
double run_time = (double) (stop_time - start_time);
run_time = run_time / 1000000;
System.out.printf("Total Server-Site run-time took %f ms\n", run_time);
break;
}
n += 2;
Expand All @@ -250,32 +213,18 @@ private void evaluate(int server_port) throws IOException, HomomorphicException
}

if (inequalityHolds) {
equalsFound = true;
level_site_data.set_next_index(next_index);
System.out.println("New index:" + level_site_data.get_current_index());
System.out.println("Next index:" + next_index);
// Remember this for next loop...
previous_index = next_index;
break;
}
}
n++;
next_index++;
}
node_level_index++;
}

if (terminalLeafFound) {
// Tell the client the value
to_client_site.writeInt(-1);
to_client_site.writeObject(ls.getVariableName());
to_client_site.flush();
long stop_time = System.nanoTime();
double run_time = (double) (stop_time - start_time);
run_time = run_time / 1000000;
System.out.printf("Total Server-Site run-time took %f ms\n", run_time);
break;
}
else {
// Remember this for next loop...
previous_index = next_index;
}
}
} catch (ClassNotFoundException e) {
throw new RuntimeException(e);
Expand Down
14 changes: 0 additions & 14 deletions src/main/java/weka/finito/utils/Shared.java

This file was deleted.

Loading

0 comments on commit b7c929e

Please sign in to comment.