/******************************************************************************
 *
 * Copyright (c) 1999-2001 AppGate AB. All Rights Reserved.
 * 
 * This file contains Original Code and/or Modifications of Original Code as
 * defined in and that are subject to the MindTerm Public Source License,
 * Version 1.1, (the 'License'). You may not use this file except in compliance
 * with the License.
 * 
 * You should have received a copy of the MindTerm Public Source License
 * along with this software; see the file LICENSE.  If not, write to
 * AppGate AB, Stora Badhusgatan 18-20, 41121 Goteborg, SWEDEN
 *
 *****************************************************************************/

package com.mindbright.ssh2;

import java.io.InputStream;
import java.io.OutputStream;
import java.io.IOException;
import java.math.BigInteger;
import java.util.StringTokenizer;
import java.util.NoSuchElementException;

import com.mindbright.jca.security.KeyPair;
import com.mindbright.jca.security.PublicKey;
import com.mindbright.jca.security.PrivateKey;
import com.mindbright.jca.security.KeyFactory;
import com.mindbright.jca.security.MessageDigest;
import com.mindbright.jca.security.InvalidKeyException;
import com.mindbright.jca.security.NoSuchAlgorithmException;
import com.mindbright.jca.security.interfaces.DSAParams;
import com.mindbright.jca.security.interfaces.DSAPublicKey;
import com.mindbright.jca.security.interfaces.RSAPublicKey;
import com.mindbright.jca.security.interfaces.DSAPrivateKey;
import com.mindbright.jca.security.interfaces.RSAPrivateCrtKey;
import com.mindbright.jca.security.spec.KeySpec;
import com.mindbright.jca.security.spec.DSAPublicKeySpec;
import com.mindbright.jca.security.spec.DSAPrivateKeySpec;
import com.mindbright.jca.security.spec.RSAPublicKeySpec;
import com.mindbright.jca.security.spec.RSAPrivateCrtKeySpec;

import javax.crypto.Cipher;
import javax.crypto.spec.SecretKeySpec;

import com.mindbright.security.publickey.RSAAlgorithm;

import com.mindbright.util.Base64;

import com.mindbright.util.ASCIIArmour;

import com.mindbright.ssh2.SSH2Signature;

public final class SSH2SimplePKIFile {
    public final static int SSH_PRIVATE_KEY_MAGIC = 0x3f6ff9eb;

    public final static String BEGIN_PUB_KEY = "---- BEGIN SSH2 PUBLIC KEY ----";
    public final static String END_PUB_KEY   = "---- END SSH2 PUBLIC KEY ----";
    public final static String BEGIN_PRV_KEY = "---- BEGIN SSH2 ENCRYPTED PRIVATE KEY ----";
    public final static String END_PRV_KEY   = "---- END SSH2 ENCRYPTED PRIVATE KEY ----";

    public final static String FILE_SUBJECT = "Subject";
    public final static String FILE_COMMENT = "Comment";

    private ASCIIArmour   armour;
    private SSH2Signature signature;
    private KeyPair       keyPair;

    private SSH2SimplePKIFile(ASCIIArmour armour, SSH2Signature signature) {
	this.armour    = armour;
	this.signature = signature;
    }

    private SSH2SimplePKIFile(KeyPair keyPair, ASCIIArmour armour,
			      SSH2Signature signature) {
	this(armour, signature);
	this.keyPair = keyPair;
    }

    public ASCIIArmour getArmour() {
	return armour;
    }

    public SSH2Signature getSignature() {
	return signature;
    }

    public KeyPair getKeyPair() {
	return keyPair;
    }

    public static SSH2SimplePKIFile createFromPublicKeyFile(InputStream in)
	throws IOException, SSH2Exception
    {
	ASCIIArmour   armour    = new ASCIIArmour(BEGIN_PUB_KEY, END_PUB_KEY);
	byte[]        keyBlob   = armour.decode(in);

	SSH2Signature signature =
	    SSH2SimpleSignature.getVerifyInstance(keyBlob);

        return new SSH2SimplePKIFile(armour, signature);
    }

    public static ASCIIArmour getPublicKeyArmour(String subject,
						 String comment) {
	ASCIIArmour armour = new ASCIIArmour(BEGIN_PUB_KEY, END_PUB_KEY);
	armour.setHeaderField(FILE_SUBJECT, subject);
	armour.setHeaderField(FILE_COMMENT, comment);
	return armour;
    }

    public static ASCIIArmour getPrivateKeyArmour(String subject,
						  String comment) {
	ASCIIArmour armour = new ASCIIArmour(BEGIN_PRV_KEY, END_PRV_KEY);
	armour.setHeaderField(FILE_SUBJECT, subject);
	armour.setHeaderField(FILE_COMMENT, comment);
	return armour;
    }

    public static String writePublicKeyFile(SSH2Signature signature,
					    ASCIIArmour armour,
					    OutputStream out)
	throws IOException, SSH2Exception
    {
        byte[] keyBlob = signature.getPublicKeyBlob();
	return writePublicKeyFile(keyBlob, armour, out);
    }

    public static String writePublicKeyFile(PublicKey publicKey,
					    ASCIIArmour armour,
					    OutputStream out)
	throws IOException, SSH2Exception
    {
	String        type    = getKeyAlgorithm(publicKey);
	SSH2Signature encoder = SSH2Signature.getEncodingInstance(type);
        byte[]        keyBlob = encoder.encodePublicKey(publicKey);
	return writePublicKeyFile(keyBlob, armour, out);
    }

    public static String writeOpenSSHPublicKeyFile(PublicKey publicKey,
						   String comment,
						   OutputStream out)
	throws IOException, SSH2Exception
    {
	String        type      = getKeyAlgorithm(publicKey);
	SSH2Signature encoder   = SSH2Signature.getEncodingInstance(type);
        byte[]        keyBlob   = encoder.encodePublicKey(publicKey);
	String        keyString = writeOpenSSHPublicKeyString(keyBlob,
							      type,
							      comment);
	out.write(keyString.getBytes());

	return keyString;
    }

    public static String writePublicKeyFile(byte[] keyBlob, ASCIIArmour armour,
					    OutputStream out) throws IOException
    {
	armour.setCanonicalLineEnd(false);
	armour.encode(out, keyBlob);

	return new String(armour.encode(keyBlob));
    }

    public static SSH2SimplePKIFile createFromKeyPairFile(String password,
							  InputStream in,
							  boolean sign)
	throws IOException, SSH2Exception, NoSuchAlgorithmException
    {
	ASCIIArmour   armour    = new ASCIIArmour(BEGIN_PRV_KEY, END_PRV_KEY);
	byte[]        keyBlob   = armour.decode(in);

	String algorithm = getKeyAlgorithm(keyBlob);

	SSH2Signature signature = SSH2Signature.getInstance(algorithm);
	KeyPair       keyPair   = readKeyPair(algorithm, keyBlob, password);

	if(sign) {
	    signature.initSign(keyPair.getPrivate());
	    signature.setPublicKey(keyPair.getPublic());
	} else {
	    signature.initVerify(keyPair.getPublic());
	}

	return new SSH2SimplePKIFile(keyPair, armour, signature);
    }

    public static String getKeyAlgorithm(byte[] keyPairBlob)
	throws SSH2FatalException
    {
	SSH2DataBuffer buf = new SSH2DataBuffer(keyPairBlob.length);
	buf.writeRaw(keyPairBlob);

	buf.readInt();
	buf.readInt();

	String type = buf.readJavaString();

	if(type.indexOf("dl-modp") != -1) {
	    return "ssh-dss";
	} else if(type.indexOf("if-modn") != -1) {
	    return "ssh-rsa";
	} else {
	    throw new SSH2FatalException("Unknown KeyPair alg: " + type);
	}
    }

    public static String getKeyAlgorithm(PublicKey publicKey)
	throws SSH2FatalException
    {
	if(publicKey instanceof DSAPublicKey) {
	    return "ssh-dss";
	} else if(publicKey instanceof RSAPublicKey) {
	    return "ssh-rsa";
	} else {
	    throw new SSH2FatalException("Unknown publickey alg: " + publicKey);
	}
    }

    public static byte[] expandPasswordToKey(String password, int keyLen) {
	try {
	    MessageDigest md5    = MessageDigest.getInstance("MD5");
	    int           digLen = md5.getDigestLength();
	    byte[]        buf    = new byte[((keyLen + digLen) / digLen) *
					   digLen];
	    int           cnt    = 0;
	    while(cnt < keyLen) {
		md5.update(password.getBytes());
		if(cnt > 0) {
		    md5.update(buf, 0, cnt);
		}
		md5.digest(buf, cnt, digLen);
		cnt += digLen;
	    }
	    byte[] key = new byte[keyLen];
	    System.arraycopy(buf, 0, key, 0, keyLen);
	    return key;
	} catch (Exception e) {
	    throw new Error("Error in SSH2DSS.expandPasswordToKey: " + e);
	}
    }

    public static KeyPair readKeyPair(String algorithm, byte[] keyBlob,
				      String password)
	throws SSH2Exception, NoSuchAlgorithmException
    {
	SSH2DataBuffer buf = new SSH2DataBuffer(keyBlob.length);

	buf.writeRaw(keyBlob);

	int    magic         = buf.readInt();
	int    privateKeyLen = buf.readInt();
	String type          = buf.readJavaString();
	String cipher        = buf.readJavaString();
	int    bufLen        = buf.readInt();

	if((algorithm.equals("ssh-dss") && (type.indexOf("dl-modp") == -1)) ||
	   (algorithm.equals("ssh-rsa") && (type.indexOf("if-modn") == -1)))
	{
	    // !!! TODO: keyformaterror exception?
	    throw new SSH2FatalException("Wrong key type '" + type + "' for " +
					 algorithm);
	}

	if(magic != SSH_PRIVATE_KEY_MAGIC) {
	    // !!! TODO: keyformaterror exception?
	    throw new SSH2FatalException("Invalid magic in private key: " +
					 magic);
	}

	if(!cipher.equals("none")) {
	    try {
		int keyLen =
		    SSH2TransportPreferences.getCipherKeyLen(cipher);
		String cipherName =
		    SSH2TransportPreferences.ssh2ToJCECipher(cipher);
		byte[] key = expandPasswordToKey(password, keyLen);
		Cipher decrypt = Cipher.getInstance(cipherName);
		decrypt.init(Cipher.DECRYPT_MODE,
			     new SecretKeySpec(key, decrypt.getAlgorithm()));
		byte[] data   = buf.getData();
		int    offset = buf.getRPos();
		decrypt.doFinal(data, offset, bufLen, data, offset);
	    } catch (InvalidKeyException e) {
		throw new SSH2FatalException("Invalid key derived in " +
					     "readKeyPair: " + e);
	    }
	}

	int parmLen = buf.readInt();
	if(parmLen > buf.getMaxReadSize() || parmLen < 0) {
	    throw new SSH2AccessDeniedException("Invalid password or corrupt key blob");
	}

	KeySpec prvSpec     = null;
	KeySpec pubSpec     = null;
	String  keyFactType = null;

	if(algorithm.equals("ssh-dss")) {
	    keyFactType = "DSA";
	    int value = buf.readInt();
	    if(value == 0) {
		BigInteger p, q, g, x, y;
		p = buf.readBigIntBits();
		g = buf.readBigIntBits();
		q = buf.readBigIntBits();
		y = buf.readBigIntBits();
		x = buf.readBigIntBits();

		prvSpec = new DSAPrivateKeySpec(x, p, q, g);
		pubSpec = new DSAPublicKeySpec(y, p, q, g);
	    } else {
		// !!! TODO: predefined params
		throw new Error("Predefined DSA params not implemented (" +
				value + ") '" + buf.readJavaString() + "'");
	    }
	} else if(algorithm.equals("ssh-rsa")) {
	    keyFactType = "RSA";
	    BigInteger n, e, d, p, q, u, pe, qe;

	    n = buf.readBigIntBits();
	    e = buf.readBigIntBits();
	    d = buf.readBigIntBits();
	    u = buf.readBigIntBits();
	    p = buf.readBigIntBits();
	    q = buf.readBigIntBits();

	    // !!! OUCH
	    //
	    pe = RSAAlgorithm.getPrimeExponent(d, p);
	    qe = RSAAlgorithm.getPrimeExponent(d, q);

	    prvSpec = new RSAPrivateCrtKeySpec(n, e, d, p, q, pe, qe, u);
	    pubSpec = new RSAPublicKeySpec(n, e);
	} else {
	    throw new SSH2FatalException("Unsupported key type: " + algorithm);
	}

	try {
	    KeyFactory keyFact = KeyFactory.getInstance(keyFactType);
	    return new KeyPair(keyFact.generatePublic(pubSpec),
			       keyFact.generatePrivate(prvSpec));
	} catch (Exception e) {
	    throw new SSH2FatalException("Error in readKeyPair: " + e );
	}
    }

    public static byte[] writeKeyPair(String algorithm, String password,
				      String cipher, KeyPair keyPair)
	throws SSH2FatalException, NoSuchAlgorithmException
    {
	SSH2DataBuffer toBeEncrypted = new SSH2DataBuffer(8192);
	int            totLen        = 0;
	String         type          = null;

	toBeEncrypted.writeInt(0); // unenc length (filled in below)

	if(algorithm.equals("ssh-dss")) {
	    DSAPublicKey  pubKey = (DSAPublicKey)keyPair.getPublic();
	    DSAPrivateKey prvKey = (DSAPrivateKey)keyPair.getPrivate();
	    DSAParams     params = pubKey.getParams();

	    toBeEncrypted.writeInt(0); // type 0 is explicit params (as opposed to predefined)
	    toBeEncrypted.writeBigIntBits(params.getP());
	    toBeEncrypted.writeBigIntBits(params.getG());
	    toBeEncrypted.writeBigIntBits(params.getQ());
	    toBeEncrypted.writeBigIntBits(pubKey.getY());
	    toBeEncrypted.writeBigIntBits(prvKey.getX());

	    type = "dl-modp{sign{dsa-nist-sha1},dh{plain}}";
	} else if(algorithm.equals("ssh-rsa")) {
	    RSAPublicKey     pubKey = (RSAPublicKey)keyPair.getPublic();
	    RSAPrivateCrtKey prvKey = (RSAPrivateCrtKey)keyPair.getPrivate();

	    toBeEncrypted.writeBigIntBits(pubKey.getModulus());
	    toBeEncrypted.writeBigIntBits(pubKey.getPublicExponent());
	    toBeEncrypted.writeBigIntBits(prvKey.getPrivateExponent());
	    toBeEncrypted.writeBigIntBits(prvKey.getCrtCoefficient());
	    toBeEncrypted.writeBigIntBits(prvKey.getPrimeP());
	    toBeEncrypted.writeBigIntBits(prvKey.getPrimeQ());

	    type = "if-modn{sign{rsa-pkcs1-md5},sign{rsa-pkcs1-sha1}," + 
		"encrypt{rsa-pkcs1-none}}";
	} else {
	    throw new SSH2FatalException("Unsupported key type: " + algorithm);
	}

	totLen = toBeEncrypted.getWPos();
	toBeEncrypted.setWPos(0);
	toBeEncrypted.writeInt(totLen - 4);

	if(!cipher.equals("none")) {
	    try {
		int keyLen =
		    SSH2TransportPreferences.getCipherKeyLen(cipher);
		String cipherName =
		    SSH2TransportPreferences.ssh2ToJCECipher(cipher);
		byte[] key = expandPasswordToKey(password, keyLen);
		Cipher encrypt = Cipher.getInstance(cipherName);
		encrypt.init(Cipher.ENCRYPT_MODE,
			     new SecretKeySpec(key, encrypt.getAlgorithm()));
		byte[] data = toBeEncrypted.getData();
		int    bs   = encrypt.getBlockSize();
		totLen += (bs - (totLen % bs)) % bs;
		totLen = encrypt.doFinal(data, 0, totLen, data, 0);
	    } catch (InvalidKeyException e) {
		throw new SSH2FatalException("Invalid key derived in " +
					     "SSH2DSS.writeKeyPair: " + e);
	    }
	}

	SSH2DataBuffer buf = new SSH2DataBuffer(512 + totLen);

	buf.writeInt(SSH_PRIVATE_KEY_MAGIC);
	buf.writeInt(0); // total length (filled in below)
	buf.writeString(type);
	buf.writeString(cipher);
	buf.writeString(toBeEncrypted.getData(), 0, totLen);

	totLen = buf.getWPos();
	buf.setWPos(4);
	buf.writeInt(totLen);

	byte[] keyBlob = new byte[totLen];
	System.arraycopy(buf.data, 0, keyBlob, 0, totLen);

	return keyBlob;
    }

    public static void writeKeyPairFile(KeyPair keyPair,
					String password, String cipher,
					ASCIIArmour armour,
					OutputStream out)
	throws IOException, SSH2FatalException, NoSuchAlgorithmException
    {
	String algorithm = getKeyAlgorithm(keyPair.getPublic());
	byte[] keyBlob   = writeKeyPair(algorithm, password, cipher, keyPair);
	armour.setCanonicalLineEnd(false);
	armour.encode(out, keyBlob);
    }

    public static String writeOpenSSHPublicKeyString(SSH2Signature signature,
						     String comment)
	throws SSH2Exception
    {
	byte[] keyBlob = signature.getPublicKeyBlob();
	String format  = signature.getAlgorithmName();
	return writeOpenSSHPublicKeyString(keyBlob, format, comment);
    }

    public static String writeOpenSSHPublicKeyString(byte[] keyBlob,
						     String format, 
						     String comment)
    {
	byte[] base64  = Base64.encode(keyBlob);

	StringBuffer buf = new StringBuffer();

	buf.append(format);
	buf.append(" ");
	buf.append(new String(base64));
	buf.append(" ");
	buf.append(comment);
	buf.append("\n");

	return buf.toString();
    }


    public static SSH2Signature readOpenSSHPublicKeyString(String pubKeyString)
	throws SSH2Exception
    {
	StringTokenizer st = new StringTokenizer(pubKeyString);

	String format  = null;
	String base64  = null;
	String comment = null;

	try {
	    format  = st.nextToken();
	    base64  = st.nextToken();
	    comment = st.nextToken();
	} catch (NoSuchElementException e) {
	    throw new SSH2FatalException("Corrupt openssh public key string");
	}

	SSH2Signature signature = SSH2Signature.getInstance(format);
	byte[]        keyBlob   = Base64.decode(base64.getBytes());
	signature.initVerify(keyBlob);

	return signature;
    }

    public static void main(String[] argv) {
	try {
	    java.io.FileInputStream f =
		new java.io.FileInputStream("/home/mats/dsakey.prv");
	    SSH2SimplePKIFile dss =
		SSH2SimplePKIFile.createFromKeyPairFile("foobar", f, true);

	    byte[] data = new byte[] { 0x01, 0x02, 0x03, 0x04, 0x05,
				       0x06, 0x07, 0x08, 0x09, 0x0a };
	    SSH2Signature sign = dss.getSignature();
	    byte[] sig = sign.sign(data);

	    sign = SSH2Signature.getInstance("ssh-dss");
	    sign.initVerify(dss.getKeyPair().getPublic());

	    System.out.println("Verify: " + sign.verify(sig, data));

	    String subject = dss.getArmour().getHeaderField(FILE_SUBJECT);
	    String comment = dss.getArmour().getHeaderField(FILE_COMMENT);
	    System.out.println("subject: " + subject);
	    System.out.println("comment: " + comment);
	    java.io.FileOutputStream of =
		new java.io.FileOutputStream("/home/mats/dsakey2.prv");
	    dss.writeKeyPairFile(dss.getKeyPair(), "foobar", "3des-cbc", 
				 dss.getArmour(), of);
	    of = new java.io.FileOutputStream("/home/mats/dsakey2.pub");
	    dss.writeKeyPairFile(dss.getKeyPair(), "foobar", "3des-cbc",
				 dss.getArmour(), of);
	    ASCIIArmour pubArm =
		SSH2SimplePKIFile.getPublicKeyArmour(subject, comment);
	    dss.writePublicKeyFile(dss.getKeyPair().getPublic(), pubArm,
				   of);

	} catch (Exception e) {
	    System.out.println("Error: " + e);
	    e.printStackTrace();
	}
    }

}
