Skip to content

Commit

Permalink
Update AnomalyDetect
Browse files Browse the repository at this point in the history
  • Loading branch information
Catchxu committed May 30, 2024
1 parent 0702170 commit 613af0c
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 7 deletions.
60 changes: 54 additions & 6 deletions src/stands/anomaly.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,12 @@
import os
import dgl
import numpy as np
import pandas as pd
import anndata as ad
import torch
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm
from typing import Optional, Dict, Union, Any
from sklearn.preprocessing import LabelEncoder

from .model import GeneratorAD, Discriminator
from .model import GMMWithPrior
from ._utils import select_device, seed_everything, calculate_gradient_penalty


Expand Down Expand Up @@ -69,6 +65,33 @@ def fit(self, ref: Dict[str, Any], only_ST: bool = False, weight_dir: Optional[s
t.update(1)

tqdm.write('Training has been finished.')

@torch.no_grad()
def predict(self, tgt: Dict[str, Any], run_gmm: bool = True):
'''Detect anomalous spots on target graph'''

tgt_g = tgt['graph']
dataset = dgl.dataloading.DataLoader(
tgt_g, tgt_g.nodes(), self.sampler, batch_size=self.batch_size,
shuffle=False, drop_last=False, num_workers=0, device=self.device
)

self.G.eval()
self.D.eval()
tqdm.write('Detect anomalous spots on target dataset...')

ref_score = self.score(self.dataset)
tgt_score = self.score(dataset)

tqdm.write('Anomalous spots have been detected.\n')

if run_gmm:
gmm = GMMWithPrior(ref_score)
threshold = gmm.fit(tgt_score=tgt_score)
tgt_label = [1 if s >= threshold else 0 for s in tgt_score]
return tgt_score, tgt_label
else:
return tgt_score

def init_model(self, ref, weight_dir):
self.G = GeneratorAD(ref['gene_dim'], ref['patch_size'], self.only_ST).to(self.device)
Expand Down Expand Up @@ -119,7 +142,7 @@ def UpdateD(self, blocks):
d1 = torch.mean(self.D.SCforward(real_g))
d2 = torch.mean(self.D.SCforward(fake_g.detach()))
gp = calculate_gradient_penalty(self.D, real_g, fake_g.detach())

else:
_, fake_g, fake_p = self.G.fullforward(
blocks, blocks[0].srcdata['gene'], blocks[1].srcdata['patch']
Expand Down Expand Up @@ -179,3 +202,28 @@ def UpdateG(self, blocks):

# updating memory block with generated embeddings, fake_z
self.G.Memory.update_mem(z)

def score(self, dataset):
# calucate anomaly score
dis = []
for _, _, blocks in dataset:
if self.only_ST:
# generate fake data
_, fake_g = self.G.STforward(blocks, blocks[0].srcdata['gene'])
d = self.D.SCforward(fake_g.detach())

else:
_, fake_g, fake_p = self.G.fullforward(
blocks, blocks[0].srcdata['gene'], blocks[1].srcdata['patch']
)

d = self.D.fullforward(fake_g.detach(), fake_p.detach())

dis.append(d.cpu().detach())

# Normalize anomaly scores
dis = torch.mean(torch.cat(dis, dim=0), dim=1).numpy()
score = (dis.max() - dis)/(dis.max() - dis.min())

score = list(score.reshape(-1))
return score
2 changes: 1 addition & 1 deletion src/stands/model/GMM.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@


class GMMWithPrior(object):
def __init__(self, ref_score, random_state=None, max_iter=100, tol=1e-3, prior_beta=[1,10]):
def __init__(self, ref_score, random_state=None, max_iter=100, tol=1e-5, prior_beta=[1,10]):
self.ref_score = ref_score
self.max_iter = max_iter
self.tol = tol
Expand Down

0 comments on commit 613af0c

Please sign in to comment.