#!/usr/bin/env python3

# ATTENTION. This code is only for demonstration and learning only.
# It's not supposed to be used in serious/production job.
# Beware! It may have bugs and vulnerabilities.
# For serious job, use OpenSSH, libssh, etc...

# https://pypi.org/project/hexdump/
# pip3 install hexdump
import hexdump

import socket, struct, random, math
import hashlib, hmac
import sys, base64
import itertools

# pip3 install pycryptodome
# (openbsd) pkg_add py3-cryptodome
# https://pycryptodome.readthedocs.io/en/latest/src/installation.html
# https://pycryptodome.readthedocs.io/en/latest/src/cipher/aes.html
import Crypto.Cipher.AES
# https://pycryptodome.readthedocs.io/en/latest/src/util/util.html
import Crypto.Util.Counter

# May need to be upgraded:
# python3 -m pip install --upgrade cryptography
# pip3 install cryptography
# (openbsd) pkg_add py3-cryptography
import cryptography.hazmat.primitives.serialization
import cryptography.hazmat.primitives.hashes
import cryptography.hazmat.primitives.asymmetric.padding

HOST=None
PORT=22
USERNAME=None
PASSWORD=None
ENCRYPTION=False # for first packets
VERBOSITY=0
COMMAND="uname -a" # default
PUBKEY_FNAME=None
PRIKEY_FNAME=None
HOSTBOUND=False
SAVE_SERV_PUB_KEY=False

KEX_ALGOS_list=[
"diffie-hellman-group-exchange-sha256",
"diffie-hellman-group1-sha1",
"diffie-hellman-group14-sha1",
"diffie-hellman-group14-sha256",
"diffie-hellman-group16-sha512",
"diffie-hellman-group18-sha512",
"ecdh-sha2-nistp256",
"ecdh-sha2-nistp384",
"ecdh-sha2-nistp521"]
KEX_ALGOS=",".join(KEX_ALGOS_list)

# https://cryptography.io/en/latest/hazmat/primitives/asymmetric/ec/
KEX_algo_EC_func={
"ecdh-sha2-nistp256": cryptography.hazmat.primitives.asymmetric.ec.SECP256R1(),
"ecdh-sha2-nistp384": cryptography.hazmat.primitives.asymmetric.ec.SECP384R1(),
"ecdh-sha2-nistp521": cryptography.hazmat.primitives.asymmetric.ec.SECP521R1()}

CIPHER_ALGOS_list=[
"aes256-ctr",
"aes128-ctr",
"none"]
CIPHER_ALGOS=",".join(CIPHER_ALGOS_list)

MAC_ALGOS_list=[
"hmac-sha2-256",
"hmac-sha2-512",
"hmac-sha1"]
MAC_ALGOS=",".join(MAC_ALGOS_list)

SERVER_HOST_ALGOS_list=[
"ecdsa-sha2-nistp256",
"rsa-sha2-256",
"rsa-sha2-512",
"ssh-rsa",
"ssh-dss"]
SERVER_HOST_ALGOS=",".join(SERVER_HOST_ALGOS_list)

def parse_command_line():
    global HOST, PORT, USERNAME, PASSWORD, VERBOSITY, COMMAND, PUBKEY_FNAME, PRIKEY_FNAME, HOSTBOUND
    global KEX_ALGOS, CIPHER_ALGOS, SERVER_HOST_ALGOS, SAVE_SERV_PUB_KEY

    if len(sys.argv)==1:
        print ("Error: no options supplied. Use these:")
        print ("  -v                 - increase verbosity level. but no dumps")
        print ("  -vv                - increase verbosity level. dumps")
        print ("  -vvv               - increase verbosity level. even more info")
        print ("  -h host")
        print ("  -port port         - default is 22")
        print ("  -c command         - run command on SSH host")
        print ("  -save_serv_pub_key - save it")
        print ("if password auth is used:")
        print ("  -u user")
        print ("  -pass password")
        print ("if publickey auth is used:")
        print ("  -u user")
        print ("  -pubkey fname - like id_rsa.pub")
        print ("  -prikey fname - like id_rsa")
        print ("Mostly used for testing:")
        print ("  -kex algos              - forcibly set KEX algos")
        print ("  -cipher algos           - forcibly set cipher algos. may be 'none'")
        print ("  -server_host_algo algos - forcibly set cipher algos. may be 'none'")
        print ("  -mac algos              - forcibly set MAC algos")
        print ("  -hostbound")
        exit(0)
    idx=1
    while idx<len(sys.argv):
        if sys.argv[idx]=="-v":
            VERBOSITY=1
            idx+=1
        elif sys.argv[idx]=="-vv":
            VERBOSITY=2
            idx+=1
        elif sys.argv[idx]=="-vvv":
            VERBOSITY=3
            idx+=1
        elif sys.argv[idx]=="-hostbound":
            HOSTBOUND=True
            idx+=1
        elif sys.argv[idx]=="-save_serv_pub_key":
            SAVE_SERV_PUB_KEY=True
            idx+=1
        elif sys.argv[idx]=="-h":
            HOST=sys.argv[idx+1]
            idx+=2
        elif sys.argv[idx]=="-port":
            PORT=int(sys.argv[idx+1])
            idx+=2
        elif sys.argv[idx]=="-u":
            USERNAME=sys.argv[idx+1]
            idx+=2
        elif sys.argv[idx]=="-pass":
            PASSWORD=sys.argv[idx+1]
            idx+=2
        elif sys.argv[idx]=="-pubkey":
            PUBKEY_FNAME=sys.argv[idx+1]
            idx+=2
        elif sys.argv[idx]=="-prikey":
            PRIKEY_FNAME=sys.argv[idx+1]
            idx+=2
        elif sys.argv[idx]=="-c":
            COMMAND=sys.argv[idx+1]
            idx+=2
        elif sys.argv[idx]=="-kex":
            KEX_ALGOS=sys.argv[idx+1]
            idx+=2
        elif sys.argv[idx]=="-mac":
            MAC_ALGOS=sys.argv[idx+1]
            idx+=2
        elif sys.argv[idx]=="-cipher":
            CIPHER_ALGOS=sys.argv[idx+1]
            idx+=2
        elif sys.argv[idx]=="-server_host_algo":
            SERVER_HOST_ALGOS=sys.argv[idx+1]
            idx+=2
        else:
            print (f"Fatal error: unknown command option: {sys.argv[idx]}")
            exit(0)

#DH_GEX_min, DH_GEX_bits, DH_GEX_max = 2048, 8192, 8192 # slow in Python!
DH_GEX_min, DH_GEX_bits, DH_GEX_max = 2048, 2048, 2048
#DH_GEX_min, DH_GEX_bits, DH_GEX_max = 512, 512, 8192 # for test

client_banner=b'SSH-2.0-ToySSH_0.?.?'

SSH_MSG_DISCONNECT=1
SSH_MSG_IGNORE=2
SSH_MSG_UNIMPLEMENTED=3
SSH_MSG_DEBUG=4
SSH_MSG_SERVICE_REQUEST=5
SSH_MSG_SERVICE_ACCEPT=6
SSH2_MSG_EXT_INFO=7 # from OpenSSH 9.0

SSH_MSG_KEXINIT=20
SSH_MSG_NEWKEYS=21 # 0x15

SSH_MSG_KEXDH_INIT=30 # 0x1e
SSH_MSG_KEXDH_REPLY=31 # 0x1f

SSH_MSG_KEX_DH_GEX_INIT=32 # 0x20
SSH_MSG_KEX_DH_GEX_REPLY=33 # 0x21
SSH_MSG_KEX_DH_GEX_REQUEST=34 # 0x22

SSH_MSG_USERAUTH_REQUEST=50 # 0x32 
SSH_MSG_USERAUTH_FAILURE=51 # 0x33 
SSH_MSG_USERAUTH_SUCCESS=52 # 0x34 
SSH_MSG_USERAUTH_BANNER=53 # 0x35

SSH_MSG_GLOBAL_REQUEST=80 # 0x50
SSH_MSG_CHANNEL_OPEN=90 # 0x5a
SSH_MSG_CHANNEL_OPEN_CONFIRMATION=91 # 0x5b
SSH_MSG_CHANNEL_WINDOW_ADJUST=93 # 0x5d
SSH_MSG_CHANNEL_DATA=94 # 0x5e
SSH_MSG_CHANNEL_EXTENDED_DATA=95 # 0x5f
SSH_MSG_CHANNEL_EOF=96 # 0x60
SSH_MSG_CHANNEL_CLOSE=97 # 0x61
SSH_MSG_CHANNEL_REQUEST=98 # 0x62
SSH_MSG_CHANNEL_SUCCESS=99 # 0x63
SSH_MSG_CHANNEL_FAILURE=100

msg_str={
    SSH_MSG_DISCONNECT                : "SSH_MSG_DISCONNECT",
    SSH_MSG_IGNORE                    : "SSH_MSG_IGNORE",
    SSH_MSG_UNIMPLEMENTED             : "SSH_MSG_UNIMPLEMENTED",
    SSH_MSG_DEBUG                     : "SSH_MSG_DEBUG",
    SSH_MSG_SERVICE_REQUEST           : "SSH_MSG_SERVICE_REQUEST",
    SSH_MSG_SERVICE_ACCEPT            : "SSH_MSG_SERVICE_ACCEPT",
    SSH2_MSG_EXT_INFO                 : "SSH2_MSG_EXT_INFO",
    SSH_MSG_KEXINIT                   : "SSH_MSG_KEXINIT",
    SSH_MSG_NEWKEYS                   : "SSH_MSG_NEWKEYS",
    SSH_MSG_KEXDH_REPLY               : "SSH_MSG_KEXDH_REPLY",
    SSH_MSG_KEXDH_INIT                : "SSH_MSG_KEXDH_INIT",
    SSH_MSG_KEX_DH_GEX_INIT           : "SSH_MSG_KEX_DH_GEX_INIT",
    SSH_MSG_KEX_DH_GEX_REPLY          : "SSH_MSG_KEX_DH_GEX_REPLY",
    SSH_MSG_KEX_DH_GEX_REQUEST        : "SSH_MSG_KEX_DH_GEX_REQUEST",
    SSH_MSG_USERAUTH_REQUEST          : "SSH_MSG_USERAUTH_REQUEST",
    SSH_MSG_USERAUTH_FAILURE          : "SSH_MSG_USERAUTH_FAILURE",
    SSH_MSG_USERAUTH_SUCCESS          : "SSH_MSG_USERAUTH_SUCCESS",
    SSH_MSG_USERAUTH_BANNER           : "SSH_MSG_USERAUTH_BANNER",
    SSH_MSG_GLOBAL_REQUEST            : "SSH_MSG_GLOBAL_REQUEST",
    SSH_MSG_CHANNEL_OPEN              : "SSH_MSG_CHANNEL_OPEN",
    SSH_MSG_CHANNEL_OPEN_CONFIRMATION : "SSH_MSG_CHANNEL_OPEN_CONFIRMATION",
    SSH_MSG_CHANNEL_WINDOW_ADJUST     : "SSH_MSG_CHANNEL_WINDOW_ADJUST",
    SSH_MSG_CHANNEL_DATA              : "SSH_MSG_CHANNEL_DATA",
    SSH_MSG_CHANNEL_EXTENDED_DATA     : "SSH_MSG_CHANNEL_EXTENDED_DATA",
    SSH_MSG_CHANNEL_EOF               : "SSH_MSG_CHANNEL_EOF",
    SSH_MSG_CHANNEL_CLOSE             : "SSH_MSG_CHANNEL_CLOSE",
    SSH_MSG_CHANNEL_REQUEST           : "SSH_MSG_CHANNEL_REQUEST",
    SSH_MSG_CHANNEL_SUCCESS           : "SSH_MSG_CHANNEL_SUCCESS",
    SSH_MSG_CHANNEL_FAILURE           : "SSH_MSG_CHANNEL_FAILURE",
}

# first list has higher priority!
def first_common_element (lst1, lst2):
    #print ("first_common_element:")
    #print ("lst1", lst1)
    #print ("lst2", lst2)
    if len(lst1)==0:
        return None
    if len(lst2)==0:
        return None
    # https://stackoverflow.com/questions/8534256/find-first-element-in-a-sequence-that-matches-a-predicate
    cartesian=itertools.product(lst1, lst2)
    try:
        return next(filter(lambda x: x[0]==x[1], cartesian))[0]
    except StopIteration:
        return None

def first_common_element_test():
    x=[1,5,7,9]
    y=[0,7,9,5]
    print (first_common_element(x, y))

    print (first_common_element(['x', 'y', 'z'], [1, 2, 'y']))

    print (first_common_element(['x', 'y', 'z'], [1, 2, 3]))

def recv_at_least(s, size):
    if VERBOSITY>=2:
        print (f"recv_at_least(), {size=}")
    buf=b""
    left=size
    while left>0:
        got=s.recv(left)
        if len(got)==0:
            print ("Fatal error. recv_at_least(): got nothing. probably it's supposed client should close connection?")
            exit(0)
        left=left-len(got)
        buf+=got
    return buf

recv_seqno=0
def my_recv_plain(s):
    global recv_seqno
    #tmp_len=s.recv(4)
    tmp_len=recv_at_least(s, 4)
    if len(tmp_len)<4:
        print (f"my_recv_plain() fatal error. incomplete read while reading packet len, got {len(tmp_len)} bytes")
        hexdump.hexdump(tmp_len)
        exit(0)
    pkt_len=struct.unpack (">i", tmp_len)[0]
    if VERBOSITY>=2:
        print ("my_recv_plain() pkt_len: %d or 0x%x" % (pkt_len & 0xffffffff, pkt_len & 0xffffffff))
    #tmp=s.recv(pkt_len) # FIXME: loop here
    tmp=recv_at_least(s, pkt_len)
    if VERBOSITY>=2:
        print ("my_recv_plain() got:")
        hexdump.hexdump(tmp_len+tmp)
    recv_seqno+=1
    return tmp_len+tmp

def my_recv_MAC(s):
    global recv_seqno
    #tmp_len=s.recv(4)
    tmp_len=recv_at_least(s, 4)
    if len(tmp_len)<4:
        print (f"my_recv_MAC() incomplete read while reading packet len, got only {len(tmp_len)} bytes")
        hexdump.hexdump(tmp_len)
        exit(0)
    pkt_len=struct.unpack (">i", tmp_len)[0]
    if VERBOSITY>=2:
        print ("my_recv_MAC() pkt_len: %d or 0x%x" % (pkt_len & 0xffffffff, pkt_len & 0xffffffff))
    pkt_len+=MAC_SIZE
    #tmp=s.recv(pkt_len)
    tmp=recv_at_least(s, pkt_len)
    if VERBOSITY>=2:
        print ("my_recv_MAC() got:")
        hexdump.hexdump(tmp_len+tmp)

    recv_seqno+=1
    return tmp_len+tmp

AES_BLOCK_SIZE_IN_BITS=128
AES_BLOCK_SIZE_IN_BYTES=AES_BLOCK_SIZE_IN_BITS//8

def my_recv_encrypted(s):
    global CIPHER_ALGO, recv_seqno, serv_to_client_ctr

    if CIPHER_ALGO=="aes128-ctr":
        KEY=key_D[0:128//8]
        assert len(KEY)==16
    elif CIPHER_ALGO=="aes256-ctr":
        KEY=key_D[0:256//8]
        assert len(KEY)==32
    else:
        assert False

    if VERBOSITY>=2:
        print ("my_recv_encrypted() begin")
    #first_block=s.recv(AES_BLOCK_SIZE_IN_BYTES)
    first_block=recv_at_least(s, AES_BLOCK_SIZE_IN_BYTES)
    if len(first_block)==0:
        print ("my_recv_encrypted() fatal error. recv() returned nothing")
        exit(0)
    if len(first_block)<AES_BLOCK_SIZE_IN_BYTES:
        print (f"my_recv_encrypted() fatal error. incomplete read while reading packet len, got only {len(first_block)} bytes")
        hexdump.hexdump(first_block)
        exit(0)

    ctr = Crypto.Util.Counter.new(128, little_endian=False, initial_value=serv_to_client_ctr)
    t = Crypto.Cipher.AES.new(KEY, Crypto.Cipher.AES.MODE_CTR, counter = ctr)
    first_block_decrypted = t.decrypt(first_block)
    if VERBOSITY>=2:
        print ("first block decrypted:")
        hexdump.hexdump(first_block_decrypted)
    pkt_len=struct.unpack(">i", first_block_decrypted[0:4])[0]
    if VERBOSITY>=2:
        print ("my_recv_encrypted() pkt_len: %d or 0x%x" % (pkt_len & 0xffffffff, pkt_len & 0xffffffff))

    already_read=AES_BLOCK_SIZE_IN_BYTES
    more_to_read=pkt_len-already_read+4
    #rest=s.recv(more_to_read)
    rest=recv_at_least(s, more_to_read)

    ctr = Crypto.Util.Counter.new(128, little_endian=False, initial_value=serv_to_client_ctr)
    t = Crypto.Cipher.AES.new(KEY, Crypto.Cipher.AES.MODE_CTR, counter = ctr)
    all_decrypted = t.decrypt(first_block+rest)
    blocks_total=len(first_block+rest)//AES_BLOCK_SIZE_IN_BYTES
    if VERBOSITY>=2:
        print ("decrypted:")
        hexdump.hexdump(all_decrypted)
    serv_to_client_ctr+=blocks_total

    #mac=s.recv(MAC_SIZE)
    mac=recv_at_least(s, MAC_SIZE)
    if VERBOSITY>=2:
        print ("read MAC")
        hexdump.hexdump(mac)

    recv_seqno+=1
    return all_decrypted + mac

def my_recv(s, MAC=False):
    if MAC:
        if ENCRYPTION:
            return my_recv_encrypted(s)
        else:
            return my_recv_MAC(s)
    else:
        return my_recv_plain(s)

send_seqno=0
def my_send(s, buf):
    global send_seqno
    if VERBOSITY>=2:
        print ("my_send() sending:")
        hexdump.hexdump(buf)
    tmp=s.send(buf)
    send_seqno+=1

# From http://docs.oracle.com/javase/specs/jvms/se7/html/jvms-3.html (3.3)
def align2grain (i, grain):
    return (i + grain-1) & ~(grain-1)

def get_buf (buf, idx):
    length=struct.unpack (">i", buf[idx:idx+4])[0]
    s=buf[idx+4:idx+4+length]
    return s, idx+4+length

def get_str (buf, idx):
    s, l = get_buf(buf, idx)
    return s.decode("utf-8"), l

def get_mpint (buf, idx):
    s, idx = get_buf(buf, idx)
    i=int.from_bytes(s, byteorder='big')
    return i, idx

def get_u32 (buf, idx):
    rt=struct.unpack (">i", buf[idx:idx+4])[0]
    return rt, idx+4

def pack_buf(buf):
    return struct.pack(">i", len(buf)) + buf

def pack_str(s):
    return pack_buf(s.encode("utf-8"))

def binlog(x):
    rt=int(math.ceil(math.log(x, 2)))
    return rt

def bits_for_mpint(i):
    return int(math.ceil(binlog(i)))

def bytes_for_mpint(i):
    return bits_for_mpint(i)//8 + 1

# we add zero byte if most significant bit=1 and this number can be treated as negative
# but we must send number in unsigned form, so extend it
def cvt_to_mpint_and_add_zero_if_needed(i):
    buf=i.to_bytes(bytes_for_mpint(i), byteorder='big')
    if (buf[0]&0x80)==0x80:
        return b"\x00" + buf
    else:
        return buf

def pack_mpint(i):
    tmp=cvt_to_mpint_and_add_zero_if_needed(i)
    return pack_buf(tmp)

def unwrap_packet(buf):
    packet_len, padding_len = struct.unpack (">iB", buf[0:4+1])

    buf_with_padding=buf[4:]
    if packet_len!=len(buf_with_padding):
        print ("fatal error in unwrap_packet")
        print (f"{packet_len=} (as we got from packet header)")
        print (f"{len(buf_with_padding)=} (actual buf len)")
        print ("exiting")
        exit(0)
    # remove padding
    payload=buf_with_padding[1:-padding_len]
    return payload

def remove_padding(buf):
    padding_len = struct.unpack (">B", buf[0:1])[0]
    return buf[1:-padding_len]

def unpack_serv_KEX(buf):
    if VERBOSITY>=1:
        print ("unpack_serv_KEX() begin")
    global from_serv_kex_algorithms, from_serv_server_host_algorithms, from_serv_encryption_algorithms
    global from_serv_mac_algorithms, from_serv_compression_algorithms, serv_KEX_blob

    serv_KEX=unwrap_packet(buf)
    idx=0
    if VERBOSITY>=2:
        print ("KEX from serv without len and padlen and msg_code:") # will be hashed as blob
        hexdump.hexdump(serv_KEX[1:])
    serv_KEX_blob=serv_KEX[1:]

    msg_code = struct.unpack (">B", serv_KEX[idx:idx+1])[0]
    idx+=1
    if msg_code==SSH_MSG_DISCONNECT:
        print ("Got SSH_MSG_DISCONNECT, can't proceed further")
        exit(0)
    assert msg_code==SSH_MSG_KEXINIT
    cookie=serv_KEX[idx:idx+0x10]
    idx+=0x10
    #print ("cookie", cookie)

    s, idx = get_str(serv_KEX, idx)
    if VERBOSITY>=1:
        for t in s.split(","):
            print (f"kex_algorithms: {t}")
    from_serv_kex_algorithms=s

    s, idx = get_str(serv_KEX, idx)
    if VERBOSITY>=1:
        for t in s.split(","):
            print (f"server_host_algorithms: {t}")
    from_serv_server_host_algorithms=s

    s, idx = get_str(serv_KEX, idx)
    from_serv_encryption_algorithms=s

    s, idx = get_str(serv_KEX, idx)
    assert (s==from_serv_encryption_algorithms) # usually the same!

    if VERBOSITY>=1:
        for t in s.split(","):
            print (f"encryption_algorithms: {t}")

    s, idx = get_str(serv_KEX, idx)
    from_serv_mac_algorithms=s

    s, idx = get_str(serv_KEX, idx)
    # the order may be different second time:
    assert sorted(s.split(","))==sorted(from_serv_mac_algorithms.split(","))

    if VERBOSITY>=1:
        for t in s.split(","):
            print (f"mac_algorithms: {t}")

    s, idx = get_str(serv_KEX, idx)
    from_serv_compression_algorithms=s

    s, idx = get_str(serv_KEX, idx)
    assert s==from_serv_compression_algorithms

    if VERBOSITY>=1:
        for t in s.split(","):
            print (f"compression_algorithms: {t}")

    s, idx = get_str(serv_KEX, idx)
    languages_client_to_server=s

    s, idx = get_str(serv_KEX, idx)
    assert s==languages_client_to_server
    if VERBOSITY>=1 and len(s)>0:
        for t in s.split(","):
            print (f"languages: {t}")

    first_KEX_packet_follows, reserverd = struct.unpack (">Bi", serv_KEX[idx:idx+1+4])
    assert first_KEX_packet_follows==0
    assert reserverd==0

def unpack_KEXDH_REPLY(buf):
    payload=unwrap_packet(buf)

    if VERBOSITY>=2:
        print ("unpack_KEXDH_REPLY():")
        hexdump.hexdump(payload)

    idx=0
    msg_code = struct.unpack (">B", payload[idx:idx+1])[0]
    if msg_code==SSH_MSG_DISCONNECT:
        print ("Got SSH_MSG_DISCONNECT, can't proceed further")
        exit(0)
    assert msg_code==SSH_MSG_KEXDH_REPLY
    idx+=1
    p, idx = get_mpint(payload, idx)
    # one from /etc/ssh/moduli
    if VERBOSITY>=1:
        print (f"DH GEX modulus (P): {hex(p)}")
        print (f"binlog(P)={binlog(p)}")
    g, idx = get_mpint(payload, idx)
    if VERBOSITY>=1:
        print (f"DH GEX base (G): {g}")
    assert len(payload)==idx # be sure nothing else left in buffer
    return g, p

def wrap_packet(payload):
    global ENCRYPTION
    
    if VERBOSITY>=3:
        print ("wrap_packet() payload:")
        hexdump.hexdump(payload)
    payload_len=len(payload)
    if VERBOSITY>=3:
        print (f"wrap_packet() {payload_len=}")
    with_header=4+1+payload_len
    if ENCRYPTION==False:
        with_padding=align2grain(with_header, 8)
    else:
        with_padding=align2grain(with_header, 16)
    padlen=with_padding-with_header
    if padlen<4:
        if ENCRYPTION==False:
            padlen+=8
        else:
            padlen+=16
    if VERBOSITY>=3:
        print (f"wrap_packet() {padlen=}")

    padding=b'\xAA'*padlen
    payload_padded=struct.pack (">B", padlen) + payload + padding
    outbuf=struct.pack(">i", len(payload_padded)) + payload_padded
    return outbuf

def pack_client_KEX():
    global EXT_INFO_C, ENCRYPTION
    
    global client_KEX_blob
    if VERBOSITY>=2:
        print ("pack_client_KEX()")
        
    if EXT_INFO_C:
        kex_algorithms=KEX_ALGOS+",ext-info-c"
    else:
        kex_algorithms=KEX_ALGOS # doesn't work with crypto=none, dunno why
        
    if VERBOSITY>=2:
        print (f"our kex_algorithms: {kex_algorithms=}")

    server_host_algorithms=SERVER_HOST_ALGOS
    encryption_algorithms_client_to_server=CIPHER_ALGOS
    encryption_algorithms_server_to_client=encryption_algorithms_client_to_server
    mac_algorithms_client_to_server=MAC_ALGOS
    mac_algorithms_server_to_client=mac_algorithms_client_to_server
    if "none" not in CIPHER_ALGOS:
        compression_algorithms_client_to_server="none,zlib@openssh.com"
    else:
        compression_algorithms_client_to_server="none" # doesn't work with crypto=none
    compression_algorithms_server_to_client=compression_algorithms_client_to_server
    languages_client_to_server=""
    languages_server_to_client=languages_client_to_server

    buf=struct.pack(">B", SSH_MSG_KEXINIT)
    cookie=b'\x11'*16
    buf+=cookie
    buf+=pack_str(kex_algorithms)
    buf+=pack_str(server_host_algorithms)
    buf+=pack_str(encryption_algorithms_client_to_server)
    buf+=pack_str(encryption_algorithms_server_to_client)
    buf+=pack_str(mac_algorithms_client_to_server)
    buf+=pack_str(mac_algorithms_server_to_client)
    buf+=pack_str(compression_algorithms_client_to_server)
    buf+=pack_str(compression_algorithms_server_to_client)
    buf+=pack_str(languages_client_to_server)
    buf+=pack_str(languages_server_to_client)
    first_KEX_packet_follows, reserverd = 0, 0
    buf+=struct.pack (">Bi", first_KEX_packet_follows, reserverd)

    if VERBOSITY>=2:
        print ("KEX from client without len and padlen and msg_code and padding:") # will be hashed as blob
        hexdump.hexdump(buf[1:])
    client_KEX_blob=buf[1:]

    return wrap_packet(buf)

def pack_DH_GEX_request():
    if VERBOSITY>=2:
        print ("pack_DH_GEX_request()")
    return wrap_packet(struct.pack(">Biii", SSH_MSG_KEX_DH_GEX_REQUEST, DH_GEX_min, DH_GEX_bits, DH_GEX_max))

def pack_DH_GEX_INIT(e):
    if VERBOSITY>=2:
        print ("pack_DH_GEX_INIT()")
    buf=struct.pack(">B", SSH_MSG_KEX_DH_GEX_INIT)
    buf+=pack_mpint(e)
    return wrap_packet(buf)

def pack_SSH_MSG_KEXDH_INIT_raw(e):
    if VERBOSITY>=2:
        print ("pack_SSH_MSG_KEXDH_INIT")
    buf=struct.pack(">B", SSH_MSG_KEXDH_INIT)
    buf+=pack_buf(e)
    return wrap_packet(buf)

def pack_SSH_MSG_KEXDH_INIT(e):
    if VERBOSITY>=2:
        print ("pack_SSH_MSG_KEXDH_INIT")
    buf=struct.pack(">B", SSH_MSG_KEXDH_INIT)
    buf+=pack_mpint(e)
    return wrap_packet(buf)

def pack_new_keys():
    if VERBOSITY>=2:
        print ("pack_new_keys()")
    buf=struct.pack(">B", SSH_MSG_NEWKEYS)
    return wrap_packet(buf)

def send_new_keys(s):
    tmp=pack_new_keys()
    if VERBOSITY>=2:
        print ("send_new_keys():")
        hexdump.hexdump(tmp)
    my_send(s, tmp)

def unpack_RSA_host_key(host_key):
    global RSA_e, RSA_modulus_n

    buf=host_key
    idx=0
    RSA_e, idx = get_mpint(buf, idx)
    RSA_modulus_n, idx = get_mpint(buf, idx)
    assert len(buf)==idx # be sure nothing else left in buffer
    if VERBOSITY>=1:
        print (f"{RSA_e=}")
        print (f"{RSA_modulus_n=}")
        print (f"{binlog(RSA_modulus_n)=}")

def unpack_ECDSA_host_key(host_key):
    global ECDSA_curve_id, ECDSA_Q
    buf=host_key
    idx=0
    if VERBOSITY>=2:
        print ("unpack_ECDSA_host_key() host_key:")
        hexdump.hexdump(buf[idx:])
    ECDSA_curve_id, idx = get_str(buf, idx)
    if VERBOSITY>=1:
        print (f"unpack_ECDSA_host_key() {ECDSA_curve_id=}:")
    ECDSA_Q, idx=get_buf (buf, idx)
    if VERBOSITY>=2:
        print ("unpack_ECDSA_host_key() ECDSA_Q:")
        hexdump.hexdump(ECDSA_Q)

def unpack_DSA_host_key(host_key):
    global DSA_p, DSA_q, DSA_g, DSA_y
    buf=host_key
    idx=0
    DSA_p, idx = get_mpint(buf, idx)
    DSA_q, idx = get_mpint(buf, idx)
    DSA_g, idx = get_mpint(buf, idx)
    DSA_y, idx = get_mpint(buf, idx)
    assert len(buf)==idx # be sure nothing else left in buffer
    if VERBOSITY>=1:
        print (f"{DSA_p=}")
        print (f"{DSA_q=}")
        print (f"{DSA_g=}")
        print (f"{DSA_y=}")
        print (f"{binlog(DSA_p)=}")
        print (f"{binlog(DSA_q)=}")
        print (f"{binlog(DSA_g)=}")
        print (f"{binlog(DSA_y)=}")

def print_fingerprint(KEX_host_key):
    hashed=hashlib.sha256(KEX_host_key).digest()
    print (f"Server fingerprint: {SERVER_HOST_ALGO} SHA256:"+base64.b64encode(hashed).decode("utf-8"))

def unpack_KEX_H_sig(KEX_H_sig):
    global KEX_H_sig_buf
    idx=0
    s, idx = get_str(KEX_H_sig, idx)
    if VERBOSITY>=1:
        # it must be equal to the SERVER_HOST_ALGO string
        print (f"unpack_KEX_H_sig(): {s}")
    tmp, idx=get_buf (KEX_H_sig, idx)
    if VERBOSITY>=2:
        hexdump.hexdump(tmp)
    KEX_H_sig_buf=tmp
    assert len(KEX_H_sig)==idx # be sure nothing else left in buffer

def chk_RSA_kexgex_hash(kexgex_hash):
    if VERBOSITY>=1:
        print ("chk_RSA_kexgex_hash()")
    global KEX_H_sig_buf # in case of RSA: 2k bits
    global RSA_e
    global RSA_modulus_n # 2k bits
    # do RSA signature check:
    KEX_H_sig_mpint=int.from_bytes(KEX_H_sig_buf, byteorder='big')
    
    plaintext_must_be=pow (KEX_H_sig_mpint, RSA_e, RSA_modulus_n)
    # plaintext_must_be:
    # 160 bits if server_host_algorithms="ssh-rsa". SHA1.
    # X bits if server_host_algorithms="rsa-sha2-X". SHA256 or SHA512
    if VERBOSITY>=1:
        print (f"{hex(plaintext_must_be)=}")
    # https://stackoverflow.com/questions/7983820/get-the-last-4-characters-of-a-string
    plaintext_must_be_truncated=hex(plaintext_must_be)[-SERVER_HOST_ALGO_NIBBLES:] # last X nibbles/characters for SHA
    if VERBOSITY>=1:
        print (f"{plaintext_must_be_truncated=}")
    if VERBOSITY>=2:
        print ("kexgex_hash:")
        hexdump.hexdump(kexgex_hash)
    tmp=SERVER_HOST_ALGO_PTR(kexgex_hash).digest()
    if VERBOSITY>=2:
        print (SERVER_HOST_ALGO+"(kexgex_hash):")
        hexdump.hexdump(tmp)
    hash_of_kexgex_hash=SERVER_HOST_ALGO_PTR(kexgex_hash).hexdigest()
    if VERBOSITY>=2:
        print (f"{hash_of_kexgex_hash=}")
    if plaintext_must_be_truncated==hash_of_kexgex_hash:
        if VERBOSITY>=1:
            print ("Server signature is correct")
    else:
        print ("Error. Server signature is incorrect")
        print ("Exiting")
        exit(0)

# chk ECDSA signature
def chk_ECDSA_kexgex_hash(kexgex_hash):
    if VERBOSITY>=1:
        print ("chk_ECDSA_kexgex_hash()")
    if VERBOSITY>=2:
        print ("kexgex_hash:")
        hexdump.hexdump(kexgex_hash)

    if VERBOSITY>=2:
        print ("KEX_H_sig_buf:")
        hexdump.hexdump(KEX_H_sig_buf)

    idx=0

    sign_r, idx=get_mpint (KEX_H_sig_buf, idx)
    sign_s, idx=get_mpint (KEX_H_sig_buf, idx)

    if VERBOSITY>=1:
        print (f"{sign_r=}")
        print (f"{sign_s=}")

    # convert r,s pair to ASN1-encoded packet
    # https://cryptography.io/en/latest/hazmat/primitives/asymmetric/utils/#cryptography.hazmat.primitives.asymmetric.utils.encode_dss_signature
    asn1_sig=cryptography.hazmat.primitives.asymmetric.utils.encode_dss_signature(sign_r, sign_s)

    global ECDSA_curve_id, ECDSA_Q
    assert ECDSA_curve_id=="nistp256"

    pub_key=cryptography.hazmat.primitives.asymmetric.ec.EllipticCurvePublicKey.from_encoded_point(
        data=ECDSA_Q,
        curve=cryptography.hazmat.primitives.asymmetric.ec.SECP256R1())
    if VERBOSITY>=1:
        print (f"ECDSA public key from server: {pub_key.public_numbers()}")
 
    try:
        pub_key.verify(asn1_sig, kexgex_hash, cryptography.hazmat.primitives.asymmetric.ec.ECDSA(cryptography.hazmat.primitives.hashes.SHA256()))
        if VERBOSITY>=1:
            print ("Server signature is correct")
    except cryptography.exceptions.InvalidSignature:
        print ("Error. Server signature is incorrect")
        print ("Exiting")
        exit(0)

# chk DSA signature
def chk_DSA_kexgex_hash(kexgex_hash):
    if VERBOSITY>=1:
        print ("chk_DSA_kexgex_hash()")

    global DSA_p, DSA_q, DSA_g, DSA_y, KEX_H_sig_buf

    L=math.ceil(math.log(DSA_p, 2))
    N=math.ceil(math.log(DSA_q, 2))
    
    if VERBOSITY>=1:
        print (f"{L=}, {N=}")

    if VERBOSITY>=2:
        print ("KEX_H_sig_buf:")
        hexdump.hexdump(KEX_H_sig_buf)

    buf_r=KEX_H_sig_buf[0:len(KEX_H_sig_buf)//2]
    buf_s=KEX_H_sig_buf[len(KEX_H_sig_buf)//2:]

    if VERBOSITY>=2:
        print ("buf_r:")
        hexdump.hexdump(buf_r)
        print ("buf_s:")
        hexdump.hexdump(buf_s)

    r=int.from_bytes(buf_r, 'big')
    s=int.from_bytes(buf_s, 'big')

    m1 = hashlib.sha1()
    m1.update(kexgex_hash)
    z=int.from_bytes(m1.digest()[0:N//8], 'big')

    w=pow(s, -1, DSA_q) # mod inverse of s mod q
    u1=divmod(z*w, DSA_q)[1] # IOW, u1=z/s mod q
    u2=divmod(r*w, DSA_q)[1] # IOW, u2=r/s mod q
    v=divmod(divmod(pow(DSA_g, u1, DSA_p) * pow(DSA_y, u2, DSA_p), DSA_p)[1], DSA_q)[1]

    if VERBOSITY>=2:
        print (f"{hex(v)=}")

    if r==v:
        if VERBOSITY>=1:
            print ("Server signature is correct")
    else:
        print ("Error. Server signature is incorrect")
        print ("Exiting")
        exit(0)

def save_serv_pub_key(KEX_host_key):
    idx=0
    _type, idx = get_str(KEX_host_key, idx)
    to_save=_type+" "+base64.b64encode(KEX_host_key).decode("utf-8")+" root@"+HOST
    if "rsa" in SERVER_HOST_ALGO:
        fname=HOST+".ssh_host_rsa_key.pub"
    elif "ecdsa" in SERVER_HOST_ALGO:
        fname=HOST+".ssh_host_ecdsa_key.pub"
    else:
        assert False
    f=open(fname, "w")
    f.write(to_save+"\n")
    f.close()
    print (f"Server's public key saved to {fname}")

def unpack_KEX_host_key(KEX_host_key):
    global SAVE_SERV_PUB_KEY, HOST_KEY_TYPE
    if VERBOSITY>=2:
        print ("KEX_host_key: will be hashed as blob:")
        # same as in /etc/ssh/ssh_host_rsa_key.pub
        # or as in   /etc/ssh/ssh_host_ecdsa_key.pub
        hexdump.hexdump(KEX_host_key)
    if VERBOSITY>=1:
        print ("unpack_KEX_host_key:")
        print_fingerprint(KEX_host_key)
    if SAVE_SERV_PUB_KEY==True:
        save_serv_pub_key(KEX_host_key)

    HOST_KEY_TYPE, idx = get_str(KEX_host_key, 0)
    
    x={
    "ecdsa-sha2-nistp256": unpack_ECDSA_host_key,
    "ssh-dss":             unpack_DSA_host_key,
    "ssh-rsa":             unpack_RSA_host_key,
    "rsa-sha2-256":        unpack_RSA_host_key,
    "rsa-sha2-512":        unpack_RSA_host_key}
    x[HOST_KEY_TYPE](KEX_host_key[idx:])

def unpack_DH_REPLY(buf, GEX):
    global SERVER_HOST_ALGO
    idx=0
    if VERBOSITY>=2:
        print ("unpack_DH_REPLY():")
        hexdump.hexdump(buf)
    msg_code = struct.unpack (">B", buf[idx:idx+1])[0]
    if GEX and msg_code!=SSH_MSG_KEX_DH_GEX_REPLY:
        print (f"expecting msg_code==SSH_MSG_KEX_DH_GEX_REPLY, but got {msg_code} or {msg_str[msg_code]}")
        exit(0)
    if GEX==False and msg_code!=SSH_MSG_KEXDH_REPLY:
        print (f"expecting msg_code==SSH_MSG_KEXDH_REPLY, but got {msg_code} or {msg_str[msg_code]}")
        exit(0)
    idx+=1
    KEX_host_key, idx = get_buf(buf, idx)

    unpack_KEX_host_key(KEX_host_key)

    server_f, idx=get_mpint (buf, idx)
    if VERBOSITY>=1:
        print (f"{hex(server_f)=}")
    KEX_H_sig, idx=get_buf (buf, idx)
    if VERBOSITY>=2:
        print ("KEX_H_sig:")
        # this buffer begins with SERVER_HOST_ALGO string
        hexdump.hexdump(KEX_H_sig)
    unpack_KEX_H_sig(KEX_H_sig)
    assert len(buf)==idx # be sure nothing else left in buffer
    return KEX_host_key, server_f

def unpack_ECDH_REPLY(buf):
    global SERVER_HOST_ALGO
    idx=0
    if VERBOSITY>=2:
        print ("unpack_ECDH_REPLY():")
        hexdump.hexdump(buf)
    msg_code = struct.unpack (">B", buf[idx:idx+1])[0]
    if msg_code!=SSH_MSG_KEXDH_REPLY:
        print (f"expecting msg_code==SSH_MSG_KEXDH_REPLY, but got {msg_code}")
        exit(0)
    idx+=1
    KEX_host_key, idx = get_buf(buf, idx)

    unpack_KEX_host_key(KEX_host_key)

    ECDH_server_Q_S, idx=get_buf (buf, idx)
    if VERBOSITY>=2:
        print ("ECDH_server_Q_S:")
        hexdump.hexdump(ECDH_server_Q_S)
    KEX_H_sig, idx=get_buf (buf, idx)
    if VERBOSITY>=2:
        print ("KEX_H_sig:")
        # this buffer begins with SERVER_HOST_ALGO string
        hexdump.hexdump(KEX_H_sig)
    unpack_KEX_H_sig(KEX_H_sig)
    assert len(buf)==idx # be sure nothing else left in buffer
    return KEX_host_key, ECDH_server_Q_S

# hash hodgepodge
# this is almost all data exchanged before this moment
# critical part - shared_secret
# [version without GEX (DH groups is hardcoded): see kex_gen_hash() in OpenSSH 9.0]
# [version for GEX: see kexgex_hash() in OpenSSH 9.0]
def calc_kexgex_hash(client_banner, serv_banner, client_KEX_blob, serv_KEX_blob, KEX_host_key, client_e, server_f, shared_secret, g, p, GEX):
    global KEX_HASH
    buf=b""
    buf+=pack_buf(client_banner)
    buf+=pack_buf(serv_banner)
    buf+=pack_buf(struct.pack(">B", SSH_MSG_KEXINIT)+client_KEX_blob)
    buf+=pack_buf(struct.pack(">B", SSH_MSG_KEXINIT)+serv_KEX_blob)
    buf+=pack_buf(KEX_host_key)
    if GEX:
        buf+=struct.pack(">iii", DH_GEX_min, DH_GEX_bits, DH_GEX_max)
        buf+=pack_mpint(p)
        buf+=pack_mpint(g)
    buf+=pack_mpint(client_e)
    buf+=pack_mpint(server_f)
    buf+=pack_mpint(shared_secret)
    if VERBOSITY>=3:
        print (f"buf to hash. {len(buf)=}")
        hexdump.hexdump(buf)
    hashed=KEX_HASH(buf).digest()
    if VERBOSITY>=2:
        print ("hashed:")
        hexdump.hexdump(hashed)
    return hashed

def derive_key(K, H, c, session_id):
    global KEX_HASH
    m=KEX_HASH()
    assert type(K)==int # let it be here
    m.update(pack_mpint(K))
    m.update(H)
    m.update(str.encode(c))
    m.update(session_id)
    key=m.digest()
    if VERBOSITY>=2:
        print ("key_%c:" % c)
        hexdump.hexdump(key)
    return key

# required - in bytes
def derive_key2(K, H, c, session_id, required):
    global KEX_HASH
    m=KEX_HASH()
    assert type(K)==int # let it be here
    m.update(pack_mpint(K))
    m.update(H)
    m.update(str.encode(c))
    m.update(session_id)
    key=m.digest()

    m1=KEX_HASH()
    assert type(K)==int # let it be here
    m1.update(pack_mpint(K))
    m1.update(H)
    prev_key=key
    # m1 may not be used
    while len(key)<required:
        # extend key.
        # [see RFC 4253 7.2 https://datatracker.ietf.org/doc/html/rfc4253#section-7.2]
        m1.update(prev_key)
        prev_key=m1.digest()
        key+=prev_key
    assert len(key)>=required
    key=key[0:required]
    if VERBOSITY>=2:
        print ("key_%c:" % c)
        hexdump.hexdump(key)
    return key

# [See kex.c, derive_key() in OpenSSH 9.0, also RFC 4253 7.2]
def derive_keys(K, H, session_id):
    global key_A, key_B, key_C, key_D, key_E, key_F, client_to_serv_ctr, serv_to_client_ctr
    global MAC_SIZE, CIPHER_KEY_SIZE, ENCRYPTION
    
    if VERBOSITY>=2:
        print ("derive_keys() begin")
    if ENCRYPTION:
        key_A=derive_key(K, H, 'A', session_id) # IV or ctr
        key_B=derive_key(K, H, 'B', session_id) # IV or ctr

        # AES counters:
        assert len(key_A)>=16
        assert len(key_B)>=16
        client_to_serv_ctr = int.from_bytes(key_A[0:128//8], byteorder='big')
        serv_to_client_ctr = int.from_bytes(key_B[0:128//8], byteorder='big')

        key_C=derive_key2(K, H, 'C', session_id, CIPHER_KEY_SIZE) # encryption key
        key_D=derive_key2(K, H, 'D', session_id, CIPHER_KEY_SIZE) # encryption key
        
    key_E=derive_key2(K, H, 'E', session_id, MAC_SIZE) # MAC key
    key_F=derive_key2(K, H, 'F', session_id, MAC_SIZE) # MAC key

def my_send_with_MAC(s, msg):
    global CIPHER_ALGO, client_to_serv_ctr, client_to_serv_ctr_1

    if CIPHER_ALGO=="aes128-ctr":
        KEY=key_C[0:128//8]
        assert len(KEY)==16
    elif CIPHER_ALGO=="aes256-ctr":
        KEY=key_C[0:256//8]
        assert len(KEY)==32
    elif CIPHER_ALGO=="none":
        pass
    else:
        assert False
        
    if VERBOSITY>=2 and CIPHER_ALGO!="none":
        print (f"my_send_with_MAC() {len(KEY)=}")
        hexdump.hexdump(KEY)
        
    if VERBOSITY>=2:
        print (f"my_send_with_MAC() {len(msg)=}")
        hexdump.hexdump(msg)

    assert len(key_E)>=MAC_SIZE
    calc_MAC = hmac.digest(key=key_E[0:MAC_SIZE], msg=struct.pack(">i", send_seqno-1)+msg, digest=MAC_ALGO)
    if VERBOSITY>=2:
        print ("my_send_with_MAC() calc_MAC:")
        hexdump.hexdump(calc_MAC)
    if ENCRYPTION==False:
        assert (len(msg)&0x7)==0
        my_send(s, msg+calc_MAC)
    else:
        assert (len(msg)&0xf)==0

        ctr = Crypto.Util.Counter.new(128, little_endian=False, initial_value=client_to_serv_ctr)
        t = Crypto.Cipher.AES.new(KEY, Crypto.Cipher.AES.MODE_CTR, counter = ctr)
        t2 = t.encrypt(msg)
        client_to_serv_ctr+=(len(t2)//AES_BLOCK_SIZE_IN_BYTES)
        my_send(s, t2+calc_MAC)
    if VERBOSITY>=2:
        print ("my_send_with_MAC() finish")

def my_recv_chk_MAC(s):
    tmp=my_recv(s, True)
    if VERBOSITY>=3:
        print ("my_recv_chk_MAC() begin. buf:")
        hexdump.hexdump(tmp)
    assert len(tmp)!=0
    tmplen=len(tmp)
    msg=tmp[0:-MAC_SIZE]
    msglen=len(msg)
    if VERBOSITY>=3:
        print ("msg to check:")
        hexdump.hexdump(msg)
    received_MAC=tmp[msglen:msglen+MAC_SIZE]
    assert len(received_MAC)==MAC_SIZE
    if VERBOSITY>=3:
        print (f"my_recv_chk_MAC(): {recv_seqno=}")

    assert len(key_F)>=MAC_SIZE
    calc_MAC = hmac.digest(key=key_F[0:MAC_SIZE], msg=struct.pack(">i", recv_seqno-2)+msg, digest=MAC_ALGO)

    if calc_MAC!=received_MAC:
        print ("fatal error: received MAC is not correct!")
        if VERBOSITY>=3:
            print ("calculated MAC:")
            hexdump.hexdump(calc_MAC)
        exit(0)
    else:
        if VERBOSITY>=3:
            print ("my_recv_chk_MAC(): MAC checked OK")
    return msg

def send_exec(s, cmd):
    print (f"send_exec {cmd}")
    msg=struct.pack("B", SSH_MSG_CHANNEL_REQUEST)
    msg+=struct.pack(">i", 0)
    msg+=pack_str("exec")
    msg+=struct.pack("B", 1)
    msg+=pack_str(cmd)
    pkt=wrap_packet(msg)
    #hexdump.hexdump(pkt)
    my_send_with_MAC(s, pkt)

def unpack_exit_status(s, msg):
    if VERBOSITY>=2:
        print ("unpack_exit_status(), msg:")
        hexdump.hexdump(msg)
    assert msg[0]==SSH_MSG_CHANNEL_REQUEST
    idx=1+4 # skip recipient channel
    tmp, idx = get_str(msg, idx)
    if tmp=="exit-status":
        # [see RFC4254 6.10]
        idx+=1 # skip boolean
        exit_status, idx = get_u32(msg, idx)
        print (f"exit-status {exit_status}")
        assert len(msg)==idx # be sure nothing else left in buffer
    elif tmp=="eow@openssh.com":
        # for more info, grep "eow@openssh.com" here: 
        # https://github.com/openssh/openssh-portable/blob/master/PROTOCOL
        idx+=1 # skip boolean
        assert len(msg)==idx # be sure nothing else left in buffer
    else:
        assert False

unpack_channel_data_got_total=0

# [see RFC4254 5.2]
def unpack_channel_data(msg):
    global unpack_channel_data_got_total

    assert msg[0]==SSH_MSG_CHANNEL_DATA
    idx=1+4 # skip recipient channel
    tmp, idx = get_str(msg, idx)
    print ("response:")
    print ("==")
    print (tmp)
    print ("==")
    assert len(msg)==idx # be sure nothing else left in buffer
    unpack_channel_data_got_total+=len(tmp)
    if unpack_channel_data_got_total>=INIT_WINDOW_SIZE:
        print ("Warning. If receiving is stuck now, ")
        print (f"this is because {INIT_WINDOW_SIZE=} is too small,")
        print (f"and we already got {unpack_channel_data_got_total} bytes")
        print ("Increase INIT_WINDOW_SIZE or send windows adjust messages to server (TODO)")

# [see RFC4254 5.2]
SSH_EXTENDED_DATA_STDERR=1
def unpack_channel_extended_data(msg):
    assert msg[0]==SSH_MSG_CHANNEL_EXTENDED_DATA
    idx=1+4 # skip recipient channel
    code, idx = get_u32(msg, idx)
    assert code==SSH_EXTENDED_DATA_STDERR
    tmp, idx = get_str(msg, idx)
    print ("response (stderr):")
    print ("==")
    print (tmp)
    print ("==")
    assert len(msg)==idx # be sure nothing else left in buffer

# see [RFC4254 5.3]
def send_eof(s):
    if VERBOSITY>=2:
        print ("send_eof()")
    msg=struct.pack("B", SSH_MSG_CHANNEL_EOF)
    msg+=struct.pack(">i", 0)
    pkt=wrap_packet(msg)
    #hexdump.hexdump(pkt)
    my_send_with_MAC(s, pkt)

INIT_WINDOW_SIZE=0x100000

def send_channel_open(s):
    msg=struct.pack("B", SSH_MSG_CHANNEL_OPEN) # 90 or 0x5a
    msg+=pack_str("session")
    msg+=struct.pack(">iii", 0, INIT_WINDOW_SIZE, 0x800000)
    pkt=wrap_packet(msg)
    #hexdump.hexdump(pkt)
    my_send_with_MAC(s, pkt)

def expect_msg(s, msg_code):
    if VERBOSITY>=2:
        print (f"expecting msg_code={hex(msg_code)} or {msg_str[msg_code]}")
    msg=unwrap_packet(my_recv_chk_MAC(s))
    if msg[0]==SSH_MSG_DEBUG:
        if VERBOSITY>=2:
            print ("Got SSH_MSG_DEBUG message:")
            hexdump.hexdump(msg)
            print ("Running expect_msg() again")
        return expect_msg(s, msg_code)
    if msg[0]==SSH_MSG_USERAUTH_BANNER:
        if VERBOSITY>=2:
            print ("Got SSH_MSG_USERAUTH_BANNER message:")
            hexdump.hexdump(msg)
            print ("Running expect_msg() again")
        return expect_msg(s, msg_code)
    if msg[0]==msg_code:
        return msg
    print ("Fatal error in expect_msg().")
    print (f"While waiting for message msg_code=0x{msg_code}, or {msg_str[msg_code]}")
    print (f"Got message {hex(msg[0])} instead, or: {msg_str[msg[0]]}")

    if msg_code==SSH_MSG_USERAUTH_FAILURE and msg[0]==SSH_MSG_USERAUTH_SUCCESS:
        print ("That means that SSH login mechanism is bypassed,")
        print ("and the real auth will happen on pty level.")
        print ("Dropbear is known to behave like this.")
        print ("Try ssh user@host -vv and see something like:")
        print ("\"Authenticated to IP ([IP]:PORT) using \"none\".\"")
        print ("pty stuff is not yet supported in toyssh.")

    exit(0)

def expect_new_keys(s):
    # plain message, unencrypted
    if VERBOSITY>=2:
        print ("expecting SSH_MSG_NEWKEYS")
    tmp=unwrap_packet(my_recv(s))
    if VERBOSITY>=2:
        print ("expect_new_keys(), got:")
        hexdump.hexdump(tmp)
    assert tmp==struct.pack("B", SSH_MSG_NEWKEYS)

def send_ssh_userauth(s):
    if VERBOSITY>=1:
        print ("sending ssh-userauth")
    msg=wrap_packet(struct.pack("B", SSH_MSG_SERVICE_REQUEST)+pack_str("ssh-userauth"))
    if VERBOSITY>=2:
        print ("send_ssh_userauth(), sending:")
        hexdump.hexdump(msg)
    my_send_with_MAC(s, msg)

def send_ssh_connection_none(s):
    global USERNAME
    if VERBOSITY>=1:
        print ("sending ssh-connection")
    msg=struct.pack("B", SSH_MSG_USERAUTH_REQUEST)
    if USERNAME==None:
        USERNAME="root"
    msg+=pack_str(USERNAME)
    msg+=pack_str("ssh-connection")
    msg+=pack_str("none")
    pkt=wrap_packet(msg)
    if VERBOSITY>=2:
        print ("send_ssh_connection_none(), sending:")
        hexdump.hexdump(pkt)
    my_send_with_MAC(s, pkt)

def send_ssh_connection_password(s, username, password):
    if VERBOSITY>=1:
        print ("sending ssh-connection")
    msg=struct.pack("B", SSH_MSG_USERAUTH_REQUEST)
    msg+=pack_str(username)
    msg+=pack_str("ssh-connection")
    msg+=pack_str("password")
    msg+=b"\x00" # change password byte
    msg+=pack_str(password)
    pkt=wrap_packet(msg)
    if VERBOSITY>=2:
        print ("send_ssh_connection_password(), sending:")
        hexdump.hexdump(pkt)
    my_send_with_MAC(s, pkt)
    msg=unwrap_packet(my_recv_chk_MAC(s))
    if msg[0]==SSH_MSG_USERAUTH_SUCCESS:
        print ("Password correct.")
        return True
    elif msg[0]==SSH_MSG_USERAUTH_FAILURE:
        print ("Fatal error. Password incorrect.")
        return False
    elif msg[0]==SSH_MSG_USERAUTH_BANNER:
        print ("Fatal error. Got SSH_MSG_USERAUTH_BANNER that we don't handle:")
        hexdump.hexdump(msg)
        exit(0)
    elif msg[0]==SSH_MSG_DISCONNECT:
        print ("Got SSH_MSG_DISCONNECT")
        exit(0)
    else:
        print (f"Fatal error. Unhandled {hex(msg[0])=}")
        exit(0)

def get_list_of_auth_modes(s):
    # this message sent to get list of auth modes:
    send_ssh_connection_none(s)
    msg=expect_msg(s, SSH_MSG_USERAUTH_FAILURE)
    tmp, _ = get_str(msg, 1)
    if VERBOSITY>=1:
        print (f"allowed auth modes: {tmp}")
    return tmp

def send_client_KEX(s):
    tmp=pack_client_KEX()
    if VERBOSITY>=2:
        print ("client KEX:")
        hexdump.hexdump(tmp)
    my_send(s, tmp)

def DH_get_params(s):
    my_send(s, pack_DH_GEX_request())
    tmp=my_recv(s)
    assert len(tmp)!=0
    g, p = unpack_KEXDH_REPLY(tmp)
    return g, p

def DH_send_client_e(s, g, p, bits, GEX):
    # ATTENTION: may be insecure. use better (C)PRNG
    my_secret=random.randrange(1, 2**bits)
    client_e=pow(g, my_secret, p)

    if VERBOSITY>=1:
        print (f"{hex(client_e)=}")

    if GEX:
        my_send(s, pack_DH_GEX_INIT(client_e))
    else:
        my_send(s, pack_SSH_MSG_KEXDH_INIT(client_e))
    return my_secret, client_e

def ECDH_send_client_pubkey(s, KEX_algo_cur):
    ECDHE_client_private_key = cryptography.hazmat.primitives.asymmetric.ec.generate_private_key(
        KEX_algo_EC_func[KEX_algo_cur])

    ECDHE_client_public_key=ECDHE_client_private_key.public_key()

    ECDHE_client_pubkey_raw_bytes = ECDHE_client_public_key.public_bytes(
        cryptography.hazmat.primitives.serialization.Encoding.X962,
        cryptography.hazmat.primitives.serialization.PublicFormat.UncompressedPoint)

    if VERBOSITY>=2:
        print ("ECDHE_client_pubkey_raw_bytes:")
        hexdump.hexdump(ECDHE_client_pubkey_raw_bytes)

    my_send(s, pack_SSH_MSG_KEXDH_INIT_raw(ECDHE_client_pubkey_raw_bytes))
    return ECDHE_client_private_key, ECDHE_client_public_key, ECDHE_client_pubkey_raw_bytes

def DH_recv_server_f(s, GEX):
    tmp=my_recv(s)
    assert len(tmp)!=0
    KEX_host_key, server_f=unpack_DH_REPLY(unwrap_packet(tmp), GEX)
    return KEX_host_key, server_f

def ECDH_do_exchange(s):
    if KEX_ALGO.startswith("ecdh-sha2-nistp"):
        ECDHE_client_private_key, ECDH_pubkey, ECDHE_client_pubkey_raw_bytes=ECDH_send_client_pubkey(s, KEX_ALGO)
        buf=my_recv(s)
        KEX_host_key, ECDH_server_pubkey_raw=unpack_ECDH_REPLY(unwrap_packet(buf))

        ECDH_server_pubkey_decoded=\
        cryptography.hazmat.primitives.asymmetric.ec.EllipticCurvePublicKey.from_encoded_point(
            data=ECDH_server_pubkey_raw,
            curve=KEX_algo_EC_func[KEX_ALGO])

        # calculate ECDH shared secret:
        shared_key = ECDHE_client_private_key.exchange(cryptography.hazmat.primitives.asymmetric.ec.ECDH(), ECDH_server_pubkey_decoded)
        if VERBOSITY>=2:
            print (f"shared_key:")
            hexdump.hexdump(shared_key)

        return KEX_host_key, ECDHE_client_pubkey_raw_bytes, ECDH_server_pubkey_raw, shared_key
    else:
        assert False

def DH_do_exchange(s):
    if KEX_ALGO=="diffie-hellman-group-exchange-sha256":
        g, p = DH_get_params(s)

        my_secret, client_e=DH_send_client_e(s, g, p, DH_GEX_bits, GEX=True)

        KEX_host_key, server_f = DH_recv_server_f(s, GEX=True)

        GEX=True

    elif KEX_ALGO=="diffie-hellman-group-exchange-sha1":
        g, p = DH_get_params(s)

        my_secret, client_e=DH_send_client_e(s, g, p, DH_GEX_bits, GEX=True)

        KEX_host_key, server_f = DH_recv_server_f(s, GEX=True)

        GEX=True

    elif KEX_ALGO=="diffie-hellman-group14-sha256": # 2048 bits
        # http://tools.ietf.org/html/rfc3526#section-3
        p = 0xFFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE45B3DC2007CB8A163BF0598DA48361C55D39A69163FA8FD24CF5F83655D23DCA3AD961C62F356208552BB9ED529077096966D670C354E4ABC9804F1746C08CA18217C32905E462E36CE3BE39E772C180E86039B2783A2EC07A28FB5C55DF06F4C52C9DE2BCBF6955817183995497CEA956AE515D2261898FA051015728E5A8AACAA68FFFFFFFFFFFFFFFF
        g = 2

        my_secret, client_e=DH_send_client_e(s, g, p, 2048, GEX=False)

        KEX_host_key, server_f = DH_recv_server_f(s, GEX=False)

        GEX=False

    elif KEX_ALGO=="diffie-hellman-group14-sha1": # 2048 bits
        # http://tools.ietf.org/html/rfc3526#section-3
        p = 0xFFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE45B3DC2007CB8A163BF0598DA48361C55D39A69163FA8FD24CF5F83655D23DCA3AD961C62F356208552BB9ED529077096966D670C354E4ABC9804F1746C08CA18217C32905E462E36CE3BE39E772C180E86039B2783A2EC07A28FB5C55DF06F4C52C9DE2BCBF6955817183995497CEA956AE515D2261898FA051015728E5A8AACAA68FFFFFFFFFFFFFFFF
        g = 2

        my_secret, client_e=DH_send_client_e(s, g, p, 2048, GEX=False)

        KEX_host_key, server_f = DH_recv_server_f(s, GEX=False)

        GEX=False

    elif KEX_ALGO=="diffie-hellman-group16-sha512": # 4096 bits
        # http://tools.ietf.org/html/rfc3526#section-5
        p = 0xFFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE45B3DC2007CB8A163BF0598DA48361C55D39A69163FA8FD24CF5F83655D23DCA3AD961C62F356208552BB9ED529077096966D670C354E4ABC9804F1746C08CA18217C32905E462E36CE3BE39E772C180E86039B2783A2EC07A28FB5C55DF06F4C52C9DE2BCBF6955817183995497CEA956AE515D2261898FA051015728E5A8AAAC42DAD33170D04507A33A85521ABDF1CBA64ECFB850458DBEF0A8AEA71575D060C7DB3970F85A6E1E4C7ABF5AE8CDB0933D71E8C94E04A25619DCEE3D2261AD2EE6BF12FFA06D98A0864D87602733EC86A64521F2B18177B200CBBE117577A615D6C770988C0BAD946E208E24FA074E5AB3143DB5BFCE0FD108E4B82D120A92108011A723C12A787E6D788719A10BDBA5B2699C327186AF4E23C1A946834B6150BDA2583E9CA2AD44CE8DBBBC2DB04DE8EF92E8EFC141FBECAA6287C59474E6BC05D99B2964FA090C3A2233BA186515BE7ED1F612970CEE2D7AFB81BDD762170481CD0069127D5B05AA993B4EA988D8FDDC186FFB7DC90A6C08F4DF435C934063199FFFFFFFFFFFFFFFF
        g = 2

        my_secret, client_e=DH_send_client_e(s, g, p, 4096, GEX=False)

        KEX_host_key, server_f = DH_recv_server_f(s, GEX=False)

        GEX=False

    elif KEX_ALGO=="diffie-hellman-group18-sha512": # 8192 bits
        # https://www.rfc-editor.org/rfc/rfc8268#section-3
        # https://datatracker.ietf.org/doc/html/rfc3526#section-7
        p = 0xFFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE45B3DC2007CB8A163BF0598DA48361C55D39A69163FA8FD24CF5F83655D23DCA3AD961C62F356208552BB9ED529077096966D670C354E4ABC9804F1746C08CA18217C32905E462E36CE3BE39E772C180E86039B2783A2EC07A28FB5C55DF06F4C52C9DE2BCBF6955817183995497CEA956AE515D2261898FA051015728E5A8AAAC42DAD33170D04507A33A85521ABDF1CBA64ECFB850458DBEF0A8AEA71575D060C7DB3970F85A6E1E4C7ABF5AE8CDB0933D71E8C94E04A25619DCEE3D2261AD2EE6BF12FFA06D98A0864D87602733EC86A64521F2B18177B200CBBE117577A615D6C770988C0BAD946E208E24FA074E5AB3143DB5BFCE0FD108E4B82D120A92108011A723C12A787E6D788719A10BDBA5B2699C327186AF4E23C1A946834B6150BDA2583E9CA2AD44CE8DBBBC2DB04DE8EF92E8EFC141FBECAA6287C59474E6BC05D99B2964FA090C3A2233BA186515BE7ED1F612970CEE2D7AFB81BDD762170481CD0069127D5B05AA993B4EA988D8FDDC186FFB7DC90A6C08F4DF435C93402849236C3FAB4D27C7026C1D4DCB2602646DEC9751E763DBA37BDF8FF9406AD9E530EE5DB382F413001AEB06A53ED9027D831179727B0865A8918DA3EDBEBCF9B14ED44CE6CBACED4BB1BDB7F1447E6CC254B332051512BD7AF426FB8F401378CD2BF5983CA01C64B92ECF032EA15D1721D03F482D7CE6E74FEF6D55E702F46980C82B5A84031900B1C9E59E7C97FBEC7E8F323A97A7E36CC88BE0F1D45B7FF585AC54BD407B22B4154AACC8F6D7EBF48E1D814CC5ED20F8037E0A79715EEF29BE32806A1D58BB7C5DA76F550AA3D8A1FBFF0EB19CCB1A313D55CDA56C9EC2EF29632387FE8D76E3C0468043E8F663F4860EE12BF2D5B0B7474D6E694F91E6DBE115974A3926F12FEE5E438777CB6A932DF8CD8BEC4D073B931BA3BC832B68D9DD300741FA7BF8AFC47ED2576F6936BA424663AAB639C5AE4F5683423B4742BF1C978238F16CBE39D652DE3FDB8BEFC848AD922222E04A4037C0713EB57A81A23F0C73473FC646CEA306B4BCBC8862F8385DDFA9D4B7FA2C087E879683303ED5BDD3A062B3CF5B3A278A66D2A13F83F44F82DDF310EE074AB6A364597E899A0255DC164F31CC50846851DF9AB48195DED7EA1B1D510BD7EE74D73FAF36BC31ECFA268359046F4EB879F924009438B481C6CD7889A002ED5EE382BC9190DA6FC026E479558E4475677E9AA9E3050E2765694DFC81F56E880B96E7160C980DD98EDD3DFFFFFFFFFFFFFFFFF
        g = 2

        my_secret, client_e=DH_send_client_e(s, g, p, 8192, GEX=False)

        KEX_host_key, server_f = DH_recv_server_f(s, GEX=False)

        GEX=False

    elif KEX_ALGO=="diffie-hellman-group1-sha1": # 1024 bits
        # https://datatracker.ietf.org/doc/html/draft-ietf-secsh-transport-09.txt#section-6.1
        # https://datatracker.ietf.org/doc/html/rfc4253#section-8.1
        # https://datatracker.ietf.org/doc/html/rfc2409#section-6.2
        p = 0xFFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE65381FFFFFFFFFFFFFFFF
        g = 2

        my_secret, client_e=DH_send_client_e(s, g, p, 1024, GEX=False)

        KEX_host_key, server_f = DH_recv_server_f(s, GEX=False)

        GEX=False

    else:
        assert False

    # calculate DH shared secret:
    shared_secret=pow(server_f, my_secret, p)
    if VERBOSITY>=1:
        print (f"{hex(shared_secret)=}")

    return g, p, KEX_host_key, client_e, server_f, shared_secret, GEX

def expect_service_accept(s):
    # waiting for SSH_MSG_SERVICE_ACCEPT
    while True:
        msg=unwrap_packet(my_recv_chk_MAC(s))
        #hexdump.hexdump(msg)
        if msg[0]==SSH2_MSG_EXT_INFO:
            if VERBOSITY>=2:
                print ("got SSH2_MSG_EXT_INFO packet:")
        elif msg[0]==SSH_MSG_SERVICE_ACCEPT:
            break
        elif msg[0]==SSH_MSG_IGNORE:
            # "markus" packet, sent by server if DEBUG_KEXDH is on [see sshd.c in OpenSSH 9.0]
            # it is the first encrypted/MACed packed server sends as a test
            # to be ignored
            if VERBOSITY>=2:
                print ("got 'markus' packet:")
                hexdump.hexdump(msg)
        else:
            print (f"error, unhandled packet type at this stage: {msg[0]}, {msg_str[msg[0]]}")

def login_password(s):
    serv_auth_modes=get_list_of_auth_modes(s)
    if "password" not in serv_auth_modes:
        print ("Fatal error: server doesn't accept password auth")
        print (f"Supported are: {serv_auth_modes}")
        exit(0)

    if send_ssh_connection_password(s, USERNAME, PASSWORD)==False:
        exit(0)

def send_ssh_connection_publickey(s, username, client_pubkey_blob):
    if VERBOSITY>=1:
        print ("sending ssh-connection")
    msg=struct.pack("B", SSH_MSG_USERAUTH_REQUEST)
    msg+=pack_str(username)           # userstyle
    msg+=pack_str("ssh-connection")   # service
    msg+=pack_str("publickey")        # method
    msg+=b"\x00"                      # have_sig
    msg+=pack_str(SERVER_HOST_ALGO)   # pkalg
    msg+=pack_buf(client_pubkey_blob) # pkblob
    pkt=wrap_packet(msg)
    if VERBOSITY>=2:
        print ("send_ssh_connection_publickey(), sending:")
        hexdump.hexdump(pkt)
    my_send_with_MAC(s, pkt)
    #return msg # will be reused

def get_client_pubkey_blob_from_file(fname):
    f=open(fname)
    first_line=f.readline().rstrip()
    t=first_line.split(" ")
    _type=t[0]
    blob=base64.b64decode(t[1])
    return _type, blob

def RSA_sign_blob_with_pri_key(message, pri_key_fname):
    global SERVER_HOST_ALGO_PTR2
    f=open(pri_key_fname, "rb")
    d=f.read()
    f.close()
    t=cryptography.hazmat.primitives.serialization.load_ssh_private_key(data=d, password=b"")

    signature = t.sign(
        message,
        cryptography.hazmat.primitives.asymmetric.padding.PKCS1v15(),
        SERVER_HOST_ALGO_PTR2())
    return signature

def EC_sign_blob_with_pri_key(message, pri_key_fname):
    global SERVER_HOST_ALGO_PTR2
    f=open(pri_key_fname, "rb")
    d=f.read()
    f.close()
    t=cryptography.hazmat.primitives.serialization.load_ssh_private_key(data=d, password=b"")

    if VERBOSITY>=2:
        print ("EC_sign_blob_with_pri_key() message to sign:")
        hexdump.hexdump(message)

        print (f"{t.curve=}")
        print (f"{t=}")

    signature = t.sign(
        message,
        cryptography.hazmat.primitives.asymmetric.ec.ECDSA(SERVER_HOST_ALGO_PTR2()))
    # here signature is in ASN1 format
    r, s=cryptography.hazmat.primitives.asymmetric.utils.decode_dss_signature(signature)
    return pack_mpint(r)+pack_mpint(s)

"""
publickey auth method: RFC 4252 7
https://datatracker.ietf.org/doc/html/rfc4252#section-7

hostbound auth method:
https://www.openssh.com/agent-restrict.html#authverify
https://github.com/openssh/openssh-portable/blob/master/PROTOCOL#L347

userauth_pubkey() in auth2-pubkey.c in OpenSSH:
https://android.googlesource.com/platform/external/openssh/+/40e3686a7062605c5d122885a74ec478a26c2c77/auth2-pubkey.c#90
"""

def send_publickey_signed(s, kexgex_hash, client_pubkey_blob, server_pubkey_blob):
    global PRIKEY_FNAME, USERNAME
    msg=struct.pack("B", SSH_MSG_USERAUTH_REQUEST)
    msg+=pack_str(USERNAME)           # userstyle
    msg+=pack_str("ssh-connection")   # service
    if HOSTBOUND:
        msg+=pack_str("publickey-hostbound-v00@openssh.com") # method
    else:
        msg+=pack_str("publickey")    # method
    msg+=b"\x01"                      # have_sig
    msg+=pack_str(SERVER_HOST_ALGO) # pkalg
    msg+=pack_buf(client_pubkey_blob)
    if HOSTBOUND:
        msg+=pack_buf(server_pubkey_blob)

    # 'reconstruct' packet for signing
    msg_to_be_signed=pack_buf(kexgex_hash)
    msg_to_be_signed+=struct.pack("B", SSH_MSG_USERAUTH_REQUEST)
    msg_to_be_signed+=pack_str(USERNAME)         # userstyle
    msg_to_be_signed+=pack_str("ssh-connection") # service
    if HOSTBOUND:
        msg_to_be_signed+=pack_str("publickey-hostbound-v00@openssh.com") # method
    else:
        msg_to_be_signed+=pack_str("publickey") # method
    msg_to_be_signed+=b"\x01"                   # have_sig
    msg_to_be_signed+=pack_str(SERVER_HOST_ALGO)  # pkalg
    msg_to_be_signed+=pack_buf(client_pubkey_blob)
    if HOSTBOUND:
        msg_to_be_signed+=pack_buf(server_pubkey_blob)
    
    if VERBOSITY>=2:
        print ("send_publickey_signed(), packet to be signed:")
        hexdump.hexdump(msg_to_be_signed)

    if "rsa" in SERVER_HOST_ALGO:
        sig=RSA_sign_blob_with_pri_key(msg_to_be_signed, PRIKEY_FNAME)
    else:
        sig=EC_sign_blob_with_pri_key(msg_to_be_signed, PRIKEY_FNAME)
    
    if VERBOSITY>=2:
        print ("send_publickey_signed(), signature:")
        hexdump.hexdump(sig)
    
    pkt=wrap_packet(msg + pack_buf(pack_str(SERVER_HOST_ALGO) + pack_buf(sig)))

    if VERBOSITY>=2:
        print ("send_publickey_signed(), sending:")
        hexdump.hexdump(pkt)
    my_send_with_MAC(s, pkt)

def login_publickey(s, KEX_host_key, kexgex_hash):
    serv_auth_modes=get_list_of_auth_modes(s)
    if "publickey" not in serv_auth_modes:
        print ("Fatal error: server doesn't accept publickey auth")
        exit(0)
    _type, client_pubkey_blob=get_client_pubkey_blob_from_file(PUBKEY_FNAME)
    if VERBOSITY>=1:
        print (f"login_publickey() {_type=} from {PUBKEY_FNAME=}")
    send_ssh_connection_publickey(s, USERNAME, client_pubkey_blob)
    msg=unwrap_packet(my_recv_chk_MAC(s))
    if VERBOSITY>=2:
        print ("login_publickey(), got our public key back:")
        hexdump.hexdump(msg)

    send_publickey_signed(s, kexgex_hash, client_pubkey_blob, KEX_host_key)

    if VERBOSITY>=2:
        print ("login_publickey(), expecting SSH_MSG_USERAUTH_SUCCESS or _FAILURE")
    
    msg=unwrap_packet(my_recv_chk_MAC(s))
    if msg[0]==SSH_MSG_USERAUTH_SUCCESS:
        print ("Publickey auth finished correctly.")
        return True
    elif msg[0]==SSH_MSG_USERAUTH_FAILURE:
        print ("Fatal error. Publickey auth finished incorrectly.")
        return False
    else:
        print (f"Fatal error. msg[0]={hex(msg[0])}")
        assert False

def exec_recv_response(s):
    # receiving loop
    # as in https://en.wikipedia.org/wiki/Inversion_of_control
    # as in GUI
    while True:
        msg=unwrap_packet(my_recv_chk_MAC(s))
        msg_code=msg[0]
        if VERBOSITY>=1:
            print (f"got message {hex(msg_code)} or {msg_str[msg_code]}, {len(msg)=}")
        if msg_code==SSH_MSG_CHANNEL_WINDOW_ADJUST:
            # see RFC4254 5.2: https://www.rfc-editor.org/rfc/rfc4254#section-5.2
            channel, size = struct.unpack (">II", msg[1:])
            if VERBOSITY>=1:
                print (f"{channel=} {size=}")
            pass
        elif msg_code==SSH_MSG_CHANNEL_SUCCESS:
            pass
        elif msg_code==SSH_MSG_CHANNEL_DATA:
            unpack_channel_data(msg)
        elif msg_code==SSH_MSG_CHANNEL_EXTENDED_DATA:
            unpack_channel_extended_data(msg)
        elif msg_code==SSH_MSG_CHANNEL_EOF:
            pass
        elif msg_code==SSH_MSG_CHANNEL_REQUEST:
            unpack_exit_status(s, msg)
        elif msg_code==SSH_MSG_CHANNEL_CLOSE:
            break
        elif msg_code==SSH_MSG_DISCONNECT:
            print (f"Fatal error. Got SSH_MSG_DISCONNECT that we don't handle:")
            hexdump.hexdump(msg)
            exit(0)
        elif msg_code==SSH_MSG_DEBUG:
            print (f"Fatal error. Got SSH_MSG_DEBUG that we don't handle:")
            # https://datatracker.ietf.org/doc/html/rfc4253#section-11.3
            idx=2 # skip header. 
            tmp, idx = get_str(msg, idx)
            print (f"\"{tmp}\"")
            exit(0)
        else:
            print (f"Fatal error. Unhandled msg_code={hex(msg_code)} or {msg_str[msg_code]}")
            exit(0)

def exec_command(s, cmd):
    send_exec(s, cmd)
    exec_recv_response(s)

def expect_messages_till_SSH_MSG_CHANNEL_OPEN_CONFIRMATION(s):
    msg=unwrap_packet(my_recv_chk_MAC(s))
    if msg[0]==SSH_MSG_GLOBAL_REQUEST:
        # [see RFC4254 4]: https://www.rfc-editor.org/rfc/rfc4254#section-4
        if VERBOSITY>=1:
            print ("Got SSH_MSG_GLOBAL_REQUEST message")
            if VERBOSITY>=2:
                hexdump.hexdump(msg)
            idx=1
            tmp, idx = get_str(msg, idx)
            print (f"request name {tmp}")
            print (f"want reply {msg[idx]}")
            assert msg[idx]==0 # we don't handle replying (yet)
            print ("Running expect_messages_till_SSH_MSG_CHANNEL_OPEN_CONFIRMATION() again")
        return expect_messages_till_SSH_MSG_CHANNEL_OPEN_CONFIRMATION(s)
    if msg[0]==SSH_MSG_DEBUG:
        if VERBOSITY>=1:
            print ("Got SSH_MSG_DEBUG message:")
            if VERBOSITY>=2:
                hexdump.hexdump(msg)
            print ("Running expect_messages_till_SSH_MSG_CHANNEL_OPEN_CONFIRMATION() again")
        return expect_messages_till_SSH_MSG_CHANNEL_OPEN_CONFIRMATION(s)
    if msg[0]==SSH_MSG_CHANNEL_OPEN_CONFIRMATION:
        if VERBOSITY>=1:
            print ("Got SSH_MSG_CHANNEL_OPEN_CONFIRMATION message:")
            if VERBOSITY>=2:
                hexdump.hexdump(msg)
            rcpt_channel, sender_channel, init_window_size, max_pkt_size = struct.unpack (">IIII", msg[1:])
            print (f"{rcpt_channel=} {sender_channel=} {init_window_size=} {max_pkt_size=}")
        return msg
    print ("Fatal error in expect_msg().")
    print (f"While waiting for message {hex(msg_code)}, or: {msg_str[msg_code]}")
    print (f"Got message {hex(msg[0])} instead, or: {msg_str[msg[0]]}")
    exit(0)

# yes! byte-by-byte
# rationale: sometimes banner + next packet is sent by server in one tcp/ip packet
# and we don't know banner's size beforehand
def read_banner(s):
    rt=b""
    while True:
        c=s.recv(1)
        if c==b"\n":
            break
        rt+=c
    return rt.rstrip()

def do_all():
    global recv_seqno, KEX_ALGO, KEX_HASH, SERVER_HOST_ALGO, SERVER_HOST_ALGOS
    global CIPHER_ALGO, SERVER_HOST_ALGO_NIBBLES
    global SERVER_HOST_ALGO_PTR, SERVER_HOST_ALGO_PTR2
    global MAC_SIZE, MAC_ALGO, CIPHER_KEY_SIZE, ENCRYPTION, EXT_INFO_C

    if PUBKEY_FNAME!=None:
        _type, _ = get_client_pubkey_blob_from_file(PUBKEY_FNAME)
        if VERBOSITY>=1:
            print (f"{_type=} from {PUBKEY_FNAME=}")
        if _type=="ssh-rsa":
            SERVER_HOST_ALGOS="rsa-sha2-512,rsa-sha2-256,ssh-rsa"
        else:
            SERVER_HOST_ALGOS=_type

    if ":" in HOST:
        s = socket.socket(socket.AF_INET6)
    else:
        s = socket.socket(socket.AF_INET)
    s.settimeout(2*60) # FIXME: make option
    s.connect((HOST, PORT))

    serv_banner=read_banner(s)

    recv_seqno+=1
    if VERBOSITY>=1:
        print ("serv_banner:", serv_banner.decode("utf-8"))

    my_send(s, client_banner+b"\r\n")

    unpack_serv_KEX(my_recv(s))

    send_client_KEX(s)

    KEX_ALGO=first_common_element(KEX_ALGOS.split(","), from_serv_kex_algorithms.split(","))

    if KEX_ALGO==None:
        print ("Fatal error. No common KEX algorithm")
        print (f"Server: {from_serv_kex_algorithms}")
        print (f"What we support: {KEX_ALGOS}")
        exit(0)
    if VERBOSITY>=1:
        print (f"Picking KEX algo {KEX_ALGO}")

    KEX_ALGO_hash_func={
    "diffie-hellman-group1-sha1":           hashlib.sha1,
    "diffie-hellman-group14-sha1":          hashlib.sha1,
    "diffie-hellman-group14-sha256":        hashlib.sha256,
    "diffie-hellman-group16-sha512":        hashlib.sha512,
    "diffie-hellman-group18-sha512":        hashlib.sha512,
    "diffie-hellman-group-exchange-sha1":   hashlib.sha1,
    "diffie-hellman-group-exchange-sha256": hashlib.sha256,
    "ecdh-sha2-nistp256":                   hashlib.sha256,
    "ecdh-sha2-nistp384":                   hashlib.sha384,
    "ecdh-sha2-nistp521":                   hashlib.sha512}
    KEX_HASH=KEX_ALGO_hash_func[KEX_ALGO]

    CIPHER_ALGO=first_common_element(CIPHER_ALGOS.split(","), from_serv_encryption_algorithms.split(","))
    if CIPHER_ALGO==None:
        print ("Fatal error. No common cipher algorithm")
        print (f"Server: {from_serv_encryption_algorithms}")
        print (f"What we support: {CIPHER_ALGOS}")
        exit(0)
    if VERBOSITY>=1:
        print (f"Picking cipher: {CIPHER_ALGO}")

    CIPHER_ALGO_list={
    "aes128-ctr": (True, 16),
    "aes256-ctr": (True, 32),
    "none": (False, 0)}
    ENCRYPTION, CIPHER_KEY_SIZE = CIPHER_ALGO_list[CIPHER_ALGO]

    SERVER_HOST_ALGO=first_common_element(SERVER_HOST_ALGOS.split(","), from_serv_server_host_algorithms.split(","))
    if SERVER_HOST_ALGO==None:
        print ("Fatal error. No common server host algorithm")
        print (f"Server: {from_serv_server_host_algorithms}")
        print (f"What we support: {SERVER_HOST_ALGOS}")
        exit(0)
    if VERBOSITY>=1:
        print (f"Picking server host algo {SERVER_HOST_ALGO}")

    SERVER_HOST_ALGO_list={
    "ssh-dss":             (40,  hashlib.sha1,   cryptography.hazmat.primitives.hashes.SHA1),
    "ssh-rsa":             (40,  hashlib.sha1,   cryptography.hazmat.primitives.hashes.SHA1),
    "rsa-sha2-256":        (64,  hashlib.sha256, cryptography.hazmat.primitives.hashes.SHA256),
    "rsa-sha2-512":        (128, hashlib.sha512, cryptography.hazmat.primitives.hashes.SHA512),
    "ecdsa-sha2-nistp256": (64,  hashlib.sha256, cryptography.hazmat.primitives.hashes.SHA256)}
    SERVER_HOST_ALGO_NIBBLES, SERVER_HOST_ALGO_PTR, SERVER_HOST_ALGO_PTR2 = SERVER_HOST_ALGO_list[SERVER_HOST_ALGO]

    MAC_ALGO=first_common_element(MAC_ALGOS.split(","), from_serv_mac_algorithms.split(","))
    if MAC_ALGO==None:
        print ("Fatal error. No common MAC algorithm")
        print (f"Server: {from_serv_mac_algorithms}")
        print (f"What we support: {MAC_ALGOS}")
        exit(0)
    if VERBOSITY>=1:
        print (f"Picking MAC algo {MAC_ALGO}")

    MAC_ALGO_list={
    "hmac-sha2-256": (0x20,   hashlib.sha256),
    "hmac-sha2-512": (0x40,   hashlib.sha512),
    "hmac-sha1":     (160//8, hashlib.sha1)}
    MAC_SIZE, MAC_ALGO = MAC_ALGO_list[MAC_ALGO]

    if KEX_ALGO.startswith("ecdh-sha2-nistp"):
        KEX_host_key, client_pub, server_pub, shared_secret_buf = ECDH_do_exchange(s)
        # calculate the hash of hodgepodge. the most important part - shared_secret, which is unknown to interceptor
        # all other parts can be easily intercepted!
        client_pub_i=int.from_bytes(client_pub, byteorder='big')
        server_pub_i=int.from_bytes(server_pub, byteorder='big')
        shared_secret=int.from_bytes(shared_secret_buf, byteorder='big')

        kexgex_hash=calc_kexgex_hash(client_banner, serv_banner, client_KEX_blob, serv_KEX_blob, KEX_host_key, client_pub_i, server_pub_i, shared_secret, 0, 0, GEX=False)
    elif KEX_ALGO.startswith ("diffie-hellman-group"):
        g, p, KEX_host_key, client_e, server_f, shared_secret, GEX = DH_do_exchange(s)
        # calculate the hash of hodgepodge. the most important part - shared_secret, which is unknown to interceptor
        # all other parts can be easily intercepted!
        kexgex_hash=calc_kexgex_hash(client_banner, serv_banner, client_KEX_blob, serv_KEX_blob, KEX_host_key, client_e, server_f, shared_secret, g, p, GEX)
    else:
        assert False
    if VERBOSITY>=2:
        print ("kexgex_hash:")
        hexdump.hexdump(kexgex_hash)

    x={
    "ssh-dss":             chk_DSA_kexgex_hash,
    "ssh-rsa":             chk_RSA_kexgex_hash,
    "rsa-sha2-256":        chk_RSA_kexgex_hash,
    "rsa-sha2-512":        chk_RSA_kexgex_hash,
    "ecdsa-sha2-nistp256": chk_ECDSA_kexgex_hash}
    x[HOST_KEY_TYPE](kexgex_hash)

    derive_keys(shared_secret, kexgex_hash, kexgex_hash)

    expect_new_keys(s)

    send_new_keys(s)

    send_ssh_userauth(s)

    expect_service_accept(s)

    if USERNAME!=None and PASSWORD!=None:
        if login_password(s)==False:
            return
    elif USERNAME!=None and PUBKEY_FNAME!=None and PRIKEY_FNAME!=None:
        if login_publickey(s, KEX_host_key, kexgex_hash)==False:
            return
    else:
        print ("Fatal error: can't go further without username/passsword/pubkey/prikey. exiting.")
        exit(0)

    send_channel_open(s)

    expect_messages_till_SSH_MSG_CHANNEL_OPEN_CONFIRMATION(s)

    exec_command(s, COMMAND)

    # be polite
    send_eof(s)
    s.close()

parse_command_line()
if HOST==None:
    print ("Fatal error: host not set")
    exit(0)

if "none" in CIPHER_ALGOS:
    EXT_INFO_C=True
else:
    EXT_INFO_C=False

try:
    do_all()
except TimeoutError as e:
    print ("Exception: TimeoutError:", e)
except ConnectionRefusedError as e:
    print ("Exception: ConnectionRefusedError:", e)
except OSError as e:
    print ("Exception: OSError:", e)
except ConnectionResetError as e:
    print ("Exception: ConnectionResetError:", e)
except TimeoutError as e:
    print ("Exception: TimeoutError:", e)
