#!/usr/bin/env python3

# Hostlist utility

__version__ = "2.2.1"

# Copyright (C) 2008-2018
#                    Kent Engström <kent@nsc.liu.se>,
#                    Thomas Bellman <bellman@nsc.liu.se>,
#                    Pär Lindfors <paran@nsc.liu.se> and
#                    Torbjörn Lönnemark <ketl@nsc.liu.se>,
#                    National Supercomputer Centre
# 
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation; either version 2 of the License, or
# (at your option) any later version.
# 
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
# General Public License for more details.
# 
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
# 02110-1301, USA.

import sys
import os
import optparse
import operator
import re

from hostlist import expand_hostlist, collect_hostlist, numerically_sorted, parse_slurm_tasks_per_node, BadHostlist, __version__ as library_version

# Python 3 compatibility
try:
    from functools import reduce
except ImportError:
    pass # earlier Python versions have this in __builtin__

# Helper functions
def flatten(list_of_lists):
    res = []
    for l in list_of_lists:
        res.extend(l)
    return res

def complain(msg):
    sys.stderr.write(f"{sys.argv[0]}: {msg}\n")

def complain_about_whitespace(s):
    fixed = re.sub(r'(\s|\N{ZERO WIDTH SPACE})', '', s)
    if fixed != s:
        complain(f"Spurious whitespace seen in argument {s!r}.")

# Operators

def func_union(args):
    return reduce(operator.or_, args)

def func_intersection(args):
    return reduce(operator.and_, args)

def func_difference(args):
    return reduce(operator.sub, args)

def func_xor(args):
    return reduce(operator.xor, args)

op = optparse.OptionParser(usage="usage: %prog [OPTION]... [HOSTLIST]...")
op.add_option("-u", "--union",
              action="store_const", dest="func", const=func_union,
              default=func_union,
              help="compute the union of the hostlist arguments (default)")
op.add_option("-i", "--intersection",
              action="store_const", dest="func", const=func_intersection,
              help="compute the intersection of the hostlist arguments")
op.add_option("-d", "--difference",
              action="store_const", dest="func", const=func_difference,
              help="compute the difference between the first hostlist argument and the rest")
op.add_option("-x", "--symmetric-difference",
              action="store_const", dest="func", const=func_xor,
              help="compute the symmetric difference between the first hostlist argument and the rest")
op.add_option("-c", "--collapse",
              action="store_false", dest="expand",
              help="output the result as a hostlist expression (default)")
op.add_option("-o", "--offset",
              action="store", type="int",
              help="skip OFFSET hosts from the beginning (default: skip nothing)")
op.add_option("-l", "--limit",
              action="store", type="int",
              help="limit result to the LIMIT first hosts (default: no limit)")
op.add_option("-n", "--count",
              action="store_true",
              help="output the number of hosts instead of a hostlist")
op.add_option("-e", "--expand",
              action="store_true",
              help="output the result as an expanded list of hostnames")
op.add_option("-w",
              action="store_true",
              dest="expand_deprecated",
              help="DEPRECATED version of -e/--expand")
op.add_option("-q", "--quiet",
              action="store_true",
              help="output nothing (useful with --non-empty)")
op.add_option("-0", "--non-empty",
              action="store_true",
              help="return success only if the resulting hostlist is non-empty")
op.add_option("-s", "--separator",
              action="store", type="string",  default="\n",
              help="separator to use between hostnames when outputting an expanded list (default is newline)")
op.add_option("-p", "--prepend",
              action="store", type="string",  default="",
              help="string to prepend to each hostname when outputting an expanded list")
op.add_option("-a", "--append",
              action="store", type="string",  default="",
              help="string to append to each hostname when outputting an expanded list")
op.add_option("-S", "--substitute",
              action="store", type="string",
              help="regular expression substitution ('from,to') to apply to each hostname")
op.add_option("--append-slurm-tasks",
              action="store", type="string",
              help="append task count based on the passed SLURM_TASKS_PER_NODE string")
op.add_option("--repeat-slurm-tasks",
              action="store", type="string",
              help="repeat hostnames based on the passed SLURM_TASKS_PER_NODE string")
op.add_option("--chop",
              action="store", type="int",
              help="chop into chunks of this size when doing --collapse (default: all in one chunk)")
op.add_option("--version",
              action="store_true",
              help="show version")
(opts, args) = op.parse_args()

if opts.version:
    print("Version %s (library version %s)" % (__version__, library_version))
    sys.exit()

func_args  = []
try:
    for a in args:
        input_string = a
        if a == "-":
            input_string = sys.stdin.read()

        func_args.append(set())
        for e in input_string.split():
            complain_about_whitespace(e)
            func_args[-1] |= set(expand_hostlist(e))

except BadHostlist as e:
    sys.stderr.write("Bad hostlist ``%s'' encountered: %s\n"
                     % ((a,) + e.args))
    sys.exit(os.EX_DATAERR)

if not func_args:
    op.print_help()
    sys.exit(os.EX_USAGE)

if opts.expand_deprecated:
    sys.stderr.write("WARNING: Option -w is deprecated. Use -e or --expand instead!\n")

# Set up initial hostlist from the arguments and the function
res = opts.func(func_args)

# Handle --substitute
if opts.substitute:
    try:
        from_re, to = opts.substitute.split(",", 1)
        res = [re.sub(from_re, to, host) for host in res]
        res = set(res) # remove duplicates that may have been created
    except (ValueError, re.error):
        sys.stderr.write("Bad --substitute option: '%s'\n" % opts.substitute)
        sys.exit(os.EX_DATAERR)

# Sort numerically
res = numerically_sorted(res) # res can be list or set before this line

# Handle --offset
if opts.offset is not None:
    res = res[opts.offset:]

# Handle --limit
if opts.limit is not None:
    res = res[:opts.limit]

# Handle options using SLURM task lists
if opts.append_slurm_tasks and opts.repeat_slurm_tasks:
    sys.stderr.write("You cannot use --append-slurm-tasks and --repeat--slurm-tasks at the same time.\n")
    sys.exit(os.EX_DATAERR)
elif opts.append_slurm_tasks or opts.repeat_slurm_tasks:
    if opts.append_slurm_tasks:
        task_list =  opts.append_slurm_tasks
    else:
        task_list =  opts.repeat_slurm_tasks

    try:
        task_list = parse_slurm_tasks_per_node(task_list)
    except BadHostlist as e:
        sys.stderr.write("Bad task list encountered: %s\n" % e.args)
        sys.exit(os.EX_DATAERR)
    if len(task_list) != len(res):
        sys.stderr.write("Length of tasks list != number of hostnames\n")
        sys.exit(os.EX_DATAERR)

# Output in the right way
if opts.quiet:
    pass
elif opts.count:
    print(len(res))
elif opts.expand or opts.expand_deprecated:
    if opts.append_slurm_tasks:
        print(opts.separator.join([opts.prepend + host + opts.append + str(tasks)
                                   for host, tasks in zip(res, task_list)]))
    elif opts.repeat_slurm_tasks:
        repeated_hosts = flatten([[host]*tasks for host, tasks in zip(res, task_list)])
        print(opts.separator.join([opts.prepend + host + opts.append
                                   for host in repeated_hosts]))
    else:
        print(opts.separator.join([opts.prepend + host + opts.append
                                   for host in res]))
else:
    # --collapse
    if opts.chop and opts.chop > 0:
        chunk_size = opts.chop
    else:
        chunk_size = len(res)
    i=0
    while i < len(res):
        try:
            if i > 0:
                sys.stdout.write(opts.separator)
            sys.stdout.write(opts.prepend + collect_hostlist(res[i:i+chunk_size]) + opts.append)
            i += chunk_size
        except BadHostlist as e:
            sys.stderr.write("Bad hostname encountered: %s\n" % e.args)
            sys.exit(os.EX_DATAERR)
    sys.stdout.write("\n")

# Exit
if opts.non_empty and len(res) == 0:
    sys.exit(os.EX_NOINPUT)
