/*
 * phe_sum - tool for generating SHA1 sums of files
 * 
 * Uses VIA PadLock PHE engine for the hashing itself
 * so it's muchh faster than sha1sum(1) tool.
 * 
 * PHE (PadLock Hash Engine) in VIA Esther (C7, C5J) 
 * processors attempt to always finalize the SHA1 hashing.
 * This is very inconvenient for hashing large amounts 
 * of data, e.g. CD or DVD images that even may not fit
 * into available memory if they're larger than 4GB or
 * at least be swapped if larger then actual physical
 * RAM. To avoid this behavior Andy Polyakov came up
 * with an idea to trigger exception (segfault) before
 * PHE can finalize the hash. In such case it stores the
 * intermediate results into the output buffer so we
 * can later use it as initdata for subsequent hasing.
 * 
 * By Michal Ludvig <michal@logix.cz> on 2006-04-23
 *    http://www.logix.cz/michal/devel/padlock
 *
 * License: GPL2
 */

/* TODO:
 * - SHA256 mode
 * - MMAP source files
 */

#define _GNU_SOURCE
#define _FILE_OFFSET_BITS       64
#include <stdio.h>
#include <errno.h>
#include <signal.h>
#include <getopt.h>
#include <stdlib.h>
#include <string.h>
#include <unistd.h>
#include <fcntl.h>
#include <sys/mman.h>
#include <sys/stat.h>
#include <sys/time.h>
#include <sys/ucontext.h>
#include <arpa/inet.h>

#define PROG_NAME       "phe_sum"
#define PROG_VERSION    "1.0"

/* ====== Features ====== */
#define USE_DIRECT
// #define OPENSSL
// #define VERBOSE

/* ====== Config options ====== */
/* The size of the input buffer is specified in "pages".
 * Each page is 4096 bytes, thus 16 pages means 64kB buffer.
 * There is a tradeoff between overhead for too small buffers
 * and L1 cache trashing for too large ones. On author's system
 * the best results were observed with 14 pages.
 */
#define INPUT_PAGES     14

/* Only used in non-direct mode which is slow anyway. */
#ifndef USE_DIRECT
#define READ_BUFFER     4096
#endif

/* ====== Nothing to configure below this line ====== */
#define MAX_SHA_BIN     64      /* 64 Bytes for SHA512 */
#define MAX_SHA_STR     2*MAX_SHA_BIN

#define ALG_SHA1        1
#define ALG_SHA256      2

/* Global variables and their default values.  */
static size_t   input_pages     = INPUT_PAGES;
static int      verbose         = 0;
static int      hash_alg        = ALG_SHA1;

char            *argv0_base     = PROG_NAME;

static size_t   total_csums;
static size_t   failed_csums;

static off64_t  total_bytes;

#ifdef OPENSSL
#include <openssl/sha.h>
#endif

#ifdef VERBOSE
static void
dump_mem(char *msg, char *ptr, int count)
{
        int i;
        if (msg)
                printf("%s:\n", msg);
        for (i=0; i < count; i++)
                printf("%02X%c", ptr[i]&0xFF, (i+1)%8?' ':'\n');
        if (i%8)
                printf("\n");
}
#define dprintf(format, ...) fprintf(stderr, format, ## __VA_ARGS__)
#else
#define dump_mem(a,b,c)
#define dprintf(format, ...)
#endif

static char *
sha_bin2str(const char *sha_bin, char *sha_str, int len)
{
        int i;
        static char result[MAX_SHA_STR];

        if (len != 160 && len != 256 && len != 512) {
                printf("Bad hash length: %d\n", len);
                return NULL;
        }

        if (!sha_str)
                sha_str = result;

        for (i = 0; i < len/(8*4); i++)
                sprintf(sha_str + (8*i), "%08x",
                        (((int*)sha_bin)[i]) & 0xFFFFFFFF);

        return sha_str;
}

#ifdef OPENSSL
static char *
sha_bin2str_be(const char *sha_bin, char *sha_str, int len)
{
        int i;
        static char result[MAX_SHA_STR];

        if (len != 160 && len != 256 && len != 512) {
                printf("Bad hash length: %d\n", len);
                return NULL;
        }

        if (!sha_str)
                sha_str = result;

        for (i = 0; i < len/(8*4); i++)
                sprintf(sha_str + (8*i), "%08x",
                        (ntohl((int*)sha_bin)[i]) & 0xFFFFFFFF);

        return sha_str;
}

#endif

struct sha_ctx {
        char            output[128];    /* Leave this field first! We need to ensure
                                         * it's 16-Bytes aligned and we rely on it in 
                                         * padlock_xsha1_hash_unfull() */
        uint64_t        total;          /* Leave this one second to ensure it's neatly
                                         * aligned as well. And by the way we use it
                                         * in padlock_xsha1_hash_unfull() */
        char            *input;
        size_t          size;
        size_t          used;
        size_t          mapped;
} __attribute__((aligned(16)));

static void
segv_action(int sig, siginfo_t *info, void *uctxp)
{
        ucontext_t *uctx = uctxp;

        if (sig != SIGSEGV) {
                fprintf(stderr, "WTF, not for us?!\n");
                exit(33);
        }

        //dprintf("EAX=0x%08x (%u)\n", uctx->uc_mcontext.gregs[11], uctx->uc_mcontext.gregs[11]);
        //dprintf("ECX=0x%08x (%u)\n", uctx->uc_mcontext.gregs[10], uctx->uc_mcontext.gregs[10]);
        //dprintf("EIP=0x%08x\n", uctx->uc_mcontext.gregs[14]);

        /* EIP in ucontext points to the first byte of the
         * offending instruction ("rep xsha1" in our case).
         * We know this instruction is 4 bytes long so we 
         * simply increment the EIP on stack by 4 bytes
         * and return. 
         * This is very very special to our case, don't take 
         * it as a design pattern ;-)
         */
        uctx->uc_mcontext.gregs[14] += 4;

        //dprintf("New IP=%p\n", (void*)(uctx->uc_mcontext.gregs[14]));

        return;
}

static inline int
padlock_xsha1_lowlevel(char *input, char *output, int count)
{
        int done = 0;
        asm volatile ("xsha1"
                      : "+S"(input), "+D"(output), "+a"(done)
                      : "c"(count));
        return done;
}

static int
padlock_xsha1_nonfinalizing(struct sha_ctx *ctx)
{
        size_t hashed;

        struct sigaction act, oldact;

        memset(&act, 0, sizeof(act));
        act.sa_sigaction = segv_action;
        act.sa_flags = SA_SIGINFO;
        sigaction(SIGSEGV, &act, &oldact);

        /* We tell xsha to hash more than available to 
         * let it trigger an exception. */
        hashed = padlock_xsha1_lowlevel(ctx->input, ctx->output, ctx->size + 64);

        sigaction(SIGSEGV, &oldact, NULL);

        if (hashed != ctx->size) {
                fprintf(stderr, "%s(): hashed(%zu) != ctx->size(%zu)\n",
                        __func__, hashed, ctx->used);
                return -1;
        }
        return 0;
}

static int
padlock_xsha1_hash_unfull(struct sha_ctx *ctx)
{
        struct sha_ctx backup_ctx = *ctx;

        /* Move the input data towards the end of accessible memory.
         * We need to trigger segfault to avoid HW finalizing. */
        memmove(ctx->input + ctx->size - ctx->used, ctx->input, ctx->used);
        ctx->input += ctx->size - ctx->used;
        ctx->size = ctx->used;

        int ret = padlock_xsha1_nonfinalizing(ctx);

        /* Copy back everything except the ctx.output buffer. */
        memcpy(&(ctx->total), &(backup_ctx.total),
               sizeof(backup_ctx) - sizeof(backup_ctx.output));
        return ret;
}

static void
padlock_sha1_init(struct sha_ctx *ctx, size_t input_pages)
{
        size_t  page_size = getpagesize();
        size_t  input_size = input_pages * page_size;

        /* We need one extra page after the real 
         * buffer and make it non-accessible */
        size_t  prot_pages = 1,
                prot_size = prot_pages * page_size;
        void    *prot_ptr;

        memset(ctx, 0, sizeof(struct sha_ctx));

        ctx->mapped = input_size + prot_size;

        /* We're allocating ctx->mapped but only ctx->size
         * will be available for data. */
        ctx->input = mmap(0, ctx->mapped, PROT_READ | PROT_WRITE, MAP_PRIVATE | MAP_ANONYMOUS, 0, 0);

        if (!ctx->input || errno) {
                perror("mmap");
                exit(1);
        };

        /* Pointer to the protected page. Just for convenience. */
        prot_ptr = ctx->input + input_size;

        /* PadLock needs at least 64 Bytes of accessible memory 
         * _after_ the end of input buffer. Oh well... */
        ctx->size = input_size - 64;

        /* We'll try to lock the pages in the memory but won't
         * care if it fails. We won't even check the return value. */
        mlock(ctx->input, ctx->mapped);

        /* The last page must be unaccessible to trigger exception.
         * This is absolutely necessary to succeed. */
        if (mprotect(prot_ptr, prot_size, PROT_NONE) < 0) {
                perror("mprotect");
                exit(1);
        }

        /* Initial constants for SHA1 */
        ((int*)ctx->output)[0] = 0x67452301;
        ((int*)ctx->output)[1] = 0xEFCDAB89;
        ((int*)ctx->output)[2] = 0x98BADCFE;
        ((int*)ctx->output)[3] = 0x10325476;
        ((int*)ctx->output)[4] = 0xC3D2E1F0;
}

static int
padlock_sha1_update(struct sha_ctx *ctx, const char *data, size_t len)
{
        int rounds = 0;
        while (len) {
                if (ctx->used + len < ctx->size) {
                        memcpy(ctx->input + ctx->used, data, len);
                        ctx->used += len;
                        ctx->total += len;
                        return rounds;
                }

                /* add's and sub's are faster than conditions */
                size_t chunk_size = ctx->size - ctx->used;
                memcpy(ctx->input + ctx->used, data, chunk_size);
                data += chunk_size;
                len -= chunk_size;
                ctx->used = ctx->size;  /* Filled the whole buffer */
                ctx->total += chunk_size;
                padlock_xsha1_nonfinalizing(ctx);
                ctx->used = 0;
                rounds++;
        }
        return rounds;
}

static off64_t
padlock_sha1_final(struct sha_ctx *ctx, char *md)
{
        size_t  hashed;
        off64_t total_bytes;

        if (ctx->used == ctx->total) {
                /* Sweet, this is the first run, leave
                 * finalizing to the hardware. */
                hashed = padlock_xsha1_lowlevel(ctx->input, ctx->output, ctx->used);
                if (hashed != ctx->used) {
                        fprintf(stderr,
                                "%s(): hashed(%zu) != ctx->used(%zu)\n",
                                __func__, hashed, ctx->used);
                        exit(1);
                }
                total_bytes = ctx->total;
        } else {
                /* Hardware already hashed some buffers.
                 * Do finalizing manually */
                uint64_t bits_le = ctx->total * 8, bits;
                size_t lastblocklen, padlen;
                /* Reusable buffer with padding pattern. */
                static const char padding[64] = { 0x80, };

                /* BigEndianise the length. */
                ((uint32_t *)&bits)[1] = htonl(((uint32_t *)&bits_le)[0]);
                ((uint32_t *)&bits)[0] = htonl(((uint32_t *)&bits_le)[1]);

                /* Append padding, leave space for length. */
                lastblocklen = ctx->total & 63;
                padlen = (lastblocklen < 56) ? (56 - lastblocklen) : ((64+56) - lastblocklen);
                padlock_sha1_update(ctx, padding, padlen);

                /* Length in BigEndian64 */
                padlock_sha1_update(ctx, (const char *)&bits, sizeof(bits));

                total_bytes = ctx->total;
                padlock_xsha1_hash_unfull(ctx);
        }

        /* Copy out the result. */
        memcpy(md, ctx->output, 20);

        /* Clean up the allocations. */
        memset(ctx->input, 0, ctx->size);
        munmap(ctx->input, ctx->mapped);
        memset(ctx->output, 0, sizeof(ctx->output));

        return total_bytes;
}

static inline char *
padlock_sha1_direct_getaddr(struct sha_ctx *ctx)
{
        return ctx->input + ctx->used;
}

static inline size_t
padlock_sha1_direct_getlength(struct sha_ctx *ctx)
{
        return ctx->size - ctx->used;
}

/*
 * WARNING: length isn't checked for validity. Passed 
 * something huge and confused counters? Your bad :p
 */
static inline void
padlock_sha1_direct_update(struct sha_ctx *ctx, size_t length)
{
        ctx->used += length;
        ctx->total += length;

        if (ctx->size - ctx->used)
                return;

        padlock_xsha1_nonfinalizing(ctx);
        ctx->used = 0;
        return;
}

#ifdef USE_DIRECT
static int
checksum(int fd, char *md)
{
        struct sha_ctx ctx;
        char    output[MAX_SHA_BIN];

        padlock_sha1_init(&ctx, input_pages);

        do {
                char    *buf = padlock_sha1_direct_getaddr(&ctx);
                size_t  rd = padlock_sha1_direct_getlength(&ctx);

                rd = read(fd, buf, rd);

                if (rd <= 0)
                        break;

                padlock_sha1_direct_update(&ctx, rd);
        } while (1);

        total_bytes += padlock_sha1_final(&ctx, output);
        sha_bin2str(output, md, 160);

        return 0;
}
#else   /* !USE_DIRECT, i.e. regular padlock_sha1_update()s */
static int
checksum(int fd, char *md)
{
        struct sha_ctx ctx;
        char    buf[READ_BUFFER];
        char    output[MAX_SHA_BIN];
        int     rd;
#ifdef OPENSSL
        char    output_openssl[MAX_SHA_BIN], md_openssl[MAX_SHA_STR];
#endif

        padlock_sha1_init(&ctx, input_pages);

#ifdef OPENSSL
        SHA_CTX c;
        SHA1_Init(&c);
#endif

        do {
                size_t  len = sizeof(buf);

                rd = read(fd, buf, len);

                if (rd <= 0)
                        break;

#ifdef OPENSSL
                SHA1_Update(&c, buf, rd);
#endif
                padlock_sha1_update(&ctx, buf, rd);
        } while (1);

        total_bytes += padlock_sha1_final(&ctx, output);

        sha_bin2str(output, md, 160);

#ifdef OPENSSL
        SHA1_Final(output_be, &c);
        sha_bin2str_be(output_be, md_openssl, 160);

        int i;
        for (i = 0; i < 5; i++)
                 if (((int*)md)[i] != ntohl(((int*)output_be)[i])) {
                         dprintf("PadLock(%s) != OpenSSL(%s)\n", md, md_openssl);
                         return 1;
#endif
        return 0;
}
#endif  /* USE_DIRECT */

static const struct option long_options[] =
{
        /* Standard options for compatibility 
         * with coreutils' sha1sum.  */
        { "binary", no_argument, NULL, 'b' },
        { "text", no_argument, NULL, 't' },
        { "warn", no_argument, NULL, 'w' },
        { "check", no_argument, NULL, 'c' },
        { "help", no_argument, NULL, 'h' },
        { "version", no_argument, NULL, 'V' },

        /* Options specific to phe_sum.  */
        { "sha1", no_argument, NULL, 1 },
        { "sha256", no_argument, NULL, 2 },
        { "pages", required_argument, NULL, 'p' },
        { "verbose", no_argument, NULL, 'v' },

        /* Barrier.  */
        { NULL, 0, NULL, 0 }
};

static void
print_version()
{
        printf("phe_sum version %s\n", PROG_VERSION);
        printf("\n");
        printf("Written by Michal Ludvig <michal@logix.cz> (c) 2006\n");
        printf("           http://www.logix.cz/michal/devel/padlock\n");
}

static void
print_license()
{
        printf("This is free software.  You may redistribute copies of it under the terms of\n");
        printf("the GNU General Public License version 2 <http://www.gnu.org/licenses/gpl2.txt>.\n");
        printf("There is NO WARRANTY, to the extent permitted by law.\n");
}

static void
print_help()
{
        print_version();
        printf("\n");
        printf("phe_sum computes SHA1 or SHA256 digests on VIA x86 processors\n");
        printf("that support PHE, the PadLock Hash Engine core.\n");
        printf("\n");
        printf("  -c, --check             Read checksums from FILEs and check them\n");
        printf("  -v, --verbose           Increase verbosity, print some statistics\n");
        printf("\n");
        printf("      --help              This help\n");
        printf("      --version           Print program version (which is \"%s\")\n", PROG_VERSION);
        printf("\n");
        printf("      --sha1              Do SHA1 (160-bits) checksum (default)\n");
        printf("      --sha256            Do SHA256 (256-bits) checksum\n");
        printf("\n");
        printf("  -p, --pages=<NN>        Use <NN> number of pages for input buffer\n");
        printf("                          Default is %d pages, each page is %d bytes\n", INPUT_PAGES, getpagesize());
        printf("\n");
        print_license();
}

static int
gen_checksum(const char *fname, char *md)
{
        int ret = 0;
        struct stat st;
        int fd;

        if (strcmp(fname, "-") == 0) {
                fd = dup(0);    /* stdin */
                if (fd < 0) {
                        fprintf(stderr, "%s: %s: %s\n", argv0_base, fname, strerror(errno));
                        return 1;
                }
        } else {
                fd = open(fname, O_RDONLY|O_LARGEFILE);

                if (fstat(fd, &st) < 0) {
                        fprintf(stderr, "%s: %s: %s\n", argv0_base, fname, strerror(errno));
                        ret = 1;
                        goto cleanup;
                }
                if (!S_ISREG(st.st_mode)) {
                        fprintf(stderr, "%s: %s: Not a regular file\n", argv0_base, fname);
                        ret = 1;
                        goto cleanup;
                }
        }

        ret = checksum(fd, md);

cleanup:
        close(fd);

        return ret;
}

static int
check_checksum(const char *fname)
{
        int ret = 0;
        char csum_buf[1024], csum_md[MAX_SHA_STR];
        FILE *csum_file;

        if (strcmp(fname, "-") == 0)
                csum_file = stdin;
        else
                csum_file = fopen(fname, "r");

        if (!csum_file) {
                fprintf(stderr, "%s: %s: %s\n", argv0_base, fname, strerror(errno));
                return 1;
        }

        while (fgets(csum_buf, sizeof(csum_buf), csum_file)) {
                char *csum_sum = NULL, *csum_fname;
                int csum_len = 0, csum_fnlen;

                if (sscanf(csum_buf, "%a[0-9a-fA-F]%n", &csum_sum, &csum_len) < 1)
                        goto invalid_line;

                /* Checksum must be 40 or 64 characters 
                 * for SHA1 (160 bits) or SHA256.  */
                if (csum_len != 40 && csum_len != 64)
                        goto invalid_line;

                csum_fname = csum_buf + csum_len;

                /* csum_fname points to non-hex character, find start of filename */
                while (index(" \t*", *++csum_fname));
                csum_fnlen = strlen(csum_fname);
                while (index("\r\n \t", csum_fname[csum_fnlen - 1]) && csum_fnlen > 0) {
                        csum_fname[csum_fnlen - 1] = 0;
                        csum_fnlen--;
                }
                if (csum_fnlen == 0)
                        goto invalid_line;

                /* We don't support STDIN in --check mode */
                if (strcmp(csum_fname, "-") == 0)
                        goto invalid_line;

                if (gen_checksum(csum_fname, csum_md) != 0)
                        ret = 1;

                if (strcasecmp(csum_md, csum_sum) != 0) {
                        printf("%s: FAILED\n", csum_fname);
                        failed_csums++;
                }
                else
                        printf("%s: OK\n", csum_fname);
                total_csums++;
                continue;

        invalid_line:
                        if (verbose)
                                fprintf(stderr, "Invalid input: %s\n", csum_buf);
                        ret = 1;
                        continue;
        }

        if (csum_file != stdin)
                fclose(csum_file);

        return ret;
}

int
main(int argc, char *argv[])
{
        int     argi, ret = 0, opt, mode_check = 0;
        char    md[MAX_SHA_STR];
        char    *filename;
        struct timeval  tv1, tv2;

        argv0_base = rindex(argv[0], '/');
        if (argv0_base)
                argv0_base++;
        else
                argv0_base = argv[0];

        if (strcmp(argv0_base, "sha1sum") == 0)
                hash_alg = ALG_SHA1;
        else if (strcmp(argv0_base, "sha256sum") == 0)
                hash_alg = ALG_SHA256;

        while ((opt = getopt_long_only (argc, argv, "btwvp:", long_options, NULL)) != -1)
                switch (opt) {
                        case 'b':
                        case 't':
                        case 'w':
                                /* ignore these options */
                                break;
                        case 'c':
                                mode_check = 1;
                                break;
                        case 'v':
                                verbose++;
                                break;
                        case 'h':
                                print_help();
                                exit(0);
                        case 'V':
                                print_version();
                                printf("\n");
                                print_license();
                                exit(0);
                        case ALG_SHA1:
                        // case ALG_SHA256:
                                hash_alg = opt;
                                break;
                        case ALG_SHA256:
                                fprintf(stderr, "SHA256 is not (yet) supported, sorry.\n");
                                exit(1);
                        case 'p':
                                if (sscanf(optarg, "%u", &input_pages) != 1) {
                                        fprintf(stderr, "Invalid argument for --pages. Number required.\n");
                                        exit(1);
                                }
                                if (verbose)
                                        printf("Using %zu pages (%zu bytes) for input buffer.\n",
                                               input_pages, input_pages * getpagesize());
                                break;
                        case '?':
                                /* Unrecognised option.  */
                                exit(1);
                        default:
                                fprintf(stderr, "Internal error: getopt code=0x%02x, please report\n", opt);
                                exit(1);
                }

        gettimeofday(&tv1, NULL);
        if (optind >= argc) {
                filename = "-";
                if (mode_check)
                        ret = check_checksum(filename);
                else {
                        ret = gen_checksum(filename, md);
                        if (ret == 0)
                                printf("%s  %s\n", md, filename);
                }
        } else {
                for (argi = optind; argi < argc; argi++) {
                        filename = argv[argi];
                        if (mode_check)
                                ret = check_checksum(filename);
                        else {
                                if (gen_checksum(filename, md) == 0)
                                        printf("%s  %s\n", md, filename);
                                else
                                        ret = 1;
                        }
                }
        }
        gettimeofday(&tv2, NULL);

        if (verbose) {
                char mult[] = " kMGT";
                int mult_i = 0;
                double secdiff = ((tv2.tv_sec*1000000.0 + tv2.tv_usec) -
                                  (tv1.tv_sec*1000000.0 + tv1.tv_usec)) / 1000000.0;
                double tput = ((double)total_bytes / secdiff);
                while (tput > 1024 && mult[mult_i] != '0') {
                        tput /= 1024;
                        mult_i++;
                }
                printf("Throughput: %llu bytes in %.3f seconds, %.3f %cB/s\n",
                       total_bytes, secdiff, tput, mult[mult_i]);
        }
        if (mode_check && failed_csums)
                fprintf(stderr, "%s: WARNING: %zu of %zu computed checksum did NOT match\n",
                        argv0_base, failed_csums, total_csums);

        return ret;
}