Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

resolves #321, Added support for passing additional headers in JWE encryption. #322

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
### News ###

* Remove support for python 3.6
* Add support for additional headers in JWE encryption.

### Housekeeping ###

Expand Down
2 changes: 1 addition & 1 deletion jose/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__version__ = "3.3.0"
__version__ = "3.4.0"
__author__ = "Michael Davis"
__license__ = "MIT"
__copyright__ = "Copyright 2016 Michael Davis"
Expand Down
16 changes: 13 additions & 3 deletions jose/jwe.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,10 @@
from .utils import base64url_decode, base64url_encode, ensure_binary


def encrypt(plaintext, key, encryption=ALGORITHMS.A256GCM, algorithm=ALGORITHMS.DIR, zip=None, cty=None, kid=None):
def encrypt(
plaintext, key, encryption=ALGORITHMS.A256GCM, algorithm=ALGORITHMS.DIR, zip=None, cty=None, kid=None,
additional_headers=None
):
"""Encrypts plaintext and returns a JWE cmpact serialization string.

Args:
Expand All @@ -28,6 +31,10 @@ def encrypt(plaintext, key, encryption=ALGORITHMS.A256GCM, algorithm=ALGORITHMS.
cty (str, optional): The media type for the secured content.
See http://www.iana.org/assignments/media-types/media-types.xhtml
kid (str, optional): Key ID for the provided key
additional_headers (dict, optional): Additional JWE protected headers.
These headers will be added to the default headers. Any headers
that are added as additional headers will override the default
headers.

Returns:
bytes: The string representation of the header, encrypted key,
Expand All @@ -48,7 +55,7 @@ def encrypt(plaintext, key, encryption=ALGORITHMS.A256GCM, algorithm=ALGORITHMS.
if encryption not in ALGORITHMS.SUPPORTED:
raise JWEError("Algorithm %s not supported." % encryption)
key = jwk.construct(key, algorithm)
encoded_header = _encoded_header(algorithm, encryption, zip, cty, kid)
encoded_header = _encoded_header(algorithm, encryption, zip, cty, kid, additional_headers)

plaintext = _compress(zip, plaintext)
enc_cek, iv, cipher_text, auth_tag = _encrypt_and_auth(key, algorithm, encryption, zip, plaintext, encoded_header)
Expand Down Expand Up @@ -327,7 +334,7 @@ def _jwe_compact_deserialize(jwe_bytes):
return header, header_segment, encrypted_key, iv, ciphertext, auth_tag


def _encoded_header(alg, enc, zip, cty, kid):
def _encoded_header(alg, enc, zip, cty, kid, additional_headers):
"""
Generate an appropriate JOSE header based on the values provided
Args:
Expand All @@ -347,6 +354,9 @@ def _encoded_header(alg, enc, zip, cty, kid):
header["cty"] = cty
if kid:
header["kid"] = kid

header.update(additional_headers or {})

json_header = json.dumps(
header,
separators=(",", ":"),
Expand Down
13 changes: 13 additions & 0 deletions tests/test_jwe.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,10 @@ class TestDecrypt:
b"eyJhbGciOiJSU0EtT0FFUCIsImVuYyI6IkEyNTZDQkMtSFM1MTIifQ.Kbd5rSN1afyre2DbkXOmGKkCNZ09TfAwNpDn1Ic7_HJNS42VDx584ReiEzpyIoWek8l87h1oZL0OC0f1ceEuuTR-_rZzKNqq6t44EvXvRusSHg_mTm8qYwyJIkJsD_Zgh0HUza20X6Ypu4ZheTzw70krFYhFnBKNXzhdrf4Bbz8e7IEeR7Po2VqOzx6JPNFsJ1tRSb9r4w60-1qq0MSdl2VItvHVY4fg-bts2k2sJ_Ub8VtRLY1MzPc1rFcI10x_AD52ntW-8T_BvY8R7Ci0cLfEycGlOM-pJOtJVY4bQisx-PvLgPoKlfTMX251m_np9ImSov9edy57-jy427l28g.w5rYu_XKzUCwTScFQ3fGOA.6zntLreCPN2Eo6aLmuqYrkyF2hOBXzNlArOOJ0iZ9TA.xiF5HLIBmIE8FCog-CZwXpIUjP6XgpncwXjw--dM57I",
id="alg: RSA-OAEP, enc: A256CBC-HS512",
),
pytest.param( # JWE with custom headers.
b"eyJhbGciOiJSU0EtT0FFUC0yNTYiLCJlbmMiOiJBMjU2Q0JDLUhTNTEyIiwidGVzdC1oZWFkZXIxIjoidmFsMSIsInRlc3QtaGVhZGVyMiI6InZhbDEifQ.tZpAFnpWe-Kump0E16wE0k-7VSjY-Sdzmj3TrnuoVgaEz4dvFs8jTknUNHgsu4USzf6JrNoTB3mK8rM30z3lgsMqJ5zs4QPOvR7CuXAXdRf5Mje9cyeiJKebqumgR5P1d1D6GWrqoO9oDHBBOXcvRAkzS_siv0SAXLue7sV4e1F5re50oD2i9-FW9L-DLnFIHc_iUKjuOW00xyjxyDAw62thb2iV_ZBD8m-oz9tRxR3NQbGOvKdBOM_29lcxhVZq4Wspv3117hoyyni6KJBg8DLVuk9Rkt4DZQdZa7PcaoeHH5AIC_wsWJTI3yIuZVYri2pX3KVbrSsAz3zB9dbj8A.vyNmAMvzl7OiaPCVVfapsg.vIJcOra4VqL1MnXGjFJEtdwYEF-YW73DPAbvN7mEtso.NiXr0iwQehLqvAUUPkqfWL_N56Nu3b7rCVY7FDyuRYM",
id="alg: RSA-OAEP, enc: A256CBC-HS512",
),
pytest.param(
b"eyJhbGciOiJSU0EtT0FFUCIsImVuYyI6IkExMjhHQ00ifQ.SUDoqix7_PhGaNeCxYEgmvZt-Bhj-EoPfnTbJpxgvdUSVk6cn2XjAJxiVHTaeM8_DPmxxeKqt-JEVljc7lUmHQpAW1Cule7ySw498OgG6q4ddpBZEPXqAHpqlfATrhGpEq0WPRZJwvbyKUd08rND1r4SePZg8sag6cvbiPbMHIzQSjGPkDwWt1P5ue7n1ySmxqGenjPlzl4g_n5wwPGG5e3RGmoiVQh2Stybp9j2fiLNzHKcO5_9BJxMR4DEB0DE3NGhszXFQneP009j4wxm5kKzuja0ks9tEdNAJ3NLWnQhU-w0_xeePj8SGxJXuGIQT0ox9yQlD-HnmlEqMWYplg.5XuF3e3g7ck1RRy8.VSph3xlmrPI3z6jcLdh862GaDq6_-g.3WcUUUcy1NZ-aFYU8u9KHA",
id="alg: RSA-OAEP, enc: A128GCM",
Expand Down Expand Up @@ -525,3 +529,12 @@ def test_kid_header_not_present_when_not_provided(self):
encrypted = jwe.encrypt("Text", PUBLIC_KEY_PEM, enc, alg)
header = json.loads(base64url_decode(encrypted.split(b".")[0]))
assert "kid" not in header

@pytest.mark.skipif(AESKey is None, reason="No AES backend")
def test_additional_headers_present_when_provided(self):
enc = ALGORITHMS.A256CBC_HS512
alg = ALGORITHMS.RSA_OAEP_256
additional_headers = {"test-header1": "val1", "test-header2": "val1"}
encrypted = jwe.encrypt("Test", PUBLIC_KEY_PEM, enc, alg, additional_headers=additional_headers.copy())
header = json.loads(base64url_decode(encrypted.split(b".")[0]))
assert set(header.items()).issuperset(set(additional_headers.items()))