On the choice of output in model reference control

Published

February 9, 2023

Abstract
This note discusses the choice of the output function when input-output linearizing.

\[ \newcommand{\tuple}[1]{ #1 } \newcommand{\d}{\mathrm d} \newcommand{\tp}[1]{{#1}^{\mathrm T}} \newcommand{\R}{\mathbb R} \]

Let’s investigate it using the example system, \[ m(x)\ddot x + r(x) \dot x + k(x) x = u, \] which in state-space form with state \(\tp{[x_1, x_2]} = \tp{[x, \dot x]}\) is \[ \begin{bmatrix} \dot x_1 \\ \dot x_2 \end{bmatrix} = \begin{bmatrix} \dot x \\ \frac 1{m(x)}[u - r(x)\dot x - k(x)x] \end{bmatrix}. \] If \(y = x_1 = x\), \[ \begin{aligned} \dot y &= \dot x_1 = x_2 \\ \ddot y &= \dot x_2 = \frac 1{m(x)}[u - r(x)\dot x - k(x)x] \\ \Leftrightarrow u &= m(x)\ddot y + r(x)\dot x + k(x)x. \end{aligned} \] If \(y = x_2 = \dot x\), \[ \begin{aligned} \dot y &= \dot x_2 = \frac 1{m(x)}[u - r(x)\dot x - k(x)x] \\ \Leftrightarrow u &= m(x)\dot y + r(x)\dot x + k(x)x. \end{aligned} \] So clearly, both choices of output will give the same input-output linearizing \(u\), but lead to systems with different relative degree. The smart choice is then to use the highest state derivative as output to reduce the number of differentiations.

Such a choice reduces JAX’s compilation times while leading to the exact same output law. Let’s look at an example using Dynax1.

1 The full source code of the example can be found here. Dynax is not yet released.

import time

import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np

from dynax import Flow, DynamicStateFeedbackSystem, ControlAffine
from dynax.linearize import relative_degree, input_output_linearize

We simulate a spring-mass damper with nonlinear drag, \[ m \ddot x + r \dot x + r_2 \dot x|\dot x| + k x = u. \] by bringing it into control-affine form \(\dot x = f(x) + g(x)u\).

class NonlinearDrag(ControlAffine):
  m: float
  r: float
  r2: float
  k: float
  n_states = 2
  n_inputs = 1
  n_outputs = 1

  def f(self, x, u=None, t=None):
    x1, x2 = x
    return jnp.array([
      x2,
      (-self.r * x2 - self.r2 * jnp.abs(x2) * x2 - self.k * x1) / self.m])

  def g(self, x, u=None, t=None):
    return jnp.array([0.0, 1.0 / self.m])

dyn = NonlinearDrag(m=1., r=1., r2=0.2, k=1.)
init_state = np.zeros(dyn.n_states)

The reference input is a sine at 0.1 Hz,

# design input signal
T = 100
sr = 100
t = np.arange(int(T*sr))/sr
r = 10*np.sin(2*np.pi*t*0.1)

and the reference dynamics is the nonlinear system linearized around \(x=0\).

ref_dyn = dyn.linearize()

fig, ax = plt.subplots(nrows=2)
ax[0].plot(r, label='$r$')
us = []

Let’s input-output linearize the system with either \(y=x_1=x\) or \(y=x_2=\dot x\):

for output in [0, 1]:
  # estimate relative degree
  x_test = np.random.normal(size=(1000, 2))
  reldeg = relative_degree(dyn, x_test, output=output)
  print("Relative degree:", reldeg)

  # construct input-output linearized model
  feedbacklaw = input_output_linearize(dyn, reldeg, ref=ref_dyn, output=output)
  feedbacksys = DynamicStateFeedbackSystem(dyn, ref_dyn, feedbacklaw)

  # combine model with ODE solver
  feedbackmodel = jax.jit(Flow(feedbacksys).__call__)
  init_xz = jnp.concatenate((init_state, init_state))

  # measure compile time
  start = time.time()
  xz, y_comp = feedbackmodel(init_xz, t, r)
  print(f"Compile+run: {time.time()-start}s")

  # measure jitted run time
  start = time.time()
  for i in range(5):
    xz, y_comp = feedbackmodel(init_xz, t, r)
  print(f"jitted run: {(time.time()-start)/5}s")

  # recompute u from states and input
  x, z = xz[:, :dyn.n_states], xz[:, dyn.n_states:]
  u = jax.vmap(feedbacklaw)(x, z, r)

  us.append(u)

  ax[0].plot(u, label=f'$u$ for $y=x[{output}]$')
  ax[1].plot(x[:, 0], label=f'$x$ for $y=x[{output}]$')
  ax[1].plot(z[:, 0], label=f'$z$ for $y=x[{output}]$')

for a in ax:
  a.legend()

plt.show()

which prints and plots:

Relative degree: 2
Compile+run: 2.477771520614624s
jitted run: 0.014502191543579101s
Relative degree: 1
Compile+run: 2.215977907180786s
jitted run: 0.014466238021850587s

Note that two outputs are exactly the same:

np.all(us[0] == us[1])  # -> True

For larger models and time series, the compilation time can change quite considerably.