#!/usr/bin/env python3

import SAT_lib, my_utils, latin_utils
import sys, random, os, copy
import cProfile, tempfile

def cnt_transv(sq_a):
    with tempfile.NamedTemporaryFile() as tmp:
        TMP=tmp.name
        sq=latin_utils.list_of_lists_of_int_to_base36_str(sq_a)
        os.system (f"echo {sq} | ./cnt_transv > {TMP}")
        f=open(TMP)
        x=f.read()
        assert " " in x
        return int(x.split(" ")[0])
        os.unlink(TMP)

def symb_to_str(symb):
    return "%x" % symb

def mutate(original_sq, mutate_row1, mutate_col1, new_val1):
    sq=copy.deepcopy(original_sq)
    SIZE=len(sq)

    sq[mutate_row1][mutate_col1]=new_val1

    s=SAT_lib.SAT_lib(maxsat=True)

    a=[[s.alloc_BV(SIZE) for c in range(SIZE)] for r in range(SIZE)]

    latin_utils.latin_add_constraints(SIZE, s, a)
    #latin_utils.latin_add_constraints_diagonal(SIZE, s, a) # NEW

    for r in range(SIZE):
        for c in range(SIZE):
            tmp=SAT_lib.n_to_BV(1 << sq[r][c], SIZE)
            if r==mutate_row1 and c==mutate_col1:
                s.fix_BV(a[r][c], tmp)
            else:
                s.fix_BV_soft(a[r][c], tmp, 1) # last arg - weight

    #print ("going to run s.solve()")
    if s.solve():
        #print ("SAT")
        # TODO check correctness
        sq_new=latin_utils.get_square(s, a)
        all_diffs=0
        print ("original_sq:")
        latin_utils.print_LS(original_sq)
        print ("diff:")
        for r in range(SIZE):
            for c in range(SIZE):
                if original_sq[r][c]!=sq_new[r][c]:
                    if (r==mutate_row1 and c==mutate_col1): #or (second_cell and r==mutate_row2 and c==mutate_col2):
                        print (symb_to_str(sq_new[r][c])+"*", end="")
                    else:
                        print (symb_to_str(sq_new[r][c])+" ", end="")
                    #print ("diff cells r/c", r, c)
                    all_diffs+=1
                else:
                    print (". ", end="")
            print ("")
        print ("mutate() all_diffs", all_diffs) # TODO print changes
        #print ("mutate() interim transversals_total", latin_utils.count_transversals(sq_new), "short", latin_utils.list_of_lists_of_int_to_base36_str(sq_new))
        short=latin_utils.list_of_lists_of_int_to_base36_str(sq_new)
        transv_cnt=cnt_transv(sq_new)
        print ("transv_cnt=%08d short=%s" % (transv_cnt, short))
        sys.stdout.flush()
        s.deinit()
        return sq_new
    else:
        #print ("UNSAT")
        # or UNKNOWN
        s.deinit()
        return original_sq

sq=latin_utils.base36_str_to_list_of_lists(sys.argv[1])
print ("init sq:", sys.argv[1])
#print(sq)
print ("transv cnt", cnt_transv(sq))
SIZE=len(sq)
for r in range(SIZE):
    for c in range(SIZE):
        for i in range(SIZE):
            if sq[r][c]==i:
                continue
            print ("set row/col to i", r, c, i)
            mutate(sq, r, c, i)

