/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.security.auth.http.jwt.keybyoidc;

import com.nimbusds.jose.jwk.JWKSet;
import java.io.IOException;
import java.io.InputStream;
import java.text.ParseException;
import java.util.concurrent.TimeUnit;
import org.apache.hc.client5.http.cache.HttpCacheContext;
import org.apache.hc.client5.http.cache.HttpCacheStorage;
import org.apache.hc.client5.http.classic.methods.HttpGet;
import org.apache.hc.client5.http.config.RequestConfig;
import org.apache.hc.client5.http.impl.cache.BasicHttpCacheStorage;
import org.apache.hc.client5.http.impl.cache.CacheConfig;
import org.apache.hc.client5.http.impl.cache.CachingHttpClients;
import org.apache.hc.client5.http.impl.classic.CloseableHttpClient;
import org.apache.hc.client5.http.impl.classic.CloseableHttpResponse;
import org.apache.hc.client5.http.impl.classic.HttpClients;
import org.apache.hc.client5.http.impl.io.PoolingHttpClientConnectionManager;
import org.apache.hc.client5.http.impl.io.PoolingHttpClientConnectionManagerBuilder;
import org.apache.hc.client5.http.io.HttpClientConnectionManager;
import org.apache.hc.client5.http.ssl.DefaultClientTlsStrategy;
import org.apache.hc.client5.http.ssl.TlsSocketStrategy;
import org.apache.hc.core5.http.ClassicHttpRequest;
import org.apache.hc.core5.http.HttpEntity;
import org.apache.hc.core5.http.protocol.HttpContext;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.core.common.Strings;
import org.opensearch.security.DefaultObjectMapper;
import org.opensearch.security.auth.http.jwt.keybyoidc.AuthenticatorUnavailableException;
import org.opensearch.security.auth.http.jwt.keybyoidc.KeySetProvider;
import org.opensearch.security.auth.http.jwt.oidc.json.OpenIdProviderConfiguration;
import org.opensearch.security.util.SettingsBasedSSLConfigurator;

public class KeySetRetriever
implements KeySetProvider {
    private static final Logger log = LogManager.getLogger(KeySetRetriever.class);
    private static final long CACHE_STATUS_LOG_INTERVAL_MS = 3600000L;
    private String openIdConnectEndpoint;
    private SettingsBasedSSLConfigurator.SSLConfig sslConfig;
    private int requestTimeoutMs = 10000;
    private CacheConfig cacheConfig;
    private HttpCacheStorage oidcHttpCacheStorage;
    private int oidcCacheHits = 0;
    private int oidcCacheMisses = 0;
    private int oidcCacheHitsValidated = 0;
    private int oidcCacheModuleResponses = 0;
    private long oidcRequests = 0L;
    private long lastCacheStatusLog = 0L;
    private String jwksUri;
    private long maxResponseSizeBytes = -1L;
    private int maxKeyCount = -1;
    private boolean enableSecurityValidation = false;

    KeySetRetriever(String openIdConnectEndpoint, SettingsBasedSSLConfigurator.SSLConfig sslConfig, boolean useCacheForOidConnectEndpoint) {
        this.openIdConnectEndpoint = openIdConnectEndpoint;
        this.sslConfig = sslConfig;
        this.configureCache(useCacheForOidConnectEndpoint);
    }

    KeySetRetriever(SettingsBasedSSLConfigurator.SSLConfig sslConfig, boolean useCacheForOidConnectEndpoint, String jwksUri) {
        this.jwksUri = jwksUri;
        this.sslConfig = sslConfig;
        this.configureCache(useCacheForOidConnectEndpoint);
    }

    public static KeySetRetriever createForJwksUri(SettingsBasedSSLConfigurator.SSLConfig sslConfig, boolean useCacheForJwksEndpoint, String jwksUri, long maxResponseSizeBytes, int maxKeyCount) {
        KeySetRetriever retriever = new KeySetRetriever(sslConfig, useCacheForJwksEndpoint, jwksUri);
        retriever.enableSecurityValidation = true;
        retriever.maxResponseSizeBytes = maxResponseSizeBytes;
        retriever.maxKeyCount = maxKeyCount;
        return retriever;
    }

    /*
     * Enabled aggressive exception aggregation
     */
    @Override
    public JWKSet get() throws AuthenticatorUnavailableException {
        String uri = this.getJwksUri();
        HttpCacheStorage cacheStorage = this.oidcHttpCacheStorage;
        try (CloseableHttpClient httpClient = this.createHttpClient(cacheStorage);){
            JWKSet jWKSet;
            block24: {
                HttpGet httpGet = new HttpGet(uri);
                RequestConfig requestConfig = RequestConfig.custom().setConnectionRequestTimeout((long)this.getRequestTimeoutMs(), TimeUnit.MILLISECONDS).setConnectTimeout((long)this.getRequestTimeoutMs(), TimeUnit.MILLISECONDS).build();
                httpGet.setConfig(requestConfig);
                if (this.enableSecurityValidation) {
                    httpGet.setHeader("Accept", (Object)"application/json, application/jwk-set+json");
                }
                HttpCacheContext httpContext = null;
                if (cacheStorage != null) {
                    httpContext = new HttpCacheContext();
                }
                CloseableHttpResponse response = httpClient.execute((ClassicHttpRequest)httpGet, (HttpContext)httpContext);
                try {
                    long contentLength;
                    if (httpContext != null) {
                        this.logCacheResponseStatus(httpContext, true);
                    }
                    if (response.getCode() < 200 || response.getCode() >= 300) {
                        throw new AuthenticatorUnavailableException("Error while getting " + uri + ": " + response.getReasonPhrase());
                    }
                    HttpEntity httpEntity = response.getEntity();
                    if (httpEntity == null) {
                        throw new AuthenticatorUnavailableException("Error while getting " + uri + ": Empty response entity");
                    }
                    if (this.enableSecurityValidation && this.maxResponseSizeBytes > 0L && (contentLength = httpEntity.getContentLength()) > this.maxResponseSizeBytes) {
                        throw new AuthenticatorUnavailableException(String.format("JWKS response too large from %s: %d bytes (max: %d)", uri, contentLength, this.maxResponseSizeBytes));
                    }
                    JWKSet keySet = JWKSet.load((InputStream)httpEntity.getContent());
                    if (this.enableSecurityValidation && this.maxKeyCount > 0 && keySet.getKeys().size() > this.maxKeyCount) {
                        throw new AuthenticatorUnavailableException(String.format("JWKS from %s contains %d keys, but max allowed is %d", uri, keySet.getKeys().size(), this.maxKeyCount));
                    }
                    jWKSet = keySet;
                    if (response == null) break block24;
                }
                catch (Throwable throwable) {
                    try {
                        if (response != null) {
                            try {
                                response.close();
                            }
                            catch (Throwable throwable2) {
                                throwable.addSuppressed(throwable2);
                            }
                        }
                        throw throwable;
                    }
                    catch (ParseException e) {
                        throw new AuthenticatorUnavailableException("Error parsing JWKS from " + uri + ": " + e.getMessage(), e);
                    }
                }
                response.close();
            }
            return jWKSet;
        }
        catch (IOException e) {
            throw new AuthenticatorUnavailableException("Error while getting " + uri + ": " + String.valueOf(e), e);
        }
    }

    /*
     * Enabled aggressive exception aggregation
     */
    String getJwksUri() throws AuthenticatorUnavailableException {
        if (!Strings.isNullOrEmpty((String)this.jwksUri)) {
            return this.jwksUri;
        }
        if (Strings.isNullOrEmpty((String)this.openIdConnectEndpoint)) {
            throw new AuthenticatorUnavailableException("Either openid_connect_url or jwks_uri must be configured for OIDC Authentication backend");
        }
        try (CloseableHttpClient httpClient = this.createHttpClient(this.oidcHttpCacheStorage);){
            String string;
            block20: {
                HttpGet httpGet = new HttpGet(this.openIdConnectEndpoint);
                RequestConfig requestConfig = RequestConfig.custom().setConnectionRequestTimeout((long)this.getRequestTimeoutMs(), TimeUnit.MILLISECONDS).setConnectTimeout((long)this.getRequestTimeoutMs(), TimeUnit.MILLISECONDS).build();
                httpGet.setConfig(requestConfig);
                HttpCacheContext httpContext = null;
                if (this.oidcHttpCacheStorage != null) {
                    httpContext = new HttpCacheContext();
                }
                CloseableHttpResponse response = httpClient.execute((ClassicHttpRequest)httpGet, (HttpContext)httpContext);
                try {
                    if (httpContext != null) {
                        this.logCacheResponseStatus(httpContext);
                    }
                    if (response.getCode() < 200 || response.getCode() >= 300) {
                        throw new AuthenticatorUnavailableException("Error while getting " + this.openIdConnectEndpoint + ": " + response.getReasonPhrase());
                    }
                    HttpEntity httpEntity = response.getEntity();
                    if (httpEntity == null) {
                        throw new AuthenticatorUnavailableException("Error while getting " + this.openIdConnectEndpoint + ": Empty response entity");
                    }
                    OpenIdProviderConfiguration parsedEntity = (OpenIdProviderConfiguration)DefaultObjectMapper.objectMapper.readValue(httpEntity.getContent(), OpenIdProviderConfiguration.class);
                    string = parsedEntity.getJwksUri();
                    if (response == null) break block20;
                }
                catch (Throwable throwable) {
                    if (response != null) {
                        try {
                            response.close();
                        }
                        catch (Throwable throwable2) {
                            throwable.addSuppressed(throwable2);
                        }
                    }
                    throw throwable;
                }
                response.close();
            }
            return string;
        }
        catch (IOException e) {
            throw new AuthenticatorUnavailableException("Error while getting " + this.openIdConnectEndpoint + ": " + String.valueOf(e), e);
        }
    }

    public int getRequestTimeoutMs() {
        return this.requestTimeoutMs;
    }

    public void setRequestTimeoutMs(int httpTimeoutMs) {
        this.requestTimeoutMs = httpTimeoutMs;
    }

    private void logCacheResponseStatus(HttpCacheContext httpContext) {
        this.logCacheResponseStatus(httpContext, false);
    }

    private void logCacheResponseStatus(HttpCacheContext httpContext, boolean isJwksRequest) {
        boolean shouldCountStats;
        ++this.oidcRequests;
        boolean bl = shouldCountStats = this.jwksUri != null || isJwksRequest;
        if (!shouldCountStats) {
            log.debug("Skipping cache statistics for OIDC discovery request #{}", (Object)this.oidcRequests);
            return;
        }
        if (httpContext.getCacheResponseStatus() == null) {
            if (this.oidcHttpCacheStorage != null) {
                ++this.oidcCacheMisses;
                log.debug("Null cache status - counting as cache miss. Total misses: {}", (Object)this.oidcCacheMisses);
            }
        } else {
            switch (httpContext.getCacheResponseStatus()) {
                case CACHE_HIT: {
                    ++this.oidcCacheHits;
                    break;
                }
                case CACHE_MODULE_RESPONSE: {
                    ++this.oidcCacheModuleResponses;
                    break;
                }
                case CACHE_MISS: {
                    ++this.oidcCacheMisses;
                    break;
                }
                case VALIDATED: {
                    ++this.oidcCacheHits;
                    ++this.oidcCacheHitsValidated;
                }
            }
        }
        long now = System.currentTimeMillis();
        if (this.oidcRequests >= 2L && now - this.lastCacheStatusLog > 3600000L) {
            log.info("Cache status for KeySetRetriever:\noidcCacheHits: {}\noidcCacheHitsValidated: {}\noidcCacheModuleResponses: {}\noidcCacheMisses: {}", (Object)this.oidcCacheHits, (Object)this.oidcCacheHitsValidated, (Object)this.oidcCacheModuleResponses, (Object)this.oidcCacheMisses);
            this.lastCacheStatusLog = now;
        }
    }

    private CloseableHttpClient createHttpClient(HttpCacheStorage httpCacheStorage) {
        Object builder = httpCacheStorage != null ? CachingHttpClients.custom().setCacheConfig(this.cacheConfig).setHttpCacheStorage(httpCacheStorage) : HttpClients.custom();
        builder.useSystemProperties();
        if (this.sslConfig != null) {
            PoolingHttpClientConnectionManager cm = PoolingHttpClientConnectionManagerBuilder.create().setTlsSocketStrategy((TlsSocketStrategy)new DefaultClientTlsStrategy(this.sslConfig.getSslContext())).build();
            builder.setConnectionManager((HttpClientConnectionManager)cm);
        }
        return builder.build();
    }

    private void configureCache(boolean useCacheForOidConnectEndpoint) {
        if (useCacheForOidConnectEndpoint) {
            this.cacheConfig = CacheConfig.custom().setMaxCacheEntries(10).setMaxObjectSize(0x100000L).build();
            this.oidcHttpCacheStorage = new BasicHttpCacheStorage(this.cacheConfig);
        }
    }

    public int getOidcCacheHits() {
        return this.oidcCacheHits;
    }

    public int getOidcCacheMisses() {
        return this.oidcCacheMisses;
    }

    public int getOidcCacheHitsValidated() {
        return this.oidcCacheHitsValidated;
    }

    public int getOidcCacheModuleResponses() {
        return this.oidcCacheModuleResponses;
    }
}

