Created
January 9, 2024 08:30
-
-
Save svaponi/586a252f19a3189cd3dcc1c23cd5c9ac to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package io.github.svaponi.auth0 | |
import com.fasterxml.jackson.core.JsonProcessingException | |
import com.fasterxml.jackson.core.type.TypeReference | |
import com.fasterxml.jackson.databind.DeserializationFeature | |
import com.fasterxml.jackson.databind.JsonNode | |
import com.fasterxml.jackson.databind.MapperFeature | |
import com.fasterxml.jackson.databind.json.JsonMapper | |
import com.sun.net.httpserver.HttpExchange | |
import com.sun.net.httpserver.HttpHandler | |
import com.sun.net.httpserver.HttpServer | |
import mu.KLogging | |
import org.bouncycastle.asn1.x500.X500Name | |
import org.bouncycastle.asn1.x509.BasicConstraints | |
import org.bouncycastle.asn1.x509.Extension | |
import org.bouncycastle.cert.jcajce.JcaX509CertificateConverter | |
import org.bouncycastle.cert.jcajce.JcaX509v3CertificateBuilder | |
import org.bouncycastle.jce.provider.BouncyCastleProvider | |
import org.bouncycastle.operator.jcajce.JcaContentSignerBuilder | |
import java.io.File | |
import java.io.FileInputStream | |
import java.io.IOException | |
import java.math.BigInteger | |
import java.net.InetSocketAddress | |
import java.net.ServerSocket | |
import java.net.URL | |
import java.nio.charset.StandardCharsets | |
import java.security.InvalidKeyException | |
import java.security.KeyFactory | |
import java.security.KeyPair | |
import java.security.KeyPairGenerator | |
import java.security.MessageDigest | |
import java.security.NoSuchAlgorithmException | |
import java.security.Signature | |
import java.security.SignatureException | |
import java.security.cert.X509Certificate | |
import java.security.interfaces.RSAPrivateKey | |
import java.security.interfaces.RSAPublicKey | |
import java.security.spec.PKCS8EncodedKeySpec | |
import java.security.spec.X509EncodedKeySpec | |
import java.time.Duration | |
import java.time.Instant | |
import java.util.Base64 | |
import java.util.Date | |
import java.util.UUID | |
import java.util.concurrent.atomic.AtomicBoolean | |
import java.util.function.Function | |
import java.util.regex.Pattern | |
private object Utils { | |
val encoder: Base64.Encoder = Base64.getEncoder() | |
val urlEncoder: Base64.Encoder = Base64.getUrlEncoder() | |
val mapper: JsonMapper = JsonMapper.builder() | |
.configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false) | |
.build() | |
} | |
class Auth0Server(port: Int = -1) { | |
val tokenCreator: TokenCreator | |
private val keyPair: KeyPair | |
private val serverCert: X509Certificate | |
private val kid: String | |
private val server: Server | |
init { | |
val certDir = File(System.getProperty("java.io.tmpdir"), javaClass.getCanonicalName()) | |
logger.debug("Certificates dir: {}", certDir) | |
val certProvider = CertProvider(certDir) | |
keyPair = certProvider.keyPair | |
serverCert = certProvider.certificate | |
kid = UUID.randomUUID().toString() | |
server = Server(port) | |
.registerResponseProvider(Pattern.compile("/health")) { buildHealthResponse() } | |
.registerResponseProvider(Pattern.compile("/oauth/token")) { buildTokenResponse(it) } | |
.registerResponseProvider(Pattern.compile("/.well-known/jwks.json")) { buildJwksResponse() } | |
.registerResponseProvider(Pattern.compile("/.well-known/openid-configuration")) { buildOpenIdConfigResponse() } | |
tokenCreator = TokenCreator(keyPair.private as RSAPrivateKey, kid, issuer) | |
} | |
fun start(): Auth0Server { | |
server.start() | |
return this | |
} | |
fun stop(): Auth0Server { | |
server.stop() | |
return this | |
} | |
val issuer: String | |
get() = "http://localhost:$port/" | |
val port: Int | |
get() = server.port | |
private var status: String = "UP" | |
private fun buildHealthResponse(): Server.Response { | |
val health = mapOf("status" to status) | |
return Server.Response( | |
status = 200, | |
body = Utils.mapper.valueToTree<JsonNode>(health).toString(), | |
headers = defaultHeaders | |
) | |
} | |
private val jwks: Map<*, *> by lazy { | |
val publicKey = keyPair.public as RSAPublicKey | |
val n = Utils.urlEncoder.encodeToString(publicKey.modulus.toByteArray()) | |
val e = Utils.urlEncoder.encodeToString(publicKey.publicExponent.toByteArray()) | |
val x5c = listOf(Utils.encoder.encodeToString(serverCert.encoded)) | |
val md = MessageDigest.getInstance("SHA-1") | |
md.update(serverCert.encoded) | |
val x5t = Utils.urlEncoder.encodeToString(md.digest()) | |
val key = mapOf( | |
"alg" to "RS256", | |
"kty" to "RSA", | |
"use" to "sig", | |
"n" to n, | |
"e" to e, | |
"kid" to kid, | |
"x5t" to x5t, | |
"x5c" to x5c, | |
) | |
mapOf("keys" to listOf(key)) | |
} | |
private fun buildJwksResponse(): Server.Response { | |
return Server.Response( | |
status = 200, | |
body = Utils.mapper.valueToTree<JsonNode>(jwks).toString(), | |
headers = defaultHeaders | |
) | |
} | |
private val openIdConfig: Map<*, *> by lazy { | |
mapOf( | |
"jwks_uri" to "$issuer.well-known/jwks.json", | |
"issuer" to issuer | |
) | |
} | |
private fun buildOpenIdConfigResponse(): Server.Response { | |
return Server.Response( | |
status = 200, | |
body = Utils.mapper.valueToTree<JsonNode>(openIdConfig).toString(), | |
headers = defaultHeaders | |
) | |
} | |
private fun buildTokenResponse(httpExchange: HttpExchange): Server.Response { | |
return try { | |
val requestBody: JsonNode = Utils.mapper.readTree(httpExchange.requestBody) | |
val grantType = requestBody.at("/grant_type").asText(null) | |
assert("client_credentials".equals(grantType, ignoreCase = true)) { "Unsupported grant_type" } | |
val clientId = requestBody.at("/client_id").asText(null) | |
assert(clientId.isNotBlank()) { "Missing client_id" } | |
val clientSecret = requestBody.at("/client_secret").asText(null) | |
assert(clientSecret.isNotBlank()) { "Missing client_secret" } | |
val audience = requestBody.at("/audience").asText(null) | |
assert(audience.isNotBlank()) { "Missing audience" } | |
val clientConfig = ClientConfigProvider.getByClientId(clientId) | |
assert(clientSecret == clientConfig.clientSecret) { "Invalid credentials" } | |
val tokenConfig: TokenConfig = clientConfig.audiences[audience] | |
?: throw IllegalArgumentException("Unknown audience '$audience' for client_id '$clientId'") | |
val claims: MutableMap<String, Any> = LinkedHashMap(tokenConfig.claims) | |
claims["gty"] = "client-credentials" // necessary for machine to machine tokens | |
val accessToken = tokenCreator.create() | |
// if missing subject, use `{clientId}@clients` to simulate Auth0 behaviour | |
.subject(tokenConfig.subject ?: "$clientId@clients") | |
.audience(audience) | |
.claims(claims) | |
.expiresIn(tokenConfig.expiresIn) | |
.sign() | |
val response: MutableMap<String, Any> = LinkedHashMap() | |
response["access_token"] = accessToken | |
response["scope"] = claims.getOrDefault("scope", "") | |
response["expires_in"] = tokenConfig.expiresIn | |
response["token_type"] = "Bearer" | |
Server.Response( | |
status = 200, | |
body = Utils.mapper.valueToTree<JsonNode>(response).toString(), | |
headers = defaultHeaders | |
) | |
} catch (e: Exception) { | |
val response: MutableMap<String, Any?> = LinkedHashMap() | |
response["error"] = e.javaClass.getCanonicalName() | |
response["error_description"] = e.message | |
Server.Response( | |
status = if (e is IllegalArgumentException) 400 else 500, | |
body = Utils.mapper.valueToTree<JsonNode>(response).toString(), | |
headers = defaultHeaders | |
) | |
} | |
} | |
companion object : KLogging() { | |
private val defaultHeaders: Map<String, List<String>> = mapOf( | |
"Content-Type" to listOf("application/json"), | |
"Cache-Control" to listOf("private, no-store, no-cache, must-revalidate, post-check=0, pre-check=0, no-transform") | |
) | |
} | |
} | |
private object ClientConfigProvider { | |
private val configs: Collection<ClientConfig> by lazy { | |
Thread.currentThread().getContextClassLoader().getResource("auth0_server.json") | |
?.let { obj: URL -> File(obj.file) } | |
?.takeIf { it.exists() } | |
?.let { Utils.mapper.readValue(FileInputStream(it), object : TypeReference<Collection<ClientConfig>>() {}) } | |
?: emptyList() | |
} | |
fun getByClientId(clientId: String): ClientConfig { | |
return configs.firstOrNull { clientId == it.clientId } | |
?: throw IllegalArgumentException("Unknown client_id '$clientId'") | |
} | |
} | |
private data class ClientConfig( | |
val clientId: String? = null, | |
val clientSecret: String? = null, | |
val audiences: Map<String, TokenConfig> = HashMap() | |
) | |
private data class TokenConfig( | |
val subject: String? = null, | |
val claims: Map<String, Any>? = null, | |
val expiresIn: Long = 3600 | |
) | |
class TokenCreator(val privateKey: RSAPrivateKey, val kid: String, val issuer: String) { | |
object PublicClaims { | |
//Header | |
const val ALGORITHM = "alg" | |
const val CONTENT_TYPE = "cty" | |
const val TYPE = "typ" | |
const val KEY_ID = "kid" | |
//Payload | |
const val ISSUER = "iss" | |
const val SUBJECT = "sub" | |
const val EXPIRES_AT = "exp" | |
const val NOT_BEFORE = "nbf" | |
const val ISSUED_AT = "iat" | |
const val JWT_ID = "jti" | |
const val AUDIENCE = "aud" | |
} | |
class TokenException(message: String, cause: Throwable) : Exception(message, cause) | |
fun create(): Token { | |
return Token() | |
} | |
inner class Token { | |
private val signed: AtomicBoolean = AtomicBoolean(false) | |
private val payloadClaims: MutableMap<String, Any> = HashMap() | |
private val headerClaims: MutableMap<String, Any> = HashMap() | |
/** | |
* @param subject the Subject ("sub") claim. | |
*/ | |
fun subject(subject: String): Token { | |
payloadClaims[PublicClaims.SUBJECT] = subject | |
return this | |
} | |
/** | |
* @param audience the Audience ("aud") claim. | |
*/ | |
fun audience(audience: String): Token { | |
payloadClaims[PublicClaims.AUDIENCE] = audience | |
return this | |
} | |
/** | |
* @param audience the Audience ("aud") claim. | |
*/ | |
fun audience(audience: Collection<String?>): Token { | |
payloadClaims[PublicClaims.AUDIENCE] = audience | |
return this | |
} | |
/** | |
* @param expiresAt the Expires At ("exp") claim. | |
*/ | |
fun expiresAt(expiresAt: Instant): Token { | |
payloadClaims[PublicClaims.EXPIRES_AT] = expiresAt.epochSecond | |
return this | |
} | |
/** | |
* @param expiresIn time to live in seconds. | |
*/ | |
fun expiresIn(expiresIn: Long): Token { | |
return expiresAt(Instant.now().plusSeconds(expiresIn)) | |
} | |
/** | |
* @param claims additional custom JWT claims to the Payload. | |
*/ | |
fun claims(claims: Map<String, Any>?): Token { | |
if (claims != null) { | |
payloadClaims.putAll(claims) | |
} | |
return this | |
} | |
/** | |
* Creates a new JWT and signs is with the given algorithm | |
* | |
* @return a new JWT token | |
* @throws IllegalArgumentException if the provided algorithm is null. | |
*/ | |
@Throws(TokenException::class) | |
fun sign(): String { | |
if (signed.compareAndSet(false, true)) { | |
return try { | |
val s = Signature.getInstance(SIGNATURE) // validates algorithm before proceeding, fast fail | |
headerClaims.putIfAbsent(PublicClaims.ALGORITHM, ALGORITHM) | |
headerClaims.putIfAbsent(PublicClaims.TYPE, "JWT") | |
headerClaims.putIfAbsent(PublicClaims.KEY_ID, kid) | |
val headerJsonBytes = JSON_MAPPER.writeValueAsBytes(headerClaims) | |
payloadClaims.putIfAbsent(PublicClaims.ISSUER, issuer) | |
payloadClaims.putIfAbsent(PublicClaims.JWT_ID, UUID.randomUUID().toString()) | |
payloadClaims.putIfAbsent(PublicClaims.ISSUED_AT, Instant.now().epochSecond) | |
val payloadJsonBytes = JSON_MAPPER.writeValueAsBytes(payloadClaims) | |
val header = ENCODER.encodeToString(headerJsonBytes) | |
val payload = ENCODER.encodeToString(payloadJsonBytes) | |
val contentBytes = String.format("%s.%s", header, payload).toByteArray(StandardCharsets.UTF_8) | |
s.initSign(privateKey) | |
s.update(contentBytes) | |
val signatureBytes = s.sign() | |
val signature = ENCODER.encodeToString(signatureBytes) | |
String.format("%s.%s.%s", header, payload, signature) | |
} catch (e: NoSuchAlgorithmException) { | |
throw TokenException("Algorithm $SIGNATURE not found", e) | |
} catch (e: SignatureException) { | |
throw TokenException("Signing error", e) | |
} catch (e: InvalidKeyException) { | |
throw TokenException("Invalid key pair", e) | |
} catch (e: JsonProcessingException) { | |
throw TokenException("Some of the Claims couldn't be converted to a valid JSON format.", e) | |
} finally { | |
headerClaims.clear() | |
payloadClaims.clear() | |
} | |
} | |
throw IllegalStateException("Token already signed") | |
} | |
} | |
companion object { | |
private const val SIGNATURE = "SHA256withRSA" | |
private const val ALGORITHM = "RS256" | |
private val ENCODER: Base64.Encoder = Base64.getUrlEncoder().withoutPadding() | |
private val JSON_MAPPER: JsonMapper = JsonMapper.builder() | |
.configure(MapperFeature.SORT_PROPERTIES_ALPHABETICALLY, true) | |
.build() | |
} | |
} | |
/** | |
* Starts a simple web application, just register responses via methods: | |
* | |
* @see Server.registerResponse | |
* @see Server.registerResponseProvider | |
*/ | |
private class Server(port: Int = -1) { | |
private val pathPatternToResponseProvider: MutableMap<Pattern, Function<HttpExchange, Response>> = HashMap() | |
private val address: InetSocketAddress by lazy { InetSocketAddress(if (port > 0) port else findFreePort()) } | |
private var httpServer: HttpServer? = null | |
data class Response( | |
val status: Int = 0, | |
val body: String? = null, | |
val headers: Map<String, List<String>>? = null | |
) | |
fun registerResponse( | |
pathPattern: Pattern, | |
status: Int, | |
body: String?, | |
headers: Map<String, List<String>>? | |
): Server { | |
pathPatternToResponseProvider[pathPattern] = | |
Function<HttpExchange, Response> { Response(status, body, headers) } | |
return this | |
} | |
fun registerResponseProvider(pathPattern: Pattern, responseProvider: Function<HttpExchange, Response>): Server { | |
pathPatternToResponseProvider[pathPattern] = responseProvider | |
return this | |
} | |
val port: Int | |
get() = address.port | |
/** | |
* Recreates an instance of HttpServer. From Javadoc: "Once stopped, a HttpServer cannot be re-used." | |
* | |
* @see HttpServer.stop | |
*/ | |
private fun createHttpServer(): HttpServer { | |
return try { | |
System.setProperty("sun.net.httpserver.maxReqTime", "1000") | |
System.setProperty("sun.net.httpserver.maxRspTime", "1000") | |
val httpServer = HttpServer.create(address, 0) | |
httpServer.createContext("/", HttpHandlerImpl()) // handles any request /** | |
httpServer | |
} catch (e: IOException) { | |
throw IllegalStateException("impossible to start server: " + e.message, e) | |
} | |
} | |
private inner class HttpHandlerImpl : HttpHandler { | |
@Throws(IOException::class) | |
override fun handle(httpExchange: HttpExchange) { | |
val path: String = httpExchange.requestURI.getRawSchemeSpecificPart() | |
try { | |
val responseProvider: Function<HttpExchange, Response>? = | |
pathPatternToResponseProvider.keys.stream() | |
.filter { pathPattern: Pattern -> pathPattern.matcher(path).matches() } | |
.findFirst() | |
.map<Function<HttpExchange, Response>?> { pathPattern: Pattern -> pathPatternToResponseProvider[pathPattern] } | |
.orElse(null) | |
if (responseProvider != null) { | |
val response: Response = responseProvider.apply(httpExchange) | |
if (response.headers != null) { | |
httpExchange.responseHeaders.putAll(response.headers) | |
} | |
if (response.body == null || response.status == 204) { | |
httpExchange.sendResponseHeaders(response.status, -1) | |
} else { | |
val body = response.body | |
httpExchange.sendResponseHeaders(response.status, body.length.toLong()) | |
httpExchange.responseBody.use { os -> os.write(body.toByteArray()) } | |
logger.info("${httpExchange.requestMethod} $path >>> ${response.status} $body") | |
return | |
} | |
} | |
httpExchange.sendResponseHeaders(404, -1) | |
logger.warn("${httpExchange.requestMethod} $path >>> 404 NOT_FOUND") | |
} catch (e: Exception) { | |
httpExchange.sendResponseHeaders(500, -1) | |
logger.error("${httpExchange.requestMethod} $path >>> 500 SERVER_ERROR ${e.javaClass.getSimpleName()} ${e.message}") | |
} | |
} | |
} | |
private fun findFreePort(): Int { | |
while (true) { | |
try { | |
ServerSocket(0).use { it.getLocalPort() } | |
} catch (ignore: IOException) { | |
} | |
} | |
} | |
fun start(): Server { | |
if (httpServer == null) { | |
logger.debug("Starting on http://localhost:{}", this.port) | |
httpServer = createHttpServer() | |
httpServer!!.start() | |
logger.info("Started on http://localhost:{}", this.port) | |
for (pathPattern in pathPatternToResponseProvider.keys) { | |
logger.debug("Active url http://localhost:{}{}", this.port, pathPattern.toString()) | |
} | |
} else { | |
logger.debug("Already running on http://localhost:{}", this.port) | |
} | |
return this | |
} | |
fun stop() { | |
if (httpServer != null) { | |
logger.debug("Stopping on http://localhost:{}", this.port) | |
httpServer!!.stop(0) | |
httpServer = null | |
logger.info("Stopped on http://localhost:{}", this.port) | |
} else { | |
logger.debug("Already stopped on http://localhost:{}", this.port) | |
} | |
} | |
companion object : KLogging() | |
} | |
/** | |
* Generates a valid KeyPair. If any, re-uses previously created key. | |
*/ | |
private class CertProvider(private val certDir: File, private val algorithm: String = "RSA") { | |
val keyPair: KeyPair by lazy { | |
val privateKeyFile = File(certDir, "$algorithm.key") | |
val publicKeyFile = File(certDir, "$algorithm.pub") | |
if (privateKeyFile.exists() && publicKeyFile.exists()) { | |
val kf = KeyFactory.getInstance(algorithm) | |
val privateKeyBytes = privateKeyFile.readBytes() | |
val publicKeyBytes = publicKeyFile.readBytes() | |
val pvt = kf.generatePrivate(PKCS8EncodedKeySpec(privateKeyBytes)) | |
val pub = kf.generatePublic(X509EncodedKeySpec(publicKeyBytes)) | |
KeyPair(pub, pvt) | |
} else { | |
// create parent dirs before writing!! | |
privateKeyFile.parentFile.mkdirs() | |
publicKeyFile.parentFile.mkdirs() | |
val kpg = KeyPairGenerator.getInstance(algorithm) | |
kpg.initialize(4096) | |
val kp = kpg.generateKeyPair() | |
privateKeyFile.writeBytes(kp.private.encoded) | |
publicKeyFile.writeBytes(kp.public.encoded) | |
kp | |
} | |
} | |
val certificate: X509Certificate by lazy { | |
val now = Instant.now() | |
val notBefore = Date.from(now) | |
val notAfter = Date.from(now.plus(Duration.ofDays(1))) | |
val contentSigner = JcaContentSignerBuilder("SHA256WithRSA").build(keyPair.private) | |
val x500Name = X500Name("CN=" + javaClass.getCanonicalName()) | |
val certificate = JcaX509v3CertificateBuilder( | |
/* issuer = */ x500Name, | |
/* serial = */ BigInteger.valueOf(now.toEpochMilli()), | |
/* notBefore = */ notBefore, | |
/* notAfter = */ notAfter, | |
/* subject = */ x500Name, | |
/* publicKey = */ keyPair.public | |
) | |
.addExtension(Extension.basicConstraints, true, BasicConstraints(true)) | |
.build(contentSigner) | |
JcaX509CertificateConverter() | |
.setProvider(BouncyCastleProvider()) | |
.getCertificate(certificate) | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment