Skip to content

Commit

Permalink
Introduce Structural Hashing
Browse files Browse the repository at this point in the history
  • Loading branch information
potatomashed committed Dec 2, 2024
1 parent b8c78bc commit 9c3279b
Show file tree
Hide file tree
Showing 17 changed files with 822 additions and 388 deletions.
2 changes: 1 addition & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ cmake_minimum_required(VERSION 3.15)

project(
mlc
VERSION 0.0.10
VERSION 0.0.11
DESCRIPTION "MLC-Python"
LANGUAGES C CXX
)
Expand Down
10 changes: 8 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ class Let(Expr):
body: Expr
```

**Structural equality**. Method eq_s is ready to use to compare the structural equality (alpha equivalence) of two IRs.
**Structural equality**. Member method `eq_s` compares the structural equality (alpha equivalence) of two IRs represented by MLC's structured dataclass.

```python
"""
Expand All @@ -110,7 +110,13 @@ True
ValueError: Structural equality check failed at {root}.rhs.b: Inconsistent binding. RHS has been bound to a different node while LHS is not bound
```

**Structural hashing**. TBD
**Structural hashing**. The structure of MLC dataclasses can be hashed via `hash_s`, which guarantees if two dataclasses are alpha-equivalent, they will share the same structural hash:

```python
>>> L1_hash, L2_hash, L3_hash = L1.hash_s(), L2.hash_s(), L3.hash_s()
>>> assert L1_hash == L2_hash
>>> assert L1_hash != L3_hash
```

### :snake: Text Formats in Python AST

Expand Down
4 changes: 4 additions & 0 deletions cpp/c_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,10 @@ MLC_REGISTER_FUNC("mlc.core.JSONDeserialize").set_body([](AnyView json_str) {
}
});
MLC_REGISTER_FUNC("mlc.core.StructuralEqual").set_body(::mlc::core::StructuralEqual);
MLC_REGISTER_FUNC("mlc.core.StructuralHash").set_body([](::mlc::Object *obj) -> int64_t {
uint64_t ret = ::mlc::core::StructuralHash(obj);
return static_cast<int64_t>(ret);
});
} // namespace

MLC_API MLCAny MLCGetLastError() {
Expand Down
1 change: 1 addition & 0 deletions cpp/registry.h
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,7 @@ struct TypeTable {
MLCTypeInfo *info = &wrapper->info;
info->type_index = type_index;
info->type_key = this->NewArray(type_key);
info->type_key_hash = ::mlc::base::StrHash(type_key, std::strlen(type_key));
info->type_depth = (parent == nullptr) ? 0 : (parent->type_depth + 1);
info->type_ancestors = this->NewArray<int32_t>(info->type_depth);
if (parent) {
Expand Down
49 changes: 49 additions & 0 deletions include/mlc/base/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#endif
#include "./base_traits.h"
#include <cstdlib>
#include <cstring>
#include <memory>
#include <sstream>
#include <type_traits>
Expand Down Expand Up @@ -297,6 +298,54 @@ inline int64_t StrToInt(const std::string &str, size_t start_pos = 0) {
}
return result;
}

MLC_INLINE uint64_t HashCombine(uint64_t seed, uint64_t value) {
return seed ^ (value + 0x9e3779b9 + (seed << 6) + (seed >> 2));
}

MLC_INLINE int32_t StrCompare(const char *a, const char *b, int64_t a_len, int64_t b_len) {
if (a_len != b_len) {
return static_cast<int32_t>(a_len - b_len);
}
return std::strncmp(a, b, a_len);
}

inline uint64_t StrHash(const char *str, int64_t length) {
const char *it = str;
const char *end = str + length;
uint64_t result = 0;
for (; it + 8 <= end; it += 8) {
uint64_t b = (static_cast<uint64_t>(it[0]) << 56) | (static_cast<uint64_t>(it[1]) << 48) |
(static_cast<uint64_t>(it[2]) << 40) | (static_cast<uint64_t>(it[3]) << 32) |
(static_cast<uint64_t>(it[4]) << 24) | (static_cast<uint64_t>(it[5]) << 16) |
(static_cast<uint64_t>(it[6]) << 8) | static_cast<uint64_t>(it[7]);
result = HashCombine(result, b);
}
if (it < end) {
uint64_t b = 0;
if (it + 4 <= end) {
b = (static_cast<uint64_t>(it[0]) << 24) | (static_cast<uint64_t>(it[1]) << 16) |
(static_cast<uint64_t>(it[2]) << 8) | static_cast<uint64_t>(it[3]);
it += 4;
}
if (it + 2 <= end) {
b = (b << 16) | (static_cast<uint64_t>(it[0]) << 8) | static_cast<uint64_t>(it[1]);
it += 2;
}
if (it + 1 <= end) {
b = (b << 8) | static_cast<uint64_t>(it[0]);
it += 1;
}
result = HashCombine(result, b);
}
return result;
}

inline uint64_t StrHash(const char *str) {
int64_t length = static_cast<int64_t>(std::strlen(str));
return StrHash(str, length);
}

} // namespace base
} // namespace mlc

Expand Down
1 change: 1 addition & 0 deletions include/mlc/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,7 @@ typedef struct {
typedef struct MLCTypeInfo {
int32_t type_index;
const char *type_key;
uint64_t type_key_hash;
int32_t type_depth;
int32_t *type_ancestors; // Range: [0, type_depth)
MLCTypeField *fields; // Ends with a field with name == nullptr
Expand Down
72 changes: 19 additions & 53 deletions include/mlc/core/str.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@ struct StrObj : public MLCStr {
MLC_INLINE const char *data() const { return this->MLCStr::data; }
MLC_INLINE int64_t length() const { return this->MLCStr::length; }
MLC_INLINE int64_t size() const { return this->MLCStr::length; }
MLC_INLINE uint64_t Hash() const;
MLC_INLINE bool StartsWith(const std::string &prefix) {
int64_t N = static_cast<int64_t>(prefix.length());
return N <= MLCStr::length && strncmp(MLCStr::data, prefix.data(), prefix.length()) == 0;
Expand All @@ -53,9 +52,21 @@ struct StrObj : public MLCStr {
return N <= MLCStr::length && strncmp(MLCStr::data + MLCStr::length - N, suffix.data(), N) == 0;
}
MLC_INLINE void PrintEscape(std::ostream &os) const;
MLC_INLINE int Compare(const StrObj *other) const { return std::strncmp(c_str(), other->c_str(), this->size() + 1); }
MLC_INLINE int Compare(const std::string &other) const { return std::strncmp(c_str(), other.c_str(), size() + 1); }
MLC_INLINE int Compare(const char *other) const { return std::strncmp(this->c_str(), other, this->size() + 1); }
MLC_INLINE int32_t Compare(const char *rhs_str, int64_t rhs_len) const {
return ::mlc::base::StrCompare(this->MLCStr::data, rhs_str, this->MLCStr::length, rhs_len);
}
MLC_INLINE int32_t Compare(const StrObj *other) const {
return this->Compare(other->c_str(), other->MLCStr::length); //
}
MLC_INLINE int32_t Compare(const std::string &other) const {
return this->Compare(other.data(), static_cast<int64_t>(other.length()));
}
MLC_INLINE int32_t Compare(const char *other) const {
return this->Compare(other, static_cast<int64_t>(std::strlen(other)));
}
MLC_INLINE uint64_t Hash() const {
return ::mlc::base::StrHash(this->MLCStr::data, this->MLCStr::length); //
}
MLC_DEF_STATIC_TYPE(StrObj, Object, MLCTypeIndex::kMLCStr, "object.Str")
.FieldReadOnly("length", &MLCStr::length)
.FieldReadOnly("data", &MLCStr::data)
Expand Down Expand Up @@ -247,7 +258,7 @@ inline std::ostream &operator<<(std::ostream &os, const Object &src) {
return os;
}

void StrObj::PrintEscape(std::ostream &oss) const {
inline void StrObj::PrintEscape(std::ostream &oss) const {
const char *data = this->MLCStr::data;
int64_t length = this->MLCStr::length;
oss << '"';
Expand Down Expand Up @@ -322,53 +333,7 @@ void StrObj::PrintEscape(std::ostream &oss) const {
oss << '"';
}

} // namespace mlc

namespace mlc {
namespace core {

MLC_INLINE int32_t StrCompare(const MLCStr *a, const MLCStr *b) {
if (a->length != b->length) {
return static_cast<int32_t>(a->length - b->length);
}
return std::strncmp(a->data, b->data, a->length);
}

MLC_INLINE uint64_t StrHash(const MLCStr *str) {
const constexpr uint64_t kMultiplier = 1099511628211ULL;
const constexpr uint64_t kMod = 2147483647ULL;
const char *it = str->data;
const char *end = it + str->length;
uint64_t result = 0;
for (; it + 8 <= end; it += 8) {
uint64_t b = (static_cast<uint64_t>(it[0]) << 56) | (static_cast<uint64_t>(it[1]) << 48) |
(static_cast<uint64_t>(it[2]) << 40) | (static_cast<uint64_t>(it[3]) << 32) |
(static_cast<uint64_t>(it[4]) << 24) | (static_cast<uint64_t>(it[5]) << 16) |
(static_cast<uint64_t>(it[6]) << 8) | static_cast<uint64_t>(it[7]);
result = (result * kMultiplier + b) % kMod;
}
if (it < end) {
uint64_t b = 0;
if (it + 4 <= end) {
b = (static_cast<uint64_t>(it[0]) << 24) | (static_cast<uint64_t>(it[1]) << 16) |
(static_cast<uint64_t>(it[2]) << 8) | static_cast<uint64_t>(it[3]);
it += 4;
}
if (it + 2 <= end) {
b = (b << 16) | (static_cast<uint64_t>(it[0]) << 8) | static_cast<uint64_t>(it[1]);
it += 2;
}
if (it + 1 <= end) {
b = (b << 8) | static_cast<uint64_t>(it[0]);
it += 1;
}
result = (result * kMultiplier + b) % kMod;
}
return result;
}
} // namespace core

MLC_INLINE Str Str::FromEscaped(int64_t N, const char *str) {
inline Str Str::FromEscaped(int64_t N, const char *str) {
std::ostringstream oss;
if (N < 2 || str[0] != '\"' || str[N - 1] != '\"') {
MLC_THROW(ValueError) << "Invalid escaped string: " << str;
Expand Down Expand Up @@ -443,13 +408,14 @@ MLC_INLINE Str Str::FromEscaped(int64_t N, const char *str) {
}
return Str(oss.str());
}
} // namespace mlc

namespace mlc {
namespace base {
MLC_INLINE StrObj *StrCopyFromCharArray(const char *source, size_t length) {
return StrObj::Allocator::New(source, length + 1);
}
} // namespace base
MLC_INLINE uint64_t StrObj::Hash() const { return ::mlc::core::StrHash(reinterpret_cast<const MLCStr *>(this)); }
} // namespace mlc

#endif // MLC_CORE_STR_H_
Loading

0 comments on commit 9c3279b

Please sign in to comment.