#!/usr/bin/env python3
# SPDX-License-Identifier: GPL-2.0-or-later

import argparse
import errno
import fnmatch
import logging
import os
import os.path
import shutil
import stat
import sys
import tempfile

logger = logging.getLogger(__name__)

DIR_MAP = {
    "bin": "usr/bin",
    "sbin": "usr/bin",
    "usr/sbin": "usr/bin",
    "lib": "usr/lib",
    "lib32": "usr/lib32",
    "lib64": "usr/lib64",
    "libx32": "usr/libx32",
}

IGNORE_PATTERNS = (
    ".keep*",
)


def raise_exists(path):
    raise FileExistsError(errno.EEXIST, os.strerror(errno.EEXIST), path)

def get_xattrs(path):
    try:
        attrs = os.listxattr(path, follow_symlinks=False)
    except OSError:
        return
    for name in attrs:
        try:
            value = os.getxattr(path, name, follow_symlinks=False)
            yield (name, value)
        except OSError as e:
            logger.error("Error getting xattr '%s' from '%s': %s",
                         name, path, e)

def set_xattrs(path, attrs):
    for (name, value) in attrs:
        try:
            os.setxattr(path, name, value, follow_symlinks=False)
        except OSError as e:
            logger.error("Error setting xattr '%s' on '%s': %s",
                         name, path, e)

def copy_xattrs(src, dst):
    attrs = get_xattrs(src)
    set_xattrs(dst, attrs)

def replace_file_with_symlink(src, dst):
    srcdir = os.path.dirname(src)
    target = os.path.relpath(dst, srcdir)
    tmp = tempfile.mktemp(dir=srcdir)
    os.symlink(target, tmp)
    copy_xattrs(src, tmp)
    os.rename(tmp, src)

def replace_directory_with_symlink(src, dst):
    srcdir = os.path.dirname(src)
    target = os.path.relpath(dst, srcdir)
    logger.info("Replacing '%s' with a symlink to '%s'", src, target)
    attrs = list(get_xattrs(src))
    tmp = tempfile.mktemp(dir=srcdir)
    try:
        os.rename(src, tmp)
    except OSError as e:
        # overlayfs returns EXDEV when renaming between layers
        if e.errno != errno.EXDEV:
            raise
        shutil.rmtree(src)
        tmp = None
    os.symlink(target, src)
    if attrs:
        set_xattrs(src, attrs)
    if tmp is not None:
        try:
            shutil.rmtree(tmp)
        except Exception as e:
            logger.warning("Unable to remove '%s': %s", tmp, e)

def check_directory(path):
    try:
        st = os.stat(path, follow_symlinks=False)
    except FileNotFoundError:
        return
    if not stat.S_ISDIR(st.st_mode):
        raise_exists(path)

def ensure_directory(path):
    try:
        os.mkdir(path)
    except FileExistsError:
        pass
    else:
        os.chmod(path, 0o0755)

class SkipSymlink(Exception):
    pass

class MergeUsr:
    def __init__(self, dryrun, root, prefix, dirmap, ignore_patterns):
        self.dryrun = dryrun
        self.root = root
        self.prefix = prefix
        self.dirmap = dirmap
        self.ignore_patterns = ignore_patterns
        self.errors = []

    def ignore_name(self, name):
        return any((fnmatch.fnmatch(name, pattern) for pattern in self.ignore_patterns))

    def resolve_symlinks(self, path):
        try:
            target = os.readlink(path)
        except OSError:
            pass
        else:
            if os.path.isabs(target):
                path = os.path.join(self.root, os.path.relpath(target, "/"))
        return os.path.realpath(path)

    def check_object(self, src, dst, src_is_symlink):
        try:
            st = os.lstat(dst)
        except FileNotFoundError:
            return
        dst_is_symlink = stat.S_ISLNK(st.st_mode)
        s = self.resolve_symlinks(src) if src_is_symlink else src
        d = self.resolve_symlinks(dst) if dst_is_symlink else dst
        if s == d or os.path.samefile(src, dst):
            if src_is_symlink:
                raise SkipSymlink()
            # Replace symlink with file.
            return
        if src_is_symlink and dst_is_symlink:
            # Check for links like /lib64/lp64d -> .
            srcdir = os.path.dirname(src)
            dstdir = os.path.dirname(dst)
            if s.startswith(srcdir) and d.startswith(dstdir):
                s = os.path.relpath(s, srcdir)
                d = os.path.relpath(d, dstdir)
                if s == d:
                    raise SkipSymlink()
        raise_exists(dst)

    def check_symlink(self, src, dst):
        self.log_compare("symlink", src, dst)
        self.check_object(src, dst, True)

    def check_file(self, src, dst):
        self.log_compare("file", src, dst)
        self.check_object(src, dst, False)

    def copy_symlink(self, src, dst):
        self.log_copy("symlink", src, dst)
        srcdir = os.path.dirname(src)
        dstdir = os.path.dirname(dst)
        target = os.readlink(src)
        if not os.path.isabs(target):
            t = os.path.join(srcdir, target)
            t = self.resolve_symlinks(t)
            if t.startswith(srcdir):
                target = os.path.relpath(t, srcdir)
            else:
                target = os.path.relpath(t, dstdir)
        os.symlink(target, dst)
        copy_xattrs(src, dst)

    def link_or_copy_file(self, src, dst):
        self.log_copy("file", src, dst)
        tmp = tempfile.mktemp(dir=os.path.dirname(dst))
        try:
            os.link(src, tmp, follow_symlinks=False)
        except FileExistsError:
            raise
        except OSError:
            shutil.copy2(src, tmp, follow_symlinks=False)
        os.rename(tmp, dst)
        replace_file_with_symlink(src, dst)

    def log_compare(self, filetype, src, dst):
        logger.debug("Comparing %s '%s' to '%s'", filetype, src, dst)

    def log_copy(self, filetype, src, dst):
        logger.debug("Copying %s '%s' to '%s'", filetype, src, dst)

    def log_error(self, filetype, path, error):
        self.errors.append((filetype, path, error))
        logger.error("Conflict for %s '%s': %s", filetype, path, error)

    def copy_tree(self, srcdir, dstdir):
        with os.scandir(srcdir) as entries:
            for entry in entries:
                if self.ignore_name(entry.name):
                    continue
                src = os.path.join(srcdir, entry.name)
                dst = os.path.join(dstdir, entry.name)
                if entry.is_symlink():
                    try:
                        self.check_symlink(src, dst)
                    except SkipSymlink:
                        logger.info("Skipping symlink '%s'; '%s' already exists", src, dst)
                    except Exception as e:
                        self.log_error("symlink", src, e)
                    else:
                        if not self.dryrun:
                            self.copy_symlink(src, dst)
                elif entry.is_file(follow_symlinks=False):
                    try:
                        self.check_file(src, dst)
                    except Exception as e:
                        self.log_error("file", src, e)
                    else:
                        if not self.dryrun:
                            self.link_or_copy_file(src, dst)
                elif entry.is_dir(follow_symlinks=False):
                    try:
                        check_directory(dst)
                    except Exception as e:
                        self.log_error("directory", src, e)
                    else:
                        if not self.dryrun:
                            ensure_directory(dst)
                        self.copy_tree(src, dst)
                else:
                    self.log_error("special file", src, "Special files are not supported")

    def run(self):
        result = True

        for d in self.dirmap:
            src = os.path.join(self.root, self.prefix, d)
            dst = os.path.join(self.root, self.prefix, self.dirmap[d])

            if os.path.islink(src):
                logger.info("Already a symlink: '%s'", src)
                continue

            if not os.path.exists(src):
                continue

            if not os.path.isdir(src):
                logger.warning("Not a directory: '%s'", src)
                continue

            logger.info("Migrating files from '%s' to '%s'", src, dst)
            self.copy_tree(src, dst)

            if self.errors:
                logger.error("Leaving '%s' as a directory due to prior errors", src)
                result = False
                self.errors.clear()
                continue

            if self.dryrun:
                logger.info("No problems found for '%s'", src)
            else:
                replace_directory_with_symlink(src, dst)

        return result

def main():
    logging.basicConfig(format="%(levelname)s: %(message)s")

    parser = argparse.ArgumentParser()
    parser.add_argument("--dryrun", action=argparse.BooleanOptionalAction)
    parser.add_argument("--root")
    parser.add_argument("--prefix")
    parser.add_argument("--verbose", action="store_true")

    args = parser.parse_args()

    if args.verbose:
        logger.setLevel(logging.DEBUG)
    else:
        logger.setLevel(logging.INFO)

    root = args.root
    if root is None:
        root = "/"
    else:
        root = args.root.rstrip("/") + "/"

    prefix = args.prefix
    if prefix is None:
        prefix = ""
    else:
        prefix = prefix.lstrip("/")

    os.umask(0o077)

    mu = MergeUsr(args.dryrun, root, prefix, DIR_MAP, IGNORE_PATTERNS)
    if not mu.run():
        return 1
    return 0

if __name__ == "__main__":
    sys.exit(main())
