Nerds N' Computers

ソフトウェア技術や思ったことについて書きます.

ソフトウェア技術や思ったことについていろいろと書きます.

Pytorchの疎行列演算まとめ

疎行列とはほとんどがゼロ要素で埋められているような行列のことを指します.  グラフ構造を扱う際に隣接行列やAffinity Matrixが疎な行列となることが多いです. このような行列はすべての要素を保存するのではなく, 非ゼロ要素の座標と値のみを保存する方がメモリ効率がよくなります.

疎行列の保存形式もいくつかあるのですが, PytorchではCOOフォーマットのみをサポートしています. 疎行列のフォーマット(COO, LIL, CSR, CSC)について気になる方は, はむかずさんのscipy.sparseでの疎ベクトルの扱いがとてもわかりやすいです.

Pytorchにおいても疎行列の演算がサポートされていますが, 前述したようにCOOフォーマットのみのサポートであり実装されている演算が限られているなどの制約はありますが, GCNなどのグラフ構造を用いた深層学習の研究が一般化するに連れて今後も開発が進んでいくと考えています.

This API is currently experimental and may change in the near future. とあるようにまだ実験段階の機能なので, As-Isでの機能を紹介する形となりますが, 疎行列のAPIの実装に関してはv1.0.0で大幅に改善されていることもあり, Pytorchのバージョンはv1.0.0を利用して検証を行いました. 疎行列の演算に関する公式ドキュメントがほとんどない状態なので少しでも理解の助けになれば嬉しいです.

疎行列の初期化方法

一般的なCOO形式の疎行列の初期化方法と同じく, インデックスの座標と値をペアになるように渡すことで初期化ができます.

>>> i = torch.LongTensor([[0, 1, 1],
                          [2, 0, 2]])
>>> v = torch.FloatTensor([3, 4, 5])
>>> sm = torch.sparse.FloatTensor(i, v, torch.Size([2,3])).to_dense()
 0  0  3
 4  0  5
[torch.FloatTensor of size 2x3]

最近になって密行列を疎行列に変換するメソッドが実装されました. 下記のように密行列をto_sparse() メソッドで疎行列に変換できます.

>>> sm = torch.randn(2, 3).to_sparse()
tensor(indices=tensor([[0, 0, 0, 1, 1, 1],
                       [0, 1, 2, 0, 1, 2]]),
       values=tensor([ 1.5901,  0.0183, -0.6146,  1.8061, -0.0112,  0.6302]),
       size=(2, 3), nnz=6, layout=torch.sparse_coo, requires_grad=True)

また, 疎行列の諸情報を取得したい場合は以下のメソッドが利用できます.

インデックスの取得

>>> sm._indices()
tensor([[ 0,  1,  1],
        [ 2,  0,  2]])

値の取得

>>> sm._values()
tensor([ 3.,  4.,  5.])

非ゼロ要素数の取得

>>> sm._nnz()
3

転置行列の取得

>>> sm.t()
torch.sparse.FloatTensor of size (3,2) with indices:
tensor([[ 2,  0,  2],
        [ 0,  1,  1]])
and values:
tensor([ 3.,  4.,  5.])

疎行列の演算

Pytorchでは基本的な疎行列の演算が実装されています.


torch.sparse.mm(mat1, mat2) → Tensor

引数が mat1 はsparse.Tensorで, mat2Tensorである必要があるということに注意してください.

疎行列と密行列の行列積の計算はv1.0.0以前からtorch.spmm(mat1, mat2)を用いることで計算可能でしたが, 疎行列の勾配を求めることができませんでした.

例えば下記のコードはエラーになります.

>>> a = torch.sparse.FloatTensor(i, v, torch.Size([2,3])).requires_grad_(True)
>>> b = torch.rand(3, 2, requires_grad=True)
>>> c = torch.spmm(a, b)
>>> y = c.sum()
>>> y.backward()
RuntimeError: calculating the gradient of a sparse Tensor argument to mm is not supported.

これに対して, torch.sparse.mm(mat1, mat2)メソッドを使うことによって, 疎行列に対しても誤差逆伝搬を適用することができるようになっています.

>>> c = torch.sparse.mm(a, b)
>>> y = c.sum()
>>> y.backward()
>>> a.grad
tensor(indices=tensor([[0, 1, 1],
                       [2, 0, 2]]),
       values=tensor([1., 1., 1.]),
       size=(2, 3), nnz=3, layout=torch.sparse_coo)

torch.sparse.addmm(mat, mat1, mat2, alpha=1, beta=1) → Tensor

下式のように和と行列積を同時に計算します. 引数が mat, mat2Tensorで, mat2がsparse.Tensorである必要があるということに注意してください.

$\text { out } = \beta \text { mat } + \alpha \left( \operatorname { mat } 1 _ { i } @ \operatorname { mat } 2 _ { i } \right)$

>>> a = torch.sparse.FloatTensor(i, v, torch.Size([2,3])).requires_grad_(True)
>>> b = torch.randn(3, 2, requires_grad=True)
>>> c = torch.randn(2, 2, requires_grad=True)
>>> d = torch.sparse.addmm(c, a, b)
>>> d.grad_fn
<SparseAddmmBackward object at 0x110000000>
>>> y = d.sum()
>>> y.backward()
>>> a.grad
tensor(indices=tensor([[0, 1, 1],
                       [2, 0, 2]]),
       values=tensor([2.6244, 1.9874, 2.6244]),
       size=(2, 3), nnz=3, layout=torch.sparse_coo)

torch.sparse.sum(input, dim=None, dtype=None) → Tensor

疎行列の要素の和を計算します. このメソッドに関しても疎行列に関しての誤差逆伝搬が可能になっています.

>>> a = torch.sparse.FloatTensor(i, v, torch.Size([2,3])).requires_grad_(True)
>>> y = torch.sparse.sum(a)
>>> y.grad_fn
<SumBackward0 object at 0x1016c7208>
>>> y.backward()
>>> a.grad
tensor(indices=tensor([[0, 1, 1],
                       [2, 0, 2]]),
       values=tensor([1., 1., 1.]),
       size=(2, 3), nnz=3, layout=torch.sparse_coo)

まとめ

公式ドキュメントの焼き直しのようになってしまいましたが, Pytorchでの疎行列の演算の方法を簡単にまとめてみました. 追記の必要, 間違い等あれば気軽にコメントください.

From Shun