Skip to content

Neural network estimators

import jax
from queso.sensors import Sensor
from queso.estimators import BayesianDNNEstimator
sensor = Sensor(
    n=4, 
    k=2,
)
theta, phi, mu = sensor.theta, sensor.phi, sensor.mu
shots = sensor.sample(theta, phi, mu, n_shots=100)
print(shots.shape)
(100, 4)

estimator = BayesianDNNEstimator(nn_dims=[4, 24, 24])
key = jax.random.PRNGKey(123)
params = estimator.init(key, shots)["params"]
print(params)
# estimator.apply(shots)
FrozenDict({
    Dense_0: {
        kernel: Array([[-0.08365491,  0.3709868 ,  0.28520408, -0.61185217],
               [-0.20845719, -0.41968852, -0.15185066,  0.9351279 ],
               [-0.21638115,  0.55949545,  0.594238  ,  0.46812084],
               [-0.8318201 , -0.5416301 ,  0.55196613, -0.25835568]],      dtype=float32),
        bias: Array([0., 0., 0., 0.], dtype=float32),
    },
    Dense_1: {
        kernel: Array([[ 0.14937131, -0.18914305, -0.5080564 ,  0.15999958, -0.2066963 ,
                -0.1842265 , -0.15356986,  0.0172486 ,  0.32604164,  0.13108334,
                 0.1680638 ,  0.40037102,  0.15847811,  0.3265098 , -0.43761328,
                -0.22420444, -0.59605354,  0.1729926 ,  0.40469268,  0.18778297,
                 0.10173465,  0.1215571 , -0.20649342,  0.2442985 ],
               [-0.24034177,  0.41185507,  0.3816779 , -0.51110876,  0.03534057,
                -0.4202144 , -0.17775752,  0.05422129,  0.12789603,  0.04376542,
                 0.23808123, -0.2112978 ,  0.15149625, -0.20524533,  0.18419266,
                 0.11474168,  0.15253516, -0.3465841 , -0.23007385,  0.48777324,
                 0.3222706 , -0.23944817, -0.15965119, -0.07027519],
               [ 0.03383712, -0.08510572,  0.3917954 ,  0.00684621,  0.465896  ,
                -0.04678811, -0.3146166 ,  0.21349184, -0.5530806 ,  0.14162503,
                 0.45468837, -0.23694351,  0.48872536, -0.14468612, -0.03755883,
                -0.06868707, -0.12848948,  0.36772716, -0.1554714 , -0.04762909,
                 0.05237445, -0.091881  ,  0.26468495,  0.30069456],
               [ 0.38084552,  0.30903834,  0.37079817, -0.26668444, -0.18872268,
                 0.45161825, -0.4414809 , -0.09558421,  0.26246756, -0.25547796,
                -0.12401953,  0.43308026,  0.00134409, -0.2966461 , -0.12335505,
                 0.01382376, -0.08543437,  0.30213448,  0.37633315, -0.3479128 ,
                 0.11296993,  0.26790282,  0.38247675, -0.13598952]],      dtype=float32),
        bias: Array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
               0., 0., 0., 0., 0., 0., 0.], dtype=float32),
    },
    Dense_2: {
        kernel: Array([[ 2.25491375e-01, -1.13865370e-02, -5.06403390e-03,
                -1.58415258e-01,  3.78137767e-01, -2.99580060e-02,
                -3.21684629e-01,  3.39279175e-02, -1.65456876e-01,
                -7.78309107e-02,  3.58399510e-01,  2.41241962e-01,
                -2.85243746e-02,  4.58118059e-02,  1.52763382e-01,
                 1.39430519e-02, -2.04812348e-01, -3.63810778e-01,
                -2.68103648e-02,  6.58722073e-02, -7.89321065e-02,
                -1.52091905e-02, -4.41068262e-02, -2.79235274e-01],
               [ 2.76903436e-02,  1.10579990e-02, -1.51234478e-01,
                -1.75035983e-01, -8.39567259e-02, -7.43359476e-02,
                 9.50256437e-02, -1.63660452e-01,  1.83771223e-01,
                 4.63409513e-01,  2.09477469e-01, -3.99626195e-02,
                 2.42376402e-02, -3.22242200e-01,  3.95426333e-01,
                 2.36458436e-01,  2.55024254e-01,  1.93468466e-01,
                 4.39392716e-01, -3.33608657e-01, -3.59660596e-01,
                -1.18314736e-02,  8.11895281e-02,  1.15287669e-01],
               [-1.63481180e-02, -8.19125697e-02,  1.26321062e-01,
                -2.77354240e-01,  2.79172629e-01,  5.04786745e-02,
                -3.55839044e-01, -3.05290967e-01,  1.93643734e-01,
                 8.51776972e-02,  2.70031095e-01,  1.36901736e-01,
                -5.02076894e-02, -1.56165063e-01, -2.80079693e-01,
                 1.00311950e-01,  3.13740402e-01, -1.65029421e-01,
                -2.16216400e-01,  4.33761984e-01, -3.02123129e-01,
                 8.17995146e-02,  2.01310307e-01,  8.15882441e-03],
               [ 1.04504995e-01, -3.11425298e-01, -4.11186934e-01,
                 1.55647144e-01, -7.59535804e-02, -6.88362792e-02,
                 3.83669257e-01,  2.98307627e-01,  4.25858200e-02,
                -1.85654029e-01,  1.81865722e-01, -4.09101725e-01,
                -1.56061724e-01,  2.19917938e-01, -3.02100629e-01,
                 3.29461008e-01,  1.01746380e-01,  3.46198440e-01,
                 2.03256577e-01, -5.44255823e-02, -1.78660229e-01,
                 1.48702130e-01, -1.70167431e-01,  2.49268323e-01],
               [ 2.65970439e-01,  3.53289634e-01, -3.26443404e-01,
                -3.25241804e-01,  1.27290320e-02,  3.80037725e-02,
                -2.69622505e-01,  1.51843280e-01,  1.42860830e-01,
                -5.38047068e-02, -1.84295505e-01,  4.90191877e-02,
                 1.32455662e-01, -3.13954622e-01,  2.51767397e-01,
                 1.50854260e-01, -3.41550522e-02,  4.55536664e-01,
                 8.28002207e-03, -8.03813115e-02,  4.35944051e-02,
                 2.26386815e-01, -5.73772267e-02,  7.07994998e-02],
               [ 4.01206106e-01, -2.70196974e-01,  1.12989388e-01,
                 2.94497106e-02,  2.27081195e-01, -7.41875619e-02,
                 3.41933101e-01, -2.75294125e-01,  1.09364204e-01,
                -1.76638961e-01,  2.42856070e-01, -1.66563958e-01,
                -2.23697782e-01, -9.79895666e-02,  1.13236167e-01,
                 4.51653570e-01,  1.32272661e-01,  1.08618416e-01,
                 1.41332904e-02,  2.84736663e-01,  3.03926140e-01,
                 3.00636888e-01,  5.67677580e-02, -1.93328723e-01],
               [-1.53933078e-01, -2.55336493e-01, -8.19650758e-03,
                 1.86908513e-01, -1.90767515e-02,  1.39956370e-01,
                 2.70861015e-02, -1.23284504e-01, -3.82062852e-01,
                 1.78689901e-02, -2.06412166e-01,  5.69417439e-02,
                 8.33688006e-02,  1.45140931e-01, -2.20353469e-01,
                -1.31755129e-01, -4.17521857e-02,  7.38363639e-02,
                -1.95212230e-01, -1.28135189e-01,  6.59373924e-02,
                 2.24778354e-02, -6.09202497e-02, -9.05505270e-02],
               [-1.80162087e-01, -3.64136636e-01,  7.23553728e-03,
                 3.71738911e-01, -7.84352124e-02,  1.63317956e-02,
                -4.21304345e-01, -4.18550879e-01,  1.12658493e-01,
                 1.26450434e-01,  8.91803857e-03, -1.31701946e-01,
                -3.03009838e-01,  4.18287292e-02,  2.41715312e-01,
                 1.47146344e-01,  2.56591767e-01,  2.02154756e-01,
                 1.55571327e-01,  1.05449550e-01, -7.92086720e-02,
                -7.51967579e-02,  3.77291411e-01, -1.85709295e-03],
               [-3.54883820e-01, -7.37906173e-02,  2.41412565e-01,
                -1.04516417e-01,  2.23844215e-01, -2.00551078e-01,
                 6.85574710e-02, -1.93590567e-01,  2.05923989e-01,
                -1.98893040e-01, -2.89795071e-01, -3.52679014e-01,
                -7.16579333e-02, -7.89012462e-02, -6.59139007e-02,
                -4.06321138e-01,  2.29897466e-03,  2.70919204e-01,
                 2.28084866e-02, -3.05252284e-01, -1.80219606e-01,
                -3.08506191e-01, -7.76899979e-02,  3.21081042e-01],
               [-1.12351499e-01,  1.10930078e-01, -8.08678269e-02,
                -2.50243515e-01,  1.79248154e-01, -1.59811139e-01,
                 3.01836103e-01, -2.38608733e-01, -3.31554383e-01,
                -1.23683535e-01, -1.17919989e-01,  3.11009973e-01,
                 1.59295455e-01, -5.36964685e-02, -2.12155879e-01,
                 2.35514805e-01,  3.24693769e-01,  2.14452714e-01,
                -3.48075569e-01,  9.54022445e-03, -4.12589222e-01,
                 1.52493184e-02,  1.74951643e-01, -2.71645635e-01],
               [-4.71246056e-02, -4.11490142e-01,  5.07123806e-02,
                 1.19829379e-01, -1.07441247e-02, -3.62271756e-01,
                 2.43651643e-01, -1.77882656e-01, -3.33583169e-02,
                -3.61576527e-01,  1.07067786e-01, -3.17357123e-01,
                -2.55576968e-01,  5.48290275e-02,  5.19542061e-02,
                -9.61242989e-03, -3.88955295e-01, -1.45109937e-01,
                 9.35693365e-03, -1.43300757e-01, -4.41279709e-01,
                 1.72140911e-01,  2.35007703e-01, -3.31197446e-03],
               [-1.33600116e-01, -1.54750451e-01,  5.64417578e-02,
                -1.65767055e-02,  3.82275224e-01,  7.35584050e-02,
                -2.27595180e-01,  2.92210672e-02,  1.67813540e-01,
                -5.18216789e-02,  1.30822882e-02,  1.40994489e-01,
                 4.61122394e-02, -3.37945670e-01,  1.96998149e-01,
                -2.66315378e-02,  3.25565875e-01,  1.92222059e-01,
                 1.32310972e-01, -2.43020818e-01,  3.44603866e-01,
                -1.85616761e-01, -7.08850026e-02, -4.04549763e-02],
               [ 1.39631957e-01, -4.32405114e-01, -3.27705257e-02,
                -1.96498334e-01,  1.90235689e-01,  1.14358060e-01,
                -1.58067092e-01,  1.17618933e-01, -2.77943254e-01,
                -6.88012242e-02, -5.78028113e-02, -3.33511591e-01,
                -1.54513866e-01,  2.48887530e-03, -1.62264466e-01,
                 2.14214846e-01,  3.47371578e-01,  1.10879473e-01,
                -4.01201695e-01,  3.68806832e-02, -2.19796315e-01,
                 1.63927466e-01,  9.02015045e-02, -1.53950319e-01],
               [ 1.69481844e-01,  5.85003383e-02, -2.22225964e-01,
                -2.16828793e-01,  7.01071694e-02,  3.12438309e-02,
                 2.93007493e-01, -3.49083841e-01, -1.71733990e-01,
                 1.86872363e-01,  3.51303101e-01, -4.45020139e-01,
                 1.69867963e-01, -2.08952576e-01, -2.57535100e-01,
                 1.27940893e-01,  2.10569724e-01, -9.74456295e-02,
                -2.86363691e-01,  3.56100388e-02,  4.57427697e-03,
                -5.32098338e-02, -5.82929328e-02, -1.27870783e-01],
               [-1.73843190e-01,  3.01884234e-01,  3.60280693e-01,
                 6.77487478e-02,  1.12952344e-01, -1.42433271e-01,
                 8.16462710e-02, -5.52017279e-02,  1.28058389e-01,
                 2.45652851e-02, -5.72223328e-02, -2.77937293e-01,
                 1.33385971e-01,  2.24444464e-01, -1.59416884e-01,
                 3.05836070e-02, -1.38374224e-01,  5.56492992e-02,
                 1.40427798e-02,  3.65309678e-02, -8.00843090e-02,
                 1.02197528e-01,  3.25496346e-01,  4.98737171e-02],
               [-9.66957882e-02, -1.11405179e-01,  8.50759596e-02,
                -1.02184564e-01, -2.06577219e-02, -2.61510134e-01,
                 9.22261104e-02, -1.35970548e-01,  3.23614269e-01,
                 3.80567014e-02,  2.23034754e-01, -2.90879250e-01,
                -3.33931558e-02, -1.10486351e-01,  8.78648311e-02,
                -4.17509764e-01, -9.04607400e-03, -2.20527545e-01,
                -1.01033665e-01, -6.21551387e-02,  2.94364356e-02,
                -2.42240634e-02,  2.21293136e-01,  2.12015688e-01],
               [-3.63979995e-01,  3.17663439e-02, -1.62746087e-01,
                 4.46584215e-03, -3.96597445e-01, -7.26754665e-02,
                 8.38287640e-03, -2.20433734e-02,  1.35798171e-01,
                 3.37653339e-01,  1.32088169e-01, -3.66047084e-01,
                 3.25076103e-01, -8.84882268e-03,  5.63119014e-04,
                 1.01472408e-01,  2.63151377e-01, -7.03234151e-02,
                 1.21566400e-01,  2.32302651e-01, -6.25573173e-02,
                -8.24496225e-02, -1.30635470e-01,  8.97181034e-02],
               [-9.45673510e-02, -2.43835207e-02,  6.54266998e-02,
                -3.70044895e-02, -5.31760640e-02, -3.62096846e-01,
                -1.50420055e-01,  9.35560092e-02,  2.24321932e-01,
                 3.67009193e-02, -3.31124812e-01, -8.06894228e-02,
                -1.98627040e-02, -9.74987671e-02,  4.50054705e-01,
                -8.62753987e-02,  3.21459144e-01,  2.99194425e-01,
                -3.47413808e-01,  2.86772698e-01, -3.11999589e-01,
                -2.49124870e-01,  5.47240935e-02, -3.89365032e-02],
               [ 1.96191400e-01, -1.66180357e-01,  1.14269428e-01,
                 2.28183463e-01,  1.15528323e-01,  1.20074466e-01,
                 2.64282793e-01,  1.43343374e-01, -1.75862223e-01,
                -1.76199526e-01, -1.88328221e-01, -1.39402315e-01,
                -1.26377558e-02, -2.41472259e-01,  1.31899178e-01,
                 8.29819590e-02, -2.47695848e-01,  7.01468661e-02,
                -2.22844198e-01, -2.80178607e-01, -2.02988148e-01,
                -1.15273558e-02,  1.26338452e-02,  7.07316957e-03],
               [-7.97300115e-02, -5.25886901e-02,  1.58460796e-01,
                -2.35498503e-01,  2.04430297e-01, -4.10622329e-01,
                -3.31815988e-01,  8.34110901e-02,  2.65484620e-02,
                 8.55713263e-02,  5.00390604e-02,  6.37989491e-02,
                 2.11597666e-01,  2.87527442e-01, -1.89522967e-01,
                 4.13127318e-02,  7.07679540e-02,  5.58287697e-03,
                 1.74191803e-01, -2.29688987e-01,  6.82202056e-02,
                 1.56064987e-01, -3.20060223e-01,  3.22367281e-01],
               [ 1.59097135e-01,  1.39410570e-01,  8.50875005e-02,
                 8.07980597e-02,  7.07337707e-02, -1.47244213e-02,
                -4.05400306e-01,  1.68567300e-01,  7.89226126e-03,
                 1.09932095e-01,  3.76471464e-04,  2.10060254e-01,
                -3.52591842e-01,  1.13672972e-01, -2.21537545e-01,
                -2.29707584e-01, -5.35412058e-02, -2.22828344e-01,
                 2.89548784e-01,  2.93076426e-01, -2.00659156e-01,
                 1.77129313e-01,  9.09424797e-02, -1.47566587e-01],
               [-1.30602211e-01, -3.42392951e-01, -2.81699032e-01,
                 1.14064831e-02,  1.05806813e-03, -1.21206455e-01,
                 4.62855436e-02,  4.14297432e-02,  2.13217407e-01,
                 1.53005973e-01, -1.80562243e-01, -5.85748851e-02,
                 1.12419151e-01,  4.87756394e-02, -1.13454066e-01,
                -3.28012645e-01, -4.05437469e-01, -1.86508641e-01,
                -4.41192478e-01, -2.06266835e-01,  2.52475947e-01,
                -7.46934563e-02,  3.24285090e-01, -2.29174852e-01],
               [-2.81270146e-02,  2.37445295e-01,  8.88779387e-02,
                 4.92694303e-02, -2.43233681e-01,  3.16045359e-02,
                -4.54478681e-01,  3.78209710e-01, -1.82452071e-02,
                -1.35280250e-03, -1.13483384e-01,  2.05291539e-01,
                 1.05845995e-01, -1.50708869e-01, -1.36057526e-01,
                 1.96010545e-01,  2.37198129e-01, -3.54070812e-01,
                -1.06457314e-02, -1.33693233e-01, -2.25637525e-01,
                -3.11158039e-02, -6.19137548e-02, -2.84459770e-01],
               [-5.47694787e-02,  2.20232755e-01, -1.83436871e-01,
                -1.80028495e-04,  2.96595655e-02, -7.82586411e-02,
                 1.58477515e-01,  5.68558797e-02,  2.55503297e-01,
                -2.02010885e-01, -3.16947728e-01, -1.83312342e-01,
                -2.22195789e-01,  1.19831428e-01,  1.22898407e-01,
                 8.92846137e-02, -3.40200514e-01,  1.21756166e-01,
                -2.13965341e-01, -2.11811978e-02, -5.96001297e-02,
                 1.88846871e-01, -1.52903005e-01,  4.52255428e-01]],      dtype=float32),
        bias: Array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
               0., 0., 0., 0., 0., 0., 0.], dtype=float32),
    },
})