1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81
| template <typename SrcExp, typename DType, int etype> inline ResizeExp<SrcExp, DType, ExpInfo<SrcExp>::kDim> resize(const Exp<SrcExp, DType, etype>& src, index_t out_height, index_t out_width, int pad_mode = resize_pad::kEdge, DType pad_value = 0) { return ResizeExp<SrcExp, DType, ExpInfo<SrcExp>::kDim>(src.self(), out_height, out_width, pad_mode, pad_value); }
MSHADOW_XINLINE static bool InBound(int32_t x, index_t low, index_t high) { return x >= low && x <= high; } template <typename SrcExp, typename DType, int srcdim> struct Plan<ResizeExp<SrcExp, DType, srcdim>, DType> { public: explicit Plan(const ResizeExp<SrcExp, DType, srcdim>& e) : src_(MakePlan(e.src_)), start_y_(e.start_y_), start_x_(e.start_x_), step_y_(e.step_y_), step_x_(e.step_x_), src_height_(e.src_height_), src_width_(e.src_width_), out_height_(e.out_height_), pad_mode_(e.pad_mode_), pad_value_(e.pad_value_) {} MSHADOW_XINLINE DType Eval(index_t i, index_t j) const { const index_t dst_w = j; const index_t dst_h = i % out_height_; const index_t c = i / out_height_; const float src_w = start_x_ + dst_w * step_x_; const float src_h = start_y_ + dst_h * step_y_; int32_t src_h_floor = static_cast<int32_t>(std::floor(src_h)); int32_t src_w_floor = static_cast<int32_t>(std::floor(src_w)); int32_t src_h_ceil = src_h_floor + 1; int32_t src_w_ceil = src_w_floor + 1; if (pad_mode_ == resize_pad::kEdge) { auto get_src_coord = [](int32_t x, int32_t max) { return mshadow::op::min::Map(mshadow::op::max::Map(x, 0), max); };
src_h_floor = get_src_coord(src_h_floor, src_height_ - 1); src_w_floor = get_src_coord(src_w_floor, src_width_ - 1); src_h_ceil = get_src_coord(src_h_ceil, src_height_ - 1); src_w_ceil = get_src_coord(src_w_ceil, src_width_ - 1); }
DType top_left_value = pad_value_, top_right_value = pad_value_, bottom_left_value = pad_value_, bottom_right_value = pad_value_;
if (InBound(src_h_floor, 0, src_height_ - 1) && InBound(src_w_floor, 0, src_width_ - 1)) { top_left_value = src_.Eval(c * src_height_ + src_h_floor, src_w_floor); } if (InBound(src_h_floor, 0, src_height_ - 1) && InBound(src_w_ceil, 0, src_width_ - 1)) { top_right_value = src_.Eval(c * src_height_ + src_h_floor, src_w_ceil); } if (InBound(src_h_ceil, 0, src_height_ - 1) && InBound(src_w_floor, 0, src_width_ - 1)) { bottom_left_value = src_.Eval(c * src_height_ + src_h_ceil, src_w_floor); } if (InBound(src_h_ceil, 0, src_height_ - 1) && InBound(src_w_ceil, 0, src_width_ - 1)) { bottom_right_value = src_.Eval(c * src_height_ + src_h_ceil, src_w_ceil); } const float dy = src_h - src_h_floor; const float dx = src_w - src_w_floor; float result = top_left_value * (1 - dy) * (1 - dx) + bottom_right_value * dy * dx + top_right_value * (1 - dy) * dx + bottom_left_value * dy * (1 - dx); return static_cast<DType>(result); }
private: Plan<SrcExp, DType> src_; const float start_y_; const float start_x_; const float step_y_; const float step_x_; const index_t src_height_; const index_t src_width_; const index_t out_height_; const int pad_mode_; const DType pad_value_; };
|