DGL-KE Documentation

Knowledge graphs (KGs) are data structures that store information about different entities (nodes) and their relations (edges). A common approach of using KGs in various machine learning tasks is to compute knowledge graph embeddings. DGL-KE is a high performance, easy-to-use, and scalable package for learning large-scale knowledge graph embeddings. The package is implemented on the top of Deep Graph Library (DGL) and developers can run DGL-KE on CPU machine, GPU machine, as well as clusters with a set of popular models, including TransE, TransR, RESCAL, DistMult, ComplEx, and RotatE.

_images/dgl_ke_arch.png

Get started with DGL-KE!

Installation Guide

DGL-KE works with both Linux and macOS, and it requires Python version 3.5 or later (Python 3.4 or earlier is not tested). DGL-KE can run on both pytorch and mxnet, please refer the following pages to install pytorch or mxnet.

Pytorch installation

MXNet installation

Install DGL

DGL-KE is implemented on DGL. You can install DGL using pip:

pip install dgl

or you can install DGL from source:

git clone --recursive https://github.com/dmlc/dgl.git
cd dgl
mkdir build
cd build
cmake ../
make -j4

Install DGL-KE

The fastest way to install DGL-KE is by using pip:

pip install dglke

or you can install DGL-KE from source:

git clone https://github.com/awslabs/dgl-ke.git
cd dgl-ke/python
sudo python3 setup.py install

Have a quick test

Once you install DGL-KE successfully, you can test it by the following command:

dglke_train --model_name TransE_l2 --dataset FB15k --batch_size 1000 --neg_sample_size 200 --hidden_dim 400 \
--gamma 19.9 --lr 0.25 --max_step 500 --log_interval 100 --batch_size_eval 16 --test -adv \
--regularization_coef 1.00E-09 --num_thread 1 --num_proc 8

This command will download the FB15k dataset, train the transE model on that, and save the trained embeddings into the file. You can see the following output at the end of the training:

training takes 37.735950231552124 seconds
-------------- Test result --------------
Test average MRR : 0.47615999491724303
Test average MR : 58.97734929153053
Test average HITS@1 : 0.28428501295051717
Test average HITS@3 : 0.6277276497773865
Test average HITS@10 : 0.775862944592101
-----------------------------------------
testing takes 110.887 seconds

DGL-KE Command Line

DGL-KE provides four commands to users:

dglke_train trains KG embeddings on CPUs or GPUs in a single machine and saves the trained node embeddings and relation embeddings into the file.

dglke_eval reads the pre-trained embeddings and evaluates the embeddings with a link prediction task on the test set. This is a common task used for evaluating the quality of pre-trained KG embeddings.

dglke_partition partitions the given knowledge graph into N parts by the METIS partition algorithm. Different partitions will be stored on different machines in distributed training. You can find more details about the METIS partition algorithm in this link.

dglke_dist_train launches a set of processes in the cluster for distributed training.

Training on Multi-Core

Multi-core processors are very common and widely used in modern computer architecture. DGL-KE is optimized on multi-core processors for the best system performance. The following command will train a transE model on FB15k dataset on a multi-core machine:

dglke_train --model_name TransE_l2 --dataset FB15k --batch_size 1000 --neg_sample_size 200 --hidden_dim 400 \
--gamma 19.9 --lr 0.25 --max_step 3000 --log_interval 100 --batch_size_eval 16 --test -adv \
--regularization_coef 1.00E-09 --num_thread 1 --num_proc 8

--num_proc indicates that we will launch 8 processes in parallel for the training task, and --num_thread indicates that each process will use 1 thread. Typically, num_proc * num_thread is set to <= the number_of _cores of the current machine for the best performance. For example, when the number of processes is the same as the number of CPU cores, a user should use one thread in each process.

--model_name is used to specify our model, including TransE_l2, TransE_l1, DistMult, ComplEx, TransR, RESCAL, and RotatE.

--dataset is used to choose a built-in dataset, including FB15k, FB15k-237, wn18, wn18rr, and Freebase. See more details about the built-in KG on this page.

--batch_size, --neg_batch_size is the hyper-parameter used for training sampler, and --batch_size_eval is the hyper-parameter used for the test.

--hidden_dim defines the dimension size of the KG embeddings. --gamma is a hyper-parameter to initialize embeddings. --regularization_coef is the hyper-parameter for regularization.

--lr is used to set the learning rate for our optimization algorithm. --max_step defines the maximal learning steps for the training task. Note that, the total steps in our training is max_step * num_proc. With multi-processing, we need to adjust the number of --max_step in each process. Usually, we only need the total number of steps performed by all processes equal to the number of steps performed in the single-process training.

-adv indicates whether to use negative adversarial sampling. It will weight negative samples with higher scores more.

--log_interval indicates that on every 100 steps we print the training loss on the screen like this:

[proc 7][Train](100/500) average pos_loss: 0.7686050720512867
[proc 7][Train](100/500) average neg_loss: 0.6058262066915632
[proc 7][Train](100/500) average loss: 0.6872156363725662
[proc 7][Train](100/500) average regularization: 8.930973201586312e-06
[proc 7][Train] 100 steps take 22.813 seconds
[proc 7]sample: 0.226, forward: 13.125, backward:

As we can see, every 100 steps will take 22.8 seconds on each process.

--test indicates that we will do an evaluation after training. It could print the following outputs to the screen:

training takes 37.735950231552124 seconds
-------------- Test result --------------
Test average MRR : 0.47615999491724303
Test average MR : 58.97734929153053
Test average HITS@1 : 0.28428501295051717
Test average HITS@3 : 0.6277276497773865
Test average HITS@10 : 0.775862944592101
-----------------------------------------
testing takes 110.887 seconds

After training, we can see a new folder ckpts/TransE_l2_FB15k_0, which stores our training log and trained KG embeddings. Users can set --no_save_emb to stop saving embedding to the file.

Training on single GPU

Training knowledge graph embedding contains large numbers of tensor computation, which can be accelerated by GPU. DGL-KE can run on single-GPU, as well as the multi-GPU machine. Also, it can run in a mix-gpu-cpu environment, where the embedding data cannot be fit into GPU memory.

The following command trains the transE model on FB15k on a single GPU:

dglke_train --model_name TransE_l2 --dataset FB15k --batch_size 1000 --log_interval 1000 \
--neg_sample_size 200 --regularization_coef=1e-9 --hidden_dim 400 --gamma 19.9 \
--lr 0.25 --batch_size_eval 16 --test -adv --gpu 0 --max_step 24000

Most of the options here we have already seen in the previous section. The only difference is that we add --gpu 0 here to indicate that we will use 1 GPU to train our model. Compared to the cpu training, every 100 steps only takes 0.68 seconds on each Nvidia v100 GPU, which is much faster 22.8 second in CPU training:

[proc 0][Train](24000/24000) average pos_loss: 0.2704171320796013
[proc 0][Train](24000/24000) average neg_loss: 0.39646861135959627
[proc 0][Train](24000/24000) average loss: 0.33344287276268003
[proc 0][Train](24000/24000) average regularization: 0.0017754920991137624
[proc 0][Train] 100 steps take 0.680 seconds

Mix CPU-GPU training

By default, DGL-KE keeps all node and relation embeddings in GPU memory for single-GPU training. Therefore, it cannot train embeddings of large knowledge graphs because the capacity of GPU memory typically is much smaller than the CPU memory. So if your KG embedding is too large to fit into the GPU memory, you can use --mix_cpu_gpu training:

dglke_train --model_name TransE_l2 --dataset FB15k --batch_size 1000 --log_interval 1000 \
--neg_sample_size 200 --regularization_coef=1e-9 --hidden_dim 400 --gamma 19.9 \
--lr 0.25 --batch_size_eval 16 --test -adv --gpu 0 --max_step 24000 --mix_cpu_gpu

The --mix_cpu_gpu training will keep node and relation embeddings in CPU memory and perform batch computation in GPU. In this way, you can train very large KG embeddings as long as your cpu memory can handle it. While the training speed of mix_cpu_gpu training will be slower than pure GPU training:

[proc 0][Train](24000/24000) average pos_loss: 0.2693914473056793
[proc 0][Train](24000/24000) average neg_loss: 0.39576649993658064
[proc 0][Train](24000/24000) average loss: 0.3325789734721184
[proc 0][Train](24000/24000) average regularization: 0.0017816077976021915
[proc 0][Train] 100 steps take 1.073 seconds
[proc 0]sample: 0.158, forward: 0.383, backward: 0.214, update: 0.316

As we can see, the mix_cpu_gpu training takes 1.07 seconds on every 100 steps.

Training on Multi-GPU

DGL-KE also supports multi-GPU training, which can increase performance by distributing training across multiple GPUs. The following figure depicts 4 GPUs on a single machine and connected to the CPU through a PCIe switch. Multi-GPU training automatically keeps node and relation embeddings on CPUs and dispatch batches to different GPUs.

_images/multi-gpu.svg

The following command shows how to training our transE model using 4 Nvidia v100 GPUs jointly:

dglke_train --model_name TransE_l2 --dataset FB15k --batch_size 1000 --log_interval 1000 \
--neg_sample_size 200 --regularization_coef=1e-9 --hidden_dim 400 --gamma 19.9 \
--lr 0.25 --batch_size_eval 16 --test -adv --gpu 0 1 2 3 --max_step 6000

Compared to single-GPU training, we change --gpu 0 to --gpu 0 1 2 3, and also we change --max_step from 24000 to 6000.

Users can add --async_update option for multi-GPU training. This optimization overlaps batch computation in GPU with gradient updates on CPU to speed up the overall training:

dglke_train --model_name TransE_l2 --dataset FB15k --batch_size 1000 --log_interval 1000 \
--neg_sample_size 200 --regularization_coef=1e-9 --hidden_dim 400 --gamma 19.9 \
--lr 0.25 --batch_size_eval 16 --test -adv --gpu 0 1 2 3 --async_update --max_step 6000

--async_update can increase system performance but it could slow down the model convergence. So DGL-KE provides another option called --force_sync_interval that forces all GPU sync their model on every N steps. For example, the following command will sync model across GPUs on every 1000 steps:

dglke_train --model_name TransE_l2 --dataset FB15k --batch_size 1000 --log_interval 1000 \
--neg_sample_size 200 --regularization_coef=1e-9 --hidden_dim 400 --gamma 19.9 \
--lr 0.25 --batch_size_eval 16 --test -adv --gpu 0 1 2 3 --async_update --max_step 6000 --force_sync_interval 1000

Evaluation on Pre-Trained Embeddings

By default, dglke_train saves the embeddings in the ckpts folder. Each runs creates a new folder in ckpts to store the training results. The new folder is named after xxxx_yyyy_zz, where xxxx is the model name, yyyy is the dataset name, zz is a sequence number that ensures a unique name for each run.

The saved embeddings are stored as numpy ndarrays. The node embedding is saved as XXX_YYY_entity.npy. The relation embedding is saved as XXX_YYY_relation.npy. XXX is the dataset name and YYY is the model name.

A user can disable saving embeddings with --no_save_emb. This might be useful for some cases, such as hyperparameter tuning.

dglke_eval reads the pre-trained embeddings and evaluates the embeddings with a link prediction task on the test set. This is a common task used for evaluating the quality of pre-trained KG embeddings. The following command evaluates the pre-trained KG embedding on multi-cores:

dglke_eval --model_name TransE_l2 --dataset FB15k --hidden_dim 400 --gamma 19.9 --batch_size_eval 16 \
--num_thread 1 --num_proc 8 --model_path ~/my_task/ckpts/TransE_l2_FB15k_0/

We can also use GPUs in our evaluation tasks:

dglke_eval --model_name TransE_l2 --dataset FB15k --hidden_dim 400 --gamma 19.9 --batch_size_eval 16 \
--gpu 0 1 2 3 4 5 6 7 --model_path ~/my_task/ckpts/TransE_l2_FB15k_0/

Train Built-in Knowledage Graphs

DGL-KE provides five built-in knowledge graphs:

Dataset #nodes #edges #relations
FB15k 14951 592213 1345
FB15k-237 14541 310116 237
wn18 40943 151442 18
wn18rr 40943 93003 11
Freebase 86054151 338586276 14824

Users can specify one of the datasets with --dataset option in their tasks.

Benchmark result

DGL-KE also provides benchmark results on FB15k, wn18, as well as Freebase. Users can go to the corresponded folder to check out the scripts and results. All the benchmark results are done by AWS EC2. For multi-cpu and distributed training, the target instance is r5dn.24xlarge, which has 48 CPU cores and 768 GB memory. Also, r5dn.xlarge has 100Gbit network throughput, which is powerful for distributed training. For GPU training, our target instance is p3.16xlarge, which has 64 CPU cores and 8 Nvidia v100 GPUs. For users, you can choose your own instance by your demand and tune the hyper-parameters for the best performance.

All the scripts can be found on this page.

FB15k

One-GPU training

Models MR MRR HITS-1 HITS-3 HITS-10 TIME
TransE_l1 47.34 0.672 0.557 0.763 0.849 201
TransE_l2 47.04 0.649 0.525 0.746 0.844 167
DistMult 61.43 0.696 0.586 0.782 0.873 150
ComplEx 64.73 0.757 0.672 0.826 0.886 171
RESCAL 124.5 0.661 0.589 0.704 0.787 1252
TransR 59.99 0.670 0.585 0.728 0.808 530
RotatE 43.85 0.726 0.632 0.799 0.873 1405

8-GPU training

Models MR MRR HITS-1 HITS-3 HITS-10 TIME
TransE_l1 48.59 0.662 0.542 0.756 0.846 53
TransE_l2 47.52 0.627 0.492 0.733 0.838 49
DistMult 59.44 0.679 0.566 0.764 0.864 47
ComplEx 64.98 0.750 0.668 0.814 0.883 49
RESCAL 133.3 0.643 0.570 0.685 0.773 179
TransR 66.51 0.666 0.581 0.724 0.803 90
RotatE 50.04 0.685 0.581 0.763 0.851 120

Multi-CPU training

Models MR MRR HITS-1 HITS-3 HITS-10 TIME
TransE_l1 48.32 0.645 0.521 0.741 0.838 140
TransE_l2 45.28 0.633 0.501 0.735 0.840 58
DistMult 62.63 0.647 0.529 0.733 0.846 58
ComplEx 67.83 0.694 0.590 0.772 0.863 69

Distributed training

Models MR MRR HITS-1 HITS-3 HITS-10 TIME
TransE_l1 38.26 0.691 0.591 0.765 0.853 104
TransE_l2 34.84 0.645 0.510 0.754 0.854 31
DistMult 51.85 0.661 0.532 0.762 0.864 57
ComplEx 62.52 0.667 0.567 0.737 0.836 65

wn18

One-GPU training

Models MR MRR HITS-1 HITS-3 HITS-10 TIME
TransE_l1 355.4 0.764 0.602 0.928 0.949 327
TransE_l2 209.4 0.560 0.306 0.797 0.943 223
DistMult 419.0 0.813 0.702 0.921 0.948 133
ComplEx 318.2 0.932 0.914 0.948 0.959 144
RESCAL 563.6 0.848 0.792 0.898 0.928 308
TransR 432.8 0.609 0.452 0.736 0.850 906
RotatE 451.6 0.944 0.940 0.945 0.950 671

8-GPU training

Models MR MRR HITS-1 HITS-3 HITS-10 TIME
TransE_l1 348.8 0.739 0.553 0.927 0.948 111
TransE_l2 198.9 0.559 0.305 0.798 0.942 71
DistMult 798.8 0.806 0.705 0.903 0.932 66
ComplEx 535.0 0.938 0.931 0.944 0.949 53
RotatE 487.7 0.943 0.939 0.945 0.951 127

Multi-CPU training

Models MR MRR HITS-1 HITS-3 HITS-10 TIME
TransE_l1 376.3 0.593 0.264 0.926 0.949 925
TransE_l2 218.3 0.528 0.259 0.777 0.939 210
DistMult 837.4 0.791 0.675 0.904 0.933 362
ComplEx 806.3 0.904 0.881 0.926 0.937 281

Distributed training

Models MR MRR HITS-1 HITS-3 HITS-10 TIME
TransE_l1 136.0 0.848 0.768 0.927 0.950 759
TransE_l2 85.04 0.797 0.672 0.921 0.958 144
DistMult 278.5 0.872 0.816 0.926 0.939 275
ComplEx 333.8 0.838 0.796 0.870 0.906 273

Freebase

8-GPU training

Models MR MRR HITS-1 HITS-3 HITS-10 TIME
TransE_l2 23.56 0.736 0.663 0.782 0.873 4767
DistMult 46.19 0.833 0.813 0.842 0.869 4281
ComplEx 46.70 0.834 0.815 0.843 0.869 8356
TransR 49.68 0.696 0.653 0.716 0.773 14235
RotatE 93.20 0.769 0.748 0.779 0.804 9060

Multi-CPU training

Models MR MRR HITS-1 HITS-3 HITS-10 TIME
TransE_l2 30.82 0.815 0.766 0.848 0.902 6993
DistMult 44.16 0.834 0.815 0.843 0.869 7146
ComplEx 45.62 0.835 0.817 0.843 0.870 8732

Distributed training

Models MR MRR HITS-1 HITS-3 HITS-10 TIME
TransE_l2 34.25 0.764 0.705 0.802 0.869 1633
DistMult 75.15 0.769 0.751 0.779 0.801 1679
ComplEx 77.83 0.771 0.754 0.779 0.802 2293

Train User-Defined Knowledage Graphs

Users can use DGL-KE to train embeddings on their own knowledge graphs. In this case, users need to use --data_path to specify the path to the knowledge graph dataset, --data_files to specify the triplets of a knowledge graph as well as node/relation Id mapping, --format to specify the input format of the knowledge graph.

The input format of users’ knowledge graphs

Users need to store all the data associated with a knowledge graph in the same directory. DGL-KE supports two knowledge graph input formats:

raw_udd_[h|r|t], raw user defined dataset. In this format, users only need to provide triplets and the dataloader generates the id mappings for entities and relations in the triplets. The dataloader outputs two files: entities.tsv for entity id mapping and relations.tsv for relation id mapping while loading data. The order of head, relation and tail entities are described in [h|r|t], for example, raw_udd_trh means the triplets are stored in the order of tail, relation and head. The directory contains three files:

  • train stores the triplets in the training set. The format of a triplet, e.g., [src_name, rel_name, dst_name], should follow the order specified in [h|r|t]
  • valid stores the triplets in the validation set. The format of a triplet, e.g., [src_name, rel_name, dst_name], should follow the order specified in [h|r|t]. This is optional.
  • test stores the triplets in the test set. The format of a triplet, e.g., [src_name, rel_name, dst_name], should follow the order specified in [h|r|t]. This is optional.

udd_[h|r|t], user defined dataset. In this format, user should provide the id mapping for entities and relations. The order of head, relation and tail entities are described in [h|r|t], for example, raw_udd_trh means the triplets are stored in the order of tail, relation and head. The directory should contains five files:

  • entities stores the mapping between entity name and entity Id
  • relations stores the mapping between relation name relation Id
  • train stores the triplets in the training set. The format of a triplet, e.g., [src_id, rel_id, dst_id], should follow the order specified in [h|r|t]
  • valid stores the triplets in the validation set. The format of a triplet, e.g., [src_id, rel_id, dst_id], should follow the order specified in [h|r|t]
  • test stores the triplets in the test set. The format of a triplet, e.g., [src_id, rel_id, dst_id], should follow the order specified in [h|r|t]

Distributed Training on Large Data