Quellcode für sunfounder_voice_assistant.stt.vosk

import logging
import queue
import wave
import requests
from urllib.request import urlopen
import sounddevice as sd
import time
import threading  # 用于终止控制
from vosk import Model, KaldiRecognizer, SetLogLevel
from tqdm import tqdm
from zipfile import ZipFile
from .._utils import ignore_stderr
from .vosk_models import DEFAULT_MODELS

import json
from pathlib import Path
import os

HOME = os.path.expanduser("~")
MODEL_PRE_URL = "https://alphacephei.com/vosk/models/"
MODEL_LIST_URL = MODEL_PRE_URL + "model-list.json"
MODEL_BASE_PATH = f"{HOME}/.vosk_models"
MODEL_LIST_CACHE_PATH = Path(MODEL_BASE_PATH, "model-list.json")

# Suppress noisy urllib3/requests connection errors — we handle them with our own warnings
logging.getLogger("urllib3").setLevel(logging.CRITICAL)
logging.getLogger("requests").setLevel(logging.CRITICAL)

[Doku] class Vosk(): """ Vosk STT class """ DEFAULT_LANGUAGE = "en-us" def __init__(self, language=None, samplerate=None, device=None, log=None): """ Initialize Vosk STT Args: language (str, optional): Language, default is None samplerate (int, optional): Samplerate, default is None device (int, optional): Device, default is None log (logging.Logger, optional): Logger, default is None """ if not Path(MODEL_BASE_PATH).exists(): Path(MODEL_BASE_PATH).mkdir(parents=True) self.log = log or logging.getLogger(__name__) self._load_model_list() SetLogLevel(-1) self.downloading = False self.stop_downloading_event = threading.Event() self.stop_listening_event = threading.Event() self.wake_word_thread = None self.waked = False self.wake_word_thread_started = False self._device = device or sd.default.device if samplerate is None: device_info = sd.query_devices(self._device, "input") samplerate = int(device_info["default_samplerate"]) self._samplerate = samplerate self.recognizer = None self._language = None self.wake_words = None if language is not None: self.set_language(language, init=False) self.init()
[Doku] def is_ready(self): """ Check if Vosk STT is ready Returns: bool: True if ready, False otherwise """ return self.recognizer is not None
[Doku] def init(self): """ Initialize Vosk STT """ model_path = self.get_model_path(self._language) if not model_path.exists(): self.download_model(self._language) model_path = str(model_path) model = Model(model_path) self.recognizer = KaldiRecognizer(model, self._samplerate)
[Doku] def _load_model_list(self): """Load model list from local cache or built-in defaults (offline, no network).""" models = None if MODEL_LIST_CACHE_PATH.exists(): try: with open(MODEL_LIST_CACHE_PATH, "r", encoding="utf-8") as f: all_models = json.load(f) models = [model for model in all_models if model["type"] == "small" and model["obsolete"] == "false"] except Exception: pass if not models: models = DEFAULT_MODELS.copy() self.available_models = models self.available_languages = [model["lang"] for model in self.available_models] self.available_model_names = [model["name"] for model in self.available_models]
[Doku] def update_model_list(self): """Fetch latest model list from network and save to cache. Call this manually when you want to check for new models online. Falls back to local cache if network is unavailable. """ models = None try: with urlopen(MODEL_LIST_URL, timeout=5) as response: all_models = json.load(response) models = [model for model in all_models if model["type"] == "small" and model["obsolete"] == "false"] self.log.info(f"Model list updated from network ({len(models)} models)") if models: try: with open(MODEL_LIST_CACHE_PATH, "w", encoding="utf-8") as f: json.dump(all_models, f, ensure_ascii=False, indent=2) except Exception: pass except Exception: self.log.warning("Network unavailable, using local model list...") if MODEL_LIST_CACHE_PATH.exists(): try: with open(MODEL_LIST_CACHE_PATH, "r", encoding="utf-8") as f: all_models = json.load(f) models = [model for model in all_models if model["type"] == "small" and model["obsolete"] == "false"] self.log.info(f"Model list loaded from cache ({len(models)} models)") except Exception: pass if not models: self.log.warning("No local model list available, keeping current list") if models: self.available_models = models self.available_languages = [model["lang"] for model in self.available_models] self.available_model_names = [model["name"] for model in self.available_models]
[Doku] def wait_until_heard(self, wake_words=None, print_callback=lambda x: print(f"heard: \x1b[K{x}", end="\r", flush=True)): """ Wait until heard a wake word Args: wake_words (list, optional): Wake words, default is None print_callback (function, optional): Print callback, default is None Returns: str: Heard wake word """ if wake_words is None: wake_words = self.wake_words if isinstance(wake_words, str): wake_words = [wake_words] while True: result = self.listen(stream=False) print_callback(result) if result is None: continue if result.lower() in wake_words: break return result
[Doku] def heard_wake_word(self, print_callback=lambda x: print(f"heard: \x1b[K{x}", end="\r", flush=True)): """ Check if heard a wake word Args: print_callback (function, optional): Print callback, default is None Returns: bool: True if heard a wake word, False otherwise """ result = self.listen(stream=False) if result is None: return False print_callback(result) return result.lower() in self.wake_words
[Doku] def wait_for_wake_word(self): """ Wait for wake word """ self.wake_word_thread_started = True self.stop_listening_event.clear() while self.wake_word_thread_started: if self.stop_listening_event.is_set(): self.wake_word_thread_started = False break if self.heard_wake_word(): print("") self.waked = True self.wake_word_thread_started = False break time.sleep(0.1) self.wake_word_thread = None
[Doku] def start_listening_wake_words(self): """ Start listening for wake words """ self.waked = False self.wake_word_thread = threading.Thread(name="wake_word_thread", target=self.wait_for_wake_word) self.wake_word_thread_started = True self.wake_word_thread.start()
[Doku] def is_waked(self): """ Check if the wake word thread is running Returns: bool: True if running, False otherwise """ return self.waked
[Doku] def stt(self, filename, stream=False): """ Perform STT on audio file Args: filename (str): Audio file path stream (bool, optional): Stream mode, default is False Returns: str: STT result """ with wave.open(filename, "rb") as wf: if wf.getnchannels() != 1 or wf.getsampwidth() != 2 or wf.getcomptype() != "NONE": raise ValueError("Audio file must be WAV format mono PCM.") self.recognizer.SetWords(True) if stream: self.recognizer.SetPartialWords(True) return self.get_stream_result(wf, self.recognizer) else: self.recognizer.SetPartialWords(False) return self.recognizer.Result()
[Doku] def get_stream_result(self, wf, recognizer): """ Get streaming results from recognizer Args: wf (wave.Wave_read): Wave file object recognizer (KaldiRecognizer): Vosk recognizer Yields: str: STT result """ while True: data = wf.readframes(4000) if len(data) == 0: break if self.recognizer.AcceptWaveform(data): yield self.recognizer.Result() else: yield self.recognizer.PartialResult()
[Doku] def listen(self, stream=False, device=None, samplerate=None): """ Listen from microphone and return results Args: stream (bool, optional): Stream mode, default is False device (int, optional): Device index, default is None samplerate (int, optional): Sampling rate, default is None Returns: str: STT result """ q = queue.Queue() def callback(indata, frames, time, status): if status: self.log.warning(status) q.put(bytes(indata)) with ignore_stderr(): self.stop_listening_event.clear() if stream: return self._listen_streaming(q, device, samplerate, callback) else: return self._listen_non_streaming(q, device, samplerate, callback)
[Doku] def _listen_streaming(self, q, device=None, samplerate=None, callback=None): """ Listen from microphone and return streaming results Args: q (queue.Queue): Queue to store audio data device (int, optional): Device index, default is None samplerate (int, optional): Sampling rate, default is None callback (function, optional): Callback function, default is None Yields: dict: STT result """ with sd.RawInputStream( samplerate=samplerate, blocksize=1024, device=device, dtype="int16", channels=1, callback=callback): while True: if self.stop_listening_event.is_set(): return None try: data = q.get(timeout=0.5) except queue.Empty: continue result = { "done": False, "partial": "", "final": "" } if self.recognizer.AcceptWaveform(data): text = self.recognizer.Result() text = json.loads(text)["text"] if text == "": continue result["done"] = True result["final"] = text.strip() yield result break else: partial = self.recognizer.PartialResult() partial = json.loads(partial)["partial"] if partial == "" or partial.isspace(): continue result["partial"] = partial.strip() yield result
[Doku] def _listen_non_streaming(self, q, device=None, samplerate=None, callback=None): """ Listen from microphone and return final result Args: q (queue.Queue): Queue to store audio data device (int, optional): Device index, default is None samplerate (int, optional): Sampling rate, default is None callback (function, optional): Callback function, default is None Returns: str: STT result """ with sd.RawInputStream(samplerate=samplerate, blocksize=1024, device=device, dtype="int16", channels=1, callback=callback): while True: if self.stop_listening_event.is_set(): return None try: data = q.get(timeout=0.5) except queue.Empty: continue if self.recognizer.AcceptWaveform(data): text = self.recognizer.Result() text = json.loads(text)["text"] if text == "": continue return text
[Doku] def set_wake_words(self, wake_words: list): """ Set wake words Args: wake_words (list): List of wake words """ self.wake_words = wake_words
[Doku] def language(self) -> str: """ Get current language Returns: str: Current language """ return self._language
[Doku] def set_language(self, language: str, init=True): """ Set language Args: language (str): Language to set init (bool, optional): Initialize recognizer, default is True """ if language not in self.available_languages: raise ValueError(f"Vosk does not support language: {language}. Available languages: {self.available_languages}") self._language = language if init: self.init()
[Doku] def get_model_name(self, lang: str) -> str: """ Get model name for language Args: lang (str): Language Returns: str: Model name """ return self.available_model_names[self.available_languages.index(lang)]
[Doku] def get_model_path(self, lang: str) -> Path: """ Get model path for language Args: lang (str): Language Returns: Path: Model path """ model_name = self.get_model_name(lang) return Path(MODEL_BASE_PATH, model_name)
[Doku] def is_model_downloaded(self, lang: str) -> bool: """ Check if model is downloaded Args: lang (str): Language Returns: bool: True if model is downloaded, False otherwise """ model_path = self.get_model_path(lang) return model_path.exists()
[Doku] def cancel_download(self): """ Public method to cancel ongoing download """ if self.downloading: self.stop_downloading_event.set() # 触发终止事件 self.log.info("Download cancellation requested")
[Doku] def download_model(self, lang: str, progress_callback=None, max_retries: int=5): """ Download model for language Args: lang (str): Language progress_callback (function, optional): Progress callback function, default is None max_retries (int, optional): Maximum retries, default is 5 """ model_path = self.get_model_path(lang) if self.is_model_downloaded(lang): return if self.downloading: return self.downloading = True self.stop_downloading_event.clear() # 重置终止事件(确保每次下载前都是未触发状态) zip_url = MODEL_PRE_URL + f"{model_path.name}.zip" zip_path = f"{model_path}.zip" retries = 0 try: while retries < max_retries: # 检查是否已触发终止 if self.stop_downloading_event.is_set(): raise Exception("Download cancelled by user") try: # Check for partially downloaded file resume_byte_pos = 0 if os.path.exists(zip_path): resume_byte_pos = os.path.getsize(zip_path) self.log.info(f"Resuming download from byte position {resume_byte_pos}") headers = {} if resume_byte_pos > 0: headers['Range'] = f'bytes={resume_byte_pos}-' # Send request response = requests.get(zip_url, headers=headers, stream=True, timeout=30) # Check response status if response.status_code not in [200, 206]: # 200: full response, 206: partial content response.raise_for_status() # Get total file size content_length = response.headers.get('content-length') if content_length is None: total_size = None else: total_size = int(content_length) + resume_byte_pos # Prepare progress display if progress_callback: progress_callback(resume_byte_pos, total_size) else: t = tqdm( total=total_size, initial=resume_byte_pos, unit="B", unit_scale=True, unit_divisor=1024, desc=zip_url.rsplit("/", maxsplit=1)[-1] ) # Write to file mode = 'ab' if resume_byte_pos > 0 else 'wb' with open(zip_path, mode) as f: downloaded_this_attempt = 0 for chunk in response.iter_content(chunk_size=8192): # 每次写入前检查是否需要终止 if self.stop_downloading_event.is_set(): raise Exception("Download cancelled by user") if chunk: # Filter out keep-alive empty chunks f.write(chunk) chunk_size = len(chunk) downloaded_this_attempt += chunk_size resume_byte_pos += chunk_size if progress_callback: progress_callback(resume_byte_pos, total_size) else: t.update(chunk_size) if not progress_callback: t.close() # Verify file size if possible if total_size is not None: downloaded_size = os.path.getsize(zip_path) if downloaded_size != total_size: raise Exception(f"Download incomplete: received {downloaded_size} bytes, expected {total_size} bytes") else: self.log.warning("Cannot verify file integrity - server did not provide content length") # Unzip and clean up with ZipFile(zip_path, "r") as model_ref: model_ref.extractall(model_path.parent) os.remove(zip_path) # Download successful, exit loop break except Exception as e: # 如果是用户终止,直接跳出重试循环 if "cancelled by user" in str(e).lower(): self.log.info(f"Download cancelled: {str(e)}") retries += 1 if retries < max_retries: self.log.info(f"Retrying download ({retries}/{max_retries})...") else: self.log.warning(f"Download failed after {max_retries} attempts. Check network connection.") # Wait before retrying (exponential backoff) wait_time = 2** retries self.log.info(f"Retrying in {wait_time} seconds...") time.sleep(wait_time) except Exception as e: self.log.warning(f"Download stopped, will retry later. ({str(e)})") # 终止后保留部分下载文件(以便后续续传),如果需要删除可改为os.remove(zip_path) if os.path.exists(zip_path): self.log.debug(f"Partial download saved to {zip_path}") finally: self.downloading = False self.stop_downloading_event.clear() # 重置终止事件
[Doku] def download_progress_hook(self, tqdm_bar=None, progress_callback=None): """ Download progress hook function Args: tqdm_bar (tqdm, optional): tqdm progress bar, default is None progress_callback (function, optional): Progress callback function, default is None """ last_b = [0] def update_to(b=1, bsize=1, tsize=None): if tsize not in (None, -1): if tqdm_bar: tqdm_bar.total = tsize # Calculate downloaded bytes downloaded = (b - last_b[0]) * bsize last_b[0] = b if tqdm_bar: return tqdm_bar.update(downloaded) elif progress_callback: current = min(b * bsize, tsize) if tsize else b * bsize progress_callback(current, tsize) return downloaded return update_to
[Doku] def stop_listening(self): """ Stop listening for wake word """ self.stop_listening_event.set()
[Doku] def close(self): """ Close STT """ self.wake_word_thread_started = False self.stop_downloading_event.set() self.stop_listening_event.set()