001/*
002 * $HeadURL: http://juliusdavies.ca/svn/not-yet-commons-ssl/tags/commons-ssl-0.3.15/src/java/org/apache/commons/ssl/PKCS8Key.java $
003 * $Revision: 153 $
004 * $Date: 2009-09-15 22:40:53 -0700 (Tue, 15 Sep 2009) $
005 *
006 * ====================================================================
007 * Licensed to the Apache Software Foundation (ASF) under one
008 * or more contributor license agreements.  See the NOTICE file
009 * distributed with this work for additional information
010 * regarding copyright ownership.  The ASF licenses this file
011 * to you under the Apache License, Version 2.0 (the
012 * "License"); you may not use this file except in compliance
013 * with the License.  You may obtain a copy of the License at
014 *
015 *   http://www.apache.org/licenses/LICENSE-2.0
016 *
017 * Unless required by applicable law or agreed to in writing,
018 * software distributed under the License is distributed on an
019 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
020 * KIND, either express or implied.  See the License for the
021 * specific language governing permissions and limitations
022 * under the License.
023 * ====================================================================
024 *
025 * This software consists of voluntary contributions made by many
026 * individuals on behalf of the Apache Software Foundation.  For more
027 * information on the Apache Software Foundation, please see
028 * <http://www.apache.org/>.
029 *
030 */
031
032package org.apache.commons.ssl;
033
034import org.apache.commons.ssl.asn1.*;
035
036import javax.crypto.*;
037import javax.crypto.spec.IvParameterSpec;
038import javax.crypto.spec.RC2ParameterSpec;
039import javax.crypto.spec.RC5ParameterSpec;
040import javax.crypto.spec.SecretKeySpec;
041import java.io.*;
042import java.math.BigInteger;
043import java.security.*;
044import java.security.interfaces.DSAParams;
045import java.security.interfaces.DSAPrivateKey;
046import java.security.interfaces.RSAPrivateCrtKey;
047import java.security.spec.DSAPublicKeySpec;
048import java.security.spec.KeySpec;
049import java.security.spec.PKCS8EncodedKeySpec;
050import java.security.spec.RSAPublicKeySpec;
051import java.util.Arrays;
052import java.util.Collections;
053import java.util.Iterator;
054import java.util.List;
055
056/**
057 * Utility for decrypting PKCS8 private keys.  Way easier to use than
058 * javax.crypto.EncryptedPrivateKeyInfo since all you need is the byte[] array
059 * and the password.  You don't need to know anything else about the PKCS8
060 * key you pass in.
061 * </p><p>
062 * Can handle base64 PEM, or raw DER.
063 * Can handle PKCS8 Version 1.5 and 2.0.
064 * Can also handle OpenSSL encrypted or unencrypted private keys (DSA or RSA).
065 * </p><p>
066 * The PKCS12 key derivation (the "pkcs12()" method) comes from BouncyCastle.
067 * </p>
068 *
069 * @author Credit Union Central of British Columbia
070 * @author <a href="http://www.cucbc.com/">www.cucbc.com</a>
071 * @author <a href="mailto:juliusdavies@cucbc.com">juliusdavies@cucbc.com</a>
072 * @author <a href="bouncycastle.org">bouncycastle.org</a>
073 * @since 7-Nov-2006
074 */
075public class PKCS8Key {
076    public final static String RSA_OID = "1.2.840.113549.1.1.1";
077    public final static String DSA_OID = "1.2.840.10040.4.1";
078
079    public final static String PKCS8_UNENCRYPTED = "PRIVATE KEY";
080    public final static String PKCS8_ENCRYPTED = "ENCRYPTED PRIVATE KEY";
081    public final static String OPENSSL_RSA = "RSA PRIVATE KEY";
082    public final static String OPENSSL_DSA = "DSA PRIVATE KEY";
083
084    private final PrivateKey privateKey;
085    private final byte[] decryptedBytes;
086    private final String transformation;
087    private final int keySize;
088    private final boolean isDSA;
089    private final boolean isRSA;
090
091    static {
092        JavaImpl.load();
093    }
094
095    /**
096     * @param in       pkcs8 file to parse (pem or der, encrypted or unencrypted)
097     * @param password password to decrypt the pkcs8 file.  Ignored if the
098     *                 supplied pkcs8 is already unencrypted.
099     * @throws GeneralSecurityException If a parsing or decryption problem
100     *                                  occured.
101     * @throws IOException              If the supplied InputStream could not be read.
102     */
103    public PKCS8Key(final InputStream in, char[] password)
104        throws GeneralSecurityException, IOException {
105        this(Util.streamToBytes(in), password);
106    }
107
108    /**
109     * @param in       pkcs8 file to parse (pem or der, encrypted or unencrypted)
110     * @param password password to decrypt the pkcs8 file.  Ignored if the
111     *                 supplied pkcs8 is already unencrypted.
112     * @throws GeneralSecurityException If a parsing or decryption problem
113     *                                  occured.
114     */
115    public PKCS8Key(final ByteArrayInputStream in, char[] password)
116        throws GeneralSecurityException {
117        this(Util.streamToBytes(in), password);
118    }
119
120    /**
121     * @param encoded  pkcs8 file to parse (pem or der, encrypted or unencrypted)
122     * @param password password to decrypt the pkcs8 file.  Ignored if the
123     *                 supplied pkcs8 is already unencrypted.
124     * @throws GeneralSecurityException If a parsing or decryption problem
125     *                                  occured.
126     */
127    public PKCS8Key(final byte[] encoded, char[] password)
128        throws GeneralSecurityException {
129        DecryptResult decryptResult =
130            new DecryptResult("UNENCRYPTED", 0, encoded);
131
132        List pemItems = PEMUtil.decode(encoded);
133        PEMItem keyItem = null;
134        byte[] derBytes = null;
135        if (pemItems.isEmpty()) {
136            // must be DER encoded - PEMUtil wasn't able to extract anything.
137            derBytes = encoded;
138        } else {
139            Iterator it = pemItems.iterator();
140            boolean opensslRSA = false;
141            boolean opensslDSA = false;
142
143            while (it.hasNext()) {
144                PEMItem item = (PEMItem) it.next();
145                String type = item.pemType.trim().toUpperCase();
146                boolean plainPKCS8 = type.startsWith(PKCS8_UNENCRYPTED);
147                boolean encryptedPKCS8 = type.startsWith(PKCS8_ENCRYPTED);
148                boolean rsa = type.startsWith(OPENSSL_RSA);
149                boolean dsa = type.startsWith(OPENSSL_DSA);
150                if (plainPKCS8 || encryptedPKCS8 || rsa || dsa) {
151                    opensslRSA = opensslRSA || rsa;
152                    opensslDSA = opensslDSA || dsa;
153                    if (derBytes != null) {
154                        throw new ProbablyNotPKCS8Exception("More than one pkcs8 or OpenSSL key found in the supplied PEM Base64 stream");
155                    }
156                    derBytes = item.getDerBytes();
157                    keyItem = item;
158                    decryptResult = new DecryptResult("UNENCRYPTED", 0, derBytes);
159                }
160            }
161            // after the loop is finished, did we find anything?
162            if (derBytes == null) {
163                throw new ProbablyNotPKCS8Exception("No pkcs8 or OpenSSL key found in the supplied PEM Base64 stream");
164            }
165
166            if (opensslDSA || opensslRSA) {
167                String c = keyItem.cipher.trim();
168                boolean encrypted = !"UNKNOWN".equals(c) && !"".equals(c);
169                if (encrypted) {
170                    decryptResult = opensslDecrypt(keyItem, password);
171                }
172
173                String oid = RSA_OID;
174                if (opensslDSA) {
175                    oid = DSA_OID;
176                }
177                derBytes = formatAsPKCS8(decryptResult.bytes, oid, null);
178
179                String tf = decryptResult.transformation;
180                int ks = decryptResult.keySize;
181                decryptResult = new DecryptResult(tf, ks, derBytes);
182            }
183        }
184
185        ASN1Structure pkcs8;
186        try {
187            pkcs8 = ASN1Util.analyze(derBytes);
188        }
189        catch (Exception e) {
190            throw new ProbablyNotPKCS8Exception("asn1 parse failure: " + e);
191        }
192
193        String oid = RSA_OID;
194        // With the OpenSSL unencrypted private keys in DER format, the only way
195        // to even have a hope of guessing what we've got (DSA or RSA?) is to
196        // count the number of DERIntegers occurring in the first DERSequence.
197        int derIntegerCount = -1;
198        if (pkcs8.derIntegers != null) {
199            derIntegerCount = pkcs8.derIntegers.size();
200        }
201        switch (derIntegerCount) {
202            case 6:
203                oid = DSA_OID;
204            case 9:
205                derBytes = formatAsPKCS8(derBytes, oid, pkcs8);
206                pkcs8.oid1 = oid;
207
208                String tf = decryptResult.transformation;
209                int ks = decryptResult.keySize;
210                decryptResult = new DecryptResult(tf, ks, derBytes);
211                break;
212            default:
213                break;
214        }
215
216        oid = pkcs8.oid1;
217        if (!oid.startsWith("1.2.840.113549.1")) {
218            boolean isOkay = false;
219            if (oid.startsWith("1.2.840.10040.4.")) {
220                String s = oid.substring("1.2.840.10040.4.".length());
221                // 1.2.840.10040.4.1 -- id-dsa
222                // 1.2.840.10040.4.3 -- id-dsa-with-sha1
223                isOkay = s.equals("1") || s.startsWith("1.") ||
224                         s.equals("3") || s.startsWith("3.");
225            }
226            if (!isOkay) {
227                throw new ProbablyNotPKCS8Exception("Valid ASN.1, but not PKCS8 or OpenSSL format.  OID=" + oid);
228            }
229        }
230
231        boolean isRSA = RSA_OID.equals(oid);
232        boolean isDSA = DSA_OID.equals(oid);
233        boolean encrypted = !isRSA && !isDSA;
234        byte[] decryptedPKCS8 = encrypted ? null : derBytes;
235
236        if (encrypted) {
237            decryptResult = decryptPKCS8(pkcs8, password);
238            decryptedPKCS8 = decryptResult.bytes;
239        }
240        if (encrypted) {
241            try {
242                pkcs8 = ASN1Util.analyze(decryptedPKCS8);
243            }
244            catch (Exception e) {
245                throw new ProbablyBadPasswordException("Decrypted stream not ASN.1.  Probably bad decryption password.");
246            }
247            oid = pkcs8.oid1;
248            isDSA = DSA_OID.equals(oid);
249        }
250
251        KeySpec spec = new PKCS8EncodedKeySpec(decryptedPKCS8);
252        String type = "RSA";
253        PrivateKey pk;
254        try {
255            KeyFactory KF;
256            if (isDSA) {
257                type = "DSA";
258                KF = KeyFactory.getInstance("DSA");
259            } else {
260                KF = KeyFactory.getInstance("RSA");
261            }
262            pk = KF.generatePrivate(spec);
263        }
264        catch (Exception e) {
265            throw new ProbablyBadPasswordException("Cannot create " + type + " private key from decrypted stream.  Probably bad decryption password. " + e);
266        }
267        if (pk != null) {
268            this.privateKey = pk;
269            this.isDSA = isDSA;
270            this.isRSA = !isDSA;
271            this.decryptedBytes = decryptedPKCS8;
272            this.transformation = decryptResult.transformation;
273            this.keySize = decryptResult.keySize;
274        } else {
275            throw new GeneralSecurityException("KeyFactory.generatePrivate() returned null and didn't throw exception!");
276        }
277    }
278
279    public boolean isRSA() {
280        return isRSA;
281    }
282
283    public boolean isDSA() {
284        return isDSA;
285    }
286
287    public String getTransformation() {
288        return transformation;
289    }
290
291    public int getKeySize() {
292        return keySize;
293    }
294
295    public byte[] getDecryptedBytes() {
296        return decryptedBytes;
297    }
298
299    public PrivateKey getPrivateKey() {
300        return privateKey;
301    }
302
303    public PublicKey getPublicKey() throws GeneralSecurityException {
304        if (privateKey instanceof DSAPrivateKey) {
305            DSAPrivateKey dsa = (DSAPrivateKey) privateKey;
306            DSAParams params = dsa.getParams();
307            BigInteger g = params.getG();
308            BigInteger p = params.getP();
309            BigInteger q = params.getQ();
310            BigInteger x = dsa.getX();
311            BigInteger y = q.modPow( x, p );
312            DSAPublicKeySpec dsaKeySpec = new DSAPublicKeySpec(y, p, q, g);
313            return KeyFactory.getInstance("DSA").generatePublic(dsaKeySpec);
314        } else if (privateKey instanceof RSAPrivateCrtKey) {
315            RSAPrivateCrtKey rsa = (RSAPrivateCrtKey) privateKey;
316            RSAPublicKeySpec rsaKeySpec = new RSAPublicKeySpec(
317                    rsa.getModulus(),
318                    rsa.getPublicExponent()
319            );
320            return KeyFactory.getInstance("RSA").generatePublic(rsaKeySpec);
321        } else {
322            throw new GeneralSecurityException("Not an RSA or DSA key");
323        }
324    }
325
326    public static class DecryptResult {
327        public final String transformation;
328        public final int keySize;
329        public final byte[] bytes;
330
331        protected DecryptResult(String transformation, int keySize,
332                                byte[] decryptedBytes) {
333            this.transformation = transformation;
334            this.keySize = keySize;
335            this.bytes = decryptedBytes;
336        }
337    }
338
339    private static DecryptResult opensslDecrypt(final PEMItem item,
340                                                final char[] password)
341        throws GeneralSecurityException {
342        final String cipher = item.cipher;
343        final String mode = item.mode;
344        final int keySize = item.keySizeInBits;
345        final byte[] salt = item.iv;
346        final boolean des2 = item.des2;
347        final DerivedKey dk = OpenSSL.deriveKey(password, salt, keySize, des2);
348        return decrypt(cipher, mode, dk, des2, null, item.getDerBytes());
349    }
350
351    public static Cipher generateCipher(String cipher, String mode,
352                                        final DerivedKey dk,
353                                        final boolean des2,
354                                        final byte[] iv,
355                                        final boolean decryptMode)
356        throws NoSuchAlgorithmException, NoSuchPaddingException,
357        InvalidKeyException, InvalidAlgorithmParameterException {
358        if (des2 && dk.key.length >= 24) {
359            // copy first 8 bytes into last 8 bytes to create 2DES key.
360            System.arraycopy(dk.key, 0, dk.key, 16, 8);
361        }
362
363        final int keySize = dk.key.length * 8;
364        cipher = cipher.trim();
365        String cipherUpper = cipher.toUpperCase();
366        mode = mode.trim().toUpperCase();
367        // Is the cipher even available?
368        Cipher.getInstance(cipher);
369        String padding = "PKCS5Padding";
370        if (mode.startsWith("CFB") || mode.startsWith("OFB")) {
371            padding = "NoPadding";
372        }
373
374        String transformation = cipher + "/" + mode + "/" + padding;
375        if (cipherUpper.startsWith("RC4")) {
376            // RC4 does not take mode or padding.
377            transformation = cipher;
378        }
379
380        SecretKey secret = new SecretKeySpec(dk.key, cipher);
381        IvParameterSpec ivParams;
382        if (iv != null) {
383            ivParams = new IvParameterSpec(iv);
384        } else {
385            ivParams = dk.iv != null ? new IvParameterSpec(dk.iv) : null;
386        }
387
388        Cipher c = Cipher.getInstance(transformation);
389        int cipherMode = Cipher.ENCRYPT_MODE;
390        if (decryptMode) {
391            cipherMode = Cipher.DECRYPT_MODE;
392        }
393
394        // RC2 requires special params to inform engine of keysize.
395        if (cipherUpper.startsWith("RC2")) {
396            RC2ParameterSpec rcParams;
397            if (mode.startsWith("ECB") || ivParams == null) {
398                // ECB doesn't take an IV.
399                rcParams = new RC2ParameterSpec(keySize);
400            } else {
401                rcParams = new RC2ParameterSpec(keySize, ivParams.getIV());
402            }
403            c.init(cipherMode, secret, rcParams);
404        } else if (cipherUpper.startsWith("RC5")) {
405            RC5ParameterSpec rcParams;
406            if (mode.startsWith("ECB") || ivParams == null) {
407                // ECB doesn't take an IV.
408                rcParams = new RC5ParameterSpec(16, 12, 32);
409            } else {
410                rcParams = new RC5ParameterSpec(16, 12, 32, ivParams.getIV());
411            }
412            c.init(cipherMode, secret, rcParams);
413        } else if (mode.startsWith("ECB") || cipherUpper.startsWith("RC4")) {
414            // RC4 doesn't require any params.
415            // Any cipher using ECB does not require an IV.
416            c.init(cipherMode, secret);
417        } else {
418            // DES, DESede, AES, BlowFish require IVParams (when in CBC, CFB,
419            // or OFB mode).  (In ECB mode they don't require IVParams).
420            c.init(cipherMode, secret, ivParams);
421        }
422        return c;
423    }
424
425    public static DecryptResult decrypt(String cipher, String mode,
426                                        final DerivedKey dk,
427                                        final boolean des2,
428                                        final byte[] iv,
429                                        final byte[] encryptedBytes)
430
431        throws NoSuchAlgorithmException, NoSuchPaddingException,
432        InvalidKeyException, InvalidAlgorithmParameterException,
433        IllegalBlockSizeException, BadPaddingException {
434        Cipher c = generateCipher(cipher, mode, dk, des2, iv, true);
435        final String transformation = c.getAlgorithm();
436        final int keySize = dk.key.length * 8;
437        byte[] decryptedBytes = c.doFinal(encryptedBytes);
438        return new DecryptResult(transformation, keySize, decryptedBytes);
439    }
440
441    private static DecryptResult decryptPKCS8(ASN1Structure pkcs8,
442                                              char[] password)
443        throws GeneralSecurityException {
444        boolean isVersion1 = true;
445        boolean isVersion2 = false;
446        boolean usePKCS12PasswordPadding = false;
447        boolean use2DES = false;
448        String cipher = null;
449        String hash = null;
450        int keySize = -1;
451        // Almost all PKCS8 encrypted keys use CBC.  Looks like the AES OID's can
452        // support different modes, and RC4 doesn't use any mode at all!
453        String mode = "CBC";
454
455        // In PKCS8 Version 2 the IV is stored in the ASN.1 structure for
456        // us, so we don't need to derive it.  Just leave "ivSize" set to 0 for
457        // those ones.
458        int ivSize = 0;
459
460        String oid = pkcs8.oid1;
461        if (oid.startsWith("1.2.840.113549.1.12."))  // PKCS12 key derivation!
462        {
463            usePKCS12PasswordPadding = true;
464
465            // Let's trim this OID to make life a little easier.
466            oid = oid.substring("1.2.840.113549.1.12.".length());
467
468            if (oid.equals("1.1") || oid.startsWith("1.1.")) {
469                // 1.2.840.113549.1.12.1.1
470                hash = "SHA1";
471                cipher = "RC4";
472                keySize = 128;
473            } else if (oid.equals("1.2") || oid.startsWith("1.2.")) {
474                // 1.2.840.113549.1.12.1.2
475                hash = "SHA1";
476                cipher = "RC4";
477                keySize = 40;
478            } else if (oid.equals("1.3") || oid.startsWith("1.3.")) {
479                // 1.2.840.113549.1.12.1.3
480                hash = "SHA1";
481                cipher = "DESede";
482                keySize = 192;
483            } else if (oid.equals("1.4") || oid.startsWith("1.4.")) {
484                // DES2 !!!
485
486                // 1.2.840.113549.1.12.1.4
487                hash = "SHA1";
488                cipher = "DESede";
489                keySize = 192;
490                use2DES = true;
491                // later on we'll copy the first 8 bytes of the 24 byte DESede key
492                // over top the last 8 bytes, making the key look like K1-K2-K1
493                // instead of the usual K1-K2-K3.
494            } else if (oid.equals("1.5") || oid.startsWith("1.5.")) {
495                // 1.2.840.113549.1.12.1.5
496                hash = "SHA1";
497                cipher = "RC2";
498                keySize = 128;
499            } else if (oid.equals("1.6") || oid.startsWith("1.6.")) {
500                // 1.2.840.113549.1.12.1.6
501                hash = "SHA1";
502                cipher = "RC2";
503                keySize = 40;
504            }
505        } else if (oid.startsWith("1.2.840.113549.1.5.")) {
506            // Let's trim this OID to make life a little easier.
507            oid = oid.substring("1.2.840.113549.1.5.".length());
508
509            if (oid.equals("1") || oid.startsWith("1.")) {
510                // 1.2.840.113549.1.5.1 -- pbeWithMD2AndDES-CBC
511                hash = "MD2";
512                cipher = "DES";
513                keySize = 64;
514            } else if (oid.equals("3") || oid.startsWith("3.")) {
515                // 1.2.840.113549.1.5.3 -- pbeWithMD5AndDES-CBC
516                hash = "MD5";
517                cipher = "DES";
518                keySize = 64;
519            } else if (oid.equals("4") || oid.startsWith("4.")) {
520                // 1.2.840.113549.1.5.4 -- pbeWithMD2AndRC2_CBC
521                hash = "MD2";
522                cipher = "RC2";
523                keySize = 64;
524            } else if (oid.equals("6") || oid.startsWith("6.")) {
525                // 1.2.840.113549.1.5.6 -- pbeWithMD5AndRC2_CBC
526                hash = "MD5";
527                cipher = "RC2";
528                keySize = 64;
529            } else if (oid.equals("10") || oid.startsWith("10.")) {
530                // 1.2.840.113549.1.5.10 -- pbeWithSHA1AndDES-CBC
531                hash = "SHA1";
532                cipher = "DES";
533                keySize = 64;
534            } else if (oid.equals("11") || oid.startsWith("11.")) {
535                // 1.2.840.113549.1.5.11 -- pbeWithSHA1AndRC2_CBC
536                hash = "SHA1";
537                cipher = "RC2";
538                keySize = 64;
539            } else if (oid.equals("12") || oid.startsWith("12.")) {
540                // 1.2.840.113549.1.5.12 - id-PBKDF2 - Key Derivation Function
541                isVersion2 = true;
542            } else if (oid.equals("13") || oid.startsWith("13.")) {
543                // 1.2.840.113549.1.5.13 - id-PBES2: PBES2 encryption scheme
544                isVersion2 = true;
545            } else if (oid.equals("14") || oid.startsWith("14.")) {
546                // 1.2.840.113549.1.5.14 - id-PBMAC1 message authentication scheme
547                isVersion2 = true;
548            }
549        }
550        if (isVersion2) {
551            isVersion1 = false;
552            hash = "HmacSHA1";
553            oid = pkcs8.oid2;
554
555            // really ought to be:
556            //
557            // if ( oid.startsWith( "1.2.840.113549.1.5.12" ) )
558            //
559            // but all my tests still pass, and I figure this to be more robust:
560            if (pkcs8.oid3 != null) {
561                oid = pkcs8.oid3;
562            }
563            if (oid.startsWith("1.3.6.1.4.1.3029.1.2")) {
564                // 1.3.6.1.4.1.3029.1.2 - Blowfish
565                cipher = "Blowfish";
566                mode = "CBC";
567                keySize = 128;
568            } else if (oid.startsWith("1.3.14.3.2.")) {
569                oid = oid.substring("1.3.14.3.2.".length());
570                if (oid.equals("6") || oid.startsWith("6.")) {
571                    // 1.3.14.3.2.6 - desECB
572                    cipher = "DES";
573                    mode = "ECB";
574                    keySize = 64;
575                } else if (oid.equals("7") || oid.startsWith("7.")) {
576                    // 1.3.14.3.2.7 - desCBC
577                    cipher = "DES";
578                    mode = "CBC";
579                    keySize = 64;
580                } else if (oid.equals("8") || oid.startsWith("8.")) {
581                    // 1.3.14.3.2.8 - desOFB
582                    cipher = "DES";
583                    mode = "OFB";
584                    keySize = 64;
585                } else if (oid.equals("9") || oid.startsWith("9.")) {
586                    // 1.3.14.3.2.9 - desCFB
587                    cipher = "DES";
588                    mode = "CFB";
589                    keySize = 64;
590                } else if (oid.equals("17") || oid.startsWith("17.")) {
591                    // 1.3.14.3.2.17 - desEDE
592                    cipher = "DESede";
593                    mode = "CBC";
594                    keySize = 192;
595
596                    // If the supplied IV is all zeroes, then this is DES2
597                    // (Well, that's what happened when I played with OpenSSL!)
598                    if (allZeroes(pkcs8.iv)) {
599                        mode = "ECB";
600                        use2DES = true;
601                        pkcs8.iv = null;
602                    }
603                }
604            }
605
606            // AES
607            // 2.16.840.1.101.3.4.1.1  - id-aes128-ECB
608            // 2.16.840.1.101.3.4.1.2  - id-aes128-CBC
609            // 2.16.840.1.101.3.4.1.3  - id-aes128-OFB
610            // 2.16.840.1.101.3.4.1.4  - id-aes128-CFB
611            // 2.16.840.1.101.3.4.1.21 - id-aes192-ECB
612            // 2.16.840.1.101.3.4.1.22 - id-aes192-CBC
613            // 2.16.840.1.101.3.4.1.23 - id-aes192-OFB
614            // 2.16.840.1.101.3.4.1.24 - id-aes192-CFB
615            // 2.16.840.1.101.3.4.1.41 - id-aes256-ECB
616            // 2.16.840.1.101.3.4.1.42 - id-aes256-CBC
617            // 2.16.840.1.101.3.4.1.43 - id-aes256-OFB
618            // 2.16.840.1.101.3.4.1.44 - id-aes256-CFB
619            else if (oid.startsWith("2.16.840.1.101.3.4.1.")) {
620                cipher = "AES";
621                if (pkcs8.iv == null) {
622                    ivSize = 128;
623                }
624                oid = oid.substring("2.16.840.1.101.3.4.1.".length());
625                int x = oid.indexOf('.');
626                int finalDigit;
627                if (x >= 0) {
628                    finalDigit = Integer.parseInt(oid.substring(0, x));
629                } else {
630                    finalDigit = Integer.parseInt(oid);
631                }
632                switch (finalDigit % 10) {
633                    case 1:
634                        mode = "ECB";
635                        break;
636                    case 2:
637                        mode = "CBC";
638                        break;
639                    case 3:
640                        mode = "OFB";
641                        break;
642                    case 4:
643                        mode = "CFB";
644                        break;
645                    default:
646                        throw new RuntimeException("Unknown AES final digit: " + finalDigit);
647                }
648                switch (finalDigit / 10) {
649                    case 0:
650                        keySize = 128;
651                        break;
652                    case 2:
653                        keySize = 192;
654                        break;
655                    case 4:
656                        keySize = 256;
657                        break;
658                    default:
659                        throw new RuntimeException("Unknown AES final digit: " + finalDigit);
660                }
661            } else if (oid.startsWith("1.2.840.113549.3.")) {
662                // Let's trim this OID to make life a little easier.
663                oid = oid.substring("1.2.840.113549.3.".length());
664
665                if (oid.equals("2") || oid.startsWith("2.")) {
666                    // 1.2.840.113549.3.2 - RC2-CBC
667                    // Note:  keysize determined in PKCS8 Version 2.0 ASN.1 field.
668                    cipher = "RC2";
669                    keySize = pkcs8.keySize * 8;
670                } else if (oid.equals("4") || oid.startsWith("4.")) {
671                    // 1.2.840.113549.3.4 - RC4
672                    // Note:  keysize determined in PKCS8 Version 2.0 ASN.1 field.
673                    cipher = "RC4";
674                    keySize = pkcs8.keySize * 8;
675                } else if (oid.equals("7") || oid.startsWith("7.")) {
676                    // 1.2.840.113549.3.7 - DES-EDE3-CBC
677                    cipher = "DESede";
678                    keySize = 192;
679                } else if (oid.equals("9") || oid.startsWith("9.")) {
680                    // 1.2.840.113549.3.9 - RC5 CBC Pad
681                    // Note:  keysize determined in PKCS8 Version 2.0 ASN.1 field.
682                    keySize = pkcs8.keySize * 8;
683                    cipher = "RC5";
684
685                    // Need to find out more about RC5.
686                    // How do I create the RC5ParameterSpec?
687                    // (int version, int rounds, int wordSize, byte[] iv)
688                }
689            }
690        }
691
692        // The pkcs8 structure has been thoroughly examined.  If we don't have
693        // a cipher or hash at this point, then we don't support the file we
694        // were given.
695        if (cipher == null || hash == null) {
696            throw new ProbablyNotPKCS8Exception("Unsupported PKCS8 format. oid1=[" + pkcs8.oid1 + "], oid2=[" + pkcs8.oid2 + "]");
697        }
698
699        // In PKCS8 Version 1.5 we need to derive an 8 byte IV.  In those cases
700        // the ASN.1 structure doesn't have the IV, anyway, so I can use that
701        // to decide whether to derive one or not.
702        //
703        // Note:  if AES, then IV has to be 16 bytes.
704        if (pkcs8.iv == null) {
705            ivSize = 64;
706        }
707
708        byte[] salt = pkcs8.salt;
709        int ic = pkcs8.iterationCount;
710
711        // PKCS8 converts the password to a byte[] array using a simple
712        // cast.  This byte[] array is ignored if we're using the PKCS12
713        // key derivation, since that employs a different technique.
714        byte[] pwd = new byte[password.length];
715        for (int i = 0; i < pwd.length; i++) {
716            pwd[i] = (byte) password[i];
717        }
718
719        DerivedKey dk;
720        if (usePKCS12PasswordPadding) {
721            MessageDigest md = MessageDigest.getInstance(hash);
722            dk = deriveKeyPKCS12(password, salt, ic, keySize, ivSize, md);
723        } else {
724            if (isVersion1) {
725                MessageDigest md = MessageDigest.getInstance(hash);
726                dk = deriveKeyV1(pwd, salt, ic, keySize, ivSize, md);
727            } else {
728                Mac mac = Mac.getInstance(hash);
729                dk = deriveKeyV2(pwd, salt, ic, keySize, ivSize, mac);
730            }
731        }
732
733
734        return decrypt(cipher, mode, dk, use2DES, pkcs8.iv, pkcs8.bigPayload);
735    }
736
737
738    public static DerivedKey deriveKeyV1(byte[] password, byte[] salt,
739                                         int iterations, int keySizeInBits,
740                                         int ivSizeInBits, MessageDigest md) {
741        int keySize = keySizeInBits / 8;
742        int ivSize = ivSizeInBits / 8;
743        md.reset();
744        md.update(password);
745        byte[] result = md.digest(salt);
746        for (int i = 1; i < iterations; i++) {
747            // Hash of the hash for each of the iterations.
748            result = md.digest(result);
749        }
750        byte[] key = new byte[keySize];
751        byte[] iv = new byte[ivSize];
752        System.arraycopy(result, 0, key, 0, key.length);
753        System.arraycopy(result, key.length, iv, 0, iv.length);
754        return new DerivedKey(key, iv);
755    }
756
757    public static DerivedKey deriveKeyPKCS12(char[] password, byte[] salt,
758                                             int iterations, int keySizeInBits,
759                                             int ivSizeInBits,
760                                             MessageDigest md) {
761        byte[] pwd;
762        if (password.length > 0) {
763            pwd = new byte[(password.length + 1) * 2];
764            for (int i = 0; i < password.length; i++) {
765                pwd[i * 2] = (byte) (password[i] >>> 8);
766                pwd[i * 2 + 1] = (byte) password[i];
767            }
768        } else {
769            pwd = new byte[0];
770        }
771        int keySize = keySizeInBits / 8;
772        int ivSize = ivSizeInBits / 8;
773        byte[] key = pkcs12(1, keySize, salt, pwd, iterations, md);
774        byte[] iv = pkcs12(2, ivSize, salt, pwd, iterations, md);
775        return new DerivedKey(key, iv);
776    }
777
778    /**
779     * This PKCS12 key derivation code comes from BouncyCastle.
780     *
781     * @param idByte         1 == key, 2 == iv
782     * @param n              keysize or ivsize
783     * @param salt           8 byte salt
784     * @param password       password
785     * @param iterationCount iteration-count
786     * @param md             The message digest to use
787     * @return byte[] the derived key
788     */
789    private static byte[] pkcs12(int idByte, int n, byte[] salt,
790                                 byte[] password, int iterationCount,
791                                 MessageDigest md) {
792        int u = md.getDigestLength();
793        // sha1, md2, md5 all use 512 bits.  But future hashes might not.
794        int v = 512 / 8;
795        md.reset();
796        byte[] D = new byte[v];
797        byte[] dKey = new byte[n];
798        for (int i = 0; i != D.length; i++) {
799            D[i] = (byte) idByte;
800        }
801        byte[] S;
802        if ((salt != null) && (salt.length != 0)) {
803            S = new byte[v * ((salt.length + v - 1) / v)];
804            for (int i = 0; i != S.length; i++) {
805                S[i] = salt[i % salt.length];
806            }
807        } else {
808            S = new byte[0];
809        }
810        byte[] P;
811        if ((password != null) && (password.length != 0)) {
812            P = new byte[v * ((password.length + v - 1) / v)];
813            for (int i = 0; i != P.length; i++) {
814                P[i] = password[i % password.length];
815            }
816        } else {
817            P = new byte[0];
818        }
819        byte[] I = new byte[S.length + P.length];
820        System.arraycopy(S, 0, I, 0, S.length);
821        System.arraycopy(P, 0, I, S.length, P.length);
822        byte[] B = new byte[v];
823        int c = (n + u - 1) / u;
824        for (int i = 1; i <= c; i++) {
825            md.update(D);
826            byte[] result = md.digest(I);
827            for (int j = 1; j != iterationCount; j++) {
828                result = md.digest(result);
829            }
830            for (int j = 0; j != B.length; j++) {
831                B[j] = result[j % result.length];
832            }
833            for (int j = 0; j < (I.length / v); j++) {
834                /*
835                     * add a + b + 1, returning the result in a. The a value is treated
836                     * as a BigInteger of length (b.length * 8) bits. The result is
837                     * modulo 2^b.length in case of overflow.
838                     */
839                int aOff = j * v;
840                int bLast = B.length - 1;
841                int x = (B[bLast] & 0xff) + (I[aOff + bLast] & 0xff) + 1;
842                I[aOff + bLast] = (byte) x;
843                x >>>= 8;
844                for (int k = B.length - 2; k >= 0; k--) {
845                    x += (B[k] & 0xff) + (I[aOff + k] & 0xff);
846                    I[aOff + k] = (byte) x;
847                    x >>>= 8;
848                }
849            }
850            if (i == c) {
851                System.arraycopy(result, 0, dKey, (i - 1) * u, dKey.length - ((i - 1) * u));
852            } else {
853                System.arraycopy(result, 0, dKey, (i - 1) * u, result.length);
854            }
855        }
856        return dKey;
857    }
858
859    public static DerivedKey deriveKeyV2(byte[] password, byte[] salt,
860                                         int iterations, int keySizeInBits,
861                                         int ivSizeInBits, Mac mac)
862        throws InvalidKeyException {
863        int keySize = keySizeInBits / 8;
864        int ivSize = ivSizeInBits / 8;
865
866        // Because we're using an Hmac, we need to initialize with a SecretKey.
867        // HmacSHA1 doesn't need SecretKeySpec's 2nd parameter, hence the "N/A".
868        SecretKeySpec sk = new SecretKeySpec(password, "N/A");
869        mac.init(sk);
870        int macLength = mac.getMacLength();
871        int derivedKeyLength = keySize + ivSize;
872        int blocks = (derivedKeyLength + macLength - 1) / macLength;
873        byte[] blockIndex = new byte[4];
874        byte[] finalResult = new byte[blocks * macLength];
875        for (int i = 1; i <= blocks; i++) {
876            int offset = (i - 1) * macLength;
877            blockIndex[0] = (byte) (i >>> 24);
878            blockIndex[1] = (byte) (i >>> 16);
879            blockIndex[2] = (byte) (i >>> 8);
880            blockIndex[3] = (byte) i;
881            mac.reset();
882            mac.update(salt);
883            byte[] result = mac.doFinal(blockIndex);
884            System.arraycopy(result, 0, finalResult, offset, result.length);
885            for (int j = 1; j < iterations; j++) {
886                mac.reset();
887                result = mac.doFinal(result);
888                for (int k = 0; k < result.length; k++) {
889                    finalResult[offset + k] ^= result[k];
890                }
891            }
892        }
893        byte[] key = new byte[keySize];
894        byte[] iv = new byte[ivSize];
895        System.arraycopy(finalResult, 0, key, 0, key.length);
896        System.arraycopy(finalResult, key.length, iv, 0, iv.length);
897        return new DerivedKey(key, iv);
898    }
899
900    public static byte[] formatAsPKCS8(byte[] privateKey, String oid,
901                                       ASN1Structure pkcs8) {
902        DERInteger derZero = new DERInteger(BigInteger.ZERO);
903        ASN1EncodableVector outterVec = new ASN1EncodableVector();
904        ASN1EncodableVector innerVec = new ASN1EncodableVector();
905        DEROctetString octetsToAppend;
906        try {
907            DERObjectIdentifier derOID = new DERObjectIdentifier(oid);
908            innerVec.add(derOID);
909            if (DSA_OID.equals(oid)) {
910                if (pkcs8 == null) {
911                    try {
912                        pkcs8 = ASN1Util.analyze(privateKey);
913                    }
914                    catch (Exception e) {
915                        throw new RuntimeException("asn1 parse failure " + e);
916                    }
917                }
918                if (pkcs8.derIntegers == null || pkcs8.derIntegers.size() < 6) {
919                    throw new RuntimeException("invalid DSA key - can't find P, Q, G, X");
920                }
921
922                DERInteger[] ints = new DERInteger[pkcs8.derIntegers.size()];
923                pkcs8.derIntegers.toArray(ints);
924                DERInteger p = ints[1];
925                DERInteger q = ints[2];
926                DERInteger g = ints[3];
927                DERInteger x = ints[5];
928
929                byte[] encodedX = encode(x);
930                octetsToAppend = new DEROctetString(encodedX);
931                ASN1EncodableVector pqgVec = new ASN1EncodableVector();
932                pqgVec.add(p);
933                pqgVec.add(q);
934                pqgVec.add(g);
935                DERSequence pqg = new DERSequence(pqgVec);
936                innerVec.add(pqg);
937            } else {
938                innerVec.add(DERNull.INSTANCE);
939                octetsToAppend = new DEROctetString(privateKey);
940            }
941
942            DERSequence inner = new DERSequence(innerVec);
943            outterVec.add(derZero);
944            outterVec.add(inner);
945            outterVec.add(octetsToAppend);
946            DERSequence outter = new DERSequence(outterVec);
947            return encode(outter);
948        }
949        catch (IOException ioe) {
950            throw JavaImpl.newRuntimeException(ioe);
951        }
952    }
953
954    private static boolean allZeroes(byte[] b) {
955        for (int i = 0; i < b.length; i++) {
956            if (b[i] != 0) {
957                return false;
958            }
959        }
960        return true;
961    }
962
963    public static byte[] encode(DEREncodable der) throws IOException {
964        ByteArrayOutputStream baos = new ByteArrayOutputStream(1024);
965        ASN1OutputStream out = new ASN1OutputStream(baos);
966        out.writeObject(der);
967        out.close();
968        return baos.toByteArray();
969    }
970
971    public static void main(String[] args) throws Exception {
972        String password = "changeit";
973        if (args.length == 0) {
974            System.out.println("Usage1:  [password] [file:private-key]      Prints decrypted PKCS8 key (base64).");
975            System.out.println("Usage2:  [password] [file1] [file2] etc...  Checks that all private keys are equal.");
976            System.out.println("Usage2 assumes that all files can be decrypted with the same password.");
977        } else if (args.length == 1 || args.length == 2) {
978            FileInputStream in = new FileInputStream(args[args.length - 1]);
979            if (args.length == 2) {
980                password = args[0];
981            }
982            byte[] bytes = Util.streamToBytes(in);
983            PKCS8Key key = new PKCS8Key(bytes, password.toCharArray());
984            PEMItem item = new PEMItem(key.getDecryptedBytes(), "PRIVATE KEY");
985            byte[] pem = PEMUtil.encode(Collections.singleton(item));
986            System.out.write(pem);
987        } else {
988            byte[] original = null;
989            File f = new File(args[0]);
990            int i = 0;
991            if (!f.exists()) {
992                // File0 doesn't exist, so it must be a password!
993                password = args[0];
994                i++;
995            }
996            for (; i < args.length; i++) {
997                FileInputStream in = new FileInputStream(args[i]);
998                byte[] bytes = Util.streamToBytes(in);
999                PKCS8Key key = null;
1000                try {
1001                    key = new PKCS8Key(bytes, password.toCharArray());
1002                }
1003                catch (Exception e) {
1004                    System.out.println(" FAILED! " + args[i] + " " + e);
1005                }
1006                if (key != null) {
1007                    byte[] decrypted = key.getDecryptedBytes();
1008                    int keySize = key.getKeySize();
1009                    String keySizeStr = "" + keySize;
1010                    if (keySize < 10) {
1011                        keySizeStr = "  " + keySizeStr;
1012                    } else if (keySize < 100) {
1013                        keySizeStr = " " + keySizeStr;
1014                    }
1015                    StringBuffer buf = new StringBuffer(key.getTransformation());
1016                    int maxLen = "Blowfish/CBC/PKCS5Padding".length();
1017                    for (int j = buf.length(); j < maxLen; j++) {
1018                        buf.append(' ');
1019                    }
1020                    String transform = buf.toString();
1021                    String type = key.isDSA() ? "DSA" : "RSA";
1022
1023                    if (original == null) {
1024                        original = decrypted;
1025                        System.out.println("   SUCCESS    \t" + type + "\t" + transform + "\t" + keySizeStr + "\t" + args[i]);
1026                    } else {
1027                        boolean identical = Arrays.equals(original, decrypted);
1028                        if (!identical) {
1029                            System.out.println("***FAILURE*** \t" + type + "\t" + transform + "\t" + keySizeStr + "\t" + args[i]);
1030                        } else {
1031                            System.out.println("   SUCCESS    \t" + type + "\t" + transform + "\t" + keySizeStr + "\t" + args[i]);
1032                        }
1033                    }
1034                }
1035            }
1036        }
1037    }
1038
1039}