Control as an inference problem

What if the data is not optimal?

We will introduce binary (True/False) Optimality Variables(O1,,OTO_1, \dots, O_T) that says: Is this person trying to be optimal at this point and time?

We will take this formula as given: (note that all of our rewards need to be negative)

p(Otst,at)=exp(r(st,at))p(O_t|s_t,a_t) = \exp(r(s_t,a_t))

Then our probabilisitic model:

p(τundefineds1:T,a1:TO1:T)=p(τ,O1:T)p(O1:T)p(τ)texp(r(st,at))=p(τ)exp(tr(st,at))\begin{split} p(\underbrace{\tau}_{\mathclap{s_{1:T},a_{1:T}}}|O_{1:T}) &= \frac{p(\tau, O_{1:T})}{p(O_{1:T})} \\ &\propto p(\tau) \prod_t \exp(r(s_t,a_t)) \\ &\quad = p(\tau) \exp(\sum_t r(s_t,a_t)) \end{split}

To do inference:

  1. Compute backward messages βt(st,at)=p(Ot:Tst,at)\beta_t(s_t,a_t) = p(O_{t:T}|s_t,a_t)
    1. What’s the probability of being optimal from now until the end of the trajectory given the state and action we are in
  1. Compute policy p(atst,O1:T)p(a_t|s_t, O_{1:T})
  1. Compute forward messages αt(st)=p(stO1:t1)\alpha_t(s_t) = p(s_t | O_{1:t-1})
    1. What’s the probability of being in state sts_t given all previous timesteps are all optimal

Backward Messages

βt(st,at)=p(Ot:Tst,at)=p(Ot:T,st+1st,at)dst+1=p(Ot+1:Tst+1)p(st+1st,at)p(Otst,at)dst+1\begin{split} \beta_t(s_t,a_t) &= p(O_{t:T}|s_t,a_t) \\ &=\int p(O_{t:T},s_{t+1}|s_t,a_t) ds_{t+1} \\ &=\int p(O_{t+1:T}|s_{t+1}) p(s_{t+1}|s_t,a_t) p(O_t|s_t, a_t) ds_{t+1} \\ \end{split}

Note:

p(Ot+1st+1)=p(Ot+1:Tst+1,at+1)undefinedβt+1(st+1,at+1)p(at+1st+1)undefinedwhich actions are likely a prioridat+1p(O_{t+1}|s_{t+1}) = \int \underbrace{p(O_{t+1:T}|s_{t+1}, a_{t+1})}_{\beta_{t+1}(s_{t+1},a_{t+1})} \underbrace{p(a_{t+1}|s_{t+1})}_{\text{which actions are likely a priori}} da_{t+1}

p(at+1st+1)p(a_{t+1}|s_{t+1}) ⇒ Which actions are likely a priori: If we don’t know whether we are optimal or not, how likely are we to choose a particular action? We will assume uniform for now

Reasonable because:

  1. Don’t know anything about the policy, reasonable to assume uniform
  1. Can modify reward function later to impose non-uniformity

Therefore,

βt(st,at)=p(Ot+1:Tst+1)p(st+1st,at)p(Otst,at)dst+1=p(Otst,at)Est+1st,at[βt+1(st+1)undefined=Eatp(atst)[βt(st,at)]]\begin{split} \beta_t(s_t,a_t) &=\int p(O_{t+1:T}|s_{t+1}) p(s_{t+1}|s_t,a_t) p(O_t|s_t, a_t) ds_{t+1} \\ &=p(O_t|s_t,a_t) \mathbb{E}_{s_{t+1}|s_t,a_t}[\underbrace{\beta_{t+1}(s_{t+1})}_{\mathclap{=\mathbb{E}_{a_t \sim p(a_t | s_t)}[\beta_t (s_t,a_t)]}}] \end{split}

This algorithm is called the backward pass, we calculate β\beta recursively from t=T1t=T-1 to 11

We will take a closer look at the backward pass

Let Vt(st)=logβt(st)V_t(s_t) = \log \beta_t(s_t), Qt(st,at)=logβt(st,at)Q_t (s_t, a_t) = \log \beta_t(s_t,a_t)

Vt(st)=logexp(Qt(st,at))datV_t(s_t) = \log \int \exp(Q_t(s_t,a_t)) da_t

We see that Vt(st)maxatQt(st,at)V_t(s_t) \to \max_{a_t} Q_t(s_t,a_t) as Qt(st,at)Q_t(s_t,a_t) gets bigger ⇒ we call this a softmax (not the softmax in neural nets, but a soft relaxation of the max operator)

Let’s also evaluate QtQ_t

Qt(st,at)=r(st,at)+logE[exp(Vt+1(st+1))]Q_t(s_t,a_t) = r(s_t,a_t) + \log \mathbb{E}[\exp(V_{t+1}(s_{t+1}))]

If we have determinimistic transition, then the update of this QtQ_t is equal to the bellman operator

Qt(st,at)=r(st,at)+Vt+1(st+1)Q_t(s_t,a_t) = r(s_t,a_t) + V_{t+1}(s_{t+1})

But if we have non-deterministic transitions,. then this update of QtQ_t will lead to optimistic transitions dominating the update - which is not a good idea

Policy Computation

Policy ⇒ p(atst,O1:T)p(a_t|s_t, O_{1:T}), what’s the probability of certain action given current state and that all timesteps should be optimal?

p(atst,O1:T)=π(atst)=p(atst,Ot:T)=p(at,stOt:T)p(stOt:T)=p(Ot:Tat,st)p(at,st)/p(Ot:T)p(Ot:Tst)p(st)/p(Ot:T)=p(Ot:Tst,at)p(Ot:Tst)p(at,st)p(st)=βt(st,at)βt(st)p(atst)undefinedaction prior assumed to be uniform=exp(Qt(st,at)Vt(st))=exp(At(st,at))exp(1αundefinedα is the added temperatureAt(st,at))\begin{split} p(a_t|s_t,O_{1:T}) &= \pi(a_t|s_t) \\ &=p(a_t|s_t,O_{t:T}) \\ &= \frac{p(a_t,s_t|O_{t:T})}{p(s_t|O_{t:T})} \\ &=\frac{p(O_{t:T}|a_t,s_t) p(a_t,s_t)/p(O_{t:T})}{p(O_{t:T}|s_t) p(s_t) / p(O_{t:T})} \\ &=\frac{p(O_{t:T}|s_t,a_t)}{p(O_{t:T}|s_t)} \frac{p(a_t,s_t)}{p(s_t)} \\ &=\frac{\beta_t(s_t,a_t)}{\beta_t(s_t)} \underbrace{p(a_t|s_t)}_{\mathclap{\text{action prior assumed to be uniform}}} \\ &=\exp(Q_t(s_t,a-t) - V_t(s_t)) = \exp(A_t(s_t,a_t)) \\ &\approx \exp(\underbrace{\frac{1}{\alpha}}_{\mathclap{\text{$\alpha$ is the added temperature}}}A_t(s_t,a_t)) \end{split}

Forward Messages

αt(st)=p(stO1:t1)=p(st,st1,at1O1:t1)dst1dat1=p(stst1,at1,O1:t1)p(at1st1,O1:t1)p(st1O1:t1)dst1dat1=p(stst1,at1)p(at1st1,O1:t1)p(st1O1:t1)dst1dat1\begin{split} \alpha_t(s_t) &= p(s_t|O_{1:t-1}) \\ &=\int p(s_t,s_{t-1},a_{t-1}|O_{1:t-1}) ds_{t-1}da_{t-1} \\ &=\int p(s_t|s_{t-1},a_{t-1}, O_{1:t-1}) p(a_{t-1}|s_{t-1},O_{1:t-1}) p(s_{t-1}|O_{1:t-1}) ds_{t-1} da_{t-1} \\ &= \int p(s_t | s_{t-1}, a_{t-1})p(a_{t-1}|s_{t-1},O_{1:t-1}) p(s_{t-1}|O_{1:t-1}) ds_{t-1} da_{t-1} \\ \end{split}

Note:

p(at1st1,Ot1)p(st1O1:t1)=p(Ot1st1,at1)p(at1st1)undefineduniformp(Ot1st1)p(Ot1st1)p(st1O1:t2)undefinedαt1(st1)p(Ot1O1:t2)=p(Ot1st1,at1)p(Ot1O1:t2)αt1(st1)\begin{split} p(a_{t-1}|s_{t-1}, O_{t-1}) p(s_{t-1}|O_{1:t-1}) &= \frac{p(O_{t-1}|s_{t-1},a_{t-1})\overbrace{p(a_{t-1}|s_{t-1})}^{\text{uniform}}}{p(O_{t-1}|s_{t-1})} \frac{p(O_{t-1}|s_{t-1}) \overbrace{p(s_{t-1}|O_{1:t-2})}^{\alpha_{t-1}(s_{t-1})}}{p(O_{t-1}|O_{1:t-2})} \\ &=\frac{p(O_{t-1}|s_{t-1},a_{t-1})}{p(O_{t-1}|O_{1:t-2})} \alpha_{t-1}(s_{t-1}) \end{split}

What if we want to know p(stO1:T)p(s_t|O_{1:T})?

p(stO1:T)=p(st,O1:T)p(O1:T)=p(Ot:Tst)undefinedβt(st)p(st,O1:t1)p(O1:T)βt(st)p(stO1:t1)p(O1:t1)βt(st)αt(st)\begin{split} p(s_t|O_{1:T}) &= \frac{p(s_t,O_{1:T})}{p(O_1:T)} \\ &= \frac{\overbrace{p(O_{t:T}|s_t)}^{\beta_t(s_t)} p(s_t, O_{1:t-1})}{p(O_{1:T})} \\ &\propto \beta_t(s_t) p(s_t|O_{1:t-1})p(O_{1:t-1}) \\ &\propto \beta_t(s_t) \alpha_t(s_t) \end{split}
Yellow cone shape is the beta, blue cone is the alpha

Control as Variational Inference

In continuous high-dimensional spaces we have to approximate

Inference problem:

p(s1:T,a1:TO1:T)p(s_{1:T}, a_{1:T}|O_{1:T})

Marginalizing and conditioning, we get the policy

π(atst)=p(atst,O1:T)\pi(a_t|s_t) = p(a_t|s_t,O_{1:T})

However,

p(st+1st,at,O1:T)p(st+1st,at)p(s_{t+1}|s_t,a_t,O_{1:T}) \ne p(s_{t+1}|s_t,a_t)

Instead of asking

“Given that you obtained high reward, what was your action probability and your transition probability”

We want to ask

“Given that you obtained high reward, what was your action probability given that your transition probability did not change?”
Can we find another distribution q(s1:T,a1:T)q(s_{1:T}, a_{1:T}) that is close to p(s1:T,a1:TO1:T)p(s_{1:T}, a_{1:T}|O_{1:T}) but has dynamics p(st+1st,at)p(s_{t+1}|s_t,a_t)?

Let’s try variational inference!

Let q(s1:T,a1:T)=p(s1)tp(st+1st,at)q(atst)q(s_{1:T}, a_{1:T}) = p(s_1) \prod_{t} p(s_{t+1}|s_t,a_t) q(a_t | s_t)

Let x=O1:T,z=(s1:T,a1:T)x = O_{1:T}, z = (s_{1:T}, a_{1:T})

The variational lower bound

logp(x)Ezq(z)[logp(x,z)logq(z)]\log p(x) \ge \mathbb{E}_{z \sim q(z)}[\log p(x,z) - \log q(z)]

Substituting in our definitions,

logp(O1:T)E(s1:T,a1:T)q[logp(s1)+t=1Tlogp(st+1st,at)+t=1Tlogp(Otst,at)logp(s1)t=1Tlogp(st+1st,at)t=1Tlogq(atst)]=E(s1:T,a1:T)q[tr(st,at)logq(atst)]=tE(st,at)q[r(st,at)+H(q(atst))]\begin{split} \log p(O_{1:T}) \ge &\mathbb{E}_{(s_{1:T}, a_{1:T}) \sim q}[\log p(s_1) + \sum_{t=1}^T \log p(s_{t+1}|s_t,a_t) + \sum_{t=1}^T \log p(O_t|s_t,a_t) \\ &-\log p(s_1) - \sum_{t=1}^T \log p(s_{t+1}|s_t, a_t) - \sum_{t=1}^T \log q(a_t|s_t)] \\ &= \mathbb{E}_{(s_{1:T}, a_{1:T}) \sim q}[\sum_t r(s_t, a_t) - \log q(a_t|s_t)] \\ &= \sum_t \mathbb{E}_{(s_t,a_t) \sim q}[r(s_t,a_t) + H(q(a_t|s_t))] \end{split}

⇒ maximize reward and maximize action entropy!

Optimize Variational Lower Bound

Base case: Solve for q(aTsT)q(a_T|s_T)

q(aTsT)=arg maxEsTq(sT)[EaTq(aTsT)[r(sT,aT)]+H(q(aTsT))]=arg maxEsTq(sT)[EaTq(aTsT)[r(sT,aT)logq(aTsT)]]\begin{split} q(a_T|s_T) &= \argmax \mathbb{E}_{s_T \sim q(s_T)}[\mathbb{E}_{a_T \sim q(a_T|s_T)}[r(s_T, a_T)] + H(q(a_T|s_T))] \\ &= \argmax \mathbb{E}_{s_T \sim q(s_T)}[\mathbb{E}_{a_T \sim q(a_T|s_T)}[r(s_T, a_T) - \log q(a_T|s_T)]] \end{split}

optimized when q(aTsT)exp(r(sT,aT))q(a_T|s_T) \propto \exp(r(s_T,a_T))

q(aTsT)=exp(r(sT,aT))exp(r(sT,a))da=exp(Q(sT,aT)V(sT))q(a_T|s_T) = \frac{\exp(r(s_T,a_T))}{\int \exp(r(s_T, a)) da} = \exp(Q(s_T,a_T) - V(s_T))
V(sT)=logexp(Q(sT,aT))daTV(s_T) = \log \int \exp(Q(s_T,a_T)) da_T

Therefore

EsTq(sT)[EaTq(aTsT)[r(sT,aT)logq(aTsT)]]=EsTq(sT)[EaTq(aTsT)[V(sT)]]\mathbb{E}_{s_T \sim q(s_T)}[\mathbb{E}_{a_T \sim q(a_T|s_T)}[r(s_T, a_T) - \log q(a_T|s_T)]] = \mathbb{E}_{s_T \sim q(s_T)}[\mathbb{E}_{a_T \sim q(a_T|s_T)}[V(s_T)]]
⚠️
Dynamic Programming Solution!
Levine (2018). Reinforcement Learning and Control as Probabilistic Inference

Q-Learning with softoptimality

Standard Q-Learning:

ϕϕ+αϕQϕ(s,a)(r(s,a)+γV(s)Qϕ(s,a))\phi \leftarrow \phi + \alpha \nabla_\phi Q_\phi(s,a)(r(s,a) + \gamma V(s') - Q_\phi(s,a))

Standard Q-Learning Target

V(s)=maxaQϕ(s,a)V(s') = \max_{a'} Q_\phi(s',a')

Soft Q-Learning

ϕϕ+αϕQϕ(s,a)(r(s,a)+γV(s)Qϕ(s,a))\phi \leftarrow \phi + \alpha \nabla_{\phi} Q_\phi(s,a)(r(s,a) + \gamma V(s') - Q_\phi(s,a))

Soft Q-Learning Target

V(s)=soft maxaQϕ(s,a)=logexp(Qϕ(s,a))daV(s') = \text{soft max}_{a'} Q_\phi(s',a') = \log \int \exp (Q_\phi(s',a')) da'

Policy

π(as)=exp(Qϕ(s,a)V(s))=exp(A(s,a))\pi(a|s) = \exp(Q_\phi(s,a) - V(s)) = \exp(A(s,a))

Policy Gradient with Soft Optimality (”Entropy regularized” policy gradient)

π(as)=exp(Qϕ(s,a)V(s))\pi(a|s) = \exp(Q_\phi(s,a) - V(s))

this policy optimizes tEπ(st,at)[r(st,at)]+Eπ(st)[H(π(atst))]\sum_t \mathbb{E}_{\pi(s_t,a_t)}[r(s_t, a_t)] + \mathbb{E}_{\pi(s_t)}[H(\pi(a_t|s_t))]

Intuition:

Benefits of soft optimality

Suggested Readings (Soft Optimality)