2017年12月18日

tensorflow Defun で ternary weight

業界動向

お疲れ様です。Leapmindの脇坂です。

この記事は、TensorFlow Advent Calendar 2017 17日目の記事です

今回は、tensorflow の tensorflow.python.framework.function.Defun を使って、weight の 3値化をしてみます。

Defunとは

tensorflow.python.framework.function.Defun はtensorflowのfunctionを定義するためのデコレータです。 今現在(tensorflow v1.4)ドキュメントにはのってない機能です。tf.contrib.eager.defun とは違うので注意

引数にtensorflow DataTypeをとって、任意のpython functionをデコレートし、tensorflowのgraphでそのfunctionが使えるように定義します。内部的には、たぶん、graphにfunction.protoを追加します。 tensorflowにおけるfucntionは1つ以上のoperationをまとめたものだそうです。

コード例は、こんな感じです。

MyFuncの定義

@tf.Defun(tf.float32, tf.float32)
def MyFunc(x, y):
  return x + y, x - y

グラフを作る時に、MyFuncを実行。

a = tf.Constant([1.0])
b = tf.Constant([2.0])
c, d = MyFunc(a, b, name='mycall')

他にオプションのキーワード引数として、

  • func_name: Functionの名前
  • grad_func or python_grad_func: backward関数
  • out_names: output tensor の名前
  • shape_func: output tensor のshapeを指定 が、あります。

さらに詳しく知りたい方は、ソースコードとtestを見てみてください。 https://github.com/tensorflow/tensorflow/blob/r1.4/tensorflow/python/framework/function.py https://github.com/tensorflow/tensorflow/blob/r1.4/tensorflow/python/framework/function_test.py

Defunをどんな時に使いますか

さて、このDefunがどういったモノかはなんとなくわかりましたが、どんな時に効果的に使えるのでしょうか?

  1. メモリ・計算量を減らす
  2. Backwardの近似 (backwardを別の関数で近似)

この二つが考えれます。

メモリ・計算量を減らす

この使用法が一般的になるのだと思いますが、すみません、ちゃんと検証していないので自信がありません。

tensorflowでの使用例を見ると、swish の実装で、operationをまとめてmemory消費を少なくするために使われています。

swish は2、3ヶ月前に発表された(https://arxiv.org/abs/1710.05941) activation functionで、式はすごく単純

f(x)=x×sigmoid(x)\begin{aligned} f(x) = x \times sigmoid(x) \end{aligned}

です。

普通にtensorflowで実装するとこうなります。

output = tf.sigmoid(x) * x

さて、deeplearningのフレームワークでは、勾配計算のためにBackpropagationを使い、そのためforward時に各operationの入力をメモリに残しておきます。 上記のように実装すると、* operationのために、xtf.sigmoid(x) の両方をメモリに残しておきます。が、これは無駄です。 メモリの観点だけで見れば、 x だけ残しておけば良いです。

そこで、tensorflowの使用例では、 noinline=True を使い、下記のような実装をしています。

@function.Defun(
    grad_func=_swish_grad,
    shape_func=_swish_shape,
    func_name="swish",
    noinline=True)
def swish(x):
  return x * math_ops.sigmoid(x)

注意: noinlineで本当にメモリ使用が抑えられるのか、試しておりません。。。XLAを有効にした時だけの話かも?

Backwardの近似

こちらの使用方法はレアケースになるかと思いますが、本記事のメインになります。 やっとです。

ここでニューラルネットのパラメータやアクティベーションの量子化について、さらっと振り返っておきます。 通常32bitや16bit floatのパラメータ(weight)を、メモリ削減のため2値化、3値化(ternary)、N bit化などに量子化する際、 推論時には2値や3値を使い、学習時にはパラメータの値を実数で保持しておいて実数の値を更新すると、高い精度が得られる事が知られています。 この際、forward時には実数のパラメータから量子化を行いますが、量子化を行う関数はstep関数なので、普通に勾配を求めるとうまく学習できないため、 backward時には量子化を行う関数を近似して、勾配を求めます。

参考: https://developer.smartnews.com/blog/2017/03/neural-network-quantization/

さて、tensorflow の python apiで、このようなbackward時の近似を行う際、Graph.gradient_override_map()tf.stop_gradient() を使う方法もあるのですが、非常に読みにくコードになってしまいます。 ここで Defun を使ってforwardとbackwardを分けて実装すると、綺麗に書く事ができます。

パラメータの3値化

では、パラメータ(weight)の3値化(ternary)の実装例を、順をおってみていきます。

3値化の定義

まず、3値化の定義です。この記事ではわかりやすくするために適当に以下のようにします。

  • スケールファクタは、パラメータの絶対値の平均
  • パラメータが閾値0.5より大きい場合 +スケールファクタ
  • パラメータが閾値-0.5より小さい場合に -スケールファクタ
  • パラメータが0.5以下-0.5以上の場合、0
f(x)={+mean(x)(x>0.5)mean(x)(x<0.5)0(otherwise)\begin{aligned} f(x) = \left\{ \begin{array}{ll} + mean(|x|) & (x>0.5) \\ - mean(|x|) & (x<-0.5) \\ 0 & (otherwise) \end{array} \right. \end{aligned}

Backwardの近似なし3値化関数

では、次に、 この3値化関数をBackwardの近似なしで、実装してみます。

def no_backward_approximate_ternary_function(weights):
    """Forward

    Args:
        weights(tf.Variable): The weights to be ternary.

    Returns:
        ternary_weights(tf.Variable): The ternary valued weights.
    """
    threshold = 0.5
    mask_positive = (weights > threshold)
    mask_negative = (weights < -threshold)

    scaling_factor = tf.reduce_mean(tf.abs(weights))

    positive_weights = scaling_factor * tf.where(mask_positive, tf.ones_like(weights), tf.zeros_like(weights))
    negative_weights = - scaling_factor * tf.where(mask_negative, tf.ones_like(weights), tf.zeros_like(weights))

    ternary_weights = positive_weights + negative_weights
    return ternary_weights

forward時のこの関数をplotしてみます。

tf.InteractiveSession()

weights_size = (100,)
np_weights = np.random.uniform(-2., 2., size=weights_size).astype(np.float32)
weights = tf.convert_to_tensor(np_weights)

no_backward_approximate_ternary_weights = no_backward_approximate_ternary_function(weights)

order = np.argsort(np_weights)
plt.plot(np_weights[order], no_backward_approximate_ternary_weights.eval()[order])

plt.ylabel("ternary weights")
plt.xlabel("input weights")
plt.title("ternary func")

plt.show()

forward時はこれで良いのですが、backward時の勾配をplotすると、 (グラフの横軸が3値化したパラメーターに関数する勾配、縦軸が実数のパラメータに関する勾配です。)

np_grad_tenary_weights = np.random.uniform(-10., 10., size=weights_size).astype(np.float32)
grad_tenary_weights = tf.convert_to_tensor(np_grad_tenary_weights)
gard_no_backward_approximate_weights, = tf.gradients(no_backward_approximate_ternary_weights, weights, grad_ys=grad_y)

order = np.argsort(np_grad_tenary_weights)
plt.plot(np_grad_tenary_weights[order], gard_no_backward_approximate_weights.eval()[order])

plt.ylabel("grad weights")
plt.xlabel("grad ternary weights")
plt.title("backward ternary func")

plt.show()

このように無茶苦茶になってしまいます。

Defunを使ったBackwardの近似ありの3値化関数

では、Defunを使ってBackwardを近似してみます。 3値関数が恒等写像であると仮定して、その上でbackwardを実装します。 このように簡潔に書けます。

@Defun(tf.float32, tf.float32)
def ternary_backward(weights, grad_ternary):
    """Backward

    Args:
        grad_ternary(tf.Tensor): The gradient w.r.t ternary valued weights.

    Return:
        grad_weights(tf.Tensor): The gradient w.r.t. normal (non-ternary) weights.
    """
    grad_weights = grad_ternary
    return grad_weights


@Defun(tf.float32,
       grad_func=ternary_backward,
       func_name="TernaryWeight",
       shape_func=lambda op: [op.inputs[0].get_shape()])
def ternary_function(weights):
    """Forward

    Args:
        weights(tf.Variable): The weights to be ternary.

    Returns:
        ternary_weights(tf.Variable): The ternary valued weights.
    """
    threshold = 0.5
    mask_positive = (weights > threshold)
    mask_negative = (weights < -threshold)

    scaling_factor = tf.reduce_mean(tf.abs(weights))

    positive_weights = scaling_factor * tf.where(mask_positive, tf.ones_like(weights), tf.zeros_like(weights))
    negative_weights = - scaling_factor * tf.where(mask_negative, tf.ones_like(weights), tf.zeros_like(weights))

    ternary_weights = positive_weights + negative_weights
    return ternary_weights

3値化関数に恒等写像を仮定したので、 backwardは、3値化したパラメーターに関数する勾配と、実数のパラメータに関する勾配とが同じになります。 では、backward時の勾配をplotすると、

order = np.argsort(np_grad_tenary_weights)
plt.plot(np_grad_tenary_weights[order], gard_weights.eval()[order])

plt.ylabel("grad weights")
plt.xlabel("grad ternary weights")
plt.title("backward ternary func")

plt.show()

わーい、意図した通りになりましたー!!

これでおしまいです。