Skip to content

Commit

Permalink
inferstats
Browse files Browse the repository at this point in the history
  • Loading branch information
bicycle1885 committed Nov 21, 2018
1 parent d0acaca commit b6969ad
Show file tree
Hide file tree
Showing 4 changed files with 43 additions and 25 deletions.
2 changes: 1 addition & 1 deletion src/CellFishing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,10 @@ using Distributions: NegativeBinomial, logcdf, logccdf
include("svd.jl")
include("bitvectors.jl")
include("HammingIndexes.jl")
include("cmatrix.jl")
include("ematrix.jl")
include("preprocessor.jl")
include("features.jl")
include("cmatrix.jl")
include("index.jl")
include("search.jl")
include("degenes.jl")
Expand Down
2 changes: 1 addition & 1 deletion src/index.jl
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ function CellIndex(
end
# preprocess expression profiles
preproc = Preprocessor(Y, transformer, PCA(n_dims, randomize=randomize), normalize, standardize, Float32(scalefactor))
X = preprocess(preproc, Y, false)
X = preprocess(preproc, Y, :database)
# hash preprocessed data
lshashes = LSHash{T}[]
for _ in 1:n_lshashes
Expand Down
43 changes: 28 additions & 15 deletions src/preprocessor.jl
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ function Preprocessor(
)
end

function preprocess(proc::Preprocessor, Y::ExpressionMatrix, inferparams::Bool)
function preprocess(proc::Preprocessor, Y::ExpressionMatrix, inferstats::Symbol)
perm = zeros(Int, length(proc.featurenames))
for (i, name) in enumerate(proc.featurenames)
perm[i] = get(Y.featuremap, name, 0)
Expand All @@ -116,23 +116,36 @@ function preprocess(proc::Preprocessor, Y::ExpressionMatrix, inferparams::Bool)
X .*= proc.scalefactor ./ sum(X, dims=1)
end
transform!(proc.transformer, X)
n1 = proc.n
m, n2 = size(X)
sums1 = proc.sums
soss1 = proc.soss
sums2 = vec(sum(X, dims=2))
soss2 = vec(sum(X.^2, dims=2))
μ = zeros(Float32, m)
σ = zeros(Float32, m)
@inbounds for i in 1:m
μ[i], σ[i] = mean_and_std(
sums1[i], soss1[i], n1,
sums2[i], soss2[i], n2,
)
m, n = size(X)
@assert length(proc.sums) == length(proc.soss) == m
if inferstats == :query
μ = vec(mean(X, dims=2))
σ = vec(std(X, dims=2))
@inbounds for i in 1:m
# insert the std of the database if no variability in the query cells
if σ[i] == 0
σ[i] = sqrt((proc.soss[i] - 2 * μ[i] * proc.sums[i]) / proc.n + μ[i]^2)
end
end
elseif inferstats == :database
μ = proc.sums ./ proc.n
σ = sqrt.((proc.soss .- 2 .* μ .* proc.sums) ./ proc.n .+ μ.^2)
else
@assert inferstats == :both
sums = vec(sum(X, dims=2))
soss = vec(sum(X.^2, dims=2))
μ = zeros(Float32, m)
σ = zeros(Float32, m)
@inbounds for i in 1:m
μ[i], σ[i] = mean_and_std(
proc.sums[i], proc.soss[i], proc.n,
sums[i], soss[i], n,
)
end
end
if proc.standardize
invstd = inv.(σ)
@inbounds for j in 1:n2, i in 1:m
@inbounds for j in 1:n, i in 1:m
X[i,j] = (X[i,j] - μ[i]) * invstd[i]
end
else
Expand Down
21 changes: 13 additions & 8 deletions src/search.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,36 +34,41 @@ end
counts::AbstractMatrix,
featurenames::AbstractVector{String},
index::CellIndex;
inferparams::Bool=true
inferstats::Symbol=:both
) -> NearestCells
Find `k`-nearest neighboring cells from `index`.
If `inferparams=true`, feature (gene) parameters are inferred from the `counts`
argument. Note that `counts` should not be biased and have enough cells to
properly infer the parameters.
The `inferstats` parameter specifies the way to infer feature (gene)
statistics. If it is `:both` (default), they are inferred both from the query
and the database cells. If it is `:query`, they are inferred only from the
query cells. If it is `:database`, they are inferred only from the database
cells. No other values are allowed.
"""
function findneighbors(
k::Integer,
counts::AbstractMatrix,
featurenames::AbstractVector{String},
index::CellIndex;
inferparams::Bool=true)
return findneighbors(k, ExpressionMatrix(counts, featurenames), index; inferparams=inferparams)
inferstats::Symbol=:both)
return findneighbors(k, ExpressionMatrix(counts, featurenames), index; inferstats=inferstats)
end

function findneighbors(k::Integer, Y::ExpressionMatrix, index::CellIndex; inferparams::Bool=true)
function findneighbors(k::Integer, Y::ExpressionMatrix, index::CellIndex; inferstats::Symbol=:both)
if k < 0
throw(ArgumentError("negative k"))
end
if inferstats (:both, :query, :database)
throw(ArgumentError("invalid value for the inferstats parameter"))
end
n = size(Y, 2)
L = length(index.lshashes)
T = bitvectype(first(index.lshashes))
@assert L 1
rtstats = index.rtstats
tic!(rtstats)
# preprocess
X = preprocess(index.preproc, Y, inferparams)
X = preprocess(index.preproc, Y, inferstats)
# allocate temporary memories
neighbors = Matrix{Int}(undef, k * L, n)
neighbors_tmp = Matrix{Int}(undef, k, n)
Expand Down

0 comments on commit b6969ad

Please sign in to comment.