/* -*- mode:c++; coding: koi8-r -*- */

/* $Id: DiscreteVariation.cpp,v 1.1 2003/04/14 20:43:30 cher Exp $ */
/* Copyright (C) 2003 Alexander Chernov <cher@unicorn.cmc.msu.ru> */

/*
 This library is free software; you can redistribute it and/or
 modify it under the terms of the GNU Lesser General Public
 License as published by the Free Software Foundation; either
 version 2 of the License, or (at your option) any later version.

 This library is distributed in the hope that it will be useful,
 but WITHOUT ANY WARRANTY; without even the implied warranty of
 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
 Lesser General Public License for more details.

 See the `COPYING' file for the full terms and conditions.
*/

#include "DiscreteVariation.hpp"

#include <math.h>

DiscreteVariation::DiscreteVariation(Random *_rnd, int _n, int _first,
                                     double const *_p)
  throw (BadArgsError)
{
  int i, j;
  double s = 0.0;
  int k = 0;

  if (!_rnd) throw BadArgsError();
  if (_n <= 0) throw BadArgsError();
  if (!_p) throw BadArgsError();

  for (i = 0; i < _n; i++) {
    if (_p[i] < 0 || _p[i] > 1) throw BadArgsError();
    if (_p[i] > 0) k++;
    s += _p[i];
  }
  if (k <= 0) throw BadArgsError();
  if (fabs(s - 1.0) > 1e-10) throw BadArgsError();

  n = k + 1;
  f = new double[n];
  v = new int[n];
  s = 0.0;
  j = 0;
  for (i = 0; i < _n; i++) {
    if (!_p[i]) continue;
    f[j] = s;
    v[j] = _first + i;
    s += _p[i];
  }
  f[n - 1] = 1.0;
  rnd = _rnd;
}

DiscreteVariation::DiscreteVariation(Random *_rnd, int _n, int const *_v,
                                     double const *_p)
  throw (BadArgsError)
{
  double s;
  int i, j, k;

  if (!_rnd) throw BadArgsError();
  if (_n <= 0) throw BadArgsError();
  if (!_v || !_p) throw BadArgsError();

  s = 0.0;
  k = 0;
  for (i = 0; i < _n; i++) {
    if (_p[i] < 0 || _p[i] > 1) throw BadArgsError();
    if (_p[i] > 0) k++;
    s += _p[i];
  }
  if (k <= 0) throw BadArgsError();
  if (fabs(s - 1.0) > 1e-10) throw BadArgsError();

  n = k + 1;
  f = new double[n];
  v = new int[n];
  s = 0.0;
  j = 0;
  for (i = 0; i < _n; i++) {
    if (!_p[i]) continue;
    f[j] = s;
    v[j] = _v[j];
    s += _p[i];
  }
  f[n - 1] = 1.0;
  rnd = _rnd;
}

DiscreteVariation::~DiscreteVariation()
  throw ()
{
  delete[] f;
  delete[] v;
}

int
DiscreteVariation::get()
  throw ()
{
  double x;
  int k;

  x = rnd->uniform();
  // dumb! better use dichotomy!
  for (k = 0; k < n; k++) {
    if (f[k + 1] > x) return v[k];
  }
  return v[n-1];
}

/*
 * Local variables:
 *  compile-command: "make -C .."
 *  c++-font-lock-extra-types: ("[A-Z]\\sw*[a-z]\\sw*")
 * End:
 */
