Skip to content
7 changes: 5 additions & 2 deletions extension/android/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -170,8 +170,11 @@ if(EXECUTORCH_BUILD_EXTENSION_TRAINING)
endif()

if(EXECUTORCH_BUILD_LLAMA_JNI)
target_sources(executorch_jni PRIVATE jni/jni_layer_llama.cpp jni/log.cpp)
list(APPEND link_libraries extension_llm_runner)
target_sources(
executorch_jni PRIVATE jni/jni_layer_llama.cpp jni/jni_layer_asr.cpp
jni/log.cpp
)
list(APPEND link_libraries extension_llm_runner extension_asr_runner)
target_compile_definitions(executorch_jni PUBLIC EXECUTORCH_BUILD_LLAMA_JNI=1)

if(QNN_SDK_ROOT)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

package org.pytorch.executorch.extension.asr

import org.pytorch.executorch.annotations.Experimental

/**
* Callback interface for ASR (Automatic Speech Recognition) module. Users can implement this
* interface to receive the transcribed tokens and completion notification.
*
* Warning: These APIs are experimental and subject to change without notice
*/
@Experimental
interface AsrCallback {
/**
* Called when a new token is available from JNI. Users will keep getting onToken() invocations
* until transcription finishes.
*
* @param token The decoded text token
*/
fun onToken(token: String)
Copy link

Copilot AI Jan 29, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing @DoNotStrip annotation on the onToken method. This annotation is required for JNI callback methods to prevent ProGuard/R8 from stripping them during release builds. The LlmCallback interface uses this annotation (see extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmCallback.java:28). You'll need to add the import: import com.facebook.jni.annotations.DoNotStrip

Copilot uses AI. Check for mistakes.

/**
* Called when transcription is complete.
*
* @param transcription The complete transcription (may be empty if tokens were streamed)
*/
fun onComplete(transcription: String) {}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,203 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

package org.pytorch.executorch.extension.asr

import java.io.Closeable
import java.io.File
import java.util.concurrent.atomic.AtomicLong
import org.pytorch.executorch.annotations.Experimental

/**
* AsrModule is a wrapper around the ExecuTorch ASR Runner. It provides a simple interface to
* transcribe audio from WAV files using speech recognition models like Whisper.
*
* The module loads a WAV file, optionally preprocesses it using a preprocessor module (e.g., for
* mel-spectrogram extraction), and then runs the ASR model to generate transcriptions.
*
* Warning: These APIs are experimental and subject to change without notice
*
* @param modelPath Path to the ExecuTorch model file (.pte). The model must expose exactly two
* callable methods named "encoder" and "text_decoder" (these names are required).
* @param tokenizerPath Path to the tokenizer directory containing tokenizer.json
* @param dataPath Optional path to additional data file (e.g., for delegate data)
* @param preprocessorPath Optional path to preprocessor .pte for converting raw audio to features.
* If not provided, raw audio samples will be passed directly to the model.
*/
@Experimental
class AsrModule(
modelPath: String,
tokenizerPath: String,
dataPath: String? = null,
preprocessorPath: String? = null,
) : Closeable {

private val nativeHandle = AtomicLong(0L)

init {
val modelFile = File(modelPath)
require(modelFile.canRead() && modelFile.isFile) { "Cannot load model path $modelPath" }
val tokenizerFile = File(tokenizerPath)
require(tokenizerFile.exists()) { "Cannot load tokenizer path $tokenizerPath" }
if (preprocessorPath != null) {
val preprocessorFile = File(preprocessorPath)
require(preprocessorFile.canRead() && preprocessorFile.isFile) {
"Cannot load preprocessor path $preprocessorPath"
}
}

val handle = nativeCreate(modelPath, tokenizerPath, dataPath, preprocessorPath)
if (handle == 0L) {
throw RuntimeException("Failed to create native AsrModule")
}
nativeHandle.set(handle)
}

companion object {
init {
System.loadLibrary("executorch")
}

@JvmStatic
private external fun nativeCreate(
modelPath: String,
tokenizerPath: String,
dataPath: String?,
preprocessorPath: String?,
): Long

@JvmStatic private external fun nativeDestroy(nativeHandle: Long)

@JvmStatic private external fun nativeLoad(nativeHandle: Long): Int

@JvmStatic private external fun nativeIsLoaded(nativeHandle: Long): Boolean

@JvmStatic
private external fun nativeTranscribe(
nativeHandle: Long,
wavPath: String,
maxNewTokens: Long,
temperature: Float,
decoderStartTokenId: Long,
callback: AsrCallback?,
): Int
}

/** Check if the native handle is valid. */
val isValid: Boolean
get() = nativeHandle.get() != 0L

/** Check if the module is loaded and ready for inference. */
val isLoaded: Boolean
get() {
val handle = nativeHandle.get()
return handle != 0L && nativeIsLoaded(handle)
}

/** Releases native resources. Call this when done with the module. */
fun destroy() {
val handle = nativeHandle.getAndSet(0L)
if (handle != 0L) {
nativeDestroy(handle)
}
}

/** Closeable implementation for use with use {} blocks. */
override fun close() {
destroy()
}

/**
* Force loading the module. Otherwise the model is loaded during first transcribe() call.
*
* @return 0 on success, error code otherwise
* @throws IllegalStateException if the module has been destroyed
*/
fun load(): Int {
val handle = nativeHandle.get()
check(handle != 0L) { "AsrModule has been destroyed" }
return nativeLoad(handle)
}

/**
* Transcribe audio from a WAV file with default configuration.
*
* @param wavPath Path to the WAV audio file
* @param callback Callback to receive tokens, can be null
* @return 0 on success, error code otherwise
* @throws IllegalStateException if the module has been destroyed
*/
fun transcribe(wavPath: String, callback: AsrCallback? = null): Int =
transcribe(wavPath, AsrTranscribeConfig(), callback)

/**
* Transcribe audio from a WAV file with custom configuration.
*
* @param wavPath Path to the WAV audio file
* @param config Configuration for transcription
* @param callback Callback to receive tokens, can be null
* @return 0 on success, error code otherwise
* @throws IllegalStateException if the module has been destroyed
*/
fun transcribe(
wavPath: String,
config: AsrTranscribeConfig,
callback: AsrCallback? = null,
): Int {
val handle = nativeHandle.get()
check(handle != 0L) { "AsrModule has been destroyed" }
val wavFile = File(wavPath)
require(wavFile.canRead() && wavFile.isFile) { "Cannot read WAV file: $wavPath" }
return nativeTranscribe(
handle,
wavPath,
config.maxNewTokens,
config.temperature,
config.decoderStartTokenId,
callback,
)
}

/**
* Transcribe audio from a WAV file and return the full transcription.
*
* This is a blocking call that collects all tokens and returns the complete transcription.
*
* @param wavPath Path to the WAV audio file
* @param config Configuration for transcription
* @return The transcribed text
* @throws RuntimeException if transcription fails
*/
@JvmOverloads
fun transcribeBlocking(
wavPath: String,
config: AsrTranscribeConfig = AsrTranscribeConfig(),
): String {
val result = StringBuilder()
val status =
transcribe(
wavPath,
config,
object : AsrCallback {
override fun onToken(token: String) {
result.append(token)
}

override fun onComplete(transcription: String) {
// Tokens already collected
}
},
)

if (status != 0) {
throw RuntimeException("Transcription failed with error code: $status")
}

return result.toString()
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

package org.pytorch.executorch.extension.asr

import org.pytorch.executorch.annotations.Experimental

/**
* Configuration for ASR transcription.
*
* Warning: These APIs are experimental and subject to change without notice
*
* @property maxNewTokens Maximum number of new tokens to generate (must be positive)
* @property temperature Temperature for sampling. 0.0 means greedy decoding
* @property decoderStartTokenId The token ID to start decoding with (e.g., language token for
* Whisper)
*/
@Experimental
data class AsrTranscribeConfig(
val maxNewTokens: Long = 128,
val temperature: Float = 0.0f,
val decoderStartTokenId: Long = 0,
) {
init {
require(maxNewTokens > 0) { "maxNewTokens must be positive" }
require(temperature >= 0) { "temperature must be non-negative" }
}

/** Builder class for AsrTranscribeConfig for Java interoperability. */
class Builder {
private var maxNewTokens: Long = 128
private var temperature: Float = 0.0f
private var decoderStartTokenId: Long = 0

fun setMaxNewTokens(maxNewTokens: Long) = apply {
require(maxNewTokens > 0) { "maxNewTokens must be positive" }
this.maxNewTokens = maxNewTokens
}

fun setTemperature(temperature: Float) = apply {
require(temperature >= 0) { "temperature must be non-negative" }
this.temperature = temperature
}

fun setDecoderStartTokenId(decoderStartTokenId: Long) = apply {
this.decoderStartTokenId = decoderStartTokenId
}

fun build() =
AsrTranscribeConfig(
maxNewTokens = maxNewTokens,
temperature = temperature,
decoderStartTokenId = decoderStartTokenId,
)
}
}
Loading
Loading