13: Julia ve strojovém učení
Tento notebook je výukovým materiálem v předmětu BI-JUL.21 vyučovaném v zimním semestru akademického roku 2021/2022 Tomášem Kalvodou. Tvorba těchto materiálů byla podpořena NVS FIT.
Hlavní stránkou předmětu, kde jsou i další notebooky a zajímavé informace, je jeho Course Pages stránka.
versioninfo()
Julia Version 1.8.3 Commit 0434deb161e (2022-11-14 20:14 UTC) Platform Info: OS: Linux (x86_64-linux-gnu) CPU: 8 × Intel(R) Core(TM) i5-8250U CPU @ 1.60GHz WORD_SIZE: 64 LIBM: libopenlibm LLVM: libLLVM-13.0.1 (ORCJIT, skylake) Threads: 1 on 8 virtual cores
1. Úvod
Julia poměrně přirozeně nachází uplatnění i v různých partiích strojového učení (ML, Machine Learning). V této oblasti existuje celá řada nástrojů, často vyuužívajících Python. Existuje celá řada Julia balíčků poskytujících rozhraní k známým nástrojům, např:
Keras.jl
,TensorFlow.jl
,- a další.
Vedle toho vznikají i nástroje založené přímo na Julia, např.:
DecisionTree.jl
,FluxML.jl
,MLJ.jl
,- a další.
V tomto Vánočním notebooku se podíváme na první dva balíčky. Tématicky půjde o rozhodovací stromy a neuronové sítě. Pole strojového učení je ovšem velmi široké, řada z vás ještě tuto látku neprobírala, takže tento notebook je spíše reklama na to co vás čeká a jak v této oblasti může Julia pomoci.
2. Rozhodovací stromy
Klasickou ukázkou použití rozhodovacích stromů je klasifikace kosatců na základě velikosti okvětních lístků uložených v tzv. Iris datasetu.
Této úlohy využijeme k demonstraci balíčku DecisionTree.jl
, nezapomeňte tento balíček nainstalovat ] add DecisionTree
.
using DecisionTree
Data jsou distribuována přímo v balíčku.
features, labels = load_data("iris")
(Any[5.1 3.5 1.4 0.2; 4.9 3.0 1.4 0.2; … ; 6.2 3.4 5.4 2.3; 5.9 3.0 5.1 1.8], Any["Iris-setosa", "Iris-setosa", "Iris-setosa", "Iris-setosa", "Iris-setosa", "Iris-setosa", "Iris-setosa", "Iris-setosa", "Iris-setosa", "Iris-setosa" … "Iris-virginica", "Iris-virginica", "Iris-virginica", "Iris-virginica", "Iris-virginica", "Iris-virginica", "Iris-virginica", "Iris-virginica", "Iris-virginica", "Iris-virginica"])
Příznaky jsou čtyři délky různých okvětních lístků, co řádek, to jedna květina:
features = float.(features)
150×4 Matrix{Float64}: 5.1 3.5 1.4 0.2 4.9 3.0 1.4 0.2 4.7 3.2 1.3 0.2 4.6 3.1 1.5 0.2 5.0 3.6 1.4 0.2 5.4 3.9 1.7 0.4 4.6 3.4 1.4 0.3 5.0 3.4 1.5 0.2 4.4 2.9 1.4 0.2 4.9 3.1 1.5 0.1 5.4 3.7 1.5 0.2 4.8 3.4 1.6 0.2 4.8 3.0 1.4 0.1 ⋮ 6.0 3.0 4.8 1.8 6.9 3.1 5.4 2.1 6.7 3.1 5.6 2.4 6.9 3.1 5.1 2.3 5.8 2.7 5.1 1.9 6.8 3.2 5.9 2.3 6.7 3.3 5.7 2.5 6.7 3.0 5.2 2.3 6.3 2.5 5.0 1.9 6.5 3.0 5.2 2.0 6.2 3.4 5.4 2.3 5.9 3.0 5.1 1.8
A dále máme k dispozici informaci o jaké kosatce v těchto případech šlo:
labels = string.(labels)
150-element Vector{String}: "Iris-setosa" "Iris-setosa" "Iris-setosa" "Iris-setosa" "Iris-setosa" "Iris-setosa" "Iris-setosa" "Iris-setosa" "Iris-setosa" "Iris-setosa" "Iris-setosa" "Iris-setosa" "Iris-setosa" ⋮ "Iris-virginica" "Iris-virginica" "Iris-virginica" "Iris-virginica" "Iris-virginica" "Iris-virginica" "Iris-virginica" "Iris-virginica" "Iris-virginica" "Iris-virginica" "Iris-virginica" "Iris-virginica"
Celkem jsou v datasetu rozlišovány tři druhy kosatců:
unique(labels)
3-element Vector{String}: "Iris-setosa" "Iris-versicolor" "Iris-virginica"
Vytvoříme a natrénujeme model:
model = build_tree(labels, features)
Decision Tree Leaves: 9 Depth: 5
print_tree(model)
Feature 3 < 2.45 ? ├─ Iris-setosa : 50/50 └─ Feature 4 < 1.75 ? ├─ Feature 3 < 4.95 ? ├─ Feature 4 < 1.65 ? ├─ Iris-versicolor : 47/47 └─ Iris-virginica : 1/1 └─ Feature 4 < 1.55 ? ├─ Iris-virginica : 3/3 └─ Feature 1 < 6.95 ? ├─ Iris-versicolor : 2/2 └─ Iris-virginica : 1/1 └─ Feature 3 < 4.85 ? ├─ Feature 1 < 5.95 ? ├─ Iris-versicolor : 1/1 └─ Iris-virginica : 2/2 └─ Iris-virginica : 43/43
Prořežeme strom:
model = prune_tree(model, 0.9)
Decision Tree Leaves: 8 Depth: 5
Rozhodovací strom si můžeme přehledně zobrazit:
print_tree(model)
Feature 3 < 2.45 ? ├─ Iris-setosa : 50/50 └─ Feature 4 < 1.75 ? ├─ Feature 3 < 4.95 ? ├─ Iris-versicolor : 47/48 └─ Feature 4 < 1.55 ? ├─ Iris-virginica : 3/3 └─ Feature 1 < 6.95 ? ├─ Iris-versicolor : 2/2 └─ Iris-virginica : 1/1 └─ Feature 3 < 4.85 ? ├─ Feature 1 < 5.95 ? ├─ Iris-versicolor : 1/1 └─ Iris-virginica : 2/2 └─ Iris-virginica : 43/43
A učinit pomocí něho klasifikaci:
apply_tree(model, [5.9, 3.0, 5.1, 1.9])
"Iris-virginica"
Pustěme ho na všech 150 záznamů v databázi:
preds = apply_tree(model, features)
150-element Vector{String}: "Iris-setosa" "Iris-setosa" "Iris-setosa" "Iris-setosa" "Iris-setosa" "Iris-setosa" "Iris-setosa" "Iris-setosa" "Iris-setosa" "Iris-setosa" "Iris-setosa" "Iris-setosa" "Iris-setosa" ⋮ "Iris-virginica" "Iris-virginica" "Iris-virginica" "Iris-virginica" "Iris-virginica" "Iris-virginica" "Iris-virginica" "Iris-virginica" "Iris-virginica" "Iris-virginica" "Iris-virginica" "Iris-virginica"
Jak přesné jsou naše předpovědi? K vyjádření kolikrát se náš klasifikátor "trefil" do správné třídy můžeme použít confusion matrix:
DecisionTree.confusion_matrix(labels, preds)
Vidíme, že jednu květinu z použitých dat klasifikuje špatně.
Můžeme získat i informaci o tom, jak je naše předpověď pravděpodobná. Následující příznaky jsou opět přímo z datasetu, takže výsledná kategorie má pravděpodobnost 1.
apply_tree_proba(model, [5.9, 3.0, 5.1, 1.9], ["Iris-setosa", "Iris-versicolor", "Iris-virginica"])
3-element Vector{Float64}: 0.0 0.0 1.0
V následujícím příkladě se klasifikátor významněji přiklání k jedné z kategorií.
apply_tree_proba(model, [5.7, 3.2, 4.9, 1.7], ["Iris-setosa", "Iris-versicolor", "Iris-virginica"])
3-element Vector{Float64}: 0.0 0.9791666666666666 0.020833333333333332
Pro další ukázky viz stránky použitého balíčku.
3. Neuronové sítě
Pro práci s neuronovými sítěmi se jako velmi atraktivní jeví balíček Flux.jl
, tj. ] add Flux
.
using Flux
Ukažme si (například pro studenty, kteří tuto látku ještě nestudovali) základní princip učení neuronové sítě.
Mějme závislost jedné reální proměnné na jedné reálné proměnné danou explicitním předpisem:
actual(x) = 4x + 2
actual (generic function with 1 method)
Příprava trénovacích a testovacích dat sestávajících vždy z několika málo hodnot.
x_train, x_test = hcat(0:5...), hcat(6:10...)
([0 1 … 4 5], [6 7 … 9 10])
A skutečné hodnoty na trénovacích i testovacích datech.
y_train, y_test = actual.(x_train), actual.(x_test)
([2 6 … 18 22], [26 30 … 38 42])
Sestavení modelu a předpovídání. Konkrétně budeme mít jeden neuron s jedním vstupem a jedním výstupem.
model = Dense(1, 1)
Dense(1 => 1) # 2 parameters
V tomto případě jde o funkci , kde je váha:
model.weight
1×1 Matrix{Float32}: 0.5675245
je bias:
model.bias
1-element Vector{Float32}: 0.0
a je aktivační funkce, zde identita,
model.σ
identity (generic function with 1 method)
Tento triviální model má tedy dva reálné parametry. Aktuálně dává model následující předpovědi (samozřejmě zcela mimo, zatím jsme ho nenatrénovali):
model(x_train)
1×6 Matrix{Float32}: 0.0 0.567524 1.13505 1.70257 2.2701 2.83762
Přesnost předpověďí našeho modelu budeme měřit pomocí (objektivní) loss funkce (MSE = mean square error):
loss(x, y) = Flux.mse(model(x), y)
loss (generic function with 1 method)
Jde o průměrnou hodnota kvadratických odchylek, tedy pro o výraz
"Chyba" je tedy zatím dost velká (z nějakého důvodu zde máme 32 bitový float):
loss(x_train, y_train)
146.3254f0
Pro kontrolu:
sum((model(x_train) - y_train) .* (model(x_train) - y_train)) / length(x_train)
146.3254f0
K učení použijeme jednoduchý gradientní sestup. V tento moment opět neděláme nic jiného, než že řešíme optimalizační úlohu!
opt = Descent() # gradientní sestup
Descent(0.1)
Kompletní trénovací data:
data = [(x_train, y_train)]
1-element Vector{Tuple{Matrix{Int64}, Matrix{Int64}}}: ([0 1 … 4 5], [2 6 … 18 22])
Parametry modelu (váha a bias):
parameters = Flux.params(model)
Params([Float32[0.5675245;;], Float32[0.0]])
Jedna epocha (jedna iterace/jeden krok optimalizačního algoritmu) proběhne zavoláním metody train!
:
Flux.Optimise.train!(loss, parameters, data, opt) # `model` je "schován" v loss funkci!
Loss funkce se zmenšila!
loss(x_train, y_train)
138.86488f0
Samozřejmě se změnily i naše dva parametry:
parameters
Params([Float32[7.860397;;], Float32[2.116238]])
Toto byla jen jedna epocha, data jsme prošli jen jednou. Projděme více epoch.
for epoch in 1:100
Flux.Optimise.train!(loss, parameters, data, opt)
end
Zřejmě se blížíme k minimu (a nebo? :-)).
loss(x_train, y_train)
0.749681f0
parameters
Params([Float32[4.2658257;;], Float32[2.0727146]])
V těchto parametrech již jistě rozeznáváte původní hodnoty, z kterých jsme data nagenerovali. Případně učení můžeme prohnat ještě pár dalšími epochami.
Jaké předpovědi dává náš model na testovacích datech?
model(x_test)
1×5 Matrix{Float32}: 27.6677 31.9335 36.1993 40.4651 44.731
"Správně" bychom očekávali:
y_test
1×5 Matrix{Int64}: 26 30 34 38 42
V podstatě jsme samozřejmě neudělali nic jiného, než lineární regresi (proložení dat přímkou).
Rozpoznávání cifer
Ukažme si komplikovanější příklad perceptronu z dokumentace Flux (aktuálně 404), resp. tohoto blogu.
using Flux
using MLDatasets
using Statistics
import Flux: onehotbatch, onecold, crossentropy, @epochs, unsqueeze
Nejprve získejme data, v tomto případě MNIST (Modified National Institute of Standards and Technology database) obsahující obrázky arabských číslic nula až devět jakožto 28x28 pixelové obrázky. Data jsou i anotována "správnou" hodnotou.
Při prvním spuštění následujícího příkazu musíte potvrdit stažení souborů.
# trénovací data
x_train, y_train = MLDatasets.MNIST.traindata(Float32);
┌ Warning: MNIST.traindata() is deprecated, use `MNIST(split=:train)[:]` instead. └ @ MLDatasets /home/kalvin/.julia/packages/MLDatasets/A3giY/src/datasets/vision/mnist.jl:187
V poli x_train
jsou uloženy obrázky, konkrétně celkem obrázků.
typeof(x_train), size(x_train)
(Array{Float32, 3}, (28, 28, 60000))
Pojďme se alespoň na pár podívat (rotace/zrcadlení?).
x_train[:, :, 1]
28×28 Matrix{Float32}: 0.0 0.0 0.0 0.0 0.0 0.0 … 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.215686 0.533333 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 … 0.67451 0.992157 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.886275 0.992157 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.992157 0.992157 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.992157 0.831373 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.992157 0.529412 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 … 0.992157 0.517647 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.956863 0.0627451 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0117647 0.521569 0.0 0.0 0.0 0.0 ⋮ ⋮ ⋱ ⋮ 0.0 0.0 0.0 0.0 0.0 0.494118 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.533333 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.686275 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.101961 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.65098 … 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.968627 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.498039 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 … 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
A skutečně graficky:
using Images
display(Gray.(x_train[:, :, 1]))
display(Gray.(x_train[:, :, 2]))
display(Gray.(x_train[:, :, 3]))
display(Gray.(x_train[:, :, 4]))
display(Gray.(x_train[:, :, 5]))
V poli y_train
jsou uložené cifry, které obrázky reprezentují.
typeof(y_train), size(y_train)
(Vector{Int64}, (60000,))
Prvních pět obrázků zobrazených výše by tedy mělo reprezentovat následující cifry:
y_train[1:5]
5-element Vector{Int64}: 5 0 4 1 9
Dále si připravíme testovací data.
# testovací (validační) data
x_valid, y_valid = MLDatasets.MNIST.testdata(Float32);
┌ Warning: MNIST.testdata() is deprecated, use `MNIST(split=:test)[:]` instead. └ @ MLDatasets /home/kalvin/.julia/packages/MLDatasets/A3giY/src/datasets/vision/mnist.jl:195
Aktuálně jsou data čistě ve formě matic, Flux očekává obrázková data včetně barevného kanálu (u nás je jen jeden - odstíny šedi). Musíme tak data obohatit o ještě jeden rozměr ("délky").
K tomu máme k dispozici metodu unsqueeze
:
x_train = unsqueeze(x_train, 3)
x_valid = unsqueeze(x_valid, 3);
typeof(x_train), size(x_train)
(Array{Float32, 4}, (28, 28, 1, 60000))
Podobně jako dříve v Sudoku budeme místo cifer pracovat s desetisložkovým vektorem tvořeným samými nulami a jednou jedničkou na místě odpovídajícím cifře.
K jednoduchému přepočítání našich dat k tomu slouží metoda onehotbatch
.
Výsledkem bude řídká matice, s kterou se dá dále efektivně pracovat.
y_train = onehotbatch(y_train, 0:9)
y_valid = onehotbatch(y_valid, 0:9)
10×10000 OneHotMatrix(::Vector{UInt32}) with eltype Bool: ⋅ ⋅ ⋅ 1 ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ 1 ⋅ ⋅ … ⋅ ⋅ ⋅ ⋅ ⋅ 1 ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ 1 ⋅ ⋅ 1 ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ 1 ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ 1 ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ 1 ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ 1 ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ 1 ⋅ 1 ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ 1 ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ 1 ⋅ ⋅ ⋅ ⋅ … 1 ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ 1 ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ 1 ⋅ ⋅ 1 ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ 1 1 ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ 1 ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ 1 ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ 1 ⋅ 1 ⋅ ⋅ 1 ⋅ ⋅ ⋅ ⋅ 1 ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅
Nyní vše spojíme do jednoho datasetu reprezentujícího trénovací data:
train_data = Flux.Data.DataLoader((x_train, y_train), batchsize=128);
Nyní sestavíme náš model. Zvolíme osm vrstev, v tento okamžik zde nebudeme zabýhat do detailu.
model = Chain(
# 28x28 => 14x14
Conv((5, 5), 1=>8, pad=2, stride=2, relu),
# 14x14 => 7x7
Conv((3, 3), 8=>16, pad=1, stride=2, relu),
# 7x7 => 4x4
Conv((3, 3), 16=>32, pad=1, stride=2, relu),
# 4x4 => 2x2
Conv((3, 3), 32=>32, pad=1, stride=2, relu),
GlobalMeanPool(),
Flux.flatten,
Dense(32, 10),
softmax
)
Chain( Conv((5, 5), 1 => 8, relu, pad=2, stride=2), # 208 parameters Conv((3, 3), 8 => 16, relu, pad=1, stride=2), # 1_168 parameters Conv((3, 3), 16 => 32, relu, pad=1, stride=2), # 4_640 parameters Conv((3, 3), 32 => 32, relu, pad=1, stride=2), # 9_248 parameters GlobalMeanPool(), Flux.flatten, Dense(32 => 10), # 330 parameters NNlib.softmax, ) # Total: 10 arrays, 15_594 parameters, 62.445 KiB.
Model jsme ještě nevytrénovali, ale i tak můžeme zkusit, jestli někde není chyba:
# Getting predictions
z = model(x_train)
# Decoding predictions
z = onecold(z)
println("Prediction of first image: $(z[1])")
Prediction of first image: 2
Pojďme "přesnost" měřit následovně (tj. 0 == nikdy jsme se netrefili, 1 == kompletní shoda):
accuracy(z, y) = mean(onecold(z) .== onecold(y))
accuracy (generic function with 1 method)
Aktuálně bychom měli mít velmi nízkou shodu, jak se snadno přesvědčíme:
accuracy(z, y_train)
0.0
Nyní zbývá zadefinovat loss funkci a připravit se na učení.
loss(x, y) = Flux.crossentropy(model(x), y)
opt = Descent()
ps = Flux.params(model)
Params([[0.04345867 0.0519028 … 0.030679185 -0.11206507; -0.0695501 0.13440974 … -0.0028491213 -0.10806305; … ; 0.082303025 -0.121138975 … -0.0427864 0.027985096; 0.04478124 -0.03732655 … -0.11871343 -0.1358215;;;; 0.103053965 0.13191748 … 0.0390348 -0.0020494254; 0.0094644055 0.14085227 … -0.16249433 0.0603232; … ; 0.0560955 0.018597117 … 0.10506808 0.032174602; -0.055019394 -0.005297032 … -0.06916303 -0.15951283;;;; -0.06410445 -0.09105586 … -0.06425119 -0.09833508; -0.04323213 -0.1531362 … 0.08278113 0.0850655; … ; -0.12999636 -0.029409831 … 0.049379554 -0.16300139; 0.14659558 -0.07386534 … -0.13154443 -0.12172512;;;; -0.05182951 -0.1317463 … 0.0729956 -0.03497103; -0.07068439 -0.10213475 … -0.16189586 0.13950673; … ; -0.15886292 0.123373725 … -0.013489984 -0.08300134; -0.014450125 -0.10403124 … 0.045309998 0.03167637;;;; -0.14134175 -0.010377847 … -0.021209463 -0.15274462; 0.13829422 -0.035228282 … -0.064470604 -0.11101778; … ; 0.038298585 0.13067764 … -0.0801855 0.011693977; 0.1358617 -0.02376526 … 0.15542275 -0.10442679;;;; -0.050319158 0.1476676 … 0.16282177 -0.14849484; 0.08358729 -0.05367564 … 0.14980118 -0.1147771; … ; 0.016804868 -0.10823167 … 0.08958686 0.14045078; 0.039942905 -0.082392946 … 0.036086455 -0.15655947;;;; 0.0903863 -0.020434666 … -0.13846995 -0.13052544; 0.060188998 0.06278592 … -0.025665043 0.048034184; … ; 0.115773186 -0.07855419 … 0.093607664 -0.00447748; -0.102154076 -0.14633189 … 0.06698112 -0.014349677;;;; 0.084439114 0.029865608 … 0.14361604 -0.0940607; -0.10479833 0.013312485 … -0.14890146 0.15549752; … ; -0.10898441 0.12475652 … -0.15087014 -0.12635462; 0.0006256044 0.07149066 … -0.09766414 -0.14514737], Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.098492965 -0.010989706 -0.041809164; 0.13799118 0.13081816 0.13059676; -0.010215163 0.038489245 0.070277736;;; 0.03993122 0.06476049 0.14447889; -0.14566231 0.06540462 0.009050211; 0.14777529 0.033976078 0.0010724267;;; -0.06651016 -0.095367774 0.16582757; 0.058266103 0.03701379 0.052316092; -0.08332573 -0.16044515 0.16260934;;; 0.13754058 -0.10994242 0.010162791; 0.033871353 0.0045828223 -0.1133553; 0.14175114 -0.15293282 -0.08582091;;; 0.08274567 -0.022414645 -0.044816714; -0.09981688 -0.026266655 0.1239665; 0.11118913 0.10677578 -0.06420557;;; -0.13225113 0.10585936 0.022342127; -0.010513286 -0.054634493 -0.021112185; 0.05273676 0.07203364 -0.11901917;;; 0.069302246 0.07534087 -0.11597924; 0.104338095 -0.025957745 0.16102087; 0.01743283 0.106756434 -0.017382305;;; 0.11708307 0.13680556 -0.03600675; -0.047263723 -0.1545913 -0.07226305; 0.14493328 0.032409173 -0.023005188;;;; 0.09586947 0.1549335 0.15256786; -0.15707688 -0.15437728 0.023636203; 0.15659812 0.04080659 -0.07040231;;; -0.09862816 -0.05103165 0.016438723; -0.10111177 -0.120266244 -0.002058665; -0.06792232 -0.00051967305 -0.062930845;;; -0.12354261 -0.095856234 -0.011102955; 0.14153475 -0.11184941 0.027495405; -0.07055505 -0.059702635 0.09269863;;; 0.09687066 0.01978024 0.02634424; 0.14300317 -0.1439642 -0.082803175; 0.09497845 -0.032181602 -0.06105141;;; -0.10922917 0.09570746 0.06308907; 0.15460376 0.115887865 0.0996663; 0.09217229 -0.051554263 0.072806045;;; -0.07811707 -0.09252606 0.046781402; 0.1333267 -0.11654854 -0.040844463; 0.060453556 -0.11527413 -0.068430685;;; 0.14375529 0.058389448 -0.13363127; -0.049835324 -0.08915661 0.030547759; -0.07883346 -0.1645546 -0.054050088;;; 0.10647708 0.05644385 -0.07629798; 0.15248261 0.032525502 -0.058862984; 0.02066958 -0.07171428 0.1004148;;;; -0.116436526 0.105549 -0.0949738; -0.10417541 0.0052270098 0.14972451; 0.08823957 -0.017463883 0.1054759;;; 0.15600704 -0.02823351 -0.06637331; 0.094292484 -0.0113812685 0.12053772; 0.01969411 0.15973982 0.072850585;;; -0.08573066 -0.11685296 -0.027546626; 0.0915156 -0.08701968 0.06121997; 0.059749763 -0.09152462 0.11666916;;; 0.07177578 0.037980318 0.060850978; -0.14005297 -0.12648241 -0.09105885; -0.109086975 0.14097899 -0.14319992;;; -0.06702237 0.16316232 -0.16356628; 0.1128574 0.11679797 0.16601723; 0.0734192 -0.042667113 -0.015418808;;; 0.16467457 -0.017160058 0.07305124; 0.09180361 -0.1431375 -0.09943782; -0.15129958 0.08867562 -0.15186772;;; 0.06166806 -0.117923915 0.073538825; 0.14543709 -0.163568 0.077689454; -0.046191793 0.08256644 0.11223342;;; -0.090913855 0.047593158 -0.16209517; 0.13696182 -0.0459856 0.15929079; 0.1293783 -0.04593573 0.10933006;;;; … ;;;; 0.08875257 0.047455333 -0.100108646; 0.03484909 -0.11461117 0.03654494; 0.13974601 -0.12704888 0.095922455;;; 0.0012560487 0.12795761 -0.11744418; -0.09306564 -0.11646005 -0.12544835; 0.10318874 -0.14477369 -0.14946476;;; 0.05153386 -0.09951709 -0.05317738; 0.09343183 0.06696044 0.15841262; -0.15880094 -0.007952373 -0.08935833;;; 0.051874023 0.08695706 0.07288488; -0.15548918 0.13025355 0.050305188; -0.055870198 -0.09664252 0.0038958192;;; -0.104099415 0.0443457 0.058937054; -0.09595668 -0.064209685 -0.16097015; -0.08442762 0.07722992 -0.14754824;;; 0.0765483 -0.10550817 -0.07092645; -0.0010806521 0.10051893 -0.115517795; -0.030120512 -0.036198854 -0.04333683;;; -0.13958043 0.14062634 0.13870166; -0.023347577 -0.106552504 -0.11686467; 0.08174803 0.08532139 0.12068486;;; -0.13310568 -0.020809155 -0.021818241; -0.16565667 0.035318814 0.0019530853; -0.14833221 0.0042972765 0.025458574;;;; 0.06026026 0.15534821 -0.012906155; 0.07124941 -0.028288364 -0.15423372; -0.00926286 0.16189763 -0.08036699;;; 0.1123808 -0.05326621 0.12505071; -0.09422588 0.014620444 0.09752326; 0.14823113 0.14449194 -0.15550487;;; -0.051793378 0.14822605 0.06264137; 0.16574594 0.077170774 0.07470574; -0.08949685 0.123805806 0.13185392;;; -0.13680416 0.16319329 -0.06956373; -0.14383934 -0.16383497 0.073431596; 0.16225612 0.0035276613 0.04317957;;; -0.091075204 -0.13167195 0.07686739; 0.13491625 0.040214263 0.1255645; 0.06552337 -0.12093544 0.038570147;;; -0.06843047 -0.0045178533 0.16651492; -0.01803426 0.0930454 0.06653478; 0.08362579 -0.04865108 -0.036420267;;; 0.09165492 -0.0044929786 0.0134309735; -0.09496224 0.103297636 0.0040376983; -0.032184403 0.14868289 -0.13937023;;; 0.15932915 0.081981204 0.06538048; 0.076568626 0.010508339 0.048101544; 0.15669504 -0.07415293 0.0699504;;;; -0.08069881 0.062364362 0.124893785; -0.13495265 -0.09873589 -0.050736845; -0.16598257 -0.12863098 -0.06857874;;; -0.045879744 0.16120599 -0.15732601; -0.103535816 -0.081530295 -0.005281548; -0.07928747 -0.053081416 0.06965824;;; -0.15905218 -0.10493247 -0.057815075; -0.08391162 -0.03292237 -0.07383923; -0.00984329 -0.011959533 -0.052526753;;; 0.053180795 0.15874664 -0.055878144; -0.038234632 0.118108034 -0.1611167; -0.06724515 0.07239348 0.015467465;;; -0.035428107 0.15569595 -0.06487626; -0.11449985 -0.06978661 0.14976531; 0.1495658 -0.030165693 -0.14154415;;; 0.15598403 0.018530708 0.15468585; -0.05259844 0.16338083 -0.10575688; 0.0548775 0.012418966 -0.013709366;;; 0.08392076 0.15710297 0.07483542; -0.031504553 0.15619342 0.039251428; 0.041137695 0.09533759 0.022108018;;; 0.053296328 -0.10033536 -0.13819176; -0.006598036 0.10779407 0.061223924; -0.015153408 -0.014042894 0.088644825], Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [-0.06042359 -0.08660967 -0.055630397; -0.08815988 0.064195566 -0.035169143; -0.09603527 -0.099422 -0.06701526;;; 0.08788788 -0.069484025 -0.0776086; -0.07148152 -0.07081234 0.10296843; 0.112362266 0.07793139 -0.038195413;;; -0.0055710687 0.026661886 -0.070060045; -0.05720735 0.092402965 -0.018252775; 0.01340142 -0.00014064403 0.042126577;;; … ;;; 0.117023125 0.11748216 -0.11220874; 0.04532208 -0.09431376 -0.10699835; 0.09185852 -0.07859564 -0.091722;;; 0.03148271 0.00016438676 0.08032767; 0.005468441 -0.031866223 -0.09998303; 0.06701791 -0.09354439 -0.0048075225;;; -0.10478373 -0.09356782 0.09225203; -0.11187694 -0.07542693 -0.00878072; -0.09351926 0.0727439 -0.09802294;;;; 0.08613692 0.056650266 -0.049715325; 0.056839466 -0.002410828 -0.04682342; -0.042816196 -0.096252054 -0.08556118;;; -0.11031078 0.07599973 0.055166025; 0.07578067 0.09646988 -0.06731202; -0.05186746 -0.088826574 0.09309893;;; -0.101404585 -0.06369933 0.0445657; 0.08334697 0.08309819 0.021893026; 0.07859269 0.06334059 -0.010844525;;; … ;;; -0.11077307 -0.02789825 0.11473771; -0.029184824 0.047109976 0.025483418; 0.0809176 -0.002515352 -0.058323078;;; -0.11509346 -0.04009819 0.10520754; -0.081775114 -0.07908829 -0.00053310144; 0.10619639 0.012081717 0.109755576;;; -0.041703142 -0.061618287 0.012191313; 0.021045579 0.0071340706 -0.08278022; 0.093283184 0.062081955 -0.027269544;;;; 0.0843725 -0.09425848 0.08070717; 0.07467716 -0.06560002 0.099061936; 0.10863947 0.070024766 0.059391513;;; -0.051938727 0.08050739 0.03373129; -0.050433382 0.030697798 -0.011945035; 0.02388906 0.0063602687 -0.117438525;;; 0.08406686 -0.04519471 -0.0637578; 0.09074557 -0.049044687 0.020482386; 0.11428881 -0.012114634 -0.0033043833;;; … ;;; 0.06549185 -0.09307323 0.0259651; -0.054547757 -0.039960578 -0.068566814; -0.056072194 0.053014092 -0.09135668;;; 0.068792425 0.017731067 0.07870205; -0.11121788 -0.0008447212 -0.0877369; -0.010374307 0.106969334 0.029535739;;; -0.040468715 -0.03560931 -0.088743605; 0.09394024 0.016264651 -0.105080776; -0.016833521 0.06633837 -0.027417494;;;; … ;;;; 0.07104594 -0.11241389 0.08947321; 0.033435713 0.10731884 0.036904864; -0.10268164 0.04816664 -0.029407864;;; 0.044742633 -0.106065355 -0.093439266; 0.03483663 -0.091866896 0.043162826; 0.030338187 0.08185432 0.0074583343;;; -0.079810455 0.07063996 -0.06707529; -0.017727287 -0.006290108 -0.07651689; -0.07984296 0.07496621 0.054145746;;; … ;;; -0.051756583 0.08356277 -0.03481806; 0.06852649 -0.077274814 0.0040810234; 0.08165109 -0.08012873 0.08699555;;; -0.046830192 -0.070529774 0.06638509; -0.06324479 0.021830916 -0.04188585; 0.080379196 0.056811903 -0.0096851215;;; -0.11022355 0.099013284 -0.1149995; 0.10566989 0.017512761 -0.116805; 0.017660458 0.10399492 0.02979369;;;; 0.07733827 0.077586725 -0.11116498; -0.08654636 -0.10688979 -0.08701194; -0.02741033 -0.10634871 -0.04353292;;; -0.055868234 -0.0014820659 0.100929126; -0.049722508 0.013145392 0.0028019927; -0.061871953 0.10925259 -0.0003455058;;; -0.009819106 0.048155963 -0.0096191475; -0.021656007 0.03695735 0.012263033; 0.10033004 0.032884166 -0.07901535;;; … ;;; 0.061449654 0.0152114015 0.07324399; 0.10943332 -0.0830591 0.06555774; 0.06152087 -0.09826199 0.082387805;;; -0.023884956 0.02383077 0.0874208; -0.04709382 0.033487484 0.03512149; -0.026025904 0.07301527 0.031429455;;; -0.08404336 -0.072803035 0.027120654; 0.050990675 0.113487594 0.042047072; 0.0016886697 -0.047027804 0.032680724;;;; 0.08992643 -0.07953932 -0.027652169; -0.05680954 -0.08573395 0.020324165; 0.10532847 -0.11007345 -0.10200549;;; 0.087424256 -0.090772495 0.06064836; -0.088461444 -0.018909745 -0.111272275; 0.013029151 -0.051981874 -0.010968268;;; 0.09599723 0.09119425 -0.099256024; 0.023567254 -0.07808867 0.07775103; 0.05172749 -0.084476344 -0.11702098;;; … ;;; 0.057024658 0.03588449 0.06518215; -0.11517725 -0.079458065 -0.10393292; -0.06126313 -0.08893906 0.09449007;;; -0.115184024 0.08676186 0.02913651; -0.038352903 0.1081972 -0.024200495; -0.078161694 0.1079326 0.03777875;;; -0.07327348 0.058588687 0.097199135; -0.064685464 0.0035126028 -0.08930635; 0.068497665 -0.03347491 -0.10834619], Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0 … 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.05736505 0.07910824 -0.0050783153; -0.07096387 -0.04436114 0.10167752; 0.055980556 -0.030940648 -0.028611299;;; -0.056326773 -0.09635433 0.03841428; 0.072198845 -0.096073344 0.05787758; -0.015899519 -0.054047003 -0.072522745;;; 0.0474121 0.00083285035 0.0019753566; 0.068468 0.009147119 -0.0952422; -0.024435002 -0.022641307 -0.06922704;;; … ;;; 0.094621 -0.07240909 0.014108682; -0.055640362 -0.087646075 -0.0895864; 0.09805373 -0.08738999 -0.10046777;;; -0.023010107 -0.0031301514 0.027530989; 0.013230036 0.039232444 -0.02645624; 0.076439895 0.003336171 0.00083011284;;; -0.03280271 0.0070184735 -0.023207597; -0.09866495 0.014952762 -0.0019330772; 0.021318015 -0.024061458 -0.05573497;;;; 0.099357106 0.0045144837 -0.098532975; -0.02292141 0.056106348 0.0017980263; -0.0073047937 0.06938027 0.07668876;;; 0.022939514 -0.009469964 0.056705732; 0.095790006 0.032307096 0.10071559; -0.056059603 -0.026924526 -0.055061553;;; 0.10158025 0.094438516 -0.05205785; -0.07706227 0.024250615 -0.014179614; 0.0006407861 -0.08686509 0.028403394;;; … ;;; -0.043543987 0.10076445 -0.07981682; -0.059287723 0.09425819 0.08304173; 0.046162345 0.041298795 -0.0230308;;; 0.06902185 -0.096342854 0.075525485; -0.04645577 -0.08231316 0.096911065; -0.0010954939 -0.07271828 0.040537972;;; -0.0997284 -0.012448043 -0.052605562; 0.049763385 -0.044758312 0.050574347; -0.012608072 -0.03561493 0.041805346;;;; 0.057001725 -0.040021397 -0.048545823; -0.05471984 0.025910866 0.059996665; 0.051619265 -0.00630622 0.065758996;;; -0.01365771 -0.05062153 0.06148809; -0.007890221 0.040930994 0.010517107; -0.030643512 -0.01675551 -0.026299385;;; -0.042533964 0.0660248 -0.007557436; -0.017219232 -0.025686985 -0.03565666; -0.09520086 -0.025579236 -0.061450265;;; … ;;; 0.087049186 -0.041221343 0.07075191; 0.04937247 -0.08693265 0.048757717; -0.002759406 -0.060941692 0.06823887;;; -0.043862574 -0.08877707 -0.059135843; 0.06773625 -0.013757209 -0.020028425; 0.06628728 0.073348105 0.05313266;;; -0.0039883573 0.04677601 0.043142714; 0.04248746 0.04744277 0.10118595; 0.08364808 -0.03160483 -0.035021033;;;; … ;;;; -0.06965584 -0.0013961829 -0.029060155; 0.0033122145 -0.07103713 0.093489386; 0.0958181 -0.06467735 -0.02018444;;; 0.072027594 -0.078316055 -0.011911331; 0.089672066 -0.08199512 -0.002356261; -0.003622065 -0.04414743 0.03233464;;; 0.061915398 0.098239884 -0.02738399; 0.08235395 -0.048877306 -0.017131522; 0.09187188 -0.0074713444 0.096671455;;; … ;;; 0.021906594 0.0014109047 -0.010903632; 0.09164269 0.08559022 -0.040745646; 0.059964996 0.05697846 0.09211376;;; 0.043448452 -0.027687307 -0.038102057; 0.027154222 0.04608887 0.051346082; -0.024543153 0.020256137 0.07648824;;; -0.06574008 0.08072355 0.074340396; -0.054879136 -0.028355323 -0.09271009; 0.050845534 -0.010056376 0.07782329;;;; 0.045539785 0.032338083 -0.048904218; 0.04658834 0.07193689 0.08490136; 0.071139336 -0.066213645 -0.100889266;;; -0.1010549 0.07626708 0.0746525; -0.05450647 -0.0278802 -0.07270774; 0.061567124 -0.023196768 -0.08654951;;; -0.013370234 -0.07297455 -0.012319051; 0.100312866 -0.038974985 0.10009289; 0.03813009 0.056243468 -0.07155648;;; … ;;; -0.058383573 -0.09969674 0.07579363; 0.085557796 0.025419025 -0.09474081; -0.020902654 0.017982515 -0.07690825;;; -0.02497057 0.032917198 -0.08355293; 0.025056504 0.049819816 0.023732373; -0.08410779 -0.06505822 0.021739192;;; -0.035714917 0.062294662 0.07860457; 0.006272871 -0.05302064 0.09013072; 0.07470063 -0.038121607 0.0003606224;;;; 0.06661162 0.09476825 0.010694595; -0.038434513 0.081536114 0.005156523; 0.018002674 0.09085246 -0.07659161;;; 0.06984942 0.017836513 0.012237472; -0.07525216 0.036720023 0.06105864; -0.09645195 0.07460411 -0.045834266;;; -0.046984803 0.0872234 0.0012835067; 0.058173563 0.002069004 -0.08194415; 0.099978335 0.014651271 0.08476727;;; … ;;; -0.017271148 -0.052735586 0.05737821; -0.0011333325 0.01949461 0.06530301; 0.038627926 0.037624437 -0.083214216;;; -0.09757997 0.040057607 -0.08728757; -0.011686466 0.0005903184 0.02830309; 0.031193096 0.007301728 -0.058197398;;; 0.071791515 -0.0657642 0.099956326; 0.031907212 -0.07410844 0.04256641; -0.0015776986 -0.050071266 0.02444027], Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0 … 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], Float32[-0.28695634 -0.2080844 … 0.17181304 -0.083971284; 0.06789562 0.34318665 … 0.24812546 -0.10470303; … ; 0.1238623 -0.31849182 … 0.36174598 -0.27657127; -0.08542383 -0.26908675 … -0.08216387 0.34210646], Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]])
K učení lze použít makro @epochs
.
number_epochs = 10
@epochs number_epochs Flux.Optimise.train!(loss, ps, train_data, opt)
accuracy(model(x_train), y_train)
┌ Warning: The macro `@epochs` will be removed from Flux 0.14. │ As an alternative, you can write a simple `for i in 1:epochs` loop. │ caller = eval at boot.jl:368 [inlined] └ @ Core ./boot.jl:368 ┌ Info: Epoch 1 └ @ Main /home/kalvin/.julia/packages/Flux/ZdbJr/src/optimise/train.jl:185 ┌ Info: Epoch 2 └ @ Main /home/kalvin/.julia/packages/Flux/ZdbJr/src/optimise/train.jl:185 ┌ Info: Epoch 3 └ @ Main /home/kalvin/.julia/packages/Flux/ZdbJr/src/optimise/train.jl:185 ┌ Info: Epoch 4 └ @ Main /home/kalvin/.julia/packages/Flux/ZdbJr/src/optimise/train.jl:185 ┌ Info: Epoch 5 └ @ Main /home/kalvin/.julia/packages/Flux/ZdbJr/src/optimise/train.jl:185 ┌ Info: Epoch 6 └ @ Main /home/kalvin/.julia/packages/Flux/ZdbJr/src/optimise/train.jl:185 ┌ Info: Epoch 7 └ @ Main /home/kalvin/.julia/packages/Flux/ZdbJr/src/optimise/train.jl:185 ┌ Info: Epoch 8 └ @ Main /home/kalvin/.julia/packages/Flux/ZdbJr/src/optimise/train.jl:185 ┌ Info: Epoch 9 └ @ Main /home/kalvin/.julia/packages/Flux/ZdbJr/src/optimise/train.jl:185 ┌ Info: Epoch 10 └ @ Main /home/kalvin/.julia/packages/Flux/ZdbJr/src/optimise/train.jl:185
0.97865
Jak náš model předpovídá a jak přesně?
onecold(model(x_valid))[1:5]
5-element Vector{Int64}: 8 3 2 1 5
onecold(y_valid)[1:5]
5-element Vector{Int64}: 8 3 2 1 5
accuracy(model(x_valid), y_valid)
0.9743
4. Uzavření semestru
Tímto se dostáváme na konec prvního běhu BI-JUL.21. Ve zbylém čase budeme na tomto Vánočním cvičení řešit případné dotazy a zajímavosti.
Nezapomeňte vyplnit anketu hodnocení výuky!
Reference
Vedle výše zmíněných balíčků a nástrojů zmíněných v úvodu můžete prohledat i kategorii Machine Learning v databázi balíčků.