#!/usr/bin/python3

import argparse
import os
import random
import shutil
import sys
import tempfile
import textwrap
from typing import Iterable

try:
    import ldap
except ImportError:
    print("Please install python-ldap before running this program")
    sys.exit(1)

basedn = "<%= dc_suffix %>"
peopledn = f"ou=people,{basedn}"
<%-
  ldap_servers.map! { |l| "'ldaps://#{l}'" }
-%>
uris = [<%= ldap_servers.join(", ") %>]
random.shuffle(uris)
uri = " ".join(uris)
timeout = 5
binddn = f"cn=<%= fqdn %>,ou=Hosts,{basedn}"
ldap_secret_file = "<%= ldap_pwfile %>"
nslcd_conf_file = "<%= nslcd_conf_file %>"
# filter out disabled accounts also
# too bad uidNumber doesn't support >= filters
objfilter = "(&(objectClass=inetOrgPerson)(objectClass=ldapPublicKey)(objectClass=posixAccount)(sshPublicKey=*))"
keypathprefix = "/home"

parser = argparse.ArgumentParser(
    formatter_class=argparse.RawDescriptionHelpFormatter,
    description=textwrap.dedent(f'''\
        Will fetch all enabled user accounts under {peopledn}
        with ssh keys in them and write each one to
        {keypathprefix}/<login>/.ssh/authorized_keys

        It will return failure when no keys are updated and success
        when one or more keys have changed.

        This script is intended to be run from cron as root;
        '''))
parser.add_argument('-n', '--dry-run', action='store_true')
parser.add_argument('-v', '--verbose', action='store_true')
args = parser.parse_args()


def get_bindpw() -> str:
    try:
        return get_nslcd_bindpw(nslcd_conf_file)
    except:
        pass

    try:
        return get_ldap_secret(ldap_secret_file)
    except:
        pass

    print("Error while reading password file, aborting")
    sys.exit(1)


def get_nslcd_bindpw(pwfile: str) -> str:
    try:
        with open(pwfile, 'r') as f:
            pwfield = "bindpw"
            for line in f:
                ls = line.strip().split()
                if len(ls) == 2 and ls[0] == pwfield:
                    return ls[1]
    except IOError as e:
        print("Error while reading nslcd file " + pwfile)
        print(e)
        raise

    print("No " + pwfield + " field found in nslcd file " + pwfile)
    raise Exception()


def get_ldap_secret(pwfile: str) -> str:
    try:
        with open(pwfile, 'r') as f:
            pw = f.readline().strip()
    except IOError as e:
        print("Error while reading password file " + pwfile)
        print(e)
        raise
    return pw


def write_keys(keys: Iterable[bytes], user: bytes, uid: int, gid: int) -> bool:
    userdir = f"{keypathprefix}/{user.decode('utf-8')}"
    keyfile = f"{userdir}/.ssh/authorized_keys"

    fromldap = ""
    for key in keys:
        fromldap += key.decode("utf-8").strip() + "\n"

    fromfile = ""
    try:
        with open(keyfile, 'r') as f:
            fromfile = f.read()
    except FileNotFoundError:
        pass

    if fromldap == fromfile:
        return False

    if args.dry_run:
        print(f"Would write {keyfile}")
        return True

    if args.verbose:
        print(f"Writing {keyfile}")

    if not os.path.isdir(userdir):
        shutil.copytree('/etc/skel', userdir)
        os.chown(userdir, uid, gid)
        for root, dirs, files in os.walk(userdir):
            for d in dirs:
                os.chown(os.path.join(root, d), uid, gid)
            for f in files:
                os.chown(os.path.join(root, f), uid, gid)

    try:
        os.makedirs(f"{userdir}/.ssh", 0o700)
    except FileExistsError:
        pass
    os.chmod(f"{userdir}/.ssh", 0o700)
    os.chown(f"{userdir}/.ssh", uid, gid)

    with tempfile.NamedTemporaryFile(
            prefix='ldap-sshkey2file-', mode='w', delete=False) as tmpfile:
        tmpfile.write(fromldap)
    os.chmod(tmpfile.name, 0o600)
    os.chown(tmpfile.name, uid, gid)
    shutil.move(tmpfile.name, keyfile)
    # Hmm, apparently shutil.move does not preserve user/group so let's reapply
    # them. I still like doing it before as this should be more "atomic"
    # if it actually worked, so it's "good practice", even if shutil.move sucks
    os.chown(keyfile, uid, gid)
    os.chmod(keyfile, 0o600)
    return True


bindpw = get_bindpw()

changed = False
try:
    ld = ldap.initialize(uri)
    ld.set_option(ldap.OPT_NETWORK_TIMEOUT, timeout)
    if uri.startswith("ldap:/"):
        ld.start_tls_s()
    ld.bind_s(binddn, bindpw)
    res = ld.search_s(peopledn, ldap.SCOPE_ONELEVEL, objfilter,
                      ['uid', 'sshPublicKey', 'uidNumber', 'gidNumber'])
    try:
        os.makedirs(keypathprefix, 0o701)
    except FileExistsError:
        pass

    if args.verbose:
        print("Found users:",
              ", ".join(sorted([x[1]['uid'][0].decode('utf-8') for x in res])))

    for result in res:
        dn, entry = result
        # skip possible system users
        if 'uidNumber' not in entry or int(entry['uidNumber'][0]) < 500:
            continue
        if write_keys(entry['sshPublicKey'], entry['uid'][0],
                      int(entry['uidNumber'][0]), int(entry['gidNumber'][0])):
            changed = True

    ld.unbind_s()
except Exception:
    print("Error")
    raise

if changed:
    if args.verbose:
        print("SSH keys changed")
    sys.exit(0)

if args.verbose:
    print("No changes in SSH keys")
sys.exit(1)


# vim:ts=4:sw=4:et:ai:si