|
| 1 | +from __future__ import print_function |
| 2 | +from pysam import Samfile |
| 3 | +from numpy.random import random |
| 4 | +from heapq import heappush, heappushpop |
| 5 | +from os import path, makedirs |
| 6 | +from sys import argv, stdout |
| 7 | + |
| 8 | + |
| 9 | +def main(): |
| 10 | + |
| 11 | + fname = argv[1] |
| 12 | + millions = list(map(lambda x: 1e6 * float(x), argv[2:])) |
| 13 | + print(millions) |
| 14 | + stdout.flush() |
| 15 | + return subsample(fname, millions, paired=True) |
| 16 | + |
| 17 | + |
| 18 | +def subsample(fn, ns=None, paired=False): |
| 19 | + if ns is None: |
| 20 | + fn, ns = fn |
| 21 | + sample = [] |
| 22 | + count = 0 |
| 23 | + outdir_base = path.join(path.dirname(fn), 'subset') |
| 24 | + sf = Samfile(fn) |
| 25 | + try: |
| 26 | + i_weight = float(sf.mapped)/max(ns) |
| 27 | + print("Read out ", i_weight) |
| 28 | + except ValueError: |
| 29 | + i_weight = 0.0 |
| 30 | + for read in sf: |
| 31 | + i_weight += 1 |
| 32 | + print("Counted ", i_weight) |
| 33 | + i_weight /= float(max(ns)) |
| 34 | + sf = Samfile(fn) |
| 35 | + |
| 36 | + if paired: |
| 37 | + read_2s = {} |
| 38 | + print(fn, count, i_weight) |
| 39 | + for i, read in enumerate(sf): |
| 40 | + key = random()**i_weight |
| 41 | + if not paired or read.is_read1: |
| 42 | + if len(sample) < max(ns): |
| 43 | + heappush(sample, (key, i+count, read)) |
| 44 | + else: |
| 45 | + dropped = heappushpop(sample, (key, i+count, read)) |
| 46 | + if paired: |
| 47 | + read_2s.pop(dropped[-1].qname, None) |
| 48 | + elif paired: |
| 49 | + read_2s[read.qname] = read |
| 50 | + else: |
| 51 | + assert ValueError("I don't know how we got here") |
| 52 | + |
| 53 | + |
| 54 | + count += i |
| 55 | + |
| 56 | + for n in ns: |
| 57 | + outdir = outdir_base + '{:04.1f}M'.format(n/1e6) |
| 58 | + try: |
| 59 | + makedirs(outdir) |
| 60 | + except OSError: |
| 61 | + pass |
| 62 | + sampN = sorted(sample, reverse=True)[:int(n)] |
| 63 | + print("Kept {: >12,} of {: >12,} reads".format(len(sampN), count)) |
| 64 | + print(fn, '->', outdir) |
| 65 | + stdout.flush() |
| 66 | + of = Samfile(path.join(outdir, 'accepted_hits.bam'), |
| 67 | + mode='wb', template=sf) |
| 68 | + sample.sort(key=lambda heap_item: (heap_item[-1].tid, heap_item[-1].pos)) |
| 69 | + missing_mates = 0 |
| 70 | + for key, pos, read in sampN: |
| 71 | + of.write(read) |
| 72 | + if paired and read.is_proper_pair: |
| 73 | + if read.qname not in read_2s: |
| 74 | + missing_mates += 1 |
| 75 | + continue |
| 76 | + of.write(read_2s[read.qname]) |
| 77 | + of.close() |
| 78 | + sf.close() |
| 79 | + print(missing_mates) |
| 80 | + return [count for key, read, count in sample] |
| 81 | + |
| 82 | +if __name__ == "__main__": |
| 83 | + subset_results = main() |
0 commit comments