{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Optimalizace až na cost\n", "\n", "Naučíme se optimalizovat funkce. Začneme od čisté implementace v Pythonu, zkusíme vyřešit problém v NumPy. Poté si ukážeme, jak postupně přejít k jazykům nižší úrovně, jako je Fortran nebo C, a také si představíme balík Numba.\n", "" ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "ExecuteTime": { "end_time": "2024-04-18T08:17:08.363327Z", "start_time": "2024-04-18T08:17:08.072665Z" }, "collapsed": false }, "outputs": [], "source": [ "from IPython.display import Image\n", "\n", "import matplotlib.pyplot as plt\n", "import numpy as np" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Základní koncepce optimalizace\n", "\n", "Už víme, že Python je jazyk, ve kterém jde velice efektivně programovat. A díky balíkům, jako jsou NumPy, SciPy nebo SymPy jde velice rychle řešit různé vědecké úlohy. Je pochopitelné, že s takovouto mírou abstrakce může být výsledný program pomalejší, než kdyby byl dobře napsán v nějakém kompilovaném jazyce typu C/C++ nebo Fortran. Musíme si ovšem uvědomit, že efektivita programu se měří v *celkovém čase stráveném na vývoji a běhu programu* (pro daný soubor úloh). Schématicky můžeme znázornit závislost rychlosti běhu programu v závislosti na délce vývoje asi takto:" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "ExecuteTime": { "end_time": "2024-04-18T08:17:10.193420Z", "start_time": "2024-04-18T08:17:10.188542Z" }, "collapsed": false }, "outputs": [ { "data": { "image/png": "", "text/plain": [ "" ] }, "execution_count": 2, "metadata": {}, "output_type": "execute_result" } ], "source": [ "Image(filename='optimizing-what.png')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Všimněte si například, že typicky existuje nějaký offset ve vývojovém čase, tj. trvá nám déle v nízkoúrovňovém jazyce, než vůbec dostaneme první výsledek. Potřeba optimalizace tedy *silně závisí na objemu výpočtů, které budeme s daným programem řešit*.\n", "\n", "Toto ovšem *neznamená, že pokud je objem velký, máme hned začít programovat v C nebo Fortranu*. Za chvíli si ukážeme, jak optimalizaci řešit chytřeji a postupně. Empirické pravidlo říká, že 90 % výpočetního času zabere 10 % zdrojového kódu. Jedná se o konkrétní příklad obecného [Paretova 80 / 20 principu](https://cs.wikipedia.org/wiki/Paret%C5%AFv_princip). Je tedy vhodné nejprve těchto 10 % najít a poté je teprve začít optimalizovat. Python nám k tomuto účelu poskytuje velice mocné nástroje." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Profilování\n", "\n", "Profilování je nástroj, který nám umožní najít kritická místa v našem programu, oněch 10 %, které stojí za to optimalizovat. Zkusme si to ukázat na jednoduchém příkladu." ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "ExecuteTime": { "end_time": "2024-04-18T08:17:12.089139Z", "start_time": "2024-04-18T08:17:12.084944Z" }, "collapsed": false }, "outputs": [], "source": [ "def heavy_calc(X):\n", " Y = X.copy()\n", " for i in range(10):\n", " Y = Y**i\n", " return Y\n", "\n", "def heavy_loop(inputs):\n", " res = []\n", " for X in inputs:\n", " res.append(heavy_calc(X))\n", " return res\n", "\n", "def code_setup():\n", " from numpy.random import rand\n", " N = 20\n", " M = 1000\n", " print(\"Will generate {} random arrays\".format(N))\n", " inputs = [rand(M, M) for n in range(N)]\n", " print(\"Will calculate now\")\n", " result = heavy_loop(inputs)\n", " print(\"Finished calculation\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Python obsahuje dva základní mofuly pro profilování - `profile` a `cProfile`, z nichž ten druhý je rychlejší. Pomocí funkce `run` pustíme výpočet pod dohledem cProfile, výsledky uložíme do souboru." ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "ExecuteTime": { "end_time": "2024-04-18T08:17:16.723579Z", "start_time": "2024-04-18T08:17:15.663576Z" }, "collapsed": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Will generate 20 random arrays\n", "Will calculate now\n", "Finished calculation\n" ] } ], "source": [ "import cProfile\n", "cProfile.run('code_setup()', 'pstats')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Dále budeme potřebovat modul `pstats`, který nám umožní s výsledky pracovat. Použije k tomu třídu `Stats`." ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "ExecuteTime": { "end_time": "2024-04-18T08:17:18.050402Z", "start_time": "2024-04-18T08:17:18.047259Z" }, "collapsed": false }, "outputs": [], "source": [ "from pstats import Stats\n", "p = Stats('pstats')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "`print_stats` nám zobrazí prvních n záznamů." ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "ExecuteTime": { "end_time": "2024-04-18T08:17:20.961915Z", "start_time": "2024-04-18T08:17:20.956328Z" }, "collapsed": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Sun Jan 19 13:23:06 2025 pstats\n", "\n", " 3892 function calls (3755 primitive calls) in 0.537 seconds\n", "\n", " Random listing order was used\n", " List reduced from 378 to 10 due to restriction <10>\n", "\n", " ncalls tottime percall cumtime percall filename:lineno(function)\n", " 13 0.000 0.000 0.000 0.000 /Users/kuba/workspace/fjfi/python-fjfi/.venv/lib/python3.12/site-packages/_distutils_hack/__init__.py:101(find_spec)\n", " 25 0.000 0.000 0.000 0.000 /Users/kuba/.local/share/uv/python/cpython-3.12.5-macos-aarch64-none/lib/python3.12/enum.py:1551(__or__)\n", " 17 0.000 0.000 0.000 0.000 /Users/kuba/.local/share/uv/python/cpython-3.12.5-macos-aarch64-none/lib/python3.12/enum.py:1562(__and__)\n", " 3 0.000 0.000 0.000 0.000 /Users/kuba/.local/share/uv/python/cpython-3.12.5-macos-aarch64-none/lib/python3.12/threading.py:1155(_wait_for_tstate_lock)\n", " 126 0.000 0.000 0.000 0.000 /Users/kuba/.local/share/uv/python/cpython-3.12.5-macos-aarch64-none/lib/python3.12/enum.py:1544(_get_value)\n", " 4 0.000 0.000 0.000 0.000 /Users/kuba/.local/share/uv/python/cpython-3.12.5-macos-aarch64-none/lib/python3.12/json/encoder.py:183(encode)\n", " 2 0.000 0.000 0.000 0.000 /Users/kuba/workspace/fjfi/python-fjfi/.venv/lib/python3.12/site-packages/_distutils_hack/__init__.py:108()\n", " 1 0.000 0.000 0.000 0.000 /Users/kuba/.local/share/uv/python/cpython-3.12.5-macos-aarch64-none/lib/python3.12/re/__init__.py:226(compile)\n", " 58 0.000 0.000 0.000 0.000 /Users/kuba/.local/share/uv/python/cpython-3.12.5-macos-aarch64-none/lib/python3.12/enum.py:726(__call__)\n", " 1 0.000 0.000 0.000 0.000 /Users/kuba/.local/share/uv/python/cpython-3.12.5-macos-aarch64-none/lib/python3.12/functools.py:35(update_wrapper)\n", "\n", "\n" ] }, { "data": { "text/plain": [ "" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "p.print_stats(10)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Ty jsou ovšem nesetříděné. Následující výstup už je užitečnější, záznamy jsou totiž setříděné podle celkového času stráveného v dané funkci. Navíc `strip_dirs` odstraní adresáře ze jmen funkcí." ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "ExecuteTime": { "end_time": "2024-04-18T08:17:31.098424Z", "start_time": "2024-04-18T08:17:31.093655Z" }, "collapsed": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Sun Jan 19 13:23:06 2025 pstats\n", "\n", " 3892 function calls (3755 primitive calls) in 0.537 seconds\n", "\n", " Ordered by: cumulative time\n", " List reduced from 378 to 10 due to restriction <10>\n", "\n", " ncalls tottime percall cumtime percall filename:lineno(function)\n", " 8 0.079 0.010 0.762 0.095 base_events.py:1909(_run_once)\n", " 4/1 0.000 0.000 0.530 0.530 {built-in method builtins.exec}\n", " 20 0.416 0.021 0.431 0.022 2469674079.py:1(heavy_calc)\n", " 1 0.000 0.000 0.320 0.320 2469674079.py:13(code_setup)\n", " 1 0.000 0.000 0.320 0.320 2469674079.py:7(heavy_loop)\n", " 20 0.015 0.001 0.015 0.001 {method 'copy' of 'numpy.ndarray' objects}\n", " 13/1 0.000 0.000 0.014 0.014 :1349(_find_and_load)\n", " 13/1 0.000 0.000 0.014 0.014 :1304(_find_and_load_unlocked)\n", " 12/1 0.000 0.000 0.014 0.014 :911(_load_unlocked)\n", " 3/1 0.000 0.000 0.014 0.014 :989(exec_module)\n", "\n", "\n" ] }, { "data": { "text/plain": [ "" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "p.strip_dirs().sort_stats('cumulative').print_stats(10)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Takto vypadá výstup setříděný pomocí nekumulovaného času." ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "ExecuteTime": { "end_time": "2024-04-11T12:49:44.698688Z", "start_time": "2024-04-11T12:49:44.695290Z" }, "collapsed": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Sun Jan 19 13:23:06 2025 pstats\n", "\n", " 3892 function calls (3755 primitive calls) in 0.537 seconds\n", "\n", " Ordered by: internal time\n", " List reduced from 378 to 10 due to restriction <10>\n", "\n", " ncalls tottime percall cumtime percall filename:lineno(function)\n", " 20 0.416 0.021 0.431 0.022 2469674079.py:1(heavy_calc)\n", " 8 0.079 0.010 0.762 0.095 base_events.py:1909(_run_once)\n", " 20 0.015 0.001 0.015 0.001 {method 'copy' of 'numpy.ndarray' objects}\n", " 7/0 0.011 0.002 0.000 {method 'control' of 'select.kqueue' objects}\n", " 9 0.007 0.001 0.007 0.001 {built-in method _imp.create_dynamic}\n", " 9/6 0.002 0.000 0.006 0.001 {built-in method _imp.exec_dynamic}\n", " 35/2 0.001 0.000 0.012 0.006 :480(_call_with_frames_removed)\n", " 9/8 0.001 0.000 0.002 0.000 events.py:86(_run)\n", " 36 0.001 0.000 0.001 0.000 {built-in method posix.stat}\n", " 24 0.000 0.000 0.000 0.000 socket.py:626(send)\n", "\n", "\n" ] }, { "data": { "text/plain": [ "" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "p.sort_stats('time').print_stats(10)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Jupyter nám může usnadnit práci pomocí `%prun` a `%%prun`. Např." ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "ExecuteTime": { "end_time": "2024-04-18T08:21:57.247796Z", "start_time": "2024-04-18T08:21:55.778256Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Will generate 20 random arrays\n", "Will calculate now\n", "Finished calculation\n", " " ] }, { "name": "stdout", "output_type": "stream", "text": [ " 2184 function calls (2126 primitive calls) in 0.755 seconds\n", "\n", " Ordered by: cumulative time\n", " List reduced from 219 to 10 due to restriction <10>\n", "\n", " ncalls tottime percall cumtime percall filename:lineno(function)\n", " 1 0.000 0.000 0.589 0.589 {built-in method builtins.exec}\n", " 1 0.016 0.016 0.589 0.589 :1()\n", " 20 0.507 0.025 0.520 0.026 2469674079.py:1(heavy_calc)\n", " 14/13 0.009 0.001 0.375 0.029 base_events.py:1909(_run_once)\n", " 1 0.000 0.000 0.363 0.363 2469674079.py:13(code_setup)\n", " 1 0.000 0.000 0.363 0.363 2469674079.py:7(heavy_loop)\n", " 20 0.016 0.001 0.016 0.001 {method 'copy' of 'numpy.ndarray' objects}\n", " 15/13 0.000 0.000 0.003 0.000 {method 'run' of '_contextvars.Context' objects}\n", " 7 0.000 0.000 0.002 0.000 zmqstream.py:583(_handle_events)\n", " 5 0.000 0.000 0.002 0.000 asyncio.py:200(_handle_events)" ] } ], "source": [ "%prun -s cumulative -l 10 code_setup()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Z obou výstupů celkem jasně vydíme, že naprostou většinu času trávíme ve funkci `heavy_calc`. Pokud se tedy chceme pustit do optimalizace, musíme se zaměřit právě na tuto část našeho programu.\n", "\n", "Výsledky můžete navíc spojit s nástroji pro vizualizaci, např.[SnakeViz](http://jiffyclub.github.io/snakeviz/) nebo [vprof](https://github.com/nvdv/vprof), popř. pokročilý editor jako [PyCharm](https://www.jetbrains.com/pycharm/)." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Vzorová úloha - vzdálenost množiny bodů ve vícerozměrném prostoru\n", "\n", "(Tento příklad byl převzat z http://jakevdp.github.io/blog/2013/06/15/numba-vs-cython-take-2.)\n", "\n", "Zadání je jednoduché: pro M bodů v N rozměrném prostoru spočítejte vzájemnou vzdálenost $d$, která je pro dva body $x,y$ definovaná jako $\\sqrt {\\sum_{i=1}^N {{{\\left( {{x_i} - {y_i}} \\right)}^2}} } $. Výslekem je tedy (symetrická) matice $M\\times M$." ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "ExecuteTime": { "end_time": "2024-04-18T08:22:00.402824Z", "start_time": "2024-04-18T08:22:00.399751Z" }, "collapsed": false }, "outputs": [], "source": [ "# toto nechť jsou naše vstupní data\n", "M = 1000\n", "N = 3\n", "X = np.random.random((M, N))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Implementace v čistém Pythonu\n", "Nemůžeme asi očekávat, že toto bude nejrychlejší a nejsnadnější verze našeho programu. Přesto stojí za to ji vyzkoušet, navíc ji budeme ještě potřebovat." ] }, { "cell_type": "code", "execution_count": 11, "metadata": { "ExecuteTime": { "end_time": "2024-04-18T08:22:01.347164Z", "start_time": "2024-04-18T08:22:01.342692Z" }, "collapsed": false }, "outputs": [], "source": [ "def pairwise_python(X):\n", " M = X.shape[0]\n", " N = X.shape[1]\n", " D = np.empty((M, M), dtype=float)\n", " for i in range(M):\n", " for j in range(M):\n", " d = 0.0\n", " for k in range(N):\n", " tmp = X[i, k] - X[j, k]\n", " d += tmp * tmp\n", " D[i, j] = np.sqrt(d)\n", " return D" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Tahle funkce nám bude pomáhat ukládat výsledné časy z `%timeit`." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Do `pairwise_times` si uložíme výsledné časy." ] }, { "cell_type": "code", "execution_count": 12, "metadata": { "ExecuteTime": { "end_time": "2024-04-18T08:22:02.796241Z", "start_time": "2024-04-18T08:22:02.792442Z" }, "collapsed": false }, "outputs": [], "source": [ "pairwise_times = {}" ] }, { "cell_type": "code", "execution_count": 13, "metadata": { "ExecuteTime": { "end_time": "2024-04-18T08:22:15.299522Z", "start_time": "2024-04-18T08:22:03.284684Z" }, "collapsed": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "1.13 s ± 68.5 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n" ] } ], "source": [ "timings = %timeit -o pairwise_python(X)\n", "pairwise_times['plain_python'] = timings" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### To samé pomocí NumPy\n", "V případě NumPy můžeme v tomto případě využít broadcasting. Celá funkce tak zabere jeden rádek." ] }, { "cell_type": "code", "execution_count": 14, "metadata": { "ExecuteTime": { "end_time": "2024-04-18T08:22:15.303807Z", "start_time": "2024-04-18T08:22:15.300664Z" }, "collapsed": false }, "outputs": [], "source": [ "def pairwise_numpy(X):\n", " return np.sqrt(((X[:, np.newaxis, :] - X) ** 2).sum(-1))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Zkusíme, jestli výsledky jsou stejné pomocí `assert` a `numpy.allclose`." ] }, { "cell_type": "code", "execution_count": 15, "metadata": { "ExecuteTime": { "end_time": "2024-04-18T08:22:19.469565Z", "start_time": "2024-04-18T08:22:17.419668Z" } }, "outputs": [], "source": [ "assert np.allclose(pairwise_numpy(X), pairwise_python(X), rtol=1e-10, atol=1e-15)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Výsledky jsou stejné až na velmi malé rozdíly - to je nebezpečí numerických výpočtů s konečnou přesností." ] }, { "cell_type": "code", "execution_count": 16, "metadata": { "ExecuteTime": { "end_time": "2024-04-18T08:22:23.531700Z", "start_time": "2024-04-18T08:22:21.409836Z" }, "collapsed": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "18.3 ms ± 221 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n" ] } ], "source": [ "timings = %timeit -o pairwise_numpy(X)\n", "pairwise_times['numpy'] = timings" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Vidíme, že jsme zkrátili běh programu více než 100-krát. To není špatné, navíc je implementace daleko jednodušší." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Přichází Cython\n", "\n", "[Cython](https://cython.org/) je nástroj, který z Python programu, obohaceného o nějaké Cython direktivy, vytvoří program v C (případně C++), který je možné zkompilovat a okamžitě použít jako modul v Pythonu. Typickým příkladem Cython direktiv jsou statické typy. Cython samozřejmě umožňuje používat funkce z binárních knihoven s C rozhraním.\n", "\n", "Zkusíme optimalizovat naší funkci `pairwise_python`.\n", "\n", "* Cython zdroják má koncovku .pyx (za začátku byl Pyrex).\n", "* Cython dokáže přeložit jakýkoli Python. Výsledkem je ale minimální (nebo spíš žádná) optimalizace.\n", "* `cimport` je analogie `import`, pracuje ale s Cython definicemi funkcí (.pxd soubory).\n", "* Cython dodává `numpy.pyx`, obsahující dodatečné informace pro kompilace NumPy modulů. Proto voláme `cimport numpy`.\n", "* Podobně `libc` je speciální modul Cythonu.\n", "\n", "* Funkce se deklarují (moho deklarovat) se statickými typy vstupních parametrů. My použijeme `np.ndarray[np.float64_t, ndim=2]`.\n", "* Proměnné se deklarují pomocí `cdef`.\n" ] }, { "cell_type": "code", "execution_count": 17, "metadata": { "ExecuteTime": { "end_time": "2024-04-18T08:22:30.642196Z", "start_time": "2024-04-18T08:22:30.639708Z" } }, "outputs": [], "source": [ "# Odkomentujte pro instalaci Cythonu\n", "# !pip install cython" ] }, { "cell_type": "code", "execution_count": 18, "metadata": { "ExecuteTime": { "end_time": "2024-04-18T08:22:32.130939Z", "start_time": "2024-04-18T08:22:32.127198Z" }, "collapsed": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Overwriting cyfuncs.pyx\n" ] } ], "source": [ "%%file cyfuncs.pyx\n", "\n", "language_level = \"3str\"\n", "\n", "import numpy as np\n", "# numpy pro Cython\n", "cimport numpy as np\n", "from libc.math cimport sqrt\n", "\n", "# tohle je čistý Python\n", "def pairwise0(X):\n", " M = X.shape[0]\n", " N = X.shape[1]\n", " D = np.empty((M, M), dtype=float)\n", " for i in range(M):\n", " for j in range(M):\n", " d = 0.0\n", " for k in range(N):\n", " tmp = X[i, k] - X[j, k]\n", " d += tmp * tmp\n", " D[i, j] = np.sqrt(d)\n", " return D\n", "\n", "# tady už začínáme optimalizovat, změny ale nejsou drastické\n", "def pairwise1(np.ndarray[np.float64_t, ndim=2] X):\n", " cdef int M = X.shape[0]\n", " cdef int N = X.shape[1]\n", " cdef double tmp, d\n", " cdef np.ndarray D = np.empty((M, M), dtype=np.float64)\n", " for i in range(M):\n", " for j in range(M):\n", " d = 0.0\n", " for k in range(N):\n", " tmp = X[i, k] - X[j, k]\n", " d += tmp * tmp\n", " D[i, j] = sqrt(d)\n", " return D" ] }, { "cell_type": "code", "execution_count": 19, "metadata": { "ExecuteTime": { "end_time": "2024-04-18T08:23:10.420625Z", "start_time": "2024-04-18T08:23:10.417392Z" }, "collapsed": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Overwriting setup.py\n" ] } ], "source": [ "%%file setup.py\n", "\n", "from distutils.core import setup\n", "from Cython.Build import cythonize\n", "import numpy\n", "\n", "setup(\n", " name='cyfuncs',\n", " include_dirs=[numpy.get_include()],\n", " ext_modules=cythonize(\"cyfuncs.pyx\"),\n", ")" ] }, { "cell_type": "code", "execution_count": 20, "metadata": { "ExecuteTime": { "end_time": "2024-04-18T08:23:13.073072Z", "start_time": "2024-04-18T08:23:10.984125Z" }, "collapsed": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Compiling cyfuncs.pyx because it changed.\n", "[1/1] Cythonizing cyfuncs.pyx\n", "/Users/kuba/workspace/fjfi/python-fjfi/.venv/lib/python3.12/site-packages/Cython/Compiler/Main.py:381: FutureWarning: Cython directive 'language_level' not set, using '3str' for now (Py3). This has changed from earlier releases! File: /Users/kuba/workspace/fjfi/python-fjfi/numerical_python_course/lecture_notes.cz/cyfuncs.pyx\n", " tree = Parsing.p_module(s, pxd, full_module_name)\n", "In file included from cyfuncs.c:1240:\n", "In file included from /Users/kuba/workspace/fjfi/python-fjfi/.venv/lib/python3.12/site-packages/numpy/_core/include/numpy/arrayobject.h:5:\n", "In file included from /Users/kuba/workspace/fjfi/python-fjfi/.venv/lib/python3.12/site-packages/numpy/_core/include/numpy/ndarrayobject.h:12:\n", "In file included from /Users/kuba/workspace/fjfi/python-fjfi/.venv/lib/python3.12/site-packages/numpy/_core/include/numpy/ndarraytypes.h:1909:\n", "\u001b[1m/Users/kuba/workspace/fjfi/python-fjfi/.venv/lib/python3.12/site-packages/numpy/_core/include/numpy/npy_1_7_deprecated_api.h:17:2: \u001b[0m\u001b[0;1;35mwarning: \u001b[0m\u001b[1m\"Using deprecated NumPy API, disable it with \" \"#define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION\" [-W#warnings]\u001b[0m\n", " 17 | #warning \"Using deprecated NumPy API, disable it with \" \\\u001b[0m\n", " | \u001b[0;1;32m ^\n", "\u001b[0m\u001b[1mcyfuncs.c:8531:26: \u001b[0m\u001b[0;1;35mwarning: \u001b[0m\u001b[1mcode will never be executed [-Wunreachable-code]\u001b[0m\n", " 8531 | module = PyImport_ImportModuleLevelObject(\u001b[0m\n", " | \u001b[0;1;32m ^~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n", "\u001b[0m2 warnings generated.\n", "ld: warning: search path 'Modules/_hacl' not found\n", "ld: warning: search path '/install/lib' not found\n" ] } ], "source": [ "!python setup.py build_ext --inplace" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Jak jsme již říkali, Cython vytvoří C zdroják, který se pak kompiluje pomocí běžného překladače (např. gcc). Pojďme se na tento soubor podívat." ] }, { "cell_type": "code", "execution_count": 21, "metadata": { "ExecuteTime": { "end_time": "2024-04-18T08:23:17.343320Z", "start_time": "2024-04-18T08:23:17.339205Z" }, "collapsed": false }, "outputs": [ { "data": { "text/html": [ "cyfuncs.c
" ], "text/plain": [ "/Users/kuba/workspace/fjfi/python-fjfi/numerical_python_course/lecture_notes.cz/cyfuncs.c" ] }, "execution_count": 21, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from IPython.display import FileLink\n", "FileLink('cyfuncs.c')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Ten soubor je dlouhý ... Obsahuje spoustu Python \"balastu\", na kterém vidíme, jak je vlastně samotný CPython naprogramován. Naštěstí tento soubor obsahuje i komentáře, které říkají, které řádce daný blok odpovídá. Např.\n", "\n", " /* \"cyfuncs.pyx\":16\n", " * tmp = X[i, k] - X[j, k]\n", " * d += tmp * tmp\n", " * D[i, j] = np.sqrt(d) # <<<<<<<<<<<<<<\n", " * return D\n", " * \n", " */\n" ] }, { "cell_type": "code", "execution_count": 22, "metadata": { "ExecuteTime": { "end_time": "2024-04-18T08:23:19.619108Z", "start_time": "2024-04-18T08:23:19.615667Z" }, "collapsed": false }, "outputs": [], "source": [ "import cyfuncs" ] }, { "cell_type": "code", "execution_count": 23, "metadata": { "ExecuteTime": { "end_time": "2024-04-18T08:23:20.331403Z", "start_time": "2024-04-18T08:23:20.328054Z" }, "collapsed": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "cyfuncs obsahuje: language_level, np, pairwise0, pairwise1\n" ] } ], "source": [ "print(\"cyfuncs obsahuje: \" + \", \".join((d for d in dir(cyfuncs) if not d.startswith(\"_\"))))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Podívejme se, jestli dostávám stále stejné výsledky." ] }, { "cell_type": "code", "execution_count": 24, "metadata": { "ExecuteTime": { "end_time": "2024-04-18T08:23:28.253271Z", "start_time": "2024-04-18T08:23:28.160670Z" }, "collapsed": false }, "outputs": [], "source": [ "assert np.allclose(pairwise_numpy(X), cyfuncs.pairwise1(X), rtol=1e-10, atol=1e-15)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "No a jak jsme na tom s časem?" ] }, { "cell_type": "code", "execution_count": 25, "metadata": { "ExecuteTime": { "end_time": "2024-04-18T08:23:39.234005Z", "start_time": "2024-04-18T08:23:29.580814Z" }, "collapsed": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "1.13 s ± 41.7 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n" ] } ], "source": [ "timings = %timeit -o cyfuncs.pairwise0(X)\n", "pairwise_times['cython0'] = timings" ] }, { "cell_type": "code", "execution_count": 26, "metadata": { "ExecuteTime": { "end_time": "2024-04-18T08:23:44.674128Z", "start_time": "2024-04-18T08:23:39.235362Z" }, "collapsed": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "79.6 ms ± 3.5 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)\n" ] } ], "source": [ "timings = %timeit -o cyfuncs.pairwise1(X)\n", "pairwise_times['cython1'] = timings" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### IPython `%%cython` magic\n", "\n", "IPython, tak jako v mnoha jiných případech, nám práci s Cythonem usnadňuje pomocí triku `%%cython`. Zkusíme ho použít. Zároveň zkusíme ještě více náš kód optimalizovat, zatím je totiž pomalejší než numpy." ] }, { "cell_type": "code", "execution_count": 27, "metadata": { "ExecuteTime": { "end_time": "2024-04-18T08:23:45.671385Z", "start_time": "2024-04-18T08:23:45.405330Z" }, "collapsed": false }, "outputs": [], "source": [ "%load_ext Cython" ] }, { "cell_type": "code", "execution_count": 28, "metadata": { "ExecuteTime": { "end_time": "2024-04-18T08:23:49.316037Z", "start_time": "2024-04-18T08:23:45.830809Z" }, "collapsed": false }, "outputs": [ { "data": { "text/html": [ "\n", "\n", "\n", "\n", " \n", " Cython: _cython_magic_72b25b6e366aa2ac2fa9f6e8b1834b878e03df51.pyx\n", " \n", "\n", "\n", "

Generated by Cython 3.0.11

\n", "

\n", " Yellow lines hint at Python interaction.
\n", " Click on a line that starts with a \"+\" to see the C code that Cython generated for it.\n", "

\n", "
 01: 
\n", "
+02: import numpy as np
\n", "
  __pyx_t_7 = __Pyx_ImportDottedModule(__pyx_n_s_numpy, NULL); if (unlikely(!__pyx_t_7)) __PYX_ERR(0, 2, __pyx_L1_error)\n",
       "  __Pyx_GOTREF(__pyx_t_7);\n",
       "  if (PyDict_SetItem(__pyx_d, __pyx_n_s_np, __pyx_t_7) < 0) __PYX_ERR(0, 2, __pyx_L1_error)\n",
       "  __Pyx_DECREF(__pyx_t_7); __pyx_t_7 = 0;\n",
       "/* … */\n",
       "  __pyx_t_7 = __Pyx_PyDict_NewPresized(0); if (unlikely(!__pyx_t_7)) __PYX_ERR(0, 2, __pyx_L1_error)\n",
       "  __Pyx_GOTREF(__pyx_t_7);\n",
       "  if (PyDict_SetItem(__pyx_d, __pyx_n_s_test, __pyx_t_7) < 0) __PYX_ERR(0, 2, __pyx_L1_error)\n",
       "  __Pyx_DECREF(__pyx_t_7); __pyx_t_7 = 0;\n",
       "
 03: cimport numpy as np
\n", "
 04: cimport cython
\n", "
 05: from libc.math cimport sqrt
\n", "
 06: 
\n", "
+07: @cython.boundscheck(False)
\n", "
/* Python wrapper */\n",
       "static PyObject *__pyx_pw_54_cython_magic_72b25b6e366aa2ac2fa9f6e8b1834b878e03df51_1pairwise_cython(PyObject *__pyx_self, \n",
       "#if CYTHON_METH_FASTCALL\n",
       "PyObject *const *__pyx_args, Py_ssize_t __pyx_nargs, PyObject *__pyx_kwds\n",
       "#else\n",
       "PyObject *__pyx_args, PyObject *__pyx_kwds\n",
       "#endif\n",
       "); /*proto*/\n",
       "static PyMethodDef __pyx_mdef_54_cython_magic_72b25b6e366aa2ac2fa9f6e8b1834b878e03df51_1pairwise_cython = {\"pairwise_cython\", (PyCFunction)(void*)(__Pyx_PyCFunction_FastCallWithKeywords)__pyx_pw_54_cython_magic_72b25b6e366aa2ac2fa9f6e8b1834b878e03df51_1pairwise_cython, __Pyx_METH_FASTCALL|METH_KEYWORDS, 0};\n",
       "static PyObject *__pyx_pw_54_cython_magic_72b25b6e366aa2ac2fa9f6e8b1834b878e03df51_1pairwise_cython(PyObject *__pyx_self, \n",
       "#if CYTHON_METH_FASTCALL\n",
       "PyObject *const *__pyx_args, Py_ssize_t __pyx_nargs, PyObject *__pyx_kwds\n",
       "#else\n",
       "PyObject *__pyx_args, PyObject *__pyx_kwds\n",
       "#endif\n",
       ") {\n",
       "  __Pyx_memviewslice __pyx_v_X = { 0, 0, { 0 }, { 0 }, { 0 } };\n",
       "  #if !CYTHON_METH_FASTCALL\n",
       "  CYTHON_UNUSED Py_ssize_t __pyx_nargs;\n",
       "  #endif\n",
       "  CYTHON_UNUSED PyObject *const *__pyx_kwvalues;\n",
       "  PyObject *__pyx_r = 0;\n",
       "  __Pyx_RefNannyDeclarations\n",
       "  __Pyx_RefNannySetupContext(\"pairwise_cython (wrapper)\", 0);\n",
       "  #if !CYTHON_METH_FASTCALL\n",
       "  #if CYTHON_ASSUME_SAFE_MACROS\n",
       "  __pyx_nargs = PyTuple_GET_SIZE(__pyx_args);\n",
       "  #else\n",
       "  __pyx_nargs = PyTuple_Size(__pyx_args); if (unlikely(__pyx_nargs < 0)) return NULL;\n",
       "  #endif\n",
       "  #endif\n",
       "  __pyx_kwvalues = __Pyx_KwValues_FASTCALL(__pyx_args, __pyx_nargs);\n",
       "  {\n",
       "    PyObject **__pyx_pyargnames[] = {&__pyx_n_s_X,0};\n",
       "  PyObject* values[1] = {0};\n",
       "    if (__pyx_kwds) {\n",
       "      Py_ssize_t kw_args;\n",
       "      switch (__pyx_nargs) {\n",
       "        case  1: values[0] = __Pyx_Arg_FASTCALL(__pyx_args, 0);\n",
       "        CYTHON_FALLTHROUGH;\n",
       "        case  0: break;\n",
       "        default: goto __pyx_L5_argtuple_error;\n",
       "      }\n",
       "      kw_args = __Pyx_NumKwargs_FASTCALL(__pyx_kwds);\n",
       "      switch (__pyx_nargs) {\n",
       "        case  0:\n",
       "        if (likely((values[0] = __Pyx_GetKwValue_FASTCALL(__pyx_kwds, __pyx_kwvalues, __pyx_n_s_X)) != 0)) {\n",
       "          (void)__Pyx_Arg_NewRef_FASTCALL(values[0]);\n",
       "          kw_args--;\n",
       "        }\n",
       "        else if (unlikely(PyErr_Occurred())) __PYX_ERR(0, 7, __pyx_L3_error)\n",
       "        else goto __pyx_L5_argtuple_error;\n",
       "      }\n",
       "      if (unlikely(kw_args > 0)) {\n",
       "        const Py_ssize_t kwd_pos_args = __pyx_nargs;\n",
       "        if (unlikely(__Pyx_ParseOptionalKeywords(__pyx_kwds, __pyx_kwvalues, __pyx_pyargnames, 0, values + 0, kwd_pos_args, \"pairwise_cython\") < 0)) __PYX_ERR(0, 7, __pyx_L3_error)\n",
       "      }\n",
       "    } else if (unlikely(__pyx_nargs != 1)) {\n",
       "      goto __pyx_L5_argtuple_error;\n",
       "    } else {\n",
       "      values[0] = __Pyx_Arg_FASTCALL(__pyx_args, 0);\n",
       "    }\n",
       "    __pyx_v_X = __Pyx_PyObject_to_MemoryviewSlice_d_dc_double(values[0], PyBUF_WRITABLE); if (unlikely(!__pyx_v_X.memview)) __PYX_ERR(0, 9, __pyx_L3_error)\n",
       "  }\n",
       "  goto __pyx_L6_skip;\n",
       "  __pyx_L5_argtuple_error:;\n",
       "  __Pyx_RaiseArgtupleInvalid(\"pairwise_cython\", 1, 1, 1, __pyx_nargs); __PYX_ERR(0, 7, __pyx_L3_error)\n",
       "  __pyx_L6_skip:;\n",
       "  goto __pyx_L4_argument_unpacking_done;\n",
       "  __pyx_L3_error:;\n",
       "  {\n",
       "    Py_ssize_t __pyx_temp;\n",
       "    for (__pyx_temp=0; __pyx_temp < (Py_ssize_t)(sizeof(values)/sizeof(values[0])); ++__pyx_temp) {\n",
       "      __Pyx_Arg_XDECREF_FASTCALL(values[__pyx_temp]);\n",
       "    }\n",
       "  }\n",
       "  __PYX_XCLEAR_MEMVIEW(&__pyx_v_X, 1);\n",
       "  __Pyx_AddTraceback(\"_cython_magic_72b25b6e366aa2ac2fa9f6e8b1834b878e03df51.pairwise_cython\", __pyx_clineno, __pyx_lineno, __pyx_filename);\n",
       "  __Pyx_RefNannyFinishContext();\n",
       "  return NULL;\n",
       "  __pyx_L4_argument_unpacking_done:;\n",
       "  __pyx_r = __pyx_pf_54_cython_magic_72b25b6e366aa2ac2fa9f6e8b1834b878e03df51_pairwise_cython(__pyx_self, __pyx_v_X);\n",
       "  int __pyx_lineno = 0;\n",
       "  const char *__pyx_filename = NULL;\n",
       "  int __pyx_clineno = 0;\n",
       "\n",
       "  /* function exit code */\n",
       "  __PYX_XCLEAR_MEMVIEW(&__pyx_v_X, 1);\n",
       "  {\n",
       "    Py_ssize_t __pyx_temp;\n",
       "    for (__pyx_temp=0; __pyx_temp < (Py_ssize_t)(sizeof(values)/sizeof(values[0])); ++__pyx_temp) {\n",
       "      __Pyx_Arg_XDECREF_FASTCALL(values[__pyx_temp]);\n",
       "    }\n",
       "  }\n",
       "  __Pyx_RefNannyFinishContext();\n",
       "  return __pyx_r;\n",
       "}\n",
       "\n",
       "static PyObject *__pyx_pf_54_cython_magic_72b25b6e366aa2ac2fa9f6e8b1834b878e03df51_pairwise_cython(CYTHON_UNUSED PyObject *__pyx_self, __Pyx_memviewslice __pyx_v_X) {\n",
       "  int __pyx_v_M;\n",
       "  int __pyx_v_N;\n",
       "  double __pyx_v_tmp;\n",
       "  double __pyx_v_d;\n",
       "  __Pyx_memviewslice __pyx_v_D = { 0, 0, { 0 }, { 0 }, { 0 } };\n",
       "  int __pyx_v_i;\n",
       "  int __pyx_v_j;\n",
       "  int __pyx_v_k;\n",
       "  PyObject *__pyx_r = NULL;\n",
       "/* … */\n",
       "  /* function exit code */\n",
       "  __pyx_L1_error:;\n",
       "  __Pyx_XDECREF(__pyx_t_1);\n",
       "  __Pyx_XDECREF(__pyx_t_2);\n",
       "  __Pyx_XDECREF(__pyx_t_3);\n",
       "  __Pyx_XDECREF(__pyx_t_4);\n",
       "  __Pyx_XDECREF(__pyx_t_5);\n",
       "  __PYX_XCLEAR_MEMVIEW(&__pyx_t_6, 1);\n",
       "  __Pyx_AddTraceback(\"_cython_magic_72b25b6e366aa2ac2fa9f6e8b1834b878e03df51.pairwise_cython\", __pyx_clineno, __pyx_lineno, __pyx_filename);\n",
       "  __pyx_r = NULL;\n",
       "  __pyx_L0:;\n",
       "  __PYX_XCLEAR_MEMVIEW(&__pyx_v_D, 1);\n",
       "  __Pyx_XGIVEREF(__pyx_r);\n",
       "  __Pyx_RefNannyFinishContext();\n",
       "  return __pyx_r;\n",
       "}\n",
       "/* … */\n",
       "  __pyx_tuple__22 = PyTuple_Pack(9, __pyx_n_s_X, __pyx_n_s_M, __pyx_n_s_N, __pyx_n_s_tmp, __pyx_n_s_d, __pyx_n_s_D, __pyx_n_s_i, __pyx_n_s_j, __pyx_n_s_k); if (unlikely(!__pyx_tuple__22)) __PYX_ERR(0, 7, __pyx_L1_error)\n",
       "  __Pyx_GOTREF(__pyx_tuple__22);\n",
       "  __Pyx_GIVEREF(__pyx_tuple__22);\n",
       "/* … */\n",
       "  __pyx_t_7 = __Pyx_CyFunction_New(&__pyx_mdef_54_cython_magic_72b25b6e366aa2ac2fa9f6e8b1834b878e03df51_1pairwise_cython, 0, __pyx_n_s_pairwise_cython, NULL, __pyx_n_s_cython_magic_72b25b6e366aa2ac2f, __pyx_d, ((PyObject *)__pyx_codeobj__23)); if (unlikely(!__pyx_t_7)) __PYX_ERR(0, 7, __pyx_L1_error)\n",
       "  __Pyx_GOTREF(__pyx_t_7);\n",
       "  if (PyDict_SetItem(__pyx_d, __pyx_n_s_pairwise_cython, __pyx_t_7) < 0) __PYX_ERR(0, 7, __pyx_L1_error)\n",
       "  __Pyx_DECREF(__pyx_t_7); __pyx_t_7 = 0;\n",
       "
 08: @cython.wraparound(False)
\n", "
 09: def pairwise_cython(double[:, ::1] X):
\n", "
+10:     cdef int M = X.shape[0]
\n", "
  __pyx_v_M = (__pyx_v_X.shape[0]);\n",
       "
+11:     cdef int N = X.shape[1]
\n", "
  __pyx_v_N = (__pyx_v_X.shape[1]);\n",
       "
 12:     cdef double tmp, d
\n", "
+13:     cdef double[:, ::1] D = np.empty((M, M), dtype=np.float64)
\n", "
  __Pyx_GetModuleGlobalName(__pyx_t_1, __pyx_n_s_np); if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 13, __pyx_L1_error)\n",
       "  __Pyx_GOTREF(__pyx_t_1);\n",
       "  __pyx_t_2 = __Pyx_PyObject_GetAttrStr(__pyx_t_1, __pyx_n_s_empty); if (unlikely(!__pyx_t_2)) __PYX_ERR(0, 13, __pyx_L1_error)\n",
       "  __Pyx_GOTREF(__pyx_t_2);\n",
       "  __Pyx_DECREF(__pyx_t_1); __pyx_t_1 = 0;\n",
       "  __pyx_t_1 = __Pyx_PyInt_From_int(__pyx_v_M); if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 13, __pyx_L1_error)\n",
       "  __Pyx_GOTREF(__pyx_t_1);\n",
       "  __pyx_t_3 = __Pyx_PyInt_From_int(__pyx_v_M); if (unlikely(!__pyx_t_3)) __PYX_ERR(0, 13, __pyx_L1_error)\n",
       "  __Pyx_GOTREF(__pyx_t_3);\n",
       "  __pyx_t_4 = PyTuple_New(2); if (unlikely(!__pyx_t_4)) __PYX_ERR(0, 13, __pyx_L1_error)\n",
       "  __Pyx_GOTREF(__pyx_t_4);\n",
       "  __Pyx_GIVEREF(__pyx_t_1);\n",
       "  if (__Pyx_PyTuple_SET_ITEM(__pyx_t_4, 0, __pyx_t_1)) __PYX_ERR(0, 13, __pyx_L1_error);\n",
       "  __Pyx_GIVEREF(__pyx_t_3);\n",
       "  if (__Pyx_PyTuple_SET_ITEM(__pyx_t_4, 1, __pyx_t_3)) __PYX_ERR(0, 13, __pyx_L1_error);\n",
       "  __pyx_t_1 = 0;\n",
       "  __pyx_t_3 = 0;\n",
       "  __pyx_t_3 = PyTuple_New(1); if (unlikely(!__pyx_t_3)) __PYX_ERR(0, 13, __pyx_L1_error)\n",
       "  __Pyx_GOTREF(__pyx_t_3);\n",
       "  __Pyx_GIVEREF(__pyx_t_4);\n",
       "  if (__Pyx_PyTuple_SET_ITEM(__pyx_t_3, 0, __pyx_t_4)) __PYX_ERR(0, 13, __pyx_L1_error);\n",
       "  __pyx_t_4 = 0;\n",
       "  __pyx_t_4 = __Pyx_PyDict_NewPresized(1); if (unlikely(!__pyx_t_4)) __PYX_ERR(0, 13, __pyx_L1_error)\n",
       "  __Pyx_GOTREF(__pyx_t_4);\n",
       "  __Pyx_GetModuleGlobalName(__pyx_t_1, __pyx_n_s_np); if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 13, __pyx_L1_error)\n",
       "  __Pyx_GOTREF(__pyx_t_1);\n",
       "  __pyx_t_5 = __Pyx_PyObject_GetAttrStr(__pyx_t_1, __pyx_n_s_float64); if (unlikely(!__pyx_t_5)) __PYX_ERR(0, 13, __pyx_L1_error)\n",
       "  __Pyx_GOTREF(__pyx_t_5);\n",
       "  __Pyx_DECREF(__pyx_t_1); __pyx_t_1 = 0;\n",
       "  if (PyDict_SetItem(__pyx_t_4, __pyx_n_s_dtype, __pyx_t_5) < 0) __PYX_ERR(0, 13, __pyx_L1_error)\n",
       "  __Pyx_DECREF(__pyx_t_5); __pyx_t_5 = 0;\n",
       "  __pyx_t_5 = __Pyx_PyObject_Call(__pyx_t_2, __pyx_t_3, __pyx_t_4); if (unlikely(!__pyx_t_5)) __PYX_ERR(0, 13, __pyx_L1_error)\n",
       "  __Pyx_GOTREF(__pyx_t_5);\n",
       "  __Pyx_DECREF(__pyx_t_2); __pyx_t_2 = 0;\n",
       "  __Pyx_DECREF(__pyx_t_3); __pyx_t_3 = 0;\n",
       "  __Pyx_DECREF(__pyx_t_4); __pyx_t_4 = 0;\n",
       "  __pyx_t_6 = __Pyx_PyObject_to_MemoryviewSlice_d_dc_double(__pyx_t_5, PyBUF_WRITABLE); if (unlikely(!__pyx_t_6.memview)) __PYX_ERR(0, 13, __pyx_L1_error)\n",
       "  __Pyx_DECREF(__pyx_t_5); __pyx_t_5 = 0;\n",
       "  __pyx_v_D = __pyx_t_6;\n",
       "  __pyx_t_6.memview = NULL;\n",
       "  __pyx_t_6.data = NULL;\n",
       "
+14:     for i in range(M):
\n", "
  __pyx_t_7 = __pyx_v_M;\n",
       "  __pyx_t_8 = __pyx_t_7;\n",
       "  for (__pyx_t_9 = 0; __pyx_t_9 < __pyx_t_8; __pyx_t_9+=1) {\n",
       "    __pyx_v_i = __pyx_t_9;\n",
       "
+15:         for j in range(M):
\n", "
    __pyx_t_10 = __pyx_v_M;\n",
       "    __pyx_t_11 = __pyx_t_10;\n",
       "    for (__pyx_t_12 = 0; __pyx_t_12 < __pyx_t_11; __pyx_t_12+=1) {\n",
       "      __pyx_v_j = __pyx_t_12;\n",
       "
+16:             d = 0.0
\n", "
      __pyx_v_d = 0.0;\n",
       "
+17:             for k in range(N):
\n", "
      __pyx_t_13 = __pyx_v_N;\n",
       "      __pyx_t_14 = __pyx_t_13;\n",
       "      for (__pyx_t_15 = 0; __pyx_t_15 < __pyx_t_14; __pyx_t_15+=1) {\n",
       "        __pyx_v_k = __pyx_t_15;\n",
       "
+18:                 tmp = X[i, k] - X[j, k]
\n", "
        __pyx_t_16 = __pyx_v_i;\n",
       "        __pyx_t_17 = __pyx_v_k;\n",
       "        __pyx_t_18 = __pyx_v_j;\n",
       "        __pyx_t_19 = __pyx_v_k;\n",
       "        __pyx_v_tmp = ((*((double *) ( /* dim=1 */ ((char *) (((double *) ( /* dim=0 */ (__pyx_v_X.data + __pyx_t_16 * __pyx_v_X.strides[0]) )) + __pyx_t_17)) ))) - (*((double *) ( /* dim=1 */ ((char *) (((double *) ( /* dim=0 */ (__pyx_v_X.data + __pyx_t_18 * __pyx_v_X.strides[0]) )) + __pyx_t_19)) ))));\n",
       "
+19:                 d += tmp * tmp
\n", "
        __pyx_v_d = (__pyx_v_d + (__pyx_v_tmp * __pyx_v_tmp));\n",
       "      }\n",
       "
+20:             D[i, j] = sqrt(d)
\n", "
      __pyx_t_19 = __pyx_v_i;\n",
       "      __pyx_t_18 = __pyx_v_j;\n",
       "      *((double *) ( /* dim=1 */ ((char *) (((double *) ( /* dim=0 */ (__pyx_v_D.data + __pyx_t_19 * __pyx_v_D.strides[0]) )) + __pyx_t_18)) )) = sqrt(__pyx_v_d);\n",
       "    }\n",
       "  }\n",
       "
+21:     return np.asarray(D)
\n", "
  __Pyx_XDECREF(__pyx_r);\n",
       "  __Pyx_GetModuleGlobalName(__pyx_t_4, __pyx_n_s_np); if (unlikely(!__pyx_t_4)) __PYX_ERR(0, 21, __pyx_L1_error)\n",
       "  __Pyx_GOTREF(__pyx_t_4);\n",
       "  __pyx_t_3 = __Pyx_PyObject_GetAttrStr(__pyx_t_4, __pyx_n_s_asarray); if (unlikely(!__pyx_t_3)) __PYX_ERR(0, 21, __pyx_L1_error)\n",
       "  __Pyx_GOTREF(__pyx_t_3);\n",
       "  __Pyx_DECREF(__pyx_t_4); __pyx_t_4 = 0;\n",
       "  __pyx_t_4 = __pyx_memoryview_fromslice(__pyx_v_D, 2, (PyObject *(*)(char *)) __pyx_memview_get_double, (int (*)(char *, PyObject *)) __pyx_memview_set_double, 0);; if (unlikely(!__pyx_t_4)) __PYX_ERR(0, 21, __pyx_L1_error)\n",
       "  __Pyx_GOTREF(__pyx_t_4);\n",
       "  __pyx_t_2 = NULL;\n",
       "  __pyx_t_20 = 0;\n",
       "  #if CYTHON_UNPACK_METHODS\n",
       "  if (unlikely(PyMethod_Check(__pyx_t_3))) {\n",
       "    __pyx_t_2 = PyMethod_GET_SELF(__pyx_t_3);\n",
       "    if (likely(__pyx_t_2)) {\n",
       "      PyObject* function = PyMethod_GET_FUNCTION(__pyx_t_3);\n",
       "      __Pyx_INCREF(__pyx_t_2);\n",
       "      __Pyx_INCREF(function);\n",
       "      __Pyx_DECREF_SET(__pyx_t_3, function);\n",
       "      __pyx_t_20 = 1;\n",
       "    }\n",
       "  }\n",
       "  #endif\n",
       "  {\n",
       "    PyObject *__pyx_callargs[2] = {__pyx_t_2, __pyx_t_4};\n",
       "    __pyx_t_5 = __Pyx_PyObject_FastCall(__pyx_t_3, __pyx_callargs+1-__pyx_t_20, 1+__pyx_t_20);\n",
       "    __Pyx_XDECREF(__pyx_t_2); __pyx_t_2 = 0;\n",
       "    __Pyx_DECREF(__pyx_t_4); __pyx_t_4 = 0;\n",
       "    if (unlikely(!__pyx_t_5)) __PYX_ERR(0, 21, __pyx_L1_error)\n",
       "    __Pyx_GOTREF(__pyx_t_5);\n",
       "    __Pyx_DECREF(__pyx_t_3); __pyx_t_3 = 0;\n",
       "  }\n",
       "  __pyx_r = __pyx_t_5;\n",
       "  __pyx_t_5 = 0;\n",
       "  goto __pyx_L0;\n",
       "
" ], "text/plain": [ "" ] }, "execution_count": 28, "metadata": {}, "output_type": "execute_result" } ], "source": [ "%%cython -a\n", "\n", "import numpy as np\n", "cimport numpy as np\n", "cimport cython\n", "from libc.math cimport sqrt\n", "\n", "@cython.boundscheck(False)\n", "@cython.wraparound(False)\n", "def pairwise_cython(double[:, ::1] X):\n", " cdef int M = X.shape[0]\n", " cdef int N = X.shape[1]\n", " cdef double tmp, d\n", " cdef double[:, ::1] D = np.empty((M, M), dtype=np.float64)\n", " for i in range(M):\n", " for j in range(M):\n", " d = 0.0\n", " for k in range(N):\n", " tmp = X[i, k] - X[j, k]\n", " d += tmp * tmp\n", " D[i, j] = sqrt(d)\n", " return np.asarray(D)" ] }, { "cell_type": "code", "execution_count": 29, "metadata": { "ExecuteTime": { "end_time": "2024-04-18T08:24:09.786064Z", "start_time": "2024-04-18T08:24:09.723085Z" }, "collapsed": false }, "outputs": [], "source": [ "assert np.allclose(pairwise_numpy(X), pairwise_cython(X), rtol=1e-10, atol=1e-15)" ] }, { "cell_type": "code", "execution_count": 30, "metadata": { "ExecuteTime": { "end_time": "2024-04-18T08:24:23.467435Z", "start_time": "2024-04-18T08:24:10.196820Z" }, "collapsed": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "1.98 ms ± 40.3 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n" ] } ], "source": [ "timings = %timeit -o pairwise_cython(X)\n", "pairwise_times['cython2'] = timings" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Tohle je už konečně výrazné zrychlení oproti NumPy." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Cython toho nabízí mnoho\n", "Podívejte se na http://docs.cython.org co všechno Cython nabízí -- není toho málo, např.\n", "\n", "* použití C++\n", "* šablony (templates)\n", "* OpenMP (k tomu se možná ještě dostaneme)\n", "* vytváření C-API\n", "* třídy (cdef classes)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Z Fortranu (nebo C) do Pythonu pomocí F2PY\n", "\n", "F2PY je nástroj, který byl v podstatě vytvořen pro NumPy a SciPy, protože, jak dobře víme, tyto moduly volají externí knihovny napsané ve Fortrane nebo C. Dokumentaci (trochu zastaralou) najdeme [zde](http://cens.ioc.ee/projects/f2py2e/usersguide/index.html). Bylo tedy velice výhodné vytvořit nástroj, který toto usnadní. A tak se zrodilo F2PY. Ve zkratce, F2PY umožňuje velice jednoduše z Fortran nebo C funkcí vytvořit Python modul. Využívá navíc vlastností NumPy pro předávání vícerozměnrých polí. \n", "\n", "Poďme chvilku programovat ve Fortranu :)" ] }, { "cell_type": "code", "execution_count": 31, "metadata": { "ExecuteTime": { "end_time": "2024-04-18T08:25:02.527832Z", "start_time": "2024-04-18T08:25:02.523778Z" }, "collapsed": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Overwriting pairwise_fort.f90\n" ] } ], "source": [ "%%file pairwise_fort.f90\n", "\n", "subroutine pairwise_fort(X, D, m, n)\n", " integer :: n,m\n", " double precision, intent(in) :: X(m, n)\n", " double precision, intent(out) :: D(m, m)\n", " integer :: i, j, k\n", " double precision :: r\n", "\n", " do j = 1,m\n", " do i = 1,m\n", " r = 0\n", " do k = 1,n\n", " r = r + (X(i,k) - X(j,k)) * (X(i,k) - X(j,k))\n", " end do\n", " D(i,j) = sqrt(r)\n", " end do\n", " end do\n", "end subroutine pairwise_fort" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Z čeho f2py bere informace o vytvoření modulu?\n", "\n", "1. Pole (double precision) převádí na numpy array.\n", "2. `intent(in)` = vstupní argument.\n", "3. `intent(out)` = výstupní argument.\n", "4. f2py schová explicitně zadané rozměry polí (m, n).\n", "\n", "Pokud bychom programovali v C, je potřeba dodat f2py nějaké informace navíc, neboť např. intent v C neexistuje. \n", "\n", "Tento soubor přeložíme pomocí `f2py`:" ] }, { "cell_type": "code", "execution_count": 32, "metadata": { "ExecuteTime": { "end_time": "2024-04-18T08:25:05.699853Z", "start_time": "2024-04-18T08:25:04.807425Z" }, "collapsed": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Cannot use distutils backend with Python>=3.12, using meson backend instead.\n", "Using meson backend\n", "Will pass --lower to f2py\n", "See https://numpy.org/doc/stable/f2py/buildtools/meson.html\n", "Reading fortran codes...\n", "\tReading file 'pairwise_fort.f90' (format:free)\n", "Post-processing...\n", "\tBlock: pairwise_fort\n", "\t\t\tBlock: pairwise_fort\n", "Applying post-processing hooks...\n", " character_backward_compatibility_hook\n", "Post-processing (stage 2)...\n", "Building modules...\n", " Building module \"pairwise_fort\"...\n", " Generating possibly empty wrappers\"\n", " Maybe empty \"pairwise_fort-f2pywrappers.f\"\n", " Constructing wrapper function \"pairwise_fort\"...\n", " d = pairwise_fort(x,[m,n])\n", " Wrote C/API module \"pairwise_fort\" to file \"./pairwise_fortmodule.c\"\n", "\u001b[1mThe Meson build system\u001b[0m\n", "Version: 1.6.1\n", "Source dir: \u001b[1m/private/var/folders/dm/gbbql3p121z0tr22r2z98vy00000gn/T/tmp5ugidbi4\u001b[0m\n", "Build dir: \u001b[1m/private/var/folders/dm/gbbql3p121z0tr22r2z98vy00000gn/T/tmp5ugidbi4/bbdir\u001b[0m\n", "Build type: \u001b[1mnative build\u001b[0m\n", "Project name: \u001b[1mpairwise_fort\u001b[0m\n", "Project version: \u001b[1m0.1\u001b[0m\n", "Fortran compiler for the host machine: \u001b[1mgfortran\u001b[0m (gcc 14.2.0 \"GNU Fortran (Homebrew GCC 14.2.0_1) 14.2.0\")\n", "Fortran linker for the host machine: \u001b[1mgfortran\u001b[0m ld64 1115.7.3\n", "C compiler for the host machine: \u001b[1mcc\u001b[0m (clang 16.0.0 \"Apple clang version 16.0.0 (clang-1600.0.26.6)\")\n", "C linker for the host machine: \u001b[1mcc\u001b[0m ld64 1115.7.3\n", "Host machine cpu family: \u001b[1maarch64\u001b[0m\n", "Host machine cpu: \u001b[1maarch64\u001b[0m\n", "Program /Users/kuba/workspace/fjfi/python-fjfi/.venv/bin/python3 found: \u001b[1;32mYES\u001b[0m (/Users/kuba/workspace/fjfi/python-fjfi/.venv/bin/python3)\n", "Found pkg-config: \u001b[1;32mYES\u001b[0m \u001b[1m(/opt/homebrew/bin/pkg-config)\u001b[0m \u001b[1;34m2.3.0\u001b[0m\n", "Run-time dependency \u001b[1mpython\u001b[0m found: \u001b[1;32mYES\u001b[0m \u001b[36m3.12\u001b[0m\n", "Library \u001b[1mquadmath\u001b[0m found: \u001b[1;32mYES\u001b[0m\n", "Build targets in project: \u001b[1m1\u001b[0m\n", "\n", "Found ninja-1.11.1.git.kitware.jobserver-1 at /Users/kuba/workspace/fjfi/python-fjfi/.venv/bin/ninja\n", "\u001b[1;32mINFO:\u001b[0m autodetecting backend as ninja\n", "\u001b[1;32mINFO:\u001b[0m calculating backend command to run: /Users/kuba/workspace/fjfi/python-fjfi/.venv/bin/ninja -C /private/var/folders/dm/gbbql3p121z0tr22r2z98vy00000gn/T/tmp5ugidbi4/bbdir\n", "ninja: Entering directory `/private/var/folders/dm/gbbql3p121z0tr22r2z98vy00000gn/T/tmp5ugidbi4/bbdir'\n", "[3/6] Compiling Fortran object pairwis...in.so.p/pairwise_fort-f2pywrappers.f.o\u001b[K\n", "\u001b[01m\u001b[Kf951:\u001b[m\u001b[K \u001b[01;35m\u001b[KWarning:\u001b[m\u001b[K Nonexistent include directory '\u001b[01m\u001b[K/install/include/python3.12\u001b[m\u001b[K' [\u001b[01;35m\u001b[K-Wmissing-include-dirs\u001b[m\u001b[K]\n", "[4/6] Compiling Fortran object pairwis...on-312-darwin.so.p/pairwise_fort.f90.o\u001b[K\n", "\u001b[01m\u001b[Kf951:\u001b[m\u001b[K \u001b[01;35m\u001b[KWarning:\u001b[m\u001b[K Nonexistent include directory '\u001b[01m\u001b[K/install/include/python3.12\u001b[m\u001b[K' [\u001b[01;35m\u001b[K-Wmissing-include-dirs\u001b[m\u001b[K]\n", "[6/6] Linking target pairwise_fort.cpython-312-darwin.so\u001b[Krc_fortranobject.c.o\u001b[K\n" ] } ], "source": [ "!f2py -c pairwise_fort.f90 -m pairwise_fort" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "`-m pairwise_fort` je požadované jméno modulu. Můžeme ho rovnou importovat, resp. jeho stejnojmennou funkci." ] }, { "cell_type": "code", "execution_count": 33, "metadata": { "ExecuteTime": { "end_time": "2024-04-18T08:25:11.219983Z", "start_time": "2024-04-18T08:25:11.216408Z" }, "collapsed": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "d = pairwise_fort(x,[m,n])\n", "\n", "Wrapper for ``pairwise_fort``.\n", "\n", "Parameters\n", "----------\n", "x : input rank-2 array('d') with bounds (m,n)\n", "\n", "Other Parameters\n", "----------------\n", "m : input int, optional\n", " Default: shape(x, 0)\n", "n : input int, optional\n", " Default: shape(x, 1)\n", "\n", "Returns\n", "-------\n", "d : rank-2 array('d') with bounds (m,m)\n", "\n" ] } ], "source": [ "from pairwise_fort import pairwise_fort\n", "print(pairwise_fort.__doc__)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Fortran a C používají jiné uspořádání paměti pro ukládání vícerozměrných polí. Fortran je \"column-major\" zatímco C je \"row-major\". NumPy dokáže pracovat s obojím a pro uživatele je to obvykle jedno. Pokud ovšem chceme předat vícerozměrné pole do Fortran funkce, je lepší mít prvky uložené v paměti jako to dělá Fortran. V takovém případě totiž f2py předá pouze referenci (ukazatel) na dané místo v paměti. V opačném případě f2py nejprve pole musí transponovat, tj. *vytvořit kopii* s jiným uspořádáním, což může být samozřejmě náročné na pamět a procesor.\n", "\n", "Vytvoříme si proměnnou XF, která má Fortran uspořádání, pomocí `numpy.asfortranarray` (prozaický název :)" ] }, { "cell_type": "code", "execution_count": 34, "metadata": { "ExecuteTime": { "end_time": "2024-04-18T08:25:15.329108Z", "start_time": "2024-04-18T08:25:15.326631Z" }, "collapsed": false }, "outputs": [], "source": [ "XF = np.asfortranarray(X)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Vyzkoušíme, jestli stále dostáváme stejné výsledky." ] }, { "cell_type": "code", "execution_count": 35, "metadata": { "ExecuteTime": { "end_time": "2024-04-18T08:25:17.031710Z", "start_time": "2024-04-18T08:25:16.997015Z" }, "collapsed": false }, "outputs": [], "source": [ "assert np.allclose(pairwise_numpy(X), pairwise_fort(X), rtol=1e-10, atol=1e-15)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "No a konečně se můžeme podívat, jak je to s rychlostí ..." ] }, { "cell_type": "code", "execution_count": 36, "metadata": { "ExecuteTime": { "end_time": "2024-04-18T08:25:34.844746Z", "start_time": "2024-04-18T08:25:20.063009Z" }, "collapsed": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "2.43 ms ± 251 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n" ] } ], "source": [ "timings = %timeit -o pairwise_fort(X)\n", "pairwise_times['fortran'] = timings" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Představuje se `numba`\n", "\n", "Numba kompiluje Python kód pomocí [LLVM](http://llvm.org/). Podporujme just-in-time kompilaci pomocí dekorátoru `jit` (http://numba.pydata.org/numba-doc/latest/user/jit.html)." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "```\n", "@numba.jit(\n", " signature=None, \n", " nopython=False, \n", " nogil=False, \n", " cache=False, \n", " forceobj=False, \n", " parallel=False, \n", " error_model='python', \n", " fastmath=False, locals={}\n", ")\n", "```" ] }, { "cell_type": "code", "execution_count": 37, "metadata": { "ExecuteTime": { "end_time": "2024-04-18T08:25:41.072952Z", "start_time": "2024-04-18T08:25:41.070431Z" } }, "outputs": [], "source": [ "# Odkomentujte pro instalaci balíku Numba\n", "# !pip install numba" ] }, { "cell_type": "code", "execution_count": 38, "metadata": { "ExecuteTime": { "end_time": "2024-04-18T08:25:59.325306Z", "start_time": "2024-04-18T08:25:59.202454Z" } }, "outputs": [], "source": [ "import numba" ] }, { "cell_type": "code", "execution_count": 39, "metadata": { "ExecuteTime": { "end_time": "2024-04-18T08:26:00.185673Z", "start_time": "2024-04-18T08:26:00.177083Z" } }, "outputs": [], "source": [ "pairwise_numba = numba.jit(pairwise_python)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Tradiční kontrola. Po prvním spuštění navíc Numba funkci poprvé zkompiluje." ] }, { "cell_type": "code", "execution_count": 40, "metadata": { "ExecuteTime": { "end_time": "2024-04-18T08:26:01.849016Z", "start_time": "2024-04-18T08:26:01.381482Z" }, "collapsed": false }, "outputs": [], "source": [ "assert np.allclose(pairwise_numpy(X), pairwise_numba(X), rtol=1e-10, atol=1e-15)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Jaký čas od takto \"jednoduché\" optimalizace můžeme očekávat?" ] }, { "cell_type": "code", "execution_count": 41, "metadata": { "ExecuteTime": { "end_time": "2024-04-18T08:26:05.311692Z", "start_time": "2024-04-18T08:26:03.651744Z" }, "collapsed": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "2.48 ms ± 65 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n" ] } ], "source": [ "timings = %timeit -o pairwise_numba(X)\n", "pairwise_times['numba'] = timings" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Vidíme, že zrychlení je výborné - jsme na úrovni zatím nejlepšího výsledku! A navíc že jsme toho dosáhli jediným řádkem (kromě importů)." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Ještě můžeme zkusit výledek vylepšit pomocí paralelizace, `nopython` a / nebo `fastmath` režimu. Pro `nopython` musíme vytvořit výsledný numpy objekt vně kompilované funkce. Paralelizace docílíme pomocí `parallel=True` a [`numba.prange`](http://numba.pydata.org/numba-doc/latest/user/parallel.html?highlight=prange). Všimněte si použití `@jit` jako dekorátoru." ] }, { "cell_type": "code", "execution_count": 42, "metadata": { "ExecuteTime": { "end_time": "2024-04-18T08:26:07.690637Z", "start_time": "2024-04-18T08:26:07.686997Z" } }, "outputs": [], "source": [ "@numba.jit(nopython=True, parallel=True, fastmath=True)\n", "def _pairwise_nopython(X: np.ndarray, D: np.ndarray) -> np.ndarray:\n", " M = X.shape[0]\n", " N = X.shape[1]\n", " for i in numba.prange(M):\n", " for j in numba.prange(M):\n", " d = 0.0\n", " for k in range(N):\n", " tmp = X[i, k] - X[j, k]\n", " d += tmp * tmp\n", " D[i, j] = np.sqrt(d)\n", " return D\n", "\n", "\n", "def pairwise_numba_fast_parallel(X: np.ndarray) -> np.ndarray:\n", " D = np.empty((X.shape[0], X.shape[0]), dtype = float)\n", " _pairwise_nopython(X, D)\n", " return D" ] }, { "cell_type": "code", "execution_count": 43, "metadata": { "ExecuteTime": { "end_time": "2024-04-18T08:26:08.828091Z", "start_time": "2024-04-18T08:26:08.498511Z" } }, "outputs": [], "source": [ "assert np.allclose(pairwise_numpy(X), pairwise_numba_fast_parallel(X), rtol=1e-10, atol=1e-15)" ] }, { "cell_type": "code", "execution_count": 44, "metadata": { "ExecuteTime": { "end_time": "2024-04-18T08:26:13.446459Z", "start_time": "2024-04-18T08:26:08.927949Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "711 μs ± 33.1 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)\n" ] } ], "source": [ "timings = %timeit -o pairwise_numba_fast_parallel(X)\n", "pairwise_times['numba_fast_parallel'] = timings" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Pojdme počítat na GPU i CPU: JAX\n", "\n", "_JAX: High-Performance Array Computing_ poskytuje plnou paletu nastrojů pro výpočty jak na CPU, tak na GPU. \n", "\n", "Mezi jeho hlavní výhody patří:\n", "\n", "* NumPy-like API - JAX podporuje většinu NumPy operací.\n", "* JIT kompilace - JAX umožňuje, podobně jako `numba`, kompilaci funkcí pomocí `jax.jit`.\n", "* Běží všude - stejný kód může běžet na CPU, GPU i TPU.\n", "* Automatická diferenciace - JAX umožňuje výpočet gradientů a Hessiánů.\n", "\n", " " ] }, { "cell_type": "code", "execution_count": 45, "metadata": { "ExecuteTime": { "end_time": "2024-04-18T08:26:14.973233Z", "start_time": "2024-04-18T08:26:14.970674Z" } }, "outputs": [], "source": [ "# Odkomentujte pro instalaci balíku JAX\n", "# !pip install jax[cpu] jaxlib" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "JAX NumPy API je téměř identické s NumPy API. Jediný rozdíl je, že JAX funkce vrací vlastní datovy typ, které reprezentují pole na určitém zařízení (CPU, GPU, TPU).).\n", "Je proto vyhodné importovat JAX jako `import jax.numpy as jnp` společně s NumPy jako `import numpy as np`." ] }, { "cell_type": "code", "execution_count": 47, "metadata": { "ExecuteTime": { "end_time": "2024-04-18T08:26:16.447427Z", "start_time": "2024-04-18T08:26:15.990599Z" } }, "outputs": [], "source": [ "import jax.numpy as jnp\n", "import jax" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "V případě, že bychom měli dostupné GPU (AMD/NVIDIA), vyděli bychom jej v nasledujícím seznamu. " ] }, { "cell_type": "code", "execution_count": 48, "metadata": { "ExecuteTime": { "end_time": "2024-04-18T08:26:18.876047Z", "start_time": "2024-04-18T08:26:18.859776Z" } }, "outputs": [ { "data": { "text/plain": [ "[CpuDevice(id=0)]" ] }, "execution_count": 48, "metadata": {}, "output_type": "execute_result" } ], "source": [ "jax.devices()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Pro zpřístupnění vypočtů na GPU stačí nahradit `np` za `jnp` a JAX se postará o zbytek." ] }, { "cell_type": "code", "execution_count": 49, "metadata": { "ExecuteTime": { "end_time": "2024-04-18T08:26:21.002794Z", "start_time": "2024-04-18T08:26:20.999454Z" } }, "outputs": [], "source": [ "def pairwise_jax(X):\n", " return jnp.sqrt(jnp.power(X[:, jnp.newaxis, :] - X, 2).sum(-1))\n", " #return jnp.sqrt((jnp.power(X[:, jnp.newaxis, :] - X, 2)).sum(-1))\n" ] }, { "cell_type": "code", "execution_count": 50, "metadata": { "ExecuteTime": { "end_time": "2024-04-18T08:26:22.355890Z", "start_time": "2024-04-18T08:26:22.156958Z" } }, "outputs": [ { "ename": "AssertionError", "evalue": "", "output_type": "error", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mAssertionError\u001b[0m Traceback (most recent call last)", "Cell \u001b[0;32mIn[50], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m np\u001b[38;5;241m.\u001b[39mallclose(pairwise_numpy(X), pairwise_jax(X), rtol\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m1e-10\u001b[39m, atol\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m1e-15\u001b[39m)\n", "\u001b[0;31mAssertionError\u001b[0m: " ] } ], "source": [ "assert np.allclose(pairwise_numpy(X), pairwise_jax(X), rtol=1e-10, atol=1e-15)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Ouha, co je špatně? Jelikož GPU pracuje s nižší přesností, JAX pracuje v zakladním režimu `float32` namísto `float64`." ] }, { "cell_type": "code", "execution_count": 51, "metadata": { "ExecuteTime": { "end_time": "2024-04-18T08:26:34.792567Z", "start_time": "2024-04-18T08:26:34.768471Z" } }, "outputs": [ { "data": { "text/plain": [ "dtype('float32')" ] }, "execution_count": 51, "metadata": {}, "output_type": "execute_result" } ], "source": [ "pairwise_jax(X).dtype" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Pro vypočty s 64-bitovou přesností je potřeba povolit `jax_enable_x64`." ] }, { "cell_type": "code", "execution_count": 52, "metadata": { "ExecuteTime": { "end_time": "2024-04-18T08:26:36.962400Z", "start_time": "2024-04-18T08:26:36.959692Z" } }, "outputs": [], "source": [ "from jax import config\n", "config.update(\"jax_enable_x64\", True)" ] }, { "cell_type": "code", "execution_count": 53, "metadata": { "ExecuteTime": { "end_time": "2024-04-18T08:26:40.204841Z", "start_time": "2024-04-18T08:26:40.108146Z" } }, "outputs": [ { "data": { "text/plain": [ "dtype('float64')" ] }, "execution_count": 53, "metadata": {}, "output_type": "execute_result" } ], "source": [ "pairwise_jax(X).dtype" ] }, { "cell_type": "code", "execution_count": 54, "metadata": { "ExecuteTime": { "end_time": "2024-04-18T08:26:41.465030Z", "start_time": "2024-04-18T08:26:41.383496Z" } }, "outputs": [], "source": [ "assert np.allclose(pairwise_numpy(X), pairwise_jax(X), rtol=1e-10, atol=1e-15)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Proj JIT kompilaci stačí přidat dekorátor `jax.jit`. \n", "\n", "Nicméně pozor: \n", "* První volání funkce s JIT kompilací může byt pomalejší, než když bychom funkci volali bez kompilace.\n", "* JIT vyžaduje znát typy a tvar (shape) vstupních parametrů v době kompilace." ] }, { "cell_type": "code", "execution_count": 55, "metadata": { "ExecuteTime": { "end_time": "2024-04-18T08:26:43.629525Z", "start_time": "2024-04-18T08:26:43.626418Z" } }, "outputs": [], "source": [ "@jax.jit\n", "def pairwise_jax_jit(X):\n", " res = (jnp.sqrt((jnp.power(X[:, jnp.newaxis, :] - X, 2)))).sum(-1)\n", " return res\n" ] }, { "cell_type": "code", "execution_count": 56, "metadata": { "ExecuteTime": { "end_time": "2024-04-18T08:26:44.614037Z", "start_time": "2024-04-18T08:26:44.611318Z" } }, "outputs": [], "source": [ "X_jnp = np.asarray(X, dtype=jnp.float64)" ] }, { "cell_type": "code", "execution_count": 57, "metadata": { "ExecuteTime": { "end_time": "2024-04-18T08:26:47.551431Z", "start_time": "2024-04-18T08:26:45.087129Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "10.7 ms ± 17 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n" ] } ], "source": [ "timings = %timeit -o pairwise_jax(X_jnp)\n", "pairwise_times['jax'] = timings" ] }, { "cell_type": "code", "execution_count": 59, "metadata": { "ExecuteTime": { "end_time": "2024-04-18T08:26:55.285580Z", "start_time": "2024-04-18T08:26:50.089028Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "663 μs ± 9.51 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)\n" ] } ], "source": [ "timings = %timeit -o pairwise_jax_jit(X_jnp)\n", "pairwise_times['jax_jit'] = timings" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Srovnání výsledků\n", "\n", "Výsledky můžeme porovnat pomocí grafu." ] }, { "cell_type": "code", "execution_count": 60, "metadata": { "ExecuteTime": { "end_time": "2024-04-18T08:26:59.171339Z", "start_time": "2024-04-18T08:26:59.167016Z" } }, "outputs": [ { "data": { "text/plain": [ "{'plain_python': ,\n", " 'numpy': ,\n", " 'cython0': ,\n", " 'cython1': ,\n", " 'cython2': ,\n", " 'fortran': ,\n", " 'numba': ,\n", " 'numba_fast_parallel': ,\n", " 'jax': ,\n", " 'jax_jit': }" ] }, "execution_count": 60, "metadata": {}, "output_type": "execute_result" } ], "source": [ "pairwise_times" ] }, { "cell_type": "code", "execution_count": 61, "metadata": { "ExecuteTime": { "end_time": "2024-04-18T08:27:01.532941Z", "start_time": "2024-04-18T08:27:01.208395Z" }, "collapsed": false }, "outputs": [ { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "fig, ax = plt.subplots()\n", "values = np.array([t.average for t in pairwise_times.values()])\n", "x = range(len(pairwise_times))\n", "ax.bar(x, values)\n", "ax.set_xticks(x)\n", "ax.set_xticklabels(tuple(pairwise_times.keys()), rotation='vertical')\n", "ax.set_ylabel('time [ms]')\n", "ax.set_yscale('log')" ] }, { "cell_type": "code", "execution_count": 62, "metadata": { "ExecuteTime": { "end_time": "2024-04-18T08:31:41.312579Z", "start_time": "2024-04-18T08:31:41.309966Z" } }, "outputs": [], "source": [ "### Důkladnější srovnaní:\n", "pairwise_functions = {\n", " 'plain_python': pairwise_python,\n", " 'numpy': pairwise_numpy,\n", " 'cython0': cyfuncs.pairwise0,\n", " 'cython1': cyfuncs.pairwise1,\n", " 'cython2': pairwise_cython,\n", " 'fortran': pairwise_fort,\n", " 'numba': pairwise_numba,\n", " 'numba_fast_parallel': pairwise_numba_fast_parallel,\n", " 'jax': pairwise_jax,\n", " 'jax_jit': pairwise_jax_jit\n", "}" ] }, { "cell_type": "code", "execution_count": 63, "metadata": { "ExecuteTime": { "end_time": "2024-04-18T08:43:49.848782Z", "start_time": "2024-04-18T08:41:23.293147Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "plain_python\n", "10\n", "111 μs ± 544 ns per loop (mean ± std. dev. of 2 runs, 10,000 loops each)\n", "31\n", "1.1 ms ± 45.6 μs per loop (mean ± std. dev. of 2 runs, 1,000 loops each)\n", "100\n", "11.4 ms ± 415 μs per loop (mean ± std. dev. of 2 runs, 100 loops each)\n", "316\n", "109 ms ± 658 μs per loop (mean ± std. dev. of 2 runs, 10 loops each)\n", "1000\n", "1.1 s ± 4.75 ms per loop (mean ± std. dev. of 2 runs, 1 loop each)\n", "3162\n", "11.2 s ± 51.6 ms per loop (mean ± std. dev. of 2 runs, 1 loop each)\n", "numpy\n", "10\n", "3.96 μs ± 0.608 ns per loop (mean ± std. dev. of 2 runs, 100,000 loops each)\n", "31\n", "19 μs ± 527 ns per loop (mean ± std. dev. of 2 runs, 100,000 loops each)\n", "100\n", "208 μs ± 10.8 μs per loop (mean ± std. dev. of 2 runs, 1,000 loops each)\n", "316\n", "1.85 ms ± 1.49 μs per loop (mean ± std. dev. of 2 runs, 1,000 loops each)\n", "1000\n", "20.5 ms ± 382 μs per loop (mean ± std. dev. of 2 runs, 10 loops each)\n", "3162\n", "199 ms ± 3.42 ms per loop (mean ± std. dev. of 2 runs, 1 loop each)\n", "cython0\n", "10\n", "106 μs ± 828 ns per loop (mean ± std. dev. of 2 runs, 10,000 loops each)\n", "31\n", "1.01 ms ± 13 μs per loop (mean ± std. dev. of 2 runs, 1,000 loops each)\n", "100\n", "10.3 ms ± 5.73 μs per loop (mean ± std. dev. of 2 runs, 100 loops each)\n", "316\n", "104 ms ± 71.3 μs per loop (mean ± std. dev. of 2 runs, 10 loops each)\n", "1000\n", "1.05 s ± 6.76 ms per loop (mean ± std. dev. of 2 runs, 1 loop each)\n", "3162\n", "10.6 s ± 110 ms per loop (mean ± std. dev. of 2 runs, 1 loop each)\n", "cython1\n", "10\n", "6.62 μs ± 4.32 ns per loop (mean ± std. dev. of 2 runs, 100,000 loops each)\n", "31\n", "60.4 μs ± 9.07 ns per loop (mean ± std. dev. of 2 runs, 10,000 loops each)\n", "100\n", "623 μs ± 1.36 μs per loop (mean ± std. dev. of 2 runs, 1,000 loops each)\n", "316\n", "6.65 ms ± 48.7 μs per loop (mean ± std. dev. of 2 runs, 100 loops each)\n", "1000\n", "81.7 ms ± 2.61 ms per loop (mean ± std. dev. of 2 runs, 10 loops each)\n", "3162\n", "787 ms ± 20.3 ms per loop (mean ± std. dev. of 2 runs, 1 loop each)\n", "cython2\n", "10\n", "982 ns ± 4.6 ns per loop (mean ± std. dev. of 2 runs, 1,000,000 loops each)\n", "31\n", "2.56 μs ± 1.39 ns per loop (mean ± std. dev. of 2 runs, 100,000 loops each)\n", "100\n", "17.1 μs ± 1.1 ns per loop (mean ± std. dev. of 2 runs, 100,000 loops each)\n", "316\n", "160 μs ± 93.1 ns per loop (mean ± std. dev. of 2 runs, 10,000 loops each)\n", "1000\n", "1.95 ms ± 1.53 μs per loop (mean ± std. dev. of 2 runs, 1,000 loops each)\n", "3162\n", "20.3 ms ± 193 μs per loop (mean ± std. dev. of 2 runs, 10 loops each)\n", "fortran\n", "10\n", "501 ns ± 0.712 ns per loop (mean ± std. dev. of 2 runs, 1,000,000 loops each)\n", "31\n", "2.29 μs ± 2.48 ns per loop (mean ± std. dev. of 2 runs, 100,000 loops each)\n", "100\n", "20.4 μs ± 48.4 ns per loop (mean ± std. dev. of 2 runs, 10,000 loops each)\n", "316\n", "201 μs ± 680 ns per loop (mean ± std. dev. of 2 runs, 1,000 loops each)\n", "1000\n", "2.32 ms ± 25.3 μs per loop (mean ± std. dev. of 2 runs, 100 loops each)\n", "3162\n", "22.8 ms ± 251 μs per loop (mean ± std. dev. of 2 runs, 10 loops each)\n", "numba\n", "10\n", "604 ns ± 1.75 ns per loop (mean ± std. dev. of 2 runs, 1,000,000 loops each)\n", "31\n", "2.37 μs ± 3.2 ns per loop (mean ± std. dev. of 2 runs, 100,000 loops each)\n", "100\n", "20 μs ± 125 ns per loop (mean ± std. dev. of 2 runs, 100,000 loops each)\n", "316\n", "192 μs ± 1.5 μs per loop (mean ± std. dev. of 2 runs, 10,000 loops each)\n", "1000\n", "2.39 ms ± 34.1 μs per loop (mean ± std. dev. of 2 runs, 100 loops each)\n", "3162\n", "24 ms ± 27.8 μs per loop (mean ± std. dev. of 2 runs, 10 loops each)\n", "numba_fast_parallel\n", "10\n", "95.4 μs ± 126 ns per loop (mean ± std. dev. of 2 runs, 10,000 loops each)\n", "31\n", "96 μs ± 293 ns per loop (mean ± std. dev. of 2 runs, 10,000 loops each)\n", "100\n", "99.9 μs ± 19.5 ns per loop (mean ± std. dev. of 2 runs, 10,000 loops each)\n", "316\n", "140 μs ± 5.56 μs per loop (mean ± std. dev. of 2 runs, 10,000 loops each)\n", "1000\n", "655 μs ± 762 ns per loop (mean ± std. dev. of 2 runs, 1,000 loops each)\n", "3162\n", "5.7 ms ± 6.79 μs per loop (mean ± std. dev. of 2 runs, 100 loops each)\n", "jax\n", "10\n", "43.3 μs ± 4.18 ns per loop (mean ± std. dev. of 2 runs, 10,000 loops each)\n", "31\n", "73.4 μs ± 1.52 μs per loop (mean ± std. dev. of 2 runs, 10,000 loops each)\n", "100\n", "161 μs ± 557 ns per loop (mean ± std. dev. of 2 runs, 10,000 loops each)\n", "316\n", "1.41 ms ± 2.31 μs per loop (mean ± std. dev. of 2 runs, 1,000 loops each)\n", "1000\n", "10.6 ms ± 1.3 μs per loop (mean ± std. dev. of 2 runs, 100 loops each)\n", "3162\n", "128 ms ± 2.93 ms per loop (mean ± std. dev. of 2 runs, 10 loops each)\n", "jax_jit\n", "10\n", "6.17 μs ± 12.4 ns per loop (mean ± std. dev. of 2 runs, 100,000 loops each)\n", "31\n", "17.4 μs ± 7.84 ns per loop (mean ± std. dev. of 2 runs, 100,000 loops each)\n", "100\n", "20.7 μs ± 19.4 ns per loop (mean ± std. dev. of 2 runs, 10,000 loops each)\n", "316\n", "99.5 μs ± 206 ns per loop (mean ± std. dev. of 2 runs, 10,000 loops each)\n", "1000\n", "649 μs ± 82.6 ns per loop (mean ± std. dev. of 2 runs, 1,000 loops each)\n", "3162\n", "6.1 ms ± 11.6 μs per loop (mean ± std. dev. of 2 runs, 100 loops each)\n" ] } ], "source": [ "\n", "Ms = np.logspace(1, 3.5, 6).astype(int)\n", "N = 3\n", "\n", "paiwise_times_all = {}\n", "for name, func in pairwise_functions.items():\n", " print(name)\n", " timings = []\n", " for M in Ms:\n", " X = np.random.random((M, N))\n", " print(M)\n", " # t = %timeit -o func(X)\n", " if \"jax\" in name:\n", " X_jnp = np.asarray(X, dtype=jnp.float64)\n", " # Přeskočíme první měření, protože probíha kompilace\n", " func(X_jnp)\n", " t = %timeit -o -r 2 func(X_jnp)\n", " else:\n", " t = %timeit -o -r 2 func(X)\n", "\n", " timings.append(t.average)\n", "\n", " paiwise_times_all[name] = timings\n" ] }, { "cell_type": "code", "execution_count": 64, "metadata": { "ExecuteTime": { "end_time": "2024-04-18T08:48:52.464845Z", "start_time": "2024-04-18T08:48:52.062914Z" } }, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 64, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "fig, ax = plt.subplots()\n", "\n", "for name, times in paiwise_times_all.items():\n", " ax.plot(Ms, times, label=name)\n", " ax.set_xscale('log')\n", " ax.set_yscale('log')\n", "\n", "ax.set_xlabel('M')\n", "ax.set_ylabel('time [s]')\n", "\n", "ax.legend()\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Další možnosti\n", "\n", "* Vestavěný modul [`ctypes`](http://docs.python.org/2/library/ctypes.html) dovoluje volat funkce z externích dynamických knihoven. Pro použití s NumPy viz [Cookbook](http://wiki.scipy.org/Cookbook/Ctypes).\n", "* Alternativou k ctypes je [cffi](https://pypi.org/project/cffi/).\n", "* [CuPy](https://cupy.chainer.org/) využívá GPU.\n", "* [numexpr](https://github.com/pydata/numexpr) dokáže kompilovat Numpy výrazy.\n", "* [Theano](http://www.deeplearning.net/software/theano/index.html) se zaměřuje na strojové učení, také optimalizuje vektorové (maticové) operace, dovoluje je spouštět na CPU.\n", "* [Nuitka](http://nuitka.net/) kompiluje Python, ale na rozdíl od Numba nespecializuje funkce na základě typů.\n", "* [SWIG](http://www.swig.org/) jde použít pro propojení s mnoha jazyky.'" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Cvičení\n", "\n", "Obdobný postup aplikujte na výpočet kumulativního součtu, který je definovaný jako\n", "\n", "$\\displaystyle\n", "S_j = \\sum\\limits_{i = 1}^j x_i $\n", "\n", "Výsledky a časování porovnejte s `numpy.cumsum`.\n", "\n", "*Nápověda: Ve vaší funkci vytvořte nejprve výsledné numpy pole pomocí `numpy.empty_like`.*" ] } ], "metadata": { "anaconda-cloud": {}, "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.5" } }, "nbformat": 4, "nbformat_minor": 2 }