Skip to content

Commit

Permalink
Updating seqweaver (FunctionLab#188)
Browse files Browse the repository at this point in the history
* init get_new_training_data script and strand spec

* refactor main script, fix strand-spec

* debugging and testing update_seqweaver

* fixed h5 file output

* added class modules for training seqweaver

* added validation/training strat

* debugging main update seqweaver module

* strand backward compatibility

* further fixes to backward compatibility

* val partition fix

* indexing fix for backward compatibility

* addressing kathy's comments

* fixed relative paths in update_seqweaver

* handling strand=. as None
  • Loading branch information
aviyalitman authored Oct 31, 2023
1 parent 86f4df3 commit 6ff4a5e
Show file tree
Hide file tree
Showing 3 changed files with 244 additions and 9 deletions.
70 changes: 70 additions & 0 deletions models/seqweaver.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
"""
Seqweaver architecture (Park & Troyanskaya, 2021).
"""
import torch
import torch.nn as nn


class LambdaBase(nn.Sequential):
def __init__(self, fn, *args):
super(LambdaBase, self).__init__(*args)
self.lambda_func = fn

def forward_prepare(self, input):
output = []
for module in self._modules.values():
output.append(module(input))
return output if output else input


class Lambda(LambdaBase):
def forward(self, input):
return self.lambda_func(self.forward_prepare(input))


class Seqweaver(nn.Module):

def __init__(self, n_classes): # 217 human, 43 mouse
super(Seqweaver, self).__init__()
self.model = nn.Sequential(
nn.Conv2d(4, 160, (1, 8)),
nn.ReLU(),
nn.MaxPool2d((1, 4), (1, 4)),
nn.Dropout(0.1),
nn.Conv2d(160, 320, (1, 8)),
nn.ReLU(),
nn.MaxPool2d((1, 4), (1, 4)),
nn.Dropout(0.1),
nn.Conv2d(320, 480, (1, 8)),
nn.ReLU(),
nn.Dropout(0.3))
self.fc = nn.Sequential(
Lambda(lambda x: torch.reshape(x, (x.size(0), 25440))),
nn.Sequential(
Lambda(lambda x: x.reshape(1, -1)
if 1 == len(x.size()) else x),
nn.Linear(25440, n_classes)
),
nn.ReLU(),
nn.Sequential(
Lambda(lambda x: x.view(1, -1)
if 1 == len(x.size()) else x),
nn.Linear(n_classes, n_classes)
),
nn.Sigmoid(),
)

def forward(self, x):
x = x.unsqueeze(2)
x = self.model(x)
x = self.fc(x)
return x


def criterion():
return nn.BCELoss()


def get_optimizer(lr):
return (torch.optim.SGD,
{"lr": lr, "weight_decay": 1e-6, "momentum": 0.9})
152 changes: 152 additions & 0 deletions scripts/update_seqweaver.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
"""
This module provides the `UpdateSeqweaver` class, which wraps the master bed file
containing all of the features' binding sites parsed from CLIP-seq.
It supports new dataset construction and training for Seqweaver.
"""
import h5py
import gzip
import numpy as np
import sys

from selene_sdk.sequences.genome import Genome
from selene_sdk.targets.genomic_features import GenomicFeatures
from selene_sdk.samplers.dataloader import H5DataLoader
from selene_sdk.train_model import TrainModel
from selene_sdk.utils.config import load_path
from selene_sdk.utils.config_utils import parse_configs_and_run

class UpdateSeqweaver():
"""
Stores a dataset specifying sequence regions and features.
Accepts a tabix-indexed `*.bed` file with the following columns,
in order:
[chrom, start, end, feature, strand]
Parameters
----------
input_path : str
Path to the tabix-indexed dataset. Note that for the file to
be tabix-indexed, it must have been compressed with `bgzip`.
Thus, `input_path` should be a `*.gz` file with a
corresponding `*.tbi` file in the same directory.
output_path : str
Path to the output constructed-training data file.
feature_path : str
Path to a '\n'-delimited .txt file containing feature names.
hg_fasta : str
Path to an indexed FASTA file -- a `*.fasta` file with
a corresponding `*.fai` file in the same directory. This file
should contain the target organism's genome sequence.
"""
def __init__(self, input_path, train_path, validate_path, feature_path, hg_fasta, yaml_path, val_prop=0.1, sequence_len=1000):
"""
Constructs a new `UpdateSeqweaver` object.
"""
self.input_path = input_path
self.train_path = train_path
self.validate_path = validate_path
self.feature_path = feature_path
self.yaml_path = yaml_path
self.val_prop = val_prop

self.hg_fasta = hg_fasta

self.sequence_len = sequence_len

with open(self.feature_path, 'r') as handle:
self.feature_set = [line.split('\n')[0] for line in handle.readlines()]

def _from_midpoint(self, start, end):
"""
Computes start and end of the sequence about the peak midpoint.
Parameters
----------
start : int
The 0-based first position in the region.
end : int
One past the 0-based last position in the region.
Returns
-------
seq_start : int
Sequence start position about the peak midpoint.
seq_end : int
Sequence end position about the peak midpoint.
"""
region_len = end - start
midpoint = start + region_len // 2
seq_start = midpoint - np.floor(self.sequence_len / 2.)
seq_end = midpoint + np.ceil(self.sequence_len / 2.)

return int(seq_start), int(seq_end)

def construct_training_data(self):
"""
Construct training dataset from bed file and write to output_file.
Parameters
----------
output_path : str
Path to the output file for the constructed training data.
colname_file : str
Path to a .txt file containing newline-delimited feature names.
"""
list_of_regions = []
with gzip.open(self.input_path) as f:
for line in f:
line = [str(data,'utf-8') for data in line.strip().split()]
list_of_regions.append(line)

seqs = Genome(self.hg_fasta, blacklist_regions = 'hg19')
targets = GenomicFeatures(self.input_path,
features = self.feature_set, feature_thresholds = 0.5)

data_seqs = []
data_labels = []
for r in list_of_regions:
chrom, start, end, target, strand = r
start, end = int(start), int(end)
sstart, ssend = self._from_midpoint(start, end)

# 1 x 4 x 1000 bp
# get_encoding_from_coords : Converts sequence to one-hot-encoding for each of the 4 bases
dna_seq, has_unk = seqs.get_encoding_from_coords_check_unk(chrom, sstart, ssend, strand=strand)
if has_unk:
continue
if len(dna_seq) != self.sequence_len:
continue

# 1 x n_features
# get_feature_data: Computes which features overlap with the given region.
labels = targets.get_feature_data(chrom, start, end, strand=strand)

data_seqs.append(dna_seq)
data_labels.append(labels)

# partition some to validation before writing
val_count = int(np.floor(self.val_prop * len(data_seqs)))
validate_seqs = data_seqs[:val_count]
validate_labels = data_labels[:val_count]
training_seqs = data_seqs[val_count:]
training_labels = data_labels[val_count:]

with h5py.File(self.validate_path, "w") as fh:
fh.create_dataset("valid_sequences", data=np.array(validate_seqs, dtype=np.int64))
fh.create_dataset("valid_targets", data=np.array(validate_labels, dtype=np.int64))

with h5py.File(self.train_path, "w") as fh:
fh.create_dataset("train_sequences", data=np.array(training_seqs, dtype=np.int64))
fh.create_dataset("train_targets", data=np.array(training_labels, dtype=np.int64))

def _load_yaml(self):
# load yaml configuration
return load_path(self.yaml_path)

def train_model(self):
# load config file and train model
yaml_config = self._load_yaml()
parse_configs_and_run(yaml_config)
31 changes: 22 additions & 9 deletions selene_sdk/targets/genomic_features.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def _is_positive_row(start, end,


def _get_feature_data(chrom, start, end,
thresholds, feature_index_dict, get_feature_rows):
thresholds, feature_index_dict, get_feature_rows, strand=None):
"""
Generates a target vector for the given query region.
Expand All @@ -125,6 +125,9 @@ def _get_feature_data(chrom, start, end,
get_feature_rows : types.FunctionType
A function that takes coordinates and returns rows
(`list(tuple(int, int, str))`).
strand : {'+', '-'}, optional
The strand the sequence is located on. Default is None (no strand provided).
If '+' or '-' is passed in, only retrieve rows with the correct matching strand.
Returns
-------
Expand All @@ -133,7 +136,7 @@ def _get_feature_data(chrom, start, end,
`i`th feature is positive, and zero otherwise.
"""
rows = get_feature_rows(chrom, start, end)
rows = get_feature_rows(chrom, start, end, strand=strand)
return _fast_get_feature_data(
start, end, thresholds, feature_index_dict, rows)

Expand Down Expand Up @@ -303,7 +306,7 @@ def dfunc(self, *args, **kwargs):
return func(self, *args, **kwargs)
return dfunc

def _query_tabix(self, chrom, start, end):
def _query_tabix(self, chrom, start, end, strand=None):
"""
Queries a tabix-indexed `*.bed` file for features falling into
the specified region.
Expand All @@ -317,6 +320,9 @@ def _query_tabix(self, chrom, start, end):
The 0-based start position of the query coordinates.
end : int
One past the last position of the query coordinates.
strand : {'+', '-'}, optional
The strand the sequence is located on. Default is None (no strand provided).
If '+' or '-' is passed in, only retrieve rows with the correct matching strand.
Returns
-------
Expand All @@ -329,12 +335,16 @@ def _query_tabix(self, chrom, start, end):
"""
try:
return self.data.query(chrom, start, end)
tabix_query = self.data.query(chrom, start, end)
if strand == '+' or strand == '-':
return [line for line in tabix_query if str(line[4]) == strand] # strand specificity
else: # not strand specific
return tabix_query
except tabix.TabixError:
return None

@init
def is_positive(self, chrom, start, end):
def is_positive(self, chrom, start, end, strand=None):
"""
Determines whether the query the `chrom` queried contains any
genomic features within the :math:`[start, end)` region. If so,
Expand All @@ -357,11 +367,11 @@ def is_positive(self, chrom, start, end):
assume the error was the result of no features being present
in the queried region and return `False`.
"""
rows = self._query_tabix(chrom, start, end)
rows = self._query_tabix(chrom, start, end, strand=strand)
return _any_positive_rows(rows, start, end, self.feature_thresholds)

@init
def get_feature_data(self, chrom, start, end):
def get_feature_data(self, chrom, start, end, strand=None):
"""
Computes which features overlap with the given region.
Expand All @@ -373,6 +383,9 @@ def get_feature_data(self, chrom, start, end):
The 0-based first position in the region.
end : int
One past the 0-based last position in the region.
strand : {'+', '-'}, optional
The strand the sequence is located on. Default is None (no strand provided).
If '+' or '-' is passed in, only retrieve rows with the correct matching strand.
Returns
-------
Expand All @@ -388,7 +401,7 @@ def get_feature_data(self, chrom, start, end):
"""
if self._feature_thresholds_vec is None:
features = np.zeros(self.n_features)
rows = self._query_tabix(chrom, start, end)
rows = self._query_tabix(chrom, start, end, strand=strand) # strand specificity
if not rows:
return features
for r in rows:
Expand All @@ -398,4 +411,4 @@ def get_feature_data(self, chrom, start, end):
return features
return _get_feature_data(
chrom, start, end, self._feature_thresholds_vec,
self.feature_index_dict, self._query_tabix)
self.feature_index_dict, self._query_tabix, strand=strand)

0 comments on commit 6ff4a5e

Please sign in to comment.