// Licensed to the Software Freedom Conservancy (SFC) under one
// or more contributor license agreements.  See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership.  The SFC licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License.  You may obtain a copy of the License at
//
//   http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied.  See the License for the
// specific language governing permissions and limitations
// under the License.

package dev.selenium.tools.modules;

import static com.github.javaparser.ParseStart.COMPILATION_UNIT;
import static net.bytebuddy.jar.asm.Opcodes.ACC_MANDATED;
import static net.bytebuddy.jar.asm.Opcodes.ACC_MODULE;
import static net.bytebuddy.jar.asm.Opcodes.ACC_OPEN;
import static net.bytebuddy.jar.asm.Opcodes.ACC_STATIC_PHASE;
import static net.bytebuddy.jar.asm.Opcodes.ACC_TRANSITIVE;
import static net.bytebuddy.jar.asm.Opcodes.ASM9;
import static net.bytebuddy.jar.asm.Opcodes.V11;

import com.github.bazelbuild.rules_jvm_external.zip.StableZipEntry;
import com.github.javaparser.JavaParser;
import com.github.javaparser.ParseResult;
import com.github.javaparser.ParserConfiguration;
import com.github.javaparser.Provider;
import com.github.javaparser.Providers;
import com.github.javaparser.ast.CompilationUnit;
import com.github.javaparser.ast.Modifier;
import com.github.javaparser.ast.NodeList;
import com.github.javaparser.ast.expr.Name;
import com.github.javaparser.ast.modules.ModuleDeclaration;
import com.github.javaparser.ast.modules.ModuleExportsDirective;
import com.github.javaparser.ast.modules.ModuleOpensDirective;
import com.github.javaparser.ast.modules.ModuleProvidesDirective;
import com.github.javaparser.ast.modules.ModuleRequiresDirective;
import com.github.javaparser.ast.modules.ModuleUsesDirective;
import com.github.javaparser.ast.visitor.VoidVisitorAdapter;
import java.io.ByteArrayOutputStream;
import java.io.File;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.io.PrintStream;
import java.io.UncheckedIOException;
import java.net.MalformedURLException;
import java.net.URL;
import java.net.URLClassLoader;
import java.nio.charset.StandardCharsets;
import java.nio.file.FileVisitResult;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.nio.file.SimpleFileVisitor;
import java.nio.file.StandardCopyOption;
import java.nio.file.attribute.BasicFileAttributes;
import java.util.Arrays;
import java.util.Collection;
import java.util.HashSet;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.ServiceLoader;
import java.util.Set;
import java.util.TreeMap;
import java.util.TreeSet;
import java.util.concurrent.atomic.AtomicReference;
import java.util.jar.Attributes;
import java.util.jar.JarEntry;
import java.util.jar.JarInputStream;
import java.util.jar.JarOutputStream;
import java.util.jar.Manifest;
import java.util.spi.ToolProvider;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import java.util.zip.ZipEntry;
import java.util.zip.ZipOutputStream;
import net.bytebuddy.jar.asm.ClassReader;
import net.bytebuddy.jar.asm.ClassVisitor;
import net.bytebuddy.jar.asm.ClassWriter;
import net.bytebuddy.jar.asm.MethodVisitor;
import net.bytebuddy.jar.asm.ModuleVisitor;
import net.bytebuddy.jar.asm.Type;
import org.openqa.selenium.io.TemporaryFilesystem;

public class ModuleGenerator {

  private static final String SERVICE_LOADER = ServiceLoader.class.getName().replace('.', '/');

  public static void main(String[] args) throws IOException {
    Path outJar = null;
    Path inJar = null;
    String moduleName = null;
    Set<Path> modulePath = new TreeSet<>();
    Set<String> exports = new TreeSet<>();
    Set<String> hides = new TreeSet<>();
    Set<String> uses = new TreeSet<>();

    // There is no way at all these two having similar names will cause problems
    Map<String, Set<String>> opensTo = new TreeMap<>();
    Set<String> openTo = new TreeSet<>();
    boolean isOpen = false;

    int argCount = args.length;
    for (int i = 0; i < argCount; i++) {
      String flag = args[i];
      String next = args[++i];
      switch (flag) {
        case "--exports":
          exports.add(next);
          break;

        case "--hides":
          hides.add(next);
          break;

        case "--in":
          inJar = Paths.get(next);
          break;

        case "--is-open":
          isOpen = Boolean.parseBoolean(next);
          break;

        case "--module-name":
          moduleName = next;
          break;

        case "--module-path":
          modulePath.add(Paths.get(next));
          break;

        case "--open-to":
          openTo.add(next);
          break;

        case "--opens-to":
          opensTo.computeIfAbsent(next, str -> new TreeSet<>()).add(args[++i]);
          break;

        case "--output":
          outJar = Paths.get(next);
          break;

        case "--uses":
          uses.add(next);
          break;

        default:
          throw new IllegalArgumentException(String.format("Unknown argument: %s", flag));
      }
    }
    Objects.requireNonNull(moduleName, "Module name must be set.");
    Objects.requireNonNull(outJar, "Output jar must be set.");
    Objects.requireNonNull(inJar, "Input jar must be set.");

    ToolProvider jdeps = ToolProvider.findFirst("jdeps").orElseThrow();
    File tempDir = TemporaryFilesystem.getDefaultTmpFS().createTempDir("module-dir", "");
    Path temp = tempDir.toPath();

    // It doesn't matter what we use for writing to the stream: jdeps doesn't use it. *facepalm*
    List<String> jdepsArgs = new LinkedList<>(List.of("--api-only", "--multi-release", "9"));
    if (!modulePath.isEmpty()) {
      Path tmp = Files.createTempDirectory("automatic_module_jars");
      jdepsArgs.addAll(
          List.of(
              "--module-path",
              modulePath.stream()
                  .map(
                      (s) -> {
                        String file = s.getFileName().toString();

                        if (file.startsWith("processed_")) {
                          Path copy = tmp.resolve(file.substring(10));

                          try {
                            Files.copy(s, copy, StandardCopyOption.REPLACE_EXISTING);
                          } catch (IOException e) {
                            throw new UncheckedIOException(e);
                          }

                          return copy.toString();
                        }

                        return s.toString();
                      })
                  .collect(Collectors.joining(File.pathSeparator))));
    }
    jdepsArgs.addAll(List.of("--generate-module-info", temp.toAbsolutePath().toString()));
    jdepsArgs.add(inJar.toAbsolutePath().toString());

    PrintStream origOut = System.out;
    PrintStream origErr = System.err;

    ByteArrayOutputStream bos = new ByteArrayOutputStream();
    PrintStream printStream = new PrintStream(bos);

    int result;
    try {
      System.setOut(printStream);
      System.setErr(printStream);
      result = jdeps.run(printStream, printStream, jdepsArgs.toArray(new String[0]));
    } finally {
      System.setOut(origOut);
      System.setErr(origErr);
    }
    if (result != 0) {
      throw new RuntimeException(
          String.format(
              "Unable to process module:%njdeps %s%n%s",
              String.join(" ", jdepsArgs), bos.toString(StandardCharsets.UTF_8)));
    }

    AtomicReference<Path> moduleInfo = new AtomicReference<>();
    // Fortunately, we know the directory where the output is written
    Files.walkFileTree(
        temp,
        new SimpleFileVisitor<>() {
          @Override
          public FileVisitResult visitFile(Path file, BasicFileAttributes attrs) {
            if ("module-info.java".equals(file.getFileName().toString())) {
              moduleInfo.set(file);
            }
            return FileVisitResult.TERMINATE;
          }
        });

    if (moduleInfo.get() == null) {
      throw new RuntimeException("Unable to read module info");
    }

    ParserConfiguration parserConfig =
        new ParserConfiguration().setLanguageLevel(ParserConfiguration.LanguageLevel.JAVA_11);

    Provider provider = Providers.provider(moduleInfo.get());

    ParseResult<CompilationUnit> parseResult =
        new JavaParser(parserConfig).parse(COMPILATION_UNIT, provider);

    CompilationUnit unit =
        parseResult
            .getResult()
            .orElseThrow(() -> new RuntimeException("Unable to parse " + moduleInfo.get()));

    ModuleDeclaration moduleDeclaration =
        unit.getModule()
            .orElseThrow(
                () -> new RuntimeException("No module declaration in " + moduleInfo.get()));

    moduleDeclaration.setName(moduleName);
    moduleDeclaration.setOpen(isOpen);

    Set<String> allUses = new TreeSet<>(uses);
    allUses.addAll(readServicesFromClasses(inJar));
    allUses.forEach(
        service -> moduleDeclaration.addDirective(new ModuleUsesDirective(new Name(service))));

    // Prepare a classloader to help us find classes.
    ClassLoader classLoader;
    if (modulePath != null) {
      URL[] urls =
          Stream.concat(Stream.of(inJar.toAbsolutePath()), modulePath.stream())
              .map(
                  path -> {
                    try {
                      return path.toUri().toURL();
                    } catch (MalformedURLException e) {
                      throw new UncheckedIOException(e);
                    }
                  })
              .toArray(URL[]::new);

      classLoader = new URLClassLoader(urls);
    } else {
      classLoader = new URLClassLoader(new URL[0]);
    }

    Set<String> packages = inferPackages(inJar);

    // Determine packages to export
    Set<String> exportedPackages = new HashSet<>();
    if (!isOpen) {
      if (!exports.isEmpty()) {
        exports.forEach(
            export -> {
              if (!packages.contains(export)) {
                throw new RuntimeException(
                    String.format("Exported package '%s' not found in jar. %s", export, packages));
              }
              exportedPackages.add(export);
              moduleDeclaration.addDirective(
                  new ModuleExportsDirective(new Name(export), new NodeList<>()));
            });
      } else {
        packages.forEach(
            export -> {
              if (!hides.contains(export)) {
                exportedPackages.add(export);
                moduleDeclaration.addDirective(
                    new ModuleExportsDirective(new Name(export), new NodeList<>()));
              }
            });
      }
    }

    openTo.forEach(
        module ->
            moduleDeclaration.addDirective(
                new ModuleOpensDirective(
                    new Name(module),
                    new NodeList(
                        exportedPackages.stream().map(Name::new).collect(Collectors.toSet())))));

    ClassWriter classWriter = new ClassWriter(0);
    classWriter.visit(V11, ACC_MODULE, "module-info", null, null, null);
    ModuleVisitor moduleVisitor = classWriter.visitModule(moduleName, isOpen ? ACC_OPEN : 0, null);
    moduleVisitor.visitRequire("java.base", ACC_MANDATED, null);

    moduleDeclaration.accept(
        new MyModuleVisitor(classLoader, exportedPackages, hides, moduleVisitor), null);

    moduleVisitor.visitEnd();

    classWriter.visitEnd();

    Manifest manifest = new Manifest();
    manifest.getMainAttributes().put(Attributes.Name.MANIFEST_VERSION, "1.0");

    try (OutputStream os = Files.newOutputStream(outJar);
        JarOutputStream jos = new JarOutputStream(os, manifest)) {
      jos.setLevel(ZipOutputStream.STORED);

      byte[] bytes = classWriter.toByteArray();

      ZipEntry entry = new StableZipEntry("module-info.class");
      entry.setSize(bytes.length);

      jos.putNextEntry(entry);
      jos.write(bytes);
      jos.closeEntry();
    }

    TemporaryFilesystem.getDefaultTmpFS().deleteTempDir(tempDir);
  }

  private static Collection<String> readServicesFromClasses(Path inJar) {
    Set<String> serviceNames = new HashSet<>();

    try (InputStream is = Files.newInputStream(inJar);
        JarInputStream jis = new JarInputStream(is)) {
      for (JarEntry entry = jis.getNextJarEntry(); entry != null; entry = jis.getNextJarEntry()) {
        if (entry.isDirectory() || !entry.getName().endsWith(".class")) {
          continue;
        }

        ClassReader reader = new ClassReader(jis);
        reader.accept(
            new ClassVisitor(ASM9) {
              private Type serviceClass;

              @Override
              public MethodVisitor visitMethod(
                  int access,
                  String name,
                  String descriptor,
                  String signature,
                  String[] exceptions) {
                return new MethodVisitor(ASM9) {
                  @Override
                  public void visitMethodInsn(
                      int opcode,
                      String owner,
                      String name,
                      String descriptor,
                      boolean isInterface) {
                    if (SERVICE_LOADER.equals(owner) && "load".equals(name)) {
                      if (serviceClass != null) {
                        serviceNames.add(serviceClass.getClassName());
                        serviceClass = null;
                      }
                    }
                  }

                  @Override
                  public void visitLdcInsn(Object value) {
                    if (value instanceof Type) {
                      serviceClass = (Type) value;
                    }
                  }
                };
              }
            },
            0);
      }
    } catch (IOException e) {
      throw new UncheckedIOException(e);
    }

    return serviceNames;
  }

  private static Set<String> inferPackages(Path inJar) {
    Set<String> packageNames = new TreeSet<>();

    try (InputStream is = Files.newInputStream(inJar);
        JarInputStream jis = new JarInputStream(is)) {
      for (JarEntry entry = jis.getNextJarEntry(); entry != null; entry = jis.getNextJarEntry()) {

        if (entry.isDirectory()) {
          continue;
        }

        if (!entry.getName().endsWith(".class")) {
          continue;
        }

        String name = entry.getName();

        int index = name.lastIndexOf('/');
        if (index == -1) {
          continue;
        }
        name = name.substring(0, index);

        // If we've a multi-release jar, remove that too
        if (name.startsWith("META-INF/versions/")) {
          String[] segments = name.split("/");
          if (segments.length < 3) {
            continue;
          }

          name =
              Arrays.stream(Arrays.copyOfRange(segments, 3, segments.length))
                  .collect(Collectors.joining("/"));
        }

        name = name.replace("/", ".");

        packageNames.add(name);
      }

      return packageNames;
    } catch (IOException e) {
      throw new UncheckedIOException(e);
    }
  }

  private static class MyModuleVisitor extends VoidVisitorAdapter<Void> {

    private final ClassLoader classLoader;
    private final Set<String> seenExports;
    private final Set<String> packages;
    private final ModuleVisitor byteBuddyVisitor;

    MyModuleVisitor(
        ClassLoader classLoader,
        Set<String> packages,
        Set<String> excluded,
        ModuleVisitor byteBuddyVisitor) {
      this.classLoader = classLoader;
      this.byteBuddyVisitor = byteBuddyVisitor;

      // Set is modifiable
      this.packages = new HashSet<>(packages);
      this.seenExports = new HashSet<>(excluded);
    }

    @Override
    public void visit(ModuleRequiresDirective n, Void arg) {
      String name = n.getNameAsString();
      if (name.startsWith("processed.")) {
        // When 'Automatic-Module-Name' is not set, we must derive the module name from the jar file
        // name. Therefore, the 'processed.' prefix added by bazel must be removed to get the name.
        name = name.substring(10);
      }
      int modifiers = getByteBuddyModifier(n.getModifiers());
      if (!name.startsWith("org.seleniumhq.selenium.") && !name.startsWith("java.")) {
        // Some people like to exclude jars from the classpath. To allow this we need to make these
        // modules static,
        // otherwise a 'module not found' error while compiling their code would be the consequence.
        modifiers |= ACC_STATIC_PHASE;
      }
      byteBuddyVisitor.visitRequire(name, modifiers, null);
    }

    @Override
    public void visit(ModuleExportsDirective n, Void arg) {
      if (seenExports.contains(n.getNameAsString())) {
        return;
      }

      seenExports.add(n.getNameAsString());

      byteBuddyVisitor.visitExport(
          n.getNameAsString().replace('.', '/'),
          0,
          n.getModuleNames().stream().map(Name::asString).toArray(String[]::new));
    }

    @Override
    public void visit(ModuleProvidesDirective n, Void arg) {
      byteBuddyVisitor.visitProvide(
          getClassName(n.getNameAsString()),
          n.getWith().stream().map(type -> getClassName(type.asString())).toArray(String[]::new));
    }

    @Override
    public void visit(ModuleUsesDirective n, Void arg) {
      byteBuddyVisitor.visitUse(n.getNameAsString().replace('.', '/'));
    }

    @Override
    public void visit(ModuleOpensDirective n, Void arg) {
      packages.forEach(
          pkg -> byteBuddyVisitor.visitOpen(pkg.replace('.', '/'), 0, n.getNameAsString()));
    }

    private int getByteBuddyModifier(NodeList<Modifier> modifiers) {
      return modifiers.stream()
          .mapToInt(
              mod -> {
                switch (mod.getKeyword()) {
                  case STATIC:
                    return ACC_STATIC_PHASE;
                  case TRANSITIVE:
                    return ACC_TRANSITIVE;
                }
                throw new RuntimeException("Unknown modifier: " + mod);
              })
          .reduce(0, (l, r) -> l | r);
    }

    private String getClassName(String possibleClassName) {
      String name = possibleClassName.replace('/', '.');
      if (lookup(name)) {
        return name.replace('.', '/');
      }

      int index = name.lastIndexOf('.');
      if (index != -1) {
        name = name.substring(0, index) + "$" + name.substring(index + 1);
        if (lookup(name)) {
          return name.replace('.', '/');
        }
      }

      throw new RuntimeException("Cannot find class: " + name);
    }

    private boolean lookup(String className) {
      try {
        Class.forName(className, false, classLoader);
        return true;
      } catch (ClassNotFoundException e) {
        return false;
      }
    }
  }
}
