Optimalizace až na Cost

Naučíme se optimalizovat funkce. Začneme od čisté implementace v Pythonu, zkusíme vyřešit problém v NumPy. Poté si ukážeme, jak postupně přejít k jazykům nižší úrovně, jako je Fortran nebo C, a také si představíme balík Numba.

In [ ]:
%matplotlib inline
from IPython.display import Image

import matplotlib.pyplot as plt
import numpy as np

Základní koncepce optimalizace

Už víme, že Python je jazyk, ve kterém jde velice efektivně programovat. A díky balíkům, jako jsou NumPy, SciPy nebo SymPy jde velice rychle řešit různé vědecké úlohy. Je pochopitelné, že s takovouto mírou abstrakce může být výsledný program pomalejší, než kdyby byl dobře napsán v nějakém kompilovaném jazyce typu C/C++ nebo Fortran. Musíme si ovšem uvědomit, že efektivita programu se měří v celkovém čase stráveném na vývoji a běhu programu (pro daný soubor úloh). Schématicky můžeme znázornit závislost rychlosti běhu programu v závislosti na délce vývoje asi takto:

In [ ]:
Image(filename='optimizing-what.png')

Všimněte si například, že typicky existuje nějaký offset ve vývojovém čase, tj. trvá nám déle v nízkoúrovňovém jazyce, než vůbec dostaneme první výsledek. Potřeba optimalizace tedy silně závisí na objemu výpočtů, které budeme s daným programem řešit.

Toto ovšem neznamená, že pokud je objem velký, máme hned začít programovat v C nebo Fortranu. Za chvíli si ukážeme, jak optimalizaci řešit chytřeji a postupně. Empirické pravidlo říká, že 90 % výpočetního času zabere 10 % zdrojového kódu. Jedná se o konkrétní příklad obecného Paretova 80 / 20 principu. Je tedy vhodné nejprve těchto 10 % najít a poté je teprve začít optimalizovat. Python nám k tomuto účelu poskytuje velice mocné nástroje.

Profilování

Profilování je nástroj, který nám umožní najít kritická místa v našem programu, oněch 10 %, které stojí za to optimalizovat. Zkusme si to ukázat na jednoduchém příkladu.

In [ ]:
def heavy_calc(X):
    Y = X.copy()
    for i in range(10):
        Y = Y**i
    return Y

def heavy_loop(inputs):
    res = []
    for X in inputs:
        res.append(heavy_calc(X))
    return res

def code_setup():
    from numpy.random import rand
    N = 20
    M = 1000
    print("Will generate {} random arrays".format(N))
    inputs = [rand(M, M) for n in range(N)]
    print("Will calculate now")
    result = heavy_loop(inputs)
    print("Finished calculation")

Python obsahuje dva základní mofuly pro profilování - profile a cProfile, z nichž ten druhý je rychlejší. Pomocí funkce run pustíme výpočet pod dohledem cProfile, výsledky uložíme do souboru.

In [ ]:
import cProfile
cProfile.run('code_setup()', 'pstats')

Dále budeme potřebovat modul pstats, který nám umožní s výsledky pracovat. Použije k tomu třídu Stats.

In [ ]:
from pstats import Stats
p = Stats('pstats')

print_stats nám zobrazí prvních n záznamů.

In [ ]:
p.print_stats(10)

Ty jsou ovšem nesetříděné. Následující výstup už je užitečnější, záznamy jsou totiž setříděné podle celkového času stráveného v dané funkci. Navíc strip_dirs odstraní adresáře ze jmen funkcí.

In [ ]:
p.strip_dirs().sort_stats('cumulative').print_stats(10)

Takto vypadá výstup setříděný pomocí nekumulovaného času.

In [ ]:
p.sort_stats('time').print_stats(10)

Jupyter nám může usnadnit práci pomocí %prun a %%prun. Např.

In [ ]:
%prun -s cumulative -l 10 code_setup()

Z obou výstupů celkem jasně vydíme, že naprostou většinu času trávíme ve funkci heavy_calc. Pokud se tedy chceme pustit do optimalizace, musíme se zaměřit právě na tuto část našeho programu.

Výsledky můžete navíc spojit s nástroji pro vizualizaci, např.SnakeViz nebo vprof, popř. pokročilý editor jako PyCharm.

Vzorová úloha - vzdálenost množiny bodů ve vícerozměrném prostoru

(Tento příklad byl převzat z http://jakevdp.github.io/blog/2013/06/15/numba-vs-cython-take-2.)

Zadání je jednoduché: pro M bodů v N rozměrném prostoru spočítejte vzájemnou vzdálenost $d$, která je pro dva body $x,y$ definovaná jako $\sqrt {\sum_{i=1}^N {{{\left( {{x_i} - {y_i}} \right)}^2}} } $. Výslekem je tedy (symetrická) matice $M\times M$.

In [ ]:
# toto nechť jsou naše vstupní data
M = 1000
N = 3
X = np.random.random((M, N))

Implementace v čistém Pythonu

Nemůžeme asi očekávat, že toto bude nejrychlejší a nejsnadnější verze našeho programu. Přesto stojí za to ji vyzkoušet, navíc ji budeme ještě potřebovat.

In [ ]:
def pairwise_python(X):
    M = X.shape[0]
    N = X.shape[1]
    D = np.empty((M, M), dtype=float)
    for i in range(M):
        for j in range(M):
            d = 0.0
            for k in range(N):
                tmp = X[i, k] - X[j, k]
                d += tmp * tmp
            D[i, j] = np.sqrt(d)
    return D

Tahle funkce nám bude pomáhat ukládat výsledné časy z %timeit.

Do pairwise_times si uložíme výsledné časy.

In [ ]:
pairwise_times = {}
In [ ]:
timings = %timeit -o pairwise_python(X)
pairwise_times['plain_python'] = timings

To samé pomocí NumPy

V případě NumPy můžeme v tomto případě využít broadcasting. Celá funkce tak zabere jeden rádek.

In [ ]:
def pairwise_numpy(X):
    return np.sqrt(((X[:, np.newaxis, :] - X) ** 2).sum(-1))

Zkusíme, jestli výsledky jsou stejné pomocí assert a numpy.allclose.

In [ ]:
assert np.allclose(pairwise_numpy(X), pairwise_python(X), rtol=1e-10, atol=1e-15)

Výsledky jsou stejné až na velmi malé rozdíly - to je nebezpečí numerických výpočtů s konečnou přesností.

In [ ]:
timings = %timeit -o pairwise_numpy(X)
pairwise_times['numpy'] = timings

Vidíme, že jsme zkrátili běh programu více než 100-krát. To není špatné, navíc je implementace daleko jednodušší.

Přichází Cython

Cython je nástroj, který z Python programu, obohaceného o nějaké Cython direktivy, vytvoří program v C (případně C++), který je možné zkompilovat a okamžitě použít jako modul v Pythonu. Typickým příkladem Cython direktiv jsou statické typy. Cython samozřejmě umožňuje používat funkce z binárních knihoven s C rozhraním.

Zkusíme optimalizovat naší funkci pairwise_python.

  • Cython zdroják má koncovku .pyx (za začátku byl Pyrex).
  • Cython dokáže přeložit jakýkoli Python. Výsledkem je ale minimální (nebo spíš žádná) optimalizace.
  • cimport je analogie import, pracuje ale s Cython definicemi funkcí (.pxd soubory).
  • Cython dodává numpy.pyx, obsahující dodatečné informace pro kompilace NumPy modulů. Proto voláme cimport numpy.
  • Podobně libc je speciální modul Cythonu.

  • Funkce se deklarují (moho deklarovat) se statickými typy vstupních parametrů. My použijeme np.ndarray[np.float64_t, ndim=2].

  • Proměnné se deklarují pomocí cdef.
In [ ]:
# Odkomentujte pro instalaci Cythonu
# !pip install cython
In [ ]:
%%file cyfuncs.pyx

language_level = "3str"

import numpy as np
# numpy pro Cython
cimport numpy as np
from libc.math cimport sqrt

# tohle je čistý Python
def pairwise0(X):
    M = X.shape[0]
    N = X.shape[1]
    D = np.empty((M, M), dtype=float)
    for i in range(M):
        for j in range(M):
            d = 0.0
            for k in range(N):
                tmp = X[i, k] - X[j, k]
                d += tmp * tmp
            D[i, j] = np.sqrt(d)
    return D

# tady už začínáme optimalizovat, změny ale nejsou drastické
def pairwise1(np.ndarray[np.float64_t, ndim=2] X):
    cdef int M = X.shape[0]
    cdef int N = X.shape[1]
    cdef double tmp, d
    cdef np.ndarray D = np.empty((M, M), dtype=np.float64)
    for i in range(M):
        for j in range(M):
            d = 0.0
            for k in range(N):
                tmp = X[i, k] - X[j, k]
                d += tmp * tmp
            D[i, j] = sqrt(d)
    return D
In [ ]:
%%file setup.py

from distutils.core import setup
from Cython.Build import cythonize
import numpy

setup(
  name='cyfuncs',
  include_dirs=[numpy.get_include()],
  ext_modules=cythonize("cyfuncs.pyx"),
)
In [ ]:
!python setup.py build_ext --inplace

Jak jsme již říkali, Cython vytvoří C zdroják, který se pak kompiluje pomocí běžného překladače (např. gcc). Pojďme se na tento soubor podívat.

In [ ]:
from IPython.display import FileLink
FileLink('cyfuncs.c')

Ten soubor je dlouhý ... Obsahuje spoustu Python "balastu", na kterém vidíme, jak je vlastně samotný CPython naprogramován. Naštěstí tento soubor obsahuje i komentáře, které říkají, které řádce daný blok odpovídá. Např.

 /* "cyfuncs.pyx":16
  *                 tmp = X[i, k] - X[j, k]
  *                 d += tmp * tmp
  *             D[i, j] = np.sqrt(d)             # <<<<<<<<<<<<<<
  *     return D
  * 
  */
In [ ]:
import cyfuncs
In [ ]:
print("cyfuncs obsahuje: " + ", ".join((d for d in dir(cyfuncs) if not d.startswith("_"))))

Podívejme se, jestli dostávám stále stejné výsledky.

In [ ]:
assert np.allclose(pairwise_numpy(X), cyfuncs.pairwise1(X), rtol=1e-10, atol=1e-15)

No a jak jsme na tom s časem?

In [ ]:
timings = %timeit -o cyfuncs.pairwise0(X)
pairwise_times['cython0'] = timings
In [ ]:
timings = %timeit -o cyfuncs.pairwise1(X)
pairwise_times['cython1'] = timings
IPython %%cython magic

IPython, tak jako v mnoha jiných případech, nám práci s Cythonem usnadňuje pomocí triku %%cython. Zkusíme ho použít. Zároveň zkusíme ještě více náš kód optimalizovat, zatím je totiž pomalejší než numpy.

In [ ]:
%load_ext Cython
In [ ]:
%%cython -a

import numpy as np
cimport numpy as np
cimport cython
from libc.math cimport sqrt

@cython.boundscheck(False)
@cython.wraparound(False)
def pairwise_cython(double[:, ::1] X):
    cdef int M = X.shape[0]
    cdef int N = X.shape[1]
    cdef double tmp, d
    cdef double[:, ::1] D = np.empty((M, M), dtype=np.float64)
    for i in range(M):
        for j in range(M):
            d = 0.0
            for k in range(N):
                tmp = X[i, k] - X[j, k]
                d += tmp * tmp
            D[i, j] = sqrt(d)
    return np.asarray(D)
In [ ]:
assert np.allclose(pairwise_numpy(X), pairwise_cython(X), rtol=1e-10, atol=1e-15)
In [ ]:
timings = %timeit -o pairwise_cython(X)
pairwise_times['cython2'] = timings

Tohle je už konečně výrazné zrychlení oproti NumPy.

Cython toho nabízí mnoho

Podívejte se na http://docs.cython.org co všechno Cython nabízí -- není toho málo, např.

  • použití C++
  • šablony (templates)
  • OpenMP (k tomu se možná ještě dostaneme)
  • vytváření C-API
  • třídy (cdef classes)

Z Fortranu (nebo C) do Pythonu pomocí F2PY

F2PY je nástroj, který byl v podstatě vytvořen pro NumPy a SciPy, protože, jak dobře víme, tyto moduly volají externí knihovny napsané ve Fortrane nebo C. Dokumentaci (trochu zastaralou) najdeme zde. Bylo tedy velice výhodné vytvořit nástroj, který toto usnadní. A tak se zrodilo F2PY. Ve zkratce, F2PY umožňuje velice jednoduše z Fortran nebo C funkcí vytvořit Python modul. Využívá navíc vlastností NumPy pro předávání vícerozměnrých polí.

Poďme chvilku programovat ve Fortranu :)

In [ ]:
%%file pairwise_fort.f90

subroutine pairwise_fort(X, D, m, n)
    integer :: n,m
    double precision, intent(in) :: X(m, n)
    double precision, intent(out) :: D(m, m)
    integer :: i, j, k
    double precision :: r

    do j = 1,m
        do i = 1,m
            r = 0
            do k = 1,n
                r = r + (X(i,k) - X(j,k)) * (X(i,k) - X(j,k))
            end do
            D(i,j) = sqrt(r)
        end do
    end do
end subroutine pairwise_fort

Z čeho f2py bere informace o vytvoření modulu?

  1. Pole (double precision) převádí na numpy array.
  2. intent(in) = vstupní argument.
  3. intent(out) = výstupní argument.
  4. f2py schová explicitně zadané rozměry polí (m, n).

Pokud bychom programovali v C, je potřeba dodat f2py nějaké informace navíc, neboť např. intent v C neexistuje.

Tento soubor přeložíme pomocí f2py:

In [ ]:
!f2py -c pairwise_fort.f90 -m pairwise_fort

-m pairwise_fort je požadované jméno modulu. Můžeme ho rovnou importovat, resp. jeho stejnojmennou funkci.

In [ ]:
from pairwise_fort import pairwise_fort
print(pairwise_fort.__doc__)

Fortran a C používají jiné uspořádání paměti pro ukládání vícerozměrných polí. Fortran je "column-major" zatímco C je "row-major". NumPy dokáže pracovat s obojím a pro uživatele je to obvykle jedno. Pokud ovšem chceme předat vícerozměrné pole do Fortran funkce, je lepší mít prvky uložené v paměti jako to dělá Fortran. V takovém případě totiž f2py předá pouze referenci (ukazatel) na dané místo v paměti. V opačném případě f2py nejprve pole musí transponovat, tj. vytvořit kopii s jiným uspořádáním, což může být samozřejmě náročné na pamět a procesor.

Vytvoříme si proměnnou XF, která má Fortran uspořádání, pomocí numpy.asfortranarray (prozaický název :)

In [ ]:
XF = np.asfortranarray(X)

Vyzkoušíme, jestli stále dostáváme stejné výsledky.

In [ ]:
assert np.allclose(pairwise_numpy(X), pairwise_fort(X), rtol=1e-10, atol=1e-15)

No a konečně se můžeme podívat, jak je to s rychlostí ...

In [ ]:
timings = %timeit -o pairwise_fort(X)
pairwise_times['fortran'] = timings

Představuje se numba

Numba kompiluje Python kód pomocí LLVM. Podporujme just-in-time kompilaci pomocí dekorátoru jit (http://numba.pydata.org/numba-doc/latest/user/jit.html).

@numba.jit(
    signature=None, 
    nopython=False, 
    nogil=False, 
    cache=False, 
    forceobj=False, 
    parallel=False, 
    error_model='python', 
    fastmath=False, locals={}
)
In [ ]:
# Odkomentujte pro instalaci balíku Numba
# !pip install numba
In [ ]:
import numba
In [ ]:
pairwise_numba = numba.jit(pairwise_python)

Tradiční kontrola. Po prvním spuštění navíc Numba funkci poprvé zkompiluje.

In [ ]:
assert np.allclose(pairwise_numpy(X), pairwise_numba(X), rtol=1e-10, atol=1e-15)

Jaký čas od takto "jednoduché" optimalizace můžeme očekávat?

In [ ]:
timings = %timeit -o pairwise_numba(X)
pairwise_times['numba'] = timings

Vidíme, že zrychlení je výborné - jsme na úrovni zatím nejlepšího výsledku! A navíc že jsme toho dosáhli jediným řádkem (kromě importů).

Ještě můžeme zkusit výledek vylepšit pomocí paralelizace, nopython a / nebo fastmath režimu. Pro nopython musíme vytvořit výsledný numpy objekt vně kompilované funkce. Paralelizace docílíme pomocí parallel=True a numba.prange. Všimněte si použití @jit jako dekorátoru.

In [ ]:
@numba.jit(nopython=True, parallel=True, fastmath=True)
def _pairwise_nopython(X: np.ndarray, D: np.ndarray) -> np.ndarray:
    M = X.shape[0]
    N = X.shape[1]
    for i in numba.prange(M):
        for j in numba.prange(M):
            d = 0.0
            for k in range(N):
                tmp = X[i, k] - X[j, k]
                d += tmp * tmp
            D[i, j] = np.sqrt(d)
    return D


def pairwise_numba_fast_parallel(X: np.ndarray) -> np.ndarray:
    D = np.empty((X.shape[0], X.shape[0]), dtype = float)
    _pairwise_nopython(X, D)
    return D
In [ ]:
assert np.allclose(pairwise_numpy(X), pairwise_numba_fast_parallel(X), rtol=1e-10, atol=1e-15)
In [ ]:
timings = %timeit -o pairwise_numba_fast_parallel(X)
pairwise_times['numba_fast_parallel'] = timings

Pojdme počítat na GPU i CPU: JAX

JAX: High-Performance Array Computing poskytuje plnou paletu nastrojů pro výpočty jak na CPU, tak na GPU.

Mezi jeho hlavní výhody patří:

  • NumPy-like API - JAX podporuje většinu NumPy operací.
  • JIT kompilace - JAX umožňuje, podobně jako numba, kompilaci funkcí pomocí jax.jit.
  • Běží všude - stejný kód může běžet na CPU, GPU i TPU.
  • Automatická diferenciace - JAX umožňuje výpočet gradientů a Hessiánů.
In [ ]:
# Odkomentujte pro instalaci balíku JAX
# !pip install jax[cpu] jaxlib

JAX NumPy API je téměř identické s NumPy API. Jediný rozdíl je, že JAX funkce vrací vlastní datovy typ, které reprezentují pole na určitém zařízení (CPU, GPU, TPU).). Je proto vyhodné importovat JAX jako import jax.numpy as jnp společně s NumPy jako import numpy as np.

In [ ]:
import jax.numpy as jnp
import jax

V případě, že bychom měli dostupné GPU (AMD/NVIDIA), vyděli bychom jej v nasledujícím seznamu.

In [ ]:
jax.devices()

Pro zpřístupnění vypočtů na GPU stačí nahradit np za jnp a JAX se postará o zbytek.

In [ ]:
def pairwise_jax(X):
    return jnp.sqrt(jnp.power(X[:, jnp.newaxis, :] - X, 2).sum(-1))
    #return jnp.sqrt((jnp.power(X[:, jnp.newaxis, :] - X, 2)).sum(-1))
In [ ]:
assert np.allclose(pairwise_numpy(X), pairwise_jax(X), rtol=1e-10, atol=1e-15)

Ouha, co je špatně? Jelikož GPU pracuje s nižší přesností, JAX pracuje v zakladním režimu float32 namísto float64.

In [ ]:
pairwise_jax(X).dtype

Pro vypočty s 64-bitovou přesností je potřeba povolit jax_enable_x64.

In [ ]:
from jax import config
config.update("jax_enable_x64", True)
In [ ]:
pairwise_jax(X).dtype
In [ ]:
assert np.allclose(pairwise_numpy(X), pairwise_jax(X), rtol=1e-10, atol=1e-15)

Proj JIT kompilaci stačí přidat dekorátor jax.jit.

Nicméně pozor:

  • První volání funkce s JIT kompilací může byt pomalejší, než když bychom funkci volali bez kompilace.
  • JIT vyžaduje znát typy a tvar (shape) vstupních parametrů v době kompilace.
In [ ]:
@jax.jit
def pairwise_jax_jit(X):
    res = (jnp.sqrt((jnp.power(X[:, jnp.newaxis, :] - X, 2)))).sum(-1)
    return res
In [ ]:
X_jnp = np.asarray(X, dtype=jnp.float64)
In [ ]:
timings = %timeit -o pairwise_jax(X_jnp)
pairwise_times['jax'] = timings
In [ ]:
timings = %timeit -o pairwise_jax_jit(X_jnp)
pairwise_times['jax_jit'] = timings

Srovnání výsledků

Výsledky můžeme porovnat pomocí grafu.

In [ ]:
pairwise_times
In [ ]:
fig, ax = plt.subplots()
values = np.array([t.average for t in pairwise_times.values()])
x = range(len(pairwise_times))
ax.bar(x, values)
ax.set_xticks(x)
ax.set_xticklabels(tuple(pairwise_times.keys()), rotation='vertical')
ax.set_ylabel('time [ms]')
ax.set_yscale('log')
In [ ]:
### Důkladnější srovnaní:
pairwise_functions = {
    'plain_python': pairwise_python,
    'numpy': pairwise_numpy,
    'cython0': cyfuncs.pairwise0,
    'cython1': cyfuncs.pairwise1,
    'cython2': pairwise_cython,
    'fortran': pairwise_fort,
    'numba': pairwise_numba,
    'numba_fast_parallel': pairwise_numba_fast_parallel,
    'jax': pairwise_jax,
    'jax_jit': pairwise_jax_jit
}
In [ ]:
Ms = np.logspace(1, 3.5, 6).astype(int)
N = 3

paiwise_times_all = {}
for name, func in pairwise_functions.items():
    print(name)
    timings = []
    for M in Ms:
        X = np.random.random((M, N))
        print(M)
        # t = %timeit -o func(X)
        if "jax" in name:
            X_jnp = np.asarray(X, dtype=jnp.float64)
            # Přeskočíme první měření, protože probíha kompilace
            func(X_jnp)
            t = %timeit -o -r 2 func(X_jnp)
        else:
            t = %timeit -o -r 2 func(X)

        timings.append(t.average)

    paiwise_times_all[name] = timings
In [ ]:
fig, ax = plt.subplots()

for name, times in paiwise_times_all.items():
    ax.plot(Ms, times, label=name)
    ax.set_xscale('log')
    ax.set_yscale('log')

ax.set_xlabel('M')
ax.set_ylabel('time [s]')

ax.legend()

Další možnosti

  • Vestavěný modul ctypes dovoluje volat funkce z externích dynamických knihoven. Pro použití s NumPy viz Cookbook.
  • Alternativou k ctypes je cffi.
  • CuPy využívá GPU.
  • numexpr dokáže kompilovat Numpy výrazy.
  • Theano se zaměřuje na strojové učení, také optimalizuje vektorové (maticové) operace, dovoluje je spouštět na CPU.
  • Nuitka kompiluje Python, ale na rozdíl od Numba nespecializuje funkce na základě typů.
  • SWIG jde použít pro propojení s mnoha jazyky.'

Cvičení

Obdobný postup aplikujte na výpočet kumulativního součtu, který je definovaný jako

$\displaystyle S_j = \sum\limits_{i = 1}^j x_i $

Výsledky a časování porovnejte s numpy.cumsum.

Nápověda: Ve vaší funkci vytvořte nejprve výsledné numpy pole pomocí numpy.empty_like.

Komentáře

Comments powered by Disqus