#!/usr/bin/env python3

import random, math, time
import SAT_lib, my_utils

from typing import List
from typing import Any

# D.Knuth, fasc0a, L(6)
# must be 808 transversals
tst_DEK=[
[0,1,2,3,4,5,6,7,8,9],
[1,8,3,2,5,4,7,6,9,0],
[2,9,5,6,3,0,8,4,7,1],
[3,7,0,9,8,6,1,5,2,4],
[4,6,7,5,2,9,0,8,1,3],
[5,0,9,4,7,8,3,1,6,2],
[6,5,4,7,1,3,2,9,0,8],
[7,4,1,8,0,2,9,3,5,6],
[8,3,6,0,9,1,5,2,4,7],
[9,2,8,1,6,7,4,0,3,5],
]

# D.Knuth, fasc0a
# 2 MOLS
tst_DEK_a=[
[3,1,4,5,9,2,6,8,7,0],
[2,8,1,9,7,6,3,5,0,4],
[9,4,5,2,3,0,7,1,6,8],
[6,2,0,8,4,5,1,7,9,3],
[8,3,6,4,0,9,5,2,1,7],
[5,9,8,1,2,7,4,0,3,6],
[4,6,2,7,5,3,0,9,8,1],
[0,5,7,6,1,4,8,3,2,9],
[1,7,3,0,6,8,9,4,5,2],
[7,0,9,3,8,1,2,6,4,5]
]
# short form: 3145926870281976350494523071686208451793836409521759812740364627530981057614832917306894527093812645

# D.Knuth, fasc0a
# no MOLS
tst_DEK_b=[
[2,7,1,8,4,5,9,0,3,6],
[0,2,8,7,1,3,5,6,4,9], 
[7,5,2,4,0,9,3,1,6,8], 
[1,4,3,5,9,6,2,7,8,0], 
[6,3,9,0,7,1,8,4,2,5], 
[4,0,6,9,2,7,1,8,5,3], 
[3,1,0,2,6,8,4,5,9,7], 
[9,8,7,1,5,4,6,3,0,2], 
[8,9,5,6,3,0,7,2,1,4], 
[5,6,4,3,8,2,0,9,7,1]
]
# short form: 2718459036028713564975240931681435962780639071842540692718533102684597987154630289563072145643820971

# D.Knuth, fasc0a
# 6 MOLS
tst_DEK_c=[
[0,5,7,2,1,6,4,9,3,8], 
[6,0,5,1,2,9,8,4,7,3], 
[4,8,6,7,0,3,9,2,1,5], 
[1,4,3,9,8,0,7,6,5,2], 
[8,3,2,4,7,5,6,0,9,1], 
[7,2,0,3,9,4,1,5,8,6], 
[5,6,1,0,4,7,3,8,2,9], 
[9,1,4,8,6,2,5,3,0,7], 
[2,7,9,5,3,8,0,1,6,4], 
[3,9,8,6,5,1,2,7,4,0]
]
# short form: 0572164938605129847348670392151439807652832475609172039415865610473829914862530727953801643986512740

# D.Knuth, fasc0a
# no MOLS
tst_DEK_d=[
[1,6,8,0,3,9,7,4,2,5], 
[8,3,4,6,5,1,2,0,9,7], 
[9,8,0,5,7,6,1,3,4,2], 
[2,7,5,4,6,8,9,1,3,0], 
[0,5,3,8,9,7,6,2,1,4], 
[4,9,6,3,8,2,0,5,7,1], 
[7,1,9,2,0,3,4,6,5,8], 
[6,2,1,9,4,0,5,7,8,3], 
[3,4,7,1,2,5,8,9,0,6], 
[5,0,2,7,1,4,3,8,6,9]
]
# short form: 1680397425834651209798057613422754689130053897621449638205717192034658621940578334712589065027143869

# D.Knuth, fasc0a
# E.T.Parker's square, (e)
# must be 5504 transversals
# 12,265,168 MOLS
tst_DEK_e=[
[7,8,2,3,4,5,6,0,1,9], 
[8,2,3,4,0,6,7,1,9,5], 
[2,3,4,0,1,7,8,9,5,6], 
[3,4,0,1,2,8,9,5,6,7], 
[4,0,1,2,3,9,5,6,7,8], 
[5,6,7,8,9,1,2,3,4,0], 
[6,7,8,9,5,2,3,4,0,1], 
[0,1,9,5,6,3,4,7,8,2], 
[1,9,5,6,7,4,0,8,2,3], 
[9,5,6,7,8,0,1,2,3,4],
]
# short form: 7823456019823406719523401789563401289567401239567856789123406789523401019563478219567408239567801234

# from https://math.stackexchange.com/questions/1743064/transversals-of-latin-squares
# must be 3 transversals
tst_SO=[
[0, 1, 2, 3, 4],
[1, 3, 0, 4, 2], 
[2, 4, 3, 1, 0], 
[3, 0, 4, 2, 1], 
[4, 2, 1, 0, 3]]

def latin_add_constraints(SIZE, s, ar):
    # 'cube'

    # one hot for each r, c.
    for r in range(SIZE):
        for c in range(SIZE):
            tmp=[ar[r][c][n] for n in range(SIZE)]
            s.make_one_hot(tmp)

    # one hot for each r, n
    for r in range(SIZE):
        for n in range(SIZE):
            tmp=[ar[r][c][n] for c in range(SIZE)]
            s.make_one_hot(tmp)

    # one hot for each c, n
    for c in range(SIZE):
        for n in range(SIZE):
            tmp=[ar[r][c][n] for r in range(SIZE)]
            s.make_one_hot(tmp)

def latin_add_constraints_diagonal(SIZE, s, ar):
    # 'cube'

    for z in range(SIZE):
        s.add_clause([ar[c][c][z] for c in range(SIZE)])
        s.add_clause([ar[SIZE-c-1][c][z] for c in range(SIZE)])

def fix_first_row_increasing(s, a, SIZE):
    # setting first row of $a$ to [0..SIZE-1]
    for _c in range(SIZE):
        _r=0
        tmp=SAT_lib.n_to_BV(1 << _c, SIZE)
        s.fix_BV(a[_r][_c], tmp)

def fix_first_col_increasing(s, a, SIZE):
    # setting all columns to [0..SIZE-1]
    for _r in range(SIZE):
        _c=0
        tmp=SAT_lib.n_to_BV(1 << _r, SIZE)
        s.fix_BV(a[_r][_c], tmp)

"""
previous version:
def fix_square_to_hardcoded(s, a, hardcoded:List[str], SIZE):
    for r in range(SIZE):
        for c in range(SIZE):
            tmp=SAT_lib.n_to_BV(1 << int(hardcoded[r][c], 16), SIZE) # FIXME: only base 16
            s.fix_BV(a[r][c], tmp)
"""
def fix_square_to_hardcoded(s, a, hardcoded:List[List[int]], SIZE):
    for r in range(SIZE):
        for c in range(SIZE):
            tmp=SAT_lib.n_to_BV(1 << hardcoded[r][c], SIZE)
            s.fix_BV(a[r][c], tmp)

def make_mutually_orthogonal(SIZE, s, a, b):
    t=[]
    for r in range(SIZE):
        for c in range(SIZE):
            t.append(s.mult_one_hots(a[r][c], b[r][c]))
    
    for x in my_utils.transpose_matrix(t):
        s.OR_always(x)

_09=list(map(chr, range(ord('0'), ord('9')+1)))
_az=list(map(chr, range(ord('a'), ord('z')+1)))
_09_az=_09+_az

def cell_to_str(v, SIZE):
    if SIZE<=36:
        return _09_az[v]
    elif SIZE<=100:
        return "%02d" % v
    else:
        return "%03d" % v

def placeholder(SIZE):
    if SIZE<=36:
        return "."
    elif SIZE<=100:
        return ".."
    else:
        return "..."

# as list of lists
def get_square(s, a:List[List[List[int]]]) -> List[List[int]]:
    SIZE=len(a) # dirty hack!
    rt=[]
    for r in range(SIZE):
        l=[]
        for c in range(SIZE):
            v=SAT_lib.one_hot_to_number(SAT_lib.BV_to_number(s.get_BV_from_solution(a[r][c])))
            l.append(v)
        rt.append(l)
    return rt

def randomize_first_square(s, SIZE, a, NORMALIZE):

    fixed_vals=[[None for _ in range(SIZE)] for _ in range(SIZE)]
    if NORMALIZE:
        for r in range(SIZE):
            fixed_vals[r][0]=r
        for c in range(SIZE):
            fixed_vals[0][c]=c

    # may be tuned:
    if SIZE==2:
        return
    if SIZE==3:
        return
    if SIZE==4:
        return

    if SIZE<7:
        random_cells=(SIZE-4)*2
    elif SIZE==7:
        random_cells=(SIZE-3)*2-2
    elif SIZE==8:
        random_cells=(SIZE-2)*2
    elif SIZE==9:
        random_cells=(SIZE-2)*2
    elif SIZE==10:
        random_cells=(SIZE-2)*2
        #random_cells=30
    else:
        random_cells=(SIZE-1)*2
    print ("randomize. random_cells=", random_cells)

    for _ in range(random_cells):
        # randomize random cell
        # do not touch first col/row:
        r=random.randint(1, SIZE-1)
        c=random.randint(1, SIZE-1)
        current_row=fixed_vals[r]
        current_col=my_utils.transpose_matrix(fixed_vals)[c]
        to_pick_from=(set(range(SIZE)) - set(current_row)) - set(current_col)
        if len(to_pick_from)>0:
            rnd=random.choice(list(to_pick_from))
            #print ("randomize. r, c, rnd", r, c, rnd)
            fixed_vals[r][c]=rnd
    # TODO - self-test

    def fixed_vals_print_row(r):
        rt=""
        for x in r:
            if x==None:
                #rt=rt+". "
                rt=rt+placeholder(SIZE)+" "
            else:
                rt=rt+cell_to_str(x, SIZE)+" "
        return rt

    print ("randomize. fixed_vals:")
    for x in fixed_vals:
        print (fixed_vals_print_row(x))

    for r in range(SIZE):
        for c in range(SIZE):
            if fixed_vals[r][c]!=None:
                tmp=SAT_lib.n_to_BV(1 << fixed_vals[r][c], SIZE)
                s.fix_BV(a[r][c], tmp)

# https://en.wikipedia.org/wiki/ANSI_escape_code
def ANSI_set_foreground_color_2(color):
    assert color>=0 and color<=15
    d={
    0: 30, 1: 31, 2: 32, 3: 33, 4: 34, 5: 35, 6: 36, 7: 37, 8: 90, 9: 91,
    10: 92, 11: 93, 12: 94, 13: 95, 14: 96, 15: 97
    }
    return '\033[%dm' % d[color]

def ANSI_set_background_color_2(color):
    assert color>=0 and color<=15
    d={
    0: 40, 1: 41, 2: 42, 3: 43, 4: 44, 5: 45, 6: 46, 7: 47,
    8: 100, 9: 101, 10: 102, 11: 103, 12: 104, 13: 105, 14: 106, 15: 107
    }
    return '\033[%dm' % d[color]

# sq - list of lists of ints
def print_transversal(sq, coords):
    order=len(sq)
    only_coords=[]
    for x in coords:
        only_coords.append((x[0], x[1]))
    #print (only_coords)
    for r in range(order):
        for c in range(order):
            if (r, c) in only_coords:
                print (ANSI_set_foreground_color_2(9), end="")
                #print (my_utils.ANSI_set_normal_color(0), end="")
            print (cell_to_str(sq[r][c], order), end="")
            print (my_utils.ANSI_reset(), end="")
            print (" ", end="")
        print ("")

# FIXME: possibly slow
# FIXME: to my_utils
def all_unique_vals_in_list(x):
    return len(set(x))==len(x)

    # another way:
    # https://stackoverflow.com/questions/5278122/checking-if-all-elements-in-a-list-are-unique
    #seen = set()
    #return not any(i in seen or seen.add(i) for i in x)

    """
    # another way:
    for i in range(len(x)-1):
        if x[i] in x[i+1:]:
            return False
    return True
    """

bitmasks={
3  : 2**3-1,
4  : 2**4-1,
5  : 2**5-1,
6  : 2**6-1,
7  : 2**7-1,
8  : 2**8-1,
9  : 2**9-1,
10 : 2**10-1,
11 : 2**11-1,
12 : 2**12-1,
13 : 2**13-1,
14 : 2**14-1,
15 : 2**15-1,
16 : 2**16-1,
17 : 2**17-1,
}

# SLOW
# transversals may be []
# if transversals is None, it's not filled, transversals only counted.
#def find_transversals(sq, transversals, print_debug=False, start_at_row=0, cur_transversal_vals=[], cur_transversal_coords=[], cnt=0, cur_cols_in_transversal=[]):
def find_transversals(sq:List[List[int]], transversals, print_debug=False, start_at_row=0, cur_transversal_vals=0, cur_transversal_coords=[], cnt=0, cur_cols_in_transversal=[]):
    order=len(sq)

    """
    def gen_bitmask(i):
        rt=0
        for j in range(i):
            rt=rt|1<<j
        return rt
    """

    BITMASK=bitmasks[order]
    #BITMASK=gen_bitmask(order)
    #print ("BITMASK", BITMASK)

    #if len(cur_transversal_vals)==order:
    if cur_transversal_vals==BITMASK:
        if print_debug:
            time.sleep(0.1/2) # visual candy effect, for video recording
            #print ("found transversal", cur_transversal_vals, cur_transversal_coords)
            print ("found transversal", cur_transversal_vals)
            print_transversal(sq, cur_transversal_coords)
        if transversals!=None:
            transversals.append(cur_transversal_coords)
        return cnt+1
    if start_at_row==order:
        return cnt # finish
    for i in range(order):
        # col must not repeat in transversal:
        if i in cur_cols_in_transversal: # FIXME: can be bitmask as well
            continue

        val=sq[start_at_row][i]
        #print ("start_at_row, i, val", start_at_row, i, val)
        #if all_unique_vals_in_list(cur_transversal_vals+[val]):
        #print (type(cur_transversal_vals))
        if ((1<<val) & cur_transversal_vals) == 0:
            cnt=find_transversals(sq, transversals, print_debug, start_at_row+1, cur_transversal_vals | (1<<val), cur_transversal_coords+[(start_at_row, i, val)], cnt, cur_cols_in_transversal+[i])
            #cnt=find_transversals(sq, transversals, print_debug, start_at_row+1, cur_transversal_vals+[val], cur_transversal_coords+[(start_at_row, i, val)], cnt, cur_cols_in_transversal+[i])
    return cnt

# SLOW
def count_transversals(sq:List[List[int]], start_at_row=0, cur_transversal_vals=0, cnt=0, cur_cols_in_transversal=0):
    order=len(sq)

    BITMASK=bitmasks[order]

    if cur_transversal_vals==BITMASK:
        return cnt+1
    if start_at_row==order:
        return cnt # finish
    for i in range(order):
        i_bitmask=1<<i
        # col must not repeat in transversal:
        if (i_bitmask & cur_cols_in_transversal) != 0:
            continue

        val=sq[start_at_row][i]
        if ((1<<val) & cur_transversal_vals) == 0:
            cnt=count_transversals(sq, start_at_row+1, cur_transversal_vals | (1<<val), cnt, cur_cols_in_transversal | i_bitmask)
    return cnt

# for non-normalized LS, this coincides with order! (factorial)
# for normalized a bit smaller. order4 = 8 possible transversals, order5 = 78, order6 = 599, order7 = at least 4320
def find_all_possible_transversals(order, transversals, start_at_row=0, cur_transversal_coords=[], cnt=0, cur_cols_in_transversal=[]):
    if len(cur_cols_in_transversal)==order:
        transversals.append(cur_transversal_coords)
        return cnt+1
    if start_at_row==order:
        return cnt # finish
    for i in range(order):
        # col must not repeat in transversal:
        if i in cur_cols_in_transversal:
            continue

        cnt=find_all_possible_transversals(order, transversals, start_at_row+1, cur_transversal_coords+[(start_at_row, i)], cnt, cur_cols_in_transversal+[i])

    return cnt

def list_of_lists_of_int_to_base36_str(l:List[List[int]]):
    order=len(l)
    rt=""
    for r in range(order):
        for c in range(order):
            assert l[r][c]<10+26
            rt=rt+_09_az[l[r][c]]
    return rt

def base36_str_to_list_of_lists(s):
    assert type(s)==str
    #print (f"base36_str_to_list_of_lists, {s=}")
    SIZE=int(math.sqrt(len(s)))
    assert SIZE**2==len(s)
    rt=[]
    for x in my_utils.partition(s, SIZE):
        rt.append(list(map(lambda x: int(x, 36), my_utils.partition(x, SIZE))))
    return rt

"""
For example:
3 2 1 0
2 3 0 1
1 0 3 2
0 1 2 3

or

g e a 9 d b 8 4 6 1 f h j 7 0 2 3 5 i c
...
e f 7 h b d c g 8 a 5 j 4 0 1 3 2 i 6 9

or

33 32 34 31 35 37 38 36 39 23 22 10 11 04 30 29 01 07 09 12 14 13 17 15 21 16 06 25 18 26 27 20 02 28 00 03 08 05 24 19
...
34 35 33 36 37 39 24 38 20 25 26 27 31 29 32 30 06 08 10 11 12 16 14 17 01 07 18 23 19 28 04 03 00 02 05 13 09 15 22 21

or

096 090 066 040 ... 075 057 047 046
...
083 073 096 051 ... 067 090 057 050

"""
def print_LS(sq:List[List[int]]):
    SIZE=len(sq)
    for r in range(SIZE):
        l=""
        for c in range(SIZE):
            v=sq[r][c]
            l=l+cell_to_str(v, SIZE)+" "
        print (l)

# s - SAT_lib object, a - list of lists of BVs
def print_LS_from_SAT_vars(s, a:List[List[List[int]]]):
    tmp=get_square(s, a)
    print_LS(tmp)

def print_MOLS2_from_SAT_vars(s, a, b):
    SIZE=len(a)
    all_vals=[]
    for r in range(SIZE):
        for c in range(SIZE):
            from_a=get_square(s, a)[r][c]
            assert from_a<0x10 # limit
            from_b=get_square(s, b)[r][c]
            assert from_b<0x10 # limit
            x=from_a*0x10 + from_b
            print ("%02x " % x, end='')
            all_vals.append(x)
        print ("")

    # self-test
    assert len(all_vals)==SIZE**2

def print_MOLS3_from_SAT_vars(s, a, b, c):
    SIZE=len(a)
    assert len(a)==len(b)
    assert len(a)==len(c)
    all_vals=[]
    for _r in range(SIZE):
        for _c in range(SIZE):
            from_a=get_square(s, a)[_r][_c]
            assert from_a<0x10 # limit
            from_b=get_square(s, b)[_r][_c]
            assert from_b<0x10 # limit
            from_c=get_square(s, c)[_r][_c]
            assert from_c<0x10 # limit
            x=from_a*0x100 + from_b*0x10 + from_c
            print ("%03x " % x, end='')
            all_vals.append(x)
        print ("")

    # self-check
    assert len(set(all_vals))==SIZE**2

def sort_columns(sq:List[List[int]]):
    order=len(sq)
    #print (f"{order=}")
    first_row=sq[0]
    inv={}
    for i, v in enumerate(first_row):
        inv[v]=i
    rt=[[0 for _ in range(order)] for _ in range(order)]
    for r in range(order):
        for c in range(order):
            rt[r][c]=sq[r][inv[c]]

    # (temporary) self-test
    #rt_first_row=rt[0]
    #assert rt_first_row==sorted(first_row)
    return rt

def sort_rows(sq:List[List[int]]):
    order=len(sq)
    #print (f"{order=}")
    first_col=[sq[i][0] for i in range(order)]
    inv={}
    for i, v in enumerate(first_col):
        inv[v]=i
    rt=[[0 for _ in range(order)] for _ in range(order)]
    for r in range(order):
        for c in range(order):
            rt[r][c]=sq[inv[r]][c]

    # (temporary) self-test
    #rt_first_col=[rt[i][0] for i in range(order)]
    #assert rt_first_col==sorted(first_col)
    return rt

def normalize(sq:List[List[int]]):
    #return sorted(sort_columns(sq))
    return sort_rows(sort_columns(sq))

def is_valid_LS(sq:List[List[int]]):
    order=len(sq)
    for r in range(order):
        x=set()
        for c in range(order):
            x.add(sq[r][c])
        if len(x)!=order:
            return False
    for c in range(order):
        x=set()
        for r in range(order):
            x.add(sq[r][c])
        if len(x)!=order:
            return False
    return True

def normalize_test():
    random_L10=[
    "9153748026837160495234071628952589076143764893521010245896375716823409426039758168952103740932451768",
    "7895634012895674012395671082345601982347408952367134782195606712395408012345678912340678952340871956",
    '7895634012895604712325478019363408912567671932540856712893404082593671012345678912347608959360178254',
    "0123456789106723954827803154963608124957431209867552319078646549870123749568103289547623019876543210",
    "0125436789193406782595401782363678912540471239560854012893676089523471789365401282567401932367801954",]

    random_L10_normalized=[
    "0123456789154978306226183409573875194206470621983552809613746034827591796253841084976051239351072648",
    "0123456789123406789523408719563478219560408952367156019823476712395408789563401289567401239567108234",
    "0123456789123476089525478019363408912567408259367156712893406719325408789563401289560471239360178254",
    "0123456789106723954827803154963608124957431209867552319078646549870123749568103289547623019876543210",
    "0123456789193604782523608719543671982540471932560854082193676082593471789563401282547601939547108236",]

    for LS, must_be in zip(random_L10, random_L10_normalized):
        x=base36_str_to_list_of_lists(LS)
        y=normalize(x)
        #print (list_of_lists_of_int_to_base36_str(y))
        assert list_of_lists_of_int_to_base36_str(y)==must_be

# tests

if __name__ == "__main__":

    normalize_test()

    for i in range(9+1):
        assert find_all_possible_transversals(i, [])==math.factorial(i)

    # self-test

    x: List[Any]
    x=[]
    assert find_transversals(tst_SO, x)==3
    assert count_transversals(tst_SO)==3
    x=[]
    assert find_transversals(tst_DEK, x)==808
    assert count_transversals(tst_DEK)==808
    x=[]
    #assert find_transversals(tst_DEK_a, x, True)==800
    assert find_transversals(tst_DEK_a, x)==800
    assert count_transversals(tst_DEK_a)==800
    x=[]
    assert find_transversals(tst_DEK_b, x)==824
    x=[]
    assert find_transversals(tst_DEK_c, x)==852
    x=[]
    assert find_transversals(tst_DEK_d, x)==864
    x=[]
    assert find_transversals(tst_DEK_e, x)==5504
    assert count_transversals(tst_DEK_e)==5504

    #"""
    for i in range(15):
        print (ANSI_set_foreground_color_2(i), end="")
        print (i, end="")
        print (my_utils.ANSI_reset(), end="")
        print ("")

    for i in range(15):
        print (ANSI_set_background_color_2(i), end="")
        print (i, end="")
        print (my_utils.ANSI_reset(), end="")
        print ("")
    #"""

    assert list_of_lists_of_int_to_base36_str([[1,2,3],[10,11,12],[30,31,32]])=="123abcuvw"

    s="0123456789abcdef1f4829ae6b5307cd2b57014d9c68fea33d046abc2f9e5817453b98f0de21a67c59b130e6ad4c72f86ed5127fc38a49b073c2bd5108ef946a8c7d560b1af9e32492fa74c8e536d10ba610eb34f7d52c89b8afced241073596c06e8f93527dba41d786a329b0c41f5ee49cf71a36b280d5fae9dc8574106b32"
    l=base36_str_to_list_of_lists(s)
    assert list_of_lists_of_int_to_base36_str(l)==s


