## Proving sorting network correctness using Z3 SMT solver

##### Thanks to @linuxenia for the idea

Sorting networks are highly popular in electronics, GPGPU and even in SAT encodings: https://en.wikipedia.org/wiki/Sorting_network.

Especially bitonic sorters, which are also sorting networks: https://en.wikipedia.org/wiki/Bitonic_sorter.

They are relatively easy to construct, but, finding smallest is a challenge.

There is a smallest network (only 25 comparators) for 9-channel sorting network:

This is combinational circuit, each connection is a comparator+swapper, it swaps if one of input values is bigger and passes output to the next level.

I copypasted it from the article: Michael Codish, Lu ́ıs Cruz-Filipe, Michael Frank, and Peter Schneider-Kamp - "Twenty-Five Comparators is Optimal when Sorting Nine Inputs (and Twenty-Nine for Ten)".

Another article about it: Ian Parberry - A Computer Assisted Optimal Depth Lower Bound for Nine-Input Sorting Networks.

I don't know (yet) how they proved it, but it's interesting, that it's extremely easy to prove its correctness using Z3 SMT solver. We just construct network out of comparators/swappers and asking Z3 to find counterexample, for which the output of the network will not be sorted. And it can't, meaning, output's state is always sorted, no matter what values are plugged into inputs.

from z3 import *

a, b, c, d, e, f, g, h, i=Ints('a b c d e f g h i')

def Z3_min (a, b):
return If(ab, a, b)

def comparator (a, b):
return (Z3_min(a, b), Z3_max(a, b))

def line(lst, params):
rt=lst
start=0
while start+1 < len(params):
try:
first=params.index("+", start)
except ValueError:
# no more "+" in parameter string
return rt
second=params.index("+", first+1)
rt[first], rt[second]=comparator(lst[first], lst[second])
start=second+1
# parameter string ended
return rt

l=[i, h, g, f, e, d, c, b, a]
l=line(l, " ++++++++")
l=line(l, " + + + + ")
l=line(l, "   +   + ")
l=line(l, " +   +   ")
l=line(l, "+      + ")
l=line(l, "  + + + +")
l=line(l, "    +   +")
l=line(l, "  +   +  ")
l=line(l, "    + +  ")
l=line(l, "   + +++ ")
l=line(l, "+   +    ")
l=line(l, "+ + + +  ")
l=line(l, "+  +     ")
l=line(l, "  +  +   ")
l=line(l, "++++++ ++")

# construct expression like And(..., k[2]>=k[1], k[1]>=k[0])
expr=[(l[k+1]>=l[k]) for k in range(len(l)-1)]

# True if everything works correctly:
correct=And(*expr)

s=Solver()

# we want to find inputs for which correct==False:
print s.check() # must be unsat



( The full source code: https://github.com/DennisYurichev/yurichev.com/blob/master/blog/sorting_network/test9.py )

There is also smaller 4-channel network I copypasted from Wikipedia:

...

l=line(l, " + +")
l=line(l, "+ + ")
l=line(l, "++++")
l=line(l, " ++ ")

...


( The full source code: https://github.com/DennisYurichev/yurichev.com/blob/master/blog/sorting_network/test4.py )

It also proved to be correct, but it's interesting, what Z3Py expression we've got at each of 4 outputs:

If(If(a < c, a, c) < If(b < d, b, d),
If(a < c, a, c),
If(b < d, b, d))

If(If(If(a < c, a, c) > If(b < d, b, d),
If(a < c, a, c),
If(b < d, b, d)) <
If(If(a > c, a, c) < If(b > d, b, d),
If(a > c, a, c),
If(b > d, b, d)),
If(If(a < c, a, c) > If(b < d, b, d),
If(a < c, a, c),
If(b < d, b, d)),
If(If(a > c, a, c) < If(b > d, b, d),
If(a > c, a, c),
If(b > d, b, d)))

If(If(If(a < c, a, c) > If(b < d, b, d),
If(a < c, a, c),
If(b < d, b, d)) >
If(If(a > c, a, c) < If(b > d, b, d),
If(a > c, a, c),
If(b > d, b, d)),
If(If(a < c, a, c) > If(b < d, b, d),
If(a < c, a, c),
If(b < d, b, d)),
If(If(a > c, a, c) < If(b > d, b, d),
If(a > c, a, c),
If(b > d, b, d)))

If(If(a > c, a, c) > If(b > d, b, d),
If(a > c, a, c),
If(b > d, b, d))


The first and the last are shorter than the 2nd and the 3rd, they are just min(min(min(a,b),c),d) and max(max(max(a,b),c),d).