Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Integrate int4 C++ types into TensorFlow. #89

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 8 additions & 9 deletions ml_dtypes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,17 +29,16 @@

from typing import Type

from ml_dtypes._custom_floats import bfloat16
from ml_dtypes._custom_floats import float8_e4m3b11fnuz
from ml_dtypes._custom_floats import float8_e4m3fn
from ml_dtypes._custom_floats import float8_e4m3fnuz
from ml_dtypes._custom_floats import float8_e5m2
from ml_dtypes._custom_floats import float8_e5m2fnuz
from ml_dtypes._custom_floats import int4
from ml_dtypes._custom_floats import uint4
from ml_dtypes._finfo import finfo
from ml_dtypes._iinfo import iinfo

from ml_dtypes._ml_dtypes_lib import bfloat16
from ml_dtypes._ml_dtypes_lib import float8_e4m3b11fnuz
from ml_dtypes._ml_dtypes_lib import float8_e4m3fn
from ml_dtypes._ml_dtypes_lib import float8_e4m3fnuz
from ml_dtypes._ml_dtypes_lib import float8_e5m2
from ml_dtypes._ml_dtypes_lib import float8_e5m2fnuz
from ml_dtypes._ml_dtypes_lib import int4
from ml_dtypes._ml_dtypes_lib import uint4
import numpy as np

bfloat16: Type[np.generic]
Expand Down
13 changes: 6 additions & 7 deletions ml_dtypes/_finfo.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,12 @@

from typing import Dict

from ml_dtypes._custom_floats import bfloat16
from ml_dtypes._custom_floats import float8_e4m3b11fnuz
from ml_dtypes._custom_floats import float8_e4m3fn
from ml_dtypes._custom_floats import float8_e4m3fnuz
from ml_dtypes._custom_floats import float8_e5m2
from ml_dtypes._custom_floats import float8_e5m2fnuz

from ml_dtypes._ml_dtypes_lib import bfloat16
from ml_dtypes._ml_dtypes_lib import float8_e4m3b11fnuz
from ml_dtypes._ml_dtypes_lib import float8_e4m3fn
from ml_dtypes._ml_dtypes_lib import float8_e4m3fnuz
from ml_dtypes._ml_dtypes_lib import float8_e5m2
from ml_dtypes._ml_dtypes_lib import float8_e5m2fnuz
import numpy as np

_bfloat16_dtype = np.dtype(bfloat16)
Expand Down
5 changes: 2 additions & 3 deletions ml_dtypes/_iinfo.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,8 @@

"""Overload of numpy.iinfo to handle dtypes defined in ml_dtypes."""

from ml_dtypes._custom_floats import int4
from ml_dtypes._custom_floats import uint4

from ml_dtypes._ml_dtypes_lib import int4
from ml_dtypes._ml_dtypes_lib import uint4
import numpy as np

_int4_dtype = np.dtype(int4)
Expand Down
9 changes: 5 additions & 4 deletions ml_dtypes/_src/dtypes.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,9 @@ limitations under the License.

#include "Eigen/Core"
#include "_src/custom_float.h"
#include "_src/int4.h"
#include "_src/int4_numpy.h"
#include "include/float8.h"
#include "include/int4.h"

namespace ml_dtypes {

Expand Down Expand Up @@ -297,7 +298,7 @@ bool Initialize() {

static PyModuleDef module_def = {
PyModuleDef_HEAD_INIT,
"_custom_floats",
"_ml_dtypes_lib",
};

// TODO(phawkins): PyMODINIT_FUNC handles visibility correctly in Python 3.9+.
Expand All @@ -308,14 +309,14 @@ static PyModuleDef module_def = {
#define EXPORT_SYMBOL __attribute__((visibility("default")))
#endif

extern "C" EXPORT_SYMBOL PyObject* PyInit__custom_floats() {
extern "C" EXPORT_SYMBOL PyObject* PyInit__ml_dtypes_lib() {
Safe_PyObjectPtr m = make_safe(PyModule_Create(&module_def));
if (!m) {
return nullptr;
}
if (!Initialize()) {
if (!PyErr_Occurred()) {
PyErr_SetString(PyExc_RuntimeError, "cannot load _custom_floats module.");
PyErr_SetString(PyExc_RuntimeError, "cannot load _ml_dtypes_lib module.");
}
return nullptr;
}
Expand Down
82 changes: 4 additions & 78 deletions ml_dtypes/_src/int4.h → ml_dtypes/_src/int4_numpy.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,95 +13,21 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#ifndef ML_DTYPES_INT4_H_
#define ML_DTYPES_INT4_H_
#ifndef ML_DTYPES_INT4_NUMPY_H_
#define ML_DTYPES_INT4_NUMPY_H_

// Must be included first
// clang-format off
#include "_src/numpy.h"
// clang-format on

#include <cstdint> //NOLINT
#include <optional> //NOLINT
#include <ostream> //NOLINT
#include <sstream> //NOLINT

#include "Eigen/Core"
#include "_src/common.h" // NOLINT
#include "_src/ufuncs.h" // NOLINT
#include "include/int4.h"

namespace ml_dtypes {

template <typename UnderlyingTy>
struct i4 {
private:
UnderlyingTy v : 4;

public:
i4() : v(0) {}
explicit i4(UnderlyingTy val) : v(val & 0x0F) {}
template <typename T>
explicit i4(T t) : i4(static_cast<UnderlyingTy>(t)) {}
i4(const i4& other) = default;

static constexpr i4 lowest() {
return std::is_signed<UnderlyingTy>::value ? i4(-8) : i4(0);
}
static constexpr i4 highest() {
return std::is_signed<UnderlyingTy>::value ? i4(7) : i4(15);
}

template <typename T, typename = std::enable_if_t<std::is_arithmetic_v<T>>>
explicit operator T() const {
return static_cast<T>(v);
}
// NOLINTNEXTLINE(google-explicit-constructor)
operator std::optional<int64_t>() const { return static_cast<int64_t>(v); }

i4 operator-() const { return i4(-v); }
i4 operator+(const i4& other) const { return i4((v + other.v)); }
i4 operator-(const i4& other) const { return i4((v - other.v)); }
i4 operator*(const i4& other) const { return i4((v * other.v)); }
i4 operator/(const i4& other) const { return i4((v / other.v)); }
i4 operator%(const i4& other) const { return i4((v % other.v)); }

i4 operator>>(const int amount) const { return i4((v >> amount)); }
i4 operator<<(const int amount) const { return i4((v << amount)); }

bool operator==(const i4& other) const { return v == other.v; }
bool operator!=(const i4& other) const { return v != other.v; }
bool operator<(const i4& other) const { return v < other.v; }
bool operator>(const i4& other) const { return v > other.v; }
bool operator<=(const i4& other) const { return v <= other.v; }
bool operator>=(const i4& other) const { return v >= other.v; }

bool operator==(const int64_t other) const { return v == other; }
bool operator!=(const int64_t other) const { return v != other; }
bool operator<(const int64_t other) const { return v < other; }
bool operator>(const int64_t other) const { return v > other; }
bool operator<=(const int64_t other) const { return v <= other; }
bool operator>=(const int64_t other) const { return v >= other; }

i4& operator++() {
v = (v + 1) & 0x0F;
return *this;
}

friend ::std::ostream& operator<<(::std::ostream& os, const i4& num) {
os << static_cast<int16_t>(num.v);
return os;
}

std::string ToString() const {
std::ostringstream os;
os << static_cast<int16_t>(v);
return os.str();
}
};

using int4 = i4<int8_t>;
using uint4 = i4<uint8_t>;

template <typename T>
struct Int4TypeDescriptor {
static int Dtype() { return npy_type; }
Expand Down Expand Up @@ -878,4 +804,4 @@ bool RegisterInt4Dtype(PyObject* numpy) {

} // namespace ml_dtypes

#endif // ML_DTYPES_INT4_H_
#endif // ML_DTYPES_INT4_NUMPY_H_
100 changes: 100 additions & 0 deletions ml_dtypes/include/int4.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
/* Copyright 2023 The ml_dtypes Authors

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#ifndef ML_DTYPES_INT4_H_
#define ML_DTYPES_INT4_H_

#include <cstdint>
#include <limits>
#include <optional>
#include <ostream>
#include <sstream>
#include <string>

namespace ml_dtypes {

template <typename UnderlyingTy>
struct i4 {
private:
UnderlyingTy v : 4;

public:
i4() : v(0) {}
explicit i4(UnderlyingTy val) : v(val & 0x0F) {}
template <typename T>
explicit i4(T t) : i4(static_cast<UnderlyingTy>(t)) {}
i4(const i4& other) = default;

static constexpr i4 lowest() {
return std::is_signed<UnderlyingTy>::value ? i4(-8) : i4(0);
}
static constexpr i4 highest() {
return std::is_signed<UnderlyingTy>::value ? i4(7) : i4(15);
}

template <typename T, typename = std::enable_if_t<std::is_arithmetic_v<T>>>
explicit operator T() const {
return static_cast<T>(v);
}
// NOLINTNEXTLINE(google-explicit-constructor)
operator std::optional<int64_t>() const { return static_cast<int64_t>(v); }

i4 operator-() const { return i4(-v); }
i4 operator+(const i4& other) const { return i4((v + other.v)); }
i4 operator-(const i4& other) const { return i4((v - other.v)); }
i4 operator*(const i4& other) const { return i4((v * other.v)); }
i4 operator/(const i4& other) const { return i4((v / other.v)); }
i4 operator%(const i4& other) const { return i4((v % other.v)); }

i4 operator>>(const int amount) const { return i4((v >> amount)); }
i4 operator<<(const int amount) const { return i4((v << amount)); }

bool operator==(const i4& other) const { return v == other.v; }
bool operator!=(const i4& other) const { return v != other.v; }
bool operator<(const i4& other) const { return v < other.v; }
bool operator>(const i4& other) const { return v > other.v; }
bool operator<=(const i4& other) const { return v <= other.v; }
bool operator>=(const i4& other) const { return v >= other.v; }

bool operator==(const int64_t other) const { return v == other; }
bool operator!=(const int64_t other) const { return v != other; }
bool operator<(const int64_t other) const { return v < other; }
bool operator>(const int64_t other) const { return v > other; }
bool operator<=(const int64_t other) const { return v <= other; }
bool operator>=(const int64_t other) const { return v >= other; }

i4& operator++() {
v = (v + 1) & 0x0F;
return *this;
}

friend ::std::ostream& operator<<(::std::ostream& os, const i4& num) {
os << static_cast<int16_t>(num.v);
return os;
}

std::string ToString() const {
std::ostringstream os;
os << static_cast<int16_t>(v);
return os.str();
}
};

using int4 = i4<int8_t>;
using uint4 = i4<uint8_t>;

} // namespace ml_dtypes

#endif // ML_DTYPES_INT4_H_
Loading
Loading