#!/usr/bin/env python3

import operator, functools, math, sys, os
import SAT_lib, my_utils, SAT_common

from typing import List

def f(a_b_max, c_part):
    print(f"f() {a_b_max=} {c_part=}")
    #s=SAT_lib.SAT_lib(verbose=0, SAT_solver="libcadical")
    s=SAT_lib.SAT_lib(verbose=0, SAT_solver="kissat")

    INPUT_SIZE=24
    #INPUT_SIZE=int(math.ceil(math.log(a_b_max))+2)
    print (f"{INPUT_SIZE=}")

    a=s.alloc_BV(INPUT_SIZE)
    b=s.alloc_BV(INPUT_SIZE)
    c=s.alloc_BV(INPUT_SIZE)
    RESULT_SIZE=INPUT_SIZE*2
    result=s.alloc_BV(RESULT_SIZE)

    tmp=s.BV_XOR(s.BV_XOR(
            [s.const_false]+SAT_common.gen_gf2_mul(s, a, a), 
            SAT_common.my_shift_left_1(s, SAT_common.gen_gf2_mul(s, a, b))),
        [s.const_false]+SAT_common.gen_gf2_mul(s, b, b))
    s.fix(s.BV_EQ(result, tmp), True)

    #s.fix(s.BV_EQ(s.BV_AND(c, s.n_to_BV(0x00FFFF, 24)), s.n_to_BV(0x00FFFF, 24)), True)
    #s.fix(s.BV_EQ(s.BV_AND(c, s.n_to_BV(0x00FFFF, 24)), s.n_to_BV(c_part, 24)), True)
    s.fix(s.BV_EQ(s.BV_AND(c, s.n_to_BV(0x000FFF, 24)), s.n_to_BV(c_part, 24)), True)

    s.fix(s.BV_EQ(result, [s.const_false]+SAT_common.gen_gf2_mul(s, c, c)), True)

    # b>=a always
    # or a<b
    s.fix(s.comparator_GE(b, a), True)

    # a_b_max>=a, b always
    s.fix(s.comparator_GE(s.n_to_BV(a_b_max, INPUT_SIZE), a), True)
    s.fix(s.comparator_GE(s.n_to_BV(a_b_max, INPUT_SIZE), b), True)

    solutions=set()
    solutions_t=0

    #s.write_CNF_to_file(f"{c_part}.cnf")
    s.write_CNF_to_file("tmp.cnf")
    #os.system (f"./sharpSAT -decot 2 -decow 10 -tmpdir . -cs 1000 tmp.cnf > {c_part}.log")
    os.system (f"./ganak tmp.cnf > {c_part}.log")
    #exit(0)

    """
    while True:
        if s.solve()==False:
            s.deinit()
            print ("solutions:", len(solutions))
            return

        solutions_t+=1
        a_n=s.get_val_from_solution(a)
        b_n=s.get_val_from_solution(b)
        c_n=s.get_val_from_solution(c)
        result_n=s.get_val_from_solution(result)
        print ("solution_t=%d a_n=0x%x, b_n=0x%x, c_n=0x%x, result_n=0x%x" % (solutions_t, a_n, b_n, c_n, result_n))
        sys.stdout.flush()
        solutions.add((a_n, b_n, c_n, result_n))

        def add_blocking_clause(s, a_n, b_n, c_n):
            # blocking clause:
            s.fix(s.AND_list([
                s.BV_EQ(a, s.n_to_BV(a_n, INPUT_SIZE)),
                s.BV_EQ(b, s.n_to_BV(b_n, INPUT_SIZE)),
                s.BV_EQ(c, s.n_to_BV(c_n, INPUT_SIZE))]
            ), False)
            s.fix(s.AND_list([
                s.BV_EQ(a, s.n_to_BV(b_n, INPUT_SIZE)),
                s.BV_EQ(b, s.n_to_BV(a_n, INPUT_SIZE)),
                s.BV_EQ(c, s.n_to_BV(c_n, INPUT_SIZE))]
            ), False)

        add_blocking_clause(s, a_n, b_n, c_n)
    """

# for PE 945
#a_b_max=10
a_b_max=10**7

#c_part=int(sys.argv[1])
#for c_part in range(0x10000):
for c_part in range(0x1000):
    f(a_b_max, c_part)

