Permalink
Cannot retrieve contributors at this time
Name already in use
A tag already exists with the provided branch name. Many Git commands accept both tag and branch names, so creating this branch may cause unexpected behavior. Are you sure you want to create this branch?
sage_selection_public/tools.py
Go to fileThis commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
138 lines (106 sloc)
3.72 KB
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python3 | |
# -*- coding: utf-8 -*- | |
""" | |
Created on Thu Dec 7 17:22:25 2017 | |
@author: mints | |
""" | |
import logging | |
import sys | |
import os | |
import numpy as np | |
from scipy.stats import binned_statistic_2d | |
from scipy.interpolate import LinearNDInterpolator, NearestNDInterpolator | |
from unidam.utils.stats import to_bins | |
from scipy.stats import laplace | |
from scipy.spatial.qhull import QhullError | |
from sage_selection import constants as c | |
from l2di import Linear2DInterpolator | |
def get_logger(name, output=True, level=logging.DEBUG): | |
""" | |
Return a logger that writes either to a file or to screen. | |
""" | |
logger = logging.getLogger(name) | |
if output: | |
handler = logging.StreamHandler(sys.stdout) | |
else: | |
handler = logging.FileHandler(name + '.log') | |
formatter = logging.Formatter( | |
'%(asctime)s - %(name)s - %(levelname)s - %(message)s') | |
handler.setFormatter(formatter) | |
logger.setLevel(level) | |
logger.addHandler(handler) | |
return logger | |
def ensure_dir(directory): | |
""" | |
Create directory if it does not exist. | |
""" | |
if not os.path.exists(directory): | |
os.makedirs(directory) | |
class SuppressRuntime(object): | |
def __init__(self, **kwargs): | |
self.store = np.geterr() | |
self.set = kwargs | |
def __enter__(self): | |
np.seterr(**self.set) | |
def __exit__(self, type, value, traceback): | |
np.seterr(**self.store) | |
def interp_log_2d(x, y, xx, **kwargs): | |
try: | |
if len(x) >= 3: | |
inter = Linear2DInterpolator(x, np.log(y)) | |
result = np.exp(inter(xx)) | |
else: | |
nninter = NearestNDInterpolator(x, np.log(y)) | |
result = np.exp(nninter(xx)) | |
except QhullError: | |
result = np.ones(len(xx)) * y.mean() | |
return np.clip(result, y.min(), y.max()) | |
x2d = np.array(np.meshgrid(c.JMAG_GRID, c.JK_GRID)) | |
x2d_points = x2d.reshape((2, len(c.JMAG_GRID) * len(c.JK_GRID))).T | |
bins = (to_bins(c.JMAG_GRID), to_bins(c.JK_GRID)) | |
RATIO = 1. | |
lap = laplace(loc=(14., 0.), scale=(2., 1.)) | |
lap2 = laplace(loc=(11., 0.7), scale=(1., 3.)) | |
lap_max = (lap.pdf(x2d_points).prod(axis=1) + | |
lap2.pdf(x2d_points).prod(axis=1)).max() | |
rand = np.random.random(size=(len(c.JMAG_GRID), len(c.JK_GRID))) | |
def selection2d_flat(y): | |
result = np.zeros(len(y)) | |
result[(y[:, 0] > c.JMAG_GRID[170]) * (y[:, 0] < c.JMAG_GRID[240])] = RATIO | |
return result | |
def selection2d_lap(y): | |
result = lap.pdf(y).prod(axis=1) + lap2.pdf(y).prod(axis=1) | |
return result * RATIO / lap_max | |
def _get_interpolator(filename): | |
_data = np.loadtxt(f'/home/mints/prog/sage_selection/observations/{filename}', | |
delimiter=',', skiprows=1, usecols=(4, 5, 14)) | |
return NearestNDInterpolator(_data[:, :2], _data[:, 2]) | |
def generate_selection(filename): | |
lndi = _get_interpolator(filename) | |
def selection2d_generated(y): | |
result = lndi(y) | |
result[np.isnan(result)] = 0. | |
return result * RATIO | |
return selection2d_generated | |
def selection2d_linear(y): | |
result = (y[:, 0] - 10) * 0.17 + 0.1 | |
result[result < 0] = 0 | |
result *= 1 - (1 - y[:, 1])**2 | |
result[result < 0] = 0 | |
result[result > 1] = 0 | |
return result * RATIO | |
FUNCTIONS = {'flat': selection2d_flat, | |
'lamost': generate_selection('292_LAMOST_GAC_VB.csv'), | |
'rave': generate_selection('176_RAVE_ON.csv'), | |
'gaia_eso': generate_selection('307_GAIA_ESO4.csv'), | |
'linear': selection2d_linear, | |
'lap': selection2d_lap} | |
NORM = {key: 1. for key in FUNCTIONS.keys()} | |
def set_norm(bg): | |
""" | |
Set proper N-to-ratio factors for a given background. | |
""" | |
global RATIO | |
RATIO = 1. | |
for key, func in FUNCTIONS.items(): | |
NORM[key] = 1./(bg.T.flatten() * func(x2d_points)).sum() |