-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathratchet.py
249 lines (201 loc) · 8.75 KB
/
ratchet.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
from __future__ import absolute_import
from .interfaces.aead import AEADIFace
from .interfaces.dhkey import DHKeyPairIface
from .interfaces.ratchet import RatchetIface
from .message import Header, Message, MessageHE
from .state import State
class MaxSkippedMksExceeded(Exception):
"""Too many message keys skipped/stored in single chain."""
pass
# Default doubel-ratchet encrypt/decrypt
class Ratchet(RatchetIface):
"""An implementation of the Ratchet Interface."""
MAX_SKIP = 1000
MAX_STORE = 2000
@staticmethod
def encrypt_message(state, pt, associated_data, aead):
if not isinstance(state, State):
raise TypeError("state must be of type: state")
if not isinstance(pt, str):
raise TypeError("pt must be of type: string")
if not isinstance(associated_data, bytes):
raise TypeError("associated_data must be of type: bytes")
if not issubclass(aead, AEADIFace):
raise TypeError("aead must implement AEADIface")
if state.delayed_send_ratchet:
state.send.ck = state.root.ratchet(state.dh_pair.dh_out(state.dh_pk_r))[0]
state.delayed_send_ratchet = False
mk = state.send.ratchet()
header = Header(state.dh_pair.public_key,
state.prev_send_len, state.send.msg_no)
state.send.msg_no += 1
ct = aead.encrypt(mk, pt.encode("utf-8"), associated_data + bytes(header))
return Message(header, ct)
@staticmethod
def decrypt_message(state, msg, associated_data, aead, keypair):
if not isinstance(state, State):
raise TypeError("state must be of type: state")
if not isinstance(msg, Message):
raise TypeError("msg must be of type: Message")
if not isinstance(associated_data, bytes):
raise TypeError("associated_data must be of type: bytes")
if not issubclass(aead, AEADIFace):
raise TypeError("aead must implement AEADIface")
if not issubclass(keypair, DHKeyPairIface):
raise TypeError("keypair must implement DHKeyPairIface")
pt = try_skipped_mks(state, msg.header, msg.ct, associated_data, aead)
if pt != None:
state.skipped_mks.notify_event() # successful decrypt event
return pt
if not state.dh_pk_r:
dh_ratchet(state, msg.header.dh_pk, keypair)
elif not state.dh_pk_r.is_equal_to(msg.header.dh_pk):
skip_over_mks(state, msg.header.prev_chain_len,
state.dh_pk_r.pk_bytes()) # save mks from old recv chain
dh_ratchet(state, msg.header.dh_pk, keypair)
skip_over_mks(state, msg.header.msg_no,
state.dh_pk_r.pk_bytes()) # save mks on new sending chain
mk = state.receive.ratchet()
state.receive.msg_no += 1
pt_bytes = aead.decrypt(mk, msg.ct, associated_data + bytes(msg.header))
state.skipped_mks.notify_event() # successful decrypt event
return pt_bytes.decode("utf-8")
# Double ratchet encrypt/decrypt (header encryption variant)
class RatchetHE(RatchetIface):
"""An implementation of the AEAD Interface."""
MAX_SKIP = 1000
MAX_STORE = 2000
@staticmethod
def encrypt_message_he(state, pt, associated_data, aead):
if not isinstance(state, State):
raise TypeError("state must be of type: state")
if not isinstance(pt, str):
raise TypeError("pt must be of type: string")
if not isinstance(associated_data, bytes):
raise TypeError("associated_data must be of type: bytes")
if not issubclass(aead, AEADIFace):
raise TypeError("aead must implement AEADIface")
if state.delayed_send_ratchet:
state.send.ck, state.next_hk_s = \
state.root.ratchet(state.dh_pair.dh_out(state.dh_pk_r), 2)
state.delayed_send_ratchet = False
mk = state.send.ratchet()
header = Header(state.dh_pair.public_key,
state.prev_send_len, state.send.msg_no)
hdr_ct = aead.encrypt(state.hk_s, bytes(header), b"")
state.send.msg_no += 1
ct = aead.encrypt(mk, pt.encode("utf-8"), associated_data + hdr_ct)
return MessageHE(hdr_ct, ct)
@staticmethod
def decrypt_message_he(state, msg, associated_data, aead, keypair):
if not isinstance(state, State):
raise TypeError("state must be of type: state")
if not isinstance(msg, Message):
raise TypeError("msg must be of type: Message")
if not isinstance(associated_data, bytes):
raise TypeError("associated_data must be of type: bytes")
if not issubclass(aead, AEADIFace):
raise TypeError("aead must implement AEADIface")
if not issubclass(keypair, DHKeyPairIface):
raise TypeError("keypair must implement DHKeyPairIface")
pt = try_skipped_mks_he(state, msg.header_ct, msg.ct, associated_data, aead)
if pt != None:
state.skipped_mks.notify_event() # successful decrypt event
return pt
header, should_dh_ratchet = decrypt_header(state, msg.header_ct, aead)
if should_dh_ratchet:
skip_over_mks(state, header.prev_chain_len,
state.hk_r) # save mks from old recv chain
dh_ratchet_he(state, header.dh_pk, keypair)
skip_over_mks(state, header.msg_no,
state.hk_r) # save mks on new sending chain
mk = state.receive.ratchet()
state.receive.msg_no += 1
pt_bytes = aead.decrypt(mk, msg.ct, associated_data + bytes(header))
state.skipped_mks.notify_event() # successful decrypt event
return pt_bytes.decode("utf-8")
# Returns decrypted plaintext if message key was stored, else None
def try_skipped_mks(state, header, ct, associated_data, aead):
hdr_pk_bytes = header.dh_pk.pk_bytes()
mk = state.skipped_mks.lookup((hdr_pk_bytes, header.msg_no))
if mk:
state.skipped_mks.delete((hdr_pk_bytes, header.msg_no))
pt_bytes = aead.decrypt(mk, ct, associated_data + bytes(header))
return pt_bytes.decode("utf-8")
return None
# Returns decrypted plaintext if header receive key was stored, else None.
def try_skipped_mks_he(state, header_ct, ct, associated_data, aead):
for ((hk_r, msg_no), mk) in state.skipped_mks.items():
try:
header_bytes = aead.decrypt(hk_r, header_ct, b"")
except:
continue
header = Header.from_bytes(header_bytes)
if header.msg_no == msg_no:
state.skipped_mks.delete((hk_r, msg_no))
pt_bytes = aead.decrypt(mk, ct, associated_data + header_ct)
return pt_bytes.decode("utf-8")
return None
# Returns decrypted header trying current and next header receive keys.
# Also returns whether DH-ratchet step is needed, i.e. if next header
# key is used then returns should DH-ratchet
def decrypt_header(state, header_ct, aead):
if state.hk_r != None: # may not have ratcheted yet
try:
header_bytes = aead.decrypt(state.hk_r, header_ct, b"")
return Header.from_bytes(header_bytes), False
except:
pass
try:
header_bytes = aead.decrypt(state.next_hk_r, header_ct, b"")
return Header.from_bytes(header_bytes), True
except:
pass
raise ValueError("Error: invalid header ciphertext.")
# Skips over and stores message keys in the current chain
# that come before provided end_msg_no. Raises exception
# if too many messages have been skipped in the current
# receiving chain.
def skip_over_mks(state, end_msg_no, map_key):
new_skip = end_msg_no - state.receive.msg_no
if new_skip + state.skipped_count > Ratchet.MAX_SKIP:
raise MaxSkippedMksExceeded("Too many messages skipped in"
"current chain")
if new_skip + state.skipped_mks.count() > Ratchet.MAX_STORE:
raise MaxSkippedMksExceeded("Too many messages stored")
elif state.receive.ck != None:
while state.receive.msg_no < end_msg_no:
mk = state.receive.ratchet()
if state.skipped_mks.count() == Ratchet.MAX_SKIP: # del keys FIFO
state.skipped_mks.delete(state.skipped_mks.front())
state.skipped_mks.put((map_key, state.receive.msg_no), mk)
state.receive.msg_no += 1
state.skipped_count += new_skip
# Diffie-Hellman ratchet step
def dh_ratchet(state, dh_pk_r, keypair):
if state.delayed_send_ratchet:
state.send.ck = state.root.ratchet(state.dh_pair.dh_out(dh_pk_r))[0]
state.dh_pk_r = dh_pk_r
state.receive.ck = state.root.ratchet(state.dh_pair.dh_out(state.dh_pk_r))[0]
state.dh_pair = keypair.generate_dh()
state.delayed_send_ratchet = True
state.prev_send_len = state.send.msg_no
state.send.msg_no = 0
state.receive.msg_no = 0
state.skipped_count = 0
# Diffie-Hellman ratchet step (header encryption variant)
def dh_ratchet_he(state, dh_pk_r, keypair):
if state.delayed_send_ratchet:
state.send.ck, state.next_hk_s = \
state.root.ratchet(state.dh_pair.dh_out(dh_pk_r), 2)
state.dh_pk_r = dh_pk_r
state.hk_s = state.next_hk_s
state.hk_r = state.next_hk_r
state.receive.ck, state.next_hk_r = \
state.root.ratchet(state.dh_pair.dh_out(state.dh_pk_r), 2)
state.dh_pair = keypair.generate_dh()
state.delayed_send_ratchet = True
state.prev_send_len = state.send.msg_no
state.send.msg_no = 0
state.receive.msg_no = 0
state.skipped_count = 0