#!/usr/bin/python2

import fcntl
import errno
import posix
import time
import signal
import os
import pwd
import sys
import getopt
import traceback
import datetime
import mimetypes
try:
	from urllib.parse import urlparse
	from urllib.parse import urlunparse
except ImportError:
	from urlparse import urlparse
	from urlparse import urlunparse
import socket
import select
import subprocess

"""Http server based on recipes 511453,511454 from code.activestate.com by Pierre Quentel"""
"""Added support for indexes, access tests, proper handle of SystemExit exception, fixed couple of errors and vulnerbilities, getopt, lockfiles, daemonize etc. by Jakub Kruszona-Zawadzki"""

# the dictionary holding one client handler for each connected client
# key = client socket, value = instance of (a subclass of) ClientHandler
client_handlers = {}

def emptybuff():
	if sys.version<'3':
		return ''
	else:
		return bytes(0)

if sys.version<'3':
	buff_type = str
else:
	buff_type = bytes

# =======================================================================
# The server class. Creating an instance starts a server on the specified
# host and port
# =======================================================================
class Server:
	def __init__(self,host='localhost',port=80):
		if host=='any':
			host=''
		self.host,self.port = host,port
		self.socket = socket.socket(socket.AF_INET,socket.SOCK_STREAM)
		self.socket.setblocking(0)
		self.socket.setsockopt(socket.SOL_SOCKET,socket.SO_REUSEADDR,1)
		self.socket.bind((host,port))
		self.socket.listen(50)

# =====================================================================
# Generic client handler. An instance of this class is created for each
# request sent by a client to the server
# =====================================================================
class ClientHandler:
	blocksize = 2048

	def __init__(self, server, client_socket, client_address):
		self.server = server
		self.client_address = client_address
		self.client_socket = client_socket
		self.client_socket.setblocking(0)
		self.host = socket.getfqdn(client_address[0])
		self.incoming = emptybuff() # receives incoming data
		self.outgoing = emptybuff()
		self.writable = False
		self.close_when_done = True
		self.response = []
 
	def handle_error(self):
		self.close()

	def handle_read(self):
		"""Reads the data received"""
		try:
			buff = self.client_socket.recv(1024)
			if not buff:  # the connection is closed
				self.close()
			else:
				# buffer the data in self.incoming
				self.incoming += buff #.write(buff)
				self.process_incoming()
		except socket.error:
			self.close()

	def process_incoming(self):
		"""Test if request is complete ; if so, build the response
		and set self.writable to True"""
		if not self.request_complete():
			return
		self.response = self.make_response()
		self.outgoing = emptybuff()
		self.writable = True

	def request_complete(self):
		"""Return True if the request is complete, False otherwise
		Override this method in subclasses"""
		return True

	def make_response(self):
		"""Return the list of strings or file objects whose content will
		be sent to the client
		Override this method in subclasses"""
		return ["xxx"]

	def handle_write(self):
		"""Send (a part of) the response on the socket
		Finish the request if the whole response has been sent
		self.response is a list of strings or file objects
		"""
		if len(self.outgoing)==0 and self.response:
			if isinstance(self.response[0],buff_type):
				self.outgoing = self.response.pop(0)
			else:
				self.outgoing = self.response[0].read(self.blocksize)
				if not self.outgoing:
					self.response[0].close()
					self.response.pop(0)
		if self.outgoing:
			try:
				sent = self.client_socket.send(self.outgoing)
			except socket.error:
				self.close()
				return
			if sent < len(self.outgoing):
				self.outgoing = self.outgoing[sent:]
			else:
				self.outgoing = emptybuff()
		if len(self.outgoing)==0 and not self.response:
			if self.close_when_done:
				self.close() # close socket
			else:
				# reset for next request
				self.writable = False
				self.incoming = emptybuff()

	def close(self):
		while self.response:
			if not isinstance(self.response[0],buff_type):
				self.response[0].close()
			self.response.pop(0)
		del client_handlers[self.client_socket]
		self.client_socket.close()

# ============================================================================
# Main loop, calling the select() function on the sockets to see if new 
# clients are trying to connect, if some clients have sent data and if those
# for which the response is complete are ready to receive it
# For each event, call the appropriate method of the server or of the instance
# of ClientHandler managing the dialog with the client : handle_read() or 
# handle_write()
# ============================================================================
def loop(server,handler,timeout=30):
	while True:
		k = list(client_handlers.keys())
		# w = sockets to which there is something to send
		# we must test if we can send data
		w = [ cl for cl in client_handlers if client_handlers[cl].writable ]
		# the heart of the program ! "r" will have the sockets that have sent
		# data, and the server socket if a new client has tried to connect
		r,w,e = select.select(k+[server.socket],w,k,timeout)
		for e_socket in e:
			client_handlers[e_socket].handle_error()
		for r_socket in r:
			if r_socket is server.socket:
				# server socket readable means a new connection request
				try:
					client_socket,client_address = server.socket.accept()
					client_handlers[client_socket] = handler(server,client_socket,client_address)
				except socket.error:
					pass
			else:
				# the client connected on r_socket has sent something
				client_handlers[r_socket].handle_read()
		w = set(w) & set(client_handlers.keys()) # remove deleted sockets
		for w_socket in w:
			client_handlers[w_socket].handle_write()


# =============================================================
# An implementation of the HTTP protocol, supporting persistent
# connections and CGI
# =============================================================

class HTTP(ClientHandler):
	# parameters to override if necessary
	root = os.getcwd()				# the directory to serve files from
	index_files = ['index.cgi','index.html']	# index files for directories
	logging = True					# print logging info for each request ?
	blocksize = 2 << 16				# size of blocks to read from files and send

	def request_complete(self):
		"""In the HTTP protocol, a request is complete if the "end of headers"
		sequence ('\r\n\r\n') has been received
		If the request is POST, stores the request body in a StringIO before
		returning True"""
		term = '\r\n\r\n'
		if sys.version>='3':
			term = term.encode('ascii')
		terminator = self.incoming.find(term)
		if terminator == -1:
			return False
		if sys.version>='3':
			lines = self.incoming[:terminator].decode('ascii').split('\r\n')
		else:
			lines = self.incoming[:terminator].split('\r\n')
		self.requestline = lines[0]
		try:
			self.method,self.url,self.protocol = lines[0].strip().split()
			if not self.protocol.startswith("HTTP/1") or ( self.protocol[7]!='0' and self.protocol[7]!='1') or len(self.protocol)!=8:
				self.method = None
				self.protocol = "HTTP/1.1"
				self.postbody = None
				return True
		except:
			self.method = None
			self.protocol = "HTTP/1.1"
			self.postbody = None
			return True
		# put request headers in a dictionary
		self.headers = {}
		for line in lines[1:]:
			k,v = line.split(':',1)
			self.headers[k.lower().strip()] = v.strip()
		# persistent connection
		close_conn = self.headers.get("connection","")
		if (self.protocol == "HTTP/1.1" and close_conn.lower() == "keep-alive"):
			self.close_when_done = False
		# parse the url
		scheme,netloc,path,params,query,fragment = urlparse(self.url)
		self.path,self.rest = path,(params,query,fragment)

		if self.method == 'POST':
			# for POST requests, read the request body
			# its length must be specified in the content-length header
			content_length = int(self.headers.get('content-length',0))
			body = self.incoming[terminator+4:]
			# request is incomplete if not all message body received
			if len(body)<content_length:
				return False
			self.postbody = body
		else:
			self.postbody = None

		return True

	def make_response(self):
#		try:
			"""Build the response : a list of strings or files"""
			if self.method is None: # bad request
				return self.err_resp(400,'Bad request : %s' %self.requestline)
			resp_headers, resp_body, resp_file = '','',None
			if not self.method in ['GET','POST','HEAD']:
				return self.err_resp(501,'Unsupported method (%s)' %self.method)
			else:
				file_name = self.file_name = self.translate_path()
				if not file_name.startswith(HTTP.root+os.path.sep) and not file_name==HTTP.root:
					return self.err_resp(403,'Forbidden')
				elif not os.path.exists(file_name):
					return self.err_resp(404,'File not found')
				elif self.managed():
					response = self.mngt_method()
				elif not os.access(file_name,os.R_OK):
					return self.err_resp(403,'Forbidden')
				else:
					fstatdata = os.stat(file_name)
					if (fstatdata.st_mode & 0xF000) == 0x4000:	# directory
						for index in self.index_files:
							if os.path.exists(file_name+'/'+index) and os.access(file_name+'/'+index,os.R_OK):
								return self.redirect_resp(index)
					if (fstatdata.st_mode & 0xF000) != 0x8000:
						return self.err_resp(403,'Forbidden')
					ext = os.path.splitext(file_name)[1]
					c_type = mimetypes.types_map.get(ext,'text/plain')
					resp_line = "%s 200 Ok\r\n" %self.protocol
					size = fstatdata.st_size
					resp_headers = "Content-Type: %s\r\n" %c_type
					resp_headers += "Content-Length: %s\r\n" %size
					resp_headers += '\r\n'
					if sys.version>='3':
						resp_line = resp_line.encode('ascii')
						resp_headers = resp_headers.encode('ascii')
					if self.method == "HEAD":
						resp_string = resp_line + resp_headers
					elif size > HTTP.blocksize:
						resp_string = resp_line + resp_headers
						resp_file = open(file_name,'rb')
					else:
						resp_string = resp_line + resp_headers + \
							open(file_name,'rb').read()
					response = [resp_string]
					if resp_file:
						response.append(resp_file)
			self.log(200)
			return response
#		except:
#			return self.err_resp(500,'Internal Server Error')

	def translate_path(self):
		"""Translate URL path into a path in the file system"""
		return os.path.realpath(os.path.join(HTTP.root,*self.path.split('/')))

	def managed(self):
		"""Test if the request can be processed by a specific method
		If so, set self.mngt_method to the method used
		This implementation tests if the script is in a cgi directory"""
		if self.is_cgi():
			self.mngt_method = self.run_cgi
			return True
		return False

	def is_cgi(self):
		"""Test if url points to cgi script"""
		if self.path.endswith(".cgi"):
			return True
		return False

	def run_cgi(self):
		if not os.access(self.file_name,os.X_OK):
			return self.err_resp(403,'Forbidden')
		# set CGI environment variables
		e = self.make_cgi_env()
		self.close_when_done = True
		if self.method == "HEAD":
			try:
				proc = subprocess.Popen(self.file_name, env=e, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
				cgiout, cgierr = proc.communicate()
				response = cgiout + cgierr
				if sys.version>='3':
					response = response.decode('latin-1')
			except:
				response = "Content-type: text/plain\r\n\r\n" + traceback.format_exc()
			# for HEAD request, don't send message body even if the script
			# returns one (RFC 3875)
			head_lines = []
			for line in response.split('\n'):
				if not line:
					break
				head_lines.append(line)
			response = '\n'.join(head_lines)
			if sys.version>='3':
				response = response.encode('latin-1')
			resp_line = "%s 200 Ok\r\n" %self.protocol
			if sys.version>='3':
				resp_line = resp_line.encode('ascii')
			return [resp_line + response]
		else:
			try:
				if self.postbody != None:
					proc = subprocess.Popen(self.file_name, env=e, stdout=subprocess.PIPE, stderr=subprocess.PIPE, stdin=subprocess.PIPE)
					cgiout, cgierr = proc.communicate(self.postbody)
					response = cgiout + cgierr
					if sys.version>='3':
						response = response.decode('latin-1')
					resp_line = "%s 200 Ok\r\n" %self.protocol
					if sys.version>='3':
						resp_line = resp_line.encode('ascii')
					return [resp_line + response]
				else:
					proc = subprocess.Popen(self.file_name, env=e, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
					resp_line = "%s 200 Ok\r\n" %self.protocol
					if sys.version>='3':
						resp_line = resp_line.encode('ascii')
					return [resp_line,proc.stdout,proc.stderr]
			except:
				response = "Content-type: text/plain\r\n\r\n" + traceback.format_exc()
				resp_line = "%s 200 Ok\r\n" %self.protocol
				if sys.version>='3':
					resp_line = resp_line.encode('ascii')
					response = response.encode('latin-1')
				return [resp_line + response]
		# sys.stdout = save_stdout # restore sys.stdout
		# close connection in case there is no content-length header
		# resp_line = "%s 200 Ok\r\n" %self.protocol
		# if sys.version>='3':
		#	resp_line = resp_line.encode('ascii')
		# return [resp_line + response]

	def make_cgi_env(self):
		"""Set CGI environment variables"""
		env = {}
		env['SERVER_SOFTWARE'] = "AsyncServer"
		env['SERVER_NAME'] = "AsyncServer"
		env['GATEWAY_INTERFACE'] = 'CGI/1.1'
		env['DOCUMENT_ROOT'] = HTTP.root
		env['SERVER_PROTOCOL'] = "HTTP/1.1"
		env['SERVER_PORT'] = str(self.server.port)

		env['REQUEST_METHOD'] = self.method
		env['REQUEST_URI'] = self.url
		env['PATH_TRANSLATED'] = self.translate_path()
		env['SCRIPT_NAME'] = self.path
		env['PATH_INFO'] = urlunparse(("","","",self.rest[0],"",""))
		env['QUERY_STRING'] = self.rest[1]
		if not self.host == self.client_address[0]:
			env['REMOTE_HOST'] = self.host
		env['REMOTE_ADDR'] = self.client_address[0]
		env['CONTENT_LENGTH'] = str(self.headers.get('content-length',''))
		for k in ['USER_AGENT','COOKIE','ACCEPT','ACCEPT_CHARSET',
			'ACCEPT_ENCODING','ACCEPT_LANGUAGE','CONNECTION']:
			hdr = k.lower().replace("_","-")
			env['HTTP_%s' %k.upper()] = str(self.headers.get(hdr,''))
		return env

	def redirect_resp(self,redirurl):
		"""Return redirect message"""
		resp_line = "%s 301 Moved Permanently\r\nLocation: %s\r\n" % (self.protocol,redirurl)
		if sys.version>='3':
			resp_line = resp_line.encode('ascii')
		self.close_when_done = True
		self.log(301)
		return [resp_line]

	def err_resp(self,code,msg):
		"""Return an error message"""
		resp_line = "%s %s %s\r\n" %(self.protocol,code,msg)
		if sys.version>='3':
			resp_line = resp_line.encode('ascii')
		self.close_when_done = True
		self.log(code)
		return [resp_line]

	def log(self,code):
		"""Write a trace of the request on stderr"""
		if HTTP.logging:
			date_str = datetime.datetime.now().strftime('[%d/%b/%Y %H:%M:%S]')
			sys.stderr.write('%s - - %s "%s" %s\n' %(self.host,date_str,self.requestline,code))


def mylock(filename):
	try:
		fd = posix.open(filename,posix.O_RDWR|posix.O_CREAT,438) # 438 = 0o666
	except IOError:
		return -1
	try:
		fcntl.flock(fd,fcntl.LOCK_EX|fcntl.LOCK_NB)
	except IOError:
		ex = sys.exc_info()[1]
		if ex.errno != errno.EAGAIN:
			posix.close(fd)
			return -1
		try:
			pid = int(posix.read(fd,100).strip())
			posix.close(fd)
			return pid
		except ValueError:
			posix.close(fd)
			return -2
	posix.ftruncate(fd,0)
	if sys.version_info[0]<3:
		posix.write(fd,"%u" % posix.getpid())
	else:
		posix.write(fd,("%u" % posix.getpid()).encode('utf-8'))
	return 0

def wdlock(fname,runmode,timeout):
	killed = 0
	for i in range(timeout):
		l = mylock(fname)
		if l==0:
			if runmode==2:
				if killed:
					return 0
				else:
					print("can't find process to terminate")
					return -1
			if runmode==3:
				print("mfscgiserv is not running")
				return 0
			print("lockfile created and locked")
			return 1
		elif l<0:
			if l<-1:
				print("lockfile is damaged (can't obtain pid - kill prevoius instance manually)")
			else:
				print("lockfile error")
			return -1
		else:
			if runmode==3:
				print("mfscgiserv pid:%u" % l)
				return 0
			if runmode==1:
				print("can't start: lockfile is already locked by another process")
				return -1
			if killed!=l:
				print("sending SIGTERM to lock owner (pid:%u)" % l)
				posix.kill(l,signal.SIGTERM)
				killed = l
			if (i%10)==0 and i>0:
				print("about %u seconds passed and lock still exists" % i)
			time.sleep(1)
	print("about %u seconds passed and lockfile is still locked - giving up" % timeout)
	return -1

if __name__=="__main__":
	locktimeout = 60
	daemonize = 1
	verbose = 0
	host = 'any'
	port = 9425
	rootpath="/usr/share/mfscgi"
	datapath="/var/lib/mfs"

	opts,args = getopt.getopt(sys.argv[1:],"hH:P:R:t:fv")
	for opt,val in opts:
		if opt=='-h':
			print("usage: %s [-H bind_host] [-P bind_port] [-R rootpath] [-t locktimeout] [-f [-v]] [start|stop|restart|test]\n" % sys.argv[0])
			print("-H bind_host : local address to listen on (default: any)\n-P bind_port : port to listen on (default: 9425)\n-R rootpath : local path to use as HTTP document root (default: /usr/share/mfscgi)\n-t locktimeout : how long to wait for lockfile (default: 60s)\n-f : run in foreground\n-v : log requests on stderr")
			os._exit(0)
		elif opt=='-H':
			host = val
		elif opt=='-P':
			port = int(val)
		elif opt=='-R':
			rootpath = val
		elif opt=='t':
			locktimeout = int(val)
		elif opt=='-f':
			daemonize = 0
		elif opt=='-v':
			verbose = 1

	lockfname = datapath + os.path.sep + '.mfscgiserv.lock'

	try:
		mode = args[0]
		if mode=='start':
			mode = 1
		elif mode=='stop':
			mode = 2
		elif mode=='test':
			mode = 3
		else:
			mode = 0
	except:
		mode = 0

	rootpath = os.path.realpath(rootpath)

	pipefd = posix.pipe()

	if (mode==1 or mode==0) and daemonize:
# daemonize
		try:
			pid = os.fork()
		except OSError:
			e = sys.exc_info()[1]
			raise Exception("fork error: %s [%d]" % (e.strerror, e.errno))
		if pid>0:
			posix.read(pipefd[0],1)
			os._exit(0)
		os.setsid()
		try:
			pid = os.fork()
		except OSError:
			e = sys.exc_info()[1]
			raise Exception("fork error: %s [%d]" % (e.strerror, e.errno))
			if sys.version_info[0]<3:
				posix.write(pipefd[1],'0')
			else:
				posix.write(pipefd[1],bytes(1))
		if pid>0:
			os._exit(0)

	if wdlock(lockfname,mode,locktimeout)==1:

		print("starting simple cgi server (host: %s , port: %u , rootpath: %s)" % (host,port,rootpath))

		if daemonize:
			os.close(0)
			os.close(1)
			os.close(2)
			if os.open("/dev/null",os.O_RDWR)!=0:
				raise Exception("can't open /dev/null as 0 descriptor")
			os.dup2(0,1)
			os.dup2(0,2)

			if sys.version_info[0]<3:
				posix.write(pipefd[1],'0')
			else:
				posix.write(pipefd[1],bytes(1))

		posix.close(pipefd[0])
		posix.close(pipefd[1])

		server = Server(host, port)

# launch the server on the specified port
		if not daemonize:
			if host!='any':
				print("Asynchronous HTTP server running on %s:%s" % (host,port))
			else:
				print("Asynchronous HTTP server running on port %s" % port)
		if not daemonize and verbose:
			HTTP.logging = True
		else:
			HTTP.logging = False
		HTTP.root = rootpath
		try:
			os.seteuid(pwd.getpwnam("mfs").pw_uid)
		except Exception:
			try:
				os.seteuid(pwd.getpwnam("nobody").pw_uid)
			except Exception:
				pass
		signal.signal(signal.SIGCHLD, signal.SIG_IGN)
		loop(server,HTTP)

	else:
		if sys.version_info[0]<3:
			posix.write(pipefd[1],'0')
		else:
			posix.write(pipefd[1],bytes(1))
		os._exit(0)
