AES-256 Implementation with CTR Mode and X9.31 Extensions


Project Overview

Source: Link to full Google Colab Notebook Notebook.

The AES project implements the Advanced Encryption Standard (AES-256) for secure data encryption and pseudo-random number generation. This project leverages Professor Avi Kak's lecture materials from Purdue University's ECE404: Introduction to Computer Security will handle the following functionalities:

  • Encrypting and decrypting text files using AES-256.
  • Encrypting PPM image files in Counter (CTR) mode.
  • Generating pseudo-random numbers using the ANSI X9.31 algorithm.

The project uses the BitVector library to replicate bitwise operations.

Author: William Wong

Emails: wong371@purdue.edu, willwong812@gmail.com

Source: Adapted from Professor Avi Kak's lecture notes and code examples Lecture 8.

Imports and Module Loading

In [1]:
!pip install BitVector
!pip install matplotlib

import sys
import matplotlib.pyplot as plt
from PIL import Image
from BitVector import *
Requirement already satisfied: BitVector in /usr/local/lib/python3.11/dist-packages (3.5.0)
Requirement already satisfied: matplotlib in /usr/local/lib/python3.11/dist-packages (3.10.0)
Requirement already satisfied: contourpy>=1.0.1 in /usr/local/lib/python3.11/dist-packages (from matplotlib) (1.3.2)
Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.11/dist-packages (from matplotlib) (0.12.1)
Requirement already satisfied: fonttools>=4.22.0 in /usr/local/lib/python3.11/dist-packages (from matplotlib) (4.58.0)
Requirement already satisfied: kiwisolver>=1.3.1 in /usr/local/lib/python3.11/dist-packages (from matplotlib) (1.4.8)
Requirement already satisfied: numpy>=1.23 in /usr/local/lib/python3.11/dist-packages (from matplotlib) (2.0.2)
Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.11/dist-packages (from matplotlib) (24.2)
Requirement already satisfied: pillow>=8 in /usr/local/lib/python3.11/dist-packages (from matplotlib) (11.2.1)
Requirement already satisfied: pyparsing>=2.3.1 in /usr/local/lib/python3.11/dist-packages (from matplotlib) (3.2.3)
Requirement already satisfied: python-dateutil>=2.7 in /usr/local/lib/python3.11/dist-packages (from matplotlib) (2.9.0.post0)
Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.11/dist-packages (from python-dateutil>=2.7->matplotlib) (1.17.0)

AES Class and Functions

  • With Counter (CTR) mode Extension.
  • Pseudo-random numbers using the ANSI X9.31 algorithm.
In [2]:
class AES:

    def __init__(self, keyfile: str) -> None:
        with open(keyfile, 'r') as key_file:
            key_text = key_file.read().strip()
        self.AES_modulus = BitVector(bitstring='100011011')
        self.key_bv = BitVector(textstring=key_text)
        self.key_words = self.gen_key_schedule_256(self.key_bv)
        self.key_words_128 = self.gen_key_schedule_128(self.key_bv)
        self.subBytesTable, self.invSubBytesTable = self.genTables()

    def genTables(self):
        subBytesTable = []
        invSubBytesTable = []
        c = BitVector(bitstring='01100011')
        d = BitVector(bitstring='00000101')
        for i in range(0, 256):
            # For the encryption SBox
            a = BitVector(intVal = i, size=8).gf_MI(self.AES_modulus, 8) if i != 0 else BitVector(intVal=0)
            # For bit scrambling for the encryption SBox entries:
            a1,a2,a3,a4 = [a.deep_copy() for x in range(4)]
            a ^= (a1 >> 4) ^ (a2 >> 5) ^ (a3 >> 6) ^ (a4 >> 7) ^ c
            subBytesTable.append(int(a))
            # For the decryption Sbox:
            b = BitVector(intVal = i, size=8)
            # For bit scrambling for the decryption SBox entries:
            b1,b2,b3 = [b.deep_copy() for x in range(3)]
            b = (b1 >> 2) ^ (b2 >> 5) ^ (b3 >> 7) ^ d
            check = b.gf_MI(self.AES_modulus, 8)
            b = check if isinstance(check, BitVector) else 0
            invSubBytesTable.append(int(b))
        return subBytesTable, invSubBytesTable

    def gen_subbytes_table(self):
        subBytesTable = []
        c = BitVector(bitstring='01100011')
        for i in range(256):
            a = BitVector(intVal=i, size=8).gf_MI(self.AES_modulus, 8) if i != 0 else BitVector(intVal=0)
            a1, a2, a3, a4 = [a.deep_copy() for _ in range(4)]
            a ^= (a1 >> 4) ^ (a2 >> 5) ^ (a3 >> 6) ^ (a4 >> 7) ^ c
            subBytesTable.append(int(a))
        return subBytesTable

    def gee(self, keyword, round_constant, byte_sub_table):
        '''
        This is the g() function you see in Figure 4 of Lecture 8.
        '''
        rotated_word = keyword.deep_copy()
        rotated_word << 8
        newword = BitVector(size=0)
        for i in range(4):
            newword += BitVector(intVal = byte_sub_table[rotated_word[8*i:8*i+8].intValue()], size = 8)
        newword[:8] ^= round_constant
        round_constant = round_constant.gf_multiply_modular(BitVector(intVal=0x02), self.AES_modulus, 8)
        return newword, round_constant

    def gen_key_schedule_128(self, key_bv):
        byte_sub_table = self.gen_subbytes_table()
        # We need 44 keywords in the key schedule for 128 bit AES. Each keyword is 32-bits
        # wide. The 128-bit AES uses the first four keywords to xor the input block with.
        # Subsequently, each of the 10 rounds uses 4 keywords from the key schedule. We will
        # store all 44 keywords in the following list:
        key_words = [None for i in range(44)]
        round_constant = BitVector(intVal = 0x01, size=8)
        for i in range(4):
            key_words[i] = key_bv[i*32 : i*32 + 32]
        for i in range(4,44):
            if i%4 == 0:
                kwd, round_constant = self.gee(key_words[i-1], round_constant, byte_sub_table)
                key_words[i] = key_words[i-4] ^ kwd
            else:
                key_words[i] = key_words[i-4] ^ key_words[i-1]
        return key_words

    def gen_key_schedule_256(self, key_bv):
        # We need 60 keywords (each keyword consists of 32 bits) in the key schedule for
        # 256 bit AES. The 256-bit AES uses the first four keywords to xor the input
        # block with. Subsequently, each of the 14 rounds uses 4 keywords from the key
        # schedule. We will store all 60 keywords in the following list:
        byte_sub_table = self.gen_subbytes_table()
        key_words = [None for _ in range(60)]
        round_constant = BitVector(intVal=0x01, size=8)
        for i in range(8):
            key_words[i] = key_bv[i*32 : i*32 + 32]
        for i in range(8, 60):
            if i % 8 == 0:
                kwd, round_constant = self.gee(key_words[i-1], round_constant, byte_sub_table)
                key_words[i] = key_words[i-8] ^ kwd
            elif (i - (i//8)*8) < 4:
                key_words[i] = key_words[i-8] ^ key_words[i-1]
            elif (i - (i//8)*8) == 4:
                key_words[i] = BitVector(size=0)
                for j in range(4):
                    key_words[i] += BitVector(intVal=byte_sub_table[key_words[i-1][8*j:8*j+8].intValue()], size=8)
                key_words[i] ^= key_words[i-8]
            elif ((i - (i//8)*8) > 4) and ((i - (i//8)*8) < 8):
                key_words[i] = key_words[i-8] ^ key_words[i-1]
            else:
                sys.exit("error in key scheduling algo for i = %d" % i)
        return key_words

    def encrypt(self, plaintext: str, ciphertext: str) -> None:
        with open(plaintext, 'r') as plain_file:
            plaintext_bv = BitVector(textstring=plain_file.read())
        plaintext_bv.pad_from_right(128 - (len(plaintext_bv) % 128))
        state_array = [[0 for _ in range(4)] for _ in range(4)]
        with open(ciphertext, 'w') as cipher_file:
            for i in range(0, len(plaintext_bv), 128):
                block = plaintext_bv[i:i+128]

                # create state array
                for row in range(4):
                    for col in range(4):
                        state_array[row][col] = block[32*col + 8*row:32*col + 8*(row+1)]

                # ALL FIRST ROUND
                # XOR with the first 4 words of the key schedule
                round_key = [self.key_words[i] for i in range(4)]
                state_array = self.add_round_key(state_array, round_key)

                # Sub Bytes
                state_array = self.sub_bytes(state_array, self.subBytesTable)
                state_array = self.shift_rows(state_array)

                # Mix Columns
                state_array = self.mix_columns(state_array)

                # XOR with the Round Key for round 1
                round_key = [self.key_words[4 + i] for i in range(4)]
                state_array = self.add_round_key(state_array, round_key)

                # loop for other rounds
                for round in range(2, 14):
                    state_array = self.sub_bytes(state_array, self.subBytesTable)
                    state_array = self.shift_rows(state_array)
                    state_array = self.mix_columns(state_array) if round < 13 else state_array
                    state_array = self.add_round_key(state_array, self.key_words[4 * round: 4 * (round + 1)])

                # no need for mix columns for final round
                state_array = self.sub_bytes(state_array, self.subBytesTable)
                state_array = self.shift_rows(state_array)
                round_key = [self.key_words[56 + i] for i in range(4)]
                state_array = self.add_round_key(state_array, round_key)

                ciphertext_bv = BitVector(size=0)
                for row in range(4):
                    for col in range(4):
                        ciphertext_bv += state_array[row][col]
                cipher_file.write(ciphertext_bv.get_hex_string_from_bitvector())

    def decrypt(self, ciphertext: str, decrypted: str) -> None:
        with open(ciphertext, 'r') as f:
            ciphertext_bv = BitVector(hexstring=f.read())

        with open(decrypted, 'w') as out_f:
            for i in range(0, len(ciphertext_bv), 128):
                block = ciphertext_bv[i:i+128]

                # create the state array
                state_array = [[block[32*col + 8*row : 32*col + 8*(row+1)]
                        for row in range(4)] for col in range(4)]

                # add last round key first
                state_array = self.add_round_key(state_array, self.key_words[56:60])

                # inverse throughout the rounds
                for round_num in range(13, 0, -1):
                    state_array = self.inv_shift_rows(state_array)
                    state_array = self.inv_sub_bytes(state_array, self.invSubBytesTable)
                    state_array = self.add_round_key(state_array, self.key_words[4*round_num:4*(round_num+1)])
                    state_array = self.inv_mix_columns(state_array)

                # No need for mix columns on last round
                state_array = self.inv_shift_rows(state_array)
                state_array = self.inv_sub_bytes(state_array, self.invSubBytesTable)
                state_array = self.add_round_key(state_array, self.key_words[0:4])

                output = BitVector(size=0)
                for col in range(4):
                    for row in range(4):
                        output += state_array[row][col]
                decrypted_text = output.get_text_from_bitvector()
                out_f.write(decrypted_text)

    def add_round_key(self, state_array, round_key):
        new_state_array = [[0 for _ in range(4)] for _ in range(4)]
        for row in range(4):
            for col in range(4):
                round_key_bytes = round_key[col][8*row:8*(row+1)]
                new_state_array[row][col] = state_array[row][col] ^ round_key_bytes
        return new_state_array

    def sub_bytes(self, state_array, sub_table):
        new_state_array = [[0 for _ in range(4)] for _ in range(4)]
        for row in range(4):
            for col in range(4):
                new_state_array[row][col] = BitVector(intVal=sub_table[state_array[row][col].intValue()], size=8)
        return new_state_array

    def shift_rows(self, state_array):
        new_state_array = [[0 for _ in range(4)] for _ in range(4)]
        for row in range(4):
            for col in range(4):
                new_state_array[row][col] = state_array[row][(col + row) % 4]
        return new_state_array

    def mix_columns(self, state_array):
        new_state_array = [[0] * 4 for _ in range(4)]
        for col in range(4):
            new_state_array[0][col] = (
                state_array[0][col].gf_multiply_modular(BitVector(intVal=0x02), self.AES_modulus, 8) ^
                state_array[1][col].gf_multiply_modular(BitVector(intVal=0x03), self.AES_modulus, 8) ^
                state_array[2][col] ^
                state_array[3][col]
            )
            new_state_array[1][col] = (
                state_array[0][col] ^
                state_array[1][col].gf_multiply_modular(BitVector(intVal=0x02), self.AES_modulus, 8) ^
                state_array[2][col].gf_multiply_modular(BitVector(intVal=0x03), self.AES_modulus, 8) ^
                state_array[3][col]
            )
            new_state_array[2][col] = (
                state_array[0][col] ^
                state_array[1][col] ^
                state_array[2][col].gf_multiply_modular(BitVector(intVal=0x02), self.AES_modulus, 8) ^
                state_array[3][col].gf_multiply_modular(BitVector(intVal=0x03), self.AES_modulus, 8)
            )
            new_state_array[3][col] = (
                state_array[0][col].gf_multiply_modular(BitVector(intVal=0x03), self.AES_modulus, 8) ^
                state_array[1][col] ^
                state_array[2][col] ^
                state_array[3][col].gf_multiply_modular(BitVector(intVal=0x02), self.AES_modulus, 8)
            )
        return new_state_array

    def inv_shift_rows(self, state_array):
        new_state_array = [[0] * 4 for _ in range(4)]
        for row in range(4):
            for col in range(4):
                new_state_array[row][col] = state_array[row][(col - row) % 4]
        return new_state_array

    def inv_sub_bytes(self, state_array, inv_sub_table):
        new_state_array = [[0] * 4 for _ in range(4)]
        for row in range(4):
            for col in range(4):
                new_state_array[row][col] = BitVector(intVal=inv_sub_table[state_array[row][col].intValue()], size=8)
        return new_state_array

    def inv_mix_columns(self, state_array):
        new_state_array = [[0] * 4 for _ in range(4)]
        for col in range(4):
            new_state_array[0][col] = (
                state_array[0][col].gf_multiply_modular(BitVector(intVal=0x0E), self.AES_modulus, 8) ^
                state_array[1][col].gf_multiply_modular(BitVector(intVal=0x0B), self.AES_modulus, 8) ^
                state_array[2][col].gf_multiply_modular(BitVector(intVal=0x0D), self.AES_modulus, 8) ^
                state_array[3][col].gf_multiply_modular(BitVector(intVal=0x09), self.AES_modulus, 8)
            )
            new_state_array[1][col] = (
                state_array[0][col].gf_multiply_modular(BitVector(intVal=0x09), self.AES_modulus, 8) ^
                state_array[1][col].gf_multiply_modular(BitVector(intVal=0x0E), self.AES_modulus, 8) ^
                state_array[2][col].gf_multiply_modular(BitVector(intVal=0x0B), self.AES_modulus, 8) ^
                state_array[3][col].gf_multiply_modular(BitVector(intVal=0x0D), self.AES_modulus, 8)
            )
            new_state_array[2][col] = (
                state_array[0][col].gf_multiply_modular(BitVector(intVal=0x0D), self.AES_modulus, 8) ^
                state_array[1][col].gf_multiply_modular(BitVector(intVal=0x09), self.AES_modulus, 8) ^
                state_array[2][col].gf_multiply_modular(BitVector(intVal=0x0E), self.AES_modulus, 8) ^
                state_array[3][col].gf_multiply_modular(BitVector(intVal=0x0B), self.AES_modulus, 8)
            )
            new_state_array[3][col] = (
                state_array[0][col].gf_multiply_modular(BitVector(intVal=0x0B), self.AES_modulus, 8) ^
                state_array[1][col].gf_multiply_modular(BitVector(intVal=0x0D), self.AES_modulus, 8) ^
                state_array[2][col].gf_multiply_modular(BitVector(intVal=0x09), self.AES_modulus, 8) ^
                state_array[3][col].gf_multiply_modular(BitVector(intVal=0x0E), self.AES_modulus, 8)
            )
        return new_state_array

    def ctr_aes_image(self, iv, image_file, enc_image):
        '''
        Encrypts a PPM image file using AES-256 in CTR mode.
        Inputs:
        iv (BitVector): 128-bit initialization vector
        image_file (str): input .ppm file name
        enc_image (str): output .ppm file name
        '''
        with open(image_file, "rb") as img_file:
            header = []
    # ... (other methods not shown for brevity)
In [3]:
# Note: This cell contains the execution of the AES class, producing an image and pseudo-random numbers
AES Encrypted Image
=== ANSI X9.31 Pseudo-Random Number Generator ===
Generated 5 pseudo-random numbers:
Number 1: 248079769487624148781305850458853487813
Number 2: 65814991611129857698228726854865446749
Number 3: 38758186303580379152434860054134701891
Number 4: 101836235598497239834638240368629799639
Number 5: 110128351161781364920043990112939672576

Conclusion

The AES-Crypt project extends AES-256 with two functionalities: Counter (CTR) mode image encryption and ANSI X9.31 pseudo-random number generation (PRNG). These extensions showcase AES-256’s versatility in cryptographic applications. The CTR mode encrypts PPM images securely by scrambling pixel data, making attacks to decrypt this images almost useless. The ANSI X9.31 PRNG generates pseudo-random numbers using AES-256. Unlike true random generators, like when a leaf falls from a tree, PRNGs like X9.31 are deterministic but designed to be unpredictable without the seed value, so that attackers cannot easily predict outputs. The goal of X9.31 is to provide secure random numbers that can resist cryptographic attacks.