#!/usr/bin/python3
# This file is part of LTSP, https://ltsp.org
# Copyright 2019 the LTSP team, see AUTHORS
# SPDX-License-Identifier: GPL-3.0-or-later
"""
Usage: pwmerge [--sur=] [--sgr=] [--dur=] [--dgr=] [-l] [-q] sdir ddir mdir

Merge passwd, group and optionally shadow and gshadow from source directory
sdir and destination directory ddir into a merge directory mdir.

Options:
  --sur=regex: source user accounts to import
  --sgr=regex: source groups; their member users are also imported
  --dur=regex: destination user accounts to preserve
  --dgr=regex: destination groups; their member users are also preserved
  -l, --ltsp:  activate the LTSP account merging mode; read below
  -q, --quiet: only show warnings and errors, not informational messages

If UIDs or primary GIDs collide in the final merging, execution is aborted.
Using a regex allows one to define e.g.: "^(?!administrator)(a|b|c).*",
which matches accounts starting from a, b, or c, but not administrator.
Group regexes may also match system groups if they are prefixed with ":",
e.g. ":sudo" matches all sudoers. Btw, ".*" = match all, "" = match none.
All regexes default to none; except if sgr = "", then sur defaults to all.

Default mode:
The default mode is useful for migrating the users of an older installation
to a new one. All available [g]shadow information is preserved.
Source [g]shadow must exist. Destination [g]shadow need to exist only if
dur or dgr are specified.
Example use, from the new installation:
    pwmerge /root/pwold/ /etc/ /root/pwmerged/
    cp -a /root/pwmerged/* /etc/

LTSP mode:
By specifying --ltsp, the LTSP mode is activated. This allows the merged
installation to have some users authenticate against a remote SSH server,
and optionally some others authenticate locally.
 * Source [g]shadow aren't processed at all. Password hashes for imported
   users are not stored; they'll authenticate against the SSH server.
 * Users are tagged for pamltsp logins by setting their pw_passwd to "pamltsp".
   The sysadmin may run `passwd user` to set a local password and untag them,
   or `usermod -p pamltsp user` to tag other users.
 * Local passwordless logins for remote users are supported by running
   `usermod -p pamltsp=$base64_password user`.
 * Guest passwordless logins without SSH authentication are supported by
   `usermod -p pamltsp= user`; /home/$user must be NFS or local, not SSHFS.
 * Destination [g]shadow are only processed if dur or dgr are specified.
   Those preserved users will authenticate locally.
"""
import getopt
import os
import re
import shutil
import sys

# Names are from `man getpwent/getgrent/getspent/getsgent`
PW_NAME = 0
PW_PASSWD = 1
PW_UID = 2
PW_GID = 3
PW_GECOS = 4
PW_DIR = 5
PW_SHELL = 6
# SP_NAMP is the same as PW_NAME so we ignore it
SP_PWDP = 7
SP_LSTCHG = 8
SP_MIN = 9
SP_MAX = 10
SP_WARN = 11
SP_INACT = 12
SP_EXPIRE = 13
SP_FLAG = 14
# These additional fields are for internal use
PW_GRNAME = 15  # The primary group name of the user
PW_MARK = 16  # Boolean, True if the user is marked to import/preserve
SP_DEFS = [  # Default values when source shadow doesn't exist (man 5 shadow)
    "!",  # "!" is better than "*", as `useradd` and `passwd -l` use "!"
    "",  # Disable password aging; it's managed on the server
    "", "", "", "", "", "", "", False]
GR_NAME = 0
GR_PASSWD = 1
GR_GID = 2
GR_MEM = 3
# SG_NAMP is the same as GR_NAME so we ignore it
SG_PASSWD = 4
# Currently we process SG_ADM (group administrator list) as a simple string;
# file a bug report if you need it properly merged:
SG_ADM = 5
# SG_MEM is the same as GR_MEM so we ignore it
# Default values when source gshadow doesn't exist:
SG_DEFS = ["*", "", ""]

QUIET = False


def log(*args, end='\n', error=False):
    """Print errors to stderr; print everything if --quiet wasn't specified"""
    if error or not QUIET:
        print(*args, end=end, file=sys.stderr)


def create(name, mode):
    """For better security, create files directly with the correct mode"""
    # If the file already exists, delete it, otherwise it gets the old mode
    if os.path.lexists(name):
        os.remove(name)
    fds = os.open(name, os.O_WRONLY | os.O_CREAT, mode)
    return open(fds, "w")


class PwMerge:
    """Merge passwd and group from source directory "sdir" to "ddir"."""
    def __init__(self, sdir, ddir, mdir, sur="", sgr="", dur="", dgr="",
                 ltsp=False):
        self.spasswd, self.sgroup = self.read_dir(sdir, not ltsp)
        self.dpasswd, self.dgroup = self.read_dir(ddir, dur or dgr)
        self.mdir = mdir
        if not sur and not sgr:
            self.sur = ".*"
        else:
            self.sur = sur
        self.sgr = sgr
        self.dur = dur
        self.dgr = dgr
        self.ltsp = ltsp
        # Read [ug]id min/max from login.defs or default to 1000/60000
        self.uid_min = 1000
        self.uid_max = 60000
        self.gid_min = 1000
        self.gid_max = 60000
        with open("/etc/login.defs", "r") as file:
            for line in file.readlines():
                words = line.split()
                if not words:
                    continue
                if words[0] == "UID_MIN":
                    self.uid_min = int(words[1])
                elif words[0] == "UID_MAX":
                    self.uid_max = int(words[1])
                elif words[0] == "GID_MIN":
                    self.gid_min = int(words[1])
                elif words[0] == "GID_MAX":
                    self.gid_max = int(words[1])

    @staticmethod
    def read_dir(xdir, read_shadow):
        """Read xdir/{passwd,group,shadow,gshadow} into dictionaries"""
        # passwd is a dictionary with keys=PW_NAME string, values=PW_ENTRY list
        passwd = {}
        with open("{}/passwd".format(xdir), "r") as file:
            for line in file.readlines():
                pwe = line.strip().split(":")
                if len(pwe) != 7:
                    raise ValueError("Invalid passwd line:\n{}".format(line))
                # Add defaults in case shadow doesn't exist
                pwe += SP_DEFS
                # Convert uid/gid to ints to be able to do comparisons
                pwe[PW_UID] = int(pwe[PW_UID])
                pwe[PW_GID] = int(pwe[PW_GID])
                passwd[pwe[PW_NAME]] = pwe
        # g2n is a temporary dictionary to map from gid to group name
        # It's used to construct pwe[PW_GRNAME] and discarded after that
        g2n = {}
        # group is a dictionary with keys=GR_NAME string, values=GR_ENTRY list
        # Note that group["user"][GR_MEM] is a set of members (strings)
        group = {}
        with open("{}/group".format(xdir), "r") as file:
            for line in file.readlines():
                gre = line.strip().split(":")
                if len(gre) != 4:
                    raise ValueError("Invalid group line:\n{}".format(line))
                # Add defaults in case gshadow doesn't exist
                gre += SG_DEFS
                gre[GR_GID] = int(gre[GR_GID])
                # Use set for group members, to avoid duplicates
                # Keep only non-empty group values
                gre[GR_MEM] = set(
                    [x for x in gre[GR_MEM].split(",") if x])
                group[gre[GR_NAME]] = gre
                # Construct g2n
                g2n[gre[GR_GID]] = gre[GR_NAME]
        # Usually system groups are like: "saned:x:121:"
        # while user groups frequently are like: ltsp:x:1000:ltsp
        # For simplicity, explicitly mention the primary user for all groups
        # In the same iteration, set pwe[PW_GRNAME]
        for pwn, pwe in passwd.items():
            grn = g2n[pwe[PW_GID]]
            pwe[PW_GRNAME] = grn
            if pwn not in group[grn][GR_MEM]:
                group[grn][GR_MEM].add(pwn)
        if read_shadow:
            with open("{}/shadow".format(xdir), "r") as file:
                for line in file.readlines():
                    pwe = line.strip().split(":")
                    if len(pwe) != 9:
                        # It's invalid; displaying it isn't a security issue
                        raise ValueError(
                            "Invalid shadow line:\n{}".format(line))
                    if pwe[0] in passwd:
                        # List slice
                        passwd[pwe[0]][SP_PWDP:SP_FLAG+1] = pwe[1:9]
        if read_shadow:
            with open("{}/gshadow".format(xdir), "r") as file:
                for line in file.readlines():
                    gre = line.strip().split(":")
                    if len(gre) != 4:
                        # It's invalid; displaying it isn't a security issue
                        raise ValueError(
                            "Invalid gshadow line:\n{}".format(line))
                    if gre[0] in group:
                        # List slice
                        group[gre[0]][SG_PASSWD:SG_ADM+1] = gre[1:3]

        return (passwd, group)

    def mark_users(self, xpasswd, xgroup, xur, xgr):
        """Mark users in [sd]passwd that match the [sd]ur/[sd]gr regexes.
        Called twice, once for source and once for destination."""
        # Mark all users that match xgr
        if xgr:
            # grn = GRoup Name, gre = GRoup Entry
            for grn, gre in xgroup.items():
                if not self.gid_min <= gre[GR_GID] <= self.gid_max:
                    # Match ":sudo"; don't match "s.*"
                    # grnm = modified GRoup Name used for Matching
                    grnm = ":{}".format(grn)
                else:
                    grnm = grn
                # re.fullmatch needs Python 3.4 (xenial+ /jessie+)
                if not re.fullmatch(xgr, grnm):
                    continue
                for grm in gre[GR_MEM]:
                    xpasswd[grm][PW_MARK] = True
        # Mark all users that match xur
        if xur:
            for pwn, pwe in xpasswd.items():
                if not self.uid_min <= pwe[PW_UID] <= self.uid_max:
                    continue
                if re.fullmatch(xur, pwn):
                    xpasswd[pwn][PW_MARK] = True
        for pwn, pwe in xpasswd.items():
            if pwe[PW_MARK]:
                log("", pwn, end="")
        log()

    def merge(self):
        """Merge while storing the result to dpasswd/dgroup"""
        # Mark all destination users that match dur/dgr
        # Note that non-system groups that do match dgr
        # are discarded in the end if they don't have any members
        log("Marked destination users for regexes '{}', '{}':".format(
            self.dur, self.dgr))
        self.mark_users(self.dpasswd, self.dgroup, self.dur, self.dgr)

        # Remove the unmarked destination users
        log("Removed destination users:")
        for pwn in list(self.dpasswd):  # list() as we're removing items
            pwe = self.dpasswd[pwn]
            if not self.uid_min <= pwe[PW_UID] <= self.uid_max:
                continue
            if not self.dpasswd[pwn][PW_MARK]:
                log("", pwn, end="")
                del self.dpasswd[pwn]
        log()

        # Remove the destination non-system groups that are empty,
        # to allow source groups with the same gid to be merged
        # Do not delete primary groups
        log("Removed destination groups:")
        for grn in list(self.dgroup):  # list() as we're removing items
            gre = self.dgroup[grn]
            if not self.gid_min <= gre[GR_GID] <= self.gid_max:
                continue
            remove = True
            for grm in gre[GR_MEM]:
                if grm in self.dpasswd:
                    remove = False
                    break  # A member exists; continue with the next group
            if remove:
                log("", grn, end="")
                del self.dgroup[grn]
        log()

        # Mark all source users that match sur/sgr
        log("Marked source users for regexes '{}', '{}':".format(
            self.sur, self.sgr))
        self.mark_users(self.spasswd, self.sgroup, self.sur, self.sgr)

        # Transfer all the marked users and their primary groups
        # Collisions in this step are considered fatal errors
        log("Transferred users:")
        for pwn, pwe in self.spasswd.items():
            if not pwe[PW_MARK]:
                continue
            if pwn in self.dpasswd:
                if pwe[PW_UID] != self.dpasswd[PW_UID] or \
                        pwe[PW_GID] != self.dpasswd[PW_GID]:
                    raise ValueError(
                        "PW_[UG]ID for {} exists in destination".format(pwn))
            self.dpasswd[pwn] = pwe
            # In ltsp mode, mark the user for ssh logins
            if self.ltsp:
                self.dpasswd[pwn][SP_PWDP] = "pamltsp"
            grn = pwe[PW_GRNAME]
            if grn in self.dgroup:
                if pwe[PW_GID] != self.dgroup[grn][GR_GID]:
                    raise ValueError(
                        "GR_GID for {} exists in destination".format(grn))
            self.dgroup[grn] = self.sgroup[grn]
            log("", pwn, end="")
        log()

        # Try to transfer all the additional groups that have marked members,
        # both system and non-system ones, and warn on collisions
        log("Transferred groups:")
        needeol = False
        for grn, gre in self.sgroup.items():
            transfer = False
            for grm in gre[GR_MEM]:
                if grm in self.spasswd and self.spasswd[grm][PW_MARK]:
                    transfer = True
                    break  # A member is marked; try to transfer
            if not transfer:
                continue
            if grn not in self.dgroup:
                self.dgroup[grn] = gre
            elif gre[GR_GID] == self.dgroup[grn][GR_GID]:
                # Same gids, just merge members without a warning
                self.dgroup[grn][GR_MEM].update(gre[GR_MEM])
            else:
                self.dgroup[grn][GR_MEM].update(gre[GR_MEM])
                log(" [WARNING: group {} has sgid={}, dgid={}; ".
                    format(grn, gre[GR_GID], self.dgroup[grn][GR_GID]),
                    end="", error=True)
                # If gids are different, keep sgid if dgid is not a system one
                if self.gid_min <= self.dgroup[grn][GR_GID] <= self.gid_max:
                    self.dgroup[grn][GR_GID] = grn[GR_GID]
                    log("keeping sgid]", end="", error=True)
                else:
                    log("keeping dgid]", end="", error=True)
                needeol = True
            # In all cases, keep source group password
            self.dgroup[grn][GR_PASSWD] = gre[GR_PASSWD]
            log("", grn, end="")
        log(error=needeol)

        # g2n is a temporary dictionary to map from group name to primary user
        # Note that some gids (e.g. sudo) don't have a primary user
        g2n = {}
        for pwn, pwe in self.dpasswd.items():
            g2n[pwe[PW_GRNAME]] = pwn

        # Remove all unknown members from destination groups,
        # remove primary users from groups,
        # and remove non-system groups that have no members
        umem = set()
        log("Removed unknown groups:")
        for grn in list(self.dgroup):  # list() as we're removing items
            gre = self.dgroup[grn]
            for grm in list(gre[GR_MEM]):
                if grm in self.dpasswd and \
                        self.dpasswd[grm][PW_GID] != gre[GR_GID]:
                    continue
                # Here it is either unknown or primary; remove it
                gre[GR_MEM].remove(grm)
                # Don't notify when removing primary users
                if not grm in self.dpasswd or \
                        self.dpasswd[grm][PW_GID] != gre[GR_GID]:
                    umem.add(grm)
            if not gre[GR_MEM]:
                # Don't remove primary groups if the user exists
                if self.gid_min <= gre[GR_GID] <= self.gid_max and \
                        gre[GR_NAME] not in g2n:
                    del self.dgroup[grn]
                    log("", grn, end="")
        log()
        log("Removed unknown members:")
        for grm in umem:
            log("", grm, end="")
        log()

    def save(self):
        """Save the merged result in mdir/{passwd,group,shadow,gshadow}"""
        with create("{}/passwd".format(self.mdir), 0o644) as file:
            for pwe in self.dpasswd.values():
                file.write("{}:{}:{}:{}:{}:{}:{}\n".format(
                    *pwe[PW_NAME:PW_SHELL+1]))
        with create("{}/group".format(self.mdir), 0o644) as file:
            for gre in self.dgroup.values():
                file.write("{}:{}:{}:{}\n".format(
                    gre[GR_NAME], gre[GR_PASSWD], gre[GR_GID],
                    ",".join(gre[GR_MEM])))
        with create("{}/shadow".format(self.mdir), 0o640) as file:
            for pwe in self.dpasswd.values():
                file.write("{}:{}:{}:{}:{}:{}:{}:{}:{}\n".format(
                    pwe[PW_NAME], *pwe[SP_PWDP:SP_FLAG+1]))
        with create("{}/gshadow".format(self.mdir), 0o640) as file:
            for gre in self.dgroup.values():
                file.write("{}:{}:{}:{}\n".format(
                    gre[GR_NAME],
                    # In python <3.5, unpacked arguments may only be last
                    *(gre[SG_PASSWD:SG_ADM+1] + [",".join(gre[GR_MEM])])))
        if os.geteuid() == 0:
            shutil.chown("{}/shadow".format(self.mdir), group="shadow")
            shutil.chown("{}/gshadow".format(self.mdir), group="shadow")
        else:
            log("Need root permissions to chown dst files to root:shadow",
                error=True)


def main(argv):
    """Run the module from the command line"""
    global QUIET
    try:
        opts, args = getopt.getopt(
            argv[1:], "lq", ["ltsp", "quiet", "sur=", "sgr=", "dur=", "dgr="])
    except getopt.GetoptError as err:
        print("Error in command line parameters:", err, file=sys.stderr)
        args = []  # Trigger line below
    if len(args) != 3:
        print(__doc__.strip())
        sys.exit(1)
    dopts = {"ltsp": False}
    for key, val in opts:
        if key == "-q" or key == "--quiet":
            QUIET = True
        elif key == "-l" or key == "--ltsp":
            dopts["ltsp"] = True
        elif key.startswith("--"):
            dopts[key[2:]] = val
        else:
            raise ValueError("Unknown parameter: ", key, val)
    pwm = PwMerge(args[0], args[1], args[2], **dopts)
    pwm.merge()
    pwm.save()


if __name__ == "__main__":
    main(sys.argv)
