OverloadHotspotDetectionAlgorithm.cpp 3.96 KB
#include "OverloadHotspotDetectionAlgorithm.hpp"

#include "algorithm/gpu/common/MedianFilter.hpp"
#include "algorithm/gpu/common/MedianFilter3x3.hpp"

void OverloadHotspotDetectionAlgorithm::setup(const cv::Size &frameSize,
                                              const CVMatLoader &loader) {
	currentFrameSize = frameSize;

	deviceFov.upload(loader.asMat("/scene_model/FOV", CV_8UC1));

	PfcSegmenter segmenter(loader, frameSize);
	devicePfc.upload(segmenter.segment());

	cv::cuda::GpuMat zerosMask;
	cv::cuda::compare(devicePfc, 0, zerosMask, cv::CMP_EQ);
	devicePfc.setTo(MAX_TEMPERATURE, zerosMask);

	cv::Mat model;
	cv::cvtColor(loader.asMat("/scene_model/CAD", CV_8UC1), model,
	             cv::COLOR_GRAY2BGR);

	std::vector<Contours> pfcContours;
	segmenter.contour(pfcContours, hostPfcMask);

	surfaceMap = std::make_unique<const SurfaceMap>(frameSize);
	medianFilter = createMedianFilter(3);

	overheating = std::make_unique<ManagedMat>(frameSize, CV_8UC1);

	deviceFrame.create(frameSize, CV_16UC1);

	medianFilterInput = std::make_unique<ContinuousGpuMat>(frameSize, CV_16UC1);
	medianFilterOutput =
	    std::make_unique<ContinuousGpuMat>(frameSize, CV_16UC1);

	deviceTemperature = medianFilterInput->device();
	deviceInput = medianFilterOutput->device();
}

void OverloadHotspotDetectionAlgorithm::handleFrame(const cv::Mat &hostFrame,
                                                    unsigned long timestamp) {

	/* Overlap data transfer and computations */
	concurrent.run(
	    [this, &hostFrame](const auto &rows, const auto &cols, auto &stream) {
		    deviceFrame(rows, cols).upload(hostFrame(rows, cols), stream);
	    },
	    currentFrameSize);

	concurrent.run(
	    [this](const auto &rows, const auto &cols, auto &stream) {
		    /* Apply field-of-view mask */
		    cv::cuda::bitwise_and(
		        deviceFrame(rows, cols), deviceFrame(rows, cols),
		        deviceInput(rows, cols), deviceFov(rows, cols), stream);
		    /* Convert Kelvin to Celsius */
		    cv::cuda::subtract(deviceInput(rows, cols), KELVINS,
		                       deviceTemperature(rows, cols), cv::noArray(), -1,
		                       stream);
	    },
	    currentFrameSize);

	concurrent.synchronize();

	/* Apply median filter */
	medianFilter->apply(deviceTemperature, deviceInput);

	/* Threshold overheating pixels */
	cv::cuda::compare(deviceInput, devicePfc, overheating->device(),
	                  cv::CMP_GT);
	currentHotspots.clear();
	identifyHotspots(timestamp);
}

void OverloadHotspotDetectionAlgorithm::identifyHotspots(
    unsigned long timestamp) {
	/* Analyse topologial structure */
	Contours hotspotContours;
	cv::findContours(overheating->host(), hotspotContours, cv::RETR_TREE,
	                 cv::CHAIN_APPROX_SIMPLE);

	for (std::size_t i = 0, count = 0, size = hotspotContours.size();
	     i < size && count < blobsLimit; ++i) {
		const auto &contour = hotspotContours[i];
		/* Discard blobs below minimum area */
		if (contour.size() >= pixelsThreshold) {
			count += 1;

			const int component = hostPfcMask.at<uchar>(contour[0]);
			assert(component > 0);

			Hotspot hotspot(deviceTemperature, timestamp, component, contour,
			                surfaceMap->at(contour));
			
			/* Find corresponding blobs */
			matchHotspot(hotspot);
			currentHotspots.emplace_back(std::move(hotspot));
		}
	}
}

void OverloadHotspotDetectionAlgorithm::matchHotspot(const Hotspot &hotspot) {
	for (auto &persistentHotspot : uniqueHotspots) {
		if (clusterCorrespondence.corresponds(persistentHotspot.mask,
		                                      hotspot.mask)) {
			/* Merge matching blobs */
			persistentHotspot.merge(hotspot);
			return;
		}
	}
	uniqueHotspots.emplace_back(hotspot);
}

std::unique_ptr<cv::cuda::Filter>
OverloadHotspotDetectionAlgorithm::createMedianFilter(int kernel) const {
	if (kernel == 3) {
		/* Implementation optimised for the small kernel size */
		return std::make_unique<MedianFilter3x3<ushort>>(currentFrameSize);
	}
	return CudaFilter::createMedianFilter16U(currentFrameSize, kernel);
}