diff --git a/veadk/toolkits/audio/asr/__init__.py b/veadk/toolkits/audio/asr/__init__.py new file mode 100644 index 00000000..7f463206 --- /dev/null +++ b/veadk/toolkits/audio/asr/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2025 Beijing Volcano Engine Technology Co., Ltd. and/or its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/veadk/toolkits/audio/asr/asr_client.py b/veadk/toolkits/audio/asr/asr_client.py new file mode 100644 index 00000000..71dc4e41 --- /dev/null +++ b/veadk/toolkits/audio/asr/asr_client.py @@ -0,0 +1,586 @@ +# Copyright (c) 2025 Beijing Volcano Engine Technology Co., Ltd. and/or its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +import gzip +import json +import os +import struct +import subprocess +import uuid +from typing import Any, AsyncGenerator, Dict, List, Tuple + +import aiohttp + +from veadk.utils.logger import get_logger + +logger = get_logger(__name__) + + +DEFAULT_SAMPLE_RATE = 16000 + + +class ProtocolVersion: + V1 = 0b0001 + + +class MessageType: + CLIENT_FULL_REQUEST = 0b0001 + CLIENT_AUDIO_ONLY_REQUEST = 0b0010 + SERVER_FULL_RESPONSE = 0b1001 + SERVER_ERROR_RESPONSE = 0b1111 + + +class MessageTypeSpecificFlags: + NO_SEQUENCE = 0b0000 + POS_SEQUENCE = 0b0001 + NEG_SEQUENCE = 0b0010 + NEG_WITH_SEQUENCE = 0b0011 + + +class SerializationType: + NO_SERIALIZATION = 0b0000 + JSON = 0b0001 + + +class CompressionType: + GZIP = 0b0001 + + +class Config: + def __init__(self): + self.auth = {"app_key": "", "access_key": ""} + + @property + def app_key(self) -> str: + return self.auth["app_key"] + + @property + def access_key(self) -> str: + return self.auth["access_key"] + + +config = Config() + + +class CommonUtils: + @staticmethod + def gzip_compress(data: bytes) -> bytes: + return gzip.compress(data) + + @staticmethod + def gzip_decompress(data: bytes) -> bytes: + return gzip.decompress(data) + + @staticmethod + def judge_wav(data: bytes) -> bool: + if len(data) < 44: + return False + return data[:4] == b"RIFF" and data[8:12] == b"WAVE" + + @staticmethod + def convert_wav_with_path( + audio_path: str, sample_rate: int = DEFAULT_SAMPLE_RATE + ) -> bytes: + try: + cmd = [ + "ffmpeg", + "-v", + "quiet", + "-y", + "-i", + audio_path, + "-acodec", + "pcm_s16le", + "-ac", + "1", + "-ar", + str(sample_rate), + "-f", + "wav", + "-", + ] + result = subprocess.run( + cmd, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE + ) + + # 尝试删除原始文件 + try: + os.remove(audio_path) + except OSError as e: + logger.warning(f"Failed to remove original file: {e}") + + return result.stdout + except subprocess.CalledProcessError as e: + logger.error(f"FFmpeg conversion failed: {e.stderr.decode()}") + raise RuntimeError(f"Audio conversion failed: {e.stderr.decode()}") + + @staticmethod + def read_wav_info(data: bytes) -> Tuple[int, int, int, int, bytes]: + if len(data) < 44: + raise ValueError("Invalid WAV file: too short") + + # 解析WAV头 + chunk_id = data[:4] + if chunk_id != b"RIFF": + raise ValueError("Invalid WAV file: not RIFF format") + + format_ = data[8:12] + if format_ != b"WAVE": + raise ValueError("Invalid WAV file: not WAVE format") + + # 解析fmt子块 + # audio_format = struct.unpack(" "AsrRequestHeader": + self.message_type = message_type + return self + + def with_message_type_specific_flags(self, flags: int) -> "AsrRequestHeader": + self.message_type_specific_flags = flags + return self + + def with_serialization_type(self, serialization_type: int) -> "AsrRequestHeader": + self.serialization_type = serialization_type + return self + + def with_compression_type(self, compression_type: int) -> "AsrRequestHeader": + self.compression_type = compression_type + return self + + def with_reserved_data(self, reserved_data: bytes) -> "AsrRequestHeader": + self.reserved_data = reserved_data + return self + + def to_bytes(self) -> bytes: + header = bytearray() + header.append((ProtocolVersion.V1 << 4) | 1) + header.append((self.message_type << 4) | self.message_type_specific_flags) + header.append((self.serialization_type << 4) | self.compression_type) + header.extend(self.reserved_data) + return bytes(header) + + @staticmethod + def default_header() -> "AsrRequestHeader": + return AsrRequestHeader() + + +class RequestBuilder: + @staticmethod + def new_auth_headers() -> Dict[str, str]: + reqid = str(uuid.uuid4()) + return { + "X-Api-Resource-Id": "volc.bigasr.sauc.duration", + "X-Api-Request-Id": reqid, + "X-Api-Access-Key": config.access_key, + "X-Api-App-Key": config.app_key, + } + + @staticmethod + def new_full_client_request(seq: int) -> bytes: # 添加seq参数 + header = AsrRequestHeader.default_header().with_message_type_specific_flags( + MessageTypeSpecificFlags.POS_SEQUENCE + ) + + payload = { + "user": {"uid": "demo_uid"}, + "audio": { + # "format": "wav", + "format": "pcm", + "codec": "raw", + "rate": 16000, + "bits": 16, + "channel": 1, + }, + "request": { + "model_name": "bigmodel", + "enable_itn": True, + "enable_punc": True, + "enable_ddc": True, + "show_utterances": True, + "enable_nonstream": False, + }, + } + + payload_bytes = json.dumps(payload).encode("utf-8") + compressed_payload = CommonUtils.gzip_compress(payload_bytes) + payload_size = len(compressed_payload) + + request = bytearray() + request.extend(header.to_bytes()) + request.extend(struct.pack(">i", seq)) # 使用传入的seq + request.extend(struct.pack(">I", payload_size)) + request.extend(compressed_payload) + + return bytes(request) + + @staticmethod + def new_audio_only_request( + seq: int, segment: bytes, is_last: bool = False + ) -> bytes: + header = AsrRequestHeader.default_header() + if is_last: # 最后一个包特殊处理 + header.with_message_type_specific_flags( + MessageTypeSpecificFlags.NEG_WITH_SEQUENCE + ) + seq = -seq # 设为负值 + else: + header.with_message_type_specific_flags( + MessageTypeSpecificFlags.POS_SEQUENCE + ) + header.with_message_type(MessageType.CLIENT_AUDIO_ONLY_REQUEST) + + request = bytearray() + request.extend(header.to_bytes()) + request.extend(struct.pack(">i", seq)) + + compressed_segment = CommonUtils.gzip_compress(segment) + request.extend(struct.pack(">I", len(compressed_segment))) + request.extend(compressed_segment) + + return bytes(request) + + +class AsrResponse: + def __init__(self): + self.code = 0 + self.event = 0 + self.is_last_package = False + self.payload_sequence = 0 + self.payload_size = 0 + self.payload_msg = None + + def to_dict(self) -> Dict[str, Any]: + return { + "code": self.code, + "event": self.event, + "is_last_package": self.is_last_package, + "payload_sequence": self.payload_sequence, + "payload_size": self.payload_size, + "payload_msg": self.payload_msg, + } + + +class ResponseParser: + @staticmethod + def parse_response(msg: bytes) -> AsrResponse: + response = AsrResponse() + + header_size = msg[0] & 0x0F + message_type = msg[1] >> 4 + message_type_specific_flags = msg[1] & 0x0F + serialization_method = msg[2] >> 4 + message_compression = msg[2] & 0x0F + + payload = msg[header_size * 4 :] + + # 解析message_type_specific_flags + if message_type_specific_flags & 0x01: + response.payload_sequence = struct.unpack(">i", payload[:4])[0] + payload = payload[4:] + if message_type_specific_flags & 0x02: + response.is_last_package = True + if message_type_specific_flags & 0x04: + response.event = struct.unpack(">i", payload[:4])[0] + payload = payload[4:] + + # 解析message_type + if message_type == MessageType.SERVER_FULL_RESPONSE: + response.payload_size = struct.unpack(">I", payload[:4])[0] + payload = payload[4:] + elif message_type == MessageType.SERVER_ERROR_RESPONSE: + response.code = struct.unpack(">i", payload[:4])[0] + response.payload_size = struct.unpack(">I", payload[4:8])[0] + payload = payload[8:] + + if not payload: + return response + + # 解压缩 + if message_compression == CompressionType.GZIP: + try: + payload = CommonUtils.gzip_decompress(payload) + except Exception as e: + logger.error(f"Failed to decompress payload: {e}") + return response + + # 解析payload + try: + if serialization_method == SerializationType.JSON: + response.payload_msg = json.loads(payload.decode("utf-8")) + except Exception as e: + logger.error(f"Failed to parse payload: {e}") + + return response + + +ASR_WS_URL = "wss://openspeech.bytedance.com/api/v3/sauc/bigmodel" + + +class AsrWsClient: + def __init__( + self, + app_id: str, + access_token: str, + url: str = ASR_WS_URL, + segment_duration: int = 200, + ): + self.seq = 1 + self.url = url + self.segment_duration = segment_duration + self.conn = None + self.session = None + + global config + config.auth = { + "app_key": app_id, + "access_key": access_token, + } + + async def __aenter__(self): + self.session = aiohttp.ClientSession() + return self + + async def __aexit__(self, exc_type, exc, tb): + if self.conn and not self.conn.closed: + await self.conn.close() + if self.session and not self.session.closed: + await self.session.close() + + async def read_audio_data(self, file_path: str) -> bytes: + try: + with open(file_path, "rb") as f: + content = f.read() + + if not CommonUtils.judge_wav(content): + logger.info("Converting audio to WAV format...") + content = CommonUtils.convert_wav_with_path( + file_path, DEFAULT_SAMPLE_RATE + ) + + return content + except Exception as e: + logger.error(f"Failed to read audio data: {e}") + raise + + def get_segment_size(self, content: bytes) -> int: + try: + channel_num, samp_width, frame_rate, _, _ = CommonUtils.read_wav_info( + content + )[:5] + size_per_sec = channel_num * samp_width * frame_rate + segment_size = size_per_sec * self.segment_duration // 1000 + return segment_size + except Exception as e: + logger.error(f"Failed to calculate segment size: {e}") + raise + + async def create_connection(self) -> None: + headers = RequestBuilder.new_auth_headers() + try: + self.conn = await self.session.ws_connect( # 使用self.session + self.url, headers=headers + ) + logger.info(f"Connected to {self.url}") + except Exception as e: + logger.error(f"Failed to connect to WebSocket: {e}") + raise + + async def send_full_client_request(self) -> None: + request = RequestBuilder.new_full_client_request(self.seq) + self.seq += 1 # 发送后递增 + try: + await self.conn.send_bytes(request) + logger.info(f"Sent full client request with seq: {self.seq - 1}") + + msg = await self.conn.receive() + if msg.type == aiohttp.WSMsgType.BINARY: + response = ResponseParser.parse_response(msg.data) + logger.info(f"Received response: {response.to_dict()}") + else: + logger.error(f"Unexpected message type: {msg.type}") + except Exception as e: + logger.error(f"Failed to send full client request: {e}") + raise + + async def send_messages( + self, segment_size: int, content: bytes + ) -> AsyncGenerator[None, None]: + audio_segments = self.split_audio(content, segment_size) + total_segments = len(audio_segments) + + for i, segment in enumerate(audio_segments): + is_last = i == total_segments - 1 + request = RequestBuilder.new_audio_only_request( + self.seq, segment, is_last=is_last + ) + await self.conn.send_bytes(request) + logger.info(f"Sent audio segment with seq: {self.seq} (last: {is_last})") + + if not is_last: + self.seq += 1 + + await asyncio.sleep( + self.segment_duration / 1000 + ) # 逐个发送,间隔时间模拟实时流 + # 让出控制权,允许接受消息 + yield + + async def recv_messages(self) -> AsyncGenerator[AsrResponse, None]: + try: + async for msg in self.conn: + if msg.type == aiohttp.WSMsgType.BINARY: + response = ResponseParser.parse_response(msg.data) + yield response + + if response.is_last_package or response.code != 0: + break + elif msg.type == aiohttp.WSMsgType.ERROR: + logger.error(f"WebSocket error: {msg.data}") + break + elif msg.type == aiohttp.WSMsgType.CLOSED: + logger.info("WebSocket connection closed") + break + except Exception as e: + logger.error(f"Error receiving messages: {e}") + raise + + async def start_audio_stream( + self, segment_size: int, content: bytes + ) -> AsyncGenerator[AsrResponse, None]: + async def sender(): + async for _ in self.send_messages(segment_size, content): + pass + + # 启动发送和接收任务 + sender_task = asyncio.create_task(sender()) + + try: + async for response in self.recv_messages(): + yield response + finally: + sender_task.cancel() + try: + await sender_task + except asyncio.CancelledError: + pass + + @staticmethod + def split_audio(data: bytes, segment_size: int) -> List[bytes]: + if segment_size <= 0: + return [] + + segments = [] + for i in range(0, len(data), segment_size): + end = i + segment_size + if end > len(data): + end = len(data) + segments.append(data[i:end]) + return segments + + async def execute(self, file_path: str) -> AsyncGenerator[AsrResponse, None]: + if not file_path: + raise ValueError("File path is empty") + + if not self.url: + raise ValueError("URL is empty") + + self.seq = 1 + + try: + # 1. 读取音频文件 + content = await self.read_audio_data(file_path) + + # 2. 计算分段大小 + segment_size = self.get_segment_size(content) + + # 3. 创建WebSocket连接 + await self.create_connection() + + # 4. 发送完整客户端请求 + await self.send_full_client_request() + + # 5. 启动音频流处理 + async for response in self.start_audio_stream(segment_size, content): + yield response + + except Exception as e: + logger.error(f"Error in ASR execution: {e}") + raise + finally: + if self.conn: + await self.conn.close() + + async def execute_stream( + self, audio_stream: AsyncGenerator[bytes, None] + ) -> AsyncGenerator[AsrResponse, None]: + """ + audio_stream: PCM bytes from frontend + """ + + self.seq = 1 + + await self.create_connection() + await self.send_full_client_request() + + async def sender(): + async for chunk in audio_stream: + request = RequestBuilder.new_audio_only_request( + self.seq, chunk, is_last=False + ) + await self.conn.send_bytes(request) + self.seq += 1 + + # 发送最后一个包 + end_req = RequestBuilder.new_audio_only_request(self.seq, b"", is_last=True) + await self.conn.send_bytes(end_req) + + sender_task = asyncio.create_task(sender()) + + try: + async for response in self.recv_messages(): + yield response + finally: + sender_task.cancel() + await self.conn.close() diff --git a/veadk/toolkits/audio/tts/__init__.py b/veadk/toolkits/audio/tts/__init__.py new file mode 100644 index 00000000..7f463206 --- /dev/null +++ b/veadk/toolkits/audio/tts/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2025 Beijing Volcano Engine Technology Co., Ltd. and/or its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/veadk/toolkits/audio/tts/protocols.py b/veadk/toolkits/audio/tts/protocols.py new file mode 100644 index 00000000..6bb2eb99 --- /dev/null +++ b/veadk/toolkits/audio/tts/protocols.py @@ -0,0 +1,581 @@ +# Copyright (c) 2025 Beijing Volcano Engine Technology Co., Ltd. and/or its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import io +import logging +import struct +from dataclasses import dataclass +from enum import IntEnum +from typing import Callable, List + +import websockets + +logger = logging.getLogger(__name__) + + +class MsgType(IntEnum): + """Message type enumeration""" + + Invalid = 0 + FullClientRequest = 0b1 + AudioOnlyClient = 0b10 + FullServerResponse = 0b1001 + AudioOnlyServer = 0b1011 + FrontEndResultServer = 0b1100 + Error = 0b1111 + + # Alias + ServerACK = AudioOnlyServer + + def __str__(self) -> str: + return self.name if self.name else f"MsgType({self.value})" + + +class MsgTypeFlagBits(IntEnum): + """Message type flag bits""" + + NoSeq = 0 # Non-terminal packet with no sequence + PositiveSeq = 0b1 # Non-terminal packet with sequence > 0 + LastNoSeq = 0b10 # Last packet with no sequence + NegativeSeq = 0b11 # Last packet with sequence < 0 + WithEvent = 0b100 # Payload contains event number (int32) + + +class VersionBits(IntEnum): + """Version bits""" + + Version1 = 1 + Version2 = 2 + Version3 = 3 + Version4 = 4 + + +class HeaderSizeBits(IntEnum): + """Header size bits""" + + HeaderSize4 = 1 + HeaderSize8 = 2 + HeaderSize12 = 3 + HeaderSize16 = 4 + + +class SerializationBits(IntEnum): + """Serialization method bits""" + + Raw = 0 + JSON = 0b1 + Thrift = 0b11 + Custom = 0b1111 + + +class CompressionBits(IntEnum): + """Compression method bits""" + + None_ = 0 + Gzip = 0b1 + Custom = 0b1111 + + +class EventType(IntEnum): + """Event type enumeration""" + + None_ = 0 # Default event + + # 1 ~ 49 Upstream Connection events + StartConnection = 1 + StartTask = 1 # Alias of StartConnection + FinishConnection = 2 + FinishTask = 2 # Alias of FinishConnection + + # 50 ~ 99 Downstream Connection events + ConnectionStarted = 50 # Connection established successfully + TaskStarted = 50 # Alias of ConnectionStarted + ConnectionFailed = 51 # Connection failed (possibly due to authentication failure) + TaskFailed = 51 # Alias of ConnectionFailed + ConnectionFinished = 52 # Connection ended + TaskFinished = 52 # Alias of ConnectionFinished + + # 100 ~ 149 Upstream Session events + StartSession = 100 + CancelSession = 101 + FinishSession = 102 + + # 150 ~ 199 Downstream Session events + SessionStarted = 150 + SessionCanceled = 151 + SessionFinished = 152 + SessionFailed = 153 + UsageResponse = 154 # Usage response + ChargeData = 154 # Alias of UsageResponse + + # 200 ~ 249 Upstream general events + TaskRequest = 200 + UpdateConfig = 201 + + # 250 ~ 299 Downstream general events + AudioMuted = 250 + + # 300 ~ 349 Upstream TTS events + SayHello = 300 + + # 350 ~ 399 Downstream TTS events + TTSSentenceStart = 350 + TTSSentenceEnd = 351 + TTSResponse = 352 + TTSEnded = 359 + PodcastRoundStart = 360 + PodcastRoundResponse = 361 + PodcastRoundEnd = 362 + + # 450 ~ 499 Downstream ASR events + ASRInfo = 450 + ASRResponse = 451 + ASREnded = 459 + + # 500 ~ 549 Upstream dialogue events + ChatTTSText = 500 # (Ground-Truth-Alignment) text for speech synthesis + + # 550 ~ 599 Downstream dialogue events + ChatResponse = 550 + ChatEnded = 559 + + # 650 ~ 699 Downstream dialogue events + # Events for source (original) language subtitle + SourceSubtitleStart = 650 + SourceSubtitleResponse = 651 + SourceSubtitleEnd = 652 + # Events for target (translation) language subtitle + TranslationSubtitleStart = 653 + TranslationSubtitleResponse = 654 + TranslationSubtitleEnd = 655 + + def __str__(self) -> str: + return self.name if self.name else f"EventType({self.value})" + + +@dataclass +class Message: + """Message object + + Message format: + 0 1 2 3 + | 0 1 2 3 4 5 6 7 | 0 1 2 3 4 5 6 7 | 0 1 2 3 4 5 6 7 | 0 1 2 3 4 5 6 7 | + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + | Version | Header Size | Msg Type | Flags | + | (4 bits) | (4 bits) | (4 bits) | (4 bits) | + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + | Serialization | Compression | Reserved | + | (4 bits) | (4 bits) | (8 bits) | + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + | | + | Optional Header Extensions | + | (if Header Size > 1) | + | | + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + | | + | Payload | + | (variable length) | + | | + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + """ + + version: VersionBits = VersionBits.Version1 + header_size: HeaderSizeBits = HeaderSizeBits.HeaderSize4 + type: MsgType = MsgType.Invalid + flag: MsgTypeFlagBits = MsgTypeFlagBits.NoSeq + serialization: SerializationBits = SerializationBits.JSON + compression: CompressionBits = CompressionBits.None_ + + event: EventType = EventType.None_ + session_id: str = "" + connect_id: str = "" + sequence: int = 0 + error_code: int = 0 + + payload: bytes = b"" + + @classmethod + def from_bytes(cls, data: bytes) -> "Message": + """Create message object from bytes""" + if len(data) < 3: + raise ValueError( + f"Data too short: expected at least 3 bytes, got {len(data)}" + ) + + type_and_flag = data[1] + msg_type = MsgType(type_and_flag >> 4) + flag = MsgTypeFlagBits(type_and_flag & 0b00001111) + + msg = cls(type=msg_type, flag=flag) + msg.unmarshal(data) + return msg + + def marshal(self) -> bytes: + """Serialize message to bytes""" + buffer = io.BytesIO() + + # Write header + header = [ + (self.version << 4) | self.header_size, + (self.type << 4) | self.flag, + (self.serialization << 4) | self.compression, + ] + + header_size = 4 * self.header_size + if padding := header_size - len(header): + header.extend([0] * padding) + + buffer.write(bytes(header)) + + # Write other fields + writers = self._get_writers() + for writer in writers: + writer(buffer) + + return buffer.getvalue() + + def unmarshal(self, data: bytes) -> None: + """Deserialize message from bytes""" + buffer = io.BytesIO(data) + + # Read version and header size + version_and_header_size = buffer.read(1)[0] + self.version = VersionBits(version_and_header_size >> 4) + self.header_size = HeaderSizeBits(version_and_header_size & 0b00001111) + + # Skip second byte + buffer.read(1) + + # Read serialization and compression methods + serialization_compression = buffer.read(1)[0] + self.serialization = SerializationBits(serialization_compression >> 4) + self.compression = CompressionBits(serialization_compression & 0b00001111) + + # Skip header padding + header_size = 4 * self.header_size + read_size = 3 + if padding_size := header_size - read_size: + buffer.read(padding_size) + + # Read other fields + readers = self._get_readers() + for reader in readers: + reader(buffer) + + # Check for remaining data + remaining = buffer.read() + if remaining: + raise ValueError(f"Unexpected data after message: {remaining}") + + def _get_writers(self) -> List[Callable[[io.BytesIO], None]]: + """Get list of writer functions""" + writers = [] + + if self.flag == MsgTypeFlagBits.WithEvent: + writers.extend([self._write_event, self._write_session_id]) + + if self.type in [ + MsgType.FullClientRequest, + MsgType.FullServerResponse, + MsgType.FrontEndResultServer, + MsgType.AudioOnlyClient, + MsgType.AudioOnlyServer, + ]: + if self.flag in [ + MsgTypeFlagBits.PositiveSeq, + MsgTypeFlagBits.NegativeSeq, + ]: + writers.append(self._write_sequence) + elif self.type == MsgType.Error: + writers.append(self._write_error_code) + else: + raise ValueError(f"Unsupported message type: {self.type}") + + writers.append(self._write_payload) + return writers + + def _get_readers(self) -> List[Callable[[io.BytesIO], None]]: + """Get list of reader functions""" + readers = [] + + if self.type in [ + MsgType.FullClientRequest, + MsgType.FullServerResponse, + MsgType.FrontEndResultServer, + MsgType.AudioOnlyClient, + MsgType.AudioOnlyServer, + ]: + if self.flag in [ + MsgTypeFlagBits.PositiveSeq, + MsgTypeFlagBits.NegativeSeq, + ]: + readers.append(self._read_sequence) + elif self.type == MsgType.Error: + readers.append(self._read_error_code) + else: + raise ValueError(f"Unsupported message type: {self.type}") + + if self.flag == MsgTypeFlagBits.WithEvent: + readers.extend( + [self._read_event, self._read_session_id, self._read_connect_id] + ) + + readers.append(self._read_payload) + return readers + + def _write_event(self, buffer: io.BytesIO) -> None: + """Write event""" + buffer.write(struct.pack(">i", self.event)) + + def _write_session_id(self, buffer: io.BytesIO) -> None: + """Write session ID""" + if self.event in [ + EventType.StartConnection, + EventType.FinishConnection, + EventType.ConnectionStarted, + EventType.ConnectionFailed, + ]: + return + + session_id_bytes = self.session_id.encode("utf-8") + size = len(session_id_bytes) + if size > 0xFFFFFFFF: + raise ValueError(f"Session ID size ({size}) exceeds max(uint32)") + + buffer.write(struct.pack(">I", size)) + if size > 0: + buffer.write(session_id_bytes) + + def _write_sequence(self, buffer: io.BytesIO) -> None: + """Write sequence number""" + buffer.write(struct.pack(">i", self.sequence)) + + def _write_error_code(self, buffer: io.BytesIO) -> None: + """Write error code""" + buffer.write(struct.pack(">I", self.error_code)) + + def _write_payload(self, buffer: io.BytesIO) -> None: + """Write payload""" + size = len(self.payload) + if size > 0xFFFFFFFF: + raise ValueError(f"Payload size ({size}) exceeds max(uint32)") + + buffer.write(struct.pack(">I", size)) + buffer.write(self.payload) + + def _read_event(self, buffer: io.BytesIO) -> None: + """Read event""" + event_bytes = buffer.read(4) + if event_bytes: + self.event = EventType(struct.unpack(">i", event_bytes)[0]) + + def _read_session_id(self, buffer: io.BytesIO) -> None: + """Read session ID""" + if self.event in [ + EventType.StartConnection, + EventType.FinishConnection, + EventType.ConnectionStarted, + EventType.ConnectionFailed, + EventType.ConnectionFinished, + ]: + return + + size_bytes = buffer.read(4) + if size_bytes: + size = struct.unpack(">I", size_bytes)[0] + if size > 0: + session_id_bytes = buffer.read(size) + if len(session_id_bytes) == size: + self.session_id = session_id_bytes.decode("utf-8") + + def _read_connect_id(self, buffer: io.BytesIO) -> None: + """Read connection ID""" + if self.event in [ + EventType.ConnectionStarted, + EventType.ConnectionFailed, + EventType.ConnectionFinished, + ]: + size_bytes = buffer.read(4) + if size_bytes: + size = struct.unpack(">I", size_bytes)[0] + if size > 0: + self.connect_id = buffer.read(size).decode("utf-8") + + def _read_sequence(self, buffer: io.BytesIO) -> None: + """Read sequence number""" + sequence_bytes = buffer.read(4) + if sequence_bytes: + self.sequence = struct.unpack(">i", sequence_bytes)[0] + + def _read_error_code(self, buffer: io.BytesIO) -> None: + """Read error code""" + error_code_bytes = buffer.read(4) + if error_code_bytes: + self.error_code = struct.unpack(">I", error_code_bytes)[0] + + def _read_payload(self, buffer: io.BytesIO) -> None: + """Read payload""" + size_bytes = buffer.read(4) + if size_bytes: + size = struct.unpack(">I", size_bytes)[0] + if size > 0: + self.payload = buffer.read(size) + + def __str__(self) -> str: + """String representation""" + if self.type in [MsgType.AudioOnlyServer, MsgType.AudioOnlyClient]: + if self.flag in [ + MsgTypeFlagBits.PositiveSeq, + MsgTypeFlagBits.NegativeSeq, + ]: + return f"MsgType: {self.type}, EventType:{self.event}, Sequence: {self.sequence}, PayloadSize: {len(self.payload)}" + return f"MsgType: {self.type}, EventType:{self.event}, PayloadSize: {len(self.payload)}" + elif self.type == MsgType.Error: + return f"MsgType: {self.type}, EventType:{self.event}, ErrorCode: {self.error_code}, Payload: {self.payload.decode('utf-8', 'ignore')}" + else: + if self.flag in [ + MsgTypeFlagBits.PositiveSeq, + MsgTypeFlagBits.NegativeSeq, + ]: + return f"MsgType: {self.type}, EventType:{self.event}, Sequence: {self.sequence}, Payload: {self.payload.decode('utf-8', 'ignore')}" + return f"MsgType: {self.type}, EventType:{self.event}, Payload: {self.payload.decode('utf-8', 'ignore')}" + + +async def receive_message( + websocket: websockets.WebSocketClientProtocol, +) -> Message: + """Receive message from websocket""" + try: + data = await websocket.recv() + if isinstance(data, str): + raise ValueError(f"Unexpected text message: {data}") + elif isinstance(data, bytes): + msg = Message.from_bytes(data) + logger.info(f"Received: {msg}") + return msg + else: + raise ValueError(f"Unexpected message type: {type(data)}") + except Exception as e: + logger.error(f"Failed to receive message: {e}") + raise + + +async def wait_for_event( + websocket: websockets.WebSocketClientProtocol, + msg_type: MsgType, + event_type: EventType, +) -> Message: + """Wait for specific event""" + while True: + msg = await receive_message(websocket) + if msg.type != msg_type or msg.event != event_type: + raise ValueError(f"Unexpected message: {msg}") + if msg.type == msg_type and msg.event == event_type: + return msg + + +async def full_client_request( + websocket: websockets.WebSocketClientProtocol, payload: bytes +) -> None: + """Send full client message""" + msg = Message(type=MsgType.FullClientRequest, flag=MsgTypeFlagBits.NoSeq) + msg.payload = payload + logger.info(f"Sending: {msg}") + await websocket.send(msg.marshal()) + + +async def audio_only_client( + websocket: websockets.WebSocketClientProtocol, + payload: bytes, + flag: MsgTypeFlagBits, +) -> None: + """Send audio-only client message""" + msg = Message(type=MsgType.AudioOnlyClient, flag=flag) + msg.payload = payload + logger.info(f"Sending: {msg}") + await websocket.send(msg.marshal()) + + +async def start_connection( + websocket: websockets.WebSocketClientProtocol, +) -> None: + """Start connection""" + msg = Message(type=MsgType.FullClientRequest, flag=MsgTypeFlagBits.WithEvent) + msg.event = EventType.StartConnection + msg.payload = b"{}" + logger.info(f"Sending: {msg}") + await websocket.send(msg.marshal()) + + +async def finish_connection( + websocket: websockets.WebSocketClientProtocol, +) -> None: + """Finish connection""" + msg = Message(type=MsgType.FullClientRequest, flag=MsgTypeFlagBits.WithEvent) + msg.event = EventType.FinishConnection + msg.payload = b"{}" + logger.info(f"Sending: {msg}") + await websocket.send(msg.marshal()) + + +async def start_session( + websocket: websockets.WebSocketClientProtocol, + payload: bytes, + session_id: str, +) -> None: + """Start session""" + msg = Message(type=MsgType.FullClientRequest, flag=MsgTypeFlagBits.WithEvent) + msg.event = EventType.StartSession + msg.session_id = session_id + msg.payload = payload + logger.info(f"Sending: {msg}") + await websocket.send(msg.marshal()) + + +async def finish_session( + websocket: websockets.WebSocketClientProtocol, session_id: str +) -> None: + """Finish session""" + msg = Message(type=MsgType.FullClientRequest, flag=MsgTypeFlagBits.WithEvent) + msg.event = EventType.FinishSession + msg.session_id = session_id + msg.payload = b"{}" + logger.info(f"Sending: {msg}") + await websocket.send(msg.marshal()) + + +async def cancel_session( + websocket: websockets.WebSocketClientProtocol, session_id: str +) -> None: + """Cancel session""" + msg = Message(type=MsgType.FullClientRequest, flag=MsgTypeFlagBits.WithEvent) + msg.event = EventType.CancelSession + msg.session_id = session_id + msg.payload = b"{}" + logger.info(f"Sending: {msg}") + await websocket.send(msg.marshal()) + + +async def task_request( + websocket: websockets.WebSocketClientProtocol, + payload: bytes, + session_id: str, +) -> None: + """Send task request""" + msg = Message(type=MsgType.FullClientRequest, flag=MsgTypeFlagBits.WithEvent) + msg.event = EventType.TaskRequest + msg.session_id = session_id + msg.payload = payload + logger.info(f"Sending: {msg}") + await websocket.send(msg.marshal()) diff --git a/veadk/toolkits/audio/tts/tts_client.py b/veadk/toolkits/audio/tts/tts_client.py new file mode 100644 index 00000000..1b68b863 --- /dev/null +++ b/veadk/toolkits/audio/tts/tts_client.py @@ -0,0 +1,73 @@ +# Copyright (c) 2025 Beijing Volcano Engine Technology Co., Ltd. and/or its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import uuid + +import websockets + +from veadk.toolkits.audio.tts.protocols import ( + MsgType, + full_client_request, + receive_message, +) + +TTS_WS_URL = "wss://openspeech.bytedance.com/api/v1/tts/ws_binary" + + +async def tts_request( + app_id: str, + access_token: str, + text: str, + voice_type: str = "zh_female_meilinvyou_saturn_bigtts", +) -> bytes: + APPID = app_id + ACCESS_TOKEN = access_token + VOICE_TYPE = voice_type + CLUSTER = "" + ENCODING = "wav" + + headers = {"Authorization": f"Bearer;{ACCESS_TOKEN}"} + + websocket = await websockets.connect( + TTS_WS_URL, additional_headers=headers, max_size=10 * 1024 * 1024 + ) + audio_data = bytearray() + + try: + cluster = ( + CLUSTER or "volcano_icl" if VOICE_TYPE.startswith("S_") else "volcano_tts" + ) + request = { + "app": {"appid": APPID, "token": ACCESS_TOKEN, "cluster": cluster}, + "user": {"uid": str(uuid.uuid4())}, + "audio": {"voice_type": VOICE_TYPE, "encoding": ENCODING}, + "request": { + "reqid": str(uuid.uuid4()), + "text": text, + "operation": "submit", + }, + } + await full_client_request(websocket, json.dumps(request).encode()) + + while True: + msg = await receive_message(websocket) + if msg.type == MsgType.AudioOnlyServer: + audio_data.extend(msg.payload) + if msg.sequence < 0: + break + finally: + await websocket.close() + + return bytes(audio_data)