Skip to content

Commit

Permalink
make training conditional for the inferrer
Browse files Browse the repository at this point in the history
  • Loading branch information
baxtree committed Feb 12, 2021
1 parent a27f454 commit 6fbc435
Show file tree
Hide file tree
Showing 8 changed files with 20 additions and 230 deletions.
4 changes: 3 additions & 1 deletion examples/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,9 @@ def main():
# Create the model and load the trained weights.
trainer = Trainer(model, config)
trainer.build_model()
trainer.train_model()

if config.load_from_data is not None:
trainer.train_model()

trainer.infer_tails(1, 10, topk=5)
trainer.infer_heads(10, 20, topk=5)
Expand Down
104 changes: 0 additions & 104 deletions pykg2vec/models/pairwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,6 @@ class TransE(PairwiseModel):
Args:
config (object): Model configuration parameters.
Examples:
>>> from pykg2vec.models.pairwise import TransE
>>> from pykg2vec.utils.trainer import Trainer
>>> model = TransE()
>>> trainer = Trainer(model=model)
>>> trainer.build_model()
>>> trainer.train_model()
Portion of the code based on `OpenKE_TransE`_ and `wencolani`_.
.. _OpenKE_TransE: https://github.com/thunlp/OpenKE/blob/master/models/TransE.py
Expand Down Expand Up @@ -115,14 +107,6 @@ class TransH(PairwiseModel):
Args:
config (object): Model configuration parameters.
Examples:
>>> from pykg2vec.models.pairwise import TransH
>>> from pykg2vec.utils.trainer import Trainer
>>> model = TransH()
>>> trainer = Trainer(model=model)
>>> trainer.build_model()
>>> trainer.train_model()
Portion of the code based on `OpenKE_TransH`_ and `thunlp_TransH`_.
.. _OpenKE_TransH:
Expand Down Expand Up @@ -208,14 +192,6 @@ class TransD(PairwiseModel):
Args:
config (object): Model configuration parameters.
Examples:
>>> from pykg2vec.models.pairwise import TransD
>>> from pykg2vec.utils.trainer import Trainer
>>> model = TransD()
>>> trainer = Trainer(model=model)
>>> trainer.build_model()
>>> trainer.train_model()
Portion of the code based on `OpenKE_TransD`_.
.. _OpenKE_TransD:
Expand Down Expand Up @@ -312,14 +288,6 @@ class TransM(PairwiseModel):
Args:
config (object): Model configuration parameters.
Examples:
>>> from pykg2vec.models.pairwise import TransM
>>> from pykg2vec.utils.trainer import Trainer
>>> model = TransM()
>>> trainer = Trainer(model=model)
>>> trainer.build_model()
>>> trainer.train_model()
.. _Transition-based Knowledge Graph Embedding with Relational Mapping Properties:
https://pdfs.semanticscholar.org/0ddd/f37145689e5f2899f8081d9971882e6ff1e9.pdf
Expand Down Expand Up @@ -406,14 +374,6 @@ class TransR(PairwiseModel):
Args:
config (object): Model configuration parameters.
Examples:
>>> from pykg2vec.models.pairwise import TransR
>>> from pykg2vec.utils.trainer import Trainer
>>> model = TransR()
>>> trainer = Trainer(model=model)
>>> trainer.build_model()
>>> trainer.train_model()
.. _thunlp_transR:
https://github.com/thunlp/TensorFlow-TransX/blob/master/transR.py
Expand Down Expand Up @@ -519,14 +479,6 @@ class SLM(PairwiseModel):
Args:
config (object): Model configuration parameters.
Examples:
>>> from pykg2vec.models.pairwise import SLM
>>> from pykg2vec.utils.trainer import Trainer
>>> model = SLM()
>>> trainer = Trainer(model=model)
>>> trainer.build_model()
>>> trainer.train_model()
.. _Reasoning With Neural Tensor Networks for Knowledge Base Completion:
https://nlp.stanford.edu/pubs/SocherChenManningNg_NIPS2013.pdf
"""
Expand Down Expand Up @@ -602,14 +554,6 @@ class SME(PairwiseModel):
Args:
config (object): Model configuration parameters.
Examples:
>>> from pykg2vec.models.pairwise import SME
>>> from pykg2vec.utils.trainer import Trainer
>>> model = SME()
>>> trainer = Trainer(model=model)
>>> trainer.build_model()
>>> trainer.train_model()
Portion of the code based on glorotxa_.
.. _glorotxa: https://github.com/glorotxa/SME/blob/master/model.py
Expand Down Expand Up @@ -725,14 +669,6 @@ class SME_BL(SME):
Args:
config (object): Model configuration parameters.
Examples:
>>> from pykg2vec.models.pairwise import SME
>>> from pykg2vec.utils.trainer import Trainer
>>> model = SME_BL()
>>> trainer = Trainer(model=model)
>>> trainer.build_model()
>>> trainer.train_model()
.. _`SME`: api.html#pykg2vec.models.pairwise.SME
"""
Expand Down Expand Up @@ -799,14 +735,6 @@ class RotatE(PairwiseModel):
Args:
config (object): Model configuration parameters.
Examples:
>>> from pykg2vec.models.pairwise import RotatE
>>> from pykg2vec.utils.trainer import Trainer
>>> model = RotatE()
>>> trainer = Trainer(model=model)
>>> trainer.build_model()
>>> trainer.train_model()
.. _Rotate-Knowledge graph embedding by relation rotation in complex space:
https://openreview.net/pdf?id=HkgEQnRqYQ
"""
Expand Down Expand Up @@ -873,14 +801,6 @@ class Rescal(PairwiseModel):
Args:
config (object): Model configuration parameters.
Examples:
>>> from pykg2vec.models.pairwise import Rescal
>>> from pykg2vec.utils.trainer import Trainer
>>> model = Rescal()
>>> trainer = Trainer(model=model)
>>> trainer.build_model()
>>> trainer.train_model()
.. _mnick: https://github.com/mnick/rescal.py/blob/master/rescal/rescal.py
.. _OpenKE_Rescal: https://github.com/thunlp/OpenKE/blob/master/models/RESCAL.py
Expand Down Expand Up @@ -958,14 +878,6 @@ class NTN(PairwiseModel):
Args:
config (object): Model configuration parameters.
Examples:
>>> from pykg2vec.models.pairwise import NTN
>>> from pykg2vec.utils.trainer import Trainer
>>> model = NTN()
>>> trainer = Trainer(model=model)
>>> trainer.build_model()
>>> trainer.train_model()
.. _siddharth-agrawal:
https://github.com/siddharth-agrawal/Neural-Tensor-Network/blob/master/neuralTensorNetwork.py
Expand Down Expand Up @@ -1065,14 +977,6 @@ class KG2E(PairwiseModel):
Args:
config (object): Model configuration parameters.
Examples:
>>> from pykg2vec.models.pairwise import KG2E
>>> from pykg2vec.utils.trainer import Trainer
>>> model = KG2E()
>>> trainer = Trainer(model=model)
>>> trainer.build_model()
>>> trainer.train_model()
.. _`mana-ysh's repository`:
https://github.com/mana-ysh/gaussian-embedding/blob/master/src/models/gaussian_model.py
Expand Down Expand Up @@ -1189,14 +1093,6 @@ class HoLE(PairwiseModel):
Args:
config (object): Model configuration parameters.
Examples:
>>> from pykg2vec.models.pairwise import HoLE
>>> from pykg2vec.utils.trainer import Trainer
>>> model = HoLE()
>>> trainer = Trainer(model=model)
>>> trainer.build_model()
>>> trainer.train_model()
.. _Holographic Embeddings of Knowledge Graphs:
https://arxiv.org/pdf/1510.04935.pdf
Expand Down
88 changes: 0 additions & 88 deletions pykg2vec/models/pointwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,6 @@ class ANALOGY(PointwiseModel):
Args:
config (object): Model configuration parameters.
Examples:
>>> from pykg2vec.models.pointwise import ANALOGY
>>> from pykg2vec.utils.trainer import Trainer
>>> model = ANALOGY()
>>> trainer = Trainer(model=model)
>>> trainer.build_model()
>>> trainer.train_model()
.. _Analogical Inference for Multi-relational Embeddings:
http://proceedings.mlr.press/v70/liu17d/liu17d.pdf
Expand Down Expand Up @@ -136,14 +128,6 @@ class Complex(PointwiseModel):
Args:
config (object): Model configuration parameters.
Examples:
>>> from pykg2vec.models.pointwise import Complex
>>> from pykg2vec.utils.trainer import Trainer
>>> model = Complex()
>>> trainer = Trainer(model=model)
>>> trainer.build_model()
>>> trainer.train_model()
.. _Complex Embeddings for Simple Link Prediction:
http://proceedings.mlr.press/v48/trouillon16.pdf
Expand Down Expand Up @@ -227,14 +211,6 @@ class ComplexN3(Complex):
Args:
config (object): Model configuration parameters.
Examples:
>>> from pykg2vec.models.pointwise import ComplexN3
>>> from pykg2vec.utils.trainer import Trainer
>>> model = ComplexN3()
>>> trainer = Trainer(model=model, debug=False)
>>> trainer.build_model()
>>> trainer.train_model()
.. _Complex Embeddings for Simple Link Prediction:
http://proceedings.mlr.press/v48/trouillon16.pdf
Expand Down Expand Up @@ -272,14 +248,6 @@ class ConvKB(PointwiseModel):
Args:
config (object): Model configuration parameters.
Examples:
>>> from pykg2vec.models.pointwise import ConvKB
>>> from pykg2vec.utils.trainer import Trainer
>>> model = ConvKB()
>>> trainer = Trainer(model=model)
>>> trainer.build_model()
>>> trainer.train_model()
.. _daiquocnguyen:
https://github.com/daiquocnguyen/ConvKB
Expand Down Expand Up @@ -357,14 +325,6 @@ class CP(PointwiseModel):
Args:
config (object): Model configuration parameters.
Examples:
>>> from pykg2vec.models.pointwise import CP
>>> from pykg2vec.utils.trainer import Trainer
>>> model = CP()
>>> trainer = Trainer(model=model)
>>> trainer.build_model()
>>> trainer.train_model()
.. _Canonical Tensor Decomposition for Knowledge Base Completion:
http://proceedings.mlr.press/v80/lacroix18a/lacroix18a.pdf
Expand Down Expand Up @@ -438,14 +398,6 @@ class DistMult(PointwiseModel):
Args:
config (object): Model configuration parameters.
Examples:
>>> from pykg2vec.models.pointwise import DistMult
>>> from pykg2vec.utils.trainer import Trainer
>>> model = DistMult()
>>> trainer = Trainer(model=model)
>>> trainer.build_model()
>>> trainer.train_model()
.. _EMBEDDING ENTITIES AND RELATIONS FOR LEARNING AND INFERENCE IN KNOWLEDGE BASES:
https://arxiv.org/pdf/1412.6575.pdf
Expand Down Expand Up @@ -513,14 +465,6 @@ class SimplE(PointwiseModel):
Args:
config (object): Model configuration parameters.
Examples:
>>> from pykg2vec.models.pointwise import SimplE
>>> from pykg2vec.utils.trainer import Trainer
>>> model = SimplE()
>>> trainer = Trainer(model=model)
>>> trainer.build_model()
>>> trainer.train_model()
.. _SimplE Embedding for Link Prediction in Knowledge Graphs:
https://papers.nips.cc/paper/7682-simple-embedding-for-link-prediction-in-knowledge-graphs.pdf
Expand Down Expand Up @@ -599,14 +543,6 @@ class SimplE_ignr(SimplE):
Args:
config (object): Model configuration parameters.
Examples:
>>> from pykg2vec.models.pointwise import SimplE_ignr
>>> from pykg2vec.utils.trainer import Trainer
>>> model = SimplE_ignr()
>>> trainer = Trainer(model=model)
>>> trainer.build_model()
>>> trainer.train_model()
.. _SimplE Embedding for Link Prediction in Knowledge Graphs:
https://papers.nips.cc/paper/7682-simple-embedding-for-link-prediction-in-knowledge-graphs.pdf
Expand Down Expand Up @@ -652,14 +588,6 @@ class QuatE(PointwiseModel):
Args:
config (object): Model configuration parameters.
Examples:
>>> from pykg2vec.models.pointwise import QuatE
>>> from pykg2vec.utils.trainer import Trainer
>>> model = QuatE()
>>> trainer = Trainer(model=model)
>>> trainer.build_model()
>>> trainer.train_model()
.. _cheungdaven: https://github.com/cheungdaven/QuatE.git
.. _Quaternion Knowledge Graph Embeddings:
Expand Down Expand Up @@ -848,14 +776,6 @@ class OctonionE(PointwiseModel):
Args:
config (object): Model configuration parameters.
Examples:
>>> from pykg2vec.models.pointwise import OctonionE
>>> from pykg2vec.utils.trainer import Trainer
>>> model = OctonionE()
>>> trainer = Trainer(model=model)
>>> trainer.build_model()
>>> trainer.train_model()
.. _cheungdaven: https://github.com/cheungdaven/QuatE.git
.. _Quaternion Knowledge Graph Embeddings:
Expand Down Expand Up @@ -1088,14 +1008,6 @@ class MuRP(PointwiseModel):
Args:
config (object): Model configuration parameters.
Examples:
>>> from pykg2vec.models.pairwise import MuRP
>>> from pykg2vec.utils.trainer import Trainer
>>> model = MuRP()
>>> trainer = Trainer(model=model)
>>> trainer.build_model()
>>> trainer.train_model()
.. _Multi-relational Poincaré Graph Embeddings:
https://arxiv.org/abs/1905.09791
Expand Down
Loading

0 comments on commit 6fbc435

Please sign in to comment.