temporal_filter_test.cc 9.89 KB
Newer Older
1 2 3 4 5 6 7 8 9 10
/*
 *  Copyright (c) 2016 The WebM project authors. All Rights Reserved.
 *
 *  Use of this source code is governed by a BSD-style license
 *  that can be found in the LICENSE file in the root of the source
 *  tree. An additional intellectual property rights grant can be found
 *  in the file PATENTS.  All contributing project authors may
 *  be found in the AUTHORS file in the root of the source tree.
 */

11 12
#include <limits>

13 14 15 16 17 18
#include "third_party/googletest/src/include/gtest/gtest.h"

#include "./vp9_rtcd.h"
#include "test/acm_random.h"
#include "test/buffer.h"
#include "test/register_state_check.h"
19
#include "vpx_ports/vpx_timer.h"
20 21 22 23 24 25

namespace {

using ::libvpx_test::ACMRandom;
using ::libvpx_test::Buffer;

26 27 28 29 30
typedef void (*TemporalFilterFunc)(const uint8_t *a, unsigned int stride,
                                   const uint8_t *b, unsigned int w,
                                   unsigned int h, int filter_strength,
                                   int filter_weight, unsigned int *accumulator,
                                   uint16_t *count);
31 32 33 34

// Calculate the difference between 'a' and 'b', sum in blocks of 9, and apply
// filter based on strength and weight. Store the resulting filter amount in
// 'count' and apply it to 'b' and store it in 'accumulator'.
35 36 37 38
void reference_filter(const Buffer<uint8_t> &a, const Buffer<uint8_t> &b, int w,
                      int h, int filter_strength, int filter_weight,
                      Buffer<unsigned int> *accumulator,
                      Buffer<uint16_t> *count) {
39
  Buffer<int> diff_sq = Buffer<int>(w, h, 0);
Johann's avatar
Johann committed
40
  ASSERT_TRUE(diff_sq.Init());
41 42 43 44 45 46 47 48 49
  diff_sq.Set(0);

  int rounding = 0;
  if (filter_strength > 0) {
    rounding = 1 << (filter_strength - 1);
  }

  // Calculate all the differences. Avoids re-calculating a bunch of extra
  // values.
50 51
  for (int height = 0; height < h; ++height) {
    for (int width = 0; width < w; ++width) {
52 53 54 55 56 57 58 59
      int diff = a.TopLeftPixel()[height * a.stride() + width] -
                 b.TopLeftPixel()[height * b.stride() + width];
      diff_sq.TopLeftPixel()[height * diff_sq.stride() + width] = diff * diff;
    }
  }

  // For any given point, sum the neighboring values and calculate the
  // modifier.
60 61
  for (int height = 0; height < h; ++height) {
    for (int width = 0; width < w; ++width) {
62 63 64
      // Determine how many values are being summed.
      int summed_values = 9;

65
      if (height == 0 || height == (h - 1)) {
66 67 68
        summed_values -= 3;
      }

69
      if (width == 0 || width == (w - 1)) {
70 71 72 73 74 75 76 77 78 79 80 81 82 83 84
        if (summed_values == 6) {  // corner
          summed_values -= 2;
        } else {
          summed_values -= 3;
        }
      }

      // Sum the diff_sq of the surrounding values.
      int sum = 0;
      for (int idy = -1; idy <= 1; ++idy) {
        for (int idx = -1; idx <= 1; ++idx) {
          const int y = height + idy;
          const int x = width + idx;

          // If inside the border.
85
          if (y >= 0 && y < h && x >= 0 && x < w) {
86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101
            sum += diff_sq.TopLeftPixel()[y * diff_sq.stride() + x];
          }
        }
      }

      sum *= 3;
      sum /= summed_values;
      sum += rounding;
      sum >>= filter_strength;

      // Clamp the value and invert it.
      if (sum > 16) sum = 16;
      sum = 16 - sum;

      sum *= filter_weight;

102 103
      count->TopLeftPixel()[height * count->stride() + width] += sum;
      accumulator->TopLeftPixel()[height * accumulator->stride() + width] +=
104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120
          sum * b.TopLeftPixel()[height * b.stride() + width];
    }
  }
}

class TemporalFilterTest : public ::testing::TestWithParam<TemporalFilterFunc> {
 public:
  virtual void SetUp() {
    filter_func_ = GetParam();
    rnd_.Reset(ACMRandom::DeterministicSeed());
  }

 protected:
  TemporalFilterFunc filter_func_;
  ACMRandom rnd_;
};

121 122 123 124
TEST_P(TemporalFilterTest, SizeCombinations) {
  // Depending on subsampling this function may be called with values of 8 or 16
  // for width and height, in any combination.
  Buffer<uint8_t> a = Buffer<uint8_t>(16, 16, 8);
Johann's avatar
Johann committed
125
  ASSERT_TRUE(a.Init());
126 127 128 129 130 131 132 133

  const int filter_weight = 2;
  const int filter_strength = 6;

  for (int width = 8; width <= 16; width += 8) {
    for (int height = 8; height <= 16; height += 8) {
      // The second buffer must not have any border.
      Buffer<uint8_t> b = Buffer<uint8_t>(width, height, 0);
Johann's avatar
Johann committed
134
      ASSERT_TRUE(b.Init());
135
      Buffer<unsigned int> accum_ref = Buffer<unsigned int>(width, height, 0);
Johann's avatar
Johann committed
136
      ASSERT_TRUE(accum_ref.Init());
137
      Buffer<unsigned int> accum_chk = Buffer<unsigned int>(width, height, 0);
Johann's avatar
Johann committed
138
      ASSERT_TRUE(accum_chk.Init());
139
      Buffer<uint16_t> count_ref = Buffer<uint16_t>(width, height, 0);
Johann's avatar
Johann committed
140
      ASSERT_TRUE(count_ref.Init());
141
      Buffer<uint16_t> count_chk = Buffer<uint16_t>(width, height, 0);
Johann's avatar
Johann committed
142
      ASSERT_TRUE(count_chk.Init());
143

144 145 146 147
      // The difference between the buffers must be small to pass the threshold
      // to apply the filter.
      a.Set(&rnd_, 0, 7);
      b.Set(&rnd_, 0, 7);
148 149 150 151 152 153 154

      accum_ref.Set(rnd_.Rand8());
      accum_chk.CopyFrom(accum_ref);
      count_ref.Set(rnd_.Rand8());
      count_chk.CopyFrom(count_ref);
      reference_filter(a, b, width, height, filter_strength, filter_weight,
                       &accum_ref, &count_ref);
155 156 157 158
      ASM_REGISTER_STATE_CHECK(
          filter_func_(a.TopLeftPixel(), a.stride(), b.TopLeftPixel(), width,
                       height, filter_strength, filter_weight,
                       accum_chk.TopLeftPixel(), count_chk.TopLeftPixel()));
159 160 161 162 163 164
      EXPECT_TRUE(accum_chk.CheckValues(accum_ref));
      EXPECT_TRUE(count_chk.CheckValues(count_ref));
      if (HasFailure()) {
        printf("Width: %d Height: %d\n", width, height);
        count_chk.PrintDifference(count_ref);
        accum_chk.PrintDifference(accum_ref);
165
        return;
166 167 168 169 170
      }
    }
  }
}

171
TEST_P(TemporalFilterTest, CompareReferenceRandom) {
172 173 174
  for (int width = 8; width <= 16; width += 8) {
    for (int height = 8; height <= 16; height += 8) {
      Buffer<uint8_t> a = Buffer<uint8_t>(width, height, 8);
Johann's avatar
Johann committed
175
      ASSERT_TRUE(a.Init());
176 177
      // The second buffer must not have any border.
      Buffer<uint8_t> b = Buffer<uint8_t>(width, height, 0);
Johann's avatar
Johann committed
178
      ASSERT_TRUE(b.Init());
179
      Buffer<unsigned int> accum_ref = Buffer<unsigned int>(width, height, 0);
Johann's avatar
Johann committed
180
      ASSERT_TRUE(accum_ref.Init());
181
      Buffer<unsigned int> accum_chk = Buffer<unsigned int>(width, height, 0);
Johann's avatar
Johann committed
182
      ASSERT_TRUE(accum_chk.Init());
183
      Buffer<uint16_t> count_ref = Buffer<uint16_t>(width, height, 0);
Johann's avatar
Johann committed
184
      ASSERT_TRUE(count_ref.Init());
185
      Buffer<uint16_t> count_chk = Buffer<uint16_t>(width, height, 0);
Johann's avatar
Johann committed
186
      ASSERT_TRUE(count_chk.Init());
187 188 189

      for (int filter_strength = 0; filter_strength <= 6; ++filter_strength) {
        for (int filter_weight = 0; filter_weight <= 2; ++filter_weight) {
190 191 192 193 194 195 196 197 198 199 200
          for (int repeat = 0; repeat < 100; ++repeat) {
            if (repeat < 50) {
              a.Set(&rnd_, 0, 7);
              b.Set(&rnd_, 0, 7);
            } else {
              // Check large (but close) values as well.
              a.Set(&rnd_, std::numeric_limits<uint8_t>::max() - 7,
                    std::numeric_limits<uint8_t>::max());
              b.Set(&rnd_, std::numeric_limits<uint8_t>::max() - 7,
                    std::numeric_limits<uint8_t>::max());
            }
201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222

            accum_ref.Set(rnd_.Rand8());
            accum_chk.CopyFrom(accum_ref);
            count_ref.Set(rnd_.Rand8());
            count_chk.CopyFrom(count_ref);
            reference_filter(a, b, width, height, filter_strength,
                             filter_weight, &accum_ref, &count_ref);
            ASM_REGISTER_STATE_CHECK(filter_func_(
                a.TopLeftPixel(), a.stride(), b.TopLeftPixel(), width, height,
                filter_strength, filter_weight, accum_chk.TopLeftPixel(),
                count_chk.TopLeftPixel()));
            EXPECT_TRUE(accum_chk.CheckValues(accum_ref));
            EXPECT_TRUE(count_chk.CheckValues(count_ref));
            if (HasFailure()) {
              printf("Weight: %d Strength: %d\n", filter_weight,
                     filter_strength);
              count_chk.PrintDifference(count_ref);
              accum_chk.PrintDifference(accum_ref);
              return;
            }
          }
        }
223 224 225 226 227 228 229
      }
    }
  }
}

TEST_P(TemporalFilterTest, DISABLED_Speed) {
  Buffer<uint8_t> a = Buffer<uint8_t>(16, 16, 8);
Johann's avatar
Johann committed
230
  ASSERT_TRUE(a.Init());
231 232 233 234 235 236 237 238

  const int filter_weight = 2;
  const int filter_strength = 6;

  for (int width = 8; width <= 16; width += 8) {
    for (int height = 8; height <= 16; height += 8) {
      // The second buffer must not have any border.
      Buffer<uint8_t> b = Buffer<uint8_t>(width, height, 0);
Johann's avatar
Johann committed
239
      ASSERT_TRUE(b.Init());
240
      Buffer<unsigned int> accum_ref = Buffer<unsigned int>(width, height, 0);
Johann's avatar
Johann committed
241
      ASSERT_TRUE(accum_ref.Init());
242
      Buffer<unsigned int> accum_chk = Buffer<unsigned int>(width, height, 0);
Johann's avatar
Johann committed
243
      ASSERT_TRUE(accum_chk.Init());
244
      Buffer<uint16_t> count_ref = Buffer<uint16_t>(width, height, 0);
Johann's avatar
Johann committed
245
      ASSERT_TRUE(count_ref.Init());
246
      Buffer<uint16_t> count_chk = Buffer<uint16_t>(width, height, 0);
Johann's avatar
Johann committed
247
      ASSERT_TRUE(count_chk.Init());
248

249 250
      a.Set(&rnd_, 0, 7);
      b.Set(&rnd_, 0, 7);
251 252 253 254 255 256 257 258 259 260 261 262

      accum_chk.Set(0);
      count_chk.Set(0);

      vpx_usec_timer timer;
      vpx_usec_timer_start(&timer);
      for (int i = 0; i < 10000; ++i) {
        filter_func_(a.TopLeftPixel(), a.stride(), b.TopLeftPixel(), width,
                     height, filter_strength, filter_weight,
                     accum_chk.TopLeftPixel(), count_chk.TopLeftPixel());
      }
      vpx_usec_timer_mark(&timer);
263 264
      const int elapsed_time = static_cast<int>(vpx_usec_timer_elapsed(&timer));
      printf("Temporal filter %dx%d time: %5d us\n", width, height,
265
             elapsed_time);
266 267 268 269 270 271 272
    }
  }
}

INSTANTIATE_TEST_CASE_P(C, TemporalFilterTest,
                        ::testing::Values(&vp9_temporal_filter_apply_c));

273 274 275 276
#if HAVE_SSE4_1
INSTANTIATE_TEST_CASE_P(SSE4_1, TemporalFilterTest,
                        ::testing::Values(&vp9_temporal_filter_apply_sse4_1));
#endif  // HAVE_SSE4_1
277
}  // namespace