/*
 * aufs sample -- ULOOP driver and ulohttp
 *
 * Copyright (C) 2007 Junjiro Okajima
 *
 * This program, aufs is free software; you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation; either version 2 of the License, or
 * (at your option) any later version.
 *
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with this program; if not, write to the Free Software
 * Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA  02110-1301  USA
 */

/* $Id: ulohttp.c,v 1.1 2007/08/06 00:38:07 sfjro Exp $ */

#include <linux/uloop.h>
#include <linux/unistd.h>

#include <sys/ioctl.h>
#include <sys/mman.h>
#include <sys/stat.h>
#include <sys/types.h>
#include <assert.h>
#include <errno.h>
#include <fcntl.h>
#include <getopt.h>
#include <signal.h>
#include <stdio.h>
#include <string.h>
#include <stdlib.h>
#include <unistd.h>

#include <curl/curl.h>

/*
 * $0
 * [ -b bitmap ]	/tmp/<pid>.bitmap
 * [ -c cache ]		/tmp/<pid>.cache
 * [ -t timeout ]	30 sec
 * /dev/lttpN url_for_fs_image_file
 *
 * and then, "mount /dev/lttpN /wherever/you/want"
 */

static int timeout = 30, pagesize;
static unsigned long long tgt_size, cache_size, mod;
static struct path_fd {
	char *path;
	int fd;
} bitmap, cache, uloop;
static char *me, *g_url;
static CURL *ezcurl;
static char range[32];
static struct arg_for_curl {
	char	*p;
	int	written;
	int	size;
	int	err;
} arg_for_curl;

#define Dbg(fmt, args...) printf("%s:%d:" fmt, __func__, __LINE__, ##args)
#define DbgErr(e) if (e) Dbg("err %d\n", e)

#define ULO_SIGNAL SIGUSR1

/* ---------------------------------------------------------------------- */

static int err_curl(CURLcode curle)
{
	int e;

	e = errno;
	fprintf(stderr, "%s: %s\n", me, curl_easy_strerror(curle));
	me = NULL;
	return e;
}

static size_t store_from_curl(void *got, size_t size, size_t nmemb, void *arg)
{
	int real_bytes;

	//Dbg("size %u, nmemb %u, arg_for_curl->err %d\n", size, nmemb, arg_for_curl.err);
	if (!size || !nmemb || arg_for_curl.err)
		return 0;

	real_bytes = size * nmemb;
	if (arg_for_curl.size < arg_for_curl.written + real_bytes) {
		arg_for_curl.err++;
		return 0;
	}

	memcpy(arg_for_curl.p, got, real_bytes);
	arg_for_curl.written += real_bytes;
	arg_for_curl.p += real_bytes;
	return nmemb;
}

static int store(CURL *ezcurl, unsigned long long start, int size)
{
	CURLcode curle;
	int err;
	unsigned long long m;
	char *o;

	//Dbg("start %Lu, size %d\n", start, size);
	assert(start + size <= cache_size);

	m = start % pagesize;
	start = start - m;
	arg_for_curl.size = size + m;
	if (tgt_size < start + arg_for_curl.size)
		arg_for_curl.size = tgt_size - start;
	o = mmap(NULL, arg_for_curl.size, PROT_WRITE, MAP_SHARED, cache.fd, start);
	if (o == MAP_FAILED)
		return -1;
	arg_for_curl.p = o;
	arg_for_curl.written = 0;
	arg_for_curl.err = 0;

	snprintf(range, sizeof(range) - 1, "%Lu-%Lu",
		 start, start + arg_for_curl.size - 1);
	if (curle == CURLE_OK)
		curle = curl_easy_perform(ezcurl);

	assert(arg_for_curl.written == arg_for_curl.size);
	err = munmap(o, arg_for_curl.size);
	if (curle != CURLE_OK) {
		me = g_url;
		return err_curl(curle);
	}
	if (!err)
		return 0;
	return err;
}

static int io_loop(void)
{
	int err;
	pid_t pid;
	static sigset_t sigset;
	char *p;
	union uloop_ctl arg;

	pid = getpid();
	err = sigemptyset(&sigset);
	if (!err)
		err = sigaddset(&sigset, ULO_SIGNAL);
	if (!err)
		err = sigprocmask(SIG_BLOCK, &sigset, NULL);

	p = uloop.path;
	while (!err) {
		arg.ready.signum = ULO_SIGNAL;
		arg.ready.pid = pid;
		//Dbg("ready\n");
		err = ioctl(uloop.fd, ULOCTL_READY, &arg);
		DbgErr(err);
		if (!err)
			err = sigwaitinfo(&sigset, NULL);
		//DbgErr(err);
		if (err == ULO_SIGNAL)
			err = ioctl(uloop.fd, ULOCTL_RCVREQ, &arg);
		DbgErr(err);
		if (!err) {
			p = g_url;
			err = store(ezcurl, arg.rcvreq.start, arg.rcvreq.size);
		}
		if (!err) {
			p = uloop.path;
			arg.sndres.start = arg.rcvreq.start;
			arg.sndres.size = arg.rcvreq.size;
			err = ioctl(uloop.fd, ULOCTL_SNDRES, &arg);
			DbgErr(err);
		}
	}
	curl_easy_cleanup(ezcurl);

	//me = p;
	return err;
}

/* ---------------------------------------------------------------------- */

static unsigned long long get_size(void)
{
	unsigned long long size;
	CURLcode curle;
	char *header;
	const int hsz = 1024;
	char *p;

	size = ULONG_MAX; /* error */
	header = malloc(hsz);
	if (!header)
		return size;
	arg_for_curl.p = header;
	arg_for_curl.size = hsz;
	arg_for_curl.written = 0;
	arg_for_curl.err = 0;

	curle = curl_easy_setopt(ezcurl, CURLOPT_HEADERFUNCTION,
				 store_from_curl);
	if (curle == CURLE_OK)
		curle = curl_easy_setopt(ezcurl, CURLOPT_RANGE, "0-1");
#if 0
	if (curle == CURLE_OK)
		curle = curl_easy_setopt(ezcurl, CURLOPT_WRITEHEADER, &arg_for_curl);
#endif
	if (curle == CURLE_OK)
		curle = curl_easy_perform(ezcurl);
	if (curle != CURLE_OK) {
		err_curl(curle);
		return size;
	}
	if (arg_for_curl.err) {
		fprintf(stderr, "%s: internal error.\n", me);
		errno = EINVAL;
		return size;
	}

	p = strstr(header, "Content-Range: bytes ");
	if (p)
		p = strchr(p, '/');
	if (!p) {
		fprintf(stderr, "%s: no range header, %s\n", me, g_url);
		errno = EINVAL;
		return size;
	}
	size = strtoull(p + 1, NULL, 10);
	free(header);

	/* reset */
	curle = curl_easy_setopt(ezcurl, CURLOPT_HEADERFUNCTION, NULL);
	if (curle == CURLE_OK)
		curle = curl_easy_setopt(ezcurl, CURLOPT_RANGE, NULL);
	if (curle == CURLE_OK)
		curle = curl_easy_setopt(ezcurl, CURLOPT_WRITEHEADER, NULL);
	if (curle == CURLE_OK)
		return size; /* success */

	err_curl(curle);
	return ULONG_MAX; /* error */
}

static int init_curl_and_size(void)
{
	CURLcode curle;

	errno = ENOMEM;
	ezcurl = curl_easy_init();
	if (!ezcurl)
		return -1;

	curle = curl_easy_setopt(ezcurl, CURLOPT_URL, g_url);
	if (curle == CURLE_OK)
		curle = curl_easy_setopt(ezcurl, CURLOPT_TIMEOUT, timeout);
#if 0
	if (curle == CURLE_OK)
		curle = curl_easy_setopt(ezcurl, CURLOPT_VERBOSE, 1);
#endif
#if 0
	if (curle == CURLE_OK)
		curle = curl_easy_setopt(ezcurl, CURLOPT_NOPROGRESS, 1);
	if (curle == CURLE_OK)
		curle = curl_easy_setopt(ezcurl, CURLOPT_FAILONERROR, 1);
#endif
	if (curle != CURLE_OK)
		goto out_curl;

	errno = ERANGE;
	cache_size = tgt_size = get_size();
	if (tgt_size == ULONG_MAX)
		return -1;

	curle = curl_easy_setopt(ezcurl, CURLOPT_WRITEFUNCTION,
				 store_from_curl);
#if 0
	if (curle == CURLE_OK)
		curle = curl_easy_setopt(ezcurl, CURLOPT_WRITEDATA,
					 &arg_for_curl);
#endif
	if (curle == CURLE_OK)
		curle = curl_easy_setopt(ezcurl, CURLOPT_RANGE, range);
	if (curle == CURLE_OK)
		return 0;

 out_curl:
	err_curl(curle);
	return -1;
}

/* ---------------------------------------------------------------------- */

static int create_size(char *path, unsigned long long size)
{
	int fd, err;
	off_t off;
	ssize_t sz;
	struct stat st;

	err = 0;
	st.st_size = 0;
	fd = open(path, O_RDWR | O_CREAT, 0644);
	if (fd < 0)
		return fd;
	err = fstat(fd, &st);
	if (err)
		return err;
	if (st.st_size == size)
		return fd; /* success */

	off = lseek(fd, size - 1, SEEK_SET);
	if (off == -1)
		return -1;
	sz = write(fd, "\0", 1);
	if (sz != 1)
		return -1;
	return fd; /* success */
}

static int init_loop(void)
{
	int err;
	struct loop_info64 loinfo64;
	union uloop_ctl	arg;

	err = uloop.fd = open(uloop.path, O_RDONLY);
	if (uloop.fd < 0)
		goto out;

	err = ioctl(uloop.fd, LOOP_SET_FD, cache.fd);
	if (err)
		goto out;

	memset(&loinfo64, 0, sizeof(loinfo64));
	strncpy((void*)(loinfo64.lo_file_name), cache.path, LO_NAME_SIZE);
	loinfo64.lo_encrypt_type = ULOOP_HTTP;
	//strncpy((void*)(loinfo64.lo_crypt_name), "ulttp", LO_NAME_SIZE);
	//loinfo64.lo_sizelimit = cache_size;
	err = ioctl(uloop.fd, LOOP_SET_STATUS64, &loinfo64);
	if (err)
		goto out_loop;

	arg.setbmp.fd = bitmap.fd;
	arg.setbmp.pagesize = pagesize;
	err = ioctl(uloop.fd, ULOCTL_SETBMP, &arg);
	DbgErr(err);
        if (!err)
                return 0;

 out_loop:
	ioctl(uloop.fd, LOOP_CLR_FD, cache.fd);
 out:
	return err;
}

static int init(void)
{
	int err;
	char *p;
	unsigned long long sz, m;

	pagesize = sysconf(_SC_PAGESIZE);
	assert(pagesize > 0);

	p = g_url;
	err = init_curl_and_size();
	if (err)
		goto out;

	p = cache.path;
	mod = cache_size % pagesize;
	if (mod)
		cache_size += pagesize - mod;
	assert(!(cache_size % pagesize));
	cache.fd = create_size(p, cache_size);
	if (cache.fd < 0)
		goto out;

	p = bitmap.path;
	sz = cache_size;
	sz /= pagesize;
	sz /= CHAR_BIT;
	if (sz < pagesize)
		sz = pagesize;
	else {
		m = sz % pagesize;
		if (m)
			sz += pagesize - m;
	}
	assert(!(sz % pagesize));
	bitmap.fd = create_size(p, sz);
	if (bitmap.fd < 0)
		goto out;

	p = uloop.path;
	err = init_loop();
	if (!err)
		return 0;
 out:
	me = p;
	return err;
}

static void usage(void)
{
	fprintf(stderr, "%s"
		" [-b bitmap]"
		" [-c cache]"
		" [-t timeout]"
		" /devloopN url_for_fs_image_file",
		me);
	exit(EINVAL);
}

static int parse(int argc, char *argv[])
{
	int opt;
	static char bitmap_def[] = "/tmp/123456.bitmap",
		cache_def[] = "/tmp/123456.cache";

	while ((opt = getopt(argc, argv, "b:m:c:")) != -1) {
		switch (opt) {
		case 'b':
			bitmap.path = optarg;
			break;
		case 'c':
			cache.path = optarg;
			break;
		case 't':
			errno = 0;
			timeout = strtol(optarg, NULL, 0);
			if (errno) {
				me = optarg;
				return ERANGE;
			}
			break;
		default:
			usage();
			break;
		}
	}

	if (argc - optind != 2) {
		usage();
		return EINVAL;
	}

	uloop.path = argv[optind];
	g_url = argv[optind + 1];

	if (!cache.path) {
		snprintf(cache_def, sizeof(cache_def) - 1, "/tmp/%d.cache",
			 getpid());
		cache.path = cache_def;
	}
	if (!bitmap.path) {
		snprintf(bitmap_def, sizeof(bitmap_def) - 1, "/tmp/%d.bitmap",
			 getpid());
		bitmap.path = bitmap_def;
	}

	return 0;
}

/* ---------------------------------------------------------------------- */

int main(int argc, char *argv[])
{
	int err;
	pid_t pid;

	me = argv[0];
	err = parse(argc, argv);
	if (!err)
		err = init();
	if (!err)
		pid = fork();
	if (!pid)
		err = io_loop();
	else if (pid > 0)
		sleep(1);

	if (err && me)
		perror(me);
	return err;
}
