# pycryptodome
from Crypto.PublicKey import ECC
from Crypto.Hash import HMAC, SHA256
from Crypto.Cipher import AES
from Crypto.Util import Counter
from Crypto.Random import get_random_bytes
from base64 import b64encode, b64decode
from sys import exit
import json

class CANEdgeSecurity(object):

    _ECC_CURVE = 'secp256r1'

    # Generate a new symmetric key using device public key
    def __gen_sym_key(self, device_public_key_string_xy):

        # Construct ECC point from device public key
        device_kpub_int_x = int.from_bytes(device_public_key_string_xy[:32], byteorder='big')
        device_kpub_int_y = int.from_bytes(device_public_key_string_xy[32:], byteorder='big')
        device_kpub_p = ECC.construct(curve=self._ECC_CURVE, point_x=device_kpub_int_x, point_y=device_kpub_int_y)

        # Create user private / public key pair
        user_key_pair = ECC.generate(curve=self._ECC_CURVE)

        # The shared secret is calculated using the device public point and the private key.
        # The secret is the x-coordinate of the resulting point
        shared_secret_int = (device_kpub_p.pointQ * user_key_pair.d).x

        # Calculate symmetric key from shared secret using hmac-sha256 and static data "config"
        shared_secret_string = int(shared_secret_int).to_bytes(32, byteorder='big')
        h = HMAC.new(shared_secret_string, msg=b'config', digestmod=SHA256)

        # Truncate to get shared private key (16 bytes)
        symmetric_key = h.digest()[0:16]

        # Create public key byte strings
        user_kpub_string_x = int(user_key_pair.pointQ.x).to_bytes(32, byteorder='big')
        user_kpub_string_y = int(user_key_pair.pointQ.y).to_bytes(32, byteorder='big')

        return symmetric_key, user_kpub_string_x + user_kpub_string_y

    def __init__(self, **kwargs):
        """
        Constructs a CANedge security object used for configuration file field encryption
        :param kwargs:
        :keyword devicePublicKeyBase64: Device public key string in base64 format
        :keyword deviceFilePath: Device.json file path, containing the kpub key field
        :keyword symkeyBase64: Symmetric cipher key
        :return:
        """

        self.devicePublicKeyBase64 = kwargs.get("devicePublicKeyBase64")
        self.deviceFilePath = kwargs.get("deviceFilePath")
        self.symkeyBase64 = kwargs.get("symkeyBase64")

        # Is symmetric key provided?
        if self.symkeyBase64 is None:

            # Public key base 64 string provided ?
            if self.devicePublicKeyBase64 is not None:
                self.device_kpub_string_xy = b64decode(self.devicePublicKeyBase64)
                self.ksym, self.user_kpub_string_xy = self.__gen_sym_key(self.device_kpub_string_xy)

            # Device key path provided ?
            elif self.deviceFilePath is not None:
                with open(self.deviceFilePath) as f:
                    self.device_kpub_string_xy = b64decode(json.load(f)["kpub"])
                    self.ksym, self.user_kpub_string_xy = self.__gen_sym_key(self.device_kpub_string_xy)
            else:
                exit()
        else:
            self.user_kpub_string_xy = None
            self.device_kpub_string_xy = None
            self.ksym = b64decode(self.symkeyBase64)

    def getSymKeyBase64(self):
        return b64encode(self.ksym).decode()

    def getDevicePublicKeyBase64(self):
        if self.device_kpub_string_xy is None:
            exit("Device public key not set")
        return b64encode(self.device_kpub_string_xy).decode()

    def getUserPublicKeyBase64(self):
        if self.user_kpub_string_xy is None:
            exit("User public key not set")
        return b64encode(self.user_kpub_string_xy).decode()

    def encryptAndEncode(self, field_value: str) -> str:
        # Create CTR cipher (the library creates a random nonce)
        iv = get_random_bytes(16)
        ctr = Counter.new(128, initial_value=int.from_bytes(iv, byteorder='big'))
        cipher = AES.new(self.ksym, AES.MODE_CTR, counter=ctr)

        # Encrypt data
        ct = cipher.encrypt(bytes(field_value, 'ascii'))

        # Concatenate and encode
        return b64encode(iv + ct).decode()

if __name__ == "__main__":

    #
    # Create CANedge security object with device public key
    #
    device_kpub_base64 = "oNjS7h06MDbZSxVcvWvt1glSVfKmZ2QNJEZOhNznpmVhAmys6Vjc5L8a8H5HakyH1s1sFMPah7rHU7bnT398oQ=="
    sec1 = CANEdgeSecurity(devicePublicKeyBase64=device_kpub_base64)

    # Encrypt field
    field_value_base64 = sec1.encryptAndEncode('fieldValue1')

    #
    # Create CANedge security object with device.json file
    #
    sec2 = CANEdgeSecurity(deviceFilePath="device.json")

    # Encrypt field
    field_value_base64 = sec2.encryptAndEncode('fieldValue2')

    #
    # Reload symkey and reused without creating new user keys
    #
    symKeyBase64 = sec2.getSymKeyBase64()
    sec3 = CANEdgeSecurity(symkeyBase64=symKeyBase64)

    # Encrypt field
    field_value_base64 = sec3.encryptAndEncode('fieldValue3')