#!/usr/bin/python3

from z3 import *
import sys

set_param('parallel.enable', True)
set_param('parallel.threads.max', 10)

def check (goal, chain_len):
    print ("trying for %d, chain_len=%d" % (goal, chain_len))

    x = Array('x', IntSort(), IntSort())
    y = Array('y', IntSort(), IntSort())
    r = Array('r', IntSort(), IntSort())
    x_val=[Int('x_val_%d' % i) for i in range(chain_len)]
    y_val=[Int('y_val_%d' % i) for i in range(chain_len)]
    r_val=[Int('r_val_%d' % i) for i in range(chain_len)]

    s = Solver()
    s.add(r[0]==1)

    for i in range(chain_len):
        s.add(x_val[i]==x[i])
        s.add(y_val[i]==y[i])
        s.add(r_val[i]==r[i])

    for i in range(1, chain_len+1):
        s.add(x[i]>=0)
        s.add(y[i]>=0)
        s.add(x[i]<i)
        s.add(y[i]<i)
        s.add(r[i]==r[x[i]]+r[y[i]])

    s.add(r[chain_len-1]==goal)

    if s.check()==unsat:
        print ("unsat")
        return False
    m=s.model()
    # first 1 is not counted:
    for i in range(1, chain_len):
        print (m[r_val[i]], "=", m[r_val[m[x_val[i]].as_long()]], "+", m[r_val[m[y_val[i]].as_long()]] )
    print ("chain_len=%d for goal=%d" % (chain_len-1, goal))
    return True

for goal in range(1, 300):
    for chain_len in range(1, 100):
        if check(goal, chain_len):
            break
        sys.stdout.flush()


