tensorflow Defun で ternary weight

お疲れ様です。Leapmindの脇坂です。

この記事は、TensorFlow Advent Calendar 2017 17日目の記事です

 

今回は、tensorflow の `tensorflow.python.framework.function.Defun` を使って、weight の 3値化をしてみます。

Defunとは

`tensorflow.python.framework.function.Defun` はtensorflowのfunctionを定義するためのデコレータです。
今現在(tensorflow v1.4)ドキュメントにはのってない機能です。`tf.contrib.eager.defun` とは違うので注意

引数にtensorflow DataTypeをとって、任意のpython functionをデコレートし、tensorflowのgraphでそのfunctionが使えるように定義します。内部的には、たぶん、graphにfunction.protoを追加します。
tensorflowにおけるfucntionは1つ以上のoperationをまとめたものだそうです。

参考: https://www.tensorflow.org/extend/language_bindings#overview

コード例は、こんな感じです。

MyFuncの定義


  @tf.Defun(tf.float32, tf.float32)
  def MyFunc(x, y):
    return x + y, x - y

グラフを作る時に、MyFuncを実行。


  a = tf.Constant([1.0])
  b = tf.Constant([2.0])
  c, d = MyFunc(a, b, name='mycall')

他にオプションのキーワード引数として、
* `func_name`: Functionの名前
* `grad_func` or `python_grad_func`: backward関数
* `out_names`: output tensor の名前
* `shape_func`: output tensor のshapeを指定
が、あります。

さらに詳しく知りたい方は、ソースコードとtestを見てみてください。
https://github.com/tensorflow/tensorflow/blob/r1.4/tensorflow/python/framework/function.py
https://github.com/tensorflow/tensorflow/blob/r1.4/tensorflow/python/framework/function_test.py

Defunをどんな時に使いますか

さて、この`Defun`がどういったモノかはなんとなくわかりましたが、どんな時に効果的に使えるのでしょうか?

1. メモリ・計算量を減らす
2. Backwardの近似 (backwardを別の関数で近似)

この二つが考えれます。

メモリ・計算量を減らす

この使用法が一般的になるのだと思いますが、すみません、ちゃんと検証していないので自信がありません。

tensorflowでの使用例を見ると、`swish` の実装で、operationをまとめてmemory消費を少なくするために使われています。

`swish` は2、3ヶ月前に発表された(https://arxiv.org/abs/1710.05941) activation functionで、式はすごく単純

\begin{aligned}
f(x) = x \times sigmoid(x)
\end{aligned}

です。

普通にtensorflowで実装するとこうなります。


output = tf.sigmoid(x) * x

さて、deeplearningのフレームワークでは、勾配計算のためにBackpropagationを使い、そのためforward時に各operationの入力をメモリに残しておきます。
上記のように実装すると、`*` operationのために、`x` と `tf.sigmoid(x)` の両方をメモリに残しておきます。が、これは無駄です。
メモリの観点だけで見れば、 `x` だけ残しておけば良いです。

そこで、tensorflowの使用例では、 `noinline=True` を使い、下記のような実装をしています。


@function.Defun(
    grad_func=_swish_grad,
    shape_func=_swish_shape,
    func_name="swish",
    noinline=True)
def swish(x):
  return x * math_ops.sigmoid(x)

注意: noinlineで本当にメモリ使用が抑えられるのか、試しておりません。。。XLAを有効にした時だけの話かも?

Backwardの近似

こちらの使用方法はレアケースになるかと思いますが、本記事のメインになります。
やっとです。

ここでニューラルネットのパラメータやアクティベーションの量子化について、さらっと振り返っておきます。
通常32bitや16bit floatのパラメータ(weight)を、メモリ削減のため2値化、3値化(ternary)、N bit化などに量子化する際、
推論時には2値や3値を使い、学習時にはパラメータの値を実数で保持しておいて実数の値を更新すると、高い精度が得られる事が知られています。
この際、forward時には実数のパラメータから量子化を行いますが、量子化を行う関数はstep関数なので、普通に勾配を求めるとうまく学習できないため、
backward時には量子化を行う関数を近似して、勾配を求めます。

参考: https://developer.smartnews.com/blog/2017/03/neural-network-quantization/

さて、tensorflow の python apiで、このようなbackward時の近似を行う際、`Graph.gradient_override_map()` や `tf.stop_gradient()` を使う方法もあるのですが、非常に読みにくコードになってしまいます。
ここで `Defun` を使ってforwardとbackwardを分けて実装すると、綺麗に書く事ができます。

パラメータの3値化

では、パラメータ(weight)の3値化(ternary)の実装例を、順をおってみていきます。

3値化の定義

まず、3値化の定義です。この記事ではわかりやすくするために適当に以下のようにします。

* スケールファクタは、パラメータの絶対値の平均
* パラメータが閾値0.5より大きい場合 +スケールファクタ
* パラメータが閾値-0.5より小さい場合に -スケールファクタ
* パラメータが0.5以下-0.5以上の場合、0

\begin{aligned}
f(x) = \left\{ \begin{array}{ll}
+ mean(|x|) & (x>0.5) \\
– mean(|x|) & (x<-0.5) \\ 0 & (otherwise) \end{array} \right. \end{aligned}

Backwardの近似なし3値化関数

では、次に、
この3値化関数をBackwardの近似なしで、実装してみます。


def no_backward_approximate_ternary_function(weights):
    """Forward

    Args:
        weights(tf.Variable): The weights to be ternary.

    Returns:
        ternary_weights(tf.Variable): The ternary valued weights.
    """
    threshold = 0.5
    mask_positive = (weights > threshold)
    mask_negative = (weights < -threshold)

    scaling_factor = tf.reduce_mean(tf.abs(weights))

    positive_weights = scaling_factor * tf.where(mask_positive, tf.ones_like(weights), tf.zeros_like(weights))
    negative_weights = - scaling_factor * tf.where(mask_negative, tf.ones_like(weights), tf.zeros_like(weights))

    ternary_weights = positive_weights + negative_weights
    return ternary_weights

forward時のこの関数をplotしてみます。


tf.InteractiveSession()

weights_size = (100,)
np_weights = np.random.uniform(-2., 2., size=weights_size).astype(np.float32)
weights = tf.convert_to_tensor(np_weights)

no_backward_approximate_ternary_weights = no_backward_approximate_ternary_function(weights)

order = np.argsort(np_weights)
plt.plot(np_weights[order], no_backward_approximate_ternary_weights.eval()[order])

plt.ylabel("ternary weights")
plt.xlabel("input weights")
plt.title("ternary func")

plt.show()

forward時はこれで良いのですが、backward時の勾配をplotすると、
(グラフの横軸が3値化したパラメーターに関数する勾配、縦軸が実数のパラメータに関する勾配です。)


np_grad_tenary_weights = np.random.uniform(-10., 10., size=weights_size).astype(np.float32)
grad_tenary_weights = tf.convert_to_tensor(np_grad_tenary_weights)
gard_no_backward_approximate_weights, = tf.gradients(no_backward_approximate_ternary_weights, weights, grad_ys=grad_y)

order = np.argsort(np_grad_tenary_weights)
plt.plot(np_grad_tenary_weights[order], gard_no_backward_approximate_weights.eval()[order])

plt.ylabel("grad weights")
plt.xlabel("grad ternary weights")
plt.title("backward ternary func")

plt.show()

このように無茶苦茶になってしまいます。

Defunを使ったBackwardの近似ありの3値化関数

では、Defunを使ってBackwardを近似してみます。
3値関数が恒等写像であると仮定して、その上でbackwardを実装します。
このように簡潔に書けます。


@Defun(tf.float32, tf.float32)
def ternary_backward(weights, grad_ternary):
    """Backward

    Args:
        grad_ternary(tf.Tensor): The gradient w.r.t ternary valued weights.

    Return:
        grad_weights(tf.Tensor): The gradient w.r.t. normal (non-ternary) weights.
    """
    grad_weights = grad_ternary
    return grad_weights


@Defun(tf.float32,
       grad_func=ternary_backward,
       func_name="TernaryWeight",
       shape_func=lambda op: [op.inputs[0].get_shape()])
def ternary_function(weights):
    """Forward

    Args:
        weights(tf.Variable): The weights to be ternary.

    Returns:
        ternary_weights(tf.Variable): The ternary valued weights.
    """
    threshold = 0.5
    mask_positive = (weights > threshold)
    mask_negative = (weights < -threshold)

    scaling_factor = tf.reduce_mean(tf.abs(weights))

    positive_weights = scaling_factor * tf.where(mask_positive, tf.ones_like(weights), tf.zeros_like(weights))
    negative_weights = - scaling_factor * tf.where(mask_negative, tf.ones_like(weights), tf.zeros_like(weights))

    ternary_weights = positive_weights + negative_weights
    return ternary_weights

3値化関数に恒等写像を仮定したので、
backwardは、3値化したパラメーターに関数する勾配と、実数のパラメータに関する勾配とが同じになります。
では、backward時の勾配をplotすると、


order = np.argsort(np_grad_tenary_weights)
plt.plot(np_grad_tenary_weights[order], gard_weights.eval()[order])

plt.ylabel("grad weights")
plt.xlabel("grad ternary weights")
plt.title("backward ternary func")

plt.show()

わーい、意図した通りになりましたー!!

これでおしまいです。

takuyawakisaka

Posted by takuyawakisaka