#!/usr/bin/python3

import sys, os

matrix=[]
f=open(sys.argv[1])
for line in f:
    line=line.rstrip()
    if len(line)==0:
        continue
    matrix.append (list(map(int, line.split(","))))

matrix_size=len(matrix)

def row_col_to_vertex_n(row, col):
    assert row<matrix_size
    assert col<matrix_size
    return row*matrix_size+col

def vertex_n_to_row_col(n):
    row=n // matrix_size
    col=n % matrix_size
    assert row<matrix_size
    assert col<matrix_size
    return row, col

RED='\033[1;31m'
GREEN='\033[1;32m'
NC='\033[0m' # No Color

def print_matrix(matrix, cells_to_HL):
    for row in range(matrix_size):
        for col in range(matrix_size):
            if (row, col) in cells_to_HL:
                print (RED, end="")
            print ("%3d " % matrix[row][col], end="")
            print (NC, end="")
        print ("")

# must be larger that row_col_to_vertex_n(matrix_size,matrix_size)
special_vertex_n=matrix_size**2
#print (f"{special_vertex_n=}")
#print (f"{row_col_to_vertex_n(matrix_size-1,matrix_size-1)=}")
assert row_col_to_vertex_n(matrix_size-1,matrix_size-1) < special_vertex_n

# car!
def first(lst):
    return lst[0]

def do_all(PE_problem):
    vertices_total=special_vertex_n+matrix_size
    edges_for_c=[]

    for row in range(matrix_size):
        edges_for_c.append ((special_vertex_n+row, row_col_to_vertex_n(row, 0), matrix[row][0]))

    for row in range(matrix_size):
        for col in range(matrix_size):
            # down
            if row!=matrix_size-1:
                vertex_n_src=row_col_to_vertex_n(row, col)
                vertex_n_dst=row_col_to_vertex_n(row+1, col)
                edges_for_c.append ((vertex_n_src, vertex_n_dst, matrix[row+1][col]))
            # right
            if col!=matrix_size-1:
                vertex_n_src=row_col_to_vertex_n(row, col)
                vertex_n_dst=row_col_to_vertex_n(row, col+1)
                edges_for_c.append ((vertex_n_src, vertex_n_dst, matrix[row][col+1]))
            if PE_problem in [82,83]:
                # up
                if row!=0:
                    vertex_n_src=row_col_to_vertex_n(row, col)
                    vertex_n_dst=row_col_to_vertex_n(row-1, col)
                    edges_for_c.append ((vertex_n_src, vertex_n_dst, matrix[row-1][col]))
            if PE_problem==83:
                # left
                if col!=0:
                    vertex_n_src=row_col_to_vertex_n(row, col)
                    vertex_n_dst=row_col_to_vertex_n(row, col-1)
                    edges_for_c.append ((vertex_n_src, vertex_n_dst, matrix[row][col-1]))

    f=open("floyd_warshall.input", "wt")
    f.write("%d %d\n" % (vertices_total, len(edges_for_c)))
    for edge in edges_for_c:
        f.write ("%d %d %d\n" % (edge[0], edge[1], edge[2]))
    f.close()

    print ("Going to run floyd_warshall")
    if matrix_size==80:
        # lousy hack. we interesting only in vertices >= special_vertex_n, which is 80^2=6400
        # filtering these lines in Python is much slower than using grep
        os.system ("./floyd_warshall floyd_warshall.input | grep \"^64[0-9][0-9]\" | grep -v 100000000 > floyd_warshall.output")
    else:
        os.system ("./floyd_warshall floyd_warshall.input > floyd_warshall.output")
    print ("floyd_warshall finished")

    if PE_problem in [81, 83]:
        to_grep=["%d -> %d," % (special_vertex_n, row_col_to_vertex_n(matrix_size-1,matrix_size-1))]

    if PE_problem==82:
        to_grep=[]
        for row1 in range(matrix_size):
            for row2 in range(matrix_size):
                to_grep.append ("%d -> %d," % (special_vertex_n+row1, row_col_to_vertex_n(row2, matrix_size-1)))

    #for x in to_grep:
    #    print ("to grep", x)
    #exit(0)

    f=open("floyd_warshall.output")
    infos=[]
    for l in f:
        if "100000000" in l:
            continue
        l=l.rstrip()
        for str_to_grep in to_grep:
            if str_to_grep in l:
                #print ("processing line:", l)
                min_path_sum=int(l.split(",")[1])
                #print ("min path sum", min_path_sum)
                if l.count(",")>1:
                    path=map(int, l.split(",")[2].split("->"))
                    infos.append ((min_path_sum, path))
                else:
                    infos.append ((min_path_sum,))
    #print (infos)
    min_=min(map(first, infos))
    print ("min path sum", min_)
    if len(infos[0])>1:
        for info in infos:
            if info[0]==min_:
                cells_to_HL=[]
                for vert in list(info[1]):
                    if vert < special_vertex_n:
                        cells_to_HL.append (vertex_n_to_row_col(vert))
                print_matrix (matrix, cells_to_HL)
    else:
        print ("matrix not printed - no info about paths")

print ("*** PE 81")
do_all(81)
print ("*** PE 82")
do_all(82)
print ("*** PE 83")
do_all(83)

