Skip to content

Commit 0628c2d

Browse files
Updated and verified tests
1 parent 95550b9 commit 0628c2d

File tree

6 files changed

+25
-15
lines changed

6 files changed

+25
-15
lines changed

src/ULMFiT/custom_layers.jl

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,8 @@ mutable struct WeightDroppedLSTMCell{A, V, S, M}
4848
Wi::A
4949
Wh::A
5050
b::V
51-
state0::S
51+
h::S
52+
c::S
5253
p::Float64
5354
maskWi::M
5455
maskWh::M
@@ -62,8 +63,8 @@ function WeightDroppedLSTMCell(in::Integer, out::Integer, p::Float64=0.0;
6263
init(out*4, in),
6364
init(out*4, out),
6465
init(out*4),
65-
(reshape(zeros(Float32, out),out, 1),
66-
reshape(zeros(Float32, out), out, 1)),
66+
reshape(zeros(Float32, out),out, 1),
67+
reshape(zeros(Float32, out), out, 1),
6768
p,
6869
drop_mask((out*4, in), p),
6970
drop_mask((out*4, out), p),
@@ -89,7 +90,7 @@ end
8990

9091
Flux.@functor WeightDroppedLSTMCell
9192

92-
Flux.trainable(m::WeightDroppedLSTMCell) = (m.Wi, m.Wh, m.b, m.state0...)
93+
Flux.trainable(m::WeightDroppedLSTMCell) = (m.Wi, m.Wh, m.b, m.h, m.c)
9394

9495
testmode!(m::WeightDroppedLSTMCell, mode=true) =
9596
(m.active = (isnothing(mode) || mode == :auto) ? nothing : !mode; m)
@@ -107,10 +108,20 @@ julia> wd = WeightDroppedLSTM(4, 5, 0.3);
107108
"""
108109
function WeightDroppedLSTM(a...; kw...)
109110
cell = WeightDroppedLSTMCell(a...;kw...)
110-
hidden = cell.state0
111+
hidden = (cell.h, cell.c)
111112
return Flux.Recur(cell, hidden)
112113
end
113114

115+
# over definition for reset! to work with pretrained model
116+
function reset!(m)
117+
try
118+
(m.state = (m.cell.h, m.cell.c))
119+
catch
120+
Flux.reset!(m)
121+
end
122+
end
123+
124+
114125
"""
115126
reset_masks!(layer)
116127

src/ULMFiT/pretrain_lm.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ function loss(lm, gen)
103103
H = forward(lm, take!(gen))
104104
Y = broadcast(x -> gpu(Flux.onehotbatch(x, lm.vocab, "_unk_")), take!(gen))
105105
l = sum(Flux.crossentropy.(H, Y))
106-
Flux.reset!(lm.layers)
106+
reset!(lm.layers)
107107
return l
108108
end
109109

src/ULMFiT/train_text_classifier.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ function TextClassifier(lm::LanguageModel=LanguageModel(), clsfr_out_sz::Integer
2020
lm.vocab,
2121
lm.layers[1:8],
2222
Chain(
23-
gpu(PooledDense(length(lm.layers[7].layer.cell.state0[1]), clsfr_hidden_sz)),
23+
gpu(PooledDense(length(lm.layers[7].layer.cell.h), clsfr_hidden_sz)),
2424
gpu(BatchNorm(clsfr_hidden_sz, relu)),
2525
Dropout(clsfr_hidden_drop),
2626
gpu(Dense(clsfr_hidden_sz, clsfr_out_sz)),

test/crf.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -118,14 +118,14 @@ using TextModels: score_sequence, forward_score
118118

119119
function train()
120120
for d in data
121-
reset!(lstm)
121+
Flux.reset!(lstm)
122122
grads = gradient(() -> loss(d[1], d[2]), ps)
123123
Flux.Optimise.update!(opt, ps, grads)
124124
end
125125
end
126126

127127
function find_loss(d)
128-
reset!(lstm)
128+
Flux.reset!(lstm)
129129
loss(d[1], d[2])
130130
end
131131
to_sum = [find_loss(d) for d in data]
@@ -138,7 +138,7 @@ using TextModels: score_sequence, forward_score
138138
train()
139139
end
140140

141-
dense_param_2 = deepcopy(d_out.W))
141+
dense_param_2 = deepcopy(d_out.W)
142142
lstm_param_2 = deepcopy(lstm.cell.Wh)
143143
crf_param_2 = deepcopy(c.W)
144144
l2 = sum([find_loss(d) for d in data])

test/runtests.jl

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,7 @@ using TextModels
55
println("Running tests:")
66

77
include("crf.jl")
8-
#include("ner.jl")
9-
#include("pos.jl")
10-
#include("averagePerceptronTagger.jl")
8+
include("ner.jl")
9+
include("pos.jl")
10+
include("averagePerceptronTagger.jl")
1111
include("ulmfit.jl")
12-
#include("sentiment.jl")

test/ulmfit.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ using BSON
44
@testset "Custom layers" begin
55
@testset "WeightDroppedLSTM" begin
66
wd = ULMFiT.WeightDroppedLSTM(4, 5, 0.3)
7-
@test all((wd.cell.state0) .== wd.state)
7+
@test all((wd.cell.h, wd.cell.c) .== wd.state)
88
@test size(wd.cell.Wi) == size(wd.cell.maskWi)
99
@test size(wd.cell.Wh) == size(wd.cell.maskWh)
1010
@test wd.cell.active

0 commit comments

Comments
 (0)