Last active
April 13, 2021 06:46
-
-
Save weidongxu-microsoft/c837ddf7c259ff2a4afe8a307e98d024 to your computer and use it in GitHub Desktop.
subclass of ApplicationTokenCredentials, with added method `clearCache`
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
/** | |
* Copyright (c) Microsoft Corporation. All rights reserved. | |
* Licensed under the MIT License. See License.txt in the project root for | |
* license information. | |
*/ | |
package com.microsoft.azure.azuresdktest; | |
import com.google.common.io.BaseEncoding; | |
import com.microsoft.aad.adal4j.AsymmetricKeyCredential; | |
import com.microsoft.aad.adal4j.AuthenticationContext; | |
import com.microsoft.aad.adal4j.AuthenticationException; | |
import com.microsoft.aad.adal4j.AuthenticationResult; | |
import com.microsoft.aad.adal4j.ClientCredential; | |
import com.microsoft.azure.AzureEnvironment; | |
import com.microsoft.azure.credentials.AzureTokenCredentials; | |
import java.io.ByteArrayInputStream; | |
import java.io.IOException; | |
import java.io.InputStream; | |
import java.security.KeyFactory; | |
import java.security.NoSuchAlgorithmException; | |
import java.security.PrivateKey; | |
import java.security.cert.CertificateException; | |
import java.security.cert.CertificateFactory; | |
import java.security.cert.X509Certificate; | |
import java.security.spec.InvalidKeySpecException; | |
import java.security.spec.PKCS8EncodedKeySpec; | |
import java.util.Date; | |
import java.util.concurrent.ConcurrentHashMap; | |
import java.util.concurrent.ExecutorService; | |
import java.util.concurrent.Executors; | |
import java.util.concurrent.Future; | |
import java.util.concurrent.TimeUnit; | |
import java.util.concurrent.locks.ReentrantLock; | |
import java.util.regex.Matcher; | |
import java.util.regex.Pattern; | |
public class CustomizedApplicationTokenCredentials extends AzureTokenCredentials { | |
private final ConcurrentHashMap<String, AuthenticationResult> tokens; | |
private final ConcurrentHashMap<String, ReentrantLock> authenticationLocks; | |
private final String clientId; | |
private String clientSecret; | |
private byte[] clientCertificate; | |
private String clientCertificatePassword; | |
private final long timeoutInSeconds = 60; | |
public CustomizedApplicationTokenCredentials(String clientId, String domain, String secret, AzureEnvironment environment) { | |
super(environment, domain); | |
this.clientId = clientId; | |
this.clientSecret = secret; | |
this.tokens = new ConcurrentHashMap<>(); | |
this.authenticationLocks = new ConcurrentHashMap<>(); | |
} | |
public CustomizedApplicationTokenCredentials(String clientId, String domain, byte[] certificate, String password, AzureEnvironment environment) { | |
super(environment, domain); | |
this.clientId = clientId; | |
this.clientCertificate = certificate; | |
this.clientCertificatePassword = password; | |
this.tokens = new ConcurrentHashMap<>(); | |
this.authenticationLocks = new ConcurrentHashMap<>(); | |
} | |
@Override | |
public String getToken(String resource) throws IOException { | |
AuthenticationResult authenticationResult = tokens.get(resource); | |
if (authenticationResult == null || authenticationResult.getExpiresOnDate().before(new Date())) { | |
ReentrantLock lock; | |
synchronized (authenticationLocks) { | |
lock = authenticationLocks.get(resource); | |
if (lock == null) { | |
lock = new ReentrantLock(); | |
authenticationLocks.put(resource, lock); | |
} | |
} | |
lock.lock(); | |
try { | |
authenticationResult = tokens.get(resource); | |
if (authenticationResult == null || authenticationResult.getExpiresOnDate().before(new Date())) { | |
ExecutorService executor = Executors.newSingleThreadExecutor(); | |
try { | |
authenticationResult = acquireAccessToken(resource, executor).get(timeoutInSeconds, TimeUnit.SECONDS); | |
tokens.put(resource, authenticationResult); | |
} finally { | |
executor.shutdown(); | |
} | |
} | |
} catch (Exception e) { | |
throw new IOException(e.getMessage(), e); | |
} finally { | |
lock.unlock(); | |
} | |
} | |
return authenticationResult.getAccessToken(); | |
} | |
private Future<AuthenticationResult> acquireAccessToken(String resource, ExecutorService executor) throws IOException { | |
String authorityUrl = this.environment().activeDirectoryEndpoint() + this.domain(); | |
AuthenticationContext context = new AuthenticationContext(authorityUrl, false, executor); | |
if (proxy() != null) { | |
context.setProxy(proxy()); | |
} | |
if (sslSocketFactory() != null) { | |
context.setSslSocketFactory(sslSocketFactory()); | |
} | |
try { | |
if (clientSecret != null) { | |
return context.acquireToken( | |
resource, | |
new ClientCredential(clientId, clientSecret), | |
null); | |
} else if (clientCertificate != null && clientCertificatePassword != null) { | |
return context.acquireToken( | |
resource, | |
AsymmetricKeyCredential.create(clientId, new ByteArrayInputStream(clientCertificate), clientCertificatePassword), | |
null); | |
} else if (clientCertificate != null) { | |
return context.acquireToken( | |
resource, | |
AsymmetricKeyCredential.create(clientId, privateKeyFromPem(new String(clientCertificate)), publicKeyFromPem(new String(clientCertificate))), | |
null); | |
} | |
throw new AuthenticationException("Please provide either a non-null secret or a non-null certificate."); | |
} catch (Exception e) { | |
throw new IOException(e.getMessage(), e); | |
} | |
} | |
private static PrivateKey privateKeyFromPem(String pem) { | |
Pattern pattern = Pattern.compile("(?s)-----BEGIN PRIVATE KEY-----.*-----END PRIVATE KEY-----"); | |
Matcher matcher = pattern.matcher(pem); | |
matcher.find(); | |
String base64 = matcher.group() | |
.replace("-----BEGIN PRIVATE KEY-----", "") | |
.replace("-----END PRIVATE KEY-----", "") | |
.replace("\n", "") | |
.replace("\r", ""); | |
byte[] key = BaseEncoding.base64().decode(base64); | |
PKCS8EncodedKeySpec spec = new PKCS8EncodedKeySpec(key); | |
try { | |
KeyFactory kf = KeyFactory.getInstance("RSA"); | |
return kf.generatePrivate(spec); | |
} catch (NoSuchAlgorithmException | InvalidKeySpecException e) { | |
throw new RuntimeException(e); | |
} | |
} | |
private static X509Certificate publicKeyFromPem(String pem) { | |
Pattern pattern = Pattern.compile("(?s)-----BEGIN CERTIFICATE-----.*-----END CERTIFICATE-----"); | |
Matcher matcher = pattern.matcher(pem); | |
matcher.find(); | |
try { | |
CertificateFactory factory = CertificateFactory.getInstance("X.509"); | |
InputStream stream = new ByteArrayInputStream(matcher.group().getBytes()); | |
return (X509Certificate) factory.generateCertificate(stream); | |
} catch (CertificateException e) { | |
throw new RuntimeException(e); | |
} | |
} | |
public void clearCache() { | |
for (String resource : tokens.keySet()) { | |
ReentrantLock lock; | |
synchronized (authenticationLocks) { | |
lock = authenticationLocks.get(resource); | |
if (lock == null) { | |
lock = new ReentrantLock(); | |
authenticationLocks.put(resource, lock); | |
} | |
} | |
lock.lock(); | |
tokens.remove(resource); | |
lock.unlock(); | |
} | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment