#!/usr/bin/env python2
#
#  Copyright (c) 2015-2016, The Linux Foundation. All rights reserved.
#
#  Redistribution and use in source and binary forms, with or without
#  modification, are permitted provided that the following conditions are
#  met:
#      * Redistributions of source code must retain the above copyright
#        notice, this list of conditions and the following disclaimer.
#      * Redistributions in binary form must reproduce the above
#        copyright notice, this list of conditions and the following
#        disclaimer in the documentation and/or other materials provided
#        with the distribution.
#      * Neither the name of The Linux Foundation nor the names of its
#        contributors may be used to endorse or promote products derived
#        from this software without specific prior written permission.
#
#  THIS SOFTWARE IS PROVIDED "AS IS" AND ANY EXPRESS OR IMPLIED
#  WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF
#  MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NON-INFRINGEMENT
#  ARE DISCLAIMED.  IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS
#  BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
#  CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
#  SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR
#  BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY,
#  WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE
#  OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN
#  IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
# Drop in replacement for dtbTool
#

from functools import total_ordering
from struct import pack, unpack
import ctypes
import os
import sys
import re
from optparse import OptionParser

soc_ids = {
		'msm8660'        : 71,
		'msm8960'        : 87,
		'apq8064'        : 109,
		'msm8660a'       : 122,
		'msm8260a'       : 123,
		'apq8060a'       : 124,
		'msm8974'        : 126,
		'msm8960ab'      : 138,
		'apq8060ab'      : 139,
		'msm8260ab'      : 140,
		'msm8660ab'      : 141,
		'apq8064_prime'  : 153,
		'apq8084'        : 178,
		'apq8074'        : 184,
		'msm8274'        : 185,
		'msm8674'        : 186,
		'msm8974_pro'    : 194,
		'ipq8062'        : 201,
		'ipq8064'        : 202,
		'ipq8066'        : 203,
		'msm8916'        : 206,
		'msm8994'        : 207,
		'apq8074_aa'     : 208,
		'apq8074_ab'     : 209,
		'apq8074_pro'    : 210,
		'msm8274_aa'     : 211,
		'msm8274_ab'     : 212,
		'msm8274_pro'    : 213,
		'msm8674_aa'     : 214,
		'msm8674_ab'     : 215,
		'msm8674_pro'    : 216,
		'msm8974_aa'     : 217,
		'msm8974_ab'     : 218,
		'apq8064_au'     : 244,
		'msm8909'        : 245,
		'msm8996'        : 246,
		'apq8016'        : 247,
		'msm8216'        : 248,
		'msm8992'        : 251,
		'apq8092'        : 252,
		'apq8094'        : 253,
		'ipq4019'        : 273,
		'apq8096'        : 291,
		'msm8996sg'      : 305,
	}

platform_type = {
		'cdp'         : 1,
		'ffa'         : 2,
		'fluid'       : 3,
		'fusion'      : 4,
		'oem'         : 5,
		'qt'          : 6,
		'mtp_mdm'     : 7,
		'mtp'         : 8,
		'liquid'      : 9,
		'dragonboard' : 10,
		'qrd'         : 11,
		'evb'         : 12,
		'hrd'         : 13,
		'dtv'         : 14,
		'rumi'        : 15,
		'virtio'      : 16,
		'xpm'         : 20,
		'rcm'         : 21,
		'stp'         : 23,
		'sbc'         : 24,
		'cls'         : 29,
		}

pmic_ids = {
		'pm8941'   : 1,
		'pm8841'   : 2,
		'pm8019'   : 3,
		'pm8026'   : 4,
		'pm8110'   : 5,
		'pma8084'  : 6,
		'pmi8962'  : 7,
		'pmd9635'  : 8,
		'pm8994'   : 9,
		'pmi8994'  : 10,
		'pm8916'   : 11,
		'pm8004'   : 12,
		'pm8909'   : 13,
		'pm2433'   : 14,
		'pmd9655'  : 15,
		'pm8950'   : 16,
		'pmi8950'  : 17,
		'pmk8001'  : 18,
		'pmi8996'  : 19,
		'pm8998'   : 20,
		'pmi8998'  : 21,
		'pm8953'   : 22,
		'smb1380'  : 23,
		'pm8005'   : 24,
		'pm8937'   : 25,
	}

boot_device_ids = { 'emmc_sdc1' : 2, 'ufs' : 4 }
boot_device_ids_mdm = { 'emmc_sdc1' : 3 }
panel_ids = { 'hd' : 0, '720p' : 1, 'qHD' : 2, 'fWVGA' : 3 }

try:
	libfdt = ctypes.CDLL('libfdt.so')
except OSError:
	sys.exit("libfdt is missing, please install it")

prototype = ctypes.CFUNCTYPE(ctypes.c_char_p, ctypes.c_void_p, ctypes.c_char_p)
paramflags = (1, "fdt", None), (1, "name", None)
fdt_get_alias = prototype(("fdt_get_alias", libfdt), paramflags)

prototype = ctypes.CFUNCTYPE(ctypes.c_int, ctypes.c_void_p, ctypes.c_char_p)
paramflags = (1, "fdt", None), (1, "path", None)
fdt_path_offset = prototype(("fdt_path_offset", libfdt), paramflags)

prototype = ctypes.CFUNCTYPE(ctypes.c_void_p, ctypes.c_void_p, ctypes.c_int, ctypes.c_char_p, ctypes.c_void_p)
paramflags = (1, "fdt", None), (1, "nodeoffset", None), (1, "name", None), (1, "lenp", None)
fdt_getprop = prototype(("fdt_getprop", libfdt), paramflags)

def fdt_get_property_int(blob, nodeoffset, name):
	"""Return an int list of property name's values at nodeoffset
	"""
	vals = []
	i = 0
	length = ctypes.c_int(0)
	name = str.encode(name)

	addr = fdt_getprop(blob, nodeoffset, name, ctypes.byref(length))
	if addr is None:
		return None

	addr = ctypes.cast(addr, ctypes.POINTER(ctypes.c_uint32))

	while i < length.value / 4:
		# Convert from big-endian to native endianness
		val = unpack("I", pack(">I", addr[i]))[0]
		vals.append(val)
		i += 1

	return vals

def fdt_get_property_string(blob, nodeoffset, name):
	"""Return a string list of property name's values at nodeoffset
	"""
	vals = []
	i = 0
	length = ctypes.c_int(0)
	name = str.encode(name)

	val = fdt_getprop(blob, nodeoffset, name, ctypes.byref(length))
	if val is None:
		return vals

	while i < length.value:
		c = ctypes.cast(val + i, ctypes.c_char_p).value.decode("utf-8")
		vals.append(c)
		i += len(vals[-1]) + 1 # Skip the NUL character

	return vals

def find_pmics(blob):
	pmics = []

	for usid in ["usid0", "usid2", "usid4", "usid6"]:
		path = fdt_get_alias(blob, str.encode(usid))
		if path is None:
			pmics.append("")
			continue
		off = fdt_path_offset(blob, path)
		pmic = fdt_get_property_string(blob, off, "compatible")
		pmics.append(pmic[0])

	return pmics

class QcomIds(object):

	pattern = re.compile(r'''
	qcom,(?P<soc>\w+) # SoC (req)
	(-v(?P<soc_major>\d+)(.(?P<soc_minor>\d+))?)? # SoC version (opt)
	(-(?P<foundry>\d+))? # foundry (opt)
	-(?P<platform>\w+) # Platform (req)
	(/(?P<platform_sub>\d+))? # Number for platform sub (opt)
	(-v(?P<plat_major>\d+)(.(?P<plat_minor>\d+))?)? # Plat version (opt)
	''', re.VERBOSE)

	pmic_pattern = re.compile(r'''
	qcom,(?P<pmic>\w+)(-v(?P<major>\d+)(.(?P<minor>\d+))?)?
	''', re.VERBOSE)

	def __init__(self, compat, pmics):
		self.msm_id = [0, 0]
		self.board_id = [0, 0]
		self.pmic_id = [0, 0, 0, 0]

		matches = QcomIds.pattern.match(compat).groupdict()

		foundry = int(matches['foundry'] or 0)
		self.msm_id[0] = soc_ids[matches['soc']] | (foundry << 16)
		major = int(matches['soc_major'] or 0)
		minor = int(matches['soc_minor'] or 0)
		self.msm_id[1] = major << 16 | minor

		plat = platform_type[matches['platform']]
		subtype = int(matches['platform_sub'] or 0)
		major = int(matches['plat_major'] or 0)
		minor = int(matches['plat_minor'] or 0)
		self.board_id[0] = plat | minor << 8 | major << 16 | subtype << 24
		self.board_id[1] = 0

		for (i, pmic) in enumerate(pmics):
			if len(pmic) == 0:
				continue

			matches = QcomIds.pmic_pattern.match(pmic).groupdict()
			model = pmic_ids[matches['pmic']]
			major = int(matches['major'] or 0)
			minor = int(matches['minor'] or 0)
			self.pmic_id[i] = model | major << 16 | minor << 8

@total_ordering
class DTRecord:
	def __init__(self, plat_id, variant_id, subtype_id, soc_rev, pmic0, pmic1, pmic2, pmic3, offset, size, f, version):
		self.plat_id = int(plat_id)
		self.variant_id = int(variant_id)
		self.subtype_id = int(subtype_id)
		self.soc_rev = int(soc_rev)
		self.pmic0 = int(pmic0)
		self.pmic1 = int(pmic1)
		self.pmic2 = int(pmic2)
		self.pmic3 = int(pmic3)
		self.offset = int(offset)
		self.size = int(size)
		self.f = f
		self.version = int(version)
		self.duplicate = False

	def __eq__(self, other):
		if other is None:
			return False
		return (self.plat_id == other.plat_id and
		       self.variant_id == other.variant_id and
		       self.subtype_id == other.subtype_id and
		       self.soc_rev == other.soc_rev and
		       self.pmic0 == other.pmic0 and
		       self.pmic1 == other.pmic1 and
		       self.pmic2 == other.pmic2 and
		       self.pmic3 == other.pmic3)

	def __lt__(self, other):
		if other is None:
			return False
		if self.plat_id < other.plat_id:
			return True
		if self.plat_id > other.plat_id:
			return False
		if self.variant_id < other.variant_id:
			return True
		if self.variant_id > other.variant_id:
			return False
		if self.subtype_id < other.subtype_id:
			return True
		if self.subtype_id > other.subtype_id:
			return False
		if self.soc_rev < other.soc_rev:
			return True
		if self.soc_rev > other.soc_rev:
			return False
		if self.pmic0 < other.pmic0:
			return True
		if self.pmic0 > other.pmic0:
			return False
		if self.pmic1 < other.pmic1:
			return True
		if self.pmic1 > other.pmic1:
			return False
		if self.pmic2 < other.pmic2:
			return True
		if self.pmic2 > other.pmic2:
			return False
		if self.pmic3 < other.pmic3:
			return True
		return False

def generate_records(f, pagesize):
	version = None
	size = os.stat(f).st_size
	mod = size % pagesize
	if mod != 0:
		size += pagesize - mod;

	fi = open(f, 'rb')
	s = ctypes.create_string_buffer(fi.read())
	blob = ctypes.byref(s)

	msm_id = fdt_get_property_int(blob, 0, "qcom,msm-id")
	if msm_id is None:
		compats = fdt_get_property_string(blob, 0, "compatible")
		if len(compats) == 0:
			sys.exit("%s is missing compatible?" % f)

	board_id = fdt_get_property_int(blob, 0, "qcom,board-id")
	pmic_id = fdt_get_property_int(blob, 0, "qcom,pmic-id")

	if board_id:
		x = iter(board_id)
		board_id = list(zip(x, x))
		x = iter(msm_id)
		msm_id = list(zip(x, x))
		version = 2
	elif msm_id:
		x = iter(msm_id)
		msm_id = list(zip(x, x, x))
		version = 1

	if pmic_id:
		x = iter(pmic_id)
		pmic_id = list(zip(x, x, x, x))
		version = 3

	records = []
	if version == 1:
		for (plat_id, variant_id, soc_rev) in msm_id:
			records += [DTRecord(plat_id, variant_id, 0, soc_rev,
					    0, 0, 0, 0, 0, size, f, version)]
	elif version == 2:
		for (plat_id, soc_rev) in msm_id:
			for (variant_id, subtype_id) in board_id:
				records += [DTRecord(plat_id, variant_id,
						    subtype_id, soc_rev,
						    0, 0, 0, 0, 0, size, f,
						    version)]
	elif version == 3:
		for (plat_id, soc_rev) in msm_id:
			for (variant_id, subtype_id) in board_id:
				for (pmic0, pmic1, pmic2, pmic3) in pmic_id:
					records += [DTRecord(plat_id, variant_id,
							    subtype_id, soc_rev,
							    pmic0, pmic1, pmic2,
							    pmic3, 0, size, f,
							    version)]
	else:
		version = 3
		pmics = find_pmics(blob)
		ids = [QcomIds(compat, pmics) for compat in compats
				if QcomIds.pattern.match(compat)]
		records += [DTRecord(x.msm_id[0], x.board_id[0], x.board_id[1],
				     x.msm_id[1], x.pmic_id[0], x.pmic_id[1],
				     x.pmic_id[2], x.pmic_id[3], 0, size, f,
				     version) for x in ids]

	return records

def write_padding(f, pagesize):
	count = pagesize - (f.tell() % pagesize)
	# Write padding as long as we aren't already aligned to a page
	if count != pagesize:
		output.write(b"".join([b'\x00' for x in range(count)]))

if __name__ == "__main__":
	usage = ("""%prog -o <output file> <input DTB directory> [options]""")
	parser = OptionParser(usage=usage)
	# Standard options
	parser.add_option("-o", "--output-file", dest="output", metavar="FILE",
			  help="output file")
	parser.add_option("-p", "--dtc-path", dest="dtc", metavar="PATH",
			  help="path to dtc", default="")
	parser.add_option("-v", "--verbose", action="store_true",
			  dest="verbose", help="verbose")
	parser.add_option("-s", "--page-size", type="int", dest="pagesize",
			  default=2048,
			  help="page size in bytes [default: %default]")
	# New options
	parser.add_option("--version", dest="version", type="int",
			  help="Force version")

	(options, args) = parser.parse_args()

	if options.output is None:
		parser.error("Output file must be specified")
	if len(args) != 1:
		parser.error("Exactly one input directory must be specified")

	pagesize = options.pagesize
	indir = args[0]
	flist = [os.path.join(indir, f)
			for f in os.listdir(indir)
			if os.path.isfile(os.path.join(indir, f)) and
			   f.endswith('.dtb')]

	records = []
	for f in flist:
		records += generate_records(f, pagesize)

	if len(records) == 0:
		sys.exit("No valid dtbs found")

	records.sort()
	if options.version is None:
		version = records[0].version
	else:
		version = options.version

	try:
		output = open(options.output, 'wb')
	except IOError:
		sys.exit("Can't open %s" % options.output)

	offset = 12
	if version == 1:
		offset += 5 * 4 * len(records)
	elif version == 2:
		offset += 6 * 4 * len(records)
	elif version == 3:
		offset += 10 * 4 * len(records)
	offset += 4
	offset += pagesize - (offset % pagesize)
	found = {}
	p = None
	for r in records:
		if r.f in found:
			r.duplicate = True
		else:
			if p:
				offset += p.size
			found[r.f] = offset
			p = r
		r.offset = found[r.f]

	hdr_format = "<4sII"
	blob = pack(hdr_format, b"QCDT", version, len(records))

	if version == 1:
		blob_format = "<IIIII"
		blob += b"".join([pack(blob_format,
			  r.plat_id,
			  r.variant_id,
			  r.subtype_id,
			  r.offset,
			  r.size) for r in records])
	elif version == 2:
		blob_format = "<IIIIII"
		blob += b"".join([pack(blob_format,
			  r.plat_id,
			  r.variant_id,
			  r.subtype_id,
			  r.soc_rev,
			  r.offset,
			  r.size) for r in records])
	elif version == 3:
		blob_format = "<IIIIIIIIII"
		blob += b"".join([pack(blob_format,
			  r.plat_id,
			  r.variant_id,
			  r.subtype_id,
			  r.soc_rev,
			  r.pmic0,
			  r.pmic1,
			  r.pmic2,
			  r.pmic3,
			  r.offset,
			  r.size) for r in records])
	blob += pack("<I", 0)
	output.write(blob)

	write_padding(output, pagesize)
	for r in records:
		try:
			if not r.duplicate:
				f = open(r.f, "rb")
				output.write(f.read())
				write_padding(output, pagesize)
				f.close()
		except IOError:
			sys.exit("Bad write")
