Skip to content

Commit

Permalink
updated dependencies
Browse files Browse the repository at this point in the history
  • Loading branch information
alitinet committed Oct 31, 2022
1 parent e68f5f7 commit 8567b5a
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 10 deletions.
8 changes: 2 additions & 6 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ requires = ["hatchling"]
[project]
name = "multigrate"
version = "0.0.1"
description = "A very interesting piece of code"
description = "Multigrate: multimodal data integration for single-cell genomics."
readme = "README.md"
requires-python = ">=3.9,<3.11"
license = {file = "LICENSE"}
Expand All @@ -20,11 +20,7 @@ urls.Documentation = "https://multigrate.readthedocs.io/"
urls.Source = "https://github.com/theislab/multigrate"
urls.Home-page = "https://github.com/theislab/multigrate"
dependencies = [
"anndata==0.8.0",
"scikit-learn==1.0.2",
"statsmodels==0.13.2",
"numba==0.55.1",
"tables==3.7.0",
"scanpy==1.9.0",
"scvi-tools==0.14.6",
"torchmetrics==0.6.0",
"matplotlib"
Expand Down
16 changes: 12 additions & 4 deletions src/multigrate/module/_multivae_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -523,15 +523,19 @@ def loss(
else:
integ_loss = torch.tensor(0.0).to(self.device)
if self.mmd == "latent" or self.mmd == "both":
integ_loss += self._calc_integ_loss(z_joint, integrate_on).to(self.device)
integ_loss += self.calc_integ_loss(z_joint, integrate_on).to(
self.device
)
if self.mmd == "marginal" or self.mmd == "both":
for i in range(len(masks)):
for j in range(i + 1, len(masks)):
idx_where_to_calc_mmd = torch.eq(
masks[i] == masks[j],
torch.eq(masks[i], torch.ones_like(masks[i])),
)
if idx_where_to_calc_mmd.any(): # if need to calc mmd for a group between modalities
if (
idx_where_to_calc_mmd.any()
): # if need to calc mmd for a group between modalities
marginal_i = z_marginal[:, i, :][idx_where_to_calc_mmd]
marginal_j = z_marginal[:, j, :][idx_where_to_calc_mmd]
marginals = torch.cat([marginal_i, marginal_j])
Expand All @@ -542,13 +546,17 @@ def loss(
]
).to(self.device)

integ_loss += self._calc_integ_loss(marginals, modalities).to(self.device)
integ_loss += self.calc_integ_loss(
marginals, modalities
).to(self.device)

for i in range(len(masks)):
marginal_i = z_marginal[:, i, :]
marginal_i = marginal_i[masks[i]]
group_marginal = integrate_on[masks[i]]
integ_loss += self._calc_integ_loss(marginal_i, group_marginal).to(self.device)
integ_loss += self.calc_integ_loss(marginal_i, group_marginal).to(
self.device
)

cycle_loss = (
torch.tensor(0.0).to(self.device)
Expand Down

0 comments on commit 8567b5a

Please sign in to comment.