diff --git a/library/init/meta/float.lean b/library/init/meta/float.lean index 772249a9ab..697985e9f4 100644 --- a/library/init/meta/float.lean +++ b/library/init/meta/float.lean @@ -116,13 +116,17 @@ meta constant acosh : float → float meta constant atanh : float → float meta constant abs : float → float -/-- Nearest integer not less than the given value. -/ +/-- Nearest integer not less than the given value. +Returns 0 if the input is not finite. -/ meta constant ceil : float → int -/-- Nearest integer not greater than the given value. -/ +/-- Nearest integer not greater than the given value. +Returns 0 if the input is not finite. -/ meta constant floor : float → int -/-- Nearest integer not greater in magnitude than the given value. -/ +/-- Nearest integer not greater in magnitude than the given value. +Returns 0 if the input is not finite. -/ meta constant trunc : float → int -/-- Round to the nearest integer, rounding away from zero in halfway cases. -/ +/-- Round to the nearest integer, rounding away from zero in halfway cases. +Returns 0 if the input is not finite. -/ meta constant round : float → int meta constant lt : float → float → bool diff --git a/src/frontends/lean/parser.cpp b/src/frontends/lean/parser.cpp index 31b05f55bd..bd342343ea 100644 --- a/src/frontends/lean/parser.cpp +++ b/src/frontends/lean/parser.cpp @@ -725,7 +725,7 @@ unsigned parser::get_small_nat() { maybe_throw_error({"invalid numeral, value does not fit in a machine integer", pos()}); return 0; } - return val.get(); + return static_cast(val); } pair parser::parse_string_lit() { diff --git a/src/library/string.cpp b/src/library/string.cpp index 2bdb51bf7d..554a182f7b 100644 --- a/src/library/string.cpp +++ b/src/library/string.cpp @@ -189,13 +189,13 @@ optional to_char_core(expr const & e) { expr const & fn = get_app_args(e, args); if (fn == *g_char_mk && args.size() == 2) { if (auto n = to_num(args[0])) { - return optional(n->get()); + return optional(static_cast(*n)); } else { return optional(); } } else if (fn == *g_char_of_nat && args.size() == 1) { if (auto n = to_num(args[0])) { - return optional(n->get()); + return optional(static_cast(*n)); } else { return optional(); } diff --git a/src/library/type_context.cpp b/src/library/type_context.cpp index 403f63162b..651bedc439 100644 --- a/src/library/type_context.cpp +++ b/src/library/type_context.cpp @@ -2962,7 +2962,7 @@ static optional eval_num(expr const & e) { optional type_context_old::to_small_num(expr const & e) { if (optional r = eval_num(e)) { if (r->is()) { - unsigned r1 = r->get(); + unsigned r1 = static_cast(*r); if (r1 <= m_cache->get_nat_offset_cnstr_threshold()) return optional(r1); } diff --git a/src/library/vm/vm.cpp b/src/library/vm/vm.cpp index fff1feb1da..3ff7cc7c05 100644 --- a/src/library/vm/vm.cpp +++ b/src/library/vm/vm.cpp @@ -712,7 +712,7 @@ vm_instr mk_constructor_instr(unsigned cidx, unsigned nfields) { vm_instr mk_num_instr(mpz const & v) { if (v < LEAN_MAX_SMALL_NAT) { vm_instr r(opcode::SConstructor); - r.m_num = v.get(); + r.m_num = static_cast(v); return r; } else { vm_instr r(opcode::Num); diff --git a/src/library/vm/vm_float.cpp b/src/library/vm/vm_float.cpp index 96c4c9b04b..6b21ce3313 100644 --- a/src/library/vm/vm_float.cpp +++ b/src/library/vm/vm_float.cpp @@ -41,10 +41,11 @@ float to_float(vm_obj const & o) { } vm_obj float_of_nat(vm_obj const & a) { - // [TODO] check that the nat isn't too big to fit in an unsigned - return mk_vm_float(static_cast(to_unsigned(a))); + return is_simple(a) ? mk_vm_float(cidx(a)) : mk_vm_float(to_mpz(a).get_double()); +} +vm_obj float_of_int(vm_obj const & i) { + return is_simple(i) ? mk_vm_float(to_int(i)) : mk_vm_float(to_mpz(i).get_double()); } -vm_obj float_of_int(vm_obj const & i) { return mk_vm_float(static_cast(to_int(i))); } vm_obj float_repr(vm_obj const & a) { std::ostringstream out; diff --git a/src/library/vm/vm_int.cpp b/src/library/vm/vm_int.cpp index 720e0e3335..faf664ba69 100644 --- a/src/library/vm/vm_int.cpp +++ b/src/library/vm/vm_int.cpp @@ -5,6 +5,7 @@ Released under Apache 2.0 license as described in the file LICENSE. Author: Leonardo de Moura */ #include +#include #include "library/vm/vm.h" #include "library/vm/vm_nat.h" @@ -14,33 +15,42 @@ Author: Leonardo de Moura namespace lean { // ======================================= // Builtin int operations -inline bool is_small_int(int n) { return LEAN_MIN_SMALL_INT <= n && n < LEAN_MAX_SMALL_INT; } -inline bool is_small_int(unsigned n) { return n < LEAN_MAX_SMALL_INT; } -inline bool is_small_int(long long n) { return LEAN_MIN_SMALL_INT <= n && n < LEAN_MAX_SMALL_INT; } -inline bool is_small_int(mpz const & n) { return LEAN_MIN_SMALL_INT <= n && n < LEAN_MAX_SMALL_INT; } +template +inline typename std::enable_if::is_signed, bool>::type is_small_int(const T& n) { + return LEAN_MIN_SMALL_INT <= n && n < LEAN_MAX_SMALL_INT; +} +template +inline typename std::enable_if::is_signed, bool>::type is_small_int(const T& n) { + return n < LEAN_MAX_SMALL_INT; +} -inline unsigned to_unsigned(int n) { +template +inline unsigned to_unsigned(T n) { lean_assert(is_small_int(n)); - unsigned r = static_cast(n) & 0x7FFFFFFF; + // small ints are strictly smaller than `signed`, so this is safe for `T = mpz` + signed ns = static_cast(n); + unsigned r = static_cast(ns) & 0x7FFFFFFF; lean_assert(r < LEAN_MAX_SMALL_NAT); return r; } -inline int of_unsigned(unsigned n) { +inline int of_unsigned(unsigned n) { return static_cast(n << 1) / 2; } -vm_obj mk_vm_int(int n) { - return is_small_int(n) ? mk_vm_simple(to_unsigned(n)) : mk_vm_mpz(mpz(n)); +template +vm_obj mk_vm_int_impl(T && n) { + return is_small_int(n) ? mk_vm_simple(to_unsigned(n)) : mk_vm_mpz(mpz(std::forward(n))); } -vm_obj mk_vm_int(unsigned n) { - return is_small_int(n) ? mk_vm_simple(to_unsigned(n)) : mk_vm_mpz(mpz(n)); -} - -vm_obj mk_vm_int(mpz const & n) { - return is_small_int(n) ? mk_vm_simple(to_unsigned(n.get())) : mk_vm_mpz(n); -} +vm_obj mk_vm_int(int n) { return mk_vm_int_impl(n); } +vm_obj mk_vm_int(unsigned int n) { return mk_vm_int_impl(n); } +vm_obj mk_vm_int(long n) { return mk_vm_int_impl(n); } +vm_obj mk_vm_int(unsigned long n) { return mk_vm_int_impl(n); } +vm_obj mk_vm_int(long long n) { return mk_vm_int_impl(n); } +vm_obj mk_vm_int(unsigned long long n) { return mk_vm_int_impl(n); } +vm_obj mk_vm_int(double n) { return std::isfinite(n) ? mk_vm_int_impl(n) : mk_vm_int(0); } +vm_obj mk_vm_int(mpz const & n) { return mk_vm_int_impl(n); } inline int to_small_int(vm_obj const & o) { lean_assert(is_simple(o)); @@ -51,7 +61,7 @@ int to_int(vm_obj const & o) { if (is_simple(o)) return to_small_int(o); else - return to_mpz(o).get(); + return static_cast(to_mpz(o)); } optional try_to_int(vm_obj const & o) { @@ -60,7 +70,7 @@ optional try_to_int(vm_obj const & o) { } else { mpz const & v = to_mpz(o); if (v.is()) - return optional(v.get()); + return optional(static_cast(v)); else return optional(); } @@ -235,7 +245,7 @@ vm_obj int_test_bit(vm_obj const & a1, vm_obj const & a2) { mpz const & v1 = to_mpz1(a1); mpz const & v2 = to_mpz2(a2); if (v2.is()) - return mk_vm_bool(v1.test_bit(v2.get())); + return mk_vm_bool(v1.test_bit(static_cast(v2))); else return mk_vm_bool(false); } diff --git a/src/library/vm/vm_int.h b/src/library/vm/vm_int.h index bf9638cb5e..3d8566db8a 100644 --- a/src/library/vm/vm_int.h +++ b/src/library/vm/vm_int.h @@ -11,7 +11,16 @@ namespace lean { int to_int(vm_obj const & o); optional try_to_int(vm_obj const & o); int force_to_int(vm_obj const & o, int def); -vm_obj mk_vm_int(int i); + +vm_obj mk_vm_int(int n); +vm_obj mk_vm_int(unsigned int n); +vm_obj mk_vm_int(long n); +vm_obj mk_vm_int(unsigned long n); +vm_obj mk_vm_int(long long n); +vm_obj mk_vm_int(unsigned long long n); +vm_obj mk_vm_int(double n); +vm_obj mk_vm_int(mpz const & n); + void initialize_vm_int(); void finalize_vm_int(); } diff --git a/src/library/vm/vm_nat.cpp b/src/library/vm/vm_nat.cpp index 6cea424ce6..32f3799ff8 100644 --- a/src/library/vm/vm_nat.cpp +++ b/src/library/vm/vm_nat.cpp @@ -21,7 +21,7 @@ vm_obj mk_vm_nat(unsigned n) { vm_obj mk_vm_nat(mpz const & n) { if (LEAN_LIKELY(n < LEAN_MAX_SMALL_NAT)) - return mk_vm_simple(n.get()); + return mk_vm_simple(static_cast(n)); else return mk_vm_mpz(n); } @@ -30,7 +30,7 @@ unsigned to_unsigned(vm_obj const & o) { if (LEAN_LIKELY(is_simple(o))) return cidx(o); else - return to_mpz(o).get(); + return static_cast(to_mpz(o)); } optional try_to_unsigned(vm_obj const & o) { @@ -39,7 +39,7 @@ optional try_to_unsigned(vm_obj const & o) { } else { mpz const & v = to_mpz(o); if (v.is()) - return optional(v.get()); + return optional(static_cast(v)); else return optional(); } @@ -262,7 +262,7 @@ vm_obj nat_test_bit(vm_obj const & a1, vm_obj const & a2) { mpz const & v1 = to_mpz1(a1); mpz const & v2 = to_mpz2(a2); if (v2.is()) - return mk_vm_bool(v1.test_bit(v2.get())); + return mk_vm_bool(v1.test_bit(static_cast(v2))); else return mk_vm_bool(false); } diff --git a/src/tests/util/numerics/mpz.cpp b/src/tests/util/numerics/mpz.cpp index 1ee98fef07..0125c5a990 100644 --- a/src/tests/util/numerics/mpz.cpp +++ b/src/tests/util/numerics/mpz.cpp @@ -82,13 +82,27 @@ static void tst5() { mpz m_max(max); lean_assert(m_max.is()); lean_assert(!(m_max + 1).is()); - lean_assert(m_max.get() == max); + lean_assert(static_cast(m_max) == max); T min = std::numeric_limits::min(); mpz m_min(min); lean_assert(m_min.is()); lean_assert(!(m_min - 1).is()); - lean_assert(m_min.get() == min); + lean_assert(static_cast(m_min) == min); + + if (std::numeric_limits::is_signed) { + T neg_one = -1; + mpz m_neg_one(neg_one); + lean_assert(m_neg_one.is()); + lean_assert(static_cast(m_neg_one) == neg_one); + } +} + +static void tst6() { + // the largest representable double is integral, so is fine to store in mpz + double max = std::numeric_limits::max(); + mpz n1(max); + lean_assert(n1.get_double() == max); } int main() { @@ -102,5 +116,6 @@ int main() { tst5(); tst5(); tst5(); + tst6(); return has_violations() ? 1 : 0; } diff --git a/src/util/numerics/mpz.cpp b/src/util/numerics/mpz.cpp index dd2e3fb199..f2de7f5d46 100644 --- a/src/util/numerics/mpz.cpp +++ b/src/util/numerics/mpz.cpp @@ -25,19 +25,19 @@ mpz::mpz(int64 v) : mpz(static_cast(v)) { mpz_add(m_val, m_val, tmp.m_val); } -template<> long long int mpz::get() const { +mpz::operator long long int() const { lean_assert(is()); mpz high_m, low_m; mpz_fdiv_r_2exp(low_m.m_val, m_val, 32); mpz_fdiv_q_2exp(high_m.m_val, m_val, 32); - return static_cast(high_m.get()) << 32 | low_m.get(); + return static_cast(high_m.operator signed()) << 32 | low_m.operator unsigned(); } -template<> unsigned long long int mpz::get() const { +mpz::operator unsigned long long int() const { lean_assert(is()); mpz high_m, low_m; mpz_fdiv_r_2exp(low_m.m_val, m_val, 32); mpz_fdiv_q_2exp(high_m.m_val, m_val, 32); - return static_cast(high_m.get()) << 32 | low_m.get(); + return static_cast(high_m.operator unsigned()) << 32 | low_m.operator unsigned(); } unsigned mpz::log2() const { diff --git a/src/util/numerics/mpz.h b/src/util/numerics/mpz.h index be177330c0..43196af8a8 100644 --- a/src/util/numerics/mpz.h +++ b/src/util/numerics/mpz.h @@ -33,6 +33,7 @@ class mpz { explicit mpz(int v) { mpz_init_set_si(m_val, v); } explicit mpz(uint64 v); explicit mpz(int64 v); + explicit mpz(double v) { mpz_init_set_d(m_val, v); } mpz(mpz const & s) { mpz_init_set(m_val, s.m_val); } mpz(mpz && s):mpz() { mpz_swap(m_val, s.m_val); } ~mpz() { mpz_clear(m_val); } @@ -58,9 +59,16 @@ class mpz { bool even() const { return mpz_even_p(m_val) != 0; } bool odd() const { return !even(); } - template bool is() const; - template T get() const; + template bool is() const = delete; + explicit operator long int() const; + explicit operator unsigned long int() const; + explicit operator int() const; + explicit operator unsigned int() const; + explicit operator long long int() const; + explicit operator unsigned long long int() const; + + // not a cast operator, to match `mpz` double get_double() const { return mpz_get_d(m_val); } mpz & operator=(mpz const & v) { mpz_set(m_val, v.m_val); return *this; } @@ -70,6 +78,7 @@ class mpz { mpz & operator=(long int v) { mpz_set_si(m_val, v); return *this; } mpz & operator=(unsigned int v) { return operator=(static_cast(v)); } mpz & operator=(int v) { return operator=(static_cast(v)); } + mpz & operator=(double v) { mpz_set_d(m_val, v); return *this; } friend int cmp(mpz const & a, mpz const & b) { return mpz_cmp(a.m_val, b.m_val); } friend int cmp(mpz const & a, unsigned b) { return mpz_cmp_ui(a.m_val, b); } @@ -234,12 +243,11 @@ template<> inline bool mpz::is() const { template<> inline bool mpz::is() const { return mpz(std::numeric_limits::min()) <= *this && *this <= mpz(std::numeric_limits::max()); } -template<> inline long int mpz::get() const { lean_assert(is()); return mpz_get_si(m_val); } -template<> inline unsigned long int mpz::get() const { lean_assert(is()); return mpz_get_ui(m_val); } -template<> inline int mpz::get() const { lean_assert(is()); return static_cast(get()); } -template<> inline unsigned int mpz::get() const { lean_assert(is()); return static_cast(get()); } -template<> long long int mpz::get() const; -template<> unsigned long long int mpz::get() const; +// we can't define these until the `is` specializations are declared +inline mpz::operator long int() const { lean_assert(is()); return mpz_get_si(m_val); } +inline mpz::operator unsigned long int() const { lean_assert(is()); return mpz_get_ui(m_val); } +inline mpz::operator int() const { lean_assert(is()); return static_cast(operator long int()); } +inline mpz::operator unsigned int() const { lean_assert(is()); return static_cast(operator unsigned long int()); } struct mpz_cmp_fn { int operator()(mpz const & v1, mpz const & v2) const { return cmp(v1, v2); } diff --git a/tests/lean/float.lean b/tests/lean/float.lean index 3199e38ea4..5434a8b874 100644 --- a/tests/lean/float.lean +++ b/tests/lean/float.lean @@ -95,6 +95,11 @@ list.foldl band tt $ floats.map (λ ⟨x,y⟩, prop x y) #eval cos pi #eval tan pi ≃ 0 +#eval of_nat 0x100000000 -- 2 ^ 32 +#eval of_int 0x100000000 -- 2 ^ 32 +#eval of_nat 0x10000000000000000 -- 2 ^ 64 +#eval of_int 0x10000000000000000 -- 2 ^ 64 + #eval (of_int (-12341234)) #eval to_bool $ float.floor (2.5) = 2 @@ -106,6 +111,14 @@ list.foldl band tt $ floats.map (λ ⟨x,y⟩, prop x y) #eval to_bool $ float.round (2.5) = 3 #eval to_bool $ float.round (-2.5) = -3 +#eval to_bool $ float.round (of_int 0x100000000) = 0x100000000 +#eval to_bool $ float.round (of_int 0x10000000000000000) = 0x10000000000000000 + +#eval (native.float.infinity : native.float).floor +#eval (native.float.infinity : native.float).ceil +#eval (native.float.infinity : native.float).trunc +#eval (native.float.infinity : native.float).round + #eval (of_string "hello") #eval (of_string "0.123E4") #eval (of_string "-123.123") diff --git a/tests/lean/float.lean.expected.out b/tests/lean/float.lean.expected.out index 611dfe7f67..52843de39f 100644 --- a/tests/lean/float.lean.expected.out +++ b/tests/lean/float.lean.expected.out @@ -70,6 +70,10 @@ tt tt -1 tt +4.29497e+09 +4.29497e+09 +1.84467e+19 +1.84467e+19 -1.23412e+07 tt tt @@ -79,6 +83,12 @@ tt tt tt tt +tt +tt +0 +0 +0 +0 none (some 1230) (some -123.123)