#!/usr/bin/env python
# -*- coding: utf-8 -*-

# Copyright (C) 2017 Pavel Raiskup
#
# This program accepts one argument IP or HOSTNAME.  First try to connect to the
# HOSTNAME as 'root' user.  If cloud-init scripts instruct us to use different
# user than 'root', switch to that user and check again.  In the end, print the
# successful username on stdout.
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation; either version 2 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.

from re import compile as re_compile
import sys
from os import devnull
from threading import Thread, Event
from argparse import ArgumentParser
from subprocess import Popen, PIPE
import logging

handler = logging.StreamHandler()
log = logging.getLogger()
log.setLevel(logging.INFO)
log.addHandler(handler)

# create console handler and set level to debug

ssh = [
    'ssh',
    '-o', 'StrictHostKeyChecking=no',
    '-o', 'UserKnownHostsFile=/dev/null',
    '-o', 'PasswordAuthentication=no',
    '-o', 'ConnectTimeout=10',
]

expected_output = 'foobar'
inner_cmd = 'echo ' + expected_output


class Checker(Thread):
    user = 'root'
    daemon = True
    user_re = '[a-zA-Z0-9_.][a-zA-Z0-9_.-]*[$]?'
    re_clouduser = re_compile('Please login as the user "({0})"'.format(user_re))
    event = Event()

    def loop(self):
        cmd = ssh + [
            '{0}@{1}'.format(self.user, self.args.host),
            inner_cmd,
        ]

        with open(devnull, 'w') as drop:
            log.debug('executing: ' + ' '.join(cmd))
            self.child = Popen(cmd, stdout=PIPE, stderr=drop)
            (stdout, _) = self.child.communicate()

        exp = (expected_output + '\n').encode('ascii')
        if self.child.returncode == 0 and stdout == exp:
            if self.args.print_user:
                print(self.user)
            return True

        if self.args.cloud_user:
            match = self.re_clouduser.search(str(stdout))
            if match:
                self.user = match.group(1)
                log.info('cloud user switched to ' + self.user)
                return False

    def run(self):
        while True:
            if self.loop():
                # Success!
                break

            if self.event.wait(1):
                log.debug("stopping per kill event")
                break

    def kill(self):
        self.event.set()

        # Best effort kill.
        try:
            self.child.kill()
        except:
            pass
        self.join()


parser = ArgumentParser(
        description="Wait till the host's ssh becomes responsive.")
parser.add_argument('host', help='hostname or IP')
parser.add_argument('--timeout',
        help='seconds to wait before failure, default=indefinitely',
        default=None, type=float)
parser.add_argument('--check-cloud-user', action='store_true', default=False,
        dest='cloud_user',
        help='if cloud-init disallows "root" login, try to detect the cloud ' \
            +'user and use that')
parser.add_argument('--print-user', action='store_true', default=False,
        dest='print_user',
        help='print the username which succeeded to connect on stdout')
parser.add_argument('--log', default=False,
        dest='log_verbosity',
        help='set the threshold for logging, e.g. debug, info, error, ...')


def main():
    sleep_period = 1.0
    args = parser.parse_args()

    if args.log_verbosity:
        log.setLevel(logging.getLevelName(args.log_verbosity.upper()))

    def timeouted():
        if args.timeout is None:
            return False
        log.debug("wait {0}s, remains {1}s".format(sleep_period, args.timeout))
        args.timeout -= sleep_period
        return args.timeout <= 0

    checker = Checker()
    checker.args = args
    checker.start()

    try:
        # threading.join() is not Ctrl-C interruptable :( in python2, so we need
        # this ugly infinite loop.
        # https://stackoverflow.com/questions/25676835/signal-handling-in-multi-threaded-python
        while True:
            checker.join(sleep_period)
            if not checker.is_alive():
                # Success!
                return 0

            if timeouted():
                log.error("timeout!")
                checker.kill()
                return 1

    except KeyboardInterrupt:
        log.error("interrupt by user")
        checker.kill()
        return 1

if __name__ == "__main__":
    sys.exit(main())