Jdi na navigaci předmětu

13: Julia ve strojovém učení

Tento notebook je výukovým materiálem v předmětu BI-JUL.21 vyučovaném v zimním semestru akademického roku 2021/2022 Tomášem Kalvodou. Tvorba těchto materiálů byla podpořena NVS FIT.

Hlavní stránkou předmětu, kde jsou i další notebooky a zajímavé informace, je jeho Course Pages stránka.

versioninfo()
Julia Version 1.8.3
Commit 0434deb161e (2022-11-14 20:14 UTC)
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 8 × Intel(R) Core(TM) i5-8250U CPU @ 1.60GHz
  WORD_SIZE: 64
  LIBM: libopenlibm
  LLVM: libLLVM-13.0.1 (ORCJIT, skylake)
  Threads: 1 on 8 virtual cores

1. Úvod

Julia poměrně přirozeně nachází uplatnění i v různých partiích strojového učení (ML, Machine Learning). V této oblasti existuje celá řada nástrojů, často vyuužívajících Python. Existuje celá řada Julia balíčků poskytujících rozhraní k známým nástrojům, např:

Vedle toho vznikají i nástroje založené přímo na Julia, např.:

V tomto Vánočním notebooku se podíváme na první dva balíčky. Tématicky půjde o rozhodovací stromy a neuronové sítě. Pole strojového učení je ovšem velmi široké, řada z vás ještě tuto látku neprobírala, takže tento notebook je spíše reklama na to co vás čeká a jak v této oblasti může Julia pomoci.


2. Rozhodovací stromy

Klasickou ukázkou použití rozhodovacích stromů je klasifikace kosatců na základě velikosti okvětních lístků uložených v tzv. Iris datasetu.

Této úlohy využijeme k demonstraci balíčku DecisionTree.jl, nezapomeňte tento balíček nainstalovat ] add DecisionTree.

using DecisionTree

Data jsou distribuována přímo v balíčku.

features, labels = load_data("iris")
(Any[5.1 3.5 1.4 0.2; 4.9 3.0 1.4 0.2; … ; 6.2 3.4 5.4 2.3; 5.9 3.0 5.1 1.8], Any["Iris-setosa", "Iris-setosa", "Iris-setosa", "Iris-setosa", "Iris-setosa", "Iris-setosa", "Iris-setosa", "Iris-setosa", "Iris-setosa", "Iris-setosa"  …  "Iris-virginica", "Iris-virginica", "Iris-virginica", "Iris-virginica", "Iris-virginica", "Iris-virginica", "Iris-virginica", "Iris-virginica", "Iris-virginica", "Iris-virginica"])

Příznaky jsou čtyři délky různých okvětních lístků, co řádek, to jedna květina:

features = float.(features)
150×4 Matrix{Float64}:
 5.1  3.5  1.4  0.2
 4.9  3.0  1.4  0.2
 4.7  3.2  1.3  0.2
 4.6  3.1  1.5  0.2
 5.0  3.6  1.4  0.2
 5.4  3.9  1.7  0.4
 4.6  3.4  1.4  0.3
 5.0  3.4  1.5  0.2
 4.4  2.9  1.4  0.2
 4.9  3.1  1.5  0.1
 5.4  3.7  1.5  0.2
 4.8  3.4  1.6  0.2
 4.8  3.0  1.4  0.1
 ⋮              
 6.0  3.0  4.8  1.8
 6.9  3.1  5.4  2.1
 6.7  3.1  5.6  2.4
 6.9  3.1  5.1  2.3
 5.8  2.7  5.1  1.9
 6.8  3.2  5.9  2.3
 6.7  3.3  5.7  2.5
 6.7  3.0  5.2  2.3
 6.3  2.5  5.0  1.9
 6.5  3.0  5.2  2.0
 6.2  3.4  5.4  2.3
 5.9  3.0  5.1  1.8

A dále máme k dispozici informaci o jaké kosatce v těchto případech šlo:

labels = string.(labels)
150-element Vector{String}:
 "Iris-setosa"
 "Iris-setosa"
 "Iris-setosa"
 "Iris-setosa"
 "Iris-setosa"
 "Iris-setosa"
 "Iris-setosa"
 "Iris-setosa"
 "Iris-setosa"
 "Iris-setosa"
 "Iris-setosa"
 "Iris-setosa"
 "Iris-setosa"
 ⋮
 "Iris-virginica"
 "Iris-virginica"
 "Iris-virginica"
 "Iris-virginica"
 "Iris-virginica"
 "Iris-virginica"
 "Iris-virginica"
 "Iris-virginica"
 "Iris-virginica"
 "Iris-virginica"
 "Iris-virginica"
 "Iris-virginica"

Celkem jsou v datasetu rozlišovány tři druhy kosatců:

unique(labels)
3-element Vector{String}:
 "Iris-setosa"
 "Iris-versicolor"
 "Iris-virginica"

Vytvoříme a natrénujeme model:

model = build_tree(labels, features)
Decision Tree
Leaves: 9
Depth:  5
print_tree(model)
Feature 3 < 2.45 ?
├─ Iris-setosa : 50/50
└─ Feature 4 < 1.75 ?
    ├─ Feature 3 < 4.95 ?
        ├─ Feature 4 < 1.65 ?
            ├─ Iris-versicolor : 47/47
            └─ Iris-virginica : 1/1
        └─ Feature 4 < 1.55 ?
            ├─ Iris-virginica : 3/3
            └─ Feature 1 < 6.95 ?
                ├─ Iris-versicolor : 2/2
                └─ Iris-virginica : 1/1
    └─ Feature 3 < 4.85 ?
        ├─ Feature 1 < 5.95 ?
            ├─ Iris-versicolor : 1/1
            └─ Iris-virginica : 2/2
        └─ Iris-virginica : 43/43

Prořežeme strom:

model = prune_tree(model, 0.9)
Decision Tree
Leaves: 8
Depth:  5

Rozhodovací strom si můžeme přehledně zobrazit:

print_tree(model)
Feature 3 < 2.45 ?
├─ Iris-setosa : 50/50
└─ Feature 4 < 1.75 ?
    ├─ Feature 3 < 4.95 ?
        ├─ Iris-versicolor : 47/48
        └─ Feature 4 < 1.55 ?
            ├─ Iris-virginica : 3/3
            └─ Feature 1 < 6.95 ?
                ├─ Iris-versicolor : 2/2
                └─ Iris-virginica : 1/1
    └─ Feature 3 < 4.85 ?
        ├─ Feature 1 < 5.95 ?
            ├─ Iris-versicolor : 1/1
            └─ Iris-virginica : 2/2
        └─ Iris-virginica : 43/43

A učinit pomocí něho klasifikaci:

apply_tree(model, [5.9, 3.0, 5.1, 1.9])
"Iris-virginica"

Pustěme ho na všech 150 záznamů v databázi:

preds = apply_tree(model, features)
150-element Vector{String}:
 "Iris-setosa"
 "Iris-setosa"
 "Iris-setosa"
 "Iris-setosa"
 "Iris-setosa"
 "Iris-setosa"
 "Iris-setosa"
 "Iris-setosa"
 "Iris-setosa"
 "Iris-setosa"
 "Iris-setosa"
 "Iris-setosa"
 "Iris-setosa"
 ⋮
 "Iris-virginica"
 "Iris-virginica"
 "Iris-virginica"
 "Iris-virginica"
 "Iris-virginica"
 "Iris-virginica"
 "Iris-virginica"
 "Iris-virginica"
 "Iris-virginica"
 "Iris-virginica"
 "Iris-virginica"
 "Iris-virginica"

Jak přesné jsou naše předpovědi? K vyjádření kolikrát se náš klasifikátor "trefil" do správné třídy můžeme použít confusion matrix:

DecisionTree.confusion_matrix(labels, preds)

Vidíme, že jednu květinu z použitých dat klasifikuje špatně.

Můžeme získat i informaci o tom, jak je naše předpověď pravděpodobná. Následující příznaky jsou opět přímo z datasetu, takže výsledná kategorie má pravděpodobnost 1.

apply_tree_proba(model, [5.9, 3.0, 5.1, 1.9], ["Iris-setosa", "Iris-versicolor", "Iris-virginica"])
3-element Vector{Float64}:
 0.0
 0.0
 1.0

V následujícím příkladě se klasifikátor významněji přiklání k jedné z kategorií.

apply_tree_proba(model, [5.7, 3.2, 4.9, 1.7], ["Iris-setosa", "Iris-versicolor", "Iris-virginica"])
3-element Vector{Float64}:
 0.0
 0.9791666666666666
 0.020833333333333332

Pro další ukázky viz stránky použitého balíčku.


3. Neuronové sítě

Pro práci s neuronovými sítěmi se jako velmi atraktivní jeví balíček Flux.jl, tj. ] add Flux.

using Flux

Ukažme si (například pro studenty, kteří tuto látku ještě nestudovali) základní princip učení neuronové sítě.

Mějme závislost jedné reální proměnné yy na jedné reálné proměnné xx danou explicitním předpisem:

actual(x) = 4x + 2
actual (generic function with 1 method)

Příprava trénovacích a testovacích dat sestávajících vždy z několika málo hodnot.

x_train, x_test = hcat(0:5...), hcat(6:10...)
([0 1 … 4 5], [6 7 … 9 10])

A skutečné hodnoty na trénovacích i testovacích datech.

y_train, y_test = actual.(x_train), actual.(x_test)
([2 6 … 18 22], [26 30 … 38 42])

Sestavení modelu a předpovídání. Konkrétně budeme mít jeden neuron s jedním vstupem a jedním výstupem.

model = Dense(1, 1)
Dense(1 => 1)       # 2 parameters

V tomto případě jde o funkci σ(wx+b)\sigma(w\cdot x + b), kde ww je váha:

model.weight
1×1 Matrix{Float32}:
 0.5675245

bb je bias:

model.bias
1-element Vector{Float32}:
 0.0

a σ\sigma je aktivační funkce, zde identita,

model.σ
identity (generic function with 1 method)

Tento triviální model má tedy dva reálné parametry. Aktuálně dává model následující předpovědi (samozřejmě zcela mimo, zatím jsme ho nenatrénovali):

model(x_train)
1×6 Matrix{Float32}:
 0.0  0.567524  1.13505  1.70257  2.2701  2.83762

Přesnost předpověďí našeho modelu budeme měřit pomocí (objektivní) loss funkce (MSE = mean square error):

loss(x, y) = Flux.mse(model(x), y)
loss (generic function with 1 method)

Jde o průměrnou hodnota kvadratických odchylek, tedy pro x,yRnx,y\in\mathbb{R}^n o výraz

1nj=1n(xjyj)2.\frac{1}{n} \sum_{j=1}^n (x_j - y_j)^2.

"Chyba" je tedy zatím dost velká (z nějakého důvodu zde máme 32 bitový float):

loss(x_train, y_train)
146.3254f0

Pro kontrolu:

sum((model(x_train) - y_train) .* (model(x_train) - y_train)) / length(x_train)
146.3254f0

K učení použijeme jednoduchý gradientní sestup. V tento moment opět neděláme nic jiného, než že řešíme optimalizační úlohu!

opt = Descent() # gradientní sestup
Descent(0.1)

Kompletní trénovací data:

data = [(x_train, y_train)]
1-element Vector{Tuple{Matrix{Int64}, Matrix{Int64}}}:
 ([0 1 … 4 5], [2 6 … 18 22])

Parametry modelu (váha a bias):

parameters = Flux.params(model)
Params([Float32[0.5675245;;], Float32[0.0]])

Jedna epocha (jedna iterace/jeden krok optimalizačního algoritmu) proběhne zavoláním metody train!:

Flux.Optimise.train!(loss, parameters, data, opt) # `model` je "schován" v loss funkci!

Loss funkce se zmenšila!

loss(x_train, y_train)
138.86488f0

Samozřejmě se změnily i naše dva parametry:

parameters
Params([Float32[7.860397;;], Float32[2.116238]])

Toto byla jen jedna epocha, data jsme prošli jen jednou. Projděme více epoch.

for epoch in 1:100
    Flux.Optimise.train!(loss, parameters, data, opt)
end

Zřejmě se blížíme k minimu (a nebo? :-)).

loss(x_train, y_train)
0.749681f0
parameters
Params([Float32[4.2658257;;], Float32[2.0727146]])

V těchto parametrech již jistě rozeznáváte původní hodnoty, z kterých jsme data nagenerovali. Případně učení můžeme prohnat ještě pár dalšími epochami.

Jaké předpovědi dává náš model na testovacích datech?

model(x_test)
1×5 Matrix{Float32}:
 27.6677  31.9335  36.1993  40.4651  44.731

"Správně" bychom očekávali:

y_test
1×5 Matrix{Int64}:
 26  30  34  38  42

V podstatě jsme samozřejmě neudělali nic jiného, než lineární regresi (proložení dat přímkou).


Rozpoznávání cifer

Ukažme si komplikovanější příklad perceptronu z dokumentace Flux (aktuálně 404), resp. tohoto blogu.

using Flux
using MLDatasets
using Statistics

import Flux: onehotbatch, onecold, crossentropy, @epochs, unsqueeze

Nejprve získejme data, v tomto případě MNIST (Modified National Institute of Standards and Technology database) obsahující obrázky arabských číslic nula až devět jakožto 28x28 pixelové obrázky. Data jsou i anotována "správnou" hodnotou.

Při prvním spuštění následujícího příkazu musíte potvrdit stažení souborů.

# trénovací data
x_train, y_train = MLDatasets.MNIST.traindata(Float32);
┌ Warning: MNIST.traindata() is deprecated, use `MNIST(split=:train)[:]` instead.
└ @ MLDatasets /home/kalvin/.julia/packages/MLDatasets/A3giY/src/datasets/vision/mnist.jl:187

V poli x_train jsou uloženy obrázky, konkrétně celkem 6000060\,000 obrázků.

typeof(x_train), size(x_train)
(Array{Float32, 3}, (28, 28, 60000))

Pojďme se alespoň na pár podívat (rotace/zrcadlení?).

x_train[:, :, 1]
28×28 Matrix{Float32}:
 0.0  0.0  0.0  0.0  0.0  0.0        …  0.0       0.0        0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0           0.0       0.0        0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0           0.0       0.0        0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0           0.0       0.0        0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0           0.215686  0.533333   0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0        …  0.67451   0.992157   0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0           0.886275  0.992157   0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0           0.992157  0.992157   0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0           0.992157  0.831373   0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0           0.992157  0.529412   0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0        …  0.992157  0.517647   0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0           0.956863  0.0627451  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0117647     0.521569  0.0        0.0  0.0  0.0
 ⋮                        ⋮          ⋱                       ⋮         
 0.0  0.0  0.0  0.0  0.0  0.494118      0.0       0.0        0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.533333      0.0       0.0        0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.686275      0.0       0.0        0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.101961      0.0       0.0        0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.65098    …  0.0       0.0        0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  1.0           0.0       0.0        0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.968627      0.0       0.0        0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.498039      0.0       0.0        0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0           0.0       0.0        0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0        …  0.0       0.0        0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0           0.0       0.0        0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0           0.0       0.0        0.0  0.0  0.0

A skutečně graficky:

using Images

display(Gray.(x_train[:, :, 1]))
display(Gray.(x_train[:, :, 2]))
display(Gray.(x_train[:, :, 3]))
display(Gray.(x_train[:, :, 4]))
display(Gray.(x_train[:, :, 5]))

V poli y_train jsou uložené cifry, které obrázky reprezentují.

typeof(y_train), size(y_train)
(Vector{Int64}, (60000,))

Prvních pět obrázků zobrazených výše by tedy mělo reprezentovat následující cifry:

y_train[1:5]
5-element Vector{Int64}:
 5
 0
 4
 1
 9

Dále si připravíme testovací data.

# testovací (validační) data
x_valid, y_valid = MLDatasets.MNIST.testdata(Float32);
┌ Warning: MNIST.testdata() is deprecated, use `MNIST(split=:test)[:]` instead.
└ @ MLDatasets /home/kalvin/.julia/packages/MLDatasets/A3giY/src/datasets/vision/mnist.jl:195

Aktuálně jsou data čistě ve formě matic, Flux očekává obrázková data včetně barevného kanálu (u nás je jen jeden - odstíny šedi). Musíme tak data obohatit o ještě jeden rozměr ("délky"). K tomu máme k dispozici metodu unsqueeze:

x_train = unsqueeze(x_train, 3)
x_valid = unsqueeze(x_valid, 3);
typeof(x_train), size(x_train)
(Array{Float32, 4}, (28, 28, 1, 60000))

Podobně jako dříve v Sudoku budeme místo cifer pracovat s desetisložkovým vektorem tvořeným samými nulami a jednou jedničkou na místě odpovídajícím cifře. K jednoduchému přepočítání našich dat k tomu slouží metoda onehotbatch. Výsledkem bude řídká matice, s kterou se dá dále efektivně pracovat.

y_train = onehotbatch(y_train, 0:9)
y_valid = onehotbatch(y_valid, 0:9)
10×10000 OneHotMatrix(::Vector{UInt32}) with eltype Bool:
 ⋅  ⋅  ⋅  1  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  1  ⋅  ⋅  …  ⋅  ⋅  ⋅  ⋅  ⋅  1  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅
 ⋅  ⋅  1  ⋅  ⋅  1  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅     ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  1  ⋅  ⋅  ⋅  ⋅  ⋅
 ⋅  1  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅     ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  1  ⋅  ⋅  ⋅  ⋅
 ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅     ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  1  ⋅  ⋅  ⋅
 ⋅  ⋅  ⋅  ⋅  1  ⋅  1  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅     ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  1  ⋅  ⋅
 ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  1  ⋅  ⋅  ⋅  ⋅  …  1  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  1  ⋅
 ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  1  ⋅     ⋅  1  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  1
 1  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅     ⋅  ⋅  1  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅
 ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅     ⋅  ⋅  ⋅  1  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅
 ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  1  ⋅  1  ⋅  ⋅  1     ⋅  ⋅  ⋅  ⋅  1  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅

Nyní vše spojíme do jednoho datasetu reprezentujícího trénovací data:

train_data = Flux.Data.DataLoader((x_train, y_train), batchsize=128);

Nyní sestavíme náš model. Zvolíme osm vrstev, v tento okamžik zde nebudeme zabýhat do detailu.

model = Chain(
    # 28x28 => 14x14
    Conv((5, 5), 1=>8, pad=2, stride=2, relu),
    # 14x14 => 7x7
    Conv((3, 3), 8=>16, pad=1, stride=2, relu),
    # 7x7 => 4x4
    Conv((3, 3), 16=>32, pad=1, stride=2, relu),
    # 4x4 => 2x2
    Conv((3, 3), 32=>32, pad=1, stride=2, relu),
    
    GlobalMeanPool(),
    Flux.flatten,
    
    Dense(32, 10),
    softmax
)
Chain(
  Conv((5, 5), 1 => 8, relu, pad=2, stride=2),  # 208 parameters
  Conv((3, 3), 8 => 16, relu, pad=1, stride=2),  # 1_168 parameters
  Conv((3, 3), 16 => 32, relu, pad=1, stride=2),  # 4_640 parameters
  Conv((3, 3), 32 => 32, relu, pad=1, stride=2),  # 9_248 parameters
  GlobalMeanPool(),
  Flux.flatten,
  Dense(32 => 10),                      # 330 parameters
  NNlib.softmax,
)                   # Total: 10 arrays, 15_594 parameters, 62.445 KiB.

Model jsme ještě nevytrénovali, ale i tak můžeme zkusit, jestli někde není chyba:

# Getting predictions
z = model(x_train)
# Decoding predictions
z = onecold(z)
println("Prediction of first image: $(z[1])")
Prediction of first image: 2

Pojďme "přesnost" měřit následovně (tj. 0 == nikdy jsme se netrefili, 1 == kompletní shoda):

accuracy(z, y) = mean(onecold(z) .== onecold(y))
accuracy (generic function with 1 method)

Aktuálně bychom měli mít velmi nízkou shodu, jak se snadno přesvědčíme:

accuracy(z, y_train)
0.0

Nyní zbývá zadefinovat loss funkci a připravit se na učení.

loss(x, y) = Flux.crossentropy(model(x), y)
opt = Descent()
ps = Flux.params(model)
Params([[0.04345867 0.0519028 … 0.030679185 -0.11206507; -0.0695501 0.13440974 … -0.0028491213 -0.10806305; … ; 0.082303025 -0.121138975 … -0.0427864 0.027985096; 0.04478124 -0.03732655 … -0.11871343 -0.1358215;;;; 0.103053965 0.13191748 … 0.0390348 -0.0020494254; 0.0094644055 0.14085227 … -0.16249433 0.0603232; … ; 0.0560955 0.018597117 … 0.10506808 0.032174602; -0.055019394 -0.005297032 … -0.06916303 -0.15951283;;;; -0.06410445 -0.09105586 … -0.06425119 -0.09833508; -0.04323213 -0.1531362 … 0.08278113 0.0850655; … ; -0.12999636 -0.029409831 … 0.049379554 -0.16300139; 0.14659558 -0.07386534 … -0.13154443 -0.12172512;;;; -0.05182951 -0.1317463 … 0.0729956 -0.03497103; -0.07068439 -0.10213475 … -0.16189586 0.13950673; … ; -0.15886292 0.123373725 … -0.013489984 -0.08300134; -0.014450125 -0.10403124 … 0.045309998 0.03167637;;;; -0.14134175 -0.010377847 … -0.021209463 -0.15274462; 0.13829422 -0.035228282 … -0.064470604 -0.11101778; … ; 0.038298585 0.13067764 … -0.0801855 0.011693977; 0.1358617 -0.02376526 … 0.15542275 -0.10442679;;;; -0.050319158 0.1476676 … 0.16282177 -0.14849484; 0.08358729 -0.05367564 … 0.14980118 -0.1147771; … ; 0.016804868 -0.10823167 … 0.08958686 0.14045078; 0.039942905 -0.082392946 … 0.036086455 -0.15655947;;;; 0.0903863 -0.020434666 … -0.13846995 -0.13052544; 0.060188998 0.06278592 … -0.025665043 0.048034184; … ; 0.115773186 -0.07855419 … 0.093607664 -0.00447748; -0.102154076 -0.14633189 … 0.06698112 -0.014349677;;;; 0.084439114 0.029865608 … 0.14361604 -0.0940607; -0.10479833 0.013312485 … -0.14890146 0.15549752; … ; -0.10898441 0.12475652 … -0.15087014 -0.12635462; 0.0006256044 0.07149066 … -0.09766414 -0.14514737], Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.098492965 -0.010989706 -0.041809164; 0.13799118 0.13081816 0.13059676; -0.010215163 0.038489245 0.070277736;;; 0.03993122 0.06476049 0.14447889; -0.14566231 0.06540462 0.009050211; 0.14777529 0.033976078 0.0010724267;;; -0.06651016 -0.095367774 0.16582757; 0.058266103 0.03701379 0.052316092; -0.08332573 -0.16044515 0.16260934;;; 0.13754058 -0.10994242 0.010162791; 0.033871353 0.0045828223 -0.1133553; 0.14175114 -0.15293282 -0.08582091;;; 0.08274567 -0.022414645 -0.044816714; -0.09981688 -0.026266655 0.1239665; 0.11118913 0.10677578 -0.06420557;;; -0.13225113 0.10585936 0.022342127; -0.010513286 -0.054634493 -0.021112185; 0.05273676 0.07203364 -0.11901917;;; 0.069302246 0.07534087 -0.11597924; 0.104338095 -0.025957745 0.16102087; 0.01743283 0.106756434 -0.017382305;;; 0.11708307 0.13680556 -0.03600675; -0.047263723 -0.1545913 -0.07226305; 0.14493328 0.032409173 -0.023005188;;;; 0.09586947 0.1549335 0.15256786; -0.15707688 -0.15437728 0.023636203; 0.15659812 0.04080659 -0.07040231;;; -0.09862816 -0.05103165 0.016438723; -0.10111177 -0.120266244 -0.002058665; -0.06792232 -0.00051967305 -0.062930845;;; -0.12354261 -0.095856234 -0.011102955; 0.14153475 -0.11184941 0.027495405; -0.07055505 -0.059702635 0.09269863;;; 0.09687066 0.01978024 0.02634424; 0.14300317 -0.1439642 -0.082803175; 0.09497845 -0.032181602 -0.06105141;;; -0.10922917 0.09570746 0.06308907; 0.15460376 0.115887865 0.0996663; 0.09217229 -0.051554263 0.072806045;;; -0.07811707 -0.09252606 0.046781402; 0.1333267 -0.11654854 -0.040844463; 0.060453556 -0.11527413 -0.068430685;;; 0.14375529 0.058389448 -0.13363127; -0.049835324 -0.08915661 0.030547759; -0.07883346 -0.1645546 -0.054050088;;; 0.10647708 0.05644385 -0.07629798; 0.15248261 0.032525502 -0.058862984; 0.02066958 -0.07171428 0.1004148;;;; -0.116436526 0.105549 -0.0949738; -0.10417541 0.0052270098 0.14972451; 0.08823957 -0.017463883 0.1054759;;; 0.15600704 -0.02823351 -0.06637331; 0.094292484 -0.0113812685 0.12053772; 0.01969411 0.15973982 0.072850585;;; -0.08573066 -0.11685296 -0.027546626; 0.0915156 -0.08701968 0.06121997; 0.059749763 -0.09152462 0.11666916;;; 0.07177578 0.037980318 0.060850978; -0.14005297 -0.12648241 -0.09105885; -0.109086975 0.14097899 -0.14319992;;; -0.06702237 0.16316232 -0.16356628; 0.1128574 0.11679797 0.16601723; 0.0734192 -0.042667113 -0.015418808;;; 0.16467457 -0.017160058 0.07305124; 0.09180361 -0.1431375 -0.09943782; -0.15129958 0.08867562 -0.15186772;;; 0.06166806 -0.117923915 0.073538825; 0.14543709 -0.163568 0.077689454; -0.046191793 0.08256644 0.11223342;;; -0.090913855 0.047593158 -0.16209517; 0.13696182 -0.0459856 0.15929079; 0.1293783 -0.04593573 0.10933006;;;; … ;;;; 0.08875257 0.047455333 -0.100108646; 0.03484909 -0.11461117 0.03654494; 0.13974601 -0.12704888 0.095922455;;; 0.0012560487 0.12795761 -0.11744418; -0.09306564 -0.11646005 -0.12544835; 0.10318874 -0.14477369 -0.14946476;;; 0.05153386 -0.09951709 -0.05317738; 0.09343183 0.06696044 0.15841262; -0.15880094 -0.007952373 -0.08935833;;; 0.051874023 0.08695706 0.07288488; -0.15548918 0.13025355 0.050305188; -0.055870198 -0.09664252 0.0038958192;;; -0.104099415 0.0443457 0.058937054; -0.09595668 -0.064209685 -0.16097015; -0.08442762 0.07722992 -0.14754824;;; 0.0765483 -0.10550817 -0.07092645; -0.0010806521 0.10051893 -0.115517795; -0.030120512 -0.036198854 -0.04333683;;; -0.13958043 0.14062634 0.13870166; -0.023347577 -0.106552504 -0.11686467; 0.08174803 0.08532139 0.12068486;;; -0.13310568 -0.020809155 -0.021818241; -0.16565667 0.035318814 0.0019530853; -0.14833221 0.0042972765 0.025458574;;;; 0.06026026 0.15534821 -0.012906155; 0.07124941 -0.028288364 -0.15423372; -0.00926286 0.16189763 -0.08036699;;; 0.1123808 -0.05326621 0.12505071; -0.09422588 0.014620444 0.09752326; 0.14823113 0.14449194 -0.15550487;;; -0.051793378 0.14822605 0.06264137; 0.16574594 0.077170774 0.07470574; -0.08949685 0.123805806 0.13185392;;; -0.13680416 0.16319329 -0.06956373; -0.14383934 -0.16383497 0.073431596; 0.16225612 0.0035276613 0.04317957;;; -0.091075204 -0.13167195 0.07686739; 0.13491625 0.040214263 0.1255645; 0.06552337 -0.12093544 0.038570147;;; -0.06843047 -0.0045178533 0.16651492; -0.01803426 0.0930454 0.06653478; 0.08362579 -0.04865108 -0.036420267;;; 0.09165492 -0.0044929786 0.0134309735; -0.09496224 0.103297636 0.0040376983; -0.032184403 0.14868289 -0.13937023;;; 0.15932915 0.081981204 0.06538048; 0.076568626 0.010508339 0.048101544; 0.15669504 -0.07415293 0.0699504;;;; -0.08069881 0.062364362 0.124893785; -0.13495265 -0.09873589 -0.050736845; -0.16598257 -0.12863098 -0.06857874;;; -0.045879744 0.16120599 -0.15732601; -0.103535816 -0.081530295 -0.005281548; -0.07928747 -0.053081416 0.06965824;;; -0.15905218 -0.10493247 -0.057815075; -0.08391162 -0.03292237 -0.07383923; -0.00984329 -0.011959533 -0.052526753;;; 0.053180795 0.15874664 -0.055878144; -0.038234632 0.118108034 -0.1611167; -0.06724515 0.07239348 0.015467465;;; -0.035428107 0.15569595 -0.06487626; -0.11449985 -0.06978661 0.14976531; 0.1495658 -0.030165693 -0.14154415;;; 0.15598403 0.018530708 0.15468585; -0.05259844 0.16338083 -0.10575688; 0.0548775 0.012418966 -0.013709366;;; 0.08392076 0.15710297 0.07483542; -0.031504553 0.15619342 0.039251428; 0.041137695 0.09533759 0.022108018;;; 0.053296328 -0.10033536 -0.13819176; -0.006598036 0.10779407 0.061223924; -0.015153408 -0.014042894 0.088644825], Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [-0.06042359 -0.08660967 -0.055630397; -0.08815988 0.064195566 -0.035169143; -0.09603527 -0.099422 -0.06701526;;; 0.08788788 -0.069484025 -0.0776086; -0.07148152 -0.07081234 0.10296843; 0.112362266 0.07793139 -0.038195413;;; -0.0055710687 0.026661886 -0.070060045; -0.05720735 0.092402965 -0.018252775; 0.01340142 -0.00014064403 0.042126577;;; … ;;; 0.117023125 0.11748216 -0.11220874; 0.04532208 -0.09431376 -0.10699835; 0.09185852 -0.07859564 -0.091722;;; 0.03148271 0.00016438676 0.08032767; 0.005468441 -0.031866223 -0.09998303; 0.06701791 -0.09354439 -0.0048075225;;; -0.10478373 -0.09356782 0.09225203; -0.11187694 -0.07542693 -0.00878072; -0.09351926 0.0727439 -0.09802294;;;; 0.08613692 0.056650266 -0.049715325; 0.056839466 -0.002410828 -0.04682342; -0.042816196 -0.096252054 -0.08556118;;; -0.11031078 0.07599973 0.055166025; 0.07578067 0.09646988 -0.06731202; -0.05186746 -0.088826574 0.09309893;;; -0.101404585 -0.06369933 0.0445657; 0.08334697 0.08309819 0.021893026; 0.07859269 0.06334059 -0.010844525;;; … ;;; -0.11077307 -0.02789825 0.11473771; -0.029184824 0.047109976 0.025483418; 0.0809176 -0.002515352 -0.058323078;;; -0.11509346 -0.04009819 0.10520754; -0.081775114 -0.07908829 -0.00053310144; 0.10619639 0.012081717 0.109755576;;; -0.041703142 -0.061618287 0.012191313; 0.021045579 0.0071340706 -0.08278022; 0.093283184 0.062081955 -0.027269544;;;; 0.0843725 -0.09425848 0.08070717; 0.07467716 -0.06560002 0.099061936; 0.10863947 0.070024766 0.059391513;;; -0.051938727 0.08050739 0.03373129; -0.050433382 0.030697798 -0.011945035; 0.02388906 0.0063602687 -0.117438525;;; 0.08406686 -0.04519471 -0.0637578; 0.09074557 -0.049044687 0.020482386; 0.11428881 -0.012114634 -0.0033043833;;; … ;;; 0.06549185 -0.09307323 0.0259651; -0.054547757 -0.039960578 -0.068566814; -0.056072194 0.053014092 -0.09135668;;; 0.068792425 0.017731067 0.07870205; -0.11121788 -0.0008447212 -0.0877369; -0.010374307 0.106969334 0.029535739;;; -0.040468715 -0.03560931 -0.088743605; 0.09394024 0.016264651 -0.105080776; -0.016833521 0.06633837 -0.027417494;;;; … ;;;; 0.07104594 -0.11241389 0.08947321; 0.033435713 0.10731884 0.036904864; -0.10268164 0.04816664 -0.029407864;;; 0.044742633 -0.106065355 -0.093439266; 0.03483663 -0.091866896 0.043162826; 0.030338187 0.08185432 0.0074583343;;; -0.079810455 0.07063996 -0.06707529; -0.017727287 -0.006290108 -0.07651689; -0.07984296 0.07496621 0.054145746;;; … ;;; -0.051756583 0.08356277 -0.03481806; 0.06852649 -0.077274814 0.0040810234; 0.08165109 -0.08012873 0.08699555;;; -0.046830192 -0.070529774 0.06638509; -0.06324479 0.021830916 -0.04188585; 0.080379196 0.056811903 -0.0096851215;;; -0.11022355 0.099013284 -0.1149995; 0.10566989 0.017512761 -0.116805; 0.017660458 0.10399492 0.02979369;;;; 0.07733827 0.077586725 -0.11116498; -0.08654636 -0.10688979 -0.08701194; -0.02741033 -0.10634871 -0.04353292;;; -0.055868234 -0.0014820659 0.100929126; -0.049722508 0.013145392 0.0028019927; -0.061871953 0.10925259 -0.0003455058;;; -0.009819106 0.048155963 -0.0096191475; -0.021656007 0.03695735 0.012263033; 0.10033004 0.032884166 -0.07901535;;; … ;;; 0.061449654 0.0152114015 0.07324399; 0.10943332 -0.0830591 0.06555774; 0.06152087 -0.09826199 0.082387805;;; -0.023884956 0.02383077 0.0874208; -0.04709382 0.033487484 0.03512149; -0.026025904 0.07301527 0.031429455;;; -0.08404336 -0.072803035 0.027120654; 0.050990675 0.113487594 0.042047072; 0.0016886697 -0.047027804 0.032680724;;;; 0.08992643 -0.07953932 -0.027652169; -0.05680954 -0.08573395 0.020324165; 0.10532847 -0.11007345 -0.10200549;;; 0.087424256 -0.090772495 0.06064836; -0.088461444 -0.018909745 -0.111272275; 0.013029151 -0.051981874 -0.010968268;;; 0.09599723 0.09119425 -0.099256024; 0.023567254 -0.07808867 0.07775103; 0.05172749 -0.084476344 -0.11702098;;; … ;;; 0.057024658 0.03588449 0.06518215; -0.11517725 -0.079458065 -0.10393292; -0.06126313 -0.08893906 0.09449007;;; -0.115184024 0.08676186 0.02913651; -0.038352903 0.1081972 -0.024200495; -0.078161694 0.1079326 0.03777875;;; -0.07327348 0.058588687 0.097199135; -0.064685464 0.0035126028 -0.08930635; 0.068497665 -0.03347491 -0.10834619], Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0  …  0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.05736505 0.07910824 -0.0050783153; -0.07096387 -0.04436114 0.10167752; 0.055980556 -0.030940648 -0.028611299;;; -0.056326773 -0.09635433 0.03841428; 0.072198845 -0.096073344 0.05787758; -0.015899519 -0.054047003 -0.072522745;;; 0.0474121 0.00083285035 0.0019753566; 0.068468 0.009147119 -0.0952422; -0.024435002 -0.022641307 -0.06922704;;; … ;;; 0.094621 -0.07240909 0.014108682; -0.055640362 -0.087646075 -0.0895864; 0.09805373 -0.08738999 -0.10046777;;; -0.023010107 -0.0031301514 0.027530989; 0.013230036 0.039232444 -0.02645624; 0.076439895 0.003336171 0.00083011284;;; -0.03280271 0.0070184735 -0.023207597; -0.09866495 0.014952762 -0.0019330772; 0.021318015 -0.024061458 -0.05573497;;;; 0.099357106 0.0045144837 -0.098532975; -0.02292141 0.056106348 0.0017980263; -0.0073047937 0.06938027 0.07668876;;; 0.022939514 -0.009469964 0.056705732; 0.095790006 0.032307096 0.10071559; -0.056059603 -0.026924526 -0.055061553;;; 0.10158025 0.094438516 -0.05205785; -0.07706227 0.024250615 -0.014179614; 0.0006407861 -0.08686509 0.028403394;;; … ;;; -0.043543987 0.10076445 -0.07981682; -0.059287723 0.09425819 0.08304173; 0.046162345 0.041298795 -0.0230308;;; 0.06902185 -0.096342854 0.075525485; -0.04645577 -0.08231316 0.096911065; -0.0010954939 -0.07271828 0.040537972;;; -0.0997284 -0.012448043 -0.052605562; 0.049763385 -0.044758312 0.050574347; -0.012608072 -0.03561493 0.041805346;;;; 0.057001725 -0.040021397 -0.048545823; -0.05471984 0.025910866 0.059996665; 0.051619265 -0.00630622 0.065758996;;; -0.01365771 -0.05062153 0.06148809; -0.007890221 0.040930994 0.010517107; -0.030643512 -0.01675551 -0.026299385;;; -0.042533964 0.0660248 -0.007557436; -0.017219232 -0.025686985 -0.03565666; -0.09520086 -0.025579236 -0.061450265;;; … ;;; 0.087049186 -0.041221343 0.07075191; 0.04937247 -0.08693265 0.048757717; -0.002759406 -0.060941692 0.06823887;;; -0.043862574 -0.08877707 -0.059135843; 0.06773625 -0.013757209 -0.020028425; 0.06628728 0.073348105 0.05313266;;; -0.0039883573 0.04677601 0.043142714; 0.04248746 0.04744277 0.10118595; 0.08364808 -0.03160483 -0.035021033;;;; … ;;;; -0.06965584 -0.0013961829 -0.029060155; 0.0033122145 -0.07103713 0.093489386; 0.0958181 -0.06467735 -0.02018444;;; 0.072027594 -0.078316055 -0.011911331; 0.089672066 -0.08199512 -0.002356261; -0.003622065 -0.04414743 0.03233464;;; 0.061915398 0.098239884 -0.02738399; 0.08235395 -0.048877306 -0.017131522; 0.09187188 -0.0074713444 0.096671455;;; … ;;; 0.021906594 0.0014109047 -0.010903632; 0.09164269 0.08559022 -0.040745646; 0.059964996 0.05697846 0.09211376;;; 0.043448452 -0.027687307 -0.038102057; 0.027154222 0.04608887 0.051346082; -0.024543153 0.020256137 0.07648824;;; -0.06574008 0.08072355 0.074340396; -0.054879136 -0.028355323 -0.09271009; 0.050845534 -0.010056376 0.07782329;;;; 0.045539785 0.032338083 -0.048904218; 0.04658834 0.07193689 0.08490136; 0.071139336 -0.066213645 -0.100889266;;; -0.1010549 0.07626708 0.0746525; -0.05450647 -0.0278802 -0.07270774; 0.061567124 -0.023196768 -0.08654951;;; -0.013370234 -0.07297455 -0.012319051; 0.100312866 -0.038974985 0.10009289; 0.03813009 0.056243468 -0.07155648;;; … ;;; -0.058383573 -0.09969674 0.07579363; 0.085557796 0.025419025 -0.09474081; -0.020902654 0.017982515 -0.07690825;;; -0.02497057 0.032917198 -0.08355293; 0.025056504 0.049819816 0.023732373; -0.08410779 -0.06505822 0.021739192;;; -0.035714917 0.062294662 0.07860457; 0.006272871 -0.05302064 0.09013072; 0.07470063 -0.038121607 0.0003606224;;;; 0.06661162 0.09476825 0.010694595; -0.038434513 0.081536114 0.005156523; 0.018002674 0.09085246 -0.07659161;;; 0.06984942 0.017836513 0.012237472; -0.07525216 0.036720023 0.06105864; -0.09645195 0.07460411 -0.045834266;;; -0.046984803 0.0872234 0.0012835067; 0.058173563 0.002069004 -0.08194415; 0.099978335 0.014651271 0.08476727;;; … ;;; -0.017271148 -0.052735586 0.05737821; -0.0011333325 0.01949461 0.06530301; 0.038627926 0.037624437 -0.083214216;;; -0.09757997 0.040057607 -0.08728757; -0.011686466 0.0005903184 0.02830309; 0.031193096 0.007301728 -0.058197398;;; 0.071791515 -0.0657642 0.099956326; 0.031907212 -0.07410844 0.04256641; -0.0015776986 -0.050071266 0.02444027], Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0  …  0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], Float32[-0.28695634 -0.2080844 … 0.17181304 -0.083971284; 0.06789562 0.34318665 … 0.24812546 -0.10470303; … ; 0.1238623 -0.31849182 … 0.36174598 -0.27657127; -0.08542383 -0.26908675 … -0.08216387 0.34210646], Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]])

K učení lze použít makro @epochs.

number_epochs = 10
@epochs number_epochs Flux.Optimise.train!(loss, ps, train_data, opt)
accuracy(model(x_train), y_train)
┌ Warning: The macro `@epochs` will be removed from Flux 0.14.
│ As an alternative, you can write a simple `for i in 1:epochs` loop.
│   caller = eval at boot.jl:368 [inlined]
└ @ Core ./boot.jl:368
┌ Info: Epoch 1
└ @ Main /home/kalvin/.julia/packages/Flux/ZdbJr/src/optimise/train.jl:185
┌ Info: Epoch 2
└ @ Main /home/kalvin/.julia/packages/Flux/ZdbJr/src/optimise/train.jl:185
┌ Info: Epoch 3
└ @ Main /home/kalvin/.julia/packages/Flux/ZdbJr/src/optimise/train.jl:185
┌ Info: Epoch 4
└ @ Main /home/kalvin/.julia/packages/Flux/ZdbJr/src/optimise/train.jl:185
┌ Info: Epoch 5
└ @ Main /home/kalvin/.julia/packages/Flux/ZdbJr/src/optimise/train.jl:185
┌ Info: Epoch 6
└ @ Main /home/kalvin/.julia/packages/Flux/ZdbJr/src/optimise/train.jl:185
┌ Info: Epoch 7
└ @ Main /home/kalvin/.julia/packages/Flux/ZdbJr/src/optimise/train.jl:185
┌ Info: Epoch 8
└ @ Main /home/kalvin/.julia/packages/Flux/ZdbJr/src/optimise/train.jl:185
┌ Info: Epoch 9
└ @ Main /home/kalvin/.julia/packages/Flux/ZdbJr/src/optimise/train.jl:185
┌ Info: Epoch 10
└ @ Main /home/kalvin/.julia/packages/Flux/ZdbJr/src/optimise/train.jl:185
0.97865

Jak náš model předpovídá a jak přesně?

onecold(model(x_valid))[1:5]
5-element Vector{Int64}:
 8
 3
 2
 1
 5
onecold(y_valid)[1:5]
5-element Vector{Int64}:
 8
 3
 2
 1
 5
accuracy(model(x_valid), y_valid)
0.9743

4. Uzavření semestru

Tímto se dostáváme na konec prvního běhu BI-JUL.21. Ve zbylém čase budeme na tomto Vánočním cvičení řešit případné dotazy a zajímavosti.

Nezapomeňte vyplnit anketu hodnocení výuky!


Reference

Vedle výše zmíněných balíčků a nástrojů zmíněných v úvodu můžete prohledat i kategorii Machine Learning v databázi balíčků.