Skip to content

Commit

Permalink
feat: add attention notebook
Browse files Browse the repository at this point in the history
  • Loading branch information
ex3ndr committed Jul 10, 2024
1 parent 6b9e1b6 commit dea4c61
Showing 1 changed file with 12 additions and 12 deletions.
24 changes: 12 additions & 12 deletions attention.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,8 @@
"With padding\n",
"torch 0.0\n",
"direct 0.00048828125\n",
"xformers 0.000244140625\n",
"flash 0.000244140625\n"
"xformers 0.0\n",
"flash 0.0\n"
]
}
],
Expand Down Expand Up @@ -81,15 +81,15 @@
"output_type": "stream",
"text": [
"Without padding\n",
"torch 0.4039938449859619\n",
"direct 0.9630444049835205\n",
"xformers 0.9124343395233154\n",
"flash 0.3562278747558594\n",
"torch 4.43998122215271\n",
"direct 10.454708576202393\n",
"xformers 8.981621742248535\n",
"flash 3.2445619106292725\n",
"With padding\n",
"torch 1.6574337482452393\n",
"direct 2.504969835281372\n",
"xformers 1.5768730640411377\n",
"flash 1.4189743995666504\n"
"torch 16.73130464553833\n",
"direct 25.77487087249756\n",
"xformers 16.095849990844727\n",
"flash 15.358261585235596\n"
]
}
],
Expand All @@ -98,14 +98,14 @@
"print(\"Without padding\")\n",
"for a in attentions:\n",
" start = time.time()\n",
" for i in range(10000):\n",
" for i in range(100000):\n",
" a(query, key, value)\n",
" print(a.engine, time.time() - start)\n",
"\n",
"print(\"With padding\")\n",
"for a in attentions:\n",
" start = time.time()\n",
" for i in range(10000):\n",
" for i in range(100000):\n",
" a(query, key, value, lenghts = lengths)\n",
" print(a.engine, time.time() - start)"
]
Expand Down

0 comments on commit dea4c61

Please sign in to comment.