Madhav commited on
Commit
a243869
·
1 Parent(s): 7d417e9

Upload stream.py

Browse files
Files changed (1) hide show
  1. stream.py +97 -0
stream.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import jax.numpy as jnp
3
+ import matplotlib.pyplot as plt
4
+ import numpyro
5
+ import numpyro.distributions as dist
6
+ from numpyro.infer import MCMC, NUTS
7
+ from sklearn.datasets import make_regression
8
+ from jax import random
9
+ import streamlit as st
10
+
11
+
12
+ # Define the model
13
+ def linear_regression(X, y, alpha_prior, beta_prior, sigma_prior):
14
+ alpha = numpyro.sample('alpha', alpha_prior)
15
+ beta = numpyro.sample('beta', beta_prior)
16
+ sigma = numpyro.sample('sigma', sigma_prior)
17
+ mean = alpha + beta * X
18
+ numpyro.sample('obs', dist.Normal(mean, sigma), obs=y)
19
+
20
+
21
+ def run_linear_regression(X, y, alpha_prior, beta_prior, sigma_prior):
22
+ # Run MCMC
23
+ rng_key = random.PRNGKey(0)
24
+ nuts_kernel = NUTS(linear_regression)
25
+ mcmc = MCMC(nuts_kernel, num_warmup=100, num_samples=600)
26
+ mcmc.run(rng_key, jnp.array(X), jnp.array(y), alpha_prior=alpha_prior, beta_prior=beta_prior, sigma_prior=sigma_prior)
27
+
28
+ mcmc.print_summary()
29
+
30
+ # Get posterior samples
31
+ samples = mcmc.get_samples()
32
+
33
+ # Plot the results
34
+ fig, ax = plt.subplots(figsize=(8, 6))
35
+ ax.scatter(X, y, color='blue', alpha=0.5, label='data')
36
+ light_color = (1.0, 0.5, 0.5, 0.7)
37
+ for i in range(500):
38
+ alpha_i = samples['alpha'][i]
39
+ beta_i = samples['beta'][i]
40
+ ax.plot(X, alpha_i + beta_i * X, color=light_color)
41
+ ax.plot(X, np.mean(samples['alpha']) + np.mean(samples['beta']) * X, color='red', label='mean')
42
+ ax.legend(loc='upper left')
43
+ st.pyplot(fig)
44
+
45
+
46
+ # User Input
47
+ st.write("# Bayesian Linear Regression with numpyro")
48
+ alpha_prior_option = st.selectbox("Choose an option for alpha prior:", ["Normal", "Laplace", "Cauchy"])
49
+ beta_prior_option = st.selectbox("Choose an option for beta prior:", ["Normal", "Laplace", "Cauchy"])
50
+ sigma_prior_option = st.selectbox("Choose an option for sigma prior:", ["HalfNormal", "HalfCauchy"])
51
+
52
+ if st.button("Run Regression"):
53
+ alpha_prior = None
54
+ beta_prior = None
55
+ sigma_prior = None
56
+
57
+ if alpha_prior_option == "Normal":
58
+ alpha_loc = st.slider("Select a mean value for alpha", -10.0, 10.0, 0.0, 0.1)
59
+ alpha_scale = st.slider("Select a standard deviation value for alpha", 0.0, 10.0, 1.0, 0.1)
60
+ alpha_prior = dist.Normal(alpha_loc, alpha_scale)
61
+ elif alpha_prior_option == "Laplace":
62
+ alpha_loc = st.slider("Select a mean value for alpha", -10.0, 10.0, 0.0, 0.1)
63
+ alpha_scale = st.slider("Select a scale value for alpha", 0.0, 10.0, 1.0, 0.1)
64
+ alpha_prior = dist.Laplace(alpha_loc, alpha_scale)
65
+ elif alpha_prior_option == "Cauchy":
66
+ alpha_loc = st.slider("Select a location value for alpha", -10.0, 10.0, 0.0, 0.1)
67
+ alpha_scale = st.slider("Select a scale value for alpha", 0.0, 10.0, 1.0, 0.1)
68
+ alpha_prior = dist.Cauchy(alpha_loc, alpha_scale)
69
+
70
+ if beta_prior_option == "Normal":
71
+ beta_loc = st.slider("Select a mean value for beta", -10.0, 10.0, 0.0, 0.1)
72
+ beta_scale = st.slider("Select a standard deviation value for beta", 0.0, 10.0, 1.0, 0.1)
73
+ beta_prior = dist.Normal(beta_loc, beta_scale)
74
+ elif beta_prior_option == "Laplace":
75
+ beta_loc = st.slider("Select a mean value for beta", -10.0, 10.0, 0.0, 0.1)
76
+ beta_scale = st.slider("Select a scale value for beta", 0.0, 10.0, 1.0, 0.1)
77
+ beta_prior = dist.Laplace(beta_loc, beta_scale)
78
+ elif beta_prior_option == "Cauchy":
79
+ beta_loc = st.slider("Select a location value for beta", -10.0, 10.0, 0.0, 0.1)
80
+ beta_scale = st.slider("Select a scale value for beta", 0.0, 10.0, 1.0, 0.1)
81
+ beta_prior = dist.Cauchy(beta_loc, beta_scale)
82
+
83
+ if sigma_prior_option == "HalfNormal":
84
+ sigma_scale = st.slider("Select a scale value for sigma", 0.0, 10.0, 1.0, 0.1)
85
+ sigma_prior = dist.HalfNormal(sigma_scale)
86
+ elif sigma_prior_option == "HalfCauchy":
87
+ sigma_scale = st.slider("Select a scale value for sigma", 0.0, 10.0, 1.0, 0.1)
88
+ sigma_prior = dist.HalfCauchy(sigma_scale)
89
+
90
+ # Generate data
91
+ rng_key = random.PRNGKey(0)
92
+ X, y = make_regression(n_samples=25, n_features=1, noise=10.0, random_state=0)
93
+ X = X.reshape(25)
94
+
95
+ # Run the regression
96
+ run_linear_regression(X, y, alpha_prior, beta_prior, sigma_prior)
97
+