From db36cbead045445fa3a867f1ef2e8fdf3617daee Mon Sep 17 00:00:00 2001 From: Kim Tran <17498395+kxtran@users.noreply.github.com> Date: Fri, 17 May 2024 14:51:40 -0400 Subject: [PATCH] Add a cronjob tests via github actions (#158) * Remove non-existtent files * Add tests for llm providers and some framework * Update pytest marks more explicit * Fix format * Add github workflow for running tests daily * Assert on output instead of stdout for streaming cases * Reformat files * Pass log10 creds to the test workflow * Add optional google-generativeai module and modify some tests * Support switching models for tests * Set default models across llm providers and framework * Use @pytest.mark.asyncio in tests and add asyncio module * Update github workflows to run individual, all tests with dispatch * Fix typo and update tests * Update poetry lock file * Update litellm tests to initialize callbacks once --- .github/workflows/test.yml | 122 +++++++++++++++ Makefile | 4 - poetry.lock | 152 +++++++++++++++++-- pyproject.toml | 4 + tests/conftest.py | 61 ++++++++ tests/pytest.ini | 18 +++ tests/test_anthropic.py | 87 +++++++++++ tests/test_google.py | 42 ++++++ tests/test_lamini.py | 16 ++ tests/test_langchain.py | 32 ++++ tests/test_litellm.py | 143 ++++++++++++++++++ tests/test_magentic.py | 147 ++++++++++++++++++ tests/test_mistralai.py | 39 +++++ tests/test_openai.py | 296 +++++++++++++++++++++++++++++++++++++ 14 files changed, 1144 insertions(+), 19 deletions(-) create mode 100644 .github/workflows/test.yml create mode 100644 tests/conftest.py create mode 100644 tests/pytest.ini create mode 100644 tests/test_anthropic.py create mode 100644 tests/test_google.py create mode 100644 tests/test_lamini.py create mode 100644 tests/test_langchain.py create mode 100644 tests/test_litellm.py create mode 100644 tests/test_magentic.py create mode 100644 tests/test_mistralai.py create mode 100644 tests/test_openai.py diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml new file mode 100644 index 00000000..db07a32c --- /dev/null +++ b/.github/workflows/test.yml @@ -0,0 +1,122 @@ +name: Test + +on: + schedule: + - cron: "0 13 * * *" + workflow_dispatch: + inputs: + openai_model: + description: 'Model name for OpenAI tests' + type: string + required: false + openai_vision_model: + description: 'Model name for OpenAI vision tests' + type: string + required: false + anthropic_model: + description: 'Model name for Anthropic tests' + type: string + required: false + google_model: + description: 'Model name for Google tests' + type: string + required: false + mistralai_model: + description: 'Model name for Mistralai tests' + type: string + required: false + lamini_model: + description: 'Model name for Lamini tests' + type: string + required: false + magentic_model: + description: 'Model name for Magentic tests' + type: string + required: false + +env: + PYTHON_VERSION: 3.11.4 +jobs: + test: + runs-on: ubuntu-latest + env: + LOG10_URL: "https://log10.io" + LOG10_ORG_ID: ${{ secrets.LOG10_ORG_ID }} + LOG10_TOKEN: ${{ secrets.LOG10_TOKEN }} + OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }} + ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }} + MISTRAL_API_KEY: ${{ secrets.MISTRAL_API_KEY }} + LAMINI_API_KEY: ${{ secrets.LAMINI_API_KEY }} + GOOGLE_API_KEY : ${{ secrets.GOOGLE_API_KEY }} + steps: + - uses: actions/checkout@v4 + - name: Install poetry + run: pipx install poetry + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: ${{env.PYTHON_VERSION}} + cache: "poetry" + architecture: 'x64' + - name: Install dependencies + run: poetry install --all-extras + + - name: Run dispatch tests + if: ${{ github.event_name == 'workflow_dispatch' }} + run: | + echo "This is a dispatch event" + openai_model_input=${{ github.event.inputs.openai_model }} + openai_vision_model_input=${{ github.event.inputs.openai_vision_model }} + anthropic_model_input=${{ github.event.inputs.anthropic_model }} + google_model_input=${{ github.event.inputs.google_model }} + mistralai_model_input=${{ github.event.inputs.mistralai_model }} + lamini_model_input=${{ github.event.inputs.lamini_model }} + magentic_model_input=${{ github.event.inputs.magentic_model }} + + empty_inputs=true + if [[ -n "$openai_model_input" ]]; then + empty_inputs=false + poetry run pytest --openai_model=$openai_model_input -vv tests/test_openai.py + fi + + if [[ -n "$openai_vision_model_input" ]]; then + empty_inputs=false + poetry run pytest --openai_vision_model=$openai_vision_model_input -m vision -vv tests/test_openai.py + fi + + if [[ -n "$anthropic_model_input" ]]; then + empty_inputs=false + poetry run pytest --anthropic_model=$anthropic_model_input -vv tests/test_anthropic.py + fi + + if [[ -n "$google_model_input" ]]; then + empty_inputs=false + poetry run pytest --google_model=$google_model_input -vv tests/test_google.py + fi + + if [[ -n "$mistralai_model_input" ]]; then + empty_inputs=false + poetry run pytest --mistralai_model=$mistralai_model_input -vv tests/test_mistralai.py + fi + + if [[ -n "$lamini_model_input" ]]; then + empty_inputs=false + poetry run pytest --lamini_model=$lamini_model_input -vv tests/test_lamini.py + fi + + if [[ -n "$magentic_model_input" ]]; then + empty_inputs=false + poetry run pytest --magentic_model=$magentic_model_input -vv tests/test_magentic.py + fi + + if $empty_inputs; then + echo "All variables are empty" + poetry run pytest -vv tests/ + fi + + - name: Run scheduled tests + if: ${{ github.event_name == 'schedule' }} + run: | + echo "This is a schedule event" + poetry run pytest -vv tests/ + poetry run pytest --openai_model=gpt-4o -m chat -vv tests/test_openai.py diff --git a/Makefile b/Makefile index db0a5f47..b97f9615 100644 --- a/Makefile +++ b/Makefile @@ -34,10 +34,6 @@ logging: python examples/logging/get_url.py # python examples/logging/langchain_babyagi.py python examples/logging/langchain_model_logger.py - python examples/logging/langchain_multiple_tools.py - python examples/logging/langchain_qa.py - python examples/logging/langchain_simple_sequential.py - python examples/logging/langchain_sqlagent.py evals: (cd examples/evals && python basic_eval.py) diff --git a/poetry.lock b/poetry.lock index 08821ab6..136b6e62 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.7.1 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.8.2 and should not be changed by hand. [[package]] name = "aiohttp" @@ -804,6 +804,23 @@ smb = ["smbprotocol"] ssh = ["paramiko"] tqdm = ["tqdm"] +[[package]] +name = "google-ai-generativelanguage" +version = "0.6.4" +description = "Google Ai Generativelanguage API client library" +optional = true +python-versions = ">=3.7" +files = [ + {file = "google-ai-generativelanguage-0.6.4.tar.gz", hash = "sha256:1750848c12af96cb24ae1c3dd05e4bfe24867dc4577009ed03e1042d8421e874"}, + {file = "google_ai_generativelanguage-0.6.4-py3-none-any.whl", hash = "sha256:730e471aa549797118fb1c88421ba1957741433ada575cf5dd08d3aebf903ab1"}, +] + +[package.dependencies] +google-api-core = {version = ">=1.34.1,<2.0.dev0 || >=2.11.dev0,<3.0.0dev", extras = ["grpc"]} +google-auth = ">=2.14.1,<2.24.0 || >2.24.0,<2.25.0 || >2.25.0,<3.0.0dev" +proto-plus = ">=1.22.3,<2.0.0dev" +protobuf = ">=3.19.5,<3.20.0 || >3.20.0,<3.20.1 || >3.20.1,<4.21.0 || >4.21.0,<4.21.1 || >4.21.1,<4.21.2 || >4.21.2,<4.21.3 || >4.21.3,<4.21.4 || >4.21.4,<4.21.5 || >4.21.5,<5.0.0dev" + [[package]] name = "google-api-core" version = "2.19.0" @@ -819,12 +836,12 @@ files = [ google-auth = ">=2.14.1,<3.0.dev0" googleapis-common-protos = ">=1.56.2,<2.0.dev0" grpcio = [ - {version = ">=1.49.1,<2.0dev", optional = true, markers = "python_version >= \"3.11\" and extra == \"grpc\""}, {version = ">=1.33.2,<2.0dev", optional = true, markers = "python_version < \"3.11\" and extra == \"grpc\""}, + {version = ">=1.49.1,<2.0dev", optional = true, markers = "python_version >= \"3.11\" and extra == \"grpc\""}, ] grpcio-status = [ - {version = ">=1.49.1,<2.0.dev0", optional = true, markers = "python_version >= \"3.11\" and extra == \"grpc\""}, {version = ">=1.33.2,<2.0.dev0", optional = true, markers = "python_version < \"3.11\" and extra == \"grpc\""}, + {version = ">=1.49.1,<2.0.dev0", optional = true, markers = "python_version >= \"3.11\" and extra == \"grpc\""}, ] proto-plus = ">=1.22.3,<2.0.0dev" protobuf = ">=3.19.5,<3.20.0 || >3.20.0,<3.20.1 || >3.20.1,<4.21.0 || >4.21.0,<4.21.1 || >4.21.1,<4.21.2 || >4.21.2,<4.21.3 || >4.21.3,<4.21.4 || >4.21.4,<4.21.5 || >4.21.5,<5.0.0.dev0" @@ -835,6 +852,24 @@ grpc = ["grpcio (>=1.33.2,<2.0dev)", "grpcio (>=1.49.1,<2.0dev)", "grpcio-status grpcgcp = ["grpcio-gcp (>=0.2.2,<1.0.dev0)"] grpcio-gcp = ["grpcio-gcp (>=0.2.2,<1.0.dev0)"] +[[package]] +name = "google-api-python-client" +version = "2.129.0" +description = "Google API Client Library for Python" +optional = true +python-versions = ">=3.7" +files = [ + {file = "google-api-python-client-2.129.0.tar.gz", hash = "sha256:984cc8cc8eb4923468b1926d2b8effc5b459a4dda3c845896eb87c153b28ef84"}, + {file = "google_api_python_client-2.129.0-py2.py3-none-any.whl", hash = "sha256:d50f7e2dfdbb7fc2732f6a0cba1c54d7bb676390679526c6bb628c901e43ec86"}, +] + +[package.dependencies] +google-api-core = ">=1.31.5,<2.0.dev0 || >2.3.0,<3.0.0.dev0" +google-auth = ">=1.32.0,<2.24.0 || >2.24.0,<2.25.0 || >2.25.0,<3.0.0.dev0" +google-auth-httplib2 = ">=0.2.0,<1.0.0" +httplib2 = ">=0.19.0,<1.dev0" +uritemplate = ">=3.0.1,<5" + [[package]] name = "google-auth" version = "2.29.0" @@ -858,6 +893,21 @@ pyopenssl = ["cryptography (>=38.0.3)", "pyopenssl (>=20.0.0)"] reauth = ["pyu2f (>=0.1.5)"] requests = ["requests (>=2.20.0,<3.0.0.dev0)"] +[[package]] +name = "google-auth-httplib2" +version = "0.2.0" +description = "Google Authentication Library: httplib2 transport" +optional = true +python-versions = "*" +files = [ + {file = "google-auth-httplib2-0.2.0.tar.gz", hash = "sha256:38aa7badf48f974f1eb9861794e9c0cb2a0511a4ec0679b1f886d108f5640e05"}, + {file = "google_auth_httplib2-0.2.0-py2.py3-none-any.whl", hash = "sha256:b65a0a2123300dd71281a7bf6e64d65a0759287df52729bdd1ae2e47dc311a3d"}, +] + +[package.dependencies] +google-auth = "*" +httplib2 = ">=0.19.0" + [[package]] name = "google-cloud-aiplatform" version = "1.49.0" @@ -1074,6 +1124,29 @@ files = [ [package.extras] testing = ["pytest"] +[[package]] +name = "google-generativeai" +version = "0.5.4" +description = "Google Generative AI High level API client library and tools." +optional = true +python-versions = ">=3.9" +files = [ + {file = "google_generativeai-0.5.4-py3-none-any.whl", hash = "sha256:036d63ee35e7c8aedceda4f81c390a5102808af09ff3a6e57e27ed0be0708f3c"}, +] + +[package.dependencies] +google-ai-generativelanguage = "0.6.4" +google-api-core = "*" +google-api-python-client = "*" +google-auth = ">=2.15.0" +protobuf = "*" +pydantic = "*" +tqdm = "*" +typing-extensions = "*" + +[package.extras] +dev = ["Pillow", "absl-py", "black", "ipython", "nose2", "pandas", "pytype", "pyyaml"] + [[package]] name = "google-resumable-media" version = "2.7.0" @@ -1343,6 +1416,20 @@ http2 = ["h2 (>=3,<5)"] socks = ["socksio (==1.*)"] trio = ["trio (>=0.22.0,<0.26.0)"] +[[package]] +name = "httplib2" +version = "0.22.0" +description = "A comprehensive HTTP client library." +optional = true +python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" +files = [ + {file = "httplib2-0.22.0-py3-none-any.whl", hash = "sha256:14ae0a53c1ba8f3d37e9e27cf37eabb0fb9980f435ba405d546948b009dd64dc"}, + {file = "httplib2-0.22.0.tar.gz", hash = "sha256:d7a10bc5ef5ab08322488bde8c726eeee5c8618723fdb399597ec58f3d82df81"}, +] + +[package.dependencies] +pyparsing = {version = ">=2.4.2,<3.0.0 || >3.0.0,<3.0.1 || >3.0.1,<3.0.2 || >3.0.2,<3.0.3 || >3.0.3,<4", markers = "python_version > \"3.0\""} + [[package]] name = "httpx" version = "0.25.2" @@ -2195,9 +2282,9 @@ files = [ [package.dependencies] numpy = [ + {version = ">=1.22.4", markers = "python_version < \"3.11\""}, {version = ">=1.23.2", markers = "python_version == \"3.11\""}, {version = ">=1.26.0", markers = "python_version >= \"3.12\""}, - {version = ">=1.22.4", markers = "python_version < \"3.11\""}, ] python-dateutil = ">=2.8.2" pytz = ">=2020.1" @@ -2474,6 +2561,20 @@ files = [ plugins = ["importlib-metadata"] windows-terminal = ["colorama (>=0.4.6)"] +[[package]] +name = "pyparsing" +version = "3.1.2" +description = "pyparsing module - Classes and methods to define and execute parsing grammars" +optional = true +python-versions = ">=3.6.8" +files = [ + {file = "pyparsing-3.1.2-py3-none-any.whl", hash = "sha256:f9db75911801ed778fe61bb643079ff86601aca99fcae6345aa67292038fb742"}, + {file = "pyparsing-3.1.2.tar.gz", hash = "sha256:a1bac0ce561155ecc3ed78ca94d3c9378656ad4c94c1270de543f621420f94ad"}, +] + +[package.extras] +diagrams = ["jinja2", "railroad-diagrams"] + [[package]] name = "pyproject-hooks" version = "1.1.0" @@ -2507,6 +2608,24 @@ tomli = {version = ">=1", markers = "python_version < \"3.11\""} [package.extras] dev = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "pygments (>=2.7.2)", "requests", "setuptools", "xmlschema"] +[[package]] +name = "pytest-asyncio" +version = "0.23.6" +description = "Pytest support for asyncio" +optional = false +python-versions = ">=3.8" +files = [ + {file = "pytest-asyncio-0.23.6.tar.gz", hash = "sha256:ffe523a89c1c222598c76856e76852b787504ddb72dd5d9b6617ffa8aa2cde5f"}, + {file = "pytest_asyncio-0.23.6-py3-none-any.whl", hash = "sha256:68516fdd1018ac57b846c9846b954f0393b26f094764a28c955eabb0536a4e8a"}, +] + +[package.dependencies] +pytest = ">=7.0.0,<9" + +[package.extras] +docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1.0)"] +testing = ["coverage (>=6.2)", "hypothesis (>=5.7.1)"] + [[package]] name = "python-dateutil" version = "2.9.0.post0" @@ -2558,7 +2677,6 @@ files = [ {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:69b023b2b4daa7548bcfbd4aa3da05b3a74b772db9e23b982788168117739938"}, {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:81e0b275a9ecc9c0c0c07b4b90ba548307583c125f54d5b6946cfee6360c733d"}, {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ba336e390cd8e4d1739f42dfe9bb83a3cc2e80f567d8805e11b46f4a943f5515"}, - {file = "PyYAML-6.0.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:326c013efe8048858a6d312ddd31d56e468118ad4cdeda36c719bf5bb6192290"}, {file = "PyYAML-6.0.1-cp310-cp310-win32.whl", hash = "sha256:bd4af7373a854424dabd882decdc5579653d7868b8fb26dc7d0e99f823aa5924"}, {file = "PyYAML-6.0.1-cp310-cp310-win_amd64.whl", hash = "sha256:fd1592b3fdf65fff2ad0004b5e363300ef59ced41c2e6b3a99d4089fa8c5435d"}, {file = "PyYAML-6.0.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:6965a7bc3cf88e5a1c3bd2e0b5c22f8d677dc88a455344035f03399034eb3007"}, @@ -2566,15 +2684,8 @@ files = [ {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:42f8152b8dbc4fe7d96729ec2b99c7097d656dc1213a3229ca5383f973a5ed6d"}, {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:062582fca9fabdd2c8b54a3ef1c978d786e0f6b3a1510e0ac93ef59e0ddae2bc"}, {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d2b04aac4d386b172d5b9692e2d2da8de7bfb6c387fa4f801fbf6fb2e6ba4673"}, - {file = "PyYAML-6.0.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:e7d73685e87afe9f3b36c799222440d6cf362062f78be1013661b00c5c6f678b"}, {file = "PyYAML-6.0.1-cp311-cp311-win32.whl", hash = "sha256:1635fd110e8d85d55237ab316b5b011de701ea0f29d07611174a1b42f1444741"}, {file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"}, - {file = "PyYAML-6.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28"}, - {file = "PyYAML-6.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9"}, - {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0"}, - {file = "PyYAML-6.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4"}, - {file = "PyYAML-6.0.1-cp312-cp312-win32.whl", hash = "sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54"}, - {file = "PyYAML-6.0.1-cp312-cp312-win_amd64.whl", hash = "sha256:0d3304d8c0adc42be59c5f8a4d9e3d7379e6955ad754aa9d6ab7a398b59dd1df"}, {file = "PyYAML-6.0.1-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:50550eb667afee136e9a77d6dc71ae76a44df8b3e51e41b77f6de2932bfe0f47"}, {file = "PyYAML-6.0.1-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1fe35611261b29bd1de0070f0b2f47cb6ff71fa6595c077e42bd0c419fa27b98"}, {file = "PyYAML-6.0.1-cp36-cp36m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:704219a11b772aea0d8ecd7058d0082713c3562b4e271b849ad7dc4a5c90c13c"}, @@ -2591,7 +2702,6 @@ files = [ {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a0cd17c15d3bb3fa06978b4e8958dcdc6e0174ccea823003a106c7d4d7899ac5"}, {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:28c119d996beec18c05208a8bd78cbe4007878c6dd15091efb73a30e90539696"}, {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7e07cbde391ba96ab58e532ff4803f79c4129397514e1413a7dc761ccd755735"}, - {file = "PyYAML-6.0.1-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:49a183be227561de579b4a36efbb21b3eab9651dd81b1858589f796549873dd6"}, {file = "PyYAML-6.0.1-cp38-cp38-win32.whl", hash = "sha256:184c5108a2aca3c5b3d3bf9395d50893a7ab82a38004c8f61c258d4428e80206"}, {file = "PyYAML-6.0.1-cp38-cp38-win_amd64.whl", hash = "sha256:1e2722cc9fbb45d9b87631ac70924c11d3a401b2d7f410cc0e3bbf249f2dca62"}, {file = "PyYAML-6.0.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:9eb6caa9a297fc2c2fb8862bc5370d0303ddba53ba97e71f08023b6cd73d16a8"}, @@ -2599,7 +2709,6 @@ files = [ {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5773183b6446b2c99bb77e77595dd486303b4faab2b086e7b17bc6bef28865f6"}, {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b786eecbdf8499b9ca1d697215862083bd6d2a99965554781d0d8d1ad31e13a0"}, {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bc1bf2925a1ecd43da378f4db9e4f799775d6367bdb94671027b73b393a7c42c"}, - {file = "PyYAML-6.0.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:04ac92ad1925b2cff1db0cfebffb6ffc43457495c9b3c39d3fcae417d7125dc5"}, {file = "PyYAML-6.0.1-cp39-cp39-win32.whl", hash = "sha256:faca3bdcf85b2fc05d06ff3fbc1f83e1391b3e724afa3feba7d13eeab355484c"}, {file = "PyYAML-6.0.1-cp39-cp39-win_amd64.whl", hash = "sha256:510c9deebc5c0225e8c96813043e62b680ba2f9c50a08d3724c7f28a747d1486"}, {file = "PyYAML-6.0.1.tar.gz", hash = "sha256:bfdf460b1736c775f2ba9f6a92bca30bc2095067b8a9d77876d1fad6cc3b4a43"}, @@ -3486,6 +3595,17 @@ files = [ {file = "tzdata-2024.1.tar.gz", hash = "sha256:2674120f8d891909751c38abcdfd386ac0a5a1127954fbc332af6b5ceae07efd"}, ] +[[package]] +name = "uritemplate" +version = "4.1.1" +description = "Implementation of RFC 6570 URI Templates" +optional = true +python-versions = ">=3.6" +files = [ + {file = "uritemplate-4.1.1-py2.py3-none-any.whl", hash = "sha256:830c08b8d99bdd312ea4ead05994a38e8936266f84b9a7878232db50b044e02e"}, + {file = "uritemplate-4.1.1.tar.gz", hash = "sha256:4346edfc5c3b79f694bccd6d6099a322bbeb628dbf2cd86eea55a456ce5124f0"}, +] + [[package]] name = "urllib3" version = "2.2.1" @@ -3725,6 +3845,8 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p [extras] autofeedback-icl = ["magentic"] gemini = ["google-cloud-aiplatform"] +google-generativeai = ["google-generativeai"] +lamini = ["lamini"] langchain = ["langchain"] litellm = ["litellm"] mistralai = ["mistralai"] @@ -3734,4 +3856,4 @@ together = ["together"] [metadata] lock-version = "2.0" python-versions = ">=3.9,<4.0" -content-hash = "488912dd5daf23c3c35b8a9e394f7d4b42562a6c7414cd4a7134ba770004a90f" +content-hash = "6246821d7d245e3166a3759da50fbe20ce66ea6a22c0dbd09dd8f6c74d35af9a" diff --git a/pyproject.toml b/pyproject.toml index c8d5f019..c224adfe 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -28,6 +28,7 @@ pytest = "^8.0.0" requests-mock = "^1.11.0" respx = "^0.20.2" ruff = "^0.3.2" +pytest-asyncio = "^0.23.6" [project.urls] "Homepage" = "https://github.com/log10-io/log10" @@ -50,6 +51,7 @@ mistralai = {version = "^0.1.5", optional = true} together = {version = "^0.2.7", optional = true} mosaicml-cli = {version = "^0.5.30", optional = true} google-cloud-bigquery = {version = "^3.11.4", optional = true} +google-generativeai = {version = "^0.5.2", optional = true} [tool.poetry.extras] autofeedback_icl = ["magentic"] @@ -59,6 +61,8 @@ gemini = ["google-cloud-aiplatform"] mistralai = ["mistralai"] together = ["together"] mosaicml = ["mosaicml-cli"] +google-generativeai = ["google-generativeai"] +lamini = ["lamini"] [tool.ruff] # Never enforce `E501` (line length violations). diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 00000000..c2241638 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,61 @@ +import pytest + + +def pytest_addoption(parser): + parser.addoption("--openai_model", action="store", help="Model name for OpenAI tests") + + parser.addoption("--openai_vision_model", action="store", help="Model name for OpenAI vision tests") + + parser.addoption("--anthropic_model", action="store", help="Model name for Message API Anthropic tests") + + parser.addoption("--anthropic_legacy_model", action="store", help="Model name for legacy Anthropic tests") + + parser.addoption("--lamini_model", action="store", help="Model name for Lamini tests") + + parser.addoption( + "--mistralai_model", action="store", default="mistral-tiny", help="Model name for Mistralai tests" + ) + + parser.addoption("--google_model", action="store", help="Model name for Google tests") + + parser.addoption("--magentic_model", action="store", help="Model name for Magentic tests") + + +@pytest.fixture +def openai_model(request): + return request.config.getoption("--openai_model") + + +@pytest.fixture +def openai_vision_model(request): + return request.config.getoption("--openai_vision_model") + + +@pytest.fixture +def anthropic_model(request): + return request.config.getoption("--anthropic_model") + + +@pytest.fixture +def anthropic_legacy_model(request): + return request.config.getoption("--anthropic_legacy_model") + + +@pytest.fixture +def mistralai_model(request): + return request.config.getoption("--mistralai_model") + + +@pytest.fixture +def lamini_model(request): + return request.config.getoption("--lamini_model") + + +@pytest.fixture +def google_model(request): + return request.config.getoption("--google_model") + + +@pytest.fixture +def magentic_model(request): + return request.config.getoption("--magentic_model") diff --git a/tests/pytest.ini b/tests/pytest.ini new file mode 100644 index 00000000..956f1de7 --- /dev/null +++ b/tests/pytest.ini @@ -0,0 +1,18 @@ +[pytest] +addopts = + --openai_model=gpt-3.5-turbo + --openai_vision_model=gpt-4o + --anthropic_model=claude-3-haiku-20240307 + --anthropic_legacy_model=claude-2.1 + --google_model=gemini-1.5-pro-latest + --mistralai_model=mistral-tiny + --lamini_model=meta-llama/Llama-2-7b-chat-hf + --magentic_model=gpt-3.5-turbo + +markers = + async_client: mark a test as an async test + chat: mark a test as a chat test + stream: mark a test as a stream test + tools: mark a test as a tools test + vision: mark a test as a vision test + widget: mark a test as a widget test diff --git a/tests/test_anthropic.py b/tests/test_anthropic.py new file mode 100644 index 00000000..166b3b75 --- /dev/null +++ b/tests/test_anthropic.py @@ -0,0 +1,87 @@ +import base64 + +import anthropic +import httpx +import pytest + +from log10.load import log10 + + +log10(anthropic) +client = anthropic.Anthropic() + + +@pytest.mark.chat +def test_messages(anthropic_model): + message = client.messages.create( + model=anthropic_model, + max_tokens=1000, + temperature=0.0, + system="Respond only in Yoda-speak.", + messages=[{"role": "user", "content": "How are you today?"}], + ) + + text = message.content[0].text + assert isinstance(text, str) + assert text, "No output from the model." + + +@pytest.mark.chat +@pytest.mark.stream +def test_messages_stream(anthropic_model): + stream = client.messages.create( + model=anthropic_model, + messages=[ + { + "role": "user", + "content": "Count to 10", + } + ], + max_tokens=128, + temperature=0.9, + stream=True, + ) + + output = "" + for event in stream: + if event.type == "content_block_delta": + text = event.delta.text + output += text + if text.isdigit(): + assert int(text) <= 10 + + assert output, "No output from the model." + + +@pytest.mark.vision +def test_messages_image(anthropic_model): + client = anthropic.Anthropic() + + image1_url = "https://upload.wikimedia.org/wikipedia/commons/a/a7/Camponotus_flavomarginatus_ant.jpg" + image1_media_type = "image/jpeg" + image1_data = base64.b64encode(httpx.get(image1_url).content).decode("utf-8") + + message = client.messages.create( + model=anthropic_model, + max_tokens=1024, + messages=[ + { + "role": "user", + "content": [ + { + "type": "image", + "source": { + "type": "base64", + "media_type": image1_media_type, + "data": image1_data, + }, + }, + {"type": "text", "text": "Describe this image."}, + ], + } + ], + ) + + text = message.content[0].text + assert text, "No output from the model." + assert "ant" in text diff --git a/tests/test_google.py b/tests/test_google.py new file mode 100644 index 00000000..a562ef39 --- /dev/null +++ b/tests/test_google.py @@ -0,0 +1,42 @@ +import google.generativeai as genai +import pytest + +from log10.load import log10 + + +log10(genai) + + +@pytest.mark.chat +def test_genai_chat(google_model): + model = genai.GenerativeModel(google_model) + chat = model.start_chat() + + prompt = "Say this is a test" + generation_config = genai.GenerationConfig( + temperature=0.9, + max_output_tokens=512, + ) + response = chat.send_message(prompt, generation_config=generation_config) + + text = response.text + assert isinstance(text, str) + assert "this is a test" in text.lower() + + +@pytest.mark.chat +def test_genai_chat_w_history(google_model): + model = genai.GenerativeModel(google_model, system_instruction="You are a cat. Your name is Neko.") + chat = model.start_chat( + history=[ + {"role": "user", "parts": [{"text": "please say yes."}]}, + {"role": "model", "parts": [{"text": "Yes yes yes?"}]}, + ] + ) + + prompt = "please say no." + response = chat.send_message(prompt) + + text = response.text + assert isinstance(text, str) + assert "no" in text.lower() diff --git a/tests/test_lamini.py b/tests/test_lamini.py new file mode 100644 index 00000000..f67ac63c --- /dev/null +++ b/tests/test_lamini.py @@ -0,0 +1,16 @@ +import lamini +import pytest + +from log10.load import log10 + + +log10(lamini) + + +@pytest.mark.chat +def test_generate(lamini_model): + llm = lamini.Lamini(lamini_model) + response = llm.generate("What's 2 + 9 * 3?") + + assert isinstance(response, str) + assert "29" in response diff --git a/tests/test_langchain.py b/tests/test_langchain.py new file mode 100644 index 00000000..189012a5 --- /dev/null +++ b/tests/test_langchain.py @@ -0,0 +1,32 @@ +import anthropic +import openai +import pytest +from langchain.chat_models import ChatAnthropic, ChatOpenAI +from langchain.schema import HumanMessage, SystemMessage + +from log10.load import log10 + + +@pytest.mark.chat +def test_chat_openai_messages(openai_model): + log10(openai) + llm = ChatOpenAI( + model_name=openai_model, + temperature=0.5, + ) + messages = [SystemMessage(content="You are a ping pong machine"), HumanMessage(content="Ping?")] + completion = llm.predict_messages(messages) + + assert isinstance(completion.content, str) + assert "pong" in completion.content.lower() + + +@pytest.mark.chat +def test_chat_anthropic_messages(anthropic_legacy_model): + log10(anthropic) + llm = ChatAnthropic(model=anthropic_legacy_model, temperature=0.7) + messages = [SystemMessage(content="You are a ping pong machine"), HumanMessage(content="Ping?")] + completion = llm.predict_messages(messages) + + assert isinstance(completion.content, str) + assert "pong" in completion.content.lower() diff --git a/tests/test_litellm.py b/tests/test_litellm.py new file mode 100644 index 00000000..e0a37ab2 --- /dev/null +++ b/tests/test_litellm.py @@ -0,0 +1,143 @@ +import base64 + +import httpx +import litellm +import pytest + +from log10.litellm import Log10LitellmLogger + + +### litellm seems allowing to use multiple callbacks +### which causes some previous tags to be added to the next callback +### when running multiple tests at the same time + +log10_handler = Log10LitellmLogger(tags=["litellm_test"]) +litellm.callbacks = [log10_handler] + + +@pytest.mark.chat +@pytest.mark.stream +def test_completion_stream(openai_model): + response = litellm.completion( + model=openai_model, messages=[{"role": "user", "content": "Count to 10."}], stream=True + ) + + output = "" + for chunk in response: + if chunk.choices[0].delta.content: + output += chunk.choices[0].delta.content + + assert output, "No output from the model." + + +@pytest.mark.async_client +@pytest.mark.chat +@pytest.mark.stream +@pytest.mark.asyncio +async def test_completion_async_stream(anthropic_model): + response = await litellm.acompletion( + model=anthropic_model, messages=[{"role": "user", "content": "count to 10"}], stream=True + ) + + output = "" + async for chunk in response: + if chunk.choices[0].delta.content: + output += chunk.choices[0].delta.content.strip() + + assert output, "No output from the model." + + +@pytest.mark.vision +def test_image(openai_vision_model): + image_url = "https://upload.wikimedia.org/wikipedia/commons/e/e8/Log10.png" + image_media_type = "image/png" + image_data = base64.b64encode(httpx.get(image_url).content).decode("utf-8") + + resp = litellm.completion( + model=openai_vision_model, + messages=[ + { + "role": "user", + "content": [ + { + "type": "image_url", + "image_url": {"url": f"data:{image_media_type};base64,{image_data}"}, + }, + {"type": "text", "text": "What's the red curve in the figure, is it log2 or log10? Be concise."}, + ], + } + ], + ) + + content = resp.choices[0].message.content + assert isinstance(content, str) + assert content, "No output from the model." + + +@pytest.mark.stream +@pytest.mark.vision +def test_image_stream(anthropic_model): + image_url = "https://upload.wikimedia.org/wikipedia/commons/e/e8/Log10.png" + image_media_type = "image/png" + image_data = base64.b64encode(httpx.get(image_url).content).decode("utf-8") + + resp = litellm.completion( + model=anthropic_model, + messages=[ + { + "role": "user", + "content": [ + { + "type": "image_url", + "image_url": {"url": f"data:{image_media_type};base64,{image_data}"}, + }, + {"type": "text", "text": "What's the red curve in the figure, is it log2 or log10? Be concise."}, + ], + } + ], + stream=True, + ) + + output = "" + for chunk in resp: + if chunk.choices[0].delta.content: + output += chunk.choices[0].delta.content + + assert output, "No output from the model." + + +@pytest.mark.async_client +@pytest.mark.stream +@pytest.mark.vision +@pytest.mark.asyncio +async def test_image_async_stream(anthropic_model): + image_url = "https://upload.wikimedia.org/wikipedia/commons/e/e8/Log10.png" + image_media_type = "image/png" + image_data = base64.b64encode(httpx.get(image_url).content).decode("utf-8") + + resp = litellm.completion( + model=anthropic_model, + messages=[ + { + "role": "user", + "content": [ + { + "type": "image_url", + "image_url": {"url": f"data:{image_media_type};base64,{image_data}"}, + }, + { + "type": "text", + "text": "What's the red curve in the figure, is it log2 or log10? Be concise.", + }, + ], + } + ], + stream=True, + ) + + output = "" + for chunk in resp: + if chunk.choices[0].delta.content: + output += chunk.choices[0].delta.content + + assert output, "No output from the model." diff --git a/tests/test_magentic.py b/tests/test_magentic.py new file mode 100644 index 00000000..71be2449 --- /dev/null +++ b/tests/test_magentic.py @@ -0,0 +1,147 @@ +from typing import Literal + +import openai +import pytest +from magentic import AsyncParallelFunctionCall, AsyncStreamedStr, FunctionCall, OpenaiChatModel, StreamedStr, prompt +from pydantic import BaseModel + +from log10.load import log10, log10_session + + +log10(openai) + + +@pytest.mark.chat +def test_prompt(magentic_model): + @prompt("Tell me a short joke", model=OpenaiChatModel(model=magentic_model)) + def llm() -> str: ... + + output = llm() + assert isinstance(output, str) + assert output, "No output from the model." + + +@pytest.mark.chat +@pytest.mark.stream +def test_prompt_stream(magentic_model): + @prompt("Tell me a short joke", model=OpenaiChatModel(model=magentic_model)) + def llm() -> StreamedStr: ... + + response = llm() + output = "" + for chunk in response: + output += chunk + + assert output, "No output from the model." + + +@pytest.mark.tools +def test_function_logging(magentic_model): + def activate_oven(temperature: int, mode: Literal["broil", "bake", "roast"]) -> str: + """Turn the oven on with the provided settings.""" + return f"Preheating to {temperature} F with mode {mode}" + + @prompt( + "Prepare the oven so I can make {food}", functions=[activate_oven], model=OpenaiChatModel(model=magentic_model) + ) + def configure_oven(food: str) -> FunctionCall[str]: # ruff: ignore + ... + + output = configure_oven("cookies!") + assert output(), "No output from the model." + + +@pytest.mark.async_client +@pytest.mark.stream +@pytest.mark.asyncio +async def test_async_stream_logging(magentic_model): + @prompt("Tell me a 50-word story about {topic}", model=OpenaiChatModel(model=magentic_model)) + async def tell_story(topic: str) -> AsyncStreamedStr: # ruff: ignore + ... + + with log10_session(tags=["async_tag"]): + output = await tell_story("Europe.") + result = "" + async for chunk in output: + result += chunk + + assert result, "No output from the model." + + +@pytest.mark.async_client +@pytest.mark.tools +@pytest.mark.asyncio +async def test_async_parallel_stream_logging(magentic_model): + def plus(a: int, b: int) -> int: + return a + b + + async def minus(a: int, b: int) -> int: + return a - b + + @prompt( + "Sum {a} and {b}. Also subtract {a} from {b}.", + functions=[plus, minus], + model=OpenaiChatModel(model=magentic_model), + ) + async def plus_and_minus(a: int, b: int) -> AsyncParallelFunctionCall[int]: ... + + output = await plus_and_minus(2, 3) + async for chunk in output: + assert isinstance(chunk, FunctionCall), "chunk is not an instance of FunctionCall" + + +@pytest.mark.async_client +@pytest.mark.stream +@pytest.mark.asyncio +async def test_async_multi_session_tags(magentic_model): + @prompt("What is {a} * {b}?", model=OpenaiChatModel(model=magentic_model)) + async def do_math_with_llm_async(a: int, b: int) -> AsyncStreamedStr: # ruff: ignore + ... + + output = "" + + with log10_session(tags=["test_tag_a"]): + result = await do_math_with_llm_async(2, 2) + async for chunk in result: + output += chunk + + result = await do_math_with_llm_async(2.5, 2.5) + async for chunk in result: + output += chunk + + with log10_session(tags=["test_tag_b"]): + result = await do_math_with_llm_async(3, 3) + async for chunk in result: + output += chunk + + assert output, "No output from the model." + assert "4" in output + assert "6.25" in output + assert "9" in output + + +@pytest.mark.async_client +@pytest.mark.widget +@pytest.mark.asyncio +async def test_async_widget(magentic_model): + class WidgetInfo(BaseModel): + title: str + description: str + + @prompt( + """ + Generate a descriptive title and short description for a widget, given the user's query and the data contained in the widget. + + Data: {widget_data} + Query: {query} + """, # noqa: E501 + model=OpenaiChatModel(magentic_model, temperature=0.1, max_tokens=1000), + ) + async def _generate_title_and_description(query: str, widget_data: str) -> WidgetInfo: ... + + r = await _generate_title_and_description(query="Give me a summary of AAPL", widget_data="") + + assert isinstance(r, WidgetInfo) + assert isinstance(r.title, str) + assert "AAPL" in r.title + assert isinstance(r.description, str) diff --git a/tests/test_mistralai.py b/tests/test_mistralai.py new file mode 100644 index 00000000..001c79ba --- /dev/null +++ b/tests/test_mistralai.py @@ -0,0 +1,39 @@ +import mistralai +import pytest +from mistralai.client import MistralClient +from mistralai.models.chat_completion import ChatMessage + +from log10.load import log10 + + +log10(mistralai) +client = MistralClient() + + +@pytest.mark.chat +def test_chat(mistralai_model): + chat_response = client.chat( + model=mistralai_model, + messages=[ChatMessage(role="user", content="10 + 2 * 3=?")], + ) + + content = chat_response.choices[0].message.content + assert content, "No output from the model." + assert "16" in content + + +@pytest.mark.chat +@pytest.mark.stream +def test_chat_stream(mistralai_model): + response = client.chat_stream( + model=mistralai_model, + messages=[ChatMessage(role="user", content="Count the odd numbers from 1 to 20.")], + ) + + output = "" + for chunk in response: + content = chunk.choices[0].delta.content + if chunk.choices[0].delta.content is not None: + output += content + + assert output, "No output from the model." diff --git a/tests/test_openai.py b/tests/test_openai.py new file mode 100644 index 00000000..da5bd2c7 --- /dev/null +++ b/tests/test_openai.py @@ -0,0 +1,296 @@ +import base64 +import json + +import httpx +import openai +import pytest +from openai import NOT_GIVEN, AsyncOpenAI + +from log10.load import log10 + + +log10(openai) +client = openai.OpenAI() + + +@pytest.mark.chat +def test_chat(openai_model): + completion = client.chat.completions.create( + model=openai_model, + messages=[ + { + "role": "system", + "content": "You will be provided with statements, and your task is to convert them to standard English.", + }, + { + "role": "user", + "content": "He no went to the market.", + }, + ], + ) + + content = completion.choices[0].message.content + assert isinstance(content, str) + assert content, "No output from the model." + assert "did not go to the market" in content + + +@pytest.mark.chat +def test_chat_not_given(openai_model): + completion = client.chat.completions.create( + model=openai_model, + messages=[ + { + "role": "user", + "content": "tell a short joke.", + }, + ], + tools=NOT_GIVEN, + tool_choice=NOT_GIVEN, + ) + + content = completion.choices[0].message.content + assert isinstance(content, str) + assert content, "No output from the model." + + +@pytest.mark.chat +@pytest.mark.async_client +@pytest.mark.asyncio +async def test_chat_async(openai_model): + client = AsyncOpenAI() + + completion = await client.chat.completions.create( + model=openai_model, + messages=[{"role": "user", "content": "Say this is a test"}], + ) + + content = completion.choices[0].message.content + assert isinstance(content, str) + assert content, "No output from the model." + assert "this is a test" in content.lower() + + +@pytest.mark.chat +@pytest.mark.stream +def test_chat_stream(openai_model): + response = client.chat.completions.create( + model=openai_model, + messages=[{"role": "user", "content": "Count to 10"}], + temperature=0, + stream=True, + ) + + output = "" + for chunk in response: + content = chunk.choices[0].delta.content + if content: + output += content.strip() + + assert output, "No output from the model." + + +@pytest.mark.async_client +@pytest.mark.stream +@pytest.mark.asyncio +async def test_chat_async_stream(openai_model): + client = AsyncOpenAI() + + output = "" + stream = await client.chat.completions.create( + model=openai_model, + messages=[{"role": "user", "content": "Count to 10"}], + stream=True, + ) + async for chunk in stream: + content = chunk.choices[0].delta.content + if content: + output += content.strip() + + assert output, "No output from the model." + + +@pytest.mark.vision +def test_chat_image(openai_vision_model): + image1_url = "https://upload.wikimedia.org/wikipedia/commons/a/a7/Camponotus_flavomarginatus_ant.jpg" + image1_media_type = "image/jpeg" + image1_data = base64.b64encode(httpx.get(image1_url).content).decode("utf-8") + + response = client.chat.completions.create( + model=openai_vision_model, + messages=[ + { + "role": "user", + "content": [ + { + "type": "text", + "text": "What are in these image?", + }, + { + "type": "image_url", + "image_url": {"url": f"data:{image1_media_type};base64,{image1_data}"}, + }, + ], + } + ], + max_tokens=300, + ) + + content = response.choices[0].message.content + assert isinstance(content, str) + assert content, "No output from the model." + assert "ant" in content + + +def get_current_weather(location, unit="fahrenheit"): + """Get the current weather in a given location""" + if "tokyo" in location.lower(): + return json.dumps({"location": "Tokyo", "temperature": "10", "unit": unit}) + elif "san francisco" in location.lower(): + return json.dumps({"location": "San Francisco", "temperature": "72", "unit": unit}) + elif "paris" in location.lower(): + return json.dumps({"location": "Paris", "temperature": "22", "unit": unit}) + else: + return json.dumps({"location": location, "temperature": "unknown"}) + + +def setup_tools_messages() -> dict: + messages = [ + { + "role": "user", + "content": "What's the weather like in San Francisco, Tokyo, and Paris?", + } + ] + tools = [ + { + "type": "function", + "function": { + "name": "get_current_weather", + "description": "Get the current weather in a given location", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state, e.g. San Francisco, CA", + }, + "unit": {"type": "string", "enum": ["celsius", "fahrenheit"]}, + }, + "required": ["location"], + }, + }, + } + ] + + return { + "messages": messages, + "tools": tools, + } + + +@pytest.mark.tools +def test_tools(openai_model): + # Step 1: send the conversation and available functions to the model + result = setup_tools_messages() + messages = result["messages"] + tools = result["tools"] + + response = client.chat.completions.create( + model=openai_model, + messages=messages, + tools=tools, + tool_choice="auto", # auto is default, but we'll be explicit + ) + response_message = response.choices[0].message + tool_calls = response_message.tool_calls + # Step 2: check if the model wanted to call a function + if tool_calls: + # Step 3: call the function + # Note: the JSON response may not always be valid; be sure to handle errors + available_functions = { + "get_current_weather": get_current_weather, + } # only one function in this example, but you can have multiple + # extend conversation with assistant's reply + messages.append(response_message) + # Step 4: send the info for each function call and function response to the model + for tool_call in tool_calls: + function_name = tool_call.function.name + function_to_call = available_functions[function_name] + function_args = json.loads(tool_call.function.arguments) + function_response = function_to_call( + location=function_args.get("location"), + unit=function_args.get("unit"), + ) + messages.append( + { + "tool_call_id": tool_call.id, + "role": "tool", + "name": function_name, + "content": function_response, + } + ) # extend conversation with function response + second_response = client.chat.completions.create( + model=openai_model, + messages=messages, + ) # get a new response from the model where it can see the function response + content = second_response.choices[0].message.content + assert isinstance(content, str) + assert content, "No output from the model." + + +@pytest.mark.stream +@pytest.mark.tools +def test_tools_stream(openai_model): + # Step 1: send the conversation and available functions to the model + result = setup_tools_messages() + messages = result["messages"] + tools = result["tools"] + response = client.chat.completions.create( + model=openai_model, + messages=messages, + tools=tools, + tool_choice="auto", # auto is default, but we'll be explicit + stream=True, + ) + tool_calls = [] + for chunk in response: + if chunk.choices[0].delta: + if tc := chunk.choices[0].delta.tool_calls: + if tc[0].id: + tool_calls.append(tc[0]) + else: + tool_calls[-1].function.arguments += tc[0].function.arguments + function_args = [{"function": t.function.name, "arguments": t.function.arguments} for t in tool_calls] + assert len(function_args) == 3 + + +@pytest.mark.tools +@pytest.mark.stream +@pytest.mark.async_client +@pytest.mark.asyncio +async def test_tools_stream_async(openai_model): + client = AsyncOpenAI() + # Step 1: send the conversation and available functions to the model + result = setup_tools_messages() + messages = result["messages"] + tools = result["tools"] + + response = await client.chat.completions.create( + model=openai_model, + messages=messages, + tools=tools, + tool_choice="auto", # auto is default, but we'll be explicit + stream=True, + ) + + tool_calls = [] + async for chunk in response: + if chunk.choices[0].delta: + if tc := chunk.choices[0].delta.tool_calls: + if tc[0].id: + tool_calls.append(tc[0]) + else: + tool_calls[-1].function.arguments += tc[0].function.arguments + + function_args = [{"function": t.function.name, "arguments": t.function.arguments} for t in tool_calls] + assert len(function_args) == 3