/*
 * Decompiled with CFR 0.152.
 */
package com.amazon.dlic.auth.http.jwt.keybyoidc;

import com.amazon.dlic.auth.http.jwt.keybyoidc.BadCredentialsException;
import com.amazon.dlic.auth.http.jwt.keybyoidc.KeyProvider;
import com.google.common.base.Strings;
import com.nimbusds.jose.Algorithm;
import com.nimbusds.jose.JOSEException;
import com.nimbusds.jose.JWSAlgorithm;
import com.nimbusds.jose.JWSVerifier;
import com.nimbusds.jose.crypto.factories.DefaultJWSVerifierFactory;
import com.nimbusds.jose.jwk.JWK;
import com.nimbusds.jose.jwk.OctetSequenceKey;
import com.nimbusds.jwt.JWTClaimsSet;
import com.nimbusds.jwt.SignedJWT;
import com.nimbusds.jwt.proc.BadJWTException;
import com.nimbusds.jwt.proc.DefaultJWTClaimsVerifier;
import java.security.Key;
import java.text.ParseException;
import java.util.Collections;
import org.apache.commons.lang3.StringEscapeUtils;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;

public class JwtVerifier {
    private static final Logger log = LogManager.getLogger(JwtVerifier.class);
    private final KeyProvider keyProvider;
    private final int clockSkewToleranceSeconds;
    private final String requiredIssuer;
    private final String requiredAudience;

    public JwtVerifier(KeyProvider keyProvider, int clockSkewToleranceSeconds, String requiredIssuer, String requiredAudience) {
        this.keyProvider = keyProvider;
        this.clockSkewToleranceSeconds = clockSkewToleranceSeconds;
        this.requiredIssuer = requiredIssuer;
        this.requiredAudience = requiredAudience;
    }

    public SignedJWT getVerifiedJwtToken(String encodedJwt) throws BadCredentialsException {
        try {
            JWK key;
            JWSVerifier signatureVerifier;
            boolean signatureValid;
            String escapedKid;
            SignedJWT jwt = SignedJWT.parse((String)encodedJwt);
            String kid = escapedKid = jwt.getHeader().getKeyID();
            if (!Strings.isNullOrEmpty((String)kid)) {
                kid = StringEscapeUtils.unescapeJava((String)escapedKid);
            }
            if (!(signatureValid = jwt.verify(signatureVerifier = this.getInitializedSignatureVerifier(key = this.keyProvider.getKey(kid), jwt))) && Strings.isNullOrEmpty((String)kid)) {
                key = this.keyProvider.getKeyAfterRefresh(null);
                signatureVerifier = this.getInitializedSignatureVerifier(key, jwt);
                signatureValid = jwt.verify(signatureVerifier);
            }
            if (!signatureValid) {
                throw new BadCredentialsException("Invalid JWT signature");
            }
            this.validateClaims(jwt);
            return jwt;
        }
        catch (JOSEException | BadJWTException | ParseException e) {
            throw new BadCredentialsException(e.getMessage(), e);
        }
    }

    private void validateSignatureAlgorithm(JWK key, SignedJWT jwt) throws BadCredentialsException {
        JWSAlgorithm tokenAlgorithm;
        if (key.getAlgorithm() == null || jwt.getHeader().getAlgorithm() == null) {
            return;
        }
        Algorithm keyAlgorithm = key.getAlgorithm();
        if (!keyAlgorithm.equals((Object)(tokenAlgorithm = jwt.getHeader().getAlgorithm()))) {
            throw new BadCredentialsException("Algorithm of JWT does not match algorithm of JWK (" + keyAlgorithm + " != " + (Algorithm)tokenAlgorithm + ")");
        }
    }

    private JWSVerifier getInitializedSignatureVerifier(JWK key, SignedJWT jwt) throws BadCredentialsException, JOSEException {
        this.validateSignatureAlgorithm(key, jwt);
        JWSVerifier result = key.getClass() == OctetSequenceKey.class ? new DefaultJWSVerifierFactory().createJWSVerifier(jwt.getHeader(), (Key)key.toOctetSequenceKey().toSecretKey()) : new DefaultJWSVerifierFactory().createJWSVerifier(jwt.getHeader(), (Key)key.toRSAKey().toRSAPublicKey());
        if (result == null) {
            throw new BadCredentialsException("Cannot verify JWT");
        }
        return result;
    }

    private void validateClaims(SignedJWT jwt) throws ParseException, BadJWTException {
        JWTClaimsSet claims = jwt.getJWTClaimsSet();
        if (claims != null) {
            DefaultJWTClaimsVerifier claimsVerifier = new DefaultJWTClaimsVerifier(this.requiredAudience, null, Collections.emptySet());
            claimsVerifier.setMaxClockSkew(this.clockSkewToleranceSeconds);
            claimsVerifier.verify(claims, null);
            this.validateRequiredAudienceAndIssuer(claims);
        }
    }

    private void validateRequiredAudienceAndIssuer(JWTClaimsSet claims) throws BadJWTException {
        String audience = claims.getAudience().stream().findFirst().orElse("");
        String issuer = claims.getIssuer();
        if (!Strings.isNullOrEmpty((String)this.requiredAudience) && !this.requiredAudience.equals(audience)) {
            throw new BadJWTException("Invalid audience");
        }
        if (!Strings.isNullOrEmpty((String)this.requiredIssuer) && !this.requiredIssuer.equals(issuer)) {
            throw new BadJWTException("Invalid issuer");
        }
    }
}

