@@ -151,12 +151,10 @@ public struct SNLM: EuclideanDifferentiable, KeyPathIterable {
151
151
/// Returns the hidden states of the encoder LSTM applied to `x`, using
152
152
/// `device`.
153
153
public func encode( _ x: CharacterSequence , device: Device ) -> [ Tensor < Float > ] {
154
- var embedded = encoderEmbedding ( x. tensor ( device: device) )
155
- embedded = dropout ( embedded)
154
+ let embedded = dropout ( encoderEmbedding ( x. tensor ( device: device) ) )
156
155
let encoderStates = encoderLSTM ( embedded. unstacked ( ) . differentiableMap { $0. rankLifted ( ) } )
157
- var encoderResult = Tensor (
158
- stacking: encoderStates. differentiableMap { $0. hidden. squeezingShape ( at: 0 ) } )
159
- encoderResult = dropout ( encoderResult)
156
+ let encoderResult = dropout ( Tensor (
157
+ stacking: encoderStates. differentiableMap { $0. hidden. squeezingShape ( at: 0 ) } ) )
160
158
return encoderResult. unstacked ( )
161
159
}
162
160
@@ -196,8 +194,7 @@ public struct SNLM: EuclideanDifferentiable, KeyPathIterable {
196
194
) . transposed ( )
197
195
198
196
// [time x batch x hiddenSize]
199
- var embeddedX = decoderEmbedding ( x)
200
- embeddedX = dropout ( embeddedX)
197
+ let embeddedX = dropout ( decoderEmbedding ( x) )
201
198
202
199
// [batch x hiddenSize]
203
200
let stateBatch = state. rankLifted ( ) . tiled ( multiples: [ candidates. count, 1 ] )
@@ -210,9 +207,8 @@ public struct SNLM: EuclideanDifferentiable, KeyPathIterable {
210
207
hidden: stateBatch) )
211
208
212
209
// [time x batch x hiddenSize]
213
- var decoderResult = Tensor (
214
- stacking: decoderStates. differentiableMap { $0. hidden } )
215
- decoderResult = dropout ( decoderResult)
210
+ let decoderResult = dropout ( Tensor (
211
+ stacking: decoderStates. differentiableMap { $0. hidden } ) )
216
212
217
213
// [time x batch x alphabet.count]
218
214
let logits = decoderDense ( decoderResult)
0 commit comments