#include "DNSKEY_RR.h"

using namespace BigIntegerLibrary;
using namespace System::Security::Cryptography;

namespace ADNS {

	DNSKEY_RR::DNSKEY_RR()
	{
		rdata = gcnew array<Byte>(4);
		SetProtocol();
		rr_type = RR_TYPE::DNSKEY;
		rr_class = RR_CLASS::IN;
	}

	Void DNSKEY_RR::SetProtocol()
	{
		rdata[2] = 3; //ALL DNSKEY records must have a protocol field of 3.  If not, they're invalid.  (RFC 4034)
		return;
	}

	Void DNSKEY_RR::SetAlgorithm(CRYPTO_ALGORITHM a)
	{
		rdata[3] = (Byte) a;
		return;
	}

	CRYPTO_ALGORITHM DNSKEY_RR::GetAlgorithm()
	{
		return (CRYPTO_ALGORITHM) rdata[3];
	}

	array<Byte>^ DNSKEY_RR::GetKey()
	{
		array<Byte>^ tmp;

		if (rdata->Length < 5)
			return nullptr;

		tmp = gcnew array<Byte>(rdata->Length - 4);
		tmp->Copy(rdata,4,tmp,0,tmp->Length);

		return tmp;
	}

	Void DNSKEY_RR::SetKey(array<Byte>^ key)
	{
		if (rdata->Length != (4 + key->Length))
			rdata->Resize(rdata,4 + key->Length);

		key->CopyTo(rdata,4);

		return;
	}

	bool DNSKEY_RR::GetZoneKeyFlag()  //TODO:: Finish these!  All flags are zero except possibly bit 7! (specifically 7 from left)
	{
		if (rdata[0] == 1)
			return true;

		return false;
	
	}

	UInt16 DNSKEY_RR::GetFlags()
	{
		array<Byte>^ tmp = gcnew array<Byte>(2);
		rdata->Copy(rdata,0,tmp,0,2);
		tmp->Reverse(tmp);

		return BitConverter::ToUInt16(tmp,0);
		
	}

	Void DNSKEY_RR::SetZoneKeyFlag(bool flag) //TODO:: Finish these!
	{
		if (flag)
			rdata[0] = 1;
		else
			rdata[0] = 0;

		return;
	}

	bool DNSKEY_RR::GetRevokedFlag()
	{
		if (rdata[1] & 0x80)
			return true;
		return false;
	}

	Void DNSKEY_RR::SetRevokedFlag(bool flag)
	{
		if (flag)
		{
			rdata[1] = rdata[1] | 0x80;
		}
		else
			rdata[1] = rdata[1] & 0x7F;
		return;
	}

	bool DNSKEY_RR::GetKeySignFlag()
	{
		if (rdata[1] & 0x01)
			return true;
		return false;
	}

	Void DNSKEY_RR::SetKeySignFlag(bool flag)
	{
		if (flag)
			rdata[1] = rdata[1] | 0x01;
		else
			rdata[1] = rdata[1] & 0xFE;
		return;
	}

	String^ DNSKEY_RR::Print()
	{
		String^ output;
		
		output = PrintHeader();
		output += " ";
		output += Convert::ToString(GetFlags());
		output += " ";
		output += Convert::ToString(rdata[2]);
		output += " ";
		output += Convert::ToString(rdata[3]);
		output += " ";
		output += Convert::ToBase64String(GetKey());

		return output;
	}
	UInt16 DNSKEY_RR::CalcKeytag()
	{
		unsigned short int ac16 = 0;
		unsigned int ac32 = 0;
		int i;

		if  (rdata->Length < 4)
			return 0;

		if (rdata[3] == (Byte)CRYPTO_ALGORITHM::RSAMD5)
		{
			if (rdata->Length > 4)
			{
				ac16 = BitConverter::ToUInt16(rdata,rdata->Length - 3);
			}
			return ac16;
		}
		else
		{
			ac32 = 0;
			for (i = 0; i < rdata->Length; ++i) {
				ac32 += (i & 1) ? rdata[i] : rdata[i] << 8;
			}
			ac32 += (ac32 >> 16) & 0xFFFF;
		}

		return (unsigned short int) (ac32 & 0xFFFF);
	}

	bool DNSKEY_RR::ValidateRSASHA1Signature(array<Byte>^ signature, array<Byte>^ thingtohash)
	{

		//Algorithm according to RFC4034 (section 3.1.8.1) and RFC3110 (Section 3)
		//Note that RFC 4034 obsoletes RFC2535 referenced by RFC 3110

		SHA1^ sha = gcnew SHA1CryptoServiceProvider();
		RSACryptoServiceProvider^ rsa = gcnew RSACryptoServiceProvider();
		RSAParameters rsaparms;
		int i;
		array<Byte>^ hash;
		array<Byte>^ thingtosign;
		array<Byte>^ pkcsprefix = { 0x30,0x21,0x30,0x09,0x06,0x05,0x2B,0x0E,0x03,0x02,0x1A,0x05,0x00,0x04,0x14 };
		int numff = 0;
		array<Byte>^ key;
		int exponentsize;
		int modulussize;
		int modulusstart = 0;
		int pos = 0;
		array<Byte>^ exponent;
		array<Byte>^ modulus;

		BigInteger^ modi = 0;
		BigInteger^ sigi = gcnew BigInteger(0);
		BigInteger^ thingi = gcnew BigInteger(0);
		BigInteger^ expi = gcnew BigInteger(0);
		BigInteger^ resi = gcnew BigInteger(0);

		if ((GetAlgorithm() != CRYPTO_ALGORITHM::RSASHA1) && (GetAlgorithm() != CRYPTO_ALGORITHM::RSASHA1_NSEC3))
			return false;

		hash = sha->ComputeHash(thingtohash);
		
		key = GetKey();

		if (key[0] == 0) //exponent is longer than 255 bytes
		{
			exponentsize = IPAddress::NetworkToHostOrder((short int) BitConverter::ToUInt16(key,1));
			exponent = gcnew array<Byte>(exponentsize);
			key->Copy(key,3,exponent,0,exponentsize);
			modulusstart = 3 + exponentsize;
		}
		else  //exponent is smaller than or equal to 255 bytes in size
		{
			exponentsize = key[0];
			exponent = gcnew array<Byte>(exponentsize);
			key->Copy(key,1,exponent,0,exponentsize);
			modulusstart = 1 + exponentsize;
		}

		modulussize = key->Length - modulusstart;
		modulus = gcnew array<Byte>(modulussize);
		key->Copy(key,modulusstart,modulus,0,modulussize);


		//Now format the thing to sign, as in RFC 3110, Section 3.
		thingtosign = gcnew array<Byte>(modulussize - 1);
		numff = modulussize - 1 - 17 - hash->Length;
		if (numff <= 0)
		{
			return false;
		}

		thingtosign[pos++] = 0x01;
		for (i = 0; i < numff; ++i)
			thingtosign[pos++] = 0xFF;
		thingtosign[pos++] = 0x00;

		//Copy the prefix
		pkcsprefix->CopyTo(thingtosign,pos);
		pos += pkcsprefix->Length;

		//Copy the hash
		hash->CopyTo(thingtosign,pos);
		
		//Do the signature
		
		for (i = 0; i < exponent->Length; ++i)
		{
			expi = (expi * 256) + exponent[i];
		}

		for (i = 0; i < modulus->Length; ++i)
		{
			modi = (modi * 256) + modulus[i];
		}

		for (i = 0; i < signature->Length; ++i)
		{
			sigi = (sigi * 256) + signature[i];
		}

		for (i = 0; i < thingtosign->Length; ++i)
		{
			thingi = (thingi * 256) + thingtosign[i];
		}
		
		resi = BigInteger::ModularExponentiation(sigi,expi,modi);


	//	MessageBox::Show(Print());
	//	MessageBox::Show(expi->ToString() + "\n\n" + modi->ToString() + "\n\n" + sigi->ToString() + "\n\n" + thingi->ToString() + "\n\n" + resi->ToString() +"\n\n" );

//		Console::WriteLine(thingi->ToString());
//		Console::Write("\n");
//		Console::WriteLine(resi->ToString());

		if (resi == thingi)
			return true;

		return false;

	}

	array<Byte>^ DNSKEY_RR::GetSHA1Hash()
	{
		array<Byte>^ output;
		array<Byte>^ thingtohash;
		int len;
		SHA1^ sha = gcnew SHA1CryptoServiceProvider();

		len = rdata->Length + owner->Size();
		thingtohash = gcnew array<Byte>(len);
		owner->GetName()->CopyTo(thingtohash,0);
		rdata->CopyTo(thingtohash,owner->Size());

		output =  sha->ComputeHash(thingtohash);
		return output;

	}


	array<Byte>^ DNSKEY_RR::GetDigest(HASH_ALGORITHM algo)
	{
		array<Byte>^ thingtohash;
		array<Byte>^ hash;
		int oldlen;
		SHA1^ sha1;
		SHA256^ sha256;

		thingtohash = owner->GetName();
		if (thingtohash == nullptr)
			return nullptr;

		oldlen = thingtohash->Length;
		thingtohash->Resize(thingtohash,oldlen + rdata->Length);
		rdata->CopyTo(thingtohash,oldlen);

		switch(algo)
		{
		case HASH_ALGORITHM::SHA1:
			sha1 = gcnew SHA1CryptoServiceProvider();
			hash = sha1->ComputeHash(thingtohash);
			break;
		case HASH_ALGORITHM::SHA256:
			sha256 = gcnew SHA256Managed();
			hash = sha256->ComputeHash(thingtohash);
			break;
		default:
			hash = nullptr;
			break;
		}

		return hash;
	}

	DNSKEY_RR^ DNSKEY_RR::Clone()
	{
		DNSKEY_RR^ newrr = gcnew DNSKEY_RR();
		newrr->rr_type = rr_type;
		newrr->owner = owner->Clone();
		newrr->ttl = ttl;
		newrr->rr_class = rr_class;
		newrr->rdata = gcnew array<Byte>(rdata->Length);
		rdata->CopyTo(newrr->rdata,0);	
		return newrr;
	}

	bool DNSKEY_RR::ValidateRSASHA256Signature(array<Byte>^ signature, array<Byte>^ thingtohash)
	{

		//Algorithm according to RFC4034 (section 3.1.8.1) and RFC3110 (Section 3)
		//Note that RFC 4034 obsoletes RFC2535 referenced by RFC 3110

		SHA256^ sha256 = gcnew SHA256Managed();
		RSACryptoServiceProvider^ rsa = gcnew RSACryptoServiceProvider();
		RSAParameters rsaparms;
		int i;
		array<Byte>^ hash;
		array<Byte>^ thingtosign;
		array<Byte>^ pkcsprefix = { 0x30,0x21,0x30,0x09,0x06,0x05,0x2B,0x0E,0x03,0x02,0x1A,0x05,0x00,0x04,0x14 };
		int numff = 0;
		array<Byte>^ key;
		int exponentsize;
		int modulussize;
		int modulusstart = 0;
		int pos = 0;
		array<Byte>^ exponent;
		array<Byte>^ modulus;

		BigInteger^ modi = 0;
		BigInteger^ sigi = gcnew BigInteger(0);
		BigInteger^ thingi = gcnew BigInteger(0);
		BigInteger^ expi = gcnew BigInteger(0);
		BigInteger^ resi = gcnew BigInteger(0);

		if (GetAlgorithm() != CRYPTO_ALGORITHM::RSASHA256)
			return false;

		hash = sha256->ComputeHash(thingtohash);
		
		key = GetKey();

		if (key[0] == 0) //exponent is longer than 255 bytes
		{
			exponentsize = IPAddress::NetworkToHostOrder((short int) BitConverter::ToUInt16(key,1));
			exponent = gcnew array<Byte>(exponentsize);
			key->Copy(key,3,exponent,0,exponentsize);
			modulusstart = 3 + exponentsize;
		}
		else  //exponent is smaller than or equal to 255 bytes in size
		{
			exponentsize = key[0];
			exponent = gcnew array<Byte>(exponentsize);
			key->Copy(key,1,exponent,0,exponentsize);
			modulusstart = 1 + exponentsize;
		}

		modulussize = key->Length - modulusstart;
		modulus = gcnew array<Byte>(modulussize);
		key->Copy(key,modulusstart,modulus,0,modulussize);


		//Now format the thing to sign, as in RFC 3110, Section 3.
		thingtosign = gcnew array<Byte>(modulussize - 1);
		numff = modulussize - 1 - 17 - hash->Length;
		if (numff <= 0)
		{
			return false;
		}

		thingtosign[pos++] = 0x01;
		for (i = 0; i < numff; ++i)
			thingtosign[pos++] = 0xFF;
		thingtosign[pos++] = 0x00;

		//Copy the prefix
		pkcsprefix->CopyTo(thingtosign,pos);
		pos += pkcsprefix->Length;

		//Copy the hash
		hash->CopyTo(thingtosign,pos);
		
		//Do the signature
		
		for (i = 0; i < exponent->Length; ++i)
		{
			expi = (expi * 256) + exponent[i];
		}

		for (i = 0; i < modulus->Length; ++i)
		{
			modi = (modi * 256) + modulus[i];
		}

		for (i = 0; i < signature->Length; ++i)
		{
			sigi = (sigi * 256) + signature[i];
		}

		for (i = 0; i < thingtosign->Length; ++i)
		{
			thingi = (thingi * 256) + thingtosign[i];
		}
		
		resi = BigInteger::ModularExponentiation(sigi,expi,modi);

//		Console::WriteLine(thingi->ToString());
//		Console::Write("\n");
//		Console::WriteLine(resi->ToString());

		if (resi == thingi)
			return true;

		return false;

	}

	ResourceRecord^ DNSKEY_RR::ParseResourceRecord(array<Byte>^ domainname, UInt16 rr_type, UInt16 rr_class, UInt32 ttl, UInt16 rdata_len, array<Byte>^ packet, int rdata_start)
	{
		DNSKEY_RR^ dnskeyout;
		array<Byte>^ tmparray;

		dnskeyout = gcnew DNSKEY_RR();
		dnskeyout->owner = gcnew DOMAIN_NAME(domainname);
		dnskeyout->rr_class = (RR_CLASS) rr_class;
		dnskeyout->ttl = ttl;

		dnskeyout->SetAlgorithm((CRYPTO_ALGORITHM) packet[rdata_start + 3]);
		tmparray = gcnew array<Byte>(rdata_len - 4);
		packet->Copy(packet,rdata_start + 4,tmparray,0,rdata_len - 4);
		dnskeyout->SetKey(tmparray);
		if (packet[rdata_start] & 0x01)
		{
			dnskeyout->SetZoneKeyFlag(true);
		}
		if (packet[rdata_start + 1] & 0x01)
			dnskeyout->SetKeySignFlag(true);
		if (packet[rdata_start + 1] & 0x80)
			dnskeyout->SetRevokedFlag(true);

		return dnskeyout;

	}

	String^ DNSKEY_RR::PrintRR(ResourceRecord^ rec)
	{
		return safe_cast<DNSKEY_RR^>(rec)->Print();
	}

	ResourceRecord^ DNSKEY_RR::CloneRR(ResourceRecord^ rec)
	{
		return safe_cast<DNSKEY_RR^>(rec)->Clone();
	}

}
