Skip to content

Commit

Permalink
Update align and subtype
Browse files Browse the repository at this point in the history
  • Loading branch information
Catchxu committed Jun 3, 2024
1 parent 294e4ef commit d41de0d
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 10 deletions.
12 changes: 6 additions & 6 deletions src/stands/align.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,8 @@ def fit(self, generator: GeneratorAD, raw: Dict[str, Any]):
t.set_description(f'Train Epochs')

# generate embeddings
z_ref = self.G.extract.encode(ref_g, ref_g[0].srcdata['gene'])
z_tgt = self.G.extract.encode(tgt_g, tgt_g[0].srcdata['gene'])
z_ref = self.G.extract.encode(ref_g, ref_g.ndata['gene'])
z_tgt = self.G.extract.encode(tgt_g, tgt_g.ndata['gene'])

self.UpdateD(z_ref, z_tgt)
self.UpdateM(z_ref, z_tgt)
Expand All @@ -63,8 +63,8 @@ def fit(self, generator: GeneratorAD, raw: Dict[str, Any]):

self.M.eval()
with torch.no_grad():
z_ref = self.G.extract.encode(ref_g, ref_g[0].srcdata['gene'])
z_tgt = self.G.extract.encode(tgt_g, tgt_g[0].srcdata['gene'])
z_ref = self.G.extract.encode(ref_g, ref_g.ndata['gene'])
z_tgt = self.G.extract.encode(tgt_g, ref_g.ndata['gene'])
_, _, m = self.M(z_ref, z_tgt)
pair_id = list(ref_g.nodes().cpu().numpy()[m.argmax(axis=1)])
ref_g = dgl.node_subgraph(ref_g, pair_id)
Expand Down Expand Up @@ -218,7 +218,7 @@ def fit(self, raw: Dict[str, Any], generator: GeneratorAD,
tqdm.write('Datasets have been corrected.\n')
return adata

def init_model(self, generator: GeneratorAD, raw, weight_dir):
def init_model(self, generator: GeneratorAD, raw):
z_dim = generator.extract.z_dim
self.G = GeneratorBC(generator.extract, raw['data_n'], z_dim).to(self.device)
self.D = Discriminator(raw['gene_dim'], raw['patch_size']).to(self.device)
Expand Down Expand Up @@ -269,6 +269,6 @@ def UpdateG(self, blocks):
Loss_adv = - torch.mean(d)

# store generator loss for printing training information and backward
self.G_loss = self.weight['w_rec'] * Loss_rec + self.weight['w_adv'] * Loss_adv
self.G_loss = self.weight['w_rec']*Loss_rec + self.weight['w_adv']*Loss_adv
self.G_loss.backward()
self.opt_G.step()
8 changes: 4 additions & 4 deletions src/stands/subtype.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,16 +40,16 @@ def generate_z_res(self, graph: dgl.DGLGraph):
'''Generate reconstructed data'''
self.G.eval()
if self.G.extract.only_ST:
z, fake_g = self.G.STforward([graph, graph], graph.ndata['gene'])
z, fake_g = self.G.STforward(graph, graph.ndata['gene'])
res_g = graph.ndata['gene'] - fake_g.detach()
res_z = self.C.STforward([graph, graph], res_g)
res_z = self.C.STforward(graph, res_g)

else:
z, fake_g, fake_p = self.G.Fullforward(
[graph, graph], graph.ndata['gene'], graph.ndata['patch']
graph, graph.ndata['gene'], graph.ndata['patch']
)
res_g = graph.ndata['gene'] - fake_g.detach()
res_p = graph.ndata['patch'] - fake_p.detach()
res_z = self.C.Fullforward([graph, graph], res_g, res_p)
res_z = self.C.Fullforward(graph, res_g, res_p)

return z, res_z

0 comments on commit d41de0d

Please sign in to comment.