@@ -10,13 +10,20 @@ import (
10
10
"crypto/cipher"
11
11
"crypto/hmac"
12
12
"crypto/sha1"
13
+ "crypto/subtle"
14
+ "errors"
13
15
"io"
16
+ "io/ioutil"
14
17
15
18
"golang.org/x/crypto/pbkdf2"
16
19
)
17
20
18
- // Counter (CTR) mode.
21
+ // Decryption Errors
22
+ var (
23
+ ErrDecryption = errors .New ("zip: decryption error" )
24
+ )
19
25
26
+ // Counter (CTR) mode.
20
27
// CTR converts a block cipher into a stream cipher by
21
28
// repeatedly encrypting an incrementing counter and
22
29
// xoring the resulting stream of data with the input.
@@ -111,14 +118,75 @@ func xorBytes(dst, a, b []byte) int {
111
118
return n
112
119
}
113
120
121
+ type authReader struct {
122
+ data io.Reader // data to be authenticated
123
+ adata io.Reader // the authentication code to read
124
+ akey []byte // authentication key
125
+ buf * bytes.Buffer // buffer to store data to authenticate
126
+ err error
127
+ auth bool
128
+ }
129
+
130
+ func newAuthReader (akey []byte , data , adata io.Reader ) io.Reader {
131
+ return & authReader {
132
+ data : data ,
133
+ adata : adata ,
134
+ akey : akey ,
135
+ buf : new (bytes.Buffer ),
136
+ err : nil ,
137
+ auth : false ,
138
+ }
139
+ }
140
+
141
+ // Read will fully buffer the file data payload to authenticate first.
142
+ // If authentication fails, returns ErrDecryption immediately.
143
+ // Else, sends data along for decryption.
144
+ func (a * authReader ) Read (b []byte ) (int , error ) {
145
+ // check for sticky error
146
+ if a .err != nil {
147
+ return 0 , a .err
148
+ }
149
+ // make sure we have auth'ed before we send any data
150
+ if ! a .auth {
151
+ nn , err := io .Copy (a .buf , a .data )
152
+ if err != nil {
153
+ a .err = ErrDecryption
154
+ return 0 , a .err
155
+ }
156
+ ab := new (bytes.Buffer )
157
+ nn , err = io .Copy (ab , a .adata )
158
+ if err != nil || nn != 10 {
159
+ a .err = ErrDecryption
160
+ return 0 , a .err
161
+ }
162
+ a .auth = checkAuthentication (a .buf .Bytes (), ab .Bytes (), a .akey )
163
+ if ! a .auth {
164
+ a .err = ErrDecryption
165
+ return 0 , a .err
166
+ }
167
+ }
168
+ // so we've authenticated the data, now just pass it on.
169
+ n , err := a .buf .Read (b )
170
+ if err != nil {
171
+ a .err = err
172
+ }
173
+ return n , a .err
174
+ }
175
+
114
176
func checkAuthentication (message , authcode , key []byte ) bool {
115
177
mac := hmac .New (sha1 .New , key )
116
178
mac .Write (message )
117
179
expectedAuthCode := mac .Sum (nil )
118
180
// Truncate at the first 10 bytes
119
181
expectedAuthCode = expectedAuthCode [:10 ]
120
182
// Change to use crypto/subtle for constant time comparison
121
- return bytes .Equal (expectedAuthCode , authcode )
183
+ b := subtle .ConstantTimeCompare (expectedAuthCode , authcode ) > 0
184
+ return b
185
+ }
186
+
187
+ func checkPasswordVerification (pwvv , pwv []byte ) bool {
188
+ b := subtle .ConstantTimeCompare (pwvv , pwv ) > 0
189
+ return b
122
190
}
123
191
124
192
func generateKeys (password , salt []byte , keySize int ) (encKey , authKey , pwv []byte ) {
@@ -130,7 +198,7 @@ func generateKeys(password, salt []byte, keySize int) (encKey, authKey, pwv []by
130
198
return
131
199
}
132
200
133
- func newDecryptionReader (r io.Reader , f * File ) (io.Reader , error ) {
201
+ func newDecryptionReader (r * io.SectionReader , f * File ) (io.ReadCloser , error ) {
134
202
keyLen := aesKeyLen (f .aesStrength )
135
203
saltLen := keyLen / 2 // salt is half of key len
136
204
if saltLen == 0 {
@@ -141,38 +209,39 @@ func newDecryptionReader(r io.Reader, f *File) (io.Reader, error) {
141
209
// See:
142
210
// https://www.imperialviolet.org/2014/06/27/streamingencryption.html
143
211
// https://www.imperialviolet.org/2015/05/16/aeads.html
144
- content := make ([]byte , f .CompressedSize64 )
145
- if _ , err := io .ReadFull (r , content ); err != nil {
212
+ // grab the salt, pwvv, data, and authcode
213
+ saltpwvv := make ([]byte , saltLen + 2 )
214
+ if _ , err := r .Read (saltpwvv ); err != nil {
146
215
return nil , ErrDecryption
147
216
}
148
- // grab the salt, pwvv, data, and authcode
149
- salt := content [: saltLen ]
150
- pwvv := content [ saltLen : saltLen + 2 ]
151
- content = content [ saltLen + 2 :]
152
- size := f . CompressedSize64 - uint64 ( saltLen ) - 2 - 10
153
- data := content [: size ]
154
- authcode := content [ size :]
217
+ salt := saltpwvv [: saltLen ]
218
+ pwvv := saltpwvv [ saltLen : saltLen + 2 ]
219
+ dataOff := int64 ( saltLen + 2 )
220
+ dataLen := int64 ( f . CompressedSize64 - uint64 ( saltLen ) - 2 - 10 )
221
+ data := io . NewSectionReader ( r , dataOff , dataLen )
222
+ authOff := dataOff + dataLen
223
+ authcode := io . NewSectionReader ( r , authOff , 10 )
155
224
// generate keys
156
225
decKey , authKey , pwv := generateKeys (f .password , salt , keyLen )
157
226
// check password verifier (pwv)
158
227
// Change to use crypto/subtle for constant time comparison
159
- if ! bytes .Equal (pwv , pwvv ) {
160
- return nil , ErrDecryption
161
- }
162
- // check authentication
163
- if ! checkAuthentication (data , authcode , authKey ) {
228
+ if ! checkPasswordVerification (pwv , pwvv ) {
164
229
return nil , ErrDecryption
165
230
}
166
- return decryptStream (data , decKey ), nil
231
+ // setup auth reader
232
+ ar := newAuthReader (authKey , data , authcode )
233
+ // return decryption reader
234
+ dr := decryptStream (decKey , ar )
235
+ return ioutil .NopCloser (dr ), nil
167
236
}
168
237
169
- func decryptStream (ciphertext , key []byte ) io.Reader {
238
+ func decryptStream (key []byte , ciphertext io. Reader ) io.Reader {
170
239
block , err := aes .NewCipher (key )
171
240
if err != nil {
172
241
return nil
173
242
}
174
243
stream := newWinZipCTR (block )
175
- reader := cipher.StreamReader {S : stream , R : bytes . NewReader ( ciphertext ) }
244
+ reader := & cipher.StreamReader {S : stream , R : ciphertext }
176
245
return reader
177
246
}
178
247
0 commit comments