Source code for pymc.tuning.starting

#   Copyright 2024 - present The PyMC Developers
#
#   Licensed under the Apache License, Version 2.0 (the "License");
#   you may not use this file except in compliance with the License.
#   You may obtain a copy of the License at
#
#       http://www.apache.org/licenses/LICENSE-2.0
#
#   Unless required by applicable law or agreed to in writing, software
#   distributed under the License is distributed on an "AS IS" BASIS,
#   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#   See the License for the specific language governing permissions and
#   limitations under the License.

"""
Created on Mar 12, 2011.

@author: johnsalvatier
"""

import warnings

from collections.abc import Sequence

import numpy as np
import pytensor.gradient as tg

from numpy import isfinite
from pytensor.graph.basic import Variable
from pytensor.utils import lazy_scipy_module
from rich.console import Console
from rich.progress import Progress, TextColumn

import pymc as pm

from pymc.blocking import DictToArrayBijection, RaveledVars
from pymc.initial_point import make_initial_point_fn
from pymc.model import modelcontext
from pymc.progress_bar import CustomProgress, default_progress_theme
from pymc.pytensorf import floatX, inputvars
from pymc.util import (
    get_default_varnames,
    get_value_vars_from_user_vars,
)
from pymc.vartypes import discrete_types, typefilter

optimize = lazy_scipy_module("optimize")

__all__ = ["find_MAP"]


[docs] def find_MAP( start=None, vars: Sequence[Variable] | None = None, method="L-BFGS-B", return_raw=False, include_transformed=True, progressbar=True, progressbar_theme=default_progress_theme, maxeval=5000, model=None, *args, seed: int | None = None, **kwargs, ): """Find the local maximum a posteriori point given a model. `find_MAP` should not be used to initialize the NUTS sampler. Simply call ``pymc.sample()`` and it will automatically initialize NUTS in a better way. Parameters ---------- start: `dict` of parameter values (Defaults to `model.initial_point`) These values will be fixed and used for any free RandomVariables that are not being optimized. vars: list of TensorVariable List of free RandomVariables to optimize the posterior with respect to. Defaults to all continuous RVs in a model. The respective value variables may also be passed instead. method: string or callable, optional Optimization algorithm. Defaults to 'L-BFGS-B' unless discrete variables are specified in `vars`, then `Powell` which will perform better. For instructions on use of a callable, refer to SciPy's documentation of `optimize.minimize`. return_raw: bool, optional defaults to False Whether to return the full output of scipy.optimize.minimize include_transformed: bool, optional defaults to True Flag for reporting automatically unconstrained transformed values in addition to the constrained values progressbar: bool, optional defaults to True Whether to display a progress bar in the command line. progressbar_theme: Theme, optional Custom theme for the progress bar. maxeval: int, optional, defaults to 5000 The maximum number of times the posterior distribution is evaluated. model: Model (optional if in `with` context) *args, **kwargs Extra args passed to scipy.optimize.minimize Notes ----- Older code examples used `find_MAP` to initialize the NUTS sampler, but this is not an effective way of choosing starting values for sampling. As a result, we have greatly enhanced the initialization of NUTS and wrapped it inside ``pymc.sample()`` and you should thus avoid this method. """ model = modelcontext(model) if vars is None: vars = model.continuous_value_vars if not vars: raise ValueError("Model has no unobserved continuous variables.") else: try: vars = get_value_vars_from_user_vars(vars, model) except ValueError as exc: # Accommodate case where user passed non-pure RV nodes vars = inputvars(model.replace_rvs_by_values(vars)) if vars: warnings.warn( "Intermediate variables (such as Deterministic or Potential) were passed. " "find_MAP will optimize the underlying free_RVs instead.", UserWarning, ) else: raise exc disc_vars = list(typefilter(vars, discrete_types)) ipfn = make_initial_point_fn( model=model, jitter_rvs=set(), return_transformed=True, overrides=start, ) start = ipfn(seed) model.check_start_vals(start) vars_dict = {var.name: var for var in vars} x0 = DictToArrayBijection.map( {var_name: value for var_name, value in start.items() if var_name in vars_dict} ) # TODO: If the mapping is fixed, we can simply create graphs for the # mapping and avoid all this bijection overhead compiled_logp_func = DictToArrayBijection.mapf(model.compile_logp(jacobian=False), start) logp_func = lambda x: compiled_logp_func(RaveledVars(x, x0.point_map_info)) # noqa: E731 rvs = [model.values_to_rvs[vars_dict[name]] for name, _, _, _ in x0.point_map_info] try: # This might be needed for calls to `dlogp_func` # start_map_info = tuple((v.name, v.shape, v.dtype) for v in vars) compiled_dlogp_func = DictToArrayBijection.mapf( model.compile_dlogp(rvs, jacobian=False), start ) dlogp_func = lambda x: compiled_dlogp_func(RaveledVars(x, x0.point_map_info)) # noqa: E731 compute_gradient = True except (AttributeError, NotImplementedError, tg.NullTypeGradError): compute_gradient = False if disc_vars or not compute_gradient: pm._log.warning( "Warning: gradient not available." + "(E.g. vars contains discrete variables). MAP " + "estimates may not be accurate for the default " + "parameters. Defaulting to non-gradient minimization " + "'Powell'." ) method = "Powell" if compute_gradient and method != "Powell": cost_func = CostFuncWrapper(maxeval, progressbar, progressbar_theme, logp_func, dlogp_func) else: cost_func = CostFuncWrapper(maxeval, progressbar, progressbar_theme, logp_func) compute_gradient = False with cost_func.progress: try: opt_result = optimize.minimize( cost_func, x0.data, method=method, jac=compute_gradient, *args, **kwargs ) mx0 = opt_result["x"] # r -> opt_result except (KeyboardInterrupt, StopIteration) as e: mx0, opt_result = cost_func.previous_x, None if isinstance(e, StopIteration): pm._log.info(e) finally: cost_func.progress.update(cost_func.task, completed=cost_func.n_eval, refresh=True) mx0 = RaveledVars(mx0, x0.point_map_info) unobserved_vars = get_default_varnames(model.unobserved_value_vars, include_transformed) unobserved_vars_values = model.compile_fn(inputs=model.value_vars, outs=unobserved_vars)( DictToArrayBijection.rmap(mx0, start) ) mx = {var.name: value for var, value in zip(unobserved_vars, unobserved_vars_values)} if return_raw: return mx, opt_result else: return mx
def allfinite(x): return np.all(isfinite(x)) class CostFuncWrapper: def __init__( self, maxeval=5000, progressbar=True, progressbar_theme=default_progress_theme, logp_func=None, dlogp_func=None, ): self.n_eval = 0 self.maxeval = maxeval self.logp_func = logp_func if dlogp_func is None: self.use_gradient = False self.desc = "logp = {:,.5g}" else: self.dlogp_func = dlogp_func self.use_gradient = True self.desc = "logp = {:,.5g}, ||grad|| = {:,.5g}" self.previous_x = None self.progressbar = progressbar self.progress = CustomProgress( *Progress.get_default_columns(), TextColumn("{task.fields[loss]}"), console=Console(theme=progressbar_theme), disable=not progressbar, ) self.task = self.progress.add_task("MAP", total=maxeval, loss="") def __call__(self, x): neg_value = np.float64(self.logp_func(floatX(x))) value = -1.0 * neg_value if self.use_gradient: neg_grad = self.dlogp_func(floatX(x)) if np.all(np.isfinite(neg_grad)): self.previous_x = x grad = -1.0 * neg_grad grad = grad.astype(np.float64) else: self.previous_x = x grad = None if self.n_eval % 10 == 0: self.progress.update(self.task, loss=self.update_progress_desc(neg_value, grad)) if self.n_eval > self.maxeval: self.progress.update(self.task, loss=self.update_progress_desc(neg_value, grad)) raise StopIteration self.n_eval += 1 self.progress.update(self.task, completed=self.n_eval) if self.use_gradient: return value, grad else: return value def update_progress_desc(self, neg_value: float, grad: np.float64 = None) -> None: if self.progressbar: if grad is None: return self.desc.format(neg_value) else: norm_grad = np.linalg.norm(grad) return self.desc.format(neg_value, norm_grad)