import json
import os
import wave
from pathlib import Path
from typing import Callable, List
from urllib.error import URLError
from urllib.request import urlopen
from piper import PiperVoice
from piper.download_voices import _needs_download, VOICE_PATTERN, URL_FORMAT
from onnxruntime.capi.onnxruntime_pybind11_state import InvalidProtobuf
from tqdm import tqdm
from .piper_models import PIPER_MODELS as _DEFAULT_PIPER_MODELS, MODELS as _DEFAULT_MODELS, COUNTRYS as _DEFAULT_COUNTRYS
from .._audio_player import AudioPlayer
from .._base import _Base
HOME = os.path.expanduser("~")
PIPER_MODEL_DIR = f"{HOME}/.piper_models"
VOICES_JSON_URL = "https://huggingface.co/rhasspy/piper-voices/resolve/main/voices.json?download=true"
PIPER_MODEL_LIST_CACHE_PATH = Path(PIPER_MODEL_DIR, "voices.json")
def _parse_voices_json(voices_dict: dict) -> dict:
"""Parse HuggingFace voices.json format into PIPER_MODELS format.
Input: {"en_US-lessac-medium": {...}, "en_US-lessac-low": {...}, ...}
Output: {"en_US": {"lessac": ["en_US-lessac-low", "en_US-lessac-medium"], ...}, ...}
"""
models = {}
for voice_id in voices_dict:
match = VOICE_PATTERN.match(voice_id)
if not match:
continue
lang_code = match.group("lang_family") + "_" + match.group("lang_region")
voice_name = match.group("voice_name")
if lang_code not in models:
models[lang_code] = {}
if voice_name not in models[lang_code]:
models[lang_code][voice_name] = []
models[lang_code][voice_name].append(voice_id)
return models
[Doku]
class Piper(_Base):
""" Piper TTS engine.
Args:
model (str, optional): model, leave it None to use default model, defaults to None
*args: passed to :class:`sunfounder_voice_assistant._base._Base`.
**kwargs: passed to :class:`sunfounder_voice_assistant._base._Base`.
"""
def __init__(self, *args, model: str = None, **kwargs) -> None:
super().__init__(*args, **kwargs)
# Init model directory
if not os.path.exists(PIPER_MODEL_DIR):
os.makedirs(PIPER_MODEL_DIR, 0o777)
os.chown(PIPER_MODEL_DIR, 1000, 1000)
self._models = list(_DEFAULT_MODELS)
self._piper_models = {k: dict(v) for k, v in _DEFAULT_PIPER_MODELS.items()}
self._countrys = list(_DEFAULT_COUNTRYS)
self._load_model_list()
self.model = None
if model is not None:
self.set_model(model)
else:
self.piper = None
[Doku]
def _load_model_list(self):
"""Load model list from local cache or built-in defaults (offline, no network)."""
models = None
piper_models = None
if PIPER_MODEL_LIST_CACHE_PATH.exists():
try:
with open(PIPER_MODEL_LIST_CACHE_PATH, "r", encoding="utf-8") as f:
voices_dict = json.load(f)
piper_models = _parse_voices_json(voices_dict)
models = []
for voices in piper_models.values():
for model_list in voices.values():
models.extend(model_list)
except Exception:
pass
if not models:
piper_models = {k: dict(v) for k, v in _DEFAULT_PIPER_MODELS.items()}
models = list(_DEFAULT_MODELS)
self._piper_models = piper_models
self._models = models
self._countrys = list(piper_models.keys())
[Doku]
def update_model_list(self):
"""Fetch latest model list from network and save to cache.
Call this manually when you want to check for new models online.
Falls back to local cache if network is unavailable.
"""
models = None
piper_models = None
try:
with urlopen(VOICES_JSON_URL, timeout=5) as response:
voices_dict = json.load(response)
piper_models = _parse_voices_json(voices_dict)
models = []
for voices in piper_models.values():
for model_list in voices.values():
models.extend(model_list)
self.log.info(f"Model list updated from network ({len(models)} models)")
if models:
try:
with open(PIPER_MODEL_LIST_CACHE_PATH, "w", encoding="utf-8") as f:
json.dump(voices_dict, f, ensure_ascii=False, indent=2)
except Exception:
pass
except Exception:
self.log.warning("Network unavailable, using local model list...")
if PIPER_MODEL_LIST_CACHE_PATH.exists():
try:
with open(PIPER_MODEL_LIST_CACHE_PATH, "r", encoding="utf-8") as f:
voices_dict = json.load(f)
piper_models = _parse_voices_json(voices_dict)
models = []
for voices in piper_models.values():
for model_list in voices.values():
models.extend(model_list)
self.log.info(f"Model list loaded from cache ({len(models)} models)")
except Exception:
pass
if not models:
self.log.warning("No local model list available, keeping current list")
if models:
self._piper_models = piper_models
self._models = models
self._countrys = list(piper_models.keys())
[Doku]
def get_language(self) -> str:
""" Get language from model.
Returns:
str: language
"""
language = self.model.split("-")[0]
return language
[Doku]
def is_model_downloaded(self, model: str) -> bool:
""" Check if model is downloaded.
Args:
model (str): model
Returns:
bool: True if model is downloaded, False otherwise
"""
if model is None:
model = self.model
onnx_file = os.path.join(PIPER_MODEL_DIR, model + ".onnx")
json_file = onnx_file + ".json"
onnx_exists = os.path.exists(onnx_file)
json_exists = os.path.exists(json_file)
self.log.debug(f"Model {model} onnx file exists: {onnx_exists}")
self.log.debug(f"Model {model} json file exists: {json_exists}")
return onnx_exists and json_exists
[Doku]
def download_model(self,
model: str,
force: bool = False,
progress_callback: Callable[[int, int], None] = None) -> None:
""" Download model.
Args:
model (str): model
force (bool, optional): force download, default is False
progress_callback (Callable[[int, int], None], optional): progress callback, default is None
"""
model_path = self.get_model_path(model)
if not self.is_model_downloaded(model) or force:
self.log.info(f"Downloading model {model} to {model_path}")
download_voice(model,
Path(PIPER_MODEL_DIR),
force_redownload=force,
progress_callback=progress_callback)
[Doku]
def fix_chinese_punctuation(self, text: str) -> str:
"""Replace Chinese punctuation with English punctuation.
Args:
text (str): text
Returns:
str: text with English punctuation
"""
if self.get_language() != "zh_CN":
return text
MAP = {
',': '. ',
'。': '. ',
'!': '! ',
'?': '? ',
'——': '. ',
'“': '"',
'”': '"',
'‘': "'",
'’': "'",
"~": ". ",
"~": ". ",
":": ". ",
"...": ". ",
"……": ". ",
"、": ". ",
}
for k, v in MAP.items():
text = text.replace(k, v)
# find number followed by dot and replace with number followed by 点
import re
text = re.sub(r'(\d)\.(\d)', r'\1点\2', text)
return text
[Doku]
def tts(self, text: str, file: str) -> None:
""" Synthesize text to wave file.
Args:
text (str): text
file (str): wave file path
Raises:
ValueError: Model not set, set model first, with Piper.set_model(model)
"""
if self.piper is None:
raise ValueError("Model not set, set model first, with Piper.set_model(model)")
text = self.fix_chinese_punctuation(text)
with wave.open(file, "wb") as wav_file:
self.piper.synthesize_wav(text, wav_file)
[Doku]
def stream(self, text: str) -> None:
""" Stream text to speaker.
Args:
text (str): text
Raises:
ValueError: Model not set, set model first, with Piper.set_model(model)
"""
if self.piper is None:
raise ValueError("Model not set, set model first, with Piper.set_model(model)")
text = self.fix_chinese_punctuation(text)
with AudioPlayer(self.piper.config.sample_rate) as player:
for chunk in self.piper.synthesize(text):
player.play(chunk.audio_int16_bytes)
[Doku]
def say(self, text: str, stream: bool = True) -> None:
""" Say text.
Args:
text (str): text
stream (bool, optional): stream to speaker, default is True
Raises:
ValueError: Model not set, set model first, with Piper.set_model(model)
"""
if self.piper is None:
raise ValueError("Model not set, set model first, with Piper.set_model(model)")
if stream:
self.stream(text)
else:
file = "./tts_piper.wav"
self.tts(text, file)
# Use the correct sample rate from Piper config
with AudioPlayer(self.piper.config.sample_rate) as player:
player.play_file(file)
[Doku]
def available_models(self, country: str = None) -> List[str]:
""" Get available models.
Args:
country (str, optional): country, leave it None to get all models, defaults to None
Returns:
List[str]: available models
"""
if country is None:
return self._models
else:
return self._piper_models.get(country, [])
[Doku]
def available_countrys(self) -> List[str]:
""" Get available countrys.
Returns:
List[str]: available countrys
"""
return self._countrys
[Doku]
def get_model_path(self, model: str) -> str:
""" Get model path.
Args:
model (str): model
Returns:
str: model path
"""
return os.path.join(PIPER_MODEL_DIR, model + ".onnx")
[Doku]
def set_model(self, model: str) -> None:
""" Set model.
Args:
model (str): model
Raises:
ValueError: Model not found
"""
if model in self._models:
model_path = self.get_model_path(model)
if not self.is_model_downloaded(model):
self.log.warning(f"Model {model} not downloaded, downloading...")
self.download_model(model)
try:
self.piper = PiperVoice.load(model_path)
except InvalidProtobuf as e:
self.log.warning(f"Failed to load model {model_path}: {e}, try to redownload model.")
self.download_model(model, force=True)
self.piper = PiperVoice.load(model_path)
self.model = model
else:
raise ValueError("Model not found")
def download_voice(voice: str,
download_dir: Path,
force_redownload: bool = False,
progress_callback: Callable[[int, int], None] = None
) -> None:
"""Download a voice model and config file to a directory.
Args:
voice (str): voice
download_dir (Path): download directory
force_redownload (bool, optional): force redownload, default is False
progress_callback (Callable[[int, int], None], optional): progress callback, default is None
"""
voice = voice.strip()
voice_match = VOICE_PATTERN.match(voice)
if not voice_match:
raise ValueError(
f"Voice '{voice}' did not match pattern: <language>-<name>-<quality> like 'en_US-lessac-medium'",
)
lang_family = voice_match.group("lang_family")
lang_code = lang_family + "_" + voice_match.group("lang_region")
voice_name = voice_match.group("voice_name")
voice_quality = voice_match.group("voice_quality")
voice_code = f"{lang_code}-{voice_name}-{voice_quality}"
format_args = {
"lang_family": lang_family,
"lang_code": lang_code,
"voice_name": voice_name,
"voice_quality": voice_quality,
}
# 下载模型文件(带进度条)
model_path = download_dir / f"{voice_code}.onnx"
if force_redownload or _needs_download(model_path):
model_url = URL_FORMAT.format(extension=".onnx",** format_args)
_download_with_progress(model_url, model_path, progress_callback)
# 下载配置文件(带进度条)
config_path = download_dir / f"{voice_code}.onnx.json"
if force_redownload or _needs_download(config_path):
config_url = URL_FORMAT.format(extension=".onnx.json", **format_args)
_download_with_progress(config_url, config_path, progress_callback)
# _LOGGER.info("Downloaded: %s", voice)
def _download_with_progress(url: str,
output_path: Path,
progress_callback: Callable[[int, int], None] = None) -> None:
""" Download file with progress bar.
Args:
url (str): URL
output_path (Path): output path
progress_callback (Callable[[int, int], None], optional): progress callback function, default is None
"""
with urlopen(url) as response:
file_size = int(response.headers.get("Content-Length", 0))
if progress_callback:
progress_callback(0, file_size)
else:
progress_bar = tqdm(
total=file_size,
unit="B",
unit_scale=True,
unit_divisor=1024,
desc=f"Downloading {output_path.name}",
leave=True
)
with open(output_path, "wb") as out_file:
while True:
chunk = response.read(8192) # 8KB 块
if not chunk:
break
out_file.write(chunk)
if progress_callback:
progress_callback(len(chunk), file_size)
else:
progress_bar.update(len(chunk))
if progress_callback:
progress_callback(file_size, file_size)
else:
progress_bar.close()