/* Copyright (C) 2020-2021 by Arm Limited. All rights reserved. */

#define _POSIX_C_SOURCE 20200101L

#include <errno.h>
#include <fcntl.h>
#include <libgen.h>
#include <poll.h>
#include <pthread.h>
#include <signal.h>
#include <stdatomic.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <sys/prctl.h>
#include <sys/types.h>
#include <sys/wait.h>
#include <unistd.h>

#define TAG_PID         "P: "
#define TAG_STDOUT      "O: "
#define TAG_STDERR      "E: "
#define TAG_ERROR       "R: "
#define TAG_CHILD_ERROR "r: "
#define TAG_EXITCODE    "X: "
#define TAG_EXITSIGNAL  "S: "
#define TAG_DEBUG       "D: "
#define TAG_CHILD_DEBUG "d: "
#define NFDS 5

#if (defined(LOG_DEBUG) && (LOG_DEBUG != 0))
#  define DEBUG(...)            printf(TAG_DEBUG __VA_ARGS__)
#  define CHILD_DEBUG(...)      dprintf(pipe_debug[1], __VA_ARGS__)
#else
#  define DEBUG(...)
#  define CHILD_DEBUG(...)
#endif

#define EXEC_STATE_FLAG_GO      0
#define EXEC_STATE_FLAG_ABORT   1

static pid_t childpid = 0;
static int pipe_in[2] = {-1, -1};
static int pipe_out[2] = {-1, -1};
static int pipe_err[2] = {-1, -1};
static int pipe_log[2] = {-1, -1};
static int pipe_debug[2] = {-1, -1};
static int pipe_do_exec[2] = {-1, -1};
static const char * TAGS[NFDS] = {"<stdin>", TAG_STDOUT, TAG_STDERR, TAG_CHILD_ERROR, TAG_CHILD_DEBUG};
static atomic_bool flag_can_exec;
static atomic_bool flag_terminated;

static int fill_pollfds(int n, struct pollfd * pfds, int * isclosed, const int * fds, int * indexes, const char ** tags)
{
    int r = 0;
    for (int i = 0; i < n; ++i) {
        if (isclosed[i]) {
            continue;
        }

        pfds[r].fd = fds[i];
        pfds[r].events = POLLIN;
        pfds[r].revents = 0;
        tags[r] = TAGS[i];
        indexes[r] = i;
        ++r;
    }
    return r;
}

static int enable_exec(char state)
{
    if (write(pipe_do_exec[1], &state, sizeof(state)) < 0) {
        printf(TAG_ERROR "Failed to write(do_exec). Error was %d :: %s\n",
               errno, strerror(errno));
        return -1;
    }
    return 0;
}

static const int SIGS_TO_CATCH[] = {
    SIGHUP,
    SIGINT,
    SIGQUIT,
    SIGUSR1,
    SIGUSR2,
    SIGTERM,
    /* SIGCHLD is needed to reap the children */
    SIGCHLD,
    0
};

static void sig_handler(int signo)
{
    if (signo == SIGCHLD) {
        DEBUG("Received SIGCHLD\n");
    }
    else {
        pid_t pid = childpid;

        DEBUG("Forwarding signal %d (%s)\n", signo, strsignal(signo));
        if (pid != 0) {
            if (kill(pid, signo) < 0) {
                printf(TAG_ERROR "Failed to kill(%d, %d (%s)). Error was %d :: %s\n",
                       pid, signo, strsignal(signo), errno, strerror(errno));
            }
        }
    }
}

static int add_signal_handlers(void)
{
    for (int i = 0; SIGS_TO_CATCH[i] != 0; ++i) {
        if (signal(SIGS_TO_CATCH[i], sig_handler) == SIG_ERR) {
            printf(TAG_ERROR "Failed to signal(%s). Error was %d :: %s\n",
                   strsignal(SIGS_TO_CATCH[i]), errno, strerror(errno));
            return -1;
        }
    }

    return 0;
}

static void * worker_fn(void * data)
{
#define NDX_STDIN   0
#define NDX_STDOUT  1
#define NDX_STDERR  2

    char buffer[65536];

    const char * tags[NFDS];
    const int fds[NFDS] = {STDIN_FILENO, pipe_out[0], pipe_err[0], pipe_log[0], pipe_debug[0]};
    struct pollfd pfds[NFDS];
    int isclosed[NFDS] = {0, 0, 0, 0, 0};
    int indexes[NFDS] = {-1, -1, -1, -1, -1};
    int nfds = fill_pollfds(NFDS, pfds, isclosed, fds, indexes, tags);

    // thread is ready, allow exec
    if (enable_exec(EXEC_STATE_FLAG_GO) != 0) {
        return NULL;
    }

    // poll for data
    while (nfds > 0)
    {
        // poll for some event
        const int result = poll(pfds, nfds, 1000);
        if (result < 0) {
            printf(TAG_ERROR "Failed to poll(). Error was %d :: %s\n",
                   errno, strerror(errno));
            return NULL;
        }

        DEBUG("----------- %d :: %d\n", result, nfds);

        // check terminated
        if ((result == 0) && atomic_load(&flag_terminated)) {
            return NULL;
        }

        // process the events
        int modified = 0;
        for (int i = 0; i < nfds; ++i) {
            const int ndx = indexes[i];

            DEBUG("[%d/%d] %d EVENT 0x%x\n", i, ndx, pfds[i].fd, pfds[i].revents);

            // data to read
            if (pfds[i].revents & POLLIN) {
                const ssize_t nread = read(pfds[i].fd, buffer, sizeof(buffer));

                DEBUG("%d READ %zd\n", pfds[i].fd, nread);

                // handle error
                if (nread < 0) {
                    printf(TAG_ERROR "Failed to read(%s). Error was %d :: %s\n",
                           tags[i], errno, strerror(errno));
                    return NULL;
                }
                // handle closed
                else if (nread == 0) {
                    DEBUG("%d CLOSED BY READ\n", pfds[i].fd);
                    isclosed[ndx] = 1;
                    modified = 1;
                    close(pfds[i].fd);
                }
                // process data (stdin special case)
                else if (ndx == NDX_STDIN) {
                    // forward directly to the child process
                    for (ssize_t wpos = 0; wpos < nread; ) {
                        // write some bytes
                        const ssize_t nwrite = write(pipe_in[1], buffer + wpos, nread - wpos);

                        // handle error
                        if (nwrite < 0) {
                            // pipe closed
                            if (errno == EPIPE) {
                                DEBUG("STDIN CLOSED BY WRITE\n");
                                isclosed[ndx] = 1;
                                modified = 1;
                                close(pipe_in[1]);
                                close(STDIN_FILENO);
                            }
                            // retry
                            else if ((errno == EAGAIN) || (errno == EWOULDBLOCK) || (errno == EINTR)) {
                                continue;
                            }
                            // error
                            else {
                                printf(TAG_ERROR "Failed to write(STDIN). Error was %d :: %s\n",
                                       errno, strerror(errno));
                                return NULL;
                            }
                        }
                        // process data
                        else {
                            wpos += nwrite;
                        }
                    }
                }
                // process data (normal case)
                else {
                    printf("%s", tags[i]);
                    int prev = 0;
                    int pos = 0;
                    while (pos < nread) {
                        // accumulate normal chars
                        if ((buffer[pos] >= 32) && (buffer[pos] < 127) && (buffer[pos] != '\\')) {
                            ++pos;
                            if (pos < nread) {
                                continue;
                            }
                        }
                        // flush accumulated items
                        if (prev < pos) {
                            char c = buffer[pos];
                            buffer[pos] = 0;
                            printf("%s", buffer + prev);
                            buffer[pos] = c;
                            prev = pos;
                        }
                        // handle special char
                        if (pos < nread) {
                            if (buffer[pos] == '\n') {
                                printf("\\n");
                                if ((pos + 1) < nread) {
                                    printf("\n%s", tags[i]);
                                }
                            }
                            else if (buffer[pos] == '\\') {
                                printf("\\\\");
                            }
                            else {
                                printf("\\x%02x", ((int) buffer[pos]) & 0xff);
                            }

                            ++pos;
                            prev = pos;
                        }
                    }
                    printf("\n");
                }
            }

            // hung-up or error
            else if ((pfds[i].revents & POLLERR) || (pfds[i].revents & POLLHUP) || (pfds[i].revents & POLLNVAL)) {
                DEBUG("DEBUG %d CLOSED\n", pfds[i].fd);
                isclosed[ndx] = 1;
                modified = 1;
                close(pfds[i].fd);
                // close the write end of the pipe as well
                if (ndx == NDX_STDIN) {
                    close(pipe_in[1]);
                }
            }
        }

        if (modified) {
            nfds = fill_pollfds(NFDS, pfds, isclosed, fds, indexes, tags);
        }
    }
}

static void help(const char * cmd)
{
    printf(TAG_ERROR "Missing arguments.\n");
    printf(TAG_ERROR "Usage %s [-d <working-dir>] <cmd> <args>...\n", cmd);
}

int main(int argc, char** argv)
{
    const char * cmd_name = argv[0];
    int argv_cmd_index = 1;

    if (argc < 2) {
        help(cmd_name);
        return 1;
    }

    if (strcmp(argv[1], "-d") == 0) {
        if (argc < 4) {
            help(cmd_name);
            return 1;
        }
        if (chdir(argv[2]) != 0) {
            printf(TAG_ERROR "Failed to chdir(%s). Error was %d :: %s\n",
                   argv[2], errno, strerror(errno));
            return 2;
        }
        argv_cmd_index = 3;
    }

    // disable buffering
    setvbuf(stdout, NULL, _IONBF, 0);
    setvbuf(stderr, NULL, _IONBF, 0);

    // prepare pipes for fork

    // child stdout pipe
    if (pipe(pipe_in) != 0) {
        printf(TAG_ERROR "Failed to pipe(in). Error was %d :: %s\n",
               errno, strerror(errno));
        return 3;
    }

    // child stdout pipe
    if (pipe(pipe_out) != 0) {
        printf(TAG_ERROR "Failed to pipe(out). Error was %d :: %s\n",
               errno, strerror(errno));
        return 3;
    }

    // child stderr pipe
    if (pipe(pipe_err) != 0) {
        printf(TAG_ERROR "Failed to pipe(err). Error was %d :: %s\n",
               errno, strerror(errno));
        return 3;
    }

    // child log pipe
    if (pipe(pipe_log) != 0) {
        printf(TAG_ERROR "Failed to pipe(log). Error was %d :: %s\n",
               errno, strerror(errno));
        return 3;
    }

    if (fcntl(pipe_log[0], F_SETFD, FD_CLOEXEC) != 0) {
        printf(TAG_ERROR "Failed to fcntl(log[0]). Error was %d :: %s\n",
               errno, strerror(errno));
        return 3;
    }

    if (fcntl(pipe_log[1], F_SETFD, FD_CLOEXEC) != 0) {
        printf(TAG_ERROR "Failed to fcntl(log[1]). Error was %d :: %s\n",
               errno, strerror(errno));
        return 3;
    }

    // child debug message pipe
    if (pipe(pipe_debug) != 0) {
        printf(TAG_ERROR "Failed to pipe(debug). Error was %d :: %s\n",
               errno, strerror(errno));
        return 3;
    }

    if (fcntl(pipe_debug[0], F_SETFD, FD_CLOEXEC) != 0) {
        printf(TAG_ERROR "Failed to fcntl(debug[0]). Error was %d :: %s\n",
               errno, strerror(errno));
        return 3;
    }

    if (fcntl(pipe_debug[1], F_SETFD, FD_CLOEXEC) != 0) {
        printf(TAG_ERROR "Failed to fcntl(debug[1]). Error was %d :: %s\n",
               errno, strerror(errno));
        return 3;
    }

    // do_exec event pipe
    if (pipe(pipe_do_exec) != 0) {
        printf(TAG_ERROR "Failed to pipe(do_exec). Error was %d :: %s\n",
               errno, strerror(errno));
        return 3;
    }

    if (fcntl(pipe_do_exec[0], F_SETFD, FD_CLOEXEC) != 0) {
        printf(TAG_ERROR "Failed to fcntl(do_exec[0]). Error was %d :: %s\n",
               errno, strerror(errno));
        return 3;
    }

    if (fcntl(pipe_do_exec[1], F_SETFD, FD_CLOEXEC) != 0) {
        printf(TAG_ERROR "Failed to fcntl(do_exec[1]). Error was %d :: %s\n",
               errno, strerror(errno));
        return 3;
    }

    // if my parent dies, give me HUP
    if (prctl(PR_SET_PDEATHSIG, SIGHUP) != 0) {
        printf(TAG_ERROR "Failed to prctl(PR_SET_PDEATHSIG, SIGHUP). Error was %d :: %s\n",
               errno, strerror(errno));
        return 3;
    }


    // become the reaperman so that we can wait on all children
    if (prctl(PR_SET_CHILD_SUBREAPER, 1) != 0) {
        printf(TAG_ERROR "Failed to prctl(PR_SET_CHILD_SUBREAPER, 1). Error was %d :: %s\n",
               errno, strerror(errno));
        return 3;
    }

    // set the process name, but don't fail if it didn't work, its only really to help cleanup using pgrep
    prctl(PR_SET_NAME, (unsigned long) "terminal_wrapper", 0, 0, 0);

    // this flag is used to terminate the worker thread
    atomic_init(&flag_terminated, 0);

    // fork
    pid_t pid = fork();

    // child
    if (pid == 0) {
        // redirect stdout/stderr
        close(STDIN_FILENO);
        close(STDOUT_FILENO);
        close(STDERR_FILENO);

        if (dup2(pipe_in[0], STDIN_FILENO) < 0) {
            dprintf(pipe_log[1], "Failed to dup2(in). Error was %d :: %s\n",
                    errno, strerror(errno));
            return 100;
        }

        if (dup2(pipe_out[1], STDOUT_FILENO) < 0) {
            dprintf(pipe_log[1], "Failed to dup2(out). Error was %d :: %s\n",
                    errno, strerror(errno));
            return 101;
        }

        if (dup2(pipe_err[1], STDERR_FILENO) < 0) {
            dprintf(pipe_log[1], "Failed to dup2(err). Error was %d :: %s\n",
                    errno, strerror(errno));
            return 102;
        }

        close(pipe_in[0]);
        close(pipe_in[1]);
        close(pipe_out[0]);
        close(pipe_out[1]);
        close(pipe_err[0]);
        close(pipe_err[1]);

        // disable buffering
        setvbuf(stdout, NULL, _IONBF, 0);
        setvbuf(stderr, NULL, _IONBF, 0);

        // if my parent dies, give me HUP
        if (prctl(PR_SET_PDEATHSIG, SIGHUP) != 0) {
            dprintf(pipe_log[1], "Failed to prctl(PR_SET_PDEATHSIG, SIGHUP). Error was %d :: %s\n",
                    errno, strerror(errno));
            return 103;
        }

        // block waiting for the ready event
        char byte = EXEC_STATE_FLAG_ABORT;
        if (read(pipe_do_exec[0], &byte, sizeof(byte)) < 0) {
            dprintf(pipe_log[1], "Failed to read(do_exec). Error was %d :: %s\n",
                    errno, strerror(errno));
            return 104;
        }

        if (byte != EXEC_STATE_FLAG_GO) {
            dprintf(pipe_log[1], "Read(do_exec) returned failure flag %d\n", (int) byte);
            return 105;
        }

        // shift the argv array, and insert null terminator
        memcpy(argv, argv + argv_cmd_index, sizeof(char*) * (argc - argv_cmd_index));
        argv[argc - argv_cmd_index] = NULL;

        CHILD_DEBUG("Executing '%s'", argv[0]);

        // exec the child process
        execvp(argv[0], argv);

        // must of had an error
        dprintf(pipe_log[1], "Failed to exec. Error was %d :: %s\n",
                errno, strerror(errno));

        // send 126/127 as per the shell
        return (errno == ENOENT ? 127 : 126);
    }
    // parent
    else if (pid > 0) {
        // close fds
        close(pipe_in[0]);
        close(pipe_out[1]);
        close(pipe_err[1]);
        close(pipe_log[1]);
        close(pipe_debug[1]);

        // log the child process's pid
        printf(TAG_PID "%d\n", pid);

        // add the signal handlers
        if (add_signal_handlers() != 0) {
            // make sure the child does not hang
            enable_exec(EXEC_STATE_FLAG_ABORT);
            return 4;
        }

        childpid = pid;

        // create output thread, when this starts it will prompt the child process
        // to exec
        pthread_t thread;
        if ((errno = pthread_create(&thread, NULL, &worker_fn, NULL)) != 0) {
            printf(TAG_ERROR "Failed to pthread_create(). Error was %d :: %s\n",
                   errno, strerror(errno));
            // make sure the child does not hang
            enable_exec(EXEC_STATE_FLAG_ABORT);
            return 5;
        }

        // wait for all child processes to terminate
        int child_status;
        while (1) {
            int status;
            pid_t result = waitpid(-1, &status, 0);
            if (result < 0) {
                // musl bug returns errn as result
                if (errno == 0) {
                    errno = -result;
                    DEBUG("Musl waitpid errno bug: %d :: %s\n",
                          errno, strerror(errno));
                }
                // handle error
                if (errno == EINTR) {
                    // interrupted by signal
                    continue;
                }
                else if (errno == ECHILD) {
                    // there are no more children to wait on
                    break;
                }
                else {
                    // error
                    printf(TAG_ERROR "Failed to waitpid(%d). Error was %d :: %s\n",
                           pid, errno, strerror(errno));
                    return 6;
                }
            }
            else {
                DEBUG("Reaping child %d\n", result);
                if (result == pid) {
                    // store child's exit code for later
                    child_status = status;
                    // clear pid so signal handle does not send to wrong process
                    childpid = 0;
                }
            }
        }

        // mark terminated
        atomic_store(&flag_terminated, 1);

        // join data thread
        if ((errno = pthread_join(thread, NULL)) != 0) {
            printf(TAG_ERROR "Failed to pthread_join(). Error was %d :: %s\n",
                   errno, strerror(errno));
            return 7;
        }

        // process exit code
        if (WIFEXITED(child_status)) {
            printf(TAG_EXITCODE "%d\n", WEXITSTATUS(child_status));
            return 0;
        }

        // did it die with a signal?
        if (WIFSIGNALED(child_status)) {
            printf(TAG_EXITSIGNAL "%d '%s'\n", WTERMSIG(child_status), strsignal(WTERMSIG(child_status)));
            return 0;
        }
    }
    // error
    else {
        printf(TAG_ERROR "Failed to fork(). Error was %d :: %s\n",
               errno, strerror(errno));
        return 8;
    }
}
