Skip to content
This repository was archived by the owner on Feb 13, 2025. It is now read-only.

Commit a275252

Browse files
authored
Converting dropout-related vars to lets. (#616)
1 parent 2323b1c commit a275252

File tree

1 file changed

+6
-10
lines changed

1 file changed

+6
-10
lines changed

Models/Text/WordSeg/Model.swift

+6-10
Original file line numberDiff line numberDiff line change
@@ -151,12 +151,10 @@ public struct SNLM: EuclideanDifferentiable, KeyPathIterable {
151151
/// Returns the hidden states of the encoder LSTM applied to `x`, using
152152
/// `device`.
153153
public func encode(_ x: CharacterSequence, device: Device) -> [Tensor<Float>] {
154-
var embedded = encoderEmbedding(x.tensor(device: device))
155-
embedded = dropout(embedded)
154+
let embedded = dropout(encoderEmbedding(x.tensor(device: device)))
156155
let encoderStates = encoderLSTM(embedded.unstacked().differentiableMap { $0.rankLifted() })
157-
var encoderResult = Tensor(
158-
stacking: encoderStates.differentiableMap { $0.hidden.squeezingShape(at: 0) })
159-
encoderResult = dropout(encoderResult)
156+
let encoderResult = dropout(Tensor(
157+
stacking: encoderStates.differentiableMap { $0.hidden.squeezingShape(at: 0) }))
160158
return encoderResult.unstacked()
161159
}
162160

@@ -196,8 +194,7 @@ public struct SNLM: EuclideanDifferentiable, KeyPathIterable {
196194
).transposed()
197195

198196
// [time x batch x hiddenSize]
199-
var embeddedX = decoderEmbedding(x)
200-
embeddedX = dropout(embeddedX)
197+
let embeddedX = dropout(decoderEmbedding(x))
201198

202199
// [batch x hiddenSize]
203200
let stateBatch = state.rankLifted().tiled(multiples: [candidates.count, 1])
@@ -210,9 +207,8 @@ public struct SNLM: EuclideanDifferentiable, KeyPathIterable {
210207
hidden: stateBatch))
211208

212209
// [time x batch x hiddenSize]
213-
var decoderResult = Tensor(
214-
stacking: decoderStates.differentiableMap { $0.hidden })
215-
decoderResult = dropout(decoderResult)
210+
let decoderResult = dropout(Tensor(
211+
stacking: decoderStates.differentiableMap { $0.hidden }))
216212

217213
// [time x batch x alphabet.count]
218214
let logits = decoderDense(decoderResult)

0 commit comments

Comments
 (0)