#!/usr/bin/python3

import os, sys
import SAT_lib

SIZE=int(sys.argv[1])
TIMEOUT=int(sys.argv[2])

s=SAT_lib.SAT_lib(maxsat=True, maxsat_inc=True, maxsat_inc_timeout=TIMEOUT)

grid=[[s.create_var() for w in range(SIZE)] for h in range(SIZE)]

def print_grid(stars):
    for r in range(SIZE):
        for c in range(SIZE):
            if (r, c) in stars:
                sys.stdout.write("* ")
            else:
                sys.stdout.write(". ")
        sys.stdout.write ("\n")

# enum all subsquares

for start_r in range(SIZE-1):
    for start_c in range(SIZE-1):
        possible_width=SIZE-start_c
        possible_heigh=SIZE-start_r
        for sq_size in range(1, min(possible_width, possible_heigh)):
            # square start_r start_c start_r+sq_size start_c+sq_size
            #print (f"square {start_r} {start_c} {start_r+sq_size} {start_c+sq_size}")
            c1=(start_r,start_c)
            c2=(start_r,start_c+sq_size)
            c3=(start_r+sq_size,start_c)
            c4=(start_r+sq_size,start_c+sq_size)
            #print_grid([c1,c2,c3,c4])
            v1=grid[c1[0]][c1[1]]
            v2=grid[c2[0]][c2[1]]
            v3=grid[c3[0]][c3[1]]
            v4=grid[c4[0]][c4[1]]
            s.fix_always_false(
                s.AND_list([
                    s.EQ(v1, s.const_true),
                    s.EQ(v2, s.const_true),
                    s.EQ(v3, s.const_true),
                    s.EQ(v4, s.const_true)]))

for r in range(SIZE):
    for c in range(SIZE):
        s.fix_soft_always_true(grid[r][c], 1)

assert s.solve()

def dump_grid_to_file(f):
    total=0
    for r in range(SIZE):
        for c in range(SIZE):
            v=grid[r][c]
            if s.solution[v]:
                f.write("* ")
                total+=1
            else:
                f.write(". ")
        f.write ("\n")
    return total

total=dump_grid_to_file(sys.stdout)
print (f"{total=}")
fname=f"size_{SIZE}_timeout_{TIMEOUT}_total_{total}"
f=open(fname, "wt")
dump_grid_to_file(f)
f.close()

