给 MShadow 添加新的 Extension

MShadow 是一个轻量级的 Tensor Template Library。mshadow 同时支持 C++ 和 CUDA,一套代码既可以跑在CPU上,也可以跑在GPU上。并且是lazy computation的。如果对于cuda上面的运行没有极致要求的话,mshadow对于Tensor的计算是个很好的选择,性能也不会很差。如果对于性能有极致的需求,那就要手动去tune了,或者也可以尝试一下tvm。
本文尝试给mshadow添加一个新的extension: clip。完整实现在 Github

clip 的定义

clip(x, min, max) 的功能是,把x的值限定在minmax之间,超出的部分截断。基本逻辑是:

1
2
3
4
5
6
7
8
9
10
template <typename DType>
DType clip(DType x, DType min, DType max){
if (x < min) {
return min;
} else if (x > max) {
return max;
} else {
return x;
}
}

一种简单的实现方法

可以借助 mshadow 提供的三元计算函数F,这样我们只需要简单的定义一个三元计算逻辑就可以了:

1
2
3
4
5
6
7
8
9
10
11
12
struct clip {
template <typename DType>
MSHADOW_XINLINE static DType Map(DType x, DType min, DType max) {
if (x < min) {
return min;
} else if (x > max) {
return max;
} else {
return x;
}
}
};

调用方法是:

1
dst = F<clip>(src, min, max);

其中 minmax 要求是和src的shape相同的Exp,因此,如果,minmax是常数的话,例如,我们想把src的所有的元素clip到[-128, 127]之间,那么,就需要首先把minmax扩展到src相同的shape,这样,就浪费了大量的memory。为了解决这个问题,本文定义了clip extension,它的接口是:

1
2
3
4
template <typename Expr, typename MinExp, typename MaxExp>
Expr clip(Expr& src, MinExp& min, MaxExp& max);
template <typename Expr, typename DType>
Expr clip(Expr& src, DType min, DType max);

接下来就要实现这两个接口。

extension 的方法实现

一个完整的 extension 的实现,需要实现大概5个部分的内容:

  1. clip 的表达式,继承自 Exp,对应的 clip 的Exp命名为 ClipExp
  2. clip 的计算逻辑,也就是 mshadow 中的 Plan
  3. clip表达式ClipExp的 MakePlan
  4. 一些检查,包括data type, tensor shape,以及定义clip表达式的的一些信息 ExpInfo
  5. 封装易用的接口(optional)

clip返回的shape和dtype和输入的src是相同的,因此,这里直接让ClipExp继承MakeTensorExp,这样,就不需要实现上面的3,4中的内容了。

表达式的定义

表达式记录了clip这个操作需要用到哪些计算数,以及如何构建这些计算数,另外还有必要的检查和初始化操作。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
template <typename SrcExp, typename MinExp, typename MaxExp, typename DType, int srcdim>
struct ClipExp
: public MakeTensorExp<ClipExp<SrcExp, MinExp, MaxExp, DType, srcdim>, SrcExp, srcdim, DType> {
/*! \brief operand */
const SrcExp& src_; // 要clip的 Exp
const MinExp& min_; // 下限
const MaxExp& max_; // 上限
/*! \brief constructor */
explicit ClipExp(const SrcExp& src, const MinExp& min, const MaxExp& max)
: src_(src), min_(min), max_(max) {
this->shape_ = ShapeCheck<srcdim, SrcExp>::Check(src_); // 做维数检查,并且确定ClipExp的shape_
}
// 为了方面使用,min和max也可以直接传具体数值,而不是像上面一样必须要传mshadow::Exp
explicit ClipExp(const SrcExp& src, const DType min, const DType max)
: src_(src), min_(scalar(min)), max_(scalar(max)) {
this->shape_ = ShapeCheck<srcdim, SrcExp>::Check(src_);
}
};

表达式的计算

这里定义了ClipExp这个表达式的计算逻辑,也就是实现 ClipExpPlan

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
template <typename SrcExp, typename MinExp, typename MaxExp, typename DType, int srcdim>
struct Plan<ClipExp<SrcExp, MinExp, MaxExp, DType, srcdim>, DType> {
public:
explicit Plan(const ClipExp<SrcExp, MinExp, MaxExp, DType, srcdim>& e)
: src_(MakePlan(e.src_)), min_(MakePlan(e.min_)), max_(MakePlan(e.max_)) {}
// mshadow 中所有的 Eval 都是使用二维坐标(i, j)索引的,这里的Eval方法就是ClipExp这个表达式的实际的计算逻辑
MSHADOW_XINLINE DType Eval(index_t i, index_t j) const {
DType src = src_.Eval(i, j);
DType min = min_.Eval(i, j);
DType max = max_.Eval(i, j);
if (src < min) {
return min;
} else if (src > max) {
return max;
} else {
return src;
}
}

private:
Plan<SrcExp, DType> src_;
Plan<MaxExp, DType> min_;
Plan<MinExp, DType> max_;
};

调用接口

至此,就可以通过ClipExp来构建表达式,实现clip的计算了,但是,很明显,ClipExp 这种接口不符合用户的使用直觉,因此,通常要给 extension 定义符合用户直接的使用接口,这里要定义clip方法。
接口一:所有的参数全部是mshadow的Exp

1
2
3
4
5
6
7
8
template <typename SrcExp, typename MinExp, typename MaxExp, typename DType, int etype, int mintype,
int maxtype>
inline ClipExp<SrcExp, MinExp, MaxExp, DType, ExpInfo<SrcExp>::kDim> clip(
const Exp<SrcExp, DType, etype>& src, const Exp<MinExp, DType, mintype>& min,
const Exp<MaxExp, DType, maxtype>& max) {
return ClipExp<SrcExp, MinExp, MaxExp, DType, ExpInfo<SrcExp>::kDim>(src.self(), min.self(),
max.self());
}

传入的 minmax必须是Exp,例如,一个Tensor,这种调用方式的 minmax 中的元素可以不相同。

接口二:minmax 直接传入具体算数数值

1
2
3
4
5
6
template <typename SrcExp, typename DType, int etype>
inline ClipExp<SrcExp, ScalarExp<DType>, ScalarExp<DType>, DType, ExpInfo<SrcExp>::kDim> clip(
const Exp<SrcExp, DType, etype>& src, const DType min, const DType max) {
return ClipExp<SrcExp, ScalarExp<DType>, ScalarExp<DType>, DType, ExpInfo<SrcExp>::kDim>(
src.self(), min, max);
}

minmax 会在 ClipExp 的构造函数中构造成ScalarExp

合法性检查

这里没有进行合法性检查,例如,我们要检查 src 的shape 和 min/max的shape是否相同,因为如果minmax不是ScalarExp时必须要求这三个操作数的shape是一致的。

0%