lavfi/nlmeans: add aarch64 neon for compute_weights_line

Implement NEON optimization for compute_weights_line.

Also update the function signature to use ptrdiff_t for stack arguments
(max_meaningful_diff, startx, endx). This is done to unify the stack
layout between Apple platforms (which pack 32-bit stack arguments tightly)
and the generic AAPCS64 ABI (which requires 8-byte stack slots for 32-bit
arguments). Using ptrdiff_t ensures 8-byte slots are used on all AArch64
platforms, avoiding ABI mismatches with the assembly implementation.

The x86 AVX2 prototype is updated to match the new signature.

Performance benchmark (AArch64) in MacOS M4:
./tests/checkasm/checkasm --test=vf_nlmeans --bench
compute_weights_line_c:     151.1 ( 1.00x)
compute_weights_line_neon:  62.6 ( 2.42x)

Reviewed-by: Martin Storsjö <martin@martin.st>
Signed-off-by: Jun Zhao <barryjzhao@tencent.com>
This commit is contained in:
Jun Zhao
2026-01-09 21:52:52 +08:00
committed by Zhao Zhili
parent 48c9c38684
commit 91ae6d10ab
6 changed files with 224 additions and 7 deletions

View File

@@ -25,10 +25,23 @@ void ff_compute_safe_ssd_integral_image_neon(uint32_t *dst, ptrdiff_t dst_linesi
const uint8_t *s2, ptrdiff_t linesize2,
int w, int h);
void ff_compute_weights_line_neon(const uint32_t *const iia,
const uint32_t *const iib,
const uint32_t *const iid,
const uint32_t *const iie,
const uint8_t *const src,
float *total_weight,
float *sum,
const float *const weight_lut,
ptrdiff_t max_meaningful_diff,
ptrdiff_t startx, ptrdiff_t endx);
av_cold void ff_nlmeans_init_aarch64(NLMeansDSPContext *dsp)
{
int cpu_flags = av_get_cpu_flags();
if (have_neon(cpu_flags))
if (have_neon(cpu_flags)) {
dsp->compute_safe_ssd_integral_image = ff_compute_safe_ssd_integral_image_neon;
dsp->compute_weights_line = ff_compute_weights_line_neon;
}
}

View File

@@ -78,3 +78,129 @@ function ff_compute_safe_ssd_integral_image_neon, export=1
b.ne 1b
ret
endfunc
function ff_compute_weights_line_neon, export=1
// x0 = iia, x1 = iib, x2 = iid, x3 = iie
// x4 = src, x5 = total_weight, x6 = sum, x7 = weight_lut
// stack: [sp+0] = max_meaningful_diff, [sp+8] = startx, [sp+16] = endx
ldr w13, [sp, #0] // max_meaningful_diff
ldr w9, [sp, #8] // startx
ldr w10, [sp, #16] // endx
cmp w9, w10
b.ge 9f // if startx >= endx return
// Offset pointers
lsl x11, x9, #2 // startx * 4 (for uint32/float)
add x0, x0, x11 // iia += startx
add x1, x1, x11 // iib += startx
add x2, x2, x11 // iid += startx
add x3, x3, x11 // iie += startx
add x5, x5, x11 // total_weight += startx
add x6, x6, x11 // sum += startx
// src is uint8, so offset is just startx
add x4, x4, x9 // src += startx
dup v7.4s, w13 // v7 = max_meaningful_diff (for vector ops)
sub w10, w10, w9 // count = endx - startx
1: // Main loop
cmp w10, #4
b.lt 2f // Handle leftovers
// Load integral image values
ld1 {v0.4s}, [x0], #16 // iia
ld1 {v1.4s}, [x1], #16 // iib
ld1 {v2.4s}, [x2], #16 // iid
ld1 {v3.4s}, [x3], #16 // iie
// diff = a - b + e - d = e - d - b + a
sub v0.4s, v0.4s, v1.4s // v0 = a - b
sub v3.4s, v3.4s, v2.4s // v3 = e - d
add v3.4s, v3.4s, v0.4s // v3 = diff (a - b + e - d)
// min(diff, max)
umin v3.4s, v3.4s, v7.4s
// Schedule independent loads early
ld1 {v0.4s}, [x5] // v0 = total_weight
ld1 {v1.s}[0], [x4], #4 // v1 = src pixels (low 4 bytes)
ld1 {v2.4s}, [x6] // v2 = sum
// Move to scalar registers to address lut
mov w8, v3.s[0]
mov w9, v3.s[1]
mov w11, v3.s[2]
mov w12, v3.s[3]
// Load 4 float weights using scalar registers
// Interleave with src conversion to hide latency
ldr s3, [x7, w8, uxtw #2] // w0 -> v3.s[0] (v3 is now free)
ldr s4, [x7, w9, uxtw #2] // w1 -> v4.s[0]
ldr s5, [x7, w11, uxtw #2] // w2 -> v5.s[0]
ldr s6, [x7, w12, uxtw #2] // w3 -> v6.s[0]
// Convert src pixels to float (independent of weights)
uxtl v1.8h, v1.8b
uxtl v1.4s, v1.4h
ucvtf v1.4s, v1.4s
// Merge weights into v3.4s
trn1 v3.2s, v3.2s, v4.2s // v3 = [w0, w1, ?, ?]
trn1 v5.2s, v5.2s, v6.2s // v5 = [w2, w3, ?, ?]
trn1 v3.2d, v3.2d, v5.2d // v3 = [w0, w1, w2, w3]
// Update total_weight and sum
fadd v0.4s, v0.4s, v3.4s // total_weight += weight
fmla v2.4s, v1.4s, v3.4s // sum += src * weight
// Store back
st1 {v0.4s}, [x5], #16
st1 {v2.4s}, [x6], #16
sub w10, w10, #4
b 1b
2: // Leftovers
cmp w10, #0
b.le 9f
// Single pixel handling
ldr w8, [x0], #4 // iia (reuse w8)
ldr w9, [x1], #4 // iib (reuse w9)
ldr w11, [x2], #4 // iid (reuse w11)
ldr w12, [x3], #4 // iie (reuse w12)
sub w12, w12, w11
sub w12, w12, w9
add w12, w12, w8
// min (unsigned comparison) - use preloaded w13
cmp w12, w13
csel w12, w12, w13, ls // unsigned lower or same
// Load weight
ldr s0, [x7, w12, uxtw #2]
// Load src
ldrb w8, [x4], #1 // src (reuse w8)
ucvtf s1, w8
// Load acc
ldr s2, [x5]
ldr s3, [x6]
fadd s2, s2, s0
fmadd s3, s1, s0, s3
str s2, [x5], #4
str s3, [x6], #4
sub w10, w10, #1
b 2b
9: ret
endfunc

View File

@@ -35,8 +35,8 @@ typedef struct NLMeansDSPContext {
float *total_weight,
float *sum,
const float *const weight_lut,
int max_meaningful_diff,
int startx, int endx);
ptrdiff_t max_meaningful_diff,
ptrdiff_t startx, ptrdiff_t endx);
} NLMeansDSPContext;
void ff_nlmeans_init_aarch64(NLMeansDSPContext *dsp);

View File

@@ -79,8 +79,8 @@ static void compute_weights_line_c(const uint32_t *const iia,
float *total_weight,
float *sum,
const float *const weight_lut,
int max_meaningful_diff,
int startx, int endx)
ptrdiff_t max_meaningful_diff,
ptrdiff_t startx, ptrdiff_t endx)
{
for (int x = startx; x < endx; x++) {
/*

View File

@@ -28,8 +28,8 @@ void ff_compute_weights_line_avx2(const uint32_t *const iia,
float *total_weight,
float *sum,
const float *const weight_lut,
int max_meaningful_diff,
int startx, int endx);
ptrdiff_t max_meaningful_diff,
ptrdiff_t startx, ptrdiff_t endx);
av_cold void ff_nlmeans_init_x86(NLMeansDSPContext *dsp)
{

View File

@@ -18,10 +18,12 @@
* Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
*/
#include <math.h>
#include "checkasm.h"
#include "libavfilter/vf_nlmeans_init.h"
#include "libavutil/avassert.h"
#include "libavutil/mem.h"
#include "libavutil/mem_internal.h"
#define randomize_buffer(buf, size) do { \
int i; \
@@ -110,5 +112,81 @@ void checkasm_check_nlmeans(void)
av_freep(&src);
}
if (check_func(dsp.compute_weights_line, "compute_weights_line")) {
#define TEST_W 256
#define MAX_MEANINGFUL_DIFF 255
const int startx = 10;
const int endx = 200;
// Allocate aligned buffers on stack
LOCAL_ALIGNED_32(uint32_t, iia, [TEST_W + 16]);
LOCAL_ALIGNED_32(uint32_t, iib, [TEST_W + 16]);
LOCAL_ALIGNED_32(uint32_t, iid, [TEST_W + 16]);
LOCAL_ALIGNED_32(uint32_t, iie, [TEST_W + 16]);
LOCAL_ALIGNED_32(uint8_t, src, [TEST_W + 16]);
LOCAL_ALIGNED_32(float, tw_ref, [TEST_W + 16]);
LOCAL_ALIGNED_32(float, tw_new, [TEST_W + 16]);
LOCAL_ALIGNED_32(float, sum_ref, [TEST_W + 16]);
LOCAL_ALIGNED_32(float, sum_new, [TEST_W + 16]);
LOCAL_ALIGNED_32(float, lut, [MAX_MEANINGFUL_DIFF + 1]);
declare_func(void, const uint32_t *const iia,
const uint32_t *const iib,
const uint32_t *const iid,
const uint32_t *const iie,
const uint8_t *const src,
float *total_weight,
float *sum,
const float *const weight_lut,
ptrdiff_t max_meaningful_diff,
ptrdiff_t startx, ptrdiff_t endx);
// Initialize LUT: weight = exp(-diff * scale)
// Using scale = 0.01 for testing
for (int i = 0; i <= MAX_MEANINGFUL_DIFF; i++)
lut[i] = expf(-i * 0.01f);
// Initialize source pixels
for (int i = 0; i < TEST_W; i++)
src[i] = rnd() & 0xff;
// Initialize integral images
// We need to ensure diff = e - d - b + a is non-negative and within range
// Set up as if computing real integral image values
for (int i = 0; i < TEST_W; i++) {
uint32_t base = rnd() % 1000;
iia[i] = base;
iib[i] = base + (rnd() % 100);
iid[i] = base + (rnd() % 100);
// e = a + (b - a) + (d - a) + diff
// So diff = e - d - b + a will be in range [0, max_meaningful_diff]
uint32_t diff = rnd() % (MAX_MEANINGFUL_DIFF + 1);
iie[i] = iia[i] + (iib[i] - iia[i]) + (iid[i] - iia[i]) + diff;
}
// Clear output buffers
memset(tw_ref, 0, (TEST_W + 16) * sizeof(float));
memset(tw_new, 0, (TEST_W + 16) * sizeof(float));
memset(sum_ref, 0, (TEST_W + 16) * sizeof(float));
memset(sum_new, 0, (TEST_W + 16) * sizeof(float));
call_ref(iia, iib, iid, iie, src, tw_ref, sum_ref, lut,
MAX_MEANINGFUL_DIFF, startx, endx);
call_new(iia, iib, iid, iie, src, tw_new, sum_new, lut,
MAX_MEANINGFUL_DIFF, startx, endx);
// Compare results with small tolerance for floating point
if (!float_near_abs_eps_array(tw_ref + startx, tw_new + startx, 1e-5f, endx - startx))
fail();
if (!float_near_abs_eps_array(sum_ref + startx, sum_new + startx, 1e-4f, endx - startx))
fail();
// Benchmark
memset(tw_new, 0, (TEST_W + 16) * sizeof(float));
memset(sum_new, 0, (TEST_W + 16) * sizeof(float));
bench_new(iia, iib, iid, iie, src, tw_new, sum_new, lut,
MAX_MEANINGFUL_DIFF, startx, endx);
}
report("dsp");
}