Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 12 additions & 14 deletions example/app/(tabs)/stt.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ import {
import { MLXModel, STT } from 'react-native-nitro-mlx'
import { SafeAreaView } from 'react-native-safe-area-context'

const MODEL_ID = MLXModel.GLM_ASR_Nano_4bit
const MODEL_ID = MLXModel.Qwen3_ASR_0_6B_4bit

type Status = 'idle' | 'loading' | 'ready' | 'listening' | 'transcribing'

Expand Down Expand Up @@ -68,11 +68,11 @@ export default function STTScreen() {
try {
const text = await STT.transcribeBuffer()
if (text) {
streamingRef.current = text
setStreamingText(text)
streamingRef.current = `${streamingRef.current} ${text}`.trim()
setStreamingText(streamingRef.current)
}
} catch {
// buffer too small or not listening, skip
} catch (error) {
console.warn('STT transcribeBuffer error:', error)
} finally {
isTranscribingChunk.current = false
}
Expand All @@ -84,18 +84,16 @@ export default function STTScreen() {

const handleToggleListening = useCallback(async () => {
if (status === 'listening') {
stopPolling()
try {
stopPolling()
setStatus('transcribing')
const finalText = await STT.stopListening()
setTranscript(finalText || streamingRef.current)
setStreamingText('')
streamingRef.current = ''
setStatus('ready')
STT.stop()
} catch (error) {
console.error('STT stopListening error:', error)
setStatus('ready')
console.error('STT stop error:', error)
}
setTranscript(streamingRef.current)
setStreamingText('')
streamingRef.current = ''
setStatus('ready')
} else if (status === 'ready') {
setTranscript('')
setStreamingText('')
Expand Down
26 changes: 25 additions & 1 deletion package/ios/Sources/AudioCaptureManager.swift
Original file line number Diff line number Diff line change
Expand Up @@ -84,14 +84,38 @@ class AudioCaptureManager {
bufferLock.unlock()

guard samples.count >= 8000 else { return nil }

// Silence gate: skip chunks whose peak amplitude is near the noise
// floor so the ASR model doesn't hallucinate ("The.", "...") on
// silence. Peak-based because measurement-mode capture disables AGC,
// making RMS of quiet speech close to ambient noise.
var peak: Float = 0
for s in samples {
let a = s < 0 ? -s : s
if a > peak { peak = a }
}
guard peak >= 0.005 else { return nil }

return MLXArray(samples)
}

func snapshot() -> MLXArray? {
// Take exclusive ownership of the accumulated buffer so the audio tap
// gets fresh empty storage to append into; the expensive MLXArray copy
// then happens off the audio path. Samples are merged back afterward
// so the buffer keeps accumulating across calls.
bufferLock.lock()
let samples = audioBuffer
var samples = audioBuffer
audioBuffer.removeAll()
bufferLock.unlock()

defer {
bufferLock.lock()
samples.append(contentsOf: audioBuffer)
audioBuffer = samples
bufferLock.unlock()
}

guard samples.count >= 16000 else { return nil }
return MLXArray(samples)
}
Expand Down
14 changes: 7 additions & 7 deletions package/ios/Sources/HybridSTT.swift
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ enum STTError: Error {
}

class HybridSTT: HybridSTTSpec {
private var model: GLMASRModel?
private var model: Qwen3ASRModel?
private var activeTask: Task<String, Error>?
private var loadTask: Task<Void, Error>?
private var captureManager: AudioCaptureManager?
Expand Down Expand Up @@ -39,7 +39,7 @@ class HybridSTT: HybridSTTSpec {
self.model = nil
MLX.Memory.clearCache()

let loadedModel = try await GLMASRModel.fromPretrained(modelId)
let loadedModel = try await Qwen3ASRModel.fromPretrained(modelId)

try Task.checkCancellation()

Expand All @@ -62,7 +62,7 @@ class HybridSTT: HybridSTTSpec {
return Promise.async { [self] in
let task = Task<String, Error> {
let mlxAudio = self.arrayBufferToMLXArray(audio)
let output = model.generate(audio: mlxAudio)
let output = model.generate(audio: mlxAudio, language: "English")
return output.text
}

Expand All @@ -84,7 +84,7 @@ class HybridSTT: HybridSTTSpec {
return Promise.async { [self] in
let task = Task<String, Error> {
let mlxAudio = self.arrayBufferToMLXArray(audio)
let stream = model.generateStream(audio: mlxAudio)
let stream = model.generateStream(audio: mlxAudio, language: "English")
var finalText = ""

for try await event in stream {
Expand Down Expand Up @@ -132,13 +132,13 @@ class HybridSTT: HybridSTTSpec {
guard let manager = captureManager, manager.isCapturing else {
throw STTError.notListening
}
guard let audio = manager.snapshot() else {
guard let audio = manager.snapshotAndClear() else {
return Promise.resolved(withResult: "")
}

return Promise.async { [self] in
let task = Task<String, Error> {
let output = model.generate(audio: audio)
let output = model.generate(audio: audio, language: "English")
return output.text
}

Expand All @@ -164,7 +164,7 @@ class HybridSTT: HybridSTTSpec {

return Promise.async { [self] in
let task = Task<String, Error> {
let output = model.generate(audio: audio)
let output = model.generate(audio: audio, language: "English")
return output.text
}

Expand Down
14 changes: 14 additions & 0 deletions package/src/models.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ export enum ModelFamily {
OpenELM = 'OpenELM',
PocketTTS = 'PocketTTS',
GLMASR = 'GLMASR',
Qwen3ASR = 'Qwen3ASR',
}

export enum ModelProvider {
Expand Down Expand Up @@ -86,6 +87,9 @@ export enum MLXModel {

// GLM-ASR (GLMASR) - Speech-to-Text
GLM_ASR_Nano_4bit = 'mlx-community/GLM-ASR-Nano-2512-4bit',

// Qwen3-ASR (Alibaba) - Speech-to-Text
Qwen3_ASR_0_6B_4bit = 'mlx-community/Qwen3-ASR-0.6B-4bit',
}

export const MLXModels: ModelInfo[] = [
Expand Down Expand Up @@ -389,4 +393,14 @@ export const MLXModels: ModelInfo[] = [
downloadSize: 600000000,
type: 'stt',
},
{
id: MLXModel.Qwen3_ASR_0_6B_4bit,
family: ModelFamily.Qwen3ASR,
provider: ModelProvider.Alibaba,
parameters: '0.6B',
quantization: '4bit',
displayName: 'Qwen3 ASR 0.6B (4-bit)',
downloadSize: 712781278,
type: 'stt',
},
]
Loading