diff --git a/CMakeLists.txt b/CMakeLists.txt index 89bdb75b..c2fa4498 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -2,7 +2,7 @@ cmake_minimum_required(VERSION 3.15) project( mlc - VERSION 0.0.10 + VERSION 0.0.11 DESCRIPTION "MLC-Python" LANGUAGES C CXX ) diff --git a/README.md b/README.md index 290203a7..b4066565 100644 --- a/README.md +++ b/README.md @@ -92,7 +92,7 @@ class Let(Expr): body: Expr ``` -**Structural equality**. Method eq_s is ready to use to compare the structural equality (alpha equivalence) of two IRs. +**Structural equality**. Member method `eq_s` compares the structural equality (alpha equivalence) of two IRs represented by MLC's structured dataclass. ```python """ @@ -110,7 +110,13 @@ True ValueError: Structural equality check failed at {root}.rhs.b: Inconsistent binding. RHS has been bound to a different node while LHS is not bound ``` -**Structural hashing**. TBD +**Structural hashing**. The structure of MLC dataclasses can be hashed via `hash_s`, which guarantees if two dataclasses are alpha-equivalent, they will share the same structural hash: + +```python +>>> L1_hash, L2_hash, L3_hash = L1.hash_s(), L2.hash_s(), L3.hash_s() +>>> assert L1_hash == L2_hash +>>> assert L1_hash != L3_hash +``` ### :snake: Text Formats in Python AST diff --git a/cpp/c_api.cc b/cpp/c_api.cc index 19a9b8b8..fd02cf6c 100644 --- a/cpp/c_api.cc +++ b/cpp/c_api.cc @@ -37,6 +37,10 @@ MLC_REGISTER_FUNC("mlc.core.JSONDeserialize").set_body([](AnyView json_str) { } }); MLC_REGISTER_FUNC("mlc.core.StructuralEqual").set_body(::mlc::core::StructuralEqual); +MLC_REGISTER_FUNC("mlc.core.StructuralHash").set_body([](::mlc::Object *obj) -> int64_t { + uint64_t ret = ::mlc::core::StructuralHash(obj); + return static_cast(ret); +}); } // namespace MLC_API MLCAny MLCGetLastError() { diff --git a/cpp/registry.h b/cpp/registry.h index 58e4406c..b1d939b7 100644 --- a/cpp/registry.h +++ b/cpp/registry.h @@ -205,6 +205,7 @@ struct TypeTable { MLCTypeInfo *info = &wrapper->info; info->type_index = type_index; info->type_key = this->NewArray(type_key); + info->type_key_hash = ::mlc::base::StrHash(type_key, std::strlen(type_key)); info->type_depth = (parent == nullptr) ? 0 : (parent->type_depth + 1); info->type_ancestors = this->NewArray(info->type_depth); if (parent) { diff --git a/include/mlc/base/utils.h b/include/mlc/base/utils.h index 9c5c9264..d4990b6e 100644 --- a/include/mlc/base/utils.h +++ b/include/mlc/base/utils.h @@ -16,6 +16,7 @@ #endif #include "./base_traits.h" #include +#include #include #include #include @@ -297,6 +298,54 @@ inline int64_t StrToInt(const std::string &str, size_t start_pos = 0) { } return result; } + +MLC_INLINE uint64_t HashCombine(uint64_t seed, uint64_t value) { + return seed ^ (value + 0x9e3779b9 + (seed << 6) + (seed >> 2)); +} + +MLC_INLINE int32_t StrCompare(const char *a, const char *b, int64_t a_len, int64_t b_len) { + if (a_len != b_len) { + return static_cast(a_len - b_len); + } + return std::strncmp(a, b, a_len); +} + +inline uint64_t StrHash(const char *str, int64_t length) { + const char *it = str; + const char *end = str + length; + uint64_t result = 0; + for (; it + 8 <= end; it += 8) { + uint64_t b = (static_cast(it[0]) << 56) | (static_cast(it[1]) << 48) | + (static_cast(it[2]) << 40) | (static_cast(it[3]) << 32) | + (static_cast(it[4]) << 24) | (static_cast(it[5]) << 16) | + (static_cast(it[6]) << 8) | static_cast(it[7]); + result = HashCombine(result, b); + } + if (it < end) { + uint64_t b = 0; + if (it + 4 <= end) { + b = (static_cast(it[0]) << 24) | (static_cast(it[1]) << 16) | + (static_cast(it[2]) << 8) | static_cast(it[3]); + it += 4; + } + if (it + 2 <= end) { + b = (b << 16) | (static_cast(it[0]) << 8) | static_cast(it[1]); + it += 2; + } + if (it + 1 <= end) { + b = (b << 8) | static_cast(it[0]); + it += 1; + } + result = HashCombine(result, b); + } + return result; +} + +inline uint64_t StrHash(const char *str) { + int64_t length = static_cast(std::strlen(str)); + return StrHash(str, length); +} + } // namespace base } // namespace mlc diff --git a/include/mlc/c_api.h b/include/mlc/c_api.h index 726e6adb..33081d83 100644 --- a/include/mlc/c_api.h +++ b/include/mlc/c_api.h @@ -200,6 +200,7 @@ typedef struct { typedef struct MLCTypeInfo { int32_t type_index; const char *type_key; + uint64_t type_key_hash; int32_t type_depth; int32_t *type_ancestors; // Range: [0, type_depth) MLCTypeField *fields; // Ends with a field with name == nullptr diff --git a/include/mlc/core/str.h b/include/mlc/core/str.h index 2bd22d66..f310b75f 100644 --- a/include/mlc/core/str.h +++ b/include/mlc/core/str.h @@ -43,7 +43,6 @@ struct StrObj : public MLCStr { MLC_INLINE const char *data() const { return this->MLCStr::data; } MLC_INLINE int64_t length() const { return this->MLCStr::length; } MLC_INLINE int64_t size() const { return this->MLCStr::length; } - MLC_INLINE uint64_t Hash() const; MLC_INLINE bool StartsWith(const std::string &prefix) { int64_t N = static_cast(prefix.length()); return N <= MLCStr::length && strncmp(MLCStr::data, prefix.data(), prefix.length()) == 0; @@ -53,9 +52,21 @@ struct StrObj : public MLCStr { return N <= MLCStr::length && strncmp(MLCStr::data + MLCStr::length - N, suffix.data(), N) == 0; } MLC_INLINE void PrintEscape(std::ostream &os) const; - MLC_INLINE int Compare(const StrObj *other) const { return std::strncmp(c_str(), other->c_str(), this->size() + 1); } - MLC_INLINE int Compare(const std::string &other) const { return std::strncmp(c_str(), other.c_str(), size() + 1); } - MLC_INLINE int Compare(const char *other) const { return std::strncmp(this->c_str(), other, this->size() + 1); } + MLC_INLINE int32_t Compare(const char *rhs_str, int64_t rhs_len) const { + return ::mlc::base::StrCompare(this->MLCStr::data, rhs_str, this->MLCStr::length, rhs_len); + } + MLC_INLINE int32_t Compare(const StrObj *other) const { + return this->Compare(other->c_str(), other->MLCStr::length); // + } + MLC_INLINE int32_t Compare(const std::string &other) const { + return this->Compare(other.data(), static_cast(other.length())); + } + MLC_INLINE int32_t Compare(const char *other) const { + return this->Compare(other, static_cast(std::strlen(other))); + } + MLC_INLINE uint64_t Hash() const { + return ::mlc::base::StrHash(this->MLCStr::data, this->MLCStr::length); // + } MLC_DEF_STATIC_TYPE(StrObj, Object, MLCTypeIndex::kMLCStr, "object.Str") .FieldReadOnly("length", &MLCStr::length) .FieldReadOnly("data", &MLCStr::data) @@ -247,7 +258,7 @@ inline std::ostream &operator<<(std::ostream &os, const Object &src) { return os; } -void StrObj::PrintEscape(std::ostream &oss) const { +inline void StrObj::PrintEscape(std::ostream &oss) const { const char *data = this->MLCStr::data; int64_t length = this->MLCStr::length; oss << '"'; @@ -322,53 +333,7 @@ void StrObj::PrintEscape(std::ostream &oss) const { oss << '"'; } -} // namespace mlc - -namespace mlc { -namespace core { - -MLC_INLINE int32_t StrCompare(const MLCStr *a, const MLCStr *b) { - if (a->length != b->length) { - return static_cast(a->length - b->length); - } - return std::strncmp(a->data, b->data, a->length); -} - -MLC_INLINE uint64_t StrHash(const MLCStr *str) { - const constexpr uint64_t kMultiplier = 1099511628211ULL; - const constexpr uint64_t kMod = 2147483647ULL; - const char *it = str->data; - const char *end = it + str->length; - uint64_t result = 0; - for (; it + 8 <= end; it += 8) { - uint64_t b = (static_cast(it[0]) << 56) | (static_cast(it[1]) << 48) | - (static_cast(it[2]) << 40) | (static_cast(it[3]) << 32) | - (static_cast(it[4]) << 24) | (static_cast(it[5]) << 16) | - (static_cast(it[6]) << 8) | static_cast(it[7]); - result = (result * kMultiplier + b) % kMod; - } - if (it < end) { - uint64_t b = 0; - if (it + 4 <= end) { - b = (static_cast(it[0]) << 24) | (static_cast(it[1]) << 16) | - (static_cast(it[2]) << 8) | static_cast(it[3]); - it += 4; - } - if (it + 2 <= end) { - b = (b << 16) | (static_cast(it[0]) << 8) | static_cast(it[1]); - it += 2; - } - if (it + 1 <= end) { - b = (b << 8) | static_cast(it[0]); - it += 1; - } - result = (result * kMultiplier + b) % kMod; - } - return result; -} -} // namespace core - -MLC_INLINE Str Str::FromEscaped(int64_t N, const char *str) { +inline Str Str::FromEscaped(int64_t N, const char *str) { std::ostringstream oss; if (N < 2 || str[0] != '\"' || str[N - 1] != '\"') { MLC_THROW(ValueError) << "Invalid escaped string: " << str; @@ -443,13 +408,14 @@ MLC_INLINE Str Str::FromEscaped(int64_t N, const char *str) { } return Str(oss.str()); } +} // namespace mlc +namespace mlc { namespace base { MLC_INLINE StrObj *StrCopyFromCharArray(const char *source, size_t length) { return StrObj::Allocator::New(source, length + 1); } } // namespace base -MLC_INLINE uint64_t StrObj::Hash() const { return ::mlc::core::StrHash(reinterpret_cast(this)); } } // namespace mlc #endif // MLC_CORE_STR_H_ diff --git a/include/mlc/core/structure.h b/include/mlc/core/structure.h index 7d2dcf48..b317da01 100644 --- a/include/mlc/core/structure.h +++ b/include/mlc/core/structure.h @@ -1,7 +1,9 @@ #ifndef MLC_CORE_STRUCTURE_H_ #define MLC_CORE_STRUCTURE_H_ #include "./field_visitor.h" +#include #include +#include #include #include #include @@ -15,6 +17,8 @@ namespace mlc { namespace core { +uint64_t StructuralHash(Object *obj); +bool StructuralEqual(Object *lhs, Object *rhs, bool bind_free_vars, bool assert_mode); void StructuralEqualImpl(Object *lhs, Object *rhs, bool bind_free_vars); struct SEqualPath { @@ -87,37 +91,37 @@ template MLC_INLINE T *WithOffset(Object *obj, MLCTypeField *field) return reinterpret_cast(reinterpret_cast(obj) + field->offset); } -#define MLC_CORE_EQ_S_ERR_OUT(LHS, RHS, PATH) \ +#define MLC_CORE_EQ_S_ERR(LHS, RHS, PATH) \ { \ std::ostringstream err; \ err << (LHS) << " vs " << (RHS); \ throw SEqualError(err.str().c_str(), (PATH)); \ } -#define MLC_CORE_EQ_S_CMP_ANY(Cond, Type, EQ, LHS, RHS, PATH) \ +#define MLC_CORE_EQ_S_ANY(Cond, Type, EQ, LHS, RHS, PATH) \ if (Cond) { \ Type lhs_value = LHS->operator Type(); \ Type rhs_value = RHS->operator Type(); \ if (EQ(lhs_value, rhs_value)) { \ return; \ } else { \ - MLC_CORE_EQ_S_ERR_OUT(*lhs, *rhs, PATH); \ + MLC_CORE_EQ_S_ERR(*lhs, *rhs, PATH); \ } \ } -#define MLC_CORE_EQ_S_COMPARE_OPT(Type, EQ) \ +#define MLC_CORE_EQ_S_OPT(Type, EQ) \ MLC_INLINE void operator()(MLCTypeField *field, StructureFieldKind, Optional *_lhs) { \ const Type *lhs = _lhs->get(); \ const Type *rhs = WithOffset>(obj_rhs, field)->get(); \ if ((lhs != nullptr || rhs != nullptr) && (lhs == nullptr || rhs == nullptr || !EQ(*lhs, *rhs))) { \ AnyView LHS = lhs ? AnyView(*lhs) : AnyView(nullptr); \ AnyView RHS = rhs ? AnyView(*rhs) : AnyView(nullptr); \ - MLC_CORE_EQ_S_ERR_OUT(LHS, RHS, Append::Field(path, field->name)); \ + MLC_CORE_EQ_S_ERR(LHS, RHS, Append::Field(path, field->name)); \ } \ } -#define MLC_CORE_EQ_S_COMPARE(Type, EQ) \ +#define MLC_CORE_EQ_S_POD(Type, EQ) \ MLC_INLINE void operator()(MLCTypeField *field, StructureFieldKind, Type *lhs) { \ const Type *rhs = WithOffset(obj_rhs, field); \ if (!EQ(*lhs, *rhs)) { \ - MLC_CORE_EQ_S_ERR_OUT(AnyView(*lhs), AnyView(*rhs), Append::Field(path, field->name)); \ + MLC_CORE_EQ_S_ERR(AnyView(*lhs), AnyView(*rhs), Append::Field(path, field->name)); \ } \ } @@ -146,25 +150,25 @@ inline void StructuralEqualImpl(Object *lhs, Object *rhs, bool bind_free_vars) { return std::make_shared(SEqualPath{self, 2, nullptr, 0, new_dict_key}); } }; - struct FieldComparator { + struct Visitor { static bool CharArrayEqual(CharArray lhs, CharArray rhs) { return std::strcmp(lhs, rhs) == 0; } static bool FloatEqual(float lhs, float rhs) { return std::abs(lhs - rhs) < 1e-6; } static bool DoubleEqual(double lhs, double rhs) { return std::abs(lhs - rhs) < 1e-8; } - MLC_CORE_EQ_S_COMPARE_OPT(int64_t, std::equal_to()); - MLC_CORE_EQ_S_COMPARE_OPT(double, DoubleEqual); - MLC_CORE_EQ_S_COMPARE_OPT(DLDevice, DeviceEqual); - MLC_CORE_EQ_S_COMPARE_OPT(DLDataType, DataTypeEqual); - MLC_CORE_EQ_S_COMPARE_OPT(VoidPtr, std::equal_to()); - MLC_CORE_EQ_S_COMPARE(int8_t, std::equal_to()); - MLC_CORE_EQ_S_COMPARE(int16_t, std::equal_to()); - MLC_CORE_EQ_S_COMPARE(int32_t, std::equal_to()); - MLC_CORE_EQ_S_COMPARE(int64_t, std::equal_to()); - MLC_CORE_EQ_S_COMPARE(float, FloatEqual); - MLC_CORE_EQ_S_COMPARE(double, DoubleEqual); - MLC_CORE_EQ_S_COMPARE(DLDataType, DataTypeEqual); - MLC_CORE_EQ_S_COMPARE(DLDevice, DeviceEqual); - MLC_CORE_EQ_S_COMPARE(VoidPtr, std::equal_to()); - MLC_CORE_EQ_S_COMPARE(CharArray, CharArrayEqual); + MLC_CORE_EQ_S_OPT(int64_t, std::equal_to()); + MLC_CORE_EQ_S_OPT(double, DoubleEqual); + MLC_CORE_EQ_S_OPT(DLDevice, DeviceEqual); + MLC_CORE_EQ_S_OPT(DLDataType, DataTypeEqual); + MLC_CORE_EQ_S_OPT(VoidPtr, std::equal_to()); + MLC_CORE_EQ_S_POD(int8_t, std::equal_to()); + MLC_CORE_EQ_S_POD(int16_t, std::equal_to()); + MLC_CORE_EQ_S_POD(int32_t, std::equal_to()); + MLC_CORE_EQ_S_POD(int64_t, std::equal_to()); + MLC_CORE_EQ_S_POD(float, FloatEqual); + MLC_CORE_EQ_S_POD(double, DoubleEqual); + MLC_CORE_EQ_S_POD(DLDataType, DataTypeEqual); + MLC_CORE_EQ_S_POD(DLDevice, DeviceEqual); + MLC_CORE_EQ_S_POD(VoidPtr, std::equal_to()); + MLC_CORE_EQ_S_POD(CharArray, CharArrayEqual); MLC_INLINE void operator()(MLCTypeField *field, StructureFieldKind field_kind, const Any *lhs) { const Any *rhs = WithOffset(obj_rhs, field); bool bind_free_vars = this->obj_bind_free_vars || field_kind == StructureFieldKind::kBind; @@ -186,17 +190,17 @@ inline void StructuralEqualImpl(Object *lhs, Object *rhs, bool bind_free_vars) { std::shared_ptr new_path) { int32_t type_index = lhs->GetTypeIndex(); if (type_index != rhs->GetTypeIndex()) { - MLC_CORE_EQ_S_ERR_OUT(lhs->GetTypeKey(), rhs->GetTypeKey(), new_path); + MLC_CORE_EQ_S_ERR(lhs->GetTypeKey(), rhs->GetTypeKey(), new_path); } if (type_index == kMLCNone) { return; } - MLC_CORE_EQ_S_CMP_ANY(type_index == kMLCInt, int64_t, std::equal_to(), lhs, rhs, new_path); - MLC_CORE_EQ_S_CMP_ANY(type_index == kMLCFloat, double, DoubleEqual, lhs, rhs, new_path); - MLC_CORE_EQ_S_CMP_ANY(type_index == kMLCPtr, VoidPtr, std::equal_to(), lhs, rhs, new_path); - MLC_CORE_EQ_S_CMP_ANY(type_index == kMLCDataType, DLDataType, DataTypeEqual, lhs, rhs, new_path); - MLC_CORE_EQ_S_CMP_ANY(type_index == kMLCDevice, DLDevice, DeviceEqual, lhs, rhs, new_path); - MLC_CORE_EQ_S_CMP_ANY(type_index == kMLCRawStr, CharArray, CharArrayEqual, lhs, rhs, new_path); + MLC_CORE_EQ_S_ANY(type_index == kMLCInt, int64_t, std::equal_to(), lhs, rhs, new_path); + MLC_CORE_EQ_S_ANY(type_index == kMLCFloat, double, DoubleEqual, lhs, rhs, new_path); + MLC_CORE_EQ_S_ANY(type_index == kMLCPtr, VoidPtr, std::equal_to(), lhs, rhs, new_path); + MLC_CORE_EQ_S_ANY(type_index == kMLCDataType, DLDataType, DataTypeEqual, lhs, rhs, new_path); + MLC_CORE_EQ_S_ANY(type_index == kMLCDevice, DLDevice, DeviceEqual, lhs, rhs, new_path); + MLC_CORE_EQ_S_ANY(type_index == kMLCRawStr, CharArray, CharArrayEqual, lhs, rhs, new_path); if (type_index < kMLCStaticObjectBegin) { MLC_THROW(InternalError) << "Unknown type key: " << lhs->GetTypeKey(); } @@ -207,13 +211,13 @@ inline void StructuralEqualImpl(Object *lhs, Object *rhs, bool bind_free_vars) { int32_t lhs_type_index = lhs ? lhs->GetTypeIndex() : kMLCNone; int32_t rhs_type_index = rhs ? rhs->GetTypeIndex() : kMLCNone; if (lhs_type_index != rhs_type_index) { - MLC_CORE_EQ_S_ERR_OUT(::mlc::base::TypeIndex2TypeKey(lhs_type_index), - ::mlc::base::TypeIndex2TypeKey(rhs_type_index), new_path); + MLC_CORE_EQ_S_ERR(::mlc::base::TypeIndex2TypeKey(lhs_type_index), + ::mlc::base::TypeIndex2TypeKey(rhs_type_index), new_path); } else if (lhs_type_index == kMLCStr) { Str lhs_str(reinterpret_cast(lhs)); Str rhs_str(reinterpret_cast(rhs)); if (lhs_str != rhs_str) { - MLC_CORE_EQ_S_ERR_OUT(lhs_str, rhs_str, new_path); + MLC_CORE_EQ_S_ERR(lhs_str, rhs_str, new_path); } } else if (lhs_type_index == kMLCFunc || lhs_type_index == kMLCError) { throw SEqualError("Cannot compare `mlc.Func` or `mlc.Error`", new_path); @@ -231,7 +235,32 @@ inline void StructuralEqualImpl(Object *lhs, Object *rhs, bool bind_free_vars) { std::vector tasks; std::unordered_map eq_lhs_to_rhs; std::unordered_map eq_rhs_to_lhs; - FieldComparator::EnqueueTask(&tasks, bind_free_vars, lhs, rhs, nullptr); + + auto check_bind = [&eq_lhs_to_rhs, &eq_rhs_to_lhs](Object *lhs, Object *rhs, + const std::shared_ptr &path) -> bool { + // check binding consistency: lhs -> rhs, rhs -> lhs + auto it_lhs_to_rhs = eq_lhs_to_rhs.find(lhs); + auto it_rhs_to_lhs = eq_rhs_to_lhs.find(rhs); + bool exist_lhs_to_rhs = it_lhs_to_rhs != eq_lhs_to_rhs.end(); + bool exist_rhs_to_lhs = it_rhs_to_lhs != eq_rhs_to_lhs.end(); + // already proven equal + if (exist_lhs_to_rhs && exist_rhs_to_lhs) { + if (it_lhs_to_rhs->second == rhs && it_rhs_to_lhs->second == lhs) { + return true; + } + throw SEqualError("Inconsistent binding: LHS and RHS are both bound, but to different nodes", path); + } + // inconsistent binding + if (exist_lhs_to_rhs) { + throw SEqualError("Inconsistent binding. LHS has been bound to a different node while RHS is not bound", path); + } + if (exist_rhs_to_lhs) { + throw SEqualError("Inconsistent binding. RHS has been bound to a different node while LHS is not bound", path); + } + return false; + }; + + Visitor::EnqueueTask(&tasks, bind_free_vars, lhs, rhs, nullptr); while (!tasks.empty()) { MLCTypeInfo *type_info; std::shared_ptr path; @@ -244,31 +273,10 @@ inline void StructuralEqualImpl(Object *lhs, Object *rhs, bool bind_free_vars) { bind_free_vars = task.bind_free_vars; if (task.err) { throw SEqualError(task.err->str().c_str(), path); - } - // check binding consistency: lhs -> rhs, rhs -> lhs - auto it_lhs_to_rhs = eq_lhs_to_rhs.find(lhs); - auto it_rhs_to_lhs = eq_rhs_to_lhs.find(rhs); - bool exist_lhs_to_rhs = it_lhs_to_rhs != eq_lhs_to_rhs.end(); - bool exist_rhs_to_lhs = it_rhs_to_lhs != eq_rhs_to_lhs.end(); - // already proven equal - if (exist_lhs_to_rhs && exist_rhs_to_lhs) { - if (it_lhs_to_rhs->second == rhs && it_rhs_to_lhs->second == lhs) { - tasks.pop_back(); - continue; - } else { - throw SEqualError("Inconsistent binding: LHS and RHS are both bound, but to different nodes", path); - } - } - // inconsistent binding - if (exist_lhs_to_rhs) { - throw SEqualError("Inconsistent binding. LHS has been bound to a different node while RHS is not bound", path); - } - if (exist_rhs_to_lhs) { - throw SEqualError("Inconsistent binding. RHS has been bound to a different node while LHS is not bound", path); - } - if (!task.visited) { - task.visited = true; - } else { + } else if (check_bind(lhs, rhs, path)) { + tasks.pop_back(); + continue; + } else if (task.visited) { StructureKind kind = static_cast(type_info->structure_kind); if (kind == StructureKind::kBind || (kind == StructureKind::kVar && bind_free_vars)) { // bind lhs <-> rhs @@ -280,7 +288,9 @@ inline void StructuralEqualImpl(Object *lhs, Object *rhs, bool bind_free_vars) { tasks.pop_back(); continue; } + task.visited = true; } + // `task.visited` was `False` int64_t task_index = static_cast(tasks.size()) - 1; if (type_info->type_index == kMLCList) { UListObj *lhs_list = reinterpret_cast(lhs); @@ -288,8 +298,7 @@ inline void StructuralEqualImpl(Object *lhs, Object *rhs, bool bind_free_vars) { int64_t lhs_size = lhs_list->size(); int64_t rhs_size = rhs_list->size(); for (int64_t i = (lhs_size < rhs_size ? lhs_size : rhs_size) - 1; i >= 0; --i) { - FieldComparator::EnqueueAny(&tasks, bind_free_vars, &lhs_list->at(i), &rhs_list->at(i), - Append::ListIndex(path, i)); + Visitor::EnqueueAny(&tasks, bind_free_vars, &lhs_list->at(i), &rhs_list->at(i), Append::ListIndex(path, i)); } if (lhs_size != rhs_size) { auto &err = tasks[task_index].err = std::make_unique(); @@ -315,8 +324,7 @@ inline void StructuralEqualImpl(Object *lhs, Object *rhs, bool bind_free_vars) { not_found_lhs_keys.push_back(lhs_key); continue; } - FieldComparator::EnqueueAny(&tasks, bind_free_vars, &kv.second, &rhs_it->second, - Append::DictKey(path, lhs_key)); + Visitor::EnqueueAny(&tasks, bind_free_vars, &kv.second, &rhs_it->second, Append::DictKey(path, lhs_key)); } auto &err = tasks[task_index].err; if (!not_found_lhs_keys.empty()) { @@ -330,15 +338,260 @@ inline void StructuralEqualImpl(Object *lhs, Object *rhs, bool bind_free_vars) { (*err) << "Dict size mismatch: " << lhs_dict->size() << " vs " << rhs_dict->size(); } } else { - VisitStructure(lhs, type_info, FieldComparator{rhs, &tasks, bind_free_vars, path}); + VisitStructure(lhs, type_info, Visitor{rhs, &tasks, bind_free_vars, path}); + } + } +} + +#define MLC_CORE_HASH_S_OPT(Type, Hasher) \ + MLC_INLINE void operator()(MLCTypeField *, StructureFieldKind, Optional *_v) { \ + if (const Type *v = _v->get()) { \ + EnqueuePOD(tasks, Hasher(*v)); \ + } else { \ + EnqueuePOD(tasks, HashCache::kNoneCombined); \ + } \ + } +#define MLC_CORE_HASH_S_POD(Type, Hasher) \ + MLC_INLINE void operator()(MLCTypeField *, StructureFieldKind, Type *v) { EnqueuePOD(tasks, Hasher(*v)); } +#define MLC_CORE_HASH_S_ANY(Cond, Type, Hasher) \ + if (Cond) { \ + EnqueuePOD(tasks, Hasher(v->operator Type())); \ + return; \ + } + +struct HashCache { + inline static const uint64_t MLC_SYMBOL_HIDE kNoneCombined = + ::mlc::base::HashCombine(::mlc::base::TypeIndex2TypeInfo(kMLCNone)->type_key_hash, 0); + inline static const uint64_t MLC_SYMBOL_HIDE kInt = ::mlc::base::TypeIndex2TypeInfo(kMLCInt)->type_key_hash; + inline static const uint64_t MLC_SYMBOL_HIDE kFloat = ::mlc::base::TypeIndex2TypeInfo(kMLCFloat)->type_key_hash; + inline static const uint64_t MLC_SYMBOL_HIDE kPtr = ::mlc::base::TypeIndex2TypeInfo(kMLCPtr)->type_key_hash; + inline static const uint64_t MLC_SYMBOL_HIDE kDType = ::mlc::base::TypeIndex2TypeInfo(kMLCDataType)->type_key_hash; + inline static const uint64_t MLC_SYMBOL_HIDE kDevice = ::mlc::base::TypeIndex2TypeInfo(kMLCDevice)->type_key_hash; + inline static const uint64_t MLC_SYMBOL_HIDE kRawStr = ::mlc::base::TypeIndex2TypeInfo(kMLCRawStr)->type_key_hash; + inline static const uint64_t MLC_SYMBOL_HIDE kStrObj = ::mlc::base::TypeIndex2TypeInfo(kMLCStr)->type_key_hash; + inline static const uint64_t MLC_SYMBOL_HIDE kBound = ::mlc::base::StrHash("$$Bounds$$"); + inline static const uint64_t MLC_SYMBOL_HIDE kUnbound = ::mlc::base::StrHash("$$Unbound$$"); +}; + +template MLC_INLINE uint64_t HashTyped(uint64_t type_hash, T value) { + union { + T src; + uint64_t tgt; + } u; + u.tgt = 0; + u.src = value; + return ::mlc::base::HashCombine(type_hash, u.tgt); +} + +inline uint64_t StructuralHash(Object *obj) { + using CharArray = const char *; + using VoidPtr = ::mlc::base::VoidPtr; + using mlc::base::HashCombine; + struct Task { + Object *obj; + MLCTypeInfo *type_info; + bool visited; + bool bind_free_vars; + uint64_t hash_value; + size_t index_in_result_hashes{0xffffffffffffffff}; + }; + struct Visitor { + static uint64_t HashInteger(int64_t a) { return HashTyped(HashCache::kInt, a); } + static uint64_t HashPtr(VoidPtr a) { return HashTyped(HashCache::kPtr, a); } + static uint64_t HashDevice(DLDevice a) { return HashTyped(HashCache::kDevice, a); } + static uint64_t HashDataType(DLDataType a) { return HashTyped(HashCache::kDType, a); } + // clang-format off + static uint64_t HashFloat(float a) { return HashTyped(HashCache::kFloat, std::isnan(a) ? std::numeric_limits::quiet_NaN() : a); } + static uint64_t HashDouble(double a) { return HashTyped(HashCache::kFloat, std::isnan(a) ? std::numeric_limits::quiet_NaN() : a); } + static uint64_t HashCharArray(CharArray a) { return HashTyped(HashCache::kRawStr, ::mlc::base::StrHash(a)); } + // clang-format on + MLC_CORE_HASH_S_OPT(int64_t, HashInteger); + MLC_CORE_HASH_S_OPT(double, HashDouble); + MLC_CORE_HASH_S_OPT(DLDevice, HashDevice); + MLC_CORE_HASH_S_OPT(DLDataType, HashDataType); + MLC_CORE_HASH_S_OPT(VoidPtr, HashPtr); + MLC_CORE_HASH_S_POD(int8_t, HashInteger); + MLC_CORE_HASH_S_POD(int16_t, HashInteger); + MLC_CORE_HASH_S_POD(int32_t, HashInteger); + MLC_CORE_HASH_S_POD(int64_t, HashInteger); + MLC_CORE_HASH_S_POD(float, HashFloat); + MLC_CORE_HASH_S_POD(double, HashDouble); + MLC_CORE_HASH_S_POD(DLDataType, HashDataType); + MLC_CORE_HASH_S_POD(DLDevice, HashDevice); + MLC_CORE_HASH_S_POD(VoidPtr, HashPtr); + MLC_CORE_HASH_S_POD(CharArray, HashCharArray); + MLC_INLINE void operator()(MLCTypeField *, StructureFieldKind field_kind, const Any *v) { + bool bind_free_vars = this->obj_bind_free_vars || field_kind == StructureFieldKind::kBind; + EnqueueAny(tasks, bind_free_vars, v); + } + MLC_INLINE void operator()(MLCTypeField *field, StructureFieldKind field_kind, ObjectRef *_v) { + HandleObject(field, field_kind, _v->get()); + } + MLC_INLINE void operator()(MLCTypeField *field, StructureFieldKind field_kind, Optional *_v) { + HandleObject(field, field_kind, _v->get()); + } + inline void HandleObject(MLCTypeField *, StructureFieldKind field_kind, Object *v) { + bool bind_free_vars = this->obj_bind_free_vars || field_kind == StructureFieldKind::kBind; + EnqueueTask(tasks, bind_free_vars, v); + } + static void EnqueuePOD(std::vector *tasks, uint64_t hash_value) { + tasks->emplace_back(Task{nullptr, nullptr, false, false, hash_value}); } + static void EnqueueAny(std::vector *tasks, bool bind_free_vars, const Any *v) { + int32_t type_index = v->GetTypeIndex(); + MLC_CORE_HASH_S_ANY(type_index == kMLCInt, int64_t, HashInteger); + MLC_CORE_HASH_S_ANY(type_index == kMLCFloat, double, HashDouble); + MLC_CORE_HASH_S_ANY(type_index == kMLCPtr, VoidPtr, HashPtr); + MLC_CORE_HASH_S_ANY(type_index == kMLCDataType, DLDataType, HashDataType); + MLC_CORE_HASH_S_ANY(type_index == kMLCDevice, DLDevice, HashDevice); + MLC_CORE_HASH_S_ANY(type_index == kMLCRawStr, CharArray, HashCharArray); + EnqueueTask(tasks, bind_free_vars, v->operator Object *()); + } + static void EnqueueTask(std::vector *tasks, bool bind_free_vars, Object *obj) { + int32_t type_index = obj ? obj->GetTypeIndex() : kMLCNone; + if (type_index == kMLCNone) { + EnqueuePOD(tasks, HashCache::kNoneCombined); + } else if (type_index == kMLCStr) { + const MLCStr *str = reinterpret_cast(obj); + uint64_t hash_value = ::mlc::base::StrHash(str->data, str->length); + hash_value = HashTyped(HashCache::kStrObj, hash_value); + EnqueuePOD(tasks, hash_value); + } else if (type_index == kMLCFunc || type_index == kMLCError) { + throw SEqualError("Cannot compare `mlc.Func` or `mlc.Error`", nullptr); + } else { + MLCTypeInfo *type_info = ::mlc::base::TypeIndex2TypeInfo(type_index); + tasks->emplace_back(Task{obj, type_info, false, bind_free_vars, type_info->type_key_hash}); + } + } + + std::vector *tasks; + bool obj_bind_free_vars; + }; + std::vector tasks; + std::vector result_hashes; + std::unordered_map obj2hash; + int64_t num_bound_nodes = 0; + int64_t num_unbound_vars = 0; + Visitor::EnqueueTask(&tasks, false, obj); + while (!tasks.empty()) { + MLCTypeInfo *type_info; + bool bind_free_vars; + uint64_t hash_value; + { + Task &task = tasks.back(); + hash_value = task.hash_value; + obj = task.obj; + type_info = task.type_info; + bind_free_vars = task.bind_free_vars; + if (task.visited) { + if (result_hashes.size() < task.index_in_result_hashes) { + MLC_THROW(InternalError) + << "Internal invariant violated: `result_hashes.size() < task.index_in_result_hashes` (" + << result_hashes.size() << " vs " << task.index_in_result_hashes << ")"; + } + for (; result_hashes.size() > task.index_in_result_hashes; result_hashes.pop_back()) { + hash_value = HashCombine(hash_value, result_hashes.back()); + } + StructureKind kind = static_cast(type_info->structure_kind); + if (kind == StructureKind::kBind || (kind == StructureKind::kVar && bind_free_vars)) { + hash_value = HashCombine(hash_value, HashCache::kBound); + hash_value = HashCombine(hash_value, num_bound_nodes++); + } else if (kind == StructureKind::kVar && !bind_free_vars) { + hash_value = HashCombine(hash_value, HashCache::kUnbound); + hash_value = HashCombine(hash_value, num_unbound_vars++); + } + obj2hash[obj] = hash_value; + result_hashes.push_back(hash_value); + tasks.pop_back(); + continue; + } else if (auto it = obj2hash.find(obj); it != obj2hash.end()) { + result_hashes.push_back(it->second); + tasks.pop_back(); + continue; + } else if (obj == nullptr) { + result_hashes.push_back(hash_value); + tasks.pop_back(); + continue; + } + task.visited = true; + task.index_in_result_hashes = result_hashes.size(); + } + // `task.visited` was `False` + if (type_info->type_index == kMLCList) { + UListObj *list = reinterpret_cast(obj); + hash_value = HashCombine(hash_value, list->size()); + for (int64_t i = list->size() - 1; i >= 0; --i) { + Visitor::EnqueueAny(&tasks, bind_free_vars, &list->at(i)); + } + } else if (type_info->type_index == kMLCDict) { + UDictObj *dict = reinterpret_cast(obj); + hash_value = HashCombine(hash_value, dict->size()); + struct KVPair { + uint64_t hash; + AnyView key; + AnyView value; + }; + std::vector kv_pairs; + for (auto &[k, v] : *dict) { + uint64_t hash = 0; + if (k.type_index == kMLCNone) { + hash = HashCache::kNoneCombined; + } else if (k.type_index == kMLCInt) { + hash = Visitor::HashInteger(k.v.v_int64); + } else if (k.type_index == kMLCFloat) { + hash = Visitor::HashDouble(k.v.v_float64); + } else if (k.type_index == kMLCPtr) { + hash = Visitor::HashPtr(k.v.v_ptr); + } else if (k.type_index == kMLCDataType) { + hash = Visitor::HashDataType(k.v.v_dtype); + } else if (k.type_index == kMLCDevice) { + hash = Visitor::HashDevice(k.v.v_device); + } else if (k.type_index == kMLCStr) { + const StrObj *str = k; + hash = ::mlc::base::StrHash(str->data(), str->length()); + hash = HashTyped(HashCache::kStrObj, hash); + } else if (k.type_index >= kMLCStaticObjectBegin) { + obj = k; + if (auto it = obj2hash.find(obj); it != obj2hash.end()) { + hash = it->second; + } else { + continue; // Skip unbound nodes + } + } + kv_pairs.push_back(KVPair{hash, k, v}); + } + std::sort(kv_pairs.begin(), kv_pairs.end(), [](const KVPair &a, const KVPair &b) { return a.hash < b.hash; }); + for (size_t i = 0; i < kv_pairs.size();) { + // [i, j) are of the same hash + size_t j = i + 1; + for (; j < kv_pairs.size() && kv_pairs[i].hash == kv_pairs[j].hash; ++j) { + } + // skip cases where multiple keys have the same hash + if (i + 1 == j) { + Any k = kv_pairs[i].key; + Any v = kv_pairs[i].value; + Visitor::EnqueueAny(&tasks, bind_free_vars, &k); + Visitor::EnqueueAny(&tasks, bind_free_vars, &v); + } + } + } else { + VisitStructure(obj, type_info, Visitor{&tasks, bind_free_vars}); + } + } + if (result_hashes.size() != 1) { + MLC_THROW(InternalError) << "Internal invariant violated: `result_hashes.size() != 1` (" << result_hashes.size() + << ")"; } + return result_hashes[0]; } -#undef MLC_CORE_EQ_S_COMPARE_OPT -#undef MLC_CORE_EQ_S_COMPARE -#undef MLC_CORE_EQ_S_ERR_OUT -#undef MLC_CORE_EQ_S_CMP_ANY +#undef MLC_CORE_EQ_S_OPT +#undef MLC_CORE_EQ_S_POD +#undef MLC_CORE_EQ_S_ANY +#undef MLC_CORE_EQ_S_ERR +#undef MLC_CORE_HASH_S_OPT +#undef MLC_CORE_HASH_S_POD +#undef MLC_CORE_HASH_S_ANY } // namespace core } // namespace mlc diff --git a/include/mlc/core/udict.h b/include/mlc/core/udict.h index 0bb2b291..5f998b27 100644 --- a/include/mlc/core/udict.h +++ b/include/mlc/core/udict.h @@ -7,9 +7,10 @@ namespace mlc { namespace core { struct AnyHash { - uint64_t operator()(const MLCAny &a) const { + inline uint64_t operator()(const MLCAny &a) const { if (a.type_index == static_cast(MLCTypeIndex::kMLCStr)) { - return ::mlc::core::StrHash(reinterpret_cast(a.v.v_obj)); + const MLCStr *str = reinterpret_cast(a.v.v_obj); + return ::mlc::base::StrHash(str->data, str->length); } union { int64_t i64; @@ -26,7 +27,9 @@ struct AnyEqual { return false; } if (a.type_index == static_cast(MLCTypeIndex::kMLCStr)) { - return ::mlc::core::StrCompare(reinterpret_cast(a.v.v_obj), reinterpret_cast(b.v.v_obj)) == 0; + const MLCStr *str_a = reinterpret_cast(a.v.v_obj); + const MLCStr *str_b = reinterpret_cast(b.v.v_obj); + return ::mlc::base::StrCompare(str_a->data, str_b->data, str_a->length, str_b->length) == 0; } return a.v.v_int64 == b.v.v_int64; } diff --git a/pyproject.toml b/pyproject.toml index 636dff1e..57329656 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "mlc-python" -version = "0.0.10" +version = "0.0.11" dependencies = [ 'numpy >= 1.22', "ml-dtypes >= 0.1", diff --git a/python/mlc/_cython/core.pyx b/python/mlc/_cython/core.pyx index 088b3871..07196f9c 100644 --- a/python/mlc/_cython/core.pyx +++ b/python/mlc/_cython/core.pyx @@ -179,10 +179,14 @@ cdef extern from "mlc/c_api.h" nogil: ctypedef struct MLCTypeInfo: int32_t type_index const char* type_key + uint64_t type_key_hash int32_t type_depth int32_t* type_ancestors MLCTypeField *fields MLCTypeMethod *methods + int32_t structure_kind + int32_t *sub_structure_indices + int32_t *sub_structure_kinds ctypedef void* MLCTypeTableHandle @@ -339,6 +343,13 @@ cdef class PyAny: def _mlc_eq_s(PyAny lhs, PyAny rhs, bind_free_vars: bool, assert_mode: bool) -> bool: return bool(func_call(_STRUCUTRAL_EQUAL, (lhs, rhs, bind_free_vars, assert_mode))) + @staticmethod + def _mlc_hash_s(PyAny x) -> object: + cdef object ret = func_call(_STRUCUTRAL_HASH, (x,)) + if ret < 0: + ret += 2 ** 63 + return ret + @classmethod def _C(cls, str name, *args): cdef PyAny func @@ -1382,4 +1393,5 @@ cdef MLCFunc* _LIST_INIT = _type_get_method(kMLCList, "__init__") cdef MLCFunc* _DICT_INIT = _type_get_method(kMLCDict, "__init__") cdef PyAny _SERIALIZE = func_get_untyped("mlc.core.JSONSerialize") # Any -> str cdef PyAny _DESERIALIZE = func_get_untyped("mlc.core.JSONDeserialize") # str -> Any -cdef PyAny _STRUCUTRAL_EQUAL = func_get_untyped("mlc.core.StructuralEqual") # (Any, Any) -> bool +cdef PyAny _STRUCUTRAL_EQUAL = func_get_untyped("mlc.core.StructuralEqual") +cdef PyAny _STRUCUTRAL_HASH = func_get_untyped("mlc.core.StructuralHash") diff --git a/python/mlc/ast/__init__.py b/python/mlc/ast/__init__.py index adcc6662..402b60b4 100644 --- a/python/mlc/ast/__init__.py +++ b/python/mlc/ast/__init__.py @@ -1,4 +1,4 @@ from . import translate +from .inspection import InspectResult, inspect_program from .mlc_ast import * # noqa: F403 -from .source import Source from .translate import translate_mlc_to_py, translate_py_to_mlc diff --git a/python/mlc/ast/inspection.py b/python/mlc/ast/inspection.py new file mode 100644 index 00000000..cb1ddb39 --- /dev/null +++ b/python/mlc/ast/inspection.py @@ -0,0 +1,247 @@ +from __future__ import annotations + +import contextlib +import dataclasses +import inspect +from collections.abc import Callable, Generator +from typing import Any + +PY_GETFILE = inspect.getfile +PY_FINDSOURCE = inspect.findsource + + +@dataclasses.dataclass(init=False) +class InspectResult: + source_name: str + source_start_line: int + source_start_column: int + source: str + source_full: str + captured: dict[str, Any] + annotations: dict[str, dict[str, Any]] + + def is_defined_in_class( + self, + frames: list[inspect.FrameInfo], # obtained via `inspect.stack()` + *, + frame_offset: int = 2, + is_decorator: Callable[[str], bool] = lambda line: line.startswith("@"), + ) -> bool: + # Step 1. Inspect `frames[frame_offset]` + try: + lineno = frames[frame_offset].lineno + line = frames[frame_offset].code_context[0].strip() # type: ignore + except: + return False + # Step 2. Determine by the line itself + if is_decorator(line): + return True + if not line.startswith("class"): + return False + # Step 3. Determine by its decorators + source_lines = self.source_full.splitlines(keepends=True) + lineno_offset = 2 + try: + source_line = source_lines[lineno - lineno_offset] + except IndexError: + return False + return is_decorator(source_line.strip()) + + +def inspect_program(program: Callable | type) -> InspectResult: + ret = InspectResult() + source = inspect_source(program) + ret.source_name, ret.source_start_line, ret.source_start_column, ret.source, ret.source_full = ( + source.source_name, + source.source_start_line, + source.source_start_column, + source.source, + source.source_full, + ) + if inspect.isfunction(program): + ret.captured = inspect_capture_function(program) + ret.annotations = inspect_annotations_function(program) + elif inspect.isclass(program): + ret.captured = inspect_capture_class(program) + ret.annotations = inspect_annotations_class(program) + else: + raise TypeError(f"{program!r} is not a function or class") + return ret + + +@contextlib.contextmanager +def override_getfile() -> Generator[None, Any, None]: + try: + inspect.getfile = getfile # type: ignore[assignment] + yield + finally: + inspect.getfile = PY_GETFILE # type: ignore[assignment] + + +def getfile(obj: Any) -> str: + if not inspect.isclass(obj): + return PY_GETFILE(obj) + mod = getattr(obj, "__module__", None) + if mod is not None: + import sys + + file = getattr(sys.modules[mod], "__file__", None) + if file is not None: + return file + for _, member in inspect.getmembers(obj): + if inspect.isfunction(member): + if obj.__qualname__ + "." + member.__name__ == member.__qualname__: + return inspect.getfile(member) + raise TypeError(f"Source for {obj:!r} not found") + + +def getsourcelines(obj: Any) -> tuple[list[str], int]: + obj = inspect.unwrap(obj) + lines, l_num = findsource(obj) + return inspect.getblock(lines[l_num:]), l_num + 1 + + +def findsource(obj: Any) -> tuple[list[str], int]: # noqa: PLR0912 + if not inspect.isclass(obj): + return PY_FINDSOURCE(obj) + + import linecache + + file = inspect.getsourcefile(obj) + if file: + linecache.checkcache(file) + else: + file = inspect.getfile(obj) + if not (file.startswith("<") and file.endswith(">")): + raise OSError("source code not available") + + module = inspect.getmodule(obj, file) + if module: + lines = linecache.getlines(file, module.__dict__) + else: + lines = linecache.getlines(file) + if not lines: + raise OSError("could not get source code") + qual_names = obj.__qualname__.replace(".", "").split(".") + in_comment = 0 + scope_stack: list[str] = [] + indent_info: dict[str, int] = {} + for i, line in enumerate(lines): + n_comment = line.count('"""') + if n_comment: + # update multi-line comments status + in_comment = in_comment ^ (n_comment & 1) + continue + if in_comment: + # skip lines within multi-line comments + continue + indent = len(line) - len(line.lstrip()) + tokens = line.split() + if len(tokens) > 1: + name = None + if tokens[0] == "def": + name = tokens[1].split(":")[0].split("(")[0] + "" + elif tokens[0] == "class": + name = tokens[1].split(":")[0].split("(")[0] + # pop scope if we are less indented + while scope_stack and indent_info[scope_stack[-1]] >= indent: + scope_stack.pop() + if name: + scope_stack.append(name) + indent_info[name] = indent + if scope_stack == qual_names: + return lines, i + raise OSError("could not find class definition") + + +@dataclasses.dataclass +class Source: + source_name: str + source_start_line: int + source_start_column: int + source: str + source_full: str + + +def inspect_source(program: Callable | type) -> Source: + with override_getfile(): + source_name: str = inspect.getsourcefile(program) # type: ignore + lines, source_start_line = getsourcelines(program) # type: ignore + if lines: + source_start_column = len(lines[0]) - len(lines[0].lstrip()) + else: + source_start_column = 0 + if source_start_column and lines: + source = "\n".join([l[source_start_column:].rstrip() for l in lines]) + else: + source = "".join(lines) + try: + # It will cause a problem when running in Jupyter Notebook. + # `mod` will be , which is a built-in module + # and `getsource` will throw a TypeError + mod = inspect.getmodule(program) + if mod: + source_full = inspect.getsource(mod) + else: + source_full = source + except TypeError: + # It's a work around for Jupyter problem. + # Since `findsource` is an internal API of inspect, we just use it + # as a fallback method. + src, _ = inspect.findsource(program) # type: ignore + source_full = "".join(src) + return Source( + source_name=source_name, + source_start_line=source_start_line, + source_start_column=source_start_column, + source=source, + source_full=source_full, + ) + + +def inspect_annotations_function(program: Callable | type) -> dict[str, dict[str, Any]]: + return {program.__name__: program.__annotations__} + + +def inspect_annotations_class(program: Callable | type) -> dict[str, dict[str, Any]]: + annotations = {} + for name, func in program.__dict__.items(): + if inspect.isfunction(func): + annotations[name] = func.__annotations__ + return annotations + + +def inspect_capture_function(func: Callable) -> dict[str, Any]: + def _getclosurevars(func: Callable) -> dict[str, Any]: + # Mofiied from `inspect.getclosurevars` + if inspect.ismethod(func): + func = func.__func__ + if not inspect.isfunction(func): + raise TypeError(f"{func!r} is not a Python function") + code = func.__code__ + # Nonlocal references are named in co_freevars and resolved + # by looking them up in __closure__ by positional index + nonlocal_vars = {} + if func.__closure__ is not None: + for var, cell in zip(code.co_freevars, func.__closure__): + try: + nonlocal_vars[var] = cell.cell_contents + except ValueError as err: + # cell_contents may raise ValueError if the cell is empty. + if "empty" not in str(err): + raise + return nonlocal_vars + + return { + **func.__globals__, + **_getclosurevars(func), + } + + +def inspect_capture_class(cls: type) -> dict[str, Any]: + result: dict[str, Any] = {} + for _, v in cls.__dict__.items(): + if inspect.isfunction(v): + func_vars = inspect_capture_function(v) + result.update(**func_vars) + return result diff --git a/python/mlc/ast/mlc_ast.py b/python/mlc/ast/mlc_ast.py index a6fba6bb..72bb7861 100644 --- a/python/mlc/ast/mlc_ast.py +++ b/python/mlc/ast/mlc_ast.py @@ -3,44 +3,44 @@ from typing import Any, Optional -from mlc.dataclasses import PyClass, py_class +from mlc import dataclasses as mlcd _Identifier = str -@py_class(type_key="mlc.ast.AST") -class AST(PyClass): ... +@mlcd.py_class(type_key="mlc.ast.AST", structure="bind") +class AST(mlcd.PyClass): ... -@py_class(type_key="mlc.ast.mod") +@mlcd.py_class(type_key="mlc.ast.mod", structure="bind") class mod(AST): ... -@py_class(type_key="mlc.ast.expr_context") +@mlcd.py_class(type_key="mlc.ast.expr_context", structure="bind") class expr_context(AST): ... -@py_class(type_key="mlc.ast.operator") +@mlcd.py_class(type_key="mlc.ast.operator", structure="bind") class operator(AST): ... -@py_class(type_key="mlc.ast.cmpop") +@mlcd.py_class(type_key="mlc.ast.cmpop", structure="bind") class cmpop(AST): ... -@py_class(type_key="mlc.ast.unaryop") +@mlcd.py_class(type_key="mlc.ast.unaryop", structure="bind") class unaryop(AST): ... -@py_class(type_key="mlc.ast.boolop") +@mlcd.py_class(type_key="mlc.ast.boolop", structure="bind") class boolop(AST): ... -@py_class(type_key="mlc.ast.type_ignore") +@mlcd.py_class(type_key="mlc.ast.type_ignore", structure="bind") class type_ignore(AST): ... -@py_class(type_key="mlc.ast.stmt") +@mlcd.py_class(type_key="mlc.ast.stmt", structure="bind") class stmt(AST): lineno: Optional[int] col_offset: Optional[int] @@ -48,7 +48,7 @@ class stmt(AST): end_col_offset: Optional[int] -@py_class(type_key="mlc.ast.expr") +@mlcd.py_class(type_key="mlc.ast.expr", structure="bind") class expr(AST): lineno: Optional[int] col_offset: Optional[int] @@ -56,7 +56,7 @@ class expr(AST): end_col_offset: Optional[int] -@py_class(type_key="mlc.ast.type_param") +@mlcd.py_class(type_key="mlc.ast.type_param", structure="bind") class type_param(AST): lineno: Optional[int] col_offset: Optional[int] @@ -64,7 +64,7 @@ class type_param(AST): end_col_offset: Optional[int] -@py_class(type_key="mlc.ast.pattern") +@mlcd.py_class(type_key="mlc.ast.pattern", structure="bind") class pattern(AST): lineno: Optional[int] col_offset: Optional[int] @@ -72,7 +72,7 @@ class pattern(AST): end_col_offset: Optional[int] -@py_class(type_key="mlc.ast.arg") +@mlcd.py_class(type_key="mlc.ast.arg", structure="bind") class arg(AST): lineno: Optional[int] col_offset: Optional[int] @@ -83,7 +83,7 @@ class arg(AST): type_comment: Optional[str] -@py_class(type_key="mlc.ast.keyword") +@mlcd.py_class(type_key="mlc.ast.keyword", structure="bind") class keyword(AST): lineno: Optional[int] col_offset: Optional[int] @@ -93,7 +93,7 @@ class keyword(AST): value: expr -@py_class(type_key="mlc.ast.alias") +@mlcd.py_class(type_key="mlc.ast.alias", structure="bind") class alias(AST): lineno: Optional[int] col_offset: Optional[int] @@ -103,7 +103,7 @@ class alias(AST): asname: Optional[_Identifier] -@py_class(type_key="mlc.ast.arguments") +@mlcd.py_class(type_key="mlc.ast.arguments", structure="bind") class arguments(AST): posonlyargs: list[arg] args: list[arg] @@ -114,7 +114,7 @@ class arguments(AST): defaults: list[expr] -@py_class(type_key="mlc.ast.comprehension") +@mlcd.py_class(type_key="mlc.ast.comprehension", structure="bind") class comprehension(AST): target: expr iter: expr @@ -122,41 +122,41 @@ class comprehension(AST): is_async: int -@py_class(type_key="mlc.ast.withitem") +@mlcd.py_class(type_key="mlc.ast.withitem", structure="bind") class withitem(AST): context_expr: expr optional_vars: Optional[expr] -@py_class(type_key="mlc.ast.TypeIgnore") +@mlcd.py_class(type_key="mlc.ast.TypeIgnore", structure="bind") class TypeIgnore(type_ignore): lineno: Optional[int] tag: str -@py_class(type_key="mlc.ast.Module") +@mlcd.py_class(type_key="mlc.ast.Module", structure="bind") class Module(mod): body: list[stmt] type_ignores: list[TypeIgnore] -@py_class(type_key="mlc.ast.Interactive") +@mlcd.py_class(type_key="mlc.ast.Interactive", structure="bind") class Interactive(mod): body: list[stmt] -@py_class(type_key="mlc.ast.Expression") +@mlcd.py_class(type_key="mlc.ast.Expression", structure="bind") class Expression(mod): body: expr -@py_class(type_key="mlc.ast.FunctionType") +@mlcd.py_class(type_key="mlc.ast.FunctionType", structure="bind") class FunctionType(mod): argtypes: list[expr] returns: expr -@py_class(type_key="mlc.ast.FunctionDef") +@mlcd.py_class(type_key="mlc.ast.FunctionDef", structure="bind") class FunctionDef(stmt): name: _Identifier args: arguments @@ -167,7 +167,7 @@ class FunctionDef(stmt): type_params: Optional[list[type_param]] -@py_class(type_key="mlc.ast.AsyncFunctionDef") +@mlcd.py_class(type_key="mlc.ast.AsyncFunctionDef", structure="bind") class AsyncFunctionDef(stmt): name: _Identifier args: arguments @@ -178,7 +178,7 @@ class AsyncFunctionDef(stmt): type_params: Optional[list[type_param]] -@py_class(type_key="mlc.ast.ClassDef") +@mlcd.py_class(type_key="mlc.ast.ClassDef", structure="bind") class ClassDef(stmt): name: _Identifier bases: list[expr] @@ -188,45 +188,45 @@ class ClassDef(stmt): type_params: Optional[list[type_param]] -@py_class(type_key="mlc.ast.Return") +@mlcd.py_class(type_key="mlc.ast.Return", structure="bind") class Return(stmt): value: Optional[expr] -@py_class(type_key="mlc.ast.Delete") +@mlcd.py_class(type_key="mlc.ast.Delete", structure="bind") class Delete(stmt): targets: list[expr] -@py_class(type_key="mlc.ast.Assign") +@mlcd.py_class(type_key="mlc.ast.Assign", structure="bind") class Assign(stmt): targets: list[expr] value: expr type_comment: Optional[str] -@py_class(type_key="mlc.ast.Attribute") +@mlcd.py_class(type_key="mlc.ast.Attribute", structure="bind") class Attribute(expr): value: expr attr: _Identifier ctx: expr_context -@py_class(type_key="mlc.ast.Subscript") +@mlcd.py_class(type_key="mlc.ast.Subscript", structure="bind") class Subscript(expr): value: expr slice: expr ctx: expr_context -@py_class(type_key="mlc.ast.AugAssign") +@mlcd.py_class(type_key="mlc.ast.AugAssign", structure="bind") class AugAssign(stmt): target: Any # Name | Attribute | Subscript op: operator value: expr -@py_class(type_key="mlc.ast.AnnAssign") +@mlcd.py_class(type_key="mlc.ast.AnnAssign", structure="bind") class AnnAssign(stmt): target: Any # Name | Attribute | Subscript annotation: expr @@ -234,7 +234,7 @@ class AnnAssign(stmt): simple: int -@py_class(type_key="mlc.ast.For") +@mlcd.py_class(type_key="mlc.ast.For", structure="bind") class For(stmt): target: expr iter: expr @@ -243,7 +243,7 @@ class For(stmt): type_comment: Optional[str] -@py_class(type_key="mlc.ast.AsyncFor") +@mlcd.py_class(type_key="mlc.ast.AsyncFor", structure="bind") class AsyncFor(stmt): target: expr iter: expr @@ -252,54 +252,54 @@ class AsyncFor(stmt): type_comment: Optional[str] -@py_class(type_key="mlc.ast.While") +@mlcd.py_class(type_key="mlc.ast.While", structure="bind") class While(stmt): test: expr body: list[stmt] orelse: list[stmt] -@py_class(type_key="mlc.ast.If") +@mlcd.py_class(type_key="mlc.ast.If", structure="bind") class If(stmt): test: expr body: list[stmt] orelse: list[stmt] -@py_class(type_key="mlc.ast.With") +@mlcd.py_class(type_key="mlc.ast.With", structure="bind") class With(stmt): items: list[withitem] body: list[stmt] type_comment: Optional[str] -@py_class(type_key="mlc.ast.AsyncWith") +@mlcd.py_class(type_key="mlc.ast.AsyncWith", structure="bind") class AsyncWith(stmt): items: list[withitem] body: list[stmt] type_comment: Optional[str] -@py_class(type_key="mlc.ast.match_case") +@mlcd.py_class(type_key="mlc.ast.match_case", structure="bind") class match_case(AST): pattern: pattern guard: Optional[expr] body: list[stmt] -@py_class(type_key="mlc.ast.Match") +@mlcd.py_class(type_key="mlc.ast.Match", structure="bind") class Match(stmt): subject: expr cases: list[match_case] -@py_class(type_key="mlc.ast.Raise") +@mlcd.py_class(type_key="mlc.ast.Raise", structure="bind") class Raise(stmt): exc: Optional[expr] cause: Optional[expr] -@py_class(type_key="mlc.ast.ExceptHandler") +@mlcd.py_class(type_key="mlc.ast.ExceptHandler", structure="bind") class ExceptHandler(AST): lineno: Optional[int] col_offset: Optional[int] @@ -310,7 +310,7 @@ class ExceptHandler(AST): body: list[stmt] -@py_class(type_key="mlc.ast.Try") +@mlcd.py_class(type_key="mlc.ast.Try", structure="bind") class Try(stmt): body: list[stmt] handlers: list[ExceptHandler] @@ -318,7 +318,7 @@ class Try(stmt): finalbody: list[stmt] -@py_class(type_key="mlc.ast.TryStar") +@mlcd.py_class(type_key="mlc.ast.TryStar", structure="bind") class TryStar(stmt): body: list[stmt] handlers: list[ExceptHandler] @@ -326,359 +326,359 @@ class TryStar(stmt): finalbody: list[stmt] -@py_class(type_key="mlc.ast.Assert") +@mlcd.py_class(type_key="mlc.ast.Assert", structure="bind") class Assert(stmt): test: expr msg: Optional[expr] -@py_class(type_key="mlc.ast.Import") +@mlcd.py_class(type_key="mlc.ast.Import", structure="bind") class Import(stmt): names: list[alias] -@py_class(type_key="mlc.ast.ImportFrom") +@mlcd.py_class(type_key="mlc.ast.ImportFrom", structure="bind") class ImportFrom(stmt): module: Optional[str] names: list[alias] level: int -@py_class(type_key="mlc.ast.Global") +@mlcd.py_class(type_key="mlc.ast.Global", structure="bind") class Global(stmt): names: list[_Identifier] -@py_class(type_key="mlc.ast.Nonlocal") +@mlcd.py_class(type_key="mlc.ast.Nonlocal", structure="bind") class Nonlocal(stmt): names: list[_Identifier] -@py_class(type_key="mlc.ast.Expr") +@mlcd.py_class(type_key="mlc.ast.Expr", structure="bind") class Expr(stmt): value: expr -@py_class(type_key="mlc.ast.Pass") +@mlcd.py_class(type_key="mlc.ast.Pass", structure="bind") class Pass(stmt): ... -@py_class(type_key="mlc.ast.Break") +@mlcd.py_class(type_key="mlc.ast.Break", structure="bind") class Break(stmt): ... -@py_class(type_key="mlc.ast.Continue") +@mlcd.py_class(type_key="mlc.ast.Continue", structure="bind") class Continue(stmt): ... -@py_class(type_key="mlc.ast.BoolOp") +@mlcd.py_class(type_key="mlc.ast.BoolOp", structure="bind") class BoolOp(expr): op: boolop values: list[expr] -@py_class(type_key="mlc.ast.Name") +@mlcd.py_class(type_key="mlc.ast.Name", structure="bind") class Name(expr): id: _Identifier ctx: expr_context -@py_class(type_key="mlc.ast.NamedExpr") +@mlcd.py_class(type_key="mlc.ast.NamedExpr", structure="bind") class NamedExpr(expr): target: Name value: expr -@py_class(type_key="mlc.ast.BinOp") +@mlcd.py_class(type_key="mlc.ast.BinOp", structure="bind") class BinOp(expr): left: expr op: operator right: expr -@py_class(type_key="mlc.ast.UnaryOp") +@mlcd.py_class(type_key="mlc.ast.UnaryOp", structure="bind") class UnaryOp(expr): op: unaryop operand: expr -@py_class(type_key="mlc.ast.Lambda") +@mlcd.py_class(type_key="mlc.ast.Lambda", structure="bind") class Lambda(expr): args: arguments body: expr -@py_class(type_key="mlc.ast.IfExp") +@mlcd.py_class(type_key="mlc.ast.IfExp", structure="bind") class IfExp(expr): test: expr body: expr orelse: expr -@py_class(type_key="mlc.ast.Dict") +@mlcd.py_class(type_key="mlc.ast.Dict", structure="bind") class Dict(expr): keys: list[Optional[expr]] values: list[expr] -@py_class(type_key="mlc.ast.Set") +@mlcd.py_class(type_key="mlc.ast.Set", structure="bind") class Set(expr): elts: list[expr] -@py_class(type_key="mlc.ast.ListComp") +@mlcd.py_class(type_key="mlc.ast.ListComp", structure="bind") class ListComp(expr): elt: expr generators: list[comprehension] -@py_class(type_key="mlc.ast.SetComp") +@mlcd.py_class(type_key="mlc.ast.SetComp", structure="bind") class SetComp(expr): elt: expr generators: list[comprehension] -@py_class(type_key="mlc.ast.DictComp") +@mlcd.py_class(type_key="mlc.ast.DictComp", structure="bind") class DictComp(expr): key: expr value: expr generators: list[comprehension] -@py_class(type_key="mlc.ast.GeneratorExp") +@mlcd.py_class(type_key="mlc.ast.GeneratorExp", structure="bind") class GeneratorExp(expr): elt: expr generators: list[comprehension] -@py_class(type_key="mlc.ast.Await") +@mlcd.py_class(type_key="mlc.ast.Await", structure="bind") class Await(expr): value: expr -@py_class(type_key="mlc.ast.Yield") +@mlcd.py_class(type_key="mlc.ast.Yield", structure="bind") class Yield(expr): value: Optional[expr] -@py_class(type_key="mlc.ast.YieldFrom") +@mlcd.py_class(type_key="mlc.ast.YieldFrom", structure="bind") class YieldFrom(expr): value: expr -@py_class(type_key="mlc.ast.Compare") +@mlcd.py_class(type_key="mlc.ast.Compare", structure="bind") class Compare(expr): left: expr ops: list[cmpop] comparators: list[expr] -@py_class(type_key="mlc.ast.Call") +@mlcd.py_class(type_key="mlc.ast.Call", structure="bind") class Call(expr): func: expr args: list[expr] keywords: list[keyword] -@py_class(type_key="mlc.ast.FormattedValue") +@mlcd.py_class(type_key="mlc.ast.FormattedValue", structure="bind") class FormattedValue(expr): value: expr conversion: int format_spec: Optional[expr] -@py_class(type_key="mlc.ast.JoinedStr") +@mlcd.py_class(type_key="mlc.ast.JoinedStr", structure="bind") class JoinedStr(expr): values: list[expr] -@py_class(type_key="mlc.ast.Ellipsis") -class Ellipsis(PyClass): ... +@mlcd.py_class(type_key="mlc.ast.Ellipsis", structure="bind") +class Ellipsis(mlcd.PyClass): ... -@py_class(type_key="mlc.ast.Constant") +@mlcd.py_class(type_key="mlc.ast.Constant", structure="bind") class Constant(expr): value: Any # None, str, bytes, bool, int, float, complex, Ellipsis kind: Optional[str] -@py_class(type_key="mlc.ast.Starred") +@mlcd.py_class(type_key="mlc.ast.Starred", structure="bind") class Starred(expr): value: expr ctx: expr_context -@py_class(type_key="mlc.ast.List") +@mlcd.py_class(type_key="mlc.ast.List", structure="bind") class List(expr): elts: list[expr] ctx: expr_context -@py_class(type_key="mlc.ast.Tuple") +@mlcd.py_class(type_key="mlc.ast.Tuple", structure="bind") class Tuple(expr): elts: list[expr] ctx: expr_context dims: list[expr] -@py_class(type_key="mlc.ast.Slice") +@mlcd.py_class(type_key="mlc.ast.Slice", structure="bind") class Slice(expr): lower: Optional[expr] upper: Optional[expr] step: Optional[expr] -@py_class(type_key="mlc.ast.Load") +@mlcd.py_class(type_key="mlc.ast.Load", structure="bind") class Load(expr_context): ... -@py_class(type_key="mlc.ast.Store") +@mlcd.py_class(type_key="mlc.ast.Store", structure="bind") class Store(expr_context): ... -@py_class(type_key="mlc.ast.Del") +@mlcd.py_class(type_key="mlc.ast.Del", structure="bind") class Del(expr_context): ... -@py_class(type_key="mlc.ast.And") +@mlcd.py_class(type_key="mlc.ast.And", structure="bind") class And(boolop): ... -@py_class(type_key="mlc.ast.Or") +@mlcd.py_class(type_key="mlc.ast.Or", structure="bind") class Or(boolop): ... -@py_class(type_key="mlc.ast.Add") +@mlcd.py_class(type_key="mlc.ast.Add", structure="bind") class Add(operator): ... -@py_class(type_key="mlc.ast.Sub") +@mlcd.py_class(type_key="mlc.ast.Sub", structure="bind") class Sub(operator): ... -@py_class(type_key="mlc.ast.Mult") +@mlcd.py_class(type_key="mlc.ast.Mult", structure="bind") class Mult(operator): ... -@py_class(type_key="mlc.ast.MatMult") +@mlcd.py_class(type_key="mlc.ast.MatMult", structure="bind") class MatMult(operator): ... -@py_class(type_key="mlc.ast.Div") +@mlcd.py_class(type_key="mlc.ast.Div", structure="bind") class Div(operator): ... -@py_class(type_key="mlc.ast.Mod") +@mlcd.py_class(type_key="mlc.ast.Mod", structure="bind") class Mod(operator): ... -@py_class(type_key="mlc.ast.Pow") +@mlcd.py_class(type_key="mlc.ast.Pow", structure="bind") class Pow(operator): ... -@py_class(type_key="mlc.ast.LShift") +@mlcd.py_class(type_key="mlc.ast.LShift", structure="bind") class LShift(operator): ... -@py_class(type_key="mlc.ast.RShift") +@mlcd.py_class(type_key="mlc.ast.RShift", structure="bind") class RShift(operator): ... -@py_class(type_key="mlc.ast.BitOr") +@mlcd.py_class(type_key="mlc.ast.BitOr", structure="bind") class BitOr(operator): ... -@py_class(type_key="mlc.ast.BitXor") +@mlcd.py_class(type_key="mlc.ast.BitXor", structure="bind") class BitXor(operator): ... -@py_class(type_key="mlc.ast.BitAnd") +@mlcd.py_class(type_key="mlc.ast.BitAnd", structure="bind") class BitAnd(operator): ... -@py_class(type_key="mlc.ast.FloorDiv") +@mlcd.py_class(type_key="mlc.ast.FloorDiv", structure="bind") class FloorDiv(operator): ... -@py_class(type_key="mlc.ast.Invert") +@mlcd.py_class(type_key="mlc.ast.Invert", structure="bind") class Invert(unaryop): ... -@py_class(type_key="mlc.ast.Not") +@mlcd.py_class(type_key="mlc.ast.Not", structure="bind") class Not(unaryop): ... -@py_class(type_key="mlc.ast.UAdd") +@mlcd.py_class(type_key="mlc.ast.UAdd", structure="bind") class UAdd(unaryop): ... -@py_class(type_key="mlc.ast.USub") +@mlcd.py_class(type_key="mlc.ast.USub", structure="bind") class USub(unaryop): ... -@py_class(type_key="mlc.ast.Eq") +@mlcd.py_class(type_key="mlc.ast.Eq", structure="bind") class Eq(cmpop): ... -@py_class(type_key="mlc.ast.NotEq") +@mlcd.py_class(type_key="mlc.ast.NotEq", structure="bind") class NotEq(cmpop): ... -@py_class(type_key="mlc.ast.Lt") +@mlcd.py_class(type_key="mlc.ast.Lt", structure="bind") class Lt(cmpop): ... -@py_class(type_key="mlc.ast.LtE") +@mlcd.py_class(type_key="mlc.ast.LtE", structure="bind") class LtE(cmpop): ... -@py_class(type_key="mlc.ast.Gt") +@mlcd.py_class(type_key="mlc.ast.Gt", structure="bind") class Gt(cmpop): ... -@py_class(type_key="mlc.ast.GtE") +@mlcd.py_class(type_key="mlc.ast.GtE", structure="bind") class GtE(cmpop): ... -@py_class(type_key="mlc.ast.Is") +@mlcd.py_class(type_key="mlc.ast.Is", structure="bind") class Is(cmpop): ... -@py_class(type_key="mlc.ast.IsNot") +@mlcd.py_class(type_key="mlc.ast.IsNot", structure="bind") class IsNot(cmpop): ... -@py_class(type_key="mlc.ast.In") +@mlcd.py_class(type_key="mlc.ast.In", structure="bind") class In(cmpop): ... -@py_class(type_key="mlc.ast.NotIn") +@mlcd.py_class(type_key="mlc.ast.NotIn", structure="bind") class NotIn(cmpop): ... -@py_class(type_key="mlc.ast.MatchValue") +@mlcd.py_class(type_key="mlc.ast.MatchValue", structure="bind") class MatchValue(pattern): value: expr -@py_class(type_key="mlc.ast.MatchSingleton") +@mlcd.py_class(type_key="mlc.ast.MatchSingleton", structure="bind") class MatchSingleton(pattern): value: int # boolean -@py_class(type_key="mlc.ast.MatchSequence") +@mlcd.py_class(type_key="mlc.ast.MatchSequence", structure="bind") class MatchSequence(pattern): patterns: list[pattern] -@py_class(type_key="mlc.ast.MatchMapping") +@mlcd.py_class(type_key="mlc.ast.MatchMapping", structure="bind") class MatchMapping(pattern): keys: list[expr] patterns: list[pattern] rest: Optional[_Identifier] -@py_class(type_key="mlc.ast.MatchClass") +@mlcd.py_class(type_key="mlc.ast.MatchClass", structure="bind") class MatchClass(pattern): cls: expr patterns: list[pattern] @@ -686,42 +686,42 @@ class MatchClass(pattern): kwd_patterns: list[pattern] -@py_class(type_key="mlc.ast.MatchStar") +@mlcd.py_class(type_key="mlc.ast.MatchStar", structure="bind") class MatchStar(pattern): name: Optional[_Identifier] -@py_class(type_key="mlc.ast.MatchAs") +@mlcd.py_class(type_key="mlc.ast.MatchAs", structure="bind") class MatchAs(pattern): pattern: Optional[pattern] name: Optional[_Identifier] -@py_class(type_key="mlc.ast.MatchOr") +@mlcd.py_class(type_key="mlc.ast.MatchOr", structure="bind") class MatchOr(pattern): patterns: list[pattern] -@py_class(type_key="mlc.ast.TypeVar") +@mlcd.py_class(type_key="mlc.ast.TypeVar", structure="bind") class TypeVar(type_param): name: _Identifier bound: Optional[expr] default_value: Optional[expr] -@py_class(type_key="mlc.ast.ParamSpec") +@mlcd.py_class(type_key="mlc.ast.ParamSpec", structure="bind") class ParamSpec(type_param): name: _Identifier default_value: Optional[expr] -@py_class(type_key="mlc.ast.TypeVarTuple") +@mlcd.py_class(type_key="mlc.ast.TypeVarTuple", structure="bind") class TypeVarTuple(type_param): name: _Identifier default_value: Optional[expr] -@py_class(type_key="mlc.ast.TypeAlias") +@mlcd.py_class(type_key="mlc.ast.TypeAlias", structure="bind") class TypeAlias(stmt): name: Name type_params: Optional[list[type_param]] diff --git a/python/mlc/ast/source.py b/python/mlc/ast/source.py deleted file mode 100644 index 0a9b5736..00000000 --- a/python/mlc/ast/source.py +++ /dev/null @@ -1,130 +0,0 @@ -from __future__ import annotations - -import ast -import inspect -import linecache -import sys -from typing import Any - - -class Source: - source_name: str - start_line: int - start_column: int - source: str - full_source: str - - def __init__( - self, - program: str | ast.AST, - ) -> None: - if isinstance(program, str): - self.source_name = "" - self.start_line = 1 - self.start_column = 0 - self.source = program - self.full_source = program - return - - self.source_name = inspect.getsourcefile(program) # type: ignore - lines, self.start_line = getsourcelines(program) # type: ignore - if lines: - self.start_column = len(lines[0]) - len(lines[0].lstrip()) - else: - self.start_column = 0 - if self.start_column and lines: - self.source = "\n".join([l[self.start_column :].rstrip() for l in lines]) - else: - self.source = "".join(lines) - try: - # It will cause a problem when running in Jupyter Notebook. - # `mod` will be , which is a built-in module - # and `getsource` will throw a TypeError - mod = inspect.getmodule(program) - if mod: - self.full_source = inspect.getsource(mod) - else: - self.full_source = self.source - except TypeError: - # It's a work around for Jupyter problem. - # Since `findsource` is an internal API of inspect, we just use it - # as a fallback method. - src, _ = inspect.findsource(program) # type: ignore - self.full_source = "".join(src) - - -def getfile(obj: Any) -> str: - if not inspect.isclass(obj): - return _getfile(obj) - mod = getattr(obj, "__module__", None) - if mod is not None: - file = getattr(sys.modules[mod], "__file__", None) - if file is not None: - return file - for _, member in inspect.getmembers(obj): - if inspect.isfunction(member): - if obj.__qualname__ + "." + member.__name__ == member.__qualname__: - return inspect.getfile(member) - raise TypeError(f"Source for {obj:!r} not found") - - -def getsourcelines(obj: Any) -> tuple[list[str], int]: - obj = inspect.unwrap(obj) - lines, l_num = findsource(obj) - return inspect.getblock(lines[l_num:]), l_num + 1 - - -def findsource(obj: Any) -> tuple[list[str], int]: # noqa: PLR0912 - if not inspect.isclass(obj): - return _findsource(obj) - - file = inspect.getsourcefile(obj) - if file: - linecache.checkcache(file) - else: - file = inspect.getfile(obj) - if not (file.startswith("<") and file.endswith(">")): - raise OSError("source code not available") - - module = inspect.getmodule(obj, file) - if module: - lines = linecache.getlines(file, module.__dict__) - else: - lines = linecache.getlines(file) - if not lines: - raise OSError("could not get source code") - qual_names = obj.__qualname__.replace(".", "").split(".") - in_comment = 0 - scope_stack: list[str] = [] - indent_info: dict[str, int] = {} - for i, line in enumerate(lines): - n_comment = line.count('"""') - if n_comment: - # update multi-line comments status - in_comment = in_comment ^ (n_comment & 1) - continue - if in_comment: - # skip lines within multi-line comments - continue - indent = len(line) - len(line.lstrip()) - tokens = line.split() - if len(tokens) > 1: - name = None - if tokens[0] == "def": - name = tokens[1].split(":")[0].split("(")[0] + "" - elif tokens[0] == "class": - name = tokens[1].split(":")[0].split("(")[0] - # pop scope if we are less indented - while scope_stack and indent_info[scope_stack[-1]] >= indent: - scope_stack.pop() - if name: - scope_stack.append(name) - indent_info[name] = indent - if scope_stack == qual_names: - return lines, i - raise OSError("could not find class definition") - - -_getfile = inspect.getfile -_findsource = inspect.findsource -inspect.getfile = getfile # type: ignore[assignment] diff --git a/python/mlc/core/object.py b/python/mlc/core/object.py index 0e87c106..f7f1e189 100644 --- a/python/mlc/core/object.py +++ b/python/mlc/core/object.py @@ -21,3 +21,6 @@ def eq_s( assert_mode: bool = False, ) -> bool: return PyAny._mlc_eq_s(self, other, bind_free_vars, assert_mode) # type: ignore[attr-defined] + + def hash_s(self) -> int: + return PyAny._mlc_hash_s(self) # type: ignore[attr-defined] diff --git a/tests/python/test_dataclasses_structure.py b/tests/python/test_dataclasses_structure.py index 73bee759..9639eeeb 100644 --- a/tests/python/test_dataclasses_structure.py +++ b/tests/python/test_dataclasses_structure.py @@ -75,6 +75,7 @@ def test_free_var_1() -> None: lhs = x rhs = x lhs.eq_s(rhs, bind_free_vars=True, assert_mode=True) + assert lhs.hash_s() == rhs.hash_s() with pytest.raises(ValueError) as e: lhs.eq_s(rhs, bind_free_vars=False, assert_mode=True) assert str(e.value) == "Structural equality check failed at {root}: Unbound variable" @@ -85,6 +86,7 @@ def test_free_var_2() -> None: lhs = Add(x, x) rhs = Add(x, x) lhs.eq_s(rhs, bind_free_vars=True, assert_mode=True) + assert lhs.hash_s() == rhs.hash_s() with pytest.raises(ValueError) as e: lhs.eq_s(rhs, bind_free_vars=False, assert_mode=True) assert str(e.value) == "Structural equality check failed at {root}.a: Unbound variable" @@ -94,7 +96,10 @@ def test_cyclic() -> None: x = Var("x") y = Var("y") z = Var("z") - (x + y + z).eq_s(y + z + x, bind_free_vars=True, assert_mode=True) + lhs = x + y + z + rhs = y + z + x + lhs.eq_s(rhs, bind_free_vars=True, assert_mode=True) + assert lhs.hash_s() == rhs.hash_s() def test_tensor_type() -> None: @@ -104,10 +109,12 @@ def test_tensor_type() -> None: t4 = TensorType(shape=(1, 2), dtype="float32") # t1 == t2 t1.eq_s(t2, bind_free_vars=False, assert_mode=True) + assert t1.hash_s() == t2.hash_s() # t1 != t3, dtype mismatch with pytest.raises(ValueError) as e: t1.eq_s(t3, bind_free_vars=False, assert_mode=True) assert str(e.value) == "Structural equality check failed at {root}.dtype: float32 vs int32" + assert t1.hash_s() != t3.hash_s() # t1 != t4, shape mismatch with pytest.raises(ValueError) as e: t1.eq_s(t4, bind_free_vars=False, assert_mode=True) @@ -115,6 +122,7 @@ def test_tensor_type() -> None: str(e.value) == "Structural equality check failed at {root}.shape: List length mismatch: 3 vs 2" ) + assert t1.hash_s() != t4.hash_s() def test_constant() -> None: @@ -125,6 +133,8 @@ def test_constant() -> None: with pytest.raises(ValueError) as e: c1.eq_s(c3, bind_free_vars=False, assert_mode=True) assert str(e.value) == "Structural equality check failed at {root}.value: 1 vs 2" + assert c1.hash_s() == c2.hash_s() + assert c1.hash_s() != c3.hash_s() def test_let_1() -> None: @@ -145,6 +155,7 @@ def test_let_1() -> None: with pytest.raises(ValueError) as e: lhs.eq_s(rhs, bind_free_vars=False, assert_mode=True) assert str(e.value) == "Structural equality check failed at {root}.rhs.a: Unbound variable" + assert lhs.hash_s() == rhs.hash_s() def test_let_2() -> None: @@ -162,6 +173,8 @@ def test_let_2() -> None: with pytest.raises(ValueError) as e: l1.eq_s(l3, bind_free_vars=True, assert_mode=True) assert str(e.value) == "Structural equality check failed at {root}.rhs.value: 1 vs 2" + assert l1.hash_s() == l2.hash_s() + assert l1.hash_s() != l3.hash_s() def test_non_scoped_compute_1() -> None: @@ -181,6 +194,7 @@ def test_non_scoped_compute_1() -> None: lhs = y + y rhs = y + z lhs.eq_s(rhs, bind_free_vars=True, assert_mode=True) + assert lhs.hash_s() == rhs.hash_s() def test_non_scoped_compute_2() -> None: @@ -205,6 +219,7 @@ def test_non_scoped_compute_2() -> None: str(e.value) == "Structural equality check failed at {root}.b: Inconsistent binding. " "LHS has been bound to a different node while RHS is not bound" ) + assert lhs.hash_s() != rhs.hash_s() def test_func_1() -> None: @@ -227,6 +242,7 @@ def test_func_1() -> None: lhs = Func("lhs", args=[x, y], body=Let(rhs=x + y, lhs=z, body=z + z)) rhs = Func("rhs", args=[y, x], body=Let(rhs=y + x, lhs=z, body=z + z)) lhs.eq_s(rhs, bind_free_vars=False, assert_mode=True) + assert lhs.hash_s() == rhs.hash_s() def test_func_2() -> None: @@ -259,6 +275,8 @@ def test_func_2() -> None: str(e.value) == "Structural equality check failed at {root}.args: List length mismatch: 4 vs 3" ) + assert l1.hash_s() == l2.hash_s() + assert l1.hash_s() != l3.hash_s() def test_func_stmts() -> None: @@ -301,3 +319,4 @@ def test_func_stmts() -> None: == "Structural equality check failed at {root}.stmts[0].rhs.b: Inconsistent binding. " "LHS has been bound to a different node while RHS is not bound" ) + assert func_f.hash_s() != func_g.hash_s()