#!/usr/bin/env python2
#
#  Copyright (c) 2015, The Linux Foundation. All rights reserved.
#
#  Redistribution and use in source and binary forms, with or without
#  modification, are permitted provided that the following conditions are
#  met:
#      * Redistributions of source code must retain the above copyright
#        notice, this list of conditions and the following disclaimer.
#      * Redistributions in binary form must reproduce the above
#        copyright notice, this list of conditions and the following
#        disclaimer in the documentation and/or other materials provided
#        with the distribution.
#      * Neither the name of The Linux Foundation nor the names of its
#        contributors may be used to endorse or promote products derived
#        from this software without specific prior written permission.
#
#  THIS SOFTWARE IS PROVIDED "AS IS" AND ANY EXPRESS OR IMPLIED
#  WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF
#  MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NON-INFRINGEMENT
#  ARE DISCLAIMED.  IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS
#  BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
#  CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
#  SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR
#  BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY,
#  WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE
#  OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN
#  IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
# Drop in replacement for dtbTool
#

from functools import total_ordering
from struct import pack
import os
import sys
import subprocess
from optparse import OptionParser

@total_ordering
class DTRecord:
	def __init__(self, plat_id, variant_id, subtype_id, soc_rev, pmic0, pmic1, pmic2, pmic3, offset, size, f, version):
		self.plat_id = int(plat_id)
		self.variant_id = int(variant_id)
		self.subtype_id = int(subtype_id)
		self.soc_rev = int(soc_rev)
		self.pmic0 = int(pmic0)
		self.pmic1 = int(pmic1)
		self.pmic2 = int(pmic2)
		self.pmic3 = int(pmic3)
		self.offset = int(offset)
		self.size = int(size)
		self.f = f
		self.version = int(version)
		self.duplicate = False

	def __eq__(self, other):
		if other is None:
			return False
		return (self.plat_id == other.plat_id and
		       self.variant_id == other.variant_id and
		       self.subtype_id == other.subtype_id and
		       self.soc_rev == other.soc_rev and
		       self.pmic0 == other.pmic0 and
		       self.pmic1 == other.pmic1 and
		       self.pmic2 == other.pmic2 and
		       self.pmic3 == other.pmic3)

	def __lt__(self, other):
		if other is None:
			return False
		if self.plat_id < other.plat_id:
			return True
		if self.plat_id > other.plat_id:
			return False
		if self.variant_id < other.variant_id:
			return True
		if self.variant_id > other.variant_id:
			return False
		if self.subtype_id < other.subtype_id:
			return True
		if self.subtype_id > other.subtype_id:
			return False
		if self.soc_rev < other.soc_rev:
			return True
		if self.soc_rev > other.soc_rev:
			return False
		if self.pmic0 < other.pmic0:
			return True
		if self.pmic0 > other.pmic0:
			return False
		if self.pmic1 < other.pmic1:
			return True
		if self.pmic1 > other.pmic1:
			return False
		if self.pmic2 < other.pmic2:
			return True
		if self.pmic2 > other.pmic2:
			return False
		if self.pmic3 < other.pmic3:
			return True
		return False

def generate_records(f, dtc, pagesize):
	size = os.stat(f).st_size
	mod = size % pagesize
	if mod != 0:
		size += pagesize - mod;
	try:
		msm_id = subprocess.check_output([os.path.join(dtc, "fdtget"),
						  f, "/", "qcom,msm-id"],
						  stderr=subprocess.PIPE)
	except subprocess.CalledProcessError:
		sys.exit("%s is missing qcom,msm-id?" % f)

	try:
		board_id = subprocess.check_output([os.path.join(dtc, "fdtget"),
						  f, "/", "qcom,board-id"],
						  stderr=subprocess.PIPE)
	except subprocess.CalledProcessError:
		board_id = None

	try:
		pmic_id = subprocess.check_output([os.path.join(dtc, "fdtget"),
						  f, "/", "qcom,pmic-id"],
						  stderr=subprocess.PIPE)
	except subprocess.CalledProcessError:
		pmic_id = None

	if board_id:
		x = iter(board_id.split(" "))
		board_id = zip(x, x)
		x = iter(msm_id.split(" "))
		msm_id = zip(x, x)
		version = 2
	else:
		x = iter(msm_id.split(" "))
		msm_id = zip(x, x, x)
		version = 1

	if pmic_id:
		x = iter(pmic_id.split(" "))
		pmic_id = zip(x, x, x, x)
		version = 3

	records = []
	if version == 1:
		for (plat_id, variant_id, subtype_id) in msm_id:
			records += [DTRecord(plat_id, variant_id, subtype_id, 0,
					    0, 0, 0, 0, 0, size, f, version)]
	elif version == 2:
		for (plat_id, soc_rev) in msm_id:
			for (variant_id, subtype_id) in board_id:
				records += [DTRecord(plat_id, variant_id,
						    subtype_id, soc_rev,
						    0, 0, 0, 0, 0, size, f,
						    version)]
	elif version == 3:
		for (plat_id, soc_rev) in msm_id:
			for (variant_id, subtype_id) in board_id:
				for (pmic0, pmic1, pmic2, pmic3) in pmic_id:
					records += [DTRecord(plat_id, variant_id,
							    subtype_id, soc_rev,
							    pmic0, pmic1, pmic2,
							    pmic3, 0, size, f,
							    version)]
	return records

def write_padding(f, pagesize):
	count = pagesize - (f.tell() % pagesize)
	# Write padding as long as we aren't already aligned to a page
	if count != pagesize:
		output.write("".join(['\x00' for x in xrange(count)]))

if __name__ == "__main__":
	usage = ("""%prog -o <output file> <input DTB directory> [options]""")
	parser = OptionParser(usage=usage)
	# Standard options
	parser.add_option("-o", "--output-file", dest="output", metavar="FILE",
			  help="output file")
	parser.add_option("-p", "--dtc-path", dest="dtc", metavar="PATH",
			  help="path to dtc", default="")
	parser.add_option("-v", "--verbose", action="store_true",
			  dest="verbose", help="verbose")
	parser.add_option("-s", "--page-size", type="int", dest="pagesize",
			  default=2048,
			  help="page size in bytes [default: %default]")
	# New options
	parser.add_option("--version", dest="version", type="int",
			  help="Force version")

	(options, args) = parser.parse_args()

	if options.output is None:
		parser.error("Output file must be specified")
	if len(args) != 1:
		parser.error("Exactly one input directory must be specified")

	pagesize = options.pagesize
	indir = args[0]
	flist = [os.path.join(indir, f)
			for f in os.listdir(indir)
			if os.path.isfile(os.path.join(indir, f)) and
			   f.endswith('.dtb')]

	records = []
	for f in flist:
		records += generate_records(f, options.dtc, pagesize)

	records.sort()
	if options.version is None:
		version = records[0].version
	else:
		version = options.version

	try:
		output = open(options.output, 'wb')
	except IOError:
		sys.exit("Can't open %s" % options.output)

	offset = 12
	if version == 1:
		offset += 5 * 4 * len(records)
	elif version == 2:
		offset += 6 * 4 * len(records)
	elif version == 3:
		offset += 10 * 4 * len(records)
	offset += 4
	offset += pagesize - (offset % pagesize)
	found = {}
	p = None
	for r in records:
		if r.f in found:
			r.duplicate = True
		else:
			if p:
				offset += p.size
			found[r.f] = offset
			p = r
		r.offset = found[r.f]

	hdr_format = "<4sII"
	blob = pack(hdr_format, "QCDT", version, len(records))

	if version == 1:
		blob_format = "<IIIII"
		blob += "".join([pack(blob_format,
			  r.plat_id,
			  r.variant_id,
			  r.subtype_id,
			  r.offset,
			  r.size) for r in records])
	elif version == 2:
		blob_format = "<IIIIII"
		blob += "".join([pack(blob_format,
			  r.plat_id,
			  r.variant_id,
			  r.subtype_id,
			  r.soc_rev,
			  r.offset,
			  r.size) for r in records])
	elif version == 3:
		blob_format = "<IIIIIIIIII"
		blob += "".join([pack(blob_format,
			  r.plat_id,
			  r.variant_id,
			  r.subtype_id,
			  r.soc_rev,
			  r.pmic0,
			  r.pmic1,
			  r.pmic2,
			  r.pmic3,
			  r.offset,
			  r.size) for r in records])
	blob += pack("<I", 0)
	output.write(blob)

	write_padding(output, pagesize)
	for r in records:
		try:
			if not r.duplicate:
				f = open(r.f, "rb")
				output.write(f.read())
				write_padding(output, pagesize)
				f.close()
		except IOError:
			sys.exit("Bad write")
