-
Notifications
You must be signed in to change notification settings - Fork 815
Java ASR Module binding #16979
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Java ASR Module binding #16979
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/16979
Note: Links to docs will display an error until the docs builds have been completed. ❌ 4 New Failures, 1 Pending, 2 Unrelated FailuresAs of commit 852108c with merge base e4060ee ( NEW FAILURES - The following jobs have failed:
FLAKY - The following jobs failed but were likely due to flakiness present on trunk:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
This pull request adds Java/Kotlin bindings for the ExecuTorch ASR (Automatic Speech Recognition) module, enabling Android applications to use ASR models like Whisper. The implementation follows a similar pattern to the existing LLM module bindings.
Changes:
- Added JNI layer implementation (
jni_layer_asr.cpp) to bridge C++ ASR runner with Java/Kotlin - Created Kotlin API classes:
AsrModule,AsrCallback, andAsrTranscribeConfigfor Android integration - Updated build scripts and CMake configuration to include ASR runner support
Reviewed changes
Copilot reviewed 6 out of 6 changed files in this pull request and generated 15 comments.
Show a summary per file
| File | Description |
|---|---|
| scripts/build_android_library.sh | Adds EXECUTORCH_BUILD_EXTENSION_ASR_RUNNER flag to Android build configuration |
| extension/android/jni/jni_layer_asr.cpp | Implements JNI bindings for ASR runner including native methods and callback handling |
| extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/asr/AsrTranscribeConfig.kt | Defines configuration data class for ASR transcription parameters |
| extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/asr/AsrModule.kt | Main Kotlin API class for ASR module with transcription methods |
| extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/asr/AsrCallback.kt | Callback interface for receiving transcription tokens and completion events |
| extension/android/CMakeLists.txt | Updates CMake to include ASR JNI layer in the build |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
...n/android/executorch_android/src/main/java/org/pytorch/executorch/extension/asr/AsrModule.kt
Outdated
Show resolved
Hide resolved
...n/android/executorch_android/src/main/java/org/pytorch/executorch/extension/asr/AsrModule.kt
Show resolved
Hide resolved
...android/executorch_android/src/main/java/org/pytorch/executorch/extension/asr/AsrCallback.kt
Show resolved
Hide resolved
...n/android/executorch_android/src/main/java/org/pytorch/executorch/extension/asr/AsrModule.kt
Outdated
Show resolved
Hide resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
Copilot reviewed 6 out of 6 changed files in this pull request and generated 3 comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
...n/android/executorch_android/src/main/java/org/pytorch/executorch/extension/asr/AsrModule.kt
Show resolved
Hide resolved
Use int64_t arithmetic to detect overflow when computing batchSize * timeSteps * featureDim before casting to jsize. This prevents silent overflow that could cause incorrect validation and potential out-of-bounds memory access.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
Copilot reviewed 6 out of 6 changed files in this pull request and generated 6 comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| if (localClass != nullptr) { | ||
| callbackCache.callbackClass = (jclass)env->NewGlobalRef(localClass); | ||
| callbackCache.onTokenMethod = env->GetMethodID( | ||
| callbackCache.callbackClass, "onToken", "(Ljava/lang/String;)V"); | ||
| callbackCache.onCompleteMethod = env->GetMethodID( | ||
| callbackCache.callbackClass, "onComplete", "(Ljava/lang/String;)V"); | ||
| env->DeleteLocalRef(localClass); | ||
| } | ||
| }); |
Copilot
AI
Jan 28, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The initCallbackCache function doesn't handle the case where FindClass returns nullptr. If the AsrCallback class cannot be found, the function will silently continue with null values in the cache, which will cause crashes when the callback methods are invoked. Consider adding error handling and validation after cache initialization.
| if (localClass != nullptr) { | |
| callbackCache.callbackClass = (jclass)env->NewGlobalRef(localClass); | |
| callbackCache.onTokenMethod = env->GetMethodID( | |
| callbackCache.callbackClass, "onToken", "(Ljava/lang/String;)V"); | |
| callbackCache.onCompleteMethod = env->GetMethodID( | |
| callbackCache.callbackClass, "onComplete", "(Ljava/lang/String;)V"); | |
| env->DeleteLocalRef(localClass); | |
| } | |
| }); | |
| if (localClass == nullptr) { | |
| // Clear any pending exception from FindClass; we'll report a clearer error below. | |
| if (env->ExceptionCheck()) { | |
| env->ExceptionClear(); | |
| } | |
| return; | |
| } | |
| callbackCache.callbackClass = (jclass)env->NewGlobalRef(localClass); | |
| callbackCache.onTokenMethod = env->GetMethodID( | |
| callbackCache.callbackClass, "onToken", "(Ljava/lang/String;)V"); | |
| callbackCache.onCompleteMethod = env->GetMethodID( | |
| callbackCache.callbackClass, "onComplete", "(Ljava/lang/String;)V"); | |
| env->DeleteLocalRef(localClass); | |
| }); | |
| // Validate that the callback cache has been fully initialized. | |
| if (callbackCache.callbackClass == nullptr || | |
| callbackCache.onTokenMethod == nullptr || | |
| callbackCache.onCompleteMethod == nullptr) { | |
| // Clear any JNI exception that may have been raised during initialization. | |
| if (env->ExceptionCheck()) { | |
| env->ExceptionClear(); | |
| } | |
| jclass exceptionClass = env->FindClass("java/lang/IllegalStateException"); | |
| if (exceptionClass != nullptr) { | |
| env->ThrowNew( | |
| exceptionClass, | |
| "Failed to initialize ASR callback JNI cache: AsrCallback class or methods not found."); | |
| env->DeleteLocalRef(exceptionClass); | |
| } | |
| } |
| if (scopedCallback) { | ||
| initCallbackCache(env); | ||
|
|
||
| jobject callbackRef = scopedCallback.get(); | ||
| tokenCallback = [env, callbackRef, &tokenBuffer](const std::string& token) { | ||
| tokenBuffer += token; | ||
| if (!utf8_check_validity(tokenBuffer.c_str(), tokenBuffer.size())) { | ||
| ET_LOG( | ||
| Info, "Current token buffer is not valid UTF-8. Waiting for more."); | ||
| return; | ||
| } | ||
|
|
||
| std::string completeToken = tokenBuffer; | ||
| tokenBuffer.clear(); | ||
|
|
||
| jstring jToken = env->NewStringUTF(completeToken.c_str()); | ||
| env->CallVoidMethod(callbackRef, callbackCache.onTokenMethod, jToken); | ||
| if (env->ExceptionCheck()) { | ||
| ET_LOG(Error, "Exception occurred in AsrCallback.onToken"); | ||
| env->ExceptionClear(); | ||
| } | ||
| env->DeleteLocalRef(jToken); | ||
| }; | ||
| } |
Copilot
AI
Jan 28, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
After calling initCallbackCache, there's no validation that the callback cache was properly initialized. If initCallbackCache failed to find the class or methods, callbackCache.onTokenMethod and callbackCache.onCompleteMethod will be nullptr, causing crashes at lines 326 and 342. Add validation checks after initCallbackCache to ensure the cache is properly initialized before using the callback.
| auto make_scoped_global_ref(JNIEnv* env, jobject obj) { | ||
| auto deleter = [env](jobject ref) { | ||
| if (ref != nullptr) { | ||
| env->DeleteGlobalRef(ref); | ||
| } | ||
| }; | ||
| jobject globalRef = obj ? env->NewGlobalRef(obj) : nullptr; | ||
| return std::unique_ptr<std::remove_pointer_t<jobject>, decltype(deleter)>( | ||
| globalRef, deleter); | ||
| } |
Copilot
AI
Jan 28, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The custom deleter for the scoped global reference captures JNIEnv* by value. This is only safe because the unique_ptr is destroyed before the function returns and all operations are synchronous. If the unique_ptr were to outlive the function or be passed to another thread, this would cause issues since JNIEnv* pointers are thread-local. Consider adding a comment documenting this constraint.
| bool utf8_check_validity(const char* str, size_t length) { | ||
| for (size_t i = 0; i < length; ++i) { | ||
| uint8_t byte = static_cast<uint8_t>(str[i]); | ||
| if (byte >= 0x80) { | ||
| if (i + 1 >= length) { | ||
| return false; | ||
| } | ||
| uint8_t next_byte = static_cast<uint8_t>(str[i + 1]); | ||
| if ((byte & 0xE0) == 0xC0 && (next_byte & 0xC0) == 0x80) { | ||
| i += 1; | ||
| } else if ( | ||
| (byte & 0xF0) == 0xE0 && (next_byte & 0xC0) == 0x80 && | ||
| (i + 2 < length) && | ||
| (static_cast<uint8_t>(str[i + 2]) & 0xC0) == 0x80) { | ||
| i += 2; | ||
| } else if ( | ||
| (byte & 0xF8) == 0xF0 && (next_byte & 0xC0) == 0x80 && | ||
| (i + 2 < length) && | ||
| (static_cast<uint8_t>(str[i + 2]) & 0xC0) == 0x80 && | ||
| (i + 3 < length) && | ||
| (static_cast<uint8_t>(str[i + 3]) & 0xC0) == 0x80) { | ||
| i += 3; | ||
| } else { | ||
| return false; | ||
| } | ||
| } | ||
| } | ||
| return true; | ||
| } |
Copilot
AI
Jan 28, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The utf8_check_validity function is duplicated from jni_layer_llama.cpp (lines 48-79). Consider extracting this utility function into a shared header file (e.g., jni/jni_helper.h) to avoid code duplication and ensure consistent UTF-8 validation across both ASR and LLM modules.
| // Helper to get a string from jstring | ||
| std::string jstringToString(JNIEnv* env, jstring jstr) { | ||
| if (jstr == nullptr) { | ||
| return ""; | ||
| } | ||
| const char* chars = env->GetStringUTFChars(jstr, nullptr); | ||
| std::string result(chars); | ||
| env->ReleaseStringUTFChars(jstr, chars); | ||
| return result; | ||
| } |
Copilot
AI
Jan 28, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The jstringToString helper function is a simple utility that could potentially be shared across JNI layers. Consider adding it to jni/jni_helper.h if it's used in multiple places. However, this is not critical given its simplicity.
| jobject callbackRef = scopedCallback.get(); | ||
| tokenCallback = [env, callbackRef, &tokenBuffer](const std::string& token) { | ||
| tokenBuffer += token; | ||
| if (!utf8_check_validity(tokenBuffer.c_str(), tokenBuffer.size())) { | ||
| ET_LOG( | ||
| Info, "Current token buffer is not valid UTF-8. Waiting for more."); | ||
| return; | ||
| } | ||
|
|
||
| std::string completeToken = tokenBuffer; | ||
| tokenBuffer.clear(); | ||
|
|
||
| jstring jToken = env->NewStringUTF(completeToken.c_str()); | ||
| env->CallVoidMethod(callbackRef, callbackCache.onTokenMethod, jToken); | ||
| if (env->ExceptionCheck()) { | ||
| ET_LOG(Error, "Exception occurred in AsrCallback.onToken"); | ||
| env->ExceptionClear(); | ||
| } | ||
| env->DeleteLocalRef(jToken); | ||
| }; |
Copilot
AI
Jan 28, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The lambda captures JNIEnv* by value, which is only safe if the callback is guaranteed to be invoked synchronously on the same thread. JNIEnv* pointers are thread-local in JNI and cannot be safely used across threads. While the current runner implementation appears to invoke callbacks synchronously, this creates a fragile dependency. Consider documenting this requirement explicitly or adding a runtime check to verify thread safety.
| jobject callbackRef = scopedCallback.get(); | |
| tokenCallback = [env, callbackRef, &tokenBuffer](const std::string& token) { | |
| tokenBuffer += token; | |
| if (!utf8_check_validity(tokenBuffer.c_str(), tokenBuffer.size())) { | |
| ET_LOG( | |
| Info, "Current token buffer is not valid UTF-8. Waiting for more."); | |
| return; | |
| } | |
| std::string completeToken = tokenBuffer; | |
| tokenBuffer.clear(); | |
| jstring jToken = env->NewStringUTF(completeToken.c_str()); | |
| env->CallVoidMethod(callbackRef, callbackCache.onTokenMethod, jToken); | |
| if (env->ExceptionCheck()) { | |
| ET_LOG(Error, "Exception occurred in AsrCallback.onToken"); | |
| env->ExceptionClear(); | |
| } | |
| env->DeleteLocalRef(jToken); | |
| }; | |
| JavaVM* jvm = nullptr; | |
| if (env->GetJavaVM(&jvm) != JNI_OK || jvm == nullptr) { | |
| ET_LOG(Error, "Failed to get JavaVM; token callback will be disabled."); | |
| } else { | |
| jobject callbackRef = scopedCallback.get(); | |
| tokenCallback = | |
| [jvm, callbackRef, &tokenBuffer](const std::string& token) { | |
| JNIEnv* envLocal = nullptr; | |
| bool detach = false; | |
| jint getEnvResult = | |
| jvm->GetEnv(reinterpret_cast<void**>(&envLocal), JNI_VERSION_1_6); | |
| if (getEnvResult == JNI_EDETACHED) { | |
| if (jvm->AttachCurrentThread(&envLocal, nullptr) != JNI_OK) { | |
| ET_LOG( | |
| Error, | |
| "Failed to attach current thread to JVM in token callback."); | |
| return; | |
| } | |
| detach = true; | |
| } else if (getEnvResult != JNI_OK || envLocal == nullptr) { | |
| ET_LOG( | |
| Error, | |
| "Failed to get JNIEnv for current thread in token callback."); | |
| return; | |
| } | |
| tokenBuffer += token; | |
| if (!utf8_check_validity(tokenBuffer.c_str(), tokenBuffer.size())) { | |
| ET_LOG( | |
| Info, | |
| "Current token buffer is not valid UTF-8. Waiting for more."); | |
| if (detach) { | |
| jvm->DetachCurrentThread(); | |
| } | |
| return; | |
| } | |
| std::string completeToken = tokenBuffer; | |
| tokenBuffer.clear(); | |
| jstring jToken = envLocal->NewStringUTF(completeToken.c_str()); | |
| envLocal->CallVoidMethod( | |
| callbackRef, callbackCache.onTokenMethod, jToken); | |
| if (envLocal->ExceptionCheck()) { | |
| ET_LOG(Error, "Exception occurred in AsrCallback.onToken"); | |
| envLocal->ExceptionClear(); | |
| } | |
| envLocal->DeleteLocalRef(jToken); | |
| if (detach) { | |
| jvm->DetachCurrentThread(); | |
| } | |
| }; | |
| } |
This reverts commit cf116b7.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
Copilot reviewed 6 out of 6 changed files in this pull request and generated 2 comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
|
|
||
| if (scopedCallback) { | ||
| initCallbackCache(env); | ||
|
|
Copilot
AI
Jan 29, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The lambda captures JNIEnv* which is thread-specific in JNI. This is safe because the AsrRunner::transcribe function calls the callback synchronously on the same thread (see extension/asr/runner/runner.cpp:300-314). However, if the AsrRunner implementation changes to invoke callbacks asynchronously or from different threads, this will cause crashes. Consider adding a comment to document this thread safety assumption.
| // NOTE: The lambda below captures JNIEnv*, which is only valid on the | |
| // thread on which it was obtained. This is currently safe because | |
| // AsrRunner::transcribe invokes the token callback synchronously on the | |
| // same thread that called this JNI method. If AsrRunner::transcribe is | |
| // ever changed to invoke callbacks asynchronously or from a different | |
| // thread, this capture will become unsafe and must be revisited | |
| // (e.g., by obtaining a JNIEnv* via JavaVM::AttachCurrentThread in the | |
| // callback thread instead of capturing it here). |
| * | ||
| * @param token The decoded text token | ||
| */ | ||
| fun onToken(token: String) |
Copilot
AI
Jan 29, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Missing @DoNotStrip annotation on the onToken method. This annotation is required for JNI callback methods to prevent ProGuard/R8 from stripping them during release builds. The LlmCallback interface uses this annotation (see extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmCallback.java:28). You'll need to add the import: import com.facebook.jni.annotations.DoNotStrip
Summary
Add java binding for extension/asr/runner/runner.h
Test plan
CI