#!/usr/bin/env python3

import SAT_lib, my_utils, functools, operator

from typing import List

"""
rows with dots are partial products:

     aaaa
b ___....
b __...._
b _....__
b ....___

width = OUTPUT_SIZE (INPUT_SIZE*2-1)
height = INPUT_SIZE

"""

def gen_gf2_mul(s, a, b):
    INPUT_SIZE=size=len(a)
    assert INPUT_SIZE==len(b)
    OUTPUT_SIZE=INPUT_SIZE*2-1

    partial_products=[s.alloc_BV(OUTPUT_SIZE) for _ in range(INPUT_SIZE)]
    product=s.alloc_BV(OUTPUT_SIZE)

    for i in range(INPUT_SIZE):
        t=s.shift_left(s.BV_zero_extend(a, OUTPUT_SIZE), i)

        # we index b[] array other way round here:
        b_idx=INPUT_SIZE-i-1
        mask = [b[b_idx]] * OUTPUT_SIZE
        s.fix_BV_EQ(partial_products[b_idx], s.BV_AND(t, mask))

    # FIXME: this is slow! due to BV_XOR
    s.fix_BV_EQ(functools.reduce (s.BV_XOR, partial_products), product)

    # but this is WAY SLOWER:
    """
    for col in range(OUTPUT_SIZE):
        tmp=[partial_products[row][col] for row in range(INPUT_SIZE)]
        s.fix_EQ(s.XOR_list(tmp), product[col])
    """
    return product

def my_shift_left_1 (s, x:List[int]) -> List[int]:
    return x+[s.const_false]
