Skip to content

Commit

Permalink
fix(library/vm/vm_{int,float}): fix overflow errors (#742)
Browse files Browse the repository at this point in the history
The behavior before this change was:
```lean
#eval native.float.of_nat 0x100000000    -- 0
#eval native.float.of_int 0x100000000    -- 0
#eval (0x100000000 : native.float).floor  -- -2147483648
#eval (0x100000000 : native.float).ceil   -- -2147483648
#eval (0x100000000 : native.float).round  -- -2147483648
```
The new behaviour outputs `4.29497e+09` or `4294967296` as appropriate.
The fix is to convert between floats and large ints/nats through `mpz`.

To get here, this also:

* replaces `v.get<T>()` with `static_cast<T>(v)` for `mpz`.
 The benefit here is that we don't then need to treat `mpz` specially when writing functions which are generic over numeric types.
* adds a much larger family of `mk_vm_int` functions with more appropriate overloads.

I can split these into separate PRs if desired.
  • Loading branch information
eric-wieser committed Jul 12, 2022
1 parent 5f58002 commit 9dc6b1e
Show file tree
Hide file tree
Showing 14 changed files with 120 additions and 50 deletions.
12 changes: 8 additions & 4 deletions library/init/meta/float.lean
Original file line number Diff line number Diff line change
Expand Up @@ -116,13 +116,17 @@ meta constant acosh : float → float
meta constant atanh : float → float

meta constant abs : float → float
/-- Nearest integer not less than the given value. -/
/-- Nearest integer not less than the given value.
Returns 0 if the input is not finite. -/
meta constant ceil : float → int
/-- Nearest integer not greater than the given value. -/
/-- Nearest integer not greater than the given value.
Returns 0 if the input is not finite. -/
meta constant floor : float → int
/-- Nearest integer not greater in magnitude than the given value. -/
/-- Nearest integer not greater in magnitude than the given value.
Returns 0 if the input is not finite. -/
meta constant trunc : float → int
/-- Round to the nearest integer, rounding away from zero in halfway cases. -/
/-- Round to the nearest integer, rounding away from zero in halfway cases.
Returns 0 if the input is not finite. -/
meta constant round : float → int

meta constant lt : float → float → bool
Expand Down
2 changes: 1 addition & 1 deletion src/frontends/lean/parser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -725,7 +725,7 @@ unsigned parser::get_small_nat() {
maybe_throw_error({"invalid numeral, value does not fit in a machine integer", pos()});
return 0;
}
return val.get<unsigned>();
return static_cast<unsigned>(val);
}

pair<ast_id, std::string> parser::parse_string_lit() {
Expand Down
4 changes: 2 additions & 2 deletions src/library/string.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -189,13 +189,13 @@ optional<unsigned> to_char_core(expr const & e) {
expr const & fn = get_app_args(e, args);
if (fn == *g_char_mk && args.size() == 2) {
if (auto n = to_num(args[0])) {
return optional<unsigned>(n->get<unsigned>());
return optional<unsigned>(static_cast<unsigned>(*n));
} else {
return optional<unsigned>();
}
} else if (fn == *g_char_of_nat && args.size() == 1) {
if (auto n = to_num(args[0])) {
return optional<unsigned>(n->get<unsigned>());
return optional<unsigned>(static_cast<unsigned>(*n));
} else {
return optional<unsigned>();
}
Expand Down
2 changes: 1 addition & 1 deletion src/library/type_context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2962,7 +2962,7 @@ static optional<mpz> eval_num(expr const & e) {
optional<unsigned> type_context_old::to_small_num(expr const & e) {
if (optional<mpz> r = eval_num(e)) {
if (r->is<unsigned>()) {
unsigned r1 = r->get<unsigned>();
unsigned r1 = static_cast<unsigned>(*r);
if (r1 <= m_cache->get_nat_offset_cnstr_threshold())
return optional<unsigned>(r1);
}
Expand Down
2 changes: 1 addition & 1 deletion src/library/vm/vm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -712,7 +712,7 @@ vm_instr mk_constructor_instr(unsigned cidx, unsigned nfields) {
vm_instr mk_num_instr(mpz const & v) {
if (v < LEAN_MAX_SMALL_NAT) {
vm_instr r(opcode::SConstructor);
r.m_num = v.get<unsigned>();
r.m_num = static_cast<unsigned>(v);
return r;
} else {
vm_instr r(opcode::Num);
Expand Down
7 changes: 4 additions & 3 deletions src/library/vm/vm_float.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,11 @@ float to_float(vm_obj const & o) {
}

vm_obj float_of_nat(vm_obj const & a) {
// [TODO] check that the nat isn't too big to fit in an unsigned
return mk_vm_float(static_cast<float>(to_unsigned(a)));
return is_simple(a) ? mk_vm_float(cidx(a)) : mk_vm_float(to_mpz(a).get_double());
}
vm_obj float_of_int(vm_obj const & i) {
return is_simple(i) ? mk_vm_float(to_int(i)) : mk_vm_float(to_mpz(i).get_double());
}
vm_obj float_of_int(vm_obj const & i) { return mk_vm_float(static_cast<float>(to_int(i))); }

vm_obj float_repr(vm_obj const & a) {
std::ostringstream out;
Expand Down
48 changes: 29 additions & 19 deletions src/library/vm/vm_int.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ Released under Apache 2.0 license as described in the file LICENSE.
Author: Leonardo de Moura
*/
#include <iostream>
#include <limits>
#include "library/vm/vm.h"
#include "library/vm/vm_nat.h"

Expand All @@ -14,33 +15,42 @@ Author: Leonardo de Moura
namespace lean {
// =======================================
// Builtin int operations
inline bool is_small_int(int n) { return LEAN_MIN_SMALL_INT <= n && n < LEAN_MAX_SMALL_INT; }
inline bool is_small_int(unsigned n) { return n < LEAN_MAX_SMALL_INT; }
inline bool is_small_int(long long n) { return LEAN_MIN_SMALL_INT <= n && n < LEAN_MAX_SMALL_INT; }
inline bool is_small_int(mpz const & n) { return LEAN_MIN_SMALL_INT <= n && n < LEAN_MAX_SMALL_INT; }
template<typename T>
inline typename std::enable_if<std::numeric_limits<T>::is_signed, bool>::type is_small_int(const T& n) {
return LEAN_MIN_SMALL_INT <= n && n < LEAN_MAX_SMALL_INT;
}
template<typename T>
inline typename std::enable_if<!std::numeric_limits<T>::is_signed, bool>::type is_small_int(const T& n) {
return n < LEAN_MAX_SMALL_INT;
}

inline unsigned to_unsigned(int n) {
template<typename T>
inline unsigned to_unsigned(T n) {
lean_assert(is_small_int(n));
unsigned r = static_cast<unsigned>(n) & 0x7FFFFFFF;
// small ints are strictly smaller than `signed`, so this is safe for `T = mpz`
signed ns = static_cast<signed>(n);
unsigned r = static_cast<unsigned>(ns) & 0x7FFFFFFF;
lean_assert(r < LEAN_MAX_SMALL_NAT);
return r;
}

inline int of_unsigned(unsigned n) {
inline int of_unsigned(unsigned n) {
return static_cast<int>(n << 1) / 2;
}

vm_obj mk_vm_int(int n) {
return is_small_int(n) ? mk_vm_simple(to_unsigned(n)) : mk_vm_mpz(mpz(n));
template<typename T>
vm_obj mk_vm_int_impl(T && n) {
return is_small_int(n) ? mk_vm_simple(to_unsigned(n)) : mk_vm_mpz(mpz(std::forward<T>(n)));
}

vm_obj mk_vm_int(unsigned n) {
return is_small_int(n) ? mk_vm_simple(to_unsigned(n)) : mk_vm_mpz(mpz(n));
}

vm_obj mk_vm_int(mpz const & n) {
return is_small_int(n) ? mk_vm_simple(to_unsigned(n.get<int>())) : mk_vm_mpz(n);
}
vm_obj mk_vm_int(int n) { return mk_vm_int_impl(n); }
vm_obj mk_vm_int(unsigned int n) { return mk_vm_int_impl(n); }
vm_obj mk_vm_int(long n) { return mk_vm_int_impl(n); }
vm_obj mk_vm_int(unsigned long n) { return mk_vm_int_impl(n); }
vm_obj mk_vm_int(long long n) { return mk_vm_int_impl(n); }
vm_obj mk_vm_int(unsigned long long n) { return mk_vm_int_impl(n); }
vm_obj mk_vm_int(double n) { return std::isfinite(n) ? mk_vm_int_impl(n) : mk_vm_int(0); }
vm_obj mk_vm_int(mpz const & n) { return mk_vm_int_impl(n); }

inline int to_small_int(vm_obj const & o) {
lean_assert(is_simple(o));
Expand All @@ -51,7 +61,7 @@ int to_int(vm_obj const & o) {
if (is_simple(o))
return to_small_int(o);
else
return to_mpz(o).get<int>();
return static_cast<int>(to_mpz(o));
}

optional<int> try_to_int(vm_obj const & o) {
Expand All @@ -60,7 +70,7 @@ optional<int> try_to_int(vm_obj const & o) {
} else {
mpz const & v = to_mpz(o);
if (v.is<int>())
return optional<int>(v.get<int>());
return optional<int>(static_cast<int>(v));
else
return optional<int>();
}
Expand Down Expand Up @@ -235,7 +245,7 @@ vm_obj int_test_bit(vm_obj const & a1, vm_obj const & a2) {
mpz const & v1 = to_mpz1(a1);
mpz const & v2 = to_mpz2(a2);
if (v2.is<unsigned long int>())
return mk_vm_bool(v1.test_bit(v2.get<unsigned long int>()));
return mk_vm_bool(v1.test_bit(static_cast<unsigned long int>(v2)));
else
return mk_vm_bool(false);
}
Expand Down
11 changes: 10 additions & 1 deletion src/library/vm/vm_int.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,16 @@ namespace lean {
int to_int(vm_obj const & o);
optional<int> try_to_int(vm_obj const & o);
int force_to_int(vm_obj const & o, int def);
vm_obj mk_vm_int(int i);

vm_obj mk_vm_int(int n);
vm_obj mk_vm_int(unsigned int n);
vm_obj mk_vm_int(long n);
vm_obj mk_vm_int(unsigned long n);
vm_obj mk_vm_int(long long n);
vm_obj mk_vm_int(unsigned long long n);
vm_obj mk_vm_int(double n);
vm_obj mk_vm_int(mpz const & n);

void initialize_vm_int();
void finalize_vm_int();
}
8 changes: 4 additions & 4 deletions src/library/vm/vm_nat.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ vm_obj mk_vm_nat(unsigned n) {

vm_obj mk_vm_nat(mpz const & n) {
if (LEAN_LIKELY(n < LEAN_MAX_SMALL_NAT))
return mk_vm_simple(n.get<unsigned>());
return mk_vm_simple(static_cast<unsigned>(n));
else
return mk_vm_mpz(n);
}
Expand All @@ -30,7 +30,7 @@ unsigned to_unsigned(vm_obj const & o) {
if (LEAN_LIKELY(is_simple(o)))
return cidx(o);
else
return to_mpz(o).get<unsigned>();
return static_cast<unsigned>(to_mpz(o));
}

optional<unsigned> try_to_unsigned(vm_obj const & o) {
Expand All @@ -39,7 +39,7 @@ optional<unsigned> try_to_unsigned(vm_obj const & o) {
} else {
mpz const & v = to_mpz(o);
if (v.is<unsigned>())
return optional<unsigned>(v.get<unsigned>());
return optional<unsigned>(static_cast<unsigned>(v));
else
return optional<unsigned>();
}
Expand Down Expand Up @@ -262,7 +262,7 @@ vm_obj nat_test_bit(vm_obj const & a1, vm_obj const & a2) {
mpz const & v1 = to_mpz1(a1);
mpz const & v2 = to_mpz2(a2);
if (v2.is<unsigned long int>())
return mk_vm_bool(v1.test_bit(v2.get<unsigned long int>()));
return mk_vm_bool(v1.test_bit(static_cast<unsigned long int>(v2)));
else
return mk_vm_bool(false);
}
Expand Down
19 changes: 17 additions & 2 deletions src/tests/util/numerics/mpz.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -82,13 +82,27 @@ static void tst5() {
mpz m_max(max);
lean_assert(m_max.is<T>());
lean_assert(!(m_max + 1).is<T>());
lean_assert(m_max.get<T>() == max);
lean_assert(static_cast<T>(m_max) == max);

T min = std::numeric_limits<T>::min();
mpz m_min(min);
lean_assert(m_min.is<T>());
lean_assert(!(m_min - 1).is<T>());
lean_assert(m_min.get<T>() == min);
lean_assert(static_cast<T>(m_min) == min);

if (std::numeric_limits<T>::is_signed) {
T neg_one = -1;
mpz m_neg_one(neg_one);
lean_assert(m_neg_one.is<T>());
lean_assert(static_cast<T>(m_neg_one) == neg_one);
}
}

static void tst6() {
// the largest representable double is integral, so is fine to store in mpz
double max = std::numeric_limits<double>::max();
mpz n1(max);
lean_assert(n1.get_double() == max);
}

int main() {
Expand All @@ -102,5 +116,6 @@ int main() {
tst5<unsigned long>();
tst5<long long>();
tst5<unsigned long long>();
tst6();
return has_violations() ? 1 : 0;
}
8 changes: 4 additions & 4 deletions src/util/numerics/mpz.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,19 +25,19 @@ mpz::mpz(int64 v) : mpz(static_cast<unsigned>(v)) {
mpz_add(m_val, m_val, tmp.m_val);
}

template<> long long int mpz::get() const {
mpz::operator long long int() const {
lean_assert(is<long long int>());
mpz high_m, low_m;
mpz_fdiv_r_2exp(low_m.m_val, m_val, 32);
mpz_fdiv_q_2exp(high_m.m_val, m_val, 32);
return static_cast<long long int>(high_m.get<signed>()) << 32 | low_m.get<unsigned>();
return static_cast<long long int>(high_m.operator signed()) << 32 | low_m.operator unsigned();
}
template<> unsigned long long int mpz::get() const {
mpz::operator unsigned long long int() const {
lean_assert(is<unsigned long long int>());
mpz high_m, low_m;
mpz_fdiv_r_2exp(low_m.m_val, m_val, 32);
mpz_fdiv_q_2exp(high_m.m_val, m_val, 32);
return static_cast<unsigned long long int>(high_m.get<unsigned>()) << 32 | low_m.get<unsigned>();
return static_cast<unsigned long long int>(high_m.operator unsigned()) << 32 | low_m.operator unsigned();
}

unsigned mpz::log2() const {
Expand Down
24 changes: 16 additions & 8 deletions src/util/numerics/mpz.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ class mpz {
explicit mpz(int v) { mpz_init_set_si(m_val, v); }
explicit mpz(uint64 v);
explicit mpz(int64 v);
explicit mpz(double v) { mpz_init_set_d(m_val, v); }
mpz(mpz const & s) { mpz_init_set(m_val, s.m_val); }
mpz(mpz && s):mpz() { mpz_swap(m_val, s.m_val); }
~mpz() { mpz_clear(m_val); }
Expand All @@ -58,9 +59,16 @@ class mpz {
bool even() const { return mpz_even_p(m_val) != 0; }
bool odd() const { return !even(); }

template <typename T> bool is() const;
template <typename T> T get() const;
template <typename T> bool is() const = delete;

explicit operator long int() const;
explicit operator unsigned long int() const;
explicit operator int() const;
explicit operator unsigned int() const;
explicit operator long long int() const;
explicit operator unsigned long long int() const;

// not a cast operator, to match `mpz`
double get_double() const { return mpz_get_d(m_val); }

mpz & operator=(mpz const & v) { mpz_set(m_val, v.m_val); return *this; }
Expand All @@ -70,6 +78,7 @@ class mpz {
mpz & operator=(long int v) { mpz_set_si(m_val, v); return *this; }
mpz & operator=(unsigned int v) { return operator=(static_cast<unsigned long int>(v)); }
mpz & operator=(int v) { return operator=(static_cast<long int>(v)); }
mpz & operator=(double v) { mpz_set_d(m_val, v); return *this; }

friend int cmp(mpz const & a, mpz const & b) { return mpz_cmp(a.m_val, b.m_val); }
friend int cmp(mpz const & a, unsigned b) { return mpz_cmp_ui(a.m_val, b); }
Expand Down Expand Up @@ -234,12 +243,11 @@ template<> inline bool mpz::is<long long>() const {
template<> inline bool mpz::is<unsigned long long>() const {
return mpz(std::numeric_limits<unsigned long long>::min()) <= *this && *this <= mpz(std::numeric_limits<unsigned long long>::max()); }

template<> inline long int mpz::get() const { lean_assert(is<long int>()); return mpz_get_si(m_val); }
template<> inline unsigned long int mpz::get() const { lean_assert(is<unsigned long int>()); return mpz_get_ui(m_val); }
template<> inline int mpz::get() const { lean_assert(is<int>()); return static_cast<int>(get<long int>()); }
template<> inline unsigned int mpz::get() const { lean_assert(is<unsigned int>()); return static_cast<unsigned>(get<unsigned long int>()); }
template<> long long int mpz::get() const;
template<> unsigned long long int mpz::get() const;
// we can't define these until the `is` specializations are declared
inline mpz::operator long int() const { lean_assert(is<long int>()); return mpz_get_si(m_val); }
inline mpz::operator unsigned long int() const { lean_assert(is<unsigned long int>()); return mpz_get_ui(m_val); }
inline mpz::operator int() const { lean_assert(is<int>()); return static_cast<long>(operator long int()); }
inline mpz::operator unsigned int() const { lean_assert(is<unsigned int>()); return static_cast<unsigned>(operator unsigned long int()); }

struct mpz_cmp_fn {
int operator()(mpz const & v1, mpz const & v2) const { return cmp(v1, v2); }
Expand Down
13 changes: 13 additions & 0 deletions tests/lean/float.lean
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,11 @@ list.foldl band tt $ floats.map (λ ⟨x,y⟩, prop x y)
#eval cos pi
#eval tan pi ≃ 0

#eval of_nat 0x100000000 -- 2 ^ 32
#eval of_int 0x100000000 -- 2 ^ 32
#eval of_nat 0x10000000000000000 -- 2 ^ 64
#eval of_int 0x10000000000000000 -- 2 ^ 64

#eval (of_int (-12341234))

#eval to_bool $ float.floor (2.5) = 2
Expand All @@ -106,6 +111,14 @@ list.foldl band tt $ floats.map (λ ⟨x,y⟩, prop x y)
#eval to_bool $ float.round (2.5) = 3
#eval to_bool $ float.round (-2.5) = -3

#eval to_bool $ float.round (of_int 0x100000000) = 0x100000000
#eval to_bool $ float.round (of_int 0x10000000000000000) = 0x10000000000000000

#eval (native.float.infinity : native.float).floor
#eval (native.float.infinity : native.float).ceil
#eval (native.float.infinity : native.float).trunc
#eval (native.float.infinity : native.float).round

#eval (of_string "hello")
#eval (of_string "0.123E4")
#eval (of_string "-123.123")
Expand Down
10 changes: 10 additions & 0 deletions tests/lean/float.lean.expected.out
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,10 @@ tt
tt
-1
tt
4.29497e+09
4.29497e+09
1.84467e+19
1.84467e+19
-1.23412e+07
tt
tt
Expand All @@ -79,6 +83,12 @@ tt
tt
tt
tt
tt
tt
0
0
0
0
none
(some 1230)
(some -123.123)
Expand Down

0 comments on commit 9dc6b1e

Please sign in to comment.