Skip to content

Instantly share code, notes, and snippets.

@thekalinga
Last active September 19, 2022 22:24
Show Gist options
  • Save thekalinga/bb34c82e8883248d508623a589c50744 to your computer and use it in GitHub Desktop.
Save thekalinga/bb34c82e8883248d508623a589c50744 to your computer and use it in GitHub Desktop.
Spring Reactor Netty reactive(non-blocking) websocket integration that converts a text file into audio. Splits file into chunks, over websocket (text & binary frames) downloads multiple binary audio fragments for each chunk, merges all fragments into a chunk file, finally merges all chunk files into final mp3 from Azure without blocking threads
package com.acme.tts.converter.microsoft.ssml.websocket;
import com.acme.tts.converter.microsoft.ssml.Prosody;
import com.acme.tts.converter.microsoft.ssml.Speak;
import com.acme.tts.converter.microsoft.ssml.Voice;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.SerializationFeature;
import com.fasterxml.jackson.dataformat.xml.XmlMapper;
import io.netty.handler.codec.http.websocketx.BinaryWebSocketFrame;
import io.netty.handler.codec.http.websocketx.CloseWebSocketFrame;
import io.netty.handler.codec.http.websocketx.PingWebSocketFrame;
import io.netty.handler.codec.http.websocketx.TextWebSocketFrame;
import lombok.extern.log4j.Log4j2;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
import reactor.core.publisher.Sinks;
import reactor.netty.http.client.HttpClient;
import reactor.netty.http.client.WebsocketClientSpec;
import java.nio.file.Path;
import java.time.Duration;
import java.time.Instant;
import java.time.ZoneId;
import java.time.ZonedDateTime;
import java.time.format.DateTimeFormatter;
import java.util.List;
import java.util.UUID;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicReference;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import static com.acme.tts.converter.microsoft.ssml.websocket.Marker.INSTANCE;
import static io.earcam.unexceptional.Exceptional.get;
import static java.nio.charset.StandardCharsets.UTF_8;
import static java.time.Duration.ofSeconds;
import static reactor.core.publisher.Sinks.EmitFailureHandler.FAIL_FAST;
@Log4j2
class ChunksAudioFetcherAndPersister {
public static final String WEBSOCKET_HTTPS_URL_PREFIX =
"https://eastus.api.speech.microsoft.com/cognitiveservices/websocket/v1?TrafficType=AzureDemo&Authorization=bearer%20undefined&X-ConnectionId=";
private static final String CRLF = "\r\n";
private static final String DOUBLE_CRLF = CRLF + CRLF;
private static final Duration PING_MESSAGE_WEBSOCKET_DURATION = ofSeconds(10);
private static final String TURN_START_HEADER_VALUE = "turn\\.start";
private static final String X_RATE_LIMIT_HEADER_MAX_ALLOWED = "X-RateLimit-Limit";
private static final String X_RATE_LIMIT_HEADER_REMAINING = "X-RateLimit-Remaining";
private static final String X_RATE_LIMIT_HEADER_LIMIT_RESET_AT_GMT = "X-RateLimit-Reset";
private static final String X_REQUEST_HEADER_NAME = "X-RequestId";
private static final Pattern X_REQUEST_ID_PATTERN =
Pattern.compile(X_REQUEST_HEADER_NAME + ":(?<requestId>[^\r]+)");
private static final String PATH_HEADER = "Path";
public static final Pattern
PATH_TURN_END_HEADER_PATTERN = Pattern.compile(PATH_HEADER + ":\s*turn\\.end" + CRLF);
private static final Pattern PATH_TURN_START_HEADER_PATTERN =
Pattern.compile(PATH_HEADER + ":\s*" + TURN_START_HEADER_VALUE + CRLF);
private final HttpClient client;
private final WebsocketClientSpec websocketClientSpec;
private final List<String> remainingChunks;
private final Sinks.Many<Integer> retryFromTextChunkIndexOnAbruptCompleteSink;
private final Path dirToDownloadTo;
private final int startingChunkNumberInOriginalChunks;
private final int totalChunkCountOfOriginalChunks;
private final Path inputFilePath;
private final ObjectMapper objectMapper = new XmlMapper();
{
objectMapper.enable(SerializationFeature.INDENT_OUTPUT);
}
ChunksAudioFetcherAndPersister(HttpClient client, WebsocketClientSpec websocketClientSpec,
List<String> remainingChunks, Sinks.Many<Integer> retryFromTextChunkIndexOnAbruptCompleteSink,
Path dirToDownloadTo, int startingChunkNumberInOriginalChunks, int totalChunkCountOfOriginalChunks, Path inputFilePath) {
this.client = client;
this.websocketClientSpec = websocketClientSpec;
this.remainingChunks = remainingChunks;
this.retryFromTextChunkIndexOnAbruptCompleteSink = retryFromTextChunkIndexOnAbruptCompleteSink;
this.dirToDownloadTo = dirToDownloadTo;
this.startingChunkNumberInOriginalChunks = startingChunkNumberInOriginalChunks;
this.totalChunkCountOfOriginalChunks = totalChunkCountOfOriginalChunks;
this.inputFilePath = inputFilePath;
}
Flux<Void> retrieve() {
final var audioChunksPersistingSink = Sinks.many().unicast().<Sinks.Many<byte[]>>onBackpressureBuffer();
final var audioChunkDownloadCompleteInfiniteStreamSink = Sinks.many().multicast().<Marker>onBackpressureBuffer();
final var inputRequestStreamSignalSink = Sinks.many().unicast().onBackpressureBuffer(); // this is emits one item per input item followed by onComplete
final var errorNotifyingSink = Sinks.one();
final var shortCircuitCompletionSink = Sinks.one();
final var allChunksDownloadComplete$ = inputRequestStreamSignalSink.asFlux()
.zipWith(audioChunkDownloadCompleteInfiniteStreamSink.asFlux()).cast(Object.class) // since we are zipping both input text & output mp3 generation, this will complete only when output has produced as many items as the input
.then(Mono.just(INSTANCE).cast(Object.class)) // once both of above are done, lets emit an item.
.mergeWith(errorNotifyingSink.asMono())
.takeUntilOther(shortCircuitCompletionSink.asMono());
final var chunkPersister$ = new ChunksPersister(
audioChunksPersistingSink, audioChunkDownloadCompleteInfiniteStreamSink,
dirToDownloadTo, startingChunkNumberInOriginalChunks, totalChunkCountOfOriginalChunks,
inputFilePath)
.asMono();
final var wsHttpsUri = WEBSOCKET_HTTPS_URL_PREFIX + generateRequestId();
//noinspection UnnecessaryLocalVariable
final var webSocketInBoundOutboundMerged$ = client.websocket(websocketClientSpec)
.uri(wsHttpsUri)
.handle((inbound, outbound) -> {
final var inboundHeaders = inbound.headers();
boolean rateLimitHeadersExist = inboundHeaders.contains(X_RATE_LIMIT_HEADER_MAX_ALLOWED) &&
inboundHeaders.contains(X_RATE_LIMIT_HEADER_REMAINING) &&
inboundHeaders.contains(X_RATE_LIMIT_HEADER_LIMIT_RESET_AT_GMT);
if(rateLimitHeadersExist) {
final var maxAllowedRequests = Integer.parseInt(inboundHeaders.get(
X_RATE_LIMIT_HEADER_MAX_ALLOWED));
final var remainingAllowedRequests = Integer.parseInt(inboundHeaders.get(X_RATE_LIMIT_HEADER_REMAINING));
final var limitResetsAtInGmt = DateTimeFormatter.ISO_DATE_TIME.parse(inboundHeaders.get(X_RATE_LIMIT_HEADER_LIMIT_RESET_AT_GMT), ZonedDateTime::from);
log.trace("Rate limiting from server: maxAllowedRequests: {}, remainingAllowedRequests: {}, limitResetsAt: {}", maxAllowedRequests, remainingAllowedRequests, limitResetsAtInGmt.withZoneSameLocal(
ZoneId.of(ZoneId.SHORT_IDS.get("IST"))));
if (remainingAllowedRequests < 0) {
final var nowInGmt = ZonedDateTime.now(ZoneId.of(ZoneId.SHORT_IDS.get("GMT")));
final var waitPeriod = Duration.between(nowInGmt, limitResetsAtInGmt);
return Mono.delay(waitPeriod)
.doOnSubscribe(__ -> log.debug("Server is rate limiting us: Next attempt would be made in {}", waitPeriod))
.flatMapMany(__ -> retrieve());
}
}
final var downloadCompleteChunkIndex = new AtomicInteger(-1);
final var requestId = new AtomicReference<String>(null);
final var audioStreamCollectingSink = new AtomicReference<Sinks.Many<byte[]>>(null);
final var turnStarted = new AtomicBoolean(false);
final var turnEnded = new AtomicBoolean(false);
final var inbound$ = inbound
.receiveFrames()
// .log("com.acme.inbound-signals")
// .doOnCancel(() -> log.trace("doOnCancel: We cancelled"))
// .doOnComplete(() -> log.trace("doOnComplete: Framework completed - reached end of session / otherside sent completion"))
// .doOnError(e -> System.err.println("doOnError: Framework erred - n/w error / otherside erred out / framework level issue"))
.handle((wsf, downstreamSink) -> {
log.trace("Received frame of type " + wsf.getClass());
if (wsf instanceof TextWebSocketFrame textWebSocketFrame) {
final var frameContent = textWebSocketFrame.text();
// log.trace("↓↓↓↓↓↓↓↓ textWebSocketFrame.text() ↓↓↓↓↓↓↓↓\n" + frameContent);
final var bodyStartIndex = frameContent.indexOf(DOUBLE_CRLF) + DOUBLE_CRLF.length();
final var headerSection = frameContent.substring(0, bodyStartIndex);
log.trace("↓↓↓↓↓↓↓ header section ↓↓↓↓↓↓\n" + headerSection);
boolean turnStartFrame = PATH_TURN_START_HEADER_PATTERN.matcher(headerSection).find();
log.trace("Did turn start? " + turnStartFrame);
if (!turnStarted.get() && turnStartFrame) {
Matcher requestIdMatcher = X_REQUEST_ID_PATTERN.matcher(headerSection);
if (requestIdMatcher.find()) {
final var matchedRequestId = requestIdMatcher.group("requestId");
requestId.set(matchedRequestId);
} else {
// lets fail application
log.debug("Text segment(s) that caused the error are {} \n", remainingChunks);
errorNotifyingSink.tryEmitError(new RuntimeException(
"Received turn.start TextWebsocketFrame. But it did not contain " + X_REQUEST_HEADER_NAME
+ " header"));
}
final var newAudioStreamCollectingSink = Sinks.many().unicast().<byte[]>onBackpressureBuffer();
audioStreamCollectingSink.set(newAudioStreamCollectingSink);
audioChunksPersistingSink.emitNext(newAudioStreamCollectingSink, FAIL_FAST);
turnEnded.set(false);
turnStarted.set(true);
}
boolean turnEndFrame = PATH_TURN_END_HEADER_PATTERN.matcher(headerSection).find();
log.trace("Did turn end? " + turnEndFrame);
if (!turnEnded.get() && turnEndFrame) {
audioStreamCollectingSink.get().emitComplete(FAIL_FAST);
requestId.set(null);
downloadCompleteChunkIndex.incrementAndGet();
turnStarted.set(false);
turnEnded.set(true);
if (downloadCompleteChunkIndex.get() + 1 == remainingChunks.size()) {
log.info("Download of current chunks is complete.");
shortCircuitCompletionSink.emitValue(INSTANCE, FAIL_FAST);
retryFromTextChunkIndexOnAbruptCompleteSink.emitComplete(FAIL_FAST);
}
}
final var frameJsonBody = frameContent.substring(bodyStartIndex);
log.trace("↓↓↓↓↓↓↓ json body ↓↓↓↓↓↓\n" + frameJsonBody);
} else if (wsf instanceof BinaryWebSocketFrame binaryWebSocketFrame) {
final var webSocketContent = binaryWebSocketFrame.content();
final short headerLengthBytes = webSocketContent.readShort();
log.trace("headerLengthBytes = " + headerLengthBytes);
final var headerBytes = new byte[headerLengthBytes];
webSocketContent.readBytes(headerBytes);
final var headerSection = new String(headerBytes, UTF_8);
log.trace("↓↓↓↓↓↓↓ header section ↓↓↓↓↓↓\n" + headerSection);
boolean partIsAudio = Pattern.compile(PATH_HEADER + ":\s*audio" + CRLF).matcher(headerSection).find();
boolean requestIdMatches = Pattern.compile(X_REQUEST_HEADER_NAME + ":" + requestId.get() + CRLF).matcher(headerSection).find();
if (requestIdMatches) {
if (partIsAudio) {
byte[] bytesRead = new byte[webSocketContent.writerIndex() - webSocketContent.readerIndex()];
webSocketContent.readBytes(bytesRead);
audioStreamCollectingSink.get().emitNext(bytesRead, FAIL_FAST);
}
} else {
// lets fail application
log.debug("Text segment(s) that caused the error are {} \n", remainingChunks);
errorNotifyingSink.tryEmitError(new RuntimeException(
"Received audio in a BinaryWebsocketFrame. But it did not match " + X_REQUEST_HEADER_NAME
+ " header"));
}
} else if (wsf instanceof CloseWebSocketFrame closeWebSocketFrame) {
log.debug("Received close websocket frame (might/might not be an issue). statusCode: {}, reason: {}", closeWebSocketFrame.statusCode(), closeWebSocketFrame.reasonText());
// log.debug("Text segment(s) that caused the error are {} \n", textChunks);
// errorNotifyingSink.tryEmitError(new RuntimeException("Otherside asked us to close the stream. statusCode: " + closeWebSocketFrame.statusCode() + " reason: " + closeWebSocketFrame.reasonText()));
} // we dont care about Ping & Pong as framework takes care of that automatically when we set `handlePing(false)` which is the default
})
.doOnComplete(() -> {
// at times server is sending fin (finish aka orderly closure of connection from its end) about 15seconds into websocket conversation in the middle websocket frame transfer & reactor netty client is only sendingg ack (acknowledgement of TCP bytes received) the reception of server fin, but not sending a fin (orderly closure) & terminating the application while the underlying TCP connection is open at OS level.
// so server is sending a rst (reset aka force closure of dangling tcp connection) after about 120 seconds
// in this case we need to restart the download for current frame
if (!turnEnded.get()) {
log.info("Received an unexpected complete (FIN) from server while downloading chunk #{} (index = {}) in current session", (downloadCompleteChunkIndex.get() + 2), (downloadCompleteChunkIndex.get() + 1));
retryFromTextChunkIndexOnAbruptCompleteSink.emitNext((downloadCompleteChunkIndex.get() + 1), FAIL_FAST);
} else {
log.info("Received an unexpected complete (FIN) from server.");
shortCircuitCompletionSink.emitValue(INSTANCE, FAIL_FAST);
retryFromTextChunkIndexOnAbruptCompleteSink.emitComplete(FAIL_FAST);
}
})
.then();
final var requestIdAndTextWebSocketFrameForText$$ = Flux.fromIterable(remainingChunks).map(this::requestIdAndTextWebSocketFrameForText$);
final Flux<TextWebSocketFrame> twf$ = requestIdAndTextWebSocketFrameForText$$
.zipWith(Mono.just(INSTANCE).cast(Object.class).concatWith(audioChunkDownloadCompleteInfiniteStreamSink.asFlux())) // lets emit next batch only when previous chunk is fully downloaded
.doOnNext(__ -> inputRequestStreamSignalSink.emitNext(INSTANCE, FAIL_FAST))
.concatMap(t -> Flux.from(t.getT1()).map(RequestIdAndTextWebSocketFrame::frame))
.doOnComplete(() -> inputRequestStreamSignalSink.emitComplete(FAIL_FAST));
final var pingWsf$ = Flux.interval(PING_MESSAGE_WEBSOCKET_DURATION).map(__ -> new PingWebSocketFrame()); // infinite stream
final var outboundWs$ = initTextWebSocketFrameToSetSpeechConfigForWebsocketSession$()
.concatWith(twf$).cast(Object.class)
.mergeWith(pingWsf$);
final var outbound$ = outbound.sendObject(outboundWs$, __ -> true);
final var closeStatus$ = inbound.receiveCloseStatus()
.doOnNext(status -> {
log.debug("Received close status. statusCode: {}; reason: {}", status.code(), status.reasonText());
// log.debug("Text segment(s) that caused the error are {} \n", textChunks);
// errorNotifyingSink.tryEmitError(new RuntimeException("Otherside asked us to close the stream. statusCode: " + status.code() + " reason: " + status.reasonText()));
})
.then();
return Mono.when(inbound$, outbound$, chunkPersister$, closeStatus$)
.takeUntilOther(allChunksDownloadComplete$); // lets cancel all streams once all chunks are downloaded
});
return webSocketInBoundOutboundMerged$;
}
Mono<TextWebSocketFrame> initTextWebSocketFrameToSetSpeechConfigForWebsocketSession$() {
final var initialRequestBodyTemplate = """
Path: speech.config
X-Timestamp: =timeStamp=
Content-Type: application/json; charset=utf-8
{"context":{"system":{"name":"SpeechSDK","version":"1.19.0","build":"JavaScript","lang":"JavaScript"},"os":{"platform":"Browser/Linux x86_64","name":"Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/104.0.5112.102 Safari/537.36","version":"5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/104.0.5112.102 Safari/537.36"}}}
""";
return Mono.just(initialRequestBodyTemplate)
.map(template -> {
final var timeStampHeaderValue = Instant.ofEpochMilli(System.currentTimeMillis()).atZone(ZoneId.of("GMT")).format(DateTimeFormatter.ISO_DATE_TIME).replaceAll("\\[[^]]+]", "");
return template.strip().trim() // get rid of indentation
.replaceAll("=timeStamp=", timeStampHeaderValue)
.replaceAll("\n", "\r\n"); // microsoft server/spec expects the lines to be separated by \r\n instead of \n
})
.map(TextWebSocketFrame::new);
}
private Flux<RequestIdAndTextWebSocketFrame> requestIdAndTextWebSocketFrameForText$(String requestId, String textToConvert) {
// we need this instead of string concatenation because we have to escape special characters within xml properly
Speak
speakXmlObj = new Speak("1.0", "en-US", new Voice("en-US", "Male", "en-GB-RyanNeural", new Prosody("0%", "0%", textToConvert)));
String ssmlSpeakXmlAsStr = get(() -> objectMapper.writeValueAsString(speakXmlObj));
// we need `synthesis.context` twf
final var ttsTextWebsocketFrameBodyTemplates = List.of(
"""
Path: synthesis.context
X-RequestId: =requestId=
X-Timestamp: =timeStamp=
Content-Type: application/json; charset=utf-8
{"synthesis":{"audio":{"metadataOptions":{"bookmarkEnabled":false,"sentenceBoundaryEnabled":false,"visemeEnabled":false,"wordBoundaryEnabled":false},"outputFormat":"audio-16khz-64kbitrate-mono-mp3"},"language":{"autoDetection":false}}}
""",
"""
Path: ssml
X-RequestId: =requestId=
X-Timestamp: =timeStamp=
Content-Type: application/ssml+xml
"""+ssmlSpeakXmlAsStr);
return Flux.fromIterable(ttsTextWebsocketFrameBodyTemplates)
.map(ttsTextWebsocketFrameBodyTemplate -> {
final var timeStampHeaderValue = generateTimeStamp();
return ttsTextWebsocketFrameBodyTemplate.strip().trim() // get rid of indentation
.replaceAll("=requestId=", requestId)
.replaceAll("=timeStamp=", timeStampHeaderValue)
.replaceAll("=ssmlSpeakXml=", ssmlSpeakXmlAsStr)
.replaceAll("\n", "\r\n"); // microsoft server/spec expects the lines to be separated by \r\n instead of \n
})
.map(frameBody -> new RequestIdAndTextWebSocketFrame(requestId, new TextWebSocketFrame(frameBody)));
}
private Flux<RequestIdAndTextWebSocketFrame> requestIdAndTextWebSocketFrameForText$(String textToConvert) {
final var requestId = generateRequestId();
return requestIdAndTextWebSocketFrameForText$(requestId, textToConvert);
}
private static String generateTimeStamp() {
return Instant.ofEpochMilli(System.currentTimeMillis()).atZone(ZoneId.of("GMT"))
.format(DateTimeFormatter.ISO_DATE_TIME).replaceAll("\\[[^]]+]", "");
}
private static String generateRequestId() {
return UUID.randomUUID().toString().replaceAll("-", "").toUpperCase();
}
}
package com.acme.tts.converter.microsoft.ssml.websocket;
import lombok.extern.log4j.Log4j2;
import org.springframework.core.io.buffer.DataBuffer;
import org.springframework.core.io.buffer.DataBufferUtils;
import org.springframework.core.io.buffer.DefaultDataBufferFactory;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
import reactor.core.publisher.Sinks.Many;
import java.nio.file.Files;
import java.nio.file.Path;
import java.util.concurrent.atomic.AtomicInteger;
import static com.acme.tts.converter.ConverterUtil.getPartFileName;
import static com.acme.tts.converter.microsoft.ssml.websocket.Marker.INSTANCE;
import static io.earcam.unexceptional.Exceptional.run;
import static java.nio.file.StandardOpenOption.CREATE_NEW;
import static java.nio.file.StandardOpenOption.WRITE;
import static reactor.core.publisher.Sinks.EmitFailureHandler.FAIL_FAST;
@Log4j2
class ChunksPersister {
private final Many<Many<byte[]>> audioChunksPersistingSink;
private final Many<Marker> audioChunkDownloadCompleteInfiniteStreamSink;
private final Path dirToDownloadTo;
private final AtomicInteger nextChunkNumber;
private final int totalChunkCount;
private final Path inputFilePath;
public ChunksPersister(Many<Many<byte[]>> audioChunksPersistingSink,
Many<Marker> audioChunkDownloadCompleteInfiniteStreamSink,
Path dirToDownloadTo,
int startingChunkNumber,
int totalChunkCount,
Path inputFilePath) {
this.audioChunksPersistingSink = audioChunksPersistingSink;
this.audioChunkDownloadCompleteInfiniteStreamSink =
audioChunkDownloadCompleteInfiniteStreamSink;
this.dirToDownloadTo = dirToDownloadTo;
this. nextChunkNumber = new AtomicInteger(startingChunkNumber);
this.totalChunkCount = totalChunkCount;
this.inputFilePath = inputFilePath;
}
Mono<Void> asMono() {
return audioChunksPersistingSink.asFlux().switchMap(audioStreamCollectingSink -> {
final Flux<DataBuffer> audioPart$ = audioStreamCollectingSink.asFlux().cast(byte[].class)
.map(DefaultDataBufferFactory.sharedInstance::wrap);
int currChunkNumber = nextChunkNumber.getAndIncrement();
final var chunkDestinationFileName = getPartFileName(currChunkNumber, totalChunkCount);
final var destination = dirToDownloadTo.resolve(chunkDestinationFileName);
// lets the file if it already exists
Mono<?> fileCleaner$ = Mono.empty();
if (Files.exists(destination)) {
fileCleaner$ = Mono.create(sink -> {
run(() -> Files.deleteIfExists(destination));
sink.success();
});
}
log.debug("[{}]: (part {} of {}) Downloading to {} …", inputFilePath.getFileName(),
currChunkNumber, totalChunkCount, chunkDestinationFileName);
return fileCleaner$.then(
DataBufferUtils.write(audioPart$, destination, CREATE_NEW, WRITE)
.doOnSuccess(__ -> {
audioChunkDownloadCompleteInfiniteStreamSink.tryEmitNext(INSTANCE);
log.debug("[{}]: (part {} of {}) Downloaded to {}",
inputFilePath.getFileName(), currChunkNumber,
totalChunkCount, chunkDestinationFileName);
})
.doOnError(e -> audioChunkDownloadCompleteInfiniteStreamSink.emitError(e, FAIL_FAST)));
})
.then();
}
}
package com.acme.tts.converter.microsoft.ssml;
import com.acme.tts.converter.FileToTtsConverter;
import com.acme.tts.converter.microsoft.ssml.websocket.WebsocketBasedDownloader;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.SerializationFeature;
import com.fasterxml.jackson.dataformat.xml.XmlMapper;
import io.netty.buffer.ByteBufAllocator;
import lombok.extern.log4j.Log4j2;
import org.springframework.core.io.buffer.DataBufferUtils;
import org.springframework.core.io.buffer.NettyDataBufferFactory;
import reactor.core.publisher.Mono;
import reactor.core.scheduler.Schedulers;
import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Path;
import java.util.List;
import java.util.stream.IntStream;
import static com.acme.tts.common.Util.launchProcessMergeInputAndOutputStreamsIgnoreInputStream;
import static com.acme.tts.converter.ConverterUtil.getPartFileName;
import static com.acme.tts.converter.ConverterUtil.readAndSanitiseText;
import static com.acme.tts.converter.ConverterUtil.splitTextIntoChunks;
import static io.earcam.unexceptional.Exceptional.get;
import static io.earcam.unexceptional.Exceptional.run;
import static java.lang.String.join;
import static java.nio.file.Files.createDirectory;
import static java.nio.file.Files.exists;
import static java.nio.file.Files.move;
import static java.nio.file.StandardOpenOption.CREATE_NEW;
import static java.nio.file.StandardOpenOption.WRITE;
import static java.util.List.of;
import static java.util.stream.Collectors.joining;
import static org.springframework.util.FileSystemUtils.deleteRecursively;
import static org.springframework.util.StringUtils.stripFilenameExtension;
/**
* Utility created to make the same calls <a href="https://azure.microsoft.com/en-us/services/cognitive-services/text-to-speech/">this page</a> makes. It uses websockets, so we have to use them too
*/
@Log4j2
public class MicrosoftTtsDownloader implements FileToTtsConverter {
private final WebsocketBasedDownloader websocketBasedDownloader;
public MicrosoftTtsDownloader(boolean debugEnabled) {
websocketBasedDownloader = new WebsocketBasedDownloader(debugEnabled);
}
@Override
public Mono<Void> convertToTts(Path inputFilePath) {
final var finalOutputMp3Path = inputFilePath
.resolveSibling(stripFilenameExtension(inputFilePath.getFileName().toString()) + ".mp3");
if (exists(finalOutputMp3Path)) {
log.info("Output file " + finalOutputMp3Path + " already exists. Skipping download.");
return Mono.empty();
}
final var textToTranscript = readAndSanitiseText(inputFilePath);
final List<String> transcribableChunks = splitTextIntoChunks(textToTranscript, 1_000); // tho it supports upto 8000 characters, microsoft JS client does this, lets also do the same
final var audioPartsDir = finalOutputMp3Path.resolveSibling("audio-parts/");
// Since all File System API in java is blocking (AsynchronousFileChannel helps only with read & write, not for iteration), we have to use `subscribeOn(elastic())` hack this hack to get around that issue
return Mono.create(sink -> {
try {
deleteRecursively(audioPartsDir);
createDirectory(audioPartsDir);
sink.success();
} catch (IOException e) {
sink.error(e);
}
})
.subscribeOn(Schedulers.boundedElastic())
.thenEmpty(websocketBasedDownloader.downloadAudioForAllChunks(transcribableChunks, audioPartsDir, inputFilePath))
.then(
Mono.create(mergeAndCleanupSink -> {
// If we have more than 1 part, merge using ffmpeg
// TO do so
// 1. Create merge config file
// 2. ffmpeg -f concat -i merge-config.txt -c copy target.mp3
// Sample merge-config.txt
// file 'p1.mp3'
// file 'p2.mp3'
if (transcribableChunks.size() > 1) {
var mergeConfigPath = audioPartsDir.resolve("merge-config.txt");
var mergeConfigFileContent = IntStream.range(1, transcribableChunks.size() + 1)
.mapToObj(transcribableChunkNumber -> getPartFileName(transcribableChunkNumber, transcribableChunks.size()))
.map(partName -> "file " + partName)
.collect(joining("\n"));
final var bufferFactory = new NettyDataBufferFactory(ByteBufAllocator.DEFAULT);
final var mergeConfigDataBuffer = bufferFactory.wrap(mergeConfigFileContent.getBytes());
DataBufferUtils.write(Mono.just(mergeConfigDataBuffer), mergeConfigPath, CREATE_NEW, WRITE)
.doOnSubscribe(__ -> log.info("[{}]: Merging all parts from {} into {}", inputFilePath.getFileName(), audioPartsDir, finalOutputMp3Path))
.doOnSubscribe(__ -> log.info("[{}]: Executing {}", inputFilePath.getFileName(), join(" ", of("ffmpeg", "-f", "concat", "-i", mergeConfigPath.toAbsolutePath().toString(), "-c", "copy", finalOutputMp3Path.toAbsolutePath().toString()))))
.doOnSubscribe(__ -> log.info("Process output is"))
.thenEmpty(
launchProcessMergeInputAndOutputStreamsIgnoreInputStream("ffmpeg", "-f", "concat", "-i", mergeConfigPath.toAbsolutePath().toString(), "-c", "copy", finalOutputMp3Path.toAbsolutePath().toString())
.doOnNext(log::info)
.then()
)
.thenEmpty(
Mono.<Void>create(deletePartDirSink -> {
log.info("[{}]: Merged all parts from {} into {}", inputFilePath.getFileName(), audioPartsDir, finalOutputMp3Path);
try {
deleteRecursively(audioPartsDir);
log.info("[{}]: Cleaned up parts from {}", inputFilePath.getFileName(), audioPartsDir);
deletePartDirSink.success();
} catch (IOException e) {
deletePartDirSink.error(e);
}
})
.subscribeOn(Schedulers.boundedElastic())
)
.subscribe(__ -> {}, mergeAndCleanupSink::error, mergeAndCleanupSink::success);
} else {
Mono.<Void>create(deletePartDirSink -> {
final var tempFilePath = get(() -> Files.list(audioPartsDir)).findFirst().orElseThrow(() -> new IllegalStateException("Shouldnt reach this point. File should have been downloaded by now"));
log.info("[{}]: Moving file {} to {}", inputFilePath.getFileName(), tempFilePath.getFileName(), finalOutputMp3Path.getFileName());
run(() -> move(tempFilePath, finalOutputMp3Path));
try {
deleteRecursively(audioPartsDir);
log.info("[{}]: Cleaned up parts from {}", inputFilePath.getFileName(), audioPartsDir);
deletePartDirSink.success();
} catch (IOException e) {
deletePartDirSink.error(e);
}
})
.subscribeOn(Schedulers.boundedElastic())
.subscribe(__ -> {}, mergeAndCleanupSink::error, mergeAndCleanupSink::success);
}
})
);
}
}
package com.acme.tts.converter.microsoft.ssml.websocket;
import io.netty.handler.logging.LogLevel;
import lombok.extern.log4j.Log4j2;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
import reactor.core.publisher.Sinks;
import reactor.netty.http.client.HttpClient;
import reactor.netty.http.client.WebsocketClientSpec;
import reactor.netty.transport.logging.AdvancedByteBufFormat;
import java.nio.file.Path;
import java.time.Duration;
import java.time.Instant;
import java.util.List;
import java.util.concurrent.ThreadLocalRandom;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicReference;
import static com.acme.tts.converter.microsoft.ssml.websocket.Marker.INSTANCE;
import static java.time.Duration.ofSeconds;
import static java.time.temporal.ChronoUnit.NANOS;
import static org.springframework.http.HttpHeaders.ORIGIN;
import static org.springframework.http.HttpHeaders.USER_AGENT;
import static reactor.core.publisher.Sinks.EmitFailureHandler.FAIL_FAST;
/**
* Protocol is now deleted from microsoft website, but we can find an older version
* <a href="https://web.archive.org/web/20191210051654/https://docs.microsoft.com/en-us/azure/cognitive-services/speech/api-reference-rest/websocketprotocol#connection-establishment">here</a>
* <p>
* Other places of interest (documentation & reference implementation)
* <ol>
* <li><a href="https://github.com/thekalinga/MsEdgeTTS">thekalinga/MsEdgeTTS</a></li>
* <li><a href="https://github.com/thekalinga/ms-bing-speech-service">thekalinga/ms-bing-speech-service</a></li>
* <li><a href="https://github.com/thekalinga/cognitive-services-speech-sdk-js">thekalinga/cognitive-services-speech-sdk-js</a></li>
* </ol>
*/
@Log4j2
public class WebsocketBasedDownloader {
// as per spec
private static final int MAX_WEBSOCKET_FRAME_PAYLOAD_LENGTH = 8_192;
public static final String BROWSER_USER_AGENT_VALUE =
"Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/104.0.5112.102 Safari/537.36";
public static final String BROWSER_ORIGIN_VALUE = "https://azure.microsoft.com";
private final HttpClient client;
private final WebsocketClientSpec websocketClientSpec;
public WebsocketBasedDownloader(boolean debugEnabled) {
HttpClient client = HttpClient.create()
.headers(headers -> headers.add(USER_AGENT, BROWSER_USER_AGENT_VALUE).add(ORIGIN, BROWSER_ORIGIN_VALUE))
.keepAlive(true)
.proxyWithSystemProperties();
if(debugEnabled) {
client = client.wiretap(getClass().getCanonicalName(), LogLevel.TRACE, AdvancedByteBufFormat.HEX_DUMP);
}
this.client = client;
websocketClientSpec = WebsocketClientSpec.builder()
.maxFramePayloadLength(MAX_WEBSOCKET_FRAME_PAYLOAD_LENGTH) // as per spec
.compress(true)
.build();
}
/**
* Downloads mp3 files all chunks specified. Here is how it works
* <ol>
* <li>Makes connection to the websocket server of microsoft azure text-to-speech server</li>
* <li>
* Sends the same request that microsoft javascript client sends
* <ul>
* <li>On outbound channel, Sends a TextWebSocketFrame to setup context (common for the whole websocket session)</li>
* </ul>
* </li>
* <li>
* For each of the input text chunk, we need to exchange multiple *WebSocketFrame that together constitute a single audio conversion
* <ul>
* <li>On outbound channel, Sends a TextWebSocketFrame to specify the settings for next SSML request
* <li>On outbound channel, Sends a TextWebSocketFrame with SSML xml</li>
* <li>On inbound channel, Receives a TextWebSocketFrame with header `Path:turn.start` which indicates we can expect response from server in subsequent frame</li>
* <li>On inbound channel, Receives first BinaryWebSocketFrame that specifies `Path:audio` & the 1st two bytes specify the header length & we skip that length to get audio fragment</li>
* <li>On inbound channel, Receives all future BinaryWebSocketFrame that specifies `Path:audio` & get audio fragment from fragment body</li>
* <li>On inbound channel, Receives a TextWebSocketFrame with header `Path:turn.end` indicating audio is fully sent & we can consider the audio to be complete</li>
* </ul>
* </li>
* <li>We repeat step (3) for every fragment with new request id</li>
* </ol>
* @param textChunks All text chunks that needs to be converted to audio. Browser client uses just 1000 characters. So, we also needs to same to avoid suspicion by microsoft
* @param dirToDownloadTo specify the directory to which all audio parts corresponding to the text chunks would be downloaded to
* @return a publisher that completes when download is complete
*/
public Flux<Void> downloadAudioForAllChunks(List<String> textChunks, Path dirToDownloadTo, Path inputFilePath) {
log.debug("[{}]: Will be downloading chunks {}-{}", inputFilePath.getFileName(), 1, textChunks.size());
// for debugging enable this (if you are getting an exception & dont know where its assembled at)
// Hooks.onOperatorDebug();
// (or) if you want to deploy like this at runtime (without stacktraces), use lightweight version of `checkpoint` operator
// For debugging any specific rx operator (runtime behaviour as to what signals are sent up & down the pipe), use log operator
// log("com.acme.tts.converter.microsoft.give-meaningful-operator-name")
final var startedAt = new AtomicReference<Instant>();
final var previousStartIndex = new AtomicInteger(0);
final var retryFromTextChunkIndexOnAbruptCompleteSink = Sinks.many().unicast().<Integer>onBackpressureBuffer();
// lets always keep atleast 1 item in the 1st leg of retryFromTextChunkIndexOnAbruptCompleteSink, so withLatestFrom works
// server is sending abrupt fin (normal closure of tcp connection) after few chunks reusing same underlying websocket session
// In this case, we need to start a fresh flow downloading remaining segments
retryFromTextChunkIndexOnAbruptCompleteSink.emitNext(previousStartIndex.get(), FAIL_FAST);
return retryFromTextChunkIndexOnAbruptCompleteSink.asFlux().cast(Integer.class)
.withLatestFrom(Mono.just(textChunks), (relIndexToStartFrom, chunks) -> {
final var nextStartIndex = previousStartIndex.addAndGet(relIndexToStartFrom);
if (chunks.size() < nextStartIndex) {
return Mono.<List<String>>empty();
} else {
log.info("Attempting to download {}-{} chunks in current file with new websocket session", nextStartIndex + 1, chunks.size());
final var chunksRemaining = chunks.subList(nextStartIndex, chunks.size());
return Mono.just(chunksRemaining)
.delayUntil(indexAnd_ -> {
if (nextStartIndex > 0) {
final var waitPeriodInSeconds = ThreadLocalRandom.current().nextInt(1, 10);
log.debug("Premptively delaying next attempt to not overwhelm the server for {} seconds", waitPeriodInSeconds);
final var timeTakenForAllPreviousDownloads = Duration.between(startedAt.get(), Instant.now());
// remaining elements * avg time taken for all past element download + avg 5 seconds of delay per future element
final var estimatedTime = Duration.ofNanos((textChunks.size() - nextStartIndex - 1) * (timeTakenForAllPreviousDownloads.get(NANOS) / (nextStartIndex + 1)))
.plus(Duration.ofSeconds(5L * (textChunks.size() - nextStartIndex - 1)));
log.debug("Estimated time to complete download remaining chunks in current file is: {}", estimatedTime);
return Mono.delay(ofSeconds(waitPeriodInSeconds)) // lets delay each batch by 20 seconds
.doOnNext(___ -> log.debug("Preemptive wait complete. Resuming next batch download"));
}
return Mono.just(INSTANCE);
});
}
}) // lets restart the process so we can download only pending ones
.doFirst(() -> startedAt.set(Instant.now()))
.switchMap(remainingChunks$ -> {
//noinspection CodeBlock2Expr
return remainingChunks$
.flatMapMany(remainingChunks -> {
final var chunksAudioFetcherAndPersister = new ChunksAudioFetcherAndPersister(client,
websocketClientSpec, remainingChunks, retryFromTextChunkIndexOnAbruptCompleteSink,
dirToDownloadTo, previousStartIndex.get() + 1, textChunks.size(), inputFilePath);
return chunksAudioFetcherAndPersister.retrieve();
});
});
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment