-
-
Notifications
You must be signed in to change notification settings - Fork 332
/
Copy pathrecursive.jl
91 lines (68 loc) · 2.71 KB
/
recursive.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
# # Recursive net on IMDB sentiment treebank
# In this example, we create a recursive neural network to perform sentiment analysis using
# IMDB data.
# This type of model can be used
# for learning tree-like structures (directed acyclic graphs).
# It computes compositional vector representations for prhases of variable length
# which are used as features for performing classification.
# 
# [Source](https://nlp.stanford.edu/~socherr/EMNLP2013_RNTN.pdf)
# This example uses the [Standford Sentiment Treebank dataset
# (SST)](https://nlp.stanford.edu/sentiment/index.html) which is often used
# as one of the benchmark datasets to test new language models.
# It has five different classes (very negative to very positive) and the
# goal is to perform sentiment analysis.
# To run this example, we need the following packages:
using Flux
using Flux: logitcrossentropy, throttle
using Flux.Data: Tree, children, isleaf
using Parameters: @with_kw
# The script `data.jl` contains the function `getdata` that obtains
# and process the SST dataset.
include("data.jl")
# We set default values for the hyperparameters:
@with_kw mutable struct Args
lr::Float64 = 1e-3 ## Learning rate
N::Int = 300
throttle::Int = 10 ## Throttle timeout
end
# ## Build the model
# The function `train` loads the data, builds and trains the model.
# For more information on how the recursive neural network works, see
# section 4 of [Recursive Deep Models for Semantic Compositionality
# Over a Sentiment Treebank](https://nlp.stanford.edu/~socherr/EMNLP2013_RNTN.pdf).
function train(; kws...)
## Initialize HyperParameters
args = Args(; kws...)
## Load data
@info("Loading Data...")
train_data, alphabet = getdata()
@info("Constructing model....")
embedding = randn(Float32, args.N, length(alphabet))
@info "Size of the embedding" size(embedding)
W = Dense(2*args.N, args.N, tanh)
combine(a, b) = W([a; b])
sentiment = Chain(Dense(args.N, 5))
function forward(tree)
if isleaf(tree)
token, sent = tree.value
phrase = embedding * token
phrase, logitcrossentropy(sentiment(phrase), sent)
else
_, sent = tree.value
c1, l1 = forward(tree[1])
c2, l2 = forward(tree[2])
phrase = combine(c1, c2)
phrase, l1 + l2 + logitcrossentropy(sentiment(phrase), sent)
end
end
loss(tree) = forward(tree)[2]
opt = ADAM(args.lr)
ps = params(embedding, W, sentiment)
evalcb = () -> @show loss(train_data[1])
@info("Training Model...")
Flux.train!(loss, ps, zip(train_data), opt,cb = throttle(evalcb, args.throttle))
end
# ## Train the model
cd(@__DIR__)
train()