Skip to content

Commit

Permalink
Expose int4 c++ type to "include" folder for use in external projects.
Browse files Browse the repository at this point in the history
This involves separating the C++ definition from the numpy type
information to avoid the numpy dependency.

Also renamed `_custom_floats` to `_ml_dtypes_lib` for the compiled
shared library, since it includes more than just floating-point types.

We will use it in XLA and TensorFlow.

PiperOrigin-RevId: 561775184
  • Loading branch information
cantonios authored and The ml_dtypes Authors committed Aug 31, 2023
1 parent 54e375a commit fa8060c
Show file tree
Hide file tree
Showing 7 changed files with 125 additions and 102 deletions.
17 changes: 8 additions & 9 deletions ml_dtypes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,17 +29,16 @@

from typing import Type

from ml_dtypes._custom_floats import bfloat16
from ml_dtypes._custom_floats import float8_e4m3b11fnuz
from ml_dtypes._custom_floats import float8_e4m3fn
from ml_dtypes._custom_floats import float8_e4m3fnuz
from ml_dtypes._custom_floats import float8_e5m2
from ml_dtypes._custom_floats import float8_e5m2fnuz
from ml_dtypes._custom_floats import int4
from ml_dtypes._custom_floats import uint4
from ml_dtypes._finfo import finfo
from ml_dtypes._iinfo import iinfo

from ml_dtypes._ml_dtypes_ext import bfloat16
from ml_dtypes._ml_dtypes_ext import float8_e4m3b11fnuz
from ml_dtypes._ml_dtypes_ext import float8_e4m3fn
from ml_dtypes._ml_dtypes_ext import float8_e4m3fnuz
from ml_dtypes._ml_dtypes_ext import float8_e5m2
from ml_dtypes._ml_dtypes_ext import float8_e5m2fnuz
from ml_dtypes._ml_dtypes_ext import int4
from ml_dtypes._ml_dtypes_ext import uint4
import numpy as np

bfloat16: Type[np.generic]
Expand Down
13 changes: 6 additions & 7 deletions ml_dtypes/_finfo.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,12 @@

from typing import Dict

from ml_dtypes._custom_floats import bfloat16
from ml_dtypes._custom_floats import float8_e4m3b11fnuz
from ml_dtypes._custom_floats import float8_e4m3fn
from ml_dtypes._custom_floats import float8_e4m3fnuz
from ml_dtypes._custom_floats import float8_e5m2
from ml_dtypes._custom_floats import float8_e5m2fnuz

from ml_dtypes._ml_dtypes_ext import bfloat16
from ml_dtypes._ml_dtypes_ext import float8_e4m3b11fnuz
from ml_dtypes._ml_dtypes_ext import float8_e4m3fn
from ml_dtypes._ml_dtypes_ext import float8_e4m3fnuz
from ml_dtypes._ml_dtypes_ext import float8_e5m2
from ml_dtypes._ml_dtypes_ext import float8_e5m2fnuz
import numpy as np

_bfloat16_dtype = np.dtype(bfloat16)
Expand Down
5 changes: 2 additions & 3 deletions ml_dtypes/_iinfo.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,8 @@

"""Overload of numpy.iinfo to handle dtypes defined in ml_dtypes."""

from ml_dtypes._custom_floats import int4
from ml_dtypes._custom_floats import uint4

from ml_dtypes._ml_dtypes_ext import int4
from ml_dtypes._ml_dtypes_ext import uint4
import numpy as np

_int4_dtype = np.dtype(int4)
Expand Down
9 changes: 5 additions & 4 deletions ml_dtypes/_src/dtypes.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,9 @@ limitations under the License.

#include "Eigen/Core"
#include "_src/custom_float.h"
#include "_src/int4.h"
#include "_src/int4_numpy.h"
#include "include/float8.h"
#include "include/int4.h"

namespace ml_dtypes {

Expand Down Expand Up @@ -297,7 +298,7 @@ bool Initialize() {

static PyModuleDef module_def = {
PyModuleDef_HEAD_INIT,
"_custom_floats",
"_ml_dtypes_ext",
};

// TODO(phawkins): PyMODINIT_FUNC handles visibility correctly in Python 3.9+.
Expand All @@ -308,14 +309,14 @@ static PyModuleDef module_def = {
#define EXPORT_SYMBOL __attribute__((visibility("default")))
#endif

extern "C" EXPORT_SYMBOL PyObject* PyInit__custom_floats() {
extern "C" EXPORT_SYMBOL PyObject* PyInit__ml_dtypes_ext() {
Safe_PyObjectPtr m = make_safe(PyModule_Create(&module_def));
if (!m) {
return nullptr;
}
if (!Initialize()) {
if (!PyErr_Occurred()) {
PyErr_SetString(PyExc_RuntimeError, "cannot load _custom_floats module.");
PyErr_SetString(PyExc_RuntimeError, "cannot load _ml_dtypes_ext module.");
}
return nullptr;
}
Expand Down
82 changes: 4 additions & 78 deletions ml_dtypes/_src/int4.h → ml_dtypes/_src/int4_numpy.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,95 +13,21 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#ifndef ML_DTYPES_INT4_H_
#define ML_DTYPES_INT4_H_
#ifndef ML_DTYPES_INT4_NUMPY_H_
#define ML_DTYPES_INT4_NUMPY_H_

// Must be included first
// clang-format off
#include "_src/numpy.h"
// clang-format on

#include <cstdint> //NOLINT
#include <optional> //NOLINT
#include <ostream> //NOLINT
#include <sstream> //NOLINT

#include "Eigen/Core"
#include "_src/common.h" // NOLINT
#include "_src/ufuncs.h" // NOLINT
#include "include/int4.h"

namespace ml_dtypes {

template <typename UnderlyingTy>
struct i4 {
private:
UnderlyingTy v : 4;

public:
i4() : v(0) {}
explicit i4(UnderlyingTy val) : v(val & 0x0F) {}
template <typename T>
explicit i4(T t) : i4(static_cast<UnderlyingTy>(t)) {}
i4(const i4& other) = default;

static constexpr i4 lowest() {
return std::is_signed<UnderlyingTy>::value ? i4(-8) : i4(0);
}
static constexpr i4 highest() {
return std::is_signed<UnderlyingTy>::value ? i4(7) : i4(15);
}

template <typename T, typename = std::enable_if_t<std::is_arithmetic_v<T>>>
explicit operator T() const {
return static_cast<T>(v);
}
// NOLINTNEXTLINE(google-explicit-constructor)
operator std::optional<int64_t>() const { return static_cast<int64_t>(v); }

i4 operator-() const { return i4(-v); }
i4 operator+(const i4& other) const { return i4((v + other.v)); }
i4 operator-(const i4& other) const { return i4((v - other.v)); }
i4 operator*(const i4& other) const { return i4((v * other.v)); }
i4 operator/(const i4& other) const { return i4((v / other.v)); }
i4 operator%(const i4& other) const { return i4((v % other.v)); }

i4 operator>>(const int amount) const { return i4((v >> amount)); }
i4 operator<<(const int amount) const { return i4((v << amount)); }

bool operator==(const i4& other) const { return v == other.v; }
bool operator!=(const i4& other) const { return v != other.v; }
bool operator<(const i4& other) const { return v < other.v; }
bool operator>(const i4& other) const { return v > other.v; }
bool operator<=(const i4& other) const { return v <= other.v; }
bool operator>=(const i4& other) const { return v >= other.v; }

bool operator==(const int64_t other) const { return v == other; }
bool operator!=(const int64_t other) const { return v != other; }
bool operator<(const int64_t other) const { return v < other; }
bool operator>(const int64_t other) const { return v > other; }
bool operator<=(const int64_t other) const { return v <= other; }
bool operator>=(const int64_t other) const { return v >= other; }

i4& operator++() {
v = (v + 1) & 0x0F;
return *this;
}

friend ::std::ostream& operator<<(::std::ostream& os, const i4& num) {
os << static_cast<int16_t>(num.v);
return os;
}

std::string ToString() const {
std::ostringstream os;
os << static_cast<int16_t>(v);
return os.str();
}
};

using int4 = i4<int8_t>;
using uint4 = i4<uint8_t>;

template <typename T>
struct Int4TypeDescriptor {
static int Dtype() { return npy_type; }
Expand Down Expand Up @@ -878,4 +804,4 @@ bool RegisterInt4Dtype(PyObject* numpy) {

} // namespace ml_dtypes

#endif // ML_DTYPES_INT4_H_
#endif // ML_DTYPES_INT4_NUMPY_H_
99 changes: 99 additions & 0 deletions ml_dtypes/include/int4.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
/* Copyright 2023 The ml_dtypes Authors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#ifndef ML_DTYPES_INT4_H_
#define ML_DTYPES_INT4_H_

#include <cstdint>
#include <optional>
#include <ostream>
#include <sstream>
#include <string>

namespace ml_dtypes {

template <typename UnderlyingTy>
struct i4 {
private:
UnderlyingTy v : 4;

public:
i4() : v(0) {}
explicit i4(UnderlyingTy val) : v(val & 0x0F) {}
template <typename T>
explicit i4(T t) : i4(static_cast<UnderlyingTy>(t)) {}
i4(const i4& other) = default;

static constexpr i4 lowest() {
return std::is_signed<UnderlyingTy>::value ? i4(-8) : i4(0);
}
static constexpr i4 highest() {
return std::is_signed<UnderlyingTy>::value ? i4(7) : i4(15);
}

template <typename T, typename = std::enable_if_t<std::is_arithmetic_v<T>>>
explicit operator T() const {
return static_cast<T>(v);
}
// NOLINTNEXTLINE(google-explicit-constructor)
operator std::optional<int64_t>() const { return static_cast<int64_t>(v); }

i4 operator-() const { return i4(-v); }
i4 operator+(const i4& other) const { return i4((v + other.v)); }
i4 operator-(const i4& other) const { return i4((v - other.v)); }
i4 operator*(const i4& other) const { return i4((v * other.v)); }
i4 operator/(const i4& other) const { return i4((v / other.v)); }
i4 operator%(const i4& other) const { return i4((v % other.v)); }

i4 operator>>(const int amount) const { return i4((v >> amount)); }
i4 operator<<(const int amount) const { return i4((v << amount)); }

bool operator==(const i4& other) const { return v == other.v; }
bool operator!=(const i4& other) const { return v != other.v; }
bool operator<(const i4& other) const { return v < other.v; }
bool operator>(const i4& other) const { return v > other.v; }
bool operator<=(const i4& other) const { return v <= other.v; }
bool operator>=(const i4& other) const { return v >= other.v; }

bool operator==(const int64_t other) const { return v == other; }
bool operator!=(const int64_t other) const { return v != other; }
bool operator<(const int64_t other) const { return v < other; }
bool operator>(const int64_t other) const { return v > other; }
bool operator<=(const int64_t other) const { return v <= other; }
bool operator>=(const int64_t other) const { return v >= other; }

i4& operator++() {
v = (v + 1) & 0x0F;
return *this;
}

friend ::std::ostream& operator<<(::std::ostream& os, const i4& num) {
os << static_cast<int16_t>(num.v);
return os;
}

std::string ToString() const {
std::ostringstream os;
os << static_cast<int16_t>(v);
return os.str();
}
};

using int4 = i4<int8_t>;
using uint4 = i4<uint8_t>;

} // namespace ml_dtypes

#endif // ML_DTYPES_INT4_H_
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def find_package_modules(self, package, package_dir):
setup(
ext_modules=[
Pybind11Extension(
"ml_dtypes._custom_floats",
"ml_dtypes._ml_dtypes_lib",
[
"ml_dtypes/_src/dtypes.cc",
"ml_dtypes/_src/numpy.cc",
Expand Down

0 comments on commit fa8060c

Please sign in to comment.