diff --git a/.clang-format b/.clang-format
new file mode 100644
index 00000000..173f8a6a
--- /dev/null
+++ b/.clang-format
@@ -0,0 +1,2 @@
+BasedOnStyle: LLVM
+ColumnLimit: 120
diff --git a/.clangd b/.clangd
new file mode 100644
index 00000000..1ce440cd
--- /dev/null
+++ b/.clangd
@@ -0,0 +1,2 @@
+CompileFlags:
+ CompilationDatabase: build-vscode/
diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml
new file mode 100644
index 00000000..9d146f53
--- /dev/null
+++ b/.github/workflows/ci.yml
@@ -0,0 +1,82 @@
+name: CI
+
+on: [push, pull_request]
+env:
+ CIBW_BUILD_VERBOSITY: 3
+ CIBW_TEST_REQUIRES: "pytest"
+ CIBW_TEST_COMMAND: "pytest -svv --durations=20 {project}/tests/python/"
+ MLC_CIBW_VERSION: "2.20.0"
+ MLC_PYTHON_VERSION: "3.9"
+ MLC_CIBW_WIN_BUILD: "cp39-win_amd64"
+ MLC_CIBW_MAC_BUILD: "cp39-macosx_arm64"
+ MLC_CIBW_LINUX_BUILD: "cp312-manylinux_x86_64"
+
+jobs:
+ pre-commit:
+ runs-on: ubuntu-latest
+ steps:
+ - uses: actions/checkout@v4
+ - uses: actions/setup-python@v5
+ with:
+ python-version: ${{ env.MLC_PYTHON_VERSION }}
+ - uses: pre-commit/action@v3.0.1
+ windows:
+ name: Windows
+ runs-on: windows-latest
+ needs: pre-commit
+ steps:
+ - uses: actions/checkout@v4
+ with:
+ submodules: "recursive"
+ - uses: actions/setup-python@v5
+ with:
+ python-version: ${{ env.MLC_PYTHON_VERSION }}
+ - name: Install cibuildwheel
+ run: python -m pip install cibuildwheel=="${{ env.MLC_CIBW_VERSION }}"
+ - name: Build wheels
+ run: python -m cibuildwheel --output-dir wheelhouse
+ env:
+ CIBW_BEFORE_ALL: ".\\scripts\\cpp_tests.bat"
+ CIBW_BUILD: ${{ env.MLC_CIBW_WIN_BUILD }}
+ - name: Show package contents
+ run: python scripts/show_wheel_content.py wheelhouse
+ macos:
+ name: MacOS
+ runs-on: macos-latest
+ needs: pre-commit
+ steps:
+ - uses: actions/checkout@v4
+ with:
+ submodules: "recursive"
+ - uses: actions/setup-python@v5
+ with:
+ python-version: ${{ env.MLC_PYTHON_VERSION }}
+ - name: Install cibuildwheel
+ run: python -m pip install cibuildwheel==${{ env.MLC_CIBW_VERSION }}
+ - name: Build wheels
+ run: python -m cibuildwheel --output-dir wheelhouse
+ env:
+ CIBW_BEFORE_ALL: "./scripts/cpp_tests.sh"
+ CIBW_BUILD: ${{ env.MLC_CIBW_MAC_BUILD }}
+ - name: Show package contents
+ run: python scripts/show_wheel_content.py wheelhouse
+ linux:
+ name: Linux
+ runs-on: ubuntu-latest
+ needs: pre-commit
+ steps:
+ - uses: actions/checkout@v4
+ with:
+ submodules: "recursive"
+ - uses: actions/setup-python@v5
+ with:
+ python-version: ${{ env.MLC_PYTHON_VERSION }}
+ - name: Install cibuildwheel
+ run: python -m pip install cibuildwheel==${{ env.MLC_CIBW_VERSION }}
+ - name: Build wheels
+ run: python -m cibuildwheel --output-dir wheelhouse
+ env:
+ CIBW_BEFORE_ALL: "./scripts/setup_manylinux2014.sh && ./scripts/cpp_tests.sh"
+ CIBW_BUILD: ${{ env.MLC_CIBW_LINUX_BUILD }}
+ - name: Show package contents
+ run: python scripts/show_wheel_content.py wheelhouse
diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml
new file mode 100644
index 00000000..64fa06a6
--- /dev/null
+++ b/.github/workflows/release.yml
@@ -0,0 +1,168 @@
+name: Release
+
+on:
+ release:
+ types: [published]
+
+env:
+ CIBW_BUILD_VERBOSITY: 3
+ CIBW_TEST_COMMAND: "python -c \"import mlc\""
+ CIBW_SKIP: "cp313-win_amd64" # Python 3.13 is not quite ready yet
+ MLC_CIBW_VERSION: "2.20.0"
+ MLC_PYTHON_VERSION: "3.9"
+ MLC_CIBW_WIN_BUILD: "cp3*-win_amd64"
+ MLC_CIBW_MAC_BUILD: "cp3*-macosx_arm64"
+ MLC_CIBW_MAC_X86_BUILD: "cp3*-macosx_x86_64"
+ MLC_CIBW_LINUX_BUILD: "cp3*-manylinux_x86_64"
+
+jobs:
+ windows:
+ name: Windows
+ runs-on: windows-latest
+ steps:
+ - uses: actions/checkout@v4
+ with:
+ submodules: "recursive"
+ - uses: actions/setup-python@v5
+ with:
+ python-version: ${{ env.MLC_PYTHON_VERSION }}
+ - name: Install cibuildwheel
+ run: python -m pip install cibuildwheel=="${{ env.MLC_CIBW_VERSION }}"
+ - name: Build wheels
+ run: python -m cibuildwheel --output-dir wheelhouse
+ env:
+ CIBW_BUILD: ${{ env.MLC_CIBW_WIN_BUILD }}
+ - name: Show package contents
+ run: python scripts/show_wheel_content.py wheelhouse
+ - name: Upload wheels
+ uses: actions/upload-artifact@v4
+ with:
+ name: wheels-windows
+ path: ./wheelhouse/*.whl
+ macos:
+ name: MacOS
+ runs-on: macos-latest
+ steps:
+ - uses: actions/checkout@v4
+ with:
+ submodules: "recursive"
+ - uses: actions/setup-python@v5
+ with:
+ python-version: ${{ env.MLC_PYTHON_VERSION }}
+ - name: Install cibuildwheel
+ run: python -m pip install cibuildwheel==${{ env.MLC_CIBW_VERSION }}
+ - name: Build wheels
+ run: python -m cibuildwheel --output-dir wheelhouse
+ env:
+ CIBW_BUILD: ${{ env.MLC_CIBW_MAC_BUILD }}
+ - name: Show package contents
+ run: python scripts/show_wheel_content.py wheelhouse
+ - name: Upload wheels
+ uses: actions/upload-artifact@v4
+ with:
+ name: wheels-macos
+ path: ./wheelhouse/*.whl
+ macos-x86:
+ name: MacOS-x86
+ runs-on: macos-13
+ steps:
+ - uses: actions/checkout@v4
+ with:
+ submodules: "recursive"
+ - uses: actions/setup-python@v5
+ with:
+ python-version: ${{ env.MLC_PYTHON_VERSION }}
+ - name: Install cibuildwheel
+ run: python -m pip install cibuildwheel==${{ env.MLC_CIBW_VERSION }}
+ - name: Build wheels
+ run: python -m cibuildwheel --output-dir wheelhouse
+ env:
+ CIBW_BUILD: ${{ env.MLC_CIBW_MAC_X86_BUILD }}
+ - name: Show package contents
+ run: python scripts/show_wheel_content.py wheelhouse
+ - name: Upload wheels
+ uses: actions/upload-artifact@v4
+ with:
+ name: wheels-macos-x86
+ path: ./wheelhouse/*.whl
+ linux:
+ name: Linux
+ runs-on: ubuntu-latest
+ steps:
+ - uses: actions/checkout@v4
+ with:
+ submodules: "recursive"
+ - uses: actions/setup-python@v5
+ with:
+ python-version: ${{ env.MLC_PYTHON_VERSION }}
+ - name: Install cibuildwheel
+ run: python -m pip install cibuildwheel==${{ env.MLC_CIBW_VERSION }}
+ - name: Build wheels
+ run: python -m cibuildwheel --output-dir wheelhouse
+ env:
+ CIBW_BUILD: ${{ env.MLC_CIBW_LINUX_BUILD }}
+ - name: Show package contents
+ run: python scripts/show_wheel_content.py wheelhouse
+ - name: Upload wheels
+ uses: actions/upload-artifact@v4
+ with:
+ name: wheels-linux
+ path: ./wheelhouse/*.whl
+ publish:
+ name: Publish
+ runs-on: ubuntu-latest
+ needs: [windows, macos, linux, macos-x86]
+ environment:
+ name: pypi
+ url: https://pypi.org/p/mlc-python
+ permissions:
+ id-token: write
+ contents: write
+ steps:
+ - uses: actions/setup-python@v5
+ with:
+ python-version: ${{ env.MLC_PYTHON_VERSION }}
+ - name: Checkout code
+ uses: actions/checkout@v4
+ - name: Download all artifacts
+ uses: actions/download-artifact@v4
+ with:
+ path: ./wheelhouse
+ - name: Prepare distribution files
+ run: |
+ mkdir -p dist
+ mv wheelhouse/wheels-macos-x86/*.whl dist/
+ mv wheelhouse/wheels-macos/*.whl dist/
+ mv wheelhouse/wheels-linux/*.whl dist/
+ mv wheelhouse/wheels-windows/*.whl dist/
+ - name: Upload wheels to release
+ env:
+ GITHUB_TOKEN: ${{ secrets.PYMLC_GITHUB_TOKEN }}
+ run: |
+ # Get the release ID
+ release_id=$(gh api repos/${{ github.repository }}/releases/tags/${{ github.ref_name }} --jq '.id')
+ if [ -z "$release_id" ]; then
+ echo "Error: Could not find release with tag ${{ github.ref_name }}"
+ exit 1
+ fi
+ upload_url=https://uploads.github.com/repos/${{ github.repository }}/releases/$release_id/assets
+ echo "Uploading to release ID: $release_id. Upload URL: $upload_url"
+ for wheel in dist/*.whl; do
+ echo "⏫ Uploading $(basename $wheel)"
+ response=$(curl -s -w "%{http_code}" -X POST \
+ -H "Authorization: token $GITHUB_TOKEN" \
+ -H "Content-Type: application/octet-stream" \
+ --data-binary @"$wheel" \
+ "$upload_url?name=$(basename $wheel)")
+ http_code=${response: -3}
+ if [ $http_code -eq 201 ]; then
+ echo "🎉 Successfully uploaded $(basename $wheel)"
+ else
+ echo "❌ Failed to upload $(basename $wheel). HTTP status: $http_code. Error response: ${response%???}"
+ exit 1
+ fi
+ done
+ - name: Publish to PyPI
+ uses: pypa/gh-action-pypi-publish@release/v1
+ with:
+ packages-dir: dist/
diff --git a/.gitignore b/.gitignore
new file mode 100644
index 00000000..54e90205
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,4 @@
+.vscode
+*.dSYM
+build
+build-cpp-tests
diff --git a/.gitmodules b/.gitmodules
new file mode 100644
index 00000000..28b086b9
--- /dev/null
+++ b/.gitmodules
@@ -0,0 +1,9 @@
+[submodule "3rdparty/dlpack"]
+ path = 3rdparty/dlpack
+ url = https://github.com/dmlc/dlpack.git
+[submodule "3rdparty/libbacktrace"]
+ path = 3rdparty/libbacktrace
+ url = https://github.com/ianlancetaylor/libbacktrace
+[submodule "3rdparty/googletest"]
+ path = 3rdparty/googletest
+ url = https://github.com/google/googletest.git
diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
new file mode 100644
index 00000000..bf1659c0
--- /dev/null
+++ b/.pre-commit-config.yaml
@@ -0,0 +1,54 @@
+# See https://pre-commit.com for more information
+# See https://pre-commit.com/hooks.html for more hooks
+repos:
+ - repo: https://github.com/pre-commit/pre-commit-hooks
+ rev: v5.0.0
+ hooks:
+ - id: trailing-whitespace
+ - id: mixed-line-ending
+ - id: end-of-file-fixer
+ - id: check-yaml
+ - id: check-toml
+ - id: check-added-large-files
+ - repo: https://github.com/astral-sh/ruff-pre-commit
+ rev: v0.4.7
+ hooks:
+ - id: ruff
+ types_or: [python, pyi, jupyter]
+ args: [--fix]
+ - id: ruff-format
+ types_or: [python, pyi, jupyter]
+ - repo: https://github.com/pre-commit/mirrors-mypy
+ rev: "v1.11.1"
+ hooks:
+ - id: mypy
+ additional_dependencies: ['numpy >= 1.22', "ml-dtypes >= 0.1", "pytest"]
+ args: [--show-error-codes]
+ - repo: https://github.com/pre-commit/mirrors-clang-format
+ rev: "v18.1.5"
+ hooks:
+ - id: clang-format
+ - repo: https://github.com/MarcoGorelli/cython-lint
+ rev: v0.16.2
+ hooks:
+ - id: cython-lint
+ - id: double-quote-cython-strings
+ - repo: https://github.com/scop/pre-commit-shfmt
+ rev: v3.8.0-1
+ hooks:
+ - id: shfmt
+ - repo: https://github.com/shellcheck-py/shellcheck-py
+ rev: v0.10.0.1
+ hooks:
+ - id: shellcheck
+ # - repo: https://github.com/cheshirekow/cmake-format-precommit
+ # rev: v0.6.10
+ # hooks:
+ # - id: cmake-format
+ # - id: cmake-lint
+ - repo: https://github.com/compilerla/conventional-pre-commit
+ rev: v2.1.1
+ hooks:
+ - id: conventional-pre-commit
+ stages: [commit-msg]
+ args: [feat, fix, ci, chore, test]
diff --git a/3rdparty/dlpack b/3rdparty/dlpack
new file mode 160000
index 00000000..bbd2f4d3
--- /dev/null
+++ b/3rdparty/dlpack
@@ -0,0 +1 @@
+Subproject commit bbd2f4d32427e548797929af08cfe2a9cbb3cf12
diff --git a/3rdparty/googletest b/3rdparty/googletest
new file mode 160000
index 00000000..f8d7d77c
--- /dev/null
+++ b/3rdparty/googletest
@@ -0,0 +1 @@
+Subproject commit f8d7d77c06936315286eb55f8de22cd23c188571
diff --git a/3rdparty/libbacktrace b/3rdparty/libbacktrace
new file mode 160000
index 00000000..febbb9bf
--- /dev/null
+++ b/3rdparty/libbacktrace
@@ -0,0 +1 @@
+Subproject commit febbb9bff98b39ee596aa15d1f58e4bba442cd6a
diff --git a/CMakeLists.txt b/CMakeLists.txt
new file mode 100644
index 00000000..961ccd48
--- /dev/null
+++ b/CMakeLists.txt
@@ -0,0 +1,107 @@
+cmake_minimum_required(VERSION 3.15)
+
+project(
+ mlc
+ VERSION 0.0.14
+ DESCRIPTION "MLC-Python"
+ LANGUAGES C CXX
+)
+
+option(MLC_BUILD_TESTS "Build tests. This option will enable a test target `mlc_tests`." OFF)
+option(MLC_BUILD_PY "Build Python bindings." OFF)
+option(MLC_BUILD_REGISTRY
+ "Support for objects with non-static type indices. When turned on, \
+ targets linked against `mlc` will allow objects that comes with non-pre-defined type indices, \
+ so that the object hierarchy could expand without limitation. \
+ This will require the downstream targets to link against target `mlc_registry` to be effective."
+ OFF
+)
+
+include(cmake/Utils/CxxWarning.cmake)
+include(cmake/Utils/Sanitizer.cmake)
+include(cmake/Utils/Library.cmake)
+include(cmake/Utils/AddLibbacktrace.cmake)
+include(cmake/Utils/DebugSymbol.cmake)
+
+########## Target: `dlpack_header` ##########
+
+add_library(dlpack_header INTERFACE)
+target_include_directories(dlpack_header INTERFACE "${CMAKE_CURRENT_SOURCE_DIR}/3rdparty/dlpack/include")
+
+########## Target: `mlc` ##########
+
+add_library(mlc INTERFACE)
+target_link_libraries(mlc INTERFACE dlpack_header)
+target_compile_features(mlc INTERFACE cxx_std_17)
+target_include_directories(mlc INTERFACE "${CMAKE_CURRENT_SOURCE_DIR}/include")
+
+########## Target: `mlc_registry` ##########
+
+if (MLC_BUILD_REGISTRY)
+ add_library(mlc_registry_objs OBJECT
+ "${CMAKE_CURRENT_SOURCE_DIR}/cpp/c_api.cc"
+ "${CMAKE_CURRENT_SOURCE_DIR}/cpp/c_api_tests.cc"
+ "${CMAKE_CURRENT_SOURCE_DIR}/cpp/printer.cc"
+ "${CMAKE_CURRENT_SOURCE_DIR}/cpp/json.cc"
+ "${CMAKE_CURRENT_SOURCE_DIR}/cpp/structure.cc"
+ "${CMAKE_CURRENT_SOURCE_DIR}/cpp/traceback.cc"
+ "${CMAKE_CURRENT_SOURCE_DIR}/cpp/traceback_win.cc"
+ )
+ set_target_properties(
+ mlc_registry_objs PROPERTIES
+ POSITION_INDEPENDENT_CODE ON
+ CXX_STANDARD 17
+ CXX_EXTENSIONS OFF
+ CXX_STANDARD_REQUIRED ON
+ CXX_VISIBILITY_PRESET hidden
+ VISIBILITY_INLINES_HIDDEN ON
+ PREFIX "lib"
+ )
+ add_cxx_warning(mlc_registry_objs)
+ target_link_libraries(mlc_registry_objs PRIVATE dlpack_header)
+ target_include_directories(mlc_registry_objs PRIVATE "${CMAKE_CURRENT_SOURCE_DIR}/include")
+ add_target_from_obj(mlc_registry mlc_registry_objs)
+ if (TARGET libbacktrace)
+ target_link_libraries(mlc_registry_objs PRIVATE libbacktrace)
+ target_link_libraries(mlc_registry_shared PRIVATE libbacktrace)
+ target_link_libraries(mlc_registry_static PRIVATE libbacktrace)
+ add_debug_symbol_apple(mlc_registry_shared "lib/")
+ endif ()
+ install(TARGETS mlc_registry_static DESTINATION "lib/")
+ install(TARGETS mlc_registry_shared DESTINATION "lib/")
+endif (MLC_BUILD_REGISTRY)
+
+########## Target: `mlc_py` ##########
+
+if (MLC_BUILD_PY)
+ find_package(Python COMPONENTS Interpreter Development.Module REQUIRED)
+ file(GLOB _cython_sources CONFIGURE_DEPENDS "${CMAKE_CURRENT_SOURCE_DIR}/python/mlc/_cython/*.pyx")
+ set(_cython_outputs "")
+ foreach(_in IN LISTS _cython_sources)
+ get_filename_component(_file ${_in} NAME_WLE)
+ set(_out "${CMAKE_BINARY_DIR}/_cython/${_file}_cython.cc")
+ message(STATUS "Cythonize: ${_in} -> ${_out}")
+ add_custom_command(
+ OUTPUT "${_out}" DEPENDS "${_in}"
+ COMMENT "Making `${_out}` from `${_in}"
+ COMMAND Python::Interpreter -m cython "${_in}" --output-file "${_out}" --cplus
+ VERBATIM
+ )
+ list(APPEND _cython_outputs "${_out}")
+ endforeach()
+ Python_add_library(mlc_py MODULE ${_cython_outputs} WITH_SOABI)
+ set_target_properties(mlc_py PROPERTIES OUTPUT_NAME "core")
+ target_link_libraries(mlc_py PUBLIC mlc)
+ install(TARGETS mlc_py DESTINATION _cython/)
+endif (MLC_BUILD_PY)
+
+########## Adding tests ##########
+
+if (${PROJECT_NAME} STREQUAL ${CMAKE_PROJECT_NAME})
+ if (MLC_BUILD_TESTS)
+ enable_testing()
+ message(STATUS "Enable Testing")
+ include(cmake/Utils/AddGoogleTest.cmake)
+ add_subdirectory(tests/cpp/)
+ endif()
+endif ()
diff --git a/LICENSE b/LICENSE
new file mode 100644
index 00000000..261eeb9e
--- /dev/null
+++ b/LICENSE
@@ -0,0 +1,201 @@
+ Apache License
+ Version 2.0, January 2004
+ http://www.apache.org/licenses/
+
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
+
+ 1. Definitions.
+
+ "License" shall mean the terms and conditions for use, reproduction,
+ and distribution as defined by Sections 1 through 9 of this document.
+
+ "Licensor" shall mean the copyright owner or entity authorized by
+ the copyright owner that is granting the License.
+
+ "Legal Entity" shall mean the union of the acting entity and all
+ other entities that control, are controlled by, or are under common
+ control with that entity. For the purposes of this definition,
+ "control" means (i) the power, direct or indirect, to cause the
+ direction or management of such entity, whether by contract or
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
+ outstanding shares, or (iii) beneficial ownership of such entity.
+
+ "You" (or "Your") shall mean an individual or Legal Entity
+ exercising permissions granted by this License.
+
+ "Source" form shall mean the preferred form for making modifications,
+ including but not limited to software source code, documentation
+ source, and configuration files.
+
+ "Object" form shall mean any form resulting from mechanical
+ transformation or translation of a Source form, including but
+ not limited to compiled object code, generated documentation,
+ and conversions to other media types.
+
+ "Work" shall mean the work of authorship, whether in Source or
+ Object form, made available under the License, as indicated by a
+ copyright notice that is included in or attached to the work
+ (an example is provided in the Appendix below).
+
+ "Derivative Works" shall mean any work, whether in Source or Object
+ form, that is based on (or derived from) the Work and for which the
+ editorial revisions, annotations, elaborations, or other modifications
+ represent, as a whole, an original work of authorship. For the purposes
+ of this License, Derivative Works shall not include works that remain
+ separable from, or merely link (or bind by name) to the interfaces of,
+ the Work and Derivative Works thereof.
+
+ "Contribution" shall mean any work of authorship, including
+ the original version of the Work and any modifications or additions
+ to that Work or Derivative Works thereof, that is intentionally
+ submitted to Licensor for inclusion in the Work by the copyright owner
+ or by an individual or Legal Entity authorized to submit on behalf of
+ the copyright owner. For the purposes of this definition, "submitted"
+ means any form of electronic, verbal, or written communication sent
+ to the Licensor or its representatives, including but not limited to
+ communication on electronic mailing lists, source code control systems,
+ and issue tracking systems that are managed by, or on behalf of, the
+ Licensor for the purpose of discussing and improving the Work, but
+ excluding communication that is conspicuously marked or otherwise
+ designated in writing by the copyright owner as "Not a Contribution."
+
+ "Contributor" shall mean Licensor and any individual or Legal Entity
+ on behalf of whom a Contribution has been received by Licensor and
+ subsequently incorporated within the Work.
+
+ 2. Grant of Copyright License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ copyright license to reproduce, prepare Derivative Works of,
+ publicly display, publicly perform, sublicense, and distribute the
+ Work and such Derivative Works in Source or Object form.
+
+ 3. Grant of Patent License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ (except as stated in this section) patent license to make, have made,
+ use, offer to sell, sell, import, and otherwise transfer the Work,
+ where such license applies only to those patent claims licensable
+ by such Contributor that are necessarily infringed by their
+ Contribution(s) alone or by combination of their Contribution(s)
+ with the Work to which such Contribution(s) was submitted. If You
+ institute patent litigation against any entity (including a
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
+ or a Contribution incorporated within the Work constitutes direct
+ or contributory patent infringement, then any patent licenses
+ granted to You under this License for that Work shall terminate
+ as of the date such litigation is filed.
+
+ 4. Redistribution. You may reproduce and distribute copies of the
+ Work or Derivative Works thereof in any medium, with or without
+ modifications, and in Source or Object form, provided that You
+ meet the following conditions:
+
+ (a) You must give any other recipients of the Work or
+ Derivative Works a copy of this License; and
+
+ (b) You must cause any modified files to carry prominent notices
+ stating that You changed the files; and
+
+ (c) You must retain, in the Source form of any Derivative Works
+ that You distribute, all copyright, patent, trademark, and
+ attribution notices from the Source form of the Work,
+ excluding those notices that do not pertain to any part of
+ the Derivative Works; and
+
+ (d) If the Work includes a "NOTICE" text file as part of its
+ distribution, then any Derivative Works that You distribute must
+ include a readable copy of the attribution notices contained
+ within such NOTICE file, excluding those notices that do not
+ pertain to any part of the Derivative Works, in at least one
+ of the following places: within a NOTICE text file distributed
+ as part of the Derivative Works; within the Source form or
+ documentation, if provided along with the Derivative Works; or,
+ within a display generated by the Derivative Works, if and
+ wherever such third-party notices normally appear. The contents
+ of the NOTICE file are for informational purposes only and
+ do not modify the License. You may add Your own attribution
+ notices within Derivative Works that You distribute, alongside
+ or as an addendum to the NOTICE text from the Work, provided
+ that such additional attribution notices cannot be construed
+ as modifying the License.
+
+ You may add Your own copyright statement to Your modifications and
+ may provide additional or different license terms and conditions
+ for use, reproduction, or distribution of Your modifications, or
+ for any such Derivative Works as a whole, provided Your use,
+ reproduction, and distribution of the Work otherwise complies with
+ the conditions stated in this License.
+
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
+ any Contribution intentionally submitted for inclusion in the Work
+ by You to the Licensor shall be under the terms and conditions of
+ this License, without any additional terms or conditions.
+ Notwithstanding the above, nothing herein shall supersede or modify
+ the terms of any separate license agreement you may have executed
+ with Licensor regarding such Contributions.
+
+ 6. Trademarks. This License does not grant permission to use the trade
+ names, trademarks, service marks, or product names of the Licensor,
+ except as required for reasonable and customary use in describing the
+ origin of the Work and reproducing the content of the NOTICE file.
+
+ 7. Disclaimer of Warranty. Unless required by applicable law or
+ agreed to in writing, Licensor provides the Work (and each
+ Contributor provides its Contributions) on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+ implied, including, without limitation, any warranties or conditions
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
+ PARTICULAR PURPOSE. You are solely responsible for determining the
+ appropriateness of using or redistributing the Work and assume any
+ risks associated with Your exercise of permissions under this License.
+
+ 8. Limitation of Liability. In no event and under no legal theory,
+ whether in tort (including negligence), contract, or otherwise,
+ unless required by applicable law (such as deliberate and grossly
+ negligent acts) or agreed to in writing, shall any Contributor be
+ liable to You for damages, including any direct, indirect, special,
+ incidental, or consequential damages of any character arising as a
+ result of this License or out of the use or inability to use the
+ Work (including but not limited to damages for loss of goodwill,
+ work stoppage, computer failure or malfunction, or any and all
+ other commercial damages or losses), even if such Contributor
+ has been advised of the possibility of such damages.
+
+ 9. Accepting Warranty or Additional Liability. While redistributing
+ the Work or Derivative Works thereof, You may choose to offer,
+ and charge a fee for, acceptance of support, warranty, indemnity,
+ or other liability obligations and/or rights consistent with this
+ License. However, in accepting such obligations, You may act only
+ on Your own behalf and on Your sole responsibility, not on behalf
+ of any other Contributor, and only if You agree to indemnify,
+ defend, and hold each Contributor harmless for any liability
+ incurred by, or claims asserted against, such Contributor by reason
+ of your accepting any such warranty or additional liability.
+
+ END OF TERMS AND CONDITIONS
+
+ APPENDIX: How to apply the Apache License to your work.
+
+ To apply the Apache License to your work, attach the following
+ boilerplate notice, with the fields enclosed by brackets "[]"
+ replaced with your own identifying information. (Don't include
+ the brackets!) The text should be enclosed in the appropriate
+ comment syntax for the file format. We also recommend that a
+ file or class name and description of purpose be included on the
+ same "printed page" as the copyright notice for easier
+ identification within third-party archives.
+
+ Copyright [yyyy] [name of copyright owner]
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
diff --git a/README.md b/README.md
new file mode 100644
index 00000000..2b93b258
--- /dev/null
+++ b/README.md
@@ -0,0 +1,237 @@
+
+
+
+ MLC-Python
+
+
+* [:inbox_tray: Installation](#inbox_tray-installation)
+* [:key: Key Features](#key-key-features)
+ + [:building_construction: MLC Dataclass](#building_construction-mlc-dataclass)
+ + [:dart: Structure-Aware Tooling](#dart-structure-aware-tooling)
+ + [:snake: Text Formats in Python](#snake-text-formats-in-python)
+ + [:zap: Zero-Copy Interoperability with C++ Plugins](#zap-zero-copy-interoperability-with-c-plugins)
+* [:fuelpump: Development](#fuelpump-development)
+ + [:gear: Editable Build](#gear-editable-build)
+ + [:ferris_wheel: Create Wheels](#ferris_wheel-create-wheels)
+
+
+MLC is a Python-first toolkit that streamlines the development of AI compilers, runtimes, and compound AI systems with its Pythonic dataclasses, structure-aware tooling, and Python-based text formats.
+
+Beyond pure Python, MLC natively supports zero-copy interoperation with C++ plugins, and enables a smooth engineering practice transitioning from Python to hybrid or Python-free development.
+
+## :inbox_tray: Installation
+
+```bash
+pip install -U mlc-python
+```
+
+## :key: Key Features
+
+### :building_construction: MLC Dataclass
+
+MLC dataclass is similar to Python’s native dataclass:
+
+```python
+import mlc.dataclasses as mlcd
+
+@mlcd.py_class("demo.MyClass")
+class MyClass(mlcd.PyClass):
+ a: int
+ b: str
+ c: float | None
+
+instance = MyClass(12, "test", c=None)
+```
+
+**Type safety**. MLC dataclass enforces strict type checking using Cython and C++.
+
+```python
+>>> instance.c = 10; print(instance)
+demo.MyClass(a=12, b='test', c=10.0)
+
+>>> instance.c = "wrong type"
+TypeError: must be real number, not str
+
+>>> instance.non_exist = 1
+AttributeError: 'MyClass' object has no attribute 'non_exist' and no __dict__ for setting new attributes
+```
+
+**Serialization**. MLC dataclasses are picklable and JSON-serializable.
+
+```python
+>>> MyClass.from_json(instance.json())
+demo.MyClass(a=12, b='test', c=None)
+
+>>> import pickle; pickle.loads(pickle.dumps(instance))
+demo.MyClass(a=12, b='test', c=None)
+```
+
+### :dart: Structure-Aware Tooling
+
+An extra `structure` field are used to specify a dataclass's structure, indicating def site and scoping in an IR.
+
+ Define a toy IR with `structure`.
+
+```python
+import mlc.dataclasses as mlcd
+
+@mlcd.py_class
+class Expr(mlcd.PyClass):
+ def __add__(self, other):
+ return Add(a=self, b=other)
+
+@mlcd.py_class(structure="nobind")
+class Add(Expr):
+ a: Expr
+ b: Expr
+
+@mlcd.py_class(structure="var")
+class Var(Expr):
+ name: str = mlcd.field(structure=None) # excludes `name` from defined structure
+
+@mlcd.py_class(structure="bind")
+class Let(Expr):
+ rhs: Expr
+ lhs: Var = mlcd.field(structure="bind") # `Let.lhs` is the def-site
+ body: Expr
+```
+
+
+
+**Structural equality**. Member method `eq_s` compares the structural equality (alpha equivalence) of two IRs represented by MLC's structured dataclass.
+
+```python
+>>> x, y, z = Var("x"), Var("y"), Var("z")
+>>> L1 = Let(rhs=x + y, lhs=z, body=z) # let z = x + y; z
+>>> L2 = Let(rhs=y + z, lhs=x, body=x) # let x = y + z; x
+>>> L3 = Let(rhs=x + x, lhs=z, body=z) # let z = x + x; z
+>>> L1.eq_s(L2)
+True
+>>> L1.eq_s(L3, assert_mode=True)
+ValueError: Structural equality check failed at {root}.rhs.b: Inconsistent binding. RHS has been bound to a different node while LHS is not bound
+```
+
+**Structural hashing**. The structure of MLC dataclasses can be hashed via `hash_s`, which guarantees if two dataclasses are alpha-equivalent, they will share the same structural hash:
+
+```python
+>>> L1_hash, L2_hash, L3_hash = L1.hash_s(), L2.hash_s(), L3.hash_s()
+>>> assert L1_hash == L2_hash
+>>> assert L1_hash != L3_hash
+```
+
+### :snake: Text Formats in Python
+
+**IR Printer.** By defining an `__ir_print__` method, which converts an IR node to MLC's Python-style AST, MLC's `IRPrinter` handles variable scoping, renaming and syntax highlighting automatically for a text format based on Python syntax.
+
+Defining Python-based text format on a toy IR using `__ir_print__`.
+
+```python
+import mlc.dataclasses as mlcd
+import mlc.printer as mlcp
+from mlc.printer import ast as mlt
+
+@mlcd.py_class
+class Expr(mlcd.PyClass): ...
+
+@mlcd.py_class
+class Stmt(mlcd.PyClass): ...
+
+@mlcd.py_class
+class Var(Expr):
+ name: str
+ def __ir_print__(self, printer: mlcp.IRPrinter, path: mlcp.ObjectPath) -> mlt.Node:
+ if not printer.var_is_defined(obj=self):
+ printer.var_def(obj=self, frame=printer.frames[-1], name=self.name)
+ return printer.var_get(obj=self)
+
+@mlcd.py_class
+class Add(Expr):
+ lhs: Expr
+ rhs: Expr
+ def __ir_print__(self, printer: mlcp.IRPrinter, path: mlcp.ObjectPath) -> mlt.Node:
+ lhs: mlt.Expr = printer(obj=self.lhs, path=path["a"])
+ rhs: mlt.Expr = printer(obj=self.rhs, path=path["b"])
+ return lhs + rhs
+
+@mlcd.py_class
+class Assign(Stmt):
+ lhs: Var
+ rhs: Expr
+ def __ir_print__(self, printer: mlcp.IRPrinter, path: mlcp.ObjectPath) -> mlt.Node:
+ rhs: mlt.Expr = printer(obj=self.rhs, path=path["b"])
+ printer.var_def(obj=self.lhs, frame=printer.frames[-1], name=self.lhs.name)
+ lhs: mlt.Expr = printer(obj=self.lhs, path=path["a"])
+ return mlt.Assign(lhs=lhs, rhs=rhs)
+
+@mlcd.py_class
+class Func(mlcd.PyClass):
+ name: str
+ args: list[Var]
+ stmts: list[Stmt]
+ ret: Var
+ def __ir_print__(self, printer: mlcp.IRPrinter, path: mlcp.ObjectPath) -> mlt.Node:
+ with printer.with_frame(mlcp.DefaultFrame()):
+ for arg in self.args:
+ printer.var_def(obj=arg, frame=printer.frames[-1], name=arg.name)
+ args: list[mlt.Expr] = [printer(obj=arg, path=path["args"][i]) for i, arg in enumerate(self.args)]
+ stmts: list[mlt.Expr] = [printer(obj=stmt, path=path["stmts"][i]) for i, stmt in enumerate(self.stmts)]
+ ret_stmt = mlt.Return(printer(obj=self.ret, path=path["ret"]))
+ return mlt.Function(
+ name=mlt.Id(self.name),
+ args=[mlt.Assign(lhs=arg, rhs=None) for arg in args],
+ decorators=[],
+ return_type=None,
+ body=[*stmts, ret_stmt],
+ )
+
+# An example IR:
+a, b, c, d, e = Var("a"), Var("b"), Var("c"), Var("d"), Var("e")
+f = Func(
+ name="f",
+ args=[a, b, c],
+ stmts=[
+ Assign(lhs=d, rhs=Add(a, b)), # d = a + b
+ Assign(lhs=e, rhs=Add(d, c)), # e = d + c
+ ],
+ ret=e,
+)
+```
+
+
+
+Two printer APIs are provided for Python-based text format:
+- `mlc.printer.to_python` that converts an IR fragment to Python text, and
+- `mlc.printer.print_python` that further renders the text with proper syntax highlighting.
+
+```python
+>>> print(mlcp.to_python(f)) # Stringify to Python
+def f(a, b, c):
+ d = a + b
+ e = d + c
+ return e
+>>> mlcp.print_python(f) # Syntax highlighting
+```
+
+### :zap: Zero-Copy Interoperability with C++ Plugins
+
+TBD
+
+## :fuelpump: Development
+
+### :gear: Editable Build
+
+```bash
+pip install --verbose --editable ".[dev]"
+pre-commit install
+```
+
+### :ferris_wheel: Create Wheels
+
+This project uses `cibuildwheel` to build cross-platform wheels. See `.github/workflows/wheels.ym` for more details.
+
+```bash
+export CIBW_BUILD_VERBOSITY=3
+export CIBW_BUILD="cp3*-manylinux_x86_64"
+python -m pip install pipx
+pipx run cibuildwheel==2.20.0 --output-dir wheelhouse
+```
diff --git a/cmake/Utils/AddGoogleTest.cmake b/cmake/Utils/AddGoogleTest.cmake
new file mode 100644
index 00000000..f2e52643
--- /dev/null
+++ b/cmake/Utils/AddGoogleTest.cmake
@@ -0,0 +1,43 @@
+set(gtest_force_shared_crt ON CACHE BOOL "Always use msvcrt.dll" FORCE)
+set(BUILD_GMOCK OFF CACHE BOOL "" FORCE)
+set(BUILD_GTEST ON CACHE BOOL "" FORCE)
+set(GOOGLETEST_ROOT ${CMAKE_CURRENT_SOURCE_DIR}/3rdparty/googletest)
+
+add_subdirectory(${GOOGLETEST_ROOT})
+include(GoogleTest)
+
+set_target_properties(gtest gtest_main
+ PROPERTIES
+ EXPORT_COMPILE_COMMANDS OFF
+ EXCLUDE_FROM_ALL ON
+ FOLDER 3rdparty
+ ARCHIVE_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/lib"
+ LIBRARY_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/lib"
+ RUNTIME_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/bin"
+)
+
+# Install gtest and gtest_main
+install(TARGETS gtest gtest_main DESTINATION "lib/")
+
+# Mark advanced variables
+mark_as_advanced(
+ BUILD_GMOCK BUILD_GTEST BUILD_SHARED_LIBS
+ gmock_build_tests gtest_build_samples gtest_build_tests
+ gtest_disable_pthreads gtest_force_shared_crt gtest_hide_internal_symbols
+)
+
+macro(add_googletest target_name)
+ add_test(
+ NAME ${target_name}
+ COMMAND ${target_name}
+ WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}
+ )
+ target_link_libraries(${target_name} PRIVATE gtest_main)
+ gtest_discover_tests(${target_name}
+ WORKING_DIRECTORY ${PROJECT_SOURCE_DIR}
+ DISCOVERY_MODE PRE_TEST
+ PROPERTIES
+ VS_DEBUGGER_WORKING_DIRECTORY "${PROJECT_SOURCE_DIR}"
+ )
+ set_target_properties(${target_name} PROPERTIES FOLDER tests)
+endmacro()
diff --git a/cmake/Utils/AddLibbacktrace.cmake b/cmake/Utils/AddLibbacktrace.cmake
new file mode 100644
index 00000000..c9697538
--- /dev/null
+++ b/cmake/Utils/AddLibbacktrace.cmake
@@ -0,0 +1,50 @@
+include(ExternalProject)
+
+function(_libbacktrace_compile)
+ set(_libbacktrace_source ${CMAKE_CURRENT_LIST_DIR}/../../3rdparty/libbacktrace)
+ set(_libbacktrace_prefix ${CMAKE_CURRENT_BINARY_DIR}/libbacktrace)
+ if(CMAKE_SYSTEM_NAME MATCHES "Darwin" AND (CMAKE_C_COMPILER MATCHES "^/Library" OR CMAKE_C_COMPILER MATCHES "^/Applications"))
+ set(_cmake_c_compiler "/usr/bin/cc")
+ else()
+ set(_cmake_c_compiler "${CMAKE_C_COMPILER}")
+ endif()
+
+ file(MAKE_DIRECTORY ${_libbacktrace_prefix}/include)
+ file(MAKE_DIRECTORY ${_libbacktrace_prefix}/lib)
+ ExternalProject_Add(project_libbacktrace
+ PREFIX libbacktrace
+ SOURCE_DIR ${_libbacktrace_source}
+ BINARY_DIR ${_libbacktrace_prefix}
+ CONFIGURE_COMMAND
+ ${_libbacktrace_source}/configure
+ "--prefix=${_libbacktrace_prefix}"
+ "--with-pic"
+ "CC=${_cmake_c_compiler}"
+ "CPP=${_cmake_c_compiler} -E"
+ "CFLAGS=${CMAKE_C_FLAGS}"
+ "LDFLAGS=${CMAKE_EXE_LINKER_FLAGS}"
+ "NM=${CMAKE_NM}"
+ "STRIP=${CMAKE_STRIP}"
+ "--host=${MACHINE_NAME}"
+ BUILD_COMMAND make -j
+ BUILD_BYPRODUCTS ${_libbacktrace_prefix}/lib/libbacktrace.a ${_libbacktrace_prefix}/include/backtrace.h
+ INSTALL_DIR ${_libbacktrace_prefix}
+ INSTALL_COMMAND make install
+ LOG_CONFIGURE ON
+ LOG_INSTALL ON
+ LOG_BUILD ON
+ LOG_OUTPUT_ON_FAILURE ON
+ )
+ ExternalProject_Add_Step(project_libbacktrace checkout DEPENDERS configure DEPENDEES download)
+ set_target_properties(project_libbacktrace PROPERTIES EXCLUDE_FROM_ALL TRUE)
+ add_library(libbacktrace STATIC IMPORTED)
+ add_dependencies(libbacktrace project_libbacktrace)
+ set_target_properties(libbacktrace PROPERTIES
+ IMPORTED_LOCATION ${_libbacktrace_prefix}/lib/libbacktrace.a
+ INTERFACE_INCLUDE_DIRECTORIES ${_libbacktrace_prefix}/include
+ )
+endfunction()
+
+if(NOT MSVC)
+ _libbacktrace_compile()
+endif()
diff --git a/cmake/Utils/CxxWarning.cmake b/cmake/Utils/CxxWarning.cmake
new file mode 100644
index 00000000..50ee5b61
--- /dev/null
+++ b/cmake/Utils/CxxWarning.cmake
@@ -0,0 +1,13 @@
+function(add_cxx_warning target_name)
+ # GNU, Clang, or AppleClang
+ if(CMAKE_CXX_COMPILER_ID MATCHES "GNU|Clang|AppleClang")
+ target_compile_options(${target_name} PRIVATE "-Werror" "-Wall" "-Wextra" "-Wpedantic")
+ return()
+ endif()
+ # MSVC
+ if(MSVC)
+ target_compile_options(${target_name} PRIVATE "/W4" "/WX")
+ return()
+ endif()
+ message(FATAL_ERROR "Unsupported compiler: ${CMAKE_CXX_COMPILER_ID}")
+endfunction()
diff --git a/cmake/Utils/DebugSymbol.cmake b/cmake/Utils/DebugSymbol.cmake
new file mode 100644
index 00000000..7f5bbae2
--- /dev/null
+++ b/cmake/Utils/DebugSymbol.cmake
@@ -0,0 +1,14 @@
+if(APPLE)
+ find_program(DSYMUTIL_PROGRAM dsymutil)
+ mark_as_advanced(DSYMUTIL_PROGRAM)
+endif()
+
+function(add_debug_symbol_apple _target _directory)
+ if (APPLE)
+ add_custom_command(TARGET ${_target} POST_BUILD
+ COMMAND ${DSYMUTIL_PROGRAM} ARGS $
+ COMMENT "Running dsymutil" VERBATIM
+ )
+ install(FILES $.dSYM DESTINATION ${_directory})
+ endif (APPLE)
+endfunction()
diff --git a/cmake/Utils/Library.cmake b/cmake/Utils/Library.cmake
new file mode 100644
index 00000000..26bdcef6
--- /dev/null
+++ b/cmake/Utils/Library.cmake
@@ -0,0 +1,28 @@
+function(add_target_from_obj target_name obj_target_name)
+ add_library(${target_name}_static STATIC $)
+ set_target_properties(
+ ${target_name}_static PROPERTIES
+ OUTPUT_NAME "${target_name}_static"
+ PREFIX "lib"
+ ARCHIVE_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/lib"
+ LIBRARY_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/lib"
+ )
+ add_library(${target_name}_shared SHARED $)
+ set_target_properties(
+ ${target_name}_shared PROPERTIES
+ OUTPUT_NAME "${target_name}"
+ PREFIX "lib"
+ ARCHIVE_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/lib"
+ LIBRARY_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/lib"
+ )
+ add_custom_target(${target_name})
+ add_dependencies(${target_name} ${target_name}_static ${target_name}_shared)
+ if (MSVC)
+ target_compile_definitions(${obj_target_name} PRIVATE MLC_EXPORTS)
+ set_target_properties(
+ ${obj_target_name} ${target_name}_shared ${target_name}_static
+ PROPERTIES
+ MSVC_RUNTIME_LIBRARY "MultiThreaded$<$:Debug>DLL"
+ )
+ endif()
+endfunction()
diff --git a/cmake/Utils/Sanitizer.cmake b/cmake/Utils/Sanitizer.cmake
new file mode 100644
index 00000000..c1facc19
--- /dev/null
+++ b/cmake/Utils/Sanitizer.cmake
@@ -0,0 +1,18 @@
+function(add_sanitizer_address target_name)
+ if(CMAKE_CXX_COMPILER_ID MATCHES "GNU|Clang|AppleClang")
+ include(CheckCXXCompilerFlag)
+ set (_saved_CRF ${CMAKE_REQUIRED_FLAGS})
+ set(CMAKE_REQUIRED_FLAGS "-fsanitize=address")
+ check_cxx_source_compiles("int main() { return 0; }" COMPILER_SUPPORTS_ASAN)
+ set (CMAKE_REQUIRED_FLAGS ${_saved_CRF})
+ get_target_property(_saved_type ${target_name} TYPE)
+ if (${_saved_type} STREQUAL "INTERFACE_LIBRARY")
+ set(_saved_type INTERFACE)
+ else()
+ set(_saved_type PRIVATE)
+ endif()
+ target_link_options(${target_name} ${_saved_type} "-fsanitize=address")
+ target_compile_options(${target_name} ${_saved_type} "-fsanitize=address" "-fno-omit-frame-pointer" "-g")
+ return()
+ endif()
+endfunction()
diff --git a/cpp/c_api.cc b/cpp/c_api.cc
new file mode 100644
index 00000000..4b4f6651
--- /dev/null
+++ b/cpp/c_api.cc
@@ -0,0 +1,187 @@
+#include "./registry.h"
+#include
+
+namespace mlc {
+namespace registry {
+TypeTable *TypeTable::Global() {
+ static TypeTable *instance = TypeTable::New();
+ return instance;
+}
+} // namespace registry
+} // namespace mlc
+
+using ::mlc::Any;
+using ::mlc::AnyView;
+using ::mlc::ErrorObj;
+using ::mlc::FuncObj;
+using ::mlc::Ref;
+using ::mlc::registry::TypeTable;
+
+namespace {
+thread_local Any last_error;
+MLC_REGISTER_FUNC("mlc.ffi.LoadDSO").set_body([](std::string name) { TypeTable::Get(nullptr)->LoadDSO(name); });
+} // namespace
+
+MLC_API MLCAny MLCGetLastError() {
+ MLCAny ret;
+ *static_cast(&ret) = std::move(last_error);
+ return ret;
+}
+
+MLC_API int32_t MLCTypeRegister(MLCTypeTableHandle _self, int32_t parent_type_index, const char *type_key,
+ int32_t type_index, MLCTypeInfo **out_type_info) {
+ MLC_SAFE_CALL_BEGIN();
+ *out_type_info = TypeTable::Get(_self)->TypeRegister(parent_type_index, type_index, type_key);
+ MLC_SAFE_CALL_END(&last_error);
+}
+
+MLC_API int32_t MLCTypeIndex2Info(MLCTypeTableHandle _self, int32_t type_index, MLCTypeInfo **ret) {
+ MLC_SAFE_CALL_BEGIN();
+ *ret = TypeTable::Get(_self)->GetTypeInfo(type_index);
+ MLC_SAFE_CALL_END(&last_error);
+}
+
+MLC_API int32_t MLCTypeKey2Info(MLCTypeTableHandle _self, const char *type_key, MLCTypeInfo **ret) {
+ MLC_SAFE_CALL_BEGIN();
+ *ret = TypeTable::Get(_self)->GetTypeInfo(type_key);
+ MLC_SAFE_CALL_END(&last_error);
+}
+
+MLC_API int32_t MLCTypeRegisterFields(MLCTypeTableHandle _self, int32_t type_index, int64_t num_fields,
+ MLCTypeField *fields) {
+ MLC_SAFE_CALL_BEGIN();
+ TypeTable::Get(_self)->SetFields(type_index, num_fields, fields);
+ MLC_SAFE_CALL_END(&last_error);
+}
+
+MLC_API int32_t MLCTypeRegisterStructure(MLCTypeTableHandle _self, int32_t type_index, int32_t structure_kind,
+ int64_t num_sub_structures, int32_t *sub_structure_indices,
+ int32_t *sub_structure_kinds) {
+ MLC_SAFE_CALL_BEGIN();
+ TypeTable::Get(_self)->SetStructure(type_index, structure_kind, num_sub_structures, sub_structure_indices,
+ sub_structure_kinds);
+ MLC_SAFE_CALL_END(&last_error);
+}
+
+MLC_API int32_t MLCTypeAddMethod(MLCTypeTableHandle _self, int32_t type_index, MLCTypeMethod method) {
+ MLC_SAFE_CALL_BEGIN();
+ TypeTable::Get(_self)->AddMethod(type_index, method);
+ MLC_SAFE_CALL_END(&last_error);
+}
+
+MLC_API int32_t MLCVTableGetGlobal(MLCTypeTableHandle _self, const char *key, MLCVTableHandle *ret) {
+ MLC_SAFE_CALL_BEGIN();
+ *ret = TypeTable::Get(_self)->GetGlobalVTable(key);
+ MLC_SAFE_CALL_END(&last_error);
+}
+
+MLC_API int32_t MLCVTableGetFunc(MLCVTableHandle vtable, int32_t type_index, int32_t allow_ancestor, MLCAny *ret) {
+ using ::mlc::registry::VTable;
+ MLC_SAFE_CALL_BEGIN();
+ *static_cast(ret) = static_cast(vtable)->GetFunc(type_index, allow_ancestor);
+ MLC_SAFE_CALL_END(&last_error);
+}
+
+MLC_API int32_t MLCVTableSetFunc(MLCVTableHandle vtable, int32_t type_index, MLCFunc *func, int32_t override_mode) {
+ using ::mlc::registry::VTable;
+ MLC_SAFE_CALL_BEGIN();
+ static_cast(vtable)->Set(type_index, static_cast(func), override_mode);
+ MLC_SAFE_CALL_END(&last_error);
+}
+
+MLC_API int32_t MLCDynTypeTypeTableCreate(MLCTypeTableHandle *ret) {
+ MLC_SAFE_CALL_BEGIN();
+ *ret = TypeTable::New();
+ MLC_SAFE_CALL_END(&last_error);
+}
+
+MLC_API int32_t MLCDynTypeTypeTableDestroy(MLCTypeTableHandle handle) {
+ MLC_SAFE_CALL_BEGIN();
+ delete static_cast(handle);
+ MLC_SAFE_CALL_END(&last_error);
+}
+
+MLC_API int32_t MLCAnyIncRef(MLCAny *any) {
+ MLC_SAFE_CALL_BEGIN();
+ if (!::mlc::base::IsTypeIndexPOD(any->type_index)) {
+ ::mlc::base::IncRef(any->v.v_obj);
+ }
+ MLC_SAFE_CALL_END(&last_error);
+}
+
+MLC_API int32_t MLCAnyDecRef(MLCAny *any) {
+ MLC_SAFE_CALL_BEGIN();
+ if (!::mlc::base::IsTypeIndexPOD(any->type_index)) {
+ ::mlc::base::DecRef(any->v.v_obj);
+ }
+ MLC_SAFE_CALL_END(&last_error);
+}
+
+MLC_API int32_t MLCAnyInplaceViewToOwned(MLCAny *any) {
+ MLC_SAFE_CALL_BEGIN();
+ MLCAny tmp{};
+ static_cast(tmp) = static_cast(*any);
+ *any = tmp;
+ MLC_SAFE_CALL_END(&last_error);
+}
+
+MLC_API int32_t MLCFuncSetGlobal(MLCTypeTableHandle _self, const char *name, MLCAny func, int allow_override) {
+ MLC_SAFE_CALL_BEGIN();
+ TypeTable::Get(_self)->SetFunc(name, static_cast(func), allow_override);
+ MLC_SAFE_CALL_END(&last_error);
+}
+
+MLC_API int32_t MLCFuncGetGlobal(MLCTypeTableHandle _self, const char *name, MLCAny *ret) {
+ MLC_SAFE_CALL_BEGIN();
+ *static_cast(ret) = TypeTable::Get(_self)->GetFunc(name);
+ MLC_SAFE_CALL_END(&last_error);
+}
+
+MLC_API int32_t MLCFuncSafeCall(MLCFunc *func, int32_t num_args, MLCAny *args, MLCAny *ret) {
+#if defined(__APPLE__)
+ if (func->_mlc_header.type_index != MLCTypeIndex::kMLCFunc) {
+ std::cout << "func->type_index: " << func->_mlc_header.type_index << std::endl;
+ }
+#endif
+ return func->safe_call(func, num_args, args, ret);
+}
+
+MLC_API int32_t MLCFuncCreate(void *self, MLCDeleterType deleter, MLCFuncSafeCallType safe_call, MLCAny *ret) {
+ MLC_SAFE_CALL_BEGIN();
+ *static_cast(ret) = FuncObj::FromForeign(self, deleter, safe_call);
+ MLC_SAFE_CALL_END(&last_error);
+}
+
+MLC_API int32_t MLCErrorCreate(const char *kind, int64_t num_bytes, const char *bytes, MLCAny *ret) {
+ MLC_SAFE_CALL_BEGIN();
+ *static_cast(ret) = Ref::New(kind, num_bytes, bytes);
+ MLC_SAFE_CALL_END(&last_error);
+}
+
+MLC_API int32_t MLCErrorGetInfo(MLCAny error, int32_t *num_strs, const char ***strs) {
+ MLC_SAFE_CALL_BEGIN();
+ thread_local std::vector ret;
+ static_cast(error).operator Ref()->GetInfo(&ret);
+ *num_strs = static_cast(ret.size());
+ *strs = ret.data();
+ MLC_SAFE_CALL_END(&last_error);
+}
+
+MLC_API int32_t MLCExtObjCreate(int32_t num_bytes, int32_t type_index, MLCAny *ret) {
+ MLC_SAFE_CALL_BEGIN();
+ *static_cast(ret) = mlc::AllocExternObject(type_index, num_bytes);
+ MLC_SAFE_CALL_END(&last_error);
+}
+
+MLC_API int32_t _MLCExtObjDeleteImpl(void *objptr) {
+ MLC_SAFE_CALL_BEGIN();
+ ::mlc::core::DeleteExternObject(static_cast<::mlc::Object *>(objptr));
+ MLC_SAFE_CALL_END(&last_error);
+}
+
+MLC_API void MLCExtObjDelete(void *objptr) {
+ if (int32_t error_code = _MLCExtObjDeleteImpl(objptr)) {
+ std::cerr << "Error code (" << error_code << ") when deleting external object: " << last_error << std::endl;
+ std::abort();
+ }
+}
diff --git a/cpp/c_api_tests.cc b/cpp/c_api_tests.cc
new file mode 100644
index 00000000..838e88c5
--- /dev/null
+++ b/cpp/c_api_tests.cc
@@ -0,0 +1,195 @@
+#include
+
+namespace mlc {
+namespace {
+
+/**************** FFI ****************/
+
+MLC_REGISTER_FUNC("mlc.testing.cxx_none").set_body([]() -> void { return; });
+MLC_REGISTER_FUNC("mlc.testing.cxx_null").set_body([]() -> void * { return nullptr; });
+MLC_REGISTER_FUNC("mlc.testing.cxx_int").set_body([](int x) -> int { return x; });
+MLC_REGISTER_FUNC("mlc.testing.cxx_float").set_body([](double x) -> double { return x; });
+MLC_REGISTER_FUNC("mlc.testing.cxx_ptr").set_body([](void *x) -> void * { return x; });
+MLC_REGISTER_FUNC("mlc.testing.cxx_dtype").set_body([](DLDataType x) { return x; });
+MLC_REGISTER_FUNC("mlc.testing.cxx_device").set_body([](DLDevice x) { return x; });
+MLC_REGISTER_FUNC("mlc.testing.cxx_raw_str").set_body([](const char *x) { return x; });
+
+/**************** Reflection ****************/
+
+struct TestingCClassObj : public Object {
+ int8_t i8;
+ int16_t i16;
+ int32_t i32;
+ int64_t i64;
+ float f32;
+ double f64;
+ void *raw_ptr;
+ DLDataType dtype;
+ DLDevice device;
+ Any any;
+ Func func;
+ UList ulist;
+ UDict udict;
+ Str str_;
+ Str str_readonly;
+
+ List list_any;
+ List> list_list_int;
+ Dict dict_any_any;
+ Dict dict_str_any;
+ Dict dict_any_str;
+ Dict> dict_str_list_int;
+
+ Optional opt_i64;
+ Optional opt_f64;
+ Optional opt_raw_ptr;
+ Optional opt_dtype;
+ Optional opt_device;
+ Optional opt_func;
+ Optional opt_ulist;
+ Optional opt_udict;
+ Optional opt_str;
+
+ Optional> opt_list_any;
+ Optional>> opt_list_list_int;
+ Optional> opt_dict_any_any;
+ Optional> opt_dict_str_any;
+ Optional> opt_dict_any_str;
+ Optional>> opt_dict_str_list_int;
+
+ explicit TestingCClassObj(int8_t i8, int16_t i16, int32_t i32, int64_t i64, float f32, double f64, void *raw_ptr,
+ DLDataType dtype, DLDevice device, Any any, Func func, UList ulist, UDict udict, Str str_,
+ Str str_readonly, List list_any, List> list_list_int,
+ Dict dict_any_any, Dict dict_str_any, Dict dict_any_str,
+ Dict> dict_str_list_int, Optional opt_i64, Optional opt_f64,
+ Optional opt_raw_ptr, Optional opt_dtype, Optional opt_device,
+ Optional opt_func, Optional opt_ulist, Optional opt_udict,
+ Optional opt_str, Optional> opt_list_any,
+ Optional>> opt_list_list_int, Optional> opt_dict_any_any,
+ Optional> opt_dict_str_any, Optional> opt_dict_any_str,
+ Optional>> opt_dict_str_list_int)
+ : i8(i8), i16(i16), i32(i32), i64(i64), f32(f32), f64(f64), raw_ptr(raw_ptr), dtype(dtype), device(device),
+ any(any), func(func), ulist(ulist), udict(udict), str_(str_), str_readonly(str_readonly), list_any(list_any),
+ list_list_int(list_list_int), dict_any_any(dict_any_any), dict_str_any(dict_str_any),
+ dict_any_str(dict_any_str), dict_str_list_int(dict_str_list_int), opt_i64(opt_i64), opt_f64(opt_f64),
+ opt_raw_ptr(opt_raw_ptr), opt_dtype(opt_dtype), opt_device(opt_device), opt_func(opt_func),
+ opt_ulist(opt_ulist), opt_udict(opt_udict), opt_str(opt_str), opt_list_any(opt_list_any),
+ opt_list_list_int(opt_list_list_int), opt_dict_any_any(opt_dict_any_any), opt_dict_str_any(opt_dict_str_any),
+ opt_dict_any_str(opt_dict_any_str), opt_dict_str_list_int(opt_dict_str_list_int) {}
+
+ int64_t i64_plus_one() const { return i64 + 1; }
+
+ MLC_DEF_DYN_TYPE(TestingCClassObj, Object, "mlc.testing.c_class");
+};
+
+struct TestingCClass : public ObjectRef {
+ MLC_DEF_OBJ_REF(TestingCClass, TestingCClassObj, ObjectRef)
+ .Field("i8", &TestingCClassObj::i8)
+ .Field("i16", &TestingCClassObj::i16)
+ .Field("i32", &TestingCClassObj::i32)
+ .Field("i64", &TestingCClassObj::i64)
+ .Field("f32", &TestingCClassObj::f32)
+ .Field("f64", &TestingCClassObj::f64)
+ .Field("raw_ptr", &TestingCClassObj::raw_ptr)
+ .Field("dtype", &TestingCClassObj::dtype)
+ .Field("device", &TestingCClassObj::device)
+ .Field("any", &TestingCClassObj::any)
+ .Field("func", &TestingCClassObj::func)
+ .Field("ulist", &TestingCClassObj::ulist)
+ .Field("udict", &TestingCClassObj::udict)
+ .Field("str_", &TestingCClassObj::str_)
+ .FieldReadOnly("str_readonly", &TestingCClassObj::str_readonly)
+ .Field("list_any", &TestingCClassObj::list_any)
+ .Field("list_list_int", &TestingCClassObj::list_list_int)
+ .Field("dict_any_any", &TestingCClassObj::dict_any_any)
+ .Field("dict_str_any", &TestingCClassObj::dict_str_any)
+ .Field("dict_any_str", &TestingCClassObj::dict_any_str)
+ .Field("dict_str_list_int", &TestingCClassObj::dict_str_list_int)
+ .Field("opt_i64", &TestingCClassObj::opt_i64)
+ .Field("opt_f64", &TestingCClassObj::opt_f64)
+ .Field("opt_raw_ptr", &TestingCClassObj::opt_raw_ptr)
+ .Field("opt_dtype", &TestingCClassObj::opt_dtype)
+ .Field("opt_device", &TestingCClassObj::opt_device)
+ .Field("opt_func", &TestingCClassObj::opt_func)
+ .Field("opt_ulist", &TestingCClassObj::opt_ulist)
+ .Field("opt_udict", &TestingCClassObj::opt_udict)
+ .Field("opt_str", &TestingCClassObj::opt_str)
+ .Field("opt_list_any", &TestingCClassObj::opt_list_any)
+ .Field("opt_list_list_int", &TestingCClassObj::opt_list_list_int)
+ .Field("opt_dict_any_any", &TestingCClassObj::opt_dict_any_any)
+ .Field("opt_dict_str_any", &TestingCClassObj::opt_dict_str_any)
+ .Field("opt_dict_any_str", &TestingCClassObj::opt_dict_any_str)
+ .Field("opt_dict_str_list_int", &TestingCClassObj::opt_dict_str_list_int)
+ .MemFn("i64_plus_one", &TestingCClassObj::i64_plus_one)
+ .StaticFn("__init__",
+ InitOf, List>, Dict, Dict,
+ Dict, Dict>, Optional, Optional, Optional,
+ Optional, Optional, Optional, Optional, Optional,
+ Optional, Optional>, Optional>>, Optional>,
+ Optional>, Optional>, Optional>>>);
+};
+
+/**************** Traceback ****************/
+
+MLC_REGISTER_FUNC("mlc.testing.throw_exception_from_c").set_body([]() {
+ // Simply throw an exception in C++
+ MLC_THROW(ValueError) << "This is an error message";
+});
+
+MLC_REGISTER_FUNC("mlc.testing.throw_exception_from_ffi_in_c").set_body([](FuncObj *func) {
+ // call a Python function which throws an exception
+ (*func)();
+});
+
+MLC_REGISTER_FUNC("mlc.testing.throw_exception_from_ffi").set_body([](FuncObj *func) {
+ // Call a Python function `func` which throws an exception and returns it
+ Any ret;
+ try {
+ (*func)();
+ } catch (Exception &error) {
+ ret = std::move(error.data_);
+ }
+ return ret;
+});
+
+/**************** Type checking ****************/
+
+MLC_REGISTER_FUNC("mlc.testing.nested_type_checking_list").set_body([](Str name) {
+ if (name == "list") {
+ using Type = UList;
+ return Func([](Type v) { return v; });
+ }
+ if (name == "list[Any]") {
+ using Type = List;
+ return Func([](Type v) { return v; });
+ }
+ if (name == "list[list[int]]") {
+ using Type = List>;
+ return Func([](Type v) { return v; });
+ }
+ if (name == "dict") {
+ using Type = UDict;
+ return Func([](Type v) { return v; });
+ }
+ if (name == "dict[str, Any]") {
+ using Type = Dict;
+ return Func([](Type v) { return v; });
+ }
+ if (name == "dict[Any, str]") {
+ using Type = Dict;
+ return Func([](Type v) { return v; });
+ }
+ if (name == "dict[Any, Any]") {
+ using Type = Dict;
+ return Func([](Type v) { return v; });
+ }
+ if (name == "dict[str, list[int]]") {
+ using Type = Dict>;
+ return Func([](Type v) { return v; });
+ }
+ MLC_UNREACHABLE();
+});
+
+} // namespace
+} // namespace mlc
diff --git a/cpp/json.cc b/cpp/json.cc
new file mode 100644
index 00000000..2af5b474
--- /dev/null
+++ b/cpp/json.cc
@@ -0,0 +1,497 @@
+#include
+#include
+#include
+#include
+
+namespace mlc {
+namespace core {
+namespace {
+
+mlc::Str Serialize(Any any);
+Any Deserialize(const char *json_str, int64_t json_str_len);
+Any JSONLoads(const char *json_str, int64_t json_str_len);
+MLC_INLINE Any Deserialize(const char *json_str) { return Deserialize(json_str, -1); }
+MLC_INLINE Any Deserialize(const Str &json_str) { return Deserialize(json_str->data(), json_str->size()); }
+MLC_INLINE Any JSONLoads(const char *json_str) { return JSONLoads(json_str, -1); }
+MLC_INLINE Any JSONLoads(const Str &json_str) { return JSONLoads(json_str->data(), json_str->size()); }
+
+inline mlc::Str Serialize(Any any) {
+ using mlc::base::TypeTraits;
+ std::vector type_keys;
+ auto get_json_type_index = [type_key2index = std::unordered_map(),
+ &type_keys](const char *type_key) mutable -> int32_t {
+ if (auto it = type_key2index.find(type_key); it != type_key2index.end()) {
+ return it->second;
+ }
+ int32_t type_index = static_cast(type_key2index.size());
+ type_key2index[type_key] = type_index;
+ type_keys.push_back(type_key);
+ return type_index;
+ };
+ using TObj2Idx = std::unordered_map;
+ using TJsonTypeIndex = decltype(get_json_type_index);
+ struct Emitter {
+ // clang-format off
+ MLC_INLINE void operator()(MLCTypeField *, const Any *any) { EmitAny(any); }
+ MLC_INLINE void operator()(MLCTypeField *, ObjectRef *obj) { if (Object *v = obj->get()) EmitObject(v); else EmitNil(); }
+ MLC_INLINE void operator()(MLCTypeField *, Optional *opt) { if (Object *v = opt->get()) EmitObject(v); else EmitNil(); }
+ MLC_INLINE void operator()(MLCTypeField *, Optional *opt) { if (const int64_t *v = opt->get()) EmitInt(*v); else EmitNil(); }
+ MLC_INLINE void operator()(MLCTypeField *, Optional *opt) { if (const double *v = opt->get()) EmitFloat(*v); else EmitNil(); }
+ MLC_INLINE void operator()(MLCTypeField *, Optional *opt) { if (const DLDevice *v = opt->get()) EmitDevice(*v); else EmitNil(); }
+ MLC_INLINE void operator()(MLCTypeField *, Optional *opt) { if (const DLDataType *v = opt->get()) EmitDType(*v); else EmitNil(); }
+ MLC_INLINE void operator()(MLCTypeField *, int8_t *v) { EmitInt(static_cast(*v)); }
+ MLC_INLINE void operator()(MLCTypeField *, int16_t *v) { EmitInt(static_cast(*v)); }
+ MLC_INLINE void operator()(MLCTypeField *, int32_t *v) { EmitInt(static_cast(*v)); }
+ MLC_INLINE void operator()(MLCTypeField *, int64_t *v) { EmitInt(static_cast(*v)); }
+ MLC_INLINE void operator()(MLCTypeField *, float *v) { EmitFloat(static_cast(*v)); }
+ MLC_INLINE void operator()(MLCTypeField *, double *v) { EmitFloat(static_cast(*v)); }
+ MLC_INLINE void operator()(MLCTypeField *, DLDataType *v) { EmitDType(*v); }
+ MLC_INLINE void operator()(MLCTypeField *, DLDevice *v) { EmitDevice(*v); }
+ MLC_INLINE void operator()(MLCTypeField *, Optional *) { MLC_THROW(TypeError) << "Unserializable type: void *"; }
+ MLC_INLINE void operator()(MLCTypeField *, void **) { MLC_THROW(TypeError) << "Unserializable type: void *"; }
+ MLC_INLINE void operator()(MLCTypeField *, const char **) { MLC_THROW(TypeError) << "Unserializable type: const char *"; }
+ // clang-format on
+ inline void EmitNil() { (*os) << ", null"; }
+ inline void EmitFloat(double v) { (*os) << ", " << std::fixed << std::setprecision(19) << v; }
+ inline void EmitInt(int64_t v) {
+ int32_t type_int = (*get_json_type_index)(TypeTraits::type_str);
+ (*os) << ", [" << type_int << ", " << v << "]";
+ }
+ inline void EmitDevice(DLDevice v) {
+ int32_t type_device = (*get_json_type_index)(TypeTraits::type_str);
+ (*os) << ", [" << type_device << ", " << ::mlc::base::TypeTraits::__str__(v) << "]";
+ }
+ inline void EmitDType(DLDataType v) {
+ int32_t type_dtype = (*get_json_type_index)(TypeTraits::type_str);
+ (*os) << ", [" << type_dtype << ", " << ::mlc::base::TypeTraits::__str__(v) << "]";
+ }
+ inline void EmitAny(const Any *any) {
+ int32_t type_index = any->type_index;
+ if (type_index == kMLCNone) {
+ EmitNil();
+ } else if (type_index == kMLCInt) {
+ EmitInt(any->operator int64_t());
+ } else if (type_index == kMLCFloat) {
+ EmitFloat(any->operator double());
+ } else if (type_index == kMLCDevice) {
+ EmitDevice(any->operator DLDevice());
+ } else if (type_index == kMLCDataType) {
+ EmitDType(any->operator DLDataType());
+ } else if (type_index >= kMLCStaticObjectBegin) {
+ EmitObject(any->operator Object *());
+ } else {
+ MLC_THROW(TypeError) << "Cannot serialize type: " << ::mlc::base::TypeIndex2TypeKey(type_index);
+ }
+ }
+ inline void EmitObject(Object *obj) {
+ if (!obj) {
+ MLC_THROW(InternalError) << "This should never happen: null object pointer during EmitObject";
+ }
+ int32_t obj_idx = obj2index->at(obj);
+ if (obj_idx == -1) {
+ MLC_THROW(InternalError) << "This should never happen: topological ordering violated";
+ }
+ (*os) << ", " << obj_idx;
+ }
+ std::ostringstream *os;
+ TJsonTypeIndex *get_json_type_index;
+ const TObj2Idx *obj2index;
+ };
+
+ std::ostringstream os;
+ auto on_visit = [get_json_type_index = &get_json_type_index, os = &os, is_first_object = true](
+ Object *object, MLCTypeInfo *type_info, const TObj2Idx &obj2index) mutable -> void {
+ Emitter emitter{os, get_json_type_index, &obj2index};
+ if (is_first_object) {
+ is_first_object = false;
+ } else {
+ os->put(',');
+ }
+ if (StrObj *str = object->TryCast()) {
+ str->PrintEscape(*os);
+ return;
+ }
+ (*os) << '[' << (*get_json_type_index)(type_info->type_key);
+ if (UListObj *list = object->TryCast()) {
+ for (Any &any : *list) {
+ emitter(nullptr, &any);
+ }
+ } else if (UDictObj *dict = object->TryCast()) {
+ for (auto &kv : *dict) {
+ emitter(nullptr, &kv.first);
+ emitter(nullptr, &kv.second);
+ }
+ } else if (object->IsInstance() || object->IsInstance()) {
+ MLC_THROW(TypeError) << "Unserializable type: " << object->GetTypeKey();
+ } else {
+ VisitFields(object, type_info, emitter);
+ }
+ os->put(']');
+ };
+ os << "{\"values\": [";
+ if (any.type_index >= kMLCStaticObjectBegin) { // Topological the objects according to dependency
+ TopoVisit(any.operator Object *(), nullptr, on_visit);
+ } else if (any.type_index == kMLCNone) {
+ os << "null";
+ } else if (any.type_index == kMLCInt) {
+ int32_t type_int = get_json_type_index(TypeTraits::type_str);
+ int64_t v = any;
+ os << "[" << type_int << ", " << v << "]";
+ } else if (any.type_index == kMLCFloat) {
+ double v = any;
+ os << v;
+ } else if (any.type_index == kMLCDevice) {
+ int32_t type_device = get_json_type_index(TypeTraits::type_str);
+ DLDevice v = any;
+ os << "[" << type_device << ", \"" << TypeTraits::__str__(v) << "\"]";
+ } else if (any.type_index == kMLCDataType) {
+ int32_t type_dtype = get_json_type_index(TypeTraits::type_str);
+ DLDataType v = any;
+ os << "[" << type_dtype << ", \"" << TypeTraits::__str__(v) << "\"]";
+ } else {
+ MLC_THROW(TypeError) << "Cannot serialize type: " << mlc::base::TypeIndex2TypeKey(any.type_index);
+ }
+ os << "], \"type_keys\": [";
+ for (size_t i = 0; i < type_keys.size(); ++i) {
+ if (i > 0) {
+ os << ", ";
+ }
+ os << '"' << type_keys[i] << '\"';
+ }
+ os << "]}";
+ return os.str();
+}
+
+inline Any Deserialize(const char *json_str, int64_t json_str_len) {
+ MLCVTableHandle init_vtable;
+ MLCVTableGetGlobal(nullptr, "__init__", &init_vtable);
+ // Step 0. Parse JSON string
+ UDict json_obj = JSONLoads(json_str, json_str_len);
+ // Step 1. type_key => constructors
+ UList type_keys = json_obj->at("type_keys");
+ std::vector constructors;
+ constructors.reserve(type_keys.size());
+ for (Str type_key : type_keys) {
+ Any init_func;
+ int32_t type_index = ::mlc::base::TypeKey2TypeIndex(type_key->data());
+ MLCVTableGetFunc(init_vtable, type_index, false, &init_func);
+ if (!::mlc::base::IsTypeIndexNone(init_func.type_index)) {
+ constructors.push_back(init_func.operator Func());
+ } else {
+ MLC_THROW(InternalError) << "Method `__init__` is not defined for type " << type_key;
+ }
+ }
+ auto invoke_init = [&constructors](UList args) {
+ int32_t json_type_index = args[0];
+ Any ret;
+ ::mlc::base::FuncCall(constructors.at(json_type_index).get(), static_cast(args.size()) - 1,
+ args->data() + 1, &ret);
+ return ret;
+ };
+ // Step 2. Translate JSON object to objects
+ UList values = json_obj->at("values");
+ for (int64_t i = 0; i < values->size(); ++i) {
+ Any obj = values[i];
+ if (obj.type_index == kMLCList) {
+ UList list = obj.operator UList();
+ int32_t json_type_index = list[0];
+ for (int64_t j = 1; j < list.size(); ++j) {
+ Any arg = list[j];
+ if (arg.type_index == kMLCInt) {
+ int64_t k = arg;
+ if (k < i) {
+ list[j] = values[k];
+ } else {
+ MLC_THROW(ValueError) << "Invalid reference when parsing type `" << type_keys[json_type_index]
+ << "`: referring #" << k << " at #" << i << ". v = " << obj;
+ }
+ } else if (arg.type_index == kMLCList) {
+ list[j] = invoke_init(arg.operator UList());
+ } else if (arg.type_index == kMLCStr || arg.type_index == kMLCFloat || arg.type_index == kMLCNone) {
+ // Do nothing
+ } else {
+ MLC_THROW(ValueError) << "Unexpected value: " << arg;
+ }
+ }
+ values[i] = invoke_init(list);
+ } else if (obj.type_index == kMLCInt) {
+ int32_t k = obj;
+ values[i] = values[k];
+ } else if (obj.type_index == kMLCStr) {
+ // Do nothing
+ } else {
+ MLC_THROW(ValueError) << "Unexpected value: " << obj;
+ }
+ }
+ return values->back();
+}
+
+inline Any JSONLoads(const char *json_str, int64_t json_str_len) {
+ struct JSONParser {
+ Any Parse() {
+ SkipWhitespace();
+ Any result = ParseValue();
+ SkipWhitespace();
+ if (i != json_str_len) {
+ MLC_THROW(ValueError) << "JSON parsing failure at position " << i
+ << ": Extra data after valid JSON. JSON string: " << json_str;
+ }
+ return result;
+ }
+
+ void ExpectChar(char c) {
+ if (json_str[i] == c) {
+ ++i;
+ } else {
+ MLC_THROW(ValueError) << "JSON parsing failure at position " << i << ": Expected '" << c << "' but got '"
+ << json_str[i] << "'. JSON string: " << json_str;
+ }
+ }
+
+ char PeekChar() { return i < json_str_len ? json_str[i] : '\0'; }
+
+ void SkipWhitespace() {
+ while (i < json_str_len && std::isspace(json_str[i])) {
+ ++i;
+ }
+ }
+
+ void ExpectString(const char *expected, int64_t len) {
+ if (i + len <= json_str_len && std::strncmp(json_str + i, expected, len) == 0) {
+ i = i + len;
+ } else {
+ MLC_THROW(ValueError) << "JSON parsing failure at position " << i << ": Expected '" << expected
+ << ". JSON string: " << json_str;
+ }
+ }
+
+ Any ParseNull() {
+ ExpectString("null", 4);
+ return Any(nullptr);
+ }
+
+ Any ParseBoolean() {
+ if (PeekChar() == 't') {
+ ExpectString("true", 4);
+ return Any(1);
+ } else {
+ ExpectString("false", 5);
+ return Any(0);
+ }
+ }
+
+ Any ParseNumber() {
+ int64_t start = i;
+ // Identify the end of the numeric sequence
+ while (i < json_str_len) {
+ char c = json_str[i];
+ if (c == '.' || c == 'e' || c == 'E' || c == '+' || c == '-' || std::isdigit(c)) {
+ ++i;
+ } else {
+ break;
+ }
+ }
+ std::string num_str(json_str + start, i - start);
+ std::size_t pos = 0;
+ try {
+ // Attempt to parse as integer
+ int64_t int_value = std::stoll(num_str, &pos);
+ if (pos == num_str.size()) {
+ i = start + static_cast(pos); // Update the main index
+ return Any(int_value);
+ }
+ } catch (const std::invalid_argument &) { // Not an integer, proceed to parse as double
+ } catch (const std::out_of_range &) { // Integer out of range, proceed to parse as double
+ }
+ try {
+ // Attempt to parse as double
+ double double_value = std::stod(num_str, &pos);
+ if (pos == num_str.size()) {
+ i = start + pos; // Update the main index
+ return Any(double_value);
+ }
+ } catch (const std::invalid_argument &) {
+ } catch (const std::out_of_range &) {
+ }
+ MLC_THROW(ValueError) << "JSON parsing failure at position " << i
+ << ": Invalid number format. JSON string: " << json_str;
+ MLC_UNREACHABLE();
+ }
+
+ Any ParseStr() {
+ ExpectChar('"');
+ std::ostringstream oss;
+ while (true) {
+ if (i >= json_str_len) {
+ MLC_THROW(ValueError) << "JSON parsing failure at position " << i
+ << ": Unterminated string. JSON string: " << json_str;
+ }
+ char c = json_str[i++];
+ if (c == '"') {
+ // End of string
+ return Any(Str(oss.str()));
+ } else if (c == '\\') {
+ // Handle escape sequences
+ if (i >= json_str_len) {
+ MLC_THROW(ValueError) << "JSON parsing failure at position " << i
+ << ": Incomplete escape sequence. JSON string: " << json_str;
+ }
+ char next = json_str[i++];
+ switch (next) {
+ case 'n':
+ oss << '\n';
+ break;
+ case 't':
+ oss << '\t';
+ break;
+ case 'r':
+ oss << '\r';
+ break;
+ case '\\':
+ oss << '\\';
+ break;
+ case '"':
+ oss << '\"';
+ break;
+ case 'x': {
+ if (i + 1 < json_str_len && std::isxdigit(json_str[i]) && std::isxdigit(json_str[i + 1])) {
+ int32_t value = std::stoi(std::string(json_str + i, 2), nullptr, 16);
+ oss << static_cast(value);
+ i += 2;
+ } else {
+ MLC_THROW(ValueError) << "Invalid hexadecimal escape sequence at position " << i - 2
+ << " in string: " << json_str;
+ }
+ break;
+ }
+ case 'u': {
+ if (i + 3 < json_str_len && std::isxdigit(json_str[i]) && std::isxdigit(json_str[i + 1]) &&
+ std::isxdigit(json_str[i + 2]) && std::isxdigit(json_str[i + 3])) {
+ int32_t codepoint = std::stoi(std::string(json_str + i, 4), nullptr, 16);
+ if (codepoint <= 0x7F) {
+ // 1-byte UTF-8
+ oss << static_cast(codepoint);
+ } else if (codepoint <= 0x7FF) {
+ // 2-byte UTF-8
+ oss << static_cast(0xC0 | (codepoint >> 6));
+ oss << static_cast(0x80 | (codepoint & 0x3F));
+ } else {
+ // 3-byte UTF-8
+ oss << static_cast(0xE0 | (codepoint >> 12));
+ oss << static_cast(0x80 | ((codepoint >> 6) & 0x3F));
+ oss << static_cast(0x80 | (codepoint & 0x3F));
+ }
+ i += 4;
+ } else {
+ MLC_THROW(ValueError) << "Invalid Unicode escape sequence at position " << i - 2
+ << " in string: " << json_str;
+ }
+ break;
+ }
+ default:
+ // Unrecognized escape sequence, interpret literally
+ oss << next;
+ break;
+ }
+ } else {
+ // Regular character, copy as-is
+ oss << c;
+ }
+ }
+ }
+
+ UList ParseArray() {
+ UList arr;
+ ExpectChar('[');
+ SkipWhitespace();
+ if (PeekChar() == ']') {
+ ExpectChar(']');
+ return arr;
+ }
+ while (true) {
+ SkipWhitespace();
+ arr.push_back(ParseValue());
+ SkipWhitespace();
+ if (PeekChar() == ']') {
+ ExpectChar(']');
+ return arr;
+ }
+ ExpectChar(',');
+ }
+ }
+
+ Any ParseObject() {
+ UDict obj;
+ ExpectChar('{');
+ SkipWhitespace();
+ if (PeekChar() == '}') {
+ ExpectChar('}');
+ return Any(obj);
+ }
+ while (true) {
+ SkipWhitespace();
+ Any key = ParseStr();
+ SkipWhitespace();
+ ExpectChar(':');
+ SkipWhitespace();
+ Any value = ParseValue();
+ obj[key] = value;
+ SkipWhitespace();
+ if (PeekChar() == '}') {
+ ExpectChar('}');
+ return Any(obj);
+ }
+ ExpectChar(',');
+ }
+ }
+
+ Any ParseValue() {
+ SkipWhitespace();
+ char c = PeekChar();
+ if (c == '"') {
+ return ParseStr();
+ } else if (c == '{') {
+ return ParseObject();
+ } else if (c == '[') {
+ return ParseArray();
+ } else if (c == 'n') {
+ return ParseNull();
+ } else if (c == 't' || c == 'f') {
+ return ParseBoolean();
+ } else if (std::isdigit(c) || c == '-') {
+ return ParseNumber();
+ } else {
+ MLC_THROW(ValueError) << "JSON parsing failure at position " << i << ": Unexpected character: " << c
+ << ". JSON string: " << json_str;
+ }
+ MLC_UNREACHABLE();
+ }
+ int64_t i;
+ int64_t json_str_len;
+ const char *json_str;
+ };
+ if (json_str_len < 0) {
+ json_str_len = static_cast(std::strlen(json_str));
+ }
+ return JSONParser{0, json_str_len, json_str}.Parse();
+}
+
+MLC_REGISTER_FUNC("mlc.core.JSONLoads").set_body([](AnyView json_str) {
+ if (json_str.type_index == kMLCRawStr) {
+ return ::mlc::core::JSONLoads(json_str.operator const char *());
+ } else {
+ ::mlc::Str str = json_str;
+ return ::mlc::core::JSONLoads(str);
+ }
+});
+MLC_REGISTER_FUNC("mlc.core.JSONSerialize").set_body(::mlc::core::Serialize);
+MLC_REGISTER_FUNC("mlc.core.JSONDeserialize").set_body([](AnyView json_str) {
+ if (json_str.type_index == kMLCRawStr) {
+ return ::mlc::core::Deserialize(json_str.operator const char *());
+ } else {
+ return ::mlc::core::Deserialize(json_str.operator ::mlc::Str());
+ }
+});
+} // namespace
+} // namespace core
+} // namespace mlc
diff --git a/cpp/printer.cc b/cpp/printer.cc
new file mode 100644
index 00000000..30d7db03
--- /dev/null
+++ b/cpp/printer.cc
@@ -0,0 +1,1044 @@
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+
+namespace mlc {
+namespace printer {
+namespace {
+using ByteSpan = std::pair;
+
+class DocPrinter {
+public:
+ explicit DocPrinter(const PrinterConfig &options);
+ virtual ~DocPrinter() = default;
+
+ void Append(const Node &doc);
+ void Append(const Node &doc, const PrinterConfig &cfg);
+ ::mlc::Str GetString() const;
+
+protected:
+ void PrintDoc(const Node &doc);
+ virtual void PrintTypedDoc(const Literal &doc) = 0;
+ virtual void PrintTypedDoc(const Id &doc) = 0;
+ virtual void PrintTypedDoc(const Attr &doc) = 0;
+ virtual void PrintTypedDoc(const Index &doc) = 0;
+ virtual void PrintTypedDoc(const Operation &doc) = 0;
+ virtual void PrintTypedDoc(const Call &doc) = 0;
+ virtual void PrintTypedDoc(const Lambda &doc) = 0;
+ virtual void PrintTypedDoc(const List &doc) = 0;
+ virtual void PrintTypedDoc(const Tuple &doc) = 0;
+ virtual void PrintTypedDoc(const Dict &doc) = 0;
+ virtual void PrintTypedDoc(const Slice &doc) = 0;
+ virtual void PrintTypedDoc(const StmtBlock &doc) = 0;
+ virtual void PrintTypedDoc(const Assign &doc) = 0;
+ virtual void PrintTypedDoc(const If &doc) = 0;
+ virtual void PrintTypedDoc(const While &doc) = 0;
+ virtual void PrintTypedDoc(const For &doc) = 0;
+ virtual void PrintTypedDoc(const With &doc) = 0;
+ virtual void PrintTypedDoc(const ExprStmt &doc) = 0;
+ virtual void PrintTypedDoc(const Assert &doc) = 0;
+ virtual void PrintTypedDoc(const Return &doc) = 0;
+ virtual void PrintTypedDoc(const Function &doc) = 0;
+ virtual void PrintTypedDoc(const Class &doc) = 0;
+ virtual void PrintTypedDoc(const Comment &doc) = 0;
+ virtual void PrintTypedDoc(const DocString &doc) = 0;
+
+ void PrintTypedDoc(const NodeObj *doc) {
+ using PrinterVTable = std::unordered_map>;
+ // clang-format off
+ #define MLC_PRINTER_PRINT_TYPED_DOC_ENTRY_(Type) { Type::TObj::_type_index, [](DocPrinter *printer, const NodeObj *doc) { printer->PrintTypedDoc(Type(doc->Cast())); } }
+ // clang-format on
+ static PrinterVTable vtable{
+ MLC_PRINTER_PRINT_TYPED_DOC_ENTRY_(Literal), MLC_PRINTER_PRINT_TYPED_DOC_ENTRY_(Id),
+ MLC_PRINTER_PRINT_TYPED_DOC_ENTRY_(Attr), MLC_PRINTER_PRINT_TYPED_DOC_ENTRY_(Index),
+ MLC_PRINTER_PRINT_TYPED_DOC_ENTRY_(Operation), MLC_PRINTER_PRINT_TYPED_DOC_ENTRY_(Call),
+ MLC_PRINTER_PRINT_TYPED_DOC_ENTRY_(Lambda), MLC_PRINTER_PRINT_TYPED_DOC_ENTRY_(List),
+ MLC_PRINTER_PRINT_TYPED_DOC_ENTRY_(Tuple), MLC_PRINTER_PRINT_TYPED_DOC_ENTRY_(Dict),
+ MLC_PRINTER_PRINT_TYPED_DOC_ENTRY_(Slice), MLC_PRINTER_PRINT_TYPED_DOC_ENTRY_(StmtBlock),
+ MLC_PRINTER_PRINT_TYPED_DOC_ENTRY_(Assign), MLC_PRINTER_PRINT_TYPED_DOC_ENTRY_(If),
+ MLC_PRINTER_PRINT_TYPED_DOC_ENTRY_(While), MLC_PRINTER_PRINT_TYPED_DOC_ENTRY_(For),
+ MLC_PRINTER_PRINT_TYPED_DOC_ENTRY_(With), MLC_PRINTER_PRINT_TYPED_DOC_ENTRY_(ExprStmt),
+ MLC_PRINTER_PRINT_TYPED_DOC_ENTRY_(Assert), MLC_PRINTER_PRINT_TYPED_DOC_ENTRY_(Return),
+ MLC_PRINTER_PRINT_TYPED_DOC_ENTRY_(Function), MLC_PRINTER_PRINT_TYPED_DOC_ENTRY_(Class),
+ MLC_PRINTER_PRINT_TYPED_DOC_ENTRY_(Comment), MLC_PRINTER_PRINT_TYPED_DOC_ENTRY_(DocString),
+ };
+ // clang-format off
+ #undef MLC_PRINTER_PRINT_TYPED_DOC_ENTRY_
+ // clang-format on
+ vtable.at(doc->GetTypeIndex())(this, doc);
+ }
+ void IncreaseIndent() { indent_ += options_->indent_spaces; }
+ void DecreaseIndent() { indent_ -= options_->indent_spaces; }
+ std::ostream &NewLine() {
+ size_t start_pos = output_.tellp();
+ output_ << "\n";
+ line_starts_.push_back(output_.tellp());
+ for (int i = 0; i < indent_; ++i) {
+ output_ << ' ';
+ }
+ size_t end_pos = output_.tellp();
+ underlines_exempted_.push_back({start_pos, end_pos});
+ return output_;
+ }
+ std::ostringstream output_;
+ std::vector underlines_exempted_;
+
+private:
+ using ObjectPath = ::mlc::core::ObjectPath;
+
+ void MarkSpan(const ByteSpan &span, const ObjectPath &path);
+ PrinterConfig options_;
+ int indent_ = 0;
+ std::vector line_starts_;
+ mlc::List path_to_underline_;
+ std::vector> current_underline_candidates_;
+ std::vector current_max_path_length_;
+ std::vector underlines_;
+};
+
+inline const char *OpKindToString(OperationObj::Kind kind) {
+ switch (kind) {
+ case OperationObj::Kind::kUSub:
+ return "-";
+ case OperationObj::Kind::kInvert:
+ return "~";
+ case OperationObj::Kind::kNot:
+ return "not ";
+ case OperationObj::Kind::kAdd:
+ return "+";
+ case OperationObj::Kind::kSub:
+ return "-";
+ case OperationObj::Kind::kMult:
+ return "*";
+ case OperationObj::Kind::kDiv:
+ return "/";
+ case OperationObj::Kind::kFloorDiv:
+ return "//";
+ case OperationObj::Kind::kMod:
+ return "%";
+ case OperationObj::Kind::kPow:
+ return "**";
+ case OperationObj::Kind::kLShift:
+ return "<<";
+ case OperationObj::Kind::kRShift:
+ return ">>";
+ case OperationObj::Kind::kBitAnd:
+ return "&";
+ case OperationObj::Kind::kBitOr:
+ return "|";
+ case OperationObj::Kind::kBitXor:
+ return "^";
+ case OperationObj::Kind::kLt:
+ return "<";
+ case OperationObj::Kind::kLtE:
+ return "<=";
+ case OperationObj::Kind::kEq:
+ return "==";
+ case OperationObj::Kind::kNotEq:
+ return "!=";
+ case OperationObj::Kind::kGt:
+ return ">";
+ case OperationObj::Kind::kGtE:
+ return ">=";
+ case OperationObj::Kind::kAnd:
+ return "and";
+ case OperationObj::Kind::kOr:
+ return "or";
+ default:
+ MLC_THROW(ValueError) << "Unknown operation kind: " << static_cast(kind);
+ }
+ MLC_UNREACHABLE();
+}
+
+inline const char *OpKindToString(int64_t kind) { return OpKindToString(static_cast(kind)); }
+
+/*!
+ * \brief Operator precedence based on https://docs.python.org/3/reference/expressions.html#operator-precedence
+ */
+enum class ExprPrecedence : int32_t {
+ /*! \brief Unknown precedence */
+ kUnkown = 0,
+ /*! \brief Lambda Expression */
+ kLambda = 1,
+ /*! \brief Conditional Expression */
+ kIfThenElse = 2,
+ /*! \brief Boolean OR */
+ kBooleanOr = 3,
+ /*! \brief Boolean AND */
+ kBooleanAnd = 4,
+ /*! \brief Boolean NOT */
+ kBooleanNot = 5,
+ /*! \brief Comparisons */
+ kComparison = 6,
+ /*! \brief Bitwise OR */
+ kBitwiseOr = 7,
+ /*! \brief Bitwise XOR */
+ kBitwiseXor = 8,
+ /*! \brief Bitwise AND */
+ kBitwiseAnd = 9,
+ /*! \brief Shift Operators */
+ kShift = 10,
+ /*! \brief Addition and subtraction */
+ kAdd = 11,
+ /*! \brief Multiplication, division, floor division, remainder */
+ kMult = 12,
+ /*! \brief Positive negative and bitwise NOT */
+ kUnary = 13,
+ /*! \brief Exponentiation */
+ kExp = 14,
+ /*! \brief Index access, attribute access, call and atom expression */
+ kIdentity = 15,
+};
+
+inline ExprPrecedence GetExprPrecedence(const Expr &doc) {
+ // Key is the type index of Doc
+ static const std::unordered_map doc_type_precedence = {
+ {LiteralObj::_type_index, ExprPrecedence::kIdentity}, {IdObj::_type_index, ExprPrecedence::kIdentity},
+ {AttrObj::_type_index, ExprPrecedence::kIdentity}, {IndexObj::_type_index, ExprPrecedence::kIdentity},
+ {CallObj::_type_index, ExprPrecedence::kIdentity}, {LambdaObj::_type_index, ExprPrecedence::kLambda},
+ {TupleObj::_type_index, ExprPrecedence::kIdentity}, {ListObj::_type_index, ExprPrecedence::kIdentity},
+ {DictObj::_type_index, ExprPrecedence::kIdentity},
+ };
+ // Key is the value of OperationDocNode::Kind
+ static const std::vector op_kind_precedence = []() {
+ using OpKind = OperationObj::Kind;
+ std::map raw_table = {
+ {OpKind::kUSub, ExprPrecedence::kUnary}, {OpKind::kInvert, ExprPrecedence::kUnary},
+ {OpKind::kNot, ExprPrecedence::kBooleanNot}, {OpKind::kAdd, ExprPrecedence::kAdd},
+ {OpKind::kSub, ExprPrecedence::kAdd}, {OpKind::kMult, ExprPrecedence::kMult},
+ {OpKind::kDiv, ExprPrecedence::kMult}, {OpKind::kFloorDiv, ExprPrecedence::kMult},
+ {OpKind::kMod, ExprPrecedence::kMult}, {OpKind::kPow, ExprPrecedence::kExp},
+ {OpKind::kLShift, ExprPrecedence::kShift}, {OpKind::kRShift, ExprPrecedence::kShift},
+ {OpKind::kBitAnd, ExprPrecedence::kBitwiseAnd}, {OpKind::kBitOr, ExprPrecedence::kBitwiseOr},
+ {OpKind::kBitXor, ExprPrecedence::kBitwiseXor}, {OpKind::kLt, ExprPrecedence::kComparison},
+ {OpKind::kLtE, ExprPrecedence::kComparison}, {OpKind::kEq, ExprPrecedence::kComparison},
+ {OpKind::kNotEq, ExprPrecedence::kComparison}, {OpKind::kGt, ExprPrecedence::kComparison},
+ {OpKind::kGtE, ExprPrecedence::kComparison}, {OpKind::kAnd, ExprPrecedence::kBooleanAnd},
+ {OpKind::kOr, ExprPrecedence::kBooleanOr}, {OpKind::kIfThenElse, ExprPrecedence::kIfThenElse},
+ };
+ std::vector table(static_cast(OpKind::kSpecialEnd) + 1, ExprPrecedence::kUnkown);
+ for (const auto &kv : raw_table) {
+ table[static_cast(kv.first)] = kv.second;
+ }
+ return table;
+ }();
+ if (const auto *op_doc = doc->TryCast()) {
+ ExprPrecedence precedence = op_kind_precedence.at(op_doc->op);
+ if (precedence == ExprPrecedence::kUnkown) {
+ MLC_THROW(ValueError) << "Unknown precedence for operator: " << op_doc->op;
+ }
+ return precedence;
+ } else if (auto it = doc_type_precedence.find(doc->GetTypeIndex()); it != doc_type_precedence.end()) {
+ return it->second;
+ }
+ MLC_THROW(ValueError) << "Unknown precedence for doc type: " << doc->GetTypeKey();
+ MLC_UNREACHABLE();
+}
+
+inline std::vector MergeAndExemptSpans(const std::vector &spans,
+ const std::vector &spans_exempted) {
+ // use prefix sum to merge and exempt spans
+ std::vector res;
+ std::vector> prefix_stamp;
+ for (ByteSpan span : spans) {
+ prefix_stamp.push_back({span.first, 1});
+ prefix_stamp.push_back({span.second, -1});
+ }
+ // at most spans.size() spans accumulated in prefix sum
+ // use spans.size() + 1 as stamp unit to exempt all positive spans
+ // with only one negative span
+ int max_n = static_cast(spans.size()) + 1;
+ for (ByteSpan span : spans_exempted) {
+ prefix_stamp.push_back({span.first, -max_n});
+ prefix_stamp.push_back({span.second, max_n});
+ }
+ std::sort(prefix_stamp.begin(), prefix_stamp.end());
+ int prefix_sum = 0;
+ int n = static_cast(prefix_stamp.size());
+ for (int i = 0; i < n - 1; ++i) {
+ prefix_sum += prefix_stamp[i].second;
+ // positive prefix sum leads to spans without exemption
+ // different stamp positions guarantee the stamps in same position accumulated
+ if (prefix_sum > 0 && prefix_stamp[i].first < prefix_stamp[i + 1].first) {
+ if (res.size() && res.back().second == prefix_stamp[i].first) {
+ // merge to the last spans if it is successive
+ res.back().second = prefix_stamp[i + 1].first;
+ } else {
+ // add a new independent span
+ res.push_back({prefix_stamp[i].first, prefix_stamp[i + 1].first});
+ }
+ }
+ }
+ return res;
+}
+
+inline size_t GetTextWidth(const mlc::Str &text, const ByteSpan &span) {
+ // FIXME: this only works for ASCII characters.
+ // To do this "correctly", we need to parse UTF-8 into codepoints
+ // and call wcwidth() or equivalent for every codepoint.
+ size_t ret = 0;
+ for (size_t i = span.first; i != span.second; ++i) {
+ if (isprint(text[i])) {
+ ret += 1;
+ }
+ }
+ return ret;
+}
+
+inline size_t MoveBack(size_t pos, size_t distance) { return distance > pos ? 0 : pos - distance; }
+
+inline size_t MoveForward(size_t pos, size_t distance, size_t max) {
+ return distance > max - pos ? max : pos + distance;
+}
+
+inline size_t GetLineIndex(size_t byte_pos, const std::vector &line_starts) {
+ auto it = std::upper_bound(line_starts.begin(), line_starts.end(), byte_pos);
+ return (it - line_starts.begin()) - 1;
+}
+
+using UnderlineIter = typename std::vector::const_iterator;
+
+inline ByteSpan PopNextUnderline(UnderlineIter *next_underline, UnderlineIter end_underline) {
+ if (*next_underline == end_underline) {
+ constexpr size_t kMaxSizeT = 18446744073709551615U;
+ return {kMaxSizeT, kMaxSizeT};
+ } else {
+ return *(*next_underline)++;
+ }
+}
+
+inline void PrintChunk(const std::pair &lines_range,
+ const std::pair &underlines, const mlc::Str &text,
+ const std::vector &line_starts, const PrinterConfig &options, size_t line_number_width,
+ std::ostringstream *out) {
+ UnderlineIter next_underline = underlines.first;
+ ByteSpan current_underline = PopNextUnderline(&next_underline, underlines.second);
+
+ for (size_t line_idx = lines_range.first; line_idx < lines_range.second; ++line_idx) {
+ if (options->print_line_numbers) {
+ (*out) << std::setw(line_number_width) << std::right << std::to_string(line_idx + 1) << ' ';
+ }
+ size_t line_start = line_starts.at(line_idx);
+ size_t line_end = line_idx + 1 == line_starts.size() ? text.size() : line_starts.at(line_idx + 1);
+ (*out) << std::string_view(text.data() + line_start, line_end - line_start);
+ bool printed_underline = false;
+ size_t line_pos = line_start;
+ bool printed_extra_caret = 0;
+ while (current_underline.first < line_end) {
+ if (!printed_underline) {
+ for (size_t i = 0; i < line_number_width; ++i) {
+ (*out) << ' ';
+ }
+ printed_underline = true;
+ }
+ size_t underline_end_for_line = std::min(line_end, current_underline.second);
+ size_t num_spaces = GetTextWidth(text, {line_pos, current_underline.first});
+ if (num_spaces > 0 && printed_extra_caret) {
+ num_spaces -= 1;
+ printed_extra_caret = false;
+ }
+ for (size_t i = 0; i < num_spaces; ++i) {
+ (*out) << ' ';
+ }
+
+ size_t num_carets = GetTextWidth(text, {current_underline.first, underline_end_for_line});
+ if (num_carets == 0 && !printed_extra_caret) {
+ // Special case: when underlineing an empty or unprintable string, make sure to print
+ // at least one caret still.
+ num_carets = 1;
+ printed_extra_caret = true;
+ } else if (num_carets > 0 && printed_extra_caret) {
+ num_carets -= 1;
+ printed_extra_caret = false;
+ }
+ for (size_t i = 0; i < num_carets; ++i) {
+ (*out) << '^';
+ }
+
+ line_pos = current_underline.first = underline_end_for_line;
+ if (current_underline.first == current_underline.second) {
+ current_underline = PopNextUnderline(&next_underline, underlines.second);
+ }
+ }
+ if (printed_underline) {
+ (*out) << '\n';
+ }
+ }
+}
+
+inline void PrintCut(size_t num_lines_skipped, std::ostringstream *out) {
+ if (num_lines_skipped != 0) {
+ (*out) << "(... " << num_lines_skipped << " lines skipped ...)\n";
+ }
+}
+
+inline std::pair GetLinesForUnderline(const ByteSpan &underline, const std::vector &line_starts,
+ size_t num_lines, const PrinterConfig &options) {
+ size_t first_line_of_underline = GetLineIndex(underline.first, line_starts);
+ size_t first_line_of_chunk = MoveBack(first_line_of_underline, options->num_context_lines);
+ size_t end_line_of_underline = GetLineIndex(underline.second - 1, line_starts) + 1;
+ size_t end_line_of_chunk = MoveForward(end_line_of_underline, options->num_context_lines, num_lines);
+
+ return {first_line_of_chunk, end_line_of_chunk};
+}
+
+// If there is only one line between the chunks, it is better to print it as is,
+// rather than something like "(... 1 line skipped ...)".
+constexpr const size_t kMinLinesToCutOut = 2;
+
+inline bool TryMergeChunks(std::pair *cur_chunk, const std::pair &new_chunk) {
+ if (new_chunk.first < cur_chunk->second + kMinLinesToCutOut) {
+ cur_chunk->second = new_chunk.second;
+ return true;
+ } else {
+ return false;
+ }
+}
+
+inline size_t GetNumLines(const mlc::Str &text, const std::vector &line_starts) {
+ if (static_cast(line_starts.back()) == text.size()) {
+ // Final empty line doesn't count as a line
+ return line_starts.size() - 1;
+ } else {
+ return line_starts.size();
+ }
+}
+
+inline size_t GetLineNumberWidth(size_t num_lines, const PrinterConfig &options) {
+ if (options->print_line_numbers) {
+ return std::to_string(num_lines).size() + 1;
+ } else {
+ return 0;
+ }
+}
+
+inline mlc::Str DecorateText(const mlc::Str &text, const std::vector &line_starts, const PrinterConfig &options,
+ const std::vector &underlines) {
+ size_t num_lines = GetNumLines(text, line_starts);
+ size_t line_number_width = GetLineNumberWidth(num_lines, options);
+
+ std::ostringstream ret;
+ if (underlines.empty()) {
+ PrintChunk({0, num_lines}, {underlines.begin(), underlines.begin()}, text, line_starts, options, line_number_width,
+ &ret);
+ return ret.str();
+ }
+
+ size_t last_end_line = 0;
+ std::pair cur_chunk = GetLinesForUnderline(underlines[0], line_starts, num_lines, options);
+ if (cur_chunk.first < kMinLinesToCutOut) {
+ cur_chunk.first = 0;
+ }
+
+ auto first_underline_in_cur_chunk = underlines.begin();
+ for (auto underline_it = underlines.begin() + 1; underline_it != underlines.end(); ++underline_it) {
+ std::pair new_chunk = GetLinesForUnderline(*underline_it, line_starts, num_lines, options);
+
+ if (!TryMergeChunks(&cur_chunk, new_chunk)) {
+ PrintCut(cur_chunk.first - last_end_line, &ret);
+ PrintChunk(cur_chunk, {first_underline_in_cur_chunk, underline_it}, text, line_starts, options, line_number_width,
+ &ret);
+ last_end_line = cur_chunk.second;
+ cur_chunk = new_chunk;
+ first_underline_in_cur_chunk = underline_it;
+ }
+ }
+
+ PrintCut(cur_chunk.first - last_end_line, &ret);
+ if (num_lines - cur_chunk.second < kMinLinesToCutOut) {
+ cur_chunk.second = num_lines;
+ }
+ PrintChunk(cur_chunk, {first_underline_in_cur_chunk, underlines.end()}, text, line_starts, options, line_number_width,
+ &ret);
+ PrintCut(num_lines - cur_chunk.second, &ret);
+ return ret.str();
+}
+
+inline DocPrinter::DocPrinter(const PrinterConfig &options) : options_(options) { line_starts_.push_back(0); }
+
+inline void DocPrinter::Append(const Node &doc) { Append(doc, PrinterConfig()); }
+
+inline void DocPrinter::Append(const Node &doc, const PrinterConfig &cfg) {
+ for (ObjectPath p : cfg->path_to_underline) {
+ path_to_underline_.push_back(p);
+ current_max_path_length_.push_back(0);
+ current_underline_candidates_.push_back(std::vector());
+ }
+ PrintDoc(doc);
+ for (const auto &c : current_underline_candidates_) {
+ underlines_.insert(underlines_.end(), c.begin(), c.end());
+ }
+}
+
+inline ::mlc::Str DocPrinter::GetString() const {
+ mlc::Str text = output_.str();
+ // Remove any trailing indentation
+ while (!text->empty() && std::isspace(text->back())) {
+ text->pop_back();
+ }
+ return DecorateText(text, line_starts_, options_, MergeAndExemptSpans(underlines_, underlines_exempted_));
+}
+
+inline void DocPrinter::PrintDoc(const Node &doc) {
+ size_t start_pos = output_.tellp();
+ this->PrintTypedDoc(doc.get());
+ size_t end_pos = output_.tellp();
+ for (ObjectPath path : doc->source_paths) {
+ MarkSpan({start_pos, end_pos}, path);
+ }
+}
+
+inline void DocPrinter::MarkSpan(const ByteSpan &span, const ObjectPath &path) {
+ int n = static_cast(path_to_underline_.size());
+ for (int i = 0; i < n; ++i) {
+ ObjectPath p = path_to_underline_[i];
+ if (path->length >= current_max_path_length_[i] && path->IsPrefixOf(p.get())) {
+ if (path->length > current_max_path_length_[i]) {
+ current_max_path_length_[i] = static_cast(path->length);
+ current_underline_candidates_[i].clear();
+ }
+ current_underline_candidates_[i].push_back(span);
+ }
+ }
+}
+
+class PythonDocPrinter : public DocPrinter {
+public:
+ explicit PythonDocPrinter(const PrinterConfig &options) : DocPrinter(options) {}
+
+protected:
+ using DocPrinter::PrintDoc;
+
+ void PrintTypedDoc(const Literal &doc) final;
+ void PrintTypedDoc(const Id &doc) final;
+ void PrintTypedDoc(const Attr &doc) final;
+ void PrintTypedDoc(const Index &doc) final;
+ void PrintTypedDoc(const Operation &doc) final;
+ void PrintTypedDoc(const Call &doc) final;
+ void PrintTypedDoc(const Lambda &doc) final;
+ void PrintTypedDoc(const List &doc) final;
+ void PrintTypedDoc(const Dict &doc) final;
+ void PrintTypedDoc(const Tuple &doc) final;
+ void PrintTypedDoc(const Slice &doc) final;
+ void PrintTypedDoc(const StmtBlock &doc) final;
+ void PrintTypedDoc(const Assign &doc) final;
+ void PrintTypedDoc(const If &doc) final;
+ void PrintTypedDoc(const While &doc) final;
+ void PrintTypedDoc(const For &doc) final;
+ void PrintTypedDoc(const ExprStmt &doc) final;
+ void PrintTypedDoc(const Assert &doc) final;
+ void PrintTypedDoc(const Return &doc) final;
+ void PrintTypedDoc(const With &doc) final;
+ void PrintTypedDoc(const Function &doc) final;
+ void PrintTypedDoc(const Class &doc) final;
+ void PrintTypedDoc(const Comment &doc) final;
+ void PrintTypedDoc(const DocString &doc) final;
+
+private:
+ void NewLineWithoutIndent() {
+ size_t start_pos = output_.tellp();
+ output_ << "\n";
+ size_t end_pos = output_.tellp();
+ underlines_exempted_.push_back({start_pos, end_pos});
+ }
+
+ template void PrintJoinedDocs(const mlc::List &docs, const char *separator) {
+ bool is_first = true;
+ for (DocType doc : docs) {
+ if (is_first) {
+ is_first = false;
+ } else {
+ output_ << separator;
+ }
+ PrintDoc(doc);
+ }
+ }
+
+ void PrintIndentedBlock(const mlc::List &docs) {
+ IncreaseIndent();
+ for (Stmt d : docs) {
+ NewLine();
+ PrintDoc(d);
+ }
+ if (docs.empty()) {
+ NewLine();
+ output_ << "pass";
+ }
+ DecreaseIndent();
+ }
+
+ void PrintDecorators(const mlc::List &decorators) {
+ for (Expr decorator : decorators) {
+ output_ << "@";
+ PrintDoc(decorator);
+ NewLine();
+ }
+ }
+
+ /*!
+ * \brief Print expression and add parenthesis if needed.
+ */
+ void PrintChildExpr(const Expr &doc, ExprPrecedence parent_precedence, bool parenthesis_for_same_precedence = false) {
+ ExprPrecedence doc_precedence = GetExprPrecedence(doc);
+ if (doc_precedence < parent_precedence ||
+ (parenthesis_for_same_precedence && doc_precedence == parent_precedence)) {
+ output_ << "(";
+ PrintDoc(doc);
+ output_ << ")";
+ } else {
+ PrintDoc(doc);
+ }
+ }
+
+ /*!
+ * \brief Print expression and add parenthesis if doc has lower precedence than parent.
+ */
+ void PrintChildExpr(const Expr &doc, const Expr &parent, bool parenthesis_for_same_precedence = false) {
+ ExprPrecedence parent_precedence = GetExprPrecedence(parent);
+ return PrintChildExpr(doc, parent_precedence, parenthesis_for_same_precedence);
+ }
+
+ /*!
+ * \brief Print expression and add parenthesis if doc doesn't have higher precedence than parent.
+ *
+ * This function should be used to print an child expression that needs to be wrapped
+ * by parenthesis even if it has the same precedence as its parent, e.g., the `b` in `a + b`
+ * and the `b` and `c` in `a if b else c`.
+ */
+ void PrintChildExprConservatively(const Expr &doc, const Expr &parent) {
+ PrintChildExpr(doc, parent, /*parenthesis_for_same_precedence=*/true);
+ }
+
+ void MaybePrintCommentInline(const Stmt &stmt) {
+ if (const mlc::StrObj *comment = stmt->comment.get()) {
+ bool has_newline = std::find(comment->begin(), comment->end(), '\n') != comment->end();
+ if (has_newline) {
+ MLC_THROW(ValueError) << "ValueError: Comment string of " << stmt->GetTypeKey()
+ << " cannot have newline, but got: " << comment;
+ }
+ size_t start_pos = output_.tellp();
+ output_ << " # " << comment->data();
+ size_t end_pos = output_.tellp();
+ underlines_exempted_.push_back({start_pos, end_pos});
+ }
+ }
+
+ void MaybePrintCommenMultiLines(const Stmt &stmt, bool new_line = false) {
+ if (const mlc::StrObj *comment = stmt->comment.get()) {
+ bool first_line = true;
+ size_t start_pos = output_.tellp();
+ for (const std::string_view &line : comment->Split('\n')) {
+ if (first_line) {
+ output_ << "# " << line;
+ first_line = false;
+ } else {
+ NewLine() << "# " << line;
+ }
+ }
+ size_t end_pos = output_.tellp();
+ underlines_exempted_.push_back({start_pos, end_pos});
+ if (new_line) {
+ NewLine();
+ }
+ }
+ }
+
+ void PrintDocString(const mlc::Str &comment) {
+ size_t start_pos = output_.tellp();
+ output_ << "\"\"\"";
+ for (const std::string_view &line : comment->Split('\n')) {
+ if (line.empty()) {
+ // No indentation on empty line
+ output_ << "\n";
+ } else {
+ NewLine() << line;
+ }
+ }
+ NewLine() << "\"\"\"";
+ size_t end_pos = output_.tellp();
+ underlines_exempted_.push_back({start_pos, end_pos});
+ }
+ void PrintBlockComment(const mlc::Str &comment) {
+ IncreaseIndent();
+ NewLine();
+ PrintDocString(comment);
+ DecreaseIndent();
+ }
+};
+
+inline void PythonDocPrinter::PrintTypedDoc(const Literal &doc) {
+ Any value = doc->value;
+ int32_t type_index = value.GetTypeIndex();
+ if (!value.defined()) {
+ output_ << "None";
+ } else if (type_index == kMLCInt) {
+ output_ << value.operator int64_t();
+ } else if (type_index == kMLCFloat) {
+ double v = value.operator double();
+ // TODO(yelite): Make float number printing roundtrippable
+ if (std::isinf(v) || std::isnan(v)) {
+ output_ << '"' << v << '"';
+ } else if (std::nearbyint(v) == v) {
+ // Special case for floating-point values which would be
+ // formatted using %g, are not displayed in scientific
+ // notation, and whose fractional part is zero.
+ //
+ // By default, using `operator<<(std::ostream&, double)`
+ // delegates to the %g printf formatter. This strips off any
+ // trailing zeros, and also strips the decimal point if no
+ // trailing zeros are found. When parsed in python, due to the
+ // missing decimal point, this would incorrectly convert a float
+ // to an integer. Providing the `std::showpoint` modifier
+ // instead delegates to the %#g printf formatter. On its own,
+ // this resolves the round-trip errors, but also prevents the
+ // trailing zeros from being stripped off.
+ std::showpoint(output_);
+ std::fixed(output_);
+ output_.precision(1);
+ output_ << v;
+ } else {
+ std::defaultfloat(output_);
+ std::noshowpoint(output_);
+ output_.precision(17);
+ output_ << v;
+ }
+ } else if (type_index == kMLCStr) {
+ (value.operator mlc::Str())->PrintEscape(output_);
+ } else {
+ MLC_THROW(TypeError) << "TypeError: Unsupported literal value type: " << value.GetTypeKey();
+ }
+}
+
+inline void PythonDocPrinter::PrintTypedDoc(const Id &doc) { output_ << doc->name; }
+
+inline void PythonDocPrinter::PrintTypedDoc(const Attr &doc) {
+ PrintChildExpr(doc->obj, doc);
+ output_ << "." << doc->name;
+}
+
+inline void PythonDocPrinter::PrintTypedDoc(const Index &doc) {
+ PrintChildExpr(doc->obj, doc);
+ if (doc->idx.size() == 0) {
+ output_ << "[()]";
+ } else {
+ output_ << "[";
+ PrintJoinedDocs(doc->idx, ", ");
+ output_ << "]";
+ }
+}
+
+inline void PythonDocPrinter::PrintTypedDoc(const Operation &doc) {
+ using OpKind = OperationObj::Kind;
+ if (doc->op < OpKind::kUnaryEnd) {
+ // Unary Operators
+ if (doc->operands.size() != 1) {
+ MLC_THROW(ValueError) << "ValueError: Unary operator requires 1 operand, but got " << doc->operands.size();
+ }
+ output_ << OpKindToString(doc->op);
+ PrintChildExpr(doc->operands[0], doc);
+ } else if (doc->op == OpKind::kPow) {
+ // Power operator is different than other binary operators
+ // It's right-associative and binds less tightly than unary operator on its right.
+ // https://docs.python.org/3/reference/expressions.html#the-power-operator
+ // https://docs.python.org/3/reference/expressions.html#operator-precedence
+ if (doc->operands.size() != 2) {
+ MLC_THROW(ValueError) << "Operator '**' requires 2 operands, but got " << doc->operands.size();
+ }
+ PrintChildExprConservatively(doc->operands[0], doc);
+ output_ << " ** ";
+ PrintChildExpr(doc->operands[1], ExprPrecedence::kUnary);
+ } else if (doc->op < OpKind::kBinaryEnd) {
+ // Binary Operator
+ if (doc->operands.size() != 2) {
+ MLC_THROW(ValueError) << "Binary operator requires 2 operands, but got " << doc->operands.size();
+ }
+ PrintChildExpr(doc->operands[0], doc);
+ output_ << " " << OpKindToString(doc->op) << " ";
+ PrintChildExprConservatively(doc->operands[1], doc);
+ } else if (doc->op == OpKind::kIfThenElse) {
+ if (doc->operands.size() != 3) {
+ MLC_THROW(ValueError) << "IfThenElse requires 3 operands, but got " << doc->operands.size();
+ }
+ PrintChildExpr(doc->operands[1], doc);
+ output_ << " if ";
+ PrintChildExprConservatively(doc->operands[0], doc);
+ output_ << " else ";
+ PrintChildExprConservatively(doc->operands[2], doc);
+ } else {
+ MLC_THROW(ValueError) << "Unknown OperationDocNode::Kind " << static_cast(doc->op);
+ }
+}
+
+inline void PythonDocPrinter::PrintTypedDoc(const Call &doc) {
+ PrintChildExpr(doc->callee, doc);
+ output_ << "(";
+ // Print positional args
+ bool is_first = true;
+ for (Expr arg : doc->args) {
+ if (is_first) {
+ is_first = false;
+ } else {
+ output_ << ", ";
+ }
+ PrintDoc(arg);
+ }
+ // Print keyword args
+ if (doc->kwargs_keys.size() != doc->kwargs_values.size()) {
+ MLC_THROW(ValueError) << "CallDoc should have equal number of elements in kwargs_keys and kwargs_values.";
+ }
+ for (int64_t i = 0; i < doc->kwargs_keys.size(); i++) {
+ if (is_first) {
+ is_first = false;
+ } else {
+ output_ << ", ";
+ }
+ const mlc::Str &keyword = doc->kwargs_keys[i];
+ output_ << keyword;
+ output_ << "=";
+ PrintDoc(doc->kwargs_values[i]);
+ }
+
+ output_ << ")";
+}
+
+inline void PythonDocPrinter::PrintTypedDoc(const Lambda &doc) {
+ output_ << "lambda ";
+ PrintJoinedDocs(doc->args, ", ");
+ output_ << ": ";
+ PrintChildExpr(doc->body, doc);
+}
+
+inline void PythonDocPrinter::PrintTypedDoc(const List &doc) {
+ output_ << "[";
+ PrintJoinedDocs(doc->values, ", ");
+ output_ << "]";
+}
+
+inline void PythonDocPrinter::PrintTypedDoc(const Tuple &doc) {
+ output_ << "(";
+ if (doc->values.size() == 1) {
+ PrintDoc(doc->values[0]);
+ output_ << ",";
+ } else {
+ PrintJoinedDocs(doc->values, ", ");
+ }
+ output_ << ")";
+}
+
+inline void PythonDocPrinter::PrintTypedDoc(const Dict &doc) {
+ if (doc->keys.size() != doc->values.size()) {
+ MLC_THROW(ValueError) << "DictDoc should have equal number of elements in keys and values.";
+ }
+ output_ << "{";
+ size_t idx = 0;
+ for (Expr key : doc->keys) {
+ if (idx > 0) {
+ output_ << ", ";
+ }
+ PrintDoc(key);
+ output_ << ": ";
+ PrintDoc(doc->values[idx]);
+ idx++;
+ }
+ output_ << "}";
+}
+
+inline void PythonDocPrinter::PrintTypedDoc(const Slice &doc) {
+ if (doc->start != nullptr) {
+ PrintDoc(doc->start.value());
+ }
+ output_ << ":";
+ if (doc->stop != nullptr) {
+ PrintDoc(doc->stop.value());
+ }
+ if (doc->step != nullptr) {
+ output_ << ":";
+ PrintDoc(doc->step.value());
+ }
+}
+
+inline void PythonDocPrinter::PrintTypedDoc(const StmtBlock &doc) {
+ bool is_first = true;
+ for (Stmt stmt : doc->stmts) {
+ if (is_first) {
+ is_first = false;
+ } else {
+ NewLine();
+ }
+ PrintDoc(stmt);
+ }
+}
+
+inline void PythonDocPrinter::PrintTypedDoc(const Assign &doc) {
+ if (const auto *tuple_doc = doc->lhs->TryCast()) {
+ PrintJoinedDocs(tuple_doc->values, ", ");
+ } else {
+ PrintDoc(doc->lhs);
+ }
+
+ if (doc->annotation.defined()) {
+ output_ << ": ";
+ PrintDoc(doc->annotation.value());
+ }
+ if (doc->rhs.defined()) {
+ output_ << " = ";
+ if (const auto *tuple_doc = doc->rhs.TryCast()) {
+ if (tuple_doc->values.size() > 1) {
+ PrintJoinedDocs(tuple_doc->values, ", ");
+ } else {
+ PrintDoc(doc->rhs.value());
+ }
+ } else {
+ PrintDoc(doc->rhs.value());
+ }
+ }
+ MaybePrintCommentInline(doc);
+}
+
+inline void PythonDocPrinter::PrintTypedDoc(const If &doc) {
+ MaybePrintCommenMultiLines(doc, true);
+ output_ << "if ";
+ PrintDoc(doc->cond);
+ output_ << ":";
+ PrintIndentedBlock(doc->then_branch);
+ if (!doc->else_branch.empty()) {
+ NewLine();
+ output_ << "else:";
+ PrintIndentedBlock(doc->else_branch);
+ }
+}
+
+inline void PythonDocPrinter::PrintTypedDoc(const While &doc) {
+ MaybePrintCommenMultiLines(doc, true);
+ output_ << "while ";
+ PrintDoc(doc->cond);
+ output_ << ":";
+ PrintIndentedBlock(doc->body);
+}
+
+inline void PythonDocPrinter::PrintTypedDoc(const For &doc) {
+ MaybePrintCommenMultiLines(doc, true);
+ output_ << "for ";
+ if (const auto *tuple = doc->lhs->TryCast()) {
+ if (tuple->values.size() == 1) {
+ PrintDoc(tuple->values[0]);
+ output_ << ",";
+ } else {
+ PrintJoinedDocs(tuple->values, ", ");
+ }
+ } else {
+ PrintDoc(doc->lhs);
+ }
+ output_ << " in ";
+ PrintDoc(doc->rhs);
+ output_ << ":";
+ PrintIndentedBlock(doc->body);
+}
+
+inline void PythonDocPrinter::PrintTypedDoc(const With &doc) {
+ MaybePrintCommenMultiLines(doc, true);
+ output_ << "with ";
+ PrintDoc(doc->rhs);
+ if (doc->lhs != nullptr) {
+ output_ << " as ";
+ PrintDoc(doc->lhs.value());
+ }
+ output_ << ":";
+
+ PrintIndentedBlock(doc->body);
+}
+
+inline void PythonDocPrinter::PrintTypedDoc(const ExprStmt &doc) {
+ PrintDoc(doc->expr);
+ MaybePrintCommentInline(doc);
+}
+
+inline void PythonDocPrinter::PrintTypedDoc(const Assert &doc) {
+ output_ << "assert ";
+ PrintDoc(doc->cond);
+ if (doc->msg.defined()) {
+ output_ << ", ";
+ PrintDoc(doc->msg.value());
+ }
+ MaybePrintCommentInline(doc);
+}
+
+inline void PythonDocPrinter::PrintTypedDoc(const Return &doc) {
+ output_ << "return";
+ if (doc->value.defined()) {
+ output_ << " ";
+ PrintDoc(doc->value.value());
+ }
+ MaybePrintCommentInline(doc);
+}
+
+inline void PythonDocPrinter::PrintTypedDoc(const Function &doc) {
+ PrintDecorators(doc->decorators);
+ output_ << "def ";
+ PrintDoc(doc->name);
+ output_ << "(";
+ PrintJoinedDocs(doc->args, ", ");
+ output_ << ")";
+ if (doc->return_type.defined()) {
+ output_ << " -> ";
+ PrintDoc(doc->return_type.value());
+ }
+ output_ << ":";
+ if (doc->comment.defined()) {
+ PrintBlockComment(doc->comment.value());
+ }
+ PrintIndentedBlock(doc->body);
+ NewLineWithoutIndent();
+}
+
+inline void PythonDocPrinter::PrintTypedDoc(const Class &doc) {
+ PrintDecorators(doc->decorators);
+ output_ << "class ";
+ PrintDoc(doc->name);
+ output_ << ":";
+ if (doc->comment.defined()) {
+ PrintBlockComment(doc->comment.value());
+ }
+ PrintIndentedBlock(doc->body);
+}
+
+inline void PythonDocPrinter::PrintTypedDoc(const Comment &doc) {
+ if (doc->comment.defined()) {
+ MaybePrintCommenMultiLines(doc, false);
+ }
+}
+
+inline void PythonDocPrinter::PrintTypedDoc(const DocString &doc) {
+ if (doc->comment.defined()) {
+ mlc::Str comment(doc->comment.value());
+ if (!comment->empty()) {
+ PrintDocString(comment);
+ }
+ }
+}
+
+mlc::Str DocToPythonScript(Node node, PrinterConfig cfg) {
+ if (cfg->num_context_lines < 0) {
+ constexpr int32_t kMaxInt32 = 2147483647;
+ cfg->num_context_lines = kMaxInt32;
+ }
+ PythonDocPrinter printer(cfg);
+ printer.Append(node, cfg);
+ mlc::Str result = printer.GetString();
+ while (!result->empty() && std::isspace(result->back())) {
+ result->pop_back();
+ }
+ return result;
+}
+
+MLC_REGISTER_FUNC("mlc.printer.DocToPythonScript").set_body(DocToPythonScript);
+MLC_REGISTER_FUNC("mlc.printer.ToPython").set_body(::mlc::printer::ToPython);
+
+} // namespace
+} // namespace printer
+} // namespace mlc
diff --git a/cpp/registry.h b/cpp/registry.h
new file mode 100644
index 00000000..5675d637
--- /dev/null
+++ b/cpp/registry.h
@@ -0,0 +1,504 @@
+#ifndef MLC_CPP_REGISTRY_H_
+#define MLC_CPP_REGISTRY_H_
+#ifdef _MSC_VER
+#include
+#else
+#include
+#endif
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+
+namespace mlc {
+namespace registry {
+
+struct DSOLibrary {
+ ~DSOLibrary() { Unload(); }
+
+#ifdef _MSC_VER
+ DSOLibrary(std::string name) {
+ std::wstring wname(name.begin(), name.end());
+ lib_handle_ = LoadLibraryW(wname.c_str());
+ if (lib_handle_ == nullptr) {
+ MLC_THROW(ValueError) << "Failed to load dynamic shared library " << name;
+ }
+ }
+
+ void Unload() {
+ if (lib_handle_ != nullptr) {
+ FreeLibrary(lib_handle_);
+ lib_handle_ = nullptr;
+ }
+ }
+
+ void *GetSymbol_(const char *name) {
+ return reinterpret_cast(GetProcAddress(lib_handle_, (LPCSTR)name)); // NOLINT(*)
+ }
+
+ HMODULE lib_handle_{nullptr};
+#else
+ DSOLibrary(std::string name) {
+ lib_handle_ = dlopen(name.c_str(), RTLD_LAZY | RTLD_LOCAL);
+ if (lib_handle_ == nullptr) {
+ MLC_THROW(ValueError) << "Failed to load dynamic shared library " << name << " " << dlerror();
+ }
+ }
+
+ void Unload() {
+ if (lib_handle_ != nullptr) {
+ dlclose(lib_handle_);
+ lib_handle_ = nullptr;
+ }
+ }
+
+ void *GetSymbol(const char *name) { return dlsym(lib_handle_, name); }
+
+ void *lib_handle_{nullptr};
+#endif
+};
+
+struct ResourcePool {
+ using ObjPtr = std::unique_ptr;
+
+ void AddObj(void *ptr) {
+ if (ptr != nullptr) {
+ MLCAny *ptr_cast = reinterpret_cast(ptr);
+ ::mlc::base::IncRef(ptr_cast);
+ this->objects.insert({ptr, ObjPtr(ptr_cast, ::mlc::base::DecRef)});
+ }
+ }
+
+ void DelObj(void *ptr) {
+ if (ptr != nullptr) {
+ this->objects.erase(this->objects.find(ptr));
+ }
+ }
+
+ template PODType *NewPODArray(int64_t size) {
+ return static_cast(this->NewPODArrayImpl(size, sizeof(PODType)));
+ }
+
+ const char *NewStr(const char *source) {
+ if (source == nullptr) {
+ return nullptr;
+ }
+ size_t len = strlen(source);
+ char *ptr = this->NewPODArray(len + 1);
+ std::memcpy(ptr, source, len + 1);
+ return ptr;
+ }
+
+ void DelPODArray(const void *ptr) {
+ if (ptr != nullptr) {
+ this->pod_array.erase(ptr);
+ }
+ }
+
+private:
+ MLC_INLINE void *NewPODArrayImpl(int64_t size, size_t pod_size) {
+ if (size == 0) {
+ return nullptr;
+ }
+ ::mlc::base::PODArray owned(static_cast(std::malloc(size * pod_size)), std::free);
+ void *ptr = owned.get();
+ auto [it, success] = this->pod_array.emplace(ptr, std::move(owned));
+ if (!success) {
+ std::cerr << "Array already registered: " << ptr;
+ std::abort();
+ }
+ return ptr;
+ }
+
+ std::unordered_map pod_array;
+ std::unordered_multimap objects;
+};
+
+struct TypeTable;
+
+struct VTable {
+ explicit VTable(TypeTable *type_table, std::string name) : type_table(type_table), name(std::move(name)), data() {}
+ void Set(int32_t type_index, FuncObj *func, int32_t override_mode);
+ FuncObj *GetFunc(int32_t type_index, bool allow_ancestor) const;
+
+private:
+ TypeTable *type_table;
+ std::string name;
+ std::unordered_map data;
+};
+
+struct TypeInfoWrapper {
+ MLCTypeInfo info{};
+ ResourcePool *pool = nullptr;
+ int64_t num_fields = 0;
+ std::vector methods = {};
+
+ ~TypeInfoWrapper() { this->Reset(); }
+
+ void Reset() {
+ if (this->pool) {
+ this->pool->DelPODArray(this->info.type_key);
+ this->pool->DelPODArray(this->info.type_ancestors);
+ this->ResetFields();
+ this->ResetMethods();
+ this->info.type_key = nullptr;
+ this->info.type_ancestors = nullptr;
+ this->pool = nullptr;
+ }
+ }
+
+ void ResetFields() {
+ if (this->num_fields > 0) {
+ MLCTypeField *&fields = this->info.fields;
+ for (int32_t i = 0; i < this->num_fields; i++) {
+ this->pool->DelPODArray(fields[i].name);
+ }
+ this->pool->DelPODArray(fields);
+ this->info.fields = nullptr;
+ this->num_fields = 0;
+ }
+ }
+
+ void ResetMethods() {
+ if (!this->methods.empty()) {
+ for (MLCTypeMethod &method : this->methods) {
+ if (method.name != nullptr) {
+ this->pool->DelPODArray(method.name);
+ this->pool->DelObj(method.func);
+ }
+ }
+ this->info.methods = nullptr;
+ this->methods.clear();
+ }
+ }
+
+ void ResetStructure() {
+ if (this->info.sub_structure_indices) {
+ this->pool->DelPODArray(this->info.sub_structure_indices);
+ this->info.sub_structure_indices = nullptr;
+ }
+ if (this->info.sub_structure_kinds) {
+ this->pool->DelPODArray(this->info.sub_structure_kinds);
+ this->info.sub_structure_kinds = nullptr;
+ }
+ }
+
+ void SetFields(int64_t new_num_fields, MLCTypeField *fields) {
+ this->ResetFields();
+ this->num_fields = new_num_fields;
+ MLCTypeField *dsts = this->info.fields = this->pool->NewPODArray(new_num_fields + 1);
+ for (int64_t i = 0; i < num_fields; i++) {
+ MLCTypeField *dst = dsts + i;
+ *dst = fields[i];
+ this->pool->AddObj(dst->ty);
+ dst->name = this->pool->NewStr(dst->name);
+ if (dst->index != i) {
+ MLC_THROW(ValueError) << "Field index mismatch: " << i << " vs " << dst->index;
+ }
+ }
+ dsts[num_fields] = MLCTypeField{};
+ std::sort(dsts, dsts + num_fields,
+ [](const MLCTypeField &a, const MLCTypeField &b) { return a.offset < b.offset; });
+ }
+
+ void AddMethod(MLCTypeMethod method) {
+ for (MLCTypeMethod &m : this->methods) {
+ if (m.name && std::strcmp(m.name, method.name) == 0) {
+ this->pool->DelObj(m.func);
+ this->pool->AddObj(reinterpret_cast(method.func));
+ m.func = method.func;
+ m.kind = method.kind;
+ return;
+ }
+ }
+ if (this->methods.empty()) {
+ this->methods.emplace_back();
+ }
+ MLCTypeMethod &dst = this->methods.back();
+ dst = method;
+ dst.name = this->pool->NewStr(dst.name);
+ this->pool->AddObj(reinterpret_cast(dst.func));
+ this->methods.emplace_back();
+ this->info.methods = this->methods.data();
+ }
+
+ void SetStructure(int32_t structure_kind, int64_t num_sub_structures, int32_t *sub_structure_indices,
+ int32_t *sub_structure_kinds) {
+ this->ResetStructure();
+ this->info.structure_kind = structure_kind;
+ if (num_sub_structures > 0) {
+ this->info.sub_structure_indices = this->pool->NewPODArray(num_sub_structures + 1);
+ this->info.sub_structure_kinds = this->pool->NewPODArray(num_sub_structures + 1);
+ std::memcpy(this->info.sub_structure_indices, sub_structure_indices, num_sub_structures * sizeof(int32_t));
+ std::memcpy(this->info.sub_structure_kinds, sub_structure_kinds, num_sub_structures * sizeof(int32_t));
+ std::reverse(this->info.sub_structure_indices, this->info.sub_structure_indices + num_sub_structures);
+ std::reverse(this->info.sub_structure_kinds, this->info.sub_structure_kinds + num_sub_structures);
+ this->info.sub_structure_indices[num_sub_structures] = -1;
+ this->info.sub_structure_kinds[num_sub_structures] = -1;
+ } else {
+ this->info.sub_structure_indices = nullptr;
+ this->info.sub_structure_kinds = nullptr;
+ }
+ }
+};
+
+struct TypeTable {
+ using ObjPtr = std::unique_ptr;
+
+ int32_t num_types;
+ std::vector> type_table;
+ std::unordered_map type_key_to_info;
+ std::unordered_map global_funcs;
+ std::unordered_map> global_vtables;
+ std::unordered_map> dso_libraries;
+ ResourcePool pool;
+
+ MLCTypeInfo *GetTypeInfo(int32_t type_index) {
+ if (type_index < 0 || type_index >= static_cast(this->type_table.size())) {
+ return nullptr;
+ } else {
+ return &this->type_table.at(type_index)->info;
+ }
+ }
+
+ MLCTypeInfo *GetTypeInfo(const char *type_key) {
+ auto ret = this->type_key_to_info.find(type_key);
+ if (ret == this->type_key_to_info.end()) {
+ return nullptr;
+ } else {
+ return ret->second;
+ }
+ }
+
+ FuncObj *GetVTable(int32_t type_index, const char *attr_key, bool allow_ancestor) {
+ if (auto it = this->global_vtables.find(attr_key); it != this->global_vtables.end()) {
+ VTable *vtable = it->second.get();
+ return vtable->GetFunc(type_index, allow_ancestor);
+ } else {
+ return nullptr;
+ }
+ }
+
+ VTable *GetGlobalVTable(const char *name) {
+ if (auto it = this->global_vtables.find(name); it != this->global_vtables.end()) {
+ return it->second.get();
+ } else {
+ std::unique_ptr &vtable = this->global_vtables[name] = std::make_unique(this, name);
+ return vtable.get();
+ }
+ }
+
+ FuncObj *GetFunc(const char *name) {
+ if (auto it = this->global_funcs.find(name); it != this->global_funcs.end()) {
+ return it->second;
+ } else {
+ return nullptr;
+ }
+ }
+
+ static TypeTable *New();
+ static TypeTable *Global();
+ static TypeTable *Get(MLCTypeTableHandle self) {
+ return (self == nullptr) ? TypeTable::Global() : static_cast(self);
+ }
+
+ MLCTypeInfo *TypeRegister(int32_t parent_type_index, int32_t type_index, const char *type_key) {
+ // Step 1.Check if the type is already registered
+ if (auto it = this->type_key_to_info.find(type_key); it != this->type_key_to_info.end()) {
+ MLCTypeInfo *ret = it->second;
+ if (type_index != -1 && type_index != ret->type_index) {
+ MLC_THROW(KeyError) << "Type `" << type_key << "` registered with type index `" << ret->type_index
+ << "`, but re-registered with type index: " << type_index;
+ }
+ return ret;
+ }
+ // Step 2. Manipulate type table
+ if (type_index == -1) {
+ type_index = this->num_types++;
+ }
+ if (type_index >= static_cast(this->type_table.size())) {
+ this->type_table.resize((type_index / 1024 + 1) * 1024);
+ }
+ std::unique_ptr &wrapper = this->type_table.at(type_index) = std::make_unique();
+ this->type_key_to_info[type_key] = &wrapper->info;
+ // Step 3. Initialize type info
+ MLCTypeInfo *parent = parent_type_index == -1 ? nullptr : this->GetTypeInfo(parent_type_index);
+ MLCTypeInfo *info = &wrapper->info;
+ info->type_index = type_index;
+ info->type_key = this->pool.NewStr(type_key);
+ info->type_key_hash = ::mlc::base::StrHash(type_key, std::strlen(type_key));
+ info->type_depth = (parent == nullptr) ? 0 : (parent->type_depth + 1);
+ info->type_ancestors = this->pool.NewPODArray(info->type_depth);
+ if (parent) {
+ std::copy(parent->type_ancestors, parent->type_ancestors + parent->type_depth, info->type_ancestors);
+ info->type_ancestors[parent->type_depth] = parent_type_index;
+ }
+ info->fields = nullptr;
+ info->methods = nullptr;
+ info->structure_kind = 0;
+ info->sub_structure_indices = nullptr;
+ info->sub_structure_kinds = nullptr;
+ wrapper->pool = &this->pool;
+ return info;
+ }
+
+ void SetFunc(const char *name, FuncObj *func, bool allow_override) {
+ auto [it, success] = this->global_funcs.try_emplace(std::string(name), nullptr);
+ if (!success && !allow_override) {
+ MLC_THROW(KeyError) << "Global function already registered: " << name;
+ } else {
+ it->second = func;
+ }
+ this->pool.AddObj(func);
+ }
+
+ TypeInfoWrapper *GetTypeInfoWrapper(int32_t type_index) {
+ TypeInfoWrapper *wrapper = nullptr;
+ try {
+ wrapper = this->type_table.at(type_index).get();
+ } catch (const std::out_of_range &) {
+ }
+ if (wrapper == nullptr || wrapper->pool != &this->pool) {
+ MLC_THROW(KeyError) << "Type index `" << type_index << "` not registered";
+ }
+ return wrapper;
+ }
+
+ void SetFields(int32_t type_index, int64_t num_fields, MLCTypeField *fields) {
+ this->GetTypeInfoWrapper(type_index)->SetFields(num_fields, fields);
+ }
+
+ void SetStructure(int32_t type_index, int32_t structure_kind, int64_t num_sub_structures,
+ int32_t *sub_structure_indices, int32_t *sub_structure_kinds) {
+ this->GetTypeInfoWrapper(type_index)
+ ->SetStructure(structure_kind, num_sub_structures, sub_structure_indices, sub_structure_kinds);
+ }
+
+ void AddMethod(int32_t type_index, MLCTypeMethod method) {
+ // TODO: check `override_mode`
+ this->GetGlobalVTable(method.name)->Set(type_index, reinterpret_cast(method.func), 0);
+ this->GetTypeInfoWrapper(type_index)->AddMethod(method);
+ }
+
+ void LoadDSO(std::string name) {
+ auto [it, success] = this->dso_libraries.try_emplace(name, nullptr);
+ if (!success) {
+ return;
+ }
+ it->second = std::make_unique(name);
+ }
+};
+
+struct _POD_REG {
+ inline static const int32_t _none = //
+ ::mlc::core::ReflectionHelper(static_cast(MLCTypeIndex::kMLCNone))
+ .MemFn("__str__", &::mlc::base::TypeTraits::__str__);
+ inline static const int32_t _int = //
+ ::mlc::core::ReflectionHelper(static_cast(MLCTypeIndex::kMLCInt))
+ .StaticFn("__init__", [](AnyView value) { return value.operator int64_t(); })
+ .StaticFn("__new_ref__",
+ [](void *dst, Optional value) { *reinterpret_cast *>(dst) = value; })
+ .MemFn("__str__", &::mlc::base::TypeTraits::__str__);
+ inline static const int32_t _float =
+ ::mlc::core::ReflectionHelper(static_cast(MLCTypeIndex::kMLCFloat))
+ .StaticFn("__new_ref__",
+ [](void *dst, Optional value) { *reinterpret_cast *>(dst) = value; })
+ .MemFn("__str__", &::mlc::base::TypeTraits::__str__);
+ inline static const int32_t _ptr =
+ ::mlc::core::ReflectionHelper(static_cast(MLCTypeIndex::kMLCPtr))
+ .StaticFn("__new_ref__",
+ [](void *dst, Optional