Permalink
Cannot retrieve contributors at this time
Name already in use
A tag already exists with the provided branch name. Many Git commands accept both tag and branch names, so creating this branch may cause unexpected behavior. Are you sure you want to create this branch?
sage_selection_public/estimate.py
Go to fileThis commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
executable file
154 lines (143 sloc)
6.39 KB
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/python3 | |
import argparse | |
import numpy as np | |
import joblib | |
import pandas as pd | |
import os | |
from os import path | |
from scipy.stats import binned_statistic_2d | |
from astropy.io import fits | |
from sage_selection import binning | |
from sage_selection.tools import get_logger, bins | |
from sage_selection import healpix_utils as hu | |
import multiprocessing as mp | |
LOGGER = get_logger('2MASS_bg') | |
class SelectionEstimator(object): | |
""" | |
Handle 2MASS density in J vs J-H space. | |
""" | |
def __init__(self, photometry='2mass', method='UniDAM', | |
full_data=None): | |
""" | |
Download/load from cache data and put it on the grid. | |
""" | |
self.background = {order: joblib.load('%s/background_%s_%d.joblib' % ( | |
path.dirname(path.abspath(__file__)), | |
photometry, order)) | |
for order in (3, 4, 5)} | |
LOGGER.info('starting selection function estimation') | |
self.variation = pd.read_csv('healpix_variations.csv') | |
self.full_data = full_data | |
self.method = method | |
def get_selection(self, cell, part_data, afilter='', | |
cell_level=3, store=True): | |
data = part_data[part_data['cell_%s' % cell_level] == cell] | |
LOGGER.info('Level %s, cell %s', cell_level, cell) | |
if len(data) < 50: | |
LOGGER.info('Data size: %d. Too few stars, passing', len(data)) | |
return [], [[]], [[]] | |
LOGGER.info('Data size: %d', len(data)) | |
iids = np.char.strip(np.array(data['id'], dtype=str)) | |
data = np.array([data['Jmag'], data['Jmag'] - data['Kmag']], | |
dtype=float).T | |
data = np.c_[data, np.arange(len(data))] | |
occupation = int(0.91 * len(data)**0.37) + 1 | |
LOGGER.info('Using minimum %s stars per bin', occupation) | |
selection, selection_err, mask = \ | |
binning.get_median_2d(self.background[cell_level][cell], | |
data[:, :2].T, | |
occupation=occupation, | |
add_bg_variance=True, shrink=True, | |
return_mask=True) | |
return iids, selection, selection_err | |
def get_foreground_indices(self, data): | |
""" | |
Map J-K / J data to the grid used for the background. | |
""" | |
hfg = binned_statistic_2d(data[:, 0], data[:, 1], values=None, | |
bins=bins, statistic='count', | |
expand_binnumbers=True) | |
return hfg.binnumber - 1 | |
def get_background_variation(self, cell, order): | |
children = hu.change_resolution(order, order + 1, cell) | |
bg = [self.background[order + 1][child] for child in children] | |
bg = np.array(bg) | |
bg[np.isnan(bg)] = 0. | |
return bg.std(axis=0) / bg.mean(axis=0) | |
def get_selection_in_cell(self, cell): | |
ids, sel, sel_err = self.get_selection(cell, self.full_data) | |
output = np.zeros((len(ids), 6)) | |
output[:, 0] = sel | |
output[:, 1] = sel_err | |
for subcell in hu.change_resolution(3, 4, cell): | |
i, sel, sel_err = self.get_selection(subcell, self.full_data, | |
cell_level=4) | |
if len(i) > 0: | |
#import ipdb; ipdb.set_trace() | |
output[np.in1d(ids, i), 2] = sel | |
output[np.in1d(ids, i), 3] = sel_err | |
for subcell in hu.change_resolution(3, 5, cell): | |
i, sel, sel_err = self.get_selection(subcell, self.full_data, | |
cell_level=5) | |
if len(i) > 0: | |
output[np.in1d(ids, i), 4] = sel | |
output[np.in1d(ids, i), 5] = sel_err | |
df = pd.DataFrame(ids, columns=['id']) | |
for level in range(3): | |
df['selection_%s' % (level + 3)] = output[:, level*2] | |
df['selection_%s_err' % (level + 3)] = output[:, level*2 + 1] | |
df['selection_best_order'] = np.argmin( | |
[df['selection_3_err'] / df['selection_3'], | |
df['selection_4_err'] / df['selection_4'], | |
df['selection_5_err'] / df['selection_5']], | |
axis=0) + 3 | |
return df | |
if __name__ == '__main__': | |
parser = argparse.ArgumentParser(description=""" | |
Tool to calculate selection for different catalogs. | |
""", formatter_class=argparse.RawDescriptionHelpFormatter) | |
parser.add_argument('-i', '--input', type=str, default=None, | |
help='Input file name') | |
parser.add_argument('-o', '--output', type=str, default=None, | |
help='Output file name') | |
parser.add_argument('-p', '--photometry', type=str, default='2mass', | |
help='Photometric catalog name') | |
parser.add_argument('-m', '--method', type=str, default='UniDAM', | |
help='Photometric catalog name') | |
parser.add_argument('-c', '--cell', type=str, default='', | |
help="""Cell number""") | |
parser.add_argument('--cell-file', type=str, default=None, | |
help="""File with cell numbers""") | |
parser.add_argument('--parallel', action="store_true", | |
default=False, | |
help='Run in parallel') | |
args = parser.parse_args() | |
if args.cell_file is not None: | |
cells = open(args.cell_file, 'r').readlines() | |
else: | |
cells = args.cell.split(',') | |
cells = [int(cell) for cell in cells] | |
selest = SelectionEstimator(photometry=args.photometry, | |
method=args.method, | |
full_data=fits.open(args.input)[1].data) | |
results = [] | |
if args.parallel: | |
from copy import deepcopy | |
def run(cells): | |
selx = deepcopy(selest) | |
result = [] | |
for cell in cells: | |
result.append(selx.get_selection_in_cell(cell)) | |
return pd.concat(result) | |
if 'OMP_NUM_THREADS' in os.environ: | |
pool_size = int(os.environ['OMP_NUM_THREADS']) | |
else: | |
pool_size = 2 | |
pool = mp.Pool(pool_size) | |
results = pool.map(run, | |
np.array_split(cells, pool_size)) | |
else: | |
for cell in cells: | |
print(cell) | |
results.append(selest.get_selection_in_cell(cell)) | |
pd.concat(results).to_csv(args.output, index=False, float_format='%.4e') |