Skip to content

Commit 6cd9824

Browse files
Reshape pretrained weights ULMFiT LM
1 parent 8398f61 commit 6cd9824

File tree

2 files changed

+11
-0
lines changed

2 files changed

+11
-0
lines changed

src/ULMFiT/pretrain_lm.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,11 @@ end
163163
# To load model
164164
function load_model!(lm::LanguageModel, filepath::String)
165165
BSON.@load filepath weights
166+
# reshape saved weights to match Recurr (h, c) shape
167+
layers = [5, 6, 10, 11, 15, 16]
168+
for l in layers
169+
weights[l] = reshape(weights[l], length(weights[l]), 1)
170+
end
166171
Flux.loadparams!(lm, weights)
167172
end
168173

test/ulmfit.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,12 @@ end
9595
@test length(ULMFiT.get_trainable_params(lm.layers)) == 10
9696

9797
pretrained_weights = BSON.load(datadep"Pretrained ULMFiT Language Model/ulmfit_lm_en.bson")
98+
# reshape weights of (h, c)
99+
layers = [5, 6, 10, 11, 15, 16]
100+
for i in layers
101+
pretrained_weights[:weights][i] = reshape(pretrained_weights[:weights][i], length(pretrained_weights[:weights][i]), 1)
102+
end
103+
98104
@test length(pretrained_weights[:weights]) == 16
99105
@test all(size.(params(lm)) .== size.(pretrained_weights[:weights]))
100106
end

0 commit comments

Comments
 (0)