import logging
import queue
import wave
import requests
from urllib.request import urlopen
import sounddevice as sd
import time
import threading # 用于终止控制
from vosk import Model, KaldiRecognizer, SetLogLevel
from tqdm import tqdm
from zipfile import ZipFile
from .._utils import ignore_stderr
from .vosk_models import DEFAULT_MODELS
import json
from pathlib import Path
import os
HOME = os.path.expanduser("~")
MODEL_PRE_URL = "https://alphacephei.com/vosk/models/"
MODEL_LIST_URL = MODEL_PRE_URL + "model-list.json"
MODEL_BASE_PATH = f"{HOME}/.vosk_models"
MODEL_LIST_CACHE_PATH = Path(MODEL_BASE_PATH, "model-list.json")
# Suppress noisy urllib3/requests connection errors — we handle them with our own warnings
logging.getLogger("urllib3").setLevel(logging.CRITICAL)
logging.getLogger("requests").setLevel(logging.CRITICAL)
[Doku]
class Vosk():
""" Vosk STT class """
DEFAULT_LANGUAGE = "en-us"
def __init__(self, language=None, samplerate=None, device=None, log=None):
""" Initialize Vosk STT
Args:
language (str, optional): Language, default is None
samplerate (int, optional): Samplerate, default is None
device (int, optional): Device, default is None
log (logging.Logger, optional): Logger, default is None
"""
if not Path(MODEL_BASE_PATH).exists():
Path(MODEL_BASE_PATH).mkdir(parents=True)
self.log = log or logging.getLogger(__name__)
self._load_model_list()
SetLogLevel(-1)
self.downloading = False
self.stop_downloading_event = threading.Event()
self.stop_listening_event = threading.Event()
self.wake_word_thread = None
self.waked = False
self.wake_word_thread_started = False
self._device = device or sd.default.device
if samplerate is None:
device_info = sd.query_devices(self._device, "input")
samplerate = int(device_info["default_samplerate"])
self._samplerate = samplerate
self.recognizer = None
self._language = None
self.wake_words = None
if language is not None:
self.set_language(language, init=False)
self.init()
[Doku]
def is_ready(self):
""" Check if Vosk STT is ready
Returns:
bool: True if ready, False otherwise
"""
return self.recognizer is not None
[Doku]
def init(self):
""" Initialize Vosk STT """
model_path = self.get_model_path(self._language)
if not model_path.exists():
self.download_model(self._language)
model_path = str(model_path)
model = Model(model_path)
self.recognizer = KaldiRecognizer(model, self._samplerate)
[Doku]
def _load_model_list(self):
"""Load model list from local cache or built-in defaults (offline, no network)."""
models = None
if MODEL_LIST_CACHE_PATH.exists():
try:
with open(MODEL_LIST_CACHE_PATH, "r", encoding="utf-8") as f:
all_models = json.load(f)
models = [model for model in all_models if
model["type"] == "small" and model["obsolete"] == "false"]
except Exception:
pass
if not models:
models = DEFAULT_MODELS.copy()
self.available_models = models
self.available_languages = [model["lang"] for model in self.available_models]
self.available_model_names = [model["name"] for model in self.available_models]
[Doku]
def update_model_list(self):
"""Fetch latest model list from network and save to cache.
Call this manually when you want to check for new models online.
Falls back to local cache if network is unavailable.
"""
models = None
try:
with urlopen(MODEL_LIST_URL, timeout=5) as response:
all_models = json.load(response)
models = [model for model in all_models if
model["type"] == "small" and model["obsolete"] == "false"]
self.log.info(f"Model list updated from network ({len(models)} models)")
if models:
try:
with open(MODEL_LIST_CACHE_PATH, "w", encoding="utf-8") as f:
json.dump(all_models, f, ensure_ascii=False, indent=2)
except Exception:
pass
except Exception:
self.log.warning("Network unavailable, using local model list...")
if MODEL_LIST_CACHE_PATH.exists():
try:
with open(MODEL_LIST_CACHE_PATH, "r", encoding="utf-8") as f:
all_models = json.load(f)
models = [model for model in all_models if
model["type"] == "small" and model["obsolete"] == "false"]
self.log.info(f"Model list loaded from cache ({len(models)} models)")
except Exception:
pass
if not models:
self.log.warning("No local model list available, keeping current list")
if models:
self.available_models = models
self.available_languages = [model["lang"] for model in self.available_models]
self.available_model_names = [model["name"] for model in self.available_models]
[Doku]
def wait_until_heard(self, wake_words=None, print_callback=lambda x: print(f"heard: \x1b[K{x}", end="\r", flush=True)):
""" Wait until heard a wake word
Args:
wake_words (list, optional): Wake words, default is None
print_callback (function, optional): Print callback, default is None
Returns:
str: Heard wake word
"""
if wake_words is None:
wake_words = self.wake_words
if isinstance(wake_words, str):
wake_words = [wake_words]
while True:
result = self.listen(stream=False)
print_callback(result)
if result is None:
continue
if result.lower() in wake_words:
break
return result
[Doku]
def heard_wake_word(self, print_callback=lambda x: print(f"heard: \x1b[K{x}", end="\r", flush=True)):
""" Check if heard a wake word
Args:
print_callback (function, optional): Print callback, default is None
Returns:
bool: True if heard a wake word, False otherwise
"""
result = self.listen(stream=False)
if result is None:
return False
print_callback(result)
return result.lower() in self.wake_words
[Doku]
def wait_for_wake_word(self):
""" Wait for wake word """
self.wake_word_thread_started = True
self.stop_listening_event.clear()
while self.wake_word_thread_started:
if self.stop_listening_event.is_set():
self.wake_word_thread_started = False
break
if self.heard_wake_word():
print("")
self.waked = True
self.wake_word_thread_started = False
break
time.sleep(0.1)
self.wake_word_thread = None
[Doku]
def start_listening_wake_words(self):
""" Start listening for wake words """
self.waked = False
self.wake_word_thread = threading.Thread(name="wake_word_thread", target=self.wait_for_wake_word)
self.wake_word_thread_started = True
self.wake_word_thread.start()
[Doku]
def is_waked(self):
""" Check if the wake word thread is running
Returns:
bool: True if running, False otherwise
"""
return self.waked
[Doku]
def stt(self, filename, stream=False):
""" Perform STT on audio file
Args:
filename (str): Audio file path
stream (bool, optional): Stream mode, default is False
Returns:
str: STT result
"""
with wave.open(filename, "rb") as wf:
if wf.getnchannels() != 1 or wf.getsampwidth() != 2 or wf.getcomptype() != "NONE":
raise ValueError("Audio file must be WAV format mono PCM.")
self.recognizer.SetWords(True)
if stream:
self.recognizer.SetPartialWords(True)
return self.get_stream_result(wf, self.recognizer)
else:
self.recognizer.SetPartialWords(False)
return self.recognizer.Result()
[Doku]
def get_stream_result(self, wf, recognizer):
""" Get streaming results from recognizer
Args:
wf (wave.Wave_read): Wave file object
recognizer (KaldiRecognizer): Vosk recognizer
Yields:
str: STT result
"""
while True:
data = wf.readframes(4000)
if len(data) == 0:
break
if self.recognizer.AcceptWaveform(data):
yield self.recognizer.Result()
else:
yield self.recognizer.PartialResult()
[Doku]
def listen(self, stream=False, device=None, samplerate=None):
""" Listen from microphone and return results
Args:
stream (bool, optional): Stream mode, default is False
device (int, optional): Device index, default is None
samplerate (int, optional): Sampling rate, default is None
Returns:
str: STT result
"""
q = queue.Queue()
def callback(indata, frames, time, status):
if status:
self.log.warning(status)
q.put(bytes(indata))
with ignore_stderr():
self.stop_listening_event.clear()
if stream:
return self._listen_streaming(q, device, samplerate, callback)
else:
return self._listen_non_streaming(q, device, samplerate, callback)
[Doku]
def _listen_streaming(self, q, device=None, samplerate=None, callback=None):
""" Listen from microphone and return streaming results
Args:
q (queue.Queue): Queue to store audio data
device (int, optional): Device index, default is None
samplerate (int, optional): Sampling rate, default is None
callback (function, optional): Callback function, default is None
Yields:
dict: STT result
"""
with sd.RawInputStream(
samplerate=samplerate,
blocksize=1024,
device=device,
dtype="int16",
channels=1,
callback=callback):
while True:
if self.stop_listening_event.is_set():
return None
try:
data = q.get(timeout=0.5)
except queue.Empty:
continue
result = {
"done": False,
"partial": "",
"final": ""
}
if self.recognizer.AcceptWaveform(data):
text = self.recognizer.Result()
text = json.loads(text)["text"]
if text == "":
continue
result["done"] = True
result["final"] = text.strip()
yield result
break
else:
partial = self.recognizer.PartialResult()
partial = json.loads(partial)["partial"]
if partial == "" or partial.isspace():
continue
result["partial"] = partial.strip()
yield result
[Doku]
def _listen_non_streaming(self, q, device=None, samplerate=None, callback=None):
""" Listen from microphone and return final result
Args:
q (queue.Queue): Queue to store audio data
device (int, optional): Device index, default is None
samplerate (int, optional): Sampling rate, default is None
callback (function, optional): Callback function, default is None
Returns:
str: STT result
"""
with sd.RawInputStream(samplerate=samplerate, blocksize=1024, device=device,
dtype="int16", channels=1, callback=callback):
while True:
if self.stop_listening_event.is_set():
return None
try:
data = q.get(timeout=0.5)
except queue.Empty:
continue
if self.recognizer.AcceptWaveform(data):
text = self.recognizer.Result()
text = json.loads(text)["text"]
if text == "":
continue
return text
[Doku]
def set_wake_words(self, wake_words: list):
""" Set wake words
Args:
wake_words (list): List of wake words
"""
self.wake_words = wake_words
[Doku]
def language(self) -> str:
""" Get current language
Returns:
str: Current language
"""
return self._language
[Doku]
def set_language(self, language: str, init=True):
""" Set language
Args:
language (str): Language to set
init (bool, optional): Initialize recognizer, default is True
"""
if language not in self.available_languages:
raise ValueError(f"Vosk does not support language: {language}. Available languages: {self.available_languages}")
self._language = language
if init:
self.init()
[Doku]
def get_model_name(self, lang: str) -> str:
""" Get model name for language
Args:
lang (str): Language
Returns:
str: Model name
"""
return self.available_model_names[self.available_languages.index(lang)]
[Doku]
def get_model_path(self, lang: str) -> Path:
""" Get model path for language
Args:
lang (str): Language
Returns:
Path: Model path
"""
model_name = self.get_model_name(lang)
return Path(MODEL_BASE_PATH, model_name)
[Doku]
def is_model_downloaded(self, lang: str) -> bool:
""" Check if model is downloaded
Args:
lang (str): Language
Returns:
bool: True if model is downloaded, False otherwise
"""
model_path = self.get_model_path(lang)
return model_path.exists()
[Doku]
def cancel_download(self):
""" Public method to cancel ongoing download """
if self.downloading:
self.stop_downloading_event.set() # 触发终止事件
self.log.info("Download cancellation requested")
[Doku]
def download_model(self, lang: str, progress_callback=None, max_retries: int=5):
""" Download model for language
Args:
lang (str): Language
progress_callback (function, optional): Progress callback function, default is None
max_retries (int, optional): Maximum retries, default is 5
"""
model_path = self.get_model_path(lang)
if self.is_model_downloaded(lang):
return
if self.downloading:
return
self.downloading = True
self.stop_downloading_event.clear() # 重置终止事件(确保每次下载前都是未触发状态)
zip_url = MODEL_PRE_URL + f"{model_path.name}.zip"
zip_path = f"{model_path}.zip"
retries = 0
try:
while retries < max_retries:
# 检查是否已触发终止
if self.stop_downloading_event.is_set():
raise Exception("Download cancelled by user")
try:
# Check for partially downloaded file
resume_byte_pos = 0
if os.path.exists(zip_path):
resume_byte_pos = os.path.getsize(zip_path)
self.log.info(f"Resuming download from byte position {resume_byte_pos}")
headers = {}
if resume_byte_pos > 0:
headers['Range'] = f'bytes={resume_byte_pos}-'
# Send request
response = requests.get(zip_url, headers=headers, stream=True, timeout=30)
# Check response status
if response.status_code not in [200, 206]: # 200: full response, 206: partial content
response.raise_for_status()
# Get total file size
content_length = response.headers.get('content-length')
if content_length is None:
total_size = None
else:
total_size = int(content_length) + resume_byte_pos
# Prepare progress display
if progress_callback:
progress_callback(resume_byte_pos, total_size)
else:
t = tqdm(
total=total_size, initial=resume_byte_pos,
unit="B", unit_scale=True, unit_divisor=1024,
desc=zip_url.rsplit("/", maxsplit=1)[-1]
)
# Write to file
mode = 'ab' if resume_byte_pos > 0 else 'wb'
with open(zip_path, mode) as f:
downloaded_this_attempt = 0
for chunk in response.iter_content(chunk_size=8192):
# 每次写入前检查是否需要终止
if self.stop_downloading_event.is_set():
raise Exception("Download cancelled by user")
if chunk: # Filter out keep-alive empty chunks
f.write(chunk)
chunk_size = len(chunk)
downloaded_this_attempt += chunk_size
resume_byte_pos += chunk_size
if progress_callback:
progress_callback(resume_byte_pos, total_size)
else:
t.update(chunk_size)
if not progress_callback:
t.close()
# Verify file size if possible
if total_size is not None:
downloaded_size = os.path.getsize(zip_path)
if downloaded_size != total_size:
raise Exception(f"Download incomplete: received {downloaded_size} bytes, expected {total_size} bytes")
else:
self.log.warning("Cannot verify file integrity - server did not provide content length")
# Unzip and clean up
with ZipFile(zip_path, "r") as model_ref:
model_ref.extractall(model_path.parent)
os.remove(zip_path)
# Download successful, exit loop
break
except Exception as e:
# 如果是用户终止,直接跳出重试循环
if "cancelled by user" in str(e).lower():
self.log.info(f"Download cancelled: {str(e)}")
retries += 1
if retries < max_retries:
self.log.info(f"Retrying download ({retries}/{max_retries})...")
else:
self.log.warning(f"Download failed after {max_retries} attempts. Check network connection.")
# Wait before retrying (exponential backoff)
wait_time = 2** retries
self.log.info(f"Retrying in {wait_time} seconds...")
time.sleep(wait_time)
except Exception as e:
self.log.warning(f"Download stopped, will retry later. ({str(e)})")
# 终止后保留部分下载文件(以便后续续传),如果需要删除可改为os.remove(zip_path)
if os.path.exists(zip_path):
self.log.debug(f"Partial download saved to {zip_path}")
finally:
self.downloading = False
self.stop_downloading_event.clear() # 重置终止事件
[Doku]
def download_progress_hook(self, tqdm_bar=None, progress_callback=None):
""" Download progress hook function
Args:
tqdm_bar (tqdm, optional): tqdm progress bar, default is None
progress_callback (function, optional): Progress callback function, default is None
"""
last_b = [0]
def update_to(b=1, bsize=1, tsize=None):
if tsize not in (None, -1):
if tqdm_bar:
tqdm_bar.total = tsize
# Calculate downloaded bytes
downloaded = (b - last_b[0]) * bsize
last_b[0] = b
if tqdm_bar:
return tqdm_bar.update(downloaded)
elif progress_callback:
current = min(b * bsize, tsize) if tsize else b * bsize
progress_callback(current, tsize)
return downloaded
return update_to
[Doku]
def stop_listening(self):
""" Stop listening for wake word """
self.stop_listening_event.set()
[Doku]
def close(self):
""" Close STT """
self.wake_word_thread_started = False
self.stop_downloading_event.set()
self.stop_listening_event.set()