Upload first version
This commit is contained in:
parent
f95d5a63b0
commit
3bdf2003ec
898 changed files with 1405811 additions and 1 deletions
2
examples/yoloworld/cpp/.gitkeep
Normal file
2
examples/yoloworld/cpp/.gitkeep
Normal file
|
|
@ -0,0 +1,2 @@
|
|||
src
|
||||
build-*
|
||||
77
examples/yoloworld/cpp/build-android.sh
Executable file
77
examples/yoloworld/cpp/build-android.sh
Executable file
|
|
@ -0,0 +1,77 @@
|
|||
#!/bin/bash
|
||||
set -e
|
||||
|
||||
#
|
||||
# Copyright (C) 2024–2025 Amlogic, Inc. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
usage() {
|
||||
echo "Usage: $0 [-a <target_abi>]"
|
||||
echo " -a <target_abi> : Target ABI (default: arm64-v8a)"
|
||||
echo " -h : Show this help message"
|
||||
exit 1
|
||||
}
|
||||
|
||||
# Default values
|
||||
TARGET_ABI=arm64-v8a
|
||||
|
||||
# Parse arguments
|
||||
while getopts 'a:h' opt; do
|
||||
case "$opt" in
|
||||
a)
|
||||
TARGET_ABI=$OPTARG
|
||||
;;
|
||||
h)
|
||||
usage
|
||||
;;
|
||||
*)
|
||||
usage
|
||||
;;
|
||||
esac
|
||||
done
|
||||
|
||||
if [ -z "${ANDROID_NDK_PATH}" ]; then
|
||||
if [ -n "${ANDROID_NDK}" ]; then
|
||||
ANDROID_NDK_PATH=${ANDROID_NDK}
|
||||
elif [ -n "${ANDROID_NDK_HOME}" ]; then
|
||||
ANDROID_NDK_PATH=${ANDROID_NDK_HOME}
|
||||
else
|
||||
echo "Error: ANDROID_NDK_PATH is not set."
|
||||
echo "Please set ANDROID_NDK_PATH to your Android NDK directory."
|
||||
exit 1
|
||||
fi
|
||||
fi
|
||||
|
||||
ROOT_PWD=$(cd "$(dirname $0)" && pwd)
|
||||
BUILD_DIR=${ROOT_PWD}/build/android
|
||||
|
||||
echo "Building for Android..."
|
||||
echo "NDK_PATH: ${ANDROID_NDK_PATH}"
|
||||
echo "TARGET_ABI: ${TARGET_ABI}"
|
||||
echo "BUILD_DIR: ${BUILD_DIR}"
|
||||
|
||||
mkdir -p ${BUILD_DIR}
|
||||
cd ${BUILD_DIR}
|
||||
|
||||
cmake ../../src \
|
||||
-DCMAKE_TOOLCHAIN_FILE=${ANDROID_NDK_PATH}/build/cmake/android.toolchain.cmake \
|
||||
-DANDROID_ABI=${TARGET_ABI} \
|
||||
-DANDROID_PLATFORM=android-24 \
|
||||
-DCMAKE_BUILD_TYPE=Release \
|
||||
-DOpenCV_DIR=${ROOT_PWD}/../../../dependency/opencv/opencv-android-sdk-build/sdk/native/jni/abi-${TARGET_ABI}
|
||||
|
||||
make -j4
|
||||
|
||||
echo "Build complete. Executable in ${BUILD_DIR}/yolo_world_demo"
|
||||
78
examples/yoloworld/cpp/build-linux.sh
Executable file
78
examples/yoloworld/cpp/build-linux.sh
Executable file
|
|
@ -0,0 +1,78 @@
|
|||
#!/bin/bash
|
||||
set -e
|
||||
|
||||
#
|
||||
# Copyright (C) 2024–2025 Amlogic, Inc. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
usage() {
|
||||
echo "Usage: $0 [-a <target_arch>]"
|
||||
echo " -a <target_arch> : Target architecture (default: aarch64)"
|
||||
echo " -h : Show this help message"
|
||||
exit 1
|
||||
}
|
||||
|
||||
# Default values
|
||||
TARGET_ARCH=aarch64
|
||||
|
||||
# Parse arguments
|
||||
while getopts 'a:h' opt; do
|
||||
case "$opt" in
|
||||
a)
|
||||
TARGET_ARCH=$OPTARG
|
||||
;;
|
||||
h)
|
||||
usage
|
||||
;;
|
||||
*)
|
||||
usage
|
||||
;;
|
||||
esac
|
||||
done
|
||||
|
||||
# Default to aarch64-linux-gnu if GCC_COMPILER is not set
|
||||
GCC_COMPILER=${GCC_COMPILER:-aarch64-linux-gnu}
|
||||
|
||||
# Set compilers
|
||||
export CC=${GCC_COMPILER}-gcc
|
||||
export CXX=${GCC_COMPILER}-g++
|
||||
|
||||
# Validate compiler
|
||||
if ! command -v ${CC} &> /dev/null; then
|
||||
echo "Error: Compiler ${CC} not found."
|
||||
echo "Please set GCC_COMPILER environment variable to your cross-compiler path prefix."
|
||||
echo "Example: export GCC_COMPILER=/path/to/toolchain/bin/aarch64-linux-gnu"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
ROOT_PWD=$(cd "$(dirname $0)" && pwd)
|
||||
BUILD_DIR=${ROOT_PWD}/build/linux
|
||||
|
||||
echo "Building for Linux..."
|
||||
echo "COMPILER: ${CC}"
|
||||
echo "TARGET_ARCH: ${TARGET_ARCH}"
|
||||
echo "BUILD_DIR: ${BUILD_DIR}"
|
||||
|
||||
mkdir -p ${BUILD_DIR}
|
||||
cd ${BUILD_DIR}
|
||||
|
||||
cmake ../../src \
|
||||
-DCMAKE_SYSTEM_NAME=Linux \
|
||||
-DCMAKE_SYSTEM_PROCESSOR=${TARGET_ARCH} \
|
||||
-DCMAKE_BUILD_TYPE=Release
|
||||
|
||||
make -j4
|
||||
|
||||
echo "Build complete. Executable in ${BUILD_DIR}/yolo_world_demo"
|
||||
40
examples/yoloworld/cpp/src/CMakeLists.txt
Executable file
40
examples/yoloworld/cpp/src/CMakeLists.txt
Executable file
|
|
@ -0,0 +1,40 @@
|
|||
cmake_minimum_required(VERSION 3.5)
|
||||
project(yolo_world_demo)
|
||||
|
||||
set(CMAKE_CXX_STANDARD 17)
|
||||
|
||||
# Set NNSDK path
|
||||
set(NNSDK_ROOT "${CMAKE_SOURCE_DIR}/../../../../dependency/nnsdk")
|
||||
include_directories(${NNSDK_ROOT}/include)
|
||||
include_directories(${CMAKE_SOURCE_DIR}/../../../../common)
|
||||
|
||||
# Set 3rdparty path
|
||||
set(3RDPARTY_DIR "${CMAKE_SOURCE_DIR}/../../../../dependency")
|
||||
|
||||
if(CMAKE_SYSTEM_NAME STREQUAL "Android")
|
||||
if (ANDROID_ABI STREQUAL "arm64-v8a")
|
||||
link_directories(${NNSDK_ROOT}/lib/android/arm64-v8a)
|
||||
else()
|
||||
link_directories(${NNSDK_ROOT}/lib/android/armeabi-v7a)
|
||||
endif()
|
||||
# Android needs log
|
||||
link_libraries(log)
|
||||
elseif(CMAKE_SYSTEM_NAME STREQUAL "Linux")
|
||||
link_directories(${NNSDK_ROOT}/lib/linux/lib64_yocto)
|
||||
endif()
|
||||
|
||||
# Find OpenCV
|
||||
find_package(OpenCV REQUIRED)
|
||||
include_directories(${OpenCV_INCLUDE_DIRS})
|
||||
|
||||
add_executable(yolo_world_demo
|
||||
main.cpp
|
||||
postprocess.cpp
|
||||
postprocess.h
|
||||
${CMAKE_SOURCE_DIR}/../../../../common/model_loader.cpp
|
||||
)
|
||||
|
||||
target_link_libraries(yolo_world_demo
|
||||
${OpenCV_LIBS}
|
||||
nnsdk
|
||||
)
|
||||
120
examples/yoloworld/cpp/src/main.cpp
Executable file
120
examples/yoloworld/cpp/src/main.cpp
Executable file
|
|
@ -0,0 +1,120 @@
|
|||
/*
|
||||
* Copyright (C) 2024–2025 Amlogic, Inc. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <iostream>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <chrono>
|
||||
#include <tuple>
|
||||
#include <opencv2/opencv.hpp>
|
||||
#include "postprocess.h"
|
||||
#include "model_loader.h"
|
||||
|
||||
const std::string DEFAULT_OUTPUT_PATH = "./result.jpg";
|
||||
const int MODEL_INPUT_WIDTH = 640;
|
||||
const int MODEL_INPUT_HEIGHT = 480;
|
||||
const float SCORE_THRESHOLD = 0.4f;
|
||||
const float NMS_THRESHOLD = 0.45f;
|
||||
|
||||
const std::vector<std::string> CLASS_NAMES = {
|
||||
"handbag", "backpack", "wallet",
|
||||
"watch", "necklace", "bracelet", "earrings", "finger ring", "sunglass", "hat", "shoes", "belt",
|
||||
"makeup palette", "lipstick tube",
|
||||
"car", "truck", "bicycle", "motorcycle",
|
||||
"phone", "laptop", "camera", "wine bottle", "stuffed toy"
|
||||
};
|
||||
|
||||
int main(int argc, char** argv) {
|
||||
std::string model_path;
|
||||
std::string image_path;
|
||||
if (argc != 3)
|
||||
{
|
||||
printf("%s <model_path> <image_path>\n", argv[0]);
|
||||
return -1;
|
||||
}
|
||||
|
||||
if (argc > 1) model_path = argv[1];
|
||||
if (argc > 2) image_path = argv[2];
|
||||
|
||||
std::cout << "YOLOWorld Native Demo" << std::endl;
|
||||
std::cout << "Model: " << model_path << std::endl;
|
||||
std::cout << "Image: " << image_path << std::endl;
|
||||
std::cout << "Output: " << DEFAULT_OUTPUT_PATH << std::endl;
|
||||
|
||||
// 1. Load Image
|
||||
cv::Mat img = cv::imread(image_path);
|
||||
if (img.empty()) {
|
||||
std::cerr << "Failed to load image from " << image_path << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
// 2. Initialize Network
|
||||
void* context = init_network(model_path.c_str());
|
||||
if (!context) {
|
||||
std::cerr << "Failed to initialize network." << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
// 3. Preprocess
|
||||
auto start_time = std::chrono::high_resolution_clock::now();
|
||||
|
||||
std::tuple<cv::Mat, float, std::tuple<int, int>> input_tuple =
|
||||
preprocess(img, std::make_tuple(MODEL_INPUT_HEIGHT, MODEL_INPUT_WIDTH));
|
||||
|
||||
// 4. Run Network
|
||||
void* output_ptr = run_network(context, {input_tuple});
|
||||
if (!output_ptr) {
|
||||
std::cerr << "Failed to run network." << std::endl;
|
||||
uninit_network(context);
|
||||
return -1;
|
||||
}
|
||||
nn_output* outdata = (nn_output*)output_ptr;
|
||||
|
||||
// 5. Postprocess
|
||||
float* outbuf0 = (float*)outdata->out[0].buf;
|
||||
float* outbuf1 = (float*)outdata->out[1].buf;
|
||||
float* outbuf2 = (float*)outdata->out[2].buf;
|
||||
|
||||
int num_classes = CLASS_NAMES.size();
|
||||
int channels = 87;
|
||||
|
||||
// Using standard stride logic assuming standard YOLOv8/World export
|
||||
std::vector<Detection> detections = postprocess(
|
||||
std::make_tuple(outbuf0, std::make_tuple(MODEL_INPUT_HEIGHT / 8, MODEL_INPUT_WIDTH / 8, channels), 8),
|
||||
std::make_tuple(outbuf1, std::make_tuple(MODEL_INPUT_HEIGHT / 16, MODEL_INPUT_WIDTH / 16, channels), 16),
|
||||
std::make_tuple(outbuf2, std::make_tuple(MODEL_INPUT_HEIGHT / 32, MODEL_INPUT_WIDTH / 32, channels), 32),
|
||||
input_tuple,
|
||||
SCORE_THRESHOLD,
|
||||
NMS_THRESHOLD,
|
||||
num_classes,
|
||||
1 // reverse=1 for YOLOWorld format
|
||||
);
|
||||
|
||||
auto end_time = std::chrono::high_resolution_clock::now();
|
||||
std::chrono::duration<double, std::milli> inference_time = end_time - start_time;
|
||||
std::cout << "Inference + Postprocess time: " << inference_time.count() << " ms" << std::endl;
|
||||
std::cout << "Detections found: " << detections.size() << std::endl;
|
||||
|
||||
// 6. Draw and Save
|
||||
cv::Mat result_img = draw_detections(img, detections, CLASS_NAMES);
|
||||
cv::imwrite(DEFAULT_OUTPUT_PATH, result_img);
|
||||
std::cout << "Result saved to " << DEFAULT_OUTPUT_PATH << std::endl;
|
||||
|
||||
// 7. Cleanup
|
||||
uninit_network(context);
|
||||
|
||||
return 0;
|
||||
}
|
||||
291
examples/yoloworld/cpp/src/postprocess.cpp
Executable file
291
examples/yoloworld/cpp/src/postprocess.cpp
Executable file
|
|
@ -0,0 +1,291 @@
|
|||
/*
|
||||
* Copyright (C) 2024–2025 Amlogic, Inc. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "postprocess.h"
|
||||
#include <iostream>
|
||||
#include <fstream>
|
||||
#include <cmath>
|
||||
#include <algorithm>
|
||||
#include <random>
|
||||
#include <map>
|
||||
#include <cstring>
|
||||
#include <unordered_map>
|
||||
#include "nn_sdk.h"
|
||||
|
||||
#define LOGI(...) do { printf(__VA_ARGS__); printf("\n"); } while(0)
|
||||
#define LOGE(...) do { fprintf(stderr, __VA_ARGS__); fprintf(stderr, "\n"); } while(0)
|
||||
|
||||
static float compute_iou(const Detection& det1, const Detection& det2) {
|
||||
float xx1 = std::max(det1.x1, det2.x1);
|
||||
float yy1 = std::max(det1.y1, det2.y1);
|
||||
float xx2 = std::min(det1.x2, det2.x2);
|
||||
float yy2 = std::min(det1.y2, det2.y2);
|
||||
|
||||
float w = std::max(0.0f, xx2 - xx1);
|
||||
float h = std::max(0.0f, yy2 - yy1);
|
||||
float inter = w * h;
|
||||
|
||||
float area1 = (det1.x2 - det1.x1) * (det1.y2 - det1.y1);
|
||||
float area2 = (det2.x2 - det2.x1) * (det2.y2 - det2.y1);
|
||||
|
||||
return inter / (area1 + area2 - inter);
|
||||
}
|
||||
|
||||
static std::vector<Detection> nms_by_class(const std::vector<Detection>& detections, float iou_threshold) {
|
||||
if (detections.empty()) return {};
|
||||
|
||||
std::vector<Detection> final_detections;
|
||||
|
||||
std::unordered_map<int, std::vector<Detection>> class_detections;
|
||||
for (const auto& det : detections) {
|
||||
class_detections[det.class_id].push_back(det);
|
||||
}
|
||||
|
||||
for (auto& [class_id, cls_dets] : class_detections) {
|
||||
std::sort(cls_dets.begin(), cls_dets.end(), [](const Detection& a, const Detection& b) {
|
||||
return a.score > b.score;
|
||||
});
|
||||
|
||||
std::vector<bool> removed(cls_dets.size(), false);
|
||||
for (size_t i = 0; i < cls_dets.size(); ++i) {
|
||||
if (removed[i]) continue;
|
||||
final_detections.push_back(cls_dets[i]);
|
||||
|
||||
for (size_t j = i + 1; j < cls_dets.size(); ++j) {
|
||||
if (removed[j]) continue;
|
||||
if (compute_iou(cls_dets[i], cls_dets[j]) > iou_threshold) {
|
||||
removed[j] = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return final_detections;
|
||||
}
|
||||
|
||||
static std::vector<Detection> suppress_cross_class_iou_conflicts(std::vector<Detection> detections, float iou_threshold) {
|
||||
std::sort(detections.begin(), detections.end(), [](const Detection& a, const Detection& b) {
|
||||
return a.score > b.score;
|
||||
});
|
||||
|
||||
std::vector<bool> removed(detections.size(), false);
|
||||
std::vector<Detection> final_detections;
|
||||
|
||||
for (size_t i = 0; i < detections.size(); ++i) {
|
||||
if (removed[i]) continue;
|
||||
final_detections.push_back(detections[i]);
|
||||
for (size_t j = i + 1; j < detections.size(); ++j) {
|
||||
if (removed[j]) continue;
|
||||
if (detections[i].class_id != detections[j].class_id &&
|
||||
compute_iou(detections[i], detections[j]) > iou_threshold) {
|
||||
removed[j] = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
return final_detections;
|
||||
}
|
||||
|
||||
static float sigmoid(float x) {
|
||||
return 1.0f / (1.0f + std::exp(-x));
|
||||
}
|
||||
|
||||
static std::vector<Detection> get_detections(float* output, std::tuple<int, int, int> output_shape,
|
||||
int stride, float conf_thresh, int num_classes, int reverse) {
|
||||
std::vector<Detection> detections;
|
||||
|
||||
int grid_h = std::get<0>(output_shape);
|
||||
int grid_w = std::get<1>(output_shape);
|
||||
int total_cells = grid_h * grid_w;
|
||||
int coords = 4 * 16; // DFL coords: 64
|
||||
|
||||
// reverse=0: standard YOLO [classes + box]
|
||||
// reverse>0: YOLOWorld [box + classes]
|
||||
int cls_offset = (reverse > 0) ? coords : 0;
|
||||
int dfl_offset = (reverse > 0) ? 0 : num_classes;
|
||||
|
||||
for (int i = 0; i < grid_h; ++i) {
|
||||
for (int j = 0; j < grid_w; ++j) {
|
||||
int idx = (i * grid_w + j) * (num_classes + coords);
|
||||
|
||||
float max_score = -1.0f;
|
||||
int class_id = -1;
|
||||
for (int c = 0; c < num_classes; ++c) {
|
||||
int cls_idx = idx + cls_offset + c;
|
||||
float score = sigmoid(output[cls_idx]);
|
||||
if (score > max_score) {
|
||||
max_score = score;
|
||||
class_id = c;
|
||||
}
|
||||
}
|
||||
|
||||
if (max_score < conf_thresh) continue;
|
||||
|
||||
float exp_vals[4] = {0.0f, 0.0f, 0.0f, 0.0f};
|
||||
for (int k = 0; k < 4; ++k) {
|
||||
int dfl_idx = idx + dfl_offset + k * 16;
|
||||
float exp_logits[16];
|
||||
float sum_exp = 0.0f;
|
||||
|
||||
float max_logit = output[dfl_idx];
|
||||
for (int t = 1; t < 16; ++t) {
|
||||
if (output[dfl_idx + t] > max_logit) max_logit = output[dfl_idx + t];
|
||||
}
|
||||
|
||||
for (int t = 0; t < 16; ++t) {
|
||||
exp_logits[t] = std::exp(output[dfl_idx + t] - max_logit);
|
||||
sum_exp += exp_logits[t];
|
||||
}
|
||||
|
||||
for (int t = 0; t < 16; ++t) {
|
||||
exp_logits[t] /= sum_exp;
|
||||
exp_vals[k] += t * exp_logits[t];
|
||||
}
|
||||
}
|
||||
|
||||
float x1 = (j + 0.5f - exp_vals[0]) * stride;
|
||||
float y1 = (i + 0.5f - exp_vals[1]) * stride;
|
||||
float x2 = (j + 0.5f + exp_vals[2]) * stride;
|
||||
float y2 = (i + 0.5f + exp_vals[3]) * stride;
|
||||
|
||||
detections.push_back({x1, y1, x2, y2, max_score, class_id});
|
||||
}
|
||||
}
|
||||
return detections;
|
||||
}
|
||||
|
||||
|
||||
std::tuple<cv::Mat, float, std::tuple<int, int>> preprocess(cv::Mat img, std::tuple<int, int> new_shape) {
|
||||
cv::Mat img_rgb;
|
||||
// Check if image is valid
|
||||
if (img.empty()) {
|
||||
LOGE("Preprocess received empty image");
|
||||
return {};
|
||||
}
|
||||
|
||||
if (img.channels() == 4)
|
||||
cv::cvtColor(img, img_rgb, cv::COLOR_RGBA2RGB);
|
||||
else if (img.channels() == 3)
|
||||
img_rgb = img.clone();
|
||||
|
||||
if (img.channels() == 3) {
|
||||
cv::cvtColor(img, img_rgb, cv::COLOR_BGR2RGB);
|
||||
}
|
||||
|
||||
int orig_h = img.rows;
|
||||
int orig_w = img.cols;
|
||||
float scale = std::min(static_cast<float>(std::get<0>(new_shape)) / orig_h, static_cast<float>(std::get<1>(new_shape)) / orig_w);
|
||||
int new_h = static_cast<int>(round(orig_h * scale));
|
||||
int new_w = static_cast<int>(round(orig_w * scale));
|
||||
|
||||
cv::Mat img_resized;
|
||||
cv::resize(img_rgb, img_resized, cv::Size(new_w, new_h), 0, 0, cv::INTER_LINEAR);
|
||||
|
||||
int pad_h = std::get<0>(new_shape) - new_h;
|
||||
int pad_w = std::get<1>(new_shape) - new_w;
|
||||
int pad_left = static_cast<int>(round(pad_w / 2. - 0.1f));
|
||||
int pad_right = static_cast<int>(round(pad_w / 2. + 0.1f));
|
||||
int pad_top = static_cast<int>(round(pad_h / 2. - 0.1f));
|
||||
int pad_bottom = static_cast<int>(round(pad_h / 2. + 0.1f));
|
||||
|
||||
cv::Mat img_padded;
|
||||
cv::copyMakeBorder(img_resized, img_padded, pad_top, pad_bottom, pad_left, pad_right, cv::BORDER_CONSTANT, cv::Scalar(114, 114, 114));
|
||||
|
||||
cv::Mat img_float;
|
||||
img_padded.convertTo(img_float, CV_32F, 1.0 / 255.0);
|
||||
|
||||
return std::make_tuple(img_float, scale, std::make_tuple(pad_left, pad_top));
|
||||
}
|
||||
|
||||
|
||||
|
||||
std::vector<Detection> postprocess(std::tuple<float*, std::tuple<int, int, int>, int> out0,
|
||||
std::tuple<float*, std::tuple<int, int, int>, int> out1,
|
||||
std::tuple<float*, std::tuple<int, int, int>, int> out2,
|
||||
std::tuple<cv::Mat, float, std::tuple<int, int>> input_tuple,
|
||||
float conf_thresh, float iou_threshold, int num_classes, int reverse) {
|
||||
float scale = std::get<1>(input_tuple);
|
||||
int pad_left = std::get<0>(std::get<2>(input_tuple));
|
||||
int pad_top = std::get<1>(std::get<2>(input_tuple));
|
||||
|
||||
std::vector<Detection> detections;
|
||||
|
||||
auto process_out = [&](auto& out) {
|
||||
float* output = std::get<0>(out);
|
||||
auto shape = std::get<1>(out);
|
||||
int stride = std::get<2>(out);
|
||||
std::vector<Detection> dets = get_detections(output, shape, stride, conf_thresh, num_classes, reverse);
|
||||
detections.insert(detections.end(), dets.begin(), dets.end());
|
||||
};
|
||||
|
||||
process_out(out0);
|
||||
process_out(out1);
|
||||
process_out(out2);
|
||||
|
||||
std::vector<Detection> detections_orig;
|
||||
for (const auto& det : detections) {
|
||||
float x1_orig = (det.x1 - pad_left) / scale;
|
||||
float y1_orig = (det.y1 - pad_top) / scale;
|
||||
float x2_orig = (det.x2 - pad_left) / scale;
|
||||
float y2_orig = (det.y2 - pad_top) / scale;
|
||||
detections_orig.push_back({x1_orig, y1_orig, x2_orig, y2_orig, det.score, det.class_id});
|
||||
}
|
||||
|
||||
std::vector<Detection> detections_nms = nms_by_class(detections_orig, iou_threshold);
|
||||
return suppress_cross_class_iou_conflicts(detections_nms, 0.8f);
|
||||
}
|
||||
|
||||
cv::Mat draw_detections(cv::Mat image, const std::vector<Detection>& detections,
|
||||
const std::vector<std::string>& classes, int seed_offset) {
|
||||
int num_classes = classes.size();
|
||||
std::vector<cv::Scalar> color_palette;
|
||||
std::mt19937 rng(42 + seed_offset);
|
||||
std::uniform_int_distribution<int> color_dist(0, 255);
|
||||
|
||||
for (int i = 0; i < num_classes; ++i) {
|
||||
color_palette.emplace_back(color_dist(rng), color_dist(rng), color_dist(rng));
|
||||
}
|
||||
|
||||
cv::Mat drawn_image = image.clone();
|
||||
|
||||
for (const auto& det : detections) {
|
||||
int class_id = det.class_id;
|
||||
if (class_id < 0 || class_id >= num_classes) continue;
|
||||
|
||||
cv::Scalar color = color_palette[class_id];
|
||||
cv::rectangle(drawn_image,
|
||||
cv::Point(static_cast<int>(det.x1), static_cast<int>(det.y1)),
|
||||
cv::Point(static_cast<int>(det.x2), static_cast<int>(det.y2)),
|
||||
color, 2);
|
||||
|
||||
std::string label = classes[class_id] + ": " + cv::format("%.2f", det.score);
|
||||
int baseline = 0;
|
||||
cv::Size text_size = cv::getTextSize(label, cv::FONT_HERSHEY_SIMPLEX, 0.5, 1, &baseline);
|
||||
|
||||
int label_x = static_cast<int>(det.x1);
|
||||
int label_y = static_cast<int>(det.y1) - 10;
|
||||
if (label_y < text_size.height) label_y = static_cast<int>(det.y1) + text_size.height + 10;
|
||||
|
||||
cv::rectangle(drawn_image,
|
||||
cv::Point(label_x, label_y - text_size.height - baseline),
|
||||
cv::Point(label_x + text_size.width, label_y + baseline),
|
||||
color, cv::FILLED);
|
||||
|
||||
cv::putText(drawn_image, label,
|
||||
cv::Point(label_x, label_y),
|
||||
cv::FONT_HERSHEY_SIMPLEX, 0.5,
|
||||
cv::Scalar(0, 0, 0), 1, cv::LINE_AA);
|
||||
}
|
||||
return drawn_image;
|
||||
}
|
||||
43
examples/yoloworld/cpp/src/postprocess.h
Executable file
43
examples/yoloworld/cpp/src/postprocess.h
Executable file
|
|
@ -0,0 +1,43 @@
|
|||
/*
|
||||
* Copyright (C) 2024–2025 Amlogic, Inc. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef _AMLNN_YOLO_WORLD_DEMO_POSTPROCESS_H_
|
||||
#define _AMLNN_YOLO_WORLD_DEMO_POSTPROCESS_H_
|
||||
|
||||
#include <opencv2/opencv.hpp>
|
||||
#include <vector>
|
||||
#include <tuple>
|
||||
#include <unordered_set>
|
||||
#include <string>
|
||||
|
||||
// Detection result structure
|
||||
struct Detection {
|
||||
float x1, y1, x2, y2; // Bounding box coordinates
|
||||
float score; // Confidence score
|
||||
int class_id; // Predicted class ID
|
||||
};
|
||||
|
||||
|
||||
std::vector<Detection> postprocess(std::tuple<float*, std::tuple<int, int, int>, int> out0,
|
||||
std::tuple<float*, std::tuple<int, int, int>, int> out1,
|
||||
std::tuple<float*, std::tuple<int, int, int>, int> out2,
|
||||
std::tuple<cv::Mat, float, std::tuple<int, int>> input_tuple,
|
||||
float conf_thresh, float iou_threshold, int num_classes, int reverse);
|
||||
|
||||
cv::Mat draw_detections(cv::Mat image, const std::vector<Detection>& detections,
|
||||
const std::vector<std::string>& classes, int seed_offset = 0);
|
||||
|
||||
#endif // _AMLNN_YOLO_WORLD_DEMO_POSTPROCESS_H_
|
||||
Loading…
Add table
Add a link
Reference in a new issue