Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for scalar addition, subtraction, and negation #10

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/workflows/lint-test-cover-docs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@ jobs:
xcl: ['omit', 'install']
name: "Python ${{ matrix.python-version }}"
steps:
- uses: actions/checkout@v3
- uses: actions/checkout@v4
- name: Install Python.
uses: actions/setup-python@v4
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}
architecture: x64
Expand Down
6 changes: 3 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "oblivious"
version = "7.0.0"
version = "7.1.0"
description = """\
Python library that serves as an API for common \
cryptographic primitives used to implement OPRF, OT, \
Expand Down Expand Up @@ -33,8 +33,8 @@ mclbn256 = [
]
docs = [
"toml~=0.10.2",
"sphinx~=4.2.0",
"sphinx-rtd-theme~=1.0.0"
"sphinx~=5.3.0",
"sphinx-rtd-theme~=1.2.0"
]
test = [
"pytest~=7.2",
Expand Down
138 changes: 136 additions & 2 deletions src/oblivious/ristretto.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,8 @@ class python:
:obj:`python.pnt <pnt>`, :obj:`python.bas <bas>`,
:obj:`python.can <can>`, :obj:`python.mul <mul>`,
:obj:`python.add <add>`, :obj:`python.sub <sub>`,
:obj:`python.neg <neg>`,
:obj:`python.neg <neg>`, :obj:`python.sad <sad>`,
:obj:`python.ssu <ssu>`,
:obj:`python.point <oblivious.ristretto.python.point>`, and
:obj:`python.scalar <oblivious.ristretto.python.scalar>`.
For example, you can perform addition of points using
Expand Down Expand Up @@ -385,6 +386,54 @@ def smu(s: bytes, t: bytes) -> bytes:
"""
return _sc25519_mul(s, t)

@staticmethod
def sad(s: bytes, t: bytes) -> bytes:
"""
Return the sum of two scalars.

>>> p = scalar.from_int(4)
>>> q = scalar.from_int(2)
>>> sodium.sad(p, q) == sodium.sad(q, p)
True
>>> sodium.sad(p, q).hex()
'0600000000000000000000000000000000000000000000000000000000000000'
>>> sodium.sad(-q, p).hex()
'0200000000000000000000000000000000000000000000000000000000000000'
"""
(s, t) = (int.from_bytes(s, 'little'), int.from_bytes(t, 'little'))
return (
(s + t) % (pow(2, 252) + 27742317777372353535851937790883648493)
).to_bytes(32, 'little')

@staticmethod
def ssu(s: bytes, t: bytes) -> bytes:
"""
Return the result of subtracting the right-hand scalar from the
left-hand scalar.

>>> p = scalar.from_int(4)
>>> q = scalar.from_int(2)
>>> sodium.ssu(p, q).hex()
'0200000000000000000000000000000000000000000000000000000000000000'
"""
(s, t) = (int.from_bytes(s, 'little'), int.from_bytes(t, 'little'))
return (
(s - t) % (pow(2, 252) + 27742317777372353535851937790883648493)
).to_bytes(32, 'little')

@staticmethod
def sne(s: bytes) -> bytes:
"""
Return the additive inverse of a scalar.

>>> p = scalar.from_int(4)
>>> sodium.sne(p).hex()
'e9d3f55c1a631258d69cf7a2def9de1400000000000000000000000000000010'
"""
return (
(pow(2, 252) + 27742317777372353535851937790883648493 - int.from_bytes(s, 'little'))
frederickjansen marked this conversation as resolved.
Show resolved Hide resolved
).to_bytes(32, 'little')

#
# Attempt to load primitives from libsodium, if it is present;
# otherwise, use the rbcl library, if it is present. Otherwise,
Expand Down Expand Up @@ -527,7 +576,7 @@ class encapsulates shared/dynamic library variants of both classes
:obj:`sodium.pnt <pnt>`, :obj:`sodium.bas <bas>`,
:obj:`sodium.can <can>`, :obj:`sodium.mul <mul>`,
:obj:`sodium.add <add>`, :obj:`sodium.sub <sub>`,
:obj:`sodium.neg <neg>`,
:obj:`sodium.neg <neg>`, :obj:`sodium.sne <sne>`,
:obj:`sodium.point <oblivious.ristretto.sodium.point>`, and
:obj:`sodium.scalar <oblivious.ristretto.sodium.scalar>`.
For example, you can perform addition of points using
Expand Down Expand Up @@ -734,6 +783,54 @@ def smu(s: bytes, t: bytes) -> bytes:
bytes(s), bytes(t)
)

@staticmethod
def sad(s: bytes, t: bytes) -> bytes:
"""
Return the sum of two scalars.

>>> s = sodium.scl()
>>> t = sodium.scl()
>>> sodium.sad(s, t) == sodium.sad(t, s)
True
"""
return sodium._call(
sodium._lib.crypto_core_ristretto255_scalarbytes(),
sodium._lib.crypto_core_ristretto255_scalar_add,
bytes(s), bytes(t)
)

@staticmethod
def ssu(p: bytes, q: bytes) -> bytes:
"""
Return the result of subtracting the right-hand scalar from the
left-hand scalar.

>>> p = scalar.hash('123'.encode())
>>> q = scalar.hash('456'.encode())
>>> sodium.ssu(p, q).hex()
'd08dcedb3a8dc87951acd91334a1faed511f49c6e9296780634b858e42347908'
"""
return sodium._call(
_sodium.crypto_core_ristretto255_scalarbytes(),
sodium._lib.crypto_core_ristretto255_scalar_sub,
bytes(p), bytes(q)
)

@staticmethod
def sne(s: bytes) -> bytes:
"""
Return the additive inverse of a scalar.

>>> p = scalar.from_int(4)
>>> sodium.sne(p).hex()
'e9d3f55c1a631258d69cf7a2def9de1400000000000000000000000000000010'
"""
return sodium._call(
_sodium.crypto_core_ristretto255_scalarbytes(),
sodium._lib.crypto_core_ristretto255_scalar_negate,
bytes(s)
)

except: # pylint: disable=W0702 # pragma: no cover
# Exported symbol.
sodium = None # pragma: no cover
Expand Down Expand Up @@ -1091,6 +1188,17 @@ def __invert__(self: scalar) -> scalar:

return self._implementation.scalar(self._implementation.inv(self))

def __neg__(self: scalar) -> scalar:
"""
Return the negation of this instance.

>>> p = scalar.hash('123'.encode())
>>> q = scalar.hash('456'.encode())
>>> ((p + q) + (-q)) == p
True
"""
return self._implementation.scalar(self._implementation.sne(self))

def __mul__(self: scalar, other: Union[scalar, point]) -> Union[scalar, point]:
"""
Multiply the supplied scalar or point by this instance.
Expand Down Expand Up @@ -1155,6 +1263,32 @@ def __rmul__(self: scalar, other: Union[scalar, point]):
'scalar must be on left-hand side of multiplication operator'
)

def __add__(self: scalar, other: scalar) -> scalar:
"""
Return the sum of this instance and another scalar.

>>> p = scalar.hash('123'.encode())
>>> q = scalar.hash('456'.encode())
>>> (p + q).hex()
'69117034205aa81808edae5d89128497ef75f5b71416d97ccfd18760ad117c0e'
>>> p + (q - q) == p
True
"""
return self._implementation.scalar(self._implementation.sad(self, other))

def __sub__(self: scalar, other: scalar) -> scalar:
"""
Return the result of subtracting another scalar from this instance.

>>> p = scalar.hash('123'.encode())
>>> q = scalar.hash('456'.encode())
>>> (p - q).hex()
'd08dcedb3a8dc87951acd91334a1faed511f49c6e9296780634b858e42347908'
>>> p - p == scalar.from_int(0)
True
"""
return self._implementation.scalar(self._implementation.ssu(self, other))

def to_bytes(self: scalar) -> bytes:
"""
Return the bytes-like object that represents this instance.
Expand Down
32 changes: 32 additions & 0 deletions test/test_ristretto.py
Original file line number Diff line number Diff line change
Expand Up @@ -551,6 +551,20 @@ def test_types_scalar_mul_point(self):
sodium_hidden_and_fallback(hidden, fallback)
self.assertTrue(isinstance(cls.scalar() * cls.point(), cls.point))

def test_types_scalar_add(self):
sodium_hidden_and_fallback(hidden, fallback)
(s0, s1) = (cls.scalar.random(), cls.scalar.random())
self.assertTrue(isinstance(s0 + s1, cls.scalar))

def test_types_scalar_sub(self):
sodium_hidden_and_fallback(hidden, fallback)
(s0, s1) = (cls.scalar.random(), cls.scalar.random())
self.assertTrue(isinstance(s0 - s1, cls.scalar))

def test_types_scalar_neg(self):
sodium_hidden_and_fallback(hidden, fallback)
self.assertTrue(isinstance(-cls.scalar.random(), cls.scalar))

class Test_algebra(TestCase):
"""
Tests of algebraic properties of primitive operations and class methods.
Expand Down Expand Up @@ -689,6 +703,24 @@ def test_algebra_scalar_mul_point_on_left_hand_side(self):
p = cls.point.hash(bytes(POINT_LEN))
self.assertRaises(TypeError, lambda: p * s)

def test_algebra_scalar_add_commute(self):
sodium_hidden_and_fallback(hidden, fallback)
for bs in fountains(SCALAR_LEN + SCALAR_LEN, limit=TRIALS_PER_TEST):
(s0, s1) = (
cls.scalar.hash(bs[:SCALAR_LEN]),
cls.scalar.hash(bs[SCALAR_LEN:])
)
self.assertEqual(cls.sad(s0, s1), cls.sad(s1, s0))

def test_algebra_scalar_add_neg_add_identity(self):
sodium_hidden_and_fallback(hidden, fallback)
for bs in fountains(SCALAR_LEN + SCALAR_LEN, limit=TRIALS_PER_TEST):
(s0, s1) = (
cls.scalar.hash(bs[:SCALAR_LEN]),
cls.scalar.hash(bs[SCALAR_LEN:])
)
self.assertEqual(cls.sad(cls.sad(s0, cls.sne(s0)), s1), s1)

return (
Test_primitives,
Test_classes,
Expand Down
Loading