docs: Update README and compilation guides for clarity and consistency, including path corrections and improved formatting. Add copyright notices to source files and adjust file permissions for several scripts and directories.

This commit is contained in:
dian.yuan 2026-02-28 11:06:26 +08:00
parent f960c5030d
commit bd891a96dd
136 changed files with 14413 additions and 9399 deletions

0
examples/whisper/cpp/.gitkeep Normal file → Executable file
View file

View file

@ -1,35 +1,35 @@
cmake_minimum_required(VERSION 3.10...3.27)
project(whisper_demo)
set(CMAKE_CXX_STANDARD 17)
list(APPEND CMAKE_MODULE_PATH "${CMAKE_SOURCE_DIR}/../../../../cmake")
find_package(AMLNN REQUIRED)
include_directories(${AMLNN_INCLUDE_DIR})
link_directories(${AMLNN_LIBRARY_DIR})
include_directories(${CMAKE_SOURCE_DIR}/../../../../common)
# Set 3rdparty path
set(3RDPARTY_DIR "${CMAKE_SOURCE_DIR}/../../../../dependency")
if(CMAKE_SYSTEM_NAME STREQUAL "Android")
# Android needs log
link_libraries(log)
endif()
add_executable(${PROJECT_NAME}
main.cpp
common.cpp
whisper.cpp
whisper_invoke.cpp
pre_process_whisper.cpp
post_process_whisper.cpp
)
target_link_libraries(${PROJECT_NAME}
${AMLNN_LIBRARY}
dl
m
)
cmake_minimum_required(VERSION 3.10...3.27)
project(whisper_demo)
set(CMAKE_CXX_STANDARD 17)
list(APPEND CMAKE_MODULE_PATH "${CMAKE_SOURCE_DIR}/../../../../cmake")
find_package(AMLNN REQUIRED)
include_directories(${AMLNN_INCLUDE_DIR})
link_directories(${AMLNN_LIBRARY_DIR})
include_directories(${CMAKE_SOURCE_DIR}/../../../../common)
# Set 3rdparty path
set(3RDPARTY_DIR "${CMAKE_SOURCE_DIR}/../../../../dependency")
if(CMAKE_SYSTEM_NAME STREQUAL "Android")
# Android needs log
link_libraries(log)
endif()
add_executable(${PROJECT_NAME}
main.cpp
common.cpp
whisper.cpp
whisper_invoke.cpp
pre_process_whisper.cpp
post_process_whisper.cpp
)
target_link_libraries(${PROJECT_NAME}
${AMLNN_LIBRARY}
dl
m
)

View file

@ -1,135 +1,135 @@
/*
* Copyright (C) 20242025 Amlogic, Inc. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "common.h"
// third-party utilities
// use your favorite implementations
#define DR_WAV_IMPLEMENTATION
#include "dr_wav.h"
bool is_wav_buffer(const std::string buf) {
// RIFF ref: https://en.wikipedia.org/wiki/Resource_Interchange_File_Format
// WAV ref: https://www.mmsp.ece.mcgill.ca/Documents/AudioFormats/WAVE/WAVE.html
if (buf.size() < 12 || buf.substr(0, 4) != "RIFF" || buf.substr(8, 4) != "WAVE") {
return false;
}
uint32_t chunk_size = *reinterpret_cast<const uint32_t*>(buf.data() + 4);
if (chunk_size + 8 != buf.size()) {
return false;
}
return true;
}
bool read_wav(const std::string & fname, std::vector<float>& pcmf32, std::vector<std::vector<float>>& pcmf32s, bool stereo) {
drwav wav;
std::vector<uint8_t> wav_data; // used for pipe input from stdin
if (fname == "-") {
{
#ifdef _WIN32
_setmode(_fileno(stdin), _O_BINARY);
#endif
uint8_t buf[1024];
while (true)
{
const size_t n = fread(buf, 1, sizeof(buf), stdin);
if (n == 0) {
break;
}
wav_data.insert(wav_data.end(), buf, buf + n);
}
}
if (drwav_init_memory(&wav, wav_data.data(), wav_data.size(), nullptr) == false) {
fprintf(stderr, "error: failed to open WAV file from stdin\n");
return false;
}
fprintf(stderr, "%s: read %zu bytes from stdin\n", __func__, wav_data.size());
}
else if (is_wav_buffer(fname)) {
if (drwav_init_memory(&wav, fname.c_str(), fname.size(), nullptr) == false) {
fprintf(stderr, "error: failed to open WAV file from fname buffer\n");
return false;
}
}
else if (drwav_init_file(&wav, fname.c_str(), nullptr) == false) {
fprintf(stderr, "error: failed to open '%s' as WAV file\n", fname.c_str());
return false;
}
if (wav.channels != 1 && wav.channels != 2) {
fprintf(stderr, "%s: WAV file '%s' must be mono or stereo\n", __func__, fname.c_str());
drwav_uninit(&wav);
return false;
}
if (stereo && wav.channels != 2) {
fprintf(stderr, "%s: WAV file '%s' must be stereo for diarization\n", __func__, fname.c_str());
drwav_uninit(&wav);
return false;
}
if (wav.sampleRate != COMMON_SAMPLE_RATE) {
fprintf(stderr, "%s: WAV file '%s' must be %i kHz\n", __func__, fname.c_str(), COMMON_SAMPLE_RATE/1000);
drwav_uninit(&wav);
return false;
}
if (wav.bitsPerSample != 16) {
fprintf(stderr, "%s: WAV file '%s' must be 16-bit\n", __func__, fname.c_str());
drwav_uninit(&wav);
return false;
}
const uint64_t n = wav_data.empty() ? wav.totalPCMFrameCount : wav_data.size()/(wav.channels*wav.bitsPerSample/8);
std::vector<int16_t> pcm16;
pcm16.resize(n*wav.channels);
drwav_read_pcm_frames_s16(&wav, n, pcm16.data());
drwav_uninit(&wav);
// convert to mono, float
pcmf32.resize(n);
if (wav.channels == 1) {
for (uint64_t i = 0; i < n; i++) {
pcmf32[i] = float(pcm16[i])/32768.0f;
}
} else {
for (uint64_t i = 0; i < n; i++) {
pcmf32[i] = float(pcm16[2*i] + pcm16[2*i + 1])/65536.0f;
}
}
if (stereo) {
// convert to stereo, float
pcmf32s.resize(2);
pcmf32s[0].resize(n);
pcmf32s[1].resize(n);
for (uint64_t i = 0; i < n; i++) {
pcmf32s[0][i] = float(pcm16[2*i])/32768.0f;
pcmf32s[1][i] = float(pcm16[2*i + 1])/32768.0f;
}
}
return true;
/*
* Copyright (C) 20242025 Amlogic, Inc. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "common.h"
// third-party utilities
// use your favorite implementations
#define DR_WAV_IMPLEMENTATION
#include "dr_wav.h"
bool is_wav_buffer(const std::string buf) {
// RIFF ref: https://en.wikipedia.org/wiki/Resource_Interchange_File_Format
// WAV ref: https://www.mmsp.ece.mcgill.ca/Documents/AudioFormats/WAVE/WAVE.html
if (buf.size() < 12 || buf.substr(0, 4) != "RIFF" || buf.substr(8, 4) != "WAVE") {
return false;
}
uint32_t chunk_size = *reinterpret_cast<const uint32_t*>(buf.data() + 4);
if (chunk_size + 8 != buf.size()) {
return false;
}
return true;
}
bool read_wav(const std::string & fname, std::vector<float>& pcmf32, std::vector<std::vector<float>>& pcmf32s, bool stereo) {
drwav wav;
std::vector<uint8_t> wav_data; // used for pipe input from stdin
if (fname == "-") {
{
#ifdef _WIN32
_setmode(_fileno(stdin), _O_BINARY);
#endif
uint8_t buf[1024];
while (true)
{
const size_t n = fread(buf, 1, sizeof(buf), stdin);
if (n == 0) {
break;
}
wav_data.insert(wav_data.end(), buf, buf + n);
}
}
if (drwav_init_memory(&wav, wav_data.data(), wav_data.size(), nullptr) == false) {
fprintf(stderr, "error: failed to open WAV file from stdin\n");
return false;
}
fprintf(stderr, "%s: read %zu bytes from stdin\n", __func__, wav_data.size());
}
else if (is_wav_buffer(fname)) {
if (drwav_init_memory(&wav, fname.c_str(), fname.size(), nullptr) == false) {
fprintf(stderr, "error: failed to open WAV file from fname buffer\n");
return false;
}
}
else if (drwav_init_file(&wav, fname.c_str(), nullptr) == false) {
fprintf(stderr, "error: failed to open '%s' as WAV file\n", fname.c_str());
return false;
}
if (wav.channels != 1 && wav.channels != 2) {
fprintf(stderr, "%s: WAV file '%s' must be mono or stereo\n", __func__, fname.c_str());
drwav_uninit(&wav);
return false;
}
if (stereo && wav.channels != 2) {
fprintf(stderr, "%s: WAV file '%s' must be stereo for diarization\n", __func__, fname.c_str());
drwav_uninit(&wav);
return false;
}
if (wav.sampleRate != COMMON_SAMPLE_RATE) {
fprintf(stderr, "%s: WAV file '%s' must be %i kHz\n", __func__, fname.c_str(), COMMON_SAMPLE_RATE/1000);
drwav_uninit(&wav);
return false;
}
if (wav.bitsPerSample != 16) {
fprintf(stderr, "%s: WAV file '%s' must be 16-bit\n", __func__, fname.c_str());
drwav_uninit(&wav);
return false;
}
const uint64_t n = wav_data.empty() ? wav.totalPCMFrameCount : wav_data.size()/(wav.channels*wav.bitsPerSample/8);
std::vector<int16_t> pcm16;
pcm16.resize(n*wav.channels);
drwav_read_pcm_frames_s16(&wav, n, pcm16.data());
drwav_uninit(&wav);
// convert to mono, float
pcmf32.resize(n);
if (wav.channels == 1) {
for (uint64_t i = 0; i < n; i++) {
pcmf32[i] = float(pcm16[i])/32768.0f;
}
} else {
for (uint64_t i = 0; i < n; i++) {
pcmf32[i] = float(pcm16[2*i] + pcm16[2*i + 1])/65536.0f;
}
}
if (stereo) {
// convert to stereo, float
pcmf32s.resize(2);
pcmf32s[0].resize(n);
pcmf32s[1].resize(n);
for (uint64_t i = 0; i < n; i++) {
pcmf32s[0][i] = float(pcm16[2*i])/32768.0f;
pcmf32s[1][i] = float(pcm16[2*i + 1])/32768.0f;
}
}
return true;
}

View file

@ -1,40 +1,40 @@
/*
* Copyright (C) 20242025 Amlogic, Inc. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include <string>
#include <map>
#include <vector>
#include <random>
#include <thread>
#include <ctime>
#include <fstream>
#define COMMON_SAMPLE_RATE 16000
bool is_wav_buffer(const std::string buf);
// Read WAV audio file and store the PCM data into pcmf32
// fname can be a buffer of WAV data instead of a filename
// The sample rate of the audio must be equal to COMMON_SAMPLE_RATE
// If stereo flag is set and the audio has 2 channels, the pcmf32s will contain 2 channel PCM
bool read_wav(
const std::string & fname,
std::vector<float> & pcmf32,
std::vector<std::vector<float>> & pcmf32s,
/*
* Copyright (C) 20242025 Amlogic, Inc. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include <string>
#include <map>
#include <vector>
#include <random>
#include <thread>
#include <ctime>
#include <fstream>
#define COMMON_SAMPLE_RATE 16000
bool is_wav_buffer(const std::string buf);
// Read WAV audio file and store the PCM data into pcmf32
// fname can be a buffer of WAV data instead of a filename
// The sample rate of the audio must be equal to COMMON_SAMPLE_RATE
// If stereo flag is set and the audio has 2 channels, the pcmf32s will contain 2 channel PCM
bool read_wav(
const std::string & fname,
std::vector<float> & pcmf32,
std::vector<std::vector<float>> & pcmf32s,
bool stereo);

View file

@ -30,7 +30,7 @@ callback which will give you information about the data format. To determine the
Introduction
============
This is a single file library. To use it, do something like the following in one .c file.
```c
#define DR_WAV_IMPLEMENTATION
#include "dr_wav.h"
@ -391,7 +391,7 @@ If the return value differs from bytesToWrite, it indicates an error.
typedef size_t (* drwav_write_proc)(void* pUserData, const void* pData, size_t bytesToWrite);
/*
Callback for when data needs to be seeked.
Callback for when data needs to be sought.
pUserData [in] The user data that was passed to drwav_init() and family.
offset [in] The number of bytes to move, relative to the origin. Will never be negative.
@ -415,16 +415,16 @@ pChunkHeader [in] A pointer to an object containing basic header informatio
container [in] Whether or not the WAV file is a RIFF or Wave64 container. If you're unsure of the difference, assume RIFF.
pFMT [in] A pointer to the object containing the contents of the "fmt" chunk.
Returns the number of bytes read + seeked.
Returns the number of bytes read + sought.
To read data from the chunk, call onRead(), passing in pReadSeekUserData as the first parameter. Do the same for seeking with onSeek(). The return value must
be the total number of bytes you have read _plus_ seeked.
be the total number of bytes you have read _plus_ sought.
Use the `container` argument to discriminate the fields in `pChunkHeader->id`. If the container is `drwav_container_riff` or `drwav_container_rf64` you should
use `id.fourcc`, otherwise you should use `id.guid`.
The `pFMT` parameter can be used to determine the data format of the wave file. Use `drwav_fmt_get_format()` to get the sample format, which will be one of the
`DR_WAVE_FORMAT_*` identifiers.
`DR_WAVE_FORMAT_*` identifiers.
The read pointer will be sitting on the first byte after the chunk's header. You must not attempt to read beyond the boundary of the chunk.
*/
@ -499,7 +499,7 @@ typedef struct
/* A pointer to the function to call when data needs to be written. Only used when the drwav object is opened in write mode. */
drwav_write_proc onWrite;
/* A pointer to the function to call when the wav file needs to be seeked. */
/* A pointer to the function to call when the wav file needs to be sought. */
drwav_seek_proc onSeek;
/* The user data to pass to callbacks. */
@ -534,7 +534,7 @@ typedef struct
/* The size in bytes of the data chunk. */
drwav_uint64 dataChunkDataSize;
/* The position in the stream of the first byte of the data chunk. This is used for seeking. */
drwav_uint64 dataChunkDataPos;
@ -565,7 +565,7 @@ typedef struct
{
drwav_uint64 iCurrentPCMFrame; /* The index of the next PCM frame that will be read by drwav_read_*(). This is used with "totalPCMFrameCount" to ensure we don't read excess samples at the end of the last block. */
} compressed;
/* Microsoft ADPCM specific data. */
struct
{
@ -1985,7 +1985,7 @@ static drwav_bool32 drwav_init__internal(drwav* pWav, drwav_chunk_proc onChunk,
We need to enumerate over each chunk for two reasons:
1) The "data" chunk may not be the next one
2) We may want to report each chunk back to the client
In order to correctly report each chunk back to the client we will need to keep looping until the end of the file.
*/
foundDataChunk = DRWAV_FALSE;
@ -2017,7 +2017,7 @@ static drwav_bool32 drwav_init__internal(drwav* pWav, drwav_chunk_proc onChunk,
}
}
}
if (!foundDataChunk) {
pWav->dataChunkDataPos = cursor;
@ -2159,7 +2159,7 @@ static drwav_bool32 drwav_init__internal(drwav* pWav, drwav_chunk_proc onChunk,
}
cursor = pWav->dataChunkDataPos;
}
/* At this point we should be sitting on the first byte of the raw audio data. */
@ -2431,7 +2431,7 @@ static drwav_bool32 drwav_init_write__internal(drwav* pWav, const drwav_data_for
runningPos += drwav__write(pWav, "WAVE", 4);
}
/* "ds64" chunk (RF64 only). */
if (pFormat->container == drwav_container_rf64) {
drwav_uint32 initialds64ChunkSize = 28; /* 28 = [Size of RIFF (8 bytes)] + [Size of DATA (8 bytes)] + [Sample Count (8 bytes)] + [Table Length (4 bytes)]. Table length always set to 0. */
@ -3291,7 +3291,7 @@ static drwav_bool32 drwav__on_seek_memory(void* pUserData, int offset, drwav_see
return DRWAV_FALSE; /* Trying to seek too far forward. */
}
}
return DRWAV_TRUE;
}
@ -3360,7 +3360,7 @@ static drwav_bool32 drwav__on_seek_memory_write(void* pUserData, int offset, drw
pWav->memoryStreamWrite.currentWritePos = pWav->memoryStreamWrite.dataSize; /* Trying to seek too far forward. */
}
}
return DRWAV_TRUE;
}
@ -3452,7 +3452,7 @@ DRWAV_API drwav_result drwav_uninit(drwav* pWav)
} else {
paddingSize = drwav__chunk_padding_size_w64(pWav->dataChunkDataSize);
}
if (paddingSize > 0) {
drwav_uint64 paddingData = 0;
drwav__write(pWav, &paddingData, paddingSize); /* Byte order does not matter for this. */
@ -3561,16 +3561,16 @@ DRWAV_API size_t drwav_read_raw(drwav* pWav, size_t bytesToRead, void* pBufferOu
/* When we get here we may need to read-and-discard some data. */
while (bytesRead < bytesToRead) {
drwav_uint8 buffer[4096];
size_t bytesSeeked;
size_t bytessought;
size_t bytesToSeek = (bytesToRead - bytesRead);
if (bytesToSeek > sizeof(buffer)) {
bytesToSeek = sizeof(buffer);
}
bytesSeeked = pWav->onRead(pWav->pUserData, buffer, bytesToSeek);
bytesRead += bytesSeeked;
bytessought = pWav->onRead(pWav->pUserData, buffer, bytesToSeek);
bytesRead += bytessought;
if (bytesSeeked < bytesToSeek) {
if (bytessought < bytesToSeek) {
break; /* Reached the end. */
}
}
@ -3662,7 +3662,7 @@ DRWAV_API drwav_bool32 drwav_seek_to_first_pcm_frame(drwav* pWav)
DRWAV_ASSERT(DRWAV_FALSE); /* If this assertion is triggered it means I've implemented a new compressed format but forgot to add a branch for it here. */
}
}
pWav->bytesRemaining = pWav->dataChunkDataSize;
return DRWAV_TRUE;
}
@ -3696,7 +3696,7 @@ DRWAV_API drwav_bool32 drwav_seek_to_pcm_frame(drwav* pWav, drwav_uint64 targetF
*/
if (drwav__is_compressed_format_tag(pWav->translatedFormatTag)) {
/* TODO: This can be optimized. */
/*
If we're seeking forward it's simple - just keep reading samples until we hit the sample we're requesting. If we're seeking backwards,
we first need to seek back to the start and then just do the same thing as a forward seek.
@ -3844,7 +3844,7 @@ DRWAV_API drwav_uint64 drwav_write_pcm_frames_be(drwav* pWav, drwav_uint64 frame
pRunningData = (const drwav_uint8*)pData;
bytesPerSample = drwav_get_bytes_per_pcm_frame(pWav) / pWav->channels;
while (bytesToWrite > 0) {
drwav_uint8 temp[4096];
drwav_uint32 sampleCount;
@ -3972,9 +3972,9 @@ static drwav_uint64 drwav_read_pcm_frames_s16__msadpcm(drwav* pWav, drwav_uint64
if (pWav->msadpcm.bytesRemainingInBlock == 0) {
continue;
} else {
static drwav_int32 adaptationTable[] = {
230, 230, 230, 230, 307, 409, 512, 614,
768, 614, 512, 409, 307, 230, 230, 230
static drwav_int32 adaptationTable[] = {
230, 230, 230, 230, 307, 409, 512, 614,
768, 614, 512, 409, 307, 230, 230, 230
};
static drwav_int32 coeff1Table[] = { 256, 512, 0, 192, 240, 460, 392 };
static drwav_int32 coeff2Table[] = { 0, -256, 0, 64, 0, -208, -232 };
@ -4081,15 +4081,15 @@ static drwav_uint64 drwav_read_pcm_frames_s16__ima(drwav* pWav, drwav_uint64 fra
};
static drwav_int32 stepTable[89] = {
7, 8, 9, 10, 11, 12, 13, 14, 16, 17,
19, 21, 23, 25, 28, 31, 34, 37, 41, 45,
50, 55, 60, 66, 73, 80, 88, 97, 107, 118,
7, 8, 9, 10, 11, 12, 13, 14, 16, 17,
19, 21, 23, 25, 28, 31, 34, 37, 41, 45,
50, 55, 60, 66, 73, 80, 88, 97, 107, 118,
130, 143, 157, 173, 190, 209, 230, 253, 279, 307,
337, 371, 408, 449, 494, 544, 598, 658, 724, 796,
876, 963, 1060, 1166, 1282, 1411, 1552, 1707, 1878, 2066,
876, 963, 1060, 1166, 1282, 1411, 1552, 1707, 1878, 2066,
2272, 2499, 2749, 3024, 3327, 3660, 4026, 4428, 4871, 5358,
5894, 6484, 7132, 7845, 8630, 9493, 10442, 11487, 12635, 13899,
15289, 16818, 18500, 20350, 22385, 24623, 27086, 29794, 32767
5894, 6484, 7132, 7845, 8630, 9493, 10442, 11487, 12635, 13899,
15289, 16818, 18500, 20350, 22385, 24623, 27086, 29794, 32767
};
DRWAV_ASSERT(pWav != NULL);
@ -4229,40 +4229,40 @@ static drwav_uint64 drwav_read_pcm_frames_s16__ima(drwav* pWav, drwav_uint64 fra
#ifndef DR_WAV_NO_CONVERSION_API
static unsigned short g_drwavAlawTable[256] = {
0xEA80, 0xEB80, 0xE880, 0xE980, 0xEE80, 0xEF80, 0xEC80, 0xED80, 0xE280, 0xE380, 0xE080, 0xE180, 0xE680, 0xE780, 0xE480, 0xE580,
0xF540, 0xF5C0, 0xF440, 0xF4C0, 0xF740, 0xF7C0, 0xF640, 0xF6C0, 0xF140, 0xF1C0, 0xF040, 0xF0C0, 0xF340, 0xF3C0, 0xF240, 0xF2C0,
0xAA00, 0xAE00, 0xA200, 0xA600, 0xBA00, 0xBE00, 0xB200, 0xB600, 0x8A00, 0x8E00, 0x8200, 0x8600, 0x9A00, 0x9E00, 0x9200, 0x9600,
0xD500, 0xD700, 0xD100, 0xD300, 0xDD00, 0xDF00, 0xD900, 0xDB00, 0xC500, 0xC700, 0xC100, 0xC300, 0xCD00, 0xCF00, 0xC900, 0xCB00,
0xFEA8, 0xFEB8, 0xFE88, 0xFE98, 0xFEE8, 0xFEF8, 0xFEC8, 0xFED8, 0xFE28, 0xFE38, 0xFE08, 0xFE18, 0xFE68, 0xFE78, 0xFE48, 0xFE58,
0xFFA8, 0xFFB8, 0xFF88, 0xFF98, 0xFFE8, 0xFFF8, 0xFFC8, 0xFFD8, 0xFF28, 0xFF38, 0xFF08, 0xFF18, 0xFF68, 0xFF78, 0xFF48, 0xFF58,
0xFAA0, 0xFAE0, 0xFA20, 0xFA60, 0xFBA0, 0xFBE0, 0xFB20, 0xFB60, 0xF8A0, 0xF8E0, 0xF820, 0xF860, 0xF9A0, 0xF9E0, 0xF920, 0xF960,
0xFD50, 0xFD70, 0xFD10, 0xFD30, 0xFDD0, 0xFDF0, 0xFD90, 0xFDB0, 0xFC50, 0xFC70, 0xFC10, 0xFC30, 0xFCD0, 0xFCF0, 0xFC90, 0xFCB0,
0x1580, 0x1480, 0x1780, 0x1680, 0x1180, 0x1080, 0x1380, 0x1280, 0x1D80, 0x1C80, 0x1F80, 0x1E80, 0x1980, 0x1880, 0x1B80, 0x1A80,
0x0AC0, 0x0A40, 0x0BC0, 0x0B40, 0x08C0, 0x0840, 0x09C0, 0x0940, 0x0EC0, 0x0E40, 0x0FC0, 0x0F40, 0x0CC0, 0x0C40, 0x0DC0, 0x0D40,
0x5600, 0x5200, 0x5E00, 0x5A00, 0x4600, 0x4200, 0x4E00, 0x4A00, 0x7600, 0x7200, 0x7E00, 0x7A00, 0x6600, 0x6200, 0x6E00, 0x6A00,
0x2B00, 0x2900, 0x2F00, 0x2D00, 0x2300, 0x2100, 0x2700, 0x2500, 0x3B00, 0x3900, 0x3F00, 0x3D00, 0x3300, 0x3100, 0x3700, 0x3500,
0x0158, 0x0148, 0x0178, 0x0168, 0x0118, 0x0108, 0x0138, 0x0128, 0x01D8, 0x01C8, 0x01F8, 0x01E8, 0x0198, 0x0188, 0x01B8, 0x01A8,
0x0058, 0x0048, 0x0078, 0x0068, 0x0018, 0x0008, 0x0038, 0x0028, 0x00D8, 0x00C8, 0x00F8, 0x00E8, 0x0098, 0x0088, 0x00B8, 0x00A8,
0x0560, 0x0520, 0x05E0, 0x05A0, 0x0460, 0x0420, 0x04E0, 0x04A0, 0x0760, 0x0720, 0x07E0, 0x07A0, 0x0660, 0x0620, 0x06E0, 0x06A0,
0xEA80, 0xEB80, 0xE880, 0xE980, 0xEE80, 0xEF80, 0xEC80, 0xED80, 0xE280, 0xE380, 0xE080, 0xE180, 0xE680, 0xE780, 0xE480, 0xE580,
0xF540, 0xF5C0, 0xF440, 0xF4C0, 0xF740, 0xF7C0, 0xF640, 0xF6C0, 0xF140, 0xF1C0, 0xF040, 0xF0C0, 0xF340, 0xF3C0, 0xF240, 0xF2C0,
0xAA00, 0xAE00, 0xA200, 0xA600, 0xBA00, 0xBE00, 0xB200, 0xB600, 0x8A00, 0x8E00, 0x8200, 0x8600, 0x9A00, 0x9E00, 0x9200, 0x9600,
0xD500, 0xD700, 0xD100, 0xD300, 0xDD00, 0xDF00, 0xD900, 0xDB00, 0xC500, 0xC700, 0xC100, 0xC300, 0xCD00, 0xCF00, 0xC900, 0xCB00,
0xFEA8, 0xFEB8, 0xFE88, 0xFE98, 0xFEE8, 0xFEF8, 0xFEC8, 0xFED8, 0xFE28, 0xFE38, 0xFE08, 0xFE18, 0xFE68, 0xFE78, 0xFE48, 0xFE58,
0xFFA8, 0xFFB8, 0xFF88, 0xFF98, 0xFFE8, 0xFFF8, 0xFFC8, 0xFFD8, 0xFF28, 0xFF38, 0xFF08, 0xFF18, 0xFF68, 0xFF78, 0xFF48, 0xFF58,
0xFAA0, 0xFAE0, 0xFA20, 0xFA60, 0xFBA0, 0xFBE0, 0xFB20, 0xFB60, 0xF8A0, 0xF8E0, 0xF820, 0xF860, 0xF9A0, 0xF9E0, 0xF920, 0xF960,
0xFD50, 0xFD70, 0xFD10, 0xFD30, 0xFDD0, 0xFDF0, 0xFD90, 0xFDB0, 0xFC50, 0xFC70, 0xFC10, 0xFC30, 0xFCD0, 0xFCF0, 0xFC90, 0xFCB0,
0x1580, 0x1480, 0x1780, 0x1680, 0x1180, 0x1080, 0x1380, 0x1280, 0x1D80, 0x1C80, 0x1F80, 0x1E80, 0x1980, 0x1880, 0x1B80, 0x1A80,
0x0AC0, 0x0A40, 0x0BC0, 0x0B40, 0x08C0, 0x0840, 0x09C0, 0x0940, 0x0EC0, 0x0E40, 0x0FC0, 0x0F40, 0x0CC0, 0x0C40, 0x0DC0, 0x0D40,
0x5600, 0x5200, 0x5E00, 0x5A00, 0x4600, 0x4200, 0x4E00, 0x4A00, 0x7600, 0x7200, 0x7E00, 0x7A00, 0x6600, 0x6200, 0x6E00, 0x6A00,
0x2B00, 0x2900, 0x2F00, 0x2D00, 0x2300, 0x2100, 0x2700, 0x2500, 0x3B00, 0x3900, 0x3F00, 0x3D00, 0x3300, 0x3100, 0x3700, 0x3500,
0x0158, 0x0148, 0x0178, 0x0168, 0x0118, 0x0108, 0x0138, 0x0128, 0x01D8, 0x01C8, 0x01F8, 0x01E8, 0x0198, 0x0188, 0x01B8, 0x01A8,
0x0058, 0x0048, 0x0078, 0x0068, 0x0018, 0x0008, 0x0038, 0x0028, 0x00D8, 0x00C8, 0x00F8, 0x00E8, 0x0098, 0x0088, 0x00B8, 0x00A8,
0x0560, 0x0520, 0x05E0, 0x05A0, 0x0460, 0x0420, 0x04E0, 0x04A0, 0x0760, 0x0720, 0x07E0, 0x07A0, 0x0660, 0x0620, 0x06E0, 0x06A0,
0x02B0, 0x0290, 0x02F0, 0x02D0, 0x0230, 0x0210, 0x0270, 0x0250, 0x03B0, 0x0390, 0x03F0, 0x03D0, 0x0330, 0x0310, 0x0370, 0x0350
};
static unsigned short g_drwavMulawTable[256] = {
0x8284, 0x8684, 0x8A84, 0x8E84, 0x9284, 0x9684, 0x9A84, 0x9E84, 0xA284, 0xA684, 0xAA84, 0xAE84, 0xB284, 0xB684, 0xBA84, 0xBE84,
0xC184, 0xC384, 0xC584, 0xC784, 0xC984, 0xCB84, 0xCD84, 0xCF84, 0xD184, 0xD384, 0xD584, 0xD784, 0xD984, 0xDB84, 0xDD84, 0xDF84,
0xE104, 0xE204, 0xE304, 0xE404, 0xE504, 0xE604, 0xE704, 0xE804, 0xE904, 0xEA04, 0xEB04, 0xEC04, 0xED04, 0xEE04, 0xEF04, 0xF004,
0xF0C4, 0xF144, 0xF1C4, 0xF244, 0xF2C4, 0xF344, 0xF3C4, 0xF444, 0xF4C4, 0xF544, 0xF5C4, 0xF644, 0xF6C4, 0xF744, 0xF7C4, 0xF844,
0xF8A4, 0xF8E4, 0xF924, 0xF964, 0xF9A4, 0xF9E4, 0xFA24, 0xFA64, 0xFAA4, 0xFAE4, 0xFB24, 0xFB64, 0xFBA4, 0xFBE4, 0xFC24, 0xFC64,
0xFC94, 0xFCB4, 0xFCD4, 0xFCF4, 0xFD14, 0xFD34, 0xFD54, 0xFD74, 0xFD94, 0xFDB4, 0xFDD4, 0xFDF4, 0xFE14, 0xFE34, 0xFE54, 0xFE74,
0xFE8C, 0xFE9C, 0xFEAC, 0xFEBC, 0xFECC, 0xFEDC, 0xFEEC, 0xFEFC, 0xFF0C, 0xFF1C, 0xFF2C, 0xFF3C, 0xFF4C, 0xFF5C, 0xFF6C, 0xFF7C,
0xFF88, 0xFF90, 0xFF98, 0xFFA0, 0xFFA8, 0xFFB0, 0xFFB8, 0xFFC0, 0xFFC8, 0xFFD0, 0xFFD8, 0xFFE0, 0xFFE8, 0xFFF0, 0xFFF8, 0x0000,
0x7D7C, 0x797C, 0x757C, 0x717C, 0x6D7C, 0x697C, 0x657C, 0x617C, 0x5D7C, 0x597C, 0x557C, 0x517C, 0x4D7C, 0x497C, 0x457C, 0x417C,
0x3E7C, 0x3C7C, 0x3A7C, 0x387C, 0x367C, 0x347C, 0x327C, 0x307C, 0x2E7C, 0x2C7C, 0x2A7C, 0x287C, 0x267C, 0x247C, 0x227C, 0x207C,
0x1EFC, 0x1DFC, 0x1CFC, 0x1BFC, 0x1AFC, 0x19FC, 0x18FC, 0x17FC, 0x16FC, 0x15FC, 0x14FC, 0x13FC, 0x12FC, 0x11FC, 0x10FC, 0x0FFC,
0x0F3C, 0x0EBC, 0x0E3C, 0x0DBC, 0x0D3C, 0x0CBC, 0x0C3C, 0x0BBC, 0x0B3C, 0x0ABC, 0x0A3C, 0x09BC, 0x093C, 0x08BC, 0x083C, 0x07BC,
0x075C, 0x071C, 0x06DC, 0x069C, 0x065C, 0x061C, 0x05DC, 0x059C, 0x055C, 0x051C, 0x04DC, 0x049C, 0x045C, 0x041C, 0x03DC, 0x039C,
0x036C, 0x034C, 0x032C, 0x030C, 0x02EC, 0x02CC, 0x02AC, 0x028C, 0x026C, 0x024C, 0x022C, 0x020C, 0x01EC, 0x01CC, 0x01AC, 0x018C,
0x0174, 0x0164, 0x0154, 0x0144, 0x0134, 0x0124, 0x0114, 0x0104, 0x00F4, 0x00E4, 0x00D4, 0x00C4, 0x00B4, 0x00A4, 0x0094, 0x0084,
0x8284, 0x8684, 0x8A84, 0x8E84, 0x9284, 0x9684, 0x9A84, 0x9E84, 0xA284, 0xA684, 0xAA84, 0xAE84, 0xB284, 0xB684, 0xBA84, 0xBE84,
0xC184, 0xC384, 0xC584, 0xC784, 0xC984, 0xCB84, 0xCD84, 0xCF84, 0xD184, 0xD384, 0xD584, 0xD784, 0xD984, 0xDB84, 0xDD84, 0xDF84,
0xE104, 0xE204, 0xE304, 0xE404, 0xE504, 0xE604, 0xE704, 0xE804, 0xE904, 0xEA04, 0xEB04, 0xEC04, 0xED04, 0xEE04, 0xEF04, 0xF004,
0xF0C4, 0xF144, 0xF1C4, 0xF244, 0xF2C4, 0xF344, 0xF3C4, 0xF444, 0xF4C4, 0xF544, 0xF5C4, 0xF644, 0xF6C4, 0xF744, 0xF7C4, 0xF844,
0xF8A4, 0xF8E4, 0xF924, 0xF964, 0xF9A4, 0xF9E4, 0xFA24, 0xFA64, 0xFAA4, 0xFAE4, 0xFB24, 0xFB64, 0xFBA4, 0xFBE4, 0xFC24, 0xFC64,
0xFC94, 0xFCB4, 0xFCD4, 0xFCF4, 0xFD14, 0xFD34, 0xFD54, 0xFD74, 0xFD94, 0xFDB4, 0xFDD4, 0xFDF4, 0xFE14, 0xFE34, 0xFE54, 0xFE74,
0xFE8C, 0xFE9C, 0xFEAC, 0xFEBC, 0xFECC, 0xFEDC, 0xFEEC, 0xFEFC, 0xFF0C, 0xFF1C, 0xFF2C, 0xFF3C, 0xFF4C, 0xFF5C, 0xFF6C, 0xFF7C,
0xFF88, 0xFF90, 0xFF98, 0xFFA0, 0xFFA8, 0xFFB0, 0xFFB8, 0xFFC0, 0xFFC8, 0xFFD0, 0xFFD8, 0xFFE0, 0xFFE8, 0xFFF0, 0xFFF8, 0x0000,
0x7D7C, 0x797C, 0x757C, 0x717C, 0x6D7C, 0x697C, 0x657C, 0x617C, 0x5D7C, 0x597C, 0x557C, 0x517C, 0x4D7C, 0x497C, 0x457C, 0x417C,
0x3E7C, 0x3C7C, 0x3A7C, 0x387C, 0x367C, 0x347C, 0x327C, 0x307C, 0x2E7C, 0x2C7C, 0x2A7C, 0x287C, 0x267C, 0x247C, 0x227C, 0x207C,
0x1EFC, 0x1DFC, 0x1CFC, 0x1BFC, 0x1AFC, 0x19FC, 0x18FC, 0x17FC, 0x16FC, 0x15FC, 0x14FC, 0x13FC, 0x12FC, 0x11FC, 0x10FC, 0x0FFC,
0x0F3C, 0x0EBC, 0x0E3C, 0x0DBC, 0x0D3C, 0x0CBC, 0x0C3C, 0x0BBC, 0x0B3C, 0x0ABC, 0x0A3C, 0x09BC, 0x093C, 0x08BC, 0x083C, 0x07BC,
0x075C, 0x071C, 0x06DC, 0x069C, 0x065C, 0x061C, 0x05DC, 0x059C, 0x055C, 0x051C, 0x04DC, 0x049C, 0x045C, 0x041C, 0x03DC, 0x039C,
0x036C, 0x034C, 0x032C, 0x030C, 0x02EC, 0x02CC, 0x02AC, 0x028C, 0x026C, 0x024C, 0x022C, 0x020C, 0x01EC, 0x01CC, 0x01AC, 0x018C,
0x0174, 0x0164, 0x0154, 0x0144, 0x0134, 0x0124, 0x0114, 0x0104, 0x00F4, 0x00E4, 0x00D4, 0x00C4, 0x00B4, 0x00A4, 0x0094, 0x0084,
0x0078, 0x0070, 0x0068, 0x0060, 0x0058, 0x0050, 0x0048, 0x0040, 0x0038, 0x0030, 0x0028, 0x0020, 0x0018, 0x0010, 0x0008, 0x0000
};
@ -4355,14 +4355,14 @@ static drwav_uint64 drwav_read_pcm_frames_s16__pcm(drwav* pWav, drwav_uint64 fra
if ((pWav->translatedFormatTag == DR_WAVE_FORMAT_PCM && pWav->bitsPerSample == 16) || pBufferOut == NULL) {
return drwav_read_pcm_frames(pWav, framesToRead, pBufferOut);
}
bytesPerFrame = drwav_get_bytes_per_pcm_frame(pWav);
if (bytesPerFrame == 0) {
return 0;
}
totalFramesRead = 0;
while (framesToRead > 0) {
drwav_uint64 framesRead = drwav_read_pcm_frames(pWav, drwav_min(framesToRead, sizeof(sampleData)/bytesPerFrame), sampleData);
if (framesRead == 0) {
@ -4395,7 +4395,7 @@ static drwav_uint64 drwav_read_pcm_frames_s16__ieee(drwav* pWav, drwav_uint64 fr
}
totalFramesRead = 0;
while (framesToRead > 0) {
drwav_uint64 framesRead = drwav_read_pcm_frames(pWav, drwav_min(framesToRead, sizeof(sampleData)/bytesPerFrame), sampleData);
if (framesRead == 0) {
@ -4428,7 +4428,7 @@ static drwav_uint64 drwav_read_pcm_frames_s16__alaw(drwav* pWav, drwav_uint64 fr
}
totalFramesRead = 0;
while (framesToRead > 0) {
drwav_uint64 framesRead = drwav_read_pcm_frames(pWav, drwav_min(framesToRead, sizeof(sampleData)/bytesPerFrame), sampleData);
if (framesRead == 0) {
@ -4777,7 +4777,7 @@ static drwav_uint64 drwav_read_pcm_frames_f32__ieee(drwav* pWav, drwav_uint64 fr
if (pWav->translatedFormatTag == DR_WAVE_FORMAT_IEEE_FLOAT && pWav->bitsPerSample == 32) {
return drwav_read_pcm_frames(pWav, framesToRead, pBufferOut);
}
bytesPerFrame = drwav_get_bytes_per_pcm_frame(pWav);
if (bytesPerFrame == 0) {
return 0;
@ -5110,7 +5110,7 @@ static drwav_uint64 drwav_read_pcm_frames_s32__pcm(drwav* pWav, drwav_uint64 fra
if (pWav->translatedFormatTag == DR_WAVE_FORMAT_PCM && pWav->bitsPerSample == 32) {
return drwav_read_pcm_frames(pWav, framesToRead, pBufferOut);
}
bytesPerFrame = drwav_get_bytes_per_pcm_frame(pWav);
if (bytesPerFrame == 0) {
return 0;

View file

@ -1,204 +1,204 @@
/*
* Copyright (C) 20242025 Amlogic, Inc. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <stdio.h>
#include <time.h>
#include <iostream>
#include "whisper_invoke.h"
#include "nn_sdk.h"
#define BILLION 1000000000
#define GET_INFERENCE_TIME (1)
#define WHISPER_DECODER_INPUTS 48
struct Get_Times
{
uint64_t init_start_time, init_end_time, init_total_time;
uint64_t preProcess_start_time, preProcess_end_time, preProcess_total_time;
uint64_t invoke_start_time, invoke_end_time, invoke_total_time; /* for whisper_decoder or llm invoke time once */
uint64_t total_time; /* for whisper or llm pipeline time */
std::vector<uint64_t> total_time_group; /* for whisper_decoder or llm invoke time everytimes */
};
static uint64_t get_time_count()
{
struct timespec ts;
clock_gettime(CLOCK_MONOTONIC, &ts);
return (uint64_t)((uint64_t)ts.tv_nsec + (uint64_t)ts.tv_sec * BILLION);
}
int main(int argc, char ** argv)
{
Get_Times encoder_time, decoder_time, whisper_time;
Input_Decoder decoder_inputs_data;
std::vector<float> encoder_input_data;
std::vector<float> encoder_output_data;
int64_t input_1_data[] = {50257, 50362}; /* init token, for tiny_en or base_en */
int input_1_data_size = sizeof(input_1_data) / sizeof(input_1_data[0]);
int ret = 0;
char* model_path_encoder = argv[1];
char* model_path_decoder = argv[2];
void *context_enc = NULL;
void *context_dec = NULL;
whisper_time.init_start_time = get_time_count();
context_enc = init_network_file(model_path_encoder);
context_dec = init_network_file(model_path_decoder);
whisper_time.init_end_time = get_time_count();
whisper_time.init_total_time = (whisper_time.init_end_time - whisper_time.init_start_time) / 1000000;
if (context_enc == NULL)
{
printf("init_network [context_enc] fail.\n");
return -1;
}
if (context_dec == NULL)
{
printf("init_network [context_dec] fail.\n");
return -1;
}
if (getenv("GET_TIME"))
{
std::cout << "init_whisper_total time : " << whisper_time.init_total_time << "ms" << std::endl;
}
while (true)
{
std::string input_str;
bool is_finish = false;
std::string out_text = "start"; /* end adla model output text init */
printf("\n");
printf("Audio Path:\n");
std::getline(std::cin, input_str);
if (input_str == "exit")
{
break;
} else if (input_str == "") {
printf("Please enter wav path\n");
continue;
} else if (input_str.size() < 4 || input_str.substr(input_str.size() - 4) != ".wav") {
std::cout << "Invalid wav path or file does not exist, please try again" << std::endl;
continue;
}
decoder_inputs_data.input_1_size = WHISPER_DECODER_INPUTS;
decoder_inputs_data.input_1= new int64_t[decoder_inputs_data.input_1_size];
std::copy(input_1_data, input_1_data + input_1_data_size, decoder_inputs_data.input_1);
// need enough data 0
std::fill(decoder_inputs_data.input_1 + input_1_data_size,
decoder_inputs_data.input_1 + decoder_inputs_data.input_1_size,
0);
whisper_time.preProcess_start_time = get_time_count();
encoder_input_data = do_pre_process(input_str);
if (!encoder_input_data.size()) /* support wav 0s */
{
is_finish = is_finish_end();
std::cout << "wav is null, please try again" << std::endl;
continue;
}
whisper_time.preProcess_end_time = get_time_count();
encoder_output_data = run_network_encoder_process(context_enc, encoder_input_data);
encoder_time.invoke_end_time = get_time_count();
decoder_inputs_data.input_0_size = encoder_output_data.size();
decoder_inputs_data.input_0 = new float[decoder_inputs_data.input_0_size];
std::copy(encoder_output_data.begin(), encoder_output_data.end(), decoder_inputs_data.input_0);
whisper_time.preProcess_total_time = (whisper_time.preProcess_end_time - whisper_time.preProcess_start_time) / 1000000;
encoder_time.invoke_total_time = (encoder_time.invoke_end_time - whisper_time.preProcess_end_time) / 1000000;
printf("\n");
printf("Audio Text:\n");
while (!is_finish)
{
decoder_time.invoke_start_time = get_time_count();
out_text = run_network_decoder(context_dec, &decoder_inputs_data);
decoder_time.invoke_end_time = get_time_count();
is_finish = is_finish_end();
decoder_time.total_time_group.push_back((decoder_time.invoke_end_time - decoder_time.invoke_start_time) / 1000000);
std::cout << out_text << std::flush;
}
printf("\n");
if (getenv("GET_OUTPUTS_SIZE"))
{
std::cout << "==================================" << std::endl;
std::cout << "WHISPER_OUTPUTS_SIZE : " << decoder_time.total_time_group.size() << std::endl;
}
if (getenv("GET_TIME"))
{
uint64_t total_time_whisper, total_time_decoder, total_time_llm;
for (int i = 0; i < decoder_time.total_time_group.size(); i++) {
std::cout << "==================================" << std::endl;
if (i < 1)
{
total_time_whisper = whisper_time.preProcess_total_time + encoder_time.invoke_total_time;
whisper_time.total_time = whisper_time.preProcess_total_time + encoder_time.invoke_total_time;
std::cout << "pre-process time : " << whisper_time.preProcess_total_time << "ms" << std::endl;
std::cout << "encoder_inference_total time : " << encoder_time.invoke_total_time << "ms" << std::endl;
}
decoder_time.invoke_total_time += decoder_time.total_time_group[i];
std::cout << "decoder inference time[" << i << "] : " << decoder_time.total_time_group[i] << "ms" << std::endl;
}
whisper_time.total_time += decoder_time.invoke_total_time;
std::cout << "model->whisper decoder avg : " << decoder_time.invoke_total_time / decoder_time.total_time_group.size() << "ms" << std::endl;
std::cout << "model->whisper total time : " << whisper_time.total_time << "ms" << std::endl;
whisper_time.total_time = decoder_time.invoke_total_time = 0;
}
encoder_time.total_time_group.clear();
if (decoder_inputs_data.input_0 != nullptr)
{
delete[] decoder_inputs_data.input_0;
decoder_inputs_data.input_0 = nullptr;
decoder_inputs_data.input_0_size = 0;
}
if (decoder_inputs_data.input_1 != nullptr)
{
delete[] decoder_inputs_data.input_1;
decoder_inputs_data.input_1 = nullptr;
decoder_inputs_data.input_1_size = 0;
}
}
ret = destroy_network(context_enc);
if (ret != 0)
{
printf("destroy_network [context_enc] fail.\n");
return -1;
}
ret = destroy_network(context_dec);
if (ret != 0)
{
printf("destroy_network [context_dec] fail.\n");
return -1;
}
return ret;
/*
* Copyright (C) 20242025 Amlogic, Inc. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <stdio.h>
#include <time.h>
#include <iostream>
#include "whisper_invoke.h"
#include "nn_sdk.h"
#define BILLION 1000000000
#define GET_INFERENCE_TIME (1)
#define WHISPER_DECODER_INPUTS 48
struct Get_Times
{
uint64_t init_start_time, init_end_time, init_total_time;
uint64_t preProcess_start_time, preProcess_end_time, preProcess_total_time;
uint64_t invoke_start_time, invoke_end_time, invoke_total_time; /* for whisper_decoder or llm invoke time once */
uint64_t total_time; /* for whisper or llm pipeline time */
std::vector<uint64_t> total_time_group; /* for whisper_decoder or llm invoke time everytimes */
};
static uint64_t get_time_count()
{
struct timespec ts;
clock_gettime(CLOCK_MONOTONIC, &ts);
return (uint64_t)((uint64_t)ts.tv_nsec + (uint64_t)ts.tv_sec * BILLION);
}
int main(int argc, char ** argv)
{
Get_Times encoder_time, decoder_time, whisper_time;
Input_Decoder decoder_inputs_data;
std::vector<float> encoder_input_data;
std::vector<float> encoder_output_data;
int64_t input_1_data[] = {50257, 50362}; /* init token, for tiny_en or base_en */
int input_1_data_size = sizeof(input_1_data) / sizeof(input_1_data[0]);
int ret = 0;
char* model_path_encoder = argv[1];
char* model_path_decoder = argv[2];
void *context_enc = NULL;
void *context_dec = NULL;
whisper_time.init_start_time = get_time_count();
context_enc = init_network_file(model_path_encoder);
context_dec = init_network_file(model_path_decoder);
whisper_time.init_end_time = get_time_count();
whisper_time.init_total_time = (whisper_time.init_end_time - whisper_time.init_start_time) / 1000000;
if (context_enc == NULL)
{
printf("init_network [context_enc] fail.\n");
return -1;
}
if (context_dec == NULL)
{
printf("init_network [context_dec] fail.\n");
return -1;
}
if (getenv("GET_TIME"))
{
std::cout << "init_whisper_total time : " << whisper_time.init_total_time << "ms" << std::endl;
}
while (true)
{
std::string input_str;
bool is_finish = false;
std::string out_text = "start"; /* end adla model output text init */
printf("\n");
printf("Audio Path:\n");
std::getline(std::cin, input_str);
if (input_str == "exit")
{
break;
} else if (input_str == "") {
printf("Please enter wav path\n");
continue;
} else if (input_str.size() < 4 || input_str.substr(input_str.size() - 4) != ".wav") {
std::cout << "Invalid wav path or file does not exist, please try again" << std::endl;
continue;
}
decoder_inputs_data.input_1_size = WHISPER_DECODER_INPUTS;
decoder_inputs_data.input_1= new int64_t[decoder_inputs_data.input_1_size];
std::copy(input_1_data, input_1_data + input_1_data_size, decoder_inputs_data.input_1);
// need enough data 0
std::fill(decoder_inputs_data.input_1 + input_1_data_size,
decoder_inputs_data.input_1 + decoder_inputs_data.input_1_size,
0);
whisper_time.preProcess_start_time = get_time_count();
encoder_input_data = do_pre_process(input_str);
if (!encoder_input_data.size()) /* support wav 0s */
{
is_finish = is_finish_end();
std::cout << "wav is null, please try again" << std::endl;
continue;
}
whisper_time.preProcess_end_time = get_time_count();
encoder_output_data = run_network_encoder_process(context_enc, encoder_input_data);
encoder_time.invoke_end_time = get_time_count();
decoder_inputs_data.input_0_size = encoder_output_data.size();
decoder_inputs_data.input_0 = new float[decoder_inputs_data.input_0_size];
std::copy(encoder_output_data.begin(), encoder_output_data.end(), decoder_inputs_data.input_0);
whisper_time.preProcess_total_time = (whisper_time.preProcess_end_time - whisper_time.preProcess_start_time) / 1000000;
encoder_time.invoke_total_time = (encoder_time.invoke_end_time - whisper_time.preProcess_end_time) / 1000000;
printf("\n");
printf("Audio Text:\n");
while (!is_finish)
{
decoder_time.invoke_start_time = get_time_count();
out_text = run_network_decoder(context_dec, &decoder_inputs_data);
decoder_time.invoke_end_time = get_time_count();
is_finish = is_finish_end();
decoder_time.total_time_group.push_back((decoder_time.invoke_end_time - decoder_time.invoke_start_time) / 1000000);
std::cout << out_text << std::flush;
}
printf("\n");
if (getenv("GET_OUTPUTS_SIZE"))
{
std::cout << "==================================" << std::endl;
std::cout << "WHISPER_OUTPUTS_SIZE : " << decoder_time.total_time_group.size() << std::endl;
}
if (getenv("GET_TIME"))
{
uint64_t total_time_whisper, total_time_decoder, total_time_llm;
for (int i = 0; i < decoder_time.total_time_group.size(); i++) {
std::cout << "==================================" << std::endl;
if (i < 1)
{
total_time_whisper = whisper_time.preProcess_total_time + encoder_time.invoke_total_time;
whisper_time.total_time = whisper_time.preProcess_total_time + encoder_time.invoke_total_time;
std::cout << "pre-process time : " << whisper_time.preProcess_total_time << "ms" << std::endl;
std::cout << "encoder_inference_total time : " << encoder_time.invoke_total_time << "ms" << std::endl;
}
decoder_time.invoke_total_time += decoder_time.total_time_group[i];
std::cout << "decoder inference time[" << i << "] : " << decoder_time.total_time_group[i] << "ms" << std::endl;
}
whisper_time.total_time += decoder_time.invoke_total_time;
std::cout << "model->whisper decoder avg : " << decoder_time.invoke_total_time / decoder_time.total_time_group.size() << "ms" << std::endl;
std::cout << "model->whisper total time : " << whisper_time.total_time << "ms" << std::endl;
whisper_time.total_time = decoder_time.invoke_total_time = 0;
}
encoder_time.total_time_group.clear();
if (decoder_inputs_data.input_0 != nullptr)
{
delete[] decoder_inputs_data.input_0;
decoder_inputs_data.input_0 = nullptr;
decoder_inputs_data.input_0_size = 0;
}
if (decoder_inputs_data.input_1 != nullptr)
{
delete[] decoder_inputs_data.input_1;
decoder_inputs_data.input_1 = nullptr;
decoder_inputs_data.input_1_size = 0;
}
}
ret = destroy_network(context_enc);
if (ret != 0)
{
printf("destroy_network [context_enc] fail.\n");
return -1;
}
ret = destroy_network(context_dec);
if (ret != 0)
{
printf("destroy_network [context_dec] fail.\n");
return -1;
}
return ret;
}

View file

@ -1,127 +1,127 @@
/*
* Copyright (C) 20242025 Amlogic, Inc. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "pre_post_common.h"
#include "post_process_whisper.h"
whisper_vocab read_token_info(std::string token_path)
{
struct whisper_context ctx;
auto & vocab = ctx.vocab;
whisper_model_loader loader = {};
auto fin = std::ifstream(token_path, std::ios::binary);
if (!fin)
{
fprintf(stderr, "%s : fail to open '%s'\n", __func__, token_path.c_str());
}
loader.context = &fin;
loader.read = [](void * ctx, void * output, size_t read_size) {
std::ifstream * fin = (std::ifstream*)ctx;
fin->read((char *)output, read_size);
return read_size;
};
loader.eof = [](void * ctx) {
std::ifstream * fin = (std::ifstream*)ctx;
return fin->eof();
};
loader.close = [](void * ctx) {
std::ifstream * fin = (std::ifstream*)ctx;
fin->close();
};
int32_t n_vocab = 0;
read_safe(&loader, n_vocab);
std::string word;
std::vector<char> tmp;
tmp.reserve(128);
for (int i = 0; i < n_vocab; i++) {
uint32_t len;
read_safe(&loader, len);
if (len > 0 and i != 50256) {
tmp.resize(len);
loader.read(loader.context, &tmp[0], tmp.size()); // read to buffer
word.assign(tmp.data(), tmp.size());
} else {
word = "";
}
vocab.token_to_id[word] = i;
vocab.id_to_token[i] = word;
}
fin.eof();
fin.close();
n_vocab = 50256;
if (n_vocab < 51863) {
// WHISPER_LOG_INFO("%s: adding %d extra tokens\n", __func__, model.hparams.n_vocab - n_vocab);
for (int i = n_vocab; i < 51863; i++) {
if (i > vocab.token_beg) {
word = "[_TT_" + std::to_string(i - vocab.token_beg) + "]";
} else if (i == vocab.token_eot) {
word = "<|endoftext|>";
} else if (i == vocab.token_sot) {
word = "<|startoftranscript|>";
} else if (i == vocab.token_translate) {
word = "<|translate|>";
} else if (i == vocab.token_transcribe) {
word = "<|transcribe|>";
} else if (i == vocab.token_solm) {
word = "[_SOLM_]";
} else if (i == vocab.token_prev) {
word = "[_PREV_]";
} else if (i == vocab.token_nosp) {
word = "[_NOSP_]";
} else if (i == vocab.token_not) {
word = "<|notimestamps|>";
} else if (i == vocab.token_beg) {
word = "[_BEG_]";
}
else if (i == 50258) {
word= "<|en|>";
}
else if (i == 50259) {
word= "<|zh|>";
}
else if (i == 50263) {
word= "<|ko|>";
}
else {
word = "[_extra_token_" + std::to_string(i) + "]";
}
vocab.token_to_id[word] = i;
vocab.id_to_token[i] = word;
}
}
return vocab;
}
std::string do_post_process(int64_t output_id, whisper_vocab vocab)
{
// std::vector<whisper_token> prompt_init = {50258, 50259, 50359, 50363,2221,13,2326,388,391,307,264,50244,295,264,2808,5359,11,293,321,366,5404,281,2928,702,14943,13,50257};
std::string text;
text = vocab.id_to_token.at(output_id).c_str();
return text;
/*
* Copyright (C) 20242025 Amlogic, Inc. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "pre_post_common.h"
#include "post_process_whisper.h"
whisper_vocab read_token_info(std::string token_path)
{
struct whisper_context ctx;
auto & vocab = ctx.vocab;
whisper_model_loader loader = {};
auto fin = std::ifstream(token_path, std::ios::binary);
if (!fin)
{
fprintf(stderr, "%s : fail to open '%s'\n", __func__, token_path.c_str());
}
loader.context = &fin;
loader.read = [](void * ctx, void * output, size_t read_size) {
std::ifstream * fin = (std::ifstream*)ctx;
fin->read((char *)output, read_size);
return read_size;
};
loader.eof = [](void * ctx) {
std::ifstream * fin = (std::ifstream*)ctx;
return fin->eof();
};
loader.close = [](void * ctx) {
std::ifstream * fin = (std::ifstream*)ctx;
fin->close();
};
int32_t n_vocab = 0;
read_safe(&loader, n_vocab);
std::string word;
std::vector<char> tmp;
tmp.reserve(128);
for (int i = 0; i < n_vocab; i++) {
uint32_t len;
read_safe(&loader, len);
if (len > 0 and i != 50256) {
tmp.resize(len);
loader.read(loader.context, &tmp[0], tmp.size()); // read to buffer
word.assign(tmp.data(), tmp.size());
} else {
word = "";
}
vocab.token_to_id[word] = i;
vocab.id_to_token[i] = word;
}
fin.eof();
fin.close();
n_vocab = 50256;
if (n_vocab < 51863) {
// WHISPER_LOG_INFO("%s: adding %d extra tokens\n", __func__, model.hparams.n_vocab - n_vocab);
for (int i = n_vocab; i < 51863; i++) {
if (i > vocab.token_beg) {
word = "[_TT_" + std::to_string(i - vocab.token_beg) + "]";
} else if (i == vocab.token_eot) {
word = "<|endoftext|>";
} else if (i == vocab.token_sot) {
word = "<|startoftranscript|>";
} else if (i == vocab.token_translate) {
word = "<|translate|>";
} else if (i == vocab.token_transcribe) {
word = "<|transcribe|>";
} else if (i == vocab.token_solm) {
word = "[_SOLM_]";
} else if (i == vocab.token_prev) {
word = "[_PREV_]";
} else if (i == vocab.token_nosp) {
word = "[_NOSP_]";
} else if (i == vocab.token_not) {
word = "<|notimestamps|>";
} else if (i == vocab.token_beg) {
word = "[_BEG_]";
}
else if (i == 50258) {
word= "<|en|>";
}
else if (i == 50259) {
word= "<|zh|>";
}
else if (i == 50263) {
word= "<|ko|>";
}
else {
word = "[_extra_token_" + std::to_string(i) + "]";
}
vocab.token_to_id[word] = i;
vocab.id_to_token[i] = word;
}
}
return vocab;
}
std::string do_post_process(int64_t output_id, whisper_vocab vocab)
{
// std::vector<whisper_token> prompt_init = {50258, 50259, 50359, 50363,2221,13,2326,388,391,307,264,50244,295,264,2808,5359,11,293,321,366,5404,281,2928,702,14943,13,50257};
std::string text;
text = vocab.id_to_token.at(output_id).c_str();
return text;
}

View file

@ -1,21 +1,21 @@
/*
* Copyright (C) 20242025 Amlogic, Inc. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "pre_post_common.h"
whisper_vocab read_token_info(std::string token_path);
int get_output_max_index(size_t id_shape, std::vector<float> buf_data);
/*
* Copyright (C) 20242025 Amlogic, Inc. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "pre_post_common.h"
whisper_vocab read_token_info(std::string token_path);
int get_output_max_index(size_t id_shape, std::vector<float> buf_data);
std::string do_post_process(int64_t output_id, whisper_vocab vocab);

View file

@ -1,105 +1,105 @@
/*
* Copyright (C) 20242025 Amlogic, Inc. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef PRE_POST_COMMON_H
#define PRE_POST_COMMON_H
#include "common.h"
#include "whisper.h"
#include <cmath>
#include <fstream>
#include <cstdio>
#include <regex>
#include <thread>
#include <vector>
#include <cstring>
#include <float.h>
struct whisper_mel {
int n_len;
int n_len_org;
int n_mel;
std::vector<float> data;
};
struct whisper_filters {
int32_t n_mel = 80;
int32_t n_fft = 201;
std::vector<float> data;
};
struct whisper_state {
int64_t t_sample_us = 0;
int64_t t_encode_us = 0;
int64_t t_decode_us = 0;
int64_t t_batchd_us = 0;
int64_t t_prompt_us = 0;
int64_t t_mel_us = 0;
int32_t n_sample = 0; // number of tokens sampled
int32_t n_encode = 0; // number of encoder calls
int32_t n_decode = 0; // number of decoder calls with n_tokens == 1 (text-generation)
int32_t n_batchd = 0; // number of decoder calls with n_tokens < 16 (batch decoding)
int32_t n_prompt = 0; // number of decoder calls with n_tokens > 1 (prompt encoding)
int32_t n_fail_p = 0; // number of logprob threshold failures
int32_t n_fail_h = 0; // number of entropy threshold failures
whisper_mel mel;
// decode output (2-dimensional array: [n_tokens][n_vocab])
std::vector<float> logits;
std::vector<whisper_token> prompt_past;
int lang_id = 0; // english by default
std::string path_model; // populated by whisper_init_from_file_with_params()
// [EXPERIMENTAL] token-level timestamps data
int64_t t_beg = 0;
int64_t t_last = 0;
whisper_token tid_last;
std::vector<float> energy; // PCM signal energy
// [EXPERIMENTAL] speed-up techniques
int32_t exp_n_audio_ctx = 0; // 0 - use default
};
struct whisper_model {
whisper_filters filters;
};
struct whisper_context {
int64_t t_load_us = 0;
int64_t t_start_us = 0;
whisper_model model;
whisper_vocab vocab;
whisper_state * state = nullptr;
};
template<typename T>
static void read_safe(whisper_model_loader * loader, T & dest) {
loader->read(loader->context, &dest, sizeof(T));
}
/*
* Copyright (C) 20242025 Amlogic, Inc. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef PRE_POST_COMMON_H
#define PRE_POST_COMMON_H
#include "common.h"
#include "whisper.h"
#include <cmath>
#include <fstream>
#include <cstdio>
#include <regex>
#include <thread>
#include <vector>
#include <cstring>
#include <float.h>
struct whisper_mel {
int n_len;
int n_len_org;
int n_mel;
std::vector<float> data;
};
struct whisper_filters {
int32_t n_mel = 80;
int32_t n_fft = 201;
std::vector<float> data;
};
struct whisper_state {
int64_t t_sample_us = 0;
int64_t t_encode_us = 0;
int64_t t_decode_us = 0;
int64_t t_batchd_us = 0;
int64_t t_prompt_us = 0;
int64_t t_mel_us = 0;
int32_t n_sample = 0; // number of tokens sampled
int32_t n_encode = 0; // number of encoder calls
int32_t n_decode = 0; // number of decoder calls with n_tokens == 1 (text-generation)
int32_t n_batchd = 0; // number of decoder calls with n_tokens < 16 (batch decoding)
int32_t n_prompt = 0; // number of decoder calls with n_tokens > 1 (prompt encoding)
int32_t n_fail_p = 0; // number of logprob threshold failures
int32_t n_fail_h = 0; // number of entropy threshold failures
whisper_mel mel;
// decode output (2-dimensional array: [n_tokens][n_vocab])
std::vector<float> logits;
std::vector<whisper_token> prompt_past;
int lang_id = 0; // english by default
std::string path_model; // populated by whisper_init_from_file_with_params()
// [EXPERIMENTAL] token-level timestamps data
int64_t t_beg = 0;
int64_t t_last = 0;
whisper_token tid_last;
std::vector<float> energy; // PCM signal energy
// [EXPERIMENTAL] speed-up techniques
int32_t exp_n_audio_ctx = 0; // 0 - use default
};
struct whisper_model {
whisper_filters filters;
};
struct whisper_context {
int64_t t_load_us = 0;
int64_t t_start_us = 0;
whisper_model model;
whisper_vocab vocab;
whisper_state * state = nullptr;
};
template<typename T>
static void read_safe(whisper_model_loader * loader, T & dest) {
loader->read(loader->context, &dest, sizeof(T));
}
#endif // PRE_POST_COMMON_H

View file

@ -1,53 +1,53 @@
/*
* Copyright (C) 20242025 Amlogic, Inc. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "pre_process_whisper.h"
#include "pre_post_common.h"
extern bool is_finish;
std::vector<float> do_pre_process(std::string fname_inp)
{
std::vector<float> pcmf32; // mono-channel F32 PCM
std::vector<std::vector<float>> pcmf32s; // stereo-channel F32 PCM
struct whisper_context ctx;
struct whisper_state state;
if (!read_wav(fname_inp, pcmf32, pcmf32s, false)) {
fprintf(stderr, "error: failed to read WAV file '%s'\n", fname_inp.c_str());
is_finish = true;
return {};
}
if (float(pcmf32.size())/WHISPER_SAMPLE_RATE == 0) {
is_finish = true;
return {};
}
if (whisper_pcm_to_mel_with_state(&ctx, &state, pcmf32.data(), pcmf32.size(), 8) != 0) {
printf("%s: failed to compute log mel spectrogram\n", __func__);
}
std::vector<float> input_data;
for (int j = 0; j < 80; j++) {
for (int i = 0; i < 3000; i++) {
input_data.push_back(state.mel.data[j * state.mel.n_len + i]);
}
}
return input_data;
/*
* Copyright (C) 20242025 Amlogic, Inc. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "pre_process_whisper.h"
#include "pre_post_common.h"
extern bool is_finish;
std::vector<float> do_pre_process(std::string fname_inp)
{
std::vector<float> pcmf32; // mono-channel F32 PCM
std::vector<std::vector<float>> pcmf32s; // stereo-channel F32 PCM
struct whisper_context ctx;
struct whisper_state state;
if (!read_wav(fname_inp, pcmf32, pcmf32s, false)) {
fprintf(stderr, "error: failed to read WAV file '%s'\n", fname_inp.c_str());
is_finish = true;
return {};
}
if (float(pcmf32.size())/WHISPER_SAMPLE_RATE == 0) {
is_finish = true;
return {};
}
if (whisper_pcm_to_mel_with_state(&ctx, &state, pcmf32.data(), pcmf32.size(), 8) != 0) {
printf("%s: failed to compute log mel spectrogram\n", __func__);
}
std::vector<float> input_data;
for (int j = 0; j < 80; j++) {
for (int i = 0; i < 3000; i++) {
input_data.push_back(state.mel.data[j * state.mel.n_len + i]);
}
}
return input_data;
}

View file

@ -1,21 +1,21 @@
/*
* Copyright (C) 20242025 Amlogic, Inc. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <vector>
#include <string>
#include <float.h>
/*
* Copyright (C) 20242025 Amlogic, Inc. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <vector>
#include <string>
#include <float.h>
std::vector<float> do_pre_process(std::string fname_inp);

View file

@ -1,433 +1,433 @@
/*
* Copyright (C) 20242025 Amlogic, Inc. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "whisper.h"
#include <atomic>
#include <algorithm>
#define _USE_MATH_DEFINES
#include <cassert>
#include <cmath>
#include <cstdio>
#include <cstdarg>
#include <cstring>
#include <fstream>
#include <set>
#include <thread>
#include <vector>
#include <regex>
#include <random>
#include <functional>
#include <codecvt>
struct whisper_mel {
int n_len;
int n_len_org;
int n_mel;
std::vector<float> data;
};
struct whisper_filters {
int32_t n_mel = 80;
int32_t n_fft = 201;
std::vector<float> data;
};
struct whisper_state {
int64_t t_sample_us = 0;
int64_t t_encode_us = 0;
int64_t t_decode_us = 0;
int64_t t_batchd_us = 0;
int64_t t_prompt_us = 0;
int64_t t_mel_us = 0;
int32_t n_sample = 0; // number of tokens sampled
int32_t n_encode = 0; // number of encoder calls
int32_t n_decode = 0; // number of decoder calls with n_tokens == 1 (text-generation)
int32_t n_batchd = 0; // number of decoder calls with n_tokens < 16 (batch decoding)
int32_t n_prompt = 0; // number of decoder calls with n_tokens > 1 (prompt encoding)
int32_t n_fail_p = 0; // number of logprob threshold failures
int32_t n_fail_h = 0; // number of entropy threshold failures
whisper_mel mel;
std::vector<float> logits;
std::vector<whisper_token> prompt_past;
int lang_id = 0; // english by default
std::string path_model; // populated by whisper_init_from_file_with_params()
// [EXPERIMENTAL] token-level timestamps data
int64_t t_beg = 0;
int64_t t_last = 0;
whisper_token tid_last;
std::vector<float> energy; // PCM signal energy
// [EXPERIMENTAL] speed-up techniques
int32_t exp_n_audio_ctx = 0; // 0 - use default
};
struct whisper_model {
whisper_filters filters;
};
struct whisper_context {
int64_t t_load_us = 0;
int64_t t_start_us = 0;
whisper_model model;
whisper_vocab vocab;
whisper_state * state = nullptr;
};
#define SIN_COS_N_COUNT WHISPER_N_FFT
static float sin_vals[SIN_COS_N_COUNT];
static float cos_vals[SIN_COS_N_COUNT];
// In FFT, we frequently use sine and cosine operations with the same values.
// We can use precalculated values to speed up the process.
static void fill_sin_cos_table() {
static bool is_filled = false;
if (is_filled) return;
for (int i = 0; i < SIN_COS_N_COUNT; i++) {
double theta = (2*M_PI*i)/SIN_COS_N_COUNT;
sin_vals[i] = sinf(theta);
cos_vals[i] = cosf(theta);
}
is_filled = true;
}
// naive Discrete Fourier Transform
// input is real-valued
// output is complex-valued
static void dft(const std::vector<float> & in, std::vector<float> & out) {
int N = in.size();
out.resize(N*2);
const int sin_cos_step = SIN_COS_N_COUNT / N;
for (int k = 0; k < N; k++) {
float re = 0;
float im = 0;
for (int n = 0; n < N; n++) {
int idx = (k * n * sin_cos_step) % (SIN_COS_N_COUNT); // t = 2*M_PI*k*n/N
re += in[n]*cos_vals[idx]; // cos(t)
im -= in[n]*sin_vals[idx]; // sin(t)
}
out[k*2 + 0] = re;
out[k*2 + 1] = im;
}
}
// Cooley-Tukey FFT
// poor man's implementation - use something better
// input is real-valued
// output is complex-valued
static void fft(const std::vector<float> & in, std::vector<float> & out) {
out.resize(in.size()*2);
int N = in.size();
if (N == 1) {
out[0] = in[0];
out[1] = 0;
return;
}
if (N%2 == 1) {
dft(in, out);
return;
}
std::vector<float> even;
std::vector<float> odd;
even.reserve(N/2);
odd.reserve(N/2);
for (int i = 0; i < N; i++) {
if (i % 2 == 0) {
even.push_back(in[i]);
} else {
odd.push_back(in[i]);
}
}
std::vector<float> even_fft;
std::vector<float> odd_fft;
fft(even, even_fft);
fft(odd, odd_fft);
const int sin_cos_step = SIN_COS_N_COUNT / N;
for (int k = 0; k < N/2; k++) {
int idx = k * sin_cos_step; // t = 2*M_PI*k/N
float re = cos_vals[idx]; // cos(t)
float im = -sin_vals[idx]; // sin(t)
float re_odd = odd_fft[2*k + 0];
float im_odd = odd_fft[2*k + 1];
out[2*k + 0] = even_fft[2*k + 0] + re*re_odd - im*im_odd;
out[2*k + 1] = even_fft[2*k + 1] + re*im_odd + im*re_odd;
out[2*(k + N/2) + 0] = even_fft[2*k + 0] - re*re_odd + im*im_odd;
out[2*(k + N/2) + 1] = even_fft[2*k + 1] - re*im_odd - im*re_odd;
}
}
static bool hann_window(int length, bool periodic, std::vector<float> & output) {
if (output.size() < static_cast<size_t>(length)) {
output.resize(length);
}
int offset = -1;
if (periodic) {
offset = 0;
}
for (int i = 0; i < length; i++) {
output[i] = 0.5*(1.0 - cosf((2.0*M_PI*i)/(length + offset)));
}
return true;
}
static void log_mel_spectrogram_worker_thread(int ith, const std::vector<float> & hann, const std::vector<float> & samples,
int n_samples, int frame_size, int frame_step, int n_threads,
const whisper_filters & filters, whisper_mel & mel) {
std::vector<float> fft_in(frame_size, 0.0);
std::vector<float> fft_out(2 * frame_size);
int n_fft = filters.n_fft;
int i = ith;
assert(n_fft == 1 + (frame_size / 2));
// calculate FFT only when fft_in are not all zero
for (; i < std::min(n_samples / frame_step + 1, mel.n_len); i += n_threads) {
const int offset = i * frame_step;
// apply Hanning window (~10% faster)
for (int j = 0; j < std::min(frame_size, n_samples - offset); j++) {
fft_in[j] = hann[j] * samples[offset + j];
}
// fill the rest with zeros
if (n_samples - offset < frame_size) {
std::fill(fft_in.begin() + (n_samples - offset), fft_in.end(), 0.0);
}
// FFT
fft(fft_in, fft_out);
// Calculate modulus^2 of complex numbers
// Use pow(fft_out[2 * j + 0], 2) + pow(fft_out[2 * j + 1], 2) causes inference quality problem? Interesting.
for (int j = 0; j < n_fft; j++) {
fft_out[j] = (fft_out[2 * j + 0] * fft_out[2 * j + 0] + fft_out[2 * j + 1] * fft_out[2 * j + 1]);
}
// mel spectrogram
for (int j = 0; j < mel.n_mel; j++) {
double sum = 0.0;
// unroll loop (suggested by GH user @lunixbochs)
int k = 0;
for (k = 0; k < n_fft - 3; k += 4) {
sum +=
fft_out[k + 0] * filters.data[j * n_fft + k + 0] +
fft_out[k + 1] * filters.data[j * n_fft + k + 1] +
fft_out[k + 2] * filters.data[j * n_fft + k + 2] +
fft_out[k + 3] * filters.data[j * n_fft + k + 3];
}
// handle n_fft remainder
for (; k < n_fft; k++) {
sum += fft_out[k] * filters.data[j * n_fft + k];
}
sum = log10(std::max(sum, 1e-10));
mel.data[j * mel.n_len + i] = sum;
}
}
// Otherwise fft_out are all zero
double sum = log10(1e-10);
for (; i < mel.n_len; i += n_threads) {
for (int j = 0; j < mel.n_mel; j++) {
mel.data[j * mel.n_len + i] = sum;
}
}
}
static bool log_mel_spectrogram(
whisper_state & wstate,
const float * samples,
const int n_samples,
const int /*sample_rate*/,
const int frame_size,
const int frame_step,
const int n_mel,
const int n_threads,
whisper_filters & filters,
const bool debug,
whisper_mel & mel) {
// Hanning window (Use cosf to eliminate difference)
// ref: https://pytorch.org/docs/stable/generated/torch.hann_window.html
// ref: https://github.com/openai/whisper/blob/main/whisper/audio.py#L147
fill_sin_cos_table();
// auto & filters = filters;
filters.data.resize(filters.n_mel*filters.n_fft);
auto fin = std::ifstream("./data_bin/data.bin", std::ios::binary);
if (!fin)
{
fprintf(stderr, "%s : fail to open '%s'\n", __func__, "./data_bin/data.bin");
}
fin.read((char *)filters.data.data(), filters.data.size()*sizeof(float));
fin.eof();
fin.close();
std::vector<float> hann;
hann_window(frame_size, true, hann);
// Calculate the length of padding
int64_t stage_1_pad = WHISPER_SAMPLE_RATE * 30;
int64_t stage_2_pad = frame_size / 2;
// Initialize a vector and copy data from C array to it.
std::vector<float> samples_padded;
samples_padded.resize(n_samples + stage_1_pad + stage_2_pad * 2);
std::copy(samples, samples + n_samples, samples_padded.begin() + stage_2_pad);
// pad 30 seconds of zeros at the end of audio (480,000 samples) + reflective pad 200 samples at the end of audio
std::fill(samples_padded.begin() + n_samples + stage_2_pad, samples_padded.begin() + n_samples + stage_1_pad + 2 * stage_2_pad, 0);
// reflective pad 200 samples at the beginning of audio
std::reverse_copy(samples + 1, samples + 1 + stage_2_pad, samples_padded.begin());
mel.n_mel = n_mel;
// https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/SpectralOps.cpp#L936
// Calculate number of frames + remove the last frame
mel.n_len = (samples_padded.size() - frame_size) / frame_step;
// Calculate semi-padded sample length to ensure compatibility
mel.n_len_org = 1 + (n_samples + stage_2_pad - frame_size) / frame_step;
mel.data.resize(mel.n_mel * mel.n_len);
{
std::vector<std::thread> workers(n_threads - 1);
for (int iw = 0; iw < n_threads - 1; ++iw) {
workers[iw] = std::thread(
log_mel_spectrogram_worker_thread, iw + 1, std::cref(hann), samples_padded,
n_samples + stage_2_pad, frame_size, frame_step, n_threads,
std::cref(filters), std::ref(mel));
}
// main thread
log_mel_spectrogram_worker_thread(0, hann, samples_padded, n_samples + stage_2_pad, frame_size, frame_step, n_threads, filters, mel);
for (int iw = 0; iw < n_threads - 1; ++iw) {
workers[iw].join();
}
}
// clamping and normalization
double mmax = -1e20;
for (int i = 0; i < mel.n_mel*mel.n_len; i++) {
if (mel.data[i] > mmax) {
mmax = mel.data[i];
}
}
mmax -= 8.0;
for (int i = 0; i < mel.n_mel*mel.n_len; i++) {
if (mel.data[i] < mmax) {
mel.data[i] = mmax;
}
mel.data[i] = (mel.data[i] + 4.0)/4.0;
}
return true;
}
int whisper_pcm_to_mel_with_state(struct whisper_context * ctx, struct whisper_state * state, const float * samples, int n_samples, int n_threads) {
if (!log_mel_spectrogram(*state, samples, n_samples, WHISPER_SAMPLE_RATE, WHISPER_N_FFT, WHISPER_HOP_LENGTH, 80, n_threads, ctx->model.filters, true, state->mel)) {
printf("%s: failed to compute mel spectrogram\n", __func__);
return -1;
}
return 0;
}
static std::vector<whisper_vocab::id> tokenize(const whisper_vocab & vocab, const std::string & text) {
std::vector<std::string> words;
// first split the text into words
{
std::string str = text;
std::string pat = R"('s|'t|'re|'ve|'m|'ll|'d| ?[[:alpha:]]+| ?[[:digit:]]+| ?[^\s[:alpha:][:digit:]]+|\s+(?!\S)|\s+)";
std::regex re(pat);
std::smatch m;
while (std::regex_search(str, m, re)) {
for (auto x : m) {
words.push_back(x);
}
str = m.suffix();
}
}
// find the longest tokens that form the words:
std::vector<whisper_vocab::id> tokens;
for (const auto & word : words) {
if (word.empty()) continue;
int i = 0;
int n = word.size();
while (i < n) {
int j = n;
bool found = false;
while (j > i) {
auto sub = word.substr(i, j-i);
auto it = vocab.token_to_id.find(sub);
if (it != vocab.token_to_id.end()) {
tokens.push_back(it->second);
i = j;
found = true;
break;
}
--j;
}
if (!found) {
printf("unknown token\n");
++i;
}
}
}
return tokens;
}
/*
* Copyright (C) 20242025 Amlogic, Inc. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "whisper.h"
#include <atomic>
#include <algorithm>
#define _USE_MATH_DEFINES
#include <cassert>
#include <cmath>
#include <cstdio>
#include <cstdarg>
#include <cstring>
#include <fstream>
#include <set>
#include <thread>
#include <vector>
#include <regex>
#include <random>
#include <functional>
#include <codecvt>
struct whisper_mel {
int n_len;
int n_len_org;
int n_mel;
std::vector<float> data;
};
struct whisper_filters {
int32_t n_mel = 80;
int32_t n_fft = 201;
std::vector<float> data;
};
struct whisper_state {
int64_t t_sample_us = 0;
int64_t t_encode_us = 0;
int64_t t_decode_us = 0;
int64_t t_batchd_us = 0;
int64_t t_prompt_us = 0;
int64_t t_mel_us = 0;
int32_t n_sample = 0; // number of tokens sampled
int32_t n_encode = 0; // number of encoder calls
int32_t n_decode = 0; // number of decoder calls with n_tokens == 1 (text-generation)
int32_t n_batchd = 0; // number of decoder calls with n_tokens < 16 (batch decoding)
int32_t n_prompt = 0; // number of decoder calls with n_tokens > 1 (prompt encoding)
int32_t n_fail_p = 0; // number of logprob threshold failures
int32_t n_fail_h = 0; // number of entropy threshold failures
whisper_mel mel;
std::vector<float> logits;
std::vector<whisper_token> prompt_past;
int lang_id = 0; // english by default
std::string path_model; // populated by whisper_init_from_file_with_params()
// [EXPERIMENTAL] token-level timestamps data
int64_t t_beg = 0;
int64_t t_last = 0;
whisper_token tid_last;
std::vector<float> energy; // PCM signal energy
// [EXPERIMENTAL] speed-up techniques
int32_t exp_n_audio_ctx = 0; // 0 - use default
};
struct whisper_model {
whisper_filters filters;
};
struct whisper_context {
int64_t t_load_us = 0;
int64_t t_start_us = 0;
whisper_model model;
whisper_vocab vocab;
whisper_state * state = nullptr;
};
#define SIN_COS_N_COUNT WHISPER_N_FFT
static float sin_vals[SIN_COS_N_COUNT];
static float cos_vals[SIN_COS_N_COUNT];
// In FFT, we frequently use sine and cosine operations with the same values.
// We can use precalculated values to speed up the process.
static void fill_sin_cos_table() {
static bool is_filled = false;
if (is_filled) return;
for (int i = 0; i < SIN_COS_N_COUNT; i++) {
double theta = (2*M_PI*i)/SIN_COS_N_COUNT;
sin_vals[i] = sinf(theta);
cos_vals[i] = cosf(theta);
}
is_filled = true;
}
// naive Discrete Fourier Transform
// input is real-valued
// output is complex-valued
static void dft(const std::vector<float> & in, std::vector<float> & out) {
int N = in.size();
out.resize(N*2);
const int sin_cos_step = SIN_COS_N_COUNT / N;
for (int k = 0; k < N; k++) {
float re = 0;
float im = 0;
for (int n = 0; n < N; n++) {
int idx = (k * n * sin_cos_step) % (SIN_COS_N_COUNT); // t = 2*M_PI*k*n/N
re += in[n]*cos_vals[idx]; // cos(t)
im -= in[n]*sin_vals[idx]; // sin(t)
}
out[k*2 + 0] = re;
out[k*2 + 1] = im;
}
}
// Cooley-Tukey FFT
// poor man's implementation - use something better
// input is real-valued
// output is complex-valued
static void fft(const std::vector<float> & in, std::vector<float> & out) {
out.resize(in.size()*2);
int N = in.size();
if (N == 1) {
out[0] = in[0];
out[1] = 0;
return;
}
if (N%2 == 1) {
dft(in, out);
return;
}
std::vector<float> even;
std::vector<float> odd;
even.reserve(N/2);
odd.reserve(N/2);
for (int i = 0; i < N; i++) {
if (i % 2 == 0) {
even.push_back(in[i]);
} else {
odd.push_back(in[i]);
}
}
std::vector<float> even_fft;
std::vector<float> odd_fft;
fft(even, even_fft);
fft(odd, odd_fft);
const int sin_cos_step = SIN_COS_N_COUNT / N;
for (int k = 0; k < N/2; k++) {
int idx = k * sin_cos_step; // t = 2*M_PI*k/N
float re = cos_vals[idx]; // cos(t)
float im = -sin_vals[idx]; // sin(t)
float re_odd = odd_fft[2*k + 0];
float im_odd = odd_fft[2*k + 1];
out[2*k + 0] = even_fft[2*k + 0] + re*re_odd - im*im_odd;
out[2*k + 1] = even_fft[2*k + 1] + re*im_odd + im*re_odd;
out[2*(k + N/2) + 0] = even_fft[2*k + 0] - re*re_odd + im*im_odd;
out[2*(k + N/2) + 1] = even_fft[2*k + 1] - re*im_odd - im*re_odd;
}
}
static bool hann_window(int length, bool periodic, std::vector<float> & output) {
if (output.size() < static_cast<size_t>(length)) {
output.resize(length);
}
int offset = -1;
if (periodic) {
offset = 0;
}
for (int i = 0; i < length; i++) {
output[i] = 0.5*(1.0 - cosf((2.0*M_PI*i)/(length + offset)));
}
return true;
}
static void log_mel_spectrogram_worker_thread(int ith, const std::vector<float> & hann, const std::vector<float> & samples,
int n_samples, int frame_size, int frame_step, int n_threads,
const whisper_filters & filters, whisper_mel & mel) {
std::vector<float> fft_in(frame_size, 0.0);
std::vector<float> fft_out(2 * frame_size);
int n_fft = filters.n_fft;
int i = ith;
assert(n_fft == 1 + (frame_size / 2));
// calculate FFT only when fft_in are not all zero
for (; i < std::min(n_samples / frame_step + 1, mel.n_len); i += n_threads) {
const int offset = i * frame_step;
// apply Hanning window (~10% faster)
for (int j = 0; j < std::min(frame_size, n_samples - offset); j++) {
fft_in[j] = hann[j] * samples[offset + j];
}
// fill the rest with zeros
if (n_samples - offset < frame_size) {
std::fill(fft_in.begin() + (n_samples - offset), fft_in.end(), 0.0);
}
// FFT
fft(fft_in, fft_out);
// Calculate modulus^2 of complex numbers
// Use pow(fft_out[2 * j + 0], 2) + pow(fft_out[2 * j + 1], 2) causes inference quality problem? Interesting.
for (int j = 0; j < n_fft; j++) {
fft_out[j] = (fft_out[2 * j + 0] * fft_out[2 * j + 0] + fft_out[2 * j + 1] * fft_out[2 * j + 1]);
}
// mel spectrogram
for (int j = 0; j < mel.n_mel; j++) {
double sum = 0.0;
// unroll loop (suggested by GH user @lunixbochs)
int k = 0;
for (k = 0; k < n_fft - 3; k += 4) {
sum +=
fft_out[k + 0] * filters.data[j * n_fft + k + 0] +
fft_out[k + 1] * filters.data[j * n_fft + k + 1] +
fft_out[k + 2] * filters.data[j * n_fft + k + 2] +
fft_out[k + 3] * filters.data[j * n_fft + k + 3];
}
// handle n_fft remainder
for (; k < n_fft; k++) {
sum += fft_out[k] * filters.data[j * n_fft + k];
}
sum = log10(std::max(sum, 1e-10));
mel.data[j * mel.n_len + i] = sum;
}
}
// Otherwise fft_out are all zero
double sum = log10(1e-10);
for (; i < mel.n_len; i += n_threads) {
for (int j = 0; j < mel.n_mel; j++) {
mel.data[j * mel.n_len + i] = sum;
}
}
}
static bool log_mel_spectrogram(
whisper_state & wstate,
const float * samples,
const int n_samples,
const int /*sample_rate*/,
const int frame_size,
const int frame_step,
const int n_mel,
const int n_threads,
whisper_filters & filters,
const bool debug,
whisper_mel & mel) {
// Hanning window (Use cosf to eliminate difference)
// ref: https://pytorch.org/docs/stable/generated/torch.hann_window.html
// ref: https://github.com/openai/whisper/blob/main/whisper/audio.py#L147
fill_sin_cos_table();
// auto & filters = filters;
filters.data.resize(filters.n_mel*filters.n_fft);
auto fin = std::ifstream("./data_bin/data.bin", std::ios::binary);
if (!fin)
{
fprintf(stderr, "%s : fail to open '%s'\n", __func__, "./data_bin/data.bin");
}
fin.read((char *)filters.data.data(), filters.data.size()*sizeof(float));
fin.eof();
fin.close();
std::vector<float> hann;
hann_window(frame_size, true, hann);
// Calculate the length of padding
int64_t stage_1_pad = WHISPER_SAMPLE_RATE * 30;
int64_t stage_2_pad = frame_size / 2;
// Initialize a vector and copy data from C array to it.
std::vector<float> samples_padded;
samples_padded.resize(n_samples + stage_1_pad + stage_2_pad * 2);
std::copy(samples, samples + n_samples, samples_padded.begin() + stage_2_pad);
// pad 30 seconds of zeros at the end of audio (480,000 samples) + reflective pad 200 samples at the end of audio
std::fill(samples_padded.begin() + n_samples + stage_2_pad, samples_padded.begin() + n_samples + stage_1_pad + 2 * stage_2_pad, 0);
// reflective pad 200 samples at the beginning of audio
std::reverse_copy(samples + 1, samples + 1 + stage_2_pad, samples_padded.begin());
mel.n_mel = n_mel;
// https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/SpectralOps.cpp#L936
// Calculate number of frames + remove the last frame
mel.n_len = (samples_padded.size() - frame_size) / frame_step;
// Calculate semi-padded sample length to ensure compatibility
mel.n_len_org = 1 + (n_samples + stage_2_pad - frame_size) / frame_step;
mel.data.resize(mel.n_mel * mel.n_len);
{
std::vector<std::thread> workers(n_threads - 1);
for (int iw = 0; iw < n_threads - 1; ++iw) {
workers[iw] = std::thread(
log_mel_spectrogram_worker_thread, iw + 1, std::cref(hann), samples_padded,
n_samples + stage_2_pad, frame_size, frame_step, n_threads,
std::cref(filters), std::ref(mel));
}
// main thread
log_mel_spectrogram_worker_thread(0, hann, samples_padded, n_samples + stage_2_pad, frame_size, frame_step, n_threads, filters, mel);
for (int iw = 0; iw < n_threads - 1; ++iw) {
workers[iw].join();
}
}
// clamping and normalization
double mmax = -1e20;
for (int i = 0; i < mel.n_mel*mel.n_len; i++) {
if (mel.data[i] > mmax) {
mmax = mel.data[i];
}
}
mmax -= 8.0;
for (int i = 0; i < mel.n_mel*mel.n_len; i++) {
if (mel.data[i] < mmax) {
mel.data[i] = mmax;
}
mel.data[i] = (mel.data[i] + 4.0)/4.0;
}
return true;
}
int whisper_pcm_to_mel_with_state(struct whisper_context * ctx, struct whisper_state * state, const float * samples, int n_samples, int n_threads) {
if (!log_mel_spectrogram(*state, samples, n_samples, WHISPER_SAMPLE_RATE, WHISPER_N_FFT, WHISPER_HOP_LENGTH, 80, n_threads, ctx->model.filters, true, state->mel)) {
printf("%s: failed to compute mel spectrogram\n", __func__);
return -1;
}
return 0;
}
static std::vector<whisper_vocab::id> tokenize(const whisper_vocab & vocab, const std::string & text) {
std::vector<std::string> words;
// first split the text into words
{
std::string str = text;
std::string pat = R"('s|'t|'re|'ve|'m|'ll|'d| ?[[:alpha:]]+| ?[[:digit:]]+| ?[^\s[:alpha:][:digit:]]+|\s+(?!\S)|\s+)";
std::regex re(pat);
std::smatch m;
while (std::regex_search(str, m, re)) {
for (auto x : m) {
words.push_back(x);
}
str = m.suffix();
}
}
// find the longest tokens that form the words:
std::vector<whisper_vocab::id> tokens;
for (const auto & word : words) {
if (word.empty()) continue;
int i = 0;
int n = word.size();
while (i < n) {
int j = n;
bool found = false;
while (j > i) {
auto sub = word.substr(i, j-i);
auto it = vocab.token_to_id.find(sub);
if (it != vocab.token_to_id.end()) {
tokens.push_back(it->second);
i = j;
found = true;
break;
}
--j;
}
if (!found) {
printf("unknown token\n");
++i;
}
}
}
return tokens;
}

View file

@ -1,404 +1,404 @@
/*
* Copyright (C) 20242025 Amlogic, Inc. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef WHISPER_H
#define WHISPER_H
#include <stddef.h>
#include <stdint.h>
#include <stdbool.h>
#include <string>
#include <map>
#include <cstdint>
#ifdef __GNUC__
# define WHISPER_DEPRECATED(func, hint) func __attribute__((deprecated(hint)))
#elif defined(_MSC_VER)
# define WHISPER_DEPRECATED(func, hint) __declspec(deprecated(hint)) func
#else
# define WHISPER_DEPRECATED(func, hint) func
#endif
#ifdef WHISPER_SHARED
# ifdef _WIN32
# ifdef WHISPER_BUILD
# define WHISPER_API __declspec(dllexport)
# else
# define WHISPER_API __declspec(dllimport)
# endif
# else
# define WHISPER_API __attribute__ ((visibility ("default")))
# endif
#else
# define WHISPER_API
#endif
#define WHISPER_SAMPLE_RATE 16000
#define WHISPER_N_FFT 400
#define WHISPER_HOP_LENGTH 160
#define WHISPER_CHUNK_SIZE 30
#define WHISPER_N_MELS 80
#define WHISPER_N_FRAMES 3000
#define WHISPER_N_SAMPLES 48000
struct whisper_vocab {
using id = int32_t;
using token = std::string;
int n_vocab = 51864;
std::map<token, id> token_to_id;
std::map<id, token> id_to_token;
// reference: https://github.com/openai/whisper/blob/248b6cb124225dd263bb9bd32d060b6517e067f8/whisper/tokenizer.py#L334-L349
id token_eot = 50256;
id token_sot = 50257;
// task tokens (used only for multilingual models)
id token_translate = 50357;
id token_transcribe = 50358;
// other special tokens
id token_solm = 50359; // [TDRZ] used by tinydiarize models to indicate speaker turn
id token_prev = 50360;
id token_nosp = 50361;
id token_not = 50362; // no timestamps
id token_beg = 50363; // begin timestamps
bool is_multilingual() const {
return n_vocab >= 51865;
}
int num_languages() const {
return n_vocab - 51765 - (is_multilingual() ? 1 : 0);
}
};
#ifdef __cplusplus
extern "C" {
#endif
struct whisper_context;
struct whisper_state;
struct whisper_full_params;
typedef int32_t whisper_pos;
typedef int32_t whisper_token;
typedef int32_t whisper_seq_id;
typedef struct whisper_token_data {
whisper_token id; // token id
whisper_token tid; // forced timestamp token id
float p; // probability of the token
float plog; // log probability of the token
float pt; // probability of the timestamp token
float ptsum; // sum of probabilities of all timestamp tokens
// token-level timestamp data
// do not use if you haven't computed token-level timestamps
int64_t t0; // start time of the token
int64_t t1; // end time of the token
// [EXPERIMENTAL] Token-level timestamps with DTW
// do not use if you haven't computed token-level timestamps with dtw
// Roughly corresponds to the moment in audio in which the token was output
int64_t t_dtw;
float vlen; // voice length of the token
} whisper_token_data;
typedef struct whisper_model_loader {
void * context;
size_t (*read)(void * ctx, void * output, size_t read_size);
bool (*eof)(void * ctx);
void (*close)(void * ctx);
} whisper_model_loader;
// Various functions for loading a ggml whisper model.
// Allocate (almost) all memory needed for the model.
// Return NULL on failure
WHISPER_API struct whisper_context * whisper_init_from_file_with_params (const char * path_model, struct whisper_context_params params);
WHISPER_API struct whisper_context * whisper_init_from_buffer_with_params(void * buffer, size_t buffer_size, struct whisper_context_params params);
WHISPER_API struct whisper_context * whisper_init_with_params (struct whisper_model_loader * loader, struct whisper_context_params params);
// These are the same as the above, but the internal state of the context is not allocated automatically
// It is the responsibility of the caller to allocate the state using whisper_init_state() (#523)
WHISPER_API struct whisper_context * whisper_init_from_file_with_params_no_state (const char * path_model, struct whisper_context_params params);
WHISPER_API struct whisper_context * whisper_init_from_buffer_with_params_no_state(void * buffer, size_t buffer_size, struct whisper_context_params params);
WHISPER_API struct whisper_context * whisper_init_with_params_no_state (struct whisper_model_loader * loader, struct whisper_context_params params);
WHISPER_API struct whisper_state * whisper_init_state(struct whisper_context * ctx);
// Given a context, enable use of OpenVINO for encode inference.
// model_path: Optional path to OpenVINO encoder IR model. If set to nullptr,
// the path will be generated from the ggml model path that was passed
// in to whisper_init_from_file. For example, if 'path_model' was
// "/path/to/ggml-base.en.bin", then OpenVINO IR model path will be
// assumed to be "/path/to/ggml-base.en-encoder-openvino.xml".
// device: OpenVINO device to run inference on ("CPU", "GPU", etc.)
// cache_dir: Optional cache directory that can speed up init time, especially for
// GPU, by caching compiled 'blobs' there.
// Set to nullptr if not used.
// Returns 0 on success. If OpenVINO is not enabled in build, this simply returns 1.
WHISPER_API int whisper_ctx_init_openvino_encoder(
struct whisper_context * ctx,
const char * model_path,
const char * device,
const char * cache_dir);
// Frees all allocated memory
WHISPER_API void whisper_free (struct whisper_context * ctx);
WHISPER_API void whisper_free_state(struct whisper_state * state);
WHISPER_API void whisper_free_params(struct whisper_full_params * params);
WHISPER_API void whisper_free_context_params(struct whisper_context_params * params);
// Convert RAW PCM audio to log mel spectrogram.
// The resulting spectrogram is stored inside the default state of the provided whisper context.
// Returns 0 on success
WHISPER_API int whisper_pcm_to_mel(
struct whisper_context * ctx,
const float * samples,
int n_samples,
int n_threads);
WHISPER_API int whisper_pcm_to_mel_with_state(
struct whisper_context * ctx,
struct whisper_state * state,
const float * samples,
int n_samples,
int n_threads);
// Convert RAW PCM audio to log mel spectrogram but applies a Phase Vocoder to speed up the audio x2.
// The resulting spectrogram is stored inside the default state of the provided whisper context.
// Returns 0 on success
WHISPER_API int whisper_pcm_to_mel_phase_vocoder(
struct whisper_context * ctx,
const float * samples,
int n_samples,
int n_threads);
WHISPER_API int whisper_pcm_to_mel_phase_vocoder_with_state(
struct whisper_context * ctx,
struct whisper_state * state,
const float * samples,
int n_samples,
int n_threads);
// This can be used to set a custom log mel spectrogram inside the default state of the provided whisper context.
// Use this instead of whisper_pcm_to_mel() if you want to provide your own log mel spectrogram.
// n_mel must be 80
// Returns 0 on success
WHISPER_API int whisper_set_mel(
struct whisper_context * ctx,
const float * data,
int n_len,
int n_mel);
WHISPER_API int whisper_set_mel_with_state(
struct whisper_context * ctx,
struct whisper_state * state,
const float * data,
int n_len,
int n_mel);
// Run the Whisper encoder on the log mel spectrogram stored inside the default state in the provided whisper context.
// Make sure to call whisper_pcm_to_mel() or whisper_set_mel() first.
// offset can be used to specify the offset of the first frame in the spectrogram.
// Returns 0 on success
WHISPER_API int whisper_encode(
struct whisper_context * ctx,
int offset,
int n_threads);
WHISPER_API int whisper_encode_with_state(
struct whisper_context * ctx,
struct whisper_state * state,
int offset,
int n_threads);
// Run the Whisper decoder to obtain the logits and probabilities for the next token.
// Make sure to call whisper_encode() first.
// tokens + n_tokens is the provided context for the decoder.
// n_past is the number of tokens to use from previous decoder calls.
// Returns 0 on success
// TODO: add support for multiple decoders
WHISPER_API int whisper_decode(
struct whisper_context * ctx,
const whisper_token * tokens,
int n_tokens,
int n_past,
int n_threads);
WHISPER_API int whisper_decode_with_state(
struct whisper_context * ctx,
struct whisper_state * state,
const whisper_token * tokens,
int n_tokens,
int n_past,
int n_threads);
// Convert the provided text into tokens.
// The tokens pointer must be large enough to hold the resulting tokens.
// Returns the number of tokens on success, no more than n_max_tokens
// Returns a negative number on failure - the number of tokens that would have been returned
// TODO: not sure if correct
WHISPER_API int whisper_tokenize(
struct whisper_context * ctx,
const char * text,
whisper_token * tokens,
int n_max_tokens);
// Return the number of tokens in the provided text
// Equivalent to: -whisper_tokenize(ctx, text, NULL, 0)
int whisper_token_count(struct whisper_context * ctx, const char * text);
// Largest language id (i.e. number of available languages - 1)
WHISPER_API int whisper_lang_max_id();
// Return the id of the specified language, returns -1 if not found
// Examples:
// "de" -> 2
// "german" -> 2
WHISPER_API int whisper_lang_id(const char * lang);
// Return the short string of the specified language id (e.g. 2 -> "de"), returns nullptr if not found
WHISPER_API const char * whisper_lang_str(int id);
// Return the short string of the specified language name (e.g. 2 -> "german"), returns nullptr if not found
WHISPER_API const char * whisper_lang_str_full(int id);
// Use mel data at offset_ms to try and auto-detect the spoken language
// Make sure to call whisper_pcm_to_mel() or whisper_set_mel() first
// Returns the top language id or negative on failure
// If not null, fills the lang_probs array with the probabilities of all languages
// The array must be whisper_lang_max_id() + 1 in size
// ref: https://github.com/openai/whisper/blob/main/whisper/decoding.py#L18-L69
WHISPER_API int whisper_lang_auto_detect(
struct whisper_context * ctx,
int offset_ms,
int n_threads,
float * lang_probs);
WHISPER_API int whisper_lang_auto_detect_with_state(
struct whisper_context * ctx,
struct whisper_state * state,
int offset_ms,
int n_threads,
float * lang_probs);
WHISPER_API int whisper_n_len (struct whisper_context * ctx); // mel length
WHISPER_API int whisper_n_len_from_state(struct whisper_state * state); // mel length
WHISPER_API int whisper_n_vocab (struct whisper_context * ctx);
WHISPER_API int whisper_n_text_ctx (struct whisper_context * ctx);
WHISPER_API int whisper_n_audio_ctx (struct whisper_context * ctx);
WHISPER_API int whisper_is_multilingual (struct whisper_context * ctx);
WHISPER_API int whisper_model_n_vocab (struct whisper_context * ctx);
WHISPER_API int whisper_model_n_audio_ctx (struct whisper_context * ctx);
WHISPER_API int whisper_model_n_audio_state(struct whisper_context * ctx);
WHISPER_API int whisper_model_n_audio_head (struct whisper_context * ctx);
WHISPER_API int whisper_model_n_audio_layer(struct whisper_context * ctx);
WHISPER_API int whisper_model_n_text_ctx (struct whisper_context * ctx);
WHISPER_API int whisper_model_n_text_state (struct whisper_context * ctx);
WHISPER_API int whisper_model_n_text_head (struct whisper_context * ctx);
WHISPER_API int whisper_model_n_text_layer (struct whisper_context * ctx);
WHISPER_API int whisper_model_n_mels (struct whisper_context * ctx);
WHISPER_API int whisper_model_ftype (struct whisper_context * ctx);
WHISPER_API int whisper_model_type (struct whisper_context * ctx);
// Token logits obtained from the last call to whisper_decode()
// The logits for the last token are stored in the last row
// Rows: n_tokens
// Cols: n_vocab
WHISPER_API float * whisper_get_logits (struct whisper_context * ctx);
WHISPER_API float * whisper_get_logits_from_state(struct whisper_state * state);
// Token Id -> String. Uses the vocabulary in the provided context
WHISPER_API const char * whisper_token_to_str(struct whisper_context * ctx, whisper_token token);
WHISPER_API const char * whisper_model_type_readable(struct whisper_context * ctx);
// Special tokens
WHISPER_API whisper_token whisper_token_eot (struct whisper_context * ctx);
WHISPER_API whisper_token whisper_token_sot (struct whisper_context * ctx);
WHISPER_API whisper_token whisper_token_solm(struct whisper_context * ctx);
WHISPER_API whisper_token whisper_token_prev(struct whisper_context * ctx);
WHISPER_API whisper_token whisper_token_nosp(struct whisper_context * ctx);
WHISPER_API whisper_token whisper_token_not (struct whisper_context * ctx);
WHISPER_API whisper_token whisper_token_beg (struct whisper_context * ctx);
WHISPER_API whisper_token whisper_token_lang(struct whisper_context * ctx, int lang_id);
// Task tokens
WHISPER_API whisper_token whisper_token_translate (struct whisper_context * ctx);
WHISPER_API whisper_token whisper_token_transcribe(struct whisper_context * ctx);
////////////////////////////////////////////////////////////////////////////
// Number of generated text segments
// A segment can be a few words, a sentence, or even a paragraph.
WHISPER_API int whisper_full_n_segments (struct whisper_context * ctx);
WHISPER_API int whisper_full_n_segments_from_state(struct whisper_state * state);
// Language id associated with the context's default state
WHISPER_API int whisper_full_lang_id(struct whisper_context * ctx);
// Language id associated with the provided state
WHISPER_API int whisper_full_lang_id_from_state(struct whisper_state * state);
// Get the start and end time of the specified segment
WHISPER_API int64_t whisper_full_get_segment_t0 (struct whisper_context * ctx, int i_segment);
WHISPER_API int64_t whisper_full_get_segment_t0_from_state(struct whisper_state * state, int i_segment);
WHISPER_API int64_t whisper_full_get_segment_t1 (struct whisper_context * ctx, int i_segment);
WHISPER_API int64_t whisper_full_get_segment_t1_from_state(struct whisper_state * state, int i_segment);
// Get whether the next segment is predicted as a speaker turn
WHISPER_API bool whisper_full_get_segment_speaker_turn_next(struct whisper_context * ctx, int i_segment);
WHISPER_API bool whisper_full_get_segment_speaker_turn_next_from_state(struct whisper_state * state, int i_segment);
// Get the text of the specified segment
WHISPER_API const char * whisper_full_get_segment_text (struct whisper_context * ctx, int i_segment);
WHISPER_API const char * whisper_full_get_segment_text_from_state(struct whisper_state * state, int i_segment);
// Get number of tokens in the specified segment
WHISPER_API int whisper_full_n_tokens (struct whisper_context * ctx, int i_segment);
WHISPER_API int whisper_full_n_tokens_from_state(struct whisper_state * state, int i_segment);
// Get the token text of the specified token in the specified segment
WHISPER_API const char * whisper_full_get_token_text (struct whisper_context * ctx, int i_segment, int i_token);
WHISPER_API const char * whisper_full_get_token_text_from_state(struct whisper_context * ctx, struct whisper_state * state, int i_segment, int i_token);
WHISPER_API whisper_token whisper_full_get_token_id (struct whisper_context * ctx, int i_segment, int i_token);
WHISPER_API whisper_token whisper_full_get_token_id_from_state(struct whisper_state * state, int i_segment, int i_token);
// Get token data for the specified token in the specified segment
// This contains probabilities, timestamps, etc.
WHISPER_API whisper_token_data whisper_full_get_token_data (struct whisper_context * ctx, int i_segment, int i_token);
WHISPER_API whisper_token_data whisper_full_get_token_data_from_state(struct whisper_state * state, int i_segment, int i_token);
// Get the probability of the specified token in the specified segment
WHISPER_API float whisper_full_get_token_p (struct whisper_context * ctx, int i_segment, int i_token);
WHISPER_API float whisper_full_get_token_p_from_state(struct whisper_state * state, int i_segment, int i_token);
////////////////////////////////////////////////////////////////////////////
#ifdef __cplusplus
}
#endif
#endif
/*
* Copyright (C) 20242025 Amlogic, Inc. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef WHISPER_H
#define WHISPER_H
#include <stddef.h>
#include <stdint.h>
#include <stdbool.h>
#include <string>
#include <map>
#include <cstdint>
#ifdef __GNUC__
# define WHISPER_DEPRECATED(func, hint) func __attribute__((deprecated(hint)))
#elif defined(_MSC_VER)
# define WHISPER_DEPRECATED(func, hint) __declspec(deprecated(hint)) func
#else
# define WHISPER_DEPRECATED(func, hint) func
#endif
#ifdef WHISPER_SHARED
# ifdef _WIN32
# ifdef WHISPER_BUILD
# define WHISPER_API __declspec(dllexport)
# else
# define WHISPER_API __declspec(dllimport)
# endif
# else
# define WHISPER_API __attribute__ ((visibility ("default")))
# endif
#else
# define WHISPER_API
#endif
#define WHISPER_SAMPLE_RATE 16000
#define WHISPER_N_FFT 400
#define WHISPER_HOP_LENGTH 160
#define WHISPER_CHUNK_SIZE 30
#define WHISPER_N_MELS 80
#define WHISPER_N_FRAMES 3000
#define WHISPER_N_SAMPLES 48000
struct whisper_vocab {
using id = int32_t;
using token = std::string;
int n_vocab = 51864;
std::map<token, id> token_to_id;
std::map<id, token> id_to_token;
// reference: https://github.com/openai/whisper/blob/248b6cb124225dd263bb9bd32d060b6517e067f8/whisper/tokenizer.py#L334-L349
id token_eot = 50256;
id token_sot = 50257;
// task tokens (used only for multilingual models)
id token_translate = 50357;
id token_transcribe = 50358;
// other special tokens
id token_solm = 50359; // [TDRZ] used by tinydiarize models to indicate speaker turn
id token_prev = 50360;
id token_nosp = 50361;
id token_not = 50362; // no timestamps
id token_beg = 50363; // begin timestamps
bool is_multilingual() const {
return n_vocab >= 51865;
}
int num_languages() const {
return n_vocab - 51765 - (is_multilingual() ? 1 : 0);
}
};
#ifdef __cplusplus
extern "C" {
#endif
struct whisper_context;
struct whisper_state;
struct whisper_full_params;
typedef int32_t whisper_pos;
typedef int32_t whisper_token;
typedef int32_t whisper_seq_id;
typedef struct whisper_token_data {
whisper_token id; // token id
whisper_token tid; // forced timestamp token id
float p; // probability of the token
float plog; // log probability of the token
float pt; // probability of the timestamp token
float ptsum; // sum of probabilities of all timestamp tokens
// token-level timestamp data
// do not use if you haven't computed token-level timestamps
int64_t t0; // start time of the token
int64_t t1; // end time of the token
// [EXPERIMENTAL] Token-level timestamps with DTW
// do not use if you haven't computed token-level timestamps with dtw
// Roughly corresponds to the moment in audio in which the token was output
int64_t t_dtw;
float vlen; // voice length of the token
} whisper_token_data;
typedef struct whisper_model_loader {
void * context;
size_t (*read)(void * ctx, void * output, size_t read_size);
bool (*eof)(void * ctx);
void (*close)(void * ctx);
} whisper_model_loader;
// Various functions for loading a ggml whisper model.
// Allocate (almost) all memory needed for the model.
// Return NULL on failure
WHISPER_API struct whisper_context * whisper_init_from_file_with_params (const char * path_model, struct whisper_context_params params);
WHISPER_API struct whisper_context * whisper_init_from_buffer_with_params(void * buffer, size_t buffer_size, struct whisper_context_params params);
WHISPER_API struct whisper_context * whisper_init_with_params (struct whisper_model_loader * loader, struct whisper_context_params params);
// These are the same as the above, but the internal state of the context is not allocated automatically
// It is the responsibility of the caller to allocate the state using whisper_init_state() (#523)
WHISPER_API struct whisper_context * whisper_init_from_file_with_params_no_state (const char * path_model, struct whisper_context_params params);
WHISPER_API struct whisper_context * whisper_init_from_buffer_with_params_no_state(void * buffer, size_t buffer_size, struct whisper_context_params params);
WHISPER_API struct whisper_context * whisper_init_with_params_no_state (struct whisper_model_loader * loader, struct whisper_context_params params);
WHISPER_API struct whisper_state * whisper_init_state(struct whisper_context * ctx);
// Given a context, enable use of OpenVINO for encode inference.
// model_path: Optional path to OpenVINO encoder IR model. If set to nullptr,
// the path will be generated from the ggml model path that was passed
// in to whisper_init_from_file. For example, if 'path_model' was
// "/path/to/ggml-base.en.bin", then OpenVINO IR model path will be
// assumed to be "/path/to/ggml-base.en-encoder-openvino.xml".
// device: OpenVINO device to run inference on ("CPU", "GPU", etc.)
// cache_dir: Optional cache directory that can speed up init time, especially for
// GPU, by caching compiled 'blobs' there.
// Set to nullptr if not used.
// Returns 0 on success. If OpenVINO is not enabled in build, this simply returns 1.
WHISPER_API int whisper_ctx_init_openvino_encoder(
struct whisper_context * ctx,
const char * model_path,
const char * device,
const char * cache_dir);
// Frees all allocated memory
WHISPER_API void whisper_free (struct whisper_context * ctx);
WHISPER_API void whisper_free_state(struct whisper_state * state);
WHISPER_API void whisper_free_params(struct whisper_full_params * params);
WHISPER_API void whisper_free_context_params(struct whisper_context_params * params);
// Convert RAW PCM audio to log mel spectrogram.
// The resulting spectrogram is stored inside the default state of the provided whisper context.
// Returns 0 on success
WHISPER_API int whisper_pcm_to_mel(
struct whisper_context * ctx,
const float * samples,
int n_samples,
int n_threads);
WHISPER_API int whisper_pcm_to_mel_with_state(
struct whisper_context * ctx,
struct whisper_state * state,
const float * samples,
int n_samples,
int n_threads);
// Convert RAW PCM audio to log mel spectrogram but applies a Phase Vocoder to speed up the audio x2.
// The resulting spectrogram is stored inside the default state of the provided whisper context.
// Returns 0 on success
WHISPER_API int whisper_pcm_to_mel_phase_vocoder(
struct whisper_context * ctx,
const float * samples,
int n_samples,
int n_threads);
WHISPER_API int whisper_pcm_to_mel_phase_vocoder_with_state(
struct whisper_context * ctx,
struct whisper_state * state,
const float * samples,
int n_samples,
int n_threads);
// This can be used to set a custom log mel spectrogram inside the default state of the provided whisper context.
// Use this instead of whisper_pcm_to_mel() if you want to provide your own log mel spectrogram.
// n_mel must be 80
// Returns 0 on success
WHISPER_API int whisper_set_mel(
struct whisper_context * ctx,
const float * data,
int n_len,
int n_mel);
WHISPER_API int whisper_set_mel_with_state(
struct whisper_context * ctx,
struct whisper_state * state,
const float * data,
int n_len,
int n_mel);
// Run the Whisper encoder on the log mel spectrogram stored inside the default state in the provided whisper context.
// Make sure to call whisper_pcm_to_mel() or whisper_set_mel() first.
// offset can be used to specify the offset of the first frame in the spectrogram.
// Returns 0 on success
WHISPER_API int whisper_encode(
struct whisper_context * ctx,
int offset,
int n_threads);
WHISPER_API int whisper_encode_with_state(
struct whisper_context * ctx,
struct whisper_state * state,
int offset,
int n_threads);
// Run the Whisper decoder to obtain the logits and probabilities for the next token.
// Make sure to call whisper_encode() first.
// tokens + n_tokens is the provided context for the decoder.
// n_past is the number of tokens to use from previous decoder calls.
// Returns 0 on success
// TODO: add support for multiple decoders
WHISPER_API int whisper_decode(
struct whisper_context * ctx,
const whisper_token * tokens,
int n_tokens,
int n_past,
int n_threads);
WHISPER_API int whisper_decode_with_state(
struct whisper_context * ctx,
struct whisper_state * state,
const whisper_token * tokens,
int n_tokens,
int n_past,
int n_threads);
// Convert the provided text into tokens.
// The tokens pointer must be large enough to hold the resulting tokens.
// Returns the number of tokens on success, no more than n_max_tokens
// Returns a negative number on failure - the number of tokens that would have been returned
// TODO: not sure if correct
WHISPER_API int whisper_tokenize(
struct whisper_context * ctx,
const char * text,
whisper_token * tokens,
int n_max_tokens);
// Return the number of tokens in the provided text
// Equivalent to: -whisper_tokenize(ctx, text, NULL, 0)
int whisper_token_count(struct whisper_context * ctx, const char * text);
// Largest language id (i.e. number of available languages - 1)
WHISPER_API int whisper_lang_max_id();
// Return the id of the specified language, returns -1 if not found
// Examples:
// "de" -> 2
// "german" -> 2
WHISPER_API int whisper_lang_id(const char * lang);
// Return the short string of the specified language id (e.g. 2 -> "de"), returns nullptr if not found
WHISPER_API const char * whisper_lang_str(int id);
// Return the short string of the specified language name (e.g. 2 -> "german"), returns nullptr if not found
WHISPER_API const char * whisper_lang_str_full(int id);
// Use mel data at offset_ms to try and auto-detect the spoken language
// Make sure to call whisper_pcm_to_mel() or whisper_set_mel() first
// Returns the top language id or negative on failure
// If not null, fills the lang_probs array with the probabilities of all languages
// The array must be whisper_lang_max_id() + 1 in size
// ref: https://github.com/openai/whisper/blob/main/whisper/decoding.py#L18-L69
WHISPER_API int whisper_lang_auto_detect(
struct whisper_context * ctx,
int offset_ms,
int n_threads,
float * lang_probs);
WHISPER_API int whisper_lang_auto_detect_with_state(
struct whisper_context * ctx,
struct whisper_state * state,
int offset_ms,
int n_threads,
float * lang_probs);
WHISPER_API int whisper_n_len (struct whisper_context * ctx); // mel length
WHISPER_API int whisper_n_len_from_state(struct whisper_state * state); // mel length
WHISPER_API int whisper_n_vocab (struct whisper_context * ctx);
WHISPER_API int whisper_n_text_ctx (struct whisper_context * ctx);
WHISPER_API int whisper_n_audio_ctx (struct whisper_context * ctx);
WHISPER_API int whisper_is_multilingual (struct whisper_context * ctx);
WHISPER_API int whisper_model_n_vocab (struct whisper_context * ctx);
WHISPER_API int whisper_model_n_audio_ctx (struct whisper_context * ctx);
WHISPER_API int whisper_model_n_audio_state(struct whisper_context * ctx);
WHISPER_API int whisper_model_n_audio_head (struct whisper_context * ctx);
WHISPER_API int whisper_model_n_audio_layer(struct whisper_context * ctx);
WHISPER_API int whisper_model_n_text_ctx (struct whisper_context * ctx);
WHISPER_API int whisper_model_n_text_state (struct whisper_context * ctx);
WHISPER_API int whisper_model_n_text_head (struct whisper_context * ctx);
WHISPER_API int whisper_model_n_text_layer (struct whisper_context * ctx);
WHISPER_API int whisper_model_n_mels (struct whisper_context * ctx);
WHISPER_API int whisper_model_ftype (struct whisper_context * ctx);
WHISPER_API int whisper_model_type (struct whisper_context * ctx);
// Token logits obtained from the last call to whisper_decode()
// The logits for the last token are stored in the last row
// Rows: n_tokens
// Cols: n_vocab
WHISPER_API float * whisper_get_logits (struct whisper_context * ctx);
WHISPER_API float * whisper_get_logits_from_state(struct whisper_state * state);
// Token Id -> String. Uses the vocabulary in the provided context
WHISPER_API const char * whisper_token_to_str(struct whisper_context * ctx, whisper_token token);
WHISPER_API const char * whisper_model_type_readable(struct whisper_context * ctx);
// Special tokens
WHISPER_API whisper_token whisper_token_eot (struct whisper_context * ctx);
WHISPER_API whisper_token whisper_token_sot (struct whisper_context * ctx);
WHISPER_API whisper_token whisper_token_solm(struct whisper_context * ctx);
WHISPER_API whisper_token whisper_token_prev(struct whisper_context * ctx);
WHISPER_API whisper_token whisper_token_nosp(struct whisper_context * ctx);
WHISPER_API whisper_token whisper_token_not (struct whisper_context * ctx);
WHISPER_API whisper_token whisper_token_beg (struct whisper_context * ctx);
WHISPER_API whisper_token whisper_token_lang(struct whisper_context * ctx, int lang_id);
// Task tokens
WHISPER_API whisper_token whisper_token_translate (struct whisper_context * ctx);
WHISPER_API whisper_token whisper_token_transcribe(struct whisper_context * ctx);
////////////////////////////////////////////////////////////////////////////
// Number of generated text segments
// A segment can be a few words, a sentence, or even a paragraph.
WHISPER_API int whisper_full_n_segments (struct whisper_context * ctx);
WHISPER_API int whisper_full_n_segments_from_state(struct whisper_state * state);
// Language id associated with the context's default state
WHISPER_API int whisper_full_lang_id(struct whisper_context * ctx);
// Language id associated with the provided state
WHISPER_API int whisper_full_lang_id_from_state(struct whisper_state * state);
// Get the start and end time of the specified segment
WHISPER_API int64_t whisper_full_get_segment_t0 (struct whisper_context * ctx, int i_segment);
WHISPER_API int64_t whisper_full_get_segment_t0_from_state(struct whisper_state * state, int i_segment);
WHISPER_API int64_t whisper_full_get_segment_t1 (struct whisper_context * ctx, int i_segment);
WHISPER_API int64_t whisper_full_get_segment_t1_from_state(struct whisper_state * state, int i_segment);
// Get whether the next segment is predicted as a speaker turn
WHISPER_API bool whisper_full_get_segment_speaker_turn_next(struct whisper_context * ctx, int i_segment);
WHISPER_API bool whisper_full_get_segment_speaker_turn_next_from_state(struct whisper_state * state, int i_segment);
// Get the text of the specified segment
WHISPER_API const char * whisper_full_get_segment_text (struct whisper_context * ctx, int i_segment);
WHISPER_API const char * whisper_full_get_segment_text_from_state(struct whisper_state * state, int i_segment);
// Get number of tokens in the specified segment
WHISPER_API int whisper_full_n_tokens (struct whisper_context * ctx, int i_segment);
WHISPER_API int whisper_full_n_tokens_from_state(struct whisper_state * state, int i_segment);
// Get the token text of the specified token in the specified segment
WHISPER_API const char * whisper_full_get_token_text (struct whisper_context * ctx, int i_segment, int i_token);
WHISPER_API const char * whisper_full_get_token_text_from_state(struct whisper_context * ctx, struct whisper_state * state, int i_segment, int i_token);
WHISPER_API whisper_token whisper_full_get_token_id (struct whisper_context * ctx, int i_segment, int i_token);
WHISPER_API whisper_token whisper_full_get_token_id_from_state(struct whisper_state * state, int i_segment, int i_token);
// Get token data for the specified token in the specified segment
// This contains probabilities, timestamps, etc.
WHISPER_API whisper_token_data whisper_full_get_token_data (struct whisper_context * ctx, int i_segment, int i_token);
WHISPER_API whisper_token_data whisper_full_get_token_data_from_state(struct whisper_state * state, int i_segment, int i_token);
// Get the probability of the specified token in the specified segment
WHISPER_API float whisper_full_get_token_p (struct whisper_context * ctx, int i_segment, int i_token);
WHISPER_API float whisper_full_get_token_p_from_state(struct whisper_state * state, int i_segment, int i_token);
////////////////////////////////////////////////////////////////////////////
#ifdef __cplusplus
}
#endif
#endif

View file

@ -70,7 +70,7 @@ void* init_network_file(const char *model_path)
/* set omp, If you are considering high CPU usage during operation,
you can turn off this api, set_openmp_opt_flag = false */
aml_openmp_opt_t openmp_opt[] =
aml_openmp_opt_t openmp_opt[] =
{
{
.operator_type = AML_Unknown,
@ -84,7 +84,7 @@ void* init_network_file(const char *model_path)
config.forward_ctrl.softop_info.openmp_opt = openmp_opt;
/* set neon */
aml_neon_opt_t neon_opt[] =
aml_neon_opt_t neon_opt[] =
{
{
.operator_type = AML_Unknown,
@ -193,11 +193,11 @@ nn_output* run_network_decoder_process(void *qcontext, Input_Decoder* input_data
}
inData.input_type = INPUT_DMA_DATA;
memcpy(mem_data[i].viraddr, i == 0 ? static_cast<const void*>(input_data->input_0) :
memcpy(mem_data[i].viraddr, i == 0 ? static_cast<const void*>(input_data->input_0) :
static_cast<const void*>(input_data->input_1), mem_config[i].mem_size);
inData.input = NULL;
} else {
inData.input = i == 0 ? reinterpret_cast<unsigned char*>(const_cast<float*>(input_data->input_0)) :
inData.input = i == 0 ? reinterpret_cast<unsigned char*>(const_cast<float*>(input_data->input_0)) :
reinterpret_cast<unsigned char*>(const_cast<int64_t*>(input_data->input_1));
inData.input_type = BINARY_RAW_DATA;
@ -266,7 +266,7 @@ int destroy_network(void *qcontext)
{
int ret = 0;
/* free encoder
/* free encoder
encoder.use_dma = true
encoder.malloc_buffer_once = false
*/
@ -279,7 +279,7 @@ int destroy_network(void *qcontext)
}
encoder.use_dma = false;
/* free decoder
/* free decoder
first use destroy_network, decoder.malloc_buffer_once is false,
and set decoder.malloc_buffer_once is true
*/

View file

@ -1,39 +1,39 @@
/*
* Copyright (C) 20242025 Amlogic, Inc. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef WHISPER_INVOKE_H
#define WHISPER_INVOKE_H
#include <string>
#include <vector>
#include <map>
#include "nn_sdk.h"
struct Input_Decoder {
float * input_0;
int input_0_size;
int64_t * input_1;
int input_1_size;
};
void* init_network_file(const char *model_path);
std::vector<float> do_pre_process(std::string fname_inp);
std::vector<float> run_network_encoder_process(void *qcontext, std::vector<float> input_ids);
std::string run_network_decoder(void *qcontext_sec, Input_Decoder* input_data);
bool is_finish_end();
int destroy_network(void *qcontext);
#endif // WHISPER_INVOKE_H
/*
* Copyright (C) 20242025 Amlogic, Inc. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef WHISPER_INVOKE_H
#define WHISPER_INVOKE_H
#include <string>
#include <vector>
#include <map>
#include "nn_sdk.h"
struct Input_Decoder {
float * input_0;
int input_0_size;
int64_t * input_1;
int input_1_size;
};
void* init_network_file(const char *model_path);
std::vector<float> do_pre_process(std::string fname_inp);
std::vector<float> run_network_encoder_process(void *qcontext, std::vector<float> input_ids);
std::string run_network_decoder(void *qcontext_sec, Input_Decoder* input_data);
bool is_finish_end();
int destroy_network(void *qcontext);
#endif // WHISPER_INVOKE_H