#!/usr/bin/env python3

# ATTENTION. This code is only for demonstration and learning only.
# It's not supposed to be used in serious/production job.
# Beware! It may have bugs and vulnerabilities.
# For serious job, use OpenSSH, libssh, etc...

# https://pypi.org/project/hexdump/
# pip3 install hexdump
import hexdump

import socket, struct, random, math
import hashlib, hmac
import sys, base64
import itertools

# pip3 install pycrypto
# (openbsd) pkg_add py3-cryptodome
# https://pycryptodome.readthedocs.io/en/latest/src/installation.html
# https://pycryptodome.readthedocs.io/en/latest/src/cipher/aes.html
import Crypto.Cipher.AES
# https://pycryptodome.readthedocs.io/en/latest/src/util/util.html
import Crypto.Util.Counter

# May need to be upgraded:
# python3 -m pip install --upgrade cryptography
# pip3 install cryptography
# (openbsd) pkg_add py3-cryptography
import cryptography.hazmat.primitives.serialization
import cryptography.hazmat.primitives.hashes
import cryptography.hazmat.primitives.asymmetric.padding

HOST=None
PORT=22
USERNAME=None
PASSWORD=None
ENCRYPTION=False # for first packets
VERBOSITY=0
COMMAND="uname -a" # default
PUBKEY_FNAME=None
PRIKEY_FNAME=None
HOSTBOUND=False
SAVE_SERV_PUB_KEY=False

KEX_ALGOS_list=[
"diffie-hellman-group-exchange-sha256",
"diffie-hellman-group1-sha1",
"diffie-hellman-group14-sha1",
"diffie-hellman-group14-sha256",
"diffie-hellman-group16-sha512",
"diffie-hellman-group18-sha512",
"ecdh-sha2-nistp256",
"ecdh-sha2-nistp384",
"ecdh-sha2-nistp521"]
KEX_ALGOS=",".join(KEX_ALGOS_list)

# https://cryptography.io/en/latest/hazmat/primitives/asymmetric/ec/
KEX_algo_EC_func={
"ecdh-sha2-nistp256": cryptography.hazmat.primitives.asymmetric.ec.SECP256R1(),
"ecdh-sha2-nistp384": cryptography.hazmat.primitives.asymmetric.ec.SECP384R1(),
"ecdh-sha2-nistp521": cryptography.hazmat.primitives.asymmetric.ec.SECP521R1()}

CIPHER_ALGOS_list=[
"aes256-ctr",
"aes128-ctr",
"none"]
CIPHER_ALGOS=",".join(CIPHER_ALGOS_list)

MAC_ALGOS_list=[
"hmac-sha2-256",
"hmac-sha2-512",
"hmac-sha1"]
MAC_ALGOS=",".join(MAC_ALGOS_list)

SERVER_HOST_ALGOS_list=[
"ecdsa-sha2-nistp256",
"rsa-sha2-256",
"rsa-sha2-512",
"ssh-rsa",
"ssh-dss"]
SERVER_HOST_ALGOS=",".join(SERVER_HOST_ALGOS_list)

def parse_command_line():
    global HOST, PORT, USERNAME, PASSWORD, VERBOSITY, COMMAND, PUBKEY_FNAME, PRIKEY_FNAME, HOSTBOUND
    global KEX_ALGOS, CIPHER_ALGOS, SERVER_HOST_ALGOS, SAVE_SERV_PUB_KEY

    if len(sys.argv)==1:
        print ("Error: no options supplied. Use these:")
        print ("  -v                 - increase verbosity level. but no dumps")
        print ("  -vv                - increase verbosity level. dumps")
        print ("  -vvv               - increase verbosity level. even more info")
        print ("  -h host")
        print ("  -port port         - default is 22")
        print ("  -c command         - run command on SSH host")
        print ("  -save_serv_pub_key - save it")
        print ("if password auth is used:")
        print ("  -u user")
        print ("  -pass password")
        print ("if publickey auth is used:")
        print ("  -u user")
        print ("  -pubkey fname - like id_rsa.pub")
        print ("  -prikey fname - like id_rsa")
        print ("Mostly used for testing:")
        print ("  -kex algos              - forcibly set KEX algos")
        print ("  -cipher algos           - forcibly set cipher algos. may be 'none'")
        print ("  -server_host_algo algos - forcibly set cipher algos. may be 'none'")
        print ("  -mac algos              - forcibly set MAC algos")
        print ("  -hostbound")
        exit(0)
    idx=1
    while idx<len(sys.argv):
        if sys.argv[idx]=="-v":
            VERBOSITY=1
            idx+=1
        elif sys.argv[idx]=="-vv":
            VERBOSITY=2
            idx+=1
        elif sys.argv[idx]=="-vvv":
            VERBOSITY=3
            idx+=1
        elif sys.argv[idx]=="-hostbound":
            HOSTBOUND=True
            idx+=1
        elif sys.argv[idx]=="-save_serv_pub_key":
            SAVE_SERV_PUB_KEY=True
            idx+=1
        elif sys.argv[idx]=="-h":
            HOST=sys.argv[idx+1]
            idx+=2
        elif sys.argv[idx]=="-port":
            PORT=int(sys.argv[idx+1])
            idx+=2
        elif sys.argv[idx]=="-u":
            USERNAME=sys.argv[idx+1]
            idx+=2
        elif sys.argv[idx]=="-pass":
            PASSWORD=sys.argv[idx+1]
            idx+=2
        elif sys.argv[idx]=="-pubkey":
            PUBKEY_FNAME=sys.argv[idx+1]
            idx+=2
        elif sys.argv[idx]=="-prikey":
            PRIKEY_FNAME=sys.argv[idx+1]
            idx+=2
        elif sys.argv[idx]=="-c":
            COMMAND=sys.argv[idx+1]
            idx+=2
        elif sys.argv[idx]=="-kex":
            KEX_ALGOS=sys.argv[idx+1]
            idx+=2
        elif sys.argv[idx]=="-mac":
            MAC_ALGOS=sys.argv[idx+1]
            idx+=2
        elif sys.argv[idx]=="-cipher":
            CIPHER_ALGOS=sys.argv[idx+1]
            idx+=2
        elif sys.argv[idx]=="-server_host_algo":
            SERVER_HOST_ALGOS=sys.argv[idx+1]
            idx+=2
        else:
            print (f"Fatal error: unknown command option: {sys.argv[idx]}")
            exit(0)

#DH_GEX_min, DH_GEX_bits, DH_GEX_max = 2048, 8192, 8192 # slow in Python!
DH_GEX_min, DH_GEX_bits, DH_GEX_max = 2048, 2048, 2048
#DH_GEX_min, DH_GEX_bits, DH_GEX_max = 512, 512, 8192 # for test

client_banner=b'SSH-2.0-ToySSH_0.?.?'

SSH_MSG_DISCONNECT=1
SSH_MSG_IGNORE=2
SSH_MSG_UNIMPLEMENTED=3
SSH_MSG_DEBUG=4
SSH_MSG_SERVICE_REQUEST=5
SSH_MSG_SERVICE_ACCEPT=6
SSH2_MSG_EXT_INFO=7 # from OpenSSH 9.0

SSH_MSG_KEXINIT=20
SSH_MSG_NEWKEYS=21 # 0x15

SSH_MSG_KEXDH_INIT=30 # 0x1e
SSH_MSG_KEXDH_REPLY=31 # 0x1f

SSH_MSG_KEX_DH_GEX_INIT=32 # 0x20
SSH_MSG_KEX_DH_GEX_REPLY=33 # 0x21
SSH_MSG_KEX_DH_GEX_REQUEST=34 # 0x22

SSH_MSG_USERAUTH_REQUEST=50 # 0x32 
SSH_MSG_USERAUTH_FAILURE=51 # 0x33 
SSH_MSG_USERAUTH_SUCCESS=52 # 0x34 
SSH_MSG_USERAUTH_BANNER=53 # 0x35

SSH_MSG_GLOBAL_REQUEST=80 # 0x50
SSH_MSG_CHANNEL_OPEN=90 # 0x5a
SSH_MSG_CHANNEL_OPEN_CONFIRMATION=91 # 0x5b
SSH_MSG_CHANNEL_WINDOW_ADJUST=93 # 0x5d
SSH_MSG_CHANNEL_DATA=94 # 0x5e
SSH_MSG_CHANNEL_EXTENDED_DATA=95 # 0x5f
SSH_MSG_CHANNEL_EOF=96 # 0x60
SSH_MSG_CHANNEL_CLOSE=97 # 0x61
SSH_MSG_CHANNEL_REQUEST=98 # 0x62
SSH_MSG_CHANNEL_SUCCESS=99 # 0x63

msg_str={
    SSH_MSG_DISCONNECT                : "SSH_MSG_DISCONNECT",
    SSH_MSG_IGNORE                    : "SSH_MSG_IGNORE",
    SSH_MSG_UNIMPLEMENTED             : "SSH_MSG_UNIMPLEMENTED",
    SSH_MSG_DEBUG                     : "SSH_MSG_DEBUG",
    SSH_MSG_SERVICE_REQUEST           : "SSH_MSG_SERVICE_REQUEST",
    SSH_MSG_SERVICE_ACCEPT            : "SSH_MSG_SERVICE_ACCEPT",
    SSH2_MSG_EXT_INFO                 : "SSH2_MSG_EXT_INFO",
    SSH_MSG_KEXINIT                   : "SSH_MSG_KEXINIT",
    SSH_MSG_NEWKEYS                   : "SSH_MSG_NEWKEYS",
    SSH_MSG_KEXDH_REPLY               : "SSH_MSG_KEXDH_REPLY",
    SSH_MSG_KEXDH_INIT                : "SSH_MSG_KEXDH_INIT",
    SSH_MSG_KEX_DH_GEX_INIT           : "SSH_MSG_KEX_DH_GEX_INIT",
    SSH_MSG_KEX_DH_GEX_REPLY          : "SSH_MSG_KEX_DH_GEX_REPLY",
    SSH_MSG_KEX_DH_GEX_REQUEST        : "SSH_MSG_KEX_DH_GEX_REQUEST",
    SSH_MSG_USERAUTH_REQUEST          : "SSH_MSG_USERAUTH_REQUEST",
    SSH_MSG_USERAUTH_FAILURE          : "SSH_MSG_USERAUTH_FAILURE",
    SSH_MSG_USERAUTH_SUCCESS          : "SSH_MSG_USERAUTH_SUCCESS",
    SSH_MSG_USERAUTH_BANNER           : "SSH_MSG_USERAUTH_BANNER",
    SSH_MSG_GLOBAL_REQUEST            : "SSH_MSG_GLOBAL_REQUEST",
    SSH_MSG_CHANNEL_OPEN              : "SSH_MSG_CHANNEL_OPEN",
    SSH_MSG_CHANNEL_OPEN_CONFIRMATION : "SSH_MSG_CHANNEL_OPEN_CONFIRMATION",
    SSH_MSG_CHANNEL_WINDOW_ADJUST     : "SSH_MSG_CHANNEL_WINDOW_ADJUST",
    SSH_MSG_CHANNEL_DATA              : "SSH_MSG_CHANNEL_DATA",
    SSH_MSG_CHANNEL_EXTENDED_DATA     : "SSH_MSG_CHANNEL_EXTENDED_DATA",
    SSH_MSG_CHANNEL_EOF               : "SSH_MSG_CHANNEL_EOF",
    SSH_MSG_CHANNEL_CLOSE             : "SSH_MSG_CHANNEL_CLOSE",
    SSH_MSG_CHANNEL_REQUEST           : "SSH_MSG_CHANNEL_REQUEST",
    SSH_MSG_CHANNEL_SUCCESS           : "SSH_MSG_CHANNEL_SUCCESS",
}

# first list has higher priority!
def first_common_element (lst1, lst2):
    #print ("first_common_element:")
    #print ("lst1", lst1)
    #print ("lst2", lst2)
    if len(lst1)==0:
        return None
    if len(lst2)==0:
        return None
    # https://stackoverflow.com/questions/8534256/find-first-element-in-a-sequence-that-matches-a-predicate
    cartesian=itertools.product(lst1, lst2)
    try:
        return next(filter(lambda x: x[0]==x[1], cartesian))[0]
    except StopIteration:
        return None

def first_common_element_test():
    x=[1,5,7,9]
    y=[0,7,9,5]
    print (first_common_element(x, y))

    print (first_common_element(['x', 'y', 'z'], [1, 2, 'y']))

    print (first_common_element(['x', 'y', 'z'], [1, 2, 3]))

recv_seqno=0
def my_recv_plain(s):
    global recv_seqno
    tmp_len=s.recv(4)
    if len(tmp_len)<4:
        print ("my_recv_plain() fatal error. incomplete read while reading packet len, got {len(tmp_len)} bytes")
        hexdump.hexdump(tmp_len)
        exit(0)
    pkt_len=struct.unpack (">i", tmp_len)[0]
    if VERBOSITY>=2:
        print ("my_recv_plain() pkt_len: %d or 0x%x" % (pkt_len & 0xffffffff, pkt_len & 0xffffffff))
    tmp=s.recv(pkt_len)
    if VERBOSITY>=2:
        print ("my_recv_plain() got:")
        hexdump.hexdump(tmp_len+tmp)
    recv_seqno+=1
    return tmp_len+tmp

def my_recv_MAC(s):
    global recv_seqno
    tmp_len=s.recv(4)
    if len(tmp_len)<4:
        print (f"my_recv_MAC() incomplete read while reading packet len, got only {len(tmp_len)} bytes")
        hexdump.hexdump(tmp_len)
        exit(0)
    pkt_len=struct.unpack (">i", tmp_len)[0]
    if VERBOSITY>=2:
        print ("my_recv_MAC() pkt_len: %d or 0x%x" % (pkt_len & 0xffffffff, pkt_len & 0xffffffff))
    pkt_len+=MAC_SIZE
    tmp=s.recv(pkt_len)
    if VERBOSITY>=2:
        print ("my_recv_MAC() got:")
        hexdump.hexdump(tmp_len+tmp)

    recv_seqno+=1
    return tmp_len+tmp

AES_BLOCK_SIZE_IN_BITS=128
AES_BLOCK_SIZE_IN_BYTES=AES_BLOCK_SIZE_IN_BITS//8

def my_recv_encrypted(s):
    global CIPHER_ALGO, recv_seqno, serv_to_client_ctr

    if CIPHER_ALGO=="aes128-ctr":
        KEY=key_D[0:128//8]
        assert len(KEY)==16
    elif CIPHER_ALGO=="aes256-ctr":
        KEY=key_D[0:256//8]
        assert len(KEY)==32
    else:
        assert False

    if VERBOSITY>=2:
        print ("my_recv_encrypted() begin")
    first_block=s.recv(AES_BLOCK_SIZE_IN_BYTES)
    if len(first_block)==0:
        print ("my_recv_encrypted() fatal error. recv() returned nothing")
        exit(0)
    if len(first_block)<AES_BLOCK_SIZE_IN_BYTES:
        print (f"my_recv_encrypted() fatal error. incomplete read while reading packet len, got only {len(first_block)} bytes")
        hexdump.hexdump(first_block)
        exit(0)

    ctr = Crypto.Util.Counter.new(128, little_endian=False, initial_value=serv_to_client_ctr)
    t = Crypto.Cipher.AES.new(KEY, Crypto.Cipher.AES.MODE_CTR, counter = ctr)
    first_block_decrypted = t.decrypt(first_block)
    if VERBOSITY>=2:
        print ("first block decrypted:")
        hexdump.hexdump(first_block_decrypted)
    pkt_len=struct.unpack(">i", first_block_decrypted[0:4])[0]
    if VERBOSITY>=2:
        print ("my_recv_encrypted() pkt_len: %d or 0x%x" % (pkt_len & 0xffffffff, pkt_len & 0xffffffff))

    already_read=AES_BLOCK_SIZE_IN_BYTES
    more_to_read=pkt_len-already_read+4
    rest=s.recv(more_to_read)

    ctr = Crypto.Util.Counter.new(128, little_endian=False, initial_value=serv_to_client_ctr)
    t = Crypto.Cipher.AES.new(KEY, Crypto.Cipher.AES.MODE_CTR, counter = ctr)
    all_decrypted = t.decrypt(first_block+rest)
    blocks_total=len(first_block+rest)//AES_BLOCK_SIZE_IN_BYTES
    if VERBOSITY>=2:
        print ("decrypted:")
        hexdump.hexdump(all_decrypted)
    serv_to_client_ctr+=blocks_total

    mac=s.recv(MAC_SIZE)
    if VERBOSITY>=2:
        print ("read MAC")
        hexdump.hexdump(mac)

    recv_seqno+=1
    return all_decrypted + mac

def my_recv(s, MAC=False):
    if MAC:
        if ENCRYPTION:
            return my_recv_encrypted(s)
        else:
            return my_recv_MAC(s)
    else:
        return my_recv_plain(s)

send_seqno=0
def my_send(s, buf):
    global send_seqno
    if VERBOSITY>=2:
        print ("my_send() sending:")
        hexdump.hexdump(buf)
    tmp=s.send(buf)
    send_seqno+=1

# From http://docs.oracle.com/javase/specs/jvms/se7/html/jvms-3.html (3.3)
def align2grain (i, grain):
    return (i + grain-1) & ~(grain-1)

def get_buf (buf, idx):
    length=struct.unpack (">i", buf[idx:idx+4])[0]
    s=buf[idx+4:idx+4+length]
    return s, idx+4+length

def get_str (buf, idx):
    s, l = get_buf(buf, idx)
    return s.decode("utf-8"), l

def get_mpint (buf, idx):
    s, idx = get_buf(buf, idx)
    i=int.from_bytes(s, byteorder='big')
    return i, idx

def get_u32 (buf, idx):
    rt=struct.unpack (">i", buf[idx:idx+4])[0]
    return rt, idx+4

def pack_buf(buf):
    return struct.pack(">i", len(buf)) + buf

def pack_str(s):
    return pack_buf(s.encode("utf-8"))

def binlog(x):
    rt=int(math.ceil(math.log(x, 2)))
    return rt

def bits_for_mpint(i):
    return int(math.ceil(binlog(i)))

def bytes_for_mpint(i):
    return bits_for_mpint(i)//8 + 1

# we add zero byte if most significant bit=1 and this number can be treated as negative
# but we must send number in unsigned form, so extend it
def cvt_to_mpint_and_add_zero_if_needed(i):
    buf=i.to_bytes(bytes_for_mpint(i), byteorder='big')
    if (buf[0]&0x80)==0x80:
        return b"\x00" + buf
    else:
        return buf

def pack_mpint(i):
    tmp=cvt_to_mpint_and_add_zero_if_needed(i)
    return pack_buf(tmp)

def unwrap_packet(buf):
    packet_len, padding_len = struct.unpack (">iB", buf[0:4+1])

    buf_with_padding=buf[4:]
    if packet_len!=len(buf_with_padding):
        print ("fatal error in unwrap_packet")
        print (f"{packet_len=}")
        print (f"{len(buf_with_padding)=}")
        print ("exiting")
        exit(0)
    # remove padding
    payload=buf_with_padding[1:-padding_len]
    return payload

def remove_padding(buf):
    padding_len = struct.unpack (">B", buf[0:1])[0]
    return buf[1:-padding_len]

def unpack_serv_KEX(buf):
    if VERBOSITY>=1:
        print ("unpack_serv_KEX() begin")
    global from_serv_kex_algorithms, from_serv_server_host_algorithms, from_serv_encryption_algorithms
    global from_serv_mac_algorithms, from_serv_compression_algorithms, serv_KEX_blob

    serv_KEX=unwrap_packet(buf)
    idx=0
    if VERBOSITY>=2:
        print ("KEX from serv without len and padlen and msg_code:") # will be hashed as blob
        hexdump.hexdump(serv_KEX[1:])
    serv_KEX_blob=serv_KEX[1:]

    msg_code = struct.unpack (">B", serv_KEX[idx:idx+1])[0]
    idx+=1
    assert msg_code==SSH_MSG_KEXINIT
    cookie=serv_KEX[idx:idx+0x10]
    idx+=0x10
    #print ("cookie", cookie)

    s, idx = get_str(serv_KEX, idx)
    if VERBOSITY>=1:
        for t in s.split(","):
            print (f"kex_algorithms: {t}")
    from_serv_kex_algorithms=s

    s, idx = get_str(serv_KEX, idx)
    if VERBOSITY>=1:
        for t in s.split(","):
            print (f"server_host_algorithms: {t}")
    from_serv_server_host_algorithms=s

    s, idx = get_str(serv_KEX, idx)
    from_serv_encryption_algorithms=s

    s, idx = get_str(serv_KEX, idx)
    assert (s==from_serv_encryption_algorithms) # usually the same!

    if VERBOSITY>=1:
        for t in s.split(","):
            print (f"encryption_algorithms: {t}")

    s, idx = get_str(serv_KEX, idx)
    from_serv_mac_algorithms=s

    s, idx = get_str(serv_KEX, idx)
    assert s==from_serv_mac_algorithms

    if VERBOSITY>=1:
        for t in s.split(","):
            print (f"mac_algorithms: {t}")

    s, idx = get_str(serv_KEX, idx)
    from_serv_compression_algorithms=s

    s, idx = get_str(serv_KEX, idx)
    assert s==from_serv_compression_algorithms

    if VERBOSITY>=1:
        for t in s.split(","):
            print (f"compression_algorithms: {t}")

    s, idx = get_str(serv_KEX, idx)
    languages_client_to_server=s

    s, idx = get_str(serv_KEX, idx)
    assert s==languages_client_to_server
    if VERBOSITY>=1 and len(s)>0:
        for t in s.split(","):
            print (f"languages: {t}")

    first_KEX_packet_follows, reserverd = struct.unpack (">Bi", serv_KEX[idx:idx+1+4])
    assert first_KEX_packet_follows==0
    assert reserverd==0

def unpack_KEXDH_REPLY(buf):
    payload=unwrap_packet(buf)

    if VERBOSITY>=2:
        print ("unpack_KEXDH_REPLY():")
        hexdump.hexdump(payload)

    idx=0
    msg_code = struct.unpack (">B", payload[idx:idx+1])[0]
    assert msg_code==SSH_MSG_KEXDH_REPLY
    idx+=1
    p, idx = get_mpint(payload, idx)
    # one from /etc/ssh/moduli
    if VERBOSITY>=1:
        print (f"DH GEX modulus (P): {hex(p)}")
        print (f"binlog(P)={binlog(p)}")
    g, idx = get_mpint(payload, idx)
    if VERBOSITY>=1:
        print (f"DH GEX base (G): {g}")
    assert len(payload)==idx # be sure nothing else left in buffer
    return g, p

def wrap_packet(payload):
    global ENCRYPTION
    
    if VERBOSITY>=3:
        print ("wrap_packet() payload:")
        hexdump.hexdump(payload)
    payload_len=len(payload)
    if VERBOSITY>=3:
        print (f"wrap_packet() {payload_len=}")
    with_header=4+1+payload_len
    if ENCRYPTION==False:
        with_padding=align2grain(with_header, 8)
    else:
        with_padding=align2grain(with_header, 16)
    padlen=with_padding-with_header
    if padlen<4:
        if ENCRYPTION==False:
            padlen+=8
        else:
            padlen+=16
    if VERBOSITY>=3:
        print (f"wrap_packet() {padlen=}")

    padding=b'\xAA'*padlen
    payload_padded=struct.pack (">B", padlen) + payload + padding
    outbuf=struct.pack(">i", len(payload_padded)) + payload_padded
    return outbuf

def pack_client_KEX():
    global EXT_INFO_C, ENCRYPTION
    
    global client_KEX_blob
    if VERBOSITY>=2:
        print ("pack_client_KEX()")
        
    if EXT_INFO_C:
        kex_algorithms=KEX_ALGOS+",ext-info-c"
    else:
        kex_algorithms=KEX_ALGOS # doesn't work with crypto=none, dunno why
        
    if VERBOSITY>=2:
        print (f"our kex_algorithms: {kex_algorithms=}")

    server_host_algorithms=SERVER_HOST_ALGOS
    encryption_algorithms_client_to_server=CIPHER_ALGOS
    encryption_algorithms_server_to_client=encryption_algorithms_client_to_server
    mac_algorithms_client_to_server=MAC_ALGOS
    mac_algorithms_server_to_client=mac_algorithms_client_to_server
    if "none" not in CIPHER_ALGOS:
        compression_algorithms_client_to_server="none,zlib@openssh.com"
    else:
        compression_algorithms_client_to_server="none" # doesn't work with crypto=none
    compression_algorithms_server_to_client=compression_algorithms_client_to_server
    languages_client_to_server=""
    languages_server_to_client=languages_client_to_server

    buf=struct.pack(">B", SSH_MSG_KEXINIT)
    cookie=b'\x11'*16
    buf+=cookie
    buf+=pack_str(kex_algorithms)
    buf+=pack_str(server_host_algorithms)
    buf+=pack_str(encryption_algorithms_client_to_server)
    buf+=pack_str(encryption_algorithms_server_to_client)
    buf+=pack_str(mac_algorithms_client_to_server)
    buf+=pack_str(mac_algorithms_server_to_client)
    buf+=pack_str(compression_algorithms_client_to_server)
    buf+=pack_str(compression_algorithms_server_to_client)
    buf+=pack_str(languages_client_to_server)
    buf+=pack_str(languages_server_to_client)
    first_KEX_packet_follows, reserverd = 0, 0
    buf+=struct.pack (">Bi", first_KEX_packet_follows, reserverd)

    if VERBOSITY>=2:
        print ("KEX from client without len and padlen and msg_code and padding:") # will be hashed as blob
        hexdump.hexdump(buf[1:])
    client_KEX_blob=buf[1:]

    return wrap_packet(buf)

def pack_DH_GEX_request():
    if VERBOSITY>=2:
        print ("pack_DH_GEX_request()")
    return wrap_packet(struct.pack(">Biii", SSH_MSG_KEX_DH_GEX_REQUEST, DH_GEX_min, DH_GEX_bits, DH_GEX_max))

def pack_DH_GEX_INIT(e):
    if VERBOSITY>=2:
        print ("pack_DH_GEX_INIT()")
    buf=struct.pack(">B", SSH_MSG_KEX_DH_GEX_INIT)
    buf+=pack_mpint(e)
    return wrap_packet(buf)

def pack_SSH_MSG_KEXDH_INIT_raw(e):
    if VERBOSITY>=2:
        print ("pack_SSH_MSG_KEXDH_INIT")
    buf=struct.pack(">B", SSH_MSG_KEXDH_INIT)
    buf+=pack_buf(e)
    return wrap_packet(buf)

def pack_SSH_MSG_KEXDH_INIT(e):
    if VERBOSITY>=2:
        print ("pack_SSH_MSG_KEXDH_INIT")
    buf=struct.pack(">B", SSH_MSG_KEXDH_INIT)
    buf+=pack_mpint(e)
    return wrap_packet(buf)

def pack_new_keys():
    if VERBOSITY>=2:
        print ("pack_new_keys()")
    buf=struct.pack(">B", SSH_MSG_NEWKEYS)
    return wrap_packet(buf)

def send_new_keys(s):
    tmp=pack_new_keys()
    if VERBOSITY>=2:
        print ("send_new_keys():")
        hexdump.hexdump(tmp)
    my_send(s, tmp)

def unpack_RSA_host_key(host_key):
    global RSA_e, RSA_modulus_n

    buf=host_key
    idx=0
    host_key_type, idx = get_str(buf, idx)
    assert host_key_type=="ssh-rsa"
    RSA_e, idx = get_mpint(buf, idx)
    RSA_modulus_n, idx = get_mpint(buf, idx)
    assert len(buf)==idx # be sure nothing else left in buffer
    if VERBOSITY>=1:
        print (f"{host_key_type=}")
        print (f"{RSA_e=}")
        print (f"{RSA_modulus_n=}")
        print (f"{binlog(RSA_modulus_n)=}")

def unpack_ECDSA_host_key(host_key):
    global ECDSA_curve_id, ECDSA_Q
    buf=host_key
    idx=0
    host_key_type, idx = get_str(buf, idx)
    assert host_key_type=="ecdsa-sha2-nistp256"
    if VERBOSITY>=2:
        print ("unpack_ECDSA_host_key() host_key:")
        hexdump.hexdump(buf[idx:])
    ECDSA_curve_id, idx = get_str(buf, idx)
    if VERBOSITY>=1:
        print (f"unpack_ECDSA_host_key() {ECDSA_curve_id=}:")
    ECDSA_Q, idx=get_buf (buf, idx)
    if VERBOSITY>=2:
        print ("unpack_ECDSA_host_key() ECDSA_Q:")
        hexdump.hexdump(ECDSA_Q)

def unpack_DSA_host_key(host_key):
    global DSA_p, DSA_q, DSA_g, DSA_y
    buf=host_key
    idx=0
    host_key_type, idx = get_str(buf, idx)
    assert host_key_type=="ssh-dss"
    DSA_p, idx = get_mpint(buf, idx)
    DSA_q, idx = get_mpint(buf, idx)
    DSA_g, idx = get_mpint(buf, idx)
    DSA_y, idx = get_mpint(buf, idx)
    assert len(buf)==idx # be sure nothing else left in buffer
    if VERBOSITY>=1:
        print (f"{host_key_type=}")
        print (f"{DSA_p=}")
        print (f"{DSA_q=}")
        print (f"{DSA_g=}")
        print (f"{DSA_y=}")
        print (f"{binlog(DSA_p)=}")
        print (f"{binlog(DSA_q)=}")
        print (f"{binlog(DSA_g)=}")
        print (f"{binlog(DSA_y)=}")

def print_fingerprint(KEX_host_key):
    hashed=hashlib.sha256(KEX_host_key).digest()
    print (f"Server fingerprint: {SERVER_HOST_ALGO} SHA256:"+base64.b64encode(hashed).decode("utf-8"))

def unpack_KEX_H_sig(KEX_H_sig):
    global KEX_H_sig_buf
    idx=0
    s, idx = get_str(KEX_H_sig, idx)
    if VERBOSITY>=1:
        # it must be equal to the SERVER_HOST_ALGO string
        print (f"unpack_KEX_H_sig(): {s}")
        assert s==SERVER_HOST_ALGO
    tmp, idx=get_buf (KEX_H_sig, idx)
    if VERBOSITY>=2:
        hexdump.hexdump(tmp)
    KEX_H_sig_buf=tmp
    assert len(KEX_H_sig)==idx # be sure nothing else left in buffer

def chk_RSA_kexgex_hash(kexgex_hash):
    if VERBOSITY>=1:
        print ("chk_RSA_kexgex_hash()")
    global KEX_H_sig_buf # in case of RSA: 2k bits
    global RSA_e
    global RSA_modulus_n # 2k bits
    # do RSA signature check:
    KEX_H_sig_mpint=int.from_bytes(KEX_H_sig_buf, byteorder='big')
    
    plaintext_must_be=pow (KEX_H_sig_mpint, RSA_e, RSA_modulus_n)
    # plaintext_must_be:
    # 160 bits if server_host_algorithms="ssh-rsa". SHA1.
    # X bits if server_host_algorithms="rsa-sha2-X". SHA256 or SHA512
    if VERBOSITY>=1:
        print (f"{hex(plaintext_must_be)=}")
    # https://stackoverflow.com/questions/7983820/get-the-last-4-characters-of-a-string
    plaintext_must_be_truncated=hex(plaintext_must_be)[-SERVER_HOST_ALGO_NIBBLES:] # last X nibbles/characters for SHA
    if VERBOSITY>=1:
        print (f"{plaintext_must_be_truncated=}")
    if VERBOSITY>=2:
        print ("kexgex_hash:")
        hexdump.hexdump(kexgex_hash)
    tmp=SERVER_HOST_ALGO_PTR(kexgex_hash).digest()
    if VERBOSITY>=2:
        print (SERVER_HOST_ALGO+"(kexgex_hash):")
        hexdump.hexdump(tmp)
    hash_of_kexgex_hash=SERVER_HOST_ALGO_PTR(kexgex_hash).hexdigest()
    if VERBOSITY>=2:
        print (f"{hash_of_kexgex_hash=}")
    if plaintext_must_be_truncated==hash_of_kexgex_hash:
        if VERBOSITY>=1:
            print ("Server signature is correct")
    else:
        print ("Error. Server signature is incorrect")
        print ("Exiting")
        exit(0)

# chk ECDSA signature
def chk_ECDSA_kexgex_hash(kexgex_hash):
    if VERBOSITY>=1:
        print ("chk_ECDSA_kexgex_hash()")
    if VERBOSITY>=2:
        print ("kexgex_hash:")
        hexdump.hexdump(kexgex_hash)

    if VERBOSITY>=2:
        print ("KEX_H_sig_buf:")
        hexdump.hexdump(KEX_H_sig_buf)

    idx=0

    sign_r, idx=get_mpint (KEX_H_sig_buf, idx)
    sign_s, idx=get_mpint (KEX_H_sig_buf, idx)

    if VERBOSITY>=1:
        print (f"{sign_r=}")
        print (f"{sign_s=}")

    # convert r,s pair to ASN1-encoded packet
    # https://cryptography.io/en/latest/hazmat/primitives/asymmetric/utils/#cryptography.hazmat.primitives.asymmetric.utils.encode_dss_signature
    asn1_sig=cryptography.hazmat.primitives.asymmetric.utils.encode_dss_signature(sign_r, sign_s)

    global ECDSA_curve_id, ECDSA_Q
    assert ECDSA_curve_id=="nistp256"

    pub_key=cryptography.hazmat.primitives.asymmetric.ec.EllipticCurvePublicKey.from_encoded_point(
        data=ECDSA_Q,
        curve=cryptography.hazmat.primitives.asymmetric.ec.SECP256R1())
    if VERBOSITY>=1:
        print (f"ECDSA public key from server: {pub_key.public_numbers()}")
 
    try:
        pub_key.verify(asn1_sig, kexgex_hash, cryptography.hazmat.primitives.asymmetric.ec.ECDSA(cryptography.hazmat.primitives.hashes.SHA256()))
        if VERBOSITY>=1:
            print ("Server signature is correct")
    except cryptography.exceptions.InvalidSignature:
        print ("Error. Server signature is incorrect")
        print ("Exiting")
        exit(0)

# chk DSA signature
def chk_DSA_kexgex_hash(kexgex_hash):
    if VERBOSITY>=1:
        print ("chk_DSA_kexgex_hash()")

    global DSA_p, DSA_q, DSA_g, DSA_y, KEX_H_sig_buf

    L=math.ceil(math.log(DSA_p, 2))
    N=math.ceil(math.log(DSA_q, 2))
    
    if VERBOSITY>=1:
        print (f"{L=}, {N=}")

    if VERBOSITY>=2:
        print ("KEX_H_sig_buf:")
        hexdump.hexdump(KEX_H_sig_buf)

    buf_r=KEX_H_sig_buf[0:len(KEX_H_sig_buf)//2]
    buf_s=KEX_H_sig_buf[len(KEX_H_sig_buf)//2:]

    if VERBOSITY>=2:
        print ("buf_r:")
        hexdump.hexdump(buf_r)
        print ("buf_s:")
        hexdump.hexdump(buf_s)

    r=int.from_bytes(buf_r, 'big')
    s=int.from_bytes(buf_s, 'big')

    m1 = hashlib.sha1()
    m1.update(kexgex_hash)
    z=int.from_bytes(m1.digest()[0:N//8], 'big')

    w=pow(s, -1, DSA_q) # mod inverse of s mod q
    u1=divmod(z*w, DSA_q)[1] # IOW, u1=z/s mod q
    u2=divmod(r*w, DSA_q)[1] # IOW, u2=r/s mod q
    v=divmod(divmod(pow(DSA_g, u1, DSA_p) * pow(DSA_y, u2, DSA_p), DSA_p)[1], DSA_q)[1]

    if VERBOSITY>=2:
        print (f"{hex(v)=}")

    if r==v:
        if VERBOSITY>=1:
            print ("Server signature is correct")
    else:
        print ("Error. Server signature is incorrect")
        print ("Exiting")
        exit(0)

def save_serv_pub_key(KEX_host_key):
    idx=0
    _type, idx = get_str(KEX_host_key, idx)
    to_save=_type+" "+base64.b64encode(KEX_host_key).decode("utf-8")+" root@"+HOST
    if "rsa" in SERVER_HOST_ALGO:
        fname=HOST+".ssh_host_rsa_key.pub"
    elif "ecdsa" in SERVER_HOST_ALGO:
        fname=HOST+".ssh_host_ecdsa_key.pub"
    else:
        assert False
    f=open(fname, "w")
    f.write(to_save+"\n")
    f.close()
    print (f"Server's public key saved to {fname}")

def unpack_KEX_host_key(KEX_host_key):
    global SAVE_SERV_PUB_KEY
    if VERBOSITY>=2:
        print ("KEX_host_key: will be hashed as blob:")
        # same as in /etc/ssh/ssh_host_rsa_key.pub
        # or as in   /etc/ssh/ssh_host_ecdsa_key.pub
        hexdump.hexdump(KEX_host_key)
    if VERBOSITY>=1:
        print ("unpack_KEX_host_key:")
        print_fingerprint(KEX_host_key)
    if SAVE_SERV_PUB_KEY==True:
        save_serv_pub_key(KEX_host_key)

    x={
    "ecdsa-sha2-nistp256": unpack_ECDSA_host_key,
    "ssh-dss":             unpack_DSA_host_key,
    "ssh-rsa":             unpack_RSA_host_key,
    "rsa-sha2-256":        unpack_RSA_host_key,
    "rsa-sha2-512":        unpack_RSA_host_key}
    x[SERVER_HOST_ALGO](KEX_host_key)

def unpack_DH_REPLY(buf, GEX):
    global SERVER_HOST_ALGO
    idx=0
    if VERBOSITY>=2:
        print ("unpack_DH_REPLY():")
        hexdump.hexdump(buf)
    msg_code = struct.unpack (">B", buf[idx:idx+1])[0]
    if GEX and msg_code!=SSH_MSG_KEX_DH_GEX_REPLY:
        print (f"expecting msg_code==SSH_MSG_KEX_DH_GEX_REPLY, but got {msg_code}")
        exit(0)
    if GEX==False and msg_code!=SSH_MSG_KEXDH_REPLY:
        print (f"expecting msg_code==SSH_MSG_KEXDH_REPLY, but got {msg_code}")
        exit(0)
    idx+=1
    KEX_host_key, idx = get_buf(buf, idx)

    unpack_KEX_host_key(KEX_host_key)

    server_f, idx=get_mpint (buf, idx)
    if VERBOSITY>=1:
        print (f"{hex(server_f)=}")
    KEX_H_sig, idx=get_buf (buf, idx)
    if VERBOSITY>=2:
        print ("KEX_H_sig:")
        # this buffer begins with SERVER_HOST_ALGO string
        hexdump.hexdump(KEX_H_sig)
    unpack_KEX_H_sig(KEX_H_sig)
    assert len(buf)==idx # be sure nothing else left in buffer
    return KEX_host_key, server_f

def unpack_ECDH_REPLY(buf):
    global SERVER_HOST_ALGO
    idx=0
    if VERBOSITY>=2:
        print ("unpack_ECDH_REPLY():")
        hexdump.hexdump(buf)
    msg_code = struct.unpack (">B", buf[idx:idx+1])[0]
    if msg_code!=SSH_MSG_KEXDH_REPLY:
        print (f"expecting msg_code==SSH_MSG_KEXDH_REPLY, but got {msg_code}")
        exit(0)
    idx+=1
    KEX_host_key, idx = get_buf(buf, idx)

    unpack_KEX_host_key(KEX_host_key)

    ECDH_server_Q_S, idx=get_buf (buf, idx)
    if VERBOSITY>=2:
        print ("ECDH_server_Q_S:")
        hexdump.hexdump(ECDH_server_Q_S)
    KEX_H_sig, idx=get_buf (buf, idx)
    if VERBOSITY>=2:
        print ("KEX_H_sig:")
        # this buffer begins with SERVER_HOST_ALGO string
        hexdump.hexdump(KEX_H_sig)
    unpack_KEX_H_sig(KEX_H_sig)
    assert len(buf)==idx # be sure nothing else left in buffer
    return KEX_host_key, ECDH_server_Q_S

# hash hodgepodge
# this is almost all data exchanged before this moment
# critical part - shared_secret
# [version without GEX (DH groups is hardcoded): see kex_gen_hash() in OpenSSH 9.0]
# [version for GEX: see kexgex_hash() in OpenSSH 9.0]
def calc_kexgex_hash(client_banner, serv_banner, client_KEX_blob, serv_KEX_blob, KEX_host_key, client_e, server_f, shared_secret, g, p, GEX):
    global KEX_HASH
    buf=b""
    buf+=pack_buf(client_banner)
    buf+=pack_buf(serv_banner)
    buf+=pack_buf(struct.pack(">B", SSH_MSG_KEXINIT)+client_KEX_blob)
    buf+=pack_buf(struct.pack(">B", SSH_MSG_KEXINIT)+serv_KEX_blob)
    buf+=pack_buf(KEX_host_key)
    if GEX:
        buf+=struct.pack(">iii", DH_GEX_min, DH_GEX_bits, DH_GEX_max)
        buf+=pack_mpint(p)
        buf+=pack_mpint(g)
    buf+=pack_mpint(client_e)
    buf+=pack_mpint(server_f)
    buf+=pack_mpint(shared_secret)
    if VERBOSITY>=3:
        print (f"buf to hash. {len(buf)=}")
        hexdump.hexdump(buf)
    hashed=KEX_HASH(buf).digest()
    if VERBOSITY>=2:
        print ("hashed:")
        hexdump.hexdump(hashed)
    return hashed

def derive_key(K, H, c, session_id):
    global KEX_HASH
    m=KEX_HASH()
    assert type(K)==int # let it be here
    m.update(pack_mpint(K))
    m.update(H)
    m.update(str.encode(c))
    m.update(session_id)
    key=m.digest()
    if VERBOSITY>=2:
        print ("key_%c:" % c)
        hexdump.hexdump(key)
    return key

# required - in bytes
def derive_key2(K, H, c, session_id, required):
    global KEX_HASH
    m=KEX_HASH()
    assert type(K)==int # let it be here
    m.update(pack_mpint(K))
    m.update(H)
    m.update(str.encode(c))
    m.update(session_id)
    key=m.digest()

    m1=KEX_HASH()
    assert type(K)==int # let it be here
    m1.update(pack_mpint(K))
    m1.update(H)
    prev_key=key
    # m1 may not be used
    while len(key)<required:
        # extend key.
        # [see RFC 4253 7.2 https://datatracker.ietf.org/doc/html/rfc4253#section-7.2]
        m1.update(prev_key)
        prev_key=m1.digest()
        key+=prev_key
    assert len(key)>=required
    key=key[0:required]
    if VERBOSITY>=2:
        print ("key_%c:" % c)
        hexdump.hexdump(key)
    return key

# [See kex.c, derive_key() in OpenSSH 9.0, also RFC 4253 7.2]
def derive_keys(K, H, session_id):
    global key_A, key_B, key_C, key_D, key_E, key_F, client_to_serv_ctr, serv_to_client_ctr
    global MAC_SIZE, CIPHER_KEY_SIZE, ENCRYPTION
    
    if VERBOSITY>=2:
        print ("derive_keys() begin")
    if ENCRYPTION:
        key_A=derive_key(K, H, 'A', session_id) # IV or ctr
        key_B=derive_key(K, H, 'B', session_id) # IV or ctr

        # AES counters:
        assert len(key_A)>=16
        assert len(key_B)>=16
        client_to_serv_ctr = int.from_bytes(key_A[0:128//8], byteorder='big')
        serv_to_client_ctr = int.from_bytes(key_B[0:128//8], byteorder='big')

        key_C=derive_key2(K, H, 'C', session_id, CIPHER_KEY_SIZE) # encryption key
        key_D=derive_key2(K, H, 'D', session_id, CIPHER_KEY_SIZE) # encryption key
        
    key_E=derive_key2(K, H, 'E', session_id, MAC_SIZE) # MAC key
    key_F=derive_key2(K, H, 'F', session_id, MAC_SIZE) # MAC key

def my_send_with_MAC(s, msg):
    global CIPHER_ALGO, client_to_serv_ctr, client_to_serv_ctr_1

    if CIPHER_ALGO=="aes128-ctr":
        KEY=key_C[0:128//8]
        assert len(KEY)==16
    elif CIPHER_ALGO=="aes256-ctr":
        KEY=key_C[0:256//8]
        assert len(KEY)==32
    elif CIPHER_ALGO=="none":
        pass
    else:
        assert False
        
    if VERBOSITY>=2 and CIPHER_ALGO!="none":
        print (f"my_send_with_MAC() {len(KEY)=}")
        hexdump.hexdump(KEY)
        
    if VERBOSITY>=2:
        print (f"my_send_with_MAC() {len(msg)=}")
        hexdump.hexdump(msg)

    assert len(key_E)>=MAC_SIZE
    calc_MAC = hmac.digest(key=key_E[0:MAC_SIZE], msg=struct.pack(">i", send_seqno-1)+msg, digest=MAC_ALGO)
    if VERBOSITY>=2:
        print ("my_send_with_MAC() calc_MAC:")
        hexdump.hexdump(calc_MAC)
    if ENCRYPTION==False:
        assert (len(msg)&0x7)==0
        my_send(s, msg+calc_MAC)
    else:
        assert (len(msg)&0xf)==0

        ctr = Crypto.Util.Counter.new(128, little_endian=False, initial_value=client_to_serv_ctr)
        t = Crypto.Cipher.AES.new(KEY, Crypto.Cipher.AES.MODE_CTR, counter = ctr)
        t2 = t.encrypt(msg)
        client_to_serv_ctr+=(len(t2)//AES_BLOCK_SIZE_IN_BYTES)
        my_send(s, t2+calc_MAC)
    if VERBOSITY>=2:
        print ("my_send_with_MAC() finish")

def my_recv_chk_MAC(s):
    tmp=my_recv(s, True)
    if VERBOSITY>=3:
        print ("my_recv_chk_MAC() begin. buf:")
        hexdump.hexdump(tmp)
    assert len(tmp)!=0
    tmplen=len(tmp)
    msg=tmp[0:-MAC_SIZE]
    msglen=len(msg)
    if VERBOSITY>=3:
        print ("msg to check:")
        hexdump.hexdump(msg)
    received_MAC=tmp[msglen:msglen+MAC_SIZE]
    assert len(received_MAC)==MAC_SIZE
    if VERBOSITY>=3:
        print (f"my_recv_chk_MAC(): {recv_seqno=}")

    assert len(key_F)>=MAC_SIZE
    calc_MAC = hmac.digest(key=key_F[0:MAC_SIZE], msg=struct.pack(">i", recv_seqno-2)+msg, digest=MAC_ALGO)

    if calc_MAC!=received_MAC:
        print ("fatal error: received MAC is not correct!")
        if VERBOSITY>=3:
            print ("calculated MAC:")
            hexdump.hexdump(calc_MAC)
        exit(0)
    else:
        if VERBOSITY>=3:
            print ("my_recv_chk_MAC(): MAC checked OK")
    return msg

def send_exec(s, cmd):
    print (f"send_exec {cmd}")
    msg=struct.pack("B", SSH_MSG_CHANNEL_REQUEST)
    msg+=struct.pack(">i", 0)
    msg+=pack_str("exec")
    msg+=struct.pack("B", 1)
    msg+=pack_str(cmd)
    pkt=wrap_packet(msg)
    #hexdump.hexdump(pkt)
    my_send_with_MAC(s, pkt)

def unpack_exit_status(s, msg):
    if VERBOSITY>=2:
        print ("unpack_exit_status(), msg:")
        hexdump.hexdump(msg)
    assert msg[0]==SSH_MSG_CHANNEL_REQUEST
    idx=1+4 # skip recipient channel
    tmp, idx = get_str(msg, idx)
    if tmp=="exit-status":
        # [see RFC4254 6.10]
        idx+=1 # skip boolean
        exit_status, idx = get_u32(msg, idx)
        print (f"exit-status {exit_status}")
        assert len(msg)==idx # be sure nothing else left in buffer
    elif tmp=="eow@openssh.com":
        # for more info, grep "eow@openssh.com" here: 
        # https://github.com/openssh/openssh-portable/blob/master/PROTOCOL
        idx+=1 # skip boolean
        assert len(msg)==idx # be sure nothing else left in buffer
    else:
        assert False

# [see RFC4254 5.2]
def unpack_channel_data(msg):
    assert msg[0]==SSH_MSG_CHANNEL_DATA
    idx=1+4 # skip recipient channel
    tmp, idx = get_str(msg, idx)
    print ("response:")
    print ("==")
    print (tmp)
    print ("==")
    assert len(msg)==idx # be sure nothing else left in buffer

# [see RFC4254 5.2]
SSH_EXTENDED_DATA_STDERR=1
def unpack_channel_extended_data(msg):
    assert msg[0]==SSH_MSG_CHANNEL_EXTENDED_DATA
    idx=1+4 # skip recipient channel
    code, idx = get_u32(msg, idx)
    assert code==SSH_EXTENDED_DATA_STDERR
    tmp, idx = get_str(msg, idx)
    print ("response (stderr):")
    print ("==")
    print (tmp)
    print ("==")
    assert len(msg)==idx # be sure nothing else left in buffer

# see [RFC4254 5.3]
def send_eof(s):
    if VERBOSITY>=2:
        print ("send_eof()")
    msg=struct.pack("B", SSH_MSG_CHANNEL_EOF)
    msg+=struct.pack(">i", 0)
    pkt=wrap_packet(msg)
    #hexdump.hexdump(pkt)
    my_send_with_MAC(s, pkt)

def send_channel_open(s):
    msg=struct.pack("B", SSH_MSG_CHANNEL_OPEN) # 90 or 0x5a
    msg+=pack_str("session")
    msg+=struct.pack(">iii", 0, 0x2000, 0x800000)
    pkt=wrap_packet(msg)
    #hexdump.hexdump(pkt)
    my_send_with_MAC(s, pkt)

def expect_msg(s, msg_code):
    if VERBOSITY>=2:
        print (f"expecting msg_code=0x{hex(msg_code)} or {msg_str[msg_code]}")
    msg=unwrap_packet(my_recv_chk_MAC(s))
    if msg[0]==SSH_MSG_DEBUG:
        if VERBOSITY>=2:
            print ("Got SSH_MSG_DEBUG message:")
            hexdump.hexdump(msg)
            print ("Running expect_msg() again")
        return expect_msg(s, msg_code)
    if msg[0]==msg_code:
        return msg
    print ("Fatal error in expect_msg().")
    print (f"While waiting for message msg_code=0x{msg_code}, or {msg_str[msg_code]}")
    print ("Got message 0x{hex(msg[0])} instead, or: {msg_str[msg[0]]}")
    exit(0)

def expect_new_keys(s):
    # plain message, unencrypted
    if VERBOSITY>=2:
        print ("expecting SSH_MSG_NEWKEYS")
    tmp=unwrap_packet(my_recv(s))
    if VERBOSITY>=2:
        print ("expect_new_keys(), got:")
        hexdump.hexdump(tmp)
    assert tmp==struct.pack("B", SSH_MSG_NEWKEYS)

def send_ssh_userauth(s):
    if VERBOSITY>=1:
        print ("sending ssh-userauth")
    msg=wrap_packet(struct.pack("B", SSH_MSG_SERVICE_REQUEST)+pack_str("ssh-userauth"))
    if VERBOSITY>=2:
        print ("send_ssh_userauth(), sending:")
        hexdump.hexdump(msg)
    my_send_with_MAC(s, msg)

def send_ssh_connection_none(s):
    global USERNAME
    if VERBOSITY>=1:
        print ("sending ssh-connection")
    msg=struct.pack("B", SSH_MSG_USERAUTH_REQUEST)
    if USERNAME==None:
        USERNAME="root"
    msg+=pack_str(USERNAME)
    msg+=pack_str("ssh-connection")
    msg+=pack_str("none")
    pkt=wrap_packet(msg)
    if VERBOSITY>=2:
        print ("send_ssh_connection_none(), sending:")
        hexdump.hexdump(pkt)
    my_send_with_MAC(s, pkt)

def send_ssh_connection_password(s, username, password):
    if VERBOSITY>=1:
        print ("sending ssh-connection")
    msg=struct.pack("B", SSH_MSG_USERAUTH_REQUEST)
    msg+=pack_str(username)
    msg+=pack_str("ssh-connection")
    msg+=pack_str("password")
    msg+=b"\x00" # change password byte
    msg+=pack_str(password)
    pkt=wrap_packet(msg)
    if VERBOSITY>=2:
        print ("send_ssh_connection_password(), sending:")
        hexdump.hexdump(pkt)
    my_send_with_MAC(s, pkt)
    msg=unwrap_packet(my_recv_chk_MAC(s))
    if msg[0]==SSH_MSG_USERAUTH_SUCCESS:
        print ("Password correct.")
        return True
    elif msg[0]==SSH_MSG_USERAUTH_FAILURE:
        print ("Fatal error. Password incorrect.")
        return False
    else:
        print (f"{hex(msg[0])=}")
        assert False

def get_list_of_auth_modes(s):
    # this message sent to get list of auth modes:
    send_ssh_connection_none(s)
    msg=expect_msg(s, SSH_MSG_USERAUTH_FAILURE)
    tmp, _ = get_str(msg, 1)
    if VERBOSITY>=1:
        print (f"allowed auth modes: {tmp}")
    return tmp

def send_client_KEX(s):
    tmp=pack_client_KEX()
    if VERBOSITY>=2:
        print ("client KEX:")
        hexdump.hexdump(tmp)
    my_send(s, tmp)

def DH_get_params(s):
    my_send(s, pack_DH_GEX_request())
    tmp=my_recv(s)
    assert len(tmp)!=0
    g, p = unpack_KEXDH_REPLY(tmp)
    return g, p

def DH_send_client_e(s, g, p, bits, GEX):
    # ATTENTION: may be insecure. use better (C)PRNG
    my_secret=random.randrange(1, 2**bits)
    client_e=pow(g, my_secret, p)

    if VERBOSITY>=1:
        print (f"{hex(client_e)=}")

    if GEX:
        my_send(s, pack_DH_GEX_INIT(client_e))
    else:
        my_send(s, pack_SSH_MSG_KEXDH_INIT(client_e))
    return my_secret, client_e

def ECDH_send_client_pubkey(s, KEX_algo_cur):
    ECDHE_client_private_key = cryptography.hazmat.primitives.asymmetric.ec.generate_private_key(
        KEX_algo_EC_func[KEX_algo_cur])

    ECDHE_client_public_key=ECDHE_client_private_key.public_key()

    ECDHE_client_pubkey_raw_bytes = ECDHE_client_public_key.public_bytes(
        cryptography.hazmat.primitives.serialization.Encoding.X962,
        cryptography.hazmat.primitives.serialization.PublicFormat.UncompressedPoint)

    if VERBOSITY>=2:
        print ("ECDHE_client_pubkey_raw_bytes:")
        hexdump.hexdump(ECDHE_client_pubkey_raw_bytes)

    my_send(s, pack_SSH_MSG_KEXDH_INIT_raw(ECDHE_client_pubkey_raw_bytes))
    return ECDHE_client_private_key, ECDHE_client_public_key, ECDHE_client_pubkey_raw_bytes

def DH_recv_server_f(s, GEX):
    tmp=my_recv(s)
    assert len(tmp)!=0
    KEX_host_key, server_f=unpack_DH_REPLY(unwrap_packet(tmp), GEX)
    return KEX_host_key, server_f

def ECDH_do_exchange(s):
    if KEX_ALGO.startswith("ecdh-sha2-nistp"):
        ECDHE_client_private_key, ECDH_pubkey, ECDHE_client_pubkey_raw_bytes=ECDH_send_client_pubkey(s, KEX_ALGO)
        buf=my_recv(s)
        KEX_host_key, ECDH_server_pubkey_raw=unpack_ECDH_REPLY(unwrap_packet(buf))

        ECDH_server_pubkey_decoded=\
        cryptography.hazmat.primitives.asymmetric.ec.EllipticCurvePublicKey.from_encoded_point(
            data=ECDH_server_pubkey_raw,
            curve=KEX_algo_EC_func[KEX_ALGO])

        # calculate ECDH shared secret:
        shared_key = ECDHE_client_private_key.exchange(cryptography.hazmat.primitives.asymmetric.ec.ECDH(), ECDH_server_pubkey_decoded)
        if VERBOSITY>=2:
            print (f"shared_key:")
            hexdump.hexdump(shared_key)

        return KEX_host_key, ECDHE_client_pubkey_raw_bytes, ECDH_server_pubkey_raw, shared_key
    else:
        assert False

def DH_do_exchange(s):
    if KEX_ALGO=="diffie-hellman-group-exchange-sha256":
        g, p = DH_get_params(s)

        my_secret, client_e=DH_send_client_e(s, g, p, DH_GEX_bits, GEX=True)

        KEX_host_key, server_f = DH_recv_server_f(s, GEX=True)

        GEX=True

    elif KEX_ALGO=="diffie-hellman-group-exchange-sha1":
        g, p = DH_get_params(s)

        my_secret, client_e=DH_send_client_e(s, g, p, DH_GEX_bits, GEX=True)

        KEX_host_key, server_f = DH_recv_server_f(s, GEX=True)

        GEX=True

    elif KEX_ALGO=="diffie-hellman-group14-sha256": # 2048 bits
        # http://tools.ietf.org/html/rfc3526#section-3
        p = 0xFFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE45B3DC2007CB8A163BF0598DA48361C55D39A69163FA8FD24CF5F83655D23DCA3AD961C62F356208552BB9ED529077096966D670C354E4ABC9804F1746C08CA18217C32905E462E36CE3BE39E772C180E86039B2783A2EC07A28FB5C55DF06F4C52C9DE2BCBF6955817183995497CEA956AE515D2261898FA051015728E5A8AACAA68FFFFFFFFFFFFFFFF
        g = 2

        my_secret, client_e=DH_send_client_e(s, g, p, 2048, GEX=False)

        KEX_host_key, server_f = DH_recv_server_f(s, GEX=False)

        GEX=False

    elif KEX_ALGO=="diffie-hellman-group14-sha1": # 2048 bits
        # http://tools.ietf.org/html/rfc3526#section-3
        p = 0xFFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE45B3DC2007CB8A163BF0598DA48361C55D39A69163FA8FD24CF5F83655D23DCA3AD961C62F356208552BB9ED529077096966D670C354E4ABC9804F1746C08CA18217C32905E462E36CE3BE39E772C180E86039B2783A2EC07A28FB5C55DF06F4C52C9DE2BCBF6955817183995497CEA956AE515D2261898FA051015728E5A8AACAA68FFFFFFFFFFFFFFFF
        g = 2

        my_secret, client_e=DH_send_client_e(s, g, p, 2048, GEX=False)

        KEX_host_key, server_f = DH_recv_server_f(s, GEX=False)

        GEX=False

    elif KEX_ALGO=="diffie-hellman-group16-sha512": # 4096 bits
        # http://tools.ietf.org/html/rfc3526#section-5
        p = 0xFFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE45B3DC2007CB8A163BF0598DA48361C55D39A69163FA8FD24CF5F83655D23DCA3AD961C62F356208552BB9ED529077096966D670C354E4ABC9804F1746C08CA18217C32905E462E36CE3BE39E772C180E86039B2783A2EC07A28FB5C55DF06F4C52C9DE2BCBF6955817183995497CEA956AE515D2261898FA051015728E5A8AAAC42DAD33170D04507A33A85521ABDF1CBA64ECFB850458DBEF0A8AEA71575D060C7DB3970F85A6E1E4C7ABF5AE8CDB0933D71E8C94E04A25619DCEE3D2261AD2EE6BF12FFA06D98A0864D87602733EC86A64521F2B18177B200CBBE117577A615D6C770988C0BAD946E208E24FA074E5AB3143DB5BFCE0FD108E4B82D120A92108011A723C12A787E6D788719A10BDBA5B2699C327186AF4E23C1A946834B6150BDA2583E9CA2AD44CE8DBBBC2DB04DE8EF92E8EFC141FBECAA6287C59474E6BC05D99B2964FA090C3A2233BA186515BE7ED1F612970CEE2D7AFB81BDD762170481CD0069127D5B05AA993B4EA988D8FDDC186FFB7DC90A6C08F4DF435C934063199FFFFFFFFFFFFFFFF
        g = 2

        my_secret, client_e=DH_send_client_e(s, g, p, 4096, GEX=False)

        KEX_host_key, server_f = DH_recv_server_f(s, GEX=False)

        GEX=False

    elif KEX_ALGO=="diffie-hellman-group18-sha512": # 8192 bits
        # https://www.rfc-editor.org/rfc/rfc8268#section-3
        # https://datatracker.ietf.org/doc/html/rfc3526#section-7
        p = 0xFFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE45B3DC2007CB8A163BF0598DA48361C55D39A69163FA8FD24CF5F83655D23DCA3AD961C62F356208552BB9ED529077096966D670C354E4ABC9804F1746C08CA18217C32905E462E36CE3BE39E772C180E86039B2783A2EC07A28FB5C55DF06F4C52C9DE2BCBF6955817183995497CEA956AE515D2261898FA051015728E5A8AAAC42DAD33170D04507A33A85521ABDF1CBA64ECFB850458DBEF0A8AEA71575D060C7DB3970F85A6E1E4C7ABF5AE8CDB0933D71E8C94E04A25619DCEE3D2261AD2EE6BF12FFA06D98A0864D87602733EC86A64521F2B18177B200CBBE117577A615D6C770988C0BAD946E208E24FA074E5AB3143DB5BFCE0FD108E4B82D120A92108011A723C12A787E6D788719A10BDBA5B2699C327186AF4E23C1A946834B6150BDA2583E9CA2AD44CE8DBBBC2DB04DE8EF92E8EFC141FBECAA6287C59474E6BC05D99B2964FA090C3A2233BA186515BE7ED1F612970CEE2D7AFB81BDD762170481CD0069127D5B05AA993B4EA988D8FDDC186FFB7DC90A6C08F4DF435C93402849236C3FAB4D27C7026C1D4DCB2602646DEC9751E763DBA37BDF8FF9406AD9E530EE5DB382F413001AEB06A53ED9027D831179727B0865A8918DA3EDBEBCF9B14ED44CE6CBACED4BB1BDB7F1447E6CC254B332051512BD7AF426FB8F401378CD2BF5983CA01C64B92ECF032EA15D1721D03F482D7CE6E74FEF6D55E702F46980C82B5A84031900B1C9E59E7C97FBEC7E8F323A97A7E36CC88BE0F1D45B7FF585AC54BD407B22B4154AACC8F6D7EBF48E1D814CC5ED20F8037E0A79715EEF29BE32806A1D58BB7C5DA76F550AA3D8A1FBFF0EB19CCB1A313D55CDA56C9EC2EF29632387FE8D76E3C0468043E8F663F4860EE12BF2D5B0B7474D6E694F91E6DBE115974A3926F12FEE5E438777CB6A932DF8CD8BEC4D073B931BA3BC832B68D9DD300741FA7BF8AFC47ED2576F6936BA424663AAB639C5AE4F5683423B4742BF1C978238F16CBE39D652DE3FDB8BEFC848AD922222E04A4037C0713EB57A81A23F0C73473FC646CEA306B4BCBC8862F8385DDFA9D4B7FA2C087E879683303ED5BDD3A062B3CF5B3A278A66D2A13F83F44F82DDF310EE074AB6A364597E899A0255DC164F31CC50846851DF9AB48195DED7EA1B1D510BD7EE74D73FAF36BC31ECFA268359046F4EB879F924009438B481C6CD7889A002ED5EE382BC9190DA6FC026E479558E4475677E9AA9E3050E2765694DFC81F56E880B96E7160C980DD98EDD3DFFFFFFFFFFFFFFFFF
        g = 2

        my_secret, client_e=DH_send_client_e(s, g, p, 8192, GEX=False)

        KEX_host_key, server_f = DH_recv_server_f(s, GEX=False)

        GEX=False

    elif KEX_ALGO=="diffie-hellman-group1-sha1": # 1024 bits
        # https://datatracker.ietf.org/doc/html/draft-ietf-secsh-transport-09.txt#section-6.1
        # https://datatracker.ietf.org/doc/html/rfc4253#section-8.1
        # https://datatracker.ietf.org/doc/html/rfc2409#section-6.2
        p = 0xFFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE65381FFFFFFFFFFFFFFFF
        g = 2

        my_secret, client_e=DH_send_client_e(s, g, p, 1024, GEX=False)

        KEX_host_key, server_f = DH_recv_server_f(s, GEX=False)

        GEX=False

    else:
        assert False

    # calculate DH shared secret:
    shared_secret=pow(server_f, my_secret, p)
    if VERBOSITY>=1:
        print (f"{hex(shared_secret)=}")

    return g, p, KEX_host_key, client_e, server_f, shared_secret, GEX

def expect_service_accept(s):
    # waiting for SSH_MSG_SERVICE_ACCEPT
    while True:
        msg=unwrap_packet(my_recv_chk_MAC(s))
        #hexdump.hexdump(msg)
        if msg[0]==SSH2_MSG_EXT_INFO:
            if VERBOSITY>=2:
                print ("got SSH2_MSG_EXT_INFO packet:")
        elif msg[0]==SSH_MSG_SERVICE_ACCEPT:
            break
        elif msg[0]==SSH_MSG_IGNORE:
            # "markus" packet, sent by server if DEBUG_KEXDH is on [see sshd.c in OpenSSH 9.0]
            # it is the first encrypted/MACed packed server sends as a test
            # to be ignored
            if VERBOSITY>=2:
                print ("got 'markus' packet:")
                hexdump.hexdump(msg)
        else:
            print (f"error, unhandled packet type at this stage: {msg[0]}, {msg_str[msg[0]]}")

def login_password(s):
    serv_auth_modes=get_list_of_auth_modes(s)
    if "password" not in serv_auth_modes:
        print ("Fatal error: server doesn't accept password auth")
        exit(0)

    if send_ssh_connection_password(s, USERNAME, PASSWORD)==False:
        exit(0)

def send_ssh_connection_publickey(s, username, client_pubkey_blob):
    if VERBOSITY>=1:
        print ("sending ssh-connection")
    msg=struct.pack("B", SSH_MSG_USERAUTH_REQUEST)
    msg+=pack_str(username)           # userstyle
    msg+=pack_str("ssh-connection")   # service
    msg+=pack_str("publickey")        # method
    msg+=b"\x00"                      # have_sig
    msg+=pack_str(SERVER_HOST_ALGO)   # pkalg
    msg+=pack_buf(client_pubkey_blob) # pkblob
    pkt=wrap_packet(msg)
    if VERBOSITY>=2:
        print ("send_ssh_connection_publickey(), sending:")
        hexdump.hexdump(pkt)
    my_send_with_MAC(s, pkt)
    #return msg # will be reused

def get_client_pubkey_blob_from_file(fname):
    f=open(fname)
    first_line=f.readline().rstrip()
    t=first_line.split(" ")
    _type=t[0]
    blob=base64.b64decode(t[1])
    return _type, blob

def RSA_sign_blob_with_pri_key(message, pri_key_fname):
    global SERVER_HOST_ALGO_PTR2
    f=open(pri_key_fname, "rb")
    d=f.read()
    f.close()
    t=cryptography.hazmat.primitives.serialization.load_ssh_private_key(data=d, password=b"")

    signature = t.sign(
        message,
        cryptography.hazmat.primitives.asymmetric.padding.PKCS1v15(),
        SERVER_HOST_ALGO_PTR2())
    return signature

def EC_sign_blob_with_pri_key(message, pri_key_fname):
    global SERVER_HOST_ALGO_PTR2
    f=open(pri_key_fname, "rb")
    d=f.read()
    f.close()
    t=cryptography.hazmat.primitives.serialization.load_ssh_private_key(data=d, password=b"")

    if VERBOSITY>=2:
        print ("EC_sign_blob_with_pri_key() message to sign:")
        hexdump.hexdump(message)

        print (f"{t.curve=}")
        print (f"{t=}")

    signature = t.sign(
        message,
        cryptography.hazmat.primitives.asymmetric.ec.ECDSA(SERVER_HOST_ALGO_PTR2()))
    # here signature is in ASN1 format
    r, s=cryptography.hazmat.primitives.asymmetric.utils.decode_dss_signature(signature)
    return pack_mpint(r)+pack_mpint(s)

"""
publickey auth method: RFC 4252 7
https://datatracker.ietf.org/doc/html/rfc4252#section-7

hostbound auth method:
https://www.openssh.com/agent-restrict.html#authverify
https://github.com/openssh/openssh-portable/blob/master/PROTOCOL#L347

userauth_pubkey() in auth2-pubkey.c in OpenSSH:
https://android.googlesource.com/platform/external/openssh/+/40e3686a7062605c5d122885a74ec478a26c2c77/auth2-pubkey.c#90
"""

def send_publickey_signed(s, kexgex_hash, client_pubkey_blob, server_pubkey_blob):
    global PRIKEY_FNAME, USERNAME
    msg=struct.pack("B", SSH_MSG_USERAUTH_REQUEST)
    msg+=pack_str(USERNAME)           # userstyle
    msg+=pack_str("ssh-connection")   # service
    if HOSTBOUND:
        msg+=pack_str("publickey-hostbound-v00@openssh.com") # method
    else:
        msg+=pack_str("publickey")    # method
    msg+=b"\x01"                      # have_sig
    msg+=pack_str(SERVER_HOST_ALGO) # pkalg
    msg+=pack_buf(client_pubkey_blob)
    if HOSTBOUND:
        msg+=pack_buf(server_pubkey_blob)

    # 'reconstruct' packet for signing
    msg_to_be_signed=pack_buf(kexgex_hash)
    msg_to_be_signed+=struct.pack("B", SSH_MSG_USERAUTH_REQUEST)
    msg_to_be_signed+=pack_str(USERNAME)         # userstyle
    msg_to_be_signed+=pack_str("ssh-connection") # service
    if HOSTBOUND:
        msg_to_be_signed+=pack_str("publickey-hostbound-v00@openssh.com") # method
    else:
        msg_to_be_signed+=pack_str("publickey") # method
    msg_to_be_signed+=b"\x01"                   # have_sig
    msg_to_be_signed+=pack_str(SERVER_HOST_ALGO)  # pkalg
    msg_to_be_signed+=pack_buf(client_pubkey_blob)
    if HOSTBOUND:
        msg_to_be_signed+=pack_buf(server_pubkey_blob)
    
    if VERBOSITY>=2:
        print ("send_publickey_signed(), packet to be signed:")
        hexdump.hexdump(msg_to_be_signed)

    if "rsa" in SERVER_HOST_ALGO:
        sig=RSA_sign_blob_with_pri_key(msg_to_be_signed, PRIKEY_FNAME)
    else:
        sig=EC_sign_blob_with_pri_key(msg_to_be_signed, PRIKEY_FNAME)
    
    if VERBOSITY>=2:
        print ("send_publickey_signed(), signature:")
        hexdump.hexdump(sig)
    
    pkt=wrap_packet(msg + pack_buf(pack_str(SERVER_HOST_ALGO) + pack_buf(sig)))

    if VERBOSITY>=2:
        print ("send_publickey_signed(), sending:")
        hexdump.hexdump(pkt)
    my_send_with_MAC(s, pkt)

def login_publickey(s, KEX_host_key, kexgex_hash):
    serv_auth_modes=get_list_of_auth_modes(s)
    if "publickey" not in serv_auth_modes:
        print ("Fatal error: server doesn't accept publickey auth")
        exit(0)
    _type, client_pubkey_blob=get_client_pubkey_blob_from_file(PUBKEY_FNAME)
    if VERBOSITY>=1:
        print (f"login_publickey() {_type=} from {PUBKEY_FNAME=}")
    send_ssh_connection_publickey(s, USERNAME, client_pubkey_blob)
    msg=unwrap_packet(my_recv_chk_MAC(s))
    if VERBOSITY>=2:
        print ("login_publickey(), got our public key back:")
        hexdump.hexdump(msg)

    send_publickey_signed(s, kexgex_hash, client_pubkey_blob, KEX_host_key)

    if VERBOSITY>=2:
        print ("login_publickey(), expecting SSH_MSG_USERAUTH_SUCCESS or _FAILURE")
    
    msg=unwrap_packet(my_recv_chk_MAC(s))
    if msg[0]==SSH_MSG_USERAUTH_SUCCESS:
        print ("Publickey auth finished correctly.")
        return True
    elif msg[0]==SSH_MSG_USERAUTH_FAILURE:
        print ("Fatal error. Publickey auth finished incorrectly.")
        return False
    else:
        print (f"Fatal error. msg[0]=0x{hex(msg[0])}")
        assert False

def exec_recv_response(s):
    # receiving loop
    # as in https://en.wikipedia.org/wiki/Inversion_of_control
    # as in GUI
    while True:
        msg=unwrap_packet(my_recv_chk_MAC(s))
        msg_code=msg[0]
        if VERBOSITY>=1:
            print (f"got message 0x{hex(msg_code)} or {msg_str[msg_code]}, {len(msg)=}")
        if msg_code==SSH_MSG_CHANNEL_WINDOW_ADJUST:
            # see [RFC4254 5.2]
            pass
        elif msg_code==SSH_MSG_CHANNEL_SUCCESS:
            pass
        elif msg_code==SSH_MSG_CHANNEL_DATA:
            unpack_channel_data(msg)
        elif msg_code==SSH_MSG_CHANNEL_EXTENDED_DATA:
            unpack_channel_extended_data(msg)
        elif msg_code==SSH_MSG_CHANNEL_EOF:
            pass
        elif msg_code==SSH_MSG_CHANNEL_REQUEST:
            unpack_exit_status(s, msg)
        elif msg_code==SSH_MSG_CHANNEL_CLOSE:
            break
        else:
            print (f"Fatal error. Unhandled msg_code=0x{hex(msg_code)} or {msg_str[msg_code]}")
            exit(0)

def exec_command(s, cmd):
    send_exec(s, cmd)
    exec_recv_response(s)

def expect_messages_till_SSH_MSG_CHANNEL_OPEN_CONFIRMATION(s):
    msg=unwrap_packet(my_recv_chk_MAC(s))
    if msg[0]==SSH_MSG_GLOBAL_REQUEST:
        # [see RFC4254 4]: https://www.rfc-editor.org/rfc/rfc4254#section-4
        if VERBOSITY>=1:
            print ("Got SSH_MSG_GLOBAL_REQUEST message")
            if VERBOSITY>=2:
                hexdump.hexdump(msg)
            idx=1
            tmp, idx = get_str(msg, idx)
            print (f"request name {tmp}")
            print (f"want reply {msg[idx]}")
            assert msg[idx]==0 # we don't handle replying (yet)
            print ("Running expect_messages_till_SSH_MSG_CHANNEL_OPEN_CONFIRMATION() again")
        return expect_messages_till_SSH_MSG_CHANNEL_OPEN_CONFIRMATION(s)
    if msg[0]==SSH_MSG_DEBUG:
        if VERBOSITY>=1:
            print ("Got SSH_MSG_DEBUG message:")
            if VERBOSITY>=2:
                hexdump.hexdump(msg)
            print ("Running expect_messages_till_SSH_MSG_CHANNEL_OPEN_CONFIRMATION() again")
        return expect_messages_till_SSH_MSG_CHANNEL_OPEN_CONFIRMATION(s)
    if msg[0]==SSH_MSG_CHANNEL_OPEN_CONFIRMATION:
        return msg
    print ("Fatal error in expect_msg().")
    print (f"While waiting for message 0x{hex(msg_code)}, or: {msg_str[msg_code]}")
    print (f"Got message 0x{hex(msg[0])} instead, or: {msg_str[msg[0]]}")
    exit(0)

# yes! byte-by-byte
# rationale: sometimes banner + next packet is sent by server in one tcp/ip packet
# and we don't know banner's size beforehand
def read_banner(s):
    rt=b""
    while True:
        c=s.recv(1)
        if c==b"\n":
            break
        rt+=c
    return rt.rstrip()

def do_all():
    global recv_seqno, KEX_ALGO, KEX_HASH, SERVER_HOST_ALGO, CIPHER_ALGO, SERVER_HOST_ALGO_NIBBLES
    global SERVER_HOST_ALGO_PTR, SERVER_HOST_ALGO_PTR2
    global MAC_SIZE, MAC_ALGO, CIPHER_KEY_SIZE, ENCRYPTION, EXT_INFO_C

    s = socket.socket(socket.AF_INET)
    s.connect((HOST, PORT))

    serv_banner=read_banner(s)
    recv_seqno+=1
    if VERBOSITY>=1:
        print ("serv_banner:", serv_banner.decode("utf-8"))

    my_send(s, client_banner+b"\r\n")

    unpack_serv_KEX(my_recv(s))

    send_client_KEX(s)

    KEX_ALGO=first_common_element(KEX_ALGOS.split(","), from_serv_kex_algorithms.split(","))
    if KEX_ALGO==None:
        print ("Fatal error. No common KEX algorithm")
        print (f"Server: {from_serv_kex_algorithms}")
        print (f"What we support: {KEX_ALGOS}")
        exit(0)
    if VERBOSITY>=1:
        print (f"Picking KEX algo {KEX_ALGO}")

    KEX_ALGO_hash_func={
    "diffie-hellman-group1-sha1":           hashlib.sha1,
    "diffie-hellman-group14-sha1":          hashlib.sha1,
    "diffie-hellman-group14-sha256":        hashlib.sha256,
    "diffie-hellman-group16-sha512":        hashlib.sha512,
    "diffie-hellman-group18-sha512":        hashlib.sha512,
    "diffie-hellman-group-exchange-sha1":   hashlib.sha1,
    "diffie-hellman-group-exchange-sha256": hashlib.sha256,
    "ecdh-sha2-nistp256":                   hashlib.sha256,
    "ecdh-sha2-nistp384":                   hashlib.sha384,
    "ecdh-sha2-nistp521":                   hashlib.sha512}
    KEX_HASH=KEX_ALGO_hash_func[KEX_ALGO]

    CIPHER_ALGO=first_common_element(CIPHER_ALGOS.split(","), from_serv_encryption_algorithms.split(","))
    if CIPHER_ALGO==None:
        print ("Fatal error. No common cipher algorithm")
        print (f"Server: {from_serv_encryption_algorithms}")
        print (f"What we support: {CIPHER_ALGOS}")
        exit(0)
    if VERBOSITY>=1:
        print (f"Picking cipher: {CIPHER_ALGO}")

    CIPHER_ALGO_list={
    "aes128-ctr": (True, 16),
    "aes256-ctr": (True, 32),
    "none": (False, 0)}
    ENCRYPTION, CIPHER_KEY_SIZE = CIPHER_ALGO_list[CIPHER_ALGO]

    SERVER_HOST_ALGO=first_common_element(SERVER_HOST_ALGOS.split(","), from_serv_server_host_algorithms.split(","))
    if SERVER_HOST_ALGO==None:
        print ("Fatal error. No common server host algorithm")
        print (f"Server: {from_serv_server_host_algorithms}")
        print (f"What we support: {SERVER_HOST_ALGOS}")
        exit(0)
    if VERBOSITY>=1:
        print (f"Picking server host algo {SERVER_HOST_ALGO}")

    SERVER_HOST_ALGO_list={
    "ssh-dss":             (40,  hashlib.sha1,   cryptography.hazmat.primitives.hashes.SHA1),
    "ssh-rsa":             (40,  hashlib.sha1,   cryptography.hazmat.primitives.hashes.SHA1),
    "rsa-sha2-256":        (64,  hashlib.sha256, cryptography.hazmat.primitives.hashes.SHA256),
    "rsa-sha2-512":        (128, hashlib.sha512, cryptography.hazmat.primitives.hashes.SHA512),
    "ecdsa-sha2-nistp256": (64,  hashlib.sha256, cryptography.hazmat.primitives.hashes.SHA256)}
    SERVER_HOST_ALGO_NIBBLES, SERVER_HOST_ALGO_PTR, SERVER_HOST_ALGO_PTR2 = SERVER_HOST_ALGO_list[SERVER_HOST_ALGO]

    MAC_ALGO=first_common_element(MAC_ALGOS.split(","), from_serv_mac_algorithms.split(","))
    if MAC_ALGO==None:
        print ("Fatal error. No common MAC algorithm")
        print (f"Server: {from_serv_mac_algorithms}")
        print (f"What we support: {MAC_ALGOS}")
        exit(0)
    if VERBOSITY>=1:
        print (f"Picking MAC algo {MAC_ALGO}")

    MAC_ALGO_list={
    "hmac-sha2-256": (0x20,   hashlib.sha256),
    "hmac-sha2-512": (0x40,   hashlib.sha512),
    "hmac-sha1":     (160//8, hashlib.sha1)}
    MAC_SIZE, MAC_ALGO = MAC_ALGO_list[MAC_ALGO]

    if KEX_ALGO.startswith("ecdh-sha2-nistp"):
        KEX_host_key, client_pub, server_pub, shared_secret_buf = ECDH_do_exchange(s)
        # calculate the hash of hodgepodge. the most important part - shared_secret, which is unknown to interceptor
        # all other parts can be easily intercepted!
        client_pub_i=int.from_bytes(client_pub, byteorder='big')
        server_pub_i=int.from_bytes(server_pub, byteorder='big')
        shared_secret=int.from_bytes(shared_secret_buf, byteorder='big')

        kexgex_hash=calc_kexgex_hash(client_banner, serv_banner, client_KEX_blob, serv_KEX_blob, KEX_host_key, client_pub_i, server_pub_i, shared_secret, 0, 0, GEX=False)
    elif KEX_ALGO.startswith ("diffie-hellman-group"):
        g, p, KEX_host_key, client_e, server_f, shared_secret, GEX = DH_do_exchange(s)
        # calculate the hash of hodgepodge. the most important part - shared_secret, which is unknown to interceptor
        # all other parts can be easily intercepted!
        kexgex_hash=calc_kexgex_hash(client_banner, serv_banner, client_KEX_blob, serv_KEX_blob, KEX_host_key, client_e, server_f, shared_secret, g, p, GEX)
    else:
        assert False
    if VERBOSITY>=2:
        print ("kexgex_hash:")
        hexdump.hexdump(kexgex_hash)

    x={
    "ssh-dss":             chk_DSA_kexgex_hash,
    "ssh-rsa":             chk_RSA_kexgex_hash,
    "rsa-sha2-256":        chk_RSA_kexgex_hash,
    "rsa-sha2-512":        chk_RSA_kexgex_hash,
    "ecdsa-sha2-nistp256": chk_ECDSA_kexgex_hash}
    x[SERVER_HOST_ALGO](kexgex_hash)

    derive_keys(shared_secret, kexgex_hash, kexgex_hash)

    expect_new_keys(s)

    send_new_keys(s)

    send_ssh_userauth(s)

    expect_service_accept(s)

    if USERNAME!=None and PASSWORD!=None:
        if login_password(s)==False:
            return
    elif USERNAME!=None and PUBKEY_FNAME!=None and PRIKEY_FNAME!=None:
        if login_publickey(s, KEX_host_key, kexgex_hash)==False:
            return
    else:
        print ("Fatal error: can't go further without username/passsword/pubkey/prikey. exiting.")
        exit(0)

    send_channel_open(s)

    expect_messages_till_SSH_MSG_CHANNEL_OPEN_CONFIRMATION(s)

    exec_command(s, COMMAND)

    # be polite
    send_eof(s)
    s.close()

parse_command_line()
if HOST==None:
    print ("Fatal error: host not set")
    exit(0)

if "none" in CIPHER_ALGOS:
    EXT_INFO_C=True
else:
    EXT_INFO_C=False

do_all()
