Skip to content

Commit

Permalink
Add dummy_to_categorical() and update write_mtx()
Browse files Browse the repository at this point in the history
  • Loading branch information
nh3 committed Nov 9, 2022
1 parent d0dece0 commit ec686e2
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 14 deletions.
1 change: 1 addition & 0 deletions sctk/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
)
from ._utils import (
cross_table,
dummy_to_categorical,
expand_feature_space,
find_top_expressed_genes,
lognorm_to_counts,
Expand Down
81 changes: 67 additions & 14 deletions sctk/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,9 @@ def lognorm_to_counts(X, norm_sum=1e4, n_counts=None, force=False, rounding=True
if rounding:
X_counts.data = np.round(X_counts.data).astype(np.int32)
return X_counts
warnings.warn(f"Non-integer residuals too large (res = {res}), "
"try inferring size_factor")
warnings.warn(
f"Non-integer residuals too large (res = {res}), try inferring size_factor"
)
x_min = np.array([X_expm1.getrow(i).data.min() for i in range(X_expm1.shape[0])])
size_factor = 1 / x_min
X_counts = (X_expm1.T * sp.csr_matrix(sp.diags(size_factor))).T
Expand Down Expand Up @@ -295,9 +296,11 @@ def cross_table(
elif normalise == "xy":
normaliser = np.tile(y_sizes.reshape(1, ny), (nx, 1))
else:
normaliser = np.tile(x_sizes.reshape(nx, 1), (1, ny)) + np.tile(
y_sizes.reshape(1, ny), (nx, 1)
) - crs_tbl.values
normaliser = (
np.tile(x_sizes.reshape(nx, 1), (1, ny))
+ np.tile(y_sizes.reshape(1, ny), (nx, 1))
- crs_tbl.values
)
crs_tbl = (crs_tbl / normaliser).round(4)
if sort in ("x", "index", "xy", "yx", "both"):
crs_tbl = crs_tbl.sort_index()
Expand All @@ -306,8 +309,11 @@ def cross_table(
return crs_tbl


def run_celltypist(ad, model, use_rep="X", require_lognorm=False, min_prob=None, key_added="ctp_pred"):
def run_celltypist(
ad, model, use_rep="X", require_lognorm=False, min_prob=None, key_added="ctp_pred"
):
import celltypist

X, var_names = _find_rep(ad, use_rep)
aux_ad = anndata.AnnData(X=X)
aux_ad.obs_names = ad.obs_names
Expand Down Expand Up @@ -951,19 +957,34 @@ def write_10x_h5(
f.create_dataset(path, data=dvalue.astype(dtype), dtype=dtype, compression="gzip")


def write_mtx(adata, fname_prefix="", var=None, obs=None, use_raw=False):
def write_mtx(
adata,
fname_prefix="",
var=["gene_ids"],
obs=None,
use_raw=False,
output_version="v3",
feature_type="Gene Expression",
):
"""Export AnnData object to mtx formt
* Parameters
+ adata : AnnData
An AnnData object
+ fname_prefix : str
Prefix of the exported files. If not empty and not ending with '/' or '_',
a '_' will be appended. Full names will be <fname_prefix>matrix.mtx,
<fname_prefix>genes.tsv, <fname_prefix>barcodes.tsv
a '_' will be appended. Full names will be <fname_prefix>matrix.mtx(.gz),
<fname_prefix>genes.tsv/(features.tsv.gz), <fname_prefix>barcodes.tsv(.gz)
+ var : list
A list of column names to be exported to gene table
A list of extra column names to be exported to gene table, default ["gene_ids"]
+ obs : list
A list of column names to be exported to barcode/cell table
A list of extra column names to be exported to barcode/cell table, default None
+ use_raw: boolean
Whether to write `adata.raw` instead of `adata`, default False
+ output_version: str
Write v2 or v3 Cellranger mtx outputs, default v3
+ feature_type: str
Text added as the last column of "features.tsv.gz", only relevant when
`output_version="v3"`
"""
if fname_prefix and not (fname_prefix.endswith("/") or fname_prefix.endswith("_")):
fname_prefix = fname_prefix + "_"
Expand All @@ -984,9 +1005,14 @@ def write_mtx(adata, fname_prefix="", var=None, obs=None, use_raw=False):
n_var, n_obs, n_entry
)
df = pd.DataFrame({"col": mat.col + 1, "row": mat.row + 1, "data": mat.data})
mtx_fname = fname_prefix + "matrix.mtx"
gene_fname = fname_prefix + "genes.tsv"
barcode_fname = fname_prefix + "barcodes.tsv"
if output_version == "v2":
mtx_fname = fname_prefix + "matrix.mtx"
gene_fname = fname_prefix + "genes.tsv"
barcode_fname = fname_prefix + "barcodes.tsv"
else:
mtx_fname = fname_prefix + "matrix.mtx.gz"
gene_fname = fname_prefix + "features.tsv.gz"
barcode_fname = fname_prefix + "barcodes.tsv.gz"
with open(mtx_fname, "a", encoding="utf8") as fh:
fh.write(header)
df.to_csv(fh, sep=" ", header=False, index=False)
Expand All @@ -996,6 +1022,8 @@ def write_mtx(adata, fname_prefix="", var=None, obs=None, use_raw=False):
var_df = adata.var[var].reset_index(level=0)
if not var:
var_df["gene"] = var_df["index"]
if output_version != "v2":
var_df["feature_type"] = feature_type
var_df.to_csv(gene_fname, sep="\t", header=False, index=False)


Expand Down Expand Up @@ -1204,6 +1232,31 @@ def random_partition(
adata.obs[key_added] = part_idx.astype(str)


def dummy_to_categorical(mat, random_state=0):
"""Convert a sparse dummy matrix into a list of category indices, when a row
has multiple entries, randomly assign to one of them.
*Parameters
+ mat: csr_matrix
A csr_matrix of ones, with cells on the rows and nhoods on the columns
+ random_state: int
Seed for numpy RNG
"""
np.random.seed(random_state)
nrow = mat.shape[0]
nhoods = []
for i in range(nrow):
k = (mat[i, ] == 1).indices
if k.size == 1:
idx = k[0]
elif k.size > 1:
idx = np.random.choice(k, 1)[0]
else:
idx = -1
nhoods.append(idx)
return nhoods


def pseudo_bulk(adata, groupby, use_rep="X", FUN=np.mean):
"""Make pseudo bulk data from grouped sc data"""
grouping = adata.obs[groupby]
Expand Down

0 comments on commit ec686e2

Please sign in to comment.