# This file is Copyright (c) 2017-2019 Florent Kermarrec <florent@enjoy-digital.fr>
# License: BSD

from migen import *
from migen.genlib.misc import WaitTimer

from litex.soc.interconnect import stream
from litex.soc.interconnect.csr import *

from liteiclink.serwb.kuserdes import KUSerdes
from liteiclink.serwb.s7serdes import S7Serdes


# SERWB Master <--> Slave physical synchronization process:
# 1) Master sends idle patterns (zeroes) to Slave to reset it.
# 2) Master sends K28.5 commas to allow Slave to calibrate, Slave sends idle patterns.
# 3) Slave sends K28.5 commas to allow Master to calibrate, Master sends K28.5 commas.
# 4) Master stops sending K28.5 commas.
# 5) Slave stops sending K28.5 commas.
# 6) Physical link is ready.


@ResetInserter()
class _SerdesMasterInit(Module):
    def __init__(self, serdes, taps, timeout):
        self.ready = Signal()
        self.error = Signal()

        # # #

        self.delay = delay = Signal(max=taps)
        self.delay_min = delay_min = Signal(max=taps)
        self.delay_min_found = delay_min_found = Signal()
        self.delay_max = delay_max = Signal(max=taps)
        self.delay_max_found = delay_max_found = Signal()
        self.bitslip = bitslip = Signal(max=40)

        self.submodules.timer = timer = WaitTimer(timeout)

        self.submodules.fsm = fsm = FSM(reset_state="IDLE")
        fsm.act("IDLE",
            NextValue(delay, 0),
            NextValue(delay_min, 0),
            NextValue(delay_min_found, 0),
            NextValue(delay_max, 0),
            NextValue(delay_max_found, 0),
            serdes.rx.delay_rst.eq(1),
            NextValue(bitslip, 0),
            NextState("RESET_SLAVE"),
            serdes.tx.idle.eq(1)
        )
        fsm.act("RESET_SLAVE",
            timer.wait.eq(1),
            If(timer.done,
                timer.wait.eq(0),
                NextState("SEND_PATTERN")
            ),
            serdes.tx.idle.eq(1)
        )
        fsm.act("SEND_PATTERN",
            If(~serdes.rx.idle,
                timer.wait.eq(1),
                If(timer.done,
                    NextState("CHECK_PATTERN")
                )
            ),
            serdes.tx.comma.eq(1)
        )
        fsm.act("WAIT_STABLE",
            timer.wait.eq(1),
            If(timer.done,
                timer.wait.eq(0),
                NextState("CHECK_PATTERN")
            ),
            serdes.tx.comma.eq(1)
        )
        fsm.act("CHECK_PATTERN",
            If(~delay_min_found,
                If(serdes.rx.comma,
                    timer.wait.eq(1),
                    If(timer.done,
                        timer.wait.eq(0),
                        NextValue(delay_min, delay),
                        NextValue(delay_min_found, 1)
                    )
                ).Else(
                    NextState("INC_DELAY_BITSLIP")
                ),
            ).Else(
                If(~serdes.rx.comma,
                    NextValue(delay_max, delay),
                    NextValue(delay_max_found, 1),
                    NextState("CHECK_SAMPLING_WINDOW")
                ).Else(
                    NextState("INC_DELAY_BITSLIP")
                )
            ),
            serdes.tx.comma.eq(1)
        )
        self.comb += serdes.rx.bitslip_value.eq(bitslip)
        fsm.act("INC_DELAY_BITSLIP",
            NextState("WAIT_STABLE"),
            If(delay == (taps - 1),
                If(bitslip == (40 - 1),
                    NextState("ERROR")
                ).Else(
                    NextValue(delay_min_found, 0),
                    NextValue(delay_min, 0),
                    NextValue(delay_max_found, 0),
                    NextValue(delay_max, 0),
                    NextValue(bitslip, bitslip + 1)
                ),
                NextValue(delay, 0),
                serdes.rx.delay_rst.eq(1)
            ).Else(
                NextValue(delay, delay + 1),
                serdes.rx.delay_inc.eq(1)
            ),
            serdes.tx.comma.eq(1)
        )
        fsm.act("CHECK_SAMPLING_WINDOW",
            If((delay_min == 0) |
               (delay_max == (taps - 1)) |
               ((delay_max - delay_min) < taps//16),
               NextValue(delay_min_found, 0),
               NextValue(delay_max_found, 0),
               NextState("WAIT_STABLE")
            ).Else(
                NextValue(delay, 0),
                serdes.rx.delay_rst.eq(1),
                NextState("CONFIGURE_SAMPLING_WINDOW")
            ),
            serdes.tx.comma.eq(1)
        )
        fsm.act("CONFIGURE_SAMPLING_WINDOW",
            If(delay == (delay_min + (delay_max - delay_min)[1:]),
                NextState("READY")
            ).Else(
                NextValue(delay, delay + 1),
                serdes.rx.delay_inc.eq(1)
            ),
            serdes.tx.comma.eq(1)
        )
        fsm.act("READY",
            self.ready.eq(1)
        )
        fsm.act("ERROR",
            self.error.eq(1)
        )


@ResetInserter()
class _SerdesSlaveInit(Module, AutoCSR):
    def __init__(self, serdes, taps, timeout):
        self.ready = Signal()
        self.error = Signal()

        # # #

        self.delay = delay = Signal(max=taps)
        self.delay_min = delay_min = Signal(max=taps)
        self.delay_min_found = delay_min_found = Signal()
        self.delay_max = delay_max = Signal(max=taps)
        self.delay_max_found = delay_max_found = Signal()
        self.bitslip = bitslip = Signal(max=40)

        self.submodules.timer = timer = WaitTimer(timeout)

        self.submodules.fsm = fsm = FSM(reset_state="IDLE")
        # reset
        fsm.act("IDLE",
            NextValue(delay, 0),
            NextValue(delay_min, 0),
            NextValue(delay_min_found, 0),
            NextValue(delay_max, 0),
            NextValue(delay_max_found, 0),
            serdes.rx.delay_rst.eq(1),
            NextValue(bitslip, 0),
            timer.wait.eq(1),
            If(timer.done,
                timer.wait.eq(0),
                NextState("WAIT_STABLE"),
            ),
            serdes.tx.idle.eq(1)
        )
        fsm.act("WAIT_STABLE",
            timer.wait.eq(1),
            If(timer.done,
                timer.wait.eq(0),
                NextState("CHECK_PATTERN")
            ),
            serdes.tx.idle.eq(1)
        )
        fsm.act("CHECK_PATTERN",
            If(~delay_min_found,
                If(serdes.rx.comma,
                    timer.wait.eq(1),
                    If(timer.done,
                        timer.wait.eq(0),
                        NextValue(delay_min, delay),
                        NextValue(delay_min_found, 1)
                    )
                ).Else(
                    NextState("INC_DELAY_BITSLIP")
                ),
            ).Else(
                If(~serdes.rx.comma,
                    NextValue(delay_max, delay),
                    NextValue(delay_max_found, 1),
                    NextState("CHECK_SAMPLING_WINDOW")
                ).Else(
                    NextState("INC_DELAY_BITSLIP")
                )
            ),
            serdes.tx.idle.eq(1)
        )
        self.comb += serdes.rx.bitslip_value.eq(bitslip)
        fsm.act("INC_DELAY_BITSLIP",
            NextState("WAIT_STABLE"),
            If(delay == (taps - 1),
                If(bitslip == (40 - 1),
                    NextState("ERROR")
                ).Else(
                    NextValue(delay_min_found, 0),
                    NextValue(delay_min, 0),
                    NextValue(delay_max_found, 0),
                    NextValue(delay_max, 0),
                    NextValue(bitslip, bitslip + 1)
                ),
                NextValue(delay, 0),
                serdes.rx.delay_rst.eq(1)
            ).Else(
                NextValue(delay, delay + 1),
                serdes.rx.delay_inc.eq(1)
            ),
            serdes.tx.idle.eq(1)
        )
        fsm.act("CHECK_SAMPLING_WINDOW",
            If((delay_min == 0) |
               (delay_max == (taps - 1)) |
               ((delay_max - delay_min) < taps//16),
               NextValue(delay_min_found, 0),
               NextValue(delay_max_found, 0),
               NextState("WAIT_STABLE")
            ).Else(
                NextValue(delay, 0),
                serdes.rx.delay_rst.eq(1),
                NextState("CONFIGURE_SAMPLING_WINDOW")
            ),
            serdes.tx.idle.eq(1)
        )
        fsm.act("CONFIGURE_SAMPLING_WINDOW",
            If(delay == (delay_min + (delay_max - delay_min)[1:]),
                NextState("SEND_PATTERN")
            ).Else(
                NextValue(delay, delay + 1),
                serdes.rx.delay_inc.eq(1),
            ),
            serdes.tx.idle.eq(1)
        )
        fsm.act("SEND_PATTERN",
            timer.wait.eq(1),
            If(timer.done,
                If(~serdes.rx.comma,
                    NextState("READY")
                )
            ),
            serdes.tx.comma.eq(1)
        )
        fsm.act("READY",
            self.ready.eq(1)
        )
        fsm.act("ERROR",
            self.error.eq(1)
        )


class _SerdesControl(Module, AutoCSR):
    def __init__(self, serdes, init, mode="master"):
        if mode == "master":
            self.reset = CSR()
        self.ready = CSRStatus()
        self.error = CSRStatus()

        self.delay = CSRStatus(9)
        self.delay_min_found = CSRStatus()
        self.delay_min = CSRStatus(9)
        self.delay_max_found = CSRStatus()
        self.delay_max = CSRStatus(9)
        self.bitslip = CSRStatus(6)

        self.prbs_error = Signal()
        self.prbs_start = CSR()
        self.prbs_cycles = CSRStorage(32)
        self.prbs_errors = CSRStatus(32)

        # # #

        if mode == "master":
            # In Master mode, reset is coming from CSR,
            # it resets the Master that will also reset
            # the Slave by putting the link in idle.
            self.sync += init.reset.eq(self.reset.re)
        else:
            # In Slave mode, reset is coming from link,
            # Master reset the Slave by putting the link
            # in idle.
            self.sync += [
                init.reset.eq(serdes.rx.idle),
                serdes.reset.eq(serdes.rx.idle)
            ]
        self.comb += [
            self.ready.status.eq(init.ready),
            self.error.status.eq(init.error),
            self.delay.status.eq(init.delay),
            self.delay_min_found.status.eq(init.delay_min_found),
            self.delay_min.status.eq(init.delay_min),
            self.delay_max_found.status.eq(init.delay_max_found),
            self.delay_max.status.eq(init.delay_max),
            self.bitslip.status.eq(init.bitslip)
        ]

        # prbs
        prbs_cycles = Signal(32)
        prbs_errors = self.prbs_errors.status
        prbs_fsm = FSM(reset_state="IDLE")
        self.submodules += prbs_fsm
        prbs_fsm.act("IDLE",
            NextValue(prbs_cycles, 0),
            If(self.prbs_start.re,
                NextValue(prbs_errors, 0),
                NextState("CHECK")
            )
        )
        prbs_fsm.act("CHECK",
            NextValue(prbs_cycles, prbs_cycles + 1),
            If(self.prbs_error,
                NextValue(prbs_errors, prbs_errors + 1),
            ),
            If(prbs_cycles == self.prbs_cycles.storage,
                NextState("IDLE")
            )
        )


class SERWBPHY(Module, AutoCSR):
    def __init__(self, device, pads, mode="master", init_timeout=2**15):
        self.sink = sink = stream.Endpoint([("data", 32)])
        self.source = source = stream.Endpoint([("data", 32)])
        assert mode in ["master", "slave"]
        if device[:4] == "xcku":
            taps = 512
            self.submodules.serdes = KUSerdes(pads, mode)
        elif device[:4] == "xc7a":
            taps = 32
            self.submodules.serdes = S7Serdes(pads, mode)
        else:
            raise NotImplementedError
        if mode == "master":
            self.submodules.init = _SerdesMasterInit(self.serdes, taps, init_timeout)
        else:
            self.submodules.init = _SerdesSlaveInit(self.serdes, taps, init_timeout)
        self.submodules.control = _SerdesControl(self.serdes, self.init, mode)

        # tx/rx dataflow
        self.comb += [
            If(self.init.ready,
                If(sink.valid,
                    sink.connect(self.serdes.tx.sink),
                ),
                self.serdes.rx.source.connect(source)
            ).Else(
                self.serdes.rx.source.ready.eq(1)
            ),
            self.serdes.tx.sink.valid.eq(1) # always transmitting
        ]

        # For PRBS test we are using the scrambler/descrambler as PRBS,
        # sending 0 to the scrambler and checking that descrambler
        # output is always 0.
        self.comb += self.control.prbs_error.eq(
                source.valid &
                source.ready &
                (source.data != 0))
