#!/bin/bash

DEVICES_PATH=/sys/bus/pci/devices
VFIO_MODULE="vfio-pci"
VFIO_PCI_NEW_ID=/sys/bus/pci/drivers/${VFIO_MODULE}/new_id
Interactive=true
UnsafeInts=false
CHOWN_USER=

die() {
	echo "$@" >&2
	exit 1
}

dieusage() {
	die "Usage: $0 [--unsafe] [--user <username>] [--yes] <vendor> <device>"
}

yesno() {
	if ! $Interactive; then
		$2
		return
	fi
	while true; do
		echo -n "$1 "
		read answer
		case "$answer" in
		yes|y|ye)
			return 0
			;;
		no|n)
			return 1
			;;
		*)
			echo "Please enter 'yes' or 'no'."
		esac
	done
}

while [ "$#" -gt 0 ]; do
	case "$1" in
	--help|-h)
		dieusage
		;;
	--unsafe)
		UnsafeInts=true
		;;
	--user|-u)
		shift
		CHOWN_USER="$1"
		;;
	--yes|-y)
		Interactive=false
		;;
	*)
		if [ -n "$WANT_VENDOR" ]; then
			WANT_DEVICE="$1"
		else
			WANT_VENDOR="$1"
		fi
		;;
	esac
	shift
done

[ -n "$WANT_DEVICE" ] || dieusage

modprobe ${VFIO_MODULE} || die 'Failed to load ${VFIO_MODULE} module'
for TARGET_DEVICE_ID in $(ls $DEVICES_PATH); do
	{ grep -q '^0x'"${WANT_VENDOR}" "${DEVICES_PATH}/${TARGET_DEVICE_ID}/vendor" && grep -q '^0x'"${WANT_DEVICE}" "${DEVICES_PATH}/${TARGET_DEVICE_ID}/device"; } || continue
	
	echo "Found $(lspci -s "$TARGET_DEVICE_ID")"
	extradevs=
	extradevsn=0
	extradevsq=
	for RELATED_DEVICE_ID in $(ls "${DEVICES_PATH}/${TARGET_DEVICE_ID}/iommu_group/devices/"); do
		if [ "$RELATED_DEVICE_ID" = "$TARGET_DEVICE_ID" ] ||
		 {
		  { ! [ -e "${DEVICES_PATH}/${RELATED_DEVICE_ID}/driver" ]; } ||
		  [ "$(basename "$(readlink "${DEVICES_PATH}/${RELATED_DEVICE_ID}/driver")")" = "${VFIO_MODULE}" ]
		 } ||
		 grep -q '^0x060400$' "${DEVICES_PATH}/${RELATED_DEVICE_ID}/class"; then
			extradevsq="$extradevsq $RELATED_DEVICE_ID"
		else
			extradevs="$extradevs $RELATED_DEVICE_ID"
			let ++extradevsn
		fi
	done
	if [ "$extradevsn" -gt 0 ]; then
		if [ $extradevsn -gt 1 ]; then
			echo "Enabling VFIO for this device will also disable the following $extradevsn devices:"
		else
			echo "Enabling VFIO for this device will also disable this device:"
		fi
		for RELATED_DEVICE_ID in ${extradevs}; do
			echo "- $(lspci -s "$RELATED_DEVICE_ID")"
		done
	fi
	yesno 'Enable VFIO?' true || continue
	
	if $UnsafeInts; then
		echo 1 >/sys/module/vfio_iommu_type1/parameters/allow_unsafe_interrupts || die 'Failed to enable unsafe interrupts'
	fi
	for RELATED_DEVICE_ID in ${extradevsq} ${extradevs}; do
		if [ -e "${DEVICES_PATH}/${RELATED_DEVICE_ID}/driver" ]; then
			echo "$RELATED_DEVICE_ID" >"${DEVICES_PATH}/${RELATED_DEVICE_ID}/driver/unbind" || die "Failed to unbind $RELATED_DEVICE_ID"
		fi
		echo "$(<"${DEVICES_PATH}/${RELATED_DEVICE_ID}/vendor") $(<"${DEVICES_PATH}/${RELATED_DEVICE_ID}/device")" >"${VFIO_PCI_NEW_ID}" || die "Failed to associate device id with ${VFIO_MODULE} module"
	done
	IOMMU_GROUP="$(basename "$(readlink "${DEVICES_PATH}/${TARGET_DEVICE_ID}/iommu_group")")"
	VFIO_DEVICE="/dev/vfio/$IOMMU_GROUP"
	[ -e "$VFIO_DEVICE" ] || die "$VFIO_DEVICE does not exist"
	if [ -n "${CHOWN_USER}" ]; then
		chown "${CHOWN_USER}" "$VFIO_DEVICE" || die "Failed to chown $VFIO_DEVICE to $CHOWN_USER"
	fi
	echo "VFIO enabled on $VFIO_DEVICE"
done
