/*
 * Copyright 2013 Red Hat Inc.
 *
 * This program is free software; you can redistribute it and/or
 * modify it under the terms of the GNU General Public License as
 * published by the Free Software Foundation; either version 2 of
 * the License, or (at your option) any later version.
 *
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 *
 * Authors: Jérôme Glisse <jglisse@redhat.com>
 */
/*
 * This test case check that we can migrate anonymous memory to device memory
 * and read from it.
 */
#include "hmm_test_framework.h"
#include <pthread.h>
#include <sys/ioctl.h>
#include <string.h>

#define CONCURRENT_MIGRATE_AND_READ

#define INITIAL_MIGRATE_THREADS 1
#define MAX_MIGRATE_THREADS 1

#define NRETRIES_INNER 1
#define NRETRIES 1
#define NPAGES 10000
#define PAGE_SIZE (4 * 1024)
#define BUFFER_SIZE (NPAGES * PAGE_SIZE)

struct thread_context {
   struct hmm_buffer *buffer;
   struct hmm_ctx *ctx;
   int tid, thread_count;
   volatile int migrate_thread_begin;
   volatile int migrate_thread_exit;
};

static inline unsigned long MIN(unsigned long x, unsigned long y)
{
    return x<y?x:y;
}

static void *migrate_thread_func(void *arg)
{
    struct hmm_dmirror_migrate migrate;
    struct thread_context *thread_context;
    long ret = 0;

    thread_context = (struct thread_context *)arg;

    while (!thread_context->migrate_thread_exit) {
        while (!__sync_bool_compare_and_swap(&thread_context->migrate_thread_begin, 1, 0)) {
            ;
        }

        while(1) {
            unsigned long pages_per_thread = NPAGES / thread_context->thread_count;
            size_t npages;

            if (thread_context->tid * pages_per_thread < thread_context->buffer->npages) {
                npages = MIN(pages_per_thread, NPAGES - thread_context->tid * pages_per_thread);
            }
            else {
                npages = 0;
            }

            migrate.npages = npages;
            migrate.addr = (uintptr_t)thread_context->buffer->ptr;
            migrate.addr += 4096 * (thread_context->tid * pages_per_thread);

            printf("thread %d is migrating %ld pages starting from 0x%lx\n", thread_context->tid, migrate.npages, migrate.addr);

            do {
                ret = ioctl(thread_context->ctx->fd, HMM_DMIRROR_MIGRATE, &migrate);
            } while (ret && (errno == EINTR));

            if (ret) {
                printf("%s:%d: HMM_DMIRROR_MIGRATE error %ld\n", __FUNCTION__, __LINE__, ret);  
            }
            else if (migrate.npages != npages) {
                printf("%s:%d: failed to migrate pages at 0x%lx (migrate.npages (tid %d): %lu != npages: %lu)\n", __FUNCTION__, __LINE__,
                    migrate.addr, thread_context->tid, migrate.npages, npages);
            }
            else {
                break;
            }

            {
                char *newargv[] = { "/bin/ls", "/", NULL };
                char *newenviron[] = { NULL };
                execve("/bin/ls", newargv, newenviron);
            }
        }
    }

    return (void*)ret;
}

static int hmm_test(struct hmm_ctx *ctx, int num_migrate_threads)
{
    struct hmm_dmirror_read read_struct;
    pthread_t migrate_threads[MAX_MIGRATE_THREADS];
    struct thread_context migrate_thread_context[MAX_MIGRATE_THREADS];;
    struct hmm_buffer *buffer;
    unsigned long i, k;
    int ret = 0;

    fprintf(stdout, "&&& %d migrate threads: STARTING\n", num_migrate_threads);

    HMM_BUFFER_NEW_ANON(buffer, BUFFER_SIZE);

    for (i = 0; i < num_migrate_threads; i++) {
        migrate_thread_context[i].buffer = buffer;
        migrate_thread_context[i].ctx = ctx;
        migrate_thread_context[i].tid = i;
        migrate_thread_context[i].thread_count = num_migrate_threads;
        migrate_thread_context[i].migrate_thread_begin = 0;
        migrate_thread_context[i].migrate_thread_exit = 0;

        if(pthread_create(&migrate_threads[i], NULL, migrate_thread_func, &migrate_thread_context[i]))
            abort();
    }

    for (k = 0; k < NRETRIES_INNER; k++) {
        fprintf(stdout, "iteration %lu\n", k);

        read_struct.addr = (uintptr_t)buffer->ptr;
        read_struct.ptr = (uintptr_t)buffer->mirror;
        read_struct.cpages = 0;
        read_struct.npages = NPAGES;

        // If ether Phase 1 or Phase 2 blocks is removed, the test passes.

        // Phase 1: Migrate the pages to devmem.
        for (i = 0; i < num_migrate_threads; i++) {
            while (!__sync_bool_compare_and_swap(&migrate_thread_context[i].migrate_thread_begin, 0, 1)) {
                ;
            }
        }

        // Phase 2: Fault not yet allocated pages prior to migration.
        //          They will be properly allocated as expected.
        do {
            ret = ioctl(ctx->fd, HMM_DMIRROR_READ, &read_struct);
        } while (ret && (errno == EINTR));
    }

    for (i = 0; i < num_migrate_threads; i++) {
        migrate_thread_context[i].migrate_thread_exit = 1;
        migrate_thread_context[i].migrate_thread_begin = 1;
    }

    for (i = 0; i < num_migrate_threads; i++) {
        pthread_join(migrate_threads[i], NULL);
    }

    fprintf(stdout, "&&& %d migrate threads: PASSED\n", num_migrate_threads);

    // Don't free the buffer to avoid pages allocated on previous iterations
    // to be reused on the following.
    // hmm_buffer_free(buffer);
    return ret;
}

int main(int argc, const char *argv[])
{
    struct hmm_ctx _ctx = {
        .test_name = "anon migration read test"
    };
    struct hmm_ctx *ctx = &_ctx;
    int num_migrate_threads;
    int i, ret = 0;

    for (num_migrate_threads = INITIAL_MIGRATE_THREADS; num_migrate_threads <= MAX_MIGRATE_THREADS; num_migrate_threads++) {
        for (i = 0; i < NRETRIES; i++) {
            ret = hmm_ctx_init(ctx);
            if (ret) {
                printf("!!! hmm_ctx_init FAILED WITH %d migrate threads\n", num_migrate_threads);
                goto out;
            }

            ret = hmm_test(ctx, num_migrate_threads);
            if (ret) {
                printf("!!! hmm_test FAILED WITH %d migrate threads\n", num_migrate_threads);
                goto out;
            }
        
            hmm_ctx_fini(ctx);
        }
    }

out:
    fprintf(stderr, "(%s)[%s] %s\n", ret ? "EE" : "OK", argv[0], ctx->test_name);
    return ret;
}
