feat:update demo code of CLIP

This commit is contained in:
dian.yuan 2026-02-12 11:19:52 +08:00
parent 4bf4aafc73
commit 5478a8618b
12 changed files with 50385 additions and 694 deletions

View file

@ -20,31 +20,20 @@
#include <fstream>
#include <algorithm>
#include <vector>
#include <cmath>
#include <cstdlib>
#include "model_invoke.h"
#include "clip_process.h"
#include "nn_sdk.h"
#include "json.hpp"
#include <filesystem>
#include <regex>
using json = nlohmann::ordered_json;
namespace fs = std::__fs::filesystem;
// Global DMA config for models
static aml_memory_config_t vision_mem_config;
static aml_memory_data_t vision_mem_data;
static void* vision_context_flag = nullptr;
struct DMAConfig {
bool use_dma = true;
bool malloc_buffer_once = true;
};
DMAConfig context_model;
///////////////////////////////////////////////////////////
aml_memory_config_t mem_config_context_model;
aml_memory_data_t mem_data_context_model;
std::vector<float> preprocess_image(const std::string& image_path);
float post_process(const float* a, const std::vector<float>& b);
static aml_memory_config_t text_mem_config;
static aml_memory_data_t text_mem_data;
static void* text_context_flag = nullptr;
void* init_network_file(const char *model_path)
{
@ -95,202 +84,119 @@ void* init_network_file(const char *model_path)
return qcontext;
}
float* run_network(void *qcontext, std::vector<float> input_ids, const std::string image_type)
std::vector<float> run_vision_model(void* qcontext, const std::vector<float>& input_data)
{
int ret = 0;
nn_input inData;
nn_output *outdata = NULL;
aml_output_config_t outconfig;
inData.input_index = 0;
inData.info.input_format = AML_INPUT_DEFAULT;
inData.size = input_ids.size() * sizeof(float);
inData.size = input_data.size() * sizeof(float);
if (context_model.use_dma) {
if (context_model.malloc_buffer_once) {
mem_config_context_model.cache_type = AML_WITH_CACHE;
mem_config_context_model.memory_type = AML_VIRTUAL_ADDR;
mem_config_context_model.direction = AML_MEM_DIRECTION_READ_WRITE;
mem_config_context_model.index = 0;
mem_config_context_model.mem_size = inData.size;
aml_util_mallocBuffer(qcontext, &mem_config_context_model, &mem_data_context_model);
aml_util_swapExternalInputBuffer(qcontext, &mem_config_context_model, &mem_data_context_model);
}
inData.input_type = INPUT_DMA_DATA;
memcpy(mem_data_context_model.viraddr, input_ids.data(), mem_config_context_model.mem_size);
inData.input = NULL;
} else {
inData.input = reinterpret_cast<unsigned char*>(input_ids.data());
inData.input_type = BINARY_RAW_DATA;
ret = aml_module_input_set(qcontext, &inData);
if (ret)
{
printf("aml_module_input_set fail.\n");
}
// Use DMA
if (!vision_context_flag) {
vision_mem_config.cache_type = AML_WITH_CACHE;
vision_mem_config.memory_type = AML_VIRTUAL_ADDR;
vision_mem_config.direction = AML_MEM_DIRECTION_READ_WRITE;
vision_mem_config.index = 0;
vision_mem_config.mem_size = inData.size;
aml_util_mallocBuffer(qcontext, &vision_mem_config, &vision_mem_data);
aml_util_swapExternalInputBuffer(qcontext, &vision_mem_config, &vision_mem_data);
vision_context_flag = qcontext;
}
context_model.malloc_buffer_once = false;
inData.input_type = INPUT_DMA_DATA;
memcpy(vision_mem_data.viraddr, input_data.data(), vision_mem_config.mem_size);
inData.input = NULL;
memset(&outconfig, 0, sizeof(aml_output_config_t));
if (context_model.use_dma) {
outconfig.format = AML_OUTDATA_DMA;
} else {
outconfig.format = AML_OUTDATA_RAW;
}
outconfig.format = AML_OUTDATA_DMA;
outconfig.typeSize = sizeof(aml_output_config_t);
outdata = (nn_output*)aml_module_output_get(qcontext, outconfig);
return reinterpret_cast<float*>(outdata->out[0].buf);
}
int extract_index(const std::string& filename) {
std::regex pattern(R"(test_\w+_(\d+)\.jpg)");
std::smatch match;
if (std::regex_match(filename, match, pattern)) {
return std::stoi(match[1]);
if (outdata == NULL || outdata->out[0].buf == NULL) {
printf("Vision model inference failed.\n");
return {};
}
return -1;
// Copy output to vector
size_t output_size = outdata->out[0].size / sizeof(float);
float* output_ptr = reinterpret_cast<float*>(outdata->out[0].buf);
std::vector<float> result(output_ptr, output_ptr + output_size);
return result;
}
std::vector<std::string> process_image_dir(
void* context_model,
const std::string& image_dir_path,
const std::string& base_dir,
const std::string& json_filename)
std::vector<float> run_text_model(void* qcontext, const std::vector<int64_t>& input_ids)
{
std::vector<std::string> results;
std::regex file_pattern(R"(test_(\w+)_\d+\.jpg)");
// Get base_dir from parameter, environment variable, or use default
std::string actual_base_dir = base_dir;
if (actual_base_dir.empty()) {
const char* env_base_dir = std::getenv("CLIP_BASE_DIR");
if (env_base_dir != nullptr) {
actual_base_dir = env_base_dir;
} else {
actual_base_dir = "./demo_data/clip_datasets/";
}
}
// Ensure base_dir ends with '/'
if (!actual_base_dir.empty() && actual_base_dir.back() != '/') {
actual_base_dir += "/";
}
// Get json_filename from parameter, environment variable, or use default
std::string actual_json_filename = json_filename;
if (actual_json_filename.empty()) {
const char* env_json_filename = std::getenv("CLIP_JSON_FILENAME");
if (env_json_filename != nullptr) {
actual_json_filename = env_json_filename;
} else {
actual_json_filename = "clip_text_res.json";
}
int ret = 0;
nn_input inData;
nn_output *outdata = NULL;
aml_output_config_t outconfig;
inData.input_index = 0;
inData.info.input_format = AML_INPUT_DEFAULT;
inData.size = input_ids.size() * sizeof(int64_t);
// Use DMA
if (!text_context_flag) {
text_mem_config.cache_type = AML_WITH_CACHE;
text_mem_config.memory_type = AML_VIRTUAL_ADDR;
text_mem_config.direction = AML_MEM_DIRECTION_READ_WRITE;
text_mem_config.index = 0;
text_mem_config.mem_size = inData.size;
aml_util_mallocBuffer(qcontext, &text_mem_config, &text_mem_data);
aml_util_swapExternalInputBuffer(qcontext, &text_mem_config, &text_mem_data);
text_context_flag = qcontext;
}
// storing qualified paths
std::vector<fs::directory_entry> matched_files;
inData.input_type = INPUT_DMA_DATA;
memcpy(text_mem_data.viraddr, input_ids.data(), text_mem_config.mem_size);
inData.input = NULL;
// collect all relevant img.
for (const auto& entry : fs::directory_iterator(image_dir_path)) {
if (!entry.is_regular_file()) continue;
memset(&outconfig, 0, sizeof(aml_output_config_t));
outconfig.format = AML_OUTDATA_DMA;
outconfig.typeSize = sizeof(aml_output_config_t);
outdata = (nn_output*)aml_module_output_get(qcontext, outconfig);
std::string filename = entry.path().filename().string();
if (std::regex_match(filename, file_pattern)) {
matched_files.push_back(entry);
}
if (outdata == NULL || outdata->out[0].buf == NULL) {
printf("Text model inference failed.\n");
return {};
}
// use index sort, test_type_index.jpg
std::sort(matched_files.begin(), matched_files.end(),
[](const fs::directory_entry& a, const fs::directory_entry& b) {
return extract_index(a.path().filename().string()) <
extract_index(b.path().filename().string());
});
// Copy output to vector
size_t output_size = outdata->out[0].size / sizeof(float);
float* output_ptr = reinterpret_cast<float*>(outdata->out[0].buf);
std::vector<float> result(output_ptr, output_ptr + output_size);
for (const auto& entry : matched_files) {
if (!entry.is_regular_file()) continue;
std::string filename = entry.path().filename().string();
std::smatch match;
if (!std::regex_match(filename, match, file_pattern)) continue;
std::string name = match[1];
std::vector<float> input_data = preprocess_image(entry.path().string());
float* model_output = run_network(context_model, input_data, name);
float max_sim = -std::numeric_limits<float>::infinity();
std::string best_key, best_id;
// Iterate through all directories to find the directory containing the name
for (const auto& dir_entry : fs::directory_iterator(actual_base_dir)) {
if (!dir_entry.is_directory()) continue;
std::string folder_name = dir_entry.path().filename().string();
if (folder_name.find(name) == std::string::npos) continue;
std::string vit_res_path = actual_base_dir + folder_name + "/" + actual_json_filename;
std::ifstream vit_in(vit_res_path);
if (!vit_in.is_open()) {
printf("unopen: %s\n", vit_res_path.c_str());
continue;
}
json vit_json;
vit_in >> vit_json;
for (auto it = vit_json.begin(); it != vit_json.end(); ++it) {
const std::string& key = it.key();
const std::vector<float> vec = it.value().get<std::vector<float>>();
float sim = post_process(model_output, vec);
// printf("sim: %.4f\n", sim);
if (sim > max_sim) {
max_sim = sim;
best_key = key;
best_id = folder_name;
}
}
}
if (!best_key.empty() && !best_id.empty()) {
std::string best_path = actual_base_dir + best_id + "/";
results.push_back(best_path);
printf("\nProcessing images: %s, datasets img path: %s\n", filename.c_str(), best_path.c_str());
// printf("最相似图片: %s 相似度: %.4f\n", best_path.c_str(), max_sim); // for debug
}
}
return results;
return result;
}
int destroy_network(void *qcontext)
{
int ret = 0;
/* free model
model.use_dma = true
model.malloc_buffer_once = false
*/
if (context_model.use_dma && mem_config_context_model.mem_size != 0) {
ret = aml_util_freeBuffer(qcontext, &mem_config_context_model, &mem_data_context_model);
if (ret)
{
std::cout << "aml_util_freeBuffer fail." << std::endl;
}
if (vision_context_flag == qcontext) {
printf("Free vision model memory.\n");
aml_util_freeBuffer(qcontext, &vision_mem_config, &vision_mem_data);
vision_context_flag = nullptr;
} else if (text_context_flag == qcontext) {
printf("Free text model memory.\n");
aml_util_freeBuffer(qcontext, &text_mem_config, &text_mem_data);
text_context_flag = nullptr;
} else {
printf("Free network failed: context not found.\n");
return -1;
}
context_model.use_dma = false;
ret = aml_module_destroy(qcontext);
if (ret)
{
printf("aml_module_destroy fail.\n");
printf("Free network failed: destroy failed.\n");
return -1;
}
return ret;
}
}