diff --git a/.clang-format b/.clang-format new file mode 100644 index 00000000..173f8a6a --- /dev/null +++ b/.clang-format @@ -0,0 +1,2 @@ +BasedOnStyle: LLVM +ColumnLimit: 120 diff --git a/.clangd b/.clangd new file mode 100644 index 00000000..1ce440cd --- /dev/null +++ b/.clangd @@ -0,0 +1,2 @@ +CompileFlags: + CompilationDatabase: build-vscode/ diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 00000000..9d146f53 --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,82 @@ +name: CI + +on: [push, pull_request] +env: + CIBW_BUILD_VERBOSITY: 3 + CIBW_TEST_REQUIRES: "pytest" + CIBW_TEST_COMMAND: "pytest -svv --durations=20 {project}/tests/python/" + MLC_CIBW_VERSION: "2.20.0" + MLC_PYTHON_VERSION: "3.9" + MLC_CIBW_WIN_BUILD: "cp39-win_amd64" + MLC_CIBW_MAC_BUILD: "cp39-macosx_arm64" + MLC_CIBW_LINUX_BUILD: "cp312-manylinux_x86_64" + +jobs: + pre-commit: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: actions/setup-python@v5 + with: + python-version: ${{ env.MLC_PYTHON_VERSION }} + - uses: pre-commit/action@v3.0.1 + windows: + name: Windows + runs-on: windows-latest + needs: pre-commit + steps: + - uses: actions/checkout@v4 + with: + submodules: "recursive" + - uses: actions/setup-python@v5 + with: + python-version: ${{ env.MLC_PYTHON_VERSION }} + - name: Install cibuildwheel + run: python -m pip install cibuildwheel=="${{ env.MLC_CIBW_VERSION }}" + - name: Build wheels + run: python -m cibuildwheel --output-dir wheelhouse + env: + CIBW_BEFORE_ALL: ".\\scripts\\cpp_tests.bat" + CIBW_BUILD: ${{ env.MLC_CIBW_WIN_BUILD }} + - name: Show package contents + run: python scripts/show_wheel_content.py wheelhouse + macos: + name: MacOS + runs-on: macos-latest + needs: pre-commit + steps: + - uses: actions/checkout@v4 + with: + submodules: "recursive" + - uses: actions/setup-python@v5 + with: + python-version: ${{ env.MLC_PYTHON_VERSION }} + - name: Install cibuildwheel + run: python -m pip install cibuildwheel==${{ env.MLC_CIBW_VERSION }} + - name: Build wheels + run: python -m cibuildwheel --output-dir wheelhouse + env: + CIBW_BEFORE_ALL: "./scripts/cpp_tests.sh" + CIBW_BUILD: ${{ env.MLC_CIBW_MAC_BUILD }} + - name: Show package contents + run: python scripts/show_wheel_content.py wheelhouse + linux: + name: Linux + runs-on: ubuntu-latest + needs: pre-commit + steps: + - uses: actions/checkout@v4 + with: + submodules: "recursive" + - uses: actions/setup-python@v5 + with: + python-version: ${{ env.MLC_PYTHON_VERSION }} + - name: Install cibuildwheel + run: python -m pip install cibuildwheel==${{ env.MLC_CIBW_VERSION }} + - name: Build wheels + run: python -m cibuildwheel --output-dir wheelhouse + env: + CIBW_BEFORE_ALL: "./scripts/setup_manylinux2014.sh && ./scripts/cpp_tests.sh" + CIBW_BUILD: ${{ env.MLC_CIBW_LINUX_BUILD }} + - name: Show package contents + run: python scripts/show_wheel_content.py wheelhouse diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml new file mode 100644 index 00000000..64fa06a6 --- /dev/null +++ b/.github/workflows/release.yml @@ -0,0 +1,168 @@ +name: Release + +on: + release: + types: [published] + +env: + CIBW_BUILD_VERBOSITY: 3 + CIBW_TEST_COMMAND: "python -c \"import mlc\"" + CIBW_SKIP: "cp313-win_amd64" # Python 3.13 is not quite ready yet + MLC_CIBW_VERSION: "2.20.0" + MLC_PYTHON_VERSION: "3.9" + MLC_CIBW_WIN_BUILD: "cp3*-win_amd64" + MLC_CIBW_MAC_BUILD: "cp3*-macosx_arm64" + MLC_CIBW_MAC_X86_BUILD: "cp3*-macosx_x86_64" + MLC_CIBW_LINUX_BUILD: "cp3*-manylinux_x86_64" + +jobs: + windows: + name: Windows + runs-on: windows-latest + steps: + - uses: actions/checkout@v4 + with: + submodules: "recursive" + - uses: actions/setup-python@v5 + with: + python-version: ${{ env.MLC_PYTHON_VERSION }} + - name: Install cibuildwheel + run: python -m pip install cibuildwheel=="${{ env.MLC_CIBW_VERSION }}" + - name: Build wheels + run: python -m cibuildwheel --output-dir wheelhouse + env: + CIBW_BUILD: ${{ env.MLC_CIBW_WIN_BUILD }} + - name: Show package contents + run: python scripts/show_wheel_content.py wheelhouse + - name: Upload wheels + uses: actions/upload-artifact@v4 + with: + name: wheels-windows + path: ./wheelhouse/*.whl + macos: + name: MacOS + runs-on: macos-latest + steps: + - uses: actions/checkout@v4 + with: + submodules: "recursive" + - uses: actions/setup-python@v5 + with: + python-version: ${{ env.MLC_PYTHON_VERSION }} + - name: Install cibuildwheel + run: python -m pip install cibuildwheel==${{ env.MLC_CIBW_VERSION }} + - name: Build wheels + run: python -m cibuildwheel --output-dir wheelhouse + env: + CIBW_BUILD: ${{ env.MLC_CIBW_MAC_BUILD }} + - name: Show package contents + run: python scripts/show_wheel_content.py wheelhouse + - name: Upload wheels + uses: actions/upload-artifact@v4 + with: + name: wheels-macos + path: ./wheelhouse/*.whl + macos-x86: + name: MacOS-x86 + runs-on: macos-13 + steps: + - uses: actions/checkout@v4 + with: + submodules: "recursive" + - uses: actions/setup-python@v5 + with: + python-version: ${{ env.MLC_PYTHON_VERSION }} + - name: Install cibuildwheel + run: python -m pip install cibuildwheel==${{ env.MLC_CIBW_VERSION }} + - name: Build wheels + run: python -m cibuildwheel --output-dir wheelhouse + env: + CIBW_BUILD: ${{ env.MLC_CIBW_MAC_X86_BUILD }} + - name: Show package contents + run: python scripts/show_wheel_content.py wheelhouse + - name: Upload wheels + uses: actions/upload-artifact@v4 + with: + name: wheels-macos-x86 + path: ./wheelhouse/*.whl + linux: + name: Linux + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + with: + submodules: "recursive" + - uses: actions/setup-python@v5 + with: + python-version: ${{ env.MLC_PYTHON_VERSION }} + - name: Install cibuildwheel + run: python -m pip install cibuildwheel==${{ env.MLC_CIBW_VERSION }} + - name: Build wheels + run: python -m cibuildwheel --output-dir wheelhouse + env: + CIBW_BUILD: ${{ env.MLC_CIBW_LINUX_BUILD }} + - name: Show package contents + run: python scripts/show_wheel_content.py wheelhouse + - name: Upload wheels + uses: actions/upload-artifact@v4 + with: + name: wheels-linux + path: ./wheelhouse/*.whl + publish: + name: Publish + runs-on: ubuntu-latest + needs: [windows, macos, linux, macos-x86] + environment: + name: pypi + url: https://pypi.org/p/mlc-python + permissions: + id-token: write + contents: write + steps: + - uses: actions/setup-python@v5 + with: + python-version: ${{ env.MLC_PYTHON_VERSION }} + - name: Checkout code + uses: actions/checkout@v4 + - name: Download all artifacts + uses: actions/download-artifact@v4 + with: + path: ./wheelhouse + - name: Prepare distribution files + run: | + mkdir -p dist + mv wheelhouse/wheels-macos-x86/*.whl dist/ + mv wheelhouse/wheels-macos/*.whl dist/ + mv wheelhouse/wheels-linux/*.whl dist/ + mv wheelhouse/wheels-windows/*.whl dist/ + - name: Upload wheels to release + env: + GITHUB_TOKEN: ${{ secrets.PYMLC_GITHUB_TOKEN }} + run: | + # Get the release ID + release_id=$(gh api repos/${{ github.repository }}/releases/tags/${{ github.ref_name }} --jq '.id') + if [ -z "$release_id" ]; then + echo "Error: Could not find release with tag ${{ github.ref_name }}" + exit 1 + fi + upload_url=https://uploads.github.com/repos/${{ github.repository }}/releases/$release_id/assets + echo "Uploading to release ID: $release_id. Upload URL: $upload_url" + for wheel in dist/*.whl; do + echo "⏫ Uploading $(basename $wheel)" + response=$(curl -s -w "%{http_code}" -X POST \ + -H "Authorization: token $GITHUB_TOKEN" \ + -H "Content-Type: application/octet-stream" \ + --data-binary @"$wheel" \ + "$upload_url?name=$(basename $wheel)") + http_code=${response: -3} + if [ $http_code -eq 201 ]; then + echo "🎉 Successfully uploaded $(basename $wheel)" + else + echo "❌ Failed to upload $(basename $wheel). HTTP status: $http_code. Error response: ${response%???}" + exit 1 + fi + done + - name: Publish to PyPI + uses: pypa/gh-action-pypi-publish@release/v1 + with: + packages-dir: dist/ diff --git a/.gitignore b/.gitignore new file mode 100644 index 00000000..54e90205 --- /dev/null +++ b/.gitignore @@ -0,0 +1,4 @@ +.vscode +*.dSYM +build +build-cpp-tests diff --git a/.gitmodules b/.gitmodules new file mode 100644 index 00000000..28b086b9 --- /dev/null +++ b/.gitmodules @@ -0,0 +1,9 @@ +[submodule "3rdparty/dlpack"] + path = 3rdparty/dlpack + url = https://github.com/dmlc/dlpack.git +[submodule "3rdparty/libbacktrace"] + path = 3rdparty/libbacktrace + url = https://github.com/ianlancetaylor/libbacktrace +[submodule "3rdparty/googletest"] + path = 3rdparty/googletest + url = https://github.com/google/googletest.git diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 00000000..bf1659c0 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,54 @@ +# See https://pre-commit.com for more information +# See https://pre-commit.com/hooks.html for more hooks +repos: + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: v5.0.0 + hooks: + - id: trailing-whitespace + - id: mixed-line-ending + - id: end-of-file-fixer + - id: check-yaml + - id: check-toml + - id: check-added-large-files + - repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.4.7 + hooks: + - id: ruff + types_or: [python, pyi, jupyter] + args: [--fix] + - id: ruff-format + types_or: [python, pyi, jupyter] + - repo: https://github.com/pre-commit/mirrors-mypy + rev: "v1.11.1" + hooks: + - id: mypy + additional_dependencies: ['numpy >= 1.22', "ml-dtypes >= 0.1", "pytest"] + args: [--show-error-codes] + - repo: https://github.com/pre-commit/mirrors-clang-format + rev: "v18.1.5" + hooks: + - id: clang-format + - repo: https://github.com/MarcoGorelli/cython-lint + rev: v0.16.2 + hooks: + - id: cython-lint + - id: double-quote-cython-strings + - repo: https://github.com/scop/pre-commit-shfmt + rev: v3.8.0-1 + hooks: + - id: shfmt + - repo: https://github.com/shellcheck-py/shellcheck-py + rev: v0.10.0.1 + hooks: + - id: shellcheck + # - repo: https://github.com/cheshirekow/cmake-format-precommit + # rev: v0.6.10 + # hooks: + # - id: cmake-format + # - id: cmake-lint + - repo: https://github.com/compilerla/conventional-pre-commit + rev: v2.1.1 + hooks: + - id: conventional-pre-commit + stages: [commit-msg] + args: [feat, fix, ci, chore, test] diff --git a/3rdparty/dlpack b/3rdparty/dlpack new file mode 160000 index 00000000..bbd2f4d3 --- /dev/null +++ b/3rdparty/dlpack @@ -0,0 +1 @@ +Subproject commit bbd2f4d32427e548797929af08cfe2a9cbb3cf12 diff --git a/3rdparty/googletest b/3rdparty/googletest new file mode 160000 index 00000000..f8d7d77c --- /dev/null +++ b/3rdparty/googletest @@ -0,0 +1 @@ +Subproject commit f8d7d77c06936315286eb55f8de22cd23c188571 diff --git a/3rdparty/libbacktrace b/3rdparty/libbacktrace new file mode 160000 index 00000000..febbb9bf --- /dev/null +++ b/3rdparty/libbacktrace @@ -0,0 +1 @@ +Subproject commit febbb9bff98b39ee596aa15d1f58e4bba442cd6a diff --git a/CMakeLists.txt b/CMakeLists.txt new file mode 100644 index 00000000..961ccd48 --- /dev/null +++ b/CMakeLists.txt @@ -0,0 +1,107 @@ +cmake_minimum_required(VERSION 3.15) + +project( + mlc + VERSION 0.0.14 + DESCRIPTION "MLC-Python" + LANGUAGES C CXX +) + +option(MLC_BUILD_TESTS "Build tests. This option will enable a test target `mlc_tests`." OFF) +option(MLC_BUILD_PY "Build Python bindings." OFF) +option(MLC_BUILD_REGISTRY + "Support for objects with non-static type indices. When turned on, \ + targets linked against `mlc` will allow objects that comes with non-pre-defined type indices, \ + so that the object hierarchy could expand without limitation. \ + This will require the downstream targets to link against target `mlc_registry` to be effective." + OFF +) + +include(cmake/Utils/CxxWarning.cmake) +include(cmake/Utils/Sanitizer.cmake) +include(cmake/Utils/Library.cmake) +include(cmake/Utils/AddLibbacktrace.cmake) +include(cmake/Utils/DebugSymbol.cmake) + +########## Target: `dlpack_header` ########## + +add_library(dlpack_header INTERFACE) +target_include_directories(dlpack_header INTERFACE "${CMAKE_CURRENT_SOURCE_DIR}/3rdparty/dlpack/include") + +########## Target: `mlc` ########## + +add_library(mlc INTERFACE) +target_link_libraries(mlc INTERFACE dlpack_header) +target_compile_features(mlc INTERFACE cxx_std_17) +target_include_directories(mlc INTERFACE "${CMAKE_CURRENT_SOURCE_DIR}/include") + +########## Target: `mlc_registry` ########## + +if (MLC_BUILD_REGISTRY) + add_library(mlc_registry_objs OBJECT + "${CMAKE_CURRENT_SOURCE_DIR}/cpp/c_api.cc" + "${CMAKE_CURRENT_SOURCE_DIR}/cpp/c_api_tests.cc" + "${CMAKE_CURRENT_SOURCE_DIR}/cpp/printer.cc" + "${CMAKE_CURRENT_SOURCE_DIR}/cpp/json.cc" + "${CMAKE_CURRENT_SOURCE_DIR}/cpp/structure.cc" + "${CMAKE_CURRENT_SOURCE_DIR}/cpp/traceback.cc" + "${CMAKE_CURRENT_SOURCE_DIR}/cpp/traceback_win.cc" + ) + set_target_properties( + mlc_registry_objs PROPERTIES + POSITION_INDEPENDENT_CODE ON + CXX_STANDARD 17 + CXX_EXTENSIONS OFF + CXX_STANDARD_REQUIRED ON + CXX_VISIBILITY_PRESET hidden + VISIBILITY_INLINES_HIDDEN ON + PREFIX "lib" + ) + add_cxx_warning(mlc_registry_objs) + target_link_libraries(mlc_registry_objs PRIVATE dlpack_header) + target_include_directories(mlc_registry_objs PRIVATE "${CMAKE_CURRENT_SOURCE_DIR}/include") + add_target_from_obj(mlc_registry mlc_registry_objs) + if (TARGET libbacktrace) + target_link_libraries(mlc_registry_objs PRIVATE libbacktrace) + target_link_libraries(mlc_registry_shared PRIVATE libbacktrace) + target_link_libraries(mlc_registry_static PRIVATE libbacktrace) + add_debug_symbol_apple(mlc_registry_shared "lib/") + endif () + install(TARGETS mlc_registry_static DESTINATION "lib/") + install(TARGETS mlc_registry_shared DESTINATION "lib/") +endif (MLC_BUILD_REGISTRY) + +########## Target: `mlc_py` ########## + +if (MLC_BUILD_PY) + find_package(Python COMPONENTS Interpreter Development.Module REQUIRED) + file(GLOB _cython_sources CONFIGURE_DEPENDS "${CMAKE_CURRENT_SOURCE_DIR}/python/mlc/_cython/*.pyx") + set(_cython_outputs "") + foreach(_in IN LISTS _cython_sources) + get_filename_component(_file ${_in} NAME_WLE) + set(_out "${CMAKE_BINARY_DIR}/_cython/${_file}_cython.cc") + message(STATUS "Cythonize: ${_in} -> ${_out}") + add_custom_command( + OUTPUT "${_out}" DEPENDS "${_in}" + COMMENT "Making `${_out}` from `${_in}" + COMMAND Python::Interpreter -m cython "${_in}" --output-file "${_out}" --cplus + VERBATIM + ) + list(APPEND _cython_outputs "${_out}") + endforeach() + Python_add_library(mlc_py MODULE ${_cython_outputs} WITH_SOABI) + set_target_properties(mlc_py PROPERTIES OUTPUT_NAME "core") + target_link_libraries(mlc_py PUBLIC mlc) + install(TARGETS mlc_py DESTINATION _cython/) +endif (MLC_BUILD_PY) + +########## Adding tests ########## + +if (${PROJECT_NAME} STREQUAL ${CMAKE_PROJECT_NAME}) + if (MLC_BUILD_TESTS) + enable_testing() + message(STATUS "Enable Testing") + include(cmake/Utils/AddGoogleTest.cmake) + add_subdirectory(tests/cpp/) + endif() +endif () diff --git a/LICENSE b/LICENSE new file mode 100644 index 00000000..261eeb9e --- /dev/null +++ b/LICENSE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/README.md b/README.md new file mode 100644 index 00000000..2b93b258 --- /dev/null +++ b/README.md @@ -0,0 +1,237 @@ +

+ MLC Logo + + MLC-Python +

+ +* [:inbox_tray: Installation](#inbox_tray-installation) +* [:key: Key Features](#key-key-features) + + [:building_construction: MLC Dataclass](#building_construction-mlc-dataclass) + + [:dart: Structure-Aware Tooling](#dart-structure-aware-tooling) + + [:snake: Text Formats in Python](#snake-text-formats-in-python) + + [:zap: Zero-Copy Interoperability with C++ Plugins](#zap-zero-copy-interoperability-with-c-plugins) +* [:fuelpump: Development](#fuelpump-development) + + [:gear: Editable Build](#gear-editable-build) + + [:ferris_wheel: Create Wheels](#ferris_wheel-create-wheels) + + +MLC is a Python-first toolkit that streamlines the development of AI compilers, runtimes, and compound AI systems with its Pythonic dataclasses, structure-aware tooling, and Python-based text formats. + +Beyond pure Python, MLC natively supports zero-copy interoperation with C++ plugins, and enables a smooth engineering practice transitioning from Python to hybrid or Python-free development. + +## :inbox_tray: Installation + +```bash +pip install -U mlc-python +``` + +## :key: Key Features + +### :building_construction: MLC Dataclass + +MLC dataclass is similar to Python’s native dataclass: + +```python +import mlc.dataclasses as mlcd + +@mlcd.py_class("demo.MyClass") +class MyClass(mlcd.PyClass): + a: int + b: str + c: float | None + +instance = MyClass(12, "test", c=None) +``` + +**Type safety**. MLC dataclass enforces strict type checking using Cython and C++. + +```python +>>> instance.c = 10; print(instance) +demo.MyClass(a=12, b='test', c=10.0) + +>>> instance.c = "wrong type" +TypeError: must be real number, not str + +>>> instance.non_exist = 1 +AttributeError: 'MyClass' object has no attribute 'non_exist' and no __dict__ for setting new attributes +``` + +**Serialization**. MLC dataclasses are picklable and JSON-serializable. + +```python +>>> MyClass.from_json(instance.json()) +demo.MyClass(a=12, b='test', c=None) + +>>> import pickle; pickle.loads(pickle.dumps(instance)) +demo.MyClass(a=12, b='test', c=None) +``` + +### :dart: Structure-Aware Tooling + +An extra `structure` field are used to specify a dataclass's structure, indicating def site and scoping in an IR. + +
Define a toy IR with `structure`. + +```python +import mlc.dataclasses as mlcd + +@mlcd.py_class +class Expr(mlcd.PyClass): + def __add__(self, other): + return Add(a=self, b=other) + +@mlcd.py_class(structure="nobind") +class Add(Expr): + a: Expr + b: Expr + +@mlcd.py_class(structure="var") +class Var(Expr): + name: str = mlcd.field(structure=None) # excludes `name` from defined structure + +@mlcd.py_class(structure="bind") +class Let(Expr): + rhs: Expr + lhs: Var = mlcd.field(structure="bind") # `Let.lhs` is the def-site + body: Expr +``` + +
+ +**Structural equality**. Member method `eq_s` compares the structural equality (alpha equivalence) of two IRs represented by MLC's structured dataclass. + +```python +>>> x, y, z = Var("x"), Var("y"), Var("z") +>>> L1 = Let(rhs=x + y, lhs=z, body=z) # let z = x + y; z +>>> L2 = Let(rhs=y + z, lhs=x, body=x) # let x = y + z; x +>>> L3 = Let(rhs=x + x, lhs=z, body=z) # let z = x + x; z +>>> L1.eq_s(L2) +True +>>> L1.eq_s(L3, assert_mode=True) +ValueError: Structural equality check failed at {root}.rhs.b: Inconsistent binding. RHS has been bound to a different node while LHS is not bound +``` + +**Structural hashing**. The structure of MLC dataclasses can be hashed via `hash_s`, which guarantees if two dataclasses are alpha-equivalent, they will share the same structural hash: + +```python +>>> L1_hash, L2_hash, L3_hash = L1.hash_s(), L2.hash_s(), L3.hash_s() +>>> assert L1_hash == L2_hash +>>> assert L1_hash != L3_hash +``` + +### :snake: Text Formats in Python + +**IR Printer.** By defining an `__ir_print__` method, which converts an IR node to MLC's Python-style AST, MLC's `IRPrinter` handles variable scoping, renaming and syntax highlighting automatically for a text format based on Python syntax. + +
Defining Python-based text format on a toy IR using `__ir_print__`. + +```python +import mlc.dataclasses as mlcd +import mlc.printer as mlcp +from mlc.printer import ast as mlt + +@mlcd.py_class +class Expr(mlcd.PyClass): ... + +@mlcd.py_class +class Stmt(mlcd.PyClass): ... + +@mlcd.py_class +class Var(Expr): + name: str + def __ir_print__(self, printer: mlcp.IRPrinter, path: mlcp.ObjectPath) -> mlt.Node: + if not printer.var_is_defined(obj=self): + printer.var_def(obj=self, frame=printer.frames[-1], name=self.name) + return printer.var_get(obj=self) + +@mlcd.py_class +class Add(Expr): + lhs: Expr + rhs: Expr + def __ir_print__(self, printer: mlcp.IRPrinter, path: mlcp.ObjectPath) -> mlt.Node: + lhs: mlt.Expr = printer(obj=self.lhs, path=path["a"]) + rhs: mlt.Expr = printer(obj=self.rhs, path=path["b"]) + return lhs + rhs + +@mlcd.py_class +class Assign(Stmt): + lhs: Var + rhs: Expr + def __ir_print__(self, printer: mlcp.IRPrinter, path: mlcp.ObjectPath) -> mlt.Node: + rhs: mlt.Expr = printer(obj=self.rhs, path=path["b"]) + printer.var_def(obj=self.lhs, frame=printer.frames[-1], name=self.lhs.name) + lhs: mlt.Expr = printer(obj=self.lhs, path=path["a"]) + return mlt.Assign(lhs=lhs, rhs=rhs) + +@mlcd.py_class +class Func(mlcd.PyClass): + name: str + args: list[Var] + stmts: list[Stmt] + ret: Var + def __ir_print__(self, printer: mlcp.IRPrinter, path: mlcp.ObjectPath) -> mlt.Node: + with printer.with_frame(mlcp.DefaultFrame()): + for arg in self.args: + printer.var_def(obj=arg, frame=printer.frames[-1], name=arg.name) + args: list[mlt.Expr] = [printer(obj=arg, path=path["args"][i]) for i, arg in enumerate(self.args)] + stmts: list[mlt.Expr] = [printer(obj=stmt, path=path["stmts"][i]) for i, stmt in enumerate(self.stmts)] + ret_stmt = mlt.Return(printer(obj=self.ret, path=path["ret"])) + return mlt.Function( + name=mlt.Id(self.name), + args=[mlt.Assign(lhs=arg, rhs=None) for arg in args], + decorators=[], + return_type=None, + body=[*stmts, ret_stmt], + ) + +# An example IR: +a, b, c, d, e = Var("a"), Var("b"), Var("c"), Var("d"), Var("e") +f = Func( + name="f", + args=[a, b, c], + stmts=[ + Assign(lhs=d, rhs=Add(a, b)), # d = a + b + Assign(lhs=e, rhs=Add(d, c)), # e = d + c + ], + ret=e, +) +``` + +
+ +Two printer APIs are provided for Python-based text format: +- `mlc.printer.to_python` that converts an IR fragment to Python text, and +- `mlc.printer.print_python` that further renders the text with proper syntax highlighting. + +```python +>>> print(mlcp.to_python(f)) # Stringify to Python +def f(a, b, c): + d = a + b + e = d + c + return e +>>> mlcp.print_python(f) # Syntax highlighting +``` + +### :zap: Zero-Copy Interoperability with C++ Plugins + +TBD + +## :fuelpump: Development + +### :gear: Editable Build + +```bash +pip install --verbose --editable ".[dev]" +pre-commit install +``` + +### :ferris_wheel: Create Wheels + +This project uses `cibuildwheel` to build cross-platform wheels. See `.github/workflows/wheels.ym` for more details. + +```bash +export CIBW_BUILD_VERBOSITY=3 +export CIBW_BUILD="cp3*-manylinux_x86_64" +python -m pip install pipx +pipx run cibuildwheel==2.20.0 --output-dir wheelhouse +``` diff --git a/cmake/Utils/AddGoogleTest.cmake b/cmake/Utils/AddGoogleTest.cmake new file mode 100644 index 00000000..f2e52643 --- /dev/null +++ b/cmake/Utils/AddGoogleTest.cmake @@ -0,0 +1,43 @@ +set(gtest_force_shared_crt ON CACHE BOOL "Always use msvcrt.dll" FORCE) +set(BUILD_GMOCK OFF CACHE BOOL "" FORCE) +set(BUILD_GTEST ON CACHE BOOL "" FORCE) +set(GOOGLETEST_ROOT ${CMAKE_CURRENT_SOURCE_DIR}/3rdparty/googletest) + +add_subdirectory(${GOOGLETEST_ROOT}) +include(GoogleTest) + +set_target_properties(gtest gtest_main + PROPERTIES + EXPORT_COMPILE_COMMANDS OFF + EXCLUDE_FROM_ALL ON + FOLDER 3rdparty + ARCHIVE_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/lib" + LIBRARY_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/lib" + RUNTIME_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/bin" +) + +# Install gtest and gtest_main +install(TARGETS gtest gtest_main DESTINATION "lib/") + +# Mark advanced variables +mark_as_advanced( + BUILD_GMOCK BUILD_GTEST BUILD_SHARED_LIBS + gmock_build_tests gtest_build_samples gtest_build_tests + gtest_disable_pthreads gtest_force_shared_crt gtest_hide_internal_symbols +) + +macro(add_googletest target_name) + add_test( + NAME ${target_name} + COMMAND ${target_name} + WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR} + ) + target_link_libraries(${target_name} PRIVATE gtest_main) + gtest_discover_tests(${target_name} + WORKING_DIRECTORY ${PROJECT_SOURCE_DIR} + DISCOVERY_MODE PRE_TEST + PROPERTIES + VS_DEBUGGER_WORKING_DIRECTORY "${PROJECT_SOURCE_DIR}" + ) + set_target_properties(${target_name} PROPERTIES FOLDER tests) +endmacro() diff --git a/cmake/Utils/AddLibbacktrace.cmake b/cmake/Utils/AddLibbacktrace.cmake new file mode 100644 index 00000000..c9697538 --- /dev/null +++ b/cmake/Utils/AddLibbacktrace.cmake @@ -0,0 +1,50 @@ +include(ExternalProject) + +function(_libbacktrace_compile) + set(_libbacktrace_source ${CMAKE_CURRENT_LIST_DIR}/../../3rdparty/libbacktrace) + set(_libbacktrace_prefix ${CMAKE_CURRENT_BINARY_DIR}/libbacktrace) + if(CMAKE_SYSTEM_NAME MATCHES "Darwin" AND (CMAKE_C_COMPILER MATCHES "^/Library" OR CMAKE_C_COMPILER MATCHES "^/Applications")) + set(_cmake_c_compiler "/usr/bin/cc") + else() + set(_cmake_c_compiler "${CMAKE_C_COMPILER}") + endif() + + file(MAKE_DIRECTORY ${_libbacktrace_prefix}/include) + file(MAKE_DIRECTORY ${_libbacktrace_prefix}/lib) + ExternalProject_Add(project_libbacktrace + PREFIX libbacktrace + SOURCE_DIR ${_libbacktrace_source} + BINARY_DIR ${_libbacktrace_prefix} + CONFIGURE_COMMAND + ${_libbacktrace_source}/configure + "--prefix=${_libbacktrace_prefix}" + "--with-pic" + "CC=${_cmake_c_compiler}" + "CPP=${_cmake_c_compiler} -E" + "CFLAGS=${CMAKE_C_FLAGS}" + "LDFLAGS=${CMAKE_EXE_LINKER_FLAGS}" + "NM=${CMAKE_NM}" + "STRIP=${CMAKE_STRIP}" + "--host=${MACHINE_NAME}" + BUILD_COMMAND make -j + BUILD_BYPRODUCTS ${_libbacktrace_prefix}/lib/libbacktrace.a ${_libbacktrace_prefix}/include/backtrace.h + INSTALL_DIR ${_libbacktrace_prefix} + INSTALL_COMMAND make install + LOG_CONFIGURE ON + LOG_INSTALL ON + LOG_BUILD ON + LOG_OUTPUT_ON_FAILURE ON + ) + ExternalProject_Add_Step(project_libbacktrace checkout DEPENDERS configure DEPENDEES download) + set_target_properties(project_libbacktrace PROPERTIES EXCLUDE_FROM_ALL TRUE) + add_library(libbacktrace STATIC IMPORTED) + add_dependencies(libbacktrace project_libbacktrace) + set_target_properties(libbacktrace PROPERTIES + IMPORTED_LOCATION ${_libbacktrace_prefix}/lib/libbacktrace.a + INTERFACE_INCLUDE_DIRECTORIES ${_libbacktrace_prefix}/include + ) +endfunction() + +if(NOT MSVC) + _libbacktrace_compile() +endif() diff --git a/cmake/Utils/CxxWarning.cmake b/cmake/Utils/CxxWarning.cmake new file mode 100644 index 00000000..50ee5b61 --- /dev/null +++ b/cmake/Utils/CxxWarning.cmake @@ -0,0 +1,13 @@ +function(add_cxx_warning target_name) + # GNU, Clang, or AppleClang + if(CMAKE_CXX_COMPILER_ID MATCHES "GNU|Clang|AppleClang") + target_compile_options(${target_name} PRIVATE "-Werror" "-Wall" "-Wextra" "-Wpedantic") + return() + endif() + # MSVC + if(MSVC) + target_compile_options(${target_name} PRIVATE "/W4" "/WX") + return() + endif() + message(FATAL_ERROR "Unsupported compiler: ${CMAKE_CXX_COMPILER_ID}") +endfunction() diff --git a/cmake/Utils/DebugSymbol.cmake b/cmake/Utils/DebugSymbol.cmake new file mode 100644 index 00000000..7f5bbae2 --- /dev/null +++ b/cmake/Utils/DebugSymbol.cmake @@ -0,0 +1,14 @@ +if(APPLE) + find_program(DSYMUTIL_PROGRAM dsymutil) + mark_as_advanced(DSYMUTIL_PROGRAM) +endif() + +function(add_debug_symbol_apple _target _directory) + if (APPLE) + add_custom_command(TARGET ${_target} POST_BUILD + COMMAND ${DSYMUTIL_PROGRAM} ARGS $ + COMMENT "Running dsymutil" VERBATIM + ) + install(FILES $.dSYM DESTINATION ${_directory}) + endif (APPLE) +endfunction() diff --git a/cmake/Utils/Library.cmake b/cmake/Utils/Library.cmake new file mode 100644 index 00000000..26bdcef6 --- /dev/null +++ b/cmake/Utils/Library.cmake @@ -0,0 +1,28 @@ +function(add_target_from_obj target_name obj_target_name) + add_library(${target_name}_static STATIC $) + set_target_properties( + ${target_name}_static PROPERTIES + OUTPUT_NAME "${target_name}_static" + PREFIX "lib" + ARCHIVE_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/lib" + LIBRARY_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/lib" + ) + add_library(${target_name}_shared SHARED $) + set_target_properties( + ${target_name}_shared PROPERTIES + OUTPUT_NAME "${target_name}" + PREFIX "lib" + ARCHIVE_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/lib" + LIBRARY_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/lib" + ) + add_custom_target(${target_name}) + add_dependencies(${target_name} ${target_name}_static ${target_name}_shared) + if (MSVC) + target_compile_definitions(${obj_target_name} PRIVATE MLC_EXPORTS) + set_target_properties( + ${obj_target_name} ${target_name}_shared ${target_name}_static + PROPERTIES + MSVC_RUNTIME_LIBRARY "MultiThreaded$<$:Debug>DLL" + ) + endif() +endfunction() diff --git a/cmake/Utils/Sanitizer.cmake b/cmake/Utils/Sanitizer.cmake new file mode 100644 index 00000000..c1facc19 --- /dev/null +++ b/cmake/Utils/Sanitizer.cmake @@ -0,0 +1,18 @@ +function(add_sanitizer_address target_name) + if(CMAKE_CXX_COMPILER_ID MATCHES "GNU|Clang|AppleClang") + include(CheckCXXCompilerFlag) + set (_saved_CRF ${CMAKE_REQUIRED_FLAGS}) + set(CMAKE_REQUIRED_FLAGS "-fsanitize=address") + check_cxx_source_compiles("int main() { return 0; }" COMPILER_SUPPORTS_ASAN) + set (CMAKE_REQUIRED_FLAGS ${_saved_CRF}) + get_target_property(_saved_type ${target_name} TYPE) + if (${_saved_type} STREQUAL "INTERFACE_LIBRARY") + set(_saved_type INTERFACE) + else() + set(_saved_type PRIVATE) + endif() + target_link_options(${target_name} ${_saved_type} "-fsanitize=address") + target_compile_options(${target_name} ${_saved_type} "-fsanitize=address" "-fno-omit-frame-pointer" "-g") + return() + endif() +endfunction() diff --git a/cpp/c_api.cc b/cpp/c_api.cc new file mode 100644 index 00000000..4b4f6651 --- /dev/null +++ b/cpp/c_api.cc @@ -0,0 +1,187 @@ +#include "./registry.h" +#include + +namespace mlc { +namespace registry { +TypeTable *TypeTable::Global() { + static TypeTable *instance = TypeTable::New(); + return instance; +} +} // namespace registry +} // namespace mlc + +using ::mlc::Any; +using ::mlc::AnyView; +using ::mlc::ErrorObj; +using ::mlc::FuncObj; +using ::mlc::Ref; +using ::mlc::registry::TypeTable; + +namespace { +thread_local Any last_error; +MLC_REGISTER_FUNC("mlc.ffi.LoadDSO").set_body([](std::string name) { TypeTable::Get(nullptr)->LoadDSO(name); }); +} // namespace + +MLC_API MLCAny MLCGetLastError() { + MLCAny ret; + *static_cast(&ret) = std::move(last_error); + return ret; +} + +MLC_API int32_t MLCTypeRegister(MLCTypeTableHandle _self, int32_t parent_type_index, const char *type_key, + int32_t type_index, MLCTypeInfo **out_type_info) { + MLC_SAFE_CALL_BEGIN(); + *out_type_info = TypeTable::Get(_self)->TypeRegister(parent_type_index, type_index, type_key); + MLC_SAFE_CALL_END(&last_error); +} + +MLC_API int32_t MLCTypeIndex2Info(MLCTypeTableHandle _self, int32_t type_index, MLCTypeInfo **ret) { + MLC_SAFE_CALL_BEGIN(); + *ret = TypeTable::Get(_self)->GetTypeInfo(type_index); + MLC_SAFE_CALL_END(&last_error); +} + +MLC_API int32_t MLCTypeKey2Info(MLCTypeTableHandle _self, const char *type_key, MLCTypeInfo **ret) { + MLC_SAFE_CALL_BEGIN(); + *ret = TypeTable::Get(_self)->GetTypeInfo(type_key); + MLC_SAFE_CALL_END(&last_error); +} + +MLC_API int32_t MLCTypeRegisterFields(MLCTypeTableHandle _self, int32_t type_index, int64_t num_fields, + MLCTypeField *fields) { + MLC_SAFE_CALL_BEGIN(); + TypeTable::Get(_self)->SetFields(type_index, num_fields, fields); + MLC_SAFE_CALL_END(&last_error); +} + +MLC_API int32_t MLCTypeRegisterStructure(MLCTypeTableHandle _self, int32_t type_index, int32_t structure_kind, + int64_t num_sub_structures, int32_t *sub_structure_indices, + int32_t *sub_structure_kinds) { + MLC_SAFE_CALL_BEGIN(); + TypeTable::Get(_self)->SetStructure(type_index, structure_kind, num_sub_structures, sub_structure_indices, + sub_structure_kinds); + MLC_SAFE_CALL_END(&last_error); +} + +MLC_API int32_t MLCTypeAddMethod(MLCTypeTableHandle _self, int32_t type_index, MLCTypeMethod method) { + MLC_SAFE_CALL_BEGIN(); + TypeTable::Get(_self)->AddMethod(type_index, method); + MLC_SAFE_CALL_END(&last_error); +} + +MLC_API int32_t MLCVTableGetGlobal(MLCTypeTableHandle _self, const char *key, MLCVTableHandle *ret) { + MLC_SAFE_CALL_BEGIN(); + *ret = TypeTable::Get(_self)->GetGlobalVTable(key); + MLC_SAFE_CALL_END(&last_error); +} + +MLC_API int32_t MLCVTableGetFunc(MLCVTableHandle vtable, int32_t type_index, int32_t allow_ancestor, MLCAny *ret) { + using ::mlc::registry::VTable; + MLC_SAFE_CALL_BEGIN(); + *static_cast(ret) = static_cast(vtable)->GetFunc(type_index, allow_ancestor); + MLC_SAFE_CALL_END(&last_error); +} + +MLC_API int32_t MLCVTableSetFunc(MLCVTableHandle vtable, int32_t type_index, MLCFunc *func, int32_t override_mode) { + using ::mlc::registry::VTable; + MLC_SAFE_CALL_BEGIN(); + static_cast(vtable)->Set(type_index, static_cast(func), override_mode); + MLC_SAFE_CALL_END(&last_error); +} + +MLC_API int32_t MLCDynTypeTypeTableCreate(MLCTypeTableHandle *ret) { + MLC_SAFE_CALL_BEGIN(); + *ret = TypeTable::New(); + MLC_SAFE_CALL_END(&last_error); +} + +MLC_API int32_t MLCDynTypeTypeTableDestroy(MLCTypeTableHandle handle) { + MLC_SAFE_CALL_BEGIN(); + delete static_cast(handle); + MLC_SAFE_CALL_END(&last_error); +} + +MLC_API int32_t MLCAnyIncRef(MLCAny *any) { + MLC_SAFE_CALL_BEGIN(); + if (!::mlc::base::IsTypeIndexPOD(any->type_index)) { + ::mlc::base::IncRef(any->v.v_obj); + } + MLC_SAFE_CALL_END(&last_error); +} + +MLC_API int32_t MLCAnyDecRef(MLCAny *any) { + MLC_SAFE_CALL_BEGIN(); + if (!::mlc::base::IsTypeIndexPOD(any->type_index)) { + ::mlc::base::DecRef(any->v.v_obj); + } + MLC_SAFE_CALL_END(&last_error); +} + +MLC_API int32_t MLCAnyInplaceViewToOwned(MLCAny *any) { + MLC_SAFE_CALL_BEGIN(); + MLCAny tmp{}; + static_cast(tmp) = static_cast(*any); + *any = tmp; + MLC_SAFE_CALL_END(&last_error); +} + +MLC_API int32_t MLCFuncSetGlobal(MLCTypeTableHandle _self, const char *name, MLCAny func, int allow_override) { + MLC_SAFE_CALL_BEGIN(); + TypeTable::Get(_self)->SetFunc(name, static_cast(func), allow_override); + MLC_SAFE_CALL_END(&last_error); +} + +MLC_API int32_t MLCFuncGetGlobal(MLCTypeTableHandle _self, const char *name, MLCAny *ret) { + MLC_SAFE_CALL_BEGIN(); + *static_cast(ret) = TypeTable::Get(_self)->GetFunc(name); + MLC_SAFE_CALL_END(&last_error); +} + +MLC_API int32_t MLCFuncSafeCall(MLCFunc *func, int32_t num_args, MLCAny *args, MLCAny *ret) { +#if defined(__APPLE__) + if (func->_mlc_header.type_index != MLCTypeIndex::kMLCFunc) { + std::cout << "func->type_index: " << func->_mlc_header.type_index << std::endl; + } +#endif + return func->safe_call(func, num_args, args, ret); +} + +MLC_API int32_t MLCFuncCreate(void *self, MLCDeleterType deleter, MLCFuncSafeCallType safe_call, MLCAny *ret) { + MLC_SAFE_CALL_BEGIN(); + *static_cast(ret) = FuncObj::FromForeign(self, deleter, safe_call); + MLC_SAFE_CALL_END(&last_error); +} + +MLC_API int32_t MLCErrorCreate(const char *kind, int64_t num_bytes, const char *bytes, MLCAny *ret) { + MLC_SAFE_CALL_BEGIN(); + *static_cast(ret) = Ref::New(kind, num_bytes, bytes); + MLC_SAFE_CALL_END(&last_error); +} + +MLC_API int32_t MLCErrorGetInfo(MLCAny error, int32_t *num_strs, const char ***strs) { + MLC_SAFE_CALL_BEGIN(); + thread_local std::vector ret; + static_cast(error).operator Ref()->GetInfo(&ret); + *num_strs = static_cast(ret.size()); + *strs = ret.data(); + MLC_SAFE_CALL_END(&last_error); +} + +MLC_API int32_t MLCExtObjCreate(int32_t num_bytes, int32_t type_index, MLCAny *ret) { + MLC_SAFE_CALL_BEGIN(); + *static_cast(ret) = mlc::AllocExternObject(type_index, num_bytes); + MLC_SAFE_CALL_END(&last_error); +} + +MLC_API int32_t _MLCExtObjDeleteImpl(void *objptr) { + MLC_SAFE_CALL_BEGIN(); + ::mlc::core::DeleteExternObject(static_cast<::mlc::Object *>(objptr)); + MLC_SAFE_CALL_END(&last_error); +} + +MLC_API void MLCExtObjDelete(void *objptr) { + if (int32_t error_code = _MLCExtObjDeleteImpl(objptr)) { + std::cerr << "Error code (" << error_code << ") when deleting external object: " << last_error << std::endl; + std::abort(); + } +} diff --git a/cpp/c_api_tests.cc b/cpp/c_api_tests.cc new file mode 100644 index 00000000..838e88c5 --- /dev/null +++ b/cpp/c_api_tests.cc @@ -0,0 +1,195 @@ +#include + +namespace mlc { +namespace { + +/**************** FFI ****************/ + +MLC_REGISTER_FUNC("mlc.testing.cxx_none").set_body([]() -> void { return; }); +MLC_REGISTER_FUNC("mlc.testing.cxx_null").set_body([]() -> void * { return nullptr; }); +MLC_REGISTER_FUNC("mlc.testing.cxx_int").set_body([](int x) -> int { return x; }); +MLC_REGISTER_FUNC("mlc.testing.cxx_float").set_body([](double x) -> double { return x; }); +MLC_REGISTER_FUNC("mlc.testing.cxx_ptr").set_body([](void *x) -> void * { return x; }); +MLC_REGISTER_FUNC("mlc.testing.cxx_dtype").set_body([](DLDataType x) { return x; }); +MLC_REGISTER_FUNC("mlc.testing.cxx_device").set_body([](DLDevice x) { return x; }); +MLC_REGISTER_FUNC("mlc.testing.cxx_raw_str").set_body([](const char *x) { return x; }); + +/**************** Reflection ****************/ + +struct TestingCClassObj : public Object { + int8_t i8; + int16_t i16; + int32_t i32; + int64_t i64; + float f32; + double f64; + void *raw_ptr; + DLDataType dtype; + DLDevice device; + Any any; + Func func; + UList ulist; + UDict udict; + Str str_; + Str str_readonly; + + List list_any; + List> list_list_int; + Dict dict_any_any; + Dict dict_str_any; + Dict dict_any_str; + Dict> dict_str_list_int; + + Optional opt_i64; + Optional opt_f64; + Optional opt_raw_ptr; + Optional opt_dtype; + Optional opt_device; + Optional opt_func; + Optional opt_ulist; + Optional opt_udict; + Optional opt_str; + + Optional> opt_list_any; + Optional>> opt_list_list_int; + Optional> opt_dict_any_any; + Optional> opt_dict_str_any; + Optional> opt_dict_any_str; + Optional>> opt_dict_str_list_int; + + explicit TestingCClassObj(int8_t i8, int16_t i16, int32_t i32, int64_t i64, float f32, double f64, void *raw_ptr, + DLDataType dtype, DLDevice device, Any any, Func func, UList ulist, UDict udict, Str str_, + Str str_readonly, List list_any, List> list_list_int, + Dict dict_any_any, Dict dict_str_any, Dict dict_any_str, + Dict> dict_str_list_int, Optional opt_i64, Optional opt_f64, + Optional opt_raw_ptr, Optional opt_dtype, Optional opt_device, + Optional opt_func, Optional opt_ulist, Optional opt_udict, + Optional opt_str, Optional> opt_list_any, + Optional>> opt_list_list_int, Optional> opt_dict_any_any, + Optional> opt_dict_str_any, Optional> opt_dict_any_str, + Optional>> opt_dict_str_list_int) + : i8(i8), i16(i16), i32(i32), i64(i64), f32(f32), f64(f64), raw_ptr(raw_ptr), dtype(dtype), device(device), + any(any), func(func), ulist(ulist), udict(udict), str_(str_), str_readonly(str_readonly), list_any(list_any), + list_list_int(list_list_int), dict_any_any(dict_any_any), dict_str_any(dict_str_any), + dict_any_str(dict_any_str), dict_str_list_int(dict_str_list_int), opt_i64(opt_i64), opt_f64(opt_f64), + opt_raw_ptr(opt_raw_ptr), opt_dtype(opt_dtype), opt_device(opt_device), opt_func(opt_func), + opt_ulist(opt_ulist), opt_udict(opt_udict), opt_str(opt_str), opt_list_any(opt_list_any), + opt_list_list_int(opt_list_list_int), opt_dict_any_any(opt_dict_any_any), opt_dict_str_any(opt_dict_str_any), + opt_dict_any_str(opt_dict_any_str), opt_dict_str_list_int(opt_dict_str_list_int) {} + + int64_t i64_plus_one() const { return i64 + 1; } + + MLC_DEF_DYN_TYPE(TestingCClassObj, Object, "mlc.testing.c_class"); +}; + +struct TestingCClass : public ObjectRef { + MLC_DEF_OBJ_REF(TestingCClass, TestingCClassObj, ObjectRef) + .Field("i8", &TestingCClassObj::i8) + .Field("i16", &TestingCClassObj::i16) + .Field("i32", &TestingCClassObj::i32) + .Field("i64", &TestingCClassObj::i64) + .Field("f32", &TestingCClassObj::f32) + .Field("f64", &TestingCClassObj::f64) + .Field("raw_ptr", &TestingCClassObj::raw_ptr) + .Field("dtype", &TestingCClassObj::dtype) + .Field("device", &TestingCClassObj::device) + .Field("any", &TestingCClassObj::any) + .Field("func", &TestingCClassObj::func) + .Field("ulist", &TestingCClassObj::ulist) + .Field("udict", &TestingCClassObj::udict) + .Field("str_", &TestingCClassObj::str_) + .FieldReadOnly("str_readonly", &TestingCClassObj::str_readonly) + .Field("list_any", &TestingCClassObj::list_any) + .Field("list_list_int", &TestingCClassObj::list_list_int) + .Field("dict_any_any", &TestingCClassObj::dict_any_any) + .Field("dict_str_any", &TestingCClassObj::dict_str_any) + .Field("dict_any_str", &TestingCClassObj::dict_any_str) + .Field("dict_str_list_int", &TestingCClassObj::dict_str_list_int) + .Field("opt_i64", &TestingCClassObj::opt_i64) + .Field("opt_f64", &TestingCClassObj::opt_f64) + .Field("opt_raw_ptr", &TestingCClassObj::opt_raw_ptr) + .Field("opt_dtype", &TestingCClassObj::opt_dtype) + .Field("opt_device", &TestingCClassObj::opt_device) + .Field("opt_func", &TestingCClassObj::opt_func) + .Field("opt_ulist", &TestingCClassObj::opt_ulist) + .Field("opt_udict", &TestingCClassObj::opt_udict) + .Field("opt_str", &TestingCClassObj::opt_str) + .Field("opt_list_any", &TestingCClassObj::opt_list_any) + .Field("opt_list_list_int", &TestingCClassObj::opt_list_list_int) + .Field("opt_dict_any_any", &TestingCClassObj::opt_dict_any_any) + .Field("opt_dict_str_any", &TestingCClassObj::opt_dict_str_any) + .Field("opt_dict_any_str", &TestingCClassObj::opt_dict_any_str) + .Field("opt_dict_str_list_int", &TestingCClassObj::opt_dict_str_list_int) + .MemFn("i64_plus_one", &TestingCClassObj::i64_plus_one) + .StaticFn("__init__", + InitOf, List>, Dict, Dict, + Dict, Dict>, Optional, Optional, Optional, + Optional, Optional, Optional, Optional, Optional, + Optional, Optional>, Optional>>, Optional>, + Optional>, Optional>, Optional>>>); +}; + +/**************** Traceback ****************/ + +MLC_REGISTER_FUNC("mlc.testing.throw_exception_from_c").set_body([]() { + // Simply throw an exception in C++ + MLC_THROW(ValueError) << "This is an error message"; +}); + +MLC_REGISTER_FUNC("mlc.testing.throw_exception_from_ffi_in_c").set_body([](FuncObj *func) { + // call a Python function which throws an exception + (*func)(); +}); + +MLC_REGISTER_FUNC("mlc.testing.throw_exception_from_ffi").set_body([](FuncObj *func) { + // Call a Python function `func` which throws an exception and returns it + Any ret; + try { + (*func)(); + } catch (Exception &error) { + ret = std::move(error.data_); + } + return ret; +}); + +/**************** Type checking ****************/ + +MLC_REGISTER_FUNC("mlc.testing.nested_type_checking_list").set_body([](Str name) { + if (name == "list") { + using Type = UList; + return Func([](Type v) { return v; }); + } + if (name == "list[Any]") { + using Type = List; + return Func([](Type v) { return v; }); + } + if (name == "list[list[int]]") { + using Type = List>; + return Func([](Type v) { return v; }); + } + if (name == "dict") { + using Type = UDict; + return Func([](Type v) { return v; }); + } + if (name == "dict[str, Any]") { + using Type = Dict; + return Func([](Type v) { return v; }); + } + if (name == "dict[Any, str]") { + using Type = Dict; + return Func([](Type v) { return v; }); + } + if (name == "dict[Any, Any]") { + using Type = Dict; + return Func([](Type v) { return v; }); + } + if (name == "dict[str, list[int]]") { + using Type = Dict>; + return Func([](Type v) { return v; }); + } + MLC_UNREACHABLE(); +}); + +} // namespace +} // namespace mlc diff --git a/cpp/json.cc b/cpp/json.cc new file mode 100644 index 00000000..2af5b474 --- /dev/null +++ b/cpp/json.cc @@ -0,0 +1,497 @@ +#include +#include +#include +#include + +namespace mlc { +namespace core { +namespace { + +mlc::Str Serialize(Any any); +Any Deserialize(const char *json_str, int64_t json_str_len); +Any JSONLoads(const char *json_str, int64_t json_str_len); +MLC_INLINE Any Deserialize(const char *json_str) { return Deserialize(json_str, -1); } +MLC_INLINE Any Deserialize(const Str &json_str) { return Deserialize(json_str->data(), json_str->size()); } +MLC_INLINE Any JSONLoads(const char *json_str) { return JSONLoads(json_str, -1); } +MLC_INLINE Any JSONLoads(const Str &json_str) { return JSONLoads(json_str->data(), json_str->size()); } + +inline mlc::Str Serialize(Any any) { + using mlc::base::TypeTraits; + std::vector type_keys; + auto get_json_type_index = [type_key2index = std::unordered_map(), + &type_keys](const char *type_key) mutable -> int32_t { + if (auto it = type_key2index.find(type_key); it != type_key2index.end()) { + return it->second; + } + int32_t type_index = static_cast(type_key2index.size()); + type_key2index[type_key] = type_index; + type_keys.push_back(type_key); + return type_index; + }; + using TObj2Idx = std::unordered_map; + using TJsonTypeIndex = decltype(get_json_type_index); + struct Emitter { + // clang-format off + MLC_INLINE void operator()(MLCTypeField *, const Any *any) { EmitAny(any); } + MLC_INLINE void operator()(MLCTypeField *, ObjectRef *obj) { if (Object *v = obj->get()) EmitObject(v); else EmitNil(); } + MLC_INLINE void operator()(MLCTypeField *, Optional *opt) { if (Object *v = opt->get()) EmitObject(v); else EmitNil(); } + MLC_INLINE void operator()(MLCTypeField *, Optional *opt) { if (const int64_t *v = opt->get()) EmitInt(*v); else EmitNil(); } + MLC_INLINE void operator()(MLCTypeField *, Optional *opt) { if (const double *v = opt->get()) EmitFloat(*v); else EmitNil(); } + MLC_INLINE void operator()(MLCTypeField *, Optional *opt) { if (const DLDevice *v = opt->get()) EmitDevice(*v); else EmitNil(); } + MLC_INLINE void operator()(MLCTypeField *, Optional *opt) { if (const DLDataType *v = opt->get()) EmitDType(*v); else EmitNil(); } + MLC_INLINE void operator()(MLCTypeField *, int8_t *v) { EmitInt(static_cast(*v)); } + MLC_INLINE void operator()(MLCTypeField *, int16_t *v) { EmitInt(static_cast(*v)); } + MLC_INLINE void operator()(MLCTypeField *, int32_t *v) { EmitInt(static_cast(*v)); } + MLC_INLINE void operator()(MLCTypeField *, int64_t *v) { EmitInt(static_cast(*v)); } + MLC_INLINE void operator()(MLCTypeField *, float *v) { EmitFloat(static_cast(*v)); } + MLC_INLINE void operator()(MLCTypeField *, double *v) { EmitFloat(static_cast(*v)); } + MLC_INLINE void operator()(MLCTypeField *, DLDataType *v) { EmitDType(*v); } + MLC_INLINE void operator()(MLCTypeField *, DLDevice *v) { EmitDevice(*v); } + MLC_INLINE void operator()(MLCTypeField *, Optional *) { MLC_THROW(TypeError) << "Unserializable type: void *"; } + MLC_INLINE void operator()(MLCTypeField *, void **) { MLC_THROW(TypeError) << "Unserializable type: void *"; } + MLC_INLINE void operator()(MLCTypeField *, const char **) { MLC_THROW(TypeError) << "Unserializable type: const char *"; } + // clang-format on + inline void EmitNil() { (*os) << ", null"; } + inline void EmitFloat(double v) { (*os) << ", " << std::fixed << std::setprecision(19) << v; } + inline void EmitInt(int64_t v) { + int32_t type_int = (*get_json_type_index)(TypeTraits::type_str); + (*os) << ", [" << type_int << ", " << v << "]"; + } + inline void EmitDevice(DLDevice v) { + int32_t type_device = (*get_json_type_index)(TypeTraits::type_str); + (*os) << ", [" << type_device << ", " << ::mlc::base::TypeTraits::__str__(v) << "]"; + } + inline void EmitDType(DLDataType v) { + int32_t type_dtype = (*get_json_type_index)(TypeTraits::type_str); + (*os) << ", [" << type_dtype << ", " << ::mlc::base::TypeTraits::__str__(v) << "]"; + } + inline void EmitAny(const Any *any) { + int32_t type_index = any->type_index; + if (type_index == kMLCNone) { + EmitNil(); + } else if (type_index == kMLCInt) { + EmitInt(any->operator int64_t()); + } else if (type_index == kMLCFloat) { + EmitFloat(any->operator double()); + } else if (type_index == kMLCDevice) { + EmitDevice(any->operator DLDevice()); + } else if (type_index == kMLCDataType) { + EmitDType(any->operator DLDataType()); + } else if (type_index >= kMLCStaticObjectBegin) { + EmitObject(any->operator Object *()); + } else { + MLC_THROW(TypeError) << "Cannot serialize type: " << ::mlc::base::TypeIndex2TypeKey(type_index); + } + } + inline void EmitObject(Object *obj) { + if (!obj) { + MLC_THROW(InternalError) << "This should never happen: null object pointer during EmitObject"; + } + int32_t obj_idx = obj2index->at(obj); + if (obj_idx == -1) { + MLC_THROW(InternalError) << "This should never happen: topological ordering violated"; + } + (*os) << ", " << obj_idx; + } + std::ostringstream *os; + TJsonTypeIndex *get_json_type_index; + const TObj2Idx *obj2index; + }; + + std::ostringstream os; + auto on_visit = [get_json_type_index = &get_json_type_index, os = &os, is_first_object = true]( + Object *object, MLCTypeInfo *type_info, const TObj2Idx &obj2index) mutable -> void { + Emitter emitter{os, get_json_type_index, &obj2index}; + if (is_first_object) { + is_first_object = false; + } else { + os->put(','); + } + if (StrObj *str = object->TryCast()) { + str->PrintEscape(*os); + return; + } + (*os) << '[' << (*get_json_type_index)(type_info->type_key); + if (UListObj *list = object->TryCast()) { + for (Any &any : *list) { + emitter(nullptr, &any); + } + } else if (UDictObj *dict = object->TryCast()) { + for (auto &kv : *dict) { + emitter(nullptr, &kv.first); + emitter(nullptr, &kv.second); + } + } else if (object->IsInstance() || object->IsInstance()) { + MLC_THROW(TypeError) << "Unserializable type: " << object->GetTypeKey(); + } else { + VisitFields(object, type_info, emitter); + } + os->put(']'); + }; + os << "{\"values\": ["; + if (any.type_index >= kMLCStaticObjectBegin) { // Topological the objects according to dependency + TopoVisit(any.operator Object *(), nullptr, on_visit); + } else if (any.type_index == kMLCNone) { + os << "null"; + } else if (any.type_index == kMLCInt) { + int32_t type_int = get_json_type_index(TypeTraits::type_str); + int64_t v = any; + os << "[" << type_int << ", " << v << "]"; + } else if (any.type_index == kMLCFloat) { + double v = any; + os << v; + } else if (any.type_index == kMLCDevice) { + int32_t type_device = get_json_type_index(TypeTraits::type_str); + DLDevice v = any; + os << "[" << type_device << ", \"" << TypeTraits::__str__(v) << "\"]"; + } else if (any.type_index == kMLCDataType) { + int32_t type_dtype = get_json_type_index(TypeTraits::type_str); + DLDataType v = any; + os << "[" << type_dtype << ", \"" << TypeTraits::__str__(v) << "\"]"; + } else { + MLC_THROW(TypeError) << "Cannot serialize type: " << mlc::base::TypeIndex2TypeKey(any.type_index); + } + os << "], \"type_keys\": ["; + for (size_t i = 0; i < type_keys.size(); ++i) { + if (i > 0) { + os << ", "; + } + os << '"' << type_keys[i] << '\"'; + } + os << "]}"; + return os.str(); +} + +inline Any Deserialize(const char *json_str, int64_t json_str_len) { + MLCVTableHandle init_vtable; + MLCVTableGetGlobal(nullptr, "__init__", &init_vtable); + // Step 0. Parse JSON string + UDict json_obj = JSONLoads(json_str, json_str_len); + // Step 1. type_key => constructors + UList type_keys = json_obj->at("type_keys"); + std::vector constructors; + constructors.reserve(type_keys.size()); + for (Str type_key : type_keys) { + Any init_func; + int32_t type_index = ::mlc::base::TypeKey2TypeIndex(type_key->data()); + MLCVTableGetFunc(init_vtable, type_index, false, &init_func); + if (!::mlc::base::IsTypeIndexNone(init_func.type_index)) { + constructors.push_back(init_func.operator Func()); + } else { + MLC_THROW(InternalError) << "Method `__init__` is not defined for type " << type_key; + } + } + auto invoke_init = [&constructors](UList args) { + int32_t json_type_index = args[0]; + Any ret; + ::mlc::base::FuncCall(constructors.at(json_type_index).get(), static_cast(args.size()) - 1, + args->data() + 1, &ret); + return ret; + }; + // Step 2. Translate JSON object to objects + UList values = json_obj->at("values"); + for (int64_t i = 0; i < values->size(); ++i) { + Any obj = values[i]; + if (obj.type_index == kMLCList) { + UList list = obj.operator UList(); + int32_t json_type_index = list[0]; + for (int64_t j = 1; j < list.size(); ++j) { + Any arg = list[j]; + if (arg.type_index == kMLCInt) { + int64_t k = arg; + if (k < i) { + list[j] = values[k]; + } else { + MLC_THROW(ValueError) << "Invalid reference when parsing type `" << type_keys[json_type_index] + << "`: referring #" << k << " at #" << i << ". v = " << obj; + } + } else if (arg.type_index == kMLCList) { + list[j] = invoke_init(arg.operator UList()); + } else if (arg.type_index == kMLCStr || arg.type_index == kMLCFloat || arg.type_index == kMLCNone) { + // Do nothing + } else { + MLC_THROW(ValueError) << "Unexpected value: " << arg; + } + } + values[i] = invoke_init(list); + } else if (obj.type_index == kMLCInt) { + int32_t k = obj; + values[i] = values[k]; + } else if (obj.type_index == kMLCStr) { + // Do nothing + } else { + MLC_THROW(ValueError) << "Unexpected value: " << obj; + } + } + return values->back(); +} + +inline Any JSONLoads(const char *json_str, int64_t json_str_len) { + struct JSONParser { + Any Parse() { + SkipWhitespace(); + Any result = ParseValue(); + SkipWhitespace(); + if (i != json_str_len) { + MLC_THROW(ValueError) << "JSON parsing failure at position " << i + << ": Extra data after valid JSON. JSON string: " << json_str; + } + return result; + } + + void ExpectChar(char c) { + if (json_str[i] == c) { + ++i; + } else { + MLC_THROW(ValueError) << "JSON parsing failure at position " << i << ": Expected '" << c << "' but got '" + << json_str[i] << "'. JSON string: " << json_str; + } + } + + char PeekChar() { return i < json_str_len ? json_str[i] : '\0'; } + + void SkipWhitespace() { + while (i < json_str_len && std::isspace(json_str[i])) { + ++i; + } + } + + void ExpectString(const char *expected, int64_t len) { + if (i + len <= json_str_len && std::strncmp(json_str + i, expected, len) == 0) { + i = i + len; + } else { + MLC_THROW(ValueError) << "JSON parsing failure at position " << i << ": Expected '" << expected + << ". JSON string: " << json_str; + } + } + + Any ParseNull() { + ExpectString("null", 4); + return Any(nullptr); + } + + Any ParseBoolean() { + if (PeekChar() == 't') { + ExpectString("true", 4); + return Any(1); + } else { + ExpectString("false", 5); + return Any(0); + } + } + + Any ParseNumber() { + int64_t start = i; + // Identify the end of the numeric sequence + while (i < json_str_len) { + char c = json_str[i]; + if (c == '.' || c == 'e' || c == 'E' || c == '+' || c == '-' || std::isdigit(c)) { + ++i; + } else { + break; + } + } + std::string num_str(json_str + start, i - start); + std::size_t pos = 0; + try { + // Attempt to parse as integer + int64_t int_value = std::stoll(num_str, &pos); + if (pos == num_str.size()) { + i = start + static_cast(pos); // Update the main index + return Any(int_value); + } + } catch (const std::invalid_argument &) { // Not an integer, proceed to parse as double + } catch (const std::out_of_range &) { // Integer out of range, proceed to parse as double + } + try { + // Attempt to parse as double + double double_value = std::stod(num_str, &pos); + if (pos == num_str.size()) { + i = start + pos; // Update the main index + return Any(double_value); + } + } catch (const std::invalid_argument &) { + } catch (const std::out_of_range &) { + } + MLC_THROW(ValueError) << "JSON parsing failure at position " << i + << ": Invalid number format. JSON string: " << json_str; + MLC_UNREACHABLE(); + } + + Any ParseStr() { + ExpectChar('"'); + std::ostringstream oss; + while (true) { + if (i >= json_str_len) { + MLC_THROW(ValueError) << "JSON parsing failure at position " << i + << ": Unterminated string. JSON string: " << json_str; + } + char c = json_str[i++]; + if (c == '"') { + // End of string + return Any(Str(oss.str())); + } else if (c == '\\') { + // Handle escape sequences + if (i >= json_str_len) { + MLC_THROW(ValueError) << "JSON parsing failure at position " << i + << ": Incomplete escape sequence. JSON string: " << json_str; + } + char next = json_str[i++]; + switch (next) { + case 'n': + oss << '\n'; + break; + case 't': + oss << '\t'; + break; + case 'r': + oss << '\r'; + break; + case '\\': + oss << '\\'; + break; + case '"': + oss << '\"'; + break; + case 'x': { + if (i + 1 < json_str_len && std::isxdigit(json_str[i]) && std::isxdigit(json_str[i + 1])) { + int32_t value = std::stoi(std::string(json_str + i, 2), nullptr, 16); + oss << static_cast(value); + i += 2; + } else { + MLC_THROW(ValueError) << "Invalid hexadecimal escape sequence at position " << i - 2 + << " in string: " << json_str; + } + break; + } + case 'u': { + if (i + 3 < json_str_len && std::isxdigit(json_str[i]) && std::isxdigit(json_str[i + 1]) && + std::isxdigit(json_str[i + 2]) && std::isxdigit(json_str[i + 3])) { + int32_t codepoint = std::stoi(std::string(json_str + i, 4), nullptr, 16); + if (codepoint <= 0x7F) { + // 1-byte UTF-8 + oss << static_cast(codepoint); + } else if (codepoint <= 0x7FF) { + // 2-byte UTF-8 + oss << static_cast(0xC0 | (codepoint >> 6)); + oss << static_cast(0x80 | (codepoint & 0x3F)); + } else { + // 3-byte UTF-8 + oss << static_cast(0xE0 | (codepoint >> 12)); + oss << static_cast(0x80 | ((codepoint >> 6) & 0x3F)); + oss << static_cast(0x80 | (codepoint & 0x3F)); + } + i += 4; + } else { + MLC_THROW(ValueError) << "Invalid Unicode escape sequence at position " << i - 2 + << " in string: " << json_str; + } + break; + } + default: + // Unrecognized escape sequence, interpret literally + oss << next; + break; + } + } else { + // Regular character, copy as-is + oss << c; + } + } + } + + UList ParseArray() { + UList arr; + ExpectChar('['); + SkipWhitespace(); + if (PeekChar() == ']') { + ExpectChar(']'); + return arr; + } + while (true) { + SkipWhitespace(); + arr.push_back(ParseValue()); + SkipWhitespace(); + if (PeekChar() == ']') { + ExpectChar(']'); + return arr; + } + ExpectChar(','); + } + } + + Any ParseObject() { + UDict obj; + ExpectChar('{'); + SkipWhitespace(); + if (PeekChar() == '}') { + ExpectChar('}'); + return Any(obj); + } + while (true) { + SkipWhitespace(); + Any key = ParseStr(); + SkipWhitespace(); + ExpectChar(':'); + SkipWhitespace(); + Any value = ParseValue(); + obj[key] = value; + SkipWhitespace(); + if (PeekChar() == '}') { + ExpectChar('}'); + return Any(obj); + } + ExpectChar(','); + } + } + + Any ParseValue() { + SkipWhitespace(); + char c = PeekChar(); + if (c == '"') { + return ParseStr(); + } else if (c == '{') { + return ParseObject(); + } else if (c == '[') { + return ParseArray(); + } else if (c == 'n') { + return ParseNull(); + } else if (c == 't' || c == 'f') { + return ParseBoolean(); + } else if (std::isdigit(c) || c == '-') { + return ParseNumber(); + } else { + MLC_THROW(ValueError) << "JSON parsing failure at position " << i << ": Unexpected character: " << c + << ". JSON string: " << json_str; + } + MLC_UNREACHABLE(); + } + int64_t i; + int64_t json_str_len; + const char *json_str; + }; + if (json_str_len < 0) { + json_str_len = static_cast(std::strlen(json_str)); + } + return JSONParser{0, json_str_len, json_str}.Parse(); +} + +MLC_REGISTER_FUNC("mlc.core.JSONLoads").set_body([](AnyView json_str) { + if (json_str.type_index == kMLCRawStr) { + return ::mlc::core::JSONLoads(json_str.operator const char *()); + } else { + ::mlc::Str str = json_str; + return ::mlc::core::JSONLoads(str); + } +}); +MLC_REGISTER_FUNC("mlc.core.JSONSerialize").set_body(::mlc::core::Serialize); +MLC_REGISTER_FUNC("mlc.core.JSONDeserialize").set_body([](AnyView json_str) { + if (json_str.type_index == kMLCRawStr) { + return ::mlc::core::Deserialize(json_str.operator const char *()); + } else { + return ::mlc::core::Deserialize(json_str.operator ::mlc::Str()); + } +}); +} // namespace +} // namespace core +} // namespace mlc diff --git a/cpp/printer.cc b/cpp/printer.cc new file mode 100644 index 00000000..30d7db03 --- /dev/null +++ b/cpp/printer.cc @@ -0,0 +1,1044 @@ +#include +#include +#include +#include +#include +#include +#include +#include + +namespace mlc { +namespace printer { +namespace { +using ByteSpan = std::pair; + +class DocPrinter { +public: + explicit DocPrinter(const PrinterConfig &options); + virtual ~DocPrinter() = default; + + void Append(const Node &doc); + void Append(const Node &doc, const PrinterConfig &cfg); + ::mlc::Str GetString() const; + +protected: + void PrintDoc(const Node &doc); + virtual void PrintTypedDoc(const Literal &doc) = 0; + virtual void PrintTypedDoc(const Id &doc) = 0; + virtual void PrintTypedDoc(const Attr &doc) = 0; + virtual void PrintTypedDoc(const Index &doc) = 0; + virtual void PrintTypedDoc(const Operation &doc) = 0; + virtual void PrintTypedDoc(const Call &doc) = 0; + virtual void PrintTypedDoc(const Lambda &doc) = 0; + virtual void PrintTypedDoc(const List &doc) = 0; + virtual void PrintTypedDoc(const Tuple &doc) = 0; + virtual void PrintTypedDoc(const Dict &doc) = 0; + virtual void PrintTypedDoc(const Slice &doc) = 0; + virtual void PrintTypedDoc(const StmtBlock &doc) = 0; + virtual void PrintTypedDoc(const Assign &doc) = 0; + virtual void PrintTypedDoc(const If &doc) = 0; + virtual void PrintTypedDoc(const While &doc) = 0; + virtual void PrintTypedDoc(const For &doc) = 0; + virtual void PrintTypedDoc(const With &doc) = 0; + virtual void PrintTypedDoc(const ExprStmt &doc) = 0; + virtual void PrintTypedDoc(const Assert &doc) = 0; + virtual void PrintTypedDoc(const Return &doc) = 0; + virtual void PrintTypedDoc(const Function &doc) = 0; + virtual void PrintTypedDoc(const Class &doc) = 0; + virtual void PrintTypedDoc(const Comment &doc) = 0; + virtual void PrintTypedDoc(const DocString &doc) = 0; + + void PrintTypedDoc(const NodeObj *doc) { + using PrinterVTable = std::unordered_map>; + // clang-format off + #define MLC_PRINTER_PRINT_TYPED_DOC_ENTRY_(Type) { Type::TObj::_type_index, [](DocPrinter *printer, const NodeObj *doc) { printer->PrintTypedDoc(Type(doc->Cast())); } } + // clang-format on + static PrinterVTable vtable{ + MLC_PRINTER_PRINT_TYPED_DOC_ENTRY_(Literal), MLC_PRINTER_PRINT_TYPED_DOC_ENTRY_(Id), + MLC_PRINTER_PRINT_TYPED_DOC_ENTRY_(Attr), MLC_PRINTER_PRINT_TYPED_DOC_ENTRY_(Index), + MLC_PRINTER_PRINT_TYPED_DOC_ENTRY_(Operation), MLC_PRINTER_PRINT_TYPED_DOC_ENTRY_(Call), + MLC_PRINTER_PRINT_TYPED_DOC_ENTRY_(Lambda), MLC_PRINTER_PRINT_TYPED_DOC_ENTRY_(List), + MLC_PRINTER_PRINT_TYPED_DOC_ENTRY_(Tuple), MLC_PRINTER_PRINT_TYPED_DOC_ENTRY_(Dict), + MLC_PRINTER_PRINT_TYPED_DOC_ENTRY_(Slice), MLC_PRINTER_PRINT_TYPED_DOC_ENTRY_(StmtBlock), + MLC_PRINTER_PRINT_TYPED_DOC_ENTRY_(Assign), MLC_PRINTER_PRINT_TYPED_DOC_ENTRY_(If), + MLC_PRINTER_PRINT_TYPED_DOC_ENTRY_(While), MLC_PRINTER_PRINT_TYPED_DOC_ENTRY_(For), + MLC_PRINTER_PRINT_TYPED_DOC_ENTRY_(With), MLC_PRINTER_PRINT_TYPED_DOC_ENTRY_(ExprStmt), + MLC_PRINTER_PRINT_TYPED_DOC_ENTRY_(Assert), MLC_PRINTER_PRINT_TYPED_DOC_ENTRY_(Return), + MLC_PRINTER_PRINT_TYPED_DOC_ENTRY_(Function), MLC_PRINTER_PRINT_TYPED_DOC_ENTRY_(Class), + MLC_PRINTER_PRINT_TYPED_DOC_ENTRY_(Comment), MLC_PRINTER_PRINT_TYPED_DOC_ENTRY_(DocString), + }; + // clang-format off + #undef MLC_PRINTER_PRINT_TYPED_DOC_ENTRY_ + // clang-format on + vtable.at(doc->GetTypeIndex())(this, doc); + } + void IncreaseIndent() { indent_ += options_->indent_spaces; } + void DecreaseIndent() { indent_ -= options_->indent_spaces; } + std::ostream &NewLine() { + size_t start_pos = output_.tellp(); + output_ << "\n"; + line_starts_.push_back(output_.tellp()); + for (int i = 0; i < indent_; ++i) { + output_ << ' '; + } + size_t end_pos = output_.tellp(); + underlines_exempted_.push_back({start_pos, end_pos}); + return output_; + } + std::ostringstream output_; + std::vector underlines_exempted_; + +private: + using ObjectPath = ::mlc::core::ObjectPath; + + void MarkSpan(const ByteSpan &span, const ObjectPath &path); + PrinterConfig options_; + int indent_ = 0; + std::vector line_starts_; + mlc::List path_to_underline_; + std::vector> current_underline_candidates_; + std::vector current_max_path_length_; + std::vector underlines_; +}; + +inline const char *OpKindToString(OperationObj::Kind kind) { + switch (kind) { + case OperationObj::Kind::kUSub: + return "-"; + case OperationObj::Kind::kInvert: + return "~"; + case OperationObj::Kind::kNot: + return "not "; + case OperationObj::Kind::kAdd: + return "+"; + case OperationObj::Kind::kSub: + return "-"; + case OperationObj::Kind::kMult: + return "*"; + case OperationObj::Kind::kDiv: + return "/"; + case OperationObj::Kind::kFloorDiv: + return "//"; + case OperationObj::Kind::kMod: + return "%"; + case OperationObj::Kind::kPow: + return "**"; + case OperationObj::Kind::kLShift: + return "<<"; + case OperationObj::Kind::kRShift: + return ">>"; + case OperationObj::Kind::kBitAnd: + return "&"; + case OperationObj::Kind::kBitOr: + return "|"; + case OperationObj::Kind::kBitXor: + return "^"; + case OperationObj::Kind::kLt: + return "<"; + case OperationObj::Kind::kLtE: + return "<="; + case OperationObj::Kind::kEq: + return "=="; + case OperationObj::Kind::kNotEq: + return "!="; + case OperationObj::Kind::kGt: + return ">"; + case OperationObj::Kind::kGtE: + return ">="; + case OperationObj::Kind::kAnd: + return "and"; + case OperationObj::Kind::kOr: + return "or"; + default: + MLC_THROW(ValueError) << "Unknown operation kind: " << static_cast(kind); + } + MLC_UNREACHABLE(); +} + +inline const char *OpKindToString(int64_t kind) { return OpKindToString(static_cast(kind)); } + +/*! + * \brief Operator precedence based on https://docs.python.org/3/reference/expressions.html#operator-precedence + */ +enum class ExprPrecedence : int32_t { + /*! \brief Unknown precedence */ + kUnkown = 0, + /*! \brief Lambda Expression */ + kLambda = 1, + /*! \brief Conditional Expression */ + kIfThenElse = 2, + /*! \brief Boolean OR */ + kBooleanOr = 3, + /*! \brief Boolean AND */ + kBooleanAnd = 4, + /*! \brief Boolean NOT */ + kBooleanNot = 5, + /*! \brief Comparisons */ + kComparison = 6, + /*! \brief Bitwise OR */ + kBitwiseOr = 7, + /*! \brief Bitwise XOR */ + kBitwiseXor = 8, + /*! \brief Bitwise AND */ + kBitwiseAnd = 9, + /*! \brief Shift Operators */ + kShift = 10, + /*! \brief Addition and subtraction */ + kAdd = 11, + /*! \brief Multiplication, division, floor division, remainder */ + kMult = 12, + /*! \brief Positive negative and bitwise NOT */ + kUnary = 13, + /*! \brief Exponentiation */ + kExp = 14, + /*! \brief Index access, attribute access, call and atom expression */ + kIdentity = 15, +}; + +inline ExprPrecedence GetExprPrecedence(const Expr &doc) { + // Key is the type index of Doc + static const std::unordered_map doc_type_precedence = { + {LiteralObj::_type_index, ExprPrecedence::kIdentity}, {IdObj::_type_index, ExprPrecedence::kIdentity}, + {AttrObj::_type_index, ExprPrecedence::kIdentity}, {IndexObj::_type_index, ExprPrecedence::kIdentity}, + {CallObj::_type_index, ExprPrecedence::kIdentity}, {LambdaObj::_type_index, ExprPrecedence::kLambda}, + {TupleObj::_type_index, ExprPrecedence::kIdentity}, {ListObj::_type_index, ExprPrecedence::kIdentity}, + {DictObj::_type_index, ExprPrecedence::kIdentity}, + }; + // Key is the value of OperationDocNode::Kind + static const std::vector op_kind_precedence = []() { + using OpKind = OperationObj::Kind; + std::map raw_table = { + {OpKind::kUSub, ExprPrecedence::kUnary}, {OpKind::kInvert, ExprPrecedence::kUnary}, + {OpKind::kNot, ExprPrecedence::kBooleanNot}, {OpKind::kAdd, ExprPrecedence::kAdd}, + {OpKind::kSub, ExprPrecedence::kAdd}, {OpKind::kMult, ExprPrecedence::kMult}, + {OpKind::kDiv, ExprPrecedence::kMult}, {OpKind::kFloorDiv, ExprPrecedence::kMult}, + {OpKind::kMod, ExprPrecedence::kMult}, {OpKind::kPow, ExprPrecedence::kExp}, + {OpKind::kLShift, ExprPrecedence::kShift}, {OpKind::kRShift, ExprPrecedence::kShift}, + {OpKind::kBitAnd, ExprPrecedence::kBitwiseAnd}, {OpKind::kBitOr, ExprPrecedence::kBitwiseOr}, + {OpKind::kBitXor, ExprPrecedence::kBitwiseXor}, {OpKind::kLt, ExprPrecedence::kComparison}, + {OpKind::kLtE, ExprPrecedence::kComparison}, {OpKind::kEq, ExprPrecedence::kComparison}, + {OpKind::kNotEq, ExprPrecedence::kComparison}, {OpKind::kGt, ExprPrecedence::kComparison}, + {OpKind::kGtE, ExprPrecedence::kComparison}, {OpKind::kAnd, ExprPrecedence::kBooleanAnd}, + {OpKind::kOr, ExprPrecedence::kBooleanOr}, {OpKind::kIfThenElse, ExprPrecedence::kIfThenElse}, + }; + std::vector table(static_cast(OpKind::kSpecialEnd) + 1, ExprPrecedence::kUnkown); + for (const auto &kv : raw_table) { + table[static_cast(kv.first)] = kv.second; + } + return table; + }(); + if (const auto *op_doc = doc->TryCast()) { + ExprPrecedence precedence = op_kind_precedence.at(op_doc->op); + if (precedence == ExprPrecedence::kUnkown) { + MLC_THROW(ValueError) << "Unknown precedence for operator: " << op_doc->op; + } + return precedence; + } else if (auto it = doc_type_precedence.find(doc->GetTypeIndex()); it != doc_type_precedence.end()) { + return it->second; + } + MLC_THROW(ValueError) << "Unknown precedence for doc type: " << doc->GetTypeKey(); + MLC_UNREACHABLE(); +} + +inline std::vector MergeAndExemptSpans(const std::vector &spans, + const std::vector &spans_exempted) { + // use prefix sum to merge and exempt spans + std::vector res; + std::vector> prefix_stamp; + for (ByteSpan span : spans) { + prefix_stamp.push_back({span.first, 1}); + prefix_stamp.push_back({span.second, -1}); + } + // at most spans.size() spans accumulated in prefix sum + // use spans.size() + 1 as stamp unit to exempt all positive spans + // with only one negative span + int max_n = static_cast(spans.size()) + 1; + for (ByteSpan span : spans_exempted) { + prefix_stamp.push_back({span.first, -max_n}); + prefix_stamp.push_back({span.second, max_n}); + } + std::sort(prefix_stamp.begin(), prefix_stamp.end()); + int prefix_sum = 0; + int n = static_cast(prefix_stamp.size()); + for (int i = 0; i < n - 1; ++i) { + prefix_sum += prefix_stamp[i].second; + // positive prefix sum leads to spans without exemption + // different stamp positions guarantee the stamps in same position accumulated + if (prefix_sum > 0 && prefix_stamp[i].first < prefix_stamp[i + 1].first) { + if (res.size() && res.back().second == prefix_stamp[i].first) { + // merge to the last spans if it is successive + res.back().second = prefix_stamp[i + 1].first; + } else { + // add a new independent span + res.push_back({prefix_stamp[i].first, prefix_stamp[i + 1].first}); + } + } + } + return res; +} + +inline size_t GetTextWidth(const mlc::Str &text, const ByteSpan &span) { + // FIXME: this only works for ASCII characters. + // To do this "correctly", we need to parse UTF-8 into codepoints + // and call wcwidth() or equivalent for every codepoint. + size_t ret = 0; + for (size_t i = span.first; i != span.second; ++i) { + if (isprint(text[i])) { + ret += 1; + } + } + return ret; +} + +inline size_t MoveBack(size_t pos, size_t distance) { return distance > pos ? 0 : pos - distance; } + +inline size_t MoveForward(size_t pos, size_t distance, size_t max) { + return distance > max - pos ? max : pos + distance; +} + +inline size_t GetLineIndex(size_t byte_pos, const std::vector &line_starts) { + auto it = std::upper_bound(line_starts.begin(), line_starts.end(), byte_pos); + return (it - line_starts.begin()) - 1; +} + +using UnderlineIter = typename std::vector::const_iterator; + +inline ByteSpan PopNextUnderline(UnderlineIter *next_underline, UnderlineIter end_underline) { + if (*next_underline == end_underline) { + constexpr size_t kMaxSizeT = 18446744073709551615U; + return {kMaxSizeT, kMaxSizeT}; + } else { + return *(*next_underline)++; + } +} + +inline void PrintChunk(const std::pair &lines_range, + const std::pair &underlines, const mlc::Str &text, + const std::vector &line_starts, const PrinterConfig &options, size_t line_number_width, + std::ostringstream *out) { + UnderlineIter next_underline = underlines.first; + ByteSpan current_underline = PopNextUnderline(&next_underline, underlines.second); + + for (size_t line_idx = lines_range.first; line_idx < lines_range.second; ++line_idx) { + if (options->print_line_numbers) { + (*out) << std::setw(line_number_width) << std::right << std::to_string(line_idx + 1) << ' '; + } + size_t line_start = line_starts.at(line_idx); + size_t line_end = line_idx + 1 == line_starts.size() ? text.size() : line_starts.at(line_idx + 1); + (*out) << std::string_view(text.data() + line_start, line_end - line_start); + bool printed_underline = false; + size_t line_pos = line_start; + bool printed_extra_caret = 0; + while (current_underline.first < line_end) { + if (!printed_underline) { + for (size_t i = 0; i < line_number_width; ++i) { + (*out) << ' '; + } + printed_underline = true; + } + size_t underline_end_for_line = std::min(line_end, current_underline.second); + size_t num_spaces = GetTextWidth(text, {line_pos, current_underline.first}); + if (num_spaces > 0 && printed_extra_caret) { + num_spaces -= 1; + printed_extra_caret = false; + } + for (size_t i = 0; i < num_spaces; ++i) { + (*out) << ' '; + } + + size_t num_carets = GetTextWidth(text, {current_underline.first, underline_end_for_line}); + if (num_carets == 0 && !printed_extra_caret) { + // Special case: when underlineing an empty or unprintable string, make sure to print + // at least one caret still. + num_carets = 1; + printed_extra_caret = true; + } else if (num_carets > 0 && printed_extra_caret) { + num_carets -= 1; + printed_extra_caret = false; + } + for (size_t i = 0; i < num_carets; ++i) { + (*out) << '^'; + } + + line_pos = current_underline.first = underline_end_for_line; + if (current_underline.first == current_underline.second) { + current_underline = PopNextUnderline(&next_underline, underlines.second); + } + } + if (printed_underline) { + (*out) << '\n'; + } + } +} + +inline void PrintCut(size_t num_lines_skipped, std::ostringstream *out) { + if (num_lines_skipped != 0) { + (*out) << "(... " << num_lines_skipped << " lines skipped ...)\n"; + } +} + +inline std::pair GetLinesForUnderline(const ByteSpan &underline, const std::vector &line_starts, + size_t num_lines, const PrinterConfig &options) { + size_t first_line_of_underline = GetLineIndex(underline.first, line_starts); + size_t first_line_of_chunk = MoveBack(first_line_of_underline, options->num_context_lines); + size_t end_line_of_underline = GetLineIndex(underline.second - 1, line_starts) + 1; + size_t end_line_of_chunk = MoveForward(end_line_of_underline, options->num_context_lines, num_lines); + + return {first_line_of_chunk, end_line_of_chunk}; +} + +// If there is only one line between the chunks, it is better to print it as is, +// rather than something like "(... 1 line skipped ...)". +constexpr const size_t kMinLinesToCutOut = 2; + +inline bool TryMergeChunks(std::pair *cur_chunk, const std::pair &new_chunk) { + if (new_chunk.first < cur_chunk->second + kMinLinesToCutOut) { + cur_chunk->second = new_chunk.second; + return true; + } else { + return false; + } +} + +inline size_t GetNumLines(const mlc::Str &text, const std::vector &line_starts) { + if (static_cast(line_starts.back()) == text.size()) { + // Final empty line doesn't count as a line + return line_starts.size() - 1; + } else { + return line_starts.size(); + } +} + +inline size_t GetLineNumberWidth(size_t num_lines, const PrinterConfig &options) { + if (options->print_line_numbers) { + return std::to_string(num_lines).size() + 1; + } else { + return 0; + } +} + +inline mlc::Str DecorateText(const mlc::Str &text, const std::vector &line_starts, const PrinterConfig &options, + const std::vector &underlines) { + size_t num_lines = GetNumLines(text, line_starts); + size_t line_number_width = GetLineNumberWidth(num_lines, options); + + std::ostringstream ret; + if (underlines.empty()) { + PrintChunk({0, num_lines}, {underlines.begin(), underlines.begin()}, text, line_starts, options, line_number_width, + &ret); + return ret.str(); + } + + size_t last_end_line = 0; + std::pair cur_chunk = GetLinesForUnderline(underlines[0], line_starts, num_lines, options); + if (cur_chunk.first < kMinLinesToCutOut) { + cur_chunk.first = 0; + } + + auto first_underline_in_cur_chunk = underlines.begin(); + for (auto underline_it = underlines.begin() + 1; underline_it != underlines.end(); ++underline_it) { + std::pair new_chunk = GetLinesForUnderline(*underline_it, line_starts, num_lines, options); + + if (!TryMergeChunks(&cur_chunk, new_chunk)) { + PrintCut(cur_chunk.first - last_end_line, &ret); + PrintChunk(cur_chunk, {first_underline_in_cur_chunk, underline_it}, text, line_starts, options, line_number_width, + &ret); + last_end_line = cur_chunk.second; + cur_chunk = new_chunk; + first_underline_in_cur_chunk = underline_it; + } + } + + PrintCut(cur_chunk.first - last_end_line, &ret); + if (num_lines - cur_chunk.second < kMinLinesToCutOut) { + cur_chunk.second = num_lines; + } + PrintChunk(cur_chunk, {first_underline_in_cur_chunk, underlines.end()}, text, line_starts, options, line_number_width, + &ret); + PrintCut(num_lines - cur_chunk.second, &ret); + return ret.str(); +} + +inline DocPrinter::DocPrinter(const PrinterConfig &options) : options_(options) { line_starts_.push_back(0); } + +inline void DocPrinter::Append(const Node &doc) { Append(doc, PrinterConfig()); } + +inline void DocPrinter::Append(const Node &doc, const PrinterConfig &cfg) { + for (ObjectPath p : cfg->path_to_underline) { + path_to_underline_.push_back(p); + current_max_path_length_.push_back(0); + current_underline_candidates_.push_back(std::vector()); + } + PrintDoc(doc); + for (const auto &c : current_underline_candidates_) { + underlines_.insert(underlines_.end(), c.begin(), c.end()); + } +} + +inline ::mlc::Str DocPrinter::GetString() const { + mlc::Str text = output_.str(); + // Remove any trailing indentation + while (!text->empty() && std::isspace(text->back())) { + text->pop_back(); + } + return DecorateText(text, line_starts_, options_, MergeAndExemptSpans(underlines_, underlines_exempted_)); +} + +inline void DocPrinter::PrintDoc(const Node &doc) { + size_t start_pos = output_.tellp(); + this->PrintTypedDoc(doc.get()); + size_t end_pos = output_.tellp(); + for (ObjectPath path : doc->source_paths) { + MarkSpan({start_pos, end_pos}, path); + } +} + +inline void DocPrinter::MarkSpan(const ByteSpan &span, const ObjectPath &path) { + int n = static_cast(path_to_underline_.size()); + for (int i = 0; i < n; ++i) { + ObjectPath p = path_to_underline_[i]; + if (path->length >= current_max_path_length_[i] && path->IsPrefixOf(p.get())) { + if (path->length > current_max_path_length_[i]) { + current_max_path_length_[i] = static_cast(path->length); + current_underline_candidates_[i].clear(); + } + current_underline_candidates_[i].push_back(span); + } + } +} + +class PythonDocPrinter : public DocPrinter { +public: + explicit PythonDocPrinter(const PrinterConfig &options) : DocPrinter(options) {} + +protected: + using DocPrinter::PrintDoc; + + void PrintTypedDoc(const Literal &doc) final; + void PrintTypedDoc(const Id &doc) final; + void PrintTypedDoc(const Attr &doc) final; + void PrintTypedDoc(const Index &doc) final; + void PrintTypedDoc(const Operation &doc) final; + void PrintTypedDoc(const Call &doc) final; + void PrintTypedDoc(const Lambda &doc) final; + void PrintTypedDoc(const List &doc) final; + void PrintTypedDoc(const Dict &doc) final; + void PrintTypedDoc(const Tuple &doc) final; + void PrintTypedDoc(const Slice &doc) final; + void PrintTypedDoc(const StmtBlock &doc) final; + void PrintTypedDoc(const Assign &doc) final; + void PrintTypedDoc(const If &doc) final; + void PrintTypedDoc(const While &doc) final; + void PrintTypedDoc(const For &doc) final; + void PrintTypedDoc(const ExprStmt &doc) final; + void PrintTypedDoc(const Assert &doc) final; + void PrintTypedDoc(const Return &doc) final; + void PrintTypedDoc(const With &doc) final; + void PrintTypedDoc(const Function &doc) final; + void PrintTypedDoc(const Class &doc) final; + void PrintTypedDoc(const Comment &doc) final; + void PrintTypedDoc(const DocString &doc) final; + +private: + void NewLineWithoutIndent() { + size_t start_pos = output_.tellp(); + output_ << "\n"; + size_t end_pos = output_.tellp(); + underlines_exempted_.push_back({start_pos, end_pos}); + } + + template void PrintJoinedDocs(const mlc::List &docs, const char *separator) { + bool is_first = true; + for (DocType doc : docs) { + if (is_first) { + is_first = false; + } else { + output_ << separator; + } + PrintDoc(doc); + } + } + + void PrintIndentedBlock(const mlc::List &docs) { + IncreaseIndent(); + for (Stmt d : docs) { + NewLine(); + PrintDoc(d); + } + if (docs.empty()) { + NewLine(); + output_ << "pass"; + } + DecreaseIndent(); + } + + void PrintDecorators(const mlc::List &decorators) { + for (Expr decorator : decorators) { + output_ << "@"; + PrintDoc(decorator); + NewLine(); + } + } + + /*! + * \brief Print expression and add parenthesis if needed. + */ + void PrintChildExpr(const Expr &doc, ExprPrecedence parent_precedence, bool parenthesis_for_same_precedence = false) { + ExprPrecedence doc_precedence = GetExprPrecedence(doc); + if (doc_precedence < parent_precedence || + (parenthesis_for_same_precedence && doc_precedence == parent_precedence)) { + output_ << "("; + PrintDoc(doc); + output_ << ")"; + } else { + PrintDoc(doc); + } + } + + /*! + * \brief Print expression and add parenthesis if doc has lower precedence than parent. + */ + void PrintChildExpr(const Expr &doc, const Expr &parent, bool parenthesis_for_same_precedence = false) { + ExprPrecedence parent_precedence = GetExprPrecedence(parent); + return PrintChildExpr(doc, parent_precedence, parenthesis_for_same_precedence); + } + + /*! + * \brief Print expression and add parenthesis if doc doesn't have higher precedence than parent. + * + * This function should be used to print an child expression that needs to be wrapped + * by parenthesis even if it has the same precedence as its parent, e.g., the `b` in `a + b` + * and the `b` and `c` in `a if b else c`. + */ + void PrintChildExprConservatively(const Expr &doc, const Expr &parent) { + PrintChildExpr(doc, parent, /*parenthesis_for_same_precedence=*/true); + } + + void MaybePrintCommentInline(const Stmt &stmt) { + if (const mlc::StrObj *comment = stmt->comment.get()) { + bool has_newline = std::find(comment->begin(), comment->end(), '\n') != comment->end(); + if (has_newline) { + MLC_THROW(ValueError) << "ValueError: Comment string of " << stmt->GetTypeKey() + << " cannot have newline, but got: " << comment; + } + size_t start_pos = output_.tellp(); + output_ << " # " << comment->data(); + size_t end_pos = output_.tellp(); + underlines_exempted_.push_back({start_pos, end_pos}); + } + } + + void MaybePrintCommenMultiLines(const Stmt &stmt, bool new_line = false) { + if (const mlc::StrObj *comment = stmt->comment.get()) { + bool first_line = true; + size_t start_pos = output_.tellp(); + for (const std::string_view &line : comment->Split('\n')) { + if (first_line) { + output_ << "# " << line; + first_line = false; + } else { + NewLine() << "# " << line; + } + } + size_t end_pos = output_.tellp(); + underlines_exempted_.push_back({start_pos, end_pos}); + if (new_line) { + NewLine(); + } + } + } + + void PrintDocString(const mlc::Str &comment) { + size_t start_pos = output_.tellp(); + output_ << "\"\"\""; + for (const std::string_view &line : comment->Split('\n')) { + if (line.empty()) { + // No indentation on empty line + output_ << "\n"; + } else { + NewLine() << line; + } + } + NewLine() << "\"\"\""; + size_t end_pos = output_.tellp(); + underlines_exempted_.push_back({start_pos, end_pos}); + } + void PrintBlockComment(const mlc::Str &comment) { + IncreaseIndent(); + NewLine(); + PrintDocString(comment); + DecreaseIndent(); + } +}; + +inline void PythonDocPrinter::PrintTypedDoc(const Literal &doc) { + Any value = doc->value; + int32_t type_index = value.GetTypeIndex(); + if (!value.defined()) { + output_ << "None"; + } else if (type_index == kMLCInt) { + output_ << value.operator int64_t(); + } else if (type_index == kMLCFloat) { + double v = value.operator double(); + // TODO(yelite): Make float number printing roundtrippable + if (std::isinf(v) || std::isnan(v)) { + output_ << '"' << v << '"'; + } else if (std::nearbyint(v) == v) { + // Special case for floating-point values which would be + // formatted using %g, are not displayed in scientific + // notation, and whose fractional part is zero. + // + // By default, using `operator<<(std::ostream&, double)` + // delegates to the %g printf formatter. This strips off any + // trailing zeros, and also strips the decimal point if no + // trailing zeros are found. When parsed in python, due to the + // missing decimal point, this would incorrectly convert a float + // to an integer. Providing the `std::showpoint` modifier + // instead delegates to the %#g printf formatter. On its own, + // this resolves the round-trip errors, but also prevents the + // trailing zeros from being stripped off. + std::showpoint(output_); + std::fixed(output_); + output_.precision(1); + output_ << v; + } else { + std::defaultfloat(output_); + std::noshowpoint(output_); + output_.precision(17); + output_ << v; + } + } else if (type_index == kMLCStr) { + (value.operator mlc::Str())->PrintEscape(output_); + } else { + MLC_THROW(TypeError) << "TypeError: Unsupported literal value type: " << value.GetTypeKey(); + } +} + +inline void PythonDocPrinter::PrintTypedDoc(const Id &doc) { output_ << doc->name; } + +inline void PythonDocPrinter::PrintTypedDoc(const Attr &doc) { + PrintChildExpr(doc->obj, doc); + output_ << "." << doc->name; +} + +inline void PythonDocPrinter::PrintTypedDoc(const Index &doc) { + PrintChildExpr(doc->obj, doc); + if (doc->idx.size() == 0) { + output_ << "[()]"; + } else { + output_ << "["; + PrintJoinedDocs(doc->idx, ", "); + output_ << "]"; + } +} + +inline void PythonDocPrinter::PrintTypedDoc(const Operation &doc) { + using OpKind = OperationObj::Kind; + if (doc->op < OpKind::kUnaryEnd) { + // Unary Operators + if (doc->operands.size() != 1) { + MLC_THROW(ValueError) << "ValueError: Unary operator requires 1 operand, but got " << doc->operands.size(); + } + output_ << OpKindToString(doc->op); + PrintChildExpr(doc->operands[0], doc); + } else if (doc->op == OpKind::kPow) { + // Power operator is different than other binary operators + // It's right-associative and binds less tightly than unary operator on its right. + // https://docs.python.org/3/reference/expressions.html#the-power-operator + // https://docs.python.org/3/reference/expressions.html#operator-precedence + if (doc->operands.size() != 2) { + MLC_THROW(ValueError) << "Operator '**' requires 2 operands, but got " << doc->operands.size(); + } + PrintChildExprConservatively(doc->operands[0], doc); + output_ << " ** "; + PrintChildExpr(doc->operands[1], ExprPrecedence::kUnary); + } else if (doc->op < OpKind::kBinaryEnd) { + // Binary Operator + if (doc->operands.size() != 2) { + MLC_THROW(ValueError) << "Binary operator requires 2 operands, but got " << doc->operands.size(); + } + PrintChildExpr(doc->operands[0], doc); + output_ << " " << OpKindToString(doc->op) << " "; + PrintChildExprConservatively(doc->operands[1], doc); + } else if (doc->op == OpKind::kIfThenElse) { + if (doc->operands.size() != 3) { + MLC_THROW(ValueError) << "IfThenElse requires 3 operands, but got " << doc->operands.size(); + } + PrintChildExpr(doc->operands[1], doc); + output_ << " if "; + PrintChildExprConservatively(doc->operands[0], doc); + output_ << " else "; + PrintChildExprConservatively(doc->operands[2], doc); + } else { + MLC_THROW(ValueError) << "Unknown OperationDocNode::Kind " << static_cast(doc->op); + } +} + +inline void PythonDocPrinter::PrintTypedDoc(const Call &doc) { + PrintChildExpr(doc->callee, doc); + output_ << "("; + // Print positional args + bool is_first = true; + for (Expr arg : doc->args) { + if (is_first) { + is_first = false; + } else { + output_ << ", "; + } + PrintDoc(arg); + } + // Print keyword args + if (doc->kwargs_keys.size() != doc->kwargs_values.size()) { + MLC_THROW(ValueError) << "CallDoc should have equal number of elements in kwargs_keys and kwargs_values."; + } + for (int64_t i = 0; i < doc->kwargs_keys.size(); i++) { + if (is_first) { + is_first = false; + } else { + output_ << ", "; + } + const mlc::Str &keyword = doc->kwargs_keys[i]; + output_ << keyword; + output_ << "="; + PrintDoc(doc->kwargs_values[i]); + } + + output_ << ")"; +} + +inline void PythonDocPrinter::PrintTypedDoc(const Lambda &doc) { + output_ << "lambda "; + PrintJoinedDocs(doc->args, ", "); + output_ << ": "; + PrintChildExpr(doc->body, doc); +} + +inline void PythonDocPrinter::PrintTypedDoc(const List &doc) { + output_ << "["; + PrintJoinedDocs(doc->values, ", "); + output_ << "]"; +} + +inline void PythonDocPrinter::PrintTypedDoc(const Tuple &doc) { + output_ << "("; + if (doc->values.size() == 1) { + PrintDoc(doc->values[0]); + output_ << ","; + } else { + PrintJoinedDocs(doc->values, ", "); + } + output_ << ")"; +} + +inline void PythonDocPrinter::PrintTypedDoc(const Dict &doc) { + if (doc->keys.size() != doc->values.size()) { + MLC_THROW(ValueError) << "DictDoc should have equal number of elements in keys and values."; + } + output_ << "{"; + size_t idx = 0; + for (Expr key : doc->keys) { + if (idx > 0) { + output_ << ", "; + } + PrintDoc(key); + output_ << ": "; + PrintDoc(doc->values[idx]); + idx++; + } + output_ << "}"; +} + +inline void PythonDocPrinter::PrintTypedDoc(const Slice &doc) { + if (doc->start != nullptr) { + PrintDoc(doc->start.value()); + } + output_ << ":"; + if (doc->stop != nullptr) { + PrintDoc(doc->stop.value()); + } + if (doc->step != nullptr) { + output_ << ":"; + PrintDoc(doc->step.value()); + } +} + +inline void PythonDocPrinter::PrintTypedDoc(const StmtBlock &doc) { + bool is_first = true; + for (Stmt stmt : doc->stmts) { + if (is_first) { + is_first = false; + } else { + NewLine(); + } + PrintDoc(stmt); + } +} + +inline void PythonDocPrinter::PrintTypedDoc(const Assign &doc) { + if (const auto *tuple_doc = doc->lhs->TryCast()) { + PrintJoinedDocs(tuple_doc->values, ", "); + } else { + PrintDoc(doc->lhs); + } + + if (doc->annotation.defined()) { + output_ << ": "; + PrintDoc(doc->annotation.value()); + } + if (doc->rhs.defined()) { + output_ << " = "; + if (const auto *tuple_doc = doc->rhs.TryCast()) { + if (tuple_doc->values.size() > 1) { + PrintJoinedDocs(tuple_doc->values, ", "); + } else { + PrintDoc(doc->rhs.value()); + } + } else { + PrintDoc(doc->rhs.value()); + } + } + MaybePrintCommentInline(doc); +} + +inline void PythonDocPrinter::PrintTypedDoc(const If &doc) { + MaybePrintCommenMultiLines(doc, true); + output_ << "if "; + PrintDoc(doc->cond); + output_ << ":"; + PrintIndentedBlock(doc->then_branch); + if (!doc->else_branch.empty()) { + NewLine(); + output_ << "else:"; + PrintIndentedBlock(doc->else_branch); + } +} + +inline void PythonDocPrinter::PrintTypedDoc(const While &doc) { + MaybePrintCommenMultiLines(doc, true); + output_ << "while "; + PrintDoc(doc->cond); + output_ << ":"; + PrintIndentedBlock(doc->body); +} + +inline void PythonDocPrinter::PrintTypedDoc(const For &doc) { + MaybePrintCommenMultiLines(doc, true); + output_ << "for "; + if (const auto *tuple = doc->lhs->TryCast()) { + if (tuple->values.size() == 1) { + PrintDoc(tuple->values[0]); + output_ << ","; + } else { + PrintJoinedDocs(tuple->values, ", "); + } + } else { + PrintDoc(doc->lhs); + } + output_ << " in "; + PrintDoc(doc->rhs); + output_ << ":"; + PrintIndentedBlock(doc->body); +} + +inline void PythonDocPrinter::PrintTypedDoc(const With &doc) { + MaybePrintCommenMultiLines(doc, true); + output_ << "with "; + PrintDoc(doc->rhs); + if (doc->lhs != nullptr) { + output_ << " as "; + PrintDoc(doc->lhs.value()); + } + output_ << ":"; + + PrintIndentedBlock(doc->body); +} + +inline void PythonDocPrinter::PrintTypedDoc(const ExprStmt &doc) { + PrintDoc(doc->expr); + MaybePrintCommentInline(doc); +} + +inline void PythonDocPrinter::PrintTypedDoc(const Assert &doc) { + output_ << "assert "; + PrintDoc(doc->cond); + if (doc->msg.defined()) { + output_ << ", "; + PrintDoc(doc->msg.value()); + } + MaybePrintCommentInline(doc); +} + +inline void PythonDocPrinter::PrintTypedDoc(const Return &doc) { + output_ << "return"; + if (doc->value.defined()) { + output_ << " "; + PrintDoc(doc->value.value()); + } + MaybePrintCommentInline(doc); +} + +inline void PythonDocPrinter::PrintTypedDoc(const Function &doc) { + PrintDecorators(doc->decorators); + output_ << "def "; + PrintDoc(doc->name); + output_ << "("; + PrintJoinedDocs(doc->args, ", "); + output_ << ")"; + if (doc->return_type.defined()) { + output_ << " -> "; + PrintDoc(doc->return_type.value()); + } + output_ << ":"; + if (doc->comment.defined()) { + PrintBlockComment(doc->comment.value()); + } + PrintIndentedBlock(doc->body); + NewLineWithoutIndent(); +} + +inline void PythonDocPrinter::PrintTypedDoc(const Class &doc) { + PrintDecorators(doc->decorators); + output_ << "class "; + PrintDoc(doc->name); + output_ << ":"; + if (doc->comment.defined()) { + PrintBlockComment(doc->comment.value()); + } + PrintIndentedBlock(doc->body); +} + +inline void PythonDocPrinter::PrintTypedDoc(const Comment &doc) { + if (doc->comment.defined()) { + MaybePrintCommenMultiLines(doc, false); + } +} + +inline void PythonDocPrinter::PrintTypedDoc(const DocString &doc) { + if (doc->comment.defined()) { + mlc::Str comment(doc->comment.value()); + if (!comment->empty()) { + PrintDocString(comment); + } + } +} + +mlc::Str DocToPythonScript(Node node, PrinterConfig cfg) { + if (cfg->num_context_lines < 0) { + constexpr int32_t kMaxInt32 = 2147483647; + cfg->num_context_lines = kMaxInt32; + } + PythonDocPrinter printer(cfg); + printer.Append(node, cfg); + mlc::Str result = printer.GetString(); + while (!result->empty() && std::isspace(result->back())) { + result->pop_back(); + } + return result; +} + +MLC_REGISTER_FUNC("mlc.printer.DocToPythonScript").set_body(DocToPythonScript); +MLC_REGISTER_FUNC("mlc.printer.ToPython").set_body(::mlc::printer::ToPython); + +} // namespace +} // namespace printer +} // namespace mlc diff --git a/cpp/registry.h b/cpp/registry.h new file mode 100644 index 00000000..5675d637 --- /dev/null +++ b/cpp/registry.h @@ -0,0 +1,504 @@ +#ifndef MLC_CPP_REGISTRY_H_ +#define MLC_CPP_REGISTRY_H_ +#ifdef _MSC_VER +#include +#else +#include +#endif +#include +#include +#include +#include +#include +#include +#include +#include + +namespace mlc { +namespace registry { + +struct DSOLibrary { + ~DSOLibrary() { Unload(); } + +#ifdef _MSC_VER + DSOLibrary(std::string name) { + std::wstring wname(name.begin(), name.end()); + lib_handle_ = LoadLibraryW(wname.c_str()); + if (lib_handle_ == nullptr) { + MLC_THROW(ValueError) << "Failed to load dynamic shared library " << name; + } + } + + void Unload() { + if (lib_handle_ != nullptr) { + FreeLibrary(lib_handle_); + lib_handle_ = nullptr; + } + } + + void *GetSymbol_(const char *name) { + return reinterpret_cast(GetProcAddress(lib_handle_, (LPCSTR)name)); // NOLINT(*) + } + + HMODULE lib_handle_{nullptr}; +#else + DSOLibrary(std::string name) { + lib_handle_ = dlopen(name.c_str(), RTLD_LAZY | RTLD_LOCAL); + if (lib_handle_ == nullptr) { + MLC_THROW(ValueError) << "Failed to load dynamic shared library " << name << " " << dlerror(); + } + } + + void Unload() { + if (lib_handle_ != nullptr) { + dlclose(lib_handle_); + lib_handle_ = nullptr; + } + } + + void *GetSymbol(const char *name) { return dlsym(lib_handle_, name); } + + void *lib_handle_{nullptr}; +#endif +}; + +struct ResourcePool { + using ObjPtr = std::unique_ptr; + + void AddObj(void *ptr) { + if (ptr != nullptr) { + MLCAny *ptr_cast = reinterpret_cast(ptr); + ::mlc::base::IncRef(ptr_cast); + this->objects.insert({ptr, ObjPtr(ptr_cast, ::mlc::base::DecRef)}); + } + } + + void DelObj(void *ptr) { + if (ptr != nullptr) { + this->objects.erase(this->objects.find(ptr)); + } + } + + template PODType *NewPODArray(int64_t size) { + return static_cast(this->NewPODArrayImpl(size, sizeof(PODType))); + } + + const char *NewStr(const char *source) { + if (source == nullptr) { + return nullptr; + } + size_t len = strlen(source); + char *ptr = this->NewPODArray(len + 1); + std::memcpy(ptr, source, len + 1); + return ptr; + } + + void DelPODArray(const void *ptr) { + if (ptr != nullptr) { + this->pod_array.erase(ptr); + } + } + +private: + MLC_INLINE void *NewPODArrayImpl(int64_t size, size_t pod_size) { + if (size == 0) { + return nullptr; + } + ::mlc::base::PODArray owned(static_cast(std::malloc(size * pod_size)), std::free); + void *ptr = owned.get(); + auto [it, success] = this->pod_array.emplace(ptr, std::move(owned)); + if (!success) { + std::cerr << "Array already registered: " << ptr; + std::abort(); + } + return ptr; + } + + std::unordered_map pod_array; + std::unordered_multimap objects; +}; + +struct TypeTable; + +struct VTable { + explicit VTable(TypeTable *type_table, std::string name) : type_table(type_table), name(std::move(name)), data() {} + void Set(int32_t type_index, FuncObj *func, int32_t override_mode); + FuncObj *GetFunc(int32_t type_index, bool allow_ancestor) const; + +private: + TypeTable *type_table; + std::string name; + std::unordered_map data; +}; + +struct TypeInfoWrapper { + MLCTypeInfo info{}; + ResourcePool *pool = nullptr; + int64_t num_fields = 0; + std::vector methods = {}; + + ~TypeInfoWrapper() { this->Reset(); } + + void Reset() { + if (this->pool) { + this->pool->DelPODArray(this->info.type_key); + this->pool->DelPODArray(this->info.type_ancestors); + this->ResetFields(); + this->ResetMethods(); + this->info.type_key = nullptr; + this->info.type_ancestors = nullptr; + this->pool = nullptr; + } + } + + void ResetFields() { + if (this->num_fields > 0) { + MLCTypeField *&fields = this->info.fields; + for (int32_t i = 0; i < this->num_fields; i++) { + this->pool->DelPODArray(fields[i].name); + } + this->pool->DelPODArray(fields); + this->info.fields = nullptr; + this->num_fields = 0; + } + } + + void ResetMethods() { + if (!this->methods.empty()) { + for (MLCTypeMethod &method : this->methods) { + if (method.name != nullptr) { + this->pool->DelPODArray(method.name); + this->pool->DelObj(method.func); + } + } + this->info.methods = nullptr; + this->methods.clear(); + } + } + + void ResetStructure() { + if (this->info.sub_structure_indices) { + this->pool->DelPODArray(this->info.sub_structure_indices); + this->info.sub_structure_indices = nullptr; + } + if (this->info.sub_structure_kinds) { + this->pool->DelPODArray(this->info.sub_structure_kinds); + this->info.sub_structure_kinds = nullptr; + } + } + + void SetFields(int64_t new_num_fields, MLCTypeField *fields) { + this->ResetFields(); + this->num_fields = new_num_fields; + MLCTypeField *dsts = this->info.fields = this->pool->NewPODArray(new_num_fields + 1); + for (int64_t i = 0; i < num_fields; i++) { + MLCTypeField *dst = dsts + i; + *dst = fields[i]; + this->pool->AddObj(dst->ty); + dst->name = this->pool->NewStr(dst->name); + if (dst->index != i) { + MLC_THROW(ValueError) << "Field index mismatch: " << i << " vs " << dst->index; + } + } + dsts[num_fields] = MLCTypeField{}; + std::sort(dsts, dsts + num_fields, + [](const MLCTypeField &a, const MLCTypeField &b) { return a.offset < b.offset; }); + } + + void AddMethod(MLCTypeMethod method) { + for (MLCTypeMethod &m : this->methods) { + if (m.name && std::strcmp(m.name, method.name) == 0) { + this->pool->DelObj(m.func); + this->pool->AddObj(reinterpret_cast(method.func)); + m.func = method.func; + m.kind = method.kind; + return; + } + } + if (this->methods.empty()) { + this->methods.emplace_back(); + } + MLCTypeMethod &dst = this->methods.back(); + dst = method; + dst.name = this->pool->NewStr(dst.name); + this->pool->AddObj(reinterpret_cast(dst.func)); + this->methods.emplace_back(); + this->info.methods = this->methods.data(); + } + + void SetStructure(int32_t structure_kind, int64_t num_sub_structures, int32_t *sub_structure_indices, + int32_t *sub_structure_kinds) { + this->ResetStructure(); + this->info.structure_kind = structure_kind; + if (num_sub_structures > 0) { + this->info.sub_structure_indices = this->pool->NewPODArray(num_sub_structures + 1); + this->info.sub_structure_kinds = this->pool->NewPODArray(num_sub_structures + 1); + std::memcpy(this->info.sub_structure_indices, sub_structure_indices, num_sub_structures * sizeof(int32_t)); + std::memcpy(this->info.sub_structure_kinds, sub_structure_kinds, num_sub_structures * sizeof(int32_t)); + std::reverse(this->info.sub_structure_indices, this->info.sub_structure_indices + num_sub_structures); + std::reverse(this->info.sub_structure_kinds, this->info.sub_structure_kinds + num_sub_structures); + this->info.sub_structure_indices[num_sub_structures] = -1; + this->info.sub_structure_kinds[num_sub_structures] = -1; + } else { + this->info.sub_structure_indices = nullptr; + this->info.sub_structure_kinds = nullptr; + } + } +}; + +struct TypeTable { + using ObjPtr = std::unique_ptr; + + int32_t num_types; + std::vector> type_table; + std::unordered_map type_key_to_info; + std::unordered_map global_funcs; + std::unordered_map> global_vtables; + std::unordered_map> dso_libraries; + ResourcePool pool; + + MLCTypeInfo *GetTypeInfo(int32_t type_index) { + if (type_index < 0 || type_index >= static_cast(this->type_table.size())) { + return nullptr; + } else { + return &this->type_table.at(type_index)->info; + } + } + + MLCTypeInfo *GetTypeInfo(const char *type_key) { + auto ret = this->type_key_to_info.find(type_key); + if (ret == this->type_key_to_info.end()) { + return nullptr; + } else { + return ret->second; + } + } + + FuncObj *GetVTable(int32_t type_index, const char *attr_key, bool allow_ancestor) { + if (auto it = this->global_vtables.find(attr_key); it != this->global_vtables.end()) { + VTable *vtable = it->second.get(); + return vtable->GetFunc(type_index, allow_ancestor); + } else { + return nullptr; + } + } + + VTable *GetGlobalVTable(const char *name) { + if (auto it = this->global_vtables.find(name); it != this->global_vtables.end()) { + return it->second.get(); + } else { + std::unique_ptr &vtable = this->global_vtables[name] = std::make_unique(this, name); + return vtable.get(); + } + } + + FuncObj *GetFunc(const char *name) { + if (auto it = this->global_funcs.find(name); it != this->global_funcs.end()) { + return it->second; + } else { + return nullptr; + } + } + + static TypeTable *New(); + static TypeTable *Global(); + static TypeTable *Get(MLCTypeTableHandle self) { + return (self == nullptr) ? TypeTable::Global() : static_cast(self); + } + + MLCTypeInfo *TypeRegister(int32_t parent_type_index, int32_t type_index, const char *type_key) { + // Step 1.Check if the type is already registered + if (auto it = this->type_key_to_info.find(type_key); it != this->type_key_to_info.end()) { + MLCTypeInfo *ret = it->second; + if (type_index != -1 && type_index != ret->type_index) { + MLC_THROW(KeyError) << "Type `" << type_key << "` registered with type index `" << ret->type_index + << "`, but re-registered with type index: " << type_index; + } + return ret; + } + // Step 2. Manipulate type table + if (type_index == -1) { + type_index = this->num_types++; + } + if (type_index >= static_cast(this->type_table.size())) { + this->type_table.resize((type_index / 1024 + 1) * 1024); + } + std::unique_ptr &wrapper = this->type_table.at(type_index) = std::make_unique(); + this->type_key_to_info[type_key] = &wrapper->info; + // Step 3. Initialize type info + MLCTypeInfo *parent = parent_type_index == -1 ? nullptr : this->GetTypeInfo(parent_type_index); + MLCTypeInfo *info = &wrapper->info; + info->type_index = type_index; + info->type_key = this->pool.NewStr(type_key); + info->type_key_hash = ::mlc::base::StrHash(type_key, std::strlen(type_key)); + info->type_depth = (parent == nullptr) ? 0 : (parent->type_depth + 1); + info->type_ancestors = this->pool.NewPODArray(info->type_depth); + if (parent) { + std::copy(parent->type_ancestors, parent->type_ancestors + parent->type_depth, info->type_ancestors); + info->type_ancestors[parent->type_depth] = parent_type_index; + } + info->fields = nullptr; + info->methods = nullptr; + info->structure_kind = 0; + info->sub_structure_indices = nullptr; + info->sub_structure_kinds = nullptr; + wrapper->pool = &this->pool; + return info; + } + + void SetFunc(const char *name, FuncObj *func, bool allow_override) { + auto [it, success] = this->global_funcs.try_emplace(std::string(name), nullptr); + if (!success && !allow_override) { + MLC_THROW(KeyError) << "Global function already registered: " << name; + } else { + it->second = func; + } + this->pool.AddObj(func); + } + + TypeInfoWrapper *GetTypeInfoWrapper(int32_t type_index) { + TypeInfoWrapper *wrapper = nullptr; + try { + wrapper = this->type_table.at(type_index).get(); + } catch (const std::out_of_range &) { + } + if (wrapper == nullptr || wrapper->pool != &this->pool) { + MLC_THROW(KeyError) << "Type index `" << type_index << "` not registered"; + } + return wrapper; + } + + void SetFields(int32_t type_index, int64_t num_fields, MLCTypeField *fields) { + this->GetTypeInfoWrapper(type_index)->SetFields(num_fields, fields); + } + + void SetStructure(int32_t type_index, int32_t structure_kind, int64_t num_sub_structures, + int32_t *sub_structure_indices, int32_t *sub_structure_kinds) { + this->GetTypeInfoWrapper(type_index) + ->SetStructure(structure_kind, num_sub_structures, sub_structure_indices, sub_structure_kinds); + } + + void AddMethod(int32_t type_index, MLCTypeMethod method) { + // TODO: check `override_mode` + this->GetGlobalVTable(method.name)->Set(type_index, reinterpret_cast(method.func), 0); + this->GetTypeInfoWrapper(type_index)->AddMethod(method); + } + + void LoadDSO(std::string name) { + auto [it, success] = this->dso_libraries.try_emplace(name, nullptr); + if (!success) { + return; + } + it->second = std::make_unique(name); + } +}; + +struct _POD_REG { + inline static const int32_t _none = // + ::mlc::core::ReflectionHelper(static_cast(MLCTypeIndex::kMLCNone)) + .MemFn("__str__", &::mlc::base::TypeTraits::__str__); + inline static const int32_t _int = // + ::mlc::core::ReflectionHelper(static_cast(MLCTypeIndex::kMLCInt)) + .StaticFn("__init__", [](AnyView value) { return value.operator int64_t(); }) + .StaticFn("__new_ref__", + [](void *dst, Optional value) { *reinterpret_cast *>(dst) = value; }) + .MemFn("__str__", &::mlc::base::TypeTraits::__str__); + inline static const int32_t _float = + ::mlc::core::ReflectionHelper(static_cast(MLCTypeIndex::kMLCFloat)) + .StaticFn("__new_ref__", + [](void *dst, Optional value) { *reinterpret_cast *>(dst) = value; }) + .MemFn("__str__", &::mlc::base::TypeTraits::__str__); + inline static const int32_t _ptr = + ::mlc::core::ReflectionHelper(static_cast(MLCTypeIndex::kMLCPtr)) + .StaticFn("__new_ref__", + [](void *dst, Optional value) { *reinterpret_cast *>(dst) = value; }) + .MemFn("__str__", &::mlc::base::TypeTraits::__str__); + inline static const int32_t _device = + ::mlc::core::ReflectionHelper(static_cast(MLCTypeIndex::kMLCDevice)) + .StaticFn("__init__", [](AnyView device) { return device.operator DLDevice(); }) + .StaticFn("__new_ref__", + [](void *dst, Optional value) { *reinterpret_cast *>(dst) = value; }) + .MemFn("__str__", &::mlc::base::TypeTraits::__str__); + inline static const int32_t _dtype = + ::mlc::core::ReflectionHelper(static_cast(MLCTypeIndex::kMLCDataType)) + .StaticFn("__init__", [](AnyView dtype) { return dtype.operator DLDataType(); }) + .StaticFn( + "__new_ref__", + [](void *dst, Optional value) { *reinterpret_cast *>(dst) = value; }) + .MemFn("__str__", &::mlc::base::TypeTraits::__str__); + inline static const int32_t _str = // + ::mlc::core::ReflectionHelper(static_cast(MLCTypeIndex::kMLCRawStr)) + .MemFn("__str__", &::mlc::base::TypeTraits::__str__); +}; + +inline TypeTable *TypeTable::New() { + TypeTable *self = new TypeTable(); + self->type_table.resize(1024); + self->type_key_to_info.reserve(1024); + self->num_types = static_cast(MLCTypeIndex::kMLCDynObjectBegin); +#define MLC_TYPE_TABLE_INIT_TYPE(UnderlyingType, Self) \ + { \ + using Traits = ::mlc::base::TypeTraits; \ + MLCTypeInfo *info = Self->TypeRegister(-1, Traits::type_index, Traits::type_str); \ + (void)info; \ + } + MLC_TYPE_TABLE_INIT_TYPE(std::nullptr_t, self); + MLC_TYPE_TABLE_INIT_TYPE(int64_t, self); + MLC_TYPE_TABLE_INIT_TYPE(double, self); + MLC_TYPE_TABLE_INIT_TYPE(void *, self); + MLC_TYPE_TABLE_INIT_TYPE(DLDevice, self); + MLC_TYPE_TABLE_INIT_TYPE(DLDataType, self); + MLC_TYPE_TABLE_INIT_TYPE(const char *, self); +#undef MLC_TYPE_TABLE_INIT_TYPE + return self; +} + +inline void VTable::Set(int32_t type_index, FuncObj *func, int32_t override_mode) { + auto [it, success] = this->data.try_emplace(type_index, nullptr); + if (!success) { + if (override_mode == 0) { + // No override + return; + } else if (override_mode == 1) { + // Allow override + this->type_table->pool.DelObj(it->second); + } else if (override_mode == 2) { + // TODO: throw exception + MLCTypeInfo *type_info = this->type_table->GetTypeInfo(type_index); + if (type_info && !name.empty()) { + MLC_THROW(KeyError) << "VTable `" << name << "` already registered for type: " << type_info->type_key; + } else if (type_info && name.empty()) { + MLC_THROW(KeyError) << "VTable already registered for type: " << type_info->type_key; + } else if (!type_info && !name.empty()) { + MLC_THROW(KeyError) << "VTable `" << name << "` already registered for type index: " << type_index; + } else { + MLC_THROW(KeyError) << "VTable already registered for type index: " << type_index; + } + MLC_UNREACHABLE(); + } + } + it->second = func; + this->type_table->pool.AddObj(func); +} + +inline FuncObj *VTable::GetFunc(int32_t type_index, bool allow_ancestor) const { + if (auto it = this->data.find(type_index); it != this->data.end()) { + return it->second; + } + if (!allow_ancestor) { + return nullptr; + } + MLCTypeInfo *info = this->type_table->GetTypeInfo(type_index); + for (int32_t i = info->type_depth - 1; i >= 0; i--) { + type_index = info->type_ancestors[i]; + if (auto it = this->data.find(type_index); it != this->data.end()) { + return it->second; + } + } + return nullptr; +} + +} // namespace registry +} // namespace mlc + +#endif // MLC_CPP_REGISTRY_H_ diff --git a/cpp/structure.cc b/cpp/structure.cc new file mode 100644 index 00000000..0401ece9 --- /dev/null +++ b/cpp/structure.cc @@ -0,0 +1,542 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace mlc { +namespace core { +namespace { + +uint64_t StructuralHash(Object *obj); +bool StructuralEqual(Object *lhs, Object *rhs, bool bind_free_vars, bool assert_mode); +void StructuralEqualImpl(Object *lhs, Object *rhs, bool bind_free_vars); + +struct SEqualError : public std::runtime_error { + ObjectPath path; + SEqualError(const char *msg, ObjectPath path) : std::runtime_error(msg), path(path) {} +}; + +inline bool StructuralEqual(Object *lhs, Object *rhs, bool bind_free_vars, bool assert_mode) { + try { + StructuralEqualImpl(lhs, rhs, bind_free_vars); + return true; + } catch (SEqualError &e) { + if (assert_mode) { + std::ostringstream os; + os << "Structural equality check failed at " << e.path << ": " << e.what(); + MLC_THROW(ValueError) << os.str(); + } + } + return false; +} + +template MLC_INLINE T *WithOffset(Object *obj, MLCTypeField *field) { + return reinterpret_cast(reinterpret_cast(obj) + field->offset); +} + +#define MLC_CORE_EQ_S_ERR(LHS, RHS, PATH) \ + { \ + std::ostringstream err; \ + err << (LHS) << " vs " << (RHS); \ + throw SEqualError(err.str().c_str(), (PATH)); \ + } +#define MLC_CORE_EQ_S_ANY(Cond, Type, EQ, LHS, RHS, PATH) \ + if (Cond) { \ + Type lhs_value = LHS->operator Type(); \ + Type rhs_value = RHS->operator Type(); \ + if (EQ(lhs_value, rhs_value)) { \ + return; \ + } else { \ + MLC_CORE_EQ_S_ERR(*lhs, *rhs, PATH); \ + } \ + } +#define MLC_CORE_EQ_S_OPT(Type, EQ) \ + MLC_INLINE void operator()(MLCTypeField *field, StructureFieldKind, Optional *_lhs) { \ + const Type *lhs = _lhs->get(); \ + const Type *rhs = WithOffset>(obj_rhs, field)->get(); \ + if ((lhs != nullptr || rhs != nullptr) && (lhs == nullptr || rhs == nullptr || !EQ(*lhs, *rhs))) { \ + AnyView LHS = lhs ? AnyView(*lhs) : AnyView(nullptr); \ + AnyView RHS = rhs ? AnyView(*rhs) : AnyView(nullptr); \ + MLC_CORE_EQ_S_ERR(LHS, RHS, path->WithField(field->name)); \ + } \ + } +#define MLC_CORE_EQ_S_POD(Type, EQ) \ + MLC_INLINE void operator()(MLCTypeField *field, StructureFieldKind, Type *lhs) { \ + const Type *rhs = WithOffset(obj_rhs, field); \ + if (!EQ(*lhs, *rhs)) { \ + MLC_CORE_EQ_S_ERR(AnyView(*lhs), AnyView(*rhs), path->WithField(field->name)); \ + } \ + } + +inline void StructuralEqualImpl(Object *lhs, Object *rhs, bool bind_free_vars) { + using CharArray = const char *; + using VoidPtr = ::mlc::base::VoidPtr; + using mlc::base::DataTypeEqual; + using mlc::base::DeviceEqual; + struct Task { + Object *lhs; + Object *rhs; + MLCTypeInfo *type_info; + bool visited; + bool bind_free_vars; // `map_free_vars` in TVM + ObjectPath path; + std::unique_ptr err; + }; + struct Visitor { + static bool CharArrayEqual(CharArray lhs, CharArray rhs) { return std::strcmp(lhs, rhs) == 0; } + static bool FloatEqual(float lhs, float rhs) { return std::abs(lhs - rhs) < 1e-6; } + static bool DoubleEqual(double lhs, double rhs) { return std::abs(lhs - rhs) < 1e-8; } + MLC_CORE_EQ_S_OPT(int64_t, std::equal_to()); + MLC_CORE_EQ_S_OPT(double, DoubleEqual); + MLC_CORE_EQ_S_OPT(DLDevice, DeviceEqual); + MLC_CORE_EQ_S_OPT(DLDataType, DataTypeEqual); + MLC_CORE_EQ_S_OPT(VoidPtr, std::equal_to()); + MLC_CORE_EQ_S_POD(int8_t, std::equal_to()); + MLC_CORE_EQ_S_POD(int16_t, std::equal_to()); + MLC_CORE_EQ_S_POD(int32_t, std::equal_to()); + MLC_CORE_EQ_S_POD(int64_t, std::equal_to()); + MLC_CORE_EQ_S_POD(float, FloatEqual); + MLC_CORE_EQ_S_POD(double, DoubleEqual); + MLC_CORE_EQ_S_POD(DLDataType, DataTypeEqual); + MLC_CORE_EQ_S_POD(DLDevice, DeviceEqual); + MLC_CORE_EQ_S_POD(VoidPtr, std::equal_to()); + MLC_CORE_EQ_S_POD(CharArray, CharArrayEqual); + MLC_INLINE void operator()(MLCTypeField *field, StructureFieldKind field_kind, const Any *lhs) { + const Any *rhs = WithOffset(obj_rhs, field); + bool bind_free_vars = this->obj_bind_free_vars || field_kind == StructureFieldKind::kBind; + EnqueueAny(tasks, bind_free_vars, lhs, rhs, path->WithField(field->name)); + } + MLC_INLINE void operator()(MLCTypeField *field, StructureFieldKind field_kind, ObjectRef *_lhs) { + HandleObject(field, field_kind, _lhs->get(), WithOffset(obj_rhs, field)->get()); + } + MLC_INLINE void operator()(MLCTypeField *field, StructureFieldKind field_kind, Optional *_lhs) { + HandleObject(field, field_kind, _lhs->get(), WithOffset>(obj_rhs, field)->get()); + } + inline void HandleObject(MLCTypeField *field, StructureFieldKind field_kind, Object *lhs, Object *rhs) { + if (lhs || rhs) { + bool bind_free_vars = this->obj_bind_free_vars || field_kind == StructureFieldKind::kBind; + EnqueueTask(tasks, bind_free_vars, lhs, rhs, path->WithField(field->name)); + } + } + static void EnqueueAny(std::vector *tasks, bool bind_free_vars, const Any *lhs, const Any *rhs, + ObjectPath new_path) { + int32_t type_index = lhs->GetTypeIndex(); + if (type_index != rhs->GetTypeIndex()) { + MLC_CORE_EQ_S_ERR(lhs->GetTypeKey(), rhs->GetTypeKey(), new_path); + } + if (type_index == kMLCNone) { + return; + } + MLC_CORE_EQ_S_ANY(type_index == kMLCInt, int64_t, std::equal_to(), lhs, rhs, new_path); + MLC_CORE_EQ_S_ANY(type_index == kMLCFloat, double, DoubleEqual, lhs, rhs, new_path); + MLC_CORE_EQ_S_ANY(type_index == kMLCPtr, VoidPtr, std::equal_to(), lhs, rhs, new_path); + MLC_CORE_EQ_S_ANY(type_index == kMLCDataType, DLDataType, DataTypeEqual, lhs, rhs, new_path); + MLC_CORE_EQ_S_ANY(type_index == kMLCDevice, DLDevice, DeviceEqual, lhs, rhs, new_path); + MLC_CORE_EQ_S_ANY(type_index == kMLCRawStr, CharArray, CharArrayEqual, lhs, rhs, new_path); + if (type_index < kMLCStaticObjectBegin) { + MLC_THROW(InternalError) << "Unknown type key: " << lhs->GetTypeKey(); + } + EnqueueTask(tasks, bind_free_vars, lhs->operator Object *(), rhs->operator Object *(), new_path); + } + static void EnqueueTask(std::vector *tasks, bool bind_free_vars, Object *lhs, Object *rhs, + ObjectPath new_path) { + int32_t lhs_type_index = lhs ? lhs->GetTypeIndex() : kMLCNone; + int32_t rhs_type_index = rhs ? rhs->GetTypeIndex() : kMLCNone; + if (lhs_type_index != rhs_type_index) { + MLC_CORE_EQ_S_ERR(::mlc::base::TypeIndex2TypeKey(lhs_type_index), + ::mlc::base::TypeIndex2TypeKey(rhs_type_index), new_path); + } else if (lhs_type_index == kMLCStr) { + Str lhs_str(reinterpret_cast(lhs)); + Str rhs_str(reinterpret_cast(rhs)); + if (lhs_str != rhs_str) { + MLC_CORE_EQ_S_ERR(lhs_str, rhs_str, new_path); + } + } else if (lhs_type_index == kMLCFunc || lhs_type_index == kMLCError) { + throw SEqualError("Cannot compare `mlc.Func` or `mlc.Error`", new_path); + } else { + bool visited = false; + MLCTypeInfo *type_info = ::mlc::base::TypeIndex2TypeInfo(lhs_type_index); + tasks->push_back(Task{lhs, rhs, type_info, visited, bind_free_vars, new_path, nullptr}); + } + } + Object *obj_rhs; + std::vector *tasks; + bool obj_bind_free_vars; + ObjectPath path; + }; + std::vector tasks; + std::unordered_map eq_lhs_to_rhs; + std::unordered_map eq_rhs_to_lhs; + + auto check_bind = [&eq_lhs_to_rhs, &eq_rhs_to_lhs](Object *lhs, Object *rhs, const ObjectPath &path) -> bool { + // check binding consistency: lhs -> rhs, rhs -> lhs + auto it_lhs_to_rhs = eq_lhs_to_rhs.find(lhs); + auto it_rhs_to_lhs = eq_rhs_to_lhs.find(rhs); + bool exist_lhs_to_rhs = it_lhs_to_rhs != eq_lhs_to_rhs.end(); + bool exist_rhs_to_lhs = it_rhs_to_lhs != eq_rhs_to_lhs.end(); + // already proven equal + if (exist_lhs_to_rhs && exist_rhs_to_lhs) { + if (it_lhs_to_rhs->second == rhs && it_rhs_to_lhs->second == lhs) { + return true; + } + throw SEqualError("Inconsistent binding: LHS and RHS are both bound, but to different nodes", path); + } + // inconsistent binding + if (exist_lhs_to_rhs) { + throw SEqualError("Inconsistent binding. LHS has been bound to a different node while RHS is not bound", path); + } + if (exist_rhs_to_lhs) { + throw SEqualError("Inconsistent binding. RHS has been bound to a different node while LHS is not bound", path); + } + return false; + }; + + Visitor::EnqueueTask(&tasks, bind_free_vars, lhs, rhs, ObjectPath::Root()); + while (!tasks.empty()) { + MLCTypeInfo *type_info; + ObjectPath path{::mlc::Null}; + { + Task &task = tasks.back(); + type_info = task.type_info; + path = task.path; + lhs = task.lhs; + rhs = task.rhs; + bind_free_vars = task.bind_free_vars; + if (task.err) { + throw SEqualError(task.err->str().c_str(), path); + } else if (check_bind(lhs, rhs, path)) { + tasks.pop_back(); + continue; + } else if (task.visited) { + StructureKind kind = static_cast(type_info->structure_kind); + if (kind == StructureKind::kBind || (kind == StructureKind::kVar && bind_free_vars)) { + // bind lhs <-> rhs + eq_lhs_to_rhs[lhs] = rhs; + eq_rhs_to_lhs[rhs] = lhs; + } else if (kind == StructureKind::kVar && !bind_free_vars) { + throw SEqualError("Unbound variable", path); + } + tasks.pop_back(); + continue; + } + task.visited = true; + } + // `task.visited` was `False` + int64_t task_index = static_cast(tasks.size()) - 1; + if (type_info->type_index == kMLCList) { + UListObj *lhs_list = reinterpret_cast(lhs); + UListObj *rhs_list = reinterpret_cast(rhs); + int64_t lhs_size = lhs_list->size(); + int64_t rhs_size = rhs_list->size(); + for (int64_t i = (lhs_size < rhs_size ? lhs_size : rhs_size) - 1; i >= 0; --i) { + Visitor::EnqueueAny(&tasks, bind_free_vars, &lhs_list->at(i), &rhs_list->at(i), path->WithListIndex(i)); + } + if (lhs_size != rhs_size) { + auto &err = tasks[task_index].err = std::make_unique(); + (*err) << "List length mismatch: " << lhs_size << " vs " << rhs_size; + } + } else if (type_info->type_index == kMLCDict) { + UDictObj *lhs_dict = reinterpret_cast(lhs); + UDictObj *rhs_dict = reinterpret_cast(rhs); + std::vector not_found_lhs_keys; + for (auto &kv : *lhs_dict) { + AnyView lhs_key = kv.first; + int32_t type_index = lhs_key.type_index; + UDictObj::iterator rhs_it; + if (type_index < kMLCStaticObjectBegin || type_index == kMLCStr) { + rhs_it = rhs_dict->find(lhs_key); + } else if (auto it = eq_lhs_to_rhs.find(lhs_key.operator Object *()); it != eq_lhs_to_rhs.end()) { + rhs_it = rhs_dict->find(it->second); + } else { + not_found_lhs_keys.push_back(lhs_key); + continue; + } + if (rhs_it == rhs_dict->end()) { + not_found_lhs_keys.push_back(lhs_key); + continue; + } + Visitor::EnqueueAny(&tasks, bind_free_vars, &kv.second, &rhs_it->second, path->WithDictKey(lhs_key)); + } + auto &err = tasks[task_index].err; + if (!not_found_lhs_keys.empty()) { + err = std::make_unique(); + (*err) << "Dict key(s) not found in rhs: " << not_found_lhs_keys[0]; + for (size_t i = 1; i < not_found_lhs_keys.size(); ++i) { + (*err) << ", " << not_found_lhs_keys[i]; + } + } else if (lhs_dict->size() != rhs_dict->size()) { + err = std::make_unique(); + (*err) << "Dict size mismatch: " << lhs_dict->size() << " vs " << rhs_dict->size(); + } + } else { + VisitStructure(lhs, type_info, Visitor{rhs, &tasks, bind_free_vars, path}); + } + } +} + +#define MLC_CORE_HASH_S_OPT(Type, Hasher) \ + MLC_INLINE void operator()(MLCTypeField *, StructureFieldKind, Optional *_v) { \ + if (const Type *v = _v->get()) { \ + EnqueuePOD(tasks, Hasher(*v)); \ + } else { \ + EnqueuePOD(tasks, HashCache::kNoneCombined); \ + } \ + } +#define MLC_CORE_HASH_S_POD(Type, Hasher) \ + MLC_INLINE void operator()(MLCTypeField *, StructureFieldKind, Type *v) { EnqueuePOD(tasks, Hasher(*v)); } +#define MLC_CORE_HASH_S_ANY(Cond, Type, Hasher) \ + if (Cond) { \ + EnqueuePOD(tasks, Hasher(v->operator Type())); \ + return; \ + } + +struct HashCache { + inline static const uint64_t MLC_SYMBOL_HIDE kNoneCombined = + ::mlc::base::HashCombine(::mlc::base::TypeIndex2TypeInfo(kMLCNone)->type_key_hash, 0); + inline static const uint64_t MLC_SYMBOL_HIDE kInt = ::mlc::base::TypeIndex2TypeInfo(kMLCInt)->type_key_hash; + inline static const uint64_t MLC_SYMBOL_HIDE kFloat = ::mlc::base::TypeIndex2TypeInfo(kMLCFloat)->type_key_hash; + inline static const uint64_t MLC_SYMBOL_HIDE kPtr = ::mlc::base::TypeIndex2TypeInfo(kMLCPtr)->type_key_hash; + inline static const uint64_t MLC_SYMBOL_HIDE kDType = ::mlc::base::TypeIndex2TypeInfo(kMLCDataType)->type_key_hash; + inline static const uint64_t MLC_SYMBOL_HIDE kDevice = ::mlc::base::TypeIndex2TypeInfo(kMLCDevice)->type_key_hash; + inline static const uint64_t MLC_SYMBOL_HIDE kRawStr = ::mlc::base::TypeIndex2TypeInfo(kMLCRawStr)->type_key_hash; + inline static const uint64_t MLC_SYMBOL_HIDE kStrObj = ::mlc::base::TypeIndex2TypeInfo(kMLCStr)->type_key_hash; + inline static const uint64_t MLC_SYMBOL_HIDE kBound = ::mlc::base::StrHash("$$Bounds$$"); + inline static const uint64_t MLC_SYMBOL_HIDE kUnbound = ::mlc::base::StrHash("$$Unbound$$"); +}; + +template MLC_INLINE uint64_t HashTyped(uint64_t type_hash, T value) { + union { + T src; + uint64_t tgt; + } u; + u.tgt = 0; + u.src = value; + return ::mlc::base::HashCombine(type_hash, u.tgt); +} + +inline uint64_t StructuralHash(Object *obj) { + using CharArray = const char *; + using VoidPtr = ::mlc::base::VoidPtr; + using mlc::base::HashCombine; + struct Task { + Object *obj; + MLCTypeInfo *type_info; + bool visited; + bool bind_free_vars; + uint64_t hash_value; + size_t index_in_result_hashes{0xffffffffffffffff}; + }; + struct Visitor { + static uint64_t HashInteger(int64_t a) { return HashTyped(HashCache::kInt, a); } + static uint64_t HashPtr(VoidPtr a) { return HashTyped(HashCache::kPtr, a); } + static uint64_t HashDevice(DLDevice a) { return HashTyped(HashCache::kDevice, a); } + static uint64_t HashDataType(DLDataType a) { return HashTyped(HashCache::kDType, a); } + // clang-format off + static uint64_t HashFloat(float a) { return HashTyped(HashCache::kFloat, std::isnan(a) ? std::numeric_limits::quiet_NaN() : a); } + static uint64_t HashDouble(double a) { return HashTyped(HashCache::kFloat, std::isnan(a) ? std::numeric_limits::quiet_NaN() : a); } + static uint64_t HashCharArray(CharArray a) { return HashTyped(HashCache::kRawStr, ::mlc::base::StrHash(a)); } + // clang-format on + MLC_CORE_HASH_S_OPT(int64_t, HashInteger); + MLC_CORE_HASH_S_OPT(double, HashDouble); + MLC_CORE_HASH_S_OPT(DLDevice, HashDevice); + MLC_CORE_HASH_S_OPT(DLDataType, HashDataType); + MLC_CORE_HASH_S_OPT(VoidPtr, HashPtr); + MLC_CORE_HASH_S_POD(int8_t, HashInteger); + MLC_CORE_HASH_S_POD(int16_t, HashInteger); + MLC_CORE_HASH_S_POD(int32_t, HashInteger); + MLC_CORE_HASH_S_POD(int64_t, HashInteger); + MLC_CORE_HASH_S_POD(float, HashFloat); + MLC_CORE_HASH_S_POD(double, HashDouble); + MLC_CORE_HASH_S_POD(DLDataType, HashDataType); + MLC_CORE_HASH_S_POD(DLDevice, HashDevice); + MLC_CORE_HASH_S_POD(VoidPtr, HashPtr); + MLC_CORE_HASH_S_POD(CharArray, HashCharArray); + MLC_INLINE void operator()(MLCTypeField *, StructureFieldKind field_kind, const Any *v) { + bool bind_free_vars = this->obj_bind_free_vars || field_kind == StructureFieldKind::kBind; + EnqueueAny(tasks, bind_free_vars, v); + } + MLC_INLINE void operator()(MLCTypeField *field, StructureFieldKind field_kind, ObjectRef *_v) { + HandleObject(field, field_kind, _v->get()); + } + MLC_INLINE void operator()(MLCTypeField *field, StructureFieldKind field_kind, Optional *_v) { + HandleObject(field, field_kind, _v->get()); + } + inline void HandleObject(MLCTypeField *, StructureFieldKind field_kind, Object *v) { + bool bind_free_vars = this->obj_bind_free_vars || field_kind == StructureFieldKind::kBind; + EnqueueTask(tasks, bind_free_vars, v); + } + static void EnqueuePOD(std::vector *tasks, uint64_t hash_value) { + tasks->emplace_back(Task{nullptr, nullptr, false, false, hash_value}); + } + static void EnqueueAny(std::vector *tasks, bool bind_free_vars, const Any *v) { + int32_t type_index = v->GetTypeIndex(); + MLC_CORE_HASH_S_ANY(type_index == kMLCInt, int64_t, HashInteger); + MLC_CORE_HASH_S_ANY(type_index == kMLCFloat, double, HashDouble); + MLC_CORE_HASH_S_ANY(type_index == kMLCPtr, VoidPtr, HashPtr); + MLC_CORE_HASH_S_ANY(type_index == kMLCDataType, DLDataType, HashDataType); + MLC_CORE_HASH_S_ANY(type_index == kMLCDevice, DLDevice, HashDevice); + MLC_CORE_HASH_S_ANY(type_index == kMLCRawStr, CharArray, HashCharArray); + EnqueueTask(tasks, bind_free_vars, v->operator Object *()); + } + static void EnqueueTask(std::vector *tasks, bool bind_free_vars, Object *obj) { + int32_t type_index = obj ? obj->GetTypeIndex() : kMLCNone; + if (type_index == kMLCNone) { + EnqueuePOD(tasks, HashCache::kNoneCombined); + } else if (type_index == kMLCStr) { + const MLCStr *str = reinterpret_cast(obj); + uint64_t hash_value = ::mlc::base::StrHash(str->data, str->length); + hash_value = HashTyped(HashCache::kStrObj, hash_value); + EnqueuePOD(tasks, hash_value); + } else if (type_index == kMLCFunc || type_index == kMLCError) { + throw SEqualError("Cannot compare `mlc.Func` or `mlc.Error`", ObjectPath::Root()); + } else { + MLCTypeInfo *type_info = ::mlc::base::TypeIndex2TypeInfo(type_index); + tasks->emplace_back(Task{obj, type_info, false, bind_free_vars, type_info->type_key_hash}); + } + } + + std::vector *tasks; + bool obj_bind_free_vars; + }; + std::vector tasks; + std::vector result_hashes; + std::unordered_map obj2hash; + int64_t num_bound_nodes = 0; + int64_t num_unbound_vars = 0; + Visitor::EnqueueTask(&tasks, false, obj); + while (!tasks.empty()) { + MLCTypeInfo *type_info; + bool bind_free_vars; + uint64_t hash_value; + { + Task &task = tasks.back(); + hash_value = task.hash_value; + obj = task.obj; + type_info = task.type_info; + bind_free_vars = task.bind_free_vars; + if (task.visited) { + if (result_hashes.size() < task.index_in_result_hashes) { + MLC_THROW(InternalError) + << "Internal invariant violated: `result_hashes.size() < task.index_in_result_hashes` (" + << result_hashes.size() << " vs " << task.index_in_result_hashes << ")"; + } + for (; result_hashes.size() > task.index_in_result_hashes; result_hashes.pop_back()) { + hash_value = HashCombine(hash_value, result_hashes.back()); + } + StructureKind kind = static_cast(type_info->structure_kind); + if (kind == StructureKind::kBind || (kind == StructureKind::kVar && bind_free_vars)) { + hash_value = HashCombine(hash_value, HashCache::kBound); + hash_value = HashCombine(hash_value, num_bound_nodes++); + } else if (kind == StructureKind::kVar && !bind_free_vars) { + hash_value = HashCombine(hash_value, HashCache::kUnbound); + hash_value = HashCombine(hash_value, num_unbound_vars++); + } + obj2hash[obj] = hash_value; + result_hashes.push_back(hash_value); + tasks.pop_back(); + continue; + } else if (auto it = obj2hash.find(obj); it != obj2hash.end()) { + result_hashes.push_back(it->second); + tasks.pop_back(); + continue; + } else if (obj == nullptr) { + result_hashes.push_back(hash_value); + tasks.pop_back(); + continue; + } + task.visited = true; + task.index_in_result_hashes = result_hashes.size(); + } + // `task.visited` was `False` + if (type_info->type_index == kMLCList) { + UListObj *list = reinterpret_cast(obj); + hash_value = HashCombine(hash_value, list->size()); + for (int64_t i = list->size() - 1; i >= 0; --i) { + Visitor::EnqueueAny(&tasks, bind_free_vars, &list->at(i)); + } + } else if (type_info->type_index == kMLCDict) { + UDictObj *dict = reinterpret_cast(obj); + hash_value = HashCombine(hash_value, dict->size()); + struct KVPair { + uint64_t hash; + AnyView key; + AnyView value; + }; + std::vector kv_pairs; + for (auto &[k, v] : *dict) { + uint64_t hash = 0; + if (k.type_index == kMLCNone) { + hash = HashCache::kNoneCombined; + } else if (k.type_index == kMLCInt) { + hash = Visitor::HashInteger(k.v.v_int64); + } else if (k.type_index == kMLCFloat) { + hash = Visitor::HashDouble(k.v.v_float64); + } else if (k.type_index == kMLCPtr) { + hash = Visitor::HashPtr(k.v.v_ptr); + } else if (k.type_index == kMLCDataType) { + hash = Visitor::HashDataType(k.v.v_dtype); + } else if (k.type_index == kMLCDevice) { + hash = Visitor::HashDevice(k.v.v_device); + } else if (k.type_index == kMLCStr) { + const StrObj *str = k; + hash = ::mlc::base::StrHash(str->data(), str->length()); + hash = HashTyped(HashCache::kStrObj, hash); + } else if (k.type_index >= kMLCStaticObjectBegin) { + obj = k; + if (auto it = obj2hash.find(obj); it != obj2hash.end()) { + hash = it->second; + } else { + continue; // Skip unbound nodes + } + } + kv_pairs.push_back(KVPair{hash, k, v}); + } + std::sort(kv_pairs.begin(), kv_pairs.end(), [](const KVPair &a, const KVPair &b) { return a.hash < b.hash; }); + for (size_t i = 0; i < kv_pairs.size();) { + // [i, j) are of the same hash + size_t j = i + 1; + for (; j < kv_pairs.size() && kv_pairs[i].hash == kv_pairs[j].hash; ++j) { + } + // skip cases where multiple keys have the same hash + if (i + 1 == j) { + Any k = kv_pairs[i].key; + Any v = kv_pairs[i].value; + Visitor::EnqueueAny(&tasks, bind_free_vars, &k); + Visitor::EnqueueAny(&tasks, bind_free_vars, &v); + } + } + } else { + VisitStructure(obj, type_info, Visitor{&tasks, bind_free_vars}); + } + } + if (result_hashes.size() != 1) { + MLC_THROW(InternalError) << "Internal invariant violated: `result_hashes.size() != 1` (" << result_hashes.size() + << ")"; + } + return result_hashes[0]; +} + +#undef MLC_CORE_EQ_S_OPT +#undef MLC_CORE_EQ_S_POD +#undef MLC_CORE_EQ_S_ANY +#undef MLC_CORE_EQ_S_ERR +#undef MLC_CORE_HASH_S_OPT +#undef MLC_CORE_HASH_S_POD +#undef MLC_CORE_HASH_S_ANY + +MLC_REGISTER_FUNC("mlc.core.StructuralEqual").set_body(::mlc::core::StructuralEqual); +MLC_REGISTER_FUNC("mlc.core.StructuralHash").set_body([](::mlc::Object *obj) -> int64_t { + uint64_t ret = ::mlc::core::StructuralHash(obj); + return static_cast(ret); +}); +} // namespace +} // namespace core +} // namespace mlc diff --git a/cpp/traceback.cc b/cpp/traceback.cc new file mode 100644 index 00000000..82635c06 --- /dev/null +++ b/cpp/traceback.cc @@ -0,0 +1,78 @@ +#ifndef _MSC_VER +#include "./traceback.h" +#include +#include +#include +#include + +namespace mlc { +namespace ffi { +namespace { + +static int32_t MLC_TRACEBACK_LIMIT = GetTracebackLimit(); +static backtrace_state *_bt_state = backtrace_create_state( + /*filename=*/nullptr, + /*threaded=*/1, + /*error_callback=*/ + +[](void * /*data*/, const char *msg, int /*errnum*/) -> void { + std::cerr << "Failed to initialize libbacktrace: " << msg << std::endl; + }, + /*data=*/nullptr); +thread_local size_t cxa_length = 1024; +thread_local char *cxa_buffer = static_cast(std::malloc(cxa_length * sizeof(char))); +thread_local char number_buffer[32]; +thread_local TracebackStorage storage; + +const char *CxaDemangle(const char *name) { + int status = 0; + size_t length = cxa_length; + char *demangled_name = abi::__cxa_demangle(name, cxa_buffer, &length, &status); + if (length > cxa_length) { + cxa_length = length; + cxa_buffer = demangled_name; + } + return (demangled_name && length && !status) ? demangled_name : name; +} + +MLCByteArray TracebackImpl() { + storage.buffer.clear(); + + auto callback = +[](void *data, uintptr_t pc, const char *filename, int lineno, const char *symbol) -> int { + if (!filename) { + filename = ""; + } + if (!symbol) { + backtrace_syminfo( + /*state=*/_bt_state, /*addr=*/pc, /*callback=*/ + +[](void *data, uintptr_t /*pc*/, const char *symname, uintptr_t /*symval*/, uintptr_t /*symsize*/) { + *reinterpret_cast(data) = symname; + }, + /*error_callback=*/+[](void * /*data*/, const char * /*msg*/, int /*errnum*/) {}, + /*data=*/&symbol); + } + symbol = symbol ? CxaDemangle(symbol) : StringifyPointer(pc, number_buffer, sizeof(number_buffer)); + if (IsForeignFrame(filename, lineno, symbol)) { + return 1; + } + reinterpret_cast(data)->Append(filename)->Append(lineno)->Append(symbol); + return 0; + }; + backtrace_full( + /*state=*/_bt_state, /*skip=*/1, /*callback=*/callback, + /*error_callback=*/[](void * /*data*/, const char * /*msg*/, int /*errnum*/) {}, + /*data=*/&storage); + return {static_cast(storage.buffer.size()), storage.buffer.data()}; +} + +} // namespace +} // namespace ffi +} // namespace mlc + +MLC_API MLCByteArray MLCTraceback(const char *, const char *, const char *) { + using namespace mlc::ffi; + if (_bt_state) { + return TracebackImpl(); + } + return {0, nullptr}; +} +#endif // _MSC_VER diff --git a/cpp/traceback.h b/cpp/traceback.h new file mode 100644 index 00000000..ea605cbd --- /dev/null +++ b/cpp/traceback.h @@ -0,0 +1,78 @@ +#ifndef MLC_CPP_TRACEBACK_H_ +#define MLC_CPP_TRACEBACK_H_ + +#include +#include +#include +#include + +namespace mlc { +namespace ffi { + +inline const char *StringifyPointer(uintptr_t ptr, char *buffer, size_t buffer_size) { + snprintf(buffer, buffer_size, "0x%016" PRIxPTR, ptr); + return buffer; +} + +inline bool EndsWith(const char *str, const char *suffix) { + size_t str_len = strlen(str); + size_t suffix_len = strlen(suffix); + return str_len >= suffix_len && strcmp(str + str_len - suffix_len, suffix) == 0; +} + +inline bool StartsWith(const char *str, const char *prefix) { + size_t str_len = strlen(str); + size_t prefix_len = strlen(prefix); + return str_len >= prefix_len && strncmp(str, prefix, prefix_len) == 0; +} + +inline bool IsForeignFrame(const char *filename, int32_t lineno, const char *func_name) { + (void)lineno; + (void)func_name; + if (EndsWith(filename, "core_cython.cc")) { + return true; + } +#if defined(__APPLE__) + if (strcmp(func_name, "MLCFuncSafeCall") == 0) { + return true; + } +#endif + return false; +} + +#ifdef _MSC_VER +#pragma warning(push) +#pragma warning(disable : 4996) // std::getenv is unsafe +#endif + +inline int32_t GetTracebackLimit() { + if (const char *env = std::getenv("MLC_TRACEBACK_LIMIT")) { + return std::stoi(env); + } + return 512; +} + +#ifdef _MSC_VER +#pragma warning(pop) +#endif + +struct TracebackStorage { + std::vector buffer = std::vector(1024 * 1024, '\0'); // 1 MB + char lineno_buffer[32]; + + TracebackStorage *Append(const char *str) { + buffer.insert(buffer.end(), str, str + strlen(str)); + buffer.push_back('\0'); + return this; + } + + TracebackStorage *Append(int lineno) { + snprintf(lineno_buffer, sizeof(lineno_buffer), "%d", lineno); + return this->Append(lineno_buffer); + } +}; + +} // namespace ffi +} // namespace mlc + +#endif // MLC_CPP_TRACEBACK_H_ diff --git a/cpp/traceback_win.cc b/cpp/traceback_win.cc new file mode 100644 index 00000000..7cf62b9d --- /dev/null +++ b/cpp/traceback_win.cc @@ -0,0 +1,104 @@ +#ifdef _MSC_VER +// clang-format off +#include +#include +// clang-format on +#include "./traceback.h" +#include + +namespace mlc { +namespace ffi { +namespace { + +static int32_t MLC_TRACEBACK_LIMIT = GetTracebackLimit(); +thread_local TracebackStorage storage; +thread_local char number_buffer[32]; +thread_local char undecorated_name[1024]; +thread_local char symbol_buffer[sizeof(SYMBOL_INFO) + MAX_SYM_NAME * sizeof(TCHAR)]; +#pragma comment(lib, "dbghelp.lib") + +thread_local std::pair process_thread = []() { + HANDLE process = GetCurrentProcess(); + HANDLE thread = GetCurrentThread(); + SymInitialize(process, NULL, TRUE); + SymSetOptions(SYMOPT_LOAD_LINES); + return std::make_pair(process, thread); +}(); + +MLCByteArray TracebackImpl() { + HANDLE process = process_thread.first; + HANDLE thread = process_thread.second; + CONTEXT context; + STACKFRAME64 frame; + DWORD machine; + + memset(&context, 0, sizeof(CONTEXT)); + context.ContextFlags = CONTEXT_FULL; + RtlCaptureContext(&context); + + memset(&frame, 0, sizeof(STACKFRAME64)); +#ifdef _M_IX86 + machine = IMAGE_FILE_MACHINE_I386; + frame.AddrPC.Offset = context.Eip; + frame.AddrStack.Offset = context.Esp; + frame.AddrFrame.Offset = context.Ebp; +#elif _M_X64 + machine = IMAGE_FILE_MACHINE_AMD64; + frame.AddrPC.Offset = context.Rip; + frame.AddrStack.Offset = context.Rsp; + frame.AddrFrame.Offset = context.Rbp; +#elif _M_ARM64 + machine = IMAGE_FILE_MACHINE_ARM64; + frame.AddrPC.Offset = context.Pc; + frame.AddrStack.Offset = context.Sp; + frame.AddrFrame.Offset = context.Fp; +#else +#error "Unsupported architecture" +#endif + frame.AddrPC.Mode = AddrModeFlat; + frame.AddrStack.Mode = AddrModeFlat; + frame.AddrFrame.Mode = AddrModeFlat; + storage.buffer.clear(); + while (StackWalk64(machine, process, thread, &frame, &context, NULL, SymFunctionTableAccess64, SymGetModuleBase64, + NULL)) { + PSYMBOL_INFO symbol = reinterpret_cast(symbol_buffer); + symbol->SizeOfStruct = sizeof(SYMBOL_INFO); + symbol->MaxNameLen = MAX_SYM_NAME; + + const char *filename = ""; + int lineno = 0; + const char *symbol_name; + + IMAGEHLP_LINE64 line; + line.SizeOfStruct = sizeof(IMAGEHLP_LINE64); + DWORD displacement_line = 0; + if (SymGetLineFromAddr64(process, frame.AddrPC.Offset, &displacement_line, &line)) { + filename = line.FileName; + lineno = line.LineNumber; + } + DWORD64 displacement = 0; + if (SymFromAddr(process, frame.AddrPC.Offset, &displacement, symbol)) { + symbol_name = UnDecorateSymbolName(symbol->Name, undecorated_name, sizeof(undecorated_name), UNDNAME_COMPLETE) + ? undecorated_name + : symbol->Name; + } else { + symbol_name = StringifyPointer(frame.AddrPC.Offset, number_buffer, sizeof(number_buffer)); + } + if (IsForeignFrame(filename, lineno, symbol_name)) { + break; + } + if (!EndsWith(filename, "traceback_win.cc")) { + storage.Append(filename); + storage.Append(lineno); + storage.Append(symbol_name); + } + } + return {static_cast(storage.buffer.size()), storage.buffer.data()}; +} + +} // namespace +} // namespace ffi +} // namespace mlc + +MLC_API MLCByteArray MLCTraceback(const char *, const char *, const char *) { return ::mlc::ffi::TracebackImpl(); } +#endif // _MSC_VER diff --git a/include/mlc/base/all.h b/include/mlc/base/all.h new file mode 100644 index 00000000..9543a129 --- /dev/null +++ b/include/mlc/base/all.h @@ -0,0 +1,253 @@ +#ifndef MLC_BASE_ALL_H_ +#define MLC_BASE_ALL_H_ + +#include "./alloc.h" // IWYU pragma: export +#include "./any.h" // IWYU pragma: export +#include "./base_traits.h" // IWYU pragma: export +#include "./optional.h" // IWYU pragma: export +#include "./ref.h" // IWYU pragma: export +#include "./traits_device.h" // IWYU pragma: export +#include "./traits_dtype.h" // IWYU pragma: export +#include "./traits_object.h" // IWYU pragma: export +#include "./traits_scalar.h" // IWYU pragma: export +#include "./traits_str.h" // IWYU pragma: export +#include // IWYU pragma: export +#include + +namespace mlc { + +/*********** New stuff ***********/ + +MLC_INLINE AnyView::AnyView(const Any &src) : MLCAny(static_cast(src)) {} +MLC_INLINE AnyView::AnyView(Any &&src) : MLCAny(static_cast(src)) {} + +template AnyView::AnyView(_T &&src) : MLCAny() { + using namespace ::mlc::base; + constexpr TypeKind kind = TypeKindOf<_T>; + if constexpr (kind == TypeKind::kObjRef) { + using TObjRef = RemoveCR<_T>; + using TObj = typename TObjRef::TObj; + TypeTraits::TypeToAny(const_cast(src.get()), this); + } else if constexpr (kind == TypeKind::kObjPtr) { + using TObjPtr = RemoveCR<_T>; + using TObj = std::remove_pointer_t; + TypeTraits::TypeToAny(src, this); + } else { // POD + using TPOD = RemoveCR<_T>; + TypeTraits::TypeToAny(src, this); + } +} + +template MLC_INLINE AnyView::AnyView(const T *src) : MLCAny() { + ::mlc::base::TypeTraits::TypeToAny(const_cast(src), this); +} + +template MLC_INLINE AnyView::AnyView(const Ref &src) : MLCAny() { + if (const auto *value = src.get()) { + if constexpr (::mlc::base::IsPOD) { + using TPOD = T; + ::mlc::base::TypeTraits::TypeToAny(*value, this); + } else { + using TObj = T; + ::mlc::base::TypeTraits::TypeToAny(const_cast(value), this); + } + } +} + +template MLC_INLINE AnyView::AnyView(Ref &&src) : AnyView(static_cast &>(src)) { + // `src` is not reset here because `AnyView` does not take ownership of the object +} + +template MLC_INLINE AnyView::AnyView(const Optional &src) { + if (const auto *value = src.get()) { + if constexpr (::mlc::base::IsPOD) { + using TPOD = T; + ::mlc::base::TypeTraits::TypeToAny(*value, this); + } else { + using TObj = typename T::TObj; + ::mlc::base::TypeTraits::TypeToAny(const_cast(value), this); + } + } +} + +template MLC_INLINE AnyView::AnyView(Optional &&src) : AnyView(static_cast &>(src)) { + // `src` is not reset here because `AnyView` does not take ownership of the object +} + +template MLC_INLINE AnyView &AnyView::operator=(T &&src) { + AnyView(std::forward(src)).Swap(*this); + return *this; +} + +MLC_INLINE AnyView &AnyView::operator=(const Any &src) { + AnyView(src).Swap(*this); + return *this; +} + +MLC_INLINE AnyView &AnyView::operator=(Any &&src) { + AnyView(std::move(src)).Swap(*this); + return *this; +} + +template MLC_INLINE AnyView &AnyView::operator=(const T *src) { + AnyView(src).Swap(*this); + return *this; +} + +template MLC_INLINE AnyView &AnyView::operator=(const Ref &src) { + AnyView(src).Swap(*this); + return *this; +} + +template MLC_INLINE AnyView &AnyView::operator=(Ref &&src) { + AnyView(std::move(src)).Swap(*this); + return *this; +} + +template MLC_INLINE AnyView &AnyView::operator=(const Optional &src) { + AnyView(src).Swap(*this); + return *this; +} + +template MLC_INLINE AnyView &AnyView::operator=(Optional &&src) { + AnyView(std::move(src)).Swap(*this); + return *this; +} + +template inline AnyView::operator T() const { + using namespace ::mlc::base; + if constexpr (IsObjRef) { + using TObjRef = T; + using TObj = typename TObjRef::TObj; + TObj *value = [this]() { + MLC_TRY_CONVERT(TypeTraits::AnyToTypeOwned(this), this->type_index, Type2Str::Run()); + }(); + return TObjRef(value); + } else { + // POD or ObjPtr + MLC_TRY_CONVERT(TypeTraits::AnyToTypeUnowned(this), this->type_index, Type2Str::Run()); + } +} + +template inline Any::operator T() const { + using namespace ::mlc::base; + if constexpr (IsObjRef) { + using TObjRef = T; + using TObj = typename TObjRef::TObj; + TObj *value = [this]() { + MLC_TRY_CONVERT(TypeTraits::AnyToTypeOwned(this), this->type_index, Type2Str::Run()); + }(); + return TObjRef(value); + } else { + // POD or ObjPtr + MLC_TRY_CONVERT(TypeTraits::AnyToTypeUnowned(this), this->type_index, Type2Str::Run()); + } +} + +template inline AnyView::operator Ref() const { + using namespace ::mlc::base; + if (IsTypeIndexNone(this->type_index)) { + return Ref(nullptr); + } + if constexpr (IsPOD) { + using TPOD = T; + TPOD value = [this]() { + MLC_TRY_CONVERT(TypeTraits::AnyToTypeUnowned(this), this->type_index, Type2Str::Run()); + }(); + return Ref(value); + } else { + static_assert(IsObj, "Unsupported type. Expect `Ref` or `Ref`"); + using TObj = T; + TObj *value = [this]() { + MLC_TRY_CONVERT(TypeTraits::AnyToTypeOwned(this), this->type_index, Type2Str::Run()); + }(); + return Ref(value); + } +} + +template inline Any::operator Ref() const { + using namespace ::mlc::base; + if (IsTypeIndexNone(this->type_index)) { + return Ref(nullptr); + } + if constexpr (IsPOD) { + using TPOD = T; + TPOD value = [this]() { + MLC_TRY_CONVERT(TypeTraits::AnyToTypeUnowned(this), this->type_index, Type2Str::Run()); + }(); + return Ref(value); + } else { + static_assert(IsObj, "Unsupported type. Expect `Ref` or `Ref`"); + using TObj = T; + TObj *value = [this]() { + MLC_TRY_CONVERT(TypeTraits::AnyToTypeOwned(this), this->type_index, Type2Str::Run()); + }(); + return Ref(value); + } +} + +template inline AnyView::operator Optional() const { + using namespace ::mlc::base; + if (IsTypeIndexNone(this->type_index)) { + return Optional(Null); + } + if constexpr (IsPOD) { + using TPOD = T; + TPOD value = [this]() { + MLC_TRY_CONVERT(TypeTraits::AnyToTypeUnowned(this), this->type_index, Type2Str::Run()); + }(); + return Optional(value); + } else { + static_assert(IsObjRef, "Unsupported type. Expect `Optional` or `Optional`"); + using TObj = typename T::TObj; + TObj *value = [this]() { + MLC_TRY_CONVERT(TypeTraits::AnyToTypeOwned(this), this->type_index, Type2Str::Run()); + }(); + return Optional(value); + } +} + +template inline Any::operator Optional() const { + using namespace ::mlc::base; + if (IsTypeIndexNone(this->type_index)) { + return Optional(Null); + } + if constexpr (IsPOD) { + using TPOD = T; + TPOD value = [this]() { + MLC_TRY_CONVERT(TypeTraits::AnyToTypeUnowned(this), this->type_index, Type2Str::Run()); + }(); + return Optional(value); + } else { + static_assert(IsObjRef, "Unsupported type. Expect `Optional` or `Optional`"); + using TObj = typename T::TObj; + TObj *value = [this]() { + MLC_TRY_CONVERT(TypeTraits::AnyToTypeOwned(this), this->type_index, Type2Str::Run()); + }(); + return Optional(value); + } +} + +/*********** Section 2. Conversion between Any/AnyView <=> POD ***********/ + +template MLC_INLINE_NO_MSVC _T AnyView::_CastWithStorage(Any *storage) const { + using namespace ::mlc::base; + using T = RemoveCR<_T>; + MLC_TRY_CONVERT(TypeTraits::AnyToTypeWithStorage(this, storage), this->type_index, Type2Str::Run()); +} + +/*********** Section 3. AnyViewArray ***********/ + +template template MLC_INLINE void AnyViewArray::Fill(Args &&...args) { + static_assert(sizeof...(args) == N, "Invalid number of arguments"); + if constexpr (N > 0) { + using IdxSeq = std::make_index_sequence; + ::mlc::base::AnyViewArrayFill::Run(v, std::forward(args)..., IdxSeq{}); + } +} +template MLC_INLINE void AnyViewArray<0>::Fill(Args &&...args) { + static_assert(sizeof...(args) == 0, "Invalid number of arguments"); +} +} // namespace mlc + +#endif // MLC_BASE_ALL_H_ diff --git a/include/mlc/base/alloc.h b/include/mlc/base/alloc.h new file mode 100644 index 00000000..826f0f94 --- /dev/null +++ b/include/mlc/base/alloc.h @@ -0,0 +1,93 @@ +#ifndef MLC_BASE_ALLOC_H_ +#define MLC_BASE_ALLOC_H_ + +#include "./utils.h" +#include +#include + +namespace mlc { + +template struct DefaultObjectAllocator { + using Storage = typename std::aligned_storage::type; + + template >> + MLC_INLINE_NO_MSVC static T *New(Args &&...args) { + Storage *data = new Storage; + try { + new (data) T(std::forward(args)...); + } catch (...) { + delete data; + throw; + } + T *ret = reinterpret_cast(data); + ret->_mlc_header.type_index = T::_type_index; + ret->_mlc_header.ref_cnt = 0; + ret->_mlc_header.v.deleter = DefaultObjectAllocator::Deleter; + return ret; + } + + template >> + MLC_INLINE_NO_MSVC static T *NewWithPad(size_t pad_size, Args &&...args) { + size_t num_storages = (sizeof(T) + pad_size * sizeof(PadType) + sizeof(Storage) - 1) / sizeof(Storage); + Storage *data = new Storage[num_storages]; + try { + new (data) T(std::forward(args)...); + } catch (...) { + delete[] data; + throw; + } + T *ret = reinterpret_cast(data); + ret->_mlc_header.type_index = T::_type_index; + ret->_mlc_header.ref_cnt = 0; + ret->_mlc_header.v.deleter = DefaultObjectAllocator::DeleterArray; + return ret; + } + + static void Deleter(void *objptr) { + T *tptr = static_cast(objptr); + tptr->T::~T(); + delete reinterpret_cast(tptr); + } + + static void DeleterArray(void *objptr) { + T *tptr = static_cast(objptr); + tptr->T::~T(); + delete[] reinterpret_cast(tptr); + } +}; + +template struct PODAllocator; + +#define MLC_DEF_POD_ALLOCATOR(Type, TypeIndex, Field) \ + template <> struct PODAllocator { \ + MLC_INLINE_NO_MSVC static MLCAny *New(Type data) { \ + MLCBoxedPOD *ret = new MLCBoxedPOD; \ + ret->_mlc_header.type_index = static_cast(TypeIndex); \ + ret->_mlc_header.ref_cnt = 0; \ + ret->_mlc_header.v.deleter = PODAllocator::Deleter; \ + ret->data.Field = data; \ + return reinterpret_cast(ret); \ + } \ + static void Deleter(void *objptr) { delete static_cast(objptr); } \ + } + +MLC_DEF_POD_ALLOCATOR(int64_t, MLCTypeIndex::kMLCInt, v_int64); +MLC_DEF_POD_ALLOCATOR(double, MLCTypeIndex::kMLCFloat, v_float64); +MLC_DEF_POD_ALLOCATOR(DLDataType, MLCTypeIndex::kMLCDataType, v_dtype); +MLC_DEF_POD_ALLOCATOR(DLDevice, MLCTypeIndex::kMLCDevice, v_device); +MLC_DEF_POD_ALLOCATOR(::mlc::base::VoidPtr, MLCTypeIndex::kMLCPtr, v_ptr); + +#undef MLC_DEF_POD_ALLOCATOR + +MLC_INLINE mlc::Object *AllocExternObject(int32_t type_index, int32_t num_bytes) { + MLCAny *ptr = reinterpret_cast(std::malloc(num_bytes)); + std::memset(ptr, 0, num_bytes); + ptr->type_index = type_index; + ptr->ref_cnt = 0; + ptr->v.deleter = MLCExtObjDelete; + return reinterpret_cast(ptr); +} + +} // namespace mlc + +#endif // MLC_BASE_ALLOC_H_ diff --git a/include/mlc/base/any.h b/include/mlc/base/any.h new file mode 100644 index 00000000..9cc01872 --- /dev/null +++ b/include/mlc/base/any.h @@ -0,0 +1,234 @@ +#ifndef MLC_BASE_ANY_H_ +#define MLC_BASE_ANY_H_ +#include "./base_traits.h" +#include "./utils.h" +#include +#include +#include + +namespace mlc { + +struct AnyView : public MLCAny { + static constexpr ::mlc::base::TypeKind _type_kind = ::mlc::base::TypeKind::kAny; + template using DObj = std::enable_if_t<::mlc::base::IsObj>; + template using DPODOrObjPtrOrObjRef = std::enable_if_t<::mlc::base::IsPODOrObjPtrOrObjRef>; + /***** Section 1. Default constructor/destructors *****/ + MLC_INLINE AnyView() : MLCAny() {} + MLC_INLINE AnyView(std::nullptr_t) : MLCAny() {} + MLC_INLINE AnyView(NullType) : MLCAny() {} + MLC_INLINE AnyView(const AnyView &src) = default; + MLC_INLINE AnyView(AnyView &&src) = default; + /* Defining "operator=" */ + // clang-format off + MLC_INLINE AnyView &operator=(const AnyView &src) = default; + MLC_INLINE AnyView &operator=(AnyView &&src) = default; + MLC_INLINE AnyView &operator=(std::nullptr_t) { this->Reset(); return *this; } + MLC_INLINE AnyView &operator=(NullType) { this->Reset(); return *this; } + MLC_INLINE ~AnyView() = default; + MLC_INLINE AnyView &Reset() { *(static_cast(this)) = MLCAny(); return *this; } + // clang-format on + /***** Section 2. Conversion from Any, POD, ObjRef, ObjPtr, `Ref`, `Optional` *****/ + MLC_INLINE AnyView(const Any &src); + MLC_INLINE AnyView(Any &&src); + template > AnyView(T &&src); + template > MLC_INLINE AnyView(const T *src); + template MLC_INLINE AnyView(const Ref &src); + template MLC_INLINE AnyView(Ref &&src); + template MLC_INLINE AnyView(const Optional &src); + template MLC_INLINE AnyView(Optional &&src); + /* Defining "operator=" */ + MLC_INLINE AnyView &operator=(const Any &src); + MLC_INLINE AnyView &operator=(Any &&src); + template > MLC_INLINE AnyView &operator=(T &&src); + template > MLC_INLINE AnyView &operator=(const T *src); + template MLC_INLINE AnyView &operator=(const Ref &src); + template MLC_INLINE AnyView &operator=(Ref &&src); + template MLC_INLINE AnyView &operator=(const Optional &src); + template MLC_INLINE AnyView &operator=(Optional &&src); + /***** Section 3. Conversion to `TPOD` `TObj *`, `Ref`, `Optional` *****/ + template > operator T() const; + template operator Ref() const; + template operator Optional() const; + /***** Section 4. accessors, comparators and stringify *****/ + MLC_INLINE bool defined() const { return !::mlc::base::IsTypeIndexNone(this->type_index); } + MLC_INLINE bool has_value() const { return !::mlc::base::IsTypeIndexNone(this->type_index); } + MLC_INLINE bool operator==(std::nullptr_t) const { return !this->defined(); } + MLC_INLINE bool operator!=(std::nullptr_t) const { return this->defined(); } + Str str() const; + friend std::ostream &operator<<(std::ostream &os, const AnyView &src); + /***** Section 5. Runtime-type information *****/ + template MLC_INLINE bool IsInstance() const { + static_assert(::mlc::base::IsObj, "DerivedObj must be an object type"); + if (::mlc::base::IsTypeIndexPOD(this->type_index)) { + return false; + } + return ::mlc::base::IsInstanceOf(this->v.v_obj); + } + template MLC_INLINE DerivedObj *TryCast() { + return this->IsInstance() ? reinterpret_cast(this->v.v_obj) : nullptr; + } + template MLC_INLINE const DerivedObj *TryCast() const { + return this->IsInstance() ? reinterpret_cast(this->v.v_obj) : nullptr; + } + template inline DerivedObj *Cast() { + if (!this->IsInstance()) { + MLC_THROW(TypeError) << "Cannot cast from type `" << ::mlc::base::TypeIndex2TypeKey(this->type_index) + << "` to type `" << ::mlc::base::Type2Str::Run() << "`"; + } + return reinterpret_cast(this->v.v_obj); + } + template MLC_INLINE const DerivedObj *Cast() const { + if (!this->IsInstance()) { + MLC_THROW(TypeError) << "Cannot cast from type `" << ::mlc::base::TypeIndex2TypeKey(this->type_index) + << "` to type `" << ::mlc::base::Type2Str::Run() << "`"; + } + return reinterpret_cast(this->v.v_obj); + } + int32_t GetTypeIndex() const { return this->type_index; } + const char *GetTypeKey() const { return ::mlc::base::TypeIndex2TypeInfo(this->type_index)->type_key; } + + template >>> + MLC_INLINE_NO_MSVC T _CastWithStorage(Any *storage) const; // TODO: reemove this + +protected: + MLC_INLINE void Swap(MLCAny &src) { + MLCAny tmp = *this; + *static_cast(this) = src; + src = tmp; + } +}; + +struct Any : public MLCAny { + static constexpr ::mlc::base::TypeKind _type_kind = ::mlc::base::TypeKind::kAny; + template using DObj = std::enable_if_t<::mlc::base::IsObj>; + template using DPODOrObjPtrOrObjRef = std::enable_if_t<::mlc::base::IsPODOrObjPtrOrObjRef>; + /***** Section 1. Default constructor/destructors *****/ + MLC_INLINE Any() : MLCAny() {} + MLC_INLINE Any(std::nullptr_t) : MLCAny() {} + MLC_INLINE Any(NullType) : MLCAny() {} + MLC_INLINE Any(const Any &src) : MLCAny(static_cast(src)) { this->IncRef(); } + MLC_INLINE Any(Any &&src) : MLCAny(static_cast(src)) { *static_cast(&src) = MLCAny(); } + /* Defining "operator=" */ + // clang-format off + MLC_INLINE Any &operator=(const Any &src) { Any(src).Swap(*this); return *this; } + MLC_INLINE Any &operator=(Any &&src) { Any(std::move(src)).Swap(*this); return *this; } + MLC_INLINE Any &operator=(std::nullptr_t) { this->Reset(); return *this; } + MLC_INLINE Any &operator=(NullType) { this->Reset(); return *this; } + MLC_INLINE ~Any() { this->Reset(); } + MLC_INLINE Any &Reset() { this->DecRef(); *(static_cast(this)) = MLCAny(); return *this; } + // clang-format on + /***** Section 2. Conversion from AnyView, POD, ObjRef, ObjPtr, `Ref`, `Optional` *****/ + // clang-format off + MLC_INLINE Any(const AnyView &src) : MLCAny(static_cast(src)) { this->SwitchFromRawStr(); this->IncRef(); } + MLC_INLINE Any(AnyView &&src) : MLCAny(static_cast(src)) { *static_cast(&src) = MLCAny(); this->SwitchFromRawStr(); this->IncRef(); } + // clang-format on + template > Any(T &&src) : Any(AnyView(std::forward(src))) {} + template > MLC_INLINE Any(const T *src) : Any(AnyView(src)) {} + template MLC_INLINE Any(const Ref &src) : Any(AnyView(src)) {} + template MLC_INLINE Any(Ref &&src) : Any(AnyView(std::move(src))) {} + template MLC_INLINE Any(const Optional &src) : Any(AnyView(src)) {} + template MLC_INLINE Any(Optional &&src) : Any(AnyView(std::move(src))) {} + /* Defining "operator=" */ + // clang-format off + MLC_INLINE Any &operator=(const AnyView &src) { Any(src).Swap(*this); return *this; } + MLC_INLINE Any &operator=(AnyView &&src) { Any(std::move(src)).Swap(*this); return *this; } + template > MLC_INLINE Any &operator=(T &&src) { Any(std::forward(src)).Swap(*this); return *this; } + template > MLC_INLINE Any &operator=(const T *src) { Any(AnyView(src)).Swap(*this); return *this; } + template MLC_INLINE Any &operator=(const Ref &src) { Any(AnyView(src)).Swap(*this); return *this; } + template MLC_INLINE Any &operator=(Ref &&src) { Any(AnyView(std::move(src))).Swap(*this); return *this; } + template MLC_INLINE Any &operator=(const Optional &src) { Any(AnyView(src)).Swap(*this); return *this; } + template MLC_INLINE Any &operator=(Optional &&src) { Any(AnyView(std::move(src))).Swap(*this); return *this; } + // clang-format on + /***** Section 3. Conversion to `TPOD` `TObj *`, `Ref`, `Optional` *****/ + template > operator T() const; + template operator Ref() const; + template operator Optional() const; + /***** Section 4. accessors, comparators and stringify *****/ + MLC_INLINE bool defined() const { return !::mlc::base::IsTypeIndexNone(this->type_index); } + MLC_INLINE bool has_value() const { return !::mlc::base::IsTypeIndexNone(this->type_index); } + MLC_INLINE bool operator==(std::nullptr_t) const { return !this->defined(); } + MLC_INLINE bool operator!=(std::nullptr_t) const { return this->defined(); } + Str str() const; + friend std::ostream &operator<<(std::ostream &os, const Any &src); + /***** Section 5. Runtime-type information *****/ + template MLC_INLINE bool IsInstance() const { + static_assert(::mlc::base::IsObj, "DerivedObj must be an object type"); + if (::mlc::base::IsTypeIndexPOD(this->type_index)) { + return false; + } + return ::mlc::base::IsInstanceOf(this->v.v_obj); + } + template MLC_INLINE DerivedObj *TryCast() { + return this->IsInstance() ? reinterpret_cast(this->v.v_obj) : nullptr; + } + template MLC_INLINE const DerivedObj *TryCast() const { + return this->IsInstance() ? reinterpret_cast(this->v.v_obj) : nullptr; + } + template inline DerivedObj *Cast() { + if (!this->IsInstance()) { + MLC_THROW(TypeError) << "Cannot cast from type `" << ::mlc::base::TypeIndex2TypeKey(this->type_index) + << "` to type `" << ::mlc::base::Type2Str::Run() << "`"; + } + return reinterpret_cast(this->v.v_obj); + } + template MLC_INLINE const DerivedObj *Cast() const { + if (!this->IsInstance()) { + MLC_THROW(TypeError) << "Cannot cast from type `" << ::mlc::base::TypeIndex2TypeKey(this->type_index) + << "` to type `" << ::mlc::base::Type2Str::Run() << "`"; + } + return reinterpret_cast(this->v.v_obj); + } + int32_t GetTypeIndex() const { return this->type_index; } + const char *GetTypeKey() const { return ::mlc::base::TypeIndex2TypeInfo(this->type_index)->type_key; } + +protected: + MLC_INLINE void Swap(MLCAny &src) { + MLCAny tmp = *this; + *static_cast(this) = src; + src = tmp; + } + MLC_INLINE void IncRef() { + if (!::mlc::base::IsTypeIndexPOD(this->type_index)) { + ::mlc::base::IncRef(this->v.v_obj); + } + } + MLC_INLINE void DecRef() { + if (!::mlc::base::IsTypeIndexPOD(this->type_index)) { + ::mlc::base::DecRef(this->v.v_obj); + } + } + MLC_INLINE void SwitchFromRawStr() { + if (this->type_index == static_cast(MLCTypeIndex::kMLCRawStr)) { + this->type_index = static_cast(MLCTypeIndex::kMLCStr); + this->v.v_obj = + reinterpret_cast(::mlc::base::StrCopyFromCharArray(this->v.v_str, std::strlen(this->v.v_str))); + } + } +}; + +template struct AnyViewArray { + template void Fill(Args &&...args); + AnyView v[N]; +}; + +template <> struct AnyViewArray<0> { + template void Fill(Args &&...args); + AnyView *v = nullptr; +}; +} // namespace mlc + +namespace mlc { +namespace base { +template struct AnyViewArrayFill { + template MLC_INLINE static void Apply(AnyView *v, Arg &&arg) { + v[i] = AnyView(std::forward(arg)); + } + template MLC_INLINE static void Run(AnyView *v, Args &&...args, std::index_sequence) { + using TExpander = int[]; + (void)TExpander{0, (Apply(v, std::forward(args)), 0)...}; + } +}; +} // namespace base +} // namespace mlc + +#endif // MLC_BASE_ANY_H_ diff --git a/include/mlc/base/base_traits.h b/include/mlc/base/base_traits.h new file mode 100644 index 00000000..7c07bd5c --- /dev/null +++ b/include/mlc/base/base_traits.h @@ -0,0 +1,329 @@ +#ifndef MLC_BASE_TRAITS_H_ +#define MLC_BASE_TRAITS_H_ +#include +#include +#include +#include +#include +#include + +namespace mlc { +struct Exception; +struct NullType {}; +constexpr NullType Null{}; + +struct AnyView; +struct Any; +template struct AnyViewArray; +template struct Ref; +template struct Optional; +struct Object; +struct ObjectRef; +struct ErrorObj; +struct Error; +struct StrObj; +struct Str; +struct FuncObj; +struct Func; +struct UListObj; +struct UList; +struct UDictObj; +struct UDict; +template struct ListObj; +template struct List; +template struct DictObj; +template struct Dict; + +template Ref InitOf(Args &&...args); +template struct DefaultObjectAllocator; +template struct PODAllocator; + +enum class StructureKind : int32_t { + kNone = 0, + kNoBind = 1, + kBind = 2, + kVar = 3, +}; +enum class StructureFieldKind : int32_t { + kNoBind = 0, + kBind = 1, +}; + +} // namespace mlc + +namespace mlc { +namespace base { + +// Basic types, tuples and other utils +using VoidPtr = void *; +template using CharArray = char[N]; +template using RemoveCR = std::remove_cv_t>; +// clang-format off +template struct _TupleAppend; +template struct _TupleAppend, Append> { using type = std::tuple; }; +template struct _TupleElementEq { static constexpr bool Run() { if constexpr (i < tuple_len) return std::is_same_v, expected>; else return false; } }; +// clang-format on +template using TupleAppend = typename _TupleAppend::type; +template +constexpr bool TupleElementEq = _TupleElementEq::Run(); + +// IsTemplate +template struct IsTemplateImpl : std::false_type {}; +template