From 63c76333ab4a4dfc0e37cca66acefd5ad74924be Mon Sep 17 00:00:00 2001 From: KepingYan Date: Tue, 3 Jan 2023 13:00:45 +0800 Subject: [PATCH] fixed version, add torchmetrics for pytorch example (#295) --- tutorials/pytorch_example.ipynb | 2 +- tutorials/raytrain_example.ipynb | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/tutorials/pytorch_example.ipynb b/tutorials/pytorch_example.ipynb index a83fd219..14083ad8 100644 --- a/tutorials/pytorch_example.ipynb +++ b/tutorials/pytorch_example.ipynb @@ -1 +1 @@ -{"cells":[{"cell_type":"markdown","metadata":{"id":"2TxtrYkpljYO"},"source":["# **How RayDP works together with Pytorch**"]},{"cell_type":"markdown","metadata":{"id":"PVGNVwUU9lGW"},"source":["RayDP is a distributed data processing library that provides simple APIs for running Spark on Ray and integrating Spark with distributed deep learning and machine learning frameworks. This document builds an end-to-end deep learning pipeline on a single Ray cluster by using Spark for data preprocessing, and uses distributed estimator based on the raydp api to complete the training and evaluation."]},{"cell_type":"markdown","metadata":{"id":"hQ0v0n-PhnhY"},"source":["[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/oap-project/raydp/blob/master/tutorials/pytorch_example.ipynb)"]},{"cell_type":"markdown","metadata":{"id":"JJCwQHcvlmqF"},"source":["## 1. Colab enviroment Setup"]},{"cell_type":"markdown","metadata":{"id":"iMRuDWw9qh29"},"source":["RayDP requires Ray and PySpark. At the same time, pytorch is used to build deep learning model."]},{"cell_type":"code","execution_count":74,"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"elapsed":13503,"status":"ok","timestamp":1653033884645,"user":{"displayName":"Ke Yan","userId":"12157985255311286389"},"user_tz":-480},"id":"SQplP7vYlCp2","outputId":"40a8aaf5-0187-4a58-e4f7-1175f2744d30"},"outputs":[{"name":"stdout","output_type":"stream","text":["Requirement already satisfied: ray==1.9 in /usr/local/lib/python3.7/dist-packages (1.9.0)\n","Requirement already satisfied: grpcio>=1.28.1 in /usr/local/lib/python3.7/dist-packages (from ray==1.9) (1.46.1)\n","Requirement already satisfied: pyyaml in /usr/local/lib/python3.7/dist-packages (from ray==1.9) (3.13)\n","Requirement already satisfied: attrs in /usr/local/lib/python3.7/dist-packages (from ray==1.9) (21.4.0)\n","Requirement already satisfied: numpy>=1.16 in /usr/local/lib/python3.7/dist-packages (from ray==1.9) (1.21.6)\n","Requirement already satisfied: click>=7.0 in /usr/local/lib/python3.7/dist-packages (from ray==1.9) (7.1.2)\n","Requirement already satisfied: msgpack<2.0.0,>=1.0.0 in /usr/local/lib/python3.7/dist-packages (from ray==1.9) (1.0.3)\n","Requirement already satisfied: protobuf>=3.15.3 in /usr/local/lib/python3.7/dist-packages (from ray==1.9) (3.17.3)\n","Requirement already satisfied: filelock in /usr/local/lib/python3.7/dist-packages (from ray==1.9) (3.7.0)\n","Requirement already satisfied: jsonschema in /usr/local/lib/python3.7/dist-packages (from ray==1.9) (4.3.3)\n","Requirement already satisfied: redis>=3.5.0 in /usr/local/lib/python3.7/dist-packages (from ray==1.9) (4.3.1)\n","Requirement already satisfied: six>=1.5.2 in /usr/local/lib/python3.7/dist-packages (from grpcio>=1.28.1->ray==1.9) (1.15.0)\n","Requirement already satisfied: typing-extensions in /usr/local/lib/python3.7/dist-packages (from redis>=3.5.0->ray==1.9) (4.2.0)\n","Requirement already satisfied: async-timeout>=4.0.2 in /usr/local/lib/python3.7/dist-packages (from redis>=3.5.0->ray==1.9) (4.0.2)\n","Requirement already satisfied: importlib-metadata>=1.0 in /usr/local/lib/python3.7/dist-packages (from redis>=3.5.0->ray==1.9) (4.11.3)\n","Requirement already satisfied: deprecated>=1.2.3 in /usr/local/lib/python3.7/dist-packages (from redis>=3.5.0->ray==1.9) (1.2.13)\n","Requirement already satisfied: packaging>=20.4 in /usr/local/lib/python3.7/dist-packages (from redis>=3.5.0->ray==1.9) (21.3)\n","Requirement already satisfied: wrapt<2,>=1.10 in /usr/local/lib/python3.7/dist-packages (from deprecated>=1.2.3->redis>=3.5.0->ray==1.9) (1.14.1)\n","Requirement already satisfied: zipp>=0.5 in /usr/local/lib/python3.7/dist-packages (from importlib-metadata>=1.0->redis>=3.5.0->ray==1.9) (3.8.0)\n","Requirement already satisfied: pyparsing!=3.0.5,>=2.0.2 in /usr/local/lib/python3.7/dist-packages (from packaging>=20.4->redis>=3.5.0->ray==1.9) (3.0.9)\n","Requirement already satisfied: pyrsistent!=0.17.0,!=0.17.1,!=0.17.2,>=0.14.0 in /usr/local/lib/python3.7/dist-packages (from jsonschema->ray==1.9) (0.18.1)\n","Requirement already satisfied: importlib-resources>=1.4.0 in /usr/local/lib/python3.7/dist-packages (from jsonschema->ray==1.9) (5.7.1)\n","Requirement already satisfied: raydp-nightly in /usr/local/lib/python3.7/dist-packages (2022.5.12.dev0)\n","Requirement already satisfied: pandas>=1.1.4 in /usr/local/lib/python3.7/dist-packages (from raydp-nightly) (1.3.5)\n","Requirement already satisfied: typing in /usr/local/lib/python3.7/dist-packages (from raydp-nightly) (3.7.4.3)\n","Requirement already satisfied: pyspark>=3.2.0 in /usr/local/lib/python3.7/dist-packages (from raydp-nightly) (3.2.1)\n","Requirement already satisfied: ray>=1.8.0 in /usr/local/lib/python3.7/dist-packages (from raydp-nightly) (1.9.0)\n","Requirement already satisfied: numpy in /usr/local/lib/python3.7/dist-packages (from raydp-nightly) (1.21.6)\n","Requirement already satisfied: netifaces in /usr/local/lib/python3.7/dist-packages (from raydp-nightly) (0.11.0)\n","Requirement already satisfied: pyarrow<7.0.0,>=4.0.1 in /usr/local/lib/python3.7/dist-packages (from raydp-nightly) (6.0.1)\n","Requirement already satisfied: psutil in /usr/local/lib/python3.7/dist-packages (from raydp-nightly) (5.4.8)\n","Requirement already satisfied: python-dateutil>=2.7.3 in /usr/local/lib/python3.7/dist-packages (from pandas>=1.1.4->raydp-nightly) (2.8.2)\n","Requirement already satisfied: pytz>=2017.3 in /usr/local/lib/python3.7/dist-packages (from pandas>=1.1.4->raydp-nightly) (2022.1)\n","Requirement already satisfied: py4j==0.10.9.3 in /usr/local/lib/python3.7/dist-packages (from pyspark>=3.2.0->raydp-nightly) (0.10.9.3)\n","Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.7/dist-packages (from python-dateutil>=2.7.3->pandas>=1.1.4->raydp-nightly) (1.15.0)\n","Requirement already satisfied: jsonschema in /usr/local/lib/python3.7/dist-packages (from ray>=1.8.0->raydp-nightly) (4.3.3)\n","Requirement already satisfied: filelock in /usr/local/lib/python3.7/dist-packages (from ray>=1.8.0->raydp-nightly) (3.7.0)\n","Requirement already satisfied: pyyaml in /usr/local/lib/python3.7/dist-packages (from ray>=1.8.0->raydp-nightly) (3.13)\n","Requirement already satisfied: attrs in /usr/local/lib/python3.7/dist-packages (from ray>=1.8.0->raydp-nightly) (21.4.0)\n","Requirement already satisfied: protobuf>=3.15.3 in /usr/local/lib/python3.7/dist-packages (from ray>=1.8.0->raydp-nightly) (3.17.3)\n","Requirement already satisfied: redis>=3.5.0 in /usr/local/lib/python3.7/dist-packages (from ray>=1.8.0->raydp-nightly) (4.3.1)\n","Requirement already satisfied: grpcio>=1.28.1 in /usr/local/lib/python3.7/dist-packages (from ray>=1.8.0->raydp-nightly) (1.46.1)\n","Requirement already satisfied: msgpack<2.0.0,>=1.0.0 in /usr/local/lib/python3.7/dist-packages (from ray>=1.8.0->raydp-nightly) (1.0.3)\n","Requirement already satisfied: click>=7.0 in /usr/local/lib/python3.7/dist-packages (from ray>=1.8.0->raydp-nightly) (7.1.2)\n","Requirement already satisfied: packaging>=20.4 in /usr/local/lib/python3.7/dist-packages (from redis>=3.5.0->ray>=1.8.0->raydp-nightly) (21.3)\n","Requirement already satisfied: importlib-metadata>=1.0 in /usr/local/lib/python3.7/dist-packages (from redis>=3.5.0->ray>=1.8.0->raydp-nightly) (4.11.3)\n","Requirement already satisfied: typing-extensions in /usr/local/lib/python3.7/dist-packages (from redis>=3.5.0->ray>=1.8.0->raydp-nightly) (4.2.0)\n","Requirement already satisfied: deprecated>=1.2.3 in /usr/local/lib/python3.7/dist-packages (from redis>=3.5.0->ray>=1.8.0->raydp-nightly) (1.2.13)\n","Requirement already satisfied: async-timeout>=4.0.2 in /usr/local/lib/python3.7/dist-packages (from redis>=3.5.0->ray>=1.8.0->raydp-nightly) (4.0.2)\n","Requirement already satisfied: wrapt<2,>=1.10 in /usr/local/lib/python3.7/dist-packages (from deprecated>=1.2.3->redis>=3.5.0->ray>=1.8.0->raydp-nightly) (1.14.1)\n","Requirement already satisfied: zipp>=0.5 in /usr/local/lib/python3.7/dist-packages (from importlib-metadata>=1.0->redis>=3.5.0->ray>=1.8.0->raydp-nightly) (3.8.0)\n","Requirement already satisfied: pyparsing!=3.0.5,>=2.0.2 in /usr/local/lib/python3.7/dist-packages (from packaging>=20.4->redis>=3.5.0->ray>=1.8.0->raydp-nightly) (3.0.9)\n","Requirement already satisfied: importlib-resources>=1.4.0 in /usr/local/lib/python3.7/dist-packages (from jsonschema->ray>=1.8.0->raydp-nightly) (5.7.1)\n","Requirement already satisfied: pyrsistent!=0.17.0,!=0.17.1,!=0.17.2,>=0.14.0 in /usr/local/lib/python3.7/dist-packages (from jsonschema->ray>=1.8.0->raydp-nightly) (0.18.1)\n","Requirement already satisfied: ray[tune] in /usr/local/lib/python3.7/dist-packages (1.9.0)\n","Requirement already satisfied: grpcio>=1.28.1 in /usr/local/lib/python3.7/dist-packages (from ray[tune]) (1.46.1)\n","Requirement already satisfied: numpy>=1.16 in /usr/local/lib/python3.7/dist-packages (from ray[tune]) (1.21.6)\n","Requirement already satisfied: filelock in /usr/local/lib/python3.7/dist-packages (from ray[tune]) (3.7.0)\n","Requirement already satisfied: jsonschema in /usr/local/lib/python3.7/dist-packages (from ray[tune]) (4.3.3)\n","Requirement already satisfied: msgpack<2.0.0,>=1.0.0 in /usr/local/lib/python3.7/dist-packages (from ray[tune]) (1.0.3)\n","Requirement already satisfied: pyyaml in /usr/local/lib/python3.7/dist-packages (from ray[tune]) (3.13)\n","Requirement already satisfied: attrs in /usr/local/lib/python3.7/dist-packages (from ray[tune]) (21.4.0)\n","Requirement already satisfied: click>=7.0 in /usr/local/lib/python3.7/dist-packages (from ray[tune]) (7.1.2)\n","Requirement already satisfied: redis>=3.5.0 in /usr/local/lib/python3.7/dist-packages (from ray[tune]) (4.3.1)\n","Requirement already satisfied: protobuf>=3.15.3 in /usr/local/lib/python3.7/dist-packages (from ray[tune]) (3.17.3)\n","Requirement already satisfied: tabulate in /usr/local/lib/python3.7/dist-packages (from ray[tune]) (0.8.9)\n","Requirement already satisfied: requests in /usr/local/lib/python3.7/dist-packages (from ray[tune]) (2.23.0)\n","Requirement already satisfied: pandas in /usr/local/lib/python3.7/dist-packages (from ray[tune]) (1.3.5)\n","Requirement already satisfied: tensorboardX>=1.9 in /usr/local/lib/python3.7/dist-packages (from ray[tune]) (2.5)\n","Requirement already satisfied: six>=1.5.2 in /usr/local/lib/python3.7/dist-packages (from grpcio>=1.28.1->ray[tune]) (1.15.0)\n","Requirement already satisfied: deprecated>=1.2.3 in /usr/local/lib/python3.7/dist-packages (from redis>=3.5.0->ray[tune]) (1.2.13)\n","Requirement already satisfied: importlib-metadata>=1.0 in /usr/local/lib/python3.7/dist-packages (from redis>=3.5.0->ray[tune]) (4.11.3)\n","Requirement already satisfied: typing-extensions in /usr/local/lib/python3.7/dist-packages (from redis>=3.5.0->ray[tune]) (4.2.0)\n","Requirement already satisfied: packaging>=20.4 in /usr/local/lib/python3.7/dist-packages (from redis>=3.5.0->ray[tune]) (21.3)\n","Requirement already satisfied: async-timeout>=4.0.2 in /usr/local/lib/python3.7/dist-packages (from redis>=3.5.0->ray[tune]) (4.0.2)\n","Requirement already satisfied: wrapt<2,>=1.10 in /usr/local/lib/python3.7/dist-packages (from deprecated>=1.2.3->redis>=3.5.0->ray[tune]) (1.14.1)\n","Requirement already satisfied: zipp>=0.5 in /usr/local/lib/python3.7/dist-packages (from importlib-metadata>=1.0->redis>=3.5.0->ray[tune]) (3.8.0)\n","Requirement already satisfied: pyparsing!=3.0.5,>=2.0.2 in /usr/local/lib/python3.7/dist-packages (from packaging>=20.4->redis>=3.5.0->ray[tune]) (3.0.9)\n","Requirement already satisfied: pyrsistent!=0.17.0,!=0.17.1,!=0.17.2,>=0.14.0 in /usr/local/lib/python3.7/dist-packages (from jsonschema->ray[tune]) (0.18.1)\n","Requirement already satisfied: importlib-resources>=1.4.0 in /usr/local/lib/python3.7/dist-packages (from jsonschema->ray[tune]) (5.7.1)\n","Requirement already satisfied: pytz>=2017.3 in /usr/local/lib/python3.7/dist-packages (from pandas->ray[tune]) (2022.1)\n","Requirement already satisfied: python-dateutil>=2.7.3 in /usr/local/lib/python3.7/dist-packages (from pandas->ray[tune]) (2.8.2)\n","Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.7/dist-packages (from requests->ray[tune]) (2021.10.8)\n","Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.7/dist-packages (from requests->ray[tune]) (1.24.3)\n","Requirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.7/dist-packages (from requests->ray[tune]) (3.0.4)\n","Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.7/dist-packages (from requests->ray[tune]) (2.10)\n","Looking in links: https://download.pytorch.org/whl/torch_stable.html\n","Requirement already satisfied: torch==1.8.1+cpu in /usr/local/lib/python3.7/dist-packages (1.8.1+cpu)\n","Requirement already satisfied: typing-extensions in /usr/local/lib/python3.7/dist-packages (from torch==1.8.1+cpu) (4.2.0)\n","Requirement already satisfied: numpy in /usr/local/lib/python3.7/dist-packages (from torch==1.8.1+cpu) (1.21.6)\n"]}],"source":["! pip install ray==1.9\n","# install RayDP nightly build\n","! pip install raydp-nightly\n","# or use the released version\n","# ! pip install raydp\n","! pip install ray[tune]\n","! pip install torch==1.8.1+cpu -f https://download.pytorch.org/whl/torch_stable.html"]},{"cell_type":"markdown","metadata":{"id":"XWw2fBhfqvME"},"source":["## 2. Get the data file"]},{"cell_type":"markdown","metadata":{"id":"pFZtArpXqyS2"},"source":["The dataset is from: https://www.kaggle.com/datasets/fedesoriano/stroke-prediction-dataset, and we store the file in github repository. It's used to predict whether a patient is likely to get stroke based on the input parameters like gender, age, various diseases, and smoking status. Each row in the data provides relavant information about the patient. "]},{"cell_type":"code","execution_count":null,"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"elapsed":17,"status":"ok","timestamp":1653033884646,"user":{"displayName":"Ke Yan","userId":"12157985255311286389"},"user_tz":-480},"id":"J-Hzp_q7qzz5","outputId":"5dae58ce-a40a-47f1-cb1b-6485cf15b949"},"outputs":[],"source":["! wget https://raw.githubusercontent.com/oap-project/raydp/master/tutorials/dataset/healthcare-dataset-stroke-data.csv -O healthcare-dataset-stroke-data.csv"]},{"cell_type":"markdown","metadata":{"id":"Y-ceLv2-q__G"},"source":["## 3. Init or connect to a ray cluster"]},{"cell_type":"code","execution_count":76,"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"elapsed":3705,"status":"ok","timestamp":1653033888338,"user":{"displayName":"Ke Yan","userId":"12157985255311286389"},"user_tz":-480},"id":"CVFQcpiJrC8Y","outputId":"bbbbc37c-9834-4c39-e652-0178842f0e7a"},"outputs":[{"data":{"text/plain":["{'metrics_export_port': 59076,\n"," 'node_id': '4a3abd8dce03a2b3e9e76b935b01c20eb501d3d5135929e1eadeccae',\n"," 'node_ip_address': '172.28.0.2',\n"," 'object_store_address': '/tmp/ray/session_2022-05-20_08-04-44_153620_58/sockets/plasma_store',\n"," 'raylet_ip_address': '172.28.0.2',\n"," 'raylet_socket_name': '/tmp/ray/session_2022-05-20_08-04-44_153620_58/sockets/raylet',\n"," 'redis_address': '172.28.0.2:6379',\n"," 'session_dir': '/tmp/ray/session_2022-05-20_08-04-44_153620_58',\n"," 'webui_url': None}"]},"execution_count":76,"metadata":{},"output_type":"execute_result"}],"source":["import ray\n","\n","ray.init(num_cpus=6)"]},{"cell_type":"markdown","metadata":{"id":"MvuatpSsrGHD"},"source":["## 4. Get a spark session"]},{"cell_type":"code","execution_count":77,"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"elapsed":6744,"status":"ok","timestamp":1653033895074,"user":{"displayName":"Ke Yan","userId":"12157985255311286389"},"user_tz":-480},"id":"ze-waQ7zrH79","outputId":"df61b5b8-4c10-4a08-b5b9-d2ead5a1e6a2"},"outputs":[{"name":"stdout","output_type":"stream","text":["\u001b[2m\u001b[36m(RayDPSparkMaster pid=4514)\u001b[0m WARNING: sun.reflect.Reflection.getCallerClass is not supported. This will impact performance.\n"]},{"name":"stderr","output_type":"stream","text":["\u001b[2m\u001b[36m(RayDPSparkMaster pid=4514)\u001b[0m WARNING: An illegal reflective access operation has occurred\n","\u001b[2m\u001b[36m(RayDPSparkMaster pid=4514)\u001b[0m WARNING: Illegal reflective access by org.apache.spark.unsafe.Platform (file:/usr/local/lib/python3.7/dist-packages/pyspark/jars/spark-unsafe_2.12-3.2.1.jar) to constructor java.nio.DirectByteBuffer(long,int)\n","\u001b[2m\u001b[36m(RayDPSparkMaster pid=4514)\u001b[0m WARNING: Please consider reporting this to the maintainers of org.apache.spark.unsafe.Platform\n","\u001b[2m\u001b[36m(RayDPSparkMaster pid=4514)\u001b[0m WARNING: Use --illegal-access=warn to enable warnings of further illegal reflective access operations\n","\u001b[2m\u001b[36m(RayDPSparkMaster pid=4514)\u001b[0m WARNING: All illegal access operations will be denied in a future release\n"]},{"name":"stdout","output_type":"stream","text":["\u001b[2m\u001b[36m(RayDPSparkMaster pid=4514)\u001b[0m 2022-05-20 08:04:52,275 WARN NativeCodeLoader [Thread-2]: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable\n","\u001b[2m\u001b[36m(RayDPSparkMaster pid=4514)\u001b[0m 2022-05-20 08:04:52,602 INFO SecurityManager [Thread-2]: Changing view acls to: root\n","\u001b[2m\u001b[36m(RayDPSparkMaster pid=4514)\u001b[0m 2022-05-20 08:04:52,604 INFO SecurityManager [Thread-2]: Changing modify acls to: root\n","\u001b[2m\u001b[36m(RayDPSparkMaster pid=4514)\u001b[0m 2022-05-20 08:04:52,605 INFO SecurityManager [Thread-2]: Changing view acls groups to: \n","\u001b[2m\u001b[36m(RayDPSparkMaster pid=4514)\u001b[0m 2022-05-20 08:04:52,606 INFO SecurityManager [Thread-2]: Changing modify acls groups to: \n","\u001b[2m\u001b[36m(RayDPSparkMaster pid=4514)\u001b[0m 2022-05-20 08:04:52,607 INFO SecurityManager [Thread-2]: SecurityManager: authentication disabled; ui acls disabled; users with view permissions: Set(root); groups with view permissions: Set(); users with modify permissions: Set(root); groups with modify permissions: Set()\n","\u001b[2m\u001b[36m(RayDPSparkMaster pid=4514)\u001b[0m 2022-05-20 08:04:53,404 INFO Utils [Thread-2]: Successfully started service 'RAY_RPC_ENV' on port 45691.\n","\u001b[2m\u001b[36m(RayDPSparkMaster pid=4514)\u001b[0m 2022-05-20 08:04:54,308 INFO RayAppMaster$RayAppMasterEndpoint [dispatcher-event-loop-1]: Registering app Stoke Prediction with RayDP\n","\u001b[2m\u001b[36m(RayDPSparkMaster pid=4514)\u001b[0m 2022-05-20 08:04:54,326 INFO RayAppMaster$RayAppMasterEndpoint [dispatcher-event-loop-1]: Registered app Stoke Prediction with RayDP with ID app-20220520080454-0000\n"]}],"source":["import raydp\n","\n","app_name = \"Stoke Prediction with RayDP\"\n","num_executors = 1\n","cores_per_executor = 1\n","memory_per_executor = \"500M\"\n","spark = raydp.init_spark(app_name, num_executors, cores_per_executor, memory_per_executor)"]},{"cell_type":"markdown","metadata":{"id":"AYdTc0wQrMSQ"},"source":["## 5. Get data from .csv file via 'spark' created by **raydp**"]},{"cell_type":"code","execution_count":78,"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"elapsed":14661,"status":"ok","timestamp":1653033909730,"user":{"displayName":"Ke Yan","userId":"12157985255311286389"},"user_tz":-480},"id":"2vVAO9gwrNY9","outputId":"ef0717b9-6b25-4327-8b66-b33789440895"},"outputs":[{"name":"stderr","output_type":"stream","text":["\u001b[2m\u001b[33m(raylet)\u001b[0m WARNING: An illegal reflective access operation has occurred\n","\u001b[2m\u001b[33m(raylet)\u001b[0m WARNING: Illegal reflective access by org.apache.spark.unsafe.Platform (file:/usr/local/lib/python3.7/dist-packages/pyspark/jars/spark-unsafe_2.12-3.2.1.jar) to constructor java.nio.DirectByteBuffer(long,int)\n","\u001b[2m\u001b[33m(raylet)\u001b[0m WARNING: Please consider reporting this to the maintainers of org.apache.spark.unsafe.Platform\n","\u001b[2m\u001b[33m(raylet)\u001b[0m WARNING: Use --illegal-access=warn to enable warnings of further illegal reflective access operations\n","\u001b[2m\u001b[33m(raylet)\u001b[0m WARNING: All illegal access operations will be denied in a future release\n"]}],"source":["data = spark.read.format(\"csv\").option(\"header\", \"true\") \\\n"," .option(\"inferSchema\", \"true\") \\\n"," .load(\"/content/healthcare-dataset-stroke-data.csv\")"]},{"cell_type":"markdown","metadata":{"id":"n8UO4brnrP_V"},"source":["## 6. Define the data_process function"]},{"cell_type":"markdown","metadata":{"id":"T1TJVdSCrSfg"},"source":["The dataset is converted to `pyspark.sql.dataframe.DataFrame`. Before feeding into the deep learning model, we can use raydp to do some transformation operations on dataset."]},{"cell_type":"markdown","metadata":{"id":"xwC6GDVrrU0s"},"source":["### 6.1 Data Analysis"]},{"cell_type":"markdown","metadata":{"id":"K_VMeCUDrZOd"},"source":["Here is a part of the data analysis."]},{"cell_type":"code","execution_count":79,"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"elapsed":6469,"status":"ok","timestamp":1653033916189,"user":{"displayName":"Ke Yan","userId":"12157985255311286389"},"user_tz":-480},"id":"rjF-QfKsrXTF","outputId":"41f2a242-caf7-4102-f9ff-0e50f1790736"},"outputs":[{"name":"stdout","output_type":"stream","text":["+-----+------+----+------------+-------------+------------+-------------+--------------+-----------------+----+---------------+------+\n","| id|gender| age|hypertension|heart_disease|ever_married| work_type|Residence_type|avg_glucose_level| bmi| smoking_status|stroke|\n","+-----+------+----+------------+-------------+------------+-------------+--------------+-----------------+----+---------------+------+\n","| 9046| Male|67.0| 0| 1| Yes| Private| Urban| 228.69|36.6|formerly smoked| 1|\n","|51676|Female|61.0| 0| 0| Yes|Self-employed| Rural| 202.21| N/A| never smoked| 1|\n","|31112| Male|80.0| 0| 1| Yes| Private| Rural| 105.92|32.5| never smoked| 1|\n","|60182|Female|49.0| 0| 0| Yes| Private| Urban| 171.23|34.4| smokes| 1|\n","| 1665|Female|79.0| 1| 0| Yes|Self-employed| Rural| 174.12| 24| never smoked| 1|\n","+-----+------+----+------------+-------------+------------+-------------+--------------+-----------------+----+---------------+------+\n","only showing top 5 rows\n","\n","+-------+-----------------+------+------------------+------------------+-------------------+------------+---------+--------------+------------------+------------------+--------------+-------------------+\n","|summary| id|gender| age| hypertension| heart_disease|ever_married|work_type|Residence_type| avg_glucose_level| bmi|smoking_status| stroke|\n","+-------+-----------------+------+------------------+------------------+-------------------+------------+---------+--------------+------------------+------------------+--------------+-------------------+\n","| count| 5110| 5110| 5110| 5110| 5110| 5110| 5110| 5110| 5110| 5110| 5110| 5110|\n","| mean|36517.82935420744| null|43.226614481409015|0.0974559686888454|0.05401174168297456| null| null| null|106.14767710371804|28.893236911794673| null| 0.0487279843444227|\n","| stddev|21161.72162482715| null| 22.61264672311348| 0.296606674233791|0.22606298750336554| null| null| null| 45.28356015058193| 7.85406672968016| null|0.21531985698023753|\n","| min| 67|Female| 0.08| 0| 0| No| Govt_job| Rural| 55.12| 10.3| Unknown| 0|\n","| max| 72940| Other| 82.0| 1| 1| Yes| children| Urban| 271.74| N/A| smokes| 1|\n","+-------+-----------------+------+------------------+------------------+-------------------+------------+---------+--------------+------------------+------------------+--------------+-------------------+\n","\n","+------+-----+\n","|gender|count|\n","+------+-----+\n","| Male| 2115|\n","| null| 5110|\n","|Female| 2994|\n","| Other| 1|\n","+------+-----+\n","\n","+------+-----+\n","|stroke|count|\n","+------+-----+\n","| 1| 249|\n","| 0| 4861|\n","| null| 5110|\n","+------+-----+\n","\n"]}],"source":["# Data overview\n","data.show(5)\n","# Statistical N/A distribution\n","# There are 201 'N/A' value in column 'bmi column',\n","# we can update them the mean of the column\n","data.describe().show()\n","data.filter(data.bmi=='N/A').count()\n","# Observe the distribution of the column 'gender'\n","# Then we should remove the outliers 'Other'\n","data.rollup(data.gender).count().show()\n","# Observe the proportion of positive and negative samples.\n","data.rollup(data.stroke).count().show()"]},{"cell_type":"markdown","metadata":{"id":"BwQIFmPArb4t"},"source":["### 6.2 Define operations"]},{"cell_type":"markdown","metadata":{"id":"nWQp7YiJsYU2"},"source":["Define data processing operations based on data analysis results."]},{"cell_type":"code","execution_count":80,"metadata":{"executionInfo":{"elapsed":14,"status":"ok","timestamp":1653033916189,"user":{"displayName":"Ke Yan","userId":"12157985255311286389"},"user_tz":-480},"id":"pYZLIkqwrduU"},"outputs":[],"source":["from pyspark.sql.functions import hour, quarter, month, year, dayofweek, dayofmonth, weekofyear, col, lit, udf, abs as functions_abs, avg"]},{"cell_type":"code","execution_count":81,"metadata":{"executionInfo":{"elapsed":14,"status":"ok","timestamp":1653033916189,"user":{"displayName":"Ke Yan","userId":"12157985255311286389"},"user_tz":-480},"id":"YqnWEwlMrkJ3"},"outputs":[],"source":["# Delete the useless column 'id'\n","def drop_col(data):\n"," data = data.drop('id')\n"," return data"]},{"cell_type":"code","execution_count":82,"metadata":{"executionInfo":{"elapsed":14,"status":"ok","timestamp":1653033916190,"user":{"displayName":"Ke Yan","userId":"12157985255311286389"},"user_tz":-480},"id":"x4FGXwNxrlK_"},"outputs":[],"source":["# Replace the value N/A in 'bmi'\n","def replace_nan(data):\n"," bmi_avg = data.agg(avg(col(\"bmi\"))).head()[0]\n","\n"," @udf(\"float\")\n"," def replace_nan(value):\n"," if value=='N/A':\n"," return float(bmi_avg)\n"," else:\n"," return float(value)\n","\n"," # Replace the value N/A\n"," data = data.withColumn('bmi', replace_nan(col(\"bmi\")))\n"," return data"]},{"cell_type":"code","execution_count":83,"metadata":{"executionInfo":{"elapsed":13,"status":"ok","timestamp":1653033916190,"user":{"displayName":"Ke Yan","userId":"12157985255311286389"},"user_tz":-480},"id":"JSAv6q_3rnmn"},"outputs":[],"source":["# Drop the only one value 'Other' in column 'gender'\n","def clean_value(data):\n"," data = data.filter(data.gender != 'Other')\n"," return data"]},{"cell_type":"code","execution_count":84,"metadata":{"executionInfo":{"elapsed":463,"status":"ok","timestamp":1653033916640,"user":{"displayName":"Ke Yan","userId":"12157985255311286389"},"user_tz":-480},"id":"24PMi6OHrpDy"},"outputs":[],"source":["# Transform the category columns\n","def trans_category(data):\n"," @udf(\"int\")\n"," def trans_gender(value):\n"," gender = {'Female': 0,\n"," 'Male': 1}\n"," return int(gender[value])\n","\n"," @udf(\"int\")\n"," def trans_ever_married(value):\n"," residence_type = {'No': 0,\n"," 'Yes': 1}\n"," return int(residence_type[value])\n","\n"," @udf(\"int\")\n"," def trans_work_type(value):\n"," work_type = {'children': 0,\n"," 'Govt_job': 1,\n"," 'Never_worked': 2,\n"," 'Private': 3,\n"," 'Self-employed': 4}\n"," return int(work_type[value])\n","\n"," @udf(\"int\")\n"," def trans_residence_type(value):\n"," residence_type = {'Rural': 0,\n"," 'Urban': 1}\n"," return int(residence_type[value])\n","\n"," @udf(\"int\")\n"," def trans_smoking_status(value):\n"," smoking_status = {'formerly smoked': 0,\n"," 'never smoked': 1,\n"," 'smokes': 2,\n"," 'Unknown': 3}\n"," return int(smoking_status[value])\n","\n"," data = data.withColumn('gender', trans_gender(col('gender'))) \\\n"," .withColumn('ever_married', trans_ever_married(col('ever_married'))) \\\n"," .withColumn('work_type', trans_work_type(col('work_type'))) \\\n"," .withColumn('Residence_type', trans_residence_type(col('Residence_type'))) \\\n"," .withColumn('smoking_status', trans_smoking_status(col('smoking_status')))\n"," return data"]},{"cell_type":"code","execution_count":85,"metadata":{"executionInfo":{"elapsed":5,"status":"ok","timestamp":1653033916641,"user":{"displayName":"Ke Yan","userId":"12157985255311286389"},"user_tz":-480},"id":"_V3UsmQwrq3I"},"outputs":[],"source":["# Add the discretized column of 'Age'\n","def map_age(data):\n"," @udf(\"int\")\n"," def get_value(value):\n"," if value >= 18 and value < 26:\n"," return int(0)\n"," elif value >=26 and value < 36:\n"," return int(1)\n"," elif value >=36 and value < 46:\n"," return int(2)\n"," elif value >=46 and value < 56:\n"," return int(3)\n"," else:\n"," return int(4)\n","\n"," data = data.withColumn('age_dis', get_value(col('age')))\n"," return data"]},{"cell_type":"code","execution_count":86,"metadata":{"executionInfo":{"elapsed":4,"status":"ok","timestamp":1653033916641,"user":{"displayName":"Ke Yan","userId":"12157985255311286389"},"user_tz":-480},"id":"-X0yYDz1rsnA"},"outputs":[],"source":["# Preprocess the data\n","def data_preprocess(data):\n"," data = drop_col(data)\n"," data = replace_nan(data)\n"," data = clean_value(data)\n"," data = trans_category(data)\n"," data = map_age(data)\n"," return data"]},{"cell_type":"markdown","metadata":{"id":"dfgFQdRRruS5"},"source":["## 7. Data processing"]},{"cell_type":"code","execution_count":87,"metadata":{"executionInfo":{"elapsed":10575,"status":"ok","timestamp":1653033927213,"user":{"displayName":"Ke Yan","userId":"12157985255311286389"},"user_tz":-480},"id":"vrudeEb1ryy2"},"outputs":[],"source":["import torch\n","from raydp.utils import random_split\n","\n","# Transform the dataset\n","data = data_preprocess(data)\n","# Split data into train_dataset and test_dataset\n","train_df, test_df = random_split(data, [0.8, 0.2], 0)\n","# Balance the positive and negative samples\n","train_df_neg = train_df.filter(train_df.stroke == '1')\n","train_df = train_df.unionByName(train_df_neg)\n","train_df = train_df.unionByName(train_df_neg)\n","features = [field.name for field in list(train_df.schema) if field.name != \"stroke\"]\n","# Convert spark dataframe into ray Dataset\n","# Remember to align ``parallelism`` with ``num_workers`` of ray train\n","train_dataset = ray.data.from_spark(train_df, parallelism = 8)\n","test_dataset = ray.data.from_spark(test_df, parallelism = 8)\n","feature_dtype = [torch.float] * len(features)"]},{"cell_type":"markdown","metadata":{"id":"B42j2xAfr3Wi"},"source":["## 8. Define a neural network model"]},{"cell_type":"code","execution_count":88,"metadata":{"executionInfo":{"elapsed":9,"status":"ok","timestamp":1653033927213,"user":{"displayName":"Ke Yan","userId":"12157985255311286389"},"user_tz":-480},"id":"Zd69EFv9r58U"},"outputs":[],"source":["import torch.nn as nn\n","import torch.nn.functional as F\n","\n","class NET_Model(nn.Module):\n"," def __init__(self, cols):\n"," super().__init__()\n"," self.emb_layer_gender = nn.Embedding(2, 1) # gender\n"," self.emb_layer_hypertension = nn.Embedding(2,1) # hypertension\n"," self.emb_layer_heart_disease = nn.Embedding(2,1) # heart_disease\n"," self.emb_layer_ever_married = nn.Embedding(2, 1) # ever_married\n"," self.emb_layer_work = nn.Embedding(5, 1) # work_type\n"," self.emb_layer_residence = nn.Embedding(2, 1) # Residence_type\n"," self.emb_layer_smoking_status = nn.Embedding(4, 1) # smoking_status\n"," self.emb_layer_age = nn.Embedding(5, 1) # age column after discretization\n"," self.fc1 = nn.Linear(cols, 256)\n"," self.fc2 = nn.Linear(256, 128)\n"," self.fc3 = nn.Linear(128, 64)\n"," self.fc4 = nn.Linear(64, 16)\n"," self.fc5 = nn.Linear(16, 2)\n"," self.bn1 = nn.BatchNorm1d(256)\n"," self.bn2 = nn.BatchNorm1d(128)\n"," self.bn3 = nn.BatchNorm1d(64)\n"," self.bn4 = nn.BatchNorm1d(16)\n","\n"," def forward(self, *x):\n"," x = torch.cat(x, dim=1)\n"," # pick the dense attribute columns\n"," dense_columns = x[:, [1,7,8]]\n"," # Embedding operation on sparse attribute columns\n"," sparse_col_1 = self.emb_layer_gender(x[:, 0].long())\n"," sparse_col_2 = self.emb_layer_hypertension(x[:, 2].long())\n"," sparse_col_3 = self.emb_layer_heart_disease(x[:, 3].long())\n"," sparse_col_4 = self.emb_layer_ever_married(x[:, 4].long())\n"," sparse_col_5 = self.emb_layer_work(x[:, 5].long())\n"," sparse_col_6 = self.emb_layer_residence(x[:, 6].long())\n"," sparse_col_7 = self.emb_layer_smoking_status(x[:, 9].long())\n"," sparse_col_8 = self.emb_layer_age(x[:, 10].long())\n"," # Splice sparse attribute columns and dense attribute columns\n"," x = torch.cat([dense_columns, sparse_col_1, sparse_col_2, sparse_col_3, sparse_col_4, sparse_col_5, sparse_col_6, sparse_col_7, sparse_col_8], dim=1)\n","\n"," x = F.relu(self.fc1(x))\n"," x = self.bn1(x)\n"," x = F.relu(self.fc2(x))\n"," x = self.bn2(x)\n"," x = F.relu(self.fc3(x))\n"," x = self.bn3(x)\n"," x = F.relu(self.fc4(x))\n"," x = self.bn4(x)\n"," x = self.fc5(x)\n"," return x\n"]},{"cell_type":"markdown","metadata":{"id":"Z85SeUPNmdXj"},"source":["## 9. Create model, critetion and optimizer"]},{"cell_type":"code","execution_count":89,"metadata":{"executionInfo":{"elapsed":7,"status":"ok","timestamp":1653033927214,"user":{"displayName":"Ke Yan","userId":"12157985255311286389"},"user_tz":-480},"id":"BM1BOwrgmjzN"},"outputs":[],"source":["import torch\n","import torch.nn as nn\n","\n","net_model = NET_Model(len(features))\n","criterion = nn.SmoothL1Loss()\n","optimizer = torch.optim.Adam(net_model.parameters(), lr=0.001)"]},{"cell_type":"markdown","metadata":{"id":"HwJdanB8TwCE"},"source":["## 10. Define the Callback which will be executed during training."]},{"cell_type":"code","execution_count":90,"metadata":{"executionInfo":{"elapsed":7,"status":"ok","timestamp":1653033927215,"user":{"displayName":"Ke Yan","userId":"12157985255311286389"},"user_tz":-480},"id":"emnm-i8YTwhS"},"outputs":[],"source":["from ray.train import TrainingCallback\n","from typing import List, Dict\n","\n","class PrintingCallback(TrainingCallback):\n"," def handle_result(self, results: List[Dict], **info):\n"," print(results)"]},{"cell_type":"markdown","metadata":{"id":"_csycx41mpuj"},"source":["## 11. Create distributed estimator and train"]},{"cell_type":"code","execution_count":91,"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"elapsed":37833,"status":"ok","timestamp":1653033965041,"user":{"displayName":"Ke Yan","userId":"12157985255311286389"},"user_tz":-480},"id":"heY9ullomyzh","outputId":"9f882156-65a4-4d89-cafd-5265bae4ef2d"},"outputs":[{"name":"stderr","output_type":"stream","text":["2022-05-20 08:05:28,725\tINFO trainer.py:172 -- Trainer logs will be logged in: /root/ray_results/train_2022-05-20_08-05-28\n","2022-05-20 08:05:30,430\tINFO trainer.py:178 -- Run results will be logged in: /root/ray_results/train_2022-05-20_08-05-28/run_001\n","\u001b[2m\u001b[36m(BaseWorkerMixin pid=5164)\u001b[0m 2022-05-20 08:05:30,427\tINFO torch.py:67 -- Setting up process group for: env:// [rank=0, world_size=1]\n","\u001b[2m\u001b[36m(BaseWorkerMixin pid=5164)\u001b[0m 2022-05-20 08:05:30,701\tINFO torch.py:239 -- Moving model to device: cpu\n","\u001b[2m\u001b[36m(BaseWorkerMixin pid=5164)\u001b[0m /usr/local/lib/python3.7/dist-packages/torch/nn/modules/loss.py:907: UserWarning: Using a target size (torch.Size([64, 1])) that is different to the input size (torch.Size([64, 2])). This will likely lead to incorrect results due to broadcasting. Please ensure they have the same size.\n","\u001b[2m\u001b[36m(BaseWorkerMixin pid=5164)\u001b[0m return F.smooth_l1_loss(input, target, reduction=self.reduction, beta=self.beta)\n"]},{"name":"stdout","output_type":"stream","text":["[{'epoch': 0, 'train_acc': 0.0, 'train_loss': 0.09247553693130613, '_timestamp': 1653033932, '_time_this_iter_s': 1.3748996257781982, '_training_iteration': 1}]\n","[{'epoch': 0, 'evaluate_acc': 0.0, 'test_loss': 0.04560316858046195, '_timestamp': 1653033932, '_time_this_iter_s': 0.09927511215209961, '_training_iteration': 2}]\n"]},{"name":"stderr","output_type":"stream","text":["\u001b[2m\u001b[36m(BaseWorkerMixin pid=5164)\u001b[0m /usr/local/lib/python3.7/dist-packages/torch/nn/modules/loss.py:907: UserWarning: Using a target size (torch.Size([38, 1])) that is different to the input size (torch.Size([38, 2])). This will likely lead to incorrect results due to broadcasting. Please ensure they have the same size.\n","\u001b[2m\u001b[36m(BaseWorkerMixin pid=5164)\u001b[0m return F.smooth_l1_loss(input, target, reduction=self.reduction, beta=self.beta)\n","\u001b[2m\u001b[36m(BaseWorkerMixin pid=5164)\u001b[0m /usr/local/lib/python3.7/dist-packages/torch/nn/modules/loss.py:907: UserWarning: Using a target size (torch.Size([29, 1])) that is different to the input size (torch.Size([29, 2])). This will likely lead to incorrect results due to broadcasting. Please ensure they have the same size.\n","\u001b[2m\u001b[36m(BaseWorkerMixin pid=5164)\u001b[0m return F.smooth_l1_loss(input, target, reduction=self.reduction, beta=self.beta)\n"]},{"name":"stdout","output_type":"stream","text":["[{'epoch': 1, 'train_acc': 0.0, 'train_loss': 0.06568372739878084, '_timestamp': 1653033933, '_time_this_iter_s': 1.012242317199707, '_training_iteration': 3}]\n","[{'epoch': 1, 'evaluate_acc': 0.0, 'test_loss': 0.051923071308171045, '_timestamp': 1653033933, '_time_this_iter_s': 0.10902786254882812, '_training_iteration': 4}]\n","[{'epoch': 2, 'train_acc': 0.0, 'train_loss': 0.062333274386557086, '_timestamp': 1653033934, '_time_this_iter_s': 0.9877479076385498, '_training_iteration': 5}]\n","[{'epoch': 2, 'evaluate_acc': 0.0, 'test_loss': 0.042880474863683474, '_timestamp': 1653033934, '_time_this_iter_s': 0.10124373435974121, '_training_iteration': 6}]\n","[{'epoch': 3, 'train_acc': 0.0, 'train_loss': 0.062327381250049385, '_timestamp': 1653033935, '_time_this_iter_s': 1.005267858505249, '_training_iteration': 7}]\n","[{'epoch': 3, 'evaluate_acc': 0.0, 'test_loss': 0.04614457756500034, '_timestamp': 1653033935, '_time_this_iter_s': 0.10248780250549316, '_training_iteration': 8}]\n","[{'epoch': 4, 'train_acc': 0.0, 'train_loss': 0.06244271249909486, '_timestamp': 1653033936, '_time_this_iter_s': 1.0482499599456787, '_training_iteration': 9}]\n","[{'epoch': 4, 'evaluate_acc': 0.0, 'test_loss': 0.04239819656290552, '_timestamp': 1653033936, '_time_this_iter_s': 0.09851694107055664, '_training_iteration': 10}]\n","[{'epoch': 5, 'train_acc': 0.0, 'train_loss': 0.062068206830216306, '_timestamp': 1653033937, '_time_this_iter_s': 1.0023341178894043, '_training_iteration': 11}]\n","[{'epoch': 5, 'evaluate_acc': 0.0, 'test_loss': 0.04345962876344428, '_timestamp': 1653033937, '_time_this_iter_s': 0.10348176956176758, '_training_iteration': 12}]\n","[{'epoch': 6, 'train_acc': 0.0, 'train_loss': 0.0617491880032633, '_timestamp': 1653033938, '_time_this_iter_s': 1.002626657485962, '_training_iteration': 13}]\n","[{'epoch': 6, 'evaluate_acc': 0.0, 'test_loss': 0.040240987368366295, '_timestamp': 1653033938, '_time_this_iter_s': 0.10725951194763184, '_training_iteration': 14}]\n","[{'epoch': 7, 'train_acc': 0.0, 'train_loss': 0.06170101864263415, '_timestamp': 1653033939, '_time_this_iter_s': 0.9948806762695312, '_training_iteration': 15}]\n","[{'epoch': 7, 'evaluate_acc': 0.0, 'test_loss': 0.04413437528316589, '_timestamp': 1653033939, '_time_this_iter_s': 0.11219525337219238, '_training_iteration': 16}]\n","[{'epoch': 8, 'train_acc': 0.0, 'train_loss': 0.06111902300534504, '_timestamp': 1653033940, '_time_this_iter_s': 1.0135900974273682, '_training_iteration': 17}]\n","[{'epoch': 8, 'evaluate_acc': 0.0, 'test_loss': 0.0439556289519019, '_timestamp': 1653033941, '_time_this_iter_s': 0.11552095413208008, '_training_iteration': 18}]\n","[{'epoch': 9, 'train_acc': 0.0, 'train_loss': 0.06117622062031712, '_timestamp': 1653033942, '_time_this_iter_s': 1.0200884342193604, '_training_iteration': 19}]\n","[{'epoch': 9, 'evaluate_acc': 0.0, 'test_loss': 0.04169800117447534, '_timestamp': 1653033942, '_time_this_iter_s': 0.09994888305664062, '_training_iteration': 20}]\n","[{'epoch': 10, 'train_acc': 0.0, 'train_loss': 0.06109746584136571, '_timestamp': 1653033943, '_time_this_iter_s': 1.0086157321929932, '_training_iteration': 21}]\n","[{'epoch': 10, 'evaluate_acc': 0.0, 'test_loss': 0.04427622495602597, '_timestamp': 1653033943, '_time_this_iter_s': 0.09854817390441895, '_training_iteration': 22}]\n","[{'epoch': 11, 'train_acc': 0.0, 'train_loss': 0.06094546751784427, '_timestamp': 1653033944, '_time_this_iter_s': 1.0076568126678467, '_training_iteration': 23}]\n","[{'epoch': 11, 'evaluate_acc': 0.0, 'test_loss': 0.044042169828625286, '_timestamp': 1653033944, '_time_this_iter_s': 0.10073971748352051, '_training_iteration': 24}]\n","[{'epoch': 12, 'train_acc': 0.0, 'train_loss': 0.060892132802733354, '_timestamp': 1653033945, '_time_this_iter_s': 1.006847858428955, '_training_iteration': 25}]\n","[{'epoch': 12, 'evaluate_acc': 0.0, 'test_loss': 0.04097760750857346, '_timestamp': 1653033945, '_time_this_iter_s': 0.10092830657958984, '_training_iteration': 26}]\n","[{'epoch': 13, 'train_acc': 0.0, 'train_loss': 0.06074612067480172, '_timestamp': 1653033946, '_time_this_iter_s': 1.0279877185821533, '_training_iteration': 27}]\n","[{'epoch': 13, 'evaluate_acc': 0.0, 'test_loss': 0.05008233364616685, '_timestamp': 1653033946, '_time_this_iter_s': 0.1041874885559082, '_training_iteration': 28}]\n","[{'epoch': 14, 'train_acc': 0.0, 'train_loss': 0.06069977636049901, '_timestamp': 1653033947, '_time_this_iter_s': 1.0180799961090088, '_training_iteration': 29}]\n","[{'epoch': 14, 'evaluate_acc': 0.0, 'test_loss': 0.04758017571807346, '_timestamp': 1653033947, '_time_this_iter_s': 0.09862518310546875, '_training_iteration': 30}]\n","[{'epoch': 15, 'train_acc': 0.0, 'train_loss': 0.06055774477177433, '_timestamp': 1653033948, '_time_this_iter_s': 0.9882826805114746, '_training_iteration': 31}]\n","[{'epoch': 15, 'evaluate_acc': 0.0, 'test_loss': 0.047914604396175814, '_timestamp': 1653033948, '_time_this_iter_s': 0.10061764717102051, '_training_iteration': 32}]\n","[{'epoch': 16, 'train_acc': 0.0, 'train_loss': 0.06053171257621476, '_timestamp': 1653033949, '_time_this_iter_s': 1.041496753692627, '_training_iteration': 33}]\n","[{'epoch': 16, 'evaluate_acc': 0.0, 'test_loss': 0.04322298049159786, '_timestamp': 1653033950, '_time_this_iter_s': 0.09923505783081055, '_training_iteration': 34}]\n","[{'epoch': 17, 'train_acc': 0.0, 'train_loss': 0.060558477350111516, '_timestamp': 1653033951, '_time_this_iter_s': 1.015479326248169, '_training_iteration': 35}]\n","[{'epoch': 17, 'evaluate_acc': 0.0, 'test_loss': 0.04752382685375564, '_timestamp': 1653033951, '_time_this_iter_s': 0.0949404239654541, '_training_iteration': 36}]\n","[{'epoch': 18, 'train_acc': 0.0, 'train_loss': 0.06059950442452516, '_timestamp': 1653033952, '_time_this_iter_s': 1.024864912033081, '_training_iteration': 37}]\n","[{'epoch': 18, 'evaluate_acc': 0.0, 'test_loss': 0.04356259904692278, '_timestamp': 1653033952, '_time_this_iter_s': 0.10841035842895508, '_training_iteration': 38}]\n","[{'epoch': 19, 'train_acc': 0.0, 'train_loss': 0.060317250567355325, '_timestamp': 1653033953, '_time_this_iter_s': 0.9851441383361816, '_training_iteration': 39}]\n","[{'epoch': 19, 'evaluate_acc': 0.0, 'test_loss': 0.04302686020074522, '_timestamp': 1653033953, '_time_this_iter_s': 0.1081991195678711, '_training_iteration': 40}]\n","[{'epoch': 20, 'train_acc': 0.0, 'train_loss': 0.060491774523896834, '_timestamp': 1653033954, '_time_this_iter_s': 0.9950211048126221, '_training_iteration': 41}]\n","[{'epoch': 20, 'evaluate_acc': 0.0, 'test_loss': 0.04074172014096642, '_timestamp': 1653033954, '_time_this_iter_s': 0.09588193893432617, '_training_iteration': 42}]\n","[{'epoch': 21, 'train_acc': 0.0, 'train_loss': 0.060222738861505476, '_timestamp': 1653033955, '_time_this_iter_s': 0.9660444259643555, '_training_iteration': 43}]\n","[{'epoch': 21, 'evaluate_acc': 0.0, 'test_loss': 0.04606892182217801, '_timestamp': 1653033955, '_time_this_iter_s': 0.09844017028808594, '_training_iteration': 44}]\n","[{'epoch': 22, 'train_acc': 0.0, 'train_loss': 0.060002223788095374, '_timestamp': 1653033956, '_time_this_iter_s': 0.9830670356750488, '_training_iteration': 45}]\n","[{'epoch': 22, 'evaluate_acc': 0.0, 'test_loss': 0.04337873318068245, '_timestamp': 1653033956, '_time_this_iter_s': 0.10351777076721191, '_training_iteration': 46}]\n","[{'epoch': 23, 'train_acc': 0.0, 'train_loss': 0.0600206001794764, '_timestamp': 1653033957, '_time_this_iter_s': 0.991854190826416, '_training_iteration': 47}]\n","[{'epoch': 23, 'evaluate_acc': 0.0, 'test_loss': 0.04502974217757583, '_timestamp': 1653033957, '_time_this_iter_s': 0.10401058197021484, '_training_iteration': 48}]\n","[{'epoch': 24, 'train_acc': 0.0, 'train_loss': 0.06031114665259208, '_timestamp': 1653033958, '_time_this_iter_s': 1.0301451683044434, '_training_iteration': 49}]\n","[{'epoch': 24, 'evaluate_acc': 0.0, 'test_loss': 0.04111280385404825, '_timestamp': 1653033958, '_time_this_iter_s': 0.0955967903137207, '_training_iteration': 50}]\n","[{'epoch': 25, 'train_acc': 0.0, 'train_loss': 0.060170894036335604, '_timestamp': 1653033959, '_time_this_iter_s': 0.978858470916748, '_training_iteration': 51}]\n","[{'epoch': 25, 'evaluate_acc': 0.0, 'test_loss': 0.041907111744341606, '_timestamp': 1653033959, '_time_this_iter_s': 0.10189700126647949, '_training_iteration': 52}]\n","[{'epoch': 26, 'train_acc': 0.0, 'train_loss': 0.059950036276131866, '_timestamp': 1653033960, '_time_this_iter_s': 1.0169415473937988, '_training_iteration': 53}]\n","[{'epoch': 26, 'evaluate_acc': 0.0, 'test_loss': 0.04607771421947023, '_timestamp': 1653033961, '_time_this_iter_s': 0.09785127639770508, '_training_iteration': 54}]\n","[{'epoch': 27, 'train_acc': 0.0, 'train_loss': 0.06003867073782853, '_timestamp': 1653033962, '_time_this_iter_s': 1.0707457065582275, '_training_iteration': 55}]\n","[{'epoch': 27, 'evaluate_acc': 0.0, 'test_loss': 0.047183933256960964, '_timestamp': 1653033962, '_time_this_iter_s': 0.10374617576599121, '_training_iteration': 56}]\n","[{'epoch': 28, 'train_acc': 0.0, 'train_loss': 0.05980566338236843, '_timestamp': 1653033963, '_time_this_iter_s': 1.0395805835723877, '_training_iteration': 57}]\n","[{'epoch': 28, 'evaluate_acc': 0.0, 'test_loss': 0.049405371825046396, '_timestamp': 1653033963, '_time_this_iter_s': 0.103973388671875, '_training_iteration': 58}]\n","[{'epoch': 29, 'train_acc': 0.0, 'train_loss': 0.05959741675427982, '_timestamp': 1653033964, '_time_this_iter_s': 1.002065896987915, '_training_iteration': 59}]\n","[{'epoch': 29, 'evaluate_acc': 0.0, 'test_loss': 0.03843427149524145, '_timestamp': 1653033964, '_time_this_iter_s': 0.09805178642272949, '_training_iteration': 60}]\n"]}],"source":["from raydp.torch import TorchEstimator\n","\n","estimator = TorchEstimator(num_workers=1, model=net_model, optimizer=optimizer, loss=criterion,\n"," feature_columns=features, feature_types=feature_dtype,\n"," label_column=\"stroke\", label_type=torch.float,\n"," batch_size=64, num_epochs=30, callbacks=[PrintingCallback()])\n","# Train the model\n","estimator.fit_on_spark(train_df, test_df)"]},{"cell_type":"markdown","metadata":{"id":"nHRY731sm4nR"},"source":["## 12. shut down ray and raydp"]},{"cell_type":"code","execution_count":92,"metadata":{"executionInfo":{"elapsed":1727,"status":"ok","timestamp":1653033966765,"user":{"displayName":"Ke Yan","userId":"12157985255311286389"},"user_tz":-480},"id":"dMt8Om94m9iH"},"outputs":[],"source":["raydp.stop_spark()\n","ray.shutdown()"]}],"metadata":{"colab":{"authorship_tag":"ABX9TyN/DQ/PIfLSmV6Kjq2Hn/Y5","collapsed_sections":[],"mount_file_id":"1zvWvMhBNUolMOVMfzYqepXB671v2KcPk","name":"pytorch_nyctaxi.ipynb","provenance":[]},"interpreter":{"hash":"4592069f3f0e7e931529bda2eb12f695b39a5cc01058a1b879fa2b8939b3a972"},"kernelspec":{"display_name":"Python 3.7.12 ('raydp')","language":"python","name":"python3"},"language_info":{"codemirror_mode":{"name":"ipython","version":3},"file_extension":".py","mimetype":"text/x-python","name":"python","nbconvert_exporter":"python","pygments_lexer":"ipython3","version":"3.7.12"}},"nbformat":4,"nbformat_minor":0} +{"cells":[{"cell_type":"markdown","metadata":{"id":"2TxtrYkpljYO"},"source":["# **How RayDP works together with Pytorch**"]},{"cell_type":"markdown","metadata":{"id":"PVGNVwUU9lGW"},"source":["RayDP is a distributed data processing library that provides simple APIs for running Spark on Ray and integrating Spark with distributed deep learning and machine learning frameworks. This document builds an end-to-end deep learning pipeline on a single Ray cluster by using Spark for data preprocessing, and uses distributed estimator based on the raydp api to complete the training and evaluation."]},{"cell_type":"markdown","metadata":{"id":"hQ0v0n-PhnhY"},"source":["[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/oap-project/raydp/blob/master/tutorials/pytorch_example.ipynb)"]},{"cell_type":"markdown","metadata":{"id":"JJCwQHcvlmqF"},"source":["## 1. Colab enviroment Setup"]},{"cell_type":"markdown","metadata":{"id":"iMRuDWw9qh29"},"source":["RayDP requires Ray and PySpark. At the same time, pytorch is used to build deep learning model."]},{"cell_type":"code","execution_count":74,"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"elapsed":13503,"status":"ok","timestamp":1653033884645,"user":{"displayName":"Ke Yan","userId":"12157985255311286389"},"user_tz":-480},"id":"SQplP7vYlCp2","outputId":"40a8aaf5-0187-4a58-e4f7-1175f2744d30"},"outputs":[{"name":"stdout","output_type":"stream","text":["Requirement already satisfied: ray==1.9 in /usr/local/lib/python3.7/dist-packages (1.9.0)\n","Requirement already satisfied: grpcio>=1.28.1 in /usr/local/lib/python3.7/dist-packages (from ray==1.9) (1.46.1)\n","Requirement already satisfied: pyyaml in /usr/local/lib/python3.7/dist-packages (from ray==1.9) (3.13)\n","Requirement already satisfied: attrs in /usr/local/lib/python3.7/dist-packages (from ray==1.9) (21.4.0)\n","Requirement already satisfied: numpy>=1.16 in /usr/local/lib/python3.7/dist-packages (from ray==1.9) (1.21.6)\n","Requirement already satisfied: click>=7.0 in /usr/local/lib/python3.7/dist-packages (from ray==1.9) (7.1.2)\n","Requirement already satisfied: msgpack<2.0.0,>=1.0.0 in /usr/local/lib/python3.7/dist-packages (from ray==1.9) (1.0.3)\n","Requirement already satisfied: protobuf>=3.15.3 in /usr/local/lib/python3.7/dist-packages (from ray==1.9) (3.17.3)\n","Requirement already satisfied: filelock in /usr/local/lib/python3.7/dist-packages (from ray==1.9) (3.7.0)\n","Requirement already satisfied: jsonschema in /usr/local/lib/python3.7/dist-packages (from ray==1.9) (4.3.3)\n","Requirement already satisfied: redis>=3.5.0 in /usr/local/lib/python3.7/dist-packages (from ray==1.9) (4.3.1)\n","Requirement already satisfied: six>=1.5.2 in /usr/local/lib/python3.7/dist-packages (from grpcio>=1.28.1->ray==1.9) (1.15.0)\n","Requirement already satisfied: typing-extensions in /usr/local/lib/python3.7/dist-packages (from redis>=3.5.0->ray==1.9) (4.2.0)\n","Requirement already satisfied: async-timeout>=4.0.2 in /usr/local/lib/python3.7/dist-packages (from redis>=3.5.0->ray==1.9) (4.0.2)\n","Requirement already satisfied: importlib-metadata>=1.0 in /usr/local/lib/python3.7/dist-packages (from redis>=3.5.0->ray==1.9) (4.11.3)\n","Requirement already satisfied: deprecated>=1.2.3 in /usr/local/lib/python3.7/dist-packages (from redis>=3.5.0->ray==1.9) (1.2.13)\n","Requirement already satisfied: packaging>=20.4 in /usr/local/lib/python3.7/dist-packages (from redis>=3.5.0->ray==1.9) (21.3)\n","Requirement already satisfied: wrapt<2,>=1.10 in /usr/local/lib/python3.7/dist-packages (from deprecated>=1.2.3->redis>=3.5.0->ray==1.9) (1.14.1)\n","Requirement already satisfied: zipp>=0.5 in /usr/local/lib/python3.7/dist-packages (from importlib-metadata>=1.0->redis>=3.5.0->ray==1.9) (3.8.0)\n","Requirement already satisfied: pyparsing!=3.0.5,>=2.0.2 in /usr/local/lib/python3.7/dist-packages (from packaging>=20.4->redis>=3.5.0->ray==1.9) (3.0.9)\n","Requirement already satisfied: pyrsistent!=0.17.0,!=0.17.1,!=0.17.2,>=0.14.0 in /usr/local/lib/python3.7/dist-packages (from jsonschema->ray==1.9) (0.18.1)\n","Requirement already satisfied: importlib-resources>=1.4.0 in /usr/local/lib/python3.7/dist-packages (from jsonschema->ray==1.9) (5.7.1)\n","Requirement already satisfied: raydp-nightly in /usr/local/lib/python3.7/dist-packages (2022.5.12.dev0)\n","Requirement already satisfied: pandas>=1.1.4 in /usr/local/lib/python3.7/dist-packages (from raydp-nightly) (1.3.5)\n","Requirement already satisfied: typing in /usr/local/lib/python3.7/dist-packages (from raydp-nightly) (3.7.4.3)\n","Requirement already satisfied: pyspark>=3.2.0 in /usr/local/lib/python3.7/dist-packages (from raydp-nightly) (3.2.1)\n","Requirement already satisfied: ray>=1.8.0 in /usr/local/lib/python3.7/dist-packages (from raydp-nightly) (1.9.0)\n","Requirement already satisfied: numpy in /usr/local/lib/python3.7/dist-packages (from raydp-nightly) (1.21.6)\n","Requirement already satisfied: netifaces in /usr/local/lib/python3.7/dist-packages (from raydp-nightly) (0.11.0)\n","Requirement already satisfied: pyarrow<7.0.0,>=4.0.1 in /usr/local/lib/python3.7/dist-packages (from raydp-nightly) (6.0.1)\n","Requirement already satisfied: psutil in /usr/local/lib/python3.7/dist-packages (from raydp-nightly) (5.4.8)\n","Requirement already satisfied: python-dateutil>=2.7.3 in /usr/local/lib/python3.7/dist-packages (from pandas>=1.1.4->raydp-nightly) (2.8.2)\n","Requirement already satisfied: pytz>=2017.3 in /usr/local/lib/python3.7/dist-packages (from pandas>=1.1.4->raydp-nightly) (2022.1)\n","Requirement already satisfied: py4j==0.10.9.3 in /usr/local/lib/python3.7/dist-packages (from pyspark>=3.2.0->raydp-nightly) (0.10.9.3)\n","Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.7/dist-packages (from python-dateutil>=2.7.3->pandas>=1.1.4->raydp-nightly) (1.15.0)\n","Requirement already satisfied: jsonschema in /usr/local/lib/python3.7/dist-packages (from ray>=1.8.0->raydp-nightly) (4.3.3)\n","Requirement already satisfied: filelock in /usr/local/lib/python3.7/dist-packages (from ray>=1.8.0->raydp-nightly) (3.7.0)\n","Requirement already satisfied: pyyaml in /usr/local/lib/python3.7/dist-packages (from ray>=1.8.0->raydp-nightly) (3.13)\n","Requirement already satisfied: attrs in /usr/local/lib/python3.7/dist-packages (from ray>=1.8.0->raydp-nightly) (21.4.0)\n","Requirement already satisfied: protobuf>=3.15.3 in /usr/local/lib/python3.7/dist-packages (from ray>=1.8.0->raydp-nightly) (3.17.3)\n","Requirement already satisfied: redis>=3.5.0 in /usr/local/lib/python3.7/dist-packages (from ray>=1.8.0->raydp-nightly) (4.3.1)\n","Requirement already satisfied: grpcio>=1.28.1 in /usr/local/lib/python3.7/dist-packages (from ray>=1.8.0->raydp-nightly) (1.46.1)\n","Requirement already satisfied: msgpack<2.0.0,>=1.0.0 in /usr/local/lib/python3.7/dist-packages (from ray>=1.8.0->raydp-nightly) (1.0.3)\n","Requirement already satisfied: click>=7.0 in /usr/local/lib/python3.7/dist-packages (from ray>=1.8.0->raydp-nightly) (7.1.2)\n","Requirement already satisfied: packaging>=20.4 in /usr/local/lib/python3.7/dist-packages (from redis>=3.5.0->ray>=1.8.0->raydp-nightly) (21.3)\n","Requirement already satisfied: importlib-metadata>=1.0 in /usr/local/lib/python3.7/dist-packages (from redis>=3.5.0->ray>=1.8.0->raydp-nightly) (4.11.3)\n","Requirement already satisfied: typing-extensions in /usr/local/lib/python3.7/dist-packages (from redis>=3.5.0->ray>=1.8.0->raydp-nightly) (4.2.0)\n","Requirement already satisfied: deprecated>=1.2.3 in /usr/local/lib/python3.7/dist-packages (from redis>=3.5.0->ray>=1.8.0->raydp-nightly) (1.2.13)\n","Requirement already satisfied: async-timeout>=4.0.2 in /usr/local/lib/python3.7/dist-packages (from redis>=3.5.0->ray>=1.8.0->raydp-nightly) (4.0.2)\n","Requirement already satisfied: wrapt<2,>=1.10 in /usr/local/lib/python3.7/dist-packages (from deprecated>=1.2.3->redis>=3.5.0->ray>=1.8.0->raydp-nightly) (1.14.1)\n","Requirement already satisfied: zipp>=0.5 in /usr/local/lib/python3.7/dist-packages (from importlib-metadata>=1.0->redis>=3.5.0->ray>=1.8.0->raydp-nightly) (3.8.0)\n","Requirement already satisfied: pyparsing!=3.0.5,>=2.0.2 in /usr/local/lib/python3.7/dist-packages (from packaging>=20.4->redis>=3.5.0->ray>=1.8.0->raydp-nightly) (3.0.9)\n","Requirement already satisfied: importlib-resources>=1.4.0 in /usr/local/lib/python3.7/dist-packages (from jsonschema->ray>=1.8.0->raydp-nightly) (5.7.1)\n","Requirement already satisfied: pyrsistent!=0.17.0,!=0.17.1,!=0.17.2,>=0.14.0 in /usr/local/lib/python3.7/dist-packages (from jsonschema->ray>=1.8.0->raydp-nightly) (0.18.1)\n","Requirement already satisfied: ray[tune] in /usr/local/lib/python3.7/dist-packages (1.9.0)\n","Requirement already satisfied: grpcio>=1.28.1 in /usr/local/lib/python3.7/dist-packages (from ray[tune]) (1.46.1)\n","Requirement already satisfied: numpy>=1.16 in /usr/local/lib/python3.7/dist-packages (from ray[tune]) (1.21.6)\n","Requirement already satisfied: filelock in /usr/local/lib/python3.7/dist-packages (from ray[tune]) (3.7.0)\n","Requirement already satisfied: jsonschema in /usr/local/lib/python3.7/dist-packages (from ray[tune]) (4.3.3)\n","Requirement already satisfied: msgpack<2.0.0,>=1.0.0 in /usr/local/lib/python3.7/dist-packages (from ray[tune]) (1.0.3)\n","Requirement already satisfied: pyyaml in /usr/local/lib/python3.7/dist-packages (from ray[tune]) (3.13)\n","Requirement already satisfied: attrs in /usr/local/lib/python3.7/dist-packages (from ray[tune]) (21.4.0)\n","Requirement already satisfied: click>=7.0 in /usr/local/lib/python3.7/dist-packages (from ray[tune]) (7.1.2)\n","Requirement already satisfied: redis>=3.5.0 in /usr/local/lib/python3.7/dist-packages (from ray[tune]) (4.3.1)\n","Requirement already satisfied: protobuf>=3.15.3 in /usr/local/lib/python3.7/dist-packages (from ray[tune]) (3.17.3)\n","Requirement already satisfied: tabulate in /usr/local/lib/python3.7/dist-packages (from ray[tune]) (0.8.9)\n","Requirement already satisfied: requests in /usr/local/lib/python3.7/dist-packages (from ray[tune]) (2.23.0)\n","Requirement already satisfied: pandas in /usr/local/lib/python3.7/dist-packages (from ray[tune]) (1.3.5)\n","Requirement already satisfied: tensorboardX>=1.9 in /usr/local/lib/python3.7/dist-packages (from ray[tune]) (2.5)\n","Requirement already satisfied: six>=1.5.2 in /usr/local/lib/python3.7/dist-packages (from grpcio>=1.28.1->ray[tune]) (1.15.0)\n","Requirement already satisfied: deprecated>=1.2.3 in /usr/local/lib/python3.7/dist-packages (from redis>=3.5.0->ray[tune]) (1.2.13)\n","Requirement already satisfied: importlib-metadata>=1.0 in /usr/local/lib/python3.7/dist-packages (from redis>=3.5.0->ray[tune]) (4.11.3)\n","Requirement already satisfied: typing-extensions in /usr/local/lib/python3.7/dist-packages (from redis>=3.5.0->ray[tune]) (4.2.0)\n","Requirement already satisfied: packaging>=20.4 in /usr/local/lib/python3.7/dist-packages (from redis>=3.5.0->ray[tune]) (21.3)\n","Requirement already satisfied: async-timeout>=4.0.2 in /usr/local/lib/python3.7/dist-packages (from redis>=3.5.0->ray[tune]) (4.0.2)\n","Requirement already satisfied: wrapt<2,>=1.10 in /usr/local/lib/python3.7/dist-packages (from deprecated>=1.2.3->redis>=3.5.0->ray[tune]) (1.14.1)\n","Requirement already satisfied: zipp>=0.5 in /usr/local/lib/python3.7/dist-packages (from importlib-metadata>=1.0->redis>=3.5.0->ray[tune]) (3.8.0)\n","Requirement already satisfied: pyparsing!=3.0.5,>=2.0.2 in /usr/local/lib/python3.7/dist-packages (from packaging>=20.4->redis>=3.5.0->ray[tune]) (3.0.9)\n","Requirement already satisfied: pyrsistent!=0.17.0,!=0.17.1,!=0.17.2,>=0.14.0 in /usr/local/lib/python3.7/dist-packages (from jsonschema->ray[tune]) (0.18.1)\n","Requirement already satisfied: importlib-resources>=1.4.0 in /usr/local/lib/python3.7/dist-packages (from jsonschema->ray[tune]) (5.7.1)\n","Requirement already satisfied: pytz>=2017.3 in /usr/local/lib/python3.7/dist-packages (from pandas->ray[tune]) (2022.1)\n","Requirement already satisfied: python-dateutil>=2.7.3 in /usr/local/lib/python3.7/dist-packages (from pandas->ray[tune]) (2.8.2)\n","Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.7/dist-packages (from requests->ray[tune]) (2021.10.8)\n","Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.7/dist-packages (from requests->ray[tune]) (1.24.3)\n","Requirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.7/dist-packages (from requests->ray[tune]) (3.0.4)\n","Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.7/dist-packages (from requests->ray[tune]) (2.10)\n","Looking in links: https://download.pytorch.org/whl/torch_stable.html\n","Requirement already satisfied: torch==1.8.1+cpu in /usr/local/lib/python3.7/dist-packages (1.8.1+cpu)\n","Requirement already satisfied: typing-extensions in /usr/local/lib/python3.7/dist-packages (from torch==1.8.1+cpu) (4.2.0)\n","Requirement already satisfied: numpy in /usr/local/lib/python3.7/dist-packages (from torch==1.8.1+cpu) (1.21.6)\n"]}],"source":["! pip install ray==1.9\n","! pip install raydp==0.5.0\n","! pip install ray[tune]\n","! pip install torch==1.8.1+cpu -f https://download.pytorch.org/whl/torch_stable.html\n","! pip install torchmetrics"]},{"cell_type":"markdown","metadata":{"id":"XWw2fBhfqvME"},"source":["## 2. Get the data file"]},{"cell_type":"markdown","metadata":{"id":"pFZtArpXqyS2"},"source":["The dataset is from: https://www.kaggle.com/datasets/fedesoriano/stroke-prediction-dataset, and we store the file in github repository. It's used to predict whether a patient is likely to get stroke based on the input parameters like gender, age, various diseases, and smoking status. Each row in the data provides relavant information about the patient. "]},{"cell_type":"code","execution_count":null,"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"elapsed":17,"status":"ok","timestamp":1653033884646,"user":{"displayName":"Ke Yan","userId":"12157985255311286389"},"user_tz":-480},"id":"J-Hzp_q7qzz5","outputId":"5dae58ce-a40a-47f1-cb1b-6485cf15b949"},"outputs":[],"source":["! wget https://raw.githubusercontent.com/oap-project/raydp/master/tutorials/dataset/healthcare-dataset-stroke-data.csv -O healthcare-dataset-stroke-data.csv"]},{"cell_type":"markdown","metadata":{"id":"Y-ceLv2-q__G"},"source":["## 3. Init or connect to a ray cluster"]},{"cell_type":"code","execution_count":76,"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"elapsed":3705,"status":"ok","timestamp":1653033888338,"user":{"displayName":"Ke Yan","userId":"12157985255311286389"},"user_tz":-480},"id":"CVFQcpiJrC8Y","outputId":"bbbbc37c-9834-4c39-e652-0178842f0e7a"},"outputs":[{"data":{"text/plain":["{'metrics_export_port': 59076,\n"," 'node_id': '4a3abd8dce03a2b3e9e76b935b01c20eb501d3d5135929e1eadeccae',\n"," 'node_ip_address': '172.28.0.2',\n"," 'object_store_address': '/tmp/ray/session_2022-05-20_08-04-44_153620_58/sockets/plasma_store',\n"," 'raylet_ip_address': '172.28.0.2',\n"," 'raylet_socket_name': '/tmp/ray/session_2022-05-20_08-04-44_153620_58/sockets/raylet',\n"," 'redis_address': '172.28.0.2:6379',\n"," 'session_dir': '/tmp/ray/session_2022-05-20_08-04-44_153620_58',\n"," 'webui_url': None}"]},"execution_count":76,"metadata":{},"output_type":"execute_result"}],"source":["import ray\n","\n","ray.init(num_cpus=6)"]},{"cell_type":"markdown","metadata":{"id":"MvuatpSsrGHD"},"source":["## 4. Get a spark session"]},{"cell_type":"code","execution_count":77,"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"elapsed":6744,"status":"ok","timestamp":1653033895074,"user":{"displayName":"Ke Yan","userId":"12157985255311286389"},"user_tz":-480},"id":"ze-waQ7zrH79","outputId":"df61b5b8-4c10-4a08-b5b9-d2ead5a1e6a2"},"outputs":[{"name":"stdout","output_type":"stream","text":["\u001b[2m\u001b[36m(RayDPSparkMaster pid=4514)\u001b[0m WARNING: sun.reflect.Reflection.getCallerClass is not supported. This will impact performance.\n"]},{"name":"stderr","output_type":"stream","text":["\u001b[2m\u001b[36m(RayDPSparkMaster pid=4514)\u001b[0m WARNING: An illegal reflective access operation has occurred\n","\u001b[2m\u001b[36m(RayDPSparkMaster pid=4514)\u001b[0m WARNING: Illegal reflective access by org.apache.spark.unsafe.Platform (file:/usr/local/lib/python3.7/dist-packages/pyspark/jars/spark-unsafe_2.12-3.2.1.jar) to constructor java.nio.DirectByteBuffer(long,int)\n","\u001b[2m\u001b[36m(RayDPSparkMaster pid=4514)\u001b[0m WARNING: Please consider reporting this to the maintainers of org.apache.spark.unsafe.Platform\n","\u001b[2m\u001b[36m(RayDPSparkMaster pid=4514)\u001b[0m WARNING: Use --illegal-access=warn to enable warnings of further illegal reflective access operations\n","\u001b[2m\u001b[36m(RayDPSparkMaster pid=4514)\u001b[0m WARNING: All illegal access operations will be denied in a future release\n"]},{"name":"stdout","output_type":"stream","text":["\u001b[2m\u001b[36m(RayDPSparkMaster pid=4514)\u001b[0m 2022-05-20 08:04:52,275 WARN NativeCodeLoader [Thread-2]: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable\n","\u001b[2m\u001b[36m(RayDPSparkMaster pid=4514)\u001b[0m 2022-05-20 08:04:52,602 INFO SecurityManager [Thread-2]: Changing view acls to: root\n","\u001b[2m\u001b[36m(RayDPSparkMaster pid=4514)\u001b[0m 2022-05-20 08:04:52,604 INFO SecurityManager [Thread-2]: Changing modify acls to: root\n","\u001b[2m\u001b[36m(RayDPSparkMaster pid=4514)\u001b[0m 2022-05-20 08:04:52,605 INFO SecurityManager [Thread-2]: Changing view acls groups to: \n","\u001b[2m\u001b[36m(RayDPSparkMaster pid=4514)\u001b[0m 2022-05-20 08:04:52,606 INFO SecurityManager [Thread-2]: Changing modify acls groups to: \n","\u001b[2m\u001b[36m(RayDPSparkMaster pid=4514)\u001b[0m 2022-05-20 08:04:52,607 INFO SecurityManager [Thread-2]: SecurityManager: authentication disabled; ui acls disabled; users with view permissions: Set(root); groups with view permissions: Set(); users with modify permissions: Set(root); groups with modify permissions: Set()\n","\u001b[2m\u001b[36m(RayDPSparkMaster pid=4514)\u001b[0m 2022-05-20 08:04:53,404 INFO Utils [Thread-2]: Successfully started service 'RAY_RPC_ENV' on port 45691.\n","\u001b[2m\u001b[36m(RayDPSparkMaster pid=4514)\u001b[0m 2022-05-20 08:04:54,308 INFO RayAppMaster$RayAppMasterEndpoint [dispatcher-event-loop-1]: Registering app Stoke Prediction with RayDP\n","\u001b[2m\u001b[36m(RayDPSparkMaster pid=4514)\u001b[0m 2022-05-20 08:04:54,326 INFO RayAppMaster$RayAppMasterEndpoint [dispatcher-event-loop-1]: Registered app Stoke Prediction with RayDP with ID app-20220520080454-0000\n"]}],"source":["import raydp\n","\n","app_name = \"Stoke Prediction with RayDP\"\n","num_executors = 1\n","cores_per_executor = 1\n","memory_per_executor = \"500M\"\n","spark = raydp.init_spark(app_name, num_executors, cores_per_executor, memory_per_executor)"]},{"cell_type":"markdown","metadata":{"id":"AYdTc0wQrMSQ"},"source":["## 5. Get data from .csv file via 'spark' created by **raydp**"]},{"cell_type":"code","execution_count":78,"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"elapsed":14661,"status":"ok","timestamp":1653033909730,"user":{"displayName":"Ke Yan","userId":"12157985255311286389"},"user_tz":-480},"id":"2vVAO9gwrNY9","outputId":"ef0717b9-6b25-4327-8b66-b33789440895"},"outputs":[{"name":"stderr","output_type":"stream","text":["\u001b[2m\u001b[33m(raylet)\u001b[0m WARNING: An illegal reflective access operation has occurred\n","\u001b[2m\u001b[33m(raylet)\u001b[0m WARNING: Illegal reflective access by org.apache.spark.unsafe.Platform (file:/usr/local/lib/python3.7/dist-packages/pyspark/jars/spark-unsafe_2.12-3.2.1.jar) to constructor java.nio.DirectByteBuffer(long,int)\n","\u001b[2m\u001b[33m(raylet)\u001b[0m WARNING: Please consider reporting this to the maintainers of org.apache.spark.unsafe.Platform\n","\u001b[2m\u001b[33m(raylet)\u001b[0m WARNING: Use --illegal-access=warn to enable warnings of further illegal reflective access operations\n","\u001b[2m\u001b[33m(raylet)\u001b[0m WARNING: All illegal access operations will be denied in a future release\n"]}],"source":["data = spark.read.format(\"csv\").option(\"header\", \"true\") \\\n"," .option(\"inferSchema\", \"true\") \\\n"," .load(\"/content/healthcare-dataset-stroke-data.csv\")"]},{"cell_type":"markdown","metadata":{"id":"n8UO4brnrP_V"},"source":["## 6. Define the data_process function"]},{"cell_type":"markdown","metadata":{"id":"T1TJVdSCrSfg"},"source":["The dataset is converted to `pyspark.sql.dataframe.DataFrame`. Before feeding into the deep learning model, we can use raydp to do some transformation operations on dataset."]},{"cell_type":"markdown","metadata":{"id":"xwC6GDVrrU0s"},"source":["### 6.1 Data Analysis"]},{"cell_type":"markdown","metadata":{"id":"K_VMeCUDrZOd"},"source":["Here is a part of the data analysis."]},{"cell_type":"code","execution_count":79,"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"elapsed":6469,"status":"ok","timestamp":1653033916189,"user":{"displayName":"Ke Yan","userId":"12157985255311286389"},"user_tz":-480},"id":"rjF-QfKsrXTF","outputId":"41f2a242-caf7-4102-f9ff-0e50f1790736"},"outputs":[{"name":"stdout","output_type":"stream","text":["+-----+------+----+------------+-------------+------------+-------------+--------------+-----------------+----+---------------+------+\n","| id|gender| age|hypertension|heart_disease|ever_married| work_type|Residence_type|avg_glucose_level| bmi| smoking_status|stroke|\n","+-----+------+----+------------+-------------+------------+-------------+--------------+-----------------+----+---------------+------+\n","| 9046| Male|67.0| 0| 1| Yes| Private| Urban| 228.69|36.6|formerly smoked| 1|\n","|51676|Female|61.0| 0| 0| Yes|Self-employed| Rural| 202.21| N/A| never smoked| 1|\n","|31112| Male|80.0| 0| 1| Yes| Private| Rural| 105.92|32.5| never smoked| 1|\n","|60182|Female|49.0| 0| 0| Yes| Private| Urban| 171.23|34.4| smokes| 1|\n","| 1665|Female|79.0| 1| 0| Yes|Self-employed| Rural| 174.12| 24| never smoked| 1|\n","+-----+------+----+------------+-------------+------------+-------------+--------------+-----------------+----+---------------+------+\n","only showing top 5 rows\n","\n","+-------+-----------------+------+------------------+------------------+-------------------+------------+---------+--------------+------------------+------------------+--------------+-------------------+\n","|summary| id|gender| age| hypertension| heart_disease|ever_married|work_type|Residence_type| avg_glucose_level| bmi|smoking_status| stroke|\n","+-------+-----------------+------+------------------+------------------+-------------------+------------+---------+--------------+------------------+------------------+--------------+-------------------+\n","| count| 5110| 5110| 5110| 5110| 5110| 5110| 5110| 5110| 5110| 5110| 5110| 5110|\n","| mean|36517.82935420744| null|43.226614481409015|0.0974559686888454|0.05401174168297456| null| null| null|106.14767710371804|28.893236911794673| null| 0.0487279843444227|\n","| stddev|21161.72162482715| null| 22.61264672311348| 0.296606674233791|0.22606298750336554| null| null| null| 45.28356015058193| 7.85406672968016| null|0.21531985698023753|\n","| min| 67|Female| 0.08| 0| 0| No| Govt_job| Rural| 55.12| 10.3| Unknown| 0|\n","| max| 72940| Other| 82.0| 1| 1| Yes| children| Urban| 271.74| N/A| smokes| 1|\n","+-------+-----------------+------+------------------+------------------+-------------------+------------+---------+--------------+------------------+------------------+--------------+-------------------+\n","\n","+------+-----+\n","|gender|count|\n","+------+-----+\n","| Male| 2115|\n","| null| 5110|\n","|Female| 2994|\n","| Other| 1|\n","+------+-----+\n","\n","+------+-----+\n","|stroke|count|\n","+------+-----+\n","| 1| 249|\n","| 0| 4861|\n","| null| 5110|\n","+------+-----+\n","\n"]}],"source":["# Data overview\n","data.show(5)\n","# Statistical N/A distribution\n","# There are 201 'N/A' value in column 'bmi column',\n","# we can update them the mean of the column\n","data.describe().show()\n","data.filter(data.bmi=='N/A').count()\n","# Observe the distribution of the column 'gender'\n","# Then we should remove the outliers 'Other'\n","data.rollup(data.gender).count().show()\n","# Observe the proportion of positive and negative samples.\n","data.rollup(data.stroke).count().show()"]},{"cell_type":"markdown","metadata":{"id":"BwQIFmPArb4t"},"source":["### 6.2 Define operations"]},{"cell_type":"markdown","metadata":{"id":"nWQp7YiJsYU2"},"source":["Define data processing operations based on data analysis results."]},{"cell_type":"code","execution_count":80,"metadata":{"executionInfo":{"elapsed":14,"status":"ok","timestamp":1653033916189,"user":{"displayName":"Ke Yan","userId":"12157985255311286389"},"user_tz":-480},"id":"pYZLIkqwrduU"},"outputs":[],"source":["from pyspark.sql.functions import hour, quarter, month, year, dayofweek, dayofmonth, weekofyear, col, lit, udf, abs as functions_abs, avg"]},{"cell_type":"code","execution_count":81,"metadata":{"executionInfo":{"elapsed":14,"status":"ok","timestamp":1653033916189,"user":{"displayName":"Ke Yan","userId":"12157985255311286389"},"user_tz":-480},"id":"YqnWEwlMrkJ3"},"outputs":[],"source":["# Delete the useless column 'id'\n","def drop_col(data):\n"," data = data.drop('id')\n"," return data"]},{"cell_type":"code","execution_count":82,"metadata":{"executionInfo":{"elapsed":14,"status":"ok","timestamp":1653033916190,"user":{"displayName":"Ke Yan","userId":"12157985255311286389"},"user_tz":-480},"id":"x4FGXwNxrlK_"},"outputs":[],"source":["# Replace the value N/A in 'bmi'\n","def replace_nan(data):\n"," bmi_avg = data.agg(avg(col(\"bmi\"))).head()[0]\n","\n"," @udf(\"float\")\n"," def replace_nan(value):\n"," if value=='N/A':\n"," return float(bmi_avg)\n"," else:\n"," return float(value)\n","\n"," # Replace the value N/A\n"," data = data.withColumn('bmi', replace_nan(col(\"bmi\")))\n"," return data"]},{"cell_type":"code","execution_count":83,"metadata":{"executionInfo":{"elapsed":13,"status":"ok","timestamp":1653033916190,"user":{"displayName":"Ke Yan","userId":"12157985255311286389"},"user_tz":-480},"id":"JSAv6q_3rnmn"},"outputs":[],"source":["# Drop the only one value 'Other' in column 'gender'\n","def clean_value(data):\n"," data = data.filter(data.gender != 'Other')\n"," return data"]},{"cell_type":"code","execution_count":84,"metadata":{"executionInfo":{"elapsed":463,"status":"ok","timestamp":1653033916640,"user":{"displayName":"Ke Yan","userId":"12157985255311286389"},"user_tz":-480},"id":"24PMi6OHrpDy"},"outputs":[],"source":["# Transform the category columns\n","def trans_category(data):\n"," @udf(\"int\")\n"," def trans_gender(value):\n"," gender = {'Female': 0,\n"," 'Male': 1}\n"," return int(gender[value])\n","\n"," @udf(\"int\")\n"," def trans_ever_married(value):\n"," residence_type = {'No': 0,\n"," 'Yes': 1}\n"," return int(residence_type[value])\n","\n"," @udf(\"int\")\n"," def trans_work_type(value):\n"," work_type = {'children': 0,\n"," 'Govt_job': 1,\n"," 'Never_worked': 2,\n"," 'Private': 3,\n"," 'Self-employed': 4}\n"," return int(work_type[value])\n","\n"," @udf(\"int\")\n"," def trans_residence_type(value):\n"," residence_type = {'Rural': 0,\n"," 'Urban': 1}\n"," return int(residence_type[value])\n","\n"," @udf(\"int\")\n"," def trans_smoking_status(value):\n"," smoking_status = {'formerly smoked': 0,\n"," 'never smoked': 1,\n"," 'smokes': 2,\n"," 'Unknown': 3}\n"," return int(smoking_status[value])\n","\n"," data = data.withColumn('gender', trans_gender(col('gender'))) \\\n"," .withColumn('ever_married', trans_ever_married(col('ever_married'))) \\\n"," .withColumn('work_type', trans_work_type(col('work_type'))) \\\n"," .withColumn('Residence_type', trans_residence_type(col('Residence_type'))) \\\n"," .withColumn('smoking_status', trans_smoking_status(col('smoking_status')))\n"," return data"]},{"cell_type":"code","execution_count":85,"metadata":{"executionInfo":{"elapsed":5,"status":"ok","timestamp":1653033916641,"user":{"displayName":"Ke Yan","userId":"12157985255311286389"},"user_tz":-480},"id":"_V3UsmQwrq3I"},"outputs":[],"source":["# Add the discretized column of 'Age'\n","def map_age(data):\n"," @udf(\"int\")\n"," def get_value(value):\n"," if value >= 18 and value < 26:\n"," return int(0)\n"," elif value >=26 and value < 36:\n"," return int(1)\n"," elif value >=36 and value < 46:\n"," return int(2)\n"," elif value >=46 and value < 56:\n"," return int(3)\n"," else:\n"," return int(4)\n","\n"," data = data.withColumn('age_dis', get_value(col('age')))\n"," return data"]},{"cell_type":"code","execution_count":86,"metadata":{"executionInfo":{"elapsed":4,"status":"ok","timestamp":1653033916641,"user":{"displayName":"Ke Yan","userId":"12157985255311286389"},"user_tz":-480},"id":"-X0yYDz1rsnA"},"outputs":[],"source":["# Preprocess the data\n","def data_preprocess(data):\n"," data = drop_col(data)\n"," data = replace_nan(data)\n"," data = clean_value(data)\n"," data = trans_category(data)\n"," data = map_age(data)\n"," return data"]},{"cell_type":"markdown","metadata":{"id":"dfgFQdRRruS5"},"source":["## 7. Data processing"]},{"cell_type":"code","execution_count":87,"metadata":{"executionInfo":{"elapsed":10575,"status":"ok","timestamp":1653033927213,"user":{"displayName":"Ke Yan","userId":"12157985255311286389"},"user_tz":-480},"id":"vrudeEb1ryy2"},"outputs":[],"source":["import torch\n","from raydp.utils import random_split\n","\n","# Transform the dataset\n","data = data_preprocess(data)\n","# Split data into train_dataset and test_dataset\n","train_df, test_df = random_split(data, [0.8, 0.2], 0)\n","# Balance the positive and negative samples\n","train_df_neg = train_df.filter(train_df.stroke == '1')\n","train_df = train_df.unionByName(train_df_neg)\n","train_df = train_df.unionByName(train_df_neg)\n","train_df = train_df.unionByName(train_df_neg)\n","features = [field.name for field in list(train_df.schema) if field.name != \"stroke\"]\n","# Convert spark dataframe into ray Dataset\n","# Remember to align ``parallelism`` with ``num_workers`` of ray train\n","train_dataset = ray.data.from_spark(train_df, parallelism = 8)\n","test_dataset = ray.data.from_spark(test_df, parallelism = 8)\n","feature_dtype = [torch.float] * len(features)"]},{"cell_type":"markdown","metadata":{"id":"B42j2xAfr3Wi"},"source":["## 8. Define a neural network model"]},{"cell_type":"code","execution_count":88,"metadata":{"executionInfo":{"elapsed":9,"status":"ok","timestamp":1653033927213,"user":{"displayName":"Ke Yan","userId":"12157985255311286389"},"user_tz":-480},"id":"Zd69EFv9r58U"},"outputs":[],"source":["import torch.nn as nn\n","import torch.nn.functional as F\n","\n","class NET_Model(nn.Module):\n"," def __init__(self, cols):\n"," super().__init__()\n"," self.emb_layer_gender = nn.Embedding(2, 1) # gender\n"," self.emb_layer_hypertension = nn.Embedding(2,1) # hypertension\n"," self.emb_layer_heart_disease = nn.Embedding(2,1) # heart_disease\n"," self.emb_layer_ever_married = nn.Embedding(2, 1) # ever_married\n"," self.emb_layer_work = nn.Embedding(5, 1) # work_type\n"," self.emb_layer_residence = nn.Embedding(2, 1) # Residence_type\n"," self.emb_layer_smoking_status = nn.Embedding(4, 1) # smoking_status\n"," self.emb_layer_age = nn.Embedding(5, 1) # age column after discretization\n"," self.fc1 = nn.Linear(cols, 256)\n"," self.fc2 = nn.Linear(256, 128)\n"," self.fc3 = nn.Linear(128, 64)\n"," self.fc4 = nn.Linear(64, 16)\n"," self.fc5 = nn.Linear(16, 1)\n"," self.bn1 = nn.BatchNorm1d(256)\n"," self.bn2 = nn.BatchNorm1d(128)\n"," self.bn3 = nn.BatchNorm1d(64)\n"," self.bn4 = nn.BatchNorm1d(16)\n","\n"," def forward(self, *x):\n"," x = torch.cat(x, dim=1)\n"," # pick the dense attribute columns\n"," dense_columns = x[:, [1,7,8]]\n"," # Embedding operation on sparse attribute columns\n"," sparse_col_1 = self.emb_layer_gender(x[:, 0].long())\n"," sparse_col_2 = self.emb_layer_hypertension(x[:, 2].long())\n"," sparse_col_3 = self.emb_layer_heart_disease(x[:, 3].long())\n"," sparse_col_4 = self.emb_layer_ever_married(x[:, 4].long())\n"," sparse_col_5 = self.emb_layer_work(x[:, 5].long())\n"," sparse_col_6 = self.emb_layer_residence(x[:, 6].long())\n"," sparse_col_7 = self.emb_layer_smoking_status(x[:, 9].long())\n"," sparse_col_8 = self.emb_layer_age(x[:, 10].long())\n"," # Splice sparse attribute columns and dense attribute columns\n"," x = torch.cat([dense_columns, sparse_col_1, sparse_col_2, sparse_col_3, sparse_col_4, sparse_col_5, sparse_col_6, sparse_col_7, sparse_col_8], dim=1)\n","\n"," x = F.relu(self.fc1(x))\n"," x = self.bn1(x)\n"," x = F.relu(self.fc2(x))\n"," x = self.bn2(x)\n"," x = F.relu(self.fc3(x))\n"," x = self.bn3(x)\n"," x = F.relu(self.fc4(x))\n"," x = self.bn4(x)\n"," x = self.fc5(x)\n"," return x\n"]},{"cell_type":"markdown","metadata":{"id":"Z85SeUPNmdXj"},"source":["## 9. Create model, critetion and optimizer"]},{"cell_type":"code","execution_count":89,"metadata":{"executionInfo":{"elapsed":7,"status":"ok","timestamp":1653033927214,"user":{"displayName":"Ke Yan","userId":"12157985255311286389"},"user_tz":-480},"id":"BM1BOwrgmjzN"},"outputs":[],"source":["import torch\n","import torch.nn as nn\n","\n","net_model = NET_Model(len(features))\n","criterion = nn.SmoothL1Loss()\n","optimizer = torch.optim.Adam(net_model.parameters(), lr=0.001)"]},{"cell_type":"markdown","metadata":{"id":"HwJdanB8TwCE"},"source":["## 10. Define the Callback which will be executed during training."]},{"cell_type":"code","execution_count":90,"metadata":{"executionInfo":{"elapsed":7,"status":"ok","timestamp":1653033927215,"user":{"displayName":"Ke Yan","userId":"12157985255311286389"},"user_tz":-480},"id":"emnm-i8YTwhS"},"outputs":[],"source":["from ray.train import TrainingCallback\n","from typing import List, Dict\n","\n","class PrintingCallback(TrainingCallback):\n"," def handle_result(self, results: List[Dict], **info):\n"," print(results)"]},{"cell_type":"markdown","metadata":{"id":"_csycx41mpuj"},"source":["## 11. Create distributed estimator and train"]},{"cell_type":"code","execution_count":91,"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"elapsed":37833,"status":"ok","timestamp":1653033965041,"user":{"displayName":"Ke Yan","userId":"12157985255311286389"},"user_tz":-480},"id":"heY9ullomyzh","outputId":"9f882156-65a4-4d89-cafd-5265bae4ef2d"},"outputs":[{"name":"stderr","output_type":"stream","text":["2022-05-20 08:05:28,725\tINFO trainer.py:172 -- Trainer logs will be logged in: /root/ray_results/train_2022-05-20_08-05-28\n","2022-05-20 08:05:30,430\tINFO trainer.py:178 -- Run results will be logged in: /root/ray_results/train_2022-05-20_08-05-28/run_001\n","\u001b[2m\u001b[36m(BaseWorkerMixin pid=5164)\u001b[0m 2022-05-20 08:05:30,427\tINFO torch.py:67 -- Setting up process group for: env:// [rank=0, world_size=1]\n","\u001b[2m\u001b[36m(BaseWorkerMixin pid=5164)\u001b[0m 2022-05-20 08:05:30,701\tINFO torch.py:239 -- Moving model to device: cpu\n","\u001b[2m\u001b[36m(BaseWorkerMixin pid=5164)\u001b[0m /usr/local/lib/python3.7/dist-packages/torch/nn/modules/loss.py:907: UserWarning: Using a target size (torch.Size([64, 1])) that is different to the input size (torch.Size([64, 2])). This will likely lead to incorrect results due to broadcasting. Please ensure they have the same size.\n","\u001b[2m\u001b[36m(BaseWorkerMixin pid=5164)\u001b[0m return F.smooth_l1_loss(input, target, reduction=self.reduction, beta=self.beta)\n"]},{"name":"stdout","output_type":"stream","text":["[{'epoch': 0, 'train_acc': 0.0, 'train_loss': 0.09247553693130613, '_timestamp': 1653033932, '_time_this_iter_s': 1.3748996257781982, '_training_iteration': 1}]\n","[{'epoch': 0, 'evaluate_acc': 0.0, 'test_loss': 0.04560316858046195, '_timestamp': 1653033932, '_time_this_iter_s': 0.09927511215209961, '_training_iteration': 2}]\n"]},{"name":"stderr","output_type":"stream","text":["\u001b[2m\u001b[36m(BaseWorkerMixin pid=5164)\u001b[0m /usr/local/lib/python3.7/dist-packages/torch/nn/modules/loss.py:907: UserWarning: Using a target size (torch.Size([38, 1])) that is different to the input size (torch.Size([38, 2])). This will likely lead to incorrect results due to broadcasting. Please ensure they have the same size.\n","\u001b[2m\u001b[36m(BaseWorkerMixin pid=5164)\u001b[0m return F.smooth_l1_loss(input, target, reduction=self.reduction, beta=self.beta)\n","\u001b[2m\u001b[36m(BaseWorkerMixin pid=5164)\u001b[0m /usr/local/lib/python3.7/dist-packages/torch/nn/modules/loss.py:907: UserWarning: Using a target size (torch.Size([29, 1])) that is different to the input size (torch.Size([29, 2])). This will likely lead to incorrect results due to broadcasting. Please ensure they have the same size.\n","\u001b[2m\u001b[36m(BaseWorkerMixin pid=5164)\u001b[0m return F.smooth_l1_loss(input, target, reduction=self.reduction, beta=self.beta)\n"]},{"name":"stdout","output_type":"stream","text":["[{'epoch': 1, 'train_acc': 0.0, 'train_loss': 0.06568372739878084, '_timestamp': 1653033933, '_time_this_iter_s': 1.012242317199707, '_training_iteration': 3}]\n","[{'epoch': 1, 'evaluate_acc': 0.0, 'test_loss': 0.051923071308171045, '_timestamp': 1653033933, '_time_this_iter_s': 0.10902786254882812, '_training_iteration': 4}]\n","[{'epoch': 2, 'train_acc': 0.0, 'train_loss': 0.062333274386557086, '_timestamp': 1653033934, '_time_this_iter_s': 0.9877479076385498, '_training_iteration': 5}]\n","[{'epoch': 2, 'evaluate_acc': 0.0, 'test_loss': 0.042880474863683474, '_timestamp': 1653033934, '_time_this_iter_s': 0.10124373435974121, '_training_iteration': 6}]\n","[{'epoch': 3, 'train_acc': 0.0, 'train_loss': 0.062327381250049385, '_timestamp': 1653033935, '_time_this_iter_s': 1.005267858505249, '_training_iteration': 7}]\n","[{'epoch': 3, 'evaluate_acc': 0.0, 'test_loss': 0.04614457756500034, '_timestamp': 1653033935, '_time_this_iter_s': 0.10248780250549316, '_training_iteration': 8}]\n","[{'epoch': 4, 'train_acc': 0.0, 'train_loss': 0.06244271249909486, '_timestamp': 1653033936, '_time_this_iter_s': 1.0482499599456787, '_training_iteration': 9}]\n","[{'epoch': 4, 'evaluate_acc': 0.0, 'test_loss': 0.04239819656290552, '_timestamp': 1653033936, '_time_this_iter_s': 0.09851694107055664, '_training_iteration': 10}]\n","[{'epoch': 5, 'train_acc': 0.0, 'train_loss': 0.062068206830216306, '_timestamp': 1653033937, '_time_this_iter_s': 1.0023341178894043, '_training_iteration': 11}]\n","[{'epoch': 5, 'evaluate_acc': 0.0, 'test_loss': 0.04345962876344428, '_timestamp': 1653033937, '_time_this_iter_s': 0.10348176956176758, '_training_iteration': 12}]\n","[{'epoch': 6, 'train_acc': 0.0, 'train_loss': 0.0617491880032633, '_timestamp': 1653033938, '_time_this_iter_s': 1.002626657485962, '_training_iteration': 13}]\n","[{'epoch': 6, 'evaluate_acc': 0.0, 'test_loss': 0.040240987368366295, '_timestamp': 1653033938, '_time_this_iter_s': 0.10725951194763184, '_training_iteration': 14}]\n","[{'epoch': 7, 'train_acc': 0.0, 'train_loss': 0.06170101864263415, '_timestamp': 1653033939, '_time_this_iter_s': 0.9948806762695312, '_training_iteration': 15}]\n","[{'epoch': 7, 'evaluate_acc': 0.0, 'test_loss': 0.04413437528316589, '_timestamp': 1653033939, '_time_this_iter_s': 0.11219525337219238, '_training_iteration': 16}]\n","[{'epoch': 8, 'train_acc': 0.0, 'train_loss': 0.06111902300534504, '_timestamp': 1653033940, '_time_this_iter_s': 1.0135900974273682, '_training_iteration': 17}]\n","[{'epoch': 8, 'evaluate_acc': 0.0, 'test_loss': 0.0439556289519019, '_timestamp': 1653033941, '_time_this_iter_s': 0.11552095413208008, '_training_iteration': 18}]\n","[{'epoch': 9, 'train_acc': 0.0, 'train_loss': 0.06117622062031712, '_timestamp': 1653033942, '_time_this_iter_s': 1.0200884342193604, '_training_iteration': 19}]\n","[{'epoch': 9, 'evaluate_acc': 0.0, 'test_loss': 0.04169800117447534, '_timestamp': 1653033942, '_time_this_iter_s': 0.09994888305664062, '_training_iteration': 20}]\n","[{'epoch': 10, 'train_acc': 0.0, 'train_loss': 0.06109746584136571, '_timestamp': 1653033943, '_time_this_iter_s': 1.0086157321929932, '_training_iteration': 21}]\n","[{'epoch': 10, 'evaluate_acc': 0.0, 'test_loss': 0.04427622495602597, '_timestamp': 1653033943, '_time_this_iter_s': 0.09854817390441895, '_training_iteration': 22}]\n","[{'epoch': 11, 'train_acc': 0.0, 'train_loss': 0.06094546751784427, '_timestamp': 1653033944, '_time_this_iter_s': 1.0076568126678467, '_training_iteration': 23}]\n","[{'epoch': 11, 'evaluate_acc': 0.0, 'test_loss': 0.044042169828625286, '_timestamp': 1653033944, '_time_this_iter_s': 0.10073971748352051, '_training_iteration': 24}]\n","[{'epoch': 12, 'train_acc': 0.0, 'train_loss': 0.060892132802733354, '_timestamp': 1653033945, '_time_this_iter_s': 1.006847858428955, '_training_iteration': 25}]\n","[{'epoch': 12, 'evaluate_acc': 0.0, 'test_loss': 0.04097760750857346, '_timestamp': 1653033945, '_time_this_iter_s': 0.10092830657958984, '_training_iteration': 26}]\n","[{'epoch': 13, 'train_acc': 0.0, 'train_loss': 0.06074612067480172, '_timestamp': 1653033946, '_time_this_iter_s': 1.0279877185821533, '_training_iteration': 27}]\n","[{'epoch': 13, 'evaluate_acc': 0.0, 'test_loss': 0.05008233364616685, '_timestamp': 1653033946, '_time_this_iter_s': 0.1041874885559082, '_training_iteration': 28}]\n","[{'epoch': 14, 'train_acc': 0.0, 'train_loss': 0.06069977636049901, '_timestamp': 1653033947, '_time_this_iter_s': 1.0180799961090088, '_training_iteration': 29}]\n","[{'epoch': 14, 'evaluate_acc': 0.0, 'test_loss': 0.04758017571807346, '_timestamp': 1653033947, '_time_this_iter_s': 0.09862518310546875, '_training_iteration': 30}]\n","[{'epoch': 15, 'train_acc': 0.0, 'train_loss': 0.06055774477177433, '_timestamp': 1653033948, '_time_this_iter_s': 0.9882826805114746, '_training_iteration': 31}]\n","[{'epoch': 15, 'evaluate_acc': 0.0, 'test_loss': 0.047914604396175814, '_timestamp': 1653033948, '_time_this_iter_s': 0.10061764717102051, '_training_iteration': 32}]\n","[{'epoch': 16, 'train_acc': 0.0, 'train_loss': 0.06053171257621476, '_timestamp': 1653033949, '_time_this_iter_s': 1.041496753692627, '_training_iteration': 33}]\n","[{'epoch': 16, 'evaluate_acc': 0.0, 'test_loss': 0.04322298049159786, '_timestamp': 1653033950, '_time_this_iter_s': 0.09923505783081055, '_training_iteration': 34}]\n","[{'epoch': 17, 'train_acc': 0.0, 'train_loss': 0.060558477350111516, '_timestamp': 1653033951, '_time_this_iter_s': 1.015479326248169, '_training_iteration': 35}]\n","[{'epoch': 17, 'evaluate_acc': 0.0, 'test_loss': 0.04752382685375564, '_timestamp': 1653033951, '_time_this_iter_s': 0.0949404239654541, '_training_iteration': 36}]\n","[{'epoch': 18, 'train_acc': 0.0, 'train_loss': 0.06059950442452516, '_timestamp': 1653033952, '_time_this_iter_s': 1.024864912033081, '_training_iteration': 37}]\n","[{'epoch': 18, 'evaluate_acc': 0.0, 'test_loss': 0.04356259904692278, '_timestamp': 1653033952, '_time_this_iter_s': 0.10841035842895508, '_training_iteration': 38}]\n","[{'epoch': 19, 'train_acc': 0.0, 'train_loss': 0.060317250567355325, '_timestamp': 1653033953, '_time_this_iter_s': 0.9851441383361816, '_training_iteration': 39}]\n","[{'epoch': 19, 'evaluate_acc': 0.0, 'test_loss': 0.04302686020074522, '_timestamp': 1653033953, '_time_this_iter_s': 0.1081991195678711, '_training_iteration': 40}]\n","[{'epoch': 20, 'train_acc': 0.0, 'train_loss': 0.060491774523896834, '_timestamp': 1653033954, '_time_this_iter_s': 0.9950211048126221, '_training_iteration': 41}]\n","[{'epoch': 20, 'evaluate_acc': 0.0, 'test_loss': 0.04074172014096642, '_timestamp': 1653033954, '_time_this_iter_s': 0.09588193893432617, '_training_iteration': 42}]\n","[{'epoch': 21, 'train_acc': 0.0, 'train_loss': 0.060222738861505476, '_timestamp': 1653033955, '_time_this_iter_s': 0.9660444259643555, '_training_iteration': 43}]\n","[{'epoch': 21, 'evaluate_acc': 0.0, 'test_loss': 0.04606892182217801, '_timestamp': 1653033955, '_time_this_iter_s': 0.09844017028808594, '_training_iteration': 44}]\n","[{'epoch': 22, 'train_acc': 0.0, 'train_loss': 0.060002223788095374, '_timestamp': 1653033956, '_time_this_iter_s': 0.9830670356750488, '_training_iteration': 45}]\n","[{'epoch': 22, 'evaluate_acc': 0.0, 'test_loss': 0.04337873318068245, '_timestamp': 1653033956, '_time_this_iter_s': 0.10351777076721191, '_training_iteration': 46}]\n","[{'epoch': 23, 'train_acc': 0.0, 'train_loss': 0.0600206001794764, '_timestamp': 1653033957, '_time_this_iter_s': 0.991854190826416, '_training_iteration': 47}]\n","[{'epoch': 23, 'evaluate_acc': 0.0, 'test_loss': 0.04502974217757583, '_timestamp': 1653033957, '_time_this_iter_s': 0.10401058197021484, '_training_iteration': 48}]\n","[{'epoch': 24, 'train_acc': 0.0, 'train_loss': 0.06031114665259208, '_timestamp': 1653033958, '_time_this_iter_s': 1.0301451683044434, '_training_iteration': 49}]\n","[{'epoch': 24, 'evaluate_acc': 0.0, 'test_loss': 0.04111280385404825, '_timestamp': 1653033958, '_time_this_iter_s': 0.0955967903137207, '_training_iteration': 50}]\n","[{'epoch': 25, 'train_acc': 0.0, 'train_loss': 0.060170894036335604, '_timestamp': 1653033959, '_time_this_iter_s': 0.978858470916748, '_training_iteration': 51}]\n","[{'epoch': 25, 'evaluate_acc': 0.0, 'test_loss': 0.041907111744341606, '_timestamp': 1653033959, '_time_this_iter_s': 0.10189700126647949, '_training_iteration': 52}]\n","[{'epoch': 26, 'train_acc': 0.0, 'train_loss': 0.059950036276131866, '_timestamp': 1653033960, '_time_this_iter_s': 1.0169415473937988, '_training_iteration': 53}]\n","[{'epoch': 26, 'evaluate_acc': 0.0, 'test_loss': 0.04607771421947023, '_timestamp': 1653033961, '_time_this_iter_s': 0.09785127639770508, '_training_iteration': 54}]\n","[{'epoch': 27, 'train_acc': 0.0, 'train_loss': 0.06003867073782853, '_timestamp': 1653033962, '_time_this_iter_s': 1.0707457065582275, '_training_iteration': 55}]\n","[{'epoch': 27, 'evaluate_acc': 0.0, 'test_loss': 0.047183933256960964, '_timestamp': 1653033962, '_time_this_iter_s': 0.10374617576599121, '_training_iteration': 56}]\n","[{'epoch': 28, 'train_acc': 0.0, 'train_loss': 0.05980566338236843, '_timestamp': 1653033963, '_time_this_iter_s': 1.0395805835723877, '_training_iteration': 57}]\n","[{'epoch': 28, 'evaluate_acc': 0.0, 'test_loss': 0.049405371825046396, '_timestamp': 1653033963, '_time_this_iter_s': 0.103973388671875, '_training_iteration': 58}]\n","[{'epoch': 29, 'train_acc': 0.0, 'train_loss': 0.05959741675427982, '_timestamp': 1653033964, '_time_this_iter_s': 1.002065896987915, '_training_iteration': 59}]\n","[{'epoch': 29, 'evaluate_acc': 0.0, 'test_loss': 0.03843427149524145, '_timestamp': 1653033964, '_time_this_iter_s': 0.09805178642272949, '_training_iteration': 60}]\n"]}],"source":["from raydp.torch import TorchEstimator\n","\n","estimator = TorchEstimator(num_workers=1, model=net_model, optimizer=optimizer, loss=criterion,\n"," feature_columns=features, feature_types=feature_dtype,\n"," label_column=\"stroke\", label_type=torch.float,\n"," batch_size=64, num_epochs=10, callbacks=[PrintingCallback()],\n"," metrics_name=[\"Accuracy\"], metrics_config={\"Accuracy\": {\"task\": \"binary\"}})\n","# Train the model\n","estimator.fit_on_spark(train_df, test_df)"]},{"cell_type":"markdown","metadata":{"id":"nHRY731sm4nR"},"source":["## 12. shut down ray and raydp"]},{"cell_type":"code","execution_count":92,"metadata":{"executionInfo":{"elapsed":1727,"status":"ok","timestamp":1653033966765,"user":{"displayName":"Ke Yan","userId":"12157985255311286389"},"user_tz":-480},"id":"dMt8Om94m9iH"},"outputs":[],"source":["raydp.stop_spark()\n","ray.shutdown()"]}],"metadata":{"colab":{"authorship_tag":"ABX9TyN/DQ/PIfLSmV6Kjq2Hn/Y5","collapsed_sections":[],"mount_file_id":"1zvWvMhBNUolMOVMfzYqepXB671v2KcPk","name":"pytorch_nyctaxi.ipynb","provenance":[]},"interpreter":{"hash":"4592069f3f0e7e931529bda2eb12f695b39a5cc01058a1b879fa2b8939b3a972"},"kernelspec":{"display_name":"Python 3.7.12 ('raydp')","language":"python","name":"python3"},"language_info":{"codemirror_mode":{"name":"ipython","version":3},"file_extension":".py","mimetype":"text/x-python","name":"python","nbconvert_exporter":"python","pygments_lexer":"ipython3","version":"3.7.12"}},"nbformat":4,"nbformat_minor":0} diff --git a/tutorials/raytrain_example.ipynb b/tutorials/raytrain_example.ipynb index 47aa241f..9419a7bc 100644 --- a/tutorials/raytrain_example.ipynb +++ b/tutorials/raytrain_example.ipynb @@ -1 +1 @@ -{"cells":[{"cell_type":"markdown","metadata":{"id":"pJGi7Gb-DIRB"},"source":["# **How RayDP works together with Ray**"]},{"cell_type":"markdown","metadata":{"id":"3W1tUlKEMzm2"},"source":["RayDP is a distributed data processing library that provides simple APIs for running Spark on Ray and integrating Spark with distributed deep learning and machine learning frameworks. This document builds an end-to-end deep learning pipeline on a single Ray cluster by using Spark for data preprocessing, and uses ray train to complete the training and evaluation."]},{"cell_type":"markdown","metadata":{"id":"cNZOlzR-ldrE"},"source":["[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/oap-project/raydp/blob/master/tutorials/raytrain_example.ipynb)"]},{"cell_type":"markdown","metadata":{"id":"NxuZwK3IDr6i"},"source":["## 1. Colab enviroment Setup"]},{"cell_type":"markdown","metadata":{"id":"2vMP2OEik9pO"},"source":["RayDP requires Ray and PySpark. At the same time, pytorch is used to build deep learning model."]},{"cell_type":"code","execution_count":null,"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"elapsed":679,"status":"ok","timestamp":1653029465793,"user":{"displayName":"Ke Yan","userId":"12157985255311286389"},"user_tz":-480},"id":"Z697GdNUDoDG","outputId":"5447afbd-9657-4705-f1ca-e9791f889e6d"},"outputs":[{"name":"stdout","output_type":"stream","text":["Package Version\n","----------------------------- ------------------------------\n","absl-py 1.0.0\n","alabaster 0.7.12\n","albumentations 0.1.12\n","altair 4.2.0\n","appdirs 1.4.4\n","argon2-cffi 21.3.0\n","argon2-cffi-bindings 21.2.0\n","arviz 0.12.1\n","astor 0.8.1\n","astropy 4.3.1\n","astunparse 1.6.3\n","async-timeout 4.0.2\n","atari-py 0.2.9\n","atomicwrites 1.4.0\n","attrs 21.4.0\n","audioread 2.1.9\n","autograd 1.4\n","Babel 2.10.1\n","backcall 0.2.0\n","beautifulsoup4 4.6.3\n","bleach 5.0.0\n","blis 0.4.1\n","bokeh 2.3.3\n","Bottleneck 1.3.4\n","branca 0.5.0\n","bs4 0.0.1\n","CacheControl 0.12.11\n","cached-property 1.5.2\n","cachetools 4.2.4\n","catalogue 1.0.0\n","certifi 2021.10.8\n","cffi 1.15.0\n","cftime 1.6.0\n","chardet 3.0.4\n","charset-normalizer 2.0.12\n","click 7.1.2\n","cloudpickle 1.3.0\n","cmake 3.22.4\n","cmdstanpy 0.9.5\n","colorcet 3.0.0\n","colorlover 0.3.0\n","community 1.0.0b1\n","contextlib2 0.5.5\n","convertdate 2.4.0\n","coverage 3.7.1\n","coveralls 0.5\n","crcmod 1.7\n","cufflinks 0.17.3\n","cvxopt 1.2.7\n","cvxpy 1.0.31\n","cycler 0.11.0\n","cymem 2.0.6\n","Cython 0.29.30\n","daft 0.0.4\n","dask 2.12.0\n","datascience 0.10.6\n","debugpy 1.0.0\n","decorator 4.4.2\n","defusedxml 0.7.1\n","Deprecated 1.2.13\n","descartes 1.1.0\n","dill 0.3.4\n","distributed 1.25.3\n","dlib 19.18.0+zzzcolab20220513001918\n","dm-tree 0.1.7\n","docopt 0.6.2\n","docutils 0.17.1\n","dopamine-rl 1.0.5\n","earthengine-api 0.1.309\n","easydict 1.9\n","ecos 2.0.10\n","editdistance 0.5.3\n","en-core-web-sm 2.2.5\n","entrypoints 0.4\n","ephem 4.1.3\n","et-xmlfile 1.1.0\n","fa2 0.3.5\n","fastai 1.0.61\n","fastdtw 0.3.4\n","fastjsonschema 2.15.3\n","fastprogress 1.0.2\n","fastrlock 0.8\n","fbprophet 0.7.1\n","feather-format 0.4.1\n","filelock 3.7.0\n","firebase-admin 4.4.0\n","fix-yahoo-finance 0.0.22\n","Flask 1.1.4\n","flatbuffers 2.0\n","folium 0.8.3\n","future 0.16.0\n","gast 0.5.3\n","GDAL 2.2.2\n","gdown 4.4.0\n","gensim 3.6.0\n","geographiclib 1.52\n","geopy 1.17.0\n","gin-config 0.5.0\n","glob2 0.7\n","google 2.0.3\n","google-api-core 1.31.5\n","google-api-python-client 1.12.11\n","google-auth 1.35.0\n","google-auth-httplib2 0.0.4\n","google-auth-oauthlib 0.4.6\n","google-cloud-bigquery 1.21.0\n","google-cloud-bigquery-storage 1.1.1\n","google-cloud-core 1.0.3\n","google-cloud-datastore 1.8.0\n","google-cloud-firestore 1.7.0\n","google-cloud-language 1.2.0\n","google-cloud-storage 1.18.1\n","google-cloud-translate 1.5.0\n","google-colab 1.0.0\n","google-pasta 0.2.0\n","google-resumable-media 0.4.1\n","googleapis-common-protos 1.56.1\n","googledrivedownloader 0.4\n","graphviz 0.10.1\n","greenlet 1.1.2\n","grpcio 1.46.1\n","gspread 3.4.2\n","gspread-dataframe 3.0.8\n","gym 0.17.3\n","h5py 3.1.0\n","HeapDict 1.0.1\n","hijri-converter 2.2.3\n","holidays 0.10.5.2\n","holoviews 1.14.9\n","html5lib 1.0.1\n","httpimport 0.5.18\n","httplib2 0.17.4\n","httplib2shim 0.0.3\n","humanize 0.5.1\n","hyperopt 0.1.2\n","ideep4py 2.0.0.post3\n","idna 2.10\n","imageio 2.4.1\n","imagesize 1.3.0\n","imbalanced-learn 0.8.1\n","imblearn 0.0\n","imgaug 0.2.9\n","importlib-metadata 4.11.3\n","importlib-resources 5.7.1\n","imutils 0.5.4\n","inflect 2.1.0\n","iniconfig 1.1.1\n","intel-openmp 2022.1.0\n","intervaltree 2.1.0\n","ipykernel 4.10.1\n","ipython 5.5.0\n","ipython-genutils 0.2.0\n","ipython-sql 0.3.9\n","ipywidgets 7.7.0\n","itsdangerous 1.1.0\n","jax 0.3.8\n","jaxlib 0.3.7+cuda11.cudnn805\n","jedi 0.18.1\n","jieba 0.42.1\n","Jinja2 2.11.3\n","joblib 1.1.0\n","jpeg4py 0.1.4\n","jsonschema 4.3.3\n","jupyter 1.0.0\n","jupyter-client 5.3.5\n","jupyter-console 5.2.0\n","jupyter-core 4.10.0\n","jupyterlab-pygments 0.2.2\n","jupyterlab-widgets 1.1.0\n","kaggle 1.5.12\n","kapre 0.3.7\n","keras 2.8.0\n","Keras-Preprocessing 1.1.2\n","keras-vis 0.4.1\n","kiwisolver 1.4.2\n","korean-lunar-calendar 0.2.1\n","libclang 14.0.1\n","librosa 0.8.1\n","lightgbm 2.2.3\n","llvmlite 0.34.0\n","lmdb 0.99\n","LunarCalendar 0.0.9\n","lxml 4.2.6\n","Markdown 3.3.7\n","MarkupSafe 2.0.1\n","matplotlib 3.2.2\n","matplotlib-inline 0.1.3\n","matplotlib-venn 0.11.7\n","missingno 0.5.1\n","mistune 0.8.4\n","mizani 0.6.0\n","mkl 2019.0\n","mlxtend 0.14.0\n","more-itertools 8.13.0\n","moviepy 0.2.3.5\n","mpmath 1.2.1\n","msgpack 1.0.3\n","multiprocess 0.70.12.2\n","multitasking 0.0.10\n","murmurhash 1.0.7\n","music21 5.5.0\n","natsort 5.5.0\n","nbclient 0.6.3\n","nbconvert 5.6.1\n","nbformat 5.4.0\n","nest-asyncio 1.5.5\n","netCDF4 1.5.8\n","netifaces 0.11.0\n","networkx 2.6.3\n","nibabel 3.0.2\n","nltk 3.2.5\n","notebook 5.3.1\n","numba 0.51.2\n","numexpr 2.8.1\n","numpy 1.21.6\n","nvidia-ml-py3 7.352.0\n","oauth2client 4.1.3\n","oauthlib 3.2.0\n","okgrade 0.4.3\n","opencv-contrib-python 4.1.2.30\n","opencv-python 4.1.2.30\n","openpyxl 3.0.9\n","opt-einsum 3.3.0\n","osqp 0.6.2.post0\n","packaging 21.3\n","palettable 3.3.0\n","pandas 1.3.5\n","pandas-datareader 0.9.0\n","pandas-gbq 0.13.3\n","pandas-profiling 1.4.1\n","pandocfilters 1.5.0\n","panel 0.12.1\n","param 1.12.1\n","parso 0.8.3\n","pathlib 1.0.1\n","patsy 0.5.2\n","pep517 0.12.0\n","pexpect 4.8.0\n","pickleshare 0.7.5\n","Pillow 7.1.2\n","pip 21.1.3\n","pip-tools 6.2.0\n","plac 1.1.3\n","plotly 5.5.0\n","plotnine 0.6.0\n","pluggy 0.7.1\n","pooch 1.6.0\n","portpicker 1.3.9\n","prefetch-generator 1.0.1\n","preshed 3.0.6\n","prettytable 3.3.0\n","progressbar2 3.38.0\n","prometheus-client 0.14.1\n","promise 2.3\n","prompt-toolkit 1.0.18\n","protobuf 3.17.3\n","psutil 5.4.8\n","psycopg2 2.7.6.1\n","ptyprocess 0.7.0\n","py 1.11.0\n","py4j 0.10.9.3\n","pyarrow 6.0.1\n","pyasn1 0.4.8\n","pyasn1-modules 0.2.8\n","pycocotools 2.0.4\n","pycparser 2.21\n","pyct 0.4.8\n","pydata-google-auth 1.4.0\n","pydot 1.3.0\n","pydot-ng 2.0.0\n","pydotplus 2.0.2\n","PyDrive 1.3.1\n","pyemd 0.5.1\n","pyerfa 2.0.0.1\n","pyglet 1.5.0\n","Pygments 2.6.1\n","pygobject 3.26.1\n","pymc3 3.11.4\n","PyMeeus 0.5.11\n","pymongo 4.1.1\n","pymystem3 0.2.0\n","PyOpenGL 3.1.6\n","pyparsing 3.0.9\n","pyrsistent 0.18.1\n","pysndfile 1.3.8\n","PySocks 1.7.1\n","pyspark 3.2.1\n","pystan 2.19.1.1\n","pytest 3.6.4\n","python-apt 0.0.0\n","python-chess 0.23.11\n","python-dateutil 2.8.2\n","python-louvain 0.16\n","python-slugify 6.1.2\n","python-utils 3.2.2\n","pytz 2022.1\n","pyviz-comms 2.2.0\n","PyWavelets 1.3.0\n","PyYAML 3.13\n","pyzmq 22.3.0\n","qdldl 0.1.5.post2\n","qtconsole 5.3.0\n","QtPy 2.1.0\n","ray 1.9.0\n","raydp-nightly 2022.5.12.dev0\n","redis 4.3.1\n","regex 2019.12.20\n","requests 2.23.0\n","requests-oauthlib 1.3.1\n","resampy 0.2.2\n","rpy2 3.4.5\n","rsa 4.8\n","scikit-image 0.18.3\n","scikit-learn 1.0.2\n","scipy 1.4.1\n","screen-resolution-extra 0.0.0\n","scs 3.2.0\n","seaborn 0.11.2\n","semver 2.13.0\n","Send2Trash 1.8.0\n","setuptools 57.4.0\n","setuptools-git 1.2\n","Shapely 1.8.2\n","simplegeneric 0.8.1\n","six 1.15.0\n","sklearn 0.0\n","sklearn-pandas 1.8.0\n","smart-open 6.0.0\n","snowballstemmer 2.2.0\n","sortedcontainers 2.4.0\n","SoundFile 0.10.3.post1\n","soupsieve 2.3.2.post1\n","spacy 2.2.4\n","Sphinx 1.8.6\n","sphinxcontrib-serializinghtml 1.1.5\n","sphinxcontrib-websupport 1.2.4\n","SQLAlchemy 1.4.36\n","sqlparse 0.4.2\n","srsly 1.0.5\n","statsmodels 0.10.2\n","sympy 1.7.1\n","tables 3.7.0\n","tabulate 0.8.9\n","tblib 1.7.0\n","tenacity 8.0.1\n","tensorboard 2.8.0\n","tensorboard-data-server 0.6.1\n","tensorboard-plugin-wit 1.8.1\n","tensorboardX 2.5\n","tensorflow 2.8.0+zzzcolab20220506162203\n","tensorflow-datasets 4.0.1\n","tensorflow-estimator 2.8.0\n","tensorflow-gcs-config 2.8.0\n","tensorflow-hub 0.12.0\n","tensorflow-io-gcs-filesystem 0.25.0\n","tensorflow-metadata 1.8.0\n","tensorflow-probability 0.16.0\n","termcolor 1.1.0\n","terminado 0.13.3\n","testpath 0.6.0\n","text-unidecode 1.3\n","textblob 0.15.3\n","Theano-PyMC 1.1.2\n","thinc 7.4.0\n","threadpoolctl 3.1.0\n","tifffile 2021.11.2\n","tinycss2 1.1.1\n","tomli 2.0.1\n","toolz 0.11.2\n","torch 1.8.1+cpu\n","torchaudio 0.11.0+cu113\n","torchsummary 1.5.1\n","torchtext 0.12.0\n","torchvision 0.12.0+cu113\n","tornado 5.1.1\n","tqdm 4.64.0\n","traitlets 5.1.1\n","tweepy 3.10.0\n","typeguard 2.7.1\n","typing 3.7.4.3\n","typing-extensions 4.2.0\n","tzlocal 1.5.1\n","uritemplate 3.0.1\n","urllib3 1.24.3\n","vega-datasets 0.9.0\n","wasabi 0.9.1\n","wcwidth 0.2.5\n","webencodings 0.5.1\n","Werkzeug 1.0.1\n","wheel 0.37.1\n","widgetsnbextension 3.6.0\n","wordcloud 1.5.0\n","wrapt 1.14.1\n","xarray 0.20.2\n","xarray-einstats 0.2.2\n","xgboost 0.90\n","xkit 0.0.0\n","xlrd 1.1.0\n","xlwt 1.3.0\n","yellowbrick 1.4\n","zict 2.2.0\n","zipp 3.8.0\n"]}],"source":["! pip install ray==1.9\n","# install RayDP nightly build\n","! pip install raydp-nightly\n","# or use the released version\n","# ! pip install raydp\n","! pip install ray[tune]\n","! pip install torch==1.8.1+cpu -f https://download.pytorch.org/whl/torch_stable.html"]},{"cell_type":"markdown","metadata":{"id":"x3o7NPecpbeA"},"source":["## 2. Get the data file"]},{"cell_type":"markdown","metadata":{"id":"JpiXje_Jmkp3"},"source":["The dataset is from: https://www.kaggle.com/datasets/fedesoriano/stroke-prediction-dataset, and we store the file in github repository. It's used to predict whether a patient is likely to get stroke based on the input parameters like gender, age, various diseases, and smoking status. Each row in the data provides relavant information about the patient. "]},{"cell_type":"code","execution_count":null,"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"elapsed":706,"status":"ok","timestamp":1653029466495,"user":{"displayName":"Ke Yan","userId":"12157985255311286389"},"user_tz":-480},"id":"ywec9GENpe7-","outputId":"01a68d70-e4eb-4343-8605-3116e0dde429"},"outputs":[],"source":["! wget https://raw.githubusercontent.com/oap-project/raydp/master/tutorials/dataset/healthcare-dataset-stroke-data.csv -O healthcare-dataset-stroke-data.csv"]},{"cell_type":"markdown","metadata":{"id":"fSD4v0zKEx7d"},"source":["## 3. Init or connect to a ray cluster"]},{"cell_type":"code","execution_count":null,"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"elapsed":4310,"status":"ok","timestamp":1653029470802,"user":{"displayName":"Ke Yan","userId":"12157985255311286389"},"user_tz":-480},"id":"kOXeU_qmE1uH","outputId":"de23e901-c1a9-4582-bf41-d5fd7abc2b64"},"outputs":[{"data":{"text/plain":["{'metrics_export_port': 44425,\n"," 'node_id': '1e8d9137f4f4a9343edcce64f2d61d72ec70d624899d07e818bd34e1',\n"," 'node_ip_address': '172.28.0.2',\n"," 'object_store_address': '/tmp/ray/session_2022-05-20_06-51-06_717688_61/sockets/plasma_store',\n"," 'raylet_ip_address': '172.28.0.2',\n"," 'raylet_socket_name': '/tmp/ray/session_2022-05-20_06-51-06_717688_61/sockets/raylet',\n"," 'redis_address': '172.28.0.2:6379',\n"," 'session_dir': '/tmp/ray/session_2022-05-20_06-51-06_717688_61',\n"," 'webui_url': None}"]},"execution_count":5,"metadata":{},"output_type":"execute_result"}],"source":["import ray\n","\n","ray.init(num_cpus=6)"]},{"cell_type":"markdown","metadata":{"id":"xyzXEJg7FVrg"},"source":["## 4. Get a spark session"]},{"cell_type":"code","execution_count":null,"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"elapsed":20879,"status":"ok","timestamp":1653029491677,"user":{"displayName":"Ke Yan","userId":"12157985255311286389"},"user_tz":-480},"id":"R9YydfatFfuF","outputId":"fb4c7b7b-f30f-42c3-daf9-fa472916f93f"},"outputs":[{"name":"stdout","output_type":"stream","text":["\u001b[2m\u001b[36m(RayDPSparkMaster pid=529)\u001b[0m WARNING: sun.reflect.Reflection.getCallerClass is not supported. This will impact performance.\n"]},{"name":"stderr","output_type":"stream","text":["\u001b[2m\u001b[36m(RayDPSparkMaster pid=529)\u001b[0m WARNING: An illegal reflective access operation has occurred\n","\u001b[2m\u001b[36m(RayDPSparkMaster pid=529)\u001b[0m WARNING: Illegal reflective access by org.apache.spark.unsafe.Platform (file:/usr/local/lib/python3.7/dist-packages/pyspark/jars/spark-unsafe_2.12-3.2.1.jar) to constructor java.nio.DirectByteBuffer(long,int)\n","\u001b[2m\u001b[36m(RayDPSparkMaster pid=529)\u001b[0m WARNING: Please consider reporting this to the maintainers of org.apache.spark.unsafe.Platform\n","\u001b[2m\u001b[36m(RayDPSparkMaster pid=529)\u001b[0m WARNING: Use --illegal-access=warn to enable warnings of further illegal reflective access operations\n","\u001b[2m\u001b[36m(RayDPSparkMaster pid=529)\u001b[0m WARNING: All illegal access operations will be denied in a future release\n"]},{"name":"stdout","output_type":"stream","text":["\u001b[2m\u001b[36m(RayDPSparkMaster pid=529)\u001b[0m 2022-05-20 06:51:17,557 WARN NativeCodeLoader [Thread-2]: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable\n","\u001b[2m\u001b[36m(RayDPSparkMaster pid=529)\u001b[0m 2022-05-20 06:51:17,872 INFO SecurityManager [Thread-2]: Changing view acls to: root\n","\u001b[2m\u001b[36m(RayDPSparkMaster pid=529)\u001b[0m 2022-05-20 06:51:17,879 INFO SecurityManager [Thread-2]: Changing modify acls to: root\n","\u001b[2m\u001b[36m(RayDPSparkMaster pid=529)\u001b[0m 2022-05-20 06:51:17,880 INFO SecurityManager [Thread-2]: Changing view acls groups to: \n","\u001b[2m\u001b[36m(RayDPSparkMaster pid=529)\u001b[0m 2022-05-20 06:51:17,882 INFO SecurityManager [Thread-2]: Changing modify acls groups to: \n","\u001b[2m\u001b[36m(RayDPSparkMaster pid=529)\u001b[0m 2022-05-20 06:51:17,883 INFO SecurityManager [Thread-2]: SecurityManager: authentication disabled; ui acls disabled; users with view permissions: Set(root); groups with view permissions: Set(); users with modify permissions: Set(root); groups with modify permissions: Set()\n","\u001b[2m\u001b[36m(RayDPSparkMaster pid=529)\u001b[0m 2022-05-20 06:51:19,091 INFO Utils [Thread-2]: Successfully started service 'RAY_RPC_ENV' on port 40877.\n","\u001b[2m\u001b[36m(RayDPSparkMaster pid=529)\u001b[0m 2022-05-20 06:51:27,127 INFO RayAppMaster$RayAppMasterEndpoint [dispatcher-event-loop-1]: Registering app Stoke Prediction with RayDP\n","\u001b[2m\u001b[36m(RayDPSparkMaster pid=529)\u001b[0m 2022-05-20 06:51:27,132 INFO RayAppMaster$RayAppMasterEndpoint [dispatcher-event-loop-1]: Registered app Stoke Prediction with RayDP with ID app-20220520065127-0000\n"]}],"source":["import raydp\n","\n","app_name = \"Stoke Prediction with RayDP\"\n","num_executors = 1\n","cores_per_executor = 1\n","memory_per_executor = \"500M\"\n","spark = raydp.init_spark(app_name, num_executors, cores_per_executor, memory_per_executor)"]},{"cell_type":"markdown","metadata":{"id":"jse6pc7OL2pA"},"source":["## 5. Get data from .csv file via 'spark' created by **raydp**"]},{"cell_type":"code","execution_count":null,"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"elapsed":15155,"status":"ok","timestamp":1653029506826,"user":{"displayName":"Ke Yan","userId":"12157985255311286389"},"user_tz":-480},"id":"DhyOSZhfL8_0","outputId":"904d2de2-d736-4cdc-8741-8abfb08e48ab"},"outputs":[{"name":"stderr","output_type":"stream","text":["\u001b[2m\u001b[33m(raylet)\u001b[0m WARNING: An illegal reflective access operation has occurred\n","\u001b[2m\u001b[33m(raylet)\u001b[0m WARNING: Illegal reflective access by org.apache.spark.unsafe.Platform (file:/usr/local/lib/python3.7/dist-packages/pyspark/jars/spark-unsafe_2.12-3.2.1.jar) to constructor java.nio.DirectByteBuffer(long,int)\n","\u001b[2m\u001b[33m(raylet)\u001b[0m WARNING: Please consider reporting this to the maintainers of org.apache.spark.unsafe.Platform\n","\u001b[2m\u001b[33m(raylet)\u001b[0m WARNING: Use --illegal-access=warn to enable warnings of further illegal reflective access operations\n","\u001b[2m\u001b[33m(raylet)\u001b[0m WARNING: All illegal access operations will be denied in a future release\n"]}],"source":["data = spark.read.format(\"csv\").option(\"header\", \"true\") \\\n"," .option(\"inferSchema\", \"true\") \\\n"," .load(\"/content/healthcare-dataset-stroke-data.csv\")"]},{"cell_type":"markdown","metadata":{"id":"yYUc524PVI2W"},"source":["## 6. Define the data_process function"]},{"cell_type":"markdown","metadata":{"id":"unbR1vvSqytm"},"source":["The dataset is converted to `pyspark.sql.dataframe.DataFrame`. Before feeding into the deep learning model, we can use raydp to do some transformation operations on dataset."]},{"cell_type":"markdown","metadata":{"id":"jJXWdonutekH"},"source":["### 6.1 Data Analysis"]},{"cell_type":"markdown","metadata":{"id":"RuZuZeyZSHQu"},"source":["Here is a part of the data analysis."]},{"cell_type":"code","execution_count":null,"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"elapsed":9734,"status":"ok","timestamp":1653029516546,"user":{"displayName":"Ke Yan","userId":"12157985255311286389"},"user_tz":-480},"id":"BUcjEczateOW","outputId":"96a6191b-225a-4320-cd01-0145b1c8e966"},"outputs":[{"name":"stdout","output_type":"stream","text":["+-----+------+----+------------+-------------+------------+-------------+--------------+-----------------+----+---------------+------+\n","| id|gender| age|hypertension|heart_disease|ever_married| work_type|Residence_type|avg_glucose_level| bmi| smoking_status|stroke|\n","+-----+------+----+------------+-------------+------------+-------------+--------------+-----------------+----+---------------+------+\n","| 9046| Male|67.0| 0| 1| Yes| Private| Urban| 228.69|36.6|formerly smoked| 1|\n","|51676|Female|61.0| 0| 0| Yes|Self-employed| Rural| 202.21| N/A| never smoked| 1|\n","|31112| Male|80.0| 0| 1| Yes| Private| Rural| 105.92|32.5| never smoked| 1|\n","|60182|Female|49.0| 0| 0| Yes| Private| Urban| 171.23|34.4| smokes| 1|\n","| 1665|Female|79.0| 1| 0| Yes|Self-employed| Rural| 174.12| 24| never smoked| 1|\n","+-----+------+----+------------+-------------+------------+-------------+--------------+-----------------+----+---------------+------+\n","only showing top 5 rows\n","\n","+-------+-----------------+------+------------------+------------------+-------------------+------------+---------+--------------+------------------+------------------+--------------+-------------------+\n","|summary| id|gender| age| hypertension| heart_disease|ever_married|work_type|Residence_type| avg_glucose_level| bmi|smoking_status| stroke|\n","+-------+-----------------+------+------------------+------------------+-------------------+------------+---------+--------------+------------------+------------------+--------------+-------------------+\n","| count| 5110| 5110| 5110| 5110| 5110| 5110| 5110| 5110| 5110| 5110| 5110| 5110|\n","| mean|36517.82935420744| null|43.226614481409015|0.0974559686888454|0.05401174168297456| null| null| null|106.14767710371804|28.893236911794673| null| 0.0487279843444227|\n","| stddev|21161.72162482715| null| 22.61264672311348| 0.296606674233791|0.22606298750336554| null| null| null| 45.28356015058193| 7.85406672968016| null|0.21531985698023753|\n","| min| 67|Female| 0.08| 0| 0| No| Govt_job| Rural| 55.12| 10.3| Unknown| 0|\n","| max| 72940| Other| 82.0| 1| 1| Yes| children| Urban| 271.74| N/A| smokes| 1|\n","+-------+-----------------+------+------------------+------------------+-------------------+------------+---------+--------------+------------------+------------------+--------------+-------------------+\n","\n","+------+-----+\n","|gender|count|\n","+------+-----+\n","| Male| 2115|\n","| null| 5110|\n","|Female| 2994|\n","| Other| 1|\n","+------+-----+\n","\n","+------+-----+\n","|stroke|count|\n","+------+-----+\n","| 1| 249|\n","| 0| 4861|\n","| null| 5110|\n","+------+-----+\n","\n"]}],"source":["# Data overview\n","data.show(5)\n","# Statistical N/A distribution\n","# There are 201 'N/A' value in column 'bmi column',\n","# we can update them the mean of the column\n","data.describe().show()\n","data.filter(data.bmi=='N/A').count()\n","# Observe the distribution of the column 'gender'\n","# Then we should remove the outliers 'Other'\n","data.rollup(data.gender).count().show()\n","# Observe the proportion of positive and negative samples.\n","data.rollup(data.stroke).count().show()"]},{"cell_type":"markdown","metadata":{"id":"cF9XEK1AtsH3"},"source":["### 6.2 Define operations"]},{"cell_type":"markdown","metadata":{"id":"gaPmPn7TsMFP"},"source":["Define data processing operations based on data analysis results."]},{"cell_type":"code","execution_count":null,"metadata":{"id":"PZn2UuWGxLsO"},"outputs":[],"source":["from pyspark.sql.functions import hour, quarter, month, year, dayofweek, dayofmonth, weekofyear, col, lit, udf, abs as functions_abs, avg"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"SSubHn45seR3"},"outputs":[],"source":["# Delete the useless column 'id'\n","def drop_col(data):\n"," data = data.drop('id')\n"," return data"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"ybPMTUtcshbm"},"outputs":[],"source":["# Replace the value N/A in 'bmi'\n","def replace_nan(data):\n"," bmi_avg = data.agg(avg(col(\"bmi\"))).head()[0]\n","\n"," @udf(\"float\")\n"," def replace_nan(value):\n"," if value=='N/A':\n"," return float(bmi_avg)\n"," else:\n"," return float(value)\n","\n"," # Replace the value N/A\n"," data = data.withColumn('bmi', replace_nan(col(\"bmi\")))\n"," return data"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"WOzxnQp7sqOw"},"outputs":[],"source":["# Drop the only one value 'Other' in column 'gender'\n","def clean_value(data):\n"," data = data.filter(data.gender != 'Other')\n"," return data"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"EGjPMaAbstlZ"},"outputs":[],"source":["# Transform the category columns\n","def trans_category(data):\n"," @udf(\"int\")\n"," def trans_gender(value):\n"," gender = {'Female': 0,\n"," 'Male': 1}\n"," return int(gender[value])\n","\n"," @udf(\"int\")\n"," def trans_ever_married(value):\n"," residence_type = {'No': 0,\n"," 'Yes': 1}\n"," return int(residence_type[value])\n","\n"," @udf(\"int\")\n"," def trans_work_type(value):\n"," work_type = {'children': 0,\n"," 'Govt_job': 1,\n"," 'Never_worked': 2,\n"," 'Private': 3,\n"," 'Self-employed': 4}\n"," return int(work_type[value])\n","\n"," @udf(\"int\")\n"," def trans_residence_type(value):\n"," residence_type = {'Rural': 0,\n"," 'Urban': 1}\n"," return int(residence_type[value])\n","\n"," @udf(\"int\")\n"," def trans_smoking_status(value):\n"," smoking_status = {'formerly smoked': 0,\n"," 'never smoked': 1,\n"," 'smokes': 2,\n"," 'Unknown': 3}\n"," return int(smoking_status[value])\n","\n"," data = data.withColumn('gender', trans_gender(col('gender'))) \\\n"," .withColumn('ever_married', trans_ever_married(col('ever_married'))) \\\n"," .withColumn('work_type', trans_work_type(col('work_type'))) \\\n"," .withColumn('Residence_type', trans_residence_type(col('Residence_type'))) \\\n"," .withColumn('smoking_status', trans_smoking_status(col('smoking_status')))\n"," return data"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"s3b_s_aqsxqo"},"outputs":[],"source":["# Add the discretized column of 'Age'\n","def map_age(data):\n"," @udf(\"int\")\n"," def get_value(value):\n"," if value >= 18 and value < 26:\n"," return int(0)\n"," elif value >=26 and value < 36:\n"," return int(1)\n"," elif value >=36 and value < 46:\n"," return int(2)\n"," elif value >=46 and value < 56:\n"," return int(3)\n"," else:\n"," return int(4)\n","\n"," data = data.withColumn('age_dis', get_value(col('age')))\n"," return data"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"Jk5b9MNwVNNk"},"outputs":[],"source":["# Preprocess the data\n","def data_preprocess(data):\n"," data = drop_col(data)\n"," data = replace_nan(data)\n"," data = clean_value(data)\n"," data = trans_category(data)\n"," data = map_age(data)\n"," return data"]},{"cell_type":"markdown","metadata":{"id":"ZPgzXrdUFtL4"},"source":["## 7. Data processing"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"ztuBpv9SF0_R"},"outputs":[],"source":["import torch\n","from raydp.utils import random_split\n","\n","# Transform the dataset\n","data = data_preprocess(data)\n","# Split data into train_dataset and test_dataset\n","train_df, test_df = random_split(data, [0.8, 0.2], 0)\n","# Balance the positive and negative samples\n","train_df_neg = train_df.filter(train_df.stroke == '1')\n","train_df = train_df.unionByName(train_df_neg)\n","train_df = train_df.unionByName(train_df_neg)\n","features = [field.name for field in list(train_df.schema) if field.name != \"stroke\"]\n","# Convert spark dataframe into ray Dataset\n","# Remember to align ``parallelism`` with ``num_workers`` of ray train\n","train_dataset = ray.data.from_spark(train_df, parallelism = 8)\n","test_dataset = ray.data.from_spark(test_df, parallelism = 8)\n","feature_dtype = [torch.float] * len(features)"]},{"cell_type":"markdown","metadata":{"id":"9PujezwdGJ87"},"source":["## 8. Define a neural network model"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"RcsoN3tWGUEM"},"outputs":[],"source":["import torch.nn as nn\n","import torch.nn.functional as F\n","\n","class NET_Model(nn.Module):\n"," def __init__(self, cols):\n"," super().__init__()\n"," self.emb_layer_gender = nn.Embedding(2, 2) # gender\n"," self.emb_layer_hypertension = nn.Embedding(2,2) # hypertension\n"," self.emb_layer_heart_disease = nn.Embedding(2,2) # heart_disease\n"," self.emb_layer_ever_married = nn.Embedding(2, 2) # ever_married\n"," self.emb_layer_work = nn.Embedding(5, 5) # work_type\n"," self.emb_layer_residence = nn.Embedding(2, 2) # Residence_type\n"," self.emb_layer_smoking_status = nn.Embedding(4, 4) # smoking_status\n"," self.emb_layer_age = nn.Embedding(5, 5) # age column after discretization\n"," self.fc1 = nn.Linear(cols, 256)\n"," self.fc2 = nn.Linear(256, 128)\n"," self.fc3 = nn.Linear(128, 64)\n"," self.fc4 = nn.Linear(64, 16)\n"," self.fc5 = nn.Linear(16, 2)\n","\n"," self.fc_sparse = nn.Linear(24, 16)\n"," self.fc_dense = nn.Linear(3, 8)\n"," self.fc = nn.Linear(24, 2)\n"," \n"," self.bn1 = nn.BatchNorm1d(256)\n"," self.bn2 = nn.BatchNorm1d(128)\n"," self.bn3 = nn.BatchNorm1d(64)\n"," self.bn4 = nn.BatchNorm1d(16)\n","\n"," def forward(self, *x):\n"," x = torch.cat(x, dim=1)\n"," # pick the dense attribute columns\n"," dense_columns = x[:, [1,7,8]]\n"," # Embedding operation on sparse attribute columns\n"," sparse_col_1 = self.emb_layer_gender(x[:, 0].long())\n"," sparse_col_2 = self.emb_layer_hypertension(x[:, 2].long())\n"," sparse_col_3 = self.emb_layer_heart_disease(x[:, 3].long())\n"," sparse_col_4 = self.emb_layer_ever_married(x[:, 4].long())\n"," sparse_col_5 = self.emb_layer_work(x[:, 5].long())\n"," sparse_col_6 = self.emb_layer_residence(x[:, 6].long())\n"," sparse_col_7 = self.emb_layer_smoking_status(x[:, 9].long())\n"," sparse_col_8 = self.emb_layer_age(x[:, 10].long())\n"," # Splice sparse attribute columns and dense attribute columns\n"," x = torch.cat([dense_columns, sparse_col_1, sparse_col_2, sparse_col_3, sparse_col_4, sparse_col_5, sparse_col_6, sparse_col_7, sparse_col_8], dim=1)\n","\n"," sparse_columns = torch.cat([sparse_col_1, sparse_col_2, sparse_col_3, sparse_col_4, sparse_col_5, sparse_col_6, sparse_col_7, sparse_col_8], dim=1)\n"," dense_feat = self.fc_dense(dense_columns)\n"," sparse_feat = self.fc_sparse(sparse_columns)\n"," return self.fc(torch.cat([dense_feat, sparse_feat], dim=1))\n","\n"," x = F.relu(self.fc1(x))\n"," x = self.bn1(x)\n"," x = F.relu(self.fc2(x))\n"," x = self.bn2(x)\n"," x = F.relu(self.fc3(x))\n"," x = self.bn3(x)\n"," x = F.relu(self.fc4(x))\n"," x = self.bn4(x)\n"," x = self.fc5(x)\n"," return x\n"]},{"cell_type":"markdown","metadata":{"id":"DZ81HRrOGVVj"},"source":["## 9. Define train and test function"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"0Y_08OhMGd4a"},"outputs":[],"source":["def train_epoch(dataset, model, criterion, optimizer):\n"," model.train()\n"," train_loss, correct, data_size, batch_idx = 0, 0, 0, 0\n"," for batch_idx, (inputs, targets) in enumerate(dataset):\n"," # Compute prediction error\n"," inputs = [inputs[:,i].unsqueeze(1) for i in range(inputs.size(1))]\n"," targets = targets.reshape(-1)\n"," outputs = model(*inputs)\n"," loss = criterion(outputs, targets)\n"," train_loss += loss.item()\n"," _, predicted = torch.max(outputs.data, 1)\n"," data_size += targets.size(0)\n"," correct += (predicted == targets).sum().item()\n"," # Backpropagation\n"," optimizer.zero_grad()\n"," loss.backward()\n"," optimizer.step()\n"," # Caculate the train_loss and train_acc\n"," train_loss /= (batch_idx + 1)\n"," train_acc = correct/data_size\n"," return train_acc, train_loss\n","\n","def test_epoch(dataset, model, criterion):\n"," model.eval()\n"," test_loss, correct, data_size, batch_idx = 0, 0, 0, 0\n"," with torch.no_grad():\n"," for batch_idx, (inputs, targets) in enumerate(dataset):\n"," # Compute prediction error\n"," inputs = [inputs[:,i].unsqueeze(1) for i in range(inputs.size(1))]\n"," targets = targets.reshape(-1)\n"," outputs = model(*inputs)\n"," test_loss += criterion(outputs, targets).item()\n"," _, predicted = torch.max(outputs.data, 1)\n"," data_size += targets.size(0)\n"," correct += (predicted == targets).sum().item()\n"," # Caculate the test_loss and test_acc\n"," test_loss /= (batch_idx + 1)\n"," test_acc = correct/data_size\n"," return test_acc, test_loss"]},{"cell_type":"markdown","metadata":{"id":"dRAamyi3GewE"},"source":["## 10. Define train function"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"76x_ZEdKGmd3"},"outputs":[],"source":["from ray import train\n","from ray.train import get_dataset_shard\n","\n","def train_func(config):\n"," num_epochs = config[\"num_epochs\"]\n"," lr = config[\"lr\"]\n"," batch_size = config[\"batch_size\"]\n"," # Then convert to torch datasets\n"," # Get the corresponging shard\n"," train_data_shard = get_dataset_shard(\"train\")\n"," train_dataset = train_data_shard.to_torch(feature_columns=features,\n"," label_column=\"stroke\",\n"," label_column_dtype=torch.long,\n"," feature_column_dtypes=feature_dtype,\n"," batch_size=batch_size)\n"," test_data_shard = get_dataset_shard(\"test\")\n"," test_dataset = test_data_shard.to_torch(feature_columns=features,\n"," label_column=\"stroke\",\n"," label_column_dtype=torch.long,\n"," feature_column_dtypes=feature_dtype,\n"," batch_size=batch_size)\n"," model = NET_Model(len(features))\n"," model = train.torch.prepare_model(model)\n"," criterion = nn.CrossEntropyLoss(weight=torch.tensor([0.35, 0.65]))\n"," optimizer = torch.optim.Adam(model.parameters(), lr=lr)\n"," loss_results = []\n"," for epoch in range(num_epochs):\n"," train_acc, train_loss = train_epoch(train_dataset, model, criterion, optimizer)\n"," test_acc, test_loss = test_epoch(test_dataset, model, criterion)\n"," train.report(epoch = epoch, train_acc = train_acc, train_loss = train_loss)\n"," train.report(epoch = epoch, test_acc=test_acc, test_loss=test_loss)\n"," loss_results.append(test_loss)"]},{"cell_type":"markdown","metadata":{"id":"d-wfpCJ_NL4-"},"source":["## 11. Define the callback function"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"omYjJ7WTNZUo"},"outputs":[],"source":["from ray.train import TrainingCallback\n","from typing import List, Dict\n","\n","# log the train results\n","class PrintingCallback(TrainingCallback):\n"," def handle_result(self, results: List[Dict], **info):\n"," print(results)"]},{"cell_type":"markdown","metadata":{"id":"kaW_aC5zGoWZ"},"source":["## 12. Train model via ray train"]},{"cell_type":"code","execution_count":null,"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"elapsed":8380,"status":"ok","timestamp":1653029538050,"user":{"displayName":"Ke Yan","userId":"12157985255311286389"},"user_tz":-480},"id":"Rerm_OwoG0g9","outputId":"62322657-f23d-4485-cbbf-10ec5d4f84de"},"outputs":[{"name":"stderr","output_type":"stream","text":["2022-05-20 06:52:09,475\tINFO trainer.py:172 -- Trainer logs will be logged in: /root/ray_results/train_2022-05-20_06-52-09\n","2022-05-20 06:52:10,966\tINFO trainer.py:178 -- Run results will be logged in: /root/ray_results/train_2022-05-20_06-52-09/run_001\n","\u001b[2m\u001b[36m(BaseWorkerMixin pid=1234)\u001b[0m 2022-05-20 06:52:10,963\tINFO torch.py:67 -- Setting up process group for: env:// [rank=0, world_size=1]\n","\u001b[2m\u001b[36m(BaseWorkerMixin pid=1234)\u001b[0m 2022-05-20 06:52:11,020\tINFO torch.py:239 -- Moving model to device: cpu\n"]},{"name":"stdout","output_type":"stream","text":["[{'epoch': 0, 'train_acc': 0.6178715761113606, 'train_loss': 2.4214471597756657, '_timestamp': 1653029531, '_time_this_iter_s': 0.8797011375427246, '_training_iteration': 1}]\n","[{'epoch': 0, 'test_acc': 0.7587844254510921, 'test_loss': 0.7357403390547809, '_timestamp': 1653029531, '_time_this_iter_s': 0.0008981227874755859, '_training_iteration': 2}]\n","[{'epoch': 1, 'train_acc': 0.797485406376291, 'train_loss': 0.6913646668195724, '_timestamp': 1653029532, '_time_this_iter_s': 0.5976800918579102, '_training_iteration': 3}]\n","[{'epoch': 1, 'test_acc': 0.9021842355175689, 'test_loss': 0.31034330848385305, '_timestamp': 1653029532, '_time_this_iter_s': 0.0007421970367431641, '_training_iteration': 4}]\n","[{'epoch': 2, 'train_acc': 0.8432869330938483, 'train_loss': 0.4305501192808151, '_timestamp': 1653029533, '_time_this_iter_s': 0.5868339538574219, '_training_iteration': 5}]\n","[{'epoch': 2, 'test_acc': 0.9069325735992403, 'test_loss': 0.2829096720499151, '_timestamp': 1653029533, '_time_this_iter_s': 0.0007574558258056641, '_training_iteration': 6}]\n","[{'epoch': 3, 'train_acc': 0.8414907947911989, 'train_loss': 0.41592263281345365, '_timestamp': 1653029533, '_time_this_iter_s': 0.618511438369751, '_training_iteration': 7}]\n","[{'epoch': 3, 'test_acc': 0.9050332383665717, 'test_loss': 0.2828731712173013, '_timestamp': 1653029533, '_time_this_iter_s': 0.001013040542602539, '_training_iteration': 8}]\n","[{'epoch': 4, 'train_acc': 0.8428378985181859, 'train_loss': 0.4120848593967302, '_timestamp': 1653029534, '_time_this_iter_s': 0.6071341037750244, '_training_iteration': 9}]\n","[{'epoch': 4, 'test_acc': 0.9050332383665717, 'test_loss': 0.281910487834145, '_timestamp': 1653029534, '_time_this_iter_s': 0.0007085800170898438, '_training_iteration': 10}]\n","[{'epoch': 5, 'train_acc': 0.843062415806017, 'train_loss': 0.4090804461921964, '_timestamp': 1653029534, '_time_this_iter_s': 0.59639573097229, '_training_iteration': 11}]\n","[{'epoch': 5, 'test_acc': 0.905982905982906, 'test_loss': 0.2811230000327615, '_timestamp': 1653029534, '_time_this_iter_s': 0.0007147789001464844, '_training_iteration': 12}]\n","[{'epoch': 6, 'train_acc': 0.8426133812303548, 'train_loss': 0.40661772191524503, '_timestamp': 1653029535, '_time_this_iter_s': 0.5747692584991455, '_training_iteration': 13}]\n","[{'epoch': 6, 'test_acc': 0.9040835707502374, 'test_loss': 0.28044137358665466, '_timestamp': 1653029535, '_time_this_iter_s': 0.001348733901977539, '_training_iteration': 14}]\n","[{'epoch': 7, 'train_acc': 0.843062415806017, 'train_loss': 0.40452814304402895, '_timestamp': 1653029536, '_time_this_iter_s': 0.5875253677368164, '_training_iteration': 15}]\n","[{'epoch': 7, 'test_acc': 0.905982905982906, 'test_loss': 0.2798162909115062, '_timestamp': 1653029536, '_time_this_iter_s': 0.0007505416870117188, '_training_iteration': 16}]\n","[{'epoch': 8, 'train_acc': 0.8421643466546924, 'train_loss': 0.40271366930433683, '_timestamp': 1653029536, '_time_this_iter_s': 0.5620317459106445, '_training_iteration': 17}]\n","[{'epoch': 8, 'test_acc': 0.9050332383665717, 'test_loss': 0.27922289161121144, '_timestamp': 1653029536, '_time_this_iter_s': 0.0009264945983886719, '_training_iteration': 18}]\n","[{'epoch': 9, 'train_acc': 0.8414907947911989, 'train_loss': 0.40110651957137244, '_timestamp': 1653029537, '_time_this_iter_s': 0.8353402614593506, '_training_iteration': 19}]\n","[{'epoch': 9, 'test_acc': 0.9031339031339032, 'test_loss': 0.2786439481903525, '_timestamp': 1653029537, '_time_this_iter_s': 0.0008196830749511719, '_training_iteration': 20}]\n"]}],"source":["from ray.train import Trainer\n","\n","trainer = Trainer(backend=\"torch\", num_workers=num_executors)\n","trainer.start()\n","results = trainer.run(\n"," train_func, config={\"num_epochs\": 10, \"lr\": 0.001, \"batch_size\": 64},\n"," callbacks=[PrintingCallback()],\n"," dataset={\n"," \"train\": train_dataset,\n"," \"test\": test_dataset\n"," }\n",")\n","trainer.shutdown()"]},{"cell_type":"markdown","metadata":{"id":"MnUdAC7FG5HT"},"source":["## 13. shut down ray and raydp"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"riDLFnZDG8Hr"},"outputs":[],"source":["raydp.stop_spark()\n","ray.shutdown()"]}],"metadata":{"colab":{"authorship_tag":"ABX9TyOEVOflyDV6AaKwjNTZxpH+","collapsed_sections":[],"mount_file_id":"1-5XUDsPtQB8uaXJSrEDI38sYBahG2Rgk","name":"raytrain_nyctaxi.ipynb","provenance":[]},"kernelspec":{"display_name":"Python 3","name":"python3"},"language_info":{"name":"python"}},"nbformat":4,"nbformat_minor":0} +{"cells":[{"cell_type":"markdown","metadata":{"id":"pJGi7Gb-DIRB"},"source":["# **How RayDP works together with Ray**"]},{"cell_type":"markdown","metadata":{"id":"3W1tUlKEMzm2"},"source":["RayDP is a distributed data processing library that provides simple APIs for running Spark on Ray and integrating Spark with distributed deep learning and machine learning frameworks. This document builds an end-to-end deep learning pipeline on a single Ray cluster by using Spark for data preprocessing, and uses ray train to complete the training and evaluation."]},{"cell_type":"markdown","metadata":{"id":"cNZOlzR-ldrE"},"source":["[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/oap-project/raydp/blob/master/tutorials/raytrain_example.ipynb)"]},{"cell_type":"markdown","metadata":{"id":"NxuZwK3IDr6i"},"source":["## 1. Colab enviroment Setup"]},{"cell_type":"markdown","metadata":{"id":"2vMP2OEik9pO"},"source":["RayDP requires Ray and PySpark. At the same time, pytorch is used to build deep learning model."]},{"cell_type":"code","execution_count":null,"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"elapsed":679,"status":"ok","timestamp":1653029465793,"user":{"displayName":"Ke Yan","userId":"12157985255311286389"},"user_tz":-480},"id":"Z697GdNUDoDG","outputId":"5447afbd-9657-4705-f1ca-e9791f889e6d"},"outputs":[{"name":"stdout","output_type":"stream","text":["Package Version\n","----------------------------- ------------------------------\n","absl-py 1.0.0\n","alabaster 0.7.12\n","albumentations 0.1.12\n","altair 4.2.0\n","appdirs 1.4.4\n","argon2-cffi 21.3.0\n","argon2-cffi-bindings 21.2.0\n","arviz 0.12.1\n","astor 0.8.1\n","astropy 4.3.1\n","astunparse 1.6.3\n","async-timeout 4.0.2\n","atari-py 0.2.9\n","atomicwrites 1.4.0\n","attrs 21.4.0\n","audioread 2.1.9\n","autograd 1.4\n","Babel 2.10.1\n","backcall 0.2.0\n","beautifulsoup4 4.6.3\n","bleach 5.0.0\n","blis 0.4.1\n","bokeh 2.3.3\n","Bottleneck 1.3.4\n","branca 0.5.0\n","bs4 0.0.1\n","CacheControl 0.12.11\n","cached-property 1.5.2\n","cachetools 4.2.4\n","catalogue 1.0.0\n","certifi 2021.10.8\n","cffi 1.15.0\n","cftime 1.6.0\n","chardet 3.0.4\n","charset-normalizer 2.0.12\n","click 7.1.2\n","cloudpickle 1.3.0\n","cmake 3.22.4\n","cmdstanpy 0.9.5\n","colorcet 3.0.0\n","colorlover 0.3.0\n","community 1.0.0b1\n","contextlib2 0.5.5\n","convertdate 2.4.0\n","coverage 3.7.1\n","coveralls 0.5\n","crcmod 1.7\n","cufflinks 0.17.3\n","cvxopt 1.2.7\n","cvxpy 1.0.31\n","cycler 0.11.0\n","cymem 2.0.6\n","Cython 0.29.30\n","daft 0.0.4\n","dask 2.12.0\n","datascience 0.10.6\n","debugpy 1.0.0\n","decorator 4.4.2\n","defusedxml 0.7.1\n","Deprecated 1.2.13\n","descartes 1.1.0\n","dill 0.3.4\n","distributed 1.25.3\n","dlib 19.18.0+zzzcolab20220513001918\n","dm-tree 0.1.7\n","docopt 0.6.2\n","docutils 0.17.1\n","dopamine-rl 1.0.5\n","earthengine-api 0.1.309\n","easydict 1.9\n","ecos 2.0.10\n","editdistance 0.5.3\n","en-core-web-sm 2.2.5\n","entrypoints 0.4\n","ephem 4.1.3\n","et-xmlfile 1.1.0\n","fa2 0.3.5\n","fastai 1.0.61\n","fastdtw 0.3.4\n","fastjsonschema 2.15.3\n","fastprogress 1.0.2\n","fastrlock 0.8\n","fbprophet 0.7.1\n","feather-format 0.4.1\n","filelock 3.7.0\n","firebase-admin 4.4.0\n","fix-yahoo-finance 0.0.22\n","Flask 1.1.4\n","flatbuffers 2.0\n","folium 0.8.3\n","future 0.16.0\n","gast 0.5.3\n","GDAL 2.2.2\n","gdown 4.4.0\n","gensim 3.6.0\n","geographiclib 1.52\n","geopy 1.17.0\n","gin-config 0.5.0\n","glob2 0.7\n","google 2.0.3\n","google-api-core 1.31.5\n","google-api-python-client 1.12.11\n","google-auth 1.35.0\n","google-auth-httplib2 0.0.4\n","google-auth-oauthlib 0.4.6\n","google-cloud-bigquery 1.21.0\n","google-cloud-bigquery-storage 1.1.1\n","google-cloud-core 1.0.3\n","google-cloud-datastore 1.8.0\n","google-cloud-firestore 1.7.0\n","google-cloud-language 1.2.0\n","google-cloud-storage 1.18.1\n","google-cloud-translate 1.5.0\n","google-colab 1.0.0\n","google-pasta 0.2.0\n","google-resumable-media 0.4.1\n","googleapis-common-protos 1.56.1\n","googledrivedownloader 0.4\n","graphviz 0.10.1\n","greenlet 1.1.2\n","grpcio 1.46.1\n","gspread 3.4.2\n","gspread-dataframe 3.0.8\n","gym 0.17.3\n","h5py 3.1.0\n","HeapDict 1.0.1\n","hijri-converter 2.2.3\n","holidays 0.10.5.2\n","holoviews 1.14.9\n","html5lib 1.0.1\n","httpimport 0.5.18\n","httplib2 0.17.4\n","httplib2shim 0.0.3\n","humanize 0.5.1\n","hyperopt 0.1.2\n","ideep4py 2.0.0.post3\n","idna 2.10\n","imageio 2.4.1\n","imagesize 1.3.0\n","imbalanced-learn 0.8.1\n","imblearn 0.0\n","imgaug 0.2.9\n","importlib-metadata 4.11.3\n","importlib-resources 5.7.1\n","imutils 0.5.4\n","inflect 2.1.0\n","iniconfig 1.1.1\n","intel-openmp 2022.1.0\n","intervaltree 2.1.0\n","ipykernel 4.10.1\n","ipython 5.5.0\n","ipython-genutils 0.2.0\n","ipython-sql 0.3.9\n","ipywidgets 7.7.0\n","itsdangerous 1.1.0\n","jax 0.3.8\n","jaxlib 0.3.7+cuda11.cudnn805\n","jedi 0.18.1\n","jieba 0.42.1\n","Jinja2 2.11.3\n","joblib 1.1.0\n","jpeg4py 0.1.4\n","jsonschema 4.3.3\n","jupyter 1.0.0\n","jupyter-client 5.3.5\n","jupyter-console 5.2.0\n","jupyter-core 4.10.0\n","jupyterlab-pygments 0.2.2\n","jupyterlab-widgets 1.1.0\n","kaggle 1.5.12\n","kapre 0.3.7\n","keras 2.8.0\n","Keras-Preprocessing 1.1.2\n","keras-vis 0.4.1\n","kiwisolver 1.4.2\n","korean-lunar-calendar 0.2.1\n","libclang 14.0.1\n","librosa 0.8.1\n","lightgbm 2.2.3\n","llvmlite 0.34.0\n","lmdb 0.99\n","LunarCalendar 0.0.9\n","lxml 4.2.6\n","Markdown 3.3.7\n","MarkupSafe 2.0.1\n","matplotlib 3.2.2\n","matplotlib-inline 0.1.3\n","matplotlib-venn 0.11.7\n","missingno 0.5.1\n","mistune 0.8.4\n","mizani 0.6.0\n","mkl 2019.0\n","mlxtend 0.14.0\n","more-itertools 8.13.0\n","moviepy 0.2.3.5\n","mpmath 1.2.1\n","msgpack 1.0.3\n","multiprocess 0.70.12.2\n","multitasking 0.0.10\n","murmurhash 1.0.7\n","music21 5.5.0\n","natsort 5.5.0\n","nbclient 0.6.3\n","nbconvert 5.6.1\n","nbformat 5.4.0\n","nest-asyncio 1.5.5\n","netCDF4 1.5.8\n","netifaces 0.11.0\n","networkx 2.6.3\n","nibabel 3.0.2\n","nltk 3.2.5\n","notebook 5.3.1\n","numba 0.51.2\n","numexpr 2.8.1\n","numpy 1.21.6\n","nvidia-ml-py3 7.352.0\n","oauth2client 4.1.3\n","oauthlib 3.2.0\n","okgrade 0.4.3\n","opencv-contrib-python 4.1.2.30\n","opencv-python 4.1.2.30\n","openpyxl 3.0.9\n","opt-einsum 3.3.0\n","osqp 0.6.2.post0\n","packaging 21.3\n","palettable 3.3.0\n","pandas 1.3.5\n","pandas-datareader 0.9.0\n","pandas-gbq 0.13.3\n","pandas-profiling 1.4.1\n","pandocfilters 1.5.0\n","panel 0.12.1\n","param 1.12.1\n","parso 0.8.3\n","pathlib 1.0.1\n","patsy 0.5.2\n","pep517 0.12.0\n","pexpect 4.8.0\n","pickleshare 0.7.5\n","Pillow 7.1.2\n","pip 21.1.3\n","pip-tools 6.2.0\n","plac 1.1.3\n","plotly 5.5.0\n","plotnine 0.6.0\n","pluggy 0.7.1\n","pooch 1.6.0\n","portpicker 1.3.9\n","prefetch-generator 1.0.1\n","preshed 3.0.6\n","prettytable 3.3.0\n","progressbar2 3.38.0\n","prometheus-client 0.14.1\n","promise 2.3\n","prompt-toolkit 1.0.18\n","protobuf 3.17.3\n","psutil 5.4.8\n","psycopg2 2.7.6.1\n","ptyprocess 0.7.0\n","py 1.11.0\n","py4j 0.10.9.3\n","pyarrow 6.0.1\n","pyasn1 0.4.8\n","pyasn1-modules 0.2.8\n","pycocotools 2.0.4\n","pycparser 2.21\n","pyct 0.4.8\n","pydata-google-auth 1.4.0\n","pydot 1.3.0\n","pydot-ng 2.0.0\n","pydotplus 2.0.2\n","PyDrive 1.3.1\n","pyemd 0.5.1\n","pyerfa 2.0.0.1\n","pyglet 1.5.0\n","Pygments 2.6.1\n","pygobject 3.26.1\n","pymc3 3.11.4\n","PyMeeus 0.5.11\n","pymongo 4.1.1\n","pymystem3 0.2.0\n","PyOpenGL 3.1.6\n","pyparsing 3.0.9\n","pyrsistent 0.18.1\n","pysndfile 1.3.8\n","PySocks 1.7.1\n","pyspark 3.2.1\n","pystan 2.19.1.1\n","pytest 3.6.4\n","python-apt 0.0.0\n","python-chess 0.23.11\n","python-dateutil 2.8.2\n","python-louvain 0.16\n","python-slugify 6.1.2\n","python-utils 3.2.2\n","pytz 2022.1\n","pyviz-comms 2.2.0\n","PyWavelets 1.3.0\n","PyYAML 3.13\n","pyzmq 22.3.0\n","qdldl 0.1.5.post2\n","qtconsole 5.3.0\n","QtPy 2.1.0\n","ray 1.9.0\n","raydp-nightly 2022.5.12.dev0\n","redis 4.3.1\n","regex 2019.12.20\n","requests 2.23.0\n","requests-oauthlib 1.3.1\n","resampy 0.2.2\n","rpy2 3.4.5\n","rsa 4.8\n","scikit-image 0.18.3\n","scikit-learn 1.0.2\n","scipy 1.4.1\n","screen-resolution-extra 0.0.0\n","scs 3.2.0\n","seaborn 0.11.2\n","semver 2.13.0\n","Send2Trash 1.8.0\n","setuptools 57.4.0\n","setuptools-git 1.2\n","Shapely 1.8.2\n","simplegeneric 0.8.1\n","six 1.15.0\n","sklearn 0.0\n","sklearn-pandas 1.8.0\n","smart-open 6.0.0\n","snowballstemmer 2.2.0\n","sortedcontainers 2.4.0\n","SoundFile 0.10.3.post1\n","soupsieve 2.3.2.post1\n","spacy 2.2.4\n","Sphinx 1.8.6\n","sphinxcontrib-serializinghtml 1.1.5\n","sphinxcontrib-websupport 1.2.4\n","SQLAlchemy 1.4.36\n","sqlparse 0.4.2\n","srsly 1.0.5\n","statsmodels 0.10.2\n","sympy 1.7.1\n","tables 3.7.0\n","tabulate 0.8.9\n","tblib 1.7.0\n","tenacity 8.0.1\n","tensorboard 2.8.0\n","tensorboard-data-server 0.6.1\n","tensorboard-plugin-wit 1.8.1\n","tensorboardX 2.5\n","tensorflow 2.8.0+zzzcolab20220506162203\n","tensorflow-datasets 4.0.1\n","tensorflow-estimator 2.8.0\n","tensorflow-gcs-config 2.8.0\n","tensorflow-hub 0.12.0\n","tensorflow-io-gcs-filesystem 0.25.0\n","tensorflow-metadata 1.8.0\n","tensorflow-probability 0.16.0\n","termcolor 1.1.0\n","terminado 0.13.3\n","testpath 0.6.0\n","text-unidecode 1.3\n","textblob 0.15.3\n","Theano-PyMC 1.1.2\n","thinc 7.4.0\n","threadpoolctl 3.1.0\n","tifffile 2021.11.2\n","tinycss2 1.1.1\n","tomli 2.0.1\n","toolz 0.11.2\n","torch 1.8.1+cpu\n","torchaudio 0.11.0+cu113\n","torchsummary 1.5.1\n","torchtext 0.12.0\n","torchvision 0.12.0+cu113\n","tornado 5.1.1\n","tqdm 4.64.0\n","traitlets 5.1.1\n","tweepy 3.10.0\n","typeguard 2.7.1\n","typing 3.7.4.3\n","typing-extensions 4.2.0\n","tzlocal 1.5.1\n","uritemplate 3.0.1\n","urllib3 1.24.3\n","vega-datasets 0.9.0\n","wasabi 0.9.1\n","wcwidth 0.2.5\n","webencodings 0.5.1\n","Werkzeug 1.0.1\n","wheel 0.37.1\n","widgetsnbextension 3.6.0\n","wordcloud 1.5.0\n","wrapt 1.14.1\n","xarray 0.20.2\n","xarray-einstats 0.2.2\n","xgboost 0.90\n","xkit 0.0.0\n","xlrd 1.1.0\n","xlwt 1.3.0\n","yellowbrick 1.4\n","zict 2.2.0\n","zipp 3.8.0\n"]}],"source":["! pip install ray==1.9\n","! pip install raydp==0.5.0\n","! pip install ray[tune]\n","! pip install torch==1.8.1+cpu -f https://download.pytorch.org/whl/torch_stable.html\n","! pip install torchmetrics"]},{"cell_type":"markdown","metadata":{"id":"x3o7NPecpbeA"},"source":["## 2. Get the data file"]},{"cell_type":"markdown","metadata":{"id":"JpiXje_Jmkp3"},"source":["The dataset is from: https://www.kaggle.com/datasets/fedesoriano/stroke-prediction-dataset, and we store the file in github repository. It's used to predict whether a patient is likely to get stroke based on the input parameters like gender, age, various diseases, and smoking status. Each row in the data provides relavant information about the patient. "]},{"cell_type":"code","execution_count":null,"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"elapsed":706,"status":"ok","timestamp":1653029466495,"user":{"displayName":"Ke Yan","userId":"12157985255311286389"},"user_tz":-480},"id":"ywec9GENpe7-","outputId":"01a68d70-e4eb-4343-8605-3116e0dde429"},"outputs":[],"source":["! wget https://raw.githubusercontent.com/oap-project/raydp/master/tutorials/dataset/healthcare-dataset-stroke-data.csv -O healthcare-dataset-stroke-data.csv"]},{"cell_type":"markdown","metadata":{"id":"fSD4v0zKEx7d"},"source":["## 3. Init or connect to a ray cluster"]},{"cell_type":"code","execution_count":null,"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"elapsed":4310,"status":"ok","timestamp":1653029470802,"user":{"displayName":"Ke Yan","userId":"12157985255311286389"},"user_tz":-480},"id":"kOXeU_qmE1uH","outputId":"de23e901-c1a9-4582-bf41-d5fd7abc2b64"},"outputs":[{"data":{"text/plain":["{'metrics_export_port': 44425,\n"," 'node_id': '1e8d9137f4f4a9343edcce64f2d61d72ec70d624899d07e818bd34e1',\n"," 'node_ip_address': '172.28.0.2',\n"," 'object_store_address': '/tmp/ray/session_2022-05-20_06-51-06_717688_61/sockets/plasma_store',\n"," 'raylet_ip_address': '172.28.0.2',\n"," 'raylet_socket_name': '/tmp/ray/session_2022-05-20_06-51-06_717688_61/sockets/raylet',\n"," 'redis_address': '172.28.0.2:6379',\n"," 'session_dir': '/tmp/ray/session_2022-05-20_06-51-06_717688_61',\n"," 'webui_url': None}"]},"execution_count":5,"metadata":{},"output_type":"execute_result"}],"source":["import ray\n","\n","ray.init(num_cpus=6)"]},{"cell_type":"markdown","metadata":{"id":"xyzXEJg7FVrg"},"source":["## 4. Get a spark session"]},{"cell_type":"code","execution_count":null,"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"elapsed":20879,"status":"ok","timestamp":1653029491677,"user":{"displayName":"Ke Yan","userId":"12157985255311286389"},"user_tz":-480},"id":"R9YydfatFfuF","outputId":"fb4c7b7b-f30f-42c3-daf9-fa472916f93f"},"outputs":[{"name":"stdout","output_type":"stream","text":["\u001b[2m\u001b[36m(RayDPSparkMaster pid=529)\u001b[0m WARNING: sun.reflect.Reflection.getCallerClass is not supported. This will impact performance.\n"]},{"name":"stderr","output_type":"stream","text":["\u001b[2m\u001b[36m(RayDPSparkMaster pid=529)\u001b[0m WARNING: An illegal reflective access operation has occurred\n","\u001b[2m\u001b[36m(RayDPSparkMaster pid=529)\u001b[0m WARNING: Illegal reflective access by org.apache.spark.unsafe.Platform (file:/usr/local/lib/python3.7/dist-packages/pyspark/jars/spark-unsafe_2.12-3.2.1.jar) to constructor java.nio.DirectByteBuffer(long,int)\n","\u001b[2m\u001b[36m(RayDPSparkMaster pid=529)\u001b[0m WARNING: Please consider reporting this to the maintainers of org.apache.spark.unsafe.Platform\n","\u001b[2m\u001b[36m(RayDPSparkMaster pid=529)\u001b[0m WARNING: Use --illegal-access=warn to enable warnings of further illegal reflective access operations\n","\u001b[2m\u001b[36m(RayDPSparkMaster pid=529)\u001b[0m WARNING: All illegal access operations will be denied in a future release\n"]},{"name":"stdout","output_type":"stream","text":["\u001b[2m\u001b[36m(RayDPSparkMaster pid=529)\u001b[0m 2022-05-20 06:51:17,557 WARN NativeCodeLoader [Thread-2]: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable\n","\u001b[2m\u001b[36m(RayDPSparkMaster pid=529)\u001b[0m 2022-05-20 06:51:17,872 INFO SecurityManager [Thread-2]: Changing view acls to: root\n","\u001b[2m\u001b[36m(RayDPSparkMaster pid=529)\u001b[0m 2022-05-20 06:51:17,879 INFO SecurityManager [Thread-2]: Changing modify acls to: root\n","\u001b[2m\u001b[36m(RayDPSparkMaster pid=529)\u001b[0m 2022-05-20 06:51:17,880 INFO SecurityManager [Thread-2]: Changing view acls groups to: \n","\u001b[2m\u001b[36m(RayDPSparkMaster pid=529)\u001b[0m 2022-05-20 06:51:17,882 INFO SecurityManager [Thread-2]: Changing modify acls groups to: \n","\u001b[2m\u001b[36m(RayDPSparkMaster pid=529)\u001b[0m 2022-05-20 06:51:17,883 INFO SecurityManager [Thread-2]: SecurityManager: authentication disabled; ui acls disabled; users with view permissions: Set(root); groups with view permissions: Set(); users with modify permissions: Set(root); groups with modify permissions: Set()\n","\u001b[2m\u001b[36m(RayDPSparkMaster pid=529)\u001b[0m 2022-05-20 06:51:19,091 INFO Utils [Thread-2]: Successfully started service 'RAY_RPC_ENV' on port 40877.\n","\u001b[2m\u001b[36m(RayDPSparkMaster pid=529)\u001b[0m 2022-05-20 06:51:27,127 INFO RayAppMaster$RayAppMasterEndpoint [dispatcher-event-loop-1]: Registering app Stoke Prediction with RayDP\n","\u001b[2m\u001b[36m(RayDPSparkMaster pid=529)\u001b[0m 2022-05-20 06:51:27,132 INFO RayAppMaster$RayAppMasterEndpoint [dispatcher-event-loop-1]: Registered app Stoke Prediction with RayDP with ID app-20220520065127-0000\n"]}],"source":["import raydp\n","\n","app_name = \"Stoke Prediction with RayDP\"\n","num_executors = 1\n","cores_per_executor = 1\n","memory_per_executor = \"500M\"\n","spark = raydp.init_spark(app_name, num_executors, cores_per_executor, memory_per_executor)"]},{"cell_type":"markdown","metadata":{"id":"jse6pc7OL2pA"},"source":["## 5. Get data from .csv file via 'spark' created by **raydp**"]},{"cell_type":"code","execution_count":null,"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"elapsed":15155,"status":"ok","timestamp":1653029506826,"user":{"displayName":"Ke Yan","userId":"12157985255311286389"},"user_tz":-480},"id":"DhyOSZhfL8_0","outputId":"904d2de2-d736-4cdc-8741-8abfb08e48ab"},"outputs":[{"name":"stderr","output_type":"stream","text":["\u001b[2m\u001b[33m(raylet)\u001b[0m WARNING: An illegal reflective access operation has occurred\n","\u001b[2m\u001b[33m(raylet)\u001b[0m WARNING: Illegal reflective access by org.apache.spark.unsafe.Platform (file:/usr/local/lib/python3.7/dist-packages/pyspark/jars/spark-unsafe_2.12-3.2.1.jar) to constructor java.nio.DirectByteBuffer(long,int)\n","\u001b[2m\u001b[33m(raylet)\u001b[0m WARNING: Please consider reporting this to the maintainers of org.apache.spark.unsafe.Platform\n","\u001b[2m\u001b[33m(raylet)\u001b[0m WARNING: Use --illegal-access=warn to enable warnings of further illegal reflective access operations\n","\u001b[2m\u001b[33m(raylet)\u001b[0m WARNING: All illegal access operations will be denied in a future release\n"]}],"source":["data = spark.read.format(\"csv\").option(\"header\", \"true\") \\\n"," .option(\"inferSchema\", \"true\") \\\n"," .load(\"/content/healthcare-dataset-stroke-data.csv\")"]},{"cell_type":"markdown","metadata":{"id":"yYUc524PVI2W"},"source":["## 6. Define the data_process function"]},{"cell_type":"markdown","metadata":{"id":"unbR1vvSqytm"},"source":["The dataset is converted to `pyspark.sql.dataframe.DataFrame`. Before feeding into the deep learning model, we can use raydp to do some transformation operations on dataset."]},{"cell_type":"markdown","metadata":{"id":"jJXWdonutekH"},"source":["### 6.1 Data Analysis"]},{"cell_type":"markdown","metadata":{"id":"RuZuZeyZSHQu"},"source":["Here is a part of the data analysis."]},{"cell_type":"code","execution_count":null,"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"elapsed":9734,"status":"ok","timestamp":1653029516546,"user":{"displayName":"Ke Yan","userId":"12157985255311286389"},"user_tz":-480},"id":"BUcjEczateOW","outputId":"96a6191b-225a-4320-cd01-0145b1c8e966"},"outputs":[{"name":"stdout","output_type":"stream","text":["+-----+------+----+------------+-------------+------------+-------------+--------------+-----------------+----+---------------+------+\n","| id|gender| age|hypertension|heart_disease|ever_married| work_type|Residence_type|avg_glucose_level| bmi| smoking_status|stroke|\n","+-----+------+----+------------+-------------+------------+-------------+--------------+-----------------+----+---------------+------+\n","| 9046| Male|67.0| 0| 1| Yes| Private| Urban| 228.69|36.6|formerly smoked| 1|\n","|51676|Female|61.0| 0| 0| Yes|Self-employed| Rural| 202.21| N/A| never smoked| 1|\n","|31112| Male|80.0| 0| 1| Yes| Private| Rural| 105.92|32.5| never smoked| 1|\n","|60182|Female|49.0| 0| 0| Yes| Private| Urban| 171.23|34.4| smokes| 1|\n","| 1665|Female|79.0| 1| 0| Yes|Self-employed| Rural| 174.12| 24| never smoked| 1|\n","+-----+------+----+------------+-------------+------------+-------------+--------------+-----------------+----+---------------+------+\n","only showing top 5 rows\n","\n","+-------+-----------------+------+------------------+------------------+-------------------+------------+---------+--------------+------------------+------------------+--------------+-------------------+\n","|summary| id|gender| age| hypertension| heart_disease|ever_married|work_type|Residence_type| avg_glucose_level| bmi|smoking_status| stroke|\n","+-------+-----------------+------+------------------+------------------+-------------------+------------+---------+--------------+------------------+------------------+--------------+-------------------+\n","| count| 5110| 5110| 5110| 5110| 5110| 5110| 5110| 5110| 5110| 5110| 5110| 5110|\n","| mean|36517.82935420744| null|43.226614481409015|0.0974559686888454|0.05401174168297456| null| null| null|106.14767710371804|28.893236911794673| null| 0.0487279843444227|\n","| stddev|21161.72162482715| null| 22.61264672311348| 0.296606674233791|0.22606298750336554| null| null| null| 45.28356015058193| 7.85406672968016| null|0.21531985698023753|\n","| min| 67|Female| 0.08| 0| 0| No| Govt_job| Rural| 55.12| 10.3| Unknown| 0|\n","| max| 72940| Other| 82.0| 1| 1| Yes| children| Urban| 271.74| N/A| smokes| 1|\n","+-------+-----------------+------+------------------+------------------+-------------------+------------+---------+--------------+------------------+------------------+--------------+-------------------+\n","\n","+------+-----+\n","|gender|count|\n","+------+-----+\n","| Male| 2115|\n","| null| 5110|\n","|Female| 2994|\n","| Other| 1|\n","+------+-----+\n","\n","+------+-----+\n","|stroke|count|\n","+------+-----+\n","| 1| 249|\n","| 0| 4861|\n","| null| 5110|\n","+------+-----+\n","\n"]}],"source":["# Data overview\n","data.show(5)\n","# Statistical N/A distribution\n","# There are 201 'N/A' value in column 'bmi column',\n","# we can update them the mean of the column\n","data.describe().show()\n","data.filter(data.bmi=='N/A').count()\n","# Observe the distribution of the column 'gender'\n","# Then we should remove the outliers 'Other'\n","data.rollup(data.gender).count().show()\n","# Observe the proportion of positive and negative samples.\n","data.rollup(data.stroke).count().show()"]},{"cell_type":"markdown","metadata":{"id":"cF9XEK1AtsH3"},"source":["### 6.2 Define operations"]},{"cell_type":"markdown","metadata":{"id":"gaPmPn7TsMFP"},"source":["Define data processing operations based on data analysis results."]},{"cell_type":"code","execution_count":null,"metadata":{"id":"PZn2UuWGxLsO"},"outputs":[],"source":["from pyspark.sql.functions import hour, quarter, month, year, dayofweek, dayofmonth, weekofyear, col, lit, udf, abs as functions_abs, avg"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"SSubHn45seR3"},"outputs":[],"source":["# Delete the useless column 'id'\n","def drop_col(data):\n"," data = data.drop('id')\n"," return data"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"ybPMTUtcshbm"},"outputs":[],"source":["# Replace the value N/A in 'bmi'\n","def replace_nan(data):\n"," bmi_avg = data.agg(avg(col(\"bmi\"))).head()[0]\n","\n"," @udf(\"float\")\n"," def replace_nan(value):\n"," if value=='N/A':\n"," return float(bmi_avg)\n"," else:\n"," return float(value)\n","\n"," # Replace the value N/A\n"," data = data.withColumn('bmi', replace_nan(col(\"bmi\")))\n"," return data"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"WOzxnQp7sqOw"},"outputs":[],"source":["# Drop the only one value 'Other' in column 'gender'\n","def clean_value(data):\n"," data = data.filter(data.gender != 'Other')\n"," return data"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"EGjPMaAbstlZ"},"outputs":[],"source":["# Transform the category columns\n","def trans_category(data):\n"," @udf(\"int\")\n"," def trans_gender(value):\n"," gender = {'Female': 0,\n"," 'Male': 1}\n"," return int(gender[value])\n","\n"," @udf(\"int\")\n"," def trans_ever_married(value):\n"," residence_type = {'No': 0,\n"," 'Yes': 1}\n"," return int(residence_type[value])\n","\n"," @udf(\"int\")\n"," def trans_work_type(value):\n"," work_type = {'children': 0,\n"," 'Govt_job': 1,\n"," 'Never_worked': 2,\n"," 'Private': 3,\n"," 'Self-employed': 4}\n"," return int(work_type[value])\n","\n"," @udf(\"int\")\n"," def trans_residence_type(value):\n"," residence_type = {'Rural': 0,\n"," 'Urban': 1}\n"," return int(residence_type[value])\n","\n"," @udf(\"int\")\n"," def trans_smoking_status(value):\n"," smoking_status = {'formerly smoked': 0,\n"," 'never smoked': 1,\n"," 'smokes': 2,\n"," 'Unknown': 3}\n"," return int(smoking_status[value])\n","\n"," data = data.withColumn('gender', trans_gender(col('gender'))) \\\n"," .withColumn('ever_married', trans_ever_married(col('ever_married'))) \\\n"," .withColumn('work_type', trans_work_type(col('work_type'))) \\\n"," .withColumn('Residence_type', trans_residence_type(col('Residence_type'))) \\\n"," .withColumn('smoking_status', trans_smoking_status(col('smoking_status')))\n"," return data"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"s3b_s_aqsxqo"},"outputs":[],"source":["# Add the discretized column of 'Age'\n","def map_age(data):\n"," @udf(\"int\")\n"," def get_value(value):\n"," if value >= 18 and value < 26:\n"," return int(0)\n"," elif value >=26 and value < 36:\n"," return int(1)\n"," elif value >=36 and value < 46:\n"," return int(2)\n"," elif value >=46 and value < 56:\n"," return int(3)\n"," else:\n"," return int(4)\n","\n"," data = data.withColumn('age_dis', get_value(col('age')))\n"," return data"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"Jk5b9MNwVNNk"},"outputs":[],"source":["# Preprocess the data\n","def data_preprocess(data):\n"," data = drop_col(data)\n"," data = replace_nan(data)\n"," data = clean_value(data)\n"," data = trans_category(data)\n"," data = map_age(data)\n"," return data"]},{"cell_type":"markdown","metadata":{"id":"ZPgzXrdUFtL4"},"source":["## 7. Data processing"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"ztuBpv9SF0_R"},"outputs":[],"source":["import torch\n","from raydp.utils import random_split\n","\n","# Transform the dataset\n","data = data_preprocess(data)\n","# Split data into train_dataset and test_dataset\n","train_df, test_df = random_split(data, [0.8, 0.2], 0)\n","# Balance the positive and negative samples\n","train_df_neg = train_df.filter(train_df.stroke == '1')\n","train_df = train_df.unionByName(train_df_neg)\n","train_df = train_df.unionByName(train_df_neg)\n","features = [field.name for field in list(train_df.schema) if field.name != \"stroke\"]\n","# Convert spark dataframe into ray Dataset\n","# Remember to align ``parallelism`` with ``num_workers`` of ray train\n","train_dataset = ray.data.from_spark(train_df, parallelism = 8)\n","test_dataset = ray.data.from_spark(test_df, parallelism = 8)\n","feature_dtype = [torch.float] * len(features)"]},{"cell_type":"markdown","metadata":{"id":"9PujezwdGJ87"},"source":["## 8. Define a neural network model"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"RcsoN3tWGUEM"},"outputs":[],"source":["import torch.nn as nn\n","import torch.nn.functional as F\n","\n","class NET_Model(nn.Module):\n"," def __init__(self, cols):\n"," super().__init__()\n"," self.emb_layer_gender = nn.Embedding(2, 2) # gender\n"," self.emb_layer_hypertension = nn.Embedding(2,2) # hypertension\n"," self.emb_layer_heart_disease = nn.Embedding(2,2) # heart_disease\n"," self.emb_layer_ever_married = nn.Embedding(2, 2) # ever_married\n"," self.emb_layer_work = nn.Embedding(5, 5) # work_type\n"," self.emb_layer_residence = nn.Embedding(2, 2) # Residence_type\n"," self.emb_layer_smoking_status = nn.Embedding(4, 4) # smoking_status\n"," self.emb_layer_age = nn.Embedding(5, 5) # age column after discretization\n"," self.fc1 = nn.Linear(cols, 256)\n"," self.fc2 = nn.Linear(256, 128)\n"," self.fc3 = nn.Linear(128, 64)\n"," self.fc4 = nn.Linear(64, 16)\n"," self.fc5 = nn.Linear(16, 2)\n","\n"," self.fc_sparse = nn.Linear(24, 16)\n"," self.fc_dense = nn.Linear(3, 8)\n"," self.fc = nn.Linear(24, 2)\n"," \n"," self.bn1 = nn.BatchNorm1d(256)\n"," self.bn2 = nn.BatchNorm1d(128)\n"," self.bn3 = nn.BatchNorm1d(64)\n"," self.bn4 = nn.BatchNorm1d(16)\n","\n"," def forward(self, *x):\n"," x = torch.cat(x, dim=1)\n"," # pick the dense attribute columns\n"," dense_columns = x[:, [1,7,8]]\n"," # Embedding operation on sparse attribute columns\n"," sparse_col_1 = self.emb_layer_gender(x[:, 0].long())\n"," sparse_col_2 = self.emb_layer_hypertension(x[:, 2].long())\n"," sparse_col_3 = self.emb_layer_heart_disease(x[:, 3].long())\n"," sparse_col_4 = self.emb_layer_ever_married(x[:, 4].long())\n"," sparse_col_5 = self.emb_layer_work(x[:, 5].long())\n"," sparse_col_6 = self.emb_layer_residence(x[:, 6].long())\n"," sparse_col_7 = self.emb_layer_smoking_status(x[:, 9].long())\n"," sparse_col_8 = self.emb_layer_age(x[:, 10].long())\n"," # Splice sparse attribute columns and dense attribute columns\n"," x = torch.cat([dense_columns, sparse_col_1, sparse_col_2, sparse_col_3, sparse_col_4, sparse_col_5, sparse_col_6, sparse_col_7, sparse_col_8], dim=1)\n","\n"," sparse_columns = torch.cat([sparse_col_1, sparse_col_2, sparse_col_3, sparse_col_4, sparse_col_5, sparse_col_6, sparse_col_7, sparse_col_8], dim=1)\n"," dense_feat = self.fc_dense(dense_columns)\n"," sparse_feat = self.fc_sparse(sparse_columns)\n"," return self.fc(torch.cat([dense_feat, sparse_feat], dim=1))\n","\n"," x = F.relu(self.fc1(x))\n"," x = self.bn1(x)\n"," x = F.relu(self.fc2(x))\n"," x = self.bn2(x)\n"," x = F.relu(self.fc3(x))\n"," x = self.bn3(x)\n"," x = F.relu(self.fc4(x))\n"," x = self.bn4(x)\n"," x = self.fc5(x)\n"," return x\n"]},{"cell_type":"markdown","metadata":{"id":"DZ81HRrOGVVj"},"source":["## 9. Define train and test function"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"0Y_08OhMGd4a"},"outputs":[],"source":["def train_epoch(dataset, model, criterion, optimizer):\n"," model.train()\n"," train_loss, correct, data_size, batch_idx = 0, 0, 0, 0\n"," for batch_idx, (inputs, targets) in enumerate(dataset):\n"," # Compute prediction error\n"," inputs = [inputs[:,i].unsqueeze(1) for i in range(inputs.size(1))]\n"," targets = targets.reshape(-1)\n"," outputs = model(*inputs)\n"," loss = criterion(outputs, targets)\n"," train_loss += loss.item()\n"," _, predicted = torch.max(outputs.data, 1)\n"," data_size += targets.size(0)\n"," correct += (predicted == targets).sum().item()\n"," # Backpropagation\n"," optimizer.zero_grad()\n"," loss.backward()\n"," optimizer.step()\n"," # Caculate the train_loss and train_acc\n"," train_loss /= (batch_idx + 1)\n"," train_acc = correct/data_size\n"," return train_acc, train_loss\n","\n","def test_epoch(dataset, model, criterion):\n"," model.eval()\n"," test_loss, correct, data_size, batch_idx = 0, 0, 0, 0\n"," with torch.no_grad():\n"," for batch_idx, (inputs, targets) in enumerate(dataset):\n"," # Compute prediction error\n"," inputs = [inputs[:,i].unsqueeze(1) for i in range(inputs.size(1))]\n"," targets = targets.reshape(-1)\n"," outputs = model(*inputs)\n"," test_loss += criterion(outputs, targets).item()\n"," _, predicted = torch.max(outputs.data, 1)\n"," data_size += targets.size(0)\n"," correct += (predicted == targets).sum().item()\n"," # Caculate the test_loss and test_acc\n"," test_loss /= (batch_idx + 1)\n"," test_acc = correct/data_size\n"," return test_acc, test_loss"]},{"cell_type":"markdown","metadata":{"id":"dRAamyi3GewE"},"source":["## 10. Define train function"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"76x_ZEdKGmd3"},"outputs":[],"source":["from ray import train\n","from ray.train import get_dataset_shard\n","\n","def train_func(config):\n"," num_epochs = config[\"num_epochs\"]\n"," lr = config[\"lr\"]\n"," batch_size = config[\"batch_size\"]\n"," # Then convert to torch datasets\n"," # Get the corresponging shard\n"," train_data_shard = get_dataset_shard(\"train\")\n"," train_dataset = train_data_shard.to_torch(feature_columns=features,\n"," label_column=\"stroke\",\n"," label_column_dtype=torch.long,\n"," feature_column_dtypes=feature_dtype,\n"," batch_size=batch_size)\n"," test_data_shard = get_dataset_shard(\"test\")\n"," test_dataset = test_data_shard.to_torch(feature_columns=features,\n"," label_column=\"stroke\",\n"," label_column_dtype=torch.long,\n"," feature_column_dtypes=feature_dtype,\n"," batch_size=batch_size)\n"," model = NET_Model(len(features))\n"," model = train.torch.prepare_model(model)\n"," criterion = nn.CrossEntropyLoss(weight=torch.tensor([0.35, 0.65]))\n"," optimizer = torch.optim.Adam(model.parameters(), lr=lr)\n"," loss_results = []\n"," for epoch in range(num_epochs):\n"," train_acc, train_loss = train_epoch(train_dataset, model, criterion, optimizer)\n"," test_acc, test_loss = test_epoch(test_dataset, model, criterion)\n"," train.report(epoch = epoch, train_acc = train_acc, train_loss = train_loss)\n"," train.report(epoch = epoch, test_acc=test_acc, test_loss=test_loss)\n"," loss_results.append(test_loss)"]},{"cell_type":"markdown","metadata":{"id":"d-wfpCJ_NL4-"},"source":["## 11. Define the callback function"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"omYjJ7WTNZUo"},"outputs":[],"source":["from ray.train import TrainingCallback\n","from typing import List, Dict\n","\n","# log the train results\n","class PrintingCallback(TrainingCallback):\n"," def handle_result(self, results: List[Dict], **info):\n"," print(results)"]},{"cell_type":"markdown","metadata":{"id":"kaW_aC5zGoWZ"},"source":["## 12. Train model via ray train"]},{"cell_type":"code","execution_count":null,"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"elapsed":8380,"status":"ok","timestamp":1653029538050,"user":{"displayName":"Ke Yan","userId":"12157985255311286389"},"user_tz":-480},"id":"Rerm_OwoG0g9","outputId":"62322657-f23d-4485-cbbf-10ec5d4f84de"},"outputs":[{"name":"stderr","output_type":"stream","text":["2022-05-20 06:52:09,475\tINFO trainer.py:172 -- Trainer logs will be logged in: /root/ray_results/train_2022-05-20_06-52-09\n","2022-05-20 06:52:10,966\tINFO trainer.py:178 -- Run results will be logged in: /root/ray_results/train_2022-05-20_06-52-09/run_001\n","\u001b[2m\u001b[36m(BaseWorkerMixin pid=1234)\u001b[0m 2022-05-20 06:52:10,963\tINFO torch.py:67 -- Setting up process group for: env:// [rank=0, world_size=1]\n","\u001b[2m\u001b[36m(BaseWorkerMixin pid=1234)\u001b[0m 2022-05-20 06:52:11,020\tINFO torch.py:239 -- Moving model to device: cpu\n"]},{"name":"stdout","output_type":"stream","text":["[{'epoch': 0, 'train_acc': 0.6178715761113606, 'train_loss': 2.4214471597756657, '_timestamp': 1653029531, '_time_this_iter_s': 0.8797011375427246, '_training_iteration': 1}]\n","[{'epoch': 0, 'test_acc': 0.7587844254510921, 'test_loss': 0.7357403390547809, '_timestamp': 1653029531, '_time_this_iter_s': 0.0008981227874755859, '_training_iteration': 2}]\n","[{'epoch': 1, 'train_acc': 0.797485406376291, 'train_loss': 0.6913646668195724, '_timestamp': 1653029532, '_time_this_iter_s': 0.5976800918579102, '_training_iteration': 3}]\n","[{'epoch': 1, 'test_acc': 0.9021842355175689, 'test_loss': 0.31034330848385305, '_timestamp': 1653029532, '_time_this_iter_s': 0.0007421970367431641, '_training_iteration': 4}]\n","[{'epoch': 2, 'train_acc': 0.8432869330938483, 'train_loss': 0.4305501192808151, '_timestamp': 1653029533, '_time_this_iter_s': 0.5868339538574219, '_training_iteration': 5}]\n","[{'epoch': 2, 'test_acc': 0.9069325735992403, 'test_loss': 0.2829096720499151, '_timestamp': 1653029533, '_time_this_iter_s': 0.0007574558258056641, '_training_iteration': 6}]\n","[{'epoch': 3, 'train_acc': 0.8414907947911989, 'train_loss': 0.41592263281345365, '_timestamp': 1653029533, '_time_this_iter_s': 0.618511438369751, '_training_iteration': 7}]\n","[{'epoch': 3, 'test_acc': 0.9050332383665717, 'test_loss': 0.2828731712173013, '_timestamp': 1653029533, '_time_this_iter_s': 0.001013040542602539, '_training_iteration': 8}]\n","[{'epoch': 4, 'train_acc': 0.8428378985181859, 'train_loss': 0.4120848593967302, '_timestamp': 1653029534, '_time_this_iter_s': 0.6071341037750244, '_training_iteration': 9}]\n","[{'epoch': 4, 'test_acc': 0.9050332383665717, 'test_loss': 0.281910487834145, '_timestamp': 1653029534, '_time_this_iter_s': 0.0007085800170898438, '_training_iteration': 10}]\n","[{'epoch': 5, 'train_acc': 0.843062415806017, 'train_loss': 0.4090804461921964, '_timestamp': 1653029534, '_time_this_iter_s': 0.59639573097229, '_training_iteration': 11}]\n","[{'epoch': 5, 'test_acc': 0.905982905982906, 'test_loss': 0.2811230000327615, '_timestamp': 1653029534, '_time_this_iter_s': 0.0007147789001464844, '_training_iteration': 12}]\n","[{'epoch': 6, 'train_acc': 0.8426133812303548, 'train_loss': 0.40661772191524503, '_timestamp': 1653029535, '_time_this_iter_s': 0.5747692584991455, '_training_iteration': 13}]\n","[{'epoch': 6, 'test_acc': 0.9040835707502374, 'test_loss': 0.28044137358665466, '_timestamp': 1653029535, '_time_this_iter_s': 0.001348733901977539, '_training_iteration': 14}]\n","[{'epoch': 7, 'train_acc': 0.843062415806017, 'train_loss': 0.40452814304402895, '_timestamp': 1653029536, '_time_this_iter_s': 0.5875253677368164, '_training_iteration': 15}]\n","[{'epoch': 7, 'test_acc': 0.905982905982906, 'test_loss': 0.2798162909115062, '_timestamp': 1653029536, '_time_this_iter_s': 0.0007505416870117188, '_training_iteration': 16}]\n","[{'epoch': 8, 'train_acc': 0.8421643466546924, 'train_loss': 0.40271366930433683, '_timestamp': 1653029536, '_time_this_iter_s': 0.5620317459106445, '_training_iteration': 17}]\n","[{'epoch': 8, 'test_acc': 0.9050332383665717, 'test_loss': 0.27922289161121144, '_timestamp': 1653029536, '_time_this_iter_s': 0.0009264945983886719, '_training_iteration': 18}]\n","[{'epoch': 9, 'train_acc': 0.8414907947911989, 'train_loss': 0.40110651957137244, '_timestamp': 1653029537, '_time_this_iter_s': 0.8353402614593506, '_training_iteration': 19}]\n","[{'epoch': 9, 'test_acc': 0.9031339031339032, 'test_loss': 0.2786439481903525, '_timestamp': 1653029537, '_time_this_iter_s': 0.0008196830749511719, '_training_iteration': 20}]\n"]}],"source":["from ray.train import Trainer\n","\n","trainer = Trainer(backend=\"torch\", num_workers=num_executors)\n","trainer.start()\n","results = trainer.run(\n"," train_func, config={\"num_epochs\": 10, \"lr\": 0.001, \"batch_size\": 64},\n"," callbacks=[PrintingCallback()],\n"," dataset={\n"," \"train\": train_dataset,\n"," \"test\": test_dataset\n"," }\n",")\n","trainer.shutdown()"]},{"cell_type":"markdown","metadata":{"id":"MnUdAC7FG5HT"},"source":["## 13. shut down ray and raydp"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"riDLFnZDG8Hr"},"outputs":[],"source":["raydp.stop_spark()\n","ray.shutdown()"]}],"metadata":{"colab":{"authorship_tag":"ABX9TyOEVOflyDV6AaKwjNTZxpH+","collapsed_sections":[],"mount_file_id":"1-5XUDsPtQB8uaXJSrEDI38sYBahG2Rgk","name":"raytrain_nyctaxi.ipynb","provenance":[]},"kernelspec":{"display_name":"Python 3","name":"python3"},"language_info":{"name":"python"}},"nbformat":4,"nbformat_minor":0}