Skip to content

Commit

Permalink
Merge branch 'master' of https://github.com/qibinc/KBRD
Browse files Browse the repository at this point in the history
  • Loading branch information
qibinc committed Oct 7, 2019
2 parents 4c40617 + fb2b74d commit f230357
Show file tree
Hide file tree
Showing 9 changed files with 23 additions and 23 deletions.
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from parlai.core.torch_agent import Output, TorchAgent
from parlai.core.utils import round_sigfigs

from .modules import RippleNet
from .modules import KBRD



Expand Down Expand Up @@ -73,11 +73,11 @@ def nltk_tokenize(text):

return full_text_embeddings

class RipplenetAgent(TorchAgent):
class KBRDAgent(TorchAgent):
@classmethod
def add_cmdline_args(cls, argparser):
"""Add command-line arguments specifically for this agent."""
super(RipplenetAgent, cls).add_cmdline_args(argparser)
super(KBRDAgent, cls).add_cmdline_args(argparser)
agent = argparser.add_argument_group("Arguments")
agent.add_argument("-ne", "--n-entity", type=int)
agent.add_argument("-nr", "--n-relation", type=int)
Expand All @@ -94,14 +94,14 @@ def add_cmdline_args(cls, argparser):
"-lr", "--learningrate", type=float, default=3e-3, help="learning rate"
)
agent.add_argument("-nb", "--num-bases", type=int, default=8)
RipplenetAgent.dictionary_class().add_cmdline_args(argparser)
KBRDAgent.dictionary_class().add_cmdline_args(argparser)
return agent

def __init__(self, opt, shared=None):
super().__init__(opt, shared)
init_model, is_finetune = self._get_init_model(opt, shared)

self.id = "RipplenetAgent"
self.id = "KBRDAgent"
self.n_entity = opt["n_entity"]
self.n_hop = opt["n_hop"]
self.n_memory = opt["n_memory"]
Expand All @@ -125,7 +125,7 @@ def __init__(self, opt, shared=None):
# entity_text_emb = _load_text_embeddings(entity2entityId, opt["dim"], abstract_path)

# encoder captures the input text
self.model = RippleNet(
self.model = KBRD(
n_entity=opt["n_entity"],
n_relation=opt["n_relation"],
dim=opt["dim"],
Expand Down Expand Up @@ -154,9 +154,9 @@ def __init__(self, opt, shared=None):
opt["learningrate"],
)

elif "ripplenet" in shared:
elif "kbrd" in shared:
# copy initialized data from shared table
self.model = shared["ripplenet"]
self.model = shared["kbrd"]
self.kg = shared["kg"]
self.movie_ids = shared["movie_ids"]
self.optimizer = shared["optimizer"]
Expand Down Expand Up @@ -204,7 +204,7 @@ def reset_metrics(self):
def share(self):
"""Share internal states."""
shared = super().share()
shared["ripplenet"] = self.model
shared["kbrd"] = self.model
shared["kg"] = self.kg
shared["movie_ids"] = self.movie_ids
shared["optimizer"] = self.optimizer
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -380,7 +380,7 @@ def _edge_list(kg, n_entity, hop):

return [(h, t, relation_idx[r]) for h, t, r in edge_list if relation_cnt[r] > 1000], len(relation_idx)

class RippleNet(nn.Module):
class KBRD(nn.Module):
def __init__(
self,
n_entity,
Expand All @@ -397,7 +397,7 @@ def __init__(
entity_text_emb,
num_bases
):
super(RippleNet, self).__init__()
super(KBRD, self).__init__()

self.n_entity = n_entity
self.n_relation = n_relation
Expand Down
10 changes: 5 additions & 5 deletions parlai/agents/transformer_rec/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@

from parlai.core.torch_generator_agent import TorchGeneratorModel
from parlai.core.utils import neginf
from parlai.agents.ripplenet.modules import RippleNet
from parlai.agents.ripplenet.ripplenet import _load_kg_embeddings
from parlai.agents.kbrd.modules import KBRD
from parlai.agents.kbrd.kbrd import _load_kg_embeddings


def _normalize(tensor, norm_layer):
Expand Down Expand Up @@ -565,7 +565,7 @@ def __init__(self, opt, dictionary):
open(os.path.join(opt["datapath"], "redial", "entity2entityId.pkl"), "rb")
)
# entity_kg_emb = _load_kg_embeddings(entity2entityId, opt["dim"], "sub_joined_embeddings.tsv")
self.ripplenet = RippleNet(opt['n_entity'],opt['n_relation'],opt['dim'],opt['n_hop'],opt['kge_weight'],opt['l2_weight'],opt['n_memory'],opt['item_update_mode'],opt['using_all_hops'], kg, None, None, num_bases=8)
self.ripplenet = KBRD(opt['n_entity'],opt['n_relation'],opt['dim'],opt['n_hop'],opt['kge_weight'],opt['l2_weight'],opt['n_memory'],opt['item_update_mode'],opt['using_all_hops'], kg, None, None, num_bases=8)
state_dict = torch.load('saved/both_rgcn_1')['model']
# state_dict = OrderedDict([('ripplenet.' + key, state_dict[key]) for key in state_dict])
self.ripplenet.load_state_dict(state_dict)
Expand All @@ -575,10 +575,10 @@ def __init__(self, opt, dictionary):
# self.user_representation_to_bias_1 = nn.Linear(opt['dim'], 256)
# self.user_representation_to_bias_2 = nn.Linear(256, 2048)
# self.user_representation_to_bias_3 = nn.Linear(2048, len(dictionary))
for param in self.ripplenet.parameters():
for param in self.kbrd.parameters():
param.requires_grad = False

# self.ripplenet.user_representation(item_list)
# self.kbrd.user_representation(item_list)

def reorder_encoder_states(self, encoder_states, indices):
enc, mask = encoder_states
Expand Down
4 changes: 2 additions & 2 deletions parlai/core/torch_generator_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -555,7 +555,7 @@ def train_step(self, batch):
self.zero_grad()

if getattr(batch, 'movies', None):
assert hasattr(self.model, 'ripplenet')
assert hasattr(self.model, 'kbrd')
self.model.user_representation, _ = self.model.ripplenet.user_representation(batch.movies)
self.model.user_representation = self.model.user_representation.detach()
try:
Expand Down Expand Up @@ -594,7 +594,7 @@ def eval_step(self, batch):
self.model.eval()
cand_scores = None
if getattr(batch, 'movies', None):
assert hasattr(self.model, 'ripplenet')
assert hasattr(self.model, 'kbrd')
self.model.user_representation, _ = self.model.ripplenet.user_representation(batch.movies)
self.model.user_representation = self.model.user_representation.detach()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@
parser.set_defaults(
task="redial",
dict_tokenizer="split",
model="ripplenet",
model="kbrd",
dict_file="saved/tmp",
model_file="saved/ripplenet",
model_file="saved/kbrd",
fp16=True,
batchsize=256,
n_entity=64368,
Expand Down
2 changes: 1 addition & 1 deletion scripts/both.sh
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,6 @@ let num_runs=32

for i in $(seq 0 $((num_runs-1)));
do
CUDA_VISIBLE_DEVICES=2 python parlai/tasks/redial/train_ripplenet.py -mf saved/both_rgcn_$i
CUDA_VISIBLE_DEVICES=2 python parlai/tasks/redial/train_kbrd.py -mf saved/both_rgcn_$i
done

2 changes: 1 addition & 1 deletion scripts/onlymovie.sh
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,6 @@ let num_runs=32

for i in $(seq 0 $((num_runs-1)));
do
CUDA_VISIBLE_DEVICES=0 python parlai/tasks/redial/train_ripplenet.py -mf saved/onlymovie_$i
CUDA_VISIBLE_DEVICES=0 python parlai/tasks/redial/train_kbrd.py -mf saved/onlymovie_$i
done

2 changes: 1 addition & 1 deletion scripts/show_bias.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def setup_args(parser=None):
agent = create_agent(opt, requireModelExists=True)
entity2entityId = pkl.load(open('data/redial/entity2entityId.pkl', 'rb'))

up, _ = agent.model.ripplenet.user_representation([
up, _ = agent.model.kbrd.user_representation([
list(map(lambda x: entity2entityId[x], movie_entities))
])

Expand Down

0 comments on commit f230357

Please sign in to comment.