From c702109342db73e6794ac9c7b52d4a8b1b529df4 Mon Sep 17 00:00:00 2001 From: inoue0426 <8393063+inoue0426@users.noreply.github.com> Date: Thu, 29 Aug 2024 22:03:09 -0400 Subject: [PATCH] add tutorial --- Tutorial.ipynb | 159 +++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 159 insertions(+) create mode 100644 Tutorial.ipynb diff --git a/Tutorial.ipynb b/Tutorial.ipynb new file mode 100644 index 0000000..0aeedb2 --- /dev/null +++ b/Tutorial.ipynb @@ -0,0 +1,159 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "id": "premier-closing", + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "import pandas as pd\n", + "import torch\n", + "\n", + "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "documented-viewer", + "metadata": {}, + "outputs": [], + "source": [ + "import drGAT" + ] + }, + { + "cell_type": "markdown", + "id": "protected-regular", + "metadata": {}, + "source": [ + "# model evaluation" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "869740ff-e2fc-43b8-9a0f-49e637522ec4", + "metadata": {}, + "outputs": [], + "source": [ + "test = torch.load('test.pt')" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "favorite-saturday", + "metadata": { + "scrolled": true + }, + "outputs": [], + "source": [ + "tmp = !ls | grep pt\n", + "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", + "model = torch.load('sample.pt', map_location=device)" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "subject-allen", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + " | Accuracy | \n", + "Precision | \n", + "Recall | \n", + "F1 Score | \n", + "True Positive | \n", + "True Negative | \n", + "False Positive | \n", + "False Negative | \n", + "
---|---|---|---|---|---|---|---|---|
0 | \n", + "0.771375 | \n", + "0.740881 | \n", + "0.783245 | \n", + "0.761474 | \n", + "1178 | \n", + "1312 | \n", + "412 | \n", + "326 | \n", + "