#!/usr/bin/python3

import SAT_lib, itertools
from typing import List

# unused
def gen_onehot_vars_adder(s, x, y):
    sum=[]
    sum_size=len(x)+len(y)-1
    for _sum in range(sum_size):
        list_of_ANDs=[]
        # for sum=i, x and y must be:
        for _x in range(len(x)):
            for _y in range(len(y)):
                if _x+_y==_sum:
                    list_of_ANDs.append(s.AND(x[_x], y[_y]))
                    #print (f"{_x=} {_y=} {_sum=}")
        #print (f"{_sum=} {len(list_of_ANDs)=}")
        sum.append(s.OR_list(list_of_ANDs))
    return sum

def gen_onehot_vars_adder_test():
    s=SAT_lib.SAT_lib()
    x=s.alloc_BV(8) # 0..7
    y=s.alloc_BV(8) # 0..7
    s.make_one_hot(x)
    s.make_one_hot(y)
    sum=gen_onehot_vars_adder(s, x, y)
    s.fix(sum[1], True)
    assert s.solve()==True
    a=SAT_lib.one_hot_to_number(SAT_lib.BV_to_number(s.get_BV_from_solution(x)))
    b=SAT_lib.one_hot_to_number(SAT_lib.BV_to_number(s.get_BV_from_solution(y)))
    s=SAT_lib.one_hot_to_number(SAT_lib.BV_to_number(s.get_BV_from_solution(sum)))
    assert a+b==s

gen_onehot_vars_adder_test()

def one_onehot_LT_than_another(s, x, y):
    list_of_ORs=[]
    for _x in range(len(x)):
        for _y in range(len(y)):
            if _x<_y:
                list_of_ORs.append(s.AND(x[_x], y[_y]))
    return s.OR_list(list_of_ORs)

def gen_onehot_vars_diff(s, x, y):
    rt=[]
    for _diff in range(len(x)):
        list_of_ANDs=[]
        # for sum=i, x and y must be:
        for _x in range(len(x)):
            for _y in range(len(y)):
                if abs(_x-_y)==_diff:
                    list_of_ANDs.append(s.AND(x[_x], y[_y]))
                    #print (f"{_x=} {_y=} {_sum=}")
        #print (f"{_sum=} {len(list_of_ANDs)=}")
        rt.append(s.OR_list(list_of_ANDs))
    return rt

# or mirror:
def mirror_ruler(ruler):
    # find diffs
    diffs=[]
    for i in range(1, len(ruler)):
        diff=ruler[i] - ruler[i-1]
        diffs.append(diff)
    start=0
    rt=[0]
    for diff in reversed(diffs):
        rt.append(start+diff)
        start+=diff
    return rt

# map one hot to list of SAT vars
def one_hot_to_SAT_vars(s, one_hot_n, bit_width):
    rt=[]
    for i in range(bit_width-1, -1, -1): # HORRIBLE
        if one_hot_n==i:
            rt.append(s.const_true)
        else:
            rt.append(s.const_false)
    return rt

def do_all(ORDER, LENGTH, TOTAL_UNIQUE):
    print (f"do_all {ORDER=} {LENGTH=} {TOTAL_UNIQUE=}")
    s=SAT_lib.SAT_lib(SAT_solver="libcadical")
    vars=[s.alloc_BV(LENGTH+1) for _ in range(ORDER)]
    for var in vars:
        s.make_one_hot(var)

    # first must be 0:
    s.fix(vars[0][LENGTH], True)
    # last must be ORDER:
    s.fix(vars[ORDER-1][0], True)

    for i in range(ORDER-1):
        s.fix(one_onehot_LT_than_another(s, vars[i+1], vars[i]), True)

    tmp1=[]
    for i in range(len(vars)):
        for j in range(i+1, len(vars)):
            v1=vars[i]
            v2=vars[j]
            #tmp1.append(gen_onehot_vars_adder(s, v1, v2))
            tmp1.append(gen_onehot_vars_diff(s, v1, v2))

    tmp2=s.BV_OR_list(tmp1)
    s.POPCNT(TOTAL_UNIQUE, tmp2)

    if s.solve()==False:
        print ("unsat")
        return False
    
    sol_n=1
    while True:
        # fix result
        #tmp2=s.BV_OR_list(tmp1)
        #s.POPCNT(TOTAL_UNIQUE, tmp2)
        
        ruler=[]
        tmp_vars=[]
        print (f"Solution #{sol_n}")
        for i in range(ORDER):
            t=SAT_lib.one_hot_to_number(SAT_lib.BV_to_number(s.get_BV_from_solution(vars[i])))
            #print (f"vars[{i}]=", t)
            ruler.append(t)
            tmp_vars.append(t)
        mirrored_ruler=mirror_ruler(ruler)
        print (f"{ruler=} AKA {mirrored_ruler}")
        #for w in itertools.product(tmp_vars, tmp_vars):
        uniq_sums=set()
        for w in range(len(tmp_vars)):
            for q in range(w+1, len(tmp_vars)):
                v1, v2=tmp_vars[w], tmp_vars[q]
                _sum=abs(v1-v2)
                print (v1, "+", v2, "=", _sum)
                uniq_sums.add(_sum)
        print (f"{uniq_sums=}")
        print (f"{len(uniq_sums)=}")
        if len(uniq_sums)==LENGTH:
            print ("this is perfect ruler")

        tmp=[]
        for var_n in range(ORDER):
            tmp.append(s.BV_EQ(vars[var_n], one_hot_to_SAT_vars(s, ruler[var_n], LENGTH+1)))
        s.fix(s.AND_list(tmp), False)

        tmp=[]
        for var_n in range(ORDER):
            tmp.append(s.BV_EQ(vars[var_n], one_hot_to_SAT_vars(s, mirrored_ruler[var_n], LENGTH+1)))
        s.fix(s.AND_list(tmp), False)

        sol_n+=1
        #if s.fetch_next_solution()==False:
        if s.solve()==False:
            return True

# ORDER AKA number of marks
# LENGTH as in Wikipedia. but 34 means that marks are in [0..34]
#ORDER, LENGTH = 6, 17
ORDER, LENGTH = 7, 25
#ORDER, LENGTH = 8, 34
#ORDER, LENGTH = 9, 44

for TOTAL_UNIQUE in range(LENGTH, ORDER-1, -1):
    if do_all(ORDER, LENGTH, TOTAL_UNIQUE):
        break

