Skip to content

Commit

Permalink
refactor: further simplification
Browse files Browse the repository at this point in the history
  • Loading branch information
sander committed Mar 3, 2023
1 parent e96b0c2 commit b2ded6c
Show file tree
Hide file tree
Showing 25 changed files with 146 additions and 156 deletions.
10 changes: 0 additions & 10 deletions src/main/kotlin/AssociatedData.kt

This file was deleted.

6 changes: 2 additions & 4 deletions src/main/kotlin/CipherKey.kt
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
package nl.sanderdijkhuis.noise

import nl.sanderdijkhuis.noise.Size.Companion.valueSize

@JvmInline
value class CipherKey(val value: ByteArray) {
value class CipherKey(val data: Data) {

init {
require(value.valueSize == SIZE)
require(data.size == SIZE)
}

companion object {
Expand Down
23 changes: 11 additions & 12 deletions src/main/kotlin/CipherState.kt
Original file line number Diff line number Diff line change
Expand Up @@ -2,22 +2,21 @@ package nl.sanderdijkhuis.noise

data class CipherState(val cryptography: Cryptography, val key: CipherKey? = null, val nonce: Nonce = Nonce.zero) {

fun encryptWithAssociatedData(associatedData: AssociatedData, plaintext: Plaintext) =
key?.let {
println("Encrypting $key $nonce $associatedData $plaintext")
State(copy(nonce = nonce.increment()), cryptography.encrypt(it, nonce, associatedData, plaintext))
} ?: let {
println("Returning plaintext $plaintext $nonce")
fun encryptWithAssociatedData(associatedData: Data, plaintext: Plaintext) =
if (key == null)
State(this, plaintext.ciphertext)
}
else
nonce.increment()?.let {
State(copy(nonce = it), cryptography.encrypt(key, nonce, associatedData, plaintext))
} ?: State(this, plaintext.ciphertext)

fun decryptWithAssociatedData(data: AssociatedData, ciphertext: Ciphertext): State<CipherState, Plaintext>? = let {
println("Decrypting $key $nonce $data $ciphertext")
fun decryptWithAssociatedData(data: Data, ciphertext: Ciphertext): State<CipherState, Plaintext>? =
if (key == null)
State(this, ciphertext.plaintext)
else
cryptography.decrypt(key, nonce, data, ciphertext)?.let {
State(copy(nonce = nonce.increment()), it)
nonce.increment()?.let { n ->
cryptography.decrypt(key, nonce, data, ciphertext)?.let { p ->
State(copy(nonce = n), p)
}
}
}
}
6 changes: 2 additions & 4 deletions src/main/kotlin/Ciphertext.kt
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
package nl.sanderdijkhuis.noise

@JvmInline
value class Ciphertext(val value: ByteArray) {
value class Ciphertext(val data: Data) {

val data get() = Data(value)

val plaintext get() = Plaintext(value)
val plaintext get() = Plaintext(data)
}
4 changes: 2 additions & 2 deletions src/main/kotlin/Cryptography.kt
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@ interface Cryptography {

fun agree(privateKey: PrivateKey, publicKey: PublicKey): SharedSecret

fun encrypt(key: CipherKey, nonce: Nonce, associatedData: AssociatedData, plaintext: Plaintext): Ciphertext
fun encrypt(key: CipherKey, nonce: Nonce, associatedData: Data, plaintext: Plaintext): Ciphertext

fun decrypt(key: CipherKey, nonce: Nonce, associatedData: AssociatedData, ciphertext: Ciphertext): Plaintext?
fun decrypt(key: CipherKey, nonce: Nonce, associatedData: Data, ciphertext: Ciphertext): Plaintext?

fun hash(data: Data): Digest
}
10 changes: 8 additions & 2 deletions src/main/kotlin/Data.kt
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,7 @@ package nl.sanderdijkhuis.noise
import nl.sanderdijkhuis.noise.Size.Companion.valueSize
import kotlin.experimental.xor

@JvmInline
value class Data(val value: ByteArray) {
data class Data(val value: ByteArray) {

init {
require(value.valueSize <= Size.MAX_MESSAGE)
Expand All @@ -14,11 +13,18 @@ value class Data(val value: ByteArray) {

val size get() = value.valueSize

val isEmpty get() = value.isEmpty()

fun xor(that: Data) = Data(let {
require(value.size == that.value.size)
ByteArray(value.size) { this.value[it].xor(that.value[it]) }
})

override fun equals(other: Any?) =
this === other || ((other as? Data)?.let { value.contentEquals(it.value) } ?: false)

override fun hashCode() = value.contentHashCode()

companion object {

val empty get() = Data(ByteArray(0))
Expand Down
2 changes: 0 additions & 2 deletions src/main/kotlin/Digest.kt
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,5 @@ value class Digest(val data: Data) {
require(data.size == HashFunction.HASH_SIZE)
}

val associatedData get() = AssociatedData(data)

val messageAuthenticationKey get() = MessageAuthenticationKey(data)
}
4 changes: 4 additions & 0 deletions src/main/kotlin/HandshakeHash.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
package nl.sanderdijkhuis.noise

@JvmInline
value class HandshakeHash(val digest: Digest)
4 changes: 2 additions & 2 deletions src/main/kotlin/HandshakePattern.kt
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,14 @@ data class HandshakePattern(val name: ProtocolName, val preSharedMessagePatterns

val Noise_XN_25519_ChaChaPoly_SHA256 =
HandshakePattern(
ProtocolName("Noise_XN_25519_ChaChaPoly_SHA256".toByteArray()),
ProtocolName(Data("Noise_XN_25519_ChaChaPoly_SHA256".toByteArray())),
listOf(),
listOf(listOf(Token.E), listOf(Token.E, Token.EE), listOf(Token.S, Token.SE))
)

val Noise_NK_25519_ChaChaPoly_SHA256 =
HandshakePattern(
ProtocolName("Noise_NK_25519_ChaChaPoly_SHA256".toByteArray()),
ProtocolName(Data("Noise_NK_25519_ChaChaPoly_SHA256".toByteArray())),
listOf(listOf(), listOf(Token.S)),
listOf(listOf(Token.E, Token.ES), listOf(Token.E, Token.EE))
)
Expand Down
88 changes: 52 additions & 36 deletions src/main/kotlin/HandshakeState.kt
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@ data class HandshakeState(
val role: Role,
val symmetricState: SymmetricState,
val messagePatterns: List<List<Token>>,
val s: KeyPair? = null,
val e: KeyPair? = null,
val rs: PublicKey? = null,
val re: PublicKey? = null,
val localStaticKeyPair: KeyPair? = null,
val localEphemeralKeyPair: KeyPair? = null,
val remoteStaticKey: PublicKey? = null,
val remoteEphemeralKey: PublicKey? = null,
val trustedStaticKeys: Set<PublicKey> = emptySet()
) {

Expand All @@ -34,17 +34,17 @@ data class HandshakeState(
}
when {
state == null -> null
token == Token.E && e != null -> state.run(e.public.data) { it.mixHash(e.public.data) }
token == Token.S && s != null -> state.runAndAppendInState {
it.encryptAndHash(s.public.plaintext).map { c -> c.data }
token == Token.E && localEphemeralKeyPair != null -> state.run(localEphemeralKeyPair.public.data) { it.mixHash(localEphemeralKeyPair.public.data) }
token == Token.S && localStaticKeyPair != null -> state.runAndAppendInState {
it.encryptAndHash(localStaticKeyPair.public.plaintext).map { c -> c.data }
}

token == Token.EE -> mixKey(e, re)
token == Token.ES && role == Role.INITIATOR -> mixKey(e, rs)
token == Token.ES && role == Role.RESPONDER -> mixKey(s, re)
token == Token.SE && role == Role.INITIATOR -> mixKey(s, re)
token == Token.SE && role == Role.RESPONDER -> mixKey(e, rs)
token == Token.SS -> mixKey(s, rs)
token == Token.EE -> mixKey(localEphemeralKeyPair, remoteEphemeralKey)
token == Token.ES && role == Role.INITIATOR -> mixKey(localEphemeralKeyPair, remoteStaticKey)
token == Token.ES && role == Role.RESPONDER -> mixKey(localStaticKeyPair, remoteEphemeralKey)
token == Token.SE && role == Role.INITIATOR -> mixKey(localStaticKeyPair, remoteEphemeralKey)
token == Token.SE && role == Role.RESPONDER -> mixKey(localEphemeralKeyPair, remoteStaticKey)
token == Token.SS -> mixKey(localStaticKeyPair, remoteStaticKey)
else -> null
}
}?.runAndAppendInState { it.encryptAndHash(payload.plainText).map { c -> c.data } }
Expand All @@ -53,7 +53,14 @@ data class HandshakeState(
when {
state == null -> MessageResult.InsufficientKeyMaterial
rest.isEmpty() -> symmetricState.split()
.let { MessageResult.FinalHandshakeMessage(it.first, it.second, symmetricState.digest, state.result) }
.let {
MessageResult.FinalHandshakeMessage(
it.first,
it.second,
symmetricState.handshakeHash,
state.result
)
}

else -> MessageResult.IntermediateHandshakeMessage(
state.current.copy(messagePatterns = rest),
Expand All @@ -74,38 +81,40 @@ data class HandshakeState(
println("Token $token")
when {
state == null -> null
token == Token.E && state.current.re == null ->
token == Token.E && state.current.remoteEphemeralKey == null ->
let {
val re =
PublicKey(
state.result.value.sliceArray(
IntRange(
0,
KeyAgreementConfiguration.SIZE.value - 1
Data(
state.result.value.sliceArray(
IntRange(
0,
SharedSecret.SIZE.value - 1
)
)
)
)
println("E: read $re")
val mixed = state.current.symmetricState.mixHash(re.data)
state.copy(
current = state.current.copy(symmetricState = mixed, re = re),
result = Data(state.result.value.drop(KeyAgreementConfiguration.SIZE.value).toByteArray())
current = state.current.copy(symmetricState = mixed, remoteEphemeralKey = re),
result = Data(state.result.value.drop(SharedSecret.SIZE.value).toByteArray())
)
}

token == Token.S && state.current.rs == null -> let {
token == Token.S && state.current.remoteStaticKey == null -> let {
println("S")
val splitAt = KeyAgreementConfiguration.SIZE.value + 16
val splitAt = SharedSecret.SIZE.value + 16
val temp =
state.result.value.sliceArray(IntRange(0, splitAt - 1))
state.current.symmetricState.decryptAndHash(Ciphertext(temp))?.let {
val publicKey = PublicKey(it.result.value)
state.current.symmetricState.decryptAndHash(Ciphertext(Data(temp)))?.let {
val publicKey = PublicKey(it.result.data)
println("Public key $publicKey")
println("Trusting $trustedStaticKeys")
println("Trusted? ${trustedStaticKeys.contains(publicKey)}")
if (trustedStaticKeys.contains(publicKey))
state.copy(
current = state.current.copy(symmetricState = it.current, rs = publicKey),
current = state.current.copy(symmetricState = it.current, remoteStaticKey = publicKey),
result = Data(
state.result.value.drop(splitAt).toByteArray()
)
Expand All @@ -115,26 +124,26 @@ data class HandshakeState(
}

token == Token.EE -> let {
println("EE: Mixing ${state.current.e} ${state.current.re}")
mixKey(state.current.e, state.current.re)
println("EE: Mixing ${state.current.localEphemeralKeyPair} ${state.current.remoteEphemeralKey}")
mixKey(state.current.localEphemeralKeyPair, state.current.remoteEphemeralKey)
}

token == Token.ES && role == Role.INITIATOR -> mixKey(state.current.e, state.current.rs)
token == Token.ES && role == Role.RESPONDER -> mixKey(state.current.s, state.current.re)
token == Token.ES && role == Role.INITIATOR -> mixKey(state.current.localEphemeralKeyPair, state.current.remoteStaticKey)
token == Token.ES && role == Role.RESPONDER -> mixKey(state.current.localStaticKeyPair, state.current.remoteEphemeralKey)
token == Token.SE && role == Role.INITIATOR -> let {
println("SE")
mixKey(state.current.s, state.current.re)
mixKey(state.current.localStaticKeyPair, state.current.remoteEphemeralKey)
}

token == Token.SE && role == Role.RESPONDER -> mixKey(state.current.e, state.current.rs)
token == Token.SS -> mixKey(state.current.s, state.current.rs)
token == Token.SE && role == Role.RESPONDER -> mixKey(state.current.localEphemeralKeyPair, state.current.remoteStaticKey)
token == Token.SS -> mixKey(state.current.localStaticKeyPair, state.current.remoteStaticKey)
else -> null
}
}?.let {
it.current.symmetricState.decryptAndHash(Ciphertext(it.result.value))?.let { decrypted ->
it.current.symmetricState.decryptAndHash(Ciphertext(it.result))?.let { decrypted ->
State(
it.current.copy(symmetricState = decrypted.current), Payload(
Data(decrypted.result.value)
decrypted.result.data
)
)
}
Expand All @@ -144,7 +153,14 @@ data class HandshakeState(
when {
state == null -> MessageResult.InsufficientKeyMaterial
rest.isEmpty() -> symmetricState.split()
.let { MessageResult.FinalHandshakeMessage(it.first, it.second, symmetricState.digest, state.result) }
.let {
MessageResult.FinalHandshakeMessage(
it.first,
it.second,
symmetricState.handshakeHash,
state.result
)
}

else -> MessageResult.IntermediateHandshakeMessage(
state.current.copy(messagePatterns = rest),
Expand Down
8 changes: 2 additions & 6 deletions src/main/kotlin/InputKeyMaterial.kt
Original file line number Diff line number Diff line change
@@ -1,14 +1,10 @@
package nl.sanderdijkhuis.noise

import nl.sanderdijkhuis.noise.Size.Companion.valueSize

@JvmInline
value class InputKeyMaterial(val value: ByteArray) {

val data get() = Data(value)
value class InputKeyMaterial(val data: Data) {

init {
require(value.isEmpty() || value.valueSize == DEFAULT_SIZE || value.valueSize == KeyAgreementConfiguration.SIZE)
require(data.isEmpty || data.size == DEFAULT_SIZE || data.size == SharedSecret.SIZE)
}

companion object {
Expand Down
6 changes: 0 additions & 6 deletions src/main/kotlin/KeyAgreementConfiguration.kt

This file was deleted.

2 changes: 1 addition & 1 deletion src/main/kotlin/MessageAuthenticationData.kt
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ package nl.sanderdijkhuis.noise
@JvmInline
value class MessageAuthenticationData(val digest: Digest) {

val cipherKey get() = CipherKey(digest.data.value)
val cipherKey get() = CipherKey(digest.data)

val chainingKey get() = ChainingKey(digest)
}
2 changes: 1 addition & 1 deletion src/main/kotlin/MessageResult.kt
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ sealed interface MessageResult<T> {
data class FinalHandshakeMessage<T>(
val initiatorCipherState: CipherState,
val responderCipherState: CipherState,
val handshakeHash: Digest,
val handshakeHash: HandshakeHash,
val result: T
) : MessageResult<T>
}
8 changes: 6 additions & 2 deletions src/main/kotlin/Nonce.kt
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,13 @@ value class Nonce(val value: ByteArray) {
require(value.valueSize == SIZE)
}

constructor(number: Long) : this(ByteBuffer.allocate(SIZE.value).order(ByteOrder.BIG_ENDIAN).putLong(number).array())
constructor(number: Long) : this(
ByteBuffer.allocate(SIZE.value).order(ByteOrder.BIG_ENDIAN).putLong(number).array()
)

fun increment() = Nonce(ByteBuffer.wrap(value).order(ByteOrder.BIG_ENDIAN).long + 1L)
fun increment() = if (value.contentEquals(SIZE.byteArray { 0x11 })) null else Nonce(
ByteBuffer.wrap(value).order(ByteOrder.BIG_ENDIAN).long + 1L
)

companion object {

Expand Down
4 changes: 2 additions & 2 deletions src/main/kotlin/Payload.kt
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,5 @@ package nl.sanderdijkhuis.noise

data class Payload(val data: Data) {

val plainText get() = Plaintext(data.value)
}
val plainText get() = Plaintext(data)
}
6 changes: 2 additions & 4 deletions src/main/kotlin/Plaintext.kt
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
package nl.sanderdijkhuis.noise

@JvmInline
value class Plaintext(val value: ByteArray) {
value class Plaintext(val data: Data) {

val ciphertext get() = Ciphertext(value)

val data get() = Data(value)
val ciphertext get() = Ciphertext(data)
}
9 changes: 1 addition & 8 deletions src/main/kotlin/ProtocolName.kt
Original file line number Diff line number Diff line change
@@ -1,11 +1,4 @@
package nl.sanderdijkhuis.noise

import nl.sanderdijkhuis.noise.Size.Companion.valueSize

@JvmInline
value class ProtocolName(val value: ByteArray) {

val data get() = Data(value)

val size get() = value.valueSize
}
value class ProtocolName(val data: Data)
Loading

0 comments on commit b2ded6c

Please sign in to comment.