/* Copyright (C) 2018-2025 by Arm Limited. All rights reserved. */

#include "linux/proc/ProcessChildren.h"

#include "Logging.h"
#include "lib/Error.h"
#include "lib/String.h"
#include "lib/Syscall.h"

#include <cerrno>
#include <csignal>
#include <cstdlib>
#include <cstring>
#include <fstream>
#include <ios>
#include <map>
#include <memory>
#include <set>
#include <unordered_set>

#include <dirent.h>
#include <sys/types.h>
#include <unistd.h>

namespace lnx {
    // NOLINTNEXTLINE(misc-no-recursion)
    void addTidsRecursively(std::set<int> & tids, int tid, tid_enumeration_mode_t tid_enumeration_mode)
    {
        constexpr std::size_t buffer_size = 64; // should be large enough for the proc path

        auto result = tids.insert(tid);
        if (!result.second) {
            return; // we've already added this and its children
        }

        lib::printf_str_t<buffer_size> filename {};

        // try to get all children (forked processes), available since Linux 3.5
        switch (tid_enumeration_mode) {
            case tid_enumeration_mode_t::self_and_threads_and_children: {
                filename.printf("/proc/%d/task/%d/children", tid, tid);
                std::ifstream children {filename, std::ios_base::in};
                if (children) {
                    int child;
                    while (children >> child) {
                        addTidsRecursively(tids, child, tid_enumeration_mode);
                    }
                }
                break;
            }

            case tid_enumeration_mode_t::self_only:
            case tid_enumeration_mode_t::self_and_threads:
            default: {
                // nothing to do
                break;
            }
        }

        // Now add all threads for the process
        // If 'children' is not found then new processes won't be counted on onlined cpu.
        // We could read /proc/[pid]/stat for every process and create a map in reverse
        // but that would likely be time consuming
        switch (tid_enumeration_mode) {
            case tid_enumeration_mode_t::self_and_threads:
            case tid_enumeration_mode_t::self_and_threads_and_children: {
                filename.printf("/proc/%d/task", tid);
                const std::unique_ptr<DIR, int (*)(DIR *)> taskDir {opendir(filename), &closedir};
                if (taskDir != nullptr) {
                    const dirent * taskEntry;
                    // NOLINTNEXTLINE(concurrency-mt-unsafe)
                    while ((taskEntry = readdir(taskDir.get())) != nullptr) {
                        // no point recursing if we're relying on the fall back
                        if (std::strcmp(taskEntry->d_name, ".") != 0 && std::strcmp(taskEntry->d_name, "..") != 0) {
                            const auto child = std::strtol(taskEntry->d_name, nullptr, 10);
                            if (child > 0) {
                                tids.insert(pid_t(child));
                            }
                        }
                    }
                }
                break;
            }
            case tid_enumeration_mode_t::self_only:
            default: {
                break;
            }
        }
    }

    sigstop_result_t send_sigstop(pid_t pid)
    {
        if (lib::kill(pid, SIGSTOP) == -1) {
            return (errno == ESRCH ? sigstop_result_t::failed_no_such_pid //
                                   : sigstop_result_t::failed_other);
        }

        return sigstop_result_t::success;
    }

    // NOLINTNEXTLINE(bugprone-easily-swappable-parameters,misc-no-recursion)
    [[nodiscard]] bool recursively_stop_all_tids(std::set<pid_t> const & pids,
                                                 std::set<pid_t> const & filter_set,
                                                 std::set<pid_t> & result,
                                                 std::unordered_set<pid_t> & recursed_into,
                                                 std::map<pid_t, sig_continuer_t> & paused_tids,
                                                 tid_enumeration_mode_t tid_enumeration_mode)
    {
        bool modified = false;
        std::set<int> tids {};

        LOG_FINE("Called recursively_stop_all_tids(%zu)", pids.size());

        // stop and process each pid
        for (pid_t pid : pids) {

            // already stopped ?
            if (paused_tids.count(pid) > 0) {
                // record it in the result as it is still a tracked pid
                result.insert(pid);

                // now find any children and stop them as well (but for each iteration of stop_all_tids, just do this one per tid)
                if (recursed_into.insert(pid).second) {
                    addTidsRecursively(tids, pid, tid_enumeration_mode);
                }

                // but no need to stop it again
                continue;
            }

            // to be ignored ?
            if (filter_set.count(pid) > 0) {
                // just skip it
                continue;
            }

            // stop it?
            switch (send_sigstop(pid)) {
                case sigstop_result_t::success: {
                    LOG_FINE("Successfully stopped %d", pid);
                    // success
                    paused_tids.emplace(pid, sig_continuer_t {pid});
                    result.insert(pid);
                    modified = true;
                    break;
                }
                case sigstop_result_t::failed_other: {
                    // error
                    auto const error = errno;

                    // add it to the map with an empty entry so as not to poll it again, but dont set modified
                    LOG_WARNING("Could not SIGSTOP %d due to errno=%d", pid, error);
                    paused_tids.emplace(pid, sig_continuer_t {});

                    // add it to 'result'
                    result.insert(pid);
                    break;
                }
                case sigstop_result_t::failed_no_such_pid: {
                    // error
                    auto const error = errno;

                    // add it to the map with an empty entry so as not to poll it again, but dont set modified
                    LOG_WARNING("Could not SIGSTOP %d due to errno=%d", pid, error);
                    paused_tids.emplace(pid, sig_continuer_t {});

                    // don't add it to the result

                    break;
                }
            }

            // now find any children and stop them as well (but for each iteration of stop_all_tids, just do this one per tid)
            if (recursed_into.insert(pid).second) {
                addTidsRecursively(tids, pid, tid_enumeration_mode);
            }
        }

        if (!tids.empty()) {
            LOG_FINE("Found %zu child pids/tids", tids.size());
            modified |=
                recursively_stop_all_tids(tids, filter_set, result, recursed_into, paused_tids, tid_enumeration_mode);
        }

        return modified;
    }

    // NOLINTNEXTLINE(bugprone-easily-swappable-parameters)
    std::set<pid_t> stop_all_tids(std::set<pid_t> const & pids,
                                  std::set<pid_t> const & filter_set,
                                  std::map<pid_t, sig_continuer_t> & paused_tids,
                                  tid_enumeration_mode_t tid_enumeration_mode)
    {
        constexpr unsigned sleep_usecs = 100;

        std::set<pid_t> result {};

        LOG_FINE("Called stop_all_tids (%zu, %d)", pids.size(), int(tid_enumeration_mode));

        std::unordered_set<pid_t> recursed_into {};

        while (recursively_stop_all_tids(pids, filter_set, result, recursed_into, paused_tids, tid_enumeration_mode)) {
            // clear on each iteration so that once per this loop we check all pids
            recursed_into.clear();

            LOG_FINE("Pause before re-scanning");
            // sleep some tiny amount of time so that the signals can propogate before checking again
            usleep(sleep_usecs);
        }

        return result;
    }

    void sig_continuer_t::signal() noexcept
    {
        pid_t pid {std::exchange(this->pid, 0)};

        if (pid != 0) {
            if (lib::kill(pid, SIGCONT) == -1) {
                auto const error = errno;

                LOG_WARNING("Could not SIGCONT %d due to errno=%d (%s)", pid, error, lib::strerror(error));
            }
            else {
                LOG_FINE("Resumed %d PID", pid);
            }
        }
    }
}
