from math import log, ceil
from binascii import hexlify, unhexlify


def get_diff(ori_string: str, new_string: str):
    """Get the difference in 2 strings
        returns only the differences in terms of original_str, new_string
    """
    ori = []
    acc = []
    for first, second in zip(ori_string, new_string):
        if first == second:
            continue
        ori.append(first)
        acc.append(second)
    return "".join(ori), "".join(acc)


def hex_to_string(hex_string: str) -> str:
    """Converts hex to string"""
    return unhexlify(hex_string).decode()


def str_to_hex(string: str) -> str:
    """Converts String to hex"""
    return hexlify(string.encode()).decode()


def hex_to_int(hex_str: str) -> int:
    """Convert hex to int"""
    return int(hex_str, 16)


def int_to_hex(num: int) -> str:
    """Converts int to hex"""
    return hex(num)[2:]


def get_keyspace(p: float, g: int):
    """Get the keyspace given the following:
        probability: p
        guesses: g
    """

    return g / p


def get_bits(num: int):
    """
    num: Number to be stored
    returns the number of bits required to store the number
    """

    return ceil(log(num) / log(2))


def xor_hex(*hex_strs) -> str:
    """Takes hex strings, xor them and return a hex string"""
    acc = '0'
    for s in hex_strs:
        acc = hex(int(s, 16) ^ int(acc, 16))
    return acc[2:]


def get_birthday_amount(prob: float, key_space: int):
    """
    prob: Probability
    key_space: Key Space
    returns Number of guesses needed for hash collision probability to exceed (prob) in (keyspace)
    """
    acc = 1
    curr = 1
    while 1 - acc < prob:
        acc *= (key_space - curr) / key_space
        curr += 1

    return curr, 1 - acc


def get_birthday_approx(prob: float, keyspace: int):
    """Birthday problem using e approximation
    prob: Probability
    key_space: Key Space
    returns Number of guesses needed for hash collision probability to exceed (prob) in (keyspace)
    """
    val_int = int(keyspace)
    if val != val_int:
        val_int += 1
    return val_int


if __name__ == "__main__":
    print("Mock Quiz Answers:")

    # print("Question 1:")
    # distinct_places = int(input('Number of distinct places: '))
    # res = pow(2, distinct_places)
    # print(f"Part 1: {res}")
    # probability = float(input("Probability: "))
    # print(f"Part 2: {get_bits(get_keyspace(probability, res))}")

    print("\nQuestion 2:")
    print("Part 1: 0x01")
    ciphertext = input("Original hex Ciphertext: ")
    text1 = input("First Correct pad: ")
    do, dn = get_diff(ciphertext, text1)
    print(f"Part 2: {xor_hex(do, dn, '0x01')}")
    print("Part 3: 0x02")
    text2 = input("Second Correct pad: ")
    do, dn = get_diff(ciphertext, text2)
    print(do, dn)
    print(f"Part 4: {xor_hex(do[:2], dn[:2], '0x02')}")

    # print("\nQuestion 8:")
    # init = input("MAC: ")
    # k_hex = str_to_hex('Kareem')
    # s_hex = str_to_hex('Shruti')

    # print(
    #     f"Answer: {xor_hex(k_hex, s_hex, init[:len(k_hex)]) + init[len(k_hex):]}")
