/* Copyright (c) 2008, AbiSource Corporation B.V.
 * All rights reserved.
 *
 * Redistribution and use in source and binary forms, with or without
 * modification, are permitted provided that the following conditions are met:
 *     * Redistributions of source code must retain the above copyright
 *       notice, this list of conditions and the following disclaimer.
 *     * Redistributions in binary form must reproduce the above copyright
 *       notice, this list of conditions and the following disclaimer in the
 *       documentation and/or other materials provided with the distribution.
 *     * Neither the name of the AbiSource Corporation B.V. nor the
 *       names of its contributors may be used to endorse or promote products
 *       derived from this software without specific prior written permission.
 *
 * THIS SOFTWARE IS PROVIDED BY AbiSource Corporation B.V. ''AS IS'' AND ANY
 * EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
 * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
 * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTOR BE LIABLE 
 * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 
 * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 
 * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 
 * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
 * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
 * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 */

#include <stdlib.h>
#include <unistd.h>
#include <string.h>
#include <string>
#include <vector>
#include <boost/bind.hpp>
#include <boost/function.hpp>
#include <boost/shared_ptr.hpp>
#include <boost/lexical_cast.hpp>
#include <shoes++.h>

using std::string;
using std::vector;
using boost::shared_ptr;
using namespace asio;

namespace shoes {

static ip::address get_address(ip::tcp::socket& sock, const string& host, 
		unsigned short port)
{
	io_service& io_service = sock.get_io_service();
	ip::tcp::resolver resolver(io_service);
	ip::tcp::resolver::query query(host, boost::lexical_cast<string>(port));
	ip::tcp::resolver::iterator iterator = resolver.resolve(query);	
	ip::tcp::endpoint endpoint = *iterator;
	return endpoint.address();
}	

namespace socks4 {

static shared_ptr< vector<unsigned char> > connect_request(
		const ip::address_v4& dst_host, unsigned short dst_port,
		const char* user_id)
{
	ip::address_v4::bytes_type ipaddr = dst_host.to_bytes();
	int id_len = user_id ? strlen(user_id) : 0;
	shared_ptr< vector<unsigned char> > request_ptr(
		new vector<unsigned char>(8 + id_len + 1));
	vector<unsigned char>& request = *request_ptr;
	request[0] = 0x04; // SOCKS version
	request[1] = 0x01; // CONNECT command code
	request[2] = dst_port >> 8; // destination port 
	request[3] = dst_port; // destination port
	memcpy(&request[4], &ipaddr[0], 4); // destination ip
	if (user_id)
		memcpy(&request[8], user_id, id_len);
	request[8 + id_len] = '\0';
	return request_ptr;
}

SHOES_MODULE_EXPORT
int connect(ip::tcp::socket& sock, const ip::address_v4& dst_host,
		unsigned short dst_port, const char* user_id)
{
	try {
		// construct and send the CONNECT request
		shared_ptr< vector<unsigned char> > request_ptr = connect_request(
			dst_host, dst_port, user_id);
		sock.send(buffer(*request_ptr));
		// receive the CONNECT reply
		vector<unsigned char> reply(8);
		sock.receive(buffer(reply));
		if (reply[1] = 90) // request granted
			return 0;
	} catch (asio::system_error& se) {
		return -1;
	}
	return -1;
}

} /* namespace socks4 */

namespace socks5 {

static shared_ptr< vector<unsigned char> > method_request(
		const vector<unsigned char>& methods)
{
	shared_ptr< vector<unsigned char> > request_ptr(
		new vector<unsigned char>(2 + methods.size()));
	vector<unsigned char>& request = *request_ptr;
	request[0] = 0x05; // SOCKS version
	request[1] = methods.size(); // number of supported authentication methods
	for (int i = 0; i < methods.size(); i++)
		request[2+i] = methods[i]; // supported authentication method
	return request_ptr;
}

static unsigned char negotiate_method(ip::tcp::socket& sock, 
		const vector<unsigned char>& methods)
{
	try {
		// construct and send the method negotiation request
		shared_ptr< vector<unsigned char> > method_ptr = method_request(
			methods);
		sock.send(buffer(*method_ptr));
		// receive the method negotiation reply
		vector<char> method_reply(2);
		sock.receive(buffer(method_reply));
		return method_reply[1];
	} catch (asio::system_error& se) {
		return 0xFF;	// FIXME: this abuses the "none of the authentication
						// methods are acceptable" reply, which is clearly not
						// the case here
	}
}

static shared_ptr< vector<unsigned char> > connect_request(
		const ip::address& dst_host, unsigned short dst_port)
{
	if (!dst_host.is_v4() && !dst_host.is_v6())
		return shared_ptr< vector<unsigned char> >();
	
	unsigned char ip_size = dst_host.is_v4() ? 4 : 16;
	shared_ptr< vector<unsigned char> > request_ptr(
		new vector<unsigned char>(6 + ip_size));
	vector<unsigned char>& request = *request_ptr;
	request[0] = 0x05; // SOCKS version
	request[1] = 0x01; // CONNECT command code
	request[2] = 0x00; // reserved
	request[3] = (dst_host.is_v4() ? 0x01 : 0x04); // address type
	// destination ip
	if (dst_host.is_v4()) {
		ip::address_v4::bytes_type ipaddr = dst_host.to_v4().to_bytes();
		memcpy(&request[4], &ipaddr[0], 4);
	} else {
		ip::address_v6::bytes_type ipaddr = dst_host.to_v6().to_bytes();
		memcpy(&request[4], &ipaddr[0], 16);
	}
	request[4 + ip_size] = dst_port >> 8; // destination port
	request[4 + ip_size + 1] = dst_port; // destination port
	return request_ptr;
}

static int sub_method_null()
{
	return 0;
}

static int sub_method_username_password(ip::tcp::socket& sock, 
		const string& username, const string& password)
{
	vector<unsigned char> request(3 + username.size() + password.size());
	request[0] = 0x01; // version
	request[1] = username.size();
	memcpy(&request[2], &username[0], username.size());
	request[2+username.size()] = password.size();
	memcpy(&request[2+username.size()+1], &password[0], password.size());
	try {
		sock.send(buffer(request));
		vector<unsigned char> reply(2);
		sock.receive(buffer(reply));
		return reply[1] == 0x00; // status, 0x00 is success
	} catch (asio::system_error& se) {
		return -1;
	}
	return -1;
}

static int connect(ip::tcp::socket& sock, const ip::address& dst_host,
		unsigned short dst_port, boost::function<bool ()> sub_method)
{
	try {
		// perform the authentication handshake
		if (sub_method() != 0)
			return -1;

		// construct and send the CONNECT request
		shared_ptr< vector<unsigned char> > request_ptr = connect_request(
				dst_host, dst_port);
		sock.send(buffer(*request_ptr));
		
		// receive the CONNECT reply
		vector<char> reply(4); // holds the first part of CONNECT reply
		sock.receive(buffer(reply));
		if (reply[1] != 0x00)
			return -1; // we don't report the error code yet
		switch (reply[3]) { // address type
			case 0x01: // IPv4
				reply.resize(4+4+2);
				sock.receive(buffer(&reply[4], 6));
				// TODO: reconnect to the bind address/port?
				break;
			case 0x03: // domainname
				return -1; // unhandled response
			case 0x04: // IPv6
				reply.resize(4+16+2);
				sock.receive(buffer(&reply[4], 18));
				// TODO: reconnect to the bind address/port?
				break;
			default:
				return -1; // invalid response
		}
		return 0; // success
	} catch (asio::system_error& se) {
		return -1;
	}
}

SHOES_MODULE_EXPORT
int connect(ip::tcp::socket& sock, const ip::address& dst_host,
		unsigned short dst_port)
{
	vector<unsigned char> methods;
	methods.push_back(0x00); // no authentication
	unsigned char method = negotiate_method(sock, methods);
	switch (method) {
		case 0x00: // no authentication
			return connect(sock, dst_host, dst_port, &sub_method_null);
		case 0xFF: // none of the authentication methods are acceptable
			return -1;
		default: // unhandled authentication type
			return -1;
	}
	return -1; // unreached
}

SHOES_MODULE_EXPORT
int connect(ip::tcp::socket& sock, const ip::address& dst_host, 
		unsigned short dst_port, const string& username, const string& password)
{
	vector<unsigned char> methods;
	methods.push_back(0x00); // no authentication
	methods.push_back(0x02); // username password authentication
	unsigned char method = negotiate_method(sock, methods);
	switch (method) {
		case 0x00: // no authentication
			return connect(sock, dst_host, dst_port, &sub_method_null);
		case 0x02:
			return connect(sock, dst_host, dst_port, boost::bind(
				&sub_method_username_password, boost::ref(sock), username, 
					password));
		case 0xFF: // none of the authentication methods are acceptable
			return -1;
		default: // unhandled authentication type
			return -1;
	}
	return -1; // unreached
}

} /* namespace socks5 */

} /* namespace shoes */

static int clone_socket(native_socket ns, native_socket& sock_clone)
{
#ifdef _WIN32
	if (!DuplicateHandle(GetCurrentProcess(), (HANDLE)ns,
		GetCurrentProcess(), (HANDLE*)&sock_clone, 0, FALSE, 
		DUPLICATE_SAME_ACCESS))
		return -1;
	return 0;
#else
	sock_clone = dup(ns);
	return sock_clone == -1 ? -1 : 0;
#endif
}

extern "C"
SHOES_MODULE_EXPORT
int shoes_socks4_connect(native_socket sock, const char* dst_host,
		unsigned short dst_port, const char* user_id)
{
	io_service io_service;
	try {
		native_socket sock_clone;
		if (clone_socket(sock, sock_clone) != 0)
			return -1;
		ip::tcp::socket socket(io_service, ip::tcp::v4(), sock_clone);
		ip::address dst_addr = shoes::get_address(socket, dst_host,
			dst_port);
		if (!dst_addr.is_v4())
			return -1;		
		return shoes::socks4::connect(socket, dst_addr.to_v4(), dst_port,
			user_id);
	} catch (asio::system_error& se) {
		return -1;
	}
}

extern "C"
SHOES_MODULE_EXPORT
int shoes_socks5_connect(native_socket sock, const char* dst_host,
		unsigned short dst_port)
{
	io_service io_service;
	try {
		native_socket sock_clone;
		if (clone_socket(sock, sock_clone) != 0)
			return -1;
		// FIXME: don't assume it is IPv4
		ip::tcp::socket socket(io_service, ip::tcp::v4(), sock_clone);
		ip::address dst_addr = shoes::get_address(socket, dst_host,
			dst_port);
		return shoes::socks5::connect(socket, dst_addr, dst_port);
	} catch (asio::system_error& se) {
		return -1;
	}
}

extern "C"
SHOES_MODULE_EXPORT
int shoes_socks5_connect_username_password(native_socket sock, const char* dst_host,
		unsigned short dst_port, const char* username,
		const char* password)
{
	io_service io_service;
	try {
		native_socket sock_clone;
		if (clone_socket(sock, sock_clone) != 0)
			return -1;
		// FIXME: don't assume it is IPv4
		ip::tcp::socket socket(io_service, ip::tcp::v4(), sock_clone);
		ip::address dst_addr = shoes::get_address(socket, dst_host,
			dst_port);
		return shoes::socks5::connect(socket, dst_addr, dst_port,
			username, password);
	} catch (asio::system_error& se) {
		return -1;
	}
}
