The (aux hansei) module

Introduction

Quoting from [1]:


Broadly speaking, probabilistic programming languages are to express computations with degrees of uncertainty, which comes from the imprecision in input data, lack of the complete knowledge or is inherent in the domain. More precisely, the goal of probabilistic programming languages is to represent and automate reasoning about probabilistic models, which describe uncertain quantities -- random variables -- and relationships among them.

Oleg Kiselyov

Here we focus on the dsl Hansei [2] and the corresponding published paper [3], by Oleg Kiselyov and Chung-chieh Shan; moreover, an application to logic can be found in [4].

Implementation

The first implementation of our language uses the probability monad that represents a stochastic computation as a lazy search tree. That is, our implementation uses the type constructor pV defined below.

type 'a vc = V of 'a | C of (unit -> 'a pV) and 'a pV = (prob * 'a vc) list
Each node in a tree is a weighted list of branches. The empty list denotes failure, and a singleton list [(p, V v)] denotes a deterministic successful outcomevwith the probability mass p. A branch of the form V v is a leaf node that describes a possible successful outcome, whereas a branch of the form C thunk is not yet explored. The intended meaning of a search tree of type 'a pV is a discrete probability distribution over values of type 'a.

(module
  (aux hansei)
  *
  (import
    scheme
    (chicken base)
    (chicken continuation)
    (chicken pretty-print)
    (chicken fixnum)
    (chicken sort)
    srfi-69
    (aux base)
    (aux continuation)
    (aux continuation delimited))
  (define op/times (make-parameter *))
  (define op/plus (make-parameter +))
  (define op/subtract (make-parameter -))
  (define op/divide (make-parameter (λ (m n) (exact->inexact (/ m n)))))
  (define op/greater (make-parameter >))
  (define-syntax-rule (probcc-τ p body ...) `((C ,(τ body ...)) ,p))
  (define-syntax-rule (probcc-value p body ...) `((V ,(begin body ...)) ,p))
  (define (probcc-explore maxdepth choices)
    (letrec ((times (op/times))
             (plus (op/plus))
             (loop (λ (p depth down choices ans susp)
                       (match/first
                         choices
                         (() susp)
                         (((,slot ,pt) unquote rest)
                          (let* ((p*pt (times p pt)) (A (λ (w) (plus w p*pt))))
                            (match/first
                              slot
                              ((V ,v)
                               (hash-table-update!/default ans v A 0)
                               (loop p depth down rest ans susp))
                              ((C ,t)
                               (cond (down
                                      (loop p
                                            depth
                                            down
                                            rest
                                            ans
                                            (loop p*pt
                                                  (add1 depth)
                                                  (< depth maxdepth)
                                                  (t)
                                                  ans
                                                  susp)))
                                     (else
                                      (let1 (s (cons (probcc-τ p*pt (t)) susp))
                                            (loop p depth down rest ans s))))))))))))
      (let* ((ans (make-hash-table))
             (susp (loop 1 0 #t choices ans '()))
             (f (λ (v p l) (cons (probcc-value p v) l)))
             (folded (hash-table-fold ans f susp))
             (greater (op/greater)))
        (sort folded (λ (a b) (greater (cadr a) (cadr b)))))))
  (define (probcc-normalize choices)
    (let* ((divide (op/divide))
           (plus (op/plus))
           (tot (foldr (λ (each t) (plus t (cadr each))) 0 choices))
           (N (λ (each) (list (car each) (divide (cadr each) tot)))))
      (map N choices)))
  (define (probcc-distribution pairs)
    (letcc/shift
      k
      (map (λ1-match/non-overlapping ((,v ,p) (probcc-τ p (k v)))) pairs)))
  (define (probcc-reflect choices)
    (letcc/shift
      k
      (letrec ((make-choices (λ (pv) (map F pv)))
               (F (λ1-match/non-overlapping
                    (((V ,v) ,p) (probcc-τ p (k v)))
                    (((C ,t) ,p) (probcc-τ p (make-choices (t)))))))
        (make-choices choices))))
  (define (probcc-impossible) (probcc-distribution '()))
  (define (probcc-unit v) (list (probcc-value 1 v)))
  (define (probcc-bernoulli t f p)
    (probcc-distribution `((,t ,p) (,f ,((op/subtract) 1 p)))))
  (define (probcc-coin p) (probcc-bernoulli #t #f p))
  (define (probcc-uniform n)
    (cond ((> n 0)
           (letrec ((p (/ 1 n))
                    (plus (op/plus))
                    (subtract (op/subtract))
                    (loop (λ (pacc acc i)
                              (if (zero? i)
                                (probcc-distribution
                                  (cons `(,i ,(subtract 1 pacc)) acc))
                                (loop (plus pacc p) (cons `(,i ,p) acc) (sub1 i))))))
             (loop 0 '() (sub1 n))))
          (else (probcc-impossible))))
  (define (probcc-uniform/range low high)
    (+ low (probcc-uniform (- high low))))
  (define (probcc-uniform/either lst)
    (list-ref lst (probcc-uniform (length lst))))
  (define (probcc-geometric p s f)
    (letrec ((subtract (op/subtract))
             (loop (λ (n)
                       (list (probcc-τ p (probcc-unit (cons s n)))
                             (probcc-τ (subtract 1 p) (loop (cons f n)))))))
      (probcc-reflect (loop '()))))
  (define-syntax-rule
    (probcc-when test body ...)
    (cond (test body ...) (else (probcc-impossible))))
  (define (probcc-reify/0 model) (resetcc (probcc-unit (model))))
  (define ((probcc-reify depth) model)
    (probcc-explore depth (probcc-reify/0 model)))
  (define probcc-reify/exact/a (probcc-reify +inf.0))
  (define-syntax-rule
    (probcc-reify/exact body ...)
    (probcc-reify/exact/a (τ body ...)))
  (define (probcc-variable-elimination f)
    (λ args (probcc-reflect (probcc-reify/exact (apply f args)))))
  (define-syntax-rule
    (λ-probcc-bucket args body ...)
    (letrec ((F (λ args body ...))
             (B (λ-memo bargs (probcc-reify/exact (apply F bargs)))))
      (o probcc-reflect B)))
  (define (probcc-leaves choices)
    (let L ((choices* choices) (count 0))
      (let1 (F (λ (probpair acc)
                   (match/first
                     probpair
                     (((V ,v) ,p) (add1 acc))
                     (((C ,t) ,p) (L (t) acc)))))
            (foldr F count choices*))))
  (define (probcc-dfs choices)
    (let1 (M (λ1-match/first
               (((C ,t) ,p)
                (apply append
                       (probcc-dfs
                         (map (λ1-match/first
                                ((,slot ,pi) `(,slot ,((op/times) p pi))))
                              (t)))))
               (,probpair (list probpair))))
          (map M choices))))

Tests

test/procc/coin-model: pass

Joint distribution of tossing two biased coins, where head has probability 0.6 to appear.

(define (test/procc/coin-model _)
  (⊦= '(((V ((x #t) (y #t))) 0.36)
          ((V ((x #t) (y #f))) 0.24)
          ((V ((x #f) (y #t))) 0.24)
          ((V ((x #f) (y #f))) 0.16))
        (probcc-normalize
          (probcc-reify/exact
            (let* ((p 0.6) (x (probcc-coin p)) (y (probcc-coin p)))
              `((x ,x) (y ,y)))))))
((eta 0.004) (memory #(6291456 1960968 1048576)) (stdout "") (stderr ""))

test/procc/coin-model/when: pass

Slightly variation of the previous test, here it has been observed that at least one head appeared.

(define (test/procc/coin-model/when _)
  (⊦= '(((V (#t #t)) 0.428571428571429)
          ((V (#t #f)) 0.285714285714286)
          ((V (#f #t)) 0.285714285714286))
        (probcc-normalize
          (probcc-reify/exact
            (let* ((p 0.6) (x (probcc-coin p)) (y (probcc-coin p)))
              (probcc-when (or x y) (list x y)))))))
((eta 0.0) (memory #(6291456 1962376 1048576)) (stdout "") (stderr ""))

test/procc/grass-model: pass


The canonical example is the grass model, with three random variables representing the events of rain, of a switched-on sprinkler and wet grass. The (a priori) probabilities of the first two events are judged to be 30% and 50% correspondingly. Probabilities are non-negative real numbers that may be regarded as weights on non-deterministic choices. Rain almost certainly (90%) wets the grass. The sprinkler also makes the grass wet, in 80% of the cases. The grass may also be wet for some other reason. The modeler gives such an unaccounted event 10% of a chance. This model is often depicted as a directed acyclic graph (DAG) -- so-called Bayesian, or belief network -- with nodes representing random variables and edges conditional dependencies. Associated with each node is a distribution (such as Bernoulli distribution: the flip of a biased coin), or a function that computes a distribution from the node's inputs (such as the noisy disjunction nor). The sort of reasoning we wish to perform on the model is finding out the probability distribution of some of its random variables. For example, we can work out from the model that the probability of the grass being wet is 60.6%. Such reasoning is called probabilistic inference. Often we are interested in the distribution conditioned on the fact that some random variables have been observed to hold a particular value. In our example, having observed that the grass is wet, we want to find out the chance it was raining on that day.

Oleg Kiselyov

The solution to this problem shows the probability distribution of raining, provided that has been observed a wet grass:

(((V (rain #f)) 0.53152855727963) ((V (rain #t)) 0.46847144272037))
as required. The following test defines and captures this problem.

(define (test/procc/grass-model _)
  (define result
    (probcc-reify/exact
      (let* ((rain (probcc-coin 0.3))
             (sprinkler (probcc-coin 0.5))
             (grass-is-wet
               (or (and (probcc-coin 0.9) rain)
                   (and (probcc-coin 0.8) sprinkler)
                   (probcc-coin 0.1))))
        (probcc-when grass-is-wet `(rain ,rain)))))
  (⊦= (list (probcc-value 0.322 '(rain #f)) (probcc-value 0.2838 '(rain #t)))
        result))
((eta 0.001) (memory #(6291456 1964272 1048576)) (stdout "") (stderr ""))

test/procc/grass-model/complete: pass

If we remove the assumption that has been observed a wet grass, then we have the joint probability distribution of all variables:

(define (test/procc/grass-model/complete _)
  (define result
    (probcc-reify/exact
      (let* ((rain (probcc-coin 0.3))
             (sprinkler (probcc-coin 0.5))
             (grass-is-wet
               (or (and (probcc-coin 0.9) rain)
                   (and (probcc-coin 0.8) sprinkler)
                   (probcc-coin 0.1))))
        `((rain ,rain) (sprinkler ,sprinkler) (grass-is-wet ,grass-is-wet)))))
  (⊦= '(((V ((rain #f) (sprinkler #f) (grass-is-wet #f))) 0.315)
          ((V ((rain #f) (sprinkler #t) (grass-is-wet #t))) 0.287)
          ((V ((rain #t) (sprinkler #t) (grass-is-wet #t))) 0.1473)
          ((V ((rain #t) (sprinkler #f) (grass-is-wet #t))) 0.1365)
          ((V ((rain #f) (sprinkler #t) (grass-is-wet #f))) 0.063)
          ((V ((rain #f) (sprinkler #f) (grass-is-wet #t))) 0.035)
          ((V ((rain #t) (sprinkler #f) (grass-is-wet #f))) 0.0135)
          ((V ((rain #t) (sprinkler #t) (grass-is-wet #f))) 0.0027))
        result)
  (⊦= (probcc-normalize result) result))
((eta 0.0) (memory #(6291456 1970672 1048576)) (stdout "") (stderr ""))

test/uniform/range: pass

(define (test/uniform/range _)
  (⊦= '(((V 1) 1/7)
          ((V 2) 1/7)
          ((V 3) 1/7)
          ((V 4) 1/7)
          ((V 5) 1/7)
          ((V 6) 1/7)
          ((V 7) 1/7))
        (sort (probcc-reify/exact (probcc-uniform/range 1 8))
              (λ (a b) (< (cadr (car a)) (cadr (car b)))))))
((eta 0.0) (memory #(6291456 1972240 1048576)) (stdout "") (stderr ""))

test/geometric: pass

(define (test/geometric _)
  (define result ((probcc-reify 5) (τ (probcc-geometric 0.85 's 'f))))
  (define t6 (cadr (car (sixth result))))
  (define t7 (cadr (car (seventh result))))
  (define t8 (cadr (car (eighth result))))
  (⊦= `(((V (s)) 0.85)
          ((V (s f)) 0.1275)
          ((V (s f f)) 0.019125)
          ((V (s f f f)) 0.00286875)
          ((V (s f f f f)) 0.0004303125)
          ((C ,t6) 6.4546875e-05)
          ((C ,t7) 9.68203125000001e-06)
          ((C ,t8) 1.70859375e-06))
        result))
((eta 0.001) (memory #(6291456 1976704 1048576)) (stdout "") (stderr ""))